neovalle commited on
Commit
865d725
·
verified ·
1 Parent(s): 4763e5c

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +37 -43
app.py CHANGED
@@ -1,5 +1,6 @@
1
  # app.py
2
  import io
 
3
  from datetime import datetime
4
 
5
  import gradio as gr
@@ -11,26 +12,20 @@ from transformers import AutoModelForCausalLM, AutoTokenizer
11
  # Config
12
  # ----------------------------
13
 
14
- # Small, free, instruction-tuned models that run on CPU in a Basic Space.
15
  DEFAULT_MODELS = [
16
  "google/gemma-2-2b-it",
17
  "TinyLlama/TinyLlama-1.1B-Chat-v1.0",
18
  "Qwen/Qwen2.5-1.5B-Instruct",
19
  ]
20
 
21
- _MODEL_CACHE = {} # (tokenizer, model) cache
22
 
23
 
24
  # ----------------------------
25
  # Utilities
26
  # ----------------------------
27
 
28
- def df_to_csv_bytes(df: pd.DataFrame) -> bytes:
29
- buf = io.StringIO()
30
- df.to_csv(buf, index=False)
31
- return buf.getvalue().encode("utf-8")
32
-
33
-
34
  def _load_model(model_id: str):
35
  """Load tokenizer and model (cached)."""
36
  if model_id in _MODEL_CACHE:
@@ -38,9 +33,8 @@ def _load_model(model_id: str):
38
 
39
  tok = AutoTokenizer.from_pretrained(model_id, use_fast=True)
40
 
41
- # Ensure we have a pad token to avoid warnings in generate
42
  if tok.pad_token is None:
43
- # Prefer eos_token, else add a pad token
44
  if tok.eos_token is not None:
45
  tok.pad_token = tok.eos_token
46
  else:
@@ -53,7 +47,8 @@ def _load_model(model_id: str):
53
  low_cpu_mem_usage=True,
54
  device_map="auto",
55
  )
56
- # If we added a pad token, resize embeddings
 
57
  if model.get_input_embeddings().num_embeddings != len(tok):
58
  model.resize_token_embeddings(len(tok))
59
 
@@ -62,9 +57,7 @@ def _load_model(model_id: str):
62
 
63
 
64
  def _format_prompt(tokenizer, system_prompt: str, user_prompt: str) -> str:
65
- """
66
- Prefer the model's chat template. Fallback to a light instruction format.
67
- """
68
  sys = (system_prompt or "").strip()
69
  usr = (user_prompt or "").strip()
70
 
@@ -79,7 +72,7 @@ def _format_prompt(tokenizer, system_prompt: str, user_prompt: str) -> str:
79
  add_generation_prompt=True,
80
  )
81
 
82
- # Fallback format
83
  prefix = f"<<SYS>>\n{sys}\n<</SYS>>\n\n" if sys else ""
84
  return f"{prefix}<<USER>>\n{usr}\n<</USER>>\n<<ASSISTANT>>\n"
85
 
