akhaliq HF Staff commited on
Commit
195d6db
·
verified ·
1 Parent(s): 221127d

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +121 -116
app.py CHANGED
@@ -1,12 +1,21 @@
1
  import os
2
- import time
3
- from typing import List, Dict, Tuple, Any
4
 
5
  import torch
6
  import gradio as gr
7
- from transformers import AutoTokenizer, AutoModelForCausalLM
8
  from huggingface_hub import login
9
- import spaces
 
 
 
 
 
 
 
 
 
10
 
11
  # =========================
12
  # Configuration
@@ -35,20 +44,21 @@ if HF_TOKEN:
35
  # =========================
36
  # Utilities
37
  # =========================
38
- def tuples_from_messages(messages: List[Dict[str, Any]]) -> List[List[str]]:
 
39
  """
40
- Convert a Chatbot(type='messages') style history into tuples format
41
- [[user, assistant], ...]. If already tuples-like, return as-is.
42
  """
43
  if not messages:
44
  return []
45
- # If already tuples-like (list with elements of length 2), return
46
  if isinstance(messages[0], (list, tuple)) and len(messages[0]) == 2:
47
  return [list(x) for x in messages]
48
 
49
- # Otherwise, convert from [{"role": "...", "content": "..."}, ...]
50
  pairs: List[List[str]] = []
51
- last_user: str | None = None
52
  for m in messages:
53
  role = m.get("role")
54
  content = m.get("content", "")
@@ -56,12 +66,10 @@ def tuples_from_messages(messages: List[Dict[str, Any]]) -> List[List[str]]:
56
  last_user = content
57
  elif role == "assistant":
58
  if last_user is None:
59
- # If assistant appears first (odd state), pair with empty user
60
  pairs.append(["", content])
61
  else:
62
  pairs.append([last_user, content])
63
  last_user = None
64
- # If there's a dangling user without assistant, pair with empty string
65
  if last_user is not None:
66
  pairs.append([last_user, ""])
67
  return pairs
