|
|
import time |
|
|
import gradio as gr |
|
|
from transformers import pipeline |
|
|
from huggingface_hub import InferenceClient |
|
|
from typing import List, Dict, Tuple, Any, Optional |
|
|
from diffusers import AutoPipelineForText2Image |
|
|
import torch |
|
|
|
|
|
|
|
|
MAX_CHAR = 8000 |
|
|
NER_NUM_ROWS = 10 |
|
|
|
|
|
|
|
|
SUMM_MODEL_ID = "sshleifer/distilbart-cnn-12-6" |
|
|
SENTIMENT_MODEL_ID = "ahmedrachid/FinancialBERT-Sentiment-Analysis" |
|
|
FINCLS_MODEL_ID = "nickmuchi/distilroberta-finetuned-financial-text-classification" |
|
|
NER_MODEL_ID = "dslim/bert-base-NER" |
|
|
CHAT_MODEL_ID = "openai/gpt-oss-20b" |
|
|
IMAGE_MODEL_ID = "stabilityai/sd-turbo" |
|
|
|
|
|
_summ_pipe = None |
|
|
_sentiment_pipe = None |
|
|
_fincls_pipe = None |
|
|
_ner_pipe = None |
|
|
_img_pipe_cpu = None |
|
|
|
|
|
|
|
|
IMG_STEPS = 2 |
|
|
IMG_GUIDANCE = 0.5 |
|
|
IMG_WIDTH = 512 |
|
|
IMG_HEIGHT = 512 |
|
|
|
|
|
|
|
|
CHAT_MAX_TOKENS = 512 |
|
|
CHAT_TEMPERATURE = 0.7 |
|
|
CHAT_TOP_P = 0.95 |
|
|
CHAT_SYSTEM_PROMPT = ("\nYou are assisting with analysis of a financial news article." |
|
|
+ "\nBe clear, cite facts from context, and avoid investment advice." |
|
|
+ "\nUse the provided ARTICLE as your primary context." |
|
|
+ "\nIf the user asks about something not in context, say what you do/don't know." |
|
|
) |
|
|
|
|
|
DEVICE_CPU = -1 |
|
|
|
|
|
|
|
|
def _get_summ_pipe(): |
|
|
global _summ_pipe |
|
|
if _summ_pipe is None: |
|
|
_summ_pipe = pipeline( |
|
|
"summarization", |
|
|
model=SUMM_MODEL_ID, |
|
|
device=DEVICE_CPU, |
|
|
) |
|
|
return _summ_pipe |
|
|
|
|
|
def _get_sentiment_pipe(): |
|
|
global _sentiment_pipe |
|
|
if _sentiment_pipe is None: |
|
|
_sentiment_pipe = pipeline( |
|
|
"text-classification", |
|
|
model=SENTIMENT_MODEL_ID, |
|
|
truncation=True, |
|
|
device=DEVICE_CPU, |
|
|
) |
|
|
return _sentiment_pipe |
|
|
|
|
|
def _get_fincls_pipe(): |
|
|
global _fincls_pipe |
|
|
if _fincls_pipe is None: |
|
|
_fincls_pipe = pipeline( |
|
|
"text-classification", |
|
|
model=FINCLS_MODEL_ID, |
|
|
truncation=True, |
|
|
return_all_scores=True, |
|
|
device=DEVICE_CPU, |
|
|
) |
|
|
return _fincls_pipe |
|
|
|
|
|
def _get_ner_pipe(): |
|
|
global _ner_pipe |
|
|
if _ner_pipe is None: |
|
|
_ner_pipe = pipeline( |
|
|
"token-classification", |
|
|
model=NER_MODEL_ID, |
|
|
aggregation_strategy="simple", |
|
|
device=DEVICE_CPU, |
|
|
) |
|
|
return _ner_pipe |
|
|
|
|
|
|
|
|
|
|
|
def _hf_token_str(hf_token): |
|
|
if hf_token is None: |
|
|
return None |
|
|
if isinstance(hf_token, str): |
|
|
return hf_token or None |
|
|
|
|
|
if hasattr(hf_token, "token"): |
|
|
return hf_token.token |
|
|
|
|
|
if isinstance(hf_token, dict): |
|
|
return hf_token.get("token") |
|
|
return None |
|
|
|
|
|
def _get_img_pipe_cpu(): |
|
|
global _img_pipe_cpu |
|
|
if _img_pipe_cpu is None: |
|
|
pipe = AutoPipelineForText2Image.from_pretrained( |
|
|
IMAGE_MODEL_ID, |
|
|
torch_dtype=torch.float32, |
|
|
use_safetensors=True, |
|
|
) |
|
|
pipe.to("cpu") |
|
|
for fn in ("enable_attention_slicing", "enable_vae_slicing"): |
|
|
try: |
|
|
getattr(pipe, fn)() |
|
|
except Exception: |
|
|
pass |
|
|
_img_pipe_cpu = pipe |
|
|
return _img_pipe_cpu |
|
|
|
|
|
def _try_cloud_text2image(prompt: str, hf_token: Optional[gr.OAuthToken]): |
|
|
tok = getattr(hf_token, "token", None) if hf_token else None |
|
|
if not tok: |
|
|
return None |
|
|
try: |
|
|
client = InferenceClient(token=tok) |
|
|
return client.text_to_image(prompt, model=IMAGE_MODEL_ID) |
|
|
except Exception: |
|
|
return None |
|
|
|
|
|
|
|
|
def _normalize_text(text: str, max_len: int = MAX_CHAR) -> str: |
|
|
return (text or "").strip()[:max_len] |
|
|
|
|
|
def run_summary(text: str) -> str: |
|
|
try: |
|
|
txt = _normalize_text(text, MAX_CHAR) |
|
|
if not txt: |
|
|
return "" |
|
|
sp = _get_summ_pipe() |
|
|
out = sp(txt[:3000], max_length=160, min_length=48, do_sample=False) |
|
|
return out[0]["summary_text"].strip() if out else "" |
|
|
except Exception as e: |
|
|
print("Summary error:", e) |
|
|
return "" |
|
|
|
|
|
def run_text_nlp(text: str) -> Tuple[str, float, str, float]: |
|
|
try: |
|
|
txt = _normalize_text(text) |
|
|
if not txt: |
|
|
return "", 0.0, "", 0.0 |
|
|
sp = _get_sentiment_pipe() |
|
|
fp = _get_fincls_pipe() |
|
|
s_pred = sp(txt)[0] |
|
|
dist = fp(txt)[0] |
|
|
top = max(dist, key=lambda d: d["score"]) if dist else {"label": "", "score": 0.0} |
|
|
return ( |
|
|
s_pred.get("label", ""), |
|
|
float(s_pred.get("score", 0.0)), |
|
|
top.get("label", ""), |
|
|
float(top.get("score", 0.0)), |
|
|
) |
|
|
except Exception as e: |
|
|
print("Text NLP error:", e) |
|
|
return "Error", 0.0, "Error", 0.0 |
|
|
|
|
|
def run_ner_rows(text: str, limit: int = NER_NUM_ROWS) -> List[List[str]]: |
|
|
try: |
|
|
txt = _normalize_text(text, MAX_CHAR) |
|
|
if not txt: |
|
|
return [] |
|
|
ner = _get_ner_pipe() |
|
|
ents = ner(txt) |
|
|
rows = [ |
|
|
[e.get("entity_group", ""), e.get("word", ""), f"{float(e.get('score', 0.0)):.2f}"] |
|
|
for e in ents |
|
|
] |
|
|
return rows[:limit] |
|
|
except Exception as e: |
|
|
print("NER error:", e) |
|
|
return [["Error", str(e), "0.00"]] |
|
|
|
|
|
|
|
|
def build_context_block(article: str, analysis: Dict[str, Any]) -> str: |
|
|
parts = [] |
|
|
if article: |
|
|
parts.append(f"ARTICLE (truncated):\n{article[:MAX_CHAR]}") |
|
|
if analysis: |
|
|
parts.append( |
|
|
"ANALYSIS SUMMARY:\n" |
|
|
f"- Sentiment: {analysis.get('sentiment')} ({analysis.get('sentiment_score'):.2f})\n" |
|
|
f"- Financial stance: {analysis.get('category')} ({analysis.get('category_score'):.2f})" |
|
|
) |
|
|
if analysis.get("summary"): |
|
|
parts.append(f"- Auto Summary: {analysis['summary']}") |
|
|
ents = analysis.get("entities", []) |
|
|
if ents: |
|
|
ent_str = ", ".join({r[1] for r in ents[:40]}) |
|
|
parts.append(f"- Top entities: {ent_str}") |
|
|
return "\n\n".join(parts) |
|
|
|
|
|
def _warn_if_no_token(hf_token: Optional[gr.OAuthToken]) -> str: |
|
|
if not hf_token or not getattr(hf_token, "token", None): |
|
|
return "\nYou are not logged in to Hugging Face. Click **Login** (left sidebar) for better reliability.\n\n" |
|
|
return "" |
|
|
|
|
|
def respond_chat( |
|
|
message: str, |
|
|
history: List[Dict[str, str]], |
|
|
article_text: str, |
|
|
analysis: Dict[str, Any], |
|
|
hf_token: gr.OAuthToken, |
|
|
_profile, |
|
|
): |
|
|
tok = _hf_token_str(hf_token) |
|
|
|
|
|
login_notice = _warn_if_no_token(hf_token) |
|
|
|
|
|
client = InferenceClient( |
|
|
token=tok, |
|
|
model=CHAT_MODEL_ID |
|
|
) |
|
|
|
|
|
context_block = build_context_block(article_text or "", analysis or {}) |
|
|
sys = (CHAT_SYSTEM_PROMPT) |
|
|
|
|
|
messages = [ |
|
|
{"role": "system", "content": sys}, |
|
|
{"role": "system", "content": context_block}, |
|
|
*history, |
|
|
{"role": "user", "content": message}, |
|
|
] |
|
|
|
|
|
response = login_notice |
|
|
try: |
|
|
stream = client.chat_completion( |
|
|
messages, |
|
|
max_tokens=CHAT_MAX_TOKENS, |
|
|
stream=True, |
|
|
temperature=CHAT_TEMPERATURE, |
|
|
top_p=CHAT_TOP_P, |
|
|
) |
|
|
for chunk in stream: |
|
|
choices = getattr(chunk, "choices", []) |
|
|
piece = "" |
|
|
if choices and getattr(choices[0], "delta", None) and choices[0].delta.content: |
|
|
piece = choices[0].delta.content |
|
|
response += piece |
|
|
yield response |
|
|
except Exception as e: |
|
|
response += ( |
|
|
f"\nChat request failed for model `{CHAT_MODEL_ID}`.\n" |
|
|
f"Error: {e}\n" |
|
|
) |
|
|
yield response |
|
|
|
|
|
|
|
|
def generate_image(prompt, width, height, hf_token, *args): |
|
|
import traceback |
|
|
t0 = time.time() |
|
|
prompt = (prompt or "").strip() |
|
|
if not prompt: |
|
|
return None, "Provide a prompt." |
|
|
|
|
|
|
|
|
try: |
|
|
img = _try_cloud_text2image(prompt, hf_token) |
|
|
if img is not None: |
|
|
return img, f"{time.time()-t0:.2f}s" |
|
|
except Exception as e: |
|
|
print("Cloud image error:", e) |
|
|
traceback.print_exc() |
|
|
|
|
|
|
|
|
try: |
|
|
pipe = _get_img_pipe_cpu() |
|
|
width, height = int(width), int(height) |
|
|
out = pipe( |
|
|
prompt=prompt, |
|
|
num_inference_steps=IMG_STEPS, |
|
|
guidance_scale=IMG_GUIDANCE, |
|
|
width=width, |
|
|
height=height, |
|
|
) |
|
|
img = out.images[0] |
|
|
return img, f"{time.time()-t0:.2f}s | steps={IMG_STEPS}, g={IMG_GUIDANCE}" |
|
|
except Exception as e: |
|
|
print("CPU image error:", e) |
|
|
traceback.print_exc() |
|
|
return None, f"Generation failed: {e}" |
|
|
|
|
|
|
|
|
with gr.Blocks(fill_height=True) as demo: |
|
|
gr.Markdown("**ARIN 460 Final — Financial News Multi-Model**") |
|
|
|
|
|
article_state = gr.State("") |
|
|
analysis_state = gr.State({}) |
|
|
|
|
|
with gr.Sidebar(): |
|
|
login_btn = gr.LoginButton() |
|
|
gr.Markdown("**Workflow**\n1) Input\n2) Analysis (Assignment 4)\n3) Chat\n4) Image") |
|
|
|
|
|
with gr.Tabs(): |
|
|
with gr.Tab("Input"): |
|
|
txt_in = gr.Textbox(lines=12, label="Article text") |
|
|
analyze_btn = gr.Button("Analyze", variant="primary") |
|
|
run_status = gr.Markdown() |
|
|
|
|
|
with gr.Tab("Text Analysis"): |
|
|
summary_box = gr.Textbox(label="Summary", lines=4, interactive=False) |
|
|
sent_lbl = gr.Textbox(label="Sentiment", interactive=False) |
|
|
sent_score = gr.Number(label="Sentiment score", precision=3, interactive=False) |
|
|
fin_lbl = gr.Textbox(label="Financial Category", interactive=False) |
|
|
fin_score = gr.Number(label="Category score", precision=3, interactive=False) |
|
|
ta_status = gr.Markdown() |
|
|
|
|
|
with gr.Tab("NER"): |
|
|
ner_out = gr.Dataframe(headers=["entity", "text", "score"], |
|
|
datatype=["str", "str", "str"], interactive=False) |
|
|
ner_status = gr.Markdown() |
|
|
|
|
|
with gr.Tab("Chat"): |
|
|
chat = gr.ChatInterface( |
|
|
respond_chat, |
|
|
type="messages", |
|
|
additional_inputs=[ |
|
|
article_state, analysis_state, login_btn |
|
|
], |
|
|
) |
|
|
chat.chatbot.height = 400 |
|
|
|
|
|
with gr.Tab("Image"): |
|
|
img_prompt = gr.Textbox(label="Prompt", lines=3) |
|
|
width_slider = gr.Slider(256, 768, value=IMG_WIDTH, step=64, label="Width") |
|
|
height_slider = gr.Slider(256, 768, value=IMG_HEIGHT, step=64, label="Height") |
|
|
gen_btn = gr.Button("Generate Image", variant="primary") |
|
|
image_out = gr.Image(label="Result", type="pil") |
|
|
gen_status = gr.Markdown() |
|
|
gen_btn.click( |
|
|
generate_image, |
|
|
inputs=[img_prompt, width_slider, height_slider, login_btn], |
|
|
outputs=[image_out, gen_status] |
|
|
) |
|
|
|
|
|
def _analyze_all(text): |
|
|
t0 = time.time() |
|
|
summ = run_summary(text) |
|
|
s_lbl, s_score, c_lbl, c_score = run_text_nlp(text) |
|
|
ner_rows = run_ner_rows(text) |
|
|
dt = time.time() - t0 |
|
|
analysis = { |
|
|
"summary": summ, |
|
|
"sentiment": s_lbl, |
|
|
"sentiment_score": s_score, |
|
|
"category": c_lbl, |
|
|
"category_score": c_score, |
|
|
"entities": ner_rows, |
|
|
} |
|
|
return ( |
|
|
f"Processed in **{dt:.2f}s**.", |
|
|
summ, s_lbl, s_score, c_lbl, c_score, f"Updated at {time.strftime('%H:%M:%S')}", |
|
|
ner_rows, f"Extracted {len(ner_rows)} entities.", |
|
|
text, analysis |
|
|
) |
|
|
|
|
|
|
|
|
analyze_btn.click(lambda: gr.update(value="Analyzing...", interactive=False), [], [analyze_btn]) \ |
|
|
.then(_analyze_all, inputs=[txt_in], |
|
|
outputs=[run_status, summary_box, sent_lbl, sent_score, fin_lbl, fin_score, |
|
|
ta_status, ner_out, ner_status, article_state, analysis_state]) \ |
|
|
.then(lambda: gr.update(value="Analyze", interactive=True), [], [analyze_btn]) |
|
|
|
|
|
if __name__ == "__main__": |
|
|
demo.launch() |