cloudzy_ai_challenge / cloudzy /search_engine.py
matinsn2000's picture
Added embedding
1006fab
raw
history blame
2.99 kB
"""FAISS-based semantic search engine"""
import faiss
import numpy as np
from typing import List, Tuple, Optional
import os
import pickle
class SearchEngine:
"""FAISS-based search engine for image embeddings"""
def __init__(self, dim: int = 1024, index_path: str = "faiss_index.bin"):
self.dim = dim
self.index_path = index_path
self.id_map: List[int] = [] # Map FAISS indices to photo IDs
# Load existing index or create new one
if os.path.exists(index_path):
self.index = faiss.read_index(index_path)
else:
self.index = faiss.IndexFlatL2(dim)
def add_embedding(self, photo_id: int, embedding: np.ndarray) -> None:
"""
Add an embedding to the index.
Args:
photo_id: Unique photo identifier
embedding: 1D numpy array of shape (dim,)
"""
# Ensure embedding is float32 and correct shape
embedding = embedding.astype(np.float32).reshape(1, -1)
# Add to FAISS index
self.index.add(embedding)
# Track photo ID
self.id_map.append(photo_id)
# Save index to disk
self.save()
def search(self, query_embedding: np.ndarray, top_k: int = 5) -> List[Tuple[int, float]]:
"""
Search for similar embeddings.
Args:
query_embedding: 1D numpy array of shape (dim,)
top_k: Number of results to return
Returns:
List of (photo_id, distance) tuples with distance <= 0.4
"""
self.load()
if self.index.ntotal == 0:
return []
# Ensure query is float32 and correct shape
query_embedding = query_embedding.astype(np.float32).reshape(1, -1)
# Search in FAISS index
distances, indices = self.index.search(query_embedding, min(top_k, self.index.ntotal))
# Map back to photo IDs and filter distances > 0.4
results = [
(self.id_map[int(idx)], float(distance))
for distance, idx in zip(distances[0], indices[0])
if distance <= 0.5
]
return results
def save(self) -> None:
"""Save index and id_map to disk"""
faiss.write_index(self.index, self.index_path)
with open(self.index_path + ".ids", "wb") as f:
pickle.dump(self.id_map, f)
def load(self) -> None:
"""Load index and id_map from disk"""
if os.path.exists(self.index_path):
self.index = faiss.read_index(self.index_path)
if os.path.exists(self.index_path + ".ids"):
with open(self.index_path + ".ids", "rb") as f:
self.id_map = pickle.load(f)
def get_stats(self) -> dict:
"""Get index statistics"""
return {
"total_embeddings": self.index.ntotal,
"dimension": self.dim,
"id_map_size": len(self.id_map)
}