prthm11 commited on
Commit
1971458
·
verified ·
1 Parent(s): d60a89a

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +73 -60
app.py CHANGED
@@ -3389,44 +3389,43 @@ SPRITE_DIR/"Zebra.sprite3"/"f3e322a25b9f79801066056de6f33fb1.png"
3389
  folder_image_paths = [os.path.normpath(str(p)) for p in folder_image_paths]
3390
 
3391
 
3392
- # ============================== #
3393
- # EMBED SPRITE IMAGES #
3394
- # (using CLIP again) #
3395
- # ============================== #
3396
-
3397
- # Make sure all buffers are at start
3398
- for buf in sprite_images_bytes:
3399
- try:
3400
- buf.seek(0)
3401
- except Exception:
3402
- pass
3403
-
3404
- # Try the fast path: embed whole list at once (many CLIP wrappers accept a list of BytesIO/PIL)
3405
- try:
3406
- sprite_matrix = clip_embd.embed_image(sprite_images_bytes, batch_size=8)
3407
- sprite_matrix = np.array(sprite_matrix, dtype=np.float32)
3408
- except Exception:
3409
- sprite_feats = []
3410
- for buf in sprite_images_bytes:
3411
- buf.seek(0)
3412
- try:
3413
- feats = clip_embd.embed_image([buf])[0]
3414
- except Exception:
3415
- buf.seek(0)
3416
- pil_img = Image.open(buf).convert("RGB")
3417
- try:
3418
- feats = clip_embd.embed_image([pil_img])[0]
3419
- except Exception:
3420
- pil_arr = np.array(pil_img)
3421
- feats = clip_embd.embed_image([pil_arr])[0]
3422
- sprite_feats.append(np.asarray(feats, dtype=np.float32))
3423
- sprite_matrix = np.vstack(sprite_feats) # shape (N, D)
3424
 
3425
  # --- load reference embeddings (unchanged) ---
3426
- with open(f"{BLOCKS_DIR}/openclip_embeddings.json", "r") as f:
3427
- embedding_json = json.load(f)
3428
 
3429
- img_matrix = np.array([img["embeddings"] for img in embedding_json], dtype=np.float32)
3430
 
3431
 
3432
  # =========================================
@@ -3443,21 +3442,21 @@ SPRITE_DIR/"Zebra.sprite3"/"f3e322a25b9f79801066056de6f33fb1.png"
3443
  # # ============================== #
3444
  # # EMBED SPRITE IMAGES #
3445
  # # ============================== #
3446
- # sprite_features = []
3447
- # for b64 in sprite_base64:
3448
- # if "," in b64: # strip data URI prefix if present
3449
- # b64 = b64.split(",", 1)[1]
3450
 
3451
- # img_bytes = base64.b64decode(b64)
3452
- # pil_img = Image.open(BytesIO(img_bytes)).convert("RGB")
3453
 
3454
- # # optional re-encode to PNG for CLIP
3455
- # buf = BytesIO()
3456
- # pil_img.save(buf, format="PNG")
3457
- # buf.seek(0)
3458
 
3459
- # feats = clip_embd.embed_image([buf])[0] # extract CLIP embedding
3460
- # sprite_features.append(feats)
3461
 
3462
  # sprite_matrix = np.array(sprite_features, dtype=np.float32)
3463
  # # ============================== #
@@ -3474,23 +3473,37 @@ SPRITE_DIR/"Zebra.sprite3"/"f3e322a25b9f79801066056de6f33fb1.png"
3474
 
3475
  # normalize both sides (important — stored embeddings may not be normalized)
3476
 
3477
- def l2_normalize_rows(x: np.ndarray, eps: float = 1e-10) -> np.ndarray:
3478
- """
3479
- L2-normalize each row of a 2D numpy array.
3480
 
3481
- Args:
3482
- x: Array of shape (N, D).
3483
- eps: Small constant to avoid division by zero.
3484
 
3485
- Returns:
3486
- Normalized array of shape (N, D) where each row has unit norm.
3487
- """
3488
- norms = np.linalg.norm(x, axis=1, keepdims=True)
3489
- return x / np.maximum(norms, eps)
3490
 
3491
- sprite_matrix = l2_normalize_rows(sprite_matrix)
3492
- img_matrix = l2_normalize_rows(img_matrix)
 
 
 
 
 
 
 
 
 
 
 
3493
 
 
 
 
3494
  # =========================================
3495
  # Compute similarities & pick best match
3496
  # =========================================
 
3389
  folder_image_paths = [os.path.normpath(str(p)) for p in folder_image_paths]
3390
 
3391
 
