Spaces:
Running
Running
| import os | |
| import sys | |
| import time | |
| from functools import wraps | |
| from typing import Any, Literal | |
| from gradio import ChatMessage | |
| from gradio.components.chatbot import Message | |
| COMMUNITY_POSTFIX_URL = "/discussions" | |
| DEBUG_MODE = False or os.environ.get("DEBUG_MODE") == "True" | |
| DEBUG_MODEL = False or os.environ.get("DEBUG_MODEL") == "True" | |
| models_config = { | |
| "Apriel-1.5-15B-thinker": { | |
| "MODEL_DISPLAY_NAME": "Apriel-1.5-15B-thinker", | |
| "MODEL_HF_URL": "https://huggingface.co/ServiceNow-AI/Apriel-1.5-15b-Thinker", | |
| "MODEL_NAME": os.environ.get("MODEL_NAME_APRIEL_1_5_15B"), | |
| "VLLM_API_URL": os.environ.get("VLLM_API_URL_APRIEL_1_5_15B"), | |
| "VLLM_API_URL_LIST": os.environ.get("VLLM_API_URL_LIST_APRIEL_1_5_15B"), | |
| "AUTH_TOKEN": os.environ.get("AUTH_TOKEN"), | |
| "REASONING": True, | |
| "MULTIMODAL": True | |
| }, | |
| # "Apriel-Nemotron-15b-Thinker": { | |
| # "MODEL_DISPLAY_NAME": "Apriel-Nemotron-15b-Thinker", | |
| # "MODEL_HF_URL": "https://huggingface.co/ServiceNow-AI/Apriel-Nemotron-15b-Thinker", | |
| # "MODEL_NAME": os.environ.get("MODEL_NAME_NEMO_15B"), | |
| # "VLLM_API_URL": os.environ.get("VLLM_API_URL_NEMO_15B"), | |
| # "AUTH_TOKEN": os.environ.get("AUTH_TOKEN"), | |
| # "REASONING": True, | |
| # "MULTIMODAL": False | |
| # }, | |
| # "Apriel-5b": { | |
| # "MODEL_DISPLAY_NAME": "Apriel-5b", | |
| # "MODEL_HF_URL": "https://huggingface.co/ServiceNow-AI/Apriel-5B-Instruct", | |
| # "MODEL_NAME": os.environ.get("MODEL_NAME_5B"), | |
| # "VLLM_API_URL": os.environ.get("VLLM_API_URL_5B"), | |
| # "AUTH_TOKEN": os.environ.get("AUTH_TOKEN"), | |
| # "REASONING": False, | |
| # "MULTIMODAL": False | |
| # } | |
| } | |
| def get_model_config(model_name: str) -> dict: | |
| config = models_config.get(model_name) | |
| config['MODEL_KEY'] = model_name | |
| if not config: | |
| raise ValueError(f"Model {model_name} not found in models_config") | |
| if not config.get("MODEL_NAME"): | |
| raise ValueError(f"Model name not found in config for {model_name}") | |
| if not config.get("VLLM_API_URL"): | |
| raise ValueError(f"VLLM API URL not found in config for {model_name}") | |
| return config | |
| def _log_message(prefix, message, icon=""): | |
| timestamp = time.strftime("%Y-%m-%d %H:%M:%S", time.localtime()) | |
| if len(icon) > 0: | |
| icon = f"{icon} " | |
| print(f"{timestamp}: {prefix} {icon}{message}") | |
| def log_debug(message): | |
| if DEBUG_MODE is True: | |
| _log_message("DEBUG", message) | |
| def log_info(message): | |
| _log_message("INFO ", message) | |
| def log_warning(message): | |
| _log_message("WARN ", message, "⚠️") | |
| def log_error(message): | |
| _log_message("ERROR", message, "‼️") | |
| # Gradio 5.0.1 had issues with checking the message formats. 5.29.0 does not! | |
| def check_format(messages: Any, type: Literal["messages", "tuples"] = "messages") -> None: | |
| if not DEBUG_MODE: | |
| return | |
| if type == "messages": | |
| all_valid = all( | |
| isinstance(message, dict) | |
| and "role" in message | |
| and "content" in message | |
| or isinstance(message, ChatMessage | Message) | |
| for message in messages | |
| ) | |
| if not all_valid: | |
| # Display which message is not valid | |
| for i, message in enumerate(messages): | |
| if not (isinstance(message, dict) and | |
| "role" in message and | |
| "content" in message) and not isinstance(message, ChatMessage | Message): | |
| print(f"_check_format() --> Invalid message at index {i}: {message}\n", file=sys.stderr) | |
| break | |
| raise Exception( | |
| "Data incompatible with messages format. Each message should be a dictionary with 'role' and 'content' keys or a ChatMessage object." | |
| ) | |
| # else: | |
| # print("_check_format() --> All messages are valid.") | |
| elif not all( | |
| isinstance(message, (tuple, list)) and len(message) == 2 | |
| for message in messages | |
| ): | |
| raise Exception( | |
| "Data incompatible with tuples format. Each message should be a list of length 2." | |
| ) | |
| # Adds timing info for a gradio event handler (non-generator functions) | |
| def logged_event_handler(log_msg='', event_handler=None, log_timer=None, clear_timer=False): | |
| def wrapped_event_handler(*args, **kwargs): | |
| # Log before | |
| if log_timer: | |
| if clear_timer: | |
| log_timer.clear() | |
| log_timer.add_step(f"Start: {log_debug}") | |
| log_debug(f"::: Before event: {log_msg}") | |
| # Call the original event handler | |
| result = event_handler(*args, **kwargs) | |
| # Log after | |
| if log_timer: | |
| log_timer.add_step(f"Completed: {log_msg}") | |
| log_debug(f"::: After event: {log_msg}") | |
| return result | |
| return wrapped_event_handler | |