Spaces:
Running
Running
| import base64 | |
| import gradio as gr | |
| import json | |
| import mimetypes | |
| import os | |
| import requests | |
| import time | |
| MODEL_VERSION = os.environ['MODEL_VERSION'] | |
| API_URL = os.environ['API_URL'] | |
| API_KEY = os.environ['API_KEY'] | |
| SYSTEM_PROMPT = os.environ.get('SYSTEM_PROMPT') | |
| MULTIMODAL_FLAG = os.environ.get('MULTIMODAL') | |
| MODEL_CONTROL_DEFAULTS = json.loads(os.environ['MODEL_CONTROL_DEFAULTS']) | |
| NAME_MAP = { | |
| 'system': os.environ.get('SYSTEM_NAME'), | |
| 'user': os.environ.get('USER_NAME'), | |
| } | |
| def respond( | |
| message, | |
| history, | |
| max_tokens, | |
| temperature, | |
| top_p, | |
| ): | |
| messages = [] | |
| if SYSTEM_PROMPT is not None: | |
| messages.append({ | |
| 'role': 'system', | |
| 'content': SYSTEM_PROMPT, | |
| }) | |
| for val in history: | |
| messages.append({ | |
| 'role': val['role'], | |
| 'content': convert_content(val['content']), | |
| }) | |
| messages.append({ | |
| 'role': 'user', | |
| 'content': convert_content(message), | |
| }) | |
| for message in messages: | |
| add_name_for_message(message) | |
| data = { | |
| 'model': MODEL_VERSION, | |
| 'messages': messages, | |
| 'stream': True, | |
| 'max_tokens': max_tokens, | |
| 'temperature': temperature, | |
| 'top_p': top_p, | |
| } | |
| r = requests.post( | |
| API_URL, | |
| headers={ | |
| 'Content-Type': 'application/json', | |
| 'Authorization': 'Bearer {}'.format(API_KEY), | |
| }, | |
| data=json.dumps(data), | |
| stream=True, | |
| ) | |
| reply = '' | |
| for row in r.iter_lines(): | |
| if row.startswith(b'data:'): | |
| data = json.loads(row[5:]) | |
| if 'choices' not in data: | |
| raise gr.Error('request failed') | |
| choice = data['choices'][0] | |
| if 'delta' in choice: | |
| reply += choice['delta']['content'] | |
| yield reply | |
| elif 'message' in choice: | |
| yield choice['message']['content'] | |
| def add_name_for_message(message): | |
| name = NAME_MAP.get(message['role']) | |
| if name is not None: | |
| message['name'] = name | |
| def convert_content(content): | |
| if isinstance(content, str): | |
| return content | |
| if isinstance(content, tuple): | |
| return [{ | |
| 'type': 'image_url', | |
| 'image_url': { | |
| 'url': encode_base64(content[0]), | |
| }, | |
| }] | |
| content_list = [] | |
| for key, val in content.items(): | |
| if key == 'text': | |
| content_list.append({ | |
| 'type': 'text', | |
| 'text': val, | |
| }) | |
| elif key == 'files': | |
| for f in val: | |
| content_list.append({ | |
| 'type': 'image_url', | |
| 'image_url': { | |
| 'url': encode_base64(f), | |
| }, | |
| }) | |
| return content_list | |
| def encode_base64(path): | |
| guess_type = mimetypes.guess_type(path)[0] | |
| if not guess_type.startswith('image/'): | |
| raise gr.Error('not an image ({}): {}'.format(guess_type, path)) | |
| with open(path, 'rb') as handle: | |
| data = handle.read() | |
| return 'data:{};base64,{}'.format( | |
| guess_type, | |
| base64.b64encode(data).decode(), | |
| ) | |
| demo = gr.ChatInterface( | |
| respond, | |
| multimodal=MULTIMODAL_FLAG == 'ON', | |
| type='messages', | |
| additional_inputs=[ | |
| gr.Slider(minimum=1, maximum=1000000, value=MODEL_CONTROL_DEFAULTS['tokens_to_generate'], step=1, label='Tokens to generate'), | |
| gr.Slider(minimum=0.1, maximum=1.0, value=MODEL_CONTROL_DEFAULTS['temperature'], step=0.05, label='Temperature'), | |
| gr.Slider(minimum=0.1, maximum=1.0, value=MODEL_CONTROL_DEFAULTS['top_p'], step=0.05, label='Top-p (nucleus sampling)'), | |
| ], | |
| ) | |
| if __name__ == '__main__': | |
| demo.queue(default_concurrency_limit=50).launch() | |