3392
+ # # ============================== #
3393
+ # # EMBED SPRITE IMAGES #
3394
+ # # (using CLIP again) #
3395
+ # # ============================== #
3396
+ # # Make sure all buffers are at start
3397
+ # for buf in sprite_images_bytes:
3398
+ # try:
3399
+ # buf.seek(0)
3400
+ # except Exception:
3401
+ # pass
3402
+
3403
+ # # Try the fast path: embed whole list at once (many CLIP wrappers accept a list of BytesIO/PIL)
3404
+ # try:
3405
+ # sprite_matrix = clip_embd.embed_image(sprite_images_bytes, batch_size=8)
3406
+ # sprite_matrix = np.array(sprite_matrix, dtype=np.float32)
3407
+ # except Exception:
3408
+ # sprite_feats = []
3409
+ # for buf in sprite_images_bytes:
3410
+ # buf.seek(0)
3411
+ # try:
3412
+ # feats = clip_embd.embed_image([buf])[0]
3413
+ # except Exception:
3414
+ # buf.seek(0)
3415
+ # pil_img = Image.open(buf).convert("RGB")
3416
+ # try:
3417
+ # feats = clip_embd.embed_image([pil_img])[0]
3418
+ # except Exception:
3419
+ # pil_arr = np.array(pil_img)
3420
+ # feats = clip_embd.embed_image([pil_arr])[0]
3421
+ # sprite_feats.append(np.asarray(feats, dtype=np.float32))
3422
+ # sprite_matrix = np.vstack(sprite_feats) # shape (N, D)
 
3423
 
3424
  # --- load reference embeddings (unchanged) ---
3425
+ # with open(f"{BLOCKS_DIR}/openclip_embeddings.json", "r") as f:
3426
+ # embedding_json = json.load(f)
3427
 
3428
+ # img_matrix = np.array([img["embeddings"] for img in embedding_json], dtype=np.float32)
3429
 
3430
 
3431
  # =========================================
 
3442
  # # ============================== #
3443
  # # EMBED SPRITE IMAGES #
3444
  # # ============================== #
3445
+ sprite_features = []
3446
+ for b64 in sprite_base64:
3447
+ if "," in b64: # strip data URI prefix if present
3448
+ b64 = b64.split(",", 1)[1]
3449
 
3450
+ img_bytes = base64.b64decode(b64)
3451
+ pil_img = Image.open(BytesIO(img_bytes)).convert("RGB")
3452
 
3453
+ # optional re-encode to PNG for CLIP
3454
+ buf = BytesIO()
3455
+ pil_img.save(buf, format="PNG")
3456
+ buf.seek(0)
3457
 
3458
+ feats = clip_embd.embed_image([buf])[0] # extract CLIP embedding
3459
+ sprite_features.append(feats)
3460
 
3461
  # sprite_matrix = np.array(sprite_features, dtype=np.float32)
3462
  # # ============================== #
 
3473
 
3474
  # normalize both sides (important — stored embeddings may not be normalized)
3475
 
3476
+ # def l2_normalize_rows(x: np.ndarray, eps: float = 1e-10) -> np.ndarray:
3477
+ # """
3478
+ # L2-normalize each row of a 2D numpy array.
3479
 
3480
+ # Args:
3481
+ # x: Array of shape (N, D).
3482
+ # eps: Small constant to avoid division by zero.
3483
 
3484
+ # Returns:
3485
+ # Normalized array of shape (N, D) where each row has unit norm.
3486
+ # """
3487
+ # norms = np.linalg.norm(x, axis=1, keepdims=True)
3488
+ # return x / np.maximum(norms, eps)
3489
 
3490
+ # sprite_matrix = l2_normalize_rows(sprite_matrix)
3491
+ # img_matrix = l2_normalize_rows(img_matrix)
3492
+ sprite_features = clip_embd.embed_image(sprite_image_paths)
3493
+
3494
+ # ============================== #
3495
+ # COMPUTE SIMILARITIES #
3496
+ # ============================== #
3497
+ with open(f"{BLOCKS_DIR}/openclip_embeddings.json", "r") as f:
3498
+ embedding_json = json.load(f)
3499
+ # print(f"\n\n EMBEDDING JSON: {embedding_json}")
3500
+
3501
+ img_matrix = np.array([img["embeddings"] for img in embedding_json])
3502
+ sprite_matrix = np.array(sprite_features)
3503
 
3504
+ similarity = np.matmul(sprite_matrix, img_matrix.T)
3505
+ most_similar_indices = np.argmax(similarity, axis=1)
3506
+
3507
  # =========================================
3508
  # Compute similarities & pick best match
3509
  # =========================================