Spaces:
Running
Running
File size: 8,693 Bytes
8779583 c6706bd 8779583 c6706bd 4d4fccb c6706bd 8779583 d667f1f c6706bd 8779583 c6706bd 8779583 d667f1f 8ad42f5 c2cd7f1 8ad42f5 c2cd7f1 8ad42f5 d667f1f 8ad42f5 c2cd7f1 d667f1f 4d4fccb 8ad42f5 c2cd7f1 8ad42f5 4d4fccb 8ad42f5 d667f1f 8ad42f5 4d4fccb 8ad42f5 d667f1f 8ad42f5 d667f1f c2cd7f1 d667f1f c2cd7f1 d667f1f c2cd7f1 d667f1f c2cd7f1 d667f1f c2cd7f1 d667f1f c2cd7f1 d667f1f c2cd7f1 d667f1f c6706bd 8779583 c6706bd 8779583 c6706bd 8779583 c6706bd 1006fab c6706bd 1006fab c6706bd d667f1f c6706bd 1006fab c6706bd 1006fab c6706bd 1006fab c6706bd 8779583 1006fab d667f1f 8779583 d667f1f c6706bd 8779583 d667f1f c6706bd 1006fab c6706bd 1006fab c6706bd 8779583 c6706bd 1006fab c6706bd 8779583 c6706bd 8779583 c6706bd 8779583 4d4fccb ab19ad9 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 |
"""FAISS-based semantic search engine using ID-mapped index"""
import faiss
import numpy as np
from typing import List, Tuple
import os
import random
class SearchEngine:
"""FAISS-based search engine for image embeddings"""
def __init__(self, dim: int = 4096, index_path: str = "faiss_index.bin"):
self.dim = dim
self.index_path = index_path
# Load existing index or create a new one
if os.path.exists(index_path):
self.index = faiss.read_index(index_path)
else:
base_index = faiss.IndexFlatL2(dim)
self.index = faiss.IndexIDMap(base_index)
def create_albums(self, top_k: int = 5, distance_threshold: float = 1.5, album_size: int = 5) -> List[List[int]]:
"""
Group similar images into albums (clusters).
Returns up to top_k albums, each containing up to album_size similar photos.
Photos are marked as visited to avoid duplicate albums.
Only includes photos within the distance threshold.
Automatically adjusts if fewer images than requested albums.
OPTIMIZATIONS:
- Batch retrieves all photos in ONE database query (not per-photo)
- Caches embeddings in memory during execution
- Single session for all DB operations
Args:
top_k: Number of albums to return (returns fewer if not enough images)
distance_threshold: Maximum distance to consider photos as similar (default 1.0 for normalized embeddings)
album_size: How many similar photos to search for per album (default 5)
Returns:
List of up to top_k albums, each album is a list of photo_ids (randomized order each call)
Returns empty list if no images exist.
"""
from cloudzy.database import SessionLocal
from cloudzy.models import Photo
from sqlmodel import select
self.load()
if self.index.ntotal == 0:
return []
# Get all photo IDs from FAISS index
id_map = self.index.id_map
all_ids = [id_map.at(i) for i in range(id_map.size())]
# Shuffle for randomization - different albums each call
random.shuffle(all_ids)
# ✅ OPTIMIZATION 1: Batch retrieve all photos in ONE query
session = SessionLocal()
try:
# Fetch all photos at once, not in a loop
photos_query = session.exec(select(Photo).where(Photo.id.in_(all_ids))).all()
# ✅ OPTIMIZATION 2: Cache embeddings in memory
embedding_cache = {}
for photo in photos_query:
embedding = photo.get_embedding()
if embedding:
embedding_cache[photo.id] = embedding
finally:
session.close()
visited = set()
albums = []
for photo_id in all_ids:
# Stop if we have enough albums
if len(albums) >= top_k:
break
# Skip if already in an album
if photo_id in visited:
continue
# Skip if no embedding cached
if photo_id not in embedding_cache:
continue
# Get embedding from cache (not DB)
embedding = embedding_cache[photo_id]
# Search for similar images
query_embedding = np.array(embedding).reshape(1, -1).astype(np.float32)
distances, ids = self.index.search(query_embedding, album_size)
# Build album: collect similar photos that haven't been visited and are within threshold
album = []
for pid, distance in zip(ids[0], distances[0]):
if pid != -1 and pid not in visited and distance <= distance_threshold:
album.append(int(pid))
visited.add(pid)
# Add album if it has at least 1 photo
if album:
albums.append(album)
return albums
def create_albums_kmeans(self, top_k: int = 5, seed: int = 42) -> List[List[int]]:
"""
Group similar images into albums using FAISS k-means clustering.
This is a BETTER approach than nearest-neighbor grouping:
- Uses true k-means clustering instead of ad-hoc neighbor search
- All photos get assigned to a cluster (no "orphans")
- Deterministic results for same seed
- Much faster for large datasets
- Automatically adjusts if fewer images than requested clusters
Args:
top_k: Number of clusters (albums) to create
seed: Random seed for reproducibility
Returns:
List of albums, each album is a list of photo_ids.
Returns up to top_k albums, or fewer if total images < top_k.
Returns empty list if no images exist.
"""
self.load()
if self.index.ntotal == 0:
return []
# Adjust k to not exceed total number of images
actual_k = min(top_k, self.index.ntotal)
# Get all photo IDs from FAISS index
id_map = self.index.id_map
all_ids = np.array([id_map.at(i) for i in range(id_map.size())], dtype=np.int64)
# Get all embeddings from the underlying index (IndexIDMap wraps the actual index)
underlying_index = faiss.downcast_index(self.index.index)
all_embeddings = underlying_index.reconstruct_n(0, self.index.ntotal).astype(np.float32)
# ✅ Run k-means clustering with adjusted k
kmeans = faiss.Kmeans(
d=self.dim,
k=actual_k,
niter=20,
verbose=False,
seed=seed
)
kmeans.train(all_embeddings)
# Assign each embedding to nearest cluster
distances, cluster_assignments = kmeans.index.search(all_embeddings, 1)
# Group photos by cluster
albums = [[] for _ in range(actual_k)]
for photo_id, cluster_id in zip(all_ids, cluster_assignments.flatten()):
albums[cluster_id].append(int(photo_id))
# Remove empty albums and return
return [album for album in albums if album]
def add_embedding(self, photo_id: int, embedding: np.ndarray) -> None:
"""
Add an embedding to the index.
Args:
photo_id: Unique photo identifier
embedding: 1D numpy array of shape (dim,)
"""
# Ensure embedding is float32 and correct shape
embedding = embedding.astype(np.float32).reshape(1, -1)
# Add embedding with its ID
self.index.add_with_ids(embedding, np.array([photo_id], dtype=np.int64))
# Save index to disk
self.save()
def search(self, query_embedding: np.ndarray, top_k: int = 5) -> List[Tuple[int, float]]:
"""
Search for similar embeddings.
Args:
query_embedding: 1D numpy array of shape (dim,)
top_k: Number of results to return
Returns:
List of (photo_id, distance) tuples with distance <= 1.0 (normalized embeddings)
"""
self.load()
if self.index.ntotal == 0:
return []
# Ensure query is float32 and correct shape
query_embedding = query_embedding.astype(np.float32).reshape(1, -1)
# Search in FAISS index
distances, ids = self.index.search(query_embedding, top_k)
print(distances)
# Filter invalid and distant results
# With normalized embeddings, L2 distance range is 0-2, threshold of 1.0 works well
results = [
(int(photo_id), float(distance))
for photo_id, distance in zip(ids[0], distances[0])
if photo_id != -1 and distance <= 1.5
]
return results
def save(self) -> None:
"""Save FAISS index to disk"""
faiss.write_index(self.index, self.index_path)
def load(self) -> None:
"""Load FAISS index from disk"""
if os.path.exists(self.index_path):
self.index = faiss.read_index(self.index_path)
else:
# Recreate empty ID-mapped index if missing
base_index = faiss.IndexFlatL2(self.dim)
self.index = faiss.IndexIDMap(base_index)
def get_stats(self) -> dict:
"""Get index statistics"""
return {
"total_embeddings": self.index.ntotal,
"dimension": self.dim,
"index_type": type(self.index).__name__,
}
|