akhaliq HF Staff commited on
Commit
01bada7
·
verified ·
1 Parent(s): 9f0ab40

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +189 -379
app.py CHANGED
@@ -1,414 +1,224 @@
1
  import os
2
- import threading
3
- from typing import List, Dict, Tuple, Any, Optional
4
-
5
  import torch
6
  import gradio as gr
7
- from transformers import AutoTokenizer, AutoModelForCausalLM, TextIteratorStreamer
 
 
 
 
 
8
  from huggingface_hub import login
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
9
 
10
- # --- Optional: Hugging Face Spaces GPU decorator (safe locally) ---
11
- try:
12
- import spaces # type: ignore
13
- GPU_DECORATOR = spaces.GPU
14
- except Exception: # running locally without `spaces`
15
- def GPU_DECORATOR(*args, **kwargs): # no-op decorator
16
- def _wrap(fn):
17
- return fn
18
- return _wrap
19
-
20
- # =========================
21
- # Configuration
22
- # =========================
23
  MODEL_ID = "facebook/MobileLLM-Pro"
24
- MODEL_SUBFOLDER = "instruct" # "base" | "instruct"
25
- MAX_HISTORY_LENGTH = 10
26
- MAX_NEW_TOKENS = 512
27
- DEFAULT_SYSTEM_PROMPT = (
28
- "You are a helpful, friendly, and intelligent assistant. "
29
- "Provide clear, accurate, and thoughtful responses."
30
- )
31
 
32
- # =========================
33
- # HF Login (optional)
34
- # =========================
35
  HF_TOKEN = os.getenv("HF_TOKEN")
36
  if HF_TOKEN:
37
  try:
38
  login(token=HF_TOKEN)
39
- print("Successfully logged in to Hugging Face")
40
  except Exception as e:
41
- print(f"Warning: Could not login to Hugging Face: {e}")
42
-
43
 
