Binned PyTorch - OpenAI-style MoE

GPU Info

▼ code ▼ output ▶ uv-logs | Cell: nv | 0.22s | Raw GitHub
import subprocess
print(subprocess.run(["nvidia-smi"], capture_output=True, text=True).stdout)
Mon Nov 10 21:58:43 2025       
+-----------------------------------------------------------------------------------------+
| NVIDIA-SMI 580.95.05              Driver Version: 580.95.05      CUDA Version: 13.0     |
+-----------------------------------------+------------------------+----------------------+
| GPU  Name                 Persistence-M | Bus-Id          Disp.A | Volatile Uncorr. ECC |
| Fan  Temp   Perf          Pwr:Usage/Cap |           Memory-Usage | GPU-Util  Compute M. |
|                                         |                        |               MIG M. |
|=========================================+========================+======================|
|   0  NVIDIA L40S                    On  |   00000000:4D:00.0 Off |                    0 |
| N/A   31C    P0             78W /  350W |       0MiB /  46068MiB |     17%      Default |
|                                         |                        |                  N/A |
+-----------------------------------------+------------------------+----------------------+

+-----------------------------------------------------------------------------------------+
| Processes:                                                                              |
|  GPU   GI   CI              PID   Type   Process name                        GPU Memory |
|        ID   ID                                                               Usage      |
|=========================================================================================|
|  No running processes found                                                             |
+-----------------------------------------------------------------------------------------+

OpenAI-style MoE Benchmark (Binned PyTorch)

▼ code ▼ output ▶ uv-logs | Cell: benchmark | 727.18s | Raw GitHub
# /// script
# requires-python = ">=3.10"
# dependencies = [
#     "numpy",
#     "torch==2.8.0",
#     "kernels-benchmark-tools",
# ]
#
# [tool.uv.sources]
# kernels-benchmark-tools = { path = "../../../../../tools", editable = true }
# ///
import torch
import sys
from kernels_benchmark_tools import KernelTypeEnum, run_benchmark


def binned_gather(x, indices, bins, expert_capacity, top_k):
    E, H = bins.shape[0], x.shape[1]
    out = torch.zeros((E, expert_capacity, H), device=x.device, dtype=x.dtype)
    for e in range(E):
        start = 0 if e == 0 else bins[e - 1]
        end = bins[e]
        n = min(end - start, expert_capacity)
        for i in range(n):
            flat_pos = indices[start + i]
            tok = flat_pos // top_k
            out[e, i] = x[tok]
    return out


def binned_scatter(x, indices, weights, bins, expert_capacity, top_k):
    E, C, H = x.shape
    N = indices.shape[0] // top_k
    out = torch.zeros((N, top_k, H), dtype=x.dtype, device=x.device)
    for e in range(E):
        start = 0 if e == 0 else bins[e - 1]
        end = bins[e]
        n = end - start
        if n == 0:
            continue
        take = min(n, expert_capacity)
        for i in range(take):
            flat_pos = indices[start + i]  # flattened (token, slot)
            tok = flat_pos // top_k
            slot = flat_pos % top_k
            scale = weights[flat_pos] if weights is not None else 1.0
            out[tok, slot] = x[e, i] * scale
    return out.sum(dim=1)


def sort_tokens_by_expert(router_indices, num_experts):
    flat_indices = router_indices.flatten()
    sorted_values, sorted_indices = torch.sort(flat_indices)
    tokens_per_expert = torch.bincount(sorted_values, minlength=num_experts)
    bins = torch.cumsum(tokens_per_expert, dim=0)
    return sorted_indices, sorted_values, bins, tokens_per_expert


def binned_experts_ref(
    hidden_states,
    router_indices,
    routing_weights,
    gate_up_proj,
    gate_up_proj_bias,
    down_proj,
    down_proj_bias,
    expert_capacity,
):
    B, S, H = hidden_states.shape
    E, K = routing_weights.shape[2], router_indices.shape[1]

    indices, _, bins, _ = sort_tokens_by_expert(router_indices, E)
    x = binned_gather(hidden_states.view(-1, H), indices, bins, expert_capacity, K)

    gate_up = torch.bmm(x, gate_up_proj) + gate_up_proj_bias[..., None, :]
    gate, up = gate_up[..., ::2], gate_up[..., 1::2]

    # clamp to limit
    limit = 7.0
    gate = gate.clamp(min=None, max=limit)
    up = up.clamp(min=-limit, max=limit)

    glu = gate * torch.sigmoid(gate * 1.702)
    x = (up + 1) * glu
    x = torch.bmm(x, down_proj) + down_proj_bias[..., None, :]

    # build routing weights aligned to (token, slot)
    flat_dense = routing_weights.view(-1, E)  # [B*S, E]
    flat_router = router_indices.view(-1, K)  # [B*S, K]
    selected = torch.gather(flat_dense, 1, flat_router).reshape(-1)  # [B*S*K]

    # scatter back
    y = binned_scatter(x, indices, selected, bins, expert_capacity, K)  # [B*S, H]

    return y.view(B, S, H)


