drbh commited on
Commit
30c62e2
·
1 Parent(s): 08478da

fix: remove debug build

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. flash_attn/artifacts/benchmark/Attention Benchmark.csv +0 -7
  2. flash_attn/artifacts/benchmark/Attention Benchmark.png +0 -3
  3. flash_attn/artifacts/benchmark/results.html +0 -3
  4. flash_attn/benchmark.html +0 -0
  5. flash_attn/cells/benchmark.py +0 -343
  6. flash_attn/cells/nv.py +0 -3
  7. flash_attn/impls/artifacts/benchmark/attn.jsonl +0 -6
  8. flash_attn/impls/artifacts/benchmark_default/attn_default.jsonl +0 -6
  9. flash_attn/impls/artifacts/benchmark_max_autotune/attn_max_autotune.jsonl +0 -6
  10. flash_attn/impls/cells/benchmark.py +0 -71
  11. flash_attn/impls/cells/benchmark_default.py +0 -70
  12. flash_attn/impls/cells/benchmark_max_autotune.py +0 -70
  13. flash_attn/impls/cells/nv.py +0 -3
  14. flash_attn/impls/compiled_variants.html +0 -0
  15. flash_attn/impls/flash_attention.html +0 -0
  16. flash_attn/impls/hf_kernels_flash_attn.html +0 -0
  17. flash_attn/impls/hf_kernels_flash_attn3.html +0 -0
  18. flash_attn/impls/index.html +0 -94
  19. flash_attn/impls/mem_efficient_attention.html +0 -0
  20. flash_attn/impls/sage_attention.html +0 -0
  21. flash_attn/impls/xformers.html +0 -0
  22. flash_attn/index.html +0 -89
  23. flash_attn/results/artifacts/combine/latency.csv +0 -43
  24. flash_attn/results/artifacts/combine/latency.png +0 -3
  25. flash_attn/results/artifacts/combine/latency.svg +0 -3
  26. flash_attn/results/cells/combine.py +0 -319
  27. flash_attn/results/combined_results.html +0 -0
  28. flash_attn/results/index.html +0 -88
  29. index.html +0 -85
  30. megablocks/cells/forward_and_backward.py +0 -196
  31. megablocks/cells/forward_and_backward_no_kernel.py +0 -196
  32. megablocks/cells/forward_only.py +0 -101
  33. megablocks/cells/no_kernels.py +0 -98
  34. megablocks/cells/nv.py +0 -3
  35. megablocks/index.html +0 -24
  36. megablocks/megablocks_only.html +0 -0
  37. megablocks_yamoe/artifacts/binned_run/binned_results.json +0 -24
  38. megablocks_yamoe/artifacts/gptoss_run/gptoss_results.json +0 -24
  39. megablocks_yamoe/artifacts/gptoss_training_run/gptoss_training_results.json +0 -24
  40. megablocks_yamoe/artifacts/yamoe_run/yamoe_results.json +0 -24
  41. megablocks_yamoe/cells/__pycache__/bench_utils.cpython-311.pyc +0 -0
  42. megablocks_yamoe/cells/__pycache__/config.cpython-311.pyc +0 -0
  43. megablocks_yamoe/cells/bench_utils.py +0 -241
  44. megablocks_yamoe/cells/binned_run.py +0 -195
  45. megablocks_yamoe/cells/config.py +0 -27
  46. megablocks_yamoe/cells/gptoss_run.py +0 -147
  47. megablocks_yamoe/cells/gptoss_training_run.py +0 -138
  48. megablocks_yamoe/cells/megablocks_run.py +0 -103
  49. megablocks_yamoe/cells/nv.py +0 -3
  50. megablocks_yamoe/cells/save_data.py +0 -42
flash_attn/artifacts/benchmark/Attention Benchmark.csv DELETED
@@ -1,7 +0,0 @@
1
- seq_len,torch_cudnn,torch_cudnn_compile_d,torch_cudnn_compile_ma,torch_flash,torch_flash_compile_d,torch_flash_compile_ma,hf_flash_attn,hf_flash_attn3
2
- 4224.000000,3.801472,3.790064,4.182320,3.968000,3.957824,4.311152,3.398160,3.330400
3
- 4352.000000,4.082944,4.082912,4.413488,4.400000,4.391936,4.738048,3.837424,3.758208
4
- 4416.000000,4.142624,4.135648,4.484160,4.452304,4.446096,4.792480,3.892064,3.864128
5
- 4480.000000,4.206144,4.198752,4.551808,4.530752,4.522944,4.873760,3.949344,3.870224
6
- 4544.000000,4.438320,4.433104,4.787584,4.584160,4.576640,4.934304,4.008960,3.974672
7
- 4608.000000,4.502432,4.495456,4.871872,4.660192,4.651040,5.029792,4.065616,3.984160
 
 
 
 
 
 
 
 
flash_attn/artifacts/benchmark/Attention Benchmark.png DELETED

Git LFS Details

  • SHA256: 69a5d2d4ac33fa06e77a599eab6cadcddb77c15ad7bde323bb07849e2aa3ac14
  • Pointer size: 131 Bytes
  • Size of remote file: 142 kB
flash_attn/artifacts/benchmark/results.html DELETED
@@ -1,3 +0,0 @@
1
- <html><body>
2
- <image src="Attention Benchmark.png"/>
3
- </body></html>
 
 
 
 
flash_attn/benchmark.html DELETED
The diff for this file is too large to render. See raw diff
 
