Spaces:
Sleeping
Sleeping
File size: 5,616 Bytes
8779583 c6706bd 8779583 c6706bd 4d4fccb c6706bd 8779583 1006fab c6706bd 8779583 c6706bd 8779583 4d4fccb 8ad42f5 4d4fccb 8ad42f5 4d4fccb 8ad42f5 4d4fccb 8ad42f5 4d4fccb 8ad42f5 4d4fccb 8ad42f5 4d4fccb 8ad42f5 c6706bd 8779583 c6706bd 8779583 c6706bd 8779583 c6706bd 1006fab c6706bd 1006fab c6706bd 8779583 c6706bd 1006fab c6706bd 1006fab c6706bd 1006fab c6706bd 8779583 1006fab 8779583 c6706bd 8779583 c6706bd 1006fab c6706bd 1006fab c6706bd 8779583 c6706bd 1006fab c6706bd 8779583 c6706bd 8779583 c6706bd 8779583 4d4fccb ab19ad9 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 |
"""FAISS-based semantic search engine using ID-mapped index"""
import faiss
import numpy as np
from typing import List, Tuple
import os
import random
class SearchEngine:
"""FAISS-based search engine for image embeddings"""
def __init__(self, dim: int = 1024, index_path: str = "faiss_index.bin"):
self.dim = dim
self.index_path = index_path
# Load existing index or create a new one
if os.path.exists(index_path):
self.index = faiss.read_index(index_path)
else:
base_index = faiss.IndexFlatL2(dim)
self.index = faiss.IndexIDMap(base_index)
def create_albums(self, top_k: int = 5, distance_threshold: float = 0.3, album_size: int = 5) -> List[List[int]]:
"""
Group similar images into albums (clusters).
Returns exactly top_k albums, each containing up to album_size similar photos.
Photos are marked as visited to avoid duplicate albums.
Only includes photos within the distance threshold.
Args:
top_k: Number of albums to return
distance_threshold: Maximum distance to consider photos as similar (default 0.3)
album_size: How many similar photos to search for per album (default 5)
Returns:
List of top_k albums, each album is a list of photo_ids (randomized order each call)
"""
from cloudzy.database import SessionLocal
from cloudzy.models import Photo
from sqlmodel import select
self.load()
if self.index.ntotal == 0:
return []
# Get all photo IDs from FAISS index
id_map = self.index.id_map
all_ids = [id_map.at(i) for i in range(id_map.size())]
# Shuffle for randomization - different albums each call
random.shuffle(all_ids)
visited = set()
albums = []
for photo_id in all_ids:
# Stop if we have enough albums
if len(albums) >= top_k:
break
# Skip if already in an album
if photo_id in visited:
continue
# Get embedding from database
session = SessionLocal()
try:
photo = session.exec(select(Photo).where(Photo.id == photo_id)).first()
if not photo:
continue
embedding = photo.get_embedding()
if not embedding:
continue
# Search for similar images
query_embedding = np.array(embedding).reshape(1, -1).astype(np.float32)
distances, ids = self.index.search(query_embedding, album_size)
# Build album: collect similar photos that haven't been visited and are within threshold
album = []
for pid, distance in zip(ids[0], distances[0]):
if pid != -1 and pid not in visited and distance <= distance_threshold:
album.append(int(pid))
visited.add(pid)
# Add album if it has at least 1 photo
if album:
albums.append(album)
finally:
session.close()
return albums
def add_embedding(self, photo_id: int, embedding: np.ndarray) -> None:
"""
Add an embedding to the index.
Args:
photo_id: Unique photo identifier
embedding: 1D numpy array of shape (dim,)
"""
# Ensure embedding is float32 and correct shape
embedding = embedding.astype(np.float32).reshape(1, -1)
# Add embedding with its ID
self.index.add_with_ids(embedding, np.array([photo_id], dtype=np.int64))
# Save index to disk
self.save()
def search(self, query_embedding: np.ndarray, top_k: int = 5) -> List[Tuple[int, float]]:
"""
Search for similar embeddings.
Args:
query_embedding: 1D numpy array of shape (dim,)
top_k: Number of results to return
Returns:
List of (photo_id, distance) tuples with distance <= 0.5
"""
self.load()
if self.index.ntotal == 0:
return []
# Ensure query is float32 and correct shape
query_embedding = query_embedding.astype(np.float32).reshape(1, -1)
# Search in FAISS index
distances, ids = self.index.search(query_embedding, top_k)
# Filter invalid and distant results
results = [
(int(photo_id), float(distance))
for photo_id, distance in zip(ids[0], distances[0])
if photo_id != -1 and distance <= 0.5
]
return results
def save(self) -> None:
"""Save FAISS index to disk"""
faiss.write_index(self.index, self.index_path)
def load(self) -> None:
"""Load FAISS index from disk"""
if os.path.exists(self.index_path):
self.index = faiss.read_index(self.index_path)
else:
# Recreate empty ID-mapped index if missing
base_index = faiss.IndexFlatL2(self.dim)
self.index = faiss.IndexIDMap(base_index)
def get_stats(self) -> dict:
"""Get index statistics"""
return {
"total_embeddings": self.index.ntotal,
"dimension": self.dim,
"index_type": type(self.index).__name__,
}
|