def binned_torch_openai_moe(
    hidden_states,
    router_indices,
    routing_weights,
    gate_up_proj,
    gate_up_proj_bias,
    down_proj,
    down_proj_bias,
):
    """
    Binned PyTorch implementation of OpenAI-style MoE.
    Sorts tokens by expert assignment for more efficient batched processing.
    """
    B, S = hidden_states.shape[0], hidden_states.shape[1]
    K = router_indices.shape[1]

    # Set expert_capacity to a reasonable value (max tokens per expert)
    # Use 2x the average to handle imbalance
    expert_capacity = (B * S * K * 2) // routing_weights.shape[2]

    return binned_experts_ref(
        hidden_states,
        router_indices,
        routing_weights,
        gate_up_proj,
        gate_up_proj_bias,
        down_proj,
        down_proj_bias,
        expert_capacity,
    )


run_benchmark(
    kernel_type=KernelTypeEnum.OPENAI_MOE,
    impl_name="binned_torch",
    impl_tags={"family": "pytorch", "backend": "eager"},
    impl_func=binned_torch_openai_moe,
    dtype="float32",
)
Running openai_moe benchmark on cuda with 8 workloads.

======================================================================
PROFILE TRACE: binned_torch | cuda_B1_S512_E2
======================================================================
-------------------------------------------------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  
                                                   Name    Self CPU %      Self CPU   CPU total %     CPU total  CPU time avg     Self CUDA   Self CUDA %    CUDA total  CUDA time avg    # of Calls  
-------------------------------------------------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  
                                           binned_torch         0.00%       0.000us         0.00%       0.000us       0.000us     931.122ms      1835.78%     931.122ms     931.122ms             1  
                                           binned_torch        25.32%     236.300ms       100.00%     933.185ms     933.185ms       0.000us         0.00%      50.723ms      50.723ms             1  
                                             aten::item         1.92%      17.916ms        25.08%     234.061ms      15.253us       0.000us         0.00%      15.750ms       1.026us         15345  
                              aten::_local_scalar_dense         5.72%      53.357ms        23.16%     216.145ms      14.086us      15.749ms        31.05%      15.750ms       1.026us         15345  
                         Memcpy DtoH (Device -> Pinned)         0.00%       0.000us         0.00%       0.000us       0.000us      15.749ms        31.05%      15.749ms       1.026us         15345  
                                     aten::floor_divide         5.56%      51.923ms        13.14%     122.652ms      19.963us       7.815ms        15.41%       7.815ms       1.272us          6144  
                                              aten::bmm         0.02%     190.442us         0.02%     231.383us      38.564us       7.780ms        15.34%       7.780ms       1.297ms             6  
                                ampere_sgemm_128x128_nn         0.00%       0.000us         0.00%       0.000us       0.000us       7.780ms        15.34%       7.780ms       1.297ms             6  
                                            aten::copy_         3.79%      35.401ms         9.18%      85.713ms      13.923us       6.584ms        12.98%       6.585ms       1.070us          6156  
                         Memcpy DtoD (Device -> Device)         0.00%       0.000us         0.00%       0.000us       0.000us       6.579ms        12.97%       6.579ms       1.069us          6153  
                                              aten::mul         3.06%      28.578ms         5.54%      51.726ms      16.789us       4.711ms         9.29%       4.711ms       1.529us          3081  
void at::native::elementwise_kernel<128, 2, at::nati...         0.00%       0.000us         0.00%       0.000us       0.000us       4.480ms         8.83%       4.480ms       1.458us          3072  
void at::native::vectorized_elementwise_kernel<2, at...         0.00%       0.000us         0.00%       0.000us       0.000us       4.161ms         8.20%       4.161ms       1.354us          3072  
                                        aten::remainder         3.12%      29.137ms         4.83%      45.065ms      14.669us       3.840ms         7.57%       3.840ms       1.250us          3072  
                                              aten::add         2.80%      26.083ms         4.76%      44.381ms      14.633us       3.757ms         7.41%       3.757ms       1.239us          3033  
