Yashashvibhardwaj commited on
Commit
fac2e05
·
verified ·
1 Parent(s): eb42a98

Update main.py

Browse files
Files changed (1) hide show
  1. main.py +46 -130
main.py CHANGED
@@ -1,149 +1,65 @@
1
- from fastapi import FastAPI, UploadFile, Form
2
- from fastapi.middleware.cors import CORSMiddleware
3
- import requests
4
- import io
5
- import faiss
6
- import json
7
  import os
 
 
8
  import numpy as np
9
- from PIL import Image
 
10
  from sentence_transformers import SentenceTransformer
 
 
 
 
 
 
 
11
 
12
- # Init FastAPI
13
  app = FastAPI()
 
 
14
  app.add_middleware(
15
  CORSMiddleware,
16
- allow_origins=["*"], # you can restrict to your Vercel URL later
17
  allow_credentials=True,
18
  allow_methods=["*"],
19
- allow_headers=["*"]
20
  )
21
 
22
- # Load CLIP model once
23
- print("🧠 Loading CLIP model...")
24
- model = SentenceTransformer("clip-ViT-B-32")
25
-
26
- # Load dataset
27
- PRODUCTS_FILE = "products.json"
28
- INDEX_FILE = "products.index"
29
-
30
- with open(PRODUCTS_FILE, "r", encoding="utf-8", errors="ignore") as f:
31
  products = json.load(f)
32
 
33
- # Build or load FAISS index
34
- if os.path.exists(INDEX_FILE):
35
- print("📦 Loading existing FAISS index...")
36
- index = faiss.read_index(INDEX_FILE)
37
- else:
38
- print("⚡ Building FAISS index from products.json (first startup only)...")
39
- # Encode product names (lightweight, avoids downloading images)
40
- texts = [p["name"] + " " + p["category"] + " " + p["brand"]
41
- for p in products]
42
- embeddings = model.encode(
43
- texts, convert_to_numpy=True, normalize_embeddings=True)
44
- index = faiss.IndexFlatIP(embeddings.shape[1])
45
- index.add(embeddings.astype("float32"))
46
- faiss.write_index(index, INDEX_FILE)
47
- print(f"✅ Saved FAISS index with {index.ntotal} vectors")
48
 
 
 
 
49
 
50
- def embed_image(img: Image.Image):
51
- return model.encode(img, convert_to_numpy=True, normalize_embeddings=True)
52
-
53
-
54
- def embed_text(query: str):
55
- return model.encode([query], convert_to_numpy=True, normalize_embeddings=True)[0]
56
-
57
-
58
- @app.post("/match")
59
- async def match(
60
- file: UploadFile = None,
61
- image_url: str = Form(None),
62
- min_score: float = Form(0.6),
63
- top_k: int = Form(60),
64
- categories: str = Form(None),
65
- brands: str = Form(None),
66
- min_price: float = Form(0),
67
- max_price: float = Form(9999)
68
- ):
69
- try:
70
- # Get query image
71
- if file:
72
- img = Image.open(io.BytesIO(await file.read())).convert("RGB")
73
- elif image_url:
74
- img = Image.open(io.BytesIO(requests.get(
75
- image_url).content)).convert("RGB")
76
- else:
77
- return {"matches": []}
78
-
79
- # Encode query
80
- q_emb = embed_image(img).reshape(1, -1)
81
-
82
- # Search FAISS
83
- scores, ids = index.search(q_emb, top_k)
84
-
85
- # Parse filters
86
- categories = json.loads(categories) if categories else []
87
- brands = json.loads(brands) if brands else []
88
-
89
- # Collect results
90
- results = []
91
- for score, idx in zip(scores[0], ids[0]):
92
- if score < min_score:
93
- continue
94
- p = products[idx]
95
-
96
- # Apply filters
97
- if categories and p["category"] not in categories:
98
- continue
99
- if brands and p["brand"] not in brands:
100
- continue
101
- if not (min_price <= p["price"] <= max_price):
102
- continue
103
 
