File size: 27,267 Bytes
2567e32
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
"""

app.py



Document Analysis Gradio app — updated to support PDF, DOCX (Word), and PPTX (PowerPoint).

- robust file reading

- streaming PDF extraction

- docx/pptx extraction to pages_texts (one element per paragraph/slide)

- token-aware truncation, chunked summarization, sampled Q&A

- multi-file upload UI (processes first supported file)

"""

import os
# disable noisy HF symlink warning on Windows
os.environ.setdefault("HF_HUB_DISABLE_SYMLINKS_WARNING", "1")

import re
import io
import math
import tempfile
import threading
from pathlib import Path
from typing import List, Tuple, Optional

import gradio as gr
import pdfplumber
import nltk
from nltk.tokenize import sent_tokenize
from transformers import pipeline, AutoTokenizer
import torch
import pandas as pd
from tqdm.auto import tqdm

# Try to import docx and pptx; if not present, we'll handle gracefully at runtime.
try:
    from docx import Document as DocxDocument  # python-docx
except Exception:
    DocxDocument = None

try:
    from pptx import Presentation as PptxPresentation  # python-pptx
except Exception:
    PptxPresentation = None

# -------------------------
# NLTK: ensure punkt available, fallback later
# -------------------------
try:
    nltk.download("punkt", quiet=True)
    try:
        nltk.download("punkt_tab", quiet=True)
    except Exception:
        pass
except Exception:
    pass

# -------------------------
# Device detection
# -------------------------
DEVICE = 0 if torch.cuda.is_available() else -1
print("Device set to use", "cuda" if DEVICE >= 0 else "cpu")

# -------------------------
# Cached pipelines and tokenizers
# -------------------------
_models = {}
_tokenizers = {}
_models_lock = threading.Lock()


def get_pipeline(name: str, task: str):
    """Return a cached HF pipeline for given task and model name."""
    key = f"{task}__{name}"
    with _models_lock:
        if key in _models:
            return _models[key]
        print(f"Loading pipeline: task={task}, model={name} ... (this may take a while on first run)")
        p = pipeline(task, model=name, device=DEVICE)
        _models[key] = p
        try:
            _tokenizers[name] = p.tokenizer
        except Exception:
            pass
        return p


def get_tokenizer(name: str):
    """Return a cached tokenizer (fallback to AutoTokenizer if not present)."""
    if name in _tokenizers:
        return _tokenizers[name]
    try:
        tok = AutoTokenizer.from_pretrained(name)
        _tokenizers[name] = tok
        return tok
    except Exception:
        return None


# -------------------------
# Default models (adjust if you want smaller/faster ones)
# -------------------------
# -------------------------
# Default models (adjusted for speed)
# -------------------------
SUMMARIZER_MODEL = "t5-small"  # This is already fast
QG_MODEL = "valhalla/t5-small-qg-hl" # Use the 'small' version
QA_MODEL = "distilbert-base-cased-distilled-squad" # Much faster than RoBERTa


# -------------------------
# Helpers: filenames / types
# -------------------------
SUPPORTED_EXTS = [".pdf", ".docx", ".pptx", ".txt", ".md", ".rtf", ".png", ".jpg", ".jpeg", ".tiff"]

def ext_of_name(name: str) -> str:
    return Path(name).suffix.lower()


def read_uploaded_file_to_bytes(file_obj):
    """

    Accept many shapes of Gradio file objects and return bytes.

    Supports list/tuple (returns first readable file), dict-like, file-likes, paths, bytes.

    """
    # If a list/tuple of uploaded files, try each candidate
    if isinstance(file_obj, (list, tuple)):
        last_err = None
        for elem in file_obj:
            try:
                return read_uploaded_file_to_bytes(elem)
            except Exception as e:
                last_err = e
                continue
        raise ValueError(f"No readable file in list. Last error: {last_err}")

    if file_obj is None:
        raise ValueError("No file provided")

    # if it's already bytes
    if isinstance(file_obj, (bytes, bytearray)):
        return bytes(file_obj)

    # dict-like
    if isinstance(file_obj, dict):
        for key in ("file", "tmp_path", "name", "data", "path"):
            val = file_obj.get(key)
            if isinstance(val, (bytes, bytearray)):
                return bytes(val)
            if isinstance(val, str) and Path(val).exists():
                return Path(val).read_bytes()
        maybe = file_obj.get("file")
        if hasattr(maybe, "read"):
            data = maybe.read()
            if isinstance(data, str):
                return data.encode("utf-8")
            return data

    # string path
    if isinstance(file_obj, str) and Path(file_obj).exists():
        return Path(file_obj).read_bytes()

    # has a .name attribute that points to a file
    if hasattr(file_obj, "name") and isinstance(getattr(file_obj, "name"), str) and Path(file_obj.name).exists():
        try:
            return Path(file_obj.name).read_bytes()
        except Exception:
            pass

    # file-like with .read()
    if hasattr(file_obj, "read"):
        try:
            data = file_obj.read()
            if isinstance(data, str):
                return data.encode("utf-8")
            return data
        except Exception:
            pass

    # last resort: string representation -> path
    try:
        s = str(file_obj)
        if Path(s).exists():
            return Path(s).read_bytes()
    except Exception:
        pass

    raise ValueError(f"Unsupported uploaded file object type: {type(file_obj)}")


