|
|
import os |
|
|
import json |
|
|
import faiss |
|
|
import numpy as np |
|
|
from fastapi import FastAPI, UploadFile, File, Form |
|
|
from fastapi.middleware.cors import CORSMiddleware |
|
|
from sentence_transformers import SentenceTransformer |
|
|
from PIL import Image |
|
|
import io |
|
|
import requests |
|
|
|
|
|
|
|
|
os.environ["HF_HOME"] = "./cache" |
|
|
os.environ["TRANSFORMERS_CACHE"] = "./cache" |
|
|
os.environ["SENTENCE_TRANSFORMERS_HOME"] = "./cache" |
|
|
|
|
|
app = FastAPI() |
|
|
|
|
|
|
|
|
app.add_middleware( |
|
|
CORSMiddleware, |
|
|
allow_origins=["*"], |
|
|
allow_credentials=True, |
|
|
allow_methods=["*"], |
|
|
allow_headers=["*"], |
|
|
) |
|
|
|
|
|
|
|
|
with open("products.json", "r", encoding="utf-8") as f: |
|
|
products = json.load(f) |
|
|
|
|
|
print(f"π¦ Loaded {len(products)} products") |
|
|
|
|
|
|
|
|
index = faiss.read_index("products.index") |
|
|
|
|
|
|
|
|
print("π§ Loading CLIP model...") |
|
|
model = SentenceTransformer("sentence-transformers/clip-ViT-B-32", cache_folder="./cache") |
|
|
|
|
|
|
|
|
@app.get("/") |
|
|
def root(): |
|
|
return {"message": "π Visual Product Matcher API is running!"} |
|
|
|
|
|
|
|
|
@app.post("/search_text") |
|
|
def search_text(query: str = Form(...), top_k: int = 5, min_score: float = 0.0): |
|
|
""" |
|
|
Search products using text query. |
|
|
""" |
|
|
query_emb = model.encode([query], convert_to_numpy=True, normalize_embeddings=True) |
|
|
sims, indices = index.search(query_emb, top_k) |
|
|
|
|
|
results = [] |
|
|
for sim, idx in zip(sims[0], indices[0]): |
|
|
score = float(sim) |
|
|
if score >= min_score: |
|
|
item = products[idx].copy() |
|
|
item["score"] = score |
|
|
results.append(item) |
|
|
|
|
|
return {"matches": results} |
|
|
|
|
|
|
|
|
@app.post("/match") |
|
|
async def search_image( |
|
|
file: UploadFile = File(None), |
|
|
image_url: str = Form(None), |
|
|
top_k: int = 5, |
|
|
min_score: float = 0.0 |
|
|
): |
|
|
""" |
|
|
Search products using image query (upload or URL). |
|
|
""" |
|
|
if file: |
|
|
image_bytes = await file.read() |
|
|
image = Image.open(io.BytesIO(image_bytes)).convert("RGB") |
|
|
elif image_url: |
|
|
response = requests.get(image_url) |
|
|
image = Image.open(io.BytesIO(response.content)).convert("RGB") |
|
|
else: |
|
|
return {"error": "No image provided"} |
|
|
|
|
|
image_emb = model.encode([image], convert_to_numpy=True, normalize_embeddings=True) |
|
|
sims, indices = index.search(image_emb, top_k) |
|
|
|
|
|
results = [] |
|
|
for sim, idx in zip(sims[0], indices[0]): |
|
|
score = float(sim) |
|
|
if score >= min_score: |
|
|
item = products[idx].copy() |
|
|
item["score"] = score |
|
|
results.append(item) |
|
|
|
|
|
return {"matches": results} |
|
|
|