drbh's picture
drbh HF Staff
Upload folder using huggingface_hub
93c5002 verified
raw
history blame
1.52 kB
# /// script
# dependencies = [
# "torch",
# "numpy",
# ]
# ///
"""
Generate deterministic shared weights once and save as artifacts so
both implementations load identical parameters.
"""
import torch
from config import NUM_EXPERTS, HIDDEN_SIZE, WEIGHT_SEED, EXPERT_SEED
def save_shared_weights():
# Router: Kaiming uniform as used by both, bias zeros
torch.manual_seed(WEIGHT_SEED)
router_weight = torch.empty(NUM_EXPERTS, HIDDEN_SIZE)
torch.nn.init.kaiming_uniform_(router_weight)
router_bias = torch.zeros(NUM_EXPERTS)
# Experts: normal(0, 0.02), biases zeros
torch.manual_seed(EXPERT_SEED)
gate_up_proj = torch.empty(NUM_EXPERTS, HIDDEN_SIZE, 2 * HIDDEN_SIZE).normal_(mean=0.0, std=0.02)
gate_up_proj_bias = torch.zeros(NUM_EXPERTS, 2 * HIDDEN_SIZE)
down_proj = torch.empty(NUM_EXPERTS, HIDDEN_SIZE, HIDDEN_SIZE).normal_(mean=0.0, std=0.02)
down_proj_bias = torch.zeros(NUM_EXPERTS, HIDDEN_SIZE)
# Save artifacts
torch.save(router_weight, 'router_weight.pt')
torch.save(router_bias, 'router_bias.pt')
torch.save(gate_up_proj, 'gate_up_proj.pt')
torch.save(gate_up_proj_bias, 'gate_up_proj_bias.pt')
torch.save(down_proj, 'down_proj.pt')
torch.save(down_proj_bias, 'down_proj_bias.pt')
print("Saved shared weights to artifacts")
print(f"Router weight sum: {router_weight.sum().item():.6f}")
print(f"Gate/up sum: {gate_up_proj.sum().item():.6f}")
print(f"Down sum: {down_proj.sum().item():.6f}")
save_shared_weights()