drbh
commited on
Commit
·
30c62e2
1
Parent(s):
08478da
fix: remove debug build
Browse filesThis view is limited to 50 files because it contains too many changes.
See raw diff
- flash_attn/artifacts/benchmark/Attention Benchmark.csv +0 -7
- flash_attn/artifacts/benchmark/Attention Benchmark.png +0 -3
- flash_attn/artifacts/benchmark/results.html +0 -3
- flash_attn/benchmark.html +0 -0
- flash_attn/cells/benchmark.py +0 -343
- flash_attn/cells/nv.py +0 -3
- flash_attn/impls/artifacts/benchmark/attn.jsonl +0 -6
- flash_attn/impls/artifacts/benchmark_default/attn_default.jsonl +0 -6
- flash_attn/impls/artifacts/benchmark_max_autotune/attn_max_autotune.jsonl +0 -6
- flash_attn/impls/cells/benchmark.py +0 -71
- flash_attn/impls/cells/benchmark_default.py +0 -70
- flash_attn/impls/cells/benchmark_max_autotune.py +0 -70
- flash_attn/impls/cells/nv.py +0 -3
- flash_attn/impls/compiled_variants.html +0 -0
- flash_attn/impls/flash_attention.html +0 -0
- flash_attn/impls/hf_kernels_flash_attn.html +0 -0
- flash_attn/impls/hf_kernels_flash_attn3.html +0 -0
- flash_attn/impls/index.html +0 -94
- flash_attn/impls/mem_efficient_attention.html +0 -0
- flash_attn/impls/sage_attention.html +0 -0
- flash_attn/impls/xformers.html +0 -0
- flash_attn/index.html +0 -89
- flash_attn/results/artifacts/combine/latency.csv +0 -43
- flash_attn/results/artifacts/combine/latency.png +0 -3
- flash_attn/results/artifacts/combine/latency.svg +0 -3
- flash_attn/results/cells/combine.py +0 -319
- flash_attn/results/combined_results.html +0 -0
- flash_attn/results/index.html +0 -88
- index.html +0 -85
- megablocks/cells/forward_and_backward.py +0 -196
- megablocks/cells/forward_and_backward_no_kernel.py +0 -196
- megablocks/cells/forward_only.py +0 -101
- megablocks/cells/no_kernels.py +0 -98
- megablocks/cells/nv.py +0 -3
- megablocks/index.html +0 -24
- megablocks/megablocks_only.html +0 -0
- megablocks_yamoe/artifacts/binned_run/binned_results.json +0 -24
- megablocks_yamoe/artifacts/gptoss_run/gptoss_results.json +0 -24
- megablocks_yamoe/artifacts/gptoss_training_run/gptoss_training_results.json +0 -24
- megablocks_yamoe/artifacts/yamoe_run/yamoe_results.json +0 -24
- megablocks_yamoe/cells/__pycache__/bench_utils.cpython-311.pyc +0 -0
- megablocks_yamoe/cells/__pycache__/config.cpython-311.pyc +0 -0
- megablocks_yamoe/cells/bench_utils.py +0 -241
- megablocks_yamoe/cells/binned_run.py +0 -195
- megablocks_yamoe/cells/config.py +0 -27
- megablocks_yamoe/cells/gptoss_run.py +0 -147
- megablocks_yamoe/cells/gptoss_training_run.py +0 -138
- megablocks_yamoe/cells/megablocks_run.py +0 -103
- megablocks_yamoe/cells/nv.py +0 -3
- megablocks_yamoe/cells/save_data.py +0 -42
flash_attn/artifacts/benchmark/Attention Benchmark.csv
DELETED
|
@@ -1,7 +0,0 @@
|
|
| 1 |
-
seq_len,torch_cudnn,torch_cudnn_compile_d,torch_cudnn_compile_ma,torch_flash,torch_flash_compile_d,torch_flash_compile_ma,hf_flash_attn,hf_flash_attn3
|
| 2 |
-
4224.000000,3.801472,3.790064,4.182320,3.968000,3.957824,4.311152,3.398160,3.330400
|
| 3 |
-
4352.000000,4.082944,4.082912,4.413488,4.400000,4.391936,4.738048,3.837424,3.758208
|
| 4 |
-
4416.000000,4.142624,4.135648,4.484160,4.452304,4.446096,4.792480,3.892064,3.864128
|
| 5 |
-
4480.000000,4.206144,4.198752,4.551808,4.530752,4.522944,4.873760,3.949344,3.870224
|
| 6 |
-
4544.000000,4.438320,4.433104,4.787584,4.584160,4.576640,4.934304,4.008960,3.974672
|
| 7 |
-
4608.000000,4.502432,4.495456,4.871872,4.660192,4.651040,5.029792,4.065616,3.984160
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
flash_attn/artifacts/benchmark/Attention Benchmark.png
DELETED
Git LFS Details
|
flash_attn/artifacts/benchmark/results.html
DELETED
|
@@ -1,3 +0,0 @@
|
|
| 1 |
-
<html><body>
|
| 2 |
-
<image src="Attention Benchmark.png"/>
|
| 3 |
-
</body></html>
|
|
|
|
|
|
|
|
|
|
|
|
flash_attn/benchmark.html
DELETED
|
The diff for this file is too large to render.
See raw diff
|
|
|
flash_attn/cells/benchmark.py
DELETED
|
@@ -1,343 +0,0 @@
|
|
| 1 |
-
# /// script
|
| 2 |
-
# dependencies = [
|
| 3 |
-
# "numpy",
|
| 4 |
-
# "torch",
|
| 5 |
-
# "kernels",
|
| 6 |
-
# "pandas",
|
| 7 |
-
# "matplotlib"
|
| 8 |
-
# ]
|
| 9 |
-
# ///
|
| 10 |
-
# Benchmarking common shapes for Flux 1024x1024px image + varying text sequence lengths
|
| 11 |
-
|
| 12 |
-
import functools
|
| 13 |
-
import os
|
| 14 |
-
import pathlib
|
| 15 |
-
|
| 16 |
-
import matplotlib.pyplot as plt
|
| 17 |
-
import torch
|
| 18 |
-
import torch._dynamo.config
|
| 19 |
-
import triton
|
| 20 |
-
import triton.language as tl
|
| 21 |
-
|
| 22 |
-
try:
|
| 23 |
-
from flash_attn import flash_attn_func
|
| 24 |
-
except:
|
| 25 |
-
flash_attn_func = None
|
| 26 |
-
print("Flash Attention 2 not found.")
|
| 27 |
-
|
| 28 |
-
try:
|
| 29 |
-
from flash_attn_interface import flash_attn_func as flash_attn_3_func
|
| 30 |
-
except:
|
| 31 |
-
flash_attn_3_func = None
|
| 32 |
-
print("Flash Attention 3 not found.")
|
| 33 |
-
|
| 34 |
-
try:
|
| 35 |
-
from kernels import get_kernel
|
| 36 |
-
hf_kernels_flash_attn = get_kernel("kernels-community/flash-attn")
|
| 37 |
-
hf_kernels_flash_attn_3 = get_kernel("kernels-community/flash-attn3")
|
| 38 |
-
except:
|
| 39 |
-
hf_kernels_flash_attn = None
|
| 40 |
-
hf_kernels_flash_attn_3 = None
|
| 41 |
-
print("HF Kernels not found.")
|
| 42 |
-
|
| 43 |
-
try:
|
| 44 |
-
from sageattention import sageattn_qk_int8_pv_fp16_cuda, sageattn_qk_int8_pv_fp16_triton, sageattn_qk_int8_pv_fp8_cuda_sm90
|
| 45 |
-
except:
|
| 46 |
-
sageattn_qk_int8_pv_fp16_cuda = None
|
| 47 |
-
sageattn_qk_int8_pv_fp16_triton = None
|
| 48 |
-
sageattn_qk_int8_pv_fp8_cuda_sm90 = None
|
| 49 |
-
print("SageAttention not found.")
|
| 50 |
-
|
| 51 |
-
try:
|
| 52 |
-
from transformer_engine.pytorch.attention import DotProductAttention
|
| 53 |
-
except:
|
| 54 |
-
DotProductAttention = None
|
| 55 |
-
print("Transformer Engine not found.")
|
| 56 |
-
|
| 57 |
-
try:
|
| 58 |
-
import xformers.ops as xops
|
| 59 |
-
except:
|
| 60 |
-
xops = None
|
| 61 |
-
print("xFormers not found.")
|
| 62 |
-
|
| 63 |
-
|
| 64 |
-
plt.rcParams.update({
|
| 65 |
-
"figure.figsize": (12, 10),
|
| 66 |
-
"figure.dpi": 120,
|
| 67 |
-
"font.size": 10,
|
| 68 |
-
"axes.titlesize": 12,
|
| 69 |
-
"axes.labelsize": 14,
|
| 70 |
-
"xtick.labelsize": 10,
|
| 71 |
-
"ytick.labelsize": 10,
|
| 72 |
-
"legend.fontsize": 8,
|
| 73 |
-
"axes.grid": True,
|
| 74 |
-
"grid.alpha": 0.3,
|
| 75 |
-
"grid.linestyle": "--",
|
| 76 |
-
"lines.linewidth": 2.0,
|
| 77 |
-
"lines.markersize": 6,
|
| 78 |
-
"legend.frameon": True,
|
| 79 |
-
"legend.framealpha": 0.9,
|
| 80 |
-
"legend.loc": "best",
|
| 81 |
-
"axes.spines.top": False,
|
| 82 |
-
"axes.spines.right": False,
|
| 83 |
-
})
|
| 84 |
-
|
| 85 |
-
|
| 86 |
-
# We want to compare the best compiled version for each specific shape (dynamic=False)
|
| 87 |
-
torch._dynamo.config.cache_size_limit = 10000
|
| 88 |
-
|
| 89 |
-
# We need to suppress_errors for FA3 to work. It makes it run in eager mode.
|
| 90 |
-
# I can't seem to get it to work any other way under torch.compile, so any suggestions are welcome!
|
| 91 |
-
torch._dynamo.config.suppress_errors = True
|
| 92 |
-
|
| 93 |
-
# output_dir = pathlib.Path("dump_attention_benchmark")
|
| 94 |
-
# output_dir.mkdir(parents=True, exist_ok=True)
|
| 95 |
-
|
| 96 |
-
output_dir = pathlib.Path(".") # output to current directory for upload
|
| 97 |
-
|
| 98 |
-
batch_size = 1
|
| 99 |
-
num_attention_heads = 24
|
| 100 |
-
attention_head_dim = 128
|
| 101 |
-
image_sequence_length = 4096 # 1024x1024px
|
| 102 |
-
text_sequence_lengths = [128, 256, 320, 384, 448, 512]
|
| 103 |
-
sequence_lengths = [image_sequence_length + i for i in text_sequence_lengths]
|
| 104 |
-
|
| 105 |
-
|
| 106 |
-
def _attention_torch(query, key, value, *, backend):
|
| 107 |
-
query, key, value = (x.transpose(1, 2).contiguous() for x in (query, key, value))
|
| 108 |
-
with torch.nn.attention.sdpa_kernel(backend):
|
| 109 |
-
out = torch.nn.functional.scaled_dot_product_attention(query, key, value)
|
| 110 |
-
out = out.transpose(1, 2).contiguous()
|
| 111 |
-
return out
|
| 112 |
-
|
| 113 |
-
|
| 114 |
-
_compiled_attention_torch_default = torch.compile(_attention_torch, mode="default", fullgraph=True, dynamic=False)
|
| 115 |
-
def _attention_torch_compile_default(query, key, value, *, backend):
|
| 116 |
-
return _compiled_attention_torch_default(query, key, value, backend=backend)
|
| 117 |
-
|
| 118 |
-
|
| 119 |
-
_compiled_attention_torch_max_autotune = torch.compile(_attention_torch, mode="max-autotune", fullgraph=True, dynamic=False)
|
| 120 |
-
def _attention_torch_compile_max_autotune(query, key, value, *, backend):
|
| 121 |
-
return _compiled_attention_torch_max_autotune(query, key, value, backend=backend)
|
| 122 |
-
|
| 123 |
-
|
| 124 |
-
def _attention_flash_attn_2(query, key, value):
|
| 125 |
-
return flash_attn_func(query, key, value)
|
| 126 |
-
|
| 127 |
-
|
| 128 |
-
_compiled_flash_attn_2_default = torch.compile(_attention_flash_attn_2, mode="default", fullgraph=True, dynamic=False)
|
| 129 |
-
def _attention_flash_attn_2_compile_default(query, key, value):
|
| 130 |
-
return _compiled_flash_attn_2_default(query, key, value)
|
| 131 |
-
|
| 132 |
-
|
| 133 |
-
_compiled_flash_attn_2_max_autotune = torch.compile(_attention_flash_attn_2, mode="max-autotune", fullgraph=True, dynamic=False)
|
| 134 |
-
def _attention_flash_attn_2_compile_max_autotune(query, key, value):
|
| 135 |
-
return _compiled_flash_attn_2_max_autotune(query, key, value)
|
| 136 |
-
|
| 137 |
-
|
| 138 |
-
# For fullgraph=True tracing to be compatible
|
| 139 |
-
@torch.library.custom_op("flash_attn_3::_flash_attn_forward", mutates_args=(), device_types="cuda")
|
| 140 |
-
def _wrapped_flash_attn_3(query: torch.Tensor, key: torch.Tensor, value: torch.Tensor) -> torch.Tensor:
|
| 141 |
-
out, lse = flash_attn_3_func(query, key, value)
|
| 142 |
-
return out
|
| 143 |
-
|
| 144 |
-
|
| 145 |
-
@torch.library.register_fake("flash_attn_3::_flash_attn_forward")
|
| 146 |
-
def _(query: torch.Tensor, key: torch.Tensor, value: torch.Tensor) -> torch.Tensor:
|
| 147 |
-
return torch.empty_like(query)
|
| 148 |
-
|
| 149 |
-
|
| 150 |
-
def _attention_flash_attn_3(query, key, value):
|
| 151 |
-
out = _wrapped_flash_attn_3(query, key, value)
|
| 152 |
-
return out
|
| 153 |
-
|
| 154 |
-
|
| 155 |
-
_compiled_flash_attn_3_default = torch.compile(_attention_flash_attn_3, mode="default", fullgraph=True, dynamic=False)
|
| 156 |
-
def _attention_flash_attn_3_compile_default(query, key, value):
|
| 157 |
-
return _compiled_flash_attn_3_default(query, key, value)
|
| 158 |
-
|
| 159 |
-
|
| 160 |
-
_compiled_flash_attn_3_max_autotune = torch.compile(_attention_flash_attn_3, mode="max-autotune", fullgraph=True, dynamic=False)
|
| 161 |
-
def _attention_flash_attn_3_compile_max_autotune(query, key, value):
|
| 162 |
-
return _compiled_flash_attn_3_max_autotune(query, key, value)
|
| 163 |
-
|
| 164 |
-
|
| 165 |
-
def _attention_hf_kernels_flash_attn(query, key, value):
|
| 166 |
-
return hf_kernels_flash_attn.fwd(query, key, value, is_causal=False)[0]
|
| 167 |
-
|
| 168 |
-
|
| 169 |
-
def _attention_hf_kernels_flash_attn3(query, key, value):
|
| 170 |
-
return hf_kernels_flash_attn_3.flash_attn_func(query, key, value, causal=False)[0]
|
| 171 |
-
|
| 172 |
-
|
| 173 |
-
def _attention_sageattn_qk_int8_pv_fp16_cuda(query, key, value):
|
| 174 |
-
return sageattn_qk_int8_pv_fp16_cuda(query, key, value, tensor_layout="NHD")
|
| 175 |
-
|
| 176 |
-
|
| 177 |
-
def _attention_sageattn_qk_int8_pv_fp16_triton(query, key, value):
|
| 178 |
-
return sageattn_qk_int8_pv_fp16_triton(query, key, value, tensor_layout="NHD")
|
| 179 |
-
|
| 180 |
-
|
| 181 |
-
def _attention_sageattn_qk_int8_pv_fp8_cuda_sm90(query, key, value):
|
| 182 |
-
return sageattn_qk_int8_pv_fp8_cuda_sm90(query, key, value, tensor_layout="NHD")
|
| 183 |
-
|
| 184 |
-
|
| 185 |
-
if DotProductAttention is not None:
|
| 186 |
-
def set_te_backend(backend):
|
| 187 |
-
# must be applied before first use of
|
| 188 |
-
# transformer_engine.pytorch.attention
|
| 189 |
-
os.environ["NVTE_FLASH_ATTN"] = '0'
|
| 190 |
-
os.environ["NVTE_FUSED_ATTN"] = '0'
|
| 191 |
-
os.environ["NVTE_UNFUSED_ATTN"] = '0'
|
| 192 |
-
if backend == 'flash':
|
| 193 |
-
os.environ["NVTE_FLASH_ATTN"] = '1'
|
| 194 |
-
if backend == 'fused':
|
| 195 |
-
os.environ["NVTE_FUSED_ATTN"] = '1'
|
| 196 |
-
if backend == 'unfused':
|
| 197 |
-
os.environ["NVTE_UNFUSED_ATTN"] = '1'
|
| 198 |
-
|
| 199 |
-
set_te_backend("fused")
|
| 200 |
-
te_attn_fn = DotProductAttention(
|
| 201 |
-
num_attention_heads=num_attention_heads,
|
| 202 |
-
kv_channels=attention_head_dim,
|
| 203 |
-
qkv_format="bshd",
|
| 204 |
-
attn_mask_type="no_mask",
|
| 205 |
-
)
|
| 206 |
-
else:
|
| 207 |
-
def te_attn_fn(query, key, value):
|
| 208 |
-
raise RuntimeError("Transformer Engine is not available. Please install it for TE-based attention.")
|
| 209 |
-
|
| 210 |
-
def _attention_te(query, key, value):
|
| 211 |
-
out = te_attn_fn(query, key, value)
|
| 212 |
-
out = out.unflatten(2, (num_attention_heads, attention_head_dim))
|
| 213 |
-
return out
|
| 214 |
-
|
| 215 |
-
|
| 216 |
-
# Cannot fullgraph compile TE
|
| 217 |
-
_compiled_te_attn_fn_default = torch.compile(_attention_te, mode="default", fullgraph=False, dynamic=False)
|
| 218 |
-
def _attention_te_compile_default(query, key, value):
|
| 219 |
-
return _compiled_te_attn_fn_default(query, key, value)
|
| 220 |
-
|
| 221 |
-
|
| 222 |
-
# Cannot fullgraph compile TE
|
| 223 |
-
_compiled_te_attn_fn_max_autotune = torch.compile(_attention_te, mode="max-autotune", fullgraph=False, dynamic=False)
|
| 224 |
-
def _attention_te_compile_max_autotune(query, key, value):
|
| 225 |
-
return _compiled_te_attn_fn_max_autotune(query, key, value)
|
| 226 |
-
|
| 227 |
-
|
| 228 |
-
def _attention_xformers(query, key, value):
|
| 229 |
-
return xops.memory_efficient_attention(query, key, value)
|
| 230 |
-
|
| 231 |
-
|
| 232 |
-
_compiled_xformers_default = torch.compile(_attention_xformers, mode="default", fullgraph=True, dynamic=False)
|
| 233 |
-
def _attention_xformers_compile_default(query, key, value):
|
| 234 |
-
return _compiled_xformers_default(query, key, value)
|
| 235 |
-
|
| 236 |
-
|
| 237 |
-
_compiled_xformers_max_autotune = torch.compile(_attention_xformers, mode="max-autotune", fullgraph=True, dynamic=False)
|
| 238 |
-
def _attention_xformers_compile_max_autotune(query, key, value):
|
| 239 |
-
return _compiled_xformers_max_autotune(query, key, value)
|
| 240 |
-
|
| 241 |
-
|
| 242 |
-
attention_ops = {}
|
| 243 |
-
attention_ops["torch_cudnn"] = functools.partial(_attention_torch, backend=torch.nn.attention.SDPBackend.CUDNN_ATTENTION)
|
| 244 |
-
attention_ops["torch_cudnn_compile_d"] = functools.partial(_attention_torch_compile_default, backend=torch.nn.attention.SDPBackend.CUDNN_ATTENTION)
|
| 245 |
-
attention_ops["torch_cudnn_compile_ma"] = functools.partial(_attention_torch_compile_max_autotune, backend=torch.nn.attention.SDPBackend.CUDNN_ATTENTION)
|
| 246 |
-
attention_ops["torch_flash"] = functools.partial(_attention_torch, backend=torch.nn.attention.SDPBackend.FLASH_ATTENTION)
|
| 247 |
-
attention_ops["torch_flash_compile_d"] = functools.partial(_attention_torch_compile_default, backend=torch.nn.attention.SDPBackend.FLASH_ATTENTION)
|
| 248 |
-
attention_ops["torch_flash_compile_ma"] = functools.partial(_attention_torch_compile_max_autotune, backend=torch.nn.attention.SDPBackend.FLASH_ATTENTION)
|
| 249 |
-
if hf_kernels_flash_attn is not None:
|
| 250 |
-
attention_ops["hf_flash_attn"] = _attention_hf_kernels_flash_attn
|
| 251 |
-
attention_ops["hf_flash_attn3"] = _attention_hf_kernels_flash_attn3
|
| 252 |
-
if flash_attn_func is not None:
|
| 253 |
-
attention_ops["flash_attn_2"] = _attention_flash_attn_2
|
| 254 |
-
attention_ops["flash_attn_2_compile_d"] = _attention_flash_attn_2_compile_default
|
| 255 |
-
attention_ops["flash_attn_2_compile_ma"] = _attention_flash_attn_2_compile_max_autotune
|
| 256 |
-
if flash_attn_3_func is not None:
|
| 257 |
-
attention_ops["flash_attn_3"] = _attention_flash_attn_3
|
| 258 |
-
attention_ops["flash_attn_3_compile_d"] = _attention_flash_attn_3_compile_default
|
| 259 |
-
attention_ops["flash_attn_3_compile_ma"] = _attention_flash_attn_3_compile_max_autotune
|
| 260 |
-
if sageattn_qk_int8_pv_fp16_cuda is not None:
|
| 261 |
-
attention_ops["sageattn_qk_int8_pv_fp16_cuda"] = _attention_sageattn_qk_int8_pv_fp16_cuda
|
| 262 |
-
attention_ops["sageattn_qk_int8_pv_fp16_triton"] = _attention_sageattn_qk_int8_pv_fp16_triton
|
| 263 |
-
if torch.cuda.get_device_capability()[0] >= 9:
|
| 264 |
-
attention_ops["sageattn_qk_int8_pv_fp8_cuda_sm90"] = _attention_sageattn_qk_int8_pv_fp8_cuda_sm90
|
| 265 |
-
if DotProductAttention is not None:
|
| 266 |
-
attention_ops["te_fused"] = _attention_te
|
| 267 |
-
attention_ops["te_fused_compile_d"] = _attention_te_compile_default
|
| 268 |
-
attention_ops["te_fused_compile_ma"] = _attention_te_compile_max_autotune
|
| 269 |
-
if xops is not None:
|
| 270 |
-
attention_ops["xformers"] = _attention_xformers
|
| 271 |
-
attention_ops["xformers_compile_d"] = _attention_xformers_compile_default
|
| 272 |
-
attention_ops["xformers_compile_ma"] = _attention_xformers_compile_max_autotune
|
| 273 |
-
|
| 274 |
-
|
| 275 |
-
def get_color_and_linestyle(n: int) -> tuple[str, str]:
|
| 276 |
-
colors = ["#e41a1c", "#377eb8", "#4daf4a", "#984ea3", "#ff7f00", "#a65628", "#f781bf", "#999999"]
|
| 277 |
-
line_styles = ["-", ":", "-.", "--"]
|
| 278 |
-
if n > len(colors) * len(line_styles):
|
| 279 |
-
raise ValueError(f"Required {n=} styles but maximum is {len(colors) * len(line_styles)}")
|
| 280 |
-
styles = []
|
| 281 |
-
for i in range(n):
|
| 282 |
-
color = colors[i % len(colors)]
|
| 283 |
-
linestyle = line_styles[i // len(colors)]
|
| 284 |
-
styles.append((color, linestyle))
|
| 285 |
-
return styles
|
| 286 |
-
|
| 287 |
-
|
| 288 |
-
def correctness():
|
| 289 |
-
for seq_len in sequence_lengths:
|
| 290 |
-
shape = (batch_size, seq_len, num_attention_heads, attention_head_dim)
|
| 291 |
-
print(f"\n\n===== Testing shape: {shape} =====")
|
| 292 |
-
|
| 293 |
-
query = torch.randn(shape, device="cuda", dtype=torch.float32)
|
| 294 |
-
key = torch.randn(shape, device="cuda", dtype=torch.float32)
|
| 295 |
-
value = torch.randn(shape, device="cuda", dtype=torch.float32)
|
| 296 |
-
|
| 297 |
-
golden_truth = _attention_torch(query, key, value, backend=torch.nn.attention.SDPBackend.MATH)
|
| 298 |
-
query, key, value = (x.bfloat16() for x in (query, key, value))
|
| 299 |
-
|
| 300 |
-
for name, fn in attention_ops.items():
|
| 301 |
-
out = fn(query, key, value)
|
| 302 |
-
absdiff = (out - golden_truth).abs()
|
| 303 |
-
absmax = torch.max(absdiff)
|
| 304 |
-
mae = torch.mean(absdiff)
|
| 305 |
-
mse = torch.mean((golden_truth - out) ** 2)
|
| 306 |
-
print(f"{name:<30}: absmax={absmax:.6f}, mae={mae:.6f}, mse={mse:.6f}")
|
| 307 |
-
|
| 308 |
-
|
| 309 |
-
@triton.testing.perf_report(
|
| 310 |
-
triton.testing.Benchmark(
|
| 311 |
-
x_names=["seq_len"],
|
| 312 |
-
x_vals=sequence_lengths,
|
| 313 |
-
x_log=False,
|
| 314 |
-
line_arg="provider",
|
| 315 |
-
line_vals=list(attention_ops.keys()),
|
| 316 |
-
line_names=[x.removeprefix("solution_") for x in attention_ops.keys()],
|
| 317 |
-
ylabel="Time (ms)",
|
| 318 |
-
styles=get_color_and_linestyle(len(attention_ops)),
|
| 319 |
-
plot_name="Attention Benchmark",
|
| 320 |
-
args={},
|
| 321 |
-
)
|
| 322 |
-
)
|
| 323 |
-
def benchmark_fn(seq_len: int, provider: str):
|
| 324 |
-
torch.manual_seed(0)
|
| 325 |
-
|
| 326 |
-
shape = (batch_size, seq_len, num_attention_heads, attention_head_dim)
|
| 327 |
-
query = torch.randn(shape, device="cuda", dtype=torch.bfloat16) * torch.randint(1, 5, shape, device="cuda", dtype=torch.bfloat16)
|
| 328 |
-
key = torch.randn(shape, device="cuda", dtype=torch.bfloat16) * torch.randint(1, 5, shape, device="cuda", dtype=torch.bfloat16)
|
| 329 |
-
value = torch.randn(shape, device="cuda", dtype=torch.bfloat16) * torch.randint(1, 5, shape, device="cuda", dtype=torch.bfloat16)
|
| 330 |
-
|
| 331 |
-
fn = attention_ops[provider]
|
| 332 |
-
ms, min_ms, max_ms = triton.testing.do_bench(
|
| 333 |
-
lambda: fn(query, key, value),
|
| 334 |
-
warmup=3,
|
| 335 |
-
rep=10,
|
| 336 |
-
quantiles=[0.5, 0.2, 0.8],
|
| 337 |
-
)
|
| 338 |
-
return ms, max_ms, min_ms
|
| 339 |
-
|
| 340 |
-
|
| 341 |
-
with torch.inference_mode():
|
| 342 |
-
correctness()
|
| 343 |
-
fig = benchmark_fn.run(print_data=True, save_path=output_dir.as_posix())
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
flash_attn/cells/nv.py
DELETED
|
@@ -1,3 +0,0 @@
|
|
| 1 |
-
import subprocess
|
| 2 |
-
|
| 3 |
-
print(subprocess.run(["nvidia-smi"], capture_output=True, text=True).stdout)
|
|
|
|
|
|
|
|
|
|
|
|
flash_attn/impls/artifacts/benchmark/attn.jsonl
DELETED
|
@@ -1,6 +0,0 @@
|
|
| 1 |
-
{"ts": "2025-10-02T19:59:35Z", "run": "8bc1bbc1e0504355abbb1f58e69828d3", "impl": "hf_kernels_flash_attn3", "tags": {"family": "hf-kernels", "backend": "flash-attn3", "compile": "none"}, "wl": {"name": "flux_L128", "batch": 1, "seq_len": 1152, "heads": 24, "head_dim": 128, "dtype": "bfloat16", "device": "cuda", "seed": 0}, "env": {"torch": "2.8.0+cu128", "cuda": "12.8", "gpu": "NVIDIA A10G", "sm": "8.6", "py": "3.11.13", "plat": "Linux-6.12.40-64.114.amzn2023.x86_64-x86_64-with-glibc2.36"}, "lat_ms": {"p10": 0.3603839874267578, "p50": 0.361952006816864, "p90": 0.3624640107154846, "mean": 0.3619711995124817, "reps": 5, "warmup": 2}, "compile_ms": 1.5701119899749756, "peak_bytes": 87425024, "ok": true, "absmax": 0.0625, "corr": {"ok": true, "rtol": 0.02, "atol": 0.02, "absmax": 0.0625, "mae": 0.00035858154296875, "mse": 2.8908252716064453e-06, "ref": "sdpa_math_fp32"}, "err": null}
|
| 2 |
-
{"ts": "2025-10-02T19:59:35Z", "run": "8bc1bbc1e0504355abbb1f58e69828d3", "impl": "hf_kernels_flash_attn3", "tags": {"family": "hf-kernels", "backend": "flash-attn3", "compile": "none"}, "wl": {"name": "flux_L256", "batch": 1, "seq_len": 1280, "heads": 24, "head_dim": 128, "dtype": "bfloat16", "device": "cuda", "seed": 0}, "env": {"torch": "2.8.0+cu128", "cuda": "12.8", "gpu": "NVIDIA A10G", "sm": "8.6", "py": "3.11.13", "plat": "Linux-6.12.40-64.114.amzn2023.x86_64-x86_64-with-glibc2.36"}, "lat_ms": {"p10": 0.3892799913883209, "p50": 0.3909760117530823, "p90": 0.3922559916973114, "mean": 0.3912447988986969, "reps": 5, "warmup": 2}, "compile_ms": 0.35811200737953186, "peak_bytes": 95027200, "ok": true, "absmax": 0.0625, "corr": {"ok": true, "rtol": 0.02, "atol": 0.02, "absmax": 0.0625, "mae": 0.00035858154296875, "mse": 2.8908252716064453e-06, "ref": "sdpa_math_fp32"}, "err": null}
|
| 3 |
-
{"ts": "2025-10-02T19:59:35Z", "run": "8bc1bbc1e0504355abbb1f58e69828d3", "impl": "hf_kernels_flash_attn3", "tags": {"family": "hf-kernels", "backend": "flash-attn3", "compile": "none"}, "wl": {"name": "flux_L320", "batch": 1, "seq_len": 1344, "heads": 24, "head_dim": 128, "dtype": "bfloat16", "device": "cuda", "seed": 0}, "env": {"torch": "2.8.0+cu128", "cuda": "12.8", "gpu": "NVIDIA A10G", "sm": "8.6", "py": "3.11.13", "plat": "Linux-6.12.40-64.114.amzn2023.x86_64-x86_64-with-glibc2.36"}, "lat_ms": {"p10": 0.5240640044212341, "p50": 0.5248960256576538, "p90": 0.5248960256576538, "mean": 0.5258048176765442, "reps": 5, "warmup": 2}, "compile_ms": 0.4891839921474457, "peak_bytes": 99680256, "ok": true, "absmax": 0.0625, "corr": {"ok": true, "rtol": 0.02, "atol": 0.02, "absmax": 0.0625, "mae": 0.00035858154296875, "mse": 2.905726432800293e-06, "ref": "sdpa_math_fp32"}, "err": null}
|
| 4 |
-
{"ts": "2025-10-02T19:59:35Z", "run": "8bc1bbc1e0504355abbb1f58e69828d3", "impl": "hf_kernels_flash_attn3", "tags": {"family": "hf-kernels", "backend": "flash-attn3", "compile": "none"}, "wl": {"name": "flux_L384", "batch": 1, "seq_len": 1408, "heads": 24, "head_dim": 128, "dtype": "bfloat16", "device": "cuda", "seed": 0}, "env": {"torch": "2.8.0+cu128", "cuda": "12.8", "gpu": "NVIDIA A10G", "sm": "8.6", "py": "3.11.13", "plat": "Linux-6.12.40-64.114.amzn2023.x86_64-x86_64-with-glibc2.36"}, "lat_ms": {"p10": 0.5265600085258484, "p50": 0.5277760028839111, "p90": 0.5282559990882874, "mean": 0.5276032090187073, "reps": 5, "warmup": 2}, "compile_ms": 0.4968000054359436, "peak_bytes": 104726528, "ok": true, "absmax": 0.0625, "corr": {"ok": true, "rtol": 0.02, "atol": 0.02, "absmax": 0.0625, "mae": 0.0003604888916015625, "mse": 2.8908252716064453e-06, "ref": "sdpa_math_fp32"}, "err": null}
|
| 5 |
-
{"ts": "2025-10-02T19:59:35Z", "run": "8bc1bbc1e0504355abbb1f58e69828d3", "impl": "hf_kernels_flash_attn3", "tags": {"family": "hf-kernels", "backend": "flash-attn3", "compile": "none"}, "wl": {"name": "flux_L448", "batch": 1, "seq_len": 1472, "heads": 24, "head_dim": 128, "dtype": "bfloat16", "device": "cuda", "seed": 0}, "env": {"torch": "2.8.0+cu128", "cuda": "12.8", "gpu": "NVIDIA A10G", "sm": "8.6", "py": "3.11.13", "plat": "Linux-6.12.40-64.114.amzn2023.x86_64-x86_64-with-glibc2.36"}, "lat_ms": {"p10": 0.5639039874076843, "p50": 0.5657920241355896, "p90": 0.5668479800224304, "mean": 0.5656383991241455, "reps": 5, "warmup": 2}, "compile_ms": 0.5312319993972778, "peak_bytes": 108855296, "ok": true, "absmax": 0.0625, "corr": {"ok": true, "rtol": 0.02, "atol": 0.02, "absmax": 0.0625, "mae": 0.0003566741943359375, "mse": 2.86102294921875e-06, "ref": "sdpa_math_fp32"}, "err": null}
|
| 6 |
-
{"ts": "2025-10-02T19:59:35Z", "run": "8bc1bbc1e0504355abbb1f58e69828d3", "impl": "hf_kernels_flash_attn3", "tags": {"family": "hf-kernels", "backend": "flash-attn3", "compile": "none"}, "wl": {"name": "flux_L512", "batch": 1, "seq_len": 1536, "heads": 24, "head_dim": 128, "dtype": "bfloat16", "device": "cuda", "seed": 0}, "env": {"torch": "2.8.0+cu128", "cuda": "12.8", "gpu": "NVIDIA A10G", "sm": "8.6", "py": "3.11.13", "plat": "Linux-6.12.40-64.114.amzn2023.x86_64-x86_64-with-glibc2.36"}, "lat_ms": {"p10": 0.5689600110054016, "p50": 0.5698239803314209, "p90": 0.5713919997215271, "mean": 0.5789952039718628, "reps": 5, "warmup": 2}, "compile_ms": 0.5350080132484436, "peak_bytes": 114425856, "ok": true, "absmax": 0.0625, "corr": {"ok": true, "rtol": 0.02, "atol": 0.02, "absmax": 0.0625, "mae": 0.00035858154296875, "mse": 2.8759241104125977e-06, "ref": "sdpa_math_fp32"}, "err": null}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
flash_attn/impls/artifacts/benchmark_default/attn_default.jsonl
DELETED
|
@@ -1,6 +0,0 @@
|
|
| 1 |
-
{"ts": "2025-10-02T19:58:18Z", "run": "9ebc449a917f4f2196503654e5549239", "impl": "torch_flash_compiled_default", "tags": {"family": "torch-sdpa", "backend": "FLASH", "compile": "default"}, "wl": {"name": "flux_L128", "batch": 1, "seq_len": 1152, "heads": 24, "head_dim": 128, "dtype": "bfloat16", "device": "cuda", "seed": 0}, "env": {"torch": "2.8.0+cu128", "cuda": "12.8", "gpu": "NVIDIA A10G", "sm": "8.6", "py": "3.11.13", "plat": "Linux-6.12.40-64.114.amzn2023.x86_64-x86_64-with-glibc2.36"}, "lat_ms": {"p10": 0.5141760110855103, "p50": 0.5175679922103882, "p90": 0.5197759866714478, "mean": 0.5181439876556396, "reps": 5, "warmup": 2}, "compile_ms": 3084.621826171875, "peak_bytes": 87425024, "ok": true, "absmax": 0.0625, "corr": {"ok": true, "rtol": 0.02, "atol": 0.02, "absmax": 0.0625, "mae": 0.000339508056640625, "mse": 2.726912498474121e-06, "ref": "sdpa_math_fp32"}, "err": null}
|
| 2 |
-
{"ts": "2025-10-02T19:58:19Z", "run": "9ebc449a917f4f2196503654e5549239", "impl": "torch_flash_compiled_default", "tags": {"family": "torch-sdpa", "backend": "FLASH", "compile": "default"}, "wl": {"name": "flux_L256", "batch": 1, "seq_len": 1280, "heads": 24, "head_dim": 128, "dtype": "bfloat16", "device": "cuda", "seed": 0}, "env": {"torch": "2.8.0+cu128", "cuda": "12.8", "gpu": "NVIDIA A10G", "sm": "8.6", "py": "3.11.13", "plat": "Linux-6.12.40-64.114.amzn2023.x86_64-x86_64-with-glibc2.36"}, "lat_ms": {"p10": 0.5549119710922241, "p50": 0.5582720041275024, "p90": 0.5598080158233643, "mean": 0.5579584002494812, "reps": 5, "warmup": 2}, "compile_ms": 270.21795654296875, "peak_bytes": 95027200, "ok": true, "absmax": 0.0625, "corr": {"ok": true, "rtol": 0.02, "atol": 0.02, "absmax": 0.0625, "mae": 0.0003414154052734375, "mse": 2.7418136596679688e-06, "ref": "sdpa_math_fp32"}, "err": null}
|
| 3 |
-
{"ts": "2025-10-02T19:58:19Z", "run": "9ebc449a917f4f2196503654e5549239", "impl": "torch_flash_compiled_default", "tags": {"family": "torch-sdpa", "backend": "FLASH", "compile": "default"}, "wl": {"name": "flux_L320", "batch": 1, "seq_len": 1344, "heads": 24, "head_dim": 128, "dtype": "bfloat16", "device": "cuda", "seed": 0}, "env": {"torch": "2.8.0+cu128", "cuda": "12.8", "gpu": "NVIDIA A10G", "sm": "8.6", "py": "3.11.13", "plat": "Linux-6.12.40-64.114.amzn2023.x86_64-x86_64-with-glibc2.36"}, "lat_ms": {"p10": 0.6853119730949402, "p50": 0.687391996383667, "p90": 0.6883519887924194, "mean": 0.6872959971427918, "reps": 5, "warmup": 2}, "compile_ms": 269.78741455078125, "peak_bytes": 99876864, "ok": true, "absmax": 0.0625, "corr": {"ok": true, "rtol": 0.02, "atol": 0.02, "absmax": 0.0625, "mae": 0.00034332275390625, "mse": 2.7567148208618164e-06, "ref": "sdpa_math_fp32"}, "err": null}
|
| 4 |
-
{"ts": "2025-10-02T19:58:19Z", "run": "9ebc449a917f4f2196503654e5549239", "impl": "torch_flash_compiled_default", "tags": {"family": "torch-sdpa", "backend": "FLASH", "compile": "default"}, "wl": {"name": "flux_L384", "batch": 1, "seq_len": 1408, "heads": 24, "head_dim": 128, "dtype": "bfloat16", "device": "cuda", "seed": 0}, "env": {"torch": "2.8.0+cu128", "cuda": "12.8", "gpu": "NVIDIA A10G", "sm": "8.6", "py": "3.11.13", "plat": "Linux-6.12.40-64.114.amzn2023.x86_64-x86_64-with-glibc2.36"}, "lat_ms": {"p10": 0.7128639817237854, "p50": 0.7160959839820862, "p90": 0.7167680263519287, "mean": 0.716153597831726, "reps": 5, "warmup": 2}, "compile_ms": 269.8607177734375, "peak_bytes": 104726528, "ok": true, "absmax": 0.0625, "corr": {"ok": true, "rtol": 0.02, "atol": 0.02, "absmax": 0.0625, "mae": 0.00034332275390625, "mse": 2.7567148208618164e-06, "ref": "sdpa_math_fp32"}, "err": null}
|
| 5 |
-
{"ts": "2025-10-02T19:58:19Z", "run": "9ebc449a917f4f2196503654e5549239", "impl": "torch_flash_compiled_default", "tags": {"family": "torch-sdpa", "backend": "FLASH", "compile": "default"}, "wl": {"name": "flux_L448", "batch": 1, "seq_len": 1472, "heads": 24, "head_dim": 128, "dtype": "bfloat16", "device": "cuda", "seed": 0}, "env": {"torch": "2.8.0+cu128", "cuda": "12.8", "gpu": "NVIDIA A10G", "sm": "8.6", "py": "3.11.13", "plat": "Linux-6.12.40-64.114.amzn2023.x86_64-x86_64-with-glibc2.36"}, "lat_ms": {"p10": 0.7386879920959473, "p50": 0.7400959730148315, "p90": 0.7415040135383606, "mean": 0.7418303966522217, "reps": 5, "warmup": 2}, "compile_ms": 269.20501708984375, "peak_bytes": 108855296, "ok": true, "absmax": 0.0625, "corr": {"ok": true, "rtol": 0.02, "atol": 0.02, "absmax": 0.0625, "mae": 0.00034332275390625, "mse": 2.7567148208618164e-06, "ref": "sdpa_math_fp32"}, "err": null}
|
| 6 |
-
{"ts": "2025-10-02T19:58:20Z", "run": "9ebc449a917f4f2196503654e5549239", "impl": "torch_flash_compiled_default", "tags": {"family": "torch-sdpa", "backend": "FLASH", "compile": "default"}, "wl": {"name": "flux_L512", "batch": 1, "seq_len": 1536, "heads": 24, "head_dim": 128, "dtype": "bfloat16", "device": "cuda", "seed": 0}, "env": {"torch": "2.8.0+cu128", "cuda": "12.8", "gpu": "NVIDIA A10G", "sm": "8.6", "py": "3.11.13", "plat": "Linux-6.12.40-64.114.amzn2023.x86_64-x86_64-with-glibc2.36"}, "lat_ms": {"p10": 0.7708160281181335, "p50": 0.7740799784660339, "p90": 0.7753919959068298, "mean": 0.7745471954345703, "reps": 5, "warmup": 2}, "compile_ms": 270.93829345703125, "peak_bytes": 114425856, "ok": true, "absmax": 0.0625, "corr": {"ok": true, "rtol": 0.02, "atol": 0.02, "absmax": 0.0625, "mae": 0.0003452301025390625, "mse": 2.771615982055664e-06, "ref": "sdpa_math_fp32"}, "err": null}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
flash_attn/impls/artifacts/benchmark_max_autotune/attn_max_autotune.jsonl
DELETED
|
@@ -1,6 +0,0 @@
|
|
| 1 |
-
{"ts": "2025-10-02T19:57:25Z", "run": "edb73be653834cdf8524ee78b403db7f", "impl": "torch_flash_compiled_max_autotune", "tags": {"family": "torch-sdpa", "backend": "FLASH", "compile": "max-autotune"}, "wl": {"name": "flux_L128", "batch": 1, "seq_len": 1152, "heads": 24, "head_dim": 128, "dtype": "bfloat16", "device": "cuda", "seed": 0}, "env": {"torch": "2.8.0+cu128", "cuda": "12.8", "gpu": "NVIDIA A10G", "sm": "8.6", "py": "3.11.13", "plat": "Linux-6.12.40-64.114.amzn2023.x86_64-x86_64-with-glibc2.36"}, "lat_ms": {"p10": 0.6144000291824341, "p50": 0.6245759725570679, "p90": 0.6483200192451477, "mean": 0.6468096017837525, "reps": 5, "warmup": 2}, "compile_ms": 4407.3388671875, "peak_bytes": 70779904, "ok": true, "absmax": 0.0625, "corr": {"ok": true, "rtol": 0.02, "atol": 0.02, "absmax": 0.0625, "mae": 0.000339508056640625, "mse": 2.726912498474121e-06, "ref": "sdpa_math_fp32"}, "err": null}
|
| 2 |
-
{"ts": "2025-10-02T19:57:27Z", "run": "edb73be653834cdf8524ee78b403db7f", "impl": "torch_flash_compiled_max_autotune", "tags": {"family": "torch-sdpa", "backend": "FLASH", "compile": "max-autotune"}, "wl": {"name": "flux_L256", "batch": 1, "seq_len": 1280, "heads": 24, "head_dim": 128, "dtype": "bfloat16", "device": "cuda", "seed": 0}, "env": {"torch": "2.8.0+cu128", "cuda": "12.8", "gpu": "NVIDIA A10G", "sm": "8.6", "py": "3.11.13", "plat": "Linux-6.12.40-64.114.amzn2023.x86_64-x86_64-with-glibc2.36"}, "lat_ms": {"p10": 0.6689280271530151, "p50": 0.6851199865341187, "p90": 0.7184960246086121, "mean": 0.7060160160064697, "reps": 5, "warmup": 2}, "compile_ms": 1686.2735595703125, "peak_bytes": 78644224, "ok": true, "absmax": 0.0625, "corr": {"ok": true, "rtol": 0.02, "atol": 0.02, "absmax": 0.0625, "mae": 0.0003414154052734375, "mse": 2.7418136596679688e-06, "ref": "sdpa_math_fp32"}, "err": null}
|
| 3 |
-
{"ts": "2025-10-02T19:57:29Z", "run": "edb73be653834cdf8524ee78b403db7f", "impl": "torch_flash_compiled_max_autotune", "tags": {"family": "torch-sdpa", "backend": "FLASH", "compile": "max-autotune"}, "wl": {"name": "flux_L320", "batch": 1, "seq_len": 1344, "heads": 24, "head_dim": 128, "dtype": "bfloat16", "device": "cuda", "seed": 0}, "env": {"torch": "2.8.0+cu128", "cuda": "12.8", "gpu": "NVIDIA A10G", "sm": "8.6", "py": "3.11.13", "plat": "Linux-6.12.40-64.114.amzn2023.x86_64-x86_64-with-glibc2.36"}, "lat_ms": {"p10": 0.7953600287437439, "p50": 0.8155840039253235, "p90": 0.8403519988059998, "mean": 0.8332608103752136, "reps": 5, "warmup": 2}, "compile_ms": 1462.938232421875, "peak_bytes": 84280320, "ok": true, "absmax": 0.0625, "corr": {"ok": true, "rtol": 0.02, "atol": 0.02, "absmax": 0.0625, "mae": 0.00034332275390625, "mse": 2.7567148208618164e-06, "ref": "sdpa_math_fp32"}, "err": null}
|
| 4 |
-
{"ts": "2025-10-02T19:57:31Z", "run": "edb73be653834cdf8524ee78b403db7f", "impl": "torch_flash_compiled_max_autotune", "tags": {"family": "torch-sdpa", "backend": "FLASH", "compile": "max-autotune"}, "wl": {"name": "flux_L384", "batch": 1, "seq_len": 1408, "heads": 24, "head_dim": 128, "dtype": "bfloat16", "device": "cuda", "seed": 0}, "env": {"torch": "2.8.0+cu128", "cuda": "12.8", "gpu": "NVIDIA A10G", "sm": "8.6", "py": "3.11.13", "plat": "Linux-6.12.40-64.114.amzn2023.x86_64-x86_64-with-glibc2.36"}, "lat_ms": {"p10": 0.8470720052719116, "p50": 0.849727988243103, "p90": 0.8745279908180237, "mean": 0.8719295978546142, "reps": 5, "warmup": 2}, "compile_ms": 1689.3455810546875, "peak_bytes": 86508544, "ok": true, "absmax": 0.0625, "corr": {"ok": true, "rtol": 0.02, "atol": 0.02, "absmax": 0.0625, "mae": 0.00034332275390625, "mse": 2.7567148208618164e-06, "ref": "sdpa_math_fp32"}, "err": null}
|
| 5 |
-
{"ts": "2025-10-02T19:57:33Z", "run": "edb73be653834cdf8524ee78b403db7f", "impl": "torch_flash_compiled_max_autotune", "tags": {"family": "torch-sdpa", "backend": "FLASH", "compile": "max-autotune"}, "wl": {"name": "flux_L448", "batch": 1, "seq_len": 1472, "heads": 24, "head_dim": 128, "dtype": "bfloat16", "device": "cuda", "seed": 0}, "env": {"torch": "2.8.0+cu128", "cuda": "12.8", "gpu": "NVIDIA A10G", "sm": "8.6", "py": "3.11.13", "plat": "Linux-6.12.40-64.114.amzn2023.x86_64-x86_64-with-glibc2.36"}, "lat_ms": {"p10": 0.8677120208740234, "p50": 0.8835520148277283, "p90": 0.9034240245819092, "mean": 0.9034304022789001, "reps": 5, "warmup": 2}, "compile_ms": 1693.035888671875, "peak_bytes": 90440704, "ok": true, "absmax": 0.0625, "corr": {"ok": true, "rtol": 0.02, "atol": 0.02, "absmax": 0.0625, "mae": 0.00034332275390625, "mse": 2.7567148208618164e-06, "ref": "sdpa_math_fp32"}, "err": null}
|
| 6 |
-
{"ts": "2025-10-02T19:57:34Z", "run": "edb73be653834cdf8524ee78b403db7f", "impl": "torch_flash_compiled_max_autotune", "tags": {"family": "torch-sdpa", "backend": "FLASH", "compile": "max-autotune"}, "wl": {"name": "flux_L512", "batch": 1, "seq_len": 1536, "heads": 24, "head_dim": 128, "dtype": "bfloat16", "device": "cuda", "seed": 0}, "env": {"torch": "2.8.0+cu128", "cuda": "12.8", "gpu": "NVIDIA A10G", "sm": "8.6", "py": "3.11.13", "plat": "Linux-6.12.40-64.114.amzn2023.x86_64-x86_64-with-glibc2.36"}, "lat_ms": {"p10": 0.9154239892959595, "p50": 0.9213759899139404, "p90": 0.9359679818153381, "mean": 0.9387519836425782, "reps": 5, "warmup": 2}, "compile_ms": 1689.36279296875, "peak_bytes": 94372864, "ok": true, "absmax": 0.0625, "corr": {"ok": true, "rtol": 0.02, "atol": 0.02, "absmax": 0.0625, "mae": 0.0003452301025390625, "mse": 2.771615982055664e-06, "ref": "sdpa_math_fp32"}, "err": null}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
flash_attn/impls/cells/benchmark.py
DELETED
|
@@ -1,71 +0,0 @@
|
|
| 1 |
-
# /// script
|
| 2 |
-
# requires-python = ">=3.10"
|
| 3 |
-
# dependencies = [
|
| 4 |
-
# "numpy",
|
| 5 |
-
# "torch",
|
| 6 |
-
# "kernels-benchmark-tools",
|
| 7 |
-
# "kernels",
|
| 8 |
-
# ]
|
| 9 |
-
#
|
| 10 |
-
# [tool.uv.sources]
|
| 11 |
-
# kernels-benchmark-tools = { git = "https://github.com/drbh/kernels-benchmark-tools.git", branch = "main" }
|
| 12 |
-
# ///
|
| 13 |
-
import torch
|
| 14 |
-
import sys
|
| 15 |
-
import os
|
| 16 |
-
import kernels_benchmark_tools as kbt
|
| 17 |
-
from kernels import get_kernel
|
| 18 |
-
|
| 19 |
-
hf_kernels_flash_attn3 = get_kernel("kernels-community/flash-attn3")
|
| 20 |
-
|
| 21 |
-
|
| 22 |
-
def hf_flash_attention3(query, key, value):
|
| 23 |
-
return hf_kernels_flash_attn3.flash_attn_func(query, key, value, causal=False)[0]
|
| 24 |
-
|
| 25 |
-
|
| 26 |
-
kbt.add(
|
| 27 |
-
"hf_kernels_flash_attn3",
|
| 28 |
-
hf_flash_attention3,
|
| 29 |
-
tags={"family": "hf-kernels", "backend": "flash-attn3", "compile": "none"},
|
| 30 |
-
)
|
| 31 |
-
|
| 32 |
-
if __name__ == "__main__":
|
| 33 |
-
device = "cuda" if torch.cuda.is_available() else "cpu"
|
| 34 |
-
|
| 35 |
-
if device == "cpu":
|
| 36 |
-
print("HF Kernels Flash Attention 3 requires CUDA - skipping benchmark")
|
| 37 |
-
sys.exit(0)
|
| 38 |
-
|
| 39 |
-
dtype = "bfloat16"
|
| 40 |
-
|
| 41 |
-
# Flux-like workloads
|
| 42 |
-
base = 1024
|
| 43 |
-
flux_sizes = [128, 256, 320, 384, 448, 512]
|
| 44 |
-
heads = 24
|
| 45 |
-
head_dim = 128
|
| 46 |
-
|
| 47 |
-
wl = []
|
| 48 |
-
for L in flux_sizes:
|
| 49 |
-
wl.append(
|
| 50 |
-
{
|
| 51 |
-
"name": f"flux_L{L}",
|
| 52 |
-
"batch": 1,
|
| 53 |
-
"seq_len": base + L,
|
| 54 |
-
"heads": heads,
|
| 55 |
-
"head_dim": head_dim,
|
| 56 |
-
"dtype": dtype,
|
| 57 |
-
"device": device,
|
| 58 |
-
"seed": 0,
|
| 59 |
-
}
|
| 60 |
-
)
|
| 61 |
-
|
| 62 |
-
kbt.run(
|
| 63 |
-
wl,
|
| 64 |
-
jsonl="attn.jsonl",
|
| 65 |
-
reps=5,
|
| 66 |
-
warmup=2,
|
| 67 |
-
gen=kbt.attn.gen_qkv,
|
| 68 |
-
ref=kbt.attn.ref_math,
|
| 69 |
-
cmp=kbt.attn.cmp_allclose,
|
| 70 |
-
)
|
| 71 |
-
kbt.summarize(["attn.jsonl"])
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
flash_attn/impls/cells/benchmark_default.py
DELETED
|
@@ -1,70 +0,0 @@
|
|
| 1 |
-
# /// script
|
| 2 |
-
# requires-python = ">=3.10"
|
| 3 |
-
# dependencies = [
|
| 4 |
-
# "numpy",
|
| 5 |
-
# "torch",
|
| 6 |
-
# "kernels-benchmark-tools",
|
| 7 |
-
# ]
|
| 8 |
-
#
|
| 9 |
-
# [tool.uv.sources]
|
| 10 |
-
# kernels-benchmark-tools = { git = "https://github.com/drbh/kernels-benchmark-tools.git", branch = "main" }
|
| 11 |
-
# ///
|
| 12 |
-
import torch
|
| 13 |
-
import sys
|
| 14 |
-
import os
|
| 15 |
-
import kernels_benchmark_tools as kbt
|
| 16 |
-
|
| 17 |
-
|
| 18 |
-
def torch_flash_base(q, k, v):
|
| 19 |
-
qt, kt, vt = (x.transpose(1, 2).contiguous() for x in (q, k, v))
|
| 20 |
-
with torch.nn.attention.sdpa_kernel(torch.nn.attention.SDPBackend.FLASH_ATTENTION):
|
| 21 |
-
o = torch.nn.functional.scaled_dot_product_attention(qt, kt, vt)
|
| 22 |
-
return o.transpose(1, 2).contiguous()
|
| 23 |
-
|
| 24 |
-
|
| 25 |
-
# Compile with default mode
|
| 26 |
-
compiled_flash_default = torch.compile(torch_flash_base, mode="default", fullgraph=True, dynamic=False)
|
| 27 |
-
|
| 28 |
-
kbt.add(
|
| 29 |
-
"torch_flash_compiled_default",
|
| 30 |
-
compiled_flash_default,
|
| 31 |
-
tags={"family": "torch-sdpa", "backend": "FLASH", "compile": "default"},
|
| 32 |
-
)
|
| 33 |
-
|
| 34 |
-
if __name__ == "__main__":
|
| 35 |
-
device = "cuda" if torch.cuda.is_available() else "cpu"
|
| 36 |
-
dtype = "float32" if device == "cpu" else "bfloat16"
|
| 37 |
-
|
| 38 |
-
# Flux-like workloads
|
| 39 |
-
base = 1024 if device == "cuda" else 512
|
| 40 |
-
flux_sizes = (
|
| 41 |
-
[128, 256, 320, 384, 448, 512] if device == "cuda" else [64, 128, 192, 256]
|
| 42 |
-
)
|
| 43 |
-
heads = 24 if device == "cuda" else 8
|
| 44 |
-
head_dim = 128 if device == "cuda" else 64
|
| 45 |
-
|
| 46 |
-
wl = []
|
| 47 |
-
for L in flux_sizes:
|
| 48 |
-
wl.append(
|
| 49 |
-
{
|
| 50 |
-
"name": f"flux_L{L}",
|
| 51 |
-
"batch": 1,
|
| 52 |
-
"seq_len": base + L,
|
| 53 |
-
"heads": heads,
|
| 54 |
-
"head_dim": head_dim,
|
| 55 |
-
"dtype": dtype,
|
| 56 |
-
"device": device,
|
| 57 |
-
"seed": 0,
|
| 58 |
-
}
|
| 59 |
-
)
|
| 60 |
-
|
| 61 |
-
kbt.run(
|
| 62 |
-
wl,
|
| 63 |
-
jsonl="attn_default.jsonl",
|
| 64 |
-
reps=5,
|
| 65 |
-
warmup=2,
|
| 66 |
-
gen=kbt.attn.gen_qkv,
|
| 67 |
-
ref=kbt.attn.ref_math,
|
| 68 |
-
cmp=kbt.attn.cmp_allclose,
|
| 69 |
-
)
|
| 70 |
-
kbt.summarize(["attn_default.jsonl"])
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
flash_attn/impls/cells/benchmark_max_autotune.py
DELETED
|
@@ -1,70 +0,0 @@
|
|
| 1 |
-
# /// script
|
| 2 |
-
# requires-python = ">=3.10"
|
| 3 |
-
# dependencies = [
|
| 4 |
-
# "numpy",
|
| 5 |
-
# "torch",
|
| 6 |
-
# "kernels-benchmark-tools",
|
| 7 |
-
# ]
|
| 8 |
-
#
|
| 9 |
-
# [tool.uv.sources]
|
| 10 |
-
# kernels-benchmark-tools = { git = "https://github.com/drbh/kernels-benchmark-tools.git", branch = "main" }
|
| 11 |
-
# ///
|
| 12 |
-
import torch
|
| 13 |
-
import sys
|
| 14 |
-
import os
|
| 15 |
-
import kernels_benchmark_tools as kbt
|
| 16 |
-
|
| 17 |
-
|
| 18 |
-
def torch_flash_base(q, k, v):
|
| 19 |
-
qt, kt, vt = (x.transpose(1, 2).contiguous() for x in (q, k, v))
|
| 20 |
-
with torch.nn.attention.sdpa_kernel(torch.nn.attention.SDPBackend.FLASH_ATTENTION):
|
| 21 |
-
o = torch.nn.functional.scaled_dot_product_attention(qt, kt, vt)
|
| 22 |
-
return o.transpose(1, 2).contiguous()
|
| 23 |
-
|
| 24 |
-
|
| 25 |
-
# Compile with max-autotune mode
|
| 26 |
-
compiled_flash_max_autotune = torch.compile(torch_flash_base, mode="max-autotune", fullgraph=True, dynamic=False)
|
| 27 |
-
|
| 28 |
-
kbt.add(
|
| 29 |
-
"torch_flash_compiled_max_autotune",
|
| 30 |
-
compiled_flash_max_autotune,
|
| 31 |
-
tags={"family": "torch-sdpa", "backend": "FLASH", "compile": "max-autotune"},
|
| 32 |
-
)
|
| 33 |
-
|
| 34 |
-
if __name__ == "__main__":
|
| 35 |
-
device = "cuda" if torch.cuda.is_available() else "cpu"
|
| 36 |
-
dtype = "float32" if device == "cpu" else "bfloat16"
|
| 37 |
-
|
| 38 |
-
# Flux-like workloads
|
| 39 |
-
base = 1024 if device == "cuda" else 512
|
| 40 |
-
flux_sizes = (
|
| 41 |
-
[128, 256, 320, 384, 448, 512] if device == "cuda" else [64, 128, 192, 256]
|
| 42 |
-
)
|
| 43 |
-
heads = 24 if device == "cuda" else 8
|
| 44 |
-
head_dim = 128 if device == "cuda" else 64
|
| 45 |
-
|
| 46 |
-
wl = []
|
| 47 |
-
for L in flux_sizes:
|
| 48 |
-
wl.append(
|
| 49 |
-
{
|
| 50 |
-
"name": f"flux_L{L}",
|
| 51 |
-
"batch": 1,
|
| 52 |
-
"seq_len": base + L,
|
| 53 |
-
"heads": heads,
|
| 54 |
-
"head_dim": head_dim,
|
| 55 |
-
"dtype": dtype,
|
| 56 |
-
"device": device,
|
| 57 |
-
"seed": 0,
|
| 58 |
-
}
|
| 59 |
-
)
|
| 60 |
-
|
| 61 |
-
kbt.run(
|
| 62 |
-
wl,
|
| 63 |
-
jsonl="attn_max_autotune.jsonl",
|
| 64 |
-
reps=5,
|
| 65 |
-
warmup=2,
|
| 66 |
-
gen=kbt.attn.gen_qkv,
|
| 67 |
-
ref=kbt.attn.ref_math,
|
| 68 |
-
cmp=kbt.attn.cmp_allclose,
|
| 69 |
-
)
|
| 70 |
-
kbt.summarize(["attn_max_autotune.jsonl"])
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
flash_attn/impls/cells/nv.py
DELETED
|
@@ -1,3 +0,0 @@
|
|
| 1 |
-
import subprocess
|
| 2 |
-
|
| 3 |
-
print(subprocess.run(["nvidia-smi"], capture_output=True, text=True).stdout)
|
|
|
|
|
|
|
|
|
|
|
|
flash_attn/impls/compiled_variants.html
DELETED
|
The diff for this file is too large to render.
See raw diff
|
|
|
flash_attn/impls/flash_attention.html
DELETED
|
The diff for this file is too large to render.
See raw diff
|
|
|
flash_attn/impls/hf_kernels_flash_attn.html
DELETED
|
The diff for this file is too large to render.
See raw diff
|
|
|
flash_attn/impls/hf_kernels_flash_attn3.html
DELETED
|
The diff for this file is too large to render.
See raw diff
|
|
|
flash_attn/impls/index.html
DELETED
|
@@ -1,94 +0,0 @@
|
|
| 1 |
-
<!DOCTYPE html>
|
| 2 |
-
<html>
|
| 3 |
-
<head>
|
| 4 |
-
<meta charset='UTF-8'>
|
| 5 |
-
<meta name='viewport' content='width=device-width, initial-scale=1.0'>
|
| 6 |
-
<title>Index of /flash_attn/impls</title>
|
| 7 |
-
<style>
|
| 8 |
-
:root {
|
| 9 |
-
--bg-primary: #0a0a0a;
|
| 10 |
-
--bg-secondary: #121212;
|
| 11 |
-
--bg-tertiary: #181818;
|
| 12 |
-
--text-primary: #e0e0e0;
|
| 13 |
-
--text-secondary: #888888;
|
| 14 |
-
--text-link: #64b5f6;
|
| 15 |
-
--border-primary: #2a2a2a;
|
| 16 |
-
}
|
| 17 |
-
body {
|
| 18 |
-
font-family: system-ui, -apple-system, Segoe UI, Roboto, Helvetica, Arial, sans-serif;
|
| 19 |
-
background: var(--bg-primary);
|
| 20 |
-
color: var(--text-primary);
|
| 21 |
-
margin: 0;
|
| 22 |
-
padding: 16px;
|
| 23 |
-
max-width: 900px;
|
| 24 |
-
margin: 0 auto;
|
| 25 |
-
}
|
| 26 |
-
.controls {
|
| 27 |
-
display: flex;
|
| 28 |
-
justify-content: flex-end;
|
| 29 |
-
margin-bottom: 1rem;
|
| 30 |
-
}
|
| 31 |
-
.back-button {
|
| 32 |
-
background: var(--bg-secondary);
|
| 33 |
-
border: 1px solid var(--border-primary);
|
| 34 |
-
padding: 8px 12px;
|
| 35 |
-
border-radius: 4px;
|
| 36 |
-
color: var(--text-secondary);
|
| 37 |
-
cursor: pointer;
|
| 38 |
-
font-size: 0.9rem;
|
| 39 |
-
text-decoration: none;
|
| 40 |
-
display: inline-block;
|
| 41 |
-
}
|
| 42 |
-
.back-button:hover {
|
| 43 |
-
color: var(--text-primary);
|
| 44 |
-
background: var(--bg-tertiary);
|
| 45 |
-
}
|
| 46 |
-
h1 {
|
| 47 |
-
font-size: 1.5em;
|
| 48 |
-
margin: 1rem 0;
|
| 49 |
-
color: var(--text-primary);
|
| 50 |
-
border-bottom: 1px solid var(--border-primary);
|
| 51 |
-
padding-bottom: 0.5rem;
|
| 52 |
-
}
|
| 53 |
-
ul {
|
| 54 |
-
list-style-type: none;
|
| 55 |
-
padding: 0;
|
| 56 |
-
}
|
| 57 |
-
li {
|
| 58 |
-
margin: 0;
|
| 59 |
-
border-bottom: 1px solid var(--border-primary);
|
| 60 |
-
}
|
| 61 |
-
li:last-child {
|
| 62 |
-
border-bottom: none;
|
| 63 |
-
}
|
| 64 |
-
a {
|
| 65 |
-
display: block;
|
| 66 |
-
padding: 0.75rem 0.5rem;
|
| 67 |
-
text-decoration: none;
|
| 68 |
-
color: var(--text-link);
|
| 69 |
-
transition: background 0.2s ease;
|
| 70 |
-
}
|
| 71 |
-
a:hover {
|
| 72 |
-
background: var(--bg-secondary);
|
| 73 |
-
}
|
| 74 |
-
.dir {
|
| 75 |
-
font-weight: 500;
|
| 76 |
-
}
|
| 77 |
-
</style>
|
| 78 |
-
</head>
|
| 79 |
-
<body>
|
| 80 |
-
<div class='controls'>
|
| 81 |
-
<a href='../index.html' class='back-button'>← back</a>
|
| 82 |
-
</div>
|
| 83 |
-
<h1>Index of /flash_attn/impls</h1>
|
| 84 |
-
<ul>
|
| 85 |
-
<li><a href='compiled_variants.html' class='file'>compiled_variants.html</a></li>
|
| 86 |
-
<li><a href='flash_attention.html' class='file'>flash_attention.html</a></li>
|
| 87 |
-
<li><a href='hf_kernels_flash_attn.html' class='file'>hf_kernels_flash_attn.html</a></li>
|
| 88 |
-
<li><a href='hf_kernels_flash_attn3.html' class='file'>hf_kernels_flash_attn3.html</a></li>
|
| 89 |
-
<li><a href='mem_efficient_attention.html' class='file'>mem_efficient_attention.html</a></li>
|
| 90 |
-
<li><a href='sage_attention.html' class='file'>sage_attention.html</a></li>
|
| 91 |
-
<li><a href='xformers.html' class='file'>xformers.html</a></li>
|
| 92 |
-
</ul>
|
| 93 |
-
</body>
|
| 94 |
-
</html>
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
flash_attn/impls/mem_efficient_attention.html
DELETED
|
The diff for this file is too large to render.
See raw diff
|
|
|
flash_attn/impls/sage_attention.html
DELETED
|
The diff for this file is too large to render.
See raw diff
|
|
|
flash_attn/impls/xformers.html
DELETED
|
The diff for this file is too large to render.
See raw diff
|
|
|
flash_attn/index.html
DELETED
|
@@ -1,89 +0,0 @@
|
|
| 1 |
-
<!DOCTYPE html>
|
| 2 |
-
<html>
|
| 3 |
-
<head>
|
| 4 |
-
<meta charset='UTF-8'>
|
| 5 |
-
<meta name='viewport' content='width=device-width, initial-scale=1.0'>
|
| 6 |
-
<title>Index of /flash_attn</title>
|
| 7 |
-
<style>
|
| 8 |
-
:root {
|
| 9 |
-
--bg-primary: #0a0a0a;
|
| 10 |
-
--bg-secondary: #121212;
|
| 11 |
-
--bg-tertiary: #181818;
|
| 12 |
-
--text-primary: #e0e0e0;
|
| 13 |
-
--text-secondary: #888888;
|
| 14 |
-
--text-link: #64b5f6;
|
| 15 |
-
--border-primary: #2a2a2a;
|
| 16 |
-
}
|
| 17 |
-
body {
|
| 18 |
-
font-family: system-ui, -apple-system, Segoe UI, Roboto, Helvetica, Arial, sans-serif;
|
| 19 |
-
background: var(--bg-primary);
|
| 20 |
-
color: var(--text-primary);
|
| 21 |
-
margin: 0;
|
| 22 |
-
padding: 16px;
|
| 23 |
-
max-width: 900px;
|
| 24 |
-
margin: 0 auto;
|
| 25 |
-
}
|
| 26 |
-
.controls {
|
| 27 |
-
display: flex;
|
| 28 |
-
justify-content: flex-end;
|
| 29 |
-
margin-bottom: 1rem;
|
| 30 |
-
}
|
| 31 |
-
.back-button {
|
| 32 |
-
background: var(--bg-secondary);
|
| 33 |
-
border: 1px solid var(--border-primary);
|
| 34 |
-
padding: 8px 12px;
|
| 35 |
-
border-radius: 4px;
|
| 36 |
-
color: var(--text-secondary);
|
| 37 |
-
cursor: pointer;
|
| 38 |
-
font-size: 0.9rem;
|
| 39 |
-
text-decoration: none;
|
| 40 |
-
display: inline-block;
|
| 41 |
-
}
|
| 42 |
-
.back-button:hover {
|
| 43 |
-
color: var(--text-primary);
|
| 44 |
-
background: var(--bg-tertiary);
|
| 45 |
-
}
|
| 46 |
-
h1 {
|
| 47 |
-
font-size: 1.5em;
|
| 48 |
-
margin: 1rem 0;
|
| 49 |
-
color: var(--text-primary);
|
| 50 |
-
border-bottom: 1px solid var(--border-primary);
|
| 51 |
-
padding-bottom: 0.5rem;
|
| 52 |
-
}
|
| 53 |
-
ul {
|
| 54 |
-
list-style-type: none;
|
| 55 |
-
padding: 0;
|
| 56 |
-
}
|
| 57 |
-
li {
|
| 58 |
-
margin: 0;
|
| 59 |
-
border-bottom: 1px solid var(--border-primary);
|
| 60 |
-
}
|
| 61 |
-
li:last-child {
|
| 62 |
-
border-bottom: none;
|
| 63 |
-
}
|
| 64 |
-
a {
|
| 65 |
-
display: block;
|
| 66 |
-
padding: 0.75rem 0.5rem;
|
| 67 |
-
text-decoration: none;
|
| 68 |
-
color: var(--text-link);
|
| 69 |
-
transition: background 0.2s ease;
|
| 70 |
-
}
|
| 71 |
-
a:hover {
|
| 72 |
-
background: var(--bg-secondary);
|
| 73 |
-
}
|
| 74 |
-
.dir {
|
| 75 |
-
font-weight: 500;
|
| 76 |
-
}
|
| 77 |
-
</style>
|
| 78 |
-
</head>
|
| 79 |
-
<body>
|
| 80 |
-
<div class='controls'>
|
| 81 |
-
<a href='../index.html' class='back-button'>← back</a>
|
| 82 |
-
</div>
|
| 83 |
-
<h1>Index of /flash_attn</h1>
|
| 84 |
-
<ul>
|
| 85 |
-
<li><a href='impls/index.html' class='dir'>impls/</a></li>
|
| 86 |
-
<li><a href='results/index.html' class='dir'>results/</a></li>
|
| 87 |
-
</ul>
|
| 88 |
-
</body>
|
| 89 |
-
</html>
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
flash_attn/results/artifacts/combine/latency.csv
DELETED
|
@@ -1,43 +0,0 @@
|
|
| 1 |
-
Implementation,Impl ID,Workload,Batch,Seq Length,Heads,Head Dim,Dtype,Mean (ms),P10 (ms),P50 (ms),P90 (ms),Reps,Peak Mem (MB),Backend,Family
|
| 2 |
-
Flash (PyTorch SDPA),torch_flash_ma,flux_L128,1,1152,24,128,bfloat16,0.49411200881004336,0.48844799399375916,0.4936000108718872,0.4944640100002289,5,83.38,FLASH,torch-sdpa
|
| 3 |
-
Flash (PyTorch SDPA),torch_flash_ma,flux_L256,1,1280,24,128,bfloat16,0.5234112024307251,0.5224320292472839,0.5235199928283691,0.5235840082168579,5,90.62,FLASH,torch-sdpa
|
| 4 |
-
Flash (PyTorch SDPA),torch_flash_ma,flux_L320,1,1344,24,128,bfloat16,0.6527232170104981,0.6503040194511414,0.6524800062179565,0.6545600295066833,5,95.06,FLASH,torch-sdpa
|
| 5 |
-
Flash (PyTorch SDPA),torch_flash_ma,flux_L384,1,1408,24,128,bfloat16,0.682803213596344,0.6805760264396667,0.6828799843788147,0.6832640171051025,5,99.88,FLASH,torch-sdpa
|
| 6 |
-
Flash (PyTorch SDPA),torch_flash_ma,flux_L448,1,1472,24,128,bfloat16,0.7075456142425537,0.7057600021362305,0.7063360214233398,0.7070720195770264,5,103.81,FLASH,torch-sdpa
|
| 7 |
-
Flash (PyTorch SDPA),torch_flash_ma,flux_L512,1,1536,24,128,bfloat16,0.7379711985588073,0.7368639707565308,0.7372480034828186,0.7391039729118347,5,109.12,FLASH,torch-sdpa
|
| 8 |
-
MemEff (PyTorch SDPA),torch_mem_eff,flux_L128,1,1152,24,128,bfloat16,0.5874239921569824,0.5861759781837463,0.5873280167579651,0.5877439975738525,5,83.38,EFFICIENT,torch-sdpa
|
| 9 |
-
MemEff (PyTorch SDPA),torch_mem_eff,flux_L256,1,1280,24,128,bfloat16,0.6502719998359681,0.6490240097045898,0.649183988571167,0.6517760157585144,5,90.62,EFFICIENT,torch-sdpa
|
| 10 |
-
MemEff (PyTorch SDPA),torch_mem_eff,flux_L320,1,1344,24,128,bfloat16,0.7812095880508423,0.7761600017547607,0.7803199887275696,0.7852799892425537,5,95.94,EFFICIENT,torch-sdpa
|
| 11 |
-
MemEff (PyTorch SDPA),torch_mem_eff,flux_L384,1,1408,24,128,bfloat16,0.7948480010032654,0.7911999821662903,0.7935360074043274,0.7948480248451233,5,100.0,EFFICIENT,torch-sdpa
|
| 12 |
-
MemEff (PyTorch SDPA),torch_mem_eff,flux_L448,1,1472,24,128,bfloat16,0.8463295936584473,0.8449919819831848,0.8459839820861816,0.8461120128631592,5,103.81,EFFICIENT,torch-sdpa
|
| 13 |
-
MemEff (PyTorch SDPA),torch_mem_eff,flux_L512,1,1536,24,128,bfloat16,0.9538687944412232,0.9492800235748291,0.9518399834632874,0.9581760168075562,5,109.12,EFFICIENT,torch-sdpa
|
| 14 |
-
xFormers,xformers_meff,flux_L128,1,1152,24,128,bfloat16,0.4515071928501129,0.44364801049232483,0.4524799883365631,0.4557119905948639,5,83.38,memory_efficient,xformers
|
| 15 |
-
xFormers,xformers_meff,flux_L256,1,1280,24,128,bfloat16,0.46787199974060056,0.46489599347114563,0.4684160053730011,0.46908798813819885,5,90.62,memory_efficient,xformers
|
| 16 |
-
xFormers,xformers_meff,flux_L320,1,1344,24,128,bfloat16,0.6001471996307373,0.596992015838623,0.5984640121459961,0.6016640067100525,5,95.06,memory_efficient,xformers
|
| 17 |
-
xFormers,xformers_meff,flux_L384,1,1408,24,128,bfloat16,0.6023231983184815,0.5997440218925476,0.6031039953231812,0.6032639741897583,5,99.88,memory_efficient,xformers
|
| 18 |
-
xFormers,xformers_meff,flux_L448,1,1472,24,128,bfloat16,0.6411136031150818,0.6381760239601135,0.6414719820022583,0.6421440243721008,5,103.81,memory_efficient,xformers
|
| 19 |
-
xFormers,xformers_meff,flux_L512,1,1536,24,128,bfloat16,0.6594688057899475,0.6441280245780945,0.6496639847755432,0.6527680158615112,5,109.12,memory_efficient,xformers
|
| 20 |
-
Compiled (default),torch_flash_compiled_default,flux_L128,1,1152,24,128,bfloat16,0.5181439876556396,0.5141760110855103,0.5175679922103882,0.5197759866714478,5,83.38,FLASH,torch-sdpa
|
| 21 |
-
Compiled (default),torch_flash_compiled_default,flux_L256,1,1280,24,128,bfloat16,0.5579584002494812,0.5549119710922241,0.5582720041275024,0.5598080158233643,5,90.62,FLASH,torch-sdpa
|
| 22 |
-
Compiled (default),torch_flash_compiled_default,flux_L320,1,1344,24,128,bfloat16,0.6872959971427918,0.6853119730949402,0.687391996383667,0.6883519887924194,5,95.25,FLASH,torch-sdpa
|
| 23 |
-
Compiled (default),torch_flash_compiled_default,flux_L384,1,1408,24,128,bfloat16,0.716153597831726,0.7128639817237854,0.7160959839820862,0.7167680263519287,5,99.88,FLASH,torch-sdpa
|
| 24 |
-
Compiled (default),torch_flash_compiled_default,flux_L448,1,1472,24,128,bfloat16,0.7418303966522217,0.7386879920959473,0.7400959730148315,0.7415040135383606,5,103.81,FLASH,torch-sdpa
|
| 25 |
-
Compiled (default),torch_flash_compiled_default,flux_L512,1,1536,24,128,bfloat16,0.7745471954345703,0.7708160281181335,0.7740799784660339,0.7753919959068298,5,109.12,FLASH,torch-sdpa
|
| 26 |
-
Compiled (max-autotune),torch_flash_compiled_max_autotune,flux_L128,1,1152,24,128,bfloat16,0.6468096017837525,0.6144000291824341,0.6245759725570679,0.6483200192451477,5,67.5,FLASH,torch-sdpa
|
| 27 |
-
Compiled (max-autotune),torch_flash_compiled_max_autotune,flux_L256,1,1280,24,128,bfloat16,0.7060160160064697,0.6689280271530151,0.6851199865341187,0.7184960246086121,5,75.0,FLASH,torch-sdpa
|
| 28 |
-
Compiled (max-autotune),torch_flash_compiled_max_autotune,flux_L320,1,1344,24,128,bfloat16,0.8332608103752136,0.7953600287437439,0.8155840039253235,0.8403519988059998,5,80.38,FLASH,torch-sdpa
|
| 29 |
-
Compiled (max-autotune),torch_flash_compiled_max_autotune,flux_L384,1,1408,24,128,bfloat16,0.8719295978546142,0.8470720052719116,0.849727988243103,0.8745279908180237,5,82.5,FLASH,torch-sdpa
|
| 30 |
-
Compiled (max-autotune),torch_flash_compiled_max_autotune,flux_L448,1,1472,24,128,bfloat16,0.9034304022789001,0.8677120208740234,0.8835520148277283,0.9034240245819092,5,86.25,FLASH,torch-sdpa
|
| 31 |
-
Compiled (max-autotune),torch_flash_compiled_max_autotune,flux_L512,1,1536,24,128,bfloat16,0.9387519836425782,0.9154239892959595,0.9213759899139404,0.9359679818153381,5,90.0,FLASH,torch-sdpa
|
| 32 |
-
HF Kernels Flash Attn,hf_kernels_flash_attn,flux_L128,1,1152,24,128,bfloat16,0.3455295979976654,0.34355199337005615,0.34563198685646057,0.34643200039863586,5,83.38,flash-attn,hf-kernels
|
| 33 |
-
HF Kernels Flash Attn,hf_kernels_flash_attn,flux_L256,1,1280,24,128,bfloat16,0.3756160080432892,0.37411201000213623,0.3752000033855438,0.3770880103111267,5,90.62,flash-attn,hf-kernels
|
| 34 |
-
HF Kernels Flash Attn,hf_kernels_flash_attn,flux_L320,1,1344,24,128,bfloat16,0.4953216016292572,0.49324798583984375,0.49433600902557373,0.49663999676704407,5,95.06,flash-attn,hf-kernels
|
| 35 |
-
HF Kernels Flash Attn,hf_kernels_flash_attn,flux_L384,1,1408,24,128,bfloat16,0.5157055854797363,0.5142719745635986,0.516319990158081,0.516543984413147,5,99.88,flash-attn,hf-kernels
|
| 36 |
-
HF Kernels Flash Attn,hf_kernels_flash_attn,flux_L448,1,1472,24,128,bfloat16,0.5356672048568726,0.5346879959106445,0.5358080267906189,0.5361599922180176,5,103.81,flash-attn,hf-kernels
|
| 37 |
-
HF Kernels Flash Attn,hf_kernels_flash_attn,flux_L512,1,1536,24,128,bfloat16,0.5587136030197144,0.5557760000228882,0.5574079751968384,0.5581120252609253,5,109.12,flash-attn,hf-kernels
|
| 38 |
-
HF Kernels Flash Attn3,hf_kernels_flash_attn3,flux_L128,1,1152,24,128,bfloat16,0.3619711995124817,0.3603839874267578,0.361952006816864,0.3624640107154846,5,83.38,flash-attn3,hf-kernels
|
| 39 |
-
HF Kernels Flash Attn3,hf_kernels_flash_attn3,flux_L256,1,1280,24,128,bfloat16,0.3912447988986969,0.3892799913883209,0.3909760117530823,0.3922559916973114,5,90.62,flash-attn3,hf-kernels
|
| 40 |
-
HF Kernels Flash Attn3,hf_kernels_flash_attn3,flux_L320,1,1344,24,128,bfloat16,0.5258048176765442,0.5240640044212341,0.5248960256576538,0.5248960256576538,5,95.06,flash-attn3,hf-kernels
|
| 41 |
-
HF Kernels Flash Attn3,hf_kernels_flash_attn3,flux_L384,1,1408,24,128,bfloat16,0.5276032090187073,0.5265600085258484,0.5277760028839111,0.5282559990882874,5,99.88,flash-attn3,hf-kernels
|
| 42 |
-
HF Kernels Flash Attn3,hf_kernels_flash_attn3,flux_L448,1,1472,24,128,bfloat16,0.5656383991241455,0.5639039874076843,0.5657920241355896,0.5668479800224304,5,103.81,flash-attn3,hf-kernels
|
| 43 |
-
HF Kernels Flash Attn3,hf_kernels_flash_attn3,flux_L512,1,1536,24,128,bfloat16,0.5789952039718628,0.5689600110054016,0.5698239803314209,0.5713919997215271,5,109.12,flash-attn3,hf-kernels
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
flash_attn/results/artifacts/combine/latency.png
DELETED
Git LFS Details
|
flash_attn/results/artifacts/combine/latency.svg
DELETED
Git LFS Details
|
flash_attn/results/cells/combine.py
DELETED
|
@@ -1,319 +0,0 @@
|
|
| 1 |
-
# /// script
|
| 2 |
-
# requires-python = ">=3.10"
|
| 3 |
-
# dependencies = [
|
| 4 |
-
# "numpy",
|
| 5 |
-
# "torch",
|
| 6 |
-
# "kernels-benchmark-tools",
|
| 7 |
-
# "matplotlib",
|
| 8 |
-
# ]
|
| 9 |
-
#
|
| 10 |
-
# [tool.uv.sources]
|
| 11 |
-
# kernels-benchmark-tools = { git = "https://github.com/drbh/kernels-benchmark-tools.git", branch = "main" }
|
| 12 |
-
# ///
|
| 13 |
-
import os
|
| 14 |
-
import sys
|
| 15 |
-
from pathlib import Path
|
| 16 |
-
import json
|
| 17 |
-
import torch # noqa: F401 # imported because upstream may expect torch to be importable
|
| 18 |
-
import kernels_benchmark_tools as kbt
|
| 19 |
-
|
| 20 |
-
# --- Matplotlib setup and helpers ------------------------------------------------
|
| 21 |
-
import matplotlib as mpl
|
| 22 |
-
import matplotlib.pyplot as plt
|
| 23 |
-
import csv
|
| 24 |
-
|
| 25 |
-
|
| 26 |
-
# Keep text as text (not paths) so CSS can style fonts, size, etc.
|
| 27 |
-
mpl.rcParams["svg.fonttype"] = "none"
|
| 28 |
-
# Make ids deterministic across builds
|
| 29 |
-
mpl.rcParams["svg.hashsalt"] = "latency-benchmark-combined"
|
| 30 |
-
# Avoid auto-closed figures interfering with our tagging
|
| 31 |
-
mpl.rcParams["figure.autolayout"] = True
|
| 32 |
-
# Make background transparent
|
| 33 |
-
mpl.rcParams["figure.facecolor"] = "none"
|
| 34 |
-
mpl.rcParams["axes.facecolor"] = "none"
|
| 35 |
-
mpl.rcParams["savefig.facecolor"] = "none"
|
| 36 |
-
mpl.rcParams["savefig.edgecolor"] = "none"
|
| 37 |
-
|
| 38 |
-
def _slugify(s: str) -> str:
|
| 39 |
-
s = (s or "").strip().lower()
|
| 40 |
-
keep = []
|
| 41 |
-
for ch in s:
|
| 42 |
-
if ch.isalnum():
|
| 43 |
-
keep.append(ch)
|
| 44 |
-
elif ch in (" ", "-", "_", "/", ".", ":"):
|
| 45 |
-
keep.append("-")
|
| 46 |
-
else:
|
| 47 |
-
keep.append("")
|
| 48 |
-
out = "".join(keep)
|
| 49 |
-
while "--" in out:
|
| 50 |
-
out = out.replace("--", "-")
|
| 51 |
-
return out.strip("-") or "unnamed"
|
| 52 |
-
|
| 53 |
-
def _tag_current_figure(default_series_prefix="series"):
|
| 54 |
-
"""Attach SVG ids (gid) to key artists so they can be targeted from CSS."""
|
| 55 |
-
fig = plt.gcf()
|
| 56 |
-
if fig is None:
|
| 57 |
-
return
|
| 58 |
-
|
| 59 |
-
# Tag the figure itself
|
| 60 |
-
fig.set_gid("figure--latency")
|
| 61 |
-
|
| 62 |
-
for ax_idx, ax in enumerate(fig.get_axes(), start=1):
|
| 63 |
-
ax.set_gid(f"axes--{ax_idx}")
|
| 64 |
-
|
| 65 |
-
# Axis labels & title
|
| 66 |
-
if ax.get_title():
|
| 67 |
-
for t in ax.texts:
|
| 68 |
-
if t.get_text() == ax.get_title():
|
| 69 |
-
t.set_gid("title--main")
|
| 70 |
-
if ax.xaxis and ax.xaxis.get_label():
|
| 71 |
-
ax.xaxis.label.set_gid("label--x")
|
| 72 |
-
if ax.yaxis and ax.yaxis.get_label():
|
| 73 |
-
ax.yaxis.label.set_gid("label--y")
|
| 74 |
-
|
| 75 |
-
# Gridlines
|
| 76 |
-
for i, gl in enumerate(ax.get_xgridlines(), start=1):
|
| 77 |
-
gl.set_gid(f"grid-x--{i}")
|
| 78 |
-
for i, gl in enumerate(ax.get_ygridlines(), start=1):
|
| 79 |
-
gl.set_gid(f"grid-y--{i}")
|
| 80 |
-
|
| 81 |
-
# Legend block & entries
|
| 82 |
-
leg = ax.get_legend()
|
| 83 |
-
if leg is not None:
|
| 84 |
-
leg.set_gid("legend")
|
| 85 |
-
for i, txt in enumerate(leg.get_texts(), start=1):
|
| 86 |
-
label_slug = _slugify(txt.get_text())
|
| 87 |
-
txt.set_gid(f"legend-label--{label_slug or i}")
|
| 88 |
-
|
| 89 |
-
# Series (lines, patches)
|
| 90 |
-
# Lines
|
| 91 |
-
line_seen = {}
|
| 92 |
-
for ln in getattr(ax, "lines", []):
|
| 93 |
-
raw_label = ln.get_label() or ""
|
| 94 |
-
# Matplotlib uses labels beginning with "_" for non-legendable items
|
| 95 |
-
label = raw_label if not raw_label.startswith("_") else f"{default_series_prefix}"
|
| 96 |
-
slug = _slugify(label)
|
| 97 |
-
line_seen[slug] = line_seen.get(slug, 0) + 1
|
| 98 |
-
suffix = "" if line_seen[slug] == 1 else f"-{line_seen[slug]}"
|
| 99 |
-
ln.set_gid(f"series--{slug}{suffix}")
|
| 100 |
-
|
| 101 |
-
# Patches (bars, areas)
|
| 102 |
-
patch_seen = {}
|
| 103 |
-
for pt in getattr(ax, "patches", []):
|
| 104 |
-
label = getattr(pt, "get_label", lambda: "")() or f"{default_series_prefix}"
|
| 105 |
-
if isinstance(label, str) and label.startswith("_"):
|
| 106 |
-
label = default_series_prefix
|
| 107 |
-
slug = _slugify(label)
|
| 108 |
-
patch_seen[slug] = patch_seen.get(slug, 0) + 1
|
| 109 |
-
suffix = "" if patch_seen[slug] == 1 else f"-{patch_seen[slug]}"
|
| 110 |
-
pt.set_gid(f"series--{slug}{suffix}")
|
| 111 |
-
|
| 112 |
-
def _postprocess_svg_add_classes(svg_path: Path):
|
| 113 |
-
"""Add convenient CSS classes alongside ids (e.g., class='series grid grid-x')."""
|
| 114 |
-
try:
|
| 115 |
-
import xml.etree.ElementTree as ET
|
| 116 |
-
ET.register_namespace("", "http://www.w3.org/2000/svg")
|
| 117 |
-
tree = ET.parse(svg_path)
|
| 118 |
-
root = tree.getroot()
|
| 119 |
-
for el in root.iter():
|
| 120 |
-
el_id = el.attrib.get("id", "")
|
| 121 |
-
if not el_id:
|
| 122 |
-
continue
|
| 123 |
-
cls = []
|
| 124 |
-
if el_id.startswith("figure--"):
|
| 125 |
-
cls.append("figure")
|
| 126 |
-
elif el_id.startswith("axes--"):
|
| 127 |
-
cls.append("axes")
|
| 128 |
-
elif el_id.startswith("grid-x--"):
|
| 129 |
-
cls += ["grid", "grid-x"]
|
| 130 |
-
elif el_id.startswith("grid-y--"):
|
| 131 |
-
cls += ["grid", "grid-y"]
|
| 132 |
-
elif el_id.startswith("legend"):
|
| 133 |
-
cls.append("legend")
|
| 134 |
-
elif el_id.startswith("label--x"):
|
| 135 |
-
cls.append("xlabel")
|
| 136 |
-
elif el_id.startswith("label--y"):
|
| 137 |
-
cls.append("ylabel")
|
| 138 |
-
elif el_id.startswith("title--"):
|
| 139 |
-
cls.append("title")
|
| 140 |
-
elif el_id.startswith("series--"):
|
| 141 |
-
cls.append("series")
|
| 142 |
-
if cls:
|
| 143 |
-
# Preserve any existing class (unlikely from Matplotlib)
|
| 144 |
-
existing = el.attrib.get("class", "")
|
| 145 |
-
el.set("class", (existing + " " + " ".join(cls)).strip())
|
| 146 |
-
tree.write(svg_path, encoding="utf-8", xml_declaration=True)
|
| 147 |
-
except Exception as e:
|
| 148 |
-
print(f"✗ SVG postprocess (classes) skipped: {e}")
|
| 149 |
-
|
| 150 |
-
# Monkey-patch savefig to force SVG & ensure tagging occurs even if kbt.viz saves internally.
|
| 151 |
-
_orig_savefig = plt.savefig
|
| 152 |
-
def _savefig_svg(fname, *args, **kwargs):
|
| 153 |
-
# Always save as SVG at a stable path for the artifact system
|
| 154 |
-
out = Path("latency.svg")
|
| 155 |
-
kwargs["format"] = "svg"
|
| 156 |
-
# Ensure everything we care about has ids before export
|
| 157 |
-
_tag_current_figure()
|
| 158 |
-
res = _orig_savefig(out, *args, **kwargs)
|
| 159 |
-
# Add helpful CSS classes on top of ids
|
| 160 |
-
_postprocess_svg_add_classes(out)
|
| 161 |
-
print(f"✓ Combined visualization saved as {out}")
|
| 162 |
-
return res
|
| 163 |
-
|
| 164 |
-
plt.savefig = _savefig_svg # apply patch
|
| 165 |
-
|
| 166 |
-
# Capture close calls in case kbt.viz() closes figures before we re-save
|
| 167 |
-
_orig_close = plt.close
|
| 168 |
-
_last_closed = {"fig": None}
|
| 169 |
-
def _capture_close(arg=None):
|
| 170 |
-
try:
|
| 171 |
-
if hasattr(arg, "savefig"): # looks like a Figure
|
| 172 |
-
_last_closed["fig"] = arg
|
| 173 |
-
else:
|
| 174 |
-
_last_closed["fig"] = plt.gcf()
|
| 175 |
-
finally:
|
| 176 |
-
return _orig_close(arg)
|
| 177 |
-
plt.close = _capture_close
|
| 178 |
-
|
| 179 |
-
# --- Locate benchmark artifacts --------------------------------------------------
|
| 180 |
-
cache_dirs = {
|
| 181 |
-
"Flash (PyTorch SDPA)": os.environ.get('UVNOTE_FILE_FLASH_ATTENTION_BENCHMARK'),
|
| 182 |
-
"MemEff (PyTorch SDPA)": os.environ.get('UVNOTE_FILE_MEM_EFFICIENT_ATTENTION_BENCHMARK'),
|
| 183 |
-
"Flash Attn 2": os.environ.get('UVNOTE_FILE_FLASH_ATTN2_BENCHMARK'),
|
| 184 |
-
"xFormers": os.environ.get('UVNOTE_FILE_XFORMERS_BENCHMARK'),
|
| 185 |
-
"SageAttention": os.environ.get('UVNOTE_FILE_SAGE_ATTENTION_BENCHMARK'),
|
| 186 |
-
"Compiled (default)": os.environ.get('UVNOTE_FILE_COMPILED_VARIANTS_BENCHMARK_DEFAULT'),
|
| 187 |
-
"Compiled (max-autotune)": os.environ.get('UVNOTE_FILE_COMPILED_VARIANTS_BENCHMARK_MAX_AUTOTUNE'),
|
| 188 |
-
"HF Kernels Flash Attn": os.environ.get('UVNOTE_FILE_HF_KERNELS_FLASH_ATTN_BENCHMARK'),
|
| 189 |
-
"HF Kernels Flash Attn3": os.environ.get('UVNOTE_FILE_HF_KERNELS_FLASH_ATTN3_BENCHMARK'),
|
| 190 |
-
}
|
| 191 |
-
|
| 192 |
-
print("LOADING BENCHMARK DATA")
|
| 193 |
-
for name, cache_dir in cache_dirs.items():
|
| 194 |
-
print(f"{name:30s}: {cache_dir}")
|
| 195 |
-
print()
|
| 196 |
-
|
| 197 |
-
file_mapping = {
|
| 198 |
-
"Flash (PyTorch SDPA)": "attn.jsonl",
|
| 199 |
-
"MemEff (PyTorch SDPA)": "attn.jsonl",
|
| 200 |
-
"Flash Attn 2": "attn.jsonl",
|
| 201 |
-
"xFormers": "attn.jsonl",
|
| 202 |
-
"SageAttention": "attn.jsonl",
|
| 203 |
-
"Compiled (default)": "attn_default.jsonl",
|
| 204 |
-
"Compiled (max-autotune)": "attn_max_autotune.jsonl",
|
| 205 |
-
"HF Kernels Flash Attn": "attn.jsonl",
|
| 206 |
-
"HF Kernels Flash Attn3": "attn.jsonl",
|
| 207 |
-
}
|
| 208 |
-
|
| 209 |
-
all_paths = []
|
| 210 |
-
for name, cache_dir in cache_dirs.items():
|
| 211 |
-
if cache_dir:
|
| 212 |
-
path = Path(cache_dir) / file_mapping[name]
|
| 213 |
-
if path.exists() and path.stat().st_size > 0:
|
| 214 |
-
all_paths.append(str(path))
|
| 215 |
-
print(f"✓ Found {name}: {path}")
|
| 216 |
-
else:
|
| 217 |
-
print(f"⊘ Empty/Missing {name}: {path}")
|
| 218 |
-
else:
|
| 219 |
-
print(f"✗ No cache dir for {name}")
|
| 220 |
-
print()
|
| 221 |
-
|
| 222 |
-
if not all_paths:
|
| 223 |
-
print("ERROR: No benchmark data files found!")
|
| 224 |
-
# restore patched functions before exiting
|
| 225 |
-
plt.savefig = _orig_savefig
|
| 226 |
-
plt.close = _orig_close
|
| 227 |
-
sys.exit(1)
|
| 228 |
-
|
| 229 |
-
# --- Summary + Visualization -----------------------------------------------------
|
| 230 |
-
print("COMBINED BENCHMARK SUMMARY\n")
|
| 231 |
-
kbt.summarize(all_paths)
|
| 232 |
-
print("\nGENERATING COMBINED VISUALIZATION\n")
|
| 233 |
-
|
| 234 |
-
try:
|
| 235 |
-
# If kbt.viz saves internally, our patched savefig ensures SVG gets written,
|
| 236 |
-
# and it will carry ids/classes for CSS styling.
|
| 237 |
-
kbt.viz(all_paths)
|
| 238 |
-
# Safety net: if kbt.viz didn't save, save now.
|
| 239 |
-
# if not Path("latency.svg").exists():
|
| 240 |
-
# _tag_current_figure()
|
| 241 |
-
# plt.savefig("latency.svg")
|
| 242 |
-
|
| 243 |
-
plt.savefig("latency.svg") # ensure saved with tagging
|
| 244 |
-
|
| 245 |
-
print("✓ SVG visualization ready: latency.svg!")
|
| 246 |
-
except ImportError as e:
|
| 247 |
-
print(f"✗ Visualization requires matplotlib: {e}")
|
| 248 |
-
except Exception as e:
|
| 249 |
-
print(f"✗ Visualization failed: {e}")
|
| 250 |
-
finally:
|
| 251 |
-
# Clean up patches to avoid side effects in later cells
|
| 252 |
-
plt.savefig = _orig_savefig
|
| 253 |
-
plt.close = _orig_close
|
| 254 |
-
|
| 255 |
-
print()
|
| 256 |
-
print("ANALYSIS COMPLETE")
|
| 257 |
-
print(f"Total implementations analyzed: {len(all_paths)}")
|
| 258 |
-
print(f"\nImplementations included:")
|
| 259 |
-
for name, cache_dir in cache_dirs.items():
|
| 260 |
-
if cache_dir:
|
| 261 |
-
path = Path(cache_dir) / file_mapping[name]
|
| 262 |
-
if path.exists() and path.stat().st_size > 0:
|
| 263 |
-
print(f" ✓ {name}")
|
| 264 |
-
|
| 265 |
-
|
| 266 |
-
|
| 267 |
-
# Collect all benchmark data and export to CSV
|
| 268 |
-
all_data = {}
|
| 269 |
-
for name, cache_dir in cache_dirs.items():
|
| 270 |
-
if cache_dir:
|
| 271 |
-
path = Path(cache_dir) / file_mapping[name]
|
| 272 |
-
if path.exists() and path.stat().st_size > 0:
|
| 273 |
-
with open(path, 'r') as f:
|
| 274 |
-
records = [json.loads(line) for line in f]
|
| 275 |
-
all_data[name] = records
|
| 276 |
-
|
| 277 |
-
# Export to CSV
|
| 278 |
-
csv_path = Path("latency.csv")
|
| 279 |
-
with open(csv_path, 'w', newline='') as csvfile:
|
| 280 |
-
writer = csv.writer(csvfile)
|
| 281 |
-
|
| 282 |
-
# Write header
|
| 283 |
-
header = ["Implementation", "Impl ID", "Workload", "Batch", "Seq Length", "Heads", "Head Dim", "Dtype",
|
| 284 |
-
"Mean (ms)", "P10 (ms)", "P50 (ms)", "P90 (ms)", "Reps",
|
| 285 |
-
# "Compile (ms)",
|
| 286 |
-
"Peak Mem (MB)", "Backend", "Family"]
|
| 287 |
-
writer.writerow(header)
|
| 288 |
-
|
| 289 |
-
# Write data rows
|
| 290 |
-
for impl_name, records in all_data.items():
|
| 291 |
-
for record in records:
|
| 292 |
-
wl = record.get('wl', {})
|
| 293 |
-
lat = record.get('lat_ms', {})
|
| 294 |
-
tags = record.get('tags', {})
|
| 295 |
-
|
| 296 |
-
row = [
|
| 297 |
-
impl_name,
|
| 298 |
-
record.get('impl', ''),
|
| 299 |
-
wl.get('name', ''),
|
| 300 |
-
wl.get('batch', ''),
|
| 301 |
-
wl.get('seq_len', ''),
|
| 302 |
-
wl.get('heads', ''),
|
| 303 |
-
wl.get('head_dim', ''),
|
| 304 |
-
wl.get('dtype', ''),
|
| 305 |
-
lat.get('mean', ''),
|
| 306 |
-
lat.get('p10', ''),
|
| 307 |
-
lat.get('p50', ''),
|
| 308 |
-
lat.get('p90', ''),
|
| 309 |
-
lat.get('reps', ''),
|
| 310 |
-
# record.get('compile_ms', ''),
|
| 311 |
-
round(record.get('peak_bytes', 0) / 1024 / 1024, 2) if record.get('peak_bytes') else '',
|
| 312 |
-
tags.get('backend', ''),
|
| 313 |
-
tags.get('family', ''),
|
| 314 |
-
]
|
| 315 |
-
writer.writerow(row)
|
| 316 |
-
|
| 317 |
-
print(f"✓ CSV export complete: {csv_path}")
|
| 318 |
-
print(f"Total implementations: {len(all_data)}")
|
| 319 |
-
print(f"Total records: {sum(len(records) for records in all_data.values())}")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
flash_attn/results/combined_results.html
DELETED
|
The diff for this file is too large to render.
See raw diff
|
|
|
flash_attn/results/index.html
DELETED
|
@@ -1,88 +0,0 @@
|
|
| 1 |
-
<!DOCTYPE html>
|
| 2 |
-
<html>
|
| 3 |
-
<head>
|
| 4 |
-
<meta charset='UTF-8'>
|
| 5 |
-
<meta name='viewport' content='width=device-width, initial-scale=1.0'>
|
| 6 |
-
<title>Index of /flash_attn/results</title>
|
| 7 |
-
<style>
|
| 8 |
-
:root {
|
| 9 |
-
--bg-primary: #0a0a0a;
|
| 10 |
-
--bg-secondary: #121212;
|
| 11 |
-
--bg-tertiary: #181818;
|
| 12 |
-
--text-primary: #e0e0e0;
|
| 13 |
-
--text-secondary: #888888;
|
| 14 |
-
--text-link: #64b5f6;
|
| 15 |
-
--border-primary: #2a2a2a;
|
| 16 |
-
}
|
| 17 |
-
body {
|
| 18 |
-
font-family: system-ui, -apple-system, Segoe UI, Roboto, Helvetica, Arial, sans-serif;
|
| 19 |
-
background: var(--bg-primary);
|
| 20 |
-
color: var(--text-primary);
|
| 21 |
-
margin: 0;
|
| 22 |
-
padding: 16px;
|
| 23 |
-
max-width: 900px;
|
| 24 |
-
margin: 0 auto;
|
| 25 |
-
}
|
| 26 |
-
.controls {
|
| 27 |
-
display: flex;
|
| 28 |
-
justify-content: flex-end;
|
| 29 |
-
margin-bottom: 1rem;
|
| 30 |
-
}
|
| 31 |
-
.back-button {
|
| 32 |
-
background: var(--bg-secondary);
|
| 33 |
-
border: 1px solid var(--border-primary);
|
| 34 |
-
padding: 8px 12px;
|
| 35 |
-
border-radius: 4px;
|
| 36 |
-
color: var(--text-secondary);
|
| 37 |
-
cursor: pointer;
|
| 38 |
-
font-size: 0.9rem;
|
| 39 |
-
text-decoration: none;
|
| 40 |
-
display: inline-block;
|
| 41 |
-
}
|
| 42 |
-
.back-button:hover {
|
| 43 |
-
color: var(--text-primary);
|
| 44 |
-
background: var(--bg-tertiary);
|
| 45 |
-
}
|
| 46 |
-
h1 {
|
| 47 |
-
font-size: 1.5em;
|
| 48 |
-
margin: 1rem 0;
|
| 49 |
-
color: var(--text-primary);
|
| 50 |
-
border-bottom: 1px solid var(--border-primary);
|
| 51 |
-
padding-bottom: 0.5rem;
|
| 52 |
-
}
|
| 53 |
-
ul {
|
| 54 |
-
list-style-type: none;
|
| 55 |
-
padding: 0;
|
| 56 |
-
}
|
| 57 |
-
li {
|
| 58 |
-
margin: 0;
|
| 59 |
-
border-bottom: 1px solid var(--border-primary);
|
| 60 |
-
}
|
| 61 |
-
li:last-child {
|
| 62 |
-
border-bottom: none;
|
| 63 |
-
}
|
| 64 |
-
a {
|
| 65 |
-
display: block;
|
| 66 |
-
padding: 0.75rem 0.5rem;
|
| 67 |
-
text-decoration: none;
|
| 68 |
-
color: var(--text-link);
|
| 69 |
-
transition: background 0.2s ease;
|
| 70 |
-
}
|
| 71 |
-
a:hover {
|
| 72 |
-
background: var(--bg-secondary);
|
| 73 |
-
}
|
| 74 |
-
.dir {
|
| 75 |
-
font-weight: 500;
|
| 76 |
-
}
|
| 77 |
-
</style>
|
| 78 |
-
</head>
|
| 79 |
-
<body>
|
| 80 |
-
<div class='controls'>
|
| 81 |
-
<a href='../index.html' class='back-button'>← back</a>
|
| 82 |
-
</div>
|
| 83 |
-
<h1>Index of /flash_attn/results</h1>
|
| 84 |
-
<ul>
|
| 85 |
-
<li><a href='combined_results.html' class='file'>combined_results.html</a></li>
|
| 86 |
-
</ul>
|
| 87 |
-
</body>
|
| 88 |
-
</html>
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
index.html
DELETED
|
@@ -1,85 +0,0 @@
|
|
| 1 |
-
<!DOCTYPE html>
|
| 2 |
-
<html>
|
| 3 |
-
<head>
|
| 4 |
-
<meta charset='UTF-8'>
|
| 5 |
-
<meta name='viewport' content='width=device-width, initial-scale=1.0'>
|
| 6 |
-
<title>Index of /</title>
|
| 7 |
-
<style>
|
| 8 |
-
:root {
|
| 9 |
-
--bg-primary: #0a0a0a;
|
| 10 |
-
--bg-secondary: #121212;
|
| 11 |
-
--bg-tertiary: #181818;
|
| 12 |
-
--text-primary: #e0e0e0;
|
| 13 |
-
--text-secondary: #888888;
|
| 14 |
-
--text-link: #64b5f6;
|
| 15 |
-
--border-primary: #2a2a2a;
|
| 16 |
-
}
|
| 17 |
-
body {
|
| 18 |
-
font-family: system-ui, -apple-system, Segoe UI, Roboto, Helvetica, Arial, sans-serif;
|
| 19 |
-
background: var(--bg-primary);
|
| 20 |
-
color: var(--text-primary);
|
| 21 |
-
margin: 0;
|
| 22 |
-
padding: 16px;
|
| 23 |
-
max-width: 900px;
|
| 24 |
-
margin: 0 auto;
|
| 25 |
-
}
|
| 26 |
-
.controls {
|
| 27 |
-
display: flex;
|
| 28 |
-
justify-content: flex-end;
|
| 29 |
-
margin-bottom: 1rem;
|
| 30 |
-
}
|
| 31 |
-
.back-button {
|
| 32 |
-
background: var(--bg-secondary);
|
| 33 |
-
border: 1px solid var(--border-primary);
|
| 34 |
-
padding: 8px 12px;
|
| 35 |
-
border-radius: 4px;
|
| 36 |
-
color: var(--text-secondary);
|
| 37 |
-
cursor: pointer;
|
| 38 |
-
font-size: 0.9rem;
|
| 39 |
-
text-decoration: none;
|
| 40 |
-
display: inline-block;
|
| 41 |
-
}
|
| 42 |
-
.back-button:hover {
|
| 43 |
-
color: var(--text-primary);
|
| 44 |
-
background: var(--bg-tertiary);
|
| 45 |
-
}
|
| 46 |
-
h1 {
|
| 47 |
-
font-size: 1.5em;
|
| 48 |
-
margin: 1rem 0;
|
| 49 |
-
color: var(--text-primary);
|
| 50 |
-
border-bottom: 1px solid var(--border-primary);
|
| 51 |
-
padding-bottom: 0.5rem;
|
| 52 |
-
}
|
| 53 |
-
ul {
|
| 54 |
-
list-style-type: none;
|
| 55 |
-
padding: 0;
|
| 56 |
-
}
|
| 57 |
-
li {
|
| 58 |
-
margin: 0;
|
| 59 |
-
border-bottom: 1px solid var(--border-primary);
|
| 60 |
-
}
|
| 61 |
-
li:last-child {
|
| 62 |
-
border-bottom: none;
|
| 63 |
-
}
|
| 64 |
-
a {
|
| 65 |
-
display: block;
|
| 66 |
-
padding: 0.75rem 0.5rem;
|
| 67 |
-
text-decoration: none;
|
| 68 |
-
color: var(--text-link);
|
| 69 |
-
transition: background 0.2s ease;
|
| 70 |
-
}
|
| 71 |
-
a:hover {
|
| 72 |
-
background: var(--bg-secondary);
|
| 73 |
-
}
|
| 74 |
-
.dir {
|
| 75 |
-
font-weight: 500;
|
| 76 |
-
}
|
| 77 |
-
</style>
|
| 78 |
-
</head>
|
| 79 |
-
<body>
|
| 80 |
-
<h1>Index of /</h1>
|
| 81 |
-
<ul>
|
| 82 |
-
<li><a href='flash_attn/index.html' class='dir'>flash_attn/</a></li>
|
| 83 |
-
</ul>
|
| 84 |
-
</body>
|
| 85 |
-
</html>
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
megablocks/cells/forward_and_backward.py
DELETED
|
@@ -1,196 +0,0 @@
|
|
| 1 |
-
# /// script
|
| 2 |
-
# requires-python = ">=3.12"
|
| 3 |
-
# dependencies = [
|
| 4 |
-
# "accelerate>=1.10.1",
|
| 5 |
-
# "torch>=2.7.0",
|
| 6 |
-
# "kernels==0.10.0",
|
| 7 |
-
# "transformers@https://github.com/huggingface/transformers.git",
|
| 8 |
-
# "ipdb>=0.13.13",
|
| 9 |
-
# "matplotlib>=3.7.2",
|
| 10 |
-
# "numpy>=1.24.3",
|
| 11 |
-
# ]
|
| 12 |
-
# ///
|
| 13 |
-
|
| 14 |
-
import torch
|
| 15 |
-
from transformers import GptOssForCausalLM, PreTrainedTokenizerFast, Mxfp4Config
|
| 16 |
-
import time
|
| 17 |
-
import torch.nn as nn
|
| 18 |
-
from kernels import register_kernel_mapping, Mode, LayerRepository, replace_kernel_forward_from_hub
|
| 19 |
-
import sys
|
| 20 |
-
import torch.profiler
|
| 21 |
-
import gc
|
| 22 |
-
import logging
|
| 23 |
-
from transformers.models.gpt_oss.modeling_gpt_oss import GptOssRMSNorm
|
| 24 |
-
|
| 25 |
-
# remove liger kernel for testing
|
| 26 |
-
replace_kernel_forward_from_hub(GptOssRMSNorm, None)
|
| 27 |
-
|
| 28 |
-
# set to debug logging
|
| 29 |
-
logging.basicConfig(level=logging.INFO)
|
| 30 |
-
|
| 31 |
-
def reset_peak_memory_stats():
|
| 32 |
-
"""Clear CUDA cache and reset memory allocation counters."""
|
| 33 |
-
torch.cuda.empty_cache()
|
| 34 |
-
if torch.cuda.is_available():
|
| 35 |
-
torch.cuda.reset_peak_memory_stats()
|
| 36 |
-
gc.collect()
|
| 37 |
-
|
| 38 |
-
def get_memory_stats():
|
| 39 |
-
"""Get current and peak CUDA memory usage."""
|
| 40 |
-
if not torch.cuda.is_available():
|
| 41 |
-
return {"allocated_gb": 0, "peak_gb": 0, "reserved_gb": 0}
|
| 42 |
-
return {
|
| 43 |
-
"allocated_gb": torch.cuda.memory_allocated() / 1e9,
|
| 44 |
-
"peak_gb": torch.cuda.max_memory_allocated() / 1e9,
|
| 45 |
-
"reserved_gb": torch.cuda.memory_reserved() / 1e9,
|
| 46 |
-
}
|
| 47 |
-
|
| 48 |
-
def override_kernel_layer_name(cls_name: str, value) -> bool:
|
| 49 |
-
"""Helper to dynamically override the kernel_layer_name in a model class."""
|
| 50 |
-
for mod in sys.modules.values():
|
| 51 |
-
if mod is None:
|
| 52 |
-
continue
|
| 53 |
-
obj = getattr(mod, cls_name, None)
|
| 54 |
-
if isinstance(obj, type) and issubclass(obj, nn.Module):
|
| 55 |
-
setattr(obj, "kernel_layer_name", value)
|
| 56 |
-
print(f"Overrode {cls_name}.kernel_layer_name to {value}")
|
| 57 |
-
return True
|
| 58 |
-
return False
|
| 59 |
-
|
| 60 |
-
|
| 61 |
-
# Init the model the normal way
|
| 62 |
-
model_id = "openai/gpt-oss-20b"
|
| 63 |
-
tokenizer = PreTrainedTokenizerFast.from_pretrained(model_id)
|
| 64 |
-
quantization_config = Mxfp4Config(dequantize=True)
|
| 65 |
-
|
| 66 |
-
model = GptOssForCausalLM.from_pretrained(
|
| 67 |
-
model_id,
|
| 68 |
-
dtype="bfloat16",
|
| 69 |
-
device_map="auto",
|
| 70 |
-
use_kernels=True,
|
| 71 |
-
quantization_config=quantization_config,
|
| 72 |
-
).eval()
|
| 73 |
-
|
| 74 |
-
messages = [
|
| 75 |
-
{"role": "system", "content": "What is Tensor Parallelism?"},
|
| 76 |
-
]
|
| 77 |
-
|
| 78 |
-
inputs = tokenizer.apply_chat_template(
|
| 79 |
-
messages,
|
| 80 |
-
add_generation_prompt=True,
|
| 81 |
-
return_tensors="pt",
|
| 82 |
-
return_dict=True,
|
| 83 |
-
reasoning_effort="low",
|
| 84 |
-
).to("cuda")
|
| 85 |
-
|
| 86 |
-
max_tokens = 128 # Reduced to help with memory usage
|
| 87 |
-
|
| 88 |
-
# Clear memory before backward pass
|
| 89 |
-
reset_peak_memory_stats()
|
| 90 |
-
print(f"Pre-generation memory: {get_memory_stats()}")
|
| 91 |
-
|
| 92 |
-
# forward and backward pass
|
| 93 |
-
with torch.autograd.set_grad_enabled(True):
|
| 94 |
-
start_time = time.perf_counter()
|
| 95 |
-
generated = model.generate(
|
| 96 |
-
**inputs,
|
| 97 |
-
max_new_tokens=max_tokens,
|
| 98 |
-
do_sample=False,
|
| 99 |
-
temperature=None,
|
| 100 |
-
)
|
| 101 |
-
end_time = time.perf_counter()
|
| 102 |
-
print(tokenizer.decode(generated[0], skip_special_tokens=False))
|
| 103 |
-
print(f"Generation took {end_time - start_time:.2f} seconds")
|
| 104 |
-
print(f"Post-generation memory: {get_memory_stats()}")
|
| 105 |
-
|
| 106 |
-
# Use gradient checkpointing to reduce memory usage
|
| 107 |
-
if hasattr(model, 'gradient_checkpointing_enable'):
|
| 108 |
-
model.gradient_checkpointing_enable()
|
| 109 |
-
print("Enabled gradient checkpointing")
|
| 110 |
-
|
| 111 |
-
# Reduce sequence length if needed for memory
|
| 112 |
-
max_seq_len = 512 # Limit sequence length for backward pass
|
| 113 |
-
if generated.size(1) > max_seq_len:
|
| 114 |
-
print(f"Truncating sequence from {generated.size(1)} to {max_seq_len} tokens")
|
| 115 |
-
full_sequence = generated[:, -max_seq_len:]
|
| 116 |
-
else:
|
| 117 |
-
full_sequence = generated
|
| 118 |
-
|
| 119 |
-
# Get model outputs for the full sequence
|
| 120 |
-
model.train() # Enable dropout and other training behaviors
|
| 121 |
-
|
| 122 |
-
try:
|
| 123 |
-
outputs = model(
|
| 124 |
-
input_ids=full_sequence,
|
| 125 |
-
labels=full_sequence, # This will compute loss internally
|
| 126 |
-
return_dict=True
|
| 127 |
-
)
|
| 128 |
-
print(f"Post-forward memory: {get_memory_stats()}")
|
| 129 |
-
|
| 130 |
-
# If model doesn't compute loss, compute it manually
|
| 131 |
-
if outputs.loss is None:
|
| 132 |
-
shift_logits = outputs.logits[..., :-1, :].contiguous()
|
| 133 |
-
shift_labels = full_sequence[..., 1:].contiguous()
|
| 134 |
-
|
| 135 |
-
# Use CrossEntropyLoss with ignore_index for padding tokens
|
| 136 |
-
loss_fct = torch.nn.CrossEntropyLoss(ignore_index=tokenizer.pad_token_id if tokenizer.pad_token_id is not None else -100)
|
| 137 |
-
loss = loss_fct(
|
| 138 |
-
shift_logits.view(-1, shift_logits.size(-1)),
|
| 139 |
-
shift_labels.view(-1)
|
| 140 |
-
)
|
| 141 |
-
else:
|
| 142 |
-
loss = outputs.loss
|
| 143 |
-
|
| 144 |
-
print(f"Loss: {loss.item():.4f}")
|
| 145 |
-
|
| 146 |
-
# Clear intermediate tensors to save memory
|
| 147 |
-
del outputs
|
| 148 |
-
torch.cuda.empty_cache()
|
| 149 |
-
|
| 150 |
-
# Perform backward pass with memory management
|
| 151 |
-
print("Running backward pass...")
|
| 152 |
-
print(f"Pre-backward memory: {get_memory_stats()}")
|
| 153 |
-
|
| 154 |
-
loss.backward()
|
| 155 |
-
print(f"Post-backward memory: {get_memory_stats()}")
|
| 156 |
-
|
| 157 |
-
except torch.cuda.OutOfMemoryError as e:
|
| 158 |
-
print(f"OOM during forward/backward pass: {e}")
|
| 159 |
-
print("Try reducing max_tokens or max_seq_len")
|
| 160 |
-
raise
|
| 161 |
-
|
| 162 |
-
# Calculate gradient statistics and print sample gradients
|
| 163 |
-
total_norm = 0.0
|
| 164 |
-
param_count = 0
|
| 165 |
-
grad_samples = {}
|
| 166 |
-
|
| 167 |
-
for name, p in model.named_parameters():
|
| 168 |
-
if p.grad is not None:
|
| 169 |
-
param_count += 1
|
| 170 |
-
grad_norm = p.grad.data.norm(2).item()
|
| 171 |
-
total_norm += grad_norm ** 2
|
| 172 |
-
|
| 173 |
-
# Collect gradient statistics for key layers
|
| 174 |
-
if any(key in name for key in ['embed', 'lm_head', 'mlp.up', 'mlp.down', 'self_attn.q_proj', 'norm']):
|
| 175 |
-
grad_samples[name] = {
|
| 176 |
-
'norm': grad_norm,
|
| 177 |
-
'mean': p.grad.data.mean().item(),
|
| 178 |
-
'std': p.grad.data.std().item(),
|
| 179 |
-
'max': p.grad.data.max().item(),
|
| 180 |
-
'min': p.grad.data.min().item(),
|
| 181 |
-
}
|
| 182 |
-
|
| 183 |
-
total_norm = total_norm ** 0.5
|
| 184 |
-
|
| 185 |
-
print(f"\nGradient norm: {total_norm:.4f}")
|
| 186 |
-
print(f"Parameters with gradients: {param_count}")
|
| 187 |
-
|
| 188 |
-
# Print sample gradients from important layers
|
| 189 |
-
print("\nSample gradient statistics:")
|
| 190 |
-
for i, (name, stats) in enumerate(list(grad_samples.items())[:10]):
|
| 191 |
-
print(f" {name[:60]:<60} | norm: {stats['norm']:.4e} | mean: {stats['mean']:.4e} | std: {stats['std']:.4e}")
|
| 192 |
-
|
| 193 |
-
# Optional: zero gradients for next iteration
|
| 194 |
-
model.zero_grad()
|
| 195 |
-
model.eval() # Switch back to eval mode
|
| 196 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
megablocks/cells/forward_and_backward_no_kernel.py
DELETED
|
@@ -1,196 +0,0 @@
|
|
| 1 |
-
# /// script
|
| 2 |
-
# requires-python = ">=3.12"
|
| 3 |
-
# dependencies = [
|
| 4 |
-
# "accelerate>=1.10.1",
|
| 5 |
-
# "torch>=2.7.0",
|
| 6 |
-
# "kernels==0.10.0",
|
| 7 |
-
# "transformers@https://github.com/huggingface/transformers.git",
|
| 8 |
-
# "ipdb>=0.13.13",
|
| 9 |
-
# "matplotlib>=3.7.2",
|
| 10 |
-
# "numpy>=1.24.3",
|
| 11 |
-
# ]
|
| 12 |
-
# ///
|
| 13 |
-
|
| 14 |
-
import torch
|
| 15 |
-
from transformers import GptOssForCausalLM, PreTrainedTokenizerFast, Mxfp4Config
|
| 16 |
-
import time
|
| 17 |
-
import torch.nn as nn
|
| 18 |
-
from kernels import register_kernel_mapping, Mode, LayerRepository, replace_kernel_forward_from_hub
|
| 19 |
-
import sys
|
| 20 |
-
import torch.profiler
|
| 21 |
-
import gc
|
| 22 |
-
import logging
|
| 23 |
-
from transformers.models.gpt_oss.modeling_gpt_oss import GptOssRMSNorm
|
| 24 |
-
|
| 25 |
-
# remove liger kernel for testing
|
| 26 |
-
replace_kernel_forward_from_hub(GptOssRMSNorm, None)
|
| 27 |
-
|
| 28 |
-
# set to debug logging
|
| 29 |
-
logging.basicConfig(level=logging.INFO)
|
| 30 |
-
|
| 31 |
-
def reset_peak_memory_stats():
|
| 32 |
-
"""Clear CUDA cache and reset memory allocation counters."""
|
| 33 |
-
torch.cuda.empty_cache()
|
| 34 |
-
if torch.cuda.is_available():
|
| 35 |
-
torch.cuda.reset_peak_memory_stats()
|
| 36 |
-
gc.collect()
|
| 37 |
-
|
| 38 |
-
def get_memory_stats():
|
| 39 |
-
"""Get current and peak CUDA memory usage."""
|
| 40 |
-
if not torch.cuda.is_available():
|
| 41 |
-
return {"allocated_gb": 0, "peak_gb": 0, "reserved_gb": 0}
|
| 42 |
-
return {
|
| 43 |
-
"allocated_gb": torch.cuda.memory_allocated() / 1e9,
|
| 44 |
-
"peak_gb": torch.cuda.max_memory_allocated() / 1e9,
|
| 45 |
-
"reserved_gb": torch.cuda.memory_reserved() / 1e9,
|
| 46 |
-
}
|
| 47 |
-
|
| 48 |
-
def override_kernel_layer_name(cls_name: str, value) -> bool:
|
| 49 |
-
"""Helper to dynamically override the kernel_layer_name in a model class."""
|
| 50 |
-
for mod in sys.modules.values():
|
| 51 |
-
if mod is None:
|
| 52 |
-
continue
|
| 53 |
-
obj = getattr(mod, cls_name, None)
|
| 54 |
-
if isinstance(obj, type) and issubclass(obj, nn.Module):
|
| 55 |
-
setattr(obj, "kernel_layer_name", value)
|
| 56 |
-
print(f"Overrode {cls_name}.kernel_layer_name to {value}")
|
| 57 |
-
return True
|
| 58 |
-
return False
|
| 59 |
-
|
| 60 |
-
|
| 61 |
-
# Init the model the normal way
|
| 62 |
-
model_id = "openai/gpt-oss-20b"
|
| 63 |
-
tokenizer = PreTrainedTokenizerFast.from_pretrained(model_id)
|
| 64 |
-
quantization_config = Mxfp4Config(dequantize=True)
|
| 65 |
-
|
| 66 |
-
model = GptOssForCausalLM.from_pretrained(
|
| 67 |
-
model_id,
|
| 68 |
-
dtype="bfloat16",
|
| 69 |
-
device_map="auto",
|
| 70 |
-
use_kernels=False,
|
| 71 |
-
quantization_config=quantization_config,
|
| 72 |
-
).eval()
|
| 73 |
-
|
| 74 |
-
messages = [
|
| 75 |
-
{"role": "system", "content": "What is Tensor Parallelism?"},
|
| 76 |
-
]
|
| 77 |
-
|
| 78 |
-
inputs = tokenizer.apply_chat_template(
|
| 79 |
-
messages,
|
| 80 |
-
add_generation_prompt=True,
|
| 81 |
-
return_tensors="pt",
|
| 82 |
-
return_dict=True,
|
| 83 |
-
reasoning_effort="low",
|
| 84 |
-
).to("cuda")
|
| 85 |
-
|
| 86 |
-
max_tokens = 128 # Reduced to help with memory usage
|
| 87 |
-
|
| 88 |
-
# Clear memory before backward pass
|
| 89 |
-
reset_peak_memory_stats()
|
| 90 |
-
print(f"Pre-generation memory: {get_memory_stats()}")
|
| 91 |
-
|
| 92 |
-
# forward and backward pass
|
| 93 |
-
with torch.autograd.set_grad_enabled(True):
|
| 94 |
-
start_time = time.perf_counter()
|
| 95 |
-
generated = model.generate(
|
| 96 |
-
**inputs,
|
| 97 |
-
max_new_tokens=max_tokens,
|
| 98 |
-
do_sample=False,
|
| 99 |
-
temperature=None,
|
| 100 |
-
)
|
| 101 |
-
end_time = time.perf_counter()
|
| 102 |
-
print(tokenizer.decode(generated[0], skip_special_tokens=False))
|
| 103 |
-
print(f"Generation took {end_time - start_time:.2f} seconds")
|
| 104 |
-
print(f"Post-generation memory: {get_memory_stats()}")
|
| 105 |
-
|
| 106 |
-
# Use gradient checkpointing to reduce memory usage
|
| 107 |
-
if hasattr(model, 'gradient_checkpointing_enable'):
|
| 108 |
-
model.gradient_checkpointing_enable()
|
| 109 |
-
print("Enabled gradient checkpointing")
|
| 110 |
-
|
| 111 |
-
# Reduce sequence length if needed for memory
|
| 112 |
-
max_seq_len = 512 # Limit sequence length for backward pass
|
| 113 |
-
if generated.size(1) > max_seq_len:
|
| 114 |
-
print(f"Truncating sequence from {generated.size(1)} to {max_seq_len} tokens")
|
| 115 |
-
full_sequence = generated[:, -max_seq_len:]
|
| 116 |
-
else:
|
| 117 |
-
full_sequence = generated
|
| 118 |
-
|
| 119 |
-
# Get model outputs for the full sequence
|
| 120 |
-
model.train() # Enable dropout and other training behaviors
|
| 121 |
-
|
| 122 |
-
try:
|
| 123 |
-
outputs = model(
|
| 124 |
-
input_ids=full_sequence,
|
| 125 |
-
labels=full_sequence, # This will compute loss internally
|
| 126 |
-
return_dict=True
|
| 127 |
-
)
|
| 128 |
-
print(f"Post-forward memory: {get_memory_stats()}")
|
| 129 |
-
|
| 130 |
-
# If model doesn't compute loss, compute it manually
|
| 131 |
-
if outputs.loss is None:
|
| 132 |
-
shift_logits = outputs.logits[..., :-1, :].contiguous()
|
| 133 |
-
shift_labels = full_sequence[..., 1:].contiguous()
|
| 134 |
-
|
| 135 |
-
# Use CrossEntropyLoss with ignore_index for padding tokens
|
| 136 |
-
loss_fct = torch.nn.CrossEntropyLoss(ignore_index=tokenizer.pad_token_id if tokenizer.pad_token_id is not None else -100)
|
| 137 |
-
loss = loss_fct(
|
| 138 |
-
shift_logits.view(-1, shift_logits.size(-1)),
|
| 139 |
-
shift_labels.view(-1)
|
| 140 |
-
)
|
| 141 |
-
else:
|
| 142 |
-
loss = outputs.loss
|
| 143 |
-
|
| 144 |
-
print(f"Loss: {loss.item():.4f}")
|
| 145 |
-
|
| 146 |
-
# Clear intermediate tensors to save memory
|
| 147 |
-
del outputs
|
| 148 |
-
torch.cuda.empty_cache()
|
| 149 |
-
|
| 150 |
-
# Perform backward pass with memory management
|
| 151 |
-
print("Running backward pass...")
|
| 152 |
-
print(f"Pre-backward memory: {get_memory_stats()}")
|
| 153 |
-
|
| 154 |
-
loss.backward()
|
| 155 |
-
print(f"Post-backward memory: {get_memory_stats()}")
|
| 156 |
-
|
| 157 |
-
except torch.cuda.OutOfMemoryError as e:
|
| 158 |
-
print(f"OOM during forward/backward pass: {e}")
|
| 159 |
-
print("Try reducing max_tokens or max_seq_len")
|
| 160 |
-
raise
|
| 161 |
-
|
| 162 |
-
# Calculate gradient statistics and print sample gradients
|
| 163 |
-
total_norm = 0.0
|
| 164 |
-
param_count = 0
|
| 165 |
-
grad_samples = {}
|
| 166 |
-
|
| 167 |
-
for name, p in model.named_parameters():
|
| 168 |
-
if p.grad is not None:
|
| 169 |
-
param_count += 1
|
| 170 |
-
grad_norm = p.grad.data.norm(2).item()
|
| 171 |
-
total_norm += grad_norm ** 2
|
| 172 |
-
|
| 173 |
-
# Collect gradient statistics for key layers
|
| 174 |
-
if any(key in name for key in ['embed', 'lm_head', 'mlp.up', 'mlp.down', 'self_attn.q_proj', 'norm']):
|
| 175 |
-
grad_samples[name] = {
|
| 176 |
-
'norm': grad_norm,
|
| 177 |
-
'mean': p.grad.data.mean().item(),
|
| 178 |
-
'std': p.grad.data.std().item(),
|
| 179 |
-
'max': p.grad.data.max().item(),
|
| 180 |
-
'min': p.grad.data.min().item(),
|
| 181 |
-
}
|
| 182 |
-
|
| 183 |
-
total_norm = total_norm ** 0.5
|
| 184 |
-
|
| 185 |
-
print(f"\nGradient norm: {total_norm:.4f}")
|
| 186 |
-
print(f"Parameters with gradients: {param_count}")
|
| 187 |
-
|
| 188 |
-
# Print sample gradients from important layers
|
| 189 |
-
print("\nSample gradient statistics:")
|
| 190 |
-
for i, (name, stats) in enumerate(list(grad_samples.items())[:10]):
|
| 191 |
-
print(f" {name[:60]:<60} | norm: {stats['norm']:.4e} | mean: {stats['mean']:.4e} | std: {stats['std']:.4e}")
|
| 192 |
-
|
| 193 |
-
# Optional: zero gradients for next iteration
|
| 194 |
-
model.zero_grad()
|
| 195 |
-
model.eval() # Switch back to eval mode
|
| 196 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
megablocks/cells/forward_only.py
DELETED
|
@@ -1,101 +0,0 @@
|
|
| 1 |
-
# /// script
|
| 2 |
-
# requires-python = ">=3.12"
|
| 3 |
-
# dependencies = [
|
| 4 |
-
# "accelerate>=1.10.1",
|
| 5 |
-
# "torch>=2.7.0",
|
| 6 |
-
# "kernels==0.10.0",
|
| 7 |
-
# "transformers@https://github.com/huggingface/transformers.git",
|
| 8 |
-
# "ipdb>=0.13.13",
|
| 9 |
-
# "matplotlib>=3.7.2",
|
| 10 |
-
# "numpy>=1.24.3",
|
| 11 |
-
# ]
|
| 12 |
-
# ///
|
| 13 |
-
|
| 14 |
-
import torch
|
| 15 |
-
from transformers import GptOssForCausalLM, PreTrainedTokenizerFast, Mxfp4Config
|
| 16 |
-
import time
|
| 17 |
-
import torch.nn as nn
|
| 18 |
-
from kernels import register_kernel_mapping, Mode, LayerRepository, replace_kernel_forward_from_hub
|
| 19 |
-
import sys
|
| 20 |
-
import torch.profiler
|
| 21 |
-
import gc
|
| 22 |
-
import logging
|
| 23 |
-
from transformers.models.gpt_oss.modeling_gpt_oss import GptOssRMSNorm
|
| 24 |
-
|
| 25 |
-
|
| 26 |
-
replace_kernel_forward_from_hub(GptOssRMSNorm, None)
|
| 27 |
-
|
| 28 |
-
# set to debug logging
|
| 29 |
-
logging.basicConfig(level=logging.INFO)
|
| 30 |
-
|
| 31 |
-
def reset_peak_memory_stats():
|
| 32 |
-
"""Clear CUDA cache and reset memory allocation counters."""
|
| 33 |
-
torch.cuda.empty_cache()
|
| 34 |
-
if torch.cuda.is_available():
|
| 35 |
-
torch.cuda.reset_peak_memory_stats()
|
| 36 |
-
gc.collect()
|
| 37 |
-
|
| 38 |
-
def get_memory_stats():
|
| 39 |
-
"""Get current and peak CUDA memory usage."""
|
| 40 |
-
if not torch.cuda.is_available():
|
| 41 |
-
return {"allocated_gb": 0, "peak_gb": 0, "reserved_gb": 0}
|
| 42 |
-
return {
|
| 43 |
-
"allocated_gb": torch.cuda.memory_allocated() / 1e9,
|
| 44 |
-
"peak_gb": torch.cuda.max_memory_allocated() / 1e9,
|
| 45 |
-
"reserved_gb": torch.cuda.memory_reserved() / 1e9,
|
| 46 |
-
}
|
| 47 |
-
|
| 48 |
-
def override_kernel_layer_name(cls_name: str, value) -> bool:
|
| 49 |
-
"""Helper to dynamically override the kernel_layer_name in a model class."""
|
| 50 |
-
for mod in sys.modules.values():
|
| 51 |
-
if mod is None:
|
| 52 |
-
continue
|
| 53 |
-
obj = getattr(mod, cls_name, None)
|
| 54 |
-
if isinstance(obj, type) and issubclass(obj, nn.Module):
|
| 55 |
-
setattr(obj, "kernel_layer_name", value)
|
| 56 |
-
print(f"Overrode {cls_name}.kernel_layer_name to {value}")
|
| 57 |
-
return True
|
| 58 |
-
return False
|
| 59 |
-
|
| 60 |
-
|
| 61 |
-
# Init the model the normal way
|
| 62 |
-
model_id = "openai/gpt-oss-20b"
|
| 63 |
-
tokenizer = PreTrainedTokenizerFast.from_pretrained(model_id)
|
| 64 |
-
quantization_config = Mxfp4Config(dequantize=True)
|
| 65 |
-
|
| 66 |
-
|
| 67 |
-
|
| 68 |
-
model = GptOssForCausalLM.from_pretrained(
|
| 69 |
-
model_id,
|
| 70 |
-
dtype="bfloat16",
|
| 71 |
-
device_map="auto",
|
| 72 |
-
use_kernels=True,
|
| 73 |
-
quantization_config=quantization_config,
|
| 74 |
-
).eval()
|
| 75 |
-
|
| 76 |
-
messages = [
|
| 77 |
-
{"role": "system", "content": "What is Tensor Parallelism?"},
|
| 78 |
-
]
|
| 79 |
-
|
| 80 |
-
inputs = tokenizer.apply_chat_template(
|
| 81 |
-
messages,
|
| 82 |
-
add_generation_prompt=True,
|
| 83 |
-
return_tensors="pt",
|
| 84 |
-
return_dict=True,
|
| 85 |
-
reasoning_effort="low",
|
| 86 |
-
).to("cuda")
|
| 87 |
-
|
| 88 |
-
max_tokens = 256
|
| 89 |
-
|
| 90 |
-
with torch.inference_mode():
|
| 91 |
-
start_time = time.perf_counter()
|
| 92 |
-
generated = model.generate(
|
| 93 |
-
**inputs,
|
| 94 |
-
max_new_tokens=max_tokens,
|
| 95 |
-
do_sample=False,
|
| 96 |
-
temperature=None,
|
| 97 |
-
)
|
| 98 |
-
end_time = time.perf_counter()
|
| 99 |
-
|
| 100 |
-
print(tokenizer.decode(generated[0], skip_special_tokens=False))
|
| 101 |
-
print(f"Generation took {end_time - start_time:.2f} seconds")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
megablocks/cells/no_kernels.py
DELETED
|
@@ -1,98 +0,0 @@
|
|
| 1 |
-
# /// script
|
| 2 |
-
# requires-python = ">=3.12"
|
| 3 |
-
# dependencies = [
|
| 4 |
-
# "accelerate>=1.10.1",
|
| 5 |
-
# "torch>=2.7.0",
|
| 6 |
-
# "kernels==0.10.0",
|
| 7 |
-
# "transformers@https://github.com/huggingface/transformers.git",
|
| 8 |
-
# "ipdb>=0.13.13",
|
| 9 |
-
# "matplotlib>=3.7.2",
|
| 10 |
-
# "numpy>=1.24.3",
|
| 11 |
-
# ]
|
| 12 |
-
# ///
|
| 13 |
-
|
| 14 |
-
import torch
|
| 15 |
-
from transformers import GptOssForCausalLM, PreTrainedTokenizerFast, Mxfp4Config
|
| 16 |
-
import time
|
| 17 |
-
import torch.nn as nn
|
| 18 |
-
from kernels import register_kernel_mapping, Mode, LayerRepository, replace_kernel_forward_from_hub
|
| 19 |
-
import sys
|
| 20 |
-
import torch.profiler
|
| 21 |
-
import gc
|
| 22 |
-
import logging
|
| 23 |
-
from transformers.models.gpt_oss.modeling_gpt_oss import GptOssRMSNorm
|
| 24 |
-
|
| 25 |
-
# set to debug logging
|
| 26 |
-
logging.basicConfig(level=logging.INFO)
|
| 27 |
-
|
| 28 |
-
def reset_peak_memory_stats():
|
| 29 |
-
"""Clear CUDA cache and reset memory allocation counters."""
|
| 30 |
-
torch.cuda.empty_cache()
|
| 31 |
-
if torch.cuda.is_available():
|
| 32 |
-
torch.cuda.reset_peak_memory_stats()
|
| 33 |
-
gc.collect()
|
| 34 |
-
|
| 35 |
-
def get_memory_stats():
|
| 36 |
-
"""Get current and peak CUDA memory usage."""
|
| 37 |
-
if not torch.cuda.is_available():
|
| 38 |
-
return {"allocated_gb": 0, "peak_gb": 0, "reserved_gb": 0}
|
| 39 |
-
return {
|
| 40 |
-
"allocated_gb": torch.cuda.memory_allocated() / 1e9,
|
| 41 |
-
"peak_gb": torch.cuda.max_memory_allocated() / 1e9,
|
| 42 |
-
"reserved_gb": torch.cuda.memory_reserved() / 1e9,
|
| 43 |
-
}
|
| 44 |
-
|
| 45 |
-
def override_kernel_layer_name(cls_name: str, value) -> bool:
|
| 46 |
-
"""Helper to dynamically override the kernel_layer_name in a model class."""
|
| 47 |
-
for mod in sys.modules.values():
|
| 48 |
-
if mod is None:
|
| 49 |
-
continue
|
| 50 |
-
obj = getattr(mod, cls_name, None)
|
| 51 |
-
if isinstance(obj, type) and issubclass(obj, nn.Module):
|
| 52 |
-
setattr(obj, "kernel_layer_name", value)
|
| 53 |
-
print(f"Overrode {cls_name}.kernel_layer_name to {value}")
|
| 54 |
-
return True
|
| 55 |
-
return False
|
| 56 |
-
|
| 57 |
-
|
| 58 |
-
# Init the model the normal way
|
| 59 |
-
model_id = "openai/gpt-oss-20b"
|
| 60 |
-
tokenizer = PreTrainedTokenizerFast.from_pretrained(model_id)
|
| 61 |
-
quantization_config = Mxfp4Config(dequantize=True)
|
| 62 |
-
|
| 63 |
-
|
| 64 |
-
|
| 65 |
-
model = GptOssForCausalLM.from_pretrained(
|
| 66 |
-
model_id,
|
| 67 |
-
dtype="bfloat16",
|
| 68 |
-
device_map="auto",
|
| 69 |
-
use_kernels=False,
|
| 70 |
-
quantization_config=quantization_config,
|
| 71 |
-
).eval()
|
| 72 |
-
|
| 73 |
-
messages = [
|
| 74 |
-
{"role": "system", "content": "What is Tensor Parallelism?"},
|
| 75 |
-
]
|
| 76 |
-
|
| 77 |
-
inputs = tokenizer.apply_chat_template(
|
| 78 |
-
messages,
|
| 79 |
-
add_generation_prompt=True,
|
| 80 |
-
return_tensors="pt",
|
| 81 |
-
return_dict=True,
|
| 82 |
-
reasoning_effort="low",
|
| 83 |
-
).to("cuda")
|
| 84 |
-
|
| 85 |
-
max_tokens = 256
|
| 86 |
-
|
| 87 |
-
with torch.inference_mode():
|
| 88 |
-
start_time = time.perf_counter()
|
| 89 |
-
generated = model.generate(
|
| 90 |
-
**inputs,
|
| 91 |
-
max_new_tokens=max_tokens,
|
| 92 |
-
do_sample=False,
|
| 93 |
-
temperature=None,
|
| 94 |
-
)
|
| 95 |
-
end_time = time.perf_counter()
|
| 96 |
-
|
| 97 |
-
print(tokenizer.decode(generated[0], skip_special_tokens=False))
|
| 98 |
-
print(f"Generation took {end_time - start_time:.2f} seconds")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
megablocks/cells/nv.py
DELETED
|
@@ -1,3 +0,0 @@
|
|
| 1 |
-
import subprocess
|
| 2 |
-
|
| 3 |
-
print(subprocess.run(["nvidia-smi"], capture_output=True, text=True).stdout)
|
|
|
|
|
|
|
|
|
|
|
|
megablocks/index.html
DELETED
|
@@ -1,24 +0,0 @@
|
|
| 1 |
-
<!DOCTYPE html>
|
| 2 |
-
<html>
|
| 3 |
-
<head>
|
| 4 |
-
<meta charset='UTF-8'>
|
| 5 |
-
<title>Directory Index</title>
|
| 6 |
-
<style>
|
| 7 |
-
body { font-family: monospace; margin: 20px; }
|
| 8 |
-
h1 { font-size: 1.5em; }
|
| 9 |
-
ul { list-style-type: none; padding-left: 20px; }
|
| 10 |
-
li { margin: 5px 0; }
|
| 11 |
-
.dir { font-weight: bold; }
|
| 12 |
-
.file { color: #0066cc; }
|
| 13 |
-
a { text-decoration: none; }
|
| 14 |
-
a:hover { text-decoration: underline; }
|
| 15 |
-
</style>
|
| 16 |
-
</head>
|
| 17 |
-
<body>
|
| 18 |
-
<h1>Index of /megablocks</h1>
|
| 19 |
-
<ul>
|
| 20 |
-
<li><a href='../index.html' class='dir'>../</a></li>
|
| 21 |
-
<li><a href='megablocks_only.html' class='file'>megablocks_only.html</a></li>
|
| 22 |
-
</ul>
|
| 23 |
-
</body>
|
| 24 |
-
</html>
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
megablocks/megablocks_only.html
DELETED
|
The diff for this file is too large to render.
See raw diff
|
|
|
megablocks_yamoe/artifacts/binned_run/binned_results.json
DELETED
|
@@ -1,24 +0,0 @@
|
|
| 1 |
-
{
|
| 2 |
-
"implementation": "binned_results",
|
| 3 |
-
"config": {
|
| 4 |
-
"warmup": 10,
|
| 5 |
-
"iters": 50,
|
| 6 |
-
"device": "cuda",
|
| 7 |
-
"dtype": "torch.float32",
|
| 8 |
-
"tokens": 100,
|
| 9 |
-
"vary_inputs": true
|
| 10 |
-
},
|
| 11 |
-
"stats": {
|
| 12 |
-
"avg_ms": 36.26809924006011,
|
| 13 |
-
"min_ms": 34.103908000361116,
|
| 14 |
-
"max_ms": 37.68557000057626,
|
| 15 |
-
"std_ms": 1.1598518125118418,
|
| 16 |
-
"p50_ms": 36.52223600056459,
|
| 17 |
-
"p95_ms": 37.6427445000445,
|
| 18 |
-
"p99_ms": 37.677440410316194,
|
| 19 |
-
"num_iters": 50,
|
| 20 |
-
"tokens_per_s": 2757.2440269917565,
|
| 21 |
-
"throughput_variance": 89.13103199163609
|
| 22 |
-
},
|
| 23 |
-
"output_sum": 3.97190523147583
|
| 24 |
-
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
megablocks_yamoe/artifacts/gptoss_run/gptoss_results.json
DELETED
|
@@ -1,24 +0,0 @@
|
|
| 1 |
-
{
|
| 2 |
-
"implementation": "gptoss_results",
|
| 3 |
-
"config": {
|
| 4 |
-
"warmup": 10,
|
| 5 |
-
"iters": 50,
|
| 6 |
-
"device": "cuda",
|
| 7 |
-
"dtype": "torch.float32",
|
| 8 |
-
"tokens": 100,
|
| 9 |
-
"vary_inputs": true
|
| 10 |
-
},
|
| 11 |
-
"stats": {
|
| 12 |
-
"avg_ms": 46.913985819956,
|
| 13 |
-
"min_ms": 40.44806400088419,
|
| 14 |
-
"max_ms": 51.07520399997156,
|
| 15 |
-
"std_ms": 2.9921332618008196,
|
| 16 |
-
"p50_ms": 47.418902999652346,
|
| 17 |
-
"p95_ms": 50.800493049837314,
|
| 18 |
-
"p99_ms": 50.948625239852845,
|
| 19 |
-
"num_iters": 50,
|
| 20 |
-
"tokens_per_s": 2131.560519794133,
|
| 21 |
-
"throughput_variance": 139.93911554997217
|
| 22 |
-
},
|
| 23 |
-
"output_sum": 11.53223705291748
|
| 24 |
-
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
megablocks_yamoe/artifacts/gptoss_training_run/gptoss_training_results.json
DELETED
|
@@ -1,24 +0,0 @@
|
|
| 1 |
-
{
|
| 2 |
-
"implementation": "gptoss_training_results",
|
| 3 |
-
"config": {
|
| 4 |
-
"warmup": 10,
|
| 5 |
-
"iters": 50,
|
| 6 |
-
"device": "cuda",
|
| 7 |
-
"dtype": "torch.float32",
|
| 8 |
-
"tokens": 100,
|
| 9 |
-
"vary_inputs": true
|
| 10 |
-
},
|
| 11 |
-
"stats": {
|
| 12 |
-
"avg_ms": 46.289439859992854,
|
| 13 |
-
"min_ms": 39.97907499979192,
|
| 14 |
-
"max_ms": 50.58144600025116,
|
| 15 |
-
"std_ms": 2.9172154402078077,
|
| 16 |
-
"p50_ms": 46.64785849990949,
|
| 17 |
-
"p95_ms": 50.26727430031315,
|
| 18 |
-
"p99_ms": 50.5162941305025,
|
| 19 |
-
"num_iters": 50,
|
| 20 |
-
"tokens_per_s": 2160.3199412751637,
|
| 21 |
-
"throughput_variance": 139.86427060112865
|
| 22 |
-
},
|
| 23 |
-
"output_sum": 11.53223705291748
|
| 24 |
-
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
megablocks_yamoe/artifacts/yamoe_run/yamoe_results.json
DELETED
|
@@ -1,24 +0,0 @@
|
|
| 1 |
-
{
|
| 2 |
-
"implementation": "yamoe_results",
|
| 3 |
-
"config": {
|
| 4 |
-
"warmup": 10,
|
| 5 |
-
"iters": 50,
|
| 6 |
-
"device": "cuda",
|
| 7 |
-
"dtype": "torch.float32",
|
| 8 |
-
"tokens": 100,
|
| 9 |
-
"vary_inputs": true
|
| 10 |
-
},
|
| 11 |
-
"stats": {
|
| 12 |
-
"avg_ms": 4.248197240067384,
|
| 13 |
-
"min_ms": 4.136622000260104,
|
| 14 |
-
"max_ms": 4.280714999367774,
|
| 15 |
-
"std_ms": 0.02141682051311511,
|
| 16 |
-
"p50_ms": 4.253484999935608,
|
| 17 |
-
"p95_ms": 4.265540049709671,
|
| 18 |
-
"p99_ms": 4.273649199667489,
|
| 19 |
-
"num_iters": 50,
|
| 20 |
-
"tokens_per_s": 23539.396677922097,
|
| 21 |
-
"throughput_variance": 120.66648678204231
|
| 22 |
-
},
|
| 23 |
-
"output_sum": 3.97190523147583
|
| 24 |
-
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
megablocks_yamoe/cells/__pycache__/bench_utils.cpython-311.pyc
DELETED
|
Binary file (16.1 kB)
|
|
|
megablocks_yamoe/cells/__pycache__/config.cpython-311.pyc
DELETED
|
Binary file (680 Bytes)
|
|
|
megablocks_yamoe/cells/bench_utils.py
DELETED
|
@@ -1,241 +0,0 @@
|
|
| 1 |
-
# /// script
|
| 2 |
-
# dependencies = [
|
| 3 |
-
# "torch",
|
| 4 |
-
# "numpy",
|
| 5 |
-
# ]
|
| 6 |
-
# ///
|
| 7 |
-
|
| 8 |
-
"""Reusable benchmarking utilities for performance testing."""
|
| 9 |
-
import time
|
| 10 |
-
import numpy as np
|
| 11 |
-
from contextlib import contextmanager
|
| 12 |
-
from typing import Callable, Dict, Tuple, Any, Optional
|
| 13 |
-
import torch
|
| 14 |
-
|
| 15 |
-
def to_dtype(dtype_str: str):
|
| 16 |
-
"""Convert string to torch dtype."""
|
| 17 |
-
if dtype_str == "float16":
|
| 18 |
-
return torch.float16
|
| 19 |
-
if dtype_str == "bfloat16":
|
| 20 |
-
return torch.bfloat16
|
| 21 |
-
return torch.float32
|
| 22 |
-
|
| 23 |
-
def _sync(device: str):
|
| 24 |
-
"""Synchronize device if CUDA."""
|
| 25 |
-
if device == "cuda":
|
| 26 |
-
torch.cuda.synchronize()
|
| 27 |
-
|
| 28 |
-
def _compute_stats(times_s, tokens: Optional[int] = None) -> Dict[str, float]:
|
| 29 |
-
"""Compute comprehensive latency and throughput statistics."""
|
| 30 |
-
lat_ms = np.array([t * 1000.0 for t in times_s])
|
| 31 |
-
lat_ms_sorted = np.sort(lat_ms)
|
| 32 |
-
n = len(lat_ms)
|
| 33 |
-
|
| 34 |
-
stats = {
|
| 35 |
-
"avg_ms": np.mean(lat_ms),
|
| 36 |
-
"min_ms": np.min(lat_ms),
|
| 37 |
-
"max_ms": np.max(lat_ms),
|
| 38 |
-
"std_ms": np.std(lat_ms),
|
| 39 |
-
"p50_ms": np.percentile(lat_ms, 50),
|
| 40 |
-
"p95_ms": np.percentile(lat_ms, 95),
|
| 41 |
-
"p99_ms": np.percentile(lat_ms, 99),
|
| 42 |
-
"num_iters": n
|
| 43 |
-
}
|
| 44 |
-
|
| 45 |
-
if tokens is not None and n > 0:
|
| 46 |
-
avg_s = np.mean(times_s)
|
| 47 |
-
stats["tokens_per_s"] = tokens / avg_s if avg_s > 0 else float("inf")
|
| 48 |
-
stats["throughput_variance"] = np.std([tokens / t for t in times_s if t > 0])
|
| 49 |
-
|
| 50 |
-
return stats
|
| 51 |
-
|
| 52 |
-
def _format_timing_stats(stats: Dict[str, float], tokens: Optional[int] = None) -> str:
|
| 53 |
-
"""Format timing statistics for display."""
|
| 54 |
-
lines = [
|
| 55 |
-
"\n━━━━━━━━━━━━━━━━━━━━ Benchmark Results ━━━━━━━━━━━━━━━━━━━━",
|
| 56 |
-
f"Iterations: {stats.get('num_iters', 0)}",
|
| 57 |
-
"\nLatency Statistics:",
|
| 58 |
-
f" Average: {stats['avg_ms']:.3f} ms",
|
| 59 |
-
f" Min: {stats['min_ms']:.3f} ms",
|
| 60 |
-
f" Max: {stats['max_ms']:.3f} ms",
|
| 61 |
-
f" Std Dev: {stats['std_ms']:.3f} ms",
|
| 62 |
-
"\nPercentiles:",
|
| 63 |
-
f" P50 (median): {stats['p50_ms']:.3f} ms",
|
| 64 |
-
f" P95: {stats['p95_ms']:.3f} ms",
|
| 65 |
-
f" P99: {stats['p99_ms']:.3f} ms",
|
| 66 |
-
]
|
| 67 |
-
|
| 68 |
-
if tokens is not None and 'tokens_per_s' in stats:
|
| 69 |
-
lines.extend([
|
| 70 |
-
"\nThroughput:",
|
| 71 |
-
f" Tokens/sec: {stats['tokens_per_s']:.1f}",
|
| 72 |
-
f" Std Dev: {stats.get('throughput_variance', 0):.1f}",
|
| 73 |
-
])
|
| 74 |
-
|
| 75 |
-
lines.append("━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━")
|
| 76 |
-
return "\n".join(lines)
|
| 77 |
-
|
| 78 |
-
def _bench_engine(
|
| 79 |
-
call: Callable[[], Any], *, warmup: int, iters: int, device: str, dtype, input_gen: Callable[[], Any] = None
|
| 80 |
-
) -> Tuple[Any, list]:
|
| 81 |
-
"""Core benchmarking engine with warmup and timing."""
|
| 82 |
-
use_autocast = device == "cuda" and dtype in (torch.float16, torch.bfloat16)
|
| 83 |
-
|
| 84 |
-
# Warmup phase
|
| 85 |
-
print(f"\nWarming up ({warmup} iterations)...")
|
| 86 |
-
with torch.inference_mode():
|
| 87 |
-
for _ in range(max(0, warmup)):
|
| 88 |
-
if use_autocast:
|
| 89 |
-
with torch.autocast(device_type="cuda", dtype=dtype):
|
| 90 |
-
if input_gen is not None:
|
| 91 |
-
_ = call(input_gen())
|
| 92 |
-
else:
|
| 93 |
-
_ = call()
|
| 94 |
-
else:
|
| 95 |
-
if input_gen is not None:
|
| 96 |
-
_ = call(input_gen())
|
| 97 |
-
else:
|
| 98 |
-
_ = call()
|
| 99 |
-
_sync(device)
|
| 100 |
-
|
| 101 |
-
# Measurement phase
|
| 102 |
-
print(f"Benchmarking ({iters} iterations)...")
|
| 103 |
-
times_s = []
|
| 104 |
-
last = None
|
| 105 |
-
with torch.inference_mode():
|
| 106 |
-
for i in range(max(1, iters)):
|
| 107 |
-
start = time.perf_counter()
|
| 108 |
-
if use_autocast:
|
| 109 |
-
with torch.autocast(device_type="cuda", dtype=dtype):
|
| 110 |
-
if input_gen is not None:
|
| 111 |
-
last = call(input_gen())
|
| 112 |
-
else:
|
| 113 |
-
last = call()
|
| 114 |
-
else:
|
| 115 |
-
if input_gen is not None:
|
| 116 |
-
last = call(input_gen())
|
| 117 |
-
else:
|
| 118 |
-
last = call()
|
| 119 |
-
_sync(device)
|
| 120 |
-
end = time.perf_counter()
|
| 121 |
-
times_s.append(end - start)
|
| 122 |
-
|
| 123 |
-
# Progress indicator every 20% of iterations
|
| 124 |
-
if i > 0 and i % max(1, iters // 5) == 0:
|
| 125 |
-
pct = (i / iters) * 100
|
| 126 |
-
avg_so_far = np.mean(times_s[:i]) * 1000
|
| 127 |
-
print(f" Progress: {pct:.0f}% complete (avg: {avg_so_far:.3f} ms)")
|
| 128 |
-
|
| 129 |
-
return last, times_s
|
| 130 |
-
|
| 131 |
-
def tensor_stats(t: torch.Tensor) -> str:
|
| 132 |
-
"""Generate comprehensive stats string for a tensor."""
|
| 133 |
-
return (f"shape={tuple(t.shape)}, "
|
| 134 |
-
f"dtype={t.dtype}, "
|
| 135 |
-
f"device={t.device}, "
|
| 136 |
-
f"range=[{t.min().item():.6f}, {t.max().item():.6f}], "
|
| 137 |
-
f"mean={t.mean().item():.6f}, "
|
| 138 |
-
f"std={t.std().item():.6f}, "
|
| 139 |
-
f"norm={t.norm().item():.6f}")
|
| 140 |
-
|
| 141 |
-
@contextmanager
|
| 142 |
-
def bench_context(
|
| 143 |
-
*, warmup: int = 25, iters: int = 100, device: str = "cuda", dtype=torch.float32, tokens: Optional[int] = None, verbose: bool = True, save_json: Optional[str] = None, vary_inputs: bool = True
|
| 144 |
-
):
|
| 145 |
-
"""Context that yields a runner: runner(fn, *args, **kwargs) -> (result, stats).
|
| 146 |
-
|
| 147 |
-
If vary_inputs=True, the first argument should be a base tensor that will be varied each iteration
|
| 148 |
-
by adding a small deterministic increment to prevent caching artifacts.
|
| 149 |
-
"""
|
| 150 |
-
|
| 151 |
-
def runner(fn: Callable[..., Any], *args, **kwargs) -> Tuple[Any, Dict[str, float]]:
|
| 152 |
-
# Log configuration
|
| 153 |
-
if verbose:
|
| 154 |
-
print(f"\n┌─ Benchmark Configuration ─────────────────────────────┐")
|
| 155 |
-
# print(f"│ Device: {device:<15} Dtype: {dtype} │")
|
| 156 |
-
print(f"│ Warmup: {warmup:<15} Iters: {iters} │")
|
| 157 |
-
if tokens:
|
| 158 |
-
print(f"│ Tokens: {tokens} │")
|
| 159 |
-
if vary_inputs:
|
| 160 |
-
print(f"│ Input Variation: Enabled (prevents caching artifacts) │")
|
| 161 |
-
print(f"└────────────────────────────────────────────────────────┘")
|
| 162 |
-
|
| 163 |
-
# Set up input generation
|
| 164 |
-
input_gen = None
|
| 165 |
-
if vary_inputs and args and isinstance(args[0], torch.Tensor):
|
| 166 |
-
base_input = args[0].clone()
|
| 167 |
-
iteration_counter = [0] # Use list for mutable closure
|
| 168 |
-
|
| 169 |
-
def generate_varied_input():
|
| 170 |
-
"""Generate input tensor varied by iteration to prevent caching."""
|
| 171 |
-
# Add small deterministic increment: 0.001 * iteration_number
|
| 172 |
-
varied_input = base_input + (iteration_counter[0] * 0.001)
|
| 173 |
-
iteration_counter[0] += 1
|
| 174 |
-
return varied_input
|
| 175 |
-
|
| 176 |
-
input_gen = generate_varied_input
|
| 177 |
-
call = lambda x: fn(x, *args[1:], **kwargs)
|
| 178 |
-
|
| 179 |
-
# Log base input stats
|
| 180 |
-
if verbose:
|
| 181 |
-
print(f"\nBase Input: {tensor_stats(base_input)}")
|
| 182 |
-
print(f"Input Variation: +{0.001:.3f} * iteration (deterministic)")
|
| 183 |
-
else:
|
| 184 |
-
# Legacy mode - static inputs
|
| 185 |
-
call = lambda: fn(*args, **kwargs)
|
| 186 |
-
if verbose and args and isinstance(args[0], torch.Tensor):
|
| 187 |
-
print(f"\nInput: {tensor_stats(args[0])}")
|
| 188 |
-
|
| 189 |
-
result, times_s = _bench_engine(call, warmup=warmup, iters=iters, device=device, dtype=dtype, input_gen=input_gen)
|
| 190 |
-
|
| 191 |
-
# Log output if it's a tensor or tuple with tensors
|
| 192 |
-
if verbose:
|
| 193 |
-
print("\nOutput tensors:")
|
| 194 |
-
if isinstance(result, torch.Tensor):
|
| 195 |
-
print(f" Primary: {tensor_stats(result)}")
|
| 196 |
-
elif isinstance(result, tuple) and len(result) > 0 and isinstance(result[0], torch.Tensor):
|
| 197 |
-
print(f" Primary: {tensor_stats(result[0])}")
|
| 198 |
-
if len(result) > 1:
|
| 199 |
-
if isinstance(result[1], torch.Tensor):
|
| 200 |
-
print(f" Auxiliary: {tensor_stats(result[1])}")
|
| 201 |
-
else:
|
| 202 |
-
print(f" Auxiliary: {type(result[1]).__name__}")
|
| 203 |
-
|
| 204 |
-
# Compute and display statistics
|
| 205 |
-
stats = _compute_stats(times_s, tokens=tokens)
|
| 206 |
-
if verbose:
|
| 207 |
-
print(_format_timing_stats(stats, tokens))
|
| 208 |
-
|
| 209 |
-
# Save to JSON if requested
|
| 210 |
-
if save_json:
|
| 211 |
-
import json
|
| 212 |
-
json_data = {
|
| 213 |
-
"implementation": save_json.replace(".json", ""),
|
| 214 |
-
"config": {
|
| 215 |
-
"warmup": warmup,
|
| 216 |
-
"iters": iters,
|
| 217 |
-
"device": str(device), # Convert device to string
|
| 218 |
-
"dtype": str(dtype),
|
| 219 |
-
"tokens": tokens,
|
| 220 |
-
"vary_inputs": vary_inputs
|
| 221 |
-
},
|
| 222 |
-
"stats": stats,
|
| 223 |
-
"output_sum": float(result[0].sum().item()) if isinstance(result, tuple) and len(result) > 0 else float(result.sum().item()) if isinstance(result, torch.Tensor) else None
|
| 224 |
-
}
|
| 225 |
-
with open(save_json, 'w') as f:
|
| 226 |
-
json.dump(json_data, f, indent=2)
|
| 227 |
-
if verbose:
|
| 228 |
-
print(f"\nSaved benchmark results to {save_json}")
|
| 229 |
-
|
| 230 |
-
return result, stats
|
| 231 |
-
|
| 232 |
-
yield runner
|
| 233 |
-
|
| 234 |
-
def set_seed(seed: int):
|
| 235 |
-
"""Set seeds for reproducibility."""
|
| 236 |
-
torch.manual_seed(seed)
|
| 237 |
-
if torch.cuda.is_available():
|
| 238 |
-
torch.cuda.manual_seed(seed)
|
| 239 |
-
torch.cuda.manual_seed_all(seed)
|
| 240 |
-
torch.backends.cudnn.deterministic = True
|
| 241 |
-
torch.backends.cudnn.benchmark = False
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
megablocks_yamoe/cells/binned_run.py
DELETED
|
@@ -1,195 +0,0 @@
|
|
| 1 |
-
# /// script
|
| 2 |
-
# dependencies = [
|
| 3 |
-
# "torch",
|
| 4 |
-
# "numpy",
|
| 5 |
-
# ]
|
| 6 |
-
# ///
|
| 7 |
-
|
| 8 |
-
import torch
|
| 9 |
-
from torch import nn
|
| 10 |
-
from torch.nn import functional as F
|
| 11 |
-
from bench_utils import to_dtype, tensor_stats, set_seed, bench_context
|
| 12 |
-
from config import (
|
| 13 |
-
NUM_EXPERTS, HIDDEN_SIZE, TOP_K,
|
| 14 |
-
BATCH_SIZE, SEQ_LEN, DTYPE, DEVICE,
|
| 15 |
-
WEIGHT_SEED, EXPERT_SEED, INPUT_SEED, GENERAL_SEED
|
| 16 |
-
)
|
| 17 |
-
from pathlib import Path
|
| 18 |
-
import os
|
| 19 |
-
|
| 20 |
-
# Discover the upstream artifact directory from env
|
| 21 |
-
data_dir = os.environ.get('UVNOTE_INPUT_SAVE_DATA', '.')
|
| 22 |
-
|
| 23 |
-
router_weight = torch.load(Path(data_dir) / 'router_weight.pt')
|
| 24 |
-
router_bias = torch.load(Path(data_dir) / 'router_bias.pt')
|
| 25 |
-
gate_up_proj = torch.load(Path(data_dir) / 'gate_up_proj.pt')
|
| 26 |
-
gate_up_proj_bias = torch.load(Path(data_dir) / 'gate_up_proj_bias.pt')
|
| 27 |
-
down_proj = torch.load(Path(data_dir) / 'down_proj.pt')
|
| 28 |
-
down_proj_bias = torch.load(Path(data_dir) / 'down_proj_bias.pt')
|
| 29 |
-
|
| 30 |
-
print("Loaded shared weights from artifacts")
|
| 31 |
-
print(f"Router weight sum: {router_weight.sum().item():.6f}")
|
| 32 |
-
print(f"Gate/up sum: {gate_up_proj.sum().item():.6f}")
|
| 33 |
-
print(f"Down sum: {down_proj.sum().item():.6f}")
|
| 34 |
-
|
| 35 |
-
def binned_gather(x, indices, bins, expert_capacity, top_k):
|
| 36 |
-
E, H = bins.shape[0], x.shape[1]
|
| 37 |
-
out = torch.zeros((E, expert_capacity, H), device=x.device, dtype=x.dtype)
|
| 38 |
-
for e in range(E):
|
| 39 |
-
start = 0 if e == 0 else bins[e - 1]
|
| 40 |
-
end = bins[e]
|
| 41 |
-
n = min(end - start, expert_capacity)
|
| 42 |
-
for i in range(n):
|
| 43 |
-
flat_pos = indices[start + i]
|
| 44 |
-
tok = flat_pos // top_k
|
| 45 |
-
out[e, i] = x[tok]
|
| 46 |
-
return out
|
| 47 |
-
|
| 48 |
-
def binned_scatter(x, indices, weights, bins, expert_capacity, top_k):
|
| 49 |
-
E, C, H = x.shape
|
| 50 |
-
N = indices.shape[0] // top_k
|
| 51 |
-
out = torch.zeros((N, top_k, H), dtype=x.dtype, device=x.device)
|
| 52 |
-
for e in range(E):
|
| 53 |
-
start = 0 if e == 0 else bins[e - 1]
|
| 54 |
-
end = bins[e]
|
| 55 |
-
n = end - start
|
| 56 |
-
if n == 0:
|
| 57 |
-
continue
|
| 58 |
-
take = min(n, expert_capacity)
|
| 59 |
-
for i in range(take):
|
| 60 |
-
flat_pos = indices[start + i]
|
| 61 |
-
tok = flat_pos // top_k
|
| 62 |
-
slot = flat_pos % top_k
|
| 63 |
-
scale = weights[flat_pos] if weights is not None else 1.0
|
| 64 |
-
out[tok, slot] = x[e, i] * scale
|
| 65 |
-
return out.sum(dim=1)
|
| 66 |
-
|
| 67 |
-
def sort_tokens_by_expert(router_indices, num_experts):
|
| 68 |
-
flat_indices = router_indices.flatten()
|
| 69 |
-
sorted_values, sorted_indices = torch.sort(flat_indices)
|
| 70 |
-
tokens_per_expert = torch.bincount(sorted_values, minlength=num_experts)
|
| 71 |
-
bins = torch.cumsum(tokens_per_expert, dim=0)
|
| 72 |
-
return sorted_indices, sorted_values, bins, tokens_per_expert
|
| 73 |
-
|
| 74 |
-
def binned_experts_ref(
|
| 75 |
-
hidden_states,
|
| 76 |
-
router_indices,
|
| 77 |
-
routing_weights,
|
| 78 |
-
gate_up_proj,
|
| 79 |
-
gate_up_proj_bias,
|
| 80 |
-
down_proj,
|
| 81 |
-
down_proj_bias,
|
| 82 |
-
expert_capacity,
|
| 83 |
-
):
|
| 84 |
-
B, S, H = hidden_states.shape
|
| 85 |
-
E, K = routing_weights.shape[1], router_indices.shape[1]
|
| 86 |
-
|
| 87 |
-
indices, _, bins, _ = sort_tokens_by_expert(router_indices, E)
|
| 88 |
-
x = binned_gather(hidden_states.view(-1, H), indices, bins, expert_capacity, K)
|
| 89 |
-
|
| 90 |
-
gate_up = torch.bmm(x, gate_up_proj)
|
| 91 |
-
gate_up += gate_up_proj_bias[..., None, :]
|
| 92 |
-
|
| 93 |
-
gate, up = gate_up[..., ::2], gate_up[..., 1::2]
|
| 94 |
-
|
| 95 |
-
# clamp to limit
|
| 96 |
-
limit = 7.0
|
| 97 |
-
gate = gate.clamp(min=None, max=limit)
|
| 98 |
-
up = up.clamp(min=-limit, max=limit)
|
| 99 |
-
|
| 100 |
-
glu = gate * torch.sigmoid(gate * 1.702)
|
| 101 |
-
x = (up + 1) * glu
|
| 102 |
-
x = torch.bmm(x, down_proj) + down_proj_bias[..., None, :]
|
| 103 |
-
|
| 104 |
-
# build routing weights aligned to (token, slot)
|
| 105 |
-
flat_dense = routing_weights.view(-1, E)
|
| 106 |
-
flat_router = router_indices.view(-1, K)
|
| 107 |
-
selected = torch.gather(flat_dense, 1, flat_router).reshape(-1)
|
| 108 |
-
|
| 109 |
-
# scatter back
|
| 110 |
-
y = binned_scatter(x, indices, selected, bins, expert_capacity, K)
|
| 111 |
-
|
| 112 |
-
return y.view(B, S, H)
|
| 113 |
-
|
| 114 |
-
class BinnedRouter(nn.Module):
|
| 115 |
-
def __init__(self, router_weight, router_bias):
|
| 116 |
-
super().__init__()
|
| 117 |
-
self.top_k = TOP_K
|
| 118 |
-
self.num_experts = NUM_EXPERTS
|
| 119 |
-
self.hidden_dim = HIDDEN_SIZE
|
| 120 |
-
self.weight = nn.Parameter(router_weight.clone())
|
| 121 |
-
self.bias = nn.Parameter(router_bias.clone())
|
| 122 |
-
|
| 123 |
-
def forward(self, hidden_states):
|
| 124 |
-
hidden_states = hidden_states.reshape(-1, self.hidden_dim)
|
| 125 |
-
router_logits = F.linear(hidden_states, self.weight, self.bias)
|
| 126 |
-
router_top_value, router_indices = torch.topk(router_logits, self.top_k, dim=-1)
|
| 127 |
-
router_top_value = torch.nn.functional.softmax(router_top_value, dim=1, dtype=router_top_value.dtype)
|
| 128 |
-
router_scores = torch.zeros_like(router_logits).scatter_(1, router_indices, router_top_value)
|
| 129 |
-
return router_scores, router_indices
|
| 130 |
-
|
| 131 |
-
def ceil_div(a, b):
|
| 132 |
-
return (a + b - 1) // b
|
| 133 |
-
|
| 134 |
-
class BinnedMoEMLP(nn.Module):
|
| 135 |
-
def __init__(self, router_weight, router_bias, gate_up_proj, gate_up_proj_bias, down_proj, down_proj_bias):
|
| 136 |
-
super().__init__()
|
| 137 |
-
self.router = BinnedRouter(router_weight, router_bias)
|
| 138 |
-
self.num_experts = NUM_EXPERTS
|
| 139 |
-
self.hidden_size = HIDDEN_SIZE
|
| 140 |
-
self.top_k = TOP_K
|
| 141 |
-
|
| 142 |
-
# Expert weights - use the loaded weights
|
| 143 |
-
self.gate_up_proj = nn.Parameter(gate_up_proj.clone())
|
| 144 |
-
self.gate_up_proj_bias = nn.Parameter(gate_up_proj_bias.clone())
|
| 145 |
-
self.down_proj = nn.Parameter(down_proj.clone())
|
| 146 |
-
self.down_proj_bias = nn.Parameter(down_proj_bias.clone())
|
| 147 |
-
|
| 148 |
-
def forward(self, hidden_states):
|
| 149 |
-
router_scores, router_indices = self.router(hidden_states)
|
| 150 |
-
batch_size = hidden_states.shape[0]
|
| 151 |
-
expert_capacity = ceil_div(batch_size * self.top_k, self.num_experts)
|
| 152 |
-
|
| 153 |
-
output = binned_experts_ref(
|
| 154 |
-
hidden_states,
|
| 155 |
-
router_indices,
|
| 156 |
-
router_scores,
|
| 157 |
-
self.gate_up_proj,
|
| 158 |
-
self.gate_up_proj_bias,
|
| 159 |
-
self.down_proj,
|
| 160 |
-
self.down_proj_bias,
|
| 161 |
-
expert_capacity,
|
| 162 |
-
)
|
| 163 |
-
|
| 164 |
-
return output, router_scores
|
| 165 |
-
|
| 166 |
-
# Run the model
|
| 167 |
-
set_seed(GENERAL_SEED)
|
| 168 |
-
|
| 169 |
-
device = torch.device(DEVICE)
|
| 170 |
-
dtype = to_dtype(DTYPE)
|
| 171 |
-
|
| 172 |
-
print("\n=== Binned Implementation ===")
|
| 173 |
-
# Initialize model with loaded weights
|
| 174 |
-
model = BinnedMoEMLP(
|
| 175 |
-
router_weight.to(device),
|
| 176 |
-
router_bias.to(device),
|
| 177 |
-
gate_up_proj.to(device),
|
| 178 |
-
gate_up_proj_bias.to(device),
|
| 179 |
-
down_proj.to(device),
|
| 180 |
-
down_proj_bias.to(device)
|
| 181 |
-
).to(device=device)
|
| 182 |
-
|
| 183 |
-
print(f"Router weight sum: {model.router.weight.sum().item():.6f}")
|
| 184 |
-
print(f"Gate/up proj sum: {model.gate_up_proj.sum().item():.6f}")
|
| 185 |
-
print(f"Down proj sum: {model.down_proj.sum().item():.6f}")
|
| 186 |
-
|
| 187 |
-
# Generate the same input as Yamoe
|
| 188 |
-
set_seed(INPUT_SEED)
|
| 189 |
-
x = torch.randn(BATCH_SIZE, SEQ_LEN, HIDDEN_SIZE, device=device, dtype=dtype) * 0.1
|
| 190 |
-
|
| 191 |
-
# Benchmark the model with varied inputs to prevent caching artifacts
|
| 192 |
-
tokens = BATCH_SIZE * SEQ_LEN
|
| 193 |
-
with bench_context(warmup=10, iters=50, device=device, dtype=dtype, tokens=tokens, save_json="binned_results.json", vary_inputs=True) as bench:
|
| 194 |
-
output, stats = bench(model, x)
|
| 195 |
-
print(f"\nOutput sum: {output[0].sum().item():.6f}")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
megablocks_yamoe/cells/config.py
DELETED
|
@@ -1,27 +0,0 @@
|
|
| 1 |
-
# /// script
|
| 2 |
-
# dependencies = [
|
| 3 |
-
# "torch",
|
| 4 |
-
# "numpy",
|
| 5 |
-
# ]
|
| 6 |
-
# ///
|
| 7 |
-
|
| 8 |
-
"""Shared configuration for both implementations."""
|
| 9 |
-
import torch
|
| 10 |
-
|
| 11 |
-
# Model configuration
|
| 12 |
-
NUM_EXPERTS = 128
|
| 13 |
-
HIDDEN_SIZE = 1152
|
| 14 |
-
INTERMEDIATE_SIZE = 3072
|
| 15 |
-
TOP_K = 4
|
| 16 |
-
|
| 17 |
-
# Input configuration
|
| 18 |
-
BATCH_SIZE = 1
|
| 19 |
-
SEQ_LEN = 100
|
| 20 |
-
DTYPE = "float32"
|
| 21 |
-
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
|
| 22 |
-
|
| 23 |
-
# Seeds for reproducibility
|
| 24 |
-
WEIGHT_SEED = 999
|
| 25 |
-
EXPERT_SEED = 777
|
| 26 |
-
INPUT_SEED = 123
|
| 27 |
-
GENERAL_SEED = 42
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
megablocks_yamoe/cells/gptoss_run.py
DELETED
|
@@ -1,147 +0,0 @@
|
|
| 1 |
-
# /// script
|
| 2 |
-
# dependencies = [
|
| 3 |
-
# "torch",
|
| 4 |
-
# "numpy",
|
| 5 |
-
# ]
|
| 6 |
-
# ///
|
| 7 |
-
|
| 8 |
-
import torch
|
| 9 |
-
from torch import nn
|
| 10 |
-
from torch.nn import functional as F
|
| 11 |
-
from bench_utils import to_dtype, tensor_stats, set_seed, bench_context
|
| 12 |
-
from config import (
|
| 13 |
-
NUM_EXPERTS, HIDDEN_SIZE, TOP_K,
|
| 14 |
-
BATCH_SIZE, SEQ_LEN, DTYPE, DEVICE,
|
| 15 |
-
WEIGHT_SEED, EXPERT_SEED, INPUT_SEED, GENERAL_SEED
|
| 16 |
-
)
|
| 17 |
-
from pathlib import Path
|
| 18 |
-
import os
|
| 19 |
-
|
| 20 |
-
# Discover the upstream artifact directory from env
|
| 21 |
-
data_dir = os.environ.get('UVNOTE_INPUT_SAVE_DATA', '.')
|
| 22 |
-
|
| 23 |
-
router_weight = torch.load(Path(data_dir) / 'router_weight.pt')
|
| 24 |
-
router_bias = torch.load(Path(data_dir) / 'router_bias.pt')
|
| 25 |
-
gate_up_proj = torch.load(Path(data_dir) / 'gate_up_proj.pt')
|
| 26 |
-
gate_up_proj_bias = torch.load(Path(data_dir) / 'gate_up_proj_bias.pt')
|
| 27 |
-
down_proj = torch.load(Path(data_dir) / 'down_proj.pt')
|
| 28 |
-
down_proj_bias = torch.load(Path(data_dir) / 'down_proj_bias.pt')
|
| 29 |
-
|
| 30 |
-
print("Loaded shared weights from artifacts")
|
| 31 |
-
print(f"Router weight sum: {router_weight.sum().item():.6f}")
|
| 32 |
-
print(f"Gate/up sum: {gate_up_proj.sum().item():.6f}")
|
| 33 |
-
print(f"Down sum: {down_proj.sum().item():.6f}")
|
| 34 |
-
|
| 35 |
-
class GptOssRouter(nn.Module):
|
| 36 |
-
def __init__(self, router_weight, router_bias):
|
| 37 |
-
super().__init__()
|
| 38 |
-
self.top_k = TOP_K
|
| 39 |
-
self.num_experts = NUM_EXPERTS
|
| 40 |
-
self.hidden_dim = HIDDEN_SIZE
|
| 41 |
-
self.weight = nn.Parameter(router_weight.clone())
|
| 42 |
-
self.bias = nn.Parameter(router_bias.clone())
|
| 43 |
-
|
| 44 |
-
def forward(self, hidden_states):
|
| 45 |
-
hidden_states = hidden_states.reshape(-1, self.hidden_dim)
|
| 46 |
-
router_logits = F.linear(hidden_states, self.weight, self.bias)
|
| 47 |
-
router_top_value, router_indices = torch.topk(router_logits, self.top_k, dim=-1)
|
| 48 |
-
router_top_value = torch.nn.functional.softmax(router_top_value, dim=1, dtype=router_top_value.dtype)
|
| 49 |
-
router_scores = torch.zeros_like(router_logits).scatter_(1, router_indices, router_top_value)
|
| 50 |
-
return router_scores, router_indices
|
| 51 |
-
|
| 52 |
-
class GptOssExperts(nn.Module):
|
| 53 |
-
def __init__(self, gate_up_proj, gate_up_proj_bias, down_proj, down_proj_bias):
|
| 54 |
-
super().__init__()
|
| 55 |
-
self.num_experts = NUM_EXPERTS
|
| 56 |
-
self.hidden_size = HIDDEN_SIZE
|
| 57 |
-
self.expert_dim = self.hidden_size
|
| 58 |
-
self.gate_up_proj = nn.Parameter(gate_up_proj.clone())
|
| 59 |
-
self.gate_up_proj_bias = nn.Parameter(gate_up_proj_bias.clone())
|
| 60 |
-
self.down_proj = nn.Parameter(down_proj.clone())
|
| 61 |
-
self.down_proj_bias = nn.Parameter(down_proj_bias.clone())
|
| 62 |
-
self.alpha = 1.702
|
| 63 |
-
self.limit = 7.0
|
| 64 |
-
|
| 65 |
-
def forward(self, hidden_states: torch.Tensor, router_indices=None, routing_weights=None) -> torch.Tensor:
|
| 66 |
-
batch_size = hidden_states.shape[0]
|
| 67 |
-
hidden_states = hidden_states.reshape(-1, self.hidden_size)
|
| 68 |
-
num_experts = routing_weights.shape[1]
|
| 69 |
-
|
| 70 |
-
if hidden_states.device.type == "cpu" or self.training:
|
| 71 |
-
next_states = torch.zeros_like(hidden_states, dtype=hidden_states.dtype, device=hidden_states.device)
|
| 72 |
-
with torch.no_grad():
|
| 73 |
-
expert_mask = torch.nn.functional.one_hot(router_indices, num_classes=num_experts)
|
| 74 |
-
expert_mask = expert_mask.permute(2, 1, 0)
|
| 75 |
-
expert_hit = torch.greater(expert_mask.sum(dim=(-1, -2)), 0).nonzero()
|
| 76 |
-
|
| 77 |
-
for expert_idx in expert_hit[:]:
|
| 78 |
-
expert_idx = expert_idx[0]
|
| 79 |
-
with torch.no_grad():
|
| 80 |
-
_, token_idx = torch.where(expert_mask[expert_idx])
|
| 81 |
-
current_state = hidden_states[token_idx]
|
| 82 |
-
gate_up = current_state @ self.gate_up_proj[expert_idx] + self.gate_up_proj_bias[expert_idx]
|
| 83 |
-
gate, up = gate_up[..., ::2], gate_up[..., 1::2]
|
| 84 |
-
gate = gate.clamp(min=None, max=self.limit)
|
| 85 |
-
up = up.clamp(min=-self.limit, max=self.limit)
|
| 86 |
-
glu = gate * torch.sigmoid(gate * self.alpha)
|
| 87 |
-
gated_output = (up + 1) * glu
|
| 88 |
-
out = gated_output @ self.down_proj[expert_idx] + self.down_proj_bias[expert_idx]
|
| 89 |
-
weighted_output = out * routing_weights[token_idx, expert_idx, None]
|
| 90 |
-
next_states.index_add_(0, token_idx, weighted_output.to(hidden_states.dtype))
|
| 91 |
-
next_states = next_states.view(batch_size, -1, self.hidden_size)
|
| 92 |
-
else:
|
| 93 |
-
hidden_states = hidden_states.repeat(num_experts, 1)
|
| 94 |
-
hidden_states = hidden_states.view(num_experts, -1, self.hidden_size)
|
| 95 |
-
gate_up = torch.bmm(hidden_states, self.gate_up_proj) + self.gate_up_proj_bias[..., None, :]
|
| 96 |
-
gate, up = gate_up[..., ::2], gate_up[..., 1::2]
|
| 97 |
-
gate = gate.clamp(min=None, max=self.limit)
|
| 98 |
-
up = up.clamp(min=-self.limit, max=self.limit)
|
| 99 |
-
glu = gate * torch.sigmoid(gate * self.alpha)
|
| 100 |
-
next_states = torch.bmm(((up + 1) * glu), self.down_proj)
|
| 101 |
-
next_states = next_states + self.down_proj_bias[..., None, :]
|
| 102 |
-
next_states = next_states.view(num_experts, batch_size, -1, self.hidden_size)
|
| 103 |
-
next_states = next_states * routing_weights.transpose(0, 1).view(num_experts, batch_size, -1)[..., None]
|
| 104 |
-
next_states = next_states.sum(dim=0)
|
| 105 |
-
return next_states
|
| 106 |
-
|
| 107 |
-
class GptOssMoEMLP(nn.Module):
|
| 108 |
-
def __init__(self, router_weight, router_bias, gate_up_proj, gate_up_proj_bias, down_proj, down_proj_bias):
|
| 109 |
-
super().__init__()
|
| 110 |
-
self.router = GptOssRouter(router_weight, router_bias)
|
| 111 |
-
self.experts = GptOssExperts(gate_up_proj, gate_up_proj_bias, down_proj, down_proj_bias)
|
| 112 |
-
|
| 113 |
-
def forward(self, hidden_states):
|
| 114 |
-
router_scores, router_indices = self.router(hidden_states)
|
| 115 |
-
routed_out = self.experts(hidden_states, router_indices=router_indices, routing_weights=router_scores)
|
| 116 |
-
return routed_out, router_scores
|
| 117 |
-
|
| 118 |
-
# Run the model
|
| 119 |
-
set_seed(GENERAL_SEED)
|
| 120 |
-
|
| 121 |
-
device = torch.device(DEVICE)
|
| 122 |
-
dtype = to_dtype(DTYPE)
|
| 123 |
-
|
| 124 |
-
print("\n=== GPT-OSS Implementation ===")
|
| 125 |
-
# Initialize model with loaded weights
|
| 126 |
-
model = GptOssMoEMLP(
|
| 127 |
-
router_weight.to(device),
|
| 128 |
-
router_bias.to(device),
|
| 129 |
-
gate_up_proj.to(device),
|
| 130 |
-
gate_up_proj_bias.to(device),
|
| 131 |
-
down_proj.to(device),
|
| 132 |
-
down_proj_bias.to(device)
|
| 133 |
-
).to(device=device)
|
| 134 |
-
|
| 135 |
-
print(f"Router weight sum: {model.router.weight.sum().item():.6f}")
|
| 136 |
-
print(f"Gate/up proj sum: {model.experts.gate_up_proj.sum().item():.6f}")
|
| 137 |
-
print(f"Down proj sum: {model.experts.down_proj.sum().item():.6f}")
|
| 138 |
-
|
| 139 |
-
# Generate the same input as other implementations
|
| 140 |
-
set_seed(INPUT_SEED)
|
| 141 |
-
x = torch.randn(BATCH_SIZE, SEQ_LEN, HIDDEN_SIZE, device=device, dtype=dtype) * 0.1
|
| 142 |
-
|
| 143 |
-
# Benchmark the model with varied inputs to prevent caching artifacts
|
| 144 |
-
tokens = BATCH_SIZE * SEQ_LEN
|
| 145 |
-
with bench_context(warmup=10, iters=50, device=device, dtype=dtype, tokens=tokens, save_json="gptoss_results.json", vary_inputs=True) as bench:
|
| 146 |
-
output, stats = bench(model, x)
|
| 147 |
-
print(f"\nOutput sum: {output[0].sum().item():.6f}")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
megablocks_yamoe/cells/gptoss_training_run.py
DELETED
|
@@ -1,138 +0,0 @@
|
|
| 1 |
-
# /// script
|
| 2 |
-
# dependencies = [
|
| 3 |
-
# "torch",
|
| 4 |
-
# "numpy",
|
| 5 |
-
# ]
|
| 6 |
-
# ///
|
| 7 |
-
|
| 8 |
-
import torch
|
| 9 |
-
from torch import nn
|
| 10 |
-
from torch.nn import functional as F
|
| 11 |
-
from bench_utils import to_dtype, tensor_stats, set_seed, bench_context
|
| 12 |
-
from config import (
|
| 13 |
-
NUM_EXPERTS, HIDDEN_SIZE, TOP_K,
|
| 14 |
-
BATCH_SIZE, SEQ_LEN, DTYPE, DEVICE,
|
| 15 |
-
WEIGHT_SEED, EXPERT_SEED, INPUT_SEED, GENERAL_SEED
|
| 16 |
-
)
|
| 17 |
-
from pathlib import Path
|
| 18 |
-
import os
|
| 19 |
-
|
| 20 |
-
# Discover the upstream artifact directory from env
|
| 21 |
-
data_dir = os.environ.get('UVNOTE_INPUT_SAVE_DATA', '.')
|
| 22 |
-
|
| 23 |
-
router_weight = torch.load(Path(data_dir) / 'router_weight.pt')
|
| 24 |
-
router_bias = torch.load(Path(data_dir) / 'router_bias.pt')
|
| 25 |
-
gate_up_proj = torch.load(Path(data_dir) / 'gate_up_proj.pt')
|
| 26 |
-
gate_up_proj_bias = torch.load(Path(data_dir) / 'gate_up_proj_bias.pt')
|
| 27 |
-
down_proj = torch.load(Path(data_dir) / 'down_proj.pt')
|
| 28 |
-
down_proj_bias = torch.load(Path(data_dir) / 'down_proj_bias.pt')
|
| 29 |
-
|
| 30 |
-
print("Loaded shared weights from artifacts")
|
| 31 |
-
print(f"Router weight sum: {router_weight.sum().item():.6f}")
|
| 32 |
-
print(f"Gate/up sum: {gate_up_proj.sum().item():.6f}")
|
| 33 |
-
print(f"Down sum: {down_proj.sum().item():.6f}")
|
| 34 |
-
|
| 35 |
-
class GptOssTrainingRouter(nn.Module):
|
| 36 |
-
def __init__(self, router_weight, router_bias):
|
| 37 |
-
super().__init__()
|
| 38 |
-
self.top_k = TOP_K
|
| 39 |
-
self.num_experts = NUM_EXPERTS
|
| 40 |
-
self.hidden_dim = HIDDEN_SIZE
|
| 41 |
-
self.weight = nn.Parameter(router_weight.clone())
|
| 42 |
-
self.bias = nn.Parameter(router_bias.clone())
|
| 43 |
-
|
| 44 |
-
def forward(self, hidden_states):
|
| 45 |
-
hidden_states = hidden_states.reshape(-1, self.hidden_dim)
|
| 46 |
-
router_logits = F.linear(hidden_states, self.weight, self.bias)
|
| 47 |
-
router_top_value, router_indices = torch.topk(router_logits, self.top_k, dim=-1)
|
| 48 |
-
router_top_value = torch.nn.functional.softmax(router_top_value, dim=1, dtype=router_top_value.dtype)
|
| 49 |
-
router_scores = torch.zeros_like(router_logits).scatter_(1, router_indices, router_top_value)
|
| 50 |
-
return router_scores, router_indices
|
| 51 |
-
|
| 52 |
-
class GptOssTrainingExperts(nn.Module):
|
| 53 |
-
def __init__(self, gate_up_proj, gate_up_proj_bias, down_proj, down_proj_bias):
|
| 54 |
-
super().__init__()
|
| 55 |
-
self.num_experts = NUM_EXPERTS
|
| 56 |
-
self.hidden_size = HIDDEN_SIZE
|
| 57 |
-
self.expert_dim = self.hidden_size
|
| 58 |
-
self.gate_up_proj = nn.Parameter(gate_up_proj.clone())
|
| 59 |
-
self.gate_up_proj_bias = nn.Parameter(gate_up_proj_bias.clone())
|
| 60 |
-
self.down_proj = nn.Parameter(down_proj.clone())
|
| 61 |
-
self.down_proj_bias = nn.Parameter(down_proj_bias.clone())
|
| 62 |
-
self.alpha = 1.702
|
| 63 |
-
self.limit = 7.0
|
| 64 |
-
|
| 65 |
-
def forward(self, hidden_states: torch.Tensor, router_indices=None, routing_weights=None) -> torch.Tensor:
|
| 66 |
-
batch_size = hidden_states.shape[0]
|
| 67 |
-
hidden_states = hidden_states.reshape(-1, self.hidden_size)
|
| 68 |
-
num_experts = routing_weights.shape[1]
|
| 69 |
-
|
| 70 |
-
# Force training mode path (expert loop instead of batched)
|
| 71 |
-
next_states = torch.zeros_like(hidden_states, dtype=hidden_states.dtype, device=hidden_states.device)
|
| 72 |
-
with torch.no_grad():
|
| 73 |
-
expert_mask = torch.nn.functional.one_hot(router_indices, num_classes=num_experts)
|
| 74 |
-
expert_mask = expert_mask.permute(2, 1, 0)
|
| 75 |
-
expert_hit = torch.greater(expert_mask.sum(dim=(-1, -2)), 0).nonzero()
|
| 76 |
-
|
| 77 |
-
for expert_idx in expert_hit[:]:
|
| 78 |
-
expert_idx = expert_idx[0]
|
| 79 |
-
with torch.no_grad():
|
| 80 |
-
_, token_idx = torch.where(expert_mask[expert_idx])
|
| 81 |
-
current_state = hidden_states[token_idx]
|
| 82 |
-
gate_up = current_state @ self.gate_up_proj[expert_idx] + self.gate_up_proj_bias[expert_idx]
|
| 83 |
-
gate, up = gate_up[..., ::2], gate_up[..., 1::2]
|
| 84 |
-
gate = gate.clamp(min=None, max=self.limit)
|
| 85 |
-
up = up.clamp(min=-self.limit, max=self.limit)
|
| 86 |
-
glu = gate * torch.sigmoid(gate * self.alpha)
|
| 87 |
-
gated_output = (up + 1) * glu
|
| 88 |
-
out = gated_output @ self.down_proj[expert_idx] + self.down_proj_bias[expert_idx]
|
| 89 |
-
weighted_output = out * routing_weights[token_idx, expert_idx, None]
|
| 90 |
-
next_states.index_add_(0, token_idx, weighted_output.to(hidden_states.dtype))
|
| 91 |
-
next_states = next_states.view(batch_size, -1, self.hidden_size)
|
| 92 |
-
return next_states
|
| 93 |
-
|
| 94 |
-
class GptOssTrainingMoEMLP(nn.Module):
|
| 95 |
-
def __init__(self, router_weight, router_bias, gate_up_proj, gate_up_proj_bias, down_proj, down_proj_bias):
|
| 96 |
-
super().__init__()
|
| 97 |
-
self.router = GptOssTrainingRouter(router_weight, router_bias)
|
| 98 |
-
self.experts = GptOssTrainingExperts(gate_up_proj, gate_up_proj_bias, down_proj, down_proj_bias)
|
| 99 |
-
|
| 100 |
-
def forward(self, hidden_states):
|
| 101 |
-
router_scores, router_indices = self.router(hidden_states)
|
| 102 |
-
routed_out = self.experts(hidden_states, router_indices=router_indices, routing_weights=router_scores)
|
| 103 |
-
return routed_out, router_scores
|
| 104 |
-
|
| 105 |
-
# Run the model
|
| 106 |
-
set_seed(GENERAL_SEED)
|
| 107 |
-
|
| 108 |
-
device = torch.device(DEVICE)
|
| 109 |
-
dtype = to_dtype(DTYPE)
|
| 110 |
-
|
| 111 |
-
print("\n=== GPT-OSS Implementation (Training Mode - Expert Loop) ===")
|
| 112 |
-
# Initialize model with loaded weights and force training mode
|
| 113 |
-
model = GptOssTrainingMoEMLP(
|
| 114 |
-
router_weight.to(device),
|
| 115 |
-
router_bias.to(device),
|
| 116 |
-
gate_up_proj.to(device),
|
| 117 |
-
gate_up_proj_bias.to(device),
|
| 118 |
-
down_proj.to(device),
|
| 119 |
-
down_proj_bias.to(device)
|
| 120 |
-
).to(device=device)
|
| 121 |
-
|
| 122 |
-
# Set to training mode to force expert loop path
|
| 123 |
-
model.train()
|
| 124 |
-
|
| 125 |
-
print(f"Router weight sum: {model.router.weight.sum().item():.6f}")
|
| 126 |
-
print(f"Gate/up proj sum: {model.experts.gate_up_proj.sum().item():.6f}")
|
| 127 |
-
print(f"Down proj sum: {model.experts.down_proj.sum().item():.6f}")
|
| 128 |
-
print(f"Model training mode: {model.training}")
|
| 129 |
-
|
| 130 |
-
# Generate the same input as other implementations
|
| 131 |
-
set_seed(INPUT_SEED)
|
| 132 |
-
x = torch.randn(BATCH_SIZE, SEQ_LEN, HIDDEN_SIZE, device=device, dtype=dtype) * 0.1
|
| 133 |
-
|
| 134 |
-
# Benchmark the model with varied inputs to prevent caching artifacts
|
| 135 |
-
tokens = BATCH_SIZE * SEQ_LEN
|
| 136 |
-
with bench_context(warmup=10, iters=50, device=device, dtype=dtype, tokens=tokens, save_json="gptoss_training_results.json", vary_inputs=True) as bench:
|
| 137 |
-
output, stats = bench(model, x)
|
| 138 |
-
print(f"\nOutput sum: {output[0].sum().item():.6f}")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
megablocks_yamoe/cells/megablocks_run.py
DELETED
|
@@ -1,103 +0,0 @@
|
|
| 1 |
-
# /// script
|
| 2 |
-
# dependencies = [
|
| 3 |
-
# "torch",
|
| 4 |
-
# "numpy",
|
| 5 |
-
# "kernels",
|
| 6 |
-
# ]
|
| 7 |
-
# ///
|
| 8 |
-
|
| 9 |
-
import torch
|
| 10 |
-
from torch import nn
|
| 11 |
-
from torch.nn import functional as F
|
| 12 |
-
from kernels import get_kernel, get_local_kernel
|
| 13 |
-
from bench_utils import to_dtype, tensor_stats, set_seed, bench_context
|
| 14 |
-
from config import (
|
| 15 |
-
NUM_EXPERTS, HIDDEN_SIZE, TOP_K,
|
| 16 |
-
BATCH_SIZE, SEQ_LEN, DTYPE, DEVICE,
|
| 17 |
-
WEIGHT_SEED, EXPERT_SEED, INPUT_SEED, GENERAL_SEED
|
| 18 |
-
)
|
| 19 |
-
from pathlib import Path
|
| 20 |
-
from collections import namedtuple
|
| 21 |
-
import os
|
| 22 |
-
|
| 23 |
-
# Discover the upstream artifact directory from env
|
| 24 |
-
data_dir = os.environ.get('UVNOTE_INPUT_SAVE_DATA', '.')
|
| 25 |
-
|
| 26 |
-
print(f"Loading weights from: {data_dir}")
|
| 27 |
-
|
| 28 |
-
router_weight = torch.load(Path(data_dir) / 'router_weight.pt')
|
| 29 |
-
router_bias = torch.load(Path(data_dir) / 'router_bias.pt')
|
| 30 |
-
gate_up_proj = torch.load(Path(data_dir) / 'gate_up_proj.pt')
|
| 31 |
-
gate_up_proj_bias = torch.load(Path(data_dir) / 'gate_up_proj_bias.pt')
|
| 32 |
-
down_proj = torch.load(Path(data_dir) / 'down_proj.pt')
|
| 33 |
-
down_proj_bias = torch.load(Path(data_dir) / 'down_proj_bias.pt')
|
| 34 |
-
|
| 35 |
-
print("Loaded shared weights from artifacts")
|
| 36 |
-
print(f"Router weight sum: {router_weight.sum().item():.6f}")
|
| 37 |
-
print(f"Gate/up sum: {gate_up_proj.sum().item():.6f}")
|
| 38 |
-
print(f"Down sum: {down_proj.sum().item():.6f}")
|
| 39 |
-
|
| 40 |
-
def build_megablocks_model(device: torch.device):
|
| 41 |
-
# Download optimized kernels from the Hugging Face hub
|
| 42 |
-
megablocks = get_kernel("kernels-community/megablocks", revision="v0.0.2")
|
| 43 |
-
model = megablocks.layers.MegaBlocksMoeMLP()
|
| 44 |
-
|
| 45 |
-
# Create attribute container for expert weights
|
| 46 |
-
model.experts = namedtuple(
|
| 47 |
-
"Experts", ["gate_up_proj", "gate_up_proj_bias", "down_proj", "down_proj_bias", "hidden_size"]
|
| 48 |
-
)
|
| 49 |
-
|
| 50 |
-
# Use loaded router weights for consistency
|
| 51 |
-
model.router = torch.nn.Linear(HIDDEN_SIZE, NUM_EXPERTS, device=device)
|
| 52 |
-
with torch.no_grad():
|
| 53 |
-
model.router.weight.copy_(router_weight)
|
| 54 |
-
model.router.bias.copy_(router_bias)
|
| 55 |
-
|
| 56 |
-
# Attach loaded expert weights to the experts container
|
| 57 |
-
e = model.experts
|
| 58 |
-
e.alpha = 1.702
|
| 59 |
-
e.capacity_factor = 32
|
| 60 |
-
e.gate_up_proj = torch.nn.Parameter(gate_up_proj.clone().to(device))
|
| 61 |
-
e.gate_up_proj_bias = torch.nn.Parameter(gate_up_proj_bias.clone().to(device))
|
| 62 |
-
e.down_proj = torch.nn.Parameter(down_proj.clone().to(device))
|
| 63 |
-
e.down_proj_bias = torch.nn.Parameter(down_proj_bias.clone().to(device))
|
| 64 |
-
e.hidden_size = HIDDEN_SIZE
|
| 65 |
-
|
| 66 |
-
# Log weight statistics for comparison
|
| 67 |
-
print(f"[MegaBlocks] Router weight sum: {model.router.weight.sum().item():.6f}")
|
| 68 |
-
print(f"[MegaBlocks] Gate/up projection shape: {tuple(e.gate_up_proj.shape)}, sum: {e.gate_up_proj.sum().item():.6f}")
|
| 69 |
-
print(f"[MegaBlocks] Down projection shape: {tuple(e.down_proj.shape)}, sum: {e.down_proj.sum().item():.6f}")
|
| 70 |
-
|
| 71 |
-
return model
|
| 72 |
-
|
| 73 |
-
# Create a wrapper to match the interface of other implementations
|
| 74 |
-
class MegaBlocksMoEWrapper(nn.Module):
|
| 75 |
-
def __init__(self, megablocks_model):
|
| 76 |
-
super().__init__()
|
| 77 |
-
self.model = megablocks_model
|
| 78 |
-
|
| 79 |
-
def forward(self, hidden_states):
|
| 80 |
-
# MegaBlocks expects input in the format (batch, seq_len, hidden_dim)
|
| 81 |
-
output, dummy_routing_weights = self.model(hidden_states)
|
| 82 |
-
return output, dummy_routing_weights
|
| 83 |
-
|
| 84 |
-
# Run the model
|
| 85 |
-
set_seed(GENERAL_SEED)
|
| 86 |
-
|
| 87 |
-
device = torch.device(DEVICE)
|
| 88 |
-
dtype = to_dtype(DTYPE)
|
| 89 |
-
|
| 90 |
-
print("\n=== MegaBlocks Implementation ===")
|
| 91 |
-
# Build MegaBlocks model with loaded weights
|
| 92 |
-
megablocks_model = build_megablocks_model(device)
|
| 93 |
-
model = MegaBlocksMoEWrapper(megablocks_model).to(device=device)
|
| 94 |
-
|
| 95 |
-
# Generate the same input as other implementations
|
| 96 |
-
set_seed(INPUT_SEED)
|
| 97 |
-
x = torch.randn(BATCH_SIZE, SEQ_LEN, HIDDEN_SIZE, device=device, dtype=dtype) * 0.1
|
| 98 |
-
|
| 99 |
-
# Benchmark the model with varied inputs to prevent caching artifacts
|
| 100 |
-
tokens = BATCH_SIZE * SEQ_LEN
|
| 101 |
-
with bench_context(warmup=10, iters=50, device=device, dtype=dtype, tokens=tokens, save_json="megablocks_results.json", vary_inputs=True) as bench:
|
| 102 |
-
output, stats = bench(model, x)
|
| 103 |
-
print(f"\nOutput sum: {output[0].sum().item():.6f}")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
megablocks_yamoe/cells/nv.py
DELETED
|
@@ -1,3 +0,0 @@
|
|
| 1 |
-
import subprocess
|
| 2 |
-
|
| 3 |
-
print(subprocess.run(["nvidia-smi"], capture_output=True, text=True).stdout)
|
|
|
|
|
|
|
|
|
|
|
|
megablocks_yamoe/cells/save_data.py
DELETED
|
@@ -1,42 +0,0 @@
|
|
| 1 |
-
# /// script
|
| 2 |
-
# dependencies = [
|
| 3 |
-
# "torch",
|
| 4 |
-
# "numpy",
|
| 5 |
-
# ]
|
| 6 |
-
# ///
|
| 7 |
-
|
| 8 |
-
"""
|
| 9 |
-
Generate deterministic shared weights once and save as artifacts so
|
| 10 |
-
both implementations load identical parameters.
|
| 11 |
-
"""
|
| 12 |
-
import torch
|
| 13 |
-
from config import NUM_EXPERTS, HIDDEN_SIZE, WEIGHT_SEED, EXPERT_SEED
|
| 14 |
-
|
| 15 |
-
def save_shared_weights():
|
| 16 |
-
# Router: Kaiming uniform as used by both, bias zeros
|
| 17 |
-
torch.manual_seed(WEIGHT_SEED)
|
| 18 |
-
router_weight = torch.empty(NUM_EXPERTS, HIDDEN_SIZE)
|
| 19 |
-
torch.nn.init.kaiming_uniform_(router_weight)
|
| 20 |
-
router_bias = torch.zeros(NUM_EXPERTS)
|
| 21 |
-
|
| 22 |
-
# Experts: normal(0, 0.02), biases zeros
|
| 23 |
-
torch.manual_seed(EXPERT_SEED)
|
| 24 |
-
gate_up_proj = torch.empty(NUM_EXPERTS, HIDDEN_SIZE, 2 * HIDDEN_SIZE).normal_(mean=0.0, std=0.02)
|
| 25 |
-
gate_up_proj_bias = torch.zeros(NUM_EXPERTS, 2 * HIDDEN_SIZE)
|
| 26 |
-
down_proj = torch.empty(NUM_EXPERTS, HIDDEN_SIZE, HIDDEN_SIZE).normal_(mean=0.0, std=0.02)
|
| 27 |
-
down_proj_bias = torch.zeros(NUM_EXPERTS, HIDDEN_SIZE)
|
| 28 |
-
|
| 29 |
-
# Save artifacts
|
| 30 |
-
torch.save(router_weight, 'router_weight.pt')
|
| 31 |
-
torch.save(router_bias, 'router_bias.pt')
|
| 32 |
-
torch.save(gate_up_proj, 'gate_up_proj.pt')
|
| 33 |
-
torch.save(gate_up_proj_bias, 'gate_up_proj_bias.pt')
|
| 34 |
-
torch.save(down_proj, 'down_proj.pt')
|
| 35 |
-
torch.save(down_proj_bias, 'down_proj_bias.pt')
|
| 36 |
-
|
| 37 |
-
print("Saved shared weights to artifacts")
|
| 38 |
-
print(f"Router weight sum: {router_weight.sum().item():.6f}")
|
| 39 |
-
print(f"Gate/up sum: {gate_up_proj.sum().item():.6f}")
|
| 40 |
-
print(f"Down sum: {down_proj.sum().item():.6f}")
|
| 41 |
-
|
| 42 |
-
save_shared_weights()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|