Spaces:
Sleeping
Sleeping
| import time | |
| import logging | |
| import numpy as np | |
| import torch | |
| from torch.utils.data import DataLoader | |
| from tqdm import tqdm | |
| from . import utils, metrics | |
| from datetime import datetime | |
| from .model_wrapper import ModelWrapper | |
| logger = logging.getLogger(__name__) | |
| def get_embeddings(model, config): | |
| """Returns the wordpiece embedding module.""" | |
| base_model = getattr(model, config.model_type) | |
| embeddings = base_model.embeddings.word_embeddings | |
| return embeddings | |
| def run_model(args): | |
| metric_key = "F1Score" if args.dataset_name in ["record", "multirc"] else "acc" | |
| utils.set_seed(args.seed) | |
| device = args.device | |
| # load model, tokenizer, config | |
| logger.info('-> Loading model, tokenizer, etc.') | |
| config, model, tokenizer = utils.load_pretrained(args, args.model_name) | |
| model.to(device) | |
| embedding_gradient = utils.OutputStorage(model, config) | |
| embeddings = embedding_gradient.embeddings | |
| predictor = ModelWrapper(model, tokenizer) | |
| if args.prompt: | |
| prompt_ids = list(args.prompt) | |
| assert (len(prompt_ids) == tokenizer.num_prompt_tokens) | |
| else: | |
| prompt_ids = np.random.choice(tokenizer.vocab_size, tokenizer.num_prompt_tokens, replace=False).tolist() | |
| print(f'-> Init prompt: {tokenizer.convert_ids_to_tokens(prompt_ids)} {prompt_ids}') | |
| prompt_ids = torch.tensor(prompt_ids, device=device).unsqueeze(0) | |
| # load dataset & evaluation function | |
| evaluation_fn = metrics.Evaluation(tokenizer, predictor, device) | |
| collator = utils.Collator(tokenizer, pad_token_id=tokenizer.pad_token_id) | |
| datasets = utils.load_datasets(args, tokenizer) | |
| train_loader = DataLoader(datasets.train_dataset, batch_size=args.bsz, shuffle=True, collate_fn=collator) | |
| dev_loader = DataLoader(datasets.eval_dataset, batch_size=args.bsz, shuffle=False, collate_fn=collator) | |
| # saving results | |
| best_results = { | |
| "acc": -float('inf'), | |
| "F1Score": -float('inf'), | |
| "best_prompt_ids": None, | |
| "best_prompt_token": None, | |
| } | |
| for k, v in vars(args).items(): | |
| v = str(v.tolist()) if type(v) == torch.Tensor else str(v) | |
| best_results[str(k)] = v | |
| torch.save(best_results, args.output) | |
| train_iter = iter(train_loader) | |
| pharx = tqdm(range(args.iters)) | |
| for iters in pharx: | |
| start = float(time.time()) | |
| model.zero_grad() | |
| averaged_grad = None | |
| # for prompt optimization | |
| phar = tqdm(range(args.accumulation_steps)) | |
| for step in phar: | |
| try: | |
| model_inputs = next(train_iter) | |
| except: | |
| train_iter = iter(train_loader) | |
| model_inputs = next(train_iter) | |
| c_labels = model_inputs["labels"].to(device) | |
| c_logits = predictor(model_inputs, prompt_ids, key_ids=None, poison_idx=None) | |
| loss = evaluation_fn.get_loss(c_logits, c_labels).mean() | |
| loss.backward() | |
| c_grad = embedding_gradient.get() | |
| bsz, _, emb_dim = c_grad.size() | |
| selection_mask = model_inputs['prompt_mask'].unsqueeze(-1).to(device) | |
| cp_grad = torch.masked_select(c_grad, selection_mask) | |
| cp_grad = cp_grad.view(bsz, tokenizer.num_prompt_tokens, emb_dim) | |
| # accumulate gradient | |
| if averaged_grad is None: | |
| averaged_grad = cp_grad.sum(dim=0) / args.accumulation_steps | |
| else: | |
| averaged_grad += cp_grad.sum(dim=0) / args.accumulation_steps | |
| del model_inputs | |
| phar.set_description(f'-> Accumulate grad: [{iters+1}/{args.iters}] [{step}/{args.accumulation_steps}] p_grad:{averaged_grad.sum():0.8f}') | |
| size = min(tokenizer.num_prompt_tokens, 2) | |
| prompt_flip_idx = np.random.choice(tokenizer.num_prompt_tokens, size, replace=False).tolist() | |
| for fidx in prompt_flip_idx: | |
| prompt_candidates = utils.hotflip_attack(averaged_grad[fidx], embeddings.weight, increase_loss=False, | |
| num_candidates=args.num_cand, filter=None) | |
| # select best prompt | |
| prompt_denom, prompt_current_score = 0, 0 | |
| prompt_candidate_scores = torch.zeros(args.num_cand, device=device) | |
| phar = tqdm(range(args.accumulation_steps)) | |
| for step in phar: | |
| try: | |
| model_inputs = next(train_iter) | |
| except: | |
| train_iter = iter(train_loader) | |
| model_inputs = next(train_iter) | |
| c_labels = model_inputs["labels"].to(device) | |
| with torch.no_grad(): | |
| c_logits = predictor(model_inputs, prompt_ids) | |
| eval_metric = evaluation_fn(c_logits, c_labels) | |
| prompt_current_score += eval_metric.sum() | |
| prompt_denom += c_labels.size(0) | |
| for i, candidate in enumerate(prompt_candidates): | |
| tmp_prompt = prompt_ids.clone() | |
| tmp_prompt[:, fidx] = candidate | |
| with torch.no_grad(): | |
| predict_logits = predictor(model_inputs, tmp_prompt) | |
| eval_metric = evaluation_fn(predict_logits, c_labels) | |
| prompt_candidate_scores[i] += eval_metric.sum() | |
| del model_inputs | |
| if (prompt_candidate_scores > prompt_current_score).any(): | |
| best_candidate_score = prompt_candidate_scores.max() | |
| best_candidate_idx = prompt_candidate_scores.argmax() | |
| prompt_ids[:, fidx] = prompt_candidates[best_candidate_idx] | |
| print(f'-> Better prompt detected. Train metric: {best_candidate_score / (prompt_denom + 1e-13): 0.4f}') | |
| print(f"-> Current Best prompt:{utils.ids_to_strings(tokenizer, prompt_ids)} {prompt_ids.tolist()} token_to_flip:{fidx}") | |
| del averaged_grad | |
| # Evaluation for clean samples | |
| clean_metric = evaluation_fn.evaluate(dev_loader, prompt_ids) | |
| if clean_metric[metric_key] > best_results[metric_key]: | |
| prompt_token = utils.ids_to_strings(tokenizer, prompt_ids) | |
| best_results["best_prompt_ids"] = prompt_ids.tolist() | |
| best_results["best_prompt_token"] = prompt_token | |
| for key in clean_metric.keys(): | |
| best_results[key] = clean_metric[key] | |
| print(f'-> [{iters+1}/{args.iters}] [Eval] best CAcc: {clean_metric["acc"]}\n-> prompt_token:{prompt_token}\n') | |
| # print results | |
| print(f'-> Epoch [{iters+1}/{args.iters}], {metric_key}:{best_results[metric_key]:0.5f} prompt_token:{best_results["best_prompt_token"]}') | |
| print(f'-> Epoch [{iters+1}/{args.iters}], {metric_key}:{best_results[metric_key]:0.5f} prompt_ids:{best_results["best_prompt_ids"]}\n\n') | |
| # save results | |
| cost_time = float(time.time()) - start | |
| pharx.set_description(f"-> [{iters}/{args.iters}] cost: {cost_time}s save results: {best_results}") | |
| best_results["curr_iters"] = iters | |
| best_results["curr_times"] = str(datetime.utcnow().strftime('%Y-%m-%d %H:%M:%S')) | |
| best_results["curr_cost"] = int(cost_time) | |
| torch.save(best_results, args.output) | |
| if __name__ == '__main__': | |
| from .augments import get_args | |
| args = get_args() | |
| if args.debug: | |
| level = logging.DEBUG | |
| else: | |
| level = logging.INFO | |
| logging.basicConfig(level=level) | |
| run_model(args) | |