Binned PyTorch - OpenAI-style MoE

GPU Info

▼ code ▼ output ▶ uv-logs | Cell: nv | 0.24s | Raw GitHub
import subprocess
print(subprocess.run(["nvidia-smi"], capture_output=True, text=True).stdout)
Fri Oct 31 20:00:34 2025       
+-----------------------------------------------------------------------------------------+
| NVIDIA-SMI 570.195.03             Driver Version: 570.195.03     CUDA Version: 12.8     |
|-----------------------------------------+------------------------+----------------------+
| GPU  Name                 Persistence-M | Bus-Id          Disp.A | Volatile Uncorr. ECC |
| Fan  Temp   Perf          Pwr:Usage/Cap |           Memory-Usage | GPU-Util  Compute M. |
|                                         |                        |               MIG M. |
|=========================================+========================+======================|
|   0  NVIDIA L40S                    On  |   00000000:4D:00.0 Off |                    0 |
| N/A   34C    P0             81W /  350W |       0MiB /  46068MiB |     18%      Default |
|                                         |                        |                  N/A |
+-----------------------------------------+------------------------+----------------------+

+-----------------------------------------------------------------------------------------+
| Processes:                                                                              |
|  GPU   GI   CI              PID   Type   Process name                        GPU Memory |
|        ID   ID                                                               Usage      |
|=========================================================================================|
|  No running processes found                                                             |
+-----------------------------------------------------------------------------------------+

OpenAI-style MoE Benchmark (Binned PyTorch)

▼ code ▼ output ▶ uv-logs | Cell: benchmark | 727.85s | Raw GitHub
# /// script
# requires-python = ">=3.10"
# dependencies = [
#     "numpy",
#     "torch==2.8.0",
#     "kernels-benchmark-tools",
# ]
#
# [tool.uv.sources]
# kernels-benchmark-tools = { path = "../../../../../tools", editable = true }
# ///
import torch
import sys
from kernels_benchmark_tools import KernelTypeEnum, run_benchmark


def binned_gather(x, indices, bins, expert_capacity, top_k):
    E, H = bins.shape[0], x.shape[1]
    out = torch.zeros((E, expert_capacity, H), device=x.device, dtype=x.dtype)
    for e in range(E):
        start = 0 if e == 0 else bins[e - 1]
        end = bins[e]
        n = min(end - start, expert_capacity)
        for i in range(n):
            flat_pos = indices[start + i]
            tok = flat_pos // top_k
            out[e, i] = x[tok]
    return out


def binned_scatter(x, indices, weights, bins, expert_capacity, top_k):
    E, C, H = x.shape
    N = indices.shape[0] // top_k
    out = torch.zeros((N, top_k, H), dtype=x.dtype, device=x.device)
    for e in range(E):
        start = 0 if e == 0 else bins[e - 1]
        end = bins[e]
        n = end - start
        if n == 0:
            continue
        take = min(n, expert_capacity)
        for i in range(take):
            flat_pos = indices[start + i]  # flattened (token, slot)
            tok = flat_pos // top_k
            slot = flat_pos % top_k
            scale = weights[flat_pos] if weights is not None else 1.0
            out[tok, slot] = x[e, i] * scale
    return out.sum(dim=1)


def sort_tokens_by_expert(router_indices, num_experts):
    flat_indices = router_indices.flatten()
    sorted_values, sorted_indices = torch.sort(flat_indices)
    tokens_per_expert = torch.bincount(sorted_values, minlength=num_experts)
    bins = torch.cumsum(tokens_per_expert, dim=0)
    return sorted_indices, sorted_values, bins, tokens_per_expert


def binned_experts_ref(
    hidden_states,
    router_indices,
    routing_weights,
    gate_up_proj,
    gate_up_proj_bias,
    down_proj,
    down_proj_bias,
    expert_capacity,
):
    B, S, H = hidden_states.shape
    E, K = routing_weights.shape[2], router_indices.shape[1]

    indices, _, bins, _ = sort_tokens_by_expert(router_indices, E)
    x = binned_gather(hidden_states.view(-1, H), indices, bins, expert_capacity, K)

    gate_up = torch.bmm(x, gate_up_proj) + gate_up_proj_bias[..., None, :]
    gate, up = gate_up[..., ::2], gate_up[..., 1::2]

    # clamp to limit
    limit = 7.0
    gate = gate.clamp(min=None, max=limit)
    up = up.clamp(min=-limit, max=limit)

    glu = gate * torch.sigmoid(gate * 1.702)
    x = (up + 1) * glu
    x = torch.bmm(x, down_proj) + down_proj_bias[..., None, :]

    # build routing weights aligned to (token, slot)
    flat_dense = routing_weights.view(-1, E)  # [B*S, E]
    flat_router = router_indices.view(-1, K)  # [B*S, K]
    selected = torch.gather(flat_dense, 1, flat_router).reshape(-1)  # [B*S*K]

    # scatter back
    y = binned_scatter(x, indices, selected, bins, expert_capacity, K)  # [B*S, H]

    return y.view(B, S, H)


