dkundel-openai commited on
Commit
97acba9
·
verified ·
1 Parent(s): 484a8d1

streaming (#22)

Browse files

- feat: streaming for analysis and answer (7579caaf8a844b420fe864d14f30b04efd1e63bf)

Files changed (1) hide show
  1. app.py +86 -77
app.py CHANGED
@@ -1,10 +1,15 @@
 
 
1
  import os
 
2
  import time
3
  from typing import List, Dict, Tuple
 
4
 
 
5
  import gradio as gr
6
- from transformers import pipeline
7
- import spaces
8
 
9
  # === Config (override via Space secrets/env vars) ===
10
  MODEL_ID = os.environ.get("MODEL_ID", "openai/gpt-oss-safeguard-20b")
@@ -14,6 +19,8 @@ DEFAULT_TOP_P = float(os.environ.get("TOP_P", 1.0))
14
  DEFAULT_REPETITION_PENALTY = float(os.environ.get("REPETITION_PENALTY", 1.0))
15
  ZGPU_DURATION = int(os.environ.get("ZGPU_DURATION", 120)) # seconds
16
 
 
 
17
  SAMPLE_POLICY = """
18
  Spam Policy (#SP)
19
  GOAL: Identify spam. Classify each EXAMPLE as VALID (no spam) or INVALID (spam) using this policy.
@@ -123,13 +130,38 @@ If financial harm or fraud → classify SP4.
123
  If combined with other indicators of abuse, violence, or illicit behavior, apply highest severity policy.
124
  """
125
 
126
- _pipe = None # cached pipeline
 
 
127
 
128
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
129
  # ----------------------------
130
  # Helpers (simple & explicit)
131
  # ----------------------------
132
 
 
133
  def _to_messages(policy: str, user_prompt: str) -> List[Dict[str, str]]:
134
  msgs: List[Dict[str, str]] = []
135
  if policy.strip():
@@ -138,94 +170,71 @@ def _to_messages(policy: str, user_prompt: str) -> List[Dict[str, str]]:
138
  return msgs
139
 
140
 
141
- def _extract_assistant_content(outputs) -> str:
142
- """Extract the assistant's content from the known shape:
143
- outputs = [
144
- {
145
- 'generated_text': [
146
- {'role': 'system', 'content': ...},
147
- {'role': 'user', 'content': ...},
148
- {'role': 'assistant', 'content': 'analysis...assistantfinal...'}
149
- ]
150
- }
151
- ]
152
- Keep this forgiving and minimal.
153
- """
154
- try:
155
- msgs = outputs[0]["generated_text"]
156
- for m in reversed(msgs):
157
- if isinstance(m, dict) and m.get("role") == "assistant":
158
- return m.get("content", "")
159
- last = msgs[-1]
160
- return last.get("content", "") if isinstance(last, dict) else str(last)
161
- except Exception:
162
- return str(outputs)
163
-
164
-
165
- def _parse_harmony_output_from_string(s: str) -> Tuple[str, str]:
166
- """Split a Harmony-style concatenated string into (analysis, final).
167
- Expects markers 'analysis' ... 'assistantfinal'.
168
- No heavy parsing — just string finds.
169
- """
170
- if not isinstance(s, str):
171
- s = str(s)
172
- final_key = "assistantfinal"
173
- j = s.find(final_key)
174
- if j != -1:
175
- final_text = s[j + len(final_key):].strip()
176
- i = s.find("analysis")
177
- if i != -1 and i < j:
178
- analysis_text = s[i + len("analysis"): j].strip()
179
- else:
180
- analysis_text = s[:j].strip()
181
- return analysis_text, final_text
182
- # no explicit final marker
183
- if s.startswith("analysis"):
184
- return s[len("analysis"):].strip(), ""
185
- return "", s.strip()
186
-
187
-
188
  # ----------------------------
189
  # Inference
190
  # ----------------------------
191
 
192
  @spaces.GPU(duration=ZGPU_DURATION)
193
- def generate_long_prompt(
194
- policy: str,
195
- prompt: str,
196
- max_new_tokens: int,
197
- temperature: float,
198
- top_p: float,
199
- repetition_penalty: float,
200
  ) -> Tuple[str, str, str]:
201
- global _pipe
202
- start = time.time()
203
 
204
- if _pipe is None:
205
- _pipe = pipeline(
206
- task="text-generation",
207
- model=MODEL_ID,
208
- torch_dtype="auto",
209
- device_map="auto",
210
- )
211
 
212
  messages = _to_messages(policy, prompt)
213
 
214
- outputs = _pipe(
 
 
 
 
 
 
215
  messages,
 
 
 
 
 
 
 
 
216
  max_new_tokens=max_new_tokens,
217
- do_sample=True,
218
- temperature=temperature,
219
  top_p=top_p,
220
- repetition_penalty=repetition_penalty,
 
 
221
  )
222
 
223
- assistant_str = _extract_assistant_content(outputs)
224
- analysis_text, final_text = _parse_harmony_output_from_string(assistant_str)
225
-
226
- elapsed = time.time() - start
227
- meta = f"Model: {MODEL_ID} | Time: {elapsed:.1f}s | max_new_tokens={max_new_tokens}"
228
- return analysis_text or "(No analysis)", final_text or "(No answer)", meta
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
229
 
230
 
231
  # ----------------------------
@@ -269,7 +278,7 @@ with gr.Blocks(css=CUSTOM_CSS, theme=gr.themes.Soft()) as demo:
269
  meta = gr.Markdown()
270
 
271
  btn.click(
272
- fn=generate_long_prompt,
273
  inputs=[policy, prompt, max_new_tokens, temperature, top_p, repetition_penalty],
274
  outputs=[analysis, answer, meta],
275
  concurrency_limit=1,
 
1
+ import spaces
2
+
3
  import os
4
+ import re
5
  import time
6
  from typing import List, Dict, Tuple
7
+ import threading
8
 
9
+ import torch
10
  import gradio as gr
11
+ from transformers import AutoTokenizer, AutoModelForCausalLM, TextIteratorStreamer
12
+
13
 
14
  # === Config (override via Space secrets/env vars) ===
15
  MODEL_ID = os.environ.get("MODEL_ID", "openai/gpt-oss-safeguard-20b")
 
19
  DEFAULT_REPETITION_PENALTY = float(os.environ.get("REPETITION_PENALTY", 1.0))
20
  ZGPU_DURATION = int(os.environ.get("ZGPU_DURATION", 120)) # seconds
21
 
22
+ ANALYSIS_PATTERN = analysis_match = re.compile(r'^(.*)assistantfinal', flags=re.DOTALL)
23
+
24
  SAMPLE_POLICY = """
25
  Spam Policy (#SP)
26
  GOAL: Identify spam. Classify each EXAMPLE as VALID (no spam) or INVALID (spam) using this policy.
 
130
  If combined with other indicators of abuse, violence, or illicit behavior, apply highest severity policy.
131
  """
132
 
133
+ _tokenizer = None
134
+ _model = None
135
+ _device = None
136
 
137
 
138
+ def _ensure_loaded():
139
+ print("Loading model and tokenizer")
140
+ global _tokenizer, _model, _device
141
+ if _tokenizer is not None and _model is not None:
142
+ return
143
+ _tokenizer = AutoTokenizer.from_pretrained(
144
+ MODEL_ID, trust_remote_code=True
145
+ )
146
+ _model = AutoModelForCausalLM.from_pretrained(
147
+ MODEL_ID,
148
+ trust_remote_code=True,
149
+ # torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32,
150
+ low_cpu_mem_usage=True,
151
+ device_map="auto" if torch.cuda.is_available() else None,
152
+ )
153
+ if _tokenizer.pad_token_id is None and _tokenizer.eos_token_id is not None:
154
+ _tokenizer.pad_token = _tokenizer.eos_token
155
+ _model.eval()
156
+ _device = next(_model.parameters()).device
157
+
158
+ _ensure_loaded()
159
+
160
  # ----------------------------
161
  # Helpers (simple & explicit)
162
  # ----------------------------
163
 
164
+
165
  def _to_messages(policy: str, user_prompt: str) -> List[Dict[str, str]]:
166
  msgs: List[Dict[str, str]] = []
167
  if policy.strip():
 
170
  return msgs
171
 
172
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
173
  # ----------------------------
174
  # Inference
175
  # ----------------------------
176
 
177
  @spaces.GPU(duration=ZGPU_DURATION)
178
+ def generate_stream(
179
+ policy: str,
180
+ prompt: str,
181
+ max_new_tokens: int,
182
+ temperature: float,
183
+ top_p: float,
184
+ repetition_penalty: float,
185
  ) -> Tuple[str, str, str]:
 
 
186
 
187
+ start = time.time()
 
 
 
 
 
 
188
 
189
  messages = _to_messages(policy, prompt)
190
 
191
+ streamer = TextIteratorStreamer(
192
+ _tokenizer,
193
+ skip_special_tokens=True,
194
+ skip_prompt=True, # <-- key fix
195
+ )
196
+
197
+ inputs = _tokenizer.apply_chat_template(
198
  messages,
199
+ return_tensors="pt",
200
+ add_generation_prompt=True,
201
+ )
202
+ input_ids = inputs["input_ids"] if isinstance(inputs, dict) else inputs
203
+ input_ids = input_ids.to(_device)
204
+
205
+ gen_kwargs = dict(
206
+ input_ids=input_ids,
207
  max_new_tokens=max_new_tokens,
208
+ do_sample=temperature > 0.0,
209
+ temperature=float(temperature),
210
  top_p=top_p,
211
+ pad_token_id=_tokenizer.pad_token_id,
212
+ eos_token_id=_tokenizer.eos_token_id,
213
+ streamer=streamer,
214
  )
215
 
216
+ thread = threading.Thread(target=_model.generate, kwargs=gen_kwargs)
217
+ thread.start()
218
+
219
+ analysis = ""
220
+ output = ""
221
+ for new_text in streamer:
222
+ output += new_text
223
+ if not analysis:
224
+ m = ANALYSIS_PATTERN.match(output)
225
+ if m:
226
+ analysis = re.sub(r'^analysis\s*', '', m.group(1))
227
+ output = ""
228
+
229
+ if not analysis:
230
+ analysis_text = re.sub(r'^analysis\s*', '', output)
231
+ final_text = None
232
+ else:
233
+ analysis_text = analysis
234
+ final_text = output
235
+ elapsed = time.time() - start
236
+ meta = f"Model: {MODEL_ID} | Time: {elapsed:.1f}s | max_new_tokens={max_new_tokens}"
237
+ yield analysis_text or "(No analysis)", final_text or "(No answer)", meta
238
 
239
 
240
  # ----------------------------
 
278
  meta = gr.Markdown()
279
 
280
  btn.click(
281
+ fn=generate_stream,
282
  inputs=[policy, prompt, max_new_tokens, temperature, top_p, repetition_penalty],
283
  outputs=[analysis, answer, meta],
284
  concurrency_limit=1,