neovalle commited on
Commit
ca66ea6
·
verified ·
1 Parent(s): 865d725

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +12 -40
app.py CHANGED
@@ -1,5 +1,4 @@
1
  # app.py
2
- import io
3
  import tempfile
4
  from datetime import datetime
5
 
@@ -12,7 +11,6 @@ from transformers import AutoModelForCausalLM, AutoTokenizer
12
  # Config
13
  # ----------------------------
14
 
15
- # Small, free, instruction-tuned models that can run on CPU (Basic Space).
16
  DEFAULT_MODELS = [
17
  "google/gemma-2-2b-it",
18
  "TinyLlama/TinyLlama-1.1B-Chat-v1.0",
@@ -27,13 +25,12 @@ _MODEL_CACHE = {} # cache: model_id -> (tokenizer, model)
27
  # ----------------------------
28
 
29
  def _load_model(model_id: str):
30
- """Load tokenizer and model (cached)."""
31
  if model_id in _MODEL_CACHE:
32
  return _MODEL_CACHE[model_id]
33
 
34
  tok = AutoTokenizer.from_pretrained(model_id, use_fast=True)
35
 
36
- # Ensure we have a pad token to avoid generate() warnings/errors.
37
  if tok.pad_token is None:
38
  if tok.eos_token is not None:
39
  tok.pad_token = tok.eos_token
@@ -47,8 +44,6 @@ def _load_model(model_id: str):
47
  low_cpu_mem_usage=True,
48
  device_map="auto",
49
  )
50
-
51
- # If we added tokens, resize embeddings.
52
  if model.get_input_embeddings().num_embeddings != len(tok):
53
  model.resize_token_embeddings(len(tok))
54
 
@@ -57,7 +52,6 @@ def _load_model(model_id: str):
57
 
58
 
59
  def _format_prompt(tokenizer, system_prompt: str, user_prompt: str) -> str:
60
- """Prefer each model's chat template; fallback to a simple instruction format."""
61
  sys = (system_prompt or "").strip()
62
  usr = (user_prompt or "").strip()
63
 
@@ -72,12 +66,11 @@ def _format_prompt(tokenizer, system_prompt: str, user_prompt: str) -> str:
72
  add_generation_prompt=True,
73
  )
74
 
75
- # Fallback plain format
76
  prefix = f"<<SYS>>\n{sys}\n<</SYS>>\n\n" if sys else ""
77
  return f"{prefix}<<USER>>\n{usr}\n<</USER>>\n<<ASSISTANT>>\n"
78
 
79
 
