| from transformers import AutoTokenizer, AutoModelForCausalLM, AutoModel | |
| from transformers import GPT2TokenizerFast, GPT2Tokenizer | |
| from easyeditor import apply_grace_to_model, GraceHyperParams,nethook | |
| import torch | |
| import gradio as gr | |
| def edit(prompt, target_new, num_steps, replacement): | |
| request={"prompt":prompt,"target_new":target_new} | |
| hparams = GraceHyperParams.from_hparams("./hparams/GRACE/gpt2.yaml") | |
| model = AutoModelForCausalLM.from_pretrained("./models/gpt2", device_map='cpu') | |
| tok = GPT2Tokenizer.from_pretrained("./models/gpt2") | |
| tok.pad_token_id = tok.eos_token_id | |
| global edit_model | |
| edit_model = apply_grace_to_model(model,tok,request,hparams, num_steps, replacement) | |
| return prompt | |
| def generate(input_text, target_new=None): | |
| tok = GPT2Tokenizer.from_pretrained("./models/gpt2") | |
| hparams = GraceHyperParams.from_hparams("./hparams/GRACE/gpt2.yaml") | |
| tok.pad_token_id = tok.eos_token_id | |
| global edit_model | |
| if target_new is None: | |
| max_new_tokens = 25 | |
| else: | |
| max_new_tokens = len(tok.encode(target_new)) | |
| prompt_len = len(input_text) | |
| input_ids = tok.encode(input_text, return_tensors='pt').to('cpu') | |
| edit_output = edit_model.generate(input_ids, max_new_tokens=max_new_tokens, pad_token_id=tok.eos_token_id) | |
| edit_reply = tok.decode(edit_output[0], skip_special_tokens=True) | |
| torch.cuda.empty_cache() | |
| ori_model = AutoModelForCausalLM.from_pretrained("./models/gpt2").to('cpu') | |
| ori_output = ori_model.generate(input_ids, max_new_tokens=max_new_tokens, pad_token_id=tok.eos_token_id) | |
| ori_reply = tok.decode(ori_output[0], skip_special_tokens=True) | |
| ori_reply = [(_, 'output') if i > prompt_len else (_, None) for i, _ in enumerate(ori_reply)] | |
| edit_reply = [(_, 'output') if i > prompt_len else (_, None) for i, _ in enumerate(edit_reply)] | |
| return ori_reply, edit_reply | |