kamalrajkannanmcw commited on
Commit
209ae73
·
1 Parent(s): 2c968b4

Remove hardcoded .cuda() calls to support single forward pass on CPU and ensure DeepSeekOCR model compatibility with transformers==4.52.4

Browse files
modeling_deepseekocr.py CHANGED
@@ -502,7 +502,7 @@ class DeepseekOCRModel(DeepseekV2Model):
502
  images_in_this_batch = torch.cat(images_in_this_batch, dim=0)
503
  # exit()
504
 
505
- inputs_embeds[idx].masked_scatter_(images_seq_mask[idx].unsqueeze(-1).cuda(), images_in_this_batch)
506
 
507
  idx += 1
508
 
 
502
  images_in_this_batch = torch.cat(images_in_this_batch, dim=0)
503
  # exit()
504
 
505
+ inputs_embeds[idx].masked_scatter_(images_seq_mask[idx].unsqueeze(-1), images_in_this_batch)
506
 
507
  idx += 1
508
 
modeling_deepseekv2.py CHANGED
@@ -36,7 +36,6 @@ from transformers.cache_utils import Cache, DynamicCache
36
  from transformers.modeling_attn_mask_utils import _prepare_4d_causal_attention_mask
37
  from transformers.models.llama.modeling_llama import (
38
  LlamaAttention,
39
- LlamaFlashAttention2
40
  )
41
  from transformers.modeling_outputs import (
42
  BaseModelOutputWithPast,
@@ -60,6 +59,8 @@ from transformers.utils.import_utils import is_torch_fx_available
60
 
61
  from .configuration_deepseek_v2 import DeepseekV2Config
62
 
 
 
63
  if is_flash_attn_2_available():
64
  from flash_attn import flash_attn_func, flash_attn_varlen_func
65
  from flash_attn.bert_padding import index_first_axis, pad_input, unpad_input # noqa
@@ -1235,7 +1236,6 @@ ATTENTION_CLASSES = {
1235
  "mla_flash_attention_2": DeepseekV2FlashAttention2,
1236
 
1237
  "mha_eager": LlamaAttention,
1238
- "mha_flash_attention_2": LlamaFlashAttention2
1239
  }
1240
 
1241
 
@@ -1269,6 +1269,8 @@ class DeepseekV2DecoderLayer(nn.Module):
1269
  self.post_attention_layernorm = DeepseekV2RMSNorm(
1270
  config.hidden_size, eps=config.rms_norm_eps
1271
  )
 
 
1272
 
1273
  def forward(
1274
  self,
@@ -1303,15 +1305,18 @@ class DeepseekV2DecoderLayer(nn.Module):
1303
  residual = hidden_states
1304
 
1305
  hidden_states = self.input_layernorm(hidden_states)
 
 
1306
 
1307
  # Self Attention
1308
- hidden_states, self_attn_weights, present_key_value = self.self_attn(
1309
  hidden_states=hidden_states,
1310
  attention_mask=attention_mask,
1311
  position_ids=position_ids,
1312
  past_key_value=past_key_value,
1313
  output_attentions=output_attentions,
1314
  use_cache=use_cache,
 
1315
  **kwargs,
1316
  )
1317
  hidden_states = residual + hidden_states
@@ -1327,9 +1332,6 @@ class DeepseekV2DecoderLayer(nn.Module):
1327
  if output_attentions:
1328
  outputs += (self_attn_weights,)
1329
 
1330
- if use_cache:
1331
- outputs += (present_key_value,)
1332
-
1333
  return outputs
1334
 
1335
 
 
36
  from transformers.modeling_attn_mask_utils import _prepare_4d_causal_attention_mask
37
  from transformers.models.llama.modeling_llama import (
38
  LlamaAttention,
 
39
  )
40
  from transformers.modeling_outputs import (
41
  BaseModelOutputWithPast,
 
59
 
60
  from .configuration_deepseek_v2 import DeepseekV2Config
61
 
62
+ from transformers.models.llama.modeling_llama import LlamaRotaryEmbedding
63
+
64
  if is_flash_attn_2_available():
65
  from flash_attn import flash_attn_func, flash_attn_varlen_func
66
  from flash_attn.bert_padding import index_first_axis, pad_input, unpad_input # noqa
 
1236
  "mla_flash_attention_2": DeepseekV2FlashAttention2,
1237
 
1238
  "mha_eager": LlamaAttention,
 
1239
  }
1240
 
1241
 
 
1269
  self.post_attention_layernorm = DeepseekV2RMSNorm(
1270
  config.hidden_size, eps=config.rms_norm_eps
1271
  )
1272
+ # Compute position_embeddings
1273
+ self.rotary_emb = LlamaRotaryEmbedding(config=config)
1274
 
1275
  def forward(
1276
  self,
 
1305
  residual = hidden_states
1306
 
1307
  hidden_states = self.input_layernorm(hidden_states)
1308
+
1309
+ position_embeddings = self.rotary_emb(hidden_states, position_ids)
1310
 
1311
  # Self Attention
1312
+ hidden_states, self_attn_weights = self.self_attn(
1313
  hidden_states=hidden_states,
1314
  attention_mask=attention_mask,
1315
  position_ids=position_ids,
1316
  past_key_value=past_key_value,
1317
  output_attentions=output_attentions,
1318
  use_cache=use_cache,
1319
+ position_embeddings=position_embeddings,
1320
  **kwargs,
1321
  )
1322
  hidden_states = residual + hidden_states
 
1332
  if output_attentions:
1333
  outputs += (self_attn_weights,)
1334
 
 
 
 
1335
  return outputs
1336
 
1337