Spaces:
Runtime error
Runtime error
Update autoregressive/models/gpt_t2i.py
Browse files
autoregressive/models/gpt_t2i.py
CHANGED
|
@@ -367,6 +367,7 @@ class Transformer(nn.Module):
|
|
| 367 |
self.mask = get_causal_mask(256)
|
| 368 |
self.global_token = None
|
| 369 |
|
|
|
|
| 370 |
|
| 371 |
def initialize_weights(self):
|
| 372 |
# Initialize nn.Linear and nn.Embedding
|
|
@@ -411,7 +412,8 @@ class Transformer(nn.Module):
|
|
| 411 |
targets: Optional[torch.Tensor] = None,
|
| 412 |
mask: Optional[torch.Tensor] = None,
|
| 413 |
valid: Optional[torch.Tensor] = None,
|
| 414 |
-
condition: Optional[torch.Tensor] = None
|
|
|
|
| 415 |
):
|
| 416 |
if idx is not None and cond_idx is not None: # training or naive inference
|
| 417 |
cond_embeddings,drop_ids = self.cls_embedding(cond_idx, train=self.training)
|
|
@@ -432,6 +434,9 @@ class Transformer(nn.Module):
|
|
| 432 |
if condition is not None:
|
| 433 |
condition_embeddings = self.condition_mlp(condition,train=self.training)#.to(torch.bfloat16),train=self.training)
|
| 434 |
self.condition_token = condition_embeddings
|
|
|
|
|
|
|
|
|
|
| 435 |
|
| 436 |
else: # decode_n_tokens(kv cache) in inference
|
| 437 |
token_embeddings = self.tok_embeddings(idx)
|
|
@@ -451,9 +456,11 @@ class Transformer(nn.Module):
|
|
| 451 |
h[:, self.cls_token_num-1:] = h[:, self.cls_token_num-1:] + self.condition_layers[i//self.layer_internal](self.condition_token)
|
| 452 |
else:
|
| 453 |
if len(input_pos)>1:
|
| 454 |
-
h[:, -1:] = h[:, -1:] + self.condition_layers[i//self.layer_internal](self.condition_token[:,0:1])
|
|
|
|
| 455 |
else:
|
| 456 |
-
h = h + self.condition_layers[i//self.layer_internal](self.condition_token[:,input_pos-self.cls_token_num+1])
|
|
|
|
| 457 |
h = layer(h, freqs_cis, input_pos, mask)
|
| 458 |
# output layers
|
| 459 |
h = self.norm(h)
|
|
|
|
| 367 |
self.mask = get_causal_mask(256)
|
| 368 |
self.global_token = None
|
| 369 |
|
| 370 |
+
self.control_strength = 1
|
| 371 |
|
| 372 |
def initialize_weights(self):
|
| 373 |
# Initialize nn.Linear and nn.Embedding
|
|
|
|
| 412 |
targets: Optional[torch.Tensor] = None,
|
| 413 |
mask: Optional[torch.Tensor] = None,
|
| 414 |
valid: Optional[torch.Tensor] = None,
|
| 415 |
+
condition: Optional[torch.Tensor] = None,
|
| 416 |
+
control_strength: Optional[int] = 1
|
| 417 |
):
|
| 418 |
if idx is not None and cond_idx is not None: # training or naive inference
|
| 419 |
cond_embeddings,drop_ids = self.cls_embedding(cond_idx, train=self.training)
|
|
|
|
| 434 |
if condition is not None:
|
| 435 |
condition_embeddings = self.condition_mlp(condition,train=self.training)#.to(torch.bfloat16),train=self.training)
|
| 436 |
self.condition_token = condition_embeddings
|
| 437 |
+
self.condition_token = [self.condition_layer[0](self.condition_token),
|
| 438 |
+
self.condition_layer[1](self.condition_token),
|
| 439 |
+
self.condition_layer[2](self.condition_token)]
|
| 440 |
|
| 441 |
else: # decode_n_tokens(kv cache) in inference
|
| 442 |
token_embeddings = self.tok_embeddings(idx)
|
|
|
|
| 456 |
h[:, self.cls_token_num-1:] = h[:, self.cls_token_num-1:] + self.condition_layers[i//self.layer_internal](self.condition_token)
|
| 457 |
else:
|
| 458 |
if len(input_pos)>1:
|
| 459 |
+
# h[:, -1:] = h[:, -1:] + self.condition_layers[i//self.layer_internal](self.condition_token[:,0:1])
|
| 460 |
+
h[:,-1:] = h[:, -1:] + self.control_strength*self.condition_token[i//self.layer_internal][:,0:1]
|
| 461 |
else:
|
| 462 |
+
# h = h + self.condition_layers[i//self.layer_internal](self.condition_token[:,input_pos-self.cls_token_num+1])
|
| 463 |
+
h = h + self.control_strength*self.condition_token[i//self.layer_internal][:,input_pos-self.cls_token_num+1]
|
| 464 |
h = layer(h, freqs_cis, input_pos, mask)
|
| 465 |
# output layers
|
| 466 |
h = self.norm(h)
|