Update app.py
Browse files
app.py
CHANGED
|
@@ -27,6 +27,17 @@ from langchain.chat_models import ChatOpenAI
|
|
| 27 |
from langchain_openai import ChatOpenAI
|
| 28 |
from pydantic import Field, SecretStr
|
| 29 |
from difflib import get_close_matches
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 30 |
|
| 31 |
os.environ["OPENROUTER_API_KEY"] = os.getenv("OPENROUTER_API_KEY", "default_key_or_placeholder")
|
| 32 |
class ChatOpenRouter(ChatOpenAI):
|
|
@@ -308,6 +319,43 @@ agent_json_resolver = create_react_agent(
|
|
| 308 |
prompt=SYSTEM_PROMPT_JSON_CORRECTOR
|
| 309 |
)
|
| 310 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 311 |
# Helper function to load the block catalog from a JSON file
|
| 312 |
def _load_block_catalog(block_type: str) -> Dict:
|
| 313 |
"""
|
|
@@ -2342,7 +2390,7 @@ def similarity_matching(sprites_data: str, project_folder: str) -> str:
|
|
| 2342 |
# -----------------------------------------
|
| 2343 |
# Load reference embeddings from JSON
|
| 2344 |
# -----------------------------------------
|
| 2345 |
-
with open(f"{BLOCKS_DIR}/
|
| 2346 |
embedding_json = json.load(f)
|
| 2347 |
|
| 2348 |
# =========================================
|
|
@@ -2364,10 +2412,18 @@ def similarity_matching(sprites_data: str, project_folder: str) -> str:
|
|
| 2364 |
# ============================== #
|
| 2365 |
# EMBED SPRITE IMAGES #
|
| 2366 |
# ============================== #
|
| 2367 |
-
|
| 2368 |
-
|
| 2369 |
-
|
| 2370 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 2371 |
|
| 2372 |
# =========================================
|
| 2373 |
# Compute similarities & pick best match
|
|
|
|
| 27 |
from langchain_openai import ChatOpenAI
|
| 28 |
from pydantic import Field, SecretStr
|
| 29 |
from difflib import get_close_matches
|
| 30 |
+
import torch
|
| 31 |
+
from transformers import AutoImageProcessor, AutoModel
|
| 32 |
+
|
| 33 |
+
# --- Config (tune threads as needed) ---
|
| 34 |
+
DINOV2_MODEL = "facebook/dinov2-small" # small = best CPU latency/quality tradeoff
|
| 35 |
+
DEVICE = torch.device("cpu")
|
| 36 |
+
torch.set_num_threads(4) # tune for your CPU
|
| 37 |
+
|
| 38 |
+
# --- Globals for single-shot model load ---
|
| 39 |
+
_dinov2_processor = None
|
| 40 |
+
_dinov2_model = None
|
| 41 |
|
| 42 |
os.environ["OPENROUTER_API_KEY"] = os.getenv("OPENROUTER_API_KEY", "default_key_or_placeholder")
|
| 43 |
class ChatOpenRouter(ChatOpenAI):
|
|
|
|
| 319 |
prompt=SYSTEM_PROMPT_JSON_CORRECTOR
|
| 320 |
)
|
| 321 |
|
| 322 |
+
# adding the new embedding models:
|
| 323 |
+
def init_dinov2(model_name: str = DINOV2_MODEL, device: torch.device = DEVICE):
|
| 324 |
+
"""Lazy-initialize DINOv2 processor & model (call once before embedding)."""
|
| 325 |
+
global _dinov2_processor, _dinov2_model
|
| 326 |
+
if _dinov2_processor is None or _dinov2_model is None:
|
| 327 |
+
_dinov2_processor = AutoImageProcessor.from_pretrained(model_name)
|
| 328 |
+
_dinov2_model = AutoModel.from_pretrained(model_name)
|
| 329 |
+
_dinov2_model.eval().to(device)
|
| 330 |
+
|
| 331 |
+
def embed_bytesio_list(bytesio_list, batch_size: int = 8):
|
| 332 |
+
"""
|
| 333 |
+
Accepts a list of BytesIO objects (each contains an image, like your sprite_images_bytes).
|
| 334 |
+
Returns: np.ndarray shape (N, D) of L2-normalized embeddings (dtype float32).
|
| 335 |
+
"""
|
| 336 |
+
if _dinov2_processor is None or _dinov2_model is None:
|
| 337 |
+
init_dinov2()
|
| 338 |
+
|
| 339 |
+
imgs = [Image.open(b).convert("RGB") for b in bytesio_list]
|
| 340 |
+
embs = []
|
| 341 |
+
for i in range(0, len(imgs), batch_size):
|
| 342 |
+
batch = imgs[i : i + batch_size]
|
| 343 |
+
inputs = _dinov2_processor(images=batch, return_tensors="pt")
|
| 344 |
+
inputs = {k: v.to(DEVICE) for k, v in inputs.items()}
|
| 345 |
+
with torch.no_grad():
|
| 346 |
+
out = _dinov2_model(**inputs)
|
| 347 |
+
# global image embedding from CLS token
|
| 348 |
+
cls = out.last_hidden_state[:, 0, :] # (B, D)
|
| 349 |
+
cls = torch.nn.functional.normalize(cls, p=2, dim=1) # L2 normalize rows
|
| 350 |
+
embs.append(cls.cpu().numpy())
|
| 351 |
+
if not embs:
|
| 352 |
+
return np.zeros((0, _dinov2_model.config.hidden_size), dtype=np.float32)
|
| 353 |
+
return np.vstack(embs).astype(np.float32)
|
| 354 |
+
|
| 355 |
+
def l2_normalize_rows(a: np.ndarray, eps: float = 1e-12) -> np.ndarray:
|
| 356 |
+
norm = np.linalg.norm(a, axis=1, keepdims=True)
|
| 357 |
+
return a / (norm + eps)
|
| 358 |
+
|
| 359 |
# Helper function to load the block catalog from a JSON file
|
| 360 |
def _load_block_catalog(block_type: str) -> Dict:
|
| 361 |
"""
|
|
|
|
| 2390 |
# -----------------------------------------
|
| 2391 |
# Load reference embeddings from JSON
|
| 2392 |
# -----------------------------------------
|
| 2393 |
+
with open(f"{BLOCKS_DIR}/embed.json", "r") as f:
|
| 2394 |
embedding_json = json.load(f)
|
| 2395 |
|
| 2396 |
# =========================================
|
|
|
|
| 2412 |
# ============================== #
|
| 2413 |
# EMBED SPRITE IMAGES #
|
| 2414 |
# ============================== #
|
| 2415 |
+
# ensure model is initialized (fast no-op after first call)
|
| 2416 |
+
init_dinov2()
|
| 2417 |
+
|
| 2418 |
+
# embed the incoming sprite BytesIO images (same data structure you already use)
|
| 2419 |
+
sprite_matrix = embed_bytesio_list(sprite_images_bytes, batch_size=8) # shape (N, D)
|
| 2420 |
+
|
| 2421 |
+
# load reference embeddings from JSON (they must be numeric lists)
|
| 2422 |
+
img_matrix = np.array([img["embeddings"] for img in embedding_json], dtype=np.float32)
|
| 2423 |
+
|
| 2424 |
+
# normalize both sides (important — stored embeddings may not be normalized)
|
| 2425 |
+
sprite_matrix = l2_normalize_rows(sprite_matrix)
|
| 2426 |
+
img_matrix = l2_normalize_rows(img_matrix)
|
| 2427 |
|
| 2428 |
# =========================================
|
| 2429 |
# Compute similarities & pick best match
|