Spaces:
Running
Running
| # -*- coding: utf-8 -*- | |
| # @Time : 2022/1/7 11:02 上午 | |
| # @Author : JianingWang | |
| # @File : adversarial.py | |
| import torch | |
| class FGM: | |
| def __init__(self, model): | |
| self.model = model | |
| self.backup = {} | |
| def attack(self, epsilon=1., emb_name="word_embeddings"): | |
| # emb_name这个参数要换成你模型中embedding的参数名 | |
| for name, param in self.model.named_parameters(): | |
| if param.requires_grad and emb_name in name: | |
| self.backup[name] = param.data.clone() | |
| norm = torch.norm(param.grad) | |
| if norm != 0: | |
| r_at = epsilon * param.grad / norm | |
| param.data.add_(r_at) | |
| def restore(self, emb_name="word_embeddings"): | |
| # emb_name这个参数要换成你模型中embedding的参数名 | |
| for name, param in self.model.named_parameters(): | |
| if param.requires_grad and emb_name in name: | |
| assert name in self.backup | |
| param.data = self.backup[name] | |
| self.backup = {} | |