neovalle commited on
Commit
19ada00
·
verified ·
1 Parent(s): fa6b03e

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +161 -51
app.py CHANGED
@@ -1,4 +1,6 @@
1
- # app.py
 
 
2
  import tempfile
3
  from datetime import datetime
4
 
@@ -7,6 +9,25 @@ import pandas as pd
7
  import torch
8
  from transformers import AutoModelForCausalLM, AutoTokenizer
9
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
10
  # ----------------------------
11
  # Config
12
  # ----------------------------
@@ -17,36 +38,94 @@ DEFAULT_MODELS = [
17
  "neovalle/tinyllama-1.1B-h4rmony-trained",
18
  ]
19
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
20
  _MODEL_CACHE = {} # cache: model_id -> (tokenizer, model)
21
 
22
 
23
  # ----------------------------
24
- # Utilities
25
  # ----------------------------
26
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
27
  def _load_model(model_id: str):
 
28
  if model_id in _MODEL_CACHE:
29
  return _MODEL_CACHE[model_id]
30
 
31
  tok = AutoTokenizer.from_pretrained(model_id, use_fast=True)
32
 
33
- # Ensure pad token exists for generate()
34
  if tok.pad_token is None:
35
  if tok.eos_token is not None:
36
  tok.pad_token = tok.eos_token
37
  else:
38
  tok.add_special_tokens({"pad_token": "<|pad|>"})
39
 
40
- dtype = torch.bfloat16 if torch.cuda.is_available() else torch.float32
 
 
 
 
 
 
 
 
 
 
 
 
 
 
41
  model = AutoModelForCausalLM.from_pretrained(
42
  model_id,
43
- torch_dtype=dtype,
44
  low_cpu_mem_usage=True,
45
  device_map="auto",
46
- )
 
 
 
 
 
47
  if model.get_input_embeddings().num_embeddings != len(tok):
48
  model.resize_token_embeddings(len(tok))
49
 
 
 
 
 
 
 
50
  _MODEL_CACHE[model_id] = (tok, model)
51
  return tok, model
52
 
@@ -70,6 +149,37 @@ def _format_prompt(tokenizer, system_prompt: str, user_prompt: str) -> str:
70
  return f"{prefix}<<USER>>\n{usr}\n<</USER>>\n<<ASSISTANT>>\n"
71
 
72
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
73
  def generate_batch_df(
74
  model_id: str,
75
  system_prompt: str,
@@ -81,45 +191,40 @@ def generate_batch_df(
81
  repetition_penalty: float,
82
  ) -> pd.DataFrame:
83
  tok, model = _load_model(model_id)
84
- device = model.device
85
 
 
86
  prompts = [p.strip() for p in prompts_multiline.splitlines() if p.strip()]
87
  if not prompts:
88
  return pd.DataFrame([{"user_prompt": "", "response": "", "tokens_out": 0}])
89
 
90
  formatted = [_format_prompt(tok, system_prompt, p) for p in prompts]
91
- enc = tok(
92
- formatted,
93
- return_tensors="pt",
94
- padding=True,
95
- truncation=True,
96
- ).to(device)
97
-
98
- prompt_lens = enc["attention_mask"].sum(dim=1)
99
 
100
- with torch.no_grad():
101
- gen = model.generate(
102
- **enc,
103
- max_new_tokens=int(max_new_tokens),
104
- do_sample=(temperature > 0.0),
105
- temperature=float(temperature) if temperature > 0 else None,
106
- top_p=float(top_p),
107
- top_k=int(top_k) if int(top_k) > 0 else None,
108
- repetition_penalty=float(repetition_penalty),
109
- eos_token_id=tok.eos_token_id,
110
- pad_token_id=tok.pad_token_id,
111
- )
 
 
 
 
112
 
113
- responses, tokens_out = [], []
114
- for i in range(gen.size(0)):
115
- start = int(prompt_lens[i].item())
116
- gen_ids = gen[i, start:]
117
- text = tok.decode(gen_ids, skip_special_tokens=True).strip()
118
- responses.append(text)
119
- tokens_out.append(len(gen_ids))
120
 
121
  return pd.DataFrame(
122
- {"user_prompt": prompts, "response": responses, "tokens_out": tokens_out}
123
  )
124
 
125
 
@@ -134,11 +239,11 @@ def write_csv_path(df: pd.DataFrame) -> str:
134
  # Gradio UI
135
  # ----------------------------
136
 
137
- with gr.Blocks(title="Multi-Prompt Chat (System Prompt Control)") as demo:
138
  gr.Markdown(
139
  """
140
- # Multi-Prompt Chat to test system prompt effects
141
- Pick a small free model, set a **system prompt**, and enter **multiple user prompts** (one per line).
142
  Click **Generate** to get batched responses and a **downloadable CSV**.
143
  """
144
  )
@@ -149,24 +254,24 @@ with gr.Blocks(title="Multi-Prompt Chat (System Prompt Control)") as demo:
149
  choices=DEFAULT_MODELS,
150
  value=DEFAULT_MODELS[0],
151
  label="Model",
152
- info="Free, small instruction-tuned models that run on CPU and free HF Space",
153
  )
