"""Context retrieval with reranking capabilities.""" import os from typing import List, Optional, Tuple, Dict, Any from langchain.schema import Document from langchain_community.vectorstores import Qdrant from langchain_community.embeddings import HuggingFaceEmbeddings from sentence_transformers import CrossEncoder import numpy as np import torch from qdrant_client.http import models as rest import traceback from .filter import create_filter class ContextRetriever: """ Context retriever for hybrid search with optional filtering and reranking. """ def __init__(self, vectorstore: Qdrant, config: dict = None): """ Initialize the context retriever. Args: vectorstore: Qdrant vector store instance config: Configuration dictionary """ self.vectorstore = vectorstore self.config = config or {} self.reranker = None # BM25 attributes self.bm25_vectorizer = None self.bm25_matrix = None self.bm25_documents = None # Initialize reranker if available # Try to get reranker model from different config paths self.reranker_model_name = ( config.get('retrieval', {}).get('reranker_model') or config.get('ranker', {}).get('model') or config.get('reranker_model') or 'BAAI/bge-reranker-v2-m3' ) self.reranker_type = self._detect_reranker_type(self.reranker_model_name) try: if self.reranker_type == 'colbert': from colbert.infra import Run, ColBERTConfig from colbert.modeling.checkpoint import Checkpoint # ColBERT uses late interaction - different implementation needed print(f"✅ RERANKER: ColBERT model detected ({self.reranker_model_name})") print(f"🔍 INTERACTION TYPE: Late interaction (token-level embeddings)") # Create ColBERT config for CPU mode colbert_config = ColBERTConfig( doc_maxlen=300, query_maxlen=32, nbits=2, kmeans_niters=4, root="./colbert_data" ) # Load checkpoint (e.g. "colbert-ir/colbertv2.0") self.colbert_checkpoint = Checkpoint(self.reranker_model_name, colbert_config=colbert_config) self.colbert_model = self.colbert_checkpoint.model self.colbert_tokenizer = self.colbert_checkpoint.raw_tokenizer self.reranker = self._colbert_rerank # attach wrapper function print(f"✅ COLBERT: Model and tokenizer loaded successfully") else: # Standard CrossEncoder for BGE and other models from sentence_transformers import CrossEncoder self.reranker = CrossEncoder(self.reranker_model_name) print(f"✅ RERANKER: Initialized {self.reranker_model_name}") print(f"🔍 INTERACTION TYPE: Cross-encoder (single relevance score)") except Exception as e: print(f"âš ī¸ Reranker initialization failed: {e}") self.reranker = None def _detect_reranker_type(self, model_name: str) -> str: """ Detect the type of reranker based on model name. Args: model_name: Name of the reranker model Returns: 'colbert' for ColBERT models, 'crossencoder' for others """ model_name_lower = model_name.lower() # ColBERT model patterns colbert_patterns = [ 'colbert', 'colbert-ir', 'colbertv2', 'colbert-v2' ] for pattern in colbert_patterns: if pattern in model_name_lower: return 'colbert' # Default to cross-encoder for BGE and other models return 'crossencoder' def _similarity_search_with_colbert_embeddings(self, query: str, k: int = 5, **kwargs) -> List[Tuple[Document, float]]: """ Perform similarity search and fetch ColBERT embeddings for documents. Args: query: Search query k: Number of documents to retrieve **kwargs: Additional search parameters (filter, etc.) Returns: List of (Document, score) tuples with ColBERT embeddings in metadata """ try: print(f"🔍 COLBERT RETRIEVAL: Fetching documents with ColBERT embeddings") # Use the vectorstore's similarity_search_with_score method instead of direct client # This ensures proper filter handling if 'filter' in kwargs and kwargs['filter']: # Use the vectorstore method with filter result = self.vectorstore.similarity_search_with_score( query, k=k, filter=kwargs['filter'] ) else: # Use the vectorstore method without filter result = self.vectorstore.similarity_search_with_score(query, k=k) # Convert to the format we need if isinstance(result, tuple) and len(result) == 2: documents, scores = result elif isinstance(result, list): documents = [] scores = [] for item in result: if isinstance(item, tuple) and len(item) == 2: doc, score = item documents.append(doc) scores.append(score) else: documents.append(item) scores.append(0.0) else: documents = [] scores = [] # Now we need to fetch the ColBERT embeddings for these documents # We'll use the Qdrant client directly for this part since we need specific payload fields from qdrant_client.http import models as rest collection_name = self.vectorstore.collection_name # Get document IDs from the retrieved documents doc_ids = [] for doc in documents: # Extract ID from document metadata or use page_content hash as fallback doc_id = doc.metadata.get('id') or doc.metadata.get('_id') if not doc_id: # Use a hash of the content as ID import hashlib doc_id = hashlib.md5(doc.page_content.encode()).hexdigest() doc_ids.append(doc_id) # Fetch documents with ColBERT embeddings from Qdrant search_result = self.vectorstore.client.retrieve( collection_name=collection_name, ids=doc_ids, with_payload=True, with_vectors=False ) # Convert results to Document objects with ColBERT embeddings enhanced_documents = [] enhanced_scores = [] # Create a mapping from doc_id to original score doc_id_to_score = {} for i, doc in enumerate(documents): doc_id = doc.metadata.get('id') or doc.metadata.get('_id') if not doc_id: import hashlib doc_id = hashlib.md5(doc.page_content.encode()).hexdigest() doc_id_to_score[doc_id] = scores[i] for point in search_result: # Extract payload payload = point.payload # Get the original score for this document doc_id = str(point.id) original_score = doc_id_to_score.get(doc_id, 0.0) # Create Document object with ColBERT embeddings doc = Document( page_content=payload.get('page_content', ''), metadata={ **payload.get('metadata', {}), 'colbert_embedding': payload.get('colbert_embedding'), 'colbert_model': payload.get('colbert_model'), 'colbert_calculated_at': payload.get('colbert_calculated_at') } ) enhanced_documents.append(doc) enhanced_scores.append(original_score) print(f"✅ COLBERT RETRIEVAL: Retrieved {len(enhanced_documents)} documents with ColBERT embeddings") return list(zip(enhanced_documents, enhanced_scores)) except Exception as e: print(f"❌ COLBERT RETRIEVAL ERROR: {e}") print(f"❌ Falling back to regular similarity search") # Fallback to regular search - handle filter parameter correctly if 'filter' in kwargs and kwargs['filter']: return self.vectorstore.similarity_search_with_score(query, k=k, filter=kwargs['filter']) else: return self.vectorstore.similarity_search_with_score(query, k=k) def retrieve_context( self, query: str, k: int = 5, reports: Optional[List[str]] = None, sources: Optional[List[str]] = None, subtype: Optional[str] = None, year: Optional[str] = None, district: Optional[List[str]] = None, filenames: Optional[List[str]] = None, use_reranking: bool = False, qdrant_filter: Optional[rest.Filter] = None ) -> List[Document]: """ Retrieve context documents using hybrid search with optional filtering and reranking. Args: query: User query top_k: Number of documents to retrieve reports: List of report names to filter by sources: List of sources to filter by subtype: Document subtype to filter by year: Year to filter by use_reranking: Whether to apply reranking qdrant_filter: Pre-built Qdrant filter to use Returns: List of retrieved documents """ try: # Determine how many documents to retrieve retrieve_k = k #* 3 if use_reranking else k # Retrieve more for reranking # Build search kwargs search_kwargs = {} # Use qdrant_filter if provided (this takes precedence) if qdrant_filter: search_kwargs = {"filter": qdrant_filter} print(f"✅ FILTERS APPLIED: Using inferred Qdrant filter") else: # Build filter from individual parameters filter_obj = create_filter( reports=reports, sources=sources, subtype=subtype, year=year, district=district, filenames=filenames ) if filter_obj: search_kwargs = {"filter": filter_obj} print(f"✅ FILTERS APPLIED: Using built filter") else: search_kwargs = {} print(f"âš ī¸ NO FILTERS APPLIED: All documents will be searched") # Perform vector search try: # Check if we need ColBERT embeddings for reranking if use_reranking and self.reranker_type == 'colbert': result = self._similarity_search_with_colbert_embeddings( query, k=retrieve_k, **search_kwargs ) else: result = self.vectorstore.similarity_search_with_score( query, k=retrieve_k, **search_kwargs ) # Handle different return formats if isinstance(result, tuple) and len(result) == 2: documents, scores = result elif isinstance(result, list) and len(result) > 0: # Handle case where result is a list of (Document, score) tuples documents = [] scores = [] for item in result: if isinstance(item, tuple) and len(item) == 2: doc, score = item documents.append(doc) scores.append(score) else: # Handle case where item is just a Document documents.append(item) scores.append(0.0) # Default score else: documents = [] scores = [] print(f"✅ RETRIEVAL SUCCESS: Retrieved {len(documents)} documents (requested: {retrieve_k})") # If we got fewer documents than requested, try without filters if len(documents) < retrieve_k and search_kwargs.get('filter'): print(f"âš ī¸ RETRIEVAL: Got {len(documents)} docs with filters, trying without filters...") try: result_no_filter = self.vectorstore.similarity_search_with_score( query, k=retrieve_k ) if isinstance(result_no_filter, tuple) and len(result_no_filter) == 2: documents_no_filter, scores_no_filter = result_no_filter elif isinstance(result_no_filter, list): documents_no_filter = [] scores_no_filter = [] for item in result_no_filter: if isinstance(item, tuple) and len(item) == 2: doc, score = item documents_no_filter.append(doc) scores_no_filter.append(score) else: documents_no_filter.append(item) scores_no_filter.append(0.0) else: documents_no_filter = [] scores_no_filter = [] if len(documents_no_filter) > len(documents): print(f"✅ RETRIEVAL: Got {len(documents_no_filter)} docs without filters") documents = documents_no_filter scores = scores_no_filter except Exception as e: print(f"âš ī¸ RETRIEVAL: Fallback search failed: {e}") except Exception as e: print(f"❌ RETRIEVAL ERROR: {str(e)}") return [] # Apply reranking if enabled reranking_applied = False if use_reranking and len(documents) > 1: print(f"🔄 RERANKING: Applying {self.reranker_model_name} to {len(documents)} documents...") try: original_docs = documents.copy() original_scores = scores.copy() # Apply reranking # print(f"🔍 ORIGINAL DOCS: {documents[0]}") reranked_docs = self._apply_reranking(query, documents, scores) # print(f"🔍 RERANKED DOCS: {reranked_docs[0]}") reranking_applied = len(reranked_docs) > 0 if reranking_applied: print(f"✅ RERANKING APPLIED: {self.reranker_model_name}") documents = reranked_docs # Update scores to reflect reranking # scores = [0.0] * len(documents) # Reranked scores are not directly comparable else: print(f"âš ī¸ RERANKING FAILED: Using original order") documents = original_docs scores = original_scores return documents except Exception as e: print(f"❌ RERANKING ERROR: {str(e)}") print(f"âš ī¸ RERANKING FAILED: Using original order") reranking_applied = False elif use_reranking and len(documents) <= 1: print(f"â„šī¸ RERANKING: Skipped (only {len(documents)} document(s) retrieved)") if use_reranking: print(f"â„šī¸ RERANKING: Skipped (disabled or insufficient documents)") # Store original scores in metadata for i, (doc, score) in enumerate(zip(documents, scores)): doc.metadata['original_score'] = float(score) doc.metadata['reranking_applied'] = False return documents else: print(f"â„šī¸ RERANKING: Skipped (disabled or insufficient documents)") # Limit to requested number of documents documents = documents[:k] scores = scores[:k] if scores else [0.0] * len(documents) # Add metadata to documents for i, (doc, score) in enumerate(zip(documents, scores)): if hasattr(doc, 'metadata'): doc.metadata.update({ 'reranking_applied': reranking_applied, 'reranker_model': 'BAAI/bge-reranker-v2-m3' if reranking_applied else None, 'original_rank': i + 1, 'final_rank': i + 1, 'original_score': float(score) if score is not None else 0.0 }) return documents except Exception as e: print(f"❌ CONTEXT RETRIEVAL ERROR: {str(e)}") return [] def _apply_reranking(self, query: str, documents: List[Document], scores: List[float]) -> List[Document]: """ Apply reranking to documents using the appropriate reranker. Args: query: User query documents: List of documents to rerank scores: Original scores Returns: Reranked list of documents """ if not self.reranker or len(documents) == 0: return documents try: print(f"🔍 RERANKING METHOD: Starting reranking with {len(documents)} documents") print(f"🔍 RERANKING TYPE: {self.reranker_type.upper()}") if self.reranker_type == 'colbert': return self._apply_colbert_reranking(query, documents, scores) else: return self._apply_crossencoder_reranking(query, documents, scores) except Exception as e: print(f"❌ RERANKING ERROR: {str(e)}") return documents def _apply_crossencoder_reranking(self, query: str, documents: List[Document], scores: List[float]) -> List[Document]: """ Apply reranking using CrossEncoder (BGE and other models). Args: query: User query documents: List of documents to rerank scores: Original scores Returns: Reranked list of documents """ # Prepare pairs for reranking pairs = [] for doc in documents: pairs.append([query, doc.page_content]) print(f"🔍 CROSS-ENCODER: Prepared {len(pairs)} pairs for reranking") # Get reranking scores using the correct CrossEncoder API rerank_scores = self.reranker.predict(pairs) # Handle single score case if not isinstance(rerank_scores, (list, np.ndarray)): rerank_scores = [rerank_scores] # Ensure we have the right number of scores if len(rerank_scores) != len(documents): print(f"âš ī¸ RERANKING WARNING: Expected {len(documents)} scores, got {len(rerank_scores)}") return documents print(f"🔍 CROSS-ENCODER: Got {len(rerank_scores)} rerank scores") print(f"🔍 CROSS-ENCODER SCORES: {rerank_scores[:5]}...") # Show first 5 scores # Combine documents with their rerank scores doc_scores = list(zip(documents, rerank_scores)) # Sort by rerank score (descending) doc_scores.sort(key=lambda x: x[1], reverse=True) # Extract reranked documents and store scores in metadata reranked_docs = [] for i, (doc, rerank_score) in enumerate(doc_scores): # Find original index for original score original_idx = documents.index(doc) original_score = scores[original_idx] if original_idx < len(scores) else 0.0 # Create new document with reranking metadata new_doc = Document( page_content=doc.page_content, metadata={ **doc.metadata, 'reranking_applied': True, 'reranker_model': self.reranker_model_name, 'reranker_type': self.reranker_type, 'original_rank': original_idx + 1, 'final_rank': i + 1, 'original_score': float(original_score), 'reranked_score': float(rerank_score) } ) reranked_docs.append(new_doc) print(f"✅ CROSS-ENCODER: Reranked {len(reranked_docs)} documents") return reranked_docs def _apply_colbert_reranking(self, query: str, documents: List[Document], scores: List[float]) -> List[Document]: """ Apply reranking using ColBERT late interaction. Args: query: User query documents: List of documents to rerank scores: Original scores Returns: Reranked list of documents """ # Use the actual ColBERT reranking implementation return self._colbert_rerank(query, documents, scores) def _colbert_rerank(self, query: str, documents: List[Document], scores: List[float]) -> List[Document]: """ ColBERT reranking using late interaction with pre-calculated embeddings support. Args: query: User query documents: List of documents to rerank scores: Original scores Returns: Reranked list of documents """ try: print(f"🔍 COLBERT: Starting late interaction reranking with {len(documents)} documents") # Check if documents have pre-calculated ColBERT embeddings pre_calculated_embeddings = [] documents_without_embeddings = [] documents_without_indices = [] for i, doc in enumerate(documents): if (hasattr(doc, 'metadata') and 'colbert_embedding' in doc.metadata and doc.metadata['colbert_embedding'] is not None): # Use pre-calculated embedding colbert_embedding = doc.metadata['colbert_embedding'] if isinstance(colbert_embedding, list): colbert_embedding = torch.tensor(colbert_embedding) pre_calculated_embeddings.append(colbert_embedding) else: # Need to calculate embedding documents_without_embeddings.append(doc) documents_without_indices.append(i) # Calculate query embedding query_embeddings = self.colbert_checkpoint.queryFromText([query]) # Calculate embeddings for documents without pre-calculated ones if documents_without_embeddings: print(f"🔄 COLBERT: Calculating embeddings for {len(documents_without_embeddings)} documents without pre-calculated embeddings") doc_texts = [doc.page_content for doc in documents_without_embeddings] doc_embeddings = self.colbert_checkpoint.docFromText(doc_texts) # Insert calculated embeddings into the right positions for i, embedding in enumerate(doc_embeddings): idx = documents_without_indices[i] pre_calculated_embeddings.insert(idx, embedding) else: print(f"✅ COLBERT: Using pre-calculated embeddings for all {len(documents)} documents") # Calculate late interaction scores # ColBERT uses MaxSim: for each query token, find max similarity with document tokens colbert_scores = [] for i, doc_embedding in enumerate(pre_calculated_embeddings): # Calculate similarity matrix between query and document i sim_matrix = torch.matmul(query_embeddings[0], doc_embedding.transpose(-1, -2)) # MaxSim: for each query token, take max similarity with document max_sim_per_query_token = torch.max(sim_matrix, dim=-1)[0] # Sum over query tokens to get final score final_score = torch.sum(max_sim_per_query_token).item() colbert_scores.append(final_score) # Sort documents by ColBERT scores doc_scores = list(zip(documents, colbert_scores)) doc_scores.sort(key=lambda x: x[1], reverse=True) # Create reranked documents with metadata reranked_docs = [] for i, (doc, colbert_score) in enumerate(doc_scores): original_idx = documents.index(doc) original_score = scores[original_idx] if original_idx < len(scores) else 0.0 new_doc = Document( page_content=doc.page_content, metadata={ **doc.metadata, 'reranking_applied': True, 'reranker_model': self.reranker_model_name, 'reranker_type': self.reranker_type, 'original_rank': original_idx + 1, 'final_rank': i + 1, 'original_score': float(original_score), 'reranked_score': float(colbert_score), 'colbert_score': float(colbert_score), 'colbert_embedding_pre_calculated': 'colbert_embedding' in doc.metadata } ) reranked_docs.append(new_doc) print(f"✅ COLBERT: Reranked {len(reranked_docs)} documents using late interaction") print(f"🔍 COLBERT SCORES: {[f'{score:.4f}' for score in colbert_scores[:5]]}...") return reranked_docs except Exception as e: print(f"❌ COLBERT RERANKING ERROR: {str(e)}") print(f"❌ COLBERT TRACEBACK: {traceback.format_exc()}") # Fallback to original order - return documents as-is return documents def retrieve_with_scores(self, query: str, vectorstore=None, k: int = 5, reports: List[str] = None, sources: List[str] = None, subtype: List[str] = None, year: List[str] = None, use_reranking: bool = False, qdrant_filter: Optional[rest.Filter] = None) -> Tuple[List[Document], List[float]]: """ Retrieve context documents with scores using hybrid search with optional reranking. Args: query: User query vectorstore: Optional vectorstore instance (for compatibility) k: Number of documents to retrieve reports: List of report names to filter by sources: List of sources to filter by subtype: Document subtype to filter by year: List of years to filter by use_reranking: Whether to apply reranking qdrant_filter: Pre-built Qdrant filter Returns: Tuple of (documents, scores) """ try: # Use the provided vectorstore if available, otherwise use the instance one if vectorstore: self.vectorstore = vectorstore # Determine search strategy search_strategy = self.config.get('retrieval', {}).get('search_strategy', 'vector_only') if search_strategy == 'vector_only': # Vector search only print(f"🔄 VECTOR SEARCH: Retrieving {k} documents...") if qdrant_filter: print(f"✅ QDRANT FILTER APPLIED: Using inferred Qdrant filter") # Pass filter as positional argument, not keyword argument results = self.vectorstore.similarity_search_with_score( query, k=k, filter=qdrant_filter ) else: # Build filter from individual parameters filter_conditions = self._build_filter_conditions(reports, sources, subtype, year) if filter_conditions: print(f"✅ FILTER APPLIED: {filter_conditions}") results = self.vectorstore.similarity_search_with_score( query, k=k, filter=filter_conditions ) else: print(f"â„šī¸ NO FILTERS APPLIED: All documents will be searched") results = self.vectorstore.similarity_search_with_score(query, k=k) print(f"🔍 SEARCH DEBUG: Raw result type: {type(results)}") print(f"🔍 SEARCH DEBUG: Raw result length: {len(results)}") # Handle different result formats if results and isinstance(results[0], tuple): documents = [doc for doc, score in results] scores = [score for doc, score in results] print(f"🔍 SEARCH DEBUG: After unpacking - documents: {len(documents)}, scores: {len(scores)}") else: documents = results scores = [0.0] * len(documents) print(f"🔍 SEARCH DEBUG: No scores available, using default") print(f"🔧 CONVERTING: Converting {len(documents)} documents") # Convert to Document objects and store original scores final_documents = [] for i, (doc, score) in enumerate(zip(documents, scores)): if hasattr(doc, 'page_content'): new_doc = Document( page_content=doc.page_content, metadata=doc.metadata.copy() ) # Store original score in metadata new_doc.metadata['original_score'] = float(score) if score is not None else 0.0 final_documents.append(new_doc) else: print(f"âš ī¸ WARNING: Document {i} has no page_content") print(f"✅ RETRIEVAL SUCCESS: Retrieved {len(final_documents)} documents") # Apply reranking if enabled if use_reranking and len(final_documents) > 1: print(f"🔄 RERANKING: Applying {self.reranker_model} to {len(final_documents)} documents...") final_documents = self._apply_reranking(query, final_documents, scores) print(f"✅ RERANKING APPLIED: {self.reranker_model}") else: print(f"â„šī¸ RERANKING: Skipped (disabled or no documents)") return final_documents, scores else: print(f"❌ UNSUPPORTED STRATEGY: {search_strategy}") return [], [] except Exception as e: print(f"❌ RETRIEVAL ERROR: {e}") print(f"❌ RETRIEVAL TRACEBACK: {traceback.format_exc()}") return [], [] def _build_filter_conditions(self, reports: List[str] = None, sources: List[str] = None, subtype: List[str] = None, year: List[str] = None) -> Optional[rest.Filter]: """ Build Qdrant filter conditions from individual parameters. Args: reports: List of report names sources: List of sources subtype: Document subtype year: List of years Returns: Qdrant filter or None """ conditions = [] if reports: conditions.append(rest.FieldCondition( key="metadata.filename", match=rest.MatchAny(any=reports) )) if sources: conditions.append(rest.FieldCondition( key="metadata.source", match=rest.MatchAny(any=sources) )) if subtype: conditions.append(rest.FieldCondition( key="metadata.subtype", match=rest.MatchAny(any=subtype) )) if year: conditions.append(rest.FieldCondition( key="metadata.year", match=rest.MatchAny(any=year) )) if conditions: return rest.Filter(must=conditions) return None def get_context( query: str, vectorstore: Qdrant, k: int = 5, reports: Optional[List[str]] = None, sources: Optional[List[str]] = None, subtype: Optional[str] = None, year: Optional[str] = None, use_reranking: bool = False, qdrant_filter: Optional[rest.Filter] = None ) -> List[Document]: """ Convenience function to get context documents. Args: query: User query vectorstore: Qdrant vector store instance k: Number of documents to retrieve reports: Optional list of report names to filter by sources: Optional list of source categories to filter by subtype: Optional subtype to filter by year: Optional year to filter by use_reranking: Whether to apply reranking qdrant_filter: Optional pre-built Qdrant filter Returns: List of retrieved documents """ retriever = ContextRetriever(vectorstore) return retriever.retrieve_context( query=query, k=k, reports=reports, sources=sources, subtype=subtype, year=year, use_reranking=use_reranking, qdrant_filter=qdrant_filter ) def format_context_for_llm(documents: List[Document]) -> str: """ Format retrieved documents for LLM input. Args: documents: List of Document objects Returns: Formatted string for LLM """ if not documents: return "" formatted_parts = [] for i, doc in enumerate(documents, 1): content = doc.page_content.strip() source = doc.metadata.get('filename', 'Unknown') formatted_parts.append(f"Document {i} (Source: {source}):\n{content}") return "\n\n".join(formatted_parts) def get_context_metadata(documents: List[Document]) -> Dict[str, Any]: """ Extract metadata summary from retrieved documents. Args: documents: List of Document objects Returns: Dictionary with metadata summary """ if not documents: return {} sources = set() years = set() doc_types = set() for doc in documents: metadata = doc.metadata if 'filename' in metadata: sources.add(metadata['filename']) if 'year' in metadata: years.add(metadata['year']) if 'source' in metadata: doc_types.add(metadata['source']) return { "num_documents": len(documents), "sources": list(sources), "years": list(years), "document_types": list(doc_types) }