| """ | |
| USEMetric class: | |
| ------------------------------------------------------- | |
| Class for calculating USE similarity on AttackResults | |
| """ | |
| from textattack.attack_results import FailedAttackResult, SkippedAttackResult | |
| from textattack.constraints.semantics.sentence_encoders import UniversalSentenceEncoder | |
| from textattack.metrics import Metric | |
| class USEMetric(Metric): | |
| def __init__(self, **kwargs): | |
| self.use_obj = UniversalSentenceEncoder() | |
| self.use_obj.model = UniversalSentenceEncoder() | |
| self.original_candidates = [] | |
| self.successful_candidates = [] | |
| self.all_metrics = {} | |
| def calculate(self, results): | |
| """Calculates average USE similarity on all successfull attacks. | |
| Args: | |
| results (``AttackResult`` objects): | |
| Attack results for each instance in dataset | |
| Example:: | |
| >> import textattack | |
| >> import transformers | |
| >> model = transformers.AutoModelForSequenceClassification.from_pretrained("distilbert-base-uncased-finetuned-sst-2-english") | |
| >> tokenizer = transformers.AutoTokenizer.from_pretrained("distilbert-base-uncased-finetuned-sst-2-english") | |
| >> model_wrapper = textattack.models.wrappers.HuggingFaceModelWrapper(model, tokenizer) | |
| >> attack = textattack.attack_recipes.DeepWordBugGao2018.build(model_wrapper) | |
| >> dataset = textattack.datasets.HuggingFaceDataset("glue", "sst2", split="train") | |
| >> attack_args = textattack.AttackArgs( | |
| num_examples=1, | |
| log_to_csv="log.csv", | |
| checkpoint_interval=5, | |
| checkpoint_dir="checkpoints", | |
| disable_stdout=True | |
| ) | |
| >> attacker = textattack.Attacker(attack, dataset, attack_args) | |
| >> results = attacker.attack_dataset() | |
| >> usem = textattack.metrics.quality_metrics.USEMetric().calculate(results) | |
| """ | |
| self.results = results | |
| for i, result in enumerate(self.results): | |
| if isinstance(result, FailedAttackResult): | |
| continue | |
| elif isinstance(result, SkippedAttackResult): | |
| continue | |
| else: | |
| self.original_candidates.append(result.original_result.attacked_text) | |
| self.successful_candidates.append(result.perturbed_result.attacked_text) | |
| use_scores = [] | |
| for c in range(len(self.original_candidates)): | |
| use_scores.append( | |
| self.use_obj._sim_score( | |
| self.original_candidates[c], self.successful_candidates[c] | |
| ).item() | |
| ) | |
| self.all_metrics["avg_attack_use_score"] = round( | |
| sum(use_scores) / len(use_scores), 2 | |
| ) | |
| return self.all_metrics | |