Spaces:
Paused
Paused
| import types | |
| from contextlib import contextmanager | |
| from dataclasses import dataclass, field | |
| from time import monotonic_ns | |
| from typing import Any, Dict, List, NamedTuple, Optional, Tuple | |
| from datasets import Dataset, DatasetDict, load_dataset | |
| from sentence_transformers import losses | |
| from transformers.utils import copy_func | |
| from .data import create_fewshot_splits, create_fewshot_splits_multilabel | |
| from .losses import SupConLoss | |
| SEC_TO_NS_SCALE = 1000000000 | |
| DEV_DATASET_TO_METRIC = { | |
| "sst2": "accuracy", | |
| "imdb": "accuracy", | |
| "subj": "accuracy", | |
| "bbc-news": "accuracy", | |
| "enron_spam": "accuracy", | |
| "student-question-categories": "accuracy", | |
| "TREC-QC": "accuracy", | |
| "toxic_conversations": "matthews_correlation", | |
| } | |
| TEST_DATASET_TO_METRIC = { | |
| "emotion": "accuracy", | |
| "SentEval-CR": "accuracy", | |
| "sst5": "accuracy", | |
| "ag_news": "accuracy", | |
| "enron_spam": "accuracy", | |
| "amazon_counterfactual_en": "matthews_correlation", | |
| } | |
| MULTILINGUAL_DATASET_TO_METRIC = { | |
| f"amazon_reviews_multi_{lang}": "mae" for lang in ["en", "de", "es", "fr", "ja", "zh"] | |
| } | |
| LOSS_NAME_TO_CLASS = { | |
| "CosineSimilarityLoss": losses.CosineSimilarityLoss, | |
| "ContrastiveLoss": losses.ContrastiveLoss, | |
| "OnlineContrastiveLoss": losses.OnlineContrastiveLoss, | |
| "BatchSemiHardTripletLoss": losses.BatchSemiHardTripletLoss, | |
| "BatchAllTripletLoss": losses.BatchAllTripletLoss, | |
| "BatchHardTripletLoss": losses.BatchHardTripletLoss, | |
| "BatchHardSoftMarginTripletLoss": losses.BatchHardSoftMarginTripletLoss, | |
| "SupConLoss": SupConLoss, | |
| } | |
| def default_hp_space_optuna(trial) -> Dict[str, Any]: | |
| from transformers.integrations import is_optuna_available | |
| assert is_optuna_available(), "This function needs Optuna installed: `pip install optuna`" | |
| return { | |
| "learning_rate": trial.suggest_float("learning_rate", 1e-6, 1e-4, log=True), | |
| "num_epochs": trial.suggest_int("num_epochs", 1, 5), | |
| "num_iterations": trial.suggest_categorical("num_iterations", [5, 10, 20]), | |
| "seed": trial.suggest_int("seed", 1, 40), | |
| "batch_size": trial.suggest_categorical("batch_size", [4, 8, 16, 32, 64]), | |
| } | |
| def load_data_splits( | |
| dataset: str, sample_sizes: List[int], add_data_augmentation: bool = False | |
| ) -> Tuple[DatasetDict, Dataset]: | |
| """Loads a dataset from the Hugging Face Hub and returns the test split and few-shot training splits.""" | |
| print(f"\n\n\n============== {dataset} ============") | |
| # Load one of the SetFit training sets from the Hugging Face Hub | |
| train_split = load_dataset(f"SetFit/{dataset}", split="train") | |
| train_splits = create_fewshot_splits(train_split, sample_sizes, add_data_augmentation, f"SetFit/{dataset}") | |
| test_split = load_dataset(f"SetFit/{dataset}", split="test") | |
| print(f"Test set: {len(test_split)}") | |
| return train_splits, test_split | |
| def load_data_splits_multilabel(dataset: str, sample_sizes: List[int]) -> Tuple[DatasetDict, Dataset]: | |
| """Loads a dataset from the Hugging Face Hub and returns the test split and few-shot training splits.""" | |
| print(f"\n\n\n============== {dataset} ============") | |
| # Load one of the SetFit training sets from the Hugging Face Hub | |
| train_split = load_dataset(f"SetFit/{dataset}", "multilabel", split="train") | |
| train_splits = create_fewshot_splits_multilabel(train_split, sample_sizes) | |
| test_split = load_dataset(f"SetFit/{dataset}", "multilabel", split="test") | |
| print(f"Test set: {len(test_split)}") | |
| return train_splits, test_split | |
| class Benchmark: | |
| """ | |
| Performs simple benchmarks of code portions (measures elapsed time). | |
| Typical usage example: | |
| bench = Benchmark() | |
| with bench.track("Foo function"): | |
| foo() | |
| with bench.track("Bar function"): | |
| bar() | |
| bench.summary() | |
| """ | |
| out_path: Optional[str] = None | |
| summary_msg: str = field(default_factory=str) | |
| def print(self, msg: str) -> None: | |
| """ | |
| Prints to system out and optionally to specified out_path. | |
| """ | |
| print(msg) | |
| if self.out_path is not None: | |
| with open(self.out_path, "a+") as f: | |
| f.write(msg + "\n") | |
| def track(self, step): | |
| """ | |
| Computes the elapsed time for given code context. | |
| """ | |
| start = monotonic_ns() | |
| yield | |
| ns = monotonic_ns() - start | |
| msg = f"\n{'*' * 70}\n'{step}' took {ns / SEC_TO_NS_SCALE:.3f}s ({ns:,}ns)\n{'*' * 70}\n" | |
| print(msg) | |
| self.summary_msg += msg + "\n" | |
| def summary(self) -> None: | |
| """ | |
| Prints summary of all benchmarks performed. | |
| """ | |
| self.print(f"\n{'#' * 30}\nBenchmark Summary:\n{'#' * 30}\n\n{self.summary_msg}") | |
| class BestRun(NamedTuple): | |
| """ | |
| The best run found by a hyperparameter search (see [`~Trainer.hyperparameter_search`]). | |
| Parameters: | |
| run_id (`str`): | |
| The id of the best run. | |
| objective (`float`): | |
| The objective that was obtained for this run. | |
| hyperparameters (`Dict[str, Any]`): | |
| The hyperparameters picked to get this run. | |
| backend (`Any`): | |
| The relevant internal object used for optimization. For optuna this is the `study` object. | |
| """ | |
| run_id: str | |
| objective: float | |
| hyperparameters: Dict[str, Any] | |
| backend: Any = None | |
| def set_docstring(method, docstring, cls=None): | |
| copied_function = copy_func(method) | |
| copied_function.__doc__ = docstring | |
| return types.MethodType(copied_function, cls or method.__self__) | |