russelllarkin commited on
Commit
51742f6
·
verified ·
1 Parent(s): 10a7a66

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +357 -136
app.py CHANGED
@@ -1,154 +1,375 @@
 
1
  import gradio as gr
2
- import numpy as np
3
- import random
4
-
5
- # import spaces #[uncomment to use ZeroGPU]
6
- from diffusers import DiffusionPipeline
7
  import torch
8
 
9
- device = "cuda" if torch.cuda.is_available() else "cpu"
10
- model_repo_id = "stabilityai/sdxl-turbo" # Replace to the model you would like to use
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
11
 
12
- if torch.cuda.is_available():
13
- torch_dtype = torch.float16
14
- else:
15
- torch_dtype = torch.float32
 
 
 
 
 
 
 
16
 
17
- pipe = DiffusionPipeline.from_pretrained(model_repo_id, torch_dtype=torch_dtype)
18
- pipe = pipe.to(device)
 
 
 
 
 
 
 
 
19
 
20
- MAX_SEED = np.iinfo(np.int32).max
21
- MAX_IMAGE_SIZE = 1024
 
 
 
 
 
 
 
 
 
 
 
 
22
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
23
 
24
- # @spaces.GPU #[uncomment to use ZeroGPU]
25
- def infer(
26
- prompt,
27
- negative_prompt,
28
- seed,
29
- randomize_seed,
30
- width,
31
- height,
32
- guidance_scale,
33
- num_inference_steps,
34
- progress=gr.Progress(track_tqdm=True),
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
35
  ):
36
- if randomize_seed:
37
- seed = random.randint(0, MAX_SEED)
38
-
39
- generator = torch.Generator().manual_seed(seed)
40
-
41
- image = pipe(
42
- prompt=prompt,
43
- negative_prompt=negative_prompt,
44
- guidance_scale=guidance_scale,
45
- num_inference_steps=num_inference_steps,
46
- width=width,
47
- height=height,
48
- generator=generator,
49
- ).images[0]
50
-
51
- return image, seed
52
-
53
-
54
- examples = [
55
- "Astronaut in a jungle, cold color palette, muted colors, detailed, 8k",
56
- "An astronaut riding a green horse",
57
- "A delicious ceviche cheesecake slice",
58
- ]
59
-
60
- css = """
61
- #col-container {
62
- margin: 0 auto;
63
- max-width: 640px;
64
- }
65
- """
66
-
67
- with gr.Blocks(css=css) as demo:
68
- with gr.Column(elem_id="col-container"):
69
- gr.Markdown(" # Text-to-Image Gradio Template")
70
-
71
- with gr.Row():
72
- prompt = gr.Text(
73
- label="Prompt",
74
- show_label=False,
75
- max_lines=1,
76
- placeholder="Enter your prompt",
77
- container=False,
78
- )
79
 
80
- run_button = gr.Button("Run", scale=0, variant="primary")
81
 
82
- result = gr.Image(label="Result", show_label=False)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
83
 
84
- with gr.Accordion("Advanced Settings", open=False):
85
- negative_prompt = gr.Text(
86
- label="Negative prompt",
87
- max_lines=1,
88
- placeholder="Enter a negative prompt",
89
- visible=False,
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
90
  )
 
91
 
92
- seed = gr.Slider(
93
- label="Seed",
94
- minimum=0,
95
- maximum=MAX_SEED,
96
- step=1,
97
- value=0,
 
 
 
 
 
98
  )
99
 
