Spaces:
Running
on
Zero
Running
on
Zero
File size: 3,674 Bytes
98da568 f35bf64 f3c01e2 98da568 f35bf64 f3c01e2 21f22c1 01bada7 f3c01e2 f35bf64 f3c01e2 f35bf64 f3c01e2 f35bf64 eb8ec5c f35bf64 a221e28 01bada7 f35bf64 01bada7 f35bf64 98da568 01bada7 f35bf64 01bada7 f35bf64 01bada7 f35bf64 98da568 21f22c1 f35bf64 01bada7 f35bf64 01bada7 f35bf64 01bada7 eb8ec5c 01bada7 f35bf64 01bada7 98da568 f35bf64 01bada7 98da568 f35bf64 01bada7 f35bf64 eb8ec5c 98da568 f3c01e2 f35bf64 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 |
import os
import threading
from typing import List, Tuple, Dict
import torch
import gradio as gr
from transformers import AutoTokenizer, AutoModelForCausalLM, TextIteratorStreamer
from huggingface_hub import login
import spaces
MODEL_ID = "facebook/MobileLLM-Pro"
MAX_NEW_TOKENS = 256
TEMPERATURE = 0.7
TOP_P = 0.95
# --- Silent Hub auth via env/Space Secret (no UI) ---
HF_TOKEN = os.getenv("HF_TOKEN") or os.getenv("HUGGINGFACEHUB_API_TOKEN") or os.getenv("HUGGINGFACE_TOKEN")
if HF_TOKEN:
try:
login(token=HF_TOKEN)
except Exception:
pass # stay silent
# Globals so we only load once
_tokenizer = None
_model = None
_device = None
def _ensure_loaded():
global _tokenizer, _model, _device
if _tokenizer is not None and _model is not None:
return
_tokenizer = AutoTokenizer.from_pretrained(
MODEL_ID, trust_remote_code=True
)
_model = AutoModelForCausalLM.from_pretrained(
MODEL_ID,
trust_remote_code=True,
torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32,
low_cpu_mem_usage=True,
device_map="auto" if torch.cuda.is_available() else None,
)
if _tokenizer.pad_token_id is None and _tokenizer.eos_token_id is not None:
_tokenizer.pad_token = _tokenizer.eos_token
_model.eval()
_device = next(_model.parameters()).device
def _history_to_messages(history: List[Tuple[str, str]]) -> List[Dict[str, str]]:
msgs: List[Dict[str, str]] = []
for user_msg, bot_msg in history:
if user_msg:
msgs.append({"role": "user", "content": user_msg})
if bot_msg:
msgs.append({"role": "assistant", "content": bot_msg})
return msgs
@spaces.GPU(duration=120)
def generate_stream(message: str, history: List[Tuple[str, str]]):
"""
Minimal streaming chat function for gr.ChatInterface.
Uses instruct chat template. No token UI. No extra controls.
"""
_ensure_loaded()
messages = _history_to_messages(history) + [{"role": "user", "content": message}]
inputs = _tokenizer.apply_chat_template(
messages,
return_tensors="pt",
add_generation_prompt=True,
)
input_ids = inputs["input_ids"] if isinstance(inputs, dict) else inputs
input_ids = input_ids.to(_device)
# IMPORTANT: don't stream the prompt (prevents system/user text from appearing)
streamer = TextIteratorStreamer(
_tokenizer,
skip_special_tokens=True,
skip_prompt=True, # <-- key fix
)
gen_kwargs = dict(
input_ids=input_ids,
max_new_tokens=MAX_NEW_TOKENS,
do_sample=TEMPERATURE > 0.0,
temperature=float(TEMPERATURE),
top_p=float(TOP_P),
pad_token_id=_tokenizer.pad_token_id,
eos_token_id=_tokenizer.eos_token_id,
streamer=streamer,
)
thread = threading.Thread(target=_model.generate, kwargs=gen_kwargs)
thread.start()
output = ""
for new_text in streamer:
output += new_text
yield output
with gr.Blocks(title="MobileLLM-Pro β Chat") as demo:
gr.Markdown(
"""
# MobileLLM-Pro β Chat
Streaming chat with facebook/MobileLLM-Pro (instruct)
<div style="text-align:center;">
Built with <a href="https://huggingface.co/spaces/akhaliq/anycoder" target="_blank">anycoder</a>
</div>
""")
gr.ChatInterface(
fn=generate_stream,
chatbot=gr.Chatbot(height=420, label="MobileLLM-Pro"),
title=None, # header handled by Markdown above
description=None,
)
if __name__ == "__main__":
demo.launch(server_name="0.0.0.0", server_port=int(os.getenv("PORT", 7860)))
|