""" app.py Document Analysis Gradio app — updated to support PDF, DOCX (Word), and PPTX (PowerPoint). - robust file reading - streaming PDF extraction - docx/pptx extraction to pages_texts (one element per paragraph/slide) - token-aware truncation, chunked summarization, sampled Q&A - multi-file upload UI (processes first supported file) """ import os # disable noisy HF symlink warning on Windows os.environ.setdefault("HF_HUB_DISABLE_SYMLINKS_WARNING", "1") import re import io import math import tempfile import threading from pathlib import Path from typing import List, Tuple, Optional import gradio as gr import pdfplumber import nltk from nltk.tokenize import sent_tokenize from transformers import pipeline, AutoTokenizer import torch import pandas as pd from tqdm.auto import tqdm # Try to import docx and pptx; if not present, we'll handle gracefully at runtime. try: from docx import Document as DocxDocument # python-docx except Exception: DocxDocument = None try: from pptx import Presentation as PptxPresentation # python-pptx except Exception: PptxPresentation = None # ------------------------- # NLTK: ensure punkt available, fallback later # ------------------------- try: nltk.download("punkt", quiet=True) try: nltk.download("punkt_tab", quiet=True) except Exception: pass except Exception: pass # ------------------------- # Device detection # ------------------------- DEVICE = 0 if torch.cuda.is_available() else -1 print("Device set to use", "cuda" if DEVICE >= 0 else "cpu") # ------------------------- # Cached pipelines and tokenizers # ------------------------- _models = {} _tokenizers = {} _models_lock = threading.Lock() def get_pipeline(name: str, task: str): """Return a cached HF pipeline for given task and model name.""" key = f"{task}__{name}" with _models_lock: if key in _models: return _models[key] print(f"Loading pipeline: task={task}, model={name} ... (this may take a while on first run)") p = pipeline(task, model=name, device=DEVICE) _models[key] = p try: _tokenizers[name] = p.tokenizer except Exception: pass return p def get_tokenizer(name: str): """Return a cached tokenizer (fallback to AutoTokenizer if not present).""" if name in _tokenizers: return _tokenizers[name] try: tok = AutoTokenizer.from_pretrained(name) _tokenizers[name] = tok return tok except Exception: return None # ------------------------- # Default models (adjust if you want smaller/faster ones) # ------------------------- # ------------------------- # Default models (adjusted for speed) # ------------------------- SUMMARIZER_MODEL = "t5-small" # This is already fast QG_MODEL = "valhalla/t5-small-qg-hl" # Use the 'small' version QA_MODEL = "distilbert-base-cased-distilled-squad" # Much faster than RoBERTa # ------------------------- # Helpers: filenames / types # ------------------------- SUPPORTED_EXTS = [".pdf", ".docx", ".pptx", ".txt", ".md", ".rtf", ".png", ".jpg", ".jpeg", ".tiff"] def ext_of_name(name: str) -> str: return Path(name).suffix.lower() def read_uploaded_file_to_bytes(file_obj): """ Accept many shapes of Gradio file objects and return bytes. Supports list/tuple (returns first readable file), dict-like, file-likes, paths, bytes. """ # If a list/tuple of uploaded files, try each candidate if isinstance(file_obj, (list, tuple)): last_err = None for elem in file_obj: try: return read_uploaded_file_to_bytes(elem) except Exception as e: last_err = e continue raise ValueError(f"No readable file in list. Last error: {last_err}") if file_obj is None: raise ValueError("No file provided") # if it's already bytes if isinstance(file_obj, (bytes, bytearray)): return bytes(file_obj) # dict-like if isinstance(file_obj, dict): for key in ("file", "tmp_path", "name", "data", "path"): val = file_obj.get(key) if isinstance(val, (bytes, bytearray)): return bytes(val) if isinstance(val, str) and Path(val).exists(): return Path(val).read_bytes() maybe = file_obj.get("file") if hasattr(maybe, "read"): data = maybe.read() if isinstance(data, str): return data.encode("utf-8") return data # string path if isinstance(file_obj, str) and Path(file_obj).exists(): return Path(file_obj).read_bytes() # has a .name attribute that points to a file if hasattr(file_obj, "name") and isinstance(getattr(file_obj, "name"), str) and Path(file_obj.name).exists(): try: return Path(file_obj.name).read_bytes() except Exception: pass # file-like with .read() if hasattr(file_obj, "read"): try: data = file_obj.read() if isinstance(data, str): return data.encode("utf-8") return data except Exception: pass # last resort: string representation -> path try: s = str(file_obj) if Path(s).exists(): return Path(s).read_bytes() except Exception: pass raise ValueError(f"Unsupported uploaded file object type: {type(file_obj)}") def get_uploaded_filenames(file_obj) -> List[str]: """ Return a list of human-friendly filenames from uploaded file object(s). """ names = [] if file_obj is None: return names if isinstance(file_obj, (list, tuple)): for elem in file_obj: names.extend(get_uploaded_filenames(elem)) return names if isinstance(file_obj, dict): for key in ("name", "filename", "file", "tmp_path", "path"): if key in file_obj: val = file_obj.get(key) if isinstance(val, str): names.append(Path(val).name) elif hasattr(val, "name"): names.append(Path(val.name).name) maybe = file_obj.get("file") if maybe is not None: if hasattr(maybe, "name"): names.append(Path(maybe.name).name) return names if isinstance(file_obj, str): return [Path(file_obj).name] if hasattr(file_obj, "name"): return [Path(getattr(file_obj, "name")).name] return [str(file_obj)] def find_first_supported_file(files) -> Tuple[Optional[object], Optional[str]]: """ From the uploaded list or single file-like object, find the first file with a supported extension. Returns (file_obj, filename) or (None, None) if none supported. """ if not files: return None, None candidates = [] if isinstance(files, (list, tuple)): for f in files: names = get_uploaded_filenames(f) for n in names: candidates.append((f, n)) else: names = get_uploaded_filenames(files) for n in names: candidates.append((files, n)) for fobj, name in candidates: ext = ext_of_name(name) if ext in SUPPORTED_EXTS: return fobj, name # fallback: if none matched, return first uploaded if candidates: return candidates[0] return None, None # ------------------------- # PDF extraction (streaming) # ------------------------- def extract_text_from_pdf_streaming(file_bytes: bytes, do_ocr: bool = False, extracted_txt_path: str = None): """Write PDF bytes to temp file and extract each page's text; returns extracted_txt_path, pages_texts list""" with tempfile.NamedTemporaryFile(delete=False, suffix=".pdf") as tmp_pdf: tmp_pdf.write(file_bytes) tmp_pdf_path = tmp_pdf.name if extracted_txt_path is None: tmp_txt = tempfile.NamedTemporaryFile(delete=False, suffix=".txt", mode="w+", encoding="utf-8") extracted_txt_path = tmp_txt.name tmp_txt.close() pages_texts = [] try: with pdfplumber.open(tmp_pdf_path) as pdf: for page in pdf.pages: page_text = page.extract_text() if (not page_text or page_text.strip() == "") and do_ocr: try: from pdf2image import convert_from_path import pytesseract images = convert_from_path(tmp_pdf_path, first_page=page.page_number, last_page=page.page_number) if images: page_text = pytesseract.image_to_string(images[0]) except Exception: page_text = "" if page_text is None: page_text = "" with open(extracted_txt_path, "a", encoding="utf-8") as fout: fout.write(page_text) fout.write("\n\n---PAGE_BREAK---\n\n") pages_texts.append(page_text) finally: try: os.remove(tmp_pdf_path) except Exception: pass return extracted_txt_path, pages_texts # ------------------------- # DOCX extraction # ------------------------- def extract_text_from_docx_bytes(file_bytes: bytes) -> Tuple[str, List[str]]: """ Returns (tmp_text_path, pages_texts) where pages_texts is a list of paragraph groups. We'll treat each paragraph as a small 'page' or group paragraphs into ~200-word chunks. """ if DocxDocument is None: raise RuntimeError("python-docx not installed. Install with `pip install python-docx`") with tempfile.NamedTemporaryFile(delete=False, suffix=".docx") as tmp: tmp.write(file_bytes) tmp_path = tmp.name pages_texts = [] try: doc = DocxDocument(tmp_path) # Collect non-empty paragraphs paras = [p.text.strip() for p in doc.paragraphs if p.text and p.text.strip()] # Group paragraphs into passages of ~200 words current = [] cur_words = 0 for p in paras: w = len(p.split()) if cur_words + w <= 200: current.append(p) cur_words += w else: pages_texts.append(" ".join(current).strip()) current = [p] cur_words = w if current: pages_texts.append(" ".join(current).strip()) # write full extracted text to tmp file for download/debug if needed tmp_txt = tempfile.NamedTemporaryFile(delete=False, suffix=".txt", mode="w+", encoding="utf-8") with open(tmp_txt.name, "w", encoding="utf-8") as f: f.write("\n\n".join(pages_texts)) txt_path = tmp_txt.name finally: try: os.remove(tmp_path) except Exception: pass return txt_path, pages_texts # ------------------------- # PPTX extraction # ------------------------- def extract_text_from_pptx_bytes(file_bytes: bytes) -> Tuple[str, List[str]]: """ Returns (tmp_text_path, pages_texts) where each slide's text is one element. """ if PptxPresentation is None: raise RuntimeError("python-pptx not installed. Install with `pip install python-pptx`") with tempfile.NamedTemporaryFile(delete=False, suffix=".pptx") as tmp: tmp.write(file_bytes) tmp_path = tmp.name pages_texts = [] try: prs = PptxPresentation(tmp_path) for slide in prs.slides: texts = [] for shape in slide.shapes: try: if hasattr(shape, "text") and shape.text: texts.append(shape.text.strip()) except Exception: continue slide_text = "\n".join([t for t in texts if t]) pages_texts.append(slide_text) tmp_txt = tempfile.NamedTemporaryFile(delete=False, suffix=".txt", mode="w+", encoding="utf-8") with open(tmp_txt.name, "w", encoding="utf-8") as f: f.write("\n\n".join(pages_texts)) txt_path = tmp_txt.name finally: try: os.remove(tmp_path) except Exception: pass return txt_path, pages_texts # ------------------------- # TXT/MD extraction # ------------------------- def extract_text_from_txt_bytes(file_bytes: bytes) -> Tuple[str, List[str]]: s = file_bytes.decode("utf-8", errors="ignore") # split into passages by blank lines or ~200 words paras = [p.strip() for p in re.split(r"\n\s*\n", s) if p.strip()] pages_texts = [] cur = [] cur_words = 0 for p in paras: w = len(p.split()) if cur_words + w <= 200: cur.append(p) cur_words += w else: pages_texts.append(" ".join(cur).strip()) cur = [p] cur_words = w if cur: pages_texts.append(" ".join(cur).strip()) tmp_txt = tempfile.NamedTemporaryFile(delete=False, suffix=".txt", mode="w+", encoding="utf-8") with open(tmp_txt.name, "w", encoding="utf-8") as f: f.write("\n\n".join(pages_texts)) return tmp_txt.name, pages_texts # ------------------------- # Token-aware truncation helpers # ------------------------- def truncate_by_tokens(text: str, tokenizer, reserve: int = 64) -> str: if not text: return text if tokenizer is None: return text if len(text) <= 3000 else text[:3000] try: ids = tokenizer.encode(text, add_special_tokens=False) max_len = getattr(tokenizer, "model_max_length", 512) allowed = max(1, max_len - reserve) if len(ids) > allowed: ids = ids[:allowed] return tokenizer.decode(ids, skip_special_tokens=True, clean_up_tokenization_spaces=True) return text except Exception: return text if len(text) <= 3000 else text[:3000] # ------------------------- # Sentence tokenization fallback # ------------------------- def safe_sentence_tokenize(text: str) -> List[str]: try: sents = sent_tokenize(text) if isinstance(sents, list) and len(sents) > 0: return sents except Exception: pass pieces = re.split(r"(?<=[.!?])\s+", text.strip()) return [p.strip() for p in pieces if p.strip()] # ------------------------- # Summarization (chunked, token-truncated) # ------------------------- def summarize_text_chunked(summarizer, pages_texts: List[str], pages_per_chunk: int = 8) -> str: if not pages_texts: return "(no text)" summaries = [] tokenizer = getattr(summarizer, "tokenizer", None) or get_tokenizer(SUMMARIZER_MODEL) num_pages = len(pages_texts) for i in range(0, num_pages, pages_per_chunk): chunk_pages = pages_texts[i : i + pages_per_chunk] chunk_text = "\n\n".join([p for p in chunk_pages if p.strip()]) if not chunk_text.strip(): continue safe_chunk = truncate_by_tokens(chunk_text, tokenizer, reserve=64) try: out = summarizer(safe_chunk, max_length=150, min_length=30, do_sample=False, truncation=True) summaries.append(out[0]["summary_text"]) except Exception: summaries.append(safe_chunk[:800]) return "\n\n".join(summaries) if summaries else "(no summary produced)" # ------------------------- # Passage splitting / QG / QA with token truncation # ------------------------- def split_into_passages_from_pages(pages_texts: List[str], max_words: int = 200) -> List[str]: all_passages = [] for page_text in pages_texts: if not page_text or not page_text.strip(): continue sents = safe_sentence_tokenize(page_text) cur = [] cur_len = 0 for s in sents: w = len(s.split()) if cur_len + w <= max_words: cur.append(s) cur_len += w else: if cur: all_passages.append(" ".join(cur).strip()) cur = [s] cur_len = w if cur: all_passages.append(" ".join(cur).strip()) return all_passages def generate_questions_from_passage(qg_pipeline, passage: str, min_questions: int = 3) -> List[str]: tok = getattr(qg_pipeline, "tokenizer", None) or get_tokenizer(QG_MODEL) safe_passage = truncate_by_tokens(passage, tok, reserve=32) prompt = f"generate questions: {safe_passage}" try: out = qg_pipeline(prompt, max_length=256, do_sample=False, truncation=True) gen_text = out[0].get("generated_text") or out[0].get("text") or "" except Exception: gen_text = "" candidates = [] if "" in gen_text: candidates = gen_text.split("") elif "\n" in gen_text: candidates = [line.strip() for line in gen_text.splitlines() if line.strip()] else: parts = [p.strip() for p in gen_text.split("?") if p.strip()] candidates = [p + "?" for p in parts] questions = [q.strip() for q in candidates if q.strip()] if len(questions) < min_questions: sentences = safe_sentence_tokenize(safe_passage) for i in range(len(sentences)): if len(questions) >= min_questions: break small = ". ".join(sentences[i : i + 2]) try: out2 = qg_pipeline(f"generate questions: {small}", max_length=128, do_sample=False, truncation=True) txt = out2[0].get("generated_text") or out2[0].get("text") or "" except Exception: txt = "" if "" in txt: more = txt.split("") else: more = [l.strip() for l in txt.splitlines() if l.strip()] for m in more: if len(questions) >= min_questions: break maybe = m.strip() if maybe and maybe not in questions: questions.append(maybe) return questions[:max(min_questions, len(questions))] def answer_questions_for_passage(qa_pipeline, passage: str, questions: List[str]) -> List[Tuple[str, str, float]]: results = [] tok = getattr(qa_pipeline, "tokenizer", None) or get_tokenizer(QA_MODEL) for q in questions: try: safe_ctx = truncate_by_tokens(passage, tok, reserve=64) res = qa_pipeline(question=q, context=safe_ctx) answer = res.get("answer", "") score = float(res.get("score", 0.0)) except Exception: answer = "" score = 0.0 results.append((q, answer, score)) return results # ------------------------- # Unified extract_text_for_file: dispatch by extension # ------------------------- def extract_text_for_file(file_bytes: bytes, filename: str, do_ocr: bool = False) -> Tuple[str, List[str]]: """ Given raw bytes and a filename, return (extracted_txt_path, pages_texts). Supports PDF (streamed), DOCX, PPTX, TXT/MD. For images, attempt OCR if do_ocr True. """ ext = ext_of_name(filename) if ext == ".pdf": return extract_text_from_pdf_streaming(file_bytes, do_ocr=do_ocr) if ext == ".docx": return extract_text_from_docx_bytes(file_bytes) if ext == ".pptx": return extract_text_from_pptx_bytes(file_bytes) if ext in (".txt", ".md", ".rtf"): return extract_text_from_txt_bytes(file_bytes) # images: try OCR if requested if ext in (".png", ".jpg", ".jpeg", ".tiff", ".bmp") and do_ocr: # write image bytes to temp file and run OCR with pytesseract if available try: from PIL import Image import pytesseract with tempfile.NamedTemporaryFile(delete=False, suffix=ext) as tmp: tmp.write(file_bytes) tmp_path = tmp.name img = Image.open(tmp_path) ocr_txt = pytesseract.image_to_string(img) pages = [p.strip() for p in re.split(r"\n\s*\n", ocr_txt) if p.strip()] tmp_txt = tempfile.NamedTemporaryFile(delete=False, suffix=".txt", mode="w+", encoding="utf-8") with open(tmp_txt.name, "w", encoding="utf-8") as f: f.write("\n\n".join(pages)) try: os.remove(tmp_path) except Exception: pass return tmp_txt.name, pages except Exception: # fallback: treat as empty tmp_txt = tempfile.NamedTemporaryFile(delete=False, suffix=".txt", mode="w+", encoding="utf-8") with open(tmp_txt.name, "w", encoding="utf-8") as f: f.write("") return tmp_txt.name, [""] # unsupported extension: write bytes to txt and return raw decode tmp_txt = tempfile.NamedTemporaryFile(delete=False, suffix=".txt", mode="w+", encoding="utf-8") try: s = file_bytes.decode("utf-8", errors="ignore") except Exception: s = "" with open(tmp_txt.name, "w", encoding="utf-8") as f: f.write(s) pages = [p.strip() for p in re.split(r"\n\s*\n", s) if p.strip()] return tmp_txt.name, pages if pages else [s] # ------------------------- # Main analyze function # ------------------------- def analyze_document(file_obj, filename: str, do_ocr: bool, max_passages_for_qa: int = 5): """ file_obj: uploaded object filename: filename string (to detect extension) """ try: file_bytes = read_uploaded_file_to_bytes(file_obj) except Exception as e: return f"(error reading file: {e})", "", [] extracted_txt_path, pages_texts = extract_text_for_file(file_bytes, filename, do_ocr=do_ocr) # preview preview_chars = 20000 extracted_preview = "" try: with open(extracted_txt_path, "r", encoding="utf-8", errors="ignore") as f: extracted_preview = f.read(preview_chars) if len(extracted_preview) >= preview_chars: extracted_preview += "\n\n... (preview truncated) ..." except Exception: extracted_preview = "(could not read extracted text preview)" summarizer = get_pipeline(SUMMARIZER_MODEL, "summarization") combined_summary = summarize_text_chunked(summarizer, pages_texts, pages_per_chunk=8) all_passages = split_into_passages_from_pages(pages_texts, max_words=200) total = len(all_passages) if total == 0: return extracted_preview, combined_summary, [] if total <= max_passages_for_qa: chosen_passages = list(enumerate(all_passages)) else: step = max(1, math.floor(total / max_passages_for_qa)) chosen_passages = [(i, all_passages[i]) for i in range(0, total, step)][:max_passages_for_qa] qg = get_pipeline(QG_MODEL, "text2text-generation") qa = get_pipeline(QA_MODEL, "question-answering") answered = [] answered_set = set() for (p_idx, passage) in chosen_passages: if not passage.strip(): continue questions = generate_questions_from_passage(qg, passage, min_questions=3) unique_questions = [q for q in questions if q not in answered_set] if not unique_questions: continue answers = answer_questions_for_passage(qa, passage, unique_questions) for q, a, score in answers: answered.append({"passage_idx": int(p_idx), "question": q, "answer": a, "score": float(score)}) answered_set.add(q) return extracted_preview, combined_summary, answered # ------------------------- # Gradio UI # ------------------------- def build_demo(): with gr.Blocks(title="Document Analysis (LLMs)") as demo: gr.Markdown("# Document Analysis using LLMs\nUpload a supported file (PDF, DOCX, PPTX, TXT) and get summary + Q&A.") with gr.Row(): with gr.Column(scale=1): files_in = gr.File(label="Upload files (PDF, DOCX, PPTX, TXT, images...)", file_count="multiple") do_ocr = gr.Checkbox(label="Try OCR for images/PDF pages (requires OCR libs)", value=False) max_pass = gr.Slider(label="Max passages to run Q&A on (lower = faster)", minimum=1, maximum=20, step=1, value=5) run_btn = gr.Button("Analyze Document") with gr.Column(scale=2): tabs = gr.Tabs() with tabs: with gr.TabItem("Uploaded files"): uploaded_list = gr.Textbox(label="Uploaded filenames", lines=4) with gr.TabItem("Extracted Text"): extracted_out = gr.Textbox(label="Extracted text (preview)", lines=15) with gr.TabItem("Summary"): summary_out = gr.Textbox(label="Summary", lines=8) with gr.TabItem("Q&A"): qa_out = gr.Dataframe(headers=["passage_idx", "question", "answer", "score"], datatype=["number", "text", "text", "number"]) def _run(files, do_ocr_val, max_pass_val): names = get_uploaded_filenames(files) uploaded_str = "\n".join(names) if names else "(no files uploaded)" fobj, fname = find_first_supported_file(files) if fobj is None or fname is None: return uploaded_str, "(no supported file found)", "", pd.DataFrame(columns=["passage_idx", "question", "answer", "score"]) text, summary, qa = analyze_document(fobj, fname, do_ocr=do_ocr_val, max_passages_for_qa=int(max_pass_val)) if not qa: qa_df = pd.DataFrame(columns=["passage_idx", "question", "answer", "score"]) else: qa_df = pd.DataFrame(qa) qa_df = qa_df.loc[:, ["passage_idx", "question", "answer", "score"]] qa_df["passage_idx"] = qa_df["passage_idx"].astype(int) qa_df["question"] = qa_df["question"].astype(str) qa_df["answer"] = qa_df["answer"].astype(str) qa_df["score"] = qa_df["score"].astype(float) return uploaded_str, text or "(no text extracted)", summary or "(no summary)", qa_df run_btn.click(_run, inputs=[files_in, do_ocr, max_pass], outputs=[uploaded_list, extracted_out, summary_out, qa_out]) return demo if __name__ == "__main__": demo = build_demo() demo.launch() # demo.launch(server_name="0.0.0.0")