void at::native::unrolled_elementwise_kernel<at::nat...         0.00%       0.000us         0.00%       0.000us       0.000us       3.656ms         7.21%       3.656ms       1.190us          3072  
void at::native::vectorized_elementwise_kernel<2, at...         0.00%       0.000us         0.00%       0.000us       0.000us       3.366ms         6.64%       3.366ms       1.111us          3030  
void at::native::vectorized_elementwise_kernel<2, at...         0.00%       0.000us         0.00%       0.000us       0.000us       2.023ms         3.99%       2.023ms       1.317us          1536  
void at::native::unrolled_elementwise_kernel<at::nat...         0.00%       0.000us         0.00%       0.000us       0.000us       1.817ms         3.58%       1.817ms       1.183us          1536  
void at::native::elementwise_kernel<128, 2, at::nati...         0.00%       0.000us         0.00%       0.000us       0.000us     283.649us         0.56%     283.649us      47.275us             6  
-------------------------------------------------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  
Self CPU time total: 933.193ms
Self CUDA time total: 50.721ms



======================================================================
PROFILE TRACE: binned_torch | cuda_B1_S512_E4
======================================================================
-------------------------------------------------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  
                                                   Name    Self CPU %      Self CPU   CPU total %     CPU total  CPU time avg     Self CUDA   Self CUDA %    CUDA total  CUDA time avg    # of Calls  
-------------------------------------------------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  
                                           binned_torch         0.00%       0.000us         0.00%       0.000us       0.000us     938.961ms      1720.32%     938.961ms     938.961ms             1  
                                           binned_torch        25.07%     235.565ms       100.00%     939.473ms     939.473ms       0.000us         0.00%      54.589ms      54.589ms             1  
                                             aten::item         1.76%      16.540ms        26.46%     248.589ms      14.679us       0.000us         0.00%      17.855ms       1.054us         16935  
                              aten::_local_scalar_dense         5.69%      53.475ms        24.70%     232.048ms      13.702us      17.853ms        32.71%      17.855ms       1.054us         16935  
                         Memcpy DtoH (Device -> Pinned)         0.00%       0.000us         0.00%       0.000us       0.000us      17.853ms        32.71%      17.853ms       1.054us         16935  
                                              aten::bmm         0.02%     182.580us         0.02%     223.522us      37.254us       7.981ms        14.62%       7.981ms       1.330ms             6  
                                ampere_sgemm_128x128_nn         0.00%       0.000us         0.00%       0.000us       0.000us       7.981ms        14.62%       7.981ms       1.330ms             6  
                                     aten::floor_divide         5.18%      48.644ms        12.51%     117.515ms      19.127us       7.813ms        14.31%       7.816ms       1.272us          6144  
                                            aten::copy_         3.69%      34.686ms         8.73%      82.032ms      13.325us       6.629ms        12.15%       6.630ms       1.077us          6156  
                         Memcpy DtoD (Device -> Device)         0.00%       0.000us         0.00%       0.000us       0.000us       6.626ms        12.14%       6.626ms       1.077us          6153  
                                              aten::add         3.97%      37.266ms         6.91%      64.908ms      14.132us       5.261ms         9.64%       5.261ms       1.145us          4593  
                                              aten::mul         2.87%      26.992ms         5.23%      49.129ms      15.946us       4.699ms         8.61%       4.699ms       1.525us          3081  
void at::native::elementwise_kernel<128, 2, at::nati...         0.00%       0.000us         0.00%       0.000us       0.000us       4.475ms         8.20%       4.475ms       1.457us          3072  
void at::native::vectorized_elementwise_kernel<2, at...         0.00%       0.000us         0.00%       0.000us       0.000us       4.158ms         7.62%       4.158ms       1.353us          3072  
                                        aten::remainder         2.85%      26.773ms         4.50%      42.318ms      13.775us       3.852ms         7.06%       3.852ms       1.254us          3072  
void at::native::unrolled_elementwise_kernel<at::nat...         0.00%       0.000us         0.00%       0.000us       0.000us       3.655ms         6.70%       3.655ms       1.190us          3072  
void at::native::vectorized_elementwise_kernel<2, at...         0.00%       0.000us         0.00%       0.000us       0.000us       3.271ms         5.99%       3.271ms       1.080us          3030  
void at::native::vectorized_elementwise_kernel<2, at...         0.00%       0.000us         0.00%       0.000us       0.000us       2.030ms         3.72%       2.030ms       1.322us          1536  
void at::native::unrolled_elementwise_kernel<at::nat...         0.00%       0.000us         0.00%       0.000us       0.000us       1.822ms         3.34%       1.822ms       1.186us          1536  
void at::native::unrolled_elementwise_kernel<at::nat...         0.00%       0.000us         0.00%       0.000us       0.000us       1.585ms         2.90%       1.585ms       1.016us          1560  
-------------------------------------------------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  
Self CPU time total: 939.480ms
Self CUDA time total: 54.581ms



======================================================================
PROFILE TRACE: binned_torch | cuda_B1_S1024_E2
======================================================================
-------------------------------------------------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  
                                                   Name    Self CPU %      Self CPU   CPU total %     CPU total  CPU time avg     Self CUDA   Self CUDA %    CUDA total  CUDA time avg    # of Calls  
