Yashashvibhardwaj's picture
Update main.py
fac2e05 verified
raw
history blame
1.85 kB
import os
import json
import faiss
import numpy as np
from fastapi import FastAPI, UploadFile, File, Form
from fastapi.middleware.cors import CORSMiddleware
from sentence_transformers import SentenceTransformer
from PIL import Image
import io
# Fix caching permissions for Hugging Face
os.environ["HF_HOME"] = "./cache"
os.environ["TRANSFORMERS_CACHE"] = "./cache"
os.environ["SENTENCE_TRANSFORMERS_HOME"] = "./cache"
app = FastAPI()
# Enable CORS (for frontend HTML to connect)
app.add_middleware(
CORSMiddleware,
allow_origins=["*"],
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
# Load product metadata
with open("id_mapping.json", "r", encoding="utf-8") as f:
products = json.load(f)
# Load FAISS index
index = faiss.read_index("products.index")
# Load CLIP model
print("🧠 Loading CLIP model...")
model = SentenceTransformer("sentence-transformers/clip-ViT-B-32", cache_folder="./cache")
@app.get("/")
def root():
return {"message": "πŸš€ Visual Product Matcher API is running!"}
@app.post("/search_text")
def search_text(query: str = Form(...), top_k: int = 5):
"""
Search products using text query.
"""
query_emb = model.encode([query], convert_to_numpy=True)
distances, indices = index.search(query_emb, top_k)
results = [products[i] for i in indices[0]]
return {"query": query, "results": results}
@app.post("/search_image")
async def search_image(file: UploadFile = File(...), top_k: int = 5):
"""
Search products using image query.
"""
image_bytes = await file.read()
image = Image.open(io.BytesIO(image_bytes)).convert("RGB")
image_emb = model.encode([image], convert_to_numpy=True)
distances, indices = index.search(image_emb, top_k)
results = [products[i] for i in indices[0]]
return {"results": results}