def get_uploaded_filenames(file_obj) -> List[str]:
    """

    Return a list of human-friendly filenames from uploaded file object(s).

    """
    names = []
    if file_obj is None:
        return names
    if isinstance(file_obj, (list, tuple)):
        for elem in file_obj:
            names.extend(get_uploaded_filenames(elem))
        return names
    if isinstance(file_obj, dict):
        for key in ("name", "filename", "file", "tmp_path", "path"):
            if key in file_obj:
                val = file_obj.get(key)
                if isinstance(val, str):
                    names.append(Path(val).name)
                elif hasattr(val, "name"):
                    names.append(Path(val.name).name)
        maybe = file_obj.get("file")
        if maybe is not None:
            if hasattr(maybe, "name"):
                names.append(Path(maybe.name).name)
        return names
    if isinstance(file_obj, str):
        return [Path(file_obj).name]
    if hasattr(file_obj, "name"):
        return [Path(getattr(file_obj, "name")).name]
    return [str(file_obj)]


def find_first_supported_file(files) -> Tuple[Optional[object], Optional[str]]:
    """

    From the uploaded list or single file-like object, find the first file with a supported extension.

    Returns (file_obj, filename) or (None, None) if none supported.

    """
    if not files:
        return None, None
    candidates = []
    if isinstance(files, (list, tuple)):
        for f in files:
            names = get_uploaded_filenames(f)
            for n in names:
                candidates.append((f, n))
    else:
        names = get_uploaded_filenames(files)
        for n in names:
            candidates.append((files, n))
    for fobj, name in candidates:
        ext = ext_of_name(name)
        if ext in SUPPORTED_EXTS:
            return fobj, name
    # fallback: if none matched, return first uploaded
    if candidates:
        return candidates[0]
    return None, None


# -------------------------
# PDF extraction (streaming)
# -------------------------
def extract_text_from_pdf_streaming(file_bytes: bytes, do_ocr: bool = False, extracted_txt_path: str = None):
    """Write PDF bytes to temp file and extract each page's text; returns extracted_txt_path, pages_texts list"""
    with tempfile.NamedTemporaryFile(delete=False, suffix=".pdf") as tmp_pdf:
        tmp_pdf.write(file_bytes)
        tmp_pdf_path = tmp_pdf.name

    if extracted_txt_path is None:
        tmp_txt = tempfile.NamedTemporaryFile(delete=False, suffix=".txt", mode="w+", encoding="utf-8")
        extracted_txt_path = tmp_txt.name
        tmp_txt.close()

    pages_texts = []
    try:
        with pdfplumber.open(tmp_pdf_path) as pdf:
            for page in pdf.pages:
                page_text = page.extract_text()
                if (not page_text or page_text.strip() == "") and do_ocr:
                    try:
                        from pdf2image import convert_from_path
                        import pytesseract
                        images = convert_from_path(tmp_pdf_path, first_page=page.page_number, last_page=page.page_number)
                        if images:
                            page_text = pytesseract.image_to_string(images[0])
                    except Exception:
                        page_text = ""
                if page_text is None:
                    page_text = ""
                with open(extracted_txt_path, "a", encoding="utf-8") as fout:
                    fout.write(page_text)
                    fout.write("\n\n---PAGE_BREAK---\n\n")
                pages_texts.append(page_text)
    finally:
        try:
            os.remove(tmp_pdf_path)
        except Exception:
            pass

    return extracted_txt_path, pages_texts


