Spaces:
Running
on
Zero
Running
on
Zero
| import gradio as gr | |
| import torch | |
| from transformers import AutoModelForCausalLM, AutoTokenizer | |
| import re | |
| import json | |
| from typing import List, Dict, Any, Optional | |
| import logging | |
| import spaces | |
| import os | |
| import sys | |
| import requests | |
| import accelerate | |
| # Set torch to use float16 on GPU for better performance, float32 on CPU for compatibility | |
| if torch.cuda.is_available(): | |
| torch.set_default_dtype(torch.float16) | |
| else: | |
| torch.set_default_dtype(torch.float32) | |
| logging.basicConfig(level=logging.INFO) | |
| logger = logging.getLogger(__name__) | |
| MAIN_MODEL_ID = "Tonic/petite-elle-L-aime-3-sft" | |
| DEVICE = "cuda" if torch.cuda.is_available() else "cpu" | |
| model = None | |
| tokenizer = None | |
| DEFAULT_SYSTEM_PROMPT = "Tu es TonicIA, un assistant francophone rigoureux et bienveillant." | |
| title = "# 🙋🏻♂️Welcome to 🌟Tonic's Petite Elle L'Aime 3" | |
| description = "A fine-tuned version of SmolLM3-3B optimized for French conversations." | |
| presentation1 = """ | |
| ### 🎯 Features | |
| - **Multilingual Support**: English, French, Italian, Portuguese, Chinese, Arabic | |
| - **Full Fine-Tuned Model**: Maximum performance and quality with full precision | |
| - **Interactive Chat Interface**: Real-time conversation with the model | |
| - **Customizable System Prompt**: Define the assistant's personality and behavior | |
| - **Thinking Mode**: Enable reasoning mode with thinking tags | |
| - **Tool Calling**: Support for function calling with XML and Python tools | |
| """ | |
| presentation2 = """### 🎯 Fonctionnalités | |
| * **Support multilingue** : Anglais, Français, Italien, Portugais, Chinois, Arabe | |
| * **Modèle complet fine-tuné** : Performance et qualité maximales avec précision complète | |
| * **Interface de chat interactive** : Conversation en temps réel avec le modèle | |
| * **Invite système personnalisable** : Définissez la personnalité et le comportement de l'assistant | |
| * **Mode Réflexion** : Activez le mode raisonnement avec des balises de réflexion | |
| * **Appel d'outils** : Support pour l'appel de fonctions avec XML et Python | |
| """ | |
| joinus = """ | |
| ## Join us : | |
| 🌟TeamTonic🌟 is always making cool demos! Join our active builder's 🛠️community 👻 [](https://discord.gg/qdfnvSPcqP) On 🤗Huggingface:[MultiTransformer](https://huggingface.co/MultiTransformer) On 🌐Github: [Tonic-AI](https://github.com/tonic-ai) & contribute to🌟 [Build Tonic](https://git.tonic-ai.com/contribute)🤗Big thanks to Yuvi Sharma and all the folks at huggingface for the community grant 🤗 | |
| """ | |
| # Default tool definition for demonstration | |
| DEFAULT_TOOLS = [ | |
| { | |
| "name": "get_weather", | |
| "description": "Get the weather in a city", | |
| "parameters": { | |
| "type": "object", | |
| "properties": { | |
| "city": { | |
| "type": "string", | |
| "description": "The city to get the weather for" | |
| } | |
| } | |
| } | |
| }, | |
| { | |
| "name": "calculate", | |
| "description": "Perform mathematical calculations", | |
| "parameters": { | |
| "type": "object", | |
| "properties": { | |
| "expression": { | |
| "type": "string", | |
| "description": "Mathematical expression to evaluate" | |
| } | |
| } | |
| } | |
| } | |
| ] | |
| def download_chat_template(): | |
| """Download the chat template from the main repository""" | |
| try: | |
| chat_template_url = f"https://huggingface.co/{MAIN_MODEL_ID}/raw/main/chat_template.jinja" | |
| logger.info(f"Downloading chat template from {chat_template_url}") | |
| response = requests.get(chat_template_url, timeout=30) | |
| response.raise_for_status() | |
| chat_template_content = response.text | |
| logger.info("Chat template downloaded successfully") | |
| return chat_template_content | |
| except requests.exceptions.RequestException as e: | |
| logger.error(f"Error downloading chat template: {e}") | |
| return None | |
| except Exception as e: | |
| logger.error(f"Unexpected error downloading chat template: {e}") | |
| return None | |
| def load_model(): | |
| """Load the full fine-tuned model and tokenizer""" | |
| global model, tokenizer | |
| try: | |
| logger.info(f"Loading tokenizer from {MAIN_MODEL_ID}") | |
| # tokenizer = AutoTokenizer.from_pretrained(MAIN_MODEL_ID, subfolder="int4") | |
| tokenizer = AutoTokenizer.from_pretrained(MAIN_MODEL_ID) | |
| # chat_template = download_chat_template() | |
| # if chat_template: | |
| # tokenizer.chat_template = chat_template | |
| # logger.info("Chat template downloaded and set successfully") | |
| # logger.info(f"Loading full fine-tuned model from {MAIN_MODEL_ID}") | |
| # Load the full fine-tuned model with optimized settings | |
| model_kwargs = { | |
| "device_map": "auto" if DEVICE == "cuda" else "cpu", | |
| "torch_dtype": torch.bfloat16 if DEVICE == "cuda" else torch.float32, # Use float16 on GPU, float32 on CPU | |
| "trust_remote_code": True, | |
| "low_cpu_mem_usage": True, | |
| # "attn_implementation": "flash_attention_2" if DEVICE == "cuda" else "eager" | |
| } | |
| logger.info(f"Model loading parameters: {model_kwargs}") | |
| # model = AutoModelForCausalLM.from_pretrained(MAIN_MODEL_ID, subfolder="int4", **model_kwargs) | |
| model = AutoModelForCausalLM.from_pretrained(MAIN_MODEL_ID, **model_kwargs) | |
| if tokenizer.pad_token_id is None: | |
| tokenizer.pad_token_id = tokenizer.eos_token_id | |
| logger.info("Full fine-tuned model loaded successfully") | |
| return True | |
| except Exception as e: | |
| logger.error(f"Error loading model: {e}") | |
| logger.error(f"Model config: {model.config if model else 'Model not loaded'}") | |
| return False | |
| def create_prompt(system_message, user_message, enable_thinking=True, tools=None, use_xml_tools=True): | |
| """Create prompt using the model's chat template with SmolLM3 features""" | |
| try: | |
| formatted_messages = [] | |
| if system_message and system_message.strip(): | |
| # Check if thinking flags are already present | |
| has_think_flag = "/think" in system_message | |
| has_no_think_flag = "/no_think" in system_message | |
| # Add thinking flag to system message if needed | |
| if not enable_thinking and not has_no_think_flag: | |
| system_message += "/no_think" | |
| elif enable_thinking and not has_think_flag and not has_no_think_flag: | |
| system_message += "/think" | |
| formatted_messages.append({"role": "system", "content": system_message}) | |
| formatted_messages.append({"role": "user", "content": user_message}) | |
| # Apply chat template with SmolLM3 features | |
| template_kwargs = { | |
| "tokenize": False, | |
| "add_generation_prompt": True, | |
| "enable_thinking": enable_thinking | |
| } | |
| # Add tool calling if tools are provided | |
| if tools and len(tools) > 0: | |
| if use_xml_tools: | |
| template_kwargs["xml_tools"] = tools | |
| else: | |
| template_kwargs["python_tools"] = tools | |
| prompt = tokenizer.apply_chat_template(formatted_messages, **template_kwargs) | |
| return prompt | |
| except Exception as e: | |
| logger.error(f"Error creating prompt: {e}") | |
| return "" | |
| def generate_response(message, history, system_message, max_tokens, temperature, top_p, repetition_penalty, do_sample, enable_thinking=True, tools=None, use_xml_tools=True): | |
| """Generate response using the full fine-tuned model with SmolLM3 features""" | |
| global model, tokenizer | |
| if model is None or tokenizer is None: | |
| return "Error: Model not loaded. Please wait for the model to load." | |
| # Parse tools from string if provided | |
| parsed_tools = None | |
| if tools and tools.strip(): | |
| try: | |
| parsed_tools = json.loads(tools) | |
| except json.JSONDecodeError as e: | |
| logger.error(f"Error parsing tools JSON: {e}") | |
| return "Error: Invalid tool definition JSON format." | |
| full_prompt = create_prompt(system_message, message, enable_thinking, parsed_tools, use_xml_tools) | |
| if not full_prompt: | |
| return "Error: Failed to create prompt." | |
| inputs = tokenizer(full_prompt, return_tensors="pt", padding=True, truncation=True) | |
| logger.info(f"Input tensor shapes: {[(k, v.shape, v.dtype) for k, v in inputs.items()]}") | |
| if DEVICE == "cuda": | |
| inputs = {k: v.cuda() for k, v in inputs.items()} | |
| with torch.no_grad(): | |
| output_ids = model.generate( | |
| inputs['input_ids'], | |
| max_new_tokens=max_tokens, | |
| temperature=temperature, | |
| top_p=top_p, | |
| repetition_penalty=repetition_penalty, | |
| do_sample=do_sample, | |
| attention_mask=inputs['attention_mask'], | |
| pad_token_id=tokenizer.eos_token_id, | |
| eos_token_id=tokenizer.eos_token_id, | |
| # cache_implementation="static" | |
| ) | |
| # First decode WITH special tokens to find markers | |
| response_with_tokens = tokenizer.decode(output_ids[0], skip_special_tokens=False) | |
| # Debug: Print the full raw response with tokens | |
| # logger.info(f"=== FULL RAW RESPONSE WITH TOKENS DEBUG ===") | |
| # logger.info(f"Raw response with tokens length: {len(response_with_tokens)}") | |
| # logger.info(f"Raw response with tokens: {repr(response_with_tokens)}") | |
| # More robust response extraction - look for assistant marker | |
| # logger.info(f"Looking for assistant marker in response...") | |
| if "<|im_start|>assistant" in response_with_tokens: | |
| # logger.info(f"Found assistant marker in response") | |
| # Find the start of assistant response | |
| assistant_start = response_with_tokens.find("<|im_start|>assistant") | |
| # logger.info(f"Assistant marker found at position: {assistant_start}") | |
| if assistant_start != -1: | |
| # Find the end of the assistant marker | |
| marker_end = response_with_tokens.find("\n", assistant_start) | |
| # logger.info(f"Marker end found at position: {marker_end}") | |
| if marker_end != -1: | |
| assistant_response = response_with_tokens[marker_end + 1:].strip() | |
| # logger.info(f"Using marker-based extraction") | |
| else: | |
| assistant_response = response_with_tokens[assistant_start + len("<|im_start|>assistant"):].strip() | |
| # logger.info(f"Using fallback marker extraction") | |
| else: | |
| # Fallback to prompt-based extraction | |
| response = tokenizer.decode(output_ids[0], skip_special_tokens=True) | |
| assistant_response = response[len(full_prompt):].strip() | |
| # logger.info(f"Using prompt-based extraction (marker not found)") | |
| else: | |
| # Fallback to original method | |
| # logger.info(f"No assistant marker found, using prompt-based extraction") | |
| response = tokenizer.decode(output_ids[0], skip_special_tokens=True) | |
| assistant_response = response[len(full_prompt):].strip() | |
| # Clean up any remaining special tokens | |
| assistant_response = re.sub(r'<\|im_start\|>.*?<\|im_end\|>', '', assistant_response, flags=re.DOTALL) | |
| assistant_response = re.sub(r'<\|im_start\|>', '', assistant_response) | |
| assistant_response = re.sub(r'<\|im_end\|>', '', assistant_response) | |
| # Debug: Print the extracted assistant response after cleanup | |
| # logger.info(f"=== EXTRACTED ASSISTANT RESPONSE AFTER CLEANUP DEBUG ===") | |
| # logger.info(f"Extracted response length: {len(assistant_response)}") | |
| # logger.info(f"Extracted response: {repr(assistant_response)}") | |
| # Debug: Print before cleanup | |
| # logger.info(f"=== BEFORE CLEANUP DEBUG ===") | |
| # logger.info(f"Before cleanup length: {len(assistant_response)}") | |
| # logger.info(f"Before cleanup: {repr(assistant_response)}") | |
| assistant_response = re.sub(r'<\|im_start\|>.*?<\|im_end\|>', '', assistant_response, flags=re.DOTALL) | |
| # Debug: Print after first cleanup | |
| # logger.info(f"=== AFTER FIRST CLEANUP DEBUG ===") | |
| # logger.info(f"After first cleanup length: {len(assistant_response)}") | |
| # logger.info(f"After first cleanup: {repr(assistant_response)}") | |
| if not enable_thinking: | |
| assistant_response = re.sub(r'<think>.*?</think>', '', assistant_response, flags=re.DOTALL) | |
| # Debug: Print after thinking cleanup | |
| # logger.info(f"=== AFTER THINKING CLEANUP DEBUG ===") | |
| # logger.info(f"After thinking cleanup length: {len(assistant_response)}") | |
| # logger.info(f"After thinking cleanup: {repr(assistant_response)}") | |
| # Debug: Print before tool call handling | |
| # logger.info(f"=== BEFORE TOOL CALL HANDLING DEBUG ===") | |
| # logger.info(f"Before tool call handling length: {len(assistant_response)}") | |
| # logger.info(f"Before tool call handling: {repr(assistant_response)}") | |
| # Handle tool calls if present | |
| if parsed_tools and ("<tool_call>" in assistant_response or "<code>" in assistant_response): | |
| if "<tool_call>" in assistant_response: | |
| tool_call_match = re.search(r'<tool_call>(.*?)</tool_call>', assistant_response, re.DOTALL) | |
| if tool_call_match: | |
| tool_call = tool_call_match.group(1) | |
| assistant_response += f"\n\n🔧 Tool Call Detected: {tool_call}\n\nNote: This is a simulated tool call. In a real scenario, the tool would be executed and its output would be used to generate a final response." | |
| elif "<code>" in assistant_response: | |
| code_match = re.search(r'<code>(.*?)</code>', assistant_response, re.DOTALL) | |
| if code_match: | |
| code_call = code_match.group(1) | |
| assistant_response += f"\n\n🐍 Python Tool Call: {code_call}\n\nNote: This is a simulated Python tool call. In a real scenario, the function would be executed and its output would be used to generate a final response." | |
| # Debug: Print after tool call handling | |
| # logger.info(f"=== AFTER TOOL CALL HANDLING DEBUG ===") | |
| # logger.info(f"After tool call handling length: {len(assistant_response)}") | |
| # logger.info(f"After tool call handling: {repr(assistant_response)}") | |
| assistant_response = assistant_response.strip() | |
| # Debug: Print final response | |
| # logger.info(f"=== FINAL RESPONSE DEBUG ===") | |
| # logger.info(f"Final response length: {len(assistant_response)}") | |
| # logger.info(f"Final response: {repr(assistant_response)}") | |
| # logger.info(f"=== END DEBUG ===") | |
| return assistant_response | |
| def user(user_message, history): | |
| """Add user message to history""" | |
| if history is None: | |
| history = [] | |
| return "", history + [{"role": "user", "content": user_message}] | |
| def bot(history, system_prompt, max_length, temperature, top_p, repetition_penalty, advanced_checkbox, enable_thinking, tools, use_xml_tools, use_tools): | |
| """Generate bot response""" | |
| if not history: | |
| return history | |
| user_message = history[-1]["content"] if history else "" | |
| do_sample = advanced_checkbox | |
| tools_to_use = tools if use_tools else None | |
| bot_message = generate_response( | |
| user_message, history, system_prompt, max_length, temperature, top_p, repetition_penalty, | |
| do_sample, enable_thinking, tools_to_use, use_xml_tools | |
| ) | |
| history.append({"role": "assistant", "content": bot_message}) | |
| return history | |
| # Load model on startup | |
| logger.info("Starting model loading process with full fine-tuned model...") | |
| load_model() | |
| # Create Gradio interface | |
| with gr.Blocks() as demo: | |
| with gr.Row(): | |
| gr.Markdown(title) | |
| with gr.Row(): | |
| gr.Markdown(description) | |
| with gr.Row(): | |
| with gr.Column(scale=1): | |
| with gr.Group(): | |
| gr.Markdown(presentation1) | |
| with gr.Column(scale=1): | |
| with gr.Group(): | |
| gr.Markdown(presentation2) | |
| with gr.Row(): | |
| with gr.Column(scale=1): | |
| with gr.Group(): | |
| gr.Markdown(joinus) | |
| with gr.Column(scale=1): | |
| pass # Empty column for balance | |
| with gr.Row(): | |
| with gr.Column(scale=2): | |
| system_prompt = gr.TextArea( | |
| label="📑 Contexte", | |
| placeholder="Tu es TonicIA, un assistant francophone rigoureux et bienveillant.", | |
| lines=5, | |
| value=DEFAULT_SYSTEM_PROMPT | |
| ) | |
| user_input = gr.TextArea( | |
| label="🤷🏻♂️ Message", | |
| placeholder="Bonjour je m'appel Tonic!", | |
| lines=2 | |
| ) | |
| advanced_checkbox = gr.Checkbox(label="🧪 Advanced Settings", value=False) | |
| with gr.Column(visible=False) as advanced_settings: | |
| max_length = gr.Slider( | |
| label="📏 Longueur de la réponse", | |
| minimum=10, | |
| maximum=9000, # maximum=32768, | |
| value=1256, | |
| step=1 | |
| ) | |
| temperature = gr.Slider( | |
| label="🌡️ Température", | |
| minimum=0.01, | |
| maximum=1.0, | |
| value=0.6, # Updated to SmolLM3 recommended | |
| step=0.01 | |
| ) | |
| top_p = gr.Slider( | |
| label="⚛️ Top-p (Echantillonnage)", | |
| minimum=0.1, | |
| maximum=1.0, | |
| value=0.95, | |
| step=0.01 | |
| ) | |
| repetition_penalty = gr.Slider( | |
| label="🔄 Pénalité de Répétition", | |
| minimum=1.0, | |
| maximum=2.0, | |
| value=1.1, | |
| step=0.01 | |
| ) | |
| enable_thinking = gr.Checkbox(label="Mode Réflexion", value=True) | |
| use_tools = gr.Checkbox(label="🔧 Enable Tool Calling", value=False) | |
| use_xml_tools = gr.Checkbox(label="📋 Use XML Tools (vs Python)", value=True) | |
| with gr.Column(visible=False) as tool_options: | |
| tools = gr.Code( | |
| label="Tool Definition (JSON)", | |
| value=json.dumps(DEFAULT_TOOLS, indent=2), | |
| lines=15, | |
| language="json" | |
| ) | |
| generate_button = gr.Button(value="🤖 Petite Elle L'Aime 3") | |
| with gr.Column(scale=2): | |
| chatbot = gr.Chatbot(label="🤖 Petite Elle L'Aime 3", type="messages", value=[]) | |
| generate_button.click( | |
| user, | |
| [user_input, chatbot], | |
| [user_input, chatbot], | |
| queue=False | |
| ).then( | |
| bot, | |
| [chatbot, system_prompt, max_length, temperature, top_p, repetition_penalty, advanced_checkbox, enable_thinking, tools, use_xml_tools, use_tools], | |
| chatbot | |
| ) | |
| advanced_checkbox.change( | |
| fn=lambda x: gr.update(visible=x), | |
| inputs=[advanced_checkbox], | |
| outputs=[advanced_settings] | |
| ) | |
| use_tools.change( | |
| fn=lambda x: gr.update(visible=x), | |
| inputs=[use_tools], | |
| outputs=[tool_options] | |
| ) | |
| if __name__ == "__main__": | |
| demo.queue() | |
| demo.launch(ssr_mode=False, mcp_server=True) |