def binned_torch_openai_moe(
    hidden_states,
    router_indices,
    routing_weights,
    gate_up_proj,
    gate_up_proj_bias,
    down_proj,
    down_proj_bias,
):
    """
    Binned PyTorch implementation of OpenAI-style MoE.
    Sorts tokens by expert assignment for more efficient batched processing.
    """
    B, S = hidden_states.shape[0], hidden_states.shape[1]
    K = router_indices.shape[1]

    # Set expert_capacity to a reasonable value (max tokens per expert)
    # Use 2x the average to handle imbalance
    expert_capacity = (B * S * K * 2) // routing_weights.shape[2]

    return binned_experts_ref(
        hidden_states,
        router_indices,
        routing_weights,
        gate_up_proj,
        gate_up_proj_bias,
        down_proj,
        down_proj_bias,
        expert_capacity,
    )


run_benchmark(
    kernel_type=KernelTypeEnum.OPENAI_MOE,
    impl_name="binned_torch",
    impl_tags={"family": "pytorch", "backend": "eager"},
    impl_func=binned_torch_openai_moe,
    dtype="float32",
)
Running openai_moe benchmark on cuda with 8 workloads.

======================================================================
PROFILE TRACE: binned_torch | cuda_B1_S512_E2
======================================================================
-------------------------------------------------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  
                                                   Name    Self CPU %      Self CPU   CPU total %     CPU total  CPU time avg     Self CUDA   Self CUDA %    CUDA total  CUDA time avg    # of Calls  
-------------------------------------------------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  
                                           binned_torch         0.00%       0.000us         0.00%       0.000us       0.000us     906.550ms      1808.50%     906.550ms     906.550ms             1  
                                           binned_torch        25.29%     229.728ms       100.00%     908.308ms     908.308ms       0.000us         0.00%      50.129ms      50.129ms             1  
                                             aten::item         1.81%      16.434ms        25.66%     233.033ms      15.186us       0.000us         0.00%      15.809ms       1.030us         15345  
                              aten::_local_scalar_dense         6.08%      55.189ms        23.85%     216.599ms      14.115us      15.808ms        31.54%      15.809ms       1.030us         15345  
                         Memcpy DtoH (Device -> Pinned)         0.00%       0.000us         0.00%       0.000us       0.000us      15.808ms        31.54%      15.808ms       1.030us         15345  
                                              aten::bmm         0.02%     187.925us         0.02%     226.636us      37.773us       7.688ms        15.34%       7.688ms       1.281ms             6  
                                ampere_sgemm_128x128_nn         0.00%       0.000us         0.00%       0.000us       0.000us       7.688ms        15.34%       7.688ms       1.281ms             6  
                                     aten::floor_divide         5.37%      48.789ms        13.13%     119.247ms      19.409us       7.554ms        15.07%       7.554ms       1.230us          6144  
                                            aten::copy_         3.71%      33.699ms         9.08%      82.451ms      13.394us       6.606ms        13.18%       6.607ms       1.073us          6156  
                         Memcpy DtoD (Device -> Device)         0.00%       0.000us         0.00%       0.000us       0.000us       6.602ms        13.17%       6.602ms       1.073us          6153  
                                              aten::mul         3.08%      27.972ms         5.49%      49.893ms      16.194us       4.718ms         9.41%       4.718ms       1.531us          3081  
void at::native::elementwise_kernel<128, 2, at::nati...         0.00%       0.000us         0.00%       0.000us       0.000us       4.471ms         8.92%       4.471ms       1.456us          3072  
void at::native::vectorized_elementwise_kernel<2, at...         0.00%       0.000us         0.00%       0.000us       0.000us       4.032ms         8.04%       4.032ms       1.312us          3072  
                                        aten::remainder         3.03%      27.567ms         4.66%      42.309ms      13.772us       3.722ms         7.42%       3.722ms       1.212us          3072  
                                              aten::add         2.91%      26.436ms         4.87%      44.207ms      14.575us       3.546ms         7.07%       3.546ms       1.169us          3033  