-------------------------------------------------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  
                                           binned_torch         0.00%       0.000us         0.00%       0.000us       0.000us        1.710s      1645.94%        1.710s        1.710s             1  
                                           binned_torch        23.47%     401.594ms       100.00%        1.711s        1.711s       0.000us         0.00%     103.932ms     103.932ms             1  
                                             aten::item         1.77%      30.361ms        27.00%     461.971ms      15.140us       0.000us         0.00%      31.541ms       1.034us         30513  
                              aten::_local_scalar_dense         5.97%     102.153ms        25.22%     431.610ms      14.145us      31.538ms        30.35%      31.541ms       1.034us         30513  
                         Memcpy DtoH (Device -> Pinned)         0.00%       0.000us         0.00%       0.000us       0.000us      31.538ms        30.35%      31.538ms       1.034us         30513  
                                     aten::floor_divide         5.77%      98.697ms        13.68%     234.018ms      19.044us      15.598ms        15.01%      15.600ms       1.270us         12288  
                                              aten::bmm         0.01%     219.084us         0.02%     260.723us      43.454us      15.235ms        14.66%      15.235ms       2.539ms             6  
                                ampere_sgemm_128x128_nn         0.00%       0.000us         0.00%       0.000us       0.000us      15.235ms        14.66%      15.235ms       2.539ms             6  
                                            aten::copy_         3.97%      67.926ms         9.38%     160.451ms      13.045us      13.315ms        12.81%      13.316ms       1.083us         12300  
                         Memcpy DtoD (Device -> Device)         0.00%       0.000us         0.00%       0.000us       0.000us      13.311ms        12.81%      13.311ms       1.083us         12294  
                                              aten::mul         3.19%      54.637ms         5.82%      99.678ms      16.200us      11.250ms        10.83%      11.252ms       1.829us          6153  
void at::native::elementwise_kernel<128, 2, at::nati...         0.00%       0.000us         0.00%       0.000us       0.000us       9.903ms         9.53%       9.903ms       1.612us          6144  
void at::native::vectorized_elementwise_kernel<2, at...         0.00%       0.000us         0.00%       0.000us       0.000us       8.304ms         7.99%       8.304ms       1.352us          6144  
                                        aten::remainder         3.07%      52.461ms         4.79%      82.008ms      13.348us       7.670ms         7.38%       7.671ms       1.249us          6144  
                                              aten::add         2.76%      47.163ms         4.86%      83.106ms      14.055us       7.632ms         7.34%       7.633ms       1.291us          5913  
void at::native::unrolled_elementwise_kernel<at::nat...         0.00%       0.000us         0.00%       0.000us       0.000us       7.294ms         7.02%       7.294ms       1.187us          6144  
void at::native::vectorized_elementwise_kernel<2, at...         0.00%       0.000us         0.00%       0.000us       0.000us       6.354ms         6.11%       6.354ms       1.075us          5910  
void at::native::vectorized_elementwise_kernel<2, at...         0.00%       0.000us         0.00%       0.000us       0.000us       4.041ms         3.89%       4.041ms       1.316us          3072  
void at::native::unrolled_elementwise_kernel<at::nat...         0.00%       0.000us         0.00%       0.000us       0.000us       3.629ms         3.49%       3.629ms       1.181us          3072  
                                            aten::clamp         0.00%      71.350us         0.01%     113.931us      18.988us       1.190ms         1.15%       1.190ms     198.366us             6  
-------------------------------------------------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  
Self CPU time total: 1.711s
Self CUDA time total: 103.922ms



======================================================================
PROFILE TRACE: binned_torch | cuda_B1_S1024_E4
======================================================================
-------------------------------------------------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  
                                                   Name    Self CPU %      Self CPU   CPU total %     CPU total  CPU time avg     Self CUDA   Self CUDA %    CUDA total  CUDA time avg    # of Calls  
