drbh HF Staff commited on
Commit
39291b0
·
verified ·
1 Parent(s): a6ab428

Upload folder using huggingface_hub

Browse files
flash_attn/benchmark.html ADDED
The diff for this file is too large to render. See raw diff
 
flash_attn/cells/benchmark.py ADDED
@@ -0,0 +1,341 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # /// script
2
+ # dependencies = [
3
+ # "numpy",
4
+ # "torch",
5
+ # "kernels",
6
+ # "pandas",
7
+ # "matplotlib"
8
+ # ]
9
+ # ///
10
+ # Benchmarking common shapes for Flux 1024x1024px image + varying text sequence lengths
11
+
12
+ import functools
13
+ import os
14
+ import pathlib
15
+
16
+ import matplotlib.pyplot as plt
17
+ import torch
18
+ import torch._dynamo.config
19
+ import triton
20
+ import triton.language as tl
21
+
22
+ try:
23
+ from flash_attn import flash_attn_func
24
+ except:
25
+ flash_attn_func = None
26
+ print("Flash Attention 2 not found.")
27
+
28
+ try:
29
+ from flash_attn_interface import flash_attn_func as flash_attn_3_func
30
+ except:
31
+ flash_attn_3_func = None
32
+ print("Flash Attention 3 not found.")
33
+
34
+ try:
35
+ from kernels import get_kernel
36
+ hf_kernels_flash_attn = get_kernel("kernels-community/flash-attn")
37
+ hf_kernels_flash_attn_3 = get_kernel("kernels-community/flash-attn3")
38
+ except:
39
+ hf_kernels_flash_attn = None
40
+ hf_kernels_flash_attn_3 = None
41
+ print("HF Kernels not found.")
42
+
43
+ try:
44
+ from sageattention import sageattn_qk_int8_pv_fp16_cuda, sageattn_qk_int8_pv_fp16_triton, sageattn_qk_int8_pv_fp8_cuda_sm90
45
+ except:
46
+ sageattn_qk_int8_pv_fp16_cuda = None
47
+ sageattn_qk_int8_pv_fp16_triton = None
48
+ sageattn_qk_int8_pv_fp8_cuda_sm90 = None
49
+ print("SageAttention not found.")
50
+
51
+ try:
52
+ from transformer_engine.pytorch.attention import DotProductAttention
53
+ except:
54
+ DotProductAttention = None
55
+ print("Transformer Engine not found.")
56
+
57
+ try:
58
+ import xformers.ops as xops
59
+ except:
60
+ xops = None
61
+ print("xFormers not found.")
62
+
63
+
64
+ plt.rcParams.update({
65
+ "figure.figsize": (12, 10),
66
+ "figure.dpi": 120,
67
+ "font.size": 10,
68
+ "axes.titlesize": 12,
69
+ "axes.labelsize": 14,
70
+ "xtick.labelsize": 10,
71
+ "ytick.labelsize": 10,
72
+ "legend.fontsize": 8,
73
+ "axes.grid": True,
74
+ "grid.alpha": 0.3,
75
+ "grid.linestyle": "--",
76
+ "lines.linewidth": 2.0,
77
+ "lines.markersize": 6,
78
+ "legend.frameon": True,
79
+ "legend.framealpha": 0.9,
80
+ "legend.loc": "best",
81
+ "axes.spines.top": False,
82
+ "axes.spines.right": False,
83
+ })
84
+
85
+
86
+ # We want to compare the best compiled version for each specific shape (dynamic=False)
87
+ torch._dynamo.config.cache_size_limit = 10000
88
+
89
+ # We need to suppress_errors for FA3 to work. It makes it run in eager mode.
90
+ # I can't seem to get it to work any other way under torch.compile, so any suggestions are welcome!
91
+ torch._dynamo.config.suppress_errors = True
92
+
93
+ output_dir = pathlib.Path("dump_attention_benchmark")
94
+ output_dir.mkdir(parents=True, exist_ok=True)
95
+
96
+ batch_size = 1
97
+ num_attention_heads = 24
98
+ attention_head_dim = 128
99
+ image_sequence_length = 4096 # 1024x1024px
100
+ text_sequence_lengths = [128, 256, 320, 384, 448, 512]
101
+ sequence_lengths = [image_sequence_length + i for i in text_sequence_lengths]
102
+
103
+
104
+ def _attention_torch(query, key, value, *, backend):
105
+ query, key, value = (x.transpose(1, 2).contiguous() for x in (query, key, value))
106
+ with torch.nn.attention.sdpa_kernel(backend):
107
+ out = torch.nn.functional.scaled_dot_product_attention(query, key, value)
108
+ out = out.transpose(1, 2).contiguous()
109
+ return out
110
+
111
+
112
+ _compiled_attention_torch_default = torch.compile(_attention_torch, mode="default", fullgraph=True, dynamic=False)
113
+ def _attention_torch_compile_default(query, key, value, *, backend):
114
+ return _compiled_attention_torch_default(query, key, value, backend=backend)
115
+
116
+
117
+ _compiled_attention_torch_max_autotune = torch.compile(_attention_torch, mode="max-autotune", fullgraph=True, dynamic=False)
118
+ def _attention_torch_compile_max_autotune(query, key, value, *, backend):
119
+ return _compiled_attention_torch_max_autotune(query, key, value, backend=backend)
120
+
121
+
122
+ def _attention_flash_attn_2(query, key, value):
123
+ return flash_attn_func(query, key, value)
124
+
125
+
126
+ _compiled_flash_attn_2_default = torch.compile(_attention_flash_attn_2, mode="default", fullgraph=True, dynamic=False)
127
+ def _attention_flash_attn_2_compile_default(query, key, value):
128
+ return _compiled_flash_attn_2_default(query, key, value)
129
+
130
+
131
+ _compiled_flash_attn_2_max_autotune = torch.compile(_attention_flash_attn_2, mode="max-autotune", fullgraph=True, dynamic=False)
132
+ def _attention_flash_attn_2_compile_max_autotune(query, key, value):
133
+ return _compiled_flash_attn_2_max_autotune(query, key, value)
134
+
135
+
136
+ # For fullgraph=True tracing to be compatible
137
+ @torch.library.custom_op("flash_attn_3::_flash_attn_forward", mutates_args=(), device_types="cuda")
138
+ def _wrapped_flash_attn_3(query: torch.Tensor, key: torch.Tensor, value: torch.Tensor) -> torch.Tensor:
139
+ out, lse = flash_attn_3_func(query, key, value)
140
+ return out
141
+
142
+
143
+ @torch.library.register_fake("flash_attn_3::_flash_attn_forward")
144
+ def _(query: torch.Tensor, key: torch.Tensor, value: torch.Tensor) -> torch.Tensor:
145
+ return torch.empty_like(query)
146
+
147
+
148
+ def _attention_flash_attn_3(query, key, value):
149
+ out = _wrapped_flash_attn_3(query, key, value)
150
+ return out
151
+
152
+
153
+ _compiled_flash_attn_3_default = torch.compile(_attention_flash_attn_3, mode="default", fullgraph=True, dynamic=False)
154
+ def _attention_flash_attn_3_compile_default(query, key, value):
155
+ return _compiled_flash_attn_3_default(query, key, value)
156
+
157
+
158
+ _compiled_flash_attn_3_max_autotune = torch.compile(_attention_flash_attn_3, mode="max-autotune", fullgraph=True, dynamic=False)
159
+ def _attention_flash_attn_3_compile_max_autotune(query, key, value):
160
+ return _compiled_flash_attn_3_max_autotune(query, key, value)
161
+
162
+
163
+ def _attention_hf_kernels_flash_attn(query, key, value):
164
+ return hf_kernels_flash_attn.fwd(query, key, value, is_causal=False)[0]
165
+
166
+
167
+ def _attention_hf_kernels_flash_attn3(query, key, value):
168
+ return hf_kernels_flash_attn_3.flash_attn_func(query, key, value, causal=False)[0]
169
+
170
+
171
+ def _attention_sageattn_qk_int8_pv_fp16_cuda(query, key, value):
172
+ return sageattn_qk_int8_pv_fp16_cuda(query, key, value, tensor_layout="NHD")
173
+
174
+
175
+ def _attention_sageattn_qk_int8_pv_fp16_triton(query, key, value):
176
+ return sageattn_qk_int8_pv_fp16_triton(query, key, value, tensor_layout="NHD")
177
+
178
+
179
+ def _attention_sageattn_qk_int8_pv_fp8_cuda_sm90(query, key, value):
180
+ return sageattn_qk_int8_pv_fp8_cuda_sm90(query, key, value, tensor_layout="NHD")
181
+
182
+
183
+ if DotProductAttention is not None:
184
+ def set_te_backend(backend):
185
+ # must be applied before first use of
186
+ # transformer_engine.pytorch.attention
187
+ os.environ["NVTE_FLASH_ATTN"] = '0'
188
+ os.environ["NVTE_FUSED_ATTN"] = '0'
189
+ os.environ["NVTE_UNFUSED_ATTN"] = '0'
190
+ if backend == 'flash':
191
+ os.environ["NVTE_FLASH_ATTN"] = '1'
192
+ if backend == 'fused':
193
+ os.environ["NVTE_FUSED_ATTN"] = '1'
194
+ if backend == 'unfused':
195
+ os.environ["NVTE_UNFUSED_ATTN"] = '1'
196
+
197
+ set_te_backend("fused")
198
+ te_attn_fn = DotProductAttention(
199
+ num_attention_heads=num_attention_heads,
200
+ kv_channels=attention_head_dim,
201
+ qkv_format="bshd",
202
+ attn_mask_type="no_mask",
203
+ )
204
+ else:
205
+ def te_attn_fn(query, key, value):
206
+ raise RuntimeError("Transformer Engine is not available. Please install it for TE-based attention.")
207
+
208
+ def _attention_te(query, key, value):
209
+ out = te_attn_fn(query, key, value)
210
+ out = out.unflatten(2, (num_attention_heads, attention_head_dim))
211
+ return out
212
+
213
+
214
+ # Cannot fullgraph compile TE
215
+ _compiled_te_attn_fn_default = torch.compile(_attention_te, mode="default", fullgraph=False, dynamic=False)
216
+ def _attention_te_compile_default(query, key, value):
217
+ return _compiled_te_attn_fn_default(query, key, value)
218
+
219
+
220
+ # Cannot fullgraph compile TE
221
+ _compiled_te_attn_fn_max_autotune = torch.compile(_attention_te, mode="max-autotune", fullgraph=False, dynamic=False)
222
+ def _attention_te_compile_max_autotune(query, key, value):
223
+ return _compiled_te_attn_fn_max_autotune(query, key, value)
224
+
225
+
226
+ def _attention_xformers(query, key, value):
227
+ return xops.memory_efficient_attention(query, key, value)
228
+
229
+
230
+ _compiled_xformers_default = torch.compile(_attention_xformers, mode="default", fullgraph=True, dynamic=False)
231
+ def _attention_xformers_compile_default(query, key, value):
232
+ return _compiled_xformers_default(query, key, value)
233
+
234
+
235
+ _compiled_xformers_max_autotune = torch.compile(_attention_xformers, mode="max-autotune", fullgraph=True, dynamic=False)
236
+ def _attention_xformers_compile_max_autotune(query, key, value):
237
+ return _compiled_xformers_max_autotune(query, key, value)
238
+
239
+
240
+ attention_ops = {}
241
+ attention_ops["torch_cudnn"] = functools.partial(_attention_torch, backend=torch.nn.attention.SDPBackend.CUDNN_ATTENTION)
242
+ attention_ops["torch_cudnn_compile_d"] = functools.partial(_attention_torch_compile_default, backend=torch.nn.attention.SDPBackend.CUDNN_ATTENTION)
243
+ attention_ops["torch_cudnn_compile_ma"] = functools.partial(_attention_torch_compile_max_autotune, backend=torch.nn.attention.SDPBackend.CUDNN_ATTENTION)
244
+ attention_ops["torch_flash"] = functools.partial(_attention_torch, backend=torch.nn.attention.SDPBackend.FLASH_ATTENTION)
245
+ attention_ops["torch_flash_compile_d"] = functools.partial(_attention_torch_compile_default, backend=torch.nn.attention.SDPBackend.FLASH_ATTENTION)
246
+ attention_ops["torch_flash_compile_ma"] = functools.partial(_attention_torch_compile_max_autotune, backend=torch.nn.attention.SDPBackend.FLASH_ATTENTION)
247
+ if hf_kernels_flash_attn is not None:
248
+ attention_ops["hf_flash_attn"] = _attention_hf_kernels_flash_attn
249
+ attention_ops["hf_flash_attn3"] = _attention_hf_kernels_flash_attn3
250
+ if flash_attn_func is not None:
251
+ attention_ops["flash_attn_2"] = _attention_flash_attn_2
252
+ attention_ops["flash_attn_2_compile_d"] = _attention_flash_attn_2_compile_default
253
+ attention_ops["flash_attn_2_compile_ma"] = _attention_flash_attn_2_compile_max_autotune
254
+ if flash_attn_3_func is not None:
255
+ attention_ops["flash_attn_3"] = _attention_flash_attn_3
256
+ attention_ops["flash_attn_3_compile_d"] = _attention_flash_attn_3_compile_default
257
+ attention_ops["flash_attn_3_compile_ma"] = _attention_flash_attn_3_compile_max_autotune
258
+ if sageattn_qk_int8_pv_fp16_cuda is not None:
259
+ attention_ops["sageattn_qk_int8_pv_fp16_cuda"] = _attention_sageattn_qk_int8_pv_fp16_cuda
260
+ attention_ops["sageattn_qk_int8_pv_fp16_triton"] = _attention_sageattn_qk_int8_pv_fp16_triton
261
+ if torch.cuda.get_device_capability()[0] >= 9:
262
+ attention_ops["sageattn_qk_int8_pv_fp8_cuda_sm90"] = _attention_sageattn_qk_int8_pv_fp8_cuda_sm90
263
+ if DotProductAttention is not None:
264
+ attention_ops["te_fused"] = _attention_te
265
+ attention_ops["te_fused_compile_d"] = _attention_te_compile_default
266
+ attention_ops["te_fused_compile_ma"] = _attention_te_compile_max_autotune
267
+ if xops is not None:
268
+ attention_ops["xformers"] = _attention_xformers
269
+ attention_ops["xformers_compile_d"] = _attention_xformers_compile_default
270
+ attention_ops["xformers_compile_ma"] = _attention_xformers_compile_max_autotune
271
+
272
+
273
+ def get_color_and_linestyle(n: int) -> tuple[str, str]:
274
+ colors = ["#e41a1c", "#377eb8", "#4daf4a", "#984ea3", "#ff7f00", "#a65628", "#f781bf", "#999999"]
275
+ line_styles = ["-", ":", "-.", "--"]
276
+ if n > len(colors) * len(line_styles):
277
+ raise ValueError(f"Required {n=} styles but maximum is {len(colors) * len(line_styles)}")
278
+ styles = []
279
+ for i in range(n):
280
+ color = colors[i % len(colors)]
281
+ linestyle = line_styles[i // len(colors)]
282
+ styles.append((color, linestyle))
283
+ return styles
284
+
285
+
286
+ def correctness():
287
+ for seq_len in sequence_lengths:
288
+ shape = (batch_size, seq_len, num_attention_heads, attention_head_dim)
289
+ print(f"\n\n===== Testing shape: {shape} =====")
290
+
291
+ query = torch.randn(shape, device="cuda", dtype=torch.float32)
292
+ key = torch.randn(shape, device="cuda", dtype=torch.float32)
293
+ value = torch.randn(shape, device="cuda", dtype=torch.float32)
294
+
295
+ golden_truth = _attention_torch(query, key, value, backend=torch.nn.attention.SDPBackend.MATH)
296
+ query, key, value = (x.bfloat16() for x in (query, key, value))
297
+
298
+ for name, fn in attention_ops.items():
299
+ out = fn(query, key, value)
300
+ absdiff = (out - golden_truth).abs()
301
+ absmax = torch.max(absdiff)
302
+ mae = torch.mean(absdiff)
303
+ mse = torch.mean((golden_truth - out) ** 2)
304
+ print(f"{name:<30}: absmax={absmax:.6f}, mae={mae:.6f}, mse={mse:.6f}")
305
+
306
+
307
+ @triton.testing.perf_report(
308
+ triton.testing.Benchmark(
309
+ x_names=["seq_len"],
310
+ x_vals=sequence_lengths,
311
+ x_log=False,
312
+ line_arg="provider",
313
+ line_vals=list(attention_ops.keys()),
314
+ line_names=[x.removeprefix("solution_") for x in attention_ops.keys()],
315
+ ylabel="Time (ms)",
316
+ styles=get_color_and_linestyle(len(attention_ops)),
317
+ plot_name="Attention Benchmark",
318
+ args={},
319
+ )
320
+ )
321
+ def benchmark_fn(seq_len: int, provider: str):
322
+ torch.manual_seed(0)
323
+
324
+ shape = (batch_size, seq_len, num_attention_heads, attention_head_dim)
325
+ query = torch.randn(shape, device="cuda", dtype=torch.bfloat16) * torch.randint(1, 5, shape, device="cuda", dtype=torch.bfloat16)
326
+ key = torch.randn(shape, device="cuda", dtype=torch.bfloat16) * torch.randint(1, 5, shape, device="cuda", dtype=torch.bfloat16)
327
+ value = torch.randn(shape, device="cuda", dtype=torch.bfloat16) * torch.randint(1, 5, shape, device="cuda", dtype=torch.bfloat16)
328
+
329
+ fn = attention_ops[provider]
330
+ ms, min_ms, max_ms = triton.testing.do_bench(
331
+ lambda: fn(query, key, value),
332
+ warmup=3,
333
+ rep=10,
334
+ quantiles=[0.5, 0.2, 0.8],
335
+ )
336
+ return ms, max_ms, min_ms
337
+
338
+
339
+ with torch.inference_mode():
340
+ correctness()
341
+ benchmark_fn.run(print_data=True, save_path=output_dir.as_posix())
flash_attn/index.html ADDED
@@ -0,0 +1,24 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ <!DOCTYPE html>
2
+ <html>
3
+ <head>
4
+ <meta charset='UTF-8'>
5
+ <title>Directory Index</title>
6
+ <style>
7
+ body { font-family: monospace; margin: 20px; }
8
+ h1 { font-size: 1.5em; }
9
+ ul { list-style-type: none; padding-left: 20px; }
10
+ li { margin: 5px 0; }
11
+ .dir { font-weight: bold; }
12
+ .file { color: #0066cc; }
13
+ a { text-decoration: none; }
14
+ a:hover { text-decoration: underline; }
15
+ </style>
16
+ </head>
17
+ <body>
18
+ <h1>Index of /flash_attn</h1>
19
+ <ul>
20
+ <li><a href='../index.html' class='dir'>../</a></li>
21
+ <li><a href='benchmark.html' class='file'>benchmark.html</a></li>
22
+ </ul>
23
+ </body>
24
+ </html>
index.html CHANGED
@@ -17,8 +17,8 @@
17
  <body>
18
  <h1>Index of /</h1>
19
  <ul>
20
- <li><a href='megablocks/index.html' class='dir'>megablocks/</a></li>
21
- <li><a href='megablocks_yamoe/index.html' class='dir'>megablocks_yamoe/</a></li>
22
  </ul>
23
  </body>
24
  </html>
 
17
  <body>
18
  <h1>Index of /</h1>
19
  <ul>
20
+ <li><a href='flash_attn/index.html' class='dir'>flash_attn/</a></li>
21
+ <li><a href='moe_benchmarks/index.html' class='dir'>moe_benchmarks/</a></li>
22
  </ul>
23
  </body>
24
  </html>
moe_benchmarks/megablocks/cells/forward_and_backward.py ADDED
@@ -0,0 +1,196 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # /// script
2
+ # requires-python = ">=3.12"
3
+ # dependencies = [
4
+ # "accelerate>=1.10.1",
5
+ # "torch>=2.7.0",
6
+ # "kernels==0.10.0",
7
+ # "transformers@https://github.com/huggingface/transformers.git",
8
+ # "ipdb>=0.13.13",
9
+ # "matplotlib>=3.7.2",
10
+ # "numpy>=1.24.3",
11
+ # ]
12
+ # ///
13
+
14
+ import torch
15
+ from transformers import GptOssForCausalLM, PreTrainedTokenizerFast, Mxfp4Config
16
+ import time
17
+ import torch.nn as nn
18
+ from kernels import register_kernel_mapping, Mode, LayerRepository, replace_kernel_forward_from_hub
19
+ import sys
20
+ import torch.profiler
21
+ import gc
22
+ import logging
23
+ from transformers.models.gpt_oss.modeling_gpt_oss import GptOssRMSNorm
24
+
25
+ # remove liger kernel for testing
26
+ replace_kernel_forward_from_hub(GptOssRMSNorm, None)
27
+
28
+ # set to debug logging
29
+ logging.basicConfig(level=logging.INFO)
30
+
31
+ def reset_peak_memory_stats():
32
+ """Clear CUDA cache and reset memory allocation counters."""
33
+ torch.cuda.empty_cache()
34
+ if torch.cuda.is_available():
35
+ torch.cuda.reset_peak_memory_stats()
36
+ gc.collect()
37
+
38
+ def get_memory_stats():
39
+ """Get current and peak CUDA memory usage."""
40
+ if not torch.cuda.is_available():
41
+ return {"allocated_gb": 0, "peak_gb": 0, "reserved_gb": 0}
42
+ return {
43
+ "allocated_gb": torch.cuda.memory_allocated() / 1e9,
44
+ "peak_gb": torch.cuda.max_memory_allocated() / 1e9,
45
+ "reserved_gb": torch.cuda.memory_reserved() / 1e9,
46
+ }
47
+
48
+ def override_kernel_layer_name(cls_name: str, value) -> bool:
49
+ """Helper to dynamically override the kernel_layer_name in a model class."""
50
+ for mod in sys.modules.values():
51
+ if mod is None:
52
+ continue
53
+ obj = getattr(mod, cls_name, None)
54
+ if isinstance(obj, type) and issubclass(obj, nn.Module):
55
+ setattr(obj, "kernel_layer_name", value)
56
+ print(f"Overrode {cls_name}.kernel_layer_name to {value}")
57
+ return True
58
+ return False
59
+
60
+
61
+ # Init the model the normal way
62
+ model_id = "openai/gpt-oss-20b"
63
+ tokenizer = PreTrainedTokenizerFast.from_pretrained(model_id)
64
+ quantization_config = Mxfp4Config(dequantize=True)
65
+
66
+ model = GptOssForCausalLM.from_pretrained(
67
+ model_id,
68
+ dtype="bfloat16",
69
+ device_map="auto",
70
+ use_kernels=True,
71
+ quantization_config=quantization_config,
72
+ ).eval()
73
+
74
+ messages = [
75
+ {"role": "system", "content": "What is Tensor Parallelism?"},
76
+ ]
77
+
78
+ inputs = tokenizer.apply_chat_template(
79
+ messages,
80
+ add_generation_prompt=True,
81
+ return_tensors="pt",
82
+ return_dict=True,
83
+ reasoning_effort="low",
84
+ ).to("cuda")
85
+
86
+ max_tokens = 128 # Reduced to help with memory usage
87
+
88
+ # Clear memory before backward pass
89
+ reset_peak_memory_stats()
90
+ print(f"Pre-generation memory: {get_memory_stats()}")
91
+
92
+ # forward and backward pass
93
+ with torch.autograd.set_grad_enabled(True):
94
+ start_time = time.perf_counter()
95
+ generated = model.generate(
96
+ **inputs,
97
+ max_new_tokens=max_tokens,
98
+ do_sample=False,
99
+ temperature=None,
100
+ )
101
+ end_time = time.perf_counter()
102
+ print(tokenizer.decode(generated[0], skip_special_tokens=False))
103
+ print(f"Generation took {end_time - start_time:.2f} seconds")
104
+ print(f"Post-generation memory: {get_memory_stats()}")
105
+
106
+ # Use gradient checkpointing to reduce memory usage
107
+ if hasattr(model, 'gradient_checkpointing_enable'):
108
+ model.gradient_checkpointing_enable()
109
+ print("Enabled gradient checkpointing")
110
+
111
+ # Reduce sequence length if needed for memory
112
+ max_seq_len = 512 # Limit sequence length for backward pass
113
+ if generated.size(1) > max_seq_len:
114
+ print(f"Truncating sequence from {generated.size(1)} to {max_seq_len} tokens")
115
+ full_sequence = generated[:, -max_seq_len:]
116
+ else:
117
+ full_sequence = generated
118
+
119
+ # Get model outputs for the full sequence
120
+ model.train() # Enable dropout and other training behaviors
121
+
122
+ try:
123
+ outputs = model(
124
+ input_ids=full_sequence,
125
+ labels=full_sequence, # This will compute loss internally
126
+ return_dict=True
127
+ )
128
+ print(f"Post-forward memory: {get_memory_stats()}")
129
+
130
+ # If model doesn't compute loss, compute it manually
131
+ if outputs.loss is None:
132
+ shift_logits = outputs.logits[..., :-1, :].contiguous()
133
+ shift_labels = full_sequence[..., 1:].contiguous()
134
+
135
+ # Use CrossEntropyLoss with ignore_index for padding tokens
136
+ loss_fct = torch.nn.CrossEntropyLoss(ignore_index=tokenizer.pad_token_id if tokenizer.pad_token_id is not None else -100)
137
+ loss = loss_fct(
138
+ shift_logits.view(-1, shift_logits.size(-1)),
139
+ shift_labels.view(-1)
140
+ )
141
+ else:
142
+ loss = outputs.loss
143
+
144
+ print(f"Loss: {loss.item():.4f}")
145
+
146
+ # Clear intermediate tensors to save memory
147
+ del outputs
148
+ torch.cuda.empty_cache()
149
+
150
+ # Perform backward pass with memory management
151
+ print("Running backward pass...")
152
+ print(f"Pre-backward memory: {get_memory_stats()}")
153
+
154
+ loss.backward()
155
+ print(f"Post-backward memory: {get_memory_stats()}")
156
+
157
+ except torch.cuda.OutOfMemoryError as e:
158
+ print(f"OOM during forward/backward pass: {e}")
159
+ print("Try reducing max_tokens or max_seq_len")
160
+ raise
161
+
162
+ # Calculate gradient statistics and print sample gradients
163
+ total_norm = 0.0
164
+ param_count = 0
165
+ grad_samples = {}
166
+
167
+ for name, p in model.named_parameters():
168
+ if p.grad is not None:
169
+ param_count += 1
170
+ grad_norm = p.grad.data.norm(2).item()
171
+ total_norm += grad_norm ** 2
172
+
173
+ # Collect gradient statistics for key layers
174
+ if any(key in name for key in ['embed', 'lm_head', 'mlp.up', 'mlp.down', 'self_attn.q_proj', 'norm']):
175
+ grad_samples[name] = {
176
+ 'norm': grad_norm,
177
+ 'mean': p.grad.data.mean().item(),
178
+ 'std': p.grad.data.std().item(),
179
+ 'max': p.grad.data.max().item(),
180
+ 'min': p.grad.data.min().item(),
181
+ }
182
+
183
+ total_norm = total_norm ** 0.5
184
+
185
+ print(f"\nGradient norm: {total_norm:.4f}")
186
+ print(f"Parameters with gradients: {param_count}")
187
+
188
+ # Print sample gradients from important layers
189
+ print("\nSample gradient statistics:")
190
+ for i, (name, stats) in enumerate(list(grad_samples.items())[:10]):
191
+ print(f" {name[:60]:<60} | norm: {stats['norm']:.4e} | mean: {stats['mean']:.4e} | std: {stats['std']:.4e}")
192
+
193
+ # Optional: zero gradients for next iteration
194
+ model.zero_grad()
195
+ model.eval() # Switch back to eval mode
196
+
moe_benchmarks/megablocks/megablocks_only.html CHANGED
@@ -3710,7 +3710,7 @@ span.linenos.special { color: #000000; background-color: #ffffc0; padding-left:
3710
  <div class="system-info">
3711
  <div class="system-info-header">Generated on:</div>
3712
  <div class="system-info-content">
3713
- Linux x86_64 | Linux-6.11.0-1018-azure-x86_64-with-glibc2.39
3714
  </div>
3715
  </div>
3716
 
@@ -3724,122 +3724,219 @@ span.linenos.special { color: #000000; background-color: #ffffc0; padding-left:
3724
  <p>Next we can run with Megablocks kernels enabled.</p>
3725
  <h3>Forward</h3>
3726
  <p>First, we run a forward pass with Megablocks kernels.</p>
3727
- <div class="cell cell-failed" id="cell-forward_only">
 
 
3728
  <div class="cell-header">
3729
  <span class="collapse-indicators">
3730
- <span onclick="toggleCode('forward_only')" style="cursor: pointer;">▼ code</span>
3731
- <span onclick="toggleOutput('forward_only')" style="cursor: pointer;">▼ output</span>
3732
- <span id="uv-indicator-forward_only" onclick="toggleUvLogsFromHeader('forward_only')" style="cursor: pointer;">▶ uv-logs</span>
3733
  </span> |
3734
- Cell: forward_only | 118.48s | FAILED
3735
- | <button class="run-btn" onclick="runCell('forward_only')">▶ run</button>
3736
- <button class="copy-btn" onclick="copyCell('forward_only')">Copy</button>
3737
- <a href="cells/forward_only.py" target="_blank" class="raw-btn">Raw</a>
3738
  </div>
3739
- <div id="code-forward_only" class="cell-code" data-lines="101">
3740
  <div class="highlight-with-lines">
3741
- <div class="line-numbers" id="lines-forward_only">
3742
- <a class="line-number" data-cell="forward_only" data-line="1" href="#cell-forward_only" onclick="event.preventDefault(); selectCellLine('forward_only', 1, true);">1</a>
3743
- <a class="line-number" data-cell="forward_only" data-line="2" href="#cell-forward_only" onclick="event.preventDefault(); selectCellLine('forward_only', 2, true);">2</a>
3744
- <a class="line-number" data-cell="forward_only" data-line="3" href="#cell-forward_only" onclick="event.preventDefault(); selectCellLine('forward_only', 3, true);">3</a>
3745
- <a class="line-number" data-cell="forward_only" data-line="4" href="#cell-forward_only" onclick="event.preventDefault(); selectCellLine('forward_only', 4, true);">4</a>
3746
- <a class="line-number" data-cell="forward_only" data-line="5" href="#cell-forward_only" onclick="event.preventDefault(); selectCellLine('forward_only', 5, true);">5</a>
3747
- <a class="line-number" data-cell="forward_only" data-line="6" href="#cell-forward_only" onclick="event.preventDefault(); selectCellLine('forward_only', 6, true);">6</a>
3748
- <a class="line-number" data-cell="forward_only" data-line="7" href="#cell-forward_only" onclick="event.preventDefault(); selectCellLine('forward_only', 7, true);">7</a>
3749
- <a class="line-number" data-cell="forward_only" data-line="8" href="#cell-forward_only" onclick="event.preventDefault(); selectCellLine('forward_only', 8, true);">8</a>
3750
- <a class="line-number" data-cell="forward_only" data-line="9" href="#cell-forward_only" onclick="event.preventDefault(); selectCellLine('forward_only', 9, true);">9</a>
3751
- <a class="line-number" data-cell="forward_only" data-line="10" href="#cell-forward_only" onclick="event.preventDefault(); selectCellLine('forward_only', 10, true);">10</a>
3752
- <a class="line-number" data-cell="forward_only" data-line="11" href="#cell-forward_only" onclick="event.preventDefault(); selectCellLine('forward_only', 11, true);">11</a>
3753
- <a class="line-number" data-cell="forward_only" data-line="12" href="#cell-forward_only" onclick="event.preventDefault(); selectCellLine('forward_only', 12, true);">12</a>
3754
- <a class="line-number" data-cell="forward_only" data-line="13" href="#cell-forward_only" onclick="event.preventDefault(); selectCellLine('forward_only', 13, true);">13</a>
3755
- <a class="line-number" data-cell="forward_only" data-line="14" href="#cell-forward_only" onclick="event.preventDefault(); selectCellLine('forward_only', 14, true);">14</a>
3756
- <a class="line-number" data-cell="forward_only" data-line="15" href="#cell-forward_only" onclick="event.preventDefault(); selectCellLine('forward_only', 15, true);">15</a>
3757
- <a class="line-number" data-cell="forward_only" data-line="16" href="#cell-forward_only" onclick="event.preventDefault(); selectCellLine('forward_only', 16, true);">16</a>
3758
- <a class="line-number" data-cell="forward_only" data-line="17" href="#cell-forward_only" onclick="event.preventDefault(); selectCellLine('forward_only', 17, true);">17</a>
3759
- <a class="line-number" data-cell="forward_only" data-line="18" href="#cell-forward_only" onclick="event.preventDefault(); selectCellLine('forward_only', 18, true);">18</a>
3760
- <a class="line-number" data-cell="forward_only" data-line="19" href="#cell-forward_only" onclick="event.preventDefault(); selectCellLine('forward_only', 19, true);">19</a>
3761
- <a class="line-number" data-cell="forward_only" data-line="20" href="#cell-forward_only" onclick="event.preventDefault(); selectCellLine('forward_only', 20, true);">20</a>
3762
- <a class="line-number" data-cell="forward_only" data-line="21" href="#cell-forward_only" onclick="event.preventDefault(); selectCellLine('forward_only', 21, true);">21</a>
3763
- <a class="line-number" data-cell="forward_only" data-line="22" href="#cell-forward_only" onclick="event.preventDefault(); selectCellLine('forward_only', 22, true);">22</a>
3764
- <a class="line-number" data-cell="forward_only" data-line="23" href="#cell-forward_only" onclick="event.preventDefault(); selectCellLine('forward_only', 23, true);">23</a>
3765
- <a class="line-number" data-cell="forward_only" data-line="24" href="#cell-forward_only" onclick="event.preventDefault(); selectCellLine('forward_only', 24, true);">24</a>
3766
- <a class="line-number" data-cell="forward_only" data-line="25" href="#cell-forward_only" onclick="event.preventDefault(); selectCellLine('forward_only', 25, true);">25</a>
3767
- <a class="line-number" data-cell="forward_only" data-line="26" href="#cell-forward_only" onclick="event.preventDefault(); selectCellLine('forward_only', 26, true);">26</a>
3768
- <a class="line-number" data-cell="forward_only" data-line="27" href="#cell-forward_only" onclick="event.preventDefault(); selectCellLine('forward_only', 27, true);">27</a>
3769
- <a class="line-number" data-cell="forward_only" data-line="28" href="#cell-forward_only" onclick="event.preventDefault(); selectCellLine('forward_only', 28, true);">28</a>
3770
- <a class="line-number" data-cell="forward_only" data-line="29" href="#cell-forward_only" onclick="event.preventDefault(); selectCellLine('forward_only', 29, true);">29</a>
3771
- <a class="line-number" data-cell="forward_only" data-line="30" href="#cell-forward_only" onclick="event.preventDefault(); selectCellLine('forward_only', 30, true);">30</a>
3772
- <a class="line-number" data-cell="forward_only" data-line="31" href="#cell-forward_only" onclick="event.preventDefault(); selectCellLine('forward_only', 31, true);">31</a>
3773
- <a class="line-number" data-cell="forward_only" data-line="32" href="#cell-forward_only" onclick="event.preventDefault(); selectCellLine('forward_only', 32, true);">32</a>
3774
- <a class="line-number" data-cell="forward_only" data-line="33" href="#cell-forward_only" onclick="event.preventDefault(); selectCellLine('forward_only', 33, true);">33</a>
3775
- <a class="line-number" data-cell="forward_only" data-line="34" href="#cell-forward_only" onclick="event.preventDefault(); selectCellLine('forward_only', 34, true);">34</a>
3776
- <a class="line-number" data-cell="forward_only" data-line="35" href="#cell-forward_only" onclick="event.preventDefault(); selectCellLine('forward_only', 35, true);">35</a>
3777
- <a class="line-number" data-cell="forward_only" data-line="36" href="#cell-forward_only" onclick="event.preventDefault(); selectCellLine('forward_only', 36, true);">36</a>
3778
- <a class="line-number" data-cell="forward_only" data-line="37" href="#cell-forward_only" onclick="event.preventDefault(); selectCellLine('forward_only', 37, true);">37</a>
3779
- <a class="line-number" data-cell="forward_only" data-line="38" href="#cell-forward_only" onclick="event.preventDefault(); selectCellLine('forward_only', 38, true);">38</a>
3780
- <a class="line-number" data-cell="forward_only" data-line="39" href="#cell-forward_only" onclick="event.preventDefault(); selectCellLine('forward_only', 39, true);">39</a>
3781
- <a class="line-number" data-cell="forward_only" data-line="40" href="#cell-forward_only" onclick="event.preventDefault(); selectCellLine('forward_only', 40, true);">40</a>
3782
- <a class="line-number" data-cell="forward_only" data-line="41" href="#cell-forward_only" onclick="event.preventDefault(); selectCellLine('forward_only', 41, true);">41</a>
3783
- <a class="line-number" data-cell="forward_only" data-line="42" href="#cell-forward_only" onclick="event.preventDefault(); selectCellLine('forward_only', 42, true);">42</a>
3784
- <a class="line-number" data-cell="forward_only" data-line="43" href="#cell-forward_only" onclick="event.preventDefault(); selectCellLine('forward_only', 43, true);">43</a>
3785
- <a class="line-number" data-cell="forward_only" data-line="44" href="#cell-forward_only" onclick="event.preventDefault(); selectCellLine('forward_only', 44, true);">44</a>
3786
- <a class="line-number" data-cell="forward_only" data-line="45" href="#cell-forward_only" onclick="event.preventDefault(); selectCellLine('forward_only', 45, true);">45</a>
3787
- <a class="line-number" data-cell="forward_only" data-line="46" href="#cell-forward_only" onclick="event.preventDefault(); selectCellLine('forward_only', 46, true);">46</a>
3788
- <a class="line-number" data-cell="forward_only" data-line="47" href="#cell-forward_only" onclick="event.preventDefault(); selectCellLine('forward_only', 47, true);">47</a>
3789
- <a class="line-number" data-cell="forward_only" data-line="48" href="#cell-forward_only" onclick="event.preventDefault(); selectCellLine('forward_only', 48, true);">48</a>
3790
- <a class="line-number" data-cell="forward_only" data-line="49" href="#cell-forward_only" onclick="event.preventDefault(); selectCellLine('forward_only', 49, true);">49</a>
3791
- <a class="line-number" data-cell="forward_only" data-line="50" href="#cell-forward_only" onclick="event.preventDefault(); selectCellLine('forward_only', 50, true);">50</a>
3792
- <a class="line-number" data-cell="forward_only" data-line="51" href="#cell-forward_only" onclick="event.preventDefault(); selectCellLine('forward_only', 51, true);">51</a>
3793
- <a class="line-number" data-cell="forward_only" data-line="52" href="#cell-forward_only" onclick="event.preventDefault(); selectCellLine('forward_only', 52, true);">52</a>
3794
- <a class="line-number" data-cell="forward_only" data-line="53" href="#cell-forward_only" onclick="event.preventDefault(); selectCellLine('forward_only', 53, true);">53</a>
3795
- <a class="line-number" data-cell="forward_only" data-line="54" href="#cell-forward_only" onclick="event.preventDefault(); selectCellLine('forward_only', 54, true);">54</a>
3796
- <a class="line-number" data-cell="forward_only" data-line="55" href="#cell-forward_only" onclick="event.preventDefault(); selectCellLine('forward_only', 55, true);">55</a>
3797
- <a class="line-number" data-cell="forward_only" data-line="56" href="#cell-forward_only" onclick="event.preventDefault(); selectCellLine('forward_only', 56, true);">56</a>
3798
- <a class="line-number" data-cell="forward_only" data-line="57" href="#cell-forward_only" onclick="event.preventDefault(); selectCellLine('forward_only', 57, true);">57</a>
3799
- <a class="line-number" data-cell="forward_only" data-line="58" href="#cell-forward_only" onclick="event.preventDefault(); selectCellLine('forward_only', 58, true);">58</a>
3800
- <a class="line-number" data-cell="forward_only" data-line="59" href="#cell-forward_only" onclick="event.preventDefault(); selectCellLine('forward_only', 59, true);">59</a>
3801
- <a class="line-number" data-cell="forward_only" data-line="60" href="#cell-forward_only" onclick="event.preventDefault(); selectCellLine('forward_only', 60, true);">60</a>
3802
- <a class="line-number" data-cell="forward_only" data-line="61" href="#cell-forward_only" onclick="event.preventDefault(); selectCellLine('forward_only', 61, true);">61</a>
3803
- <a class="line-number" data-cell="forward_only" data-line="62" href="#cell-forward_only" onclick="event.preventDefault(); selectCellLine('forward_only', 62, true);">62</a>
3804
- <a class="line-number" data-cell="forward_only" data-line="63" href="#cell-forward_only" onclick="event.preventDefault(); selectCellLine('forward_only', 63, true);">63</a>
3805
- <a class="line-number" data-cell="forward_only" data-line="64" href="#cell-forward_only" onclick="event.preventDefault(); selectCellLine('forward_only', 64, true);">64</a>
3806
- <a class="line-number" data-cell="forward_only" data-line="65" href="#cell-forward_only" onclick="event.preventDefault(); selectCellLine('forward_only', 65, true);">65</a>
3807
- <a class="line-number" data-cell="forward_only" data-line="66" href="#cell-forward_only" onclick="event.preventDefault(); selectCellLine('forward_only', 66, true);">66</a>
3808
- <a class="line-number" data-cell="forward_only" data-line="67" href="#cell-forward_only" onclick="event.preventDefault(); selectCellLine('forward_only', 67, true);">67</a>
3809
- <a class="line-number" data-cell="forward_only" data-line="68" href="#cell-forward_only" onclick="event.preventDefault(); selectCellLine('forward_only', 68, true);">68</a>
3810
- <a class="line-number" data-cell="forward_only" data-line="69" href="#cell-forward_only" onclick="event.preventDefault(); selectCellLine('forward_only', 69, true);">69</a>
3811
- <a class="line-number" data-cell="forward_only" data-line="70" href="#cell-forward_only" onclick="event.preventDefault(); selectCellLine('forward_only', 70, true);">70</a>
3812
- <a class="line-number" data-cell="forward_only" data-line="71" href="#cell-forward_only" onclick="event.preventDefault(); selectCellLine('forward_only', 71, true);">71</a>
3813
- <a class="line-number" data-cell="forward_only" data-line="72" href="#cell-forward_only" onclick="event.preventDefault(); selectCellLine('forward_only', 72, true);">72</a>
3814
- <a class="line-number" data-cell="forward_only" data-line="73" href="#cell-forward_only" onclick="event.preventDefault(); selectCellLine('forward_only', 73, true);">73</a>
3815
- <a class="line-number" data-cell="forward_only" data-line="74" href="#cell-forward_only" onclick="event.preventDefault(); selectCellLine('forward_only', 74, true);">74</a>
3816
- <a class="line-number" data-cell="forward_only" data-line="75" href="#cell-forward_only" onclick="event.preventDefault(); selectCellLine('forward_only', 75, true);">75</a>
3817
- <a class="line-number" data-cell="forward_only" data-line="76" href="#cell-forward_only" onclick="event.preventDefault(); selectCellLine('forward_only', 76, true);">76</a>
3818
- <a class="line-number" data-cell="forward_only" data-line="77" href="#cell-forward_only" onclick="event.preventDefault(); selectCellLine('forward_only', 77, true);">77</a>
3819
- <a class="line-number" data-cell="forward_only" data-line="78" href="#cell-forward_only" onclick="event.preventDefault(); selectCellLine('forward_only', 78, true);">78</a>
3820
- <a class="line-number" data-cell="forward_only" data-line="79" href="#cell-forward_only" onclick="event.preventDefault(); selectCellLine('forward_only', 79, true);">79</a>
3821
- <a class="line-number" data-cell="forward_only" data-line="80" href="#cell-forward_only" onclick="event.preventDefault(); selectCellLine('forward_only', 80, true);">80</a>
3822
- <a class="line-number" data-cell="forward_only" data-line="81" href="#cell-forward_only" onclick="event.preventDefault(); selectCellLine('forward_only', 81, true);">81</a>
3823
- <a class="line-number" data-cell="forward_only" data-line="82" href="#cell-forward_only" onclick="event.preventDefault(); selectCellLine('forward_only', 82, true);">82</a>
3824
- <a class="line-number" data-cell="forward_only" data-line="83" href="#cell-forward_only" onclick="event.preventDefault(); selectCellLine('forward_only', 83, true);">83</a>
3825
- <a class="line-number" data-cell="forward_only" data-line="84" href="#cell-forward_only" onclick="event.preventDefault(); selectCellLine('forward_only', 84, true);">84</a>
3826
- <a class="line-number" data-cell="forward_only" data-line="85" href="#cell-forward_only" onclick="event.preventDefault(); selectCellLine('forward_only', 85, true);">85</a>
3827
- <a class="line-number" data-cell="forward_only" data-line="86" href="#cell-forward_only" onclick="event.preventDefault(); selectCellLine('forward_only', 86, true);">86</a>
3828
- <a class="line-number" data-cell="forward_only" data-line="87" href="#cell-forward_only" onclick="event.preventDefault(); selectCellLine('forward_only', 87, true);">87</a>
3829
- <a class="line-number" data-cell="forward_only" data-line="88" href="#cell-forward_only" onclick="event.preventDefault(); selectCellLine('forward_only', 88, true);">88</a>
3830
- <a class="line-number" data-cell="forward_only" data-line="89" href="#cell-forward_only" onclick="event.preventDefault(); selectCellLine('forward_only', 89, true);">89</a>
3831
- <a class="line-number" data-cell="forward_only" data-line="90" href="#cell-forward_only" onclick="event.preventDefault(); selectCellLine('forward_only', 90, true);">90</a>
3832
- <a class="line-number" data-cell="forward_only" data-line="91" href="#cell-forward_only" onclick="event.preventDefault(); selectCellLine('forward_only', 91, true);">91</a>
3833
- <a class="line-number" data-cell="forward_only" data-line="92" href="#cell-forward_only" onclick="event.preventDefault(); selectCellLine('forward_only', 92, true);">92</a>
3834
- <a class="line-number" data-cell="forward_only" data-line="93" href="#cell-forward_only" onclick="event.preventDefault(); selectCellLine('forward_only', 93, true);">93</a>
3835
- <a class="line-number" data-cell="forward_only" data-line="94" href="#cell-forward_only" onclick="event.preventDefault(); selectCellLine('forward_only', 94, true);">94</a>
3836
- <a class="line-number" data-cell="forward_only" data-line="95" href="#cell-forward_only" onclick="event.preventDefault(); selectCellLine('forward_only', 95, true);">95</a>
3837
- <a class="line-number" data-cell="forward_only" data-line="96" href="#cell-forward_only" onclick="event.preventDefault(); selectCellLine('forward_only', 96, true);">96</a>
3838
- <a class="line-number" data-cell="forward_only" data-line="97" href="#cell-forward_only" onclick="event.preventDefault(); selectCellLine('forward_only', 97, true);">97</a>
3839
- <a class="line-number" data-cell="forward_only" data-line="98" href="#cell-forward_only" onclick="event.preventDefault(); selectCellLine('forward_only', 98, true);">98</a>
3840
- <a class="line-number" data-cell="forward_only" data-line="99" href="#cell-forward_only" onclick="event.preventDefault(); selectCellLine('forward_only', 99, true);">99</a>
3841
- <a class="line-number" data-cell="forward_only" data-line="100" href="#cell-forward_only" onclick="event.preventDefault(); selectCellLine('forward_only', 100, true);">100</a>
3842
- <a class="line-number" data-cell="forward_only" data-line="101" href="#cell-forward_only" onclick="event.preventDefault(); selectCellLine('forward_only', 101, true);">101</a>
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
3843
  </div>
3844
  <div class="code-wrap">
3845
  <div class="highlight"><pre><span></span><span class="c1"># /// script</span>
@@ -3866,7 +3963,7 @@ Cell: forward_only | 118.48s | FAILED
3866
  <span class="kn">import</span><span class="w"> </span><span class="nn">logging</span>
3867
  <span class="kn">from</span><span class="w"> </span><span class="nn">transformers.models.gpt_oss.modeling_gpt_oss</span><span class="w"> </span><span class="kn">import</span> <span class="n">GptOssRMSNorm</span>
3868
 
3869
-
3870
  <span class="n">replace_kernel_forward_from_hub</span><span class="p">(</span><span class="n">GptOssRMSNorm</span><span class="p">,</span> <span class="kc">None</span><span class="p">)</span>
3871
 
3872
  <span class="c1"># set to debug logging</span>
@@ -3907,8 +4004,6 @@ Cell: forward_only | 118.48s | FAILED
3907
  <span class="n">tokenizer</span> <span class="o">=</span> <span class="n">PreTrainedTokenizerFast</span><span class="o">.</span><span class="n">from_pretrained</span><span class="p">(</span><span class="n">model_id</span><span class="p">)</span>
3908
  <span class="n">quantization_config</span> <span class="o">=</span> <span class="n">Mxfp4Config</span><span class="p">(</span><span class="n">dequantize</span><span class="o">=</span><span class="kc">True</span><span class="p">)</span>
3909
 
3910
-
3911
-
3912
  <span class="n">model</span> <span class="o">=</span> <span class="n">GptOssForCausalLM</span><span class="o">.</span><span class="n">from_pretrained</span><span class="p">(</span>
3913
  <span class="n">model_id</span><span class="p">,</span>
3914
  <span class="n">dtype</span><span class="o">=</span><span class="s2">&quot;bfloat16&quot;</span><span class="p">,</span>
@@ -3929,9 +4024,14 @@ Cell: forward_only | 118.48s | FAILED
3929
  <span class="n">reasoning_effort</span><span class="o">=</span><span class="s2">&quot;low&quot;</span><span class="p">,</span>
3930
  <span class="p">)</span><span class="o">.</span><span class="n">to</span><span class="p">(</span><span class="s2">&quot;cuda&quot;</span><span class="p">)</span>
3931
 
3932
- <span class="n">max_tokens</span> <span class="o">=</span> <span class="mi">256</span>
 
 
 
 
3933
 
3934
- <span class="k">with</span> <span class="n">torch</span><span class="o">.</span><span class="n">inference_mode</span><span class="p">():</span>
 
3935
  <span class="n">start_time</span> <span class="o">=</span> <span class="n">time</span><span class="o">.</span><span class="n">perf_counter</span><span class="p">()</span>
3936
  <span class="n">generated</span> <span class="o">=</span> <span class="n">model</span><span class="o">.</span><span class="n">generate</span><span class="p">(</span>
3937
  <span class="o">**</span><span class="n">inputs</span><span class="p">,</span>
@@ -3940,144 +4040,124 @@ Cell: forward_only | 118.48s | FAILED
3940
  <span class="n">temperature</span><span class="o">=</span><span class="kc">None</span><span class="p">,</span>
3941
  <span class="p">)</span>
3942
  <span class="n">end_time</span> <span class="o">=</span> <span class="n">time</span><span class="o">.</span><span class="n">perf_counter</span><span class="p">()</span>
3943
-
3944
- <span class="nb">print</span><span class="p">(</span><span class="n">tokenizer</span><span class="o">.</span><span class="n">decode</span><span class="p">(</span><span class="n">generated</span><span class="p">[</span><span class="mi">0</span><span class="p">],</span> <span class="n">skip_special_tokens</span><span class="o">=</span><span class="kc">False</span><span class="p">))</span>
3945
- <span class="nb">print</span><span class="p">(</span><span class="sa">f</span><span class="s2">&quot;Generation took </span><span class="si">{</span><span class="n">end_time</span><span class="w"> </span><span class="o">-</span><span class="w"> </span><span class="n">start_time</span><span class="si">:</span><span class="s2">.2f</span><span class="si">}</span><span class="s2"> seconds&quot;</span><span class="p">)</span>
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
3946
  </pre></div>
3947
 
3948
- <div class="code-line-highlight" id="line-highlight-forward_only"></div>
3949
  </div>
3950
  </div>
3951
  </div>
3952
- <div id="output-forward_only" class="cell-output">
3953
- <div class="uv-install-logs" id="uv-logs-forward_only">
3954
- <div class="uv-logs-header" onclick="toggleUvLogs(this)">▶ UV Install Logs</div>
3955
- <div class="uv-logs-content" style="display: none;">
3956
  Updating https://github.com/huggingface/transformers.git (HEAD)
3957
- Updated https://github.com/huggingface/transformers.git (7258ea44bc0c0a425a468f66f8559d1de8c4126d)
3958
- Building transformers @ git+https://github.com/huggingface/transformers.git@7258ea44bc0c0a425a468f66f8559d1de8c4126d
3959
- Downloading triton (148.4MiB)
3960
- Downloading nvidia-cuda-cupti-cu12 (9.8MiB)
3961
- Downloading hf-xet (3.0MiB)
3962
- Downloading pillow (6.3MiB)
3963
- Downloading tokenizers (3.1MiB)
3964
- Downloading jedi (1.5MiB)
3965
- Downloading nvidia-nvjitlink-cu12 (37.4MiB)
3966
- Downloading nvidia-cufile-cu12 (1.1MiB)
3967
- Downloading networkx (1.9MiB)
3968
- Downloading nvidia-cusparselt-cu12 (273.9MiB)
3969
- Downloading nvidia-cusolver-cu12 (255.1MiB)
3970
- Downloading nvidia-cufft-cu12 (184.2MiB)
3971
- Downloading nvidia-curand-cu12 (60.7MiB)
3972
- Downloading nvidia-cublas-cu12 (566.8MiB)
3973
- Downloading nvidia-nccl-cu12 (307.4MiB)
3974
- Downloading nvidia-cudnn-cu12 (674.0MiB)
3975
- Downloading nvidia-cuda-nvrtc-cu12 (84.0MiB)
3976
- Downloading sympy (6.0MiB)
3977
- Downloading fonttools (4.7MiB)
3978
- Downloading nvidia-cusparse-cu12 (274.9MiB)
3979
- Downloading torch (846.8MiB)
3980
- Downloading numpy (15.9MiB)
3981
- Downloading matplotlib (8.3MiB)
3982
- Downloading kiwisolver (1.4MiB)
3983
- Downloading nvidia-cufile-cu12
3984
- Downloading kiwisolver
3985
- Downloading hf-xet
3986
- Downloading tokenizers
3987
- Downloading networkx
3988
- Downloading fonttools
3989
- Downloading pillow
3990
- Downloading matplotlib
3991
- Downloading nvidia-cuda-cupti-cu12
3992
- Downloading sympy
3993
- Downloading numpy
3994
- Downloading jedi
3995
- Built transformers @ git+https://github.com/huggingface/transformers.git@7258ea44bc0c0a425a468f66f8559d1de8c4126d
3996
- Downloading nvidia-nvjitlink-cu12
3997
- Downloading nvidia-curand-cu12
3998
- Downloading nvidia-cuda-nvrtc-cu12
3999
- Downloading triton
4000
- Downloading nvidia-cufft-cu12
4001
- Downloading nvidia-cusolver-cu12
4002
- Downloading nvidia-cusparselt-cu12
4003
- Downloading nvidia-cusparse-cu12
4004
- Downloading nvidia-nccl-cu12
4005
- Downloading nvidia-cublas-cu12
4006
- Downloading nvidia-cudnn-cu12
4007
- Downloading torch
4008
- Installed 69 packages in 321ms
4009
- </div>
4010
  </div>
4011
- <div class="cell-stderr">Fetching 3 files: 0%| | 0/3 [00:00&lt;?, ?it/s]
4012
- Fetching 3 files: 0%| | 0/3 [00:50&lt;?, ?it/s]
4013
- Traceback (most recent call last):
4014
- File &quot;/home/runner/work/kernels-uvnotes/kernels-uvnotes/moe_benchmarks/megablocks/.uvnote/cells/forward_only.py&quot;, line 68, in &lt;module&gt;
4015
- model = GptOssForCausalLM.from_pretrained(
4016
- ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
4017
- File &quot;/home/runner/work/_temp/setup-uv-cache/environments-v2/forward-only-b65004b2d0cb4ca8/lib/python3.12/site-packages/transformers/modeling_utils.py&quot;, line 285, in _wrapper
4018
- return func(*args, **kwargs)
4019
- ^^^^^^^^^^^^^^^^^^^^^
4020
- File &quot;/home/runner/work/_temp/setup-uv-cache/environments-v2/forward-only-b65004b2d0cb4ca8/lib/python3.12/site-packages/transformers/modeling_utils.py&quot;, line 4904, in from_pretrained
4021
- checkpoint_files, sharded_metadata = _get_resolved_checkpoint_files(
4022
- ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
4023
- File &quot;/home/runner/work/_temp/setup-uv-cache/environments-v2/forward-only-b65004b2d0cb4ca8/lib/python3.12/site-packages/transformers/modeling_utils.py&quot;, line 1239, in _get_resolved_checkpoint_files
4024
- checkpoint_files, sharded_metadata = get_checkpoint_shard_files(
4025
- ^^^^^^^^^^^^^^^^^^^^^^^^^^^
4026
- File &quot;/home/runner/work/_temp/setup-uv-cache/environments-v2/forward-only-b65004b2d0cb4ca8/lib/python3.12/site-packages/transformers/utils/hub.py&quot;, line 1116, in get_checkpoint_shard_files
4027
- cached_filenames = cached_files(
4028
- ^^^^^^^^^^^^^
4029
- File &quot;/home/runner/work/_temp/setup-uv-cache/environments-v2/forward-only-b65004b2d0cb4ca8/lib/python3.12/site-packages/transformers/utils/hub.py&quot;, line 564, in cached_files
4030
- raise e
4031
- File &quot;/home/runner/work/_temp/setup-uv-cache/environments-v2/forward-only-b65004b2d0cb4ca8/lib/python3.12/site-packages/transformers/utils/hub.py&quot;, line 491, in cached_files
4032
- snapshot_download(
4033
- File &quot;/home/runner/work/_temp/setup-uv-cache/environments-v2/forward-only-b65004b2d0cb4ca8/lib/python3.12/site-packages/huggingface_hub/utils/_validators.py&quot;, line 114, in _inner_fn
4034
- return fn(*args, **kwargs)
4035
- ^^^^^^^^^^^^^^^^^^^
4036
- File &quot;/home/runner/work/_temp/setup-uv-cache/environments-v2/forward-only-b65004b2d0cb4ca8/lib/python3.12/site-packages/huggingface_hub/_snapshot_download.py&quot;, line 332, in snapshot_download
4037
- thread_map(
4038
- File &quot;/home/runner/work/_temp/setup-uv-cache/environments-v2/forward-only-b65004b2d0cb4ca8/lib/python3.12/site-packages/tqdm/contrib/concurrent.py&quot;, line 69, in thread_map
4039
- return _executor_map(ThreadPoolExecutor, fn, *iterables, **tqdm_kwargs)
4040
- ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
4041
- File &quot;/home/runner/work/_temp/setup-uv-cache/environments-v2/forward-only-b65004b2d0cb4ca8/lib/python3.12/site-packages/tqdm/contrib/concurrent.py&quot;, line 51, in _executor_map
4042
- return list(tqdm_class(ex.map(fn, *iterables, chunksize=chunksize), **kwargs))
4043
- ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
4044
- File &quot;/home/runner/work/_temp/setup-uv-cache/environments-v2/forward-only-b65004b2d0cb4ca8/lib/python3.12/site-packages/tqdm/std.py&quot;, line 1181, in __iter__
4045
- for obj in iterable:
4046
- File &quot;/usr/lib/python3.12/concurrent/futures/_base.py&quot;, line 619, in result_iterator
4047
- yield _result_or_cancel(fs.pop())
4048
- ^^^^^^^^^^^^^^^^^^^^^^^^^^^
4049
- File &quot;/usr/lib/python3.12/concurrent/futures/_base.py&quot;, line 317, in _result_or_cancel
4050
- return fut.result(timeout)
4051
- ^^^^^^^^^^^^^^^^^^^
4052
- File &quot;/usr/lib/python3.12/concurrent/futures/_base.py&quot;, line 456, in result
4053
- return self.__get_result()
4054
- ^^^^^^^^^^^^^^^^^^^
4055
- File &quot;/usr/lib/python3.12/concurrent/futures/_base.py&quot;, line 401, in __get_result
4056
- raise self._exception
4057
- File &quot;/usr/lib/python3.12/concurrent/futures/thread.py&quot;, line 58, in run
4058
- result = self.fn(*self.args, **self.kwargs)
4059
- ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
4060
- File &quot;/home/runner/work/_temp/setup-uv-cache/environments-v2/forward-only-b65004b2d0cb4ca8/lib/python3.12/site-packages/huggingface_hub/_snapshot_download.py&quot;, line 306, in _inner_hf_hub_download
4061
- return hf_hub_download(
4062
- ^^^^^^^^^^^^^^^^
4063
- File &quot;/home/runner/work/_temp/setup-uv-cache/environments-v2/forward-only-b65004b2d0cb4ca8/lib/python3.12/site-packages/huggingface_hub/utils/_validators.py&quot;, line 114, in _inner_fn
4064
- return fn(*args, **kwargs)
4065
- ^^^^^^^^^^^^^^^^^^^
4066
- File &quot;/home/runner/work/_temp/setup-uv-cache/environments-v2/forward-only-b65004b2d0cb4ca8/lib/python3.12/site-packages/huggingface_hub/file_download.py&quot;, line 1010, in hf_hub_download
4067
- return _hf_hub_download_to_cache_dir(
4068
- ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
4069
- File &quot;/home/runner/work/_temp/setup-uv-cache/environments-v2/forward-only-b65004b2d0cb4ca8/lib/python3.12/site-packages/huggingface_hub/file_download.py&quot;, line 1171, in _hf_hub_download_to_cache_dir
4070
- _download_to_tmp_and_move(
4071
- File &quot;/home/runner/work/_temp/setup-uv-cache/environments-v2/forward-only-b65004b2d0cb4ca8/lib/python3.12/site-packages/huggingface_hub/file_download.py&quot;, line 1723, in _download_to_tmp_and_move
4072
- xet_get(
4073
- File &quot;/home/runner/work/_temp/setup-uv-cache/environments-v2/forward-only-b65004b2d0cb4ca8/lib/python3.12/site-packages/huggingface_hub/file_download.py&quot;, line 629, in xet_get
4074
- download_files(
4075
- RuntimeError: Data processing error: CAS service error : IO Error: No space left on device (os error 28)</div>
4076
  </div>
4077
  </div>
4078
-
4079
- <h2>Forward and Backward</h2>
4080
- <p>Next, we run a forward and backward pass with Megablocks kernels enabled. This should be more memory efficient and allow us to complete the backward pass without running out of memory.</p>
4081
  </div>
4082
 
4083
  </body>
 
3710
  <div class="system-info">
3711
  <div class="system-info-header">Generated on:</div>
3712
  <div class="system-info-content">
3713
+ Linux x86_64 | Linux-6.12.40-64.114.amzn2023.x86_64-x86_64-with-glibc2.36
3714
  </div>
3715
  </div>
3716
 
 
3724
  <p>Next we can run with Megablocks kernels enabled.</p>
3725
  <h3>Forward</h3>
3726
  <p>First, we run a forward pass with Megablocks kernels.</p>
3727
+ <h2>Forward and Backward</h2>
3728
+ <p>Next, we run a forward and backward pass with Megablocks kernels enabled. This should be more memory efficient and allow us to complete the backward pass without running out of memory.</p>
3729
+ <div class="cell cell-failed" id="cell-forward_and_backward">
3730
  <div class="cell-header">
3731
  <span class="collapse-indicators">
3732
+ <span onclick="toggleCode('forward_and_backward')" style="cursor: pointer;">▼ code</span>
3733
+ <span onclick="toggleOutput('forward_and_backward')" style="cursor: pointer;">▼ output</span>
3734
+ <span id="uv-indicator-forward_and_backward" style="cursor: default; opacity: 0.3;">▶ uv-logs</span>
3735
  </span> |
3736
+ Cell: forward_and_backward | 19.43s | FAILED
3737
+ | <button class="run-btn" onclick="runCell('forward_and_backward')">▶ run</button>
3738
+ <button class="copy-btn" onclick="copyCell('forward_and_backward')">Copy</button>
3739
+ <a href="cells/forward_and_backward.py" target="_blank" class="raw-btn">Raw</a>
3740
  </div>
3741
+ <div id="code-forward_and_backward" class="cell-code" data-lines="196">
3742
  <div class="highlight-with-lines">
3743
+ <div class="line-numbers" id="lines-forward_and_backward">
3744
+ <a class="line-number" data-cell="forward_and_backward" data-line="1" href="#cell-forward_and_backward" onclick="event.preventDefault(); selectCellLine('forward_and_backward', 1, true);">1</a>
3745
+ <a class="line-number" data-cell="forward_and_backward" data-line="2" href="#cell-forward_and_backward" onclick="event.preventDefault(); selectCellLine('forward_and_backward', 2, true);">2</a>
3746
+ <a class="line-number" data-cell="forward_and_backward" data-line="3" href="#cell-forward_and_backward" onclick="event.preventDefault(); selectCellLine('forward_and_backward', 3, true);">3</a>
3747
+ <a class="line-number" data-cell="forward_and_backward" data-line="4" href="#cell-forward_and_backward" onclick="event.preventDefault(); selectCellLine('forward_and_backward', 4, true);">4</a>
3748
+ <a class="line-number" data-cell="forward_and_backward" data-line="5" href="#cell-forward_and_backward" onclick="event.preventDefault(); selectCellLine('forward_and_backward', 5, true);">5</a>
3749
+ <a class="line-number" data-cell="forward_and_backward" data-line="6" href="#cell-forward_and_backward" onclick="event.preventDefault(); selectCellLine('forward_and_backward', 6, true);">6</a>
3750
+ <a class="line-number" data-cell="forward_and_backward" data-line="7" href="#cell-forward_and_backward" onclick="event.preventDefault(); selectCellLine('forward_and_backward', 7, true);">7</a>
3751
+ <a class="line-number" data-cell="forward_and_backward" data-line="8" href="#cell-forward_and_backward" onclick="event.preventDefault(); selectCellLine('forward_and_backward', 8, true);">8</a>
3752
+ <a class="line-number" data-cell="forward_and_backward" data-line="9" href="#cell-forward_and_backward" onclick="event.preventDefault(); selectCellLine('forward_and_backward', 9, true);">9</a>
3753
+ <a class="line-number" data-cell="forward_and_backward" data-line="10" href="#cell-forward_and_backward" onclick="event.preventDefault(); selectCellLine('forward_and_backward', 10, true);">10</a>
3754
+ <a class="line-number" data-cell="forward_and_backward" data-line="11" href="#cell-forward_and_backward" onclick="event.preventDefault(); selectCellLine('forward_and_backward', 11, true);">11</a>
3755
+ <a class="line-number" data-cell="forward_and_backward" data-line="12" href="#cell-forward_and_backward" onclick="event.preventDefault(); selectCellLine('forward_and_backward', 12, true);">12</a>
3756
+ <a class="line-number" data-cell="forward_and_backward" data-line="13" href="#cell-forward_and_backward" onclick="event.preventDefault(); selectCellLine('forward_and_backward', 13, true);">13</a>
3757
+ <a class="line-number" data-cell="forward_and_backward" data-line="14" href="#cell-forward_and_backward" onclick="event.preventDefault(); selectCellLine('forward_and_backward', 14, true);">14</a>
3758
+ <a class="line-number" data-cell="forward_and_backward" data-line="15" href="#cell-forward_and_backward" onclick="event.preventDefault(); selectCellLine('forward_and_backward', 15, true);">15</a>
3759
+ <a class="line-number" data-cell="forward_and_backward" data-line="16" href="#cell-forward_and_backward" onclick="event.preventDefault(); selectCellLine('forward_and_backward', 16, true);">16</a>
3760
+ <a class="line-number" data-cell="forward_and_backward" data-line="17" href="#cell-forward_and_backward" onclick="event.preventDefault(); selectCellLine('forward_and_backward', 17, true);">17</a>
3761
+ <a class="line-number" data-cell="forward_and_backward" data-line="18" href="#cell-forward_and_backward" onclick="event.preventDefault(); selectCellLine('forward_and_backward', 18, true);">18</a>
3762
+ <a class="line-number" data-cell="forward_and_backward" data-line="19" href="#cell-forward_and_backward" onclick="event.preventDefault(); selectCellLine('forward_and_backward', 19, true);">19</a>
3763
+ <a class="line-number" data-cell="forward_and_backward" data-line="20" href="#cell-forward_and_backward" onclick="event.preventDefault(); selectCellLine('forward_and_backward', 20, true);">20</a>
3764
+ <a class="line-number" data-cell="forward_and_backward" data-line="21" href="#cell-forward_and_backward" onclick="event.preventDefault(); selectCellLine('forward_and_backward', 21, true);">21</a>
3765
+ <a class="line-number" data-cell="forward_and_backward" data-line="22" href="#cell-forward_and_backward" onclick="event.preventDefault(); selectCellLine('forward_and_backward', 22, true);">22</a>
3766
+ <a class="line-number" data-cell="forward_and_backward" data-line="23" href="#cell-forward_and_backward" onclick="event.preventDefault(); selectCellLine('forward_and_backward', 23, true);">23</a>
3767
+ <a class="line-number" data-cell="forward_and_backward" data-line="24" href="#cell-forward_and_backward" onclick="event.preventDefault(); selectCellLine('forward_and_backward', 24, true);">24</a>
3768
+ <a class="line-number" data-cell="forward_and_backward" data-line="25" href="#cell-forward_and_backward" onclick="event.preventDefault(); selectCellLine('forward_and_backward', 25, true);">25</a>
3769
+ <a class="line-number" data-cell="forward_and_backward" data-line="26" href="#cell-forward_and_backward" onclick="event.preventDefault(); selectCellLine('forward_and_backward', 26, true);">26</a>
3770
+ <a class="line-number" data-cell="forward_and_backward" data-line="27" href="#cell-forward_and_backward" onclick="event.preventDefault(); selectCellLine('forward_and_backward', 27, true);">27</a>
3771
+ <a class="line-number" data-cell="forward_and_backward" data-line="28" href="#cell-forward_and_backward" onclick="event.preventDefault(); selectCellLine('forward_and_backward', 28, true);">28</a>
3772
+ <a class="line-number" data-cell="forward_and_backward" data-line="29" href="#cell-forward_and_backward" onclick="event.preventDefault(); selectCellLine('forward_and_backward', 29, true);">29</a>
3773
+ <a class="line-number" data-cell="forward_and_backward" data-line="30" href="#cell-forward_and_backward" onclick="event.preventDefault(); selectCellLine('forward_and_backward', 30, true);">30</a>
3774
+ <a class="line-number" data-cell="forward_and_backward" data-line="31" href="#cell-forward_and_backward" onclick="event.preventDefault(); selectCellLine('forward_and_backward', 31, true);">31</a>
3775
+ <a class="line-number" data-cell="forward_and_backward" data-line="32" href="#cell-forward_and_backward" onclick="event.preventDefault(); selectCellLine('forward_and_backward', 32, true);">32</a>
3776
+ <a class="line-number" data-cell="forward_and_backward" data-line="33" href="#cell-forward_and_backward" onclick="event.preventDefault(); selectCellLine('forward_and_backward', 33, true);">33</a>
3777
+ <a class="line-number" data-cell="forward_and_backward" data-line="34" href="#cell-forward_and_backward" onclick="event.preventDefault(); selectCellLine('forward_and_backward', 34, true);">34</a>
3778
+ <a class="line-number" data-cell="forward_and_backward" data-line="35" href="#cell-forward_and_backward" onclick="event.preventDefault(); selectCellLine('forward_and_backward', 35, true);">35</a>
3779
+ <a class="line-number" data-cell="forward_and_backward" data-line="36" href="#cell-forward_and_backward" onclick="event.preventDefault(); selectCellLine('forward_and_backward', 36, true);">36</a>
3780
+ <a class="line-number" data-cell="forward_and_backward" data-line="37" href="#cell-forward_and_backward" onclick="event.preventDefault(); selectCellLine('forward_and_backward', 37, true);">37</a>
3781
+ <a class="line-number" data-cell="forward_and_backward" data-line="38" href="#cell-forward_and_backward" onclick="event.preventDefault(); selectCellLine('forward_and_backward', 38, true);">38</a>
3782
+ <a class="line-number" data-cell="forward_and_backward" data-line="39" href="#cell-forward_and_backward" onclick="event.preventDefault(); selectCellLine('forward_and_backward', 39, true);">39</a>
3783
+ <a class="line-number" data-cell="forward_and_backward" data-line="40" href="#cell-forward_and_backward" onclick="event.preventDefault(); selectCellLine('forward_and_backward', 40, true);">40</a>
3784
+ <a class="line-number" data-cell="forward_and_backward" data-line="41" href="#cell-forward_and_backward" onclick="event.preventDefault(); selectCellLine('forward_and_backward', 41, true);">41</a>
3785
+ <a class="line-number" data-cell="forward_and_backward" data-line="42" href="#cell-forward_and_backward" onclick="event.preventDefault(); selectCellLine('forward_and_backward', 42, true);">42</a>
3786
+ <a class="line-number" data-cell="forward_and_backward" data-line="43" href="#cell-forward_and_backward" onclick="event.preventDefault(); selectCellLine('forward_and_backward', 43, true);">43</a>
3787
+ <a class="line-number" data-cell="forward_and_backward" data-line="44" href="#cell-forward_and_backward" onclick="event.preventDefault(); selectCellLine('forward_and_backward', 44, true);">44</a>
3788
+ <a class="line-number" data-cell="forward_and_backward" data-line="45" href="#cell-forward_and_backward" onclick="event.preventDefault(); selectCellLine('forward_and_backward', 45, true);">45</a>
3789
+ <a class="line-number" data-cell="forward_and_backward" data-line="46" href="#cell-forward_and_backward" onclick="event.preventDefault(); selectCellLine('forward_and_backward', 46, true);">46</a>
3790
+ <a class="line-number" data-cell="forward_and_backward" data-line="47" href="#cell-forward_and_backward" onclick="event.preventDefault(); selectCellLine('forward_and_backward', 47, true);">47</a>
3791
+ <a class="line-number" data-cell="forward_and_backward" data-line="48" href="#cell-forward_and_backward" onclick="event.preventDefault(); selectCellLine('forward_and_backward', 48, true);">48</a>
3792
+ <a class="line-number" data-cell="forward_and_backward" data-line="49" href="#cell-forward_and_backward" onclick="event.preventDefault(); selectCellLine('forward_and_backward', 49, true);">49</a>
3793
+ <a class="line-number" data-cell="forward_and_backward" data-line="50" href="#cell-forward_and_backward" onclick="event.preventDefault(); selectCellLine('forward_and_backward', 50, true);">50</a>
3794
+ <a class="line-number" data-cell="forward_and_backward" data-line="51" href="#cell-forward_and_backward" onclick="event.preventDefault(); selectCellLine('forward_and_backward', 51, true);">51</a>
3795
+ <a class="line-number" data-cell="forward_and_backward" data-line="52" href="#cell-forward_and_backward" onclick="event.preventDefault(); selectCellLine('forward_and_backward', 52, true);">52</a>
3796
+ <a class="line-number" data-cell="forward_and_backward" data-line="53" href="#cell-forward_and_backward" onclick="event.preventDefault(); selectCellLine('forward_and_backward', 53, true);">53</a>
3797
+ <a class="line-number" data-cell="forward_and_backward" data-line="54" href="#cell-forward_and_backward" onclick="event.preventDefault(); selectCellLine('forward_and_backward', 54, true);">54</a>
3798
+ <a class="line-number" data-cell="forward_and_backward" data-line="55" href="#cell-forward_and_backward" onclick="event.preventDefault(); selectCellLine('forward_and_backward', 55, true);">55</a>
3799
+ <a class="line-number" data-cell="forward_and_backward" data-line="56" href="#cell-forward_and_backward" onclick="event.preventDefault(); selectCellLine('forward_and_backward', 56, true);">56</a>
3800
+ <a class="line-number" data-cell="forward_and_backward" data-line="57" href="#cell-forward_and_backward" onclick="event.preventDefault(); selectCellLine('forward_and_backward', 57, true);">57</a>
3801
+ <a class="line-number" data-cell="forward_and_backward" data-line="58" href="#cell-forward_and_backward" onclick="event.preventDefault(); selectCellLine('forward_and_backward', 58, true);">58</a>
3802
+ <a class="line-number" data-cell="forward_and_backward" data-line="59" href="#cell-forward_and_backward" onclick="event.preventDefault(); selectCellLine('forward_and_backward', 59, true);">59</a>
3803
+ <a class="line-number" data-cell="forward_and_backward" data-line="60" href="#cell-forward_and_backward" onclick="event.preventDefault(); selectCellLine('forward_and_backward', 60, true);">60</a>
3804
+ <a class="line-number" data-cell="forward_and_backward" data-line="61" href="#cell-forward_and_backward" onclick="event.preventDefault(); selectCellLine('forward_and_backward', 61, true);">61</a>
3805
+ <a class="line-number" data-cell="forward_and_backward" data-line="62" href="#cell-forward_and_backward" onclick="event.preventDefault(); selectCellLine('forward_and_backward', 62, true);">62</a>
3806
+ <a class="line-number" data-cell="forward_and_backward" data-line="63" href="#cell-forward_and_backward" onclick="event.preventDefault(); selectCellLine('forward_and_backward', 63, true);">63</a>
3807
+ <a class="line-number" data-cell="forward_and_backward" data-line="64" href="#cell-forward_and_backward" onclick="event.preventDefault(); selectCellLine('forward_and_backward', 64, true);">64</a>
3808
+ <a class="line-number" data-cell="forward_and_backward" data-line="65" href="#cell-forward_and_backward" onclick="event.preventDefault(); selectCellLine('forward_and_backward', 65, true);">65</a>
3809
+ <a class="line-number" data-cell="forward_and_backward" data-line="66" href="#cell-forward_and_backward" onclick="event.preventDefault(); selectCellLine('forward_and_backward', 66, true);">66</a>
3810
+ <a class="line-number" data-cell="forward_and_backward" data-line="67" href="#cell-forward_and_backward" onclick="event.preventDefault(); selectCellLine('forward_and_backward', 67, true);">67</a>
3811
+ <a class="line-number" data-cell="forward_and_backward" data-line="68" href="#cell-forward_and_backward" onclick="event.preventDefault(); selectCellLine('forward_and_backward', 68, true);">68</a>
3812
+ <a class="line-number" data-cell="forward_and_backward" data-line="69" href="#cell-forward_and_backward" onclick="event.preventDefault(); selectCellLine('forward_and_backward', 69, true);">69</a>
3813
+ <a class="line-number" data-cell="forward_and_backward" data-line="70" href="#cell-forward_and_backward" onclick="event.preventDefault(); selectCellLine('forward_and_backward', 70, true);">70</a>
3814
+ <a class="line-number" data-cell="forward_and_backward" data-line="71" href="#cell-forward_and_backward" onclick="event.preventDefault(); selectCellLine('forward_and_backward', 71, true);">71</a>
3815
+ <a class="line-number" data-cell="forward_and_backward" data-line="72" href="#cell-forward_and_backward" onclick="event.preventDefault(); selectCellLine('forward_and_backward', 72, true);">72</a>
3816
+ <a class="line-number" data-cell="forward_and_backward" data-line="73" href="#cell-forward_and_backward" onclick="event.preventDefault(); selectCellLine('forward_and_backward', 73, true);">73</a>
3817
+ <a class="line-number" data-cell="forward_and_backward" data-line="74" href="#cell-forward_and_backward" onclick="event.preventDefault(); selectCellLine('forward_and_backward', 74, true);">74</a>
3818
+ <a class="line-number" data-cell="forward_and_backward" data-line="75" href="#cell-forward_and_backward" onclick="event.preventDefault(); selectCellLine('forward_and_backward', 75, true);">75</a>
3819
+ <a class="line-number" data-cell="forward_and_backward" data-line="76" href="#cell-forward_and_backward" onclick="event.preventDefault(); selectCellLine('forward_and_backward', 76, true);">76</a>
3820
+ <a class="line-number" data-cell="forward_and_backward" data-line="77" href="#cell-forward_and_backward" onclick="event.preventDefault(); selectCellLine('forward_and_backward', 77, true);">77</a>
3821
+ <a class="line-number" data-cell="forward_and_backward" data-line="78" href="#cell-forward_and_backward" onclick="event.preventDefault(); selectCellLine('forward_and_backward', 78, true);">78</a>
3822
+ <a class="line-number" data-cell="forward_and_backward" data-line="79" href="#cell-forward_and_backward" onclick="event.preventDefault(); selectCellLine('forward_and_backward', 79, true);">79</a>
3823
+ <a class="line-number" data-cell="forward_and_backward" data-line="80" href="#cell-forward_and_backward" onclick="event.preventDefault(); selectCellLine('forward_and_backward', 80, true);">80</a>
3824
+ <a class="line-number" data-cell="forward_and_backward" data-line="81" href="#cell-forward_and_backward" onclick="event.preventDefault(); selectCellLine('forward_and_backward', 81, true);">81</a>
3825
+ <a class="line-number" data-cell="forward_and_backward" data-line="82" href="#cell-forward_and_backward" onclick="event.preventDefault(); selectCellLine('forward_and_backward', 82, true);">82</a>
3826
+ <a class="line-number" data-cell="forward_and_backward" data-line="83" href="#cell-forward_and_backward" onclick="event.preventDefault(); selectCellLine('forward_and_backward', 83, true);">83</a>
3827
+ <a class="line-number" data-cell="forward_and_backward" data-line="84" href="#cell-forward_and_backward" onclick="event.preventDefault(); selectCellLine('forward_and_backward', 84, true);">84</a>
3828
+ <a class="line-number" data-cell="forward_and_backward" data-line="85" href="#cell-forward_and_backward" onclick="event.preventDefault(); selectCellLine('forward_and_backward', 85, true);">85</a>
3829
+ <a class="line-number" data-cell="forward_and_backward" data-line="86" href="#cell-forward_and_backward" onclick="event.preventDefault(); selectCellLine('forward_and_backward', 86, true);">86</a>
3830
+ <a class="line-number" data-cell="forward_and_backward" data-line="87" href="#cell-forward_and_backward" onclick="event.preventDefault(); selectCellLine('forward_and_backward', 87, true);">87</a>
3831
+ <a class="line-number" data-cell="forward_and_backward" data-line="88" href="#cell-forward_and_backward" onclick="event.preventDefault(); selectCellLine('forward_and_backward', 88, true);">88</a>
3832
+ <a class="line-number" data-cell="forward_and_backward" data-line="89" href="#cell-forward_and_backward" onclick="event.preventDefault(); selectCellLine('forward_and_backward', 89, true);">89</a>
3833
+ <a class="line-number" data-cell="forward_and_backward" data-line="90" href="#cell-forward_and_backward" onclick="event.preventDefault(); selectCellLine('forward_and_backward', 90, true);">90</a>
3834
+ <a class="line-number" data-cell="forward_and_backward" data-line="91" href="#cell-forward_and_backward" onclick="event.preventDefault(); selectCellLine('forward_and_backward', 91, true);">91</a>
3835
+ <a class="line-number" data-cell="forward_and_backward" data-line="92" href="#cell-forward_and_backward" onclick="event.preventDefault(); selectCellLine('forward_and_backward', 92, true);">92</a>
3836
+ <a class="line-number" data-cell="forward_and_backward" data-line="93" href="#cell-forward_and_backward" onclick="event.preventDefault(); selectCellLine('forward_and_backward', 93, true);">93</a>
3837
+ <a class="line-number" data-cell="forward_and_backward" data-line="94" href="#cell-forward_and_backward" onclick="event.preventDefault(); selectCellLine('forward_and_backward', 94, true);">94</a>
3838
+ <a class="line-number" data-cell="forward_and_backward" data-line="95" href="#cell-forward_and_backward" onclick="event.preventDefault(); selectCellLine('forward_and_backward', 95, true);">95</a>
3839
+ <a class="line-number" data-cell="forward_and_backward" data-line="96" href="#cell-forward_and_backward" onclick="event.preventDefault(); selectCellLine('forward_and_backward', 96, true);">96</a>
3840
+ <a class="line-number" data-cell="forward_and_backward" data-line="97" href="#cell-forward_and_backward" onclick="event.preventDefault(); selectCellLine('forward_and_backward', 97, true);">97</a>
3841
+ <a class="line-number" data-cell="forward_and_backward" data-line="98" href="#cell-forward_and_backward" onclick="event.preventDefault(); selectCellLine('forward_and_backward', 98, true);">98</a>
3842
+ <a class="line-number" data-cell="forward_and_backward" data-line="99" href="#cell-forward_and_backward" onclick="event.preventDefault(); selectCellLine('forward_and_backward', 99, true);">99</a>
3843
+ <a class="line-number" data-cell="forward_and_backward" data-line="100" href="#cell-forward_and_backward" onclick="event.preventDefault(); selectCellLine('forward_and_backward', 100, true);">100</a>
3844
+ <a class="line-number" data-cell="forward_and_backward" data-line="101" href="#cell-forward_and_backward" onclick="event.preventDefault(); selectCellLine('forward_and_backward', 101, true);">101</a>
3845
+ <a class="line-number" data-cell="forward_and_backward" data-line="102" href="#cell-forward_and_backward" onclick="event.preventDefault(); selectCellLine('forward_and_backward', 102, true);">102</a>
3846
+ <a class="line-number" data-cell="forward_and_backward" data-line="103" href="#cell-forward_and_backward" onclick="event.preventDefault(); selectCellLine('forward_and_backward', 103, true);">103</a>
3847
+ <a class="line-number" data-cell="forward_and_backward" data-line="104" href="#cell-forward_and_backward" onclick="event.preventDefault(); selectCellLine('forward_and_backward', 104, true);">104</a>
3848
+ <a class="line-number" data-cell="forward_and_backward" data-line="105" href="#cell-forward_and_backward" onclick="event.preventDefault(); selectCellLine('forward_and_backward', 105, true);">105</a>
3849
+ <a class="line-number" data-cell="forward_and_backward" data-line="106" href="#cell-forward_and_backward" onclick="event.preventDefault(); selectCellLine('forward_and_backward', 106, true);">106</a>
3850
+ <a class="line-number" data-cell="forward_and_backward" data-line="107" href="#cell-forward_and_backward" onclick="event.preventDefault(); selectCellLine('forward_and_backward', 107, true);">107</a>
3851
+ <a class="line-number" data-cell="forward_and_backward" data-line="108" href="#cell-forward_and_backward" onclick="event.preventDefault(); selectCellLine('forward_and_backward', 108, true);">108</a>
3852
+ <a class="line-number" data-cell="forward_and_backward" data-line="109" href="#cell-forward_and_backward" onclick="event.preventDefault(); selectCellLine('forward_and_backward', 109, true);">109</a>
3853
+ <a class="line-number" data-cell="forward_and_backward" data-line="110" href="#cell-forward_and_backward" onclick="event.preventDefault(); selectCellLine('forward_and_backward', 110, true);">110</a>
3854
+ <a class="line-number" data-cell="forward_and_backward" data-line="111" href="#cell-forward_and_backward" onclick="event.preventDefault(); selectCellLine('forward_and_backward', 111, true);">111</a>
3855
+ <a class="line-number" data-cell="forward_and_backward" data-line="112" href="#cell-forward_and_backward" onclick="event.preventDefault(); selectCellLine('forward_and_backward', 112, true);">112</a>
3856
+ <a class="line-number" data-cell="forward_and_backward" data-line="113" href="#cell-forward_and_backward" onclick="event.preventDefault(); selectCellLine('forward_and_backward', 113, true);">113</a>
3857
+ <a class="line-number" data-cell="forward_and_backward" data-line="114" href="#cell-forward_and_backward" onclick="event.preventDefault(); selectCellLine('forward_and_backward', 114, true);">114</a>
3858
+ <a class="line-number" data-cell="forward_and_backward" data-line="115" href="#cell-forward_and_backward" onclick="event.preventDefault(); selectCellLine('forward_and_backward', 115, true);">115</a>
3859
+ <a class="line-number" data-cell="forward_and_backward" data-line="116" href="#cell-forward_and_backward" onclick="event.preventDefault(); selectCellLine('forward_and_backward', 116, true);">116</a>
3860
+ <a class="line-number" data-cell="forward_and_backward" data-line="117" href="#cell-forward_and_backward" onclick="event.preventDefault(); selectCellLine('forward_and_backward', 117, true);">117</a>
3861
+ <a class="line-number" data-cell="forward_and_backward" data-line="118" href="#cell-forward_and_backward" onclick="event.preventDefault(); selectCellLine('forward_and_backward', 118, true);">118</a>
3862
+ <a class="line-number" data-cell="forward_and_backward" data-line="119" href="#cell-forward_and_backward" onclick="event.preventDefault(); selectCellLine('forward_and_backward', 119, true);">119</a>
3863
+ <a class="line-number" data-cell="forward_and_backward" data-line="120" href="#cell-forward_and_backward" onclick="event.preventDefault(); selectCellLine('forward_and_backward', 120, true);">120</a>
3864
+ <a class="line-number" data-cell="forward_and_backward" data-line="121" href="#cell-forward_and_backward" onclick="event.preventDefault(); selectCellLine('forward_and_backward', 121, true);">121</a>
3865
+ <a class="line-number" data-cell="forward_and_backward" data-line="122" href="#cell-forward_and_backward" onclick="event.preventDefault(); selectCellLine('forward_and_backward', 122, true);">122</a>
3866
+ <a class="line-number" data-cell="forward_and_backward" data-line="123" href="#cell-forward_and_backward" onclick="event.preventDefault(); selectCellLine('forward_and_backward', 123, true);">123</a>
3867
+ <a class="line-number" data-cell="forward_and_backward" data-line="124" href="#cell-forward_and_backward" onclick="event.preventDefault(); selectCellLine('forward_and_backward', 124, true);">124</a>
3868
+ <a class="line-number" data-cell="forward_and_backward" data-line="125" href="#cell-forward_and_backward" onclick="event.preventDefault(); selectCellLine('forward_and_backward', 125, true);">125</a>
3869
+ <a class="line-number" data-cell="forward_and_backward" data-line="126" href="#cell-forward_and_backward" onclick="event.preventDefault(); selectCellLine('forward_and_backward', 126, true);">126</a>
3870
+ <a class="line-number" data-cell="forward_and_backward" data-line="127" href="#cell-forward_and_backward" onclick="event.preventDefault(); selectCellLine('forward_and_backward', 127, true);">127</a>
3871
+ <a class="line-number" data-cell="forward_and_backward" data-line="128" href="#cell-forward_and_backward" onclick="event.preventDefault(); selectCellLine('forward_and_backward', 128, true);">128</a>
3872
+ <a class="line-number" data-cell="forward_and_backward" data-line="129" href="#cell-forward_and_backward" onclick="event.preventDefault(); selectCellLine('forward_and_backward', 129, true);">129</a>
3873
+ <a class="line-number" data-cell="forward_and_backward" data-line="130" href="#cell-forward_and_backward" onclick="event.preventDefault(); selectCellLine('forward_and_backward', 130, true);">130</a>
3874
+ <a class="line-number" data-cell="forward_and_backward" data-line="131" href="#cell-forward_and_backward" onclick="event.preventDefault(); selectCellLine('forward_and_backward', 131, true);">131</a>
3875
+ <a class="line-number" data-cell="forward_and_backward" data-line="132" href="#cell-forward_and_backward" onclick="event.preventDefault(); selectCellLine('forward_and_backward', 132, true);">132</a>
3876
+ <a class="line-number" data-cell="forward_and_backward" data-line="133" href="#cell-forward_and_backward" onclick="event.preventDefault(); selectCellLine('forward_and_backward', 133, true);">133</a>
3877
+ <a class="line-number" data-cell="forward_and_backward" data-line="134" href="#cell-forward_and_backward" onclick="event.preventDefault(); selectCellLine('forward_and_backward', 134, true);">134</a>
3878
+ <a class="line-number" data-cell="forward_and_backward" data-line="135" href="#cell-forward_and_backward" onclick="event.preventDefault(); selectCellLine('forward_and_backward', 135, true);">135</a>
3879
+ <a class="line-number" data-cell="forward_and_backward" data-line="136" href="#cell-forward_and_backward" onclick="event.preventDefault(); selectCellLine('forward_and_backward', 136, true);">136</a>
3880
+ <a class="line-number" data-cell="forward_and_backward" data-line="137" href="#cell-forward_and_backward" onclick="event.preventDefault(); selectCellLine('forward_and_backward', 137, true);">137</a>
3881
+ <a class="line-number" data-cell="forward_and_backward" data-line="138" href="#cell-forward_and_backward" onclick="event.preventDefault(); selectCellLine('forward_and_backward', 138, true);">138</a>
3882
+ <a class="line-number" data-cell="forward_and_backward" data-line="139" href="#cell-forward_and_backward" onclick="event.preventDefault(); selectCellLine('forward_and_backward', 139, true);">139</a>
3883
+ <a class="line-number" data-cell="forward_and_backward" data-line="140" href="#cell-forward_and_backward" onclick="event.preventDefault(); selectCellLine('forward_and_backward', 140, true);">140</a>
3884
+ <a class="line-number" data-cell="forward_and_backward" data-line="141" href="#cell-forward_and_backward" onclick="event.preventDefault(); selectCellLine('forward_and_backward', 141, true);">141</a>
3885
+ <a class="line-number" data-cell="forward_and_backward" data-line="142" href="#cell-forward_and_backward" onclick="event.preventDefault(); selectCellLine('forward_and_backward', 142, true);">142</a>
3886
+ <a class="line-number" data-cell="forward_and_backward" data-line="143" href="#cell-forward_and_backward" onclick="event.preventDefault(); selectCellLine('forward_and_backward', 143, true);">143</a>
3887
+ <a class="line-number" data-cell="forward_and_backward" data-line="144" href="#cell-forward_and_backward" onclick="event.preventDefault(); selectCellLine('forward_and_backward', 144, true);">144</a>
3888
+ <a class="line-number" data-cell="forward_and_backward" data-line="145" href="#cell-forward_and_backward" onclick="event.preventDefault(); selectCellLine('forward_and_backward', 145, true);">145</a>
3889
+ <a class="line-number" data-cell="forward_and_backward" data-line="146" href="#cell-forward_and_backward" onclick="event.preventDefault(); selectCellLine('forward_and_backward', 146, true);">146</a>
3890
+ <a class="line-number" data-cell="forward_and_backward" data-line="147" href="#cell-forward_and_backward" onclick="event.preventDefault(); selectCellLine('forward_and_backward', 147, true);">147</a>
3891
+ <a class="line-number" data-cell="forward_and_backward" data-line="148" href="#cell-forward_and_backward" onclick="event.preventDefault(); selectCellLine('forward_and_backward', 148, true);">148</a>
3892
+ <a class="line-number" data-cell="forward_and_backward" data-line="149" href="#cell-forward_and_backward" onclick="event.preventDefault(); selectCellLine('forward_and_backward', 149, true);">149</a>
3893
+ <a class="line-number" data-cell="forward_and_backward" data-line="150" href="#cell-forward_and_backward" onclick="event.preventDefault(); selectCellLine('forward_and_backward', 150, true);">150</a>
3894
+ <a class="line-number" data-cell="forward_and_backward" data-line="151" href="#cell-forward_and_backward" onclick="event.preventDefault(); selectCellLine('forward_and_backward', 151, true);">151</a>
3895
+ <a class="line-number" data-cell="forward_and_backward" data-line="152" href="#cell-forward_and_backward" onclick="event.preventDefault(); selectCellLine('forward_and_backward', 152, true);">152</a>
3896
+ <a class="line-number" data-cell="forward_and_backward" data-line="153" href="#cell-forward_and_backward" onclick="event.preventDefault(); selectCellLine('forward_and_backward', 153, true);">153</a>
3897
+ <a class="line-number" data-cell="forward_and_backward" data-line="154" href="#cell-forward_and_backward" onclick="event.preventDefault(); selectCellLine('forward_and_backward', 154, true);">154</a>
3898
+ <a class="line-number" data-cell="forward_and_backward" data-line="155" href="#cell-forward_and_backward" onclick="event.preventDefault(); selectCellLine('forward_and_backward', 155, true);">155</a>
3899
+ <a class="line-number" data-cell="forward_and_backward" data-line="156" href="#cell-forward_and_backward" onclick="event.preventDefault(); selectCellLine('forward_and_backward', 156, true);">156</a>
3900
+ <a class="line-number" data-cell="forward_and_backward" data-line="157" href="#cell-forward_and_backward" onclick="event.preventDefault(); selectCellLine('forward_and_backward', 157, true);">157</a>
3901
+ <a class="line-number" data-cell="forward_and_backward" data-line="158" href="#cell-forward_and_backward" onclick="event.preventDefault(); selectCellLine('forward_and_backward', 158, true);">158</a>
3902
+ <a class="line-number" data-cell="forward_and_backward" data-line="159" href="#cell-forward_and_backward" onclick="event.preventDefault(); selectCellLine('forward_and_backward', 159, true);">159</a>
3903
+ <a class="line-number" data-cell="forward_and_backward" data-line="160" href="#cell-forward_and_backward" onclick="event.preventDefault(); selectCellLine('forward_and_backward', 160, true);">160</a>
3904
+ <a class="line-number" data-cell="forward_and_backward" data-line="161" href="#cell-forward_and_backward" onclick="event.preventDefault(); selectCellLine('forward_and_backward', 161, true);">161</a>
3905
+ <a class="line-number" data-cell="forward_and_backward" data-line="162" href="#cell-forward_and_backward" onclick="event.preventDefault(); selectCellLine('forward_and_backward', 162, true);">162</a>
3906
+ <a class="line-number" data-cell="forward_and_backward" data-line="163" href="#cell-forward_and_backward" onclick="event.preventDefault(); selectCellLine('forward_and_backward', 163, true);">163</a>
3907
+ <a class="line-number" data-cell="forward_and_backward" data-line="164" href="#cell-forward_and_backward" onclick="event.preventDefault(); selectCellLine('forward_and_backward', 164, true);">164</a>
3908
+ <a class="line-number" data-cell="forward_and_backward" data-line="165" href="#cell-forward_and_backward" onclick="event.preventDefault(); selectCellLine('forward_and_backward', 165, true);">165</a>
3909
+ <a class="line-number" data-cell="forward_and_backward" data-line="166" href="#cell-forward_and_backward" onclick="event.preventDefault(); selectCellLine('forward_and_backward', 166, true);">166</a>
3910
+ <a class="line-number" data-cell="forward_and_backward" data-line="167" href="#cell-forward_and_backward" onclick="event.preventDefault(); selectCellLine('forward_and_backward', 167, true);">167</a>
3911
+ <a class="line-number" data-cell="forward_and_backward" data-line="168" href="#cell-forward_and_backward" onclick="event.preventDefault(); selectCellLine('forward_and_backward', 168, true);">168</a>
3912
+ <a class="line-number" data-cell="forward_and_backward" data-line="169" href="#cell-forward_and_backward" onclick="event.preventDefault(); selectCellLine('forward_and_backward', 169, true);">169</a>
3913
+ <a class="line-number" data-cell="forward_and_backward" data-line="170" href="#cell-forward_and_backward" onclick="event.preventDefault(); selectCellLine('forward_and_backward', 170, true);">170</a>
3914
+ <a class="line-number" data-cell="forward_and_backward" data-line="171" href="#cell-forward_and_backward" onclick="event.preventDefault(); selectCellLine('forward_and_backward', 171, true);">171</a>
3915
+ <a class="line-number" data-cell="forward_and_backward" data-line="172" href="#cell-forward_and_backward" onclick="event.preventDefault(); selectCellLine('forward_and_backward', 172, true);">172</a>
3916
+ <a class="line-number" data-cell="forward_and_backward" data-line="173" href="#cell-forward_and_backward" onclick="event.preventDefault(); selectCellLine('forward_and_backward', 173, true);">173</a>
3917
+ <a class="line-number" data-cell="forward_and_backward" data-line="174" href="#cell-forward_and_backward" onclick="event.preventDefault(); selectCellLine('forward_and_backward', 174, true);">174</a>
3918
+ <a class="line-number" data-cell="forward_and_backward" data-line="175" href="#cell-forward_and_backward" onclick="event.preventDefault(); selectCellLine('forward_and_backward', 175, true);">175</a>
3919
+ <a class="line-number" data-cell="forward_and_backward" data-line="176" href="#cell-forward_and_backward" onclick="event.preventDefault(); selectCellLine('forward_and_backward', 176, true);">176</a>
3920
+ <a class="line-number" data-cell="forward_and_backward" data-line="177" href="#cell-forward_and_backward" onclick="event.preventDefault(); selectCellLine('forward_and_backward', 177, true);">177</a>
3921
+ <a class="line-number" data-cell="forward_and_backward" data-line="178" href="#cell-forward_and_backward" onclick="event.preventDefault(); selectCellLine('forward_and_backward', 178, true);">178</a>
3922
+ <a class="line-number" data-cell="forward_and_backward" data-line="179" href="#cell-forward_and_backward" onclick="event.preventDefault(); selectCellLine('forward_and_backward', 179, true);">179</a>
3923
+ <a class="line-number" data-cell="forward_and_backward" data-line="180" href="#cell-forward_and_backward" onclick="event.preventDefault(); selectCellLine('forward_and_backward', 180, true);">180</a>
3924
+ <a class="line-number" data-cell="forward_and_backward" data-line="181" href="#cell-forward_and_backward" onclick="event.preventDefault(); selectCellLine('forward_and_backward', 181, true);">181</a>
3925
+ <a class="line-number" data-cell="forward_and_backward" data-line="182" href="#cell-forward_and_backward" onclick="event.preventDefault(); selectCellLine('forward_and_backward', 182, true);">182</a>
3926
+ <a class="line-number" data-cell="forward_and_backward" data-line="183" href="#cell-forward_and_backward" onclick="event.preventDefault(); selectCellLine('forward_and_backward', 183, true);">183</a>
3927
+ <a class="line-number" data-cell="forward_and_backward" data-line="184" href="#cell-forward_and_backward" onclick="event.preventDefault(); selectCellLine('forward_and_backward', 184, true);">184</a>
3928
+ <a class="line-number" data-cell="forward_and_backward" data-line="185" href="#cell-forward_and_backward" onclick="event.preventDefault(); selectCellLine('forward_and_backward', 185, true);">185</a>
3929
+ <a class="line-number" data-cell="forward_and_backward" data-line="186" href="#cell-forward_and_backward" onclick="event.preventDefault(); selectCellLine('forward_and_backward', 186, true);">186</a>
3930
+ <a class="line-number" data-cell="forward_and_backward" data-line="187" href="#cell-forward_and_backward" onclick="event.preventDefault(); selectCellLine('forward_and_backward', 187, true);">187</a>
3931
+ <a class="line-number" data-cell="forward_and_backward" data-line="188" href="#cell-forward_and_backward" onclick="event.preventDefault(); selectCellLine('forward_and_backward', 188, true);">188</a>
3932
+ <a class="line-number" data-cell="forward_and_backward" data-line="189" href="#cell-forward_and_backward" onclick="event.preventDefault(); selectCellLine('forward_and_backward', 189, true);">189</a>
3933
+ <a class="line-number" data-cell="forward_and_backward" data-line="190" href="#cell-forward_and_backward" onclick="event.preventDefault(); selectCellLine('forward_and_backward', 190, true);">190</a>
3934
+ <a class="line-number" data-cell="forward_and_backward" data-line="191" href="#cell-forward_and_backward" onclick="event.preventDefault(); selectCellLine('forward_and_backward', 191, true);">191</a>
3935
+ <a class="line-number" data-cell="forward_and_backward" data-line="192" href="#cell-forward_and_backward" onclick="event.preventDefault(); selectCellLine('forward_and_backward', 192, true);">192</a>
3936
+ <a class="line-number" data-cell="forward_and_backward" data-line="193" href="#cell-forward_and_backward" onclick="event.preventDefault(); selectCellLine('forward_and_backward', 193, true);">193</a>
3937
+ <a class="line-number" data-cell="forward_and_backward" data-line="194" href="#cell-forward_and_backward" onclick="event.preventDefault(); selectCellLine('forward_and_backward', 194, true);">194</a>
3938
+ <a class="line-number" data-cell="forward_and_backward" data-line="195" href="#cell-forward_and_backward" onclick="event.preventDefault(); selectCellLine('forward_and_backward', 195, true);">195</a>
3939
+ <a class="line-number" data-cell="forward_and_backward" data-line="196" href="#cell-forward_and_backward" onclick="event.preventDefault(); selectCellLine('forward_and_backward', 196, true);">196</a>
3940
  </div>
3941
  <div class="code-wrap">
3942
  <div class="highlight"><pre><span></span><span class="c1"># /// script</span>
 
3963
  <span class="kn">import</span><span class="w"> </span><span class="nn">logging</span>
3964
  <span class="kn">from</span><span class="w"> </span><span class="nn">transformers.models.gpt_oss.modeling_gpt_oss</span><span class="w"> </span><span class="kn">import</span> <span class="n">GptOssRMSNorm</span>
3965
 
3966
+ <span class="c1"># remove liger kernel for testing </span>
3967
  <span class="n">replace_kernel_forward_from_hub</span><span class="p">(</span><span class="n">GptOssRMSNorm</span><span class="p">,</span> <span class="kc">None</span><span class="p">)</span>
3968
 
3969
  <span class="c1"># set to debug logging</span>
 
4004
  <span class="n">tokenizer</span> <span class="o">=</span> <span class="n">PreTrainedTokenizerFast</span><span class="o">.</span><span class="n">from_pretrained</span><span class="p">(</span><span class="n">model_id</span><span class="p">)</span>
4005
  <span class="n">quantization_config</span> <span class="o">=</span> <span class="n">Mxfp4Config</span><span class="p">(</span><span class="n">dequantize</span><span class="o">=</span><span class="kc">True</span><span class="p">)</span>
4006
 
 
 
4007
  <span class="n">model</span> <span class="o">=</span> <span class="n">GptOssForCausalLM</span><span class="o">.</span><span class="n">from_pretrained</span><span class="p">(</span>
4008
  <span class="n">model_id</span><span class="p">,</span>
4009
  <span class="n">dtype</span><span class="o">=</span><span class="s2">&quot;bfloat16&quot;</span><span class="p">,</span>
 
4024
  <span class="n">reasoning_effort</span><span class="o">=</span><span class="s2">&quot;low&quot;</span><span class="p">,</span>
4025
  <span class="p">)</span><span class="o">.</span><span class="n">to</span><span class="p">(</span><span class="s2">&quot;cuda&quot;</span><span class="p">)</span>
4026
 
4027
+ <span class="n">max_tokens</span> <span class="o">=</span> <span class="mi">128</span> <span class="c1"># Reduced to help with memory usage</span>
4028
+
4029
+ <span class="c1"># Clear memory before backward pass</span>
4030
+ <span class="n">reset_peak_memory_stats</span><span class="p">()</span>
4031
+ <span class="nb">print</span><span class="p">(</span><span class="sa">f</span><span class="s2">&quot;Pre-generation memory: </span><span class="si">{</span><span class="n">get_memory_stats</span><span class="p">()</span><span class="si">}</span><span class="s2">&quot;</span><span class="p">)</span>
4032
 
4033
+ <span class="c1"># forward and backward pass</span>
4034
+ <span class="k">with</span> <span class="n">torch</span><span class="o">.</span><span class="n">autograd</span><span class="o">.</span><span class="n">set_grad_enabled</span><span class="p">(</span><span class="kc">True</span><span class="p">):</span>
4035
  <span class="n">start_time</span> <span class="o">=</span> <span class="n">time</span><span class="o">.</span><span class="n">perf_counter</span><span class="p">()</span>
4036
  <span class="n">generated</span> <span class="o">=</span> <span class="n">model</span><span class="o">.</span><span class="n">generate</span><span class="p">(</span>
4037
  <span class="o">**</span><span class="n">inputs</span><span class="p">,</span>
 
4040
  <span class="n">temperature</span><span class="o">=</span><span class="kc">None</span><span class="p">,</span>
4041
  <span class="p">)</span>
4042
  <span class="n">end_time</span> <span class="o">=</span> <span class="n">time</span><span class="o">.</span><span class="n">perf_counter</span><span class="p">()</span>
4043
+ <span class="nb">print</span><span class="p">(</span><span class="n">tokenizer</span><span class="o">.</span><span class="n">decode</span><span class="p">(</span><span class="n">generated</span><span class="p">[</span><span class="mi">0</span><span class="p">],</span> <span class="n">skip_special_tokens</span><span class="o">=</span><span class="kc">False</span><span class="p">))</span>
4044
+ <span class="nb">print</span><span class="p">(</span><span class="sa">f</span><span class="s2">&quot;Generation took </span><span class="si">{</span><span class="n">end_time</span><span class="w"> </span><span class="o">-</span><span class="w"> </span><span class="n">start_time</span><span class="si">:</span><span class="s2">.2f</span><span class="si">}</span><span class="s2"> seconds&quot;</span><span class="p">)</span>
4045
+ <span class="nb">print</span><span class="p">(</span><span class="sa">f</span><span class="s2">&quot;Post-generation memory: </span><span class="si">{</span><span class="n">get_memory_stats</span><span class="p">()</span><span class="si">}</span><span class="s2">&quot;</span><span class="p">)</span>
4046
+
4047
+ <span class="c1"># Use gradient checkpointing to reduce memory usage</span>
4048
+ <span class="k">if</span> <span class="nb">hasattr</span><span class="p">(</span><span class="n">model</span><span class="p">,</span> <span class="s1">&#39;gradient_checkpointing_enable&#39;</span><span class="p">):</span>
4049
+ <span class="n">model</span><span class="o">.</span><span class="n">gradient_checkpointing_enable</span><span class="p">()</span>
4050
+ <span class="nb">print</span><span class="p">(</span><span class="s2">&quot;Enabled gradient checkpointing&quot;</span><span class="p">)</span>
4051
+
4052
+ <span class="c1"># Reduce sequence length if needed for memory</span>
4053
+ <span class="n">max_seq_len</span> <span class="o">=</span> <span class="mi">512</span> <span class="c1"># Limit sequence length for backward pass</span>
4054
+ <span class="k">if</span> <span class="n">generated</span><span class="o">.</span><span class="n">size</span><span class="p">(</span><span class="mi">1</span><span class="p">)</span> <span class="o">&gt;</span> <span class="n">max_seq_len</span><span class="p">:</span>
4055
+ <span class="nb">print</span><span class="p">(</span><span class="sa">f</span><span class="s2">&quot;Truncating sequence from </span><span class="si">{</span><span class="n">generated</span><span class="o">.</span><span class="n">size</span><span class="p">(</span><span class="mi">1</span><span class="p">)</span><span class="si">}</span><span class="s2"> to </span><span class="si">{</span><span class="n">max_seq_len</span><span class="si">}</span><span class="s2"> tokens&quot;</span><span class="p">)</span>
4056
+ <span class="n">full_sequence</span> <span class="o">=</span> <span class="n">generated</span><span class="p">[:,</span> <span class="o">-</span><span class="n">max_seq_len</span><span class="p">:]</span>
4057
+ <span class="k">else</span><span class="p">:</span>
4058
+ <span class="n">full_sequence</span> <span class="o">=</span> <span class="n">generated</span>
4059
+
4060
+ <span class="c1"># Get model outputs for the full sequence</span>
4061
+ <span class="n">model</span><span class="o">.</span><span class="n">train</span><span class="p">()</span> <span class="c1"># Enable dropout and other training behaviors</span>
4062
+
4063
+ <span class="k">try</span><span class="p">:</span>
4064
+ <span class="n">outputs</span> <span class="o">=</span> <span class="n">model</span><span class="p">(</span>
4065
+ <span class="n">input_ids</span><span class="o">=</span><span class="n">full_sequence</span><span class="p">,</span>
4066
+ <span class="n">labels</span><span class="o">=</span><span class="n">full_sequence</span><span class="p">,</span> <span class="c1"># This will compute loss internally</span>
4067
+ <span class="n">return_dict</span><span class="o">=</span><span class="kc">True</span>
4068
+ <span class="p">)</span>
4069
+ <span class="nb">print</span><span class="p">(</span><span class="sa">f</span><span class="s2">&quot;Post-forward memory: </span><span class="si">{</span><span class="n">get_memory_stats</span><span class="p">()</span><span class="si">}</span><span class="s2">&quot;</span><span class="p">)</span>
4070
+
4071
+ <span class="c1"># If model doesn&#39;t compute loss, compute it manually</span>
4072
+ <span class="k">if</span> <span class="n">outputs</span><span class="o">.</span><span class="n">loss</span> <span class="ow">is</span> <span class="kc">None</span><span class="p">:</span>
4073
+ <span class="n">shift_logits</span> <span class="o">=</span> <span class="n">outputs</span><span class="o">.</span><span class="n">logits</span><span class="p">[</span><span class="o">...</span><span class="p">,</span> <span class="p">:</span><span class="o">-</span><span class="mi">1</span><span class="p">,</span> <span class="p">:]</span><span class="o">.</span><span class="n">contiguous</span><span class="p">()</span>
4074
+ <span class="n">shift_labels</span> <span class="o">=</span> <span class="n">full_sequence</span><span class="p">[</span><span class="o">...</span><span class="p">,</span> <span class="mi">1</span><span class="p">:]</span><span class="o">.</span><span class="n">contiguous</span><span class="p">()</span>
4075
+
4076
+ <span class="c1"># Use CrossEntropyLoss with ignore_index for padding tokens</span>
4077
+ <span class="n">loss_fct</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">nn</span><span class="o">.</span><span class="n">CrossEntropyLoss</span><span class="p">(</span><span class="n">ignore_index</span><span class="o">=</span><span class="n">tokenizer</span><span class="o">.</span><span class="n">pad_token_id</span> <span class="k">if</span> <span class="n">tokenizer</span><span class="o">.</span><span class="n">pad_token_id</span> <span class="ow">is</span> <span class="ow">not</span> <span class="kc">None</span> <span class="k">else</span> <span class="o">-</span><span class="mi">100</span><span class="p">)</span>
4078
+ <span class="n">loss</span> <span class="o">=</span> <span class="n">loss_fct</span><span class="p">(</span>
4079
+ <span class="n">shift_logits</span><span class="o">.</span><span class="n">view</span><span class="p">(</span><span class="o">-</span><span class="mi">1</span><span class="p">,</span> <span class="n">shift_logits</span><span class="o">.</span><span class="n">size</span><span class="p">(</span><span class="o">-</span><span class="mi">1</span><span class="p">)),</span>
4080
+ <span class="n">shift_labels</span><span class="o">.</span><span class="n">view</span><span class="p">(</span><span class="o">-</span><span class="mi">1</span><span class="p">)</span>
4081
+ <span class="p">)</span>
4082
+ <span class="k">else</span><span class="p">:</span>
4083
+ <span class="n">loss</span> <span class="o">=</span> <span class="n">outputs</span><span class="o">.</span><span class="n">loss</span>
4084
+
4085
+ <span class="nb">print</span><span class="p">(</span><span class="sa">f</span><span class="s2">&quot;Loss: </span><span class="si">{</span><span class="n">loss</span><span class="o">.</span><span class="n">item</span><span class="p">()</span><span class="si">:</span><span class="s2">.4f</span><span class="si">}</span><span class="s2">&quot;</span><span class="p">)</span>
4086
+
4087
+ <span class="c1"># Clear intermediate tensors to save memory</span>
4088
+ <span class="k">del</span> <span class="n">outputs</span>
4089
+ <span class="n">torch</span><span class="o">.</span><span class="n">cuda</span><span class="o">.</span><span class="n">empty_cache</span><span class="p">()</span>
4090
+
4091
+ <span class="c1"># Perform backward pass with memory management</span>
4092
+ <span class="nb">print</span><span class="p">(</span><span class="s2">&quot;Running backward pass...&quot;</span><span class="p">)</span>
4093
+ <span class="nb">print</span><span class="p">(</span><span class="sa">f</span><span class="s2">&quot;Pre-backward memory: </span><span class="si">{</span><span class="n">get_memory_stats</span><span class="p">()</span><span class="si">}</span><span class="s2">&quot;</span><span class="p">)</span>
4094
+
4095
+ <span class="n">loss</span><span class="o">.</span><span class="n">backward</span><span class="p">()</span>
4096
+ <span class="nb">print</span><span class="p">(</span><span class="sa">f</span><span class="s2">&quot;Post-backward memory: </span><span class="si">{</span><span class="n">get_memory_stats</span><span class="p">()</span><span class="si">}</span><span class="s2">&quot;</span><span class="p">)</span>
4097
+
4098
+ <span class="k">except</span> <span class="n">torch</span><span class="o">.</span><span class="n">cuda</span><span class="o">.</span><span class="n">OutOfMemoryError</span> <span class="k">as</span> <span class="n">e</span><span class="p">:</span>
4099
+ <span class="nb">print</span><span class="p">(</span><span class="sa">f</span><span class="s2">&quot;OOM during forward/backward pass: </span><span class="si">{</span><span class="n">e</span><span class="si">}</span><span class="s2">&quot;</span><span class="p">)</span>
4100
+ <span class="nb">print</span><span class="p">(</span><span class="s2">&quot;Try reducing max_tokens or max_seq_len&quot;</span><span class="p">)</span>
4101
+ <span class="k">raise</span>
4102
+
4103
+ <span class="c1"># Calculate gradient statistics and print sample gradients</span>
4104
+ <span class="n">total_norm</span> <span class="o">=</span> <span class="mf">0.0</span>
4105
+ <span class="n">param_count</span> <span class="o">=</span> <span class="mi">0</span>
4106
+ <span class="n">grad_samples</span> <span class="o">=</span> <span class="p">{}</span>
4107
+
4108
+ <span class="k">for</span> <span class="n">name</span><span class="p">,</span> <span class="n">p</span> <span class="ow">in</span> <span class="n">model</span><span class="o">.</span><span class="n">named_parameters</span><span class="p">():</span>
4109
+ <span class="k">if</span> <span class="n">p</span><span class="o">.</span><span class="n">grad</span> <span class="ow">is</span> <span class="ow">not</span> <span class="kc">None</span><span class="p">:</span>
4110
+ <span class="n">param_count</span> <span class="o">+=</span> <span class="mi">1</span>
4111
+ <span class="n">grad_norm</span> <span class="o">=</span> <span class="n">p</span><span class="o">.</span><span class="n">grad</span><span class="o">.</span><span class="n">data</span><span class="o">.</span><span class="n">norm</span><span class="p">(</span><span class="mi">2</span><span class="p">)</span><span class="o">.</span><span class="n">item</span><span class="p">()</span>
4112
+ <span class="n">total_norm</span> <span class="o">+=</span> <span class="n">grad_norm</span> <span class="o">**</span> <span class="mi">2</span>
4113
+
4114
+ <span class="c1"># Collect gradient statistics for key layers</span>
4115
+ <span class="k">if</span> <span class="nb">any</span><span class="p">(</span><span class="n">key</span> <span class="ow">in</span> <span class="n">name</span> <span class="k">for</span> <span class="n">key</span> <span class="ow">in</span> <span class="p">[</span><span class="s1">&#39;embed&#39;</span><span class="p">,</span> <span class="s1">&#39;lm_head&#39;</span><span class="p">,</span> <span class="s1">&#39;mlp.up&#39;</span><span class="p">,</span> <span class="s1">&#39;mlp.down&#39;</span><span class="p">,</span> <span class="s1">&#39;self_attn.q_proj&#39;</span><span class="p">,</span> <span class="s1">&#39;norm&#39;</span><span class="p">]):</span>
4116
+ <span class="n">grad_samples</span><span class="p">[</span><span class="n">name</span><span class="p">]</span> <span class="o">=</span> <span class="p">{</span>
4117
+ <span class="s1">&#39;norm&#39;</span><span class="p">:</span> <span class="n">grad_norm</span><span class="p">,</span>
4118
+ <span class="s1">&#39;mean&#39;</span><span class="p">:</span> <span class="n">p</span><span class="o">.</span><span class="n">grad</span><span class="o">.</span><span class="n">data</span><span class="o">.</span><span class="n">mean</span><span class="p">()</span><span class="o">.</span><span class="n">item</span><span class="p">(),</span>
4119
+ <span class="s1">&#39;std&#39;</span><span class="p">:</span> <span class="n">p</span><span class="o">.</span><span class="n">grad</span><span class="o">.</span><span class="n">data</span><span class="o">.</span><span class="n">std</span><span class="p">()</span><span class="o">.</span><span class="n">item</span><span class="p">(),</span>
4120
+ <span class="s1">&#39;max&#39;</span><span class="p">:</span> <span class="n">p</span><span class="o">.</span><span class="n">grad</span><span class="o">.</span><span class="n">data</span><span class="o">.</span><span class="n">max</span><span class="p">()</span><span class="o">.</span><span class="n">item</span><span class="p">(),</span>
4121
+ <span class="s1">&#39;min&#39;</span><span class="p">:</span> <span class="n">p</span><span class="o">.</span><span class="n">grad</span><span class="o">.</span><span class="n">data</span><span class="o">.</span><span class="n">min</span><span class="p">()</span><span class="o">.</span><span class="n">item</span><span class="p">(),</span>
4122
+ <span class="p">}</span>
4123
+
4124
+ <span class="n">total_norm</span> <span class="o">=</span> <span class="n">total_norm</span> <span class="o">**</span> <span class="mf">0.5</span>
4125
+
4126
+ <span class="nb">print</span><span class="p">(</span><span class="sa">f</span><span class="s2">&quot;</span><span class="se">\n</span><span class="s2">Gradient norm: </span><span class="si">{</span><span class="n">total_norm</span><span class="si">:</span><span class="s2">.4f</span><span class="si">}</span><span class="s2">&quot;</span><span class="p">)</span>
4127
+ <span class="nb">print</span><span class="p">(</span><span class="sa">f</span><span class="s2">&quot;Parameters with gradients: </span><span class="si">{</span><span class="n">param_count</span><span class="si">}</span><span class="s2">&quot;</span><span class="p">)</span>
4128
+
4129
+ <span class="c1"># Print sample gradients from important layers</span>
4130
+ <span class="nb">print</span><span class="p">(</span><span class="s2">&quot;</span><span class="se">\n</span><span class="s2">Sample gradient statistics:&quot;</span><span class="p">)</span>
4131
+ <span class="k">for</span> <span class="n">i</span><span class="p">,</span> <span class="p">(</span><span class="n">name</span><span class="p">,</span> <span class="n">stats</span><span class="p">)</span> <span class="ow">in</span> <span class="nb">enumerate</span><span class="p">(</span><span class="nb">list</span><span class="p">(</span><span class="n">grad_samples</span><span class="o">.</span><span class="n">items</span><span class="p">())[:</span><span class="mi">10</span><span class="p">]):</span>
4132
+ <span class="nb">print</span><span class="p">(</span><span class="sa">f</span><span class="s2">&quot; </span><span class="si">{</span><span class="n">name</span><span class="p">[:</span><span class="mi">60</span><span class="p">]</span><span class="si">:</span><span class="s2">&lt;60</span><span class="si">}</span><span class="s2"> | norm: </span><span class="si">{</span><span class="n">stats</span><span class="p">[</span><span class="s1">&#39;norm&#39;</span><span class="p">]</span><span class="si">:</span><span class="s2">.4e</span><span class="si">}</span><span class="s2"> | mean: </span><span class="si">{</span><span class="n">stats</span><span class="p">[</span><span class="s1">&#39;mean&#39;</span><span class="p">]</span><span class="si">:</span><span class="s2">.4e</span><span class="si">}</span><span class="s2"> | std: </span><span class="si">{</span><span class="n">stats</span><span class="p">[</span><span class="s1">&#39;std&#39;</span><span class="p">]</span><span class="si">:</span><span class="s2">.4e</span><span class="si">}</span><span class="s2">&quot;</span><span class="p">)</span>
4133
+
4134
+ <span class="c1"># Optional: zero gradients for next iteration</span>
4135
+ <span class="n">model</span><span class="o">.</span><span class="n">zero_grad</span><span class="p">()</span>
4136
+ <span class="n">model</span><span class="o">.</span><span class="n">eval</span><span class="p">()</span> <span class="c1"># Switch back to eval mode</span>
4137
  </pre></div>
4138
 
4139
+ <div class="code-line-highlight" id="line-highlight-forward_and_backward"></div>
4140
  </div>
4141
  </div>
4142
  </div>
4143
+ <div id="output-forward_and_backward" class="cell-output">
4144
+ <div class="cell-stderr">Downloading cpython-3.13.7-linux-x86_64-gnu (download) (32.0MiB)
4145
+ Downloading cpython-3.13.7-linux-x86_64-gnu (download)
 
4146
  Updating https://github.com/huggingface/transformers.git (HEAD)
4147
+ Updated https://github.com/huggingface/transformers.git (449533af73874470e914a203391635e04ac2ffc8)
4148
+ × No solution found when resolving script dependencies:
4149
+ ╰─▶ Because only transformers==4.57.0.dev0 is available and
4150
+ transformers==4.57.0.dev0 depends on huggingface-hub==1.0.0rc1,
4151
+ we can conclude that all versions of transformers depend on
4152
+ huggingface-hub==1.0.0rc1.
4153
+ And because kernels==0.10.0 depends on huggingface-hub&gt;=0.26.0,&lt;1.0,
4154
+ we can conclude that kernels==0.10.0 and all versions of transformers
4155
+ are incompatible.
4156
+ And because you require kernels==0.10.0 and transformers, we can
4157
+ conclude that your requirements are unsatisfiable.
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
4158
  </div>
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
4159
  </div>
4160
  </div>
 
 
 
4161
  </div>
4162
 
4163
  </body>
moe_benchmarks/megablocks_yamoe/artifacts/binned_run/binned_results.json ADDED
@@ -0,0 +1,24 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "implementation": "binned_results",
3
+ "config": {
4
+ "warmup": 10,
5
+ "iters": 50,
6
+ "device": "cuda",
7
+ "dtype": "torch.float32",
8
+ "tokens": 100,
9
+ "vary_inputs": true
10
+ },
11
+ "stats": {
12
+ "avg_ms": 36.06324691992995,
13
+ "min_ms": 33.29206800026441,
14
+ "max_ms": 38.40615900026023,
15
+ "std_ms": 1.258567678508065,
16
+ "p50_ms": 36.21510599987232,
17
+ "p95_ms": 37.524451049966956,
18
+ "p99_ms": 38.03603995002959,
19
+ "num_iters": 50,
20
+ "tokens_per_s": 2772.906172925215,
21
+ "throughput_variance": 98.28636435515342
22
+ },
23
+ "output_sum": 3.97190523147583
24
+ }
moe_benchmarks/megablocks_yamoe/artifacts/gptoss_run/gptoss_results.json ADDED
@@ -0,0 +1,24 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "implementation": "gptoss_results",
3
+ "config": {
4
+ "warmup": 10,
5
+ "iters": 50,
6
+ "device": "cuda",
7
+ "dtype": "torch.float32",
8
+ "tokens": 100,
9
+ "vary_inputs": true
10
+ },
11
+ "stats": {
12
+ "avg_ms": 45.286630379978305,
13
+ "min_ms": 38.91367899996112,
14
+ "max_ms": 49.84392799997295,
15
+ "std_ms": 3.2326168009526866,
16
+ "p50_ms": 45.42240999990099,
17
+ "p95_ms": 49.729684149951936,
18
+ "p99_ms": 49.82545450991893,
19
+ "num_iters": 50,
20
+ "tokens_per_s": 2208.1572234663554,
21
+ "throughput_variance": 161.27578702324564
22
+ },
23
+ "output_sum": 11.53223705291748
24
+ }
moe_benchmarks/megablocks_yamoe/artifacts/gptoss_training_run/gptoss_training_results.json ADDED
@@ -0,0 +1,24 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "implementation": "gptoss_training_results",
3
+ "config": {
4
+ "warmup": 10,
5
+ "iters": 50,
6
+ "device": "cuda",
7
+ "dtype": "torch.float32",
8
+ "tokens": 100,
9
+ "vary_inputs": true
10
+ },
11
+ "stats": {
12
+ "avg_ms": 46.01034353989235,
13
+ "min_ms": 39.20698799993261,
14
+ "max_ms": 51.09754699969926,
15
+ "std_ms": 3.2594474712819497,
16
+ "p50_ms": 46.132551999562565,
17
+ "p95_ms": 50.721096600273086,
18
+ "p99_ms": 51.0080171399477,
19
+ "num_iters": 50,
20
+ "tokens_per_s": 2173.4243282338675,
21
+ "throughput_variance": 158.68467070353637
22
+ },
23
+ "output_sum": 11.53223705291748
24
+ }
moe_benchmarks/megablocks_yamoe/artifacts/yamoe_run/yamoe_results.json ADDED
@@ -0,0 +1,24 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "implementation": "yamoe_results",
3
+ "config": {
4
+ "warmup": 10,
5
+ "iters": 50,
6
+ "device": "cuda",
7
+ "dtype": "torch.float32",
8
+ "tokens": 100,
9
+ "vary_inputs": true
10
+ },
11
+ "stats": {
12
+ "avg_ms": 4.2510544400101935,
13
+ "min_ms": 4.144352999901457,
14
+ "max_ms": 4.320155999266717,
15
+ "std_ms": 0.02873328656403644,
16
+ "p50_ms": 4.2539659998510615,
17
+ "p95_ms": 4.2857709999225335,
18
+ "p99_ms": 4.306132199617423,
19
+ "num_iters": 50,
20
+ "tokens_per_s": 23523.575482547854,
21
+ "throughput_variance": 160.28680309512873
22
+ },
23
+ "output_sum": 3.97190523147583
24
+ }
moe_benchmarks/megablocks_yamoe/cells/__pycache__/bench_utils.cpython-311.pyc CHANGED
Binary files a/moe_benchmarks/megablocks_yamoe/cells/__pycache__/bench_utils.cpython-311.pyc and b/moe_benchmarks/megablocks_yamoe/cells/__pycache__/bench_utils.cpython-311.pyc differ
 
moe_benchmarks/megablocks_yamoe/cells/__pycache__/config.cpython-311.pyc CHANGED
Binary files a/moe_benchmarks/megablocks_yamoe/cells/__pycache__/config.cpython-311.pyc and b/moe_benchmarks/megablocks_yamoe/cells/__pycache__/config.cpython-311.pyc differ
 
moe_benchmarks/megablocks_yamoe/cells/binned_run.py ADDED
@@ -0,0 +1,195 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # /// script
2
+ # dependencies = [
3
+ # "torch",
4
+ # "numpy",
5
+ # ]
6
+ # ///
7
+
8
+ import torch
9
+ from torch import nn
10
+ from torch.nn import functional as F
11
+ from bench_utils import to_dtype, tensor_stats, set_seed, bench_context
12
+ from config import (
13
+ NUM_EXPERTS, HIDDEN_SIZE, TOP_K,
14
+ BATCH_SIZE, SEQ_LEN, DTYPE, DEVICE,
15
+ WEIGHT_SEED, EXPERT_SEED, INPUT_SEED, GENERAL_SEED
16
+ )
17
+ from pathlib import Path
18
+ import os
19
+
20
+ # Discover the upstream artifact directory from env
21
+ data_dir = os.environ.get('UVNOTE_INPUT_SAVE_DATA', '.')
22
+
23
+ router_weight = torch.load(Path(data_dir) / 'router_weight.pt')
24
+ router_bias = torch.load(Path(data_dir) / 'router_bias.pt')
25
+ gate_up_proj = torch.load(Path(data_dir) / 'gate_up_proj.pt')
26
+ gate_up_proj_bias = torch.load(Path(data_dir) / 'gate_up_proj_bias.pt')
27
+ down_proj = torch.load(Path(data_dir) / 'down_proj.pt')
28
+ down_proj_bias = torch.load(Path(data_dir) / 'down_proj_bias.pt')
29
+
30
+ print("Loaded shared weights from artifacts")
31
+ print(f"Router weight sum: {router_weight.sum().item():.6f}")
32
+ print(f"Gate/up sum: {gate_up_proj.sum().item():.6f}")
33
+ print(f"Down sum: {down_proj.sum().item():.6f}")
34
+
35
+ def binned_gather(x, indices, bins, expert_capacity, top_k):
36
+ E, H = bins.shape[0], x.shape[1]
37
+ out = torch.zeros((E, expert_capacity, H), device=x.device, dtype=x.dtype)
38
+ for e in range(E):
39
+ start = 0 if e == 0 else bins[e - 1]
40
+ end = bins[e]
41
+ n = min(end - start, expert_capacity)
42
+ for i in range(n):
43
+ flat_pos = indices[start + i]
44
+ tok = flat_pos // top_k
45
+ out[e, i] = x[tok]
46
+ return out
47
+
48
+ def binned_scatter(x, indices, weights, bins, expert_capacity, top_k):
49
+ E, C, H = x.shape
50
+ N = indices.shape[0] // top_k
51
+ out = torch.zeros((N, top_k, H), dtype=x.dtype, device=x.device)
52
+ for e in range(E):
53
+ start = 0 if e == 0 else bins[e - 1]
54
+ end = bins[e]
55
+ n = end - start
56
+ if n == 0:
57
+ continue
58
+ take = min(n, expert_capacity)
59
+ for i in range(take):
60
+ flat_pos = indices[start + i]
61
+ tok = flat_pos // top_k
62
+ slot = flat_pos % top_k
63
+ scale = weights[flat_pos] if weights is not None else 1.0
64
+ out[tok, slot] = x[e, i] * scale
65
+ return out.sum(dim=1)
66
+
67
+ def sort_tokens_by_expert(router_indices, num_experts):
68
+ flat_indices = router_indices.flatten()
69
+ sorted_values, sorted_indices = torch.sort(flat_indices)
70
+ tokens_per_expert = torch.bincount(sorted_values, minlength=num_experts)
71
+ bins = torch.cumsum(tokens_per_expert, dim=0)
72
+ return sorted_indices, sorted_values, bins, tokens_per_expert
73
+
74
+ def binned_experts_ref(
75
+ hidden_states,
76
+ router_indices,
77
+ routing_weights,
78
+ gate_up_proj,
79
+ gate_up_proj_bias,
80
+ down_proj,
81
+ down_proj_bias,
82
+ expert_capacity,
83
+ ):
84
+ B, S, H = hidden_states.shape
85
+ E, K = routing_weights.shape[1], router_indices.shape[1]
86
+
87
+ indices, _, bins, _ = sort_tokens_by_expert(router_indices, E)
88
+ x = binned_gather(hidden_states.view(-1, H), indices, bins, expert_capacity, K)
89
+
90
+ gate_up = torch.bmm(x, gate_up_proj)
91
+ gate_up += gate_up_proj_bias[..., None, :]
92
+
93
+ gate, up = gate_up[..., ::2], gate_up[..., 1::2]
94
+
95
+ # clamp to limit
96
+ limit = 7.0
97
+ gate = gate.clamp(min=None, max=limit)
98
+ up = up.clamp(min=-limit, max=limit)
99
+
100
+ glu = gate * torch.sigmoid(gate * 1.702)
101
+ x = (up + 1) * glu
102
+ x = torch.bmm(x, down_proj) + down_proj_bias[..., None, :]
103
+
104
+ # build routing weights aligned to (token, slot)
105
+ flat_dense = routing_weights.view(-1, E)
106
+ flat_router = router_indices.view(-1, K)
107
+ selected = torch.gather(flat_dense, 1, flat_router).reshape(-1)
108
+
109
+ # scatter back
110
+ y = binned_scatter(x, indices, selected, bins, expert_capacity, K)
111
+
112
+ return y.view(B, S, H)
113
+
114
+ class BinnedRouter(nn.Module):
115
+ def __init__(self, router_weight, router_bias):
116
+ super().__init__()
117
+ self.top_k = TOP_K
118
+ self.num_experts = NUM_EXPERTS
119
+ self.hidden_dim = HIDDEN_SIZE
120
+ self.weight = nn.Parameter(router_weight.clone())
121
+ self.bias = nn.Parameter(router_bias.clone())
122
+
123
+ def forward(self, hidden_states):
124
+ hidden_states = hidden_states.reshape(-1, self.hidden_dim)
125
+ router_logits = F.linear(hidden_states, self.weight, self.bias)
126
+ router_top_value, router_indices = torch.topk(router_logits, self.top_k, dim=-1)
127
+ router_top_value = torch.nn.functional.softmax(router_top_value, dim=1, dtype=router_top_value.dtype)
128
+ router_scores = torch.zeros_like(router_logits).scatter_(1, router_indices, router_top_value)
129
+ return router_scores, router_indices
130
+
131
+ def ceil_div(a, b):
132
+ return (a + b - 1) // b
133
+
134
+ class BinnedMoEMLP(nn.Module):
135
+ def __init__(self, router_weight, router_bias, gate_up_proj, gate_up_proj_bias, down_proj, down_proj_bias):
136
+ super().__init__()
137
+ self.router = BinnedRouter(router_weight, router_bias)
138
+ self.num_experts = NUM_EXPERTS
139
+ self.hidden_size = HIDDEN_SIZE
140
+ self.top_k = TOP_K
141
+
142
+ # Expert weights - use the loaded weights
143
+ self.gate_up_proj = nn.Parameter(gate_up_proj.clone())
144
+ self.gate_up_proj_bias = nn.Parameter(gate_up_proj_bias.clone())
145
+ self.down_proj = nn.Parameter(down_proj.clone())
146
+ self.down_proj_bias = nn.Parameter(down_proj_bias.clone())
147
+
148
+ def forward(self, hidden_states):
149
+ router_scores, router_indices = self.router(hidden_states)
150
+ batch_size = hidden_states.shape[0]
151
+ expert_capacity = ceil_div(batch_size * self.top_k, self.num_experts)
152
+
153
+ output = binned_experts_ref(
154
+ hidden_states,
155
+ router_indices,
156
+ router_scores,
157
+ self.gate_up_proj,
158
+ self.gate_up_proj_bias,
159
+ self.down_proj,
160
+ self.down_proj_bias,
161
+ expert_capacity,
162
+ )
163
+
164
+ return output, router_scores
165
+
166
+ # Run the model
167
+ set_seed(GENERAL_SEED)
168
+
169
+ device = torch.device(DEVICE)
170
+ dtype = to_dtype(DTYPE)
171
+
172
+ print("\n=== Binned Implementation ===")
173
+ # Initialize model with loaded weights
174
+ model = BinnedMoEMLP(
175
+ router_weight.to(device),
176
+ router_bias.to(device),
177
+ gate_up_proj.to(device),
178
+ gate_up_proj_bias.to(device),
179
+ down_proj.to(device),
180
+ down_proj_bias.to(device)
181
+ ).to(device=device)
182
+
183
+ print(f"Router weight sum: {model.router.weight.sum().item():.6f}")
184
+ print(f"Gate/up proj sum: {model.gate_up_proj.sum().item():.6f}")
185
+ print(f"Down proj sum: {model.down_proj.sum().item():.6f}")
186
+
187
+ # Generate the same input as Yamoe
188
+ set_seed(INPUT_SEED)
189
+ x = torch.randn(BATCH_SIZE, SEQ_LEN, HIDDEN_SIZE, device=device, dtype=dtype) * 0.1
190
+
191
+ # Benchmark the model with varied inputs to prevent caching artifacts
192
+ tokens = BATCH_SIZE * SEQ_LEN
193
+ with bench_context(warmup=10, iters=50, device=device, dtype=dtype, tokens=tokens, save_json="binned_results.json", vary_inputs=True) as bench:
194
+ output, stats = bench(model, x)
195
+ print(f"\nOutput sum: {output[0].sum().item():.6f}")
moe_benchmarks/megablocks_yamoe/cells/gptoss_run.py ADDED
@@ -0,0 +1,147 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # /// script
2
+ # dependencies = [
3
+ # "torch",
4
+ # "numpy",
5
+ # ]
6
+ # ///
7
+
8
+ import torch
9
+ from torch import nn
10
+ from torch.nn import functional as F
11
+ from bench_utils import to_dtype, tensor_stats, set_seed, bench_context
12
+ from config import (
13
+ NUM_EXPERTS, HIDDEN_SIZE, TOP_K,
14
+ BATCH_SIZE, SEQ_LEN, DTYPE, DEVICE,
15
+ WEIGHT_SEED, EXPERT_SEED, INPUT_SEED, GENERAL_SEED
16
+ )
17
+ from pathlib import Path
18
+ import os
19
+
20
+ # Discover the upstream artifact directory from env
21
+ data_dir = os.environ.get('UVNOTE_INPUT_SAVE_DATA', '.')
22
+
23
+ router_weight = torch.load(Path(data_dir) / 'router_weight.pt')
24
+ router_bias = torch.load(Path(data_dir) / 'router_bias.pt')
25
+ gate_up_proj = torch.load(Path(data_dir) / 'gate_up_proj.pt')
26
+ gate_up_proj_bias = torch.load(Path(data_dir) / 'gate_up_proj_bias.pt')
27
+ down_proj = torch.load(Path(data_dir) / 'down_proj.pt')
28
+ down_proj_bias = torch.load(Path(data_dir) / 'down_proj_bias.pt')
29
+
30
+ print("Loaded shared weights from artifacts")
31
+ print(f"Router weight sum: {router_weight.sum().item():.6f}")
32
+ print(f"Gate/up sum: {gate_up_proj.sum().item():.6f}")
33
+ print(f"Down sum: {down_proj.sum().item():.6f}")
34
+
35
+ class GptOssRouter(nn.Module):
36
+ def __init__(self, router_weight, router_bias):
37
+ super().__init__()
38
+ self.top_k = TOP_K
39
+ self.num_experts = NUM_EXPERTS
40
+ self.hidden_dim = HIDDEN_SIZE
41
+ self.weight = nn.Parameter(router_weight.clone())
42
+ self.bias = nn.Parameter(router_bias.clone())
43
+
44
+ def forward(self, hidden_states):
45
+ hidden_states = hidden_states.reshape(-1, self.hidden_dim)
46
+ router_logits = F.linear(hidden_states, self.weight, self.bias)
47
+ router_top_value, router_indices = torch.topk(router_logits, self.top_k, dim=-1)
48
+ router_top_value = torch.nn.functional.softmax(router_top_value, dim=1, dtype=router_top_value.dtype)
49
+ router_scores = torch.zeros_like(router_logits).scatter_(1, router_indices, router_top_value)
50
+ return router_scores, router_indices
51
+
52
+ class GptOssExperts(nn.Module):
53
+ def __init__(self, gate_up_proj, gate_up_proj_bias, down_proj, down_proj_bias):
54
+ super().__init__()
55
+ self.num_experts = NUM_EXPERTS
56
+ self.hidden_size = HIDDEN_SIZE
57
+ self.expert_dim = self.hidden_size
58
+ self.gate_up_proj = nn.Parameter(gate_up_proj.clone())
59
+ self.gate_up_proj_bias = nn.Parameter(gate_up_proj_bias.clone())
60
+ self.down_proj = nn.Parameter(down_proj.clone())
61
+ self.down_proj_bias = nn.Parameter(down_proj_bias.clone())
62
+ self.alpha = 1.702
63
+ self.limit = 7.0
64
+
65
+ def forward(self, hidden_states: torch.Tensor, router_indices=None, routing_weights=None) -> torch.Tensor:
66
+ batch_size = hidden_states.shape[0]
67
+ hidden_states = hidden_states.reshape(-1, self.hidden_size)
68
+ num_experts = routing_weights.shape[1]
69
+
70
+ if hidden_states.device.type == "cpu" or self.training:
71
+ next_states = torch.zeros_like(hidden_states, dtype=hidden_states.dtype, device=hidden_states.device)
72
+ with torch.no_grad():
73
+ expert_mask = torch.nn.functional.one_hot(router_indices, num_classes=num_experts)
74
+ expert_mask = expert_mask.permute(2, 1, 0)
75
+ expert_hit = torch.greater(expert_mask.sum(dim=(-1, -2)), 0).nonzero()
76
+
77
+ for expert_idx in expert_hit[:]:
78
+ expert_idx = expert_idx[0]
79
+ with torch.no_grad():
80
+ _, token_idx = torch.where(expert_mask[expert_idx])
81
+ current_state = hidden_states[token_idx]
82
+ gate_up = current_state @ self.gate_up_proj[expert_idx] + self.gate_up_proj_bias[expert_idx]
83
+ gate, up = gate_up[..., ::2], gate_up[..., 1::2]
84
+ gate = gate.clamp(min=None, max=self.limit)
85
+ up = up.clamp(min=-self.limit, max=self.limit)
86
+ glu = gate * torch.sigmoid(gate * self.alpha)
87
+ gated_output = (up + 1) * glu
88
+ out = gated_output @ self.down_proj[expert_idx] + self.down_proj_bias[expert_idx]
89
+ weighted_output = out * routing_weights[token_idx, expert_idx, None]
90
+ next_states.index_add_(0, token_idx, weighted_output.to(hidden_states.dtype))
91
+ next_states = next_states.view(batch_size, -1, self.hidden_size)
92
+ else:
93
+ hidden_states = hidden_states.repeat(num_experts, 1)
94
+ hidden_states = hidden_states.view(num_experts, -1, self.hidden_size)
95
+ gate_up = torch.bmm(hidden_states, self.gate_up_proj) + self.gate_up_proj_bias[..., None, :]
96
+ gate, up = gate_up[..., ::2], gate_up[..., 1::2]
97
+ gate = gate.clamp(min=None, max=self.limit)
98
+ up = up.clamp(min=-self.limit, max=self.limit)
99
+ glu = gate * torch.sigmoid(gate * self.alpha)
100
+ next_states = torch.bmm(((up + 1) * glu), self.down_proj)
101
+ next_states = next_states + self.down_proj_bias[..., None, :]
102
+ next_states = next_states.view(num_experts, batch_size, -1, self.hidden_size)
103
+ next_states = next_states * routing_weights.transpose(0, 1).view(num_experts, batch_size, -1)[..., None]
104
+ next_states = next_states.sum(dim=0)
105
+ return next_states
106
+
107
+ class GptOssMoEMLP(nn.Module):
108
+ def __init__(self, router_weight, router_bias, gate_up_proj, gate_up_proj_bias, down_proj, down_proj_bias):
109
+ super().__init__()
110
+ self.router = GptOssRouter(router_weight, router_bias)
111
+ self.experts = GptOssExperts(gate_up_proj, gate_up_proj_bias, down_proj, down_proj_bias)
112
+
113
+ def forward(self, hidden_states):
114
+ router_scores, router_indices = self.router(hidden_states)
115
+ routed_out = self.experts(hidden_states, router_indices=router_indices, routing_weights=router_scores)
116
+ return routed_out, router_scores
117
+
118
+ # Run the model
119
+ set_seed(GENERAL_SEED)
120
+
121
+ device = torch.device(DEVICE)
122
+ dtype = to_dtype(DTYPE)
123
+
124
+ print("\n=== GPT-OSS Implementation ===")
125
+ # Initialize model with loaded weights
126
+ model = GptOssMoEMLP(
127
+ router_weight.to(device),
128
+ router_bias.to(device),
129
+ gate_up_proj.to(device),
130
+ gate_up_proj_bias.to(device),
131
+ down_proj.to(device),
132
+ down_proj_bias.to(device)
133
+ ).to(device=device)
134
+
135
+ print(f"Router weight sum: {model.router.weight.sum().item():.6f}")
136
+ print(f"Gate/up proj sum: {model.experts.gate_up_proj.sum().item():.6f}")
137
+ print(f"Down proj sum: {model.experts.down_proj.sum().item():.6f}")
138
+
139
+ # Generate the same input as other implementations
140
+ set_seed(INPUT_SEED)
141
+ x = torch.randn(BATCH_SIZE, SEQ_LEN, HIDDEN_SIZE, device=device, dtype=dtype) * 0.1
142
+
143
+ # Benchmark the model with varied inputs to prevent caching artifacts
144
+ tokens = BATCH_SIZE * SEQ_LEN
145
+ with bench_context(warmup=10, iters=50, device=device, dtype=dtype, tokens=tokens, save_json="gptoss_results.json", vary_inputs=True) as bench:
146
+ output, stats = bench(model, x)
147
+ print(f"\nOutput sum: {output[0].sum().item():.6f}")
moe_benchmarks/megablocks_yamoe/cells/gptoss_training_run.py ADDED
@@ -0,0 +1,138 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # /// script
2
+ # dependencies = [
3
+ # "torch",
4
+ # "numpy",
5
+ # ]
6
+ # ///
7
+
8
+ import torch
9
+ from torch import nn
10
+ from torch.nn import functional as F
11
+ from bench_utils import to_dtype, tensor_stats, set_seed, bench_context
12
+ from config import (
13
+ NUM_EXPERTS, HIDDEN_SIZE, TOP_K,
14
+ BATCH_SIZE, SEQ_LEN, DTYPE, DEVICE,
15
+ WEIGHT_SEED, EXPERT_SEED, INPUT_SEED, GENERAL_SEED
16
+ )
17
+ from pathlib import Path
18
+ import os
19
+
20
+ # Discover the upstream artifact directory from env
21
+ data_dir = os.environ.get('UVNOTE_INPUT_SAVE_DATA', '.')
22
+
23
+ router_weight = torch.load(Path(data_dir) / 'router_weight.pt')
24
+ router_bias = torch.load(Path(data_dir) / 'router_bias.pt')
25
+ gate_up_proj = torch.load(Path(data_dir) / 'gate_up_proj.pt')
26
+ gate_up_proj_bias = torch.load(Path(data_dir) / 'gate_up_proj_bias.pt')
27
+ down_proj = torch.load(Path(data_dir) / 'down_proj.pt')
28
+ down_proj_bias = torch.load(Path(data_dir) / 'down_proj_bias.pt')
29
+
30
+ print("Loaded shared weights from artifacts")
31
+ print(f"Router weight sum: {router_weight.sum().item():.6f}")
32
+ print(f"Gate/up sum: {gate_up_proj.sum().item():.6f}")
33
+ print(f"Down sum: {down_proj.sum().item():.6f}")
34
+
35
+ class GptOssTrainingRouter(nn.Module):
36
+ def __init__(self, router_weight, router_bias):
37
+ super().__init__()
38
+ self.top_k = TOP_K
39
+ self.num_experts = NUM_EXPERTS
40
+ self.hidden_dim = HIDDEN_SIZE
41
+ self.weight = nn.Parameter(router_weight.clone())
42
+ self.bias = nn.Parameter(router_bias.clone())
43
+
44
+ def forward(self, hidden_states):
45
+ hidden_states = hidden_states.reshape(-1, self.hidden_dim)
46
+ router_logits = F.linear(hidden_states, self.weight, self.bias)
47
+ router_top_value, router_indices = torch.topk(router_logits, self.top_k, dim=-1)
48
+ router_top_value = torch.nn.functional.softmax(router_top_value, dim=1, dtype=router_top_value.dtype)
49
+ router_scores = torch.zeros_like(router_logits).scatter_(1, router_indices, router_top_value)
50
+ return router_scores, router_indices
51
+
52
+ class GptOssTrainingExperts(nn.Module):
53
+ def __init__(self, gate_up_proj, gate_up_proj_bias, down_proj, down_proj_bias):
54
+ super().__init__()
55
+ self.num_experts = NUM_EXPERTS
56
+ self.hidden_size = HIDDEN_SIZE
57
+ self.expert_dim = self.hidden_size
58
+ self.gate_up_proj = nn.Parameter(gate_up_proj.clone())
59
+ self.gate_up_proj_bias = nn.Parameter(gate_up_proj_bias.clone())
60
+ self.down_proj = nn.Parameter(down_proj.clone())
61
+ self.down_proj_bias = nn.Parameter(down_proj_bias.clone())
62
+ self.alpha = 1.702
63
+ self.limit = 7.0
64
+
65
+ def forward(self, hidden_states: torch.Tensor, router_indices=None, routing_weights=None) -> torch.Tensor:
66
+ batch_size = hidden_states.shape[0]
67
+ hidden_states = hidden_states.reshape(-1, self.hidden_size)
68
+ num_experts = routing_weights.shape[1]
69
+
70
+ # Force training mode path (expert loop instead of batched)
71
+ next_states = torch.zeros_like(hidden_states, dtype=hidden_states.dtype, device=hidden_states.device)
72
+ with torch.no_grad():
73
+ expert_mask = torch.nn.functional.one_hot(router_indices, num_classes=num_experts)
74
+ expert_mask = expert_mask.permute(2, 1, 0)
75
+ expert_hit = torch.greater(expert_mask.sum(dim=(-1, -2)), 0).nonzero()
76
+
77
+ for expert_idx in expert_hit[:]:
78
+ expert_idx = expert_idx[0]
79
+ with torch.no_grad():
80
+ _, token_idx = torch.where(expert_mask[expert_idx])
81
+ current_state = hidden_states[token_idx]
82
+ gate_up = current_state @ self.gate_up_proj[expert_idx] + self.gate_up_proj_bias[expert_idx]
83
+ gate, up = gate_up[..., ::2], gate_up[..., 1::2]
84
+ gate = gate.clamp(min=None, max=self.limit)
85
+ up = up.clamp(min=-self.limit, max=self.limit)
86
+ glu = gate * torch.sigmoid(gate * self.alpha)
87
+ gated_output = (up + 1) * glu
88
+ out = gated_output @ self.down_proj[expert_idx] + self.down_proj_bias[expert_idx]
89
+ weighted_output = out * routing_weights[token_idx, expert_idx, None]
90
+ next_states.index_add_(0, token_idx, weighted_output.to(hidden_states.dtype))
91
+ next_states = next_states.view(batch_size, -1, self.hidden_size)
92
+ return next_states
93
+
94
+ class GptOssTrainingMoEMLP(nn.Module):
95
+ def __init__(self, router_weight, router_bias, gate_up_proj, gate_up_proj_bias, down_proj, down_proj_bias):
96
+ super().__init__()
97
+ self.router = GptOssTrainingRouter(router_weight, router_bias)
98
+ self.experts = GptOssTrainingExperts(gate_up_proj, gate_up_proj_bias, down_proj, down_proj_bias)
99
+
100
+ def forward(self, hidden_states):
101
+ router_scores, router_indices = self.router(hidden_states)
102
+ routed_out = self.experts(hidden_states, router_indices=router_indices, routing_weights=router_scores)
103
+ return routed_out, router_scores
104
+
105
+ # Run the model
106
+ set_seed(GENERAL_SEED)
107
+
108
+ device = torch.device(DEVICE)
109
+ dtype = to_dtype(DTYPE)
110
+
111
+ print("\n=== GPT-OSS Implementation (Training Mode - Expert Loop) ===")
112
+ # Initialize model with loaded weights and force training mode
113
+ model = GptOssTrainingMoEMLP(
114
+ router_weight.to(device),
115
+ router_bias.to(device),
116
+ gate_up_proj.to(device),
117
+ gate_up_proj_bias.to(device),
118
+ down_proj.to(device),
119
+ down_proj_bias.to(device)
120
+ ).to(device=device)
121
+
122
+ # Set to training mode to force expert loop path
123
+ model.train()
124
+
125
+ print(f"Router weight sum: {model.router.weight.sum().item():.6f}")
126
+ print(f"Gate/up proj sum: {model.experts.gate_up_proj.sum().item():.6f}")
127
+ print(f"Down proj sum: {model.experts.down_proj.sum().item():.6f}")
128
+ print(f"Model training mode: {model.training}")
129
+
130
+ # Generate the same input as other implementations
131
+ set_seed(INPUT_SEED)
132
+ x = torch.randn(BATCH_SIZE, SEQ_LEN, HIDDEN_SIZE, device=device, dtype=dtype) * 0.1
133
+
134
+ # Benchmark the model with varied inputs to prevent caching artifacts
135
+ tokens = BATCH_SIZE * SEQ_LEN
136
+ with bench_context(warmup=10, iters=50, device=device, dtype=dtype, tokens=tokens, save_json="gptoss_training_results.json", vary_inputs=True) as bench:
137
+ output, stats = bench(model, x)
138
+ print(f"\nOutput sum: {output[0].sum().item():.6f}")
moe_benchmarks/megablocks_yamoe/cells/megablocks_run.py ADDED
@@ -0,0 +1,103 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # /// script
2
+ # dependencies = [
3
+ # "torch",
4
+ # "numpy",
5
+ # "kernels",
6
+ # ]
7
+ # ///
8
+
9
+ import torch
10
+ from torch import nn
11
+ from torch.nn import functional as F
12
+ from kernels import get_kernel, get_local_kernel
13
+ from bench_utils import to_dtype, tensor_stats, set_seed, bench_context
14
+ from config import (
15
+ NUM_EXPERTS, HIDDEN_SIZE, TOP_K,
16
+ BATCH_SIZE, SEQ_LEN, DTYPE, DEVICE,
17
+ WEIGHT_SEED, EXPERT_SEED, INPUT_SEED, GENERAL_SEED
18
+ )
19
+ from pathlib import Path
20
+ from collections import namedtuple
21
+ import os
22
+
23
+ # Discover the upstream artifact directory from env
24
+ data_dir = os.environ.get('UVNOTE_INPUT_SAVE_DATA', '.')
25
+
26
+ print(f"Loading weights from: {data_dir}")
27
+
28
+ router_weight = torch.load(Path(data_dir) / 'router_weight.pt')
29
+ router_bias = torch.load(Path(data_dir) / 'router_bias.pt')
30
+ gate_up_proj = torch.load(Path(data_dir) / 'gate_up_proj.pt')
31
+ gate_up_proj_bias = torch.load(Path(data_dir) / 'gate_up_proj_bias.pt')
32
+ down_proj = torch.load(Path(data_dir) / 'down_proj.pt')
33
+ down_proj_bias = torch.load(Path(data_dir) / 'down_proj_bias.pt')
34
+
35
+ print("Loaded shared weights from artifacts")
36
+ print(f"Router weight sum: {router_weight.sum().item():.6f}")
37
+ print(f"Gate/up sum: {gate_up_proj.sum().item():.6f}")
38
+ print(f"Down sum: {down_proj.sum().item():.6f}")
39
+
40
+ def build_megablocks_model(device: torch.device):
41
+ # Download optimized kernels from the Hugging Face hub
42
+ megablocks = get_kernel("kernels-community/megablocks", revision="v0.0.2")
43
+ model = megablocks.layers.MegaBlocksMoeMLP()
44
+
45
+ # Create attribute container for expert weights
46
+ model.experts = namedtuple(
47
+ "Experts", ["gate_up_proj", "gate_up_proj_bias", "down_proj", "down_proj_bias", "hidden_size"]
48
+ )
49
+
50
+ # Use loaded router weights for consistency
51
+ model.router = torch.nn.Linear(HIDDEN_SIZE, NUM_EXPERTS, device=device)
52
+ with torch.no_grad():
53
+ model.router.weight.copy_(router_weight)
54
+ model.router.bias.copy_(router_bias)
55
+
56
+ # Attach loaded expert weights to the experts container
57
+ e = model.experts
58
+ e.alpha = 1.702
59
+ e.capacity_factor = 32
60
+ e.gate_up_proj = torch.nn.Parameter(gate_up_proj.clone().to(device))
61
+ e.gate_up_proj_bias = torch.nn.Parameter(gate_up_proj_bias.clone().to(device))
62
+ e.down_proj = torch.nn.Parameter(down_proj.clone().to(device))
63
+ e.down_proj_bias = torch.nn.Parameter(down_proj_bias.clone().to(device))
64
+ e.hidden_size = HIDDEN_SIZE
65
+
66
+ # Log weight statistics for comparison
67
+ print(f"[MegaBlocks] Router weight sum: {model.router.weight.sum().item():.6f}")
68
+ print(f"[MegaBlocks] Gate/up projection shape: {tuple(e.gate_up_proj.shape)}, sum: {e.gate_up_proj.sum().item():.6f}")
69
+ print(f"[MegaBlocks] Down projection shape: {tuple(e.down_proj.shape)}, sum: {e.down_proj.sum().item():.6f}")
70
+
71
+ return model
72
+
73
+ # Create a wrapper to match the interface of other implementations
74
+ class MegaBlocksMoEWrapper(nn.Module):
75
+ def __init__(self, megablocks_model):
76
+ super().__init__()
77
+ self.model = megablocks_model
78
+
79
+ def forward(self, hidden_states):
80
+ # MegaBlocks expects input in the format (batch, seq_len, hidden_dim)
81
+ output, dummy_routing_weights = self.model(hidden_states)
82
+ return output, dummy_routing_weights
83
+
84
+ # Run the model
85
+ set_seed(GENERAL_SEED)
86
+
87
+ device = torch.device(DEVICE)
88
+ dtype = to_dtype(DTYPE)
89
+
90
+ print("\n=== MegaBlocks Implementation ===")
91
+ # Build MegaBlocks model with loaded weights
92
+ megablocks_model = build_megablocks_model(device)
93
+ model = MegaBlocksMoEWrapper(megablocks_model).to(device=device)
94
+
95
+ # Generate the same input as other implementations
96
+ set_seed(INPUT_SEED)
97
+ x = torch.randn(BATCH_SIZE, SEQ_LEN, HIDDEN_SIZE, device=device, dtype=dtype) * 0.1
98
+
99
+ # Benchmark the model with varied inputs to prevent caching artifacts
100
+ tokens = BATCH_SIZE * SEQ_LEN
101
+ with bench_context(warmup=10, iters=50, device=device, dtype=dtype, tokens=tokens, save_json="megablocks_results.json", vary_inputs=True) as bench:
102
+ output, stats = bench(model, x)
103
+ print(f"\nOutput sum: {output[0].sum().item():.6f}")
moe_benchmarks/megablocks_yamoe/cells/setup.py ADDED
@@ -0,0 +1,116 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # /// script
2
+ # requires-python = ">=3.12"
3
+ # dependencies = [
4
+ # "accelerate>=1.10.1",
5
+ # "torch>=2.7.0",
6
+ # "kernels==0.10.0",
7
+ # "transformers@https://github.com/huggingface/transformers.git",
8
+ # "ipdb>=0.13.13",
9
+ # "matplotlib>=3.7.2",
10
+ # "numpy>=1.24.3",
11
+ # ]
12
+ # ///
13
+
14
+ import torch
15
+ from transformers import GptOssForCausalLM, PreTrainedTokenizerFast, Mxfp4Config
16
+ import time
17
+ import torch.nn as nn
18
+ from kernels import register_kernel_mapping, Mode, LayerRepository
19
+ import sys
20
+ import torch.profiler
21
+ import gc
22
+ import logging
23
+
24
+ # set to debug logging
25
+ logging.basicConfig(level=logging.INFO)
26
+
27
+ def reset_peak_memory_stats():
28
+ """Clear CUDA cache and reset memory allocation counters."""
29
+ torch.cuda.empty_cache()
30
+ if torch.cuda.is_available():
31
+ torch.cuda.reset_peak_memory_stats()
32
+ gc.collect()
33
+
34
+ def get_memory_stats():
35
+ """Get current and peak CUDA memory usage."""
36
+ if not torch.cuda.is_available():
37
+ return {"allocated_gb": 0, "peak_gb": 0, "reserved_gb": 0}
38
+ return {
39
+ "allocated_gb": torch.cuda.memory_allocated() / 1e9,
40
+ "peak_gb": torch.cuda.max_memory_allocated() / 1e9,
41
+ "reserved_gb": torch.cuda.memory_reserved() / 1e9,
42
+ }
43
+
44
+ def override_kernel_layer_name(cls_name: str, value) -> bool:
45
+ """Helper to dynamically override the kernel_layer_name in a model class."""
46
+ for mod in sys.modules.values():
47
+ if mod is None:
48
+ continue
49
+ obj = getattr(mod, cls_name, None)
50
+ if isinstance(obj, type) and issubclass(obj, nn.Module):
51
+ setattr(obj, "kernel_layer_name", value)
52
+ print(f"Overrode {cls_name}.kernel_layer_name to {value}")
53
+ return True
54
+ return False
55
+
56
+
57
+ # Init the model the normal way
58
+ model_id = "openai/gpt-oss-20b"
59
+ tokenizer = PreTrainedTokenizerFast.from_pretrained(model_id)
60
+ quantization_config = Mxfp4Config(dequantize=True)
61
+
62
+
63
+ from kernels import replace_kernel_forward_from_hub, register_kernel_mapping, LayerRepository, Mode
64
+
65
+ from transformers.models.gpt_oss.modeling_gpt_oss import GptOssMLP, GptOssRMSNorm
66
+
67
+ replace_kernel_forward_from_hub(GptOssMLP, "Yamoe")
68
+ replace_kernel_forward_from_hub(GptOssRMSNorm, None)
69
+ custom_mapping = {
70
+ "Yamoe": {
71
+ "cuda": {
72
+ Mode.INFERENCE: LayerRepository(
73
+ repo_id="drbh/yamoe",
74
+ layer_name="Yamoe",
75
+ revision="v0.3.0",
76
+ )
77
+ }
78
+ }
79
+ }
80
+ register_kernel_mapping(custom_mapping)
81
+
82
+
83
+ model = GptOssForCausalLM.from_pretrained(
84
+ model_id,
85
+ dtype="bfloat16",
86
+ device_map="auto",
87
+ use_kernels=True,
88
+ quantization_config=quantization_config,
89
+ ).eval()
90
+
91
+ messages = [
92
+ {"role": "system", "content": "What is Tensor Parallelism?"},
93
+ ]
94
+
95
+ inputs = tokenizer.apply_chat_template(
96
+ messages,
97
+ add_generation_prompt=True,
98
+ return_tensors="pt",
99
+ return_dict=True,
100
+ reasoning_effort="low",
101
+ ).to("cuda")
102
+
103
+ max_tokens = 256
104
+
105
+ with torch.inference_mode():
106
+ start_time = time.perf_counter()
107
+ generated = model.generate(
108
+ **inputs,
109
+ max_new_tokens=max_tokens,
110
+ do_sample=False,
111
+ temperature=None,
112
+ )
113
+ end_time = time.perf_counter()
114
+
115
+ print(tokenizer.decode(generated[0], skip_special_tokens=False))
116
+ print(f"Generation took {end_time - start_time:.2f} seconds")
moe_benchmarks/megablocks_yamoe/megablocks_yamoe.html CHANGED
@@ -3710,61 +3710,288 @@ span.linenos.special { color: #000000; background-color: #ffffc0; padding-left:
3710
  <div class="system-info">
3711
  <div class="system-info-header">Generated on:</div>
3712
  <div class="system-info-content">
3713
- Linux x86_64 | Linux-6.11.0-1018-azure-x86_64-with-glibc2.39
3714
  </div>
3715
  </div>
3716
 
3717
  <div class="main-content">
3718
- <div class="cell cell-failed" id="cell-nv">
 
 
 
 
3719
  <div class="cell-header">
3720
  <span class="collapse-indicators">
3721
- <span onclick="toggleCode('nv')" style="cursor: pointer;">▼ code</span>
3722
- <span onclick="toggleOutput('nv')" style="cursor: pointer;">▼ output</span>
3723
- <span id="uv-indicator-nv" style="cursor: default; opacity: 0.3;">▶ uv-logs</span>
3724
  </span> |
3725
- Cell: nv | 0.07s | FAILED
3726
- | <button class="run-btn" onclick="runCell('nv')">▶ run</button>
3727
- <button class="copy-btn" onclick="copyCell('nv')">Copy</button>
3728
- <a href="cells/nv.py" target="_blank" class="raw-btn">Raw</a>
3729
  </div>
3730
- <div id="code-nv" class="cell-code" data-lines="3">
3731
  <div class="highlight-with-lines">
3732
- <div class="line-numbers" id="lines-nv">
3733
- <a class="line-number" data-cell="nv" data-line="1" href="#cell-nv" onclick="event.preventDefault(); selectCellLine('nv', 1, true);">1</a>
3734
- <a class="line-number" data-cell="nv" data-line="2" href="#cell-nv" onclick="event.preventDefault(); selectCellLine('nv', 2, true);">2</a>
3735
- <a class="line-number" data-cell="nv" data-line="3" href="#cell-nv" onclick="event.preventDefault(); selectCellLine('nv', 3, true);">3</a>
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
3736
  </div>
3737
  <div class="code-wrap">
3738
- <div class="highlight"><pre><span></span><span class="kn">import</span><span class="w"> </span><span class="nn">subprocess</span>
3739
-
3740
- <span class="nb">print</span><span class="p">(</span><span class="n">subprocess</span><span class="o">.</span><span class="n">run</span><span class="p">([</span><span class="s2">&quot;nvidia-smi&quot;</span><span class="p">],</span> <span class="n">capture_output</span><span class="o">=</span><span class="kc">True</span><span class="p">,</span> <span class="n">text</span><span class="o">=</span><span class="kc">True</span><span class="p">)</span><span class="o">.</span><span class="n">stdout</span><span class="p">)</span>
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
3741
  </pre></div>
3742
 
3743
- <div class="code-line-highlight" id="line-highlight-nv"></div>
3744
  </div>
3745
  </div>
3746
  </div>
3747
- <div id="output-nv" class="cell-output">
3748
- <div class="cell-stderr">Traceback (most recent call last):
3749
- File &quot;/home/runner/work/kernels-uvnotes/kernels-uvnotes/moe_benchmarks/megablocks_yamoe/.uvnote/cells/nv.py&quot;, line 3, in &lt;module&gt;
3750
- print(subprocess.run([&quot;nvidia-smi&quot;], capture_output=True, text=True).stdout)
3751
- ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
3752
- File &quot;/opt/hostedtoolcache/Python/3.11.13/x64/lib/python3.11/subprocess.py&quot;, line 548, in run
3753
- with Popen(*popenargs, **kwargs) as process:
3754
- ^^^^^^^^^^^^^^^^^^^^^^^^^^^
3755
- File &quot;/opt/hostedtoolcache/Python/3.11.13/x64/lib/python3.11/subprocess.py&quot;, line 1026, in __init__
3756
- self._execute_child(args, executable, preexec_fn, close_fds,
3757
- File &quot;/opt/hostedtoolcache/Python/3.11.13/x64/lib/python3.11/subprocess.py&quot;, line 1955, in _execute_child
3758
- raise child_exception_type(errno_num, err_msg, err_filename)
3759
- FileNotFoundError: [Errno 2] No such file or directory: &#x27;nvidia-smi&#x27;
 
 
3760
  </div>
3761
  </div>
3762
  </div>
3763
-
3764
- <h1>Comparison of Megablocks and Yamoe Kernels</h1>
3765
- <p>This note compares the performance of the Megablocks and Yamoe kernels on the GPT-OSS-20B model.</p>
3766
- <h2>Megablocks kernel</h2>
3767
- <h2>Yamoe Kernel</h2>
3768
  </div>
3769
 
3770
  </body>
 
3710
  <div class="system-info">
3711
  <div class="system-info-header">Generated on:</div>
3712
  <div class="system-info-content">
3713
+ Linux x86_64 | Linux-6.12.40-64.114.amzn2023.x86_64-x86_64-with-glibc2.36
3714
  </div>
3715
  </div>
3716
 
3717
  <div class="main-content">
3718
+ <h1>Comparison of Megablocks and Yamoe Kernels</h1>
3719
+ <p>This note compares the performance of the Megablocks and Yamoe kernels on the GPT-OSS-20B model.</p>
3720
+ <h2>Megablocks kernel</h2>
3721
+ <h2>Yamoe Kernel</h2>
3722
+ <div class="cell cell-failed" id="cell-setup">
3723
  <div class="cell-header">
3724
  <span class="collapse-indicators">
3725
+ <span onclick="toggleCode('setup')" style="cursor: pointer;">▼ code</span>
3726
+ <span onclick="toggleOutput('setup')" style="cursor: pointer;">▼ output</span>
3727
+ <span id="uv-indicator-setup" style="cursor: default; opacity: 0.3;">▶ uv-logs</span>
3728
  </span> |
3729
+ Cell: setup | 19.20s | FAILED
3730
+ | <button class="run-btn" onclick="runCell('setup')">▶ run</button>
3731
+ <button class="copy-btn" onclick="copyCell('setup')">Copy</button>
3732
+ <a href="cells/setup.py" target="_blank" class="raw-btn">Raw</a>
3733
  </div>
3734
+ <div id="code-setup" class="cell-code" data-lines="116">
3735
  <div class="highlight-with-lines">
3736
+ <div class="line-numbers" id="lines-setup">
3737
+ <a class="line-number" data-cell="setup" data-line="1" href="#cell-setup" onclick="event.preventDefault(); selectCellLine('setup', 1, true);">1</a>
3738
+ <a class="line-number" data-cell="setup" data-line="2" href="#cell-setup" onclick="event.preventDefault(); selectCellLine('setup', 2, true);">2</a>
3739
+ <a class="line-number" data-cell="setup" data-line="3" href="#cell-setup" onclick="event.preventDefault(); selectCellLine('setup', 3, true);">3</a>
3740
+ <a class="line-number" data-cell="setup" data-line="4" href="#cell-setup" onclick="event.preventDefault(); selectCellLine('setup', 4, true);">4</a>
3741
+ <a class="line-number" data-cell="setup" data-line="5" href="#cell-setup" onclick="event.preventDefault(); selectCellLine('setup', 5, true);">5</a>
3742
+ <a class="line-number" data-cell="setup" data-line="6" href="#cell-setup" onclick="event.preventDefault(); selectCellLine('setup', 6, true);">6</a>
3743
+ <a class="line-number" data-cell="setup" data-line="7" href="#cell-setup" onclick="event.preventDefault(); selectCellLine('setup', 7, true);">7</a>
3744
+ <a class="line-number" data-cell="setup" data-line="8" href="#cell-setup" onclick="event.preventDefault(); selectCellLine('setup', 8, true);">8</a>
3745
+ <a class="line-number" data-cell="setup" data-line="9" href="#cell-setup" onclick="event.preventDefault(); selectCellLine('setup', 9, true);">9</a>
3746
+ <a class="line-number" data-cell="setup" data-line="10" href="#cell-setup" onclick="event.preventDefault(); selectCellLine('setup', 10, true);">10</a>
3747
+ <a class="line-number" data-cell="setup" data-line="11" href="#cell-setup" onclick="event.preventDefault(); selectCellLine('setup', 11, true);">11</a>
3748
+ <a class="line-number" data-cell="setup" data-line="12" href="#cell-setup" onclick="event.preventDefault(); selectCellLine('setup', 12, true);">12</a>
3749
+ <a class="line-number" data-cell="setup" data-line="13" href="#cell-setup" onclick="event.preventDefault(); selectCellLine('setup', 13, true);">13</a>
3750
+ <a class="line-number" data-cell="setup" data-line="14" href="#cell-setup" onclick="event.preventDefault(); selectCellLine('setup', 14, true);">14</a>
3751
+ <a class="line-number" data-cell="setup" data-line="15" href="#cell-setup" onclick="event.preventDefault(); selectCellLine('setup', 15, true);">15</a>
3752
+ <a class="line-number" data-cell="setup" data-line="16" href="#cell-setup" onclick="event.preventDefault(); selectCellLine('setup', 16, true);">16</a>
3753
+ <a class="line-number" data-cell="setup" data-line="17" href="#cell-setup" onclick="event.preventDefault(); selectCellLine('setup', 17, true);">17</a>
3754
+ <a class="line-number" data-cell="setup" data-line="18" href="#cell-setup" onclick="event.preventDefault(); selectCellLine('setup', 18, true);">18</a>
3755
+ <a class="line-number" data-cell="setup" data-line="19" href="#cell-setup" onclick="event.preventDefault(); selectCellLine('setup', 19, true);">19</a>
3756
+ <a class="line-number" data-cell="setup" data-line="20" href="#cell-setup" onclick="event.preventDefault(); selectCellLine('setup', 20, true);">20</a>
3757
+ <a class="line-number" data-cell="setup" data-line="21" href="#cell-setup" onclick="event.preventDefault(); selectCellLine('setup', 21, true);">21</a>
3758
+ <a class="line-number" data-cell="setup" data-line="22" href="#cell-setup" onclick="event.preventDefault(); selectCellLine('setup', 22, true);">22</a>
3759
+ <a class="line-number" data-cell="setup" data-line="23" href="#cell-setup" onclick="event.preventDefault(); selectCellLine('setup', 23, true);">23</a>
3760
+ <a class="line-number" data-cell="setup" data-line="24" href="#cell-setup" onclick="event.preventDefault(); selectCellLine('setup', 24, true);">24</a>
3761
+ <a class="line-number" data-cell="setup" data-line="25" href="#cell-setup" onclick="event.preventDefault(); selectCellLine('setup', 25, true);">25</a>
3762
+ <a class="line-number" data-cell="setup" data-line="26" href="#cell-setup" onclick="event.preventDefault(); selectCellLine('setup', 26, true);">26</a>
3763
+ <a class="line-number" data-cell="setup" data-line="27" href="#cell-setup" onclick="event.preventDefault(); selectCellLine('setup', 27, true);">27</a>
3764
+ <a class="line-number" data-cell="setup" data-line="28" href="#cell-setup" onclick="event.preventDefault(); selectCellLine('setup', 28, true);">28</a>
3765
+ <a class="line-number" data-cell="setup" data-line="29" href="#cell-setup" onclick="event.preventDefault(); selectCellLine('setup', 29, true);">29</a>
3766
+ <a class="line-number" data-cell="setup" data-line="30" href="#cell-setup" onclick="event.preventDefault(); selectCellLine('setup', 30, true);">30</a>
3767
+ <a class="line-number" data-cell="setup" data-line="31" href="#cell-setup" onclick="event.preventDefault(); selectCellLine('setup', 31, true);">31</a>
3768
+ <a class="line-number" data-cell="setup" data-line="32" href="#cell-setup" onclick="event.preventDefault(); selectCellLine('setup', 32, true);">32</a>
3769
+ <a class="line-number" data-cell="setup" data-line="33" href="#cell-setup" onclick="event.preventDefault(); selectCellLine('setup', 33, true);">33</a>
3770
+ <a class="line-number" data-cell="setup" data-line="34" href="#cell-setup" onclick="event.preventDefault(); selectCellLine('setup', 34, true);">34</a>
3771
+ <a class="line-number" data-cell="setup" data-line="35" href="#cell-setup" onclick="event.preventDefault(); selectCellLine('setup', 35, true);">35</a>
3772
+ <a class="line-number" data-cell="setup" data-line="36" href="#cell-setup" onclick="event.preventDefault(); selectCellLine('setup', 36, true);">36</a>
3773
+ <a class="line-number" data-cell="setup" data-line="37" href="#cell-setup" onclick="event.preventDefault(); selectCellLine('setup', 37, true);">37</a>
3774
+ <a class="line-number" data-cell="setup" data-line="38" href="#cell-setup" onclick="event.preventDefault(); selectCellLine('setup', 38, true);">38</a>
3775
+ <a class="line-number" data-cell="setup" data-line="39" href="#cell-setup" onclick="event.preventDefault(); selectCellLine('setup', 39, true);">39</a>
3776
+ <a class="line-number" data-cell="setup" data-line="40" href="#cell-setup" onclick="event.preventDefault(); selectCellLine('setup', 40, true);">40</a>
3777
+ <a class="line-number" data-cell="setup" data-line="41" href="#cell-setup" onclick="event.preventDefault(); selectCellLine('setup', 41, true);">41</a>
3778
+ <a class="line-number" data-cell="setup" data-line="42" href="#cell-setup" onclick="event.preventDefault(); selectCellLine('setup', 42, true);">42</a>
3779
+ <a class="line-number" data-cell="setup" data-line="43" href="#cell-setup" onclick="event.preventDefault(); selectCellLine('setup', 43, true);">43</a>
3780
+ <a class="line-number" data-cell="setup" data-line="44" href="#cell-setup" onclick="event.preventDefault(); selectCellLine('setup', 44, true);">44</a>
3781
+ <a class="line-number" data-cell="setup" data-line="45" href="#cell-setup" onclick="event.preventDefault(); selectCellLine('setup', 45, true);">45</a>
3782
+ <a class="line-number" data-cell="setup" data-line="46" href="#cell-setup" onclick="event.preventDefault(); selectCellLine('setup', 46, true);">46</a>
3783
+ <a class="line-number" data-cell="setup" data-line="47" href="#cell-setup" onclick="event.preventDefault(); selectCellLine('setup', 47, true);">47</a>
3784
+ <a class="line-number" data-cell="setup" data-line="48" href="#cell-setup" onclick="event.preventDefault(); selectCellLine('setup', 48, true);">48</a>
3785
+ <a class="line-number" data-cell="setup" data-line="49" href="#cell-setup" onclick="event.preventDefault(); selectCellLine('setup', 49, true);">49</a>
3786
+ <a class="line-number" data-cell="setup" data-line="50" href="#cell-setup" onclick="event.preventDefault(); selectCellLine('setup', 50, true);">50</a>
3787
+ <a class="line-number" data-cell="setup" data-line="51" href="#cell-setup" onclick="event.preventDefault(); selectCellLine('setup', 51, true);">51</a>
3788
+ <a class="line-number" data-cell="setup" data-line="52" href="#cell-setup" onclick="event.preventDefault(); selectCellLine('setup', 52, true);">52</a>
3789
+ <a class="line-number" data-cell="setup" data-line="53" href="#cell-setup" onclick="event.preventDefault(); selectCellLine('setup', 53, true);">53</a>
3790
+ <a class="line-number" data-cell="setup" data-line="54" href="#cell-setup" onclick="event.preventDefault(); selectCellLine('setup', 54, true);">54</a>
3791
+ <a class="line-number" data-cell="setup" data-line="55" href="#cell-setup" onclick="event.preventDefault(); selectCellLine('setup', 55, true);">55</a>
3792
+ <a class="line-number" data-cell="setup" data-line="56" href="#cell-setup" onclick="event.preventDefault(); selectCellLine('setup', 56, true);">56</a>
3793
+ <a class="line-number" data-cell="setup" data-line="57" href="#cell-setup" onclick="event.preventDefault(); selectCellLine('setup', 57, true);">57</a>
3794
+ <a class="line-number" data-cell="setup" data-line="58" href="#cell-setup" onclick="event.preventDefault(); selectCellLine('setup', 58, true);">58</a>
3795
+ <a class="line-number" data-cell="setup" data-line="59" href="#cell-setup" onclick="event.preventDefault(); selectCellLine('setup', 59, true);">59</a>
3796
+ <a class="line-number" data-cell="setup" data-line="60" href="#cell-setup" onclick="event.preventDefault(); selectCellLine('setup', 60, true);">60</a>
3797
+ <a class="line-number" data-cell="setup" data-line="61" href="#cell-setup" onclick="event.preventDefault(); selectCellLine('setup', 61, true);">61</a>
3798
+ <a class="line-number" data-cell="setup" data-line="62" href="#cell-setup" onclick="event.preventDefault(); selectCellLine('setup', 62, true);">62</a>
3799
+ <a class="line-number" data-cell="setup" data-line="63" href="#cell-setup" onclick="event.preventDefault(); selectCellLine('setup', 63, true);">63</a>
3800
+ <a class="line-number" data-cell="setup" data-line="64" href="#cell-setup" onclick="event.preventDefault(); selectCellLine('setup', 64, true);">64</a>
3801
+ <a class="line-number" data-cell="setup" data-line="65" href="#cell-setup" onclick="event.preventDefault(); selectCellLine('setup', 65, true);">65</a>
3802
+ <a class="line-number" data-cell="setup" data-line="66" href="#cell-setup" onclick="event.preventDefault(); selectCellLine('setup', 66, true);">66</a>
3803
+ <a class="line-number" data-cell="setup" data-line="67" href="#cell-setup" onclick="event.preventDefault(); selectCellLine('setup', 67, true);">67</a>
3804
+ <a class="line-number" data-cell="setup" data-line="68" href="#cell-setup" onclick="event.preventDefault(); selectCellLine('setup', 68, true);">68</a>
3805
+ <a class="line-number" data-cell="setup" data-line="69" href="#cell-setup" onclick="event.preventDefault(); selectCellLine('setup', 69, true);">69</a>
3806
+ <a class="line-number" data-cell="setup" data-line="70" href="#cell-setup" onclick="event.preventDefault(); selectCellLine('setup', 70, true);">70</a>
3807
+ <a class="line-number" data-cell="setup" data-line="71" href="#cell-setup" onclick="event.preventDefault(); selectCellLine('setup', 71, true);">71</a>
3808
+ <a class="line-number" data-cell="setup" data-line="72" href="#cell-setup" onclick="event.preventDefault(); selectCellLine('setup', 72, true);">72</a>
3809
+ <a class="line-number" data-cell="setup" data-line="73" href="#cell-setup" onclick="event.preventDefault(); selectCellLine('setup', 73, true);">73</a>
3810
+ <a class="line-number" data-cell="setup" data-line="74" href="#cell-setup" onclick="event.preventDefault(); selectCellLine('setup', 74, true);">74</a>
3811
+ <a class="line-number" data-cell="setup" data-line="75" href="#cell-setup" onclick="event.preventDefault(); selectCellLine('setup', 75, true);">75</a>
3812
+ <a class="line-number" data-cell="setup" data-line="76" href="#cell-setup" onclick="event.preventDefault(); selectCellLine('setup', 76, true);">76</a>
3813
+ <a class="line-number" data-cell="setup" data-line="77" href="#cell-setup" onclick="event.preventDefault(); selectCellLine('setup', 77, true);">77</a>
3814
+ <a class="line-number" data-cell="setup" data-line="78" href="#cell-setup" onclick="event.preventDefault(); selectCellLine('setup', 78, true);">78</a>
3815
+ <a class="line-number" data-cell="setup" data-line="79" href="#cell-setup" onclick="event.preventDefault(); selectCellLine('setup', 79, true);">79</a>
3816
+ <a class="line-number" data-cell="setup" data-line="80" href="#cell-setup" onclick="event.preventDefault(); selectCellLine('setup', 80, true);">80</a>
3817
+ <a class="line-number" data-cell="setup" data-line="81" href="#cell-setup" onclick="event.preventDefault(); selectCellLine('setup', 81, true);">81</a>
3818
+ <a class="line-number" data-cell="setup" data-line="82" href="#cell-setup" onclick="event.preventDefault(); selectCellLine('setup', 82, true);">82</a>
3819
+ <a class="line-number" data-cell="setup" data-line="83" href="#cell-setup" onclick="event.preventDefault(); selectCellLine('setup', 83, true);">83</a>
3820
+ <a class="line-number" data-cell="setup" data-line="84" href="#cell-setup" onclick="event.preventDefault(); selectCellLine('setup', 84, true);">84</a>
3821
+ <a class="line-number" data-cell="setup" data-line="85" href="#cell-setup" onclick="event.preventDefault(); selectCellLine('setup', 85, true);">85</a>
3822
+ <a class="line-number" data-cell="setup" data-line="86" href="#cell-setup" onclick="event.preventDefault(); selectCellLine('setup', 86, true);">86</a>
3823
+ <a class="line-number" data-cell="setup" data-line="87" href="#cell-setup" onclick="event.preventDefault(); selectCellLine('setup', 87, true);">87</a>
3824
+ <a class="line-number" data-cell="setup" data-line="88" href="#cell-setup" onclick="event.preventDefault(); selectCellLine('setup', 88, true);">88</a>
3825
+ <a class="line-number" data-cell="setup" data-line="89" href="#cell-setup" onclick="event.preventDefault(); selectCellLine('setup', 89, true);">89</a>
3826
+ <a class="line-number" data-cell="setup" data-line="90" href="#cell-setup" onclick="event.preventDefault(); selectCellLine('setup', 90, true);">90</a>
3827
+ <a class="line-number" data-cell="setup" data-line="91" href="#cell-setup" onclick="event.preventDefault(); selectCellLine('setup', 91, true);">91</a>
3828
+ <a class="line-number" data-cell="setup" data-line="92" href="#cell-setup" onclick="event.preventDefault(); selectCellLine('setup', 92, true);">92</a>
3829
+ <a class="line-number" data-cell="setup" data-line="93" href="#cell-setup" onclick="event.preventDefault(); selectCellLine('setup', 93, true);">93</a>
3830
+ <a class="line-number" data-cell="setup" data-line="94" href="#cell-setup" onclick="event.preventDefault(); selectCellLine('setup', 94, true);">94</a>
3831
+ <a class="line-number" data-cell="setup" data-line="95" href="#cell-setup" onclick="event.preventDefault(); selectCellLine('setup', 95, true);">95</a>
3832
+ <a class="line-number" data-cell="setup" data-line="96" href="#cell-setup" onclick="event.preventDefault(); selectCellLine('setup', 96, true);">96</a>
3833
+ <a class="line-number" data-cell="setup" data-line="97" href="#cell-setup" onclick="event.preventDefault(); selectCellLine('setup', 97, true);">97</a>
3834
+ <a class="line-number" data-cell="setup" data-line="98" href="#cell-setup" onclick="event.preventDefault(); selectCellLine('setup', 98, true);">98</a>
3835
+ <a class="line-number" data-cell="setup" data-line="99" href="#cell-setup" onclick="event.preventDefault(); selectCellLine('setup', 99, true);">99</a>
3836
+ <a class="line-number" data-cell="setup" data-line="100" href="#cell-setup" onclick="event.preventDefault(); selectCellLine('setup', 100, true);">100</a>
3837
+ <a class="line-number" data-cell="setup" data-line="101" href="#cell-setup" onclick="event.preventDefault(); selectCellLine('setup', 101, true);">101</a>
3838
+ <a class="line-number" data-cell="setup" data-line="102" href="#cell-setup" onclick="event.preventDefault(); selectCellLine('setup', 102, true);">102</a>
3839
+ <a class="line-number" data-cell="setup" data-line="103" href="#cell-setup" onclick="event.preventDefault(); selectCellLine('setup', 103, true);">103</a>
3840
+ <a class="line-number" data-cell="setup" data-line="104" href="#cell-setup" onclick="event.preventDefault(); selectCellLine('setup', 104, true);">104</a>
3841
+ <a class="line-number" data-cell="setup" data-line="105" href="#cell-setup" onclick="event.preventDefault(); selectCellLine('setup', 105, true);">105</a>
3842
+ <a class="line-number" data-cell="setup" data-line="106" href="#cell-setup" onclick="event.preventDefault(); selectCellLine('setup', 106, true);">106</a>
3843
+ <a class="line-number" data-cell="setup" data-line="107" href="#cell-setup" onclick="event.preventDefault(); selectCellLine('setup', 107, true);">107</a>
3844
+ <a class="line-number" data-cell="setup" data-line="108" href="#cell-setup" onclick="event.preventDefault(); selectCellLine('setup', 108, true);">108</a>
3845
+ <a class="line-number" data-cell="setup" data-line="109" href="#cell-setup" onclick="event.preventDefault(); selectCellLine('setup', 109, true);">109</a>
3846
+ <a class="line-number" data-cell="setup" data-line="110" href="#cell-setup" onclick="event.preventDefault(); selectCellLine('setup', 110, true);">110</a>
3847
+ <a class="line-number" data-cell="setup" data-line="111" href="#cell-setup" onclick="event.preventDefault(); selectCellLine('setup', 111, true);">111</a>
3848
+ <a class="line-number" data-cell="setup" data-line="112" href="#cell-setup" onclick="event.preventDefault(); selectCellLine('setup', 112, true);">112</a>
3849
+ <a class="line-number" data-cell="setup" data-line="113" href="#cell-setup" onclick="event.preventDefault(); selectCellLine('setup', 113, true);">113</a>
3850
+ <a class="line-number" data-cell="setup" data-line="114" href="#cell-setup" onclick="event.preventDefault(); selectCellLine('setup', 114, true);">114</a>
3851
+ <a class="line-number" data-cell="setup" data-line="115" href="#cell-setup" onclick="event.preventDefault(); selectCellLine('setup', 115, true);">115</a>
3852
+ <a class="line-number" data-cell="setup" data-line="116" href="#cell-setup" onclick="event.preventDefault(); selectCellLine('setup', 116, true);">116</a>
3853
  </div>
3854
  <div class="code-wrap">
3855
+ <div class="highlight"><pre><span></span><span class="c1"># /// script</span>
3856
+ <span class="c1"># requires-python = &quot;&gt;=3.12&quot;</span>
3857
+ <span class="c1"># dependencies = [</span>
3858
+ <span class="c1"># &quot;accelerate&gt;=1.10.1&quot;,</span>
3859
+ <span class="c1"># &quot;torch&gt;=2.7.0&quot;,</span>
3860
+ <span class="c1"># &quot;kernels==0.10.0&quot;,</span>
3861
+ <span class="c1"># &quot;transformers@https://github.com/huggingface/transformers.git&quot;,</span>
3862
+ <span class="c1"># &quot;ipdb&gt;=0.13.13&quot;,</span>
3863
+ <span class="c1"># &quot;matplotlib&gt;=3.7.2&quot;,</span>
3864
+ <span class="c1"># &quot;numpy&gt;=1.24.3&quot;,</span>
3865
+ <span class="c1"># ]</span>
3866
+ <span class="c1"># ///</span>
3867
+
3868
+ <span class="kn">import</span><span class="w"> </span><span class="nn">torch</span>
3869
+ <span class="kn">from</span><span class="w"> </span><span class="nn">transformers</span><span class="w"> </span><span class="kn">import</span> <span class="n">GptOssForCausalLM</span><span class="p">,</span> <span class="n">PreTrainedTokenizerFast</span><span class="p">,</span> <span class="n">Mxfp4Config</span>
3870
+ <span class="kn">import</span><span class="w"> </span><span class="nn">time</span>
3871
+ <span class="kn">import</span><span class="w"> </span><span class="nn">torch.nn</span><span class="w"> </span><span class="k">as</span><span class="w"> </span><span class="nn">nn</span>
3872
+ <span class="kn">from</span><span class="w"> </span><span class="nn">kernels</span><span class="w"> </span><span class="kn">import</span> <span class="n">register_kernel_mapping</span><span class="p">,</span> <span class="n">Mode</span><span class="p">,</span> <span class="n">LayerRepository</span>
3873
+ <span class="kn">import</span><span class="w"> </span><span class="nn">sys</span>
3874
+ <span class="kn">import</span><span class="w"> </span><span class="nn">torch.profiler</span>
3875
+ <span class="kn">import</span><span class="w"> </span><span class="nn">gc</span>
3876
+ <span class="kn">import</span><span class="w"> </span><span class="nn">logging</span>
3877
+
3878
+ <span class="c1"># set to debug logging</span>
3879
+ <span class="n">logging</span><span class="o">.</span><span class="n">basicConfig</span><span class="p">(</span><span class="n">level</span><span class="o">=</span><span class="n">logging</span><span class="o">.</span><span class="n">INFO</span><span class="p">)</span>
3880
+
3881
+ <span class="k">def</span><span class="w"> </span><span class="nf">reset_peak_memory_stats</span><span class="p">():</span>
3882
+ <span class="w"> </span><span class="sd">&quot;&quot;&quot;Clear CUDA cache and reset memory allocation counters.&quot;&quot;&quot;</span>
3883
+ <span class="n">torch</span><span class="o">.</span><span class="n">cuda</span><span class="o">.</span><span class="n">empty_cache</span><span class="p">()</span>
3884
+ <span class="k">if</span> <span class="n">torch</span><span class="o">.</span><span class="n">cuda</span><span class="o">.</span><span class="n">is_available</span><span class="p">():</span>
3885
+ <span class="n">torch</span><span class="o">.</span><span class="n">cuda</span><span class="o">.</span><span class="n">reset_peak_memory_stats</span><span class="p">()</span>
3886
+ <span class="n">gc</span><span class="o">.</span><span class="n">collect</span><span class="p">()</span>
3887
+
3888
+ <span class="k">def</span><span class="w"> </span><span class="nf">get_memory_stats</span><span class="p">():</span>
3889
+ <span class="w"> </span><span class="sd">&quot;&quot;&quot;Get current and peak CUDA memory usage.&quot;&quot;&quot;</span>
3890
+ <span class="k">if</span> <span class="ow">not</span> <span class="n">torch</span><span class="o">.</span><span class="n">cuda</span><span class="o">.</span><span class="n">is_available</span><span class="p">():</span>
3891
+ <span class="k">return</span> <span class="p">{</span><span class="s2">&quot;allocated_gb&quot;</span><span class="p">:</span> <span class="mi">0</span><span class="p">,</span> <span class="s2">&quot;peak_gb&quot;</span><span class="p">:</span> <span class="mi">0</span><span class="p">,</span> <span class="s2">&quot;reserved_gb&quot;</span><span class="p">:</span> <span class="mi">0</span><span class="p">}</span>
3892
+ <span class="k">return</span> <span class="p">{</span>
3893
+ <span class="s2">&quot;allocated_gb&quot;</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">cuda</span><span class="o">.</span><span class="n">memory_allocated</span><span class="p">()</span> <span class="o">/</span> <span class="mf">1e9</span><span class="p">,</span>
3894
+ <span class="s2">&quot;peak_gb&quot;</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">cuda</span><span class="o">.</span><span class="n">max_memory_allocated</span><span class="p">()</span> <span class="o">/</span> <span class="mf">1e9</span><span class="p">,</span>
3895
+ <span class="s2">&quot;reserved_gb&quot;</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">cuda</span><span class="o">.</span><span class="n">memory_reserved</span><span class="p">()</span> <span class="o">/</span> <span class="mf">1e9</span><span class="p">,</span>
3896
+ <span class="p">}</span>
3897
+
3898
+ <span class="k">def</span><span class="w"> </span><span class="nf">override_kernel_layer_name</span><span class="p">(</span><span class="n">cls_name</span><span class="p">:</span> <span class="nb">str</span><span class="p">,</span> <span class="n">value</span><span class="p">)</span> <span class="o">-&gt;</span> <span class="nb">bool</span><span class="p">:</span>
3899
+ <span class="w"> </span><span class="sd">&quot;&quot;&quot;Helper to dynamically override the kernel_layer_name in a model class.&quot;&quot;&quot;</span>
3900
+ <span class="k">for</span> <span class="n">mod</span> <span class="ow">in</span> <span class="n">sys</span><span class="o">.</span><span class="n">modules</span><span class="o">.</span><span class="n">values</span><span class="p">():</span>
3901
+ <span class="k">if</span> <span class="n">mod</span> <span class="ow">is</span> <span class="kc">None</span><span class="p">:</span>
3902
+ <span class="k">continue</span>
3903
+ <span class="n">obj</span> <span class="o">=</span> <span class="nb">getattr</span><span class="p">(</span><span class="n">mod</span><span class="p">,</span> <span class="n">cls_name</span><span class="p">,</span> <span class="kc">None</span><span class="p">)</span>
3904
+ <span class="k">if</span> <span class="nb">isinstance</span><span class="p">(</span><span class="n">obj</span><span class="p">,</span> <span class="nb">type</span><span class="p">)</span> <span class="ow">and</span> <span class="nb">issubclass</span><span class="p">(</span><span class="n">obj</span><span class="p">,</span> <span class="n">nn</span><span class="o">.</span><span class="n">Module</span><span class="p">):</span>
3905
+ <span class="nb">setattr</span><span class="p">(</span><span class="n">obj</span><span class="p">,</span> <span class="s2">&quot;kernel_layer_name&quot;</span><span class="p">,</span> <span class="n">value</span><span class="p">)</span>
3906
+ <span class="nb">print</span><span class="p">(</span><span class="sa">f</span><span class="s2">&quot;Overrode </span><span class="si">{</span><span class="n">cls_name</span><span class="si">}</span><span class="s2">.kernel_layer_name to </span><span class="si">{</span><span class="n">value</span><span class="si">}</span><span class="s2">&quot;</span><span class="p">)</span>
3907
+ <span class="k">return</span> <span class="kc">True</span>
3908
+ <span class="k">return</span> <span class="kc">False</span>
3909
+
3910
+
3911
+ <span class="c1"># Init the model the normal way</span>
3912
+ <span class="n">model_id</span> <span class="o">=</span> <span class="s2">&quot;openai/gpt-oss-20b&quot;</span>
3913
+ <span class="n">tokenizer</span> <span class="o">=</span> <span class="n">PreTrainedTokenizerFast</span><span class="o">.</span><span class="n">from_pretrained</span><span class="p">(</span><span class="n">model_id</span><span class="p">)</span>
3914
+ <span class="n">quantization_config</span> <span class="o">=</span> <span class="n">Mxfp4Config</span><span class="p">(</span><span class="n">dequantize</span><span class="o">=</span><span class="kc">True</span><span class="p">)</span>
3915
+
3916
+
3917
+ <span class="kn">from</span><span class="w"> </span><span class="nn">kernels</span><span class="w"> </span><span class="kn">import</span> <span class="n">replace_kernel_forward_from_hub</span><span class="p">,</span> <span class="n">register_kernel_mapping</span><span class="p">,</span> <span class="n">LayerRepository</span><span class="p">,</span> <span class="n">Mode</span>
3918
+
3919
+ <span class="kn">from</span><span class="w"> </span><span class="nn">transformers.models.gpt_oss.modeling_gpt_oss</span><span class="w"> </span><span class="kn">import</span> <span class="n">GptOssMLP</span><span class="p">,</span> <span class="n">GptOssRMSNorm</span>
3920
+
3921
+ <span class="n">replace_kernel_forward_from_hub</span><span class="p">(</span><span class="n">GptOssMLP</span><span class="p">,</span> <span class="s2">&quot;Yamoe&quot;</span><span class="p">)</span>
3922
+ <span class="n">replace_kernel_forward_from_hub</span><span class="p">(</span><span class="n">GptOssRMSNorm</span><span class="p">,</span> <span class="kc">None</span><span class="p">)</span>
3923
+ <span class="n">custom_mapping</span> <span class="o">=</span> <span class="p">{</span>
3924
+ <span class="s2">&quot;Yamoe&quot;</span><span class="p">:</span> <span class="p">{</span>
3925
+ <span class="s2">&quot;cuda&quot;</span><span class="p">:</span> <span class="p">{</span>
3926
+ <span class="n">Mode</span><span class="o">.</span><span class="n">INFERENCE</span><span class="p">:</span> <span class="n">LayerRepository</span><span class="p">(</span>
3927
+ <span class="n">repo_id</span><span class="o">=</span><span class="s2">&quot;drbh/yamoe&quot;</span><span class="p">,</span>
3928
+ <span class="n">layer_name</span><span class="o">=</span><span class="s2">&quot;Yamoe&quot;</span><span class="p">,</span>
3929
+ <span class="n">revision</span><span class="o">=</span><span class="s2">&quot;v0.3.0&quot;</span><span class="p">,</span>
3930
+ <span class="p">)</span>
3931
+ <span class="p">}</span>
3932
+ <span class="p">}</span>
3933
+ <span class="p">}</span>
3934
+ <span class="n">register_kernel_mapping</span><span class="p">(</span><span class="n">custom_mapping</span><span class="p">)</span>
3935
+
3936
+
3937
+ <span class="n">model</span> <span class="o">=</span> <span class="n">GptOssForCausalLM</span><span class="o">.</span><span class="n">from_pretrained</span><span class="p">(</span>
3938
+ <span class="n">model_id</span><span class="p">,</span>
3939
+ <span class="n">dtype</span><span class="o">=</span><span class="s2">&quot;bfloat16&quot;</span><span class="p">,</span>
3940
+ <span class="n">device_map</span><span class="o">=</span><span class="s2">&quot;auto&quot;</span><span class="p">,</span>
3941
+ <span class="n">use_kernels</span><span class="o">=</span><span class="kc">True</span><span class="p">,</span>
3942
+ <span class="n">quantization_config</span><span class="o">=</span><span class="n">quantization_config</span><span class="p">,</span>
3943
+ <span class="p">)</span><span class="o">.</span><span class="n">eval</span><span class="p">()</span>
3944
+
3945
+ <span class="n">messages</span> <span class="o">=</span> <span class="p">[</span>
3946
+ <span class="p">{</span><span class="s2">&quot;role&quot;</span><span class="p">:</span> <span class="s2">&quot;system&quot;</span><span class="p">,</span> <span class="s2">&quot;content&quot;</span><span class="p">:</span> <span class="s2">&quot;What is Tensor Parallelism?&quot;</span><span class="p">},</span>
3947
+ <span class="p">]</span>
3948
+
3949
+ <span class="n">inputs</span> <span class="o">=</span> <span class="n">tokenizer</span><span class="o">.</span><span class="n">apply_chat_template</span><span class="p">(</span>
3950
+ <span class="n">messages</span><span class="p">,</span>
3951
+ <span class="n">add_generation_prompt</span><span class="o">=</span><span class="kc">True</span><span class="p">,</span>
3952
+ <span class="n">return_tensors</span><span class="o">=</span><span class="s2">&quot;pt&quot;</span><span class="p">,</span>
3953
+ <span class="n">return_dict</span><span class="o">=</span><span class="kc">True</span><span class="p">,</span>
3954
+ <span class="n">reasoning_effort</span><span class="o">=</span><span class="s2">&quot;low&quot;</span><span class="p">,</span>
3955
+ <span class="p">)</span><span class="o">.</span><span class="n">to</span><span class="p">(</span><span class="s2">&quot;cuda&quot;</span><span class="p">)</span>
3956
+
3957
+ <span class="n">max_tokens</span> <span class="o">=</span> <span class="mi">256</span>
3958
+
3959
+ <span class="k">with</span> <span class="n">torch</span><span class="o">.</span><span class="n">inference_mode</span><span class="p">():</span>
3960
+ <span class="n">start_time</span> <span class="o">=</span> <span class="n">time</span><span class="o">.</span><span class="n">perf_counter</span><span class="p">()</span>
3961
+ <span class="n">generated</span> <span class="o">=</span> <span class="n">model</span><span class="o">.</span><span class="n">generate</span><span class="p">(</span>
3962
+ <span class="o">**</span><span class="n">inputs</span><span class="p">,</span>
3963
+ <span class="n">max_new_tokens</span><span class="o">=</span><span class="n">max_tokens</span><span class="p">,</span>
3964
+ <span class="n">do_sample</span><span class="o">=</span><span class="kc">False</span><span class="p">,</span>
3965
+ <span class="n">temperature</span><span class="o">=</span><span class="kc">None</span><span class="p">,</span>
3966
+ <span class="p">)</span>
3967
+ <span class="n">end_time</span> <span class="o">=</span> <span class="n">time</span><span class="o">.</span><span class="n">perf_counter</span><span class="p">()</span>
3968
+
3969
+ <span class="nb">print</span><span class="p">(</span><span class="n">tokenizer</span><span class="o">.</span><span class="n">decode</span><span class="p">(</span><span class="n">generated</span><span class="p">[</span><span class="mi">0</span><span class="p">],</span> <span class="n">skip_special_tokens</span><span class="o">=</span><span class="kc">False</span><span class="p">))</span>
3970
+ <span class="nb">print</span><span class="p">(</span><span class="sa">f</span><span class="s2">&quot;Generation took </span><span class="si">{</span><span class="n">end_time</span><span class="w"> </span><span class="o">-</span><span class="w"> </span><span class="n">start_time</span><span class="si">:</span><span class="s2">.2f</span><span class="si">}</span><span class="s2"> seconds&quot;</span><span class="p">)</span>
3971
  </pre></div>
3972
 
3973
+ <div class="code-line-highlight" id="line-highlight-setup"></div>
3974
  </div>
3975
  </div>
3976
  </div>
3977
+ <div id="output-setup" class="cell-output">
3978
+ <div class="cell-stderr">Downloading cpython-3.13.7-linux-x86_64-gnu (download) (32.0MiB)
3979
+ Downloading cpython-3.13.7-linux-x86_64-gnu (download)
3980
+ Updating https://github.com/huggingface/transformers.git (HEAD)
3981
+ Updated https://github.com/huggingface/transformers.git (449533af73874470e914a203391635e04ac2ffc8)
3982
+ × No solution found when resolving script dependencies:
3983
+ ╰─▶ Because only transformers==4.57.0.dev0 is available and
3984
+ transformers==4.57.0.dev0 depends on huggingface-hub==1.0.0rc1,
3985
+ we can conclude that all versions of transformers depend on
3986
+ huggingface-hub==1.0.0rc1.
3987
+ And because kernels==0.10.0 depends on huggingface-hub&gt;=0.26.0,&lt;1.0,
3988
+ we can conclude that kernels==0.10.0 and all versions of transformers
3989
+ are incompatible.
3990
+ And because you require kernels==0.10.0 and transformers, we can
3991
+ conclude that your requirements are unsatisfiable.
3992
  </div>
3993
  </div>
3994
  </div>
 
 
 
 
 
3995
  </div>
3996
 
3997
  </body>
moe_benchmarks/megablocks_yamoe/torch_profile.html CHANGED
The diff for this file is too large to render. See raw diff