Spaces:
Running
on
Zero
Running
on
Zero
Update hy3dgen/shapegen/models/conditioner.py
Browse files
hy3dgen/shapegen/models/conditioner.py
CHANGED
|
@@ -103,7 +103,7 @@ class ImageEncoder(nn.Module):
|
|
| 103 |
|
| 104 |
return last_hidden_state
|
| 105 |
|
| 106 |
-
def unconditional_embedding(self, batch_size):
|
| 107 |
device = next(self.model.parameters()).device
|
| 108 |
dtype = next(self.model.parameters()).dtype
|
| 109 |
zero = torch.zeros(
|
|
@@ -159,9 +159,6 @@ class DinoImageEncoderMV(DinoImageEncoder):
|
|
| 159 |
image = image.to(self.model.device, dtype=self.model.dtype)
|
| 160 |
|
| 161 |
bs, num_views, c, h, w = image.shape
|
| 162 |
-
# TODO: find a better place to set view_num?
|
| 163 |
-
self.view_num = num_views
|
| 164 |
-
|
| 165 |
image = image.view(bs * num_views, c, h, w)
|
| 166 |
|
| 167 |
inputs = self.transform(image)
|
|
@@ -190,12 +187,12 @@ class DinoImageEncoderMV(DinoImageEncoder):
|
|
| 190 |
last_hidden_state.shape[-1])
|
| 191 |
return last_hidden_state
|
| 192 |
|
| 193 |
-
def unconditional_embedding(self, batch_size):
|
| 194 |
device = next(self.model.parameters()).device
|
| 195 |
dtype = next(self.model.parameters()).dtype
|
| 196 |
zero = torch.zeros(
|
| 197 |
batch_size,
|
| 198 |
-
self.num_patches *
|
| 199 |
self.model.config.hidden_size,
|
| 200 |
device=device,
|
| 201 |
dtype=dtype,
|
|
@@ -224,17 +221,17 @@ class DualImageEncoder(nn.Module):
|
|
| 224 |
self.main_image_encoder = build_image_encoder(main_image_encoder)
|
| 225 |
self.additional_image_encoder = build_image_encoder(additional_image_encoder)
|
| 226 |
|
| 227 |
-
def forward(self, image, mask=None):
|
| 228 |
outputs = {
|
| 229 |
-
'main': self.main_image_encoder(image, mask=mask),
|
| 230 |
-
'additional': self.additional_image_encoder(image, mask=mask),
|
| 231 |
}
|
| 232 |
return outputs
|
| 233 |
|
| 234 |
-
def unconditional_embedding(self, batch_size):
|
| 235 |
outputs = {
|
| 236 |
-
'main': self.main_image_encoder.unconditional_embedding(batch_size),
|
| 237 |
-
'additional': self.additional_image_encoder.unconditional_embedding(batch_size),
|
| 238 |
}
|
| 239 |
return outputs
|
| 240 |
|
|
@@ -253,8 +250,8 @@ class SingleImageEncoder(nn.Module):
|
|
| 253 |
}
|
| 254 |
return outputs
|
| 255 |
|
| 256 |
-
def unconditional_embedding(self, batch_size):
|
| 257 |
outputs = {
|
| 258 |
-
'main': self.main_image_encoder.unconditional_embedding(batch_size),
|
| 259 |
}
|
| 260 |
return outputs
|
|
|
|
| 103 |
|
| 104 |
return last_hidden_state
|
| 105 |
|
| 106 |
+
def unconditional_embedding(self, batch_size, **kwargs):
|
| 107 |
device = next(self.model.parameters()).device
|
| 108 |
dtype = next(self.model.parameters()).dtype
|
| 109 |
zero = torch.zeros(
|
|
|
|
| 159 |
image = image.to(self.model.device, dtype=self.model.dtype)
|
| 160 |
|
| 161 |
bs, num_views, c, h, w = image.shape
|
|
|
|
|
|
|
|
|
|
| 162 |
image = image.view(bs * num_views, c, h, w)
|
| 163 |
|
| 164 |
inputs = self.transform(image)
|
|
|
|
| 187 |
last_hidden_state.shape[-1])
|
| 188 |
return last_hidden_state
|
| 189 |
|
| 190 |
+
def unconditional_embedding(self, batch_size, view_idxs=None, **kwargs):
|
| 191 |
device = next(self.model.parameters()).device
|
| 192 |
dtype = next(self.model.parameters()).dtype
|
| 193 |
zero = torch.zeros(
|
| 194 |
batch_size,
|
| 195 |
+
self.num_patches * len(view_idxs[0]),
|
| 196 |
self.model.config.hidden_size,
|
| 197 |
device=device,
|
| 198 |
dtype=dtype,
|
|
|
|
| 221 |
self.main_image_encoder = build_image_encoder(main_image_encoder)
|
| 222 |
self.additional_image_encoder = build_image_encoder(additional_image_encoder)
|
| 223 |
|
| 224 |
+
def forward(self, image, mask=None, **kwargs):
|
| 225 |
outputs = {
|
| 226 |
+
'main': self.main_image_encoder(image, mask=mask, **kwargs),
|
| 227 |
+
'additional': self.additional_image_encoder(image, mask=mask, **kwargs),
|
| 228 |
}
|
| 229 |
return outputs
|
| 230 |
|
| 231 |
+
def unconditional_embedding(self, batch_size, **kwargs):
|
| 232 |
outputs = {
|
| 233 |
+
'main': self.main_image_encoder.unconditional_embedding(batch_size, **kwargs),
|
| 234 |
+
'additional': self.additional_image_encoder.unconditional_embedding(batch_size, **kwargs),
|
| 235 |
}
|
| 236 |
return outputs
|
| 237 |
|
|
|
|
| 250 |
}
|
| 251 |
return outputs
|
| 252 |
|
| 253 |
+
def unconditional_embedding(self, batch_size, **kwargs):
|
| 254 |
outputs = {
|
| 255 |
+
'main': self.main_image_encoder.unconditional_embedding(batch_size, **kwargs),
|
| 256 |
}
|
| 257 |
return outputs
|