44
- # =========================
45
- # Utilities
46
- # =========================
47
-
48
- def tuples_from_messages(messages: List[Any]) -> List[List[str]]:
49
- """
50
- Normalize a Chatbot history to tuples [[user, assistant], ...].
51
- Accepts either tuples-style or messages-style ({role, content}) lists.
52
- """
53
- if not messages:
54
- return []
55
- # Already tuples-like
56
- if isinstance(messages[0], (list, tuple)) and len(messages[0]) == 2:
57
- out: List[List[str]] = []
58
- for x in messages:
59
- try:
60
- a, b = x[0], x[1]
61
- except Exception:
62
- continue
63
- out.append([str(a) if a is not None else "", str(b) if b is not None else ""])
64
- return out
65
-
66
- # Convert from messages-style
67
- pairs: List[List[str]] = []
68
- last_user: Optional[str] = None
69
- for m in messages:
70
- if not isinstance(m, dict):
71
- # Skip any stray items
72
- continue
73
- role = m.get("role")
74
- content = m.get("content", "")
75
- if role == "user":
76
- last_user = str(content)
77
- elif role == "assistant":
78
- if last_user is None:
79
- pairs.append(["", str(content)])
80
- else:
81
- pairs.append([last_user, str(content)])
82
- last_user = None
83
- if last_user is not None:
84
- pairs.append([last_user, ""])
85
- return pairs
86
-
87
-
88
- def messages_from_tuples(history_tuples: List[List[str]]) -> List[Dict[str, str]]:
89
- """
90
- Convert tuples [[user, assistant], ...] into list of role dictionaries:
91
- [{"role": "user", ...}, {"role": "assistant", ...}, ...]
92
- """
93
- messages: List[Dict[str, str]] = []
94
- for pair in history_tuples:
95
- if not isinstance(pair, (list, tuple)) or len(pair) != 2:
96
- # Skip malformed entries defensively
97
- continue
98
- u, a = pair
99
- u = "" if u is None else str(u)
100
- a = "" if a is None else str(a)
101
- if u:
102
- messages.append({"role": "user", "content": u})
103
- if a:
104
- messages.append({"role": "assistant", "content": a})
105
- return messages
106
-
107
-
108
- # =========================
109
- # Chat Model Wrapper
110
- # =========================
111
- class MobileLLMChat:
112
- def __init__(self):
113
- self.model = None
114
- self.tokenizer = None
115
- self.device = None
116
- self.model_loaded = False
117
- self.version = None
118
- self.load_model(version=MODEL_SUBFOLDER)
119
-
120
- def load_model(self, version: str = "instruct") -> bool:
121
- """Load tokenizer+model; choose dtype/device_map safely for CPU/GPU."""
122
- try:
123
- print(f"Loading {MODEL_ID} ({version}) ...")
124
- use_cuda = torch.cuda.is_available()
125
- torch_dtype = torch.float16 if use_cuda else torch.float32
126
-
127
- self.tokenizer = AutoTokenizer.from_pretrained(
128
- MODEL_ID, trust_remote_code=True, subfolder=version
129
- )
130
- self.model = AutoModelForCausalLM.from_pretrained(
131
- MODEL_ID,
132
- trust_remote_code=True,
133
- subfolder=version,
134
- torch_dtype=torch_dtype,
135
- low_cpu_mem_usage=True,
136
- device_map="auto" if use_cuda else None,
137
- )
138
- if self.tokenizer.pad_token_id is None:
139
- self.tokenizer.pad_token_id = self.tokenizer.eos_token_id
140
-
141
- self.model.eval()
142
- self.version = version
143
- self.device = next(self.model.parameters()).device
144
- self.model_loaded = True
145
- print("Model loaded successfully.")
146
- return True
147
- except Exception as e:
148
- print(f"Error loading model: {e}")
149
- self.model_loaded = False
150
- return False
151
-
152
- def format_chat_history(
153
- self, history_msgs: List[Dict[str, str]], system_prompt: str
154
- ) -> List[Dict[str, str]]:
155
- messages = [{"role": "system", "content": system_prompt}]
156
- trimmed = [m for m in history_msgs if m.get("role") in ("user", "assistant")]
157
- if MAX_HISTORY_LENGTH > 0:
158
- trimmed = trimmed[-(MAX_HISTORY_LENGTH * 2) :]
159
- messages.extend(trimmed)
160
- return messages
161
-
162
- @GPU_DECORATOR(duration=120)
163
- def generate_once(
164
- self,
165
- user_input: str,
166
- history_msgs: List[Dict[str, str]],
167
- system_prompt: str,
168
- temperature: float = 0.7,
169
- max_new_tokens: int = MAX_NEW_TOKENS,
170
- top_p: float = 0.95,
171
- ) -> str:
172
- """Single-shot generation (no streaming)."""
173
- if not self.model_loaded:
174
- return "Model not loaded. Please reload."
175
- try:
176
- messages = self.format_chat_history(history_msgs + [{"role": "user", "content": user_input}], system_prompt)
177
- inputs = self.tokenizer.apply_chat_template(
178
- messages, return_tensors="pt", add_generation_prompt=True
179
- )
180
- input_ids = inputs if isinstance(inputs, torch.Tensor) else inputs["input_ids"]
181
- input_ids = input_ids.to(self.device)
182
-
183
- with torch.no_grad():
184
- outputs = self.model.generate(
185
- input_ids,
186
- max_new_tokens=max_new_tokens,
187
- temperature=float(temperature),
188
- do_sample=temperature > 0,
189
- top_p=float(top_p),
190
- pad_token_id=self.tokenizer.pad_token_id,
191
- eos_token_id=self.tokenizer.eos_token_id,
192
- )
193
- gen_ids = outputs[0][input_ids.shape[1] :]
194
- return self.tokenizer.decode(gen_ids, skip_special_tokens=True).strip()
195
- except Exception as e:
196
- return f"Error generating response: {e}"
197
-
198
- @GPU_DECORATOR(duration=120)
199
- def stream_generate(
200
- self,
201
- user_input: str,
202
- history_msgs: List[Dict[str, str]],
203
- system_prompt: str,
204
- temperature: float = 0.7,
205
- max_new_tokens: int = MAX_NEW_TOKENS,
206
- top_p: float = 0.95,
207
- ):
208
- """Streaming generator using TextIteratorStreamer."""
209
- messages = self.format_chat_history(history_msgs + [{"role": "user", "content": user_input}], system_prompt)
210
- inputs = self.tokenizer.apply_chat_template(
211
- messages, return_tensors="pt", add_generation_prompt=True
212
- )
213
- input_ids = inputs if isinstance(inputs, torch.Tensor) else inputs["input_ids"]
214
- input_ids = input_ids.to(self.device)
215
-
216
- streamer = TextIteratorStreamer(self.tokenizer, skip_special_tokens=True)
217
- gen_kwargs = dict(
218
- input_ids=input_ids,
219
- max_new_tokens=max_new_tokens,
220
- temperature=float(temperature),
221
- do_sample=temperature > 0,
222
- top_p=float(top_p),
223
- pad_token_id=self.tokenizer.pad_token_id,
224
- eos_token_id=self.tokenizer.eos_token_id,
225
- streamer=streamer,
226
- )
227
 