100
- randomize_seed = gr.Checkbox(label="Randomize seed", value=True)
101
-
102
- with gr.Row():
103
- width = gr.Slider(
104
- label="Width",
105
- minimum=256,
106
- maximum=MAX_IMAGE_SIZE,
107
- step=32,
108
- value=1024, # Replace with defaults that work for your model
109
- )
110
-
111
- height = gr.Slider(
112
- label="Height",
113
- minimum=256,
114
- maximum=MAX_IMAGE_SIZE,
115
- step=32,
116
- value=1024, # Replace with defaults that work for your model
117
- )
118
-
119
- with gr.Row():
120
- guidance_scale = gr.Slider(
121
- label="Guidance scale",
122
- minimum=0.0,
123
- maximum=10.0,
124
- step=0.1,
125
- value=0.0, # Replace with defaults that work for your model
126
- )
127
-
128
- num_inference_steps = gr.Slider(
129
- label="Number of inference steps",
130
- minimum=1,
131
- maximum=50,
132
- step=1,
133
- value=2, # Replace with defaults that work for your model
134
- )
135
-
136
- gr.Examples(examples=examples, inputs=[prompt])
137
- gr.on(
138
- triggers=[run_button.click, prompt.submit],
139
- fn=infer,
140
- inputs=[
141
- prompt,
142
- negative_prompt,
143
- seed,
144
- randomize_seed,
145
- width,
146
- height,
147
- guidance_scale,
148
- num_inference_steps,
149
- ],
150
- outputs=[result, seed],
151
- )
152
 
153
  if __name__ == "__main__":
154
- demo.launch()
 
1
+ import time
2
  import gradio as gr
3
+ from transformers import pipeline
4
+ from huggingface_hub import InferenceClient
5
+ from typing import List, Dict, Tuple, Any, Optional
6
+ from diffusers import AutoPipelineForText2Image
 
7
  import torch
8
 
9
+ # Article Analysis Constants
10
+ MAX_CHAR = 8000
11
+ NER_NUM_ROWS = 10
12
+
13
+ # Model Constants
14
+ SUMM_MODEL_ID = "sshleifer/distilbart-cnn-12-6"
15
+ SENTIMENT_MODEL_ID = "ahmedrachid/FinancialBERT-Sentiment-Analysis"
16
+ FINCLS_MODEL_ID = "nickmuchi/distilroberta-finetuned-financial-text-classification"
17
+ NER_MODEL_ID = "dslim/bert-base-NER"
18
+ CHAT_MODEL_ID = "openai/gpt-oss-20b"
19
+ IMAGE_MODEL_ID = "stabilityai/sd-turbo"
20
+
21
+ _summ_pipe = None
22
+ _sentiment_pipe = None
23
+ _fincls_pipe = None
24
+ _ner_pipe = None
25
+ _img_pipe_cpu = None
26
+
27
+ # Image Constants
28
+ IMG_STEPS = 2
29
+ IMG_GUIDANCE = 0.5
30
+ IMG_WIDTH = 512
31
+ IMG_HEIGHT = 512
32
+
33
+ # Chat Constants
34
+ CHAT_MAX_TOKENS = 512
35
+ CHAT_TEMPERATURE = 0.7
36
+ CHAT_TOP_P = 0.95
37
+ CHAT_SYSTEM_PROMPT = ("\nYou are assisting with analysis of a financial news article."
38
+ + "\nBe clear, cite facts from context, and avoid investment advice."
39
+ + "\nUse the provided ARTICLE as your primary context."
40
+ + "\nIf the user asks about something not in context, say what you do/don't know."
41
+ )
42
+
43
+ DEVICE_CPU = -1
44
+
45
+ # Assignment 4 Pipelines
46
+ def _get_summ_pipe():
47
+ global _summ_pipe
48
+ if _summ_pipe is None:
49
+ _summ_pipe = pipeline(
50
+ "summarization",
51
+ model=SUMM_MODEL_ID,
52
+ device=DEVICE_CPU,
53
+ )
54
+ return _summ_pipe
55
+
56
+ def _get_sentiment_pipe():
57
+ global _sentiment_pipe
58
+ if _sentiment_pipe is None:
59
+ _sentiment_pipe = pipeline(
60
+ "text-classification",
61
+ model=SENTIMENT_MODEL_ID,
62
+ truncation=True,
63
+ device=DEVICE_CPU,
64
+ )
65
+ return _sentiment_pipe
66
 
67
+ def _get_fincls_pipe():
68
+ global _fincls_pipe
69
+ if _fincls_pipe is None:
70
+ _fincls_pipe = pipeline(
71
+ "text-classification",
72
+ model=FINCLS_MODEL_ID,
73
+ truncation=True,
74
+ return_all_scores=True,
75
+ device=DEVICE_CPU,
76
+ )
77
+ return _fincls_pipe
78
 
