Spaces:
Runtime error
Runtime error
Mehdi Cherti
commited on
Commit
·
b991bac
1
Parent(s):
be61cf2
update discr
Browse files
score_sde/models/discriminator.py
CHANGED
|
@@ -252,12 +252,31 @@ class SmallCondAttnDiscriminator(nn.Module):
|
|
| 252 |
class Discriminator_large(nn.Module):
|
| 253 |
"""A time-dependent discriminator for large images (CelebA, LSUN)."""
|
| 254 |
|
| 255 |
-
def __init__(self, nc = 1, ngf = 32, t_emb_dim = 128, act=nn.LeakyReLU(0.2), cond_size=768):
|
| 256 |
super().__init__()
|
| 257 |
# Gaussian random feature embedding layer for time
|
| 258 |
self.cond_proj = nn.Linear(cond_size, ngf*8)
|
| 259 |
self.act = act
|
| 260 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 261 |
self.t_embed = TimestepEmbedding(
|
| 262 |
embedding_dim=t_emb_dim,
|
| 263 |
hidden_dim=t_emb_dim,
|
|
@@ -317,7 +336,21 @@ class Discriminator_large(nn.Module):
|
|
| 317 |
out = self.act(out)
|
| 318 |
|
| 319 |
out = out.view(out.shape[0], out.shape[1], -1).sum(2)
|
| 320 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 321 |
return out
|
| 322 |
|
| 323 |
|
|
|
|
| 252 |
class Discriminator_large(nn.Module):
|
| 253 |
"""A time-dependent discriminator for large images (CelebA, LSUN)."""
|
| 254 |
|
| 255 |
+
def __init__(self, nc = 1, ngf = 32, t_emb_dim = 128, act=nn.LeakyReLU(0.2), cond_size=768, attn_pool=False, attn_pool_kw=None):
|
| 256 |
super().__init__()
|
| 257 |
# Gaussian random feature embedding layer for time
|
| 258 |
self.cond_proj = nn.Linear(cond_size, ngf*8)
|
| 259 |
self.act = act
|
| 260 |
+
if attn_pool:
|
| 261 |
+
if attn_pool_kw is None:
|
| 262 |
+
attn_pool_kw = dict(
|
| 263 |
+
depth=1,
|
| 264 |
+
dim_head = 64,
|
| 265 |
+
heads = 8,
|
| 266 |
+
num_latents = 64,
|
| 267 |
+
num_latents_mean_pooled = 4, # number of latents derived from mean pooled representation of the sequence
|
| 268 |
+
max_seq_len = 512,
|
| 269 |
+
ff_mult = 4,
|
| 270 |
+
cosine_sim_attn = False,
|
| 271 |
+
)
|
| 272 |
+
self.attn_pool = layers.PerceiverResampler(
|
| 273 |
+
dim=cond_size,
|
| 274 |
+
**attn_pool_kw,
|
| 275 |
+
)
|
| 276 |
+
max_text_len = 512
|
| 277 |
+
self.null_text_embed = torch.nn.Parameter(torch.randn(1, max_text_len, cond_size))
|
| 278 |
+
else:
|
| 279 |
+
self.attn_pool = None
|
| 280 |
self.t_embed = TimestepEmbedding(
|
| 281 |
embedding_dim=t_emb_dim,
|
| 282 |
hidden_dim=t_emb_dim,
|
|
|
|
| 336 |
out = self.act(out)
|
| 337 |
|
| 338 |
out = out.view(out.shape[0], out.shape[1], -1).sum(2)
|
| 339 |
+
|
| 340 |
+
if self.attn_pool is not None:
|
| 341 |
+
(cond_pooled, cond, cond_mask) = cond
|
| 342 |
+
if len(cond_mask.shape) == 2:
|
| 343 |
+
cond_mask = cond_mask.view(cond_mask.shape[0], cond_mask.shape[1], 1)
|
| 344 |
+
cond = torch.where(
|
| 345 |
+
cond_mask,
|
| 346 |
+
cond,
|
| 347 |
+
self.null_text_embed[:, :cond.shape[1]]
|
| 348 |
+
)
|
| 349 |
+
cond = self.attn_pool(cond)
|
| 350 |
+
cond = cond.mean(dim=1)
|
| 351 |
+
cond = self.cond_proj(cond)
|
| 352 |
+
|
| 353 |
+
out = self.end_linear(out) + (cond * out).sum(dim=1, keepdim=True)
|
| 354 |
return out
|
| 355 |
|
| 356 |
|