Transformers-tenets / src /fragments /glm-compare.html
Molbap's picture
Molbap HF Staff
feat: Add all changes and images via LFS
c0a0e96
raw
history blame
6.22 kB
<div class="code-compare" style="display: grid; grid-template-columns: 1fr 1fr; gap: 1rem; margin: 1.5rem 0;">
<div class="code-column" style="border: 1px solid #e2e8f0; border-radius: 8px; overflow: hidden;">
<div class="code-header" style="background: #f8f9fa; padding: 0.75rem 1rem; font-weight: 600; color: #495057; border-bottom: 1px solid #e2e8f0;">
modular_glm.py
</div>
<pre style="margin: 0; padding: 1rem; background: #ffffff; overflow-x: auto; font-size: 0.9em;"><code class="language-python">class GlmMLP(Phi3MLP):
pass
class GlmAttention(LlamaAttention):
def __init__(self, config, layer_idx=None):
super().__init__(config, layer_idx)
self.o_proj = nn.Linear(
config.num_attention_heads * self.head_dim,
config.hidden_size,
bias=False
)
class GlmForCausalLM(LlamaForCausalLM):
pass</code></pre>
</div>
<div class="code-column" style="border: 1px solid #e2e8f0; border-radius: 8px; overflow: hidden;">
<div class="code-header" style="background: #f8f9fa; padding: 0.75rem 1rem; font-weight: 600; color: #495057; border-bottom: 1px solid #e2e8f0;">
modeling_glm.py (auto-expanded)
</div>
<pre style="margin: 0; padding: 1rem; background: #ffffff; overflow-x: auto; font-size: 0.9em; max-height: 400px;"><code class="language-python">class GlmMLP(nn.Module):
def __init__(self, config):
super().__init__()
self.config = config
self.gate_up_proj = nn.Linear(
config.hidden_size,
2 * config.intermediate_size,
bias=False
)
self.down_proj = nn.Linear(
config.intermediate_size,
config.hidden_size,
bias=False
)
self.activation_fn = ACT2FN[config.hidden_act]
def forward(self, hidden_states: torch.FloatTensor) -> torch.FloatTensor:
up_states = self.gate_up_proj(hidden_states)
gate, up_states = up_states.chunk(2, dim=-1)
up_states = up_states * self.activation_fn(gate)
return self.down_proj(up_states)
class GlmAttention(nn.Module):
"""Multi-headed attention from 'Attention Is All You Need' paper"""
def __init__(self, config: GlmConfig, layer_idx: Optional[int] = None):
super().__init__()
self.config = config
self.layer_idx = layer_idx
self.head_dim = getattr(
config, "head_dim",
config.hidden_size // config.num_attention_heads
)
self.num_key_value_groups = (
config.num_attention_heads // config.num_key_value_heads
)
self.scaling = self.head_dim**-0.5
self.attention_dropout = config.attention_dropout
self.is_causal = True
self.q_proj = nn.Linear(
config.hidden_size,
config.num_attention_heads * self.head_dim,
bias=config.attention_bias
)
self.k_proj = nn.Linear(
config.hidden_size,
config.num_key_value_heads * self.head_dim,
bias=config.attention_bias
)
self.v_proj = nn.Linear(
config.hidden_size,
config.num_key_value_heads * self.head_dim,
bias=config.attention_bias
)
self.o_proj = nn.Linear(
config.num_attention_heads * self.head_dim,
config.hidden_size,
bias=False
)
def forward(
self,
hidden_states: torch.Tensor,
position_embeddings: Tuple[torch.Tensor, torch.Tensor],
attention_mask: Optional[torch.Tensor],
past_key_value: Optional[Cache] = None,
cache_position: Optional[torch.LongTensor] = None,
**kwargs: Unpack[FlashAttentionKwargs],
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
input_shape = hidden_states.shape[:-1]
hidden_shape = (*input_shape, -1, self.head_dim)
query_states = self.q_proj(hidden_states).view(hidden_shape).transpose(1, 2)
key_states = self.k_proj(hidden_states).view(hidden_shape).transpose(1, 2)
value_states = self.v_proj(hidden_states).view(hidden_shape).transpose(1, 2)
cos, sin = position_embeddings
query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)
if past_key_value is not None:
cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position}
key_states, value_states = past_key_value.update(
key_states, value_states, self.layer_idx, cache_kwargs
)
attention_interface: Callable = eager_attention_forward
if self.config._attn_implementation != "eager":
attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation]
attn_output, attn_weights = attention_interface(
self, query_states, key_states, value_states,
attention_mask, dropout=0.0 if not self.training else self.attention_dropout,
scaling=self.scaling, **kwargs,
)
attn_output = attn_output.reshape(*input_shape, -1).contiguous()
attn_output = self.o_proj(attn_output)
return attn_output, attn_weights
@use_kernel_forward_from_hub("RMSNorm")
class GlmRMSNorm(nn.Module):
def __init__(self, hidden_size, eps=1e-6):
super().__init__()
self.weight = nn.Parameter(torch.ones(hidden_size))
self.variance_epsilon = eps
def forward(self, hidden_states):
input_dtype = hidden_states.dtype
hidden_states = hidden_states.to(torch.float32)
variance = hidden_states.pow(2).mean(-1, keepdim=True)
hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
return self.weight * hidden_states.to(input_dtype)
# ... (many more classes and functions would follow)</code></pre>
</div>
</div>
<p style="text-align: center; font-style: italic; color: #6c757d; margin-top: 1rem;">
<strong>Left:</strong> Clean modular definition with inheritance.
<strong>Right:</strong> Auto-expanded version with all inherited functionality visible.
</p>