loki2910 commited on
Commit
1351f70
·
verified ·
1 Parent(s): 85fa7cd

Upload 2 files

Browse files
Files changed (2) hide show
  1. app.py +715 -0
  2. requirements.txt +13 -0
app.py ADDED
@@ -0,0 +1,715 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ app.py
3
+
4
+ Document Analysis Gradio app — updated to support PDF, DOCX (Word), and PPTX (PowerPoint).
5
+ - robust file reading
6
+ - streaming PDF extraction
7
+ - docx/pptx extraction to pages_texts (one element per paragraph/slide)
8
+ - token-aware truncation, chunked summarization, sampled Q&A
9
+ - multi-file upload UI (processes first supported file)
10
+ """
11
+
12
+ import os
13
+ # disable noisy HF symlink warning on Windows
14
+ os.environ.setdefault("HF_HUB_DISABLE_SYMLINKS_WARNING", "1")
15
+
16
+ import re
17
+ import io
18
+ import math
19
+ import tempfile
20
+ import threading
21
+ from pathlib import Path
22
+ from typing import List, Tuple, Optional
23
+
24
+ import gradio as gr
25
+ import pdfplumber
26
+ import nltk
27
+ from nltk.tokenize import sent_tokenize
28
+ from transformers import pipeline, AutoTokenizer
29
+ import torch
30
+ import pandas as pd
31
+ from tqdm.auto import tqdm
32
+
33
+ # Try to import docx and pptx; if not present, we'll handle gracefully at runtime.
34
+ try:
35
+ from docx import Document as DocxDocument # python-docx
36
+ except Exception:
37
+ DocxDocument = None
38
+
39
+ try:
40
+ from pptx import Presentation as PptxPresentation # python-pptx
41
+ except Exception:
42
+ PptxPresentation = None
43
+
44
+ # -------------------------
45
+ # NLTK: ensure punkt available, fallback later
46
+ # -------------------------
47
+ try:
48
+ nltk.download("punkt", quiet=True)
49
+ try:
50
+ nltk.download("punkt_tab", quiet=True)
51
+ except Exception:
52
+ pass
53
+ except Exception:
54
+ pass
55
+
56
+ # -------------------------
57
+ # Device detection
58
+ # -------------------------
59
+ DEVICE = 0 if torch.cuda.is_available() else -1
60
+ print("Device set to use", "cuda" if DEVICE >= 0 else "cpu")
61
+
62
+ # -------------------------
63
+ # Cached pipelines and tokenizers
64
+ # -------------------------
65
+ _models = {}
66
+ _tokenizers = {}
67
+ _models_lock = threading.Lock()
68
+
69
+
70
+ def get_pipeline(name: str, task: str):
71
+ """Return a cached HF pipeline for given task and model name."""
72
+ key = f"{task}__{name}"
73
+ with _models_lock:
74
+ if key in _models:
75
+ return _models[key]
76
+ print(f"Loading pipeline: task={task}, model={name} ... (this may take a while on first run)")
77
+ p = pipeline(task, model=name, device=DEVICE)
78
+ _models[key] = p
79
+ try:
80
+ _tokenizers[name] = p.tokenizer
81
+ except Exception:
82
+ pass
83
+ return p
84
+
85
+
86
+ def get_tokenizer(name: str):
87
+ """Return a cached tokenizer (fallback to AutoTokenizer if not present)."""
88
+ if name in _tokenizers:
89
+ return _tokenizers[name]
90
+ try:
91
+ tok = AutoTokenizer.from_pretrained(name)
92
+ _tokenizers[name] = tok
93
+ return tok
94
+ except Exception:
95
+ return None
96
+
97
+
98
+ # -------------------------
99
+ # Default models (adjust if you want smaller/faster ones)
100
+ # -------------------------
101
+ SUMMARIZER_MODEL = "t5-small"
102
+ QG_MODEL = "valhalla/t5-base-qg-hl"
103
+ QA_MODEL = "deepset/roberta-base-squad2"
104
+
105
+
106
+ # -------------------------
107
+ # Helpers: filenames / types
108
+ # -------------------------
109
+ SUPPORTED_EXTS = [".pdf", ".docx", ".pptx", ".txt", ".md", ".rtf", ".png", ".jpg", ".jpeg", ".tiff"]
110
+
111
+ def ext_of_name(name: str) -> str:
112
+ return Path(name).suffix.lower()
113
+
114
+
115
+ # -------------------------
116
+ # Robust file reading (supports lists)
117
+ # -------------------------
118
+ def read_uploaded_file_to_bytes(file_obj):
119
+ """
120
+ Accept many shapes of Gradio file objects and return bytes.
121
+ Supports list/tuple (returns first readable file), dict-like, file-likes, paths, bytes.
122
+ """
123
+ # If a list/tuple of uploaded files, try each candidate
124
+ if isinstance(file_obj, (list, tuple)):
125
+ last_err = None
126
+ for elem in file_obj:
127
+ try:
128
+ return read_uploaded_file_to_bytes(elem)
129
+ except Exception as e:
130
+ last_err = e
131
+ continue
132
+ raise ValueError(f"No readable file in list. Last error: {last_err}")
133
+
134
+ if file_obj is None:
135
+ raise ValueError("No file provided")
136
+
137
+ # if it's already bytes
138
+ if isinstance(file_obj, (bytes, bytearray)):
139
+ return bytes(file_obj)
140
+
141
+ # dict-like
142
+ if isinstance(file_obj, dict):
143
+ for key in ("file", "tmp_path", "name", "data", "path"):
144
+ val = file_obj.get(key)
145
+ if isinstance(val, (bytes, bytearray)):
146
+ return bytes(val)
147
+ if isinstance(val, str) and Path(val).exists():
148
+ return Path(val).read_bytes()
149
+ maybe = file_obj.get("file")
150
+ if hasattr(maybe, "read"):
151
+ data = maybe.read()
152
+ if isinstance(data, str):
153
+ return data.encode("utf-8")
154
+ return data
155
+
156
+ # string path
157
+ if isinstance(file_obj, str) and Path(file_obj).exists():
158
+ return Path(file_obj).read_bytes()
159
+
160
+ # has a .name attribute that points to a file
161
+ if hasattr(file_obj, "name") and isinstance(getattr(file_obj, "name"), str) and Path(file_obj.name).exists():
162
+ try:
163
+ return Path(file_obj.name).read_bytes()
164
+ except Exception:
165
+ pass
166
+
167
+ # file-like with .read()
168
+ if hasattr(file_obj, "read"):
169
+ try:
170
+ data = file_obj.read()
171
+ if isinstance(data, str):
172
+ return data.encode("utf-8")
173
+ return data
174
+ except Exception:
175
+ pass
176
+
177
+ # last resort: string representation -> path
178
+ try:
179
+ s = str(file_obj)
180
+ if Path(s).exists():
181
+ return Path(s).read_bytes()
182
+ except Exception:
183
+ pass
184
+
185
+ raise ValueError(f"Unsupported uploaded file object type: {type(file_obj)}")
186
+
187
+
188
+ def get_uploaded_filenames(file_obj) -> List[str]:
189
+ """
190
+ Return a list of human-friendly filenames from uploaded file object(s).
191
+ """
192
+ names = []
193
+ if file_obj is None:
194
+ return names
195
+ if isinstance(file_obj, (list, tuple)):
196
+ for elem in file_obj:
197
+ names.extend(get_uploaded_filenames(elem))
198
+ return names
199
+ if isinstance(file_obj, dict):
200
+ for key in ("name", "filename", "file", "tmp_path", "path"):
201
+ if key in file_obj:
202
+ val = file_obj.get(key)
203
+ if isinstance(val, str):
204
+ names.append(Path(val).name)
205
+ elif hasattr(val, "name"):
206
+ names.append(Path(val.name).name)
207
+ maybe = file_obj.get("file")
208
+ if maybe is not None:
209
+ if hasattr(maybe, "name"):
210
+ names.append(Path(maybe.name).name)
211
+ return names
212
+ if isinstance(file_obj, str):
213
+ return [Path(file_obj).name]
214
+ if hasattr(file_obj, "name"):
215
+ return [Path(getattr(file_obj, "name")).name]
216
+ return [str(file_obj)]
217
+
218
+
219
+ def find_first_supported_file(files) -> Tuple[Optional[object], Optional[str]]:
220
+ """
221
+ From the uploaded list or single file-like object, find the first file with a supported extension.
222
+ Returns (file_obj, filename) or (None, None) if none supported.
223
+ """
224
+ if not files:
225
+ return None, None
226
+ candidates = []
227
+ if isinstance(files, (list, tuple)):
228
+ for f in files:
229
+ names = get_uploaded_filenames(f)
230
+ for n in names:
231
+ candidates.append((f, n))
232
+ else:
233
+ names = get_uploaded_filenames(files)
234
+ for n in names:
235
+ candidates.append((files, n))
236
+ for fobj, name in candidates:
237
+ ext = ext_of_name(name)
238
+ if ext in SUPPORTED_EXTS:
239
+ return fobj, name
240
+ # fallback: if none matched, return first uploaded
241
+ if candidates:
242
+ return candidates[0]
243
+ return None, None
244
+
245
+
246
+ # -------------------------
247
+ # PDF extraction (streaming)
248
+ # -------------------------
249
+ def extract_text_from_pdf_streaming(file_bytes: bytes, do_ocr: bool = False, extracted_txt_path: str = None):
250
+ """Write PDF bytes to temp file and extract each page's text; returns extracted_txt_path, pages_texts list"""
251
+ with tempfile.NamedTemporaryFile(delete=False, suffix=".pdf") as tmp_pdf:
252
+ tmp_pdf.write(file_bytes)
253
+ tmp_pdf_path = tmp_pdf.name
254
+
255
+ if extracted_txt_path is None:
256
+ tmp_txt = tempfile.NamedTemporaryFile(delete=False, suffix=".txt", mode="w+", encoding="utf-8")
257
+ extracted_txt_path = tmp_txt.name
258
+ tmp_txt.close()
259
+
260
+ pages_texts = []
261
+ try:
262
+ with pdfplumber.open(tmp_pdf_path) as pdf:
263
+ for page in pdf.pages:
264
+ page_text = page.extract_text()
265
+ if (not page_text or page_text.strip() == "") and do_ocr:
266
+ try:
267
+ from pdf2image import convert_from_path
268
+ import pytesseract
269
+ images = convert_from_path(tmp_pdf_path, first_page=page.page_number, last_page=page.page_number)
270
+ if images:
271
+ page_text = pytesseract.image_to_string(images[0])
272
+ except Exception:
273
+ page_text = ""
274
+ if page_text is None:
275
+ page_text = ""
276
+ with open(extracted_txt_path, "a", encoding="utf-8") as fout:
277
+ fout.write(page_text)
278
+ fout.write("\n\n---PAGE_BREAK---\n\n")
279
+ pages_texts.append(page_text)
280
+ finally:
281
+ try:
282
+ os.remove(tmp_pdf_path)
283
+ except Exception:
284
+ pass
285
+
286
+ return extracted_txt_path, pages_texts
287
+
288
+
289
+ # -------------------------
290
+ # DOCX extraction
291
+ # -------------------------
292
+ def extract_text_from_docx_bytes(file_bytes: bytes) -> Tuple[str, List[str]]:
293
+ """
294
+ Returns (tmp_text_path, pages_texts) where pages_texts is a list of paragraph groups.
295
+ We'll treat each paragraph as a small 'page' or group paragraphs into ~200-word chunks.
296
+ """
297
+ if DocxDocument is None:
298
+ raise RuntimeError("python-docx not installed. Install with `pip install python-docx`")
299
+
300
+ with tempfile.NamedTemporaryFile(delete=False, suffix=".docx") as tmp:
301
+ tmp.write(file_bytes)
302
+ tmp_path = tmp.name
303
+
304
+ pages_texts = []
305
+ try:
306
+ doc = DocxDocument(tmp_path)
307
+ # Collect non-empty paragraphs
308
+ paras = [p.text.strip() for p in doc.paragraphs if p.text and p.text.strip()]
309
+ # Group paragraphs into passages of ~200 words
310
+ current = []
311
+ cur_words = 0
312
+ for p in paras:
313
+ w = len(p.split())
314
+ if cur_words + w <= 200:
315
+ current.append(p)
316
+ cur_words += w
317
+ else:
318
+ pages_texts.append(" ".join(current).strip())
319
+ current = [p]
320
+ cur_words = w
321
+ if current:
322
+ pages_texts.append(" ".join(current).strip())
323
+ # write full extracted text to tmp file for download/debug if needed
324
+ tmp_txt = tempfile.NamedTemporaryFile(delete=False, suffix=".txt", mode="w+", encoding="utf-8")
325
+ with open(tmp_txt.name, "w", encoding="utf-8") as f:
326
+ f.write("\n\n".join(pages_texts))
327
+ txt_path = tmp_txt.name
328
+ finally:
329
+ try:
330
+ os.remove(tmp_path)
331
+ except Exception:
332
+ pass
333
+
334
+ return txt_path, pages_texts
335
+
336
+
337
+ # -------------------------
338
+ # PPTX extraction
339
+ # -------------------------
340
+ def extract_text_from_pptx_bytes(file_bytes: bytes) -> Tuple[str, List[str]]:
341
+ """
342
+ Returns (tmp_text_path, pages_texts) where each slide's text is one element.
343
+ """
344
+ if PptxPresentation is None:
345
+ raise RuntimeError("python-pptx not installed. Install with `pip install python-pptx`")
346
+
347
+ with tempfile.NamedTemporaryFile(delete=False, suffix=".pptx") as tmp:
348
+ tmp.write(file_bytes)
349
+ tmp_path = tmp.name
350
+
351
+ pages_texts = []
352
+ try:
353
+ prs = PptxPresentation(tmp_path)
354
+ for slide in prs.slides:
355
+ texts = []
356
+ for shape in slide.shapes:
357
+ try:
358
+ if hasattr(shape, "text") and shape.text:
359
+ texts.append(shape.text.strip())
360
+ except Exception:
361
+ continue
362
+ slide_text = "\n".join([t for t in texts if t])
363
+ pages_texts.append(slide_text)
364
+ tmp_txt = tempfile.NamedTemporaryFile(delete=False, suffix=".txt", mode="w+", encoding="utf-8")
365
+ with open(tmp_txt.name, "w", encoding="utf-8") as f:
366
+ f.write("\n\n".join(pages_texts))
367
+ txt_path = tmp_txt.name
368
+ finally:
369
+ try:
370
+ os.remove(tmp_path)
371
+ except Exception:
372
+ pass
373
+
374
+ return txt_path, pages_texts
375
+
376
+
377
+ # -------------------------
378
+ # TXT/MD extraction
379
+ # -------------------------
380
+ def extract_text_from_txt_bytes(file_bytes: bytes) -> Tuple[str, List[str]]:
381
+ s = file_bytes.decode("utf-8", errors="ignore")
382
+ # split into passages by blank lines or ~200 words
383
+ paras = [p.strip() for p in re.split(r"\n\s*\n", s) if p.strip()]
384
+ pages_texts = []
385
+ cur = []
386
+ cur_words = 0
387
+ for p in paras:
388
+ w = len(p.split())
389
+ if cur_words + w <= 200:
390
+ cur.append(p)
391
+ cur_words += w
392
+ else:
393
+ pages_texts.append(" ".join(cur).strip())
394
+ cur = [p]
395
+ cur_words = w
396
+ if cur:
397
+ pages_texts.append(" ".join(cur).strip())
398
+ tmp_txt = tempfile.NamedTemporaryFile(delete=False, suffix=".txt", mode="w+", encoding="utf-8")
399
+ with open(tmp_txt.name, "w", encoding="utf-8") as f:
400
+ f.write("\n\n".join(pages_texts))
401
+ return tmp_txt.name, pages_texts
402
+
403
+
404
+ # -------------------------
405
+ # Token-aware truncation helpers
406
+ # -------------------------
407
+ def truncate_by_tokens(text: str, tokenizer, reserve: int = 64) -> str:
408
+ if not text:
409
+ return text
410
+ if tokenizer is None:
411
+ return text if len(text) <= 3000 else text[:3000]
412
+ try:
413
+ ids = tokenizer.encode(text, add_special_tokens=False)
414
+ max_len = getattr(tokenizer, "model_max_length", 512)
415
+ allowed = max(1, max_len - reserve)
416
+ if len(ids) > allowed:
417
+ ids = ids[:allowed]
418
+ return tokenizer.decode(ids, skip_special_tokens=True, clean_up_tokenization_spaces=True)
419
+ return text
420
+ except Exception:
421
+ return text if len(text) <= 3000 else text[:3000]
422
+
423
+
424
+ # -------------------------
425
+ # Sentence tokenization fallback
426
+ # -------------------------
427
+ def safe_sentence_tokenize(text: str) -> List[str]:
428
+ try:
429
+ sents = sent_tokenize(text)
430
+ if isinstance(sents, list) and len(sents) > 0:
431
+ return sents
432
+ except Exception:
433
+ pass
434
+ pieces = re.split(r"(?<=[.!?])\s+", text.strip())
435
+ return [p.strip() for p in pieces if p.strip()]
436
+
437
+
438
+ # -------------------------
439
+ # Summarization (chunked, token-truncated)
440
+ # -------------------------
441
+ def summarize_text_chunked(summarizer, pages_texts: List[str], pages_per_chunk: int = 8) -> str:
442
+ if not pages_texts:
443
+ return "(no text)"
444
+ summaries = []
445
+ tokenizer = getattr(summarizer, "tokenizer", None) or get_tokenizer(SUMMARIZER_MODEL)
446
+ num_pages = len(pages_texts)
447
+ for i in range(0, num_pages, pages_per_chunk):
448
+ chunk_pages = pages_texts[i : i + pages_per_chunk]
449
+ chunk_text = "\n\n".join([p for p in chunk_pages if p.strip()])
450
+ if not chunk_text.strip():
451
+ continue
452
+ safe_chunk = truncate_by_tokens(chunk_text, tokenizer, reserve=64)
453
+ try:
454
+ out = summarizer(safe_chunk, max_length=150, min_length=30, do_sample=False, truncation=True)
455
+ summaries.append(out[0]["summary_text"])
456
+ except Exception:
457
+ summaries.append(safe_chunk[:800])
458
+ return "\n\n".join(summaries) if summaries else "(no summary produced)"
459
+
460
+
461
+ # -------------------------
462
+ # Passage splitting / QG / QA with token truncation
463
+ # -------------------------
464
+ def split_into_passages_from_pages(pages_texts: List[str], max_words: int = 200) -> List[str]:
465
+ all_passages = []
466
+ for page_text in pages_texts:
467
+ if not page_text or not page_text.strip():
468
+ continue
469
+ sents = safe_sentence_tokenize(page_text)
470
+ cur = []
471
+ cur_len = 0
472
+ for s in sents:
473
+ w = len(s.split())
474
+ if cur_len + w <= max_words:
475
+ cur.append(s)
476
+ cur_len += w
477
+ else:
478
+ if cur:
479
+ all_passages.append(" ".join(cur).strip())
480
+ cur = [s]
481
+ cur_len = w
482
+ if cur:
483
+ all_passages.append(" ".join(cur).strip())
484
+ return all_passages
485
+
486
+
487
+ def generate_questions_from_passage(qg_pipeline, passage: str, min_questions: int = 3) -> List[str]:
488
+ tok = getattr(qg_pipeline, "tokenizer", None) or get_tokenizer(QG_MODEL)
489
+ safe_passage = truncate_by_tokens(passage, tok, reserve=32)
490
+ prompt = f"generate questions: {safe_passage}"
491
+ try:
492
+ out = qg_pipeline(prompt, max_length=256, do_sample=False, truncation=True)
493
+ gen_text = out[0].get("generated_text") or out[0].get("text") or ""
494
+ except Exception:
495
+ gen_text = ""
496
+ candidates = []
497
+ if "<sep>" in gen_text:
498
+ candidates = gen_text.split("<sep>")
499
+ elif "\n" in gen_text:
500
+ candidates = [line.strip() for line in gen_text.splitlines() if line.strip()]
501
+ else:
502
+ parts = [p.strip() for p in gen_text.split("?") if p.strip()]
503
+ candidates = [p + "?" for p in parts]
504
+ questions = [q.strip() for q in candidates if q.strip()]
505
+ if len(questions) < min_questions:
506
+ sentences = safe_sentence_tokenize(safe_passage)
507
+ for i in range(len(sentences)):
508
+ if len(questions) >= min_questions:
509
+ break
510
+ small = ". ".join(sentences[i : i + 2])
511
+ try:
512
+ out2 = qg_pipeline(f"generate questions: {small}", max_length=128, do_sample=False, truncation=True)
513
+ txt = out2[0].get("generated_text") or out2[0].get("text") or ""
514
+ except Exception:
515
+ txt = ""
516
+ if "<sep>" in txt:
517
+ more = txt.split("<sep>")
518
+ else:
519
+ more = [l.strip() for l in txt.splitlines() if l.strip()]
520
+ for m in more:
521
+ if len(questions) >= min_questions:
522
+ break
523
+ maybe = m.strip()
524
+ if maybe and maybe not in questions:
525
+ questions.append(maybe)
526
+ return questions[:max(min_questions, len(questions))]
527
+
528
+
529
+ def answer_questions_for_passage(qa_pipeline, passage: str, questions: List[str]) -> List[Tuple[str, str, float]]:
530
+ results = []
531
+ tok = getattr(qa_pipeline, "tokenizer", None) or get_tokenizer(QA_MODEL)
532
+ for q in questions:
533
+ try:
534
+ safe_ctx = truncate_by_tokens(passage, tok, reserve=64)
535
+ res = qa_pipeline(question=q, context=safe_ctx)
536
+ answer = res.get("answer", "")
537
+ score = float(res.get("score", 0.0))
538
+ except Exception:
539
+ answer = ""
540
+ score = 0.0
541
+ results.append((q, answer, score))
542
+ return results
543
+
544
+
545
+ # -------------------------
546
+ # Unified extract_text_for_file: dispatch by extension
547
+ # -------------------------
548
+ def extract_text_for_file(file_bytes: bytes, filename: str, do_ocr: bool = False) -> Tuple[str, List[str]]:
549
+ """
550
+ Given raw bytes and a filename, return (extracted_txt_path, pages_texts).
551
+ Supports PDF (streamed), DOCX, PPTX, TXT/MD. For images, attempt OCR if do_ocr True.
552
+ """
553
+ ext = ext_of_name(filename)
554
+ if ext == ".pdf":
555
+ return extract_text_from_pdf_streaming(file_bytes, do_ocr=do_ocr)
556
+ if ext == ".docx":
557
+ return extract_text_from_docx_bytes(file_bytes)
558
+ if ext == ".pptx":
559
+ return extract_text_from_pptx_bytes(file_bytes)
560
+ if ext in (".txt", ".md", ".rtf"):
561
+ return extract_text_from_txt_bytes(file_bytes)
562
+ # images: try OCR if requested
563
+ if ext in (".png", ".jpg", ".jpeg", ".tiff", ".bmp") and do_ocr:
564
+ # write image bytes to temp file and run OCR with pytesseract if available
565
+ try:
566
+ from PIL import Image
567
+ import pytesseract
568
+ with tempfile.NamedTemporaryFile(delete=False, suffix=ext) as tmp:
569
+ tmp.write(file_bytes)
570
+ tmp_path = tmp.name
571
+ img = Image.open(tmp_path)
572
+ ocr_txt = pytesseract.image_to_string(img)
573
+ pages = [p.strip() for p in re.split(r"\n\s*\n", ocr_txt) if p.strip()]
574
+ tmp_txt = tempfile.NamedTemporaryFile(delete=False, suffix=".txt", mode="w+", encoding="utf-8")
575
+ with open(tmp_txt.name, "w", encoding="utf-8") as f:
576
+ f.write("\n\n".join(pages))
577
+ try:
578
+ os.remove(tmp_path)
579
+ except Exception:
580
+ pass
581
+ return tmp_txt.name, pages
582
+ except Exception:
583
+ # fallback: treat as empty
584
+ tmp_txt = tempfile.NamedTemporaryFile(delete=False, suffix=".txt", mode="w+", encoding="utf-8")
585
+ with open(tmp_txt.name, "w", encoding="utf-8") as f:
586
+ f.write("")
587
+ return tmp_txt.name, [""]
588
+ # unsupported extension: write bytes to txt and return raw decode
589
+ tmp_txt = tempfile.NamedTemporaryFile(delete=False, suffix=".txt", mode="w+", encoding="utf-8")
590
+ try:
591
+ s = file_bytes.decode("utf-8", errors="ignore")
592
+ except Exception:
593
+ s = ""
594
+ with open(tmp_txt.name, "w", encoding="utf-8") as f:
595
+ f.write(s)
596
+ pages = [p.strip() for p in re.split(r"\n\s*\n", s) if p.strip()]
597
+ return tmp_txt.name, pages if pages else [s]
598
+
599
+
600
+ # -------------------------
601
+ # Main analyze function
602
+ # -------------------------
603
+ def analyze_document(file_obj, filename: str, do_ocr: bool, max_passages_for_qa: int = 5):
604
+ """
605
+ file_obj: uploaded object
606
+ filename: filename string (to detect extension)
607
+ """
608
+ try:
609
+ file_bytes = read_uploaded_file_to_bytes(file_obj)
610
+ except Exception as e:
611
+ return f"(error reading file: {e})", "", []
612
+
613
+ extracted_txt_path, pages_texts = extract_text_for_file(file_bytes, filename, do_ocr=do_ocr)
614
+
615
+ # preview
616
+ preview_chars = 20000
617
+ extracted_preview = ""
618
+ try:
619
+ with open(extracted_txt_path, "r", encoding="utf-8", errors="ignore") as f:
620
+ extracted_preview = f.read(preview_chars)
621
+ if len(extracted_preview) >= preview_chars:
622
+ extracted_preview += "\n\n... (preview truncated) ..."
623
+ except Exception:
624
+ extracted_preview = "(could not read extracted text preview)"
625
+
626
+ summarizer = get_pipeline(SUMMARIZER_MODEL, "summarization")
627
+ combined_summary = summarize_text_chunked(summarizer, pages_texts, pages_per_chunk=8)
628
+
629
+ all_passages = split_into_passages_from_pages(pages_texts, max_words=200)
630
+ total = len(all_passages)
631
+ if total == 0:
632
+ return extracted_preview, combined_summary, []
633
+
634
+ if total <= max_passages_for_qa:
635
+ chosen_passages = list(enumerate(all_passages))
636
+ else:
637
+ step = max(1, math.floor(total / max_passages_for_qa))
638
+ chosen_passages = [(i, all_passages[i]) for i in range(0, total, step)][:max_passages_for_qa]
639
+
640
+ qg = get_pipeline(QG_MODEL, "text2text-generation")
641
+ qa = get_pipeline(QA_MODEL, "question-answering")
642
+
643
+ answered = []
644
+ answered_set = set()
645
+ for (p_idx, passage) in chosen_passages:
646
+ if not passage.strip():
647
+ continue
648
+ questions = generate_questions_from_passage(qg, passage, min_questions=3)
649
+ unique_questions = [q for q in questions if q not in answered_set]
650
+ if not unique_questions:
651
+ continue
652
+ answers = answer_questions_for_passage(qa, passage, unique_questions)
653
+ for q, a, score in answers:
654
+ answered.append({"passage_idx": int(p_idx), "question": q, "answer": a, "score": float(score)})
655
+ answered_set.add(q)
656
+
657
+ return extracted_preview, combined_summary, answered
658
+
659
+
660
+ # -------------------------
661
+ # Gradio UI
662
+ # -------------------------
663
+ def build_demo():
664
+ with gr.Blocks(title="Document Analysis (LLMs)") as demo:
665
+ gr.Markdown("# Document Analysis using LLMs\nUpload a supported file (PDF, DOCX, PPTX, TXT) and get summary + Q&A.")
666
+ with gr.Row():
667
+ with gr.Column(scale=1):
668
+ files_in = gr.File(label="Upload files (PDF, DOCX, PPTX, TXT, images...)", file_count="multiple")
669
+ do_ocr = gr.Checkbox(label="Try OCR for images/PDF pages (requires OCR libs)", value=False)
670
+ max_pass = gr.Slider(label="Max passages to run Q&A on (lower = faster)", minimum=1, maximum=20, step=1, value=5)
671
+ run_btn = gr.Button("Analyze Document")
672
+
673
+ with gr.Column(scale=2):
674
+ tabs = gr.Tabs()
675
+ with tabs:
676
+ with gr.TabItem("Uploaded files"):
677
+ uploaded_list = gr.Textbox(label="Uploaded filenames", lines=4)
678
+ with gr.TabItem("Extracted Text"):
679
+ extracted_out = gr.Textbox(label="Extracted text (preview)", lines=15)
680
+ with gr.TabItem("Summary"):
681
+ summary_out = gr.Textbox(label="Summary", lines=8)
682
+ with gr.TabItem("Q&A"):
683
+ qa_out = gr.Dataframe(headers=["passage_idx", "question", "answer", "score"],
684
+ datatype=["number", "text", "text", "number"])
685
+
686
+ def _run(files, do_ocr_val, max_pass_val):
687
+ names = get_uploaded_filenames(files)
688
+ uploaded_str = "\n".join(names) if names else "(no files uploaded)"
689
+
690
+ fobj, fname = find_first_supported_file(files)
691
+ if fobj is None or fname is None:
692
+ return uploaded_str, "(no supported file found)", "", pd.DataFrame(columns=["passage_idx", "question", "answer", "score"])
693
+
694
+ text, summary, qa = analyze_document(fobj, fname, do_ocr=do_ocr_val, max_passages_for_qa=int(max_pass_val))
695
+ if not qa:
696
+ qa_df = pd.DataFrame(columns=["passage_idx", "question", "answer", "score"])
697
+ else:
698
+ qa_df = pd.DataFrame(qa)
699
+ qa_df = qa_df.loc[:, ["passage_idx", "question", "answer", "score"]]
700
+ qa_df["passage_idx"] = qa_df["passage_idx"].astype(int)
701
+ qa_df["question"] = qa_df["question"].astype(str)
702
+ qa_df["answer"] = qa_df["answer"].astype(str)
703
+ qa_df["score"] = qa_df["score"].astype(float)
704
+
705
+ return uploaded_str, text or "(no text extracted)", summary or "(no summary)", qa_df
706
+
707
+ run_btn.click(_run, inputs=[files_in, do_ocr, max_pass], outputs=[uploaded_list, extracted_out, summary_out, qa_out])
708
+
709
+ return demo
710
+
711
+
712
+ if __name__ == "__main__":
713
+ demo = build_demo()
714
+ demo.launch()
715
+ # demo.launch(server_name="0.0.0.0")
requirements.txt ADDED
@@ -0,0 +1,13 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ gradio>=3.30
2
+ transformers>=4.30.0
3
+ torch>=2.0.0
4
+ pdfplumber>=0.11
5
+ pandas
6
+ tqdm
7
+ nltk>=3.8
8
+ sentencepiece
9
+ python-docx
10
+ python-pptx
11
+ Pillow
12
+ pytesseract
13
+ pdf2image