File size: 7,994 Bytes
98da568
01bada7
f3c01e2
98da568
01bada7
 
 
 
 
 
f3c01e2
01bada7
21f22c1
01bada7
 
 
 
 
 
 
 
 
 
 
 
 
 
 
195d6db
f3c01e2
01bada7
 
 
 
f3c01e2
01bada7
f3c01e2
 
 
 
01bada7
f3c01e2
01bada7
98da568
 
01bada7
 
 
 
 
 
 
 
 
 
 
 
 
 
9f0ab40
01bada7
 
 
f3c01e2
01bada7
 
 
98da568
 
01bada7
 
 
 
 
 
 
 
 
f3c01e2
98da568
21f22c1
01bada7
 
 
 
 
 
 
 
 
 
 
98da568
01bada7
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
98da568
 
01bada7
 
98da568
01bada7
 
 
 
98da568
 
01bada7
 
 
 
 
 
21f22c1
 
 
 
98da568
f3c01e2
01bada7
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
f3c01e2
01bada7
 
 
 
 
 
 
 
 
f3c01e2
01bada7
 
 
 
f3c01e2
98da568
01bada7
 
98da568
01bada7
98da568
f3c01e2
01bada7
21f22c1
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
import os
import time
import torch
import gradio as gr
from typing import List, Dict, Any, Tuple
from transformers import (
    AutoTokenizer,
    AutoModelForCausalLM,
    TextIteratorStreamer,
)
from huggingface_hub import login
import threading
import spaces

"""
Gradio chat app for facebook/MobileLLM-Pro
- Uses the model's chat template when using the "instruct" subfolder
- Streams tokens to the Gradio UI
- Minimal controls: max_new_tokens, temperature, top_p
- Optional HF_TOKEN login via env var or textbox

To run locally:
  pip install -U gradio transformers accelerate sentencepiece huggingface_hub
  HF_TOKEN=xxxx python app.py

On Hugging Face Spaces:
  - Remove explicit login() call or set HF_TOKEN as a secret
"""

MODEL_ID = "facebook/MobileLLM-Pro"
DEFAULT_VERSION = "instruct"  # "base" | "instruct"
DEFAULT_MAX_NEW_TOKENS = 256
DEFAULT_TEMPERATURE = 0.7
DEFAULT_TOP_P = 0.95

# ---- Optional: login to Hugging Face if token is provided ----
HF_TOKEN = os.getenv("HF_TOKEN")
if HF_TOKEN:
    try:
        login(token=HF_TOKEN)
        print("[INFO] Logged in to Hugging Face Hub.")
    except Exception as e:
        print(f"[WARN] Could not login to Hugging Face: {e}")


def load_model(version: str = DEFAULT_VERSION):
    """Load tokenizer+model for the selected subfolder (base/instruct)."""
    print(f"[INFO] Loading {MODEL_ID}:{version} ...")
    tokenizer = AutoTokenizer.from_pretrained(
        MODEL_ID, trust_remote_code=True, subfolder=version
    )
    model = AutoModelForCausalLM.from_pretrained(
        MODEL_ID,
        trust_remote_code=True,
        subfolder=version,
        torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32,
        low_cpu_mem_usage=True,
        device_map="auto" if torch.cuda.is_available() else None,
    )

    # Ensure special tokens are set to avoid warnings
    if tokenizer.pad_token_id is None and tokenizer.eos_token_id is not None:
        tokenizer.pad_token = tokenizer.eos_token

    model.eval()
    print("[INFO] Model loaded.")
    return tokenizer, model


def _history_to_messages(history: List[Tuple[str, str]]) -> List[Dict[str, str]]:
    """Map Gradio history [(user, assistant), ...] to chat template messages."""
    messages: List[Dict[str, str]] = []
    for user_msg, bot_msg in history:
        if user_msg:
            messages.append({"role": "user", "content": user_msg})
        if bot_msg:
            messages.append({"role": "assistant", "content": bot_msg})
    return messages


