Spaces:
Running
Running
| import gradio as gr | |
| from datasets import disable_caching, load_dataset | |
| from transformer_ranker import TransformerRanker | |
| from demo.config import SAMPLE_SIZE, MAX_SAMPLE_SIZE, ALL_LMS, PRESELECTED_LMS, GRADIO_THEME | |
| from demo.utils import ( | |
| BANNER, FOOTER, CSS, UNSET, | |
| EmbeddingProgressTracker, compute_ratio, | |
| validate_dataset, preprocess_dataset, ensure_dataset_is_loaded | |
| ) | |
| disable_caching() | |
| with gr.Blocks(css=CSS, theme=None) as demo: | |
| gr.Markdown(BANNER) | |
| ##### 1. Load from datasets ##### | |
| gr.Markdown("## Load Downstream Dataset") | |
| gr.Markdown( | |
| "Select a dataset from the Hugging Face Hub such as `trec`. " | |
| "This defines your downstream task." | |
| ) | |
| with gr.Group(): | |
| dataset = gr.State(None) | |
| dataset_id = gr.Textbox( | |
| label="Dataset name", | |
| placeholder="try: trec, conll2003, ag_news", | |
| max_lines=1, | |
| ) | |
| load_dataset_button = gr.Button(value="Load data", variant="primary", interactive=True,) | |
| # enable loading if dataset exists on hub | |
| dataset_id.change(validate_dataset, inputs=dataset_id, outputs=load_dataset_button) | |
| gr.Markdown( | |
| "Settings auto-configured. " | |
| "Adjust the downsampling ratio in Dataset Setup, " | |
| "or use the complete dataset with the [framework](https://github.com/flairNLP/transformer-ranker)." | |
| ) | |
| ##### data preprocessing ##### | |
| with gr.Accordion("Dataset Setup", open=False) as dataset_config: | |
| with gr.Row() as dataset_details: | |
| dataset_id_label = gr.Label("", label="Dataset") | |
| num_samples = gr.State(0) | |
| num_samples_label = gr.Label("", label="Dataset size") | |
| num_samples.change( | |
| lambda x: str(x), inputs=[num_samples], outputs=[num_samples_label] | |
| ) | |
| with gr.Row(): | |
| text_column = gr.Dropdown("", label="Text Column") | |
| text_pair_column = gr.Dropdown("", label="Text Pair") | |
| with gr.Row(): | |
| label_column = gr.Dropdown("", label="Labels") | |
| task_category = gr.Dropdown("", label="Downstream Task") | |
| with gr.Group(): | |
| downsample_ratio = gr.State(0.0) | |
| sampling_rate = gr.Slider( | |
| 20, MAX_SAMPLE_SIZE, label="Sampling rate", value=SAMPLE_SIZE, step=1 | |
| ) | |
| downsample_ratio_label = gr.Label("", label="Sampling rate") | |
| downsample_ratio.change( | |
| lambda x: f"{x:.1%}", | |
| inputs=[downsample_ratio], | |
| outputs=[downsample_ratio_label], | |
| ) | |
| sampling_rate.change( | |
| compute_ratio, | |
| inputs=[sampling_rate, num_samples], | |
| outputs=downsample_ratio, | |
| ) | |
| num_samples.change( | |
| compute_ratio, | |
| inputs=[sampling_rate, num_samples], | |
| outputs=downsample_ratio, | |
| ) | |
| # load and show details | |
| def load_hf_dataset(dataset_id): | |
| try: | |
| dataset = load_dataset(dataset_id, trust_remote_code=True) | |
| dataset_details = preprocess_dataset(dataset) | |
| except ValueError as e: | |
| gr.Warning("Collections not supported. Load one dataset only.") | |
| return ( | |
| gr.update(value="Loaded"), | |
| dataset_id, | |
| dataset, | |
| *dataset_details | |
| ) | |
| load_dataset_button.click( | |
| load_hf_dataset, | |
| inputs=[dataset_id], | |
| outputs=[ | |
| load_dataset_button, | |
| dataset_id_label, | |
| dataset, | |
| task_category, | |
| text_column, | |
| text_pair_column, | |
| label_column, | |
| num_samples, | |
| ], | |
| scroll_to_output=True, | |
| ) | |
| ########## 2. Select LMs ########## | |
| gr.Markdown("## Select Language Models") | |
| gr.Markdown( | |
| "Add two or more pretrained models for ranking. " | |
| "Go with small models since this demo runs on CPU." | |
| ) | |
| with gr.Group(): | |
| model_options = [ | |
| (model_handle.split("/")[-1], model_handle) | |
| for model_handle in ALL_LMS | |
| ] | |
| models = gr.CheckboxGroup( | |
| choices=model_options, label="Model List", value=PRESELECTED_LMS | |
| ) | |
| ########## 3. Run ranking ########## | |
| gr.Markdown("## Rank Language Models") | |
| gr.Markdown( | |
| "Rank models by transferability to your downstream task. " | |
| "Adjust the metric and layer aggregation in Advanced Settings." | |
| ) | |
| with gr.Group(): | |
| submit_button = gr.Button("Run ranking", variant="primary", interactive=False) | |
| with gr.Accordion("Advanced Settings", open=False): | |
| with gr.Row(): | |
| estimator = gr.Dropdown( | |
| choices=["hscore", "logme", "knn"], | |
| label="Transferability metric", | |
| value="hscore", | |
| ) | |
| layer_aggregator = gr.Dropdown( | |
| choices=["lastlayer", "layermean", "bestlayer"], | |
| label="Layer aggregation", | |
| value="layermean", | |
| ) | |
| # ranking button works after dataset loads | |
| dataset.change( | |
| ensure_dataset_is_loaded, | |
| inputs=[dataset, text_column, label_column, task_category], | |
| outputs=submit_button | |
| ) | |
| label_column.change( | |
| ensure_dataset_is_loaded, | |
| inputs=[dataset, text_column, label_column, task_category], | |
| outputs=submit_button | |
| ) | |
| text_column.change( | |
| ensure_dataset_is_loaded, | |
| inputs=[dataset, text_column, label_column, task_category], | |
| outputs=submit_button | |
| ) | |
| def rank_models( | |
| dataset, | |
| downsample_ratio, | |
| selected_models, | |
| layer_aggregator, | |
| estimator, | |
| text_column, | |
| text_pair_column, | |
| label_column, | |
| task_category, | |
| progress=gr.Progress(), | |
| ): | |
| if text_column == UNSET: | |
| raise gr.Error("Text column is not set.") | |
| if label_column == UNSET: | |
| raise gr.Error("Label column is not set.") | |
| if task_category == UNSET: | |
| raise gr.Error( | |
| "Task category not set. Dataset must support classification or regression." | |
| ) | |
| if text_pair_column == UNSET: | |
| text_pair_column = None | |
| progress(0.0, "Starting") | |
| with EmbeddingProgressTracker(progress=progress, model_names=selected_models) as tracker: | |
| try: | |
| ranker = TransformerRanker( | |
| dataset, | |
| dataset_downsample=downsample_ratio, | |
| text_column=text_column, | |
| text_pair_column=text_pair_column, | |
| label_column=label_column, | |
| task_category=task_category, | |
| ) | |
| results = ranker.run( | |
| models=selected_models, | |
| layer_aggregator=layer_aggregator, | |
| estimator=estimator, | |
| batch_size=64, | |
| tracker=tracker, | |
| ) | |
| sorted_results = sorted( | |
| results._results.items(), key=lambda item: item[1], reverse=True | |
| ) | |
| return [ | |
| (i + 1, model, score) for i, (model, score) in enumerate(sorted_results) | |
| ] | |
| except Exception as e: | |
| print(e) | |
| gr.Warning(f"Ranking issue: {e}") | |
| return [] | |
| gr.Markdown("Ranking table → higher scores indicate better downstream performance.") | |
| ranking_results = gr.Dataframe( | |
| headers=["Rank", "Model", "Score"], | |
| datatype=["number", "str", "number"], | |
| value=[["-", "-", "-"]] | |
| ) | |
| submit_button.click( | |
| rank_models, | |
| inputs=[ | |
| dataset, | |
| downsample_ratio, | |
| models, | |
| layer_aggregator, | |
| estimator, | |
| text_column, | |
| text_pair_column, | |
| label_column, | |
| task_category, | |
| ], | |
| outputs=ranking_results, | |
| scroll_to_output=True, | |
| ) | |
| gr.Markdown(FOOTER) | |
| if __name__ == "__main__": | |
| # run up to 3 requests at once | |
| demo.queue(default_concurrency_limit=3) | |
| # run with 6 workers | |
| demo.launch(max_threads=6) | |