-------------------------------------------------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  
                                           binned_torch         0.00%       0.000us         0.00%       0.000us       0.000us        1.831s      1659.19%        1.831s        1.831s             1  
                                           binned_torch        23.77%     435.469ms       100.00%        1.832s        1.832s       0.000us         0.00%     110.361ms     110.361ms             1  
                                             aten::item         1.74%      31.875ms        27.52%     504.183ms      14.948us       0.000us         0.00%      34.964ms       1.037us         33729  
                              aten::_local_scalar_dense         6.20%     113.521ms        25.78%     472.309ms      14.003us      34.961ms        31.68%      34.964ms       1.037us         33729  
                         Memcpy DtoH (Device -> Pinned)         0.00%       0.000us         0.00%       0.000us       0.000us      34.961ms        31.68%      34.961ms       1.037us         33729  
                                     aten::floor_divide         5.21%      95.369ms        12.55%     229.877ms      18.707us      15.595ms        14.13%      15.597ms       1.269us         12288  
                                              aten::bmm         0.01%     225.035us         0.01%     267.825us      44.638us      15.231ms        13.80%      15.231ms       2.539ms             6  
                                ampere_sgemm_128x128_nn         0.00%       0.000us         0.00%       0.000us       0.000us      15.231ms        13.80%      15.231ms       2.539ms             6  
                                            aten::copy_         3.69%      67.648ms         8.80%     161.241ms      13.109us      13.343ms        12.09%      13.347ms       1.085us         12300  
                         Memcpy DtoD (Device -> Device)         0.00%       0.000us         0.00%       0.000us       0.000us      13.340ms        12.09%      13.340ms       1.085us         12297  
                                              aten::mul         2.99%      54.761ms         5.39%      98.799ms      16.057us      10.934ms         9.91%      10.936ms       1.777us          6153  
                                              aten::add         3.91%      71.612ms         6.90%     126.397ms      13.891us      10.863ms         9.84%      10.863ms       1.194us          9099  
void at::native::elementwise_kernel<128, 2, at::nati...         0.00%       0.000us         0.00%       0.000us       0.000us       9.586ms         8.69%       9.586ms       1.560us          6144  
void at::native::vectorized_elementwise_kernel<2, at...         0.00%       0.000us         0.00%       0.000us       0.000us       8.308ms         7.53%       8.308ms       1.352us          6144  
                                        aten::remainder         2.81%      51.395ms         4.41%      80.796ms      13.150us       7.688ms         6.97%       7.688ms       1.251us          6144  
void at::native::unrolled_elementwise_kernel<at::nat...         0.00%       0.000us         0.00%       0.000us       0.000us       7.287ms         6.60%       7.287ms       1.186us          6144  
void at::native::vectorized_elementwise_kernel<2, at...         0.00%       0.000us         0.00%       0.000us       0.000us       6.364ms         5.77%       6.364ms       1.077us          5910  
void at::native::vectorized_elementwise_kernel<2, at...         0.00%       0.000us         0.00%       0.000us       0.000us       4.054ms         3.67%       4.054ms       1.320us          3072  
void at::native::unrolled_elementwise_kernel<at::nat...         0.00%       0.000us         0.00%       0.000us       0.000us       3.634ms         3.29%       3.634ms       1.183us          3072  
void at::native::unrolled_elementwise_kernel<at::nat...         0.00%       0.000us         0.00%       0.000us       0.000us       3.232ms         2.93%       3.232ms       1.014us          3186  
-------------------------------------------------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  
Self CPU time total: 1.832s
Self CUDA time total: 110.351ms



======================================================================
PROFILE TRACE: binned_torch | cuda_B4_S512_E2
======================================================================
-------------------------------------------------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  
                                                   Name    Self CPU %      Self CPU   CPU total %     CPU total  CPU time avg     Self CUDA   Self CUDA %    CUDA total  CUDA time avg    # of Calls  
-------------------------------------------------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  
                                           binned_torch         0.00%       0.000us         0.00%       0.000us       0.000us        3.493s      1641.52%        3.493s        3.493s             1  
                                           binned_torch        23.72%     828.141ms       100.00%        3.492s        3.492s       0.000us         0.00%     212.777ms     212.777ms             1  
                         Memcpy DtoH (Device -> Pinned)         0.00%       0.000us         0.00%       0.000us       0.000us      63.619ms        29.90%      63.619ms       1.033us         61586  
                                             aten::item         1.76%      61.470ms        26.76%     934.319ms      15.171us       0.000us         0.00%      63.619ms       1.033us         61587  
                              aten::_local_scalar_dense         5.95%     207.894ms        25.00%     872.849ms      14.173us      63.616ms        29.90%      63.619ms       1.033us         61587  
                                     aten::floor_divide         5.53%     193.077ms        13.34%     465.879ms      18.957us      31.606ms        14.86%      31.612ms       1.286us         24576  
                                              aten::bmm         0.01%     236.694us         0.01%     284.594us      47.432us      29.067ms        13.66%      29.067ms       4.844ms             6  
                                ampere_sgemm_128x128_nn         0.00%       0.000us         0.00%       0.000us       0.000us      29.067ms        13.66%      29.067ms       4.844ms             6  
                                            aten::copy_         3.89%     135.756ms         9.33%     325.881ms      13.254us      26.713ms        12.56%      26.714ms       1.086us         24588  
                         Memcpy DtoD (Device -> Device)         0.00%       0.000us         0.00%       0.000us       0.000us      26.711ms        12.55%      26.711ms       1.087us         24582  
                                              aten::mul         3.15%     110.066ms         5.73%     199.944ms      16.260us      25.593ms        12.03%      25.595ms       2.081us         12297  
