Spaces:
Sleeping
Sleeping
| import math | |
| import gradio as gr | |
| from datasets import concatenate_datasets | |
| from huggingface_hub import HfApi | |
| from huggingface_hub.errors import HFValidationError | |
| from requests.exceptions import HTTPError | |
| from transformer_ranker.datacleaner import DatasetCleaner, TaskCategory | |
| from transformer_ranker.embedder import Embedder | |
| BANNER = """ | |
| <h1 align="center">🔥 TransformerRanker 🔥</h1> | |
| <p align="center" style="max-width: 560px; margin: auto;"> | |
| Find the best language model for your downstream task. | |
| Load a dataset, select models from the 🤗 Hub, and rank them by <strong>transferability</strong>. | |
| </p> | |
| <p align="center" style="font-weight: bold; margin-top: 20px; display: flex; justify-content: center; gap: 10px;"> | |
| <a href="https://github.com/flairNLP/transformer-ranker"> | |
| <img src="https://img.shields.io/badge/Code Repo-black?style=flat&logo=github" alt="repository"> | |
| </a> | |
| <a href="https://opensource.org/licenses/MIT"> | |
| <img src="https://img.shields.io/badge/License-MIT-brightgreen?style=flat" alt="license"> | |
| </a> | |
| <a href="https://pypi.org/project/transformer-ranker/"> | |
| <img src="https://img.shields.io/badge/Package-orange?style=flat&logo=python" alt="package"> | |
| </a> | |
| <a href="https://github.com/flairNLP/transformer-ranker/blob/main/docs/01-walkthrough.md"> | |
| <img src="https://img.shields.io/badge/Tutorials-blue?style=flat&logo=readthedocs&logoColor=white" alt="tutorials"> | |
| </a> | |
| </p> | |
| <p align="center">Developed at <a href="https://www.informatik.hu-berlin.de/en/forschung-en/gebiete/ml-en/">Humboldt University of Berlin</a>.</p> | |
| """ | |
| FOOTER = """ | |
| **Note:** CPU-only quick demo. **Built by:** @lukasgarbas & @plonerma | |
| **Questions?** Open a [GitHub issue](https://github.com/flairNLP/transformer-ranker/issues) 🔫. | |
| """ | |
| CSS = """ | |
| .gradio-container { | |
| max-width: 800px; | |
| margin: auto; | |
| } | |
| """ | |
| UNSET = "-" | |
| hf_api = HfApi() | |
| preprocessing = DatasetCleaner() | |
| def validate_dataset(dataset_name): | |
| """Enable if dataset exists on Hub.""" | |
| try: | |
| hf_api.dataset_info(dataset_name) # quick dataset info call | |
| return gr.update(interactive=True) | |
| except (HTTPError, HFValidationError): | |
| return gr.update(value="Load data", interactive=False) | |
| def preprocess_dataset(dataset): | |
| """Use data preprocessing to find text/label columns and task category.""" | |
| data = concatenate_datasets(list(dataset.values())) | |
| try: | |
| text_column = preprocessing._find_column(data, "text column") | |
| except ValueError: | |
| gr.Warning("Text column not auto-detected — select in settings.") | |
| text_column = UNSET | |
| try: | |
| label_column = preprocessing._find_column(data, "label column") | |
| except ValueError: | |
| gr.Warning("Label column not auto-detected — select in settings.") | |
| label_column = UNSET | |
| task_category = UNSET | |
| if label_column != UNSET: | |
| try: | |
| task_category = preprocessing._find_task_category(data, label_column) | |
| except ValueError: | |
| gr.Warning("Task category not auto-detected — framework supports classification, regression.") | |
| text_column = gr.update(value=text_column, choices=data.column_names, interactive=True) | |
| label_column = gr.update(value=label_column, choices=data.column_names, interactive=True) | |
| text_pair = gr.update(value=UNSET, choices=[UNSET, *data.column_names], interactive=True) | |
| task_category = gr.update(value=task_category, choices=[str(t) for t in TaskCategory], interactive=True) | |
| sample_size = len(data) | |
| return task_category, text_column, text_pair, label_column, sample_size | |
| """ | |
| return ( | |
| text_column, | |
| gr.update( | |
| value=task_category, | |
| choices=[str(t) for t in TaskCategory], | |
| interactive=True, | |
| ), | |
| gr.update( | |
| value=text_column, choices=data.column_names, interactive=True | |
| ), | |
| gr.update( | |
| value=UNSET, choices=[UNSET, *data.column_names], interactive=True | |
| ), | |
| gr.update( | |
| value=label_column, choices=data.column_names, interactive=True | |
| ), | |
| num_samples, | |
| ) | |
| """ | |
| def compute_ratio(num_samples_to_use, num_samples): | |
| if num_samples > 0: | |
| return num_samples_to_use / num_samples | |
| else: | |
| return 0.0 | |
| def ensure_dataset_is_loaded(dataset, text_column, label_column, task_category): | |
| if dataset and text_column != UNSET and label_column != UNSET and task_category != UNSET: | |
| return gr.update(interactive=True) | |
| else: | |
| return gr.update(interactive=False) | |
| def ensure_one_lm_selected(checkbox_values, previous_values): | |
| if not any(checkbox_values): | |
| return previous_values | |
| return checkbox_values | |
| # apply monkey patch to enable callbacks | |
| _old_embed = Embedder.embed | |
| def _new_embed(embedder, sentences, batch_size: int = 32, **kw): | |
| if embedder.tracker is not None: | |
| embedder.tracker.update_num_batches(math.ceil(len(sentences) / batch_size)) | |
| return _old_embed(embedder, sentences, batch_size=batch_size, **kw) | |
| Embedder.embed = _new_embed | |
| _old_embed_batch = Embedder.embed_batch | |
| def _new_embed_batch(embedder, *args, **kw): | |
| r = _old_embed_batch(embedder, *args, **kw) | |
| if embedder.tracker is not None: | |
| embedder.tracker.update_batch_complete() | |
| return r | |
| Embedder.embed_batch = _new_embed_batch | |
| _old_init = Embedder.__init__ | |
| def _new_init(embedder, *args, tracker=None, **kw): | |
| _old_init(embedder, *args, **kw) | |
| embedder.tracker = tracker | |
| Embedder.__init__ = _new_init | |
| class EmbeddingProgressTracker: | |
| def __init__(self, *, progress, model_names): | |
| self.model_names = model_names | |
| self.progress_bar = progress | |
| def total(self): | |
| return len(self.model_names) | |
| def __enter__(self): | |
| self.progress_bar = gr.Progress(track_tqdm=False) | |
| self.current_model = -1 | |
| self.batches_complete = 0 | |
| self.batches_total = None | |
| return self | |
| def __exit__(self, typ, value, tb): | |
| if typ is None: | |
| self.progress_bar(1.0, desc="Done") | |
| else: | |
| self.progress_bar(1.0, desc="Error") | |
| # Do not suppress any errors | |
| return False | |
| def update_num_batches(self, total): | |
| self.current_model += 1 | |
| self.batches_complete = 0 | |
| self.batches_total = total | |
| self.update_bar() | |
| def update_batch_complete(self): | |
| self.batches_complete += 1 | |
| self.update_bar() | |
| def update_bar(self): | |
| i = self.current_model | |
| description = f"Running {self.model_names[i]} ({i + 1} / {self.total})" | |
| progress = i / self.total | |
| if self.batches_total is not None: | |
| progress += (self.batches_complete / self.batches_total) / self.total | |
| self.progress_bar(progress=progress, desc=description) | |