228
- thread = threading.Thread(target=self.model.generate, kwargs=gen_kwargs)
229
- thread.start()
230
-
231
- partial = ""
232
- for text in streamer:
233
- partial += text
234
- yield partial
235
-
236
-
237
- # =========================
238
- # Initialize Chat Model
239
- # =========================
240
- print("Initializing MobileLLM-Pro model...")
241
- chat_model = MobileLLMChat()
242
-
243
-
244
- # =========================
245
- # Gradio Helpers
246
- # =========================
247
-
248
- def clear_chat():
249
- return [], ""
250
-
251
-
252
- def chat_fn(message, history, system_prompt, temperature, top_p):
253
- """Non-streaming chat handler (returns tuples)."""
254
- history = tuples_from_messages(history or [])
255
- if not chat_model.model_loaded:
256
- return history + [[message, "Please wait for the model to load or reload the space."]]
257
-
258
- formatted_history = messages_from_tuples(history)
259
- response = chat_model.generate_once(message, formatted_history, system_prompt, temperature, MAX_NEW_TOKENS, top_p)
260
 
261
- # Always return strict [[str, str], ...]
262
- return tuples_from_messages(history + [[message, response]])
 
263
 
 
 
 
264
 
265
- def chat_stream_fn(message, history, system_prompt, temperature, top_p):
266
- """Streaming chat handler: yields updated tuples as tokens arrive."""
267
- history = tuples_from_messages(history or [])
268
- if not chat_model.model_loaded:
269
- yield history + [[message, "Please wait for the model to load or reload the space."]]
270
- return
271
 
272
- formatted_history = messages_from_tuples(history)
 
 
 
 
 
 
 
 
273
 
274
- # Start a new row for the assistant and fill progressively
275
- base = history + [[message, ""]]
276
- for chunk in chat_model.stream_generate(message, formatted_history, system_prompt, temperature, MAX_NEW_TOKENS, top_p):
277
- yield tuples_from_messages(base[:-1] + [[message, chunk]])
278
- # Final state already yielded
279
- # Ensure completion (in case streamer ended exactly on boundary)
280
- # No extra yield needed; last chunk already yielded.
281
 
 
 
 
 
 
 
 
 
 
 
 
282
 
283
- def handle_chat(message, history, system_prompt, temperature, top_p, streaming):
284
- return (
285
- chat_stream_fn(message, history, system_prompt, temperature, top_p)
286
- if streaming
287
- else chat_fn(message, history, system_prompt, temperature, top_p)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
288
  )
289
 
 
 
290
 
291
- # =========================
292
- # Gradio UI
293
- # =========================
294
- with gr.Blocks(
295
- title="MobileLLM-Pro Chat",
296
- theme=gr.themes.Soft(),
297
- css="""
298
- .gradio-container { max-width: 900px !important; margin: auto !important; }
299
- .message { padding: 12px !important; border-radius: 8px !important; margin-bottom: 8px !important; }
300
- .user-message { background-color: #e3f2fd !important; margin-left: 20% !important; }
301
- .assistant-message { background-color: #f5f5f5 !important; margin-right: 20% !important; }
302
- """
303
- ) as demo:
304
-
305
- gr.HTML(
306
- """
307
- <div style=\"text-align: center; margin-bottom: 20px;\">
308
- <h1>🤖 MobileLLM-Pro Chat</h1>
309
- <p>Built with <a href=\"https://huggingface.co/spaces/akhaliq/anycoder\" target=\"_blank\">anycoder</a></p>
310
- <p>Chat with Facebook's MobileLLM-Pro model optimized for on-device inference</p>
311
- </div>
312
- """
313
- )
314
 
315
- with gr.Row():
316
- model_status = gr.Textbox(
317
- label="Model Status",
318
- value="Model loaded and ready!" if chat_model.model_loaded else "Model loading...",
319
- interactive=False,
320
- container=True,
321
- )
322
 