@@ -98,14 +91,13 @@ def generate_batch(
98
  tok, model = _load_model(model_id)
99
  device = model.device
100
 
101
- # Split lines, discard empties
102
  prompts = [p.strip() for p in prompts_multiline.splitlines() if p.strip()]
103
  if not prompts:
104
  return pd.DataFrame([{"user_prompt": "", "response": "", "tokens_out": 0}])
105
 
106
- # Build formatted prompts per model
107
  formatted = [_format_prompt(tok, system_prompt, p) for p in prompts]
108
-
109
  enc = tok(
110
  formatted,
111
  return_tensors="pt",
@@ -113,7 +105,7 @@ def generate_batch(
113
  truncation=True,
114
  ).to(device)
115
 
116
- # True prompt lengths per row (use attention mask sum to ignore padding)
117
  prompt_lens = enc["attention_mask"].sum(dim=1)
118
 
119
  with torch.no_grad():
@@ -129,9 +121,8 @@ def generate_batch(
129
  pad_token_id=tok.pad_token_id,
130
  )
131
 
132
- # Slice generated tokens per row using actual prompt length
133
- responses = []
134
- tokens_out = []
135
  for i in range(gen.size(0)):
136
  start = int(prompt_lens[i].item())
137
  gen_ids = gen[i, start:]
@@ -139,14 +130,18 @@ def generate_batch(
139
  responses.append(text)
140
  tokens_out.append(len(gen_ids))
141
 
142
- df = pd.DataFrame(
143
- {
144
- "user_prompt": prompts,
145
- "response": responses,
146
- "tokens_out": tokens_out,
147
- }
148
  )
149
- return df
 
 
 
 
 
 
 
 
150
 
151
 
152
  # ----------------------------
@@ -158,7 +153,7 @@ with gr.Blocks(title="Multi-Prompt Chat (System Prompt Control)") as demo:
158
  """
159
  # 🧪 Multi-Prompt Chat for HF Space
160
  Pick a small free model, set a **system prompt**, and enter **multiple user prompts** (one per line).
161
- Click **Generate** to get batched responses, then **Download CSV** for offline use.
162
  """
163
  )
164
 
@@ -191,7 +186,7 @@ with gr.Blocks(title="Multi-Prompt Chat (System Prompt Control)") as demo:
191
  run_btn = gr.Button("Generate", variant="primary")
192
 
193
  with gr.Column(scale=1):
194
- # Keep last results for stable downloads
195
  state_df = gr.State(value=None)
196
 
197
  out_df = gr.Dataframe(
@@ -201,14 +196,14 @@ with gr.Blocks(title="Multi-Prompt Chat (System Prompt Control)") as demo:
201
  wrap=True,
202
  interactive=False,
203
  row_count=(0, "dynamic"),
204
- type="pandas", # ensure callbacks get a pandas DataFrame
205
  )
206
 
207
- # Older Gradio versions don't support file_name on DownloadButton
208
- download_btn = gr.DownloadButton(
209
- label="Download CSV",
210
- value=None, # we update this with bytes on demand
211
- )
212
 
213
  # -------- Callbacks --------
214
 
@@ -223,7 +218,7 @@ with gr.Blocks(title="Multi-Prompt Chat (System Prompt Control)") as demo:
223
  top_k=int(top_k),
224
  repetition_penalty=float(repetition_penalty),
225
  )
226
- return df, df # show in table, also store in state
227
 
228
  run_btn.click(
229
  _generate_cb,
@@ -233,14 +228,13 @@ with gr.Blocks(title="Multi-Prompt Chat (System Prompt Control)") as demo:
233
  )
234
 
235
  def _prepare_csv_cb(df_state):
236
- # Fallback-safe: produce bytes only (older Gradio uses a default filename)
237
  if df_state is None or len(df_state) == 0:
238
  df_state = pd.DataFrame([{"user_prompt": "", "response": "", "tokens_out": 0}])
239
- csv_bytes = df_to_csv_bytes(df_state)
240
- # Some Gradio versions ignore filename updates; return bytes only for compatibility
241
- return gr.DownloadButton.update(value=csv_bytes)
242
 
243
- download_btn.click(_prepare_csv_cb, inputs=[state_df], outputs=[download_btn], api_name="download_csv")
244
 
245
  if __name__ == "__main__":
246
  demo.launch()
 
1
  # app.py
2
  import io
3
+ import tempfile
4
  from datetime import datetime
5
 
6
  import gradio as gr
 
12
  # Config
13
  # ----------------------------
14
 
15
+ # Small, free, instruction-tuned models that can run on CPU (Basic Space).
16
  DEFAULT_MODELS = [
17
  "google/gemma-2-2b-it",
18
  "TinyLlama/TinyLlama-1.1B-Chat-v1.0",
19
  "Qwen/Qwen2.5-1.5B-Instruct",
20
  ]
21
 
22
+ _MODEL_CACHE = {} # cache: model_id -> (tokenizer, model)
23
 
24
 
25
  # ----------------------------
26
  # Utilities
27
  # ----------------------------
28
 
 
 
 
 
 
 
29
  def _load_model(model_id: str):
30
  """Load tokenizer and model (cached)."""
31
  if model_id in _MODEL_CACHE:
 
33
 
34
  tok = AutoTokenizer.from_pretrained(model_id, use_fast=True)
35
 
36
+ # Ensure we have a pad token to avoid generate() warnings/errors.
37
  if tok.pad_token is None:
 
38
  if tok.eos_token is not None:
39
  tok.pad_token = tok.eos_token
40
  else:
 
47
  low_cpu_mem_usage=True,
48
  device_map="auto",
49
  )
50
+
51
+ # If we added tokens, resize embeddings.
52
  if model.get_input_embeddings().num_embeddings != len(tok):
53
  model.resize_token_embeddings(len(tok))
54
 
 
57
 
58
 
59
  def _format_prompt(tokenizer, system_prompt: str, user_prompt: str) -> str:
