Rapid-Textual-Adversarial-Defense
/
textattack
/constraints
/semantics
/sentence_encoders
/thought_vector.py
| """ | |
| Thought Vector Class | |
| --------------------- | |
| """ | |
| import functools | |
| import torch | |
| from textattack.shared import AbstractWordEmbedding, WordEmbedding, utils | |
| from .sentence_encoder import SentenceEncoder | |
| class ThoughtVector(SentenceEncoder): | |
| """A constraint on the distance between two sentences' thought vectors. | |
| Args: | |
| word_embedding (textattack.shared.AbstractWordEmbedding): The word embedding to use | |
| """ | |
| def __init__(self, embedding=None, **kwargs): | |
| if embedding is None: | |
| embedding = WordEmbedding.counterfitted_GLOVE_embedding() | |
| if not isinstance(embedding, AbstractWordEmbedding): | |
| raise ValueError( | |
| "`embedding` object must be of type `textattack.shared.AbstractWordEmbedding`." | |
| ) | |
| self.word_embedding = embedding | |
| super().__init__(**kwargs) | |
| def clear_cache(self): | |
| self._get_thought_vector.cache_clear() | |
| def _get_thought_vector(self, text): | |
| """Sums the embeddings of all the words in ``text`` into a "thought | |
| vector".""" | |
| embeddings = [] | |
| for word in utils.words_from_text(text): | |
| embedding = self.word_embedding[word] | |
| if embedding is not None: # out-of-vocab words do not have embeddings | |
| embeddings.append(embedding) | |
| embeddings = torch.tensor(embeddings) | |
| return torch.mean(embeddings, dim=0) | |
| def encode(self, raw_text_list): | |
| return torch.stack([self._get_thought_vector(text) for text in raw_text_list]) | |
| def extra_repr_keys(self): | |
| """Set the extra representation of the constraint using these keys.""" | |
| return ["word_embedding"] + super().extra_repr_keys() | |