Spaces:
Sleeping
Sleeping
| """Hybrid search implementation combining vector and sparse retrieval.""" | |
| import json | |
| import numpy as np | |
| from typing import List, Dict, Any, Tuple | |
| from pathlib import Path | |
| from langchain.docstore.document import Document | |
| from langchain_qdrant import QdrantVectorStore | |
| from langchain_community.retrievers import BM25Retriever | |
| from .filter import create_filter | |
| import pickle | |
| import os | |
| class HybridRetriever: | |
| """ | |
| Hybrid retrieval system combining vector search (dense) and BM25 (sparse) search. | |
| Supports configurable search modes: vector_only, sparse_only, or hybrid. | |
| """ | |
| def __init__(self, config: Dict[str, Any]): | |
| """ | |
| Initialize hybrid retriever. | |
| Args: | |
| config: Configuration dictionary with hybrid search settings | |
| """ | |
| self.config = config | |
| self.bm25_retriever = None | |
| self.documents = [] | |
| self._bm25_cache_file = None | |
| def _get_bm25_cache_path(self) -> str: | |
| """Get path for BM25 cache file.""" | |
| cache_dir = Path("cache/bm25") | |
| cache_dir.mkdir(parents=True, exist_ok=True) | |
| return str(cache_dir / "bm25_retriever.pkl") | |
| def initialize_bm25(self, documents: List[Document], force_rebuild: bool = False) -> None: | |
| """ | |
| Initialize BM25 retriever with documents. | |
| Args: | |
| documents: List of Document objects to index | |
| force_rebuild: Whether to force rebuilding the BM25 index | |
| """ | |
| self.documents = documents | |
| self._bm25_cache_file = self._get_bm25_cache_path() | |
| # Try to load cached BM25 retriever | |
| if not force_rebuild and os.path.exists(self._bm25_cache_file): | |
| try: | |
| print("Loading cached BM25 retriever...") | |
| with open(self._bm25_cache_file, 'rb') as f: | |
| self.bm25_retriever = pickle.load(f) | |
| print(f"✅ Loaded cached BM25 retriever with {len(self.documents)} documents") | |
| return | |
| except Exception as e: | |
| print(f"⚠️ Failed to load cached BM25 retriever: {e}") | |
| print("Building new BM25 index...") | |
| # Build new BM25 retriever | |
| print("Building BM25 index...") | |
| try: | |
| # Use langchain's BM25Retriever | |
| self.bm25_retriever = BM25Retriever.from_documents(documents) | |
| # Configure BM25 parameters | |
| bm25_config = self.config.get("bm25", {}) | |
| k = bm25_config.get("top_k", 20) | |
| self.bm25_retriever.k = k | |
| # Cache the BM25 retriever | |
| with open(self._bm25_cache_file, 'wb') as f: | |
| pickle.dump(self.bm25_retriever, f) | |
| print(f"✅ Built and cached BM25 retriever with {len(documents)} documents") | |
| except Exception as e: | |
| print(f"❌ Failed to build BM25 retriever: {e}") | |
| print("BM25 search will be disabled") | |
| self.bm25_retriever = None | |
| def _filter_documents_by_metadata( | |
| self, | |
| documents: List[Document], | |
| reports: List[str] = None, | |
| sources: str = None, | |
| subtype: List[str] = None, | |
| year: List[str] = None | |
| ) -> List[Document]: | |
| """ | |
| Filter documents by metadata criteria. | |
| Args: | |
| documents: List of documents to filter | |
| reports: List of specific report filenames | |
| sources: Source category | |
| subtype: List of subtypes | |
| year: List of years | |
| Returns: | |
| Filtered list of documents | |
| """ | |
| if not any([reports, sources, subtype, year]): | |
| return documents | |
| filtered_docs = [] | |
| for doc in documents: | |
| metadata = doc.metadata | |
| # Filter by reports | |
| if reports: | |
| filename = metadata.get('filename', '') | |
| if not any(report in filename for report in reports): | |
| continue | |
| # Filter by sources | |
| if sources: | |
| doc_source = metadata.get('source', '') | |
| if sources != doc_source: | |
| continue | |
| # Filter by subtype | |
| if subtype: | |
| doc_subtype = metadata.get('subtype', '') | |
| if doc_subtype not in subtype: | |
| continue | |
| # Filter by year | |
| if year: | |
| doc_year = str(metadata.get('year', '')) | |
| if doc_year not in year: | |
| continue | |
| filtered_docs.append(doc) | |
| return filtered_docs | |
| def _bm25_search( | |
| self, | |
| query: str, | |
| k: int = 20, | |
| reports: List[str] = None, | |
| sources: str = None, | |
| subtype: List[str] = None, | |
| year: List[str] = None | |
| ) -> List[Tuple[Document, float]]: | |
| """ | |
| Perform BM25 sparse search. | |
| Args: | |
| query: Search query | |
| k: Number of documents to retrieve | |
| reports: List of specific report filenames | |
| sources: Source category | |
| subtype: List of subtypes | |
| year: List of years | |
| Returns: | |
| List of (Document, score) tuples | |
| """ | |
| if not self.bm25_retriever: | |
| print("⚠️ BM25 retriever not available") | |
| return [] | |
| try: | |
| # Get BM25 results | |
| self.bm25_retriever.k = k | |
| bm25_docs = self.bm25_retriever.invoke(query) | |
| # Apply metadata filtering | |
| if any([reports, sources, subtype, year]): | |
| bm25_docs = self._filter_documents_by_metadata( | |
| bm25_docs, reports, sources, subtype, year | |
| ) | |
| # BM25Retriever doesn't return scores directly, so we'll use placeholder scores | |
| # In a production system, you'd want to access the actual BM25 scores | |
| results = [] | |
| for i, doc in enumerate(bm25_docs): | |
| # Assign decreasing scores based on rank (higher rank = higher score) | |
| # Normalize to [0, 1] range for consistency with vector search | |
| score = max(0.1, 1.0 - (i / max(len(bm25_docs), 1))) | |
| results.append((doc, score)) | |
| return results | |
| except Exception as e: | |
| print(f"❌ BM25 search failed: {e}") | |
| return [] | |
| def _vector_search( | |
| self, | |
| vectorstore: QdrantVectorStore, | |
| query: str, | |
| k: int = 20, | |
| reports: List[str] = None, | |
| sources: str = None, | |
| subtype: List[str] = None, | |
| year: List[str] = None | |
| ) -> List[Tuple[Document, float]]: | |
| """ | |
| Perform vector similarity search. | |
| Args: | |
| vectorstore: QdrantVectorStore instance | |
| query: Search query | |
| k: Number of documents to retrieve | |
| reports: List of specific report filenames | |
| sources: Source category | |
| subtype: List of subtypes | |
| year: List of years | |
| Returns: | |
| List of (Document, score) tuples | |
| """ | |
| try: | |
| # Create filter | |
| filter_obj = create_filter( | |
| reports=reports, | |
| sources=sources, | |
| subtype=subtype, | |
| year=year | |
| ) | |
| # Perform vector search | |
| if filter_obj: | |
| results = vectorstore.similarity_search_with_score( | |
| query, k=k, filter=filter_obj | |
| ) | |
| else: | |
| results = vectorstore.similarity_search_with_score(query, k=k) | |
| return results | |
| except Exception as e: | |
| print(f"❌ Vector search failed: {e}") | |
| return [] | |
| def _normalize_scores(self, results: List[Tuple[Document, float]], method: str = "min_max") -> List[Tuple[Document, float]]: | |
| """ | |
| Normalize scores to [0, 1] range. | |
| Args: | |
| results: List of (Document, score) tuples | |
| method: Normalization method ('min_max' or 'z_score') | |
| Returns: | |
| List of (Document, normalized_score) tuples | |
| """ | |
| if not results: | |
| return results | |
| scores = [score for _, score in results] | |
| if method == "min_max": | |
| min_score = min(scores) | |
| max_score = max(scores) | |
| if max_score == min_score: | |
| normalized_results = [(doc, 1.0) for doc, _ in results] | |
| else: | |
| normalized_results = [ | |
| (doc, (score - min_score) / (max_score - min_score)) | |
| for doc, score in results | |
| ] | |
| elif method == "z_score": | |
| mean_score = np.mean(scores) | |
| std_score = np.std(scores) | |
| if std_score == 0: | |
| normalized_results = [(doc, 1.0) for doc, _ in results] | |
| else: | |
| normalized_results = [ | |
| (doc, max(0, (score - mean_score) / std_score)) | |
| for doc, score in results | |
| ] | |
| else: | |
| normalized_results = results | |
| return normalized_results | |
| def _combine_results( | |
| self, | |
| vector_results: List[Tuple[Document, float]], | |
| bm25_results: List[Tuple[Document, float]], | |
| alpha: float = 0.5 | |
| ) -> List[Tuple[Document, float]]: | |
| """ | |
| Combine vector and BM25 results with weighted scoring. | |
| Args: | |
| vector_results: Vector search results | |
| bm25_results: BM25 search results | |
| alpha: Weight for vector scores (1-alpha for BM25 scores) | |
| Returns: | |
| Combined and ranked results | |
| """ | |
| # Normalize scores | |
| vector_results = self._normalize_scores(vector_results) | |
| bm25_results = self._normalize_scores(bm25_results) | |
| # Create document ID mapping for both result sets | |
| vector_docs = {id(doc): (doc, score) for doc, score in vector_results} | |
| bm25_docs = {id(doc): (doc, score) for doc, score in bm25_results} | |
| # Combine scores | |
| combined_scores = {} | |
| all_doc_ids = set(vector_docs.keys()) | set(bm25_docs.keys()) | |
| for doc_id in all_doc_ids: | |
| vector_score = vector_docs.get(doc_id, (None, 0.0))[1] | |
| bm25_score = bm25_docs.get(doc_id, (None, 0.0))[1] | |
| # Weighted combination | |
| combined_score = alpha * vector_score + (1 - alpha) * bm25_score | |
| # Get document object | |
| doc = vector_docs.get(doc_id, bm25_docs.get(doc_id))[0] | |
| combined_scores[doc_id] = (doc, combined_score) | |
| # Sort by combined score (descending) | |
| sorted_results = sorted( | |
| combined_scores.values(), | |
| key=lambda x: x[1], | |
| reverse=True | |
| ) | |
| return sorted_results | |
| def retrieve( | |
| self, | |
| vectorstore: QdrantVectorStore, | |
| query: str, | |
| mode: str = "hybrid", | |
| reports: List[str] = None, | |
| sources: str = None, | |
| subtype: List[str] = None, | |
| year: List[str] = None, | |
| alpha: float = 0.5, | |
| k: int = None | |
| ) -> List[Document]: | |
| """ | |
| Retrieve documents using the specified search mode. | |
| Args: | |
| vectorstore: QdrantVectorStore instance | |
| query: Search query | |
| mode: Search mode ('vector_only', 'sparse_only', or 'hybrid') | |
| reports: List of specific report filenames | |
| sources: Source category | |
| subtype: List of subtypes | |
| year: List of years | |
| alpha: Weight for vector scores in hybrid mode (0.5 = equal weight) | |
| k: Number of documents to retrieve | |
| Returns: | |
| List of relevant Document objects | |
| """ | |
| if k is None: | |
| k = self.config.get("retriever", {}).get("top_k", 20) | |
| results = [] | |
| if mode == "vector_only": | |
| # Vector search only | |
| vector_results = self._vector_search( | |
| vectorstore, query, k, reports, sources, subtype, year | |
| ) | |
| results = [(doc, score) for doc, score in vector_results] | |
| elif mode == "sparse_only": | |
| # BM25 search only | |
| bm25_results = self._bm25_search( | |
| query, k, reports, sources, subtype, year | |
| ) | |
| results = [(doc, score) for doc, score in bm25_results] | |
| elif mode == "hybrid": | |
| # Hybrid search - combine both | |
| # Get more results from each method to have better fusion | |
| retrieval_k = min(k * 2, 50) # Get more candidates for fusion | |
| vector_results = self._vector_search( | |
| vectorstore, query, retrieval_k, reports, sources, subtype, year | |
| ) | |
| bm25_results = self._bm25_search( | |
| query, retrieval_k, reports, sources, subtype, year | |
| ) | |
| results = self._combine_results(vector_results, bm25_results, alpha) | |
| else: | |
| raise ValueError(f"Unknown search mode: {mode}") | |
| # Limit to top k results | |
| results = results[:k] | |
| # Return just the documents | |
| return [doc for doc, score in results] | |
| def retrieve_with_scores( | |
| self, | |
| vectorstore: QdrantVectorStore, | |
| query: str, | |
| mode: str = "hybrid", | |
| reports: List[str] = None, | |
| sources: str = None, | |
| subtype: List[str] = None, | |
| year: List[str] = None, | |
| alpha: float = 0.5, | |
| k: int = None | |
| ) -> List[Tuple[Document, float]]: | |
| """ | |
| Retrieve documents with scores using the specified search mode. | |
| Args: | |
| vectorstore: QdrantVectorStore instance | |
| query: Search query | |
| mode: Search mode ('vector_only', 'sparse_only', or 'hybrid') | |
| reports: List of specific report filenames | |
| sources: Source category | |
| subtype: List of subtypes | |
| year: List of years | |
| alpha: Weight for vector scores in hybrid mode (0.5 = equal weight) | |
| k: Number of documents to retrieve | |
| Returns: | |
| List of (Document, score) tuples | |
| """ | |
| if k is None: | |
| k = self.config.get("retriever", {}).get("top_k", 20) | |
| results = [] | |
| if mode == "vector_only": | |
| # Vector search only | |
| results = self._vector_search( | |
| vectorstore, query, k, reports, sources, subtype, year | |
| ) | |
| elif mode == "sparse_only": | |
| # BM25 search only | |
| results = self._bm25_search( | |
| query, k, reports, sources, subtype, year | |
| ) | |
| elif mode == "hybrid": | |
| # Hybrid search - combine both | |
| # Get more results from each method to have better fusion | |
| retrieval_k = min(k * 2, 50) # Get more candidates for fusion | |
| vector_results = self._vector_search( | |
| vectorstore, query, retrieval_k, reports, sources, subtype, year | |
| ) | |
| bm25_results = self._bm25_search( | |
| query, retrieval_k, reports, sources, subtype, year | |
| ) | |
| results = self._combine_results(vector_results, bm25_results, alpha) | |
| else: | |
| raise ValueError(f"Unknown search mode: {mode}") | |
| # Limit to top k results | |
| return results[:k] | |
| def get_available_search_modes() -> List[str]: | |
| """Get list of available search modes.""" | |
| return ["vector_only", "sparse_only", "hybrid"] | |
| def get_search_mode_description() -> Dict[str, str]: | |
| """Get descriptions for each search mode.""" | |
| return { | |
| "vector_only": "Semantic search using dense embeddings - good for conceptual matching", | |
| "sparse_only": "Keyword search using BM25 - good for exact term matching", | |
| "hybrid": "Combined semantic and keyword search - balanced approach" | |
| } |