Ara Yeroyan
add src
f5df983
raw
history blame
16.7 kB
"""Hybrid search implementation combining vector and sparse retrieval."""
import json
import numpy as np
from typing import List, Dict, Any, Tuple
from pathlib import Path
from langchain.docstore.document import Document
from langchain_qdrant import QdrantVectorStore
from langchain_community.retrievers import BM25Retriever
from .filter import create_filter
import pickle
import os
class HybridRetriever:
"""
Hybrid retrieval system combining vector search (dense) and BM25 (sparse) search.
Supports configurable search modes: vector_only, sparse_only, or hybrid.
"""
def __init__(self, config: Dict[str, Any]):
"""
Initialize hybrid retriever.
Args:
config: Configuration dictionary with hybrid search settings
"""
self.config = config
self.bm25_retriever = None
self.documents = []
self._bm25_cache_file = None
def _get_bm25_cache_path(self) -> str:
"""Get path for BM25 cache file."""
cache_dir = Path("cache/bm25")
cache_dir.mkdir(parents=True, exist_ok=True)
return str(cache_dir / "bm25_retriever.pkl")
def initialize_bm25(self, documents: List[Document], force_rebuild: bool = False) -> None:
"""
Initialize BM25 retriever with documents.
Args:
documents: List of Document objects to index
force_rebuild: Whether to force rebuilding the BM25 index
"""
self.documents = documents
self._bm25_cache_file = self._get_bm25_cache_path()
# Try to load cached BM25 retriever
if not force_rebuild and os.path.exists(self._bm25_cache_file):
try:
print("Loading cached BM25 retriever...")
with open(self._bm25_cache_file, 'rb') as f:
self.bm25_retriever = pickle.load(f)
print(f"✅ Loaded cached BM25 retriever with {len(self.documents)} documents")
return
except Exception as e:
print(f"⚠️ Failed to load cached BM25 retriever: {e}")
print("Building new BM25 index...")
# Build new BM25 retriever
print("Building BM25 index...")
try:
# Use langchain's BM25Retriever
self.bm25_retriever = BM25Retriever.from_documents(documents)
# Configure BM25 parameters
bm25_config = self.config.get("bm25", {})
k = bm25_config.get("top_k", 20)
self.bm25_retriever.k = k
# Cache the BM25 retriever
with open(self._bm25_cache_file, 'wb') as f:
pickle.dump(self.bm25_retriever, f)
print(f"✅ Built and cached BM25 retriever with {len(documents)} documents")
except Exception as e:
print(f"❌ Failed to build BM25 retriever: {e}")
print("BM25 search will be disabled")
self.bm25_retriever = None
def _filter_documents_by_metadata(
self,
documents: List[Document],
reports: List[str] = None,
sources: str = None,
subtype: List[str] = None,
year: List[str] = None
) -> List[Document]:
"""
Filter documents by metadata criteria.
Args:
documents: List of documents to filter
reports: List of specific report filenames
sources: Source category
subtype: List of subtypes
year: List of years
Returns:
Filtered list of documents
"""
if not any([reports, sources, subtype, year]):
return documents
filtered_docs = []
for doc in documents:
metadata = doc.metadata
# Filter by reports
if reports:
filename = metadata.get('filename', '')
if not any(report in filename for report in reports):
continue
# Filter by sources
if sources:
doc_source = metadata.get('source', '')
if sources != doc_source:
continue
# Filter by subtype
if subtype:
doc_subtype = metadata.get('subtype', '')
if doc_subtype not in subtype:
continue
# Filter by year
if year:
doc_year = str(metadata.get('year', ''))
if doc_year not in year:
continue
filtered_docs.append(doc)
return filtered_docs
def _bm25_search(
self,
query: str,
k: int = 20,
reports: List[str] = None,
sources: str = None,
subtype: List[str] = None,
year: List[str] = None
) -> List[Tuple[Document, float]]:
"""
Perform BM25 sparse search.
Args:
query: Search query
k: Number of documents to retrieve
reports: List of specific report filenames
sources: Source category
subtype: List of subtypes
year: List of years
Returns:
List of (Document, score) tuples
"""
if not self.bm25_retriever:
print("⚠️ BM25 retriever not available")
return []
try:
# Get BM25 results
self.bm25_retriever.k = k
bm25_docs = self.bm25_retriever.invoke(query)
# Apply metadata filtering
if any([reports, sources, subtype, year]):
bm25_docs = self._filter_documents_by_metadata(
bm25_docs, reports, sources, subtype, year
)
# BM25Retriever doesn't return scores directly, so we'll use placeholder scores
# In a production system, you'd want to access the actual BM25 scores
results = []
for i, doc in enumerate(bm25_docs):
# Assign decreasing scores based on rank (higher rank = higher score)
# Normalize to [0, 1] range for consistency with vector search
score = max(0.1, 1.0 - (i / max(len(bm25_docs), 1)))
results.append((doc, score))
return results
except Exception as e:
print(f"❌ BM25 search failed: {e}")
return []
def _vector_search(
self,
vectorstore: QdrantVectorStore,
query: str,
k: int = 20,
reports: List[str] = None,
sources: str = None,
subtype: List[str] = None,
year: List[str] = None
) -> List[Tuple[Document, float]]:
"""
Perform vector similarity search.
Args:
vectorstore: QdrantVectorStore instance
query: Search query
k: Number of documents to retrieve
reports: List of specific report filenames
sources: Source category
subtype: List of subtypes
year: List of years
Returns:
List of (Document, score) tuples
"""
try:
# Create filter
filter_obj = create_filter(
reports=reports,
sources=sources,
subtype=subtype,
year=year
)
# Perform vector search
if filter_obj:
results = vectorstore.similarity_search_with_score(
query, k=k, filter=filter_obj
)
else:
results = vectorstore.similarity_search_with_score(query, k=k)
return results
except Exception as e:
print(f"❌ Vector search failed: {e}")
return []
def _normalize_scores(self, results: List[Tuple[Document, float]], method: str = "min_max") -> List[Tuple[Document, float]]:
"""
Normalize scores to [0, 1] range.
Args:
results: List of (Document, score) tuples
method: Normalization method ('min_max' or 'z_score')
Returns:
List of (Document, normalized_score) tuples
"""
if not results:
return results
scores = [score for _, score in results]
if method == "min_max":
min_score = min(scores)
max_score = max(scores)
if max_score == min_score:
normalized_results = [(doc, 1.0) for doc, _ in results]
else:
normalized_results = [
(doc, (score - min_score) / (max_score - min_score))
for doc, score in results
]
elif method == "z_score":
mean_score = np.mean(scores)
std_score = np.std(scores)
if std_score == 0:
normalized_results = [(doc, 1.0) for doc, _ in results]
else:
normalized_results = [
(doc, max(0, (score - mean_score) / std_score))
for doc, score in results
]
else:
normalized_results = results
return normalized_results
def _combine_results(
self,
vector_results: List[Tuple[Document, float]],
bm25_results: List[Tuple[Document, float]],
alpha: float = 0.5
) -> List[Tuple[Document, float]]:
"""
Combine vector and BM25 results with weighted scoring.
Args:
vector_results: Vector search results
bm25_results: BM25 search results
alpha: Weight for vector scores (1-alpha for BM25 scores)
Returns:
Combined and ranked results
"""
# Normalize scores
vector_results = self._normalize_scores(vector_results)
bm25_results = self._normalize_scores(bm25_results)
# Create document ID mapping for both result sets
vector_docs = {id(doc): (doc, score) for doc, score in vector_results}
bm25_docs = {id(doc): (doc, score) for doc, score in bm25_results}
# Combine scores
combined_scores = {}
all_doc_ids = set(vector_docs.keys()) | set(bm25_docs.keys())
for doc_id in all_doc_ids:
vector_score = vector_docs.get(doc_id, (None, 0.0))[1]
bm25_score = bm25_docs.get(doc_id, (None, 0.0))[1]
# Weighted combination
combined_score = alpha * vector_score + (1 - alpha) * bm25_score
# Get document object
doc = vector_docs.get(doc_id, bm25_docs.get(doc_id))[0]
combined_scores[doc_id] = (doc, combined_score)
# Sort by combined score (descending)
sorted_results = sorted(
combined_scores.values(),
key=lambda x: x[1],
reverse=True
)
return sorted_results
def retrieve(
self,
vectorstore: QdrantVectorStore,
query: str,
mode: str = "hybrid",
reports: List[str] = None,
sources: str = None,
subtype: List[str] = None,
year: List[str] = None,
alpha: float = 0.5,
k: int = None
) -> List[Document]:
"""
Retrieve documents using the specified search mode.
Args:
vectorstore: QdrantVectorStore instance
query: Search query
mode: Search mode ('vector_only', 'sparse_only', or 'hybrid')
reports: List of specific report filenames
sources: Source category
subtype: List of subtypes
year: List of years
alpha: Weight for vector scores in hybrid mode (0.5 = equal weight)
k: Number of documents to retrieve
Returns:
List of relevant Document objects
"""
if k is None:
k = self.config.get("retriever", {}).get("top_k", 20)
results = []
if mode == "vector_only":
# Vector search only
vector_results = self._vector_search(
vectorstore, query, k, reports, sources, subtype, year
)
results = [(doc, score) for doc, score in vector_results]
elif mode == "sparse_only":
# BM25 search only
bm25_results = self._bm25_search(
query, k, reports, sources, subtype, year
)
results = [(doc, score) for doc, score in bm25_results]
elif mode == "hybrid":
# Hybrid search - combine both
# Get more results from each method to have better fusion
retrieval_k = min(k * 2, 50) # Get more candidates for fusion
vector_results = self._vector_search(
vectorstore, query, retrieval_k, reports, sources, subtype, year
)
bm25_results = self._bm25_search(
query, retrieval_k, reports, sources, subtype, year
)
results = self._combine_results(vector_results, bm25_results, alpha)
else:
raise ValueError(f"Unknown search mode: {mode}")
# Limit to top k results
results = results[:k]
# Return just the documents
return [doc for doc, score in results]
def retrieve_with_scores(
self,
vectorstore: QdrantVectorStore,
query: str,
mode: str = "hybrid",
reports: List[str] = None,
sources: str = None,
subtype: List[str] = None,
year: List[str] = None,
alpha: float = 0.5,
k: int = None
) -> List[Tuple[Document, float]]:
"""
Retrieve documents with scores using the specified search mode.
Args:
vectorstore: QdrantVectorStore instance
query: Search query
mode: Search mode ('vector_only', 'sparse_only', or 'hybrid')
reports: List of specific report filenames
sources: Source category
subtype: List of subtypes
year: List of years
alpha: Weight for vector scores in hybrid mode (0.5 = equal weight)
k: Number of documents to retrieve
Returns:
List of (Document, score) tuples
"""
if k is None:
k = self.config.get("retriever", {}).get("top_k", 20)
results = []
if mode == "vector_only":
# Vector search only
results = self._vector_search(
vectorstore, query, k, reports, sources, subtype, year
)
elif mode == "sparse_only":
# BM25 search only
results = self._bm25_search(
query, k, reports, sources, subtype, year
)
elif mode == "hybrid":
# Hybrid search - combine both
# Get more results from each method to have better fusion
retrieval_k = min(k * 2, 50) # Get more candidates for fusion
vector_results = self._vector_search(
vectorstore, query, retrieval_k, reports, sources, subtype, year
)
bm25_results = self._bm25_search(
query, retrieval_k, reports, sources, subtype, year
)
results = self._combine_results(vector_results, bm25_results, alpha)
else:
raise ValueError(f"Unknown search mode: {mode}")
# Limit to top k results
return results[:k]
def get_available_search_modes() -> List[str]:
"""Get list of available search modes."""
return ["vector_only", "sparse_only", "hybrid"]
def get_search_mode_description() -> Dict[str, str]:
"""Get descriptions for each search mode."""
return {
"vector_only": "Semantic search using dense embeddings - good for conceptual matching",
"sparse_only": "Keyword search using BM25 - good for exact term matching",
"hybrid": "Combined semantic and keyword search - balanced approach"
}