Spaces:
Starting
Starting
Update gradio_app.py
Browse files- gradio_app.py +26 -11
gradio_app.py
CHANGED
|
@@ -71,26 +71,30 @@ def create_batch(input_image: Image.Image) -> dict[str, Any]:
|
|
| 71 |
rgb = img_array
|
| 72 |
mask = np.ones((*img_array.shape[:2], 1), dtype=np.float32)
|
| 73 |
|
| 74 |
-
# Convert to tensors
|
| 75 |
-
rgb = torch.from_numpy(rgb).float()
|
| 76 |
-
mask = torch.from_numpy(mask).float()
|
| 77 |
print("[debug] rgb tensor shape:", rgb.shape)
|
| 78 |
print("[debug] mask tensor shape:", mask.shape)
|
| 79 |
|
| 80 |
# Create background blend
|
| 81 |
-
bg_tensor = torch.tensor(BACKGROUND_COLOR)[
|
| 82 |
print("[debug] bg_tensor shape:", bg_tensor.shape)
|
| 83 |
|
| 84 |
# Blend RGB with background using mask
|
| 85 |
-
rgb_cond = torch.lerp(
|
| 86 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 87 |
|
| 88 |
-
# Permute the tensors to
|
| 89 |
-
rgb_cond =
|
| 90 |
-
mask =
|
| 91 |
|
| 92 |
-
print("[debug] rgb_cond
|
| 93 |
-
print("[debug] mask
|
| 94 |
|
| 95 |
batch = {
|
| 96 |
"rgb_cond": rgb_cond,
|
|
@@ -109,6 +113,17 @@ def create_batch(input_image: Image.Image) -> dict[str, Any]:
|
|
| 109 |
def forward_model(batch, system, guidance_scale=3.0, seed=0, device="cuda"):
|
| 110 |
"""Process batch through model and generate point cloud."""
|
| 111 |
print("[debug] Starting forward_model")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 112 |
batch_size = batch["rgb_cond"].shape[0]
|
| 113 |
|
| 114 |
# Generate point cloud tokens
|
|
|
|
| 71 |
rgb = img_array
|
| 72 |
mask = np.ones((*img_array.shape[:2], 1), dtype=np.float32)
|
| 73 |
|
| 74 |
+
# Convert to tensors and keep in channel-last format initially
|
| 75 |
+
rgb = torch.from_numpy(rgb).float() # [H, W, 3]
|
| 76 |
+
mask = torch.from_numpy(mask).float() # [H, W, 1]
|
| 77 |
print("[debug] rgb tensor shape:", rgb.shape)
|
| 78 |
print("[debug] mask tensor shape:", mask.shape)
|
| 79 |
|
| 80 |
# Create background blend
|
| 81 |
+
bg_tensor = torch.tensor(BACKGROUND_COLOR) # [3]
|
| 82 |
print("[debug] bg_tensor shape:", bg_tensor.shape)
|
| 83 |
|
| 84 |
# Blend RGB with background using mask
|
| 85 |
+
rgb_cond = torch.lerp(
|
| 86 |
+
bg_tensor.view(1, 1, 3), # [1, 1, 3]
|
| 87 |
+
rgb, # [H, W, 3]
|
| 88 |
+
mask # [H, W, 1]
|
| 89 |
+
)
|
| 90 |
+
print("[debug] rgb_cond shape after blend:", rgb_cond.shape)
|
| 91 |
|
| 92 |
+
# Permute the tensors to [B, C, H, W] format at the end
|
| 93 |
+
rgb_cond = rgb_cond.permute(2, 0, 1).unsqueeze(0) # [1, 3, H, W]
|
| 94 |
+
mask = mask.permute(2, 0, 1).unsqueeze(0) # [1, 1, H, W]
|
| 95 |
|
| 96 |
+
print("[debug] rgb_cond final shape:", rgb_cond.shape)
|
| 97 |
+
print("[debug] mask final shape:", mask.shape)
|
| 98 |
|
| 99 |
batch = {
|
| 100 |
"rgb_cond": rgb_cond,
|
|
|
|
| 113 |
def forward_model(batch, system, guidance_scale=3.0, seed=0, device="cuda"):
|
| 114 |
"""Process batch through model and generate point cloud."""
|
| 115 |
print("[debug] Starting forward_model")
|
| 116 |
+
print("[debug] Input rgb_cond shape:", batch["rgb_cond"].shape)
|
| 117 |
+
|
| 118 |
+
# Ensure input is in correct format [B, C, H, W]
|
| 119 |
+
if batch["rgb_cond"].shape[1] != 3:
|
| 120 |
+
batch["rgb_cond"] = batch["rgb_cond"].permute(0, 3, 1, 2)
|
| 121 |
+
if batch["mask_cond"].shape[1] != 1:
|
| 122 |
+
batch["mask_cond"] = batch["mask_cond"].permute(0, 3, 1, 2)
|
| 123 |
+
|
| 124 |
+
print("[debug] Processed rgb_cond shape:", batch["rgb_cond"].shape)
|
| 125 |
+
print("[debug] Processed mask_cond shape:", batch["mask_cond"].shape)
|
| 126 |
+
|
| 127 |
batch_size = batch["rgb_cond"].shape[0]
|
| 128 |
|
| 129 |
# Generate point cloud tokens
|