File size: 6,953 Bytes
73f8595 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 |
# /// script
# dependencies = [
# "torch",
# "numpy",
# ]
# ///
import torch
from torch import nn
from torch.nn import functional as F
from bench_utils import to_dtype, tensor_stats, set_seed, bench_context
from config import (
NUM_EXPERTS, HIDDEN_SIZE, TOP_K,
BATCH_SIZE, SEQ_LEN, DTYPE, DEVICE,
WEIGHT_SEED, EXPERT_SEED, INPUT_SEED, GENERAL_SEED
)
from pathlib import Path
import os
# Discover the upstream artifact directory from env
data_dir = os.environ.get('UVNOTE_INPUT_SAVE_DATA', '.')
router_weight = torch.load(Path(data_dir) / 'router_weight.pt')
router_bias = torch.load(Path(data_dir) / 'router_bias.pt')
gate_up_proj = torch.load(Path(data_dir) / 'gate_up_proj.pt')
gate_up_proj_bias = torch.load(Path(data_dir) / 'gate_up_proj_bias.pt')
down_proj = torch.load(Path(data_dir) / 'down_proj.pt')
down_proj_bias = torch.load(Path(data_dir) / 'down_proj_bias.pt')
print("Loaded shared weights from artifacts")
print(f"Router weight sum: {router_weight.sum().item():.6f}")
print(f"Gate/up sum: {gate_up_proj.sum().item():.6f}")
print(f"Down sum: {down_proj.sum().item():.6f}")
def binned_gather(x, indices, bins, expert_capacity, top_k):
E, H = bins.shape[0], x.shape[1]
out = torch.zeros((E, expert_capacity, H), device=x.device, dtype=x.dtype)
for e in range(E):
start = 0 if e == 0 else bins[e - 1]
end = bins[e]
n = min(end - start, expert_capacity)
for i in range(n):
flat_pos = indices[start + i]
tok = flat_pos // top_k
out[e, i] = x[tok]
return out
def binned_scatter(x, indices, weights, bins, expert_capacity, top_k):
E, C, H = x.shape
N = indices.shape[0] // top_k
out = torch.zeros((N, top_k, H), dtype=x.dtype, device=x.device)
for e in range(E):
start = 0 if e == 0 else bins[e - 1]
end = bins[e]
n = end - start
if n == 0:
continue
take = min(n, expert_capacity)
for i in range(take):
flat_pos = indices[start + i]
tok = flat_pos // top_k
slot = flat_pos % top_k
scale = weights[flat_pos] if weights is not None else 1.0
out[tok, slot] = x[e, i] * scale
return out.sum(dim=1)
def sort_tokens_by_expert(router_indices, num_experts):
flat_indices = router_indices.flatten()
sorted_values, sorted_indices = torch.sort(flat_indices)
tokens_per_expert = torch.bincount(sorted_values, minlength=num_experts)
bins = torch.cumsum(tokens_per_expert, dim=0)
return sorted_indices, sorted_values, bins, tokens_per_expert
def binned_experts_ref(
hidden_states,
router_indices,
routing_weights,
gate_up_proj,
gate_up_proj_bias,
down_proj,
down_proj_bias,
expert_capacity,
):
B, S, H = hidden_states.shape
E, K = routing_weights.shape[1], router_indices.shape[1]
indices, _, bins, _ = sort_tokens_by_expert(router_indices, E)
x = binned_gather(hidden_states.view(-1, H), indices, bins, expert_capacity, K)
gate_up = torch.bmm(x, gate_up_proj)
gate_up += gate_up_proj_bias[..., None, :]
gate, up = gate_up[..., ::2], gate_up[..., 1::2]
# clamp to limit
limit = 7.0
gate = gate.clamp(min=None, max=limit)
up = up.clamp(min=-limit, max=limit)
glu = gate * torch.sigmoid(gate * 1.702)
x = (up + 1) * glu
x = torch.bmm(x, down_proj) + down_proj_bias[..., None, :]
# build routing weights aligned to (token, slot)
flat_dense = routing_weights.view(-1, E)
flat_router = router_indices.view(-1, K)
selected = torch.gather(flat_dense, 1, flat_router).reshape(-1)
# scatter back
y = binned_scatter(x, indices, selected, bins, expert_capacity, K)
return y.view(B, S, H)
class BinnedRouter(nn.Module):
def __init__(self, router_weight, router_bias):
super().__init__()
self.top_k = TOP_K
self.num_experts = NUM_EXPERTS
self.hidden_dim = HIDDEN_SIZE
self.weight = nn.Parameter(router_weight.clone())
self.bias = nn.Parameter(router_bias.clone())
def forward(self, hidden_states):
hidden_states = hidden_states.reshape(-1, self.hidden_dim)
router_logits = F.linear(hidden_states, self.weight, self.bias)
router_top_value, router_indices = torch.topk(router_logits, self.top_k, dim=-1)
router_top_value = torch.nn.functional.softmax(router_top_value, dim=1, dtype=router_top_value.dtype)
router_scores = torch.zeros_like(router_logits).scatter_(1, router_indices, router_top_value)
return router_scores, router_indices
def ceil_div(a, b):
return (a + b - 1) // b
class BinnedMoEMLP(nn.Module):
def __init__(self, router_weight, router_bias, gate_up_proj, gate_up_proj_bias, down_proj, down_proj_bias):
super().__init__()
self.router = BinnedRouter(router_weight, router_bias)
self.num_experts = NUM_EXPERTS
self.hidden_size = HIDDEN_SIZE
self.top_k = TOP_K
# Expert weights - use the loaded weights
self.gate_up_proj = nn.Parameter(gate_up_proj.clone())
self.gate_up_proj_bias = nn.Parameter(gate_up_proj_bias.clone())
self.down_proj = nn.Parameter(down_proj.clone())
self.down_proj_bias = nn.Parameter(down_proj_bias.clone())
def forward(self, hidden_states):
router_scores, router_indices = self.router(hidden_states)
batch_size = hidden_states.shape[0]
expert_capacity = ceil_div(batch_size * self.top_k, self.num_experts)
output = binned_experts_ref(
hidden_states,
router_indices,
router_scores,
self.gate_up_proj,
self.gate_up_proj_bias,
self.down_proj,
self.down_proj_bias,
expert_capacity,
)
return output, router_scores
# Run the model
set_seed(GENERAL_SEED)
device = torch.device(DEVICE)
dtype = to_dtype(DTYPE)
print("\n=== Binned Implementation ===")
# Initialize model with loaded weights
model = BinnedMoEMLP(
router_weight.to(device),
router_bias.to(device),
gate_up_proj.to(device),
gate_up_proj_bias.to(device),
down_proj.to(device),
down_proj_bias.to(device)
).to(device=device)
print(f"Router weight sum: {model.router.weight.sum().item():.6f}")
print(f"Gate/up proj sum: {model.gate_up_proj.sum().item():.6f}")
print(f"Down proj sum: {model.down_proj.sum().item():.6f}")
# Generate the same input as Yamoe
set_seed(INPUT_SEED)
x = torch.randn(BATCH_SIZE, SEQ_LEN, HIDDEN_SIZE, device=device, dtype=dtype) * 0.1
# Benchmark the model with varied inputs to prevent caching artifacts
tokens = BATCH_SIZE * SEQ_LEN
with bench_context(warmup=10, iters=50, device=device, dtype=dtype, tokens=tokens, save_json="binned_results.json", vary_inputs=True) as bench:
output, stats = bench(model, x)
print(f"\nOutput sum: {output[0].sum().item():.6f}") |