Spaces:
Sleeping
Sleeping
| import gradio as gr | |
| import pandas as pd | |
| import re | |
| import os | |
| import json | |
| import yaml | |
| import matplotlib.pyplot as plt | |
| import seaborn as sns | |
| import plotnine as p9 | |
| import sys | |
| sys.path.append('./src') | |
| sys.path.append('.') | |
| from huggingface_hub import HfApi | |
| repo_id = "HUBioDataLab/PROBE" | |
| api = HfApi() | |
| from src.about import * | |
| from src.saving_utils import * | |
| from src.vis_utils import * | |
| from src.bin.PROBE import run_probe | |
| # ------------------------------------------------------------------ | |
| # Helper functions moved / added here so that UI callbacks can see them | |
| # ------------------------------------------------------------------ | |
| def add_new_eval( | |
| human_file, | |
| skempi_file, | |
| model_name_textbox: str, | |
| revision_name_textbox: str, | |
| benchmark_types, | |
| similarity_tasks, | |
| function_prediction_aspect, | |
| function_prediction_dataset, | |
| family_prediction_dataset, | |
| save, | |
| ): | |
| """Validate inputs, run evaluation and (optionally) save results.""" | |
| if any(task in benchmark_types for task in ['similarity', 'family', 'function']) and human_file is None: | |
| gr.Warning("Human representations are required for similarity, family, or function benchmarks!") | |
| return -1 | |
| if 'affinity' in benchmark_types and skempi_file is None: | |
| gr.Warning("SKEMPI representations are required for affinity benchmark!") | |
| return -1 | |
| gr.Info("Your submission is being processed…") | |
| representation_name = model_name_textbox if revision_name_textbox == '' else revision_name_textbox | |
| try: | |
| results = run_probe( | |
| benchmark_types, | |
| representation_name, | |
| human_file, | |
| skempi_file, | |
| similarity_tasks, | |
| function_prediction_aspect, | |
| function_prediction_dataset, | |
| family_prediction_dataset, | |
| ) | |
| except Exception: | |
| gr.Warning("Your submission has not been processed. Please check your representation files!") | |
| return -1 | |
| if save: | |
| save_results(representation_name, benchmark_types, results) | |
| gr.Info("Your submission has been processed and results are saved!") | |
| else: | |
| gr.Info("Your submission has been processed!") | |
| return 0 | |
| def refresh_data(): | |
| """Re‑start the space and pull fresh leaderboard CSVs from the HF Hub.""" | |
| api.restart_space(repo_id=repo_id) | |
| benchmark_types = ["similarity", "function", "family", "affinity", "leaderboard"] | |
| for benchmark_type in benchmark_types: | |
| path = f"/tmp/{benchmark_type}_results.csv" | |
| if os.path.exists(path): | |
| os.remove(path) | |
| benchmark_types.remove("leaderboard") | |
| download_from_hub(benchmark_types) | |
| # ------- Leaderboard helpers ------------------------------------------------- | |
| def update_metrics(selected_benchmarks): | |
| """Populate metric selector according to chosen benchmark types.""" | |
| updated_metrics = set() | |
| for benchmark in selected_benchmarks: | |
| updated_metrics.update(benchmark_metric_mapping.get(benchmark, [])) | |
| return list(updated_metrics) | |
| def update_leaderboard(selected_methods, selected_metrics): | |
| updated_df = get_baseline_df(selected_methods, selected_metrics) | |
| return updated_df | |
| # ------- Visualisation helpers ---------------------------------------------- | |
| def get_plot_explanation(benchmark_type, x_metric, y_metric, aspect, dataset, single_metric): | |
| """Return a short natural‑language explanation for the produced plot.""" | |
| if benchmark_type == "similarity": | |
| return ( | |
| f"The scatter plot compares models on **{x_metric}** (x‑axis) and " | |
| f"**{y_metric}** (y‑axis). Points further to the upper‑right indicate better " | |
| "performance on both metrics." | |
| ) | |
| elif benchmark_type == "function": | |
| return ( | |
| f"The heat‑map shows performance of each model (columns) across GO terms " | |
| f"for the **{aspect.upper()}** aspect using the **{single_metric}** metric. " | |
| "Darker squares correspond to stronger performance; hierarchical clustering " | |
| "groups similar models and tasks together." | |
| ) | |
| elif benchmark_type == "family": | |
| return ( | |
| f"The horizontal box‑plots summarise cross‑validation performance on the " | |
| f"**{dataset}** dataset. Higher median MCC values indicate better family‑" | |
| "classification accuracy." | |
| ) | |
| elif benchmark_type == "affinity": | |
| return ( | |
| f"Each box‑plot shows the distribution of **{single_metric}** scores for every " | |
| "model when predicting binding affinity changes. Higher values are better." | |
| ) | |
| return "" | |
| def generate_plot_and_explanation( | |
| benchmark_type, | |
| methods_selected, | |
| x_metric, | |
| y_metric, | |
| aspect, | |
| dataset, | |
| single_metric, | |
| ): | |
| """Callback wrapper that returns both the image path and a textual explanation.""" | |
| plot_path = benchmark_plot( | |
| benchmark_type, | |
| methods_selected, | |
| x_metric, | |
| y_metric, | |
| aspect, | |
| dataset, | |
| single_metric, | |
| ) | |
| explanation = get_plot_explanation(benchmark_type, x_metric, y_metric, aspect, dataset, single_metric) | |
| return plot_path, explanation | |
| # ------------------------------------------------------------------ | |
| # UI definition | |
| # ------------------------------------------------------------------ | |
| block = gr.Blocks() | |
| with block: | |
| gr.Markdown(LEADERBOARD_INTRODUCTION) | |
| with gr.Tabs(elem_classes="tab-buttons") as tabs: | |
| # ------------------------------------------------------------------ | |
| # 1️⃣ Leaderboard tab | |
| # ------------------------------------------------------------------ | |
| with gr.TabItem("🏅 PROBE Leaderboard", elem_id="probe-benchmark-tab-table", id=1): | |
| leaderboard = get_baseline_df(None, None) # baseline leaderboard without filtering | |
| method_names = leaderboard['Method'].unique().tolist() | |
| metric_names = leaderboard.columns.tolist() | |
| metric_names.remove('Method') # remove non‑metric column | |
| benchmark_metric_mapping = { | |
| "similarity": [m for m in metric_names if m.startswith('sim_')], | |
| "function": [m for m in metric_names if m.startswith('func')], | |
| "family": [m for m in metric_names if m.startswith('fam_')], | |
| "affinity": [m for m in metric_names if m.startswith('aff_')], | |
| } | |
| # selectors ----------------------------------------------------- | |
| leaderboard_method_selector = gr.CheckboxGroup( | |
| choices=method_names, | |
| label="Select Methods for the Leaderboard", | |
| value=method_names, | |
| interactive=True, | |
| ) | |
| benchmark_type_selector_lb = gr.CheckboxGroup( | |
| choices=list(benchmark_metric_mapping.keys()), | |
| label="Select Benchmark Types", | |
| value=None, | |
| interactive=True, | |
| ) | |
| leaderboard_metric_selector = gr.CheckboxGroup( | |
| choices=metric_names, | |
| label="Select Metrics for the Leaderboard", | |
| value=None, | |
| interactive=True, | |
| ) | |
| # leaderboard table -------------------------------------------- | |
| baseline_value = get_baseline_df(method_names, metric_names) | |
| baseline_value = baseline_value.applymap(lambda x: round(x, 4) if isinstance(x, (int, float)) else x) | |
| baseline_header = ["Method"] + metric_names | |
| baseline_datatype = ['markdown'] + ['number'] * len(metric_names) | |
| with gr.Row(show_progress=True, variant='panel'): | |
| data_component = gr.Dataframe( | |
| value=baseline_value, | |
| headers=baseline_header, | |
| type="pandas", | |
| datatype=baseline_datatype, | |
| interactive=False, | |
| visible=True, | |
| ) | |
| # callbacks ----------------------------------------------------- | |
| leaderboard_method_selector.change( | |
| get_baseline_df, | |
| inputs=[leaderboard_method_selector, leaderboard_metric_selector], | |
| outputs=data_component, | |
| ) | |
| benchmark_type_selector_lb.change( | |
| lambda selected: update_metrics(selected), | |
| inputs=[benchmark_type_selector_lb], | |
| outputs=leaderboard_metric_selector, | |
| ) | |
| leaderboard_metric_selector.change( | |
| get_baseline_df, | |
| inputs=[leaderboard_method_selector, leaderboard_metric_selector], | |
| outputs=data_component, | |
| ) | |
| # ------------------------------------------------------------------ | |
| # 2️⃣ Visualisation tab | |
| # ------------------------------------------------------------------ | |
| with gr.TabItem("📊 Visualization", elem_id="probe-benchmark-tab-visualization", id=2): | |
| # Intro / instructions | |
| gr.Markdown( | |
| """ | |
| ## **Interactive Visualizations** | |
| Select a benchmark type first; context‑specific options will appear automatically. | |
| Once your parameters are set, click **Plot** to generate the figure. | |
| **How to read the plots** | |
| * **Similarity (scatter)** – Each point is a model. Points nearer the top‑right perform well on both chosen similarity metrics. | |
| * **Function prediction (heat‑map)** – Darker squares denote better scores. Rows/columns are clustered to reveal shared structure. | |
| * **Family / Affinity (boxplots)** – Boxes summarise distribution across folds/targets. Higher medians indicate stronger performance. | |
| """, | |
| elem_classes="markdown-text", | |
| ) | |
| # ------------------------------------------------------------------ | |
| # selectors specific to visualisation | |
| # ------------------------------------------------------------------ | |
| vis_benchmark_type_selector = gr.Dropdown( | |
| choices=list(benchmark_specific_metrics.keys()), | |
| label="Select Benchmark Type", | |
| value=None, | |
| ) | |
| with gr.Row(): | |
| vis_x_metric_selector = gr.Dropdown(choices=[], label="Select X‑axis Metric", visible=False) | |
| vis_y_metric_selector = gr.Dropdown(choices=[], label="Select Y‑axis Metric", visible=False) | |
| vis_aspect_type_selector = gr.Dropdown(choices=[], label="Select Aspect Type", visible=False) | |
| vis_dataset_selector = gr.Dropdown(choices=[], label="Select Dataset", visible=False) | |
| vis_single_metric_selector = gr.Dropdown(choices=[], label="Select Metric", visible=False) | |
| vis_method_selector = gr.CheckboxGroup( | |
| choices=method_names, | |
| label="Select methods to visualize", | |
| interactive=True, | |
| value=method_names, | |
| ) | |
| plot_button = gr.Button("Plot") | |
| with gr.Row(show_progress=True, variant='panel'): | |
| plot_output = gr.Image(label="Plot") | |
| # textual explanation below the image | |
| plot_explanation = gr.Markdown(visible=False) | |
| # ------------------------------------------------------------------ | |
| # callbacks for visualisation tab | |
| # ------------------------------------------------------------------ | |
| vis_benchmark_type_selector.change( | |
| update_metric_choices, | |
| inputs=[vis_benchmark_type_selector], | |
| outputs=[ | |
| vis_x_metric_selector, | |
| vis_y_metric_selector, | |
| vis_aspect_type_selector, | |
| vis_dataset_selector, | |
| vis_single_metric_selector, | |
| ], | |
| ) | |
| plot_button.click( | |
| generate_plot_and_explanation, | |
| inputs=[ | |
| vis_benchmark_type_selector, | |
| vis_method_selector, | |
| vis_x_metric_selector, | |
| vis_y_metric_selector, | |
| vis_aspect_type_selector, | |
| vis_dataset_selector, | |
| vis_single_metric_selector, | |
| ], | |
| outputs=[plot_output, plot_explanation], | |
| ) | |
| # ------------------------------------------------------------------ | |
| # 3️⃣ About tab | |
| # ------------------------------------------------------------------ | |
| with gr.TabItem("📝 About", elem_id="probe-benchmark-tab-table", id=3): | |
| with gr.Row(): | |
| gr.Markdown(LLM_BENCHMARKS_TEXT, elem_classes="markdown-text") | |
| with gr.Row(): | |
| gr.Image( | |
| value="./src/data/PROBE_workflow_figure.jpg", | |
| label="PROBE Workflow Figure", | |
| elem_classes="about-image", | |
| ) | |
| # ------------------------------------------------------------------ | |
| # 4️⃣ Submit tab | |
| # ------------------------------------------------------------------ | |
| with gr.TabItem("🚀 Submit here! ", elem_id="probe-benchmark-tab-table", id=4): | |
| with gr.Row(): | |
| gr.Markdown(EVALUATION_QUEUE_TEXT, elem_classes="markdown-text") | |
| with gr.Row(): | |
| gr.Markdown("# ✉️✨ Submit your model's representation files here!", elem_classes="markdown-text") | |
| with gr.Row(): | |
| with gr.Column(): | |
| model_name_textbox = gr.Textbox(label="Method name") | |
| revision_name_textbox = gr.Textbox(label="Revision Method Name") | |
| benchmark_types = gr.CheckboxGroup( | |
| choices=TASK_INFO, | |
| label="Benchmark Types", | |
| interactive=True, | |
| ) | |
| similarity_tasks = gr.CheckboxGroup( | |
| choices=similarity_tasks_options, | |
| label="Similarity Tasks", | |
| interactive=True, | |
| ) | |
| function_prediction_aspect = gr.Radio( | |
| choices=function_prediction_aspect_options, | |
| label="Function Prediction Aspects", | |
| interactive=True, | |
| ) | |
| family_prediction_dataset = gr.CheckboxGroup( | |
| choices=family_prediction_dataset_options, | |
| label="Family Prediction Datasets", | |
| interactive=True, | |
| ) | |
| function_dataset = gr.Textbox( | |
| label="Function Prediction Datasets", | |
| visible=False, | |
| value="All_Data_Sets", | |
| ) | |
| save_checkbox = gr.Checkbox( | |
| label="Save results for leaderboard and visualization", | |
| value=True, | |
| ) | |
| with gr.Row(): | |
| human_file = gr.File(label="Representation file (CSV) for Human dataset", file_count="single", type='filepath') | |
| skempi_file = gr.File(label="Representation file (CSV) for SKEMPI dataset", file_count="single", type='filepath') | |
| submit_button = gr.Button("Submit Eval") | |
| submission_result = gr.Markdown() | |
| submit_button.click( | |
| add_new_eval, | |
| inputs=[ | |
| human_file, | |
| skempi_file, | |
| model_name_textbox, | |
| revision_name_textbox, | |
| benchmark_types, | |
| similarity_tasks, | |
| function_prediction_aspect, | |
| function_dataset, | |
| family_prediction_dataset, | |
| save_checkbox, | |
| ], | |
| ) | |
| # ---------------------------------------------------------------------- | |
| # global refresh button & citation accordion | |
| # ---------------------------------------------------------------------- | |
| with gr.Row(): | |
| data_run = gr.Button("Refresh") | |
| data_run.click(refresh_data, outputs=[data_component]) | |
| with gr.Accordion("Citation", open=False): | |
| citation_button = gr.Textbox( | |
| value=CITATION_BUTTON_TEXT, | |
| label=CITATION_BUTTON_LABEL, | |
| elem_id="citation-button", | |
| show_copy_button=True, | |
| ) | |
| # ----------------------------------------------------------------------------- | |
| block.launch() | |