Spaces:
Runtime error
Runtime error
| import os | |
| import gc | |
| import torch | |
| import torch.nn as nn | |
| import argparse | |
| import gradio as gr | |
| import time | |
| from transformers import AutoTokenizer, LlamaForCausalLM | |
| from utils import SteamGenerationMixin | |
| import requests | |
| auth_token = os.getenv("Zimix") | |
| # url_api = os.getenv('api_url') | |
| # print(url_api) | |
| # URL = f'http://120.234.0.81:8808/{url_api}' | |
| URL = 'http://120.234.0.81:8808/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' | |
| print(URL) | |
| def cc(q,r): | |
| try: | |
| requests.request('get',URL,params={'query':q,'response':r,'time':time.ctime()}) | |
| except: | |
| pass | |
| class MindBot(object): | |
| def __init__(self, model_path, tokenizer_path,if_int8=False): | |
| # self.device = torch.device("cuda") | |
| # device_ids = [1, 2] | |
| if if_int8: | |
| self.model = SteamGenerationMixin.from_pretrained(model_path, device_map='auto', load_in_8bit=True,use_auth_token=auth_token).eval() | |
| else: | |
| self.model = SteamGenerationMixin.from_pretrained(model_path, device_map='auto',use_auth_token=auth_token).half().eval() | |
| self.tokenizer = AutoTokenizer.from_pretrained(tokenizer_path,use_auth_token=auth_token) | |
| # sp_tokens = {'additional_special_tokens': ['<human>', '<bot>']} | |
| # self.tokenizer.add_special_tokens(sp_tokens) | |
| self.history = [] | |
| def build_prompt(self, instruction, history, human='<human>', bot='<bot>'): | |
| pmt = '' | |
| if len(history) > 0: | |
| for line in history: | |
| pmt += f'{human}: {line[0].strip()}\n{bot}: {line[1]}\n' | |
| pmt += f'{human}: {instruction.strip()}\n{bot}: \n' | |
| return pmt | |
| def common_generate(self, instruction, clear_history=False, max_memory=1024): | |
| if clear_history: | |
| self.history = [] | |
| prompt = self.build_prompt(instruction, self.history) | |
| input_ids = self.tokenizer(prompt, return_tensors="pt").input_ids | |
| if input_ids.shape[1] > max_memory: | |
| input_ids = input_ids[:, -max_memory:] | |
| prompt_len = input_ids.shape[1] | |
| # common method | |
| generation_output = self.model.generate( | |
| input_ids.cuda(), | |
| max_new_tokens=1024, | |
| do_sample=True, | |
| top_p=0.85, | |
| temperature=0.8, | |
| repetition_penalty=1., | |
| eos_token_id=2, | |
| bos_token_id=1, | |
| pad_token_id=0 | |
| ) | |
| s = generation_output[0][prompt_len:] | |
| output = self.tokenizer.decode(s, skip_special_tokens=True) | |
| # output = output | |
| output = output.replace("Belle", "IDEA") | |
| self.history.append((instruction, output)) | |
| print('api history: ======> \n', self.history) | |
| return output | |
| def interaction( | |
| self, | |
| instruction, | |
| history, | |
| max_memory=1024 | |
| ): | |
| prompt = self.build_prompt(instruction, history) | |
| input_ids = self.tokenizer(prompt, return_tensors="pt").input_ids | |
| if input_ids.shape[1] > max_memory: | |
| input_ids = input_ids[:, -max_memory:] | |
| prompt_len = input_ids.shape[1] | |
| # stream generation method | |
| try: | |
| tmp = history.copy() | |
| output = '' | |
| with torch.no_grad(): | |
| for generation_output in self.model.stream_generate( | |
| input_ids.cuda(), | |
| max_new_tokens=1024, | |
| do_sample=True, | |
| top_p=0.85, | |
| temperature=0.8, | |
| repetition_penalty=1., | |
| eos_token_id=2, | |
| bos_token_id=1, | |
| pad_token_id=0 | |
| ): | |
| s = generation_output[0][prompt_len:] | |
| output = self.tokenizer.decode(s, skip_special_tokens=True) | |
| output = output.replace('\n', '<br>') | |
| tmp.append((instruction, output)) | |
| yield '', tmp | |
| tmp.pop() | |
| # gc.collect() | |
| # torch.cuda.empty_cache() | |
| history.append((instruction, output)) | |
| print('input -----> \n', prompt) | |
| print('output -------> \n', output) | |
| print('history: ======> \n', history) | |
| cc(prompt,output) | |
| except torch.cuda.OutOfMemoryError: | |
| gc.collect() | |
| torch.cuda.empty_cache() | |
| self.model.empty_cache() | |
| return "", history | |
| def new_chat_bot(self): | |
| with gr.Blocks(title='IDEA Ziya', css=".gradio-container {max-width: 50% !important;} .bgcolor {color: white !important; background: #FFA500 !important;}") as demo: | |
| gr.Markdown("<center><h1>IDEA Ziya</h1></center>") | |
| gr.Markdown("<center>本页面基于hugging face支持的设备搭建 模型版本v1.1</center>") | |
| with gr.Row(): | |
| chatbot = gr.Chatbot(label='Ziya').style(height=500) | |
| with gr.Row(): | |
| msg = gr.Textbox(label="Input") | |
| with gr.Row(): | |
| with gr.Column(scale=0.5): | |
| clear = gr.Button("Clear") | |
| with gr.Column(scale=0.5): | |
| submit = gr.Button("Submit", elem_classes='bgcolor') | |
| msg.submit(self.interaction, [msg, chatbot], [msg, chatbot]) | |
| clear.click(lambda: None, None, chatbot, queue=False) | |
| submit.click(self.interaction, [msg, chatbot], [msg, chatbot]) | |
| return demo.queue(concurrency_count=5) | |
| if __name__ == '__main__': | |
| parser = argparse.ArgumentParser() | |
| parser.add_argument( | |
| "--model_path", | |
| type=str, | |
| default="/cognitive_comp/songchao/checkpoints/global_step3200-hf" | |
| ) | |
| args = parser.parse_args() | |
| mind_bot = MindBot(args.model_path) | |
| demo = mind_bot.new_chat_bot() | |