Spaces:
Sleeping
Sleeping
File size: 5,258 Bytes
8779583 c6706bd 8779583 c6706bd 8779583 1006fab c6706bd 8779583 c6706bd 8779583 8ad42f5 c6706bd 8779583 c6706bd 8779583 c6706bd 8779583 c6706bd 1006fab c6706bd 1006fab c6706bd 8779583 c6706bd 1006fab c6706bd 1006fab c6706bd 1006fab c6706bd 8779583 1006fab 8779583 c6706bd 8779583 c6706bd 1006fab c6706bd 1006fab c6706bd 8779583 c6706bd 1006fab c6706bd 8779583 c6706bd 8779583 c6706bd 8779583 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 |
"""FAISS-based semantic search engine using ID-mapped index"""
import faiss
import numpy as np
from typing import List, Tuple
import os
class SearchEngine:
"""FAISS-based search engine for image embeddings"""
def __init__(self, dim: int = 1024, index_path: str = "faiss_index.bin"):
self.dim = dim
self.index_path = index_path
# Load existing index or create a new one
if os.path.exists(index_path):
self.index = faiss.read_index(index_path)
else:
base_index = faiss.IndexFlatL2(dim)
self.index = faiss.IndexIDMap(base_index)
def create_albums(self, top_k: int = 5, distance_threshold: float = 0.3) -> List[List[int]]:
"""
Group similar images into albums (clusters).
For each unvisited photo, finds its top_k most similar photos and creates an album.
Photos are marked as visited to avoid duplicate albums.
Only includes photos within the distance threshold.
Args:
top_k: Number of similar images to find for each album
distance_threshold: Maximum distance to consider photos as similar (default 0.5)
Returns:
List of albums, each album is a list of photo_ids
"""
from cloudzy.database import SessionLocal
from cloudzy.models import Photo
from sqlmodel import select
self.load()
if self.index.ntotal == 0:
return []
# Get all photo IDs from FAISS index
id_map = self.index.id_map
all_ids = [id_map.at(i) for i in range(id_map.size())]
visited = set()
albums = []
for photo_id in all_ids:
# Skip if already in an album
if photo_id in visited:
continue
# Get embedding from database
session = SessionLocal()
try:
photo = session.exec(select(Photo).where(Photo.id == photo_id)).first()
if not photo:
continue
embedding = photo.get_embedding()
if not embedding:
continue
# Search for similar images
query_embedding = np.array(embedding).reshape(1, -1).astype(np.float32)
distances, ids = self.index.search(query_embedding, top_k)
# Build album: collect similar photos that haven't been visited and are within threshold
album = []
for pid, distance in zip(ids[0], distances[0]):
if pid != -1 and pid not in visited and distance <= distance_threshold:
album.append(int(pid))
visited.add(pid)
# Add album if it has at least 1 photo
if album:
albums.append(album)
finally:
session.close()
return albums
def add_embedding(self, photo_id: int, embedding: np.ndarray) -> None:
"""
Add an embedding to the index.
Args:
photo_id: Unique photo identifier
embedding: 1D numpy array of shape (dim,)
"""
# Ensure embedding is float32 and correct shape
embedding = embedding.astype(np.float32).reshape(1, -1)
# Add embedding with its ID
self.index.add_with_ids(embedding, np.array([photo_id], dtype=np.int64))
# Save index to disk
self.save()
def search(self, query_embedding: np.ndarray, top_k: int = 5) -> List[Tuple[int, float]]:
"""
Search for similar embeddings.
Args:
query_embedding: 1D numpy array of shape (dim,)
top_k: Number of results to return
Returns:
List of (photo_id, distance) tuples with distance <= 0.5
"""
self.load()
if self.index.ntotal == 0:
return []
# Ensure query is float32 and correct shape
query_embedding = query_embedding.astype(np.float32).reshape(1, -1)
# Search in FAISS index
distances, ids = self.index.search(query_embedding, top_k)
# Filter invalid and distant results
results = [
(int(photo_id), float(distance))
for photo_id, distance in zip(ids[0], distances[0])
if photo_id != -1 and distance <= 0.5
]
return results
def save(self) -> None:
"""Save FAISS index to disk"""
faiss.write_index(self.index, self.index_path)
def load(self) -> None:
"""Load FAISS index from disk"""
if os.path.exists(self.index_path):
self.index = faiss.read_index(self.index_path)
else:
# Recreate empty ID-mapped index if missing
base_index = faiss.IndexFlatL2(self.dim)
self.index = faiss.IndexIDMap(base_index)
def get_stats(self) -> dict:
"""Get index statistics"""
return {
"total_embeddings": self.index.ntotal,
"dimension": self.dim,
"index_type": type(self.index).__name__,
}
|