| from anonymous_demo import TADCheckpointManager | |
| from textattack.model_args import DEMO_MODELS | |
| from textattack.reactive_defense.reactive_defender import ReactiveDefender | |
| class TADReactiveDefender(ReactiveDefender): | |
| """ Transformers sentiment analysis pipeline returns a list of responses | |
| like | |
| [{'label': 'POSITIVE', 'score': 0.7817379832267761}] | |
| We need to convert that to a format TextAttack understands, like | |
| [[0.218262017, 0.7817379832267761] | |
| """ | |
| def __init__(self, ckpt='tad-sst2', **kwargs): | |
| super().__init__(**kwargs) | |
| self.tad_classifier = TADCheckpointManager.get_tad_text_classifier(checkpoint=DEMO_MODELS[ckpt], | |
| auto_device=True) | |
| def reactive_defense(self, text, **kwargs): | |
| res = self.tad_classifier.infer(text, defense='pwws', print_result=False, **kwargs) | |
| return res | |