GptOssExperts - OpenAI-style MoE

GPU Info

▼ code ▼ output ▶ uv-logs | Cell: nv | 0.22s | Raw GitHub 🤗 HF
import subprocess
print(subprocess.run(["nvidia-smi"], capture_output=True, text=True).stdout)
Mon Nov 10 21:58:43 2025       
+-----------------------------------------------------------------------------------------+
| NVIDIA-SMI 580.95.05              Driver Version: 580.95.05      CUDA Version: 13.0     |
+-----------------------------------------+------------------------+----------------------+
| GPU  Name                 Persistence-M | Bus-Id          Disp.A | Volatile Uncorr. ECC |
| Fan  Temp   Perf          Pwr:Usage/Cap |           Memory-Usage | GPU-Util  Compute M. |
|                                         |                        |               MIG M. |
|=========================================+========================+======================|
|   0  NVIDIA L40S                    On  |   00000000:4D:00.0 Off |                    0 |
| N/A   31C    P0             78W /  350W |       0MiB /  46068MiB |     17%      Default |
|                                         |                        |                  N/A |
+-----------------------------------------+------------------------+----------------------+

+-----------------------------------------------------------------------------------------+
| Processes:                                                                              |
|  GPU   GI   CI              PID   Type   Process name                        GPU Memory |
|        ID   ID                                                               Usage      |
|=========================================================================================|
|  No running processes found                                                             |
+-----------------------------------------------------------------------------------------+

OpenAI-style MoE Benchmark (GptOssExperts Reference)

▼ code ▼ output ▶ uv-logs | Cell: benchmark | 25.04s | Raw GitHub 🤗 HF
# /// script
# requires-python = ">=3.10"
# dependencies = [
#     "numpy",
#     "torch==2.8.0",
#     "kernels-benchmark-tools",
#     "kernels",
# ]
#
# [tool.uv.sources]
# kernels-benchmark-tools = { path = "../../../../../tools", editable = true }
# kernels = { git = "https://github.com/huggingface/kernels.git" }
# ///
import torch
import sys
from kernels_benchmark_tools import KernelTypeEnum, run_benchmark
from kernels import get_kernel

# Load yamoe to get GptOssExperts reference
yamoe = get_kernel("drbh/yamoe", revision="v0.2.0")
GptOssExperts = yamoe.vendored.gpt_oss_mlp.GptOssExperts


def gpt_oss_openai_moe(
    hidden_states,
    router_indices,
    routing_weights,
    gate_up_proj,
    gate_up_proj_bias,
    down_proj,
    down_proj_bias,
):
    """
    GptOssExperts reference implementation of OpenAI-style MoE.
    This is the reference model implementation from the original GPT OSS codebase.
    """
    B, S, H = hidden_states.shape
    E = routing_weights.shape[2]

    # Create a config object for GptOssExperts
    config = type("Config", (), {})()
    config.hidden_size = H
    config.intermediate_size = gate_up_proj.shape[2] // 2  # expert_dim / 2 = H
    config.num_local_experts = E

    # Initialize model
    model = GptOssExperts(config)

    # Set weights from benchmark inputs
    model.gate_up_proj.data = gate_up_proj
    model.gate_up_proj_bias.data = gate_up_proj_bias
    model.down_proj.data = down_proj
    model.down_proj_bias.data = down_proj_bias

    model = model.to(hidden_states.device)
    model.eval()

    # Force GptOssExperts to use CPU path for correctness (matches naive_moe_ref behavior)
    # The GPU path processes all experts which can lead to numerical differences
    # CPU path explicitly uses router_indices like the reference implementation
    model.train()  # Force CPU path

    # Flatten routing_weights to [batch_seq, num_experts]
    routing_weights_flat = routing_weights.view(-1, E)

    # Run forward pass
    with torch.no_grad():
        output = model(hidden_states, router_indices, routing_weights_flat)

    model.eval()  # Reset to eval mode

    return output


run_benchmark(
    kernel_type=KernelTypeEnum.OPENAI_MOE,
    impl_name="gpt_oss_experts",
    impl_tags={"family": "reference", "backend": "pytorch"},
    impl_func=gpt_oss_openai_moe,
    dtype="float32",
)
Running openai_moe benchmark on cuda with 8 workloads.

======================================================================
PROFILE TRACE: gpt_oss_experts | cuda_B1_S512_E2
======================================================================
-------------------------------------------------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  
                                                   Name    Self CPU %      Self CPU   CPU total %     CPU total  CPU time avg     Self CUDA   Self CUDA %    CUDA total  CUDA time avg    # of Calls  