80
- def generate_batch(
81
  model_id: str,
82
  system_prompt: str,
83
  prompts_multiline: str,
@@ -87,16 +80,13 @@ def generate_batch(
87
  top_k: int,
88
  repetition_penalty: float,
89
  ) -> pd.DataFrame:
90
- """Generate responses for multiple user prompts (one per line)."""
91
  tok, model = _load_model(model_id)
92
  device = model.device
93
 
94
- # Split lines, drop empties
95
  prompts = [p.strip() for p in prompts_multiline.splitlines() if p.strip()]
96
  if not prompts:
97
  return pd.DataFrame([{"user_prompt": "", "response": "", "tokens_out": 0}])
98
 
99
- # Build formatted prompts and encode
100
  formatted = [_format_prompt(tok, system_prompt, p) for p in prompts]
101
  enc = tok(
102
  formatted,
@@ -105,7 +95,6 @@ def generate_batch(
105
  truncation=True,
106
  ).to(device)
107
 
108
- # True prompt lengths per row (ignore padding)
109
  prompt_lens = enc["attention_mask"].sum(dim=1)
110
 
111
  with torch.no_grad():
@@ -121,7 +110,6 @@ def generate_batch(
121
  pad_token_id=tok.pad_token_id,
122
  )
123
 
124
- # Slice generated tokens using prompt lengths
125
  responses, tokens_out = [], []
126
  for i in range(gen.size(0)):
127
  start = int(prompt_lens[i].item())
@@ -135,9 +123,7 @@ def generate_batch(
135
  )
136
 
137
 
138
- def write_csv_tempfile(df: pd.DataFrame) -> str:
139
- """Write CSV to a real temp file and return its path (works in Spaces)."""
140
- # Use NamedTemporaryFile with delete=False so Gradio can read after returning.
141
  ts = datetime.utcnow().strftime("%Y%m%d-%H%M%S")
142
  tmp = tempfile.NamedTemporaryFile(prefix=f"batch_{ts}_", suffix=".csv", delete=False, dir="/tmp")
143
  df.to_csv(tmp.name, index=False)
@@ -153,7 +139,7 @@ with gr.Blocks(title="Multi-Prompt Chat (System Prompt Control)") as demo:
153
  """
154
  # 🧪 Multi-Prompt Chat for HF Space
155
  Pick a small free model, set a **system prompt**, and enter **multiple user prompts** (one per line).
156
- Click **Generate** to get batched responses, then **Download CSV** to save them.
157
  """
158
  )
159
 
@@ -186,9 +172,6 @@ with gr.Blocks(title="Multi-Prompt Chat (System Prompt Control)") as demo:
186
  run_btn = gr.Button("Generate", variant="primary")
187
 
188
  with gr.Column(scale=1):
189
- # Keep last results for downloading
190
- state_df = gr.State(value=None)
191
-
192
  out_df = gr.Dataframe(
193
  headers=["user_prompt", "response", "tokens_out"],
194
  datatype=["str", "str", "number"],
@@ -196,19 +179,16 @@ with gr.Blocks(title="Multi-Prompt Chat (System Prompt Control)") as demo:
196
  wrap=True,
197
  interactive=False,
198
  row_count=(0, "dynamic"),
199
- type="pandas", # ensure callbacks receive a pandas DataFrame
200
  )
201
 
202
- # File widget that will display a real downloadable file
203
- out_file = gr.File(label="Download CSV", visible=False)
204
-
205
- # Separate button to trigger file creation
206
- csv_btn = gr.Button("Prepare CSV for download")
207
 
208
- # -------- Callbacks --------
209
 
210
  def _generate_cb(model_id, system_prompt, prompts_multiline, max_new_tokens, temperature, top_p, top_k, repetition_penalty):
211
- df = generate_batch(
212
  model_id=model_id,
213
  system_prompt=system_prompt,
214
  prompts_multiline=prompts_multiline,
@@ -218,23 +198,15 @@ with gr.Blocks(title="Multi-Prompt Chat (System Prompt Control)") as demo:
218
  top_k=int(top_k),
219
  repetition_penalty=float(repetition_penalty),
220
  )
221
- return df, df # (table, state)
 
222
 
223
  run_btn.click(
224
  _generate_cb,
225
  inputs=[model_id, system_prompt, prompts_multiline, max_new_tokens, temperature, top_p, top_k, repetition_penalty],
226
- outputs=[out_df, state_df],
227
  api_name="generate_batch",
228
  )
229
 
230
- def _prepare_csv_cb(df_state):
231
- # Robust across Gradio versions: write to a real temp file and return its path
232
- if df_state is None or len(df_state) == 0:
233
- df_state = pd.DataFrame([{"user_prompt": "", "response": "", "tokens_out": 0}])
234
- path = write_csv_tempfile(df_state)
235
- return gr.File.update(value=path, visible=True)
236
-
237
- csv_btn.click(_prepare_csv_cb, inputs=[state_df], outputs=[out_file], api_name="download_csv")
238
-
239
  if __name__ == "__main__":
240
  demo.launch()
 
1
  # app.py
 
2
  import tempfile
3
  from datetime import datetime
4
 
 
11
  # Config
12
  # ----------------------------
13
 
 
14
  DEFAULT_MODELS = [
15
  "google/gemma-2-2b-it",
16
  "TinyLlama/TinyLlama-1.1B-Chat-v1.0",
 
25
  # ----------------------------
26
 
27
  def _load_model(model_id: str):
 
28
  if model_id in _MODEL_CACHE:
29
  return _MODEL_CACHE[model_id]
30
 
31
  tok = AutoTokenizer.from_pretrained(model_id, use_fast=True)
32
 
33
+ # Ensure pad token exists for generate()
34
  if tok.pad_token is None:
35
  if tok.eos_token is not None:
36
  tok.pad_token = tok.eos_token
 
44
  low_cpu_mem_usage=True,
45
  device_map="auto",
46
  )
 
 
47
  if model.get_input_embeddings().num_embeddings != len(tok):
48
  model.resize_token_embeddings(len(tok))
49
 
 
52
 
53
 
54
  def _format_prompt(tokenizer, system_prompt: str, user_prompt: str) -> str:
 
55
  sys = (system_prompt or "").strip()
56
  usr = (user_prompt or "").strip()
57
 
 
66
  add_generation_prompt=True,
67
  )