@spaces.GPU(duration=120)
def generate_stream(
    message: str,
    history: List[Tuple[str, str]],
    version: str,
    max_new_tokens: int,
    temperature: float,
    top_p: float,
    use_chat_template: bool,
    state: Dict[str, Any],
):
    """Streaming text generator compatible with gr.ChatInterface.

    Args map to UI controls. `state` holds tokenizer/model between calls.
    """
    tokenizer = state.get("tokenizer")
    model = state.get("model")

    # (Re)load model if version changed or not yet loaded
    if (
        tokenizer is None
        or model is None
        or state.get("version") != version
    ):
        tokenizer, model = load_model(version)
        state["tokenizer"], state["model"], state["version"] = tokenizer, model, version

    device = next(model.parameters()).device

    if use_chat_template and version == "instruct":
        messages = _history_to_messages(history) + [
            {"role": "user", "content": message}
        ]
        inputs = tokenizer.apply_chat_template(
            messages,
            return_tensors="pt",
            add_generation_prompt=True,
        ).to(device)
        input_ids = inputs if isinstance(inputs, torch.Tensor) else inputs["input_ids"]
    else:
        input_ids = tokenizer(
            message,
            return_tensors="pt",
            add_special_tokens=True,
        )["input_ids"].to(device)

    streamer = TextIteratorStreamer(tokenizer, skip_special_tokens=True)

    gen_kwargs = dict(
        input_ids=input_ids,
        max_new_tokens=max_new_tokens,
        do_sample=temperature > 0.0,
        temperature=max(0.0, float(temperature)),
        top_p=float(top_p),
        pad_token_id=tokenizer.pad_token_id,
        eos_token_id=tokenizer.eos_token_id,
        streamer=streamer,
    )

    thread = threading.Thread(target=model.generate, kwargs=gen_kwargs)
    thread.start()

    output_text = ""
    for new_text in streamer:
        output_text += new_text
        yield output_text


with gr.Blocks(title="MobileLLM-Pro Chat") as demo:
    gr.Markdown("""
    # facebook/MobileLLM-Pro — Chat Demo
    - **Version**: choose `instruct` to enable the model's chat template.
    - **Streaming** is enabled. Use the controls in the right panel.
    """)
    gr.Markdown(
        "<div style='text-align: center;'>Built with <a href='https://huggingface.co/spaces/akhaliq/anycoder'>anycoder</a></div>",
        elem_id="anycoder_attribution"
    )

    with gr.Row():
        with gr.Column(scale=3):
            chatbot = gr.Chatbot(height=420, label="MobileLLM-Pro")
            msg = gr.Textbox(placeholder="Ask me anything…", scale=1)
            submit = gr.Button("Send", variant="primary")
            clear_btn = gr.Button("Clear chat")
        with gr.Column(scale=2):
            version = gr.Dropdown(["base", "instruct"], value=DEFAULT_VERSION, label="Subfolder (version)")
            use_chat_template = gr.Checkbox(value=True, label="Use chat template (instruct only)")
            max_new = gr.Slider(32, 1024, value=DEFAULT_MAX_NEW_TOKENS, step=8, label="Max new tokens")
            temperature = gr.Slider(0.0, 1.5, value=DEFAULT_TEMPERATURE, step=0.05, label="Temperature")
            top_p = gr.Slider(0.1, 1.0, value=DEFAULT_TOP_P, step=0.01, label="Top-p")
            hf_token_box = gr.Textbox(value=os.getenv("HF_TOKEN", ""), label="HF_TOKEN (optional)")

            state = gr.State({"tokenizer": None, "model": None, "version": None})

    def _maybe_login(token: str):
        token = (token or "").strip()
        if not token:
            return "(No token provided; skipping login)"
        try:
            login(token=token)
            return "Logged in to Hugging Face Hub."
        except Exception as e:
            return f"Login failed: {e}"

    login_btn = gr.Button("Login to HF (optional)")
    login_status = gr.Markdown()
    login_btn.click(_maybe_login, inputs=[hf_token_box], outputs=[login_status])

    def user_submit(user_message, chat_history):
        # Immediately append the user's message so the stream shows inline
        return "", chat_history + [(user_message, None)]

    def bot_respond(chat_history, version, max_new, temperature, top_p, use_chat_template, state):
        # The last tuple is (user, None)
        user_message = chat_history[-1][0] if chat_history else ""
        partials = generate_stream(
            user_message,
            chat_history[:-1],
            version,
            int(max_new),
            float(temperature),
            float(top_p),
            bool(use_chat_template),
            state,
        )
        # Stream tokens to the last assistant message slot
        for chunk in partials:
            chat_history[-1] = (chat_history[-1][0], chunk)
            yield chat_history

    msg.submit(user_submit, [msg, chatbot], [msg, chatbot]).then(
        bot_respond,
        [chatbot, version, max_new, temperature, top_p, use_chat_template, state],
        [chatbot],
    )
    submit.click(user_submit, [msg, chatbot], [msg, chatbot]).then(
        bot_respond,
        [chatbot, version, max_new, temperature, top_p, use_chat_template, state],
        [chatbot],
    )

    def clear_chat():
        return []

    clear_btn.click(clear_chat, outputs=[chatbot])

if __name__ == "__main__":
    # For Spaces, Gradio will call `demo.launch()` automatically; locally we launch here.
    demo.launch(server_name="0.0.0.0", server_port=int(os.getenv("PORT", 7860)))