prthm11 commited on
Commit
adfd01f
·
verified ·
1 Parent(s): 9e9f81f

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +61 -5
app.py CHANGED
@@ -27,6 +27,17 @@ from langchain.chat_models import ChatOpenAI
27
  from langchain_openai import ChatOpenAI
28
  from pydantic import Field, SecretStr
29
  from difflib import get_close_matches
 
 
 
 
 
 
 
 
 
 
 
30
 
31
  os.environ["OPENROUTER_API_KEY"] = os.getenv("OPENROUTER_API_KEY", "default_key_or_placeholder")
32
  class ChatOpenRouter(ChatOpenAI):
@@ -308,6 +319,43 @@ agent_json_resolver = create_react_agent(
308
  prompt=SYSTEM_PROMPT_JSON_CORRECTOR
309
  )
310
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
311
  # Helper function to load the block catalog from a JSON file
312
  def _load_block_catalog(block_type: str) -> Dict:
313
  """
@@ -2342,7 +2390,7 @@ def similarity_matching(sprites_data: str, project_folder: str) -> str:
2342
  # -----------------------------------------
2343
  # Load reference embeddings from JSON
2344
  # -----------------------------------------
2345
- with open(f"{BLOCKS_DIR}/embeddings.json", "r") as f:
2346
  embedding_json = json.load(f)
2347
 
2348
  # =========================================
@@ -2364,10 +2412,18 @@ def similarity_matching(sprites_data: str, project_folder: str) -> str:
2364
  # ============================== #
2365
  # EMBED SPRITE IMAGES #
2366
  # ============================== #
2367
- sprite_features = clip_embd.embed_image(sprite_images_bytes)
2368
-
2369
- sprite_matrix = np.vstack(sprite_features)
2370
- img_matrix = np.array([img["embeddings"] for img in embedding_json])
 
 
 
 
 
 
 
 
2371
 
2372
  # =========================================
2373
  # Compute similarities & pick best match
 
27
  from langchain_openai import ChatOpenAI
28
  from pydantic import Field, SecretStr
29
  from difflib import get_close_matches
30
+ import torch
31
+ from transformers import AutoImageProcessor, AutoModel
32
+
33
+ # --- Config (tune threads as needed) ---
34
+ DINOV2_MODEL = "facebook/dinov2-small" # small = best CPU latency/quality tradeoff
35
+ DEVICE = torch.device("cpu")
36
+ torch.set_num_threads(4) # tune for your CPU
37
+
38
+ # --- Globals for single-shot model load ---
39
+ _dinov2_processor = None
40
+ _dinov2_model = None
41
 
42
  os.environ["OPENROUTER_API_KEY"] = os.getenv("OPENROUTER_API_KEY", "default_key_or_placeholder")
43
  class ChatOpenRouter(ChatOpenAI):
 
319
  prompt=SYSTEM_PROMPT_JSON_CORRECTOR
320
  )
321
 
322
+ # adding the new embedding models:
323
+ def init_dinov2(model_name: str = DINOV2_MODEL, device: torch.device = DEVICE):
324
+ """Lazy-initialize DINOv2 processor & model (call once before embedding)."""
325
+ global _dinov2_processor, _dinov2_model
326
+ if _dinov2_processor is None or _dinov2_model is None:
327
+ _dinov2_processor = AutoImageProcessor.from_pretrained(model_name)
328
+ _dinov2_model = AutoModel.from_pretrained(model_name)
329
+ _dinov2_model.eval().to(device)
330
+
331
+ def embed_bytesio_list(bytesio_list, batch_size: int = 8):
332
+ """
333
+ Accepts a list of BytesIO objects (each contains an image, like your sprite_images_bytes).
334
+ Returns: np.ndarray shape (N, D) of L2-normalized embeddings (dtype float32).
335
+ """
336
+ if _dinov2_processor is None or _dinov2_model is None:
337
+ init_dinov2()
338
+
339
+ imgs = [Image.open(b).convert("RGB") for b in bytesio_list]
340
+ embs = []
341
+ for i in range(0, len(imgs), batch_size):
342
+ batch = imgs[i : i + batch_size]
343
+ inputs = _dinov2_processor(images=batch, return_tensors="pt")
344
+ inputs = {k: v.to(DEVICE) for k, v in inputs.items()}
345
+ with torch.no_grad():
346
+ out = _dinov2_model(**inputs)
347
+ # global image embedding from CLS token
348
+ cls = out.last_hidden_state[:, 0, :] # (B, D)
349
+ cls = torch.nn.functional.normalize(cls, p=2, dim=1) # L2 normalize rows
350
+ embs.append(cls.cpu().numpy())
351
+ if not embs:
352
+ return np.zeros((0, _dinov2_model.config.hidden_size), dtype=np.float32)
353
+ return np.vstack(embs).astype(np.float32)
354
+
355
+ def l2_normalize_rows(a: np.ndarray, eps: float = 1e-12) -> np.ndarray:
356
+ norm = np.linalg.norm(a, axis=1, keepdims=True)
357
+ return a / (norm + eps)
358
+
359
  # Helper function to load the block catalog from a JSON file
360
  def _load_block_catalog(block_type: str) -> Dict:
361
  """
 
2390
  # -----------------------------------------
2391
  # Load reference embeddings from JSON
2392
  # -----------------------------------------
2393
+ with open(f"{BLOCKS_DIR}/embed.json", "r") as f:
2394
  embedding_json = json.load(f)
2395
 
2396
  # =========================================
 
2412
  # ============================== #
2413
  # EMBED SPRITE IMAGES #
2414
  # ============================== #
2415
+ # ensure model is initialized (fast no-op after first call)
2416
+ init_dinov2()
2417
+
2418
+ # embed the incoming sprite BytesIO images (same data structure you already use)
2419
+ sprite_matrix = embed_bytesio_list(sprite_images_bytes, batch_size=8) # shape (N, D)
2420
+
2421
+ # load reference embeddings from JSON (they must be numeric lists)
2422
+ img_matrix = np.array([img["embeddings"] for img in embedding_json], dtype=np.float32)
2423
+
2424
+ # normalize both sides (important — stored embeddings may not be normalized)
2425
+ sprite_matrix = l2_normalize_rows(sprite_matrix)
2426
+ img_matrix = l2_normalize_rows(img_matrix)
2427
 
2428
  # =========================================
2429
  # Compute similarities & pick best match