Spaces:
Runtime error
Runtime error
| # %% | |
| import gradio as gr | |
| import matplotlib.pyplot as plt | |
| import numpy as np | |
| import pandas as pd | |
| import random | |
| from matplotlib.ticker import MaxNLocator | |
| from transformers import pipeline | |
| MODEL_NAMES = ["bert-base-uncased", "roberta-base", "bert-large-uncased", "roberta-large"] | |
| OWN_MODEL_NAME = 'add-a-model' | |
| DECIMAL_PLACES = 1 | |
| EPS = 1e-5 # to avoid /0 errors | |
| # Example date conts | |
| DATE_SPLIT_KEY = "DATE" | |
| START_YEAR = 1801 | |
| STOP_YEAR = 1999 | |
| NUM_PTS = 20 | |
| DATES = np.linspace(START_YEAR, STOP_YEAR, NUM_PTS).astype(int).tolist() | |
| DATES = [f'{d}' for d in DATES] | |
| # Example place conts | |
| # https://www3.weforum.org/docs/WEF_GGGR_2021.pdf | |
| # Bottom 10 and top 10 Global Gender Gap ranked countries. | |
| PLACE_SPLIT_KEY = "PLACE" | |
| PLACES = [ | |
| "Afghanistan", | |
| "Yemen", | |
| "Iraq", | |
| "Pakistan", | |
| "Syria", | |
| "Democratic Republic of Congo", | |
| "Iran", | |
| "Mali", | |
| "Chad", | |
| "Saudi Arabia", | |
| "Switzerland", | |
| "Ireland", | |
| "Lithuania", | |
| "Rwanda", | |
| "Namibia", | |
| "Sweden", | |
| "New Zealand", | |
| "Norway", | |
| "Finland", | |
| "Iceland"] | |
| # Example Reddit interest consts | |
| # in order of increasing self-identified female participation. | |
| # See http://bburky.com/subredditgenderratios/ , Minimum subreddit size: 400000 | |
| SUBREDDITS = [ | |
| "GlobalOffensive", | |
| "pcmasterrace", | |
| "nfl", | |
| "sports", | |
| "The_Donald", | |
| "leagueoflegends", | |
| "Overwatch", | |
| "gonewild", | |
| "Futurology", | |
| "space", | |
| "technology", | |
| "gaming", | |
| "Jokes", | |
| "dataisbeautiful", | |
| "woahdude", | |
| "askscience", | |
| "wow", | |
| "anime", | |
| "BlackPeopleTwitter", | |
| "politics", | |
| "pokemon", | |
| "worldnews", | |
| "reddit.com", | |
| "interestingasfuck", | |
| "videos", | |
| "nottheonion", | |
| "television", | |
| "science", | |
| "atheism", | |
| "movies", | |
| "gifs", | |
| "Music", | |
| "trees", | |
| "EarthPorn", | |
| "GetMotivated", | |
| "pokemongo", | |
| "news", | |
| # removing below subreddit as most of the tokens are taken up by it: | |
| # ['ff', '##ff', '##ff', '##fu', '##u', '##u', '##u', '##u', '##u', '##u', '##u', '##u', '##u', '##u', '##u', ...] | |
| # "fffffffuuuuuuuuuuuu", | |
| "Fitness", | |
| "Showerthoughts", | |
| "OldSchoolCool", | |
| "explainlikeimfive", | |
| "todayilearned", | |
| "gameofthrones", | |
| "AdviceAnimals", | |
| "DIY", | |
| "WTF", | |
| "IAmA", | |
| "cringepics", | |
| "tifu", | |
| "mildlyinteresting", | |
| "funny", | |
| "pics", | |
| "LifeProTips", | |
| "creepy", | |
| "personalfinance", | |
| "food", | |
| "AskReddit", | |
| "books", | |
| "aww", | |
| "sex", | |
| "relationships", | |
| ] | |
| GENDERED_LIST = [ | |
| ['he', 'she'], | |
| ['him', 'her'], | |
| ['his', 'hers'], | |
| ["himself", "herself"], | |
| ['male', 'female'], | |
| ['man', 'woman'], | |
| ['men', 'women'], | |
| ["husband", "wife"], | |
| ['father', 'mother'], | |
| ['boyfriend', 'girlfriend'], | |
| ['brother', 'sister'], | |
| ["actor", "actress"], | |
| ] | |
| # %% | |
| # Fire up the models | |
| models = dict() | |
| for bert_like in MODEL_NAMES: | |
| models[bert_like] = pipeline("fill-mask", model=bert_like) | |
| # %% | |
| def get_gendered_token_ids(): | |
| male_gendered_tokens = [list[0] for list in GENDERED_LIST] | |
| female_gendered_tokens = [list[1] for list in GENDERED_LIST] | |
| return male_gendered_tokens, female_gendered_tokens | |
| def prepare_text_for_masking(input_text, mask_token, gendered_tokens, split_key): | |
| text_w_masks_list = [ | |
| mask_token if word.lower() in gendered_tokens else word for word in input_text.split()] | |
| num_masks = len([m for m in text_w_masks_list if m == mask_token]) | |
| text_portions = ' '.join(text_w_masks_list).split(split_key) | |
| return text_portions, num_masks | |
| def get_avg_prob_from_pipeline_outputs(mask_filled_text, gendered_token, num_preds): | |
| pronoun_preds = [sum([ | |
| pronoun["score"] if pronoun["token_str"].strip().lower() in gendered_token else 0.0 | |
| for pronoun in top_preds]) | |
| for top_preds in mask_filled_text | |
| ] | |
| return round(sum(pronoun_preds) / (EPS + num_preds) * 100, DECIMAL_PLACES) | |
| # %% | |
| def get_figure(df, gender, n_fit=1): | |
| df = df.set_index('x-axis') | |
| cols = df.columns | |
| xs = list(range(len(df))) | |
| ys = df[cols[0]] | |
| fig, ax = plt.subplots() | |
| # Trying small fig due to rendering issues on HF, not on VS Code | |
| fig.set_figheight(3) | |
| fig.set_figwidth(9) | |
| # find stackoverflow reference | |
| p, C_p = np.polyfit(xs, ys, n_fit, cov=1) | |
| t = np.linspace(min(xs)-1, max(xs)+1, 10*len(xs)) | |
| TT = np.vstack([t**(n_fit-i) for i in range(n_fit+1)]).T | |
| # matrix multiplication calculates the polynomial values | |
| yi = np.dot(TT, p) | |
| C_yi = np.dot(TT, np.dot(C_p, TT.T)) # C_y = TT*C_z*TT.T | |
| sig_yi = np.sqrt(np.diag(C_yi)) # Standard deviations are sqrt of diagonal | |
| ax.fill_between(t, yi+sig_yi, yi-sig_yi, alpha=.25) | |
| ax.plot(t, yi, '-') | |
| ax.plot(df, 'ro') | |
| ax.legend(list(df.columns)) | |
| ax.axis('tight') | |
| ax.set_xlabel("Value injected into input text") | |
| ax.set_title( | |
| f"Probability of predicting {gender} pronouns.") | |
| ax.set_ylabel(f"Softmax prob for pronouns") | |
| ax.xaxis.set_major_locator(MaxNLocator(6)) | |
| ax.tick_params(axis='x', labelrotation=5) | |
| return fig | |
| # %% | |
| def predict_gender_pronouns( | |
| model_name, | |
| own_model_name, | |
| indie_vars, | |
| split_key, | |
| normalizing, | |
| n_fit, | |
| input_text, | |
| ): | |
| """Run inference on input_text for each model type, returning df and plots of percentage | |
| of gender pronouns predicted as female and male in each target text. | |
| """ | |
| if model_name not in MODEL_NAMES: | |
| model = pipeline("fill-mask", model=own_model_name) | |
| else: | |
| model = models[model_name] | |
| mask_token = model.tokenizer.mask_token | |
| indie_vars_list = indie_vars.split(',') | |
| male_gendered_tokens, female_gendered_tokens = get_gendered_token_ids() | |
| text_segments, num_preds = prepare_text_for_masking( | |
| input_text, mask_token, male_gendered_tokens + female_gendered_tokens, split_key) | |
| male_pronoun_preds = [] | |
| female_pronoun_preds = [] | |
| for indie_var in indie_vars_list: | |
| target_text = f"{indie_var}".join(text_segments) | |
| mask_filled_text = model(target_text) | |
| # Quick hack as realized return type based on how many MASKs in text. | |
| if type(mask_filled_text[0]) is not list: | |
| mask_filled_text = [mask_filled_text] | |
| female_pronoun_preds.append(get_avg_prob_from_pipeline_outputs( | |
| mask_filled_text, | |
| female_gendered_tokens, | |
| num_preds | |
| )) | |
| male_pronoun_preds.append(get_avg_prob_from_pipeline_outputs( | |
| mask_filled_text, | |
| male_gendered_tokens, | |
| num_preds | |
| )) | |
| if normalizing: | |
| total_gendered_probs = np.add( | |
| female_pronoun_preds, male_pronoun_preds) | |
| female_pronoun_preds = np.around( | |
| np.divide(female_pronoun_preds, total_gendered_probs+EPS)*100, | |
| decimals=DECIMAL_PLACES | |
| ) | |
| male_pronoun_preds = np.around( | |
| np.divide(male_pronoun_preds, total_gendered_probs+EPS)*100, | |
| decimals=DECIMAL_PLACES | |
| ) | |
| results_df = pd.DataFrame({'x-axis': indie_vars_list}) | |
| results_df['female_pronouns'] = female_pronoun_preds | |
| results_df['male_pronouns'] = male_pronoun_preds | |
| female_fig = get_figure(results_df.drop( | |
| 'male_pronouns', axis=1), 'female', n_fit,) | |
| male_fig = get_figure(results_df.drop( | |
| 'female_pronouns', axis=1), 'male', n_fit,) | |
| display_text = f"{random.choice(indie_vars_list)}".join(text_segments) | |
| return ( | |
| display_text, | |
| female_fig, | |
| male_fig, | |
| results_df, | |
| ) | |
| # %% | |
| title = "Causing Gender Pronouns" | |
| description = """ | |
| ## Intro | |
| """ | |
| date_example = [ | |
| MODEL_NAMES[1], | |
| '', | |
| ', '.join(DATES), | |
| 'DATE', | |
| "False", | |
| 1, | |
| 'She was a teenager in DATE.' | |
| ] | |
| place_example = [ | |
| MODEL_NAMES[0], | |
| '', | |
| ', '.join(PLACES), | |
| 'PLACE', | |
| "False", | |
| 1, | |
| 'She became an adult in PLACE.' | |
| ] | |
| subreddit_example = [ | |
| MODEL_NAMES[3], | |
| '', | |
| ', '.join(SUBREDDITS), | |
| 'SUBREDDIT', | |
| "False", | |
| 1, | |
| 'She was a kid. SUBREDDIT.' | |
| ] | |
| own_model_example = [ | |
| OWN_MODEL_NAME, | |
| 'emilyalsentzer/Bio_ClinicalBERT', | |
| ', '.join(DATES), | |
| 'DATE', | |
| "False", | |
| 1, | |
| 'She was exposed to the virus in DATE.' | |
| ] | |
| def date_fn(): | |
| return date_example | |
| def place_fn(): | |
| return place_example | |
| def reddit_fn(): | |
| return subreddit_example | |
| def your_fn(): | |
| return own_model_example | |
| # %% | |
| demo = gr.Blocks() | |
| with demo: | |
| gr.Markdown("# Spurious Correlation Evaluation for Pre-trained LLMs") | |
| gr.Markdown("Find spurious correlations between seemingly independent variables (for example between `gender` and `time`) in almost any BERT-like LLM on Hugging Face, below.") | |
| gr.Markdown("See why this happens how in our paper, [Selection Bias Induced Spurious Correlations in Large Language Models](https://arxiv.org/pdf/2207.08982.pdf), presented at [ICML 2022 Workshop on Spurious Correlations, Invariance, and Stability](https://sites.google.com/view/scis-workshop/home).") | |
| gr.Markdown("## Instructions for this Demo") | |
| gr.Markdown("1) Click on one of the examples below (where we sweep through a spectrum of `places`, `dates` and `subreddits`) to pre-populate the input fields.") | |
| gr.Markdown("2) Check out the pre-populated fields as you scroll down to the ['Hit Submit...'] button!") | |
| gr.Markdown("3) Repeat steps (1) and (2) with more pre-populated inputs or with your own values in the input fields!") | |
| gr.Markdown("## Example inputs") | |
| gr.Markdown("Click a button below to pre-populate input fields with example values. Then scroll down to Hit Submit to generate predictions.") | |
| with gr.Row(): | |
| date_gen = gr.Button('Click for date example inputs') | |
| gr.Markdown("<-- x-axis sorted by older to more recent dates:") | |
| place_gen = gr.Button('Click for country example inputs') | |
| gr.Markdown( | |
| "<-- x-axis sorted by bottom 10 and top 10 [Global Gender Gap](https://www3.weforum.org/docs/WEF_GGGR_2021.pdf) ranked countries:") | |
| subreddit_gen = gr.Button('Click for Subreddit example inputs') | |
| gr.Markdown( | |
| "<-- x-axis sorted in order of increasing self-identified female participation (see [bburky](http://bburky.com/subredditgenderratios/)): ") | |
| your_gen = gr.Button('Add-a-model example inputs') | |
| gr.Markdown("<-- x-axis dates, with your own model loaded! (If first time, try another example, it can take a while to load new model.)") | |
| gr.Markdown("## Input fields") | |
| gr.Markdown( | |
| f"A) Pick a spectrum of comma separated values for text injection and x-axis.") | |
| with gr.Row(): | |
| x_axis = gr.Textbox( | |
| lines=3, | |
| label="A) Comma separated values for text injection and x-axis", | |
| ) | |
| gr.Markdown("B) Pick a pre-loaded BERT-family model of interest on the right.") | |
| gr.Markdown(f"Or C) select `{OWN_MODEL_NAME}`, then add the mame of any other Hugging Face model that supports the [fill-mask](https://huggingface.co/models?pipeline_tag=fill-mask) task on the right (note: this may take some time to load).") | |
| with gr.Row(): | |
| model_name = gr.Radio( | |
| MODEL_NAMES + [OWN_MODEL_NAME], | |
| type="value", | |
| label="B) BERT-like model.", | |
| ) | |
| own_model_name = gr.Textbox( | |
| label="C) If you selected an 'add-a-model' model, put any Hugging Face pipeline model name (that supports the fill-mask task) here.", | |
| ) | |
| gr.Markdown("D) Pick if you want to the predictions normalied to these gendered terms only.") | |
| gr.Markdown("E) Also tell the demo what special token you will use in your input text, that you would like replaced with the spectrum of values you listed above.") | |
| gr.Markdown("And F) the degree of polynomial fit used for high-lighting potential spurious association.") | |
| with gr.Row(): | |
| to_normalize = gr.Dropdown( | |
| ["False", "True"], | |
| label="D) Normalize model's predictions to only the gendered ones?", | |
| type="index", | |
| ) | |
| place_holder = gr.Textbox( | |
| label="E) Special token place-holder", | |
| ) | |
| n_fit = gr.Dropdown( | |
| list(range(1, 5)), | |
| label="F) Degree of polynomial fit", | |
| type="value", | |
| ) | |
| gr.Markdown( | |
| "G) Finally, add input text that includes at least one gendered pronouns and one place-holder token specified above.") | |
| with gr.Row(): | |
| input_text = gr.Textbox( | |
| lines=2, | |
| label="G) Input text with pronouns and place-holder token", | |
| ) | |
| gr.Markdown("## Outputs!") | |
| #gr.Markdown("Scroll down and 'Hit Submit'!") | |
| with gr.Row(): | |
| btn = gr.Button("Hit submit to generate predictions!") | |
| with gr.Row(): | |
| sample_text = gr.Textbox( | |
| type="auto", label="Output text: Sample of text fed to model") | |
| with gr.Row(): | |
| female_fig = gr.Plot(type="auto") | |
| male_fig = gr.Plot(type="auto") | |
| with gr.Row(): | |
| df = gr.Dataframe( | |
| show_label=True, | |
| overflow_row_behaviour="show_ends", | |
| label="Table of softmax probability for pronouns predictions", | |
| ) | |
| with gr.Row(): | |
| date_gen.click(date_fn, inputs=[], outputs=[model_name, own_model_name, | |
| x_axis, place_holder, to_normalize, n_fit, input_text]) | |
| place_gen.click(place_fn, inputs=[], outputs=[ | |
| model_name, own_model_name, x_axis, place_holder, to_normalize, n_fit, input_text]) | |
| subreddit_gen.click(reddit_fn, inputs=[], outputs=[ | |
| model_name, own_model_name, x_axis, place_holder, to_normalize, n_fit, input_text]) | |
| your_gen.click(your_fn, inputs=[], outputs=[ | |
| model_name, own_model_name, x_axis, place_holder, to_normalize, n_fit, input_text]) | |
| btn.click( | |
| predict_gender_pronouns, | |
| inputs=[model_name, own_model_name, x_axis, place_holder, | |
| to_normalize, n_fit, input_text], | |
| outputs=[sample_text, female_fig, male_fig, df]) | |
| demo.launch(debug=True) | |
| # %% | |