void at::native::unrolled_elementwise_kernel<at::nat...         0.00%       0.000us         0.00%       0.000us       0.000us       3.524ms         7.03%       3.524ms       1.147us          3072  
void at::native::vectorized_elementwise_kernel<2, at...         0.00%       0.000us         0.00%       0.000us       0.000us       3.156ms         6.30%       3.156ms       1.042us          3030  
void at::native::vectorized_elementwise_kernel<2, at...         0.00%       0.000us         0.00%       0.000us       0.000us       1.964ms         3.92%       1.964ms       1.279us          1536  
void at::native::unrolled_elementwise_kernel<at::nat...         0.00%       0.000us         0.00%       0.000us       0.000us       1.758ms         3.51%       1.758ms       1.145us          1536  
void at::native::elementwise_kernel<128, 2, at::nati...         0.00%       0.000us         0.00%       0.000us       0.000us     286.305us         0.57%     286.305us      47.718us             6  
-------------------------------------------------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  
Self CPU time total: 908.315ms
Self CUDA time total: 50.127ms



======================================================================
PROFILE TRACE: binned_torch | cuda_B1_S512_E4
======================================================================
-------------------------------------------------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  
                                                   Name    Self CPU %      Self CPU   CPU total %     CPU total  CPU time avg     Self CUDA   Self CUDA %    CUDA total  CUDA time avg    # of Calls  
-------------------------------------------------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  
                                           binned_torch         0.00%       0.000us         0.00%       0.000us       0.000us     939.657ms      1760.51%     939.657ms     939.657ms             1  
                                           binned_torch        24.72%     232.366ms       100.00%     940.175ms     940.175ms       0.000us         0.00%      53.379ms      53.379ms             1  
                                             aten::item         1.65%      15.471ms        26.56%     249.752ms      14.748us       0.000us         0.00%      17.339ms       1.024us         16935  
                              aten::_local_scalar_dense         6.16%      57.893ms        24.92%     234.282ms      13.834us      17.337ms        32.48%      17.339ms       1.024us         16935  
                         Memcpy DtoH (Device -> Pinned)         0.00%       0.000us         0.00%       0.000us       0.000us      17.337ms        32.48%      17.337ms       1.024us         16935  
                                              aten::bmm         0.02%     191.684us         0.02%     230.777us      38.463us       7.882ms        14.77%       7.882ms       1.314ms             6  
                                ampere_sgemm_128x128_nn         0.00%       0.000us         0.00%       0.000us       0.000us       7.882ms        14.77%       7.882ms       1.314ms             6  
                                     aten::floor_divide         5.10%      47.974ms        12.37%     116.337ms      18.935us       7.540ms        14.13%       7.541ms       1.227us          6144  
                                            aten::copy_         3.80%      35.738ms         9.00%      84.586ms      13.740us       6.593ms        12.35%       6.595ms       1.071us          6156  
                         Memcpy DtoD (Device -> Device)         0.00%       0.000us         0.00%       0.000us       0.000us       6.590ms        12.35%       6.590ms       1.071us          6153  
                                              aten::add         4.16%      39.066ms         7.01%      65.874ms      14.342us       5.113ms         9.58%       5.113ms       1.113us          4593  
                                              aten::mul         2.92%      27.472ms         5.20%      48.883ms      15.866us       4.715ms         8.83%       4.715ms       1.530us          3081  
void at::native::elementwise_kernel<128, 2, at::nati...         0.00%       0.000us         0.00%       0.000us       0.000us       4.472ms         8.38%       4.472ms       1.456us          3072  
void at::native::vectorized_elementwise_kernel<2, at...         0.00%       0.000us         0.00%       0.000us       0.000us       4.021ms         7.53%       4.021ms       1.309us          3072  
                                        aten::remainder         2.73%      25.664ms         4.27%      40.147ms      13.069us       3.707ms         6.95%       3.707ms       1.207us          3072  
void at::native::unrolled_elementwise_kernel<at::nat...         0.00%       0.000us         0.00%       0.000us       0.000us       3.519ms         6.59%       3.519ms       1.146us          3072  
void at::native::vectorized_elementwise_kernel<2, at...         0.00%       0.000us         0.00%       0.000us       0.000us       3.178ms         5.95%       3.178ms       1.049us          3030  
void at::native::vectorized_elementwise_kernel<2, at...         0.00%       0.000us         0.00%       0.000us       0.000us       1.958ms         3.67%       1.958ms       1.275us          1536  
void at::native::unrolled_elementwise_kernel<at::nat...         0.00%       0.000us         0.00%       0.000us       0.000us       1.749ms         3.28%       1.749ms       1.139us          1536  
void at::native::unrolled_elementwise_kernel<at::nat...         0.00%       0.000us         0.00%       0.000us       0.000us       1.537ms         2.88%       1.537ms       0.985us          1560  
-------------------------------------------------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  
Self CPU time total: 940.182ms
Self CUDA time total: 53.374ms



======================================================================
PROFILE TRACE: binned_torch | cuda_B1_S1024_E2
======================================================================
-------------------------------------------------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  
                                                   Name    Self CPU %      Self CPU   CPU total %     CPU total  CPU time avg     Self CUDA   Self CUDA %    CUDA total  CUDA time avg    # of Calls  
