GptOssExperts - OpenAI-style MoE

GPU Info

▼ code ▼ output ▶ uv-logs | Cell: nv | 0.24s | Raw GitHub 🤗 HF
import subprocess
print(subprocess.run(["nvidia-smi"], capture_output=True, text=True).stdout)
Fri Oct 31 20:00:34 2025       
+-----------------------------------------------------------------------------------------+
| NVIDIA-SMI 570.195.03             Driver Version: 570.195.03     CUDA Version: 12.8     |
|-----------------------------------------+------------------------+----------------------+
| GPU  Name                 Persistence-M | Bus-Id          Disp.A | Volatile Uncorr. ECC |
| Fan  Temp   Perf          Pwr:Usage/Cap |           Memory-Usage | GPU-Util  Compute M. |
|                                         |                        |               MIG M. |
|=========================================+========================+======================|
|   0  NVIDIA L40S                    On  |   00000000:4D:00.0 Off |                    0 |
| N/A   34C    P0             81W /  350W |       0MiB /  46068MiB |     18%      Default |
|                                         |                        |                  N/A |
+-----------------------------------------+------------------------+----------------------+

+-----------------------------------------------------------------------------------------+
| Processes:                                                                              |
|  GPU   GI   CI              PID   Type   Process name                        GPU Memory |
|        ID   ID                                                               Usage      |
|=========================================================================================|
|  No running processes found                                                             |
+-----------------------------------------------------------------------------------------+

OpenAI-style MoE Benchmark (GptOssExperts Reference)

▼ code ▼ output ▶ uv-logs | Cell: benchmark | 24.32s | Raw GitHub 🤗 HF
# /// script
# requires-python = ">=3.10"
# dependencies = [
#     "numpy",
#     "torch==2.8.0",
#     "kernels-benchmark-tools",
#     "kernels",
# ]
#
# [tool.uv.sources]
# kernels-benchmark-tools = { path = "../../../../../tools", editable = true }
# kernels = { git = "https://github.com/huggingface/kernels.git" }
# ///
import torch
import sys
from kernels_benchmark_tools import KernelTypeEnum, run_benchmark
from kernels import get_kernel

# Load yamoe to get GptOssExperts reference
yamoe = get_kernel("drbh/yamoe", revision="v0.2.0")
GptOssExperts = yamoe.vendored.gpt_oss_mlp.GptOssExperts


def gpt_oss_openai_moe(
    hidden_states,
    router_indices,
    routing_weights,
    gate_up_proj,
    gate_up_proj_bias,
    down_proj,
    down_proj_bias,
):
    """
    GptOssExperts reference implementation of OpenAI-style MoE.
    This is the reference model implementation from the original GPT OSS codebase.
    """
    B, S, H = hidden_states.shape
    E = routing_weights.shape[2]

    # Create a config object for GptOssExperts
    config = type("Config", (), {})()
    config.hidden_size = H
    config.intermediate_size = gate_up_proj.shape[2] // 2  # expert_dim / 2 = H
    config.num_local_experts = E

    # Initialize model
    model = GptOssExperts(config)

    # Set weights from benchmark inputs
    model.gate_up_proj.data = gate_up_proj
    model.gate_up_proj_bias.data = gate_up_proj_bias
    model.down_proj.data = down_proj
    model.down_proj_bias.data = down_proj_bias

    model = model.to(hidden_states.device)
    model.eval()

    # Force GptOssExperts to use CPU path for correctness (matches naive_moe_ref behavior)
    # The GPU path processes all experts which can lead to numerical differences
    # CPU path explicitly uses router_indices like the reference implementation
    model.train()  # Force CPU path

    # Flatten routing_weights to [batch_seq, num_experts]
    routing_weights_flat = routing_weights.view(-1, E)

    # Run forward pass
    with torch.no_grad():
        output = model(hidden_states, router_indices, routing_weights_flat)

    model.eval()  # Reset to eval mode

    return output


run_benchmark(
    kernel_type=KernelTypeEnum.OPENAI_MOE,
    impl_name="gpt_oss_experts",
    impl_tags={"family": "reference", "backend": "pytorch"},
    impl_func=gpt_oss_openai_moe,
    dtype="float32",
)
Running openai_moe benchmark on cuda with 8 workloads.

======================================================================
PROFILE TRACE: gpt_oss_experts | cuda_B1_S512_E2
======================================================================
-------------------------------------------------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  
                                                   Name    Self CPU %      Self CPU   CPU total %     CPU total  CPU time avg     Self CUDA   Self CUDA %    CUDA total  CUDA time avg    # of Calls  
