Molbap HF Staff commited on
Commit
2e2c7be
·
verified ·
1 Parent(s): cc25da0

Create modeling_glm.py

Browse files
Files changed (1) hide show
  1. content/modeling_glm.py +131 -0
content/modeling_glm.py ADDED
@@ -0,0 +1,131 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ class GlmMLP(nn.Module):
2
+ def __init__(self, config):
3
+ super().__init__()
4
+
5
+ self.config = config
6
+ self.gate_up_proj = nn.Linear(config.hidden_size, 2 * config.intermediate_size, bias=False)
7
+ self.down_proj = nn.Linear(config.intermediate_size, config.hidden_size, bias=False)
8
+ self.activation_fn = ACT2FN[config.hidden_act]
9
+
10
+ def forward(self, hidden_states: torch.FloatTensor) -> torch.FloatTensor:
11
+ up_states = self.gate_up_proj(hidden_states)
12
+
13
+ gate, up_states = up_states.chunk(2, dim=-1)
14
+ up_states = up_states * self.activation_fn(gate)
15
+
16
+ return self.down_proj(up_states)
17
+
18
+
19
+ class GlmAttention(nn.Module):
20
+ """Multi-headed attention from 'Attention Is All You Need' paper"""
21
+
22
+ def __init__(self, config: GlmConfig, layer_idx: Optional[int] = None):
23
+ super().__init__()
24
+ self.config = config
25
+ self.layer_idx = layer_idx
26
+ self.head_dim = getattr(config, "head_dim", config.hidden_size // config.num_attention_heads)
27
+ self.num_key_value_groups = config.num_attention_heads // config.num_key_value_heads
28
+ self.scaling = self.head_dim**-0.5
29
+ self.attention_dropout = config.attention_dropout
30
+ self.is_causal = True
31
+
32
+ self.q_proj = nn.Linear(
33
+ config.hidden_size, config.num_attention_heads * self.head_dim, bias=config.attention_bias
34
+ )
35
+ self.k_proj = nn.Linear(
36
+ config.hidden_size, config.num_key_value_heads * self.head_dim, bias=config.attention_bias
37
+ )
38
+ self.v_proj = nn.Linear(
39
+ config.hidden_size, config.num_key_value_heads * self.head_dim, bias=config.attention_bias
40
+ )
41
+ self.o_proj = nn.Linear(config.num_attention_heads * self.head_dim, config.hidden_size, bias=False)
42
+
43
+ def forward(
44
+ self,
45
+ hidden_states: torch.Tensor,
46
+ position_embeddings: Tuple[torch.Tensor, torch.Tensor],
47
+ attention_mask: Optional[torch.Tensor],
48
+ past_key_value: Optional[Cache] = None,
49
+ cache_position: Optional[torch.LongTensor] = None,
50
+ **kwargs: Unpack[FlashAttentionKwargs],
51
+ ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
52
+ input_shape = hidden_states.shape[:-1]
53
+ hidden_shape = (*input_shape, -1, self.head_dim)
54
+
55
+ query_states = self.q_proj(hidden_states).view(hidden_shape).transpose(1, 2)
56
+ key_states = self.k_proj(hidden_states).view(hidden_shape).transpose(1, 2)
57
+ value_states = self.v_proj(hidden_states).view(hidden_shape).transpose(1, 2)
58
+
59
+ cos, sin = position_embeddings
60
+ query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)
61
+
62
+ if past_key_value is not None:
63
+ # sin and cos are specific to RoPE models; cache_position needed for the static cache
64
+ cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position}
65
+ key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)
66
+
67
+ attention_interface: Callable = eager_attention_forward
68
+
69
+ if self.config._attn_implementation != "eager":
70
+ if self.config._attn_implementation == "sdpa" and kwargs.get("output_attentions", False):
71
+ logger.warning_once(
72
+ "`torch.nn.functional.scaled_dot_product_attention` does not support `output_attentions=True`. Falling back to "
73
+ 'eager attention. This warning can be removed using the argument `attn_implementation="eager"` when loading the model.'
74
+ )
75
+ else:
76
+ attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation]
77
+
78
+ attn_output, attn_weights = attention_interface(
79
+ self,
80
+ query_states,
81
+ key_states,
82
+ value_states,
83
+ attention_mask,
84
+ dropout=0.0 if not self.training else self.attention_dropout,
85
+ scaling=self.scaling,
86
+ **kwargs,
87
+ )
88
+
89
+ attn_output = attn_output.reshape(*input_shape, -1).contiguous()
90
+ attn_output = self.o_proj(attn_output)
91
+ return attn_output, attn_weights
92
+
93
+
94
+ @use_kernel_forward_from_hub("RMSNorm")
95
+ class GlmRMSNorm(nn.Module):
96
+ def __init__(self, hidden_size, eps=1e-6):
97
+ """
98
+ GlmRMSNorm is equivalent to T5LayerNorm
99
+ """
100
+ super().__init__()
101
+ self.weight = nn.Parameter(torch.ones(hidden_size))
102
+ self.variance_epsilon = eps
103
+
104
+ def forward(self, hidden_states):
105
+ input_dtype = hidden_states.dtype
106
+ hidden_states = hidden_states.to(torch.float32)
107
+ variance = hidden_states.pow(2).mean(-1, keepdim=True)
108
+ hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
109
+ return self.weight * hidden_states.to(input_dtype)
110
+
111
+ def extra_repr(self):
112
+ return f"{tuple(self.weight.shape)}, eps={self.variance_epsilon}"
113
+
114
+
115
+ class GlmRotaryEmbedding(nn.Module):
116
+ def __init__(self, config: GlmConfig, device=None):
117
+ super().__init__()
118
+ # BC: "rope_type" was originally "type"
119
+ if hasattr(config, "rope_scaling") and config.rope_scaling is not None:
120
+ self.rope_type = config.rope_scaling.get("rope_type", config.rope_scaling.get("type"))
121
+ else:
122
+ self.rope_type = "default"
123
+ self.max_seq_len_cached = config.max_position_embeddings
124
+ self.original_max_seq_len = config.max_position_embeddings
125
+
126
+ self.config = config
127
+ self.rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type]
128
+
129
+ inv_freq, self.attention_scaling = self.rope_init_fn(self.config, device)
130
+ self.register_buffer("inv_freq", inv_freq, persistent=False)
131
+ self.original_inv_freq = self.inv_freq