File size: 6,162 Bytes
73f8595 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 |
# /// script
# dependencies = [
# "torch",
# "numpy",
# ]
# ///
import torch
from torch import nn
from torch.nn import functional as F
from bench_utils import to_dtype, tensor_stats, set_seed, bench_context
from config import (
NUM_EXPERTS, HIDDEN_SIZE, TOP_K,
BATCH_SIZE, SEQ_LEN, DTYPE, DEVICE,
WEIGHT_SEED, EXPERT_SEED, INPUT_SEED, GENERAL_SEED
)
from pathlib import Path
import os
# Discover the upstream artifact directory from env
data_dir = os.environ.get('UVNOTE_INPUT_SAVE_DATA', '.')
router_weight = torch.load(Path(data_dir) / 'router_weight.pt')
router_bias = torch.load(Path(data_dir) / 'router_bias.pt')
gate_up_proj = torch.load(Path(data_dir) / 'gate_up_proj.pt')
gate_up_proj_bias = torch.load(Path(data_dir) / 'gate_up_proj_bias.pt')
down_proj = torch.load(Path(data_dir) / 'down_proj.pt')
down_proj_bias = torch.load(Path(data_dir) / 'down_proj_bias.pt')
print("Loaded shared weights from artifacts")
print(f"Router weight sum: {router_weight.sum().item():.6f}")
print(f"Gate/up sum: {gate_up_proj.sum().item():.6f}")
print(f"Down sum: {down_proj.sum().item():.6f}")
class GptOssTrainingRouter(nn.Module):
def __init__(self, router_weight, router_bias):
super().__init__()
self.top_k = TOP_K
self.num_experts = NUM_EXPERTS
self.hidden_dim = HIDDEN_SIZE
self.weight = nn.Parameter(router_weight.clone())
self.bias = nn.Parameter(router_bias.clone())
def forward(self, hidden_states):
hidden_states = hidden_states.reshape(-1, self.hidden_dim)
router_logits = F.linear(hidden_states, self.weight, self.bias)
router_top_value, router_indices = torch.topk(router_logits, self.top_k, dim=-1)
router_top_value = torch.nn.functional.softmax(router_top_value, dim=1, dtype=router_top_value.dtype)
router_scores = torch.zeros_like(router_logits).scatter_(1, router_indices, router_top_value)
return router_scores, router_indices
class GptOssTrainingExperts(nn.Module):
def __init__(self, gate_up_proj, gate_up_proj_bias, down_proj, down_proj_bias):
super().__init__()
self.num_experts = NUM_EXPERTS
self.hidden_size = HIDDEN_SIZE
self.expert_dim = self.hidden_size
self.gate_up_proj = nn.Parameter(gate_up_proj.clone())
self.gate_up_proj_bias = nn.Parameter(gate_up_proj_bias.clone())
self.down_proj = nn.Parameter(down_proj.clone())
self.down_proj_bias = nn.Parameter(down_proj_bias.clone())
self.alpha = 1.702
self.limit = 7.0
def forward(self, hidden_states: torch.Tensor, router_indices=None, routing_weights=None) -> torch.Tensor:
batch_size = hidden_states.shape[0]
hidden_states = hidden_states.reshape(-1, self.hidden_size)
num_experts = routing_weights.shape[1]
# Force training mode path (expert loop instead of batched)
next_states = torch.zeros_like(hidden_states, dtype=hidden_states.dtype, device=hidden_states.device)
with torch.no_grad():
expert_mask = torch.nn.functional.one_hot(router_indices, num_classes=num_experts)
expert_mask = expert_mask.permute(2, 1, 0)
expert_hit = torch.greater(expert_mask.sum(dim=(-1, -2)), 0).nonzero()
for expert_idx in expert_hit[:]:
expert_idx = expert_idx[0]
with torch.no_grad():
_, token_idx = torch.where(expert_mask[expert_idx])
current_state = hidden_states[token_idx]
gate_up = current_state @ self.gate_up_proj[expert_idx] + self.gate_up_proj_bias[expert_idx]
gate, up = gate_up[..., ::2], gate_up[..., 1::2]
gate = gate.clamp(min=None, max=self.limit)
up = up.clamp(min=-self.limit, max=self.limit)
glu = gate * torch.sigmoid(gate * self.alpha)
gated_output = (up + 1) * glu
out = gated_output @ self.down_proj[expert_idx] + self.down_proj_bias[expert_idx]
weighted_output = out * routing_weights[token_idx, expert_idx, None]
next_states.index_add_(0, token_idx, weighted_output.to(hidden_states.dtype))
next_states = next_states.view(batch_size, -1, self.hidden_size)
return next_states
class GptOssTrainingMoEMLP(nn.Module):
def __init__(self, router_weight, router_bias, gate_up_proj, gate_up_proj_bias, down_proj, down_proj_bias):
super().__init__()
self.router = GptOssTrainingRouter(router_weight, router_bias)
self.experts = GptOssTrainingExperts(gate_up_proj, gate_up_proj_bias, down_proj, down_proj_bias)
def forward(self, hidden_states):
router_scores, router_indices = self.router(hidden_states)
routed_out = self.experts(hidden_states, router_indices=router_indices, routing_weights=router_scores)
return routed_out, router_scores
# Run the model
set_seed(GENERAL_SEED)
device = torch.device(DEVICE)
dtype = to_dtype(DTYPE)
print("\n=== GPT-OSS Implementation (Training Mode - Expert Loop) ===")
# Initialize model with loaded weights and force training mode
model = GptOssTrainingMoEMLP(
router_weight.to(device),
router_bias.to(device),
gate_up_proj.to(device),
gate_up_proj_bias.to(device),
down_proj.to(device),
down_proj_bias.to(device)
).to(device=device)
# Set to training mode to force expert loop path
model.train()
print(f"Router weight sum: {model.router.weight.sum().item():.6f}")
print(f"Gate/up proj sum: {model.experts.gate_up_proj.sum().item():.6f}")
print(f"Down proj sum: {model.experts.down_proj.sum().item():.6f}")
print(f"Model training mode: {model.training}")
# Generate the same input as other implementations
set_seed(INPUT_SEED)
x = torch.randn(BATCH_SIZE, SEQ_LEN, HIDDEN_SIZE, device=device, dtype=dtype) * 0.1
# Benchmark the model with varied inputs to prevent caching artifacts
tokens = BATCH_SIZE * SEQ_LEN
with bench_context(warmup=10, iters=50, device=device, dtype=dtype, tokens=tokens, save_json="gptoss_training_results.json", vary_inputs=True) as bench:
output, stats = bench(model, x)
print(f"\nOutput sum: {output[0].sum().item():.6f}") |