Spaces:
Paused
Paused
| import warnings | |
| from typing import TYPE_CHECKING, Callable, Dict, Iterable, List, Optional, Tuple, Union | |
| import torch | |
| from datasets import Dataset | |
| from sentence_transformers import InputExample, losses, util | |
| from torch import nn | |
| from torch.utils.data import DataLoader | |
| from . import logging | |
| from .sampler import ContrastiveDistillationDataset | |
| from .trainer import Trainer | |
| from .training_args import TrainingArguments | |
| if TYPE_CHECKING: | |
| from .modeling import SetFitModel | |
| logging.set_verbosity_info() | |
| logger = logging.get_logger(__name__) | |
| class DistillationTrainer(Trainer): | |
| """Trainer to compress a SetFit model with knowledge distillation. | |
| Args: | |
| teacher_model (`SetFitModel`): | |
| The teacher model to mimic. | |
| student_model (`SetFitModel`, *optional*): | |
| The model to train. If not provided, a `model_init` must be passed. | |
| args (`TrainingArguments`, *optional*): | |
| The training arguments to use. | |
| train_dataset (`Dataset`): | |
| The training dataset. | |
| eval_dataset (`Dataset`, *optional*): | |
| The evaluation dataset. | |
| model_init (`Callable[[], SetFitModel]`, *optional*): | |
| A function that instantiates the model to be used. If provided, each call to | |
| [`~DistillationTrainer.train`] will start from a new instance of the model as given by this | |
| function when a `trial` is passed. | |
| metric (`str` or `Callable`, *optional*, defaults to `"accuracy"`): | |
| The metric to use for evaluation. If a string is provided, we treat it as the metric | |
| name and load it with default settings. | |
| If a callable is provided, it must take two arguments (`y_pred`, `y_test`). | |
| column_mapping (`Dict[str, str]`, *optional*): | |
| A mapping from the column names in the dataset to the column names expected by the model. | |
| The expected format is a dictionary with the following format: | |
| `{"text_column_name": "text", "label_column_name: "label"}`. | |
| """ | |
| _REQUIRED_COLUMNS = {"text"} | |
| def __init__( | |
| self, | |
| teacher_model: "SetFitModel", | |
| student_model: Optional["SetFitModel"] = None, | |
| args: TrainingArguments = None, | |
| train_dataset: Optional["Dataset"] = None, | |
| eval_dataset: Optional["Dataset"] = None, | |
| model_init: Optional[Callable[[], "SetFitModel"]] = None, | |
| metric: Union[str, Callable[["Dataset", "Dataset"], Dict[str, float]]] = "accuracy", | |
| column_mapping: Optional[Dict[str, str]] = None, | |
| ) -> None: | |
| super().__init__( | |
| model=student_model, | |
| args=args, | |
| train_dataset=train_dataset, | |
| eval_dataset=eval_dataset, | |
| model_init=model_init, | |
| metric=metric, | |
| column_mapping=column_mapping, | |
| ) | |
| self.teacher_model = teacher_model | |
| self.student_model = self.model | |
| def dataset_to_parameters(self, dataset: Dataset) -> List[Iterable]: | |
| return [dataset["text"]] | |
| def get_dataloader( | |
| self, | |
| x: List[str], | |
| y: Optional[Union[List[int], List[List[int]]]], | |
| args: TrainingArguments, | |
| max_pairs: int = -1, | |
| ) -> Tuple[DataLoader, nn.Module, int]: | |
| x_embd_student = self.teacher_model.model_body.encode( | |
| x, convert_to_tensor=self.teacher_model.has_differentiable_head | |
| ) | |
| cos_sim_matrix = util.cos_sim(x_embd_student, x_embd_student) | |
| input_data = [InputExample(texts=[text]) for text in x] | |
| data_sampler = ContrastiveDistillationDataset( | |
| input_data, cos_sim_matrix, args.num_iterations, args.sampling_strategy, max_pairs=max_pairs | |
| ) | |
| # shuffle_sampler = True can be dropped in for further 'randomising' | |
| shuffle_sampler = True if args.sampling_strategy == "unique" else False | |
| batch_size = min(args.embedding_batch_size, len(data_sampler)) | |
| dataloader = DataLoader(data_sampler, batch_size=batch_size, shuffle=shuffle_sampler, drop_last=False) | |
| loss = args.loss(self.model.model_body) | |
| return dataloader, loss, batch_size | |
| def train_classifier(self, x_train: List[str], args: Optional[TrainingArguments] = None) -> None: | |
| """ | |
| Method to perform the classifier phase: fitting the student classifier head. | |
| Args: | |
| x_train (`List[str]`): A list of training sentences. | |
| args (`TrainingArguments`, *optional*): | |
| Temporarily change the training arguments for this training call. | |
| """ | |
| y_train = self.teacher_model.predict(x_train, as_numpy=not self.student_model.has_differentiable_head) | |
| return super().train_classifier(x_train, y_train, args) | |
| class DistillationSetFitTrainer(DistillationTrainer): | |
| """ | |
| `DistillationSetFitTrainer` has been deprecated and will be removed in v2.0.0 of SetFit. | |
| Please use `DistillationTrainer` instead. | |
| """ | |
| def __init__( | |
| self, | |
| teacher_model: "SetFitModel", | |
| student_model: Optional["SetFitModel"] = None, | |
| train_dataset: Optional["Dataset"] = None, | |
| eval_dataset: Optional["Dataset"] = None, | |
| model_init: Optional[Callable[[], "SetFitModel"]] = None, | |
| metric: Union[str, Callable[["Dataset", "Dataset"], Dict[str, float]]] = "accuracy", | |
| loss_class: torch.nn.Module = losses.CosineSimilarityLoss, | |
| num_iterations: int = 20, | |
| num_epochs: int = 1, | |
| learning_rate: float = 2e-5, | |
| batch_size: int = 16, | |
| seed: int = 42, | |
| column_mapping: Optional[Dict[str, str]] = None, | |
| use_amp: bool = False, | |
| warmup_proportion: float = 0.1, | |
| ) -> None: | |
| warnings.warn( | |
| "`DistillationSetFitTrainer` has been deprecated and will be removed in v2.0.0 of SetFit. " | |
| "Please use `DistillationTrainer` instead.", | |
| DeprecationWarning, | |
| stacklevel=2, | |
| ) | |
| args = TrainingArguments( | |
| num_iterations=num_iterations, | |
| num_epochs=num_epochs, | |
| body_learning_rate=learning_rate, | |
| head_learning_rate=learning_rate, | |
| batch_size=batch_size, | |
| seed=seed, | |
| use_amp=use_amp, | |
| warmup_proportion=warmup_proportion, | |
| loss=loss_class, | |
| ) | |
| super().__init__( | |
| teacher_model=teacher_model, | |
| student_model=student_model, | |
| args=args, | |
| train_dataset=train_dataset, | |
| eval_dataset=eval_dataset, | |
| model_init=model_init, | |
| metric=metric, | |
| column_mapping=column_mapping, | |
| ) | |