Spaces:
Paused
Paused
| from itertools import zip_longest | |
| from typing import Generator, Iterable, List, Optional | |
| import numpy as np | |
| import torch | |
| from sentence_transformers import InputExample | |
| from torch.utils.data import IterableDataset | |
| from . import logging | |
| logging.set_verbosity_info() | |
| logger = logging.get_logger(__name__) | |
| def shuffle_combinations(iterable: Iterable, replacement: bool = True) -> Generator: | |
| """Generates shuffled pair combinations for any iterable data provided. | |
| Args: | |
| iterable: data to generate pair combinations from | |
| replacement: enable to include combinations of same samples, | |
| equivalent to itertools.combinations_with_replacement | |
| Returns: | |
| Generator of shuffled pairs as a tuple | |
| """ | |
| n = len(iterable) | |
| k = 1 if not replacement else 0 | |
| idxs = np.stack(np.triu_indices(n, k), axis=-1) | |
| for i in np.random.RandomState(seed=42).permutation(len(idxs)): | |
| _idx, idx = idxs[i, :] | |
| yield iterable[_idx], iterable[idx] | |
| class ContrastiveDataset(IterableDataset): | |
| def __init__( | |
| self, | |
| examples: List[InputExample], | |
| multilabel: bool, | |
| num_iterations: Optional[None] = None, | |
| sampling_strategy: str = "oversampling", | |
| max_pairs: int = -1, | |
| ) -> None: | |
| """Generates positive and negative text pairs for contrastive learning. | |
| Args: | |
| examples (InputExample): text and labels in a text transformer dataclass | |
| multilabel: set to process "multilabel" labels array | |
| sampling_strategy: "unique", "oversampling", or "undersampling" | |
| num_iterations: if provided explicitly sets the number of pairs to be generated | |
| where n_pairs = n_iterations * n_sentences * 2 (for pos & neg pairs) | |
| max_pairs: If not -1, then we only sample pairs until we have certainly reached | |
| max_pairs pairs. | |
| """ | |
| super().__init__() | |
| self.pos_index = 0 | |
| self.neg_index = 0 | |
| self.pos_pairs = [] | |
| self.neg_pairs = [] | |
| self.sentences = np.array([s.texts[0] for s in examples]) | |
| self.labels = np.array([s.label for s in examples]) | |
| self.sentence_labels = list(zip(self.sentences, self.labels)) | |
| self.max_pairs = max_pairs | |
| if multilabel: | |
| self.generate_multilabel_pairs() | |
| else: | |
| self.generate_pairs() | |
| if num_iterations is not None and num_iterations > 0: | |
| self.len_pos_pairs = num_iterations * len(self.sentences) | |
| self.len_neg_pairs = num_iterations * len(self.sentences) | |
| elif sampling_strategy == "unique": | |
| self.len_pos_pairs = len(self.pos_pairs) | |
| self.len_neg_pairs = len(self.neg_pairs) | |
| elif sampling_strategy == "undersampling": | |
| self.len_pos_pairs = min(len(self.pos_pairs), len(self.neg_pairs)) | |
| self.len_neg_pairs = min(len(self.pos_pairs), len(self.neg_pairs)) | |
| elif sampling_strategy == "oversampling": | |
| self.len_pos_pairs = max(len(self.pos_pairs), len(self.neg_pairs)) | |
| self.len_neg_pairs = max(len(self.pos_pairs), len(self.neg_pairs)) | |
| else: | |
| raise ValueError("Invalid sampling strategy. Must be one of 'unique', 'oversampling', or 'undersampling'.") | |
| def generate_pairs(self) -> None: | |
| for (_text, _label), (text, label) in shuffle_combinations(self.sentence_labels): | |
| if _label == label: | |
| self.pos_pairs.append(InputExample(texts=[_text, text], label=1.0)) | |
| else: | |
| self.neg_pairs.append(InputExample(texts=[_text, text], label=0.0)) | |
| if self.max_pairs != -1 and len(self.pos_pairs) > self.max_pairs and len(self.neg_pairs) > self.max_pairs: | |
| break | |
| def generate_multilabel_pairs(self) -> None: | |
| for (_text, _label), (text, label) in shuffle_combinations(self.sentence_labels): | |
| if any(np.logical_and(_label, label)): | |
| # logical_and checks if labels are both set for each class | |
| self.pos_pairs.append(InputExample(texts=[_text, text], label=1.0)) | |
| else: | |
| self.neg_pairs.append(InputExample(texts=[_text, text], label=0.0)) | |
| if self.max_pairs != -1 and len(self.pos_pairs) > self.max_pairs and len(self.neg_pairs) > self.max_pairs: | |
| break | |
| def get_positive_pairs(self) -> List[InputExample]: | |
| pairs = [] | |
| for _ in range(self.len_pos_pairs): | |
| if self.pos_index >= len(self.pos_pairs): | |
| self.pos_index = 0 | |
| pairs.append(self.pos_pairs[self.pos_index]) | |
| self.pos_index += 1 | |
| return pairs | |
| def get_negative_pairs(self) -> List[InputExample]: | |
| pairs = [] | |
| for _ in range(self.len_neg_pairs): | |
| if self.neg_index >= len(self.neg_pairs): | |
| self.neg_index = 0 | |
| pairs.append(self.neg_pairs[self.neg_index]) | |
| self.neg_index += 1 | |
| return pairs | |
| def __iter__(self): | |
| for pos_pair, neg_pair in zip_longest(self.get_positive_pairs(), self.get_negative_pairs()): | |
| if pos_pair is not None: | |
| yield pos_pair | |
| if neg_pair is not None: | |
| yield neg_pair | |
| def __len__(self) -> int: | |
| return self.len_pos_pairs + self.len_neg_pairs | |
| class ContrastiveDistillationDataset(ContrastiveDataset): | |
| def __init__( | |
| self, | |
| examples: List[InputExample], | |
| cos_sim_matrix: torch.Tensor, | |
| num_iterations: Optional[None] = None, | |
| sampling_strategy: str = "oversampling", | |
| max_pairs: int = -1, | |
| ) -> None: | |
| self.cos_sim_matrix = cos_sim_matrix | |
| super().__init__( | |
| examples, | |
| multilabel=False, | |
| num_iterations=num_iterations, | |
| sampling_strategy=sampling_strategy, | |
| max_pairs=max_pairs, | |
| ) | |
| # Internally we store all pairs in pos_pairs, regardless of sampling strategy. | |
| # After all, without labels, there isn't much of a strategy. | |
| self.sentence_labels = list(enumerate(self.sentences)) | |
| self.len_neg_pairs = 0 | |
| if num_iterations is not None and num_iterations > 0: | |
| self.len_pos_pairs = num_iterations * len(self.sentences) | |
| else: | |
| self.len_pos_pairs = len(self.pos_pairs) | |
| def generate_pairs(self) -> None: | |
| for (text_one, id_one), (text_two, id_two) in shuffle_combinations(self.sentence_labels): | |
| self.pos_pairs.append(InputExample(texts=[text_one, text_two], label=self.cos_sim_matrix[id_one][id_two])) | |
| if self.max_pairs != -1 and len(self.pos_pairs) > self.max_pairs: | |
| break | |