Spaces:
Running
on
Zero
Running
on
Zero
| import os | |
| import threading | |
| from typing import List, Dict, Tuple, Any, Optional | |
| import torch | |
| import gradio as gr | |
| from transformers import AutoTokenizer, AutoModelForCausalLM, TextIteratorStreamer | |
| from huggingface_hub import login | |
| # --- Optional: Hugging Face Spaces GPU decorator (safe locally) --- | |
| try: | |
| import spaces # type: ignore | |
| GPU_DECORATOR = spaces.GPU | |
| except Exception: # running locally without `spaces` | |
| def GPU_DECORATOR(*args, **kwargs): # no-op decorator | |
| def _wrap(fn): | |
| return fn | |
| return _wrap | |
| # ========================= | |
| # Configuration | |
| # ========================= | |
| MODEL_ID = "facebook/MobileLLM-Pro" | |
| MODEL_SUBFOLDER = "instruct" # "base" | "instruct" | |
| MAX_HISTORY_LENGTH = 10 | |
| MAX_NEW_TOKENS = 512 | |
| DEFAULT_SYSTEM_PROMPT = ( | |
| "You are a helpful, friendly, and intelligent assistant. " | |
| "Provide clear, accurate, and thoughtful responses." | |
| ) | |
| # ========================= | |
| # HF Login (optional) | |
| # ========================= | |
| HF_TOKEN = os.getenv("HF_TOKEN") | |
| if HF_TOKEN: | |
| try: | |
| login(token=HF_TOKEN) | |
| print("Successfully logged in to Hugging Face") | |
| except Exception as e: | |
| print(f"Warning: Could not login to Hugging Face: {e}") | |
| # ========================= | |
| # Utilities | |
| # ========================= | |
| def tuples_from_messages(messages: List[Any]) -> List[List[str]]: | |
| """ | |
| Normalize a Chatbot history to tuples [[user, assistant], ...]. | |
| Accepts either tuples-style or messages-style ({role, content}) lists. | |
| """ | |
| if not messages: | |
| return [] | |
| # Already tuples-like | |
| if isinstance(messages[0], (list, tuple)) and len(messages[0]) == 2: | |
| return [list(x) for x in messages] | |
| # Convert from messages-style | |
| pairs: List[List[str]] = [] | |
| last_user: Optional[str] = None | |
| for m in messages: | |
| role = m.get("role") | |
| content = m.get("content", "") | |
| if role == "user": | |
| last_user = content | |
| elif role == "assistant": | |
| if last_user is None: | |
| pairs.append(["", content]) | |
| else: | |
| pairs.append([last_user, content]) | |
| last_user = None | |
| if last_user is not None: | |
| pairs.append([last_user, ""]) | |
| return pairs | |
| def messages_from_tuples(history_tuples: List[List[str]]) -> List[Dict[str, str]]: | |
| """ | |
| Convert tuples [[user, assistant], ...] into list of role dictionaries: | |
| [{"role": "user", ...}, {"role": "assistant", ...}, ...] | |
| """ | |
| messages: List[Dict[str, str]] = [] | |
| for u, a in history_tuples: | |
| if u: | |
| messages.append({"role": "user", "content": u}) | |
| if a: | |
| messages.append({"role": "assistant", "content": a}) | |
| return messages | |
| # ========================= | |
| # Chat Model Wrapper | |
| # ========================= | |
| class MobileLLMChat: | |
| def __init__(self): | |
| self.model = None | |
| self.tokenizer = None | |
| self.device = None | |
| self.model_loaded = False | |
| self.version = None | |
| self.load_model(version=MODEL_SUBFOLDER) | |
| def load_model(self, version: str = "instruct") -> bool: | |
| """Load tokenizer+model; choose dtype/device_map safely for CPU/GPU.""" | |
| try: | |
| print(f"Loading {MODEL_ID} ({version}) ...") | |
| use_cuda = torch.cuda.is_available() | |
| torch_dtype = torch.float16 if use_cuda else torch.float32 | |
| self.tokenizer = AutoTokenizer.from_pretrained( | |
| MODEL_ID, trust_remote_code=True, subfolder=version | |
| ) | |
| self.model = AutoModelForCausalLM.from_pretrained( | |
| MODEL_ID, | |
| trust_remote_code=True, | |
| subfolder=version, | |
| torch_dtype=torch_dtype, | |
| low_cpu_mem_usage=True, | |
| device_map="auto" if use_cuda else None, | |
| ) | |
| if self.tokenizer.pad_token_id is None: | |
| self.tokenizer.pad_token_id = self.tokenizer.eos_token_id | |
| self.model.eval() | |
| self.version = version | |
| self.device = next(self.model.parameters()).device | |
| self.model_loaded = True | |
| print("Model loaded successfully.") | |
| return True | |
| except Exception as e: | |
| print(f"Error loading model: {e}") | |
| self.model_loaded = False | |
| return False | |
| def format_chat_history( | |
| self, history_msgs: List[Dict[str, str]], system_prompt: str | |
| ) -> List[Dict[str, str]]: | |
| messages = [{"role": "system", "content": system_prompt}] | |
| trimmed = [m for m in history_msgs if m.get("role") in ("user", "assistant")] | |
| if MAX_HISTORY_LENGTH > 0: | |
| trimmed = trimmed[-(MAX_HISTORY_LENGTH * 2) :] | |
| messages.extend(trimmed) | |
| return messages | |
| def generate_once( | |
| self, | |
| user_input: str, | |
| history_msgs: List[Dict[str, str]], | |
| system_prompt: str, | |
| temperature: float = 0.7, | |
| max_new_tokens: int = MAX_NEW_TOKENS, | |
| top_p: float = 0.95, | |
| ) -> str: | |
| """Single-shot generation (no streaming).""" | |
| if not self.model_loaded: | |
| return "Model not loaded. Please reload." | |
| try: | |
| messages = self.format_chat_history(history_msgs + [{"role": "user", "content": user_input}], system_prompt) | |
| inputs = self.tokenizer.apply_chat_template( | |
| messages, return_tensors="pt", add_generation_prompt=True | |
| ) | |
| input_ids = inputs if isinstance(inputs, torch.Tensor) else inputs["input_ids"] | |
| input_ids = input_ids.to(self.device) | |
| with torch.no_grad(): | |
| outputs = self.model.generate( | |
| input_ids, | |
| max_new_tokens=max_new_tokens, | |
| temperature=float(temperature), | |
| do_sample=temperature > 0, | |
| top_p=float(top_p), | |
| pad_token_id=self.tokenizer.pad_token_id, | |
| eos_token_id=self.tokenizer.eos_token_id, | |
| ) | |
| gen_ids = outputs[0][input_ids.shape[1] :] | |
| return self.tokenizer.decode(gen_ids, skip_special_tokens=True).strip() | |
| except Exception as e: | |
| return f"Error generating response: {e}" | |
| def stream_generate( | |
| self, | |
| user_input: str, | |
| history_msgs: List[Dict[str, str]], | |
| system_prompt: str, | |
| temperature: float = 0.7, | |
| max_new_tokens: int = MAX_NEW_TOKENS, | |
| top_p: float = 0.95, | |
| ): | |
| """Streaming generator using TextIteratorStreamer.""" | |
| messages = self.format_chat_history(history_msgs + [{"role": "user", "content": user_input}], system_prompt) | |
| inputs = self.tokenizer.apply_chat_template( | |
| messages, return_tensors="pt", add_generation_prompt=True | |
| ) | |
| input_ids = inputs if isinstance(inputs, torch.Tensor) else inputs["input_ids"] | |
| input_ids = input_ids.to(self.device) | |
| streamer = TextIteratorStreamer(self.tokenizer, skip_special_tokens=True) | |
| gen_kwargs = dict( | |
| input_ids=input_ids, | |
| max_new_tokens=max_new_tokens, | |
| temperature=float(temperature), | |
| do_sample=temperature > 0, | |
| top_p=float(top_p), | |
| pad_token_id=self.tokenizer.pad_token_id, | |
| eos_token_id=self.tokenizer.eos_token_id, | |
| streamer=streamer, | |
| ) | |
| thread = threading.Thread(target=self.model.generate, kwargs=gen_kwargs) | |
| thread.start() | |
| partial = "" | |
| for text in streamer: | |
| partial += text | |
| yield partial | |
| # ========================= | |
| # Initialize Chat Model | |
| # ========================= | |
| print("Initializing MobileLLM-Pro model...") | |
| chat_model = MobileLLMChat() | |
| # ========================= | |
| # Gradio Helpers | |
| # ========================= | |
| def clear_chat(): | |
| return [], "" | |
| def chat_fn(message, history, system_prompt, temperature, top_p): | |
| """Non-streaming chat handler (returns tuples).""" | |
| history = tuples_from_messages(history) | |
| if not chat_model.model_loaded: | |
| return history + [[message, "Please wait for the model to load or reload the space."]] | |
| formatted_history = messages_from_tuples(history) | |
| response = chat_model.generate_once(message, formatted_history, system_prompt, temperature, MAX_NEW_TOKENS, top_p) | |
| return history + [[message, response]] | |
| def chat_stream_fn(message, history, system_prompt, temperature, top_p): | |
| """Streaming chat handler: yields updated tuples as tokens arrive.""" | |
| history = tuples_from_messages(history) | |
| if not chat_model.model_loaded: | |
| yield history + [[message, "Please wait for the model to load or reload the space."]] | |
| return | |
| formatted_history = messages_from_tuples(history) | |
| # Start a new row for the assistant and fill progressively | |
| base = history + [[message, ""]] | |
| for chunk in chat_model.stream_generate(message, formatted_history, system_prompt, temperature, MAX_NEW_TOKENS, top_p): | |
| yield base[:-1] + [[message, chunk]] | |
| # Ensure completion (in case streamer ended exactly on boundary) | |
| # No extra yield needed; last chunk already yielded. | |
| def handle_chat(message, history, system_prompt, temperature, top_p, streaming): | |
| return ( | |
| chat_stream_fn(message, history, system_prompt, temperature, top_p) | |
| if streaming | |
| else chat_fn(message, history, system_prompt, temperature, top_p) | |
| ) | |
| # ========================= | |
| # Gradio UI | |
| # ========================= | |
| with gr.Blocks( | |
| title="MobileLLM-Pro Chat", | |
| theme=gr.themes.Soft(), | |
| css=""" | |
| .gradio-container { max-width: 900px !important; margin: auto !important; } | |
| .message { padding: 12px !important; border-radius: 8px !important; margin-bottom: 8px !important; } | |
| .user-message { background-color: #e3f2fd !important; margin-left: 20% !important; } | |
| .assistant-message { background-color: #f5f5f5 !important; margin-right: 20% !important; } | |
| """ | |
| ) as demo: | |
| gr.HTML( | |
| """ | |
| <div style=\"text-align: center; margin-bottom: 20px;\"> | |
| <h1>🤖 MobileLLM-Pro Chat</h1> | |
| <p>Built with <a href=\"https://huggingface.co/spaces/akhaliq/anycoder\" target=\"_blank\">anycoder</a></p> | |
| <p>Chat with Facebook's MobileLLM-Pro model optimized for on-device inference</p> | |
| </div> | |
| """ | |
| ) | |
| with gr.Row(): | |
| model_status = gr.Textbox( | |
| label="Model Status", | |
| value="Model loaded and ready!" if chat_model.model_loaded else "Model loading...", | |
| interactive=False, | |
| container=True, | |
| ) | |
| with gr.Accordion("⚙️ Configuration", open=False): | |
| with gr.Row(): | |
| system_prompt = gr.Textbox( | |
| value=DEFAULT_SYSTEM_PROMPT, | |
| label="System Prompt", | |
| lines=3, | |
| info="Customize the AI's behavior and personality", | |
| ) | |
| with gr.Row(): | |
| temperature = gr.Slider( | |
| minimum=0.0, | |
| maximum=2.0, | |
| value=0.7, | |
| step=0.05, | |
| label="Temperature", | |
| info="Controls randomness (higher = more creative)", | |
| ) | |
| top_p = gr.Slider( | |
| minimum=0.1, | |
| maximum=1.0, | |
| value=0.95, | |
| step=0.01, | |
| label="Top-p", | |
| info="Nucleus sampling threshold", | |
| ) | |
| streaming = gr.Checkbox( | |
| value=True, | |
| label="Enable Streaming", | |
| info="Show responses as they're being generated", | |
| ) | |
| chatbot = gr.Chatbot( | |
| type="tuples", | |
| label="Chat History", | |
| height=500, | |
| show_copy_button=True, | |
| ) | |
| with gr.Row(): | |
| msg = gr.Textbox( | |
| label="Your Message", | |
| placeholder="Type your message here...", | |
| scale=4, | |
| container=False, | |
| ) | |
| submit_btn = gr.Button("Send", variant="primary", scale=1) | |
| clear_btn = gr.Button("Clear", scale=0) | |
| msg.submit( | |
| handle_chat, | |
| inputs=[msg, chatbot, system_prompt, temperature, top_p, streaming], | |
| outputs=[chatbot], | |
| ).then(lambda: "", None, msg) | |
| submit_btn.click( | |
| handle_chat, | |
| inputs=[msg, chatbot, system_prompt, temperature, top_p, streaming], | |
| outputs=[chatbot], | |
| ).then(lambda: "", None, msg) | |
| clear_btn.click( | |
| clear_chat, | |
| outputs=[chatbot, msg], | |
| ) | |
| gr.Examples( | |
| examples=[ | |
| ["What are the benefits of on-device AI models?"], | |
| ["Explain quantum computing in simple terms."], | |
| ["Write a short poem about technology."], | |
| ["What's the difference between machine learning and deep learning?"], | |
| ["How can I improve my productivity?"], | |
| ], | |
| inputs=[msg], | |
| label="Example Prompts", | |
| ) | |
| gr.HTML( | |
| """ | |
| <div style=\"text-align: center; margin-top: 20px; color: #666;\"> | |
| <p>⚠️ Note: Model is pre-loaded for faster inference. GPU is allocated only during generation.</p> | |
| <p>Model: <a href=\"https://huggingface.co/facebook/MobileLLM-Pro\" target=\"_blank\">facebook/MobileLLM-Pro</a></p> | |
| </div> | |
| """ | |
| ) | |
| # Improve streaming UX | |
| demo.queue() | |
| if __name__ == "__main__": | |
| demo.launch(show_error=True, debug=True) | |