-------------------------------------------------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  
                                        gpt_oss_experts         0.00%       0.000us         0.00%       0.000us       0.000us      10.211ms       197.81%      10.211ms      10.211ms             1  
                                        gpt_oss_experts        16.48%       2.023ms        99.94%      12.270ms      12.270ms       0.000us         0.00%       5.165ms       5.165ms             1  
                                           aten::matmul         0.22%      26.489us         3.82%     468.520us      39.043us       0.000us         0.00%       4.540ms     378.357us            12  
                                               aten::mm         2.36%     289.825us         3.60%     442.031us      36.836us       4.540ms        87.96%       4.540ms     378.357us            12  
                                 ampere_sgemm_128x64_nn         0.00%       0.000us         0.00%       0.000us       0.000us       3.078ms        59.62%       3.078ms     341.948us             9  
void cutlass::Kernel2<cutlass_80_simt_sgemm_128x64_8...         0.00%       0.000us         0.00%       0.000us       0.000us       1.457ms        28.23%       1.457ms     485.813us             3  
                                              aten::mul         1.42%     174.948us         2.34%     287.701us      11.988us     109.119us         2.11%     109.119us       4.547us            24  
                                              aten::add         1.61%     197.786us         3.85%     472.357us      26.242us     103.039us         2.00%     103.039us       5.724us            18  
                                            aten::index         1.73%     212.127us         2.86%     350.900us      29.242us      86.591us         1.68%      86.591us       7.216us            12  
                                       aten::index_add_         0.51%      62.499us         0.79%      97.312us      16.219us      82.688us         1.60%      82.688us      13.781us             6  
void at::native::indexFuncLargeIndex<float, long, un...         0.00%       0.000us         0.00%       0.000us       0.000us      82.688us         1.60%      82.688us      13.781us             6  
void at::native::elementwise_kernel<128, 2, at::nati...         0.00%       0.000us         0.00%       0.000us       0.000us      80.511us         1.56%      80.511us       6.709us            12  
                                          aten::nonzero         2.20%     270.146us         6.58%     808.380us      89.820us      63.743us         1.23%      74.368us       8.263us             9  
                                            aten::clamp         0.98%     120.045us         1.63%     200.026us      16.669us      64.705us         1.25%      64.705us       5.392us            12  
void at::native::elementwise_kernel<128, 2, at::nati...         0.00%       0.000us         0.00%       0.000us       0.000us      64.705us         1.25%      64.705us       5.392us            12  
                                            aten::where         0.06%       7.400us         5.25%     644.007us     107.334us       0.000us         0.00%      60.384us      10.064us             6  
                                    aten::nonzero_numpy         0.11%      13.320us         5.19%     636.607us     106.101us       0.000us         0.00%      60.384us      10.064us             6  