-------------------------------------------------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  
                                           binned_torch         0.00%       0.000us         0.00%       0.000us       0.000us        1.751s      1703.41%        1.751s        1.751s             1  
                                           binned_torch        24.63%     431.727ms       100.00%        1.753s        1.753s       0.000us         0.00%     102.829ms     102.829ms             1  
                                             aten::item         1.69%      29.621ms        25.96%     455.095ms      14.915us       0.000us         0.00%      31.387ms       1.029us         30513  
                              aten::_local_scalar_dense         5.96%     104.552ms        24.27%     425.474ms      13.944us      31.383ms        30.52%      31.387ms       1.029us         30513  
                         Memcpy DtoH (Device -> Pinned)         0.00%       0.000us         0.00%       0.000us       0.000us      31.383ms        30.52%      31.383ms       1.029us         30513  
                                              aten::bmm         0.01%     224.614us         0.02%     267.595us      44.599us      15.143ms        14.73%      15.143ms       2.524ms             6  
                                ampere_sgemm_128x128_nn         0.00%       0.000us         0.00%       0.000us       0.000us      15.143ms        14.73%      15.143ms       2.524ms             6  
                                     aten::floor_divide         5.56%      97.549ms        13.34%     233.779ms      19.025us      15.089ms        14.68%      15.090ms       1.228us         12288  
                                            aten::copy_         4.01%      70.283ms         9.47%     166.011ms      13.497us      13.317ms        12.95%      13.317ms       1.083us         12300  
                         Memcpy DtoD (Device -> Device)         0.00%       0.000us         0.00%       0.000us       0.000us      13.313ms        12.95%      13.313ms       1.083us         12294  
                                              aten::mul         3.14%      55.060ms         5.66%      99.236ms      16.128us      11.295ms        10.99%      11.297ms       1.836us          6153  
void at::native::elementwise_kernel<128, 2, at::nati...         0.00%       0.000us         0.00%       0.000us       0.000us       9.940ms         9.67%       9.940ms       1.618us          6144  
void at::native::vectorized_elementwise_kernel<2, at...         0.00%       0.000us         0.00%       0.000us       0.000us       8.059ms         7.84%       8.059ms       1.312us          6144  
                                              aten::add         2.85%      49.952ms         4.90%      85.866ms      14.522us       7.505ms         7.30%       7.506ms       1.269us          5913  
                                        aten::remainder         3.02%      53.015ms         4.74%      83.117ms      13.528us       7.414ms         7.21%       7.416ms       1.207us          6144  
void at::native::unrolled_elementwise_kernel<at::nat...         0.00%       0.000us         0.00%       0.000us       0.000us       7.031ms         6.84%       7.031ms       1.144us          6144  
void at::native::vectorized_elementwise_kernel<2, at...         0.00%       0.000us         0.00%       0.000us       0.000us       6.224ms         6.05%       6.224ms       1.053us          5910  
void at::native::vectorized_elementwise_kernel<2, at...         0.00%       0.000us         0.00%       0.000us       0.000us       3.914ms         3.81%       3.914ms       1.274us          3072  
void at::native::unrolled_elementwise_kernel<at::nat...         0.00%       0.000us         0.00%       0.000us       0.000us       3.500ms         3.40%       3.500ms       1.139us          3072  
                                            aten::clamp         0.00%      71.603us         0.01%     117.833us      19.639us       1.180ms         1.15%       1.180ms     196.722us             6  
-------------------------------------------------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  
Self CPU time total: 1.753s
Self CUDA time total: 102.819ms



======================================================================
PROFILE TRACE: binned_torch | cuda_B1_S1024_E4
======================================================================
-------------------------------------------------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  
                                                   Name    Self CPU %      Self CPU   CPU total %     CPU total  CPU time avg     Self CUDA   Self CUDA %    CUDA total  CUDA time avg    # of Calls  
