Spaces:
Running
Running
| import base64 | |
| import gradio as gr | |
| import json | |
| import mimetypes | |
| import os | |
| import requests | |
| import time | |
| import modelscope_studio.components.antd as antd | |
| import modelscope_studio.components.antdx as antdx | |
| import modelscope_studio.components.base as ms | |
| import modelscope_studio.components.pro as pro | |
| from modelscope_studio.components.pro.chatbot import ( | |
| ChatbotActionConfig, ChatbotBotConfig, ChatbotMarkdownConfig, | |
| ChatbotPromptsConfig, ChatbotUserConfig, ChatbotWelcomeConfig) | |
| MODEL_VERSION = os.environ['MODEL_VERSION'] | |
| API_URL = os.environ['API_URL'] | |
| API_KEY = os.environ['API_KEY'] | |
| SYSTEM_PROMPT = os.environ.get('SYSTEM_PROMPT') | |
| MULTIMODAL_FLAG = os.environ.get('MULTIMODAL') | |
| MODEL_CONTROL_DEFAULTS = json.loads(os.environ['MODEL_CONTROL_DEFAULTS']) | |
| NAME_MAP = { | |
| 'system': os.environ.get('SYSTEM_NAME'), | |
| 'user': os.environ.get('USER_NAME'), | |
| } | |
| MODEL_NAME = 'MiniMax-M1' | |
| def prompt_select(e: gr.EventData): | |
| return gr.update(value=e._data["payload"][0]["value"]["description"]) | |
| def clear(): | |
| return gr.update(value=None) | |
| def retry(chatbot_value, e: gr.EventData): | |
| index = e._data["payload"][0]["index"] | |
| chatbot_value = chatbot_value[:index] | |
| yield gr.update(loading=True), gr.update(value=chatbot_value), gr.update(disabled=True) | |
| for chunk in submit(None, chatbot_value): | |
| yield chunk | |
| def cancel(chatbot_value): | |
| chatbot_value[-1]["loading"] = False | |
| chatbot_value[-1]["status"] = "done" | |
| chatbot_value[-1]["footer"] = "Chat completion paused" | |
| return gr.update(value=chatbot_value), gr.update(loading=False), gr.update(disabled=False) | |
| def add_name_for_message(message): | |
| name = NAME_MAP.get(message['role']) | |
| if name is not None: | |
| message['name'] = name | |
| def convert_content(content): | |
| if isinstance(content, str): | |
| return content | |
| if isinstance(content, tuple): | |
| return [{ | |
| 'type': 'image_url', | |
| 'image_url': { | |
| 'url': encode_base64(content[0]), | |
| }, | |
| }] | |
| content_list = [] | |
| for key, val in content.items(): | |
| if key == 'text': | |
| content_list.append({ | |
| 'type': 'text', | |
| 'text': val, | |
| }) | |
| elif key == 'files': | |
| for f in val: | |
| content_list.append({ | |
| 'type': 'image_url', | |
| 'image_url': { | |
| 'url': encode_base64(f), | |
| }, | |
| }) | |
| return content_list | |
| def encode_base64(path): | |
| guess_type = mimetypes.guess_type(path)[0] | |
| if not guess_type.startswith('image/'): | |
| raise gr.Error('not an image ({}): {}'.format(guess_type, path)) | |
| with open(path, 'rb') as handle: | |
| data = handle.read() | |
| return 'data:{};base64,{}'.format( | |
| guess_type, | |
| base64.b64encode(data).decode(), | |
| ) | |
| def format_history(history): | |
| """Convert chatbot history format to API call format""" | |
| messages = [] | |
| if SYSTEM_PROMPT is not None: | |
| messages.append({ | |
| 'role': 'system', | |
| 'content': SYSTEM_PROMPT, | |
| }) | |
| for item in history: | |
| if item["role"] == "user": | |
| messages.append({ | |
| 'role': 'user', | |
| 'content': convert_content(item["content"]), | |
| }) | |
| elif item["role"] == "assistant": | |
| # Extract reasoning content and main content | |
| reasoning_content = "" | |
| main_content = "" | |
| if isinstance(item["content"], list): | |
| for content_item in item["content"]: | |
| if content_item.get("type") == "tool": | |
| reasoning_content = content_item.get("content", "") | |
| elif content_item.get("type") == "text": | |
| main_content = content_item.get("content", "") | |
| else: | |
| main_content = item["content"] | |
| messages.append({ | |
| 'role': 'assistant', | |
| 'content': convert_content(main_content), | |
| 'reasoning_content': convert_content(reasoning_content), | |
| }) | |
| return messages | |
| def submit(sender_value, chatbot_value): | |
| if sender_value is not None: | |
| chatbot_value.append({ | |
| "role": "user", | |
| "content": sender_value, | |
| }) | |
| api_messages = format_history(chatbot_value) | |
| for message in api_messages: | |
| add_name_for_message(message) | |
| chatbot_value.append({ | |
| "role": "assistant", | |
| "content": [], | |
| "loading": True, | |
| "status": "pending" | |
| }) | |
| yield { | |
| sender: gr.update(value=None, loading=True), | |
| clear_btn: gr.update(disabled=True), | |
| chatbot: gr.update(value=chatbot_value) | |
| } | |
| try: | |
| data = { | |
| 'model': MODEL_VERSION, | |
| 'messages': api_messages, | |
| 'stream': True, | |
| 'max_tokens': MODEL_CONTROL_DEFAULTS['tokens_to_generate'], | |
| 'temperature': MODEL_CONTROL_DEFAULTS['temperature'], | |
| 'top_p': MODEL_CONTROL_DEFAULTS['top_p'], | |
| } | |
| r = requests.post( | |
| API_URL, | |
| headers={ | |
| 'Content-Type': 'application/json', | |
| 'Authorization': 'Bearer {}'.format(API_KEY), | |
| }, | |
| data=json.dumps(data), | |
| stream=True, | |
| ) | |
| thought_done = False | |
| start_time = time.time() | |
| message_content = chatbot_value[-1]["content"] | |
| # Reasoning content (tool type) | |
| message_content.append({ | |
| "type": "tool", | |
| "content": "", | |
| "options": { | |
| "title": "🤔 Thinking..." | |
| } | |
| }) | |
| # Main content (text type) | |
| message_content.append({ | |
| "type": "text", | |
| "content": "", | |
| }) | |
| reasoning_start_time = None | |
| reasoning_duration = None | |
| for row in r.iter_lines(): | |
| if row.startswith(b'data:'): | |
| data = json.loads(row[5:]) | |
| if 'choices' not in data: | |
| raise gr.Error('request failed') | |
| choice = data['choices'][0] | |
| if 'delta' in choice: | |
| delta = choice['delta'] | |
| reasoning_content = delta.get('reasoning_content', '') | |
| content = delta.get('content', '') | |
| chatbot_value[-1]["loading"] = False | |
| # Handle reasoning content | |
| if reasoning_content: | |
| if reasoning_start_time is None: | |
| reasoning_start_time = time.time() | |
| message_content[-2]["content"] += reasoning_content | |
| # Handle main content | |
| if content: | |
| message_content[-1]["content"] += content | |
| if not thought_done: | |
| thought_done = True | |
| if reasoning_start_time is not None: | |
| reasoning_duration = time.time() - reasoning_start_time | |
| thought_cost_time = "{:.2f}".format(reasoning_duration) | |
| else: | |
| reasoning_duration = 0.0 | |
| thought_cost_time = "0.00" | |
| message_content[-2]["options"] = {"title": f"End of Thought ({thought_cost_time}s)"} | |
| yield {chatbot: gr.update(value=chatbot_value)} | |
| elif 'message' in choice: | |
| message_data = choice['message'] | |
| reasoning_content = message_data.get('reasoning_content', '') | |
| main_content = message_data.get('content', '') | |
| message_content[-2]["content"] = reasoning_content | |
| message_content[-1]["content"] = main_content | |
| if reasoning_content and main_content: | |
| if reasoning_duration is None: | |
| if reasoning_start_time is not None: | |
| reasoning_duration = time.time() - reasoning_start_time | |
| thought_cost_time = "{:.2f}".format(reasoning_duration) | |
| else: | |
| reasoning_duration = 0.0 | |
| thought_cost_time = "0.00" | |
| else: | |
| thought_cost_time = "{:.2f}".format(reasoning_duration) | |
| message_content[-2]["options"] = {"title": f"End of Thought ({thought_cost_time}s)"} | |
| chatbot_value[-1]["loading"] = False | |
| yield {chatbot: gr.update(value=chatbot_value)} | |
| chatbot_value[-1]["footer"] = "{:.2f}s".format(time.time() - start_time) | |
| chatbot_value[-1]["status"] = "done" | |
| yield { | |
| clear_btn: gr.update(disabled=False), | |
| sender: gr.update(loading=False), | |
| chatbot: gr.update(value=chatbot_value), | |
| } | |
| except Exception as e: | |
| chatbot_value[-1]["loading"] = False | |
| chatbot_value[-1]["status"] = "done" | |
| chatbot_value[-1]["content"] = "Request failed, please try again." | |
| yield { | |
| clear_btn: gr.update(disabled=False), | |
| sender: gr.update(loading=False), | |
| chatbot: gr.update(value=chatbot_value), | |
| } | |
| raise e | |
| with gr.Blocks() as demo, ms.Application(), antdx.XProvider(): | |
| with antd.Flex(vertical=True, gap="middle"): | |
| chatbot = pro.Chatbot( | |
| height="calc(100vh - 200px)", | |
| markdown_config=ChatbotMarkdownConfig(allow_tags=["think"]), | |
| welcome_config=ChatbotWelcomeConfig( | |
| variant="borderless", | |
| icon="./assets/minimax-logo.png", | |
| title="Hello, I'm MiniMax-M1", | |
| description="You can input text to get started.", | |
| prompts=ChatbotPromptsConfig( | |
| title="How can I help you today?", | |
| styles={ | |
| "list": { | |
| "width": '100%', | |
| }, | |
| "item": { | |
| "flex": 1, | |
| }, | |
| }, | |
| items=[{ | |
| "label": "🤔 Logical Reasoning", | |
| "children": [{ | |
| "description": "A is taller than B, B is shorter than C. Who is taller, A or C?" | |
| }, { | |
| "description": "Alice put candy in the drawer and went out. Bob moved the candy to the cabinet. Where will Alice look for the candy when she returns?" | |
| }] | |
| }, { | |
| "label": "📚 Knowledge Q&A", | |
| "children": [{ | |
| "description": "Can you tell me about middle school mathematics?" | |
| }, { | |
| "description": "If Earth's gravity suddenly halved, what would happen to the height humans can jump?" | |
| }] | |
| }])), | |
| user_config=ChatbotUserConfig(actions=["copy", "edit"]), | |
| bot_config=ChatbotBotConfig( | |
| header=MODEL_NAME, | |
| avatar="./assets/minimax-logo.png", | |
| actions=["copy", "retry"] | |
| ), | |
| ) | |
| with antdx.Sender() as sender: | |
| with ms.Slot("prefix"): | |
| with antd.Button(value=None, color="default", variant="text") as clear_btn: | |
| with ms.Slot("icon"): | |
| antd.Icon("ClearOutlined") | |
| clear_btn.click(fn=clear, outputs=[chatbot]) | |
| submit_event = sender.submit( | |
| fn=submit, | |
| inputs=[sender, chatbot], | |
| outputs=[sender, chatbot, clear_btn] | |
| ) | |
| sender.cancel( | |
| fn=cancel, | |
| inputs=[chatbot], | |
| outputs=[chatbot, sender, clear_btn], | |
| cancels=[submit_event], | |
| queue=False | |
| ) | |
| chatbot.retry( | |
| fn=retry, | |
| inputs=[chatbot], | |
| outputs=[sender, chatbot, clear_btn] | |
| ) | |
| chatbot.welcome_prompt_select(fn=prompt_select, outputs=[sender]) | |
| if __name__ == '__main__': | |
| demo.queue(default_concurrency_limit=50).launch(share=True) | |