File size: 5,258 Bytes
8779583
c6706bd
 
8779583
c6706bd
 
 
 
 
8779583
1006fab
c6706bd
 
8779583
 
c6706bd
 
 
8779583
 
 
8ad42f5
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
c6706bd
 
 
8779583
c6706bd
 
 
 
 
 
8779583
 
 
 
c6706bd
 
8779583
c6706bd
 
 
1006fab
c6706bd
 
 
1006fab
c6706bd
8779583
c6706bd
1006fab
 
c6706bd
 
1006fab
c6706bd
 
1006fab
c6706bd
8779583
1006fab
8779583
c6706bd
8779583
 
 
c6706bd
1006fab
c6706bd
1006fab
c6706bd
8779583
c6706bd
1006fab
c6706bd
8779583
c6706bd
 
8779583
 
 
 
 
c6706bd
 
 
 
 
8779583
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
"""FAISS-based semantic search engine using ID-mapped index"""
import faiss
import numpy as np
from typing import List, Tuple
import os


class SearchEngine:
    """FAISS-based search engine for image embeddings"""

    def __init__(self, dim: int = 1024, index_path: str = "faiss_index.bin"):
        self.dim = dim
        self.index_path = index_path

        # Load existing index or create a new one
        if os.path.exists(index_path):
            self.index = faiss.read_index(index_path)
        else:
            base_index = faiss.IndexFlatL2(dim)
            self.index = faiss.IndexIDMap(base_index)

    def create_albums(self, top_k: int = 5, distance_threshold: float = 0.3) -> List[List[int]]:
        """
        Group similar images into albums (clusters).
        
        For each unvisited photo, finds its top_k most similar photos and creates an album.
        Photos are marked as visited to avoid duplicate albums.
        Only includes photos within the distance threshold.
        
        Args:
            top_k: Number of similar images to find for each album
            distance_threshold: Maximum distance to consider photos as similar (default 0.5)
            
        Returns:
            List of albums, each album is a list of photo_ids
        """
        from cloudzy.database import SessionLocal
        from cloudzy.models import Photo
        from sqlmodel import select
        
        self.load()
        if self.index.ntotal == 0:
            return []

        # Get all photo IDs from FAISS index
        id_map = self.index.id_map
        all_ids = [id_map.at(i) for i in range(id_map.size())]

        visited = set()
        albums = []

        for photo_id in all_ids:
            # Skip if already in an album
            if photo_id in visited:
                continue
            
            # Get embedding from database
            session = SessionLocal()
            try:
                photo = session.exec(select(Photo).where(Photo.id == photo_id)).first()
                if not photo:
                    continue
                
                embedding = photo.get_embedding()
                if not embedding:
                    continue
                
                # Search for similar images
                query_embedding = np.array(embedding).reshape(1, -1).astype(np.float32)
                distances, ids = self.index.search(query_embedding, top_k)
                
                # Build album: collect similar photos that haven't been visited and are within threshold
                album = []
                for pid, distance in zip(ids[0], distances[0]):
                    if pid != -1 and pid not in visited and distance <= distance_threshold:
                        album.append(int(pid))
                        visited.add(pid)
                
                # Add album if it has at least 1 photo
                if album:
                    albums.append(album)
                    
            finally:
                session.close()
        
        return albums

    def add_embedding(self, photo_id: int, embedding: np.ndarray) -> None:
        """
        Add an embedding to the index.

        Args:
            photo_id: Unique photo identifier
            embedding: 1D numpy array of shape (dim,)
        """
        # Ensure embedding is float32 and correct shape
        embedding = embedding.astype(np.float32).reshape(1, -1)

        # Add embedding with its ID
        self.index.add_with_ids(embedding, np.array([photo_id], dtype=np.int64))

        # Save index to disk
        self.save()

    def search(self, query_embedding: np.ndarray, top_k: int = 5) -> List[Tuple[int, float]]:
        """
        Search for similar embeddings.

        Args:
            query_embedding: 1D numpy array of shape (dim,)
            top_k: Number of results to return

        Returns:
            List of (photo_id, distance) tuples with distance <= 0.5
        """
        self.load()

        if self.index.ntotal == 0:
            return []

        # Ensure query is float32 and correct shape
        query_embedding = query_embedding.astype(np.float32).reshape(1, -1)

        # Search in FAISS index
        distances, ids = self.index.search(query_embedding, top_k)

        # Filter invalid and distant results
        results = [
            (int(photo_id), float(distance))
            for photo_id, distance in zip(ids[0], distances[0])
            if photo_id != -1 and distance <= 0.5
        ]

        return results

    def save(self) -> None:
        """Save FAISS index to disk"""
        faiss.write_index(self.index, self.index_path)

    def load(self) -> None:
        """Load FAISS index from disk"""
        if os.path.exists(self.index_path):
            self.index = faiss.read_index(self.index_path)
        else:
            # Recreate empty ID-mapped index if missing
            base_index = faiss.IndexFlatL2(self.dim)
            self.index = faiss.IndexIDMap(base_index)

    def get_stats(self) -> dict:
        """Get index statistics"""
        return {
            "total_embeddings": self.index.ntotal,
            "dimension": self.dim,
            "index_type": type(self.index).__name__,
        }