79
+ def _get_ner_pipe():
80
+ global _ner_pipe
81
+ if _ner_pipe is None:
82
+ _ner_pipe = pipeline(
83
+ "token-classification",
84
+ model=NER_MODEL_ID,
85
+ aggregation_strategy="simple",
86
+ device=DEVICE_CPU,
87
+ )
88
+ return _ner_pipe
89
 
90
+ # Image Generation
91
+ # Return a plain string token from LoginButton value.
92
+ def _hf_token_str(hf_token):
93
+ if hf_token is None:
94
+ return None
95
+ if isinstance(hf_token, str):
96
+ return hf_token or None
97
+ # gr.OAuthToken-like object
98
+ if hasattr(hf_token, "token"):
99
+ return hf_token.token
100
+ # dict {"token": "..."}
101
+ if isinstance(hf_token, dict):
102
+ return hf_token.get("token")
103
+ return None
104
 
105
+ def _get_img_pipe_cpu():
106
+ global _img_pipe_cpu
107
+ if _img_pipe_cpu is None:
108
+ pipe = AutoPipelineForText2Image.from_pretrained(
109
+ IMAGE_MODEL_ID,
110
+ torch_dtype=torch.float32,
111
+ use_safetensors=True,
112
+ )
113
+ pipe.to("cpu")
114
+ for fn in ("enable_attention_slicing", "enable_vae_slicing"):
115
+ try:
116
+ getattr(pipe, fn)()
117
+ except Exception:
118
+ pass
119
+ _img_pipe_cpu = pipe
120
+ return _img_pipe_cpu
121
 
122
+ def _try_cloud_text2image(prompt: str, hf_token: Optional[gr.OAuthToken]):
123
+ tok = getattr(hf_token, "token", None) if hf_token else None
124
+ if not tok:
125
+ return None
126
+ try:
127
+ client = InferenceClient(token=tok)
128
+ return client.text_to_image(prompt, model=IMAGE_MODEL_ID)
129
+ except Exception:
130
+ return None
131
+
132
+ # Analysis helpers
133
+ def _normalize_text(text: str, max_len: int = MAX_CHAR) -> str:
134
+ return (text or "").strip()[:max_len]
135
+
136
+ def run_summary(text: str) -> str:
137
+ try:
138
+ txt = _normalize_text(text, MAX_CHAR)
139
+ if not txt:
140
+ return ""
141
+ sp = _get_summ_pipe()
142
+ out = sp(txt[:3000], max_length=160, min_length=48, do_sample=False)
143
+ return out[0]["summary_text"].strip() if out else ""
144
+ except Exception as e:
145
+ print("Summary error:", e)
146
+ return ""
147
+
148
+ def run_text_nlp(text: str) -> Tuple[str, float, str, float]:
149
+ try:
150
+ txt = _normalize_text(text)
151
+ if not txt:
152
+ return "", 0.0, "", 0.0
153
+ sp = _get_sentiment_pipe()
154
+ fp = _get_fincls_pipe()
155
+ s_pred = sp(txt)[0]
156
+ dist = fp(txt)[0]
157
+ top = max(dist, key=lambda d: d["score"]) if dist else {"label": "", "score": 0.0}
158
+ return (
159
+ s_pred.get("label", ""),
160
+ float(s_pred.get("score", 0.0)),
161
+ top.get("label", ""),
162
+ float(top.get("score", 0.0)),
163
+ )
164
+ except Exception as e:
165
+ print("Text NLP error:", e)
166
+ return "Error", 0.0, "Error", 0.0
167
+
168
+ def run_ner_rows(text: str, limit: int = NER_NUM_ROWS) -> List[List[str]]:
169
+ try:
170
+ txt = _normalize_text(text, MAX_CHAR)
171
+ if not txt:
172
+ return []
173
+ ner = _get_ner_pipe()
174
+ ents = ner(txt)
175
+ rows = [
176
+ [e.get("entity_group", ""), e.get("word", ""), f"{float(e.get('score', 0.0)):.2f}"]
177
+ for e in ents
178
+ ]
179
+ return rows[:limit]
180
+ except Exception as e:
181
+ print("NER error:", e)
182
+ return [["Error", str(e), "0.00"]]
183
+
184
+ # Chat helpers
185
+ def build_context_block(article: str, analysis: Dict[str, Any]) -> str:
186
+ parts = []
187
+ if article:
188
+ parts.append(f"ARTICLE (truncated):\n{article[:MAX_CHAR]}")
189
+ if analysis:
190
+ parts.append(
191
+ "ANALYSIS SUMMARY:\n"
192
+ f"- Sentiment: {analysis.get('sentiment')} ({analysis.get('sentiment_score'):.2f})\n"
193
+ f"- Financial stance: {analysis.get('category')} ({analysis.get('category_score'):.2f})"
194
+ )
195
+ if analysis.get("summary"):
196
+ parts.append(f"- Auto Summary: {analysis['summary']}")
197
+ ents = analysis.get("entities", [])
198
+ if ents:
199
+ ent_str = ", ".join({r[1] for r in ents[:40]})
200
+ parts.append(f"- Top entities: {ent_str}")
201
+ return "\n\n".join(parts)
202
+
203
+ def _warn_if_no_token(hf_token: Optional[gr.OAuthToken]) -> str:
204
+ if not hf_token or not getattr(hf_token, "token", None):
205
+ return "\nYou are not logged in to Hugging Face. Click **Login** (left sidebar) for better reliability.\n\n"
206
+ return ""
207
+
208
+ def respond_chat(
209
+ message: str,
210
+ history: List[Dict[str, str]],
211
+ article_text: str,
212
+ analysis: Dict[str, Any],
213
+ hf_token: gr.OAuthToken,
214
+ _profile,
215
  ):
