Yashashvibhardwaj commited on
Commit
3eaabcf
·
1 Parent(s): 9293eee

Deploy backend code

Browse files
Files changed (5) hide show
  1. .gitignore +2 -0
  2. DOCKERFILE +15 -0
  3. build_index.py +80 -0
  4. main.py +149 -0
  5. requirements.txt +6 -0
.gitignore ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ products.json
2
+ products.index
DOCKERFILE ADDED
@@ -0,0 +1,15 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ FROM python:3.11-slim
2
+
3
+ RUN apt-get update && apt-get install -y git wget curl build-essential
4
+
5
+ WORKDIR /app
6
+
7
+ COPY requirements.txt .
8
+ RUN pip install --no-cache-dir -r requirements.txt
9
+
10
+ COPY . .
11
+
12
+ EXPOSE 7860
13
+
14
+ # Build index if missing, then run FastAPI
15
+ CMD ["bash", "-c", "python build_index.py && uvicorn main:app --host 0.0.0.0 --port 7860"]
build_index.py ADDED
@@ -0,0 +1,80 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import json
3
+ import requests
4
+ import io
5
+ import faiss
6
+ import numpy as np
7
+ from PIL import Image
8
+ from sentence_transformers import SentenceTransformer
9
+ from tqdm import tqdm # progress bar
10
+
11
+ # ---------------------------------------------------
12
+ # Locate products.json in the same folder as this script
13
+ # ---------------------------------------------------
14
+ BASE_DIR = os.path.dirname(os.path.abspath(__file__))
15
+ PRODUCTS_FILE = os.path.join(BASE_DIR, "products.json")
16
+ INDEX_FILE = os.path.join(BASE_DIR, "products.index")
17
+
18
+ # ---------------------------------------------------
19
+ # Load product metadata
20
+ # ---------------------------------------------------
21
+ if not os.path.exists(PRODUCTS_FILE):
22
+ raise FileNotFoundError(f"❌ Could not find {PRODUCTS_FILE}")
23
+
24
+ with open(PRODUCTS_FILE, "r", encoding="utf-8") as f:
25
+ products = json.load(f)
26
+
27
+ print(f"📦 Loaded {len(products)} products from {PRODUCTS_FILE}")
28
+
29
+ # ---------------------------------------------------
30
+ # Load CLIP model
31
+ # ---------------------------------------------------
32
+ print("🧠 Loading CLIP model (this may take a few seconds)...")
33
+ model = SentenceTransformer("clip-ViT-B-32")
34
+
35
+ # ---------------------------------------------------
36
+ # Collect unique image URLs (avoid redundant downloads)
37
+ # ---------------------------------------------------
38
+ unique_urls = list({p["image_url"] for p in products})
39
+ print(f"🔗 Found {len(unique_urls)} unique image URLs")
40
+
41
+ # ---------------------------------------------------
42
+ # Compute embeddings for unique URLs
43
+ # ---------------------------------------------------
44
+ url_to_emb = {}
45
+
46
+ for url in tqdm(unique_urls, desc="Embedding unique images"):
47
+ try:
48
+ response = requests.get(url, timeout=10)
49
+ response.raise_for_status()
50
+ img = Image.open(io.BytesIO(response.content)).convert("RGB")
51
+ emb = model.encode(img, convert_to_numpy=True,
52
+ normalize_embeddings=True)
53
+ url_to_emb[url] = emb
54
+ except Exception as e:
55
+ print(f"⚠️ Error processing {url}: {e}")
56
+ url_to_emb[url] = np.zeros(512, dtype=np.float32) # fallback embedding
57
+
58
+ # ---------------------------------------------------
59
+ # Build embeddings array for all products
60
+ # ---------------------------------------------------
61
+ embeddings = []
62
+ for p in products:
63
+ embeddings.append(url_to_emb[p["image_url"]])
64
+
65
+ embeddings = np.array(embeddings).astype("float32")
66
+
67
+ print(f"✅ Built embeddings array: {embeddings.shape}")
68
+
69
+ # ---------------------------------------------------
70
+ # Create FAISS index (cosine similarity via inner product)
71
+ # ---------------------------------------------------
72
+ dim = embeddings.shape[1] # 512 for CLIP
73
+ index = faiss.IndexFlatIP(dim)
74
+ index.add(embeddings)
75
+
76
+ # ---------------------------------------------------
77
+ # Save FAISS index
78
+ # ---------------------------------------------------
79
+ faiss.write_index(index, INDEX_FILE)
80
+ print(f"🎉 Saved FAISS index with {index.ntotal} vectors → {INDEX_FILE}")
main.py ADDED
@@ -0,0 +1,149 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from fastapi import FastAPI, UploadFile, Form
2
+ from fastapi.middleware.cors import CORSMiddleware
3
+ import requests
4
+ import io
5
+ import faiss
6
+ import json
7
+ import os
8
+ import numpy as np
9
+ from PIL import Image
10
+ from sentence_transformers import SentenceTransformer
11
+
12
+ # Init FastAPI
13
+ app = FastAPI()
14
+ app.add_middleware(
15
+ CORSMiddleware,
16
+ allow_origins=["*"], # you can restrict to your Vercel URL later
17
+ allow_credentials=True,
18
+ allow_methods=["*"],
19
+ allow_headers=["*"]
20
+ )
21
+
22
+ # Load CLIP model once
23
+ print("🧠 Loading CLIP model...")
24
+ model = SentenceTransformer("clip-ViT-B-32")
25
+
26
+ # Load dataset
27
+ PRODUCTS_FILE = "products.json"
28
+ INDEX_FILE = "products.index"
29
+
30
+ with open(PRODUCTS_FILE, "r", encoding="utf-8", errors="ignore") as f:
31
+ products = json.load(f)
32
+
33
+ # Build or load FAISS index
34
+ if os.path.exists(INDEX_FILE):
35
+ print("📦 Loading existing FAISS index...")
36
+ index = faiss.read_index(INDEX_FILE)
37
+ else:
38
+ print("⚡ Building FAISS index from products.json (first startup only)...")
39
+ # Encode product names (lightweight, avoids downloading images)
40
+ texts = [p["name"] + " " + p["category"] + " " + p["brand"]
41
+ for p in products]
42
+ embeddings = model.encode(
43
+ texts, convert_to_numpy=True, normalize_embeddings=True)
44
+ index = faiss.IndexFlatIP(embeddings.shape[1])
45
+ index.add(embeddings.astype("float32"))
46
+ faiss.write_index(index, INDEX_FILE)
47
+ print(f"✅ Saved FAISS index with {index.ntotal} vectors")
48
+
49
+
50
+ def embed_image(img: Image.Image):
51
+ return model.encode(img, convert_to_numpy=True, normalize_embeddings=True)
52
+
53
+
54
+ def embed_text(query: str):
55
+ return model.encode([query], convert_to_numpy=True, normalize_embeddings=True)[0]
56
+
57
+
58
+ @app.post("/match")
59
+ async def match(
60
+ file: UploadFile = None,
61
+ image_url: str = Form(None),
62
+ min_score: float = Form(0.6),
63
+ top_k: int = Form(60),
64
+ categories: str = Form(None),
65
+ brands: str = Form(None),
66
+ min_price: float = Form(0),
67
+ max_price: float = Form(9999)
68
+ ):
69
+ try:
70
+ # Get query image
71
+ if file:
72
+ img = Image.open(io.BytesIO(await file.read())).convert("RGB")
73
+ elif image_url:
74
+ img = Image.open(io.BytesIO(requests.get(
75
+ image_url).content)).convert("RGB")
76
+ else:
77
+ return {"matches": []}
78
+
79
+ # Encode query
80
+ q_emb = embed_image(img).reshape(1, -1)
81
+
82
+ # Search FAISS
83
+ scores, ids = index.search(q_emb, top_k)
84
+
85
+ # Parse filters
86
+ categories = json.loads(categories) if categories else []
87
+ brands = json.loads(brands) if brands else []
88
+
89
+ # Collect results
90
+ results = []
91
+ for score, idx in zip(scores[0], ids[0]):
92
+ if score < min_score:
93
+ continue
94
+ p = products[idx]
95
+
96
+ # Apply filters
97
+ if categories and p["category"] not in categories:
98
+ continue
99
+ if brands and p["brand"] not in brands:
100
+ continue
101
+ if not (min_price <= p["price"] <= max_price):
102
+ continue
103
+
104
+ results.append({**p, "score": float(score)})
105
+ return {"matches": results}
106
+ except Exception as e:
107
+ return {"error": str(e)}
108
+
109
+
110
+ @app.post("/search_text")
111
+ async def search_text(
112
+ query: str = Form(...),
113
+ min_score: float = Form(0.6),
114
+ top_k: int = Form(60),
115
+ categories: str = Form(None),
116
+ brands: str = Form(None),
117
+ min_price: float = Form(0),
118
+ max_price: float = Form(9999)
119
+ ):
120
+ try:
121
+ # Encode text query
122
+ q_emb = embed_text(query).reshape(1, -1)
123
+
124
+ # Search FAISS
125
+ scores, ids = index.search(q_emb, top_k)
126
+
127
+ # Parse filters
128
+ categories = json.loads(categories) if categories else []
129
+ brands = json.loads(brands) if brands else []
130
+
131
+ # Collect results
132
+ results = []
133
+ for score, idx in zip(scores[0], ids[0]):
134
+ if score < min_score:
135
+ continue
136
+ p = products[idx]
137
+
138
+ # Apply filters
139
+ if categories and p["category"] not in categories:
140
+ continue
141
+ if brands and p["brand"] not in brands:
142
+ continue
143
+ if not (min_price <= p["price"] <= max_price):
144
+ continue
145
+
146
+ results.append({**p, "score": float(score)})
147
+ return {"matches": results}
148
+ except Exception as e:
149
+ return {"error": str(e)}
requirements.txt ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ fastapi
2
+ uvicorn
3
+ faiss-cpu
4
+ sentence-transformers
5
+ pillow
6
+ requests