60
+ """Prefer each model's chat template; fallback to a simple instruction format."""
 
 
61
  sys = (system_prompt or "").strip()
62
  usr = (user_prompt or "").strip()
63
 
 
72
  add_generation_prompt=True,
73
  )
74
 
75
+ # Fallback plain format
76
  prefix = f"<<SYS>>\n{sys}\n<</SYS>>\n\n" if sys else ""
77
  return f"{prefix}<<USER>>\n{usr}\n<</USER>>\n<<ASSISTANT>>\n"
78
 
 
91
  tok, model = _load_model(model_id)
92
  device = model.device
93
 
94
+ # Split lines, drop empties
95
  prompts = [p.strip() for p in prompts_multiline.splitlines() if p.strip()]
96
  if not prompts:
97
  return pd.DataFrame([{"user_prompt": "", "response": "", "tokens_out": 0}])
98
 
99
+ # Build formatted prompts and encode
100
  formatted = [_format_prompt(tok, system_prompt, p) for p in prompts]
 
101
  enc = tok(
102
  formatted,
103
  return_tensors="pt",
 
105
  truncation=True,
106
  ).to(device)
107
 
108
+ # True prompt lengths per row (ignore padding)
109
  prompt_lens = enc["attention_mask"].sum(dim=1)
110
 
111
  with torch.no_grad():
 
121
  pad_token_id=tok.pad_token_id,
122
  )
123
 
124
+ # Slice generated tokens using prompt lengths
125
+ responses, tokens_out = [], []
 
126
  for i in range(gen.size(0)):
127
  start = int(prompt_lens[i].item())
128
  gen_ids = gen[i, start:]
 
130
  responses.append(text)
131
  tokens_out.append(len(gen_ids))
132
 
133
+ return pd.DataFrame(
134
+ {"user_prompt": prompts, "response": responses, "tokens_out": tokens_out}
 
 
 
 
135
  )
136
+
137
+
138
+ def write_csv_tempfile(df: pd.DataFrame) -> str:
139
+ """Write CSV to a real temp file and return its path (works in Spaces)."""
140
+ # Use NamedTemporaryFile with delete=False so Gradio can read after returning.
141
+ ts = datetime.utcnow().strftime("%Y%m%d-%H%M%S")
142
+ tmp = tempfile.NamedTemporaryFile(prefix=f"batch_{ts}_", suffix=".csv", delete=False, dir="/tmp")
143
+ df.to_csv(tmp.name, index=False)
144
+ return tmp.name
145
 
146
 
147
  # ----------------------------
 
153
  """
154
  # 🧪 Multi-Prompt Chat for HF Space
155
  Pick a small free model, set a **system prompt**, and enter **multiple user prompts** (one per line).
156
+ Click **Generate** to get batched responses, then **Download CSV** to save them.
157
  """
158
  )
159
 
 
186
  run_btn = gr.Button("Generate", variant="primary")
187
 
188
  with gr.Column(scale=1):
189
+ # Keep last results for downloading
190
  state_df = gr.State(value=None)
191
 
192
  out_df = gr.Dataframe(
 
196
  wrap=True,
197
  interactive=False,
198
  row_count=(0, "dynamic"),
199
+ type="pandas", # ensure callbacks receive a pandas DataFrame
200
  )
201
 
202
+ # File widget that will display a real downloadable file
203
+ out_file = gr.File(label="Download CSV", visible=False)
204
+
205
+ # Separate button to trigger file creation
206
+ csv_btn = gr.Button("Prepare CSV for download")
207
 
208
  # -------- Callbacks --------
209
 
 
218
  top_k=int(top_k),
219
  repetition_penalty=float(repetition_penalty),
220
  )
221
+ return df, df # (table, state)
222
 
223
  run_btn.click(
224
  _generate_cb,
 
228
  )
229
 
230
  def _prepare_csv_cb(df_state):
231
+ # Robust across Gradio versions: write to a real temp file and return its path
232
  if df_state is None or len(df_state) == 0:
233
  df_state = pd.DataFrame([{"user_prompt": "", "response": "", "tokens_out": 0}])
234
+ path = write_csv_tempfile(df_state)
235
+ return gr.File.update(value=path, visible=True)
 
236
 
237
+ csv_btn.click(_prepare_csv_cb, inputs=[state_df], outputs=[out_file], api_name="download_csv")
238
 
239
  if __name__ == "__main__":
240
  demo.launch()