104
- results.append({**p, "score": float(score)})
105
- return {"matches": results}
106
- except Exception as e:
107
- return {"error": str(e)}
108
 
109
 
110
  @app.post("/search_text")
111
- async def search_text(
112
- query: str = Form(...),
113
- min_score: float = Form(0.6),
114
- top_k: int = Form(60),
115
- categories: str = Form(None),
116
- brands: str = Form(None),
117
- min_price: float = Form(0),
118
- max_price: float = Form(9999)
119
- ):
120
- try:
121
- # Encode text query
122
- q_emb = embed_text(query).reshape(1, -1)
123
-
124
- # Search FAISS
125
- scores, ids = index.search(q_emb, top_k)
126
-
127
- # Parse filters
128
- categories = json.loads(categories) if categories else []
129
- brands = json.loads(brands) if brands else []
130
-
131
- # Collect results
132
- results = []
133
- for score, idx in zip(scores[0], ids[0]):
134
- if score < min_score:
135
- continue
136
- p = products[idx]
137
-
138
- # Apply filters
139
- if categories and p["category"] not in categories:
140
- continue
141
- if brands and p["brand"] not in brands:
142
- continue
143
- if not (min_price <= p["price"] <= max_price):
144
- continue
145
-
146
- results.append({**p, "score": float(score)})
147
- return {"matches": results}
148
- except Exception as e:
149
- return {"error": str(e)}
 
 
 
 
 
 
 
1
  import os
2
+ import json
3
+ import faiss
4
  import numpy as np
5
+ from fastapi import FastAPI, UploadFile, File, Form
6
+ from fastapi.middleware.cors import CORSMiddleware
7
  from sentence_transformers import SentenceTransformer
8
+ from PIL import Image
9
+ import io
10
+
11
+ # Fix caching permissions for Hugging Face
12
+ os.environ["HF_HOME"] = "./cache"
13
+ os.environ["TRANSFORMERS_CACHE"] = "./cache"
14
+ os.environ["SENTENCE_TRANSFORMERS_HOME"] = "./cache"
15
 
 
16
  app = FastAPI()
17
+
18
+ # Enable CORS (for frontend HTML to connect)
19
  app.add_middleware(
20
  CORSMiddleware,
21
+ allow_origins=["*"],
22
  allow_credentials=True,
23
  allow_methods=["*"],
24
+ allow_headers=["*"],
25
  )
26
 
27
+ # Load product metadata
28
+ with open("id_mapping.json", "r", encoding="utf-8") as f:
 
 
 
 
 
 
 
29
  products = json.load(f)
30
 
31
+ # Load FAISS index
32
+ index = faiss.read_index("products.index")
 
 
 
 
 
 
 
 
 
 
 
 
 
33
 
34
+ # Load CLIP model
35
+ print("🧠 Loading CLIP model...")
36
+ model = SentenceTransformer("sentence-transformers/clip-ViT-B-32", cache_folder="./cache")
37
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
38
 
39
+ @app.get("/")
40
+ def root():
41
+ return {"message": "🚀 Visual Product Matcher API is running!"}
 
42
 
43
 
44
  @app.post("/search_text")
45
+ def search_text(query: str = Form(...), top_k: int = 5):
46
+ """
47
+ Search products using text query.
48
+ """
49
+ query_emb = model.encode([query], convert_to_numpy=True)
50
+ distances, indices = index.search(query_emb, top_k)
51
+ results = [products[i] for i in indices[0]]
52
+ return {"query": query, "results": results}
53
+
54
+
55
+ @app.post("/search_image")
56
+ async def search_image(file: UploadFile = File(...), top_k: int = 5):
57
+ """
58
+ Search products using image query.
59
+ """
60
+ image_bytes = await file.read()
61
+ image = Image.open(io.BytesIO(image_bytes)).convert("RGB")
62
+ image_emb = model.encode([image], convert_to_numpy=True)
63
+ distances, indices = index.search(image_emb, top_k)
64
+ results = [products[i] for i in indices[0]]
65
+ return {"results": results}