Spaces:
Running
Running
Commit
Β·
d667f1f
1
Parent(s):
1cb8b50
Used better model for text embedding
Browse files- AI_USAGE_REPORT.txt +22 -12
- cloudzy/agents/image_analyzer_2.py +36 -22
- cloudzy/ai_utils.py +28 -9
- cloudzy/routes/photo.py +1 -1
- cloudzy/routes/search.py +11 -20
- cloudzy/search_engine.py +98 -33
AI_USAGE_REPORT.txt
CHANGED
|
@@ -18,8 +18,8 @@ WHERE & HOW AI WAS USED:
|
|
| 18 |
- Function: Generate images from text prompts
|
| 19 |
|
| 20 |
3. Semantic Search (cloudzy/search_engine.py + cloudzy/routes/search.py)
|
| 21 |
-
- Tool: FAISS (vector database) with embeddings
|
| 22 |
-
- Function: Find visually similar photos via embedding vectors
|
| 23 |
|
| 24 |
PROMPTS & MODEL INPUTS:
|
| 25 |
Image Analysis Prompt #1 - Structured Metadata (image_analyzer.py):
|
|
@@ -39,8 +39,10 @@ Search Queries:
|
|
| 39 |
- Album creation: Groups similar photos by distance threshold (randomized each call)
|
| 40 |
|
| 41 |
MODEL OUTPUTS REFINED:
|
| 42 |
-
β JSON parsing: Extracted structured data from model text response
|
| 43 |
-
β
|
|
|
|
|
|
|
| 44 |
β Album randomization: Added random.shuffle() to prevent deterministic groupings
|
| 45 |
β Error handling: Wrapped API failures to graceful fallbacks
|
| 46 |
|
|
@@ -60,14 +62,22 @@ Manual Refinements (35%):
|
|
| 60 |
- CORS middleware configuration
|
| 61 |
|
| 62 |
KEY TECHNICAL DECISIONS:
|
| 63 |
-
1.
|
| 64 |
-
2.
|
| 65 |
-
3.
|
| 66 |
-
4.
|
| 67 |
-
5.
|
|
|
|
|
|
|
| 68 |
|
| 69 |
FILES MODIFIED FOR IMPROVEMENTS:
|
| 70 |
-
-
|
|
|
|
| 71 |
- image_analyzer.py: JSON error handling for vision model output
|
| 72 |
-
- image_analyzer_2.py:
|
| 73 |
-
- text_to_image.py: Timestamp-based filename collision prevention
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 18 |
- Function: Generate images from text prompts
|
| 19 |
|
| 20 |
3. Semantic Search (cloudzy/search_engine.py + cloudzy/routes/search.py)
|
| 21 |
+
- Tool: FAISS (vector database) with embeddings from Qwen/Qwen3-Embedding-8B (4096-dimensional)
|
| 22 |
+
- Function: Find visually similar photos via L2-normalized embedding vectors
|
| 23 |
|
| 24 |
PROMPTS & MODEL INPUTS:
|
| 25 |
Image Analysis Prompt #1 - Structured Metadata (image_analyzer.py):
|
|
|
|
| 39 |
- Album creation: Groups similar photos by distance threshold (randomized each call)
|
| 40 |
|
| 41 |
MODEL OUTPUTS REFINED:
|
| 42 |
+
β JSON parsing: Extracted structured data from model text response (with dict type-check for Gemini responses)
|
| 43 |
+
β Embedding model upgrade: Migrated from multilingual-e5-large (1024-d) to Qwen3-Embedding-8B (4096-d)
|
| 44 |
+
β L2 normalization: Added unit-vector normalization to embeddings for consistent distance calculations
|
| 45 |
+
β Distance threshold tuning: Adjusted for normalized embeddings (0.5 β 1.0 for search, 0.3 β 1.5 for albums)
|
| 46 |
β Album randomization: Added random.shuffle() to prevent deterministic groupings
|
| 47 |
β Error handling: Wrapped API failures to graceful fallbacks
|
| 48 |
|
|
|
|
| 62 |
- CORS middleware configuration
|
| 63 |
|
| 64 |
KEY TECHNICAL DECISIONS:
|
| 65 |
+
1. Embedding model: Qwen3-Embedding-8B (4096-d) for better semantic understanding than smaller models
|
| 66 |
+
2. L2 normalization: Ensures normalized distances (0-2 range) independent of embedding dimension
|
| 67 |
+
3. Distance thresholds: search() β€ 1.0, create_albums() β€ 1.5 (optimized for normalized embeddings)
|
| 68 |
+
4. Model choice: Qwen3-VL for balanced speed/quality in image analysis
|
| 69 |
+
5. FLUX.1-dev: High-quality image generation over speed
|
| 70 |
+
6. Random album creation: Ensures different groupings per request
|
| 71 |
+
7. HuggingFace Hub: Leveraged pre-tuned models vs training custom
|
| 72 |
|
| 73 |
FILES MODIFIED FOR IMPROVEMENTS:
|
| 74 |
+
- ai_utils.py: Added L2 normalization to both generate_embedding() and _embed_text() methods
|
| 75 |
+
- search_engine.py: Updated distance thresholds (0.5β1.0 search, 0.3β1.5 albums) for normalized embeddings
|
| 76 |
- image_analyzer.py: JSON error handling for vision model output
|
| 77 |
+
- image_analyzer_2.py: Dict type-check for Gemini responses + agentic image analysis with Gemini-2.0-Flash
|
| 78 |
+
- text_to_image.py: Timestamp-based filename collision prevention
|
| 79 |
+
|
| 80 |
+
EMBEDDING UPGRADE SUMMARY:
|
| 81 |
+
Old: multilingual-e5-large (1024-dimensional, unnormalized)
|
| 82 |
+
New: Qwen/Qwen3-Embedding-8B (4096-dimensional, L2-normalized)
|
| 83 |
+
Benefit: Better semantic understanding + consistent distance calculations across query types
|
cloudzy/agents/image_analyzer_2.py
CHANGED
|
@@ -97,29 +97,43 @@ result: {
|
|
| 97 |
|
| 98 |
response = self.agent.run(prompt, images=[image])
|
| 99 |
|
| 100 |
-
#
|
| 101 |
-
|
| 102 |
-
|
| 103 |
-
|
| 104 |
-
#
|
| 105 |
-
|
| 106 |
-
|
| 107 |
-
|
| 108 |
-
|
| 109 |
-
|
| 110 |
-
|
| 111 |
-
|
| 112 |
-
|
| 113 |
-
|
| 114 |
-
print(f"[Warning] No closing brace found in JSON, attempting to add closing brace...")
|
| 115 |
-
json_str = json_str + "}"
|
| 116 |
-
|
| 117 |
try:
|
| 118 |
-
|
| 119 |
-
|
| 120 |
-
|
| 121 |
-
|
| 122 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 123 |
|
| 124 |
|
| 125 |
# Test with sample images
|
|
|
|
| 97 |
|
| 98 |
response = self.agent.run(prompt, images=[image])
|
| 99 |
|
| 100 |
+
# If response is already a dict, return it directly
|
| 101 |
+
if isinstance(response, dict):
|
| 102 |
+
return response
|
| 103 |
+
|
| 104 |
+
# Safely convert to string, handling non-string types
|
| 105 |
+
if response is None:
|
| 106 |
+
text_content = ""
|
| 107 |
+
else:
|
| 108 |
+
text_content = str(response).strip()
|
| 109 |
+
|
| 110 |
+
if not text_content:
|
| 111 |
+
raise ValueError("Model returned empty response")
|
| 112 |
+
|
| 113 |
+
# Try to extract JSON-like dict from model output
|
|
|
|
|
|
|
|
|
|
| 114 |
try:
|
| 115 |
+
if "{" not in text_content:
|
| 116 |
+
raise ValueError("Response does not contain valid JSON structure (missing opening brace)")
|
| 117 |
+
|
| 118 |
+
start = text_content.index("{")
|
| 119 |
+
|
| 120 |
+
# Try to find closing brace
|
| 121 |
+
if "}" not in text_content[start:]:
|
| 122 |
+
# No closing brace found, try adding one
|
| 123 |
+
print(f"[Warning] No closing brace found in response, attempting to add closing brace...")
|
| 124 |
+
json_str = text_content[start:] + "}"
|
| 125 |
+
else:
|
| 126 |
+
end = text_content.rindex("}") + 1
|
| 127 |
+
json_str = text_content[start:end]
|
| 128 |
+
|
| 129 |
+
result = json.loads(json_str)
|
| 130 |
+
return result
|
| 131 |
+
except ValueError as ve:
|
| 132 |
+
raise ValueError(f"Failed to parse model output: {text_content}\nError: {ve}")
|
| 133 |
+
except json.JSONDecodeError as je:
|
| 134 |
+
raise ValueError(f"Invalid JSON in model output: {text_content}\nError: {je}")
|
| 135 |
+
except Exception as e:
|
| 136 |
+
raise ValueError(f"Failed to parse model output: {text_content}\nError: {e}")
|
| 137 |
|
| 138 |
|
| 139 |
# Test with sample images
|
cloudzy/ai_utils.py
CHANGED
|
@@ -1,24 +1,28 @@
|
|
| 1 |
import os
|
| 2 |
import numpy as np
|
| 3 |
from huggingface_hub import InferenceClient
|
|
|
|
|
|
|
| 4 |
|
| 5 |
from dotenv import load_dotenv
|
| 6 |
load_dotenv()
|
| 7 |
|
|
|
|
|
|
|
| 8 |
class ImageEmbeddingGenerator:
|
| 9 |
-
def __init__(self, model_name: str = "
|
| 10 |
"""
|
| 11 |
Initialize the embedding generator with a Hugging Face model.
|
| 12 |
"""
|
| 13 |
self.client = InferenceClient(
|
| 14 |
-
provider="
|
| 15 |
api_key=os.environ["HF_TOKEN_1"],
|
| 16 |
)
|
| 17 |
self.model_name = model_name
|
| 18 |
|
| 19 |
def generate_embedding(self, tags: list[str], description: str, caption: str) -> np.ndarray:
|
| 20 |
"""
|
| 21 |
-
Generate a
|
| 22 |
|
| 23 |
Args:
|
| 24 |
tags: List of tags related to the image
|
|
@@ -26,7 +30,7 @@ class ImageEmbeddingGenerator:
|
|
| 26 |
caption: Short caption for the image
|
| 27 |
|
| 28 |
Returns:
|
| 29 |
-
embedding: 1D numpy array of shape (
|
| 30 |
"""
|
| 31 |
# Combine text fields into a single string
|
| 32 |
text = " ".join(tags) + " " + description + " " + caption
|
|
@@ -40,9 +44,15 @@ class ImageEmbeddingGenerator:
|
|
| 40 |
# Convert to numpy array
|
| 41 |
embedding = np.array(result, dtype=np.float32).reshape(-1)
|
| 42 |
|
| 43 |
-
# Ensure shape is (
|
| 44 |
-
if embedding.shape[0] !=
|
| 45 |
-
raise ValueError(f"Expected embedding of size
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 46 |
|
| 47 |
return embedding
|
| 48 |
|
|
@@ -50,6 +60,7 @@ class ImageEmbeddingGenerator:
|
|
| 50 |
def _embed_text(self, text: str) -> np.ndarray:
|
| 51 |
"""
|
| 52 |
Internal helper to call Hugging Face feature_extraction and return a numpy array.
|
|
|
|
| 53 |
"""
|
| 54 |
result = self.client.feature_extraction(
|
| 55 |
text,
|
|
@@ -57,11 +68,19 @@ class ImageEmbeddingGenerator:
|
|
| 57 |
)
|
| 58 |
embedding = np.array(result, dtype=np.float32).reshape(-1)
|
| 59 |
|
| 60 |
-
if embedding.shape[0] !=
|
| 61 |
-
raise ValueError(f"Expected embedding of size
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 62 |
return embedding
|
| 63 |
|
| 64 |
|
|
|
|
|
|
|
| 65 |
class TextSummarizer:
|
| 66 |
def __init__(self, model_name: str = "facebook/bart-large-cnn"):
|
| 67 |
"""
|
|
|
|
| 1 |
import os
|
| 2 |
import numpy as np
|
| 3 |
from huggingface_hub import InferenceClient
|
| 4 |
+
from typing import List, Dict, Tuple
|
| 5 |
+
import re
|
| 6 |
|
| 7 |
from dotenv import load_dotenv
|
| 8 |
load_dotenv()
|
| 9 |
|
| 10 |
+
|
| 11 |
+
|
| 12 |
class ImageEmbeddingGenerator:
|
| 13 |
+
def __init__(self, model_name: str = "Qwen/Qwen3-Embedding-8B"):
|
| 14 |
"""
|
| 15 |
Initialize the embedding generator with a Hugging Face model.
|
| 16 |
"""
|
| 17 |
self.client = InferenceClient(
|
| 18 |
+
provider="nebius",
|
| 19 |
api_key=os.environ["HF_TOKEN_1"],
|
| 20 |
)
|
| 21 |
self.model_name = model_name
|
| 22 |
|
| 23 |
def generate_embedding(self, tags: list[str], description: str, caption: str) -> np.ndarray:
|
| 24 |
"""
|
| 25 |
+
Generate a 4096-d embedding for an image using its tags, description, and caption.
|
| 26 |
|
| 27 |
Args:
|
| 28 |
tags: List of tags related to the image
|
|
|
|
| 30 |
caption: Short caption for the image
|
| 31 |
|
| 32 |
Returns:
|
| 33 |
+
embedding: 1D numpy array of shape (4096,), normalized to unit length
|
| 34 |
"""
|
| 35 |
# Combine text fields into a single string
|
| 36 |
text = " ".join(tags) + " " + description + " " + caption
|
|
|
|
| 44 |
# Convert to numpy array
|
| 45 |
embedding = np.array(result, dtype=np.float32).reshape(-1)
|
| 46 |
|
| 47 |
+
# Ensure shape is (4096,)
|
| 48 |
+
if embedding.shape[0] != 4096:
|
| 49 |
+
raise ValueError(f"Expected embedding of size 4096, got {embedding.shape[0]}")
|
| 50 |
+
|
| 51 |
+
# Normalize to unit length (L2 normalization)
|
| 52 |
+
# This ensures distances stay consistent across models and dimensions
|
| 53 |
+
norm = np.linalg.norm(embedding)
|
| 54 |
+
if norm > 0:
|
| 55 |
+
embedding = embedding / norm
|
| 56 |
|
| 57 |
return embedding
|
| 58 |
|
|
|
|
| 60 |
def _embed_text(self, text: str) -> np.ndarray:
|
| 61 |
"""
|
| 62 |
Internal helper to call Hugging Face feature_extraction and return a numpy array.
|
| 63 |
+
Embeddings are normalized to unit length for consistent distance calculations.
|
| 64 |
"""
|
| 65 |
result = self.client.feature_extraction(
|
| 66 |
text,
|
|
|
|
| 68 |
)
|
| 69 |
embedding = np.array(result, dtype=np.float32).reshape(-1)
|
| 70 |
|
| 71 |
+
if embedding.shape[0] != 4096:
|
| 72 |
+
raise ValueError(f"Expected embedding of size 4096, got {embedding.shape[0]}")
|
| 73 |
+
|
| 74 |
+
# Normalize to unit length (L2 normalization)
|
| 75 |
+
norm = np.linalg.norm(embedding)
|
| 76 |
+
if norm > 0:
|
| 77 |
+
embedding = embedding / norm
|
| 78 |
+
|
| 79 |
return embedding
|
| 80 |
|
| 81 |
|
| 82 |
+
|
| 83 |
+
|
| 84 |
class TextSummarizer:
|
| 85 |
def __init__(self, model_name: str = "facebook/bart-large-cnn"):
|
| 86 |
"""
|
cloudzy/routes/photo.py
CHANGED
|
@@ -89,7 +89,7 @@ async def get_albums(
|
|
| 89 |
"""
|
| 90 |
|
| 91 |
search_engine = SearchEngine()
|
| 92 |
-
albums_ids = search_engine.
|
| 93 |
APP_DOMAIN = os.getenv("APP_DOMAIN") or "http://127.0.0.1:8000/"
|
| 94 |
summarizer = TextSummarizer()
|
| 95 |
|
|
|
|
| 89 |
"""
|
| 90 |
|
| 91 |
search_engine = SearchEngine()
|
| 92 |
+
albums_ids = search_engine.create_albums_kmeans(top_k=top_k)
|
| 93 |
APP_DOMAIN = os.getenv("APP_DOMAIN") or "http://127.0.0.1:8000/"
|
| 94 |
summarizer = TextSummarizer()
|
| 95 |
|
cloudzy/routes/search.py
CHANGED
|
@@ -21,56 +21,47 @@ async def search_photos(
|
|
| 21 |
session: Session = Depends(get_session),
|
| 22 |
):
|
| 23 |
"""
|
| 24 |
-
Semantic search
|
| 25 |
-
|
| 26 |
-
Converts query to embedding and finds most similar images.
|
| 27 |
-
|
| 28 |
Args:
|
| 29 |
q: Search query (used to generate embedding)
|
| 30 |
top_k: Number of results to return (max 50)
|
| 31 |
-
|
| 32 |
-
Returns: List of similar photos
|
| 33 |
"""
|
| 34 |
|
| 35 |
generator = ImageEmbeddingGenerator()
|
| 36 |
query_embedding = generator._embed_text(q)
|
| 37 |
|
| 38 |
-
|
| 39 |
-
|
| 40 |
-
# Search in FAISS
|
| 41 |
search_engine = SearchEngine()
|
| 42 |
search_results = search_engine.search(query_embedding, top_k=top_k)
|
| 43 |
-
|
| 44 |
-
|
| 45 |
if not search_results:
|
| 46 |
return SearchResponse(
|
| 47 |
query=q,
|
| 48 |
results=[],
|
| 49 |
total_results=0,
|
| 50 |
)
|
| 51 |
-
|
| 52 |
-
APP_DOMAIN = os.getenv("APP_DOMAIN")
|
| 53 |
|
| 54 |
-
|
| 55 |
-
|
| 56 |
-
# Fetch photo details from database
|
| 57 |
result_objects = []
|
|
|
|
| 58 |
for photo_id, distance in search_results:
|
| 59 |
statement = select(Photo).where(Photo.id == photo_id)
|
| 60 |
photo = session.exec(statement).first()
|
| 61 |
-
|
| 62 |
-
if photo:
|
| 63 |
result_objects.append(
|
| 64 |
SearchResult(
|
| 65 |
photo_id=photo.id,
|
| 66 |
filename=photo.filename,
|
| 67 |
-
image_url
|
| 68 |
tags=photo.get_tags(),
|
| 69 |
caption=photo.caption,
|
| 70 |
distance=distance,
|
| 71 |
)
|
| 72 |
)
|
| 73 |
-
|
| 74 |
return SearchResponse(
|
| 75 |
query=q,
|
| 76 |
results=result_objects,
|
|
|
|
| 21 |
session: Session = Depends(get_session),
|
| 22 |
):
|
| 23 |
"""
|
| 24 |
+
Semantic search endpoint using FAISS.
|
| 25 |
+
|
|
|
|
|
|
|
| 26 |
Args:
|
| 27 |
q: Search query (used to generate embedding)
|
| 28 |
top_k: Number of results to return (max 50)
|
| 29 |
+
|
| 30 |
+
Returns: List of similar photos
|
| 31 |
"""
|
| 32 |
|
| 33 |
generator = ImageEmbeddingGenerator()
|
| 34 |
query_embedding = generator._embed_text(q)
|
| 35 |
|
|
|
|
|
|
|
|
|
|
| 36 |
search_engine = SearchEngine()
|
| 37 |
search_results = search_engine.search(query_embedding, top_k=top_k)
|
| 38 |
+
|
|
|
|
| 39 |
if not search_results:
|
| 40 |
return SearchResponse(
|
| 41 |
query=q,
|
| 42 |
results=[],
|
| 43 |
total_results=0,
|
| 44 |
)
|
|
|
|
|
|
|
| 45 |
|
| 46 |
+
APP_DOMAIN = os.getenv("APP_DOMAIN")
|
|
|
|
|
|
|
| 47 |
result_objects = []
|
| 48 |
+
|
| 49 |
for photo_id, distance in search_results:
|
| 50 |
statement = select(Photo).where(Photo.id == photo_id)
|
| 51 |
photo = session.exec(statement).first()
|
| 52 |
+
|
| 53 |
+
if photo:
|
| 54 |
result_objects.append(
|
| 55 |
SearchResult(
|
| 56 |
photo_id=photo.id,
|
| 57 |
filename=photo.filename,
|
| 58 |
+
image_url=f"{APP_DOMAIN}uploads/{photo.filename}",
|
| 59 |
tags=photo.get_tags(),
|
| 60 |
caption=photo.caption,
|
| 61 |
distance=distance,
|
| 62 |
)
|
| 63 |
)
|
| 64 |
+
|
| 65 |
return SearchResponse(
|
| 66 |
query=q,
|
| 67 |
results=result_objects,
|
cloudzy/search_engine.py
CHANGED
|
@@ -9,7 +9,7 @@ import random
|
|
| 9 |
class SearchEngine:
|
| 10 |
"""FAISS-based search engine for image embeddings"""
|
| 11 |
|
| 12 |
-
def __init__(self, dim: int =
|
| 13 |
self.dim = dim
|
| 14 |
self.index_path = index_path
|
| 15 |
|
|
@@ -20,7 +20,7 @@ class SearchEngine:
|
|
| 20 |
base_index = faiss.IndexFlatL2(dim)
|
| 21 |
self.index = faiss.IndexIDMap(base_index)
|
| 22 |
|
| 23 |
-
def create_albums(self, top_k: int = 5, distance_threshold: float =
|
| 24 |
"""
|
| 25 |
Group similar images into albums (clusters).
|
| 26 |
|
|
@@ -28,9 +28,14 @@ class SearchEngine:
|
|
| 28 |
Photos are marked as visited to avoid duplicate albums.
|
| 29 |
Only includes photos within the distance threshold.
|
| 30 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 31 |
Args:
|
| 32 |
top_k: Number of albums to return
|
| 33 |
-
distance_threshold: Maximum distance to consider photos as similar (default 0
|
| 34 |
album_size: How many similar photos to search for per album (default 5)
|
| 35 |
|
| 36 |
Returns:
|
|
@@ -51,6 +56,20 @@ class SearchEngine:
|
|
| 51 |
# Shuffle for randomization - different albums each call
|
| 52 |
random.shuffle(all_ids)
|
| 53 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 54 |
visited = set()
|
| 55 |
albums = []
|
| 56 |
|
|
@@ -63,37 +82,80 @@ class SearchEngine:
|
|
| 63 |
if photo_id in visited:
|
| 64 |
continue
|
| 65 |
|
| 66 |
-
#
|
| 67 |
-
|
| 68 |
-
|
| 69 |
-
|
| 70 |
-
|
| 71 |
-
|
| 72 |
-
|
| 73 |
-
|
| 74 |
-
|
| 75 |
-
|
| 76 |
-
|
| 77 |
-
|
| 78 |
-
|
| 79 |
-
|
| 80 |
-
|
| 81 |
-
|
| 82 |
-
|
| 83 |
-
|
| 84 |
-
|
| 85 |
-
|
| 86 |
-
|
| 87 |
-
|
| 88 |
-
# Add album if it has at least 1 photo
|
| 89 |
-
if album:
|
| 90 |
-
albums.append(album)
|
| 91 |
-
|
| 92 |
-
finally:
|
| 93 |
-
session.close()
|
| 94 |
|
| 95 |
return albums
|
| 96 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 97 |
def add_embedding(self, photo_id: int, embedding: np.ndarray) -> None:
|
| 98 |
"""
|
| 99 |
Add an embedding to the index.
|
|
@@ -120,7 +182,7 @@ class SearchEngine:
|
|
| 120 |
top_k: Number of results to return
|
| 121 |
|
| 122 |
Returns:
|
| 123 |
-
List of (photo_id, distance) tuples with distance <= 0
|
| 124 |
"""
|
| 125 |
self.load()
|
| 126 |
|
|
@@ -133,11 +195,14 @@ class SearchEngine:
|
|
| 133 |
# Search in FAISS index
|
| 134 |
distances, ids = self.index.search(query_embedding, top_k)
|
| 135 |
|
|
|
|
|
|
|
| 136 |
# Filter invalid and distant results
|
|
|
|
| 137 |
results = [
|
| 138 |
(int(photo_id), float(distance))
|
| 139 |
for photo_id, distance in zip(ids[0], distances[0])
|
| 140 |
-
if photo_id != -1 and distance <=
|
| 141 |
]
|
| 142 |
|
| 143 |
return results
|
|
|
|
| 9 |
class SearchEngine:
|
| 10 |
"""FAISS-based search engine for image embeddings"""
|
| 11 |
|
| 12 |
+
def __init__(self, dim: int = 4096, index_path: str = "faiss_index.bin"):
|
| 13 |
self.dim = dim
|
| 14 |
self.index_path = index_path
|
| 15 |
|
|
|
|
| 20 |
base_index = faiss.IndexFlatL2(dim)
|
| 21 |
self.index = faiss.IndexIDMap(base_index)
|
| 22 |
|
| 23 |
+
def create_albums(self, top_k: int = 5, distance_threshold: float = 1.5, album_size: int = 5) -> List[List[int]]:
|
| 24 |
"""
|
| 25 |
Group similar images into albums (clusters).
|
| 26 |
|
|
|
|
| 28 |
Photos are marked as visited to avoid duplicate albums.
|
| 29 |
Only includes photos within the distance threshold.
|
| 30 |
|
| 31 |
+
OPTIMIZATIONS:
|
| 32 |
+
- Batch retrieves all photos in ONE database query (not per-photo)
|
| 33 |
+
- Caches embeddings in memory during execution
|
| 34 |
+
- Single session for all DB operations
|
| 35 |
+
|
| 36 |
Args:
|
| 37 |
top_k: Number of albums to return
|
| 38 |
+
distance_threshold: Maximum distance to consider photos as similar (default 1.0 for normalized embeddings)
|
| 39 |
album_size: How many similar photos to search for per album (default 5)
|
| 40 |
|
| 41 |
Returns:
|
|
|
|
| 56 |
# Shuffle for randomization - different albums each call
|
| 57 |
random.shuffle(all_ids)
|
| 58 |
|
| 59 |
+
# β
OPTIMIZATION 1: Batch retrieve all photos in ONE query
|
| 60 |
+
session = SessionLocal()
|
| 61 |
+
try:
|
| 62 |
+
# Fetch all photos at once, not in a loop
|
| 63 |
+
photos_query = session.exec(select(Photo).where(Photo.id.in_(all_ids))).all()
|
| 64 |
+
# β
OPTIMIZATION 2: Cache embeddings in memory
|
| 65 |
+
embedding_cache = {}
|
| 66 |
+
for photo in photos_query:
|
| 67 |
+
embedding = photo.get_embedding()
|
| 68 |
+
if embedding:
|
| 69 |
+
embedding_cache[photo.id] = embedding
|
| 70 |
+
finally:
|
| 71 |
+
session.close()
|
| 72 |
+
|
| 73 |
visited = set()
|
| 74 |
albums = []
|
| 75 |
|
|
|
|
| 82 |
if photo_id in visited:
|
| 83 |
continue
|
| 84 |
|
| 85 |
+
# Skip if no embedding cached
|
| 86 |
+
if photo_id not in embedding_cache:
|
| 87 |
+
continue
|
| 88 |
+
|
| 89 |
+
# Get embedding from cache (not DB)
|
| 90 |
+
embedding = embedding_cache[photo_id]
|
| 91 |
+
|
| 92 |
+
# Search for similar images
|
| 93 |
+
query_embedding = np.array(embedding).reshape(1, -1).astype(np.float32)
|
| 94 |
+
distances, ids = self.index.search(query_embedding, album_size)
|
| 95 |
+
|
| 96 |
+
# Build album: collect similar photos that haven't been visited and are within threshold
|
| 97 |
+
album = []
|
| 98 |
+
for pid, distance in zip(ids[0], distances[0]):
|
| 99 |
+
if pid != -1 and pid not in visited and distance <= distance_threshold:
|
| 100 |
+
album.append(int(pid))
|
| 101 |
+
visited.add(pid)
|
| 102 |
+
|
| 103 |
+
# Add album if it has at least 1 photo
|
| 104 |
+
if album:
|
| 105 |
+
albums.append(album)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 106 |
|
| 107 |
return albums
|
| 108 |
|
| 109 |
+
def create_albums_kmeans(self, top_k: int = 5, seed: int = 42) -> List[List[int]]:
|
| 110 |
+
"""
|
| 111 |
+
Group similar images into albums using FAISS k-means clustering.
|
| 112 |
+
|
| 113 |
+
This is a BETTER approach than nearest-neighbor grouping:
|
| 114 |
+
- Uses true k-means clustering instead of ad-hoc neighbor search
|
| 115 |
+
- All photos get assigned to a cluster (no "orphans")
|
| 116 |
+
- Deterministic results for same seed
|
| 117 |
+
- Much faster for large datasets
|
| 118 |
+
|
| 119 |
+
Args:
|
| 120 |
+
top_k: Number of clusters (albums) to create
|
| 121 |
+
seed: Random seed for reproducibility
|
| 122 |
+
|
| 123 |
+
Returns:
|
| 124 |
+
List of top_k albums, each album is a list of photo_ids
|
| 125 |
+
"""
|
| 126 |
+
self.load()
|
| 127 |
+
if self.index.ntotal < top_k:
|
| 128 |
+
return []
|
| 129 |
+
|
| 130 |
+
# Get all photo IDs from FAISS index
|
| 131 |
+
id_map = self.index.id_map
|
| 132 |
+
all_ids = np.array([id_map.at(i) for i in range(id_map.size())], dtype=np.int64)
|
| 133 |
+
|
| 134 |
+
# Get all embeddings from the underlying index (IndexIDMap wraps the actual index)
|
| 135 |
+
underlying_index = faiss.downcast_index(self.index.index)
|
| 136 |
+
all_embeddings = underlying_index.reconstruct_n(0, self.index.ntotal).astype(np.float32)
|
| 137 |
+
|
| 138 |
+
# β
Run k-means clustering
|
| 139 |
+
kmeans = faiss.Kmeans(
|
| 140 |
+
d=self.dim,
|
| 141 |
+
k=top_k,
|
| 142 |
+
niter=20,
|
| 143 |
+
verbose=False,
|
| 144 |
+
seed=seed
|
| 145 |
+
)
|
| 146 |
+
kmeans.train(all_embeddings)
|
| 147 |
+
|
| 148 |
+
# Assign each embedding to nearest cluster
|
| 149 |
+
distances, cluster_assignments = kmeans.index.search(all_embeddings, 1)
|
| 150 |
+
|
| 151 |
+
# Group photos by cluster
|
| 152 |
+
albums = [[] for _ in range(top_k)]
|
| 153 |
+
for photo_id, cluster_id in zip(all_ids, cluster_assignments.flatten()):
|
| 154 |
+
albums[cluster_id].append(int(photo_id))
|
| 155 |
+
|
| 156 |
+
# Remove empty albums and return
|
| 157 |
+
return [album for album in albums if album]
|
| 158 |
+
|
| 159 |
def add_embedding(self, photo_id: int, embedding: np.ndarray) -> None:
|
| 160 |
"""
|
| 161 |
Add an embedding to the index.
|
|
|
|
| 182 |
top_k: Number of results to return
|
| 183 |
|
| 184 |
Returns:
|
| 185 |
+
List of (photo_id, distance) tuples with distance <= 1.0 (normalized embeddings)
|
| 186 |
"""
|
| 187 |
self.load()
|
| 188 |
|
|
|
|
| 195 |
# Search in FAISS index
|
| 196 |
distances, ids = self.index.search(query_embedding, top_k)
|
| 197 |
|
| 198 |
+
print(distances)
|
| 199 |
+
|
| 200 |
# Filter invalid and distant results
|
| 201 |
+
# With normalized embeddings, L2 distance range is 0-2, threshold of 1.0 works well
|
| 202 |
results = [
|
| 203 |
(int(photo_id), float(distance))
|
| 204 |
for photo_id, distance in zip(ids[0], distances[0])
|
| 205 |
+
if photo_id != -1 and distance <= 1.5
|
| 206 |
]
|
| 207 |
|
| 208 |
return results
|