@@ -74,7 +82,8 @@ def messages_from_tuples(history_tuples: List[List[str]]) -> List[Dict[str, str]
74
  """
75
  messages: List[Dict[str, str]] = []
76
  for u, a in history_tuples:
77
- messages.append({"role": "user", "content": u})
 
78
  if a:
79
  messages.append({"role": "assistant", "content": a})
80
  return messages
@@ -89,12 +98,16 @@ class MobileLLMChat:
89
  self.tokenizer = None
90
  self.device = None
91
  self.model_loaded = False
 
92
  self.load_model(version=MODEL_SUBFOLDER)
93
 
94
- def load_model(self, version="instruct"):
95
- """Load the MobileLLM-Pro model and tokenizer (initially to CPU)."""
96
  try:
97
- print(f"Loading {MODEL_ID} ({version})...")
 
 
 
98
  self.tokenizer = AutoTokenizer.from_pretrained(
99
  MODEL_ID, trust_remote_code=True, subfolder=version
100
  )
@@ -102,91 +115,107 @@ class MobileLLMChat:
102
  MODEL_ID,
103
  trust_remote_code=True,
104
  subfolder=version,
105
- torch_dtype=torch.float16,
106
  low_cpu_mem_usage=True,
 
107
  )
108
- # Safety: ensure pad token exists (some LLMs don't set it)
109
  if self.tokenizer.pad_token_id is None:
110
  self.tokenizer.pad_token_id = self.tokenizer.eos_token_id
111
 
112
  self.model.eval()
 
 
113
  self.model_loaded = True
114
- print("Model loaded successfully to system memory (CPU).")
115
  return True
116
  except Exception as e:
117
  print(f"Error loading model: {e}")
 
118
  return False
119
 
120
  def format_chat_history(
121
- self, history: List[Dict[str, str]], system_prompt: str
122
  ) -> List[Dict[str, str]]:
123
- """Format chat history for tokenizer's chat template."""
124
  messages = [{"role": "system", "content": system_prompt}]
125
- # Truncate to keep the last N turns
126
- trimmed = []
127
- for msg in history:
128
- if msg["role"] in ("user", "assistant"):
129
- trimmed.append(msg)
130
  if MAX_HISTORY_LENGTH > 0:
131
  trimmed = trimmed[-(MAX_HISTORY_LENGTH * 2) :]
132
  messages.extend(trimmed)
133
  return messages
134
 
135
- @spaces.GPU(duration=120)
136
- def generate_response(
137
  self,
138
  user_input: str,
139
- history: List[Dict[str, str]],
140
  system_prompt: str,
141
  temperature: float = 0.7,
142
  max_new_tokens: int = MAX_NEW_TOKENS,
 
143
  ) -> str:
144
- """Generate a full response (GPU during inference)."""
145
  if not self.model_loaded:
146
- return "Model not loaded. Please try reloading the space."
147
  try:
148
- # Choose device (Spaces GPU if available)
149
- use_cuda = torch.cuda.is_available()
150
- self.device = torch.device("cuda" if use_cuda else "cpu")
151
- self.model.to(self.device)
152
-
153
- # Append the new user message
154
- history.append({"role": "user", "content": user_input})
155
- messages = self.format_chat_history(history, system_prompt)
156
-
157
- # Build inputs with chat template
158
- input_ids = self.tokenizer.apply_chat_template(
159
  messages, return_tensors="pt", add_generation_prompt=True
160
- ).to(self.device)
161
- # No padding used here -> full ones mask
162
- attention_mask = torch.ones_like(input_ids)
163
 
164
  with torch.no_grad():
165
  outputs = self.model.generate(
166
  input_ids,
167
- attention_mask=attention_mask,
168
  max_new_tokens=max_new_tokens,
169
- temperature=temperature,
170
- do_sample=True,
171
- pad_token_id=self.tokenizer.eos_token_id,
 
172
  eos_token_id=self.tokenizer.eos_token_id,
173
  )
174
-
175
- # Slice only the newly generated tokens
176
  gen_ids = outputs[0][input_ids.shape[1] :]
177
- response = self.tokenizer.decode(gen_ids, skip_special_tokens=True).strip()
 
 
178
 
179
- # Update history (internal state for the caller if desired)
180
- history.append({"role": "assistant", "content": response})
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
181
 
182
- # Free GPU VRAM
183
- if use_cuda:
184
- self.model.to("cpu")
185
- torch.cuda.empty_cache()
186
 
187
- return response
188
- except Exception as e:
189
- return f"Error generating response: {str(e)}"
 
190
 
191
 
192
  # =========================
@@ -199,65 +228,44 @@ chat_model = MobileLLMChat()
199
  # =========================
200
  # Gradio Helpers
201
  # =========================
 
202
  def clear_chat():
203
- """Clear the chat history and input box."""
204
  return [], ""
205
 
206
 
207
- def chat_fn(message, history, system_prompt, temperature):
208
  """Non-streaming chat handler (returns tuples)."""
209
- # DEFENSIVE: ensure tuples format
210
  history = tuples_from_messages(history)
211
-
212
  if not chat_model.model_loaded:
213
  return history + [[message, "Please wait for the model to load or reload the space."]]
214
 
215
- # Convert tuples -> role dicts for the model
216
  formatted_history = messages_from_tuples(history)
217
-
218
- # Generate full response once
219
- response = chat_model.generate_response(message, formatted_history, system_prompt, temperature)
220
-
221
- # Return updated tuples history
222
  return history + [[message, response]]
223
 
224
 
225
- def chat_stream_fn(message, history, system_prompt, temperature):
226
- """Streaming chat handler (tuples): generate once, then chunk out."""
227
- # DEFENSIVE: ensure tuples format
228
  history = tuples_from_messages(history)
229
-
230
  if not chat_model.model_loaded:
231
  yield history + [[message, "Please wait for the model to load or reload the space."]]
232
  return
233
 
234
- # Convert tuples -> role dicts for the model
235
  formatted_history = messages_from_tuples(history)
236
 
237
- # Generate full response (GPU)
238
- full_response = chat_model.generate_response(
239
- message, formatted_history, system_prompt, temperature
240
- )
241
-
242
- # Start new row and progressively fill assistant side
243
  base = history + [[message, ""]]
244
- if not isinstance(full_response, str):
245
- full_response = str(full_response)
246
-
247
- step = max(8, len(full_response) // 40) # ~40 chunks
248
- for i in range(0, len(full_response), step):
249
- partial = full_response[: i + step]
250
- yield base[:-1] + [[message, partial]]
251
 
252
- # Final ensure complete
253
- yield base[:-1] + [[message, full_response]]
254
 
255
-
256
- def handle_chat(message, history, system_prompt, temperature, streaming):
257
  return (
258
- chat_stream_fn(message, history, system_prompt, temperature)
259
  if streaming
260
- else chat_fn(message, history, system_prompt, temperature)
261
  )
262
 
263
 
@@ -275,18 +283,16 @@ with gr.Blocks(
275
  """
276
  ) as demo:
277
 
278
- # Header
279
  gr.HTML(
280
  """
281
- <div style="text-align: center; margin-bottom: 20px;">
282
  <h1>🤖 MobileLLM-Pro Chat</h1>
283
- <p>Built with <a href="https://huggingface.co/spaces/akhaliq/anycoder" target="_blank">anycoder</a></p>
284
  <p>Chat with Facebook's MobileLLM-Pro model optimized for on-device inference</p>
285
  </div>
286
  """
287
  )
288
 
289
- # Model status
290
  with gr.Row():
291
  model_status = gr.Textbox(
292
  label="Model Status",
@@ -295,7 +301,6 @@ with gr.Blocks(
295
  container=True,
296
  )
297
 
298
- # Config
299
  with gr.Accordion("⚙️ Configuration", open=False):
300
  with gr.Row():
301
  system_prompt = gr.Textbox(
@@ -306,20 +311,27 @@ with gr.Blocks(
306
  )
307
  with gr.Row():
308
  temperature = gr.Slider(
309
- minimum=0.1,
310
  maximum=2.0,
311
  value=0.7,
312
- step=0.1,
313
  label="Temperature",
314
  info="Controls randomness (higher = more creative)",
315
  )
 
 
 
 
 
 
 
 
316
  streaming = gr.Checkbox(
317
  value=True,
318
  label="Enable Streaming",
319
  info="Show responses as they're being generated",
320
  )
321
 
322
- # Chatbot in TUPLES mode (explicit)
323
  chatbot = gr.Chatbot(
324
  type="tuples",
325
  label="Chat History",
@@ -337,16 +349,15 @@ with gr.Blocks(
337
  submit_btn = gr.Button("Send", variant="primary", scale=1)
338
  clear_btn = gr.Button("Clear", scale=0)
339
 
340
- # Wire events (also clear the input box after send)
341
  msg.submit(
342
  handle_chat,
343
- inputs=[msg, chatbot, system_prompt, temperature, streaming],
344
  outputs=[chatbot],
345
  ).then(lambda: "", None, msg)
346
 
347
  submit_btn.click(
348
  handle_chat,
349
- inputs=[msg, chatbot, system_prompt, temperature, streaming],
350
  outputs=[chatbot],
351
  ).then(lambda: "", None, msg)
352
 
@@ -355,7 +366,6 @@ with gr.Blocks(
355
  outputs=[chatbot, msg],
356
  )
357
 
358
- # Examples
359
  gr.Examples(
360
  examples=[
361
  ["What are the benefits of on-device AI models?"],
@@ -368,22 +378,17 @@ with gr.Blocks(
368
  label="Example Prompts",
369
  )
370
 
371
- # Footer
372
  gr.HTML(
373
  """
374
- <div style="text-align: center; margin-top: 20px; color: #666;">
375
  <p>⚠️ Note: Model is pre-loaded for faster inference. GPU is allocated only during generation.</p>
376
- <p>Model: <a href="https://huggingface.co/facebook/MobileLLM-Pro" target="_blank">facebook/MobileLLM-Pro</a></p>
377
  </div>
378
  """
379
  )
380
 
381
- # Optional: queue to improve streaming UX
382
  demo.queue()
383
 
384
- # Launch (NO share=True on Spaces)
385
  if __name__ == "__main__":
386
- demo.launch(
387
- show_error=True,
388
- debug=True,
389
- )
 
1
  import os
2
+ import threading
3
+ from typing import List, Dict, Tuple, Any, Optional
4
 
5
  import torch
6
  import gradio as gr
7
+ from transformers import AutoTokenizer, AutoModelForCausalLM, TextIteratorStreamer
8
  from huggingface_hub import login
9
+
10
+ # --- Optional: Hugging Face Spaces GPU decorator (safe locally) ---
11
+ try:
12
+ import spaces # type: ignore
13
+ GPU_DECORATOR = spaces.GPU
14
+ except Exception: # running locally without `spaces`
15
+ def GPU_DECORATOR(*args, **kwargs): # no-op decorator
16
+ def _wrap(fn):
17
+ return fn
18
+ return _wrap
19
 
20
  # =========================
21
  # Configuration
 
44
  # =========================
45
  # Utilities
46
  # =========================
47
+
48
+ def tuples_from_messages(messages: List[Any]) -> List[List[str]]:
49
  """
50
+ Normalize a Chatbot history to tuples [[user, assistant], ...].
51
+ Accepts either tuples-style or messages-style ({role, content}) lists.
52
  """
53
  if not messages:
54
  return []
55
+ # Already tuples-like
56
  if isinstance(messages[0], (list, tuple)) and len(messages[0]) == 2:
57
  return [list(x) for x in messages]
58
 
59
+ # Convert from messages-style
60
  pairs: List[List[str]] = []
61
+ last_user: Optional[str] = None
62
  for m in messages:
63
  role = m.get("role")
64
  content = m.get("content", "")
 
66
  last_user = content
67
  elif role == "assistant":
68
  if last_user is None:
 
69
  pairs.append(["", content])
70
  else:
71
  pairs.append([last_user, content])
72
  last_user = None
 
73
  if last_user is not None:
74
  pairs.append([last_user, ""])
75
  return pairs
 
82
  """
83
  messages: List[Dict[str, str]] = []
84
  for u, a in history_tuples:
85
+ if u:
86
+ messages.append({"role": "user", "content": u})
87
  if a:
88
  messages.append({"role": "assistant", "content": a})
89
  return messages
 
98
  self.tokenizer = None
99
  self.device = None
100
  self.model_loaded = False
101
+ self.version = None
102
  self.load_model(version=MODEL_SUBFOLDER)
103
 
104
+ def load_model(self, version: str = "instruct") -> bool:
105
+ """Load tokenizer+model; choose dtype/device_map safely for CPU/GPU."""
106
  try:
107
+ print(f"Loading {MODEL_ID} ({version}) ...")
108
+ use_cuda = torch.cuda.is_available()
109
+ torch_dtype = torch.float16 if use_cuda else torch.float32
110
+
111
  self.tokenizer = AutoTokenizer.from_pretrained(
112
  MODEL_ID, trust_remote_code=True, subfolder=version
113
  )
 
115
  MODEL_ID,
116
  trust_remote_code=True,
117
  subfolder=version,
118
+ torch_dtype=torch_dtype,
119
  low_cpu_mem_usage=True,
120
+ device_map="auto" if use_cuda else None,
121
  )
 
122
  if self.tokenizer.pad_token_id is None:
123
  self.tokenizer.pad_token_id = self.tokenizer.eos_token_id
124
 
125
  self.model.eval()
126
+ self.version = version
127
+ self.device = next(self.model.parameters()).device
128
  self.model_loaded = True
129
+ print("Model loaded successfully.")
130
  return True
131
  except Exception as e:
132
  print(f"Error loading model: {e}")
133
+ self.model_loaded = False
134
  return False
135
 
136
  def format_chat_history(
137
+ self, history_msgs: List[Dict[str, str]], system_prompt: str
138
  ) -> List[Dict[str, str]]:
 
139
  messages = [{"role": "system", "content": system_prompt}]
140
+ trimmed = [m for m in history_msgs if m.get("role") in ("user", "assistant")]
 
 
 
 
141
  if MAX_HISTORY_LENGTH > 0:
142
  trimmed = trimmed[-(MAX_HISTORY_LENGTH * 2) :]
143
  messages.extend(trimmed)
144
  return messages
145
 
146
+ @GPU_DECORATOR(duration=120)
147
+ def generate_once(
148
  self,
149
  user_input: str,
150
+ history_msgs: List[Dict[str, str]],
151
  system_prompt: str,
152
  temperature: float = 0.7,
153
  max_new_tokens: int = MAX_NEW_TOKENS,
154
+ top_p: float = 0.95,
155
  ) -> str:
156
+ """Single-shot generation (no streaming)."""
157
  if not self.model_loaded:
158
+ return "Model not loaded. Please reload."
159
  try:
160
+ messages = self.format_chat_history(history_msgs + [{"role": "user", "content": user_input}], system_prompt)
161
+ inputs = self.tokenizer.apply_chat_template(
 
 
 
 
 
 
 
 
 
162
  messages, return_tensors="pt", add_generation_prompt=True
163
+ )
164
+ input_ids = inputs if isinstance(inputs, torch.Tensor) else inputs["input_ids"]
165
+ input_ids = input_ids.to(self.device)
166
 
167
  with torch.no_grad():
168
  outputs = self.model.generate(
169
  input_ids,
 
170
  max_new_tokens=max_new_tokens,
171
+ temperature=float(temperature),
172
+ do_sample=temperature > 0,
173
+ top_p=float(top_p),
174
+ pad_token_id=self.tokenizer.pad_token_id,
175
  eos_token_id=self.tokenizer.eos_token_id,
176
  )
 
 
177
  gen_ids = outputs[0][input_ids.shape[1] :]
178
+ return self.tokenizer.decode(gen_ids, skip_special_tokens=True).strip()
179
+ except Exception as e:
180
+ return f"Error generating response: {e}"
181
 
182
+ @GPU_DECORATOR(duration=120)
183
+ def stream_generate(
184
+ self,
185
+ user_input: str,
186
+ history_msgs: List[Dict[str, str]],
187
+ system_prompt: str,
188
+ temperature: float = 0.7,
189
+ max_new_tokens: int = MAX_NEW_TOKENS,
190
+ top_p: float = 0.95,
191
+ ):
192
+ """Streaming generator using TextIteratorStreamer."""
193
+ messages = self.format_chat_history(history_msgs + [{"role": "user", "content": user_input}], system_prompt)
194
+ inputs = self.tokenizer.apply_chat_template(
195
+ messages, return_tensors="pt", add_generation_prompt=True
196
+ )
197
+ input_ids = inputs if isinstance(inputs, torch.Tensor) else inputs["input_ids"]
198
+ input_ids = input_ids.to(self.device)
199
+
200
+ streamer = TextIteratorStreamer(self.tokenizer, skip_special_tokens=True)
201
+ gen_kwargs = dict(
202
+ input_ids=input_ids,
203
+ max_new_tokens=max_new_tokens,
204
+ temperature=float(temperature),
205
+ do_sample=temperature > 0,
206
+ top_p=float(top_p),
207
+ pad_token_id=self.tokenizer.pad_token_id,
208
+ eos_token_id=self.tokenizer.eos_token_id,
209
+ streamer=streamer,
210
+ )
211
 
212
+ thread = threading.Thread(target=self.model.generate, kwargs=gen_kwargs)
213
+ thread.start()
 
 
214
 
215
+ partial = ""
216
+ for text in streamer:
217
+ partial += text
218
+ yield partial
219
 
220
 
221
  # =========================
 
228
  # =========================
229
  # Gradio Helpers
230
  # =========================
231
+
232
  def clear_chat():
 
233
  return [], ""
234
 
235
 
236
+ def chat_fn(message, history, system_prompt, temperature, top_p):
237
  """Non-streaming chat handler (returns tuples)."""
 
238
  history = tuples_from_messages(history)
 
239
  if not chat_model.model_loaded:
240
  return history + [[message, "Please wait for the model to load or reload the space."]]
241
 
 
242
  formatted_history = messages_from_tuples(history)
243
+ response = chat_model.generate_once(message, formatted_history, system_prompt, temperature, MAX_NEW_TOKENS, top_p)
 
 
 
 
244
  return history + [[message, response]]
245
 
246
 
247
+ def chat_stream_fn(message, history, system_prompt, temperature, top_p):
248
+ """Streaming chat handler: yields updated tuples as tokens arrive."""
 
249
  history = tuples_from_messages(history)
 
250
  if not chat_model.model_loaded:
251
  yield history + [[message, "Please wait for the model to load or reload the space."]]
252
  return
253
 
 
254
  formatted_history = messages_from_tuples(history)
255
 
256
+ # Start a new row for the assistant and fill progressively
 
 
 
 
 
257
  base = history + [[message, ""]]
258
+ for chunk in chat_model.stream_generate(message, formatted_history, system_prompt, temperature, MAX_NEW_TOKENS, top_p):
259
+ yield base[:-1] + [[message, chunk]]
260
+ # Ensure completion (in case streamer ended exactly on boundary)
261
+ # No extra yield needed; last chunk already yielded.
 
 
 
262
 
 
 
263
 
264
+ def handle_chat(message, history, system_prompt, temperature, top_p, streaming):
 
265
  return (
266
+ chat_stream_fn(message, history, system_prompt, temperature, top_p)
267
  if streaming
268
+ else chat_fn(message, history, system_prompt, temperature, top_p)
269
  )
270
 
271
 
 
283
  """
284
  ) as demo:
285
 
 
286
  gr.HTML(
287
  """
288
+ <div style=\"text-align: center; margin-bottom: 20px;\">
289
  <h1>🤖 MobileLLM-Pro Chat</h1>
290
+ <p>Built with <a href=\"https://huggingface.co/spaces/akhaliq/anycoder\" target=\"_blank\">anycoder</a></p>
291
  <p>Chat with Facebook's MobileLLM-Pro model optimized for on-device inference</p>
292
  </div>
293
  """
294
  )
295
 
 
296
  with gr.Row():
297
  model_status = gr.Textbox(
298
  label="Model Status",
 
301
  container=True,
302
  )
303
 
 
304
  with gr.Accordion("⚙️ Configuration", open=False):
305
  with gr.Row():
306
  system_prompt = gr.Textbox(
 
311
  )
312
  with gr.Row():
313
  temperature = gr.Slider(
314
+ minimum=0.0,
315
  maximum=2.0,
316
  value=0.7,
317
+ step=0.05,
318
  label="Temperature",
319
  info="Controls randomness (higher = more creative)",
320
  )
321
+ top_p = gr.Slider(
322
+ minimum=0.1,
323
+ maximum=1.0,
324
+ value=0.95,
325
+ step=0.01,
326
+ label="Top-p",
327
+ info="Nucleus sampling threshold",
328
+ )
329
  streaming = gr.Checkbox(
330
  value=True,
331
  label="Enable Streaming",
332
  info="Show responses as they're being generated",
333
  )
334
 
 
335
  chatbot = gr.Chatbot(
336
  type="tuples",
337
  label="Chat History",
 
349
  submit_btn = gr.Button("Send", variant="primary", scale=1)
350
  clear_btn = gr.Button("Clear", scale=0)
351
 
 
352
  msg.submit(
353
  handle_chat,
354
+ inputs=[msg, chatbot, system_prompt, temperature, top_p, streaming],
355
  outputs=[chatbot],
356
  ).then(lambda: "", None, msg)
357
 
358
  submit_btn.click(
359
  handle_chat,
360
+ inputs=[msg, chatbot, system_prompt, temperature, top_p, streaming],
361
  outputs=[chatbot],
362
  ).then(lambda: "", None, msg)
363
 
 
366
  outputs=[chatbot, msg],
367
  )
368
 
 
369
  gr.Examples(
370
  examples=[
371
  ["What are the benefits of on-device AI models?"],
 
378
  label="Example Prompts",
379
  )
380
 
 
381
  gr.HTML(
382
  """
383
+ <div style=\"text-align: center; margin-top: 20px; color: #666;\">
384
  <p>⚠️ Note: Model is pre-loaded for faster inference. GPU is allocated only during generation.</p>
385
+ <p>Model: <a href=\"https://huggingface.co/facebook/MobileLLM-Pro\" target=\"_blank\">facebook/MobileLLM-Pro</a></p>
386
  </div>
387
  """
388
  )
389
 
390
+ # Improve streaming UX
391
  demo.queue()
392
 
 
393
  if __name__ == "__main__":
394
+ demo.launch(show_error=True, debug=True)