Spaces:
Running
on
Zero
Running
on
Zero
New masking implementation
Browse files- llama_diffusion_model.py +6 -5
llama_diffusion_model.py
CHANGED
|
@@ -97,8 +97,8 @@ class CustomTransformerModel(PreTrainedModel):
|
|
| 97 |
self.llama = AutoModelForCausalLM.from_pretrained("meta-llama/Llama-3.2-3B", torch_dtype=torch.float16, device_map="auto", token=hf_token)
|
| 98 |
self.llama.resize_token_embeddings(config.vocab_size)
|
| 99 |
|
| 100 |
-
for i, layer in enumerate(self.llama.model.layers):
|
| 101 |
-
|
| 102 |
|
| 103 |
for param in self.llama.parameters():
|
| 104 |
param.requires_grad = False
|
|
@@ -113,7 +113,7 @@ class CustomTransformerModel(PreTrainedModel):
|
|
| 113 |
|
| 114 |
self.llama = get_peft_model(self.llama, lora_config)
|
| 115 |
self.llama.print_trainable_parameters()
|
| 116 |
-
self.llama = self.llama.to(torch.float16)
|
| 117 |
|
| 118 |
def forward(self, input_ids, labels=None, **kwargs):
|
| 119 |
batch_size, seq_len = input_ids.shape
|
|
@@ -121,8 +121,8 @@ class CustomTransformerModel(PreTrainedModel):
|
|
| 121 |
|
| 122 |
# Build attention mask
|
| 123 |
device = input_ids.device
|
| 124 |
-
|
| 125 |
-
masking_type = getattr(self.config, "masking_type", "
|
| 126 |
if masking_type == 'bidirectional':
|
| 127 |
base_mask = torch.ones(seq_len, seq_len, dtype=torch.bool, device=device)
|
| 128 |
elif masking_type == 'bidirectional_masked':
|
|
@@ -142,6 +142,7 @@ class CustomTransformerModel(PreTrainedModel):
|
|
| 142 |
input_ids,
|
| 143 |
attention_mask=attention_mask,
|
| 144 |
output_hidden_states=True,
|
|
|
|
| 145 |
**kwargs
|
| 146 |
)
|
| 147 |
|
|
|
|
| 97 |
self.llama = AutoModelForCausalLM.from_pretrained("meta-llama/Llama-3.2-3B", torch_dtype=torch.float16, device_map="auto", token=hf_token)
|
| 98 |
self.llama.resize_token_embeddings(config.vocab_size)
|
| 99 |
|
| 100 |
+
# for i, layer in enumerate(self.llama.model.layers):
|
| 101 |
+
# layer.self_attn = BidirectionalLlamaAttention(layer.self_attn, masking=config.masking_type)
|
| 102 |
|
| 103 |
for param in self.llama.parameters():
|
| 104 |
param.requires_grad = False
|
|
|
|
| 113 |
|
| 114 |
self.llama = get_peft_model(self.llama, lora_config)
|
| 115 |
self.llama.print_trainable_parameters()
|
| 116 |
+
# self.llama = self.llama.to(torch.float16)
|
| 117 |
|
| 118 |
def forward(self, input_ids, labels=None, **kwargs):
|
| 119 |
batch_size, seq_len = input_ids.shape
|
|
|
|
| 121 |
|
| 122 |
# Build attention mask
|
| 123 |
device = input_ids.device
|
| 124 |
+
|
| 125 |
+
masking_type = getattr(self.config, "masking_type", "bidirectional_masked")
|
| 126 |
if masking_type == 'bidirectional':
|
| 127 |
base_mask = torch.ones(seq_len, seq_len, dtype=torch.bool, device=device)
|
| 128 |
elif masking_type == 'bidirectional_masked':
|
|
|
|
| 142 |
input_ids,
|
| 143 |
attention_mask=attention_mask,
|
| 144 |
output_hidden_states=True,
|
| 145 |
+
use_cache=False,
|
| 146 |
**kwargs
|
| 147 |
)
|
| 148 |
|