Spaces:
Running
Running
| # %% | |
| import gradio as gr | |
| import matplotlib.pyplot as plt | |
| import numpy as np | |
| import pandas as pd | |
| import random | |
| from matplotlib.ticker import MaxNLocator | |
| from transformers import pipeline | |
| from winogender_sentences import get_sentences | |
| OWN_MODEL_NAME = 'add-a-model' | |
| PICK_YOUR_OWN_LABEL = 'pick-your-own' | |
| MODEL_NAME_DICT = { | |
| "roberta-large": "RoBERTa-large", | |
| "bert-large-uncased": "BERT-large", | |
| "roberta-base": "RoBERTa-base", | |
| "bert-base-uncased": "BERT-base", | |
| OWN_MODEL_NAME: "Your model's" | |
| } | |
| MODEL_NAMES = list(MODEL_NAME_DICT.keys()) | |
| DECIMAL_PLACES = 1 | |
| EPS = 1e-5 # to avoid /0 errors | |
| NUM_PTS_TO_AVERAGE = 2 | |
| # Example date conts | |
| DATE_SPLIT_KEY = "DATE" | |
| START_YEAR = 1901 | |
| STOP_YEAR = 2016 | |
| NUM_PTS = 30 | |
| DATES = np.linspace(START_YEAR, STOP_YEAR, NUM_PTS).astype(int).tolist() | |
| DATES = [f'{d}' for d in DATES] | |
| GENDERED_LIST = [ | |
| ['he', 'she'], | |
| ['him', 'her'], | |
| ['his', 'hers'], | |
| ["himself", "herself"], | |
| ['male', 'female'], | |
| # ['man', 'woman'] Explicitly added in winogender extended sentences | |
| ['men', 'women'], | |
| ["husband", "wife"], | |
| ['father', 'mother'], | |
| ['boyfriend', 'girlfriend'], | |
| ['brother', 'sister'], | |
| ["actor", "actress"], | |
| ] | |
| # %% | |
| # Fire up the models | |
| models = {m: pipeline("fill-mask", model=m) | |
| for m in MODEL_NAMES if m != OWN_MODEL_NAME} | |
| # %% | |
| # Get the winogender sentences | |
| winogender_sentences = get_sentences() | |
| occs = sorted(list({sentence_id.split('_')[0] | |
| for sentence_id in winogender_sentences})) | |
| # %% | |
| def get_gendered_token_ids(): | |
| male_gendered_tokens = [list[0] for list in GENDERED_LIST] | |
| female_gendered_tokens = [list[1] for list in GENDERED_LIST] | |
| return male_gendered_tokens, female_gendered_tokens | |
| def get_winogender_texts(occ): | |
| return [winogender_sentences[id] for id in winogender_sentences.keys() if id.split('_')[0] == occ] | |
| def display_input_texts(occ, alt_text): | |
| if occ == PICK_YOUR_OWN_LABEL: | |
| texts = alt_text.split('\n') | |
| else: | |
| texts = get_winogender_texts(occ) | |
| display_texts = [ | |
| f"{i+1}) {text}" for (i, text) in enumerate(texts)] | |
| return "\n".join(display_texts), texts | |
| def get_avg_prob_from_pipeline_outputs(pipeline_preds, gendered_tokens, num_preds): | |
| pronoun_preds = [sum([ | |
| pronoun["score"] if pronoun["token_str"].strip( | |
| ).lower() in gendered_tokens else 0.0 | |
| for pronoun in top_preds]) | |
| for top_preds in pipeline_preds | |
| ] | |
| return round(sum(pronoun_preds) / (EPS + num_preds) * 100, DECIMAL_PLACES) | |
| def is_top_pred_gendered(pipeline_preds, gendered_tokens): | |
| return pipeline_preds[0][0]['token_str'].strip().lower() in gendered_tokens | |
| # %% | |
| def get_figure(df, model_name, occ): | |
| xs = df[df.columns[0]] | |
| ys = df[df.columns[1]] | |
| fig, ax = plt.subplots() | |
| ax.bar(xs, ys) | |
| ax.axis('tight') | |
| ax.set_xlabel("Sentence number") | |
| ax.set_ylabel("Specification Metric") | |
| ax.set_title( | |
| f"Task Specification Metric on {MODEL_NAME_DICT[model_name]} for '{occ}' sentences") | |
| return fig | |
| # %% | |
| def predict_gender_pronouns( | |
| model_name, | |
| own_model_name, | |
| texts, | |
| occ, | |
| ): | |
| """Run inference on input_text for selected model type, returning Task Specification metric results. | |
| """ | |
| # TODO: make these selectable by user | |
| indie_vars = ', '.join(DATES) | |
| num_ave = NUM_PTS_TO_AVERAGE | |
| # For debugging | |
| print('input_texts', texts) | |
| if model_name is None or model_name == '': | |
| model_name = MODEL_NAMES[0] | |
| model = models[model_name] | |
| elif model_name == OWN_MODEL_NAME: | |
| model = pipeline("fill-mask", model=own_model_name) | |
| else: | |
| model = models[model_name] | |
| mask_token = model.tokenizer.mask_token | |
| indie_vars_list = indie_vars.split(',') | |
| male_gendered_tokens, female_gendered_tokens = get_gendered_token_ids() | |
| masked_texts = [text.replace('MASK', mask_token) for text in texts] | |
| all_uncertainty_f = {} | |
| not_top_gendered = set() | |
| for i, text in enumerate(masked_texts): | |
| female_pronoun_preds = [] | |
| male_pronoun_preds = [] | |
| top_pred_gendered = True # Assume true unless told otherwise | |
| print(f"{i+1}) {text}") | |
| for indie_var in indie_vars_list[:num_ave] + indie_vars_list[-num_ave:]: | |
| target_text = f"In {indie_var}: {text}" | |
| pipeline_preds = model(target_text) | |
| # Quick hack as realized return type based on how many MASKs in text. | |
| if type(pipeline_preds[0]) is not list: | |
| pipeline_preds = [pipeline_preds] | |
| # If top-pred not gendered, record as such | |
| if not is_top_pred_gendered(pipeline_preds, female_gendered_tokens + male_gendered_tokens): | |
| top_pred_gendered = False | |
| num_preds = 1 # By design | |
| female_pronoun_preds.append(get_avg_prob_from_pipeline_outputs( | |
| pipeline_preds, | |
| female_gendered_tokens, | |
| num_preds | |
| )) | |
| male_pronoun_preds.append(get_avg_prob_from_pipeline_outputs( | |
| pipeline_preds, | |
| male_gendered_tokens, | |
| num_preds | |
| )) | |
| # Normalizing by all gendered predictions | |
| total_gendered_probs = np.add( | |
| female_pronoun_preds, male_pronoun_preds) | |
| norm_female_pronoun_preds = np.around( | |
| np.divide(female_pronoun_preds, total_gendered_probs+EPS)*100, | |
| decimals=DECIMAL_PLACES | |
| ) | |
| sent_idx = f"{i+1}" if top_pred_gendered else f"{i+1}*" | |
| all_uncertainty_f[sent_idx] = round(abs((sum(norm_female_pronoun_preds[-num_ave:]) - sum(norm_female_pronoun_preds[:num_ave])) | |
| / num_ave), DECIMAL_PLACES) | |
| uncertain_df = pd.DataFrame.from_dict( | |
| all_uncertainty_f, orient='index', columns=['Specification Metric']) | |
| uncertain_df = uncertain_df.reset_index().rename( | |
| columns={'index': 'Sentence number'}) | |
| return ( | |
| target_text, | |
| uncertain_df, | |
| get_figure(uncertain_df, model_name, occ), | |
| ) | |
| demo = gr.Blocks() | |
| with demo: | |
| input_texts = gr.Variable([]) | |
| gr.Markdown("**Detect Task Specification at Inference-time.**") | |
| gr.Markdown("""This method exploits the specification-induced spurious correlations demonstrated in this | |
| [Spurious Correlations Hugging Face Space](https://huggingface.co/spaces/anonymousauthorsanonymous/spurious) to detect task specification at inference-time. | |
| For this method, well-specified tasks should have a lower specification metric value, and unspecified tasks should have a higher specification metric value. | |
| """) | |
| gr.Markdown("""As an example, see the figure below with test sentences from the [Winogender schema](https://aclanthology.org/N18-2002/) for the occupation of `Doctor`. | |
| With a close read, you can see that only sentence numbers (3) and (4) are well-specified for the gendered pronoun resolution task: | |
| the masked pronoun is coreferent with the `man` or `woman`; the remainder are unspecfied: the masked pronoun is coreferent with a gender-unspecified person. | |
| In this example we have 100\% accurate detection with the specification metric near zero for only sentence (3) and (4). | |
| <p align="center"> | |
| <img src="file/spec_metric_result.png" alt="results" width="500"/> | |
| </p> | |
| """) | |
| gr.Markdown("**To test this for yourself, follow the numbered steps below to test one of the pre-loaded options.** Once you get the hang of it, you can load a new model and/or provide your own input texts.") | |
| gr.Markdown(f"""1) Pick a preloaded BERT-like model. | |
| *Note: RoBERTa-large performance is best.* | |
| 2) Pick an Occupation type from the Winogender Schemas evaluation set. | |
| *Or select '{PICK_YOUR_OWN_LABEL}' (it need not be about an occupation).* | |
| 3) Click the first button to load input texts. | |
| *Read the sentences to determine which two are well-specified for gendered pronoun coreference resolution. The rest are gender-unspecified.* | |
| 4) Click the second button to get Task Specification Metric results. | |
| """) | |
| with gr.Row(): | |
| model_name = gr.Radio( | |
| MODEL_NAMES, | |
| type="value", | |
| label="1) Pick a preloaded BERT-like model (note: RoBERTa-large performance is best).", | |
| ) | |
| own_model_name = gr.Textbox( | |
| label=f"...Or, if you selected an '{OWN_MODEL_NAME}' model, put any Hugging Face pipeline model name \ | |
| (that supports the `fill-mask` task (see list at https://huggingface.co/models?pipeline_tag=fill-mask).", | |
| ) | |
| with gr.Row(): | |
| occ_box = gr.Radio( | |
| occs+[PICK_YOUR_OWN_LABEL], label=f"2) Pick an Occupation type from the Winogender Schemas evaluation set, or select '{PICK_YOUR_OWN_LABEL}'\ | |
| (it need not be about an occupation).") | |
| with gr.Row(): | |
| alt_input_texts = gr.Textbox( | |
| lines=2, | |
| label=f"...Or, if you selected '{PICK_YOUR_OWN_LABEL}' above, add your own texts new-line delimited sentences here. Be sure\ | |
| to include a single MASK-ed out pronoun. \ | |
| If unsure on the required format, click an occupation above instead, to see some example input texts for this round." | |
| ) | |
| with gr.Row(): | |
| get_text_btn = gr.Button("3) Click to load input texts.") | |
| get_text_btn.click( | |
| fn=display_input_texts, | |
| inputs=[occ_box, alt_input_texts], | |
| outputs=[gr.Textbox( | |
| label='Numbered sentences for evaluation. Number below corresponds to number in x-axis of plot.'), input_texts], | |
| ) | |
| with gr.Row(): | |
| uncertain_btn = gr.Button("4) Click to get Task Specification Metric results!") | |
| gr.Markdown( | |
| """We expect a lower specification metric value for well-specified tasks. | |
| Note: If there is an * by a sentence number, then at least one top prediction for that sentence was non-gendered.""") | |
| with gr.Row(): | |
| female_fig = gr.Plot(type="auto") | |
| with gr.Row(): | |
| female_df = gr.Dataframe() | |
| with gr.Row(): | |
| display_text = gr.Textbox( | |
| type="text", label="Sample of text fed to model") | |
| uncertain_btn.click( | |
| fn=predict_gender_pronouns, | |
| inputs=[model_name, own_model_name, input_texts, occ_box], | |
| # inputs=date_example, | |
| outputs=[display_text, female_df, female_fig] | |
| ) | |
| demo.launch(debug=True) | |
| # %% |