Spaces:
Running
on
CPU Upgrade
Running
on
CPU Upgrade
| # Copyright 2025 Google LLC | |
| # | |
| # Licensed under the Apache License, Version 2.0 (the "License"); | |
| # you may not use this file except in compliance with the License. | |
| # You may obtain a copy of the License at | |
| # | |
| # http://www.apache.org/licenses/LICENSE-2.0 | |
| # | |
| # Unless required by applicable law or agreed to in writing, software | |
| # distributed under the License is distributed on an "AS IS" BASIS, | |
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
| # See the License for the specific language governing permissions and | |
| # limitations under the License. | |
| import concurrent.futures | |
| import logging | |
| import os | |
| import re | |
| from pathlib import Path | |
| from typing import Dict, List | |
| import fitz # PyMuPDF | |
| from PIL import Image | |
| from langchain.docstore.document import Document as LangchainDocument | |
| from langchain.retrievers import BM25Retriever, EnsembleRetriever | |
| from langchain.text_splitter import NLTKTextSplitter | |
| from langchain_community.vectorstores import Chroma | |
| from tqdm import tqdm | |
| logger = logging.getLogger(__name__) | |
| IMAGE_SUMMARY_PROMPT = """Summarize key findings in this image.""" | |
| class KnowledgeBase: | |
| """Processes a source PDF and builds a self-contained, searchable RAG knowledge base.""" | |
| def __init__(self, models: dict, config_overrides: dict | None = None): | |
| """Initializes the builder with necessary models and configuration.""" | |
| self.embedder = models.get("embedder") | |
| self.ner_pipeline = models.get("ner_pipeline") | |
| # Set default config and apply any overrides | |
| self.config = self._get_default_config() | |
| if config_overrides: | |
| self.config.update(config_overrides) | |
| # For consistent chunking, the RAG query uses the same enriching and chunking logic as the knowledge base. | |
| self.document_enricher = self._enrich_documents | |
| self.chunker = self._create_chunks_from_documents | |
| self.retriever: EnsembleRetriever | None = None | |
| self.page_map: Dict[int, Dict] = {} | |
| self.source_filepath = "" | |
| # Create necessary directories from config | |
| Path(self.config["IMAGE_DIR"]).mkdir(parents=True, exist_ok=True) | |
| Path(self.config["CHROMA_PERSIST_DIR"]).mkdir(parents=True, exist_ok=True) | |
| def _get_default_config(self): | |
| """Returns the default configuration for the KnowledgeBase.""" | |
| return { | |
| "IMAGE_DIR": Path("processed_figures_kb/"), | |
| "CHROMA_PERSIST_DIR": Path("chroma_db_store/"), | |
| "MEDICAL_ENTITY_TYPES_TO_EXTRACT": ["PROBLEM"], | |
| "EXTRACT_IMAGE_SUMMARIES": False, # Disabled as we don't load the LLM here | |
| "FILTER_FIRST_PAGES": 6, | |
| "FIGURE_MIN_WIDTH": 30, | |
| "FIGURE_MIN_HEIGHT": 30, | |
| "SENTENCE_CHUNK_SIZE": 250, | |
| "CHUNK_FILTER_SIZE": 20, | |
| "RETRIEVER_TOP_K": 20, | |
| "ENSEMBLE_WEIGHTS_BM25,SENTENCE,NER": [0.2, 0.3, 0.5], | |
| "SENTENCE_SCORE_THRESHOLD": 0.6, | |
| "NER_SCORE_THRESHOLD": 0.5, | |
| "MAX_PARALLEL_WORKERS": 16, | |
| } | |
| def build(self, pdf_filepath: str): | |
| """The main public method to build the knowledge base from a PDF.""" | |
| logger.info(f"--------- Building Knowledge Base from '{pdf_filepath}' ---------") | |
| pdf_path = Path(pdf_filepath) | |
| if not pdf_path.exists(): | |
| logger.error(f"ERROR: PDF file not found at {pdf_filepath}") | |
| return None | |
| self.source_filepath = pdf_path | |
| # Step 1: Process the PDF and build the structured page_map. | |
| self.page_map = self._process_and_structure_pdf(pdf_path) | |
| all_docs = [ | |
| doc for page_data in self.page_map.values() for doc in page_data["blocks"] | |
| ] | |
| # Step 2: Enrich documents with NER metadata. | |
| enriched_docs = self._enrich_documents(all_docs, self.config.get("EXTRACT_IMAGE_SUMMARIES", False)) | |
| # Step 3: Chunk the enriched documents into final searchable units. | |
| final_chunks = self._create_chunks_from_documents(enriched_docs) | |
| # Step 4: Build the final ensemble retriever. | |
| self.retriever = self._build_ensemble_retriever(final_chunks) | |
| if self.retriever: | |
| logger.info(f"--------- Knowledge Base Built Successfully ---------") | |
| else: | |
| logger.error(f"--------- Knowledge Base Building Failed ---------") | |
| return self | |
| # --- Step 1: PDF Content Extraction --- | |
| def _process_and_structure_pdf(self, pdf_path: Path) -> dict: | |
| """Processes a PDF in parallel and directly builds the final page_map. | |
| This version is more efficient by opening the PDF only once. | |
| """ | |
| logger.info("Step 1: Processing PDF and building structured page map...") | |
| page_map = {} | |
| try: | |
| # Improvement: Open the PDF ONCE to get all preliminary info | |
| with fitz.open(pdf_path) as doc: | |
| pdf_bytes_buffer = doc.write() | |
| page_count = len(doc) | |
| toc = doc.get_toc() | |
| # Improvement: Create a more robust chapter lookup map | |
| page_to_chapter_id = {} | |
| if toc: | |
| chapters = [item for item in toc if item[0] == 1] | |
| for i, (lvl, title, start_page) in enumerate(chapters): | |
| end_page = ( | |
| chapters[i + 1][2] - 1 if i + 1 < len(chapters) else page_count | |
| ) | |
| for page_num in range(start_page, end_page + 1): | |
| page_to_chapter_id[page_num] = i | |
| # Create tasks for the thread pool (using a tuple as requested) | |
| tasks = [ | |
| ( | |
| pdf_bytes_buffer, | |
| i, | |
| self.config, | |
| pdf_path.name, | |
| page_to_chapter_id, | |
| ) | |
| for i in range(self.config["FILTER_FIRST_PAGES"], page_count) | |
| ] | |
| # Parallel Processing | |
| num_workers = min( | |
| self.config["MAX_PARALLEL_WORKERS"], os.cpu_count() or 1 | |
| ) | |
| with concurrent.futures.ThreadPoolExecutor( | |
| max_workers=num_workers | |
| ) as executor: | |
| futures = [ | |
| executor.submit(self.process_single_page, task) for task in tasks | |
| ] | |
| progress_bar = tqdm( | |
| concurrent.futures.as_completed(futures), | |
| total=len(tasks), | |
| desc="Processing & Structuring Pages", | |
| ) | |
| for future in progress_bar: | |
| result = future.result() | |
| if result: | |
| # The worker now returns a fully formed dictionary for the page_map | |
| page_map[result["page_num"]] = result["content"] | |
| except Exception as e: | |
| logger.error(f"❌ Failed to process PDF {pdf_path.name}: {e}") | |
| return {} | |
| logger.info(f"✅ PDF processed. Created a map of {len(page_map)} pages.") | |
| return dict(sorted(page_map.items())) | |
| # --- Step 2: Document Enrichment --- | |
| def _enrich_documents( | |
| self, docs: List[LangchainDocument], summarize: bool = False | |
| ) -> List[LangchainDocument]: | |
| """Enriches a list of documents with NER metadata and image summaries.""" | |
| logger.info("\nStep 2: Enriching documents...") | |
| # NER Enrichment | |
| if self.ner_pipeline: | |
| logger.info("Adding NER metadata...") | |
| for doc in tqdm(docs, desc="Enriching with NER"): | |
| # 1. Skip documents that have no actual text content | |
| if not doc.page_content or not doc.page_content.strip(): | |
| continue | |
| try: | |
| # 2. Process ONLY the text of the current document | |
| processed_doc = self.ner_pipeline(doc.page_content) | |
| # 3. Extract entities from the result. This result now | |
| # unambiguously belongs to the current 'doc'. | |
| entities = [ | |
| ent.text | |
| for ent in processed_doc.ents | |
| if ent.type in self.config["MEDICAL_ENTITY_TYPES_TO_EXTRACT"] | |
| ] | |
| # 4. Assign the correctly mapped entities to the document's metadata | |
| if entities: | |
| # Using set() handles duplicates before sorting and joining | |
| unique_entities = sorted(list(set(entities))) | |
| doc.metadata["block_ner_entities"] = ", ".join(unique_entities) | |
| except Exception as e: | |
| # Add error handling for robustness in case a single block fails | |
| logger.warning( | |
| f"\nWarning: Could not process NER for a block on page {doc.metadata.get('page_number', 'N/A')}: {e}") | |
| # Image Summary Enrichment | |
| if summarize: | |
| logger.info("Generating image summaries...") | |
| docs_with_figures = [ | |
| doc for doc in docs if "linked_figure_path" in doc.metadata | |
| ] | |
| for doc in tqdm(docs_with_figures, desc="Summarizing Images"): | |
| try: | |
| img = Image.open(doc.metadata["linked_figure_path"]).convert("RGB") | |
| summary = self._summarize_image(img) | |
| if summary: | |
| doc.metadata["image_summary"] = summary | |
| except Exception as e: | |
| logger.warning( | |
| "Warning: Could not summarize image" | |
| f" {doc.metadata.get('linked_figure_path', '')}: {e}" | |
| ) | |
| return docs | |
| def _summarize_image(self, pil_image: Image.Image) -> str: | |
| """Helper method to call the LLM for image summarization.""" | |
| if not self.llm_pipeline: | |
| return "" | |
| messages = [{ | |
| "role": "user", | |
| "content": [ | |
| {"type": "text", "text": IMAGE_SUMMARY_PROMPT}, | |
| {"type": "image", "image": pil_image}, | |
| ], | |
| }] | |
| try: | |
| output = self.llm_pipeline(text=messages, max_new_tokens=150) | |
| return output[0]["generated_text"][-1]["content"].strip() | |
| except Exception: | |
| return "" | |
| # --- Step 3: Document Chunking --- | |
| def _create_chunks_from_documents( | |
| self, enriched_docs: List[LangchainDocument], display_results: bool = True | |
| ) -> List[LangchainDocument]: | |
| """Takes enriched documents and creates the final list of chunks for indexing. | |
| This method now has a single responsibility: chunking. | |
| """ | |
| if display_results: | |
| logger.info("\nStep 3: Creating final chunks...") | |
| # Sentence Splitting | |
| if display_results: | |
| logger.info("Applying NLTK Sentence Splitting...") | |
| splitter = NLTKTextSplitter(chunk_size=self.config["SENTENCE_CHUNK_SIZE"]) | |
| sentence_chunks = splitter.split_documents(enriched_docs) | |
| if display_results: | |
| logger.info(f"Generated {len(sentence_chunks)} sentence-level chunks.") | |
| # NER Entity Chunking (based on previously enriched metadata) | |
| if display_results: | |
| logger.info("Creating NER Entity Chunks...") | |
| ner_entity_chunks = [ | |
| LangchainDocument( | |
| page_content=entity, | |
| metadata={**doc.metadata, "chunk_type": "ner_entity_standalone"}, | |
| ) | |
| for doc in enriched_docs | |
| if (entities_str := doc.metadata.get("block_ner_entities")) | |
| for entity in entities_str.split(", ") | |
| if entity | |
| ] | |
| if display_results: | |
| logger.info(f"Added {len(ner_entity_chunks)} NER entity chunks.") | |
| all_chunks = sentence_chunks + ner_entity_chunks | |
| return [chunk for chunk in all_chunks if chunk.page_content] | |
| # --- Step 4: Retriever Building --- | |
| def _build_ensemble_retriever( | |
| self, chunks: List[LangchainDocument] | |
| ) -> EnsembleRetriever | None: | |
| """Builds the final ensemble retriever from the chunks. | |
| This method was already well-focused. | |
| """ | |
| if not chunks: | |
| logger.error("No chunks to build retriever from.") | |
| return None | |
| logger.info("\nStep 4: Building specialized retrievers...") | |
| sentence_chunks = [ | |
| doc | |
| for doc in chunks | |
| if doc.metadata.get("chunk_type") != "ner_entity_standalone" | |
| ] | |
| ner_chunks = [ | |
| doc | |
| for doc in chunks | |
| if doc.metadata.get("chunk_type") == "ner_entity_standalone" | |
| ] | |
| retrievers, weights = [], [] | |
| if sentence_chunks: | |
| bm25_retriever = BM25Retriever.from_documents(sentence_chunks) | |
| bm25_retriever.k = self.config["RETRIEVER_TOP_K"] | |
| retrievers.append(bm25_retriever) | |
| weights.append(self.config["ENSEMBLE_WEIGHTS_BM25,SENTENCE,NER"][0]) | |
| sentence_vs = Chroma.from_documents( | |
| documents=sentence_chunks, | |
| embedding=self.embedder, | |
| persist_directory=str( | |
| self.config["CHROMA_PERSIST_DIR"] / "sentences" | |
| ), | |
| ) | |
| vector_retriever = sentence_vs.as_retriever( | |
| search_type="similarity_score_threshold", | |
| search_kwargs={ | |
| "k": self.config["RETRIEVER_TOP_K"], | |
| "score_threshold": self.config["SENTENCE_SCORE_THRESHOLD"], | |
| }, | |
| ) | |
| retrievers.append(vector_retriever) | |
| weights.append(self.config["ENSEMBLE_WEIGHTS_BM25,SENTENCE,NER"][1]) | |
| if ner_chunks: | |
| ner_vs = Chroma.from_documents( | |
| documents=ner_chunks, | |
| embedding=self.embedder, | |
| persist_directory=str(self.config["CHROMA_PERSIST_DIR"] / "entities"), | |
| ) | |
| ner_retriever = ner_vs.as_retriever( | |
| search_type="similarity_score_threshold", | |
| search_kwargs={ | |
| "k": self.config["RETRIEVER_TOP_K"], | |
| "score_threshold": self.config["NER_SCORE_THRESHOLD"], | |
| }, | |
| ) | |
| retrievers.append(ner_retriever) | |
| weights.append(self.config["ENSEMBLE_WEIGHTS_BM25,SENTENCE,NER"][2]) | |
| if not retrievers: | |
| logger.error("⚠️ Could not create any retrievers.") | |
| return None | |
| logger.info(f"Creating final ensemble with weights: {weights}") | |
| return EnsembleRetriever(retrievers=retrievers, weights=weights) | |
| def process_single_page(args_tuple: tuple) -> dict | None: | |
| """Worker function for parallel PDF processing. | |
| Processes one page and returns a structured dictionary for that page. | |
| """ | |
| # Unpack arguments (still using a tuple as requested) | |
| pdf_bytes_buffer, page_num_idx, config, pdf_filename, page_to_chapter_id = ( | |
| args_tuple | |
| ) | |
| lc_documents = [] | |
| page_num = page_num_idx + 1 | |
| try: | |
| # Improvement: Use a 'with' statement for resource management | |
| with fitz.open(stream=pdf_bytes_buffer, filetype="pdf") as doc: | |
| page = doc[page_num_idx] | |
| # 1. Extract raw, potentially fragmented text blocks | |
| raw_text_blocks = page.get_text("blocks", sort=True) | |
| # 2. Immediately merge blocks into paragraphs >>> | |
| paragraph_blocks = KnowledgeBase._merge_text_blocks(raw_text_blocks) | |
| # 3. Process figures (no change) | |
| page_figures = [] | |
| for fig_j, path_dict in enumerate(page.get_drawings()): | |
| bbox = path_dict["rect"] | |
| if ( | |
| bbox.is_empty | |
| or bbox.width < config["FIGURE_MIN_WIDTH"] | |
| or bbox.height < config["FIGURE_MIN_HEIGHT"] | |
| ): | |
| continue | |
| # Improvement: More concise bounding box padding | |
| padded_bbox = bbox + (-2, -2, 2, 2) | |
| padded_bbox.intersect(page.rect) | |
| if padded_bbox.is_empty: | |
| continue | |
| pix = page.get_pixmap(clip=padded_bbox, dpi=150) | |
| if pix.width > 0 and pix.height > 0: | |
| img = Image.frombytes("RGB", [pix.width, pix.height], pix.samples) | |
| img_path = ( | |
| config["IMAGE_DIR"] | |
| / f"{Path(pdf_filename).stem}_p{page_num}_fig{fig_j + 1}.png" | |
| ) | |
| img.save(img_path) | |
| page_figures.append({ | |
| "bbox": bbox, | |
| "path": str(img_path), | |
| "id": f"Figure {fig_j + 1} on {pdf_filename}, page {page_num}", | |
| }) | |
| # 4. Process the clean PARAGRAPH blocks | |
| text_blocks_on_page = [ | |
| { | |
| "bbox": fitz.Rect(x0, y0, x1, y1), | |
| "text": text.strip(), | |
| "original_idx": b_idx, | |
| } | |
| for b_idx, (x0, y0, x1, y1, text, _, _) in enumerate( | |
| paragraph_blocks | |
| ) | |
| if text.strip() | |
| ] | |
| # 5. Link captions and create documents | |
| potential_captions = [ | |
| b | |
| for b in text_blocks_on_page | |
| if re.match(r"^\s*Figure\s*\d+", b["text"], re.I) | |
| ] | |
| mapped_caption_indices = set() | |
| for fig_data in page_figures: | |
| cap_text, cap_idx = KnowledgeBase.find_best_caption_for_figure( | |
| fig_data["bbox"], potential_captions | |
| ) | |
| if cap_text and cap_idx not in mapped_caption_indices: | |
| mapped_caption_indices.add(cap_idx) | |
| metadata = { | |
| "source_pdf": pdf_filename, | |
| "page_number": page_num, | |
| "chunk_type": "figure-caption", | |
| "linked_figure_path": fig_data["path"], | |
| "linked_figure_id": fig_data["id"], | |
| "block_id": f"{page_num}_{cap_idx}", | |
| "original_block_text": cap_text, | |
| } | |
| lc_documents.append( | |
| LangchainDocument(page_content=cap_text, metadata=metadata) | |
| ) | |
| for block_data in text_blocks_on_page: | |
| if block_data["original_idx"] in mapped_caption_indices: | |
| continue | |
| if KnowledgeBase.should_filter_text_block( | |
| block_data["text"], | |
| block_data["bbox"], | |
| page.rect.height, | |
| config["CHUNK_FILTER_SIZE"], | |
| ): | |
| continue | |
| metadata = { | |
| "source_pdf": pdf_filename, | |
| "page_number": page_num, | |
| "chunk_type": "text_block", | |
| "block_id": f"{page_num}_{block_data['original_idx']}", | |
| "original_block_text": block_data["text"], | |
| } | |
| lc_documents.append( | |
| LangchainDocument( | |
| page_content=block_data["text"], metadata=metadata | |
| ) | |
| ) | |
| except Exception as e: | |
| logger.error(f"Error processing {pdf_filename} page {page_num}: {e}") | |
| return None | |
| if not lc_documents: | |
| return None | |
| # Structure the final output | |
| lc_documents.sort( | |
| key=lambda d: int(d.metadata.get("block_id", "0_0").split("_")[-1]) | |
| ) | |
| return { | |
| "page_num": page_num, | |
| "content": { | |
| "chapter_id": page_to_chapter_id.get(page_num, -1), | |
| "blocks": lc_documents, | |
| }, | |
| } | |
| def _merge_text_blocks(blocks: list) -> list: | |
| """Intelligently merges fragmented text blocks into coherent paragraphs.""" | |
| if not blocks: | |
| return [] | |
| merged_blocks = [] | |
| current_text = "" | |
| current_bbox = fitz.Rect() | |
| sentence_enders = {".", "?", "!", "•"} | |
| for i, block in enumerate(blocks): | |
| block_text = block[4].strip() | |
| if not current_text: # Starting a new paragraph | |
| current_bbox = fitz.Rect(block[:4]) | |
| current_text = block_text | |
| else: # Continue existing paragraph | |
| current_bbox.include_rect(block[:4]) | |
| current_text = f"{current_text} {block_text}" | |
| is_last_block = i == len(blocks) - 1 | |
| ends_with_punctuation = block_text.endswith(tuple(sentence_enders)) | |
| if ends_with_punctuation or is_last_block: | |
| merged_blocks.append(( | |
| current_bbox.x0, | |
| current_bbox.y0, | |
| current_bbox.x1, | |
| current_bbox.y1, | |
| current_text, | |
| len(merged_blocks), | |
| 0, | |
| )) | |
| current_text = "" | |
| return merged_blocks | |
| def should_filter_text_block( | |
| block_text: str, | |
| block_bbox: fitz.Rect, | |
| page_height: float, | |
| filter_size: int, | |
| ) -> bool: | |
| """Determines if a text block from a header/footer should be filtered out.""" | |
| is_in_header_area = block_bbox.y0 < (page_height * 0.10) | |
| is_in_footer_area = block_bbox.y1 > (page_height * 0.80) | |
| is_short_text = len(block_text) < filter_size | |
| return (is_in_header_area or is_in_footer_area) and is_short_text | |
| def find_best_caption_for_figure( | |
| figure_bbox: fitz.Rect, potential_captions_on_page: list | |
| ) -> tuple: | |
| """Finds the best caption for a given figure based on proximity and alignment.""" | |
| best_caption_info = (None, -1) | |
| min_score = float("inf") | |
| for cap_info in potential_captions_on_page: | |
| cap_bbox = cap_info["bbox"] | |
| # Heuristic: Score captions directly below the figure | |
| if cap_bbox.y0 >= figure_bbox.y1 - 10: # Caption starts below the figure | |
| vertical_dist = cap_bbox.y0 - figure_bbox.y1 | |
| # Calculate horizontal overlap | |
| overlap_x_start = max(figure_bbox.x0, cap_bbox.x0) | |
| overlap_x_end = min(figure_bbox.x1, cap_bbox.x1) | |
| if ( | |
| overlap_x_end - overlap_x_start | |
| ) > 0: # If they overlap horizontally | |
| fig_center_x = (figure_bbox.x0 + figure_bbox.x1) / 2 | |
| cap_center_x = (cap_bbox.x0 + cap_bbox.x1) / 2 | |
| horizontal_center_dist = abs(fig_center_x - cap_center_x) | |
| # Score is a combination of vertical and horizontal distance | |
| score = vertical_dist + (horizontal_center_dist * 0.5) | |
| if score < min_score: | |
| min_score = score | |
| best_caption_info = (cap_info["text"], cap_info["original_idx"]) | |
| return best_caption_info | |