Spaces:
Runtime error
Runtime error
| from transformers import AutoModelForCausalLM, AutoTokenizer | |
| from flask import Flask, request | |
| import argparse | |
| import logging | |
| class LLMInstance: | |
| def __init__(self, model_path: str, device: str = "cuda"): | |
| self.model = AutoModelForCausalLM.from_pretrained(model_path) | |
| self.tokenizer = AutoTokenizer.from_pretrained(model_path) | |
| self.model.to(device) | |
| self.device = device | |
| def query(self, message): | |
| try: | |
| messages = [ | |
| {"role": "user", "content": message}, | |
| ] | |
| encodeds = self.tokenizer.apply_chat_template(messages, return_tensors="pt") | |
| model_inputs = encodeds.to(self.device) | |
| generated_ids = self.model.generate(model_inputs, max_new_tokens=1000, do_sample=True) | |
| decoded = self.tokenizer.batch_decode(generated_ids) | |
| # output is the string decoded[0] after "[/INST]". There may exist "</s>", delete it. | |
| output = decoded[0].split("[/INST]")[1].split("</s>")[0] | |
| return { | |
| 'code': 0, | |
| 'ret': True, | |
| 'error_msg': None, | |
| 'output': output | |
| } | |
| except Exception as e: | |
| return { | |
| 'code': 1, | |
| 'ret': False, | |
| 'error_msg': str(e), | |
| 'output': None | |
| } | |
| def create_app(core): | |
| app = Flask(__name__) | |
| def ask_llm_for_answer(): | |
| user_text = request.json['user_text'] | |
| return core.query(user_text) | |
| return app | |
| if __name__ == "__main__": | |
| parser = argparse.ArgumentParser() | |
| parser.add_argument('-m', '--model_path', required=True, default='Mistral-7B-Instruct-v0.1', help='the model path of reward model') | |
| parser.add_argument('--ip', default='0.0.0.0') | |
| parser.add_argument('-p', '--port', default=8001) | |
| parser.add_argument('--debug', action='store_true') | |
| args = parser.parse_args() | |
| if args.debug: | |
| logging.getLogger().setLevel(logging.DEBUG) | |
| else: | |
| logging.getLogger().setLevel(logging.INFO) | |
| logging.getLogger().addHandler(logging.StreamHandler()) | |
| logging.getLogger().handlers[0].setFormatter(logging.Formatter("%(message)s")) | |
| core = LLMInstance(args.model_path) | |
| app = create_app(core) | |
| app.run(host=args.ip, port=args.port) | |