323
- with gr.Accordion("⚙️ Configuration", open=False):
324
- with gr.Row():
325
- system_prompt = gr.Textbox(
326
- value=DEFAULT_SYSTEM_PROMPT,
327
- label="System Prompt",
328
- lines=3,
329
- info="Customize the AI's behavior and personality",
330
- )
331
- with gr.Row():
332
- temperature = gr.Slider(
333
- minimum=0.0,
334
- maximum=2.0,
335
- value=0.7,
336
- step=0.05,
337
- label="Temperature",
338
- info="Controls randomness (higher = more creative)",
339
- )
340
- top_p = gr.Slider(
341
- minimum=0.1,
342
- maximum=1.0,
343
- value=0.95,
344
- step=0.01,
345
- label="Top-p",
346
- info="Nucleus sampling threshold",
347
- )
348
- streaming = gr.Checkbox(
349
- value=True,
350
- label="Enable Streaming",
351
- info="Show responses as they're being generated",
352
- )
353
-
354
- chatbot = gr.Chatbot(
355
- type="tuples",
356
- value=[], # ensure initial value is a list of [user, assistant]
357
- label="Chat History",
358
- height=500,
359
- show_copy_button=True,
360
- )
361
 
362
  with gr.Row():
363
- msg = gr.Textbox(
364
- label="Your Message",
365
- placeholder="Type your message here...",
366
- scale=4,
367
- container=False,
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
368
  )
369
- submit_btn = gr.Button("Send", variant="primary", scale=1)
370
- clear_btn = gr.Button("Clear", scale=0)
371
-
372
- msg.submit(
373
- handle_chat,
374
- inputs=[msg, chatbot, system_prompt, temperature, top_p, streaming],
375
- outputs=[chatbot],
376
- ).then(lambda: "", None, msg)
377
-
378
- submit_btn.click(
379
- handle_chat,
380
- inputs=[msg, chatbot, system_prompt, temperature, top_p, streaming],
381
- outputs=[chatbot],
382
- ).then(lambda: "", None, msg)
383
-
384
- clear_btn.click(
385
- clear_chat,
386
- outputs=[chatbot, msg],
387
  )
388
-
389
- gr.Examples(
390
- examples=[
391
- ["What are the benefits of on-device AI models?"],
392
- ["Explain quantum computing in simple terms."],
393
- ["Write a short poem about technology."],
394
- ["What's the difference between machine learning and deep learning?"],
395
- ["How can I improve my productivity?"],
396
- ],
397
- inputs=[msg],
398
- label="Example Prompts",
399
  )
400
 
401
- gr.HTML(
402
- """
403
- <div style=\"text-align: center; margin-top: 20px; color: #666;\">
404
- <p>⚠️ Note: Model is pre-loaded for faster inference. GPU is allocated only during generation.</p>
405
- <p>Model: <a href=\"https://huggingface.co/facebook/MobileLLM-Pro\" target=\"_blank\">facebook/MobileLLM-Pro</a></p>
406
- </div>
407
- """
408
- )
409
 
410
- # Improve streaming UX
411
- demo.queue()
412
 
413
  if __name__ == "__main__":
414
- demo.launch(show_error=True, debug=True)
 
 
1
  import os
2
+ import time
 
 
3
  import torch
4
  import gradio as gr
5
+ from typing import List, Dict, Any, Tuple
6
+ from transformers import (
7
+ AutoTokenizer,
8
+ AutoModelForCausalLM,
9
+ TextIteratorStreamer,
10
+ )
11
  from huggingface_hub import login
12
+ import threading
13
+
14
+ """
15
+ Gradio chat app for facebook/MobileLLM-Pro
16
+ - Uses the model's chat template when using the "instruct" subfolder
17
+ - Streams tokens to the Gradio UI
18
+ - Minimal controls: max_new_tokens, temperature, top_p
19
+ - Optional HF_TOKEN login via env var or textbox
20
+
21
+ To run locally:
22
+ pip install -U gradio transformers accelerate sentencepiece huggingface_hub
23
+ HF_TOKEN=xxxx python app.py
24
+
25
+ On Hugging Face Spaces:
26
+ - Remove explicit login() call or set HF_TOKEN as a secret
27
+ """
28
 
 
 
 
 
 
 
 
 
 
 
 
 
 
29
  MODEL_ID = "facebook/MobileLLM-Pro"