-------------------------------------------------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  
                                           binned_torch         0.00%       0.000us         0.00%       0.000us       0.000us        1.834s      1680.90%        1.834s        1.834s             1  
                                           binned_torch        24.76%     454.393ms       100.00%        1.835s        1.835s       0.000us         0.00%     109.119ms     109.119ms             1  
                                             aten::item         1.65%      30.229ms        26.42%     484.819ms      14.374us       0.000us         0.00%      34.734ms       1.030us         33729  
                              aten::_local_scalar_dense         6.08%     111.551ms        24.77%     454.590ms      13.478us      34.731ms        31.83%      34.734ms       1.030us         33729  
                         Memcpy DtoH (Device -> Pinned)         0.00%       0.000us         0.00%       0.000us       0.000us      34.731ms        31.83%      34.731ms       1.030us         33729  
                                              aten::bmm         0.01%     219.836us         0.01%     260.868us      43.478us      15.243ms        13.97%      15.243ms       2.540ms             6  
                                ampere_sgemm_128x128_nn         0.00%       0.000us         0.00%       0.000us       0.000us      15.243ms        13.97%      15.243ms       2.540ms             6  
                                     aten::floor_divide         5.37%      98.619ms        12.62%     231.581ms      18.846us      15.065ms        13.81%      15.065ms       1.226us         12288  
                                            aten::copy_         3.65%      66.986ms         8.64%     158.623ms      12.896us      13.313ms        12.20%      13.316ms       1.083us         12300  
                         Memcpy DtoD (Device -> Device)         0.00%       0.000us         0.00%       0.000us       0.000us      13.309ms        12.20%      13.309ms       1.082us         12297  
                                              aten::mul         2.96%      54.365ms         5.27%      96.616ms      15.702us      10.967ms        10.05%      10.969ms       1.783us          6153  
                                              aten::add         4.05%      74.247ms         6.97%     127.934ms      14.060us      10.631ms         9.74%      10.631ms       1.168us          9099  
void at::native::elementwise_kernel<128, 2, at::nati...         0.00%       0.000us         0.00%       0.000us       0.000us       9.613ms         8.81%       9.613ms       1.565us          6144  
void at::native::vectorized_elementwise_kernel<2, at...         0.00%       0.000us         0.00%       0.000us       0.000us       8.047ms         7.37%       8.047ms       1.310us          6144  
                                        aten::remainder         2.81%      51.641ms         4.37%      80.193ms      13.052us       7.438ms         6.82%       7.438ms       1.211us          6144  
void at::native::unrolled_elementwise_kernel<at::nat...         0.00%       0.000us         0.00%       0.000us       0.000us       7.018ms         6.43%       7.018ms       1.142us          6144  
void at::native::vectorized_elementwise_kernel<2, at...         0.00%       0.000us         0.00%       0.000us       0.000us       6.225ms         5.71%       6.225ms       1.053us          5910  
void at::native::vectorized_elementwise_kernel<2, at...         0.00%       0.000us         0.00%       0.000us       0.000us       3.928ms         3.60%       3.928ms       1.279us          3072  
void at::native::unrolled_elementwise_kernel<at::nat...         0.00%       0.000us         0.00%       0.000us       0.000us       3.510ms         3.22%       3.510ms       1.143us          3072  
void at::native::unrolled_elementwise_kernel<at::nat...         0.00%       0.000us         0.00%       0.000us       0.000us       3.154ms         2.89%       3.154ms       0.990us          3186  
-------------------------------------------------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  
Self CPU time total: 1.835s
Self CUDA time total: 109.111ms



======================================================================
PROFILE TRACE: binned_torch | cuda_B4_S512_E2
======================================================================
-------------------------------------------------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  
                                                   Name    Self CPU %      Self CPU   CPU total %     CPU total  CPU time avg     Self CUDA   Self CUDA %    CUDA total  CUDA time avg    # of Calls  
-------------------------------------------------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  
                                           binned_torch         0.00%       0.000us         0.00%       0.000us       0.000us        3.518s      1672.53%        3.518s        3.518s             1  
                                           binned_torch        24.37%     858.118ms       100.00%        3.521s        3.521s       0.000us         0.00%     210.357ms     210.357ms             1  
                         Memcpy DtoH (Device -> Pinned)         0.00%       0.000us         0.00%       0.000us       0.000us      63.177ms        30.04%      63.177ms       1.026us         61586  
                                             aten::item         1.69%      59.432ms        26.02%     916.275ms      14.878us       0.000us         0.00%      63.177ms       1.026us         61587  
                              aten::_local_scalar_dense         5.96%     209.806ms        24.34%     856.843ms      13.913us      63.176ms        30.03%      63.177ms       1.026us         61587  
                                     aten::floor_divide         5.42%     190.698ms        13.50%     475.217ms      19.337us      30.482ms        14.49%      30.486ms       1.240us         24576  
                                              aten::bmm         0.01%     235.397us         0.01%     281.998us      47.000us      29.291ms        13.93%      29.291ms       4.882ms             6  
                                ampere_sgemm_128x128_nn         0.00%       0.000us         0.00%       0.000us       0.000us      29.291ms        13.93%      29.291ms       4.882ms             6  
                                            aten::copy_         3.77%     132.744ms         9.15%     322.282ms      13.107us      26.808ms        12.75%      26.810ms       1.090us         24588  
                         Memcpy DtoD (Device -> Device)         0.00%       0.000us         0.00%       0.000us       0.000us      26.805ms        12.74%      26.805ms       1.090us         24582  
                                              aten::mul         3.15%     110.895ms         5.78%     203.457ms      16.545us      25.566ms        12.15%      25.568ms       2.079us         12297  
