File size: 5,616 Bytes
8779583
c6706bd
 
8779583
c6706bd
4d4fccb
c6706bd
 
 
 
8779583
1006fab
c6706bd
 
8779583
 
c6706bd
 
 
8779583
 
 
4d4fccb
8ad42f5
 
 
4d4fccb
8ad42f5
 
 
 
4d4fccb
 
 
8ad42f5
 
4d4fccb
8ad42f5
 
 
 
 
 
 
 
 
 
 
 
4d4fccb
 
 
8ad42f5
 
 
 
 
4d4fccb
 
 
 
8ad42f5
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
4d4fccb
8ad42f5
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
c6706bd
 
 
8779583
c6706bd
 
 
 
 
 
8779583
 
 
 
c6706bd
 
8779583
c6706bd
 
 
1006fab
c6706bd
 
 
1006fab
c6706bd
8779583
c6706bd
1006fab
 
c6706bd
 
1006fab
c6706bd
 
1006fab
c6706bd
8779583
1006fab
8779583
c6706bd
8779583
 
 
c6706bd
1006fab
c6706bd
1006fab
c6706bd
8779583
c6706bd
1006fab
c6706bd
8779583
c6706bd
 
8779583
 
 
 
 
c6706bd
 
 
 
 
8779583
 
4d4fccb
ab19ad9
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
"""FAISS-based semantic search engine using ID-mapped index"""
import faiss
import numpy as np
from typing import List, Tuple
import os
import random


class SearchEngine:
    """FAISS-based search engine for image embeddings"""

    def __init__(self, dim: int = 1024, index_path: str = "faiss_index.bin"):
        self.dim = dim
        self.index_path = index_path

        # Load existing index or create a new one
        if os.path.exists(index_path):
            self.index = faiss.read_index(index_path)
        else:
            base_index = faiss.IndexFlatL2(dim)
            self.index = faiss.IndexIDMap(base_index)

    def create_albums(self, top_k: int = 5, distance_threshold: float = 0.3, album_size: int = 5) -> List[List[int]]:
        """
        Group similar images into albums (clusters).
        
        Returns exactly top_k albums, each containing up to album_size similar photos.
        Photos are marked as visited to avoid duplicate albums.
        Only includes photos within the distance threshold.
        
        Args:
            top_k: Number of albums to return
            distance_threshold: Maximum distance to consider photos as similar (default 0.3)
            album_size: How many similar photos to search for per album (default 5)
            
        Returns:
            List of top_k albums, each album is a list of photo_ids (randomized order each call)
        """
        from cloudzy.database import SessionLocal
        from cloudzy.models import Photo
        from sqlmodel import select
        
        self.load()
        if self.index.ntotal == 0:
            return []

        # Get all photo IDs from FAISS index
        id_map = self.index.id_map
        all_ids = [id_map.at(i) for i in range(id_map.size())]
        
        # Shuffle for randomization - different albums each call
        random.shuffle(all_ids)

        visited = set()
        albums = []

        for photo_id in all_ids:
            # Stop if we have enough albums
            if len(albums) >= top_k:
                break
            
            # Skip if already in an album
            if photo_id in visited:
                continue
            
            # Get embedding from database
            session = SessionLocal()
            try:
                photo = session.exec(select(Photo).where(Photo.id == photo_id)).first()
                if not photo:
                    continue
                
                embedding = photo.get_embedding()
                if not embedding:
                    continue
                
                # Search for similar images
                query_embedding = np.array(embedding).reshape(1, -1).astype(np.float32)
                distances, ids = self.index.search(query_embedding, album_size)
                
                # Build album: collect similar photos that haven't been visited and are within threshold
                album = []
                for pid, distance in zip(ids[0], distances[0]):
                    if pid != -1 and pid not in visited and distance <= distance_threshold:
                        album.append(int(pid))
                        visited.add(pid)
                
                # Add album if it has at least 1 photo
                if album:
                    albums.append(album)
                    
            finally:
                session.close()
        
        return albums

    def add_embedding(self, photo_id: int, embedding: np.ndarray) -> None:
        """
        Add an embedding to the index.

        Args:
            photo_id: Unique photo identifier
            embedding: 1D numpy array of shape (dim,)
        """
        # Ensure embedding is float32 and correct shape
        embedding = embedding.astype(np.float32).reshape(1, -1)

        # Add embedding with its ID
        self.index.add_with_ids(embedding, np.array([photo_id], dtype=np.int64))

        # Save index to disk
        self.save()

    def search(self, query_embedding: np.ndarray, top_k: int = 5) -> List[Tuple[int, float]]:
        """
        Search for similar embeddings.

        Args:
            query_embedding: 1D numpy array of shape (dim,)
            top_k: Number of results to return

        Returns:
            List of (photo_id, distance) tuples with distance <= 0.5
        """
        self.load()

        if self.index.ntotal == 0:
            return []

        # Ensure query is float32 and correct shape
        query_embedding = query_embedding.astype(np.float32).reshape(1, -1)

        # Search in FAISS index
        distances, ids = self.index.search(query_embedding, top_k)

        # Filter invalid and distant results
        results = [
            (int(photo_id), float(distance))
            for photo_id, distance in zip(ids[0], distances[0])
            if photo_id != -1 and distance <= 0.5
        ]

        return results

    def save(self) -> None:
        """Save FAISS index to disk"""
        faiss.write_index(self.index, self.index_path)

    def load(self) -> None:
        """Load FAISS index from disk"""
        if os.path.exists(self.index_path):
            self.index = faiss.read_index(self.index_path)
        else:
            # Recreate empty ID-mapped index if missing
            base_index = faiss.IndexFlatL2(self.dim)
            self.index = faiss.IndexIDMap(base_index)

    def get_stats(self) -> dict:
        """Get index statistics"""
        return {
            "total_embeddings": self.index.ntotal,
            "dimension": self.dim,
            "index_type": type(self.index).__name__,
        }