Spaces:
Sleeping
Sleeping
| """ | |
| app.py | |
| Document Analysis Gradio app — updated to support PDF, DOCX (Word), and PPTX (PowerPoint). | |
| - robust file reading | |
| - streaming PDF extraction | |
| - docx/pptx extraction to pages_texts (one element per paragraph/slide) | |
| - token-aware truncation, chunked summarization, sampled Q&A | |
| - multi-file upload UI (processes first supported file) | |
| """ | |
| import os | |
| # disable noisy HF symlink warning on Windows | |
| os.environ.setdefault("HF_HUB_DISABLE_SYMLINKS_WARNING", "1") | |
| import re | |
| import io | |
| import math | |
| import tempfile | |
| import threading | |
| from pathlib import Path | |
| from typing import List, Tuple, Optional | |
| import gradio as gr | |
| import pdfplumber | |
| import nltk | |
| from nltk.tokenize import sent_tokenize | |
| from transformers import pipeline, AutoTokenizer | |
| import torch | |
| import pandas as pd | |
| from tqdm.auto import tqdm | |
| # Try to import docx and pptx; if not present, we'll handle gracefully at runtime. | |
| try: | |
| from docx import Document as DocxDocument # python-docx | |
| except Exception: | |
| DocxDocument = None | |
| try: | |
| from pptx import Presentation as PptxPresentation # python-pptx | |
| except Exception: | |
| PptxPresentation = None | |
| # ------------------------- | |
| # NLTK: ensure punkt available, fallback later | |
| # ------------------------- | |
| try: | |
| nltk.download("punkt", quiet=True) | |
| try: | |
| nltk.download("punkt_tab", quiet=True) | |
| except Exception: | |
| pass | |
| except Exception: | |
| pass | |
| # ------------------------- | |
| # Device detection | |
| # ------------------------- | |
| DEVICE = 0 if torch.cuda.is_available() else -1 | |
| print("Device set to use", "cuda" if DEVICE >= 0 else "cpu") | |
| # ------------------------- | |
| # Cached pipelines and tokenizers | |
| # ------------------------- | |
| _models = {} | |
| _tokenizers = {} | |
| _models_lock = threading.Lock() | |
| def get_pipeline(name: str, task: str): | |
| """Return a cached HF pipeline for given task and model name.""" | |
| key = f"{task}__{name}" | |
| with _models_lock: | |
| if key in _models: | |
| return _models[key] | |
| print(f"Loading pipeline: task={task}, model={name} ... (this may take a while on first run)") | |
| p = pipeline(task, model=name, device=DEVICE) | |
| _models[key] = p | |
| try: | |
| _tokenizers[name] = p.tokenizer | |
| except Exception: | |
| pass | |
| return p | |
| def get_tokenizer(name: str): | |
| """Return a cached tokenizer (fallback to AutoTokenizer if not present).""" | |
| if name in _tokenizers: | |
| return _tokenizers[name] | |
| try: | |
| tok = AutoTokenizer.from_pretrained(name) | |
| _tokenizers[name] = tok | |
| return tok | |
| except Exception: | |
| return None | |
| # ------------------------- | |
| # Default models (adjust if you want smaller/faster ones) | |
| # ------------------------- | |
| # ------------------------- | |
| # Default models (adjusted for speed) | |
| # ------------------------- | |
| SUMMARIZER_MODEL = "t5-small" # This is already fast | |
| QG_MODEL = "valhalla/t5-small-qg-hl" # Use the 'small' version | |
| QA_MODEL = "distilbert-base-cased-distilled-squad" # Much faster than RoBERTa | |
| # ------------------------- | |
| # Helpers: filenames / types | |
| # ------------------------- | |
| SUPPORTED_EXTS = [".pdf", ".docx", ".pptx", ".txt", ".md", ".rtf", ".png", ".jpg", ".jpeg", ".tiff"] | |
| def ext_of_name(name: str) -> str: | |
| return Path(name).suffix.lower() | |
| def read_uploaded_file_to_bytes(file_obj): | |
| """ | |
| Accept many shapes of Gradio file objects and return bytes. | |
| Supports list/tuple (returns first readable file), dict-like, file-likes, paths, bytes. | |
| """ | |
| # If a list/tuple of uploaded files, try each candidate | |
| if isinstance(file_obj, (list, tuple)): | |
| last_err = None | |
| for elem in file_obj: | |
| try: | |
| return read_uploaded_file_to_bytes(elem) | |
| except Exception as e: | |
| last_err = e | |
| continue | |
| raise ValueError(f"No readable file in list. Last error: {last_err}") | |
| if file_obj is None: | |
| raise ValueError("No file provided") | |
| # if it's already bytes | |
| if isinstance(file_obj, (bytes, bytearray)): | |
| return bytes(file_obj) | |
| # dict-like | |
| if isinstance(file_obj, dict): | |
| for key in ("file", "tmp_path", "name", "data", "path"): | |
| val = file_obj.get(key) | |
| if isinstance(val, (bytes, bytearray)): | |
| return bytes(val) | |
| if isinstance(val, str) and Path(val).exists(): | |
| return Path(val).read_bytes() | |
| maybe = file_obj.get("file") | |
| if hasattr(maybe, "read"): | |
| data = maybe.read() | |
| if isinstance(data, str): | |
| return data.encode("utf-8") | |
| return data | |
| # string path | |
| if isinstance(file_obj, str) and Path(file_obj).exists(): | |
| return Path(file_obj).read_bytes() | |
| # has a .name attribute that points to a file | |
| if hasattr(file_obj, "name") and isinstance(getattr(file_obj, "name"), str) and Path(file_obj.name).exists(): | |
| try: | |
| return Path(file_obj.name).read_bytes() | |
| except Exception: | |
| pass | |
| # file-like with .read() | |
| if hasattr(file_obj, "read"): | |
| try: | |
| data = file_obj.read() | |
| if isinstance(data, str): | |
| return data.encode("utf-8") | |
| return data | |
| except Exception: | |
| pass | |
| # last resort: string representation -> path | |
| try: | |
| s = str(file_obj) | |
| if Path(s).exists(): | |
| return Path(s).read_bytes() | |
| except Exception: | |
| pass | |
| raise ValueError(f"Unsupported uploaded file object type: {type(file_obj)}") | |
| def get_uploaded_filenames(file_obj) -> List[str]: | |
| """ | |
| Return a list of human-friendly filenames from uploaded file object(s). | |
| """ | |
| names = [] | |
| if file_obj is None: | |
| return names | |
| if isinstance(file_obj, (list, tuple)): | |
| for elem in file_obj: | |
| names.extend(get_uploaded_filenames(elem)) | |
| return names | |
| if isinstance(file_obj, dict): | |
| for key in ("name", "filename", "file", "tmp_path", "path"): | |
| if key in file_obj: | |
| val = file_obj.get(key) | |
| if isinstance(val, str): | |
| names.append(Path(val).name) | |
| elif hasattr(val, "name"): | |
| names.append(Path(val.name).name) | |
| maybe = file_obj.get("file") | |
| if maybe is not None: | |
| if hasattr(maybe, "name"): | |
| names.append(Path(maybe.name).name) | |
| return names | |
| if isinstance(file_obj, str): | |
| return [Path(file_obj).name] | |
| if hasattr(file_obj, "name"): | |
| return [Path(getattr(file_obj, "name")).name] | |
| return [str(file_obj)] | |
| def find_first_supported_file(files) -> Tuple[Optional[object], Optional[str]]: | |
| """ | |
| From the uploaded list or single file-like object, find the first file with a supported extension. | |
| Returns (file_obj, filename) or (None, None) if none supported. | |
| """ | |
| if not files: | |
| return None, None | |
| candidates = [] | |
| if isinstance(files, (list, tuple)): | |
| for f in files: | |
| names = get_uploaded_filenames(f) | |
| for n in names: | |
| candidates.append((f, n)) | |
| else: | |
| names = get_uploaded_filenames(files) | |
| for n in names: | |
| candidates.append((files, n)) | |
| for fobj, name in candidates: | |
| ext = ext_of_name(name) | |
| if ext in SUPPORTED_EXTS: | |
| return fobj, name | |
| # fallback: if none matched, return first uploaded | |
| if candidates: | |
| return candidates[0] | |
| return None, None | |
| # ------------------------- | |
| # PDF extraction (streaming) | |
| # ------------------------- | |
| def extract_text_from_pdf_streaming(file_bytes: bytes, do_ocr: bool = False, extracted_txt_path: str = None): | |
| """Write PDF bytes to temp file and extract each page's text; returns extracted_txt_path, pages_texts list""" | |
| with tempfile.NamedTemporaryFile(delete=False, suffix=".pdf") as tmp_pdf: | |
| tmp_pdf.write(file_bytes) | |
| tmp_pdf_path = tmp_pdf.name | |
| if extracted_txt_path is None: | |
| tmp_txt = tempfile.NamedTemporaryFile(delete=False, suffix=".txt", mode="w+", encoding="utf-8") | |
| extracted_txt_path = tmp_txt.name | |
| tmp_txt.close() | |
| pages_texts = [] | |
| try: | |
| with pdfplumber.open(tmp_pdf_path) as pdf: | |
| for page in pdf.pages: | |
| page_text = page.extract_text() | |
| if (not page_text or page_text.strip() == "") and do_ocr: | |
| try: | |
| from pdf2image import convert_from_path | |
| import pytesseract | |
| images = convert_from_path(tmp_pdf_path, first_page=page.page_number, last_page=page.page_number) | |
| if images: | |
| page_text = pytesseract.image_to_string(images[0]) | |
| except Exception: | |
| page_text = "" | |
| if page_text is None: | |
| page_text = "" | |
| with open(extracted_txt_path, "a", encoding="utf-8") as fout: | |
| fout.write(page_text) | |
| fout.write("\n\n---PAGE_BREAK---\n\n") | |
| pages_texts.append(page_text) | |
| finally: | |
| try: | |
| os.remove(tmp_pdf_path) | |
| except Exception: | |
| pass | |
| return extracted_txt_path, pages_texts | |
| # ------------------------- | |
| # DOCX extraction | |
| # ------------------------- | |
| def extract_text_from_docx_bytes(file_bytes: bytes) -> Tuple[str, List[str]]: | |
| """ | |
| Returns (tmp_text_path, pages_texts) where pages_texts is a list of paragraph groups. | |
| We'll treat each paragraph as a small 'page' or group paragraphs into ~200-word chunks. | |
| """ | |
| if DocxDocument is None: | |
| raise RuntimeError("python-docx not installed. Install with `pip install python-docx`") | |
| with tempfile.NamedTemporaryFile(delete=False, suffix=".docx") as tmp: | |
| tmp.write(file_bytes) | |
| tmp_path = tmp.name | |
| pages_texts = [] | |
| try: | |
| doc = DocxDocument(tmp_path) | |
| # Collect non-empty paragraphs | |
| paras = [p.text.strip() for p in doc.paragraphs if p.text and p.text.strip()] | |
| # Group paragraphs into passages of ~200 words | |
| current = [] | |
| cur_words = 0 | |
| for p in paras: | |
| w = len(p.split()) | |
| if cur_words + w <= 200: | |
| current.append(p) | |
| cur_words += w | |
| else: | |
| pages_texts.append(" ".join(current).strip()) | |
| current = [p] | |
| cur_words = w | |
| if current: | |
| pages_texts.append(" ".join(current).strip()) | |
| # write full extracted text to tmp file for download/debug if needed | |
| tmp_txt = tempfile.NamedTemporaryFile(delete=False, suffix=".txt", mode="w+", encoding="utf-8") | |
| with open(tmp_txt.name, "w", encoding="utf-8") as f: | |
| f.write("\n\n".join(pages_texts)) | |
| txt_path = tmp_txt.name | |
| finally: | |
| try: | |
| os.remove(tmp_path) | |
| except Exception: | |
| pass | |
| return txt_path, pages_texts | |
| # ------------------------- | |
| # PPTX extraction | |
| # ------------------------- | |
| def extract_text_from_pptx_bytes(file_bytes: bytes) -> Tuple[str, List[str]]: | |
| """ | |
| Returns (tmp_text_path, pages_texts) where each slide's text is one element. | |
| """ | |
| if PptxPresentation is None: | |
| raise RuntimeError("python-pptx not installed. Install with `pip install python-pptx`") | |
| with tempfile.NamedTemporaryFile(delete=False, suffix=".pptx") as tmp: | |
| tmp.write(file_bytes) | |
| tmp_path = tmp.name | |
| pages_texts = [] | |
| try: | |
| prs = PptxPresentation(tmp_path) | |
| for slide in prs.slides: | |
| texts = [] | |
| for shape in slide.shapes: | |
| try: | |
| if hasattr(shape, "text") and shape.text: | |
| texts.append(shape.text.strip()) | |
| except Exception: | |
| continue | |
| slide_text = "\n".join([t for t in texts if t]) | |
| pages_texts.append(slide_text) | |
| tmp_txt = tempfile.NamedTemporaryFile(delete=False, suffix=".txt", mode="w+", encoding="utf-8") | |
| with open(tmp_txt.name, "w", encoding="utf-8") as f: | |
| f.write("\n\n".join(pages_texts)) | |
| txt_path = tmp_txt.name | |
| finally: | |
| try: | |
| os.remove(tmp_path) | |
| except Exception: | |
| pass | |
| return txt_path, pages_texts | |
| # ------------------------- | |
| # TXT/MD extraction | |
| # ------------------------- | |
| def extract_text_from_txt_bytes(file_bytes: bytes) -> Tuple[str, List[str]]: | |
| s = file_bytes.decode("utf-8", errors="ignore") | |
| # split into passages by blank lines or ~200 words | |
| paras = [p.strip() for p in re.split(r"\n\s*\n", s) if p.strip()] | |
| pages_texts = [] | |
| cur = [] | |
| cur_words = 0 | |
| for p in paras: | |
| w = len(p.split()) | |
| if cur_words + w <= 200: | |
| cur.append(p) | |
| cur_words += w | |
| else: | |
| pages_texts.append(" ".join(cur).strip()) | |
| cur = [p] | |
| cur_words = w | |
| if cur: | |
| pages_texts.append(" ".join(cur).strip()) | |
| tmp_txt = tempfile.NamedTemporaryFile(delete=False, suffix=".txt", mode="w+", encoding="utf-8") | |
| with open(tmp_txt.name, "w", encoding="utf-8") as f: | |
| f.write("\n\n".join(pages_texts)) | |
| return tmp_txt.name, pages_texts | |
| # ------------------------- | |
| # Token-aware truncation helpers | |
| # ------------------------- | |
| def truncate_by_tokens(text: str, tokenizer, reserve: int = 64) -> str: | |
| if not text: | |
| return text | |
| if tokenizer is None: | |
| return text if len(text) <= 3000 else text[:3000] | |
| try: | |
| ids = tokenizer.encode(text, add_special_tokens=False) | |
| max_len = getattr(tokenizer, "model_max_length", 512) | |
| allowed = max(1, max_len - reserve) | |
| if len(ids) > allowed: | |
| ids = ids[:allowed] | |
| return tokenizer.decode(ids, skip_special_tokens=True, clean_up_tokenization_spaces=True) | |
| return text | |
| except Exception: | |
| return text if len(text) <= 3000 else text[:3000] | |
| # ------------------------- | |
| # Sentence tokenization fallback | |
| # ------------------------- | |
| def safe_sentence_tokenize(text: str) -> List[str]: | |
| try: | |
| sents = sent_tokenize(text) | |
| if isinstance(sents, list) and len(sents) > 0: | |
| return sents | |
| except Exception: | |
| pass | |
| pieces = re.split(r"(?<=[.!?])\s+", text.strip()) | |
| return [p.strip() for p in pieces if p.strip()] | |
| # ------------------------- | |
| # Summarization (chunked, token-truncated) | |
| # ------------------------- | |
| def summarize_text_chunked(summarizer, pages_texts: List[str], pages_per_chunk: int = 8) -> str: | |
| if not pages_texts: | |
| return "(no text)" | |
| summaries = [] | |
| tokenizer = getattr(summarizer, "tokenizer", None) or get_tokenizer(SUMMARIZER_MODEL) | |
| num_pages = len(pages_texts) | |
| for i in range(0, num_pages, pages_per_chunk): | |
| chunk_pages = pages_texts[i : i + pages_per_chunk] | |
| chunk_text = "\n\n".join([p for p in chunk_pages if p.strip()]) | |
| if not chunk_text.strip(): | |
| continue | |
| safe_chunk = truncate_by_tokens(chunk_text, tokenizer, reserve=64) | |
| try: | |
| out = summarizer(safe_chunk, max_length=150, min_length=30, do_sample=False, truncation=True) | |
| summaries.append(out[0]["summary_text"]) | |
| except Exception: | |
| summaries.append(safe_chunk[:800]) | |
| return "\n\n".join(summaries) if summaries else "(no summary produced)" | |
| # ------------------------- | |
| # Passage splitting / QG / QA with token truncation | |
| # ------------------------- | |
| def split_into_passages_from_pages(pages_texts: List[str], max_words: int = 200) -> List[str]: | |
| all_passages = [] | |
| for page_text in pages_texts: | |
| if not page_text or not page_text.strip(): | |
| continue | |
| sents = safe_sentence_tokenize(page_text) | |
| cur = [] | |
| cur_len = 0 | |
| for s in sents: | |
| w = len(s.split()) | |
| if cur_len + w <= max_words: | |
| cur.append(s) | |
| cur_len += w | |
| else: | |
| if cur: | |
| all_passages.append(" ".join(cur).strip()) | |
| cur = [s] | |
| cur_len = w | |
| if cur: | |
| all_passages.append(" ".join(cur).strip()) | |
| return all_passages | |
| def generate_questions_from_passage(qg_pipeline, passage: str, min_questions: int = 3) -> List[str]: | |
| tok = getattr(qg_pipeline, "tokenizer", None) or get_tokenizer(QG_MODEL) | |
| safe_passage = truncate_by_tokens(passage, tok, reserve=32) | |
| prompt = f"generate questions: {safe_passage}" | |
| try: | |
| out = qg_pipeline(prompt, max_length=256, do_sample=False, truncation=True) | |
| gen_text = out[0].get("generated_text") or out[0].get("text") or "" | |
| except Exception: | |
| gen_text = "" | |
| candidates = [] | |
| if "<sep>" in gen_text: | |
| candidates = gen_text.split("<sep>") | |
| elif "\n" in gen_text: | |
| candidates = [line.strip() for line in gen_text.splitlines() if line.strip()] | |
| else: | |
| parts = [p.strip() for p in gen_text.split("?") if p.strip()] | |
| candidates = [p + "?" for p in parts] | |
| questions = [q.strip() for q in candidates if q.strip()] | |
| if len(questions) < min_questions: | |
| sentences = safe_sentence_tokenize(safe_passage) | |
| for i in range(len(sentences)): | |
| if len(questions) >= min_questions: | |
| break | |
| small = ". ".join(sentences[i : i + 2]) | |
| try: | |
| out2 = qg_pipeline(f"generate questions: {small}", max_length=128, do_sample=False, truncation=True) | |
| txt = out2[0].get("generated_text") or out2[0].get("text") or "" | |
| except Exception: | |
| txt = "" | |
| if "<sep>" in txt: | |
| more = txt.split("<sep>") | |
| else: | |
| more = [l.strip() for l in txt.splitlines() if l.strip()] | |
| for m in more: | |
| if len(questions) >= min_questions: | |
| break | |
| maybe = m.strip() | |
| if maybe and maybe not in questions: | |
| questions.append(maybe) | |
| return questions[:max(min_questions, len(questions))] | |
| def answer_questions_for_passage(qa_pipeline, passage: str, questions: List[str]) -> List[Tuple[str, str, float]]: | |
| results = [] | |
| tok = getattr(qa_pipeline, "tokenizer", None) or get_tokenizer(QA_MODEL) | |
| for q in questions: | |
| try: | |
| safe_ctx = truncate_by_tokens(passage, tok, reserve=64) | |
| res = qa_pipeline(question=q, context=safe_ctx) | |
| answer = res.get("answer", "") | |
| score = float(res.get("score", 0.0)) | |
| except Exception: | |
| answer = "" | |
| score = 0.0 | |
| results.append((q, answer, score)) | |
| return results | |
| # ------------------------- | |
| # Unified extract_text_for_file: dispatch by extension | |
| # ------------------------- | |
| def extract_text_for_file(file_bytes: bytes, filename: str, do_ocr: bool = False) -> Tuple[str, List[str]]: | |
| """ | |
| Given raw bytes and a filename, return (extracted_txt_path, pages_texts). | |
| Supports PDF (streamed), DOCX, PPTX, TXT/MD. For images, attempt OCR if do_ocr True. | |
| """ | |
| ext = ext_of_name(filename) | |
| if ext == ".pdf": | |
| return extract_text_from_pdf_streaming(file_bytes, do_ocr=do_ocr) | |
| if ext == ".docx": | |
| return extract_text_from_docx_bytes(file_bytes) | |
| if ext == ".pptx": | |
| return extract_text_from_pptx_bytes(file_bytes) | |
| if ext in (".txt", ".md", ".rtf"): | |
| return extract_text_from_txt_bytes(file_bytes) | |
| # images: try OCR if requested | |
| if ext in (".png", ".jpg", ".jpeg", ".tiff", ".bmp") and do_ocr: | |
| # write image bytes to temp file and run OCR with pytesseract if available | |
| try: | |
| from PIL import Image | |
| import pytesseract | |
| with tempfile.NamedTemporaryFile(delete=False, suffix=ext) as tmp: | |
| tmp.write(file_bytes) | |
| tmp_path = tmp.name | |
| img = Image.open(tmp_path) | |
| ocr_txt = pytesseract.image_to_string(img) | |
| pages = [p.strip() for p in re.split(r"\n\s*\n", ocr_txt) if p.strip()] | |
| tmp_txt = tempfile.NamedTemporaryFile(delete=False, suffix=".txt", mode="w+", encoding="utf-8") | |
| with open(tmp_txt.name, "w", encoding="utf-8") as f: | |
| f.write("\n\n".join(pages)) | |
| try: | |
| os.remove(tmp_path) | |
| except Exception: | |
| pass | |
| return tmp_txt.name, pages | |
| except Exception: | |
| # fallback: treat as empty | |
| tmp_txt = tempfile.NamedTemporaryFile(delete=False, suffix=".txt", mode="w+", encoding="utf-8") | |
| with open(tmp_txt.name, "w", encoding="utf-8") as f: | |
| f.write("") | |
| return tmp_txt.name, [""] | |
| # unsupported extension: write bytes to txt and return raw decode | |
| tmp_txt = tempfile.NamedTemporaryFile(delete=False, suffix=".txt", mode="w+", encoding="utf-8") | |
| try: | |
| s = file_bytes.decode("utf-8", errors="ignore") | |
| except Exception: | |
| s = "" | |
| with open(tmp_txt.name, "w", encoding="utf-8") as f: | |
| f.write(s) | |
| pages = [p.strip() for p in re.split(r"\n\s*\n", s) if p.strip()] | |
| return tmp_txt.name, pages if pages else [s] | |
| # ------------------------- | |
| # Main analyze function | |
| # ------------------------- | |
| def analyze_document(file_obj, filename: str, do_ocr: bool, max_passages_for_qa: int = 5): | |
| """ | |
| file_obj: uploaded object | |
| filename: filename string (to detect extension) | |
| """ | |
| try: | |
| file_bytes = read_uploaded_file_to_bytes(file_obj) | |
| except Exception as e: | |
| return f"(error reading file: {e})", "", [] | |
| extracted_txt_path, pages_texts = extract_text_for_file(file_bytes, filename, do_ocr=do_ocr) | |
| # preview | |
| preview_chars = 20000 | |
| extracted_preview = "" | |
| try: | |
| with open(extracted_txt_path, "r", encoding="utf-8", errors="ignore") as f: | |
| extracted_preview = f.read(preview_chars) | |
| if len(extracted_preview) >= preview_chars: | |
| extracted_preview += "\n\n... (preview truncated) ..." | |
| except Exception: | |
| extracted_preview = "(could not read extracted text preview)" | |
| summarizer = get_pipeline(SUMMARIZER_MODEL, "summarization") | |
| combined_summary = summarize_text_chunked(summarizer, pages_texts, pages_per_chunk=8) | |
| all_passages = split_into_passages_from_pages(pages_texts, max_words=200) | |
| total = len(all_passages) | |
| if total == 0: | |
| return extracted_preview, combined_summary, [] | |
| if total <= max_passages_for_qa: | |
| chosen_passages = list(enumerate(all_passages)) | |
| else: | |
| step = max(1, math.floor(total / max_passages_for_qa)) | |
| chosen_passages = [(i, all_passages[i]) for i in range(0, total, step)][:max_passages_for_qa] | |
| qg = get_pipeline(QG_MODEL, "text2text-generation") | |
| qa = get_pipeline(QA_MODEL, "question-answering") | |
| answered = [] | |
| answered_set = set() | |
| for (p_idx, passage) in chosen_passages: | |
| if not passage.strip(): | |
| continue | |
| questions = generate_questions_from_passage(qg, passage, min_questions=3) | |
| unique_questions = [q for q in questions if q not in answered_set] | |
| if not unique_questions: | |
| continue | |
| answers = answer_questions_for_passage(qa, passage, unique_questions) | |
| for q, a, score in answers: | |
| answered.append({"passage_idx": int(p_idx), "question": q, "answer": a, "score": float(score)}) | |
| answered_set.add(q) | |
| return extracted_preview, combined_summary, answered | |
| # ------------------------- | |
| # Gradio UI | |
| # ------------------------- | |
| def build_demo(): | |
| with gr.Blocks(title="Document Analysis (LLMs)") as demo: | |
| gr.Markdown("# Document Analysis using LLMs\nUpload a supported file (PDF, DOCX, PPTX, TXT) and get summary + Q&A.") | |
| with gr.Row(): | |
| with gr.Column(scale=1): | |
| files_in = gr.File(label="Upload files (PDF, DOCX, PPTX, TXT, images...)", file_count="multiple") | |
| do_ocr = gr.Checkbox(label="Try OCR for images/PDF pages (requires OCR libs)", value=False) | |
| max_pass = gr.Slider(label="Max passages to run Q&A on (lower = faster)", minimum=1, maximum=20, step=1, value=5) | |
| run_btn = gr.Button("Analyze Document") | |
| with gr.Column(scale=2): | |
| tabs = gr.Tabs() | |
| with tabs: | |
| with gr.TabItem("Uploaded files"): | |
| uploaded_list = gr.Textbox(label="Uploaded filenames", lines=4) | |
| with gr.TabItem("Extracted Text"): | |
| extracted_out = gr.Textbox(label="Extracted text (preview)", lines=15) | |
| with gr.TabItem("Summary"): | |
| summary_out = gr.Textbox(label="Summary", lines=8) | |
| with gr.TabItem("Q&A"): | |
| qa_out = gr.Dataframe(headers=["passage_idx", "question", "answer", "score"], | |
| datatype=["number", "text", "text", "number"]) | |
| def _run(files, do_ocr_val, max_pass_val): | |
| names = get_uploaded_filenames(files) | |
| uploaded_str = "\n".join(names) if names else "(no files uploaded)" | |
| fobj, fname = find_first_supported_file(files) | |
| if fobj is None or fname is None: | |
| return uploaded_str, "(no supported file found)", "", pd.DataFrame(columns=["passage_idx", "question", "answer", "score"]) | |
| text, summary, qa = analyze_document(fobj, fname, do_ocr=do_ocr_val, max_passages_for_qa=int(max_pass_val)) | |
| if not qa: | |
| qa_df = pd.DataFrame(columns=["passage_idx", "question", "answer", "score"]) | |
| else: | |
| qa_df = pd.DataFrame(qa) | |
| qa_df = qa_df.loc[:, ["passage_idx", "question", "answer", "score"]] | |
| qa_df["passage_idx"] = qa_df["passage_idx"].astype(int) | |
| qa_df["question"] = qa_df["question"].astype(str) | |
| qa_df["answer"] = qa_df["answer"].astype(str) | |
| qa_df["score"] = qa_df["score"].astype(float) | |
| return uploaded_str, text or "(no text extracted)", summary or "(no summary)", qa_df | |
| run_btn.click(_run, inputs=[files_in, do_ocr, max_pass], outputs=[uploaded_list, extracted_out, summary_out, qa_out]) | |
| return demo | |
| if __name__ == "__main__": | |
| demo = build_demo() | |
| demo.launch() | |
| # demo.launch(server_name="0.0.0.0") | |