void at::native::elementwise_kernel<128, 2, at::nati...         0.00%       0.000us         0.00%       0.000us       0.000us      22.101ms        10.51%      22.101ms       1.799us         12288  
void at::native::vectorized_elementwise_kernel<2, at...         0.00%       0.000us         0.00%       0.000us       0.000us      16.470ms         7.83%      16.470ms       1.340us         12288  
                                              aten::add         2.99%     105.439ms         5.15%     181.211ms      14.601us      16.115ms         7.66%      16.116ms       1.298us         12411  
                                        aten::remainder         2.99%     105.111ms         4.72%     166.195ms      13.525us      14.836ms         7.05%      14.838ms       1.208us         12288  
void at::native::unrolled_elementwise_kernel<at::nat...         0.00%       0.000us         0.00%       0.000us       0.000us      14.014ms         6.66%      14.014ms       1.140us         12288  
void at::native::vectorized_elementwise_kernel<2, at...         0.00%       0.000us         0.00%       0.000us       0.000us      12.996ms         6.18%      12.996ms       1.047us         12408  
void at::native::vectorized_elementwise_kernel<2, at...         0.00%       0.000us         0.00%       0.000us       0.000us       7.830ms         3.72%       7.830ms       1.274us          6144  
void at::native::unrolled_elementwise_kernel<at::nat...         0.00%       0.000us         0.00%       0.000us       0.000us       7.006ms         3.33%       7.006ms       1.140us          6144  
void at::native::vectorized_elementwise_kernel<4, at...         0.00%       0.000us         0.00%       0.000us       0.000us       2.626ms         1.25%       2.626ms     437.595us             6  
-------------------------------------------------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  
Self CPU time total: 3.521s
Self CUDA time total: 210.342ms



======================================================================
PROFILE TRACE: binned_torch | cuda_B4_S512_E4
======================================================================
-------------------------------------------------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  
                                                   Name    Self CPU %      Self CPU   CPU total %     CPU total  CPU time avg     Self CUDA   Self CUDA %    CUDA total  CUDA time avg    # of Calls  
-------------------------------------------------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  
                                           binned_torch         0.00%       0.000us         0.00%       0.000us       0.000us        3.742s      1679.57%        3.742s        3.742s             1  
                                           binned_torch        24.42%     914.204ms       100.00%        3.744s        3.744s       0.000us         0.00%     222.834ms     222.834ms             1  
                                             aten::item         1.73%      64.729ms        26.53%     993.125ms      14.638us       0.000us         0.00%      69.848ms       1.030us         67845  
                              aten::_local_scalar_dense         6.14%     229.850ms        24.80%     928.396ms      13.684us      69.844ms        31.35%      69.848ms       1.030us         67845  
                         Memcpy DtoH (Device -> Pinned)         0.00%       0.000us         0.00%       0.000us       0.000us      69.844ms        31.35%      69.844ms       1.030us         67841  
                                     aten::floor_divide         5.29%     197.931ms        12.52%     468.921ms      19.080us      30.509ms        13.69%      30.515ms       1.242us         24576  
                                ampere_sgemm_128x128_nn         0.00%       0.000us         0.00%       0.000us       0.000us      29.140ms        13.08%      29.140ms       4.857ms             6  
                                              aten::bmm         0.01%     232.675us         0.01%     273.538us      45.590us      29.140ms        13.08%      29.140ms       4.857ms             6  
                                            aten::copy_         3.66%     136.881ms         8.73%     326.908ms      13.295us      26.646ms        11.96%      26.647ms       1.084us         24588  
                         Memcpy DtoD (Device -> Device)         0.00%       0.000us         0.00%       0.000us       0.000us      26.643ms        11.96%      26.643ms       1.084us         24581  
                                              aten::mul         2.96%     110.832ms         5.24%     196.253ms      15.959us      25.520ms        11.45%      25.522ms       2.075us         12297  
                                              aten::add         4.16%     155.619ms         7.13%     266.948ms      14.322us      22.169ms         9.95%      22.169ms       1.189us         18639  
void at::native::elementwise_kernel<128, 2, at::nati...         0.00%       0.000us         0.00%       0.000us       0.000us      22.076ms         9.91%      22.076ms       1.797us         12288  
void at::native::vectorized_elementwise_kernel<2, at...         0.00%       0.000us         0.00%       0.000us       0.000us      16.462ms         7.39%      16.462ms       1.340us         12287  
                                        aten::remainder         2.77%     103.887ms         4.33%     162.240ms      13.203us      14.877ms         6.68%      14.879ms       1.211us         12288  
