File size: 3,674 Bytes
98da568
f35bf64
 
 
f3c01e2
98da568
f35bf64
f3c01e2
21f22c1
01bada7
f3c01e2
f35bf64
 
 
f3c01e2
f35bf64
 
f3c01e2
 
 
f35bf64
eb8ec5c
f35bf64
 
 
 
 
 
 
 
 
 
 
a221e28
01bada7
f35bf64
01bada7
 
 
 
 
 
f35bf64
 
 
 
98da568
01bada7
f35bf64
01bada7
 
f35bf64
01bada7
f35bf64
 
98da568
21f22c1
f35bf64
01bada7
f35bf64
 
 
 
01bada7
f35bf64
 
 
 
 
 
 
 
01bada7
eb8ec5c
 
 
 
 
 
 
01bada7
 
f35bf64
 
 
 
 
 
01bada7
98da568
 
f35bf64
01bada7
98da568
f35bf64
01bada7
f35bf64
 
 
eb8ec5c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
98da568
f3c01e2
f35bf64
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
import os
import threading
from typing import List, Tuple, Dict

import torch
import gradio as gr
from transformers import AutoTokenizer, AutoModelForCausalLM, TextIteratorStreamer
from huggingface_hub import login
import spaces

MODEL_ID = "facebook/MobileLLM-Pro"
MAX_NEW_TOKENS = 256
TEMPERATURE = 0.7
TOP_P = 0.95

# --- Silent Hub auth via env/Space Secret (no UI) ---
HF_TOKEN = os.getenv("HF_TOKEN") or os.getenv("HUGGINGFACEHUB_API_TOKEN") or os.getenv("HUGGINGFACE_TOKEN")
if HF_TOKEN:
    try:
        login(token=HF_TOKEN)
    except Exception:
        pass  # stay silent

# Globals so we only load once
_tokenizer = None
_model = None
_device = None

def _ensure_loaded():
    global _tokenizer, _model, _device
    if _tokenizer is not None and _model is not None:
        return
    _tokenizer = AutoTokenizer.from_pretrained(
        MODEL_ID, trust_remote_code=True
    )
    _model = AutoModelForCausalLM.from_pretrained(
        MODEL_ID,
        trust_remote_code=True,
        torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32,
        low_cpu_mem_usage=True,
        device_map="auto" if torch.cuda.is_available() else None,
    )
    if _tokenizer.pad_token_id is None and _tokenizer.eos_token_id is not None:
        _tokenizer.pad_token = _tokenizer.eos_token
    _model.eval()
    _device = next(_model.parameters()).device

def _history_to_messages(history: List[Tuple[str, str]]) -> List[Dict[str, str]]:
    msgs: List[Dict[str, str]] = []
    for user_msg, bot_msg in history:
        if user_msg:
            msgs.append({"role": "user", "content": user_msg})
        if bot_msg:
            msgs.append({"role": "assistant", "content": bot_msg})
    return msgs

@spaces.GPU(duration=120)
def generate_stream(message: str, history: List[Tuple[str, str]]):
    """
    Minimal streaming chat function for gr.ChatInterface.
    Uses instruct chat template. No token UI. No extra controls.
    """
    _ensure_loaded()

    messages = _history_to_messages(history) + [{"role": "user", "content": message}]
    inputs = _tokenizer.apply_chat_template(
        messages,
        return_tensors="pt",
        add_generation_prompt=True,
    )
    input_ids = inputs["input_ids"] if isinstance(inputs, dict) else inputs
    input_ids = input_ids.to(_device)

    # IMPORTANT: don't stream the prompt (prevents system/user text from appearing)
    streamer = TextIteratorStreamer(
        _tokenizer,
        skip_special_tokens=True,
        skip_prompt=True,      # <-- key fix
    )

    gen_kwargs = dict(
        input_ids=input_ids,
        max_new_tokens=MAX_NEW_TOKENS,
        do_sample=TEMPERATURE > 0.0,
        temperature=float(TEMPERATURE),
        top_p=float(TOP_P),
        pad_token_id=_tokenizer.pad_token_id,
        eos_token_id=_tokenizer.eos_token_id,
        streamer=streamer,
    )

    thread = threading.Thread(target=_model.generate, kwargs=gen_kwargs)
    thread.start()

    output = ""
    for new_text in streamer:
        output += new_text
        yield output

with gr.Blocks(title="MobileLLM-Pro β€” Chat") as demo:
    gr.Markdown(
        """
# MobileLLM-Pro β€” Chat
Streaming chat with facebook/MobileLLM-Pro (instruct)

<div style="text-align:center;">
  Built with <a href="https://huggingface.co/spaces/akhaliq/anycoder" target="_blank">anycoder</a>
</div>
""")
    gr.ChatInterface(
        fn=generate_stream,
        chatbot=gr.Chatbot(height=420, label="MobileLLM-Pro"),
        title=None,            # header handled by Markdown above
        description=None,
    )

if __name__ == "__main__":
    demo.launch(server_name="0.0.0.0", server_port=int(os.getenv("PORT", 7860)))