flash_attn/cells/benchmark.py DELETED
@@ -1,343 +0,0 @@
1
- # /// script
2
- # dependencies = [
3
- # "numpy",
4
- # "torch",
5
- # "kernels",
6
- # "pandas",
7
- # "matplotlib"
8
- # ]
9
- # ///
10
- # Benchmarking common shapes for Flux 1024x1024px image + varying text sequence lengths
11
-
12
- import functools
13
- import os
14
- import pathlib
15
-
16
- import matplotlib.pyplot as plt
17
- import torch
18
- import torch._dynamo.config
19
- import triton
20
- import triton.language as tl
21
-
22
- try:
23
- from flash_attn import flash_attn_func
24
- except:
25
- flash_attn_func = None
26
- print("Flash Attention 2 not found.")
27
-
28
- try:
29
- from flash_attn_interface import flash_attn_func as flash_attn_3_func
30
- except:
31
- flash_attn_3_func = None
32
- print("Flash Attention 3 not found.")
33
-
34
- try:
35
- from kernels import get_kernel
36
- hf_kernels_flash_attn = get_kernel("kernels-community/flash-attn")
37
- hf_kernels_flash_attn_3 = get_kernel("kernels-community/flash-attn3")
38
- except:
39
- hf_kernels_flash_attn = None
40
- hf_kernels_flash_attn_3 = None
41
- print("HF Kernels not found.")
42
-
43
- try:
44
- from sageattention import sageattn_qk_int8_pv_fp16_cuda, sageattn_qk_int8_pv_fp16_triton, sageattn_qk_int8_pv_fp8_cuda_sm90
45
- except:
46
- sageattn_qk_int8_pv_fp16_cuda = None
47
- sageattn_qk_int8_pv_fp16_triton = None
48
- sageattn_qk_int8_pv_fp8_cuda_sm90 = None
49
- print("SageAttention not found.")
50
-
51
- try:
52
- from transformer_engine.pytorch.attention import DotProductAttention
53
- except:
54
- DotProductAttention = None
55
- print("Transformer Engine not found.")
56
-
57
- try:
58
- import xformers.ops as xops
59
- except:
60
- xops = None
61
- print("xFormers not found.")
62
-
63
-
64
- plt.rcParams.update({
65
- "figure.figsize": (12, 10),
66
- "figure.dpi": 120,
67
- "font.size": 10,
68
- "axes.titlesize": 12,
69
- "axes.labelsize": 14,
70
- "xtick.labelsize": 10,
71
- "ytick.labelsize": 10,
72
- "legend.fontsize": 8,
73
- "axes.grid": True,
74
- "grid.alpha": 0.3,
75
- "grid.linestyle": "--",
76
- "lines.linewidth": 2.0,
77
- "lines.markersize": 6,
78
- "legend.frameon": True,
79
- "legend.framealpha": 0.9,
80
- "legend.loc": "best",
81
- "axes.spines.top": False,
82
- "axes.spines.right": False,
83
- })
84
-
85
-
86
- # We want to compare the best compiled version for each specific shape (dynamic=False)
87
- torch._dynamo.config.cache_size_limit = 10000
88
-
89
- # We need to suppress_errors for FA3 to work. It makes it run in eager mode.
90
- # I can't seem to get it to work any other way under torch.compile, so any suggestions are welcome!
91
- torch._dynamo.config.suppress_errors = True
92
-
93
- # output_dir = pathlib.Path("dump_attention_benchmark")
94
- # output_dir.mkdir(parents=True, exist_ok=True)
95
-
96
- output_dir = pathlib.Path(".") # output to current directory for upload
97
-
98
- batch_size = 1
99
- num_attention_heads = 24
100
- attention_head_dim = 128
101
- image_sequence_length = 4096 # 1024x1024px
102
- text_sequence_lengths = [128, 256, 320, 384, 448, 512]
103
- sequence_lengths = [image_sequence_length + i for i in text_sequence_lengths]
104
-
105
-
106
- def _attention_torch(query, key, value, *, backend):
107
- query, key, value = (x.transpose(1, 2).contiguous() for x in (query, key, value))
108
- with torch.nn.attention.sdpa_kernel(backend):
109
- out = torch.nn.functional.scaled_dot_product_attention(query, key, value)
110
- out = out.transpose(1, 2).contiguous()
111
- return out
112
-
113
-
114
- _compiled_attention_torch_default = torch.compile(_attention_torch, mode="default", fullgraph=True, dynamic=False)
115
- def _attention_torch_compile_default(query, key, value, *, backend):
116
- return _compiled_attention_torch_default(query, key, value, backend=backend)
117
-
118
-
119
- _compiled_attention_torch_max_autotune = torch.compile(_attention_torch, mode="max-autotune", fullgraph=True, dynamic=False)
120
- def _attention_torch_compile_max_autotune(query, key, value, *, backend):
121
- return _compiled_attention_torch_max_autotune(query, key, value, backend=backend)
122
-
123
-
124
- def _attention_flash_attn_2(query, key, value):
125
- return flash_attn_func(query, key, value)
126
-
127
-
128
- _compiled_flash_attn_2_default = torch.compile(_attention_flash_attn_2, mode="default", fullgraph=True, dynamic=False)
129
- def _attention_flash_attn_2_compile_default(query, key, value):
130
- return _compiled_flash_attn_2_default(query, key, value)
131
-
132
-
133
- _compiled_flash_attn_2_max_autotune = torch.compile(_attention_flash_attn_2, mode="max-autotune", fullgraph=True, dynamic=False)
134
- def _attention_flash_attn_2_compile_max_autotune(query, key, value):
135
- return _compiled_flash_attn_2_max_autotune(query, key, value)
136
-
137
-
138
- # For fullgraph=True tracing to be compatible
139
- @torch.library.custom_op("flash_attn_3::_flash_attn_forward", mutates_args=(), device_types="cuda")
140
- def _wrapped_flash_attn_3(query: torch.Tensor, key: torch.Tensor, value: torch.Tensor) -> torch.Tensor:
141
- out, lse = flash_attn_3_func(query, key, value)
142
- return out
143
-
144
-
145
- @torch.library.register_fake("flash_attn_3::_flash_attn_forward")
146
- def _(query: torch.Tensor, key: torch.Tensor, value: torch.Tensor) -> torch.Tensor:
147
- return torch.empty_like(query)
148
-
149
-
150
- def _attention_flash_attn_3(query, key, value):
151
- out = _wrapped_flash_attn_3(query, key, value)
152
- return out
153
-
154
-
155
- _compiled_flash_attn_3_default = torch.compile(_attention_flash_attn_3, mode="default", fullgraph=True, dynamic=False)
156
- def _attention_flash_attn_3_compile_default(query, key, value):
157
- return _compiled_flash_attn_3_default(query, key, value)
158
-
159
-
160
- _compiled_flash_attn_3_max_autotune = torch.compile(_attention_flash_attn_3, mode="max-autotune", fullgraph=True, dynamic=False)
161
- def _attention_flash_attn_3_compile_max_autotune(query, key, value):
162
- return _compiled_flash_attn_3_max_autotune(query, key, value)
163
-
164
-
165
- def _attention_hf_kernels_flash_attn(query, key, value):
166
- return hf_kernels_flash_attn.fwd(query, key, value, is_causal=False)[0]
167
-
168
-
169
- def _attention_hf_kernels_flash_attn3(query, key, value):
170
- return hf_kernels_flash_attn_3.flash_attn_func(query, key, value, causal=False)[0]
171
-
172
-
173
- def _attention_sageattn_qk_int8_pv_fp16_cuda(query, key, value):
174
- return sageattn_qk_int8_pv_fp16_cuda(query, key, value, tensor_layout="NHD")
175
-
176
-
177
- def _attention_sageattn_qk_int8_pv_fp16_triton(query, key, value):
178
- return sageattn_qk_int8_pv_fp16_triton(query, key, value, tensor_layout="NHD")
179
-
180
-
181
- def _attention_sageattn_qk_int8_pv_fp8_cuda_sm90(query, key, value):
182
- return sageattn_qk_int8_pv_fp8_cuda_sm90(query, key, value, tensor_layout="NHD")
183
-
184
-
185
- if DotProductAttention is not None:
186
- def set_te_backend(backend):
187
- # must be applied before first use of
188
- # transformer_engine.pytorch.attention
189
- os.environ["NVTE_FLASH_ATTN"] = '0'
190
- os.environ["NVTE_FUSED_ATTN"] = '0'
191
- os.environ["NVTE_UNFUSED_ATTN"] = '0'
192
- if backend == 'flash':
193
- os.environ["NVTE_FLASH_ATTN"] = '1'
194
- if backend == 'fused':
195
- os.environ["NVTE_FUSED_ATTN"] = '1'
196
- if backend == 'unfused':
197
- os.environ["NVTE_UNFUSED_ATTN"] = '1'
198
-
199
- set_te_backend("fused")
200
- te_attn_fn = DotProductAttention(
201
- num_attention_heads=num_attention_heads,
202
- kv_channels=attention_head_dim,
203
- qkv_format="bshd",
204
- attn_mask_type="no_mask",
205
- )
206
- else:
207
- def te_attn_fn(query, key, value):
208
- raise RuntimeError("Transformer Engine is not available. Please install it for TE-based attention.")
209
-
210
- def _attention_te(query, key, value):
211
- out = te_attn_fn(query, key, value)
212
- out = out.unflatten(2, (num_attention_heads, attention_head_dim))
213
- return out
214
-
215
-
216
- # Cannot fullgraph compile TE
217
- _compiled_te_attn_fn_default = torch.compile(_attention_te, mode="default", fullgraph=False, dynamic=False)
218
- def _attention_te_compile_default(query, key, value):
219
- return _compiled_te_attn_fn_default(query, key, value)
220
-
221
-
222
- # Cannot fullgraph compile TE
223
- _compiled_te_attn_fn_max_autotune = torch.compile(_attention_te, mode="max-autotune", fullgraph=False, dynamic=False)
224
- def _attention_te_compile_max_autotune(query, key, value):
225
- return _compiled_te_attn_fn_max_autotune(query, key, value)
226
-
227
-
228
- def _attention_xformers(query, key, value):
229
- return xops.memory_efficient_attention(query, key, value)
230
-
231
-
232
- _compiled_xformers_default = torch.compile(_attention_xformers, mode="default", fullgraph=True, dynamic=False)
233
- def _attention_xformers_compile_default(query, key, value):
234
- return _compiled_xformers_default(query, key, value)
235
-
236
-
237
- _compiled_xformers_max_autotune = torch.compile(_attention_xformers, mode="max-autotune", fullgraph=True, dynamic=False)
238
- def _attention_xformers_compile_max_autotune(query, key, value):
239
- return _compiled_xformers_max_autotune(query, key, value)
240
-
241
-
242
- attention_ops = {}
243
- attention_ops["torch_cudnn"] = functools.partial(_attention_torch, backend=torch.nn.attention.SDPBackend.CUDNN_ATTENTION)
244
- attention_ops["torch_cudnn_compile_d"] = functools.partial(_attention_torch_compile_default, backend=torch.nn.attention.SDPBackend.CUDNN_ATTENTION)
245
- attention_ops["torch_cudnn_compile_ma"] = functools.partial(_attention_torch_compile_max_autotune, backend=torch.nn.attention.SDPBackend.CUDNN_ATTENTION)
246
- attention_ops["torch_flash"] = functools.partial(_attention_torch, backend=torch.nn.attention.SDPBackend.FLASH_ATTENTION)
247
- attention_ops["torch_flash_compile_d"] = functools.partial(_attention_torch_compile_default, backend=torch.nn.attention.SDPBackend.FLASH_ATTENTION)
248
- attention_ops["torch_flash_compile_ma"] = functools.partial(_attention_torch_compile_max_autotune, backend=torch.nn.attention.SDPBackend.FLASH_ATTENTION)
249
- if hf_kernels_flash_attn is not None:
250
- attention_ops["hf_flash_attn"] = _attention_hf_kernels_flash_attn
251
- attention_ops["hf_flash_attn3"] = _attention_hf_kernels_flash_attn3
252
- if flash_attn_func is not None:
253
- attention_ops["flash_attn_2"] = _attention_flash_attn_2
254
- attention_ops["flash_attn_2_compile_d"] = _attention_flash_attn_2_compile_default
255
- attention_ops["flash_attn_2_compile_ma"] = _attention_flash_attn_2_compile_max_autotune
256
- if flash_attn_3_func is not None:
257
- attention_ops["flash_attn_3"] = _attention_flash_attn_3
258
- attention_ops["flash_attn_3_compile_d"] = _attention_flash_attn_3_compile_default
259
- attention_ops["flash_attn_3_compile_ma"] = _attention_flash_attn_3_compile_max_autotune
260
- if sageattn_qk_int8_pv_fp16_cuda is not None:
261
- attention_ops["sageattn_qk_int8_pv_fp16_cuda"] = _attention_sageattn_qk_int8_pv_fp16_cuda
262
- attention_ops["sageattn_qk_int8_pv_fp16_triton"] = _attention_sageattn_qk_int8_pv_fp16_triton
263
- if torch.cuda.get_device_capability()[0] >= 9:
264
- attention_ops["sageattn_qk_int8_pv_fp8_cuda_sm90"] = _attention_sageattn_qk_int8_pv_fp8_cuda_sm90
265
- if DotProductAttention is not None:
266
- attention_ops["te_fused"] = _attention_te
267
- attention_ops["te_fused_compile_d"] = _attention_te_compile_default
268
- attention_ops["te_fused_compile_ma"] = _attention_te_compile_max_autotune
269
- if xops is not None:
270
- attention_ops["xformers"] = _attention_xformers
271
- attention_ops["xformers_compile_d"] = _attention_xformers_compile_default
272
- attention_ops["xformers_compile_ma"] = _attention_xformers_compile_max_autotune
273
-
274
-
275
- def get_color_and_linestyle(n: int) -> tuple[str, str]:
276
- colors = ["#e41a1c", "#377eb8", "#4daf4a", "#984ea3", "#ff7f00", "#a65628", "#f781bf", "#999999"]
277
- line_styles = ["-", ":", "-.", "--"]
278
- if n > len(colors) * len(line_styles):
279
- raise ValueError(f"Required {n=} styles but maximum is {len(colors) * len(line_styles)}")
280
- styles = []
281
- for i in range(n):
282
- color = colors[i % len(colors)]
283
- linestyle = line_styles[i // len(colors)]
284
- styles.append((color, linestyle))
285
- return styles
286
-
287
-
288
- def correctness():
289
- for seq_len in sequence_lengths:
290
- shape = (batch_size, seq_len, num_attention_heads, attention_head_dim)
291
- print(f"\n\n===== Testing shape: {shape} =====")
292
-
293
- query = torch.randn(shape, device="cuda", dtype=torch.float32)
294
- key = torch.randn(shape, device="cuda", dtype=torch.float32)
295
- value = torch.randn(shape, device="cuda", dtype=torch.float32)
296
-
297
- golden_truth = _attention_torch(query, key, value, backend=torch.nn.attention.SDPBackend.MATH)
298
- query, key, value = (x.bfloat16() for x in (query, key, value))
299
-
300
- for name, fn in attention_ops.items():
301
- out = fn(query, key, value)
302
- absdiff = (out - golden_truth).abs()
303
- absmax = torch.max(absdiff)
304
- mae = torch.mean(absdiff)
305
- mse = torch.mean((golden_truth - out) ** 2)
306
- print(f"{name:<30}: absmax={absmax:.6f}, mae={mae:.6f}, mse={mse:.6f}")
307
-
308
-
309
- @triton.testing.perf_report(
310
- triton.testing.Benchmark(
311
- x_names=["seq_len"],
312
- x_vals=sequence_lengths,
313
- x_log=False,
314
- line_arg="provider",
315
- line_vals=list(attention_ops.keys()),
316
- line_names=[x.removeprefix("solution_") for x in attention_ops.keys()],
317
- ylabel="Time (ms)",
318
- styles=get_color_and_linestyle(len(attention_ops)),
319
- plot_name="Attention Benchmark",
320
- args={},
321
- )
322
- )
323
- def benchmark_fn(seq_len: int, provider: str):
324
- torch.manual_seed(0)
325
-
326
- shape = (batch_size, seq_len, num_attention_heads, attention_head_dim)
327
- query = torch.randn(shape, device="cuda", dtype=torch.bfloat16) * torch.randint(1, 5, shape, device="cuda", dtype=torch.bfloat16)
328
- key = torch.randn(shape, device="cuda", dtype=torch.bfloat16) * torch.randint(1, 5, shape, device="cuda", dtype=torch.bfloat16)
329
- value = torch.randn(shape, device="cuda", dtype=torch.bfloat16) * torch.randint(1, 5, shape, device="cuda", dtype=torch.bfloat16)
330
-
331
- fn = attention_ops[provider]
332
- ms, min_ms, max_ms = triton.testing.do_bench(
333
- lambda: fn(query, key, value),
334
- warmup=3,
335
- rep=10,
336
- quantiles=[0.5, 0.2, 0.8],
337
- )
338
- return ms, max_ms, min_ms
339
-
340
-
341
- with torch.inference_mode():
342
- correctness()
343
- fig = benchmark_fn.run(print_data=True, save_path=output_dir.as_posix())
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
flash_attn/cells/nv.py DELETED
@@ -1,3 +0,0 @@
1
- import subprocess
2
-
3
- print(subprocess.run(["nvidia-smi"], capture_output=True, text=True).stdout)
 
 
 
 
flash_attn/impls/artifacts/benchmark/attn.jsonl DELETED
@@ -1,6 +0,0 @@
1
- {"ts": "2025-10-02T19:59:35Z", "run": "8bc1bbc1e0504355abbb1f58e69828d3", "impl": "hf_kernels_flash_attn3", "tags": {"family": "hf-kernels", "backend": "flash-attn3", "compile": "none"}, "wl": {"name": "flux_L128", "batch": 1, "seq_len": 1152, "heads": 24, "head_dim": 128, "dtype": "bfloat16", "device": "cuda", "seed": 0}, "env": {"torch": "2.8.0+cu128", "cuda": "12.8", "gpu": "NVIDIA A10G", "sm": "8.6", "py": "3.11.13", "plat": "Linux-6.12.40-64.114.amzn2023.x86_64-x86_64-with-glibc2.36"}, "lat_ms": {"p10": 0.3603839874267578, "p50": 0.361952006816864, "p90": 0.3624640107154846, "mean": 0.3619711995124817, "reps": 5, "warmup": 2}, "compile_ms": 1.5701119899749756, "peak_bytes": 87425024, "ok": true, "absmax": 0.0625, "corr": {"ok": true, "rtol": 0.02, "atol": 0.02, "absmax": 0.0625, "mae": 0.00035858154296875, "mse": 2.8908252716064453e-06, "ref": "sdpa_math_fp32"}, "err": null}
2
- {"ts": "2025-10-02T19:59:35Z", "run": "8bc1bbc1e0504355abbb1f58e69828d3", "impl": "hf_kernels_flash_attn3", "tags": {"family": "hf-kernels", "backend": "flash-attn3", "compile": "none"}, "wl": {"name": "flux_L256", "batch": 1, "seq_len": 1280, "heads": 24, "head_dim": 128, "dtype": "bfloat16", "device": "cuda", "seed": 0}, "env": {"torch": "2.8.0+cu128", "cuda": "12.8", "gpu": "NVIDIA A10G", "sm": "8.6", "py": "3.11.13", "plat": "Linux-6.12.40-64.114.amzn2023.x86_64-x86_64-with-glibc2.36"}, "lat_ms": {"p10": 0.3892799913883209, "p50": 0.3909760117530823, "p90": 0.3922559916973114, "mean": 0.3912447988986969, "reps": 5, "warmup": 2}, "compile_ms": 0.35811200737953186, "peak_bytes": 95027200, "ok": true, "absmax": 0.0625, "corr": {"ok": true, "rtol": 0.02, "atol": 0.02, "absmax": 0.0625, "mae": 0.00035858154296875, "mse": 2.8908252716064453e-06, "ref": "sdpa_math_fp32"}, "err": null}
3
- {"ts": "2025-10-02T19:59:35Z", "run": "8bc1bbc1e0504355abbb1f58e69828d3", "impl": "hf_kernels_flash_attn3", "tags": {"family": "hf-kernels", "backend": "flash-attn3", "compile": "none"}, "wl": {"name": "flux_L320", "batch": 1, "seq_len": 1344, "heads": 24, "head_dim": 128, "dtype": "bfloat16", "device": "cuda", "seed": 0}, "env": {"torch": "2.8.0+cu128", "cuda": "12.8", "gpu": "NVIDIA A10G", "sm": "8.6", "py": "3.11.13", "plat": "Linux-6.12.40-64.114.amzn2023.x86_64-x86_64-with-glibc2.36"}, "lat_ms": {"p10": 0.5240640044212341, "p50": 0.5248960256576538, "p90": 0.5248960256576538, "mean": 0.5258048176765442, "reps": 5, "warmup": 2}, "compile_ms": 0.4891839921474457, "peak_bytes": 99680256, "ok": true, "absmax": 0.0625, "corr": {"ok": true, "rtol": 0.02, "atol": 0.02, "absmax": 0.0625, "mae": 0.00035858154296875, "mse": 2.905726432800293e-06, "ref": "sdpa_math_fp32"}, "err": null}
4
- {"ts": "2025-10-02T19:59:35Z", "run": "8bc1bbc1e0504355abbb1f58e69828d3", "impl": "hf_kernels_flash_attn3", "tags": {"family": "hf-kernels", "backend": "flash-attn3", "compile": "none"}, "wl": {"name": "flux_L384", "batch": 1, "seq_len": 1408, "heads": 24, "head_dim": 128, "dtype": "bfloat16", "device": "cuda", "seed": 0}, "env": {"torch": "2.8.0+cu128", "cuda": "12.8", "gpu": "NVIDIA A10G", "sm": "8.6", "py": "3.11.13", "plat": "Linux-6.12.40-64.114.amzn2023.x86_64-x86_64-with-glibc2.36"}, "lat_ms": {"p10": 0.5265600085258484, "p50": 0.5277760028839111, "p90": 0.5282559990882874, "mean": 0.5276032090187073, "reps": 5, "warmup": 2}, "compile_ms": 0.4968000054359436, "peak_bytes": 104726528, "ok": true, "absmax": 0.0625, "corr": {"ok": true, "rtol": 0.02, "atol": 0.02, "absmax": 0.0625, "mae": 0.0003604888916015625, "mse": 2.8908252716064453e-06, "ref": "sdpa_math_fp32"}, "err": null}
5
- {"ts": "2025-10-02T19:59:35Z", "run": "8bc1bbc1e0504355abbb1f58e69828d3", "impl": "hf_kernels_flash_attn3", "tags": {"family": "hf-kernels", "backend": "flash-attn3", "compile": "none"}, "wl": {"name": "flux_L448", "batch": 1, "seq_len": 1472, "heads": 24, "head_dim": 128, "dtype": "bfloat16", "device": "cuda", "seed": 0}, "env": {"torch": "2.8.0+cu128", "cuda": "12.8", "gpu": "NVIDIA A10G", "sm": "8.6", "py": "3.11.13", "plat": "Linux-6.12.40-64.114.amzn2023.x86_64-x86_64-with-glibc2.36"}, "lat_ms": {"p10": 0.5639039874076843, "p50": 0.5657920241355896, "p90": 0.5668479800224304, "mean": 0.5656383991241455, "reps": 5, "warmup": 2}, "compile_ms": 0.5312319993972778, "peak_bytes": 108855296, "ok": true, "absmax": 0.0625, "corr": {"ok": true, "rtol": 0.02, "atol": 0.02, "absmax": 0.0625, "mae": 0.0003566741943359375, "mse": 2.86102294921875e-06, "ref": "sdpa_math_fp32"}, "err": null}
6
- {"ts": "2025-10-02T19:59:35Z", "run": "8bc1bbc1e0504355abbb1f58e69828d3", "impl": "hf_kernels_flash_attn3", "tags": {"family": "hf-kernels", "backend": "flash-attn3", "compile": "none"}, "wl": {"name": "flux_L512", "batch": 1, "seq_len": 1536, "heads": 24, "head_dim": 128, "dtype": "bfloat16", "device": "cuda", "seed": 0}, "env": {"torch": "2.8.0+cu128", "cuda": "12.8", "gpu": "NVIDIA A10G", "sm": "8.6", "py": "3.11.13", "plat": "Linux-6.12.40-64.114.amzn2023.x86_64-x86_64-with-glibc2.36"}, "lat_ms": {"p10": 0.5689600110054016, "p50": 0.5698239803314209, "p90": 0.5713919997215271, "mean": 0.5789952039718628, "reps": 5, "warmup": 2}, "compile_ms": 0.5350080132484436, "peak_bytes": 114425856, "ok": true, "absmax": 0.0625, "corr": {"ok": true, "rtol": 0.02, "atol": 0.02, "absmax": 0.0625, "mae": 0.00035858154296875, "mse": 2.8759241104125977e-06, "ref": "sdpa_math_fp32"}, "err": null}
 
 
 
 
 
 
 
flash_attn/impls/artifacts/benchmark_default/attn_default.jsonl DELETED
@@ -1,6 +0,0 @@
1
- {"ts": "2025-10-02T19:58:18Z", "run": "9ebc449a917f4f2196503654e5549239", "impl": "torch_flash_compiled_default", "tags": {"family": "torch-sdpa", "backend": "FLASH", "compile": "default"}, "wl": {"name": "flux_L128", "batch": 1, "seq_len": 1152, "heads": 24, "head_dim": 128, "dtype": "bfloat16", "device": "cuda", "seed": 0}, "env": {"torch": "2.8.0+cu128", "cuda": "12.8", "gpu": "NVIDIA A10G", "sm": "8.6", "py": "3.11.13", "plat": "Linux-6.12.40-64.114.amzn2023.x86_64-x86_64-with-glibc2.36"}, "lat_ms": {"p10": 0.5141760110855103, "p50": 0.5175679922103882, "p90": 0.5197759866714478, "mean": 0.5181439876556396, "reps": 5, "warmup": 2}, "compile_ms": 3084.621826171875, "peak_bytes": 87425024, "ok": true, "absmax": 0.0625, "corr": {"ok": true, "rtol": 0.02, "atol": 0.02, "absmax": 0.0625, "mae": 0.000339508056640625, "mse": 2.726912498474121e-06, "ref": "sdpa_math_fp32"}, "err": null}
2
- {"ts": "2025-10-02T19:58:19Z", "run": "9ebc449a917f4f2196503654e5549239", "impl": "torch_flash_compiled_default", "tags": {"family": "torch-sdpa", "backend": "FLASH", "compile": "default"}, "wl": {"name": "flux_L256", "batch": 1, "seq_len": 1280, "heads": 24, "head_dim": 128, "dtype": "bfloat16", "device": "cuda", "seed": 0}, "env": {"torch": "2.8.0+cu128", "cuda": "12.8", "gpu": "NVIDIA A10G", "sm": "8.6", "py": "3.11.13", "plat": "Linux-6.12.40-64.114.amzn2023.x86_64-x86_64-with-glibc2.36"}, "lat_ms": {"p10": 0.5549119710922241, "p50": 0.5582720041275024, "p90": 0.5598080158233643, "mean": 0.5579584002494812, "reps": 5, "warmup": 2}, "compile_ms": 270.21795654296875, "peak_bytes": 95027200, "ok": true, "absmax": 0.0625, "corr": {"ok": true, "rtol": 0.02, "atol": 0.02, "absmax": 0.0625, "mae": 0.0003414154052734375, "mse": 2.7418136596679688e-06, "ref": "sdpa_math_fp32"}, "err": null}
3
- {"ts": "2025-10-02T19:58:19Z", "run": "9ebc449a917f4f2196503654e5549239", "impl": "torch_flash_compiled_default", "tags": {"family": "torch-sdpa", "backend": "FLASH", "compile": "default"}, "wl": {"name": "flux_L320", "batch": 1, "seq_len": 1344, "heads": 24, "head_dim": 128, "dtype": "bfloat16", "device": "cuda", "seed": 0}, "env": {"torch": "2.8.0+cu128", "cuda": "12.8", "gpu": "NVIDIA A10G", "sm": "8.6", "py": "3.11.13", "plat": "Linux-6.12.40-64.114.amzn2023.x86_64-x86_64-with-glibc2.36"}, "lat_ms": {"p10": 0.6853119730949402, "p50": 0.687391996383667, "p90": 0.6883519887924194, "mean": 0.6872959971427918, "reps": 5, "warmup": 2}, "compile_ms": 269.78741455078125, "peak_bytes": 99876864, "ok": true, "absmax": 0.0625, "corr": {"ok": true, "rtol": 0.02, "atol": 0.02, "absmax": 0.0625, "mae": 0.00034332275390625, "mse": 2.7567148208618164e-06, "ref": "sdpa_math_fp32"}, "err": null}
4
- {"ts": "2025-10-02T19:58:19Z", "run": "9ebc449a917f4f2196503654e5549239", "impl": "torch_flash_compiled_default", "tags": {"family": "torch-sdpa", "backend": "FLASH", "compile": "default"}, "wl": {"name": "flux_L384", "batch": 1, "seq_len": 1408, "heads": 24, "head_dim": 128, "dtype": "bfloat16", "device": "cuda", "seed": 0}, "env": {"torch": "2.8.0+cu128", "cuda": "12.8", "gpu": "NVIDIA A10G", "sm": "8.6", "py": "3.11.13", "plat": "Linux-6.12.40-64.114.amzn2023.x86_64-x86_64-with-glibc2.36"}, "lat_ms": {"p10": 0.7128639817237854, "p50": 0.7160959839820862, "p90": 0.7167680263519287, "mean": 0.716153597831726, "reps": 5, "warmup": 2}, "compile_ms": 269.8607177734375, "peak_bytes": 104726528, "ok": true, "absmax": 0.0625, "corr": {"ok": true, "rtol": 0.02, "atol": 0.02, "absmax": 0.0625, "mae": 0.00034332275390625, "mse": 2.7567148208618164e-06, "ref": "sdpa_math_fp32"}, "err": null}
5
- {"ts": "2025-10-02T19:58:19Z", "run": "9ebc449a917f4f2196503654e5549239", "impl": "torch_flash_compiled_default", "tags": {"family": "torch-sdpa", "backend": "FLASH", "compile": "default"}, "wl": {"name": "flux_L448", "batch": 1, "seq_len": 1472, "heads": 24, "head_dim": 128, "dtype": "bfloat16", "device": "cuda", "seed": 0}, "env": {"torch": "2.8.0+cu128", "cuda": "12.8", "gpu": "NVIDIA A10G", "sm": "8.6", "py": "3.11.13", "plat": "Linux-6.12.40-64.114.amzn2023.x86_64-x86_64-with-glibc2.36"}, "lat_ms": {"p10": 0.7386879920959473, "p50": 0.7400959730148315, "p90": 0.7415040135383606, "mean": 0.7418303966522217, "reps": 5, "warmup": 2}, "compile_ms": 269.20501708984375, "peak_bytes": 108855296, "ok": true, "absmax": 0.0625, "corr": {"ok": true, "rtol": 0.02, "atol": 0.02, "absmax": 0.0625, "mae": 0.00034332275390625, "mse": 2.7567148208618164e-06, "ref": "sdpa_math_fp32"}, "err": null}
6
- {"ts": "2025-10-02T19:58:20Z", "run": "9ebc449a917f4f2196503654e5549239", "impl": "torch_flash_compiled_default", "tags": {"family": "torch-sdpa", "backend": "FLASH", "compile": "default"}, "wl": {"name": "flux_L512", "batch": 1, "seq_len": 1536, "heads": 24, "head_dim": 128, "dtype": "bfloat16", "device": "cuda", "seed": 0}, "env": {"torch": "2.8.0+cu128", "cuda": "12.8", "gpu": "NVIDIA A10G", "sm": "8.6", "py": "3.11.13", "plat": "Linux-6.12.40-64.114.amzn2023.x86_64-x86_64-with-glibc2.36"}, "lat_ms": {"p10": 0.7708160281181335, "p50": 0.7740799784660339, "p90": 0.7753919959068298, "mean": 0.7745471954345703, "reps": 5, "warmup": 2}, "compile_ms": 270.93829345703125, "peak_bytes": 114425856, "ok": true, "absmax": 0.0625, "corr": {"ok": true, "rtol": 0.02, "atol": 0.02, "absmax": 0.0625, "mae": 0.0003452301025390625, "mse": 2.771615982055664e-06, "ref": "sdpa_math_fp32"}, "err": null}
 
 
 
 
 
 
 
flash_attn/impls/artifacts/benchmark_max_autotune/attn_max_autotune.jsonl DELETED
@@ -1,6 +0,0 @@
1
- {"ts": "2025-10-02T19:57:25Z", "run": "edb73be653834cdf8524ee78b403db7f", "impl": "torch_flash_compiled_max_autotune", "tags": {"family": "torch-sdpa", "backend": "FLASH", "compile": "max-autotune"}, "wl": {"name": "flux_L128", "batch": 1, "seq_len": 1152, "heads": 24, "head_dim": 128, "dtype": "bfloat16", "device": "cuda", "seed": 0}, "env": {"torch": "2.8.0+cu128", "cuda": "12.8", "gpu": "NVIDIA A10G", "sm": "8.6", "py": "3.11.13", "plat": "Linux-6.12.40-64.114.amzn2023.x86_64-x86_64-with-glibc2.36"}, "lat_ms": {"p10": 0.6144000291824341, "p50": 0.6245759725570679, "p90": 0.6483200192451477, "mean": 0.6468096017837525, "reps": 5, "warmup": 2}, "compile_ms": 4407.3388671875, "peak_bytes": 70779904, "ok": true, "absmax": 0.0625, "corr": {"ok": true, "rtol": 0.02, "atol": 0.02, "absmax": 0.0625, "mae": 0.000339508056640625, "mse": 2.726912498474121e-06, "ref": "sdpa_math_fp32"}, "err": null}
2
- {"ts": "2025-10-02T19:57:27Z", "run": "edb73be653834cdf8524ee78b403db7f", "impl": "torch_flash_compiled_max_autotune", "tags": {"family": "torch-sdpa", "backend": "FLASH", "compile": "max-autotune"}, "wl": {"name": "flux_L256", "batch": 1, "seq_len": 1280, "heads": 24, "head_dim": 128, "dtype": "bfloat16", "device": "cuda", "seed": 0}, "env": {"torch": "2.8.0+cu128", "cuda": "12.8", "gpu": "NVIDIA A10G", "sm": "8.6", "py": "3.11.13", "plat": "Linux-6.12.40-64.114.amzn2023.x86_64-x86_64-with-glibc2.36"}, "lat_ms": {"p10": 0.6689280271530151, "p50": 0.6851199865341187, "p90": 0.7184960246086121, "mean": 0.7060160160064697, "reps": 5, "warmup": 2}, "compile_ms": 1686.2735595703125, "peak_bytes": 78644224, "ok": true, "absmax": 0.0625, "corr": {"ok": true, "rtol": 0.02, "atol": 0.02, "absmax": 0.0625, "mae": 0.0003414154052734375, "mse": 2.7418136596679688e-06, "ref": "sdpa_math_fp32"}, "err": null}
3
- {"ts": "2025-10-02T19:57:29Z", "run": "edb73be653834cdf8524ee78b403db7f", "impl": "torch_flash_compiled_max_autotune", "tags": {"family": "torch-sdpa", "backend": "FLASH", "compile": "max-autotune"}, "wl": {"name": "flux_L320", "batch": 1, "seq_len": 1344, "heads": 24, "head_dim": 128, "dtype": "bfloat16", "device": "cuda", "seed": 0}, "env": {"torch": "2.8.0+cu128", "cuda": "12.8", "gpu": "NVIDIA A10G", "sm": "8.6", "py": "3.11.13", "plat": "Linux-6.12.40-64.114.amzn2023.x86_64-x86_64-with-glibc2.36"}, "lat_ms": {"p10": 0.7953600287437439, "p50": 0.8155840039253235, "p90": 0.8403519988059998, "mean": 0.8332608103752136, "reps": 5, "warmup": 2}, "compile_ms": 1462.938232421875, "peak_bytes": 84280320, "ok": true, "absmax": 0.0625, "corr": {"ok": true, "rtol": 0.02, "atol": 0.02, "absmax": 0.0625, "mae": 0.00034332275390625, "mse": 2.7567148208618164e-06, "ref": "sdpa_math_fp32"}, "err": null}
4
- {"ts": "2025-10-02T19:57:31Z", "run": "edb73be653834cdf8524ee78b403db7f", "impl": "torch_flash_compiled_max_autotune", "tags": {"family": "torch-sdpa", "backend": "FLASH", "compile": "max-autotune"}, "wl": {"name": "flux_L384", "batch": 1, "seq_len": 1408, "heads": 24, "head_dim": 128, "dtype": "bfloat16", "device": "cuda", "seed": 0}, "env": {"torch": "2.8.0+cu128", "cuda": "12.8", "gpu": "NVIDIA A10G", "sm": "8.6", "py": "3.11.13", "plat": "Linux-6.12.40-64.114.amzn2023.x86_64-x86_64-with-glibc2.36"}, "lat_ms": {"p10": 0.8470720052719116, "p50": 0.849727988243103, "p90": 0.8745279908180237, "mean": 0.8719295978546142, "reps": 5, "warmup": 2}, "compile_ms": 1689.3455810546875, "peak_bytes": 86508544, "ok": true, "absmax": 0.0625, "corr": {"ok": true, "rtol": 0.02, "atol": 0.02, "absmax": 0.0625, "mae": 0.00034332275390625, "mse": 2.7567148208618164e-06, "ref": "sdpa_math_fp32"}, "err": null}
5
- {"ts": "2025-10-02T19:57:33Z", "run": "edb73be653834cdf8524ee78b403db7f", "impl": "torch_flash_compiled_max_autotune", "tags": {"family": "torch-sdpa", "backend": "FLASH", "compile": "max-autotune"}, "wl": {"name": "flux_L448", "batch": 1, "seq_len": 1472, "heads": 24, "head_dim": 128, "dtype": "bfloat16", "device": "cuda", "seed": 0}, "env": {"torch": "2.8.0+cu128", "cuda": "12.8", "gpu": "NVIDIA A10G", "sm": "8.6", "py": "3.11.13", "plat": "Linux-6.12.40-64.114.amzn2023.x86_64-x86_64-with-glibc2.36"}, "lat_ms": {"p10": 0.8677120208740234, "p50": 0.8835520148277283, "p90": 0.9034240245819092, "mean": 0.9034304022789001, "reps": 5, "warmup": 2}, "compile_ms": 1693.035888671875, "peak_bytes": 90440704, "ok": true, "absmax": 0.0625, "corr": {"ok": true, "rtol": 0.02, "atol": 0.02, "absmax": 0.0625, "mae": 0.00034332275390625, "mse": 2.7567148208618164e-06, "ref": "sdpa_math_fp32"}, "err": null}
6
- {"ts": "2025-10-02T19:57:34Z", "run": "edb73be653834cdf8524ee78b403db7f", "impl": "torch_flash_compiled_max_autotune", "tags": {"family": "torch-sdpa", "backend": "FLASH", "compile": "max-autotune"}, "wl": {"name": "flux_L512", "batch": 1, "seq_len": 1536, "heads": 24, "head_dim": 128, "dtype": "bfloat16", "device": "cuda", "seed": 0}, "env": {"torch": "2.8.0+cu128", "cuda": "12.8", "gpu": "NVIDIA A10G", "sm": "8.6", "py": "3.11.13", "plat": "Linux-6.12.40-64.114.amzn2023.x86_64-x86_64-with-glibc2.36"}, "lat_ms": {"p10": 0.9154239892959595, "p50": 0.9213759899139404, "p90": 0.9359679818153381, "mean": 0.9387519836425782, "reps": 5, "warmup": 2}, "compile_ms": 1689.36279296875, "peak_bytes": 94372864, "ok": true, "absmax": 0.0625, "corr": {"ok": true, "rtol": 0.02, "atol": 0.02, "absmax": 0.0625, "mae": 0.0003452301025390625, "mse": 2.771615982055664e-06, "ref": "sdpa_math_fp32"}, "err": null}
 
 
 
 
 
 
 
flash_attn/impls/cells/benchmark.py DELETED
@@ -1,71 +0,0 @@
1
- # /// script
2
- # requires-python = ">=3.10"
3
- # dependencies = [
4
- # "numpy",
5
- # "torch",
6
- # "kernels-benchmark-tools",
7
- # "kernels",
8
- # ]
9
- #
10
- # [tool.uv.sources]
11
- # kernels-benchmark-tools = { git = "https://github.com/drbh/kernels-benchmark-tools.git", branch = "main" }
12
- # ///
13
- import torch
14
- import sys
15
- import os
16
- import kernels_benchmark_tools as kbt
17
- from kernels import get_kernel
18
-
19
- hf_kernels_flash_attn3 = get_kernel("kernels-community/flash-attn3")
20
-
21
-
22
- def hf_flash_attention3(query, key, value):
23
- return hf_kernels_flash_attn3.flash_attn_func(query, key, value, causal=False)[0]
24
-
25
-
26
- kbt.add(
27
- "hf_kernels_flash_attn3",
28
- hf_flash_attention3,
29
- tags={"family": "hf-kernels", "backend": "flash-attn3", "compile": "none"},
30
- )
31
-
32
- if __name__ == "__main__":
33
- device = "cuda" if torch.cuda.is_available() else "cpu"
34
-
35
- if device == "cpu":
36
- print("HF Kernels Flash Attention 3 requires CUDA - skipping benchmark")
37
- sys.exit(0)
38
-
39
- dtype = "bfloat16"
40
-
41
- # Flux-like workloads
42
- base = 1024
43
- flux_sizes = [128, 256, 320, 384, 448, 512]
44
- heads = 24
45
- head_dim = 128
46
-
47
- wl = []
48
- for L in flux_sizes:
49
- wl.append(
50
- {
51
- "name": f"flux_L{L}",
52
- "batch": 1,
53
- "seq_len": base + L,
54
- "heads": heads,
55
- "head_dim": head_dim,
56
- "dtype": dtype,
57
- "device": device,
58
- "seed": 0,
59
- }
60
- )
61
-
62
- kbt.run(
63
- wl,
64
- jsonl="attn.jsonl",
65
- reps=5,
66
- warmup=2,
67
- gen=kbt.attn.gen_qkv,
68
- ref=kbt.attn.ref_math,
69
- cmp=kbt.attn.cmp_allclose,
70
- )
71
- kbt.summarize(["attn.jsonl"])
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
flash_attn/impls/cells/benchmark_default.py DELETED
@@ -1,70 +0,0 @@
1
- # /// script
2
- # requires-python = ">=3.10"
3
- # dependencies = [
4
- # "numpy",
5
- # "torch",
6
- # "kernels-benchmark-tools",
7
- # ]
8
- #
9
- # [tool.uv.sources]
10
- # kernels-benchmark-tools = { git = "https://github.com/drbh/kernels-benchmark-tools.git", branch = "main" }
11
- # ///
12
- import torch
13
- import sys
14
- import os
15
- import kernels_benchmark_tools as kbt
16
-
17
-
18
- def torch_flash_base(q, k, v):
19
- qt, kt, vt = (x.transpose(1, 2).contiguous() for x in (q, k, v))
20
- with torch.nn.attention.sdpa_kernel(torch.nn.attention.SDPBackend.FLASH_ATTENTION):
21
- o = torch.nn.functional.scaled_dot_product_attention(qt, kt, vt)
22
- return o.transpose(1, 2).contiguous()
23
-
24
-
25
- # Compile with default mode
26
- compiled_flash_default = torch.compile(torch_flash_base, mode="default", fullgraph=True, dynamic=False)
27
-
28
- kbt.add(
29
- "torch_flash_compiled_default",
30
- compiled_flash_default,
31
- tags={"family": "torch-sdpa", "backend": "FLASH", "compile": "default"},
32
- )
33
-
34
- if __name__ == "__main__":
35
- device = "cuda" if torch.cuda.is_available() else "cpu"
36
- dtype = "float32" if device == "cpu" else "bfloat16"
37
-
38
- # Flux-like workloads
39
- base = 1024 if device == "cuda" else 512
40
- flux_sizes = (
41
- [128, 256, 320, 384, 448, 512] if device == "cuda" else [64, 128, 192, 256]
42
- )
43
- heads = 24 if device == "cuda" else 8
44
- head_dim = 128 if device == "cuda" else 64
45
-
46
- wl = []
47
- for L in flux_sizes:
48
- wl.append(
49
- {
50
- "name": f"flux_L{L}",
51
- "batch": 1,
52
- "seq_len": base + L,
53
- "heads": heads,
54
- "head_dim": head_dim,
55
- "dtype": dtype,
56
- "device": device,
57
- "seed": 0,
58
- }
59
- )
60
-
61
- kbt.run(
62
- wl,
63
- jsonl="attn_default.jsonl",
64
- reps=5,
65
- warmup=2,
66
- gen=kbt.attn.gen_qkv,
67
- ref=kbt.attn.ref_math,
68
- cmp=kbt.attn.cmp_allclose,
69
- )
70
- kbt.summarize(["attn_default.jsonl"])
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
flash_attn/impls/cells/benchmark_max_autotune.py DELETED
@@ -1,70 +0,0 @@
1
- # /// script
2
- # requires-python = ">=3.10"
3
- # dependencies = [
4
- # "numpy",
5
- # "torch",
6
- # "kernels-benchmark-tools",
7
- # ]
8
- #
9
- # [tool.uv.sources]
10
- # kernels-benchmark-tools = { git = "https://github.com/drbh/kernels-benchmark-tools.git", branch = "main" }
11
- # ///
12
- import torch
13
- import sys
14
- import os
15
- import kernels_benchmark_tools as kbt
16
-
17
-
18
- def torch_flash_base(q, k, v):
19
- qt, kt, vt = (x.transpose(1, 2).contiguous() for x in (q, k, v))
20
- with torch.nn.attention.sdpa_kernel(torch.nn.attention.SDPBackend.FLASH_ATTENTION):
21
- o = torch.nn.functional.scaled_dot_product_attention(qt, kt, vt)
22
- return o.transpose(1, 2).contiguous()
23
-
24
-
25
- # Compile with max-autotune mode
26
- compiled_flash_max_autotune = torch.compile(torch_flash_base, mode="max-autotune", fullgraph=True, dynamic=False)
27
-
28
- kbt.add(
29
- "torch_flash_compiled_max_autotune",
30
- compiled_flash_max_autotune,
31
- tags={"family": "torch-sdpa", "backend": "FLASH", "compile": "max-autotune"},
32
- )
33
-
34
- if __name__ == "__main__":
35
- device = "cuda" if torch.cuda.is_available() else "cpu"
36
- dtype = "float32" if device == "cpu" else "bfloat16"
37
-
38
- # Flux-like workloads
39
- base = 1024 if device == "cuda" else 512
40
- flux_sizes = (
41
- [128, 256, 320, 384, 448, 512] if device == "cuda" else [64, 128, 192, 256]
42
- )
43
- heads = 24 if device == "cuda" else 8
44
- head_dim = 128 if device == "cuda" else 64
45
-
46
- wl = []
47
- for L in flux_sizes:
48
- wl.append(
49
- {
50
- "name": f"flux_L{L}",
51
- "batch": 1,
52
- "seq_len": base + L,
53
- "heads": heads,
54
- "head_dim": head_dim,
55
- "dtype": dtype,
56
- "device": device,
57
- "seed": 0,
58
- }
59
- )
60
-
61
- kbt.run(
62
- wl,
63
- jsonl="attn_max_autotune.jsonl",
64
- reps=5,
65
- warmup=2,
66
- gen=kbt.attn.gen_qkv,
67
- ref=kbt.attn.ref_math,
68
- cmp=kbt.attn.cmp_allclose,
69
- )
70
- kbt.summarize(["attn_max_autotune.jsonl"])
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
flash_attn/impls/cells/nv.py DELETED
@@ -1,3 +0,0 @@
1
- import subprocess
2
-
3
- print(subprocess.run(["nvidia-smi"], capture_output=True, text=True).stdout)
 
 
 
 
flash_attn/impls/compiled_variants.html DELETED
The diff for this file is too large to render. See raw diff
 
flash_attn/impls/flash_attention.html DELETED
The diff for this file is too large to render. See raw diff
 
flash_attn/impls/hf_kernels_flash_attn.html DELETED
The diff for this file is too large to render. See raw diff
 
flash_attn/impls/hf_kernels_flash_attn3.html DELETED
The diff for this file is too large to render. See raw diff
 
flash_attn/impls/index.html DELETED
@@ -1,94 +0,0 @@
1
- <!DOCTYPE html>
2
- <html>
3
- <head>
4
- <meta charset='UTF-8'>
5
- <meta name='viewport' content='width=device-width, initial-scale=1.0'>
6
- <title>Index of /flash_attn/impls</title>
7
- <style>
8
- :root {
9
- --bg-primary: #0a0a0a;
10
- --bg-secondary: #121212;
11
- --bg-tertiary: #181818;
12
- --text-primary: #e0e0e0;
13
- --text-secondary: #888888;
14
- --text-link: #64b5f6;
15
- --border-primary: #2a2a2a;
16
- }
17
- body {
18
- font-family: system-ui, -apple-system, Segoe UI, Roboto, Helvetica, Arial, sans-serif;
19
- background: var(--bg-primary);
20
- color: var(--text-primary);
21
- margin: 0;
22
- padding: 16px;
23
- max-width: 900px;
24
- margin: 0 auto;
25
- }
26
- .controls {
27
- display: flex;
28
- justify-content: flex-end;
29
- margin-bottom: 1rem;
30
- }
31
- .back-button {
32
- background: var(--bg-secondary);
33
- border: 1px solid var(--border-primary);
34
- padding: 8px 12px;
35
- border-radius: 4px;
36
- color: var(--text-secondary);
37
- cursor: pointer;
38
- font-size: 0.9rem;
39
- text-decoration: none;
40
- display: inline-block;
41
- }
42
- .back-button:hover {
43
- color: var(--text-primary);
44
- background: var(--bg-tertiary);
45
- }
46
- h1 {
47
- font-size: 1.5em;
48
- margin: 1rem 0;
49
- color: var(--text-primary);
50
- border-bottom: 1px solid var(--border-primary);
51
- padding-bottom: 0.5rem;
52
- }
53
- ul {
54
- list-style-type: none;
55
- padding: 0;
56
- }
57
- li {
58
- margin: 0;
59
- border-bottom: 1px solid var(--border-primary);
60
- }
61
- li:last-child {
62
- border-bottom: none;
63
- }
64
- a {
65
- display: block;
66
- padding: 0.75rem 0.5rem;
67
- text-decoration: none;
68
- color: var(--text-link);
69
- transition: background 0.2s ease;
70
- }
71
- a:hover {
72
- background: var(--bg-secondary);
73
- }
74
- .dir {
75
- font-weight: 500;
76
- }
77
- </style>
78
- </head>
79
- <body>
80
- <div class='controls'>
81
- <a href='../index.html' class='back-button'>← back</a>
82
- </div>
83
- <h1>Index of /flash_attn/impls</h1>
84
- <ul>
85
- <li><a href='compiled_variants.html' class='file'>compiled_variants.html</a></li>
86
- <li><a href='flash_attention.html' class='file'>flash_attention.html</a></li>
87
- <li><a href='hf_kernels_flash_attn.html' class='file'>hf_kernels_flash_attn.html</a></li>
88
- <li><a href='hf_kernels_flash_attn3.html' class='file'>hf_kernels_flash_attn3.html</a></li>
89
- <li><a href='mem_efficient_attention.html' class='file'>mem_efficient_attention.html</a></li>
90
- <li><a href='sage_attention.html' class='file'>sage_attention.html</a></li>
91
- <li><a href='xformers.html' class='file'>xformers.html</a></li>
92
- </ul>
93
- </body>
94
- </html>
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
flash_attn/impls/mem_efficient_attention.html DELETED
The diff for this file is too large to render. See raw diff
 
flash_attn/impls/sage_attention.html DELETED
The diff for this file is too large to render. See raw diff
 
flash_attn/impls/xformers.html DELETED
The diff for this file is too large to render. See raw diff
 
flash_attn/index.html DELETED
@@ -1,89 +0,0 @@
1
- <!DOCTYPE html>
2
- <html>
3
- <head>
4
- <meta charset='UTF-8'>
5
- <meta name='viewport' content='width=device-width, initial-scale=1.0'>
6
- <title>Index of /flash_attn</title>
7
- <style>
8
- :root {
9
- --bg-primary: #0a0a0a;
10
- --bg-secondary: #121212;
11
- --bg-tertiary: #181818;
12
- --text-primary: #e0e0e0;
13
- --text-secondary: #888888;
14
- --text-link: #64b5f6;
15
- --border-primary: #2a2a2a;
16
- }
17
- body {
18
- font-family: system-ui, -apple-system, Segoe UI, Roboto, Helvetica, Arial, sans-serif;
19
- background: var(--bg-primary);
20
- color: var(--text-primary);
21
- margin: 0;
22
- padding: 16px;
23
- max-width: 900px;
24
- margin: 0 auto;
25
- }
26
- .controls {
27
- display: flex;
28
- justify-content: flex-end;
29
- margin-bottom: 1rem;
30
- }
31
- .back-button {
32
- background: var(--bg-secondary);
33
- border: 1px solid var(--border-primary);
34
- padding: 8px 12px;
35
- border-radius: 4px;
36
- color: var(--text-secondary);
37
- cursor: pointer;
38
- font-size: 0.9rem;
39
- text-decoration: none;
40
- display: inline-block;
41
- }
42
- .back-button:hover {
43
- color: var(--text-primary);
44
- background: var(--bg-tertiary);
45
- }
46
- h1 {
47
- font-size: 1.5em;
48
- margin: 1rem 0;
49
- color: var(--text-primary);
50
- border-bottom: 1px solid var(--border-primary);
51
- padding-bottom: 0.5rem;
52
- }
53
- ul {
54
- list-style-type: none;
55
- padding: 0;
56
- }
57
- li {
58
- margin: 0;
59
- border-bottom: 1px solid var(--border-primary);
60
- }
61
- li:last-child {
62
- border-bottom: none;
63
- }
64
- a {
65
- display: block;
66
- padding: 0.75rem 0.5rem;
67
- text-decoration: none;
68
- color: var(--text-link);
69
- transition: background 0.2s ease;
70
- }
71
- a:hover {
72
- background: var(--bg-secondary);
73
- }
74
- .dir {
75
- font-weight: 500;
76
- }
77
- </style>
78
- </head>
79
- <body>
80
- <div class='controls'>
81
- <a href='../index.html' class='back-button'>← back</a>
82
- </div>
83
- <h1>Index of /flash_attn</h1>
84
- <ul>
85
- <li><a href='impls/index.html' class='dir'>impls/</a></li>
86
- <li><a href='results/index.html' class='dir'>results/</a></li>
87
- </ul>
88
- </body>
89
- </html>
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
flash_attn/results/artifacts/combine/latency.csv DELETED
@@ -1,43 +0,0 @@
1
- Implementation,Impl ID,Workload,Batch,Seq Length,Heads,Head Dim,Dtype,Mean (ms),P10 (ms),P50 (ms),P90 (ms),Reps,Peak Mem (MB),Backend,Family
2
- Flash (PyTorch SDPA),torch_flash_ma,flux_L128,1,1152,24,128,bfloat16,0.49411200881004336,0.48844799399375916,0.4936000108718872,0.4944640100002289,5,83.38,FLASH,torch-sdpa
3
- Flash (PyTorch SDPA),torch_flash_ma,flux_L256,1,1280,24,128,bfloat16,0.5234112024307251,0.5224320292472839,0.5235199928283691,0.5235840082168579,5,90.62,FLASH,torch-sdpa
4
- Flash (PyTorch SDPA),torch_flash_ma,flux_L320,1,1344,24,128,bfloat16,0.6527232170104981,0.6503040194511414,0.6524800062179565,0.6545600295066833,5,95.06,FLASH,torch-sdpa
5
- Flash (PyTorch SDPA),torch_flash_ma,flux_L384,1,1408,24,128,bfloat16,0.682803213596344,0.6805760264396667,0.6828799843788147,0.6832640171051025,5,99.88,FLASH,torch-sdpa
6
- Flash (PyTorch SDPA),torch_flash_ma,flux_L448,1,1472,24,128,bfloat16,0.7075456142425537,0.7057600021362305,0.7063360214233398,0.7070720195770264,5,103.81,FLASH,torch-sdpa
7
- Flash (PyTorch SDPA),torch_flash_ma,flux_L512,1,1536,24,128,bfloat16,0.7379711985588073,0.7368639707565308,0.7372480034828186,0.7391039729118347,5,109.12,FLASH,torch-sdpa
8
- MemEff (PyTorch SDPA),torch_mem_eff,flux_L128,1,1152,24,128,bfloat16,0.5874239921569824,0.5861759781837463,0.5873280167579651,0.5877439975738525,5,83.38,EFFICIENT,torch-sdpa
9
- MemEff (PyTorch SDPA),torch_mem_eff,flux_L256,1,1280,24,128,bfloat16,0.6502719998359681,0.6490240097045898,0.649183988571167,0.6517760157585144,5,90.62,EFFICIENT,torch-sdpa
10
- MemEff (PyTorch SDPA),torch_mem_eff,flux_L320,1,1344,24,128,bfloat16,0.7812095880508423,0.7761600017547607,0.7803199887275696,0.7852799892425537,5,95.94,EFFICIENT,torch-sdpa
11
- MemEff (PyTorch SDPA),torch_mem_eff,flux_L384,1,1408,24,128,bfloat16,0.7948480010032654,0.7911999821662903,0.7935360074043274,0.7948480248451233,5,100.0,EFFICIENT,torch-sdpa
12
- MemEff (PyTorch SDPA),torch_mem_eff,flux_L448,1,1472,24,128,bfloat16,0.8463295936584473,0.8449919819831848,0.8459839820861816,0.8461120128631592,5,103.81,EFFICIENT,torch-sdpa
13
- MemEff (PyTorch SDPA),torch_mem_eff,flux_L512,1,1536,24,128,bfloat16,0.9538687944412232,0.9492800235748291,0.9518399834632874,0.9581760168075562,5,109.12,EFFICIENT,torch-sdpa
14
- xFormers,xformers_meff,flux_L128,1,1152,24,128,bfloat16,0.4515071928501129,0.44364801049232483,0.4524799883365631,0.4557119905948639,5,83.38,memory_efficient,xformers
15
- xFormers,xformers_meff,flux_L256,1,1280,24,128,bfloat16,0.46787199974060056,0.46489599347114563,0.4684160053730011,0.46908798813819885,5,90.62,memory_efficient,xformers
16
- xFormers,xformers_meff,flux_L320,1,1344,24,128,bfloat16,0.6001471996307373,0.596992015838623,0.5984640121459961,0.6016640067100525,5,95.06,memory_efficient,xformers
17
- xFormers,xformers_meff,flux_L384,1,1408,24,128,bfloat16,0.6023231983184815,0.5997440218925476,0.6031039953231812,0.6032639741897583,5,99.88,memory_efficient,xformers
18
- xFormers,xformers_meff,flux_L448,1,1472,24,128,bfloat16,0.6411136031150818,0.6381760239601135,0.6414719820022583,0.6421440243721008,5,103.81,memory_efficient,xformers
19
- xFormers,xformers_meff,flux_L512,1,1536,24,128,bfloat16,0.6594688057899475,0.6441280245780945,0.6496639847755432,0.6527680158615112,5,109.12,memory_efficient,xformers
20
- Compiled (default),torch_flash_compiled_default,flux_L128,1,1152,24,128,bfloat16,0.5181439876556396,0.5141760110855103,0.5175679922103882,0.5197759866714478,5,83.38,FLASH,torch-sdpa
21
- Compiled (default),torch_flash_compiled_default,flux_L256,1,1280,24,128,bfloat16,0.5579584002494812,0.5549119710922241,0.5582720041275024,0.5598080158233643,5,90.62,FLASH,torch-sdpa
22
- Compiled (default),torch_flash_compiled_default,flux_L320,1,1344,24,128,bfloat16,0.6872959971427918,0.6853119730949402,0.687391996383667,0.6883519887924194,5,95.25,FLASH,torch-sdpa
23
- Compiled (default),torch_flash_compiled_default,flux_L384,1,1408,24,128,bfloat16,0.716153597831726,0.7128639817237854,0.7160959839820862,0.7167680263519287,5,99.88,FLASH,torch-sdpa
24
- Compiled (default),torch_flash_compiled_default,flux_L448,1,1472,24,128,bfloat16,0.7418303966522217,0.7386879920959473,0.7400959730148315,0.7415040135383606,5,103.81,FLASH,torch-sdpa
25
- Compiled (default),torch_flash_compiled_default,flux_L512,1,1536,24,128,bfloat16,0.7745471954345703,0.7708160281181335,0.7740799784660339,0.7753919959068298,5,109.12,FLASH,torch-sdpa
26
- Compiled (max-autotune),torch_flash_compiled_max_autotune,flux_L128,1,1152,24,128,bfloat16,0.6468096017837525,0.6144000291824341,0.6245759725570679,0.6483200192451477,5,67.5,FLASH,torch-sdpa
27
- Compiled (max-autotune),torch_flash_compiled_max_autotune,flux_L256,1,1280,24,128,bfloat16,0.7060160160064697,0.6689280271530151,0.6851199865341187,0.7184960246086121,5,75.0,FLASH,torch-sdpa
28
- Compiled (max-autotune),torch_flash_compiled_max_autotune,flux_L320,1,1344,24,128,bfloat16,0.8332608103752136,0.7953600287437439,0.8155840039253235,0.8403519988059998,5,80.38,FLASH,torch-sdpa
29
- Compiled (max-autotune),torch_flash_compiled_max_autotune,flux_L384,1,1408,24,128,bfloat16,0.8719295978546142,0.8470720052719116,0.849727988243103,0.8745279908180237,5,82.5,FLASH,torch-sdpa
30
- Compiled (max-autotune),torch_flash_compiled_max_autotune,flux_L448,1,1472,24,128,bfloat16,0.9034304022789001,0.8677120208740234,0.8835520148277283,0.9034240245819092,5,86.25,FLASH,torch-sdpa
31
- Compiled (max-autotune),torch_flash_compiled_max_autotune,flux_L512,1,1536,24,128,bfloat16,0.9387519836425782,0.9154239892959595,0.9213759899139404,0.9359679818153381,5,90.0,FLASH,torch-sdpa
32
- HF Kernels Flash Attn,hf_kernels_flash_attn,flux_L128,1,1152,24,128,bfloat16,0.3455295979976654,0.34355199337005615,0.34563198685646057,0.34643200039863586,5,83.38,flash-attn,hf-kernels
33
- HF Kernels Flash Attn,hf_kernels_flash_attn,flux_L256,1,1280,24,128,bfloat16,0.3756160080432892,0.37411201000213623,0.3752000033855438,0.3770880103111267,5,90.62,flash-attn,hf-kernels
34
- HF Kernels Flash Attn,hf_kernels_flash_attn,flux_L320,1,1344,24,128,bfloat16,0.4953216016292572,0.49324798583984375,0.49433600902557373,0.49663999676704407,5,95.06,flash-attn,hf-kernels
35
- HF Kernels Flash Attn,hf_kernels_flash_attn,flux_L384,1,1408,24,128,bfloat16,0.5157055854797363,0.5142719745635986,0.516319990158081,0.516543984413147,5,99.88,flash-attn,hf-kernels
36
- HF Kernels Flash Attn,hf_kernels_flash_attn,flux_L448,1,1472,24,128,bfloat16,0.5356672048568726,0.5346879959106445,0.5358080267906189,0.5361599922180176,5,103.81,flash-attn,hf-kernels
37
- HF Kernels Flash Attn,hf_kernels_flash_attn,flux_L512,1,1536,24,128,bfloat16,0.5587136030197144,0.5557760000228882,0.5574079751968384,0.5581120252609253,5,109.12,flash-attn,hf-kernels
38
- HF Kernels Flash Attn3,hf_kernels_flash_attn3,flux_L128,1,1152,24,128,bfloat16,0.3619711995124817,0.3603839874267578,0.361952006816864,0.3624640107154846,5,83.38,flash-attn3,hf-kernels
39
- HF Kernels Flash Attn3,hf_kernels_flash_attn3,flux_L256,1,1280,24,128,bfloat16,0.3912447988986969,0.3892799913883209,0.3909760117530823,0.3922559916973114,5,90.62,flash-attn3,hf-kernels
40
- HF Kernels Flash Attn3,hf_kernels_flash_attn3,flux_L320,1,1344,24,128,bfloat16,0.5258048176765442,0.5240640044212341,0.5248960256576538,0.5248960256576538,5,95.06,flash-attn3,hf-kernels
41
- HF Kernels Flash Attn3,hf_kernels_flash_attn3,flux_L384,1,1408,24,128,bfloat16,0.5276032090187073,0.5265600085258484,0.5277760028839111,0.5282559990882874,5,99.88,flash-attn3,hf-kernels
42
- HF Kernels Flash Attn3,hf_kernels_flash_attn3,flux_L448,1,1472,24,128,bfloat16,0.5656383991241455,0.5639039874076843,0.5657920241355896,0.5668479800224304,5,103.81,flash-attn3,hf-kernels
43
- HF Kernels Flash Attn3,hf_kernels_flash_attn3,flux_L512,1,1536,24,128,bfloat16,0.5789952039718628,0.5689600110054016,0.5698239803314209,0.5713919997215271,5,109.12,flash-attn3,hf-kernels
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
flash_attn/results/artifacts/combine/latency.png DELETED

Git LFS Details

  • SHA256: 87dbea8f2773d7fcee9fd191cb6e67cd1e2ddd379cef90ee01bb4ac40a55b5f1
  • Pointer size: 131 Bytes
  • Size of remote file: 110 kB
flash_attn/results/artifacts/combine/latency.svg DELETED

Git LFS Details

  • SHA256: 2c1da56080e7fd1a85c14295083b11d6bac981f6fb3faef98b0753eb2c1676c7
  • Pointer size: 130 Bytes
  • Size of remote file: 28.2 kB
flash_attn/results/cells/combine.py DELETED
@@ -1,319 +0,0 @@
1
- # /// script
2
- # requires-python = ">=3.10"
3
- # dependencies = [
4
- # "numpy",
5
- # "torch",
6
- # "kernels-benchmark-tools",
7
- # "matplotlib",
8
- # ]
9
- #
10
- # [tool.uv.sources]
11
- # kernels-benchmark-tools = { git = "https://github.com/drbh/kernels-benchmark-tools.git", branch = "main" }
12
- # ///
13
- import os
14
- import sys
15
- from pathlib import Path
16
- import json
17
- import torch # noqa: F401 # imported because upstream may expect torch to be importable
18
- import kernels_benchmark_tools as kbt
19
-
20
- # --- Matplotlib setup and helpers ------------------------------------------------
21
- import matplotlib as mpl
22
- import matplotlib.pyplot as plt
23
- import csv
24
-
25
-
26
- # Keep text as text (not paths) so CSS can style fonts, size, etc.
27
- mpl.rcParams["svg.fonttype"] = "none"
28
- # Make ids deterministic across builds
29
- mpl.rcParams["svg.hashsalt"] = "latency-benchmark-combined"
30
- # Avoid auto-closed figures interfering with our tagging
31
- mpl.rcParams["figure.autolayout"] = True
32
- # Make background transparent
33
- mpl.rcParams["figure.facecolor"] = "none"
34
- mpl.rcParams["axes.facecolor"] = "none"
35
- mpl.rcParams["savefig.facecolor"] = "none"
36
- mpl.rcParams["savefig.edgecolor"] = "none"
37
-
38
- def _slugify(s: str) -> str:
39
- s = (s or "").strip().lower()
40
- keep = []
41
- for ch in s:
42
- if ch.isalnum():
43
- keep.append(ch)
44
- elif ch in (" ", "-", "_", "/", ".", ":"):
45
- keep.append("-")
46
- else:
47
- keep.append("")
48
- out = "".join(keep)
49
- while "--" in out:
50
- out = out.replace("--", "-")
51
- return out.strip("-") or "unnamed"
52
-
53
- def _tag_current_figure(default_series_prefix="series"):
54
- """Attach SVG ids (gid) to key artists so they can be targeted from CSS."""
55
- fig = plt.gcf()
56
- if fig is None:
57
- return
58
-
59
- # Tag the figure itself
60
- fig.set_gid("figure--latency")
61
-
62
- for ax_idx, ax in enumerate(fig.get_axes(), start=1):
63
- ax.set_gid(f"axes--{ax_idx}")
64
-
65
- # Axis labels & title
66
- if ax.get_title():
67
- for t in ax.texts:
68
- if t.get_text() == ax.get_title():
69
- t.set_gid("title--main")
70
- if ax.xaxis and ax.xaxis.get_label():
71
- ax.xaxis.label.set_gid("label--x")
72
- if ax.yaxis and ax.yaxis.get_label():
73
- ax.yaxis.label.set_gid("label--y")
74
-
75
- # Gridlines
76
- for i, gl in enumerate(ax.get_xgridlines(), start=1):
77
- gl.set_gid(f"grid-x--{i}")
78
- for i, gl in enumerate(ax.get_ygridlines(), start=1):
79
- gl.set_gid(f"grid-y--{i}")
80
-
81
- # Legend block & entries
82
- leg = ax.get_legend()
83
- if leg is not None:
84
- leg.set_gid("legend")
85
- for i, txt in enumerate(leg.get_texts(), start=1):
86
- label_slug = _slugify(txt.get_text())
87
- txt.set_gid(f"legend-label--{label_slug or i}")
88
-
89
- # Series (lines, patches)
90
- # Lines
91
- line_seen = {}
92
- for ln in getattr(ax, "lines", []):
93
- raw_label = ln.get_label() or ""
94
- # Matplotlib uses labels beginning with "_" for non-legendable items
95
- label = raw_label if not raw_label.startswith("_") else f"{default_series_prefix}"
96
- slug = _slugify(label)
97
- line_seen[slug] = line_seen.get(slug, 0) + 1
98
- suffix = "" if line_seen[slug] == 1 else f"-{line_seen[slug]}"
99
- ln.set_gid(f"series--{slug}{suffix}")
100
-
101
- # Patches (bars, areas)
102
- patch_seen = {}
103
- for pt in getattr(ax, "patches", []):
104
- label = getattr(pt, "get_label", lambda: "")() or f"{default_series_prefix}"
105
- if isinstance(label, str) and label.startswith("_"):
106
- label = default_series_prefix
107
- slug = _slugify(label)
108
- patch_seen[slug] = patch_seen.get(slug, 0) + 1
109
- suffix = "" if patch_seen[slug] == 1 else f"-{patch_seen[slug]}"
110
- pt.set_gid(f"series--{slug}{suffix}")
111
-
112
- def _postprocess_svg_add_classes(svg_path: Path):
113
- """Add convenient CSS classes alongside ids (e.g., class='series grid grid-x')."""
114
- try:
115
- import xml.etree.ElementTree as ET
116
- ET.register_namespace("", "http://www.w3.org/2000/svg")
117
- tree = ET.parse(svg_path)
118
- root = tree.getroot()
119
- for el in root.iter():
120
- el_id = el.attrib.get("id", "")
121
- if not el_id:
122
- continue
123
- cls = []
124
- if el_id.startswith("figure--"):
125
- cls.append("figure")
126
- elif el_id.startswith("axes--"):
127
- cls.append("axes")
128
- elif el_id.startswith("grid-x--"):
129
- cls += ["grid", "grid-x"]
130
- elif el_id.startswith("grid-y--"):
131
- cls += ["grid", "grid-y"]
132
- elif el_id.startswith("legend"):
133
- cls.append("legend")
134
- elif el_id.startswith("label--x"):
135
- cls.append("xlabel")
136
- elif el_id.startswith("label--y"):
137
- cls.append("ylabel")
138
- elif el_id.startswith("title--"):
139
- cls.append("title")
140
- elif el_id.startswith("series--"):
141
- cls.append("series")
142
- if cls:
143
- # Preserve any existing class (unlikely from Matplotlib)
144
- existing = el.attrib.get("class", "")
145
- el.set("class", (existing + " " + " ".join(cls)).strip())
146
- tree.write(svg_path, encoding="utf-8", xml_declaration=True)
147
- except Exception as e:
148
- print(f"✗ SVG postprocess (classes) skipped: {e}")
149
-
150
- # Monkey-patch savefig to force SVG & ensure tagging occurs even if kbt.viz saves internally.
151
- _orig_savefig = plt.savefig
152
- def _savefig_svg(fname, *args, **kwargs):
153
- # Always save as SVG at a stable path for the artifact system
154
- out = Path("latency.svg")
155
- kwargs["format"] = "svg"
156
- # Ensure everything we care about has ids before export
157
- _tag_current_figure()
158
- res = _orig_savefig(out, *args, **kwargs)
159
- # Add helpful CSS classes on top of ids
160
- _postprocess_svg_add_classes(out)
161
- print(f"✓ Combined visualization saved as {out}")
162
- return res
163
-
164
- plt.savefig = _savefig_svg # apply patch
165
-
166
- # Capture close calls in case kbt.viz() closes figures before we re-save
167
- _orig_close = plt.close
168
- _last_closed = {"fig": None}
169
- def _capture_close(arg=None):
170
- try:
171
- if hasattr(arg, "savefig"): # looks like a Figure
172
- _last_closed["fig"] = arg
173
- else:
174
- _last_closed["fig"] = plt.gcf()
175
- finally:
176
- return _orig_close(arg)
177
- plt.close = _capture_close
178
-
179
- # --- Locate benchmark artifacts --------------------------------------------------
180
- cache_dirs = {
181
- "Flash (PyTorch SDPA)": os.environ.get('UVNOTE_FILE_FLASH_ATTENTION_BENCHMARK'),
182
- "MemEff (PyTorch SDPA)": os.environ.get('UVNOTE_FILE_MEM_EFFICIENT_ATTENTION_BENCHMARK'),
183
- "Flash Attn 2": os.environ.get('UVNOTE_FILE_FLASH_ATTN2_BENCHMARK'),
184
- "xFormers": os.environ.get('UVNOTE_FILE_XFORMERS_BENCHMARK'),
185
- "SageAttention": os.environ.get('UVNOTE_FILE_SAGE_ATTENTION_BENCHMARK'),
186
- "Compiled (default)": os.environ.get('UVNOTE_FILE_COMPILED_VARIANTS_BENCHMARK_DEFAULT'),
187
- "Compiled (max-autotune)": os.environ.get('UVNOTE_FILE_COMPILED_VARIANTS_BENCHMARK_MAX_AUTOTUNE'),
188
- "HF Kernels Flash Attn": os.environ.get('UVNOTE_FILE_HF_KERNELS_FLASH_ATTN_BENCHMARK'),
189
- "HF Kernels Flash Attn3": os.environ.get('UVNOTE_FILE_HF_KERNELS_FLASH_ATTN3_BENCHMARK'),
190
- }
191
-
192
- print("LOADING BENCHMARK DATA")
193
- for name, cache_dir in cache_dirs.items():
194
- print(f"{name:30s}: {cache_dir}")
195
- print()
196
-
197
- file_mapping = {
198
- "Flash (PyTorch SDPA)": "attn.jsonl",
199
- "MemEff (PyTorch SDPA)": "attn.jsonl",
200
- "Flash Attn 2": "attn.jsonl",
201
- "xFormers": "attn.jsonl",
202
- "SageAttention": "attn.jsonl",
203
- "Compiled (default)": "attn_default.jsonl",
204
- "Compiled (max-autotune)": "attn_max_autotune.jsonl",
205
- "HF Kernels Flash Attn": "attn.jsonl",
206
- "HF Kernels Flash Attn3": "attn.jsonl",
207
- }
208
-
209
- all_paths = []
210
- for name, cache_dir in cache_dirs.items():
211
- if cache_dir:
212
- path = Path(cache_dir) / file_mapping[name]
213
- if path.exists() and path.stat().st_size > 0:
214
- all_paths.append(str(path))
215
- print(f"✓ Found {name}: {path}")
216
- else:
217
- print(f"⊘ Empty/Missing {name}: {path}")
218
- else:
219
- print(f"✗ No cache dir for {name}")
220
- print()
221
-
222
- if not all_paths:
223
- print("ERROR: No benchmark data files found!")
224
- # restore patched functions before exiting
225
- plt.savefig = _orig_savefig
226
- plt.close = _orig_close
227
- sys.exit(1)
228
-
229
- # --- Summary + Visualization -----------------------------------------------------
230
- print("COMBINED BENCHMARK SUMMARY\n")
231
- kbt.summarize(all_paths)
232
- print("\nGENERATING COMBINED VISUALIZATION\n")
233
-
234
- try:
235
- # If kbt.viz saves internally, our patched savefig ensures SVG gets written,
236
- # and it will carry ids/classes for CSS styling.
237
- kbt.viz(all_paths)
238
- # Safety net: if kbt.viz didn't save, save now.
239
- # if not Path("latency.svg").exists():
240
- # _tag_current_figure()
241
- # plt.savefig("latency.svg")
242
-
243
- plt.savefig("latency.svg") # ensure saved with tagging
244
-
245
- print("✓ SVG visualization ready: latency.svg!")
246
- except ImportError as e:
247
- print(f"✗ Visualization requires matplotlib: {e}")
248
- except Exception as e:
249
- print(f"✗ Visualization failed: {e}")
250
- finally:
251
- # Clean up patches to avoid side effects in later cells
252
- plt.savefig = _orig_savefig
253
- plt.close = _orig_close
254
-
255
- print()
256
- print("ANALYSIS COMPLETE")
257
- print(f"Total implementations analyzed: {len(all_paths)}")
258
- print(f"\nImplementations included:")
259
- for name, cache_dir in cache_dirs.items():
260
- if cache_dir:
261
- path = Path(cache_dir) / file_mapping[name]
262
- if path.exists() and path.stat().st_size > 0:
263
- print(f" ✓ {name}")
264
-
265
-
266
-
267
- # Collect all benchmark data and export to CSV
268
- all_data = {}
269
- for name, cache_dir in cache_dirs.items():
270
- if cache_dir:
271
- path = Path(cache_dir) / file_mapping[name]
272
- if path.exists() and path.stat().st_size > 0:
273
- with open(path, 'r') as f:
274
- records = [json.loads(line) for line in f]
275
- all_data[name] = records
276
-
277
- # Export to CSV
278
- csv_path = Path("latency.csv")
279
- with open(csv_path, 'w', newline='') as csvfile:
280
- writer = csv.writer(csvfile)
281
-
282
- # Write header
283
- header = ["Implementation", "Impl ID", "Workload", "Batch", "Seq Length", "Heads", "Head Dim", "Dtype",
284
- "Mean (ms)", "P10 (ms)", "P50 (ms)", "P90 (ms)", "Reps",
285
- # "Compile (ms)",
286
- "Peak Mem (MB)", "Backend", "Family"]
287
- writer.writerow(header)
288
-
289
- # Write data rows
290
- for impl_name, records in all_data.items():
291
- for record in records:
292
- wl = record.get('wl', {})
293
- lat = record.get('lat_ms', {})
294
- tags = record.get('tags', {})
295
-
296
- row = [
297
- impl_name,
298
- record.get('impl', ''),
299
- wl.get('name', ''),
300
- wl.get('batch', ''),
301
- wl.get('seq_len', ''),
302
- wl.get('heads', ''),
303
- wl.get('head_dim', ''),
304
- wl.get('dtype', ''),
305
- lat.get('mean', ''),
306
- lat.get('p10', ''),
307
- lat.get('p50', ''),
308
- lat.get('p90', ''),
309
- lat.get('reps', ''),
310
- # record.get('compile_ms', ''),
311
- round(record.get('peak_bytes', 0) / 1024 / 1024, 2) if record.get('peak_bytes') else '',
312
- tags.get('backend', ''),
313
- tags.get('family', ''),
314
- ]
315
- writer.writerow(row)
316
-
317
- print(f"✓ CSV export complete: {csv_path}")
318
- print(f"Total implementations: {len(all_data)}")
319
- print(f"Total records: {sum(len(records) for records in all_data.values())}")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
flash_attn/results/combined_results.html DELETED
The diff for this file is too large to render. See raw diff
 
flash_attn/results/index.html DELETED
@@ -1,88 +0,0 @@
1
- <!DOCTYPE html>
2
- <html>
3
- <head>
4
- <meta charset='UTF-8'>
5
- <meta name='viewport' content='width=device-width, initial-scale=1.0'>
6
- <title>Index of /flash_attn/results</title>
7
- <style>
8
- :root {
9
- --bg-primary: #0a0a0a;
10
- --bg-secondary: #121212;
11
- --bg-tertiary: #181818;
12
- --text-primary: #e0e0e0;
13
- --text-secondary: #888888;
14
- --text-link: #64b5f6;
15
- --border-primary: #2a2a2a;
16
- }
17
- body {
18
- font-family: system-ui, -apple-system, Segoe UI, Roboto, Helvetica, Arial, sans-serif;
19
- background: var(--bg-primary);
20
- color: var(--text-primary);
21
- margin: 0;
22
- padding: 16px;
23
- max-width: 900px;
24
- margin: 0 auto;
25
- }
26
- .controls {
27
- display: flex;
28
- justify-content: flex-end;
29
- margin-bottom: 1rem;
30
- }
31
- .back-button {
32
- background: var(--bg-secondary);
33
- border: 1px solid var(--border-primary);
34
- padding: 8px 12px;
35
- border-radius: 4px;
36
- color: var(--text-secondary);
37
- cursor: pointer;
38
- font-size: 0.9rem;
39
- text-decoration: none;
40
- display: inline-block;
41
- }
42
- .back-button:hover {
43
- color: var(--text-primary);
44
- background: var(--bg-tertiary);
45
- }
46
- h1 {
47
- font-size: 1.5em;
48
- margin: 1rem 0;
49
- color: var(--text-primary);
50
- border-bottom: 1px solid var(--border-primary);
51
- padding-bottom: 0.5rem;
52
- }
53
- ul {
54
- list-style-type: none;
55
- padding: 0;
56
- }
57
- li {
58
- margin: 0;
59
- border-bottom: 1px solid var(--border-primary);
60
- }
61
- li:last-child {
62
- border-bottom: none;
63
- }
64
- a {
65
- display: block;
66
- padding: 0.75rem 0.5rem;
67
- text-decoration: none;
68
- color: var(--text-link);
69
- transition: background 0.2s ease;
70
- }
71
- a:hover {
72
- background: var(--bg-secondary);
73
- }
74
- .dir {
75
- font-weight: 500;
76
- }
77
- </style>
78
- </head>
79
- <body>
80
- <div class='controls'>
81
- <a href='../index.html' class='back-button'>← back</a>
82
- </div>
83
- <h1>Index of /flash_attn/results</h1>
84
- <ul>
85
- <li><a href='combined_results.html' class='file'>combined_results.html</a></li>
86
- </ul>
87
- </body>
88
- </html>
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
index.html DELETED
@@ -1,85 +0,0 @@
1
- <!DOCTYPE html>
2
- <html>
3
- <head>
4
- <meta charset='UTF-8'>
5
- <meta name='viewport' content='width=device-width, initial-scale=1.0'>
6
- <title>Index of /</title>
7
- <style>
8
- :root {
9
- --bg-primary: #0a0a0a;
10
- --bg-secondary: #121212;
11
- --bg-tertiary: #181818;
12
- --text-primary: #e0e0e0;
13
- --text-secondary: #888888;
14
- --text-link: #64b5f6;
15
- --border-primary: #2a2a2a;
16
- }
17
- body {
18
- font-family: system-ui, -apple-system, Segoe UI, Roboto, Helvetica, Arial, sans-serif;
19
- background: var(--bg-primary);
20
- color: var(--text-primary);
21
- margin: 0;
22
- padding: 16px;
23
- max-width: 900px;
24
- margin: 0 auto;
25
- }
26
- .controls {
27
- display: flex;
28
- justify-content: flex-end;
29
- margin-bottom: 1rem;
30
- }
31
- .back-button {
32
- background: var(--bg-secondary);
33
- border: 1px solid var(--border-primary);
34
- padding: 8px 12px;
35
- border-radius: 4px;
36
- color: var(--text-secondary);
37
- cursor: pointer;
38
- font-size: 0.9rem;
39
- text-decoration: none;
40
- display: inline-block;
41
- }
42
- .back-button:hover {
43
- color: var(--text-primary);
44
- background: var(--bg-tertiary);
45
- }
46
- h1 {
47
- font-size: 1.5em;
48
- margin: 1rem 0;
49
- color: var(--text-primary);
50
- border-bottom: 1px solid var(--border-primary);
51
- padding-bottom: 0.5rem;
52
- }
53
- ul {
54
- list-style-type: none;
55
- padding: 0;
56
- }
57
- li {
58
- margin: 0;
59
- border-bottom: 1px solid var(--border-primary);
60
- }
61
- li:last-child {
62
- border-bottom: none;
63
- }
64
- a {
65
- display: block;
66
- padding: 0.75rem 0.5rem;
67
- text-decoration: none;
68
- color: var(--text-link);
69
- transition: background 0.2s ease;
70
- }
71
- a:hover {
72
- background: var(--bg-secondary);
73
- }
74
- .dir {
75
- font-weight: 500;
76
- }
77
- </style>
78
- </head>
79
- <body>
80
- <h1>Index of /</h1>
81
- <ul>
82
- <li><a href='flash_attn/index.html' class='dir'>flash_attn/</a></li>
83
- </ul>
84
- </body>
85
- </html>
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
megablocks/cells/forward_and_backward.py DELETED
@@ -1,196 +0,0 @@
1
- # /// script
2
- # requires-python = ">=3.12"
3
- # dependencies = [
4
- # "accelerate>=1.10.1",
5
- # "torch>=2.7.0",
6
- # "kernels==0.10.0",
7
- # "transformers@https://github.com/huggingface/transformers.git",
8
- # "ipdb>=0.13.13",
9
- # "matplotlib>=3.7.2",
10
- # "numpy>=1.24.3",
11
- # ]
12
- # ///
13
-
14
- import torch
15
- from transformers import GptOssForCausalLM, PreTrainedTokenizerFast, Mxfp4Config
16
- import time
17
- import torch.nn as nn
18
- from kernels import register_kernel_mapping, Mode, LayerRepository, replace_kernel_forward_from_hub
19
- import sys
20
- import torch.profiler
21
- import gc
22
- import logging
23
- from transformers.models.gpt_oss.modeling_gpt_oss import GptOssRMSNorm
24
-
25
- # remove liger kernel for testing
26
- replace_kernel_forward_from_hub(GptOssRMSNorm, None)
27
-
28
- # set to debug logging
29
- logging.basicConfig(level=logging.INFO)
30
-
31
- def reset_peak_memory_stats():
32
- """Clear CUDA cache and reset memory allocation counters."""
33
- torch.cuda.empty_cache()
34
- if torch.cuda.is_available():
35
- torch.cuda.reset_peak_memory_stats()
36
- gc.collect()
37
-
38
- def get_memory_stats():
39
- """Get current and peak CUDA memory usage."""
40
- if not torch.cuda.is_available():
41
- return {"allocated_gb": 0, "peak_gb": 0, "reserved_gb": 0}
42
- return {
43
- "allocated_gb": torch.cuda.memory_allocated() / 1e9,
44
- "peak_gb": torch.cuda.max_memory_allocated() / 1e9,
45
- "reserved_gb": torch.cuda.memory_reserved() / 1e9,
46
- }
47
-
48
- def override_kernel_layer_name(cls_name: str, value) -> bool:
49
- """Helper to dynamically override the kernel_layer_name in a model class."""
50
- for mod in sys.modules.values():
51
- if mod is None:
52
- continue
53
- obj = getattr(mod, cls_name, None)
54
- if isinstance(obj, type) and issubclass(obj, nn.Module):
55
- setattr(obj, "kernel_layer_name", value)
56
- print(f"Overrode {cls_name}.kernel_layer_name to {value}")
57
- return True
58
- return False
59
-
60
-
61
- # Init the model the normal way
62
- model_id = "openai/gpt-oss-20b"
63
- tokenizer = PreTrainedTokenizerFast.from_pretrained(model_id)
64
- quantization_config = Mxfp4Config(dequantize=True)
65
-
66
- model = GptOssForCausalLM.from_pretrained(
67
- model_id,
68
- dtype="bfloat16",
69
- device_map="auto",
70
- use_kernels=True,
71
- quantization_config=quantization_config,
72
- ).eval()
73
-
74
- messages = [
75
- {"role": "system", "content": "What is Tensor Parallelism?"},
76
- ]
77
-
78
- inputs = tokenizer.apply_chat_template(
79
- messages,
80
- add_generation_prompt=True,
81
- return_tensors="pt",
82
- return_dict=True,
83
- reasoning_effort="low",
84
- ).to("cuda")
85
-
86
- max_tokens = 128 # Reduced to help with memory usage
87
-
88
- # Clear memory before backward pass
89
- reset_peak_memory_stats()
90
- print(f"Pre-generation memory: {get_memory_stats()}")
91
-
92
- # forward and backward pass
93
- with torch.autograd.set_grad_enabled(True):
94
- start_time = time.perf_counter()
95
- generated = model.generate(
96
- **inputs,
97
- max_new_tokens=max_tokens,
98
- do_sample=False,
99
- temperature=None,
100
- )
101
- end_time = time.perf_counter()
102
- print(tokenizer.decode(generated[0], skip_special_tokens=False))
103
- print(f"Generation took {end_time - start_time:.2f} seconds")
104
- print(f"Post-generation memory: {get_memory_stats()}")
105
-
106
- # Use gradient checkpointing to reduce memory usage
107
- if hasattr(model, 'gradient_checkpointing_enable'):
108
- model.gradient_checkpointing_enable()
109
- print("Enabled gradient checkpointing")
110
-
111
- # Reduce sequence length if needed for memory
112
- max_seq_len = 512 # Limit sequence length for backward pass
113
- if generated.size(1) > max_seq_len:
114
- print(f"Truncating sequence from {generated.size(1)} to {max_seq_len} tokens")
115
- full_sequence = generated[:, -max_seq_len:]
116
- else:
117
- full_sequence = generated
118
-
119
- # Get model outputs for the full sequence
120
- model.train() # Enable dropout and other training behaviors
121
-
122
- try:
123
- outputs = model(
124
- input_ids=full_sequence,
125
- labels=full_sequence, # This will compute loss internally
126
- return_dict=True
127
- )
128
- print(f"Post-forward memory: {get_memory_stats()}")
129
-
130
- # If model doesn't compute loss, compute it manually
131
- if outputs.loss is None:
132
- shift_logits = outputs.logits[..., :-1, :].contiguous()
133
- shift_labels = full_sequence[..., 1:].contiguous()
134
-
135
- # Use CrossEntropyLoss with ignore_index for padding tokens
136
- loss_fct = torch.nn.CrossEntropyLoss(ignore_index=tokenizer.pad_token_id if tokenizer.pad_token_id is not None else -100)
137
- loss = loss_fct(
138
- shift_logits.view(-1, shift_logits.size(-1)),
139
- shift_labels.view(-1)
140
- )
141
- else:
142
- loss = outputs.loss
143
-
144
- print(f"Loss: {loss.item():.4f}")
145
-
146
- # Clear intermediate tensors to save memory
147
- del outputs
148
- torch.cuda.empty_cache()
149
-
150
- # Perform backward pass with memory management
151
- print("Running backward pass...")
152
- print(f"Pre-backward memory: {get_memory_stats()}")
153
-
154
- loss.backward()
155
- print(f"Post-backward memory: {get_memory_stats()}")
156
-
157
- except torch.cuda.OutOfMemoryError as e:
158
- print(f"OOM during forward/backward pass: {e}")
159
- print("Try reducing max_tokens or max_seq_len")
160
- raise
161
-
162
- # Calculate gradient statistics and print sample gradients
163
- total_norm = 0.0
164
- param_count = 0
165
- grad_samples = {}
166
-
167
- for name, p in model.named_parameters():
168
- if p.grad is not None:
169
- param_count += 1
170
- grad_norm = p.grad.data.norm(2).item()
171
- total_norm += grad_norm ** 2
172
-
173
- # Collect gradient statistics for key layers
174
- if any(key in name for key in ['embed', 'lm_head', 'mlp.up', 'mlp.down', 'self_attn.q_proj', 'norm']):
175
- grad_samples[name] = {
176
- 'norm': grad_norm,
177
- 'mean': p.grad.data.mean().item(),
178
- 'std': p.grad.data.std().item(),
179
- 'max': p.grad.data.max().item(),
180
- 'min': p.grad.data.min().item(),
181
- }
182
-
183
- total_norm = total_norm ** 0.5
184
-
185
- print(f"\nGradient norm: {total_norm:.4f}")
186
- print(f"Parameters with gradients: {param_count}")
187
-
188
- # Print sample gradients from important layers
189
- print("\nSample gradient statistics:")
190
- for i, (name, stats) in enumerate(list(grad_samples.items())[:10]):
191
- print(f" {name[:60]:<60} | norm: {stats['norm']:.4e} | mean: {stats['mean']:.4e} | std: {stats['std']:.4e}")
192
-
193
- # Optional: zero gradients for next iteration
194
- model.zero_grad()
195
- model.eval() # Switch back to eval mode
196
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
megablocks/cells/forward_and_backward_no_kernel.py DELETED
@@ -1,196 +0,0 @@
1
- # /// script
2
- # requires-python = ">=3.12"
3
- # dependencies = [
4
- # "accelerate>=1.10.1",
5
- # "torch>=2.7.0",
6
- # "kernels==0.10.0",
7
- # "transformers@https://github.com/huggingface/transformers.git",
8
- # "ipdb>=0.13.13",
9
- # "matplotlib>=3.7.2",
10
- # "numpy>=1.24.3",
11
- # ]
12
- # ///
13
-
14
- import torch
15
- from transformers import GptOssForCausalLM, PreTrainedTokenizerFast, Mxfp4Config
16
- import time
17
- import torch.nn as nn
18
- from kernels import register_kernel_mapping, Mode, LayerRepository, replace_kernel_forward_from_hub
19
- import sys
20
- import torch.profiler
21
- import gc
22
- import logging
23
- from transformers.models.gpt_oss.modeling_gpt_oss import GptOssRMSNorm
24
-
25
- # remove liger kernel for testing
26
- replace_kernel_forward_from_hub(GptOssRMSNorm, None)
27
-
28
- # set to debug logging
29
- logging.basicConfig(level=logging.INFO)
30
-
31
- def reset_peak_memory_stats():
32
- """Clear CUDA cache and reset memory allocation counters."""
33
- torch.cuda.empty_cache()
34
- if torch.cuda.is_available():
35
- torch.cuda.reset_peak_memory_stats()
36
- gc.collect()
37
-
38
- def get_memory_stats():
39
- """Get current and peak CUDA memory usage."""
40
- if not torch.cuda.is_available():
41
- return {"allocated_gb": 0, "peak_gb": 0, "reserved_gb": 0}
42
- return {
43
- "allocated_gb": torch.cuda.memory_allocated() / 1e9,
44
- "peak_gb": torch.cuda.max_memory_allocated() / 1e9,
45
- "reserved_gb": torch.cuda.memory_reserved() / 1e9,
46
- }
47
-
48
- def override_kernel_layer_name(cls_name: str, value) -> bool:
49
- """Helper to dynamically override the kernel_layer_name in a model class."""
50
- for mod in sys.modules.values():
51
- if mod is None:
52
- continue
53
- obj = getattr(mod, cls_name, None)
54
- if isinstance(obj, type) and issubclass(obj, nn.Module):
55
- setattr(obj, "kernel_layer_name", value)
56
- print(f"Overrode {cls_name}.kernel_layer_name to {value}")
57
- return True
58
- return False
59
-
60
-
61
- # Init the model the normal way
62
- model_id = "openai/gpt-oss-20b"
63
- tokenizer = PreTrainedTokenizerFast.from_pretrained(model_id)
64
- quantization_config = Mxfp4Config(dequantize=True)
65
-
66
- model = GptOssForCausalLM.from_pretrained(
67
- model_id,
68
- dtype="bfloat16",
69
- device_map="auto",
70
- use_kernels=False,
71
- quantization_config=quantization_config,
72
- ).eval()
73
-
74
- messages = [
75
- {"role": "system", "content": "What is Tensor Parallelism?"},
76
- ]
77
-
78
- inputs = tokenizer.apply_chat_template(
79
- messages,
80
- add_generation_prompt=True,
81
- return_tensors="pt",
82
- return_dict=True,
83
- reasoning_effort="low",
84
- ).to("cuda")
85
-
86
- max_tokens = 128 # Reduced to help with memory usage
87
-
88
- # Clear memory before backward pass
89
- reset_peak_memory_stats()
90
- print(f"Pre-generation memory: {get_memory_stats()}")
91
-
92
- # forward and backward pass
93
- with torch.autograd.set_grad_enabled(True):
94
- start_time = time.perf_counter()
95
- generated = model.generate(
96
- **inputs,
97
- max_new_tokens=max_tokens,
98
- do_sample=False,
99
- temperature=None,
100
- )
101
- end_time = time.perf_counter()
102
- print(tokenizer.decode(generated[0], skip_special_tokens=False))
103
- print(f"Generation took {end_time - start_time:.2f} seconds")
104
- print(f"Post-generation memory: {get_memory_stats()}")
105
-
106
- # Use gradient checkpointing to reduce memory usage
107
- if hasattr(model, 'gradient_checkpointing_enable'):
108
- model.gradient_checkpointing_enable()
109
- print("Enabled gradient checkpointing")
110
-
111
- # Reduce sequence length if needed for memory
112
- max_seq_len = 512 # Limit sequence length for backward pass
113
- if generated.size(1) > max_seq_len:
114
- print(f"Truncating sequence from {generated.size(1)} to {max_seq_len} tokens")
115
- full_sequence = generated[:, -max_seq_len:]
116
- else:
117
- full_sequence = generated
118
-
119
- # Get model outputs for the full sequence
120
- model.train() # Enable dropout and other training behaviors
121
-
122
- try:
123
- outputs = model(
124
- input_ids=full_sequence,
125
- labels=full_sequence, # This will compute loss internally
126
- return_dict=True
127
- )
128
- print(f"Post-forward memory: {get_memory_stats()}")
129
-
130
- # If model doesn't compute loss, compute it manually
131
- if outputs.loss is None:
132
- shift_logits = outputs.logits[..., :-1, :].contiguous()
133
- shift_labels = full_sequence[..., 1:].contiguous()
134
-
135
- # Use CrossEntropyLoss with ignore_index for padding tokens
136
- loss_fct = torch.nn.CrossEntropyLoss(ignore_index=tokenizer.pad_token_id if tokenizer.pad_token_id is not None else -100)
137
- loss = loss_fct(
138
- shift_logits.view(-1, shift_logits.size(-1)),
139
- shift_labels.view(-1)
140
- )
141
- else:
142
- loss = outputs.loss
143
-
144
- print(f"Loss: {loss.item():.4f}")
145
-
146
- # Clear intermediate tensors to save memory
147
- del outputs
148
- torch.cuda.empty_cache()
149
-
150
- # Perform backward pass with memory management
151
- print("Running backward pass...")
152
- print(f"Pre-backward memory: {get_memory_stats()}")
153
-
154
- loss.backward()
155
- print(f"Post-backward memory: {get_memory_stats()}")
156
-
157
- except torch.cuda.OutOfMemoryError as e:
158
- print(f"OOM during forward/backward pass: {e}")
159
- print("Try reducing max_tokens or max_seq_len")
160
- raise
161
-
162
- # Calculate gradient statistics and print sample gradients
163
- total_norm = 0.0
164
- param_count = 0
165
- grad_samples = {}
166
-
167
- for name, p in model.named_parameters():
168
- if p.grad is not None:
169
- param_count += 1
170
- grad_norm = p.grad.data.norm(2).item()
171
- total_norm += grad_norm ** 2
172
-
173
- # Collect gradient statistics for key layers
174
- if any(key in name for key in ['embed', 'lm_head', 'mlp.up', 'mlp.down', 'self_attn.q_proj', 'norm']):
175
- grad_samples[name] = {
176
- 'norm': grad_norm,
177
- 'mean': p.grad.data.mean().item(),
178
- 'std': p.grad.data.std().item(),
179
- 'max': p.grad.data.max().item(),
180
- 'min': p.grad.data.min().item(),
181
- }
182
-
183
- total_norm = total_norm ** 0.5
184
-
185
- print(f"\nGradient norm: {total_norm:.4f}")
186
- print(f"Parameters with gradients: {param_count}")
187
-
188
- # Print sample gradients from important layers
189
- print("\nSample gradient statistics:")
190
- for i, (name, stats) in enumerate(list(grad_samples.items())[:10]):
191
- print(f" {name[:60]:<60} | norm: {stats['norm']:.4e} | mean: {stats['mean']:.4e} | std: {stats['std']:.4e}")
192
-
193
- # Optional: zero gradients for next iteration
194
- model.zero_grad()
195
- model.eval() # Switch back to eval mode
196
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
megablocks/cells/forward_only.py DELETED
@@ -1,101 +0,0 @@
1
- # /// script
2
- # requires-python = ">=3.12"
3
- # dependencies = [
4
- # "accelerate>=1.10.1",
5
- # "torch>=2.7.0",
6
- # "kernels==0.10.0",
7
- # "transformers@https://github.com/huggingface/transformers.git",
8
- # "ipdb>=0.13.13",
9
- # "matplotlib>=3.7.2",
10
- # "numpy>=1.24.3",
11
- # ]
12
- # ///
13
-
14
- import torch
15
- from transformers import GptOssForCausalLM, PreTrainedTokenizerFast, Mxfp4Config
16
- import time
17
- import torch.nn as nn
18
- from kernels import register_kernel_mapping, Mode, LayerRepository, replace_kernel_forward_from_hub
19
- import sys
20
- import torch.profiler
21
- import gc
22
- import logging
23
- from transformers.models.gpt_oss.modeling_gpt_oss import GptOssRMSNorm
24
-
25
-
26
- replace_kernel_forward_from_hub(GptOssRMSNorm, None)
27
-
28
- # set to debug logging
29
- logging.basicConfig(level=logging.INFO)
30
-
31
- def reset_peak_memory_stats():
32
- """Clear CUDA cache and reset memory allocation counters."""
33
- torch.cuda.empty_cache()
34
- if torch.cuda.is_available():
35
- torch.cuda.reset_peak_memory_stats()
36
- gc.collect()
37
-
38
- def get_memory_stats():
39
- """Get current and peak CUDA memory usage."""
40
- if not torch.cuda.is_available():
41
- return {"allocated_gb": 0, "peak_gb": 0, "reserved_gb": 0}
42
- return {
43
- "allocated_gb": torch.cuda.memory_allocated() / 1e9,
44
- "peak_gb": torch.cuda.max_memory_allocated() / 1e9,
45
- "reserved_gb": torch.cuda.memory_reserved() / 1e9,
46
- }
47
-
48
- def override_kernel_layer_name(cls_name: str, value) -> bool:
49
- """Helper to dynamically override the kernel_layer_name in a model class."""
50
- for mod in sys.modules.values():
51
- if mod is None:
52
- continue
53
- obj = getattr(mod, cls_name, None)
54
- if isinstance(obj, type) and issubclass(obj, nn.Module):
55
- setattr(obj, "kernel_layer_name", value)
56
- print(f"Overrode {cls_name}.kernel_layer_name to {value}")
57
- return True
58
- return False
59
-
60
-
61
- # Init the model the normal way
62
- model_id = "openai/gpt-oss-20b"
63
- tokenizer = PreTrainedTokenizerFast.from_pretrained(model_id)
64
- quantization_config = Mxfp4Config(dequantize=True)
65
-
66
-
67
-
68
- model = GptOssForCausalLM.from_pretrained(
69
- model_id,
70
- dtype="bfloat16",
71
- device_map="auto",
72
- use_kernels=True,
73
- quantization_config=quantization_config,
74
- ).eval()
75
-
76
- messages = [
77
- {"role": "system", "content": "What is Tensor Parallelism?"},
78
- ]
79
-
80
- inputs = tokenizer.apply_chat_template(
81
- messages,
82
- add_generation_prompt=True,
83
- return_tensors="pt",
84
- return_dict=True,
85
- reasoning_effort="low",
86
- ).to("cuda")
87
-
88
- max_tokens = 256
89
-
90
- with torch.inference_mode():
91
- start_time = time.perf_counter()
92
- generated = model.generate(
93
- **inputs,
94
- max_new_tokens=max_tokens,
95
- do_sample=False,
96
- temperature=None,
97
- )
98
- end_time = time.perf_counter()
99
-
100
- print(tokenizer.decode(generated[0], skip_special_tokens=False))
101
- print(f"Generation took {end_time - start_time:.2f} seconds")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
megablocks/cells/no_kernels.py DELETED
@@ -1,98 +0,0 @@
1
- # /// script
2
- # requires-python = ">=3.12"
3
- # dependencies = [
4
- # "accelerate>=1.10.1",
5
- # "torch>=2.7.0",
6
- # "kernels==0.10.0",
7
- # "transformers@https://github.com/huggingface/transformers.git",
8
- # "ipdb>=0.13.13",
9
- # "matplotlib>=3.7.2",
10
- # "numpy>=1.24.3",
11
- # ]
12
- # ///
13
-
14
- import torch
15
- from transformers import GptOssForCausalLM, PreTrainedTokenizerFast, Mxfp4Config
16
- import time
17
- import torch.nn as nn
18
- from kernels import register_kernel_mapping, Mode, LayerRepository, replace_kernel_forward_from_hub
19
- import sys
20
- import torch.profiler
21
- import gc
22
- import logging
23
- from transformers.models.gpt_oss.modeling_gpt_oss import GptOssRMSNorm
24
-
25
- # set to debug logging
26
- logging.basicConfig(level=logging.INFO)
27
-
28
- def reset_peak_memory_stats():
29
- """Clear CUDA cache and reset memory allocation counters."""
30
- torch.cuda.empty_cache()
31
- if torch.cuda.is_available():
32
- torch.cuda.reset_peak_memory_stats()
33
- gc.collect()
34
-
35
- def get_memory_stats():
36
- """Get current and peak CUDA memory usage."""
37
- if not torch.cuda.is_available():
38
- return {"allocated_gb": 0, "peak_gb": 0, "reserved_gb": 0}
39
- return {
40
- "allocated_gb": torch.cuda.memory_allocated() / 1e9,
41
- "peak_gb": torch.cuda.max_memory_allocated() / 1e9,
42
- "reserved_gb": torch.cuda.memory_reserved() / 1e9,
43
- }
44
-
45
- def override_kernel_layer_name(cls_name: str, value) -> bool:
46
- """Helper to dynamically override the kernel_layer_name in a model class."""
47
- for mod in sys.modules.values():
48
- if mod is None:
49
- continue
50
- obj = getattr(mod, cls_name, None)
51
- if isinstance(obj, type) and issubclass(obj, nn.Module):
52
- setattr(obj, "kernel_layer_name", value)
53
- print(f"Overrode {cls_name}.kernel_layer_name to {value}")
54
- return True
55
- return False
56
-
57
-
58
- # Init the model the normal way
59
- model_id = "openai/gpt-oss-20b"
60
- tokenizer = PreTrainedTokenizerFast.from_pretrained(model_id)
61
- quantization_config = Mxfp4Config(dequantize=True)
62
-
63
-
64
-
65
- model = GptOssForCausalLM.from_pretrained(
66
- model_id,
67
- dtype="bfloat16",
68
- device_map="auto",
69
- use_kernels=False,
70
- quantization_config=quantization_config,
71
- ).eval()
72
-
73
- messages = [
74
- {"role": "system", "content": "What is Tensor Parallelism?"},
75
- ]
76
-
77
- inputs = tokenizer.apply_chat_template(
78
- messages,
79
- add_generation_prompt=True,
80
- return_tensors="pt",
81
- return_dict=True,
82
- reasoning_effort="low",
83
- ).to("cuda")
84
-
85
- max_tokens = 256
86
-
87
- with torch.inference_mode():
88
- start_time = time.perf_counter()
89
- generated = model.generate(
90
- **inputs,
91
- max_new_tokens=max_tokens,
92
- do_sample=False,
93
- temperature=None,
94
- )
95
- end_time = time.perf_counter()
96
-
97
- print(tokenizer.decode(generated[0], skip_special_tokens=False))
98
- print(f"Generation took {end_time - start_time:.2f} seconds")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
megablocks/cells/nv.py DELETED
@@ -1,3 +0,0 @@
1
- import subprocess
2
-
3
- print(subprocess.run(["nvidia-smi"], capture_output=True, text=True).stdout)
 
 
 
 
megablocks/index.html DELETED
@@ -1,24 +0,0 @@
1
- <!DOCTYPE html>
2
- <html>
3
- <head>
4
- <meta charset='UTF-8'>
5
- <title>Directory Index</title>
6
- <style>
7
- body { font-family: monospace; margin: 20px; }
8
- h1 { font-size: 1.5em; }
9
- ul { list-style-type: none; padding-left: 20px; }
10
- li { margin: 5px 0; }
11
- .dir { font-weight: bold; }
12
- .file { color: #0066cc; }
13
- a { text-decoration: none; }
14
- a:hover { text-decoration: underline; }
15
- </style>
16
- </head>
17
- <body>
18
- <h1>Index of /megablocks</h1>
19
- <ul>
20
- <li><a href='../index.html' class='dir'>../</a></li>
21
- <li><a href='megablocks_only.html' class='file'>megablocks_only.html</a></li>
22
- </ul>
23
- </body>
24
- </html>
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
megablocks/megablocks_only.html DELETED
The diff for this file is too large to render. See raw diff
 
megablocks_yamoe/artifacts/binned_run/binned_results.json DELETED
@@ -1,24 +0,0 @@
1
- {
2
- "implementation": "binned_results",
3
- "config": {
4
- "warmup": 10,
5
- "iters": 50,
6
- "device": "cuda",
7
- "dtype": "torch.float32",
8
- "tokens": 100,
9
- "vary_inputs": true
10
- },
11
- "stats": {
12
- "avg_ms": 36.26809924006011,
13
- "min_ms": 34.103908000361116,
14
- "max_ms": 37.68557000057626,
15
- "std_ms": 1.1598518125118418,
16
- "p50_ms": 36.52223600056459,
17
- "p95_ms": 37.6427445000445,
18
- "p99_ms": 37.677440410316194,
19
- "num_iters": 50,
20
- "tokens_per_s": 2757.2440269917565,
21
- "throughput_variance": 89.13103199163609
22
- },
23
- "output_sum": 3.97190523147583
24
- }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
megablocks_yamoe/artifacts/gptoss_run/gptoss_results.json DELETED
@@ -1,24 +0,0 @@
1
- {
2
- "implementation": "gptoss_results",
3
- "config": {
4
- "warmup": 10,
5
- "iters": 50,
6
- "device": "cuda",
7
- "dtype": "torch.float32",
8
- "tokens": 100,
9
- "vary_inputs": true
10
- },
11
- "stats": {
12
- "avg_ms": 46.913985819956,
13
- "min_ms": 40.44806400088419,
14
- "max_ms": 51.07520399997156,
15
- "std_ms": 2.9921332618008196,
16
- "p50_ms": 47.418902999652346,
17
- "p95_ms": 50.800493049837314,
18
- "p99_ms": 50.948625239852845,
19
- "num_iters": 50,
20
- "tokens_per_s": 2131.560519794133,
21
- "throughput_variance": 139.93911554997217
22
- },
23
- "output_sum": 11.53223705291748
24
- }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
megablocks_yamoe/artifacts/gptoss_training_run/gptoss_training_results.json DELETED
@@ -1,24 +0,0 @@
1
- {
2
- "implementation": "gptoss_training_results",
3
- "config": {
4
- "warmup": 10,
5
- "iters": 50,
6
- "device": "cuda",
7
- "dtype": "torch.float32",
8
- "tokens": 100,
9
- "vary_inputs": true
10
- },
11
- "stats": {
12
- "avg_ms": 46.289439859992854,
13
- "min_ms": 39.97907499979192,
14
- "max_ms": 50.58144600025116,
15
- "std_ms": 2.9172154402078077,
16
- "p50_ms": 46.64785849990949,
17
- "p95_ms": 50.26727430031315,
18
- "p99_ms": 50.5162941305025,
19
- "num_iters": 50,
20
- "tokens_per_s": 2160.3199412751637,
21
- "throughput_variance": 139.86427060112865
22
- },
23
- "output_sum": 11.53223705291748
24
- }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
megablocks_yamoe/artifacts/yamoe_run/yamoe_results.json DELETED
@@ -1,24 +0,0 @@
1
- {
2
- "implementation": "yamoe_results",
3
- "config": {
4
- "warmup": 10,
5
- "iters": 50,
6
- "device": "cuda",
7
- "dtype": "torch.float32",
8
- "tokens": 100,
9
- "vary_inputs": true
10
- },
11
- "stats": {
12
- "avg_ms": 4.248197240067384,
13
- "min_ms": 4.136622000260104,
14
- "max_ms": 4.280714999367774,
15
- "std_ms": 0.02141682051311511,
16
- "p50_ms": 4.253484999935608,
17
- "p95_ms": 4.265540049709671,
18
- "p99_ms": 4.273649199667489,
19
- "num_iters": 50,
20
- "tokens_per_s": 23539.396677922097,
21
- "throughput_variance": 120.66648678204231
22
- },
23
- "output_sum": 3.97190523147583
24
- }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
megablocks_yamoe/cells/__pycache__/bench_utils.cpython-311.pyc DELETED
Binary file (16.1 kB)
 
megablocks_yamoe/cells/__pycache__/config.cpython-311.pyc DELETED
Binary file (680 Bytes)
 
megablocks_yamoe/cells/bench_utils.py DELETED
@@ -1,241 +0,0 @@
1
- # /// script
2
- # dependencies = [
3
- # "torch",
4
- # "numpy",
5
- # ]
6
- # ///
7
-
8
- """Reusable benchmarking utilities for performance testing."""
9
- import time
10
- import numpy as np
11
- from contextlib import contextmanager
12
- from typing import Callable, Dict, Tuple, Any, Optional
13
- import torch
14
-
15
- def to_dtype(dtype_str: str):
16
- """Convert string to torch dtype."""
17
- if dtype_str == "float16":
18
- return torch.float16
19
- if dtype_str == "bfloat16":
20
- return torch.bfloat16
21
- return torch.float32
22
-
23
- def _sync(device: str):
24
- """Synchronize device if CUDA."""
25
- if device == "cuda":
26
- torch.cuda.synchronize()
27
-
28
- def _compute_stats(times_s, tokens: Optional[int] = None) -> Dict[str, float]:
29
- """Compute comprehensive latency and throughput statistics."""
30
- lat_ms = np.array([t * 1000.0 for t in times_s])
31
- lat_ms_sorted = np.sort(lat_ms)
32
- n = len(lat_ms)
33
-
34
- stats = {
35
- "avg_ms": np.mean(lat_ms),
36
- "min_ms": np.min(lat_ms),
37
- "max_ms": np.max(lat_ms),
38
- "std_ms": np.std(lat_ms),
39
- "p50_ms": np.percentile(lat_ms, 50),
40
- "p95_ms": np.percentile(lat_ms, 95),
41
- "p99_ms": np.percentile(lat_ms, 99),
42
- "num_iters": n
43
- }
44
-
45
- if tokens is not None and n > 0:
46
- avg_s = np.mean(times_s)
47
- stats["tokens_per_s"] = tokens / avg_s if avg_s > 0 else float("inf")
48
- stats["throughput_variance"] = np.std([tokens / t for t in times_s if t > 0])
49
-
50
- return stats
51
-
52
- def _format_timing_stats(stats: Dict[str, float], tokens: Optional[int] = None) -> str:
53
- """Format timing statistics for display."""
54
- lines = [
55
- "\n━━━━━━━━━━━━━━━━━━━━ Benchmark Results ━━━━━━━━━━━━━━━━━━━━",
56
- f"Iterations: {stats.get('num_iters', 0)}",
57
- "\nLatency Statistics:",
58
- f" Average: {stats['avg_ms']:.3f} ms",
59
- f" Min: {stats['min_ms']:.3f} ms",
60
- f" Max: {stats['max_ms']:.3f} ms",
61
- f" Std Dev: {stats['std_ms']:.3f} ms",
62
- "\nPercentiles:",
63
- f" P50 (median): {stats['p50_ms']:.3f} ms",
64
- f" P95: {stats['p95_ms']:.3f} ms",
65
- f" P99: {stats['p99_ms']:.3f} ms",
66
- ]
67
-
68
- if tokens is not None and 'tokens_per_s' in stats:
69
- lines.extend([
70
- "\nThroughput:",
71
- f" Tokens/sec: {stats['tokens_per_s']:.1f}",
72
- f" Std Dev: {stats.get('throughput_variance', 0):.1f}",
73
- ])
74
-
75
- lines.append("━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━")
76
- return "\n".join(lines)
77
-
78
- def _bench_engine(
79
- call: Callable[[], Any], *, warmup: int, iters: int, device: str, dtype, input_gen: Callable[[], Any] = None
80
- ) -> Tuple[Any, list]:
81
- """Core benchmarking engine with warmup and timing."""
82
- use_autocast = device == "cuda" and dtype in (torch.float16, torch.bfloat16)
83
-
84
- # Warmup phase
85
- print(f"\nWarming up ({warmup} iterations)...")
86
- with torch.inference_mode():
87
- for _ in range(max(0, warmup)):
88
- if use_autocast:
89
- with torch.autocast(device_type="cuda", dtype=dtype):
90
- if input_gen is not None:
91
- _ = call(input_gen())
92
- else:
93
- _ = call()
94
- else:
95
- if input_gen is not None:
96
- _ = call(input_gen())
97
- else:
98
- _ = call()
99
- _sync(device)
100
-
101
- # Measurement phase
102
- print(f"Benchmarking ({iters} iterations)...")
103
- times_s = []
104
- last = None
105
- with torch.inference_mode():
106
- for i in range(max(1, iters)):
107
- start = time.perf_counter()
108
- if use_autocast:
109
- with torch.autocast(device_type="cuda", dtype=dtype):
110
- if input_gen is not None:
111
- last = call(input_gen())
112
- else:
113
- last = call()
114
- else:
115
- if input_gen is not None:
116
- last = call(input_gen())
117
- else:
118
- last = call()
119
- _sync(device)
120
- end = time.perf_counter()
121
- times_s.append(end - start)
122
-
123
- # Progress indicator every 20% of iterations
124
- if i > 0 and i % max(1, iters // 5) == 0:
125
- pct = (i / iters) * 100
126
- avg_so_far = np.mean(times_s[:i]) * 1000
127
- print(f" Progress: {pct:.0f}% complete (avg: {avg_so_far:.3f} ms)")
128
-
129
- return last, times_s
130
-
131
- def tensor_stats(t: torch.Tensor) -> str:
132
- """Generate comprehensive stats string for a tensor."""
133
- return (f"shape={tuple(t.shape)}, "
134
- f"dtype={t.dtype}, "
135
- f"device={t.device}, "
136
- f"range=[{t.min().item():.6f}, {t.max().item():.6f}], "
137
- f"mean={t.mean().item():.6f}, "
138
- f"std={t.std().item():.6f}, "
139
- f"norm={t.norm().item():.6f}")
140
-
141
- @contextmanager
142
- def bench_context(
143
- *, warmup: int = 25, iters: int = 100, device: str = "cuda", dtype=torch.float32, tokens: Optional[int] = None, verbose: bool = True, save_json: Optional[str] = None, vary_inputs: bool = True
144
- ):
145
- """Context that yields a runner: runner(fn, *args, **kwargs) -> (result, stats).
146
-
147
- If vary_inputs=True, the first argument should be a base tensor that will be varied each iteration
148
- by adding a small deterministic increment to prevent caching artifacts.
149
- """
150
-
151
- def runner(fn: Callable[..., Any], *args, **kwargs) -> Tuple[Any, Dict[str, float]]:
152
- # Log configuration
153
- if verbose:
154
- print(f"\n┌─ Benchmark Configuration ─────────────────────────────┐")
155
- # print(f"│ Device: {device:<15} Dtype: {dtype} │")
156
- print(f"│ Warmup: {warmup:<15} Iters: {iters} │")
157
- if tokens:
158
- print(f"│ Tokens: {tokens} │")
159
- if vary_inputs:
160
- print(f"│ Input Variation: Enabled (prevents caching artifacts) │")
161
- print(f"└────────────────────────────────────────────────────────┘")
162
-
163
- # Set up input generation
164
- input_gen = None
165
- if vary_inputs and args and isinstance(args[0], torch.Tensor):
166
- base_input = args[0].clone()
167
- iteration_counter = [0] # Use list for mutable closure
168
-
169
- def generate_varied_input():
170
- """Generate input tensor varied by iteration to prevent caching."""
171
- # Add small deterministic increment: 0.001 * iteration_number
172
- varied_input = base_input + (iteration_counter[0] * 0.001)
173
- iteration_counter[0] += 1
174
- return varied_input
175
-
176
- input_gen = generate_varied_input
177
- call = lambda x: fn(x, *args[1:], **kwargs)
178
-
179
- # Log base input stats
180
- if verbose:
181
- print(f"\nBase Input: {tensor_stats(base_input)}")
182
- print(f"Input Variation: +{0.001:.3f} * iteration (deterministic)")
183
- else:
184
- # Legacy mode - static inputs
185
- call = lambda: fn(*args, **kwargs)
186
- if verbose and args and isinstance(args[0], torch.Tensor):
187
- print(f"\nInput: {tensor_stats(args[0])}")
188
-
189
- result, times_s = _bench_engine(call, warmup=warmup, iters=iters, device=device, dtype=dtype, input_gen=input_gen)
190
-
191
- # Log output if it's a tensor or tuple with tensors
192
- if verbose:
193
- print("\nOutput tensors:")
194
- if isinstance(result, torch.Tensor):
195
- print(f" Primary: {tensor_stats(result)}")
196
- elif isinstance(result, tuple) and len(result) > 0 and isinstance(result[0], torch.Tensor):
197
- print(f" Primary: {tensor_stats(result[0])}")
198
- if len(result) > 1:
199
- if isinstance(result[1], torch.Tensor):
200
- print(f" Auxiliary: {tensor_stats(result[1])}")
201
- else:
202
- print(f" Auxiliary: {type(result[1]).__name__}")
203
-
204
- # Compute and display statistics
205
- stats = _compute_stats(times_s, tokens=tokens)
206
- if verbose:
207
- print(_format_timing_stats(stats, tokens))
208
-
209
- # Save to JSON if requested
210
- if save_json:
211
- import json
212
- json_data = {
213
- "implementation": save_json.replace(".json", ""),
214
- "config": {
215
- "warmup": warmup,
216
- "iters": iters,
217
- "device": str(device), # Convert device to string
218
- "dtype": str(dtype),
219
- "tokens": tokens,
220
- "vary_inputs": vary_inputs
221
- },
222
- "stats": stats,
223
- "output_sum": float(result[0].sum().item()) if isinstance(result, tuple) and len(result) > 0 else float(result.sum().item()) if isinstance(result, torch.Tensor) else None
224
- }
225
- with open(save_json, 'w') as f:
226
- json.dump(json_data, f, indent=2)
227
- if verbose:
228
- print(f"\nSaved benchmark results to {save_json}")
229
-
230
- return result, stats
231
-
232
- yield runner
233
-
234
- def set_seed(seed: int):
235
- """Set seeds for reproducibility."""
236
- torch.manual_seed(seed)
237
- if torch.cuda.is_available():
238
- torch.cuda.manual_seed(seed)
239
- torch.cuda.manual_seed_all(seed)
240
- torch.backends.cudnn.deterministic = True
241
- torch.backends.cudnn.benchmark = False
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
megablocks_yamoe/cells/binned_run.py DELETED
@@ -1,195 +0,0 @@
1
- # /// script
2
- # dependencies = [
3
- # "torch",
4
- # "numpy",
5
- # ]
6
- # ///
7
-
8
- import torch
9
- from torch import nn
10
- from torch.nn import functional as F
11
- from bench_utils import to_dtype, tensor_stats, set_seed, bench_context
12
- from config import (
13
- NUM_EXPERTS, HIDDEN_SIZE, TOP_K,
14
- BATCH_SIZE, SEQ_LEN, DTYPE, DEVICE,
15
- WEIGHT_SEED, EXPERT_SEED, INPUT_SEED, GENERAL_SEED
16
- )
17
- from pathlib import Path
18
- import os
19
-
20
- # Discover the upstream artifact directory from env
21
- data_dir = os.environ.get('UVNOTE_INPUT_SAVE_DATA', '.')
22
-
23
- router_weight = torch.load(Path(data_dir) / 'router_weight.pt')
24
- router_bias = torch.load(Path(data_dir) / 'router_bias.pt')
25
- gate_up_proj = torch.load(Path(data_dir) / 'gate_up_proj.pt')
26
- gate_up_proj_bias = torch.load(Path(data_dir) / 'gate_up_proj_bias.pt')
27
- down_proj = torch.load(Path(data_dir) / 'down_proj.pt')
28
- down_proj_bias = torch.load(Path(data_dir) / 'down_proj_bias.pt')
29
-
30
- print("Loaded shared weights from artifacts")
31
- print(f"Router weight sum: {router_weight.sum().item():.6f}")
32
- print(f"Gate/up sum: {gate_up_proj.sum().item():.6f}")
33
- print(f"Down sum: {down_proj.sum().item():.6f}")
34
-
35
- def binned_gather(x, indices, bins, expert_capacity, top_k):
36
- E, H = bins.shape[0], x.shape[1]
37
- out = torch.zeros((E, expert_capacity, H), device=x.device, dtype=x.dtype)
38
- for e in range(E):
39
- start = 0 if e == 0 else bins[e - 1]
40
- end = bins[e]
41
- n = min(end - start, expert_capacity)
42
- for i in range(n):
43
- flat_pos = indices[start + i]
44
- tok = flat_pos // top_k
45
- out[e, i] = x[tok]
46
- return out
47
-
48
- def binned_scatter(x, indices, weights, bins, expert_capacity, top_k):
49
- E, C, H = x.shape
50
- N = indices.shape[0] // top_k
51
- out = torch.zeros((N, top_k, H), dtype=x.dtype, device=x.device)
52
- for e in range(E):
53
- start = 0 if e == 0 else bins[e - 1]
54
- end = bins[e]
55
- n = end - start
56
- if n == 0:
57
- continue
58
- take = min(n, expert_capacity)
59
- for i in range(take):
60
- flat_pos = indices[start + i]
61
- tok = flat_pos // top_k
62
- slot = flat_pos % top_k
63
- scale = weights[flat_pos] if weights is not None else 1.0
64
- out[tok, slot] = x[e, i] * scale
65
- return out.sum(dim=1)
66
-
67
- def sort_tokens_by_expert(router_indices, num_experts):
68
- flat_indices = router_indices.flatten()
69
- sorted_values, sorted_indices = torch.sort(flat_indices)
70
- tokens_per_expert = torch.bincount(sorted_values, minlength=num_experts)
71
- bins = torch.cumsum(tokens_per_expert, dim=0)
72
- return sorted_indices, sorted_values, bins, tokens_per_expert
73
-
74
- def binned_experts_ref(
75
- hidden_states,
76
- router_indices,
77
- routing_weights,
78
- gate_up_proj,
79
- gate_up_proj_bias,
80
- down_proj,
81
- down_proj_bias,
82
- expert_capacity,
83
- ):
84
- B, S, H = hidden_states.shape
85
- E, K = routing_weights.shape[1], router_indices.shape[1]
86
-
87
- indices, _, bins, _ = sort_tokens_by_expert(router_indices, E)
88
- x = binned_gather(hidden_states.view(-1, H), indices, bins, expert_capacity, K)
89
-
90
- gate_up = torch.bmm(x, gate_up_proj)
91
- gate_up += gate_up_proj_bias[..., None, :]
92
-
93
- gate, up = gate_up[..., ::2], gate_up[..., 1::2]
94
-
95
- # clamp to limit
96
- limit = 7.0
97
- gate = gate.clamp(min=None, max=limit)
98
- up = up.clamp(min=-limit, max=limit)
99
-
100
- glu = gate * torch.sigmoid(gate * 1.702)
101
- x = (up + 1) * glu
102
- x = torch.bmm(x, down_proj) + down_proj_bias[..., None, :]
103
-
104
- # build routing weights aligned to (token, slot)
105
- flat_dense = routing_weights.view(-1, E)
106
- flat_router = router_indices.view(-1, K)
107
- selected = torch.gather(flat_dense, 1, flat_router).reshape(-1)
108
-
109
- # scatter back
110
- y = binned_scatter(x, indices, selected, bins, expert_capacity, K)
111
-
112
- return y.view(B, S, H)
113
-
114
- class BinnedRouter(nn.Module):
115
- def __init__(self, router_weight, router_bias):
116
- super().__init__()
117
- self.top_k = TOP_K
118
- self.num_experts = NUM_EXPERTS
119
- self.hidden_dim = HIDDEN_SIZE
120
- self.weight = nn.Parameter(router_weight.clone())
121
- self.bias = nn.Parameter(router_bias.clone())
122
-
123
- def forward(self, hidden_states):
124
- hidden_states = hidden_states.reshape(-1, self.hidden_dim)
125
- router_logits = F.linear(hidden_states, self.weight, self.bias)
126
- router_top_value, router_indices = torch.topk(router_logits, self.top_k, dim=-1)
127
- router_top_value = torch.nn.functional.softmax(router_top_value, dim=1, dtype=router_top_value.dtype)
128
- router_scores = torch.zeros_like(router_logits).scatter_(1, router_indices, router_top_value)
129
- return router_scores, router_indices
130
-
131
- def ceil_div(a, b):
132
- return (a + b - 1) // b
133
-
134
- class BinnedMoEMLP(nn.Module):
135
- def __init__(self, router_weight, router_bias, gate_up_proj, gate_up_proj_bias, down_proj, down_proj_bias):
136
- super().__init__()
137
- self.router = BinnedRouter(router_weight, router_bias)
138
- self.num_experts = NUM_EXPERTS
139
- self.hidden_size = HIDDEN_SIZE
140
- self.top_k = TOP_K
141
-
142
- # Expert weights - use the loaded weights
143
- self.gate_up_proj = nn.Parameter(gate_up_proj.clone())
144
- self.gate_up_proj_bias = nn.Parameter(gate_up_proj_bias.clone())
145
- self.down_proj = nn.Parameter(down_proj.clone())
146
- self.down_proj_bias = nn.Parameter(down_proj_bias.clone())
147
-
148
- def forward(self, hidden_states):
149
- router_scores, router_indices = self.router(hidden_states)
150
- batch_size = hidden_states.shape[0]
151
- expert_capacity = ceil_div(batch_size * self.top_k, self.num_experts)
152
-
153
- output = binned_experts_ref(
154
- hidden_states,
155
- router_indices,
156
- router_scores,
157
- self.gate_up_proj,
158
- self.gate_up_proj_bias,
159
- self.down_proj,
160
- self.down_proj_bias,
161
- expert_capacity,
162
- )
163
-
164
- return output, router_scores
165
-
166
- # Run the model
167
- set_seed(GENERAL_SEED)
168
-
169
- device = torch.device(DEVICE)
170
- dtype = to_dtype(DTYPE)
171
-
172
- print("\n=== Binned Implementation ===")
173
- # Initialize model with loaded weights
174
- model = BinnedMoEMLP(
175
- router_weight.to(device),
176
- router_bias.to(device),
177
- gate_up_proj.to(device),
178
- gate_up_proj_bias.to(device),
179
- down_proj.to(device),
180
- down_proj_bias.to(device)
181
- ).to(device=device)
182
-
183
- print(f"Router weight sum: {model.router.weight.sum().item():.6f}")
184
- print(f"Gate/up proj sum: {model.gate_up_proj.sum().item():.6f}")
185
- print(f"Down proj sum: {model.down_proj.sum().item():.6f}")
186
-
187
- # Generate the same input as Yamoe
188
- set_seed(INPUT_SEED)
189
- x = torch.randn(BATCH_SIZE, SEQ_LEN, HIDDEN_SIZE, device=device, dtype=dtype) * 0.1
190
-
191
- # Benchmark the model with varied inputs to prevent caching artifacts
192
- tokens = BATCH_SIZE * SEQ_LEN
193
- with bench_context(warmup=10, iters=50, device=device, dtype=dtype, tokens=tokens, save_json="binned_results.json", vary_inputs=True) as bench:
194
- output, stats = bench(model, x)
195
- print(f"\nOutput sum: {output[0].sum().item():.6f}")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
megablocks_yamoe/cells/config.py DELETED
@@ -1,27 +0,0 @@
1
- # /// script
2
- # dependencies = [
3
- # "torch",
4
- # "numpy",
5
- # ]
6
- # ///
7
-
8
- """Shared configuration for both implementations."""
9
- import torch
10
-
11
- # Model configuration
12
- NUM_EXPERTS = 128
13
- HIDDEN_SIZE = 1152
14
- INTERMEDIATE_SIZE = 3072
15
- TOP_K = 4
16
-
17
- # Input configuration
18
- BATCH_SIZE = 1
19
- SEQ_LEN = 100
20
- DTYPE = "float32"
21
- DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
22
-
23
- # Seeds for reproducibility
24
- WEIGHT_SEED = 999
25
- EXPERT_SEED = 777
26
- INPUT_SEED = 123
27
- GENERAL_SEED = 42
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
megablocks_yamoe/cells/gptoss_run.py DELETED
@@ -1,147 +0,0 @@
1
- # /// script
2
- # dependencies = [
3
- # "torch",
4
- # "numpy",
5
- # ]
6
- # ///
7
-
8
- import torch
9
- from torch import nn
10
- from torch.nn import functional as F
11
- from bench_utils import to_dtype, tensor_stats, set_seed, bench_context
12
- from config import (
13
- NUM_EXPERTS, HIDDEN_SIZE, TOP_K,
14
- BATCH_SIZE, SEQ_LEN, DTYPE, DEVICE,
15
- WEIGHT_SEED, EXPERT_SEED, INPUT_SEED, GENERAL_SEED
16
- )
17
- from pathlib import Path
18
- import os
19
-
20
- # Discover the upstream artifact directory from env
21
- data_dir = os.environ.get('UVNOTE_INPUT_SAVE_DATA', '.')
22
-
23
- router_weight = torch.load(Path(data_dir) / 'router_weight.pt')
24
- router_bias = torch.load(Path(data_dir) / 'router_bias.pt')
25
- gate_up_proj = torch.load(Path(data_dir) / 'gate_up_proj.pt')
26
- gate_up_proj_bias = torch.load(Path(data_dir) / 'gate_up_proj_bias.pt')
27
- down_proj = torch.load(Path(data_dir) / 'down_proj.pt')
28
- down_proj_bias = torch.load(Path(data_dir) / 'down_proj_bias.pt')
29
-
30
- print("Loaded shared weights from artifacts")
31
- print(f"Router weight sum: {router_weight.sum().item():.6f}")
32
- print(f"Gate/up sum: {gate_up_proj.sum().item():.6f}")
33
- print(f"Down sum: {down_proj.sum().item():.6f}")
34
-
35
- class GptOssRouter(nn.Module):
36
- def __init__(self, router_weight, router_bias):
37
- super().__init__()
38
- self.top_k = TOP_K
39
- self.num_experts = NUM_EXPERTS
40
- self.hidden_dim = HIDDEN_SIZE
41
- self.weight = nn.Parameter(router_weight.clone())
42
- self.bias = nn.Parameter(router_bias.clone())
43
-
44
- def forward(self, hidden_states):
45
- hidden_states = hidden_states.reshape(-1, self.hidden_dim)
46
- router_logits = F.linear(hidden_states, self.weight, self.bias)
47
- router_top_value, router_indices = torch.topk(router_logits, self.top_k, dim=-1)
48
- router_top_value = torch.nn.functional.softmax(router_top_value, dim=1, dtype=router_top_value.dtype)
49
- router_scores = torch.zeros_like(router_logits).scatter_(1, router_indices, router_top_value)
50
- return router_scores, router_indices
51
-
52
- class GptOssExperts(nn.Module):
53
- def __init__(self, gate_up_proj, gate_up_proj_bias, down_proj, down_proj_bias):
54
- super().__init__()
55
- self.num_experts = NUM_EXPERTS
56
- self.hidden_size = HIDDEN_SIZE
57
- self.expert_dim = self.hidden_size
58
- self.gate_up_proj = nn.Parameter(gate_up_proj.clone())
59
- self.gate_up_proj_bias = nn.Parameter(gate_up_proj_bias.clone())
60
- self.down_proj = nn.Parameter(down_proj.clone())
61
- self.down_proj_bias = nn.Parameter(down_proj_bias.clone())
62
- self.alpha = 1.702
63
- self.limit = 7.0
64
-
65
- def forward(self, hidden_states: torch.Tensor, router_indices=None, routing_weights=None) -> torch.Tensor:
66
- batch_size = hidden_states.shape[0]
67
- hidden_states = hidden_states.reshape(-1, self.hidden_size)
68
- num_experts = routing_weights.shape[1]
69
-
70
- if hidden_states.device.type == "cpu" or self.training:
71
- next_states = torch.zeros_like(hidden_states, dtype=hidden_states.dtype, device=hidden_states.device)
72
- with torch.no_grad():
73
- expert_mask = torch.nn.functional.one_hot(router_indices, num_classes=num_experts)
74
- expert_mask = expert_mask.permute(2, 1, 0)
75
- expert_hit = torch.greater(expert_mask.sum(dim=(-1, -2)), 0).nonzero()
76
-
77
- for expert_idx in expert_hit[:]:
78
- expert_idx = expert_idx[0]
79
- with torch.no_grad():
80
- _, token_idx = torch.where(expert_mask[expert_idx])
81
- current_state = hidden_states[token_idx]
82
- gate_up = current_state @ self.gate_up_proj[expert_idx] + self.gate_up_proj_bias[expert_idx]
83
- gate, up = gate_up[..., ::2], gate_up[..., 1::2]
84
- gate = gate.clamp(min=None, max=self.limit)
85
- up = up.clamp(min=-self.limit, max=self.limit)
86
- glu = gate * torch.sigmoid(gate * self.alpha)
87
- gated_output = (up + 1) * glu
88
- out = gated_output @ self.down_proj[expert_idx] + self.down_proj_bias[expert_idx]
89
- weighted_output = out * routing_weights[token_idx, expert_idx, None]
90
- next_states.index_add_(0, token_idx, weighted_output.to(hidden_states.dtype))
91
- next_states = next_states.view(batch_size, -1, self.hidden_size)
92
- else:
93
- hidden_states = hidden_states.repeat(num_experts, 1)
94
- hidden_states = hidden_states.view(num_experts, -1, self.hidden_size)
95
- gate_up = torch.bmm(hidden_states, self.gate_up_proj) + self.gate_up_proj_bias[..., None, :]
96
- gate, up = gate_up[..., ::2], gate_up[..., 1::2]
97
- gate = gate.clamp(min=None, max=self.limit)
98
- up = up.clamp(min=-self.limit, max=self.limit)
99
- glu = gate * torch.sigmoid(gate * self.alpha)
100
- next_states = torch.bmm(((up + 1) * glu), self.down_proj)
101
- next_states = next_states + self.down_proj_bias[..., None, :]
102
- next_states = next_states.view(num_experts, batch_size, -1, self.hidden_size)
103
- next_states = next_states * routing_weights.transpose(0, 1).view(num_experts, batch_size, -1)[..., None]
104
- next_states = next_states.sum(dim=0)
105
- return next_states
106
-
107
- class GptOssMoEMLP(nn.Module):
108
- def __init__(self, router_weight, router_bias, gate_up_proj, gate_up_proj_bias, down_proj, down_proj_bias):
109
- super().__init__()
110
- self.router = GptOssRouter(router_weight, router_bias)
111
- self.experts = GptOssExperts(gate_up_proj, gate_up_proj_bias, down_proj, down_proj_bias)
112
-
113
- def forward(self, hidden_states):
114
- router_scores, router_indices = self.router(hidden_states)
115
- routed_out = self.experts(hidden_states, router_indices=router_indices, routing_weights=router_scores)
116
- return routed_out, router_scores
117
-
118
- # Run the model
119
- set_seed(GENERAL_SEED)
120
-
121
- device = torch.device(DEVICE)
122
- dtype = to_dtype(DTYPE)
123
-
124
- print("\n=== GPT-OSS Implementation ===")
125
- # Initialize model with loaded weights
126
- model = GptOssMoEMLP(
127
- router_weight.to(device),
128
- router_bias.to(device),
129
- gate_up_proj.to(device),
130
- gate_up_proj_bias.to(device),
131
- down_proj.to(device),
132
- down_proj_bias.to(device)
133
- ).to(device=device)
134
-
135
- print(f"Router weight sum: {model.router.weight.sum().item():.6f}")
136
- print(f"Gate/up proj sum: {model.experts.gate_up_proj.sum().item():.6f}")
137
- print(f"Down proj sum: {model.experts.down_proj.sum().item():.6f}")
138
-
139
- # Generate the same input as other implementations
140
- set_seed(INPUT_SEED)
141
- x = torch.randn(BATCH_SIZE, SEQ_LEN, HIDDEN_SIZE, device=device, dtype=dtype) * 0.1
142
-
143
- # Benchmark the model with varied inputs to prevent caching artifacts
144
- tokens = BATCH_SIZE * SEQ_LEN
145
- with bench_context(warmup=10, iters=50, device=device, dtype=dtype, tokens=tokens, save_json="gptoss_results.json", vary_inputs=True) as bench:
146
- output, stats = bench(model, x)
147
- print(f"\nOutput sum: {output[0].sum().item():.6f}")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
megablocks_yamoe/cells/gptoss_training_run.py DELETED
@@ -1,138 +0,0 @@
1
- # /// script
2
- # dependencies = [
3
- # "torch",
4
- # "numpy",
5
- # ]
6
- # ///
7
-
8
- import torch
9
- from torch import nn
10
- from torch.nn import functional as F
11
- from bench_utils import to_dtype, tensor_stats, set_seed, bench_context
12
- from config import (
13
- NUM_EXPERTS, HIDDEN_SIZE, TOP_K,
14
- BATCH_SIZE, SEQ_LEN, DTYPE, DEVICE,
15
- WEIGHT_SEED, EXPERT_SEED, INPUT_SEED, GENERAL_SEED
16
- )
17
- from pathlib import Path
18
- import os
19
-
20
- # Discover the upstream artifact directory from env
21
- data_dir = os.environ.get('UVNOTE_INPUT_SAVE_DATA', '.')
22
-
23
- router_weight = torch.load(Path(data_dir) / 'router_weight.pt')
24
- router_bias = torch.load(Path(data_dir) / 'router_bias.pt')
25
- gate_up_proj = torch.load(Path(data_dir) / 'gate_up_proj.pt')
26
- gate_up_proj_bias = torch.load(Path(data_dir) / 'gate_up_proj_bias.pt')
27
- down_proj = torch.load(Path(data_dir) / 'down_proj.pt')
28
- down_proj_bias = torch.load(Path(data_dir) / 'down_proj_bias.pt')
29
-
30
- print("Loaded shared weights from artifacts")
31
- print(f"Router weight sum: {router_weight.sum().item():.6f}")
32
- print(f"Gate/up sum: {gate_up_proj.sum().item():.6f}")
33
- print(f"Down sum: {down_proj.sum().item():.6f}")
34
-
35
- class GptOssTrainingRouter(nn.Module):
36
- def __init__(self, router_weight, router_bias):
37
- super().__init__()
38
- self.top_k = TOP_K
39
- self.num_experts = NUM_EXPERTS
40
- self.hidden_dim = HIDDEN_SIZE
41
- self.weight = nn.Parameter(router_weight.clone())
42
- self.bias = nn.Parameter(router_bias.clone())
43
-
44
- def forward(self, hidden_states):
45
- hidden_states = hidden_states.reshape(-1, self.hidden_dim)
46
- router_logits = F.linear(hidden_states, self.weight, self.bias)
47
- router_top_value, router_indices = torch.topk(router_logits, self.top_k, dim=-1)
48
- router_top_value = torch.nn.functional.softmax(router_top_value, dim=1, dtype=router_top_value.dtype)
49
- router_scores = torch.zeros_like(router_logits).scatter_(1, router_indices, router_top_value)
50
- return router_scores, router_indices
51
-
52
- class GptOssTrainingExperts(nn.Module):
53
- def __init__(self, gate_up_proj, gate_up_proj_bias, down_proj, down_proj_bias):
54
- super().__init__()
55
- self.num_experts = NUM_EXPERTS
56
- self.hidden_size = HIDDEN_SIZE
57
- self.expert_dim = self.hidden_size
58
- self.gate_up_proj = nn.Parameter(gate_up_proj.clone())
59
- self.gate_up_proj_bias = nn.Parameter(gate_up_proj_bias.clone())
60
- self.down_proj = nn.Parameter(down_proj.clone())
61
- self.down_proj_bias = nn.Parameter(down_proj_bias.clone())
62
- self.alpha = 1.702
63
- self.limit = 7.0
64
-
65
- def forward(self, hidden_states: torch.Tensor, router_indices=None, routing_weights=None) -> torch.Tensor:
66
- batch_size = hidden_states.shape[0]
67
- hidden_states = hidden_states.reshape(-1, self.hidden_size)
68
- num_experts = routing_weights.shape[1]
69
-
70
- # Force training mode path (expert loop instead of batched)
71
- next_states = torch.zeros_like(hidden_states, dtype=hidden_states.dtype, device=hidden_states.device)
72
- with torch.no_grad():
73
- expert_mask = torch.nn.functional.one_hot(router_indices, num_classes=num_experts)
74
- expert_mask = expert_mask.permute(2, 1, 0)
75
- expert_hit = torch.greater(expert_mask.sum(dim=(-1, -2)), 0).nonzero()
76
-
77
- for expert_idx in expert_hit[:]:
78
- expert_idx = expert_idx[0]
79
- with torch.no_grad():
80
- _, token_idx = torch.where(expert_mask[expert_idx])
81
- current_state = hidden_states[token_idx]
82
- gate_up = current_state @ self.gate_up_proj[expert_idx] + self.gate_up_proj_bias[expert_idx]
83
- gate, up = gate_up[..., ::2], gate_up[..., 1::2]
84
- gate = gate.clamp(min=None, max=self.limit)
85
- up = up.clamp(min=-self.limit, max=self.limit)
86
- glu = gate * torch.sigmoid(gate * self.alpha)
87
- gated_output = (up + 1) * glu
88
- out = gated_output @ self.down_proj[expert_idx] + self.down_proj_bias[expert_idx]
89
- weighted_output = out * routing_weights[token_idx, expert_idx, None]
90
- next_states.index_add_(0, token_idx, weighted_output.to(hidden_states.dtype))
91
- next_states = next_states.view(batch_size, -1, self.hidden_size)
92
- return next_states
93
-
94
- class GptOssTrainingMoEMLP(nn.Module):
95
- def __init__(self, router_weight, router_bias, gate_up_proj, gate_up_proj_bias, down_proj, down_proj_bias):
96
- super().__init__()
97
- self.router = GptOssTrainingRouter(router_weight, router_bias)
98
- self.experts = GptOssTrainingExperts(gate_up_proj, gate_up_proj_bias, down_proj, down_proj_bias)
99
-
100
- def forward(self, hidden_states):
101
- router_scores, router_indices = self.router(hidden_states)
102
- routed_out = self.experts(hidden_states, router_indices=router_indices, routing_weights=router_scores)
103
- return routed_out, router_scores
104
-
105
- # Run the model
106
- set_seed(GENERAL_SEED)
107
-
108
- device = torch.device(DEVICE)
109
- dtype = to_dtype(DTYPE)
110
-
111
- print("\n=== GPT-OSS Implementation (Training Mode - Expert Loop) ===")
112
- # Initialize model with loaded weights and force training mode
113
- model = GptOssTrainingMoEMLP(
114
- router_weight.to(device),
115
- router_bias.to(device),
116
- gate_up_proj.to(device),
117
- gate_up_proj_bias.to(device),
118
- down_proj.to(device),
119
- down_proj_bias.to(device)
120
- ).to(device=device)
121
-
122
- # Set to training mode to force expert loop path
123
- model.train()
124
-
125
- print(f"Router weight sum: {model.router.weight.sum().item():.6f}")
126
- print(f"Gate/up proj sum: {model.experts.gate_up_proj.sum().item():.6f}")
127
- print(f"Down proj sum: {model.experts.down_proj.sum().item():.6f}")
128
- print(f"Model training mode: {model.training}")
129
-
130
- # Generate the same input as other implementations
131
- set_seed(INPUT_SEED)
132
- x = torch.randn(BATCH_SIZE, SEQ_LEN, HIDDEN_SIZE, device=device, dtype=dtype) * 0.1
133
-
134
- # Benchmark the model with varied inputs to prevent caching artifacts
135
- tokens = BATCH_SIZE * SEQ_LEN
136
- with bench_context(warmup=10, iters=50, device=device, dtype=dtype, tokens=tokens, save_json="gptoss_training_results.json", vary_inputs=True) as bench:
137
- output, stats = bench(model, x)
138
- print(f"\nOutput sum: {output[0].sum().item():.6f}")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
megablocks_yamoe/cells/megablocks_run.py DELETED
@@ -1,103 +0,0 @@
1
- # /// script
2
- # dependencies = [
3
- # "torch",
4
- # "numpy",
5
- # "kernels",
6
- # ]
7
- # ///
8
-
9
- import torch
10
- from torch import nn
11
- from torch.nn import functional as F
12
- from kernels import get_kernel, get_local_kernel
13
- from bench_utils import to_dtype, tensor_stats, set_seed, bench_context
14
- from config import (
15
- NUM_EXPERTS, HIDDEN_SIZE, TOP_K,
16
- BATCH_SIZE, SEQ_LEN, DTYPE, DEVICE,
17
- WEIGHT_SEED, EXPERT_SEED, INPUT_SEED, GENERAL_SEED
18
- )
19
- from pathlib import Path
20
- from collections import namedtuple
21
- import os
22
-
23
- # Discover the upstream artifact directory from env
24
- data_dir = os.environ.get('UVNOTE_INPUT_SAVE_DATA', '.')
25
-
26
- print(f"Loading weights from: {data_dir}")
27
-
28
- router_weight = torch.load(Path(data_dir) / 'router_weight.pt')
29
- router_bias = torch.load(Path(data_dir) / 'router_bias.pt')
30
- gate_up_proj = torch.load(Path(data_dir) / 'gate_up_proj.pt')
31
- gate_up_proj_bias = torch.load(Path(data_dir) / 'gate_up_proj_bias.pt')
32
- down_proj = torch.load(Path(data_dir) / 'down_proj.pt')
33
- down_proj_bias = torch.load(Path(data_dir) / 'down_proj_bias.pt')
34
-
35
- print("Loaded shared weights from artifacts")
36
- print(f"Router weight sum: {router_weight.sum().item():.6f}")
37
- print(f"Gate/up sum: {gate_up_proj.sum().item():.6f}")
38
- print(f"Down sum: {down_proj.sum().item():.6f}")
39
-
40
- def build_megablocks_model(device: torch.device):
41
- # Download optimized kernels from the Hugging Face hub
42
- megablocks = get_kernel("kernels-community/megablocks", revision="v0.0.2")
43
- model = megablocks.layers.MegaBlocksMoeMLP()
44
-
45
- # Create attribute container for expert weights
46
- model.experts = namedtuple(
47
- "Experts", ["gate_up_proj", "gate_up_proj_bias", "down_proj", "down_proj_bias", "hidden_size"]
48
- )
49
-
50
- # Use loaded router weights for consistency
51
- model.router = torch.nn.Linear(HIDDEN_SIZE, NUM_EXPERTS, device=device)
52
- with torch.no_grad():
53
- model.router.weight.copy_(router_weight)
54
- model.router.bias.copy_(router_bias)
55
-
56
- # Attach loaded expert weights to the experts container
57
- e = model.experts
58
- e.alpha = 1.702
59
- e.capacity_factor = 32
60
- e.gate_up_proj = torch.nn.Parameter(gate_up_proj.clone().to(device))
61
- e.gate_up_proj_bias = torch.nn.Parameter(gate_up_proj_bias.clone().to(device))
62
- e.down_proj = torch.nn.Parameter(down_proj.clone().to(device))
63
- e.down_proj_bias = torch.nn.Parameter(down_proj_bias.clone().to(device))
64
- e.hidden_size = HIDDEN_SIZE
65
-
66
- # Log weight statistics for comparison
67
- print(f"[MegaBlocks] Router weight sum: {model.router.weight.sum().item():.6f}")
68
- print(f"[MegaBlocks] Gate/up projection shape: {tuple(e.gate_up_proj.shape)}, sum: {e.gate_up_proj.sum().item():.6f}")
69
- print(f"[MegaBlocks] Down projection shape: {tuple(e.down_proj.shape)}, sum: {e.down_proj.sum().item():.6f}")
70
-
71
- return model
72
-
73
- # Create a wrapper to match the interface of other implementations
74
- class MegaBlocksMoEWrapper(nn.Module):
75
- def __init__(self, megablocks_model):
76
- super().__init__()
77
- self.model = megablocks_model
78
-
79
- def forward(self, hidden_states):
80
- # MegaBlocks expects input in the format (batch, seq_len, hidden_dim)
81
- output, dummy_routing_weights = self.model(hidden_states)
82
- return output, dummy_routing_weights
83
-
84
- # Run the model
85
- set_seed(GENERAL_SEED)
86
-
87
- device = torch.device(DEVICE)
88
- dtype = to_dtype(DTYPE)
89
-
90
- print("\n=== MegaBlocks Implementation ===")
91
- # Build MegaBlocks model with loaded weights
92
- megablocks_model = build_megablocks_model(device)
93
- model = MegaBlocksMoEWrapper(megablocks_model).to(device=device)
94
-
95
- # Generate the same input as other implementations
96
- set_seed(INPUT_SEED)
97
- x = torch.randn(BATCH_SIZE, SEQ_LEN, HIDDEN_SIZE, device=device, dtype=dtype) * 0.1
98
-
99
- # Benchmark the model with varied inputs to prevent caching artifacts
100
- tokens = BATCH_SIZE * SEQ_LEN
101
- with bench_context(warmup=10, iters=50, device=device, dtype=dtype, tokens=tokens, save_json="megablocks_results.json", vary_inputs=True) as bench:
102
- output, stats = bench(model, x)
103
- print(f"\nOutput sum: {output[0].sum().item():.6f}")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
megablocks_yamoe/cells/nv.py DELETED
@@ -1,3 +0,0 @@
1
- import subprocess
2
-
3
- print(subprocess.run(["nvidia-smi"], capture_output=True, text=True).stdout)
 
 
 
 
megablocks_yamoe/cells/save_data.py DELETED
@@ -1,42 +0,0 @@
1
- # /// script
2
- # dependencies = [
3
- # "torch",
4
- # "numpy",
5
- # ]
6
- # ///
7
-
8
- """
9
- Generate deterministic shared weights once and save as artifacts so
10
- both implementations load identical parameters.
11
- """
12
- import torch
13
- from config import NUM_EXPERTS, HIDDEN_SIZE, WEIGHT_SEED, EXPERT_SEED
14
-
15
- def save_shared_weights():
16
- # Router: Kaiming uniform as used by both, bias zeros
17
- torch.manual_seed(WEIGHT_SEED)
18
- router_weight = torch.empty(NUM_EXPERTS, HIDDEN_SIZE)
19
- torch.nn.init.kaiming_uniform_(router_weight)
20
- router_bias = torch.zeros(NUM_EXPERTS)
21
-
22
- # Experts: normal(0, 0.02), biases zeros
23
- torch.manual_seed(EXPERT_SEED)
24
- gate_up_proj = torch.empty(NUM_EXPERTS, HIDDEN_SIZE, 2 * HIDDEN_SIZE).normal_(mean=0.0, std=0.02)
25
- gate_up_proj_bias = torch.zeros(NUM_EXPERTS, 2 * HIDDEN_SIZE)
26
- down_proj = torch.empty(NUM_EXPERTS, HIDDEN_SIZE, HIDDEN_SIZE).normal_(mean=0.0, std=0.02)
27
- down_proj_bias = torch.zeros(NUM_EXPERTS, HIDDEN_SIZE)
28
-
29
- # Save artifacts
30
- torch.save(router_weight, 'router_weight.pt')
31
- torch.save(router_bias, 'router_bias.pt')
32
- torch.save(gate_up_proj, 'gate_up_proj.pt')
33
- torch.save(gate_up_proj_bias, 'gate_up_proj_bias.pt')
34
- torch.save(down_proj, 'down_proj.pt')
35
- torch.save(down_proj_bias, 'down_proj_bias.pt')
36
-
37
- print("Saved shared weights to artifacts")
38
- print(f"Router weight sum: {router_weight.sum().item():.6f}")
39
- print(f"Gate/up sum: {gate_up_proj.sum().item():.6f}")
40
- print(f"Down sum: {down_proj.sum().item():.6f}")
41
-
42
- save_shared_weights()