void at::native::unrolled_elementwise_kernel<at::nat...         0.00%       0.000us         0.00%       0.000us       0.000us      14.047ms         6.30%      14.047ms       1.143us         12287  
void at::native::vectorized_elementwise_kernel<2, at...         0.00%       0.000us         0.00%       0.000us       0.000us      12.957ms         5.82%      12.957ms       1.044us         12407  
void at::native::vectorized_elementwise_kernel<2, at...         0.00%       0.000us         0.00%       0.000us       0.000us       7.856ms         3.53%       7.856ms       1.279us          6144  
void at::native::unrolled_elementwise_kernel<at::nat...         0.00%       0.000us         0.00%       0.000us       0.000us       7.021ms         3.15%       7.021ms       1.143us          6144  
void at::native::unrolled_elementwise_kernel<at::nat...         0.00%       0.000us         0.00%       0.000us       0.000us       6.109ms         2.74%       6.109ms       0.981us          6228  
-------------------------------------------------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  
Self CPU time total: 3.744s
Self CUDA time total: 222.814ms



======================================================================
PROFILE TRACE: binned_torch | cuda_B4_S1024_E2
======================================================================
-------------------------------------------------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  
                                                   Name    Self CPU %      Self CPU   CPU total %     CPU total  CPU time avg     Self CUDA   Self CUDA %    CUDA total  CUDA time avg    # of Calls  
-------------------------------------------------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  
                                           binned_torch         0.00%       0.000us         0.00%       0.000us       0.000us        6.967s      1665.27%        6.967s        6.967s             1  
                                           binned_torch        24.68%        1.721s       100.00%        6.973s        6.973s       0.000us         0.00%     418.392ms     418.392ms             1  
                                             aten::item         1.64%     114.231ms        25.94%        1.809s      14.732us       0.000us         0.00%     125.163ms       1.020us        122763  
                              aten::_local_scalar_dense         5.97%     416.624ms        24.30%        1.694s      13.802us     125.151ms        29.91%     125.163ms       1.020us        122763  
                         Memcpy DtoH (Device -> Pinned)         0.00%       0.000us         0.00%       0.000us       0.000us     125.151ms        29.91%     125.151ms       1.019us        122762  
                                     aten::floor_divide         5.62%     391.846ms        13.33%     929.253ms      18.906us      61.051ms        14.59%      61.053ms       1.242us         49152  
                                ampere_sgemm_128x128_nn         0.00%       0.000us         0.00%       0.000us       0.000us      57.281ms        13.69%      57.281ms       9.547ms             6  
                                              aten::bmm         0.00%     234.996us         0.00%     276.787us      46.131us      57.281ms        13.69%      57.281ms       9.547ms             6  
                                            aten::copy_         3.92%     273.517ms         9.35%     652.240ms      13.268us      53.435ms        12.77%      53.437ms       1.087us         49158  
                         Memcpy DtoD (Device -> Device)         0.00%       0.000us         0.00%       0.000us       0.000us      53.433ms        12.77%      53.433ms       1.087us         49154  
                                              aten::mul         3.15%     219.950ms         5.62%     391.612ms      15.929us      51.411ms        12.29%      51.419ms       2.091us         24585  
void at::native::elementwise_kernel<128, 2, at::nati...         0.00%       0.000us         0.00%       0.000us       0.000us      44.451ms        10.62%      44.451ms       1.809us         24576  
void at::native::vectorized_elementwise_kernel<2, at...         0.00%       0.000us         0.00%       0.000us       0.000us      32.993ms         7.89%      32.993ms       1.343us         24576  
                                              aten::add         2.87%     200.428ms         4.94%     344.166ms      14.085us      31.887ms         7.62%      31.889ms       1.305us         24435  
                                        aten::remainder         3.00%     208.953ms         4.67%     325.902ms      13.261us      29.680ms         7.09%      29.684ms       1.208us         24576  
void at::native::unrolled_elementwise_kernel<at::nat...         0.00%       0.000us         0.00%       0.000us       0.000us      28.059ms         6.71%      28.059ms       1.142us         24576  
void at::native::vectorized_elementwise_kernel<2, at...         0.00%       0.000us         0.00%       0.000us       0.000us      25.247ms         6.03%      25.247ms       1.033us         24431  
void at::native::vectorized_elementwise_kernel<2, at...         0.00%       0.000us         0.00%       0.000us       0.000us      15.667ms         3.74%      15.667ms       1.275us         12288  
void at::native::unrolled_elementwise_kernel<at::nat...         0.00%       0.000us         0.00%       0.000us       0.000us      14.014ms         3.35%      14.014ms       1.140us         12288  
void at::native::vectorized_elementwise_kernel<4, at...         0.00%       0.000us         0.00%       0.000us       0.000us       5.233ms         1.25%       5.233ms     872.184us             6  
-------------------------------------------------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  
Self CPU time total: 6.973s
Self CUDA time total: 418.361ms