# -------------------------
# DOCX extraction
# -------------------------
def extract_text_from_docx_bytes(file_bytes: bytes) -> Tuple[str, List[str]]:
    """

    Returns (tmp_text_path, pages_texts) where pages_texts is a list of paragraph groups.

    We'll treat each paragraph as a small 'page' or group paragraphs into ~200-word chunks.

    """
    if DocxDocument is None:
        raise RuntimeError("python-docx not installed. Install with `pip install python-docx`")

    with tempfile.NamedTemporaryFile(delete=False, suffix=".docx") as tmp:
        tmp.write(file_bytes)
        tmp_path = tmp.name

    pages_texts = []
    try:
        doc = DocxDocument(tmp_path)
        # Collect non-empty paragraphs
        paras = [p.text.strip() for p in doc.paragraphs if p.text and p.text.strip()]
        # Group paragraphs into passages of ~200 words
        current = []
        cur_words = 0
        for p in paras:
            w = len(p.split())
            if cur_words + w <= 200:
                current.append(p)
                cur_words += w
            else:
                pages_texts.append(" ".join(current).strip())
                current = [p]
                cur_words = w
        if current:
            pages_texts.append(" ".join(current).strip())
        # write full extracted text to tmp file for download/debug if needed
        tmp_txt = tempfile.NamedTemporaryFile(delete=False, suffix=".txt", mode="w+", encoding="utf-8")
        with open(tmp_txt.name, "w", encoding="utf-8") as f:
            f.write("\n\n".join(pages_texts))
        txt_path = tmp_txt.name
    finally:
        try:
            os.remove(tmp_path)
        except Exception:
            pass

    return txt_path, pages_texts


# -------------------------
# PPTX extraction
# -------------------------
def extract_text_from_pptx_bytes(file_bytes: bytes) -> Tuple[str, List[str]]:
    """

    Returns (tmp_text_path, pages_texts) where each slide's text is one element.

    """
    if PptxPresentation is None:
        raise RuntimeError("python-pptx not installed. Install with `pip install python-pptx`")

    with tempfile.NamedTemporaryFile(delete=False, suffix=".pptx") as tmp:
        tmp.write(file_bytes)
        tmp_path = tmp.name

    pages_texts = []
    try:
        prs = PptxPresentation(tmp_path)
        for slide in prs.slides:
            texts = []
            for shape in slide.shapes:
                try:
                    if hasattr(shape, "text") and shape.text:
                        texts.append(shape.text.strip())
                except Exception:
                    continue
            slide_text = "\n".join([t for t in texts if t])
            pages_texts.append(slide_text)
        tmp_txt = tempfile.NamedTemporaryFile(delete=False, suffix=".txt", mode="w+", encoding="utf-8")
        with open(tmp_txt.name, "w", encoding="utf-8") as f:
            f.write("\n\n".join(pages_texts))
        txt_path = tmp_txt.name
    finally:
        try:
            os.remove(tmp_path)
        except Exception:
            pass

    return txt_path, pages_texts


# -------------------------
# TXT/MD extraction
# -------------------------
def extract_text_from_txt_bytes(file_bytes: bytes) -> Tuple[str, List[str]]:
    s = file_bytes.decode("utf-8", errors="ignore")
    # split into passages by blank lines or ~200 words
    paras = [p.strip() for p in re.split(r"\n\s*\n", s) if p.strip()]
    pages_texts = []
    cur = []
    cur_words = 0
    for p in paras:
        w = len(p.split())
        if cur_words + w <= 200:
            cur.append(p)
            cur_words += w
        else:
            pages_texts.append(" ".join(cur).strip())
            cur = [p]
            cur_words = w
    if cur:
        pages_texts.append(" ".join(cur).strip())
    tmp_txt = tempfile.NamedTemporaryFile(delete=False, suffix=".txt", mode="w+", encoding="utf-8")
    with open(tmp_txt.name, "w", encoding="utf-8") as f:
        f.write("\n\n".join(pages_texts))
    return tmp_txt.name, pages_texts


