File size: 1,521 Bytes
93c5002
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
# /// script
# dependencies = [
#     "torch",
#     "numpy",
# ]
# ///

"""
Generate deterministic shared weights once and save as artifacts so
both implementations load identical parameters.
"""
import torch
from config import NUM_EXPERTS, HIDDEN_SIZE, WEIGHT_SEED, EXPERT_SEED

def save_shared_weights():
    # Router: Kaiming uniform as used by both, bias zeros
    torch.manual_seed(WEIGHT_SEED)
    router_weight = torch.empty(NUM_EXPERTS, HIDDEN_SIZE)
    torch.nn.init.kaiming_uniform_(router_weight)
    router_bias = torch.zeros(NUM_EXPERTS)

    # Experts: normal(0, 0.02), biases zeros
    torch.manual_seed(EXPERT_SEED)
    gate_up_proj = torch.empty(NUM_EXPERTS, HIDDEN_SIZE, 2 * HIDDEN_SIZE).normal_(mean=0.0, std=0.02)
    gate_up_proj_bias = torch.zeros(NUM_EXPERTS, 2 * HIDDEN_SIZE)
    down_proj = torch.empty(NUM_EXPERTS, HIDDEN_SIZE, HIDDEN_SIZE).normal_(mean=0.0, std=0.02)
    down_proj_bias = torch.zeros(NUM_EXPERTS, HIDDEN_SIZE)

    # Save artifacts
    torch.save(router_weight, 'router_weight.pt')
    torch.save(router_bias, 'router_bias.pt')
    torch.save(gate_up_proj, 'gate_up_proj.pt')
    torch.save(gate_up_proj_bias, 'gate_up_proj_bias.pt')
    torch.save(down_proj, 'down_proj.pt')
    torch.save(down_proj_bias, 'down_proj_bias.pt')

    print("Saved shared weights to artifacts")
    print(f"Router weight sum: {router_weight.sum().item():.6f}")
    print(f"Gate/up sum: {gate_up_proj.sum().item():.6f}")
    print(f"Down sum: {down_proj.sum().item():.6f}")

save_shared_weights()