|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
import torch |
|
|
from torch import nn |
|
|
from torch.nn import functional as F |
|
|
from kernels import get_kernel, get_local_kernel |
|
|
from bench_utils import to_dtype, tensor_stats, set_seed, bench_context |
|
|
from config import ( |
|
|
NUM_EXPERTS, HIDDEN_SIZE, TOP_K, |
|
|
BATCH_SIZE, SEQ_LEN, DTYPE, DEVICE, |
|
|
WEIGHT_SEED, EXPERT_SEED, INPUT_SEED, GENERAL_SEED |
|
|
) |
|
|
from pathlib import Path |
|
|
from collections import namedtuple |
|
|
import os |
|
|
|
|
|
|
|
|
data_dir = os.environ.get('UVNOTE_INPUT_SAVE_DATA', '.') |
|
|
|
|
|
print(f"Loading weights from: {data_dir}") |
|
|
|
|
|
router_weight = torch.load(Path(data_dir) / 'router_weight.pt') |
|
|
router_bias = torch.load(Path(data_dir) / 'router_bias.pt') |
|
|
gate_up_proj = torch.load(Path(data_dir) / 'gate_up_proj.pt') |
|
|
gate_up_proj_bias = torch.load(Path(data_dir) / 'gate_up_proj_bias.pt') |
|
|
down_proj = torch.load(Path(data_dir) / 'down_proj.pt') |
|
|
down_proj_bias = torch.load(Path(data_dir) / 'down_proj_bias.pt') |
|
|
|
|
|
print("Loaded shared weights from artifacts") |
|
|
print(f"Router weight sum: {router_weight.sum().item():.6f}") |
|
|
print(f"Gate/up sum: {gate_up_proj.sum().item():.6f}") |
|
|
print(f"Down sum: {down_proj.sum().item():.6f}") |
|
|
|
|
|
def build_megablocks_model(device: torch.device): |
|
|
|
|
|
megablocks = get_kernel("kernels-community/megablocks", revision="v0.0.2") |
|
|
model = megablocks.layers.MegaBlocksMoeMLP() |
|
|
|
|
|
|
|
|
model.experts = namedtuple( |
|
|
"Experts", ["gate_up_proj", "gate_up_proj_bias", "down_proj", "down_proj_bias", "hidden_size"] |
|
|
) |
|
|
|
|
|
|
|
|
model.router = torch.nn.Linear(HIDDEN_SIZE, NUM_EXPERTS, device=device) |
|
|
with torch.no_grad(): |
|
|
model.router.weight.copy_(router_weight) |
|
|
model.router.bias.copy_(router_bias) |
|
|
|
|
|
|
|
|
e = model.experts |
|
|
e.alpha = 1.702 |
|
|
e.capacity_factor = 32 |
|
|
e.gate_up_proj = torch.nn.Parameter(gate_up_proj.clone().to(device)) |
|
|
e.gate_up_proj_bias = torch.nn.Parameter(gate_up_proj_bias.clone().to(device)) |
|
|
e.down_proj = torch.nn.Parameter(down_proj.clone().to(device)) |
|
|
e.down_proj_bias = torch.nn.Parameter(down_proj_bias.clone().to(device)) |
|
|
e.hidden_size = HIDDEN_SIZE |
|
|
|
|
|
|
|
|
print(f"[MegaBlocks] Router weight sum: {model.router.weight.sum().item():.6f}") |
|
|
print(f"[MegaBlocks] Gate/up projection shape: {tuple(e.gate_up_proj.shape)}, sum: {e.gate_up_proj.sum().item():.6f}") |
|
|
print(f"[MegaBlocks] Down projection shape: {tuple(e.down_proj.shape)}, sum: {e.down_proj.sum().item():.6f}") |
|
|
|
|
|
return model |
|
|
|
|
|
|
|
|
class MegaBlocksMoEWrapper(nn.Module): |
|
|
def __init__(self, megablocks_model): |
|
|
super().__init__() |
|
|
self.model = megablocks_model |
|
|
|
|
|
def forward(self, hidden_states): |
|
|
|
|
|
output, dummy_routing_weights = self.model(hidden_states) |
|
|
return output, dummy_routing_weights |
|
|
|
|
|
|
|
|
set_seed(GENERAL_SEED) |
|
|
|
|
|
device = torch.device(DEVICE) |
|
|
dtype = to_dtype(DTYPE) |
|
|
|
|
|
print("\n=== MegaBlocks Implementation ===") |
|
|
|
|
|
megablocks_model = build_megablocks_model(device) |
|
|
model = MegaBlocksMoEWrapper(megablocks_model).to(device=device) |
|
|
|
|
|
|
|
|
set_seed(INPUT_SEED) |
|
|
x = torch.randn(BATCH_SIZE, SEQ_LEN, HIDDEN_SIZE, device=device, dtype=dtype) * 0.1 |
|
|
|
|
|
|
|
|
tokens = BATCH_SIZE * SEQ_LEN |
|
|
with bench_context(warmup=10, iters=50, device=device, dtype=dtype, tokens=tokens, save_json="megablocks_results.json", vary_inputs=True) as bench: |
|
|
output, stats = bench(model, x) |
|
|
print(f"\nOutput sum: {output[0].sum().item():.6f}") |