russelllarkin's picture
Update app.py
36971a5 verified
raw
history blame
12.2 kB
import time
import gradio as gr
from transformers import pipeline
from huggingface_hub import InferenceClient
from typing import List, Dict, Tuple, Any, Optional
from diffusers import AutoPipelineForText2Image
import torch
# Article Analysis Constants
MAX_CHAR = 8000
NER_NUM_ROWS = 10
# Model Constants
SUMM_MODEL_ID = "sshleifer/distilbart-cnn-12-6"
SENTIMENT_MODEL_ID = "ahmedrachid/FinancialBERT-Sentiment-Analysis"
FINCLS_MODEL_ID = "nickmuchi/distilroberta-finetuned-financial-text-classification"
NER_MODEL_ID = "dslim/bert-base-NER"
CHAT_MODEL_ID = "openai/gpt-oss-20b"
IMAGE_MODEL_ID = "stabilityai/sd-turbo"
_summ_pipe = None
_sentiment_pipe = None
_fincls_pipe = None
_ner_pipe = None
_img_pipe_cpu = None
# Image Constants
IMG_STEPS = 2
IMG_GUIDANCE = 0.5
IMG_WIDTH = 512
IMG_HEIGHT = 512
# Chat Constants
CHAT_MAX_TOKENS = 512
CHAT_TEMPERATURE = 0.7
CHAT_TOP_P = 0.95
CHAT_SYSTEM_PROMPT = ("\nYou are assisting with analysis of a financial news article."
+ "\nBe clear, cite facts from context, and avoid investment advice."
+ "\nUse the provided ARTICLE as your primary context."
+ "\nIf the user asks about something not in context, say what you do/don't know."
)
DEVICE_CPU = -1
# Assignment 4 Pipelines
def _get_summ_pipe():
global _summ_pipe
if _summ_pipe is None:
_summ_pipe = pipeline(
"summarization",
model=SUMM_MODEL_ID,
device=DEVICE_CPU,
)
return _summ_pipe
def _get_sentiment_pipe():
global _sentiment_pipe
if _sentiment_pipe is None:
_sentiment_pipe = pipeline(
"text-classification",
model=SENTIMENT_MODEL_ID,
truncation=True,
device=DEVICE_CPU,
)
return _sentiment_pipe
def _get_fincls_pipe():
global _fincls_pipe
if _fincls_pipe is None:
_fincls_pipe = pipeline(
"text-classification",
model=FINCLS_MODEL_ID,
truncation=True,
return_all_scores=True,
device=DEVICE_CPU,
)
return _fincls_pipe
def _get_ner_pipe():
global _ner_pipe
if _ner_pipe is None:
_ner_pipe = pipeline(
"token-classification",
model=NER_MODEL_ID,
aggregation_strategy="simple",
device=DEVICE_CPU,
)
return _ner_pipe
# Image Generation
# Return a plain string token from LoginButton value.
def _hf_token_str(hf_token):
if hf_token is None:
return None
if isinstance(hf_token, str):
return hf_token or None
# gr.OAuthToken-like object
if hasattr(hf_token, "token"):
return hf_token.token
# dict {"token": "..."}
if isinstance(hf_token, dict):
return hf_token.get("token")
return None
def _get_img_pipe_cpu():
global _img_pipe_cpu
if _img_pipe_cpu is None:
pipe = AutoPipelineForText2Image.from_pretrained(
IMAGE_MODEL_ID,
torch_dtype=torch.float32,
use_safetensors=True,
)
pipe.to("cpu")
for fn in ("enable_attention_slicing", "enable_vae_slicing"):
try:
getattr(pipe, fn)()
except Exception:
pass
_img_pipe_cpu = pipe
return _img_pipe_cpu
def _try_cloud_text2image(prompt: str, hf_token: Optional[gr.OAuthToken]):
tok = getattr(hf_token, "token", None) if hf_token else None
if not tok:
return None
try:
client = InferenceClient(token=tok)
return client.text_to_image(prompt, model=IMAGE_MODEL_ID)
except Exception:
return None
# Analysis helpers
def _normalize_text(text: str, max_len: int = MAX_CHAR) -> str:
return (text or "").strip()[:max_len]
def run_summary(text: str) -> str:
try:
txt = _normalize_text(text, MAX_CHAR)
if not txt:
return ""
sp = _get_summ_pipe()
out = sp(txt[:3000], max_length=160, min_length=48, do_sample=False)
return out[0]["summary_text"].strip() if out else ""
except Exception as e:
print("Summary error:", e)
return ""
def run_text_nlp(text: str) -> Tuple[str, float, str, float]:
try:
txt = _normalize_text(text)
if not txt:
return "", 0.0, "", 0.0
sp = _get_sentiment_pipe()
fp = _get_fincls_pipe()
s_pred = sp(txt)[0]
dist = fp(txt)[0]
top = max(dist, key=lambda d: d["score"]) if dist else {"label": "", "score": 0.0}
return (
s_pred.get("label", ""),
float(s_pred.get("score", 0.0)),
top.get("label", ""),
float(top.get("score", 0.0)),
)
except Exception as e:
print("Text NLP error:", e)
return "Error", 0.0, "Error", 0.0
def run_ner_rows(text: str, limit: int = NER_NUM_ROWS) -> List[List[str]]:
try:
txt = _normalize_text(text, MAX_CHAR)
if not txt:
return []
ner = _get_ner_pipe()
ents = ner(txt)
rows = [
[e.get("entity_group", ""), e.get("word", ""), f"{float(e.get('score', 0.0)):.2f}"]
for e in ents
]
return rows[:limit]
except Exception as e:
print("NER error:", e)
return [["Error", str(e), "0.00"]]
# Chat helpers
def build_context_block(article: str, analysis: Dict[str, Any]) -> str:
parts = []
if article:
parts.append(f"ARTICLE (truncated):\n{article[:MAX_CHAR]}")
if analysis:
parts.append(
"ANALYSIS SUMMARY:\n"
f"- Sentiment: {analysis.get('sentiment')} ({analysis.get('sentiment_score'):.2f})\n"
f"- Financial stance: {analysis.get('category')} ({analysis.get('category_score'):.2f})"
)
if analysis.get("summary"):
parts.append(f"- Auto Summary: {analysis['summary']}")
ents = analysis.get("entities", [])
if ents:
ent_str = ", ".join({r[1] for r in ents[:40]})
parts.append(f"- Top entities: {ent_str}")
return "\n\n".join(parts)
def _warn_if_no_token(hf_token: Optional[gr.OAuthToken]) -> str:
if not hf_token or not getattr(hf_token, "token", None):
return "\nYou are not logged in to Hugging Face. Click **Login** (left sidebar) for better reliability.\n\n"
return ""
def respond_chat(
message: str,
history: List[Dict[str, str]],
article_text: str,
analysis: Dict[str, Any],
hf_token: gr.OAuthToken,
_profile,
):
tok = _hf_token_str(hf_token)
login_notice = _warn_if_no_token(hf_token)
client = InferenceClient(
token=tok,
model=CHAT_MODEL_ID
)
context_block = build_context_block(article_text or "", analysis or {})
sys = (CHAT_SYSTEM_PROMPT)
messages = [
{"role": "system", "content": sys},
{"role": "system", "content": context_block},
*history,
{"role": "user", "content": message},
]
response = login_notice
try:
stream = client.chat_completion(
messages,
max_tokens=CHAT_MAX_TOKENS,
stream=True,
temperature=CHAT_TEMPERATURE,
top_p=CHAT_TOP_P,
)
for chunk in stream:
choices = getattr(chunk, "choices", [])
piece = ""
if choices and getattr(choices[0], "delta", None) and choices[0].delta.content:
piece = choices[0].delta.content
response += piece
yield response
except Exception as e:
response += (
f"\nChat request failed for model `{CHAT_MODEL_ID}`.\n"
f"Error: {e}\n"
)
yield response
# Image helpers
def generate_image(prompt, width, height, hf_token, *args):
import traceback
t0 = time.time()
prompt = (prompt or "").strip()
if not prompt:
return None, "Provide a prompt."
# 1) Cloud first (shared GPU)
try:
img = _try_cloud_text2image(prompt, hf_token)
if img is not None:
return img, f"{time.time()-t0:.2f}s"
except Exception as e:
print("Cloud image error:", e)
traceback.print_exc()
# 2) CPU fallback
try:
pipe = _get_img_pipe_cpu()
width, height = int(width), int(height)
out = pipe(
prompt=prompt,
num_inference_steps=IMG_STEPS,
guidance_scale=IMG_GUIDANCE,
width=width,
height=height,
)
img = out.images[0]
return img, f"{time.time()-t0:.2f}s | steps={IMG_STEPS}, g={IMG_GUIDANCE}"
except Exception as e:
print("CPU image error:", e)
traceback.print_exc()
return None, f"Generation failed: {e}"
# Gradio UI
with gr.Blocks(fill_height=True) as demo:
gr.Markdown("**ARIN 460 Final — Financial News Multi-Model**")
article_state = gr.State("")
analysis_state = gr.State({})
with gr.Sidebar():
login_btn = gr.LoginButton()
gr.Markdown("**Workflow**\n1) Input\n2) Analysis (Assignment 4)\n3) Chat\n4) Image")
with gr.Tabs():
with gr.Tab("Input"):
txt_in = gr.Textbox(lines=12, label="Article text")
analyze_btn = gr.Button("Analyze", variant="primary")
run_status = gr.Markdown()
with gr.Tab("Text Analysis"):
summary_box = gr.Textbox(label="Summary", lines=4, interactive=False)
sent_lbl = gr.Textbox(label="Sentiment", interactive=False)
sent_score = gr.Number(label="Sentiment score", precision=3, interactive=False)
fin_lbl = gr.Textbox(label="Financial Category", interactive=False)
fin_score = gr.Number(label="Category score", precision=3, interactive=False)
ta_status = gr.Markdown()
with gr.Tab("NER"):
ner_out = gr.Dataframe(headers=["entity", "text", "score"],
datatype=["str", "str", "str"], interactive=False)
ner_status = gr.Markdown()
with gr.Tab("Chat"):
chat = gr.ChatInterface(
respond_chat,
type="messages",
additional_inputs=[
article_state, analysis_state, login_btn
],
)
chat.chatbot.height = 400
with gr.Tab("Image"):
img_prompt = gr.Textbox(label="Prompt", lines=3)
width_slider = gr.Slider(256, 768, value=IMG_WIDTH, step=64, label="Width")
height_slider = gr.Slider(256, 768, value=IMG_HEIGHT, step=64, label="Height")
gen_btn = gr.Button("Generate Image", variant="primary")
image_out = gr.Image(label="Result", type="pil")
gen_status = gr.Markdown()
gen_btn.click(
generate_image,
inputs=[img_prompt, width_slider, height_slider, login_btn],
outputs=[image_out, gen_status]
)
def _analyze_all(text):
t0 = time.time()
summ = run_summary(text)
s_lbl, s_score, c_lbl, c_score = run_text_nlp(text)
ner_rows = run_ner_rows(text)
dt = time.time() - t0
analysis = {
"summary": summ,
"sentiment": s_lbl,
"sentiment_score": s_score,
"category": c_lbl,
"category_score": c_score,
"entities": ner_rows,
}
return (
f"Processed in **{dt:.2f}s**.",
summ, s_lbl, s_score, c_lbl, c_score, f"Updated at {time.strftime('%H:%M:%S')}",
ner_rows, f"Extracted {len(ner_rows)} entities.",
text, analysis
)
# Analyze button
analyze_btn.click(lambda: gr.update(value="Analyzing...", interactive=False), [], [analyze_btn]) \
.then(_analyze_all, inputs=[txt_in],
outputs=[run_status, summary_box, sent_lbl, sent_score, fin_lbl, fin_score,
ta_status, ner_out, ner_status, article_state, analysis_state]) \
.then(lambda: gr.update(value="Analyze", interactive=True), [], [analyze_btn])
if __name__ == "__main__":
demo.launch()