30
+ DEFAULT_VERSION = "instruct" # "base" | "instruct"
31
+ DEFAULT_MAX_NEW_TOKENS = 256
32
+ DEFAULT_TEMPERATURE = 0.7
33
+ DEFAULT_TOP_P = 0.95
 
 
 
34
 
35
+ # ---- Optional: login to Hugging Face if token is provided ----
 
 
36
  HF_TOKEN = os.getenv("HF_TOKEN")
37
  if HF_TOKEN:
38
  try:
39
  login(token=HF_TOKEN)
40
+ print("[INFO] Logged in to Hugging Face Hub.")
41
  except Exception as e:
42
+ print(f"[WARN] Could not login to Hugging Face: {e}")
 
43
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
44
 
45
+ def load_model(version: str = DEFAULT_VERSION):
46
+ """Load tokenizer+model for the selected subfolder (base/instruct)."""
47
+ print(f"[INFO] Loading {MODEL_ID}:{version} ...")
48
+ tokenizer = AutoTokenizer.from_pretrained(
49
+ MODEL_ID, trust_remote_code=True, subfolder=version
50
+ )
51
+ model = AutoModelForCausalLM.from_pretrained(
52
+ MODEL_ID,
53
+ trust_remote_code=True,
54
+ subfolder=version,
55
+ torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32,
56
+ low_cpu_mem_usage=True,
57
+ device_map="auto" if torch.cuda.is_available() else None,
58
+ )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
59
 
60
+ # Ensure special tokens are set to avoid warnings
61
+ if tokenizer.pad_token_id is None and tokenizer.eos_token_id is not None:
62
+ tokenizer.pad_token = tokenizer.eos_token
63
 
64
+ model.eval()
65
+ print("[INFO] Model loaded.")
66
+ return tokenizer, model
67
 
 
 
 
 
 
 
68
 
69
+ def _history_to_messages(history: List[Tuple[str, str]]) -> List[Dict[str, str]]:
70
+ """Map Gradio history [(user, assistant), ...] to chat template messages."""
71
+ messages: List[Dict[str, str]] = []
72
+ for user_msg, bot_msg in history:
73
+ if user_msg:
74
+ messages.append({"role": "user", "content": user_msg})
75
+ if bot_msg:
76
+ messages.append({"role": "assistant", "content": bot_msg})
77
+ return messages
78
 
 
 
 
 
 
 
 
79
 
80
+ def generate_stream(
81
+ message: str,
82
+ history: List[Tuple[str, str]],
83
+ version: str,
84
+ max_new_tokens: int,
85
+ temperature: float,
86
+ top_p: float,
87
+ use_chat_template: bool,
88
+ state: Dict[str, Any],
89
+ ):
90
+ """Streaming text generator compatible with gr.ChatInterface.
91
 
92
+ Args map to UI controls. `state` holds tokenizer/model between calls.
93
+ """
94
+ tokenizer = state.get("tokenizer")
95
+ model = state.get("model")
96
+
97
+ # (Re)load model if version changed or not yet loaded
98
+ if (
99
+ tokenizer is None
100
+ or model is None
101
+ or state.get("version") != version
102
+ ):
103
+ tokenizer, model = load_model(version)
104
+ state["tokenizer"], state["model"], state["version"] = tokenizer, model, version
105
+
106
+ device = next(model.parameters()).device
107
+
108
+ if use_chat_template and version == "instruct":
109
+ messages = _history_to_messages(history) + [
110
+ {"role": "user", "content": message}
111
+ ]
112
+ inputs = tokenizer.apply_chat_template(
113
+ messages,
114
+ return_tensors="pt",
115
+ add_generation_prompt=True,
116
+ ).to(device)
117
+ input_ids = inputs if isinstance(inputs, torch.Tensor) else inputs["input_ids"]
118
+ else:
119
+ input_ids = tokenizer(
120
+ message,
121
+ return_tensors="pt",
122
+ add_special_tokens=True,
123
+ )["input_ids"].to(device)
124
+
125
+ streamer = TextIteratorStreamer(tokenizer, skip_special_tokens=True)
126
+
127
+ gen_kwargs = dict(
128
+ input_ids=input_ids,
129
+ max_new_tokens=max_new_tokens,
130
+ do_sample=temperature > 0.0,
131
+ temperature=max(0.0, float(temperature)),
132
+ top_p=float(top_p),
133
+ pad_token_id=tokenizer.pad_token_id,
134
+ eos_token_id=tokenizer.eos_token_id,
135
+ streamer=streamer,
136
  )
