Yashashvibhardwaj's picture
Update main.py
dae3a4f verified
raw
history blame
2.61 kB
import os
import json
import faiss
import numpy as np
from fastapi import FastAPI, UploadFile, File, Form
from fastapi.middleware.cors import CORSMiddleware
from sentence_transformers import SentenceTransformer
from PIL import Image
import io
# Fix caching permissions for Hugging Face
os.environ["HF_HOME"] = "./cache"
os.environ["TRANSFORMERS_CACHE"] = "./cache"
os.environ["SENTENCE_TRANSFORMERS_HOME"] = "./cache"
app = FastAPI()
# Enable CORS (so frontend on Netlify can call backend on HF)
app.add_middleware(
CORSMiddleware,
allow_origins=["*"], # for now allow all, can restrict to Netlify domain
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
# Load product metadata
with open("id_mapping.json", "r", encoding="utf-8") as f:
products = json.load(f)
# Load FAISS index
index = faiss.read_index("products.index")
# Load CLIP model
print("🧠 Loading CLIP model...")
model = SentenceTransformer("sentence-transformers/clip-ViT-B-32", cache_folder="./cache")
@app.get("/")
def root():
return {"message": "πŸš€ Visual Product Matcher API is running!"}
@app.post("/search_text")
def search_text(query: str = Form(...), top_k: int = 5, min_score: float = 0.0):
"""
Search products using text query.
"""
query_emb = model.encode([query], convert_to_numpy=True)
distances, indices = index.search(query_emb, top_k)
results = []
for score, idx in zip(distances[0], indices[0]):
if score >= min_score: # filter by threshold
item = products[idx]
item["score"] = float(score)
results.append(item)
return {"matches": results}
@app.post("/match") # πŸ‘ˆ Renamed to match frontend
async def search_image(file: UploadFile = File(None), image_url: str = Form(None), top_k: int = 5, min_score: float = 0.0):
"""
Search products using image query (upload or URL).
"""
if file:
image_bytes = await file.read()
image = Image.open(io.BytesIO(image_bytes)).convert("RGB")
elif image_url:
import requests
response = requests.get(image_url)
image = Image.open(io.BytesIO(response.content)).convert("RGB")
else:
return {"error": "No image provided"}
image_emb = model.encode([image], convert_to_numpy=True)
distances, indices = index.search(image_emb, top_k)
results = []
for score, idx in zip(distances[0], indices[0]):
if score >= min_score:
item = products[idx]
item["score"] = float(score)
results.append(item)
return {"matches": results}