void at::native::elementwise_kernel<128, 2, at::nati...         0.00%       0.000us         0.00%       0.000us       0.000us      22.131ms        10.40%      22.131ms       1.801us         12288  
void at::native::vectorized_elementwise_kernel<2, at...         0.00%       0.000us         0.00%       0.000us       0.000us      17.009ms         7.99%      17.009ms       1.384us         12288  
                                              aten::add         2.82%      98.495ms         4.98%     173.932ms      14.014us      16.658ms         7.83%      16.659ms       1.342us         12411  
                                        aten::remainder         3.04%     106.037ms         4.77%     166.563ms      13.555us      15.433ms         7.25%      15.435ms       1.256us         12288  
void at::native::unrolled_elementwise_kernel<at::nat...         0.00%       0.000us         0.00%       0.000us       0.000us      14.597ms         6.86%      14.597ms       1.188us         12288  
void at::native::vectorized_elementwise_kernel<2, at...         0.00%       0.000us         0.00%       0.000us       0.000us      13.527ms         6.36%      13.527ms       1.090us         12408  
void at::native::vectorized_elementwise_kernel<2, at...         0.00%       0.000us         0.00%       0.000us       0.000us       8.132ms         3.82%       8.132ms       1.324us          6144  
void at::native::unrolled_elementwise_kernel<at::nat...         0.00%       0.000us         0.00%       0.000us       0.000us       7.300ms         3.43%       7.300ms       1.188us          6144  
void at::native::vectorized_elementwise_kernel<4, at...         0.00%       0.000us         0.00%       0.000us       0.000us       2.623ms         1.23%       2.623ms     437.201us             6  
-------------------------------------------------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  
Self CPU time total: 3.492s
Self CUDA time total: 212.763ms



======================================================================
PROFILE TRACE: binned_torch | cuda_B4_S512_E4
======================================================================
-------------------------------------------------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  
                                                   Name    Self CPU %      Self CPU   CPU total %     CPU total  CPU time avg     Self CUDA   Self CUDA %    CUDA total  CUDA time avg    # of Calls  
-------------------------------------------------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  
                                           binned_torch         0.00%       0.000us         0.00%       0.000us       0.000us        3.669s      1629.04%        3.669s        3.669s             1  
                                           binned_torch        23.71%     870.025ms       100.00%        3.670s        3.670s       0.000us         0.00%     225.217ms     225.217ms             1  
                                             aten::item         1.74%      63.801ms        26.98%     990.130ms      14.594us       0.000us         0.00%      69.736ms       1.028us         67845  
                              aten::_local_scalar_dense         5.93%     217.737ms        25.24%     926.329ms      13.654us      69.731ms        30.96%      69.736ms       1.028us         67845  
                         Memcpy DtoH (Device -> Pinned)         0.00%       0.000us         0.00%       0.000us       0.000us      69.731ms        30.96%      69.731ms       1.028us         67841  
                                     aten::floor_divide         5.15%     189.112ms        12.36%     453.770ms      18.464us      31.523ms        14.00%      31.529ms       1.283us         24576  
                                              aten::bmm         0.01%     229.594us         0.01%     272.075us      45.346us      28.926ms        12.84%      28.926ms       4.821ms             6  
                                ampere_sgemm_128x128_nn         0.00%       0.000us         0.00%       0.000us       0.000us      28.926ms        12.84%      28.926ms       4.821ms             6  
                                            aten::copy_         3.90%     143.149ms         8.93%     327.628ms      13.325us      26.721ms        11.87%      26.722ms       1.087us         24588  
                         Memcpy DtoD (Device -> Device)         0.00%       0.000us         0.00%       0.000us       0.000us      26.719ms        11.86%      26.719ms       1.087us         24581  
                                              aten::mul         3.13%     114.822ms         5.47%     200.852ms      16.333us      25.594ms        11.37%      25.596ms       2.081us         12297  
                                              aten::add         3.87%     141.881ms         6.78%     248.742ms      13.345us      23.243ms        10.32%      23.243ms       1.247us         18639  
void at::native::elementwise_kernel<128, 2, at::nati...         0.00%       0.000us         0.00%       0.000us       0.000us      22.132ms         9.83%      22.132ms       1.801us         12288  
void at::native::vectorized_elementwise_kernel<2, at...         0.00%       0.000us         0.00%       0.000us       0.000us      16.988ms         7.54%      16.988ms       1.383us         12287  
                                        aten::remainder         2.85%     104.729ms         4.42%     162.304ms      13.208us      15.354ms         6.82%      15.355ms       1.250us         12288  
