audit_assistant / src /retrieval /colbert_cache.py
Ara Yeroyan
add src
f5df983
raw
history blame
2.75 kB
"""
ColBERT embeddings cache for test set documents.
Provides O(1) lookup for ColBERT embeddings during late interaction.
"""
import json
import numpy as np
from pathlib import Path
from typing import Dict, Optional, Any
class ColBERTCache:
"""Cache for ColBERT embeddings of test set documents."""
def __init__(self, cache_file: str = "test_set_colbert_cache.json"):
self.cache_file = Path("outputs/caches") / cache_file
self.embeddings_cache: Dict[str, np.ndarray] = {}
self._load_cache()
def _load_cache(self):
"""Load embeddings from cache file."""
if not self.cache_file.exists():
print(f"⚠️ ColBERT cache not found: {self.cache_file}")
print("πŸ’‘ Run 'python precalculate_test_set_colbert.py' to create cache")
return
print(f"πŸ“‚ Loading ColBERT cache from {self.cache_file}...")
try:
with open(self.cache_file, 'r') as f:
cache_data = json.load(f)
# Reconstruct embeddings from compressed format
for doc_id, data in cache_data.items():
embedding_min = data['min']
embedding_max = data['max']
quantized_embedding = np.array(data['embedding'], dtype=np.uint8)
# Reconstruct original embedding
reconstructed = (quantized_embedding.astype(np.float32) / 255.0) * (embedding_max - embedding_min) + embedding_min
self.embeddings_cache[doc_id] = reconstructed.reshape(data['shape'])
print(f"βœ… Loaded {len(self.embeddings_cache)} ColBERT embeddings from cache")
except Exception as e:
print(f"❌ Error loading ColBERT cache: {e}")
self.embeddings_cache = {}
def get_embedding(self, document_text: str) -> Optional[np.ndarray]:
"""Get ColBERT embedding for a document (O(1) lookup)."""
return self.embeddings_cache.get(document_text)
def has_embedding(self, document_text: str) -> bool:
"""Check if embedding exists for document."""
return document_text in self.embeddings_cache
def get_cache_stats(self) -> Dict[str, Any]:
"""Get cache statistics."""
return {
'total_embeddings': len(self.embeddings_cache),
'cache_file': str(self.cache_file),
'cache_exists': self.cache_file.exists()
}
# Global cache instance
_colbert_cache = None
def get_colbert_cache() -> ColBERTCache:
"""Get global ColBERT cache instance."""
global _colbert_cache
if _colbert_cache is None:
_colbert_cache = ColBERTCache()
return _colbert_cache