bkhmsi commited on
Commit
9a48e97
·
1 Parent(s): 4c963c4

removed flash attention

Browse files
models/micro_llama.py CHANGED
@@ -249,7 +249,7 @@ class MiCRoLlama(LlamaPreTrainedModel, GenerationMixin):
249
  self.config: MiCRoLlamaConfig = config
250
  self.config.torch_dtype = torch.bfloat16
251
  self.config.use_bfloat16 = True
252
- self.config._attn_implementation = "flash_attention_2" # {sdpa, flash_attention_2, eager}
253
  self.config.backbone_num_layers = self.config.num_hidden_layers
254
  self.config.num_hidden_layers = self.config.num_hidden_layers * run_config["num-experts"]
255
  self.config.loss_type = "ForCausalLMLoss"
 
249
  self.config: MiCRoLlamaConfig = config
250
  self.config.torch_dtype = torch.bfloat16
251
  self.config.use_bfloat16 = True
252
+ self.config._attn_implementation = "eager" # {sdpa, flash_attention_2, eager}
253
  self.config.backbone_num_layers = self.config.num_hidden_layers
254
  self.config.num_hidden_layers = self.config.num_hidden_layers * run_config["num-experts"]
255
  self.config.loss_type = "ForCausalLMLoss"
models/micro_moe_llama.py CHANGED
@@ -275,7 +275,7 @@ class MiCRoLlamaMoE(LlamaPreTrainedModel, GenerationMixin):
275
  self.config: MiCRoLlamaMoEConfig = config
276
  self.config.torch_dtype = torch.bfloat16
277
  self.config.use_bfloat16 = True
278
- self.config._attn_implementation = "flash_attention_2" # {sdpa, flash_attention_2, eager}
279
  self.config.use_cache = True
280
  self.config.backbone_num_layers = self.config.num_hidden_layers
281
  self.config.num_hidden_layers = self.config.num_hidden_layers
 
275
  self.config: MiCRoLlamaMoEConfig = config
276
  self.config.torch_dtype = torch.bfloat16
277
  self.config.use_bfloat16 = True
278
+ self.config._attn_implementation = "eager" # {sdpa, flash_attention_2, eager}
279
  self.config.use_cache = True
280
  self.config.backbone_num_layers = self.config.num_hidden_layers
281
  self.config.num_hidden_layers = self.config.num_hidden_layers
models/micro_olmo.py CHANGED
@@ -191,7 +191,7 @@ class MiCRoOLMo(Olmo2PreTrainedModel, GenerationMixin):
191
  self.config: Olmo2Config = config
192
  self.config.torch_dtype = torch.bfloat16
193
  self.config.use_bfloat16 = True
194
- self.config._attn_implementation = "flash_attention_2" # {sdpa, flash_attention_2, eager}
195
  self.config.use_cache = True
196
  self.config.backbone_num_layers = self.config.num_hidden_layers
197
  self.config.num_hidden_layers = self.config.num_hidden_layers * run_config["num-experts"]
 
191
  self.config: Olmo2Config = config
192
  self.config.torch_dtype = torch.bfloat16
193
  self.config.use_bfloat16 = True
194
+ self.config._attn_implementation = "eager" # {sdpa, flash_attention_2, eager}
195
  self.config.use_cache = True
196
  self.config.backbone_num_layers = self.config.num_hidden_layers
197
  self.config.num_hidden_layers = self.config.num_hidden_layers * run_config["num-experts"]
requirements.txt CHANGED
@@ -3,5 +3,4 @@ plotly>=5.22.0
3
  pandas>=2.2.0
4
  torch==2.7.1
5
  transformers==4.53.2
6
- numpy==2.3.4
7
- flash-attn
 
3
  pandas>=2.2.0
4
  torch==2.7.1
5
  transformers==4.53.2
6
+ numpy==2.3.4
 
router_backend.py CHANGED
@@ -195,7 +195,7 @@ def build_model(model_id: str, hf_token: str, use_cache: bool = True):
195
 
196
  model_config.torch_dtype = torch.bfloat16
197
  model_config.use_bfloat16 = True
198
- model_config._attn_implementation = "flash_attention_2"
199
  model_config.use_cache = use_cache
200
  model_config.ablate = []
201
 
 
195
 
196
  model_config.torch_dtype = torch.bfloat16
197
  model_config.use_bfloat16 = True
198
+ model_config._attn_implementation = "eager" # {sdpa, flash_attention_2, eager}
199
  model_config.use_cache = use_cache
200
  model_config.ablate = []
201