216
+ tok = _hf_token_str(hf_token)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
217
 
218
+ login_notice = _warn_if_no_token(hf_token)
219
 
220
+ client = InferenceClient(
221
+ token=tok,
222
+ model=CHAT_MODEL_ID
223
+ )
224
+
225
+ context_block = build_context_block(article_text or "", analysis or {})
226
+ sys = (CHAT_SYSTEM_PROMPT)
227
+
228
+ messages = [
229
+ {"role": "system", "content": sys},
230
+ {"role": "system", "content": context_block},
231
+ *history,
232
+ {"role": "user", "content": message},
233
+ ]
234
+
235
+ response = login_notice
236
+ try:
237
+ stream = client.chat_completion(
238
+ messages,
239
+ max_tokens=CHAT_MAX_TOKENS,
240
+ stream=True,
241
+ temperature=CHAT_TEMPERATURE,
242
+ top_p=CHAT_TOP_P,
243
+ )
244
+ for chunk in stream:
245
+ choices = getattr(chunk, "choices", [])
246
+ piece = ""
247
+ if choices and getattr(choices[0], "delta", None) and choices[0].delta.content:
248
+ piece = choices[0].delta.content
249
+ response += piece
250
+ yield response
251
+ except Exception as e:
252
+ response += (
253
+ f"\nChat request failed for model `{CHAT_MODEL_ID}`.\n"
254
+ f"Error: {e}\n"
255
+ )
256
+ yield response
257
+
258
+ # Image helpers
259
+ def generate_image(prompt, width, height, hf_token, *args):
260
+ import traceback
261
+ t0 = time.time()
262
+ prompt = (prompt or "").strip()
263
+ if not prompt:
264
+ return None, "Provide a prompt."
265
+
266
+ # 1) Cloud first (shared GPU)
267
+ try:
268
+ img = _try_cloud_text2image(prompt, hf_token)
269
+ if img is not None:
270
+ return img, f"{time.time()-t0:.2f}s"
271
+ except Exception as e:
272
+ print("Cloud image error:", e)
273
+ traceback.print_exc()
274
+
275
+ # 2) CPU fallback
276
+ try:
277
+ pipe = _get_img_pipe_cpu()
278
+ width, height = int(width), int(height)
279
+ out = pipe(
280
+ prompt=prompt,
281
+ num_inference_steps=IMG_STEPS,
282
+ guidance_scale=IMG_GUIDANCE,
283
+ width=width,
284
+ height=height,
285
+ )
286
+ img = out.images[0]
287
+ return img, f"{time.time()-t0:.2f}s | steps={IMG_STEPS}, g={IMG_GUIDANCE}"
288
+ except Exception as e:
289
+ print("CPU image error:", e)
290
+ traceback.print_exc()
291
+ return None, f"Generation failed: {e}"
292
+
293
+ # Gradio UI
294
+ with gr.Blocks(fill_height=True) as demo:
295
+ gr.Markdown("**ARIN 460 Final — Financial News Multi-Model**")
296
 
