# /// script # dependencies = [ # "torch", # "kernels", # "numpy", # ] # /// import torch from torch import nn from torch.nn import functional as F from kernels import get_kernel, get_local_kernel from bench_utils import to_dtype, tensor_stats, set_seed, bench_context from config import ( NUM_EXPERTS, HIDDEN_SIZE, TOP_K, BATCH_SIZE, SEQ_LEN, DTYPE, DEVICE, WEIGHT_SEED, EXPERT_SEED, INPUT_SEED, GENERAL_SEED ) from pathlib import Path import os # Discover the upstream artifact directory from env data_dir = os.environ.get('UVNOTE_INPUT_SAVE_DATA', '.') print(f"Loading weights from: {data_dir}") router_weight = torch.load(Path(data_dir) / 'router_weight.pt') router_bias = torch.load(Path(data_dir) / 'router_bias.pt') gate_up_proj = torch.load(Path(data_dir) / 'gate_up_proj.pt') gate_up_proj_bias = torch.load(Path(data_dir) / 'gate_up_proj_bias.pt') down_proj = torch.load(Path(data_dir) / 'down_proj.pt') down_proj_bias = torch.load(Path(data_dir) / 'down_proj_bias.pt') print("Loaded shared weights from artifacts") print(f"Router weight sum: {router_weight.sum().item():.6f}") print(f"Gate/up sum: {gate_up_proj.sum().item():.6f}") print(f"Down sum: {down_proj.sum().item():.6f}") class YamoeRouter(nn.Module): def __init__(self, router_weight, router_bias): super().__init__() self.top_k = TOP_K self.num_experts = NUM_EXPERTS self.hidden_dim = HIDDEN_SIZE self.weight = nn.Parameter(router_weight.clone()) self.bias = nn.Parameter(router_bias.clone()) def forward(self, hidden_states): hidden_states = hidden_states.reshape(-1, self.hidden_dim) router_logits = F.linear(hidden_states, self.weight, self.bias) router_top_value, router_indices = torch.topk(router_logits, self.top_k, dim=-1) router_top_value = torch.nn.functional.softmax(router_top_value, dim=1, dtype=router_top_value.dtype) router_scores = torch.zeros_like(router_logits).scatter_(1, router_indices, router_top_value) return router_scores, router_indices def ceil_div(a, b): return (a + b - 1) // b class YamoeMoEMLP(nn.Module): def __init__(self, router_weight, router_bias, gate_up_proj, gate_up_proj_bias, down_proj, down_proj_bias): super().__init__() self.router = YamoeRouter(router_weight, router_bias) self.num_experts = NUM_EXPERTS self.hidden_size = HIDDEN_SIZE self.top_k = TOP_K # Load Yamoe kernel # self.yamoe = get_local_kernel(Path("/home/ubuntu/Projects/yamoe/result"), "yamoe") self.yamoe = get_kernel("drbh/yamoe", revision="v0.2.0") # Expert weights - use the loaded weights self.gate_up_proj = nn.Parameter(gate_up_proj.clone()) self.gate_up_proj_bias = nn.Parameter(gate_up_proj_bias.clone()) self.down_proj = nn.Parameter(down_proj.clone()) self.down_proj_bias = nn.Parameter(down_proj_bias.clone()) def forward(self, hidden_states): batch_size, seq_len, hidden_dim = hidden_states.shape # Get routing decisions routing_weights, router_indices = self.router(hidden_states) # Reshape for Yamoe kernel hidden_states_flat = hidden_states.view(-1, hidden_dim) routing_weights_flat = routing_weights.view(-1, self.num_experts) expert_capacity = ceil_div(batch_size * self.top_k, self.num_experts) # Call Yamoe optimized kernel output = self.yamoe.experts( hidden_states_flat, router_indices, routing_weights_flat, self.gate_up_proj, self.gate_up_proj_bias, self.down_proj, self.down_proj_bias, expert_capacity, self.num_experts, self.top_k, ) # Reshape output back output = output.view(batch_size, seq_len, hidden_dim) return output, routing_weights # Run the model set_seed(GENERAL_SEED) device = torch.device(DEVICE if DEVICE == "cuda" else "cuda") dtype = to_dtype(DTYPE) print("\n=== Yamoe Implementation ===") # Initialize model with loaded weights model = YamoeMoEMLP( router_weight.to(device), router_bias.to(device), gate_up_proj.to(device), gate_up_proj_bias.to(device), down_proj.to(device), down_proj_bias.to(device) ).to(device=device) print(f"Router weight sum: {model.router.weight.sum().item():.6f}") print(f"Gate/up proj sum: {model.gate_up_proj.sum().item():.6f}") print(f"Down proj sum: {model.down_proj.sum().item():.6f}") # Generate input set_seed(INPUT_SEED) x = torch.randn(BATCH_SIZE, SEQ_LEN, HIDDEN_SIZE, device=device, dtype=dtype) * 0.1 # Benchmark the model with varied inputs to prevent caching artifacts tokens = BATCH_SIZE * SEQ_LEN with bench_context(warmup=10, iters=50, device=device, dtype=dtype, tokens=tokens, save_json="yamoe_results.json", vary_inputs=True) as bench: output, stats = bench(model, x) print(f"\nOutput sum: {output[0].sum().item():.6f}")