update from https://github.com/ArneBinder/argumentation-structure-identification/pull/529
d868d2e
verified
| import abc | |
| import logging | |
| from typing import Callable, List, Optional | |
| import torch | |
| import torch.nn.functional as F | |
| from adapters import AutoAdapterModel | |
| from pie_modules.models import SequencePairSimilarityModelWithPooler | |
| from pie_modules.models.sequence_classification_with_pooler import ( | |
| InputType, | |
| OutputType, | |
| SequenceClassificationModelWithPooler, | |
| SequenceClassificationModelWithPoolerBase, | |
| TargetType, | |
| separate_arguments_by_prefix, | |
| ) | |
| from pytorch_ie import PyTorchIEModel | |
| from torch import FloatTensor, Tensor | |
| from transformers import AutoConfig, PreTrainedModel | |
| from transformers.modeling_outputs import SequenceClassifierOutput | |
| logger = logging.getLogger(__name__) | |
| class SequenceClassificationModelWithPoolerAndAdapterBase( | |
| SequenceClassificationModelWithPoolerBase, abc.ABC | |
| ): | |
| def __init__(self, adapter_name_or_path: Optional[str] = None, **kwargs): | |
| self.adapter_name_or_path = adapter_name_or_path | |
| super().__init__(**kwargs) | |
| def setup_base_model(self) -> PreTrainedModel: | |
| if self.adapter_name_or_path is None: | |
| return super().setup_base_model() | |
| else: | |
| config = AutoConfig.from_pretrained(self.model_name_or_path) | |
| if self.is_from_pretrained: | |
| model = AutoAdapterModel.from_config(config=config) | |
| else: | |
| model = AutoAdapterModel.from_pretrained(self.model_name_or_path, config=config) | |
| # load the adapter in any case (it looks like it is not saved in the state or loaded | |
| # from a serialized state) | |
| logger.info(f"load adapter: {self.adapter_name_or_path}") | |
| model.load_adapter(self.adapter_name_or_path, source="hf", set_active=True) | |
| return model | |
| class SequencePairSimilarityModelWithPoolerAndAdapter( | |
| SequencePairSimilarityModelWithPooler, SequenceClassificationModelWithPoolerAndAdapterBase | |
| ): | |
| pass | |
| class SequenceClassificationModelWithPoolerAndAdapter( | |
| SequenceClassificationModelWithPooler, SequenceClassificationModelWithPoolerAndAdapterBase | |
| ): | |
| pass | |
| def get_max_cosine_sim(embeddings: Tensor, embeddings_pair: Tensor) -> Tensor: | |
| # Normalize the embeddings | |
| embeddings_normalized = F.normalize(embeddings, p=2, dim=1) # Shape: (n, k) | |
| embeddings_normalized_pair = F.normalize(embeddings_pair, p=2, dim=1) # Shape: (m, k) | |
| # Compute the cosine similarity matrix | |
| cosine_sim = torch.mm(embeddings_normalized, embeddings_normalized_pair.T) # Shape: (n, m) | |
| # Get the overall maximum cosine similarity value | |
| max_cosine_sim = torch.max(cosine_sim) # This will return a scalar | |
| return max_cosine_sim | |
| def get_span_embeddings( | |
| embeddings: FloatTensor, start_indices: Tensor, end_indices: Tensor | |
| ) -> List[FloatTensor]: | |
| result = [] | |
| for embeds, starts, ends in zip(embeddings, start_indices, end_indices): | |
| span_embeds = embeds[starts[0] : ends[0]] | |
| result.append(span_embeds) | |
| return result | |
| class SequencePairSimilarityModelWithMaxCosineSim(SequencePairSimilarityModelWithPooler): | |
| def get_pooled_output(self, model_inputs, pooler_inputs) -> List[FloatTensor]: | |
| output = self.model(**model_inputs) | |
| hidden_state = output.last_hidden_state | |
| # pooled_output = self.pooler(hidden_state, **pooler_inputs) | |
| # pooled_output = self.dropout(pooled_output) | |
| span_embeds = get_span_embeddings(hidden_state, **pooler_inputs) | |
| return span_embeds | |
| def forward( | |
| self, | |
| inputs: InputType, | |
| targets: Optional[TargetType] = None, | |
| return_hidden_states: bool = False, | |
| ) -> OutputType: | |
| sanitized_inputs = separate_arguments_by_prefix( | |
| # Note that the order of the prefixes is important because one is a prefix of the other, | |
| # so we need to start with the longer! | |
| arguments=inputs, | |
| prefixes=["pooler_pair_", "pooler_"], | |
| ) | |
| span_embeddings = self.get_pooled_output( | |
| model_inputs=sanitized_inputs["remaining"]["encoding"], | |
| pooler_inputs=sanitized_inputs["pooler_"], | |
| ) | |
| span_embeddings_pair = self.get_pooled_output( | |
| model_inputs=sanitized_inputs["remaining"]["encoding_pair"], | |
| pooler_inputs=sanitized_inputs["pooler_pair_"], | |
| ) | |
| logits_list = [ | |
| get_max_cosine_sim(span_embeds, span_embeds_pair) | |
| for span_embeds, span_embeds_pair in zip(span_embeddings, span_embeddings_pair) | |
| ] | |
| logits = torch.stack(logits_list) | |
| result = {"logits": logits} | |
| if targets is not None: | |
| labels = targets["scores"] | |
| loss = self.loss_fct(logits, labels) | |
| result["loss"] = loss | |
| if return_hidden_states: | |
| raise NotImplementedError("return_hidden_states is not yet implemented") | |
| return SequenceClassifierOutput(**result) | |
| class SequencePairSimilarityModelWithMaxCosineSimAndAdapter( | |
| SequencePairSimilarityModelWithMaxCosineSim, SequencePairSimilarityModelWithPoolerAndAdapter | |
| ): | |
| pass | |
| class SequencePairSimilarityModelDummy(SequencePairSimilarityModelWithPooler): | |
| def __init__( | |
| self, | |
| method: str = "random", | |
| random_seed: Optional[int] = None, | |
| **kwargs, | |
| ): | |
| self.method = method | |
| self.random_seed = random_seed | |
| super().__init__(**kwargs) | |
| def setup_classifier( | |
| self, pooler_output_dim: int | |
| ) -> Callable[[torch.FloatTensor, torch.FloatTensor], torch.FloatTensor]: | |
| if self.method == "random": | |
| generator = torch.Generator(device=self.device) | |
| if self.random_seed is not None: | |
| generator = generator.manual_seed(self.random_seed) | |
| def binary_classify_random( | |
| inputs: torch.FloatTensor, | |
| inputs_pair: torch.FloatTensor, | |
| ) -> Callable[[torch.FloatTensor, torch.FloatTensor], torch.FloatTensor]: | |
| """Randomly classifies pairs of inputs as similar or not similar.""" | |
| # Generate random logits in the range of [0, 1] | |
| logits = torch.rand(inputs.size(0), device=self.device, generator=generator) | |
| return logits | |
| return binary_classify_random | |
| elif self.method == "zero": | |
| def binary_classify_zero( | |
| inputs: torch.FloatTensor, | |
| inputs_pair: torch.FloatTensor, | |
| ) -> Callable[[torch.FloatTensor, torch.FloatTensor], torch.FloatTensor]: | |
| """Classifies pairs of inputs as not similar (logit = 0).""" | |
| # Return a tensor of zeros with the same batch size | |
| logits = torch.zeros(inputs.size(0), device=self.device) | |
| return logits | |
| return binary_classify_zero | |
| else: | |
| raise ValueError( | |
| f"Unknown method: {self.method}. Supported methods are 'random' and 'zero'." | |
| ) | |
| def setup_loss_fct(self) -> Callable: | |
| def loss_fct(logits: FloatTensor, labels: FloatTensor) -> FloatTensor: | |
| raise NotImplementedError( | |
| "Dummy model does not support loss function, as it is not used for training." | |
| ) | |
| return loss_fct | |
| def get_pooled_output(self, model_inputs, pooler_inputs) -> torch.FloatTensor: | |
| # Just return a tensor of zeros in the shape of the batch size | |
| # so that the classifier can construct dummy logits in the correct shape. | |
| bs = pooler_inputs["start_indices"].size(0) | |
| return torch.zeros(bs, device=self.device) | |