void at::native::unrolled_elementwise_kernel<at::nat...         0.00%       0.000us         0.00%       0.000us       0.000us      14.535ms         6.45%      14.535ms       1.183us         12287  
void at::native::vectorized_elementwise_kernel<2, at...         0.00%       0.000us         0.00%       0.000us       0.000us      13.676ms         6.07%      13.676ms       1.102us         12407  
void at::native::vectorized_elementwise_kernel<2, at...         0.00%       0.000us         0.00%       0.000us       0.000us       8.096ms         3.60%       8.096ms       1.318us          6144  
void at::native::unrolled_elementwise_kernel<at::nat...         0.00%       0.000us         0.00%       0.000us       0.000us       7.258ms         3.22%       7.258ms       1.181us          6144  
void at::native::unrolled_elementwise_kernel<at::nat...         0.00%       0.000us         0.00%       0.000us       0.000us       6.475ms         2.88%       6.475ms       1.040us          6228  
-------------------------------------------------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  
Self CPU time total: 3.670s
Self CUDA time total: 225.199ms



======================================================================
PROFILE TRACE: binned_torch | cuda_B4_S1024_E2
======================================================================
-------------------------------------------------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  
                                                   Name    Self CPU %      Self CPU   CPU total %     CPU total  CPU time avg     Self CUDA   Self CUDA %    CUDA total  CUDA time avg    # of Calls  
-------------------------------------------------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  
                                           binned_torch         0.00%       0.000us         0.00%       0.000us       0.000us        6.859s      1611.59%        6.859s        6.859s             1  
                                           binned_torch        24.10%        1.655s       100.00%        6.866s        6.866s       0.000us         0.00%     425.661ms     425.661ms             1  
                                             aten::item         1.68%     115.068ms        26.29%        1.805s      14.704us       0.000us         0.00%     127.116ms       1.035us        122763  
                              aten::_local_scalar_dense         5.74%     393.879ms        24.61%        1.690s      13.764us     127.109ms        29.86%     127.116ms       1.035us        122763  
                         Memcpy DtoH (Device -> Pinned)         0.00%       0.000us         0.00%       0.000us       0.000us     127.110ms        29.86%     127.110ms       1.035us        122762  
                                     aten::floor_divide         5.46%     374.656ms        13.09%     898.826ms      18.287us      63.404ms        14.90%      63.408ms       1.290us         49152  
                                              aten::bmm         0.00%     234.973us         0.00%     276.793us      46.132us      56.971ms        13.38%      56.971ms       9.495ms             6  
                                ampere_sgemm_128x128_nn         0.00%       0.000us         0.00%       0.000us       0.000us      56.971ms        13.38%      56.971ms       9.495ms             6  
                                            aten::copy_         4.17%     286.167ms         9.49%     651.750ms      13.258us      53.615ms        12.60%      53.616ms       1.091us         49158  
                         Memcpy DtoD (Device -> Device)         0.00%       0.000us         0.00%       0.000us       0.000us      53.612ms        12.60%      53.612ms       1.091us         49154  
                                              aten::mul         3.34%     229.543ms         5.86%     402.465ms      16.370us      51.556ms        12.11%      51.561ms       2.097us         24585  
void at::native::elementwise_kernel<128, 2, at::nati...         0.00%       0.000us         0.00%       0.000us       0.000us      44.609ms        10.48%      44.609ms       1.815us         24576  
void at::native::vectorized_elementwise_kernel<2, at...         0.00%       0.000us         0.00%       0.000us       0.000us      34.184ms         8.03%      34.184ms       1.391us         24576  
                                              aten::add         2.69%     184.813ms         4.71%     323.308ms      13.231us      33.584ms         7.89%      33.588ms       1.375us         24435  
                                        aten::remainder         3.06%     210.055ms         4.75%     326.044ms      13.267us      30.927ms         7.27%      30.931ms       1.259us         24576  
void at::native::unrolled_elementwise_kernel<at::nat...         0.00%       0.000us         0.00%       0.000us       0.000us      29.221ms         6.87%      29.221ms       1.189us         24576  
void at::native::vectorized_elementwise_kernel<2, at...         0.00%       0.000us         0.00%       0.000us       0.000us      26.946ms         6.33%      26.946ms       1.103us         24431  
void at::native::vectorized_elementwise_kernel<2, at...         0.00%       0.000us         0.00%       0.000us       0.000us      16.291ms         3.83%      16.291ms       1.326us         12288  
void at::native::unrolled_elementwise_kernel<at::nat...         0.00%       0.000us         0.00%       0.000us       0.000us      14.637ms         3.44%      14.637ms       1.191us         12288  
void at::native::vectorized_elementwise_kernel<4, at...         0.00%       0.000us         0.00%       0.000us       0.000us       5.222ms         1.23%       5.222ms     870.407us             6  
-------------------------------------------------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  
Self CPU time total: 6.866s
Self CUDA time total: 425.634ms



