# /// script
# requires-python = ">=3.10"
# dependencies = [
# "numpy",
# "torch==2.8.0",
# "kernels-benchmark-tools",
# ]
#
# [tool.uv.sources]
# kernels-benchmark-tools = { path = "../../../../../tools", editable = true }
# ///
import torch
import sys
from kernels_benchmark_tools import KernelTypeEnum, run_benchmark
def binned_gather(x, indices, bins, expert_capacity, top_k):
E, H = bins.shape[0], x.shape[1]
out = torch.zeros((E, expert_capacity, H), device=x.device, dtype=x.dtype)
for e in range(E):
start = 0 if e == 0 else bins[e - 1]
end = bins[e]
n = min(end - start, expert_capacity)
for i in range(n):
flat_pos = indices[start + i]
tok = flat_pos // top_k
out[e, i] = x[tok]
return out
def binned_scatter(x, indices, weights, bins, expert_capacity, top_k):
E, C, H = x.shape
N = indices.shape[0] // top_k
out = torch.zeros((N, top_k, H), dtype=x.dtype, device=x.device)
for e in range(E):
start = 0 if e == 0 else bins[e - 1]
end = bins[e]
n = end - start
if n == 0:
continue
take = min(n, expert_capacity)
for i in range(take):
flat_pos = indices[start + i] # flattened (token, slot)
tok = flat_pos // top_k
slot = flat_pos % top_k
scale = weights[flat_pos] if weights is not None else 1.0
out[tok, slot] = x[e, i] * scale
return out.sum(dim=1)
def sort_tokens_by_expert(router_indices, num_experts):
flat_indices = router_indices.flatten()
sorted_values, sorted_indices = torch.sort(flat_indices)
tokens_per_expert = torch.bincount(sorted_values, minlength=num_experts)
bins = torch.cumsum(tokens_per_expert, dim=0)
return sorted_indices, sorted_values, bins, tokens_per_expert
def binned_experts_ref(
hidden_states,
router_indices,
routing_weights,
gate_up_proj,
gate_up_proj_bias,
down_proj,
down_proj_bias,
expert_capacity,
):
B, S, H = hidden_states.shape
E, K = routing_weights.shape[2], router_indices.shape[1]
indices, _, bins, _ = sort_tokens_by_expert(router_indices, E)
x = binned_gather(hidden_states.view(-1, H), indices, bins, expert_capacity, K)
gate_up = torch.bmm(x, gate_up_proj) + gate_up_proj_bias[..., None, :]
gate, up = gate_up[..., ::2], gate_up[..., 1::2]
# clamp to limit
limit = 7.0
gate = gate.clamp(min=None, max=limit)
up = up.clamp(min=-limit, max=limit)
glu = gate * torch.sigmoid(gate * 1.702)
x = (up + 1) * glu
x = torch.bmm(x, down_proj) + down_proj_bias[..., None, :]
# build routing weights aligned to (token, slot)
flat_dense = routing_weights.view(-1, E) # [B*S, E]
flat_router = router_indices.view(-1, K) # [B*S, K]
selected = torch.gather(flat_dense, 1, flat_router).reshape(-1) # [B*S*K]
# scatter back
y = binned_scatter(x, indices, selected, bins, expert_capacity, K) # [B*S, H]
return y.view(B, S, H)
def binned_torch_openai_moe(
hidden_states,
router_indices,
routing_weights,
gate_up_proj,
gate_up_proj_bias,
down_proj,
down_proj_bias,
):
"""
Binned PyTorch implementation of OpenAI-style MoE.
Sorts tokens by expert assignment for more efficient batched processing.
"""
B, S = hidden_states.shape[0], hidden_states.shape[1]
K = router_indices.shape[1]
# Set expert_capacity to a reasonable value (max tokens per expert)
# Use 2x the average to handle imbalance
expert_capacity = (B * S * K * 2) // routing_weights.shape[2]
return binned_experts_ref(
hidden_states,
router_indices,
routing_weights,
gate_up_proj,
gate_up_proj_bias,
down_proj,
down_proj_bias,
expert_capacity,
)
run_benchmark(
kernel_type=KernelTypeEnum.OPENAI_MOE,
impl_name="binned_torch",
impl_tags={"family": "pytorch", "backend": "eager"},
impl_func=binned_torch_openai_moe,
dtype="float32",
)