68
 
 
69
  prefix = f"<<SYS>>\n{sys}\n<</SYS>>\n\n" if sys else ""
70
  return f"{prefix}<<USER>>\n{usr}\n<</USER>>\n<<ASSISTANT>>\n"
71
 
72
 
73
+ def generate_batch_df(
74
  model_id: str,
75
  system_prompt: str,
76
  prompts_multiline: str,
 
80
  top_k: int,
81
  repetition_penalty: float,
82
  ) -> pd.DataFrame:
 
83
  tok, model = _load_model(model_id)
84
  device = model.device
85
 
 
86
  prompts = [p.strip() for p in prompts_multiline.splitlines() if p.strip()]
87
  if not prompts:
88
  return pd.DataFrame([{"user_prompt": "", "response": "", "tokens_out": 0}])
89
 
 
90
  formatted = [_format_prompt(tok, system_prompt, p) for p in prompts]
91
  enc = tok(
92
  formatted,
 
95
  truncation=True,
96
  ).to(device)
97
 
 
98
  prompt_lens = enc["attention_mask"].sum(dim=1)
99
 
100
  with torch.no_grad():
 
110
  pad_token_id=tok.pad_token_id,
111
  )
112
 
 
113
  responses, tokens_out = [], []
114
  for i in range(gen.size(0)):
115
  start = int(prompt_lens[i].item())
 
123
  )
124
 
125
 
126
+ def write_csv_path(df: pd.DataFrame) -> str:
 
 
127
  ts = datetime.utcnow().strftime("%Y%m%d-%H%M%S")
128
  tmp = tempfile.NamedTemporaryFile(prefix=f"batch_{ts}_", suffix=".csv", delete=False, dir="/tmp")
129
  df.to_csv(tmp.name, index=False)
 
139
  """
140
  # 🧪 Multi-Prompt Chat for HF Space
141
  Pick a small free model, set a **system prompt**, and enter **multiple user prompts** (one per line).
142
+ Click **Generate** to get batched responses and a **downloadable CSV**.
143
  """
144
  )
145
 
 
172
  run_btn = gr.Button("Generate", variant="primary")
173
 
174
  with gr.Column(scale=1):
 
 
 
175
  out_df = gr.Dataframe(
176
  headers=["user_prompt", "response", "tokens_out"],
177
  datatype=["str", "str", "number"],
 
179
  wrap=True,
180
  interactive=False,
181
  row_count=(0, "dynamic"),
182
+ type="pandas", # ensures pandas goes into callbacks
183
  )
184
 
185
+ # IMPORTANT: type="filepath" so we can return a string path
186
+ csv_out = gr.File(label="Scored CSV", interactive=False, type="filepath")
 
 
 
187
 
188
+ # -------- Callback: generate table AND CSV path in one go --------
189
 
190
  def _generate_cb(model_id, system_prompt, prompts_multiline, max_new_tokens, temperature, top_p, top_k, repetition_penalty):
191
+ df = generate_batch_df(
192
  model_id=model_id,
193
  system_prompt=system_prompt,
194
  prompts_multiline=prompts_multiline,
 
198
  top_k=int(top_k),
199
  repetition_penalty=float(repetition_penalty),
200
  )
201
+ csv_path = write_csv_path(df)
202
+ return df, csv_path # DataFrame to table, path to File(type="filepath")
203
 
204
  run_btn.click(
205
  _generate_cb,
206
  inputs=[model_id, system_prompt, prompts_multiline, max_new_tokens, temperature, top_p, top_k, repetition_penalty],
207
+ outputs=[out_df, csv_out],
208
  api_name="generate_batch",
209
  )
210
 
 
 
 
 
 
 
 
 
 
211
  if __name__ == "__main__":
212
  demo.launch()