297
+ article_state = gr.State("")
298
+ analysis_state = gr.State({})
299
+
300
+ with gr.Sidebar():
301
+ login_btn = gr.LoginButton()
302
+ gr.Markdown("**Workflow**\n1) Input\n2) Analysis (Assignment 4)\n3) Chat\n4) Image")
303
+
304
+ with gr.Tabs():
305
+ with gr.Tab("Input"):
306
+ txt_in = gr.Textbox(lines=12, label="Article text")
307
+ analyze_btn = gr.Button("Analyze", variant="primary")
308
+ run_status = gr.Markdown()
309
+
310
+ with gr.Tab("Text Analysis"):
311
+ summary_box = gr.Textbox(label="Summary", lines=4, interactive=False)
312
+ sent_lbl = gr.Textbox(label="Sentiment", interactive=False)
313
+ sent_score = gr.Number(label="Sentiment score", precision=3, interactive=False)
314
+ fin_lbl = gr.Textbox(label="Financial Category", interactive=False)
315
+ fin_score = gr.Number(label="Category score", precision=3, interactive=False)
316
+ ta_status = gr.Markdown()
317
+
318
+ with gr.Tab("NER"):
319
+ ner_out = gr.Dataframe(headers=["entity", "text", "score"],
320
+ datatype=["str", "str", "str"], interactive=False)
321
+ ner_status = gr.Markdown()
322
+
323
+ with gr.Tab("Chat"):
324
+ chat = gr.ChatInterface(
325
+ respond_chat,
326
+ type="messages",
327
+ additional_inputs=[
328
+ article_state, analysis_state, login_btn
329
+ ],
330
  )
331
+ chat.chatbot.height = 400
332
 
333
+ with gr.Tab("Image"):
334
+ img_prompt = gr.Textbox(label="Prompt", lines=3)
335
+ width_slider = gr.Slider(256, 768, value=IMG_WIDTH, step=64, label="Width")
336
+ height_slider = gr.Slider(256, 768, value=IMG_HEIGHT, step=64, label="Height")
337
+ gen_btn = gr.Button("Generate Image", variant="primary")
338
+ image_out = gr.Image(label="Result", type="pil")
339
+ gen_status = gr.Markdown()
340
+ gen_btn.click(
341
+ generate_image,
342
+ inputs=[img_prompt, width_slider, height_slider, login_btn],
343
+ outputs=[image_out, gen_status]
344
  )
345
 
346
+ def _analyze_all(text):
347
+ t0 = time.time()
348
+ summ = run_summary(text)
349
+ s_lbl, s_score, c_lbl, c_score = run_text_nlp(text)
350
+ ner_rows = run_ner_rows(text)
351
+ dt = time.time() - t0
352
+ analysis = {
353
+ "summary": summ,
354
+ "sentiment": s_lbl,
355
+ "sentiment_score": s_score,
356
+ "category": c_lbl,
357
+ "category_score": c_score,
358
+ "entities": ner_rows,
359
+ }
360
+ return (
361
+ f"Processed in **{dt:.2f}s**.",
362
+ summ, s_lbl, s_score, c_lbl, c_score, f"Updated at {time.strftime('%H:%M:%S')}",
363
+ ner_rows, f"Extracted {len(ner_rows)} entities.",
364
+ text, analysis
365
+ )
366
+
367
+ # Analyze button
368
+ analyze_btn.click(lambda: gr.update(value="Analyzing...", interactive=False), [], [analyze_btn]) \
369
+ .then(_analyze_all, inputs=[txt_in],
370
+ outputs=[run_status, summary_box, sent_lbl, sent_score, fin_lbl, fin_score,
371
+ ta_status, ner_out, ner_status, article_state, analysis_state]) \
372
+ .then(lambda: gr.update(value="Analyze", interactive=True), [], [analyze_btn])
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
373
 
374
  if __name__ == "__main__":
375
+ demo.launch()