======================================================================
PROFILE TRACE: binned_torch | cuda_B4_S1024_E4
======================================================================
-------------------------------------------------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  
                                                   Name    Self CPU %      Self CPU   CPU total %     CPU total  CPU time avg     Self CUDA   Self CUDA %    CUDA total  CUDA time avg    # of Calls  
-------------------------------------------------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  
                                           binned_torch         0.00%       0.000us         0.00%       0.000us       0.000us        7.368s      1660.72%        7.368s        7.368s             1  
                                           binned_torch        24.39%        1.797s       100.00%        7.370s        7.370s       0.000us         0.00%     443.698ms     443.698ms             1  
                                             aten::item         1.69%     124.742ms        26.51%        1.954s      14.504us       0.000us         0.00%     137.717ms       1.022us        134715  
                              aten::_local_scalar_dense         6.11%     450.407ms        24.82%        1.829s      13.577us     137.708ms        31.04%     137.717ms       1.022us        134715  
                         Memcpy DtoH (Device -> Pinned)         0.00%       0.000us         0.00%       0.000us       0.000us     137.710ms        31.04%     137.710ms       1.022us        134711  
                                     aten::floor_divide         5.42%     399.563ms        12.65%     932.414ms      18.970us      61.071ms        13.77%      61.077ms       1.243us         49152  
                                              aten::bmm         0.00%     230.664us         0.00%     272.466us      45.411us      57.304ms        12.92%      57.304ms       9.551ms             6  
                                ampere_sgemm_128x128_nn         0.00%       0.000us         0.00%       0.000us       0.000us      57.304ms        12.92%      57.304ms       9.551ms             6  
                                            aten::copy_         3.65%     269.132ms         8.67%     639.259ms      13.004us      54.065ms        12.19%      54.067ms       1.100us         49158  
                         Memcpy DtoD (Device -> Device)         0.00%       0.000us         0.00%       0.000us       0.000us      54.062ms        12.19%      54.062ms       1.100us         49153  
                                              aten::mul         2.96%     217.959ms         5.26%     387.551ms      15.764us      51.653ms        11.64%      51.660ms       2.101us         24585  
void at::native::elementwise_kernel<128, 2, at::nati...         0.00%       0.000us         0.00%       0.000us       0.000us      44.653ms        10.06%      44.653ms       1.817us         24576  
                                              aten::add         4.03%     296.962ms         6.96%     512.647ms      14.100us      43.690ms         9.85%      43.694ms       1.202us         36357  
void at::native::vectorized_elementwise_kernel<2, at...         0.00%       0.000us         0.00%       0.000us       0.000us      32.954ms         7.43%      32.954ms       1.341us         24575  
                                        aten::remainder         2.83%     208.527ms         4.40%     323.906ms      13.180us      29.662ms         6.69%      29.664ms       1.207us         24576  
void at::native::unrolled_elementwise_kernel<at::nat...         0.00%       0.000us         0.00%       0.000us       0.000us      28.119ms         6.34%      28.119ms       1.144us         24576  
void at::native::vectorized_elementwise_kernel<2, at...         0.00%       0.000us         0.00%       0.000us       0.000us      25.409ms         5.73%      25.409ms       1.040us         24431  
void at::native::vectorized_elementwise_kernel<2, at...         0.00%       0.000us         0.00%       0.000us       0.000us      15.666ms         3.53%      15.666ms       1.275us         12288  
void at::native::unrolled_elementwise_kernel<at::nat...         0.00%       0.000us         0.00%       0.000us       0.000us      13.995ms         3.15%      13.995ms       1.139us         12288  
void at::native::unrolled_elementwise_kernel<at::nat...         0.00%       0.000us         0.00%       0.000us       0.000us      11.644ms         2.62%      11.644ms       0.977us         11922  
-------------------------------------------------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  
Self CPU time total: 7.370s
Self CUDA time total: 443.660ms


impl                     wl                  p50(ms)  ok
binned_torch             cuda_B1_S1024_E2     372.79  True
binned_torch             cuda_B1_S1024_E4     382.68  True
binned_torch             cuda_B1_S512_E2      150.05  True
binned_torch             cuda_B1_S512_E4      200.26  True
binned_torch             cuda_B4_S1024_E2    1486.48  True
binned_torch             cuda_B4_S1024_E4    1524.50  True
binned_torch             cuda_B4_S512_E2      742.02  True
binned_torch             cuda_B4_S512_E4      801.90  True

Artifacts:

openai_moe.jsonl