154
  system_prompt = gr.Textbox(
155
  label="System prompt",
156
- placeholder="e.g., You are an ecolinguistics-aware assistant that always prioritise planetary well-being over anthropocentrism.",
157
  lines=5,
158
  )
159
  prompts_multiline = gr.Textbox(
160
  label="User prompts (one per line)",
161
- placeholder="One query per line.\nExample:\nExplain transformers in simple terms\nGive 3 eco-friendly tips for students\nSummarise the benefits of multilingual models",
162
  lines=10,
163
  )
164
 
165
  with gr.Accordion("Generation settings", open=False):
166
- max_new_tokens = gr.Slider(16, 1024, value=256, step=1, label="max_new_tokens")
167
- temperature = gr.Slider(0.0, 2.0, value=0.7, step=0.05, label="temperature")
168
- top_p = gr.Slider(0.1, 1.0, value=0.9, step=0.05, label="top_p")
169
- top_k = gr.Slider(0, 200, value=40, step=1, label="top_k (0 disables)")
170
  repetition_penalty = gr.Slider(1.0, 2.0, value=1.1, step=0.01, label="repetition_penalty")
171
 
172
  run_btn = gr.Button("Generate", variant="primary")
@@ -179,15 +284,18 @@ with gr.Blocks(title="Multi-Prompt Chat (System Prompt Control)") as demo:
179
  wrap=True,
180
  interactive=False,
181
  row_count=(0, "dynamic"),
182
- type="pandas", # ensures pandas goes into callbacks
183
  )
184
-
185
- # IMPORTANT: type="filepath" so we can return a string path
186
  csv_out = gr.File(label="CSV output", interactive=False, type="filepath")
187
 
188
- # -------- Callback: generate table AND CSV path in one go --------
 
 
 
 
 
189
 
190
- def _generate_cb(model_id, system_prompt, prompts_multiline, max_new_tokens, temperature, top_p, top_k, repetition_penalty):
191
  df = generate_batch_df(
192
  model_id=model_id,
193
  system_prompt=system_prompt,
@@ -198,8 +306,10 @@ with gr.Blocks(title="Multi-Prompt Chat (System Prompt Control)") as demo:
198
  top_k=int(top_k),
199
  repetition_penalty=float(repetition_penalty),
200
  )
 
201
  csv_path = write_csv_path(df)
202
- return df, csv_path # DataFrame to table, path to File(type="filepath")
 
203
 
204
  run_btn.click(
205
  _generate_cb,
 
1
+ # app.py — ZeroGPU-optimised Gradio app (HF Spaces)
2
+
3
+ import os
4
  import tempfile
5
  from datetime import datetime
6
 
 
9
  import torch
10
  from transformers import AutoModelForCausalLM, AutoTokenizer
11
 
12
+ # ---- ZeroGPU decorator ----
13
+ try:
14
+ import spaces # HF Spaces utility (provides @spaces.GPU())
15
+ except Exception:
16
+ # Fallback: make a no-op decorator so the app still runs locally/CPU
17
+ class _Noop:
18
+ def GPU(self, *args, **kwargs):
19
+ def deco(fn):
20
+ return fn
21
+ return deco
22
+ spaces = _Noop()
23
+
24
+ # ---- Optional quantisation (GPU only) ----
25
+ try:
26
+ from transformers import BitsAndBytesConfig
27
+ HAS_BNB = True
28
+ except Exception:
29
+ HAS_BNB = False
30
+
31
  # ----------------------------
32
  # Config
33
  # ----------------------------
 
38
  "neovalle/tinyllama-1.1B-h4rmony-trained",
39
  ]
40
 
