Spaces:
Paused
Paused
| from collections import defaultdict | |
| from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional, Tuple, Union | |
| from datasets import Dataset | |
| from transformers.trainer_callback import TrainerCallback | |
| from setfit.span.modeling import AbsaModel, AspectModel, PolarityModel | |
| from setfit.training_args import TrainingArguments | |
| from .. import logging | |
| from ..trainer import ColumnMappingMixin, Trainer | |
| if TYPE_CHECKING: | |
| import optuna | |
| logger = logging.get_logger(__name__) | |
| class AbsaTrainer(ColumnMappingMixin): | |
| """Trainer to train a SetFit ABSA model. | |
| Args: | |
| model (`AbsaModel`): | |
| The AbsaModel model to train. | |
| args (`TrainingArguments`, *optional*): | |
| The training arguments to use. If `polarity_args` is not defined, then `args` is used for both | |
| the aspect and the polarity model. | |
| polarity_args (`TrainingArguments`, *optional*): | |
| The training arguments to use for the polarity model. If not defined, `args` is used for both | |
| the aspect and the polarity model. | |
| train_dataset (`Dataset`): | |
| The training dataset. The dataset must have "text", "span", "label" and "ordinal" columns. | |
| eval_dataset (`Dataset`, *optional*): | |
| The evaluation dataset. The dataset must have "text", "span", "label" and "ordinal" columns. | |
| metric (`str` or `Callable`, *optional*, defaults to `"accuracy"`): | |
| The metric to use for evaluation. If a string is provided, we treat it as the metric | |
| name and load it with default settings. | |
| If a callable is provided, it must take two arguments (`y_pred`, `y_test`). | |
| metric_kwargs (`Dict[str, Any]`, *optional*): | |
| Keyword arguments passed to the evaluation function if `metric` is an evaluation string like "f1". | |
| For example useful for providing an averaging strategy for computing f1 in a multi-label setting. | |
| callbacks (`List[`[`~transformers.TrainerCallback`]`]`, *optional*): | |
| A list of callbacks to customize the training loop. Will add those to the list of default callbacks | |
| detailed in [here](https://huggingface.co/docs/transformers/main/en/main_classes/callback). | |
| If you want to remove one of the default callbacks used, use the [`Trainer.remove_callback`] method. | |
| column_mapping (`Dict[str, str]`, *optional*): | |
| A mapping from the column names in the dataset to the column names expected by the model. | |
| The expected format is a dictionary with the following format: | |
| `{"text_column_name": "text", "span_column_name": "span", "label_column_name: "label", "ordinal_column_name": "ordinal"}`. | |
| """ | |
| _REQUIRED_COLUMNS = {"text", "span", "label", "ordinal"} | |
| def __init__( | |
| self, | |
| model: AbsaModel, | |
| args: Optional[TrainingArguments] = None, | |
| polarity_args: Optional[TrainingArguments] = None, | |
| train_dataset: Optional["Dataset"] = None, | |
| eval_dataset: Optional["Dataset"] = None, | |
| metric: Union[str, Callable[["Dataset", "Dataset"], Dict[str, float]]] = "accuracy", | |
| metric_kwargs: Optional[Dict[str, Any]] = None, | |
| callbacks: Optional[List[TrainerCallback]] = None, | |
| column_mapping: Optional[Dict[str, str]] = None, | |
| ) -> None: | |
| self.model = model | |
| self.aspect_extractor = model.aspect_extractor | |
| if train_dataset is not None and column_mapping: | |
| train_dataset = self._apply_column_mapping(train_dataset, column_mapping) | |
| aspect_train_dataset, polarity_train_dataset = self.preprocess_dataset( | |
| model.aspect_model, model.polarity_model, train_dataset | |
| ) | |
| if eval_dataset is not None and column_mapping: | |
| eval_dataset = self._apply_column_mapping(eval_dataset, column_mapping) | |
| aspect_eval_dataset, polarity_eval_dataset = self.preprocess_dataset( | |
| model.aspect_model, model.polarity_model, eval_dataset | |
| ) | |
| self.aspect_trainer = Trainer( | |
| model.aspect_model, | |
| args=args, | |
| train_dataset=aspect_train_dataset, | |
| eval_dataset=aspect_eval_dataset, | |
| metric=metric, | |
| metric_kwargs=metric_kwargs, | |
| callbacks=callbacks, | |
| ) | |
| self.aspect_trainer._set_logs_mapper( | |
| { | |
| "eval_embedding_loss": "eval_aspect_embedding_loss", | |
| "embedding_loss": "aspect_embedding_loss", | |
| } | |
| ) | |
| self.polarity_trainer = Trainer( | |
| model.polarity_model, | |
| args=polarity_args or args, | |
| train_dataset=polarity_train_dataset, | |
| eval_dataset=polarity_eval_dataset, | |
| metric=metric, | |
| metric_kwargs=metric_kwargs, | |
| callbacks=callbacks, | |
| ) | |
| self.polarity_trainer._set_logs_mapper( | |
| { | |
| "eval_embedding_loss": "eval_polarity_embedding_loss", | |
| "embedding_loss": "polarity_embedding_loss", | |
| } | |
| ) | |
| def preprocess_dataset( | |
| self, aspect_model: AspectModel, polarity_model: PolarityModel, dataset: Dataset | |
| ) -> Dataset: | |
| if dataset is None: | |
| return dataset, dataset | |
| # Group by "text" | |
| grouped_data = defaultdict(list) | |
| for sample in dataset: | |
| text = sample.pop("text") | |
| grouped_data[text].append(sample) | |
| def index_ordinal(text: str, target: str, ordinal: int) -> Tuple[int, int]: | |
| find_from = 0 | |
| for _ in range(ordinal + 1): | |
| start_idx = text.index(target, find_from) | |
| find_from = start_idx + 1 | |
| return start_idx, start_idx + len(target) | |
| def overlaps(aspect: slice, aspects: List[slice]) -> bool: | |
| for test_aspect in aspects: | |
| overlapping_indices = set(range(aspect.start, aspect.stop + 1)) & set( | |
| range(test_aspect.start, test_aspect.stop + 1) | |
| ) | |
| if overlapping_indices: | |
| return True | |
| return False | |
| docs, aspects_list = self.aspect_extractor(grouped_data.keys()) | |
| aspect_aspect_list = [] | |
| aspect_labels = [] | |
| polarity_aspect_list = [] | |
| polarity_labels = [] | |
| for doc, aspects, text in zip(docs, aspects_list, grouped_data): | |
| # Collect all of the gold aspects | |
| gold_aspects = [] | |
| gold_polarity_labels = [] | |
| for annotation in grouped_data[text]: | |
| try: | |
| start, end = index_ordinal(text, annotation["span"], annotation["ordinal"]) | |
| except ValueError: | |
| logger.info( | |
| f"The ordinal of {annotation['ordinal']} for span {annotation['span']!r} in {text!r} is too high. " | |
| "Skipping this sample." | |
| ) | |
| continue | |
| gold_aspect_span = doc.char_span(start, end) | |
| if gold_aspect_span is None: | |
| continue | |
| gold_aspects.append(slice(gold_aspect_span.start, gold_aspect_span.end)) | |
| gold_polarity_labels.append(annotation["label"]) | |
| # The Aspect model uses all gold aspects as "True", and all non-overlapping predicted | |
| # aspects as "False" | |
| aspect_labels.extend([True] * len(gold_aspects)) | |
| aspect_aspect_list.append(gold_aspects[:]) | |
| for aspect in aspects: | |
| if not overlaps(aspect, gold_aspects): | |
| aspect_labels.append(False) | |
| aspect_aspect_list[-1].append(aspect) | |
| # The Polarity model uses only the gold aspects and labels | |
| polarity_labels.extend(gold_polarity_labels) | |
| polarity_aspect_list.append(gold_aspects) | |
| aspect_texts = list(aspect_model.prepend_aspects(docs, aspect_aspect_list)) | |
| polarity_texts = list(polarity_model.prepend_aspects(docs, polarity_aspect_list)) | |
| return Dataset.from_dict({"text": aspect_texts, "label": aspect_labels}), Dataset.from_dict( | |
| {"text": polarity_texts, "label": polarity_labels} | |
| ) | |
| def train( | |
| self, | |
| args: Optional[TrainingArguments] = None, | |
| polarity_args: Optional[TrainingArguments] = None, | |
| trial: Optional[Union["optuna.Trial", Dict[str, Any]]] = None, | |
| **kwargs, | |
| ) -> None: | |
| """ | |
| Main training entry point. | |
| Args: | |
| args (`TrainingArguments`, *optional*): | |
| Temporarily change the aspect training arguments for this training call. | |
| polarity_args (`TrainingArguments`, *optional*): | |
| Temporarily change the polarity training arguments for this training call. | |
| trial (`optuna.Trial` or `Dict[str, Any]`, *optional*): | |
| The trial run or the hyperparameter dictionary for hyperparameter search. | |
| """ | |
| self.train_aspect(args=args, trial=trial, **kwargs) | |
| self.train_polarity(args=polarity_args, trial=trial, **kwargs) | |
| def train_aspect( | |
| self, | |
| args: Optional[TrainingArguments] = None, | |
| trial: Optional[Union["optuna.Trial", Dict[str, Any]]] = None, | |
| **kwargs, | |
| ) -> None: | |
| """ | |
| Train the aspect model only. | |
| Args: | |
| args (`TrainingArguments`, *optional*): | |
| Temporarily change the aspect training arguments for this training call. | |
| trial (`optuna.Trial` or `Dict[str, Any]`, *optional*): | |
| The trial run or the hyperparameter dictionary for hyperparameter search. | |
| """ | |
| self.aspect_trainer.train(args=args, trial=trial, **kwargs) | |
| def train_polarity( | |
| self, | |
| args: Optional[TrainingArguments] = None, | |
| trial: Optional[Union["optuna.Trial", Dict[str, Any]]] = None, | |
| **kwargs, | |
| ) -> None: | |
| """ | |
| Train the polarity model only. | |
| Args: | |
| args (`TrainingArguments`, *optional*): | |
| Temporarily change the aspect training arguments for this training call. | |
| trial (`optuna.Trial` or `Dict[str, Any]`, *optional*): | |
| The trial run or the hyperparameter dictionary for hyperparameter search. | |
| """ | |
| self.polarity_trainer.train(args=args, trial=trial, **kwargs) | |
| def add_callback(self, callback: Union[type, TrainerCallback]) -> None: | |
| """ | |
| Add a callback to the current list of [`~transformers.TrainerCallback`]. | |
| Args: | |
| callback (`type` or [`~transformers.TrainerCallback`]): | |
| A [`~transformers.TrainerCallback`] class or an instance of a [`~transformers.TrainerCallback`]. In the | |
| first case, will instantiate a member of that class. | |
| """ | |
| self.aspect_trainer.add_callback(callback) | |
| self.polarity_trainer.add_callback(callback) | |
| def pop_callback(self, callback: Union[type, TrainerCallback]) -> Tuple[TrainerCallback, TrainerCallback]: | |
| """ | |
| Remove a callback from the current list of [`~transformers.TrainerCallback`] and returns it. | |
| If the callback is not found, returns `None` (and no error is raised). | |
| Args: | |
| callback (`type` or [`~transformers.TrainerCallback`]): | |
| A [`~transformers.TrainerCallback`] class or an instance of a [`~transformers.TrainerCallback`]. In the | |
| first case, will pop the first member of that class found in the list of callbacks. | |
| Returns: | |
| `Tuple[`[`~transformers.TrainerCallback`], [`~transformers.TrainerCallback`]`]`: The callbacks removed from the | |
| aspect and polarity trainers, if found. | |
| """ | |
| return self.aspect_trainer.pop_callback(callback), self.polarity_trainer.pop_callback(callback) | |
| def remove_callback(self, callback: Union[type, TrainerCallback]) -> None: | |
| """ | |
| Remove a callback from the current list of [`~transformers.TrainerCallback`]. | |
| Args: | |
| callback (`type` or [`~transformers.TrainerCallback`]): | |
| A [`~transformers.TrainerCallback`] class or an instance of a [`~transformers.TrainerCallback`]. In the | |
| first case, will remove the first member of that class found in the list of callbacks. | |
| """ | |
| self.aspect_trainer.remove_callback(callback) | |
| self.polarity_trainer.remove_callback(callback) | |
| def push_to_hub(self, repo_id: str, polarity_repo_id: Optional[str] = None, **kwargs) -> None: | |
| """Upload model checkpoint to the Hub using `huggingface_hub`. | |
| See the full list of parameters for your `huggingface_hub` version in the\ | |
| [huggingface_hub documentation](https://huggingface.co/docs/huggingface_hub/package_reference/mixins#huggingface_hub.ModelHubMixin.push_to_hub). | |
| Args: | |
| repo_id (`str`): | |
| The full repository ID to push to, e.g. `"tomaarsen/setfit-aspect"`. | |
| repo_id (`str`): | |
| The full repository ID to push to, e.g. `"tomaarsen/setfit-sst2"`. | |
| config (`dict`, *optional*): | |
| Configuration object to be saved alongside the model weights. | |
| commit_message (`str`, *optional*): | |
| Message to commit while pushing. | |
| private (`bool`, *optional*, defaults to `False`): | |
| Whether the repository created should be private. | |
| api_endpoint (`str`, *optional*): | |
| The API endpoint to use when pushing the model to the hub. | |
| token (`str`, *optional*): | |
| The token to use as HTTP bearer authorization for remote files. | |
| If not set, will use the token set when logging in with | |
| `transformers-cli login` (stored in `~/.huggingface`). | |
| branch (`str`, *optional*): | |
| The git branch on which to push the model. This defaults to | |
| the default branch as specified in your repository, which | |
| defaults to `"main"`. | |
| create_pr (`boolean`, *optional*): | |
| Whether or not to create a Pull Request from `branch` with that commit. | |
| Defaults to `False`. | |
| allow_patterns (`List[str]` or `str`, *optional*): | |
| If provided, only files matching at least one pattern are pushed. | |
| ignore_patterns (`List[str]` or `str`, *optional*): | |
| If provided, files matching any of the patterns are not pushed. | |
| """ | |
| return self.model.push_to_hub(repo_id=repo_id, polarity_repo_id=polarity_repo_id, **kwargs) | |
| def evaluate(self, dataset: Optional[Dataset] = None) -> Dict[str, Dict[str, float]]: | |
| """ | |
| Computes the metrics for a given classifier. | |
| Args: | |
| dataset (`Dataset`, *optional*): | |
| The dataset to compute the metrics on. If not provided, will use the evaluation dataset passed via | |
| the `eval_dataset` argument at `Trainer` initialization. | |
| Returns: | |
| `Dict[str, Dict[str, float]]`: The evaluation metrics. | |
| """ | |
| aspect_eval_dataset = polarity_eval_dataset = None | |
| if dataset: | |
| aspect_eval_dataset, polarity_eval_dataset = self.preprocess_dataset( | |
| self.model.aspect_model, self.model.polarity_model, dataset | |
| ) | |
| return { | |
| "aspect": self.aspect_trainer.evaluate(aspect_eval_dataset), | |
| "polarity": self.polarity_trainer.evaluate(polarity_eval_dataset), | |
| } | |