MobileLLM-Pro / app.py
akhaliq's picture
akhaliq HF Staff
Update Gradio app with multiple files
21f22c1 verified
raw
history blame
7.99 kB
import os
import time
import torch
import gradio as gr
from typing import List, Dict, Any, Tuple
from transformers import (
AutoTokenizer,
AutoModelForCausalLM,
TextIteratorStreamer,
)
from huggingface_hub import login
import threading
import spaces
"""
Gradio chat app for facebook/MobileLLM-Pro
- Uses the model's chat template when using the "instruct" subfolder
- Streams tokens to the Gradio UI
- Minimal controls: max_new_tokens, temperature, top_p
- Optional HF_TOKEN login via env var or textbox
To run locally:
pip install -U gradio transformers accelerate sentencepiece huggingface_hub
HF_TOKEN=xxxx python app.py
On Hugging Face Spaces:
- Remove explicit login() call or set HF_TOKEN as a secret
"""
MODEL_ID = "facebook/MobileLLM-Pro"
DEFAULT_VERSION = "instruct" # "base" | "instruct"
DEFAULT_MAX_NEW_TOKENS = 256
DEFAULT_TEMPERATURE = 0.7
DEFAULT_TOP_P = 0.95
# ---- Optional: login to Hugging Face if token is provided ----
HF_TOKEN = os.getenv("HF_TOKEN")
if HF_TOKEN:
try:
login(token=HF_TOKEN)
print("[INFO] Logged in to Hugging Face Hub.")
except Exception as e:
print(f"[WARN] Could not login to Hugging Face: {e}")
def load_model(version: str = DEFAULT_VERSION):
"""Load tokenizer+model for the selected subfolder (base/instruct)."""
print(f"[INFO] Loading {MODEL_ID}:{version} ...")
tokenizer = AutoTokenizer.from_pretrained(
MODEL_ID, trust_remote_code=True, subfolder=version
)
model = AutoModelForCausalLM.from_pretrained(
MODEL_ID,
trust_remote_code=True,
subfolder=version,
torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32,
low_cpu_mem_usage=True,
device_map="auto" if torch.cuda.is_available() else None,
)
# Ensure special tokens are set to avoid warnings
if tokenizer.pad_token_id is None and tokenizer.eos_token_id is not None:
tokenizer.pad_token = tokenizer.eos_token
model.eval()
print("[INFO] Model loaded.")
return tokenizer, model
def _history_to_messages(history: List[Tuple[str, str]]) -> List[Dict[str, str]]:
"""Map Gradio history [(user, assistant), ...] to chat template messages."""
messages: List[Dict[str, str]] = []
for user_msg, bot_msg in history:
if user_msg:
messages.append({"role": "user", "content": user_msg})
if bot_msg:
messages.append({"role": "assistant", "content": bot_msg})
return messages
@spaces.GPU(duration=120)
def generate_stream(
message: str,
history: List[Tuple[str, str]],
version: str,
max_new_tokens: int,
temperature: float,
top_p: float,
use_chat_template: bool,
state: Dict[str, Any],
):
"""Streaming text generator compatible with gr.ChatInterface.
Args map to UI controls. `state` holds tokenizer/model between calls.
"""
tokenizer = state.get("tokenizer")
model = state.get("model")
# (Re)load model if version changed or not yet loaded
if (
tokenizer is None
or model is None
or state.get("version") != version
):
tokenizer, model = load_model(version)
state["tokenizer"], state["model"], state["version"] = tokenizer, model, version
device = next(model.parameters()).device
if use_chat_template and version == "instruct":
messages = _history_to_messages(history) + [
{"role": "user", "content": message}
]
inputs = tokenizer.apply_chat_template(
messages,
return_tensors="pt",
add_generation_prompt=True,
).to(device)
input_ids = inputs if isinstance(inputs, torch.Tensor) else inputs["input_ids"]
else:
input_ids = tokenizer(
message,
return_tensors="pt",
add_special_tokens=True,
)["input_ids"].to(device)
streamer = TextIteratorStreamer(tokenizer, skip_special_tokens=True)
gen_kwargs = dict(
input_ids=input_ids,
max_new_tokens=max_new_tokens,
do_sample=temperature > 0.0,
temperature=max(0.0, float(temperature)),
top_p=float(top_p),
pad_token_id=tokenizer.pad_token_id,
eos_token_id=tokenizer.eos_token_id,
streamer=streamer,
)
thread = threading.Thread(target=model.generate, kwargs=gen_kwargs)
thread.start()
output_text = ""
for new_text in streamer:
output_text += new_text
yield output_text
with gr.Blocks(title="MobileLLM-Pro Chat") as demo:
gr.Markdown("""
# facebook/MobileLLM-Pro — Chat Demo
- **Version**: choose `instruct` to enable the model's chat template.
- **Streaming** is enabled. Use the controls in the right panel.
""")
gr.Markdown(
"<div style='text-align: center;'>Built with <a href='https://huggingface.co/spaces/akhaliq/anycoder'>anycoder</a></div>",
elem_id="anycoder_attribution"
)
with gr.Row():
with gr.Column(scale=3):
chatbot = gr.Chatbot(height=420, label="MobileLLM-Pro")
msg = gr.Textbox(placeholder="Ask me anything…", scale=1)
submit = gr.Button("Send", variant="primary")
clear_btn = gr.Button("Clear chat")
with gr.Column(scale=2):
version = gr.Dropdown(["base", "instruct"], value=DEFAULT_VERSION, label="Subfolder (version)")
use_chat_template = gr.Checkbox(value=True, label="Use chat template (instruct only)")
max_new = gr.Slider(32, 1024, value=DEFAULT_MAX_NEW_TOKENS, step=8, label="Max new tokens")
temperature = gr.Slider(0.0, 1.5, value=DEFAULT_TEMPERATURE, step=0.05, label="Temperature")
top_p = gr.Slider(0.1, 1.0, value=DEFAULT_TOP_P, step=0.01, label="Top-p")
hf_token_box = gr.Textbox(value=os.getenv("HF_TOKEN", ""), label="HF_TOKEN (optional)")
state = gr.State({"tokenizer": None, "model": None, "version": None})
def _maybe_login(token: str):
token = (token or "").strip()
if not token:
return "(No token provided; skipping login)"
try:
login(token=token)
return "Logged in to Hugging Face Hub."
except Exception as e:
return f"Login failed: {e}"
login_btn = gr.Button("Login to HF (optional)")
login_status = gr.Markdown()
login_btn.click(_maybe_login, inputs=[hf_token_box], outputs=[login_status])
def user_submit(user_message, chat_history):
# Immediately append the user's message so the stream shows inline
return "", chat_history + [(user_message, None)]
def bot_respond(chat_history, version, max_new, temperature, top_p, use_chat_template, state):
# The last tuple is (user, None)
user_message = chat_history[-1][0] if chat_history else ""
partials = generate_stream(
user_message,
chat_history[:-1],
version,
int(max_new),
float(temperature),
float(top_p),
bool(use_chat_template),
state,
)
# Stream tokens to the last assistant message slot
for chunk in partials:
chat_history[-1] = (chat_history[-1][0], chunk)
yield chat_history
msg.submit(user_submit, [msg, chatbot], [msg, chatbot]).then(
bot_respond,
[chatbot, version, max_new, temperature, top_p, use_chat_template, state],
[chatbot],
)
submit.click(user_submit, [msg, chatbot], [msg, chatbot]).then(
bot_respond,
[chatbot, version, max_new, temperature, top_p, use_chat_template, state],
[chatbot],
)
def clear_chat():
return []
clear_btn.click(clear_chat, outputs=[chatbot])
if __name__ == "__main__":
# For Spaces, Gradio will call `demo.launch()` automatically; locally we launch here.
demo.launch(server_name="0.0.0.0", server_port=int(os.getenv("PORT", 7860)))