-------------------------------------------------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  
                                        gpt_oss_experts         0.00%       0.000us         0.00%       0.000us       0.000us      10.360ms       190.98%      10.360ms      10.360ms             1  
                                        gpt_oss_experts        15.12%       1.924ms        99.94%      12.713ms      12.713ms       0.000us         0.00%       5.428ms       5.428ms             1  
                                           aten::matmul         0.18%      22.311us         3.73%     473.846us      39.487us       0.000us         0.00%       4.800ms     400.041us            12  
                                               aten::mm         2.34%     297.100us         3.55%     451.535us      37.628us       4.800ms        88.50%       4.800ms     400.041us            12  
                                 ampere_sgemm_128x64_nn         0.00%       0.000us         0.00%       0.000us       0.000us       3.258ms        60.07%       3.258ms     362.028us             9  
void cutlass::Kernel2<cutlass_80_simt_sgemm_128x64_8...         0.00%       0.000us         0.00%       0.000us       0.000us       1.536ms        28.31%       1.536ms     511.862us             3  
                                              aten::mul         1.29%     163.978us         2.14%     271.630us      11.318us     109.411us         2.02%     109.411us       4.559us            24  
                                              aten::add         1.51%     192.130us         3.80%     483.423us      26.857us     103.358us         1.91%     103.358us       5.742us            18  
                                            aten::index         1.52%     193.374us         2.62%     333.164us      27.764us      88.224us         1.63%      88.224us       7.352us            12  
void at::native::elementwise_kernel<128, 2, at::nati...         0.00%       0.000us         0.00%       0.000us       0.000us      80.864us         1.49%      80.864us       6.739us            12  
                                       aten::index_add_         0.46%      58.130us         0.76%      97.241us      16.207us      80.064us         1.48%      80.064us      13.344us             6  
void at::native::indexFuncLargeIndex<float, long, un...         0.00%       0.000us         0.00%       0.000us       0.000us      80.064us         1.48%      80.064us      13.344us             6  
                                          aten::nonzero         2.05%     260.439us         6.29%     799.492us      88.832us      65.278us         1.20%      76.093us       8.455us             9  
                                            aten::clamp         0.99%     126.442us         1.60%     203.852us      16.988us      63.456us         1.17%      63.456us       5.288us            12  
void at::native::elementwise_kernel<128, 2, at::nati...         0.00%       0.000us         0.00%       0.000us       0.000us      63.456us         1.17%      63.456us       5.288us            12  
                                            aten::where         0.06%       7.391us         5.01%     637.190us     106.198us       0.000us         0.00%      61.533us      10.256us             6  
                                    aten::nonzero_numpy         0.09%      11.880us         4.95%     629.799us     104.967us       0.000us         0.00%      61.533us      10.256us             6  