137
 
138
+ thread = threading.Thread(target=model.generate, kwargs=gen_kwargs)
139
+ thread.start()
140
 
141
+ output_text = ""
142
+ for new_text in streamer:
143
+ output_text += new_text
144
+ yield output_text
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
145
 
 
 
 
 
 
 
 
146
 
147
+ with gr.Blocks(title="MobileLLM-Pro Chat") as demo:
148
+ gr.Markdown("""
149
+ # facebook/MobileLLM-Pro — Chat Demo
150
+ - **Version**: choose `instruct` to enable the model's chat template.
151
+ - **Streaming** is enabled. Use the controls in the right panel.
152
+ """)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
153
 
154
  with gr.Row():
155
+ with gr.Column(scale=3):
156
+ chatbot = gr.Chatbot(height=420, label="MobileLLM-Pro")
157
+ msg = gr.Textbox(placeholder="Ask me anything…", scale=1)
158
+ submit = gr.Button("Send", variant="primary")
159
+ clear_btn = gr.Button("Clear chat")
160
+ with gr.Column(scale=2):
161
+ version = gr.Dropdown(["base", "instruct"], value=DEFAULT_VERSION, label="Subfolder (version)")
162
+ use_chat_template = gr.Checkbox(value=True, label="Use chat template (instruct only)")
163
+ max_new = gr.Slider(32, 1024, value=DEFAULT_MAX_NEW_TOKENS, step=8, label="Max new tokens")
164
+ temperature = gr.Slider(0.0, 1.5, value=DEFAULT_TEMPERATURE, step=0.05, label="Temperature")
165
+ top_p = gr.Slider(0.1, 1.0, value=DEFAULT_TOP_P, step=0.01, label="Top-p")
166
+ hf_token_box = gr.Textbox(value=os.getenv("HF_TOKEN", ""), label="HF_TOKEN (optional)")
167
+
168
+ state = gr.State({"tokenizer": None, "model": None, "version": None})
169
+
170
+ def _maybe_login(token: str):
171
+ token = (token or "").strip()
172
+ if not token:
173
+ return "(No token provided; skipping login)"
174
+ try:
175
+ login(token=token)
176
+ return "Logged in to Hugging Face Hub."
177
+ except Exception as e:
178
+ return f"Login failed: {e}"
179
+
180
+ login_btn = gr.Button("Login to HF (optional)")
181
+ login_status = gr.Markdown()
182
+ login_btn.click(_maybe_login, inputs=[hf_token_box], outputs=[login_status])
183
+
184
+ def user_submit(user_message, chat_history):
185
+ # Immediately append the user's message so the stream shows inline
186
+ return "", chat_history + [(user_message, None)]
187
+
188
+ def bot_respond(chat_history, version, max_new, temperature, top_p, use_chat_template, state):
189
+ # The last tuple is (user, None)
190
+ user_message = chat_history[-1][0] if chat_history else ""
191
+ partials = generate_stream(
192
+ user_message,
193
+ chat_history[:-1],
194
+ version,
195
+ int(max_new),
196
+ float(temperature),
197
+ float(top_p),
198
+ bool(use_chat_template),
199
+ state,
200
  )
201
+ # Stream tokens to the last assistant message slot
202
+ for chunk in partials:
203
+ chat_history[-1] = (chat_history[-1][0], chunk)
204
+ yield chat_history
205
+
206
+ msg.submit(user_submit, [msg, chatbot], [msg, chatbot]).then(
207
+ bot_respond,
208
+ [chatbot, version, max_new, temperature, top_p, use_chat_template, state],
209
+ [chatbot],
 
 
 
 
 
 
 
 
 
210
  )
211
+ submit.click(user_submit, [msg, chatbot], [msg, chatbot]).then(
212
+ bot_respond,
213
+ [chatbot, version, max_new, temperature, top_p, use_chat_template, state],
214
+ [chatbot],
 
 
 
 
 
 
 
215
  )
216
 
217
+ def clear_chat():
218
+ return []
 
 
 
 
 
 
219
 
220
+ clear_btn.click(clear_chat, outputs=[chatbot])
 
221
 
222
  if __name__ == "__main__":
223
+ # For Spaces, Gradio will call `demo.launch()` automatically; locally we launch here.
224
+ demo.launch(server_name="0.0.0.0", server_port=int(os.getenv("PORT", 7860)))