|
|
import os |
|
|
import json |
|
|
import faiss |
|
|
import numpy as np |
|
|
from fastapi import FastAPI, UploadFile, File, Form |
|
|
from fastapi.middleware.cors import CORSMiddleware |
|
|
from sentence_transformers import SentenceTransformer |
|
|
from PIL import Image |
|
|
import io |
|
|
|
|
|
|
|
|
os.environ["HF_HOME"] = "./cache" |
|
|
os.environ["TRANSFORMERS_CACHE"] = "./cache" |
|
|
os.environ["SENTENCE_TRANSFORMERS_HOME"] = "./cache" |
|
|
|
|
|
app = FastAPI() |
|
|
|
|
|
|
|
|
app.add_middleware( |
|
|
CORSMiddleware, |
|
|
allow_origins=["*"], |
|
|
allow_credentials=True, |
|
|
allow_methods=["*"], |
|
|
allow_headers=["*"], |
|
|
) |
|
|
|
|
|
|
|
|
with open("id_mapping.json", "r", encoding="utf-8") as f: |
|
|
products = json.load(f) |
|
|
|
|
|
|
|
|
index = faiss.read_index("products.index") |
|
|
|
|
|
|
|
|
print("π§ Loading CLIP model...") |
|
|
model = SentenceTransformer("sentence-transformers/clip-ViT-B-32", cache_folder="./cache") |
|
|
|
|
|
|
|
|
@app.get("/") |
|
|
def root(): |
|
|
return {"message": "π Visual Product Matcher API is running!"} |
|
|
|
|
|
|
|
|
@app.post("/search_text") |
|
|
def search_text(query: str = Form(...), top_k: int = 5): |
|
|
""" |
|
|
Search products using text query. |
|
|
""" |
|
|
query_emb = model.encode([query], convert_to_numpy=True) |
|
|
distances, indices = index.search(query_emb, top_k) |
|
|
results = [products[i] for i in indices[0]] |
|
|
return {"query": query, "results": results} |
|
|
|
|
|
|
|
|
@app.post("/search_image") |
|
|
async def search_image(file: UploadFile = File(...), top_k: int = 5): |
|
|
""" |
|
|
Search products using image query. |
|
|
""" |
|
|
image_bytes = await file.read() |
|
|
image = Image.open(io.BytesIO(image_bytes)).convert("RGB") |
|
|
image_emb = model.encode([image], convert_to_numpy=True) |
|
|
distances, indices = index.search(image_emb, top_k) |
|
|
results = [products[i] for i in indices[0]] |
|
|
return {"results": results} |
|
|
|