drbh's picture
drbh HF Staff
Upload folder using huggingface_hub
73f8595 verified
raw
history blame
5.05 kB
# /// script
# dependencies = [
# "torch",
# "kernels",
# "numpy",
# ]
# ///
import torch
from torch import nn
from torch.nn import functional as F
from kernels import get_kernel, get_local_kernel
from bench_utils import to_dtype, tensor_stats, set_seed, bench_context
from config import (
NUM_EXPERTS, HIDDEN_SIZE, TOP_K,
BATCH_SIZE, SEQ_LEN, DTYPE, DEVICE,
WEIGHT_SEED, EXPERT_SEED, INPUT_SEED, GENERAL_SEED
)
from pathlib import Path
import os
# Discover the upstream artifact directory from env
data_dir = os.environ.get('UVNOTE_INPUT_SAVE_DATA', '.')
print(f"Loading weights from: {data_dir}")
router_weight = torch.load(Path(data_dir) / 'router_weight.pt')
router_bias = torch.load(Path(data_dir) / 'router_bias.pt')
gate_up_proj = torch.load(Path(data_dir) / 'gate_up_proj.pt')
gate_up_proj_bias = torch.load(Path(data_dir) / 'gate_up_proj_bias.pt')
down_proj = torch.load(Path(data_dir) / 'down_proj.pt')
down_proj_bias = torch.load(Path(data_dir) / 'down_proj_bias.pt')
print("Loaded shared weights from artifacts")
print(f"Router weight sum: {router_weight.sum().item():.6f}")
print(f"Gate/up sum: {gate_up_proj.sum().item():.6f}")
print(f"Down sum: {down_proj.sum().item():.6f}")
class YamoeRouter(nn.Module):
def __init__(self, router_weight, router_bias):
super().__init__()
self.top_k = TOP_K
self.num_experts = NUM_EXPERTS
self.hidden_dim = HIDDEN_SIZE
self.weight = nn.Parameter(router_weight.clone())
self.bias = nn.Parameter(router_bias.clone())
def forward(self, hidden_states):
hidden_states = hidden_states.reshape(-1, self.hidden_dim)
router_logits = F.linear(hidden_states, self.weight, self.bias)
router_top_value, router_indices = torch.topk(router_logits, self.top_k, dim=-1)
router_top_value = torch.nn.functional.softmax(router_top_value, dim=1, dtype=router_top_value.dtype)
router_scores = torch.zeros_like(router_logits).scatter_(1, router_indices, router_top_value)
return router_scores, router_indices
def ceil_div(a, b):
return (a + b - 1) // b
class YamoeMoEMLP(nn.Module):
def __init__(self, router_weight, router_bias, gate_up_proj, gate_up_proj_bias, down_proj, down_proj_bias):
super().__init__()
self.router = YamoeRouter(router_weight, router_bias)
self.num_experts = NUM_EXPERTS
self.hidden_size = HIDDEN_SIZE
self.top_k = TOP_K
# Load Yamoe kernel
# self.yamoe = get_local_kernel(Path("/home/ubuntu/Projects/yamoe/result"), "yamoe")
self.yamoe = get_kernel("drbh/yamoe", revision="v0.2.0")
# Expert weights - use the loaded weights
self.gate_up_proj = nn.Parameter(gate_up_proj.clone())
self.gate_up_proj_bias = nn.Parameter(gate_up_proj_bias.clone())
self.down_proj = nn.Parameter(down_proj.clone())
self.down_proj_bias = nn.Parameter(down_proj_bias.clone())
def forward(self, hidden_states):
batch_size, seq_len, hidden_dim = hidden_states.shape
# Get routing decisions
routing_weights, router_indices = self.router(hidden_states)
# Reshape for Yamoe kernel
hidden_states_flat = hidden_states.view(-1, hidden_dim)
routing_weights_flat = routing_weights.view(-1, self.num_experts)
expert_capacity = ceil_div(batch_size * self.top_k, self.num_experts)
# Call Yamoe optimized kernel
output = self.yamoe.experts(
hidden_states_flat,
router_indices,
routing_weights_flat,
self.gate_up_proj,
self.gate_up_proj_bias,
self.down_proj,
self.down_proj_bias,
expert_capacity,
self.num_experts,
self.top_k,
)
# Reshape output back
output = output.view(batch_size, seq_len, hidden_dim)
return output, routing_weights
# Run the model
set_seed(GENERAL_SEED)
device = torch.device(DEVICE if DEVICE == "cuda" else "cuda")
dtype = to_dtype(DTYPE)
print("\n=== Yamoe Implementation ===")
# Initialize model with loaded weights
model = YamoeMoEMLP(
router_weight.to(device),
router_bias.to(device),
gate_up_proj.to(device),
gate_up_proj_bias.to(device),
down_proj.to(device),
down_proj_bias.to(device)
).to(device=device)
print(f"Router weight sum: {model.router.weight.sum().item():.6f}")
print(f"Gate/up proj sum: {model.gate_up_proj.sum().item():.6f}")
print(f"Down proj sum: {model.down_proj.sum().item():.6f}")
# Generate input
set_seed(INPUT_SEED)
x = torch.randn(BATCH_SIZE, SEQ_LEN, HIDDEN_SIZE, device=device, dtype=dtype) * 0.1
# Benchmark the model with varied inputs to prevent caching artifacts
tokens = BATCH_SIZE * SEQ_LEN
with bench_context(warmup=10, iters=50, device=device, dtype=dtype, tokens=tokens, save_json="yamoe_results.json", vary_inputs=True) as bench:
output, stats = bench(model, x)
print(f"\nOutput sum: {output[0].sum().item():.6f}")