======================================================================
PROFILE TRACE: binned_torch | cuda_B4_S1024_E4
======================================================================
-------------------------------------------------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  
                                                   Name    Self CPU %      Self CPU   CPU total %     CPU total  CPU time avg     Self CUDA   Self CUDA %    CUDA total  CUDA time avg    # of Calls  
-------------------------------------------------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  
                                           binned_torch         0.00%       0.000us         0.00%       0.000us       0.000us        7.331s      1630.84%        7.331s        7.331s             1  
                                           binned_torch        23.92%        1.754s       100.00%        7.333s        7.333s       0.000us         0.00%     449.578ms     449.578ms             1  
                                             aten::item         1.73%     127.153ms        27.44%        2.013s      14.940us       0.000us         0.00%     139.264ms       1.034us        134715  
                              aten::_local_scalar_dense         6.23%     456.926ms        25.71%        1.885s      13.996us     139.253ms        30.98%     139.264ms       1.034us        134715  
                         Memcpy DtoH (Device -> Pinned)         0.00%       0.000us         0.00%       0.000us       0.000us     139.255ms        30.98%     139.255ms       1.034us        134707  
                                     aten::floor_divide         5.02%     368.091ms        12.28%     900.843ms      18.328us      63.383ms        14.10%      63.388ms       1.290us         49152  
                                ampere_sgemm_128x128_nn         0.00%       0.000us         0.00%       0.000us       0.000us      56.831ms        12.64%      56.831ms       9.472ms             6  
                                              aten::bmm         0.00%     231.002us         0.00%     273.424us      45.571us      56.831ms        12.64%      56.831ms       9.472ms             6  
                                            aten::copy_         3.67%     268.957ms         8.71%     638.523ms      12.989us      53.771ms        11.96%      53.773ms       1.094us         49158  
                         Memcpy DtoD (Device -> Device)         0.00%       0.000us         0.00%       0.000us       0.000us      53.768ms        11.96%      53.768ms       1.094us         49149  
                                              aten::mul         2.96%     217.228ms         5.34%     391.576ms      15.927us      51.518ms        11.46%      51.524ms       2.096us         24585  
                                              aten::add         3.83%     280.607ms         6.79%     497.692ms      13.689us      45.514ms        10.12%      45.518ms       1.252us         36357  
void at::native::elementwise_kernel<128, 2, at::nati...         0.00%       0.000us         0.00%       0.000us       0.000us      44.542ms         9.91%      44.542ms       1.812us         24576  
void at::native::vectorized_elementwise_kernel<2, at...         0.00%       0.000us         0.00%       0.000us       0.000us      34.127ms         7.59%      34.127ms       1.389us         24573  
                                        aten::remainder         2.85%     209.203ms         4.50%     330.314ms      13.441us      30.793ms         6.85%      30.795ms       1.253us         24576  
void at::native::unrolled_elementwise_kernel<at::nat...         0.00%       0.000us         0.00%       0.000us       0.000us      29.257ms         6.51%      29.257ms       1.191us         24573  
void at::native::vectorized_elementwise_kernel<2, at...         0.00%       0.000us         0.00%       0.000us       0.000us      26.610ms         5.92%      26.610ms       1.089us         24431  
void at::native::vectorized_elementwise_kernel<2, at...         0.00%       0.000us         0.00%       0.000us       0.000us      16.233ms         3.61%      16.233ms       1.321us         12288  
void at::native::unrolled_elementwise_kernel<at::nat...         0.00%       0.000us         0.00%       0.000us       0.000us      14.559ms         3.24%      14.559ms       1.185us         12288  
void at::native::unrolled_elementwise_kernel<at::nat...         0.00%       0.000us         0.00%       0.000us       0.000us      12.261ms         2.73%      12.261ms       1.028us         11922  
-------------------------------------------------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  
Self CPU time total: 7.333s
Self CUDA time total: 449.542ms


impl                     wl                  p50(ms)  ok
binned_torch             cuda_B1_S1024_E2     367.62  True
binned_torch             cuda_B1_S1024_E4     394.19  True
binned_torch             cuda_B1_S512_E2      154.67  True
binned_torch             cuda_B1_S512_E4      201.50  True
binned_torch             cuda_B4_S1024_E2    1483.54  True
binned_torch             cuda_B4_S1024_E4    1601.90  True
binned_torch             cuda_B4_S512_E2      736.26  True
binned_torch             cuda_B4_S512_E4      798.88  True

Artifacts:

openai_moe.jsonl