File size: 459 Bytes
73f8595
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
# /// script
# dependencies = [
#     "torch",
#     "numpy",
# ]
# ///

"""Shared configuration for both implementations."""
import torch

# Model configuration
NUM_EXPERTS = 128
HIDDEN_SIZE = 1152
INTERMEDIATE_SIZE = 3072
TOP_K = 4

# Input configuration
BATCH_SIZE = 1
SEQ_LEN = 100
DTYPE = "float32"
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"

# Seeds for reproducibility
WEIGHT_SEED = 999
EXPERT_SEED = 777
INPUT_SEED = 123
GENERAL_SEED = 42