# -------------------------
# Token-aware truncation helpers
# -------------------------
def truncate_by_tokens(text: str, tokenizer, reserve: int = 64) -> str:
    if not text:
        return text
    if tokenizer is None:
        return text if len(text) <= 3000 else text[:3000]
    try:
        ids = tokenizer.encode(text, add_special_tokens=False)
        max_len = getattr(tokenizer, "model_max_length", 512)
        allowed = max(1, max_len - reserve)
        if len(ids) > allowed:
            ids = ids[:allowed]
            return tokenizer.decode(ids, skip_special_tokens=True, clean_up_tokenization_spaces=True)
        return text
    except Exception:
        return text if len(text) <= 3000 else text[:3000]


# -------------------------
# Sentence tokenization fallback
# -------------------------
def safe_sentence_tokenize(text: str) -> List[str]:
    try:
        sents = sent_tokenize(text)
        if isinstance(sents, list) and len(sents) > 0:
            return sents
    except Exception:
        pass
    pieces = re.split(r"(?<=[.!?])\s+", text.strip())
    return [p.strip() for p in pieces if p.strip()]


# -------------------------
# Summarization (chunked, token-truncated)
# -------------------------
def summarize_text_chunked(summarizer, pages_texts: List[str], pages_per_chunk: int = 8) -> str:
    if not pages_texts:
        return "(no text)"
    summaries = []
    tokenizer = getattr(summarizer, "tokenizer", None) or get_tokenizer(SUMMARIZER_MODEL)
    num_pages = len(pages_texts)
    for i in range(0, num_pages, pages_per_chunk):
        chunk_pages = pages_texts[i : i + pages_per_chunk]
        chunk_text = "\n\n".join([p for p in chunk_pages if p.strip()])
        if not chunk_text.strip():
            continue
        safe_chunk = truncate_by_tokens(chunk_text, tokenizer, reserve=64)
        try:
            out = summarizer(safe_chunk, max_length=150, min_length=30, do_sample=False, truncation=True)
            summaries.append(out[0]["summary_text"])
        except Exception:
            summaries.append(safe_chunk[:800])
    return "\n\n".join(summaries) if summaries else "(no summary produced)"


# -------------------------
# Passage splitting / QG / QA with token truncation
# -------------------------
def split_into_passages_from_pages(pages_texts: List[str], max_words: int = 200) -> List[str]:
    all_passages = []
    for page_text in pages_texts:
        if not page_text or not page_text.strip():
            continue
        sents = safe_sentence_tokenize(page_text)
        cur = []
        cur_len = 0
        for s in sents:
            w = len(s.split())
            if cur_len + w <= max_words:
                cur.append(s)
                cur_len += w
            else:
                if cur:
                    all_passages.append(" ".join(cur).strip())
                cur = [s]
                cur_len = w
        if cur:
            all_passages.append(" ".join(cur).strip())
    return all_passages


def generate_questions_from_passage(qg_pipeline, passage: str, min_questions: int = 3) -> List[str]:
    tok = getattr(qg_pipeline, "tokenizer", None) or get_tokenizer(QG_MODEL)
    safe_passage = truncate_by_tokens(passage, tok, reserve=32)
    prompt = f"generate questions: {safe_passage}"
    try:
        out = qg_pipeline(prompt, max_length=256, do_sample=False, truncation=True)
        gen_text = out[0].get("generated_text") or out[0].get("text") or ""
    except Exception:
        gen_text = ""
    candidates = []
    if "<sep>" in gen_text:
        candidates = gen_text.split("<sep>")
    elif "\n" in gen_text:
        candidates = [line.strip() for line in gen_text.splitlines() if line.strip()]
    else:
        parts = [p.strip() for p in gen_text.split("?") if p.strip()]
        candidates = [p + "?" for p in parts]
    questions = [q.strip() for q in candidates if q.strip()]
    if len(questions) < min_questions:
        sentences = safe_sentence_tokenize(safe_passage)
        for i in range(len(sentences)):
            if len(questions) >= min_questions:
                break
            small = ". ".join(sentences[i : i + 2])
            try:
                out2 = qg_pipeline(f"generate questions: {small}", max_length=128, do_sample=False, truncation=True)
                txt = out2[0].get("generated_text") or out2[0].get("text") or ""
            except Exception:
                txt = ""
            if "<sep>" in txt:
                more = txt.split("<sep>")
            else:
                more = [l.strip() for l in txt.splitlines() if l.strip()]
            for m in more:
                if len(questions) >= min_questions:
                    break
                maybe = m.strip()
                if maybe and maybe not in questions:
                    questions.append(maybe)
    return questions[:max(min_questions, len(questions))]