void at::native::vectorized_gather_kernel<16, long>(...         0.00%       0.000us         0.00%       0.000us       0.000us      60.544us         1.12%      60.544us      10.091us             6  
void at::native::vectorized_elementwise_kernel<4, at...         0.00%       0.000us         0.00%       0.000us       0.000us      56.929us         1.05%      56.929us       4.744us            12  
                         Memcpy DtoH (Device -> Pinned)         0.00%       0.000us         0.00%       0.000us       0.000us      51.073us         0.94%      51.073us       1.135us            45  
-------------------------------------------------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  
Self CPU time total: 12.720ms
Self CUDA time total: 5.425ms



======================================================================
PROFILE TRACE: gpt_oss_experts | cuda_B1_S512_E4
======================================================================
-------------------------------------------------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  
                                                   Name    Self CPU %      Self CPU   CPU total %     CPU total  CPU time avg     Self CUDA   Self CUDA %    CUDA total  CUDA time avg    # of Calls  
-------------------------------------------------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  
                                        gpt_oss_experts         0.00%       0.000us         0.00%       0.000us       0.000us      13.942ms       218.38%      13.942ms      13.942ms             1  
                                        gpt_oss_experts        15.57%       2.499ms        99.97%      16.048ms      16.048ms       0.000us         0.00%       6.387ms       6.387ms             1  
                                           aten::matmul         0.25%      39.461us         4.79%     769.170us      32.049us       0.000us         0.00%       5.570ms     232.102us            24  
                                               aten::mm         2.77%     444.894us         4.55%     729.709us      30.405us       5.570ms        87.25%       5.570ms     232.102us            24  
                                 ampere_sgemm_128x64_nn         0.00%       0.000us         0.00%       0.000us       0.000us       5.515ms        86.38%       5.515ms     229.794us            24  
                                          aten::nonzero         2.34%     374.919us         7.60%       1.220ms      81.308us     114.786us         1.80%     137.349us       9.157us            15  
                                              aten::mul         1.86%     298.668us         3.09%     496.508us      10.344us     131.614us         2.06%     131.614us       2.742us            48  
                                              aten::add         2.06%     330.439us         3.47%     556.980us      15.472us     127.904us         2.00%     127.904us       3.553us            36  
                                            aten::where         0.07%      11.120us         7.17%       1.151ms      95.939us       0.000us         0.00%     123.109us      10.259us            12  
                                    aten::nonzero_numpy         0.13%      20.771us         7.10%       1.140ms      95.012us       0.000us         0.00%     123.109us      10.259us            12  
                                            aten::index         2.15%     344.365us         3.72%     597.667us      24.903us     111.391us         1.74%     111.391us       4.641us            24  
void at::native::elementwise_kernel<128, 2, at::nati...         0.00%       0.000us         0.00%       0.000us       0.000us     101.985us         1.60%     101.985us       4.249us            24  
                         Memcpy DtoH (Device -> Pinned)         0.00%       0.000us         0.00%       0.000us       0.000us      91.395us         1.43%      91.395us       1.051us            87  
                                            aten::clamp         1.30%     208.833us         2.21%     355.215us      14.801us      88.257us         1.38%      88.257us       3.677us            24  
void at::native::elementwise_kernel<128, 2, at::nati...         0.00%       0.000us         0.00%       0.000us       0.000us      88.257us         1.38%      88.257us       3.677us            24  
                                             aten::item         0.49%      78.042us        39.66%       6.367ms      88.433us       0.000us         0.00%      75.297us       1.046us            72  
                              aten::_local_scalar_dense         1.92%     308.797us        39.18%       6.289ms      87.349us      75.297us         1.18%      75.297us       1.046us            72  
                                       aten::index_add_         0.59%      94.029us         0.99%     158.640us      13.220us      71.454us         1.12%      71.454us       5.954us            12  
void at::native::indexFuncLargeIndex<float, long, un...         0.00%       0.000us         0.00%       0.000us       0.000us      71.454us         1.12%      71.454us       5.954us            12  
void at::native::vectorized_gather_kernel<16, long>(...         0.00%       0.000us         0.00%       0.000us       0.000us      66.271us         1.04%      66.271us       5.523us            12  
-------------------------------------------------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  
Self CPU time total: 16.053ms
Self CUDA time total: 6.384ms



======================================================================
PROFILE TRACE: gpt_oss_experts | cuda_B1_S1024_E2
======================================================================
-------------------------------------------------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  
                                                   Name    Self CPU %      Self CPU   CPU total %     CPU total  CPU time avg     Self CUDA   Self CUDA %    CUDA total  CUDA time avg    # of Calls  
-------------------------------------------------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  
                                        gpt_oss_experts         0.00%       0.000us         0.00%       0.000us       0.000us      12.597ms       146.28%      12.597ms      12.597ms             1  
                                        gpt_oss_experts        11.26%       1.671ms        99.96%      14.835ms      14.835ms       0.000us         0.00%       8.616ms       8.616ms             1  
                                           aten::matmul         0.13%      19.980us         2.85%     423.596us      35.300us       0.000us         0.00%       7.614ms     634.486us            12  
                                               aten::mm         1.70%     251.563us         2.72%     403.616us      33.635us       7.614ms        88.42%       7.614ms     634.486us            12  
void cutlass::Kernel2<cutlass_80_simt_sgemm_256x128_...         0.00%       0.000us         0.00%       0.000us       0.000us       4.628ms        53.74%       4.628ms     771.312us             6  
                                 ampere_sgemm_128x64_nn         0.00%       0.000us         0.00%       0.000us       0.000us       1.524ms        17.70%       1.524ms     508.107us             3  
void cutlass::Kernel2<cutlass_80_simt_sgemm_128x64_8...         0.00%       0.000us         0.00%       0.000us       0.000us       1.455ms        16.90%       1.455ms     485.046us             3  
                                              aten::mul         1.00%     148.488us         1.71%     253.960us      10.582us     188.737us         2.19%     188.737us       7.864us            24  
                                              aten::add         1.14%     169.821us         1.97%     292.395us      16.244us     180.606us         2.10%     180.606us      10.034us            18  
                                       aten::index_add_         0.32%      47.691us         0.57%      84.001us      14.000us     164.000us         1.90%     164.000us      27.333us             6  
void at::native::indexFuncLargeIndex<float, long, un...         0.00%       0.000us         0.00%       0.000us       0.000us     164.000us         1.90%     164.000us      27.333us             6  
                                            aten::index         1.23%     181.951us         2.12%     314.145us      26.179us     144.608us         1.68%     144.608us      12.051us            12  
void at::native::elementwise_kernel<128, 2, at::nati...         0.00%       0.000us         0.00%       0.000us       0.000us     142.815us         1.66%     142.815us      11.901us            12  
void at::native::vectorized_gather_kernel<16, long>(...         0.00%       0.000us         0.00%       0.000us       0.000us     114.816us         1.33%     114.816us      19.136us             6  
                                            aten::clamp         0.72%     107.083us         1.24%     184.134us      15.345us     106.818us         1.24%     106.818us       8.902us            12  
void at::native::elementwise_kernel<128, 2, at::nati...         0.00%       0.000us         0.00%       0.000us       0.000us     106.818us         1.24%     106.818us       8.902us            12  
void at::native::vectorized_elementwise_kernel<4, at...         0.00%       0.000us         0.00%       0.000us       0.000us     100.513us         1.17%     100.513us       8.376us            12  
                                          aten::nonzero         1.51%     224.830us         4.84%     718.263us      79.807us      68.894us         0.80%      80.029us       8.892us             9  
                                            aten::where         0.04%       5.681us         3.95%     586.411us      97.735us       0.000us         0.00%      65.405us      10.901us             6  
                                    aten::nonzero_numpy         0.07%      10.160us         3.91%     580.730us      96.788us       0.000us         0.00%      65.405us      10.901us             6  
-------------------------------------------------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  
Self CPU time total: 14.841ms
Self CUDA time total: 8.611ms



======================================================================
PROFILE TRACE: gpt_oss_experts | cuda_B1_S1024_E4
======================================================================
-------------------------------------------------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  
                                                   Name    Self CPU %      Self CPU   CPU total %     CPU total  CPU time avg     Self CUDA   Self CUDA %    CUDA total  CUDA time avg    # of Calls  
-------------------------------------------------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  
                                        gpt_oss_experts         0.00%       0.000us         0.00%       0.000us       0.000us      18.460ms       171.74%      18.460ms      18.460ms             1  
                                        gpt_oss_experts        12.58%       2.618ms        99.97%      20.806ms      20.806ms       0.000us         0.00%      10.754ms      10.754ms             1  
                                           aten::matmul         0.19%      39.724us         3.85%     801.313us      33.388us       0.000us         0.00%       9.496ms     395.681us            24  
                                               aten::mm         2.21%     460.813us         3.66%     761.589us      31.733us       9.496ms        88.35%       9.496ms     395.681us            24  
                                 ampere_sgemm_128x64_nn         0.00%       0.000us         0.00%       0.000us       0.000us       6.491ms        60.39%       6.491ms     360.603us            18  
void cutlass::Kernel2<cutlass_80_simt_sgemm_128x64_8...         0.00%       0.000us         0.00%       0.000us       0.000us       2.993ms        27.84%       2.993ms     498.774us             6  
                                              aten::mul         2.25%     467.369us         3.28%     683.452us      14.239us     226.014us         2.10%     226.014us       4.709us            48  
                                              aten::add         1.60%     332.210us         2.74%     569.351us      15.815us     207.013us         1.93%     207.013us       5.750us            36  
                                            aten::index         1.72%     357.427us         2.99%     622.664us      25.944us     203.329us         1.89%     203.329us       8.472us            24  
void at::native::elementwise_kernel<128, 2, at::nati...         0.00%       0.000us         0.00%       0.000us       0.000us     162.243us         1.51%     162.243us       6.760us            24  
                                       aten::index_add_         0.45%      94.395us         0.78%     161.485us      13.457us     155.167us         1.44%     155.167us      12.931us            12  
void at::native::indexFuncLargeIndex<float, long, un...         0.00%       0.000us         0.00%       0.000us       0.000us     155.167us         1.44%     155.167us      12.931us            12  
                                          aten::nonzero         1.86%     386.184us         6.07%       1.263ms      84.202us     120.989us         1.13%     144.894us       9.660us            15  
void at::native::vectorized_gather_kernel<16, long>(...         0.00%       0.000us         0.00%       0.000us       0.000us     144.769us         1.35%     144.769us      12.064us            12  
                                            aten::where         0.05%      10.779us         5.71%       1.188ms      99.031us       0.000us         0.00%     130.270us      10.856us            12  
                                    aten::nonzero_numpy         0.10%      20.452us         5.66%       1.178ms      98.133us       0.000us         0.00%     130.270us      10.856us            12  
                                            aten::clamp         1.04%     217.185us         1.79%     373.407us      15.559us     129.252us         1.20%     129.252us       5.386us            24  
void at::native::elementwise_kernel<128, 2, at::nati...         0.00%       0.000us         0.00%       0.000us       0.000us     129.252us         1.20%     129.252us       5.386us            24  
void at::native::vectorized_elementwise_kernel<4, at...         0.00%       0.000us         0.00%       0.000us       0.000us     115.584us         1.08%     115.584us       4.816us            24  
                         Memcpy DtoH (Device -> Pinned)         0.00%       0.000us         0.00%       0.000us       0.000us     107.234us         1.00%     107.234us       1.233us            87  
-------------------------------------------------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  
Self CPU time total: 20.812ms
Self CUDA time total: 10.749ms



======================================================================
PROFILE TRACE: gpt_oss_experts | cuda_B4_S512_E2
======================================================================
-------------------------------------------------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  
                                                   Name    Self CPU %      Self CPU   CPU total %     CPU total  CPU time avg     Self CUDA   Self CUDA %    CUDA total  CUDA time avg    # of Calls  
-------------------------------------------------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  
                                        gpt_oss_experts         0.00%       0.000us         0.00%       0.000us       0.000us      21.083ms       119.21%      21.083ms      21.083ms             1  
                                        gpt_oss_experts         7.12%       1.665ms        99.98%      23.365ms      23.365ms       0.000us         0.00%      17.695ms      17.695ms             1  
                                           aten::matmul         0.09%      20.129us         1.89%     441.429us      36.786us       0.000us         0.00%      14.828ms       1.236ms            12  
                                               aten::mm         1.11%     260.517us         1.80%     421.300us      35.108us      14.828ms        83.84%      14.828ms       1.236ms            12  
void cutlass::Kernel2<cutlass_80_simt_sgemm_256x128_...         0.00%       0.000us         0.00%       0.000us       0.000us       9.047ms        51.15%       9.047ms       1.508ms             6  
                                 ampere_sgemm_128x64_nn         0.00%       0.000us         0.00%       0.000us       0.000us       5.773ms        32.64%       5.773ms     962.167us             6  
                                              aten::add         0.74%     174.025us         1.27%     296.156us      16.453us     776.579us         4.39%     776.579us      43.143us            18  
                                              aten::mul         0.64%     149.555us         1.10%     257.226us      10.718us     654.338us         3.70%     654.338us      27.264us            24  
void at::native::elementwise_kernel<128, 2, at::nati...         0.00%       0.000us         0.00%       0.000us       0.000us     499.874us         2.83%     499.874us      41.656us            12  
                                       aten::index_add_         0.21%      48.400us         0.36%      84.241us      14.040us     449.985us         2.54%     449.985us      74.998us             6  
void at::native::indexFuncLargeIndex<float, long, un...         0.00%       0.000us         0.00%       0.000us       0.000us     449.985us         2.54%     449.985us      74.998us             6  
                                            aten::clamp         0.46%     107.321us         0.79%     185.253us      15.438us     329.054us         1.86%     329.054us      27.421us            12  
void at::native::elementwise_kernel<128, 2, at::nati...         0.00%       0.000us         0.00%       0.000us       0.000us     329.054us         1.86%     329.054us      27.421us            12  
void at::native::vectorized_elementwise_kernel<4, at...         0.00%       0.000us         0.00%       0.000us       0.000us     300.737us         1.70%     300.737us      50.123us             6  
void at::native::vectorized_elementwise_kernel<4, at...         0.00%       0.000us         0.00%       0.000us       0.000us     276.705us         1.56%     276.705us      46.117us             6  
                                            aten::index         0.76%     178.051us         1.32%     309.462us      25.788us     268.800us         1.52%     268.800us      22.400us            12  
void at::native::vectorized_elementwise_kernel<4, at...         0.00%       0.000us         0.00%       0.000us       0.000us     253.889us         1.44%     253.889us      21.157us            12  
void at::native::vectorized_gather_kernel<16, long>(...         0.00%       0.000us         0.00%       0.000us       0.000us     236.095us         1.33%     236.095us      39.349us             6  
                                          aten::sigmoid         0.16%      36.571us         0.27%      63.572us      10.595us     176.833us         1.00%     176.833us      29.472us             6  
void at::native::vectorized_elementwise_kernel<4, at...         0.00%       0.000us         0.00%       0.000us       0.000us     176.833us         1.00%     176.833us      29.472us             6  
-------------------------------------------------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  
Self CPU time total: 23.371ms
Self CUDA time total: 17.686ms



======================================================================
PROFILE TRACE: gpt_oss_experts | cuda_B4_S512_E4
======================================================================
-------------------------------------------------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  
                                                   Name    Self CPU %      Self CPU   CPU total %     CPU total  CPU time avg     Self CUDA   Self CUDA %    CUDA total  CUDA time avg    # of Calls  
-------------------------------------------------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  
                                        gpt_oss_experts         0.00%       0.000us         0.00%       0.000us       0.000us      24.709ms       139.35%      24.709ms      24.709ms             1  
                                        gpt_oss_experts         9.76%       2.650ms        99.98%      27.156ms      27.156ms       0.000us         0.00%      17.741ms      17.741ms             1  
                                           aten::matmul         0.15%      40.162us         3.17%     860.144us      35.839us       0.000us         0.00%      15.537ms     647.383us            24  
                                               aten::mm         1.90%     517.331us         3.02%     819.982us      34.166us      15.537ms        87.63%      15.537ms     647.383us            24  
void cutlass::Kernel2<cutlass_80_simt_sgemm_256x128_...         0.00%       0.000us         0.00%       0.000us       0.000us       9.352ms        52.74%       9.352ms     779.317us            12  
                                 ampere_sgemm_128x64_nn         0.00%       0.000us         0.00%       0.000us       0.000us       3.225ms        18.19%       3.225ms     537.452us             6  
void cutlass::Kernel2<cutlass_80_simt_sgemm_128x64_8...         0.00%       0.000us         0.00%       0.000us       0.000us       2.947ms        16.62%       2.947ms     491.169us             6  
                                              aten::add         1.29%     349.077us         2.22%     601.999us      16.722us     419.552us         2.37%     419.552us      11.654us            36  
                                              aten::mul         1.15%     311.953us         1.98%     539.014us      11.229us     410.371us         2.31%     410.371us       8.549us            48  
                                       aten::index_add_         0.36%      97.270us         0.61%     164.412us      13.701us     379.682us         2.14%     379.682us      31.640us            12  
void at::native::indexFuncLargeIndex<float, long, un...         0.00%       0.000us         0.00%       0.000us       0.000us     379.682us         2.14%     379.682us      31.640us            12  
                                            aten::index         1.31%     354.897us         2.36%     641.129us      26.714us     344.639us         1.94%     344.639us      14.360us            24  
void at::native::elementwise_kernel<128, 2, at::nati...         0.00%       0.000us         0.00%       0.000us       0.000us     337.056us         1.90%     337.056us      14.044us            24  
void at::native::vectorized_gather_kernel<16, long>(...         0.00%       0.000us         0.00%       0.000us       0.000us     280.607us         1.58%     280.607us      23.384us            12  
                                            aten::clamp         0.78%     212.661us         1.36%     368.626us      15.359us     225.662us         1.27%     225.662us       9.403us            24  
void at::native::elementwise_kernel<128, 2, at::nati...         0.00%       0.000us         0.00%       0.000us       0.000us     225.662us         1.27%     225.662us       9.403us            24  
void at::native::vectorized_elementwise_kernel<4, at...         0.00%       0.000us         0.00%       0.000us       0.000us     218.112us         1.23%     218.112us       9.088us            24  
                                          aten::nonzero         1.41%     383.824us         4.68%       1.271ms      84.702us     127.715us         0.72%     153.604us      10.240us            15  
                                            aten::where         0.04%      11.073us         4.43%       1.203ms     100.252us       0.000us         0.00%     138.052us      11.504us            12  
                                    aten::nonzero_numpy         0.07%      20.230us         4.39%       1.192ms      99.329us       0.000us         0.00%     138.052us      11.504us            12  
-------------------------------------------------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  
Self CPU time total: 27.162ms
Self CUDA time total: 17.731ms



======================================================================
PROFILE TRACE: gpt_oss_experts | cuda_B4_S1024_E2
======================================================================
-------------------------------------------------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  
                                                   Name    Self CPU %      Self CPU   CPU total %     CPU total  CPU time avg     Self CUDA   Self CUDA %    CUDA total  CUDA time avg    # of Calls  
-------------------------------------------------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  
                                        gpt_oss_experts         0.00%       0.000us         0.00%       0.000us       0.000us      40.750ms       109.03%      40.750ms      40.750ms             1  
                                        gpt_oss_experts         4.08%       1.695ms        99.82%      41.512ms      41.512ms       0.000us         0.00%      37.407ms      37.407ms             1  
                                           aten::matmul         0.05%      20.951us         1.02%     424.118us      35.343us       0.000us         0.00%      27.409ms       2.284ms            12  
                                               aten::mm         0.67%     277.566us         0.97%     403.167us      33.597us      27.409ms        73.34%      27.409ms       2.284ms            12  
void cutlass::Kernel2<cutlass_80_simt_sgemm_256x128_...         0.00%       0.000us         0.00%       0.000us       0.000us      27.406ms        73.33%      27.406ms       2.284ms            12  
                                              aten::mul         0.37%     154.550us         0.63%     261.852us      10.911us       2.976ms         7.96%       2.976ms     124.014us            24  
                                              aten::add         0.45%     185.160us         1.07%     445.895us      24.772us       2.401ms         6.42%       2.401ms     133.369us            18  
                                            aten::clamp         0.28%     116.599us         0.48%     198.482us      16.540us       2.391ms         6.40%       2.391ms     199.291us            12  
void at::native::elementwise_kernel<128, 2, at::nati...         0.00%       0.000us         0.00%       0.000us       0.000us       2.391ms         6.40%       2.391ms     199.291us            12  
void at::native::vectorized_elementwise_kernel<4, at...         0.00%       0.000us         0.00%       0.000us       0.000us       1.983ms         5.30%       1.983ms     165.222us            12  
void at::native::elementwise_kernel<128, 2, at::nati...         0.00%       0.000us         0.00%       0.000us       0.000us       1.625ms         4.35%       1.625ms     135.419us            12  
                                       aten::index_add_         0.12%      48.080us         0.21%      86.751us      14.459us     910.402us         2.44%     910.402us     151.734us             6  
void at::native::indexFuncLargeIndex<float, long, un...         0.00%       0.000us         0.00%       0.000us       0.000us     910.402us         2.44%     910.402us     151.734us             6  
void at::native::vectorized_elementwise_kernel<4, at...         0.00%       0.000us         0.00%       0.000us       0.000us     775.618us         2.08%     775.618us     129.270us             6  
void at::native::vectorized_elementwise_kernel<4, at...         0.00%       0.000us         0.00%       0.000us       0.000us     740.611us         1.98%     740.611us     123.435us             6  
                                            aten::index         0.44%     181.234us         0.76%     317.848us      26.487us     714.884us         1.91%     714.884us      59.574us            12  
void at::native::vectorized_gather_kernel<16, long>(...         0.00%       0.000us         0.00%       0.000us       0.000us     681.379us         1.82%     681.379us     113.563us             6  
                                          aten::sigmoid         0.09%      38.611us         0.16%      65.922us      10.987us     320.927us         0.86%     320.927us      53.488us             6  
void at::native::vectorized_elementwise_kernel<4, at...         0.00%       0.000us         0.00%       0.000us       0.000us     320.927us         0.86%     320.927us      53.488us             6  
void at::native::elementwise_kernel<128, 2, at::nati...         0.00%       0.000us         0.00%       0.000us       0.000us     253.057us         0.68%     253.057us      42.176us             6  
-------------------------------------------------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  
Self CPU time total: 41.585ms
Self CUDA time total: 37.374ms



======================================================================
PROFILE TRACE: gpt_oss_experts | cuda_B4_S1024_E4
======================================================================
-------------------------------------------------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  
                                                   Name    Self CPU %      Self CPU   CPU total %     CPU total  CPU time avg     Self CUDA   Self CUDA %    CUDA total  CUDA time avg    # of Calls  
-------------------------------------------------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  
                                        gpt_oss_experts         0.00%       0.000us         0.00%       0.000us       0.000us      41.218ms       116.52%      41.218ms      41.218ms             1  
                                        gpt_oss_experts         6.00%       2.524ms        99.99%      42.088ms      42.088ms       0.000us         0.00%      35.395ms      35.395ms             1  
                                           aten::matmul         0.10%      40.160us         2.08%     875.043us      36.460us       0.000us         0.00%      29.436ms       1.226ms            24  
                                               aten::mm         1.24%     520.099us         1.98%     834.883us      34.787us      29.436ms        83.21%      29.436ms       1.226ms            24  
void cutlass::Kernel2<cutlass_80_simt_sgemm_256x128_...         0.00%       0.000us         0.00%       0.000us       0.000us      20.785ms        58.75%      20.785ms       1.386ms            15  
                                 ampere_sgemm_128x64_nn         0.00%       0.000us         0.00%       0.000us       0.000us       8.635ms        24.41%       8.635ms     959.410us             9  
                                              aten::add         0.83%     349.812us         1.43%     602.505us      16.736us       1.482ms         4.19%       1.482ms      41.161us            36  
                                              aten::mul         0.72%     302.661us         1.25%     525.878us      10.956us       1.369ms         3.87%       1.369ms      28.527us            48  
void at::native::elementwise_kernel<128, 2, at::nati...         0.00%       0.000us         0.00%       0.000us       0.000us     928.163us         2.62%     928.163us      38.673us            24  
                                       aten::index_add_         0.23%      95.791us         0.40%     170.382us      14.198us     908.198us         2.57%     908.198us      75.683us            12  
void at::native::indexFuncLargeIndex<float, long, un...         0.00%       0.000us         0.00%       0.000us       0.000us     908.198us         2.57%     908.198us      75.683us            12  
                                            aten::clamp         0.52%     220.263us         0.90%     378.355us      15.765us     771.551us         2.18%     771.551us      32.148us            24  
void at::native::elementwise_kernel<128, 2, at::nati...         0.00%       0.000us         0.00%       0.000us       0.000us     771.551us         2.18%     771.551us      32.148us            24  
                                            aten::index         0.83%     351.191us         1.46%     613.487us      25.562us     665.121us         1.88%     665.121us      27.713us            24  
void at::native::vectorized_elementwise_kernel<4, at...         0.00%       0.000us         0.00%       0.000us       0.000us     648.065us         1.83%     648.065us      54.005us            12  
void at::native::vectorized_gather_kernel<16, long>(...         0.00%       0.000us         0.00%       0.000us       0.000us     594.560us         1.68%     594.560us      49.547us            12  
void at::native::vectorized_elementwise_kernel<4, at...         0.00%       0.000us         0.00%       0.000us       0.000us     553.635us         1.57%     553.635us      46.136us            12  
void at::native::vectorized_elementwise_kernel<4, at...         0.00%       0.000us         0.00%       0.000us       0.000us     519.010us         1.47%     519.010us      21.625us            24  
                                          aten::sigmoid         0.17%      72.451us         0.30%     125.701us      10.475us     356.257us         1.01%     356.257us      29.688us            12  
void at::native::vectorized_elementwise_kernel<4, at...         0.00%       0.000us         0.00%       0.000us       0.000us     356.257us         1.01%     356.257us      29.688us            12  
-------------------------------------------------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  
Self CPU time total: 42.094ms
Self CUDA time total: 35.375ms


impl                     wl                  p50(ms)  ok
gpt_oss_experts          cuda_B1_S1024_E2       3.84  True
gpt_oss_experts          cuda_B1_S1024_E4       5.30  True
gpt_oss_experts          cuda_B1_S512_E2        2.68  True
gpt_oss_experts          cuda_B1_S512_E4        3.91  True
gpt_oss_experts          cuda_B4_S1024_E2      13.35  True
gpt_oss_experts          cuda_B4_S1024_E4      13.35  True
gpt_oss_experts          cuda_B4_S512_E2        6.80  True
gpt_oss_experts          cuda_B4_S512_E4        7.46  True
▶ UV Install Logs
Fetching 6 files: 0%| | 0/6 [00:00<?, ?it/s] Fetching 6 files: 50%|█████ | 3/6 [00:00<00:00, 3.54it/s] Fetching 6 files: 100%|██████████| 6/6 [00:00<00:00, 7.08it/s]

Artifacts:

openai_moe.jsonl