diff --git a/flash_attn/benchmark.html b/flash_attn/benchmark.html new file mode 100644 index 0000000000000000000000000000000000000000..70d0e46ff13160963d3b931674761f5a5b7430a5 --- /dev/null +++ b/flash_attn/benchmark.html @@ -0,0 +1,4253 @@ + + + + + + benchmark + + + + + + + +
+
+
light
+
reset
+ +
+
+ +
+
Generated on:
+
+ Linux x86_64 | Linux-6.12.40-64.114.amzn2023.x86_64-x86_64-with-glibc2.36 +
+
+ +
+
+

title: "Flash Attention Benchmark" +author: "uvnote" +theme: "dark" +syntax_theme: "monokai" +show_line_numbers: true +collapse_code: false +custom_css: | + #output-setup { + overflow-x: auto; + } + .cell-output { + overflow: scroll; + } + .cell-stdout { + width: max-content; + overflow: scroll; + } + .cell-stderr { + width: max-content; + overflow: scroll; + max-height: 300px; + }

+
+
+
+ +▼ code +▼ output + ▶ uv-logs + | +Cell: benchmark | 50.28s | FAILED + | + +Raw +
+
+
+
# /// script
+# dependencies = [
+#   "numpy",
+#   "torch",
+#   "kernels",
+#   "pandas",
+#   "matplotlib"
+# ]
+# ///
+# Benchmarking common shapes for Flux 1024x1024px image + varying text sequence lengths
+
+import functools
+import os
+import pathlib
+
+import matplotlib.pyplot as plt
+import torch
+import torch._dynamo.config
+import triton
+import triton.language as tl
+
+try:
+    from flash_attn import flash_attn_func
+except:
+    flash_attn_func = None
+    print("Flash Attention 2 not found.")
+
+try:
+    from flash_attn_interface import flash_attn_func as flash_attn_3_func
+except:
+    flash_attn_3_func = None
+    print("Flash Attention 3 not found.")
+
+try:
+    from kernels import get_kernel
+    hf_kernels_flash_attn = get_kernel("kernels-community/flash-attn")
+    hf_kernels_flash_attn_3 = get_kernel("kernels-community/flash-attn3")
+except:
+    hf_kernels_flash_attn = None
+    hf_kernels_flash_attn_3 = None
+    print("HF Kernels not found.")
+
+try:
+    from sageattention import sageattn_qk_int8_pv_fp16_cuda, sageattn_qk_int8_pv_fp16_triton, sageattn_qk_int8_pv_fp8_cuda_sm90
+except:
+    sageattn_qk_int8_pv_fp16_cuda = None
+    sageattn_qk_int8_pv_fp16_triton = None
+    sageattn_qk_int8_pv_fp8_cuda_sm90 = None
+    print("SageAttention not found.")
+
+try:
+    from transformer_engine.pytorch.attention import DotProductAttention
+except:
+    DotProductAttention = None
+    print("Transformer Engine not found.")
+
+try:
+    import xformers.ops as xops
+except:
+    xops = None
+    print("xFormers not found.")
+
+
+plt.rcParams.update({
+    "figure.figsize": (12, 10),
+    "figure.dpi": 120,
+    "font.size": 10,
+    "axes.titlesize": 12,
+    "axes.labelsize": 14,
+    "xtick.labelsize": 10,
+    "ytick.labelsize": 10,
+    "legend.fontsize": 8,
+    "axes.grid": True,
+    "grid.alpha": 0.3,
+    "grid.linestyle": "--",
+    "lines.linewidth": 2.0,
+    "lines.markersize": 6,
+    "legend.frameon": True,
+    "legend.framealpha": 0.9,
+    "legend.loc": "best",
+    "axes.spines.top": False,
+    "axes.spines.right": False,
+})
+
+
+# We want to compare the best compiled version for each specific shape (dynamic=False)
+torch._dynamo.config.cache_size_limit = 10000
+
+# We need to suppress_errors for FA3 to work. It makes it run in eager mode.
+# I can't seem to get it to work any other way under torch.compile, so any suggestions are welcome!
+torch._dynamo.config.suppress_errors = True
+
+output_dir = pathlib.Path("dump_attention_benchmark")
+output_dir.mkdir(parents=True, exist_ok=True)
+
+batch_size = 1
+num_attention_heads = 24
+attention_head_dim = 128
+image_sequence_length = 4096  # 1024x1024px
+text_sequence_lengths = [128, 256, 320, 384, 448, 512]
+sequence_lengths = [image_sequence_length + i for i in text_sequence_lengths]
+
+
+def _attention_torch(query, key, value, *, backend):
+    query, key, value = (x.transpose(1, 2).contiguous() for x in (query, key, value))
+    with torch.nn.attention.sdpa_kernel(backend):
+        out = torch.nn.functional.scaled_dot_product_attention(query, key, value)
+    out = out.transpose(1, 2).contiguous()
+    return out
+
+
+_compiled_attention_torch_default = torch.compile(_attention_torch, mode="default", fullgraph=True, dynamic=False)
+def _attention_torch_compile_default(query, key, value, *, backend):
+    return _compiled_attention_torch_default(query, key, value, backend=backend)
+
+
+_compiled_attention_torch_max_autotune = torch.compile(_attention_torch, mode="max-autotune", fullgraph=True, dynamic=False)
+def _attention_torch_compile_max_autotune(query, key, value, *, backend):
+    return _compiled_attention_torch_max_autotune(query, key, value, backend=backend)
+
+
+def _attention_flash_attn_2(query, key, value):
+    return flash_attn_func(query, key, value)
+
+
+_compiled_flash_attn_2_default = torch.compile(_attention_flash_attn_2, mode="default", fullgraph=True, dynamic=False)
+def _attention_flash_attn_2_compile_default(query, key, value):
+    return _compiled_flash_attn_2_default(query, key, value)
+
+
+_compiled_flash_attn_2_max_autotune = torch.compile(_attention_flash_attn_2, mode="max-autotune", fullgraph=True, dynamic=False)
+def _attention_flash_attn_2_compile_max_autotune(query, key, value):
+    return _compiled_flash_attn_2_max_autotune(query, key, value)
+
+
+# For fullgraph=True tracing to be compatible
+@torch.library.custom_op("flash_attn_3::_flash_attn_forward", mutates_args=(), device_types="cuda")
+def _wrapped_flash_attn_3(query: torch.Tensor, key: torch.Tensor, value: torch.Tensor) -> torch.Tensor:
+    out, lse = flash_attn_3_func(query, key, value)
+    return out
+
+
+@torch.library.register_fake("flash_attn_3::_flash_attn_forward")
+def _(query: torch.Tensor, key: torch.Tensor, value: torch.Tensor) -> torch.Tensor:
+    return torch.empty_like(query)
+
+
+def _attention_flash_attn_3(query, key, value):
+    out = _wrapped_flash_attn_3(query, key, value)
+    return out
+
+
+_compiled_flash_attn_3_default = torch.compile(_attention_flash_attn_3, mode="default", fullgraph=True, dynamic=False)
+def _attention_flash_attn_3_compile_default(query, key, value):
+    return _compiled_flash_attn_3_default(query, key, value)
+
+
+_compiled_flash_attn_3_max_autotune = torch.compile(_attention_flash_attn_3, mode="max-autotune", fullgraph=True, dynamic=False)
+def _attention_flash_attn_3_compile_max_autotune(query, key, value):
+    return _compiled_flash_attn_3_max_autotune(query, key, value)
+
+
+def _attention_hf_kernels_flash_attn(query, key, value):
+    return hf_kernels_flash_attn.fwd(query, key, value, is_causal=False)[0]
+
+
+def _attention_hf_kernels_flash_attn3(query, key, value):
+    return hf_kernels_flash_attn_3.flash_attn_func(query, key, value, causal=False)[0]
+
+
+def _attention_sageattn_qk_int8_pv_fp16_cuda(query, key, value):
+    return sageattn_qk_int8_pv_fp16_cuda(query, key, value, tensor_layout="NHD")
+
+
+def _attention_sageattn_qk_int8_pv_fp16_triton(query, key, value):
+    return sageattn_qk_int8_pv_fp16_triton(query, key, value, tensor_layout="NHD")
+
+
+def _attention_sageattn_qk_int8_pv_fp8_cuda_sm90(query, key, value):
+    return sageattn_qk_int8_pv_fp8_cuda_sm90(query, key, value, tensor_layout="NHD")
+
+
+if DotProductAttention is not None:
+    def set_te_backend(backend):
+        # must be applied before first use of
+        # transformer_engine.pytorch.attention
+        os.environ["NVTE_FLASH_ATTN"] = '0'
+        os.environ["NVTE_FUSED_ATTN"] = '0'
+        os.environ["NVTE_UNFUSED_ATTN"] = '0'
+        if backend == 'flash':
+            os.environ["NVTE_FLASH_ATTN"] = '1'
+        if backend == 'fused':
+            os.environ["NVTE_FUSED_ATTN"] = '1'
+        if backend == 'unfused':
+            os.environ["NVTE_UNFUSED_ATTN"] = '1'
+
+    set_te_backend("fused")
+    te_attn_fn = DotProductAttention(
+        num_attention_heads=num_attention_heads,
+        kv_channels=attention_head_dim,
+        qkv_format="bshd",
+        attn_mask_type="no_mask",
+    )
+else:
+    def te_attn_fn(query, key, value):
+        raise RuntimeError("Transformer Engine is not available. Please install it for TE-based attention.")
+
+def _attention_te(query, key, value):
+    out = te_attn_fn(query, key, value)
+    out = out.unflatten(2, (num_attention_heads, attention_head_dim))
+    return out
+
+
+# Cannot fullgraph compile TE
+_compiled_te_attn_fn_default = torch.compile(_attention_te, mode="default", fullgraph=False, dynamic=False)
+def _attention_te_compile_default(query, key, value):
+    return _compiled_te_attn_fn_default(query, key, value)
+
+
+# Cannot fullgraph compile TE
+_compiled_te_attn_fn_max_autotune = torch.compile(_attention_te, mode="max-autotune", fullgraph=False, dynamic=False)
+def _attention_te_compile_max_autotune(query, key, value):
+    return _compiled_te_attn_fn_max_autotune(query, key, value)
+
+
+def _attention_xformers(query, key, value):
+    return xops.memory_efficient_attention(query, key, value)
+
+
+_compiled_xformers_default = torch.compile(_attention_xformers, mode="default", fullgraph=True, dynamic=False)
+def _attention_xformers_compile_default(query, key, value):
+    return _compiled_xformers_default(query, key, value)
+
+
+_compiled_xformers_max_autotune = torch.compile(_attention_xformers, mode="max-autotune", fullgraph=True, dynamic=False)
+def _attention_xformers_compile_max_autotune(query, key, value):
+    return _compiled_xformers_max_autotune(query, key, value)
+
+
+attention_ops = {}
+attention_ops["torch_cudnn"] = functools.partial(_attention_torch, backend=torch.nn.attention.SDPBackend.CUDNN_ATTENTION)
+attention_ops["torch_cudnn_compile_d"] = functools.partial(_attention_torch_compile_default, backend=torch.nn.attention.SDPBackend.CUDNN_ATTENTION)
+attention_ops["torch_cudnn_compile_ma"] = functools.partial(_attention_torch_compile_max_autotune, backend=torch.nn.attention.SDPBackend.CUDNN_ATTENTION)
+attention_ops["torch_flash"] = functools.partial(_attention_torch, backend=torch.nn.attention.SDPBackend.FLASH_ATTENTION)
+attention_ops["torch_flash_compile_d"] = functools.partial(_attention_torch_compile_default, backend=torch.nn.attention.SDPBackend.FLASH_ATTENTION)
+attention_ops["torch_flash_compile_ma"] = functools.partial(_attention_torch_compile_max_autotune, backend=torch.nn.attention.SDPBackend.FLASH_ATTENTION)
+if hf_kernels_flash_attn is not None:
+    attention_ops["hf_flash_attn"] = _attention_hf_kernels_flash_attn
+    attention_ops["hf_flash_attn3"] = _attention_hf_kernels_flash_attn3
+if flash_attn_func is not None:
+    attention_ops["flash_attn_2"] = _attention_flash_attn_2
+    attention_ops["flash_attn_2_compile_d"] = _attention_flash_attn_2_compile_default
+    attention_ops["flash_attn_2_compile_ma"] = _attention_flash_attn_2_compile_max_autotune
+if flash_attn_3_func is not None:
+    attention_ops["flash_attn_3"] = _attention_flash_attn_3
+    attention_ops["flash_attn_3_compile_d"] = _attention_flash_attn_3_compile_default
+    attention_ops["flash_attn_3_compile_ma"] = _attention_flash_attn_3_compile_max_autotune
+if sageattn_qk_int8_pv_fp16_cuda is not None:
+    attention_ops["sageattn_qk_int8_pv_fp16_cuda"] = _attention_sageattn_qk_int8_pv_fp16_cuda
+    attention_ops["sageattn_qk_int8_pv_fp16_triton"] = _attention_sageattn_qk_int8_pv_fp16_triton
+    if torch.cuda.get_device_capability()[0] >= 9:
+        attention_ops["sageattn_qk_int8_pv_fp8_cuda_sm90"] = _attention_sageattn_qk_int8_pv_fp8_cuda_sm90
+if DotProductAttention is not None:
+    attention_ops["te_fused"] = _attention_te
+    attention_ops["te_fused_compile_d"] = _attention_te_compile_default
+    attention_ops["te_fused_compile_ma"] = _attention_te_compile_max_autotune
+if xops is not None:
+    attention_ops["xformers"] = _attention_xformers
+    attention_ops["xformers_compile_d"] = _attention_xformers_compile_default
+    attention_ops["xformers_compile_ma"] = _attention_xformers_compile_max_autotune
+
+
+def get_color_and_linestyle(n: int) -> tuple[str, str]:
+    colors = ["#e41a1c", "#377eb8", "#4daf4a", "#984ea3", "#ff7f00", "#a65628", "#f781bf", "#999999"]
+    line_styles = ["-", ":", "-.", "--"]
+    if n > len(colors) * len(line_styles):
+        raise ValueError(f"Required {n=} styles but maximum is {len(colors) * len(line_styles)}")
+    styles = []
+    for i in range(n):
+        color = colors[i % len(colors)]
+        linestyle = line_styles[i // len(colors)]
+        styles.append((color, linestyle))
+    return styles
+
+
+def correctness():
+    for seq_len in sequence_lengths:
+        shape = (batch_size, seq_len, num_attention_heads, attention_head_dim)
+        print(f"\n\n===== Testing shape: {shape} =====")
+
+        query = torch.randn(shape, device="cuda", dtype=torch.float32)
+        key = torch.randn(shape, device="cuda", dtype=torch.float32)
+        value = torch.randn(shape, device="cuda", dtype=torch.float32)
+
+        golden_truth = _attention_torch(query, key, value, backend=torch.nn.attention.SDPBackend.MATH)
+        query, key, value = (x.bfloat16() for x in (query, key, value))
+
+        for name, fn in attention_ops.items():
+            out = fn(query, key, value)
+            absdiff = (out - golden_truth).abs()
+            absmax = torch.max(absdiff)
+            mae = torch.mean(absdiff)
+            mse = torch.mean((golden_truth - out) ** 2)
+            print(f"{name:<30}: absmax={absmax:.6f}, mae={mae:.6f}, mse={mse:.6f}")
+
+
+@triton.testing.perf_report(
+    triton.testing.Benchmark(
+        x_names=["seq_len"],
+        x_vals=sequence_lengths,
+        x_log=False,
+        line_arg="provider",
+        line_vals=list(attention_ops.keys()),
+        line_names=[x.removeprefix("solution_") for x in attention_ops.keys()],
+        ylabel="Time (ms)",
+        styles=get_color_and_linestyle(len(attention_ops)),
+        plot_name="Attention Benchmark",
+        args={},
+    )
+)
+def benchmark_fn(seq_len: int, provider: str):
+    torch.manual_seed(0)
+
+    shape = (batch_size, seq_len, num_attention_heads, attention_head_dim)
+    query = torch.randn(shape, device="cuda", dtype=torch.bfloat16) * torch.randint(1, 5, shape, device="cuda", dtype=torch.bfloat16)
+    key = torch.randn(shape, device="cuda", dtype=torch.bfloat16) * torch.randint(1, 5, shape, device="cuda", dtype=torch.bfloat16)
+    value = torch.randn(shape, device="cuda", dtype=torch.bfloat16) * torch.randint(1, 5, shape, device="cuda", dtype=torch.bfloat16)
+
+    fn = attention_ops[provider]
+    ms, min_ms, max_ms = triton.testing.do_bench(
+        lambda: fn(query, key, value),
+        warmup=3,
+        rep=10,
+        quantiles=[0.5, 0.2, 0.8],
+    )
+    return ms, max_ms, min_ms
+
+
+with torch.inference_mode():
+    correctness()
+    benchmark_fn.run(print_data=True, save_path=output_dir.as_posix())
+
+ +
+
+
+
+
Flash Attention 2 not found. +Flash Attention 3 not found. +SageAttention not found. +Transformer Engine not found. +xFormers not found. + + +===== Testing shape: (1, 4224, 24, 128) ===== +torch_cudnn : absmax=0.001547, mae=0.000075, mse=0.000000 +
+
+
▶ UV Install Logs
+ +
+
Fetching 20 files: 0%| | 0/20 [00:00<?, ?it/s] +Fetching 20 files: 5%|▌ | 1/20 [00:00<00:08, 2.21it/s] +Fetching 20 files: 10%|█ | 2/20 [00:02<00:21, 1.17s/it] +Fetching 20 files: 100%|██████████| 20/20 [00:02<00:00, 9.41it/s] + +Fetching 4 files: 0%| | 0/4 [00:00<?, ?it/s] +Fetching 4 files: 25%|██▌ | 1/4 [00:00<00:00, 5.28it/s] +Fetching 4 files: 50%|█████ | 2/4 [00:02<00:02, 1.15s/it] +Fetching 4 files: 100%|██████████| 4/4 [00:02<00:00, 1.99it/s] +/tmp/tmpyw1le_3d/cuda_utils.c:5:10: fatal error: Python.h: No such file or directory + 5 | #include <Python.h> + | ^~~~~~~~~~ +compilation terminated. +Traceback (most recent call last): + File "/repo/flash_attn/.uvnote/cells/benchmark.py", line 340, in <module> + correctness() + File "/repo/flash_attn/.uvnote/cells/benchmark.py", line 299, in correctness + out = fn(query, key, value) + ^^^^^^^^^^^^^^^^^^^^^ + File "/repo/flash_attn/.uvnote/cells/benchmark.py", line 114, in _attention_torch_compile_default + return _compiled_attention_torch_default(query, key, value, backend=backend) + ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + File "/tmp/uvnote-run-08by6gh7/home/.cache/uv/environments-v2/benchmark-bfbc462482636f25/lib/python3.11/site-packages/torch/_dynamo/eval_frame.py", line 749, in compile_wrapper + raise e.remove_dynamo_frames() from None # see TORCHDYNAMO_VERBOSE=1 + ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + File "/tmp/uvnote-run-08by6gh7/home/.cache/uv/environments-v2/benchmark-bfbc462482636f25/lib/python3.11/site-packages/torch/_inductor/compile_fx.py", line 923, in _compile_fx_inner + raise InductorError(e, currentframe()).with_traceback( + File "/tmp/uvnote-run-08by6gh7/home/.cache/uv/environments-v2/benchmark-bfbc462482636f25/lib/python3.11/site-packages/torch/_inductor/compile_fx.py", line 907, in _compile_fx_inner + mb_compiled_graph = fx_codegen_and_compile( + ^^^^^^^^^^^^^^^^^^^^^^^ + File "/tmp/uvnote-run-08by6gh7/home/.cache/uv/environments-v2/benchmark-bfbc462482636f25/lib/python3.11/site-packages/torch/_inductor/compile_fx.py", line 1578, in fx_codegen_and_compile + return scheme.codegen_and_compile(gm, example_inputs, inputs_to_check, graph_kwargs) + ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + File "/tmp/uvnote-run-08by6gh7/home/.cache/uv/environments-v2/benchmark-bfbc462482636f25/lib/python3.11/site-packages/torch/_inductor/compile_fx.py", line 1456, in codegen_and_compile + compiled_module = graph.compile_to_module() + ^^^^^^^^^^^^^^^^^^^^^^^^^ + File "/tmp/uvnote-run-08by6gh7/home/.cache/uv/environments-v2/benchmark-bfbc462482636f25/lib/python3.11/site-packages/torch/_inductor/graph.py", line 2293, in compile_to_module + return self._compile_to_module() + ^^^^^^^^^^^^^^^^^^^^^^^^^ + File "/tmp/uvnote-run-08by6gh7/home/.cache/uv/environments-v2/benchmark-bfbc462482636f25/lib/python3.11/site-packages/torch/_inductor/graph.py", line 2299, in _compile_to_module + self.codegen_with_cpp_wrapper() if self.cpp_wrapper else self.codegen() + ^^^^^^^^^^^^^^ + File "/tmp/uvnote-run-08by6gh7/home/.cache/uv/environments-v2/benchmark-bfbc462482636f25/lib/python3.11/site-packages/torch/_inductor/graph.py", line 2238, in codegen + self.scheduler.codegen() + File "/tmp/uvnote-run-08by6gh7/home/.cache/uv/environments-v2/benchmark-bfbc462482636f25/lib/python3.11/site-packages/torch/_inductor/scheduler.py", line 4598, in codegen + else self._codegen(self.nodes) + ^^^^^^^^^^^^^^^^^^^^^^^^^ + File "/tmp/uvnote-run-08by6gh7/home/.cache/uv/environments-v2/benchmark-bfbc462482636f25/lib/python3.11/site-packages/torch/_inductor/scheduler.py", line 4750, in _codegen + self.get_backend(device).codegen_node(node) + File "/tmp/uvnote-run-08by6gh7/home/.cache/uv/environments-v2/benchmark-bfbc462482636f25/lib/python3.11/site-packages/torch/_inductor/codegen/cuda_combined_scheduling.py", line 107, in codegen_node + return self._triton_scheduling.codegen_node(node) + ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + File "/tmp/uvnote-run-08by6gh7/home/.cache/uv/environments-v2/benchmark-bfbc462482636f25/lib/python3.11/site-packages/torch/_inductor/codegen/simd.py", line 1371, in codegen_node + return self.codegen_node_schedule( + ^^^^^^^^^^^^^^^^^^^^^^^^^^^ + File "/tmp/uvnote-run-08by6gh7/home/.cache/uv/environments-v2/benchmark-bfbc462482636f25/lib/python3.11/site-packages/torch/_inductor/codegen/simd.py", line 1424, in codegen_node_schedule + src_code = kernel.codegen_kernel() + ^^^^^^^^^^^^^^^^^^^^^^^ + File "/tmp/uvnote-run-08by6gh7/home/.cache/uv/environments-v2/benchmark-bfbc462482636f25/lib/python3.11/site-packages/torch/_inductor/codegen/triton.py", line 3677, in codegen_kernel + **self.inductor_meta_common(), + ^^^^^^^^^^^^^^^^^^^^^^^^^^^ + File "/tmp/uvnote-run-08by6gh7/home/.cache/uv/environments-v2/benchmark-bfbc462482636f25/lib/python3.11/site-packages/torch/_inductor/codegen/triton.py", line 3501, in inductor_meta_common + "backend_hash": torch.utils._triton.triton_hash_with_backend(), + ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + File "/tmp/uvnote-run-08by6gh7/home/.cache/uv/environments-v2/benchmark-bfbc462482636f25/lib/python3.11/site-packages/torch/utils/_triton.py", line 165, in triton_hash_with_backend + backend = triton_backend() + ^^^^^^^^^^^^^^^^ + File "/tmp/uvnote-run-08by6gh7/home/.cache/uv/environments-v2/benchmark-bfbc462482636f25/lib/python3.11/site-packages/torch/utils/_triton.py", line 157, in triton_backend + target = driver.active.get_current_target() + ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + File "/tmp/uvnote-run-08by6gh7/home/.cache/uv/environments-v2/benchmark-bfbc462482636f25/lib/python3.11/site-packages/triton/runtime/driver.py", line 30, in __getattr__ + return getattr(self._initialize_obj(), name) + ^^^^^^^^^^^^^^^^^^^^^^ + File "/tmp/uvnote-run-08by6gh7/home/.cache/uv/environments-v2/benchmark-bfbc462482636f25/lib/python3.11/site-packages/triton/runtime/driver.py", line 26, in _initialize_obj + self._obj = self._init_fn() + ^^^^^^^^^^^^^^^ + File "/tmp/uvnote-run-08by6gh7/home/.cache/uv/environments-v2/benchmark-bfbc462482636f25/lib/python3.11/site-packages/triton/runtime/driver.py", line 12, in _create_driver + return active_drivers[0]() + ^^^^^^^^^^^^^^^^^^^ + File "/tmp/uvnote-run-08by6gh7/home/.cache/uv/environments-v2/benchmark-bfbc462482636f25/lib/python3.11/site-packages/triton/backends/nvidia/driver.py", line 715, in __init__ + self.utils = CudaUtils() # TODO: make static + ^^^^^^^^^^^ + File "/tmp/uvnote-run-08by6gh7/home/.cache/uv/environments-v2/benchmark-bfbc462482636f25/lib/python3.11/site-packages/triton/backends/nvidia/driver.py", line 62, in __init__ + mod = compile_module_from_src( + ^^^^^^^^^^^^^^^^^^^^^^^^ + File "/tmp/uvnote-run-08by6gh7/home/.cache/uv/environments-v2/benchmark-bfbc462482636f25/lib/python3.11/site-packages/triton/runtime/build.py", line 88, in compile_module_from_src + so = _build(name, src_path, tmpdir, library_dirs or [], include_dirs or [], libraries or []) + ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + File "/tmp/uvnote-run-08by6gh7/home/.cache/uv/environments-v2/benchmark-bfbc462482636f25/lib/python3.11/site-packages/triton/runtime/build.py", line 51, in _build + subprocess.check_call(cc_cmd, stdout=subprocess.DEVNULL) + File "/usr/lib/python3.11/subprocess.py", line 413, in check_call + raise CalledProcessError(retcode, cmd) +torch._inductor.exc.InductorError: CalledProcessError: Command '['/usr/bin/gcc', '/tmp/tmpyw1le_3d/cuda_utils.c', '-O3', '-shared', '-fPIC', '-Wno-psabi', '-o', '/tmp/tmpyw1le_3d/cuda_utils.cpython-311-x86_64-linux-gnu.so', '-lcuda', '-L/tmp/uvnote-run-08by6gh7/home/.cache/uv/environments-v2/benchmark-bfbc462482636f25/lib/python3.11/site-packages/triton/backends/nvidia/lib', '-L/usr/lib/x86_64-linux-gnu', '-I/tmp/uvnote-run-08by6gh7/home/.cache/uv/environments-v2/benchmark-bfbc462482636f25/lib/python3.11/site-packages/triton/backends/nvidia/include', '-I/tmp/tmpyw1le_3d', '-I/usr/include/python3.11']' returned non-zero exit status 1. + +Set TORCHDYNAMO_VERBOSE=1 for the internal stack trace (please do this especially if you're reporting a bug to PyTorch). For even more developer context, set TORCH_LOGS="+dynamo"
+
+
+
+ + + \ No newline at end of file diff --git a/flash_attn/cells/benchmark.py b/flash_attn/cells/benchmark.py new file mode 100644 index 0000000000000000000000000000000000000000..e5caf7572b9a79185f0d090fff80e62027a77237 --- /dev/null +++ b/flash_attn/cells/benchmark.py @@ -0,0 +1,341 @@ +# /// script +# dependencies = [ +# "numpy", +# "torch", +# "kernels", +# "pandas", +# "matplotlib" +# ] +# /// +# Benchmarking common shapes for Flux 1024x1024px image + varying text sequence lengths + +import functools +import os +import pathlib + +import matplotlib.pyplot as plt +import torch +import torch._dynamo.config +import triton +import triton.language as tl + +try: + from flash_attn import flash_attn_func +except: + flash_attn_func = None + print("Flash Attention 2 not found.") + +try: + from flash_attn_interface import flash_attn_func as flash_attn_3_func +except: + flash_attn_3_func = None + print("Flash Attention 3 not found.") + +try: + from kernels import get_kernel + hf_kernels_flash_attn = get_kernel("kernels-community/flash-attn") + hf_kernels_flash_attn_3 = get_kernel("kernels-community/flash-attn3") +except: + hf_kernels_flash_attn = None + hf_kernels_flash_attn_3 = None + print("HF Kernels not found.") + +try: + from sageattention import sageattn_qk_int8_pv_fp16_cuda, sageattn_qk_int8_pv_fp16_triton, sageattn_qk_int8_pv_fp8_cuda_sm90 +except: + sageattn_qk_int8_pv_fp16_cuda = None + sageattn_qk_int8_pv_fp16_triton = None + sageattn_qk_int8_pv_fp8_cuda_sm90 = None + print("SageAttention not found.") + +try: + from transformer_engine.pytorch.attention import DotProductAttention +except: + DotProductAttention = None + print("Transformer Engine not found.") + +try: + import xformers.ops as xops +except: + xops = None + print("xFormers not found.") + + +plt.rcParams.update({ + "figure.figsize": (12, 10), + "figure.dpi": 120, + "font.size": 10, + "axes.titlesize": 12, + "axes.labelsize": 14, + "xtick.labelsize": 10, + "ytick.labelsize": 10, + "legend.fontsize": 8, + "axes.grid": True, + "grid.alpha": 0.3, + "grid.linestyle": "--", + "lines.linewidth": 2.0, + "lines.markersize": 6, + "legend.frameon": True, + "legend.framealpha": 0.9, + "legend.loc": "best", + "axes.spines.top": False, + "axes.spines.right": False, +}) + + +# We want to compare the best compiled version for each specific shape (dynamic=False) +torch._dynamo.config.cache_size_limit = 10000 + +# We need to suppress_errors for FA3 to work. It makes it run in eager mode. +# I can't seem to get it to work any other way under torch.compile, so any suggestions are welcome! +torch._dynamo.config.suppress_errors = True + +output_dir = pathlib.Path("dump_attention_benchmark") +output_dir.mkdir(parents=True, exist_ok=True) + +batch_size = 1 +num_attention_heads = 24 +attention_head_dim = 128 +image_sequence_length = 4096 # 1024x1024px +text_sequence_lengths = [128, 256, 320, 384, 448, 512] +sequence_lengths = [image_sequence_length + i for i in text_sequence_lengths] + + +def _attention_torch(query, key, value, *, backend): + query, key, value = (x.transpose(1, 2).contiguous() for x in (query, key, value)) + with torch.nn.attention.sdpa_kernel(backend): + out = torch.nn.functional.scaled_dot_product_attention(query, key, value) + out = out.transpose(1, 2).contiguous() + return out + + +_compiled_attention_torch_default = torch.compile(_attention_torch, mode="default", fullgraph=True, dynamic=False) +def _attention_torch_compile_default(query, key, value, *, backend): + return _compiled_attention_torch_default(query, key, value, backend=backend) + + +_compiled_attention_torch_max_autotune = torch.compile(_attention_torch, mode="max-autotune", fullgraph=True, dynamic=False) +def _attention_torch_compile_max_autotune(query, key, value, *, backend): + return _compiled_attention_torch_max_autotune(query, key, value, backend=backend) + + +def _attention_flash_attn_2(query, key, value): + return flash_attn_func(query, key, value) + + +_compiled_flash_attn_2_default = torch.compile(_attention_flash_attn_2, mode="default", fullgraph=True, dynamic=False) +def _attention_flash_attn_2_compile_default(query, key, value): + return _compiled_flash_attn_2_default(query, key, value) + + +_compiled_flash_attn_2_max_autotune = torch.compile(_attention_flash_attn_2, mode="max-autotune", fullgraph=True, dynamic=False) +def _attention_flash_attn_2_compile_max_autotune(query, key, value): + return _compiled_flash_attn_2_max_autotune(query, key, value) + + +# For fullgraph=True tracing to be compatible +@torch.library.custom_op("flash_attn_3::_flash_attn_forward", mutates_args=(), device_types="cuda") +def _wrapped_flash_attn_3(query: torch.Tensor, key: torch.Tensor, value: torch.Tensor) -> torch.Tensor: + out, lse = flash_attn_3_func(query, key, value) + return out + + +@torch.library.register_fake("flash_attn_3::_flash_attn_forward") +def _(query: torch.Tensor, key: torch.Tensor, value: torch.Tensor) -> torch.Tensor: + return torch.empty_like(query) + + +def _attention_flash_attn_3(query, key, value): + out = _wrapped_flash_attn_3(query, key, value) + return out + + +_compiled_flash_attn_3_default = torch.compile(_attention_flash_attn_3, mode="default", fullgraph=True, dynamic=False) +def _attention_flash_attn_3_compile_default(query, key, value): + return _compiled_flash_attn_3_default(query, key, value) + + +_compiled_flash_attn_3_max_autotune = torch.compile(_attention_flash_attn_3, mode="max-autotune", fullgraph=True, dynamic=False) +def _attention_flash_attn_3_compile_max_autotune(query, key, value): + return _compiled_flash_attn_3_max_autotune(query, key, value) + + +def _attention_hf_kernels_flash_attn(query, key, value): + return hf_kernels_flash_attn.fwd(query, key, value, is_causal=False)[0] + + +def _attention_hf_kernels_flash_attn3(query, key, value): + return hf_kernels_flash_attn_3.flash_attn_func(query, key, value, causal=False)[0] + + +def _attention_sageattn_qk_int8_pv_fp16_cuda(query, key, value): + return sageattn_qk_int8_pv_fp16_cuda(query, key, value, tensor_layout="NHD") + + +def _attention_sageattn_qk_int8_pv_fp16_triton(query, key, value): + return sageattn_qk_int8_pv_fp16_triton(query, key, value, tensor_layout="NHD") + + +def _attention_sageattn_qk_int8_pv_fp8_cuda_sm90(query, key, value): + return sageattn_qk_int8_pv_fp8_cuda_sm90(query, key, value, tensor_layout="NHD") + + +if DotProductAttention is not None: + def set_te_backend(backend): + # must be applied before first use of + # transformer_engine.pytorch.attention + os.environ["NVTE_FLASH_ATTN"] = '0' + os.environ["NVTE_FUSED_ATTN"] = '0' + os.environ["NVTE_UNFUSED_ATTN"] = '0' + if backend == 'flash': + os.environ["NVTE_FLASH_ATTN"] = '1' + if backend == 'fused': + os.environ["NVTE_FUSED_ATTN"] = '1' + if backend == 'unfused': + os.environ["NVTE_UNFUSED_ATTN"] = '1' + + set_te_backend("fused") + te_attn_fn = DotProductAttention( + num_attention_heads=num_attention_heads, + kv_channels=attention_head_dim, + qkv_format="bshd", + attn_mask_type="no_mask", + ) +else: + def te_attn_fn(query, key, value): + raise RuntimeError("Transformer Engine is not available. Please install it for TE-based attention.") + +def _attention_te(query, key, value): + out = te_attn_fn(query, key, value) + out = out.unflatten(2, (num_attention_heads, attention_head_dim)) + return out + + +# Cannot fullgraph compile TE +_compiled_te_attn_fn_default = torch.compile(_attention_te, mode="default", fullgraph=False, dynamic=False) +def _attention_te_compile_default(query, key, value): + return _compiled_te_attn_fn_default(query, key, value) + + +# Cannot fullgraph compile TE +_compiled_te_attn_fn_max_autotune = torch.compile(_attention_te, mode="max-autotune", fullgraph=False, dynamic=False) +def _attention_te_compile_max_autotune(query, key, value): + return _compiled_te_attn_fn_max_autotune(query, key, value) + + +def _attention_xformers(query, key, value): + return xops.memory_efficient_attention(query, key, value) + + +_compiled_xformers_default = torch.compile(_attention_xformers, mode="default", fullgraph=True, dynamic=False) +def _attention_xformers_compile_default(query, key, value): + return _compiled_xformers_default(query, key, value) + + +_compiled_xformers_max_autotune = torch.compile(_attention_xformers, mode="max-autotune", fullgraph=True, dynamic=False) +def _attention_xformers_compile_max_autotune(query, key, value): + return _compiled_xformers_max_autotune(query, key, value) + + +attention_ops = {} +attention_ops["torch_cudnn"] = functools.partial(_attention_torch, backend=torch.nn.attention.SDPBackend.CUDNN_ATTENTION) +attention_ops["torch_cudnn_compile_d"] = functools.partial(_attention_torch_compile_default, backend=torch.nn.attention.SDPBackend.CUDNN_ATTENTION) +attention_ops["torch_cudnn_compile_ma"] = functools.partial(_attention_torch_compile_max_autotune, backend=torch.nn.attention.SDPBackend.CUDNN_ATTENTION) +attention_ops["torch_flash"] = functools.partial(_attention_torch, backend=torch.nn.attention.SDPBackend.FLASH_ATTENTION) +attention_ops["torch_flash_compile_d"] = functools.partial(_attention_torch_compile_default, backend=torch.nn.attention.SDPBackend.FLASH_ATTENTION) +attention_ops["torch_flash_compile_ma"] = functools.partial(_attention_torch_compile_max_autotune, backend=torch.nn.attention.SDPBackend.FLASH_ATTENTION) +if hf_kernels_flash_attn is not None: + attention_ops["hf_flash_attn"] = _attention_hf_kernels_flash_attn + attention_ops["hf_flash_attn3"] = _attention_hf_kernels_flash_attn3 +if flash_attn_func is not None: + attention_ops["flash_attn_2"] = _attention_flash_attn_2 + attention_ops["flash_attn_2_compile_d"] = _attention_flash_attn_2_compile_default + attention_ops["flash_attn_2_compile_ma"] = _attention_flash_attn_2_compile_max_autotune +if flash_attn_3_func is not None: + attention_ops["flash_attn_3"] = _attention_flash_attn_3 + attention_ops["flash_attn_3_compile_d"] = _attention_flash_attn_3_compile_default + attention_ops["flash_attn_3_compile_ma"] = _attention_flash_attn_3_compile_max_autotune +if sageattn_qk_int8_pv_fp16_cuda is not None: + attention_ops["sageattn_qk_int8_pv_fp16_cuda"] = _attention_sageattn_qk_int8_pv_fp16_cuda + attention_ops["sageattn_qk_int8_pv_fp16_triton"] = _attention_sageattn_qk_int8_pv_fp16_triton + if torch.cuda.get_device_capability()[0] >= 9: + attention_ops["sageattn_qk_int8_pv_fp8_cuda_sm90"] = _attention_sageattn_qk_int8_pv_fp8_cuda_sm90 +if DotProductAttention is not None: + attention_ops["te_fused"] = _attention_te + attention_ops["te_fused_compile_d"] = _attention_te_compile_default + attention_ops["te_fused_compile_ma"] = _attention_te_compile_max_autotune +if xops is not None: + attention_ops["xformers"] = _attention_xformers + attention_ops["xformers_compile_d"] = _attention_xformers_compile_default + attention_ops["xformers_compile_ma"] = _attention_xformers_compile_max_autotune + + +def get_color_and_linestyle(n: int) -> tuple[str, str]: + colors = ["#e41a1c", "#377eb8", "#4daf4a", "#984ea3", "#ff7f00", "#a65628", "#f781bf", "#999999"] + line_styles = ["-", ":", "-.", "--"] + if n > len(colors) * len(line_styles): + raise ValueError(f"Required {n=} styles but maximum is {len(colors) * len(line_styles)}") + styles = [] + for i in range(n): + color = colors[i % len(colors)] + linestyle = line_styles[i // len(colors)] + styles.append((color, linestyle)) + return styles + + +def correctness(): + for seq_len in sequence_lengths: + shape = (batch_size, seq_len, num_attention_heads, attention_head_dim) + print(f"\n\n===== Testing shape: {shape} =====") + + query = torch.randn(shape, device="cuda", dtype=torch.float32) + key = torch.randn(shape, device="cuda", dtype=torch.float32) + value = torch.randn(shape, device="cuda", dtype=torch.float32) + + golden_truth = _attention_torch(query, key, value, backend=torch.nn.attention.SDPBackend.MATH) + query, key, value = (x.bfloat16() for x in (query, key, value)) + + for name, fn in attention_ops.items(): + out = fn(query, key, value) + absdiff = (out - golden_truth).abs() + absmax = torch.max(absdiff) + mae = torch.mean(absdiff) + mse = torch.mean((golden_truth - out) ** 2) + print(f"{name:<30}: absmax={absmax:.6f}, mae={mae:.6f}, mse={mse:.6f}") + + +@triton.testing.perf_report( + triton.testing.Benchmark( + x_names=["seq_len"], + x_vals=sequence_lengths, + x_log=False, + line_arg="provider", + line_vals=list(attention_ops.keys()), + line_names=[x.removeprefix("solution_") for x in attention_ops.keys()], + ylabel="Time (ms)", + styles=get_color_and_linestyle(len(attention_ops)), + plot_name="Attention Benchmark", + args={}, + ) +) +def benchmark_fn(seq_len: int, provider: str): + torch.manual_seed(0) + + shape = (batch_size, seq_len, num_attention_heads, attention_head_dim) + query = torch.randn(shape, device="cuda", dtype=torch.bfloat16) * torch.randint(1, 5, shape, device="cuda", dtype=torch.bfloat16) + key = torch.randn(shape, device="cuda", dtype=torch.bfloat16) * torch.randint(1, 5, shape, device="cuda", dtype=torch.bfloat16) + value = torch.randn(shape, device="cuda", dtype=torch.bfloat16) * torch.randint(1, 5, shape, device="cuda", dtype=torch.bfloat16) + + fn = attention_ops[provider] + ms, min_ms, max_ms = triton.testing.do_bench( + lambda: fn(query, key, value), + warmup=3, + rep=10, + quantiles=[0.5, 0.2, 0.8], + ) + return ms, max_ms, min_ms + + +with torch.inference_mode(): + correctness() + benchmark_fn.run(print_data=True, save_path=output_dir.as_posix()) diff --git a/flash_attn/index.html b/flash_attn/index.html new file mode 100644 index 0000000000000000000000000000000000000000..398172379434c672102a2bd0e4175dcb8e06f75e --- /dev/null +++ b/flash_attn/index.html @@ -0,0 +1,24 @@ + + + + + Directory Index + + + +

Index of /flash_attn

+ + + \ No newline at end of file diff --git a/index.html b/index.html index 689cacba8fe49eeacaab52239a8357160121c5b8..0c248cc83161f77a6f15eafdaf3758bde0650b7a 100644 --- a/index.html +++ b/index.html @@ -17,8 +17,8 @@

Index of /

\ No newline at end of file diff --git a/moe_benchmarks/megablocks/cells/forward_and_backward.py b/moe_benchmarks/megablocks/cells/forward_and_backward.py new file mode 100644 index 0000000000000000000000000000000000000000..a8ac420c8a43009eb857f3a7889b4f79ad5a1191 --- /dev/null +++ b/moe_benchmarks/megablocks/cells/forward_and_backward.py @@ -0,0 +1,196 @@ +# /// script +# requires-python = ">=3.12" +# dependencies = [ +# "accelerate>=1.10.1", +# "torch>=2.7.0", +# "kernels==0.10.0", +# "transformers@https://github.com/huggingface/transformers.git", +# "ipdb>=0.13.13", +# "matplotlib>=3.7.2", +# "numpy>=1.24.3", +# ] +# /// + +import torch +from transformers import GptOssForCausalLM, PreTrainedTokenizerFast, Mxfp4Config +import time +import torch.nn as nn +from kernels import register_kernel_mapping, Mode, LayerRepository, replace_kernel_forward_from_hub +import sys +import torch.profiler +import gc +import logging +from transformers.models.gpt_oss.modeling_gpt_oss import GptOssRMSNorm + +# remove liger kernel for testing +replace_kernel_forward_from_hub(GptOssRMSNorm, None) + +# set to debug logging +logging.basicConfig(level=logging.INFO) + +def reset_peak_memory_stats(): + """Clear CUDA cache and reset memory allocation counters.""" + torch.cuda.empty_cache() + if torch.cuda.is_available(): + torch.cuda.reset_peak_memory_stats() + gc.collect() + +def get_memory_stats(): + """Get current and peak CUDA memory usage.""" + if not torch.cuda.is_available(): + return {"allocated_gb": 0, "peak_gb": 0, "reserved_gb": 0} + return { + "allocated_gb": torch.cuda.memory_allocated() / 1e9, + "peak_gb": torch.cuda.max_memory_allocated() / 1e9, + "reserved_gb": torch.cuda.memory_reserved() / 1e9, + } + +def override_kernel_layer_name(cls_name: str, value) -> bool: + """Helper to dynamically override the kernel_layer_name in a model class.""" + for mod in sys.modules.values(): + if mod is None: + continue + obj = getattr(mod, cls_name, None) + if isinstance(obj, type) and issubclass(obj, nn.Module): + setattr(obj, "kernel_layer_name", value) + print(f"Overrode {cls_name}.kernel_layer_name to {value}") + return True + return False + + +# Init the model the normal way +model_id = "openai/gpt-oss-20b" +tokenizer = PreTrainedTokenizerFast.from_pretrained(model_id) +quantization_config = Mxfp4Config(dequantize=True) + +model = GptOssForCausalLM.from_pretrained( + model_id, + dtype="bfloat16", + device_map="auto", + use_kernels=True, + quantization_config=quantization_config, +).eval() + +messages = [ + {"role": "system", "content": "What is Tensor Parallelism?"}, +] + +inputs = tokenizer.apply_chat_template( + messages, + add_generation_prompt=True, + return_tensors="pt", + return_dict=True, + reasoning_effort="low", +).to("cuda") + +max_tokens = 128 # Reduced to help with memory usage + +# Clear memory before backward pass +reset_peak_memory_stats() +print(f"Pre-generation memory: {get_memory_stats()}") + +# forward and backward pass +with torch.autograd.set_grad_enabled(True): + start_time = time.perf_counter() + generated = model.generate( + **inputs, + max_new_tokens=max_tokens, + do_sample=False, + temperature=None, + ) + end_time = time.perf_counter() + print(tokenizer.decode(generated[0], skip_special_tokens=False)) + print(f"Generation took {end_time - start_time:.2f} seconds") + print(f"Post-generation memory: {get_memory_stats()}") + + # Use gradient checkpointing to reduce memory usage + if hasattr(model, 'gradient_checkpointing_enable'): + model.gradient_checkpointing_enable() + print("Enabled gradient checkpointing") + + # Reduce sequence length if needed for memory + max_seq_len = 512 # Limit sequence length for backward pass + if generated.size(1) > max_seq_len: + print(f"Truncating sequence from {generated.size(1)} to {max_seq_len} tokens") + full_sequence = generated[:, -max_seq_len:] + else: + full_sequence = generated + + # Get model outputs for the full sequence + model.train() # Enable dropout and other training behaviors + + try: + outputs = model( + input_ids=full_sequence, + labels=full_sequence, # This will compute loss internally + return_dict=True + ) + print(f"Post-forward memory: {get_memory_stats()}") + + # If model doesn't compute loss, compute it manually + if outputs.loss is None: + shift_logits = outputs.logits[..., :-1, :].contiguous() + shift_labels = full_sequence[..., 1:].contiguous() + + # Use CrossEntropyLoss with ignore_index for padding tokens + loss_fct = torch.nn.CrossEntropyLoss(ignore_index=tokenizer.pad_token_id if tokenizer.pad_token_id is not None else -100) + loss = loss_fct( + shift_logits.view(-1, shift_logits.size(-1)), + shift_labels.view(-1) + ) + else: + loss = outputs.loss + + print(f"Loss: {loss.item():.4f}") + + # Clear intermediate tensors to save memory + del outputs + torch.cuda.empty_cache() + + # Perform backward pass with memory management + print("Running backward pass...") + print(f"Pre-backward memory: {get_memory_stats()}") + + loss.backward() + print(f"Post-backward memory: {get_memory_stats()}") + + except torch.cuda.OutOfMemoryError as e: + print(f"OOM during forward/backward pass: {e}") + print("Try reducing max_tokens or max_seq_len") + raise + + # Calculate gradient statistics and print sample gradients + total_norm = 0.0 + param_count = 0 + grad_samples = {} + + for name, p in model.named_parameters(): + if p.grad is not None: + param_count += 1 + grad_norm = p.grad.data.norm(2).item() + total_norm += grad_norm ** 2 + + # Collect gradient statistics for key layers + if any(key in name for key in ['embed', 'lm_head', 'mlp.up', 'mlp.down', 'self_attn.q_proj', 'norm']): + grad_samples[name] = { + 'norm': grad_norm, + 'mean': p.grad.data.mean().item(), + 'std': p.grad.data.std().item(), + 'max': p.grad.data.max().item(), + 'min': p.grad.data.min().item(), + } + + total_norm = total_norm ** 0.5 + + print(f"\nGradient norm: {total_norm:.4f}") + print(f"Parameters with gradients: {param_count}") + + # Print sample gradients from important layers + print("\nSample gradient statistics:") + for i, (name, stats) in enumerate(list(grad_samples.items())[:10]): + print(f" {name[:60]:<60} | norm: {stats['norm']:.4e} | mean: {stats['mean']:.4e} | std: {stats['std']:.4e}") + + # Optional: zero gradients for next iteration + model.zero_grad() + model.eval() # Switch back to eval mode + diff --git a/moe_benchmarks/megablocks/megablocks_only.html b/moe_benchmarks/megablocks/megablocks_only.html index 8606aa9eddbd37a08f18ccfdeb910a8caa1cf0b5..2a81ff4825c4828a6c1b4b0e16548e5a18bc2114 100644 --- a/moe_benchmarks/megablocks/megablocks_only.html +++ b/moe_benchmarks/megablocks/megablocks_only.html @@ -3710,7 +3710,7 @@ span.linenos.special { color: #000000; background-color: #ffffc0; padding-left:
Generated on:
- Linux x86_64 | Linux-6.11.0-1018-azure-x86_64-with-glibc2.39 + Linux x86_64 | Linux-6.12.40-64.114.amzn2023.x86_64-x86_64-with-glibc2.36
@@ -3724,122 +3724,219 @@ span.linenos.special { color: #000000; background-color: #ffffc0; padding-left:

Next we can run with Megablocks kernels enabled.

Forward

First, we run a forward pass with Megablocks kernels.

-
+

Forward and Backward

+

Next, we run a forward and backward pass with Megablocks kernels enabled. This should be more memory efficient and allow us to complete the backward pass without running out of memory.

+
-▼ code -▼ output - ▶ uv-logs +▼ code +▼ output + ▶ uv-logs | -Cell: forward_only | 118.48s | FAILED - | - -Raw +Cell: forward_and_backward | 19.43s | FAILED + | + +Raw
-
+
-
-1 -2 -3 -4 -5 -6 -7 -8 -9 -10 -11 -12 -13 -14 -15 -16 -17 -18 -19 -20 -21 -22 -23 -24 -25 -26 -27 -28 -29 -30 -31 -32 -33 -34 -35 -36 -37 -38 -39 -40 -41 -42 -43 -44 -45 -46 -47 -48 -49 -50 -51 -52 -53 -54 -55 -56 -57 -58 -59 -60 -61 -62 -63 -64 -65 -66 -67 -68 -69 -70 -71 -72 -73 -74 -75 -76 -77 -78 -79 -80 -81 -82 -83 -84 -85 -86 -87 -88 -89 -90 -91 -92 -93 -94 -95 -96 -97 -98 -99 -100 -101 +
+1 +2 +3 +4 +5 +6 +7 +8 +9 +10 +11 +12 +13 +14 +15 +16 +17 +18 +19 +20 +21 +22 +23 +24 +25 +26 +27 +28 +29 +30 +31 +32 +33 +34 +35 +36 +37 +38 +39 +40 +41 +42 +43 +44 +45 +46 +47 +48 +49 +50 +51 +52 +53 +54 +55 +56 +57 +58 +59 +60 +61 +62 +63 +64 +65 +66 +67 +68 +69 +70 +71 +72 +73 +74 +75 +76 +77 +78 +79 +80 +81 +82 +83 +84 +85 +86 +87 +88 +89 +90 +91 +92 +93 +94 +95 +96 +97 +98 +99 +100 +101 +102 +103 +104 +105 +106 +107 +108 +109 +110 +111 +112 +113 +114 +115 +116 +117 +118 +119 +120 +121 +122 +123 +124 +125 +126 +127 +128 +129 +130 +131 +132 +133 +134 +135 +136 +137 +138 +139 +140 +141 +142 +143 +144 +145 +146 +147 +148 +149 +150 +151 +152 +153 +154 +155 +156 +157 +158 +159 +160 +161 +162 +163 +164 +165 +166 +167 +168 +169 +170 +171 +172 +173 +174 +175 +176 +177 +178 +179 +180 +181 +182 +183 +184 +185 +186 +187 +188 +189 +190 +191 +192 +193 +194 +195 +196
# /// script
@@ -3866,7 +3963,7 @@ Cell: forward_only | 118.48s | FAILED
 import logging
 from transformers.models.gpt_oss.modeling_gpt_oss import GptOssRMSNorm
 
-
+# remove liger kernel for testing 
 replace_kernel_forward_from_hub(GptOssRMSNorm, None)
 
 # set to debug logging
@@ -3907,8 +4004,6 @@ Cell: forward_only | 118.48s | FAILED
 tokenizer = PreTrainedTokenizerFast.from_pretrained(model_id)
 quantization_config = Mxfp4Config(dequantize=True)
 
-
-
 model = GptOssForCausalLM.from_pretrained(
     model_id,
     dtype="bfloat16",
@@ -3929,9 +4024,14 @@ Cell: forward_only | 118.48s | FAILED
     reasoning_effort="low",
 ).to("cuda")
 
-max_tokens = 256
+max_tokens = 128  # Reduced to help with memory usage
+
+# Clear memory before backward pass
+reset_peak_memory_stats()
+print(f"Pre-generation memory: {get_memory_stats()}")
 
-with torch.inference_mode():
+# forward and backward pass
+with torch.autograd.set_grad_enabled(True):
     start_time = time.perf_counter()
     generated = model.generate(
         **inputs,
@@ -3940,144 +4040,124 @@ Cell: forward_only | 118.48s | FAILED
         temperature=None,
     )
     end_time = time.perf_counter()
-
-print(tokenizer.decode(generated[0], skip_special_tokens=False))
-print(f"Generation took {end_time - start_time:.2f} seconds")
+    print(tokenizer.decode(generated[0], skip_special_tokens=False))
+    print(f"Generation took {end_time - start_time:.2f} seconds")
+    print(f"Post-generation memory: {get_memory_stats()}")
+
+    # Use gradient checkpointing to reduce memory usage
+    if hasattr(model, 'gradient_checkpointing_enable'):
+        model.gradient_checkpointing_enable()
+        print("Enabled gradient checkpointing")
+
+    # Reduce sequence length if needed for memory
+    max_seq_len = 512  # Limit sequence length for backward pass
+    if generated.size(1) > max_seq_len:
+        print(f"Truncating sequence from {generated.size(1)} to {max_seq_len} tokens")
+        full_sequence = generated[:, -max_seq_len:]
+    else:
+        full_sequence = generated
+
+    # Get model outputs for the full sequence
+    model.train()  # Enable dropout and other training behaviors
+
+    try:
+        outputs = model(
+            input_ids=full_sequence,
+            labels=full_sequence,  # This will compute loss internally
+            return_dict=True
+        )
+        print(f"Post-forward memory: {get_memory_stats()}")
+
+        # If model doesn't compute loss, compute it manually
+        if outputs.loss is None:
+            shift_logits = outputs.logits[..., :-1, :].contiguous()
+            shift_labels = full_sequence[..., 1:].contiguous()
+
+            # Use CrossEntropyLoss with ignore_index for padding tokens
+            loss_fct = torch.nn.CrossEntropyLoss(ignore_index=tokenizer.pad_token_id if tokenizer.pad_token_id is not None else -100)
+            loss = loss_fct(
+                shift_logits.view(-1, shift_logits.size(-1)),
+                shift_labels.view(-1)
+            )
+        else:
+            loss = outputs.loss
+
+        print(f"Loss: {loss.item():.4f}")
+
+        # Clear intermediate tensors to save memory
+        del outputs
+        torch.cuda.empty_cache()
+
+        # Perform backward pass with memory management
+        print("Running backward pass...")
+        print(f"Pre-backward memory: {get_memory_stats()}")
+
+        loss.backward()
+        print(f"Post-backward memory: {get_memory_stats()}")
+
+    except torch.cuda.OutOfMemoryError as e:
+        print(f"OOM during forward/backward pass: {e}")
+        print("Try reducing max_tokens or max_seq_len")
+        raise
+
+    # Calculate gradient statistics and print sample gradients
+    total_norm = 0.0
+    param_count = 0
+    grad_samples = {}
+
+    for name, p in model.named_parameters():
+        if p.grad is not None:
+            param_count += 1
+            grad_norm = p.grad.data.norm(2).item()
+            total_norm += grad_norm ** 2
+
+            # Collect gradient statistics for key layers
+            if any(key in name for key in ['embed', 'lm_head', 'mlp.up', 'mlp.down', 'self_attn.q_proj', 'norm']):
+                grad_samples[name] = {
+                    'norm': grad_norm,
+                    'mean': p.grad.data.mean().item(),
+                    'std': p.grad.data.std().item(),
+                    'max': p.grad.data.max().item(),
+                    'min': p.grad.data.min().item(),
+                }
+
+    total_norm = total_norm ** 0.5
+
+    print(f"\nGradient norm: {total_norm:.4f}")
+    print(f"Parameters with gradients: {param_count}")
+
+    # Print sample gradients from important layers
+    print("\nSample gradient statistics:")
+    for i, (name, stats) in enumerate(list(grad_samples.items())[:10]):
+        print(f"  {name[:60]:<60} | norm: {stats['norm']:.4e} | mean: {stats['mean']:.4e} | std: {stats['std']:.4e}")
+
+    # Optional: zero gradients for next iteration
+    model.zero_grad()
+    model.eval()  # Switch back to eval mode
 
-
+
-
-
-
▶ UV Install Logs
-
- -

Forward and Backward

-

Next, we run a forward and backward pass with Megablocks kernels enabled. This should be more memory efficient and allow us to complete the backward pass without running out of memory.

diff --git a/moe_benchmarks/megablocks_yamoe/artifacts/binned_run/binned_results.json b/moe_benchmarks/megablocks_yamoe/artifacts/binned_run/binned_results.json new file mode 100644 index 0000000000000000000000000000000000000000..b222e58061a4ff0233f1bab85f7d2c289d162f3e --- /dev/null +++ b/moe_benchmarks/megablocks_yamoe/artifacts/binned_run/binned_results.json @@ -0,0 +1,24 @@ +{ + "implementation": "binned_results", + "config": { + "warmup": 10, + "iters": 50, + "device": "cuda", + "dtype": "torch.float32", + "tokens": 100, + "vary_inputs": true + }, + "stats": { + "avg_ms": 36.06324691992995, + "min_ms": 33.29206800026441, + "max_ms": 38.40615900026023, + "std_ms": 1.258567678508065, + "p50_ms": 36.21510599987232, + "p95_ms": 37.524451049966956, + "p99_ms": 38.03603995002959, + "num_iters": 50, + "tokens_per_s": 2772.906172925215, + "throughput_variance": 98.28636435515342 + }, + "output_sum": 3.97190523147583 +} \ No newline at end of file diff --git a/moe_benchmarks/megablocks_yamoe/artifacts/gptoss_run/gptoss_results.json b/moe_benchmarks/megablocks_yamoe/artifacts/gptoss_run/gptoss_results.json new file mode 100644 index 0000000000000000000000000000000000000000..491d61dbedf6649b5665666f8408fd1d61d51144 --- /dev/null +++ b/moe_benchmarks/megablocks_yamoe/artifacts/gptoss_run/gptoss_results.json @@ -0,0 +1,24 @@ +{ + "implementation": "gptoss_results", + "config": { + "warmup": 10, + "iters": 50, + "device": "cuda", + "dtype": "torch.float32", + "tokens": 100, + "vary_inputs": true + }, + "stats": { + "avg_ms": 45.286630379978305, + "min_ms": 38.91367899996112, + "max_ms": 49.84392799997295, + "std_ms": 3.2326168009526866, + "p50_ms": 45.42240999990099, + "p95_ms": 49.729684149951936, + "p99_ms": 49.82545450991893, + "num_iters": 50, + "tokens_per_s": 2208.1572234663554, + "throughput_variance": 161.27578702324564 + }, + "output_sum": 11.53223705291748 +} \ No newline at end of file diff --git a/moe_benchmarks/megablocks_yamoe/artifacts/gptoss_training_run/gptoss_training_results.json b/moe_benchmarks/megablocks_yamoe/artifacts/gptoss_training_run/gptoss_training_results.json new file mode 100644 index 0000000000000000000000000000000000000000..c0d899fe9a06bb4ca74cd58fe9746a5942bbd236 --- /dev/null +++ b/moe_benchmarks/megablocks_yamoe/artifacts/gptoss_training_run/gptoss_training_results.json @@ -0,0 +1,24 @@ +{ + "implementation": "gptoss_training_results", + "config": { + "warmup": 10, + "iters": 50, + "device": "cuda", + "dtype": "torch.float32", + "tokens": 100, + "vary_inputs": true + }, + "stats": { + "avg_ms": 46.01034353989235, + "min_ms": 39.20698799993261, + "max_ms": 51.09754699969926, + "std_ms": 3.2594474712819497, + "p50_ms": 46.132551999562565, + "p95_ms": 50.721096600273086, + "p99_ms": 51.0080171399477, + "num_iters": 50, + "tokens_per_s": 2173.4243282338675, + "throughput_variance": 158.68467070353637 + }, + "output_sum": 11.53223705291748 +} \ No newline at end of file diff --git a/moe_benchmarks/megablocks_yamoe/artifacts/yamoe_run/yamoe_results.json b/moe_benchmarks/megablocks_yamoe/artifacts/yamoe_run/yamoe_results.json new file mode 100644 index 0000000000000000000000000000000000000000..ec2f20c34ce683f571a322b29e917480b9e73939 --- /dev/null +++ b/moe_benchmarks/megablocks_yamoe/artifacts/yamoe_run/yamoe_results.json @@ -0,0 +1,24 @@ +{ + "implementation": "yamoe_results", + "config": { + "warmup": 10, + "iters": 50, + "device": "cuda", + "dtype": "torch.float32", + "tokens": 100, + "vary_inputs": true + }, + "stats": { + "avg_ms": 4.2510544400101935, + "min_ms": 4.144352999901457, + "max_ms": 4.320155999266717, + "std_ms": 0.02873328656403644, + "p50_ms": 4.2539659998510615, + "p95_ms": 4.2857709999225335, + "p99_ms": 4.306132199617423, + "num_iters": 50, + "tokens_per_s": 23523.575482547854, + "throughput_variance": 160.28680309512873 + }, + "output_sum": 3.97190523147583 +} \ No newline at end of file diff --git a/moe_benchmarks/megablocks_yamoe/cells/__pycache__/bench_utils.cpython-311.pyc b/moe_benchmarks/megablocks_yamoe/cells/__pycache__/bench_utils.cpython-311.pyc index d5df252f2e6edaec8717d49f0fe7d72b278c362e..d4071be109b86c510b35c06dedc5c8c3e35bfe86 100644 Binary files a/moe_benchmarks/megablocks_yamoe/cells/__pycache__/bench_utils.cpython-311.pyc and b/moe_benchmarks/megablocks_yamoe/cells/__pycache__/bench_utils.cpython-311.pyc differ diff --git a/moe_benchmarks/megablocks_yamoe/cells/__pycache__/config.cpython-311.pyc b/moe_benchmarks/megablocks_yamoe/cells/__pycache__/config.cpython-311.pyc index be5ea0a48cedabb22eac9d1ef3f5b0422d87c5c2..8eaaeeb579d95786f08f8a033ac563063f8e58ba 100644 Binary files a/moe_benchmarks/megablocks_yamoe/cells/__pycache__/config.cpython-311.pyc and b/moe_benchmarks/megablocks_yamoe/cells/__pycache__/config.cpython-311.pyc differ diff --git a/moe_benchmarks/megablocks_yamoe/cells/binned_run.py b/moe_benchmarks/megablocks_yamoe/cells/binned_run.py new file mode 100644 index 0000000000000000000000000000000000000000..fe9e54316e7380bc60d7bb62459498e450575b31 --- /dev/null +++ b/moe_benchmarks/megablocks_yamoe/cells/binned_run.py @@ -0,0 +1,195 @@ +# /// script +# dependencies = [ +# "torch", +# "numpy", +# ] +# /// + +import torch +from torch import nn +from torch.nn import functional as F +from bench_utils import to_dtype, tensor_stats, set_seed, bench_context +from config import ( + NUM_EXPERTS, HIDDEN_SIZE, TOP_K, + BATCH_SIZE, SEQ_LEN, DTYPE, DEVICE, + WEIGHT_SEED, EXPERT_SEED, INPUT_SEED, GENERAL_SEED +) +from pathlib import Path +import os + +# Discover the upstream artifact directory from env +data_dir = os.environ.get('UVNOTE_INPUT_SAVE_DATA', '.') + +router_weight = torch.load(Path(data_dir) / 'router_weight.pt') +router_bias = torch.load(Path(data_dir) / 'router_bias.pt') +gate_up_proj = torch.load(Path(data_dir) / 'gate_up_proj.pt') +gate_up_proj_bias = torch.load(Path(data_dir) / 'gate_up_proj_bias.pt') +down_proj = torch.load(Path(data_dir) / 'down_proj.pt') +down_proj_bias = torch.load(Path(data_dir) / 'down_proj_bias.pt') + +print("Loaded shared weights from artifacts") +print(f"Router weight sum: {router_weight.sum().item():.6f}") +print(f"Gate/up sum: {gate_up_proj.sum().item():.6f}") +print(f"Down sum: {down_proj.sum().item():.6f}") + +def binned_gather(x, indices, bins, expert_capacity, top_k): + E, H = bins.shape[0], x.shape[1] + out = torch.zeros((E, expert_capacity, H), device=x.device, dtype=x.dtype) + for e in range(E): + start = 0 if e == 0 else bins[e - 1] + end = bins[e] + n = min(end - start, expert_capacity) + for i in range(n): + flat_pos = indices[start + i] + tok = flat_pos // top_k + out[e, i] = x[tok] + return out + +def binned_scatter(x, indices, weights, bins, expert_capacity, top_k): + E, C, H = x.shape + N = indices.shape[0] // top_k + out = torch.zeros((N, top_k, H), dtype=x.dtype, device=x.device) + for e in range(E): + start = 0 if e == 0 else bins[e - 1] + end = bins[e] + n = end - start + if n == 0: + continue + take = min(n, expert_capacity) + for i in range(take): + flat_pos = indices[start + i] + tok = flat_pos // top_k + slot = flat_pos % top_k + scale = weights[flat_pos] if weights is not None else 1.0 + out[tok, slot] = x[e, i] * scale + return out.sum(dim=1) + +def sort_tokens_by_expert(router_indices, num_experts): + flat_indices = router_indices.flatten() + sorted_values, sorted_indices = torch.sort(flat_indices) + tokens_per_expert = torch.bincount(sorted_values, minlength=num_experts) + bins = torch.cumsum(tokens_per_expert, dim=0) + return sorted_indices, sorted_values, bins, tokens_per_expert + +def binned_experts_ref( + hidden_states, + router_indices, + routing_weights, + gate_up_proj, + gate_up_proj_bias, + down_proj, + down_proj_bias, + expert_capacity, +): + B, S, H = hidden_states.shape + E, K = routing_weights.shape[1], router_indices.shape[1] + + indices, _, bins, _ = sort_tokens_by_expert(router_indices, E) + x = binned_gather(hidden_states.view(-1, H), indices, bins, expert_capacity, K) + + gate_up = torch.bmm(x, gate_up_proj) + gate_up += gate_up_proj_bias[..., None, :] + + gate, up = gate_up[..., ::2], gate_up[..., 1::2] + + # clamp to limit + limit = 7.0 + gate = gate.clamp(min=None, max=limit) + up = up.clamp(min=-limit, max=limit) + + glu = gate * torch.sigmoid(gate * 1.702) + x = (up + 1) * glu + x = torch.bmm(x, down_proj) + down_proj_bias[..., None, :] + + # build routing weights aligned to (token, slot) + flat_dense = routing_weights.view(-1, E) + flat_router = router_indices.view(-1, K) + selected = torch.gather(flat_dense, 1, flat_router).reshape(-1) + + # scatter back + y = binned_scatter(x, indices, selected, bins, expert_capacity, K) + + return y.view(B, S, H) + +class BinnedRouter(nn.Module): + def __init__(self, router_weight, router_bias): + super().__init__() + self.top_k = TOP_K + self.num_experts = NUM_EXPERTS + self.hidden_dim = HIDDEN_SIZE + self.weight = nn.Parameter(router_weight.clone()) + self.bias = nn.Parameter(router_bias.clone()) + + def forward(self, hidden_states): + hidden_states = hidden_states.reshape(-1, self.hidden_dim) + router_logits = F.linear(hidden_states, self.weight, self.bias) + router_top_value, router_indices = torch.topk(router_logits, self.top_k, dim=-1) + router_top_value = torch.nn.functional.softmax(router_top_value, dim=1, dtype=router_top_value.dtype) + router_scores = torch.zeros_like(router_logits).scatter_(1, router_indices, router_top_value) + return router_scores, router_indices + +def ceil_div(a, b): + return (a + b - 1) // b + +class BinnedMoEMLP(nn.Module): + def __init__(self, router_weight, router_bias, gate_up_proj, gate_up_proj_bias, down_proj, down_proj_bias): + super().__init__() + self.router = BinnedRouter(router_weight, router_bias) + self.num_experts = NUM_EXPERTS + self.hidden_size = HIDDEN_SIZE + self.top_k = TOP_K + + # Expert weights - use the loaded weights + self.gate_up_proj = nn.Parameter(gate_up_proj.clone()) + self.gate_up_proj_bias = nn.Parameter(gate_up_proj_bias.clone()) + self.down_proj = nn.Parameter(down_proj.clone()) + self.down_proj_bias = nn.Parameter(down_proj_bias.clone()) + + def forward(self, hidden_states): + router_scores, router_indices = self.router(hidden_states) + batch_size = hidden_states.shape[0] + expert_capacity = ceil_div(batch_size * self.top_k, self.num_experts) + + output = binned_experts_ref( + hidden_states, + router_indices, + router_scores, + self.gate_up_proj, + self.gate_up_proj_bias, + self.down_proj, + self.down_proj_bias, + expert_capacity, + ) + + return output, router_scores + +# Run the model +set_seed(GENERAL_SEED) + +device = torch.device(DEVICE) +dtype = to_dtype(DTYPE) + +print("\n=== Binned Implementation ===") +# Initialize model with loaded weights +model = BinnedMoEMLP( + router_weight.to(device), + router_bias.to(device), + gate_up_proj.to(device), + gate_up_proj_bias.to(device), + down_proj.to(device), + down_proj_bias.to(device) +).to(device=device) + +print(f"Router weight sum: {model.router.weight.sum().item():.6f}") +print(f"Gate/up proj sum: {model.gate_up_proj.sum().item():.6f}") +print(f"Down proj sum: {model.down_proj.sum().item():.6f}") + +# Generate the same input as Yamoe +set_seed(INPUT_SEED) +x = torch.randn(BATCH_SIZE, SEQ_LEN, HIDDEN_SIZE, device=device, dtype=dtype) * 0.1 + +# Benchmark the model with varied inputs to prevent caching artifacts +tokens = BATCH_SIZE * SEQ_LEN +with bench_context(warmup=10, iters=50, device=device, dtype=dtype, tokens=tokens, save_json="binned_results.json", vary_inputs=True) as bench: + output, stats = bench(model, x) + print(f"\nOutput sum: {output[0].sum().item():.6f}") \ No newline at end of file diff --git a/moe_benchmarks/megablocks_yamoe/cells/gptoss_run.py b/moe_benchmarks/megablocks_yamoe/cells/gptoss_run.py new file mode 100644 index 0000000000000000000000000000000000000000..5a1532dabff53ecb068ddd4354c545f0cea2d72b --- /dev/null +++ b/moe_benchmarks/megablocks_yamoe/cells/gptoss_run.py @@ -0,0 +1,147 @@ +# /// script +# dependencies = [ +# "torch", +# "numpy", +# ] +# /// + +import torch +from torch import nn +from torch.nn import functional as F +from bench_utils import to_dtype, tensor_stats, set_seed, bench_context +from config import ( + NUM_EXPERTS, HIDDEN_SIZE, TOP_K, + BATCH_SIZE, SEQ_LEN, DTYPE, DEVICE, + WEIGHT_SEED, EXPERT_SEED, INPUT_SEED, GENERAL_SEED +) +from pathlib import Path +import os + +# Discover the upstream artifact directory from env +data_dir = os.environ.get('UVNOTE_INPUT_SAVE_DATA', '.') + +router_weight = torch.load(Path(data_dir) / 'router_weight.pt') +router_bias = torch.load(Path(data_dir) / 'router_bias.pt') +gate_up_proj = torch.load(Path(data_dir) / 'gate_up_proj.pt') +gate_up_proj_bias = torch.load(Path(data_dir) / 'gate_up_proj_bias.pt') +down_proj = torch.load(Path(data_dir) / 'down_proj.pt') +down_proj_bias = torch.load(Path(data_dir) / 'down_proj_bias.pt') + +print("Loaded shared weights from artifacts") +print(f"Router weight sum: {router_weight.sum().item():.6f}") +print(f"Gate/up sum: {gate_up_proj.sum().item():.6f}") +print(f"Down sum: {down_proj.sum().item():.6f}") + +class GptOssRouter(nn.Module): + def __init__(self, router_weight, router_bias): + super().__init__() + self.top_k = TOP_K + self.num_experts = NUM_EXPERTS + self.hidden_dim = HIDDEN_SIZE + self.weight = nn.Parameter(router_weight.clone()) + self.bias = nn.Parameter(router_bias.clone()) + + def forward(self, hidden_states): + hidden_states = hidden_states.reshape(-1, self.hidden_dim) + router_logits = F.linear(hidden_states, self.weight, self.bias) + router_top_value, router_indices = torch.topk(router_logits, self.top_k, dim=-1) + router_top_value = torch.nn.functional.softmax(router_top_value, dim=1, dtype=router_top_value.dtype) + router_scores = torch.zeros_like(router_logits).scatter_(1, router_indices, router_top_value) + return router_scores, router_indices + +class GptOssExperts(nn.Module): + def __init__(self, gate_up_proj, gate_up_proj_bias, down_proj, down_proj_bias): + super().__init__() + self.num_experts = NUM_EXPERTS + self.hidden_size = HIDDEN_SIZE + self.expert_dim = self.hidden_size + self.gate_up_proj = nn.Parameter(gate_up_proj.clone()) + self.gate_up_proj_bias = nn.Parameter(gate_up_proj_bias.clone()) + self.down_proj = nn.Parameter(down_proj.clone()) + self.down_proj_bias = nn.Parameter(down_proj_bias.clone()) + self.alpha = 1.702 + self.limit = 7.0 + + def forward(self, hidden_states: torch.Tensor, router_indices=None, routing_weights=None) -> torch.Tensor: + batch_size = hidden_states.shape[0] + hidden_states = hidden_states.reshape(-1, self.hidden_size) + num_experts = routing_weights.shape[1] + + if hidden_states.device.type == "cpu" or self.training: + next_states = torch.zeros_like(hidden_states, dtype=hidden_states.dtype, device=hidden_states.device) + with torch.no_grad(): + expert_mask = torch.nn.functional.one_hot(router_indices, num_classes=num_experts) + expert_mask = expert_mask.permute(2, 1, 0) + expert_hit = torch.greater(expert_mask.sum(dim=(-1, -2)), 0).nonzero() + + for expert_idx in expert_hit[:]: + expert_idx = expert_idx[0] + with torch.no_grad(): + _, token_idx = torch.where(expert_mask[expert_idx]) + current_state = hidden_states[token_idx] + gate_up = current_state @ self.gate_up_proj[expert_idx] + self.gate_up_proj_bias[expert_idx] + gate, up = gate_up[..., ::2], gate_up[..., 1::2] + gate = gate.clamp(min=None, max=self.limit) + up = up.clamp(min=-self.limit, max=self.limit) + glu = gate * torch.sigmoid(gate * self.alpha) + gated_output = (up + 1) * glu + out = gated_output @ self.down_proj[expert_idx] + self.down_proj_bias[expert_idx] + weighted_output = out * routing_weights[token_idx, expert_idx, None] + next_states.index_add_(0, token_idx, weighted_output.to(hidden_states.dtype)) + next_states = next_states.view(batch_size, -1, self.hidden_size) + else: + hidden_states = hidden_states.repeat(num_experts, 1) + hidden_states = hidden_states.view(num_experts, -1, self.hidden_size) + gate_up = torch.bmm(hidden_states, self.gate_up_proj) + self.gate_up_proj_bias[..., None, :] + gate, up = gate_up[..., ::2], gate_up[..., 1::2] + gate = gate.clamp(min=None, max=self.limit) + up = up.clamp(min=-self.limit, max=self.limit) + glu = gate * torch.sigmoid(gate * self.alpha) + next_states = torch.bmm(((up + 1) * glu), self.down_proj) + next_states = next_states + self.down_proj_bias[..., None, :] + next_states = next_states.view(num_experts, batch_size, -1, self.hidden_size) + next_states = next_states * routing_weights.transpose(0, 1).view(num_experts, batch_size, -1)[..., None] + next_states = next_states.sum(dim=0) + return next_states + +class GptOssMoEMLP(nn.Module): + def __init__(self, router_weight, router_bias, gate_up_proj, gate_up_proj_bias, down_proj, down_proj_bias): + super().__init__() + self.router = GptOssRouter(router_weight, router_bias) + self.experts = GptOssExperts(gate_up_proj, gate_up_proj_bias, down_proj, down_proj_bias) + + def forward(self, hidden_states): + router_scores, router_indices = self.router(hidden_states) + routed_out = self.experts(hidden_states, router_indices=router_indices, routing_weights=router_scores) + return routed_out, router_scores + +# Run the model +set_seed(GENERAL_SEED) + +device = torch.device(DEVICE) +dtype = to_dtype(DTYPE) + +print("\n=== GPT-OSS Implementation ===") +# Initialize model with loaded weights +model = GptOssMoEMLP( + router_weight.to(device), + router_bias.to(device), + gate_up_proj.to(device), + gate_up_proj_bias.to(device), + down_proj.to(device), + down_proj_bias.to(device) +).to(device=device) + +print(f"Router weight sum: {model.router.weight.sum().item():.6f}") +print(f"Gate/up proj sum: {model.experts.gate_up_proj.sum().item():.6f}") +print(f"Down proj sum: {model.experts.down_proj.sum().item():.6f}") + +# Generate the same input as other implementations +set_seed(INPUT_SEED) +x = torch.randn(BATCH_SIZE, SEQ_LEN, HIDDEN_SIZE, device=device, dtype=dtype) * 0.1 + +# Benchmark the model with varied inputs to prevent caching artifacts +tokens = BATCH_SIZE * SEQ_LEN +with bench_context(warmup=10, iters=50, device=device, dtype=dtype, tokens=tokens, save_json="gptoss_results.json", vary_inputs=True) as bench: + output, stats = bench(model, x) + print(f"\nOutput sum: {output[0].sum().item():.6f}") \ No newline at end of file diff --git a/moe_benchmarks/megablocks_yamoe/cells/gptoss_training_run.py b/moe_benchmarks/megablocks_yamoe/cells/gptoss_training_run.py new file mode 100644 index 0000000000000000000000000000000000000000..f18731a74bfa546e612addbaab9e3ff5ec5d26dc --- /dev/null +++ b/moe_benchmarks/megablocks_yamoe/cells/gptoss_training_run.py @@ -0,0 +1,138 @@ +# /// script +# dependencies = [ +# "torch", +# "numpy", +# ] +# /// + +import torch +from torch import nn +from torch.nn import functional as F +from bench_utils import to_dtype, tensor_stats, set_seed, bench_context +from config import ( + NUM_EXPERTS, HIDDEN_SIZE, TOP_K, + BATCH_SIZE, SEQ_LEN, DTYPE, DEVICE, + WEIGHT_SEED, EXPERT_SEED, INPUT_SEED, GENERAL_SEED +) +from pathlib import Path +import os + +# Discover the upstream artifact directory from env +data_dir = os.environ.get('UVNOTE_INPUT_SAVE_DATA', '.') + +router_weight = torch.load(Path(data_dir) / 'router_weight.pt') +router_bias = torch.load(Path(data_dir) / 'router_bias.pt') +gate_up_proj = torch.load(Path(data_dir) / 'gate_up_proj.pt') +gate_up_proj_bias = torch.load(Path(data_dir) / 'gate_up_proj_bias.pt') +down_proj = torch.load(Path(data_dir) / 'down_proj.pt') +down_proj_bias = torch.load(Path(data_dir) / 'down_proj_bias.pt') + +print("Loaded shared weights from artifacts") +print(f"Router weight sum: {router_weight.sum().item():.6f}") +print(f"Gate/up sum: {gate_up_proj.sum().item():.6f}") +print(f"Down sum: {down_proj.sum().item():.6f}") + +class GptOssTrainingRouter(nn.Module): + def __init__(self, router_weight, router_bias): + super().__init__() + self.top_k = TOP_K + self.num_experts = NUM_EXPERTS + self.hidden_dim = HIDDEN_SIZE + self.weight = nn.Parameter(router_weight.clone()) + self.bias = nn.Parameter(router_bias.clone()) + + def forward(self, hidden_states): + hidden_states = hidden_states.reshape(-1, self.hidden_dim) + router_logits = F.linear(hidden_states, self.weight, self.bias) + router_top_value, router_indices = torch.topk(router_logits, self.top_k, dim=-1) + router_top_value = torch.nn.functional.softmax(router_top_value, dim=1, dtype=router_top_value.dtype) + router_scores = torch.zeros_like(router_logits).scatter_(1, router_indices, router_top_value) + return router_scores, router_indices + +class GptOssTrainingExperts(nn.Module): + def __init__(self, gate_up_proj, gate_up_proj_bias, down_proj, down_proj_bias): + super().__init__() + self.num_experts = NUM_EXPERTS + self.hidden_size = HIDDEN_SIZE + self.expert_dim = self.hidden_size + self.gate_up_proj = nn.Parameter(gate_up_proj.clone()) + self.gate_up_proj_bias = nn.Parameter(gate_up_proj_bias.clone()) + self.down_proj = nn.Parameter(down_proj.clone()) + self.down_proj_bias = nn.Parameter(down_proj_bias.clone()) + self.alpha = 1.702 + self.limit = 7.0 + + def forward(self, hidden_states: torch.Tensor, router_indices=None, routing_weights=None) -> torch.Tensor: + batch_size = hidden_states.shape[0] + hidden_states = hidden_states.reshape(-1, self.hidden_size) + num_experts = routing_weights.shape[1] + + # Force training mode path (expert loop instead of batched) + next_states = torch.zeros_like(hidden_states, dtype=hidden_states.dtype, device=hidden_states.device) + with torch.no_grad(): + expert_mask = torch.nn.functional.one_hot(router_indices, num_classes=num_experts) + expert_mask = expert_mask.permute(2, 1, 0) + expert_hit = torch.greater(expert_mask.sum(dim=(-1, -2)), 0).nonzero() + + for expert_idx in expert_hit[:]: + expert_idx = expert_idx[0] + with torch.no_grad(): + _, token_idx = torch.where(expert_mask[expert_idx]) + current_state = hidden_states[token_idx] + gate_up = current_state @ self.gate_up_proj[expert_idx] + self.gate_up_proj_bias[expert_idx] + gate, up = gate_up[..., ::2], gate_up[..., 1::2] + gate = gate.clamp(min=None, max=self.limit) + up = up.clamp(min=-self.limit, max=self.limit) + glu = gate * torch.sigmoid(gate * self.alpha) + gated_output = (up + 1) * glu + out = gated_output @ self.down_proj[expert_idx] + self.down_proj_bias[expert_idx] + weighted_output = out * routing_weights[token_idx, expert_idx, None] + next_states.index_add_(0, token_idx, weighted_output.to(hidden_states.dtype)) + next_states = next_states.view(batch_size, -1, self.hidden_size) + return next_states + +class GptOssTrainingMoEMLP(nn.Module): + def __init__(self, router_weight, router_bias, gate_up_proj, gate_up_proj_bias, down_proj, down_proj_bias): + super().__init__() + self.router = GptOssTrainingRouter(router_weight, router_bias) + self.experts = GptOssTrainingExperts(gate_up_proj, gate_up_proj_bias, down_proj, down_proj_bias) + + def forward(self, hidden_states): + router_scores, router_indices = self.router(hidden_states) + routed_out = self.experts(hidden_states, router_indices=router_indices, routing_weights=router_scores) + return routed_out, router_scores + +# Run the model +set_seed(GENERAL_SEED) + +device = torch.device(DEVICE) +dtype = to_dtype(DTYPE) + +print("\n=== GPT-OSS Implementation (Training Mode - Expert Loop) ===") +# Initialize model with loaded weights and force training mode +model = GptOssTrainingMoEMLP( + router_weight.to(device), + router_bias.to(device), + gate_up_proj.to(device), + gate_up_proj_bias.to(device), + down_proj.to(device), + down_proj_bias.to(device) +).to(device=device) + +# Set to training mode to force expert loop path +model.train() + +print(f"Router weight sum: {model.router.weight.sum().item():.6f}") +print(f"Gate/up proj sum: {model.experts.gate_up_proj.sum().item():.6f}") +print(f"Down proj sum: {model.experts.down_proj.sum().item():.6f}") +print(f"Model training mode: {model.training}") + +# Generate the same input as other implementations +set_seed(INPUT_SEED) +x = torch.randn(BATCH_SIZE, SEQ_LEN, HIDDEN_SIZE, device=device, dtype=dtype) * 0.1 + +# Benchmark the model with varied inputs to prevent caching artifacts +tokens = BATCH_SIZE * SEQ_LEN +with bench_context(warmup=10, iters=50, device=device, dtype=dtype, tokens=tokens, save_json="gptoss_training_results.json", vary_inputs=True) as bench: + output, stats = bench(model, x) + print(f"\nOutput sum: {output[0].sum().item():.6f}") \ No newline at end of file diff --git a/moe_benchmarks/megablocks_yamoe/cells/megablocks_run.py b/moe_benchmarks/megablocks_yamoe/cells/megablocks_run.py new file mode 100644 index 0000000000000000000000000000000000000000..a18723cb66c892119c0a9e88d8c2a140a6354a00 --- /dev/null +++ b/moe_benchmarks/megablocks_yamoe/cells/megablocks_run.py @@ -0,0 +1,103 @@ +# /// script +# dependencies = [ +# "torch", +# "numpy", +# "kernels", +# ] +# /// + +import torch +from torch import nn +from torch.nn import functional as F +from kernels import get_kernel, get_local_kernel +from bench_utils import to_dtype, tensor_stats, set_seed, bench_context +from config import ( + NUM_EXPERTS, HIDDEN_SIZE, TOP_K, + BATCH_SIZE, SEQ_LEN, DTYPE, DEVICE, + WEIGHT_SEED, EXPERT_SEED, INPUT_SEED, GENERAL_SEED +) +from pathlib import Path +from collections import namedtuple +import os + +# Discover the upstream artifact directory from env +data_dir = os.environ.get('UVNOTE_INPUT_SAVE_DATA', '.') + +print(f"Loading weights from: {data_dir}") + +router_weight = torch.load(Path(data_dir) / 'router_weight.pt') +router_bias = torch.load(Path(data_dir) / 'router_bias.pt') +gate_up_proj = torch.load(Path(data_dir) / 'gate_up_proj.pt') +gate_up_proj_bias = torch.load(Path(data_dir) / 'gate_up_proj_bias.pt') +down_proj = torch.load(Path(data_dir) / 'down_proj.pt') +down_proj_bias = torch.load(Path(data_dir) / 'down_proj_bias.pt') + +print("Loaded shared weights from artifacts") +print(f"Router weight sum: {router_weight.sum().item():.6f}") +print(f"Gate/up sum: {gate_up_proj.sum().item():.6f}") +print(f"Down sum: {down_proj.sum().item():.6f}") + +def build_megablocks_model(device: torch.device): + # Download optimized kernels from the Hugging Face hub + megablocks = get_kernel("kernels-community/megablocks", revision="v0.0.2") + model = megablocks.layers.MegaBlocksMoeMLP() + + # Create attribute container for expert weights + model.experts = namedtuple( + "Experts", ["gate_up_proj", "gate_up_proj_bias", "down_proj", "down_proj_bias", "hidden_size"] + ) + + # Use loaded router weights for consistency + model.router = torch.nn.Linear(HIDDEN_SIZE, NUM_EXPERTS, device=device) + with torch.no_grad(): + model.router.weight.copy_(router_weight) + model.router.bias.copy_(router_bias) + + # Attach loaded expert weights to the experts container + e = model.experts + e.alpha = 1.702 + e.capacity_factor = 32 + e.gate_up_proj = torch.nn.Parameter(gate_up_proj.clone().to(device)) + e.gate_up_proj_bias = torch.nn.Parameter(gate_up_proj_bias.clone().to(device)) + e.down_proj = torch.nn.Parameter(down_proj.clone().to(device)) + e.down_proj_bias = torch.nn.Parameter(down_proj_bias.clone().to(device)) + e.hidden_size = HIDDEN_SIZE + + # Log weight statistics for comparison + print(f"[MegaBlocks] Router weight sum: {model.router.weight.sum().item():.6f}") + print(f"[MegaBlocks] Gate/up projection shape: {tuple(e.gate_up_proj.shape)}, sum: {e.gate_up_proj.sum().item():.6f}") + print(f"[MegaBlocks] Down projection shape: {tuple(e.down_proj.shape)}, sum: {e.down_proj.sum().item():.6f}") + + return model + +# Create a wrapper to match the interface of other implementations +class MegaBlocksMoEWrapper(nn.Module): + def __init__(self, megablocks_model): + super().__init__() + self.model = megablocks_model + + def forward(self, hidden_states): + # MegaBlocks expects input in the format (batch, seq_len, hidden_dim) + output, dummy_routing_weights = self.model(hidden_states) + return output, dummy_routing_weights + +# Run the model +set_seed(GENERAL_SEED) + +device = torch.device(DEVICE) +dtype = to_dtype(DTYPE) + +print("\n=== MegaBlocks Implementation ===") +# Build MegaBlocks model with loaded weights +megablocks_model = build_megablocks_model(device) +model = MegaBlocksMoEWrapper(megablocks_model).to(device=device) + +# Generate the same input as other implementations +set_seed(INPUT_SEED) +x = torch.randn(BATCH_SIZE, SEQ_LEN, HIDDEN_SIZE, device=device, dtype=dtype) * 0.1 + +# Benchmark the model with varied inputs to prevent caching artifacts +tokens = BATCH_SIZE * SEQ_LEN +with bench_context(warmup=10, iters=50, device=device, dtype=dtype, tokens=tokens, save_json="megablocks_results.json", vary_inputs=True) as bench: + output, stats = bench(model, x) + print(f"\nOutput sum: {output[0].sum().item():.6f}") \ No newline at end of file diff --git a/moe_benchmarks/megablocks_yamoe/cells/setup.py b/moe_benchmarks/megablocks_yamoe/cells/setup.py new file mode 100644 index 0000000000000000000000000000000000000000..6d7f386417ca59470f5e6404d26b64a6d1fd6f39 --- /dev/null +++ b/moe_benchmarks/megablocks_yamoe/cells/setup.py @@ -0,0 +1,116 @@ +# /// script +# requires-python = ">=3.12" +# dependencies = [ +# "accelerate>=1.10.1", +# "torch>=2.7.0", +# "kernels==0.10.0", +# "transformers@https://github.com/huggingface/transformers.git", +# "ipdb>=0.13.13", +# "matplotlib>=3.7.2", +# "numpy>=1.24.3", +# ] +# /// + +import torch +from transformers import GptOssForCausalLM, PreTrainedTokenizerFast, Mxfp4Config +import time +import torch.nn as nn +from kernels import register_kernel_mapping, Mode, LayerRepository +import sys +import torch.profiler +import gc +import logging + +# set to debug logging +logging.basicConfig(level=logging.INFO) + +def reset_peak_memory_stats(): + """Clear CUDA cache and reset memory allocation counters.""" + torch.cuda.empty_cache() + if torch.cuda.is_available(): + torch.cuda.reset_peak_memory_stats() + gc.collect() + +def get_memory_stats(): + """Get current and peak CUDA memory usage.""" + if not torch.cuda.is_available(): + return {"allocated_gb": 0, "peak_gb": 0, "reserved_gb": 0} + return { + "allocated_gb": torch.cuda.memory_allocated() / 1e9, + "peak_gb": torch.cuda.max_memory_allocated() / 1e9, + "reserved_gb": torch.cuda.memory_reserved() / 1e9, + } + +def override_kernel_layer_name(cls_name: str, value) -> bool: + """Helper to dynamically override the kernel_layer_name in a model class.""" + for mod in sys.modules.values(): + if mod is None: + continue + obj = getattr(mod, cls_name, None) + if isinstance(obj, type) and issubclass(obj, nn.Module): + setattr(obj, "kernel_layer_name", value) + print(f"Overrode {cls_name}.kernel_layer_name to {value}") + return True + return False + + +# Init the model the normal way +model_id = "openai/gpt-oss-20b" +tokenizer = PreTrainedTokenizerFast.from_pretrained(model_id) +quantization_config = Mxfp4Config(dequantize=True) + + +from kernels import replace_kernel_forward_from_hub, register_kernel_mapping, LayerRepository, Mode + +from transformers.models.gpt_oss.modeling_gpt_oss import GptOssMLP, GptOssRMSNorm + +replace_kernel_forward_from_hub(GptOssMLP, "Yamoe") +replace_kernel_forward_from_hub(GptOssRMSNorm, None) +custom_mapping = { + "Yamoe": { + "cuda": { + Mode.INFERENCE: LayerRepository( + repo_id="drbh/yamoe", + layer_name="Yamoe", + revision="v0.3.0", + ) + } + } +} +register_kernel_mapping(custom_mapping) + + +model = GptOssForCausalLM.from_pretrained( + model_id, + dtype="bfloat16", + device_map="auto", + use_kernels=True, + quantization_config=quantization_config, +).eval() + +messages = [ + {"role": "system", "content": "What is Tensor Parallelism?"}, +] + +inputs = tokenizer.apply_chat_template( + messages, + add_generation_prompt=True, + return_tensors="pt", + return_dict=True, + reasoning_effort="low", +).to("cuda") + +max_tokens = 256 + +with torch.inference_mode(): + start_time = time.perf_counter() + generated = model.generate( + **inputs, + max_new_tokens=max_tokens, + do_sample=False, + temperature=None, + ) + end_time = time.perf_counter() + +print(tokenizer.decode(generated[0], skip_special_tokens=False)) +print(f"Generation took {end_time - start_time:.2f} seconds") diff --git a/moe_benchmarks/megablocks_yamoe/megablocks_yamoe.html b/moe_benchmarks/megablocks_yamoe/megablocks_yamoe.html index d483be109634d9f2c6ca41723356d82e1bf2cfa1..c4126222a6acb3c8c1746f92615d27aaf4909fb6 100644 --- a/moe_benchmarks/megablocks_yamoe/megablocks_yamoe.html +++ b/moe_benchmarks/megablocks_yamoe/megablocks_yamoe.html @@ -3710,61 +3710,288 @@ span.linenos.special { color: #000000; background-color: #ffffc0; padding-left:
Generated on:
- Linux x86_64 | Linux-6.11.0-1018-azure-x86_64-with-glibc2.39 + Linux x86_64 | Linux-6.12.40-64.114.amzn2023.x86_64-x86_64-with-glibc2.36
-
+

Comparison of Megablocks and Yamoe Kernels

+

This note compares the performance of the Megablocks and Yamoe kernels on the GPT-OSS-20B model.

+

Megablocks kernel

+

Yamoe Kernel

+
-▼ code -▼ output - ▶ uv-logs +▼ code +▼ output + ▶ uv-logs | -Cell: nv | 0.07s | FAILED - | - -Raw +Cell: setup | 19.20s | FAILED + | + +Raw
-
+
-
-1 -2 -3 +
+1 +2 +3 +4 +5 +6 +7 +8 +9 +10 +11 +12 +13 +14 +15 +16 +17 +18 +19 +20 +21 +22 +23 +24 +25 +26 +27 +28 +29 +30 +31 +32 +33 +34 +35 +36 +37 +38 +39 +40 +41 +42 +43 +44 +45 +46 +47 +48 +49 +50 +51 +52 +53 +54 +55 +56 +57 +58 +59 +60 +61 +62 +63 +64 +65 +66 +67 +68 +69 +70 +71 +72 +73 +74 +75 +76 +77 +78 +79 +80 +81 +82 +83 +84 +85 +86 +87 +88 +89 +90 +91 +92 +93 +94 +95 +96 +97 +98 +99 +100 +101 +102 +103 +104 +105 +106 +107 +108 +109 +110 +111 +112 +113 +114 +115 +116
-
import subprocess
-
-print(subprocess.run(["nvidia-smi"], capture_output=True, text=True).stdout)
+
# /// script
+# requires-python = ">=3.12"
+# dependencies = [
+#     "accelerate>=1.10.1",
+#     "torch>=2.7.0",
+#     "kernels==0.10.0",
+#     "transformers@https://github.com/huggingface/transformers.git",
+#     "ipdb>=0.13.13",
+#     "matplotlib>=3.7.2",
+#     "numpy>=1.24.3",
+# ]
+# ///
+
+import torch
+from transformers import GptOssForCausalLM, PreTrainedTokenizerFast, Mxfp4Config
+import time
+import torch.nn as nn
+from kernels import register_kernel_mapping, Mode, LayerRepository
+import sys
+import torch.profiler
+import gc
+import logging
+
+# set to debug logging
+logging.basicConfig(level=logging.INFO)
+
+def reset_peak_memory_stats():
+    """Clear CUDA cache and reset memory allocation counters."""
+    torch.cuda.empty_cache()
+    if torch.cuda.is_available():
+        torch.cuda.reset_peak_memory_stats()
+    gc.collect()
+
+def get_memory_stats():
+    """Get current and peak CUDA memory usage."""
+    if not torch.cuda.is_available():
+        return {"allocated_gb": 0, "peak_gb": 0, "reserved_gb": 0}
+    return {
+        "allocated_gb": torch.cuda.memory_allocated() / 1e9,
+        "peak_gb": torch.cuda.max_memory_allocated() / 1e9,
+        "reserved_gb": torch.cuda.memory_reserved() / 1e9,
+    }
+
+def override_kernel_layer_name(cls_name: str, value) -> bool:
+    """Helper to dynamically override the kernel_layer_name in a model class."""
+    for mod in sys.modules.values():
+        if mod is None:
+            continue
+        obj = getattr(mod, cls_name, None)
+        if isinstance(obj, type) and issubclass(obj, nn.Module):
+            setattr(obj, "kernel_layer_name", value)
+            print(f"Overrode {cls_name}.kernel_layer_name to {value}")
+            return True
+    return False
+
+
+# Init the model the normal way
+model_id = "openai/gpt-oss-20b"
+tokenizer = PreTrainedTokenizerFast.from_pretrained(model_id)
+quantization_config = Mxfp4Config(dequantize=True)
+
+
+from kernels import replace_kernel_forward_from_hub, register_kernel_mapping, LayerRepository, Mode
+
+from transformers.models.gpt_oss.modeling_gpt_oss import GptOssMLP, GptOssRMSNorm
+
+replace_kernel_forward_from_hub(GptOssMLP, "Yamoe")
+replace_kernel_forward_from_hub(GptOssRMSNorm, None)
+custom_mapping = {
+    "Yamoe": {
+        "cuda": {
+            Mode.INFERENCE: LayerRepository(
+                repo_id="drbh/yamoe",
+                layer_name="Yamoe",
+                revision="v0.3.0",
+            )
+        }
+    }
+}
+register_kernel_mapping(custom_mapping)
+
+
+model = GptOssForCausalLM.from_pretrained(
+    model_id,
+    dtype="bfloat16",
+    device_map="auto",
+    use_kernels=True,
+    quantization_config=quantization_config,
+).eval()
+
+messages = [
+    {"role": "system", "content": "What is Tensor Parallelism?"},
+]
+
+inputs = tokenizer.apply_chat_template(
+    messages,
+    add_generation_prompt=True,
+    return_tensors="pt",
+    return_dict=True,
+    reasoning_effort="low",
+).to("cuda")
+
+max_tokens = 256
+
+with torch.inference_mode():
+    start_time = time.perf_counter()
+    generated = model.generate(
+        **inputs,
+        max_new_tokens=max_tokens,
+        do_sample=False,
+        temperature=None,
+    )
+    end_time = time.perf_counter()
+
+print(tokenizer.decode(generated[0], skip_special_tokens=False))
+print(f"Generation took {end_time - start_time:.2f} seconds")
 
-
+
-
-
Traceback (most recent call last): - File "/home/runner/work/kernels-uvnotes/kernels-uvnotes/moe_benchmarks/megablocks_yamoe/.uvnote/cells/nv.py", line 3, in <module> - print(subprocess.run(["nvidia-smi"], capture_output=True, text=True).stdout) - ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ - File "/opt/hostedtoolcache/Python/3.11.13/x64/lib/python3.11/subprocess.py", line 548, in run - with Popen(*popenargs, **kwargs) as process: - ^^^^^^^^^^^^^^^^^^^^^^^^^^^ - File "/opt/hostedtoolcache/Python/3.11.13/x64/lib/python3.11/subprocess.py", line 1026, in __init__ - self._execute_child(args, executable, preexec_fn, close_fds, - File "/opt/hostedtoolcache/Python/3.11.13/x64/lib/python3.11/subprocess.py", line 1955, in _execute_child - raise child_exception_type(errno_num, err_msg, err_filename) -FileNotFoundError: [Errno 2] No such file or directory: 'nvidia-smi' +
+
Downloading cpython-3.13.7-linux-x86_64-gnu (download) (32.0MiB) + Downloading cpython-3.13.7-linux-x86_64-gnu (download) + Updating https://github.com/huggingface/transformers.git (HEAD) + Updated https://github.com/huggingface/transformers.git (449533af73874470e914a203391635e04ac2ffc8) + × No solution found when resolving script dependencies: + ╰─▶ Because only transformers==4.57.0.dev0 is available and + transformers==4.57.0.dev0 depends on huggingface-hub==1.0.0rc1, + we can conclude that all versions of transformers depend on + huggingface-hub==1.0.0rc1. + And because kernels==0.10.0 depends on huggingface-hub>=0.26.0,<1.0, + we can conclude that kernels==0.10.0 and all versions of transformers + are incompatible. + And because you require kernels==0.10.0 and transformers, we can + conclude that your requirements are unsatisfiable.
- -

Comparison of Megablocks and Yamoe Kernels

-

This note compares the performance of the Megablocks and Yamoe kernels on the GPT-OSS-20B model.

-

Megablocks kernel

-

Yamoe Kernel

diff --git a/moe_benchmarks/megablocks_yamoe/torch_profile.html b/moe_benchmarks/megablocks_yamoe/torch_profile.html index 03274be1af151bba4833da45e7954d7de1f9a558..ec3f276d4f2ffdf0354ae0f539751c7b01f73f61 100644 --- a/moe_benchmarks/megablocks_yamoe/torch_profile.html +++ b/moe_benchmarks/megablocks_yamoe/torch_profile.html @@ -3708,7 +3708,7 @@ span.linenos.special { color: #000000; background-color: #ffffc0; padding-left:
Generated on:
- Linux x86_64 | Linux-6.11.0-1018-azure-x86_64-with-glibc2.39 + Linux x86_64 | Linux-6.12.40-64.114.amzn2023.x86_64-x86_64-with-glibc2.36
@@ -3720,7 +3720,7 @@ span.linenos.special { color: #000000; background-color: #ffffc0; padding-left: ▼ output ▶ uv-logs | -Cell: utils | deps: torch, numpy | 3.06s +Cell: utils | deps: torch, numpy | 34.59s | Raw @@ -3794,7 +3794,43 @@ Cell: utils | deps: torch, numpy | 3.06s
▶ UV Install Logs
@@ -3807,7 +3843,7 @@ Installed 26 packages in 253ms ▼ output ▶ uv-logs | -Cell: bench_utils | deps: torch, numpy | 13.67s +Cell: bench_utils | deps: torch, numpy | 35.65s | Raw @@ -4295,13 +4331,43 @@ Cell: bench_utils | deps: torch, numpy | 13.67s
▶ UV Install Logs
@@ -4315,7 +4381,7 @@ Installed 26 packages in 259ms ▼ output ▶ uv-logs | -Cell: config | deps: torch, numpy | 3.02s +Cell: config | deps: torch, numpy | 34.53s | Raw @@ -4375,7 +4441,43 @@ Cell: config | deps: torch, numpy | 3.02s
▶ UV Install Logs
@@ -4388,7 +4490,7 @@ Installed 26 packages in 243ms ▼ output ▶ uv-logs | -Cell: save_data | deps: torch, numpy | 11.90s +Cell: save_data | deps: torch, numpy | 39.05s | Raw @@ -4476,38 +4578,74 @@ Cell: save_data | deps: torch, numpy | 11.90s
Saved shared weights to artifacts -Router weight sum: 12.588735 +Router weight sum: 12.588732 Gate/up sum: 1026.601807 -Down sum: 206.729279 +Down sum: 206.729263
▶ UV Install Logs

Yamoe Implementation

This section runs the Yamoe MoE implementation with optimized Triton kernels.

-
+
▼ code ▼ output ▶ uv-logs | -Cell: yamoe_run | deps: torch, kernels, numpy | 4.02s | FAILED +Cell: yamoe_run | deps: torch, kernels, numpy | 39.19s | Raw @@ -4778,38 +4916,1811 @@ Cell: yamoe_run | deps: torch, kernels, numpy | 4.02s | FAILED
-
Loading weights from: /home/runner/work/kernels-uvnotes/kernels-uvnotes/moe_benchmarks/megablocks_yamoe/.uvnote/cache/57bbe537b6c3412d45373a8967728666b60b8687c5d1f5d0decc3ba51923edde +
Loading weights from: /repo/moe_benchmarks/megablocks_yamoe/.uvnote/cache/f8744f31d9cf720409852d42748815c6d61f005a2a9b297b7b9bf986ed98bb90 Loaded shared weights from artifacts -Router weight sum: 12.588735 +Router weight sum: 12.588732 Gate/up sum: 1026.601807 -Down sum: 206.729279 +Down sum: 206.729263 === Yamoe Implementation === +Router weight sum: 12.588732 +Gate/up proj sum: 1026.601807 +Down proj sum: 206.729340 + +┌─ Benchmark Configuration ─────────────────────────────┐ +│ Warmup: 10 Iters: 50 │ +│ Tokens: 100 │ +│ Input Variation: Enabled (prevents caching artifacts) │ +└────────────────────────────────────────────────────────┘ + +Base Input: shape=(1, 100, 1152), dtype=torch.float32, device=cuda:0, range=[-0.486445, 0.446746], mean=-0.000048, std=0.099986, norm=33.936142 +Input Variation: +0.001 * iteration (deterministic) + +Warming up (10 iterations)... +Benchmarking (50 iterations)... + Progress: 20% complete (avg: 4.253 ms) + Progress: 40% complete (avg: 4.250 ms) + Progress: 60% complete (avg: 4.250 ms) + Progress: 80% complete (avg: 4.251 ms) + +Output tensors: + Primary: shape=(1, 100, 1152), dtype=torch.float32, device=cuda:0, range=[-0.049506, 0.054984], mean=0.000034, std=0.006508, norm=2.208791 + Auxiliary: shape=(100, 128), dtype=torch.float32, device=cuda:0, range=[0.000000, 0.302948], mean=0.007812, std=0.043553, norm=5.005893 + +━━━━━━━━━━━━━━━━━━━━ Benchmark Results ━━━━━━━━━━━━━━━━━━━━ +Iterations: 50 + +Latency Statistics: + Average: 4.251 ms + Min: 4.144 ms + Max: 4.320 ms + Std Dev: 0.029 ms + +Percentiles: + P50 (median): 4.254 ms + P95: 4.286 ms + P99: 4.306 ms + +Throughput: + Tokens/sec: 23523.6 + Std Dev: 160.3 +━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ + +Saved benchmark results to yamoe_results.json + +Output sum: 3.971905
▶ UV Install Logs
+
Fetching 6 files: 0%| | 0/6 [00:00<?, ?it/s] +Fetching 6 files: 17%|█▋ | 1/6 [00:00<00:01, 3.18it/s] +Fetching 6 files: 50%|█████ | 3/6 [00:00<00:00, 3.84it/s] +Fetching 6 files: 100%|██████████| 6/6 [00:00<00:00, 7.53it/s]
+
+

Artifacts:

+yamoe_results.json
-
Traceback (most recent call last): - File "/home/runner/work/kernels-uvnotes/kernels-uvnotes/moe_benchmarks/megablocks_yamoe/.uvnote/cells/yamoe_run.py", line 115, in <module> - router_weight.to(device), - ^^^^^^^^^^^^^^^^^^^^^^^^ - File "/home/runner/work/_temp/setup-uv-cache/environments-v2/yamoe-run-07f6c9b004377cec/lib/python3.11/site-packages/torch/cuda/__init__.py", line 412, in _lazy_init - torch._C._cuda_init() -RuntimeError: Found no NVIDIA driver on your system. Please check that you have an NVIDIA GPU and installed a driver from http://www.nvidia.com/Download/index.aspx

Binned Implementation

This section runs the binned implementation that manually handles token gathering/scattering.

+
+
+ +▼ code +▼ output + ▶ uv-logs + | +Cell: binned_run | deps: torch, numpy | 39.23s + | + +Raw +
+
+
+
+1 +2 +3 +4 +5 +6 +7 +8 +9 +10 +11 +12 +13 +14 +15 +16 +17 +18 +19 +20 +21 +22 +23 +24 +25 +26 +27 +28 +29 +30 +31 +32 +33 +34 +35 +36 +37 +38 +39 +40 +41 +42 +43 +44 +45 +46 +47 +48 +49 +50 +51 +52 +53 +54 +55 +56 +57 +58 +59 +60 +61 +62 +63 +64 +65 +66 +67 +68 +69 +70 +71 +72 +73 +74 +75 +76 +77 +78 +79 +80 +81 +82 +83 +84 +85 +86 +87 +88 +89 +90 +91 +92 +93 +94 +95 +96 +97 +98 +99 +100 +101 +102 +103 +104 +105 +106 +107 +108 +109 +110 +111 +112 +113 +114 +115 +116 +117 +118 +119 +120 +121 +122 +123 +124 +125 +126 +127 +128 +129 +130 +131 +132 +133 +134 +135 +136 +137 +138 +139 +140 +141 +142 +143 +144 +145 +146 +147 +148 +149 +150 +151 +152 +153 +154 +155 +156 +157 +158 +159 +160 +161 +162 +163 +164 +165 +166 +167 +168 +169 +170 +171 +172 +173 +174 +175 +176 +177 +178 +179 +180 +181 +182 +183 +184 +185 +186 +187 +188 +
+
+
import torch
+from torch import nn
+from torch.nn import functional as F
+from bench_utils import to_dtype, tensor_stats, set_seed, bench_context
+from config import (
+    NUM_EXPERTS, HIDDEN_SIZE, TOP_K,
+    BATCH_SIZE, SEQ_LEN, DTYPE, DEVICE,
+    WEIGHT_SEED, EXPERT_SEED, INPUT_SEED, GENERAL_SEED
+)
+from pathlib import Path
+import os
+
+# Discover the upstream artifact directory from env
+data_dir = os.environ.get('UVNOTE_INPUT_SAVE_DATA', '.')
+
+router_weight = torch.load(Path(data_dir) / 'router_weight.pt')
+router_bias = torch.load(Path(data_dir) / 'router_bias.pt')
+gate_up_proj = torch.load(Path(data_dir) / 'gate_up_proj.pt')
+gate_up_proj_bias = torch.load(Path(data_dir) / 'gate_up_proj_bias.pt')
+down_proj = torch.load(Path(data_dir) / 'down_proj.pt')
+down_proj_bias = torch.load(Path(data_dir) / 'down_proj_bias.pt')
+
+print("Loaded shared weights from artifacts")
+print(f"Router weight sum: {router_weight.sum().item():.6f}")
+print(f"Gate/up sum: {gate_up_proj.sum().item():.6f}")
+print(f"Down sum: {down_proj.sum().item():.6f}")
+
+def binned_gather(x, indices, bins, expert_capacity, top_k):
+    E, H = bins.shape[0], x.shape[1]
+    out = torch.zeros((E, expert_capacity, H), device=x.device, dtype=x.dtype)
+    for e in range(E):
+        start = 0 if e == 0 else bins[e - 1]
+        end = bins[e]
+        n = min(end - start, expert_capacity)
+        for i in range(n):
+            flat_pos = indices[start + i]
+            tok = flat_pos // top_k
+            out[e, i] = x[tok]
+    return out
+
+def binned_scatter(x, indices, weights, bins, expert_capacity, top_k):
+    E, C, H = x.shape
+    N = indices.shape[0] // top_k
+    out = torch.zeros((N, top_k, H), dtype=x.dtype, device=x.device)
+    for e in range(E):
+        start = 0 if e == 0 else bins[e - 1]
+        end = bins[e]
+        n = end - start
+        if n == 0:
+            continue
+        take = min(n, expert_capacity)
+        for i in range(take):
+            flat_pos = indices[start + i]
+            tok = flat_pos // top_k
+            slot = flat_pos % top_k
+            scale = weights[flat_pos] if weights is not None else 1.0
+            out[tok, slot] = x[e, i] * scale
+    return out.sum(dim=1)
+
+def sort_tokens_by_expert(router_indices, num_experts):
+    flat_indices = router_indices.flatten()
+    sorted_values, sorted_indices = torch.sort(flat_indices)
+    tokens_per_expert = torch.bincount(sorted_values, minlength=num_experts)
+    bins = torch.cumsum(tokens_per_expert, dim=0)
+    return sorted_indices, sorted_values, bins, tokens_per_expert
+
+def binned_experts_ref(
+    hidden_states,
+    router_indices,
+    routing_weights,
+    gate_up_proj,
+    gate_up_proj_bias,
+    down_proj,
+    down_proj_bias,
+    expert_capacity,
+):
+    B, S, H = hidden_states.shape
+    E, K = routing_weights.shape[1], router_indices.shape[1]
+
+    indices, _, bins, _ = sort_tokens_by_expert(router_indices, E)
+    x = binned_gather(hidden_states.view(-1, H), indices, bins, expert_capacity, K)
+
+    gate_up = torch.bmm(x, gate_up_proj) 
+    gate_up += gate_up_proj_bias[..., None, :]
+
+    gate, up = gate_up[..., ::2], gate_up[..., 1::2]
+
+    # clamp to limit
+    limit = 7.0
+    gate = gate.clamp(min=None, max=limit)
+    up = up.clamp(min=-limit, max=limit)
+
+    glu = gate * torch.sigmoid(gate * 1.702)
+    x = (up + 1) * glu
+    x = torch.bmm(x, down_proj) + down_proj_bias[..., None, :]
+
+    # build routing weights aligned to (token, slot)
+    flat_dense = routing_weights.view(-1, E)
+    flat_router = router_indices.view(-1, K)
+    selected = torch.gather(flat_dense, 1, flat_router).reshape(-1)
+
+    # scatter back
+    y = binned_scatter(x, indices, selected, bins, expert_capacity, K)
+
+    return y.view(B, S, H)
+
+class BinnedRouter(nn.Module):
+    def __init__(self, router_weight, router_bias):
+        super().__init__()
+        self.top_k = TOP_K
+        self.num_experts = NUM_EXPERTS
+        self.hidden_dim = HIDDEN_SIZE
+        self.weight = nn.Parameter(router_weight.clone())
+        self.bias = nn.Parameter(router_bias.clone())
+
+    def forward(self, hidden_states):
+        hidden_states = hidden_states.reshape(-1, self.hidden_dim)
+        router_logits = F.linear(hidden_states, self.weight, self.bias)
+        router_top_value, router_indices = torch.topk(router_logits, self.top_k, dim=-1)
+        router_top_value = torch.nn.functional.softmax(router_top_value, dim=1, dtype=router_top_value.dtype)
+        router_scores = torch.zeros_like(router_logits).scatter_(1, router_indices, router_top_value)
+        return router_scores, router_indices
+
+def ceil_div(a, b):
+    return (a + b - 1) // b
+
+class BinnedMoEMLP(nn.Module):
+    def __init__(self, router_weight, router_bias, gate_up_proj, gate_up_proj_bias, down_proj, down_proj_bias):
+        super().__init__()
+        self.router = BinnedRouter(router_weight, router_bias)
+        self.num_experts = NUM_EXPERTS
+        self.hidden_size = HIDDEN_SIZE
+        self.top_k = TOP_K
+
+        # Expert weights - use the loaded weights
+        self.gate_up_proj = nn.Parameter(gate_up_proj.clone())
+        self.gate_up_proj_bias = nn.Parameter(gate_up_proj_bias.clone())
+        self.down_proj = nn.Parameter(down_proj.clone())
+        self.down_proj_bias = nn.Parameter(down_proj_bias.clone())
+
+    def forward(self, hidden_states):
+        router_scores, router_indices = self.router(hidden_states)
+        batch_size = hidden_states.shape[0]
+        expert_capacity = ceil_div(batch_size * self.top_k, self.num_experts)
+
+        output = binned_experts_ref(
+            hidden_states,
+            router_indices,
+            router_scores,
+            self.gate_up_proj,
+            self.gate_up_proj_bias,
+            self.down_proj,
+            self.down_proj_bias,
+            expert_capacity,
+        )
+
+        return output, router_scores
+
+# Run the model
+set_seed(GENERAL_SEED)
+
+device = torch.device(DEVICE)
+dtype = to_dtype(DTYPE)
+
+print("\n=== Binned Implementation ===")
+# Initialize model with loaded weights
+model = BinnedMoEMLP(
+    router_weight.to(device),
+    router_bias.to(device),
+    gate_up_proj.to(device),
+    gate_up_proj_bias.to(device),
+    down_proj.to(device),
+    down_proj_bias.to(device)
+).to(device=device)
+
+print(f"Router weight sum: {model.router.weight.sum().item():.6f}")
+print(f"Gate/up proj sum: {model.gate_up_proj.sum().item():.6f}")
+print(f"Down proj sum: {model.down_proj.sum().item():.6f}")
+
+# Generate the same input as Yamoe
+set_seed(INPUT_SEED)
+x = torch.randn(BATCH_SIZE, SEQ_LEN, HIDDEN_SIZE, device=device, dtype=dtype) * 0.1
+
+# Benchmark the model with varied inputs to prevent caching artifacts
+tokens = BATCH_SIZE * SEQ_LEN
+with bench_context(warmup=10, iters=50, device=device, dtype=dtype, tokens=tokens, save_json="binned_results.json", vary_inputs=True) as bench:
+    output, stats = bench(model, x)
+    print(f"\nOutput sum: {output[0].sum().item():.6f}")
+
+ +
+
+
+
+
+
Loaded shared weights from artifacts +Router weight sum: 12.588732 +Gate/up sum: 1026.601807 +Down sum: 206.729263 + +=== Binned Implementation === +Router weight sum: 12.588732 +Gate/up proj sum: 1026.601807 +Down proj sum: 206.729340 + +┌─ Benchmark Configuration ─────────────────────────────┐ +│ Warmup: 10 Iters: 50 │ +│ Tokens: 100 │ +│ Input Variation: Enabled (prevents caching artifacts) │ +└────────────────────────────────────────────────────────┘ + +Base Input: shape=(1, 100, 1152), dtype=torch.float32, device=cuda:0, range=[-0.486445, 0.446746], mean=-0.000048, std=0.099986, norm=33.936142 +Input Variation: +0.001 * iteration (deterministic) + +Warming up (10 iterations)... +Benchmarking (50 iterations)... + Progress: 20% complete (avg: 37.503 ms) + Progress: 40% complete (avg: 37.304 ms) + Progress: 60% complete (avg: 36.964 ms) + Progress: 80% complete (avg: 36.508 ms) + +Output tensors: + Primary: shape=(1, 100, 1152), dtype=torch.float32, device=cuda:0, range=[-0.049506, 0.054984], mean=0.000034, std=0.006508, norm=2.208791 + Auxiliary: shape=(100, 128), dtype=torch.float32, device=cuda:0, range=[0.000000, 0.302948], mean=0.007812, std=0.043553, norm=5.005893 + +━━━━━━━━━━━━━━━━━━━━ Benchmark Results ━━━━━━━━━━━━━━━━━━━━ +Iterations: 50 + +Latency Statistics: + Average: 36.063 ms + Min: 33.292 ms + Max: 38.406 ms + Std Dev: 1.259 ms + +Percentiles: + P50 (median): 36.215 ms + P95: 37.524 ms + P99: 38.036 ms + +Throughput: + Tokens/sec: 2772.9 + Std Dev: 98.3 +━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ + +Saved benchmark results to binned_results.json + +Output sum: 3.971905 +
+
+
▶ UV Install Logs
+ +
+
+

Artifacts:

+binned_results.json +
+
+
+

GPT-OSS Implementation

This section runs the GPT-OSS MoE implementation with manual expert loop handling.

+
+
+ +▼ code +▼ output + ▶ uv-logs + | +Cell: gptoss_run | deps: torch, numpy | 39.77s + | + +Raw +
+
+
+
+1 +2 +3 +4 +5 +6 +7 +8 +9 +10 +11 +12 +13 +14 +15 +16 +17 +18 +19 +20 +21 +22 +23 +24 +25 +26 +27 +28 +29 +30 +31 +32 +33 +34 +35 +36 +37 +38 +39 +40 +41 +42 +43 +44 +45 +46 +47 +48 +49 +50 +51 +52 +53 +54 +55 +56 +57 +58 +59 +60 +61 +62 +63 +64 +65 +66 +67 +68 +69 +70 +71 +72 +73 +74 +75 +76 +77 +78 +79 +80 +81 +82 +83 +84 +85 +86 +87 +88 +89 +90 +91 +92 +93 +94 +95 +96 +97 +98 +99 +100 +101 +102 +103 +104 +105 +106 +107 +108 +109 +110 +111 +112 +113 +114 +115 +116 +117 +118 +119 +120 +121 +122 +123 +124 +125 +126 +127 +128 +129 +130 +131 +132 +133 +134 +135 +136 +137 +138 +139 +140 +
+
+
import torch
+from torch import nn
+from torch.nn import functional as F
+from bench_utils import to_dtype, tensor_stats, set_seed, bench_context
+from config import (
+    NUM_EXPERTS, HIDDEN_SIZE, TOP_K,
+    BATCH_SIZE, SEQ_LEN, DTYPE, DEVICE,
+    WEIGHT_SEED, EXPERT_SEED, INPUT_SEED, GENERAL_SEED
+)
+from pathlib import Path
+import os
+
+# Discover the upstream artifact directory from env
+data_dir = os.environ.get('UVNOTE_INPUT_SAVE_DATA', '.')
+
+router_weight = torch.load(Path(data_dir) / 'router_weight.pt')
+router_bias = torch.load(Path(data_dir) / 'router_bias.pt')
+gate_up_proj = torch.load(Path(data_dir) / 'gate_up_proj.pt')
+gate_up_proj_bias = torch.load(Path(data_dir) / 'gate_up_proj_bias.pt')
+down_proj = torch.load(Path(data_dir) / 'down_proj.pt')
+down_proj_bias = torch.load(Path(data_dir) / 'down_proj_bias.pt')
+
+print("Loaded shared weights from artifacts")
+print(f"Router weight sum: {router_weight.sum().item():.6f}")
+print(f"Gate/up sum: {gate_up_proj.sum().item():.6f}")
+print(f"Down sum: {down_proj.sum().item():.6f}")
+
+class GptOssRouter(nn.Module):
+    def __init__(self, router_weight, router_bias):
+        super().__init__()
+        self.top_k = TOP_K
+        self.num_experts = NUM_EXPERTS
+        self.hidden_dim = HIDDEN_SIZE
+        self.weight = nn.Parameter(router_weight.clone())
+        self.bias = nn.Parameter(router_bias.clone())
+
+    def forward(self, hidden_states):
+        hidden_states = hidden_states.reshape(-1, self.hidden_dim)
+        router_logits = F.linear(hidden_states, self.weight, self.bias)
+        router_top_value, router_indices = torch.topk(router_logits, self.top_k, dim=-1)
+        router_top_value = torch.nn.functional.softmax(router_top_value, dim=1, dtype=router_top_value.dtype)
+        router_scores = torch.zeros_like(router_logits).scatter_(1, router_indices, router_top_value)
+        return router_scores, router_indices
+
+class GptOssExperts(nn.Module):
+    def __init__(self, gate_up_proj, gate_up_proj_bias, down_proj, down_proj_bias):
+        super().__init__()
+        self.num_experts = NUM_EXPERTS
+        self.hidden_size = HIDDEN_SIZE
+        self.expert_dim = self.hidden_size
+        self.gate_up_proj = nn.Parameter(gate_up_proj.clone())
+        self.gate_up_proj_bias = nn.Parameter(gate_up_proj_bias.clone())
+        self.down_proj = nn.Parameter(down_proj.clone())
+        self.down_proj_bias = nn.Parameter(down_proj_bias.clone())
+        self.alpha = 1.702
+        self.limit = 7.0
+
+    def forward(self, hidden_states: torch.Tensor, router_indices=None, routing_weights=None) -> torch.Tensor:
+        batch_size = hidden_states.shape[0]
+        hidden_states = hidden_states.reshape(-1, self.hidden_size)
+        num_experts = routing_weights.shape[1]
+
+        if hidden_states.device.type == "cpu" or self.training:
+            next_states = torch.zeros_like(hidden_states, dtype=hidden_states.dtype, device=hidden_states.device)
+            with torch.no_grad():
+                expert_mask = torch.nn.functional.one_hot(router_indices, num_classes=num_experts)
+                expert_mask = expert_mask.permute(2, 1, 0)
+                expert_hit = torch.greater(expert_mask.sum(dim=(-1, -2)), 0).nonzero()
+
+            for expert_idx in expert_hit[:]:
+                expert_idx = expert_idx[0]
+                with torch.no_grad():
+                    _, token_idx = torch.where(expert_mask[expert_idx])
+                current_state = hidden_states[token_idx]
+                gate_up = current_state @ self.gate_up_proj[expert_idx] + self.gate_up_proj_bias[expert_idx]
+                gate, up = gate_up[..., ::2], gate_up[..., 1::2]
+                gate = gate.clamp(min=None, max=self.limit)
+                up = up.clamp(min=-self.limit, max=self.limit)
+                glu = gate * torch.sigmoid(gate * self.alpha)
+                gated_output = (up + 1) * glu
+                out = gated_output @ self.down_proj[expert_idx] + self.down_proj_bias[expert_idx]
+                weighted_output = out * routing_weights[token_idx, expert_idx, None]
+                next_states.index_add_(0, token_idx, weighted_output.to(hidden_states.dtype))
+            next_states = next_states.view(batch_size, -1, self.hidden_size)
+        else:
+            hidden_states = hidden_states.repeat(num_experts, 1)
+            hidden_states = hidden_states.view(num_experts, -1, self.hidden_size)
+            gate_up = torch.bmm(hidden_states, self.gate_up_proj) + self.gate_up_proj_bias[..., None, :]
+            gate, up = gate_up[..., ::2], gate_up[..., 1::2]
+            gate = gate.clamp(min=None, max=self.limit)
+            up = up.clamp(min=-self.limit, max=self.limit)
+            glu = gate * torch.sigmoid(gate * self.alpha)
+            next_states = torch.bmm(((up + 1) * glu), self.down_proj)
+            next_states = next_states + self.down_proj_bias[..., None, :]
+            next_states = next_states.view(num_experts, batch_size, -1, self.hidden_size)
+            next_states = next_states * routing_weights.transpose(0, 1).view(num_experts, batch_size, -1)[..., None]
+            next_states = next_states.sum(dim=0)
+        return next_states
+
+class GptOssMoEMLP(nn.Module):
+    def __init__(self, router_weight, router_bias, gate_up_proj, gate_up_proj_bias, down_proj, down_proj_bias):
+        super().__init__()
+        self.router = GptOssRouter(router_weight, router_bias)
+        self.experts = GptOssExperts(gate_up_proj, gate_up_proj_bias, down_proj, down_proj_bias)
+
+    def forward(self, hidden_states):
+        router_scores, router_indices = self.router(hidden_states)
+        routed_out = self.experts(hidden_states, router_indices=router_indices, routing_weights=router_scores)
+        return routed_out, router_scores
+
+# Run the model
+set_seed(GENERAL_SEED)
+
+device = torch.device(DEVICE)
+dtype = to_dtype(DTYPE)
+
+print("\n=== GPT-OSS Implementation ===")
+# Initialize model with loaded weights
+model = GptOssMoEMLP(
+    router_weight.to(device),
+    router_bias.to(device),
+    gate_up_proj.to(device),
+    gate_up_proj_bias.to(device),
+    down_proj.to(device),
+    down_proj_bias.to(device)
+).to(device=device)
+
+print(f"Router weight sum: {model.router.weight.sum().item():.6f}")
+print(f"Gate/up proj sum: {model.experts.gate_up_proj.sum().item():.6f}")
+print(f"Down proj sum: {model.experts.down_proj.sum().item():.6f}")
+
+# Generate the same input as other implementations
+set_seed(INPUT_SEED)
+x = torch.randn(BATCH_SIZE, SEQ_LEN, HIDDEN_SIZE, device=device, dtype=dtype) * 0.1
+
+# Benchmark the model with varied inputs to prevent caching artifacts
+tokens = BATCH_SIZE * SEQ_LEN
+with bench_context(warmup=10, iters=50, device=device, dtype=dtype, tokens=tokens, save_json="gptoss_results.json", vary_inputs=True) as bench:
+    output, stats = bench(model, x)
+    print(f"\nOutput sum: {output[0].sum().item():.6f}")
+
+ +
+
+
+
+
+
Loaded shared weights from artifacts +Router weight sum: 12.588732 +Gate/up sum: 1026.601807 +Down sum: 206.729263 + +=== GPT-OSS Implementation === +Router weight sum: 12.588732 +Gate/up proj sum: 1026.601807 +Down proj sum: 206.729340 + +┌─ Benchmark Configuration ─────────────────────────────┐ +│ Warmup: 10 Iters: 50 │ +│ Tokens: 100 │ +│ Input Variation: Enabled (prevents caching artifacts) │ +└────────────────────────────────────────────────────────┘ + +Base Input: shape=(1, 100, 1152), dtype=torch.float32, device=cuda:0, range=[-0.486445, 0.446746], mean=-0.000048, std=0.099986, norm=33.936142 +Input Variation: +0.001 * iteration (deterministic) + +Warming up (10 iterations)... +Benchmarking (50 iterations)... + Progress: 20% complete (avg: 48.905 ms) + Progress: 40% complete (avg: 48.717 ms) + Progress: 60% complete (avg: 47.570 ms) + Progress: 80% complete (avg: 46.370 ms) + +Output tensors: + Primary: shape=(1, 100, 1152), dtype=torch.float32, device=cuda:0, range=[-0.064982, 0.061193], mean=0.000100, std=0.013510, norm=4.585560 + Auxiliary: shape=(100, 128), dtype=torch.float32, device=cuda:0, range=[0.000000, 0.302948], mean=0.007812, std=0.043553, norm=5.005893 + +━━━━━━━━━━━━━━━━━━━━ Benchmark Results ━━━━━━━━━━━━━━━━━━━━ +Iterations: 50 + +Latency Statistics: + Average: 45.287 ms + Min: 38.914 ms + Max: 49.844 ms + Std Dev: 3.233 ms + +Percentiles: + P50 (median): 45.422 ms + P95: 49.730 ms + P99: 49.825 ms + +Throughput: + Tokens/sec: 2208.2 + Std Dev: 161.3 +━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ + +Saved benchmark results to gptoss_results.json + +Output sum: 11.532237 +
+
+
▶ UV Install Logs
+ +
+
+

Artifacts:

+gptoss_results.json +
+
+
+

GPT-OSS Implementation (Training Mode)

This section runs the GPT-OSS MoE implementation with training mode enabled to force the expert loop path.

+
+
+ +▼ code +▼ output + ▶ uv-logs + | +Cell: gptoss_training_run | deps: torch, numpy | 40.24s + | + +Raw +
+
+
+
+1 +2 +3 +4 +5 +6 +7 +8 +9 +10 +11 +12 +13 +14 +15 +16 +17 +18 +19 +20 +21 +22 +23 +24 +25 +26 +27 +28 +29 +30 +31 +32 +33 +34 +35 +36 +37 +38 +39 +40 +41 +42 +43 +44 +45 +46 +47 +48 +49 +50 +51 +52 +53 +54 +55 +56 +57 +58 +59 +60 +61 +62 +63 +64 +65 +66 +67 +68 +69 +70 +71 +72 +73 +74 +75 +76 +77 +78 +79 +80 +81 +82 +83 +84 +85 +86 +87 +88 +89 +90 +91 +92 +93 +94 +95 +96 +97 +98 +99 +100 +101 +102 +103 +104 +105 +106 +107 +108 +109 +110 +111 +112 +113 +114 +115 +116 +117 +118 +119 +120 +121 +122 +123 +124 +125 +126 +127 +128 +129 +130 +131 +
+
+
import torch
+from torch import nn
+from torch.nn import functional as F
+from bench_utils import to_dtype, tensor_stats, set_seed, bench_context
+from config import (
+    NUM_EXPERTS, HIDDEN_SIZE, TOP_K,
+    BATCH_SIZE, SEQ_LEN, DTYPE, DEVICE,
+    WEIGHT_SEED, EXPERT_SEED, INPUT_SEED, GENERAL_SEED
+)
+from pathlib import Path
+import os
+
+# Discover the upstream artifact directory from env
+data_dir = os.environ.get('UVNOTE_INPUT_SAVE_DATA', '.')
+
+router_weight = torch.load(Path(data_dir) / 'router_weight.pt')
+router_bias = torch.load(Path(data_dir) / 'router_bias.pt')
+gate_up_proj = torch.load(Path(data_dir) / 'gate_up_proj.pt')
+gate_up_proj_bias = torch.load(Path(data_dir) / 'gate_up_proj_bias.pt')
+down_proj = torch.load(Path(data_dir) / 'down_proj.pt')
+down_proj_bias = torch.load(Path(data_dir) / 'down_proj_bias.pt')
+
+print("Loaded shared weights from artifacts")
+print(f"Router weight sum: {router_weight.sum().item():.6f}")
+print(f"Gate/up sum: {gate_up_proj.sum().item():.6f}")
+print(f"Down sum: {down_proj.sum().item():.6f}")
+
+class GptOssTrainingRouter(nn.Module):
+    def __init__(self, router_weight, router_bias):
+        super().__init__()
+        self.top_k = TOP_K
+        self.num_experts = NUM_EXPERTS
+        self.hidden_dim = HIDDEN_SIZE
+        self.weight = nn.Parameter(router_weight.clone())
+        self.bias = nn.Parameter(router_bias.clone())
+
+    def forward(self, hidden_states):
+        hidden_states = hidden_states.reshape(-1, self.hidden_dim)
+        router_logits = F.linear(hidden_states, self.weight, self.bias)
+        router_top_value, router_indices = torch.topk(router_logits, self.top_k, dim=-1)
+        router_top_value = torch.nn.functional.softmax(router_top_value, dim=1, dtype=router_top_value.dtype)
+        router_scores = torch.zeros_like(router_logits).scatter_(1, router_indices, router_top_value)
+        return router_scores, router_indices
+
+class GptOssTrainingExperts(nn.Module):
+    def __init__(self, gate_up_proj, gate_up_proj_bias, down_proj, down_proj_bias):
+        super().__init__()
+        self.num_experts = NUM_EXPERTS
+        self.hidden_size = HIDDEN_SIZE
+        self.expert_dim = self.hidden_size
+        self.gate_up_proj = nn.Parameter(gate_up_proj.clone())
+        self.gate_up_proj_bias = nn.Parameter(gate_up_proj_bias.clone())
+        self.down_proj = nn.Parameter(down_proj.clone())
+        self.down_proj_bias = nn.Parameter(down_proj_bias.clone())
+        self.alpha = 1.702
+        self.limit = 7.0
+
+    def forward(self, hidden_states: torch.Tensor, router_indices=None, routing_weights=None) -> torch.Tensor:
+        batch_size = hidden_states.shape[0]
+        hidden_states = hidden_states.reshape(-1, self.hidden_size)
+        num_experts = routing_weights.shape[1]
+
+        # Force training mode path (expert loop instead of batched)
+        next_states = torch.zeros_like(hidden_states, dtype=hidden_states.dtype, device=hidden_states.device)
+        with torch.no_grad():
+            expert_mask = torch.nn.functional.one_hot(router_indices, num_classes=num_experts)
+            expert_mask = expert_mask.permute(2, 1, 0)
+            expert_hit = torch.greater(expert_mask.sum(dim=(-1, -2)), 0).nonzero()
+
+        for expert_idx in expert_hit[:]:
+            expert_idx = expert_idx[0]
+            with torch.no_grad():
+                _, token_idx = torch.where(expert_mask[expert_idx])
+            current_state = hidden_states[token_idx]
+            gate_up = current_state @ self.gate_up_proj[expert_idx] + self.gate_up_proj_bias[expert_idx]
+            gate, up = gate_up[..., ::2], gate_up[..., 1::2]
+            gate = gate.clamp(min=None, max=self.limit)
+            up = up.clamp(min=-self.limit, max=self.limit)
+            glu = gate * torch.sigmoid(gate * self.alpha)
+            gated_output = (up + 1) * glu
+            out = gated_output @ self.down_proj[expert_idx] + self.down_proj_bias[expert_idx]
+            weighted_output = out * routing_weights[token_idx, expert_idx, None]
+            next_states.index_add_(0, token_idx, weighted_output.to(hidden_states.dtype))
+        next_states = next_states.view(batch_size, -1, self.hidden_size)
+        return next_states
+
+class GptOssTrainingMoEMLP(nn.Module):
+    def __init__(self, router_weight, router_bias, gate_up_proj, gate_up_proj_bias, down_proj, down_proj_bias):
+        super().__init__()
+        self.router = GptOssTrainingRouter(router_weight, router_bias)
+        self.experts = GptOssTrainingExperts(gate_up_proj, gate_up_proj_bias, down_proj, down_proj_bias)
+
+    def forward(self, hidden_states):
+        router_scores, router_indices = self.router(hidden_states)
+        routed_out = self.experts(hidden_states, router_indices=router_indices, routing_weights=router_scores)
+        return routed_out, router_scores
+
+# Run the model
+set_seed(GENERAL_SEED)
+
+device = torch.device(DEVICE)
+dtype = to_dtype(DTYPE)
+
+print("\n=== GPT-OSS Implementation (Training Mode - Expert Loop) ===")
+# Initialize model with loaded weights and force training mode
+model = GptOssTrainingMoEMLP(
+    router_weight.to(device),
+    router_bias.to(device),
+    gate_up_proj.to(device),
+    gate_up_proj_bias.to(device),
+    down_proj.to(device),
+    down_proj_bias.to(device)
+).to(device=device)
+
+# Set to training mode to force expert loop path
+model.train()
+
+print(f"Router weight sum: {model.router.weight.sum().item():.6f}")
+print(f"Gate/up proj sum: {model.experts.gate_up_proj.sum().item():.6f}")
+print(f"Down proj sum: {model.experts.down_proj.sum().item():.6f}")
+print(f"Model training mode: {model.training}")
+
+# Generate the same input as other implementations
+set_seed(INPUT_SEED)
+x = torch.randn(BATCH_SIZE, SEQ_LEN, HIDDEN_SIZE, device=device, dtype=dtype) * 0.1
+
+# Benchmark the model with varied inputs to prevent caching artifacts
+tokens = BATCH_SIZE * SEQ_LEN
+with bench_context(warmup=10, iters=50, device=device, dtype=dtype, tokens=tokens, save_json="gptoss_training_results.json", vary_inputs=True) as bench:
+    output, stats = bench(model, x)
+    print(f"\nOutput sum: {output[0].sum().item():.6f}")
+
+ +
+
+
+
+
+
Loaded shared weights from artifacts +Router weight sum: 12.588732 +Gate/up sum: 1026.601807 +Down sum: 206.729263 + +=== GPT-OSS Implementation (Training Mode - Expert Loop) === +Router weight sum: 12.588732 +Gate/up proj sum: 1026.601807 +Down proj sum: 206.729340 +Model training mode: True + +┌─ Benchmark Configuration ─────────────────────────────┐ +│ Warmup: 10 Iters: 50 │ +│ Tokens: 100 │ +│ Input Variation: Enabled (prevents caching artifacts) │ +└────────────────────────────────────────────────────────┘ + +Base Input: shape=(1, 100, 1152), dtype=torch.float32, device=cuda:0, range=[-0.486445, 0.446746], mean=-0.000048, std=0.099986, norm=33.936142 +Input Variation: +0.001 * iteration (deterministic) + +Warming up (10 iterations)... +Benchmarking (50 iterations)... + Progress: 20% complete (avg: 49.963 ms) + Progress: 40% complete (avg: 49.344 ms) + Progress: 60% complete (avg: 48.274 ms) + Progress: 80% complete (avg: 47.165 ms) + +Output tensors: + Primary: shape=(1, 100, 1152), dtype=torch.float32, device=cuda:0, range=[-0.064982, 0.061193], mean=0.000100, std=0.013510, norm=4.585560 + Auxiliary: shape=(100, 128), dtype=torch.float32, device=cuda:0, range=[0.000000, 0.302948], mean=0.007812, std=0.043553, norm=5.005893 + +━━━━━━━━━━━━━━━━━━━━ Benchmark Results ━━━━━━━━━━━━━━━━━━━━ +Iterations: 50 + +Latency Statistics: + Average: 46.010 ms + Min: 39.207 ms + Max: 51.098 ms + Std Dev: 3.259 ms + +Percentiles: + P50 (median): 46.133 ms + P95: 50.721 ms + P99: 51.008 ms + +Throughput: + Tokens/sec: 2173.4 + Std Dev: 158.7 +━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ + +Saved benchmark results to gptoss_training_results.json + +Output sum: 11.532237 +
+
+
▶ UV Install Logs
+ +
+ +
+
+

MegaBlocks Implementation

This section runs the MegaBlocks MoE implementation with optimized kernels from the Hugging Face hub.

+
+
+ +▼ code +▼ output + ▶ uv-logs + | +Cell: megablocks_run | deps: torch, numpy, kernels | 40.58s | FAILED + | + +Raw +
+
+
+
+1 +2 +3 +4 +5 +6 +7 +8 +9 +10 +11 +12 +13 +14 +15 +16 +17 +18 +19 +20 +21 +22 +23 +24 +25 +26 +27 +28 +29 +30 +31 +32 +33 +34 +35 +36 +37 +38 +39 +40 +41 +42 +43 +44 +45 +46 +47 +48 +49 +50 +51 +52 +53 +54 +55 +56 +57 +58 +59 +60 +61 +62 +63 +64 +65 +66 +67 +68 +69 +70 +71 +72 +73 +74 +75 +76 +77 +78 +79 +80 +81 +82 +83 +84 +85 +86 +87 +88 +89 +90 +91 +92 +93 +94 +95 +
+
+
import torch
+from torch import nn
+from torch.nn import functional as F
+from kernels import get_kernel, get_local_kernel
+from bench_utils import to_dtype, tensor_stats, set_seed, bench_context
+from config import (
+    NUM_EXPERTS, HIDDEN_SIZE, TOP_K,
+    BATCH_SIZE, SEQ_LEN, DTYPE, DEVICE,
+    WEIGHT_SEED, EXPERT_SEED, INPUT_SEED, GENERAL_SEED
+)
+from pathlib import Path
+from collections import namedtuple
+import os
+
+# Discover the upstream artifact directory from env
+data_dir = os.environ.get('UVNOTE_INPUT_SAVE_DATA', '.')
+
+print(f"Loading weights from: {data_dir}")
+
+router_weight = torch.load(Path(data_dir) / 'router_weight.pt')
+router_bias = torch.load(Path(data_dir) / 'router_bias.pt')
+gate_up_proj = torch.load(Path(data_dir) / 'gate_up_proj.pt')
+gate_up_proj_bias = torch.load(Path(data_dir) / 'gate_up_proj_bias.pt')
+down_proj = torch.load(Path(data_dir) / 'down_proj.pt')
+down_proj_bias = torch.load(Path(data_dir) / 'down_proj_bias.pt')
+
+print("Loaded shared weights from artifacts")
+print(f"Router weight sum: {router_weight.sum().item():.6f}")
+print(f"Gate/up sum: {gate_up_proj.sum().item():.6f}")
+print(f"Down sum: {down_proj.sum().item():.6f}")
+
+def build_megablocks_model(device: torch.device):
+    # Download optimized kernels from the Hugging Face hub
+    megablocks = get_kernel("kernels-community/megablocks", revision="v0.0.2")
+    model = megablocks.layers.MegaBlocksMoeMLP()
+
+    # Create attribute container for expert weights
+    model.experts = namedtuple(
+        "Experts", ["gate_up_proj", "gate_up_proj_bias", "down_proj", "down_proj_bias", "hidden_size"]
+    )
+
+    # Use loaded router weights for consistency
+    model.router = torch.nn.Linear(HIDDEN_SIZE, NUM_EXPERTS, device=device)
+    with torch.no_grad():
+        model.router.weight.copy_(router_weight)
+        model.router.bias.copy_(router_bias)
+
+    # Attach loaded expert weights to the experts container
+    e = model.experts
+    e.alpha = 1.702
+    e.capacity_factor = 32
+    e.gate_up_proj = torch.nn.Parameter(gate_up_proj.clone().to(device))
+    e.gate_up_proj_bias = torch.nn.Parameter(gate_up_proj_bias.clone().to(device))
+    e.down_proj = torch.nn.Parameter(down_proj.clone().to(device))
+    e.down_proj_bias = torch.nn.Parameter(down_proj_bias.clone().to(device))
+    e.hidden_size = HIDDEN_SIZE
+
+    # Log weight statistics for comparison
+    print(f"[MegaBlocks] Router weight sum: {model.router.weight.sum().item():.6f}")
+    print(f"[MegaBlocks] Gate/up projection shape: {tuple(e.gate_up_proj.shape)}, sum: {e.gate_up_proj.sum().item():.6f}")
+    print(f"[MegaBlocks] Down projection shape: {tuple(e.down_proj.shape)}, sum: {e.down_proj.sum().item():.6f}")
+
+    return model
+
+# Create a wrapper to match the interface of other implementations
+class MegaBlocksMoEWrapper(nn.Module):
+    def __init__(self, megablocks_model):
+        super().__init__()
+        self.model = megablocks_model
+
+    def forward(self, hidden_states):
+        # MegaBlocks expects input in the format (batch, seq_len, hidden_dim)
+        output, dummy_routing_weights = self.model(hidden_states)
+        return output, dummy_routing_weights
+
+# Run the model
+set_seed(GENERAL_SEED)
+
+device = torch.device(DEVICE)
+dtype = to_dtype(DTYPE)
+
+print("\n=== MegaBlocks Implementation ===")
+# Build MegaBlocks model with loaded weights
+megablocks_model = build_megablocks_model(device)
+model = MegaBlocksMoEWrapper(megablocks_model).to(device=device)
+
+# Generate the same input as other implementations
+set_seed(INPUT_SEED)
+x = torch.randn(BATCH_SIZE, SEQ_LEN, HIDDEN_SIZE, device=device, dtype=dtype) * 0.1
+
+# Benchmark the model with varied inputs to prevent caching artifacts
+tokens = BATCH_SIZE * SEQ_LEN
+with bench_context(warmup=10, iters=50, device=device, dtype=dtype, tokens=tokens, save_json="megablocks_results.json", vary_inputs=True) as bench:
+    output, stats = bench(model, x)
+    print(f"\nOutput sum: {output[0].sum().item():.6f}")
+
+ +
+
+
+
+
+
Loading weights from: /repo/moe_benchmarks/megablocks_yamoe/.uvnote/cache/f8744f31d9cf720409852d42748815c6d61f005a2a9b297b7b9bf986ed98bb90 +Loaded shared weights from artifacts +Router weight sum: 12.588732 +Gate/up sum: 1026.601807 +Down sum: 206.729263 + +=== MegaBlocks Implementation === +[MegaBlocks] Router weight sum: 12.588732 +[MegaBlocks] Gate/up projection shape: (128, 1152, 2304), sum: 1026.601807 +[MegaBlocks] Down projection shape: (128, 1152, 1152), sum: 206.729340 + +┌─ Benchmark Configuration ─────────────────────────────┐ +│ Warmup: 10 Iters: 50 │ +│ Tokens: 100 │ +│ Input Variation: Enabled (prevents caching artifacts) │ +└────────────────────────────────────────────────────────┘ + +Base Input: shape=(1, 100, 1152), dtype=torch.float32, device=cuda:0, range=[-0.486445, 0.446746], mean=-0.000048, std=0.099986, norm=33.936142 +Input Variation: +0.001 * iteration (deterministic) + +Warming up (10 iterations)... +
+
+
▶ UV Install Logs
+ +
+
Fetching 66 files: 0%| | 0/66 [00:00<?, ?it/s] +Fetching 66 files: 2%|▏ | 1/66 [00:00<00:24, 2.66it/s] +Fetching 66 files: 14%|█▎ | 9/66 [00:00<00:02, 20.99it/s] +Fetching 66 files: 24%|██▍ | 16/66 [00:00<00:01, 31.57it/s] +Fetching 66 files: 32%|███▏ | 21/66 [00:01<00:02, 17.74it/s] +Fetching 66 files: 53%|█████▎ | 35/66 [00:01<00:01, 29.20it/s] +Fetching 66 files: 71%|███████ | 47/66 [00:01<00:00, 40.39it/s] +Fetching 66 files: 85%|████████▍ | 56/66 [00:01<00:00, 43.01it/s] +Fetching 66 files: 97%|█████████▋| 64/66 [00:01<00:00, 47.82it/s] +Fetching 66 files: 100%|██████████| 66/66 [00:01<00:00, 35.14it/s] +/tmp/tmpsyirxqys/cuda_utils.c:5:10: fatal error: Python.h: No such file or directory + 5 | #include <Python.h> + | ^~~~~~~~~~ +compilation terminated. +Traceback (most recent call last): + File "/repo/moe_benchmarks/megablocks_yamoe/.uvnote/cells/megablocks_run.py", line 102, in <module> + output, stats = bench(model, x) + ^^^^^^^^^^^^^^^ + File "/repo/moe_benchmarks/megablocks_yamoe/.uvnote/cells/bench_utils.py", line 189, in runner + result, times_s = _bench_engine(call, warmup=warmup, iters=iters, device=device, dtype=dtype, input_gen=input_gen) + ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + File "/repo/moe_benchmarks/megablocks_yamoe/.uvnote/cells/bench_utils.py", line 96, in _bench_engine + _ = call(input_gen()) + ^^^^^^^^^^^^^^^^^ + File "/repo/moe_benchmarks/megablocks_yamoe/.uvnote/cells/bench_utils.py", line 177, in <lambda> + call = lambda x: fn(x, *args[1:], **kwargs) + ^^^^^^^^^^^^^^^^^^^^^^^^^^ + File "/tmp/uvnote-run-4n1mby1e/home/.cache/uv/environments-v2/megablocks-run-8802ebf6d3566120/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1773, in _wrapped_call_impl + return self._call_impl(*args, **kwargs) + ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + File "/tmp/uvnote-run-4n1mby1e/home/.cache/uv/environments-v2/megablocks-run-8802ebf6d3566120/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1784, in _call_impl + return forward_call(*args, **kwargs) + ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + File "/repo/moe_benchmarks/megablocks_yamoe/.uvnote/cells/megablocks_run.py", line 81, in forward + output, dummy_routing_weights = self.model(hidden_states) + ^^^^^^^^^^^^^^^^^^^^^^^^^ + File "/tmp/uvnote-run-4n1mby1e/home/.cache/uv/environments-v2/megablocks-run-8802ebf6d3566120/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1773, in _wrapped_call_impl + return self._call_impl(*args, **kwargs) + ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + File "/tmp/uvnote-run-4n1mby1e/home/.cache/uv/environments-v2/megablocks-run-8802ebf6d3566120/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1784, in _call_impl + return forward_call(*args, **kwargs) + ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + File "/tmp/uvnote-run-4n1mby1e/home/.cache/huggingface/hub/models--kernels-community--megablocks/snapshots/e0fb1437de3f8d7079c4da13be8cb64dc0cfcdd5/build/torch28-cxx11-cu128-x86_64-linux/megablocks/layers.py", line 896, in forward + output, expert_weights_out, *_ = moe_forward( + ^^^^^^^^^^^^ + File "/tmp/uvnote-run-4n1mby1e/home/.cache/huggingface/hub/models--kernels-community--megablocks/snapshots/e0fb1437de3f8d7079c4da13be8cb64dc0cfcdd5/build/torch28-cxx11-cu128-x86_64-linux/megablocks/layers.py", line 730, in moe_forward + x, tokens_per_expert = forward_fn(**forward_args) + ^^^^^^^^^^^^^^^^^^^^^^^^^^ + File "/tmp/uvnote-run-4n1mby1e/home/.cache/huggingface/hub/models--kernels-community--megablocks/snapshots/e0fb1437de3f8d7079c4da13be8cb64dc0cfcdd5/build/torch28-cxx11-cu128-x86_64-linux/megablocks/layers.py", line 457, in forward_once + x = permute_and_compute( + ^^^^^^^^^^^^^^^^^^^^ + File "/tmp/uvnote-run-4n1mby1e/home/.cache/huggingface/hub/models--kernels-community--megablocks/snapshots/e0fb1437de3f8d7079c4da13be8cb64dc0cfcdd5/build/torch28-cxx11-cu128-x86_64-linux/megablocks/layers.py", line 401, in permute_and_compute + x = ops.binned_gather(x, indices, bins, expert_capacity, top_k) + ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + File "/tmp/uvnote-run-4n1mby1e/home/.cache/uv/environments-v2/megablocks-run-8802ebf6d3566120/lib/python3.11/site-packages/torch/autograd/function.py", line 576, in apply + return super().apply(*args, **kwargs) # type: ignore[misc] + ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + File "/tmp/uvnote-run-4n1mby1e/home/.cache/huggingface/hub/models--kernels-community--megablocks/snapshots/e0fb1437de3f8d7079c4da13be8cb64dc0cfcdd5/build/torch28-cxx11-cu128-x86_64-linux/megablocks/ops/stk_autocast.py", line 30, in decorate_fwd + return fwd(*args, **kwargs) + ^^^^^^^^^^^^^^^^^^^^ + File "/tmp/uvnote-run-4n1mby1e/home/.cache/huggingface/hub/models--kernels-community--megablocks/snapshots/e0fb1437de3f8d7079c4da13be8cb64dc0cfcdd5/build/torch28-cxx11-cu128-x86_64-linux/megablocks/ops/binned_gather.py", line 26, in forward + return kernels.binned_gather(x, indices, None, bins, bin_size, top_k) + ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + File "/tmp/uvnote-run-4n1mby1e/home/.cache/huggingface/hub/models--kernels-community--megablocks/snapshots/e0fb1437de3f8d7079c4da13be8cb64dc0cfcdd5/build/torch28-cxx11-cu128-x86_64-linux/megablocks/backend/kernels.py", line 419, in binned_gather + _binned_copy[(num_experts, expert_capacity)]( + File "/tmp/uvnote-run-4n1mby1e/home/.cache/uv/environments-v2/megablocks-run-8802ebf6d3566120/lib/python3.11/site-packages/triton/runtime/jit.py", line 390, in <lambda> + return lambda *args, **kwargs: self.run(grid=grid, warmup=False, *args, **kwargs) + ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + File "/tmp/uvnote-run-4n1mby1e/home/.cache/uv/environments-v2/megablocks-run-8802ebf6d3566120/lib/python3.11/site-packages/triton/runtime/autotuner.py", line 239, in run + benchmark() + File "/tmp/uvnote-run-4n1mby1e/home/.cache/uv/environments-v2/megablocks-run-8802ebf6d3566120/lib/python3.11/site-packages/triton/runtime/autotuner.py", line 228, in benchmark + timings = {config: self._bench(*args, config=config, **kwargs) for config in pruned_configs} + ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + File "/tmp/uvnote-run-4n1mby1e/home/.cache/uv/environments-v2/megablocks-run-8802ebf6d3566120/lib/python3.11/site-packages/triton/runtime/autotuner.py", line 228, in <dictcomp> + timings = {config: self._bench(*args, config=config, **kwargs) for config in pruned_configs} + ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + File "/tmp/uvnote-run-4n1mby1e/home/.cache/uv/environments-v2/megablocks-run-8802ebf6d3566120/lib/python3.11/site-packages/triton/runtime/autotuner.py", line 160, in _bench + return self.do_bench(kernel_call, quantiles=(0.5, 0.2, 0.8)) + ^^^^^^^^^^^^^ + File "/usr/lib/python3.11/functools.py", line 1001, in __get__ + val = self.func(instance) + ^^^^^^^^^^^^^^^^^^^ + File "/tmp/uvnote-run-4n1mby1e/home/.cache/uv/environments-v2/megablocks-run-8802ebf6d3566120/lib/python3.11/site-packages/triton/runtime/autotuner.py", line 121, in do_bench + return driver.active.get_benchmarker() + ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + File "/tmp/uvnote-run-4n1mby1e/home/.cache/uv/environments-v2/megablocks-run-8802ebf6d3566120/lib/python3.11/site-packages/triton/runtime/driver.py", line 30, in __getattr__ + return getattr(self._initialize_obj(), name) + ^^^^^^^^^^^^^^^^^^^^^^ + File "/tmp/uvnote-run-4n1mby1e/home/.cache/uv/environments-v2/megablocks-run-8802ebf6d3566120/lib/python3.11/site-packages/triton/runtime/driver.py", line 26, in _initialize_obj + self._obj = self._init_fn() + ^^^^^^^^^^^^^^^ + File "/tmp/uvnote-run-4n1mby1e/home/.cache/uv/environments-v2/megablocks-run-8802ebf6d3566120/lib/python3.11/site-packages/triton/runtime/driver.py", line 12, in _create_driver + return active_drivers[0]() + ^^^^^^^^^^^^^^^^^^^ + File "/tmp/uvnote-run-4n1mby1e/home/.cache/uv/environments-v2/megablocks-run-8802ebf6d3566120/lib/python3.11/site-packages/triton/backends/nvidia/driver.py", line 715, in __init__ + self.utils = CudaUtils() # TODO: make static + ^^^^^^^^^^^ + File "/tmp/uvnote-run-4n1mby1e/home/.cache/uv/environments-v2/megablocks-run-8802ebf6d3566120/lib/python3.11/site-packages/triton/backends/nvidia/driver.py", line 62, in __init__ + mod = compile_module_from_src( + ^^^^^^^^^^^^^^^^^^^^^^^^ + File "/tmp/uvnote-run-4n1mby1e/home/.cache/uv/environments-v2/megablocks-run-8802ebf6d3566120/lib/python3.11/site-packages/triton/runtime/build.py", line 88, in compile_module_from_src + so = _build(name, src_path, tmpdir, library_dirs or [], include_dirs or [], libraries or []) + ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + File "/tmp/uvnote-run-4n1mby1e/home/.cache/uv/environments-v2/megablocks-run-8802ebf6d3566120/lib/python3.11/site-packages/triton/runtime/build.py", line 51, in _build + subprocess.check_call(cc_cmd, stdout=subprocess.DEVNULL) + File "/usr/lib/python3.11/subprocess.py", line 413, in check_call + raise CalledProcessError(retcode, cmd) +subprocess.CalledProcessError: Command '['/usr/bin/gcc', '/tmp/tmpsyirxqys/cuda_utils.c', '-O3', '-shared', '-fPIC', '-Wno-psabi', '-o', '/tmp/tmpsyirxqys/cuda_utils.cpython-311-x86_64-linux-gnu.so', '-lcuda', '-L/tmp/uvnote-run-4n1mby1e/home/.cache/uv/environments-v2/megablocks-run-8802ebf6d3566120/lib/python3.11/site-packages/triton/backends/nvidia/lib', '-L/usr/lib/x86_64-linux-gnu', '-I/tmp/uvnote-run-4n1mby1e/home/.cache/uv/environments-v2/megablocks-run-8802ebf6d3566120/lib/python3.11/site-packages/triton/backends/nvidia/include', '-I/tmp/tmpsyirxqys', '-I/usr/include/python3.11']' returned non-zero exit status 1.
+
+
+

Performance Visualization

This section reads all benchmark results and creates a comprehensive performance comparison chart.