void at::native::vectorized_gather_kernel<16, long>(...         0.00%       0.000us         0.00%       0.000us       0.000us      60.063us         1.16%      60.063us      10.011us             6  
void at::native::vectorized_elementwise_kernel<4, at...         0.00%       0.000us         0.00%       0.000us       0.000us      56.800us         1.10%      56.800us       4.733us            12  
                         Memcpy DtoH (Device -> Pinned)         0.00%       0.000us         0.00%       0.000us       0.000us      50.911us         0.99%      50.911us       1.131us            45  
-------------------------------------------------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  
Self CPU time total: 12.278ms
Self CUDA time total: 5.162ms



======================================================================
PROFILE TRACE: gpt_oss_experts | cuda_B1_S512_E4
======================================================================
-------------------------------------------------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  
                                                   Name    Self CPU %      Self CPU   CPU total %     CPU total  CPU time avg     Self CUDA   Self CUDA %    CUDA total  CUDA time avg    # of Calls  
-------------------------------------------------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  
                                        gpt_oss_experts         0.00%       0.000us         0.00%       0.000us       0.000us      13.933ms       229.38%      13.933ms      13.933ms             1  
                                        gpt_oss_experts        16.29%       2.560ms        99.97%      15.712ms      15.712ms       0.000us         0.00%       6.077ms       6.077ms             1  
                                           aten::matmul         0.30%      47.223us         5.17%     812.581us      33.858us       0.000us         0.00%       5.268ms     219.512us            24  
                                               aten::mm         3.09%     485.951us         4.87%     765.358us      31.890us       5.268ms        86.73%       5.268ms     219.512us            24  
                                 ampere_sgemm_128x64_nn         0.00%       0.000us         0.00%       0.000us       0.000us       5.213ms        85.81%       5.213ms     217.198us            24  
                                          aten::nonzero         2.45%     385.408us         7.89%       1.240ms      82.649us     112.163us         1.85%     134.498us       8.967us            15  
                                              aten::mul         2.03%     318.275us         3.36%     528.222us      11.005us     130.496us         2.15%     130.496us       2.719us            48  
                                              aten::add         2.25%     353.820us         3.74%     587.771us      16.327us     127.072us         2.09%     127.072us       3.530us            36  
                                            aten::where         0.08%      11.882us         7.49%       1.177ms      98.080us       0.000us         0.00%     120.705us      10.059us            12  
                                    aten::nonzero_numpy         0.15%      24.083us         7.41%       1.165ms      97.090us       0.000us         0.00%     120.705us      10.059us            12  
                                            aten::index         2.31%     363.442us         3.93%     617.030us      25.710us     110.145us         1.81%     110.145us       4.589us            24  
void at::native::elementwise_kernel<128, 2, at::nati...         0.00%       0.000us         0.00%       0.000us       0.000us     101.312us         1.67%     101.312us       4.221us            24  
                         Memcpy DtoH (Device -> Pinned)         0.00%       0.000us         0.00%       0.000us       0.000us      91.447us         1.51%      91.447us       1.051us            87  
                                            aten::clamp         1.32%     207.076us         2.26%     355.011us      14.792us      85.793us         1.41%      85.793us       3.575us            24  
void at::native::elementwise_kernel<128, 2, at::nati...         0.00%       0.000us         0.00%       0.000us       0.000us      85.793us         1.41%      85.793us       3.575us            24  
                                             aten::item         0.52%      81.620us        38.60%       6.066ms      84.255us       0.000us         0.00%      75.446us       1.048us            72  
                              aten::_local_scalar_dense         2.00%     315.046us        38.08%       5.985ms      83.122us      75.446us         1.24%      75.446us       1.048us            72  
                                       aten::index_add_         0.75%     118.511us         1.16%     182.084us      15.174us      72.926us         1.20%      72.926us       6.077us            12  
void at::native::indexFuncLargeIndex<float, long, un...         0.00%       0.000us         0.00%       0.000us       0.000us      72.926us         1.20%      72.926us       6.077us            12  
void at::native::vectorized_gather_kernel<16, long>(...         0.00%       0.000us         0.00%       0.000us       0.000us      65.857us         1.08%      65.857us       5.488us            12  
-------------------------------------------------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  
Self CPU time total: 15.717ms
Self CUDA time total: 6.074ms



======================================================================
PROFILE TRACE: gpt_oss_experts | cuda_B1_S1024_E2
======================================================================
-------------------------------------------------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  
                                                   Name    Self CPU %      Self CPU   CPU total %     CPU total  CPU time avg     Self CUDA   Self CUDA %    CUDA total  CUDA time avg    # of Calls  
-------------------------------------------------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  
                                        gpt_oss_experts         0.00%       0.000us         0.00%       0.000us       0.000us      12.540ms       148.48%      12.540ms      12.540ms             1  
                                        gpt_oss_experts        11.83%       1.734ms        99.96%      14.654ms      14.654ms       0.000us         0.00%       8.451ms       8.451ms             1  
                                           aten::matmul         0.16%      23.602us         3.00%     439.592us      36.633us       0.000us         0.00%       7.417ms     618.087us            12  
                                               aten::mm         1.78%     261.037us         2.84%     415.990us      34.666us       7.417ms        87.82%       7.417ms     618.087us            12  
void cutlass::Kernel2<cutlass_80_simt_sgemm_256x128_...         0.00%       0.000us         0.00%       0.000us       0.000us       4.532ms        53.65%       4.532ms     755.263us             6  
                                 ampere_sgemm_128x64_nn         0.00%       0.000us         0.00%       0.000us       0.000us       1.475ms        17.46%       1.475ms     491.509us             3  
void cutlass::Kernel2<cutlass_80_simt_sgemm_128x64_8...         0.00%       0.000us         0.00%       0.000us       0.000us       1.405ms        16.64%       1.405ms     468.490us             3  
                                              aten::mul         1.05%     153.262us         1.78%     261.173us      10.882us     197.791us         2.34%     197.791us       8.241us            24  
                                              aten::add         1.26%     184.574us         2.07%     304.007us      16.889us     188.543us         2.23%     188.543us      10.475us            18  
                                       aten::index_add_         0.35%      50.951us         0.57%      83.553us      13.925us     169.408us         2.01%     169.408us      28.235us             6  
void at::native::indexFuncLargeIndex<float, long, un...         0.00%       0.000us         0.00%       0.000us       0.000us     169.408us         2.01%     169.408us      28.235us             6  
void at::native::elementwise_kernel<128, 2, at::nati...         0.00%       0.000us         0.00%       0.000us       0.000us     149.663us         1.77%     149.663us      12.472us            12  
                                            aten::index         1.27%     186.102us         2.16%     316.927us      26.411us     146.942us         1.74%     146.942us      12.245us            12  
void at::native::vectorized_gather_kernel<16, long>(...         0.00%       0.000us         0.00%       0.000us       0.000us     117.440us         1.39%     117.440us      19.573us             6  
                                            aten::clamp         0.71%     104.743us         1.22%     178.924us      14.910us     110.912us         1.31%     110.912us       9.243us            12  
void at::native::elementwise_kernel<128, 2, at::nati...         0.00%       0.000us         0.00%       0.000us       0.000us     110.912us         1.31%     110.912us       9.243us            12  
void at::native::vectorized_elementwise_kernel<4, at...         0.00%       0.000us         0.00%       0.000us       0.000us     104.864us         1.24%     104.864us       8.739us            12  
                                          aten::nonzero         1.58%     232.211us         4.94%     724.348us      80.483us      69.633us         0.82%      81.377us       9.042us             9  
                                            aten::where         0.04%       6.259us         4.08%     597.684us      99.614us       0.000us         0.00%      66.816us      11.136us             6  
                                    aten::nonzero_numpy         0.08%      11.999us         4.03%     591.425us      98.571us       0.000us         0.00%      66.816us      11.136us             6  
-------------------------------------------------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  
Self CPU time total: 14.659ms
Self CUDA time total: 8.446ms



======================================================================
PROFILE TRACE: gpt_oss_experts | cuda_B1_S1024_E4
======================================================================
-------------------------------------------------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  
                                                   Name    Self CPU %      Self CPU   CPU total %     CPU total  CPU time avg     Self CUDA   Self CUDA %    CUDA total  CUDA time avg    # of Calls  
-------------------------------------------------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  
                                        gpt_oss_experts         0.00%       0.000us         0.00%       0.000us       0.000us      18.317ms       174.31%      18.317ms      18.317ms             1  
                                        gpt_oss_experts        13.54%       2.761ms        99.97%      20.385ms      20.385ms       0.000us         0.00%      10.514ms      10.514ms             1  
                                           aten::matmul         0.23%      47.082us         4.02%     819.853us      34.161us       0.000us         0.00%       9.237ms     384.865us            24  
                                               aten::mm         2.37%     482.255us         3.79%     772.771us      32.199us       9.237ms        87.90%       9.237ms     384.865us            24  
                                 ampere_sgemm_128x64_nn         0.00%       0.000us         0.00%       0.000us       0.000us       6.282ms        59.78%       6.282ms     349.001us            18  
void cutlass::Kernel2<cutlass_80_simt_sgemm_128x64_8...         0.00%       0.000us         0.00%       0.000us       0.000us       2.944ms        28.01%       2.944ms     490.655us             6  
                                              aten::mul         1.50%     305.331us         2.55%     520.818us      10.850us     235.298us         2.24%     235.298us       4.902us            48  
                                              aten::add         1.72%     351.113us         2.86%     584.036us      16.223us     213.502us         2.03%     213.502us       5.931us            36  
                                            aten::index         1.95%     397.314us         3.28%     668.454us      27.852us     205.349us         1.95%     205.349us       8.556us            24  
void at::native::elementwise_kernel<128, 2, at::nati...         0.00%       0.000us         0.00%       0.000us       0.000us     166.720us         1.59%     166.720us       6.947us            24  
                                       aten::index_add_         0.50%     101.340us         0.81%     165.573us      13.798us     155.585us         1.48%     155.585us      12.965us            12  
void at::native::indexFuncLargeIndex<float, long, un...         0.00%       0.000us         0.00%       0.000us       0.000us     155.585us         1.48%     155.585us      12.965us            12  
void at::native::vectorized_gather_kernel<16, long>(...         0.00%       0.000us         0.00%       0.000us       0.000us     146.947us         1.40%     146.947us      12.246us            12  
                                          aten::nonzero         1.95%     398.176us         6.26%       1.276ms      85.090us     121.380us         1.16%     145.668us       9.711us            15  
                                            aten::clamp         1.04%     212.193us         1.79%     365.180us      15.216us     134.239us         1.28%     134.239us       5.593us            24  
void at::native::elementwise_kernel<128, 2, at::nati...         0.00%       0.000us         0.00%       0.000us       0.000us     134.239us         1.28%     134.239us       5.593us            24  
                                            aten::where         0.06%      11.340us         5.97%       1.216ms     101.373us       0.000us         0.00%     131.522us      10.960us            12  
                                    aten::nonzero_numpy         0.12%      24.140us         5.91%       1.205ms     100.428us       0.000us         0.00%     131.522us      10.960us            12  
void at::native::vectorized_elementwise_kernel<4, at...         0.00%       0.000us         0.00%       0.000us       0.000us     119.840us         1.14%     119.840us       4.993us            24  
                         Memcpy DtoH (Device -> Pinned)         0.00%       0.000us         0.00%       0.000us       0.000us     100.830us         0.96%     100.830us       1.159us            87  
-------------------------------------------------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  
Self CPU time total: 20.390ms
Self CUDA time total: 10.509ms



======================================================================
PROFILE TRACE: gpt_oss_experts | cuda_B4_S512_E2
======================================================================
-------------------------------------------------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  
                                                   Name    Self CPU %      Self CPU   CPU total %     CPU total  CPU time avg     Self CUDA   Self CUDA %    CUDA total  CUDA time avg    # of Calls  
-------------------------------------------------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  
                                        gpt_oss_experts         0.00%       0.000us         0.00%       0.000us       0.000us      21.031ms       119.92%      21.031ms      21.031ms             1  
                                        gpt_oss_experts         7.59%       1.747ms        99.98%      23.024ms      23.024ms       0.000us         0.00%      17.548ms      17.548ms             1  
                                           aten::matmul         0.10%      23.660us         1.94%     446.020us      37.168us       0.000us         0.00%      14.659ms       1.222ms            12  
                                               aten::mm         1.17%     268.524us         1.83%     422.360us      35.197us      14.659ms        83.59%      14.659ms       1.222ms            12  
void cutlass::Kernel2<cutlass_80_simt_sgemm_256x128_...         0.00%       0.000us         0.00%       0.000us       0.000us       8.967ms        51.13%       8.967ms       1.495ms             6  
                                 ampere_sgemm_128x64_nn         0.00%       0.000us         0.00%       0.000us       0.000us       5.685ms        32.42%       5.685ms     947.562us             6  
                                              aten::add         0.82%     187.722us         1.36%     312.616us      17.368us     785.408us         4.48%     785.408us      43.634us            18  
                                              aten::mul         0.68%     156.369us         1.15%     264.222us      11.009us     674.688us         3.85%     674.688us      28.112us            24  
void at::native::elementwise_kernel<128, 2, at::nati...         0.00%       0.000us         0.00%       0.000us       0.000us     504.575us         2.88%     504.575us      42.048us            12  
                                       aten::index_add_         0.22%      50.951us         0.37%      86.132us      14.355us     448.545us         2.56%     448.545us      74.757us             6  
void at::native::indexFuncLargeIndex<float, long, un...         0.00%       0.000us         0.00%       0.000us       0.000us     448.545us         2.56%     448.545us      74.757us             6  
                                            aten::clamp         0.46%     107.053us         0.80%     183.295us      15.275us     336.000us         1.92%     336.000us      28.000us            12  
void at::native::elementwise_kernel<128, 2, at::nati...         0.00%       0.000us         0.00%       0.000us       0.000us     336.000us         1.92%     336.000us      28.000us            12  
void at::native::vectorized_elementwise_kernel<4, at...         0.00%       0.000us         0.00%       0.000us       0.000us     314.239us         1.79%     314.239us      52.373us             6  
void at::native::vectorized_elementwise_kernel<4, at...         0.00%       0.000us         0.00%       0.000us       0.000us     280.833us         1.60%     280.833us      46.806us             6  
                                            aten::index         0.81%     185.806us         1.39%     320.548us      26.712us     259.102us         1.48%     259.102us      21.592us            12  
void at::native::vectorized_elementwise_kernel<4, at...         0.00%       0.000us         0.00%       0.000us       0.000us     258.944us         1.48%     258.944us      21.579us            12  
void at::native::vectorized_gather_kernel<16, long>(...         0.00%       0.000us         0.00%       0.000us       0.000us     225.407us         1.29%     225.407us      37.568us             6  
                                          aten::sigmoid         0.16%      36.131us         0.27%      61.901us      10.317us     175.073us         1.00%     175.073us      29.179us             6  
void at::native::vectorized_elementwise_kernel<4, at...         0.00%       0.000us         0.00%       0.000us       0.000us     175.073us         1.00%     175.073us      29.179us             6  
-------------------------------------------------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  
Self CPU time total: 23.030ms
Self CUDA time total: 17.537ms



======================================================================
PROFILE TRACE: gpt_oss_experts | cuda_B4_S512_E4
======================================================================
-------------------------------------------------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  
                                                   Name    Self CPU %      Self CPU   CPU total %     CPU total  CPU time avg     Self CUDA   Self CUDA %    CUDA total  CUDA time avg    # of Calls  
-------------------------------------------------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  
                                        gpt_oss_experts         0.00%       0.000us         0.00%       0.000us       0.000us      24.377ms       140.11%      24.377ms      24.377ms             1  
                                        gpt_oss_experts        10.50%       2.651ms        99.98%      25.237ms      25.237ms       0.000us         0.00%      17.408ms      17.408ms             1  
                                           aten::matmul         0.19%      47.519us         3.41%     860.801us      35.867us       0.000us         0.00%      15.185ms     632.705us            24  
                                               aten::mm         2.06%     521.061us         3.22%     813.282us      33.887us      15.185ms        87.28%      15.185ms     632.705us            24  
void cutlass::Kernel2<cutlass_80_simt_sgemm_256x128_...         0.00%       0.000us         0.00%       0.000us       0.000us       9.179ms        52.76%       9.179ms     764.922us            12  
                                 ampere_sgemm_128x64_nn         0.00%       0.000us         0.00%       0.000us       0.000us       3.124ms        17.96%       3.124ms     520.682us             6  
void cutlass::Kernel2<cutlass_80_simt_sgemm_128x64_8...         0.00%       0.000us         0.00%       0.000us       0.000us       2.871ms        16.50%       2.871ms     478.432us             6  
                                              aten::add         1.42%     359.495us         2.37%     598.003us      16.611us     427.713us         2.46%     427.713us      11.881us            36  
                                              aten::mul         1.23%     309.946us         2.09%     527.073us      10.981us     420.510us         2.42%     420.510us       8.761us            48  
                                       aten::index_add_         0.40%     101.283us         0.66%     166.886us      13.907us     383.489us         2.20%     383.489us      31.957us            12  
void at::native::indexFuncLargeIndex<float, long, un...         0.00%       0.000us         0.00%       0.000us       0.000us     383.489us         2.20%     383.489us      31.957us            12  
void at::native::elementwise_kernel<128, 2, at::nati...         0.00%       0.000us         0.00%       0.000us       0.000us     343.712us         1.98%     343.712us      14.321us            24  
                                            aten::index         1.56%     393.991us         2.62%     662.158us      27.590us     337.086us         1.94%     337.086us      14.045us            24  
void at::native::vectorized_gather_kernel<16, long>(...         0.00%       0.000us         0.00%       0.000us       0.000us     272.926us         1.57%     272.926us      22.744us            12  
                                            aten::clamp         0.84%     212.993us         1.44%     363.038us      15.127us     230.431us         1.32%     230.431us       9.601us            24  
void at::native::elementwise_kernel<128, 2, at::nati...         0.00%       0.000us         0.00%       0.000us       0.000us     230.431us         1.32%     230.431us       9.601us            24  
void at::native::vectorized_elementwise_kernel<4, at...         0.00%       0.000us         0.00%       0.000us       0.000us     223.071us         1.28%     223.071us       9.295us            24  
                                          aten::nonzero         1.57%     395.401us         5.00%       1.262ms      84.127us     128.836us         0.74%     156.164us      10.411us            15  
                                            aten::where         0.05%      12.011us         4.77%       1.205ms     100.378us       0.000us         0.00%     140.900us      11.742us            12  
                                    aten::nonzero_numpy         0.10%      25.021us         4.72%       1.193ms      99.377us       0.000us         0.00%     140.900us      11.742us            12  
-------------------------------------------------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  
Self CPU time total: 25.242ms
Self CUDA time total: 17.398ms



======================================================================
PROFILE TRACE: gpt_oss_experts | cuda_B4_S1024_E2
======================================================================
-------------------------------------------------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  
                                                   Name    Self CPU %      Self CPU   CPU total %     CPU total  CPU time avg     Self CUDA   Self CUDA %    CUDA total  CUDA time avg    # of Calls  
-------------------------------------------------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  
                                        gpt_oss_experts         0.00%       0.000us         0.00%       0.000us       0.000us      40.556ms       109.47%      40.556ms      40.556ms             1  
                                        gpt_oss_experts         4.33%       1.794ms        99.85%      41.353ms      41.353ms       0.000us         0.00%      37.080ms      37.080ms             1  
                                           aten::matmul         0.06%      24.371us         1.08%     445.903us      37.159us       0.000us         0.00%      27.082ms       2.257ms            12  
                                               aten::mm         0.70%     291.738us         1.02%     421.532us      35.128us      27.082ms        73.10%      27.082ms       2.257ms            12  
void cutlass::Kernel2<cutlass_80_simt_sgemm_256x128_...         0.00%       0.000us         0.00%       0.000us       0.000us      27.079ms        73.09%      27.079ms       2.257ms            12  
                                              aten::mul         0.38%     159.199us         0.65%     268.178us      11.174us       2.983ms         8.05%       2.983ms     124.287us            24  
                                              aten::add         0.48%     198.424us         1.09%     451.763us      25.098us       2.404ms         6.49%       2.404ms     133.559us            18  
                                            aten::clamp         0.27%     112.290us         0.46%     189.433us      15.786us       2.392ms         6.46%       2.392ms     199.373us            12  
void at::native::elementwise_kernel<128, 2, at::nati...         0.00%       0.000us         0.00%       0.000us       0.000us       2.392ms         6.46%       2.392ms     199.373us            12  
void at::native::vectorized_elementwise_kernel<4, at...         0.00%       0.000us         0.00%       0.000us       0.000us       1.988ms         5.37%       1.988ms     165.669us            12  
void at::native::elementwise_kernel<128, 2, at::nati...         0.00%       0.000us         0.00%       0.000us       0.000us       1.629ms         4.40%       1.629ms     135.763us            12  
                                       aten::index_add_         0.12%      50.103us         0.20%      84.453us      14.076us     899.456us         2.43%     899.456us     149.909us             6  
void at::native::indexFuncLargeIndex<float, long, un...         0.00%       0.000us         0.00%       0.000us       0.000us     899.456us         2.43%     899.456us     149.909us             6  
void at::native::vectorized_elementwise_kernel<4, at...         0.00%       0.000us         0.00%       0.000us       0.000us     774.912us         2.09%     774.912us     129.152us             6  
void at::native::vectorized_elementwise_kernel<4, at...         0.00%       0.000us         0.00%       0.000us       0.000us     733.217us         1.98%     733.217us     122.203us             6  
                                            aten::index         0.45%     187.302us         0.77%     318.787us      26.566us     712.767us         1.92%     712.767us      59.397us            12  
void at::native::vectorized_gather_kernel<16, long>(...         0.00%       0.000us         0.00%       0.000us       0.000us     678.496us         1.83%     678.496us     113.083us             6  
                                          aten::sigmoid         0.09%      36.082us         0.15%      63.023us      10.504us     323.008us         0.87%     323.008us      53.835us             6  
void at::native::vectorized_elementwise_kernel<4, at...         0.00%       0.000us         0.00%       0.000us       0.000us     323.008us         0.87%     323.008us      53.835us             6  
void at::native::elementwise_kernel<128, 2, at::nati...         0.00%       0.000us         0.00%       0.000us       0.000us     261.631us         0.71%     261.631us      43.605us             6  
-------------------------------------------------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  
Self CPU time total: 41.415ms
Self CUDA time total: 37.046ms



======================================================================
PROFILE TRACE: gpt_oss_experts | cuda_B4_S1024_E4
======================================================================
-------------------------------------------------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  
                                                   Name    Self CPU %      Self CPU   CPU total %     CPU total  CPU time avg     Self CUDA   Self CUDA %    CUDA total  CUDA time avg    # of Calls  
-------------------------------------------------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  
                                        gpt_oss_experts         0.00%       0.000us         0.00%       0.000us       0.000us      41.050ms       117.27%      41.050ms      41.050ms             1  
                                        gpt_oss_experts         6.46%       2.709ms        99.99%      41.912ms      41.912ms       0.000us         0.00%      35.025ms      35.025ms             1  
                                           aten::matmul         0.11%      47.590us         2.12%     888.873us      37.036us       0.000us         0.00%      29.051ms       1.210ms            24  
                                               aten::mm         1.28%     536.727us         2.01%     841.283us      35.053us      29.051ms        82.99%      29.051ms       1.210ms            24  
void cutlass::Kernel2<cutlass_80_simt_sgemm_256x128_...         0.00%       0.000us         0.00%       0.000us       0.000us      20.585ms        58.81%      20.585ms       1.372ms            15  
                                 ampere_sgemm_128x64_nn         0.00%       0.000us         0.00%       0.000us       0.000us       8.453ms        24.15%       8.453ms     939.204us             9  
                                              aten::add         0.88%     367.610us         1.45%     609.056us      16.918us       1.486ms         4.24%       1.486ms      41.264us            36  
                                              aten::mul         0.74%     309.128us         1.24%     518.283us      10.798us       1.380ms         3.94%       1.380ms      28.757us            48  
void at::native::elementwise_kernel<128, 2, at::nati...         0.00%       0.000us         0.00%       0.000us       0.000us     925.695us         2.64%     925.695us      38.571us            24  
                                       aten::index_add_         0.24%      99.111us         0.40%     167.273us      13.939us     903.487us         2.58%     903.487us      75.291us            12  
void at::native::indexFuncLargeIndex<float, long, un...         0.00%       0.000us         0.00%       0.000us       0.000us     903.487us         2.58%     903.487us      75.291us            12  
                                            aten::clamp         0.51%     214.986us         0.87%     364.790us      15.200us     775.806us         2.22%     775.806us      32.325us            24  
void at::native::elementwise_kernel<128, 2, at::nati...         0.00%       0.000us         0.00%       0.000us       0.000us     775.806us         2.22%     775.806us      32.325us            24  
                                            aten::index         0.89%     373.269us         1.50%     629.207us      26.217us     670.881us         1.92%     670.881us      27.953us            24  
void at::native::vectorized_elementwise_kernel<4, at...         0.00%       0.000us         0.00%       0.000us       0.000us     631.200us         1.80%     631.200us      52.600us            12  
void at::native::vectorized_gather_kernel<16, long>(...         0.00%       0.000us         0.00%       0.000us       0.000us     600.224us         1.71%     600.224us      50.019us            12  
void at::native::vectorized_elementwise_kernel<4, at...         0.00%       0.000us         0.00%       0.000us       0.000us     559.808us         1.60%     559.808us      46.651us            12  
void at::native::vectorized_elementwise_kernel<4, at...         0.00%       0.000us         0.00%       0.000us       0.000us     540.611us         1.54%     540.611us      22.525us            24  
                                          aten::sigmoid         0.17%      72.182us         0.29%     123.582us      10.298us     351.039us         1.00%     351.039us      29.253us            12  
void at::native::vectorized_elementwise_kernel<4, at...         0.00%       0.000us         0.00%       0.000us       0.000us     351.039us         1.00%     351.039us      29.253us            12  
-------------------------------------------------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  
Self CPU time total: 41.917ms
Self CUDA time total: 35.005ms


impl                     wl                  p50(ms)  ok
gpt_oss_experts          cuda_B1_S1024_E2       3.79  True
gpt_oss_experts          cuda_B1_S1024_E4       5.24  True
gpt_oss_experts          cuda_B1_S512_E2        2.63  True
gpt_oss_experts          cuda_B1_S512_E4        3.89  True
gpt_oss_experts          cuda_B4_S1024_E2      13.28  True
gpt_oss_experts          cuda_B4_S1024_E4      13.19  True
gpt_oss_experts          cuda_B4_S512_E2        6.74  True
gpt_oss_experts          cuda_B4_S512_E4        7.36  True
▶ UV Install Logs
Fetching 6 files: 0%| | 0/6 [00:00<?, ?it/s] Fetching 6 files: 33%|███▎ | 2/6 [00:00<00:00, 16.13it/s] Fetching 6 files: 67%|██████▋ | 4/6 [00:00<00:00, 7.33it/s] Fetching 6 files: 100%|██████████| 6/6 [00:00<00:00, 11.97it/s]

Artifacts:

openai_moe.jsonl