def answer_questions_for_passage(qa_pipeline, passage: str, questions: List[str]) -> List[Tuple[str, str, float]]:
    results = []
    tok = getattr(qa_pipeline, "tokenizer", None) or get_tokenizer(QA_MODEL)
    for q in questions:
        try:
            safe_ctx = truncate_by_tokens(passage, tok, reserve=64)
            res = qa_pipeline(question=q, context=safe_ctx)
            answer = res.get("answer", "")
            score = float(res.get("score", 0.0))
        except Exception:
            answer = ""
            score = 0.0
        results.append((q, answer, score))
    return results


# -------------------------
# Unified extract_text_for_file: dispatch by extension
# -------------------------
def extract_text_for_file(file_bytes: bytes, filename: str, do_ocr: bool = False) -> Tuple[str, List[str]]:
    """

    Given raw bytes and a filename, return (extracted_txt_path, pages_texts).

    Supports PDF (streamed), DOCX, PPTX, TXT/MD. For images, attempt OCR if do_ocr True.

    """
    ext = ext_of_name(filename)
    if ext == ".pdf":
        return extract_text_from_pdf_streaming(file_bytes, do_ocr=do_ocr)
    if ext == ".docx":
        return extract_text_from_docx_bytes(file_bytes)
    if ext == ".pptx":
        return extract_text_from_pptx_bytes(file_bytes)
    if ext in (".txt", ".md", ".rtf"):
        return extract_text_from_txt_bytes(file_bytes)
    # images: try OCR if requested
    if ext in (".png", ".jpg", ".jpeg", ".tiff", ".bmp") and do_ocr:
        # write image bytes to temp file and run OCR with pytesseract if available
        try:
            from PIL import Image
            import pytesseract
            with tempfile.NamedTemporaryFile(delete=False, suffix=ext) as tmp:
                tmp.write(file_bytes)
                tmp_path = tmp.name
            img = Image.open(tmp_path)
            ocr_txt = pytesseract.image_to_string(img)
            pages = [p.strip() for p in re.split(r"\n\s*\n", ocr_txt) if p.strip()]
            tmp_txt = tempfile.NamedTemporaryFile(delete=False, suffix=".txt", mode="w+", encoding="utf-8")
            with open(tmp_txt.name, "w", encoding="utf-8") as f:
                f.write("\n\n".join(pages))
            try:
                os.remove(tmp_path)
            except Exception:
                pass
            return tmp_txt.name, pages
        except Exception:
            # fallback: treat as empty
            tmp_txt = tempfile.NamedTemporaryFile(delete=False, suffix=".txt", mode="w+", encoding="utf-8")
            with open(tmp_txt.name, "w", encoding="utf-8") as f:
                f.write("")
            return tmp_txt.name, [""]
    # unsupported extension: write bytes to txt and return raw decode
    tmp_txt = tempfile.NamedTemporaryFile(delete=False, suffix=".txt", mode="w+", encoding="utf-8")
    try:
        s = file_bytes.decode("utf-8", errors="ignore")
    except Exception:
        s = ""
    with open(tmp_txt.name, "w", encoding="utf-8") as f:
        f.write(s)
    pages = [p.strip() for p in re.split(r"\n\s*\n", s) if p.strip()]
    return tmp_txt.name, pages if pages else [s]


