MobileLLM-Pro / app.py
akhaliq's picture
akhaliq HF Staff
Update to account for MobileLLM-Pro repository split (#1)
a221e28 verified
import os
import threading
from typing import List, Tuple, Dict
import torch
import gradio as gr
from transformers import AutoTokenizer, AutoModelForCausalLM, TextIteratorStreamer
from huggingface_hub import login
import spaces
MODEL_ID = "facebook/MobileLLM-Pro"
MAX_NEW_TOKENS = 256
TEMPERATURE = 0.7
TOP_P = 0.95
# --- Silent Hub auth via env/Space Secret (no UI) ---
HF_TOKEN = os.getenv("HF_TOKEN") or os.getenv("HUGGINGFACEHUB_API_TOKEN") or os.getenv("HUGGINGFACE_TOKEN")
if HF_TOKEN:
try:
login(token=HF_TOKEN)
except Exception:
pass # stay silent
# Globals so we only load once
_tokenizer = None
_model = None
_device = None
def _ensure_loaded():
global _tokenizer, _model, _device
if _tokenizer is not None and _model is not None:
return
_tokenizer = AutoTokenizer.from_pretrained(
MODEL_ID, trust_remote_code=True
)
_model = AutoModelForCausalLM.from_pretrained(
MODEL_ID,
trust_remote_code=True,
torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32,
low_cpu_mem_usage=True,
device_map="auto" if torch.cuda.is_available() else None,
)
if _tokenizer.pad_token_id is None and _tokenizer.eos_token_id is not None:
_tokenizer.pad_token = _tokenizer.eos_token
_model.eval()
_device = next(_model.parameters()).device
def _history_to_messages(history: List[Tuple[str, str]]) -> List[Dict[str, str]]:
msgs: List[Dict[str, str]] = []
for user_msg, bot_msg in history:
if user_msg:
msgs.append({"role": "user", "content": user_msg})
if bot_msg:
msgs.append({"role": "assistant", "content": bot_msg})
return msgs
@spaces.GPU(duration=120)
def generate_stream(message: str, history: List[Tuple[str, str]]):
"""
Minimal streaming chat function for gr.ChatInterface.
Uses instruct chat template. No token UI. No extra controls.
"""
_ensure_loaded()
messages = _history_to_messages(history) + [{"role": "user", "content": message}]
inputs = _tokenizer.apply_chat_template(
messages,
return_tensors="pt",
add_generation_prompt=True,
)
input_ids = inputs["input_ids"] if isinstance(inputs, dict) else inputs
input_ids = input_ids.to(_device)
# IMPORTANT: don't stream the prompt (prevents system/user text from appearing)
streamer = TextIteratorStreamer(
_tokenizer,
skip_special_tokens=True,
skip_prompt=True, # <-- key fix
)
gen_kwargs = dict(
input_ids=input_ids,
max_new_tokens=MAX_NEW_TOKENS,
do_sample=TEMPERATURE > 0.0,
temperature=float(TEMPERATURE),
top_p=float(TOP_P),
pad_token_id=_tokenizer.pad_token_id,
eos_token_id=_tokenizer.eos_token_id,
streamer=streamer,
)
thread = threading.Thread(target=_model.generate, kwargs=gen_kwargs)
thread.start()
output = ""
for new_text in streamer:
output += new_text
yield output
with gr.Blocks(title="MobileLLM-Pro β€” Chat") as demo:
gr.Markdown(
"""
# MobileLLM-Pro β€” Chat
Streaming chat with facebook/MobileLLM-Pro (instruct)
<div style="text-align:center;">
Built with <a href="https://huggingface.co/spaces/akhaliq/anycoder" target="_blank">anycoder</a>
</div>
""")
gr.ChatInterface(
fn=generate_stream,
chatbot=gr.Chatbot(height=420, label="MobileLLM-Pro"),
title=None, # header handled by Markdown above
description=None,
)
if __name__ == "__main__":
demo.launch(server_name="0.0.0.0", server_port=int(os.getenv("PORT", 7860)))