Spaces:
Running
on
CPU Upgrade
Running
on
CPU Upgrade
| # Copyright 2025 Google LLC | |
| # | |
| # Licensed under the Apache License, Version 2.0 (the "License"); | |
| # you may not use this file except in compliance with the License. | |
| # You may obtain a copy of the License at | |
| # | |
| # http://www.apache.org/licenses/LICENSE-2.0 | |
| # | |
| # Unless required by applicable law or agreed to in writing, software | |
| # distributed under the License is distributed on an "AS IS" BASIS, | |
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
| # See the License for the specific language governing permissions and | |
| # limitations under the License. | |
| import logging | |
| import os | |
| from typing import List | |
| from PIL import Image | |
| from langchain.docstore.document import Document as LangchainDocument | |
| from .knowledge_base import KnowledgeBase | |
| logger = logging.getLogger(__name__) | |
| def format_context_messages_to_string(context_messages: list[dict]) -> str: | |
| """Takes a list of context message dicts and formats them into a single string.""" | |
| if not context_messages: | |
| return "No relevant context was retrieved from the guideline document." | |
| full_text = [ | |
| msg.get("text", "") for msg in context_messages if msg.get("type") == "text" | |
| ] | |
| return "\n".join(full_text) | |
| class RAGContextEngine: | |
| """Uses a pre-built KnowledgeBase to retrieve and format context for queries.""" | |
| def __init__(self, knowledge_base: KnowledgeBase, config_overrides: dict | None = None): | |
| if not isinstance(knowledge_base, KnowledgeBase) or not knowledge_base.retriever: | |
| raise ValueError("An initialized KnowledgeBase with a built retriever is required.") | |
| self.kb = knowledge_base | |
| self.config = self._get_default_config() | |
| if config_overrides: | |
| self.config.update(config_overrides) | |
| def _get_default_config(self): | |
| return { | |
| "FINAL_CONTEXT_TOP_K": 5, | |
| "CONTEXT_SELECTION_STRATEGY": "chapter_aware_window_expansion", | |
| "CONTEXT_WINDOW_SIZE": 0, | |
| "ADD_MAPPED_FIGURES_TO_PROMPT": False, | |
| } | |
| def get_context_messages(self, query_text: str) -> list[dict] | None: | |
| """Public API to get final, formatted context messages for a long query.""" | |
| final_context_docs = self.retrieve_context_docs(query_text) | |
| if not final_context_docs: | |
| logger.warning(f"No relevant context found for query: {query_text}") | |
| return None | |
| context_messages, _ = self.build_context_messages(final_context_docs) | |
| return context_messages | |
| def retrieve_context_docs(self, query_text: str) -> list: | |
| """Handles both short and long queries to retrieve context documents.""" | |
| logger.info(f"Retrieving context documents with query: {query_text}") | |
| if len(query_text.split()) > 5: | |
| logger.info("Long query detected. Decomposing into sub-queries...") | |
| temp_doc = LangchainDocument(page_content=query_text) | |
| enriched_temp_docs = self.kb.document_enricher([temp_doc], summarize=False) | |
| query_chunks_as_docs = self.kb.chunker(enriched_docs=enriched_temp_docs, display_results=False) | |
| sub_queries = list(set([doc.page_content for doc in query_chunks_as_docs])) | |
| else: | |
| logger.info("Short query detected. Using direct retrieval.") | |
| sub_queries = [query_text] | |
| return self.retrieve_context_docs_for_simple_queries(sub_queries) | |
| def get_context_messages_for_simple_queries(self, queries: list[str]) -> list: | |
| """Retrieves context docs and builds them into formatted messages.""" | |
| final_context_docs = self.retrieve_context_docs_for_simple_queries(queries) | |
| if not final_context_docs: | |
| logger.warning(f"No relevant context found for queries: {queries}") | |
| return [] | |
| context_messages, _ = self.build_context_messages(final_context_docs) | |
| return context_messages | |
| def retrieve_context_docs_for_simple_queries(self, queries: list[str]) -> list: | |
| """Invokes the retriever for a list of simple queries and selects the final documents.""" | |
| logger.info(f"Retrieving context documents with simple queries: {queries}") | |
| retrieved_docs = [] | |
| for query in queries: | |
| docs = self.kb.retriever.invoke(query) | |
| retrieved_docs.extend(docs) | |
| return RAGContextEngine.select_final_context( | |
| retrieved_docs=retrieved_docs, | |
| config=self.config, | |
| page_map=self.kb.page_map, | |
| ) | |
| def build_context_messages( | |
| self, docs: List[LangchainDocument] | |
| ) -> tuple[list[dict], list[Image.Image]]: | |
| """Builds a structured list of messages by grouping consecutive text blocks.""" | |
| if not docs: | |
| return [], [] | |
| context_messages = [] | |
| images_found = [] | |
| prose_buffer = [] | |
| def flush_prose_buffer(): | |
| if prose_buffer: | |
| full_prose = "\n\n".join(prose_buffer) | |
| context_messages.append({"type": "text", "text": full_prose}) | |
| prose_buffer.clear() | |
| add_images = self.config.get("ADD_MAPPED_FIGURES_TO_PROMPT", False) | |
| for i, doc in enumerate(docs): | |
| current_page = doc.metadata.get("page_number") | |
| is_new_page = (i > 0) and (current_page != docs[i - 1].metadata.get("page_number")) | |
| is_caption = doc.metadata.get("chunk_type") == "figure-caption" | |
| if is_new_page or (add_images and is_caption): | |
| flush_prose_buffer() | |
| if add_images and is_caption: | |
| source_info = f"--- Source: Page {current_page} ---" | |
| caption_text = f"{source_info}\n{doc.page_content}" | |
| context_messages.append({"type": "text", "text": caption_text}) | |
| image_path = doc.metadata.get("linked_figure_path") | |
| if image_path and os.path.exists(image_path): | |
| try: | |
| image = Image.open(image_path).convert("RGB") | |
| context_messages.append({"type": "image", "image": image}) | |
| images_found.append(image) | |
| except Exception as e: | |
| logger.warning(f"Could not load image {image_path}: {e}") | |
| else: | |
| if not prose_buffer: | |
| source_info = f"--- Source: Page {current_page} ---" | |
| prose_buffer.append(f"\n{source_info}\n") | |
| prose_buffer.append(doc.page_content) | |
| flush_prose_buffer() | |
| return context_messages, images_found | |
| def select_final_context(retrieved_docs: list, config: dict, page_map: dict) -> list: | |
| """Selects final context from retrieved documents using the specified strategy.""" | |
| strategy = config.get("CONTEXT_SELECTION_STRATEGY") | |
| top_k = config.get("FINAL_CONTEXT_TOP_K", 5) | |
| def _calculate_block_frequencies(docs_list: list) -> list: | |
| blocks = {} | |
| for doc in docs_list: | |
| if block_id := doc.metadata.get("block_id"): | |
| if block_id not in blocks: | |
| blocks[block_id] = [] | |
| blocks[block_id].append(doc) | |
| return sorted(blocks.items(), key=lambda item: len(item[1]), reverse=True) | |
| def _expand_chunks_to_blocks(chunks: list) -> list: | |
| return [ | |
| LangchainDocument( | |
| page_content=c.metadata.get("original_block_text", c.page_content), | |
| metadata=c.metadata, | |
| ) | |
| for c in chunks | |
| ] | |
| final_context = [] | |
| if strategy == "chapter_aware_window_expansion": | |
| if not retrieved_docs or not page_map: | |
| return [] | |
| scored_blocks = _calculate_block_frequencies(retrieved_docs) | |
| if not scored_blocks: | |
| return _expand_chunks_to_blocks(retrieved_docs[:top_k]) | |
| primary_hit_page = scored_blocks[0][1][0].metadata.get("page_number") | |
| important_pages = { | |
| c[0].metadata.get("page_number") | |
| for _, c in scored_blocks[:top_k] | |
| if c and c[0].metadata.get("page_number") | |
| } | |
| window_size = config.get("CONTEXT_WINDOW_SIZE", 0) | |
| pages_to_extract = set() | |
| for page_num in important_pages: | |
| current_chapter_info = page_map.get(page_num) | |
| if not current_chapter_info: | |
| continue | |
| current_chapter_id = current_chapter_info["chapter_id"] | |
| pages_to_extract.add(page_num) | |
| for i in range(1, window_size + 1): | |
| if (prev_info := page_map.get(page_num - i)) and prev_info["chapter_id"] == current_chapter_id: | |
| pages_to_extract.add(page_num - i) | |
| if (next_info := page_map.get(page_num + i)) and next_info["chapter_id"] == current_chapter_id: | |
| pages_to_extract.add(page_num + i) | |
| sorted_pages = sorted(list(pages_to_extract)) | |
| if primary_hit_page and primary_hit_page in page_map: | |
| final_context.extend(page_map[primary_hit_page]["blocks"]) | |
| for page_num in sorted_pages: | |
| if page_num != primary_hit_page and page_num in page_map: | |
| final_context.extend(page_map[page_num]["blocks"]) | |
| elif strategy == "rerank_by_frequency": | |
| scored_blocks = _calculate_block_frequencies(retrieved_docs) | |
| representative_chunks = [chunks[0] for _, chunks in scored_blocks[:top_k]] | |
| final_context = _expand_chunks_to_blocks(representative_chunks) | |
| elif strategy == "select_by_rank": | |
| unique_docs_map = {f"{doc.metadata.get('block_id', '')}_{doc.page_content}": doc for doc in retrieved_docs} | |
| representative_chunks = list(unique_docs_map.values())[:top_k] | |
| final_context = _expand_chunks_to_blocks(representative_chunks) | |
| else: | |
| logger.warning(f"Unknown strategy '{strategy}'. Defaulting to top-k raw chunks.") | |
| final_context = retrieved_docs[:top_k] | |
| logger.info(f"Selected {len(final_context)} final context blocks using '{strategy}' strategy.") | |
| return final_context | |