# -------------------------
# Main analyze function
# -------------------------
def analyze_document(file_obj, filename: str, do_ocr: bool, max_passages_for_qa: int = 5):
    """

    file_obj: uploaded object

    filename: filename string (to detect extension)

    """
    try:
        file_bytes = read_uploaded_file_to_bytes(file_obj)
    except Exception as e:
        return f"(error reading file: {e})", "", []

    extracted_txt_path, pages_texts = extract_text_for_file(file_bytes, filename, do_ocr=do_ocr)

    # preview
    preview_chars = 20000
    extracted_preview = ""
    try:
        with open(extracted_txt_path, "r", encoding="utf-8", errors="ignore") as f:
            extracted_preview = f.read(preview_chars)
            if len(extracted_preview) >= preview_chars:
                extracted_preview += "\n\n... (preview truncated) ..."
    except Exception:
        extracted_preview = "(could not read extracted text preview)"

    summarizer = get_pipeline(SUMMARIZER_MODEL, "summarization")
    combined_summary = summarize_text_chunked(summarizer, pages_texts, pages_per_chunk=8)

    all_passages = split_into_passages_from_pages(pages_texts, max_words=200)
    total = len(all_passages)
    if total == 0:
        return extracted_preview, combined_summary, []

    if total <= max_passages_for_qa:
        chosen_passages = list(enumerate(all_passages))
    else:
        step = max(1, math.floor(total / max_passages_for_qa))
        chosen_passages = [(i, all_passages[i]) for i in range(0, total, step)][:max_passages_for_qa]

    qg = get_pipeline(QG_MODEL, "text2text-generation")
    qa = get_pipeline(QA_MODEL, "question-answering")

    answered = []
    answered_set = set()
    for (p_idx, passage) in chosen_passages:
        if not passage.strip():
            continue
        questions = generate_questions_from_passage(qg, passage, min_questions=3)
        unique_questions = [q for q in questions if q not in answered_set]
        if not unique_questions:
            continue
        answers = answer_questions_for_passage(qa, passage, unique_questions)
        for q, a, score in answers:
            answered.append({"passage_idx": int(p_idx), "question": q, "answer": a, "score": float(score)})
            answered_set.add(q)

    return extracted_preview, combined_summary, answered


# -------------------------
# Gradio UI
# -------------------------
def build_demo():
    with gr.Blocks(title="Document Analysis (LLMs)") as demo:
        gr.Markdown("# Document Analysis using LLMs\nUpload a supported file (PDF, DOCX, PPTX, TXT) and get summary + Q&A.")
        with gr.Row():
            with gr.Column(scale=1):
                files_in = gr.File(label="Upload files (PDF, DOCX, PPTX, TXT, images...)", file_count="multiple")
                do_ocr = gr.Checkbox(label="Try OCR for images/PDF pages (requires OCR libs)", value=False)
                max_pass = gr.Slider(label="Max passages to run Q&A on (lower = faster)", minimum=1, maximum=20, step=1, value=5)
                run_btn = gr.Button("Analyze Document")

            with gr.Column(scale=2):
                tabs = gr.Tabs()
                with tabs:
                    with gr.TabItem("Uploaded files"):
                        uploaded_list = gr.Textbox(label="Uploaded filenames", lines=4)
                    with gr.TabItem("Extracted Text"):
                        extracted_out = gr.Textbox(label="Extracted text (preview)", lines=15)
                    with gr.TabItem("Summary"):
                        summary_out = gr.Textbox(label="Summary", lines=8)
                    with gr.TabItem("Q&A"):
                        qa_out = gr.Dataframe(headers=["passage_idx", "question", "answer", "score"],
                                              datatype=["number", "text", "text", "number"])

        def _run(files, do_ocr_val, max_pass_val):
            names = get_uploaded_filenames(files)
            uploaded_str = "\n".join(names) if names else "(no files uploaded)"

            fobj, fname = find_first_supported_file(files)
            if fobj is None or fname is None:
                return uploaded_str, "(no supported file found)", "", pd.DataFrame(columns=["passage_idx", "question", "answer", "score"])

            text, summary, qa = analyze_document(fobj, fname, do_ocr=do_ocr_val, max_passages_for_qa=int(max_pass_val))
            if not qa:
                qa_df = pd.DataFrame(columns=["passage_idx", "question", "answer", "score"])
            else:
                qa_df = pd.DataFrame(qa)
                qa_df = qa_df.loc[:, ["passage_idx", "question", "answer", "score"]]
                qa_df["passage_idx"] = qa_df["passage_idx"].astype(int)
                qa_df["question"] = qa_df["question"].astype(str)
                qa_df["answer"] = qa_df["answer"].astype(str)
                qa_df["score"] = qa_df["score"].astype(float)

            return uploaded_str, text or "(no text extracted)", summary or "(no summary)", qa_df

        run_btn.click(_run, inputs=[files_in, do_ocr, max_pass], outputs=[uploaded_list, extracted_out, summary_out, qa_out])

    return demo


if __name__ == "__main__":
    demo = build_demo()
    demo.launch()
    # demo.launch(server_name="0.0.0.0")