File size: 5,051 Bytes
73f8595 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 |
# /// script
# dependencies = [
# "torch",
# "kernels",
# "numpy",
# ]
# ///
import torch
from torch import nn
from torch.nn import functional as F
from kernels import get_kernel, get_local_kernel
from bench_utils import to_dtype, tensor_stats, set_seed, bench_context
from config import (
NUM_EXPERTS, HIDDEN_SIZE, TOP_K,
BATCH_SIZE, SEQ_LEN, DTYPE, DEVICE,
WEIGHT_SEED, EXPERT_SEED, INPUT_SEED, GENERAL_SEED
)
from pathlib import Path
import os
# Discover the upstream artifact directory from env
data_dir = os.environ.get('UVNOTE_INPUT_SAVE_DATA', '.')
print(f"Loading weights from: {data_dir}")
router_weight = torch.load(Path(data_dir) / 'router_weight.pt')
router_bias = torch.load(Path(data_dir) / 'router_bias.pt')
gate_up_proj = torch.load(Path(data_dir) / 'gate_up_proj.pt')
gate_up_proj_bias = torch.load(Path(data_dir) / 'gate_up_proj_bias.pt')
down_proj = torch.load(Path(data_dir) / 'down_proj.pt')
down_proj_bias = torch.load(Path(data_dir) / 'down_proj_bias.pt')
print("Loaded shared weights from artifacts")
print(f"Router weight sum: {router_weight.sum().item():.6f}")
print(f"Gate/up sum: {gate_up_proj.sum().item():.6f}")
print(f"Down sum: {down_proj.sum().item():.6f}")
class YamoeRouter(nn.Module):
def __init__(self, router_weight, router_bias):
super().__init__()
self.top_k = TOP_K
self.num_experts = NUM_EXPERTS
self.hidden_dim = HIDDEN_SIZE
self.weight = nn.Parameter(router_weight.clone())
self.bias = nn.Parameter(router_bias.clone())
def forward(self, hidden_states):
hidden_states = hidden_states.reshape(-1, self.hidden_dim)
router_logits = F.linear(hidden_states, self.weight, self.bias)
router_top_value, router_indices = torch.topk(router_logits, self.top_k, dim=-1)
router_top_value = torch.nn.functional.softmax(router_top_value, dim=1, dtype=router_top_value.dtype)
router_scores = torch.zeros_like(router_logits).scatter_(1, router_indices, router_top_value)
return router_scores, router_indices
def ceil_div(a, b):
return (a + b - 1) // b
class YamoeMoEMLP(nn.Module):
def __init__(self, router_weight, router_bias, gate_up_proj, gate_up_proj_bias, down_proj, down_proj_bias):
super().__init__()
self.router = YamoeRouter(router_weight, router_bias)
self.num_experts = NUM_EXPERTS
self.hidden_size = HIDDEN_SIZE
self.top_k = TOP_K
# Load Yamoe kernel
# self.yamoe = get_local_kernel(Path("/home/ubuntu/Projects/yamoe/result"), "yamoe")
self.yamoe = get_kernel("drbh/yamoe", revision="v0.2.0")
# Expert weights - use the loaded weights
self.gate_up_proj = nn.Parameter(gate_up_proj.clone())
self.gate_up_proj_bias = nn.Parameter(gate_up_proj_bias.clone())
self.down_proj = nn.Parameter(down_proj.clone())
self.down_proj_bias = nn.Parameter(down_proj_bias.clone())
def forward(self, hidden_states):
batch_size, seq_len, hidden_dim = hidden_states.shape
# Get routing decisions
routing_weights, router_indices = self.router(hidden_states)
# Reshape for Yamoe kernel
hidden_states_flat = hidden_states.view(-1, hidden_dim)
routing_weights_flat = routing_weights.view(-1, self.num_experts)
expert_capacity = ceil_div(batch_size * self.top_k, self.num_experts)
# Call Yamoe optimized kernel
output = self.yamoe.experts(
hidden_states_flat,
router_indices,
routing_weights_flat,
self.gate_up_proj,
self.gate_up_proj_bias,
self.down_proj,
self.down_proj_bias,
expert_capacity,
self.num_experts,
self.top_k,
)
# Reshape output back
output = output.view(batch_size, seq_len, hidden_dim)
return output, routing_weights
# Run the model
set_seed(GENERAL_SEED)
device = torch.device(DEVICE if DEVICE == "cuda" else "cuda")
dtype = to_dtype(DTYPE)
print("\n=== Yamoe Implementation ===")
# Initialize model with loaded weights
model = YamoeMoEMLP(
router_weight.to(device),
router_bias.to(device),
gate_up_proj.to(device),
gate_up_proj_bias.to(device),
down_proj.to(device),
down_proj_bias.to(device)
).to(device=device)
print(f"Router weight sum: {model.router.weight.sum().item():.6f}")
print(f"Gate/up proj sum: {model.gate_up_proj.sum().item():.6f}")
print(f"Down proj sum: {model.down_proj.sum().item():.6f}")
# Generate input
set_seed(INPUT_SEED)
x = torch.randn(BATCH_SIZE, SEQ_LEN, HIDDEN_SIZE, device=device, dtype=dtype) * 0.1
# Benchmark the model with varied inputs to prevent caching artifacts
tokens = BATCH_SIZE * SEQ_LEN
with bench_context(warmup=10, iters=50, device=device, dtype=dtype, tokens=tokens, save_json="yamoe_results.json", vary_inputs=True) as bench:
output, stats = bench(model, x)
print(f"\nOutput sum: {output[0].sum().item():.6f}") |