|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
import torch |
|
|
from torch import nn |
|
|
from torch.nn import functional as F |
|
|
from bench_utils import to_dtype, tensor_stats, set_seed, bench_context |
|
|
from config import ( |
|
|
NUM_EXPERTS, HIDDEN_SIZE, TOP_K, |
|
|
BATCH_SIZE, SEQ_LEN, DTYPE, DEVICE, |
|
|
WEIGHT_SEED, EXPERT_SEED, INPUT_SEED, GENERAL_SEED |
|
|
) |
|
|
from pathlib import Path |
|
|
import os |
|
|
|
|
|
|
|
|
data_dir = os.environ.get('UVNOTE_INPUT_SAVE_DATA', '.') |
|
|
|
|
|
router_weight = torch.load(Path(data_dir) / 'router_weight.pt') |
|
|
router_bias = torch.load(Path(data_dir) / 'router_bias.pt') |
|
|
gate_up_proj = torch.load(Path(data_dir) / 'gate_up_proj.pt') |
|
|
gate_up_proj_bias = torch.load(Path(data_dir) / 'gate_up_proj_bias.pt') |
|
|
down_proj = torch.load(Path(data_dir) / 'down_proj.pt') |
|
|
down_proj_bias = torch.load(Path(data_dir) / 'down_proj_bias.pt') |
|
|
|
|
|
print("Loaded shared weights from artifacts") |
|
|
print(f"Router weight sum: {router_weight.sum().item():.6f}") |
|
|
print(f"Gate/up sum: {gate_up_proj.sum().item():.6f}") |
|
|
print(f"Down sum: {down_proj.sum().item():.6f}") |
|
|
|
|
|
def binned_gather(x, indices, bins, expert_capacity, top_k): |
|
|
E, H = bins.shape[0], x.shape[1] |
|
|
out = torch.zeros((E, expert_capacity, H), device=x.device, dtype=x.dtype) |
|
|
for e in range(E): |
|
|
start = 0 if e == 0 else bins[e - 1] |
|
|
end = bins[e] |
|
|
n = min(end - start, expert_capacity) |
|
|
for i in range(n): |
|
|
flat_pos = indices[start + i] |
|
|
tok = flat_pos // top_k |
|
|
out[e, i] = x[tok] |
|
|
return out |
|
|
|
|
|
def binned_scatter(x, indices, weights, bins, expert_capacity, top_k): |
|
|
E, C, H = x.shape |
|
|
N = indices.shape[0] // top_k |
|
|
out = torch.zeros((N, top_k, H), dtype=x.dtype, device=x.device) |
|
|
for e in range(E): |
|
|
start = 0 if e == 0 else bins[e - 1] |
|
|
end = bins[e] |
|
|
n = end - start |
|
|
if n == 0: |
|
|
continue |
|
|
take = min(n, expert_capacity) |
|
|
for i in range(take): |
|
|
flat_pos = indices[start + i] |
|
|
tok = flat_pos // top_k |
|
|
slot = flat_pos % top_k |
|
|
scale = weights[flat_pos] if weights is not None else 1.0 |
|
|
out[tok, slot] = x[e, i] * scale |
|
|
return out.sum(dim=1) |
|
|
|
|
|
def sort_tokens_by_expert(router_indices, num_experts): |
|
|
flat_indices = router_indices.flatten() |
|
|
sorted_values, sorted_indices = torch.sort(flat_indices) |
|
|
tokens_per_expert = torch.bincount(sorted_values, minlength=num_experts) |
|
|
bins = torch.cumsum(tokens_per_expert, dim=0) |
|
|
return sorted_indices, sorted_values, bins, tokens_per_expert |
|
|
|
|
|
def binned_experts_ref( |
|
|
hidden_states, |
|
|
router_indices, |
|
|
routing_weights, |
|
|
gate_up_proj, |
|
|
gate_up_proj_bias, |
|
|
down_proj, |
|
|
down_proj_bias, |
|
|
expert_capacity, |
|
|
): |
|
|
B, S, H = hidden_states.shape |
|
|
E, K = routing_weights.shape[1], router_indices.shape[1] |
|
|
|
|
|
indices, _, bins, _ = sort_tokens_by_expert(router_indices, E) |
|
|
x = binned_gather(hidden_states.view(-1, H), indices, bins, expert_capacity, K) |
|
|
|
|
|
gate_up = torch.bmm(x, gate_up_proj) |
|
|
gate_up += gate_up_proj_bias[..., None, :] |
|
|
|
|
|
gate, up = gate_up[..., ::2], gate_up[..., 1::2] |
|
|
|
|
|
|
|
|
limit = 7.0 |
|
|
gate = gate.clamp(min=None, max=limit) |
|
|
up = up.clamp(min=-limit, max=limit) |
|
|
|
|
|
glu = gate * torch.sigmoid(gate * 1.702) |
|
|
x = (up + 1) * glu |
|
|
x = torch.bmm(x, down_proj) + down_proj_bias[..., None, :] |
|
|
|
|
|
|
|
|
flat_dense = routing_weights.view(-1, E) |
|
|
flat_router = router_indices.view(-1, K) |
|
|
selected = torch.gather(flat_dense, 1, flat_router).reshape(-1) |
|
|
|
|
|
|
|
|
y = binned_scatter(x, indices, selected, bins, expert_capacity, K) |
|
|
|
|
|
return y.view(B, S, H) |
|
|
|
|
|
class BinnedRouter(nn.Module): |
|
|
def __init__(self, router_weight, router_bias): |
|
|
super().__init__() |
|
|
self.top_k = TOP_K |
|
|
self.num_experts = NUM_EXPERTS |
|
|
self.hidden_dim = HIDDEN_SIZE |
|
|
self.weight = nn.Parameter(router_weight.clone()) |
|
|
self.bias = nn.Parameter(router_bias.clone()) |
|
|
|
|
|
def forward(self, hidden_states): |
|
|
hidden_states = hidden_states.reshape(-1, self.hidden_dim) |
|
|
router_logits = F.linear(hidden_states, self.weight, self.bias) |
|
|
router_top_value, router_indices = torch.topk(router_logits, self.top_k, dim=-1) |
|
|
router_top_value = torch.nn.functional.softmax(router_top_value, dim=1, dtype=router_top_value.dtype) |
|
|
router_scores = torch.zeros_like(router_logits).scatter_(1, router_indices, router_top_value) |
|
|
return router_scores, router_indices |
|
|
|
|
|
def ceil_div(a, b): |
|
|
return (a + b - 1) // b |
|
|
|
|
|
class BinnedMoEMLP(nn.Module): |
|
|
def __init__(self, router_weight, router_bias, gate_up_proj, gate_up_proj_bias, down_proj, down_proj_bias): |
|
|
super().__init__() |
|
|
self.router = BinnedRouter(router_weight, router_bias) |
|
|
self.num_experts = NUM_EXPERTS |
|
|
self.hidden_size = HIDDEN_SIZE |
|
|
self.top_k = TOP_K |
|
|
|
|
|
|
|
|
self.gate_up_proj = nn.Parameter(gate_up_proj.clone()) |
|
|
self.gate_up_proj_bias = nn.Parameter(gate_up_proj_bias.clone()) |
|
|
self.down_proj = nn.Parameter(down_proj.clone()) |
|
|
self.down_proj_bias = nn.Parameter(down_proj_bias.clone()) |
|
|
|
|
|
def forward(self, hidden_states): |
|
|
router_scores, router_indices = self.router(hidden_states) |
|
|
batch_size = hidden_states.shape[0] |
|
|
expert_capacity = ceil_div(batch_size * self.top_k, self.num_experts) |
|
|
|
|
|
output = binned_experts_ref( |
|
|
hidden_states, |
|
|
router_indices, |
|
|
router_scores, |
|
|
self.gate_up_proj, |
|
|
self.gate_up_proj_bias, |
|
|
self.down_proj, |
|
|
self.down_proj_bias, |
|
|
expert_capacity, |
|
|
) |
|
|
|
|
|
return output, router_scores |
|
|
|
|
|
|
|
|
set_seed(GENERAL_SEED) |
|
|
|
|
|
device = torch.device(DEVICE) |
|
|
dtype = to_dtype(DTYPE) |
|
|
|
|
|
print("\n=== Binned Implementation ===") |
|
|
|
|
|
model = BinnedMoEMLP( |
|
|
router_weight.to(device), |
|
|
router_bias.to(device), |
|
|
gate_up_proj.to(device), |
|
|
gate_up_proj_bias.to(device), |
|
|
down_proj.to(device), |
|
|
down_proj_bias.to(device) |
|
|
).to(device=device) |
|
|
|
|
|
print(f"Router weight sum: {model.router.weight.sum().item():.6f}") |
|
|
print(f"Gate/up proj sum: {model.gate_up_proj.sum().item():.6f}") |
|
|
print(f"Down proj sum: {model.down_proj.sum().item():.6f}") |
|
|
|
|
|
|
|
|
set_seed(INPUT_SEED) |
|
|
x = torch.randn(BATCH_SIZE, SEQ_LEN, HIDDEN_SIZE, device=device, dtype=dtype) * 0.1 |
|
|
|
|
|
|
|
|
tokens = BATCH_SIZE * SEQ_LEN |
|
|
with bench_context(warmup=10, iters=50, device=device, dtype=dtype, tokens=tokens, save_json="binned_results.json", vary_inputs=True) as bench: |
|
|
output, stats = bench(model, x) |
|
|
print(f"\nOutput sum: {output[0].sum().item():.6f}") |