41
+ # Keep batches reasonable on ZeroGPU for low latency
42
+ MICROBATCH = 4
43
+
44
+ # Cap encoder length to avoid wasting time on very long inputs
45
+ MAX_INPUT_TOKENS = 1024
46
+
47
+ # Speed on GPU (TF32 gives extra throughput on Ampere+)
48
+ if torch.cuda.is_available():
49
+ torch.backends.cuda.matmul.allow_tf32 = True
50
+ torch.backends.cudnn.allow_tf32 = True
51
+ else:
52
+ # On CPU, reducing threads sometimes helps stability/predictability
53
+ try:
54
+ torch.set_num_threads(max(1, (os.cpu_count() or 4) // 2))
55
+ except Exception:
56
+ pass
57
+
58
  _MODEL_CACHE = {} # cache: model_id -> (tokenizer, model)
59
 
60
 
61
  # ----------------------------
62
+ # Helpers
63
  # ----------------------------
64
 
65
+ def _all_eos_ids(tok):
66
+ """Collect a few likely EOS ids so generation can stop earlier."""
67
+ ids = set()
68
+ if tok.eos_token_id is not None:
69
+ ids.add(tok.eos_token_id)
70
+ for t in ("<|im_end|>", "<|endoftext|>", "</s>"):
71
+ try:
72
+ tid = tok.convert_tokens_to_ids(t)
73
+ if isinstance(tid, int) and tid >= 0:
74
+ ids.add(tid)
75
+ except Exception:
76
+ pass
77
+ return list(ids) if ids else None
78
+
79
+
80
  def _load_model(model_id: str):
81
+ """Load & cache model/tokenizer. On GPU, prefer 4-bit NF4 with BF16 compute."""
82
  if model_id in _MODEL_CACHE:
83
  return _MODEL_CACHE[model_id]
84
 
85
  tok = AutoTokenizer.from_pretrained(model_id, use_fast=True)
86
 
87
+ # Ensure a pad token for batch generate()
88
  if tok.pad_token is None:
89
  if tok.eos_token is not None:
90
  tok.pad_token = tok.eos_token
91
  else:
92
  tok.add_special_tokens({"pad_token": "<|pad|>"})
93
 
94
+ use_gpu = torch.cuda.is_available()
95
+ dtype = (
96
+ torch.bfloat16 if (use_gpu and torch.cuda.is_bf16_supported()) else
97
+ (torch.float16 if use_gpu else torch.float32)
98
+ )
99
+
100
+ quant_cfg = None
101
+ if use_gpu and HAS_BNB:
102
+ quant_cfg = BitsAndBytesConfig(
103
+ load_in_4bit=True,
104
+ bnb_4bit_use_double_quant=True,
105
+ bnb_4bit_quant_type="nf4",
106
+ bnb_4bit_compute_dtype=torch.bfloat16 if torch.cuda.is_bf16_supported() else torch.float16,
107
+ )
108
+
109
  model = AutoModelForCausalLM.from_pretrained(
110
  model_id,
111
+ torch_dtype=(torch.bfloat16 if use_gpu else torch.float32),
112
  low_cpu_mem_usage=True,
113
  device_map="auto",
114
+ quantization_config=quant_cfg, # 4-bit on GPU if available; None on CPU
115
+ trust_remote_code=True, # helps for chat templates (e.g., Qwen)
116
+ # attn_implementation="flash_attention_2", # enable only if flash-attn in requirements
117
+ ).eval()
118
+
119
+ # Resize if we added new pad token
120
  if model.get_input_embeddings().num_embeddings != len(tok):
121
  model.resize_token_embeddings(len(tok))
122
 
123
+ # Prefer KV cache
124
+ try:
125
+ model.generation_config.use_cache = True
126
+ except Exception:
127
+ pass
128
+
129
  _MODEL_CACHE[model_id] = (tok, model)
130
  return tok, model
131
 
 
149
  return f"{prefix}<<USER>>\n{usr}\n<</USER>>\n<<ASSISTANT>>\n"
150
 
151
 
152
+ @torch.inference_mode()
153
+ def _generate_microbatch(tok, model, formatted_prompts, gen_kwargs):
154
+ """Generate for a list of formatted prompts. Returns (texts, tokens_out)."""
155
+ device = model.device
156
+ eos_ids = _all_eos_ids(tok)
157
+
158
+ enc = tok(
159
+ formatted_prompts,
160
+ return_tensors="pt",
161
+ padding=True,
162
+ truncation=True,
163
+ max_length=MAX_INPUT_TOKENS,
164
+ ).to(device)
165
+
166
+ prompt_lens = enc["attention_mask"].sum(dim=1)
167
+ outputs = model.generate(
168
+ **enc,
169
+ eos_token_id=eos_ids,
170
+ pad_token_id=tok.pad_token_id,
171
+ **gen_kwargs,
172
+ )
173
+
174
+ texts, toks_out = [], []
175
+ for i in range(outputs.size(0)):
176
+ start = int(prompt_lens[i].item())
177
+ gen_ids = outputs[i, start:]
178
+ texts.append(tok.decode(gen_ids, skip_special_tokens=True).strip())
179
+ toks_out.append(int(gen_ids.numel()))
180
+ return texts, toks_out
181
+
182
+
183
  def generate_batch_df(
184
  model_id: str,
185
  system_prompt: str,
 
191
  repetition_penalty: float,
192
  ) -> pd.DataFrame:
193
  tok, model = _load_model(model_id)
 
194
 
195
+ # Split user inputs
196
  prompts = [p.strip() for p in prompts_multiline.splitlines() if p.strip()]
197
  if not prompts:
198
  return pd.DataFrame([{"user_prompt": "", "response": "", "tokens_out": 0}])
199
 
200
  formatted = [_format_prompt(tok, system_prompt, p) for p in prompts]
 
 
 
 
 
 
 
 
201
 
202
+ # Micro-batch multi-line input to keep latency low on ZeroGPU
203
+ B = MICROBATCH if len(formatted) > MICROBATCH else len(formatted)
204
+
205
+ # Greedy is fine (and fastest). If temp > 0, enable sampling knobs.
206
+ do_sample = bool(temperature > 0.0)
207
+ gen_kwargs = dict(
208
+ max_new_tokens=int(max_new_tokens),
209
+ do_sample=do_sample,
210
+ temperature=float(temperature) if do_sample else None,
211
+ top_p=float(top_p) if do_sample else None,
212
+ top_k=int(top_k) if (do_sample and int(top_k) > 0) else None,
213
+ repetition_penalty=float(repetition_penalty),
214
+ num_beams=1,
215
+ return_dict_in_generate=False,
216
+ use_cache=True,
217
+ )
218
 
219
+ all_texts, all_toks = [], []
220
+ for i in range(0, len(formatted), B):
221
+ batch_prompts = formatted[i : i + B]
222
+ texts, toks = _generate_microbatch(tok, model, batch_prompts, gen_kwargs)
223
+ all_texts.extend(texts)
224
+ all_toks.extend(toks)
 
225
 
226
  return pd.DataFrame(
227
+ {"user_prompt": prompts, "response": all_texts, "tokens_out": all_toks}
228
  )
229
 
230
 
 
239
  # Gradio UI
240
  # ----------------------------
241
 
242
+ with gr.Blocks(title="Multi-Prompt Chat (ZeroGPU-optimised)") as demo:
243
  gr.Markdown(
244
  """
245
+ # Multi-Prompt Chat to test system prompt effects (ZeroGPU-optimised)
246
+ Pick a small model, set a **system prompt**, and enter **multiple user prompts** (one per line).
247
  Click **Generate** to get batched responses and a **downloadable CSV**.
248
  """
249
  )
 
254
  choices=DEFAULT_MODELS,
255
  value=DEFAULT_MODELS[0],
256
  label="Model",
257
+ info="ZeroGPU attaches an H200 dynamically. 4-bit is used automatically on GPU.",
258
  )
259
  system_prompt = gr.Textbox(
260
  label="System prompt",
261
+ placeholder="e.g., You are an ecolinguistics-aware assistant...",
262
  lines=5,
263
  )
264
  prompts_multiline = gr.Textbox(
265
  label="User prompts (one per line)",
266
+ placeholder="One query per line.\nExample:\nExplain transformers in simple terms\nGive 3 eco-friendly tips\nSummarise benefits of multilingual models",
267
  lines=10,
268
  )
269
 
270
  with gr.Accordion("Generation settings", open=False):
271
+ max_new_tokens = gr.Slider(16, 1024, value=200, step=1, label="max_new_tokens")
272
+ temperature = gr.Slider(0.0, 2.0, value=0.0, step=0.05, label="temperature (0 = greedy, fastest)")
273
+ top_p = gr.Slider(0.1, 1.0, value=0.9, step=0.05, label="top_p (used if temp > 0)")
274
+ top_k = gr.Slider(0, 200, value=40, step=1, label="top_k (0 disables; used if temp > 0)")
275
  repetition_penalty = gr.Slider(1.0, 2.0, value=1.1, step=0.01, label="repetition_penalty")
276
 
277
  run_btn = gr.Button("Generate", variant="primary")
 
284
  wrap=True,
285
  interactive=False,
286
  row_count=(0, "dynamic"),
287
+ type="pandas",
288
  )
 
 
289
  csv_out = gr.File(label="CSV output", interactive=False, type="filepath")
290
 
291
+ # -------- Callback: GPU-decorated for ZeroGPU --------
292
+
293
+ @spaces.GPU() # <— This tells ZeroGPU to attach a GPU for this request
294
+ def _generate_cb(model_id, system_prompt, prompts_multiline,
295
+ max_new_tokens, temperature, top_p, top_k, repetition_penalty,
296
+ progress=gr.Progress(track_tqdm=True)):
297
 
298
+ progress(0.05, desc="Requesting ZeroGPU…")
299
  df = generate_batch_df(
300
  model_id=model_id,
301
  system_prompt=system_prompt,
 
306
  top_k=int(top_k),
307
  repetition_penalty=float(repetition_penalty),
308
  )
309
+ progress(0.95, desc="Preparing CSV…")
310
  csv_path = write_csv_path(df)
311
+ progress(1.0, desc="Done")
312
+ return df, csv_path
313
 
314
  run_btn.click(
315
  _generate_cb,