Spaces:
Running
on
Zero
Running
on
Zero
Update src/models/models/worldmirror.py
Browse files
src/models/models/worldmirror.py
CHANGED
|
@@ -131,40 +131,27 @@ class WorldMirror(nn.Module, PyTorchModelHubMixin):
|
|
| 131 |
imgs = views['img']
|
| 132 |
|
| 133 |
# Enable conditional input during training if enabled, or during inference if any cond_flags are set
|
| 134 |
-
use_cond = (
|
| 135 |
-
(self.training and self.enable_cond) or
|
| 136 |
-
(not self.training and sum(cond_flags) > 0)
|
| 137 |
-
)
|
| 138 |
|
| 139 |
# Extract priors and process features based on conditional input
|
| 140 |
-
context_token_list = None
|
| 141 |
if use_cond:
|
| 142 |
priors = self.extract_priors(views)
|
| 143 |
token_list, patch_start_idx = self.visual_geometry_transformer(
|
| 144 |
imgs, priors, cond_flags=cond_flags
|
| 145 |
)
|
| 146 |
-
if self.enable_gs:
|
| 147 |
-
cnums = views["context_nums"]
|
| 148 |
-
context_priors = (priors[0][:,:cnums], priors[1][:,:cnums], priors[2][:,:cnums])
|
| 149 |
-
context_token_list = self.visual_geometry_transformer(
|
| 150 |
-
imgs[:,:cnums], context_priors, cond_flags=cond_flags
|
| 151 |
-
)[0]
|
| 152 |
else:
|
| 153 |
token_list, patch_start_idx = self.visual_geometry_transformer(imgs)
|
| 154 |
-
if self.enable_gs:
|
| 155 |
-
cnums = views["context_nums"] if "context_nums" in views else imgs.shape[1]
|
| 156 |
-
context_token_list = self.visual_geometry_transformer(imgs[:,:cnums])[0]
|
| 157 |
|
| 158 |
# Execute predictions
|
| 159 |
with torch.amp.autocast('cuda', enabled=False):
|
| 160 |
# Generate all predictions
|
| 161 |
preds = self._gen_all_preds(
|
| 162 |
-
token_list,
|
| 163 |
)
|
| 164 |
|
| 165 |
return preds
|
| 166 |
|
| 167 |
-
def _gen_all_preds(self, token_list,
|
| 168 |
imgs, patch_start_idx, views):
|
| 169 |
"""Generate all enabled predictions"""
|
| 170 |
preds = {}
|
|
@@ -175,9 +162,7 @@ class WorldMirror(nn.Module, PyTorchModelHubMixin):
|
|
| 175 |
cam_seq = self.cam_head(token_list)
|
| 176 |
cam_params = cam_seq[-1]
|
| 177 |
preds["camera_params"] = cam_params
|
| 178 |
-
|
| 179 |
-
context_cam_params = self.cam_head(context_token_list)[-1]
|
| 180 |
-
context_preds = {"camera_params": context_cam_params}
|
| 181 |
ext_mat, int_mat = vector_to_camera_matrices(
|
| 182 |
cam_params, image_hw=(imgs.shape[-2], imgs.shape[-1])
|
| 183 |
)
|
|
@@ -216,9 +201,8 @@ class WorldMirror(nn.Module, PyTorchModelHubMixin):
|
|
| 216 |
|
| 217 |
# 3D Gaussian Splatting
|
| 218 |
if self.enable_gs:
|
| 219 |
-
views['context_nums'] = imgs.shape[1] if "context_nums" not in views else views["context_nums"]
|
| 220 |
gs_feat, gs_depth, gs_depth_conf = self.gs_head(
|
| 221 |
-
|
| 222 |
)
|
| 223 |
|
| 224 |
preds["gs_depth"] = gs_depth
|
|
@@ -228,7 +212,6 @@ class WorldMirror(nn.Module, PyTorchModelHubMixin):
|
|
| 228 |
images=imgs,
|
| 229 |
predictions=preds,
|
| 230 |
views=views,
|
| 231 |
-
context_predictions=context_preds
|
| 232 |
)
|
| 233 |
|
| 234 |
return preds
|
|
@@ -246,7 +229,7 @@ class WorldMirror(nn.Module, PyTorchModelHubMixin):
|
|
| 246 |
h, w = views['img'].shape[-2:]
|
| 247 |
|
| 248 |
# Initialize prior variables
|
| 249 |
-
|
| 250 |
|
| 251 |
# Extract camera pose
|
| 252 |
if 'camera_pose' in views:
|
|
|
|
| 131 |
imgs = views['img']
|
| 132 |
|
| 133 |
# Enable conditional input during training if enabled, or during inference if any cond_flags are set
|
| 134 |
+
use_cond = sum(cond_flags) > 0
|
|
|
|
|
|
|
|
|
|
| 135 |
|
| 136 |
# Extract priors and process features based on conditional input
|
|
|
|
| 137 |
if use_cond:
|
| 138 |
priors = self.extract_priors(views)
|
| 139 |
token_list, patch_start_idx = self.visual_geometry_transformer(
|
| 140 |
imgs, priors, cond_flags=cond_flags
|
| 141 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 142 |
else:
|
| 143 |
token_list, patch_start_idx = self.visual_geometry_transformer(imgs)
|
|
|
|
|
|
|
|
|
|
| 144 |
|
| 145 |
# Execute predictions
|
| 146 |
with torch.amp.autocast('cuda', enabled=False):
|
| 147 |
# Generate all predictions
|
| 148 |
preds = self._gen_all_preds(
|
| 149 |
+
token_list, imgs, patch_start_idx, views
|
| 150 |
)
|
| 151 |
|
| 152 |
return preds
|
| 153 |
|
| 154 |
+
def _gen_all_preds(self, token_list,
|
| 155 |
imgs, patch_start_idx, views):
|
| 156 |
"""Generate all enabled predictions"""
|
| 157 |
preds = {}
|
|
|
|
| 162 |
cam_seq = self.cam_head(token_list)
|
| 163 |
cam_params = cam_seq[-1]
|
| 164 |
preds["camera_params"] = cam_params
|
| 165 |
+
|
|
|
|
|
|
|
| 166 |
ext_mat, int_mat = vector_to_camera_matrices(
|
| 167 |
cam_params, image_hw=(imgs.shape[-2], imgs.shape[-1])
|
| 168 |
)
|
|
|
|
| 201 |
|
| 202 |
# 3D Gaussian Splatting
|
| 203 |
if self.enable_gs:
|
|
|
|
| 204 |
gs_feat, gs_depth, gs_depth_conf = self.gs_head(
|
| 205 |
+
token_list, images=imgs, patch_start_idx=patch_start_idx
|
| 206 |
)
|
| 207 |
|
| 208 |
preds["gs_depth"] = gs_depth
|
|
|
|
| 212 |
images=imgs,
|
| 213 |
predictions=preds,
|
| 214 |
views=views,
|
|
|
|
| 215 |
)
|
| 216 |
|
| 217 |
return preds
|
|
|
|
| 229 |
h, w = views['img'].shape[-2:]
|
| 230 |
|
| 231 |
# Initialize prior variables
|
| 232 |
+
depths = rays = poses = None
|
| 233 |
|
| 234 |
# Extract camera pose
|
| 235 |
if 'camera_pose' in views:
|