Spaces:
Runtime error
Runtime error
Update autoregressive/models/gpt_t2i.py
Browse files
autoregressive/models/gpt_t2i.py
CHANGED
|
@@ -429,6 +429,7 @@ class Transformer(nn.Module):
|
|
| 429 |
self.freqs_cis = self.freqs_cis.to(h.device)
|
| 430 |
else:
|
| 431 |
if cond_idx is not None: # prefill in inference
|
|
|
|
| 432 |
token_embeddings = self.cls_embedding(cond_idx, train=self.training)
|
| 433 |
token_embeddings = token_embeddings[:,:self.cls_token_num]
|
| 434 |
if condition is not None:
|
|
|
|
| 429 |
self.freqs_cis = self.freqs_cis.to(h.device)
|
| 430 |
else:
|
| 431 |
if cond_idx is not None: # prefill in inference
|
| 432 |
+
self.control_strength = control_strength
|
| 433 |
token_embeddings = self.cls_embedding(cond_idx, train=self.training)
|
| 434 |
token_embeddings = token_embeddings[:,:self.cls_token_num]
|
| 435 |
if condition is not None:
|