diff --git a/flash_attn/artifacts/benchmark/Attention Benchmark.csv b/flash_attn/artifacts/benchmark/Attention Benchmark.csv deleted file mode 100644 index 0b8db9d25de250af246cc667ea3a1c80b37942ac..0000000000000000000000000000000000000000 --- a/flash_attn/artifacts/benchmark/Attention Benchmark.csv +++ /dev/null @@ -1,7 +0,0 @@ -seq_len,torch_cudnn,torch_cudnn_compile_d,torch_cudnn_compile_ma,torch_flash,torch_flash_compile_d,torch_flash_compile_ma,hf_flash_attn,hf_flash_attn3 -4224.000000,3.801472,3.790064,4.182320,3.968000,3.957824,4.311152,3.398160,3.330400 -4352.000000,4.082944,4.082912,4.413488,4.400000,4.391936,4.738048,3.837424,3.758208 -4416.000000,4.142624,4.135648,4.484160,4.452304,4.446096,4.792480,3.892064,3.864128 -4480.000000,4.206144,4.198752,4.551808,4.530752,4.522944,4.873760,3.949344,3.870224 -4544.000000,4.438320,4.433104,4.787584,4.584160,4.576640,4.934304,4.008960,3.974672 -4608.000000,4.502432,4.495456,4.871872,4.660192,4.651040,5.029792,4.065616,3.984160 diff --git a/flash_attn/artifacts/benchmark/Attention Benchmark.png b/flash_attn/artifacts/benchmark/Attention Benchmark.png deleted file mode 100644 index 12e70febc74744001584a043800503cee0e0bf08..0000000000000000000000000000000000000000 --- a/flash_attn/artifacts/benchmark/Attention Benchmark.png +++ /dev/null @@ -1,3 +0,0 @@ -version https://git-lfs.github.com/spec/v1 -oid sha256:69a5d2d4ac33fa06e77a599eab6cadcddb77c15ad7bde323bb07849e2aa3ac14 -size 141768 diff --git a/flash_attn/artifacts/benchmark/results.html b/flash_attn/artifacts/benchmark/results.html deleted file mode 100644 index 5535383951dbeead49ab170ff228324321c07c2a..0000000000000000000000000000000000000000 --- a/flash_attn/artifacts/benchmark/results.html +++ /dev/null @@ -1,3 +0,0 @@ - - - diff --git a/flash_attn/benchmark.html b/flash_attn/benchmark.html deleted file mode 100644 index 53d1e5a43809334e0e7582be97c5ddb4a1e3cac0..0000000000000000000000000000000000000000 --- a/flash_attn/benchmark.html +++ /dev/null @@ -1,4652 +0,0 @@ - - - - - - Flash Attention Benchmark - - - - - - - -
-
-
light
-
reset
- -
-
- -
-
Generated on:
-
- Linux x86_64 | Linux-6.12.40-64.114.amzn2023.x86_64-x86_64-with-glibc2.36 -
-
- -
-
-
- -▼ code -▼ output - ▶ uv-logs - | -Cell: nv | 0.67s - | - -Raw -
-
-
-
-1 -2 -3 -
-
-
import subprocess
-
-print(subprocess.run(["nvidia-smi"], capture_output=True, text=True).stdout)
-
- -
-
-
-
-
-
Fri Sep 26 03:53:23 2025 -+-----------------------------------------------------------------------------------------+ -| NVIDIA-SMI 570.172.08 Driver Version: 570.172.08 CUDA Version: 12.8 | -|-----------------------------------------+------------------------+----------------------+ -| GPU Name Persistence-M | Bus-Id Disp.A | Volatile Uncorr. ECC | -| Fan Temp Perf Pwr:Usage/Cap | Memory-Usage | GPU-Util Compute M. | -| | | MIG M. | -|=========================================+========================+======================| -| 0 NVIDIA A10G On | 00000000:00:1B.0 Off | 0 | -| 0% 38C P0 51W / 300W | 0MiB / 23028MiB | 0% Default | -| | | N/A | -+-----------------------------------------+------------------------+----------------------+ -| 1 NVIDIA A10G On | 00000000:00:1C.0 Off | 0 | -| 0% 31C P8 24W / 300W | 0MiB / 23028MiB | 0% Default | -| | | N/A | -+-----------------------------------------+------------------------+----------------------+ -| 2 NVIDIA A10G On | 00000000:00:1D.0 Off | 0 | -| 0% 31C P8 25W / 300W | 0MiB / 23028MiB | 0% Default | -| | | N/A | -+-----------------------------------------+------------------------+----------------------+ -| 3 NVIDIA A10G On | 00000000:00:1E.0 Off | 0 | -| 0% 31C P8 25W / 300W | 0MiB / 23028MiB | 0% Default | -| | | N/A | -+-----------------------------------------+------------------------+----------------------+ - -+-----------------------------------------------------------------------------------------+ -| Processes: | -| GPU GI CI PID Type Process name GPU Memory | -| ID ID Usage | -|=========================================================================================| -| No running processes found | -+-----------------------------------------------------------------------------------------+ - -
-
-
- -
-
- -▼ code -▼ output - ▶ uv-logs - | -Cell: benchmark | 75.46s - | - -Raw -
-
-
-
-1 -2 -3 -4 -5 -6 -7 -8 -9 -10 -11 -12 -13 -14 -15 -16 -17 -18 -19 -20 -21 -22 -23 -24 -25 -26 -27 -28 -29 -30 -31 -32 -33 -34 -35 -36 -37 -38 -39 -40 -41 -42 -43 -44 -45 -46 -47 -48 -49 -50 -51 -52 -53 -54 -55 -56 -57 -58 -59 -60 -61 -62 -63 -64 -65 -66 -67 -68 -69 -70 -71 -72 -73 -74 -75 -76 -77 -78 -79 -80 -81 -82 -83 -84 -85 -86 -87 -88 -89 -90 -91 -92 -93 -94 -95 -96 -97 -98 -99 -100 -101 -102 -103 -104 -105 -106 -107 -108 -109 -110 -111 -112 -113 -114 -115 -116 -117 -118 -119 -120 -121 -122 -123 -124 -125 -126 -127 -128 -129 -130 -131 -132 -133 -134 -135 -136 -137 -138 -139 -140 -141 -142 -143 -144 -145 -146 -147 -148 -149 -150 -151 -152 -153 -154 -155 -156 -157 -158 -159 -160 -161 -162 -163 -164 -165 -166 -167 -168 -169 -170 -171 -172 -173 -174 -175 -176 -177 -178 -179 -180 -181 -182 -183 -184 -185 -186 -187 -188 -189 -190 -191 -192 -193 -194 -195 -196 -197 -198 -199 -200 -201 -202 -203 -204 -205 -206 -207 -208 -209 -210 -211 -212 -213 -214 -215 -216 -217 -218 -219 -220 -221 -222 -223 -224 -225 -226 -227 -228 -229 -230 -231 -232 -233 -234 -235 -236 -237 -238 -239 -240 -241 -242 -243 -244 -245 -246 -247 -248 -249 -250 -251 -252 -253 -254 -255 -256 -257 -258 -259 -260 -261 -262 -263 -264 -265 -266 -267 -268 -269 -270 -271 -272 -273 -274 -275 -276 -277 -278 -279 -280 -281 -282 -283 -284 -285 -286 -287 -288 -289 -290 -291 -292 -293 -294 -295 -296 -297 -298 -299 -300 -301 -302 -303 -304 -305 -306 -307 -308 -309 -310 -311 -312 -313 -314 -315 -316 -317 -318 -319 -320 -321 -322 -323 -324 -325 -326 -327 -328 -329 -330 -331 -332 -333 -334 -335 -336 -337 -338 -339 -340 -341 -342 -343 -
-
-
# /// script
-# dependencies = [
-#   "numpy",
-#   "torch",
-#   "kernels",
-#   "pandas",
-#   "matplotlib"
-# ]
-# ///
-# Benchmarking common shapes for Flux 1024x1024px image + varying text sequence lengths
-
-import functools
-import os
-import pathlib
-
-import matplotlib.pyplot as plt
-import torch
-import torch._dynamo.config
-import triton
-import triton.language as tl
-
-try:
-    from flash_attn import flash_attn_func
-except:
-    flash_attn_func = None
-    print("Flash Attention 2 not found.")
-
-try:
-    from flash_attn_interface import flash_attn_func as flash_attn_3_func
-except:
-    flash_attn_3_func = None
-    print("Flash Attention 3 not found.")
-
-try:
-    from kernels import get_kernel
-    hf_kernels_flash_attn = get_kernel("kernels-community/flash-attn")
-    hf_kernels_flash_attn_3 = get_kernel("kernels-community/flash-attn3")
-except:
-    hf_kernels_flash_attn = None
-    hf_kernels_flash_attn_3 = None
-    print("HF Kernels not found.")
-
-try:
-    from sageattention import sageattn_qk_int8_pv_fp16_cuda, sageattn_qk_int8_pv_fp16_triton, sageattn_qk_int8_pv_fp8_cuda_sm90
-except:
-    sageattn_qk_int8_pv_fp16_cuda = None
-    sageattn_qk_int8_pv_fp16_triton = None
-    sageattn_qk_int8_pv_fp8_cuda_sm90 = None
-    print("SageAttention not found.")
-
-try:
-    from transformer_engine.pytorch.attention import DotProductAttention
-except:
-    DotProductAttention = None
-    print("Transformer Engine not found.")
-
-try:
-    import xformers.ops as xops
-except:
-    xops = None
-    print("xFormers not found.")
-
-
-plt.rcParams.update({
-    "figure.figsize": (12, 10),
-    "figure.dpi": 120,
-    "font.size": 10,
-    "axes.titlesize": 12,
-    "axes.labelsize": 14,
-    "xtick.labelsize": 10,
-    "ytick.labelsize": 10,
-    "legend.fontsize": 8,
-    "axes.grid": True,
-    "grid.alpha": 0.3,
-    "grid.linestyle": "--",
-    "lines.linewidth": 2.0,
-    "lines.markersize": 6,
-    "legend.frameon": True,
-    "legend.framealpha": 0.9,
-    "legend.loc": "best",
-    "axes.spines.top": False,
-    "axes.spines.right": False,
-})
-
-
-# We want to compare the best compiled version for each specific shape (dynamic=False)
-torch._dynamo.config.cache_size_limit = 10000
-
-# We need to suppress_errors for FA3 to work. It makes it run in eager mode.
-# I can't seem to get it to work any other way under torch.compile, so any suggestions are welcome!
-torch._dynamo.config.suppress_errors = True
-
-# output_dir = pathlib.Path("dump_attention_benchmark")
-# output_dir.mkdir(parents=True, exist_ok=True)
-
-output_dir = pathlib.Path(".") # output to current directory for upload
-
-batch_size = 1
-num_attention_heads = 24
-attention_head_dim = 128
-image_sequence_length = 4096  # 1024x1024px
-text_sequence_lengths = [128, 256, 320, 384, 448, 512]
-sequence_lengths = [image_sequence_length + i for i in text_sequence_lengths]
-
-
-def _attention_torch(query, key, value, *, backend):
-    query, key, value = (x.transpose(1, 2).contiguous() for x in (query, key, value))
-    with torch.nn.attention.sdpa_kernel(backend):
-        out = torch.nn.functional.scaled_dot_product_attention(query, key, value)
-    out = out.transpose(1, 2).contiguous()
-    return out
-
-
-_compiled_attention_torch_default = torch.compile(_attention_torch, mode="default", fullgraph=True, dynamic=False)
-def _attention_torch_compile_default(query, key, value, *, backend):
-    return _compiled_attention_torch_default(query, key, value, backend=backend)
-
-
-_compiled_attention_torch_max_autotune = torch.compile(_attention_torch, mode="max-autotune", fullgraph=True, dynamic=False)
-def _attention_torch_compile_max_autotune(query, key, value, *, backend):
-    return _compiled_attention_torch_max_autotune(query, key, value, backend=backend)
-
-
-def _attention_flash_attn_2(query, key, value):
-    return flash_attn_func(query, key, value)
-
-
-_compiled_flash_attn_2_default = torch.compile(_attention_flash_attn_2, mode="default", fullgraph=True, dynamic=False)
-def _attention_flash_attn_2_compile_default(query, key, value):
-    return _compiled_flash_attn_2_default(query, key, value)
-
-
-_compiled_flash_attn_2_max_autotune = torch.compile(_attention_flash_attn_2, mode="max-autotune", fullgraph=True, dynamic=False)
-def _attention_flash_attn_2_compile_max_autotune(query, key, value):
-    return _compiled_flash_attn_2_max_autotune(query, key, value)
-
-
-# For fullgraph=True tracing to be compatible
-@torch.library.custom_op("flash_attn_3::_flash_attn_forward", mutates_args=(), device_types="cuda")
-def _wrapped_flash_attn_3(query: torch.Tensor, key: torch.Tensor, value: torch.Tensor) -> torch.Tensor:
-    out, lse = flash_attn_3_func(query, key, value)
-    return out
-
-
-@torch.library.register_fake("flash_attn_3::_flash_attn_forward")
-def _(query: torch.Tensor, key: torch.Tensor, value: torch.Tensor) -> torch.Tensor:
-    return torch.empty_like(query)
-
-
-def _attention_flash_attn_3(query, key, value):
-    out = _wrapped_flash_attn_3(query, key, value)
-    return out
-
-
-_compiled_flash_attn_3_default = torch.compile(_attention_flash_attn_3, mode="default", fullgraph=True, dynamic=False)
-def _attention_flash_attn_3_compile_default(query, key, value):
-    return _compiled_flash_attn_3_default(query, key, value)
-
-
-_compiled_flash_attn_3_max_autotune = torch.compile(_attention_flash_attn_3, mode="max-autotune", fullgraph=True, dynamic=False)
-def _attention_flash_attn_3_compile_max_autotune(query, key, value):
-    return _compiled_flash_attn_3_max_autotune(query, key, value)
-
-
-def _attention_hf_kernels_flash_attn(query, key, value):
-    return hf_kernels_flash_attn.fwd(query, key, value, is_causal=False)[0]
-
-
-def _attention_hf_kernels_flash_attn3(query, key, value):
-    return hf_kernels_flash_attn_3.flash_attn_func(query, key, value, causal=False)[0]
-
-
-def _attention_sageattn_qk_int8_pv_fp16_cuda(query, key, value):
-    return sageattn_qk_int8_pv_fp16_cuda(query, key, value, tensor_layout="NHD")
-
-
-def _attention_sageattn_qk_int8_pv_fp16_triton(query, key, value):
-    return sageattn_qk_int8_pv_fp16_triton(query, key, value, tensor_layout="NHD")
-
-
-def _attention_sageattn_qk_int8_pv_fp8_cuda_sm90(query, key, value):
-    return sageattn_qk_int8_pv_fp8_cuda_sm90(query, key, value, tensor_layout="NHD")
-
-
-if DotProductAttention is not None:
-    def set_te_backend(backend):
-        # must be applied before first use of
-        # transformer_engine.pytorch.attention
-        os.environ["NVTE_FLASH_ATTN"] = '0'
-        os.environ["NVTE_FUSED_ATTN"] = '0'
-        os.environ["NVTE_UNFUSED_ATTN"] = '0'
-        if backend == 'flash':
-            os.environ["NVTE_FLASH_ATTN"] = '1'
-        if backend == 'fused':
-            os.environ["NVTE_FUSED_ATTN"] = '1'
-        if backend == 'unfused':
-            os.environ["NVTE_UNFUSED_ATTN"] = '1'
-
-    set_te_backend("fused")
-    te_attn_fn = DotProductAttention(
-        num_attention_heads=num_attention_heads,
-        kv_channels=attention_head_dim,
-        qkv_format="bshd",
-        attn_mask_type="no_mask",
-    )
-else:
-    def te_attn_fn(query, key, value):
-        raise RuntimeError("Transformer Engine is not available. Please install it for TE-based attention.")
-
-def _attention_te(query, key, value):
-    out = te_attn_fn(query, key, value)
-    out = out.unflatten(2, (num_attention_heads, attention_head_dim))
-    return out
-
-
-# Cannot fullgraph compile TE
-_compiled_te_attn_fn_default = torch.compile(_attention_te, mode="default", fullgraph=False, dynamic=False)
-def _attention_te_compile_default(query, key, value):
-    return _compiled_te_attn_fn_default(query, key, value)
-
-
-# Cannot fullgraph compile TE
-_compiled_te_attn_fn_max_autotune = torch.compile(_attention_te, mode="max-autotune", fullgraph=False, dynamic=False)
-def _attention_te_compile_max_autotune(query, key, value):
-    return _compiled_te_attn_fn_max_autotune(query, key, value)
-
-
-def _attention_xformers(query, key, value):
-    return xops.memory_efficient_attention(query, key, value)
-
-
-_compiled_xformers_default = torch.compile(_attention_xformers, mode="default", fullgraph=True, dynamic=False)
-def _attention_xformers_compile_default(query, key, value):
-    return _compiled_xformers_default(query, key, value)
-
-
-_compiled_xformers_max_autotune = torch.compile(_attention_xformers, mode="max-autotune", fullgraph=True, dynamic=False)
-def _attention_xformers_compile_max_autotune(query, key, value):
-    return _compiled_xformers_max_autotune(query, key, value)
-
-
-attention_ops = {}
-attention_ops["torch_cudnn"] = functools.partial(_attention_torch, backend=torch.nn.attention.SDPBackend.CUDNN_ATTENTION)
-attention_ops["torch_cudnn_compile_d"] = functools.partial(_attention_torch_compile_default, backend=torch.nn.attention.SDPBackend.CUDNN_ATTENTION)
-attention_ops["torch_cudnn_compile_ma"] = functools.partial(_attention_torch_compile_max_autotune, backend=torch.nn.attention.SDPBackend.CUDNN_ATTENTION)
-attention_ops["torch_flash"] = functools.partial(_attention_torch, backend=torch.nn.attention.SDPBackend.FLASH_ATTENTION)
-attention_ops["torch_flash_compile_d"] = functools.partial(_attention_torch_compile_default, backend=torch.nn.attention.SDPBackend.FLASH_ATTENTION)
-attention_ops["torch_flash_compile_ma"] = functools.partial(_attention_torch_compile_max_autotune, backend=torch.nn.attention.SDPBackend.FLASH_ATTENTION)
-if hf_kernels_flash_attn is not None:
-    attention_ops["hf_flash_attn"] = _attention_hf_kernels_flash_attn
-    attention_ops["hf_flash_attn3"] = _attention_hf_kernels_flash_attn3
-if flash_attn_func is not None:
-    attention_ops["flash_attn_2"] = _attention_flash_attn_2
-    attention_ops["flash_attn_2_compile_d"] = _attention_flash_attn_2_compile_default
-    attention_ops["flash_attn_2_compile_ma"] = _attention_flash_attn_2_compile_max_autotune
-if flash_attn_3_func is not None:
-    attention_ops["flash_attn_3"] = _attention_flash_attn_3
-    attention_ops["flash_attn_3_compile_d"] = _attention_flash_attn_3_compile_default
-    attention_ops["flash_attn_3_compile_ma"] = _attention_flash_attn_3_compile_max_autotune
-if sageattn_qk_int8_pv_fp16_cuda is not None:
-    attention_ops["sageattn_qk_int8_pv_fp16_cuda"] = _attention_sageattn_qk_int8_pv_fp16_cuda
-    attention_ops["sageattn_qk_int8_pv_fp16_triton"] = _attention_sageattn_qk_int8_pv_fp16_triton
-    if torch.cuda.get_device_capability()[0] >= 9:
-        attention_ops["sageattn_qk_int8_pv_fp8_cuda_sm90"] = _attention_sageattn_qk_int8_pv_fp8_cuda_sm90
-if DotProductAttention is not None:
-    attention_ops["te_fused"] = _attention_te
-    attention_ops["te_fused_compile_d"] = _attention_te_compile_default
-    attention_ops["te_fused_compile_ma"] = _attention_te_compile_max_autotune
-if xops is not None:
-    attention_ops["xformers"] = _attention_xformers
-    attention_ops["xformers_compile_d"] = _attention_xformers_compile_default
-    attention_ops["xformers_compile_ma"] = _attention_xformers_compile_max_autotune
-
-
-def get_color_and_linestyle(n: int) -> tuple[str, str]:
-    colors = ["#e41a1c", "#377eb8", "#4daf4a", "#984ea3", "#ff7f00", "#a65628", "#f781bf", "#999999"]
-    line_styles = ["-", ":", "-.", "--"]
-    if n > len(colors) * len(line_styles):
-        raise ValueError(f"Required {n=} styles but maximum is {len(colors) * len(line_styles)}")
-    styles = []
-    for i in range(n):
-        color = colors[i % len(colors)]
-        linestyle = line_styles[i // len(colors)]
-        styles.append((color, linestyle))
-    return styles
-
-
-def correctness():
-    for seq_len in sequence_lengths:
-        shape = (batch_size, seq_len, num_attention_heads, attention_head_dim)
-        print(f"\n\n===== Testing shape: {shape} =====")
-
-        query = torch.randn(shape, device="cuda", dtype=torch.float32)
-        key = torch.randn(shape, device="cuda", dtype=torch.float32)
-        value = torch.randn(shape, device="cuda", dtype=torch.float32)
-
-        golden_truth = _attention_torch(query, key, value, backend=torch.nn.attention.SDPBackend.MATH)
-        query, key, value = (x.bfloat16() for x in (query, key, value))
-
-        for name, fn in attention_ops.items():
-            out = fn(query, key, value)
-            absdiff = (out - golden_truth).abs()
-            absmax = torch.max(absdiff)
-            mae = torch.mean(absdiff)
-            mse = torch.mean((golden_truth - out) ** 2)
-            print(f"{name:<30}: absmax={absmax:.6f}, mae={mae:.6f}, mse={mse:.6f}")
-
-
-@triton.testing.perf_report(
-    triton.testing.Benchmark(
-        x_names=["seq_len"],
-        x_vals=sequence_lengths,
-        x_log=False,
-        line_arg="provider",
-        line_vals=list(attention_ops.keys()),
-        line_names=[x.removeprefix("solution_") for x in attention_ops.keys()],
-        ylabel="Time (ms)",
-        styles=get_color_and_linestyle(len(attention_ops)),
-        plot_name="Attention Benchmark",
-        args={},
-    )
-)
-def benchmark_fn(seq_len: int, provider: str):
-    torch.manual_seed(0)
-
-    shape = (batch_size, seq_len, num_attention_heads, attention_head_dim)
-    query = torch.randn(shape, device="cuda", dtype=torch.bfloat16) * torch.randint(1, 5, shape, device="cuda", dtype=torch.bfloat16)
-    key = torch.randn(shape, device="cuda", dtype=torch.bfloat16) * torch.randint(1, 5, shape, device="cuda", dtype=torch.bfloat16)
-    value = torch.randn(shape, device="cuda", dtype=torch.bfloat16) * torch.randint(1, 5, shape, device="cuda", dtype=torch.bfloat16)
-
-    fn = attention_ops[provider]
-    ms, min_ms, max_ms = triton.testing.do_bench(
-        lambda: fn(query, key, value),
-        warmup=3,
-        rep=10,
-        quantiles=[0.5, 0.2, 0.8],
-    )
-    return ms, max_ms, min_ms
-
-
-with torch.inference_mode():
-    correctness()
-    fig = benchmark_fn.run(print_data=True, save_path=output_dir.as_posix())
-
- -
-
-
-
-
-
Flash Attention 2 not found. -Flash Attention 3 not found. -SageAttention not found. -Transformer Engine not found. -xFormers not found. - - -===== Testing shape: (1, 4224, 24, 128) ===== -torch_cudnn : absmax=0.001524, mae=0.000075, mse=0.000000 -torch_cudnn_compile_d : absmax=0.001524, mae=0.000075, mse=0.000000 -torch_cudnn_compile_ma : absmax=0.001524, mae=0.000075, mse=0.000000 -torch_flash : absmax=0.001524, mae=0.000075, mse=0.000000 -torch_flash_compile_d : absmax=0.001524, mae=0.000075, mse=0.000000 -torch_flash_compile_ma : absmax=0.001524, mae=0.000075, mse=0.000000 -hf_flash_attn : absmax=0.001524, mae=0.000075, mse=0.000000 -hf_flash_attn3 : absmax=0.001524, mae=0.000075, mse=0.000000 - - -===== Testing shape: (1, 4352, 24, 128) ===== -torch_cudnn : absmax=0.001335, mae=0.000074, mse=0.000000 -torch_cudnn_compile_d : absmax=0.001335, mae=0.000074, mse=0.000000 -torch_cudnn_compile_ma : absmax=0.001335, mae=0.000074, mse=0.000000 -torch_flash : absmax=0.001321, mae=0.000074, mse=0.000000 -torch_flash_compile_d : absmax=0.001321, mae=0.000074, mse=0.000000 -torch_flash_compile_ma : absmax=0.001321, mae=0.000074, mse=0.000000 -hf_flash_attn : absmax=0.001321, mae=0.000074, mse=0.000000 -hf_flash_attn3 : absmax=0.001321, mae=0.000074, mse=0.000000 - - -===== Testing shape: (1, 4416, 24, 128) ===== -torch_cudnn : absmax=0.000897, mae=0.000073, mse=0.000000 -torch_cudnn_compile_d : absmax=0.000897, mae=0.000073, mse=0.000000 -torch_cudnn_compile_ma : absmax=0.000897, mae=0.000073, mse=0.000000 -torch_flash : absmax=0.000897, mae=0.000073, mse=0.000000 -torch_flash_compile_d : absmax=0.000897, mae=0.000073, mse=0.000000 -torch_flash_compile_ma : absmax=0.000897, mae=0.000073, mse=0.000000 -hf_flash_attn : absmax=0.000897, mae=0.000073, mse=0.000000 -hf_flash_attn3 : absmax=0.000897, mae=0.000073, mse=0.000000 - - -===== Testing shape: (1, 4480, 24, 128) ===== -torch_cudnn : absmax=0.001691, mae=0.000073, mse=0.000000 -torch_cudnn_compile_d : absmax=0.001691, mae=0.000073, mse=0.000000 -torch_cudnn_compile_ma : absmax=0.001691, mae=0.000073, mse=0.000000 -torch_flash : absmax=0.001691, mae=0.000073, mse=0.000000 -torch_flash_compile_d : absmax=0.001691, mae=0.000073, mse=0.000000 -torch_flash_compile_ma : absmax=0.001691, mae=0.000073, mse=0.000000 -hf_flash_attn : absmax=0.001691, mae=0.000073, mse=0.000000 -hf_flash_attn3 : absmax=0.001691, mae=0.000073, mse=0.000000 - - -===== Testing shape: (1, 4544, 24, 128) ===== -torch_cudnn : absmax=0.001201, mae=0.000072, mse=0.000000 -torch_cudnn_compile_d : absmax=0.001201, mae=0.000072, mse=0.000000 -torch_cudnn_compile_ma : absmax=0.001201, mae=0.000072, mse=0.000000 -torch_flash : absmax=0.001201, mae=0.000072, mse=0.000000 -torch_flash_compile_d : absmax=0.001201, mae=0.000072, mse=0.000000 -torch_flash_compile_ma : absmax=0.001201, mae=0.000072, mse=0.000000 -hf_flash_attn : absmax=0.001201, mae=0.000072, mse=0.000000 -hf_flash_attn3 : absmax=0.001201, mae=0.000072, mse=0.000000 - - -===== Testing shape: (1, 4608, 24, 128) ===== -torch_cudnn : absmax=0.001150, mae=0.000071, mse=0.000000 -torch_cudnn_compile_d : absmax=0.001150, mae=0.000071, mse=0.000000 -torch_cudnn_compile_ma : absmax=0.001150, mae=0.000071, mse=0.000000 -torch_flash : absmax=0.001150, mae=0.000071, mse=0.000000 -torch_flash_compile_d : absmax=0.001150, mae=0.000071, mse=0.000000 -torch_flash_compile_ma : absmax=0.001150, mae=0.000071, mse=0.000000 -hf_flash_attn : absmax=0.001150, mae=0.000071, mse=0.000000 -hf_flash_attn3 : absmax=0.001150, mae=0.000071, mse=0.000000 -Attention Benchmark: - seq_len torch_cudnn torch_cudnn_compile_d torch_cudnn_compile_ma torch_flash torch_flash_compile_d torch_flash_compile_ma hf_flash_attn hf_flash_attn3 -0 4224.0 3.801472 3.790064 4.182320 3.968000 3.957824 4.311152 3.398160 3.330400 -1 4352.0 4.082944 4.082912 4.413488 4.400000 4.391936 4.738048 3.837424 3.758208 -2 4416.0 4.142624 4.135648 4.484160 4.452304 4.446096 4.792480 3.892064 3.864128 -3 4480.0 4.206144 4.198752 4.551808 4.530752 4.522944 4.873760 3.949344 3.870224 -4 4544.0 4.438320 4.433104 4.787584 4.584160 4.576640 4.934304 4.008960 3.974672 -5 4608.0 4.502432 4.495456 4.871872 4.660192 4.651040 5.029792 4.065616 3.984160 -
-
-
▶ UV Install Logs
- -
-
Fetching 20 files: 0%| | 0/20 [00:00<?, ?it/s] -Fetching 20 files: 5%|▌ | 1/20 [00:00<00:03, 5.10it/s] -Fetching 20 files: 10%|█ | 2/20 [00:01<00:14, 1.23it/s] -Fetching 20 files: 100%|██████████| 20/20 [00:01<00:00, 13.86it/s] - -Fetching 4 files: 0%| | 0/4 [00:00<?, ?it/s] -Fetching 4 files: 25%|██▌ | 1/4 [00:00<00:00, 6.31it/s] -Fetching 4 files: 50%|█████ | 2/4 [00:01<00:01, 1.34it/s] -Fetching 4 files: 100%|██████████| 4/4 [00:01<00:00, 3.05it/s]
- -
-
-
- - - \ No newline at end of file diff --git a/flash_attn/cells/benchmark.py b/flash_attn/cells/benchmark.py deleted file mode 100644 index 808f95f44a1a4079bfcb3d075e6afbf8fddbbfd7..0000000000000000000000000000000000000000 --- a/flash_attn/cells/benchmark.py +++ /dev/null @@ -1,343 +0,0 @@ -# /// script -# dependencies = [ -# "numpy", -# "torch", -# "kernels", -# "pandas", -# "matplotlib" -# ] -# /// -# Benchmarking common shapes for Flux 1024x1024px image + varying text sequence lengths - -import functools -import os -import pathlib - -import matplotlib.pyplot as plt -import torch -import torch._dynamo.config -import triton -import triton.language as tl - -try: - from flash_attn import flash_attn_func -except: - flash_attn_func = None - print("Flash Attention 2 not found.") - -try: - from flash_attn_interface import flash_attn_func as flash_attn_3_func -except: - flash_attn_3_func = None - print("Flash Attention 3 not found.") - -try: - from kernels import get_kernel - hf_kernels_flash_attn = get_kernel("kernels-community/flash-attn") - hf_kernels_flash_attn_3 = get_kernel("kernels-community/flash-attn3") -except: - hf_kernels_flash_attn = None - hf_kernels_flash_attn_3 = None - print("HF Kernels not found.") - -try: - from sageattention import sageattn_qk_int8_pv_fp16_cuda, sageattn_qk_int8_pv_fp16_triton, sageattn_qk_int8_pv_fp8_cuda_sm90 -except: - sageattn_qk_int8_pv_fp16_cuda = None - sageattn_qk_int8_pv_fp16_triton = None - sageattn_qk_int8_pv_fp8_cuda_sm90 = None - print("SageAttention not found.") - -try: - from transformer_engine.pytorch.attention import DotProductAttention -except: - DotProductAttention = None - print("Transformer Engine not found.") - -try: - import xformers.ops as xops -except: - xops = None - print("xFormers not found.") - - -plt.rcParams.update({ - "figure.figsize": (12, 10), - "figure.dpi": 120, - "font.size": 10, - "axes.titlesize": 12, - "axes.labelsize": 14, - "xtick.labelsize": 10, - "ytick.labelsize": 10, - "legend.fontsize": 8, - "axes.grid": True, - "grid.alpha": 0.3, - "grid.linestyle": "--", - "lines.linewidth": 2.0, - "lines.markersize": 6, - "legend.frameon": True, - "legend.framealpha": 0.9, - "legend.loc": "best", - "axes.spines.top": False, - "axes.spines.right": False, -}) - - -# We want to compare the best compiled version for each specific shape (dynamic=False) -torch._dynamo.config.cache_size_limit = 10000 - -# We need to suppress_errors for FA3 to work. It makes it run in eager mode. -# I can't seem to get it to work any other way under torch.compile, so any suggestions are welcome! -torch._dynamo.config.suppress_errors = True - -# output_dir = pathlib.Path("dump_attention_benchmark") -# output_dir.mkdir(parents=True, exist_ok=True) - -output_dir = pathlib.Path(".") # output to current directory for upload - -batch_size = 1 -num_attention_heads = 24 -attention_head_dim = 128 -image_sequence_length = 4096 # 1024x1024px -text_sequence_lengths = [128, 256, 320, 384, 448, 512] -sequence_lengths = [image_sequence_length + i for i in text_sequence_lengths] - - -def _attention_torch(query, key, value, *, backend): - query, key, value = (x.transpose(1, 2).contiguous() for x in (query, key, value)) - with torch.nn.attention.sdpa_kernel(backend): - out = torch.nn.functional.scaled_dot_product_attention(query, key, value) - out = out.transpose(1, 2).contiguous() - return out - - -_compiled_attention_torch_default = torch.compile(_attention_torch, mode="default", fullgraph=True, dynamic=False) -def _attention_torch_compile_default(query, key, value, *, backend): - return _compiled_attention_torch_default(query, key, value, backend=backend) - - -_compiled_attention_torch_max_autotune = torch.compile(_attention_torch, mode="max-autotune", fullgraph=True, dynamic=False) -def _attention_torch_compile_max_autotune(query, key, value, *, backend): - return _compiled_attention_torch_max_autotune(query, key, value, backend=backend) - - -def _attention_flash_attn_2(query, key, value): - return flash_attn_func(query, key, value) - - -_compiled_flash_attn_2_default = torch.compile(_attention_flash_attn_2, mode="default", fullgraph=True, dynamic=False) -def _attention_flash_attn_2_compile_default(query, key, value): - return _compiled_flash_attn_2_default(query, key, value) - - -_compiled_flash_attn_2_max_autotune = torch.compile(_attention_flash_attn_2, mode="max-autotune", fullgraph=True, dynamic=False) -def _attention_flash_attn_2_compile_max_autotune(query, key, value): - return _compiled_flash_attn_2_max_autotune(query, key, value) - - -# For fullgraph=True tracing to be compatible -@torch.library.custom_op("flash_attn_3::_flash_attn_forward", mutates_args=(), device_types="cuda") -def _wrapped_flash_attn_3(query: torch.Tensor, key: torch.Tensor, value: torch.Tensor) -> torch.Tensor: - out, lse = flash_attn_3_func(query, key, value) - return out - - -@torch.library.register_fake("flash_attn_3::_flash_attn_forward") -def _(query: torch.Tensor, key: torch.Tensor, value: torch.Tensor) -> torch.Tensor: - return torch.empty_like(query) - - -def _attention_flash_attn_3(query, key, value): - out = _wrapped_flash_attn_3(query, key, value) - return out - - -_compiled_flash_attn_3_default = torch.compile(_attention_flash_attn_3, mode="default", fullgraph=True, dynamic=False) -def _attention_flash_attn_3_compile_default(query, key, value): - return _compiled_flash_attn_3_default(query, key, value) - - -_compiled_flash_attn_3_max_autotune = torch.compile(_attention_flash_attn_3, mode="max-autotune", fullgraph=True, dynamic=False) -def _attention_flash_attn_3_compile_max_autotune(query, key, value): - return _compiled_flash_attn_3_max_autotune(query, key, value) - - -def _attention_hf_kernels_flash_attn(query, key, value): - return hf_kernels_flash_attn.fwd(query, key, value, is_causal=False)[0] - - -def _attention_hf_kernels_flash_attn3(query, key, value): - return hf_kernels_flash_attn_3.flash_attn_func(query, key, value, causal=False)[0] - - -def _attention_sageattn_qk_int8_pv_fp16_cuda(query, key, value): - return sageattn_qk_int8_pv_fp16_cuda(query, key, value, tensor_layout="NHD") - - -def _attention_sageattn_qk_int8_pv_fp16_triton(query, key, value): - return sageattn_qk_int8_pv_fp16_triton(query, key, value, tensor_layout="NHD") - - -def _attention_sageattn_qk_int8_pv_fp8_cuda_sm90(query, key, value): - return sageattn_qk_int8_pv_fp8_cuda_sm90(query, key, value, tensor_layout="NHD") - - -if DotProductAttention is not None: - def set_te_backend(backend): - # must be applied before first use of - # transformer_engine.pytorch.attention - os.environ["NVTE_FLASH_ATTN"] = '0' - os.environ["NVTE_FUSED_ATTN"] = '0' - os.environ["NVTE_UNFUSED_ATTN"] = '0' - if backend == 'flash': - os.environ["NVTE_FLASH_ATTN"] = '1' - if backend == 'fused': - os.environ["NVTE_FUSED_ATTN"] = '1' - if backend == 'unfused': - os.environ["NVTE_UNFUSED_ATTN"] = '1' - - set_te_backend("fused") - te_attn_fn = DotProductAttention( - num_attention_heads=num_attention_heads, - kv_channels=attention_head_dim, - qkv_format="bshd", - attn_mask_type="no_mask", - ) -else: - def te_attn_fn(query, key, value): - raise RuntimeError("Transformer Engine is not available. Please install it for TE-based attention.") - -def _attention_te(query, key, value): - out = te_attn_fn(query, key, value) - out = out.unflatten(2, (num_attention_heads, attention_head_dim)) - return out - - -# Cannot fullgraph compile TE -_compiled_te_attn_fn_default = torch.compile(_attention_te, mode="default", fullgraph=False, dynamic=False) -def _attention_te_compile_default(query, key, value): - return _compiled_te_attn_fn_default(query, key, value) - - -# Cannot fullgraph compile TE -_compiled_te_attn_fn_max_autotune = torch.compile(_attention_te, mode="max-autotune", fullgraph=False, dynamic=False) -def _attention_te_compile_max_autotune(query, key, value): - return _compiled_te_attn_fn_max_autotune(query, key, value) - - -def _attention_xformers(query, key, value): - return xops.memory_efficient_attention(query, key, value) - - -_compiled_xformers_default = torch.compile(_attention_xformers, mode="default", fullgraph=True, dynamic=False) -def _attention_xformers_compile_default(query, key, value): - return _compiled_xformers_default(query, key, value) - - -_compiled_xformers_max_autotune = torch.compile(_attention_xformers, mode="max-autotune", fullgraph=True, dynamic=False) -def _attention_xformers_compile_max_autotune(query, key, value): - return _compiled_xformers_max_autotune(query, key, value) - - -attention_ops = {} -attention_ops["torch_cudnn"] = functools.partial(_attention_torch, backend=torch.nn.attention.SDPBackend.CUDNN_ATTENTION) -attention_ops["torch_cudnn_compile_d"] = functools.partial(_attention_torch_compile_default, backend=torch.nn.attention.SDPBackend.CUDNN_ATTENTION) -attention_ops["torch_cudnn_compile_ma"] = functools.partial(_attention_torch_compile_max_autotune, backend=torch.nn.attention.SDPBackend.CUDNN_ATTENTION) -attention_ops["torch_flash"] = functools.partial(_attention_torch, backend=torch.nn.attention.SDPBackend.FLASH_ATTENTION) -attention_ops["torch_flash_compile_d"] = functools.partial(_attention_torch_compile_default, backend=torch.nn.attention.SDPBackend.FLASH_ATTENTION) -attention_ops["torch_flash_compile_ma"] = functools.partial(_attention_torch_compile_max_autotune, backend=torch.nn.attention.SDPBackend.FLASH_ATTENTION) -if hf_kernels_flash_attn is not None: - attention_ops["hf_flash_attn"] = _attention_hf_kernels_flash_attn - attention_ops["hf_flash_attn3"] = _attention_hf_kernels_flash_attn3 -if flash_attn_func is not None: - attention_ops["flash_attn_2"] = _attention_flash_attn_2 - attention_ops["flash_attn_2_compile_d"] = _attention_flash_attn_2_compile_default - attention_ops["flash_attn_2_compile_ma"] = _attention_flash_attn_2_compile_max_autotune -if flash_attn_3_func is not None: - attention_ops["flash_attn_3"] = _attention_flash_attn_3 - attention_ops["flash_attn_3_compile_d"] = _attention_flash_attn_3_compile_default - attention_ops["flash_attn_3_compile_ma"] = _attention_flash_attn_3_compile_max_autotune -if sageattn_qk_int8_pv_fp16_cuda is not None: - attention_ops["sageattn_qk_int8_pv_fp16_cuda"] = _attention_sageattn_qk_int8_pv_fp16_cuda - attention_ops["sageattn_qk_int8_pv_fp16_triton"] = _attention_sageattn_qk_int8_pv_fp16_triton - if torch.cuda.get_device_capability()[0] >= 9: - attention_ops["sageattn_qk_int8_pv_fp8_cuda_sm90"] = _attention_sageattn_qk_int8_pv_fp8_cuda_sm90 -if DotProductAttention is not None: - attention_ops["te_fused"] = _attention_te - attention_ops["te_fused_compile_d"] = _attention_te_compile_default - attention_ops["te_fused_compile_ma"] = _attention_te_compile_max_autotune -if xops is not None: - attention_ops["xformers"] = _attention_xformers - attention_ops["xformers_compile_d"] = _attention_xformers_compile_default - attention_ops["xformers_compile_ma"] = _attention_xformers_compile_max_autotune - - -def get_color_and_linestyle(n: int) -> tuple[str, str]: - colors = ["#e41a1c", "#377eb8", "#4daf4a", "#984ea3", "#ff7f00", "#a65628", "#f781bf", "#999999"] - line_styles = ["-", ":", "-.", "--"] - if n > len(colors) * len(line_styles): - raise ValueError(f"Required {n=} styles but maximum is {len(colors) * len(line_styles)}") - styles = [] - for i in range(n): - color = colors[i % len(colors)] - linestyle = line_styles[i // len(colors)] - styles.append((color, linestyle)) - return styles - - -def correctness(): - for seq_len in sequence_lengths: - shape = (batch_size, seq_len, num_attention_heads, attention_head_dim) - print(f"\n\n===== Testing shape: {shape} =====") - - query = torch.randn(shape, device="cuda", dtype=torch.float32) - key = torch.randn(shape, device="cuda", dtype=torch.float32) - value = torch.randn(shape, device="cuda", dtype=torch.float32) - - golden_truth = _attention_torch(query, key, value, backend=torch.nn.attention.SDPBackend.MATH) - query, key, value = (x.bfloat16() for x in (query, key, value)) - - for name, fn in attention_ops.items(): - out = fn(query, key, value) - absdiff = (out - golden_truth).abs() - absmax = torch.max(absdiff) - mae = torch.mean(absdiff) - mse = torch.mean((golden_truth - out) ** 2) - print(f"{name:<30}: absmax={absmax:.6f}, mae={mae:.6f}, mse={mse:.6f}") - - -@triton.testing.perf_report( - triton.testing.Benchmark( - x_names=["seq_len"], - x_vals=sequence_lengths, - x_log=False, - line_arg="provider", - line_vals=list(attention_ops.keys()), - line_names=[x.removeprefix("solution_") for x in attention_ops.keys()], - ylabel="Time (ms)", - styles=get_color_and_linestyle(len(attention_ops)), - plot_name="Attention Benchmark", - args={}, - ) -) -def benchmark_fn(seq_len: int, provider: str): - torch.manual_seed(0) - - shape = (batch_size, seq_len, num_attention_heads, attention_head_dim) - query = torch.randn(shape, device="cuda", dtype=torch.bfloat16) * torch.randint(1, 5, shape, device="cuda", dtype=torch.bfloat16) - key = torch.randn(shape, device="cuda", dtype=torch.bfloat16) * torch.randint(1, 5, shape, device="cuda", dtype=torch.bfloat16) - value = torch.randn(shape, device="cuda", dtype=torch.bfloat16) * torch.randint(1, 5, shape, device="cuda", dtype=torch.bfloat16) - - fn = attention_ops[provider] - ms, min_ms, max_ms = triton.testing.do_bench( - lambda: fn(query, key, value), - warmup=3, - rep=10, - quantiles=[0.5, 0.2, 0.8], - ) - return ms, max_ms, min_ms - - -with torch.inference_mode(): - correctness() - fig = benchmark_fn.run(print_data=True, save_path=output_dir.as_posix()) diff --git a/flash_attn/cells/nv.py b/flash_attn/cells/nv.py deleted file mode 100644 index 80eef60a7536ed875fb21731ab2d059458bd20b4..0000000000000000000000000000000000000000 --- a/flash_attn/cells/nv.py +++ /dev/null @@ -1,3 +0,0 @@ -import subprocess - -print(subprocess.run(["nvidia-smi"], capture_output=True, text=True).stdout) \ No newline at end of file diff --git a/flash_attn/impls/artifacts/benchmark/attn.jsonl b/flash_attn/impls/artifacts/benchmark/attn.jsonl deleted file mode 100644 index 5acfbcfc5bc3cd28bec9cacd18308543bc7864e5..0000000000000000000000000000000000000000 --- a/flash_attn/impls/artifacts/benchmark/attn.jsonl +++ /dev/null @@ -1,6 +0,0 @@ -{"ts": "2025-10-02T19:59:35Z", "run": "8bc1bbc1e0504355abbb1f58e69828d3", "impl": "hf_kernels_flash_attn3", "tags": {"family": "hf-kernels", "backend": "flash-attn3", "compile": "none"}, "wl": {"name": "flux_L128", "batch": 1, "seq_len": 1152, "heads": 24, "head_dim": 128, "dtype": "bfloat16", "device": "cuda", "seed": 0}, "env": {"torch": "2.8.0+cu128", "cuda": "12.8", "gpu": "NVIDIA A10G", "sm": "8.6", "py": "3.11.13", "plat": "Linux-6.12.40-64.114.amzn2023.x86_64-x86_64-with-glibc2.36"}, "lat_ms": {"p10": 0.3603839874267578, "p50": 0.361952006816864, "p90": 0.3624640107154846, "mean": 0.3619711995124817, "reps": 5, "warmup": 2}, "compile_ms": 1.5701119899749756, "peak_bytes": 87425024, "ok": true, "absmax": 0.0625, "corr": {"ok": true, "rtol": 0.02, "atol": 0.02, "absmax": 0.0625, "mae": 0.00035858154296875, "mse": 2.8908252716064453e-06, "ref": "sdpa_math_fp32"}, "err": null} -{"ts": "2025-10-02T19:59:35Z", "run": "8bc1bbc1e0504355abbb1f58e69828d3", "impl": "hf_kernels_flash_attn3", "tags": {"family": "hf-kernels", "backend": "flash-attn3", "compile": "none"}, "wl": {"name": "flux_L256", "batch": 1, "seq_len": 1280, "heads": 24, "head_dim": 128, "dtype": "bfloat16", "device": "cuda", "seed": 0}, "env": {"torch": "2.8.0+cu128", "cuda": "12.8", "gpu": "NVIDIA A10G", "sm": "8.6", "py": "3.11.13", "plat": "Linux-6.12.40-64.114.amzn2023.x86_64-x86_64-with-glibc2.36"}, "lat_ms": {"p10": 0.3892799913883209, "p50": 0.3909760117530823, "p90": 0.3922559916973114, "mean": 0.3912447988986969, "reps": 5, "warmup": 2}, "compile_ms": 0.35811200737953186, "peak_bytes": 95027200, "ok": true, "absmax": 0.0625, "corr": {"ok": true, "rtol": 0.02, "atol": 0.02, "absmax": 0.0625, "mae": 0.00035858154296875, "mse": 2.8908252716064453e-06, "ref": "sdpa_math_fp32"}, "err": null} -{"ts": "2025-10-02T19:59:35Z", "run": "8bc1bbc1e0504355abbb1f58e69828d3", "impl": "hf_kernels_flash_attn3", "tags": {"family": "hf-kernels", "backend": "flash-attn3", "compile": "none"}, "wl": {"name": "flux_L320", "batch": 1, "seq_len": 1344, "heads": 24, "head_dim": 128, "dtype": "bfloat16", "device": "cuda", "seed": 0}, "env": {"torch": "2.8.0+cu128", "cuda": "12.8", "gpu": "NVIDIA A10G", "sm": "8.6", "py": "3.11.13", "plat": "Linux-6.12.40-64.114.amzn2023.x86_64-x86_64-with-glibc2.36"}, "lat_ms": {"p10": 0.5240640044212341, "p50": 0.5248960256576538, "p90": 0.5248960256576538, "mean": 0.5258048176765442, "reps": 5, "warmup": 2}, "compile_ms": 0.4891839921474457, "peak_bytes": 99680256, "ok": true, "absmax": 0.0625, "corr": {"ok": true, "rtol": 0.02, "atol": 0.02, "absmax": 0.0625, "mae": 0.00035858154296875, "mse": 2.905726432800293e-06, "ref": "sdpa_math_fp32"}, "err": null} -{"ts": "2025-10-02T19:59:35Z", "run": "8bc1bbc1e0504355abbb1f58e69828d3", "impl": "hf_kernels_flash_attn3", "tags": {"family": "hf-kernels", "backend": "flash-attn3", "compile": "none"}, "wl": {"name": "flux_L384", "batch": 1, "seq_len": 1408, "heads": 24, "head_dim": 128, "dtype": "bfloat16", "device": "cuda", "seed": 0}, "env": {"torch": "2.8.0+cu128", "cuda": "12.8", "gpu": "NVIDIA A10G", "sm": "8.6", "py": "3.11.13", "plat": "Linux-6.12.40-64.114.amzn2023.x86_64-x86_64-with-glibc2.36"}, "lat_ms": {"p10": 0.5265600085258484, "p50": 0.5277760028839111, "p90": 0.5282559990882874, "mean": 0.5276032090187073, "reps": 5, "warmup": 2}, "compile_ms": 0.4968000054359436, "peak_bytes": 104726528, "ok": true, "absmax": 0.0625, "corr": {"ok": true, "rtol": 0.02, "atol": 0.02, "absmax": 0.0625, "mae": 0.0003604888916015625, "mse": 2.8908252716064453e-06, "ref": "sdpa_math_fp32"}, "err": null} -{"ts": "2025-10-02T19:59:35Z", "run": "8bc1bbc1e0504355abbb1f58e69828d3", "impl": "hf_kernels_flash_attn3", "tags": {"family": "hf-kernels", "backend": "flash-attn3", "compile": "none"}, "wl": {"name": "flux_L448", "batch": 1, "seq_len": 1472, "heads": 24, "head_dim": 128, "dtype": "bfloat16", "device": "cuda", "seed": 0}, "env": {"torch": "2.8.0+cu128", "cuda": "12.8", "gpu": "NVIDIA A10G", "sm": "8.6", "py": "3.11.13", "plat": "Linux-6.12.40-64.114.amzn2023.x86_64-x86_64-with-glibc2.36"}, "lat_ms": {"p10": 0.5639039874076843, "p50": 0.5657920241355896, "p90": 0.5668479800224304, "mean": 0.5656383991241455, "reps": 5, "warmup": 2}, "compile_ms": 0.5312319993972778, "peak_bytes": 108855296, "ok": true, "absmax": 0.0625, "corr": {"ok": true, "rtol": 0.02, "atol": 0.02, "absmax": 0.0625, "mae": 0.0003566741943359375, "mse": 2.86102294921875e-06, "ref": "sdpa_math_fp32"}, "err": null} -{"ts": "2025-10-02T19:59:35Z", "run": "8bc1bbc1e0504355abbb1f58e69828d3", "impl": "hf_kernels_flash_attn3", "tags": {"family": "hf-kernels", "backend": "flash-attn3", "compile": "none"}, "wl": {"name": "flux_L512", "batch": 1, "seq_len": 1536, "heads": 24, "head_dim": 128, "dtype": "bfloat16", "device": "cuda", "seed": 0}, "env": {"torch": "2.8.0+cu128", "cuda": "12.8", "gpu": "NVIDIA A10G", "sm": "8.6", "py": "3.11.13", "plat": "Linux-6.12.40-64.114.amzn2023.x86_64-x86_64-with-glibc2.36"}, "lat_ms": {"p10": 0.5689600110054016, "p50": 0.5698239803314209, "p90": 0.5713919997215271, "mean": 0.5789952039718628, "reps": 5, "warmup": 2}, "compile_ms": 0.5350080132484436, "peak_bytes": 114425856, "ok": true, "absmax": 0.0625, "corr": {"ok": true, "rtol": 0.02, "atol": 0.02, "absmax": 0.0625, "mae": 0.00035858154296875, "mse": 2.8759241104125977e-06, "ref": "sdpa_math_fp32"}, "err": null} diff --git a/flash_attn/impls/artifacts/benchmark_default/attn_default.jsonl b/flash_attn/impls/artifacts/benchmark_default/attn_default.jsonl deleted file mode 100644 index f0bbce330d32dfe0cbd3c869a8ad5f2aaa045e94..0000000000000000000000000000000000000000 --- a/flash_attn/impls/artifacts/benchmark_default/attn_default.jsonl +++ /dev/null @@ -1,6 +0,0 @@ -{"ts": "2025-10-02T19:58:18Z", "run": "9ebc449a917f4f2196503654e5549239", "impl": "torch_flash_compiled_default", "tags": {"family": "torch-sdpa", "backend": "FLASH", "compile": "default"}, "wl": {"name": "flux_L128", "batch": 1, "seq_len": 1152, "heads": 24, "head_dim": 128, "dtype": "bfloat16", "device": "cuda", "seed": 0}, "env": {"torch": "2.8.0+cu128", "cuda": "12.8", "gpu": "NVIDIA A10G", "sm": "8.6", "py": "3.11.13", "plat": "Linux-6.12.40-64.114.amzn2023.x86_64-x86_64-with-glibc2.36"}, "lat_ms": {"p10": 0.5141760110855103, "p50": 0.5175679922103882, "p90": 0.5197759866714478, "mean": 0.5181439876556396, "reps": 5, "warmup": 2}, "compile_ms": 3084.621826171875, "peak_bytes": 87425024, "ok": true, "absmax": 0.0625, "corr": {"ok": true, "rtol": 0.02, "atol": 0.02, "absmax": 0.0625, "mae": 0.000339508056640625, "mse": 2.726912498474121e-06, "ref": "sdpa_math_fp32"}, "err": null} -{"ts": "2025-10-02T19:58:19Z", "run": "9ebc449a917f4f2196503654e5549239", "impl": "torch_flash_compiled_default", "tags": {"family": "torch-sdpa", "backend": "FLASH", "compile": "default"}, "wl": {"name": "flux_L256", "batch": 1, "seq_len": 1280, "heads": 24, "head_dim": 128, "dtype": "bfloat16", "device": "cuda", "seed": 0}, "env": {"torch": "2.8.0+cu128", "cuda": "12.8", "gpu": "NVIDIA A10G", "sm": "8.6", "py": "3.11.13", "plat": "Linux-6.12.40-64.114.amzn2023.x86_64-x86_64-with-glibc2.36"}, "lat_ms": {"p10": 0.5549119710922241, "p50": 0.5582720041275024, "p90": 0.5598080158233643, "mean": 0.5579584002494812, "reps": 5, "warmup": 2}, "compile_ms": 270.21795654296875, "peak_bytes": 95027200, "ok": true, "absmax": 0.0625, "corr": {"ok": true, "rtol": 0.02, "atol": 0.02, "absmax": 0.0625, "mae": 0.0003414154052734375, "mse": 2.7418136596679688e-06, "ref": "sdpa_math_fp32"}, "err": null} -{"ts": "2025-10-02T19:58:19Z", "run": "9ebc449a917f4f2196503654e5549239", "impl": "torch_flash_compiled_default", "tags": {"family": "torch-sdpa", "backend": "FLASH", "compile": "default"}, "wl": {"name": "flux_L320", "batch": 1, "seq_len": 1344, "heads": 24, "head_dim": 128, "dtype": "bfloat16", "device": "cuda", "seed": 0}, "env": {"torch": "2.8.0+cu128", "cuda": "12.8", "gpu": "NVIDIA A10G", "sm": "8.6", "py": "3.11.13", "plat": "Linux-6.12.40-64.114.amzn2023.x86_64-x86_64-with-glibc2.36"}, "lat_ms": {"p10": 0.6853119730949402, "p50": 0.687391996383667, "p90": 0.6883519887924194, "mean": 0.6872959971427918, "reps": 5, "warmup": 2}, "compile_ms": 269.78741455078125, "peak_bytes": 99876864, "ok": true, "absmax": 0.0625, "corr": {"ok": true, "rtol": 0.02, "atol": 0.02, "absmax": 0.0625, "mae": 0.00034332275390625, "mse": 2.7567148208618164e-06, "ref": "sdpa_math_fp32"}, "err": null} -{"ts": "2025-10-02T19:58:19Z", "run": "9ebc449a917f4f2196503654e5549239", "impl": "torch_flash_compiled_default", "tags": {"family": "torch-sdpa", "backend": "FLASH", "compile": "default"}, "wl": {"name": "flux_L384", "batch": 1, "seq_len": 1408, "heads": 24, "head_dim": 128, "dtype": "bfloat16", "device": "cuda", "seed": 0}, "env": {"torch": "2.8.0+cu128", "cuda": "12.8", "gpu": "NVIDIA A10G", "sm": "8.6", "py": "3.11.13", "plat": "Linux-6.12.40-64.114.amzn2023.x86_64-x86_64-with-glibc2.36"}, "lat_ms": {"p10": 0.7128639817237854, "p50": 0.7160959839820862, "p90": 0.7167680263519287, "mean": 0.716153597831726, "reps": 5, "warmup": 2}, "compile_ms": 269.8607177734375, "peak_bytes": 104726528, "ok": true, "absmax": 0.0625, "corr": {"ok": true, "rtol": 0.02, "atol": 0.02, "absmax": 0.0625, "mae": 0.00034332275390625, "mse": 2.7567148208618164e-06, "ref": "sdpa_math_fp32"}, "err": null} -{"ts": "2025-10-02T19:58:19Z", "run": "9ebc449a917f4f2196503654e5549239", "impl": "torch_flash_compiled_default", "tags": {"family": "torch-sdpa", "backend": "FLASH", "compile": "default"}, "wl": {"name": "flux_L448", "batch": 1, "seq_len": 1472, "heads": 24, "head_dim": 128, "dtype": "bfloat16", "device": "cuda", "seed": 0}, "env": {"torch": "2.8.0+cu128", "cuda": "12.8", "gpu": "NVIDIA A10G", "sm": "8.6", "py": "3.11.13", "plat": "Linux-6.12.40-64.114.amzn2023.x86_64-x86_64-with-glibc2.36"}, "lat_ms": {"p10": 0.7386879920959473, "p50": 0.7400959730148315, "p90": 0.7415040135383606, "mean": 0.7418303966522217, "reps": 5, "warmup": 2}, "compile_ms": 269.20501708984375, "peak_bytes": 108855296, "ok": true, "absmax": 0.0625, "corr": {"ok": true, "rtol": 0.02, "atol": 0.02, "absmax": 0.0625, "mae": 0.00034332275390625, "mse": 2.7567148208618164e-06, "ref": "sdpa_math_fp32"}, "err": null} -{"ts": "2025-10-02T19:58:20Z", "run": "9ebc449a917f4f2196503654e5549239", "impl": "torch_flash_compiled_default", "tags": {"family": "torch-sdpa", "backend": "FLASH", "compile": "default"}, "wl": {"name": "flux_L512", "batch": 1, "seq_len": 1536, "heads": 24, "head_dim": 128, "dtype": "bfloat16", "device": "cuda", "seed": 0}, "env": {"torch": "2.8.0+cu128", "cuda": "12.8", "gpu": "NVIDIA A10G", "sm": "8.6", "py": "3.11.13", "plat": "Linux-6.12.40-64.114.amzn2023.x86_64-x86_64-with-glibc2.36"}, "lat_ms": {"p10": 0.7708160281181335, "p50": 0.7740799784660339, "p90": 0.7753919959068298, "mean": 0.7745471954345703, "reps": 5, "warmup": 2}, "compile_ms": 270.93829345703125, "peak_bytes": 114425856, "ok": true, "absmax": 0.0625, "corr": {"ok": true, "rtol": 0.02, "atol": 0.02, "absmax": 0.0625, "mae": 0.0003452301025390625, "mse": 2.771615982055664e-06, "ref": "sdpa_math_fp32"}, "err": null} diff --git a/flash_attn/impls/artifacts/benchmark_max_autotune/attn_max_autotune.jsonl b/flash_attn/impls/artifacts/benchmark_max_autotune/attn_max_autotune.jsonl deleted file mode 100644 index 87baeccdc6ab98529881effd3a62d0dfe5029b5a..0000000000000000000000000000000000000000 --- a/flash_attn/impls/artifacts/benchmark_max_autotune/attn_max_autotune.jsonl +++ /dev/null @@ -1,6 +0,0 @@ -{"ts": "2025-10-02T19:57:25Z", "run": "edb73be653834cdf8524ee78b403db7f", "impl": "torch_flash_compiled_max_autotune", "tags": {"family": "torch-sdpa", "backend": "FLASH", "compile": "max-autotune"}, "wl": {"name": "flux_L128", "batch": 1, "seq_len": 1152, "heads": 24, "head_dim": 128, "dtype": "bfloat16", "device": "cuda", "seed": 0}, "env": {"torch": "2.8.0+cu128", "cuda": "12.8", "gpu": "NVIDIA A10G", "sm": "8.6", "py": "3.11.13", "plat": "Linux-6.12.40-64.114.amzn2023.x86_64-x86_64-with-glibc2.36"}, "lat_ms": {"p10": 0.6144000291824341, "p50": 0.6245759725570679, "p90": 0.6483200192451477, "mean": 0.6468096017837525, "reps": 5, "warmup": 2}, "compile_ms": 4407.3388671875, "peak_bytes": 70779904, "ok": true, "absmax": 0.0625, "corr": {"ok": true, "rtol": 0.02, "atol": 0.02, "absmax": 0.0625, "mae": 0.000339508056640625, "mse": 2.726912498474121e-06, "ref": "sdpa_math_fp32"}, "err": null} -{"ts": "2025-10-02T19:57:27Z", "run": "edb73be653834cdf8524ee78b403db7f", "impl": "torch_flash_compiled_max_autotune", "tags": {"family": "torch-sdpa", "backend": "FLASH", "compile": "max-autotune"}, "wl": {"name": "flux_L256", "batch": 1, "seq_len": 1280, "heads": 24, "head_dim": 128, "dtype": "bfloat16", "device": "cuda", "seed": 0}, "env": {"torch": "2.8.0+cu128", "cuda": "12.8", "gpu": "NVIDIA A10G", "sm": "8.6", "py": "3.11.13", "plat": "Linux-6.12.40-64.114.amzn2023.x86_64-x86_64-with-glibc2.36"}, "lat_ms": {"p10": 0.6689280271530151, "p50": 0.6851199865341187, "p90": 0.7184960246086121, "mean": 0.7060160160064697, "reps": 5, "warmup": 2}, "compile_ms": 1686.2735595703125, "peak_bytes": 78644224, "ok": true, "absmax": 0.0625, "corr": {"ok": true, "rtol": 0.02, "atol": 0.02, "absmax": 0.0625, "mae": 0.0003414154052734375, "mse": 2.7418136596679688e-06, "ref": "sdpa_math_fp32"}, "err": null} -{"ts": "2025-10-02T19:57:29Z", "run": "edb73be653834cdf8524ee78b403db7f", "impl": "torch_flash_compiled_max_autotune", "tags": {"family": "torch-sdpa", "backend": "FLASH", "compile": "max-autotune"}, "wl": {"name": "flux_L320", "batch": 1, "seq_len": 1344, "heads": 24, "head_dim": 128, "dtype": "bfloat16", "device": "cuda", "seed": 0}, "env": {"torch": "2.8.0+cu128", "cuda": "12.8", "gpu": "NVIDIA A10G", "sm": "8.6", "py": "3.11.13", "plat": "Linux-6.12.40-64.114.amzn2023.x86_64-x86_64-with-glibc2.36"}, "lat_ms": {"p10": 0.7953600287437439, "p50": 0.8155840039253235, "p90": 0.8403519988059998, "mean": 0.8332608103752136, "reps": 5, "warmup": 2}, "compile_ms": 1462.938232421875, "peak_bytes": 84280320, "ok": true, "absmax": 0.0625, "corr": {"ok": true, "rtol": 0.02, "atol": 0.02, "absmax": 0.0625, "mae": 0.00034332275390625, "mse": 2.7567148208618164e-06, "ref": "sdpa_math_fp32"}, "err": null} -{"ts": "2025-10-02T19:57:31Z", "run": "edb73be653834cdf8524ee78b403db7f", "impl": "torch_flash_compiled_max_autotune", "tags": {"family": "torch-sdpa", "backend": "FLASH", "compile": "max-autotune"}, "wl": {"name": "flux_L384", "batch": 1, "seq_len": 1408, "heads": 24, "head_dim": 128, "dtype": "bfloat16", "device": "cuda", "seed": 0}, "env": {"torch": "2.8.0+cu128", "cuda": "12.8", "gpu": "NVIDIA A10G", "sm": "8.6", "py": "3.11.13", "plat": "Linux-6.12.40-64.114.amzn2023.x86_64-x86_64-with-glibc2.36"}, "lat_ms": {"p10": 0.8470720052719116, "p50": 0.849727988243103, "p90": 0.8745279908180237, "mean": 0.8719295978546142, "reps": 5, "warmup": 2}, "compile_ms": 1689.3455810546875, "peak_bytes": 86508544, "ok": true, "absmax": 0.0625, "corr": {"ok": true, "rtol": 0.02, "atol": 0.02, "absmax": 0.0625, "mae": 0.00034332275390625, "mse": 2.7567148208618164e-06, "ref": "sdpa_math_fp32"}, "err": null} -{"ts": "2025-10-02T19:57:33Z", "run": "edb73be653834cdf8524ee78b403db7f", "impl": "torch_flash_compiled_max_autotune", "tags": {"family": "torch-sdpa", "backend": "FLASH", "compile": "max-autotune"}, "wl": {"name": "flux_L448", "batch": 1, "seq_len": 1472, "heads": 24, "head_dim": 128, "dtype": "bfloat16", "device": "cuda", "seed": 0}, "env": {"torch": "2.8.0+cu128", "cuda": "12.8", "gpu": "NVIDIA A10G", "sm": "8.6", "py": "3.11.13", "plat": "Linux-6.12.40-64.114.amzn2023.x86_64-x86_64-with-glibc2.36"}, "lat_ms": {"p10": 0.8677120208740234, "p50": 0.8835520148277283, "p90": 0.9034240245819092, "mean": 0.9034304022789001, "reps": 5, "warmup": 2}, "compile_ms": 1693.035888671875, "peak_bytes": 90440704, "ok": true, "absmax": 0.0625, "corr": {"ok": true, "rtol": 0.02, "atol": 0.02, "absmax": 0.0625, "mae": 0.00034332275390625, "mse": 2.7567148208618164e-06, "ref": "sdpa_math_fp32"}, "err": null} -{"ts": "2025-10-02T19:57:34Z", "run": "edb73be653834cdf8524ee78b403db7f", "impl": "torch_flash_compiled_max_autotune", "tags": {"family": "torch-sdpa", "backend": "FLASH", "compile": "max-autotune"}, "wl": {"name": "flux_L512", "batch": 1, "seq_len": 1536, "heads": 24, "head_dim": 128, "dtype": "bfloat16", "device": "cuda", "seed": 0}, "env": {"torch": "2.8.0+cu128", "cuda": "12.8", "gpu": "NVIDIA A10G", "sm": "8.6", "py": "3.11.13", "plat": "Linux-6.12.40-64.114.amzn2023.x86_64-x86_64-with-glibc2.36"}, "lat_ms": {"p10": 0.9154239892959595, "p50": 0.9213759899139404, "p90": 0.9359679818153381, "mean": 0.9387519836425782, "reps": 5, "warmup": 2}, "compile_ms": 1689.36279296875, "peak_bytes": 94372864, "ok": true, "absmax": 0.0625, "corr": {"ok": true, "rtol": 0.02, "atol": 0.02, "absmax": 0.0625, "mae": 0.0003452301025390625, "mse": 2.771615982055664e-06, "ref": "sdpa_math_fp32"}, "err": null} diff --git a/flash_attn/impls/cells/benchmark.py b/flash_attn/impls/cells/benchmark.py deleted file mode 100644 index f471434e8cb2d02f0e3b081d6133c17ea46bc373..0000000000000000000000000000000000000000 --- a/flash_attn/impls/cells/benchmark.py +++ /dev/null @@ -1,71 +0,0 @@ -# /// script -# requires-python = ">=3.10" -# dependencies = [ -# "numpy", -# "torch", -# "kernels-benchmark-tools", -# "kernels", -# ] -# -# [tool.uv.sources] -# kernels-benchmark-tools = { git = "https://github.com/drbh/kernels-benchmark-tools.git", branch = "main" } -# /// -import torch -import sys -import os -import kernels_benchmark_tools as kbt -from kernels import get_kernel - -hf_kernels_flash_attn3 = get_kernel("kernels-community/flash-attn3") - - -def hf_flash_attention3(query, key, value): - return hf_kernels_flash_attn3.flash_attn_func(query, key, value, causal=False)[0] - - -kbt.add( - "hf_kernels_flash_attn3", - hf_flash_attention3, - tags={"family": "hf-kernels", "backend": "flash-attn3", "compile": "none"}, -) - -if __name__ == "__main__": - device = "cuda" if torch.cuda.is_available() else "cpu" - - if device == "cpu": - print("HF Kernels Flash Attention 3 requires CUDA - skipping benchmark") - sys.exit(0) - - dtype = "bfloat16" - - # Flux-like workloads - base = 1024 - flux_sizes = [128, 256, 320, 384, 448, 512] - heads = 24 - head_dim = 128 - - wl = [] - for L in flux_sizes: - wl.append( - { - "name": f"flux_L{L}", - "batch": 1, - "seq_len": base + L, - "heads": heads, - "head_dim": head_dim, - "dtype": dtype, - "device": device, - "seed": 0, - } - ) - - kbt.run( - wl, - jsonl="attn.jsonl", - reps=5, - warmup=2, - gen=kbt.attn.gen_qkv, - ref=kbt.attn.ref_math, - cmp=kbt.attn.cmp_allclose, - ) - kbt.summarize(["attn.jsonl"]) \ No newline at end of file diff --git a/flash_attn/impls/cells/benchmark_default.py b/flash_attn/impls/cells/benchmark_default.py deleted file mode 100644 index cc2fd06ac69ffe1f5bc88d1821b17447dc90c846..0000000000000000000000000000000000000000 --- a/flash_attn/impls/cells/benchmark_default.py +++ /dev/null @@ -1,70 +0,0 @@ -# /// script -# requires-python = ">=3.10" -# dependencies = [ -# "numpy", -# "torch", -# "kernels-benchmark-tools", -# ] -# -# [tool.uv.sources] -# kernels-benchmark-tools = { git = "https://github.com/drbh/kernels-benchmark-tools.git", branch = "main" } -# /// -import torch -import sys -import os -import kernels_benchmark_tools as kbt - - -def torch_flash_base(q, k, v): - qt, kt, vt = (x.transpose(1, 2).contiguous() for x in (q, k, v)) - with torch.nn.attention.sdpa_kernel(torch.nn.attention.SDPBackend.FLASH_ATTENTION): - o = torch.nn.functional.scaled_dot_product_attention(qt, kt, vt) - return o.transpose(1, 2).contiguous() - - -# Compile with default mode -compiled_flash_default = torch.compile(torch_flash_base, mode="default", fullgraph=True, dynamic=False) - -kbt.add( - "torch_flash_compiled_default", - compiled_flash_default, - tags={"family": "torch-sdpa", "backend": "FLASH", "compile": "default"}, -) - -if __name__ == "__main__": - device = "cuda" if torch.cuda.is_available() else "cpu" - dtype = "float32" if device == "cpu" else "bfloat16" - - # Flux-like workloads - base = 1024 if device == "cuda" else 512 - flux_sizes = ( - [128, 256, 320, 384, 448, 512] if device == "cuda" else [64, 128, 192, 256] - ) - heads = 24 if device == "cuda" else 8 - head_dim = 128 if device == "cuda" else 64 - - wl = [] - for L in flux_sizes: - wl.append( - { - "name": f"flux_L{L}", - "batch": 1, - "seq_len": base + L, - "heads": heads, - "head_dim": head_dim, - "dtype": dtype, - "device": device, - "seed": 0, - } - ) - - kbt.run( - wl, - jsonl="attn_default.jsonl", - reps=5, - warmup=2, - gen=kbt.attn.gen_qkv, - ref=kbt.attn.ref_math, - cmp=kbt.attn.cmp_allclose, - ) - kbt.summarize(["attn_default.jsonl"]) \ No newline at end of file diff --git a/flash_attn/impls/cells/benchmark_max_autotune.py b/flash_attn/impls/cells/benchmark_max_autotune.py deleted file mode 100644 index bd96e676c4d9ebdf709b701a7b9a71b9d51774fd..0000000000000000000000000000000000000000 --- a/flash_attn/impls/cells/benchmark_max_autotune.py +++ /dev/null @@ -1,70 +0,0 @@ -# /// script -# requires-python = ">=3.10" -# dependencies = [ -# "numpy", -# "torch", -# "kernels-benchmark-tools", -# ] -# -# [tool.uv.sources] -# kernels-benchmark-tools = { git = "https://github.com/drbh/kernels-benchmark-tools.git", branch = "main" } -# /// -import torch -import sys -import os -import kernels_benchmark_tools as kbt - - -def torch_flash_base(q, k, v): - qt, kt, vt = (x.transpose(1, 2).contiguous() for x in (q, k, v)) - with torch.nn.attention.sdpa_kernel(torch.nn.attention.SDPBackend.FLASH_ATTENTION): - o = torch.nn.functional.scaled_dot_product_attention(qt, kt, vt) - return o.transpose(1, 2).contiguous() - - -# Compile with max-autotune mode -compiled_flash_max_autotune = torch.compile(torch_flash_base, mode="max-autotune", fullgraph=True, dynamic=False) - -kbt.add( - "torch_flash_compiled_max_autotune", - compiled_flash_max_autotune, - tags={"family": "torch-sdpa", "backend": "FLASH", "compile": "max-autotune"}, -) - -if __name__ == "__main__": - device = "cuda" if torch.cuda.is_available() else "cpu" - dtype = "float32" if device == "cpu" else "bfloat16" - - # Flux-like workloads - base = 1024 if device == "cuda" else 512 - flux_sizes = ( - [128, 256, 320, 384, 448, 512] if device == "cuda" else [64, 128, 192, 256] - ) - heads = 24 if device == "cuda" else 8 - head_dim = 128 if device == "cuda" else 64 - - wl = [] - for L in flux_sizes: - wl.append( - { - "name": f"flux_L{L}", - "batch": 1, - "seq_len": base + L, - "heads": heads, - "head_dim": head_dim, - "dtype": dtype, - "device": device, - "seed": 0, - } - ) - - kbt.run( - wl, - jsonl="attn_max_autotune.jsonl", - reps=5, - warmup=2, - gen=kbt.attn.gen_qkv, - ref=kbt.attn.ref_math, - cmp=kbt.attn.cmp_allclose, - ) - kbt.summarize(["attn_max_autotune.jsonl"]) \ No newline at end of file diff --git a/flash_attn/impls/cells/nv.py b/flash_attn/impls/cells/nv.py deleted file mode 100644 index 80eef60a7536ed875fb21731ab2d059458bd20b4..0000000000000000000000000000000000000000 --- a/flash_attn/impls/cells/nv.py +++ /dev/null @@ -1,3 +0,0 @@ -import subprocess - -print(subprocess.run(["nvidia-smi"], capture_output=True, text=True).stdout) \ No newline at end of file diff --git a/flash_attn/impls/compiled_variants.html b/flash_attn/impls/compiled_variants.html deleted file mode 100644 index df7b4a9560dee2b9cc43fb2840e4da8927f7cde9..0000000000000000000000000000000000000000 --- a/flash_attn/impls/compiled_variants.html +++ /dev/null @@ -1,4077 +0,0 @@ - - - - - - compiled_variants - - - - - - - -
-
- - ← back - -
light
-
reset
- -
-
- -
-
Generated on:
-
- Linux x86_64 | Linux-6.12.40-64.114.amzn2023.x86_64-x86_64-with-glibc2.36 -
-
- -
-

Torch Compile Variants

-

This file benchmarks Flash Attention with different torch.compile modes.

-

Flash Attention with torch.compile(mode="default")

-
-
- -▼ code -▼ output - ▶ uv-logs - | -Cell: benchmark_default | 45.23s - | - -Raw -
-
-
-
# /// script
-# requires-python = ">=3.10"
-# dependencies = [
-#     "numpy",
-#     "torch",
-#     "kernels-benchmark-tools",
-# ]
-#
-# [tool.uv.sources]
-# kernels-benchmark-tools = { git = "https://github.com/drbh/kernels-benchmark-tools.git", branch = "main" }
-# ///
-import torch
-import sys
-import os
-import kernels_benchmark_tools as kbt
-
-
-def torch_flash_base(q, k, v):
-    qt, kt, vt = (x.transpose(1, 2).contiguous() for x in (q, k, v))
-    with torch.nn.attention.sdpa_kernel(torch.nn.attention.SDPBackend.FLASH_ATTENTION):
-        o = torch.nn.functional.scaled_dot_product_attention(qt, kt, vt)
-    return o.transpose(1, 2).contiguous()
-
-
-# Compile with default mode
-compiled_flash_default = torch.compile(torch_flash_base, mode="default", fullgraph=True, dynamic=False)
-
-kbt.add(
-    "torch_flash_compiled_default",
-    compiled_flash_default,
-    tags={"family": "torch-sdpa", "backend": "FLASH", "compile": "default"},
-)
-
-if __name__ == "__main__":
-    device = "cuda" if torch.cuda.is_available() else "cpu"
-    dtype = "float32" if device == "cpu" else "bfloat16"
-
-    # Flux-like workloads
-    base = 1024 if device == "cuda" else 512
-    flux_sizes = (
-        [128, 256, 320, 384, 448, 512] if device == "cuda" else [64, 128, 192, 256]
-    )
-    heads = 24 if device == "cuda" else 8
-    head_dim = 128 if device == "cuda" else 64
-
-    wl = []
-    for L in flux_sizes:
-        wl.append(
-            {
-                "name": f"flux_L{L}",
-                "batch": 1,
-                "seq_len": base + L,
-                "heads": heads,
-                "head_dim": head_dim,
-                "dtype": dtype,
-                "device": device,
-                "seed": 0,
-            }
-        )
-
-    kbt.run(
-        wl,
-        jsonl="attn_default.jsonl",
-        reps=5,
-        warmup=2,
-        gen=kbt.attn.gen_qkv,
-        ref=kbt.attn.ref_math,
-        cmp=kbt.attn.cmp_allclose,
-    )
-    kbt.summarize(["attn_default.jsonl"])
-
- -
-
-
-
-
impl wl p50(ms) ok -torch_flash_compiled_default flux_L128 0.52 True -torch_flash_compiled_default flux_L256 0.56 True -torch_flash_compiled_default flux_L320 0.69 True -torch_flash_compiled_default flux_L384 0.72 True -torch_flash_compiled_default flux_L448 0.74 True -torch_flash_compiled_default flux_L512 0.77 True -
-
-
▶ UV Install Logs
- -
-
-

Artifacts:

-attn_default.jsonl -
-
-
- -

Flash Attention with torch.compile(mode="max-autotune")

-
-
- -▼ code -▼ output - ▶ uv-logs - | -Cell: benchmark_max_autotune | 54.06s - | - -Raw -
-
-
-
# /// script
-# requires-python = ">=3.10"
-# dependencies = [
-#     "numpy",
-#     "torch",
-#     "kernels-benchmark-tools",
-# ]
-#
-# [tool.uv.sources]
-# kernels-benchmark-tools = { git = "https://github.com/drbh/kernels-benchmark-tools.git", branch = "main" }
-# ///
-import torch
-import sys
-import os
-import kernels_benchmark_tools as kbt
-
-
-def torch_flash_base(q, k, v):
-    qt, kt, vt = (x.transpose(1, 2).contiguous() for x in (q, k, v))
-    with torch.nn.attention.sdpa_kernel(torch.nn.attention.SDPBackend.FLASH_ATTENTION):
-        o = torch.nn.functional.scaled_dot_product_attention(qt, kt, vt)
-    return o.transpose(1, 2).contiguous()
-
-
-# Compile with max-autotune mode
-compiled_flash_max_autotune = torch.compile(torch_flash_base, mode="max-autotune", fullgraph=True, dynamic=False)
-
-kbt.add(
-    "torch_flash_compiled_max_autotune",
-    compiled_flash_max_autotune,
-    tags={"family": "torch-sdpa", "backend": "FLASH", "compile": "max-autotune"},
-)
-
-if __name__ == "__main__":
-    device = "cuda" if torch.cuda.is_available() else "cpu"
-    dtype = "float32" if device == "cpu" else "bfloat16"
-
-    # Flux-like workloads
-    base = 1024 if device == "cuda" else 512
-    flux_sizes = (
-        [128, 256, 320, 384, 448, 512] if device == "cuda" else [64, 128, 192, 256]
-    )
-    heads = 24 if device == "cuda" else 8
-    head_dim = 128 if device == "cuda" else 64
-
-    wl = []
-    for L in flux_sizes:
-        wl.append(
-            {
-                "name": f"flux_L{L}",
-                "batch": 1,
-                "seq_len": base + L,
-                "heads": heads,
-                "head_dim": head_dim,
-                "dtype": dtype,
-                "device": device,
-                "seed": 0,
-            }
-        )
-
-    kbt.run(
-        wl,
-        jsonl="attn_max_autotune.jsonl",
-        reps=5,
-        warmup=2,
-        gen=kbt.attn.gen_qkv,
-        ref=kbt.attn.ref_math,
-        cmp=kbt.attn.cmp_allclose,
-    )
-    kbt.summarize(["attn_max_autotune.jsonl"])
-
- -
-
-
-
-
impl wl p50(ms) ok -torch_flash_compiled_max_autotune flux_L128 0.62 True -torch_flash_compiled_max_autotune flux_L256 0.69 True -torch_flash_compiled_max_autotune flux_L320 0.82 True -torch_flash_compiled_max_autotune flux_L384 0.85 True -torch_flash_compiled_max_autotune flux_L448 0.88 True -torch_flash_compiled_max_autotune flux_L512 0.92 True -
-
-
▶ UV Install Logs
- -
-
-

Artifacts:

-attn_max_autotune.jsonl -
-
-
-
- - - \ No newline at end of file diff --git a/flash_attn/impls/flash_attention.html b/flash_attn/impls/flash_attention.html deleted file mode 100644 index 6c4655b3e27d9fb741de565f4276ebd9e9f5beef..0000000000000000000000000000000000000000 --- a/flash_attn/impls/flash_attention.html +++ /dev/null @@ -1,3973 +0,0 @@ - - - - - - flash_attention - - - - - - - -
-
- - ← back - -
light
-
reset
- -
-
- -
-
Generated on:
-
- Linux x86_64 | Linux-6.12.40-64.114.amzn2023.x86_64-x86_64-with-glibc2.36 -
-
- -
-

Flash Attention Implementation

-

GPU Info

-
-
- -▼ code -▼ output - ▶ uv-logs - | -Cell: nv | 0.67s - | - -Raw -
-
-
-
import subprocess
-
-print(subprocess.run(["nvidia-smi"], capture_output=True, text=True).stdout)
-
- -
-
-
-
-
Thu Oct 2 19:58:23 2025 -+-----------------------------------------------------------------------------------------+ -| NVIDIA-SMI 570.172.08 Driver Version: 570.172.08 CUDA Version: 12.8 | -|-----------------------------------------+------------------------+----------------------+ -| GPU Name Persistence-M | Bus-Id Disp.A | Volatile Uncorr. ECC | -| Fan Temp Perf Pwr:Usage/Cap | Memory-Usage | GPU-Util Compute M. | -| | | MIG M. | -|=========================================+========================+======================| -| 0 NVIDIA A10G On | 00000000:00:1B.0 Off | 0 | -| 0% 37C P0 92W / 300W | 0MiB / 23028MiB | 0% Default | -| | | N/A | -+-----------------------------------------+------------------------+----------------------+ -| 1 NVIDIA A10G On | 00000000:00:1C.0 Off | 0 | -| 0% 29C P8 24W / 300W | 0MiB / 23028MiB | 0% Default | -| | | N/A | -+-----------------------------------------+------------------------+----------------------+ -| 2 NVIDIA A10G On | 00000000:00:1D.0 Off | 0 | -| 0% 29C P8 24W / 300W | 0MiB / 23028MiB | 0% Default | -| | | N/A | -+-----------------------------------------+------------------------+----------------------+ -| 3 NVIDIA A10G On | 00000000:00:1E.0 Off | 0 | -| 0% 30C P8 24W / 300W | 0MiB / 23028MiB | 0% Default | -| | | N/A | -+-----------------------------------------+------------------------+----------------------+ - -+-----------------------------------------------------------------------------------------+ -| Processes: | -| GPU GI CI PID Type Process name GPU Memory | -| ID ID Usage | -|=========================================================================================| -| No running processes found | -+-----------------------------------------------------------------------------------------+ - -
-
-
- -

Flash Attention Benchmark

-
-
- -▼ code -▼ output - ▶ uv-logs - | -Cell: benchmark | 35.41s - | - -Raw -
-
-
-
# /// script
-# requires-python = ">=3.10"
-# dependencies = [
-#     "numpy",
-#     "torch",
-#     "kernels-benchmark-tools",
-# ]
-#
-# [tool.uv.sources]
-# kernels-benchmark-tools = { git = "https://github.com/drbh/kernels-benchmark-tools.git", branch = "main" }
-# ///
-import torch
-import sys
-import os
-import kernels_benchmark_tools as kbt
-
-
-def torch_flash(q, k, v):
-    qt, kt, vt = (x.transpose(1, 2).contiguous() for x in (q, k, v))
-    with torch.nn.attention.sdpa_kernel(torch.nn.attention.SDPBackend.FLASH_ATTENTION):
-        o = torch.nn.functional.scaled_dot_product_attention(qt, kt, vt)
-    return o.transpose(1, 2).contiguous()
-
-kbt.add(
-    "torch_flash_ma",
-    torch_flash,
-    tags={"family": "torch-sdpa", "backend": "FLASH", "compile": "max-autotune"},
-)
-
-if __name__ == "__main__":
-    device = "cuda" if torch.cuda.is_available() else "cpu"
-    dtype = "float32" if device == "cpu" else "bfloat16"
-
-    # Flux-like workloads scaled down for CPU testing
-    base = 1024 if device == "cuda" else 512
-    flux_sizes = (
-        [128, 256, 320, 384, 448, 512] if device == "cuda" else [64, 128, 192, 256]
-    )
-    heads = 24 if device == "cuda" else 8
-    head_dim = 128 if device == "cuda" else 64
-
-    wl = []
-    for L in flux_sizes:
-        wl.append(
-            {
-                "name": f"flux_L{L}",
-                "batch": 1,
-                "seq_len": base + L,
-                "heads": heads,
-                "head_dim": head_dim,
-                "dtype": dtype,
-                "device": device,
-                "seed": 0,
-            }
-        )
-
-    kbt.run(
-        wl,
-        jsonl="attn.jsonl",
-        reps=5,
-        warmup=2,
-        gen=kbt.attn.gen_qkv,
-        ref=kbt.attn.ref_math,
-        cmp=kbt.attn.cmp_allclose,
-    )
-    kbt.summarize(["attn.jsonl"])
-
- -
-
-
-
-
impl wl p50(ms) ok -torch_flash_ma flux_L128 0.49 True -torch_flash_ma flux_L256 0.52 True -torch_flash_ma flux_L320 0.65 True -torch_flash_ma flux_L384 0.68 True -torch_flash_ma flux_L448 0.71 True -torch_flash_ma flux_L512 0.74 True -
-
-
▶ UV Install Logs
- -
-
-

Artifacts:

-attn.jsonl -
-
-
-
- - - \ No newline at end of file diff --git a/flash_attn/impls/hf_kernels_flash_attn.html b/flash_attn/impls/hf_kernels_flash_attn.html deleted file mode 100644 index eb96379b9cca8d14f8e80d4141cdae4ecd113c07..0000000000000000000000000000000000000000 --- a/flash_attn/impls/hf_kernels_flash_attn.html +++ /dev/null @@ -1,3924 +0,0 @@ - - - - - - hf_kernels_flash_attn - - - - - - - -
-
- - ← back - -
light
-
reset
- -
-
- -
-
Generated on:
-
- Linux x86_64 | Linux-6.12.40-64.114.amzn2023.x86_64-x86_64-with-glibc2.36 -
-
- -
-

HF Kernels - Flash Attention

-

HuggingFace Kernels Flash Attention Benchmark

-
-
- -▼ code -▼ output - ▶ uv-logs - | -Cell: benchmark | 38.65s - | - -Raw -
-
-
-
# /// script
-# requires-python = ">=3.10"
-# dependencies = [
-#     "numpy",
-#     "torch",
-#     "kernels-benchmark-tools",
-#     "kernels",
-# ]
-#
-# [tool.uv.sources]
-# kernels-benchmark-tools = { git = "https://github.com/drbh/kernels-benchmark-tools.git", branch = "main" }
-# ///
-import torch
-import sys
-import os
-import kernels_benchmark_tools as kbt
-from kernels import get_kernel
-
-hf_kernels_flash_attn = get_kernel("kernels-community/flash-attn", revision="v0.0.2")
-
-
-def hf_flash_attention(query, key, value):
-    """HuggingFace Kernels Flash Attention"""
-    return hf_kernels_flash_attn.fwd(query, key, value, is_causal=False)[0]
-
-
-kbt.add(
-    "hf_kernels_flash_attn",
-    hf_flash_attention,
-    tags={"family": "hf-kernels", "backend": "flash-attn", "compile": "none"},
-)
-
-if __name__ == "__main__":
-    device = "cuda" if torch.cuda.is_available() else "cpu"
-
-    if device == "cpu":
-        print("HF Kernels Flash Attention requires CUDA - skipping benchmark")
-        sys.exit(0)
-
-    dtype = "bfloat16"
-
-    # Flux-like workloads
-    base = 1024
-    flux_sizes = [128, 256, 320, 384, 448, 512]
-    heads = 24
-    head_dim = 128
-
-    wl = []
-    for L in flux_sizes:
-        wl.append(
-            {
-                "name": f"flux_L{L}",
-                "batch": 1,
-                "seq_len": base + L,
-                "heads": heads,
-                "head_dim": head_dim,
-                "dtype": dtype,
-                "device": device,
-                "seed": 0,
-            }
-        )
-
-    kbt.run(
-        wl,
-        jsonl="attn.jsonl",
-        reps=5,
-        warmup=2,
-        gen=kbt.attn.gen_qkv,
-        ref=kbt.attn.ref_math,
-        cmp=kbt.attn.cmp_allclose,
-    )
-    kbt.summarize(["attn.jsonl"])
-
- -
-
-
-
-
impl wl p50(ms) ok -hf_kernels_flash_attn flux_L128 0.35 True -hf_kernels_flash_attn flux_L256 0.38 True -hf_kernels_flash_attn flux_L320 0.49 True -hf_kernels_flash_attn flux_L384 0.52 True -hf_kernels_flash_attn flux_L448 0.54 True -hf_kernels_flash_attn flux_L512 0.56 True -
-
-
▶ UV Install Logs
- -
-
Fetching 20 files: 0%| | 0/20 [00:00<?, ?it/s] -Fetching 20 files: 5%|▌ | 1/20 [00:00<00:03, 5.70it/s] -Fetching 20 files: 10%|█ | 2/20 [00:01<00:13, 1.36it/s] -Fetching 20 files: 100%|██████████| 20/20 [00:01<00:00, 15.31it/s]
-
-

Artifacts:

-attn.jsonl -
-
-
-
- - - \ No newline at end of file diff --git a/flash_attn/impls/hf_kernels_flash_attn3.html b/flash_attn/impls/hf_kernels_flash_attn3.html deleted file mode 100644 index 8f2225d6d312b7248216f259f9d91f6f5d1a2bff..0000000000000000000000000000000000000000 --- a/flash_attn/impls/hf_kernels_flash_attn3.html +++ /dev/null @@ -1,3923 +0,0 @@ - - - - - - hf_kernels_flash_attn3 - - - - - - - -
-
- - ← back - -
light
-
reset
- -
-
- -
-
Generated on:
-
- Linux x86_64 | Linux-6.12.40-64.114.amzn2023.x86_64-x86_64-with-glibc2.36 -
-
- -
-

HF Kernels - Flash Attention 3

-

HuggingFace Kernels Flash Attention 3 Benchmark

-
-
- -▼ code -▼ output - ▶ uv-logs - | -Cell: benchmark | 38.16s - | - -Raw -
-
-
-
# /// script
-# requires-python = ">=3.10"
-# dependencies = [
-#     "numpy",
-#     "torch",
-#     "kernels-benchmark-tools",
-#     "kernels",
-# ]
-#
-# [tool.uv.sources]
-# kernels-benchmark-tools = { git = "https://github.com/drbh/kernels-benchmark-tools.git", branch = "main" }
-# ///
-import torch
-import sys
-import os
-import kernels_benchmark_tools as kbt
-from kernels import get_kernel
-
-hf_kernels_flash_attn3 = get_kernel("kernels-community/flash-attn3")
-
-
-def hf_flash_attention3(query, key, value):
-    return hf_kernels_flash_attn3.flash_attn_func(query, key, value, causal=False)[0]
-
-
-kbt.add(
-    "hf_kernels_flash_attn3",
-    hf_flash_attention3,
-    tags={"family": "hf-kernels", "backend": "flash-attn3", "compile": "none"},
-)
-
-if __name__ == "__main__":
-    device = "cuda" if torch.cuda.is_available() else "cpu"
-
-    if device == "cpu":
-        print("HF Kernels Flash Attention 3 requires CUDA - skipping benchmark")
-        sys.exit(0)
-
-    dtype = "bfloat16"
-
-    # Flux-like workloads
-    base = 1024
-    flux_sizes = [128, 256, 320, 384, 448, 512]
-    heads = 24
-    head_dim = 128
-
-    wl = []
-    for L in flux_sizes:
-        wl.append(
-            {
-                "name": f"flux_L{L}",
-                "batch": 1,
-                "seq_len": base + L,
-                "heads": heads,
-                "head_dim": head_dim,
-                "dtype": dtype,
-                "device": device,
-                "seed": 0,
-            }
-        )
-
-    kbt.run(
-        wl,
-        jsonl="attn.jsonl",
-        reps=5,
-        warmup=2,
-        gen=kbt.attn.gen_qkv,
-        ref=kbt.attn.ref_math,
-        cmp=kbt.attn.cmp_allclose,
-    )
-    kbt.summarize(["attn.jsonl"])
-
- -
-
-
-
-
impl wl p50(ms) ok -hf_kernels_flash_attn3 flux_L128 0.36 True -hf_kernels_flash_attn3 flux_L256 0.39 True -hf_kernels_flash_attn3 flux_L320 0.52 True -hf_kernels_flash_attn3 flux_L384 0.53 True -hf_kernels_flash_attn3 flux_L448 0.57 True -hf_kernels_flash_attn3 flux_L512 0.57 True -
-
-
▶ UV Install Logs
- -
-
Fetching 4 files: 0%| | 0/4 [00:00<?, ?it/s] -Fetching 4 files: 25%|██▌ | 1/4 [00:00<00:00, 5.17it/s] -Fetching 4 files: 50%|█████ | 2/4 [00:01<00:01, 1.22it/s] -Fetching 4 files: 100%|██████████| 4/4 [00:01<00:00, 2.76it/s]
-
-

Artifacts:

-attn.jsonl -
-
-
-
- - - \ No newline at end of file diff --git a/flash_attn/impls/index.html b/flash_attn/impls/index.html deleted file mode 100644 index fbf5a24033c4d68d64ce4f1be0ba13266dd4e89d..0000000000000000000000000000000000000000 --- a/flash_attn/impls/index.html +++ /dev/null @@ -1,94 +0,0 @@ - - - - - - Index of /flash_attn/impls - - - -
- ← back -
-

Index of /flash_attn/impls

- - - \ No newline at end of file diff --git a/flash_attn/impls/mem_efficient_attention.html b/flash_attn/impls/mem_efficient_attention.html deleted file mode 100644 index 7213e66ff6d02989f90efc2869171fe1cad361a3..0000000000000000000000000000000000000000 --- a/flash_attn/impls/mem_efficient_attention.html +++ /dev/null @@ -1,3914 +0,0 @@ - - - - - - mem_efficient_attention - - - - - - - -
-
- - ← back - -
light
-
reset
- -
-
- -
-
Generated on:
-
- Linux x86_64 | Linux-6.12.40-64.114.amzn2023.x86_64-x86_64-with-glibc2.36 -
-
- -
-

Memory Efficient Attention Implementation

-

Memory Efficient SDPA Benchmark

-
-
- -▼ code -▼ output - ▶ uv-logs - | -Cell: benchmark | 36.80s - | - -Raw -
-
-
-
# /// script
-# requires-python = ">=3.10"
-# dependencies = [
-#     "numpy",
-#     "torch",
-#     "kernels-benchmark-tools",
-# ]
-#
-# [tool.uv.sources]
-# kernels-benchmark-tools = { git = "https://github.com/drbh/kernels-benchmark-tools.git", branch = "main" }
-# ///
-import torch
-import sys
-import os
-import kernels_benchmark_tools as kbt
-
-
-def torch_mem_eff(q, k, v):
-    qt, kt, vt = (x.transpose(1, 2).contiguous() for x in (q, k, v))
-    with torch.nn.attention.sdpa_kernel(
-        torch.nn.attention.SDPBackend.EFFICIENT_ATTENTION
-    ):
-        o = torch.nn.functional.scaled_dot_product_attention(qt, kt, vt)
-    return o.transpose(1, 2).contiguous()
-
-kbt.add(
-    "torch_mem_eff",
-    torch_mem_eff,
-    tags={"family": "torch-sdpa", "backend": "EFFICIENT", "compile": "none"},
-)
-
-if __name__ == "__main__":
-    device = "cuda" if torch.cuda.is_available() else "cpu"
-    dtype = "float32" if device == "cpu" else "bfloat16"
-
-    # Flux-like workloads scaled down for CPU testing
-    base = 1024 if device == "cuda" else 512
-    flux_sizes = (
-        [128, 256, 320, 384, 448, 512] if device == "cuda" else [64, 128, 192, 256]
-    )
-    heads = 24 if device == "cuda" else 8
-    head_dim = 128 if device == "cuda" else 64
-
-    wl = []
-    for L in flux_sizes:
-        wl.append(
-            {
-                "name": f"flux_L{L}",
-                "batch": 1,
-                "seq_len": base + L,
-                "heads": heads,
-                "head_dim": head_dim,
-                "dtype": dtype,
-                "device": device,
-                "seed": 0,
-            }
-        )
-
-    kbt.run(
-        wl,
-        jsonl="attn.jsonl",
-        reps=5,
-        warmup=2,
-        gen=kbt.attn.gen_qkv,
-        ref=kbt.attn.ref_math,
-        cmp=kbt.attn.cmp_allclose,
-    )
-    kbt.summarize(["attn.jsonl"])
-
- -
-
-
-
-
impl wl p50(ms) ok -torch_mem_eff flux_L128 0.59 True -torch_mem_eff flux_L256 0.65 True -torch_mem_eff flux_L320 0.78 True -torch_mem_eff flux_L384 0.79 True -torch_mem_eff flux_L448 0.85 True -torch_mem_eff flux_L512 0.95 True -
-
-
▶ UV Install Logs
- -
-
-

Artifacts:

-attn.jsonl -
-
-
-
- - - \ No newline at end of file diff --git a/flash_attn/impls/sage_attention.html b/flash_attn/impls/sage_attention.html deleted file mode 100644 index 5a357f2e63a095cefb67f14c37e15b66b39d007a..0000000000000000000000000000000000000000 --- a/flash_attn/impls/sage_attention.html +++ /dev/null @@ -1,3937 +0,0 @@ - - - - - - sage_attention - - - - - - - -
-
- - ← back - -
light
-
reset
- -
-
- -
-
Generated on:
-
- Linux x86_64 | Linux-6.12.40-64.114.amzn2023.x86_64-x86_64-with-glibc2.36 -
-
- -
-

SageAttention Implementation

-

SageAttention Benchmark (INT8 Quantized)

-
-
- -▼ code -▼ output - ▶ uv-logs - | -Cell: benchmark | 40.58s - | - -Raw -
-
-
-
# /// script
-# requires-python = ">=3.10"
-# dependencies = [
-#     "numpy",
-#     "torch",
-#     "kernels",
-#     "kernels-benchmark-tools",
-#     "sageattention",
-# ]
-#
-# [tool.uv.sources]
-# kernels-benchmark-tools = { git = "https://github.com/drbh/kernels-benchmark-tools.git", branch = "main" }
-# ///
-import torch
-import sys
-import os
-import kernels_benchmark_tools as kbt
-# from sageattention import sageattn_qk_int8_pv_fp16_cuda
-
-
-# def sage_attention(q, k, v):
-#     """SageAttention with INT8 Q/K quantization and FP16 P/V"""
-#     return sageattn_qk_int8_pv_fp16_cuda(q, k, v, tensor_layout="NHD")
-
-from kernels import get_kernel
-
-hf_kernels_sage_attn = get_kernel("kernels-community/sage_attention")
-
-
-def sage_attention(query, key, value):
-    """HuggingFace Kernels Flash Attention"""
-    return hf_kernels_sage_attn.fwd(query, key, value, is_causal=False)[0]
-
-kbt.add(
-    "sage_int8_fp16",
-    sage_attention,
-    tags={"family": "sageattention", "backend": "int8_fp16_cuda", "compile": "none"},
-)
-
-if __name__ == "__main__":
-    device = "cuda" if torch.cuda.is_available() else "cpu"
-
-    if device == "cpu":
-        print("SageAttention requires CUDA - skipping benchmark")
-        sys.exit(0)
-
-    dtype = "bfloat16"
-
-    # Flux-like workloads
-    base = 1024
-    flux_sizes = [128, 256, 320, 384, 448, 512]
-    heads = 24
-    head_dim = 128
-
-    wl = []
-    for L in flux_sizes:
-        wl.append(
-            {
-                "name": f"flux_L{L}",
-                "batch": 1,
-                "seq_len": base + L,
-                "heads": heads,
-                "head_dim": head_dim,
-                "dtype": dtype,
-                "device": device,
-                "seed": 0,
-            }
-        )
-
-    kbt.run(
-        wl,
-        jsonl="attn.jsonl",
-        reps=5,
-        warmup=2,
-        gen=kbt.attn.gen_qkv,
-        ref=kbt.attn.ref_math,
-        cmp=kbt.attn.cmp_allclose,
-    )
-    kbt.summarize(["attn.jsonl"])
-
- -
-
-
-
-
impl wl p50(ms) ok -sage_int8_fp16 flux_L128 FAIL False - Error: module 'sage_attention_1863f4c92418f0f6' has no attribute 'fwd' -sage_int8_fp16 flux_L256 FAIL False - Error: module 'sage_attention_1863f4c92418f0f6' has no attribute 'fwd' -sage_int8_fp16 flux_L320 FAIL False - Error: module 'sage_attention_1863f4c92418f0f6' has no attribute 'fwd' -sage_int8_fp16 flux_L384 FAIL False - Error: module 'sage_attention_1863f4c92418f0f6' has no attribute 'fwd' -sage_int8_fp16 flux_L448 FAIL False - Error: module 'sage_attention_1863f4c92418f0f6' has no attribute 'fwd' -sage_int8_fp16 flux_L512 FAIL False - Error: module 'sage_attention_1863f4c92418f0f6' has no attribute 'fwd' -
-
-
▶ UV Install Logs
- -
-
Fetching 11 files: 0%| | 0/11 [00:00<?, ?it/s] -Fetching 11 files: 9%|▉ | 1/11 [00:00<00:01, 5.59it/s] -Fetching 11 files: 73%|███████▎ | 8/11 [00:00<00:00, 12.79it/s] -Fetching 11 files: 100%|██████████| 11/11 [00:00<00:00, 16.77it/s]
-
-

Artifacts:

-attn.jsonl -
-
-
-
- - - \ No newline at end of file diff --git a/flash_attn/impls/xformers.html b/flash_attn/impls/xformers.html deleted file mode 100644 index b4b072cd702fba0ae3c8cc918c12746908e8d2e3..0000000000000000000000000000000000000000 --- a/flash_attn/impls/xformers.html +++ /dev/null @@ -1,3916 +0,0 @@ - - - - - - xformers - - - - - - - -
-
- - ← back - -
light
-
reset
- -
-
- -
-
Generated on:
-
- Linux x86_64 | Linux-6.12.40-64.114.amzn2023.x86_64-x86_64-with-glibc2.36 -
-
- -
-

xFormers Memory Efficient Attention

-

xFormers Benchmark

-
-
- -▼ code -▼ output - ▶ uv-logs - | -Cell: benchmark | 42.08s - | - -Raw -
-
-
-
# /// script
-# requires-python = ">=3.10"
-# dependencies = [
-#     "numpy",
-#     "torch",
-#     "kernels-benchmark-tools",
-#     "xformers",
-# ]
-#
-# [tool.uv.sources]
-# kernels-benchmark-tools = { git = "https://github.com/drbh/kernels-benchmark-tools.git", branch = "main" }
-# ///
-import torch
-import sys
-import os
-import kernels_benchmark_tools as kbt
-import xformers.ops as xops
-
-
-def xformers_attention(q, k, v):
-    """xFormers memory efficient attention"""
-    # xFormers expects [batch, seq_len, heads, head_dim]
-    return xops.memory_efficient_attention(q, k, v)
-
-
-kbt.add(
-    "xformers_meff",
-    xformers_attention,
-    tags={"family": "xformers", "backend": "memory_efficient", "compile": "none"},
-)
-
-if __name__ == "__main__":
-    device = "cuda" if torch.cuda.is_available() else "cpu"
-    dtype = "float32" if device == "cpu" else "bfloat16"
-
-    # Flux-like workloads
-    base = 1024 if device == "cuda" else 512
-    flux_sizes = (
-        [128, 256, 320, 384, 448, 512] if device == "cuda" else [64, 128, 192, 256]
-    )
-    heads = 24 if device == "cuda" else 8
-    head_dim = 128 if device == "cuda" else 64
-
-    wl = []
-    for L in flux_sizes:
-        wl.append(
-            {
-                "name": f"flux_L{L}",
-                "batch": 1,
-                "seq_len": base + L,
-                "heads": heads,
-                "head_dim": head_dim,
-                "dtype": dtype,
-                "device": device,
-                "seed": 0,
-            }
-        )
-
-    kbt.run(
-        wl,
-        jsonl="attn.jsonl",
-        reps=5,
-        warmup=2,
-        gen=kbt.attn.gen_qkv,
-        ref=kbt.attn.ref_math,
-        cmp=kbt.attn.cmp_allclose,
-    )
-    kbt.summarize(["attn.jsonl"])
-
- -
-
-
-
-
impl wl p50(ms) ok -xformers_meff flux_L128 0.45 True -xformers_meff flux_L256 0.47 True -xformers_meff flux_L320 0.60 True -xformers_meff flux_L384 0.60 True -xformers_meff flux_L448 0.64 True -xformers_meff flux_L512 0.65 True -
-
-
▶ UV Install Logs
- -
-
-

Artifacts:

-attn.jsonl -
-
-
-
- - - \ No newline at end of file diff --git a/flash_attn/index.html b/flash_attn/index.html deleted file mode 100644 index eea7df846d9f2d44c6c6e03a5ac30d00cecd90cf..0000000000000000000000000000000000000000 --- a/flash_attn/index.html +++ /dev/null @@ -1,89 +0,0 @@ - - - - - - Index of /flash_attn - - - -
- ← back -
-

Index of /flash_attn

- - - \ No newline at end of file diff --git a/flash_attn/results/artifacts/combine/latency.csv b/flash_attn/results/artifacts/combine/latency.csv deleted file mode 100644 index 3d8227b97b600492490413c0bfcaae4e1e8d53d2..0000000000000000000000000000000000000000 --- a/flash_attn/results/artifacts/combine/latency.csv +++ /dev/null @@ -1,43 +0,0 @@ -Implementation,Impl ID,Workload,Batch,Seq Length,Heads,Head Dim,Dtype,Mean (ms),P10 (ms),P50 (ms),P90 (ms),Reps,Peak Mem (MB),Backend,Family -Flash (PyTorch SDPA),torch_flash_ma,flux_L128,1,1152,24,128,bfloat16,0.49411200881004336,0.48844799399375916,0.4936000108718872,0.4944640100002289,5,83.38,FLASH,torch-sdpa -Flash (PyTorch SDPA),torch_flash_ma,flux_L256,1,1280,24,128,bfloat16,0.5234112024307251,0.5224320292472839,0.5235199928283691,0.5235840082168579,5,90.62,FLASH,torch-sdpa -Flash (PyTorch SDPA),torch_flash_ma,flux_L320,1,1344,24,128,bfloat16,0.6527232170104981,0.6503040194511414,0.6524800062179565,0.6545600295066833,5,95.06,FLASH,torch-sdpa -Flash (PyTorch SDPA),torch_flash_ma,flux_L384,1,1408,24,128,bfloat16,0.682803213596344,0.6805760264396667,0.6828799843788147,0.6832640171051025,5,99.88,FLASH,torch-sdpa -Flash (PyTorch SDPA),torch_flash_ma,flux_L448,1,1472,24,128,bfloat16,0.7075456142425537,0.7057600021362305,0.7063360214233398,0.7070720195770264,5,103.81,FLASH,torch-sdpa -Flash (PyTorch SDPA),torch_flash_ma,flux_L512,1,1536,24,128,bfloat16,0.7379711985588073,0.7368639707565308,0.7372480034828186,0.7391039729118347,5,109.12,FLASH,torch-sdpa -MemEff (PyTorch SDPA),torch_mem_eff,flux_L128,1,1152,24,128,bfloat16,0.5874239921569824,0.5861759781837463,0.5873280167579651,0.5877439975738525,5,83.38,EFFICIENT,torch-sdpa -MemEff (PyTorch SDPA),torch_mem_eff,flux_L256,1,1280,24,128,bfloat16,0.6502719998359681,0.6490240097045898,0.649183988571167,0.6517760157585144,5,90.62,EFFICIENT,torch-sdpa -MemEff (PyTorch SDPA),torch_mem_eff,flux_L320,1,1344,24,128,bfloat16,0.7812095880508423,0.7761600017547607,0.7803199887275696,0.7852799892425537,5,95.94,EFFICIENT,torch-sdpa -MemEff (PyTorch SDPA),torch_mem_eff,flux_L384,1,1408,24,128,bfloat16,0.7948480010032654,0.7911999821662903,0.7935360074043274,0.7948480248451233,5,100.0,EFFICIENT,torch-sdpa -MemEff (PyTorch SDPA),torch_mem_eff,flux_L448,1,1472,24,128,bfloat16,0.8463295936584473,0.8449919819831848,0.8459839820861816,0.8461120128631592,5,103.81,EFFICIENT,torch-sdpa -MemEff (PyTorch SDPA),torch_mem_eff,flux_L512,1,1536,24,128,bfloat16,0.9538687944412232,0.9492800235748291,0.9518399834632874,0.9581760168075562,5,109.12,EFFICIENT,torch-sdpa -xFormers,xformers_meff,flux_L128,1,1152,24,128,bfloat16,0.4515071928501129,0.44364801049232483,0.4524799883365631,0.4557119905948639,5,83.38,memory_efficient,xformers -xFormers,xformers_meff,flux_L256,1,1280,24,128,bfloat16,0.46787199974060056,0.46489599347114563,0.4684160053730011,0.46908798813819885,5,90.62,memory_efficient,xformers -xFormers,xformers_meff,flux_L320,1,1344,24,128,bfloat16,0.6001471996307373,0.596992015838623,0.5984640121459961,0.6016640067100525,5,95.06,memory_efficient,xformers -xFormers,xformers_meff,flux_L384,1,1408,24,128,bfloat16,0.6023231983184815,0.5997440218925476,0.6031039953231812,0.6032639741897583,5,99.88,memory_efficient,xformers -xFormers,xformers_meff,flux_L448,1,1472,24,128,bfloat16,0.6411136031150818,0.6381760239601135,0.6414719820022583,0.6421440243721008,5,103.81,memory_efficient,xformers -xFormers,xformers_meff,flux_L512,1,1536,24,128,bfloat16,0.6594688057899475,0.6441280245780945,0.6496639847755432,0.6527680158615112,5,109.12,memory_efficient,xformers -Compiled (default),torch_flash_compiled_default,flux_L128,1,1152,24,128,bfloat16,0.5181439876556396,0.5141760110855103,0.5175679922103882,0.5197759866714478,5,83.38,FLASH,torch-sdpa -Compiled (default),torch_flash_compiled_default,flux_L256,1,1280,24,128,bfloat16,0.5579584002494812,0.5549119710922241,0.5582720041275024,0.5598080158233643,5,90.62,FLASH,torch-sdpa -Compiled (default),torch_flash_compiled_default,flux_L320,1,1344,24,128,bfloat16,0.6872959971427918,0.6853119730949402,0.687391996383667,0.6883519887924194,5,95.25,FLASH,torch-sdpa -Compiled (default),torch_flash_compiled_default,flux_L384,1,1408,24,128,bfloat16,0.716153597831726,0.7128639817237854,0.7160959839820862,0.7167680263519287,5,99.88,FLASH,torch-sdpa -Compiled (default),torch_flash_compiled_default,flux_L448,1,1472,24,128,bfloat16,0.7418303966522217,0.7386879920959473,0.7400959730148315,0.7415040135383606,5,103.81,FLASH,torch-sdpa -Compiled (default),torch_flash_compiled_default,flux_L512,1,1536,24,128,bfloat16,0.7745471954345703,0.7708160281181335,0.7740799784660339,0.7753919959068298,5,109.12,FLASH,torch-sdpa -Compiled (max-autotune),torch_flash_compiled_max_autotune,flux_L128,1,1152,24,128,bfloat16,0.6468096017837525,0.6144000291824341,0.6245759725570679,0.6483200192451477,5,67.5,FLASH,torch-sdpa -Compiled (max-autotune),torch_flash_compiled_max_autotune,flux_L256,1,1280,24,128,bfloat16,0.7060160160064697,0.6689280271530151,0.6851199865341187,0.7184960246086121,5,75.0,FLASH,torch-sdpa -Compiled (max-autotune),torch_flash_compiled_max_autotune,flux_L320,1,1344,24,128,bfloat16,0.8332608103752136,0.7953600287437439,0.8155840039253235,0.8403519988059998,5,80.38,FLASH,torch-sdpa -Compiled (max-autotune),torch_flash_compiled_max_autotune,flux_L384,1,1408,24,128,bfloat16,0.8719295978546142,0.8470720052719116,0.849727988243103,0.8745279908180237,5,82.5,FLASH,torch-sdpa -Compiled (max-autotune),torch_flash_compiled_max_autotune,flux_L448,1,1472,24,128,bfloat16,0.9034304022789001,0.8677120208740234,0.8835520148277283,0.9034240245819092,5,86.25,FLASH,torch-sdpa -Compiled (max-autotune),torch_flash_compiled_max_autotune,flux_L512,1,1536,24,128,bfloat16,0.9387519836425782,0.9154239892959595,0.9213759899139404,0.9359679818153381,5,90.0,FLASH,torch-sdpa -HF Kernels Flash Attn,hf_kernels_flash_attn,flux_L128,1,1152,24,128,bfloat16,0.3455295979976654,0.34355199337005615,0.34563198685646057,0.34643200039863586,5,83.38,flash-attn,hf-kernels -HF Kernels Flash Attn,hf_kernels_flash_attn,flux_L256,1,1280,24,128,bfloat16,0.3756160080432892,0.37411201000213623,0.3752000033855438,0.3770880103111267,5,90.62,flash-attn,hf-kernels -HF Kernels Flash Attn,hf_kernels_flash_attn,flux_L320,1,1344,24,128,bfloat16,0.4953216016292572,0.49324798583984375,0.49433600902557373,0.49663999676704407,5,95.06,flash-attn,hf-kernels -HF Kernels Flash Attn,hf_kernels_flash_attn,flux_L384,1,1408,24,128,bfloat16,0.5157055854797363,0.5142719745635986,0.516319990158081,0.516543984413147,5,99.88,flash-attn,hf-kernels -HF Kernels Flash Attn,hf_kernels_flash_attn,flux_L448,1,1472,24,128,bfloat16,0.5356672048568726,0.5346879959106445,0.5358080267906189,0.5361599922180176,5,103.81,flash-attn,hf-kernels -HF Kernels Flash Attn,hf_kernels_flash_attn,flux_L512,1,1536,24,128,bfloat16,0.5587136030197144,0.5557760000228882,0.5574079751968384,0.5581120252609253,5,109.12,flash-attn,hf-kernels -HF Kernels Flash Attn3,hf_kernels_flash_attn3,flux_L128,1,1152,24,128,bfloat16,0.3619711995124817,0.3603839874267578,0.361952006816864,0.3624640107154846,5,83.38,flash-attn3,hf-kernels -HF Kernels Flash Attn3,hf_kernels_flash_attn3,flux_L256,1,1280,24,128,bfloat16,0.3912447988986969,0.3892799913883209,0.3909760117530823,0.3922559916973114,5,90.62,flash-attn3,hf-kernels -HF Kernels Flash Attn3,hf_kernels_flash_attn3,flux_L320,1,1344,24,128,bfloat16,0.5258048176765442,0.5240640044212341,0.5248960256576538,0.5248960256576538,5,95.06,flash-attn3,hf-kernels -HF Kernels Flash Attn3,hf_kernels_flash_attn3,flux_L384,1,1408,24,128,bfloat16,0.5276032090187073,0.5265600085258484,0.5277760028839111,0.5282559990882874,5,99.88,flash-attn3,hf-kernels -HF Kernels Flash Attn3,hf_kernels_flash_attn3,flux_L448,1,1472,24,128,bfloat16,0.5656383991241455,0.5639039874076843,0.5657920241355896,0.5668479800224304,5,103.81,flash-attn3,hf-kernels -HF Kernels Flash Attn3,hf_kernels_flash_attn3,flux_L512,1,1536,24,128,bfloat16,0.5789952039718628,0.5689600110054016,0.5698239803314209,0.5713919997215271,5,109.12,flash-attn3,hf-kernels diff --git a/flash_attn/results/artifacts/combine/latency.png b/flash_attn/results/artifacts/combine/latency.png deleted file mode 100644 index 2f751d720a7fc9c0b347c634988f6ad8d0824a42..0000000000000000000000000000000000000000 --- a/flash_attn/results/artifacts/combine/latency.png +++ /dev/null @@ -1,3 +0,0 @@ -version https://git-lfs.github.com/spec/v1 -oid sha256:87dbea8f2773d7fcee9fd191cb6e67cd1e2ddd379cef90ee01bb4ac40a55b5f1 -size 110313 diff --git a/flash_attn/results/artifacts/combine/latency.svg b/flash_attn/results/artifacts/combine/latency.svg deleted file mode 100644 index 242a7df95a0a79e47f8fa9f6f7a99913358dab23..0000000000000000000000000000000000000000 --- a/flash_attn/results/artifacts/combine/latency.svg +++ /dev/null @@ -1,3 +0,0 @@ -version https://git-lfs.github.com/spec/v1 -oid sha256:2c1da56080e7fd1a85c14295083b11d6bac981f6fb3faef98b0753eb2c1676c7 -size 28243 diff --git a/flash_attn/results/cells/combine.py b/flash_attn/results/cells/combine.py deleted file mode 100644 index f703ae3d1403d602560c9d3b36d51fab69b7f3a5..0000000000000000000000000000000000000000 --- a/flash_attn/results/cells/combine.py +++ /dev/null @@ -1,319 +0,0 @@ -# /// script -# requires-python = ">=3.10" -# dependencies = [ -# "numpy", -# "torch", -# "kernels-benchmark-tools", -# "matplotlib", -# ] -# -# [tool.uv.sources] -# kernels-benchmark-tools = { git = "https://github.com/drbh/kernels-benchmark-tools.git", branch = "main" } -# /// -import os -import sys -from pathlib import Path -import json -import torch # noqa: F401 # imported because upstream may expect torch to be importable -import kernels_benchmark_tools as kbt - -# --- Matplotlib setup and helpers ------------------------------------------------ -import matplotlib as mpl -import matplotlib.pyplot as plt -import csv - - -# Keep text as text (not paths) so CSS can style fonts, size, etc. -mpl.rcParams["svg.fonttype"] = "none" -# Make ids deterministic across builds -mpl.rcParams["svg.hashsalt"] = "latency-benchmark-combined" -# Avoid auto-closed figures interfering with our tagging -mpl.rcParams["figure.autolayout"] = True -# Make background transparent -mpl.rcParams["figure.facecolor"] = "none" -mpl.rcParams["axes.facecolor"] = "none" -mpl.rcParams["savefig.facecolor"] = "none" -mpl.rcParams["savefig.edgecolor"] = "none" - -def _slugify(s: str) -> str: - s = (s or "").strip().lower() - keep = [] - for ch in s: - if ch.isalnum(): - keep.append(ch) - elif ch in (" ", "-", "_", "/", ".", ":"): - keep.append("-") - else: - keep.append("") - out = "".join(keep) - while "--" in out: - out = out.replace("--", "-") - return out.strip("-") or "unnamed" - -def _tag_current_figure(default_series_prefix="series"): - """Attach SVG ids (gid) to key artists so they can be targeted from CSS.""" - fig = plt.gcf() - if fig is None: - return - - # Tag the figure itself - fig.set_gid("figure--latency") - - for ax_idx, ax in enumerate(fig.get_axes(), start=1): - ax.set_gid(f"axes--{ax_idx}") - - # Axis labels & title - if ax.get_title(): - for t in ax.texts: - if t.get_text() == ax.get_title(): - t.set_gid("title--main") - if ax.xaxis and ax.xaxis.get_label(): - ax.xaxis.label.set_gid("label--x") - if ax.yaxis and ax.yaxis.get_label(): - ax.yaxis.label.set_gid("label--y") - - # Gridlines - for i, gl in enumerate(ax.get_xgridlines(), start=1): - gl.set_gid(f"grid-x--{i}") - for i, gl in enumerate(ax.get_ygridlines(), start=1): - gl.set_gid(f"grid-y--{i}") - - # Legend block & entries - leg = ax.get_legend() - if leg is not None: - leg.set_gid("legend") - for i, txt in enumerate(leg.get_texts(), start=1): - label_slug = _slugify(txt.get_text()) - txt.set_gid(f"legend-label--{label_slug or i}") - - # Series (lines, patches) - # Lines - line_seen = {} - for ln in getattr(ax, "lines", []): - raw_label = ln.get_label() or "" - # Matplotlib uses labels beginning with "_" for non-legendable items - label = raw_label if not raw_label.startswith("_") else f"{default_series_prefix}" - slug = _slugify(label) - line_seen[slug] = line_seen.get(slug, 0) + 1 - suffix = "" if line_seen[slug] == 1 else f"-{line_seen[slug]}" - ln.set_gid(f"series--{slug}{suffix}") - - # Patches (bars, areas) - patch_seen = {} - for pt in getattr(ax, "patches", []): - label = getattr(pt, "get_label", lambda: "")() or f"{default_series_prefix}" - if isinstance(label, str) and label.startswith("_"): - label = default_series_prefix - slug = _slugify(label) - patch_seen[slug] = patch_seen.get(slug, 0) + 1 - suffix = "" if patch_seen[slug] == 1 else f"-{patch_seen[slug]}" - pt.set_gid(f"series--{slug}{suffix}") - -def _postprocess_svg_add_classes(svg_path: Path): - """Add convenient CSS classes alongside ids (e.g., class='series grid grid-x').""" - try: - import xml.etree.ElementTree as ET - ET.register_namespace("", "http://www.w3.org/2000/svg") - tree = ET.parse(svg_path) - root = tree.getroot() - for el in root.iter(): - el_id = el.attrib.get("id", "") - if not el_id: - continue - cls = [] - if el_id.startswith("figure--"): - cls.append("figure") - elif el_id.startswith("axes--"): - cls.append("axes") - elif el_id.startswith("grid-x--"): - cls += ["grid", "grid-x"] - elif el_id.startswith("grid-y--"): - cls += ["grid", "grid-y"] - elif el_id.startswith("legend"): - cls.append("legend") - elif el_id.startswith("label--x"): - cls.append("xlabel") - elif el_id.startswith("label--y"): - cls.append("ylabel") - elif el_id.startswith("title--"): - cls.append("title") - elif el_id.startswith("series--"): - cls.append("series") - if cls: - # Preserve any existing class (unlikely from Matplotlib) - existing = el.attrib.get("class", "") - el.set("class", (existing + " " + " ".join(cls)).strip()) - tree.write(svg_path, encoding="utf-8", xml_declaration=True) - except Exception as e: - print(f"✗ SVG postprocess (classes) skipped: {e}") - -# Monkey-patch savefig to force SVG & ensure tagging occurs even if kbt.viz saves internally. -_orig_savefig = plt.savefig -def _savefig_svg(fname, *args, **kwargs): - # Always save as SVG at a stable path for the artifact system - out = Path("latency.svg") - kwargs["format"] = "svg" - # Ensure everything we care about has ids before export - _tag_current_figure() - res = _orig_savefig(out, *args, **kwargs) - # Add helpful CSS classes on top of ids - _postprocess_svg_add_classes(out) - print(f"✓ Combined visualization saved as {out}") - return res - -plt.savefig = _savefig_svg # apply patch - -# Capture close calls in case kbt.viz() closes figures before we re-save -_orig_close = plt.close -_last_closed = {"fig": None} -def _capture_close(arg=None): - try: - if hasattr(arg, "savefig"): # looks like a Figure - _last_closed["fig"] = arg - else: - _last_closed["fig"] = plt.gcf() - finally: - return _orig_close(arg) -plt.close = _capture_close - -# --- Locate benchmark artifacts -------------------------------------------------- -cache_dirs = { - "Flash (PyTorch SDPA)": os.environ.get('UVNOTE_FILE_FLASH_ATTENTION_BENCHMARK'), - "MemEff (PyTorch SDPA)": os.environ.get('UVNOTE_FILE_MEM_EFFICIENT_ATTENTION_BENCHMARK'), - "Flash Attn 2": os.environ.get('UVNOTE_FILE_FLASH_ATTN2_BENCHMARK'), - "xFormers": os.environ.get('UVNOTE_FILE_XFORMERS_BENCHMARK'), - "SageAttention": os.environ.get('UVNOTE_FILE_SAGE_ATTENTION_BENCHMARK'), - "Compiled (default)": os.environ.get('UVNOTE_FILE_COMPILED_VARIANTS_BENCHMARK_DEFAULT'), - "Compiled (max-autotune)": os.environ.get('UVNOTE_FILE_COMPILED_VARIANTS_BENCHMARK_MAX_AUTOTUNE'), - "HF Kernels Flash Attn": os.environ.get('UVNOTE_FILE_HF_KERNELS_FLASH_ATTN_BENCHMARK'), - "HF Kernels Flash Attn3": os.environ.get('UVNOTE_FILE_HF_KERNELS_FLASH_ATTN3_BENCHMARK'), -} - -print("LOADING BENCHMARK DATA") -for name, cache_dir in cache_dirs.items(): - print(f"{name:30s}: {cache_dir}") -print() - -file_mapping = { - "Flash (PyTorch SDPA)": "attn.jsonl", - "MemEff (PyTorch SDPA)": "attn.jsonl", - "Flash Attn 2": "attn.jsonl", - "xFormers": "attn.jsonl", - "SageAttention": "attn.jsonl", - "Compiled (default)": "attn_default.jsonl", - "Compiled (max-autotune)": "attn_max_autotune.jsonl", - "HF Kernels Flash Attn": "attn.jsonl", - "HF Kernels Flash Attn3": "attn.jsonl", -} - -all_paths = [] -for name, cache_dir in cache_dirs.items(): - if cache_dir: - path = Path(cache_dir) / file_mapping[name] - if path.exists() and path.stat().st_size > 0: - all_paths.append(str(path)) - print(f"✓ Found {name}: {path}") - else: - print(f"⊘ Empty/Missing {name}: {path}") - else: - print(f"✗ No cache dir for {name}") -print() - -if not all_paths: - print("ERROR: No benchmark data files found!") - # restore patched functions before exiting - plt.savefig = _orig_savefig - plt.close = _orig_close - sys.exit(1) - -# --- Summary + Visualization ----------------------------------------------------- -print("COMBINED BENCHMARK SUMMARY\n") -kbt.summarize(all_paths) -print("\nGENERATING COMBINED VISUALIZATION\n") - -try: - # If kbt.viz saves internally, our patched savefig ensures SVG gets written, - # and it will carry ids/classes for CSS styling. - kbt.viz(all_paths) - # Safety net: if kbt.viz didn't save, save now. - # if not Path("latency.svg").exists(): - # _tag_current_figure() - # plt.savefig("latency.svg") - - plt.savefig("latency.svg") # ensure saved with tagging - - print("✓ SVG visualization ready: latency.svg!") -except ImportError as e: - print(f"✗ Visualization requires matplotlib: {e}") -except Exception as e: - print(f"✗ Visualization failed: {e}") -finally: - # Clean up patches to avoid side effects in later cells - plt.savefig = _orig_savefig - plt.close = _orig_close - -print() -print("ANALYSIS COMPLETE") -print(f"Total implementations analyzed: {len(all_paths)}") -print(f"\nImplementations included:") -for name, cache_dir in cache_dirs.items(): - if cache_dir: - path = Path(cache_dir) / file_mapping[name] - if path.exists() and path.stat().st_size > 0: - print(f" ✓ {name}") - - - -# Collect all benchmark data and export to CSV -all_data = {} -for name, cache_dir in cache_dirs.items(): - if cache_dir: - path = Path(cache_dir) / file_mapping[name] - if path.exists() and path.stat().st_size > 0: - with open(path, 'r') as f: - records = [json.loads(line) for line in f] - all_data[name] = records - -# Export to CSV -csv_path = Path("latency.csv") -with open(csv_path, 'w', newline='') as csvfile: - writer = csv.writer(csvfile) - - # Write header - header = ["Implementation", "Impl ID", "Workload", "Batch", "Seq Length", "Heads", "Head Dim", "Dtype", - "Mean (ms)", "P10 (ms)", "P50 (ms)", "P90 (ms)", "Reps", - # "Compile (ms)", - "Peak Mem (MB)", "Backend", "Family"] - writer.writerow(header) - - # Write data rows - for impl_name, records in all_data.items(): - for record in records: - wl = record.get('wl', {}) - lat = record.get('lat_ms', {}) - tags = record.get('tags', {}) - - row = [ - impl_name, - record.get('impl', ''), - wl.get('name', ''), - wl.get('batch', ''), - wl.get('seq_len', ''), - wl.get('heads', ''), - wl.get('head_dim', ''), - wl.get('dtype', ''), - lat.get('mean', ''), - lat.get('p10', ''), - lat.get('p50', ''), - lat.get('p90', ''), - lat.get('reps', ''), - # record.get('compile_ms', ''), - round(record.get('peak_bytes', 0) / 1024 / 1024, 2) if record.get('peak_bytes') else '', - tags.get('backend', ''), - tags.get('family', ''), - ] - writer.writerow(row) - -print(f"✓ CSV export complete: {csv_path}") -print(f"Total implementations: {len(all_data)}") -print(f"Total records: {sum(len(records) for records in all_data.values())}") diff --git a/flash_attn/results/combined_results.html b/flash_attn/results/combined_results.html deleted file mode 100644 index c3231e96afd1eeac602f702cc3b1291530a0a629..0000000000000000000000000000000000000000 --- a/flash_attn/results/combined_results.html +++ /dev/null @@ -1,7236 +0,0 @@ - - - - - - Flash Attention Benchmark - Combined Results - - - - - - - -
-
- - ← back - -
light
-
reset
- -
-
- -
-
Generated on:
-
- Linux x86_64 | Linux-6.12.40-64.114.amzn2023.x86_64-x86_64-with-glibc2.36 -
-
- -
-

Flash Attention Benchmarks - Aggregated Results

-

This document combines benchmark results from multiple attention implementations -using cross-file dependencies.

-

Combined Summary and Visualization

-
- - - - - - - 2025-10-02T20:00:13.145631 - image/svg+xml - - - Matplotlib v3.10.6, https://matplotlib.org/ - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - flux_L128 - - - - - - - - - - - - - flux_L256 - - - - - - - - - - - - - flux_L320 - - - - - - - - - - - - - flux_L384 - - - - - - - - - - - - - flux_L448 - - - - - - - - - - - - - flux_L512 - - - - Workload - - - - - - - - - - - - - - - - - 0.4 - - - - - - - - - - - - - 0.5 - - - - - - - - - - - - - 0.6 - - - - - - - - - - - - - 0.7 - - - - - - - - - - - - - 0.8 - - - - - - - - - - - - - 0.9 - - - - Latency P50 (ms) - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - Attention Implementation Latency - - - - - - - - - - - - - torch_flash_ma - - - - - - - - - torch_mem_eff - - - - - - - - - xformers_meff - - - - - - - - - torch_flash_compiled_default - - - - - - - - - torch_flash_compiled_max_autotune - - - - - - - - - hf_kernels_flash_attn - - - - - - - - - hf_kernels_flash_attn3 - - - - - - - - - - -
- -
- - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - -
ImplementationImpl IDWorkloadBatchSeq LengthHeadsHead DimDtypeMean (ms)P10 (ms)P50 (ms)P90 (ms)RepsPeak Mem (MB)BackendFamily
Flash (PyTorch SDPA)torch_flash_maflux_L1281115224128bfloat160.494112008810043360.488447993993759160.49360001087188720.4944640100002289583.38FLASHtorch-sdpa
Flash (PyTorch SDPA)torch_flash_maflux_L2561128024128bfloat160.52341120243072510.52243202924728390.52351999282836910.5235840082168579590.62FLASHtorch-sdpa
Flash (PyTorch SDPA)torch_flash_maflux_L3201134424128bfloat160.65272321701049810.65030401945114140.65248000621795650.6545600295066833595.06FLASHtorch-sdpa
Flash (PyTorch SDPA)torch_flash_maflux_L3841140824128bfloat160.6828032135963440.68057602643966670.68287998437881470.6832640171051025599.88FLASHtorch-sdpa
Flash (PyTorch SDPA)torch_flash_maflux_L4481147224128bfloat160.70754561424255370.70576000213623050.70633602142333980.70707201957702645103.81FLASHtorch-sdpa
Flash (PyTorch SDPA)torch_flash_maflux_L5121153624128bfloat160.73797119855880730.73686397075653080.73724800348281860.73910397291183475109.12FLASHtorch-sdpa
MemEff (PyTorch SDPA)torch_mem_effflux_L1281115224128bfloat160.58742399215698240.58617597818374630.58732801675796510.5877439975738525583.38EFFICIENTtorch-sdpa
MemEff (PyTorch SDPA)torch_mem_effflux_L2561128024128bfloat160.65027199983596810.64902400970458980.6491839885711670.6517760157585144590.62EFFICIENTtorch-sdpa
MemEff (PyTorch SDPA)torch_mem_effflux_L3201134424128bfloat160.78120958805084230.77616000175476070.78031998872756960.7852799892425537595.94EFFICIENTtorch-sdpa
MemEff (PyTorch SDPA)torch_mem_effflux_L3841140824128bfloat160.79484800100326540.79119998216629030.79353600740432740.79484802484512335100.0EFFICIENTtorch-sdpa
MemEff (PyTorch SDPA)torch_mem_effflux_L4481147224128bfloat160.84632959365844730.84499198198318480.84598398208618160.84611201286315925103.81EFFICIENTtorch-sdpa
MemEff (PyTorch SDPA)torch_mem_effflux_L5121153624128bfloat160.95386879444122320.94928002357482910.95183998346328740.95817601680755625109.12EFFICIENTtorch-sdpa
xFormersxformers_meffflux_L1281115224128bfloat160.45150719285011290.443648010492324830.45247998833656310.4557119905948639583.38memory_efficientxformers
xFormersxformers_meffflux_L2561128024128bfloat160.467871999740600560.464895993471145630.46841600537300110.46908798813819885590.62memory_efficientxformers
xFormersxformers_meffflux_L3201134424128bfloat160.60014719963073730.5969920158386230.59846401214599610.6016640067100525595.06memory_efficientxformers
xFormersxformers_meffflux_L3841140824128bfloat160.60232319831848150.59974402189254760.60310399532318120.6032639741897583599.88memory_efficientxformers
xFormersxformers_meffflux_L4481147224128bfloat160.64111360311508180.63817602396011350.64147198200225830.64214402437210085103.81memory_efficientxformers
xFormersxformers_meffflux_L5121153624128bfloat160.65946880578994750.64412802457809450.64966398477554320.65276801586151125109.12memory_efficientxformers
Compiled (default)torch_flash_compiled_defaultflux_L1281115224128bfloat160.51814398765563960.51417601108551030.51756799221038820.5197759866714478583.38FLASHtorch-sdpa
Compiled (default)torch_flash_compiled_defaultflux_L2561128024128bfloat160.55795840024948120.55491197109222410.55827200412750240.5598080158233643590.62FLASHtorch-sdpa
Compiled (default)torch_flash_compiled_defaultflux_L3201134424128bfloat160.68729599714279180.68531197309494020.6873919963836670.6883519887924194595.25FLASHtorch-sdpa
Compiled (default)torch_flash_compiled_defaultflux_L3841140824128bfloat160.7161535978317260.71286398172378540.71609598398208620.7167680263519287599.88FLASHtorch-sdpa
Compiled (default)torch_flash_compiled_defaultflux_L4481147224128bfloat160.74183039665222170.73868799209594730.74009597301483150.74150401353836065103.81FLASHtorch-sdpa
Compiled (default)torch_flash_compiled_defaultflux_L5121153624128bfloat160.77454719543457030.77081602811813350.77407997846603390.77539199590682985109.12FLASHtorch-sdpa
Compiled (max-autotune)torch_flash_compiled_max_autotuneflux_L1281115224128bfloat160.64680960178375250.61440002918243410.62457597255706790.6483200192451477567.5FLASHtorch-sdpa
Compiled (max-autotune)torch_flash_compiled_max_autotuneflux_L2561128024128bfloat160.70601601600646970.66892802715301510.68511998653411870.7184960246086121575.0FLASHtorch-sdpa
Compiled (max-autotune)torch_flash_compiled_max_autotuneflux_L3201134424128bfloat160.83326081037521360.79536002874374390.81558400392532350.8403519988059998580.38FLASHtorch-sdpa
Compiled (max-autotune)torch_flash_compiled_max_autotuneflux_L3841140824128bfloat160.87192959785461420.84707200527191160.8497279882431030.8745279908180237582.5FLASHtorch-sdpa
Compiled (max-autotune)torch_flash_compiled_max_autotuneflux_L4481147224128bfloat160.90343040227890010.86771202087402340.88355201482772830.9034240245819092586.25FLASHtorch-sdpa
Compiled (max-autotune)torch_flash_compiled_max_autotuneflux_L5121153624128bfloat160.93875198364257820.91542398929595950.92137598991394040.9359679818153381590.0FLASHtorch-sdpa
HF Kernels Flash Attnhf_kernels_flash_attnflux_L1281115224128bfloat160.34552959799766540.343551993370056150.345631986856460570.34643200039863586583.38flash-attnhf-kernels
HF Kernels Flash Attnhf_kernels_flash_attnflux_L2561128024128bfloat160.37561600804328920.374112010002136230.37520000338554380.3770880103111267590.62flash-attnhf-kernels
HF Kernels Flash Attnhf_kernels_flash_attnflux_L3201134424128bfloat160.49532160162925720.493247985839843750.494336009025573730.49663999676704407595.06flash-attnhf-kernels
HF Kernels Flash Attnhf_kernels_flash_attnflux_L3841140824128bfloat160.51570558547973630.51427197456359860.5163199901580810.516543984413147599.88flash-attnhf-kernels
HF Kernels Flash Attnhf_kernels_flash_attnflux_L4481147224128bfloat160.53566720485687260.53468799591064450.53580802679061890.53615999221801765103.81flash-attnhf-kernels
HF Kernels Flash Attnhf_kernels_flash_attnflux_L5121153624128bfloat160.55871360301971440.55577600002288820.55740797519683840.55811202526092535109.12flash-attnhf-kernels
HF Kernels Flash Attn3hf_kernels_flash_attn3flux_L1281115224128bfloat160.36197119951248170.36038398742675780.3619520068168640.3624640107154846583.38flash-attn3hf-kernels
HF Kernels Flash Attn3hf_kernels_flash_attn3flux_L2561128024128bfloat160.39124479889869690.38927999138832090.39097601175308230.3922559916973114590.62flash-attn3hf-kernels
HF Kernels Flash Attn3hf_kernels_flash_attn3flux_L3201134424128bfloat160.52580481767654420.52406400442123410.52489602565765380.5248960256576538595.06flash-attn3hf-kernels
HF Kernels Flash Attn3hf_kernels_flash_attn3flux_L3841140824128bfloat160.52760320901870730.52656000852584840.52777600288391110.5282559990882874599.88flash-attn3hf-kernels
HF Kernels Flash Attn3hf_kernels_flash_attn3flux_L4481147224128bfloat160.56563839912414550.56390398740768430.56579202413558960.56684798002243045103.81flash-attn3hf-kernels
HF Kernels Flash Attn3hf_kernels_flash_attn3flux_L5121153624128bfloat160.57899520397186280.56896001100540160.56982398033142090.57139199972152715109.12flash-attn3hf-kernels
-
- -
-
- -▶ code -▶ output - ▶ uv-logs - | -Cell: combine | 36.89s - | - -Raw -
- - -
- - -
- - - \ No newline at end of file diff --git a/flash_attn/results/index.html b/flash_attn/results/index.html deleted file mode 100644 index b87b6002f4b781572dbb50f91850e50ee98130ab..0000000000000000000000000000000000000000 --- a/flash_attn/results/index.html +++ /dev/null @@ -1,88 +0,0 @@ - - - - - - Index of /flash_attn/results - - - -
- ← back -
-

Index of /flash_attn/results

- - - \ No newline at end of file diff --git a/index.html b/index.html deleted file mode 100644 index df44040e2dd9e1e4a0fc2d5ee08453d4b9953f11..0000000000000000000000000000000000000000 --- a/index.html +++ /dev/null @@ -1,85 +0,0 @@ - - - - - - Index of / - - - -

Index of /

- - - \ No newline at end of file diff --git a/megablocks/cells/forward_and_backward.py b/megablocks/cells/forward_and_backward.py deleted file mode 100644 index a8ac420c8a43009eb857f3a7889b4f79ad5a1191..0000000000000000000000000000000000000000 --- a/megablocks/cells/forward_and_backward.py +++ /dev/null @@ -1,196 +0,0 @@ -# /// script -# requires-python = ">=3.12" -# dependencies = [ -# "accelerate>=1.10.1", -# "torch>=2.7.0", -# "kernels==0.10.0", -# "transformers@https://github.com/huggingface/transformers.git", -# "ipdb>=0.13.13", -# "matplotlib>=3.7.2", -# "numpy>=1.24.3", -# ] -# /// - -import torch -from transformers import GptOssForCausalLM, PreTrainedTokenizerFast, Mxfp4Config -import time -import torch.nn as nn -from kernels import register_kernel_mapping, Mode, LayerRepository, replace_kernel_forward_from_hub -import sys -import torch.profiler -import gc -import logging -from transformers.models.gpt_oss.modeling_gpt_oss import GptOssRMSNorm - -# remove liger kernel for testing -replace_kernel_forward_from_hub(GptOssRMSNorm, None) - -# set to debug logging -logging.basicConfig(level=logging.INFO) - -def reset_peak_memory_stats(): - """Clear CUDA cache and reset memory allocation counters.""" - torch.cuda.empty_cache() - if torch.cuda.is_available(): - torch.cuda.reset_peak_memory_stats() - gc.collect() - -def get_memory_stats(): - """Get current and peak CUDA memory usage.""" - if not torch.cuda.is_available(): - return {"allocated_gb": 0, "peak_gb": 0, "reserved_gb": 0} - return { - "allocated_gb": torch.cuda.memory_allocated() / 1e9, - "peak_gb": torch.cuda.max_memory_allocated() / 1e9, - "reserved_gb": torch.cuda.memory_reserved() / 1e9, - } - -def override_kernel_layer_name(cls_name: str, value) -> bool: - """Helper to dynamically override the kernel_layer_name in a model class.""" - for mod in sys.modules.values(): - if mod is None: - continue - obj = getattr(mod, cls_name, None) - if isinstance(obj, type) and issubclass(obj, nn.Module): - setattr(obj, "kernel_layer_name", value) - print(f"Overrode {cls_name}.kernel_layer_name to {value}") - return True - return False - - -# Init the model the normal way -model_id = "openai/gpt-oss-20b" -tokenizer = PreTrainedTokenizerFast.from_pretrained(model_id) -quantization_config = Mxfp4Config(dequantize=True) - -model = GptOssForCausalLM.from_pretrained( - model_id, - dtype="bfloat16", - device_map="auto", - use_kernels=True, - quantization_config=quantization_config, -).eval() - -messages = [ - {"role": "system", "content": "What is Tensor Parallelism?"}, -] - -inputs = tokenizer.apply_chat_template( - messages, - add_generation_prompt=True, - return_tensors="pt", - return_dict=True, - reasoning_effort="low", -).to("cuda") - -max_tokens = 128 # Reduced to help with memory usage - -# Clear memory before backward pass -reset_peak_memory_stats() -print(f"Pre-generation memory: {get_memory_stats()}") - -# forward and backward pass -with torch.autograd.set_grad_enabled(True): - start_time = time.perf_counter() - generated = model.generate( - **inputs, - max_new_tokens=max_tokens, - do_sample=False, - temperature=None, - ) - end_time = time.perf_counter() - print(tokenizer.decode(generated[0], skip_special_tokens=False)) - print(f"Generation took {end_time - start_time:.2f} seconds") - print(f"Post-generation memory: {get_memory_stats()}") - - # Use gradient checkpointing to reduce memory usage - if hasattr(model, 'gradient_checkpointing_enable'): - model.gradient_checkpointing_enable() - print("Enabled gradient checkpointing") - - # Reduce sequence length if needed for memory - max_seq_len = 512 # Limit sequence length for backward pass - if generated.size(1) > max_seq_len: - print(f"Truncating sequence from {generated.size(1)} to {max_seq_len} tokens") - full_sequence = generated[:, -max_seq_len:] - else: - full_sequence = generated - - # Get model outputs for the full sequence - model.train() # Enable dropout and other training behaviors - - try: - outputs = model( - input_ids=full_sequence, - labels=full_sequence, # This will compute loss internally - return_dict=True - ) - print(f"Post-forward memory: {get_memory_stats()}") - - # If model doesn't compute loss, compute it manually - if outputs.loss is None: - shift_logits = outputs.logits[..., :-1, :].contiguous() - shift_labels = full_sequence[..., 1:].contiguous() - - # Use CrossEntropyLoss with ignore_index for padding tokens - loss_fct = torch.nn.CrossEntropyLoss(ignore_index=tokenizer.pad_token_id if tokenizer.pad_token_id is not None else -100) - loss = loss_fct( - shift_logits.view(-1, shift_logits.size(-1)), - shift_labels.view(-1) - ) - else: - loss = outputs.loss - - print(f"Loss: {loss.item():.4f}") - - # Clear intermediate tensors to save memory - del outputs - torch.cuda.empty_cache() - - # Perform backward pass with memory management - print("Running backward pass...") - print(f"Pre-backward memory: {get_memory_stats()}") - - loss.backward() - print(f"Post-backward memory: {get_memory_stats()}") - - except torch.cuda.OutOfMemoryError as e: - print(f"OOM during forward/backward pass: {e}") - print("Try reducing max_tokens or max_seq_len") - raise - - # Calculate gradient statistics and print sample gradients - total_norm = 0.0 - param_count = 0 - grad_samples = {} - - for name, p in model.named_parameters(): - if p.grad is not None: - param_count += 1 - grad_norm = p.grad.data.norm(2).item() - total_norm += grad_norm ** 2 - - # Collect gradient statistics for key layers - if any(key in name for key in ['embed', 'lm_head', 'mlp.up', 'mlp.down', 'self_attn.q_proj', 'norm']): - grad_samples[name] = { - 'norm': grad_norm, - 'mean': p.grad.data.mean().item(), - 'std': p.grad.data.std().item(), - 'max': p.grad.data.max().item(), - 'min': p.grad.data.min().item(), - } - - total_norm = total_norm ** 0.5 - - print(f"\nGradient norm: {total_norm:.4f}") - print(f"Parameters with gradients: {param_count}") - - # Print sample gradients from important layers - print("\nSample gradient statistics:") - for i, (name, stats) in enumerate(list(grad_samples.items())[:10]): - print(f" {name[:60]:<60} | norm: {stats['norm']:.4e} | mean: {stats['mean']:.4e} | std: {stats['std']:.4e}") - - # Optional: zero gradients for next iteration - model.zero_grad() - model.eval() # Switch back to eval mode - diff --git a/megablocks/cells/forward_and_backward_no_kernel.py b/megablocks/cells/forward_and_backward_no_kernel.py deleted file mode 100644 index d56805f64c56b484df98c41b9e62d3b6f27ff088..0000000000000000000000000000000000000000 --- a/megablocks/cells/forward_and_backward_no_kernel.py +++ /dev/null @@ -1,196 +0,0 @@ -# /// script -# requires-python = ">=3.12" -# dependencies = [ -# "accelerate>=1.10.1", -# "torch>=2.7.0", -# "kernels==0.10.0", -# "transformers@https://github.com/huggingface/transformers.git", -# "ipdb>=0.13.13", -# "matplotlib>=3.7.2", -# "numpy>=1.24.3", -# ] -# /// - -import torch -from transformers import GptOssForCausalLM, PreTrainedTokenizerFast, Mxfp4Config -import time -import torch.nn as nn -from kernels import register_kernel_mapping, Mode, LayerRepository, replace_kernel_forward_from_hub -import sys -import torch.profiler -import gc -import logging -from transformers.models.gpt_oss.modeling_gpt_oss import GptOssRMSNorm - -# remove liger kernel for testing -replace_kernel_forward_from_hub(GptOssRMSNorm, None) - -# set to debug logging -logging.basicConfig(level=logging.INFO) - -def reset_peak_memory_stats(): - """Clear CUDA cache and reset memory allocation counters.""" - torch.cuda.empty_cache() - if torch.cuda.is_available(): - torch.cuda.reset_peak_memory_stats() - gc.collect() - -def get_memory_stats(): - """Get current and peak CUDA memory usage.""" - if not torch.cuda.is_available(): - return {"allocated_gb": 0, "peak_gb": 0, "reserved_gb": 0} - return { - "allocated_gb": torch.cuda.memory_allocated() / 1e9, - "peak_gb": torch.cuda.max_memory_allocated() / 1e9, - "reserved_gb": torch.cuda.memory_reserved() / 1e9, - } - -def override_kernel_layer_name(cls_name: str, value) -> bool: - """Helper to dynamically override the kernel_layer_name in a model class.""" - for mod in sys.modules.values(): - if mod is None: - continue - obj = getattr(mod, cls_name, None) - if isinstance(obj, type) and issubclass(obj, nn.Module): - setattr(obj, "kernel_layer_name", value) - print(f"Overrode {cls_name}.kernel_layer_name to {value}") - return True - return False - - -# Init the model the normal way -model_id = "openai/gpt-oss-20b" -tokenizer = PreTrainedTokenizerFast.from_pretrained(model_id) -quantization_config = Mxfp4Config(dequantize=True) - -model = GptOssForCausalLM.from_pretrained( - model_id, - dtype="bfloat16", - device_map="auto", - use_kernels=False, - quantization_config=quantization_config, -).eval() - -messages = [ - {"role": "system", "content": "What is Tensor Parallelism?"}, -] - -inputs = tokenizer.apply_chat_template( - messages, - add_generation_prompt=True, - return_tensors="pt", - return_dict=True, - reasoning_effort="low", -).to("cuda") - -max_tokens = 128 # Reduced to help with memory usage - -# Clear memory before backward pass -reset_peak_memory_stats() -print(f"Pre-generation memory: {get_memory_stats()}") - -# forward and backward pass -with torch.autograd.set_grad_enabled(True): - start_time = time.perf_counter() - generated = model.generate( - **inputs, - max_new_tokens=max_tokens, - do_sample=False, - temperature=None, - ) - end_time = time.perf_counter() - print(tokenizer.decode(generated[0], skip_special_tokens=False)) - print(f"Generation took {end_time - start_time:.2f} seconds") - print(f"Post-generation memory: {get_memory_stats()}") - - # Use gradient checkpointing to reduce memory usage - if hasattr(model, 'gradient_checkpointing_enable'): - model.gradient_checkpointing_enable() - print("Enabled gradient checkpointing") - - # Reduce sequence length if needed for memory - max_seq_len = 512 # Limit sequence length for backward pass - if generated.size(1) > max_seq_len: - print(f"Truncating sequence from {generated.size(1)} to {max_seq_len} tokens") - full_sequence = generated[:, -max_seq_len:] - else: - full_sequence = generated - - # Get model outputs for the full sequence - model.train() # Enable dropout and other training behaviors - - try: - outputs = model( - input_ids=full_sequence, - labels=full_sequence, # This will compute loss internally - return_dict=True - ) - print(f"Post-forward memory: {get_memory_stats()}") - - # If model doesn't compute loss, compute it manually - if outputs.loss is None: - shift_logits = outputs.logits[..., :-1, :].contiguous() - shift_labels = full_sequence[..., 1:].contiguous() - - # Use CrossEntropyLoss with ignore_index for padding tokens - loss_fct = torch.nn.CrossEntropyLoss(ignore_index=tokenizer.pad_token_id if tokenizer.pad_token_id is not None else -100) - loss = loss_fct( - shift_logits.view(-1, shift_logits.size(-1)), - shift_labels.view(-1) - ) - else: - loss = outputs.loss - - print(f"Loss: {loss.item():.4f}") - - # Clear intermediate tensors to save memory - del outputs - torch.cuda.empty_cache() - - # Perform backward pass with memory management - print("Running backward pass...") - print(f"Pre-backward memory: {get_memory_stats()}") - - loss.backward() - print(f"Post-backward memory: {get_memory_stats()}") - - except torch.cuda.OutOfMemoryError as e: - print(f"OOM during forward/backward pass: {e}") - print("Try reducing max_tokens or max_seq_len") - raise - - # Calculate gradient statistics and print sample gradients - total_norm = 0.0 - param_count = 0 - grad_samples = {} - - for name, p in model.named_parameters(): - if p.grad is not None: - param_count += 1 - grad_norm = p.grad.data.norm(2).item() - total_norm += grad_norm ** 2 - - # Collect gradient statistics for key layers - if any(key in name for key in ['embed', 'lm_head', 'mlp.up', 'mlp.down', 'self_attn.q_proj', 'norm']): - grad_samples[name] = { - 'norm': grad_norm, - 'mean': p.grad.data.mean().item(), - 'std': p.grad.data.std().item(), - 'max': p.grad.data.max().item(), - 'min': p.grad.data.min().item(), - } - - total_norm = total_norm ** 0.5 - - print(f"\nGradient norm: {total_norm:.4f}") - print(f"Parameters with gradients: {param_count}") - - # Print sample gradients from important layers - print("\nSample gradient statistics:") - for i, (name, stats) in enumerate(list(grad_samples.items())[:10]): - print(f" {name[:60]:<60} | norm: {stats['norm']:.4e} | mean: {stats['mean']:.4e} | std: {stats['std']:.4e}") - - # Optional: zero gradients for next iteration - model.zero_grad() - model.eval() # Switch back to eval mode - diff --git a/megablocks/cells/forward_only.py b/megablocks/cells/forward_only.py deleted file mode 100644 index c72358d0eef5e1f993aef1e76dfb0f26761c4881..0000000000000000000000000000000000000000 --- a/megablocks/cells/forward_only.py +++ /dev/null @@ -1,101 +0,0 @@ -# /// script -# requires-python = ">=3.12" -# dependencies = [ -# "accelerate>=1.10.1", -# "torch>=2.7.0", -# "kernels==0.10.0", -# "transformers@https://github.com/huggingface/transformers.git", -# "ipdb>=0.13.13", -# "matplotlib>=3.7.2", -# "numpy>=1.24.3", -# ] -# /// - -import torch -from transformers import GptOssForCausalLM, PreTrainedTokenizerFast, Mxfp4Config -import time -import torch.nn as nn -from kernels import register_kernel_mapping, Mode, LayerRepository, replace_kernel_forward_from_hub -import sys -import torch.profiler -import gc -import logging -from transformers.models.gpt_oss.modeling_gpt_oss import GptOssRMSNorm - - -replace_kernel_forward_from_hub(GptOssRMSNorm, None) - -# set to debug logging -logging.basicConfig(level=logging.INFO) - -def reset_peak_memory_stats(): - """Clear CUDA cache and reset memory allocation counters.""" - torch.cuda.empty_cache() - if torch.cuda.is_available(): - torch.cuda.reset_peak_memory_stats() - gc.collect() - -def get_memory_stats(): - """Get current and peak CUDA memory usage.""" - if not torch.cuda.is_available(): - return {"allocated_gb": 0, "peak_gb": 0, "reserved_gb": 0} - return { - "allocated_gb": torch.cuda.memory_allocated() / 1e9, - "peak_gb": torch.cuda.max_memory_allocated() / 1e9, - "reserved_gb": torch.cuda.memory_reserved() / 1e9, - } - -def override_kernel_layer_name(cls_name: str, value) -> bool: - """Helper to dynamically override the kernel_layer_name in a model class.""" - for mod in sys.modules.values(): - if mod is None: - continue - obj = getattr(mod, cls_name, None) - if isinstance(obj, type) and issubclass(obj, nn.Module): - setattr(obj, "kernel_layer_name", value) - print(f"Overrode {cls_name}.kernel_layer_name to {value}") - return True - return False - - -# Init the model the normal way -model_id = "openai/gpt-oss-20b" -tokenizer = PreTrainedTokenizerFast.from_pretrained(model_id) -quantization_config = Mxfp4Config(dequantize=True) - - - -model = GptOssForCausalLM.from_pretrained( - model_id, - dtype="bfloat16", - device_map="auto", - use_kernels=True, - quantization_config=quantization_config, -).eval() - -messages = [ - {"role": "system", "content": "What is Tensor Parallelism?"}, -] - -inputs = tokenizer.apply_chat_template( - messages, - add_generation_prompt=True, - return_tensors="pt", - return_dict=True, - reasoning_effort="low", -).to("cuda") - -max_tokens = 256 - -with torch.inference_mode(): - start_time = time.perf_counter() - generated = model.generate( - **inputs, - max_new_tokens=max_tokens, - do_sample=False, - temperature=None, - ) - end_time = time.perf_counter() - -print(tokenizer.decode(generated[0], skip_special_tokens=False)) -print(f"Generation took {end_time - start_time:.2f} seconds") diff --git a/megablocks/cells/no_kernels.py b/megablocks/cells/no_kernels.py deleted file mode 100644 index 28857794c6623cd7e1ff39059ae75b06dde2be33..0000000000000000000000000000000000000000 --- a/megablocks/cells/no_kernels.py +++ /dev/null @@ -1,98 +0,0 @@ -# /// script -# requires-python = ">=3.12" -# dependencies = [ -# "accelerate>=1.10.1", -# "torch>=2.7.0", -# "kernels==0.10.0", -# "transformers@https://github.com/huggingface/transformers.git", -# "ipdb>=0.13.13", -# "matplotlib>=3.7.2", -# "numpy>=1.24.3", -# ] -# /// - -import torch -from transformers import GptOssForCausalLM, PreTrainedTokenizerFast, Mxfp4Config -import time -import torch.nn as nn -from kernels import register_kernel_mapping, Mode, LayerRepository, replace_kernel_forward_from_hub -import sys -import torch.profiler -import gc -import logging -from transformers.models.gpt_oss.modeling_gpt_oss import GptOssRMSNorm - -# set to debug logging -logging.basicConfig(level=logging.INFO) - -def reset_peak_memory_stats(): - """Clear CUDA cache and reset memory allocation counters.""" - torch.cuda.empty_cache() - if torch.cuda.is_available(): - torch.cuda.reset_peak_memory_stats() - gc.collect() - -def get_memory_stats(): - """Get current and peak CUDA memory usage.""" - if not torch.cuda.is_available(): - return {"allocated_gb": 0, "peak_gb": 0, "reserved_gb": 0} - return { - "allocated_gb": torch.cuda.memory_allocated() / 1e9, - "peak_gb": torch.cuda.max_memory_allocated() / 1e9, - "reserved_gb": torch.cuda.memory_reserved() / 1e9, - } - -def override_kernel_layer_name(cls_name: str, value) -> bool: - """Helper to dynamically override the kernel_layer_name in a model class.""" - for mod in sys.modules.values(): - if mod is None: - continue - obj = getattr(mod, cls_name, None) - if isinstance(obj, type) and issubclass(obj, nn.Module): - setattr(obj, "kernel_layer_name", value) - print(f"Overrode {cls_name}.kernel_layer_name to {value}") - return True - return False - - -# Init the model the normal way -model_id = "openai/gpt-oss-20b" -tokenizer = PreTrainedTokenizerFast.from_pretrained(model_id) -quantization_config = Mxfp4Config(dequantize=True) - - - -model = GptOssForCausalLM.from_pretrained( - model_id, - dtype="bfloat16", - device_map="auto", - use_kernels=False, - quantization_config=quantization_config, -).eval() - -messages = [ - {"role": "system", "content": "What is Tensor Parallelism?"}, -] - -inputs = tokenizer.apply_chat_template( - messages, - add_generation_prompt=True, - return_tensors="pt", - return_dict=True, - reasoning_effort="low", -).to("cuda") - -max_tokens = 256 - -with torch.inference_mode(): - start_time = time.perf_counter() - generated = model.generate( - **inputs, - max_new_tokens=max_tokens, - do_sample=False, - temperature=None, - ) - end_time = time.perf_counter() - -print(tokenizer.decode(generated[0], skip_special_tokens=False)) -print(f"Generation took {end_time - start_time:.2f} seconds") diff --git a/megablocks/cells/nv.py b/megablocks/cells/nv.py deleted file mode 100644 index 80eef60a7536ed875fb21731ab2d059458bd20b4..0000000000000000000000000000000000000000 --- a/megablocks/cells/nv.py +++ /dev/null @@ -1,3 +0,0 @@ -import subprocess - -print(subprocess.run(["nvidia-smi"], capture_output=True, text=True).stdout) \ No newline at end of file diff --git a/megablocks/index.html b/megablocks/index.html deleted file mode 100644 index 1f92e85585db921abce97173a251f506939fe8e8..0000000000000000000000000000000000000000 --- a/megablocks/index.html +++ /dev/null @@ -1,24 +0,0 @@ - - - - - Directory Index - - - -

Index of /megablocks

- - - \ No newline at end of file diff --git a/megablocks/megablocks_only.html b/megablocks/megablocks_only.html deleted file mode 100644 index 441195330aafcc34612fc835c6b554a1f98cb9cb..0000000000000000000000000000000000000000 --- a/megablocks/megablocks_only.html +++ /dev/null @@ -1,3970 +0,0 @@ - - - - - - Megablocks Only Test - - - - - - - -
-
-
light
-
reset
- -
-
- -
-
Generated on:
-
- Linux x86_64 | Linux-6.12.40-64.114.amzn2023.x86_64-x86_64-with-glibc2.36 -
-
- -
-

No Kernels

-

First, we run the model without any custom kernels to get a reference point.

-

Forward

-
-
- -▼ code -▼ output - ▶ uv-logs - | -Cell: no_kernels | 19.21s | FAILED - | - -Raw -
-
-
-
-1 -2 -3 -4 -5 -6 -7 -8 -9 -10 -11 -12 -13 -14 -15 -16 -17 -18 -19 -20 -21 -22 -23 -24 -25 -26 -27 -28 -29 -30 -31 -32 -33 -34 -35 -36 -37 -38 -39 -40 -41 -42 -43 -44 -45 -46 -47 -48 -49 -50 -51 -52 -53 -54 -55 -56 -57 -58 -59 -60 -61 -62 -63 -64 -65 -66 -67 -68 -69 -70 -71 -72 -73 -74 -75 -76 -77 -78 -79 -80 -81 -82 -83 -84 -85 -86 -87 -88 -89 -90 -91 -92 -93 -94 -95 -96 -97 -98 -
-
-
# /// script
-# requires-python = ">=3.12"
-# dependencies = [
-#     "accelerate>=1.10.1",
-#     "torch>=2.7.0",
-#     "kernels==0.10.0",
-#     "transformers@https://github.com/huggingface/transformers.git",
-#     "ipdb>=0.13.13",
-#     "matplotlib>=3.7.2",
-#     "numpy>=1.24.3",
-# ]
-# ///
-
-import torch
-from transformers import GptOssForCausalLM, PreTrainedTokenizerFast, Mxfp4Config
-import time
-import torch.nn as nn
-from kernels import register_kernel_mapping, Mode, LayerRepository, replace_kernel_forward_from_hub
-import sys
-import torch.profiler
-import gc
-import logging
-from transformers.models.gpt_oss.modeling_gpt_oss import GptOssRMSNorm
-
-# set to debug logging
-logging.basicConfig(level=logging.INFO)
-
-def reset_peak_memory_stats():
-    """Clear CUDA cache and reset memory allocation counters."""
-    torch.cuda.empty_cache()
-    if torch.cuda.is_available():
-        torch.cuda.reset_peak_memory_stats()
-    gc.collect()
-
-def get_memory_stats():
-    """Get current and peak CUDA memory usage."""
-    if not torch.cuda.is_available():
-        return {"allocated_gb": 0, "peak_gb": 0, "reserved_gb": 0}
-    return {
-        "allocated_gb": torch.cuda.memory_allocated() / 1e9,
-        "peak_gb": torch.cuda.max_memory_allocated() / 1e9,
-        "reserved_gb": torch.cuda.memory_reserved() / 1e9,
-    }
-
-def override_kernel_layer_name(cls_name: str, value) -> bool:
-    """Helper to dynamically override the kernel_layer_name in a model class."""
-    for mod in sys.modules.values():
-        if mod is None:
-            continue
-        obj = getattr(mod, cls_name, None)
-        if isinstance(obj, type) and issubclass(obj, nn.Module):
-            setattr(obj, "kernel_layer_name", value)
-            print(f"Overrode {cls_name}.kernel_layer_name to {value}")
-            return True
-    return False
-
-
-# Init the model the normal way
-model_id = "openai/gpt-oss-20b"
-tokenizer = PreTrainedTokenizerFast.from_pretrained(model_id)
-quantization_config = Mxfp4Config(dequantize=True)
-
-
-
-model = GptOssForCausalLM.from_pretrained(
-    model_id,
-    dtype="bfloat16",
-    device_map="auto",
-    use_kernels=False,
-    quantization_config=quantization_config,
-).eval()
-
-messages = [
-    {"role": "system", "content": "What is Tensor Parallelism?"},
-]
-
-inputs = tokenizer.apply_chat_template(
-    messages,
-    add_generation_prompt=True,
-    return_tensors="pt",
-    return_dict=True,
-    reasoning_effort="low",
-).to("cuda")
-
-max_tokens = 256
-
-with torch.inference_mode():
-    start_time = time.perf_counter()
-    generated = model.generate(
-        **inputs,
-        max_new_tokens=max_tokens,
-        do_sample=False,
-        temperature=None,
-    )
-    end_time = time.perf_counter()
-
-print(tokenizer.decode(generated[0], skip_special_tokens=False))
-print(f"Generation took {end_time - start_time:.2f} seconds")
-
- -
-
-
-
-
-
Downloading cpython-3.13.7-linux-x86_64-gnu (download) (32.0MiB) - Downloading cpython-3.13.7-linux-x86_64-gnu (download) - Updating https://github.com/huggingface/transformers.git (HEAD) - Updated https://github.com/huggingface/transformers.git (e691f84412563b6abca098f3e044980725d8daa3) - × No solution found when resolving script dependencies: - ╰─▶ Because only transformers==4.57.0.dev0 is available and - transformers==4.57.0.dev0 depends on huggingface-hub==1.0.0rc1, - we can conclude that all versions of transformers depend on - huggingface-hub==1.0.0rc1. - And because kernels==0.10.0 depends on huggingface-hub>=0.26.0,<1.0, - we can conclude that kernels==0.10.0 and all versions of transformers - are incompatible. - And because you require kernels==0.10.0 and transformers, we can - conclude that your requirements are unsatisfiable. -
-
-
- -

Forward and Backward

-

Next, we'll attempt to run a forward and backward pass without any custom kernels. This will likely run out of memory since the default implementation is not optimized for memory usage.

-

Kernels

-

Next we can run with Megablocks kernels enabled.

-

Forward

-

First, we run a forward pass with Megablocks kernels.

-

Forward and Backward

-

Next, we run a forward and backward pass with Megablocks kernels enabled. This should be more memory efficient and allow us to complete the backward pass without running out of memory.

-
- - - \ No newline at end of file diff --git a/megablocks_yamoe/artifacts/binned_run/binned_results.json b/megablocks_yamoe/artifacts/binned_run/binned_results.json deleted file mode 100644 index 95525a705c5a65d02e8ee4abf5eea6c2bd1a3607..0000000000000000000000000000000000000000 --- a/megablocks_yamoe/artifacts/binned_run/binned_results.json +++ /dev/null @@ -1,24 +0,0 @@ -{ - "implementation": "binned_results", - "config": { - "warmup": 10, - "iters": 50, - "device": "cuda", - "dtype": "torch.float32", - "tokens": 100, - "vary_inputs": true - }, - "stats": { - "avg_ms": 36.26809924006011, - "min_ms": 34.103908000361116, - "max_ms": 37.68557000057626, - "std_ms": 1.1598518125118418, - "p50_ms": 36.52223600056459, - "p95_ms": 37.6427445000445, - "p99_ms": 37.677440410316194, - "num_iters": 50, - "tokens_per_s": 2757.2440269917565, - "throughput_variance": 89.13103199163609 - }, - "output_sum": 3.97190523147583 -} \ No newline at end of file diff --git a/megablocks_yamoe/artifacts/gptoss_run/gptoss_results.json b/megablocks_yamoe/artifacts/gptoss_run/gptoss_results.json deleted file mode 100644 index 7591730a256735e996b1d5635614593e3fac9b3b..0000000000000000000000000000000000000000 --- a/megablocks_yamoe/artifacts/gptoss_run/gptoss_results.json +++ /dev/null @@ -1,24 +0,0 @@ -{ - "implementation": "gptoss_results", - "config": { - "warmup": 10, - "iters": 50, - "device": "cuda", - "dtype": "torch.float32", - "tokens": 100, - "vary_inputs": true - }, - "stats": { - "avg_ms": 46.913985819956, - "min_ms": 40.44806400088419, - "max_ms": 51.07520399997156, - "std_ms": 2.9921332618008196, - "p50_ms": 47.418902999652346, - "p95_ms": 50.800493049837314, - "p99_ms": 50.948625239852845, - "num_iters": 50, - "tokens_per_s": 2131.560519794133, - "throughput_variance": 139.93911554997217 - }, - "output_sum": 11.53223705291748 -} \ No newline at end of file diff --git a/megablocks_yamoe/artifacts/gptoss_training_run/gptoss_training_results.json b/megablocks_yamoe/artifacts/gptoss_training_run/gptoss_training_results.json deleted file mode 100644 index 65743bcf0b442940f85835ec369b26e629ee5583..0000000000000000000000000000000000000000 --- a/megablocks_yamoe/artifacts/gptoss_training_run/gptoss_training_results.json +++ /dev/null @@ -1,24 +0,0 @@ -{ - "implementation": "gptoss_training_results", - "config": { - "warmup": 10, - "iters": 50, - "device": "cuda", - "dtype": "torch.float32", - "tokens": 100, - "vary_inputs": true - }, - "stats": { - "avg_ms": 46.289439859992854, - "min_ms": 39.97907499979192, - "max_ms": 50.58144600025116, - "std_ms": 2.9172154402078077, - "p50_ms": 46.64785849990949, - "p95_ms": 50.26727430031315, - "p99_ms": 50.5162941305025, - "num_iters": 50, - "tokens_per_s": 2160.3199412751637, - "throughput_variance": 139.86427060112865 - }, - "output_sum": 11.53223705291748 -} \ No newline at end of file diff --git a/megablocks_yamoe/artifacts/yamoe_run/yamoe_results.json b/megablocks_yamoe/artifacts/yamoe_run/yamoe_results.json deleted file mode 100644 index 87dcfce09fbf9fb7eeb5f0496295c417f4a85841..0000000000000000000000000000000000000000 --- a/megablocks_yamoe/artifacts/yamoe_run/yamoe_results.json +++ /dev/null @@ -1,24 +0,0 @@ -{ - "implementation": "yamoe_results", - "config": { - "warmup": 10, - "iters": 50, - "device": "cuda", - "dtype": "torch.float32", - "tokens": 100, - "vary_inputs": true - }, - "stats": { - "avg_ms": 4.248197240067384, - "min_ms": 4.136622000260104, - "max_ms": 4.280714999367774, - "std_ms": 0.02141682051311511, - "p50_ms": 4.253484999935608, - "p95_ms": 4.265540049709671, - "p99_ms": 4.273649199667489, - "num_iters": 50, - "tokens_per_s": 23539.396677922097, - "throughput_variance": 120.66648678204231 - }, - "output_sum": 3.97190523147583 -} \ No newline at end of file diff --git a/megablocks_yamoe/cells/__pycache__/bench_utils.cpython-311.pyc b/megablocks_yamoe/cells/__pycache__/bench_utils.cpython-311.pyc deleted file mode 100644 index 28c4e7d268ce2e8e7fea8922428f2cda6a03b8a3..0000000000000000000000000000000000000000 Binary files a/megablocks_yamoe/cells/__pycache__/bench_utils.cpython-311.pyc and /dev/null differ diff --git a/megablocks_yamoe/cells/__pycache__/config.cpython-311.pyc b/megablocks_yamoe/cells/__pycache__/config.cpython-311.pyc deleted file mode 100644 index 0bcc722fe3898e8b68642a05144a24fa037c6a9c..0000000000000000000000000000000000000000 Binary files a/megablocks_yamoe/cells/__pycache__/config.cpython-311.pyc and /dev/null differ diff --git a/megablocks_yamoe/cells/bench_utils.py b/megablocks_yamoe/cells/bench_utils.py deleted file mode 100644 index 6bb3706118149df02c1f7ebaaa6fbba84e71cd5e..0000000000000000000000000000000000000000 --- a/megablocks_yamoe/cells/bench_utils.py +++ /dev/null @@ -1,241 +0,0 @@ -# /// script -# dependencies = [ -# "torch", -# "numpy", -# ] -# /// - -"""Reusable benchmarking utilities for performance testing.""" -import time -import numpy as np -from contextlib import contextmanager -from typing import Callable, Dict, Tuple, Any, Optional -import torch - -def to_dtype(dtype_str: str): - """Convert string to torch dtype.""" - if dtype_str == "float16": - return torch.float16 - if dtype_str == "bfloat16": - return torch.bfloat16 - return torch.float32 - -def _sync(device: str): - """Synchronize device if CUDA.""" - if device == "cuda": - torch.cuda.synchronize() - -def _compute_stats(times_s, tokens: Optional[int] = None) -> Dict[str, float]: - """Compute comprehensive latency and throughput statistics.""" - lat_ms = np.array([t * 1000.0 for t in times_s]) - lat_ms_sorted = np.sort(lat_ms) - n = len(lat_ms) - - stats = { - "avg_ms": np.mean(lat_ms), - "min_ms": np.min(lat_ms), - "max_ms": np.max(lat_ms), - "std_ms": np.std(lat_ms), - "p50_ms": np.percentile(lat_ms, 50), - "p95_ms": np.percentile(lat_ms, 95), - "p99_ms": np.percentile(lat_ms, 99), - "num_iters": n - } - - if tokens is not None and n > 0: - avg_s = np.mean(times_s) - stats["tokens_per_s"] = tokens / avg_s if avg_s > 0 else float("inf") - stats["throughput_variance"] = np.std([tokens / t for t in times_s if t > 0]) - - return stats - -def _format_timing_stats(stats: Dict[str, float], tokens: Optional[int] = None) -> str: - """Format timing statistics for display.""" - lines = [ - "\n━━━━━━━━━━━━━━━━━━━━ Benchmark Results ━━━━━━━━━━━━━━━━━━━━", - f"Iterations: {stats.get('num_iters', 0)}", - "\nLatency Statistics:", - f" Average: {stats['avg_ms']:.3f} ms", - f" Min: {stats['min_ms']:.3f} ms", - f" Max: {stats['max_ms']:.3f} ms", - f" Std Dev: {stats['std_ms']:.3f} ms", - "\nPercentiles:", - f" P50 (median): {stats['p50_ms']:.3f} ms", - f" P95: {stats['p95_ms']:.3f} ms", - f" P99: {stats['p99_ms']:.3f} ms", - ] - - if tokens is not None and 'tokens_per_s' in stats: - lines.extend([ - "\nThroughput:", - f" Tokens/sec: {stats['tokens_per_s']:.1f}", - f" Std Dev: {stats.get('throughput_variance', 0):.1f}", - ]) - - lines.append("━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━") - return "\n".join(lines) - -def _bench_engine( - call: Callable[[], Any], *, warmup: int, iters: int, device: str, dtype, input_gen: Callable[[], Any] = None -) -> Tuple[Any, list]: - """Core benchmarking engine with warmup and timing.""" - use_autocast = device == "cuda" and dtype in (torch.float16, torch.bfloat16) - - # Warmup phase - print(f"\nWarming up ({warmup} iterations)...") - with torch.inference_mode(): - for _ in range(max(0, warmup)): - if use_autocast: - with torch.autocast(device_type="cuda", dtype=dtype): - if input_gen is not None: - _ = call(input_gen()) - else: - _ = call() - else: - if input_gen is not None: - _ = call(input_gen()) - else: - _ = call() - _sync(device) - - # Measurement phase - print(f"Benchmarking ({iters} iterations)...") - times_s = [] - last = None - with torch.inference_mode(): - for i in range(max(1, iters)): - start = time.perf_counter() - if use_autocast: - with torch.autocast(device_type="cuda", dtype=dtype): - if input_gen is not None: - last = call(input_gen()) - else: - last = call() - else: - if input_gen is not None: - last = call(input_gen()) - else: - last = call() - _sync(device) - end = time.perf_counter() - times_s.append(end - start) - - # Progress indicator every 20% of iterations - if i > 0 and i % max(1, iters // 5) == 0: - pct = (i / iters) * 100 - avg_so_far = np.mean(times_s[:i]) * 1000 - print(f" Progress: {pct:.0f}% complete (avg: {avg_so_far:.3f} ms)") - - return last, times_s - -def tensor_stats(t: torch.Tensor) -> str: - """Generate comprehensive stats string for a tensor.""" - return (f"shape={tuple(t.shape)}, " - f"dtype={t.dtype}, " - f"device={t.device}, " - f"range=[{t.min().item():.6f}, {t.max().item():.6f}], " - f"mean={t.mean().item():.6f}, " - f"std={t.std().item():.6f}, " - f"norm={t.norm().item():.6f}") - -@contextmanager -def bench_context( - *, warmup: int = 25, iters: int = 100, device: str = "cuda", dtype=torch.float32, tokens: Optional[int] = None, verbose: bool = True, save_json: Optional[str] = None, vary_inputs: bool = True -): - """Context that yields a runner: runner(fn, *args, **kwargs) -> (result, stats). - - If vary_inputs=True, the first argument should be a base tensor that will be varied each iteration - by adding a small deterministic increment to prevent caching artifacts. - """ - - def runner(fn: Callable[..., Any], *args, **kwargs) -> Tuple[Any, Dict[str, float]]: - # Log configuration - if verbose: - print(f"\n┌─ Benchmark Configuration ─────────────────────────────┐") - # print(f"│ Device: {device:<15} Dtype: {dtype} │") - print(f"│ Warmup: {warmup:<15} Iters: {iters} │") - if tokens: - print(f"│ Tokens: {tokens} │") - if vary_inputs: - print(f"│ Input Variation: Enabled (prevents caching artifacts) │") - print(f"└────────────────────────────────────────────────────────┘") - - # Set up input generation - input_gen = None - if vary_inputs and args and isinstance(args[0], torch.Tensor): - base_input = args[0].clone() - iteration_counter = [0] # Use list for mutable closure - - def generate_varied_input(): - """Generate input tensor varied by iteration to prevent caching.""" - # Add small deterministic increment: 0.001 * iteration_number - varied_input = base_input + (iteration_counter[0] * 0.001) - iteration_counter[0] += 1 - return varied_input - - input_gen = generate_varied_input - call = lambda x: fn(x, *args[1:], **kwargs) - - # Log base input stats - if verbose: - print(f"\nBase Input: {tensor_stats(base_input)}") - print(f"Input Variation: +{0.001:.3f} * iteration (deterministic)") - else: - # Legacy mode - static inputs - call = lambda: fn(*args, **kwargs) - if verbose and args and isinstance(args[0], torch.Tensor): - print(f"\nInput: {tensor_stats(args[0])}") - - result, times_s = _bench_engine(call, warmup=warmup, iters=iters, device=device, dtype=dtype, input_gen=input_gen) - - # Log output if it's a tensor or tuple with tensors - if verbose: - print("\nOutput tensors:") - if isinstance(result, torch.Tensor): - print(f" Primary: {tensor_stats(result)}") - elif isinstance(result, tuple) and len(result) > 0 and isinstance(result[0], torch.Tensor): - print(f" Primary: {tensor_stats(result[0])}") - if len(result) > 1: - if isinstance(result[1], torch.Tensor): - print(f" Auxiliary: {tensor_stats(result[1])}") - else: - print(f" Auxiliary: {type(result[1]).__name__}") - - # Compute and display statistics - stats = _compute_stats(times_s, tokens=tokens) - if verbose: - print(_format_timing_stats(stats, tokens)) - - # Save to JSON if requested - if save_json: - import json - json_data = { - "implementation": save_json.replace(".json", ""), - "config": { - "warmup": warmup, - "iters": iters, - "device": str(device), # Convert device to string - "dtype": str(dtype), - "tokens": tokens, - "vary_inputs": vary_inputs - }, - "stats": stats, - "output_sum": float(result[0].sum().item()) if isinstance(result, tuple) and len(result) > 0 else float(result.sum().item()) if isinstance(result, torch.Tensor) else None - } - with open(save_json, 'w') as f: - json.dump(json_data, f, indent=2) - if verbose: - print(f"\nSaved benchmark results to {save_json}") - - return result, stats - - yield runner - -def set_seed(seed: int): - """Set seeds for reproducibility.""" - torch.manual_seed(seed) - if torch.cuda.is_available(): - torch.cuda.manual_seed(seed) - torch.cuda.manual_seed_all(seed) - torch.backends.cudnn.deterministic = True - torch.backends.cudnn.benchmark = False \ No newline at end of file diff --git a/megablocks_yamoe/cells/binned_run.py b/megablocks_yamoe/cells/binned_run.py deleted file mode 100644 index fe9e54316e7380bc60d7bb62459498e450575b31..0000000000000000000000000000000000000000 --- a/megablocks_yamoe/cells/binned_run.py +++ /dev/null @@ -1,195 +0,0 @@ -# /// script -# dependencies = [ -# "torch", -# "numpy", -# ] -# /// - -import torch -from torch import nn -from torch.nn import functional as F -from bench_utils import to_dtype, tensor_stats, set_seed, bench_context -from config import ( - NUM_EXPERTS, HIDDEN_SIZE, TOP_K, - BATCH_SIZE, SEQ_LEN, DTYPE, DEVICE, - WEIGHT_SEED, EXPERT_SEED, INPUT_SEED, GENERAL_SEED -) -from pathlib import Path -import os - -# Discover the upstream artifact directory from env -data_dir = os.environ.get('UVNOTE_INPUT_SAVE_DATA', '.') - -router_weight = torch.load(Path(data_dir) / 'router_weight.pt') -router_bias = torch.load(Path(data_dir) / 'router_bias.pt') -gate_up_proj = torch.load(Path(data_dir) / 'gate_up_proj.pt') -gate_up_proj_bias = torch.load(Path(data_dir) / 'gate_up_proj_bias.pt') -down_proj = torch.load(Path(data_dir) / 'down_proj.pt') -down_proj_bias = torch.load(Path(data_dir) / 'down_proj_bias.pt') - -print("Loaded shared weights from artifacts") -print(f"Router weight sum: {router_weight.sum().item():.6f}") -print(f"Gate/up sum: {gate_up_proj.sum().item():.6f}") -print(f"Down sum: {down_proj.sum().item():.6f}") - -def binned_gather(x, indices, bins, expert_capacity, top_k): - E, H = bins.shape[0], x.shape[1] - out = torch.zeros((E, expert_capacity, H), device=x.device, dtype=x.dtype) - for e in range(E): - start = 0 if e == 0 else bins[e - 1] - end = bins[e] - n = min(end - start, expert_capacity) - for i in range(n): - flat_pos = indices[start + i] - tok = flat_pos // top_k - out[e, i] = x[tok] - return out - -def binned_scatter(x, indices, weights, bins, expert_capacity, top_k): - E, C, H = x.shape - N = indices.shape[0] // top_k - out = torch.zeros((N, top_k, H), dtype=x.dtype, device=x.device) - for e in range(E): - start = 0 if e == 0 else bins[e - 1] - end = bins[e] - n = end - start - if n == 0: - continue - take = min(n, expert_capacity) - for i in range(take): - flat_pos = indices[start + i] - tok = flat_pos // top_k - slot = flat_pos % top_k - scale = weights[flat_pos] if weights is not None else 1.0 - out[tok, slot] = x[e, i] * scale - return out.sum(dim=1) - -def sort_tokens_by_expert(router_indices, num_experts): - flat_indices = router_indices.flatten() - sorted_values, sorted_indices = torch.sort(flat_indices) - tokens_per_expert = torch.bincount(sorted_values, minlength=num_experts) - bins = torch.cumsum(tokens_per_expert, dim=0) - return sorted_indices, sorted_values, bins, tokens_per_expert - -def binned_experts_ref( - hidden_states, - router_indices, - routing_weights, - gate_up_proj, - gate_up_proj_bias, - down_proj, - down_proj_bias, - expert_capacity, -): - B, S, H = hidden_states.shape - E, K = routing_weights.shape[1], router_indices.shape[1] - - indices, _, bins, _ = sort_tokens_by_expert(router_indices, E) - x = binned_gather(hidden_states.view(-1, H), indices, bins, expert_capacity, K) - - gate_up = torch.bmm(x, gate_up_proj) - gate_up += gate_up_proj_bias[..., None, :] - - gate, up = gate_up[..., ::2], gate_up[..., 1::2] - - # clamp to limit - limit = 7.0 - gate = gate.clamp(min=None, max=limit) - up = up.clamp(min=-limit, max=limit) - - glu = gate * torch.sigmoid(gate * 1.702) - x = (up + 1) * glu - x = torch.bmm(x, down_proj) + down_proj_bias[..., None, :] - - # build routing weights aligned to (token, slot) - flat_dense = routing_weights.view(-1, E) - flat_router = router_indices.view(-1, K) - selected = torch.gather(flat_dense, 1, flat_router).reshape(-1) - - # scatter back - y = binned_scatter(x, indices, selected, bins, expert_capacity, K) - - return y.view(B, S, H) - -class BinnedRouter(nn.Module): - def __init__(self, router_weight, router_bias): - super().__init__() - self.top_k = TOP_K - self.num_experts = NUM_EXPERTS - self.hidden_dim = HIDDEN_SIZE - self.weight = nn.Parameter(router_weight.clone()) - self.bias = nn.Parameter(router_bias.clone()) - - def forward(self, hidden_states): - hidden_states = hidden_states.reshape(-1, self.hidden_dim) - router_logits = F.linear(hidden_states, self.weight, self.bias) - router_top_value, router_indices = torch.topk(router_logits, self.top_k, dim=-1) - router_top_value = torch.nn.functional.softmax(router_top_value, dim=1, dtype=router_top_value.dtype) - router_scores = torch.zeros_like(router_logits).scatter_(1, router_indices, router_top_value) - return router_scores, router_indices - -def ceil_div(a, b): - return (a + b - 1) // b - -class BinnedMoEMLP(nn.Module): - def __init__(self, router_weight, router_bias, gate_up_proj, gate_up_proj_bias, down_proj, down_proj_bias): - super().__init__() - self.router = BinnedRouter(router_weight, router_bias) - self.num_experts = NUM_EXPERTS - self.hidden_size = HIDDEN_SIZE - self.top_k = TOP_K - - # Expert weights - use the loaded weights - self.gate_up_proj = nn.Parameter(gate_up_proj.clone()) - self.gate_up_proj_bias = nn.Parameter(gate_up_proj_bias.clone()) - self.down_proj = nn.Parameter(down_proj.clone()) - self.down_proj_bias = nn.Parameter(down_proj_bias.clone()) - - def forward(self, hidden_states): - router_scores, router_indices = self.router(hidden_states) - batch_size = hidden_states.shape[0] - expert_capacity = ceil_div(batch_size * self.top_k, self.num_experts) - - output = binned_experts_ref( - hidden_states, - router_indices, - router_scores, - self.gate_up_proj, - self.gate_up_proj_bias, - self.down_proj, - self.down_proj_bias, - expert_capacity, - ) - - return output, router_scores - -# Run the model -set_seed(GENERAL_SEED) - -device = torch.device(DEVICE) -dtype = to_dtype(DTYPE) - -print("\n=== Binned Implementation ===") -# Initialize model with loaded weights -model = BinnedMoEMLP( - router_weight.to(device), - router_bias.to(device), - gate_up_proj.to(device), - gate_up_proj_bias.to(device), - down_proj.to(device), - down_proj_bias.to(device) -).to(device=device) - -print(f"Router weight sum: {model.router.weight.sum().item():.6f}") -print(f"Gate/up proj sum: {model.gate_up_proj.sum().item():.6f}") -print(f"Down proj sum: {model.down_proj.sum().item():.6f}") - -# Generate the same input as Yamoe -set_seed(INPUT_SEED) -x = torch.randn(BATCH_SIZE, SEQ_LEN, HIDDEN_SIZE, device=device, dtype=dtype) * 0.1 - -# Benchmark the model with varied inputs to prevent caching artifacts -tokens = BATCH_SIZE * SEQ_LEN -with bench_context(warmup=10, iters=50, device=device, dtype=dtype, tokens=tokens, save_json="binned_results.json", vary_inputs=True) as bench: - output, stats = bench(model, x) - print(f"\nOutput sum: {output[0].sum().item():.6f}") \ No newline at end of file diff --git a/megablocks_yamoe/cells/config.py b/megablocks_yamoe/cells/config.py deleted file mode 100644 index 747a7224106854e57904aa10edc15f4d5f0c4a17..0000000000000000000000000000000000000000 --- a/megablocks_yamoe/cells/config.py +++ /dev/null @@ -1,27 +0,0 @@ -# /// script -# dependencies = [ -# "torch", -# "numpy", -# ] -# /// - -"""Shared configuration for both implementations.""" -import torch - -# Model configuration -NUM_EXPERTS = 128 -HIDDEN_SIZE = 1152 -INTERMEDIATE_SIZE = 3072 -TOP_K = 4 - -# Input configuration -BATCH_SIZE = 1 -SEQ_LEN = 100 -DTYPE = "float32" -DEVICE = "cuda" if torch.cuda.is_available() else "cpu" - -# Seeds for reproducibility -WEIGHT_SEED = 999 -EXPERT_SEED = 777 -INPUT_SEED = 123 -GENERAL_SEED = 42 \ No newline at end of file diff --git a/megablocks_yamoe/cells/gptoss_run.py b/megablocks_yamoe/cells/gptoss_run.py deleted file mode 100644 index 5a1532dabff53ecb068ddd4354c545f0cea2d72b..0000000000000000000000000000000000000000 --- a/megablocks_yamoe/cells/gptoss_run.py +++ /dev/null @@ -1,147 +0,0 @@ -# /// script -# dependencies = [ -# "torch", -# "numpy", -# ] -# /// - -import torch -from torch import nn -from torch.nn import functional as F -from bench_utils import to_dtype, tensor_stats, set_seed, bench_context -from config import ( - NUM_EXPERTS, HIDDEN_SIZE, TOP_K, - BATCH_SIZE, SEQ_LEN, DTYPE, DEVICE, - WEIGHT_SEED, EXPERT_SEED, INPUT_SEED, GENERAL_SEED -) -from pathlib import Path -import os - -# Discover the upstream artifact directory from env -data_dir = os.environ.get('UVNOTE_INPUT_SAVE_DATA', '.') - -router_weight = torch.load(Path(data_dir) / 'router_weight.pt') -router_bias = torch.load(Path(data_dir) / 'router_bias.pt') -gate_up_proj = torch.load(Path(data_dir) / 'gate_up_proj.pt') -gate_up_proj_bias = torch.load(Path(data_dir) / 'gate_up_proj_bias.pt') -down_proj = torch.load(Path(data_dir) / 'down_proj.pt') -down_proj_bias = torch.load(Path(data_dir) / 'down_proj_bias.pt') - -print("Loaded shared weights from artifacts") -print(f"Router weight sum: {router_weight.sum().item():.6f}") -print(f"Gate/up sum: {gate_up_proj.sum().item():.6f}") -print(f"Down sum: {down_proj.sum().item():.6f}") - -class GptOssRouter(nn.Module): - def __init__(self, router_weight, router_bias): - super().__init__() - self.top_k = TOP_K - self.num_experts = NUM_EXPERTS - self.hidden_dim = HIDDEN_SIZE - self.weight = nn.Parameter(router_weight.clone()) - self.bias = nn.Parameter(router_bias.clone()) - - def forward(self, hidden_states): - hidden_states = hidden_states.reshape(-1, self.hidden_dim) - router_logits = F.linear(hidden_states, self.weight, self.bias) - router_top_value, router_indices = torch.topk(router_logits, self.top_k, dim=-1) - router_top_value = torch.nn.functional.softmax(router_top_value, dim=1, dtype=router_top_value.dtype) - router_scores = torch.zeros_like(router_logits).scatter_(1, router_indices, router_top_value) - return router_scores, router_indices - -class GptOssExperts(nn.Module): - def __init__(self, gate_up_proj, gate_up_proj_bias, down_proj, down_proj_bias): - super().__init__() - self.num_experts = NUM_EXPERTS - self.hidden_size = HIDDEN_SIZE - self.expert_dim = self.hidden_size - self.gate_up_proj = nn.Parameter(gate_up_proj.clone()) - self.gate_up_proj_bias = nn.Parameter(gate_up_proj_bias.clone()) - self.down_proj = nn.Parameter(down_proj.clone()) - self.down_proj_bias = nn.Parameter(down_proj_bias.clone()) - self.alpha = 1.702 - self.limit = 7.0 - - def forward(self, hidden_states: torch.Tensor, router_indices=None, routing_weights=None) -> torch.Tensor: - batch_size = hidden_states.shape[0] - hidden_states = hidden_states.reshape(-1, self.hidden_size) - num_experts = routing_weights.shape[1] - - if hidden_states.device.type == "cpu" or self.training: - next_states = torch.zeros_like(hidden_states, dtype=hidden_states.dtype, device=hidden_states.device) - with torch.no_grad(): - expert_mask = torch.nn.functional.one_hot(router_indices, num_classes=num_experts) - expert_mask = expert_mask.permute(2, 1, 0) - expert_hit = torch.greater(expert_mask.sum(dim=(-1, -2)), 0).nonzero() - - for expert_idx in expert_hit[:]: - expert_idx = expert_idx[0] - with torch.no_grad(): - _, token_idx = torch.where(expert_mask[expert_idx]) - current_state = hidden_states[token_idx] - gate_up = current_state @ self.gate_up_proj[expert_idx] + self.gate_up_proj_bias[expert_idx] - gate, up = gate_up[..., ::2], gate_up[..., 1::2] - gate = gate.clamp(min=None, max=self.limit) - up = up.clamp(min=-self.limit, max=self.limit) - glu = gate * torch.sigmoid(gate * self.alpha) - gated_output = (up + 1) * glu - out = gated_output @ self.down_proj[expert_idx] + self.down_proj_bias[expert_idx] - weighted_output = out * routing_weights[token_idx, expert_idx, None] - next_states.index_add_(0, token_idx, weighted_output.to(hidden_states.dtype)) - next_states = next_states.view(batch_size, -1, self.hidden_size) - else: - hidden_states = hidden_states.repeat(num_experts, 1) - hidden_states = hidden_states.view(num_experts, -1, self.hidden_size) - gate_up = torch.bmm(hidden_states, self.gate_up_proj) + self.gate_up_proj_bias[..., None, :] - gate, up = gate_up[..., ::2], gate_up[..., 1::2] - gate = gate.clamp(min=None, max=self.limit) - up = up.clamp(min=-self.limit, max=self.limit) - glu = gate * torch.sigmoid(gate * self.alpha) - next_states = torch.bmm(((up + 1) * glu), self.down_proj) - next_states = next_states + self.down_proj_bias[..., None, :] - next_states = next_states.view(num_experts, batch_size, -1, self.hidden_size) - next_states = next_states * routing_weights.transpose(0, 1).view(num_experts, batch_size, -1)[..., None] - next_states = next_states.sum(dim=0) - return next_states - -class GptOssMoEMLP(nn.Module): - def __init__(self, router_weight, router_bias, gate_up_proj, gate_up_proj_bias, down_proj, down_proj_bias): - super().__init__() - self.router = GptOssRouter(router_weight, router_bias) - self.experts = GptOssExperts(gate_up_proj, gate_up_proj_bias, down_proj, down_proj_bias) - - def forward(self, hidden_states): - router_scores, router_indices = self.router(hidden_states) - routed_out = self.experts(hidden_states, router_indices=router_indices, routing_weights=router_scores) - return routed_out, router_scores - -# Run the model -set_seed(GENERAL_SEED) - -device = torch.device(DEVICE) -dtype = to_dtype(DTYPE) - -print("\n=== GPT-OSS Implementation ===") -# Initialize model with loaded weights -model = GptOssMoEMLP( - router_weight.to(device), - router_bias.to(device), - gate_up_proj.to(device), - gate_up_proj_bias.to(device), - down_proj.to(device), - down_proj_bias.to(device) -).to(device=device) - -print(f"Router weight sum: {model.router.weight.sum().item():.6f}") -print(f"Gate/up proj sum: {model.experts.gate_up_proj.sum().item():.6f}") -print(f"Down proj sum: {model.experts.down_proj.sum().item():.6f}") - -# Generate the same input as other implementations -set_seed(INPUT_SEED) -x = torch.randn(BATCH_SIZE, SEQ_LEN, HIDDEN_SIZE, device=device, dtype=dtype) * 0.1 - -# Benchmark the model with varied inputs to prevent caching artifacts -tokens = BATCH_SIZE * SEQ_LEN -with bench_context(warmup=10, iters=50, device=device, dtype=dtype, tokens=tokens, save_json="gptoss_results.json", vary_inputs=True) as bench: - output, stats = bench(model, x) - print(f"\nOutput sum: {output[0].sum().item():.6f}") \ No newline at end of file diff --git a/megablocks_yamoe/cells/gptoss_training_run.py b/megablocks_yamoe/cells/gptoss_training_run.py deleted file mode 100644 index f18731a74bfa546e612addbaab9e3ff5ec5d26dc..0000000000000000000000000000000000000000 --- a/megablocks_yamoe/cells/gptoss_training_run.py +++ /dev/null @@ -1,138 +0,0 @@ -# /// script -# dependencies = [ -# "torch", -# "numpy", -# ] -# /// - -import torch -from torch import nn -from torch.nn import functional as F -from bench_utils import to_dtype, tensor_stats, set_seed, bench_context -from config import ( - NUM_EXPERTS, HIDDEN_SIZE, TOP_K, - BATCH_SIZE, SEQ_LEN, DTYPE, DEVICE, - WEIGHT_SEED, EXPERT_SEED, INPUT_SEED, GENERAL_SEED -) -from pathlib import Path -import os - -# Discover the upstream artifact directory from env -data_dir = os.environ.get('UVNOTE_INPUT_SAVE_DATA', '.') - -router_weight = torch.load(Path(data_dir) / 'router_weight.pt') -router_bias = torch.load(Path(data_dir) / 'router_bias.pt') -gate_up_proj = torch.load(Path(data_dir) / 'gate_up_proj.pt') -gate_up_proj_bias = torch.load(Path(data_dir) / 'gate_up_proj_bias.pt') -down_proj = torch.load(Path(data_dir) / 'down_proj.pt') -down_proj_bias = torch.load(Path(data_dir) / 'down_proj_bias.pt') - -print("Loaded shared weights from artifacts") -print(f"Router weight sum: {router_weight.sum().item():.6f}") -print(f"Gate/up sum: {gate_up_proj.sum().item():.6f}") -print(f"Down sum: {down_proj.sum().item():.6f}") - -class GptOssTrainingRouter(nn.Module): - def __init__(self, router_weight, router_bias): - super().__init__() - self.top_k = TOP_K - self.num_experts = NUM_EXPERTS - self.hidden_dim = HIDDEN_SIZE - self.weight = nn.Parameter(router_weight.clone()) - self.bias = nn.Parameter(router_bias.clone()) - - def forward(self, hidden_states): - hidden_states = hidden_states.reshape(-1, self.hidden_dim) - router_logits = F.linear(hidden_states, self.weight, self.bias) - router_top_value, router_indices = torch.topk(router_logits, self.top_k, dim=-1) - router_top_value = torch.nn.functional.softmax(router_top_value, dim=1, dtype=router_top_value.dtype) - router_scores = torch.zeros_like(router_logits).scatter_(1, router_indices, router_top_value) - return router_scores, router_indices - -class GptOssTrainingExperts(nn.Module): - def __init__(self, gate_up_proj, gate_up_proj_bias, down_proj, down_proj_bias): - super().__init__() - self.num_experts = NUM_EXPERTS - self.hidden_size = HIDDEN_SIZE - self.expert_dim = self.hidden_size - self.gate_up_proj = nn.Parameter(gate_up_proj.clone()) - self.gate_up_proj_bias = nn.Parameter(gate_up_proj_bias.clone()) - self.down_proj = nn.Parameter(down_proj.clone()) - self.down_proj_bias = nn.Parameter(down_proj_bias.clone()) - self.alpha = 1.702 - self.limit = 7.0 - - def forward(self, hidden_states: torch.Tensor, router_indices=None, routing_weights=None) -> torch.Tensor: - batch_size = hidden_states.shape[0] - hidden_states = hidden_states.reshape(-1, self.hidden_size) - num_experts = routing_weights.shape[1] - - # Force training mode path (expert loop instead of batched) - next_states = torch.zeros_like(hidden_states, dtype=hidden_states.dtype, device=hidden_states.device) - with torch.no_grad(): - expert_mask = torch.nn.functional.one_hot(router_indices, num_classes=num_experts) - expert_mask = expert_mask.permute(2, 1, 0) - expert_hit = torch.greater(expert_mask.sum(dim=(-1, -2)), 0).nonzero() - - for expert_idx in expert_hit[:]: - expert_idx = expert_idx[0] - with torch.no_grad(): - _, token_idx = torch.where(expert_mask[expert_idx]) - current_state = hidden_states[token_idx] - gate_up = current_state @ self.gate_up_proj[expert_idx] + self.gate_up_proj_bias[expert_idx] - gate, up = gate_up[..., ::2], gate_up[..., 1::2] - gate = gate.clamp(min=None, max=self.limit) - up = up.clamp(min=-self.limit, max=self.limit) - glu = gate * torch.sigmoid(gate * self.alpha) - gated_output = (up + 1) * glu - out = gated_output @ self.down_proj[expert_idx] + self.down_proj_bias[expert_idx] - weighted_output = out * routing_weights[token_idx, expert_idx, None] - next_states.index_add_(0, token_idx, weighted_output.to(hidden_states.dtype)) - next_states = next_states.view(batch_size, -1, self.hidden_size) - return next_states - -class GptOssTrainingMoEMLP(nn.Module): - def __init__(self, router_weight, router_bias, gate_up_proj, gate_up_proj_bias, down_proj, down_proj_bias): - super().__init__() - self.router = GptOssTrainingRouter(router_weight, router_bias) - self.experts = GptOssTrainingExperts(gate_up_proj, gate_up_proj_bias, down_proj, down_proj_bias) - - def forward(self, hidden_states): - router_scores, router_indices = self.router(hidden_states) - routed_out = self.experts(hidden_states, router_indices=router_indices, routing_weights=router_scores) - return routed_out, router_scores - -# Run the model -set_seed(GENERAL_SEED) - -device = torch.device(DEVICE) -dtype = to_dtype(DTYPE) - -print("\n=== GPT-OSS Implementation (Training Mode - Expert Loop) ===") -# Initialize model with loaded weights and force training mode -model = GptOssTrainingMoEMLP( - router_weight.to(device), - router_bias.to(device), - gate_up_proj.to(device), - gate_up_proj_bias.to(device), - down_proj.to(device), - down_proj_bias.to(device) -).to(device=device) - -# Set to training mode to force expert loop path -model.train() - -print(f"Router weight sum: {model.router.weight.sum().item():.6f}") -print(f"Gate/up proj sum: {model.experts.gate_up_proj.sum().item():.6f}") -print(f"Down proj sum: {model.experts.down_proj.sum().item():.6f}") -print(f"Model training mode: {model.training}") - -# Generate the same input as other implementations -set_seed(INPUT_SEED) -x = torch.randn(BATCH_SIZE, SEQ_LEN, HIDDEN_SIZE, device=device, dtype=dtype) * 0.1 - -# Benchmark the model with varied inputs to prevent caching artifacts -tokens = BATCH_SIZE * SEQ_LEN -with bench_context(warmup=10, iters=50, device=device, dtype=dtype, tokens=tokens, save_json="gptoss_training_results.json", vary_inputs=True) as bench: - output, stats = bench(model, x) - print(f"\nOutput sum: {output[0].sum().item():.6f}") \ No newline at end of file diff --git a/megablocks_yamoe/cells/megablocks_run.py b/megablocks_yamoe/cells/megablocks_run.py deleted file mode 100644 index a18723cb66c892119c0a9e88d8c2a140a6354a00..0000000000000000000000000000000000000000 --- a/megablocks_yamoe/cells/megablocks_run.py +++ /dev/null @@ -1,103 +0,0 @@ -# /// script -# dependencies = [ -# "torch", -# "numpy", -# "kernels", -# ] -# /// - -import torch -from torch import nn -from torch.nn import functional as F -from kernels import get_kernel, get_local_kernel -from bench_utils import to_dtype, tensor_stats, set_seed, bench_context -from config import ( - NUM_EXPERTS, HIDDEN_SIZE, TOP_K, - BATCH_SIZE, SEQ_LEN, DTYPE, DEVICE, - WEIGHT_SEED, EXPERT_SEED, INPUT_SEED, GENERAL_SEED -) -from pathlib import Path -from collections import namedtuple -import os - -# Discover the upstream artifact directory from env -data_dir = os.environ.get('UVNOTE_INPUT_SAVE_DATA', '.') - -print(f"Loading weights from: {data_dir}") - -router_weight = torch.load(Path(data_dir) / 'router_weight.pt') -router_bias = torch.load(Path(data_dir) / 'router_bias.pt') -gate_up_proj = torch.load(Path(data_dir) / 'gate_up_proj.pt') -gate_up_proj_bias = torch.load(Path(data_dir) / 'gate_up_proj_bias.pt') -down_proj = torch.load(Path(data_dir) / 'down_proj.pt') -down_proj_bias = torch.load(Path(data_dir) / 'down_proj_bias.pt') - -print("Loaded shared weights from artifacts") -print(f"Router weight sum: {router_weight.sum().item():.6f}") -print(f"Gate/up sum: {gate_up_proj.sum().item():.6f}") -print(f"Down sum: {down_proj.sum().item():.6f}") - -def build_megablocks_model(device: torch.device): - # Download optimized kernels from the Hugging Face hub - megablocks = get_kernel("kernels-community/megablocks", revision="v0.0.2") - model = megablocks.layers.MegaBlocksMoeMLP() - - # Create attribute container for expert weights - model.experts = namedtuple( - "Experts", ["gate_up_proj", "gate_up_proj_bias", "down_proj", "down_proj_bias", "hidden_size"] - ) - - # Use loaded router weights for consistency - model.router = torch.nn.Linear(HIDDEN_SIZE, NUM_EXPERTS, device=device) - with torch.no_grad(): - model.router.weight.copy_(router_weight) - model.router.bias.copy_(router_bias) - - # Attach loaded expert weights to the experts container - e = model.experts - e.alpha = 1.702 - e.capacity_factor = 32 - e.gate_up_proj = torch.nn.Parameter(gate_up_proj.clone().to(device)) - e.gate_up_proj_bias = torch.nn.Parameter(gate_up_proj_bias.clone().to(device)) - e.down_proj = torch.nn.Parameter(down_proj.clone().to(device)) - e.down_proj_bias = torch.nn.Parameter(down_proj_bias.clone().to(device)) - e.hidden_size = HIDDEN_SIZE - - # Log weight statistics for comparison - print(f"[MegaBlocks] Router weight sum: {model.router.weight.sum().item():.6f}") - print(f"[MegaBlocks] Gate/up projection shape: {tuple(e.gate_up_proj.shape)}, sum: {e.gate_up_proj.sum().item():.6f}") - print(f"[MegaBlocks] Down projection shape: {tuple(e.down_proj.shape)}, sum: {e.down_proj.sum().item():.6f}") - - return model - -# Create a wrapper to match the interface of other implementations -class MegaBlocksMoEWrapper(nn.Module): - def __init__(self, megablocks_model): - super().__init__() - self.model = megablocks_model - - def forward(self, hidden_states): - # MegaBlocks expects input in the format (batch, seq_len, hidden_dim) - output, dummy_routing_weights = self.model(hidden_states) - return output, dummy_routing_weights - -# Run the model -set_seed(GENERAL_SEED) - -device = torch.device(DEVICE) -dtype = to_dtype(DTYPE) - -print("\n=== MegaBlocks Implementation ===") -# Build MegaBlocks model with loaded weights -megablocks_model = build_megablocks_model(device) -model = MegaBlocksMoEWrapper(megablocks_model).to(device=device) - -# Generate the same input as other implementations -set_seed(INPUT_SEED) -x = torch.randn(BATCH_SIZE, SEQ_LEN, HIDDEN_SIZE, device=device, dtype=dtype) * 0.1 - -# Benchmark the model with varied inputs to prevent caching artifacts -tokens = BATCH_SIZE * SEQ_LEN -with bench_context(warmup=10, iters=50, device=device, dtype=dtype, tokens=tokens, save_json="megablocks_results.json", vary_inputs=True) as bench: - output, stats = bench(model, x) - print(f"\nOutput sum: {output[0].sum().item():.6f}") \ No newline at end of file diff --git a/megablocks_yamoe/cells/nv.py b/megablocks_yamoe/cells/nv.py deleted file mode 100644 index 80eef60a7536ed875fb21731ab2d059458bd20b4..0000000000000000000000000000000000000000 --- a/megablocks_yamoe/cells/nv.py +++ /dev/null @@ -1,3 +0,0 @@ -import subprocess - -print(subprocess.run(["nvidia-smi"], capture_output=True, text=True).stdout) \ No newline at end of file diff --git a/megablocks_yamoe/cells/save_data.py b/megablocks_yamoe/cells/save_data.py deleted file mode 100644 index b15750dce52da48651ccd9805cdab51af88503d5..0000000000000000000000000000000000000000 --- a/megablocks_yamoe/cells/save_data.py +++ /dev/null @@ -1,42 +0,0 @@ -# /// script -# dependencies = [ -# "torch", -# "numpy", -# ] -# /// - -""" -Generate deterministic shared weights once and save as artifacts so -both implementations load identical parameters. -""" -import torch -from config import NUM_EXPERTS, HIDDEN_SIZE, WEIGHT_SEED, EXPERT_SEED - -def save_shared_weights(): - # Router: Kaiming uniform as used by both, bias zeros - torch.manual_seed(WEIGHT_SEED) - router_weight = torch.empty(NUM_EXPERTS, HIDDEN_SIZE) - torch.nn.init.kaiming_uniform_(router_weight) - router_bias = torch.zeros(NUM_EXPERTS) - - # Experts: normal(0, 0.02), biases zeros - torch.manual_seed(EXPERT_SEED) - gate_up_proj = torch.empty(NUM_EXPERTS, HIDDEN_SIZE, 2 * HIDDEN_SIZE).normal_(mean=0.0, std=0.02) - gate_up_proj_bias = torch.zeros(NUM_EXPERTS, 2 * HIDDEN_SIZE) - down_proj = torch.empty(NUM_EXPERTS, HIDDEN_SIZE, HIDDEN_SIZE).normal_(mean=0.0, std=0.02) - down_proj_bias = torch.zeros(NUM_EXPERTS, HIDDEN_SIZE) - - # Save artifacts - torch.save(router_weight, 'router_weight.pt') - torch.save(router_bias, 'router_bias.pt') - torch.save(gate_up_proj, 'gate_up_proj.pt') - torch.save(gate_up_proj_bias, 'gate_up_proj_bias.pt') - torch.save(down_proj, 'down_proj.pt') - torch.save(down_proj_bias, 'down_proj_bias.pt') - - print("Saved shared weights to artifacts") - print(f"Router weight sum: {router_weight.sum().item():.6f}") - print(f"Gate/up sum: {gate_up_proj.sum().item():.6f}") - print(f"Down sum: {down_proj.sum().item():.6f}") - -save_shared_weights() \ No newline at end of file diff --git a/megablocks_yamoe/cells/setup.py b/megablocks_yamoe/cells/setup.py deleted file mode 100644 index 6d7f386417ca59470f5e6404d26b64a6d1fd6f39..0000000000000000000000000000000000000000 --- a/megablocks_yamoe/cells/setup.py +++ /dev/null @@ -1,116 +0,0 @@ -# /// script -# requires-python = ">=3.12" -# dependencies = [ -# "accelerate>=1.10.1", -# "torch>=2.7.0", -# "kernels==0.10.0", -# "transformers@https://github.com/huggingface/transformers.git", -# "ipdb>=0.13.13", -# "matplotlib>=3.7.2", -# "numpy>=1.24.3", -# ] -# /// - -import torch -from transformers import GptOssForCausalLM, PreTrainedTokenizerFast, Mxfp4Config -import time -import torch.nn as nn -from kernels import register_kernel_mapping, Mode, LayerRepository -import sys -import torch.profiler -import gc -import logging - -# set to debug logging -logging.basicConfig(level=logging.INFO) - -def reset_peak_memory_stats(): - """Clear CUDA cache and reset memory allocation counters.""" - torch.cuda.empty_cache() - if torch.cuda.is_available(): - torch.cuda.reset_peak_memory_stats() - gc.collect() - -def get_memory_stats(): - """Get current and peak CUDA memory usage.""" - if not torch.cuda.is_available(): - return {"allocated_gb": 0, "peak_gb": 0, "reserved_gb": 0} - return { - "allocated_gb": torch.cuda.memory_allocated() / 1e9, - "peak_gb": torch.cuda.max_memory_allocated() / 1e9, - "reserved_gb": torch.cuda.memory_reserved() / 1e9, - } - -def override_kernel_layer_name(cls_name: str, value) -> bool: - """Helper to dynamically override the kernel_layer_name in a model class.""" - for mod in sys.modules.values(): - if mod is None: - continue - obj = getattr(mod, cls_name, None) - if isinstance(obj, type) and issubclass(obj, nn.Module): - setattr(obj, "kernel_layer_name", value) - print(f"Overrode {cls_name}.kernel_layer_name to {value}") - return True - return False - - -# Init the model the normal way -model_id = "openai/gpt-oss-20b" -tokenizer = PreTrainedTokenizerFast.from_pretrained(model_id) -quantization_config = Mxfp4Config(dequantize=True) - - -from kernels import replace_kernel_forward_from_hub, register_kernel_mapping, LayerRepository, Mode - -from transformers.models.gpt_oss.modeling_gpt_oss import GptOssMLP, GptOssRMSNorm - -replace_kernel_forward_from_hub(GptOssMLP, "Yamoe") -replace_kernel_forward_from_hub(GptOssRMSNorm, None) -custom_mapping = { - "Yamoe": { - "cuda": { - Mode.INFERENCE: LayerRepository( - repo_id="drbh/yamoe", - layer_name="Yamoe", - revision="v0.3.0", - ) - } - } -} -register_kernel_mapping(custom_mapping) - - -model = GptOssForCausalLM.from_pretrained( - model_id, - dtype="bfloat16", - device_map="auto", - use_kernels=True, - quantization_config=quantization_config, -).eval() - -messages = [ - {"role": "system", "content": "What is Tensor Parallelism?"}, -] - -inputs = tokenizer.apply_chat_template( - messages, - add_generation_prompt=True, - return_tensors="pt", - return_dict=True, - reasoning_effort="low", -).to("cuda") - -max_tokens = 256 - -with torch.inference_mode(): - start_time = time.perf_counter() - generated = model.generate( - **inputs, - max_new_tokens=max_tokens, - do_sample=False, - temperature=None, - ) - end_time = time.perf_counter() - -print(tokenizer.decode(generated[0], skip_special_tokens=False)) -print(f"Generation took {end_time - start_time:.2f} seconds") diff --git a/megablocks_yamoe/cells/setup2.py b/megablocks_yamoe/cells/setup2.py deleted file mode 100644 index b67054a2580d775875fd3f0382d5820f3076236b..0000000000000000000000000000000000000000 --- a/megablocks_yamoe/cells/setup2.py +++ /dev/null @@ -1,115 +0,0 @@ -# /// script -# requires-python = ">=3.12" -# dependencies = [ -# "accelerate>=1.10.1", -# "torch>=2.7.0", -# "kernels==0.10.0", -# "transformers@https://github.com/huggingface/transformers.git", -# "ipdb>=0.13.13", -# "matplotlib>=3.7.2", -# "numpy>=1.24.3", -# ] -# /// - -import torch -from transformers import GptOssForCausalLM, PreTrainedTokenizerFast, Mxfp4Config -import time -import torch.nn as nn -from kernels import register_kernel_mapping, Mode, LayerRepository -import sys -import torch.profiler -import gc -import logging - -# set to debug logging -logging.basicConfig(level=logging.INFO) - -def reset_peak_memory_stats(): - """Clear CUDA cache and reset memory allocation counters.""" - torch.cuda.empty_cache() - if torch.cuda.is_available(): - torch.cuda.reset_peak_memory_stats() - gc.collect() - -def get_memory_stats(): - """Get current and peak CUDA memory usage.""" - if not torch.cuda.is_available(): - return {"allocated_gb": 0, "peak_gb": 0, "reserved_gb": 0} - return { - "allocated_gb": torch.cuda.memory_allocated() / 1e9, - "peak_gb": torch.cuda.max_memory_allocated() / 1e9, - "reserved_gb": torch.cuda.memory_reserved() / 1e9, - } - -def override_kernel_layer_name(cls_name: str, value) -> bool: - """Helper to dynamically override the kernel_layer_name in a model class.""" - for mod in sys.modules.values(): - if mod is None: - continue - obj = getattr(mod, cls_name, None) - if isinstance(obj, type) and issubclass(obj, nn.Module): - setattr(obj, "kernel_layer_name", value) - print(f"Overrode {cls_name}.kernel_layer_name to {value}") - return True - return False - - -# Init the model the normal way -model_id = "openai/gpt-oss-20b" -tokenizer = PreTrainedTokenizerFast.from_pretrained(model_id) -quantization_config = Mxfp4Config(dequantize=True) - - -from kernels import replace_kernel_forward_from_hub, register_kernel_mapping, LayerRepository, Mode - -from transformers.models.gpt_oss.modeling_gpt_oss import GptOssMLP, GptOssRMSNorm - -replace_kernel_forward_from_hub(GptOssRMSNorm, None) # direct, type-safe -custom_mapping = { - "Yamoe": { - "cuda": { - Mode.INFERENCE: LayerRepository( - repo_id="drbh/yamoe", - layer_name="Yamoe", - revision="v0.3.0", - ) - } - } -} -register_kernel_mapping(custom_mapping) - - -model = GptOssForCausalLM.from_pretrained( - model_id, - dtype="bfloat16", - device_map="auto", - use_kernels=True, - quantization_config=quantization_config, -).eval() - -messages = [ - {"role": "system", "content": "What is Tensor Parallelism?"}, -] - -inputs = tokenizer.apply_chat_template( - messages, - add_generation_prompt=True, - return_tensors="pt", - return_dict=True, - reasoning_effort="low", -).to("cuda") - -max_tokens = 256 - -with torch.inference_mode(): - start_time = time.perf_counter() - generated = model.generate( - **inputs, - max_new_tokens=max_tokens, - do_sample=False, - temperature=None, - ) - end_time = time.perf_counter() - -print(tokenizer.decode(generated[0], skip_special_tokens=False)) -print(f"Generation took {end_time - start_time:.2f} seconds") diff --git a/megablocks_yamoe/cells/utils.py b/megablocks_yamoe/cells/utils.py deleted file mode 100644 index f1f83f42e002602ff034c10cdc3f2f598c779e1f..0000000000000000000000000000000000000000 --- a/megablocks_yamoe/cells/utils.py +++ /dev/null @@ -1,34 +0,0 @@ -# /// script -# dependencies = [ -# "torch", -# "numpy", -# ] -# /// - -"""Simple utilities for running the models.""" -import torch - -def to_dtype(dtype_str: str): - """Convert string to torch dtype.""" - if dtype_str == "float16": - return torch.float16 - if dtype_str == "bfloat16": - return torch.bfloat16 - return torch.float32 - -def tensor_stats(t: torch.Tensor) -> str: - """Generate stats string for a tensor.""" - return (f"shape={tuple(t.shape)}, " - f"dtype={t.dtype}, " - f"device={t.device}, " - f"mean={t.mean().item():.6f}, " - f"std={t.std().item():.6f}") - -def set_seed(seed: int): - """Set seeds for reproducibility.""" - torch.manual_seed(seed) - if torch.cuda.is_available(): - torch.cuda.manual_seed(seed) - torch.cuda.manual_seed_all(seed) - torch.backends.cudnn.deterministic = True - torch.backends.cudnn.benchmark = False \ No newline at end of file diff --git a/megablocks_yamoe/cells/yamoe_run.py b/megablocks_yamoe/cells/yamoe_run.py deleted file mode 100644 index b3e73c4cb44433286cab638f8faae2623c5a5030..0000000000000000000000000000000000000000 --- a/megablocks_yamoe/cells/yamoe_run.py +++ /dev/null @@ -1,135 +0,0 @@ -# /// script -# dependencies = [ -# "torch", -# "kernels", -# "numpy", -# ] -# /// - -import torch -from torch import nn -from torch.nn import functional as F -from kernels import get_kernel, get_local_kernel -from bench_utils import to_dtype, tensor_stats, set_seed, bench_context -from config import ( - NUM_EXPERTS, HIDDEN_SIZE, TOP_K, - BATCH_SIZE, SEQ_LEN, DTYPE, DEVICE, - WEIGHT_SEED, EXPERT_SEED, INPUT_SEED, GENERAL_SEED -) -from pathlib import Path -import os - -# Discover the upstream artifact directory from env -data_dir = os.environ.get('UVNOTE_INPUT_SAVE_DATA', '.') -print(f"Loading weights from: {data_dir}") - -router_weight = torch.load(Path(data_dir) / 'router_weight.pt') -router_bias = torch.load(Path(data_dir) / 'router_bias.pt') -gate_up_proj = torch.load(Path(data_dir) / 'gate_up_proj.pt') -gate_up_proj_bias = torch.load(Path(data_dir) / 'gate_up_proj_bias.pt') -down_proj = torch.load(Path(data_dir) / 'down_proj.pt') -down_proj_bias = torch.load(Path(data_dir) / 'down_proj_bias.pt') - -print("Loaded shared weights from artifacts") -print(f"Router weight sum: {router_weight.sum().item():.6f}") -print(f"Gate/up sum: {gate_up_proj.sum().item():.6f}") -print(f"Down sum: {down_proj.sum().item():.6f}") - -class YamoeRouter(nn.Module): - def __init__(self, router_weight, router_bias): - super().__init__() - self.top_k = TOP_K - self.num_experts = NUM_EXPERTS - self.hidden_dim = HIDDEN_SIZE - self.weight = nn.Parameter(router_weight.clone()) - self.bias = nn.Parameter(router_bias.clone()) - - def forward(self, hidden_states): - hidden_states = hidden_states.reshape(-1, self.hidden_dim) - router_logits = F.linear(hidden_states, self.weight, self.bias) - router_top_value, router_indices = torch.topk(router_logits, self.top_k, dim=-1) - router_top_value = torch.nn.functional.softmax(router_top_value, dim=1, dtype=router_top_value.dtype) - router_scores = torch.zeros_like(router_logits).scatter_(1, router_indices, router_top_value) - return router_scores, router_indices - -def ceil_div(a, b): - return (a + b - 1) // b - -class YamoeMoEMLP(nn.Module): - def __init__(self, router_weight, router_bias, gate_up_proj, gate_up_proj_bias, down_proj, down_proj_bias): - super().__init__() - self.router = YamoeRouter(router_weight, router_bias) - self.num_experts = NUM_EXPERTS - self.hidden_size = HIDDEN_SIZE - self.top_k = TOP_K - - # Load Yamoe kernel - # self.yamoe = get_local_kernel(Path("/home/ubuntu/Projects/yamoe/result"), "yamoe") - self.yamoe = get_kernel("drbh/yamoe", revision="v0.2.0") - - # Expert weights - use the loaded weights - self.gate_up_proj = nn.Parameter(gate_up_proj.clone()) - self.gate_up_proj_bias = nn.Parameter(gate_up_proj_bias.clone()) - self.down_proj = nn.Parameter(down_proj.clone()) - self.down_proj_bias = nn.Parameter(down_proj_bias.clone()) - - def forward(self, hidden_states): - batch_size, seq_len, hidden_dim = hidden_states.shape - - # Get routing decisions - routing_weights, router_indices = self.router(hidden_states) - - # Reshape for Yamoe kernel - hidden_states_flat = hidden_states.view(-1, hidden_dim) - routing_weights_flat = routing_weights.view(-1, self.num_experts) - expert_capacity = ceil_div(batch_size * self.top_k, self.num_experts) - - # Call Yamoe optimized kernel - output = self.yamoe.experts( - hidden_states_flat, - router_indices, - routing_weights_flat, - self.gate_up_proj, - self.gate_up_proj_bias, - self.down_proj, - self.down_proj_bias, - expert_capacity, - self.num_experts, - self.top_k, - ) - - # Reshape output back - output = output.view(batch_size, seq_len, hidden_dim) - - return output, routing_weights - -# Run the model -set_seed(GENERAL_SEED) - -device = torch.device(DEVICE if DEVICE == "cuda" else "cuda") -dtype = to_dtype(DTYPE) - -print("\n=== Yamoe Implementation ===") -# Initialize model with loaded weights -model = YamoeMoEMLP( - router_weight.to(device), - router_bias.to(device), - gate_up_proj.to(device), - gate_up_proj_bias.to(device), - down_proj.to(device), - down_proj_bias.to(device) -).to(device=device) - -print(f"Router weight sum: {model.router.weight.sum().item():.6f}") -print(f"Gate/up proj sum: {model.gate_up_proj.sum().item():.6f}") -print(f"Down proj sum: {model.down_proj.sum().item():.6f}") - -# Generate input -set_seed(INPUT_SEED) -x = torch.randn(BATCH_SIZE, SEQ_LEN, HIDDEN_SIZE, device=device, dtype=dtype) * 0.1 - -# Benchmark the model with varied inputs to prevent caching artifacts -tokens = BATCH_SIZE * SEQ_LEN -with bench_context(warmup=10, iters=50, device=device, dtype=dtype, tokens=tokens, save_json="yamoe_results.json", vary_inputs=True) as bench: - output, stats = bench(model, x) - print(f"\nOutput sum: {output[0].sum().item():.6f}") \ No newline at end of file diff --git a/megablocks_yamoe/index.html b/megablocks_yamoe/index.html deleted file mode 100644 index e2b56801bc9da4ad11d988da7e1af7718213fee7..0000000000000000000000000000000000000000 --- a/megablocks_yamoe/index.html +++ /dev/null @@ -1,25 +0,0 @@ - - - - - Directory Index - - - -

Index of /megablocks_yamoe

- - - \ No newline at end of file diff --git a/megablocks_yamoe/megablocks_yamoe.html b/megablocks_yamoe/megablocks_yamoe.html deleted file mode 100644 index a8cf5d4877e09a7a05950f6858ecd6bd231a819f..0000000000000000000000000000000000000000 --- a/megablocks_yamoe/megablocks_yamoe.html +++ /dev/null @@ -1,3997 +0,0 @@ - - - - - - uvnote Integration Test Report - - - - - - - -
-
-
light
-
reset
- -
-
- -
-
Generated on:
-
- Linux x86_64 | Linux-6.12.40-64.114.amzn2023.x86_64-x86_64-with-glibc2.36 -
-
- -
-

Comparison of Megablocks and Yamoe Kernels

-

This note compares the performance of the Megablocks and Yamoe kernels on the GPT-OSS-20B model.

-

Megablocks kernel

-
-
- -▼ code -▼ output - ▶ uv-logs - | -Cell: setup2 | 18.93s | FAILED - | - -Raw -
-
-
-
-1 -2 -3 -4 -5 -6 -7 -8 -9 -10 -11 -12 -13 -14 -15 -16 -17 -18 -19 -20 -21 -22 -23 -24 -25 -26 -27 -28 -29 -30 -31 -32 -33 -34 -35 -36 -37 -38 -39 -40 -41 -42 -43 -44 -45 -46 -47 -48 -49 -50 -51 -52 -53 -54 -55 -56 -57 -58 -59 -60 -61 -62 -63 -64 -65 -66 -67 -68 -69 -70 -71 -72 -73 -74 -75 -76 -77 -78 -79 -80 -81 -82 -83 -84 -85 -86 -87 -88 -89 -90 -91 -92 -93 -94 -95 -96 -97 -98 -99 -100 -101 -102 -103 -104 -105 -106 -107 -108 -109 -110 -111 -112 -113 -114 -115 -
-
-
# /// script
-# requires-python = ">=3.12"
-# dependencies = [
-#     "accelerate>=1.10.1",
-#     "torch>=2.7.0",
-#     "kernels==0.10.0",
-#     "transformers@https://github.com/huggingface/transformers.git",
-#     "ipdb>=0.13.13",
-#     "matplotlib>=3.7.2",
-#     "numpy>=1.24.3",
-# ]
-# ///
-
-import torch
-from transformers import GptOssForCausalLM, PreTrainedTokenizerFast, Mxfp4Config
-import time
-import torch.nn as nn
-from kernels import register_kernel_mapping, Mode, LayerRepository
-import sys
-import torch.profiler
-import gc
-import logging
-
-# set to debug logging
-logging.basicConfig(level=logging.INFO)
-
-def reset_peak_memory_stats():
-    """Clear CUDA cache and reset memory allocation counters."""
-    torch.cuda.empty_cache()
-    if torch.cuda.is_available():
-        torch.cuda.reset_peak_memory_stats()
-    gc.collect()
-
-def get_memory_stats():
-    """Get current and peak CUDA memory usage."""
-    if not torch.cuda.is_available():
-        return {"allocated_gb": 0, "peak_gb": 0, "reserved_gb": 0}
-    return {
-        "allocated_gb": torch.cuda.memory_allocated() / 1e9,
-        "peak_gb": torch.cuda.max_memory_allocated() / 1e9,
-        "reserved_gb": torch.cuda.memory_reserved() / 1e9,
-    }
-
-def override_kernel_layer_name(cls_name: str, value) -> bool:
-    """Helper to dynamically override the kernel_layer_name in a model class."""
-    for mod in sys.modules.values():
-        if mod is None:
-            continue
-        obj = getattr(mod, cls_name, None)
-        if isinstance(obj, type) and issubclass(obj, nn.Module):
-            setattr(obj, "kernel_layer_name", value)
-            print(f"Overrode {cls_name}.kernel_layer_name to {value}")
-            return True
-    return False
-
-
-# Init the model the normal way
-model_id = "openai/gpt-oss-20b"
-tokenizer = PreTrainedTokenizerFast.from_pretrained(model_id)
-quantization_config = Mxfp4Config(dequantize=True)
-
-
-from kernels import replace_kernel_forward_from_hub, register_kernel_mapping, LayerRepository, Mode
-
-from transformers.models.gpt_oss.modeling_gpt_oss import GptOssMLP, GptOssRMSNorm
-
-replace_kernel_forward_from_hub(GptOssRMSNorm, None)  # direct, type-safe
-custom_mapping = {
-    "Yamoe": {
-        "cuda": {
-            Mode.INFERENCE: LayerRepository(
-                repo_id="drbh/yamoe",
-                layer_name="Yamoe",
-                revision="v0.3.0",
-            )
-        }
-    }
-}
-register_kernel_mapping(custom_mapping)
-
-
-model = GptOssForCausalLM.from_pretrained(
-    model_id,
-    dtype="bfloat16",
-    device_map="auto",
-    use_kernels=True,
-    quantization_config=quantization_config,
-).eval()
-
-messages = [
-    {"role": "system", "content": "What is Tensor Parallelism?"},
-]
-
-inputs = tokenizer.apply_chat_template(
-    messages,
-    add_generation_prompt=True,
-    return_tensors="pt",
-    return_dict=True,
-    reasoning_effort="low",
-).to("cuda")
-
-max_tokens = 256
-
-with torch.inference_mode():
-    start_time = time.perf_counter()
-    generated = model.generate(
-        **inputs,
-        max_new_tokens=max_tokens,
-        do_sample=False,
-        temperature=None,
-    )
-    end_time = time.perf_counter()
-
-print(tokenizer.decode(generated[0], skip_special_tokens=False))
-print(f"Generation took {end_time - start_time:.2f} seconds")
-
- -
-
-
-
-
-
Downloading cpython-3.13.7-linux-x86_64-gnu (download) (32.0MiB) - Downloading cpython-3.13.7-linux-x86_64-gnu (download) - Updating https://github.com/huggingface/transformers.git (HEAD) - Updated https://github.com/huggingface/transformers.git (e691f84412563b6abca098f3e044980725d8daa3) - × No solution found when resolving script dependencies: - ╰─▶ Because only transformers==4.57.0.dev0 is available and - transformers==4.57.0.dev0 depends on huggingface-hub==1.0.0rc1, - we can conclude that all versions of transformers depend on - huggingface-hub==1.0.0rc1. - And because kernels==0.10.0 depends on huggingface-hub>=0.26.0,<1.0, - we can conclude that kernels==0.10.0 and all versions of transformers - are incompatible. - And because you require kernels==0.10.0 and transformers, we can - conclude that your requirements are unsatisfiable. -
-
-
- -

Yamoe Kernel

-
- - - \ No newline at end of file diff --git a/megablocks_yamoe/torch_profile.html b/megablocks_yamoe/torch_profile.html deleted file mode 100644 index 9da1188cfcfd30377673bfc35c64018b27f7114e..0000000000000000000000000000000000000000 --- a/megablocks_yamoe/torch_profile.html +++ /dev/null @@ -1,6728 +0,0 @@ - - - - - - Compare Yamoe and Binned MoE Implementations - - - - - - - -
-
-
light
-
reset
- -
-
- -
-
Generated on:
-
- Linux x86_64 | Linux-6.12.40-64.114.amzn2023.x86_64-x86_64-with-glibc2.36 -
-
- -
-
-
- -▼ code -▼ output - ▶ uv-logs - | -Cell: utils | deps: torch, numpy | 34.88s - | - -Raw -
-
-
-
-1 -2 -3 -4 -5 -6 -7 -8 -9 -10 -11 -12 -13 -14 -15 -16 -17 -18 -19 -20 -21 -22 -23 -24 -25 -26 -27 -
-
-
"""Simple utilities for running the models."""
-import torch
-
-def to_dtype(dtype_str: str):
-    """Convert string to torch dtype."""
-    if dtype_str == "float16":
-        return torch.float16
-    if dtype_str == "bfloat16":
-        return torch.bfloat16
-    return torch.float32
-
-def tensor_stats(t: torch.Tensor) -> str:
-    """Generate stats string for a tensor."""
-    return (f"shape={tuple(t.shape)}, "
-            f"dtype={t.dtype}, "
-            f"device={t.device}, "
-            f"mean={t.mean().item():.6f}, "
-            f"std={t.std().item():.6f}")
-
-def set_seed(seed: int):
-    """Set seeds for reproducibility."""
-    torch.manual_seed(seed)
-    if torch.cuda.is_available():
-        torch.cuda.manual_seed(seed)
-        torch.cuda.manual_seed_all(seed)
-        torch.backends.cudnn.deterministic = True
-        torch.backends.cudnn.benchmark = False
-
- -
-
-
-
-
-
-
▶ UV Install Logs
- -
-
-
- -
-
- -▼ code -▼ output - ▶ uv-logs - | -Cell: bench_utils | deps: torch, numpy | 34.66s - | - -Raw -
-
-
-
-1 -2 -3 -4 -5 -6 -7 -8 -9 -10 -11 -12 -13 -14 -15 -16 -17 -18 -19 -20 -21 -22 -23 -24 -25 -26 -27 -28 -29 -30 -31 -32 -33 -34 -35 -36 -37 -38 -39 -40 -41 -42 -43 -44 -45 -46 -47 -48 -49 -50 -51 -52 -53 -54 -55 -56 -57 -58 -59 -60 -61 -62 -63 -64 -65 -66 -67 -68 -69 -70 -71 -72 -73 -74 -75 -76 -77 -78 -79 -80 -81 -82 -83 -84 -85 -86 -87 -88 -89 -90 -91 -92 -93 -94 -95 -96 -97 -98 -99 -100 -101 -102 -103 -104 -105 -106 -107 -108 -109 -110 -111 -112 -113 -114 -115 -116 -117 -118 -119 -120 -121 -122 -123 -124 -125 -126 -127 -128 -129 -130 -131 -132 -133 -134 -135 -136 -137 -138 -139 -140 -141 -142 -143 -144 -145 -146 -147 -148 -149 -150 -151 -152 -153 -154 -155 -156 -157 -158 -159 -160 -161 -162 -163 -164 -165 -166 -167 -168 -169 -170 -171 -172 -173 -174 -175 -176 -177 -178 -179 -180 -181 -182 -183 -184 -185 -186 -187 -188 -189 -190 -191 -192 -193 -194 -195 -196 -197 -198 -199 -200 -201 -202 -203 -204 -205 -206 -207 -208 -209 -210 -211 -212 -213 -214 -215 -216 -217 -218 -219 -220 -221 -222 -223 -224 -225 -226 -227 -228 -229 -230 -231 -232 -233 -234 -
-
-
"""Reusable benchmarking utilities for performance testing."""
-import time
-import numpy as np
-from contextlib import contextmanager
-from typing import Callable, Dict, Tuple, Any, Optional
-import torch
-
-def to_dtype(dtype_str: str):
-    """Convert string to torch dtype."""
-    if dtype_str == "float16":
-        return torch.float16
-    if dtype_str == "bfloat16":
-        return torch.bfloat16
-    return torch.float32
-
-def _sync(device: str):
-    """Synchronize device if CUDA."""
-    if device == "cuda":
-        torch.cuda.synchronize()
-
-def _compute_stats(times_s, tokens: Optional[int] = None) -> Dict[str, float]:
-    """Compute comprehensive latency and throughput statistics."""
-    lat_ms = np.array([t * 1000.0 for t in times_s])
-    lat_ms_sorted = np.sort(lat_ms)
-    n = len(lat_ms)
-
-    stats = {
-        "avg_ms": np.mean(lat_ms),
-        "min_ms": np.min(lat_ms),
-        "max_ms": np.max(lat_ms),
-        "std_ms": np.std(lat_ms),
-        "p50_ms": np.percentile(lat_ms, 50),
-        "p95_ms": np.percentile(lat_ms, 95),
-        "p99_ms": np.percentile(lat_ms, 99),
-        "num_iters": n
-    }
-
-    if tokens is not None and n > 0:
-        avg_s = np.mean(times_s)
-        stats["tokens_per_s"] = tokens / avg_s if avg_s > 0 else float("inf")
-        stats["throughput_variance"] = np.std([tokens / t for t in times_s if t > 0])
-
-    return stats
-
-def _format_timing_stats(stats: Dict[str, float], tokens: Optional[int] = None) -> str:
-    """Format timing statistics for display."""
-    lines = [
-        "\n━━━━━━━━━━━━━━━━━━━━ Benchmark Results ━━━━━━━━━━━━━━━━━━━━",
-        f"Iterations: {stats.get('num_iters', 0)}",
-        "\nLatency Statistics:",
-        f"  Average: {stats['avg_ms']:.3f} ms",
-        f"  Min:     {stats['min_ms']:.3f} ms",
-        f"  Max:     {stats['max_ms']:.3f} ms", 
-        f"  Std Dev: {stats['std_ms']:.3f} ms",
-        "\nPercentiles:",
-        f"  P50 (median): {stats['p50_ms']:.3f} ms",
-        f"  P95:          {stats['p95_ms']:.3f} ms",
-        f"  P99:          {stats['p99_ms']:.3f} ms",
-    ]
-
-    if tokens is not None and 'tokens_per_s' in stats:
-        lines.extend([
-            "\nThroughput:",
-            f"  Tokens/sec: {stats['tokens_per_s']:.1f}",
-            f"  Std Dev:    {stats.get('throughput_variance', 0):.1f}",
-        ])
-
-    lines.append("━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━")
-    return "\n".join(lines)
-
-def _bench_engine(
-    call: Callable[[], Any], *, warmup: int, iters: int, device: str, dtype, input_gen: Callable[[], Any] = None
-) -> Tuple[Any, list]:
-    """Core benchmarking engine with warmup and timing."""
-    use_autocast = device == "cuda" and dtype in (torch.float16, torch.bfloat16)
-
-    # Warmup phase
-    print(f"\nWarming up ({warmup} iterations)...")
-    with torch.inference_mode():
-        for _ in range(max(0, warmup)):
-            if use_autocast:
-                with torch.autocast(device_type="cuda", dtype=dtype):
-                    if input_gen is not None:
-                        _ = call(input_gen())
-                    else:
-                        _ = call()
-            else:
-                if input_gen is not None:
-                    _ = call(input_gen())
-                else:
-                    _ = call()
-        _sync(device)
-
-    # Measurement phase
-    print(f"Benchmarking ({iters} iterations)...")
-    times_s = []
-    last = None
-    with torch.inference_mode():
-        for i in range(max(1, iters)):
-            start = time.perf_counter()
-            if use_autocast:
-                with torch.autocast(device_type="cuda", dtype=dtype):
-                    if input_gen is not None:
-                        last = call(input_gen())
-                    else:
-                        last = call()
-            else:
-                if input_gen is not None:
-                    last = call(input_gen())
-                else:
-                    last = call()
-            _sync(device)
-            end = time.perf_counter()
-            times_s.append(end - start)
-
-            # Progress indicator every 20% of iterations
-            if i > 0 and i % max(1, iters // 5) == 0:
-                pct = (i / iters) * 100
-                avg_so_far = np.mean(times_s[:i]) * 1000
-                print(f"  Progress: {pct:.0f}% complete (avg: {avg_so_far:.3f} ms)")
-
-    return last, times_s
-
-def tensor_stats(t: torch.Tensor) -> str:
-    """Generate comprehensive stats string for a tensor."""
-    return (f"shape={tuple(t.shape)}, "
-            f"dtype={t.dtype}, "
-            f"device={t.device}, "
-            f"range=[{t.min().item():.6f}, {t.max().item():.6f}], "
-            f"mean={t.mean().item():.6f}, "
-            f"std={t.std().item():.6f}, "
-            f"norm={t.norm().item():.6f}")
-
-@contextmanager
-def bench_context(
-    *, warmup: int = 25, iters: int = 100, device: str = "cuda", dtype=torch.float32, tokens: Optional[int] = None, verbose: bool = True, save_json: Optional[str] = None, vary_inputs: bool = True
-):
-    """Context that yields a runner: runner(fn, *args, **kwargs) -> (result, stats).
-
-    If vary_inputs=True, the first argument should be a base tensor that will be varied each iteration
-    by adding a small deterministic increment to prevent caching artifacts.
-    """
-
-    def runner(fn: Callable[..., Any], *args, **kwargs) -> Tuple[Any, Dict[str, float]]:
-        # Log configuration
-        if verbose:
-            print(f"\n┌─ Benchmark Configuration ─────────────────────────────┐")
-            # print(f"│ Device: {device:<15} Dtype: {dtype}              │")
-            print(f"│ Warmup: {warmup:<15} Iters: {iters}              │")
-            if tokens:
-                print(f"│ Tokens: {tokens}                                        │")
-            if vary_inputs:
-                print(f"│ Input Variation: Enabled (prevents caching artifacts)  │")
-            print(f"└────────────────────────────────────────────────────────┘")
-
-        # Set up input generation
-        input_gen = None
-        if vary_inputs and args and isinstance(args[0], torch.Tensor):
-            base_input = args[0].clone()
-            iteration_counter = [0]  # Use list for mutable closure
-
-            def generate_varied_input():
-                """Generate input tensor varied by iteration to prevent caching."""
-                # Add small deterministic increment: 0.001 * iteration_number
-                varied_input = base_input + (iteration_counter[0] * 0.001)
-                iteration_counter[0] += 1
-                return varied_input
-
-            input_gen = generate_varied_input
-            call = lambda x: fn(x, *args[1:], **kwargs)
-
-            # Log base input stats
-            if verbose:
-                print(f"\nBase Input: {tensor_stats(base_input)}")
-                print(f"Input Variation: +{0.001:.3f} * iteration (deterministic)")
-        else:
-            # Legacy mode - static inputs
-            call = lambda: fn(*args, **kwargs)
-            if verbose and args and isinstance(args[0], torch.Tensor):
-                print(f"\nInput: {tensor_stats(args[0])}")
-
-        result, times_s = _bench_engine(call, warmup=warmup, iters=iters, device=device, dtype=dtype, input_gen=input_gen)
-
-        # Log output if it's a tensor or tuple with tensors
-        if verbose:
-            print("\nOutput tensors:")
-            if isinstance(result, torch.Tensor):
-                print(f"  Primary: {tensor_stats(result)}")
-            elif isinstance(result, tuple) and len(result) > 0 and isinstance(result[0], torch.Tensor):
-                print(f"  Primary: {tensor_stats(result[0])}")
-                if len(result) > 1:
-                    if isinstance(result[1], torch.Tensor):
-                        print(f"  Auxiliary: {tensor_stats(result[1])}")
-                    else:
-                        print(f"  Auxiliary: {type(result[1]).__name__}")
-
-        # Compute and display statistics
-        stats = _compute_stats(times_s, tokens=tokens)
-        if verbose:
-            print(_format_timing_stats(stats, tokens))
-
-        # Save to JSON if requested
-        if save_json:
-            import json
-            json_data = {
-                "implementation": save_json.replace(".json", ""),
-                "config": {
-                    "warmup": warmup,
-                    "iters": iters,
-                    "device": str(device),  # Convert device to string
-                    "dtype": str(dtype),
-                    "tokens": tokens,
-                    "vary_inputs": vary_inputs
-                },
-                "stats": stats,
-                "output_sum": float(result[0].sum().item()) if isinstance(result, tuple) and len(result) > 0 else float(result.sum().item()) if isinstance(result, torch.Tensor) else None
-            }
-            with open(save_json, 'w') as f:
-                json.dump(json_data, f, indent=2)
-            if verbose:
-                print(f"\nSaved benchmark results to {save_json}")
-
-        return result, stats
-
-    yield runner
-
-def set_seed(seed: int):
-    """Set seeds for reproducibility."""
-    torch.manual_seed(seed)
-    if torch.cuda.is_available():
-        torch.cuda.manual_seed(seed)
-        torch.cuda.manual_seed_all(seed)
-        torch.backends.cudnn.deterministic = True
-        torch.backends.cudnn.benchmark = False
-
- -
-
-
-
-
-
-
▶ UV Install Logs
- -
-
-
- -

This notebook benchmarks multiple MoE implementations with varied inputs across iterations to prevent unrealistic caching artifacts and measure true performance characteristics.

-
-
- -▼ code -▼ output - ▶ uv-logs - | -Cell: config | deps: torch, numpy | 35.36s - | - -Raw -
-
-
-
-1 -2 -3 -4 -5 -6 -7 -8 -9 -10 -11 -12 -13 -14 -15 -16 -17 -18 -19 -20 -
-
-
"""Shared configuration for both implementations."""
-import torch
-
-# Model configuration
-NUM_EXPERTS = 128
-HIDDEN_SIZE = 1152
-INTERMEDIATE_SIZE = 3072
-TOP_K = 4
-
-# Input configuration
-BATCH_SIZE = 1
-SEQ_LEN = 100
-DTYPE = "float32"
-DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
-
-# Seeds for reproducibility
-WEIGHT_SEED = 999
-EXPERT_SEED = 777
-INPUT_SEED = 123
-GENERAL_SEED = 42
-
- -
-
-
-
-
-
-
▶ UV Install Logs
- -
-
-
- -
-
- -▼ code -▼ output - ▶ uv-logs - | -Cell: save_data | deps: torch, numpy | 39.03s - | - -Raw -
-
-
-
-1 -2 -3 -4 -5 -6 -7 -8 -9 -10 -11 -12 -13 -14 -15 -16 -17 -18 -19 -20 -21 -22 -23 -24 -25 -26 -27 -28 -29 -30 -31 -32 -33 -34 -35 -
-
-
"""
-Generate deterministic shared weights once and save as artifacts so
-both implementations load identical parameters.
-"""
-import torch
-from config import NUM_EXPERTS, HIDDEN_SIZE, WEIGHT_SEED, EXPERT_SEED
-
-def save_shared_weights():
-    # Router: Kaiming uniform as used by both, bias zeros
-    torch.manual_seed(WEIGHT_SEED)
-    router_weight = torch.empty(NUM_EXPERTS, HIDDEN_SIZE)
-    torch.nn.init.kaiming_uniform_(router_weight)
-    router_bias = torch.zeros(NUM_EXPERTS)
-
-    # Experts: normal(0, 0.02), biases zeros
-    torch.manual_seed(EXPERT_SEED)
-    gate_up_proj = torch.empty(NUM_EXPERTS, HIDDEN_SIZE, 2 * HIDDEN_SIZE).normal_(mean=0.0, std=0.02)
-    gate_up_proj_bias = torch.zeros(NUM_EXPERTS, 2 * HIDDEN_SIZE)
-    down_proj = torch.empty(NUM_EXPERTS, HIDDEN_SIZE, HIDDEN_SIZE).normal_(mean=0.0, std=0.02)
-    down_proj_bias = torch.zeros(NUM_EXPERTS, HIDDEN_SIZE)
-
-    # Save artifacts
-    torch.save(router_weight, 'router_weight.pt')
-    torch.save(router_bias, 'router_bias.pt')
-    torch.save(gate_up_proj, 'gate_up_proj.pt')
-    torch.save(gate_up_proj_bias, 'gate_up_proj_bias.pt')
-    torch.save(down_proj, 'down_proj.pt')
-    torch.save(down_proj_bias, 'down_proj_bias.pt')
-
-    print("Saved shared weights to artifacts")
-    print(f"Router weight sum: {router_weight.sum().item():.6f}")
-    print(f"Gate/up sum: {gate_up_proj.sum().item():.6f}")
-    print(f"Down sum: {down_proj.sum().item():.6f}")
-
-save_shared_weights()
-
- -
-
-
-
-
-
Saved shared weights to artifacts -Router weight sum: 12.588732 -Gate/up sum: 1026.601807 -Down sum: 206.729263 -
-
-
▶ UV Install Logs
- -
- -
-
- -

Yamoe Implementation

-

This section runs the Yamoe MoE implementation with optimized Triton kernels.

-
-
- -▼ code -▼ output - ▶ uv-logs - | -Cell: yamoe_run | deps: torch, kernels, numpy | 39.06s - | - -Raw -
-
-
-
-1 -2 -3 -4 -5 -6 -7 -8 -9 -10 -11 -12 -13 -14 -15 -16 -17 -18 -19 -20 -21 -22 -23 -24 -25 -26 -27 -28 -29 -30 -31 -32 -33 -34 -35 -36 -37 -38 -39 -40 -41 -42 -43 -44 -45 -46 -47 -48 -49 -50 -51 -52 -53 -54 -55 -56 -57 -58 -59 -60 -61 -62 -63 -64 -65 -66 -67 -68 -69 -70 -71 -72 -73 -74 -75 -76 -77 -78 -79 -80 -81 -82 -83 -84 -85 -86 -87 -88 -89 -90 -91 -92 -93 -94 -95 -96 -97 -98 -99 -100 -101 -102 -103 -104 -105 -106 -107 -108 -109 -110 -111 -112 -113 -114 -115 -116 -117 -118 -119 -120 -121 -122 -123 -124 -125 -126 -127 -
-
-
import torch
-from torch import nn
-from torch.nn import functional as F
-from kernels import get_kernel, get_local_kernel
-from bench_utils import to_dtype, tensor_stats, set_seed, bench_context
-from config import (
-    NUM_EXPERTS, HIDDEN_SIZE, TOP_K,
-    BATCH_SIZE, SEQ_LEN, DTYPE, DEVICE,
-    WEIGHT_SEED, EXPERT_SEED, INPUT_SEED, GENERAL_SEED
-)
-from pathlib import Path
-import os
-
-# Discover the upstream artifact directory from env
-data_dir = os.environ.get('UVNOTE_INPUT_SAVE_DATA', '.')
-print(f"Loading weights from: {data_dir}")
-
-router_weight = torch.load(Path(data_dir) / 'router_weight.pt')
-router_bias = torch.load(Path(data_dir) / 'router_bias.pt')
-gate_up_proj = torch.load(Path(data_dir) / 'gate_up_proj.pt')
-gate_up_proj_bias = torch.load(Path(data_dir) / 'gate_up_proj_bias.pt')
-down_proj = torch.load(Path(data_dir) / 'down_proj.pt')
-down_proj_bias = torch.load(Path(data_dir) / 'down_proj_bias.pt')
-
-print("Loaded shared weights from artifacts")
-print(f"Router weight sum: {router_weight.sum().item():.6f}")
-print(f"Gate/up sum: {gate_up_proj.sum().item():.6f}")
-print(f"Down sum: {down_proj.sum().item():.6f}")
-
-class YamoeRouter(nn.Module):
-    def __init__(self, router_weight, router_bias):
-        super().__init__()
-        self.top_k = TOP_K
-        self.num_experts = NUM_EXPERTS
-        self.hidden_dim = HIDDEN_SIZE
-        self.weight = nn.Parameter(router_weight.clone())
-        self.bias = nn.Parameter(router_bias.clone())
-
-    def forward(self, hidden_states):
-        hidden_states = hidden_states.reshape(-1, self.hidden_dim)
-        router_logits = F.linear(hidden_states, self.weight, self.bias)
-        router_top_value, router_indices = torch.topk(router_logits, self.top_k, dim=-1)
-        router_top_value = torch.nn.functional.softmax(router_top_value, dim=1, dtype=router_top_value.dtype)
-        router_scores = torch.zeros_like(router_logits).scatter_(1, router_indices, router_top_value)
-        return router_scores, router_indices
-
-def ceil_div(a, b):
-    return (a + b - 1) // b
-
-class YamoeMoEMLP(nn.Module):
-    def __init__(self, router_weight, router_bias, gate_up_proj, gate_up_proj_bias, down_proj, down_proj_bias):
-        super().__init__()
-        self.router = YamoeRouter(router_weight, router_bias)
-        self.num_experts = NUM_EXPERTS
-        self.hidden_size = HIDDEN_SIZE
-        self.top_k = TOP_K
-
-        # Load Yamoe kernel
-        # self.yamoe = get_local_kernel(Path("/home/ubuntu/Projects/yamoe/result"), "yamoe")
-        self.yamoe = get_kernel("drbh/yamoe", revision="v0.2.0")
-
-        # Expert weights - use the loaded weights
-        self.gate_up_proj = nn.Parameter(gate_up_proj.clone())
-        self.gate_up_proj_bias = nn.Parameter(gate_up_proj_bias.clone())
-        self.down_proj = nn.Parameter(down_proj.clone())
-        self.down_proj_bias = nn.Parameter(down_proj_bias.clone())
-
-    def forward(self, hidden_states):
-        batch_size, seq_len, hidden_dim = hidden_states.shape
-
-        # Get routing decisions
-        routing_weights, router_indices = self.router(hidden_states)
-
-        # Reshape for Yamoe kernel
-        hidden_states_flat = hidden_states.view(-1, hidden_dim)
-        routing_weights_flat = routing_weights.view(-1, self.num_experts)
-        expert_capacity = ceil_div(batch_size * self.top_k, self.num_experts)
-
-        # Call Yamoe optimized kernel
-        output = self.yamoe.experts(
-            hidden_states_flat,
-            router_indices,
-            routing_weights_flat,
-            self.gate_up_proj,
-            self.gate_up_proj_bias,
-            self.down_proj,
-            self.down_proj_bias,
-            expert_capacity,
-            self.num_experts,
-            self.top_k,
-        )
-
-        # Reshape output back
-        output = output.view(batch_size, seq_len, hidden_dim)
-
-        return output, routing_weights
-
-# Run the model
-set_seed(GENERAL_SEED)
-
-device = torch.device(DEVICE if DEVICE == "cuda" else "cuda")
-dtype = to_dtype(DTYPE)
-
-print("\n=== Yamoe Implementation ===")
-# Initialize model with loaded weights
-model = YamoeMoEMLP(
-    router_weight.to(device),
-    router_bias.to(device),
-    gate_up_proj.to(device),
-    gate_up_proj_bias.to(device),
-    down_proj.to(device),
-    down_proj_bias.to(device)
-).to(device=device)
-
-print(f"Router weight sum: {model.router.weight.sum().item():.6f}")
-print(f"Gate/up proj sum: {model.gate_up_proj.sum().item():.6f}")
-print(f"Down proj sum: {model.down_proj.sum().item():.6f}")
-
-# Generate input
-set_seed(INPUT_SEED)
-x = torch.randn(BATCH_SIZE, SEQ_LEN, HIDDEN_SIZE, device=device, dtype=dtype) * 0.1
-
-# Benchmark the model with varied inputs to prevent caching artifacts
-tokens = BATCH_SIZE * SEQ_LEN
-with bench_context(warmup=10, iters=50, device=device, dtype=dtype, tokens=tokens, save_json="yamoe_results.json", vary_inputs=True) as bench:
-    output, stats = bench(model, x)
-    print(f"\nOutput sum: {output[0].sum().item():.6f}")
-
- -
-
-
-
-
-
Loading weights from: /repo/moe_benchmarks/megablocks_yamoe/.uvnote/cache/f8744f31d9cf720409852d42748815c6d61f005a2a9b297b7b9bf986ed98bb90 -Loaded shared weights from artifacts -Router weight sum: 12.588732 -Gate/up sum: 1026.601807 -Down sum: 206.729263 - -=== Yamoe Implementation === -Router weight sum: 12.588732 -Gate/up proj sum: 1026.601807 -Down proj sum: 206.729340 - -┌─ Benchmark Configuration ─────────────────────────────┐ -│ Warmup: 10 Iters: 50 │ -│ Tokens: 100 │ -│ Input Variation: Enabled (prevents caching artifacts) │ -└────────────────────────────────────────────────────────┘ - -Base Input: shape=(1, 100, 1152), dtype=torch.float32, device=cuda:0, range=[-0.486445, 0.446746], mean=-0.000048, std=0.099986, norm=33.936142 -Input Variation: +0.001 * iteration (deterministic) - -Warming up (10 iterations)... -Benchmarking (50 iterations)... - Progress: 20% complete (avg: 4.247 ms) - Progress: 40% complete (avg: 4.244 ms) - Progress: 60% complete (avg: 4.246 ms) - Progress: 80% complete (avg: 4.246 ms) - -Output tensors: - Primary: shape=(1, 100, 1152), dtype=torch.float32, device=cuda:0, range=[-0.049506, 0.054984], mean=0.000034, std=0.006508, norm=2.208791 - Auxiliary: shape=(100, 128), dtype=torch.float32, device=cuda:0, range=[0.000000, 0.302948], mean=0.007812, std=0.043553, norm=5.005893 - -━━━━━━━━━━━━━━━━━━━━ Benchmark Results ━━━━━━━━━━━━━━━━━━━━ -Iterations: 50 - -Latency Statistics: - Average: 4.248 ms - Min: 4.137 ms - Max: 4.281 ms - Std Dev: 0.021 ms - -Percentiles: - P50 (median): 4.253 ms - P95: 4.266 ms - P99: 4.274 ms - -Throughput: - Tokens/sec: 23539.4 - Std Dev: 120.7 -━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ - -Saved benchmark results to yamoe_results.json - -Output sum: 3.971905 -
-
-
▶ UV Install Logs
- -
-
Fetching 6 files: 0%| | 0/6 [00:00<?, ?it/s] -Fetching 6 files: 17%|█▋ | 1/6 [00:00<00:01, 2.76it/s] -Fetching 6 files: 50%|█████ | 3/6 [00:00<00:00, 3.03it/s] -Fetching 6 files: 100%|██████████| 6/6 [00:00<00:00, 6.01it/s]
-
-

Artifacts:

-yamoe_results.json -
-
-
- -

Binned Implementation

-

This section runs the binned implementation that manually handles token gathering/scattering.

-
-
- -▼ code -▼ output - ▶ uv-logs - | -Cell: binned_run | deps: torch, numpy | 39.51s - | - -Raw -
-
-
-
-1 -2 -3 -4 -5 -6 -7 -8 -9 -10 -11 -12 -13 -14 -15 -16 -17 -18 -19 -20 -21 -22 -23 -24 -25 -26 -27 -28 -29 -30 -31 -32 -33 -34 -35 -36 -37 -38 -39 -40 -41 -42 -43 -44 -45 -46 -47 -48 -49 -50 -51 -52 -53 -54 -55 -56 -57 -58 -59 -60 -61 -62 -63 -64 -65 -66 -67 -68 -69 -70 -71 -72 -73 -74 -75 -76 -77 -78 -79 -80 -81 -82 -83 -84 -85 -86 -87 -88 -89 -90 -91 -92 -93 -94 -95 -96 -97 -98 -99 -100 -101 -102 -103 -104 -105 -106 -107 -108 -109 -110 -111 -112 -113 -114 -115 -116 -117 -118 -119 -120 -121 -122 -123 -124 -125 -126 -127 -128 -129 -130 -131 -132 -133 -134 -135 -136 -137 -138 -139 -140 -141 -142 -143 -144 -145 -146 -147 -148 -149 -150 -151 -152 -153 -154 -155 -156 -157 -158 -159 -160 -161 -162 -163 -164 -165 -166 -167 -168 -169 -170 -171 -172 -173 -174 -175 -176 -177 -178 -179 -180 -181 -182 -183 -184 -185 -186 -187 -188 -
-
-
import torch
-from torch import nn
-from torch.nn import functional as F
-from bench_utils import to_dtype, tensor_stats, set_seed, bench_context
-from config import (
-    NUM_EXPERTS, HIDDEN_SIZE, TOP_K,
-    BATCH_SIZE, SEQ_LEN, DTYPE, DEVICE,
-    WEIGHT_SEED, EXPERT_SEED, INPUT_SEED, GENERAL_SEED
-)
-from pathlib import Path
-import os
-
-# Discover the upstream artifact directory from env
-data_dir = os.environ.get('UVNOTE_INPUT_SAVE_DATA', '.')
-
-router_weight = torch.load(Path(data_dir) / 'router_weight.pt')
-router_bias = torch.load(Path(data_dir) / 'router_bias.pt')
-gate_up_proj = torch.load(Path(data_dir) / 'gate_up_proj.pt')
-gate_up_proj_bias = torch.load(Path(data_dir) / 'gate_up_proj_bias.pt')
-down_proj = torch.load(Path(data_dir) / 'down_proj.pt')
-down_proj_bias = torch.load(Path(data_dir) / 'down_proj_bias.pt')
-
-print("Loaded shared weights from artifacts")
-print(f"Router weight sum: {router_weight.sum().item():.6f}")
-print(f"Gate/up sum: {gate_up_proj.sum().item():.6f}")
-print(f"Down sum: {down_proj.sum().item():.6f}")
-
-def binned_gather(x, indices, bins, expert_capacity, top_k):
-    E, H = bins.shape[0], x.shape[1]
-    out = torch.zeros((E, expert_capacity, H), device=x.device, dtype=x.dtype)
-    for e in range(E):
-        start = 0 if e == 0 else bins[e - 1]
-        end = bins[e]
-        n = min(end - start, expert_capacity)
-        for i in range(n):
-            flat_pos = indices[start + i]
-            tok = flat_pos // top_k
-            out[e, i] = x[tok]
-    return out
-
-def binned_scatter(x, indices, weights, bins, expert_capacity, top_k):
-    E, C, H = x.shape
-    N = indices.shape[0] // top_k
-    out = torch.zeros((N, top_k, H), dtype=x.dtype, device=x.device)
-    for e in range(E):
-        start = 0 if e == 0 else bins[e - 1]
-        end = bins[e]
-        n = end - start
-        if n == 0:
-            continue
-        take = min(n, expert_capacity)
-        for i in range(take):
-            flat_pos = indices[start + i]
-            tok = flat_pos // top_k
-            slot = flat_pos % top_k
-            scale = weights[flat_pos] if weights is not None else 1.0
-            out[tok, slot] = x[e, i] * scale
-    return out.sum(dim=1)
-
-def sort_tokens_by_expert(router_indices, num_experts):
-    flat_indices = router_indices.flatten()
-    sorted_values, sorted_indices = torch.sort(flat_indices)
-    tokens_per_expert = torch.bincount(sorted_values, minlength=num_experts)
-    bins = torch.cumsum(tokens_per_expert, dim=0)
-    return sorted_indices, sorted_values, bins, tokens_per_expert
-
-def binned_experts_ref(
-    hidden_states,
-    router_indices,
-    routing_weights,
-    gate_up_proj,
-    gate_up_proj_bias,
-    down_proj,
-    down_proj_bias,
-    expert_capacity,
-):
-    B, S, H = hidden_states.shape
-    E, K = routing_weights.shape[1], router_indices.shape[1]
-
-    indices, _, bins, _ = sort_tokens_by_expert(router_indices, E)
-    x = binned_gather(hidden_states.view(-1, H), indices, bins, expert_capacity, K)
-
-    gate_up = torch.bmm(x, gate_up_proj) 
-    gate_up += gate_up_proj_bias[..., None, :]
-
-    gate, up = gate_up[..., ::2], gate_up[..., 1::2]
-
-    # clamp to limit
-    limit = 7.0
-    gate = gate.clamp(min=None, max=limit)
-    up = up.clamp(min=-limit, max=limit)
-
-    glu = gate * torch.sigmoid(gate * 1.702)
-    x = (up + 1) * glu
-    x = torch.bmm(x, down_proj) + down_proj_bias[..., None, :]
-
-    # build routing weights aligned to (token, slot)
-    flat_dense = routing_weights.view(-1, E)
-    flat_router = router_indices.view(-1, K)
-    selected = torch.gather(flat_dense, 1, flat_router).reshape(-1)
-
-    # scatter back
-    y = binned_scatter(x, indices, selected, bins, expert_capacity, K)
-
-    return y.view(B, S, H)
-
-class BinnedRouter(nn.Module):
-    def __init__(self, router_weight, router_bias):
-        super().__init__()
-        self.top_k = TOP_K
-        self.num_experts = NUM_EXPERTS
-        self.hidden_dim = HIDDEN_SIZE
-        self.weight = nn.Parameter(router_weight.clone())
-        self.bias = nn.Parameter(router_bias.clone())
-
-    def forward(self, hidden_states):
-        hidden_states = hidden_states.reshape(-1, self.hidden_dim)
-        router_logits = F.linear(hidden_states, self.weight, self.bias)
-        router_top_value, router_indices = torch.topk(router_logits, self.top_k, dim=-1)
-        router_top_value = torch.nn.functional.softmax(router_top_value, dim=1, dtype=router_top_value.dtype)
-        router_scores = torch.zeros_like(router_logits).scatter_(1, router_indices, router_top_value)
-        return router_scores, router_indices
-
-def ceil_div(a, b):
-    return (a + b - 1) // b
-
-class BinnedMoEMLP(nn.Module):
-    def __init__(self, router_weight, router_bias, gate_up_proj, gate_up_proj_bias, down_proj, down_proj_bias):
-        super().__init__()
-        self.router = BinnedRouter(router_weight, router_bias)
-        self.num_experts = NUM_EXPERTS
-        self.hidden_size = HIDDEN_SIZE
-        self.top_k = TOP_K
-
-        # Expert weights - use the loaded weights
-        self.gate_up_proj = nn.Parameter(gate_up_proj.clone())
-        self.gate_up_proj_bias = nn.Parameter(gate_up_proj_bias.clone())
-        self.down_proj = nn.Parameter(down_proj.clone())
-        self.down_proj_bias = nn.Parameter(down_proj_bias.clone())
-
-    def forward(self, hidden_states):
-        router_scores, router_indices = self.router(hidden_states)
-        batch_size = hidden_states.shape[0]
-        expert_capacity = ceil_div(batch_size * self.top_k, self.num_experts)
-
-        output = binned_experts_ref(
-            hidden_states,
-            router_indices,
-            router_scores,
-            self.gate_up_proj,
-            self.gate_up_proj_bias,
-            self.down_proj,
-            self.down_proj_bias,
-            expert_capacity,
-        )
-
-        return output, router_scores
-
-# Run the model
-set_seed(GENERAL_SEED)
-
-device = torch.device(DEVICE)
-dtype = to_dtype(DTYPE)
-
-print("\n=== Binned Implementation ===")
-# Initialize model with loaded weights
-model = BinnedMoEMLP(
-    router_weight.to(device),
-    router_bias.to(device),
-    gate_up_proj.to(device),
-    gate_up_proj_bias.to(device),
-    down_proj.to(device),
-    down_proj_bias.to(device)
-).to(device=device)
-
-print(f"Router weight sum: {model.router.weight.sum().item():.6f}")
-print(f"Gate/up proj sum: {model.gate_up_proj.sum().item():.6f}")
-print(f"Down proj sum: {model.down_proj.sum().item():.6f}")
-
-# Generate the same input as Yamoe
-set_seed(INPUT_SEED)
-x = torch.randn(BATCH_SIZE, SEQ_LEN, HIDDEN_SIZE, device=device, dtype=dtype) * 0.1
-
-# Benchmark the model with varied inputs to prevent caching artifacts
-tokens = BATCH_SIZE * SEQ_LEN
-with bench_context(warmup=10, iters=50, device=device, dtype=dtype, tokens=tokens, save_json="binned_results.json", vary_inputs=True) as bench:
-    output, stats = bench(model, x)
-    print(f"\nOutput sum: {output[0].sum().item():.6f}")
-
- -
-
-
-
-
-
Loaded shared weights from artifacts -Router weight sum: 12.588732 -Gate/up sum: 1026.601807 -Down sum: 206.729263 - -=== Binned Implementation === -Router weight sum: 12.588732 -Gate/up proj sum: 1026.601807 -Down proj sum: 206.729340 - -┌─ Benchmark Configuration ─────────────────────────────┐ -│ Warmup: 10 Iters: 50 │ -│ Tokens: 100 │ -│ Input Variation: Enabled (prevents caching artifacts) │ -└────────────────────────────────────────────────────────┘ - -Base Input: shape=(1, 100, 1152), dtype=torch.float32, device=cuda:0, range=[-0.486445, 0.446746], mean=-0.000048, std=0.099986, norm=33.936142 -Input Variation: +0.001 * iteration (deterministic) - -Warming up (10 iterations)... -Benchmarking (50 iterations)... - Progress: 20% complete (avg: 37.524 ms) - Progress: 40% complete (avg: 37.442 ms) - Progress: 60% complete (avg: 37.122 ms) - Progress: 80% complete (avg: 36.627 ms) - -Output tensors: - Primary: shape=(1, 100, 1152), dtype=torch.float32, device=cuda:0, range=[-0.049506, 0.054984], mean=0.000034, std=0.006508, norm=2.208791 - Auxiliary: shape=(100, 128), dtype=torch.float32, device=cuda:0, range=[0.000000, 0.302948], mean=0.007812, std=0.043553, norm=5.005893 - -━━━━━━━━━━━━━━━━━━━━ Benchmark Results ━━━━━━━━━━━━━━━━━━━━ -Iterations: 50 - -Latency Statistics: - Average: 36.268 ms - Min: 34.104 ms - Max: 37.686 ms - Std Dev: 1.160 ms - -Percentiles: - P50 (median): 36.522 ms - P95: 37.643 ms - P99: 37.677 ms - -Throughput: - Tokens/sec: 2757.2 - Std Dev: 89.1 -━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ - -Saved benchmark results to binned_results.json - -Output sum: 3.971905 -
-
-
▶ UV Install Logs
- -
-
-

Artifacts:

-binned_results.json -
-
-
- -

GPT-OSS Implementation

-

This section runs the GPT-OSS MoE implementation with manual expert loop handling.

-
-
- -▼ code -▼ output - ▶ uv-logs - | -Cell: gptoss_run | deps: torch, numpy | 40.20s - | - -Raw -
-
-
-
-1 -2 -3 -4 -5 -6 -7 -8 -9 -10 -11 -12 -13 -14 -15 -16 -17 -18 -19 -20 -21 -22 -23 -24 -25 -26 -27 -28 -29 -30 -31 -32 -33 -34 -35 -36 -37 -38 -39 -40 -41 -42 -43 -44 -45 -46 -47 -48 -49 -50 -51 -52 -53 -54 -55 -56 -57 -58 -59 -60 -61 -62 -63 -64 -65 -66 -67 -68 -69 -70 -71 -72 -73 -74 -75 -76 -77 -78 -79 -80 -81 -82 -83 -84 -85 -86 -87 -88 -89 -90 -91 -92 -93 -94 -95 -96 -97 -98 -99 -100 -101 -102 -103 -104 -105 -106 -107 -108 -109 -110 -111 -112 -113 -114 -115 -116 -117 -118 -119 -120 -121 -122 -123 -124 -125 -126 -127 -128 -129 -130 -131 -132 -133 -134 -135 -136 -137 -138 -139 -140 -
-
-
import torch
-from torch import nn
-from torch.nn import functional as F
-from bench_utils import to_dtype, tensor_stats, set_seed, bench_context
-from config import (
-    NUM_EXPERTS, HIDDEN_SIZE, TOP_K,
-    BATCH_SIZE, SEQ_LEN, DTYPE, DEVICE,
-    WEIGHT_SEED, EXPERT_SEED, INPUT_SEED, GENERAL_SEED
-)
-from pathlib import Path
-import os
-
-# Discover the upstream artifact directory from env
-data_dir = os.environ.get('UVNOTE_INPUT_SAVE_DATA', '.')
-
-router_weight = torch.load(Path(data_dir) / 'router_weight.pt')
-router_bias = torch.load(Path(data_dir) / 'router_bias.pt')
-gate_up_proj = torch.load(Path(data_dir) / 'gate_up_proj.pt')
-gate_up_proj_bias = torch.load(Path(data_dir) / 'gate_up_proj_bias.pt')
-down_proj = torch.load(Path(data_dir) / 'down_proj.pt')
-down_proj_bias = torch.load(Path(data_dir) / 'down_proj_bias.pt')
-
-print("Loaded shared weights from artifacts")
-print(f"Router weight sum: {router_weight.sum().item():.6f}")
-print(f"Gate/up sum: {gate_up_proj.sum().item():.6f}")
-print(f"Down sum: {down_proj.sum().item():.6f}")
-
-class GptOssRouter(nn.Module):
-    def __init__(self, router_weight, router_bias):
-        super().__init__()
-        self.top_k = TOP_K
-        self.num_experts = NUM_EXPERTS
-        self.hidden_dim = HIDDEN_SIZE
-        self.weight = nn.Parameter(router_weight.clone())
-        self.bias = nn.Parameter(router_bias.clone())
-
-    def forward(self, hidden_states):
-        hidden_states = hidden_states.reshape(-1, self.hidden_dim)
-        router_logits = F.linear(hidden_states, self.weight, self.bias)
-        router_top_value, router_indices = torch.topk(router_logits, self.top_k, dim=-1)
-        router_top_value = torch.nn.functional.softmax(router_top_value, dim=1, dtype=router_top_value.dtype)
-        router_scores = torch.zeros_like(router_logits).scatter_(1, router_indices, router_top_value)
-        return router_scores, router_indices
-
-class GptOssExperts(nn.Module):
-    def __init__(self, gate_up_proj, gate_up_proj_bias, down_proj, down_proj_bias):
-        super().__init__()
-        self.num_experts = NUM_EXPERTS
-        self.hidden_size = HIDDEN_SIZE
-        self.expert_dim = self.hidden_size
-        self.gate_up_proj = nn.Parameter(gate_up_proj.clone())
-        self.gate_up_proj_bias = nn.Parameter(gate_up_proj_bias.clone())
-        self.down_proj = nn.Parameter(down_proj.clone())
-        self.down_proj_bias = nn.Parameter(down_proj_bias.clone())
-        self.alpha = 1.702
-        self.limit = 7.0
-
-    def forward(self, hidden_states: torch.Tensor, router_indices=None, routing_weights=None) -> torch.Tensor:
-        batch_size = hidden_states.shape[0]
-        hidden_states = hidden_states.reshape(-1, self.hidden_size)
-        num_experts = routing_weights.shape[1]
-
-        if hidden_states.device.type == "cpu" or self.training:
-            next_states = torch.zeros_like(hidden_states, dtype=hidden_states.dtype, device=hidden_states.device)
-            with torch.no_grad():
-                expert_mask = torch.nn.functional.one_hot(router_indices, num_classes=num_experts)
-                expert_mask = expert_mask.permute(2, 1, 0)
-                expert_hit = torch.greater(expert_mask.sum(dim=(-1, -2)), 0).nonzero()
-
-            for expert_idx in expert_hit[:]:
-                expert_idx = expert_idx[0]
-                with torch.no_grad():
-                    _, token_idx = torch.where(expert_mask[expert_idx])
-                current_state = hidden_states[token_idx]
-                gate_up = current_state @ self.gate_up_proj[expert_idx] + self.gate_up_proj_bias[expert_idx]
-                gate, up = gate_up[..., ::2], gate_up[..., 1::2]
-                gate = gate.clamp(min=None, max=self.limit)
-                up = up.clamp(min=-self.limit, max=self.limit)
-                glu = gate * torch.sigmoid(gate * self.alpha)
-                gated_output = (up + 1) * glu
-                out = gated_output @ self.down_proj[expert_idx] + self.down_proj_bias[expert_idx]
-                weighted_output = out * routing_weights[token_idx, expert_idx, None]
-                next_states.index_add_(0, token_idx, weighted_output.to(hidden_states.dtype))
-            next_states = next_states.view(batch_size, -1, self.hidden_size)
-        else:
-            hidden_states = hidden_states.repeat(num_experts, 1)
-            hidden_states = hidden_states.view(num_experts, -1, self.hidden_size)
-            gate_up = torch.bmm(hidden_states, self.gate_up_proj) + self.gate_up_proj_bias[..., None, :]
-            gate, up = gate_up[..., ::2], gate_up[..., 1::2]
-            gate = gate.clamp(min=None, max=self.limit)
-            up = up.clamp(min=-self.limit, max=self.limit)
-            glu = gate * torch.sigmoid(gate * self.alpha)
-            next_states = torch.bmm(((up + 1) * glu), self.down_proj)
-            next_states = next_states + self.down_proj_bias[..., None, :]
-            next_states = next_states.view(num_experts, batch_size, -1, self.hidden_size)
-            next_states = next_states * routing_weights.transpose(0, 1).view(num_experts, batch_size, -1)[..., None]
-            next_states = next_states.sum(dim=0)
-        return next_states
-
-class GptOssMoEMLP(nn.Module):
-    def __init__(self, router_weight, router_bias, gate_up_proj, gate_up_proj_bias, down_proj, down_proj_bias):
-        super().__init__()
-        self.router = GptOssRouter(router_weight, router_bias)
-        self.experts = GptOssExperts(gate_up_proj, gate_up_proj_bias, down_proj, down_proj_bias)
-
-    def forward(self, hidden_states):
-        router_scores, router_indices = self.router(hidden_states)
-        routed_out = self.experts(hidden_states, router_indices=router_indices, routing_weights=router_scores)
-        return routed_out, router_scores
-
-# Run the model
-set_seed(GENERAL_SEED)
-
-device = torch.device(DEVICE)
-dtype = to_dtype(DTYPE)
-
-print("\n=== GPT-OSS Implementation ===")
-# Initialize model with loaded weights
-model = GptOssMoEMLP(
-    router_weight.to(device),
-    router_bias.to(device),
-    gate_up_proj.to(device),
-    gate_up_proj_bias.to(device),
-    down_proj.to(device),
-    down_proj_bias.to(device)
-).to(device=device)
-
-print(f"Router weight sum: {model.router.weight.sum().item():.6f}")
-print(f"Gate/up proj sum: {model.experts.gate_up_proj.sum().item():.6f}")
-print(f"Down proj sum: {model.experts.down_proj.sum().item():.6f}")
-
-# Generate the same input as other implementations
-set_seed(INPUT_SEED)
-x = torch.randn(BATCH_SIZE, SEQ_LEN, HIDDEN_SIZE, device=device, dtype=dtype) * 0.1
-
-# Benchmark the model with varied inputs to prevent caching artifacts
-tokens = BATCH_SIZE * SEQ_LEN
-with bench_context(warmup=10, iters=50, device=device, dtype=dtype, tokens=tokens, save_json="gptoss_results.json", vary_inputs=True) as bench:
-    output, stats = bench(model, x)
-    print(f"\nOutput sum: {output[0].sum().item():.6f}")
-
- -
-
-
-
-
-
Loaded shared weights from artifacts -Router weight sum: 12.588732 -Gate/up sum: 1026.601807 -Down sum: 206.729263 - -=== GPT-OSS Implementation === -Router weight sum: 12.588732 -Gate/up proj sum: 1026.601807 -Down proj sum: 206.729340 - -┌─ Benchmark Configuration ─────────────────────────────┐ -│ Warmup: 10 Iters: 50 │ -│ Tokens: 100 │ -│ Input Variation: Enabled (prevents caching artifacts) │ -└────────────────────────────────────────────────────────┘ - -Base Input: shape=(1, 100, 1152), dtype=torch.float32, device=cuda:0, range=[-0.486445, 0.446746], mean=-0.000048, std=0.099986, norm=33.936142 -Input Variation: +0.001 * iteration (deterministic) - -Warming up (10 iterations)... -Benchmarking (50 iterations)... - Progress: 20% complete (avg: 50.493 ms) - Progress: 40% complete (avg: 49.981 ms) - Progress: 60% complete (avg: 49.061 ms) - Progress: 80% complete (avg: 47.981 ms) - -Output tensors: - Primary: shape=(1, 100, 1152), dtype=torch.float32, device=cuda:0, range=[-0.064982, 0.061193], mean=0.000100, std=0.013510, norm=4.585560 - Auxiliary: shape=(100, 128), dtype=torch.float32, device=cuda:0, range=[0.000000, 0.302948], mean=0.007812, std=0.043553, norm=5.005893 - -━━━━━━━━━━━━━━━━━━━━ Benchmark Results ━━━━━━━━━━━━━━━━━━━━ -Iterations: 50 - -Latency Statistics: - Average: 46.914 ms - Min: 40.448 ms - Max: 51.075 ms - Std Dev: 2.992 ms - -Percentiles: - P50 (median): 47.419 ms - P95: 50.800 ms - P99: 50.949 ms - -Throughput: - Tokens/sec: 2131.6 - Std Dev: 139.9 -━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ - -Saved benchmark results to gptoss_results.json - -Output sum: 11.532237 -
-
-
▶ UV Install Logs
- -
-
-

Artifacts:

-gptoss_results.json -
-
-
- -

GPT-OSS Implementation (Training Mode)

-

This section runs the GPT-OSS MoE implementation with training mode enabled to force the expert loop path.

-
-
- -▼ code -▼ output - ▶ uv-logs - | -Cell: gptoss_training_run | deps: torch, numpy | 40.63s - | - -Raw -
-
-
-
-1 -2 -3 -4 -5 -6 -7 -8 -9 -10 -11 -12 -13 -14 -15 -16 -17 -18 -19 -20 -21 -22 -23 -24 -25 -26 -27 -28 -29 -30 -31 -32 -33 -34 -35 -36 -37 -38 -39 -40 -41 -42 -43 -44 -45 -46 -47 -48 -49 -50 -51 -52 -53 -54 -55 -56 -57 -58 -59 -60 -61 -62 -63 -64 -65 -66 -67 -68 -69 -70 -71 -72 -73 -74 -75 -76 -77 -78 -79 -80 -81 -82 -83 -84 -85 -86 -87 -88 -89 -90 -91 -92 -93 -94 -95 -96 -97 -98 -99 -100 -101 -102 -103 -104 -105 -106 -107 -108 -109 -110 -111 -112 -113 -114 -115 -116 -117 -118 -119 -120 -121 -122 -123 -124 -125 -126 -127 -128 -129 -130 -131 -
-
-
import torch
-from torch import nn
-from torch.nn import functional as F
-from bench_utils import to_dtype, tensor_stats, set_seed, bench_context
-from config import (
-    NUM_EXPERTS, HIDDEN_SIZE, TOP_K,
-    BATCH_SIZE, SEQ_LEN, DTYPE, DEVICE,
-    WEIGHT_SEED, EXPERT_SEED, INPUT_SEED, GENERAL_SEED
-)
-from pathlib import Path
-import os
-
-# Discover the upstream artifact directory from env
-data_dir = os.environ.get('UVNOTE_INPUT_SAVE_DATA', '.')
-
-router_weight = torch.load(Path(data_dir) / 'router_weight.pt')
-router_bias = torch.load(Path(data_dir) / 'router_bias.pt')
-gate_up_proj = torch.load(Path(data_dir) / 'gate_up_proj.pt')
-gate_up_proj_bias = torch.load(Path(data_dir) / 'gate_up_proj_bias.pt')
-down_proj = torch.load(Path(data_dir) / 'down_proj.pt')
-down_proj_bias = torch.load(Path(data_dir) / 'down_proj_bias.pt')
-
-print("Loaded shared weights from artifacts")
-print(f"Router weight sum: {router_weight.sum().item():.6f}")
-print(f"Gate/up sum: {gate_up_proj.sum().item():.6f}")
-print(f"Down sum: {down_proj.sum().item():.6f}")
-
-class GptOssTrainingRouter(nn.Module):
-    def __init__(self, router_weight, router_bias):
-        super().__init__()
-        self.top_k = TOP_K
-        self.num_experts = NUM_EXPERTS
-        self.hidden_dim = HIDDEN_SIZE
-        self.weight = nn.Parameter(router_weight.clone())
-        self.bias = nn.Parameter(router_bias.clone())
-
-    def forward(self, hidden_states):
-        hidden_states = hidden_states.reshape(-1, self.hidden_dim)
-        router_logits = F.linear(hidden_states, self.weight, self.bias)
-        router_top_value, router_indices = torch.topk(router_logits, self.top_k, dim=-1)
-        router_top_value = torch.nn.functional.softmax(router_top_value, dim=1, dtype=router_top_value.dtype)
-        router_scores = torch.zeros_like(router_logits).scatter_(1, router_indices, router_top_value)
-        return router_scores, router_indices
-
-class GptOssTrainingExperts(nn.Module):
-    def __init__(self, gate_up_proj, gate_up_proj_bias, down_proj, down_proj_bias):
-        super().__init__()
-        self.num_experts = NUM_EXPERTS
-        self.hidden_size = HIDDEN_SIZE
-        self.expert_dim = self.hidden_size
-        self.gate_up_proj = nn.Parameter(gate_up_proj.clone())
-        self.gate_up_proj_bias = nn.Parameter(gate_up_proj_bias.clone())
-        self.down_proj = nn.Parameter(down_proj.clone())
-        self.down_proj_bias = nn.Parameter(down_proj_bias.clone())
-        self.alpha = 1.702
-        self.limit = 7.0
-
-    def forward(self, hidden_states: torch.Tensor, router_indices=None, routing_weights=None) -> torch.Tensor:
-        batch_size = hidden_states.shape[0]
-        hidden_states = hidden_states.reshape(-1, self.hidden_size)
-        num_experts = routing_weights.shape[1]
-
-        # Force training mode path (expert loop instead of batched)
-        next_states = torch.zeros_like(hidden_states, dtype=hidden_states.dtype, device=hidden_states.device)
-        with torch.no_grad():
-            expert_mask = torch.nn.functional.one_hot(router_indices, num_classes=num_experts)
-            expert_mask = expert_mask.permute(2, 1, 0)
-            expert_hit = torch.greater(expert_mask.sum(dim=(-1, -2)), 0).nonzero()
-
-        for expert_idx in expert_hit[:]:
-            expert_idx = expert_idx[0]
-            with torch.no_grad():
-                _, token_idx = torch.where(expert_mask[expert_idx])
-            current_state = hidden_states[token_idx]
-            gate_up = current_state @ self.gate_up_proj[expert_idx] + self.gate_up_proj_bias[expert_idx]
-            gate, up = gate_up[..., ::2], gate_up[..., 1::2]
-            gate = gate.clamp(min=None, max=self.limit)
-            up = up.clamp(min=-self.limit, max=self.limit)
-            glu = gate * torch.sigmoid(gate * self.alpha)
-            gated_output = (up + 1) * glu
-            out = gated_output @ self.down_proj[expert_idx] + self.down_proj_bias[expert_idx]
-            weighted_output = out * routing_weights[token_idx, expert_idx, None]
-            next_states.index_add_(0, token_idx, weighted_output.to(hidden_states.dtype))
-        next_states = next_states.view(batch_size, -1, self.hidden_size)
-        return next_states
-
-class GptOssTrainingMoEMLP(nn.Module):
-    def __init__(self, router_weight, router_bias, gate_up_proj, gate_up_proj_bias, down_proj, down_proj_bias):
-        super().__init__()
-        self.router = GptOssTrainingRouter(router_weight, router_bias)
-        self.experts = GptOssTrainingExperts(gate_up_proj, gate_up_proj_bias, down_proj, down_proj_bias)
-
-    def forward(self, hidden_states):
-        router_scores, router_indices = self.router(hidden_states)
-        routed_out = self.experts(hidden_states, router_indices=router_indices, routing_weights=router_scores)
-        return routed_out, router_scores
-
-# Run the model
-set_seed(GENERAL_SEED)
-
-device = torch.device(DEVICE)
-dtype = to_dtype(DTYPE)
-
-print("\n=== GPT-OSS Implementation (Training Mode - Expert Loop) ===")
-# Initialize model with loaded weights and force training mode
-model = GptOssTrainingMoEMLP(
-    router_weight.to(device),
-    router_bias.to(device),
-    gate_up_proj.to(device),
-    gate_up_proj_bias.to(device),
-    down_proj.to(device),
-    down_proj_bias.to(device)
-).to(device=device)
-
-# Set to training mode to force expert loop path
-model.train()
-
-print(f"Router weight sum: {model.router.weight.sum().item():.6f}")
-print(f"Gate/up proj sum: {model.experts.gate_up_proj.sum().item():.6f}")
-print(f"Down proj sum: {model.experts.down_proj.sum().item():.6f}")
-print(f"Model training mode: {model.training}")
-
-# Generate the same input as other implementations
-set_seed(INPUT_SEED)
-x = torch.randn(BATCH_SIZE, SEQ_LEN, HIDDEN_SIZE, device=device, dtype=dtype) * 0.1
-
-# Benchmark the model with varied inputs to prevent caching artifacts
-tokens = BATCH_SIZE * SEQ_LEN
-with bench_context(warmup=10, iters=50, device=device, dtype=dtype, tokens=tokens, save_json="gptoss_training_results.json", vary_inputs=True) as bench:
-    output, stats = bench(model, x)
-    print(f"\nOutput sum: {output[0].sum().item():.6f}")
-
- -
-
-
-
-
-
Loaded shared weights from artifacts -Router weight sum: 12.588732 -Gate/up sum: 1026.601807 -Down sum: 206.729263 - -=== GPT-OSS Implementation (Training Mode - Expert Loop) === -Router weight sum: 12.588732 -Gate/up proj sum: 1026.601807 -Down proj sum: 206.729340 -Model training mode: True - -┌─ Benchmark Configuration ─────────────────────────────┐ -│ Warmup: 10 Iters: 50 │ -│ Tokens: 100 │ -│ Input Variation: Enabled (prevents caching artifacts) │ -└────────────────────────────────────────────────────────┘ - -Base Input: shape=(1, 100, 1152), dtype=torch.float32, device=cuda:0, range=[-0.486445, 0.446746], mean=-0.000048, std=0.099986, norm=33.936142 -Input Variation: +0.001 * iteration (deterministic) - -Warming up (10 iterations)... -Benchmarking (50 iterations)... - Progress: 20% complete (avg: 49.824 ms) - Progress: 40% complete (avg: 49.309 ms) - Progress: 60% complete (avg: 48.365 ms) - Progress: 80% complete (avg: 47.278 ms) - -Output tensors: - Primary: shape=(1, 100, 1152), dtype=torch.float32, device=cuda:0, range=[-0.064982, 0.061193], mean=0.000100, std=0.013510, norm=4.585560 - Auxiliary: shape=(100, 128), dtype=torch.float32, device=cuda:0, range=[0.000000, 0.302948], mean=0.007812, std=0.043553, norm=5.005893 - -━━━━━━━━━━━━━━━━━━━━ Benchmark Results ━━━━━━━━━━━━━━━━━━━━ -Iterations: 50 - -Latency Statistics: - Average: 46.289 ms - Min: 39.979 ms - Max: 50.581 ms - Std Dev: 2.917 ms - -Percentiles: - P50 (median): 46.648 ms - P95: 50.267 ms - P99: 50.516 ms - -Throughput: - Tokens/sec: 2160.3 - Std Dev: 139.9 -━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ - -Saved benchmark results to gptoss_training_results.json - -Output sum: 11.532237 -
-
-
▶ UV Install Logs
- -
- -
-
- -

MegaBlocks Implementation

-

This section runs the MegaBlocks MoE implementation with optimized kernels from the Hugging Face hub.

-
-
- -▼ code -▼ output - ▶ uv-logs - | -Cell: megablocks_run | deps: torch, numpy, kernels | 40.35s | FAILED - | - -Raw -
-
-
-
-1 -2 -3 -4 -5 -6 -7 -8 -9 -10 -11 -12 -13 -14 -15 -16 -17 -18 -19 -20 -21 -22 -23 -24 -25 -26 -27 -28 -29 -30 -31 -32 -33 -34 -35 -36 -37 -38 -39 -40 -41 -42 -43 -44 -45 -46 -47 -48 -49 -50 -51 -52 -53 -54 -55 -56 -57 -58 -59 -60 -61 -62 -63 -64 -65 -66 -67 -68 -69 -70 -71 -72 -73 -74 -75 -76 -77 -78 -79 -80 -81 -82 -83 -84 -85 -86 -87 -88 -89 -90 -91 -92 -93 -94 -95 -
-
-
import torch
-from torch import nn
-from torch.nn import functional as F
-from kernels import get_kernel, get_local_kernel
-from bench_utils import to_dtype, tensor_stats, set_seed, bench_context
-from config import (
-    NUM_EXPERTS, HIDDEN_SIZE, TOP_K,
-    BATCH_SIZE, SEQ_LEN, DTYPE, DEVICE,
-    WEIGHT_SEED, EXPERT_SEED, INPUT_SEED, GENERAL_SEED
-)
-from pathlib import Path
-from collections import namedtuple
-import os
-
-# Discover the upstream artifact directory from env
-data_dir = os.environ.get('UVNOTE_INPUT_SAVE_DATA', '.')
-
-print(f"Loading weights from: {data_dir}")
-
-router_weight = torch.load(Path(data_dir) / 'router_weight.pt')
-router_bias = torch.load(Path(data_dir) / 'router_bias.pt')
-gate_up_proj = torch.load(Path(data_dir) / 'gate_up_proj.pt')
-gate_up_proj_bias = torch.load(Path(data_dir) / 'gate_up_proj_bias.pt')
-down_proj = torch.load(Path(data_dir) / 'down_proj.pt')
-down_proj_bias = torch.load(Path(data_dir) / 'down_proj_bias.pt')
-
-print("Loaded shared weights from artifacts")
-print(f"Router weight sum: {router_weight.sum().item():.6f}")
-print(f"Gate/up sum: {gate_up_proj.sum().item():.6f}")
-print(f"Down sum: {down_proj.sum().item():.6f}")
-
-def build_megablocks_model(device: torch.device):
-    # Download optimized kernels from the Hugging Face hub
-    megablocks = get_kernel("kernels-community/megablocks", revision="v0.0.2")
-    model = megablocks.layers.MegaBlocksMoeMLP()
-
-    # Create attribute container for expert weights
-    model.experts = namedtuple(
-        "Experts", ["gate_up_proj", "gate_up_proj_bias", "down_proj", "down_proj_bias", "hidden_size"]
-    )
-
-    # Use loaded router weights for consistency
-    model.router = torch.nn.Linear(HIDDEN_SIZE, NUM_EXPERTS, device=device)
-    with torch.no_grad():
-        model.router.weight.copy_(router_weight)
-        model.router.bias.copy_(router_bias)
-
-    # Attach loaded expert weights to the experts container
-    e = model.experts
-    e.alpha = 1.702
-    e.capacity_factor = 32
-    e.gate_up_proj = torch.nn.Parameter(gate_up_proj.clone().to(device))
-    e.gate_up_proj_bias = torch.nn.Parameter(gate_up_proj_bias.clone().to(device))
-    e.down_proj = torch.nn.Parameter(down_proj.clone().to(device))
-    e.down_proj_bias = torch.nn.Parameter(down_proj_bias.clone().to(device))
-    e.hidden_size = HIDDEN_SIZE
-
-    # Log weight statistics for comparison
-    print(f"[MegaBlocks] Router weight sum: {model.router.weight.sum().item():.6f}")
-    print(f"[MegaBlocks] Gate/up projection shape: {tuple(e.gate_up_proj.shape)}, sum: {e.gate_up_proj.sum().item():.6f}")
-    print(f"[MegaBlocks] Down projection shape: {tuple(e.down_proj.shape)}, sum: {e.down_proj.sum().item():.6f}")
-
-    return model
-
-# Create a wrapper to match the interface of other implementations
-class MegaBlocksMoEWrapper(nn.Module):
-    def __init__(self, megablocks_model):
-        super().__init__()
-        self.model = megablocks_model
-
-    def forward(self, hidden_states):
-        # MegaBlocks expects input in the format (batch, seq_len, hidden_dim)
-        output, dummy_routing_weights = self.model(hidden_states)
-        return output, dummy_routing_weights
-
-# Run the model
-set_seed(GENERAL_SEED)
-
-device = torch.device(DEVICE)
-dtype = to_dtype(DTYPE)
-
-print("\n=== MegaBlocks Implementation ===")
-# Build MegaBlocks model with loaded weights
-megablocks_model = build_megablocks_model(device)
-model = MegaBlocksMoEWrapper(megablocks_model).to(device=device)
-
-# Generate the same input as other implementations
-set_seed(INPUT_SEED)
-x = torch.randn(BATCH_SIZE, SEQ_LEN, HIDDEN_SIZE, device=device, dtype=dtype) * 0.1
-
-# Benchmark the model with varied inputs to prevent caching artifacts
-tokens = BATCH_SIZE * SEQ_LEN
-with bench_context(warmup=10, iters=50, device=device, dtype=dtype, tokens=tokens, save_json="megablocks_results.json", vary_inputs=True) as bench:
-    output, stats = bench(model, x)
-    print(f"\nOutput sum: {output[0].sum().item():.6f}")
-
- -
-
-
-
-
-
Loading weights from: /repo/moe_benchmarks/megablocks_yamoe/.uvnote/cache/f8744f31d9cf720409852d42748815c6d61f005a2a9b297b7b9bf986ed98bb90 -Loaded shared weights from artifacts -Router weight sum: 12.588732 -Gate/up sum: 1026.601807 -Down sum: 206.729263 - -=== MegaBlocks Implementation === -[MegaBlocks] Router weight sum: 12.588732 -[MegaBlocks] Gate/up projection shape: (128, 1152, 2304), sum: 1026.601807 -[MegaBlocks] Down projection shape: (128, 1152, 1152), sum: 206.729340 - -┌─ Benchmark Configuration ─────────────────────────────┐ -│ Warmup: 10 Iters: 50 │ -│ Tokens: 100 │ -│ Input Variation: Enabled (prevents caching artifacts) │ -└────────────────────────────────────────────────────────┘ - -Base Input: shape=(1, 100, 1152), dtype=torch.float32, device=cuda:0, range=[-0.486445, 0.446746], mean=-0.000048, std=0.099986, norm=33.936142 -Input Variation: +0.001 * iteration (deterministic) - -Warming up (10 iterations)... -
-
-
▶ UV Install Logs
- -
-
Fetching 66 files: 0%| | 0/66 [00:00<?, ?it/s] -Fetching 66 files: 2%|▏ | 1/66 [00:00<00:28, 2.31it/s] -Fetching 66 files: 14%|█▎ | 9/66 [00:00<00:03, 18.19it/s] -Fetching 66 files: 26%|██▌ | 17/66 [00:01<00:02, 16.61it/s] -Fetching 66 files: 52%|█████▏ | 34/66 [00:01<00:00, 38.17it/s] -Fetching 66 files: 64%|██████▎ | 42/66 [00:01<00:00, 36.62it/s] -Fetching 66 files: 73%|███████▎ | 48/66 [00:01<00:00, 28.57it/s] -Fetching 66 files: 92%|█████████▏| 61/66 [00:01<00:00, 39.67it/s] -Fetching 66 files: 100%|██████████| 66/66 [00:02<00:00, 32.91it/s] -/tmp/tmp1397kafx/cuda_utils.c:5:10: fatal error: Python.h: No such file or directory - 5 | #include <Python.h> - | ^~~~~~~~~~ -compilation terminated. -Traceback (most recent call last): - File "/repo/moe_benchmarks/megablocks_yamoe/.uvnote/cells/megablocks_run.py", line 102, in <module> - output, stats = bench(model, x) - ^^^^^^^^^^^^^^^ - File "/repo/moe_benchmarks/megablocks_yamoe/.uvnote/cells/bench_utils.py", line 189, in runner - result, times_s = _bench_engine(call, warmup=warmup, iters=iters, device=device, dtype=dtype, input_gen=input_gen) - ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ - File "/repo/moe_benchmarks/megablocks_yamoe/.uvnote/cells/bench_utils.py", line 96, in _bench_engine - _ = call(input_gen()) - ^^^^^^^^^^^^^^^^^ - File "/repo/moe_benchmarks/megablocks_yamoe/.uvnote/cells/bench_utils.py", line 177, in <lambda> - call = lambda x: fn(x, *args[1:], **kwargs) - ^^^^^^^^^^^^^^^^^^^^^^^^^^ - File "/tmp/uvnote-run-g9v2jr6r/home/.cache/uv/environments-v2/megablocks-run-8802ebf6d3566120/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1773, in _wrapped_call_impl - return self._call_impl(*args, **kwargs) - ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ - File "/tmp/uvnote-run-g9v2jr6r/home/.cache/uv/environments-v2/megablocks-run-8802ebf6d3566120/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1784, in _call_impl - return forward_call(*args, **kwargs) - ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ - File "/repo/moe_benchmarks/megablocks_yamoe/.uvnote/cells/megablocks_run.py", line 81, in forward - output, dummy_routing_weights = self.model(hidden_states) - ^^^^^^^^^^^^^^^^^^^^^^^^^ - File "/tmp/uvnote-run-g9v2jr6r/home/.cache/uv/environments-v2/megablocks-run-8802ebf6d3566120/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1773, in _wrapped_call_impl - return self._call_impl(*args, **kwargs) - ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ - File "/tmp/uvnote-run-g9v2jr6r/home/.cache/uv/environments-v2/megablocks-run-8802ebf6d3566120/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1784, in _call_impl - return forward_call(*args, **kwargs) - ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ - File "/tmp/uvnote-run-g9v2jr6r/home/.cache/huggingface/hub/models--kernels-community--megablocks/snapshots/e0fb1437de3f8d7079c4da13be8cb64dc0cfcdd5/build/torch28-cxx11-cu128-x86_64-linux/megablocks/layers.py", line 896, in forward - output, expert_weights_out, *_ = moe_forward( - ^^^^^^^^^^^^ - File "/tmp/uvnote-run-g9v2jr6r/home/.cache/huggingface/hub/models--kernels-community--megablocks/snapshots/e0fb1437de3f8d7079c4da13be8cb64dc0cfcdd5/build/torch28-cxx11-cu128-x86_64-linux/megablocks/layers.py", line 730, in moe_forward - x, tokens_per_expert = forward_fn(**forward_args) - ^^^^^^^^^^^^^^^^^^^^^^^^^^ - File "/tmp/uvnote-run-g9v2jr6r/home/.cache/huggingface/hub/models--kernels-community--megablocks/snapshots/e0fb1437de3f8d7079c4da13be8cb64dc0cfcdd5/build/torch28-cxx11-cu128-x86_64-linux/megablocks/layers.py", line 457, in forward_once - x = permute_and_compute( - ^^^^^^^^^^^^^^^^^^^^ - File "/tmp/uvnote-run-g9v2jr6r/home/.cache/huggingface/hub/models--kernels-community--megablocks/snapshots/e0fb1437de3f8d7079c4da13be8cb64dc0cfcdd5/build/torch28-cxx11-cu128-x86_64-linux/megablocks/layers.py", line 401, in permute_and_compute - x = ops.binned_gather(x, indices, bins, expert_capacity, top_k) - ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ - File "/tmp/uvnote-run-g9v2jr6r/home/.cache/uv/environments-v2/megablocks-run-8802ebf6d3566120/lib/python3.11/site-packages/torch/autograd/function.py", line 576, in apply - return super().apply(*args, **kwargs) # type: ignore[misc] - ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ - File "/tmp/uvnote-run-g9v2jr6r/home/.cache/huggingface/hub/models--kernels-community--megablocks/snapshots/e0fb1437de3f8d7079c4da13be8cb64dc0cfcdd5/build/torch28-cxx11-cu128-x86_64-linux/megablocks/ops/stk_autocast.py", line 30, in decorate_fwd - return fwd(*args, **kwargs) - ^^^^^^^^^^^^^^^^^^^^ - File "/tmp/uvnote-run-g9v2jr6r/home/.cache/huggingface/hub/models--kernels-community--megablocks/snapshots/e0fb1437de3f8d7079c4da13be8cb64dc0cfcdd5/build/torch28-cxx11-cu128-x86_64-linux/megablocks/ops/binned_gather.py", line 26, in forward - return kernels.binned_gather(x, indices, None, bins, bin_size, top_k) - ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ - File "/tmp/uvnote-run-g9v2jr6r/home/.cache/huggingface/hub/models--kernels-community--megablocks/snapshots/e0fb1437de3f8d7079c4da13be8cb64dc0cfcdd5/build/torch28-cxx11-cu128-x86_64-linux/megablocks/backend/kernels.py", line 419, in binned_gather - _binned_copy[(num_experts, expert_capacity)]( - File "/tmp/uvnote-run-g9v2jr6r/home/.cache/uv/environments-v2/megablocks-run-8802ebf6d3566120/lib/python3.11/site-packages/triton/runtime/jit.py", line 390, in <lambda> - return lambda *args, **kwargs: self.run(grid=grid, warmup=False, *args, **kwargs) - ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ - File "/tmp/uvnote-run-g9v2jr6r/home/.cache/uv/environments-v2/megablocks-run-8802ebf6d3566120/lib/python3.11/site-packages/triton/runtime/autotuner.py", line 239, in run - benchmark() - File "/tmp/uvnote-run-g9v2jr6r/home/.cache/uv/environments-v2/megablocks-run-8802ebf6d3566120/lib/python3.11/site-packages/triton/runtime/autotuner.py", line 228, in benchmark - timings = {config: self._bench(*args, config=config, **kwargs) for config in pruned_configs} - ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ - File "/tmp/uvnote-run-g9v2jr6r/home/.cache/uv/environments-v2/megablocks-run-8802ebf6d3566120/lib/python3.11/site-packages/triton/runtime/autotuner.py", line 228, in <dictcomp> - timings = {config: self._bench(*args, config=config, **kwargs) for config in pruned_configs} - ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ - File "/tmp/uvnote-run-g9v2jr6r/home/.cache/uv/environments-v2/megablocks-run-8802ebf6d3566120/lib/python3.11/site-packages/triton/runtime/autotuner.py", line 160, in _bench - return self.do_bench(kernel_call, quantiles=(0.5, 0.2, 0.8)) - ^^^^^^^^^^^^^ - File "/usr/lib/python3.11/functools.py", line 1001, in __get__ - val = self.func(instance) - ^^^^^^^^^^^^^^^^^^^ - File "/tmp/uvnote-run-g9v2jr6r/home/.cache/uv/environments-v2/megablocks-run-8802ebf6d3566120/lib/python3.11/site-packages/triton/runtime/autotuner.py", line 121, in do_bench - return driver.active.get_benchmarker() - ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ - File "/tmp/uvnote-run-g9v2jr6r/home/.cache/uv/environments-v2/megablocks-run-8802ebf6d3566120/lib/python3.11/site-packages/triton/runtime/driver.py", line 30, in __getattr__ - return getattr(self._initialize_obj(), name) - ^^^^^^^^^^^^^^^^^^^^^^ - File "/tmp/uvnote-run-g9v2jr6r/home/.cache/uv/environments-v2/megablocks-run-8802ebf6d3566120/lib/python3.11/site-packages/triton/runtime/driver.py", line 26, in _initialize_obj - self._obj = self._init_fn() - ^^^^^^^^^^^^^^^ - File "/tmp/uvnote-run-g9v2jr6r/home/.cache/uv/environments-v2/megablocks-run-8802ebf6d3566120/lib/python3.11/site-packages/triton/runtime/driver.py", line 12, in _create_driver - return active_drivers[0]() - ^^^^^^^^^^^^^^^^^^^ - File "/tmp/uvnote-run-g9v2jr6r/home/.cache/uv/environments-v2/megablocks-run-8802ebf6d3566120/lib/python3.11/site-packages/triton/backends/nvidia/driver.py", line 715, in __init__ - self.utils = CudaUtils() # TODO: make static - ^^^^^^^^^^^ - File "/tmp/uvnote-run-g9v2jr6r/home/.cache/uv/environments-v2/megablocks-run-8802ebf6d3566120/lib/python3.11/site-packages/triton/backends/nvidia/driver.py", line 62, in __init__ - mod = compile_module_from_src( - ^^^^^^^^^^^^^^^^^^^^^^^^ - File "/tmp/uvnote-run-g9v2jr6r/home/.cache/uv/environments-v2/megablocks-run-8802ebf6d3566120/lib/python3.11/site-packages/triton/runtime/build.py", line 88, in compile_module_from_src - so = _build(name, src_path, tmpdir, library_dirs or [], include_dirs or [], libraries or []) - ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ - File "/tmp/uvnote-run-g9v2jr6r/home/.cache/uv/environments-v2/megablocks-run-8802ebf6d3566120/lib/python3.11/site-packages/triton/runtime/build.py", line 51, in _build - subprocess.check_call(cc_cmd, stdout=subprocess.DEVNULL) - File "/usr/lib/python3.11/subprocess.py", line 413, in check_call - raise CalledProcessError(retcode, cmd) -subprocess.CalledProcessError: Command '['/usr/bin/gcc', '/tmp/tmp1397kafx/cuda_utils.c', '-O3', '-shared', '-fPIC', '-Wno-psabi', '-o', '/tmp/tmp1397kafx/cuda_utils.cpython-311-x86_64-linux-gnu.so', '-lcuda', '-L/tmp/uvnote-run-g9v2jr6r/home/.cache/uv/environments-v2/megablocks-run-8802ebf6d3566120/lib/python3.11/site-packages/triton/backends/nvidia/lib', '-L/usr/lib/x86_64-linux-gnu', '-I/tmp/uvnote-run-g9v2jr6r/home/.cache/uv/environments-v2/megablocks-run-8802ebf6d3566120/lib/python3.11/site-packages/triton/backends/nvidia/include', '-I/tmp/tmp1397kafx', '-I/usr/include/python3.11']' returned non-zero exit status 1.
-
-
- -

Performance Visualization

-

This section reads all benchmark results and creates a comprehensive performance comparison chart.

-
- - - \ No newline at end of file diff --git a/moe_benchmarks/index.html b/moe_benchmarks/index.html deleted file mode 100644 index 459357c72e0afcc10921ae5c13c251e999b35f06..0000000000000000000000000000000000000000 --- a/moe_benchmarks/index.html +++ /dev/null @@ -1,25 +0,0 @@ - - - - - Directory Index - - - -

Index of /moe_benchmarks

- - - \ No newline at end of file diff --git a/moe_benchmarks/megablocks/cells/forward_and_backward.py b/moe_benchmarks/megablocks/cells/forward_and_backward.py deleted file mode 100644 index a8ac420c8a43009eb857f3a7889b4f79ad5a1191..0000000000000000000000000000000000000000 --- a/moe_benchmarks/megablocks/cells/forward_and_backward.py +++ /dev/null @@ -1,196 +0,0 @@ -# /// script -# requires-python = ">=3.12" -# dependencies = [ -# "accelerate>=1.10.1", -# "torch>=2.7.0", -# "kernels==0.10.0", -# "transformers@https://github.com/huggingface/transformers.git", -# "ipdb>=0.13.13", -# "matplotlib>=3.7.2", -# "numpy>=1.24.3", -# ] -# /// - -import torch -from transformers import GptOssForCausalLM, PreTrainedTokenizerFast, Mxfp4Config -import time -import torch.nn as nn -from kernels import register_kernel_mapping, Mode, LayerRepository, replace_kernel_forward_from_hub -import sys -import torch.profiler -import gc -import logging -from transformers.models.gpt_oss.modeling_gpt_oss import GptOssRMSNorm - -# remove liger kernel for testing -replace_kernel_forward_from_hub(GptOssRMSNorm, None) - -# set to debug logging -logging.basicConfig(level=logging.INFO) - -def reset_peak_memory_stats(): - """Clear CUDA cache and reset memory allocation counters.""" - torch.cuda.empty_cache() - if torch.cuda.is_available(): - torch.cuda.reset_peak_memory_stats() - gc.collect() - -def get_memory_stats(): - """Get current and peak CUDA memory usage.""" - if not torch.cuda.is_available(): - return {"allocated_gb": 0, "peak_gb": 0, "reserved_gb": 0} - return { - "allocated_gb": torch.cuda.memory_allocated() / 1e9, - "peak_gb": torch.cuda.max_memory_allocated() / 1e9, - "reserved_gb": torch.cuda.memory_reserved() / 1e9, - } - -def override_kernel_layer_name(cls_name: str, value) -> bool: - """Helper to dynamically override the kernel_layer_name in a model class.""" - for mod in sys.modules.values(): - if mod is None: - continue - obj = getattr(mod, cls_name, None) - if isinstance(obj, type) and issubclass(obj, nn.Module): - setattr(obj, "kernel_layer_name", value) - print(f"Overrode {cls_name}.kernel_layer_name to {value}") - return True - return False - - -# Init the model the normal way -model_id = "openai/gpt-oss-20b" -tokenizer = PreTrainedTokenizerFast.from_pretrained(model_id) -quantization_config = Mxfp4Config(dequantize=True) - -model = GptOssForCausalLM.from_pretrained( - model_id, - dtype="bfloat16", - device_map="auto", - use_kernels=True, - quantization_config=quantization_config, -).eval() - -messages = [ - {"role": "system", "content": "What is Tensor Parallelism?"}, -] - -inputs = tokenizer.apply_chat_template( - messages, - add_generation_prompt=True, - return_tensors="pt", - return_dict=True, - reasoning_effort="low", -).to("cuda") - -max_tokens = 128 # Reduced to help with memory usage - -# Clear memory before backward pass -reset_peak_memory_stats() -print(f"Pre-generation memory: {get_memory_stats()}") - -# forward and backward pass -with torch.autograd.set_grad_enabled(True): - start_time = time.perf_counter() - generated = model.generate( - **inputs, - max_new_tokens=max_tokens, - do_sample=False, - temperature=None, - ) - end_time = time.perf_counter() - print(tokenizer.decode(generated[0], skip_special_tokens=False)) - print(f"Generation took {end_time - start_time:.2f} seconds") - print(f"Post-generation memory: {get_memory_stats()}") - - # Use gradient checkpointing to reduce memory usage - if hasattr(model, 'gradient_checkpointing_enable'): - model.gradient_checkpointing_enable() - print("Enabled gradient checkpointing") - - # Reduce sequence length if needed for memory - max_seq_len = 512 # Limit sequence length for backward pass - if generated.size(1) > max_seq_len: - print(f"Truncating sequence from {generated.size(1)} to {max_seq_len} tokens") - full_sequence = generated[:, -max_seq_len:] - else: - full_sequence = generated - - # Get model outputs for the full sequence - model.train() # Enable dropout and other training behaviors - - try: - outputs = model( - input_ids=full_sequence, - labels=full_sequence, # This will compute loss internally - return_dict=True - ) - print(f"Post-forward memory: {get_memory_stats()}") - - # If model doesn't compute loss, compute it manually - if outputs.loss is None: - shift_logits = outputs.logits[..., :-1, :].contiguous() - shift_labels = full_sequence[..., 1:].contiguous() - - # Use CrossEntropyLoss with ignore_index for padding tokens - loss_fct = torch.nn.CrossEntropyLoss(ignore_index=tokenizer.pad_token_id if tokenizer.pad_token_id is not None else -100) - loss = loss_fct( - shift_logits.view(-1, shift_logits.size(-1)), - shift_labels.view(-1) - ) - else: - loss = outputs.loss - - print(f"Loss: {loss.item():.4f}") - - # Clear intermediate tensors to save memory - del outputs - torch.cuda.empty_cache() - - # Perform backward pass with memory management - print("Running backward pass...") - print(f"Pre-backward memory: {get_memory_stats()}") - - loss.backward() - print(f"Post-backward memory: {get_memory_stats()}") - - except torch.cuda.OutOfMemoryError as e: - print(f"OOM during forward/backward pass: {e}") - print("Try reducing max_tokens or max_seq_len") - raise - - # Calculate gradient statistics and print sample gradients - total_norm = 0.0 - param_count = 0 - grad_samples = {} - - for name, p in model.named_parameters(): - if p.grad is not None: - param_count += 1 - grad_norm = p.grad.data.norm(2).item() - total_norm += grad_norm ** 2 - - # Collect gradient statistics for key layers - if any(key in name for key in ['embed', 'lm_head', 'mlp.up', 'mlp.down', 'self_attn.q_proj', 'norm']): - grad_samples[name] = { - 'norm': grad_norm, - 'mean': p.grad.data.mean().item(), - 'std': p.grad.data.std().item(), - 'max': p.grad.data.max().item(), - 'min': p.grad.data.min().item(), - } - - total_norm = total_norm ** 0.5 - - print(f"\nGradient norm: {total_norm:.4f}") - print(f"Parameters with gradients: {param_count}") - - # Print sample gradients from important layers - print("\nSample gradient statistics:") - for i, (name, stats) in enumerate(list(grad_samples.items())[:10]): - print(f" {name[:60]:<60} | norm: {stats['norm']:.4e} | mean: {stats['mean']:.4e} | std: {stats['std']:.4e}") - - # Optional: zero gradients for next iteration - model.zero_grad() - model.eval() # Switch back to eval mode - diff --git a/moe_benchmarks/megablocks/cells/forward_and_backward_no_kernel.py b/moe_benchmarks/megablocks/cells/forward_and_backward_no_kernel.py deleted file mode 100644 index d56805f64c56b484df98c41b9e62d3b6f27ff088..0000000000000000000000000000000000000000 --- a/moe_benchmarks/megablocks/cells/forward_and_backward_no_kernel.py +++ /dev/null @@ -1,196 +0,0 @@ -# /// script -# requires-python = ">=3.12" -# dependencies = [ -# "accelerate>=1.10.1", -# "torch>=2.7.0", -# "kernels==0.10.0", -# "transformers@https://github.com/huggingface/transformers.git", -# "ipdb>=0.13.13", -# "matplotlib>=3.7.2", -# "numpy>=1.24.3", -# ] -# /// - -import torch -from transformers import GptOssForCausalLM, PreTrainedTokenizerFast, Mxfp4Config -import time -import torch.nn as nn -from kernels import register_kernel_mapping, Mode, LayerRepository, replace_kernel_forward_from_hub -import sys -import torch.profiler -import gc -import logging -from transformers.models.gpt_oss.modeling_gpt_oss import GptOssRMSNorm - -# remove liger kernel for testing -replace_kernel_forward_from_hub(GptOssRMSNorm, None) - -# set to debug logging -logging.basicConfig(level=logging.INFO) - -def reset_peak_memory_stats(): - """Clear CUDA cache and reset memory allocation counters.""" - torch.cuda.empty_cache() - if torch.cuda.is_available(): - torch.cuda.reset_peak_memory_stats() - gc.collect() - -def get_memory_stats(): - """Get current and peak CUDA memory usage.""" - if not torch.cuda.is_available(): - return {"allocated_gb": 0, "peak_gb": 0, "reserved_gb": 0} - return { - "allocated_gb": torch.cuda.memory_allocated() / 1e9, - "peak_gb": torch.cuda.max_memory_allocated() / 1e9, - "reserved_gb": torch.cuda.memory_reserved() / 1e9, - } - -def override_kernel_layer_name(cls_name: str, value) -> bool: - """Helper to dynamically override the kernel_layer_name in a model class.""" - for mod in sys.modules.values(): - if mod is None: - continue - obj = getattr(mod, cls_name, None) - if isinstance(obj, type) and issubclass(obj, nn.Module): - setattr(obj, "kernel_layer_name", value) - print(f"Overrode {cls_name}.kernel_layer_name to {value}") - return True - return False - - -# Init the model the normal way -model_id = "openai/gpt-oss-20b" -tokenizer = PreTrainedTokenizerFast.from_pretrained(model_id) -quantization_config = Mxfp4Config(dequantize=True) - -model = GptOssForCausalLM.from_pretrained( - model_id, - dtype="bfloat16", - device_map="auto", - use_kernels=False, - quantization_config=quantization_config, -).eval() - -messages = [ - {"role": "system", "content": "What is Tensor Parallelism?"}, -] - -inputs = tokenizer.apply_chat_template( - messages, - add_generation_prompt=True, - return_tensors="pt", - return_dict=True, - reasoning_effort="low", -).to("cuda") - -max_tokens = 128 # Reduced to help with memory usage - -# Clear memory before backward pass -reset_peak_memory_stats() -print(f"Pre-generation memory: {get_memory_stats()}") - -# forward and backward pass -with torch.autograd.set_grad_enabled(True): - start_time = time.perf_counter() - generated = model.generate( - **inputs, - max_new_tokens=max_tokens, - do_sample=False, - temperature=None, - ) - end_time = time.perf_counter() - print(tokenizer.decode(generated[0], skip_special_tokens=False)) - print(f"Generation took {end_time - start_time:.2f} seconds") - print(f"Post-generation memory: {get_memory_stats()}") - - # Use gradient checkpointing to reduce memory usage - if hasattr(model, 'gradient_checkpointing_enable'): - model.gradient_checkpointing_enable() - print("Enabled gradient checkpointing") - - # Reduce sequence length if needed for memory - max_seq_len = 512 # Limit sequence length for backward pass - if generated.size(1) > max_seq_len: - print(f"Truncating sequence from {generated.size(1)} to {max_seq_len} tokens") - full_sequence = generated[:, -max_seq_len:] - else: - full_sequence = generated - - # Get model outputs for the full sequence - model.train() # Enable dropout and other training behaviors - - try: - outputs = model( - input_ids=full_sequence, - labels=full_sequence, # This will compute loss internally - return_dict=True - ) - print(f"Post-forward memory: {get_memory_stats()}") - - # If model doesn't compute loss, compute it manually - if outputs.loss is None: - shift_logits = outputs.logits[..., :-1, :].contiguous() - shift_labels = full_sequence[..., 1:].contiguous() - - # Use CrossEntropyLoss with ignore_index for padding tokens - loss_fct = torch.nn.CrossEntropyLoss(ignore_index=tokenizer.pad_token_id if tokenizer.pad_token_id is not None else -100) - loss = loss_fct( - shift_logits.view(-1, shift_logits.size(-1)), - shift_labels.view(-1) - ) - else: - loss = outputs.loss - - print(f"Loss: {loss.item():.4f}") - - # Clear intermediate tensors to save memory - del outputs - torch.cuda.empty_cache() - - # Perform backward pass with memory management - print("Running backward pass...") - print(f"Pre-backward memory: {get_memory_stats()}") - - loss.backward() - print(f"Post-backward memory: {get_memory_stats()}") - - except torch.cuda.OutOfMemoryError as e: - print(f"OOM during forward/backward pass: {e}") - print("Try reducing max_tokens or max_seq_len") - raise - - # Calculate gradient statistics and print sample gradients - total_norm = 0.0 - param_count = 0 - grad_samples = {} - - for name, p in model.named_parameters(): - if p.grad is not None: - param_count += 1 - grad_norm = p.grad.data.norm(2).item() - total_norm += grad_norm ** 2 - - # Collect gradient statistics for key layers - if any(key in name for key in ['embed', 'lm_head', 'mlp.up', 'mlp.down', 'self_attn.q_proj', 'norm']): - grad_samples[name] = { - 'norm': grad_norm, - 'mean': p.grad.data.mean().item(), - 'std': p.grad.data.std().item(), - 'max': p.grad.data.max().item(), - 'min': p.grad.data.min().item(), - } - - total_norm = total_norm ** 0.5 - - print(f"\nGradient norm: {total_norm:.4f}") - print(f"Parameters with gradients: {param_count}") - - # Print sample gradients from important layers - print("\nSample gradient statistics:") - for i, (name, stats) in enumerate(list(grad_samples.items())[:10]): - print(f" {name[:60]:<60} | norm: {stats['norm']:.4e} | mean: {stats['mean']:.4e} | std: {stats['std']:.4e}") - - # Optional: zero gradients for next iteration - model.zero_grad() - model.eval() # Switch back to eval mode - diff --git a/moe_benchmarks/megablocks/cells/forward_only.py b/moe_benchmarks/megablocks/cells/forward_only.py deleted file mode 100644 index c72358d0eef5e1f993aef1e76dfb0f26761c4881..0000000000000000000000000000000000000000 --- a/moe_benchmarks/megablocks/cells/forward_only.py +++ /dev/null @@ -1,101 +0,0 @@ -# /// script -# requires-python = ">=3.12" -# dependencies = [ -# "accelerate>=1.10.1", -# "torch>=2.7.0", -# "kernels==0.10.0", -# "transformers@https://github.com/huggingface/transformers.git", -# "ipdb>=0.13.13", -# "matplotlib>=3.7.2", -# "numpy>=1.24.3", -# ] -# /// - -import torch -from transformers import GptOssForCausalLM, PreTrainedTokenizerFast, Mxfp4Config -import time -import torch.nn as nn -from kernels import register_kernel_mapping, Mode, LayerRepository, replace_kernel_forward_from_hub -import sys -import torch.profiler -import gc -import logging -from transformers.models.gpt_oss.modeling_gpt_oss import GptOssRMSNorm - - -replace_kernel_forward_from_hub(GptOssRMSNorm, None) - -# set to debug logging -logging.basicConfig(level=logging.INFO) - -def reset_peak_memory_stats(): - """Clear CUDA cache and reset memory allocation counters.""" - torch.cuda.empty_cache() - if torch.cuda.is_available(): - torch.cuda.reset_peak_memory_stats() - gc.collect() - -def get_memory_stats(): - """Get current and peak CUDA memory usage.""" - if not torch.cuda.is_available(): - return {"allocated_gb": 0, "peak_gb": 0, "reserved_gb": 0} - return { - "allocated_gb": torch.cuda.memory_allocated() / 1e9, - "peak_gb": torch.cuda.max_memory_allocated() / 1e9, - "reserved_gb": torch.cuda.memory_reserved() / 1e9, - } - -def override_kernel_layer_name(cls_name: str, value) -> bool: - """Helper to dynamically override the kernel_layer_name in a model class.""" - for mod in sys.modules.values(): - if mod is None: - continue - obj = getattr(mod, cls_name, None) - if isinstance(obj, type) and issubclass(obj, nn.Module): - setattr(obj, "kernel_layer_name", value) - print(f"Overrode {cls_name}.kernel_layer_name to {value}") - return True - return False - - -# Init the model the normal way -model_id = "openai/gpt-oss-20b" -tokenizer = PreTrainedTokenizerFast.from_pretrained(model_id) -quantization_config = Mxfp4Config(dequantize=True) - - - -model = GptOssForCausalLM.from_pretrained( - model_id, - dtype="bfloat16", - device_map="auto", - use_kernels=True, - quantization_config=quantization_config, -).eval() - -messages = [ - {"role": "system", "content": "What is Tensor Parallelism?"}, -] - -inputs = tokenizer.apply_chat_template( - messages, - add_generation_prompt=True, - return_tensors="pt", - return_dict=True, - reasoning_effort="low", -).to("cuda") - -max_tokens = 256 - -with torch.inference_mode(): - start_time = time.perf_counter() - generated = model.generate( - **inputs, - max_new_tokens=max_tokens, - do_sample=False, - temperature=None, - ) - end_time = time.perf_counter() - -print(tokenizer.decode(generated[0], skip_special_tokens=False)) -print(f"Generation took {end_time - start_time:.2f} seconds") diff --git a/moe_benchmarks/megablocks/cells/nv.py b/moe_benchmarks/megablocks/cells/nv.py deleted file mode 100644 index 80eef60a7536ed875fb21731ab2d059458bd20b4..0000000000000000000000000000000000000000 --- a/moe_benchmarks/megablocks/cells/nv.py +++ /dev/null @@ -1,3 +0,0 @@ -import subprocess - -print(subprocess.run(["nvidia-smi"], capture_output=True, text=True).stdout) \ No newline at end of file diff --git a/moe_benchmarks/megablocks/index.html b/moe_benchmarks/megablocks/index.html deleted file mode 100644 index 5058977b1559a20266c8982c064cfc3de010bb13..0000000000000000000000000000000000000000 --- a/moe_benchmarks/megablocks/index.html +++ /dev/null @@ -1,24 +0,0 @@ - - - - - Directory Index - - - -

Index of /moe_benchmarks/megablocks

- - - \ No newline at end of file diff --git a/moe_benchmarks/megablocks/megablocks_only.html b/moe_benchmarks/megablocks/megablocks_only.html deleted file mode 100644 index 125dc951bf7ca625c9935f3da09b02fd8b36cd19..0000000000000000000000000000000000000000 --- a/moe_benchmarks/megablocks/megablocks_only.html +++ /dev/null @@ -1,4164 +0,0 @@ - - - - - - Megablocks Only Test - - - - - - - -
-
-
light
-
reset
- -
-
- -
-
Generated on:
-
- Linux x86_64 | Linux-6.12.40-64.114.amzn2023.x86_64-x86_64-with-glibc2.36 -
-
- -
-

No Kernels

-

First, we run the model without any custom kernels to get a reference point.

-

Forward

-

Forward and Backward

-

Next, we'll attempt to run a forward and backward pass without any custom kernels. This will likely run out of memory since the default implementation is not optimized for memory usage.

-
-
- -▼ code -▼ output - ▶ uv-logs - | -Cell: forward_and_backward_no_kernel | 17.10s | FAILED - | - -Raw -
-
-
-
-1 -2 -3 -4 -5 -6 -7 -8 -9 -10 -11 -12 -13 -14 -15 -16 -17 -18 -19 -20 -21 -22 -23 -24 -25 -26 -27 -28 -29 -30 -31 -32 -33 -34 -35 -36 -37 -38 -39 -40 -41 -42 -43 -44 -45 -46 -47 -48 -49 -50 -51 -52 -53 -54 -55 -56 -57 -58 -59 -60 -61 -62 -63 -64 -65 -66 -67 -68 -69 -70 -71 -72 -73 -74 -75 -76 -77 -78 -79 -80 -81 -82 -83 -84 -85 -86 -87 -88 -89 -90 -91 -92 -93 -94 -95 -96 -97 -98 -99 -100 -101 -102 -103 -104 -105 -106 -107 -108 -109 -110 -111 -112 -113 -114 -115 -116 -117 -118 -119 -120 -121 -122 -123 -124 -125 -126 -127 -128 -129 -130 -131 -132 -133 -134 -135 -136 -137 -138 -139 -140 -141 -142 -143 -144 -145 -146 -147 -148 -149 -150 -151 -152 -153 -154 -155 -156 -157 -158 -159 -160 -161 -162 -163 -164 -165 -166 -167 -168 -169 -170 -171 -172 -173 -174 -175 -176 -177 -178 -179 -180 -181 -182 -183 -184 -185 -186 -187 -188 -189 -190 -191 -192 -193 -194 -195 -196 -
-
-
# /// script
-# requires-python = ">=3.12"
-# dependencies = [
-#     "accelerate>=1.10.1",
-#     "torch>=2.7.0",
-#     "kernels==0.10.0",
-#     "transformers@https://github.com/huggingface/transformers.git",
-#     "ipdb>=0.13.13",
-#     "matplotlib>=3.7.2",
-#     "numpy>=1.24.3",
-# ]
-# ///
-
-import torch
-from transformers import GptOssForCausalLM, PreTrainedTokenizerFast, Mxfp4Config
-import time
-import torch.nn as nn
-from kernels import register_kernel_mapping, Mode, LayerRepository, replace_kernel_forward_from_hub
-import sys
-import torch.profiler
-import gc
-import logging
-from transformers.models.gpt_oss.modeling_gpt_oss import GptOssRMSNorm
-
-# remove liger kernel for testing 
-replace_kernel_forward_from_hub(GptOssRMSNorm, None)
-
-# set to debug logging
-logging.basicConfig(level=logging.INFO)
-
-def reset_peak_memory_stats():
-    """Clear CUDA cache and reset memory allocation counters."""
-    torch.cuda.empty_cache()
-    if torch.cuda.is_available():
-        torch.cuda.reset_peak_memory_stats()
-    gc.collect()
-
-def get_memory_stats():
-    """Get current and peak CUDA memory usage."""
-    if not torch.cuda.is_available():
-        return {"allocated_gb": 0, "peak_gb": 0, "reserved_gb": 0}
-    return {
-        "allocated_gb": torch.cuda.memory_allocated() / 1e9,
-        "peak_gb": torch.cuda.max_memory_allocated() / 1e9,
-        "reserved_gb": torch.cuda.memory_reserved() / 1e9,
-    }
-
-def override_kernel_layer_name(cls_name: str, value) -> bool:
-    """Helper to dynamically override the kernel_layer_name in a model class."""
-    for mod in sys.modules.values():
-        if mod is None:
-            continue
-        obj = getattr(mod, cls_name, None)
-        if isinstance(obj, type) and issubclass(obj, nn.Module):
-            setattr(obj, "kernel_layer_name", value)
-            print(f"Overrode {cls_name}.kernel_layer_name to {value}")
-            return True
-    return False
-
-
-# Init the model the normal way
-model_id = "openai/gpt-oss-20b"
-tokenizer = PreTrainedTokenizerFast.from_pretrained(model_id)
-quantization_config = Mxfp4Config(dequantize=True)
-
-model = GptOssForCausalLM.from_pretrained(
-    model_id,
-    dtype="bfloat16",
-    device_map="auto",
-    use_kernels=False,
-    quantization_config=quantization_config,
-).eval()
-
-messages = [
-    {"role": "system", "content": "What is Tensor Parallelism?"},
-]
-
-inputs = tokenizer.apply_chat_template(
-    messages,
-    add_generation_prompt=True,
-    return_tensors="pt",
-    return_dict=True,
-    reasoning_effort="low",
-).to("cuda")
-
-max_tokens = 128  # Reduced to help with memory usage
-
-# Clear memory before backward pass
-reset_peak_memory_stats()
-print(f"Pre-generation memory: {get_memory_stats()}")
-
-# forward and backward pass
-with torch.autograd.set_grad_enabled(True):
-    start_time = time.perf_counter()
-    generated = model.generate(
-        **inputs,
-        max_new_tokens=max_tokens,
-        do_sample=False,
-        temperature=None,
-    )
-    end_time = time.perf_counter()
-    print(tokenizer.decode(generated[0], skip_special_tokens=False))
-    print(f"Generation took {end_time - start_time:.2f} seconds")
-    print(f"Post-generation memory: {get_memory_stats()}")
-
-    # Use gradient checkpointing to reduce memory usage
-    if hasattr(model, 'gradient_checkpointing_enable'):
-        model.gradient_checkpointing_enable()
-        print("Enabled gradient checkpointing")
-
-    # Reduce sequence length if needed for memory
-    max_seq_len = 512  # Limit sequence length for backward pass
-    if generated.size(1) > max_seq_len:
-        print(f"Truncating sequence from {generated.size(1)} to {max_seq_len} tokens")
-        full_sequence = generated[:, -max_seq_len:]
-    else:
-        full_sequence = generated
-
-    # Get model outputs for the full sequence
-    model.train()  # Enable dropout and other training behaviors
-
-    try:
-        outputs = model(
-            input_ids=full_sequence,
-            labels=full_sequence,  # This will compute loss internally
-            return_dict=True
-        )
-        print(f"Post-forward memory: {get_memory_stats()}")
-
-        # If model doesn't compute loss, compute it manually
-        if outputs.loss is None:
-            shift_logits = outputs.logits[..., :-1, :].contiguous()
-            shift_labels = full_sequence[..., 1:].contiguous()
-
-            # Use CrossEntropyLoss with ignore_index for padding tokens
-            loss_fct = torch.nn.CrossEntropyLoss(ignore_index=tokenizer.pad_token_id if tokenizer.pad_token_id is not None else -100)
-            loss = loss_fct(
-                shift_logits.view(-1, shift_logits.size(-1)),
-                shift_labels.view(-1)
-            )
-        else:
-            loss = outputs.loss
-
-        print(f"Loss: {loss.item():.4f}")
-
-        # Clear intermediate tensors to save memory
-        del outputs
-        torch.cuda.empty_cache()
-
-        # Perform backward pass with memory management
-        print("Running backward pass...")
-        print(f"Pre-backward memory: {get_memory_stats()}")
-
-        loss.backward()
-        print(f"Post-backward memory: {get_memory_stats()}")
-
-    except torch.cuda.OutOfMemoryError as e:
-        print(f"OOM during forward/backward pass: {e}")
-        print("Try reducing max_tokens or max_seq_len")
-        raise
-
-    # Calculate gradient statistics and print sample gradients
-    total_norm = 0.0
-    param_count = 0
-    grad_samples = {}
-
-    for name, p in model.named_parameters():
-        if p.grad is not None:
-            param_count += 1
-            grad_norm = p.grad.data.norm(2).item()
-            total_norm += grad_norm ** 2
-
-            # Collect gradient statistics for key layers
-            if any(key in name for key in ['embed', 'lm_head', 'mlp.up', 'mlp.down', 'self_attn.q_proj', 'norm']):
-                grad_samples[name] = {
-                    'norm': grad_norm,
-                    'mean': p.grad.data.mean().item(),
-                    'std': p.grad.data.std().item(),
-                    'max': p.grad.data.max().item(),
-                    'min': p.grad.data.min().item(),
-                }
-
-    total_norm = total_norm ** 0.5
-
-    print(f"\nGradient norm: {total_norm:.4f}")
-    print(f"Parameters with gradients: {param_count}")
-
-    # Print sample gradients from important layers
-    print("\nSample gradient statistics:")
-    for i, (name, stats) in enumerate(list(grad_samples.items())[:10]):
-        print(f"  {name[:60]:<60} | norm: {stats['norm']:.4e} | mean: {stats['mean']:.4e} | std: {stats['std']:.4e}")
-
-    # Optional: zero gradients for next iteration
-    model.zero_grad()
-    model.eval()  # Switch back to eval mode
-
- -
-
-
-
-
-
warning: The requested interpreter resolved to Python 3.11.13, which is incompatible with the script's Python requirement: `>=3.12` - Updating https://github.com/huggingface/transformers.git (HEAD) - Updated https://github.com/huggingface/transformers.git (53838edde77cb10f3a360150aa85a457637e9ac3) - × No solution found when resolving script dependencies: - ╰─▶ Because only transformers==4.57.0.dev0 is available and - transformers==4.57.0.dev0 depends on huggingface-hub==1.0.0rc1, - we can conclude that all versions of transformers depend on - huggingface-hub==1.0.0rc1. - And because kernels==0.10.0 depends on huggingface-hub>=0.26.0,<1.0, - we can conclude that kernels==0.10.0 and all versions of transformers - are incompatible. - And because you require kernels==0.10.0 and transformers, we can - conclude that your requirements are unsatisfiable. -
-
-
- -

Kernels

-

Next we can run with Megablocks kernels enabled.

-

Forward

-

First, we run a forward pass with Megablocks kernels.

-

Forward and Backward

-

Next, we run a forward and backward pass with Megablocks kernels enabled. This should be more memory efficient and allow us to complete the backward pass without running out of memory.

-
- - - \ No newline at end of file diff --git a/moe_benchmarks/megablocks_yamoe/artifacts/binned_run/binned_results.json b/moe_benchmarks/megablocks_yamoe/artifacts/binned_run/binned_results.json deleted file mode 100644 index 896e5f026a2f9eb24f3d40ecc94f9a01dac8c742..0000000000000000000000000000000000000000 --- a/moe_benchmarks/megablocks_yamoe/artifacts/binned_run/binned_results.json +++ /dev/null @@ -1,24 +0,0 @@ -{ - "implementation": "binned_results", - "config": { - "warmup": 10, - "iters": 50, - "device": "cuda", - "dtype": "torch.float32", - "tokens": 100, - "vary_inputs": true - }, - "stats": { - "avg_ms": 35.79408963999981, - "min_ms": 33.22658100000808, - "max_ms": 37.58223699998098, - "std_ms": 1.260985811405264, - "p50_ms": 36.03647150001166, - "p95_ms": 37.377484250018256, - "p99_ms": 37.52526078000528, - "num_iters": 50, - "tokens_per_s": 2793.7573215509365, - "throughput_variance": 99.68321642463675 - }, - "output_sum": 3.97190523147583 -} \ No newline at end of file diff --git a/moe_benchmarks/megablocks_yamoe/artifacts/gptoss_run/gptoss_results.json b/moe_benchmarks/megablocks_yamoe/artifacts/gptoss_run/gptoss_results.json deleted file mode 100644 index ae210db04e7f3222c805dfcf2696c3adda1792ab..0000000000000000000000000000000000000000 --- a/moe_benchmarks/megablocks_yamoe/artifacts/gptoss_run/gptoss_results.json +++ /dev/null @@ -1,24 +0,0 @@ -{ - "implementation": "gptoss_results", - "config": { - "warmup": 10, - "iters": 50, - "device": "cuda", - "dtype": "torch.float32", - "tokens": 100, - "vary_inputs": true - }, - "stats": { - "avg_ms": 45.9334355199951, - "min_ms": 40.05551199998081, - "max_ms": 49.51232600001276, - "std_ms": 2.4709340263031536, - "p50_ms": 46.49940249998963, - "p95_ms": 49.05830289997937, - "p99_ms": 49.3528599099983, - "num_iters": 50, - "tokens_per_s": 2177.0633715492368, - "throughput_variance": 121.25434497483073 - }, - "output_sum": 11.53223705291748 -} \ No newline at end of file diff --git a/moe_benchmarks/megablocks_yamoe/artifacts/gptoss_training_run/gptoss_training_results.json b/moe_benchmarks/megablocks_yamoe/artifacts/gptoss_training_run/gptoss_training_results.json deleted file mode 100644 index 8d525d761504e223064d2ee782dd2504d7cc758b..0000000000000000000000000000000000000000 --- a/moe_benchmarks/megablocks_yamoe/artifacts/gptoss_training_run/gptoss_training_results.json +++ /dev/null @@ -1,24 +0,0 @@ -{ - "implementation": "gptoss_training_results", - "config": { - "warmup": 10, - "iters": 50, - "device": "cuda", - "dtype": "torch.float32", - "tokens": 100, - "vary_inputs": true - }, - "stats": { - "avg_ms": 45.94743435999817, - "min_ms": 38.690121999991334, - "max_ms": 51.193351999984316, - "std_ms": 3.91507100876056, - "p50_ms": 45.20909099997539, - "p95_ms": 51.039028550002286, - "p99_ms": 51.14429515998495, - "num_iters": 50, - "tokens_per_s": 2176.4000839851024, - "throughput_variance": 188.75969966024954 - }, - "output_sum": 11.53223705291748 -} \ No newline at end of file diff --git a/moe_benchmarks/megablocks_yamoe/artifacts/megablocks_run/megablocks_results.json b/moe_benchmarks/megablocks_yamoe/artifacts/megablocks_run/megablocks_results.json deleted file mode 100644 index dfa1ef72a20f95f6c1c67c811de254883dcfa50a..0000000000000000000000000000000000000000 --- a/moe_benchmarks/megablocks_yamoe/artifacts/megablocks_run/megablocks_results.json +++ /dev/null @@ -1,24 +0,0 @@ -{ - "implementation": "megablocks_results", - "config": { - "warmup": 10, - "iters": 50, - "device": "cuda", - "dtype": "torch.float32", - "tokens": 100, - "vary_inputs": true - }, - "stats": { - "avg_ms": 3.8478457200017147, - "min_ms": 0.8121239999354657, - "max_ms": 8.535666000057063, - "std_ms": 3.697659288553723, - "p50_ms": 0.8394504999955643, - "p95_ms": 8.499624499950187, - "p99_ms": 8.528520820026415, - "num_iters": 50, - "tokens_per_s": 25988.567961595778, - "throughput_variance": 53035.39729321811 - }, - "output_sum": 6.4738850593566895 -} \ No newline at end of file diff --git a/moe_benchmarks/megablocks_yamoe/artifacts/visualization/moe_performance_comparison.png b/moe_benchmarks/megablocks_yamoe/artifacts/visualization/moe_performance_comparison.png deleted file mode 100644 index 8226b6253ae72ac85dd88be88cc9d18ca596e436..0000000000000000000000000000000000000000 --- a/moe_benchmarks/megablocks_yamoe/artifacts/visualization/moe_performance_comparison.png +++ /dev/null @@ -1,3 +0,0 @@ -version https://git-lfs.github.com/spec/v1 -oid sha256:c42b489158de5c37f27cbdb0e80f1dabfdc2341274809a6555fac2944d83349d -size 309063 diff --git a/moe_benchmarks/megablocks_yamoe/artifacts/yamoe_run/yamoe_results.json b/moe_benchmarks/megablocks_yamoe/artifacts/yamoe_run/yamoe_results.json deleted file mode 100644 index ea5e05d56334a58146b9b2359db281538a210f41..0000000000000000000000000000000000000000 --- a/moe_benchmarks/megablocks_yamoe/artifacts/yamoe_run/yamoe_results.json +++ /dev/null @@ -1,24 +0,0 @@ -{ - "implementation": "yamoe_results", - "config": { - "warmup": 10, - "iters": 50, - "device": "cuda", - "dtype": "torch.float32", - "tokens": 100, - "vary_inputs": true - }, - "stats": { - "avg_ms": 4.246394759998111, - "min_ms": 4.066528999999264, - "max_ms": 4.294285000014497, - "std_ms": 0.033808054217192726, - "p50_ms": 4.2530110000313925, - "p95_ms": 4.267295049984909, - "p99_ms": 4.287134920007816, - "num_iters": 50, - "tokens_per_s": 23549.38851705923, - "throughput_variance": 193.18069406896424 - }, - "output_sum": 3.97190523147583 -} \ No newline at end of file diff --git a/moe_benchmarks/megablocks_yamoe/cells/__pycache__/bench_utils.cpython-311.pyc b/moe_benchmarks/megablocks_yamoe/cells/__pycache__/bench_utils.cpython-311.pyc deleted file mode 100644 index d06c1e4d6458b8cf941570beedefd99839dd37d4..0000000000000000000000000000000000000000 Binary files a/moe_benchmarks/megablocks_yamoe/cells/__pycache__/bench_utils.cpython-311.pyc and /dev/null differ diff --git a/moe_benchmarks/megablocks_yamoe/cells/__pycache__/config.cpython-311.pyc b/moe_benchmarks/megablocks_yamoe/cells/__pycache__/config.cpython-311.pyc deleted file mode 100644 index 9e45b0f236650cb5461fb63b26508ecae56af697..0000000000000000000000000000000000000000 Binary files a/moe_benchmarks/megablocks_yamoe/cells/__pycache__/config.cpython-311.pyc and /dev/null differ diff --git a/moe_benchmarks/megablocks_yamoe/cells/bench_utils.py b/moe_benchmarks/megablocks_yamoe/cells/bench_utils.py deleted file mode 100644 index 6bb3706118149df02c1f7ebaaa6fbba84e71cd5e..0000000000000000000000000000000000000000 --- a/moe_benchmarks/megablocks_yamoe/cells/bench_utils.py +++ /dev/null @@ -1,241 +0,0 @@ -# /// script -# dependencies = [ -# "torch", -# "numpy", -# ] -# /// - -"""Reusable benchmarking utilities for performance testing.""" -import time -import numpy as np -from contextlib import contextmanager -from typing import Callable, Dict, Tuple, Any, Optional -import torch - -def to_dtype(dtype_str: str): - """Convert string to torch dtype.""" - if dtype_str == "float16": - return torch.float16 - if dtype_str == "bfloat16": - return torch.bfloat16 - return torch.float32 - -def _sync(device: str): - """Synchronize device if CUDA.""" - if device == "cuda": - torch.cuda.synchronize() - -def _compute_stats(times_s, tokens: Optional[int] = None) -> Dict[str, float]: - """Compute comprehensive latency and throughput statistics.""" - lat_ms = np.array([t * 1000.0 for t in times_s]) - lat_ms_sorted = np.sort(lat_ms) - n = len(lat_ms) - - stats = { - "avg_ms": np.mean(lat_ms), - "min_ms": np.min(lat_ms), - "max_ms": np.max(lat_ms), - "std_ms": np.std(lat_ms), - "p50_ms": np.percentile(lat_ms, 50), - "p95_ms": np.percentile(lat_ms, 95), - "p99_ms": np.percentile(lat_ms, 99), - "num_iters": n - } - - if tokens is not None and n > 0: - avg_s = np.mean(times_s) - stats["tokens_per_s"] = tokens / avg_s if avg_s > 0 else float("inf") - stats["throughput_variance"] = np.std([tokens / t for t in times_s if t > 0]) - - return stats - -def _format_timing_stats(stats: Dict[str, float], tokens: Optional[int] = None) -> str: - """Format timing statistics for display.""" - lines = [ - "\n━━━━━━━━━━━━━━━━━━━━ Benchmark Results ━━━━━━━━━━━━━━━━━━━━", - f"Iterations: {stats.get('num_iters', 0)}", - "\nLatency Statistics:", - f" Average: {stats['avg_ms']:.3f} ms", - f" Min: {stats['min_ms']:.3f} ms", - f" Max: {stats['max_ms']:.3f} ms", - f" Std Dev: {stats['std_ms']:.3f} ms", - "\nPercentiles:", - f" P50 (median): {stats['p50_ms']:.3f} ms", - f" P95: {stats['p95_ms']:.3f} ms", - f" P99: {stats['p99_ms']:.3f} ms", - ] - - if tokens is not None and 'tokens_per_s' in stats: - lines.extend([ - "\nThroughput:", - f" Tokens/sec: {stats['tokens_per_s']:.1f}", - f" Std Dev: {stats.get('throughput_variance', 0):.1f}", - ]) - - lines.append("━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━") - return "\n".join(lines) - -def _bench_engine( - call: Callable[[], Any], *, warmup: int, iters: int, device: str, dtype, input_gen: Callable[[], Any] = None -) -> Tuple[Any, list]: - """Core benchmarking engine with warmup and timing.""" - use_autocast = device == "cuda" and dtype in (torch.float16, torch.bfloat16) - - # Warmup phase - print(f"\nWarming up ({warmup} iterations)...") - with torch.inference_mode(): - for _ in range(max(0, warmup)): - if use_autocast: - with torch.autocast(device_type="cuda", dtype=dtype): - if input_gen is not None: - _ = call(input_gen()) - else: - _ = call() - else: - if input_gen is not None: - _ = call(input_gen()) - else: - _ = call() - _sync(device) - - # Measurement phase - print(f"Benchmarking ({iters} iterations)...") - times_s = [] - last = None - with torch.inference_mode(): - for i in range(max(1, iters)): - start = time.perf_counter() - if use_autocast: - with torch.autocast(device_type="cuda", dtype=dtype): - if input_gen is not None: - last = call(input_gen()) - else: - last = call() - else: - if input_gen is not None: - last = call(input_gen()) - else: - last = call() - _sync(device) - end = time.perf_counter() - times_s.append(end - start) - - # Progress indicator every 20% of iterations - if i > 0 and i % max(1, iters // 5) == 0: - pct = (i / iters) * 100 - avg_so_far = np.mean(times_s[:i]) * 1000 - print(f" Progress: {pct:.0f}% complete (avg: {avg_so_far:.3f} ms)") - - return last, times_s - -def tensor_stats(t: torch.Tensor) -> str: - """Generate comprehensive stats string for a tensor.""" - return (f"shape={tuple(t.shape)}, " - f"dtype={t.dtype}, " - f"device={t.device}, " - f"range=[{t.min().item():.6f}, {t.max().item():.6f}], " - f"mean={t.mean().item():.6f}, " - f"std={t.std().item():.6f}, " - f"norm={t.norm().item():.6f}") - -@contextmanager -def bench_context( - *, warmup: int = 25, iters: int = 100, device: str = "cuda", dtype=torch.float32, tokens: Optional[int] = None, verbose: bool = True, save_json: Optional[str] = None, vary_inputs: bool = True -): - """Context that yields a runner: runner(fn, *args, **kwargs) -> (result, stats). - - If vary_inputs=True, the first argument should be a base tensor that will be varied each iteration - by adding a small deterministic increment to prevent caching artifacts. - """ - - def runner(fn: Callable[..., Any], *args, **kwargs) -> Tuple[Any, Dict[str, float]]: - # Log configuration - if verbose: - print(f"\n┌─ Benchmark Configuration ─────────────────────────────┐") - # print(f"│ Device: {device:<15} Dtype: {dtype} │") - print(f"│ Warmup: {warmup:<15} Iters: {iters} │") - if tokens: - print(f"│ Tokens: {tokens} │") - if vary_inputs: - print(f"│ Input Variation: Enabled (prevents caching artifacts) │") - print(f"└────────────────────────────────────────────────────────┘") - - # Set up input generation - input_gen = None - if vary_inputs and args and isinstance(args[0], torch.Tensor): - base_input = args[0].clone() - iteration_counter = [0] # Use list for mutable closure - - def generate_varied_input(): - """Generate input tensor varied by iteration to prevent caching.""" - # Add small deterministic increment: 0.001 * iteration_number - varied_input = base_input + (iteration_counter[0] * 0.001) - iteration_counter[0] += 1 - return varied_input - - input_gen = generate_varied_input - call = lambda x: fn(x, *args[1:], **kwargs) - - # Log base input stats - if verbose: - print(f"\nBase Input: {tensor_stats(base_input)}") - print(f"Input Variation: +{0.001:.3f} * iteration (deterministic)") - else: - # Legacy mode - static inputs - call = lambda: fn(*args, **kwargs) - if verbose and args and isinstance(args[0], torch.Tensor): - print(f"\nInput: {tensor_stats(args[0])}") - - result, times_s = _bench_engine(call, warmup=warmup, iters=iters, device=device, dtype=dtype, input_gen=input_gen) - - # Log output if it's a tensor or tuple with tensors - if verbose: - print("\nOutput tensors:") - if isinstance(result, torch.Tensor): - print(f" Primary: {tensor_stats(result)}") - elif isinstance(result, tuple) and len(result) > 0 and isinstance(result[0], torch.Tensor): - print(f" Primary: {tensor_stats(result[0])}") - if len(result) > 1: - if isinstance(result[1], torch.Tensor): - print(f" Auxiliary: {tensor_stats(result[1])}") - else: - print(f" Auxiliary: {type(result[1]).__name__}") - - # Compute and display statistics - stats = _compute_stats(times_s, tokens=tokens) - if verbose: - print(_format_timing_stats(stats, tokens)) - - # Save to JSON if requested - if save_json: - import json - json_data = { - "implementation": save_json.replace(".json", ""), - "config": { - "warmup": warmup, - "iters": iters, - "device": str(device), # Convert device to string - "dtype": str(dtype), - "tokens": tokens, - "vary_inputs": vary_inputs - }, - "stats": stats, - "output_sum": float(result[0].sum().item()) if isinstance(result, tuple) and len(result) > 0 else float(result.sum().item()) if isinstance(result, torch.Tensor) else None - } - with open(save_json, 'w') as f: - json.dump(json_data, f, indent=2) - if verbose: - print(f"\nSaved benchmark results to {save_json}") - - return result, stats - - yield runner - -def set_seed(seed: int): - """Set seeds for reproducibility.""" - torch.manual_seed(seed) - if torch.cuda.is_available(): - torch.cuda.manual_seed(seed) - torch.cuda.manual_seed_all(seed) - torch.backends.cudnn.deterministic = True - torch.backends.cudnn.benchmark = False \ No newline at end of file diff --git a/moe_benchmarks/megablocks_yamoe/cells/binned_run.py b/moe_benchmarks/megablocks_yamoe/cells/binned_run.py deleted file mode 100644 index fe9e54316e7380bc60d7bb62459498e450575b31..0000000000000000000000000000000000000000 --- a/moe_benchmarks/megablocks_yamoe/cells/binned_run.py +++ /dev/null @@ -1,195 +0,0 @@ -# /// script -# dependencies = [ -# "torch", -# "numpy", -# ] -# /// - -import torch -from torch import nn -from torch.nn import functional as F -from bench_utils import to_dtype, tensor_stats, set_seed, bench_context -from config import ( - NUM_EXPERTS, HIDDEN_SIZE, TOP_K, - BATCH_SIZE, SEQ_LEN, DTYPE, DEVICE, - WEIGHT_SEED, EXPERT_SEED, INPUT_SEED, GENERAL_SEED -) -from pathlib import Path -import os - -# Discover the upstream artifact directory from env -data_dir = os.environ.get('UVNOTE_INPUT_SAVE_DATA', '.') - -router_weight = torch.load(Path(data_dir) / 'router_weight.pt') -router_bias = torch.load(Path(data_dir) / 'router_bias.pt') -gate_up_proj = torch.load(Path(data_dir) / 'gate_up_proj.pt') -gate_up_proj_bias = torch.load(Path(data_dir) / 'gate_up_proj_bias.pt') -down_proj = torch.load(Path(data_dir) / 'down_proj.pt') -down_proj_bias = torch.load(Path(data_dir) / 'down_proj_bias.pt') - -print("Loaded shared weights from artifacts") -print(f"Router weight sum: {router_weight.sum().item():.6f}") -print(f"Gate/up sum: {gate_up_proj.sum().item():.6f}") -print(f"Down sum: {down_proj.sum().item():.6f}") - -def binned_gather(x, indices, bins, expert_capacity, top_k): - E, H = bins.shape[0], x.shape[1] - out = torch.zeros((E, expert_capacity, H), device=x.device, dtype=x.dtype) - for e in range(E): - start = 0 if e == 0 else bins[e - 1] - end = bins[e] - n = min(end - start, expert_capacity) - for i in range(n): - flat_pos = indices[start + i] - tok = flat_pos // top_k - out[e, i] = x[tok] - return out - -def binned_scatter(x, indices, weights, bins, expert_capacity, top_k): - E, C, H = x.shape - N = indices.shape[0] // top_k - out = torch.zeros((N, top_k, H), dtype=x.dtype, device=x.device) - for e in range(E): - start = 0 if e == 0 else bins[e - 1] - end = bins[e] - n = end - start - if n == 0: - continue - take = min(n, expert_capacity) - for i in range(take): - flat_pos = indices[start + i] - tok = flat_pos // top_k - slot = flat_pos % top_k - scale = weights[flat_pos] if weights is not None else 1.0 - out[tok, slot] = x[e, i] * scale - return out.sum(dim=1) - -def sort_tokens_by_expert(router_indices, num_experts): - flat_indices = router_indices.flatten() - sorted_values, sorted_indices = torch.sort(flat_indices) - tokens_per_expert = torch.bincount(sorted_values, minlength=num_experts) - bins = torch.cumsum(tokens_per_expert, dim=0) - return sorted_indices, sorted_values, bins, tokens_per_expert - -def binned_experts_ref( - hidden_states, - router_indices, - routing_weights, - gate_up_proj, - gate_up_proj_bias, - down_proj, - down_proj_bias, - expert_capacity, -): - B, S, H = hidden_states.shape - E, K = routing_weights.shape[1], router_indices.shape[1] - - indices, _, bins, _ = sort_tokens_by_expert(router_indices, E) - x = binned_gather(hidden_states.view(-1, H), indices, bins, expert_capacity, K) - - gate_up = torch.bmm(x, gate_up_proj) - gate_up += gate_up_proj_bias[..., None, :] - - gate, up = gate_up[..., ::2], gate_up[..., 1::2] - - # clamp to limit - limit = 7.0 - gate = gate.clamp(min=None, max=limit) - up = up.clamp(min=-limit, max=limit) - - glu = gate * torch.sigmoid(gate * 1.702) - x = (up + 1) * glu - x = torch.bmm(x, down_proj) + down_proj_bias[..., None, :] - - # build routing weights aligned to (token, slot) - flat_dense = routing_weights.view(-1, E) - flat_router = router_indices.view(-1, K) - selected = torch.gather(flat_dense, 1, flat_router).reshape(-1) - - # scatter back - y = binned_scatter(x, indices, selected, bins, expert_capacity, K) - - return y.view(B, S, H) - -class BinnedRouter(nn.Module): - def __init__(self, router_weight, router_bias): - super().__init__() - self.top_k = TOP_K - self.num_experts = NUM_EXPERTS - self.hidden_dim = HIDDEN_SIZE - self.weight = nn.Parameter(router_weight.clone()) - self.bias = nn.Parameter(router_bias.clone()) - - def forward(self, hidden_states): - hidden_states = hidden_states.reshape(-1, self.hidden_dim) - router_logits = F.linear(hidden_states, self.weight, self.bias) - router_top_value, router_indices = torch.topk(router_logits, self.top_k, dim=-1) - router_top_value = torch.nn.functional.softmax(router_top_value, dim=1, dtype=router_top_value.dtype) - router_scores = torch.zeros_like(router_logits).scatter_(1, router_indices, router_top_value) - return router_scores, router_indices - -def ceil_div(a, b): - return (a + b - 1) // b - -class BinnedMoEMLP(nn.Module): - def __init__(self, router_weight, router_bias, gate_up_proj, gate_up_proj_bias, down_proj, down_proj_bias): - super().__init__() - self.router = BinnedRouter(router_weight, router_bias) - self.num_experts = NUM_EXPERTS - self.hidden_size = HIDDEN_SIZE - self.top_k = TOP_K - - # Expert weights - use the loaded weights - self.gate_up_proj = nn.Parameter(gate_up_proj.clone()) - self.gate_up_proj_bias = nn.Parameter(gate_up_proj_bias.clone()) - self.down_proj = nn.Parameter(down_proj.clone()) - self.down_proj_bias = nn.Parameter(down_proj_bias.clone()) - - def forward(self, hidden_states): - router_scores, router_indices = self.router(hidden_states) - batch_size = hidden_states.shape[0] - expert_capacity = ceil_div(batch_size * self.top_k, self.num_experts) - - output = binned_experts_ref( - hidden_states, - router_indices, - router_scores, - self.gate_up_proj, - self.gate_up_proj_bias, - self.down_proj, - self.down_proj_bias, - expert_capacity, - ) - - return output, router_scores - -# Run the model -set_seed(GENERAL_SEED) - -device = torch.device(DEVICE) -dtype = to_dtype(DTYPE) - -print("\n=== Binned Implementation ===") -# Initialize model with loaded weights -model = BinnedMoEMLP( - router_weight.to(device), - router_bias.to(device), - gate_up_proj.to(device), - gate_up_proj_bias.to(device), - down_proj.to(device), - down_proj_bias.to(device) -).to(device=device) - -print(f"Router weight sum: {model.router.weight.sum().item():.6f}") -print(f"Gate/up proj sum: {model.gate_up_proj.sum().item():.6f}") -print(f"Down proj sum: {model.down_proj.sum().item():.6f}") - -# Generate the same input as Yamoe -set_seed(INPUT_SEED) -x = torch.randn(BATCH_SIZE, SEQ_LEN, HIDDEN_SIZE, device=device, dtype=dtype) * 0.1 - -# Benchmark the model with varied inputs to prevent caching artifacts -tokens = BATCH_SIZE * SEQ_LEN -with bench_context(warmup=10, iters=50, device=device, dtype=dtype, tokens=tokens, save_json="binned_results.json", vary_inputs=True) as bench: - output, stats = bench(model, x) - print(f"\nOutput sum: {output[0].sum().item():.6f}") \ No newline at end of file diff --git a/moe_benchmarks/megablocks_yamoe/cells/config.py b/moe_benchmarks/megablocks_yamoe/cells/config.py deleted file mode 100644 index 747a7224106854e57904aa10edc15f4d5f0c4a17..0000000000000000000000000000000000000000 --- a/moe_benchmarks/megablocks_yamoe/cells/config.py +++ /dev/null @@ -1,27 +0,0 @@ -# /// script -# dependencies = [ -# "torch", -# "numpy", -# ] -# /// - -"""Shared configuration for both implementations.""" -import torch - -# Model configuration -NUM_EXPERTS = 128 -HIDDEN_SIZE = 1152 -INTERMEDIATE_SIZE = 3072 -TOP_K = 4 - -# Input configuration -BATCH_SIZE = 1 -SEQ_LEN = 100 -DTYPE = "float32" -DEVICE = "cuda" if torch.cuda.is_available() else "cpu" - -# Seeds for reproducibility -WEIGHT_SEED = 999 -EXPERT_SEED = 777 -INPUT_SEED = 123 -GENERAL_SEED = 42 \ No newline at end of file diff --git a/moe_benchmarks/megablocks_yamoe/cells/gptoss_run.py b/moe_benchmarks/megablocks_yamoe/cells/gptoss_run.py deleted file mode 100644 index 5a1532dabff53ecb068ddd4354c545f0cea2d72b..0000000000000000000000000000000000000000 --- a/moe_benchmarks/megablocks_yamoe/cells/gptoss_run.py +++ /dev/null @@ -1,147 +0,0 @@ -# /// script -# dependencies = [ -# "torch", -# "numpy", -# ] -# /// - -import torch -from torch import nn -from torch.nn import functional as F -from bench_utils import to_dtype, tensor_stats, set_seed, bench_context -from config import ( - NUM_EXPERTS, HIDDEN_SIZE, TOP_K, - BATCH_SIZE, SEQ_LEN, DTYPE, DEVICE, - WEIGHT_SEED, EXPERT_SEED, INPUT_SEED, GENERAL_SEED -) -from pathlib import Path -import os - -# Discover the upstream artifact directory from env -data_dir = os.environ.get('UVNOTE_INPUT_SAVE_DATA', '.') - -router_weight = torch.load(Path(data_dir) / 'router_weight.pt') -router_bias = torch.load(Path(data_dir) / 'router_bias.pt') -gate_up_proj = torch.load(Path(data_dir) / 'gate_up_proj.pt') -gate_up_proj_bias = torch.load(Path(data_dir) / 'gate_up_proj_bias.pt') -down_proj = torch.load(Path(data_dir) / 'down_proj.pt') -down_proj_bias = torch.load(Path(data_dir) / 'down_proj_bias.pt') - -print("Loaded shared weights from artifacts") -print(f"Router weight sum: {router_weight.sum().item():.6f}") -print(f"Gate/up sum: {gate_up_proj.sum().item():.6f}") -print(f"Down sum: {down_proj.sum().item():.6f}") - -class GptOssRouter(nn.Module): - def __init__(self, router_weight, router_bias): - super().__init__() - self.top_k = TOP_K - self.num_experts = NUM_EXPERTS - self.hidden_dim = HIDDEN_SIZE - self.weight = nn.Parameter(router_weight.clone()) - self.bias = nn.Parameter(router_bias.clone()) - - def forward(self, hidden_states): - hidden_states = hidden_states.reshape(-1, self.hidden_dim) - router_logits = F.linear(hidden_states, self.weight, self.bias) - router_top_value, router_indices = torch.topk(router_logits, self.top_k, dim=-1) - router_top_value = torch.nn.functional.softmax(router_top_value, dim=1, dtype=router_top_value.dtype) - router_scores = torch.zeros_like(router_logits).scatter_(1, router_indices, router_top_value) - return router_scores, router_indices - -class GptOssExperts(nn.Module): - def __init__(self, gate_up_proj, gate_up_proj_bias, down_proj, down_proj_bias): - super().__init__() - self.num_experts = NUM_EXPERTS - self.hidden_size = HIDDEN_SIZE - self.expert_dim = self.hidden_size - self.gate_up_proj = nn.Parameter(gate_up_proj.clone()) - self.gate_up_proj_bias = nn.Parameter(gate_up_proj_bias.clone()) - self.down_proj = nn.Parameter(down_proj.clone()) - self.down_proj_bias = nn.Parameter(down_proj_bias.clone()) - self.alpha = 1.702 - self.limit = 7.0 - - def forward(self, hidden_states: torch.Tensor, router_indices=None, routing_weights=None) -> torch.Tensor: - batch_size = hidden_states.shape[0] - hidden_states = hidden_states.reshape(-1, self.hidden_size) - num_experts = routing_weights.shape[1] - - if hidden_states.device.type == "cpu" or self.training: - next_states = torch.zeros_like(hidden_states, dtype=hidden_states.dtype, device=hidden_states.device) - with torch.no_grad(): - expert_mask = torch.nn.functional.one_hot(router_indices, num_classes=num_experts) - expert_mask = expert_mask.permute(2, 1, 0) - expert_hit = torch.greater(expert_mask.sum(dim=(-1, -2)), 0).nonzero() - - for expert_idx in expert_hit[:]: - expert_idx = expert_idx[0] - with torch.no_grad(): - _, token_idx = torch.where(expert_mask[expert_idx]) - current_state = hidden_states[token_idx] - gate_up = current_state @ self.gate_up_proj[expert_idx] + self.gate_up_proj_bias[expert_idx] - gate, up = gate_up[..., ::2], gate_up[..., 1::2] - gate = gate.clamp(min=None, max=self.limit) - up = up.clamp(min=-self.limit, max=self.limit) - glu = gate * torch.sigmoid(gate * self.alpha) - gated_output = (up + 1) * glu - out = gated_output @ self.down_proj[expert_idx] + self.down_proj_bias[expert_idx] - weighted_output = out * routing_weights[token_idx, expert_idx, None] - next_states.index_add_(0, token_idx, weighted_output.to(hidden_states.dtype)) - next_states = next_states.view(batch_size, -1, self.hidden_size) - else: - hidden_states = hidden_states.repeat(num_experts, 1) - hidden_states = hidden_states.view(num_experts, -1, self.hidden_size) - gate_up = torch.bmm(hidden_states, self.gate_up_proj) + self.gate_up_proj_bias[..., None, :] - gate, up = gate_up[..., ::2], gate_up[..., 1::2] - gate = gate.clamp(min=None, max=self.limit) - up = up.clamp(min=-self.limit, max=self.limit) - glu = gate * torch.sigmoid(gate * self.alpha) - next_states = torch.bmm(((up + 1) * glu), self.down_proj) - next_states = next_states + self.down_proj_bias[..., None, :] - next_states = next_states.view(num_experts, batch_size, -1, self.hidden_size) - next_states = next_states * routing_weights.transpose(0, 1).view(num_experts, batch_size, -1)[..., None] - next_states = next_states.sum(dim=0) - return next_states - -class GptOssMoEMLP(nn.Module): - def __init__(self, router_weight, router_bias, gate_up_proj, gate_up_proj_bias, down_proj, down_proj_bias): - super().__init__() - self.router = GptOssRouter(router_weight, router_bias) - self.experts = GptOssExperts(gate_up_proj, gate_up_proj_bias, down_proj, down_proj_bias) - - def forward(self, hidden_states): - router_scores, router_indices = self.router(hidden_states) - routed_out = self.experts(hidden_states, router_indices=router_indices, routing_weights=router_scores) - return routed_out, router_scores - -# Run the model -set_seed(GENERAL_SEED) - -device = torch.device(DEVICE) -dtype = to_dtype(DTYPE) - -print("\n=== GPT-OSS Implementation ===") -# Initialize model with loaded weights -model = GptOssMoEMLP( - router_weight.to(device), - router_bias.to(device), - gate_up_proj.to(device), - gate_up_proj_bias.to(device), - down_proj.to(device), - down_proj_bias.to(device) -).to(device=device) - -print(f"Router weight sum: {model.router.weight.sum().item():.6f}") -print(f"Gate/up proj sum: {model.experts.gate_up_proj.sum().item():.6f}") -print(f"Down proj sum: {model.experts.down_proj.sum().item():.6f}") - -# Generate the same input as other implementations -set_seed(INPUT_SEED) -x = torch.randn(BATCH_SIZE, SEQ_LEN, HIDDEN_SIZE, device=device, dtype=dtype) * 0.1 - -# Benchmark the model with varied inputs to prevent caching artifacts -tokens = BATCH_SIZE * SEQ_LEN -with bench_context(warmup=10, iters=50, device=device, dtype=dtype, tokens=tokens, save_json="gptoss_results.json", vary_inputs=True) as bench: - output, stats = bench(model, x) - print(f"\nOutput sum: {output[0].sum().item():.6f}") \ No newline at end of file diff --git a/moe_benchmarks/megablocks_yamoe/cells/gptoss_training_run.py b/moe_benchmarks/megablocks_yamoe/cells/gptoss_training_run.py deleted file mode 100644 index f18731a74bfa546e612addbaab9e3ff5ec5d26dc..0000000000000000000000000000000000000000 --- a/moe_benchmarks/megablocks_yamoe/cells/gptoss_training_run.py +++ /dev/null @@ -1,138 +0,0 @@ -# /// script -# dependencies = [ -# "torch", -# "numpy", -# ] -# /// - -import torch -from torch import nn -from torch.nn import functional as F -from bench_utils import to_dtype, tensor_stats, set_seed, bench_context -from config import ( - NUM_EXPERTS, HIDDEN_SIZE, TOP_K, - BATCH_SIZE, SEQ_LEN, DTYPE, DEVICE, - WEIGHT_SEED, EXPERT_SEED, INPUT_SEED, GENERAL_SEED -) -from pathlib import Path -import os - -# Discover the upstream artifact directory from env -data_dir = os.environ.get('UVNOTE_INPUT_SAVE_DATA', '.') - -router_weight = torch.load(Path(data_dir) / 'router_weight.pt') -router_bias = torch.load(Path(data_dir) / 'router_bias.pt') -gate_up_proj = torch.load(Path(data_dir) / 'gate_up_proj.pt') -gate_up_proj_bias = torch.load(Path(data_dir) / 'gate_up_proj_bias.pt') -down_proj = torch.load(Path(data_dir) / 'down_proj.pt') -down_proj_bias = torch.load(Path(data_dir) / 'down_proj_bias.pt') - -print("Loaded shared weights from artifacts") -print(f"Router weight sum: {router_weight.sum().item():.6f}") -print(f"Gate/up sum: {gate_up_proj.sum().item():.6f}") -print(f"Down sum: {down_proj.sum().item():.6f}") - -class GptOssTrainingRouter(nn.Module): - def __init__(self, router_weight, router_bias): - super().__init__() - self.top_k = TOP_K - self.num_experts = NUM_EXPERTS - self.hidden_dim = HIDDEN_SIZE - self.weight = nn.Parameter(router_weight.clone()) - self.bias = nn.Parameter(router_bias.clone()) - - def forward(self, hidden_states): - hidden_states = hidden_states.reshape(-1, self.hidden_dim) - router_logits = F.linear(hidden_states, self.weight, self.bias) - router_top_value, router_indices = torch.topk(router_logits, self.top_k, dim=-1) - router_top_value = torch.nn.functional.softmax(router_top_value, dim=1, dtype=router_top_value.dtype) - router_scores = torch.zeros_like(router_logits).scatter_(1, router_indices, router_top_value) - return router_scores, router_indices - -class GptOssTrainingExperts(nn.Module): - def __init__(self, gate_up_proj, gate_up_proj_bias, down_proj, down_proj_bias): - super().__init__() - self.num_experts = NUM_EXPERTS - self.hidden_size = HIDDEN_SIZE - self.expert_dim = self.hidden_size - self.gate_up_proj = nn.Parameter(gate_up_proj.clone()) - self.gate_up_proj_bias = nn.Parameter(gate_up_proj_bias.clone()) - self.down_proj = nn.Parameter(down_proj.clone()) - self.down_proj_bias = nn.Parameter(down_proj_bias.clone()) - self.alpha = 1.702 - self.limit = 7.0 - - def forward(self, hidden_states: torch.Tensor, router_indices=None, routing_weights=None) -> torch.Tensor: - batch_size = hidden_states.shape[0] - hidden_states = hidden_states.reshape(-1, self.hidden_size) - num_experts = routing_weights.shape[1] - - # Force training mode path (expert loop instead of batched) - next_states = torch.zeros_like(hidden_states, dtype=hidden_states.dtype, device=hidden_states.device) - with torch.no_grad(): - expert_mask = torch.nn.functional.one_hot(router_indices, num_classes=num_experts) - expert_mask = expert_mask.permute(2, 1, 0) - expert_hit = torch.greater(expert_mask.sum(dim=(-1, -2)), 0).nonzero() - - for expert_idx in expert_hit[:]: - expert_idx = expert_idx[0] - with torch.no_grad(): - _, token_idx = torch.where(expert_mask[expert_idx]) - current_state = hidden_states[token_idx] - gate_up = current_state @ self.gate_up_proj[expert_idx] + self.gate_up_proj_bias[expert_idx] - gate, up = gate_up[..., ::2], gate_up[..., 1::2] - gate = gate.clamp(min=None, max=self.limit) - up = up.clamp(min=-self.limit, max=self.limit) - glu = gate * torch.sigmoid(gate * self.alpha) - gated_output = (up + 1) * glu - out = gated_output @ self.down_proj[expert_idx] + self.down_proj_bias[expert_idx] - weighted_output = out * routing_weights[token_idx, expert_idx, None] - next_states.index_add_(0, token_idx, weighted_output.to(hidden_states.dtype)) - next_states = next_states.view(batch_size, -1, self.hidden_size) - return next_states - -class GptOssTrainingMoEMLP(nn.Module): - def __init__(self, router_weight, router_bias, gate_up_proj, gate_up_proj_bias, down_proj, down_proj_bias): - super().__init__() - self.router = GptOssTrainingRouter(router_weight, router_bias) - self.experts = GptOssTrainingExperts(gate_up_proj, gate_up_proj_bias, down_proj, down_proj_bias) - - def forward(self, hidden_states): - router_scores, router_indices = self.router(hidden_states) - routed_out = self.experts(hidden_states, router_indices=router_indices, routing_weights=router_scores) - return routed_out, router_scores - -# Run the model -set_seed(GENERAL_SEED) - -device = torch.device(DEVICE) -dtype = to_dtype(DTYPE) - -print("\n=== GPT-OSS Implementation (Training Mode - Expert Loop) ===") -# Initialize model with loaded weights and force training mode -model = GptOssTrainingMoEMLP( - router_weight.to(device), - router_bias.to(device), - gate_up_proj.to(device), - gate_up_proj_bias.to(device), - down_proj.to(device), - down_proj_bias.to(device) -).to(device=device) - -# Set to training mode to force expert loop path -model.train() - -print(f"Router weight sum: {model.router.weight.sum().item():.6f}") -print(f"Gate/up proj sum: {model.experts.gate_up_proj.sum().item():.6f}") -print(f"Down proj sum: {model.experts.down_proj.sum().item():.6f}") -print(f"Model training mode: {model.training}") - -# Generate the same input as other implementations -set_seed(INPUT_SEED) -x = torch.randn(BATCH_SIZE, SEQ_LEN, HIDDEN_SIZE, device=device, dtype=dtype) * 0.1 - -# Benchmark the model with varied inputs to prevent caching artifacts -tokens = BATCH_SIZE * SEQ_LEN -with bench_context(warmup=10, iters=50, device=device, dtype=dtype, tokens=tokens, save_json="gptoss_training_results.json", vary_inputs=True) as bench: - output, stats = bench(model, x) - print(f"\nOutput sum: {output[0].sum().item():.6f}") \ No newline at end of file diff --git a/moe_benchmarks/megablocks_yamoe/cells/megablocks_run.py b/moe_benchmarks/megablocks_yamoe/cells/megablocks_run.py deleted file mode 100644 index a18723cb66c892119c0a9e88d8c2a140a6354a00..0000000000000000000000000000000000000000 --- a/moe_benchmarks/megablocks_yamoe/cells/megablocks_run.py +++ /dev/null @@ -1,103 +0,0 @@ -# /// script -# dependencies = [ -# "torch", -# "numpy", -# "kernels", -# ] -# /// - -import torch -from torch import nn -from torch.nn import functional as F -from kernels import get_kernel, get_local_kernel -from bench_utils import to_dtype, tensor_stats, set_seed, bench_context -from config import ( - NUM_EXPERTS, HIDDEN_SIZE, TOP_K, - BATCH_SIZE, SEQ_LEN, DTYPE, DEVICE, - WEIGHT_SEED, EXPERT_SEED, INPUT_SEED, GENERAL_SEED -) -from pathlib import Path -from collections import namedtuple -import os - -# Discover the upstream artifact directory from env -data_dir = os.environ.get('UVNOTE_INPUT_SAVE_DATA', '.') - -print(f"Loading weights from: {data_dir}") - -router_weight = torch.load(Path(data_dir) / 'router_weight.pt') -router_bias = torch.load(Path(data_dir) / 'router_bias.pt') -gate_up_proj = torch.load(Path(data_dir) / 'gate_up_proj.pt') -gate_up_proj_bias = torch.load(Path(data_dir) / 'gate_up_proj_bias.pt') -down_proj = torch.load(Path(data_dir) / 'down_proj.pt') -down_proj_bias = torch.load(Path(data_dir) / 'down_proj_bias.pt') - -print("Loaded shared weights from artifacts") -print(f"Router weight sum: {router_weight.sum().item():.6f}") -print(f"Gate/up sum: {gate_up_proj.sum().item():.6f}") -print(f"Down sum: {down_proj.sum().item():.6f}") - -def build_megablocks_model(device: torch.device): - # Download optimized kernels from the Hugging Face hub - megablocks = get_kernel("kernels-community/megablocks", revision="v0.0.2") - model = megablocks.layers.MegaBlocksMoeMLP() - - # Create attribute container for expert weights - model.experts = namedtuple( - "Experts", ["gate_up_proj", "gate_up_proj_bias", "down_proj", "down_proj_bias", "hidden_size"] - ) - - # Use loaded router weights for consistency - model.router = torch.nn.Linear(HIDDEN_SIZE, NUM_EXPERTS, device=device) - with torch.no_grad(): - model.router.weight.copy_(router_weight) - model.router.bias.copy_(router_bias) - - # Attach loaded expert weights to the experts container - e = model.experts - e.alpha = 1.702 - e.capacity_factor = 32 - e.gate_up_proj = torch.nn.Parameter(gate_up_proj.clone().to(device)) - e.gate_up_proj_bias = torch.nn.Parameter(gate_up_proj_bias.clone().to(device)) - e.down_proj = torch.nn.Parameter(down_proj.clone().to(device)) - e.down_proj_bias = torch.nn.Parameter(down_proj_bias.clone().to(device)) - e.hidden_size = HIDDEN_SIZE - - # Log weight statistics for comparison - print(f"[MegaBlocks] Router weight sum: {model.router.weight.sum().item():.6f}") - print(f"[MegaBlocks] Gate/up projection shape: {tuple(e.gate_up_proj.shape)}, sum: {e.gate_up_proj.sum().item():.6f}") - print(f"[MegaBlocks] Down projection shape: {tuple(e.down_proj.shape)}, sum: {e.down_proj.sum().item():.6f}") - - return model - -# Create a wrapper to match the interface of other implementations -class MegaBlocksMoEWrapper(nn.Module): - def __init__(self, megablocks_model): - super().__init__() - self.model = megablocks_model - - def forward(self, hidden_states): - # MegaBlocks expects input in the format (batch, seq_len, hidden_dim) - output, dummy_routing_weights = self.model(hidden_states) - return output, dummy_routing_weights - -# Run the model -set_seed(GENERAL_SEED) - -device = torch.device(DEVICE) -dtype = to_dtype(DTYPE) - -print("\n=== MegaBlocks Implementation ===") -# Build MegaBlocks model with loaded weights -megablocks_model = build_megablocks_model(device) -model = MegaBlocksMoEWrapper(megablocks_model).to(device=device) - -# Generate the same input as other implementations -set_seed(INPUT_SEED) -x = torch.randn(BATCH_SIZE, SEQ_LEN, HIDDEN_SIZE, device=device, dtype=dtype) * 0.1 - -# Benchmark the model with varied inputs to prevent caching artifacts -tokens = BATCH_SIZE * SEQ_LEN -with bench_context(warmup=10, iters=50, device=device, dtype=dtype, tokens=tokens, save_json="megablocks_results.json", vary_inputs=True) as bench: - output, stats = bench(model, x) - print(f"\nOutput sum: {output[0].sum().item():.6f}") \ No newline at end of file diff --git a/moe_benchmarks/megablocks_yamoe/cells/nv.py b/moe_benchmarks/megablocks_yamoe/cells/nv.py deleted file mode 100644 index 80eef60a7536ed875fb21731ab2d059458bd20b4..0000000000000000000000000000000000000000 --- a/moe_benchmarks/megablocks_yamoe/cells/nv.py +++ /dev/null @@ -1,3 +0,0 @@ -import subprocess - -print(subprocess.run(["nvidia-smi"], capture_output=True, text=True).stdout) \ No newline at end of file diff --git a/moe_benchmarks/megablocks_yamoe/cells/save_data.py b/moe_benchmarks/megablocks_yamoe/cells/save_data.py deleted file mode 100644 index b15750dce52da48651ccd9805cdab51af88503d5..0000000000000000000000000000000000000000 --- a/moe_benchmarks/megablocks_yamoe/cells/save_data.py +++ /dev/null @@ -1,42 +0,0 @@ -# /// script -# dependencies = [ -# "torch", -# "numpy", -# ] -# /// - -""" -Generate deterministic shared weights once and save as artifacts so -both implementations load identical parameters. -""" -import torch -from config import NUM_EXPERTS, HIDDEN_SIZE, WEIGHT_SEED, EXPERT_SEED - -def save_shared_weights(): - # Router: Kaiming uniform as used by both, bias zeros - torch.manual_seed(WEIGHT_SEED) - router_weight = torch.empty(NUM_EXPERTS, HIDDEN_SIZE) - torch.nn.init.kaiming_uniform_(router_weight) - router_bias = torch.zeros(NUM_EXPERTS) - - # Experts: normal(0, 0.02), biases zeros - torch.manual_seed(EXPERT_SEED) - gate_up_proj = torch.empty(NUM_EXPERTS, HIDDEN_SIZE, 2 * HIDDEN_SIZE).normal_(mean=0.0, std=0.02) - gate_up_proj_bias = torch.zeros(NUM_EXPERTS, 2 * HIDDEN_SIZE) - down_proj = torch.empty(NUM_EXPERTS, HIDDEN_SIZE, HIDDEN_SIZE).normal_(mean=0.0, std=0.02) - down_proj_bias = torch.zeros(NUM_EXPERTS, HIDDEN_SIZE) - - # Save artifacts - torch.save(router_weight, 'router_weight.pt') - torch.save(router_bias, 'router_bias.pt') - torch.save(gate_up_proj, 'gate_up_proj.pt') - torch.save(gate_up_proj_bias, 'gate_up_proj_bias.pt') - torch.save(down_proj, 'down_proj.pt') - torch.save(down_proj_bias, 'down_proj_bias.pt') - - print("Saved shared weights to artifacts") - print(f"Router weight sum: {router_weight.sum().item():.6f}") - print(f"Gate/up sum: {gate_up_proj.sum().item():.6f}") - print(f"Down sum: {down_proj.sum().item():.6f}") - -save_shared_weights() \ No newline at end of file diff --git a/moe_benchmarks/megablocks_yamoe/cells/setup.py b/moe_benchmarks/megablocks_yamoe/cells/setup.py deleted file mode 100644 index 6d7f386417ca59470f5e6404d26b64a6d1fd6f39..0000000000000000000000000000000000000000 --- a/moe_benchmarks/megablocks_yamoe/cells/setup.py +++ /dev/null @@ -1,116 +0,0 @@ -# /// script -# requires-python = ">=3.12" -# dependencies = [ -# "accelerate>=1.10.1", -# "torch>=2.7.0", -# "kernels==0.10.0", -# "transformers@https://github.com/huggingface/transformers.git", -# "ipdb>=0.13.13", -# "matplotlib>=3.7.2", -# "numpy>=1.24.3", -# ] -# /// - -import torch -from transformers import GptOssForCausalLM, PreTrainedTokenizerFast, Mxfp4Config -import time -import torch.nn as nn -from kernels import register_kernel_mapping, Mode, LayerRepository -import sys -import torch.profiler -import gc -import logging - -# set to debug logging -logging.basicConfig(level=logging.INFO) - -def reset_peak_memory_stats(): - """Clear CUDA cache and reset memory allocation counters.""" - torch.cuda.empty_cache() - if torch.cuda.is_available(): - torch.cuda.reset_peak_memory_stats() - gc.collect() - -def get_memory_stats(): - """Get current and peak CUDA memory usage.""" - if not torch.cuda.is_available(): - return {"allocated_gb": 0, "peak_gb": 0, "reserved_gb": 0} - return { - "allocated_gb": torch.cuda.memory_allocated() / 1e9, - "peak_gb": torch.cuda.max_memory_allocated() / 1e9, - "reserved_gb": torch.cuda.memory_reserved() / 1e9, - } - -def override_kernel_layer_name(cls_name: str, value) -> bool: - """Helper to dynamically override the kernel_layer_name in a model class.""" - for mod in sys.modules.values(): - if mod is None: - continue - obj = getattr(mod, cls_name, None) - if isinstance(obj, type) and issubclass(obj, nn.Module): - setattr(obj, "kernel_layer_name", value) - print(f"Overrode {cls_name}.kernel_layer_name to {value}") - return True - return False - - -# Init the model the normal way -model_id = "openai/gpt-oss-20b" -tokenizer = PreTrainedTokenizerFast.from_pretrained(model_id) -quantization_config = Mxfp4Config(dequantize=True) - - -from kernels import replace_kernel_forward_from_hub, register_kernel_mapping, LayerRepository, Mode - -from transformers.models.gpt_oss.modeling_gpt_oss import GptOssMLP, GptOssRMSNorm - -replace_kernel_forward_from_hub(GptOssMLP, "Yamoe") -replace_kernel_forward_from_hub(GptOssRMSNorm, None) -custom_mapping = { - "Yamoe": { - "cuda": { - Mode.INFERENCE: LayerRepository( - repo_id="drbh/yamoe", - layer_name="Yamoe", - revision="v0.3.0", - ) - } - } -} -register_kernel_mapping(custom_mapping) - - -model = GptOssForCausalLM.from_pretrained( - model_id, - dtype="bfloat16", - device_map="auto", - use_kernels=True, - quantization_config=quantization_config, -).eval() - -messages = [ - {"role": "system", "content": "What is Tensor Parallelism?"}, -] - -inputs = tokenizer.apply_chat_template( - messages, - add_generation_prompt=True, - return_tensors="pt", - return_dict=True, - reasoning_effort="low", -).to("cuda") - -max_tokens = 256 - -with torch.inference_mode(): - start_time = time.perf_counter() - generated = model.generate( - **inputs, - max_new_tokens=max_tokens, - do_sample=False, - temperature=None, - ) - end_time = time.perf_counter() - -print(tokenizer.decode(generated[0], skip_special_tokens=False)) -print(f"Generation took {end_time - start_time:.2f} seconds") diff --git a/moe_benchmarks/megablocks_yamoe/cells/setup2.py b/moe_benchmarks/megablocks_yamoe/cells/setup2.py deleted file mode 100644 index b67054a2580d775875fd3f0382d5820f3076236b..0000000000000000000000000000000000000000 --- a/moe_benchmarks/megablocks_yamoe/cells/setup2.py +++ /dev/null @@ -1,115 +0,0 @@ -# /// script -# requires-python = ">=3.12" -# dependencies = [ -# "accelerate>=1.10.1", -# "torch>=2.7.0", -# "kernels==0.10.0", -# "transformers@https://github.com/huggingface/transformers.git", -# "ipdb>=0.13.13", -# "matplotlib>=3.7.2", -# "numpy>=1.24.3", -# ] -# /// - -import torch -from transformers import GptOssForCausalLM, PreTrainedTokenizerFast, Mxfp4Config -import time -import torch.nn as nn -from kernels import register_kernel_mapping, Mode, LayerRepository -import sys -import torch.profiler -import gc -import logging - -# set to debug logging -logging.basicConfig(level=logging.INFO) - -def reset_peak_memory_stats(): - """Clear CUDA cache and reset memory allocation counters.""" - torch.cuda.empty_cache() - if torch.cuda.is_available(): - torch.cuda.reset_peak_memory_stats() - gc.collect() - -def get_memory_stats(): - """Get current and peak CUDA memory usage.""" - if not torch.cuda.is_available(): - return {"allocated_gb": 0, "peak_gb": 0, "reserved_gb": 0} - return { - "allocated_gb": torch.cuda.memory_allocated() / 1e9, - "peak_gb": torch.cuda.max_memory_allocated() / 1e9, - "reserved_gb": torch.cuda.memory_reserved() / 1e9, - } - -def override_kernel_layer_name(cls_name: str, value) -> bool: - """Helper to dynamically override the kernel_layer_name in a model class.""" - for mod in sys.modules.values(): - if mod is None: - continue - obj = getattr(mod, cls_name, None) - if isinstance(obj, type) and issubclass(obj, nn.Module): - setattr(obj, "kernel_layer_name", value) - print(f"Overrode {cls_name}.kernel_layer_name to {value}") - return True - return False - - -# Init the model the normal way -model_id = "openai/gpt-oss-20b" -tokenizer = PreTrainedTokenizerFast.from_pretrained(model_id) -quantization_config = Mxfp4Config(dequantize=True) - - -from kernels import replace_kernel_forward_from_hub, register_kernel_mapping, LayerRepository, Mode - -from transformers.models.gpt_oss.modeling_gpt_oss import GptOssMLP, GptOssRMSNorm - -replace_kernel_forward_from_hub(GptOssRMSNorm, None) # direct, type-safe -custom_mapping = { - "Yamoe": { - "cuda": { - Mode.INFERENCE: LayerRepository( - repo_id="drbh/yamoe", - layer_name="Yamoe", - revision="v0.3.0", - ) - } - } -} -register_kernel_mapping(custom_mapping) - - -model = GptOssForCausalLM.from_pretrained( - model_id, - dtype="bfloat16", - device_map="auto", - use_kernels=True, - quantization_config=quantization_config, -).eval() - -messages = [ - {"role": "system", "content": "What is Tensor Parallelism?"}, -] - -inputs = tokenizer.apply_chat_template( - messages, - add_generation_prompt=True, - return_tensors="pt", - return_dict=True, - reasoning_effort="low", -).to("cuda") - -max_tokens = 256 - -with torch.inference_mode(): - start_time = time.perf_counter() - generated = model.generate( - **inputs, - max_new_tokens=max_tokens, - do_sample=False, - temperature=None, - ) - end_time = time.perf_counter() - -print(tokenizer.decode(generated[0], skip_special_tokens=False)) -print(f"Generation took {end_time - start_time:.2f} seconds") diff --git a/moe_benchmarks/megablocks_yamoe/cells/utils.py b/moe_benchmarks/megablocks_yamoe/cells/utils.py deleted file mode 100644 index f1f83f42e002602ff034c10cdc3f2f598c779e1f..0000000000000000000000000000000000000000 --- a/moe_benchmarks/megablocks_yamoe/cells/utils.py +++ /dev/null @@ -1,34 +0,0 @@ -# /// script -# dependencies = [ -# "torch", -# "numpy", -# ] -# /// - -"""Simple utilities for running the models.""" -import torch - -def to_dtype(dtype_str: str): - """Convert string to torch dtype.""" - if dtype_str == "float16": - return torch.float16 - if dtype_str == "bfloat16": - return torch.bfloat16 - return torch.float32 - -def tensor_stats(t: torch.Tensor) -> str: - """Generate stats string for a tensor.""" - return (f"shape={tuple(t.shape)}, " - f"dtype={t.dtype}, " - f"device={t.device}, " - f"mean={t.mean().item():.6f}, " - f"std={t.std().item():.6f}") - -def set_seed(seed: int): - """Set seeds for reproducibility.""" - torch.manual_seed(seed) - if torch.cuda.is_available(): - torch.cuda.manual_seed(seed) - torch.cuda.manual_seed_all(seed) - torch.backends.cudnn.deterministic = True - torch.backends.cudnn.benchmark = False \ No newline at end of file diff --git a/moe_benchmarks/megablocks_yamoe/cells/visualization.py b/moe_benchmarks/megablocks_yamoe/cells/visualization.py deleted file mode 100644 index 5240d0bce8b3ecb0ae1c431b40a8eafedc3e044b..0000000000000000000000000000000000000000 --- a/moe_benchmarks/megablocks_yamoe/cells/visualization.py +++ /dev/null @@ -1,116 +0,0 @@ -# /// script -# dependencies = [ -# "matplotlib", -# ] -# /// - -import json -import matplotlib.pyplot as plt -import numpy as np -from pathlib import Path -import os - -# List of expected result files -yamoe_dir = os.environ.get('UVNOTE_INPUT_YAMOE_RUN', '.') -binned_dir = os.environ.get('UVNOTE_INPUT_BINNED_RUN', '.') -gptoss_dir = os.environ.get('UVNOTE_INPUT_GPTOSS_RUN', '.') -gptoss_training_dir = os.environ.get('UVNOTE_INPUT_GPTOSS_TRAINING_RUN', '.') -megablocks_dir = os.environ.get('UVNOTE_INPUT_MEGABLOCKS_RUN', '.') - -result_files = [ - Path(yamoe_dir) / "yamoe_results.json", - Path(binned_dir) / "binned_results.json", - Path(gptoss_dir) / "gptoss_results.json", - Path(gptoss_training_dir) / "gptoss_training_results.json", - Path(megablocks_dir) / "megablocks_results.json" -] - -# Load all benchmark results -results = {} -for file in result_files: - if Path(file).exists(): - with open(file, 'r') as f: - data = json.load(f) - results[data['implementation']] = data - print(f"Loaded {file}") - else: - print(f"Missing {file}") - -if not results: - print("No benchmark results found. Run the benchmark cells first.") -else: - # Extract data for plotting - implementations = list(results.keys()) - avg_latencies = [results[impl]['stats']['avg_ms'] for impl in implementations] - p95_latencies = [results[impl]['stats']['p95_ms'] for impl in implementations] - throughputs = [results[impl]['stats'].get('tokens_per_s', 0) for impl in implementations] - - # Create figure with subplots - fig, (ax1, ax2, ax3) = plt.subplots(1, 3, figsize=(18, 6)) - fig.suptitle('MoE Implementation Performance Comparison', fontsize=16, fontweight='bold') - - # Colors for each implementation - colors = ['#FF6B6B', '#4ECDC4', '#45B7D1', '#96CEB4', '#FECA57'][:len(implementations)] - - # 1. Average Latency Chart - bars1 = ax1.bar(implementations, avg_latencies, color=colors, alpha=0.8, edgecolor='black', linewidth=1) - ax1.set_title('Average Latency', fontweight='bold', fontsize=14) - ax1.set_ylabel('Latency (ms)', fontweight='bold') - ax1.tick_params(axis='x', rotation=45) - ax1.grid(axis='y', alpha=0.3) - - # Add value labels on bars - for bar, val in zip(bars1, avg_latencies): - ax1.text(bar.get_x() + bar.get_width()/2, bar.get_height() + max(avg_latencies)*0.01, - f'{val:.2f}ms', ha='center', va='bottom', fontweight='bold') - - # 2. P95 Latency Chart - bars2 = ax2.bar(implementations, p95_latencies, color=colors, alpha=0.8, edgecolor='black', linewidth=1) - ax2.set_title('95th Percentile Latency', fontweight='bold', fontsize=14) - ax2.set_ylabel('Latency (ms)', fontweight='bold') - ax2.tick_params(axis='x', rotation=45) - ax2.grid(axis='y', alpha=0.3) - - # Add value labels on bars - for bar, val in zip(bars2, p95_latencies): - ax2.text(bar.get_x() + bar.get_width()/2, bar.get_height() + max(p95_latencies)*0.01, - f'{val:.2f}ms', ha='center', va='bottom', fontweight='bold') - - # 3. Throughput Chart - bars3 = ax3.bar(implementations, throughputs, color=colors, alpha=0.8, edgecolor='black', linewidth=1) - ax3.set_title('Throughput', fontweight='bold', fontsize=14) - ax3.set_ylabel('Tokens/sec', fontweight='bold') - ax3.tick_params(axis='x', rotation=45) - ax3.grid(axis='y', alpha=0.3) - - # Add value labels on bars - for bar, val in zip(bars3, throughputs): - if val > 0: # Only show label if throughput was calculated - ax3.text(bar.get_x() + bar.get_width()/2, bar.get_height() + max(throughputs)*0.01, - f'{val:.0f}', ha='center', va='bottom', fontweight='bold') - - plt.tight_layout() - plt.savefig("moe_performance_comparison.png", dpi=300) - - # Print summary table - print("\nPerformance Summary:") - print(f"{'Implementation':<30} {'Avg (ms)':<12} {'P95 (ms)':<12} {'Tokens/sec':<12} {'Relative Speed':<15}") - print("-"*80) - - # Sort by average latency for relative speed calculation - sorted_results = sorted(results.items(), key=lambda x: x[1]['stats']['avg_ms']) - fastest_latency = sorted_results[0][1]['stats']['avg_ms'] - - for impl, data in sorted_results: - avg_ms = data['stats']['avg_ms'] - p95_ms = data['stats']['p95_ms'] - tokens_s = data['stats'].get('tokens_per_s', 0) - relative_speed = fastest_latency / avg_ms - - print(f"{impl:<30} {avg_ms:>8.2f} {p95_ms:>8.2f} {tokens_s:>8.0f} {relative_speed:>6.2f}x") - - print(f"\nFastest: {sorted_results[0][0]} ({sorted_results[0][1]['stats']['avg_ms']:.2f}ms avg)") - if len(sorted_results) > 1: - print(f"Slowest: {sorted_results[-1][0]} ({sorted_results[-1][1]['stats']['avg_ms']:.2f}ms avg)") - speedup = sorted_results[-1][1]['stats']['avg_ms'] / sorted_results[0][1]['stats']['avg_ms'] - print(f"Max Speedup: {speedup:.1f}x") \ No newline at end of file diff --git a/moe_benchmarks/megablocks_yamoe/cells/yamoe_run.py b/moe_benchmarks/megablocks_yamoe/cells/yamoe_run.py deleted file mode 100644 index b3e73c4cb44433286cab638f8faae2623c5a5030..0000000000000000000000000000000000000000 --- a/moe_benchmarks/megablocks_yamoe/cells/yamoe_run.py +++ /dev/null @@ -1,135 +0,0 @@ -# /// script -# dependencies = [ -# "torch", -# "kernels", -# "numpy", -# ] -# /// - -import torch -from torch import nn -from torch.nn import functional as F -from kernels import get_kernel, get_local_kernel -from bench_utils import to_dtype, tensor_stats, set_seed, bench_context -from config import ( - NUM_EXPERTS, HIDDEN_SIZE, TOP_K, - BATCH_SIZE, SEQ_LEN, DTYPE, DEVICE, - WEIGHT_SEED, EXPERT_SEED, INPUT_SEED, GENERAL_SEED -) -from pathlib import Path -import os - -# Discover the upstream artifact directory from env -data_dir = os.environ.get('UVNOTE_INPUT_SAVE_DATA', '.') -print(f"Loading weights from: {data_dir}") - -router_weight = torch.load(Path(data_dir) / 'router_weight.pt') -router_bias = torch.load(Path(data_dir) / 'router_bias.pt') -gate_up_proj = torch.load(Path(data_dir) / 'gate_up_proj.pt') -gate_up_proj_bias = torch.load(Path(data_dir) / 'gate_up_proj_bias.pt') -down_proj = torch.load(Path(data_dir) / 'down_proj.pt') -down_proj_bias = torch.load(Path(data_dir) / 'down_proj_bias.pt') - -print("Loaded shared weights from artifacts") -print(f"Router weight sum: {router_weight.sum().item():.6f}") -print(f"Gate/up sum: {gate_up_proj.sum().item():.6f}") -print(f"Down sum: {down_proj.sum().item():.6f}") - -class YamoeRouter(nn.Module): - def __init__(self, router_weight, router_bias): - super().__init__() - self.top_k = TOP_K - self.num_experts = NUM_EXPERTS - self.hidden_dim = HIDDEN_SIZE - self.weight = nn.Parameter(router_weight.clone()) - self.bias = nn.Parameter(router_bias.clone()) - - def forward(self, hidden_states): - hidden_states = hidden_states.reshape(-1, self.hidden_dim) - router_logits = F.linear(hidden_states, self.weight, self.bias) - router_top_value, router_indices = torch.topk(router_logits, self.top_k, dim=-1) - router_top_value = torch.nn.functional.softmax(router_top_value, dim=1, dtype=router_top_value.dtype) - router_scores = torch.zeros_like(router_logits).scatter_(1, router_indices, router_top_value) - return router_scores, router_indices - -def ceil_div(a, b): - return (a + b - 1) // b - -class YamoeMoEMLP(nn.Module): - def __init__(self, router_weight, router_bias, gate_up_proj, gate_up_proj_bias, down_proj, down_proj_bias): - super().__init__() - self.router = YamoeRouter(router_weight, router_bias) - self.num_experts = NUM_EXPERTS - self.hidden_size = HIDDEN_SIZE - self.top_k = TOP_K - - # Load Yamoe kernel - # self.yamoe = get_local_kernel(Path("/home/ubuntu/Projects/yamoe/result"), "yamoe") - self.yamoe = get_kernel("drbh/yamoe", revision="v0.2.0") - - # Expert weights - use the loaded weights - self.gate_up_proj = nn.Parameter(gate_up_proj.clone()) - self.gate_up_proj_bias = nn.Parameter(gate_up_proj_bias.clone()) - self.down_proj = nn.Parameter(down_proj.clone()) - self.down_proj_bias = nn.Parameter(down_proj_bias.clone()) - - def forward(self, hidden_states): - batch_size, seq_len, hidden_dim = hidden_states.shape - - # Get routing decisions - routing_weights, router_indices = self.router(hidden_states) - - # Reshape for Yamoe kernel - hidden_states_flat = hidden_states.view(-1, hidden_dim) - routing_weights_flat = routing_weights.view(-1, self.num_experts) - expert_capacity = ceil_div(batch_size * self.top_k, self.num_experts) - - # Call Yamoe optimized kernel - output = self.yamoe.experts( - hidden_states_flat, - router_indices, - routing_weights_flat, - self.gate_up_proj, - self.gate_up_proj_bias, - self.down_proj, - self.down_proj_bias, - expert_capacity, - self.num_experts, - self.top_k, - ) - - # Reshape output back - output = output.view(batch_size, seq_len, hidden_dim) - - return output, routing_weights - -# Run the model -set_seed(GENERAL_SEED) - -device = torch.device(DEVICE if DEVICE == "cuda" else "cuda") -dtype = to_dtype(DTYPE) - -print("\n=== Yamoe Implementation ===") -# Initialize model with loaded weights -model = YamoeMoEMLP( - router_weight.to(device), - router_bias.to(device), - gate_up_proj.to(device), - gate_up_proj_bias.to(device), - down_proj.to(device), - down_proj_bias.to(device) -).to(device=device) - -print(f"Router weight sum: {model.router.weight.sum().item():.6f}") -print(f"Gate/up proj sum: {model.gate_up_proj.sum().item():.6f}") -print(f"Down proj sum: {model.down_proj.sum().item():.6f}") - -# Generate input -set_seed(INPUT_SEED) -x = torch.randn(BATCH_SIZE, SEQ_LEN, HIDDEN_SIZE, device=device, dtype=dtype) * 0.1 - -# Benchmark the model with varied inputs to prevent caching artifacts -tokens = BATCH_SIZE * SEQ_LEN -with bench_context(warmup=10, iters=50, device=device, dtype=dtype, tokens=tokens, save_json="yamoe_results.json", vary_inputs=True) as bench: - output, stats = bench(model, x) - print(f"\nOutput sum: {output[0].sum().item():.6f}") \ No newline at end of file diff --git a/moe_benchmarks/megablocks_yamoe/index.html b/moe_benchmarks/megablocks_yamoe/index.html deleted file mode 100644 index eb7b2fa8f6dfc6dacc0572fe71072184ac1d81ea..0000000000000000000000000000000000000000 --- a/moe_benchmarks/megablocks_yamoe/index.html +++ /dev/null @@ -1,25 +0,0 @@ - - - - - Directory Index - - - -

Index of /moe_benchmarks/megablocks_yamoe

- - - \ No newline at end of file diff --git a/moe_benchmarks/megablocks_yamoe/megablocks_yamoe.html b/moe_benchmarks/megablocks_yamoe/megablocks_yamoe.html deleted file mode 100644 index 52333e778e360a2f70f7e4fa369233a0b4523f61..0000000000000000000000000000000000000000 --- a/moe_benchmarks/megablocks_yamoe/megablocks_yamoe.html +++ /dev/null @@ -1,3997 +0,0 @@ - - - - - - uvnote Integration Test Report - - - - - - - -
-
-
light
-
reset
- -
-
- -
-
Generated on:
-
- Linux x86_64 | Linux-6.12.40-64.114.amzn2023.x86_64-x86_64-with-glibc2.36 -
-
- -
-

Comparison of Megablocks and Yamoe Kernels

-

This note compares the performance of the Megablocks and Yamoe kernels on the GPT-OSS-20B model.

-

Megablocks kernel

-

Yamoe Kernel

-
-
- -▼ code -▼ output - ▶ uv-logs - | -Cell: setup | 16.96s | FAILED - | - -Raw -
-
-
-
-1 -2 -3 -4 -5 -6 -7 -8 -9 -10 -11 -12 -13 -14 -15 -16 -17 -18 -19 -20 -21 -22 -23 -24 -25 -26 -27 -28 -29 -30 -31 -32 -33 -34 -35 -36 -37 -38 -39 -40 -41 -42 -43 -44 -45 -46 -47 -48 -49 -50 -51 -52 -53 -54 -55 -56 -57 -58 -59 -60 -61 -62 -63 -64 -65 -66 -67 -68 -69 -70 -71 -72 -73 -74 -75 -76 -77 -78 -79 -80 -81 -82 -83 -84 -85 -86 -87 -88 -89 -90 -91 -92 -93 -94 -95 -96 -97 -98 -99 -100 -101 -102 -103 -104 -105 -106 -107 -108 -109 -110 -111 -112 -113 -114 -115 -116 -
-
-
# /// script
-# requires-python = ">=3.12"
-# dependencies = [
-#     "accelerate>=1.10.1",
-#     "torch>=2.7.0",
-#     "kernels==0.10.0",
-#     "transformers@https://github.com/huggingface/transformers.git",
-#     "ipdb>=0.13.13",
-#     "matplotlib>=3.7.2",
-#     "numpy>=1.24.3",
-# ]
-# ///
-
-import torch
-from transformers import GptOssForCausalLM, PreTrainedTokenizerFast, Mxfp4Config
-import time
-import torch.nn as nn
-from kernels import register_kernel_mapping, Mode, LayerRepository
-import sys
-import torch.profiler
-import gc
-import logging
-
-# set to debug logging
-logging.basicConfig(level=logging.INFO)
-
-def reset_peak_memory_stats():
-    """Clear CUDA cache and reset memory allocation counters."""
-    torch.cuda.empty_cache()
-    if torch.cuda.is_available():
-        torch.cuda.reset_peak_memory_stats()
-    gc.collect()
-
-def get_memory_stats():
-    """Get current and peak CUDA memory usage."""
-    if not torch.cuda.is_available():
-        return {"allocated_gb": 0, "peak_gb": 0, "reserved_gb": 0}
-    return {
-        "allocated_gb": torch.cuda.memory_allocated() / 1e9,
-        "peak_gb": torch.cuda.max_memory_allocated() / 1e9,
-        "reserved_gb": torch.cuda.memory_reserved() / 1e9,
-    }
-
-def override_kernel_layer_name(cls_name: str, value) -> bool:
-    """Helper to dynamically override the kernel_layer_name in a model class."""
-    for mod in sys.modules.values():
-        if mod is None:
-            continue
-        obj = getattr(mod, cls_name, None)
-        if isinstance(obj, type) and issubclass(obj, nn.Module):
-            setattr(obj, "kernel_layer_name", value)
-            print(f"Overrode {cls_name}.kernel_layer_name to {value}")
-            return True
-    return False
-
-
-# Init the model the normal way
-model_id = "openai/gpt-oss-20b"
-tokenizer = PreTrainedTokenizerFast.from_pretrained(model_id)
-quantization_config = Mxfp4Config(dequantize=True)
-
-
-from kernels import replace_kernel_forward_from_hub, register_kernel_mapping, LayerRepository, Mode
-
-from transformers.models.gpt_oss.modeling_gpt_oss import GptOssMLP, GptOssRMSNorm
-
-replace_kernel_forward_from_hub(GptOssMLP, "Yamoe")
-replace_kernel_forward_from_hub(GptOssRMSNorm, None)
-custom_mapping = {
-    "Yamoe": {
-        "cuda": {
-            Mode.INFERENCE: LayerRepository(
-                repo_id="drbh/yamoe",
-                layer_name="Yamoe",
-                revision="v0.3.0",
-            )
-        }
-    }
-}
-register_kernel_mapping(custom_mapping)
-
-
-model = GptOssForCausalLM.from_pretrained(
-    model_id,
-    dtype="bfloat16",
-    device_map="auto",
-    use_kernels=True,
-    quantization_config=quantization_config,
-).eval()
-
-messages = [
-    {"role": "system", "content": "What is Tensor Parallelism?"},
-]
-
-inputs = tokenizer.apply_chat_template(
-    messages,
-    add_generation_prompt=True,
-    return_tensors="pt",
-    return_dict=True,
-    reasoning_effort="low",
-).to("cuda")
-
-max_tokens = 256
-
-with torch.inference_mode():
-    start_time = time.perf_counter()
-    generated = model.generate(
-        **inputs,
-        max_new_tokens=max_tokens,
-        do_sample=False,
-        temperature=None,
-    )
-    end_time = time.perf_counter()
-
-print(tokenizer.decode(generated[0], skip_special_tokens=False))
-print(f"Generation took {end_time - start_time:.2f} seconds")
-
- -
-
-
-
-
-
warning: The requested interpreter resolved to Python 3.11.13, which is incompatible with the script's Python requirement: `>=3.12` - Updating https://github.com/huggingface/transformers.git (HEAD) - Updated https://github.com/huggingface/transformers.git (53838edde77cb10f3a360150aa85a457637e9ac3) - × No solution found when resolving script dependencies: - ╰─▶ Because only transformers==4.57.0.dev0 is available and - transformers==4.57.0.dev0 depends on huggingface-hub==1.0.0rc1, - we can conclude that all versions of transformers depend on - huggingface-hub==1.0.0rc1. - And because kernels==0.10.0 depends on huggingface-hub>=0.26.0,<1.0, - we can conclude that kernels==0.10.0 and all versions of transformers - are incompatible. - And because you require kernels==0.10.0 and transformers, we can - conclude that your requirements are unsatisfiable. -
-
-
-
- - - \ No newline at end of file diff --git a/moe_benchmarks/megablocks_yamoe/torch_profile.html b/moe_benchmarks/megablocks_yamoe/torch_profile.html deleted file mode 100644 index 984f73501b96fe8890308b7db3afdaba7529dee5..0000000000000000000000000000000000000000 --- a/moe_benchmarks/megablocks_yamoe/torch_profile.html +++ /dev/null @@ -1,6956 +0,0 @@ - - - - - - Compare Yamoe and Binned MoE Implementations - - - - - - - -
-
-
light
-
reset
- -
-
- -
-
Generated on:
-
- Linux x86_64 | Linux-6.12.40-64.114.amzn2023.x86_64-x86_64-with-glibc2.36 -
-
- -
-
-
- -▼ code -▼ output - ▶ uv-logs - | -Cell: utils | deps: torch, numpy | 34.47s - | - -Raw -
-
-
-
-1 -2 -3 -4 -5 -6 -7 -8 -9 -10 -11 -12 -13 -14 -15 -16 -17 -18 -19 -20 -21 -22 -23 -24 -25 -26 -27 -
-
-
"""Simple utilities for running the models."""
-import torch
-
-def to_dtype(dtype_str: str):
-    """Convert string to torch dtype."""
-    if dtype_str == "float16":
-        return torch.float16
-    if dtype_str == "bfloat16":
-        return torch.bfloat16
-    return torch.float32
-
-def tensor_stats(t: torch.Tensor) -> str:
-    """Generate stats string for a tensor."""
-    return (f"shape={tuple(t.shape)}, "
-            f"dtype={t.dtype}, "
-            f"device={t.device}, "
-            f"mean={t.mean().item():.6f}, "
-            f"std={t.std().item():.6f}")
-
-def set_seed(seed: int):
-    """Set seeds for reproducibility."""
-    torch.manual_seed(seed)
-    if torch.cuda.is_available():
-        torch.cuda.manual_seed(seed)
-        torch.cuda.manual_seed_all(seed)
-        torch.backends.cudnn.deterministic = True
-        torch.backends.cudnn.benchmark = False
-
- -
-
-
-
-
-
-
▶ UV Install Logs
- -
-
-
- -
-
- -▼ code -▼ output - ▶ uv-logs - | -Cell: bench_utils | deps: torch, numpy | 34.94s - | - -Raw -
-
-
-
-1 -2 -3 -4 -5 -6 -7 -8 -9 -10 -11 -12 -13 -14 -15 -16 -17 -18 -19 -20 -21 -22 -23 -24 -25 -26 -27 -28 -29 -30 -31 -32 -33 -34 -35 -36 -37 -38 -39 -40 -41 -42 -43 -44 -45 -46 -47 -48 -49 -50 -51 -52 -53 -54 -55 -56 -57 -58 -59 -60 -61 -62 -63 -64 -65 -66 -67 -68 -69 -70 -71 -72 -73 -74 -75 -76 -77 -78 -79 -80 -81 -82 -83 -84 -85 -86 -87 -88 -89 -90 -91 -92 -93 -94 -95 -96 -97 -98 -99 -100 -101 -102 -103 -104 -105 -106 -107 -108 -109 -110 -111 -112 -113 -114 -115 -116 -117 -118 -119 -120 -121 -122 -123 -124 -125 -126 -127 -128 -129 -130 -131 -132 -133 -134 -135 -136 -137 -138 -139 -140 -141 -142 -143 -144 -145 -146 -147 -148 -149 -150 -151 -152 -153 -154 -155 -156 -157 -158 -159 -160 -161 -162 -163 -164 -165 -166 -167 -168 -169 -170 -171 -172 -173 -174 -175 -176 -177 -178 -179 -180 -181 -182 -183 -184 -185 -186 -187 -188 -189 -190 -191 -192 -193 -194 -195 -196 -197 -198 -199 -200 -201 -202 -203 -204 -205 -206 -207 -208 -209 -210 -211 -212 -213 -214 -215 -216 -217 -218 -219 -220 -221 -222 -223 -224 -225 -226 -227 -228 -229 -230 -231 -232 -233 -234 -
-
-
"""Reusable benchmarking utilities for performance testing."""
-import time
-import numpy as np
-from contextlib import contextmanager
-from typing import Callable, Dict, Tuple, Any, Optional
-import torch
-
-def to_dtype(dtype_str: str):
-    """Convert string to torch dtype."""
-    if dtype_str == "float16":
-        return torch.float16
-    if dtype_str == "bfloat16":
-        return torch.bfloat16
-    return torch.float32
-
-def _sync(device: str):
-    """Synchronize device if CUDA."""
-    if device == "cuda":
-        torch.cuda.synchronize()
-
-def _compute_stats(times_s, tokens: Optional[int] = None) -> Dict[str, float]:
-    """Compute comprehensive latency and throughput statistics."""
-    lat_ms = np.array([t * 1000.0 for t in times_s])
-    lat_ms_sorted = np.sort(lat_ms)
-    n = len(lat_ms)
-
-    stats = {
-        "avg_ms": np.mean(lat_ms),
-        "min_ms": np.min(lat_ms),
-        "max_ms": np.max(lat_ms),
-        "std_ms": np.std(lat_ms),
-        "p50_ms": np.percentile(lat_ms, 50),
-        "p95_ms": np.percentile(lat_ms, 95),
-        "p99_ms": np.percentile(lat_ms, 99),
-        "num_iters": n
-    }
-
-    if tokens is not None and n > 0:
-        avg_s = np.mean(times_s)
-        stats["tokens_per_s"] = tokens / avg_s if avg_s > 0 else float("inf")
-        stats["throughput_variance"] = np.std([tokens / t for t in times_s if t > 0])
-
-    return stats
-
-def _format_timing_stats(stats: Dict[str, float], tokens: Optional[int] = None) -> str:
-    """Format timing statistics for display."""
-    lines = [
-        "\n━━━━━━━━━━━━━━━━━━━━ Benchmark Results ━━━━━━━━━━━━━━━━━━━━",
-        f"Iterations: {stats.get('num_iters', 0)}",
-        "\nLatency Statistics:",
-        f"  Average: {stats['avg_ms']:.3f} ms",
-        f"  Min:     {stats['min_ms']:.3f} ms",
-        f"  Max:     {stats['max_ms']:.3f} ms", 
-        f"  Std Dev: {stats['std_ms']:.3f} ms",
-        "\nPercentiles:",
-        f"  P50 (median): {stats['p50_ms']:.3f} ms",
-        f"  P95:          {stats['p95_ms']:.3f} ms",
-        f"  P99:          {stats['p99_ms']:.3f} ms",
-    ]
-
-    if tokens is not None and 'tokens_per_s' in stats:
-        lines.extend([
-            "\nThroughput:",
-            f"  Tokens/sec: {stats['tokens_per_s']:.1f}",
-            f"  Std Dev:    {stats.get('throughput_variance', 0):.1f}",
-        ])
-
-    lines.append("━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━")
-    return "\n".join(lines)
-
-def _bench_engine(
-    call: Callable[[], Any], *, warmup: int, iters: int, device: str, dtype, input_gen: Callable[[], Any] = None
-) -> Tuple[Any, list]:
-    """Core benchmarking engine with warmup and timing."""
-    use_autocast = device == "cuda" and dtype in (torch.float16, torch.bfloat16)
-
-    # Warmup phase
-    print(f"\nWarming up ({warmup} iterations)...")
-    with torch.inference_mode():
-        for _ in range(max(0, warmup)):
-            if use_autocast:
-                with torch.autocast(device_type="cuda", dtype=dtype):
-                    if input_gen is not None:
-                        _ = call(input_gen())
-                    else:
-                        _ = call()
-            else:
-                if input_gen is not None:
-                    _ = call(input_gen())
-                else:
-                    _ = call()
-        _sync(device)
-
-    # Measurement phase
-    print(f"Benchmarking ({iters} iterations)...")
-    times_s = []
-    last = None
-    with torch.inference_mode():
-        for i in range(max(1, iters)):
-            start = time.perf_counter()
-            if use_autocast:
-                with torch.autocast(device_type="cuda", dtype=dtype):
-                    if input_gen is not None:
-                        last = call(input_gen())
-                    else:
-                        last = call()
-            else:
-                if input_gen is not None:
-                    last = call(input_gen())
-                else:
-                    last = call()
-            _sync(device)
-            end = time.perf_counter()
-            times_s.append(end - start)
-
-            # Progress indicator every 20% of iterations
-            if i > 0 and i % max(1, iters // 5) == 0:
-                pct = (i / iters) * 100
-                avg_so_far = np.mean(times_s[:i]) * 1000
-                print(f"  Progress: {pct:.0f}% complete (avg: {avg_so_far:.3f} ms)")
-
-    return last, times_s
-
-def tensor_stats(t: torch.Tensor) -> str:
-    """Generate comprehensive stats string for a tensor."""
-    return (f"shape={tuple(t.shape)}, "
-            f"dtype={t.dtype}, "
-            f"device={t.device}, "
-            f"range=[{t.min().item():.6f}, {t.max().item():.6f}], "
-            f"mean={t.mean().item():.6f}, "
-            f"std={t.std().item():.6f}, "
-            f"norm={t.norm().item():.6f}")
-
-@contextmanager
-def bench_context(
-    *, warmup: int = 25, iters: int = 100, device: str = "cuda", dtype=torch.float32, tokens: Optional[int] = None, verbose: bool = True, save_json: Optional[str] = None, vary_inputs: bool = True
-):
-    """Context that yields a runner: runner(fn, *args, **kwargs) -> (result, stats).
-
-    If vary_inputs=True, the first argument should be a base tensor that will be varied each iteration
-    by adding a small deterministic increment to prevent caching artifacts.
-    """
-
-    def runner(fn: Callable[..., Any], *args, **kwargs) -> Tuple[Any, Dict[str, float]]:
-        # Log configuration
-        if verbose:
-            print(f"\n┌─ Benchmark Configuration ─────────────────────────────┐")
-            # print(f"│ Device: {device:<15} Dtype: {dtype}              │")
-            print(f"│ Warmup: {warmup:<15} Iters: {iters}              │")
-            if tokens:
-                print(f"│ Tokens: {tokens}                                        │")
-            if vary_inputs:
-                print(f"│ Input Variation: Enabled (prevents caching artifacts)  │")
-            print(f"└────────────────────────────────────────────────────────┘")
-
-        # Set up input generation
-        input_gen = None
-        if vary_inputs and args and isinstance(args[0], torch.Tensor):
-            base_input = args[0].clone()
-            iteration_counter = [0]  # Use list for mutable closure
-
-            def generate_varied_input():
-                """Generate input tensor varied by iteration to prevent caching."""
-                # Add small deterministic increment: 0.001 * iteration_number
-                varied_input = base_input + (iteration_counter[0] * 0.001)
-                iteration_counter[0] += 1
-                return varied_input
-
-            input_gen = generate_varied_input
-            call = lambda x: fn(x, *args[1:], **kwargs)
-
-            # Log base input stats
-            if verbose:
-                print(f"\nBase Input: {tensor_stats(base_input)}")
-                print(f"Input Variation: +{0.001:.3f} * iteration (deterministic)")
-        else:
-            # Legacy mode - static inputs
-            call = lambda: fn(*args, **kwargs)
-            if verbose and args and isinstance(args[0], torch.Tensor):
-                print(f"\nInput: {tensor_stats(args[0])}")
-
-        result, times_s = _bench_engine(call, warmup=warmup, iters=iters, device=device, dtype=dtype, input_gen=input_gen)
-
-        # Log output if it's a tensor or tuple with tensors
-        if verbose:
-            print("\nOutput tensors:")
-            if isinstance(result, torch.Tensor):
-                print(f"  Primary: {tensor_stats(result)}")
-            elif isinstance(result, tuple) and len(result) > 0 and isinstance(result[0], torch.Tensor):
-                print(f"  Primary: {tensor_stats(result[0])}")
-                if len(result) > 1:
-                    if isinstance(result[1], torch.Tensor):
-                        print(f"  Auxiliary: {tensor_stats(result[1])}")
-                    else:
-                        print(f"  Auxiliary: {type(result[1]).__name__}")
-
-        # Compute and display statistics
-        stats = _compute_stats(times_s, tokens=tokens)
-        if verbose:
-            print(_format_timing_stats(stats, tokens))
-
-        # Save to JSON if requested
-        if save_json:
-            import json
-            json_data = {
-                "implementation": save_json.replace(".json", ""),
-                "config": {
-                    "warmup": warmup,
-                    "iters": iters,
-                    "device": str(device),  # Convert device to string
-                    "dtype": str(dtype),
-                    "tokens": tokens,
-                    "vary_inputs": vary_inputs
-                },
-                "stats": stats,
-                "output_sum": float(result[0].sum().item()) if isinstance(result, tuple) and len(result) > 0 else float(result.sum().item()) if isinstance(result, torch.Tensor) else None
-            }
-            with open(save_json, 'w') as f:
-                json.dump(json_data, f, indent=2)
-            if verbose:
-                print(f"\nSaved benchmark results to {save_json}")
-
-        return result, stats
-
-    yield runner
-
-def set_seed(seed: int):
-    """Set seeds for reproducibility."""
-    torch.manual_seed(seed)
-    if torch.cuda.is_available():
-        torch.cuda.manual_seed(seed)
-        torch.cuda.manual_seed_all(seed)
-        torch.backends.cudnn.deterministic = True
-        torch.backends.cudnn.benchmark = False
-
- -
-
-
-
-
-
-
▶ UV Install Logs
- -
-
-
- -

This notebook benchmarks multiple MoE implementations with varied inputs across iterations to prevent unrealistic caching artifacts and measure true performance characteristics.

-
-
- -▼ code -▼ output - ▶ uv-logs - | -Cell: config | deps: torch, numpy | 35.62s - | - -Raw -
-
-
-
-1 -2 -3 -4 -5 -6 -7 -8 -9 -10 -11 -12 -13 -14 -15 -16 -17 -18 -19 -20 -
-
-
"""Shared configuration for both implementations."""
-import torch
-
-# Model configuration
-NUM_EXPERTS = 128
-HIDDEN_SIZE = 1152
-INTERMEDIATE_SIZE = 3072
-TOP_K = 4
-
-# Input configuration
-BATCH_SIZE = 1
-SEQ_LEN = 100
-DTYPE = "float32"
-DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
-
-# Seeds for reproducibility
-WEIGHT_SEED = 999
-EXPERT_SEED = 777
-INPUT_SEED = 123
-GENERAL_SEED = 42
-
- -
-
-
-
-
-
-
▶ UV Install Logs
- -
-
-
- -
-
- -▼ code -▼ output - ▶ uv-logs - | -Cell: save_data | deps: torch, numpy | 39.76s - | - -Raw -
-
-
-
-1 -2 -3 -4 -5 -6 -7 -8 -9 -10 -11 -12 -13 -14 -15 -16 -17 -18 -19 -20 -21 -22 -23 -24 -25 -26 -27 -28 -29 -30 -31 -32 -33 -34 -35 -
-
-
"""
-Generate deterministic shared weights once and save as artifacts so
-both implementations load identical parameters.
-"""
-import torch
-from config import NUM_EXPERTS, HIDDEN_SIZE, WEIGHT_SEED, EXPERT_SEED
-
-def save_shared_weights():
-    # Router: Kaiming uniform as used by both, bias zeros
-    torch.manual_seed(WEIGHT_SEED)
-    router_weight = torch.empty(NUM_EXPERTS, HIDDEN_SIZE)
-    torch.nn.init.kaiming_uniform_(router_weight)
-    router_bias = torch.zeros(NUM_EXPERTS)
-
-    # Experts: normal(0, 0.02), biases zeros
-    torch.manual_seed(EXPERT_SEED)
-    gate_up_proj = torch.empty(NUM_EXPERTS, HIDDEN_SIZE, 2 * HIDDEN_SIZE).normal_(mean=0.0, std=0.02)
-    gate_up_proj_bias = torch.zeros(NUM_EXPERTS, 2 * HIDDEN_SIZE)
-    down_proj = torch.empty(NUM_EXPERTS, HIDDEN_SIZE, HIDDEN_SIZE).normal_(mean=0.0, std=0.02)
-    down_proj_bias = torch.zeros(NUM_EXPERTS, HIDDEN_SIZE)
-
-    # Save artifacts
-    torch.save(router_weight, 'router_weight.pt')
-    torch.save(router_bias, 'router_bias.pt')
-    torch.save(gate_up_proj, 'gate_up_proj.pt')
-    torch.save(gate_up_proj_bias, 'gate_up_proj_bias.pt')
-    torch.save(down_proj, 'down_proj.pt')
-    torch.save(down_proj_bias, 'down_proj_bias.pt')
-
-    print("Saved shared weights to artifacts")
-    print(f"Router weight sum: {router_weight.sum().item():.6f}")
-    print(f"Gate/up sum: {gate_up_proj.sum().item():.6f}")
-    print(f"Down sum: {down_proj.sum().item():.6f}")
-
-save_shared_weights()
-
- -
-
-
-
-
-
Saved shared weights to artifacts -Router weight sum: 12.588732 -Gate/up sum: 1026.601807 -Down sum: 206.729263 -
-
-
▶ UV Install Logs
- -
- -
-
- -

Yamoe Implementation

-

This section runs the Yamoe MoE implementation with optimized Triton kernels.

-
-
- -▼ code -▼ output - ▶ uv-logs - | -Cell: yamoe_run | deps: torch, kernels, numpy | 38.79s - | - -Raw -
-
-
-
-1 -2 -3 -4 -5 -6 -7 -8 -9 -10 -11 -12 -13 -14 -15 -16 -17 -18 -19 -20 -21 -22 -23 -24 -25 -26 -27 -28 -29 -30 -31 -32 -33 -34 -35 -36 -37 -38 -39 -40 -41 -42 -43 -44 -45 -46 -47 -48 -49 -50 -51 -52 -53 -54 -55 -56 -57 -58 -59 -60 -61 -62 -63 -64 -65 -66 -67 -68 -69 -70 -71 -72 -73 -74 -75 -76 -77 -78 -79 -80 -81 -82 -83 -84 -85 -86 -87 -88 -89 -90 -91 -92 -93 -94 -95 -96 -97 -98 -99 -100 -101 -102 -103 -104 -105 -106 -107 -108 -109 -110 -111 -112 -113 -114 -115 -116 -117 -118 -119 -120 -121 -122 -123 -124 -125 -126 -127 -
-
-
import torch
-from torch import nn
-from torch.nn import functional as F
-from kernels import get_kernel, get_local_kernel
-from bench_utils import to_dtype, tensor_stats, set_seed, bench_context
-from config import (
-    NUM_EXPERTS, HIDDEN_SIZE, TOP_K,
-    BATCH_SIZE, SEQ_LEN, DTYPE, DEVICE,
-    WEIGHT_SEED, EXPERT_SEED, INPUT_SEED, GENERAL_SEED
-)
-from pathlib import Path
-import os
-
-# Discover the upstream artifact directory from env
-data_dir = os.environ.get('UVNOTE_INPUT_SAVE_DATA', '.')
-print(f"Loading weights from: {data_dir}")
-
-router_weight = torch.load(Path(data_dir) / 'router_weight.pt')
-router_bias = torch.load(Path(data_dir) / 'router_bias.pt')
-gate_up_proj = torch.load(Path(data_dir) / 'gate_up_proj.pt')
-gate_up_proj_bias = torch.load(Path(data_dir) / 'gate_up_proj_bias.pt')
-down_proj = torch.load(Path(data_dir) / 'down_proj.pt')
-down_proj_bias = torch.load(Path(data_dir) / 'down_proj_bias.pt')
-
-print("Loaded shared weights from artifacts")
-print(f"Router weight sum: {router_weight.sum().item():.6f}")
-print(f"Gate/up sum: {gate_up_proj.sum().item():.6f}")
-print(f"Down sum: {down_proj.sum().item():.6f}")
-
-class YamoeRouter(nn.Module):
-    def __init__(self, router_weight, router_bias):
-        super().__init__()
-        self.top_k = TOP_K
-        self.num_experts = NUM_EXPERTS
-        self.hidden_dim = HIDDEN_SIZE
-        self.weight = nn.Parameter(router_weight.clone())
-        self.bias = nn.Parameter(router_bias.clone())
-
-    def forward(self, hidden_states):
-        hidden_states = hidden_states.reshape(-1, self.hidden_dim)
-        router_logits = F.linear(hidden_states, self.weight, self.bias)
-        router_top_value, router_indices = torch.topk(router_logits, self.top_k, dim=-1)
-        router_top_value = torch.nn.functional.softmax(router_top_value, dim=1, dtype=router_top_value.dtype)
-        router_scores = torch.zeros_like(router_logits).scatter_(1, router_indices, router_top_value)
-        return router_scores, router_indices
-
-def ceil_div(a, b):
-    return (a + b - 1) // b
-
-class YamoeMoEMLP(nn.Module):
-    def __init__(self, router_weight, router_bias, gate_up_proj, gate_up_proj_bias, down_proj, down_proj_bias):
-        super().__init__()
-        self.router = YamoeRouter(router_weight, router_bias)
-        self.num_experts = NUM_EXPERTS
-        self.hidden_size = HIDDEN_SIZE
-        self.top_k = TOP_K
-
-        # Load Yamoe kernel
-        # self.yamoe = get_local_kernel(Path("/home/ubuntu/Projects/yamoe/result"), "yamoe")
-        self.yamoe = get_kernel("drbh/yamoe", revision="v0.2.0")
-
-        # Expert weights - use the loaded weights
-        self.gate_up_proj = nn.Parameter(gate_up_proj.clone())
-        self.gate_up_proj_bias = nn.Parameter(gate_up_proj_bias.clone())
-        self.down_proj = nn.Parameter(down_proj.clone())
-        self.down_proj_bias = nn.Parameter(down_proj_bias.clone())
-
-    def forward(self, hidden_states):
-        batch_size, seq_len, hidden_dim = hidden_states.shape
-
-        # Get routing decisions
-        routing_weights, router_indices = self.router(hidden_states)
-
-        # Reshape for Yamoe kernel
-        hidden_states_flat = hidden_states.view(-1, hidden_dim)
-        routing_weights_flat = routing_weights.view(-1, self.num_experts)
-        expert_capacity = ceil_div(batch_size * self.top_k, self.num_experts)
-
-        # Call Yamoe optimized kernel
-        output = self.yamoe.experts(
-            hidden_states_flat,
-            router_indices,
-            routing_weights_flat,
-            self.gate_up_proj,
-            self.gate_up_proj_bias,
-            self.down_proj,
-            self.down_proj_bias,
-            expert_capacity,
-            self.num_experts,
-            self.top_k,
-        )
-
-        # Reshape output back
-        output = output.view(batch_size, seq_len, hidden_dim)
-
-        return output, routing_weights
-
-# Run the model
-set_seed(GENERAL_SEED)
-
-device = torch.device(DEVICE if DEVICE == "cuda" else "cuda")
-dtype = to_dtype(DTYPE)
-
-print("\n=== Yamoe Implementation ===")
-# Initialize model with loaded weights
-model = YamoeMoEMLP(
-    router_weight.to(device),
-    router_bias.to(device),
-    gate_up_proj.to(device),
-    gate_up_proj_bias.to(device),
-    down_proj.to(device),
-    down_proj_bias.to(device)
-).to(device=device)
-
-print(f"Router weight sum: {model.router.weight.sum().item():.6f}")
-print(f"Gate/up proj sum: {model.gate_up_proj.sum().item():.6f}")
-print(f"Down proj sum: {model.down_proj.sum().item():.6f}")
-
-# Generate input
-set_seed(INPUT_SEED)
-x = torch.randn(BATCH_SIZE, SEQ_LEN, HIDDEN_SIZE, device=device, dtype=dtype) * 0.1
-
-# Benchmark the model with varied inputs to prevent caching artifacts
-tokens = BATCH_SIZE * SEQ_LEN
-with bench_context(warmup=10, iters=50, device=device, dtype=dtype, tokens=tokens, save_json="yamoe_results.json", vary_inputs=True) as bench:
-    output, stats = bench(model, x)
-    print(f"\nOutput sum: {output[0].sum().item():.6f}")
-
- -
-
-
-
-
-
Loading weights from: /repo/moe_benchmarks/megablocks_yamoe/.uvnote/cache/b398a2853af91970392ae37f0d53a0eda463df639220863fbd38f33605bf9cbb -Loaded shared weights from artifacts -Router weight sum: 12.588732 -Gate/up sum: 1026.601807 -Down sum: 206.729263 - -=== Yamoe Implementation === -Router weight sum: 12.588732 -Gate/up proj sum: 1026.601807 -Down proj sum: 206.729340 - -┌─ Benchmark Configuration ─────────────────────────────┐ -│ Warmup: 10 Iters: 50 │ -│ Tokens: 100 │ -│ Input Variation: Enabled (prevents caching artifacts) │ -└────────────────────────────────────────────────────────┘ - -Base Input: shape=(1, 100, 1152), dtype=torch.float32, device=cuda:0, range=[-0.486445, 0.446746], mean=-0.000048, std=0.099986, norm=33.936142 -Input Variation: +0.001 * iteration (deterministic) - -Warming up (10 iterations)... -Benchmarking (50 iterations)... - Progress: 20% complete (avg: 4.251 ms) - Progress: 40% complete (avg: 4.249 ms) - Progress: 60% complete (avg: 4.244 ms) - Progress: 80% complete (avg: 4.246 ms) - -Output tensors: - Primary: shape=(1, 100, 1152), dtype=torch.float32, device=cuda:0, range=[-0.049506, 0.054984], mean=0.000034, std=0.006508, norm=2.208791 - Auxiliary: shape=(100, 128), dtype=torch.float32, device=cuda:0, range=[0.000000, 0.302948], mean=0.007812, std=0.043553, norm=5.005893 - -━━━━━━━━━━━━━━━━━━━━ Benchmark Results ━━━━━━━━━━━━━━━━━━━━ -Iterations: 50 - -Latency Statistics: - Average: 4.246 ms - Min: 4.067 ms - Max: 4.294 ms - Std Dev: 0.034 ms - -Percentiles: - P50 (median): 4.253 ms - P95: 4.267 ms - P99: 4.287 ms - -Throughput: - Tokens/sec: 23549.4 - Std Dev: 193.2 -━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ - -Saved benchmark results to yamoe_results.json - -Output sum: 3.971905 -
-
-
▶ UV Install Logs
- -
-
Fetching 6 files: 0%| | 0/6 [00:00<?, ?it/s] -Fetching 6 files: 17%|█▋ | 1/6 [00:00<00:01, 3.75it/s] -Fetching 6 files: 50%|█████ | 3/6 [00:00<00:00, 3.39it/s] -Fetching 6 files: 100%|██████████| 6/6 [00:00<00:00, 6.84it/s]
-
-

Artifacts:

-yamoe_results.json -
-
-
- -

Binned Implementation

-

This section runs the binned implementation that manually handles token gathering/scattering.

-
-
- -▼ code -▼ output - ▶ uv-logs - | -Cell: binned_run | deps: torch, numpy | 39.53s - | - -Raw -
-
-
-
-1 -2 -3 -4 -5 -6 -7 -8 -9 -10 -11 -12 -13 -14 -15 -16 -17 -18 -19 -20 -21 -22 -23 -24 -25 -26 -27 -28 -29 -30 -31 -32 -33 -34 -35 -36 -37 -38 -39 -40 -41 -42 -43 -44 -45 -46 -47 -48 -49 -50 -51 -52 -53 -54 -55 -56 -57 -58 -59 -60 -61 -62 -63 -64 -65 -66 -67 -68 -69 -70 -71 -72 -73 -74 -75 -76 -77 -78 -79 -80 -81 -82 -83 -84 -85 -86 -87 -88 -89 -90 -91 -92 -93 -94 -95 -96 -97 -98 -99 -100 -101 -102 -103 -104 -105 -106 -107 -108 -109 -110 -111 -112 -113 -114 -115 -116 -117 -118 -119 -120 -121 -122 -123 -124 -125 -126 -127 -128 -129 -130 -131 -132 -133 -134 -135 -136 -137 -138 -139 -140 -141 -142 -143 -144 -145 -146 -147 -148 -149 -150 -151 -152 -153 -154 -155 -156 -157 -158 -159 -160 -161 -162 -163 -164 -165 -166 -167 -168 -169 -170 -171 -172 -173 -174 -175 -176 -177 -178 -179 -180 -181 -182 -183 -184 -185 -186 -187 -188 -
-
-
import torch
-from torch import nn
-from torch.nn import functional as F
-from bench_utils import to_dtype, tensor_stats, set_seed, bench_context
-from config import (
-    NUM_EXPERTS, HIDDEN_SIZE, TOP_K,
-    BATCH_SIZE, SEQ_LEN, DTYPE, DEVICE,
-    WEIGHT_SEED, EXPERT_SEED, INPUT_SEED, GENERAL_SEED
-)
-from pathlib import Path
-import os
-
-# Discover the upstream artifact directory from env
-data_dir = os.environ.get('UVNOTE_INPUT_SAVE_DATA', '.')
-
-router_weight = torch.load(Path(data_dir) / 'router_weight.pt')
-router_bias = torch.load(Path(data_dir) / 'router_bias.pt')
-gate_up_proj = torch.load(Path(data_dir) / 'gate_up_proj.pt')
-gate_up_proj_bias = torch.load(Path(data_dir) / 'gate_up_proj_bias.pt')
-down_proj = torch.load(Path(data_dir) / 'down_proj.pt')
-down_proj_bias = torch.load(Path(data_dir) / 'down_proj_bias.pt')
-
-print("Loaded shared weights from artifacts")
-print(f"Router weight sum: {router_weight.sum().item():.6f}")
-print(f"Gate/up sum: {gate_up_proj.sum().item():.6f}")
-print(f"Down sum: {down_proj.sum().item():.6f}")
-
-def binned_gather(x, indices, bins, expert_capacity, top_k):
-    E, H = bins.shape[0], x.shape[1]
-    out = torch.zeros((E, expert_capacity, H), device=x.device, dtype=x.dtype)
-    for e in range(E):
-        start = 0 if e == 0 else bins[e - 1]
-        end = bins[e]
-        n = min(end - start, expert_capacity)
-        for i in range(n):
-            flat_pos = indices[start + i]
-            tok = flat_pos // top_k
-            out[e, i] = x[tok]
-    return out
-
-def binned_scatter(x, indices, weights, bins, expert_capacity, top_k):
-    E, C, H = x.shape
-    N = indices.shape[0] // top_k
-    out = torch.zeros((N, top_k, H), dtype=x.dtype, device=x.device)
-    for e in range(E):
-        start = 0 if e == 0 else bins[e - 1]
-        end = bins[e]
-        n = end - start
-        if n == 0:
-            continue
-        take = min(n, expert_capacity)
-        for i in range(take):
-            flat_pos = indices[start + i]
-            tok = flat_pos // top_k
-            slot = flat_pos % top_k
-            scale = weights[flat_pos] if weights is not None else 1.0
-            out[tok, slot] = x[e, i] * scale
-    return out.sum(dim=1)
-
-def sort_tokens_by_expert(router_indices, num_experts):
-    flat_indices = router_indices.flatten()
-    sorted_values, sorted_indices = torch.sort(flat_indices)
-    tokens_per_expert = torch.bincount(sorted_values, minlength=num_experts)
-    bins = torch.cumsum(tokens_per_expert, dim=0)
-    return sorted_indices, sorted_values, bins, tokens_per_expert
-
-def binned_experts_ref(
-    hidden_states,
-    router_indices,
-    routing_weights,
-    gate_up_proj,
-    gate_up_proj_bias,
-    down_proj,
-    down_proj_bias,
-    expert_capacity,
-):
-    B, S, H = hidden_states.shape
-    E, K = routing_weights.shape[1], router_indices.shape[1]
-
-    indices, _, bins, _ = sort_tokens_by_expert(router_indices, E)
-    x = binned_gather(hidden_states.view(-1, H), indices, bins, expert_capacity, K)
-
-    gate_up = torch.bmm(x, gate_up_proj) 
-    gate_up += gate_up_proj_bias[..., None, :]
-
-    gate, up = gate_up[..., ::2], gate_up[..., 1::2]
-
-    # clamp to limit
-    limit = 7.0
-    gate = gate.clamp(min=None, max=limit)
-    up = up.clamp(min=-limit, max=limit)
-
-    glu = gate * torch.sigmoid(gate * 1.702)
-    x = (up + 1) * glu
-    x = torch.bmm(x, down_proj) + down_proj_bias[..., None, :]
-
-    # build routing weights aligned to (token, slot)
-    flat_dense = routing_weights.view(-1, E)
-    flat_router = router_indices.view(-1, K)
-    selected = torch.gather(flat_dense, 1, flat_router).reshape(-1)
-
-    # scatter back
-    y = binned_scatter(x, indices, selected, bins, expert_capacity, K)
-
-    return y.view(B, S, H)
-
-class BinnedRouter(nn.Module):
-    def __init__(self, router_weight, router_bias):
-        super().__init__()
-        self.top_k = TOP_K
-        self.num_experts = NUM_EXPERTS
-        self.hidden_dim = HIDDEN_SIZE
-        self.weight = nn.Parameter(router_weight.clone())
-        self.bias = nn.Parameter(router_bias.clone())
-
-    def forward(self, hidden_states):
-        hidden_states = hidden_states.reshape(-1, self.hidden_dim)
-        router_logits = F.linear(hidden_states, self.weight, self.bias)
-        router_top_value, router_indices = torch.topk(router_logits, self.top_k, dim=-1)
-        router_top_value = torch.nn.functional.softmax(router_top_value, dim=1, dtype=router_top_value.dtype)
-        router_scores = torch.zeros_like(router_logits).scatter_(1, router_indices, router_top_value)
-        return router_scores, router_indices
-
-def ceil_div(a, b):
-    return (a + b - 1) // b
-
-class BinnedMoEMLP(nn.Module):
-    def __init__(self, router_weight, router_bias, gate_up_proj, gate_up_proj_bias, down_proj, down_proj_bias):
-        super().__init__()
-        self.router = BinnedRouter(router_weight, router_bias)
-        self.num_experts = NUM_EXPERTS
-        self.hidden_size = HIDDEN_SIZE
-        self.top_k = TOP_K
-
-        # Expert weights - use the loaded weights
-        self.gate_up_proj = nn.Parameter(gate_up_proj.clone())
-        self.gate_up_proj_bias = nn.Parameter(gate_up_proj_bias.clone())
-        self.down_proj = nn.Parameter(down_proj.clone())
-        self.down_proj_bias = nn.Parameter(down_proj_bias.clone())
-
-    def forward(self, hidden_states):
-        router_scores, router_indices = self.router(hidden_states)
-        batch_size = hidden_states.shape[0]
-        expert_capacity = ceil_div(batch_size * self.top_k, self.num_experts)
-
-        output = binned_experts_ref(
-            hidden_states,
-            router_indices,
-            router_scores,
-            self.gate_up_proj,
-            self.gate_up_proj_bias,
-            self.down_proj,
-            self.down_proj_bias,
-            expert_capacity,
-        )
-
-        return output, router_scores
-
-# Run the model
-set_seed(GENERAL_SEED)
-
-device = torch.device(DEVICE)
-dtype = to_dtype(DTYPE)
-
-print("\n=== Binned Implementation ===")
-# Initialize model with loaded weights
-model = BinnedMoEMLP(
-    router_weight.to(device),
-    router_bias.to(device),
-    gate_up_proj.to(device),
-    gate_up_proj_bias.to(device),
-    down_proj.to(device),
-    down_proj_bias.to(device)
-).to(device=device)
-
-print(f"Router weight sum: {model.router.weight.sum().item():.6f}")
-print(f"Gate/up proj sum: {model.gate_up_proj.sum().item():.6f}")
-print(f"Down proj sum: {model.down_proj.sum().item():.6f}")
-
-# Generate the same input as Yamoe
-set_seed(INPUT_SEED)
-x = torch.randn(BATCH_SIZE, SEQ_LEN, HIDDEN_SIZE, device=device, dtype=dtype) * 0.1
-
-# Benchmark the model with varied inputs to prevent caching artifacts
-tokens = BATCH_SIZE * SEQ_LEN
-with bench_context(warmup=10, iters=50, device=device, dtype=dtype, tokens=tokens, save_json="binned_results.json", vary_inputs=True) as bench:
-    output, stats = bench(model, x)
-    print(f"\nOutput sum: {output[0].sum().item():.6f}")
-
- -
-
-
-
-
-
Loaded shared weights from artifacts -Router weight sum: 12.588732 -Gate/up sum: 1026.601807 -Down sum: 206.729263 - -=== Binned Implementation === -Router weight sum: 12.588732 -Gate/up proj sum: 1026.601807 -Down proj sum: 206.729340 - -┌─ Benchmark Configuration ─────────────────────────────┐ -│ Warmup: 10 Iters: 50 │ -│ Tokens: 100 │ -│ Input Variation: Enabled (prevents caching artifacts) │ -└────────────────────────────────────────────────────────┘ - -Base Input: shape=(1, 100, 1152), dtype=torch.float32, device=cuda:0, range=[-0.486445, 0.446746], mean=-0.000048, std=0.099986, norm=33.936142 -Input Variation: +0.001 * iteration (deterministic) - -Warming up (10 iterations)... -Benchmarking (50 iterations)... - Progress: 20% complete (avg: 37.247 ms) - Progress: 40% complete (avg: 37.082 ms) - Progress: 60% complete (avg: 36.706 ms) - Progress: 80% complete (avg: 36.240 ms) - -Output tensors: - Primary: shape=(1, 100, 1152), dtype=torch.float32, device=cuda:0, range=[-0.049506, 0.054984], mean=0.000034, std=0.006508, norm=2.208791 - Auxiliary: shape=(100, 128), dtype=torch.float32, device=cuda:0, range=[0.000000, 0.302948], mean=0.007812, std=0.043553, norm=5.005893 - -━━━━━━━━━━━━━━━━━━━━ Benchmark Results ━━━━━━━━━━━━━━━━━━━━ -Iterations: 50 - -Latency Statistics: - Average: 35.794 ms - Min: 33.227 ms - Max: 37.582 ms - Std Dev: 1.261 ms - -Percentiles: - P50 (median): 36.036 ms - P95: 37.377 ms - P99: 37.525 ms - -Throughput: - Tokens/sec: 2793.8 - Std Dev: 99.7 -━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ - -Saved benchmark results to binned_results.json - -Output sum: 3.971905 -
-
-
▶ UV Install Logs
- -
-
-

Artifacts:

-binned_results.json -
-
-
- -

GPT-OSS Implementation

-

This section runs the GPT-OSS MoE implementation with manual expert loop handling.

-
-
- -▼ code -▼ output - ▶ uv-logs - | -Cell: gptoss_run | deps: torch, numpy | 40.29s - | - -Raw -
-
-
-
-1 -2 -3 -4 -5 -6 -7 -8 -9 -10 -11 -12 -13 -14 -15 -16 -17 -18 -19 -20 -21 -22 -23 -24 -25 -26 -27 -28 -29 -30 -31 -32 -33 -34 -35 -36 -37 -38 -39 -40 -41 -42 -43 -44 -45 -46 -47 -48 -49 -50 -51 -52 -53 -54 -55 -56 -57 -58 -59 -60 -61 -62 -63 -64 -65 -66 -67 -68 -69 -70 -71 -72 -73 -74 -75 -76 -77 -78 -79 -80 -81 -82 -83 -84 -85 -86 -87 -88 -89 -90 -91 -92 -93 -94 -95 -96 -97 -98 -99 -100 -101 -102 -103 -104 -105 -106 -107 -108 -109 -110 -111 -112 -113 -114 -115 -116 -117 -118 -119 -120 -121 -122 -123 -124 -125 -126 -127 -128 -129 -130 -131 -132 -133 -134 -135 -136 -137 -138 -139 -140 -
-
-
import torch
-from torch import nn
-from torch.nn import functional as F
-from bench_utils import to_dtype, tensor_stats, set_seed, bench_context
-from config import (
-    NUM_EXPERTS, HIDDEN_SIZE, TOP_K,
-    BATCH_SIZE, SEQ_LEN, DTYPE, DEVICE,
-    WEIGHT_SEED, EXPERT_SEED, INPUT_SEED, GENERAL_SEED
-)
-from pathlib import Path
-import os
-
-# Discover the upstream artifact directory from env
-data_dir = os.environ.get('UVNOTE_INPUT_SAVE_DATA', '.')
-
-router_weight = torch.load(Path(data_dir) / 'router_weight.pt')
-router_bias = torch.load(Path(data_dir) / 'router_bias.pt')
-gate_up_proj = torch.load(Path(data_dir) / 'gate_up_proj.pt')
-gate_up_proj_bias = torch.load(Path(data_dir) / 'gate_up_proj_bias.pt')
-down_proj = torch.load(Path(data_dir) / 'down_proj.pt')
-down_proj_bias = torch.load(Path(data_dir) / 'down_proj_bias.pt')
-
-print("Loaded shared weights from artifacts")
-print(f"Router weight sum: {router_weight.sum().item():.6f}")
-print(f"Gate/up sum: {gate_up_proj.sum().item():.6f}")
-print(f"Down sum: {down_proj.sum().item():.6f}")
-
-class GptOssRouter(nn.Module):
-    def __init__(self, router_weight, router_bias):
-        super().__init__()
-        self.top_k = TOP_K
-        self.num_experts = NUM_EXPERTS
-        self.hidden_dim = HIDDEN_SIZE
-        self.weight = nn.Parameter(router_weight.clone())
-        self.bias = nn.Parameter(router_bias.clone())
-
-    def forward(self, hidden_states):
-        hidden_states = hidden_states.reshape(-1, self.hidden_dim)
-        router_logits = F.linear(hidden_states, self.weight, self.bias)
-        router_top_value, router_indices = torch.topk(router_logits, self.top_k, dim=-1)
-        router_top_value = torch.nn.functional.softmax(router_top_value, dim=1, dtype=router_top_value.dtype)
-        router_scores = torch.zeros_like(router_logits).scatter_(1, router_indices, router_top_value)
-        return router_scores, router_indices
-
-class GptOssExperts(nn.Module):
-    def __init__(self, gate_up_proj, gate_up_proj_bias, down_proj, down_proj_bias):
-        super().__init__()
-        self.num_experts = NUM_EXPERTS
-        self.hidden_size = HIDDEN_SIZE
-        self.expert_dim = self.hidden_size
-        self.gate_up_proj = nn.Parameter(gate_up_proj.clone())
-        self.gate_up_proj_bias = nn.Parameter(gate_up_proj_bias.clone())
-        self.down_proj = nn.Parameter(down_proj.clone())
-        self.down_proj_bias = nn.Parameter(down_proj_bias.clone())
-        self.alpha = 1.702
-        self.limit = 7.0
-
-    def forward(self, hidden_states: torch.Tensor, router_indices=None, routing_weights=None) -> torch.Tensor:
-        batch_size = hidden_states.shape[0]
-        hidden_states = hidden_states.reshape(-1, self.hidden_size)
-        num_experts = routing_weights.shape[1]
-
-        if hidden_states.device.type == "cpu" or self.training:
-            next_states = torch.zeros_like(hidden_states, dtype=hidden_states.dtype, device=hidden_states.device)
-            with torch.no_grad():
-                expert_mask = torch.nn.functional.one_hot(router_indices, num_classes=num_experts)
-                expert_mask = expert_mask.permute(2, 1, 0)
-                expert_hit = torch.greater(expert_mask.sum(dim=(-1, -2)), 0).nonzero()
-
-            for expert_idx in expert_hit[:]:
-                expert_idx = expert_idx[0]
-                with torch.no_grad():
-                    _, token_idx = torch.where(expert_mask[expert_idx])
-                current_state = hidden_states[token_idx]
-                gate_up = current_state @ self.gate_up_proj[expert_idx] + self.gate_up_proj_bias[expert_idx]
-                gate, up = gate_up[..., ::2], gate_up[..., 1::2]
-                gate = gate.clamp(min=None, max=self.limit)
-                up = up.clamp(min=-self.limit, max=self.limit)
-                glu = gate * torch.sigmoid(gate * self.alpha)
-                gated_output = (up + 1) * glu
-                out = gated_output @ self.down_proj[expert_idx] + self.down_proj_bias[expert_idx]
-                weighted_output = out * routing_weights[token_idx, expert_idx, None]
-                next_states.index_add_(0, token_idx, weighted_output.to(hidden_states.dtype))
-            next_states = next_states.view(batch_size, -1, self.hidden_size)
-        else:
-            hidden_states = hidden_states.repeat(num_experts, 1)
-            hidden_states = hidden_states.view(num_experts, -1, self.hidden_size)
-            gate_up = torch.bmm(hidden_states, self.gate_up_proj) + self.gate_up_proj_bias[..., None, :]
-            gate, up = gate_up[..., ::2], gate_up[..., 1::2]
-            gate = gate.clamp(min=None, max=self.limit)
-            up = up.clamp(min=-self.limit, max=self.limit)
-            glu = gate * torch.sigmoid(gate * self.alpha)
-            next_states = torch.bmm(((up + 1) * glu), self.down_proj)
-            next_states = next_states + self.down_proj_bias[..., None, :]
-            next_states = next_states.view(num_experts, batch_size, -1, self.hidden_size)
-            next_states = next_states * routing_weights.transpose(0, 1).view(num_experts, batch_size, -1)[..., None]
-            next_states = next_states.sum(dim=0)
-        return next_states
-
-class GptOssMoEMLP(nn.Module):
-    def __init__(self, router_weight, router_bias, gate_up_proj, gate_up_proj_bias, down_proj, down_proj_bias):
-        super().__init__()
-        self.router = GptOssRouter(router_weight, router_bias)
-        self.experts = GptOssExperts(gate_up_proj, gate_up_proj_bias, down_proj, down_proj_bias)
-
-    def forward(self, hidden_states):
-        router_scores, router_indices = self.router(hidden_states)
-        routed_out = self.experts(hidden_states, router_indices=router_indices, routing_weights=router_scores)
-        return routed_out, router_scores
-
-# Run the model
-set_seed(GENERAL_SEED)
-
-device = torch.device(DEVICE)
-dtype = to_dtype(DTYPE)
-
-print("\n=== GPT-OSS Implementation ===")
-# Initialize model with loaded weights
-model = GptOssMoEMLP(
-    router_weight.to(device),
-    router_bias.to(device),
-    gate_up_proj.to(device),
-    gate_up_proj_bias.to(device),
-    down_proj.to(device),
-    down_proj_bias.to(device)
-).to(device=device)
-
-print(f"Router weight sum: {model.router.weight.sum().item():.6f}")
-print(f"Gate/up proj sum: {model.experts.gate_up_proj.sum().item():.6f}")
-print(f"Down proj sum: {model.experts.down_proj.sum().item():.6f}")
-
-# Generate the same input as other implementations
-set_seed(INPUT_SEED)
-x = torch.randn(BATCH_SIZE, SEQ_LEN, HIDDEN_SIZE, device=device, dtype=dtype) * 0.1
-
-# Benchmark the model with varied inputs to prevent caching artifacts
-tokens = BATCH_SIZE * SEQ_LEN
-with bench_context(warmup=10, iters=50, device=device, dtype=dtype, tokens=tokens, save_json="gptoss_results.json", vary_inputs=True) as bench:
-    output, stats = bench(model, x)
-    print(f"\nOutput sum: {output[0].sum().item():.6f}")
-
- -
-
-
-
-
-
Loaded shared weights from artifacts -Router weight sum: 12.588732 -Gate/up sum: 1026.601807 -Down sum: 206.729263 - -=== GPT-OSS Implementation === -Router weight sum: 12.588732 -Gate/up proj sum: 1026.601807 -Down proj sum: 206.729340 - -┌─ Benchmark Configuration ─────────────────────────────┐ -│ Warmup: 10 Iters: 50 │ -│ Tokens: 100 │ -│ Input Variation: Enabled (prevents caching artifacts) │ -└────────────────────────────────────────────────────────┘ - -Base Input: shape=(1, 100, 1152), dtype=torch.float32, device=cuda:0, range=[-0.486445, 0.446746], mean=-0.000048, std=0.099986, norm=33.936142 -Input Variation: +0.001 * iteration (deterministic) - -Warming up (10 iterations)... -Benchmarking (50 iterations)... - Progress: 20% complete (avg: 48.814 ms) - Progress: 40% complete (avg: 48.182 ms) - Progress: 60% complete (avg: 47.686 ms) - Progress: 80% complete (avg: 46.880 ms) - -Output tensors: - Primary: shape=(1, 100, 1152), dtype=torch.float32, device=cuda:0, range=[-0.064982, 0.061193], mean=0.000100, std=0.013510, norm=4.585560 - Auxiliary: shape=(100, 128), dtype=torch.float32, device=cuda:0, range=[0.000000, 0.302948], mean=0.007812, std=0.043553, norm=5.005893 - -━━━━━━━━━━━━━━━━━━━━ Benchmark Results ━━━━━━━━━━━━━━━━━━━━ -Iterations: 50 - -Latency Statistics: - Average: 45.933 ms - Min: 40.056 ms - Max: 49.512 ms - Std Dev: 2.471 ms - -Percentiles: - P50 (median): 46.499 ms - P95: 49.058 ms - P99: 49.353 ms - -Throughput: - Tokens/sec: 2177.1 - Std Dev: 121.3 -━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ - -Saved benchmark results to gptoss_results.json - -Output sum: 11.532237 -
-
-
▶ UV Install Logs
- -
-
-

Artifacts:

-gptoss_results.json -
-
-
- -

GPT-OSS Implementation (Training Mode)

-

This section runs the GPT-OSS MoE implementation with training mode enabled to force the expert loop path.

-
-
- -▼ code -▼ output - ▶ uv-logs - | -Cell: gptoss_training_run | deps: torch, numpy | 39.76s - | - -Raw -
-
-
-
-1 -2 -3 -4 -5 -6 -7 -8 -9 -10 -11 -12 -13 -14 -15 -16 -17 -18 -19 -20 -21 -22 -23 -24 -25 -26 -27 -28 -29 -30 -31 -32 -33 -34 -35 -36 -37 -38 -39 -40 -41 -42 -43 -44 -45 -46 -47 -48 -49 -50 -51 -52 -53 -54 -55 -56 -57 -58 -59 -60 -61 -62 -63 -64 -65 -66 -67 -68 -69 -70 -71 -72 -73 -74 -75 -76 -77 -78 -79 -80 -81 -82 -83 -84 -85 -86 -87 -88 -89 -90 -91 -92 -93 -94 -95 -96 -97 -98 -99 -100 -101 -102 -103 -104 -105 -106 -107 -108 -109 -110 -111 -112 -113 -114 -115 -116 -117 -118 -119 -120 -121 -122 -123 -124 -125 -126 -127 -128 -129 -130 -131 -
-
-
import torch
-from torch import nn
-from torch.nn import functional as F
-from bench_utils import to_dtype, tensor_stats, set_seed, bench_context
-from config import (
-    NUM_EXPERTS, HIDDEN_SIZE, TOP_K,
-    BATCH_SIZE, SEQ_LEN, DTYPE, DEVICE,
-    WEIGHT_SEED, EXPERT_SEED, INPUT_SEED, GENERAL_SEED
-)
-from pathlib import Path
-import os
-
-# Discover the upstream artifact directory from env
-data_dir = os.environ.get('UVNOTE_INPUT_SAVE_DATA', '.')
-
-router_weight = torch.load(Path(data_dir) / 'router_weight.pt')
-router_bias = torch.load(Path(data_dir) / 'router_bias.pt')
-gate_up_proj = torch.load(Path(data_dir) / 'gate_up_proj.pt')
-gate_up_proj_bias = torch.load(Path(data_dir) / 'gate_up_proj_bias.pt')
-down_proj = torch.load(Path(data_dir) / 'down_proj.pt')
-down_proj_bias = torch.load(Path(data_dir) / 'down_proj_bias.pt')
-
-print("Loaded shared weights from artifacts")
-print(f"Router weight sum: {router_weight.sum().item():.6f}")
-print(f"Gate/up sum: {gate_up_proj.sum().item():.6f}")
-print(f"Down sum: {down_proj.sum().item():.6f}")
-
-class GptOssTrainingRouter(nn.Module):
-    def __init__(self, router_weight, router_bias):
-        super().__init__()
-        self.top_k = TOP_K
-        self.num_experts = NUM_EXPERTS
-        self.hidden_dim = HIDDEN_SIZE
-        self.weight = nn.Parameter(router_weight.clone())
-        self.bias = nn.Parameter(router_bias.clone())
-
-    def forward(self, hidden_states):
-        hidden_states = hidden_states.reshape(-1, self.hidden_dim)
-        router_logits = F.linear(hidden_states, self.weight, self.bias)
-        router_top_value, router_indices = torch.topk(router_logits, self.top_k, dim=-1)
-        router_top_value = torch.nn.functional.softmax(router_top_value, dim=1, dtype=router_top_value.dtype)
-        router_scores = torch.zeros_like(router_logits).scatter_(1, router_indices, router_top_value)
-        return router_scores, router_indices
-
-class GptOssTrainingExperts(nn.Module):
-    def __init__(self, gate_up_proj, gate_up_proj_bias, down_proj, down_proj_bias):
-        super().__init__()
-        self.num_experts = NUM_EXPERTS
-        self.hidden_size = HIDDEN_SIZE
-        self.expert_dim = self.hidden_size
-        self.gate_up_proj = nn.Parameter(gate_up_proj.clone())
-        self.gate_up_proj_bias = nn.Parameter(gate_up_proj_bias.clone())
-        self.down_proj = nn.Parameter(down_proj.clone())
-        self.down_proj_bias = nn.Parameter(down_proj_bias.clone())
-        self.alpha = 1.702
-        self.limit = 7.0
-
-    def forward(self, hidden_states: torch.Tensor, router_indices=None, routing_weights=None) -> torch.Tensor:
-        batch_size = hidden_states.shape[0]
-        hidden_states = hidden_states.reshape(-1, self.hidden_size)
-        num_experts = routing_weights.shape[1]
-
-        # Force training mode path (expert loop instead of batched)
-        next_states = torch.zeros_like(hidden_states, dtype=hidden_states.dtype, device=hidden_states.device)
-        with torch.no_grad():
-            expert_mask = torch.nn.functional.one_hot(router_indices, num_classes=num_experts)
-            expert_mask = expert_mask.permute(2, 1, 0)
-            expert_hit = torch.greater(expert_mask.sum(dim=(-1, -2)), 0).nonzero()
-
-        for expert_idx in expert_hit[:]:
-            expert_idx = expert_idx[0]
-            with torch.no_grad():
-                _, token_idx = torch.where(expert_mask[expert_idx])
-            current_state = hidden_states[token_idx]
-            gate_up = current_state @ self.gate_up_proj[expert_idx] + self.gate_up_proj_bias[expert_idx]
-            gate, up = gate_up[..., ::2], gate_up[..., 1::2]
-            gate = gate.clamp(min=None, max=self.limit)
-            up = up.clamp(min=-self.limit, max=self.limit)
-            glu = gate * torch.sigmoid(gate * self.alpha)
-            gated_output = (up + 1) * glu
-            out = gated_output @ self.down_proj[expert_idx] + self.down_proj_bias[expert_idx]
-            weighted_output = out * routing_weights[token_idx, expert_idx, None]
-            next_states.index_add_(0, token_idx, weighted_output.to(hidden_states.dtype))
-        next_states = next_states.view(batch_size, -1, self.hidden_size)
-        return next_states
-
-class GptOssTrainingMoEMLP(nn.Module):
-    def __init__(self, router_weight, router_bias, gate_up_proj, gate_up_proj_bias, down_proj, down_proj_bias):
-        super().__init__()
-        self.router = GptOssTrainingRouter(router_weight, router_bias)
-        self.experts = GptOssTrainingExperts(gate_up_proj, gate_up_proj_bias, down_proj, down_proj_bias)
-
-    def forward(self, hidden_states):
-        router_scores, router_indices = self.router(hidden_states)
-        routed_out = self.experts(hidden_states, router_indices=router_indices, routing_weights=router_scores)
-        return routed_out, router_scores
-
-# Run the model
-set_seed(GENERAL_SEED)
-
-device = torch.device(DEVICE)
-dtype = to_dtype(DTYPE)
-
-print("\n=== GPT-OSS Implementation (Training Mode - Expert Loop) ===")
-# Initialize model with loaded weights and force training mode
-model = GptOssTrainingMoEMLP(
-    router_weight.to(device),
-    router_bias.to(device),
-    gate_up_proj.to(device),
-    gate_up_proj_bias.to(device),
-    down_proj.to(device),
-    down_proj_bias.to(device)
-).to(device=device)
-
-# Set to training mode to force expert loop path
-model.train()
-
-print(f"Router weight sum: {model.router.weight.sum().item():.6f}")
-print(f"Gate/up proj sum: {model.experts.gate_up_proj.sum().item():.6f}")
-print(f"Down proj sum: {model.experts.down_proj.sum().item():.6f}")
-print(f"Model training mode: {model.training}")
-
-# Generate the same input as other implementations
-set_seed(INPUT_SEED)
-x = torch.randn(BATCH_SIZE, SEQ_LEN, HIDDEN_SIZE, device=device, dtype=dtype) * 0.1
-
-# Benchmark the model with varied inputs to prevent caching artifacts
-tokens = BATCH_SIZE * SEQ_LEN
-with bench_context(warmup=10, iters=50, device=device, dtype=dtype, tokens=tokens, save_json="gptoss_training_results.json", vary_inputs=True) as bench:
-    output, stats = bench(model, x)
-    print(f"\nOutput sum: {output[0].sum().item():.6f}")
-
- -
-
-
-
-
-
Loaded shared weights from artifacts -Router weight sum: 12.588732 -Gate/up sum: 1026.601807 -Down sum: 206.729263 - -=== GPT-OSS Implementation (Training Mode - Expert Loop) === -Router weight sum: 12.588732 -Gate/up proj sum: 1026.601807 -Down proj sum: 206.729340 -Model training mode: True - -┌─ Benchmark Configuration ─────────────────────────────┐ -│ Warmup: 10 Iters: 50 │ -│ Tokens: 100 │ -│ Input Variation: Enabled (prevents caching artifacts) │ -└────────────────────────────────────────────────────────┘ - -Base Input: shape=(1, 100, 1152), dtype=torch.float32, device=cuda:0, range=[-0.486445, 0.446746], mean=-0.000048, std=0.099986, norm=33.936142 -Input Variation: +0.001 * iteration (deterministic) - -Warming up (10 iterations)... -Benchmarking (50 iterations)... - Progress: 20% complete (avg: 50.744 ms) - Progress: 40% complete (avg: 50.240 ms) - Progress: 60% complete (avg: 48.683 ms) - Progress: 80% complete (avg: 47.222 ms) - -Output tensors: - Primary: shape=(1, 100, 1152), dtype=torch.float32, device=cuda:0, range=[-0.064982, 0.061193], mean=0.000100, std=0.013510, norm=4.585560 - Auxiliary: shape=(100, 128), dtype=torch.float32, device=cuda:0, range=[0.000000, 0.302948], mean=0.007812, std=0.043553, norm=5.005893 - -━━━━━━━━━━━━━━━━━━━━ Benchmark Results ━━━━━━━━━━━━━━━━━━━━ -Iterations: 50 - -Latency Statistics: - Average: 45.947 ms - Min: 38.690 ms - Max: 51.193 ms - Std Dev: 3.915 ms - -Percentiles: - P50 (median): 45.209 ms - P95: 51.039 ms - P99: 51.144 ms - -Throughput: - Tokens/sec: 2176.4 - Std Dev: 188.8 -━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ - -Saved benchmark results to gptoss_training_results.json - -Output sum: 11.532237 -
-
-
▶ UV Install Logs
- -
- -
-
- -

MegaBlocks Implementation

-

This section runs the MegaBlocks MoE implementation with optimized kernels from the Hugging Face hub.

-
-
- -▼ code -▼ output - ▶ uv-logs - | -Cell: megablocks_run | deps: torch, numpy, kernels | 47.11s - | - -Raw -
-
-
-
-1 -2 -3 -4 -5 -6 -7 -8 -9 -10 -11 -12 -13 -14 -15 -16 -17 -18 -19 -20 -21 -22 -23 -24 -25 -26 -27 -28 -29 -30 -31 -32 -33 -34 -35 -36 -37 -38 -39 -40 -41 -42 -43 -44 -45 -46 -47 -48 -49 -50 -51 -52 -53 -54 -55 -56 -57 -58 -59 -60 -61 -62 -63 -64 -65 -66 -67 -68 -69 -70 -71 -72 -73 -74 -75 -76 -77 -78 -79 -80 -81 -82 -83 -84 -85 -86 -87 -88 -89 -90 -91 -92 -93 -94 -95 -
-
-
import torch
-from torch import nn
-from torch.nn import functional as F
-from kernels import get_kernel, get_local_kernel
-from bench_utils import to_dtype, tensor_stats, set_seed, bench_context
-from config import (
-    NUM_EXPERTS, HIDDEN_SIZE, TOP_K,
-    BATCH_SIZE, SEQ_LEN, DTYPE, DEVICE,
-    WEIGHT_SEED, EXPERT_SEED, INPUT_SEED, GENERAL_SEED
-)
-from pathlib import Path
-from collections import namedtuple
-import os
-
-# Discover the upstream artifact directory from env
-data_dir = os.environ.get('UVNOTE_INPUT_SAVE_DATA', '.')
-
-print(f"Loading weights from: {data_dir}")
-
-router_weight = torch.load(Path(data_dir) / 'router_weight.pt')
-router_bias = torch.load(Path(data_dir) / 'router_bias.pt')
-gate_up_proj = torch.load(Path(data_dir) / 'gate_up_proj.pt')
-gate_up_proj_bias = torch.load(Path(data_dir) / 'gate_up_proj_bias.pt')
-down_proj = torch.load(Path(data_dir) / 'down_proj.pt')
-down_proj_bias = torch.load(Path(data_dir) / 'down_proj_bias.pt')
-
-print("Loaded shared weights from artifacts")
-print(f"Router weight sum: {router_weight.sum().item():.6f}")
-print(f"Gate/up sum: {gate_up_proj.sum().item():.6f}")
-print(f"Down sum: {down_proj.sum().item():.6f}")
-
-def build_megablocks_model(device: torch.device):
-    # Download optimized kernels from the Hugging Face hub
-    megablocks = get_kernel("kernels-community/megablocks", revision="v0.0.2")
-    model = megablocks.layers.MegaBlocksMoeMLP()
-
-    # Create attribute container for expert weights
-    model.experts = namedtuple(
-        "Experts", ["gate_up_proj", "gate_up_proj_bias", "down_proj", "down_proj_bias", "hidden_size"]
-    )
-
-    # Use loaded router weights for consistency
-    model.router = torch.nn.Linear(HIDDEN_SIZE, NUM_EXPERTS, device=device)
-    with torch.no_grad():
-        model.router.weight.copy_(router_weight)
-        model.router.bias.copy_(router_bias)
-
-    # Attach loaded expert weights to the experts container
-    e = model.experts
-    e.alpha = 1.702
-    e.capacity_factor = 32
-    e.gate_up_proj = torch.nn.Parameter(gate_up_proj.clone().to(device))
-    e.gate_up_proj_bias = torch.nn.Parameter(gate_up_proj_bias.clone().to(device))
-    e.down_proj = torch.nn.Parameter(down_proj.clone().to(device))
-    e.down_proj_bias = torch.nn.Parameter(down_proj_bias.clone().to(device))
-    e.hidden_size = HIDDEN_SIZE
-
-    # Log weight statistics for comparison
-    print(f"[MegaBlocks] Router weight sum: {model.router.weight.sum().item():.6f}")
-    print(f"[MegaBlocks] Gate/up projection shape: {tuple(e.gate_up_proj.shape)}, sum: {e.gate_up_proj.sum().item():.6f}")
-    print(f"[MegaBlocks] Down projection shape: {tuple(e.down_proj.shape)}, sum: {e.down_proj.sum().item():.6f}")
-
-    return model
-
-# Create a wrapper to match the interface of other implementations
-class MegaBlocksMoEWrapper(nn.Module):
-    def __init__(self, megablocks_model):
-        super().__init__()
-        self.model = megablocks_model
-
-    def forward(self, hidden_states):
-        # MegaBlocks expects input in the format (batch, seq_len, hidden_dim)
-        output, dummy_routing_weights = self.model(hidden_states)
-        return output, dummy_routing_weights
-
-# Run the model
-set_seed(GENERAL_SEED)
-
-device = torch.device(DEVICE)
-dtype = to_dtype(DTYPE)
-
-print("\n=== MegaBlocks Implementation ===")
-# Build MegaBlocks model with loaded weights
-megablocks_model = build_megablocks_model(device)
-model = MegaBlocksMoEWrapper(megablocks_model).to(device=device)
-
-# Generate the same input as other implementations
-set_seed(INPUT_SEED)
-x = torch.randn(BATCH_SIZE, SEQ_LEN, HIDDEN_SIZE, device=device, dtype=dtype) * 0.1
-
-# Benchmark the model with varied inputs to prevent caching artifacts
-tokens = BATCH_SIZE * SEQ_LEN
-with bench_context(warmup=10, iters=50, device=device, dtype=dtype, tokens=tokens, save_json="megablocks_results.json", vary_inputs=True) as bench:
-    output, stats = bench(model, x)
-    print(f"\nOutput sum: {output[0].sum().item():.6f}")
-
- -
-
-
-
-
-
Loading weights from: /repo/moe_benchmarks/megablocks_yamoe/.uvnote/cache/b398a2853af91970392ae37f0d53a0eda463df639220863fbd38f33605bf9cbb -Loaded shared weights from artifacts -Router weight sum: 12.588732 -Gate/up sum: 1026.601807 -Down sum: 206.729263 - -=== MegaBlocks Implementation === -[MegaBlocks] Router weight sum: 12.588732 -[MegaBlocks] Gate/up projection shape: (128, 1152, 2304), sum: 1026.601807 -[MegaBlocks] Down projection shape: (128, 1152, 1152), sum: 206.729340 - -┌─ Benchmark Configuration ─────────────────────────────┐ -│ Warmup: 10 Iters: 50 │ -│ Tokens: 100 │ -│ Input Variation: Enabled (prevents caching artifacts) │ -└────────────────────────────────────────────────────────┘ - -Base Input: shape=(1, 100, 1152), dtype=torch.float32, device=cuda:0, range=[-0.486445, 0.446746], mean=-0.000048, std=0.099986, norm=33.936142 -Input Variation: +0.001 * iteration (deterministic) - -Warming up (10 iterations)... -Benchmarking (50 iterations)... - Progress: 20% complete (avg: 0.855 ms) - Progress: 40% complete (avg: 0.840 ms) - Progress: 60% complete (avg: 0.838 ms) - Progress: 80% complete (avg: 2.699 ms) - -Output tensors: - Primary: shape=(1, 100, 1152), dtype=torch.float32, device=cuda:0, range=[-0.061104, 0.055115], mean=0.000056, std=0.013535, norm=4.593927 - Auxiliary: shape=(100, 4), dtype=torch.float32, device=cuda:0, range=[0.220999, 0.302948], mean=0.250000, std=0.012156, norm=5.005893 - -━━━━━━━━━━━━━━━━━━━━ Benchmark Results ━━━━━━━━━━━━━━━━━━━━ -Iterations: 50 - -Latency Statistics: - Average: 3.848 ms - Min: 0.812 ms - Max: 8.536 ms - Std Dev: 3.698 ms - -Percentiles: - P50 (median): 0.839 ms - P95: 8.500 ms - P99: 8.529 ms - -Throughput: - Tokens/sec: 25988.6 - Std Dev: 53035.4 -━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ - -Saved benchmark results to megablocks_results.json - -Output sum: 6.473885 -
-
-
▶ UV Install Logs
- -
-
Fetching 66 files: 0%| | 0/66 [00:00<?, ?it/s] -Fetching 66 files: 2%|▏ | 1/66 [00:00<00:11, 5.87it/s] -Fetching 66 files: 3%|▎ | 2/66 [00:00<00:10, 6.31it/s] -Fetching 66 files: 17%|█▋ | 11/66 [00:00<00:01, 34.40it/s] -Fetching 66 files: 24%|██▍ | 16/66 [00:00<00:01, 38.36it/s] -Fetching 66 files: 32%|███▏ | 21/66 [00:00<00:01, 23.86it/s] -Fetching 66 files: 52%|█████▏ | 34/66 [00:01<00:00, 39.79it/s] -Fetching 66 files: 59%|█████▉ | 39/66 [00:01<00:00, 28.73it/s] -Fetching 66 files: 82%|████████▏ | 54/66 [00:01<00:00, 40.50it/s] -Fetching 66 files: 94%|█████████▍| 62/66 [00:01<00:00, 45.88it/s] -Fetching 66 files: 100%|██████████| 66/66 [00:01<00:00, 37.06it/s]
-
-

Artifacts:

-megablocks_results.json -
-
-
- -

Performance Visualization

-

This section reads all benchmark results and creates a comprehensive performance comparison chart.

-
-
- -▼ code -▼ output - ▶ uv-logs - | -Cell: visualization | deps: matplotlib | 3.13s - | - -Raw -
-
-
-
-1 -2 -3 -4 -5 -6 -7 -8 -9 -10 -11 -12 -13 -14 -15 -16 -17 -18 -19 -20 -21 -22 -23 -24 -25 -26 -27 -28 -29 -30 -31 -32 -33 -34 -35 -36 -37 -38 -39 -40 -41 -42 -43 -44 -45 -46 -47 -48 -49 -50 -51 -52 -53 -54 -55 -56 -57 -58 -59 -60 -61 -62 -63 -64 -65 -66 -67 -68 -69 -70 -71 -72 -73 -74 -75 -76 -77 -78 -79 -80 -81 -82 -83 -84 -85 -86 -87 -88 -89 -90 -91 -92 -93 -94 -95 -96 -97 -98 -99 -100 -101 -102 -103 -104 -105 -106 -107 -108 -109 -110 -
-
-
import json
-import matplotlib.pyplot as plt
-import numpy as np
-from pathlib import Path
-import os
-
-# List of expected result files
-yamoe_dir = os.environ.get('UVNOTE_INPUT_YAMOE_RUN', '.')
-binned_dir = os.environ.get('UVNOTE_INPUT_BINNED_RUN', '.')
-gptoss_dir = os.environ.get('UVNOTE_INPUT_GPTOSS_RUN', '.')
-gptoss_training_dir = os.environ.get('UVNOTE_INPUT_GPTOSS_TRAINING_RUN', '.')
-megablocks_dir = os.environ.get('UVNOTE_INPUT_MEGABLOCKS_RUN', '.')
-
-result_files = [
-    Path(yamoe_dir) / "yamoe_results.json",
-    Path(binned_dir) / "binned_results.json", 
-    Path(gptoss_dir) / "gptoss_results.json",
-    Path(gptoss_training_dir) / "gptoss_training_results.json",
-    Path(megablocks_dir) / "megablocks_results.json"
-]
-
-# Load all benchmark results
-results = {}
-for file in result_files:
-    if Path(file).exists():
-        with open(file, 'r') as f:
-            data = json.load(f)
-            results[data['implementation']] = data
-        print(f"Loaded {file}")
-    else:
-        print(f"Missing {file}")
-
-if not results:
-    print("No benchmark results found. Run the benchmark cells first.")
-else:
-    # Extract data for plotting
-    implementations = list(results.keys())
-    avg_latencies = [results[impl]['stats']['avg_ms'] for impl in implementations]
-    p95_latencies = [results[impl]['stats']['p95_ms'] for impl in implementations]
-    throughputs = [results[impl]['stats'].get('tokens_per_s', 0) for impl in implementations]
-
-    # Create figure with subplots
-    fig, (ax1, ax2, ax3) = plt.subplots(1, 3, figsize=(18, 6))
-    fig.suptitle('MoE Implementation Performance Comparison', fontsize=16, fontweight='bold')
-
-    # Colors for each implementation
-    colors = ['#FF6B6B', '#4ECDC4', '#45B7D1', '#96CEB4', '#FECA57'][:len(implementations)]
-
-    # 1. Average Latency Chart
-    bars1 = ax1.bar(implementations, avg_latencies, color=colors, alpha=0.8, edgecolor='black', linewidth=1)
-    ax1.set_title('Average Latency', fontweight='bold', fontsize=14)
-    ax1.set_ylabel('Latency (ms)', fontweight='bold')
-    ax1.tick_params(axis='x', rotation=45)
-    ax1.grid(axis='y', alpha=0.3)
-
-    # Add value labels on bars
-    for bar, val in zip(bars1, avg_latencies):
-        ax1.text(bar.get_x() + bar.get_width()/2, bar.get_height() + max(avg_latencies)*0.01,
-                f'{val:.2f}ms', ha='center', va='bottom', fontweight='bold')
-
-    # 2. P95 Latency Chart
-    bars2 = ax2.bar(implementations, p95_latencies, color=colors, alpha=0.8, edgecolor='black', linewidth=1)
-    ax2.set_title('95th Percentile Latency', fontweight='bold', fontsize=14)
-    ax2.set_ylabel('Latency (ms)', fontweight='bold')
-    ax2.tick_params(axis='x', rotation=45)
-    ax2.grid(axis='y', alpha=0.3)
-
-    # Add value labels on bars
-    for bar, val in zip(bars2, p95_latencies):
-        ax2.text(bar.get_x() + bar.get_width()/2, bar.get_height() + max(p95_latencies)*0.01,
-                f'{val:.2f}ms', ha='center', va='bottom', fontweight='bold')
-
-    # 3. Throughput Chart
-    bars3 = ax3.bar(implementations, throughputs, color=colors, alpha=0.8, edgecolor='black', linewidth=1)
-    ax3.set_title('Throughput', fontweight='bold', fontsize=14)
-    ax3.set_ylabel('Tokens/sec', fontweight='bold')
-    ax3.tick_params(axis='x', rotation=45)
-    ax3.grid(axis='y', alpha=0.3)
-
-    # Add value labels on bars
-    for bar, val in zip(bars3, throughputs):
-        if val > 0:  # Only show label if throughput was calculated
-            ax3.text(bar.get_x() + bar.get_width()/2, bar.get_height() + max(throughputs)*0.01,
-                    f'{val:.0f}', ha='center', va='bottom', fontweight='bold')
-
-    plt.tight_layout()
-    plt.savefig("moe_performance_comparison.png", dpi=300)
-
-    # Print summary table
-    print("\nPerformance Summary:")
-    print(f"{'Implementation':<30} {'Avg (ms)':<12} {'P95 (ms)':<12} {'Tokens/sec':<12} {'Relative Speed':<15}")
-    print("-"*80)
-
-    # Sort by average latency for relative speed calculation
-    sorted_results = sorted(results.items(), key=lambda x: x[1]['stats']['avg_ms'])
-    fastest_latency = sorted_results[0][1]['stats']['avg_ms']
-
-    for impl, data in sorted_results:
-        avg_ms = data['stats']['avg_ms']
-        p95_ms = data['stats']['p95_ms']
-        tokens_s = data['stats'].get('tokens_per_s', 0)
-        relative_speed = fastest_latency / avg_ms
-
-        print(f"{impl:<30} {avg_ms:>8.2f}    {p95_ms:>8.2f}    {tokens_s:>8.0f}      {relative_speed:>6.2f}x")
-
-    print(f"\nFastest: {sorted_results[0][0]} ({sorted_results[0][1]['stats']['avg_ms']:.2f}ms avg)")
-    if len(sorted_results) > 1:
-        print(f"Slowest: {sorted_results[-1][0]} ({sorted_results[-1][1]['stats']['avg_ms']:.2f}ms avg)")
-        speedup = sorted_results[-1][1]['stats']['avg_ms'] / sorted_results[0][1]['stats']['avg_ms']
-        print(f"Max Speedup: {speedup:.1f}x")
-
- -
-
-
-
-
-
Loaded /repo/moe_benchmarks/megablocks_yamoe/.uvnote/cache/274d1d4e0722f5affb811112832e03d26daafb5eaa96259e7ec575eb43a40f12/yamoe_results.json -Loaded /repo/moe_benchmarks/megablocks_yamoe/.uvnote/cache/0e2a9f24cc405bb3c4ccb37530405ffe7cae24c59066185a87e856b3ac7344b3/binned_results.json -Loaded /repo/moe_benchmarks/megablocks_yamoe/.uvnote/cache/b40a0492fc99c75ce021114ee849e7db60a33cfdf61891ace614b748953db1eb/gptoss_results.json -Loaded /repo/moe_benchmarks/megablocks_yamoe/.uvnote/cache/ab389cf3b8cc56969604061ec8bc29a5701c53cdc24bd2682cf630b5e1eeb7bb/gptoss_training_results.json -Loaded /repo/moe_benchmarks/megablocks_yamoe/.uvnote/cache/0febdf3420999533bc2e14bb2a4bffaba4af699a19ddf644f24806180c8347e1/megablocks_results.json - -Performance Summary: -Implementation Avg (ms) P95 (ms) Tokens/sec Relative Speed --------------------------------------------------------------------------------- -megablocks_results 3.85 8.50 25989 1.00x -yamoe_results 4.25 4.27 23549 0.91x -binned_results 35.79 37.38 2794 0.11x -gptoss_results 45.93 49.06 2177 0.08x -gptoss_training_results 45.95 51.04 2176 0.08x - -Fastest: megablocks_results (3.85ms avg) -Slowest: gptoss_training_results (45.95ms avg) -Max Speedup: 11.9x -
-
-
▶ UV Install Logs
- -
-
-

Artifacts:

-moe_performance_comparison.png -
-moe_performance_comparison.png -
-
-
-
-
- - - \ No newline at end of file