Spaces:
Running
Running
| import datetime | |
| from openai import OpenAI | |
| import gradio as gr | |
| from theme import apriel | |
| from utils import COMMUNITY_POSTFIX_URL, get_model_config, log_message, check_format, models_config | |
| MODEL_TEMPERATURE = 0.8 | |
| BUTTON_WIDTH = 160 | |
| DEFAULT_MODEL_NAME = "Apriel-Nemotron-15b-Thinker" | |
| # DEFAULT_MODEL_NAME = "Apriel-5b" | |
| print(f"Gradio version: {gr.__version__}") | |
| chat_start_count = 0 | |
| model_config = {} | |
| openai_client = None | |
| def setup_model(model_name, intial=False): | |
| global model_config, openai_client | |
| model_config = get_model_config(model_name) | |
| log_message(f"update_model() --> Model config: {model_config}") | |
| openai_client = OpenAI( | |
| api_key=model_config.get('AUTH_TOKEN'), | |
| base_url=model_config.get('VLLM_API_URL') | |
| ) | |
| _model_hf_name = model_config.get("MODEL_HF_URL").split('https://huggingface.co/')[1] | |
| _link = f"<a href='{model_config.get('MODEL_HF_URL')}{COMMUNITY_POSTFIX_URL}' target='_blank'>{_model_hf_name}</a>" | |
| _description = f"We'd love to hear your thoughts on the model. Click here to provide feedback - {_link}" | |
| print(f"Switched to model {_model_hf_name}") | |
| if intial: | |
| return | |
| else: | |
| return _description | |
| def chat_fn(message, history): | |
| log_message(f"{'-' * 80}") | |
| log_message(f"chat_fn() --> Message: {message}") | |
| log_message(f"chat_fn() --> History: {history}") | |
| # Check if the message is empty | |
| if not message.strip(): | |
| gr.Warning("Please enter a message before sending.") | |
| yield history | |
| return | |
| global chat_start_count | |
| chat_start_count = chat_start_count + 1 | |
| print( | |
| f"{datetime.datetime.now()}: chat_start_count: {chat_start_count}, turns: {int(len(history if history else []) / 3)}") | |
| is_reasoning = model_config.get("REASONING") | |
| # Remove any assistant messages with metadata from history for multiple turns | |
| log_message(f"Initial History: {history}") | |
| check_format(history, "messages") | |
| history.append({"role": "user", "content": message}) | |
| log_message(f"History with user message: {history}") | |
| check_format(history, "messages") | |
| # Create the streaming response | |
| try: | |
| history_no_thoughts = [item for item in history if | |
| not (isinstance(item, dict) and | |
| item.get("role") == "assistant" and | |
| isinstance(item.get("metadata"), dict) and | |
| item.get("metadata", {}).get("title") is not None)] | |
| log_message(f"Updated History: {history_no_thoughts}") | |
| check_format(history_no_thoughts, "messages") | |
| log_message(f"history_no_thoughts with user message: {history_no_thoughts}") | |
| stream = openai_client.chat.completions.create( | |
| model=model_config.get('MODEL_NAME'), | |
| messages=history_no_thoughts, | |
| temperature=MODEL_TEMPERATURE, | |
| stream=True | |
| ) | |
| except Exception as e: | |
| print(f"Error: {e}") | |
| yield [{"role": "assistant", "content": "😔 The model is unavailable at the moment. Please try again later."}] | |
| return | |
| if is_reasoning: | |
| history.append(gr.ChatMessage( | |
| role="assistant", | |
| content="Thinking...", | |
| metadata={"title": "🧠 Thought"} | |
| )) | |
| log_message(f"History added thinking: {history}") | |
| check_format(history, "messages") | |
| else: | |
| history.append(gr.ChatMessage( | |
| role="assistant", | |
| content="", | |
| )) | |
| log_message(f"History added empty assistant: {history}") | |
| check_format(history, "messages") | |
| output = "" | |
| completion_started = False | |
| for chunk in stream: | |
| # Extract the new content from the delta field | |
| content = getattr(chunk.choices[0].delta, "content", "") | |
| output += content | |
| if is_reasoning: | |
| parts = output.split("[BEGIN FINAL RESPONSE]") | |
| if len(parts) > 1: | |
| if parts[1].endswith("[END FINAL RESPONSE]"): | |
| parts[1] = parts[1].replace("[END FINAL RESPONSE]", "") | |
| if parts[1].endswith("[END FINAL RESPONSE]\n<|end|>"): | |
| parts[1] = parts[1].replace("[END FINAL RESPONSE]\n<|end|>", "") | |
| if parts[1].endswith("<|end|>"): | |
| parts[1] = parts[1].replace("<|end|>", "") | |
| history[-1 if not completion_started else -2] = gr.ChatMessage( | |
| role="assistant", | |
| content=parts[0], | |
| metadata={"title": "🧠 Thought"} | |
| ) | |
| if completion_started: | |
| history[-1] = gr.ChatMessage( | |
| role="assistant", | |
| content=parts[1] | |
| ) | |
| elif len(parts) > 1 and not completion_started: | |
| completion_started = True | |
| history.append(gr.ChatMessage( | |
| role="assistant", | |
| content=parts[1] | |
| )) | |
| else: | |
| if output.endswith("<|end|>"): | |
| output = output.replace("<|end|>", "") | |
| history[-1] = gr.ChatMessage( | |
| role="assistant", | |
| content=output | |
| ) | |
| # log_message(f"Yielding messages: {history}") | |
| yield history | |
| log_message(f"Final History: {history}") | |
| check_format(history, "messages") | |
| title = None | |
| description = None | |
| # theme = gr.themes.Default(primary_hue="green") | |
| # theme = gr.themes.Soft(primary_hue="gray", secondary_hue="slate", neutral_hue="slate", | |
| # text_size=gr.themes.sizes.text_lg, font=[gr.themes.GoogleFont("Inconsolata"), "Arial", "sans-serif"]) | |
| # theme = gr.Theme.from_hub("earneleh/paris") | |
| theme = apriel | |
| with gr.Blocks(theme=theme) as demo: | |
| gr.HTML(""" | |
| <style> | |
| .html-container:has(.css-styles) { | |
| padding: 0; | |
| margin: 0; | |
| } | |
| .css-styles { height: 0; } | |
| .model-message { | |
| text-align: end; | |
| } | |
| .model-dropdown-container { | |
| display: flex; | |
| align-items: center; | |
| gap: 10px; | |
| padding: 0; | |
| } | |
| .chatbot { | |
| max-height: 1400px; | |
| } | |
| @media (max-width: 800px) { | |
| .responsive-row { | |
| flex-direction: column; | |
| } | |
| .model-message { | |
| text-align: start; | |
| font-size: 10px !important; | |
| } | |
| .model-dropdown-container { | |
| flex-direction: column; | |
| align-items: flex-start; | |
| } | |
| .chatbot { | |
| max-height: 850px; | |
| } | |
| } | |
| @media (max-width: 400px) { | |
| .responsive-row { | |
| flex-direction: column; | |
| } | |
| .model-message { | |
| text-align: start; | |
| font-size: 10px !important; | |
| } | |
| .model-dropdown-container { | |
| flex-direction: column; | |
| align-items: flex-start; | |
| } | |
| .chatbot { | |
| max-height: 400px; | |
| } | |
| } | |
| """ + f""" | |
| @media (min-width: 1024px) {{ | |
| .send-button-container, .clear-button-container {{ | |
| max-width: {BUTTON_WIDTH}px; | |
| }} | |
| }} | |
| </style> | |
| """, elem_classes="css-styles") | |
| with gr.Row(variant="panel", elem_classes="responsive-row"): | |
| with gr.Column(scale=1, min_width=400, elem_classes="model-dropdown-container"): | |
| model_dropdown = gr.Dropdown( | |
| choices=[f"Model: {model}" for model in models_config.keys()], | |
| value=f"Model: {DEFAULT_MODEL_NAME}", | |
| label=None, | |
| interactive=True, | |
| container=False, | |
| scale=0, | |
| min_width=400 | |
| ) | |
| with gr.Column(scale=4, min_width=0): | |
| description_html = gr.HTML(description, elem_classes="model-message") | |
| chatbot = gr.Chatbot( | |
| type="messages", | |
| height="calc(100dvh - 280px)", | |
| elem_classes="chatbot", | |
| ) | |
| # chat_interface = gr.ChatInterface( | |
| # chat_fn, | |
| # description="", | |
| # type="messages", | |
| # chatbot=chatbot, | |
| # fill_height=True, | |
| # ) | |
| with gr.Row(): | |
| with gr.Column(scale=10, min_width=400, elem_classes="user-input-container"): | |
| user_input = gr.Textbox( | |
| show_label=False, | |
| placeholder="Type your message here and press Enter", | |
| container=False, | |
| ) | |
| with gr.Column(scale=1, min_width=BUTTON_WIDTH * 2 + 20): | |
| with gr.Row(): | |
| with gr.Column(scale=1, min_width=BUTTON_WIDTH, elem_classes="send-button-container"): | |
| send_btn = gr.Button("Send", variant="primary") | |
| with gr.Column(scale=1, min_width=BUTTON_WIDTH, elem_classes="clear-button-container"): | |
| clear_btn = gr.ClearButton(chatbot, value="New Chat", variant="secondary") | |
| # on Enter: stream into the chatbot, then clear the textbox | |
| user_input.submit( | |
| fn=chat_fn, | |
| inputs=[user_input, chatbot], | |
| outputs=[chatbot] | |
| ) | |
| user_input.submit(lambda: "", None, user_input, queue=False) | |
| send_btn.click( | |
| fn=chat_fn, | |
| inputs=[user_input, chatbot], | |
| outputs=[chatbot] | |
| ) | |
| send_btn.click(lambda: "", None, user_input, queue=False) | |
| # Ensure the model is reset to default on page reload | |
| demo.load(lambda: setup_model(DEFAULT_MODEL_NAME, intial=False), [], [description_html]) | |
| def update_model_and_clear(model_name): | |
| actual_model_name = model_name.replace("Model: ", "") | |
| desc = setup_model(actual_model_name) | |
| return desc, [] | |
| model_dropdown.change( | |
| fn=update_model_and_clear, | |
| inputs=[model_dropdown], | |
| outputs=[description_html, chatbot] | |
| ) | |
| demo.launch(ssr_mode=False, show_api=False) | |