|
|
""" |
|
|
ColBERT embeddings cache for test set documents. |
|
|
Provides O(1) lookup for ColBERT embeddings during late interaction. |
|
|
""" |
|
|
|
|
|
import json |
|
|
import numpy as np |
|
|
from pathlib import Path |
|
|
from typing import Dict, Optional, Any |
|
|
|
|
|
|
|
|
class ColBERTCache: |
|
|
"""Cache for ColBERT embeddings of test set documents.""" |
|
|
|
|
|
def __init__(self, cache_file: str = "test_set_colbert_cache.json"): |
|
|
self.cache_file = Path("outputs/caches") / cache_file |
|
|
self.embeddings_cache: Dict[str, np.ndarray] = {} |
|
|
self._load_cache() |
|
|
|
|
|
def _load_cache(self): |
|
|
"""Load embeddings from cache file.""" |
|
|
if not self.cache_file.exists(): |
|
|
print(f"β οΈ ColBERT cache not found: {self.cache_file}") |
|
|
print("π‘ Run 'python precalculate_test_set_colbert.py' to create cache") |
|
|
return |
|
|
|
|
|
print(f"π Loading ColBERT cache from {self.cache_file}...") |
|
|
|
|
|
try: |
|
|
with open(self.cache_file, 'r') as f: |
|
|
cache_data = json.load(f) |
|
|
|
|
|
|
|
|
for doc_id, data in cache_data.items(): |
|
|
embedding_min = data['min'] |
|
|
embedding_max = data['max'] |
|
|
quantized_embedding = np.array(data['embedding'], dtype=np.uint8) |
|
|
|
|
|
|
|
|
reconstructed = (quantized_embedding.astype(np.float32) / 255.0) * (embedding_max - embedding_min) + embedding_min |
|
|
self.embeddings_cache[doc_id] = reconstructed.reshape(data['shape']) |
|
|
|
|
|
print(f"β
Loaded {len(self.embeddings_cache)} ColBERT embeddings from cache") |
|
|
|
|
|
except Exception as e: |
|
|
print(f"β Error loading ColBERT cache: {e}") |
|
|
self.embeddings_cache = {} |
|
|
|
|
|
def get_embedding(self, document_text: str) -> Optional[np.ndarray]: |
|
|
"""Get ColBERT embedding for a document (O(1) lookup).""" |
|
|
return self.embeddings_cache.get(document_text) |
|
|
|
|
|
def has_embedding(self, document_text: str) -> bool: |
|
|
"""Check if embedding exists for document.""" |
|
|
return document_text in self.embeddings_cache |
|
|
|
|
|
def get_cache_stats(self) -> Dict[str, Any]: |
|
|
"""Get cache statistics.""" |
|
|
return { |
|
|
'total_embeddings': len(self.embeddings_cache), |
|
|
'cache_file': str(self.cache_file), |
|
|
'cache_exists': self.cache_file.exists() |
|
|
} |
|
|
|
|
|
|
|
|
|
|
|
_colbert_cache = None |
|
|
|
|
|
def get_colbert_cache() -> ColBERTCache: |
|
|
"""Get global ColBERT cache instance.""" |
|
|
global _colbert_cache |
|
|
if _colbert_cache is None: |
|
|
_colbert_cache = ColBERTCache() |
|
|
return _colbert_cache |