Upload folder using huggingface_hub
Browse files- flash_attn/benchmark.html +0 -0
- flash_attn/cells/benchmark.py +341 -0
- flash_attn/index.html +24 -0
- index.html +2 -2
- moe_benchmarks/megablocks/cells/forward_and_backward.py +196 -0
- moe_benchmarks/megablocks/megablocks_only.html +326 -246
- moe_benchmarks/megablocks_yamoe/artifacts/binned_run/binned_results.json +24 -0
- moe_benchmarks/megablocks_yamoe/artifacts/gptoss_run/gptoss_results.json +24 -0
- moe_benchmarks/megablocks_yamoe/artifacts/gptoss_training_run/gptoss_training_results.json +24 -0
- moe_benchmarks/megablocks_yamoe/artifacts/yamoe_run/yamoe_results.json +24 -0
- moe_benchmarks/megablocks_yamoe/cells/__pycache__/bench_utils.cpython-311.pyc +0 -0
- moe_benchmarks/megablocks_yamoe/cells/__pycache__/config.cpython-311.pyc +0 -0
- moe_benchmarks/megablocks_yamoe/cells/binned_run.py +195 -0
- moe_benchmarks/megablocks_yamoe/cells/gptoss_run.py +147 -0
- moe_benchmarks/megablocks_yamoe/cells/gptoss_training_run.py +138 -0
- moe_benchmarks/megablocks_yamoe/cells/megablocks_run.py +103 -0
- moe_benchmarks/megablocks_yamoe/cells/setup.py +116 -0
- moe_benchmarks/megablocks_yamoe/megablocks_yamoe.html +263 -36
- moe_benchmarks/megablocks_yamoe/torch_profile.html +0 -0
flash_attn/benchmark.html
ADDED
|
The diff for this file is too large to render.
See raw diff
|
|
|
flash_attn/cells/benchmark.py
ADDED
|
@@ -0,0 +1,341 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# /// script
|
| 2 |
+
# dependencies = [
|
| 3 |
+
# "numpy",
|
| 4 |
+
# "torch",
|
| 5 |
+
# "kernels",
|
| 6 |
+
# "pandas",
|
| 7 |
+
# "matplotlib"
|
| 8 |
+
# ]
|
| 9 |
+
# ///
|
| 10 |
+
# Benchmarking common shapes for Flux 1024x1024px image + varying text sequence lengths
|
| 11 |
+
|
| 12 |
+
import functools
|
| 13 |
+
import os
|
| 14 |
+
import pathlib
|
| 15 |
+
|
| 16 |
+
import matplotlib.pyplot as plt
|
| 17 |
+
import torch
|
| 18 |
+
import torch._dynamo.config
|
| 19 |
+
import triton
|
| 20 |
+
import triton.language as tl
|
| 21 |
+
|
| 22 |
+
try:
|
| 23 |
+
from flash_attn import flash_attn_func
|
| 24 |
+
except:
|
| 25 |
+
flash_attn_func = None
|
| 26 |
+
print("Flash Attention 2 not found.")
|
| 27 |
+
|
| 28 |
+
try:
|
| 29 |
+
from flash_attn_interface import flash_attn_func as flash_attn_3_func
|
| 30 |
+
except:
|
| 31 |
+
flash_attn_3_func = None
|
| 32 |
+
print("Flash Attention 3 not found.")
|
| 33 |
+
|
| 34 |
+
try:
|
| 35 |
+
from kernels import get_kernel
|
| 36 |
+
hf_kernels_flash_attn = get_kernel("kernels-community/flash-attn")
|
| 37 |
+
hf_kernels_flash_attn_3 = get_kernel("kernels-community/flash-attn3")
|
| 38 |
+
except:
|
| 39 |
+
hf_kernels_flash_attn = None
|
| 40 |
+
hf_kernels_flash_attn_3 = None
|
| 41 |
+
print("HF Kernels not found.")
|
| 42 |
+
|
| 43 |
+
try:
|
| 44 |
+
from sageattention import sageattn_qk_int8_pv_fp16_cuda, sageattn_qk_int8_pv_fp16_triton, sageattn_qk_int8_pv_fp8_cuda_sm90
|
| 45 |
+
except:
|
| 46 |
+
sageattn_qk_int8_pv_fp16_cuda = None
|
| 47 |
+
sageattn_qk_int8_pv_fp16_triton = None
|
| 48 |
+
sageattn_qk_int8_pv_fp8_cuda_sm90 = None
|
| 49 |
+
print("SageAttention not found.")
|
| 50 |
+
|
| 51 |
+
try:
|
| 52 |
+
from transformer_engine.pytorch.attention import DotProductAttention
|
| 53 |
+
except:
|
| 54 |
+
DotProductAttention = None
|
| 55 |
+
print("Transformer Engine not found.")
|
| 56 |
+
|
| 57 |
+
try:
|
| 58 |
+
import xformers.ops as xops
|
| 59 |
+
except:
|
| 60 |
+
xops = None
|
| 61 |
+
print("xFormers not found.")
|
| 62 |
+
|
| 63 |
+
|
| 64 |
+
plt.rcParams.update({
|
| 65 |
+
"figure.figsize": (12, 10),
|
| 66 |
+
"figure.dpi": 120,
|
| 67 |
+
"font.size": 10,
|
| 68 |
+
"axes.titlesize": 12,
|
| 69 |
+
"axes.labelsize": 14,
|
| 70 |
+
"xtick.labelsize": 10,
|
| 71 |
+
"ytick.labelsize": 10,
|
| 72 |
+
"legend.fontsize": 8,
|
| 73 |
+
"axes.grid": True,
|
| 74 |
+
"grid.alpha": 0.3,
|
| 75 |
+
"grid.linestyle": "--",
|
| 76 |
+
"lines.linewidth": 2.0,
|
| 77 |
+
"lines.markersize": 6,
|
| 78 |
+
"legend.frameon": True,
|
| 79 |
+
"legend.framealpha": 0.9,
|
| 80 |
+
"legend.loc": "best",
|
| 81 |
+
"axes.spines.top": False,
|
| 82 |
+
"axes.spines.right": False,
|
| 83 |
+
})
|
| 84 |
+
|
| 85 |
+
|
| 86 |
+
# We want to compare the best compiled version for each specific shape (dynamic=False)
|
| 87 |
+
torch._dynamo.config.cache_size_limit = 10000
|
| 88 |
+
|
| 89 |
+
# We need to suppress_errors for FA3 to work. It makes it run in eager mode.
|
| 90 |
+
# I can't seem to get it to work any other way under torch.compile, so any suggestions are welcome!
|
| 91 |
+
torch._dynamo.config.suppress_errors = True
|
| 92 |
+
|
| 93 |
+
output_dir = pathlib.Path("dump_attention_benchmark")
|
| 94 |
+
output_dir.mkdir(parents=True, exist_ok=True)
|
| 95 |
+
|
| 96 |
+
batch_size = 1
|
| 97 |
+
num_attention_heads = 24
|
| 98 |
+
attention_head_dim = 128
|
| 99 |
+
image_sequence_length = 4096 # 1024x1024px
|
| 100 |
+
text_sequence_lengths = [128, 256, 320, 384, 448, 512]
|
| 101 |
+
sequence_lengths = [image_sequence_length + i for i in text_sequence_lengths]
|
| 102 |
+
|
| 103 |
+
|
| 104 |
+
def _attention_torch(query, key, value, *, backend):
|
| 105 |
+
query, key, value = (x.transpose(1, 2).contiguous() for x in (query, key, value))
|
| 106 |
+
with torch.nn.attention.sdpa_kernel(backend):
|
| 107 |
+
out = torch.nn.functional.scaled_dot_product_attention(query, key, value)
|
| 108 |
+
out = out.transpose(1, 2).contiguous()
|
| 109 |
+
return out
|
| 110 |
+
|
| 111 |
+
|
| 112 |
+
_compiled_attention_torch_default = torch.compile(_attention_torch, mode="default", fullgraph=True, dynamic=False)
|
| 113 |
+
def _attention_torch_compile_default(query, key, value, *, backend):
|
| 114 |
+
return _compiled_attention_torch_default(query, key, value, backend=backend)
|
| 115 |
+
|
| 116 |
+
|
| 117 |
+
_compiled_attention_torch_max_autotune = torch.compile(_attention_torch, mode="max-autotune", fullgraph=True, dynamic=False)
|
| 118 |
+
def _attention_torch_compile_max_autotune(query, key, value, *, backend):
|
| 119 |
+
return _compiled_attention_torch_max_autotune(query, key, value, backend=backend)
|
| 120 |
+
|
| 121 |
+
|
| 122 |
+
def _attention_flash_attn_2(query, key, value):
|
| 123 |
+
return flash_attn_func(query, key, value)
|
| 124 |
+
|
| 125 |
+
|
| 126 |
+
_compiled_flash_attn_2_default = torch.compile(_attention_flash_attn_2, mode="default", fullgraph=True, dynamic=False)
|
| 127 |
+
def _attention_flash_attn_2_compile_default(query, key, value):
|
| 128 |
+
return _compiled_flash_attn_2_default(query, key, value)
|
| 129 |
+
|
| 130 |
+
|
| 131 |
+
_compiled_flash_attn_2_max_autotune = torch.compile(_attention_flash_attn_2, mode="max-autotune", fullgraph=True, dynamic=False)
|
| 132 |
+
def _attention_flash_attn_2_compile_max_autotune(query, key, value):
|
| 133 |
+
return _compiled_flash_attn_2_max_autotune(query, key, value)
|
| 134 |
+
|
| 135 |
+
|
| 136 |
+
# For fullgraph=True tracing to be compatible
|
| 137 |
+
@torch.library.custom_op("flash_attn_3::_flash_attn_forward", mutates_args=(), device_types="cuda")
|
| 138 |
+
def _wrapped_flash_attn_3(query: torch.Tensor, key: torch.Tensor, value: torch.Tensor) -> torch.Tensor:
|
| 139 |
+
out, lse = flash_attn_3_func(query, key, value)
|
| 140 |
+
return out
|
| 141 |
+
|
| 142 |
+
|
| 143 |
+
@torch.library.register_fake("flash_attn_3::_flash_attn_forward")
|
| 144 |
+
def _(query: torch.Tensor, key: torch.Tensor, value: torch.Tensor) -> torch.Tensor:
|
| 145 |
+
return torch.empty_like(query)
|
| 146 |
+
|
| 147 |
+
|
| 148 |
+
def _attention_flash_attn_3(query, key, value):
|
| 149 |
+
out = _wrapped_flash_attn_3(query, key, value)
|
| 150 |
+
return out
|
| 151 |
+
|
| 152 |
+
|
| 153 |
+
_compiled_flash_attn_3_default = torch.compile(_attention_flash_attn_3, mode="default", fullgraph=True, dynamic=False)
|
| 154 |
+
def _attention_flash_attn_3_compile_default(query, key, value):
|
| 155 |
+
return _compiled_flash_attn_3_default(query, key, value)
|
| 156 |
+
|
| 157 |
+
|
| 158 |
+
_compiled_flash_attn_3_max_autotune = torch.compile(_attention_flash_attn_3, mode="max-autotune", fullgraph=True, dynamic=False)
|
| 159 |
+
def _attention_flash_attn_3_compile_max_autotune(query, key, value):
|
| 160 |
+
return _compiled_flash_attn_3_max_autotune(query, key, value)
|
| 161 |
+
|
| 162 |
+
|
| 163 |
+
def _attention_hf_kernels_flash_attn(query, key, value):
|
| 164 |
+
return hf_kernels_flash_attn.fwd(query, key, value, is_causal=False)[0]
|
| 165 |
+
|
| 166 |
+
|
| 167 |
+
def _attention_hf_kernels_flash_attn3(query, key, value):
|
| 168 |
+
return hf_kernels_flash_attn_3.flash_attn_func(query, key, value, causal=False)[0]
|
| 169 |
+
|
| 170 |
+
|
| 171 |
+
def _attention_sageattn_qk_int8_pv_fp16_cuda(query, key, value):
|
| 172 |
+
return sageattn_qk_int8_pv_fp16_cuda(query, key, value, tensor_layout="NHD")
|
| 173 |
+
|
| 174 |
+
|
| 175 |
+
def _attention_sageattn_qk_int8_pv_fp16_triton(query, key, value):
|
| 176 |
+
return sageattn_qk_int8_pv_fp16_triton(query, key, value, tensor_layout="NHD")
|
| 177 |
+
|
| 178 |
+
|
| 179 |
+
def _attention_sageattn_qk_int8_pv_fp8_cuda_sm90(query, key, value):
|
| 180 |
+
return sageattn_qk_int8_pv_fp8_cuda_sm90(query, key, value, tensor_layout="NHD")
|
| 181 |
+
|
| 182 |
+
|
| 183 |
+
if DotProductAttention is not None:
|
| 184 |
+
def set_te_backend(backend):
|
| 185 |
+
# must be applied before first use of
|
| 186 |
+
# transformer_engine.pytorch.attention
|
| 187 |
+
os.environ["NVTE_FLASH_ATTN"] = '0'
|
| 188 |
+
os.environ["NVTE_FUSED_ATTN"] = '0'
|
| 189 |
+
os.environ["NVTE_UNFUSED_ATTN"] = '0'
|
| 190 |
+
if backend == 'flash':
|
| 191 |
+
os.environ["NVTE_FLASH_ATTN"] = '1'
|
| 192 |
+
if backend == 'fused':
|
| 193 |
+
os.environ["NVTE_FUSED_ATTN"] = '1'
|
| 194 |
+
if backend == 'unfused':
|
| 195 |
+
os.environ["NVTE_UNFUSED_ATTN"] = '1'
|
| 196 |
+
|
| 197 |
+
set_te_backend("fused")
|
| 198 |
+
te_attn_fn = DotProductAttention(
|
| 199 |
+
num_attention_heads=num_attention_heads,
|
| 200 |
+
kv_channels=attention_head_dim,
|
| 201 |
+
qkv_format="bshd",
|
| 202 |
+
attn_mask_type="no_mask",
|
| 203 |
+
)
|
| 204 |
+
else:
|
| 205 |
+
def te_attn_fn(query, key, value):
|
| 206 |
+
raise RuntimeError("Transformer Engine is not available. Please install it for TE-based attention.")
|
| 207 |
+
|
| 208 |
+
def _attention_te(query, key, value):
|
| 209 |
+
out = te_attn_fn(query, key, value)
|
| 210 |
+
out = out.unflatten(2, (num_attention_heads, attention_head_dim))
|
| 211 |
+
return out
|
| 212 |
+
|
| 213 |
+
|
| 214 |
+
# Cannot fullgraph compile TE
|
| 215 |
+
_compiled_te_attn_fn_default = torch.compile(_attention_te, mode="default", fullgraph=False, dynamic=False)
|
| 216 |
+
def _attention_te_compile_default(query, key, value):
|
| 217 |
+
return _compiled_te_attn_fn_default(query, key, value)
|
| 218 |
+
|
| 219 |
+
|
| 220 |
+
# Cannot fullgraph compile TE
|
| 221 |
+
_compiled_te_attn_fn_max_autotune = torch.compile(_attention_te, mode="max-autotune", fullgraph=False, dynamic=False)
|
| 222 |
+
def _attention_te_compile_max_autotune(query, key, value):
|
| 223 |
+
return _compiled_te_attn_fn_max_autotune(query, key, value)
|
| 224 |
+
|
| 225 |
+
|
| 226 |
+
def _attention_xformers(query, key, value):
|
| 227 |
+
return xops.memory_efficient_attention(query, key, value)
|
| 228 |
+
|
| 229 |
+
|
| 230 |
+
_compiled_xformers_default = torch.compile(_attention_xformers, mode="default", fullgraph=True, dynamic=False)
|
| 231 |
+
def _attention_xformers_compile_default(query, key, value):
|
| 232 |
+
return _compiled_xformers_default(query, key, value)
|
| 233 |
+
|
| 234 |
+
|
| 235 |
+
_compiled_xformers_max_autotune = torch.compile(_attention_xformers, mode="max-autotune", fullgraph=True, dynamic=False)
|
| 236 |
+
def _attention_xformers_compile_max_autotune(query, key, value):
|
| 237 |
+
return _compiled_xformers_max_autotune(query, key, value)
|
| 238 |
+
|
| 239 |
+
|
| 240 |
+
attention_ops = {}
|
| 241 |
+
attention_ops["torch_cudnn"] = functools.partial(_attention_torch, backend=torch.nn.attention.SDPBackend.CUDNN_ATTENTION)
|
| 242 |
+
attention_ops["torch_cudnn_compile_d"] = functools.partial(_attention_torch_compile_default, backend=torch.nn.attention.SDPBackend.CUDNN_ATTENTION)
|
| 243 |
+
attention_ops["torch_cudnn_compile_ma"] = functools.partial(_attention_torch_compile_max_autotune, backend=torch.nn.attention.SDPBackend.CUDNN_ATTENTION)
|
| 244 |
+
attention_ops["torch_flash"] = functools.partial(_attention_torch, backend=torch.nn.attention.SDPBackend.FLASH_ATTENTION)
|
| 245 |
+
attention_ops["torch_flash_compile_d"] = functools.partial(_attention_torch_compile_default, backend=torch.nn.attention.SDPBackend.FLASH_ATTENTION)
|
| 246 |
+
attention_ops["torch_flash_compile_ma"] = functools.partial(_attention_torch_compile_max_autotune, backend=torch.nn.attention.SDPBackend.FLASH_ATTENTION)
|
| 247 |
+
if hf_kernels_flash_attn is not None:
|
| 248 |
+
attention_ops["hf_flash_attn"] = _attention_hf_kernels_flash_attn
|
| 249 |
+
attention_ops["hf_flash_attn3"] = _attention_hf_kernels_flash_attn3
|
| 250 |
+
if flash_attn_func is not None:
|
| 251 |
+
attention_ops["flash_attn_2"] = _attention_flash_attn_2
|
| 252 |
+
attention_ops["flash_attn_2_compile_d"] = _attention_flash_attn_2_compile_default
|
| 253 |
+
attention_ops["flash_attn_2_compile_ma"] = _attention_flash_attn_2_compile_max_autotune
|
| 254 |
+
if flash_attn_3_func is not None:
|
| 255 |
+
attention_ops["flash_attn_3"] = _attention_flash_attn_3
|
| 256 |
+
attention_ops["flash_attn_3_compile_d"] = _attention_flash_attn_3_compile_default
|
| 257 |
+
attention_ops["flash_attn_3_compile_ma"] = _attention_flash_attn_3_compile_max_autotune
|
| 258 |
+
if sageattn_qk_int8_pv_fp16_cuda is not None:
|
| 259 |
+
attention_ops["sageattn_qk_int8_pv_fp16_cuda"] = _attention_sageattn_qk_int8_pv_fp16_cuda
|
| 260 |
+
attention_ops["sageattn_qk_int8_pv_fp16_triton"] = _attention_sageattn_qk_int8_pv_fp16_triton
|
| 261 |
+
if torch.cuda.get_device_capability()[0] >= 9:
|
| 262 |
+
attention_ops["sageattn_qk_int8_pv_fp8_cuda_sm90"] = _attention_sageattn_qk_int8_pv_fp8_cuda_sm90
|
| 263 |
+
if DotProductAttention is not None:
|
| 264 |
+
attention_ops["te_fused"] = _attention_te
|
| 265 |
+
attention_ops["te_fused_compile_d"] = _attention_te_compile_default
|
| 266 |
+
attention_ops["te_fused_compile_ma"] = _attention_te_compile_max_autotune
|
| 267 |
+
if xops is not None:
|
| 268 |
+
attention_ops["xformers"] = _attention_xformers
|
| 269 |
+
attention_ops["xformers_compile_d"] = _attention_xformers_compile_default
|
| 270 |
+
attention_ops["xformers_compile_ma"] = _attention_xformers_compile_max_autotune
|
| 271 |
+
|
| 272 |
+
|
| 273 |
+
def get_color_and_linestyle(n: int) -> tuple[str, str]:
|
| 274 |
+
colors = ["#e41a1c", "#377eb8", "#4daf4a", "#984ea3", "#ff7f00", "#a65628", "#f781bf", "#999999"]
|
| 275 |
+
line_styles = ["-", ":", "-.", "--"]
|
| 276 |
+
if n > len(colors) * len(line_styles):
|
| 277 |
+
raise ValueError(f"Required {n=} styles but maximum is {len(colors) * len(line_styles)}")
|
| 278 |
+
styles = []
|
| 279 |
+
for i in range(n):
|
| 280 |
+
color = colors[i % len(colors)]
|
| 281 |
+
linestyle = line_styles[i // len(colors)]
|
| 282 |
+
styles.append((color, linestyle))
|
| 283 |
+
return styles
|
| 284 |
+
|
| 285 |
+
|
| 286 |
+
def correctness():
|
| 287 |
+
for seq_len in sequence_lengths:
|
| 288 |
+
shape = (batch_size, seq_len, num_attention_heads, attention_head_dim)
|
| 289 |
+
print(f"\n\n===== Testing shape: {shape} =====")
|
| 290 |
+
|
| 291 |
+
query = torch.randn(shape, device="cuda", dtype=torch.float32)
|
| 292 |
+
key = torch.randn(shape, device="cuda", dtype=torch.float32)
|
| 293 |
+
value = torch.randn(shape, device="cuda", dtype=torch.float32)
|
| 294 |
+
|
| 295 |
+
golden_truth = _attention_torch(query, key, value, backend=torch.nn.attention.SDPBackend.MATH)
|
| 296 |
+
query, key, value = (x.bfloat16() for x in (query, key, value))
|
| 297 |
+
|
| 298 |
+
for name, fn in attention_ops.items():
|
| 299 |
+
out = fn(query, key, value)
|
| 300 |
+
absdiff = (out - golden_truth).abs()
|
| 301 |
+
absmax = torch.max(absdiff)
|
| 302 |
+
mae = torch.mean(absdiff)
|
| 303 |
+
mse = torch.mean((golden_truth - out) ** 2)
|
| 304 |
+
print(f"{name:<30}: absmax={absmax:.6f}, mae={mae:.6f}, mse={mse:.6f}")
|
| 305 |
+
|
| 306 |
+
|
| 307 |
+
@triton.testing.perf_report(
|
| 308 |
+
triton.testing.Benchmark(
|
| 309 |
+
x_names=["seq_len"],
|
| 310 |
+
x_vals=sequence_lengths,
|
| 311 |
+
x_log=False,
|
| 312 |
+
line_arg="provider",
|
| 313 |
+
line_vals=list(attention_ops.keys()),
|
| 314 |
+
line_names=[x.removeprefix("solution_") for x in attention_ops.keys()],
|
| 315 |
+
ylabel="Time (ms)",
|
| 316 |
+
styles=get_color_and_linestyle(len(attention_ops)),
|
| 317 |
+
plot_name="Attention Benchmark",
|
| 318 |
+
args={},
|
| 319 |
+
)
|
| 320 |
+
)
|
| 321 |
+
def benchmark_fn(seq_len: int, provider: str):
|
| 322 |
+
torch.manual_seed(0)
|
| 323 |
+
|
| 324 |
+
shape = (batch_size, seq_len, num_attention_heads, attention_head_dim)
|
| 325 |
+
query = torch.randn(shape, device="cuda", dtype=torch.bfloat16) * torch.randint(1, 5, shape, device="cuda", dtype=torch.bfloat16)
|
| 326 |
+
key = torch.randn(shape, device="cuda", dtype=torch.bfloat16) * torch.randint(1, 5, shape, device="cuda", dtype=torch.bfloat16)
|
| 327 |
+
value = torch.randn(shape, device="cuda", dtype=torch.bfloat16) * torch.randint(1, 5, shape, device="cuda", dtype=torch.bfloat16)
|
| 328 |
+
|
| 329 |
+
fn = attention_ops[provider]
|
| 330 |
+
ms, min_ms, max_ms = triton.testing.do_bench(
|
| 331 |
+
lambda: fn(query, key, value),
|
| 332 |
+
warmup=3,
|
| 333 |
+
rep=10,
|
| 334 |
+
quantiles=[0.5, 0.2, 0.8],
|
| 335 |
+
)
|
| 336 |
+
return ms, max_ms, min_ms
|
| 337 |
+
|
| 338 |
+
|
| 339 |
+
with torch.inference_mode():
|
| 340 |
+
correctness()
|
| 341 |
+
benchmark_fn.run(print_data=True, save_path=output_dir.as_posix())
|
flash_attn/index.html
ADDED
|
@@ -0,0 +1,24 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
<!DOCTYPE html>
|
| 2 |
+
<html>
|
| 3 |
+
<head>
|
| 4 |
+
<meta charset='UTF-8'>
|
| 5 |
+
<title>Directory Index</title>
|
| 6 |
+
<style>
|
| 7 |
+
body { font-family: monospace; margin: 20px; }
|
| 8 |
+
h1 { font-size: 1.5em; }
|
| 9 |
+
ul { list-style-type: none; padding-left: 20px; }
|
| 10 |
+
li { margin: 5px 0; }
|
| 11 |
+
.dir { font-weight: bold; }
|
| 12 |
+
.file { color: #0066cc; }
|
| 13 |
+
a { text-decoration: none; }
|
| 14 |
+
a:hover { text-decoration: underline; }
|
| 15 |
+
</style>
|
| 16 |
+
</head>
|
| 17 |
+
<body>
|
| 18 |
+
<h1>Index of /flash_attn</h1>
|
| 19 |
+
<ul>
|
| 20 |
+
<li><a href='../index.html' class='dir'>../</a></li>
|
| 21 |
+
<li><a href='benchmark.html' class='file'>benchmark.html</a></li>
|
| 22 |
+
</ul>
|
| 23 |
+
</body>
|
| 24 |
+
</html>
|
index.html
CHANGED
|
@@ -17,8 +17,8 @@
|
|
| 17 |
<body>
|
| 18 |
<h1>Index of /</h1>
|
| 19 |
<ul>
|
| 20 |
-
<li><a href='
|
| 21 |
-
<li><a href='
|
| 22 |
</ul>
|
| 23 |
</body>
|
| 24 |
</html>
|
|
|
|
| 17 |
<body>
|
| 18 |
<h1>Index of /</h1>
|
| 19 |
<ul>
|
| 20 |
+
<li><a href='flash_attn/index.html' class='dir'>flash_attn/</a></li>
|
| 21 |
+
<li><a href='moe_benchmarks/index.html' class='dir'>moe_benchmarks/</a></li>
|
| 22 |
</ul>
|
| 23 |
</body>
|
| 24 |
</html>
|
moe_benchmarks/megablocks/cells/forward_and_backward.py
ADDED
|
@@ -0,0 +1,196 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# /// script
|
| 2 |
+
# requires-python = ">=3.12"
|
| 3 |
+
# dependencies = [
|
| 4 |
+
# "accelerate>=1.10.1",
|
| 5 |
+
# "torch>=2.7.0",
|
| 6 |
+
# "kernels==0.10.0",
|
| 7 |
+
# "transformers@https://github.com/huggingface/transformers.git",
|
| 8 |
+
# "ipdb>=0.13.13",
|
| 9 |
+
# "matplotlib>=3.7.2",
|
| 10 |
+
# "numpy>=1.24.3",
|
| 11 |
+
# ]
|
| 12 |
+
# ///
|
| 13 |
+
|
| 14 |
+
import torch
|
| 15 |
+
from transformers import GptOssForCausalLM, PreTrainedTokenizerFast, Mxfp4Config
|
| 16 |
+
import time
|
| 17 |
+
import torch.nn as nn
|
| 18 |
+
from kernels import register_kernel_mapping, Mode, LayerRepository, replace_kernel_forward_from_hub
|
| 19 |
+
import sys
|
| 20 |
+
import torch.profiler
|
| 21 |
+
import gc
|
| 22 |
+
import logging
|
| 23 |
+
from transformers.models.gpt_oss.modeling_gpt_oss import GptOssRMSNorm
|
| 24 |
+
|
| 25 |
+
# remove liger kernel for testing
|
| 26 |
+
replace_kernel_forward_from_hub(GptOssRMSNorm, None)
|
| 27 |
+
|
| 28 |
+
# set to debug logging
|
| 29 |
+
logging.basicConfig(level=logging.INFO)
|
| 30 |
+
|
| 31 |
+
def reset_peak_memory_stats():
|
| 32 |
+
"""Clear CUDA cache and reset memory allocation counters."""
|
| 33 |
+
torch.cuda.empty_cache()
|
| 34 |
+
if torch.cuda.is_available():
|
| 35 |
+
torch.cuda.reset_peak_memory_stats()
|
| 36 |
+
gc.collect()
|
| 37 |
+
|
| 38 |
+
def get_memory_stats():
|
| 39 |
+
"""Get current and peak CUDA memory usage."""
|
| 40 |
+
if not torch.cuda.is_available():
|
| 41 |
+
return {"allocated_gb": 0, "peak_gb": 0, "reserved_gb": 0}
|
| 42 |
+
return {
|
| 43 |
+
"allocated_gb": torch.cuda.memory_allocated() / 1e9,
|
| 44 |
+
"peak_gb": torch.cuda.max_memory_allocated() / 1e9,
|
| 45 |
+
"reserved_gb": torch.cuda.memory_reserved() / 1e9,
|
| 46 |
+
}
|
| 47 |
+
|
| 48 |
+
def override_kernel_layer_name(cls_name: str, value) -> bool:
|
| 49 |
+
"""Helper to dynamically override the kernel_layer_name in a model class."""
|
| 50 |
+
for mod in sys.modules.values():
|
| 51 |
+
if mod is None:
|
| 52 |
+
continue
|
| 53 |
+
obj = getattr(mod, cls_name, None)
|
| 54 |
+
if isinstance(obj, type) and issubclass(obj, nn.Module):
|
| 55 |
+
setattr(obj, "kernel_layer_name", value)
|
| 56 |
+
print(f"Overrode {cls_name}.kernel_layer_name to {value}")
|
| 57 |
+
return True
|
| 58 |
+
return False
|
| 59 |
+
|
| 60 |
+
|
| 61 |
+
# Init the model the normal way
|
| 62 |
+
model_id = "openai/gpt-oss-20b"
|
| 63 |
+
tokenizer = PreTrainedTokenizerFast.from_pretrained(model_id)
|
| 64 |
+
quantization_config = Mxfp4Config(dequantize=True)
|
| 65 |
+
|
| 66 |
+
model = GptOssForCausalLM.from_pretrained(
|
| 67 |
+
model_id,
|
| 68 |
+
dtype="bfloat16",
|
| 69 |
+
device_map="auto",
|
| 70 |
+
use_kernels=True,
|
| 71 |
+
quantization_config=quantization_config,
|
| 72 |
+
).eval()
|
| 73 |
+
|
| 74 |
+
messages = [
|
| 75 |
+
{"role": "system", "content": "What is Tensor Parallelism?"},
|
| 76 |
+
]
|
| 77 |
+
|
| 78 |
+
inputs = tokenizer.apply_chat_template(
|
| 79 |
+
messages,
|
| 80 |
+
add_generation_prompt=True,
|
| 81 |
+
return_tensors="pt",
|
| 82 |
+
return_dict=True,
|
| 83 |
+
reasoning_effort="low",
|
| 84 |
+
).to("cuda")
|
| 85 |
+
|
| 86 |
+
max_tokens = 128 # Reduced to help with memory usage
|
| 87 |
+
|
| 88 |
+
# Clear memory before backward pass
|
| 89 |
+
reset_peak_memory_stats()
|
| 90 |
+
print(f"Pre-generation memory: {get_memory_stats()}")
|
| 91 |
+
|
| 92 |
+
# forward and backward pass
|
| 93 |
+
with torch.autograd.set_grad_enabled(True):
|
| 94 |
+
start_time = time.perf_counter()
|
| 95 |
+
generated = model.generate(
|
| 96 |
+
**inputs,
|
| 97 |
+
max_new_tokens=max_tokens,
|
| 98 |
+
do_sample=False,
|
| 99 |
+
temperature=None,
|
| 100 |
+
)
|
| 101 |
+
end_time = time.perf_counter()
|
| 102 |
+
print(tokenizer.decode(generated[0], skip_special_tokens=False))
|
| 103 |
+
print(f"Generation took {end_time - start_time:.2f} seconds")
|
| 104 |
+
print(f"Post-generation memory: {get_memory_stats()}")
|
| 105 |
+
|
| 106 |
+
# Use gradient checkpointing to reduce memory usage
|
| 107 |
+
if hasattr(model, 'gradient_checkpointing_enable'):
|
| 108 |
+
model.gradient_checkpointing_enable()
|
| 109 |
+
print("Enabled gradient checkpointing")
|
| 110 |
+
|
| 111 |
+
# Reduce sequence length if needed for memory
|
| 112 |
+
max_seq_len = 512 # Limit sequence length for backward pass
|
| 113 |
+
if generated.size(1) > max_seq_len:
|
| 114 |
+
print(f"Truncating sequence from {generated.size(1)} to {max_seq_len} tokens")
|
| 115 |
+
full_sequence = generated[:, -max_seq_len:]
|
| 116 |
+
else:
|
| 117 |
+
full_sequence = generated
|
| 118 |
+
|
| 119 |
+
# Get model outputs for the full sequence
|
| 120 |
+
model.train() # Enable dropout and other training behaviors
|
| 121 |
+
|
| 122 |
+
try:
|
| 123 |
+
outputs = model(
|
| 124 |
+
input_ids=full_sequence,
|
| 125 |
+
labels=full_sequence, # This will compute loss internally
|
| 126 |
+
return_dict=True
|
| 127 |
+
)
|
| 128 |
+
print(f"Post-forward memory: {get_memory_stats()}")
|
| 129 |
+
|
| 130 |
+
# If model doesn't compute loss, compute it manually
|
| 131 |
+
if outputs.loss is None:
|
| 132 |
+
shift_logits = outputs.logits[..., :-1, :].contiguous()
|
| 133 |
+
shift_labels = full_sequence[..., 1:].contiguous()
|
| 134 |
+
|
| 135 |
+
# Use CrossEntropyLoss with ignore_index for padding tokens
|
| 136 |
+
loss_fct = torch.nn.CrossEntropyLoss(ignore_index=tokenizer.pad_token_id if tokenizer.pad_token_id is not None else -100)
|
| 137 |
+
loss = loss_fct(
|
| 138 |
+
shift_logits.view(-1, shift_logits.size(-1)),
|
| 139 |
+
shift_labels.view(-1)
|
| 140 |
+
)
|
| 141 |
+
else:
|
| 142 |
+
loss = outputs.loss
|
| 143 |
+
|
| 144 |
+
print(f"Loss: {loss.item():.4f}")
|
| 145 |
+
|
| 146 |
+
# Clear intermediate tensors to save memory
|
| 147 |
+
del outputs
|
| 148 |
+
torch.cuda.empty_cache()
|
| 149 |
+
|
| 150 |
+
# Perform backward pass with memory management
|
| 151 |
+
print("Running backward pass...")
|
| 152 |
+
print(f"Pre-backward memory: {get_memory_stats()}")
|
| 153 |
+
|
| 154 |
+
loss.backward()
|
| 155 |
+
print(f"Post-backward memory: {get_memory_stats()}")
|
| 156 |
+
|
| 157 |
+
except torch.cuda.OutOfMemoryError as e:
|
| 158 |
+
print(f"OOM during forward/backward pass: {e}")
|
| 159 |
+
print("Try reducing max_tokens or max_seq_len")
|
| 160 |
+
raise
|
| 161 |
+
|
| 162 |
+
# Calculate gradient statistics and print sample gradients
|
| 163 |
+
total_norm = 0.0
|
| 164 |
+
param_count = 0
|
| 165 |
+
grad_samples = {}
|
| 166 |
+
|
| 167 |
+
for name, p in model.named_parameters():
|
| 168 |
+
if p.grad is not None:
|
| 169 |
+
param_count += 1
|
| 170 |
+
grad_norm = p.grad.data.norm(2).item()
|
| 171 |
+
total_norm += grad_norm ** 2
|
| 172 |
+
|
| 173 |
+
# Collect gradient statistics for key layers
|
| 174 |
+
if any(key in name for key in ['embed', 'lm_head', 'mlp.up', 'mlp.down', 'self_attn.q_proj', 'norm']):
|
| 175 |
+
grad_samples[name] = {
|
| 176 |
+
'norm': grad_norm,
|
| 177 |
+
'mean': p.grad.data.mean().item(),
|
| 178 |
+
'std': p.grad.data.std().item(),
|
| 179 |
+
'max': p.grad.data.max().item(),
|
| 180 |
+
'min': p.grad.data.min().item(),
|
| 181 |
+
}
|
| 182 |
+
|
| 183 |
+
total_norm = total_norm ** 0.5
|
| 184 |
+
|
| 185 |
+
print(f"\nGradient norm: {total_norm:.4f}")
|
| 186 |
+
print(f"Parameters with gradients: {param_count}")
|
| 187 |
+
|
| 188 |
+
# Print sample gradients from important layers
|
| 189 |
+
print("\nSample gradient statistics:")
|
| 190 |
+
for i, (name, stats) in enumerate(list(grad_samples.items())[:10]):
|
| 191 |
+
print(f" {name[:60]:<60} | norm: {stats['norm']:.4e} | mean: {stats['mean']:.4e} | std: {stats['std']:.4e}")
|
| 192 |
+
|
| 193 |
+
# Optional: zero gradients for next iteration
|
| 194 |
+
model.zero_grad()
|
| 195 |
+
model.eval() # Switch back to eval mode
|
| 196 |
+
|
moe_benchmarks/megablocks/megablocks_only.html
CHANGED
|
@@ -3710,7 +3710,7 @@ span.linenos.special { color: #000000; background-color: #ffffc0; padding-left:
|
|
| 3710 |
<div class="system-info">
|
| 3711 |
<div class="system-info-header">Generated on:</div>
|
| 3712 |
<div class="system-info-content">
|
| 3713 |
-
Linux x86_64 | Linux-6.
|
| 3714 |
</div>
|
| 3715 |
</div>
|
| 3716 |
|
|
@@ -3724,122 +3724,219 @@ span.linenos.special { color: #000000; background-color: #ffffc0; padding-left:
|
|
| 3724 |
<p>Next we can run with Megablocks kernels enabled.</p>
|
| 3725 |
<h3>Forward</h3>
|
| 3726 |
<p>First, we run a forward pass with Megablocks kernels.</p>
|
| 3727 |
-
<
|
|
|
|
|
|
|
| 3728 |
<div class="cell-header">
|
| 3729 |
<span class="collapse-indicators">
|
| 3730 |
-
<span onclick="toggleCode('
|
| 3731 |
-
<span onclick="toggleOutput('
|
| 3732 |
-
<span id="uv-indicator-
|
| 3733 |
</span> |
|
| 3734 |
-
Cell:
|
| 3735 |
-
| <button class="run-btn" onclick="runCell('
|
| 3736 |
-
<button class="copy-btn" onclick="copyCell('
|
| 3737 |
-
<a href="cells/
|
| 3738 |
</div>
|
| 3739 |
-
<div id="code-
|
| 3740 |
<div class="highlight-with-lines">
|
| 3741 |
-
<div class="line-numbers" id="lines-
|
| 3742 |
-
<a class="line-number" data-cell="
|
| 3743 |
-
<a class="line-number" data-cell="
|
| 3744 |
-
<a class="line-number" data-cell="
|
| 3745 |
-
<a class="line-number" data-cell="
|
| 3746 |
-
<a class="line-number" data-cell="
|
| 3747 |
-
<a class="line-number" data-cell="
|
| 3748 |
-
<a class="line-number" data-cell="
|
| 3749 |
-
<a class="line-number" data-cell="
|
| 3750 |
-
<a class="line-number" data-cell="
|
| 3751 |
-
<a class="line-number" data-cell="
|
| 3752 |
-
<a class="line-number" data-cell="
|
| 3753 |
-
<a class="line-number" data-cell="
|
| 3754 |
-
<a class="line-number" data-cell="
|
| 3755 |
-
<a class="line-number" data-cell="
|
| 3756 |
-
<a class="line-number" data-cell="
|
| 3757 |
-
<a class="line-number" data-cell="
|
| 3758 |
-
<a class="line-number" data-cell="
|
| 3759 |
-
<a class="line-number" data-cell="
|
| 3760 |
-
<a class="line-number" data-cell="
|
| 3761 |
-
<a class="line-number" data-cell="
|
| 3762 |
-
<a class="line-number" data-cell="
|
| 3763 |
-
<a class="line-number" data-cell="
|
| 3764 |
-
<a class="line-number" data-cell="
|
| 3765 |
-
<a class="line-number" data-cell="
|
| 3766 |
-
<a class="line-number" data-cell="
|
| 3767 |
-
<a class="line-number" data-cell="
|
| 3768 |
-
<a class="line-number" data-cell="
|
| 3769 |
-
<a class="line-number" data-cell="
|
| 3770 |
-
<a class="line-number" data-cell="
|
| 3771 |
-
<a class="line-number" data-cell="
|
| 3772 |
-
<a class="line-number" data-cell="
|
| 3773 |
-
<a class="line-number" data-cell="
|
| 3774 |
-
<a class="line-number" data-cell="
|
| 3775 |
-
<a class="line-number" data-cell="
|
| 3776 |
-
<a class="line-number" data-cell="
|
| 3777 |
-
<a class="line-number" data-cell="
|
| 3778 |
-
<a class="line-number" data-cell="
|
| 3779 |
-
<a class="line-number" data-cell="
|
| 3780 |
-
<a class="line-number" data-cell="
|
| 3781 |
-
<a class="line-number" data-cell="
|
| 3782 |
-
<a class="line-number" data-cell="
|
| 3783 |
-
<a class="line-number" data-cell="
|
| 3784 |
-
<a class="line-number" data-cell="
|
| 3785 |
-
<a class="line-number" data-cell="
|
| 3786 |
-
<a class="line-number" data-cell="
|
| 3787 |
-
<a class="line-number" data-cell="
|
| 3788 |
-
<a class="line-number" data-cell="
|
| 3789 |
-
<a class="line-number" data-cell="
|
| 3790 |
-
<a class="line-number" data-cell="
|
| 3791 |
-
<a class="line-number" data-cell="
|
| 3792 |
-
<a class="line-number" data-cell="
|
| 3793 |
-
<a class="line-number" data-cell="
|
| 3794 |
-
<a class="line-number" data-cell="
|
| 3795 |
-
<a class="line-number" data-cell="
|
| 3796 |
-
<a class="line-number" data-cell="
|
| 3797 |
-
<a class="line-number" data-cell="
|
| 3798 |
-
<a class="line-number" data-cell="
|
| 3799 |
-
<a class="line-number" data-cell="
|
| 3800 |
-
<a class="line-number" data-cell="
|
| 3801 |
-
<a class="line-number" data-cell="
|
| 3802 |
-
<a class="line-number" data-cell="
|
| 3803 |
-
<a class="line-number" data-cell="
|
| 3804 |
-
<a class="line-number" data-cell="
|
| 3805 |
-
<a class="line-number" data-cell="
|
| 3806 |
-
<a class="line-number" data-cell="
|
| 3807 |
-
<a class="line-number" data-cell="
|
| 3808 |
-
<a class="line-number" data-cell="
|
| 3809 |
-
<a class="line-number" data-cell="
|
| 3810 |
-
<a class="line-number" data-cell="
|
| 3811 |
-
<a class="line-number" data-cell="
|
| 3812 |
-
<a class="line-number" data-cell="
|
| 3813 |
-
<a class="line-number" data-cell="
|
| 3814 |
-
<a class="line-number" data-cell="
|
| 3815 |
-
<a class="line-number" data-cell="
|
| 3816 |
-
<a class="line-number" data-cell="
|
| 3817 |
-
<a class="line-number" data-cell="
|
| 3818 |
-
<a class="line-number" data-cell="
|
| 3819 |
-
<a class="line-number" data-cell="
|
| 3820 |
-
<a class="line-number" data-cell="
|
| 3821 |
-
<a class="line-number" data-cell="
|
| 3822 |
-
<a class="line-number" data-cell="
|
| 3823 |
-
<a class="line-number" data-cell="
|
| 3824 |
-
<a class="line-number" data-cell="
|
| 3825 |
-
<a class="line-number" data-cell="
|
| 3826 |
-
<a class="line-number" data-cell="
|
| 3827 |
-
<a class="line-number" data-cell="
|
| 3828 |
-
<a class="line-number" data-cell="
|
| 3829 |
-
<a class="line-number" data-cell="
|
| 3830 |
-
<a class="line-number" data-cell="
|
| 3831 |
-
<a class="line-number" data-cell="
|
| 3832 |
-
<a class="line-number" data-cell="
|
| 3833 |
-
<a class="line-number" data-cell="
|
| 3834 |
-
<a class="line-number" data-cell="
|
| 3835 |
-
<a class="line-number" data-cell="
|
| 3836 |
-
<a class="line-number" data-cell="
|
| 3837 |
-
<a class="line-number" data-cell="
|
| 3838 |
-
<a class="line-number" data-cell="
|
| 3839 |
-
<a class="line-number" data-cell="
|
| 3840 |
-
<a class="line-number" data-cell="
|
| 3841 |
-
<a class="line-number" data-cell="
|
| 3842 |
-
<a class="line-number" data-cell="
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 3843 |
</div>
|
| 3844 |
<div class="code-wrap">
|
| 3845 |
<div class="highlight"><pre><span></span><span class="c1"># /// script</span>
|
|
@@ -3866,7 +3963,7 @@ Cell: forward_only | 118.48s | FAILED
|
|
| 3866 |
<span class="kn">import</span><span class="w"> </span><span class="nn">logging</span>
|
| 3867 |
<span class="kn">from</span><span class="w"> </span><span class="nn">transformers.models.gpt_oss.modeling_gpt_oss</span><span class="w"> </span><span class="kn">import</span> <span class="n">GptOssRMSNorm</span>
|
| 3868 |
|
| 3869 |
-
|
| 3870 |
<span class="n">replace_kernel_forward_from_hub</span><span class="p">(</span><span class="n">GptOssRMSNorm</span><span class="p">,</span> <span class="kc">None</span><span class="p">)</span>
|
| 3871 |
|
| 3872 |
<span class="c1"># set to debug logging</span>
|
|
@@ -3907,8 +4004,6 @@ Cell: forward_only | 118.48s | FAILED
|
|
| 3907 |
<span class="n">tokenizer</span> <span class="o">=</span> <span class="n">PreTrainedTokenizerFast</span><span class="o">.</span><span class="n">from_pretrained</span><span class="p">(</span><span class="n">model_id</span><span class="p">)</span>
|
| 3908 |
<span class="n">quantization_config</span> <span class="o">=</span> <span class="n">Mxfp4Config</span><span class="p">(</span><span class="n">dequantize</span><span class="o">=</span><span class="kc">True</span><span class="p">)</span>
|
| 3909 |
|
| 3910 |
-
|
| 3911 |
-
|
| 3912 |
<span class="n">model</span> <span class="o">=</span> <span class="n">GptOssForCausalLM</span><span class="o">.</span><span class="n">from_pretrained</span><span class="p">(</span>
|
| 3913 |
<span class="n">model_id</span><span class="p">,</span>
|
| 3914 |
<span class="n">dtype</span><span class="o">=</span><span class="s2">"bfloat16"</span><span class="p">,</span>
|
|
@@ -3929,9 +4024,14 @@ Cell: forward_only | 118.48s | FAILED
|
|
| 3929 |
<span class="n">reasoning_effort</span><span class="o">=</span><span class="s2">"low"</span><span class="p">,</span>
|
| 3930 |
<span class="p">)</span><span class="o">.</span><span class="n">to</span><span class="p">(</span><span class="s2">"cuda"</span><span class="p">)</span>
|
| 3931 |
|
| 3932 |
-
<span class="n">max_tokens</span> <span class="o">=</span> <span class="mi">
|
|
|
|
|
|
|
|
|
|
|
|
|
| 3933 |
|
| 3934 |
-
<span class="
|
|
|
|
| 3935 |
<span class="n">start_time</span> <span class="o">=</span> <span class="n">time</span><span class="o">.</span><span class="n">perf_counter</span><span class="p">()</span>
|
| 3936 |
<span class="n">generated</span> <span class="o">=</span> <span class="n">model</span><span class="o">.</span><span class="n">generate</span><span class="p">(</span>
|
| 3937 |
<span class="o">**</span><span class="n">inputs</span><span class="p">,</span>
|
|
@@ -3940,144 +4040,124 @@ Cell: forward_only | 118.48s | FAILED
|
|
| 3940 |
<span class="n">temperature</span><span class="o">=</span><span class="kc">None</span><span class="p">,</span>
|
| 3941 |
<span class="p">)</span>
|
| 3942 |
<span class="n">end_time</span> <span class="o">=</span> <span class="n">time</span><span class="o">.</span><span class="n">perf_counter</span><span class="p">()</span>
|
| 3943 |
-
|
| 3944 |
-
<span class="nb">print</span><span class="p">(</span><span class="
|
| 3945 |
-
<span class="nb">print</span><span class="p">(</span><span class="sa">f</span><span class="s2">"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 3946 |
</pre></div>
|
| 3947 |
|
| 3948 |
-
<div class="code-line-highlight" id="line-highlight-
|
| 3949 |
</div>
|
| 3950 |
</div>
|
| 3951 |
</div>
|
| 3952 |
-
<div id="output-
|
| 3953 |
-
<div class="
|
| 3954 |
-
|
| 3955 |
-
<div class="uv-logs-content" style="display: none;">
|
| 3956 |
Updating https://github.com/huggingface/transformers.git (HEAD)
|
| 3957 |
-
Updated https://github.com/huggingface/transformers.git (
|
| 3958 |
-
|
| 3959 |
-
|
| 3960 |
-
|
| 3961 |
-
|
| 3962 |
-
|
| 3963 |
-
|
| 3964 |
-
|
| 3965 |
-
|
| 3966 |
-
|
| 3967 |
-
|
| 3968 |
-
Downloading nvidia-cusparselt-cu12 (273.9MiB)
|
| 3969 |
-
Downloading nvidia-cusolver-cu12 (255.1MiB)
|
| 3970 |
-
Downloading nvidia-cufft-cu12 (184.2MiB)
|
| 3971 |
-
Downloading nvidia-curand-cu12 (60.7MiB)
|
| 3972 |
-
Downloading nvidia-cublas-cu12 (566.8MiB)
|
| 3973 |
-
Downloading nvidia-nccl-cu12 (307.4MiB)
|
| 3974 |
-
Downloading nvidia-cudnn-cu12 (674.0MiB)
|
| 3975 |
-
Downloading nvidia-cuda-nvrtc-cu12 (84.0MiB)
|
| 3976 |
-
Downloading sympy (6.0MiB)
|
| 3977 |
-
Downloading fonttools (4.7MiB)
|
| 3978 |
-
Downloading nvidia-cusparse-cu12 (274.9MiB)
|
| 3979 |
-
Downloading torch (846.8MiB)
|
| 3980 |
-
Downloading numpy (15.9MiB)
|
| 3981 |
-
Downloading matplotlib (8.3MiB)
|
| 3982 |
-
Downloading kiwisolver (1.4MiB)
|
| 3983 |
-
Downloading nvidia-cufile-cu12
|
| 3984 |
-
Downloading kiwisolver
|
| 3985 |
-
Downloading hf-xet
|
| 3986 |
-
Downloading tokenizers
|
| 3987 |
-
Downloading networkx
|
| 3988 |
-
Downloading fonttools
|
| 3989 |
-
Downloading pillow
|
| 3990 |
-
Downloading matplotlib
|
| 3991 |
-
Downloading nvidia-cuda-cupti-cu12
|
| 3992 |
-
Downloading sympy
|
| 3993 |
-
Downloading numpy
|
| 3994 |
-
Downloading jedi
|
| 3995 |
-
Built transformers @ git+https://github.com/huggingface/transformers.git@7258ea44bc0c0a425a468f66f8559d1de8c4126d
|
| 3996 |
-
Downloading nvidia-nvjitlink-cu12
|
| 3997 |
-
Downloading nvidia-curand-cu12
|
| 3998 |
-
Downloading nvidia-cuda-nvrtc-cu12
|
| 3999 |
-
Downloading triton
|
| 4000 |
-
Downloading nvidia-cufft-cu12
|
| 4001 |
-
Downloading nvidia-cusolver-cu12
|
| 4002 |
-
Downloading nvidia-cusparselt-cu12
|
| 4003 |
-
Downloading nvidia-cusparse-cu12
|
| 4004 |
-
Downloading nvidia-nccl-cu12
|
| 4005 |
-
Downloading nvidia-cublas-cu12
|
| 4006 |
-
Downloading nvidia-cudnn-cu12
|
| 4007 |
-
Downloading torch
|
| 4008 |
-
Installed 69 packages in 321ms
|
| 4009 |
-
</div>
|
| 4010 |
</div>
|
| 4011 |
-
<div class="cell-stderr">Fetching 3 files: 0%| | 0/3 [00:00<?, ?it/s]
|
| 4012 |
-
Fetching 3 files: 0%| | 0/3 [00:50<?, ?it/s]
|
| 4013 |
-
Traceback (most recent call last):
|
| 4014 |
-
File "/home/runner/work/kernels-uvnotes/kernels-uvnotes/moe_benchmarks/megablocks/.uvnote/cells/forward_only.py", line 68, in <module>
|
| 4015 |
-
model = GptOssForCausalLM.from_pretrained(
|
| 4016 |
-
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
|
| 4017 |
-
File "/home/runner/work/_temp/setup-uv-cache/environments-v2/forward-only-b65004b2d0cb4ca8/lib/python3.12/site-packages/transformers/modeling_utils.py", line 285, in _wrapper
|
| 4018 |
-
return func(*args, **kwargs)
|
| 4019 |
-
^^^^^^^^^^^^^^^^^^^^^
|
| 4020 |
-
File "/home/runner/work/_temp/setup-uv-cache/environments-v2/forward-only-b65004b2d0cb4ca8/lib/python3.12/site-packages/transformers/modeling_utils.py", line 4904, in from_pretrained
|
| 4021 |
-
checkpoint_files, sharded_metadata = _get_resolved_checkpoint_files(
|
| 4022 |
-
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
|
| 4023 |
-
File "/home/runner/work/_temp/setup-uv-cache/environments-v2/forward-only-b65004b2d0cb4ca8/lib/python3.12/site-packages/transformers/modeling_utils.py", line 1239, in _get_resolved_checkpoint_files
|
| 4024 |
-
checkpoint_files, sharded_metadata = get_checkpoint_shard_files(
|
| 4025 |
-
^^^^^^^^^^^^^^^^^^^^^^^^^^^
|
| 4026 |
-
File "/home/runner/work/_temp/setup-uv-cache/environments-v2/forward-only-b65004b2d0cb4ca8/lib/python3.12/site-packages/transformers/utils/hub.py", line 1116, in get_checkpoint_shard_files
|
| 4027 |
-
cached_filenames = cached_files(
|
| 4028 |
-
^^^^^^^^^^^^^
|
| 4029 |
-
File "/home/runner/work/_temp/setup-uv-cache/environments-v2/forward-only-b65004b2d0cb4ca8/lib/python3.12/site-packages/transformers/utils/hub.py", line 564, in cached_files
|
| 4030 |
-
raise e
|
| 4031 |
-
File "/home/runner/work/_temp/setup-uv-cache/environments-v2/forward-only-b65004b2d0cb4ca8/lib/python3.12/site-packages/transformers/utils/hub.py", line 491, in cached_files
|
| 4032 |
-
snapshot_download(
|
| 4033 |
-
File "/home/runner/work/_temp/setup-uv-cache/environments-v2/forward-only-b65004b2d0cb4ca8/lib/python3.12/site-packages/huggingface_hub/utils/_validators.py", line 114, in _inner_fn
|
| 4034 |
-
return fn(*args, **kwargs)
|
| 4035 |
-
^^^^^^^^^^^^^^^^^^^
|
| 4036 |
-
File "/home/runner/work/_temp/setup-uv-cache/environments-v2/forward-only-b65004b2d0cb4ca8/lib/python3.12/site-packages/huggingface_hub/_snapshot_download.py", line 332, in snapshot_download
|
| 4037 |
-
thread_map(
|
| 4038 |
-
File "/home/runner/work/_temp/setup-uv-cache/environments-v2/forward-only-b65004b2d0cb4ca8/lib/python3.12/site-packages/tqdm/contrib/concurrent.py", line 69, in thread_map
|
| 4039 |
-
return _executor_map(ThreadPoolExecutor, fn, *iterables, **tqdm_kwargs)
|
| 4040 |
-
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
|
| 4041 |
-
File "/home/runner/work/_temp/setup-uv-cache/environments-v2/forward-only-b65004b2d0cb4ca8/lib/python3.12/site-packages/tqdm/contrib/concurrent.py", line 51, in _executor_map
|
| 4042 |
-
return list(tqdm_class(ex.map(fn, *iterables, chunksize=chunksize), **kwargs))
|
| 4043 |
-
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
|
| 4044 |
-
File "/home/runner/work/_temp/setup-uv-cache/environments-v2/forward-only-b65004b2d0cb4ca8/lib/python3.12/site-packages/tqdm/std.py", line 1181, in __iter__
|
| 4045 |
-
for obj in iterable:
|
| 4046 |
-
File "/usr/lib/python3.12/concurrent/futures/_base.py", line 619, in result_iterator
|
| 4047 |
-
yield _result_or_cancel(fs.pop())
|
| 4048 |
-
^^^^^^^^^^^^^^^^^^^^^^^^^^^
|
| 4049 |
-
File "/usr/lib/python3.12/concurrent/futures/_base.py", line 317, in _result_or_cancel
|
| 4050 |
-
return fut.result(timeout)
|
| 4051 |
-
^^^^^^^^^^^^^^^^^^^
|
| 4052 |
-
File "/usr/lib/python3.12/concurrent/futures/_base.py", line 456, in result
|
| 4053 |
-
return self.__get_result()
|
| 4054 |
-
^^^^^^^^^^^^^^^^^^^
|
| 4055 |
-
File "/usr/lib/python3.12/concurrent/futures/_base.py", line 401, in __get_result
|
| 4056 |
-
raise self._exception
|
| 4057 |
-
File "/usr/lib/python3.12/concurrent/futures/thread.py", line 58, in run
|
| 4058 |
-
result = self.fn(*self.args, **self.kwargs)
|
| 4059 |
-
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
|
| 4060 |
-
File "/home/runner/work/_temp/setup-uv-cache/environments-v2/forward-only-b65004b2d0cb4ca8/lib/python3.12/site-packages/huggingface_hub/_snapshot_download.py", line 306, in _inner_hf_hub_download
|
| 4061 |
-
return hf_hub_download(
|
| 4062 |
-
^^^^^^^^^^^^^^^^
|
| 4063 |
-
File "/home/runner/work/_temp/setup-uv-cache/environments-v2/forward-only-b65004b2d0cb4ca8/lib/python3.12/site-packages/huggingface_hub/utils/_validators.py", line 114, in _inner_fn
|
| 4064 |
-
return fn(*args, **kwargs)
|
| 4065 |
-
^^^^^^^^^^^^^^^^^^^
|
| 4066 |
-
File "/home/runner/work/_temp/setup-uv-cache/environments-v2/forward-only-b65004b2d0cb4ca8/lib/python3.12/site-packages/huggingface_hub/file_download.py", line 1010, in hf_hub_download
|
| 4067 |
-
return _hf_hub_download_to_cache_dir(
|
| 4068 |
-
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
|
| 4069 |
-
File "/home/runner/work/_temp/setup-uv-cache/environments-v2/forward-only-b65004b2d0cb4ca8/lib/python3.12/site-packages/huggingface_hub/file_download.py", line 1171, in _hf_hub_download_to_cache_dir
|
| 4070 |
-
_download_to_tmp_and_move(
|
| 4071 |
-
File "/home/runner/work/_temp/setup-uv-cache/environments-v2/forward-only-b65004b2d0cb4ca8/lib/python3.12/site-packages/huggingface_hub/file_download.py", line 1723, in _download_to_tmp_and_move
|
| 4072 |
-
xet_get(
|
| 4073 |
-
File "/home/runner/work/_temp/setup-uv-cache/environments-v2/forward-only-b65004b2d0cb4ca8/lib/python3.12/site-packages/huggingface_hub/file_download.py", line 629, in xet_get
|
| 4074 |
-
download_files(
|
| 4075 |
-
RuntimeError: Data processing error: CAS service error : IO Error: No space left on device (os error 28)</div>
|
| 4076 |
</div>
|
| 4077 |
</div>
|
| 4078 |
-
|
| 4079 |
-
<h2>Forward and Backward</h2>
|
| 4080 |
-
<p>Next, we run a forward and backward pass with Megablocks kernels enabled. This should be more memory efficient and allow us to complete the backward pass without running out of memory.</p>
|
| 4081 |
</div>
|
| 4082 |
|
| 4083 |
</body>
|
|
|
|
| 3710 |
<div class="system-info">
|
| 3711 |
<div class="system-info-header">Generated on:</div>
|
| 3712 |
<div class="system-info-content">
|
| 3713 |
+
Linux x86_64 | Linux-6.12.40-64.114.amzn2023.x86_64-x86_64-with-glibc2.36
|
| 3714 |
</div>
|
| 3715 |
</div>
|
| 3716 |
|
|
|
|
| 3724 |
<p>Next we can run with Megablocks kernels enabled.</p>
|
| 3725 |
<h3>Forward</h3>
|
| 3726 |
<p>First, we run a forward pass with Megablocks kernels.</p>
|
| 3727 |
+
<h2>Forward and Backward</h2>
|
| 3728 |
+
<p>Next, we run a forward and backward pass with Megablocks kernels enabled. This should be more memory efficient and allow us to complete the backward pass without running out of memory.</p>
|
| 3729 |
+
<div class="cell cell-failed" id="cell-forward_and_backward">
|
| 3730 |
<div class="cell-header">
|
| 3731 |
<span class="collapse-indicators">
|
| 3732 |
+
<span onclick="toggleCode('forward_and_backward')" style="cursor: pointer;">▼ code</span>
|
| 3733 |
+
<span onclick="toggleOutput('forward_and_backward')" style="cursor: pointer;">▼ output</span>
|
| 3734 |
+
<span id="uv-indicator-forward_and_backward" style="cursor: default; opacity: 0.3;">▶ uv-logs</span>
|
| 3735 |
</span> |
|
| 3736 |
+
Cell: forward_and_backward | 19.43s | FAILED
|
| 3737 |
+
| <button class="run-btn" onclick="runCell('forward_and_backward')">▶ run</button>
|
| 3738 |
+
<button class="copy-btn" onclick="copyCell('forward_and_backward')">Copy</button>
|
| 3739 |
+
<a href="cells/forward_and_backward.py" target="_blank" class="raw-btn">Raw</a>
|
| 3740 |
</div>
|
| 3741 |
+
<div id="code-forward_and_backward" class="cell-code" data-lines="196">
|
| 3742 |
<div class="highlight-with-lines">
|
| 3743 |
+
<div class="line-numbers" id="lines-forward_and_backward">
|
| 3744 |
+
<a class="line-number" data-cell="forward_and_backward" data-line="1" href="#cell-forward_and_backward" onclick="event.preventDefault(); selectCellLine('forward_and_backward', 1, true);">1</a>
|
| 3745 |
+
<a class="line-number" data-cell="forward_and_backward" data-line="2" href="#cell-forward_and_backward" onclick="event.preventDefault(); selectCellLine('forward_and_backward', 2, true);">2</a>
|
| 3746 |
+
<a class="line-number" data-cell="forward_and_backward" data-line="3" href="#cell-forward_and_backward" onclick="event.preventDefault(); selectCellLine('forward_and_backward', 3, true);">3</a>
|
| 3747 |
+
<a class="line-number" data-cell="forward_and_backward" data-line="4" href="#cell-forward_and_backward" onclick="event.preventDefault(); selectCellLine('forward_and_backward', 4, true);">4</a>
|
| 3748 |
+
<a class="line-number" data-cell="forward_and_backward" data-line="5" href="#cell-forward_and_backward" onclick="event.preventDefault(); selectCellLine('forward_and_backward', 5, true);">5</a>
|
| 3749 |
+
<a class="line-number" data-cell="forward_and_backward" data-line="6" href="#cell-forward_and_backward" onclick="event.preventDefault(); selectCellLine('forward_and_backward', 6, true);">6</a>
|
| 3750 |
+
<a class="line-number" data-cell="forward_and_backward" data-line="7" href="#cell-forward_and_backward" onclick="event.preventDefault(); selectCellLine('forward_and_backward', 7, true);">7</a>
|
| 3751 |
+
<a class="line-number" data-cell="forward_and_backward" data-line="8" href="#cell-forward_and_backward" onclick="event.preventDefault(); selectCellLine('forward_and_backward', 8, true);">8</a>
|
| 3752 |
+
<a class="line-number" data-cell="forward_and_backward" data-line="9" href="#cell-forward_and_backward" onclick="event.preventDefault(); selectCellLine('forward_and_backward', 9, true);">9</a>
|
| 3753 |
+
<a class="line-number" data-cell="forward_and_backward" data-line="10" href="#cell-forward_and_backward" onclick="event.preventDefault(); selectCellLine('forward_and_backward', 10, true);">10</a>
|
| 3754 |
+
<a class="line-number" data-cell="forward_and_backward" data-line="11" href="#cell-forward_and_backward" onclick="event.preventDefault(); selectCellLine('forward_and_backward', 11, true);">11</a>
|
| 3755 |
+
<a class="line-number" data-cell="forward_and_backward" data-line="12" href="#cell-forward_and_backward" onclick="event.preventDefault(); selectCellLine('forward_and_backward', 12, true);">12</a>
|
| 3756 |
+
<a class="line-number" data-cell="forward_and_backward" data-line="13" href="#cell-forward_and_backward" onclick="event.preventDefault(); selectCellLine('forward_and_backward', 13, true);">13</a>
|
| 3757 |
+
<a class="line-number" data-cell="forward_and_backward" data-line="14" href="#cell-forward_and_backward" onclick="event.preventDefault(); selectCellLine('forward_and_backward', 14, true);">14</a>
|
| 3758 |
+
<a class="line-number" data-cell="forward_and_backward" data-line="15" href="#cell-forward_and_backward" onclick="event.preventDefault(); selectCellLine('forward_and_backward', 15, true);">15</a>
|
| 3759 |
+
<a class="line-number" data-cell="forward_and_backward" data-line="16" href="#cell-forward_and_backward" onclick="event.preventDefault(); selectCellLine('forward_and_backward', 16, true);">16</a>
|
| 3760 |
+
<a class="line-number" data-cell="forward_and_backward" data-line="17" href="#cell-forward_and_backward" onclick="event.preventDefault(); selectCellLine('forward_and_backward', 17, true);">17</a>
|
| 3761 |
+
<a class="line-number" data-cell="forward_and_backward" data-line="18" href="#cell-forward_and_backward" onclick="event.preventDefault(); selectCellLine('forward_and_backward', 18, true);">18</a>
|
| 3762 |
+
<a class="line-number" data-cell="forward_and_backward" data-line="19" href="#cell-forward_and_backward" onclick="event.preventDefault(); selectCellLine('forward_and_backward', 19, true);">19</a>
|
| 3763 |
+
<a class="line-number" data-cell="forward_and_backward" data-line="20" href="#cell-forward_and_backward" onclick="event.preventDefault(); selectCellLine('forward_and_backward', 20, true);">20</a>
|
| 3764 |
+
<a class="line-number" data-cell="forward_and_backward" data-line="21" href="#cell-forward_and_backward" onclick="event.preventDefault(); selectCellLine('forward_and_backward', 21, true);">21</a>
|
| 3765 |
+
<a class="line-number" data-cell="forward_and_backward" data-line="22" href="#cell-forward_and_backward" onclick="event.preventDefault(); selectCellLine('forward_and_backward', 22, true);">22</a>
|
| 3766 |
+
<a class="line-number" data-cell="forward_and_backward" data-line="23" href="#cell-forward_and_backward" onclick="event.preventDefault(); selectCellLine('forward_and_backward', 23, true);">23</a>
|
| 3767 |
+
<a class="line-number" data-cell="forward_and_backward" data-line="24" href="#cell-forward_and_backward" onclick="event.preventDefault(); selectCellLine('forward_and_backward', 24, true);">24</a>
|
| 3768 |
+
<a class="line-number" data-cell="forward_and_backward" data-line="25" href="#cell-forward_and_backward" onclick="event.preventDefault(); selectCellLine('forward_and_backward', 25, true);">25</a>
|
| 3769 |
+
<a class="line-number" data-cell="forward_and_backward" data-line="26" href="#cell-forward_and_backward" onclick="event.preventDefault(); selectCellLine('forward_and_backward', 26, true);">26</a>
|
| 3770 |
+
<a class="line-number" data-cell="forward_and_backward" data-line="27" href="#cell-forward_and_backward" onclick="event.preventDefault(); selectCellLine('forward_and_backward', 27, true);">27</a>
|
| 3771 |
+
<a class="line-number" data-cell="forward_and_backward" data-line="28" href="#cell-forward_and_backward" onclick="event.preventDefault(); selectCellLine('forward_and_backward', 28, true);">28</a>
|
| 3772 |
+
<a class="line-number" data-cell="forward_and_backward" data-line="29" href="#cell-forward_and_backward" onclick="event.preventDefault(); selectCellLine('forward_and_backward', 29, true);">29</a>
|
| 3773 |
+
<a class="line-number" data-cell="forward_and_backward" data-line="30" href="#cell-forward_and_backward" onclick="event.preventDefault(); selectCellLine('forward_and_backward', 30, true);">30</a>
|
| 3774 |
+
<a class="line-number" data-cell="forward_and_backward" data-line="31" href="#cell-forward_and_backward" onclick="event.preventDefault(); selectCellLine('forward_and_backward', 31, true);">31</a>
|
| 3775 |
+
<a class="line-number" data-cell="forward_and_backward" data-line="32" href="#cell-forward_and_backward" onclick="event.preventDefault(); selectCellLine('forward_and_backward', 32, true);">32</a>
|
| 3776 |
+
<a class="line-number" data-cell="forward_and_backward" data-line="33" href="#cell-forward_and_backward" onclick="event.preventDefault(); selectCellLine('forward_and_backward', 33, true);">33</a>
|
| 3777 |
+
<a class="line-number" data-cell="forward_and_backward" data-line="34" href="#cell-forward_and_backward" onclick="event.preventDefault(); selectCellLine('forward_and_backward', 34, true);">34</a>
|
| 3778 |
+
<a class="line-number" data-cell="forward_and_backward" data-line="35" href="#cell-forward_and_backward" onclick="event.preventDefault(); selectCellLine('forward_and_backward', 35, true);">35</a>
|
| 3779 |
+
<a class="line-number" data-cell="forward_and_backward" data-line="36" href="#cell-forward_and_backward" onclick="event.preventDefault(); selectCellLine('forward_and_backward', 36, true);">36</a>
|
| 3780 |
+
<a class="line-number" data-cell="forward_and_backward" data-line="37" href="#cell-forward_and_backward" onclick="event.preventDefault(); selectCellLine('forward_and_backward', 37, true);">37</a>
|
| 3781 |
+
<a class="line-number" data-cell="forward_and_backward" data-line="38" href="#cell-forward_and_backward" onclick="event.preventDefault(); selectCellLine('forward_and_backward', 38, true);">38</a>
|
| 3782 |
+
<a class="line-number" data-cell="forward_and_backward" data-line="39" href="#cell-forward_and_backward" onclick="event.preventDefault(); selectCellLine('forward_and_backward', 39, true);">39</a>
|
| 3783 |
+
<a class="line-number" data-cell="forward_and_backward" data-line="40" href="#cell-forward_and_backward" onclick="event.preventDefault(); selectCellLine('forward_and_backward', 40, true);">40</a>
|
| 3784 |
+
<a class="line-number" data-cell="forward_and_backward" data-line="41" href="#cell-forward_and_backward" onclick="event.preventDefault(); selectCellLine('forward_and_backward', 41, true);">41</a>
|
| 3785 |
+
<a class="line-number" data-cell="forward_and_backward" data-line="42" href="#cell-forward_and_backward" onclick="event.preventDefault(); selectCellLine('forward_and_backward', 42, true);">42</a>
|
| 3786 |
+
<a class="line-number" data-cell="forward_and_backward" data-line="43" href="#cell-forward_and_backward" onclick="event.preventDefault(); selectCellLine('forward_and_backward', 43, true);">43</a>
|
| 3787 |
+
<a class="line-number" data-cell="forward_and_backward" data-line="44" href="#cell-forward_and_backward" onclick="event.preventDefault(); selectCellLine('forward_and_backward', 44, true);">44</a>
|
| 3788 |
+
<a class="line-number" data-cell="forward_and_backward" data-line="45" href="#cell-forward_and_backward" onclick="event.preventDefault(); selectCellLine('forward_and_backward', 45, true);">45</a>
|
| 3789 |
+
<a class="line-number" data-cell="forward_and_backward" data-line="46" href="#cell-forward_and_backward" onclick="event.preventDefault(); selectCellLine('forward_and_backward', 46, true);">46</a>
|
| 3790 |
+
<a class="line-number" data-cell="forward_and_backward" data-line="47" href="#cell-forward_and_backward" onclick="event.preventDefault(); selectCellLine('forward_and_backward', 47, true);">47</a>
|
| 3791 |
+
<a class="line-number" data-cell="forward_and_backward" data-line="48" href="#cell-forward_and_backward" onclick="event.preventDefault(); selectCellLine('forward_and_backward', 48, true);">48</a>
|
| 3792 |
+
<a class="line-number" data-cell="forward_and_backward" data-line="49" href="#cell-forward_and_backward" onclick="event.preventDefault(); selectCellLine('forward_and_backward', 49, true);">49</a>
|
| 3793 |
+
<a class="line-number" data-cell="forward_and_backward" data-line="50" href="#cell-forward_and_backward" onclick="event.preventDefault(); selectCellLine('forward_and_backward', 50, true);">50</a>
|
| 3794 |
+
<a class="line-number" data-cell="forward_and_backward" data-line="51" href="#cell-forward_and_backward" onclick="event.preventDefault(); selectCellLine('forward_and_backward', 51, true);">51</a>
|
| 3795 |
+
<a class="line-number" data-cell="forward_and_backward" data-line="52" href="#cell-forward_and_backward" onclick="event.preventDefault(); selectCellLine('forward_and_backward', 52, true);">52</a>
|
| 3796 |
+
<a class="line-number" data-cell="forward_and_backward" data-line="53" href="#cell-forward_and_backward" onclick="event.preventDefault(); selectCellLine('forward_and_backward', 53, true);">53</a>
|
| 3797 |
+
<a class="line-number" data-cell="forward_and_backward" data-line="54" href="#cell-forward_and_backward" onclick="event.preventDefault(); selectCellLine('forward_and_backward', 54, true);">54</a>
|
| 3798 |
+
<a class="line-number" data-cell="forward_and_backward" data-line="55" href="#cell-forward_and_backward" onclick="event.preventDefault(); selectCellLine('forward_and_backward', 55, true);">55</a>
|
| 3799 |
+
<a class="line-number" data-cell="forward_and_backward" data-line="56" href="#cell-forward_and_backward" onclick="event.preventDefault(); selectCellLine('forward_and_backward', 56, true);">56</a>
|
| 3800 |
+
<a class="line-number" data-cell="forward_and_backward" data-line="57" href="#cell-forward_and_backward" onclick="event.preventDefault(); selectCellLine('forward_and_backward', 57, true);">57</a>
|
| 3801 |
+
<a class="line-number" data-cell="forward_and_backward" data-line="58" href="#cell-forward_and_backward" onclick="event.preventDefault(); selectCellLine('forward_and_backward', 58, true);">58</a>
|
| 3802 |
+
<a class="line-number" data-cell="forward_and_backward" data-line="59" href="#cell-forward_and_backward" onclick="event.preventDefault(); selectCellLine('forward_and_backward', 59, true);">59</a>
|
| 3803 |
+
<a class="line-number" data-cell="forward_and_backward" data-line="60" href="#cell-forward_and_backward" onclick="event.preventDefault(); selectCellLine('forward_and_backward', 60, true);">60</a>
|
| 3804 |
+
<a class="line-number" data-cell="forward_and_backward" data-line="61" href="#cell-forward_and_backward" onclick="event.preventDefault(); selectCellLine('forward_and_backward', 61, true);">61</a>
|
| 3805 |
+
<a class="line-number" data-cell="forward_and_backward" data-line="62" href="#cell-forward_and_backward" onclick="event.preventDefault(); selectCellLine('forward_and_backward', 62, true);">62</a>
|
| 3806 |
+
<a class="line-number" data-cell="forward_and_backward" data-line="63" href="#cell-forward_and_backward" onclick="event.preventDefault(); selectCellLine('forward_and_backward', 63, true);">63</a>
|
| 3807 |
+
<a class="line-number" data-cell="forward_and_backward" data-line="64" href="#cell-forward_and_backward" onclick="event.preventDefault(); selectCellLine('forward_and_backward', 64, true);">64</a>
|
| 3808 |
+
<a class="line-number" data-cell="forward_and_backward" data-line="65" href="#cell-forward_and_backward" onclick="event.preventDefault(); selectCellLine('forward_and_backward', 65, true);">65</a>
|
| 3809 |
+
<a class="line-number" data-cell="forward_and_backward" data-line="66" href="#cell-forward_and_backward" onclick="event.preventDefault(); selectCellLine('forward_and_backward', 66, true);">66</a>
|
| 3810 |
+
<a class="line-number" data-cell="forward_and_backward" data-line="67" href="#cell-forward_and_backward" onclick="event.preventDefault(); selectCellLine('forward_and_backward', 67, true);">67</a>
|
| 3811 |
+
<a class="line-number" data-cell="forward_and_backward" data-line="68" href="#cell-forward_and_backward" onclick="event.preventDefault(); selectCellLine('forward_and_backward', 68, true);">68</a>
|
| 3812 |
+
<a class="line-number" data-cell="forward_and_backward" data-line="69" href="#cell-forward_and_backward" onclick="event.preventDefault(); selectCellLine('forward_and_backward', 69, true);">69</a>
|
| 3813 |
+
<a class="line-number" data-cell="forward_and_backward" data-line="70" href="#cell-forward_and_backward" onclick="event.preventDefault(); selectCellLine('forward_and_backward', 70, true);">70</a>
|
| 3814 |
+
<a class="line-number" data-cell="forward_and_backward" data-line="71" href="#cell-forward_and_backward" onclick="event.preventDefault(); selectCellLine('forward_and_backward', 71, true);">71</a>
|
| 3815 |
+
<a class="line-number" data-cell="forward_and_backward" data-line="72" href="#cell-forward_and_backward" onclick="event.preventDefault(); selectCellLine('forward_and_backward', 72, true);">72</a>
|
| 3816 |
+
<a class="line-number" data-cell="forward_and_backward" data-line="73" href="#cell-forward_and_backward" onclick="event.preventDefault(); selectCellLine('forward_and_backward', 73, true);">73</a>
|
| 3817 |
+
<a class="line-number" data-cell="forward_and_backward" data-line="74" href="#cell-forward_and_backward" onclick="event.preventDefault(); selectCellLine('forward_and_backward', 74, true);">74</a>
|
| 3818 |
+
<a class="line-number" data-cell="forward_and_backward" data-line="75" href="#cell-forward_and_backward" onclick="event.preventDefault(); selectCellLine('forward_and_backward', 75, true);">75</a>
|
| 3819 |
+
<a class="line-number" data-cell="forward_and_backward" data-line="76" href="#cell-forward_and_backward" onclick="event.preventDefault(); selectCellLine('forward_and_backward', 76, true);">76</a>
|
| 3820 |
+
<a class="line-number" data-cell="forward_and_backward" data-line="77" href="#cell-forward_and_backward" onclick="event.preventDefault(); selectCellLine('forward_and_backward', 77, true);">77</a>
|
| 3821 |
+
<a class="line-number" data-cell="forward_and_backward" data-line="78" href="#cell-forward_and_backward" onclick="event.preventDefault(); selectCellLine('forward_and_backward', 78, true);">78</a>
|
| 3822 |
+
<a class="line-number" data-cell="forward_and_backward" data-line="79" href="#cell-forward_and_backward" onclick="event.preventDefault(); selectCellLine('forward_and_backward', 79, true);">79</a>
|
| 3823 |
+
<a class="line-number" data-cell="forward_and_backward" data-line="80" href="#cell-forward_and_backward" onclick="event.preventDefault(); selectCellLine('forward_and_backward', 80, true);">80</a>
|
| 3824 |
+
<a class="line-number" data-cell="forward_and_backward" data-line="81" href="#cell-forward_and_backward" onclick="event.preventDefault(); selectCellLine('forward_and_backward', 81, true);">81</a>
|
| 3825 |
+
<a class="line-number" data-cell="forward_and_backward" data-line="82" href="#cell-forward_and_backward" onclick="event.preventDefault(); selectCellLine('forward_and_backward', 82, true);">82</a>
|
| 3826 |
+
<a class="line-number" data-cell="forward_and_backward" data-line="83" href="#cell-forward_and_backward" onclick="event.preventDefault(); selectCellLine('forward_and_backward', 83, true);">83</a>
|
| 3827 |
+
<a class="line-number" data-cell="forward_and_backward" data-line="84" href="#cell-forward_and_backward" onclick="event.preventDefault(); selectCellLine('forward_and_backward', 84, true);">84</a>
|
| 3828 |
+
<a class="line-number" data-cell="forward_and_backward" data-line="85" href="#cell-forward_and_backward" onclick="event.preventDefault(); selectCellLine('forward_and_backward', 85, true);">85</a>
|
| 3829 |
+
<a class="line-number" data-cell="forward_and_backward" data-line="86" href="#cell-forward_and_backward" onclick="event.preventDefault(); selectCellLine('forward_and_backward', 86, true);">86</a>
|
| 3830 |
+
<a class="line-number" data-cell="forward_and_backward" data-line="87" href="#cell-forward_and_backward" onclick="event.preventDefault(); selectCellLine('forward_and_backward', 87, true);">87</a>
|
| 3831 |
+
<a class="line-number" data-cell="forward_and_backward" data-line="88" href="#cell-forward_and_backward" onclick="event.preventDefault(); selectCellLine('forward_and_backward', 88, true);">88</a>
|
| 3832 |
+
<a class="line-number" data-cell="forward_and_backward" data-line="89" href="#cell-forward_and_backward" onclick="event.preventDefault(); selectCellLine('forward_and_backward', 89, true);">89</a>
|
| 3833 |
+
<a class="line-number" data-cell="forward_and_backward" data-line="90" href="#cell-forward_and_backward" onclick="event.preventDefault(); selectCellLine('forward_and_backward', 90, true);">90</a>
|
| 3834 |
+
<a class="line-number" data-cell="forward_and_backward" data-line="91" href="#cell-forward_and_backward" onclick="event.preventDefault(); selectCellLine('forward_and_backward', 91, true);">91</a>
|
| 3835 |
+
<a class="line-number" data-cell="forward_and_backward" data-line="92" href="#cell-forward_and_backward" onclick="event.preventDefault(); selectCellLine('forward_and_backward', 92, true);">92</a>
|
| 3836 |
+
<a class="line-number" data-cell="forward_and_backward" data-line="93" href="#cell-forward_and_backward" onclick="event.preventDefault(); selectCellLine('forward_and_backward', 93, true);">93</a>
|
| 3837 |
+
<a class="line-number" data-cell="forward_and_backward" data-line="94" href="#cell-forward_and_backward" onclick="event.preventDefault(); selectCellLine('forward_and_backward', 94, true);">94</a>
|
| 3838 |
+
<a class="line-number" data-cell="forward_and_backward" data-line="95" href="#cell-forward_and_backward" onclick="event.preventDefault(); selectCellLine('forward_and_backward', 95, true);">95</a>
|
| 3839 |
+
<a class="line-number" data-cell="forward_and_backward" data-line="96" href="#cell-forward_and_backward" onclick="event.preventDefault(); selectCellLine('forward_and_backward', 96, true);">96</a>
|
| 3840 |
+
<a class="line-number" data-cell="forward_and_backward" data-line="97" href="#cell-forward_and_backward" onclick="event.preventDefault(); selectCellLine('forward_and_backward', 97, true);">97</a>
|
| 3841 |
+
<a class="line-number" data-cell="forward_and_backward" data-line="98" href="#cell-forward_and_backward" onclick="event.preventDefault(); selectCellLine('forward_and_backward', 98, true);">98</a>
|
| 3842 |
+
<a class="line-number" data-cell="forward_and_backward" data-line="99" href="#cell-forward_and_backward" onclick="event.preventDefault(); selectCellLine('forward_and_backward', 99, true);">99</a>
|
| 3843 |
+
<a class="line-number" data-cell="forward_and_backward" data-line="100" href="#cell-forward_and_backward" onclick="event.preventDefault(); selectCellLine('forward_and_backward', 100, true);">100</a>
|
| 3844 |
+
<a class="line-number" data-cell="forward_and_backward" data-line="101" href="#cell-forward_and_backward" onclick="event.preventDefault(); selectCellLine('forward_and_backward', 101, true);">101</a>
|
| 3845 |
+
<a class="line-number" data-cell="forward_and_backward" data-line="102" href="#cell-forward_and_backward" onclick="event.preventDefault(); selectCellLine('forward_and_backward', 102, true);">102</a>
|
| 3846 |
+
<a class="line-number" data-cell="forward_and_backward" data-line="103" href="#cell-forward_and_backward" onclick="event.preventDefault(); selectCellLine('forward_and_backward', 103, true);">103</a>
|
| 3847 |
+
<a class="line-number" data-cell="forward_and_backward" data-line="104" href="#cell-forward_and_backward" onclick="event.preventDefault(); selectCellLine('forward_and_backward', 104, true);">104</a>
|
| 3848 |
+
<a class="line-number" data-cell="forward_and_backward" data-line="105" href="#cell-forward_and_backward" onclick="event.preventDefault(); selectCellLine('forward_and_backward', 105, true);">105</a>
|
| 3849 |
+
<a class="line-number" data-cell="forward_and_backward" data-line="106" href="#cell-forward_and_backward" onclick="event.preventDefault(); selectCellLine('forward_and_backward', 106, true);">106</a>
|
| 3850 |
+
<a class="line-number" data-cell="forward_and_backward" data-line="107" href="#cell-forward_and_backward" onclick="event.preventDefault(); selectCellLine('forward_and_backward', 107, true);">107</a>
|
| 3851 |
+
<a class="line-number" data-cell="forward_and_backward" data-line="108" href="#cell-forward_and_backward" onclick="event.preventDefault(); selectCellLine('forward_and_backward', 108, true);">108</a>
|
| 3852 |
+
<a class="line-number" data-cell="forward_and_backward" data-line="109" href="#cell-forward_and_backward" onclick="event.preventDefault(); selectCellLine('forward_and_backward', 109, true);">109</a>
|
| 3853 |
+
<a class="line-number" data-cell="forward_and_backward" data-line="110" href="#cell-forward_and_backward" onclick="event.preventDefault(); selectCellLine('forward_and_backward', 110, true);">110</a>
|
| 3854 |
+
<a class="line-number" data-cell="forward_and_backward" data-line="111" href="#cell-forward_and_backward" onclick="event.preventDefault(); selectCellLine('forward_and_backward', 111, true);">111</a>
|
| 3855 |
+
<a class="line-number" data-cell="forward_and_backward" data-line="112" href="#cell-forward_and_backward" onclick="event.preventDefault(); selectCellLine('forward_and_backward', 112, true);">112</a>
|
| 3856 |
+
<a class="line-number" data-cell="forward_and_backward" data-line="113" href="#cell-forward_and_backward" onclick="event.preventDefault(); selectCellLine('forward_and_backward', 113, true);">113</a>
|
| 3857 |
+
<a class="line-number" data-cell="forward_and_backward" data-line="114" href="#cell-forward_and_backward" onclick="event.preventDefault(); selectCellLine('forward_and_backward', 114, true);">114</a>
|
| 3858 |
+
<a class="line-number" data-cell="forward_and_backward" data-line="115" href="#cell-forward_and_backward" onclick="event.preventDefault(); selectCellLine('forward_and_backward', 115, true);">115</a>
|
| 3859 |
+
<a class="line-number" data-cell="forward_and_backward" data-line="116" href="#cell-forward_and_backward" onclick="event.preventDefault(); selectCellLine('forward_and_backward', 116, true);">116</a>
|
| 3860 |
+
<a class="line-number" data-cell="forward_and_backward" data-line="117" href="#cell-forward_and_backward" onclick="event.preventDefault(); selectCellLine('forward_and_backward', 117, true);">117</a>
|
| 3861 |
+
<a class="line-number" data-cell="forward_and_backward" data-line="118" href="#cell-forward_and_backward" onclick="event.preventDefault(); selectCellLine('forward_and_backward', 118, true);">118</a>
|
| 3862 |
+
<a class="line-number" data-cell="forward_and_backward" data-line="119" href="#cell-forward_and_backward" onclick="event.preventDefault(); selectCellLine('forward_and_backward', 119, true);">119</a>
|
| 3863 |
+
<a class="line-number" data-cell="forward_and_backward" data-line="120" href="#cell-forward_and_backward" onclick="event.preventDefault(); selectCellLine('forward_and_backward', 120, true);">120</a>
|
| 3864 |
+
<a class="line-number" data-cell="forward_and_backward" data-line="121" href="#cell-forward_and_backward" onclick="event.preventDefault(); selectCellLine('forward_and_backward', 121, true);">121</a>
|
| 3865 |
+
<a class="line-number" data-cell="forward_and_backward" data-line="122" href="#cell-forward_and_backward" onclick="event.preventDefault(); selectCellLine('forward_and_backward', 122, true);">122</a>
|
| 3866 |
+
<a class="line-number" data-cell="forward_and_backward" data-line="123" href="#cell-forward_and_backward" onclick="event.preventDefault(); selectCellLine('forward_and_backward', 123, true);">123</a>
|
| 3867 |
+
<a class="line-number" data-cell="forward_and_backward" data-line="124" href="#cell-forward_and_backward" onclick="event.preventDefault(); selectCellLine('forward_and_backward', 124, true);">124</a>
|
| 3868 |
+
<a class="line-number" data-cell="forward_and_backward" data-line="125" href="#cell-forward_and_backward" onclick="event.preventDefault(); selectCellLine('forward_and_backward', 125, true);">125</a>
|
| 3869 |
+
<a class="line-number" data-cell="forward_and_backward" data-line="126" href="#cell-forward_and_backward" onclick="event.preventDefault(); selectCellLine('forward_and_backward', 126, true);">126</a>
|
| 3870 |
+
<a class="line-number" data-cell="forward_and_backward" data-line="127" href="#cell-forward_and_backward" onclick="event.preventDefault(); selectCellLine('forward_and_backward', 127, true);">127</a>
|
| 3871 |
+
<a class="line-number" data-cell="forward_and_backward" data-line="128" href="#cell-forward_and_backward" onclick="event.preventDefault(); selectCellLine('forward_and_backward', 128, true);">128</a>
|
| 3872 |
+
<a class="line-number" data-cell="forward_and_backward" data-line="129" href="#cell-forward_and_backward" onclick="event.preventDefault(); selectCellLine('forward_and_backward', 129, true);">129</a>
|
| 3873 |
+
<a class="line-number" data-cell="forward_and_backward" data-line="130" href="#cell-forward_and_backward" onclick="event.preventDefault(); selectCellLine('forward_and_backward', 130, true);">130</a>
|
| 3874 |
+
<a class="line-number" data-cell="forward_and_backward" data-line="131" href="#cell-forward_and_backward" onclick="event.preventDefault(); selectCellLine('forward_and_backward', 131, true);">131</a>
|
| 3875 |
+
<a class="line-number" data-cell="forward_and_backward" data-line="132" href="#cell-forward_and_backward" onclick="event.preventDefault(); selectCellLine('forward_and_backward', 132, true);">132</a>
|
| 3876 |
+
<a class="line-number" data-cell="forward_and_backward" data-line="133" href="#cell-forward_and_backward" onclick="event.preventDefault(); selectCellLine('forward_and_backward', 133, true);">133</a>
|
| 3877 |
+
<a class="line-number" data-cell="forward_and_backward" data-line="134" href="#cell-forward_and_backward" onclick="event.preventDefault(); selectCellLine('forward_and_backward', 134, true);">134</a>
|
| 3878 |
+
<a class="line-number" data-cell="forward_and_backward" data-line="135" href="#cell-forward_and_backward" onclick="event.preventDefault(); selectCellLine('forward_and_backward', 135, true);">135</a>
|
| 3879 |
+
<a class="line-number" data-cell="forward_and_backward" data-line="136" href="#cell-forward_and_backward" onclick="event.preventDefault(); selectCellLine('forward_and_backward', 136, true);">136</a>
|
| 3880 |
+
<a class="line-number" data-cell="forward_and_backward" data-line="137" href="#cell-forward_and_backward" onclick="event.preventDefault(); selectCellLine('forward_and_backward', 137, true);">137</a>
|
| 3881 |
+
<a class="line-number" data-cell="forward_and_backward" data-line="138" href="#cell-forward_and_backward" onclick="event.preventDefault(); selectCellLine('forward_and_backward', 138, true);">138</a>
|
| 3882 |
+
<a class="line-number" data-cell="forward_and_backward" data-line="139" href="#cell-forward_and_backward" onclick="event.preventDefault(); selectCellLine('forward_and_backward', 139, true);">139</a>
|
| 3883 |
+
<a class="line-number" data-cell="forward_and_backward" data-line="140" href="#cell-forward_and_backward" onclick="event.preventDefault(); selectCellLine('forward_and_backward', 140, true);">140</a>
|
| 3884 |
+
<a class="line-number" data-cell="forward_and_backward" data-line="141" href="#cell-forward_and_backward" onclick="event.preventDefault(); selectCellLine('forward_and_backward', 141, true);">141</a>
|
| 3885 |
+
<a class="line-number" data-cell="forward_and_backward" data-line="142" href="#cell-forward_and_backward" onclick="event.preventDefault(); selectCellLine('forward_and_backward', 142, true);">142</a>
|
| 3886 |
+
<a class="line-number" data-cell="forward_and_backward" data-line="143" href="#cell-forward_and_backward" onclick="event.preventDefault(); selectCellLine('forward_and_backward', 143, true);">143</a>
|
| 3887 |
+
<a class="line-number" data-cell="forward_and_backward" data-line="144" href="#cell-forward_and_backward" onclick="event.preventDefault(); selectCellLine('forward_and_backward', 144, true);">144</a>
|
| 3888 |
+
<a class="line-number" data-cell="forward_and_backward" data-line="145" href="#cell-forward_and_backward" onclick="event.preventDefault(); selectCellLine('forward_and_backward', 145, true);">145</a>
|
| 3889 |
+
<a class="line-number" data-cell="forward_and_backward" data-line="146" href="#cell-forward_and_backward" onclick="event.preventDefault(); selectCellLine('forward_and_backward', 146, true);">146</a>
|
| 3890 |
+
<a class="line-number" data-cell="forward_and_backward" data-line="147" href="#cell-forward_and_backward" onclick="event.preventDefault(); selectCellLine('forward_and_backward', 147, true);">147</a>
|
| 3891 |
+
<a class="line-number" data-cell="forward_and_backward" data-line="148" href="#cell-forward_and_backward" onclick="event.preventDefault(); selectCellLine('forward_and_backward', 148, true);">148</a>
|
| 3892 |
+
<a class="line-number" data-cell="forward_and_backward" data-line="149" href="#cell-forward_and_backward" onclick="event.preventDefault(); selectCellLine('forward_and_backward', 149, true);">149</a>
|
| 3893 |
+
<a class="line-number" data-cell="forward_and_backward" data-line="150" href="#cell-forward_and_backward" onclick="event.preventDefault(); selectCellLine('forward_and_backward', 150, true);">150</a>
|
| 3894 |
+
<a class="line-number" data-cell="forward_and_backward" data-line="151" href="#cell-forward_and_backward" onclick="event.preventDefault(); selectCellLine('forward_and_backward', 151, true);">151</a>
|
| 3895 |
+
<a class="line-number" data-cell="forward_and_backward" data-line="152" href="#cell-forward_and_backward" onclick="event.preventDefault(); selectCellLine('forward_and_backward', 152, true);">152</a>
|
| 3896 |
+
<a class="line-number" data-cell="forward_and_backward" data-line="153" href="#cell-forward_and_backward" onclick="event.preventDefault(); selectCellLine('forward_and_backward', 153, true);">153</a>
|
| 3897 |
+
<a class="line-number" data-cell="forward_and_backward" data-line="154" href="#cell-forward_and_backward" onclick="event.preventDefault(); selectCellLine('forward_and_backward', 154, true);">154</a>
|
| 3898 |
+
<a class="line-number" data-cell="forward_and_backward" data-line="155" href="#cell-forward_and_backward" onclick="event.preventDefault(); selectCellLine('forward_and_backward', 155, true);">155</a>
|
| 3899 |
+
<a class="line-number" data-cell="forward_and_backward" data-line="156" href="#cell-forward_and_backward" onclick="event.preventDefault(); selectCellLine('forward_and_backward', 156, true);">156</a>
|
| 3900 |
+
<a class="line-number" data-cell="forward_and_backward" data-line="157" href="#cell-forward_and_backward" onclick="event.preventDefault(); selectCellLine('forward_and_backward', 157, true);">157</a>
|
| 3901 |
+
<a class="line-number" data-cell="forward_and_backward" data-line="158" href="#cell-forward_and_backward" onclick="event.preventDefault(); selectCellLine('forward_and_backward', 158, true);">158</a>
|
| 3902 |
+
<a class="line-number" data-cell="forward_and_backward" data-line="159" href="#cell-forward_and_backward" onclick="event.preventDefault(); selectCellLine('forward_and_backward', 159, true);">159</a>
|
| 3903 |
+
<a class="line-number" data-cell="forward_and_backward" data-line="160" href="#cell-forward_and_backward" onclick="event.preventDefault(); selectCellLine('forward_and_backward', 160, true);">160</a>
|
| 3904 |
+
<a class="line-number" data-cell="forward_and_backward" data-line="161" href="#cell-forward_and_backward" onclick="event.preventDefault(); selectCellLine('forward_and_backward', 161, true);">161</a>
|
| 3905 |
+
<a class="line-number" data-cell="forward_and_backward" data-line="162" href="#cell-forward_and_backward" onclick="event.preventDefault(); selectCellLine('forward_and_backward', 162, true);">162</a>
|
| 3906 |
+
<a class="line-number" data-cell="forward_and_backward" data-line="163" href="#cell-forward_and_backward" onclick="event.preventDefault(); selectCellLine('forward_and_backward', 163, true);">163</a>
|
| 3907 |
+
<a class="line-number" data-cell="forward_and_backward" data-line="164" href="#cell-forward_and_backward" onclick="event.preventDefault(); selectCellLine('forward_and_backward', 164, true);">164</a>
|
| 3908 |
+
<a class="line-number" data-cell="forward_and_backward" data-line="165" href="#cell-forward_and_backward" onclick="event.preventDefault(); selectCellLine('forward_and_backward', 165, true);">165</a>
|
| 3909 |
+
<a class="line-number" data-cell="forward_and_backward" data-line="166" href="#cell-forward_and_backward" onclick="event.preventDefault(); selectCellLine('forward_and_backward', 166, true);">166</a>
|
| 3910 |
+
<a class="line-number" data-cell="forward_and_backward" data-line="167" href="#cell-forward_and_backward" onclick="event.preventDefault(); selectCellLine('forward_and_backward', 167, true);">167</a>
|
| 3911 |
+
<a class="line-number" data-cell="forward_and_backward" data-line="168" href="#cell-forward_and_backward" onclick="event.preventDefault(); selectCellLine('forward_and_backward', 168, true);">168</a>
|
| 3912 |
+
<a class="line-number" data-cell="forward_and_backward" data-line="169" href="#cell-forward_and_backward" onclick="event.preventDefault(); selectCellLine('forward_and_backward', 169, true);">169</a>
|
| 3913 |
+
<a class="line-number" data-cell="forward_and_backward" data-line="170" href="#cell-forward_and_backward" onclick="event.preventDefault(); selectCellLine('forward_and_backward', 170, true);">170</a>
|
| 3914 |
+
<a class="line-number" data-cell="forward_and_backward" data-line="171" href="#cell-forward_and_backward" onclick="event.preventDefault(); selectCellLine('forward_and_backward', 171, true);">171</a>
|
| 3915 |
+
<a class="line-number" data-cell="forward_and_backward" data-line="172" href="#cell-forward_and_backward" onclick="event.preventDefault(); selectCellLine('forward_and_backward', 172, true);">172</a>
|
| 3916 |
+
<a class="line-number" data-cell="forward_and_backward" data-line="173" href="#cell-forward_and_backward" onclick="event.preventDefault(); selectCellLine('forward_and_backward', 173, true);">173</a>
|
| 3917 |
+
<a class="line-number" data-cell="forward_and_backward" data-line="174" href="#cell-forward_and_backward" onclick="event.preventDefault(); selectCellLine('forward_and_backward', 174, true);">174</a>
|
| 3918 |
+
<a class="line-number" data-cell="forward_and_backward" data-line="175" href="#cell-forward_and_backward" onclick="event.preventDefault(); selectCellLine('forward_and_backward', 175, true);">175</a>
|
| 3919 |
+
<a class="line-number" data-cell="forward_and_backward" data-line="176" href="#cell-forward_and_backward" onclick="event.preventDefault(); selectCellLine('forward_and_backward', 176, true);">176</a>
|
| 3920 |
+
<a class="line-number" data-cell="forward_and_backward" data-line="177" href="#cell-forward_and_backward" onclick="event.preventDefault(); selectCellLine('forward_and_backward', 177, true);">177</a>
|
| 3921 |
+
<a class="line-number" data-cell="forward_and_backward" data-line="178" href="#cell-forward_and_backward" onclick="event.preventDefault(); selectCellLine('forward_and_backward', 178, true);">178</a>
|
| 3922 |
+
<a class="line-number" data-cell="forward_and_backward" data-line="179" href="#cell-forward_and_backward" onclick="event.preventDefault(); selectCellLine('forward_and_backward', 179, true);">179</a>
|
| 3923 |
+
<a class="line-number" data-cell="forward_and_backward" data-line="180" href="#cell-forward_and_backward" onclick="event.preventDefault(); selectCellLine('forward_and_backward', 180, true);">180</a>
|
| 3924 |
+
<a class="line-number" data-cell="forward_and_backward" data-line="181" href="#cell-forward_and_backward" onclick="event.preventDefault(); selectCellLine('forward_and_backward', 181, true);">181</a>
|
| 3925 |
+
<a class="line-number" data-cell="forward_and_backward" data-line="182" href="#cell-forward_and_backward" onclick="event.preventDefault(); selectCellLine('forward_and_backward', 182, true);">182</a>
|
| 3926 |
+
<a class="line-number" data-cell="forward_and_backward" data-line="183" href="#cell-forward_and_backward" onclick="event.preventDefault(); selectCellLine('forward_and_backward', 183, true);">183</a>
|
| 3927 |
+
<a class="line-number" data-cell="forward_and_backward" data-line="184" href="#cell-forward_and_backward" onclick="event.preventDefault(); selectCellLine('forward_and_backward', 184, true);">184</a>
|
| 3928 |
+
<a class="line-number" data-cell="forward_and_backward" data-line="185" href="#cell-forward_and_backward" onclick="event.preventDefault(); selectCellLine('forward_and_backward', 185, true);">185</a>
|
| 3929 |
+
<a class="line-number" data-cell="forward_and_backward" data-line="186" href="#cell-forward_and_backward" onclick="event.preventDefault(); selectCellLine('forward_and_backward', 186, true);">186</a>
|
| 3930 |
+
<a class="line-number" data-cell="forward_and_backward" data-line="187" href="#cell-forward_and_backward" onclick="event.preventDefault(); selectCellLine('forward_and_backward', 187, true);">187</a>
|
| 3931 |
+
<a class="line-number" data-cell="forward_and_backward" data-line="188" href="#cell-forward_and_backward" onclick="event.preventDefault(); selectCellLine('forward_and_backward', 188, true);">188</a>
|
| 3932 |
+
<a class="line-number" data-cell="forward_and_backward" data-line="189" href="#cell-forward_and_backward" onclick="event.preventDefault(); selectCellLine('forward_and_backward', 189, true);">189</a>
|
| 3933 |
+
<a class="line-number" data-cell="forward_and_backward" data-line="190" href="#cell-forward_and_backward" onclick="event.preventDefault(); selectCellLine('forward_and_backward', 190, true);">190</a>
|
| 3934 |
+
<a class="line-number" data-cell="forward_and_backward" data-line="191" href="#cell-forward_and_backward" onclick="event.preventDefault(); selectCellLine('forward_and_backward', 191, true);">191</a>
|
| 3935 |
+
<a class="line-number" data-cell="forward_and_backward" data-line="192" href="#cell-forward_and_backward" onclick="event.preventDefault(); selectCellLine('forward_and_backward', 192, true);">192</a>
|
| 3936 |
+
<a class="line-number" data-cell="forward_and_backward" data-line="193" href="#cell-forward_and_backward" onclick="event.preventDefault(); selectCellLine('forward_and_backward', 193, true);">193</a>
|
| 3937 |
+
<a class="line-number" data-cell="forward_and_backward" data-line="194" href="#cell-forward_and_backward" onclick="event.preventDefault(); selectCellLine('forward_and_backward', 194, true);">194</a>
|
| 3938 |
+
<a class="line-number" data-cell="forward_and_backward" data-line="195" href="#cell-forward_and_backward" onclick="event.preventDefault(); selectCellLine('forward_and_backward', 195, true);">195</a>
|
| 3939 |
+
<a class="line-number" data-cell="forward_and_backward" data-line="196" href="#cell-forward_and_backward" onclick="event.preventDefault(); selectCellLine('forward_and_backward', 196, true);">196</a>
|
| 3940 |
</div>
|
| 3941 |
<div class="code-wrap">
|
| 3942 |
<div class="highlight"><pre><span></span><span class="c1"># /// script</span>
|
|
|
|
| 3963 |
<span class="kn">import</span><span class="w"> </span><span class="nn">logging</span>
|
| 3964 |
<span class="kn">from</span><span class="w"> </span><span class="nn">transformers.models.gpt_oss.modeling_gpt_oss</span><span class="w"> </span><span class="kn">import</span> <span class="n">GptOssRMSNorm</span>
|
| 3965 |
|
| 3966 |
+
<span class="c1"># remove liger kernel for testing </span>
|
| 3967 |
<span class="n">replace_kernel_forward_from_hub</span><span class="p">(</span><span class="n">GptOssRMSNorm</span><span class="p">,</span> <span class="kc">None</span><span class="p">)</span>
|
| 3968 |
|
| 3969 |
<span class="c1"># set to debug logging</span>
|
|
|
|
| 4004 |
<span class="n">tokenizer</span> <span class="o">=</span> <span class="n">PreTrainedTokenizerFast</span><span class="o">.</span><span class="n">from_pretrained</span><span class="p">(</span><span class="n">model_id</span><span class="p">)</span>
|
| 4005 |
<span class="n">quantization_config</span> <span class="o">=</span> <span class="n">Mxfp4Config</span><span class="p">(</span><span class="n">dequantize</span><span class="o">=</span><span class="kc">True</span><span class="p">)</span>
|
| 4006 |
|
|
|
|
|
|
|
| 4007 |
<span class="n">model</span> <span class="o">=</span> <span class="n">GptOssForCausalLM</span><span class="o">.</span><span class="n">from_pretrained</span><span class="p">(</span>
|
| 4008 |
<span class="n">model_id</span><span class="p">,</span>
|
| 4009 |
<span class="n">dtype</span><span class="o">=</span><span class="s2">"bfloat16"</span><span class="p">,</span>
|
|
|
|
| 4024 |
<span class="n">reasoning_effort</span><span class="o">=</span><span class="s2">"low"</span><span class="p">,</span>
|
| 4025 |
<span class="p">)</span><span class="o">.</span><span class="n">to</span><span class="p">(</span><span class="s2">"cuda"</span><span class="p">)</span>
|
| 4026 |
|
| 4027 |
+
<span class="n">max_tokens</span> <span class="o">=</span> <span class="mi">128</span> <span class="c1"># Reduced to help with memory usage</span>
|
| 4028 |
+
|
| 4029 |
+
<span class="c1"># Clear memory before backward pass</span>
|
| 4030 |
+
<span class="n">reset_peak_memory_stats</span><span class="p">()</span>
|
| 4031 |
+
<span class="nb">print</span><span class="p">(</span><span class="sa">f</span><span class="s2">"Pre-generation memory: </span><span class="si">{</span><span class="n">get_memory_stats</span><span class="p">()</span><span class="si">}</span><span class="s2">"</span><span class="p">)</span>
|
| 4032 |
|
| 4033 |
+
<span class="c1"># forward and backward pass</span>
|
| 4034 |
+
<span class="k">with</span> <span class="n">torch</span><span class="o">.</span><span class="n">autograd</span><span class="o">.</span><span class="n">set_grad_enabled</span><span class="p">(</span><span class="kc">True</span><span class="p">):</span>
|
| 4035 |
<span class="n">start_time</span> <span class="o">=</span> <span class="n">time</span><span class="o">.</span><span class="n">perf_counter</span><span class="p">()</span>
|
| 4036 |
<span class="n">generated</span> <span class="o">=</span> <span class="n">model</span><span class="o">.</span><span class="n">generate</span><span class="p">(</span>
|
| 4037 |
<span class="o">**</span><span class="n">inputs</span><span class="p">,</span>
|
|
|
|
| 4040 |
<span class="n">temperature</span><span class="o">=</span><span class="kc">None</span><span class="p">,</span>
|
| 4041 |
<span class="p">)</span>
|
| 4042 |
<span class="n">end_time</span> <span class="o">=</span> <span class="n">time</span><span class="o">.</span><span class="n">perf_counter</span><span class="p">()</span>
|
| 4043 |
+
<span class="nb">print</span><span class="p">(</span><span class="n">tokenizer</span><span class="o">.</span><span class="n">decode</span><span class="p">(</span><span class="n">generated</span><span class="p">[</span><span class="mi">0</span><span class="p">],</span> <span class="n">skip_special_tokens</span><span class="o">=</span><span class="kc">False</span><span class="p">))</span>
|
| 4044 |
+
<span class="nb">print</span><span class="p">(</span><span class="sa">f</span><span class="s2">"Generation took </span><span class="si">{</span><span class="n">end_time</span><span class="w"> </span><span class="o">-</span><span class="w"> </span><span class="n">start_time</span><span class="si">:</span><span class="s2">.2f</span><span class="si">}</span><span class="s2"> seconds"</span><span class="p">)</span>
|
| 4045 |
+
<span class="nb">print</span><span class="p">(</span><span class="sa">f</span><span class="s2">"Post-generation memory: </span><span class="si">{</span><span class="n">get_memory_stats</span><span class="p">()</span><span class="si">}</span><span class="s2">"</span><span class="p">)</span>
|
| 4046 |
+
|
| 4047 |
+
<span class="c1"># Use gradient checkpointing to reduce memory usage</span>
|
| 4048 |
+
<span class="k">if</span> <span class="nb">hasattr</span><span class="p">(</span><span class="n">model</span><span class="p">,</span> <span class="s1">'gradient_checkpointing_enable'</span><span class="p">):</span>
|
| 4049 |
+
<span class="n">model</span><span class="o">.</span><span class="n">gradient_checkpointing_enable</span><span class="p">()</span>
|
| 4050 |
+
<span class="nb">print</span><span class="p">(</span><span class="s2">"Enabled gradient checkpointing"</span><span class="p">)</span>
|
| 4051 |
+
|
| 4052 |
+
<span class="c1"># Reduce sequence length if needed for memory</span>
|
| 4053 |
+
<span class="n">max_seq_len</span> <span class="o">=</span> <span class="mi">512</span> <span class="c1"># Limit sequence length for backward pass</span>
|
| 4054 |
+
<span class="k">if</span> <span class="n">generated</span><span class="o">.</span><span class="n">size</span><span class="p">(</span><span class="mi">1</span><span class="p">)</span> <span class="o">></span> <span class="n">max_seq_len</span><span class="p">:</span>
|
| 4055 |
+
<span class="nb">print</span><span class="p">(</span><span class="sa">f</span><span class="s2">"Truncating sequence from </span><span class="si">{</span><span class="n">generated</span><span class="o">.</span><span class="n">size</span><span class="p">(</span><span class="mi">1</span><span class="p">)</span><span class="si">}</span><span class="s2"> to </span><span class="si">{</span><span class="n">max_seq_len</span><span class="si">}</span><span class="s2"> tokens"</span><span class="p">)</span>
|
| 4056 |
+
<span class="n">full_sequence</span> <span class="o">=</span> <span class="n">generated</span><span class="p">[:,</span> <span class="o">-</span><span class="n">max_seq_len</span><span class="p">:]</span>
|
| 4057 |
+
<span class="k">else</span><span class="p">:</span>
|
| 4058 |
+
<span class="n">full_sequence</span> <span class="o">=</span> <span class="n">generated</span>
|
| 4059 |
+
|
| 4060 |
+
<span class="c1"># Get model outputs for the full sequence</span>
|
| 4061 |
+
<span class="n">model</span><span class="o">.</span><span class="n">train</span><span class="p">()</span> <span class="c1"># Enable dropout and other training behaviors</span>
|
| 4062 |
+
|
| 4063 |
+
<span class="k">try</span><span class="p">:</span>
|
| 4064 |
+
<span class="n">outputs</span> <span class="o">=</span> <span class="n">model</span><span class="p">(</span>
|
| 4065 |
+
<span class="n">input_ids</span><span class="o">=</span><span class="n">full_sequence</span><span class="p">,</span>
|
| 4066 |
+
<span class="n">labels</span><span class="o">=</span><span class="n">full_sequence</span><span class="p">,</span> <span class="c1"># This will compute loss internally</span>
|
| 4067 |
+
<span class="n">return_dict</span><span class="o">=</span><span class="kc">True</span>
|
| 4068 |
+
<span class="p">)</span>
|
| 4069 |
+
<span class="nb">print</span><span class="p">(</span><span class="sa">f</span><span class="s2">"Post-forward memory: </span><span class="si">{</span><span class="n">get_memory_stats</span><span class="p">()</span><span class="si">}</span><span class="s2">"</span><span class="p">)</span>
|
| 4070 |
+
|
| 4071 |
+
<span class="c1"># If model doesn't compute loss, compute it manually</span>
|
| 4072 |
+
<span class="k">if</span> <span class="n">outputs</span><span class="o">.</span><span class="n">loss</span> <span class="ow">is</span> <span class="kc">None</span><span class="p">:</span>
|
| 4073 |
+
<span class="n">shift_logits</span> <span class="o">=</span> <span class="n">outputs</span><span class="o">.</span><span class="n">logits</span><span class="p">[</span><span class="o">...</span><span class="p">,</span> <span class="p">:</span><span class="o">-</span><span class="mi">1</span><span class="p">,</span> <span class="p">:]</span><span class="o">.</span><span class="n">contiguous</span><span class="p">()</span>
|
| 4074 |
+
<span class="n">shift_labels</span> <span class="o">=</span> <span class="n">full_sequence</span><span class="p">[</span><span class="o">...</span><span class="p">,</span> <span class="mi">1</span><span class="p">:]</span><span class="o">.</span><span class="n">contiguous</span><span class="p">()</span>
|
| 4075 |
+
|
| 4076 |
+
<span class="c1"># Use CrossEntropyLoss with ignore_index for padding tokens</span>
|
| 4077 |
+
<span class="n">loss_fct</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">nn</span><span class="o">.</span><span class="n">CrossEntropyLoss</span><span class="p">(</span><span class="n">ignore_index</span><span class="o">=</span><span class="n">tokenizer</span><span class="o">.</span><span class="n">pad_token_id</span> <span class="k">if</span> <span class="n">tokenizer</span><span class="o">.</span><span class="n">pad_token_id</span> <span class="ow">is</span> <span class="ow">not</span> <span class="kc">None</span> <span class="k">else</span> <span class="o">-</span><span class="mi">100</span><span class="p">)</span>
|
| 4078 |
+
<span class="n">loss</span> <span class="o">=</span> <span class="n">loss_fct</span><span class="p">(</span>
|
| 4079 |
+
<span class="n">shift_logits</span><span class="o">.</span><span class="n">view</span><span class="p">(</span><span class="o">-</span><span class="mi">1</span><span class="p">,</span> <span class="n">shift_logits</span><span class="o">.</span><span class="n">size</span><span class="p">(</span><span class="o">-</span><span class="mi">1</span><span class="p">)),</span>
|
| 4080 |
+
<span class="n">shift_labels</span><span class="o">.</span><span class="n">view</span><span class="p">(</span><span class="o">-</span><span class="mi">1</span><span class="p">)</span>
|
| 4081 |
+
<span class="p">)</span>
|
| 4082 |
+
<span class="k">else</span><span class="p">:</span>
|
| 4083 |
+
<span class="n">loss</span> <span class="o">=</span> <span class="n">outputs</span><span class="o">.</span><span class="n">loss</span>
|
| 4084 |
+
|
| 4085 |
+
<span class="nb">print</span><span class="p">(</span><span class="sa">f</span><span class="s2">"Loss: </span><span class="si">{</span><span class="n">loss</span><span class="o">.</span><span class="n">item</span><span class="p">()</span><span class="si">:</span><span class="s2">.4f</span><span class="si">}</span><span class="s2">"</span><span class="p">)</span>
|
| 4086 |
+
|
| 4087 |
+
<span class="c1"># Clear intermediate tensors to save memory</span>
|
| 4088 |
+
<span class="k">del</span> <span class="n">outputs</span>
|
| 4089 |
+
<span class="n">torch</span><span class="o">.</span><span class="n">cuda</span><span class="o">.</span><span class="n">empty_cache</span><span class="p">()</span>
|
| 4090 |
+
|
| 4091 |
+
<span class="c1"># Perform backward pass with memory management</span>
|
| 4092 |
+
<span class="nb">print</span><span class="p">(</span><span class="s2">"Running backward pass..."</span><span class="p">)</span>
|
| 4093 |
+
<span class="nb">print</span><span class="p">(</span><span class="sa">f</span><span class="s2">"Pre-backward memory: </span><span class="si">{</span><span class="n">get_memory_stats</span><span class="p">()</span><span class="si">}</span><span class="s2">"</span><span class="p">)</span>
|
| 4094 |
+
|
| 4095 |
+
<span class="n">loss</span><span class="o">.</span><span class="n">backward</span><span class="p">()</span>
|
| 4096 |
+
<span class="nb">print</span><span class="p">(</span><span class="sa">f</span><span class="s2">"Post-backward memory: </span><span class="si">{</span><span class="n">get_memory_stats</span><span class="p">()</span><span class="si">}</span><span class="s2">"</span><span class="p">)</span>
|
| 4097 |
+
|
| 4098 |
+
<span class="k">except</span> <span class="n">torch</span><span class="o">.</span><span class="n">cuda</span><span class="o">.</span><span class="n">OutOfMemoryError</span> <span class="k">as</span> <span class="n">e</span><span class="p">:</span>
|
| 4099 |
+
<span class="nb">print</span><span class="p">(</span><span class="sa">f</span><span class="s2">"OOM during forward/backward pass: </span><span class="si">{</span><span class="n">e</span><span class="si">}</span><span class="s2">"</span><span class="p">)</span>
|
| 4100 |
+
<span class="nb">print</span><span class="p">(</span><span class="s2">"Try reducing max_tokens or max_seq_len"</span><span class="p">)</span>
|
| 4101 |
+
<span class="k">raise</span>
|
| 4102 |
+
|
| 4103 |
+
<span class="c1"># Calculate gradient statistics and print sample gradients</span>
|
| 4104 |
+
<span class="n">total_norm</span> <span class="o">=</span> <span class="mf">0.0</span>
|
| 4105 |
+
<span class="n">param_count</span> <span class="o">=</span> <span class="mi">0</span>
|
| 4106 |
+
<span class="n">grad_samples</span> <span class="o">=</span> <span class="p">{}</span>
|
| 4107 |
+
|
| 4108 |
+
<span class="k">for</span> <span class="n">name</span><span class="p">,</span> <span class="n">p</span> <span class="ow">in</span> <span class="n">model</span><span class="o">.</span><span class="n">named_parameters</span><span class="p">():</span>
|
| 4109 |
+
<span class="k">if</span> <span class="n">p</span><span class="o">.</span><span class="n">grad</span> <span class="ow">is</span> <span class="ow">not</span> <span class="kc">None</span><span class="p">:</span>
|
| 4110 |
+
<span class="n">param_count</span> <span class="o">+=</span> <span class="mi">1</span>
|
| 4111 |
+
<span class="n">grad_norm</span> <span class="o">=</span> <span class="n">p</span><span class="o">.</span><span class="n">grad</span><span class="o">.</span><span class="n">data</span><span class="o">.</span><span class="n">norm</span><span class="p">(</span><span class="mi">2</span><span class="p">)</span><span class="o">.</span><span class="n">item</span><span class="p">()</span>
|
| 4112 |
+
<span class="n">total_norm</span> <span class="o">+=</span> <span class="n">grad_norm</span> <span class="o">**</span> <span class="mi">2</span>
|
| 4113 |
+
|
| 4114 |
+
<span class="c1"># Collect gradient statistics for key layers</span>
|
| 4115 |
+
<span class="k">if</span> <span class="nb">any</span><span class="p">(</span><span class="n">key</span> <span class="ow">in</span> <span class="n">name</span> <span class="k">for</span> <span class="n">key</span> <span class="ow">in</span> <span class="p">[</span><span class="s1">'embed'</span><span class="p">,</span> <span class="s1">'lm_head'</span><span class="p">,</span> <span class="s1">'mlp.up'</span><span class="p">,</span> <span class="s1">'mlp.down'</span><span class="p">,</span> <span class="s1">'self_attn.q_proj'</span><span class="p">,</span> <span class="s1">'norm'</span><span class="p">]):</span>
|
| 4116 |
+
<span class="n">grad_samples</span><span class="p">[</span><span class="n">name</span><span class="p">]</span> <span class="o">=</span> <span class="p">{</span>
|
| 4117 |
+
<span class="s1">'norm'</span><span class="p">:</span> <span class="n">grad_norm</span><span class="p">,</span>
|
| 4118 |
+
<span class="s1">'mean'</span><span class="p">:</span> <span class="n">p</span><span class="o">.</span><span class="n">grad</span><span class="o">.</span><span class="n">data</span><span class="o">.</span><span class="n">mean</span><span class="p">()</span><span class="o">.</span><span class="n">item</span><span class="p">(),</span>
|
| 4119 |
+
<span class="s1">'std'</span><span class="p">:</span> <span class="n">p</span><span class="o">.</span><span class="n">grad</span><span class="o">.</span><span class="n">data</span><span class="o">.</span><span class="n">std</span><span class="p">()</span><span class="o">.</span><span class="n">item</span><span class="p">(),</span>
|
| 4120 |
+
<span class="s1">'max'</span><span class="p">:</span> <span class="n">p</span><span class="o">.</span><span class="n">grad</span><span class="o">.</span><span class="n">data</span><span class="o">.</span><span class="n">max</span><span class="p">()</span><span class="o">.</span><span class="n">item</span><span class="p">(),</span>
|
| 4121 |
+
<span class="s1">'min'</span><span class="p">:</span> <span class="n">p</span><span class="o">.</span><span class="n">grad</span><span class="o">.</span><span class="n">data</span><span class="o">.</span><span class="n">min</span><span class="p">()</span><span class="o">.</span><span class="n">item</span><span class="p">(),</span>
|
| 4122 |
+
<span class="p">}</span>
|
| 4123 |
+
|
| 4124 |
+
<span class="n">total_norm</span> <span class="o">=</span> <span class="n">total_norm</span> <span class="o">**</span> <span class="mf">0.5</span>
|
| 4125 |
+
|
| 4126 |
+
<span class="nb">print</span><span class="p">(</span><span class="sa">f</span><span class="s2">"</span><span class="se">\n</span><span class="s2">Gradient norm: </span><span class="si">{</span><span class="n">total_norm</span><span class="si">:</span><span class="s2">.4f</span><span class="si">}</span><span class="s2">"</span><span class="p">)</span>
|
| 4127 |
+
<span class="nb">print</span><span class="p">(</span><span class="sa">f</span><span class="s2">"Parameters with gradients: </span><span class="si">{</span><span class="n">param_count</span><span class="si">}</span><span class="s2">"</span><span class="p">)</span>
|
| 4128 |
+
|
| 4129 |
+
<span class="c1"># Print sample gradients from important layers</span>
|
| 4130 |
+
<span class="nb">print</span><span class="p">(</span><span class="s2">"</span><span class="se">\n</span><span class="s2">Sample gradient statistics:"</span><span class="p">)</span>
|
| 4131 |
+
<span class="k">for</span> <span class="n">i</span><span class="p">,</span> <span class="p">(</span><span class="n">name</span><span class="p">,</span> <span class="n">stats</span><span class="p">)</span> <span class="ow">in</span> <span class="nb">enumerate</span><span class="p">(</span><span class="nb">list</span><span class="p">(</span><span class="n">grad_samples</span><span class="o">.</span><span class="n">items</span><span class="p">())[:</span><span class="mi">10</span><span class="p">]):</span>
|
| 4132 |
+
<span class="nb">print</span><span class="p">(</span><span class="sa">f</span><span class="s2">" </span><span class="si">{</span><span class="n">name</span><span class="p">[:</span><span class="mi">60</span><span class="p">]</span><span class="si">:</span><span class="s2"><60</span><span class="si">}</span><span class="s2"> | norm: </span><span class="si">{</span><span class="n">stats</span><span class="p">[</span><span class="s1">'norm'</span><span class="p">]</span><span class="si">:</span><span class="s2">.4e</span><span class="si">}</span><span class="s2"> | mean: </span><span class="si">{</span><span class="n">stats</span><span class="p">[</span><span class="s1">'mean'</span><span class="p">]</span><span class="si">:</span><span class="s2">.4e</span><span class="si">}</span><span class="s2"> | std: </span><span class="si">{</span><span class="n">stats</span><span class="p">[</span><span class="s1">'std'</span><span class="p">]</span><span class="si">:</span><span class="s2">.4e</span><span class="si">}</span><span class="s2">"</span><span class="p">)</span>
|
| 4133 |
+
|
| 4134 |
+
<span class="c1"># Optional: zero gradients for next iteration</span>
|
| 4135 |
+
<span class="n">model</span><span class="o">.</span><span class="n">zero_grad</span><span class="p">()</span>
|
| 4136 |
+
<span class="n">model</span><span class="o">.</span><span class="n">eval</span><span class="p">()</span> <span class="c1"># Switch back to eval mode</span>
|
| 4137 |
</pre></div>
|
| 4138 |
|
| 4139 |
+
<div class="code-line-highlight" id="line-highlight-forward_and_backward"></div>
|
| 4140 |
</div>
|
| 4141 |
</div>
|
| 4142 |
</div>
|
| 4143 |
+
<div id="output-forward_and_backward" class="cell-output">
|
| 4144 |
+
<div class="cell-stderr">Downloading cpython-3.13.7-linux-x86_64-gnu (download) (32.0MiB)
|
| 4145 |
+
Downloading cpython-3.13.7-linux-x86_64-gnu (download)
|
|
|
|
| 4146 |
Updating https://github.com/huggingface/transformers.git (HEAD)
|
| 4147 |
+
Updated https://github.com/huggingface/transformers.git (449533af73874470e914a203391635e04ac2ffc8)
|
| 4148 |
+
× No solution found when resolving script dependencies:
|
| 4149 |
+
╰─▶ Because only transformers==4.57.0.dev0 is available and
|
| 4150 |
+
transformers==4.57.0.dev0 depends on huggingface-hub==1.0.0rc1,
|
| 4151 |
+
we can conclude that all versions of transformers depend on
|
| 4152 |
+
huggingface-hub==1.0.0rc1.
|
| 4153 |
+
And because kernels==0.10.0 depends on huggingface-hub>=0.26.0,<1.0,
|
| 4154 |
+
we can conclude that kernels==0.10.0 and all versions of transformers
|
| 4155 |
+
are incompatible.
|
| 4156 |
+
And because you require kernels==0.10.0 and transformers, we can
|
| 4157 |
+
conclude that your requirements are unsatisfiable.
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 4158 |
</div>
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 4159 |
</div>
|
| 4160 |
</div>
|
|
|
|
|
|
|
|
|
|
| 4161 |
</div>
|
| 4162 |
|
| 4163 |
</body>
|
moe_benchmarks/megablocks_yamoe/artifacts/binned_run/binned_results.json
ADDED
|
@@ -0,0 +1,24 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"implementation": "binned_results",
|
| 3 |
+
"config": {
|
| 4 |
+
"warmup": 10,
|
| 5 |
+
"iters": 50,
|
| 6 |
+
"device": "cuda",
|
| 7 |
+
"dtype": "torch.float32",
|
| 8 |
+
"tokens": 100,
|
| 9 |
+
"vary_inputs": true
|
| 10 |
+
},
|
| 11 |
+
"stats": {
|
| 12 |
+
"avg_ms": 36.06324691992995,
|
| 13 |
+
"min_ms": 33.29206800026441,
|
| 14 |
+
"max_ms": 38.40615900026023,
|
| 15 |
+
"std_ms": 1.258567678508065,
|
| 16 |
+
"p50_ms": 36.21510599987232,
|
| 17 |
+
"p95_ms": 37.524451049966956,
|
| 18 |
+
"p99_ms": 38.03603995002959,
|
| 19 |
+
"num_iters": 50,
|
| 20 |
+
"tokens_per_s": 2772.906172925215,
|
| 21 |
+
"throughput_variance": 98.28636435515342
|
| 22 |
+
},
|
| 23 |
+
"output_sum": 3.97190523147583
|
| 24 |
+
}
|
moe_benchmarks/megablocks_yamoe/artifacts/gptoss_run/gptoss_results.json
ADDED
|
@@ -0,0 +1,24 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"implementation": "gptoss_results",
|
| 3 |
+
"config": {
|
| 4 |
+
"warmup": 10,
|
| 5 |
+
"iters": 50,
|
| 6 |
+
"device": "cuda",
|
| 7 |
+
"dtype": "torch.float32",
|
| 8 |
+
"tokens": 100,
|
| 9 |
+
"vary_inputs": true
|
| 10 |
+
},
|
| 11 |
+
"stats": {
|
| 12 |
+
"avg_ms": 45.286630379978305,
|
| 13 |
+
"min_ms": 38.91367899996112,
|
| 14 |
+
"max_ms": 49.84392799997295,
|
| 15 |
+
"std_ms": 3.2326168009526866,
|
| 16 |
+
"p50_ms": 45.42240999990099,
|
| 17 |
+
"p95_ms": 49.729684149951936,
|
| 18 |
+
"p99_ms": 49.82545450991893,
|
| 19 |
+
"num_iters": 50,
|
| 20 |
+
"tokens_per_s": 2208.1572234663554,
|
| 21 |
+
"throughput_variance": 161.27578702324564
|
| 22 |
+
},
|
| 23 |
+
"output_sum": 11.53223705291748
|
| 24 |
+
}
|
moe_benchmarks/megablocks_yamoe/artifacts/gptoss_training_run/gptoss_training_results.json
ADDED
|
@@ -0,0 +1,24 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"implementation": "gptoss_training_results",
|
| 3 |
+
"config": {
|
| 4 |
+
"warmup": 10,
|
| 5 |
+
"iters": 50,
|
| 6 |
+
"device": "cuda",
|
| 7 |
+
"dtype": "torch.float32",
|
| 8 |
+
"tokens": 100,
|
| 9 |
+
"vary_inputs": true
|
| 10 |
+
},
|
| 11 |
+
"stats": {
|
| 12 |
+
"avg_ms": 46.01034353989235,
|
| 13 |
+
"min_ms": 39.20698799993261,
|
| 14 |
+
"max_ms": 51.09754699969926,
|
| 15 |
+
"std_ms": 3.2594474712819497,
|
| 16 |
+
"p50_ms": 46.132551999562565,
|
| 17 |
+
"p95_ms": 50.721096600273086,
|
| 18 |
+
"p99_ms": 51.0080171399477,
|
| 19 |
+
"num_iters": 50,
|
| 20 |
+
"tokens_per_s": 2173.4243282338675,
|
| 21 |
+
"throughput_variance": 158.68467070353637
|
| 22 |
+
},
|
| 23 |
+
"output_sum": 11.53223705291748
|
| 24 |
+
}
|
moe_benchmarks/megablocks_yamoe/artifacts/yamoe_run/yamoe_results.json
ADDED
|
@@ -0,0 +1,24 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"implementation": "yamoe_results",
|
| 3 |
+
"config": {
|
| 4 |
+
"warmup": 10,
|
| 5 |
+
"iters": 50,
|
| 6 |
+
"device": "cuda",
|
| 7 |
+
"dtype": "torch.float32",
|
| 8 |
+
"tokens": 100,
|
| 9 |
+
"vary_inputs": true
|
| 10 |
+
},
|
| 11 |
+
"stats": {
|
| 12 |
+
"avg_ms": 4.2510544400101935,
|
| 13 |
+
"min_ms": 4.144352999901457,
|
| 14 |
+
"max_ms": 4.320155999266717,
|
| 15 |
+
"std_ms": 0.02873328656403644,
|
| 16 |
+
"p50_ms": 4.2539659998510615,
|
| 17 |
+
"p95_ms": 4.2857709999225335,
|
| 18 |
+
"p99_ms": 4.306132199617423,
|
| 19 |
+
"num_iters": 50,
|
| 20 |
+
"tokens_per_s": 23523.575482547854,
|
| 21 |
+
"throughput_variance": 160.28680309512873
|
| 22 |
+
},
|
| 23 |
+
"output_sum": 3.97190523147583
|
| 24 |
+
}
|
moe_benchmarks/megablocks_yamoe/cells/__pycache__/bench_utils.cpython-311.pyc
CHANGED
|
Binary files a/moe_benchmarks/megablocks_yamoe/cells/__pycache__/bench_utils.cpython-311.pyc and b/moe_benchmarks/megablocks_yamoe/cells/__pycache__/bench_utils.cpython-311.pyc differ
|
|
|
moe_benchmarks/megablocks_yamoe/cells/__pycache__/config.cpython-311.pyc
CHANGED
|
Binary files a/moe_benchmarks/megablocks_yamoe/cells/__pycache__/config.cpython-311.pyc and b/moe_benchmarks/megablocks_yamoe/cells/__pycache__/config.cpython-311.pyc differ
|
|
|
moe_benchmarks/megablocks_yamoe/cells/binned_run.py
ADDED
|
@@ -0,0 +1,195 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# /// script
|
| 2 |
+
# dependencies = [
|
| 3 |
+
# "torch",
|
| 4 |
+
# "numpy",
|
| 5 |
+
# ]
|
| 6 |
+
# ///
|
| 7 |
+
|
| 8 |
+
import torch
|
| 9 |
+
from torch import nn
|
| 10 |
+
from torch.nn import functional as F
|
| 11 |
+
from bench_utils import to_dtype, tensor_stats, set_seed, bench_context
|
| 12 |
+
from config import (
|
| 13 |
+
NUM_EXPERTS, HIDDEN_SIZE, TOP_K,
|
| 14 |
+
BATCH_SIZE, SEQ_LEN, DTYPE, DEVICE,
|
| 15 |
+
WEIGHT_SEED, EXPERT_SEED, INPUT_SEED, GENERAL_SEED
|
| 16 |
+
)
|
| 17 |
+
from pathlib import Path
|
| 18 |
+
import os
|
| 19 |
+
|
| 20 |
+
# Discover the upstream artifact directory from env
|
| 21 |
+
data_dir = os.environ.get('UVNOTE_INPUT_SAVE_DATA', '.')
|
| 22 |
+
|
| 23 |
+
router_weight = torch.load(Path(data_dir) / 'router_weight.pt')
|
| 24 |
+
router_bias = torch.load(Path(data_dir) / 'router_bias.pt')
|
| 25 |
+
gate_up_proj = torch.load(Path(data_dir) / 'gate_up_proj.pt')
|
| 26 |
+
gate_up_proj_bias = torch.load(Path(data_dir) / 'gate_up_proj_bias.pt')
|
| 27 |
+
down_proj = torch.load(Path(data_dir) / 'down_proj.pt')
|
| 28 |
+
down_proj_bias = torch.load(Path(data_dir) / 'down_proj_bias.pt')
|
| 29 |
+
|
| 30 |
+
print("Loaded shared weights from artifacts")
|
| 31 |
+
print(f"Router weight sum: {router_weight.sum().item():.6f}")
|
| 32 |
+
print(f"Gate/up sum: {gate_up_proj.sum().item():.6f}")
|
| 33 |
+
print(f"Down sum: {down_proj.sum().item():.6f}")
|
| 34 |
+
|
| 35 |
+
def binned_gather(x, indices, bins, expert_capacity, top_k):
|
| 36 |
+
E, H = bins.shape[0], x.shape[1]
|
| 37 |
+
out = torch.zeros((E, expert_capacity, H), device=x.device, dtype=x.dtype)
|
| 38 |
+
for e in range(E):
|
| 39 |
+
start = 0 if e == 0 else bins[e - 1]
|
| 40 |
+
end = bins[e]
|
| 41 |
+
n = min(end - start, expert_capacity)
|
| 42 |
+
for i in range(n):
|
| 43 |
+
flat_pos = indices[start + i]
|
| 44 |
+
tok = flat_pos // top_k
|
| 45 |
+
out[e, i] = x[tok]
|
| 46 |
+
return out
|
| 47 |
+
|
| 48 |
+
def binned_scatter(x, indices, weights, bins, expert_capacity, top_k):
|
| 49 |
+
E, C, H = x.shape
|
| 50 |
+
N = indices.shape[0] // top_k
|
| 51 |
+
out = torch.zeros((N, top_k, H), dtype=x.dtype, device=x.device)
|
| 52 |
+
for e in range(E):
|
| 53 |
+
start = 0 if e == 0 else bins[e - 1]
|
| 54 |
+
end = bins[e]
|
| 55 |
+
n = end - start
|
| 56 |
+
if n == 0:
|
| 57 |
+
continue
|
| 58 |
+
take = min(n, expert_capacity)
|
| 59 |
+
for i in range(take):
|
| 60 |
+
flat_pos = indices[start + i]
|
| 61 |
+
tok = flat_pos // top_k
|
| 62 |
+
slot = flat_pos % top_k
|
| 63 |
+
scale = weights[flat_pos] if weights is not None else 1.0
|
| 64 |
+
out[tok, slot] = x[e, i] * scale
|
| 65 |
+
return out.sum(dim=1)
|
| 66 |
+
|
| 67 |
+
def sort_tokens_by_expert(router_indices, num_experts):
|
| 68 |
+
flat_indices = router_indices.flatten()
|
| 69 |
+
sorted_values, sorted_indices = torch.sort(flat_indices)
|
| 70 |
+
tokens_per_expert = torch.bincount(sorted_values, minlength=num_experts)
|
| 71 |
+
bins = torch.cumsum(tokens_per_expert, dim=0)
|
| 72 |
+
return sorted_indices, sorted_values, bins, tokens_per_expert
|
| 73 |
+
|
| 74 |
+
def binned_experts_ref(
|
| 75 |
+
hidden_states,
|
| 76 |
+
router_indices,
|
| 77 |
+
routing_weights,
|
| 78 |
+
gate_up_proj,
|
| 79 |
+
gate_up_proj_bias,
|
| 80 |
+
down_proj,
|
| 81 |
+
down_proj_bias,
|
| 82 |
+
expert_capacity,
|
| 83 |
+
):
|
| 84 |
+
B, S, H = hidden_states.shape
|
| 85 |
+
E, K = routing_weights.shape[1], router_indices.shape[1]
|
| 86 |
+
|
| 87 |
+
indices, _, bins, _ = sort_tokens_by_expert(router_indices, E)
|
| 88 |
+
x = binned_gather(hidden_states.view(-1, H), indices, bins, expert_capacity, K)
|
| 89 |
+
|
| 90 |
+
gate_up = torch.bmm(x, gate_up_proj)
|
| 91 |
+
gate_up += gate_up_proj_bias[..., None, :]
|
| 92 |
+
|
| 93 |
+
gate, up = gate_up[..., ::2], gate_up[..., 1::2]
|
| 94 |
+
|
| 95 |
+
# clamp to limit
|
| 96 |
+
limit = 7.0
|
| 97 |
+
gate = gate.clamp(min=None, max=limit)
|
| 98 |
+
up = up.clamp(min=-limit, max=limit)
|
| 99 |
+
|
| 100 |
+
glu = gate * torch.sigmoid(gate * 1.702)
|
| 101 |
+
x = (up + 1) * glu
|
| 102 |
+
x = torch.bmm(x, down_proj) + down_proj_bias[..., None, :]
|
| 103 |
+
|
| 104 |
+
# build routing weights aligned to (token, slot)
|
| 105 |
+
flat_dense = routing_weights.view(-1, E)
|
| 106 |
+
flat_router = router_indices.view(-1, K)
|
| 107 |
+
selected = torch.gather(flat_dense, 1, flat_router).reshape(-1)
|
| 108 |
+
|
| 109 |
+
# scatter back
|
| 110 |
+
y = binned_scatter(x, indices, selected, bins, expert_capacity, K)
|
| 111 |
+
|
| 112 |
+
return y.view(B, S, H)
|
| 113 |
+
|
| 114 |
+
class BinnedRouter(nn.Module):
|
| 115 |
+
def __init__(self, router_weight, router_bias):
|
| 116 |
+
super().__init__()
|
| 117 |
+
self.top_k = TOP_K
|
| 118 |
+
self.num_experts = NUM_EXPERTS
|
| 119 |
+
self.hidden_dim = HIDDEN_SIZE
|
| 120 |
+
self.weight = nn.Parameter(router_weight.clone())
|
| 121 |
+
self.bias = nn.Parameter(router_bias.clone())
|
| 122 |
+
|
| 123 |
+
def forward(self, hidden_states):
|
| 124 |
+
hidden_states = hidden_states.reshape(-1, self.hidden_dim)
|
| 125 |
+
router_logits = F.linear(hidden_states, self.weight, self.bias)
|
| 126 |
+
router_top_value, router_indices = torch.topk(router_logits, self.top_k, dim=-1)
|
| 127 |
+
router_top_value = torch.nn.functional.softmax(router_top_value, dim=1, dtype=router_top_value.dtype)
|
| 128 |
+
router_scores = torch.zeros_like(router_logits).scatter_(1, router_indices, router_top_value)
|
| 129 |
+
return router_scores, router_indices
|
| 130 |
+
|
| 131 |
+
def ceil_div(a, b):
|
| 132 |
+
return (a + b - 1) // b
|
| 133 |
+
|
| 134 |
+
class BinnedMoEMLP(nn.Module):
|
| 135 |
+
def __init__(self, router_weight, router_bias, gate_up_proj, gate_up_proj_bias, down_proj, down_proj_bias):
|
| 136 |
+
super().__init__()
|
| 137 |
+
self.router = BinnedRouter(router_weight, router_bias)
|
| 138 |
+
self.num_experts = NUM_EXPERTS
|
| 139 |
+
self.hidden_size = HIDDEN_SIZE
|
| 140 |
+
self.top_k = TOP_K
|
| 141 |
+
|
| 142 |
+
# Expert weights - use the loaded weights
|
| 143 |
+
self.gate_up_proj = nn.Parameter(gate_up_proj.clone())
|
| 144 |
+
self.gate_up_proj_bias = nn.Parameter(gate_up_proj_bias.clone())
|
| 145 |
+
self.down_proj = nn.Parameter(down_proj.clone())
|
| 146 |
+
self.down_proj_bias = nn.Parameter(down_proj_bias.clone())
|
| 147 |
+
|
| 148 |
+
def forward(self, hidden_states):
|
| 149 |
+
router_scores, router_indices = self.router(hidden_states)
|
| 150 |
+
batch_size = hidden_states.shape[0]
|
| 151 |
+
expert_capacity = ceil_div(batch_size * self.top_k, self.num_experts)
|
| 152 |
+
|
| 153 |
+
output = binned_experts_ref(
|
| 154 |
+
hidden_states,
|
| 155 |
+
router_indices,
|
| 156 |
+
router_scores,
|
| 157 |
+
self.gate_up_proj,
|
| 158 |
+
self.gate_up_proj_bias,
|
| 159 |
+
self.down_proj,
|
| 160 |
+
self.down_proj_bias,
|
| 161 |
+
expert_capacity,
|
| 162 |
+
)
|
| 163 |
+
|
| 164 |
+
return output, router_scores
|
| 165 |
+
|
| 166 |
+
# Run the model
|
| 167 |
+
set_seed(GENERAL_SEED)
|
| 168 |
+
|
| 169 |
+
device = torch.device(DEVICE)
|
| 170 |
+
dtype = to_dtype(DTYPE)
|
| 171 |
+
|
| 172 |
+
print("\n=== Binned Implementation ===")
|
| 173 |
+
# Initialize model with loaded weights
|
| 174 |
+
model = BinnedMoEMLP(
|
| 175 |
+
router_weight.to(device),
|
| 176 |
+
router_bias.to(device),
|
| 177 |
+
gate_up_proj.to(device),
|
| 178 |
+
gate_up_proj_bias.to(device),
|
| 179 |
+
down_proj.to(device),
|
| 180 |
+
down_proj_bias.to(device)
|
| 181 |
+
).to(device=device)
|
| 182 |
+
|
| 183 |
+
print(f"Router weight sum: {model.router.weight.sum().item():.6f}")
|
| 184 |
+
print(f"Gate/up proj sum: {model.gate_up_proj.sum().item():.6f}")
|
| 185 |
+
print(f"Down proj sum: {model.down_proj.sum().item():.6f}")
|
| 186 |
+
|
| 187 |
+
# Generate the same input as Yamoe
|
| 188 |
+
set_seed(INPUT_SEED)
|
| 189 |
+
x = torch.randn(BATCH_SIZE, SEQ_LEN, HIDDEN_SIZE, device=device, dtype=dtype) * 0.1
|
| 190 |
+
|
| 191 |
+
# Benchmark the model with varied inputs to prevent caching artifacts
|
| 192 |
+
tokens = BATCH_SIZE * SEQ_LEN
|
| 193 |
+
with bench_context(warmup=10, iters=50, device=device, dtype=dtype, tokens=tokens, save_json="binned_results.json", vary_inputs=True) as bench:
|
| 194 |
+
output, stats = bench(model, x)
|
| 195 |
+
print(f"\nOutput sum: {output[0].sum().item():.6f}")
|
moe_benchmarks/megablocks_yamoe/cells/gptoss_run.py
ADDED
|
@@ -0,0 +1,147 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# /// script
|
| 2 |
+
# dependencies = [
|
| 3 |
+
# "torch",
|
| 4 |
+
# "numpy",
|
| 5 |
+
# ]
|
| 6 |
+
# ///
|
| 7 |
+
|
| 8 |
+
import torch
|
| 9 |
+
from torch import nn
|
| 10 |
+
from torch.nn import functional as F
|
| 11 |
+
from bench_utils import to_dtype, tensor_stats, set_seed, bench_context
|
| 12 |
+
from config import (
|
| 13 |
+
NUM_EXPERTS, HIDDEN_SIZE, TOP_K,
|
| 14 |
+
BATCH_SIZE, SEQ_LEN, DTYPE, DEVICE,
|
| 15 |
+
WEIGHT_SEED, EXPERT_SEED, INPUT_SEED, GENERAL_SEED
|
| 16 |
+
)
|
| 17 |
+
from pathlib import Path
|
| 18 |
+
import os
|
| 19 |
+
|
| 20 |
+
# Discover the upstream artifact directory from env
|
| 21 |
+
data_dir = os.environ.get('UVNOTE_INPUT_SAVE_DATA', '.')
|
| 22 |
+
|
| 23 |
+
router_weight = torch.load(Path(data_dir) / 'router_weight.pt')
|
| 24 |
+
router_bias = torch.load(Path(data_dir) / 'router_bias.pt')
|
| 25 |
+
gate_up_proj = torch.load(Path(data_dir) / 'gate_up_proj.pt')
|
| 26 |
+
gate_up_proj_bias = torch.load(Path(data_dir) / 'gate_up_proj_bias.pt')
|
| 27 |
+
down_proj = torch.load(Path(data_dir) / 'down_proj.pt')
|
| 28 |
+
down_proj_bias = torch.load(Path(data_dir) / 'down_proj_bias.pt')
|
| 29 |
+
|
| 30 |
+
print("Loaded shared weights from artifacts")
|
| 31 |
+
print(f"Router weight sum: {router_weight.sum().item():.6f}")
|
| 32 |
+
print(f"Gate/up sum: {gate_up_proj.sum().item():.6f}")
|
| 33 |
+
print(f"Down sum: {down_proj.sum().item():.6f}")
|
| 34 |
+
|
| 35 |
+
class GptOssRouter(nn.Module):
|
| 36 |
+
def __init__(self, router_weight, router_bias):
|
| 37 |
+
super().__init__()
|
| 38 |
+
self.top_k = TOP_K
|
| 39 |
+
self.num_experts = NUM_EXPERTS
|
| 40 |
+
self.hidden_dim = HIDDEN_SIZE
|
| 41 |
+
self.weight = nn.Parameter(router_weight.clone())
|
| 42 |
+
self.bias = nn.Parameter(router_bias.clone())
|
| 43 |
+
|
| 44 |
+
def forward(self, hidden_states):
|
| 45 |
+
hidden_states = hidden_states.reshape(-1, self.hidden_dim)
|
| 46 |
+
router_logits = F.linear(hidden_states, self.weight, self.bias)
|
| 47 |
+
router_top_value, router_indices = torch.topk(router_logits, self.top_k, dim=-1)
|
| 48 |
+
router_top_value = torch.nn.functional.softmax(router_top_value, dim=1, dtype=router_top_value.dtype)
|
| 49 |
+
router_scores = torch.zeros_like(router_logits).scatter_(1, router_indices, router_top_value)
|
| 50 |
+
return router_scores, router_indices
|
| 51 |
+
|
| 52 |
+
class GptOssExperts(nn.Module):
|
| 53 |
+
def __init__(self, gate_up_proj, gate_up_proj_bias, down_proj, down_proj_bias):
|
| 54 |
+
super().__init__()
|
| 55 |
+
self.num_experts = NUM_EXPERTS
|
| 56 |
+
self.hidden_size = HIDDEN_SIZE
|
| 57 |
+
self.expert_dim = self.hidden_size
|
| 58 |
+
self.gate_up_proj = nn.Parameter(gate_up_proj.clone())
|
| 59 |
+
self.gate_up_proj_bias = nn.Parameter(gate_up_proj_bias.clone())
|
| 60 |
+
self.down_proj = nn.Parameter(down_proj.clone())
|
| 61 |
+
self.down_proj_bias = nn.Parameter(down_proj_bias.clone())
|
| 62 |
+
self.alpha = 1.702
|
| 63 |
+
self.limit = 7.0
|
| 64 |
+
|
| 65 |
+
def forward(self, hidden_states: torch.Tensor, router_indices=None, routing_weights=None) -> torch.Tensor:
|
| 66 |
+
batch_size = hidden_states.shape[0]
|
| 67 |
+
hidden_states = hidden_states.reshape(-1, self.hidden_size)
|
| 68 |
+
num_experts = routing_weights.shape[1]
|
| 69 |
+
|
| 70 |
+
if hidden_states.device.type == "cpu" or self.training:
|
| 71 |
+
next_states = torch.zeros_like(hidden_states, dtype=hidden_states.dtype, device=hidden_states.device)
|
| 72 |
+
with torch.no_grad():
|
| 73 |
+
expert_mask = torch.nn.functional.one_hot(router_indices, num_classes=num_experts)
|
| 74 |
+
expert_mask = expert_mask.permute(2, 1, 0)
|
| 75 |
+
expert_hit = torch.greater(expert_mask.sum(dim=(-1, -2)), 0).nonzero()
|
| 76 |
+
|
| 77 |
+
for expert_idx in expert_hit[:]:
|
| 78 |
+
expert_idx = expert_idx[0]
|
| 79 |
+
with torch.no_grad():
|
| 80 |
+
_, token_idx = torch.where(expert_mask[expert_idx])
|
| 81 |
+
current_state = hidden_states[token_idx]
|
| 82 |
+
gate_up = current_state @ self.gate_up_proj[expert_idx] + self.gate_up_proj_bias[expert_idx]
|
| 83 |
+
gate, up = gate_up[..., ::2], gate_up[..., 1::2]
|
| 84 |
+
gate = gate.clamp(min=None, max=self.limit)
|
| 85 |
+
up = up.clamp(min=-self.limit, max=self.limit)
|
| 86 |
+
glu = gate * torch.sigmoid(gate * self.alpha)
|
| 87 |
+
gated_output = (up + 1) * glu
|
| 88 |
+
out = gated_output @ self.down_proj[expert_idx] + self.down_proj_bias[expert_idx]
|
| 89 |
+
weighted_output = out * routing_weights[token_idx, expert_idx, None]
|
| 90 |
+
next_states.index_add_(0, token_idx, weighted_output.to(hidden_states.dtype))
|
| 91 |
+
next_states = next_states.view(batch_size, -1, self.hidden_size)
|
| 92 |
+
else:
|
| 93 |
+
hidden_states = hidden_states.repeat(num_experts, 1)
|
| 94 |
+
hidden_states = hidden_states.view(num_experts, -1, self.hidden_size)
|
| 95 |
+
gate_up = torch.bmm(hidden_states, self.gate_up_proj) + self.gate_up_proj_bias[..., None, :]
|
| 96 |
+
gate, up = gate_up[..., ::2], gate_up[..., 1::2]
|
| 97 |
+
gate = gate.clamp(min=None, max=self.limit)
|
| 98 |
+
up = up.clamp(min=-self.limit, max=self.limit)
|
| 99 |
+
glu = gate * torch.sigmoid(gate * self.alpha)
|
| 100 |
+
next_states = torch.bmm(((up + 1) * glu), self.down_proj)
|
| 101 |
+
next_states = next_states + self.down_proj_bias[..., None, :]
|
| 102 |
+
next_states = next_states.view(num_experts, batch_size, -1, self.hidden_size)
|
| 103 |
+
next_states = next_states * routing_weights.transpose(0, 1).view(num_experts, batch_size, -1)[..., None]
|
| 104 |
+
next_states = next_states.sum(dim=0)
|
| 105 |
+
return next_states
|
| 106 |
+
|
| 107 |
+
class GptOssMoEMLP(nn.Module):
|
| 108 |
+
def __init__(self, router_weight, router_bias, gate_up_proj, gate_up_proj_bias, down_proj, down_proj_bias):
|
| 109 |
+
super().__init__()
|
| 110 |
+
self.router = GptOssRouter(router_weight, router_bias)
|
| 111 |
+
self.experts = GptOssExperts(gate_up_proj, gate_up_proj_bias, down_proj, down_proj_bias)
|
| 112 |
+
|
| 113 |
+
def forward(self, hidden_states):
|
| 114 |
+
router_scores, router_indices = self.router(hidden_states)
|
| 115 |
+
routed_out = self.experts(hidden_states, router_indices=router_indices, routing_weights=router_scores)
|
| 116 |
+
return routed_out, router_scores
|
| 117 |
+
|
| 118 |
+
# Run the model
|
| 119 |
+
set_seed(GENERAL_SEED)
|
| 120 |
+
|
| 121 |
+
device = torch.device(DEVICE)
|
| 122 |
+
dtype = to_dtype(DTYPE)
|
| 123 |
+
|
| 124 |
+
print("\n=== GPT-OSS Implementation ===")
|
| 125 |
+
# Initialize model with loaded weights
|
| 126 |
+
model = GptOssMoEMLP(
|
| 127 |
+
router_weight.to(device),
|
| 128 |
+
router_bias.to(device),
|
| 129 |
+
gate_up_proj.to(device),
|
| 130 |
+
gate_up_proj_bias.to(device),
|
| 131 |
+
down_proj.to(device),
|
| 132 |
+
down_proj_bias.to(device)
|
| 133 |
+
).to(device=device)
|
| 134 |
+
|
| 135 |
+
print(f"Router weight sum: {model.router.weight.sum().item():.6f}")
|
| 136 |
+
print(f"Gate/up proj sum: {model.experts.gate_up_proj.sum().item():.6f}")
|
| 137 |
+
print(f"Down proj sum: {model.experts.down_proj.sum().item():.6f}")
|
| 138 |
+
|
| 139 |
+
# Generate the same input as other implementations
|
| 140 |
+
set_seed(INPUT_SEED)
|
| 141 |
+
x = torch.randn(BATCH_SIZE, SEQ_LEN, HIDDEN_SIZE, device=device, dtype=dtype) * 0.1
|
| 142 |
+
|
| 143 |
+
# Benchmark the model with varied inputs to prevent caching artifacts
|
| 144 |
+
tokens = BATCH_SIZE * SEQ_LEN
|
| 145 |
+
with bench_context(warmup=10, iters=50, device=device, dtype=dtype, tokens=tokens, save_json="gptoss_results.json", vary_inputs=True) as bench:
|
| 146 |
+
output, stats = bench(model, x)
|
| 147 |
+
print(f"\nOutput sum: {output[0].sum().item():.6f}")
|
moe_benchmarks/megablocks_yamoe/cells/gptoss_training_run.py
ADDED
|
@@ -0,0 +1,138 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# /// script
|
| 2 |
+
# dependencies = [
|
| 3 |
+
# "torch",
|
| 4 |
+
# "numpy",
|
| 5 |
+
# ]
|
| 6 |
+
# ///
|
| 7 |
+
|
| 8 |
+
import torch
|
| 9 |
+
from torch import nn
|
| 10 |
+
from torch.nn import functional as F
|
| 11 |
+
from bench_utils import to_dtype, tensor_stats, set_seed, bench_context
|
| 12 |
+
from config import (
|
| 13 |
+
NUM_EXPERTS, HIDDEN_SIZE, TOP_K,
|
| 14 |
+
BATCH_SIZE, SEQ_LEN, DTYPE, DEVICE,
|
| 15 |
+
WEIGHT_SEED, EXPERT_SEED, INPUT_SEED, GENERAL_SEED
|
| 16 |
+
)
|
| 17 |
+
from pathlib import Path
|
| 18 |
+
import os
|
| 19 |
+
|
| 20 |
+
# Discover the upstream artifact directory from env
|
| 21 |
+
data_dir = os.environ.get('UVNOTE_INPUT_SAVE_DATA', '.')
|
| 22 |
+
|
| 23 |
+
router_weight = torch.load(Path(data_dir) / 'router_weight.pt')
|
| 24 |
+
router_bias = torch.load(Path(data_dir) / 'router_bias.pt')
|
| 25 |
+
gate_up_proj = torch.load(Path(data_dir) / 'gate_up_proj.pt')
|
| 26 |
+
gate_up_proj_bias = torch.load(Path(data_dir) / 'gate_up_proj_bias.pt')
|
| 27 |
+
down_proj = torch.load(Path(data_dir) / 'down_proj.pt')
|
| 28 |
+
down_proj_bias = torch.load(Path(data_dir) / 'down_proj_bias.pt')
|
| 29 |
+
|
| 30 |
+
print("Loaded shared weights from artifacts")
|
| 31 |
+
print(f"Router weight sum: {router_weight.sum().item():.6f}")
|
| 32 |
+
print(f"Gate/up sum: {gate_up_proj.sum().item():.6f}")
|
| 33 |
+
print(f"Down sum: {down_proj.sum().item():.6f}")
|
| 34 |
+
|
| 35 |
+
class GptOssTrainingRouter(nn.Module):
|
| 36 |
+
def __init__(self, router_weight, router_bias):
|
| 37 |
+
super().__init__()
|
| 38 |
+
self.top_k = TOP_K
|
| 39 |
+
self.num_experts = NUM_EXPERTS
|
| 40 |
+
self.hidden_dim = HIDDEN_SIZE
|
| 41 |
+
self.weight = nn.Parameter(router_weight.clone())
|
| 42 |
+
self.bias = nn.Parameter(router_bias.clone())
|
| 43 |
+
|
| 44 |
+
def forward(self, hidden_states):
|
| 45 |
+
hidden_states = hidden_states.reshape(-1, self.hidden_dim)
|
| 46 |
+
router_logits = F.linear(hidden_states, self.weight, self.bias)
|
| 47 |
+
router_top_value, router_indices = torch.topk(router_logits, self.top_k, dim=-1)
|
| 48 |
+
router_top_value = torch.nn.functional.softmax(router_top_value, dim=1, dtype=router_top_value.dtype)
|
| 49 |
+
router_scores = torch.zeros_like(router_logits).scatter_(1, router_indices, router_top_value)
|
| 50 |
+
return router_scores, router_indices
|
| 51 |
+
|
| 52 |
+
class GptOssTrainingExperts(nn.Module):
|
| 53 |
+
def __init__(self, gate_up_proj, gate_up_proj_bias, down_proj, down_proj_bias):
|
| 54 |
+
super().__init__()
|
| 55 |
+
self.num_experts = NUM_EXPERTS
|
| 56 |
+
self.hidden_size = HIDDEN_SIZE
|
| 57 |
+
self.expert_dim = self.hidden_size
|
| 58 |
+
self.gate_up_proj = nn.Parameter(gate_up_proj.clone())
|
| 59 |
+
self.gate_up_proj_bias = nn.Parameter(gate_up_proj_bias.clone())
|
| 60 |
+
self.down_proj = nn.Parameter(down_proj.clone())
|
| 61 |
+
self.down_proj_bias = nn.Parameter(down_proj_bias.clone())
|
| 62 |
+
self.alpha = 1.702
|
| 63 |
+
self.limit = 7.0
|
| 64 |
+
|
| 65 |
+
def forward(self, hidden_states: torch.Tensor, router_indices=None, routing_weights=None) -> torch.Tensor:
|
| 66 |
+
batch_size = hidden_states.shape[0]
|
| 67 |
+
hidden_states = hidden_states.reshape(-1, self.hidden_size)
|
| 68 |
+
num_experts = routing_weights.shape[1]
|
| 69 |
+
|
| 70 |
+
# Force training mode path (expert loop instead of batched)
|
| 71 |
+
next_states = torch.zeros_like(hidden_states, dtype=hidden_states.dtype, device=hidden_states.device)
|
| 72 |
+
with torch.no_grad():
|
| 73 |
+
expert_mask = torch.nn.functional.one_hot(router_indices, num_classes=num_experts)
|
| 74 |
+
expert_mask = expert_mask.permute(2, 1, 0)
|
| 75 |
+
expert_hit = torch.greater(expert_mask.sum(dim=(-1, -2)), 0).nonzero()
|
| 76 |
+
|
| 77 |
+
for expert_idx in expert_hit[:]:
|
| 78 |
+
expert_idx = expert_idx[0]
|
| 79 |
+
with torch.no_grad():
|
| 80 |
+
_, token_idx = torch.where(expert_mask[expert_idx])
|
| 81 |
+
current_state = hidden_states[token_idx]
|
| 82 |
+
gate_up = current_state @ self.gate_up_proj[expert_idx] + self.gate_up_proj_bias[expert_idx]
|
| 83 |
+
gate, up = gate_up[..., ::2], gate_up[..., 1::2]
|
| 84 |
+
gate = gate.clamp(min=None, max=self.limit)
|
| 85 |
+
up = up.clamp(min=-self.limit, max=self.limit)
|
| 86 |
+
glu = gate * torch.sigmoid(gate * self.alpha)
|
| 87 |
+
gated_output = (up + 1) * glu
|
| 88 |
+
out = gated_output @ self.down_proj[expert_idx] + self.down_proj_bias[expert_idx]
|
| 89 |
+
weighted_output = out * routing_weights[token_idx, expert_idx, None]
|
| 90 |
+
next_states.index_add_(0, token_idx, weighted_output.to(hidden_states.dtype))
|
| 91 |
+
next_states = next_states.view(batch_size, -1, self.hidden_size)
|
| 92 |
+
return next_states
|
| 93 |
+
|
| 94 |
+
class GptOssTrainingMoEMLP(nn.Module):
|
| 95 |
+
def __init__(self, router_weight, router_bias, gate_up_proj, gate_up_proj_bias, down_proj, down_proj_bias):
|
| 96 |
+
super().__init__()
|
| 97 |
+
self.router = GptOssTrainingRouter(router_weight, router_bias)
|
| 98 |
+
self.experts = GptOssTrainingExperts(gate_up_proj, gate_up_proj_bias, down_proj, down_proj_bias)
|
| 99 |
+
|
| 100 |
+
def forward(self, hidden_states):
|
| 101 |
+
router_scores, router_indices = self.router(hidden_states)
|
| 102 |
+
routed_out = self.experts(hidden_states, router_indices=router_indices, routing_weights=router_scores)
|
| 103 |
+
return routed_out, router_scores
|
| 104 |
+
|
| 105 |
+
# Run the model
|
| 106 |
+
set_seed(GENERAL_SEED)
|
| 107 |
+
|
| 108 |
+
device = torch.device(DEVICE)
|
| 109 |
+
dtype = to_dtype(DTYPE)
|
| 110 |
+
|
| 111 |
+
print("\n=== GPT-OSS Implementation (Training Mode - Expert Loop) ===")
|
| 112 |
+
# Initialize model with loaded weights and force training mode
|
| 113 |
+
model = GptOssTrainingMoEMLP(
|
| 114 |
+
router_weight.to(device),
|
| 115 |
+
router_bias.to(device),
|
| 116 |
+
gate_up_proj.to(device),
|
| 117 |
+
gate_up_proj_bias.to(device),
|
| 118 |
+
down_proj.to(device),
|
| 119 |
+
down_proj_bias.to(device)
|
| 120 |
+
).to(device=device)
|
| 121 |
+
|
| 122 |
+
# Set to training mode to force expert loop path
|
| 123 |
+
model.train()
|
| 124 |
+
|
| 125 |
+
print(f"Router weight sum: {model.router.weight.sum().item():.6f}")
|
| 126 |
+
print(f"Gate/up proj sum: {model.experts.gate_up_proj.sum().item():.6f}")
|
| 127 |
+
print(f"Down proj sum: {model.experts.down_proj.sum().item():.6f}")
|
| 128 |
+
print(f"Model training mode: {model.training}")
|
| 129 |
+
|
| 130 |
+
# Generate the same input as other implementations
|
| 131 |
+
set_seed(INPUT_SEED)
|
| 132 |
+
x = torch.randn(BATCH_SIZE, SEQ_LEN, HIDDEN_SIZE, device=device, dtype=dtype) * 0.1
|
| 133 |
+
|
| 134 |
+
# Benchmark the model with varied inputs to prevent caching artifacts
|
| 135 |
+
tokens = BATCH_SIZE * SEQ_LEN
|
| 136 |
+
with bench_context(warmup=10, iters=50, device=device, dtype=dtype, tokens=tokens, save_json="gptoss_training_results.json", vary_inputs=True) as bench:
|
| 137 |
+
output, stats = bench(model, x)
|
| 138 |
+
print(f"\nOutput sum: {output[0].sum().item():.6f}")
|
moe_benchmarks/megablocks_yamoe/cells/megablocks_run.py
ADDED
|
@@ -0,0 +1,103 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# /// script
|
| 2 |
+
# dependencies = [
|
| 3 |
+
# "torch",
|
| 4 |
+
# "numpy",
|
| 5 |
+
# "kernels",
|
| 6 |
+
# ]
|
| 7 |
+
# ///
|
| 8 |
+
|
| 9 |
+
import torch
|
| 10 |
+
from torch import nn
|
| 11 |
+
from torch.nn import functional as F
|
| 12 |
+
from kernels import get_kernel, get_local_kernel
|
| 13 |
+
from bench_utils import to_dtype, tensor_stats, set_seed, bench_context
|
| 14 |
+
from config import (
|
| 15 |
+
NUM_EXPERTS, HIDDEN_SIZE, TOP_K,
|
| 16 |
+
BATCH_SIZE, SEQ_LEN, DTYPE, DEVICE,
|
| 17 |
+
WEIGHT_SEED, EXPERT_SEED, INPUT_SEED, GENERAL_SEED
|
| 18 |
+
)
|
| 19 |
+
from pathlib import Path
|
| 20 |
+
from collections import namedtuple
|
| 21 |
+
import os
|
| 22 |
+
|
| 23 |
+
# Discover the upstream artifact directory from env
|
| 24 |
+
data_dir = os.environ.get('UVNOTE_INPUT_SAVE_DATA', '.')
|
| 25 |
+
|
| 26 |
+
print(f"Loading weights from: {data_dir}")
|
| 27 |
+
|
| 28 |
+
router_weight = torch.load(Path(data_dir) / 'router_weight.pt')
|
| 29 |
+
router_bias = torch.load(Path(data_dir) / 'router_bias.pt')
|
| 30 |
+
gate_up_proj = torch.load(Path(data_dir) / 'gate_up_proj.pt')
|
| 31 |
+
gate_up_proj_bias = torch.load(Path(data_dir) / 'gate_up_proj_bias.pt')
|
| 32 |
+
down_proj = torch.load(Path(data_dir) / 'down_proj.pt')
|
| 33 |
+
down_proj_bias = torch.load(Path(data_dir) / 'down_proj_bias.pt')
|
| 34 |
+
|
| 35 |
+
print("Loaded shared weights from artifacts")
|
| 36 |
+
print(f"Router weight sum: {router_weight.sum().item():.6f}")
|
| 37 |
+
print(f"Gate/up sum: {gate_up_proj.sum().item():.6f}")
|
| 38 |
+
print(f"Down sum: {down_proj.sum().item():.6f}")
|
| 39 |
+
|
| 40 |
+
def build_megablocks_model(device: torch.device):
|
| 41 |
+
# Download optimized kernels from the Hugging Face hub
|
| 42 |
+
megablocks = get_kernel("kernels-community/megablocks", revision="v0.0.2")
|
| 43 |
+
model = megablocks.layers.MegaBlocksMoeMLP()
|
| 44 |
+
|
| 45 |
+
# Create attribute container for expert weights
|
| 46 |
+
model.experts = namedtuple(
|
| 47 |
+
"Experts", ["gate_up_proj", "gate_up_proj_bias", "down_proj", "down_proj_bias", "hidden_size"]
|
| 48 |
+
)
|
| 49 |
+
|
| 50 |
+
# Use loaded router weights for consistency
|
| 51 |
+
model.router = torch.nn.Linear(HIDDEN_SIZE, NUM_EXPERTS, device=device)
|
| 52 |
+
with torch.no_grad():
|
| 53 |
+
model.router.weight.copy_(router_weight)
|
| 54 |
+
model.router.bias.copy_(router_bias)
|
| 55 |
+
|
| 56 |
+
# Attach loaded expert weights to the experts container
|
| 57 |
+
e = model.experts
|
| 58 |
+
e.alpha = 1.702
|
| 59 |
+
e.capacity_factor = 32
|
| 60 |
+
e.gate_up_proj = torch.nn.Parameter(gate_up_proj.clone().to(device))
|
| 61 |
+
e.gate_up_proj_bias = torch.nn.Parameter(gate_up_proj_bias.clone().to(device))
|
| 62 |
+
e.down_proj = torch.nn.Parameter(down_proj.clone().to(device))
|
| 63 |
+
e.down_proj_bias = torch.nn.Parameter(down_proj_bias.clone().to(device))
|
| 64 |
+
e.hidden_size = HIDDEN_SIZE
|
| 65 |
+
|
| 66 |
+
# Log weight statistics for comparison
|
| 67 |
+
print(f"[MegaBlocks] Router weight sum: {model.router.weight.sum().item():.6f}")
|
| 68 |
+
print(f"[MegaBlocks] Gate/up projection shape: {tuple(e.gate_up_proj.shape)}, sum: {e.gate_up_proj.sum().item():.6f}")
|
| 69 |
+
print(f"[MegaBlocks] Down projection shape: {tuple(e.down_proj.shape)}, sum: {e.down_proj.sum().item():.6f}")
|
| 70 |
+
|
| 71 |
+
return model
|
| 72 |
+
|
| 73 |
+
# Create a wrapper to match the interface of other implementations
|
| 74 |
+
class MegaBlocksMoEWrapper(nn.Module):
|
| 75 |
+
def __init__(self, megablocks_model):
|
| 76 |
+
super().__init__()
|
| 77 |
+
self.model = megablocks_model
|
| 78 |
+
|
| 79 |
+
def forward(self, hidden_states):
|
| 80 |
+
# MegaBlocks expects input in the format (batch, seq_len, hidden_dim)
|
| 81 |
+
output, dummy_routing_weights = self.model(hidden_states)
|
| 82 |
+
return output, dummy_routing_weights
|
| 83 |
+
|
| 84 |
+
# Run the model
|
| 85 |
+
set_seed(GENERAL_SEED)
|
| 86 |
+
|
| 87 |
+
device = torch.device(DEVICE)
|
| 88 |
+
dtype = to_dtype(DTYPE)
|
| 89 |
+
|
| 90 |
+
print("\n=== MegaBlocks Implementation ===")
|
| 91 |
+
# Build MegaBlocks model with loaded weights
|
| 92 |
+
megablocks_model = build_megablocks_model(device)
|
| 93 |
+
model = MegaBlocksMoEWrapper(megablocks_model).to(device=device)
|
| 94 |
+
|
| 95 |
+
# Generate the same input as other implementations
|
| 96 |
+
set_seed(INPUT_SEED)
|
| 97 |
+
x = torch.randn(BATCH_SIZE, SEQ_LEN, HIDDEN_SIZE, device=device, dtype=dtype) * 0.1
|
| 98 |
+
|
| 99 |
+
# Benchmark the model with varied inputs to prevent caching artifacts
|
| 100 |
+
tokens = BATCH_SIZE * SEQ_LEN
|
| 101 |
+
with bench_context(warmup=10, iters=50, device=device, dtype=dtype, tokens=tokens, save_json="megablocks_results.json", vary_inputs=True) as bench:
|
| 102 |
+
output, stats = bench(model, x)
|
| 103 |
+
print(f"\nOutput sum: {output[0].sum().item():.6f}")
|
moe_benchmarks/megablocks_yamoe/cells/setup.py
ADDED
|
@@ -0,0 +1,116 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# /// script
|
| 2 |
+
# requires-python = ">=3.12"
|
| 3 |
+
# dependencies = [
|
| 4 |
+
# "accelerate>=1.10.1",
|
| 5 |
+
# "torch>=2.7.0",
|
| 6 |
+
# "kernels==0.10.0",
|
| 7 |
+
# "transformers@https://github.com/huggingface/transformers.git",
|
| 8 |
+
# "ipdb>=0.13.13",
|
| 9 |
+
# "matplotlib>=3.7.2",
|
| 10 |
+
# "numpy>=1.24.3",
|
| 11 |
+
# ]
|
| 12 |
+
# ///
|
| 13 |
+
|
| 14 |
+
import torch
|
| 15 |
+
from transformers import GptOssForCausalLM, PreTrainedTokenizerFast, Mxfp4Config
|
| 16 |
+
import time
|
| 17 |
+
import torch.nn as nn
|
| 18 |
+
from kernels import register_kernel_mapping, Mode, LayerRepository
|
| 19 |
+
import sys
|
| 20 |
+
import torch.profiler
|
| 21 |
+
import gc
|
| 22 |
+
import logging
|
| 23 |
+
|
| 24 |
+
# set to debug logging
|
| 25 |
+
logging.basicConfig(level=logging.INFO)
|
| 26 |
+
|
| 27 |
+
def reset_peak_memory_stats():
|
| 28 |
+
"""Clear CUDA cache and reset memory allocation counters."""
|
| 29 |
+
torch.cuda.empty_cache()
|
| 30 |
+
if torch.cuda.is_available():
|
| 31 |
+
torch.cuda.reset_peak_memory_stats()
|
| 32 |
+
gc.collect()
|
| 33 |
+
|
| 34 |
+
def get_memory_stats():
|
| 35 |
+
"""Get current and peak CUDA memory usage."""
|
| 36 |
+
if not torch.cuda.is_available():
|
| 37 |
+
return {"allocated_gb": 0, "peak_gb": 0, "reserved_gb": 0}
|
| 38 |
+
return {
|
| 39 |
+
"allocated_gb": torch.cuda.memory_allocated() / 1e9,
|
| 40 |
+
"peak_gb": torch.cuda.max_memory_allocated() / 1e9,
|
| 41 |
+
"reserved_gb": torch.cuda.memory_reserved() / 1e9,
|
| 42 |
+
}
|
| 43 |
+
|
| 44 |
+
def override_kernel_layer_name(cls_name: str, value) -> bool:
|
| 45 |
+
"""Helper to dynamically override the kernel_layer_name in a model class."""
|
| 46 |
+
for mod in sys.modules.values():
|
| 47 |
+
if mod is None:
|
| 48 |
+
continue
|
| 49 |
+
obj = getattr(mod, cls_name, None)
|
| 50 |
+
if isinstance(obj, type) and issubclass(obj, nn.Module):
|
| 51 |
+
setattr(obj, "kernel_layer_name", value)
|
| 52 |
+
print(f"Overrode {cls_name}.kernel_layer_name to {value}")
|
| 53 |
+
return True
|
| 54 |
+
return False
|
| 55 |
+
|
| 56 |
+
|
| 57 |
+
# Init the model the normal way
|
| 58 |
+
model_id = "openai/gpt-oss-20b"
|
| 59 |
+
tokenizer = PreTrainedTokenizerFast.from_pretrained(model_id)
|
| 60 |
+
quantization_config = Mxfp4Config(dequantize=True)
|
| 61 |
+
|
| 62 |
+
|
| 63 |
+
from kernels import replace_kernel_forward_from_hub, register_kernel_mapping, LayerRepository, Mode
|
| 64 |
+
|
| 65 |
+
from transformers.models.gpt_oss.modeling_gpt_oss import GptOssMLP, GptOssRMSNorm
|
| 66 |
+
|
| 67 |
+
replace_kernel_forward_from_hub(GptOssMLP, "Yamoe")
|
| 68 |
+
replace_kernel_forward_from_hub(GptOssRMSNorm, None)
|
| 69 |
+
custom_mapping = {
|
| 70 |
+
"Yamoe": {
|
| 71 |
+
"cuda": {
|
| 72 |
+
Mode.INFERENCE: LayerRepository(
|
| 73 |
+
repo_id="drbh/yamoe",
|
| 74 |
+
layer_name="Yamoe",
|
| 75 |
+
revision="v0.3.0",
|
| 76 |
+
)
|
| 77 |
+
}
|
| 78 |
+
}
|
| 79 |
+
}
|
| 80 |
+
register_kernel_mapping(custom_mapping)
|
| 81 |
+
|
| 82 |
+
|
| 83 |
+
model = GptOssForCausalLM.from_pretrained(
|
| 84 |
+
model_id,
|
| 85 |
+
dtype="bfloat16",
|
| 86 |
+
device_map="auto",
|
| 87 |
+
use_kernels=True,
|
| 88 |
+
quantization_config=quantization_config,
|
| 89 |
+
).eval()
|
| 90 |
+
|
| 91 |
+
messages = [
|
| 92 |
+
{"role": "system", "content": "What is Tensor Parallelism?"},
|
| 93 |
+
]
|
| 94 |
+
|
| 95 |
+
inputs = tokenizer.apply_chat_template(
|
| 96 |
+
messages,
|
| 97 |
+
add_generation_prompt=True,
|
| 98 |
+
return_tensors="pt",
|
| 99 |
+
return_dict=True,
|
| 100 |
+
reasoning_effort="low",
|
| 101 |
+
).to("cuda")
|
| 102 |
+
|
| 103 |
+
max_tokens = 256
|
| 104 |
+
|
| 105 |
+
with torch.inference_mode():
|
| 106 |
+
start_time = time.perf_counter()
|
| 107 |
+
generated = model.generate(
|
| 108 |
+
**inputs,
|
| 109 |
+
max_new_tokens=max_tokens,
|
| 110 |
+
do_sample=False,
|
| 111 |
+
temperature=None,
|
| 112 |
+
)
|
| 113 |
+
end_time = time.perf_counter()
|
| 114 |
+
|
| 115 |
+
print(tokenizer.decode(generated[0], skip_special_tokens=False))
|
| 116 |
+
print(f"Generation took {end_time - start_time:.2f} seconds")
|
moe_benchmarks/megablocks_yamoe/megablocks_yamoe.html
CHANGED
|
@@ -3710,61 +3710,288 @@ span.linenos.special { color: #000000; background-color: #ffffc0; padding-left:
|
|
| 3710 |
<div class="system-info">
|
| 3711 |
<div class="system-info-header">Generated on:</div>
|
| 3712 |
<div class="system-info-content">
|
| 3713 |
-
Linux x86_64 | Linux-6.
|
| 3714 |
</div>
|
| 3715 |
</div>
|
| 3716 |
|
| 3717 |
<div class="main-content">
|
| 3718 |
-
<
|
|
|
|
|
|
|
|
|
|
|
|
|
| 3719 |
<div class="cell-header">
|
| 3720 |
<span class="collapse-indicators">
|
| 3721 |
-
<span onclick="toggleCode('
|
| 3722 |
-
<span onclick="toggleOutput('
|
| 3723 |
-
<span id="uv-indicator-
|
| 3724 |
</span> |
|
| 3725 |
-
Cell:
|
| 3726 |
-
| <button class="run-btn" onclick="runCell('
|
| 3727 |
-
<button class="copy-btn" onclick="copyCell('
|
| 3728 |
-
<a href="cells/
|
| 3729 |
</div>
|
| 3730 |
-
<div id="code-
|
| 3731 |
<div class="highlight-with-lines">
|
| 3732 |
-
<div class="line-numbers" id="lines-
|
| 3733 |
-
<a class="line-number" data-cell="
|
| 3734 |
-
<a class="line-number" data-cell="
|
| 3735 |
-
<a class="line-number" data-cell="
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 3736 |
</div>
|
| 3737 |
<div class="code-wrap">
|
| 3738 |
-
<div class="highlight"><pre><span></span><span class="
|
| 3739 |
-
|
| 3740 |
-
<span class="
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 3741 |
</pre></div>
|
| 3742 |
|
| 3743 |
-
<div class="code-line-highlight" id="line-highlight-
|
| 3744 |
</div>
|
| 3745 |
</div>
|
| 3746 |
</div>
|
| 3747 |
-
<div id="output-
|
| 3748 |
-
<div class="cell-stderr">
|
| 3749 |
-
|
| 3750 |
-
|
| 3751 |
-
|
| 3752 |
-
|
| 3753 |
-
|
| 3754 |
-
|
| 3755 |
-
|
| 3756 |
-
|
| 3757 |
-
|
| 3758 |
-
|
| 3759 |
-
|
|
|
|
|
|
|
| 3760 |
</div>
|
| 3761 |
</div>
|
| 3762 |
</div>
|
| 3763 |
-
|
| 3764 |
-
<h1>Comparison of Megablocks and Yamoe Kernels</h1>
|
| 3765 |
-
<p>This note compares the performance of the Megablocks and Yamoe kernels on the GPT-OSS-20B model.</p>
|
| 3766 |
-
<h2>Megablocks kernel</h2>
|
| 3767 |
-
<h2>Yamoe Kernel</h2>
|
| 3768 |
</div>
|
| 3769 |
|
| 3770 |
</body>
|
|
|
|
| 3710 |
<div class="system-info">
|
| 3711 |
<div class="system-info-header">Generated on:</div>
|
| 3712 |
<div class="system-info-content">
|
| 3713 |
+
Linux x86_64 | Linux-6.12.40-64.114.amzn2023.x86_64-x86_64-with-glibc2.36
|
| 3714 |
</div>
|
| 3715 |
</div>
|
| 3716 |
|
| 3717 |
<div class="main-content">
|
| 3718 |
+
<h1>Comparison of Megablocks and Yamoe Kernels</h1>
|
| 3719 |
+
<p>This note compares the performance of the Megablocks and Yamoe kernels on the GPT-OSS-20B model.</p>
|
| 3720 |
+
<h2>Megablocks kernel</h2>
|
| 3721 |
+
<h2>Yamoe Kernel</h2>
|
| 3722 |
+
<div class="cell cell-failed" id="cell-setup">
|
| 3723 |
<div class="cell-header">
|
| 3724 |
<span class="collapse-indicators">
|
| 3725 |
+
<span onclick="toggleCode('setup')" style="cursor: pointer;">▼ code</span>
|
| 3726 |
+
<span onclick="toggleOutput('setup')" style="cursor: pointer;">▼ output</span>
|
| 3727 |
+
<span id="uv-indicator-setup" style="cursor: default; opacity: 0.3;">▶ uv-logs</span>
|
| 3728 |
</span> |
|
| 3729 |
+
Cell: setup | 19.20s | FAILED
|
| 3730 |
+
| <button class="run-btn" onclick="runCell('setup')">▶ run</button>
|
| 3731 |
+
<button class="copy-btn" onclick="copyCell('setup')">Copy</button>
|
| 3732 |
+
<a href="cells/setup.py" target="_blank" class="raw-btn">Raw</a>
|
| 3733 |
</div>
|
| 3734 |
+
<div id="code-setup" class="cell-code" data-lines="116">
|
| 3735 |
<div class="highlight-with-lines">
|
| 3736 |
+
<div class="line-numbers" id="lines-setup">
|
| 3737 |
+
<a class="line-number" data-cell="setup" data-line="1" href="#cell-setup" onclick="event.preventDefault(); selectCellLine('setup', 1, true);">1</a>
|
| 3738 |
+
<a class="line-number" data-cell="setup" data-line="2" href="#cell-setup" onclick="event.preventDefault(); selectCellLine('setup', 2, true);">2</a>
|
| 3739 |
+
<a class="line-number" data-cell="setup" data-line="3" href="#cell-setup" onclick="event.preventDefault(); selectCellLine('setup', 3, true);">3</a>
|
| 3740 |
+
<a class="line-number" data-cell="setup" data-line="4" href="#cell-setup" onclick="event.preventDefault(); selectCellLine('setup', 4, true);">4</a>
|
| 3741 |
+
<a class="line-number" data-cell="setup" data-line="5" href="#cell-setup" onclick="event.preventDefault(); selectCellLine('setup', 5, true);">5</a>
|
| 3742 |
+
<a class="line-number" data-cell="setup" data-line="6" href="#cell-setup" onclick="event.preventDefault(); selectCellLine('setup', 6, true);">6</a>
|
| 3743 |
+
<a class="line-number" data-cell="setup" data-line="7" href="#cell-setup" onclick="event.preventDefault(); selectCellLine('setup', 7, true);">7</a>
|
| 3744 |
+
<a class="line-number" data-cell="setup" data-line="8" href="#cell-setup" onclick="event.preventDefault(); selectCellLine('setup', 8, true);">8</a>
|
| 3745 |
+
<a class="line-number" data-cell="setup" data-line="9" href="#cell-setup" onclick="event.preventDefault(); selectCellLine('setup', 9, true);">9</a>
|
| 3746 |
+
<a class="line-number" data-cell="setup" data-line="10" href="#cell-setup" onclick="event.preventDefault(); selectCellLine('setup', 10, true);">10</a>
|
| 3747 |
+
<a class="line-number" data-cell="setup" data-line="11" href="#cell-setup" onclick="event.preventDefault(); selectCellLine('setup', 11, true);">11</a>
|
| 3748 |
+
<a class="line-number" data-cell="setup" data-line="12" href="#cell-setup" onclick="event.preventDefault(); selectCellLine('setup', 12, true);">12</a>
|
| 3749 |
+
<a class="line-number" data-cell="setup" data-line="13" href="#cell-setup" onclick="event.preventDefault(); selectCellLine('setup', 13, true);">13</a>
|
| 3750 |
+
<a class="line-number" data-cell="setup" data-line="14" href="#cell-setup" onclick="event.preventDefault(); selectCellLine('setup', 14, true);">14</a>
|
| 3751 |
+
<a class="line-number" data-cell="setup" data-line="15" href="#cell-setup" onclick="event.preventDefault(); selectCellLine('setup', 15, true);">15</a>
|
| 3752 |
+
<a class="line-number" data-cell="setup" data-line="16" href="#cell-setup" onclick="event.preventDefault(); selectCellLine('setup', 16, true);">16</a>
|
| 3753 |
+
<a class="line-number" data-cell="setup" data-line="17" href="#cell-setup" onclick="event.preventDefault(); selectCellLine('setup', 17, true);">17</a>
|
| 3754 |
+
<a class="line-number" data-cell="setup" data-line="18" href="#cell-setup" onclick="event.preventDefault(); selectCellLine('setup', 18, true);">18</a>
|
| 3755 |
+
<a class="line-number" data-cell="setup" data-line="19" href="#cell-setup" onclick="event.preventDefault(); selectCellLine('setup', 19, true);">19</a>
|
| 3756 |
+
<a class="line-number" data-cell="setup" data-line="20" href="#cell-setup" onclick="event.preventDefault(); selectCellLine('setup', 20, true);">20</a>
|
| 3757 |
+
<a class="line-number" data-cell="setup" data-line="21" href="#cell-setup" onclick="event.preventDefault(); selectCellLine('setup', 21, true);">21</a>
|
| 3758 |
+
<a class="line-number" data-cell="setup" data-line="22" href="#cell-setup" onclick="event.preventDefault(); selectCellLine('setup', 22, true);">22</a>
|
| 3759 |
+
<a class="line-number" data-cell="setup" data-line="23" href="#cell-setup" onclick="event.preventDefault(); selectCellLine('setup', 23, true);">23</a>
|
| 3760 |
+
<a class="line-number" data-cell="setup" data-line="24" href="#cell-setup" onclick="event.preventDefault(); selectCellLine('setup', 24, true);">24</a>
|
| 3761 |
+
<a class="line-number" data-cell="setup" data-line="25" href="#cell-setup" onclick="event.preventDefault(); selectCellLine('setup', 25, true);">25</a>
|
| 3762 |
+
<a class="line-number" data-cell="setup" data-line="26" href="#cell-setup" onclick="event.preventDefault(); selectCellLine('setup', 26, true);">26</a>
|
| 3763 |
+
<a class="line-number" data-cell="setup" data-line="27" href="#cell-setup" onclick="event.preventDefault(); selectCellLine('setup', 27, true);">27</a>
|
| 3764 |
+
<a class="line-number" data-cell="setup" data-line="28" href="#cell-setup" onclick="event.preventDefault(); selectCellLine('setup', 28, true);">28</a>
|
| 3765 |
+
<a class="line-number" data-cell="setup" data-line="29" href="#cell-setup" onclick="event.preventDefault(); selectCellLine('setup', 29, true);">29</a>
|
| 3766 |
+
<a class="line-number" data-cell="setup" data-line="30" href="#cell-setup" onclick="event.preventDefault(); selectCellLine('setup', 30, true);">30</a>
|
| 3767 |
+
<a class="line-number" data-cell="setup" data-line="31" href="#cell-setup" onclick="event.preventDefault(); selectCellLine('setup', 31, true);">31</a>
|
| 3768 |
+
<a class="line-number" data-cell="setup" data-line="32" href="#cell-setup" onclick="event.preventDefault(); selectCellLine('setup', 32, true);">32</a>
|
| 3769 |
+
<a class="line-number" data-cell="setup" data-line="33" href="#cell-setup" onclick="event.preventDefault(); selectCellLine('setup', 33, true);">33</a>
|
| 3770 |
+
<a class="line-number" data-cell="setup" data-line="34" href="#cell-setup" onclick="event.preventDefault(); selectCellLine('setup', 34, true);">34</a>
|
| 3771 |
+
<a class="line-number" data-cell="setup" data-line="35" href="#cell-setup" onclick="event.preventDefault(); selectCellLine('setup', 35, true);">35</a>
|
| 3772 |
+
<a class="line-number" data-cell="setup" data-line="36" href="#cell-setup" onclick="event.preventDefault(); selectCellLine('setup', 36, true);">36</a>
|
| 3773 |
+
<a class="line-number" data-cell="setup" data-line="37" href="#cell-setup" onclick="event.preventDefault(); selectCellLine('setup', 37, true);">37</a>
|
| 3774 |
+
<a class="line-number" data-cell="setup" data-line="38" href="#cell-setup" onclick="event.preventDefault(); selectCellLine('setup', 38, true);">38</a>
|
| 3775 |
+
<a class="line-number" data-cell="setup" data-line="39" href="#cell-setup" onclick="event.preventDefault(); selectCellLine('setup', 39, true);">39</a>
|
| 3776 |
+
<a class="line-number" data-cell="setup" data-line="40" href="#cell-setup" onclick="event.preventDefault(); selectCellLine('setup', 40, true);">40</a>
|
| 3777 |
+
<a class="line-number" data-cell="setup" data-line="41" href="#cell-setup" onclick="event.preventDefault(); selectCellLine('setup', 41, true);">41</a>
|
| 3778 |
+
<a class="line-number" data-cell="setup" data-line="42" href="#cell-setup" onclick="event.preventDefault(); selectCellLine('setup', 42, true);">42</a>
|
| 3779 |
+
<a class="line-number" data-cell="setup" data-line="43" href="#cell-setup" onclick="event.preventDefault(); selectCellLine('setup', 43, true);">43</a>
|
| 3780 |
+
<a class="line-number" data-cell="setup" data-line="44" href="#cell-setup" onclick="event.preventDefault(); selectCellLine('setup', 44, true);">44</a>
|
| 3781 |
+
<a class="line-number" data-cell="setup" data-line="45" href="#cell-setup" onclick="event.preventDefault(); selectCellLine('setup', 45, true);">45</a>
|
| 3782 |
+
<a class="line-number" data-cell="setup" data-line="46" href="#cell-setup" onclick="event.preventDefault(); selectCellLine('setup', 46, true);">46</a>
|
| 3783 |
+
<a class="line-number" data-cell="setup" data-line="47" href="#cell-setup" onclick="event.preventDefault(); selectCellLine('setup', 47, true);">47</a>
|
| 3784 |
+
<a class="line-number" data-cell="setup" data-line="48" href="#cell-setup" onclick="event.preventDefault(); selectCellLine('setup', 48, true);">48</a>
|
| 3785 |
+
<a class="line-number" data-cell="setup" data-line="49" href="#cell-setup" onclick="event.preventDefault(); selectCellLine('setup', 49, true);">49</a>
|
| 3786 |
+
<a class="line-number" data-cell="setup" data-line="50" href="#cell-setup" onclick="event.preventDefault(); selectCellLine('setup', 50, true);">50</a>
|
| 3787 |
+
<a class="line-number" data-cell="setup" data-line="51" href="#cell-setup" onclick="event.preventDefault(); selectCellLine('setup', 51, true);">51</a>
|
| 3788 |
+
<a class="line-number" data-cell="setup" data-line="52" href="#cell-setup" onclick="event.preventDefault(); selectCellLine('setup', 52, true);">52</a>
|
| 3789 |
+
<a class="line-number" data-cell="setup" data-line="53" href="#cell-setup" onclick="event.preventDefault(); selectCellLine('setup', 53, true);">53</a>
|
| 3790 |
+
<a class="line-number" data-cell="setup" data-line="54" href="#cell-setup" onclick="event.preventDefault(); selectCellLine('setup', 54, true);">54</a>
|
| 3791 |
+
<a class="line-number" data-cell="setup" data-line="55" href="#cell-setup" onclick="event.preventDefault(); selectCellLine('setup', 55, true);">55</a>
|
| 3792 |
+
<a class="line-number" data-cell="setup" data-line="56" href="#cell-setup" onclick="event.preventDefault(); selectCellLine('setup', 56, true);">56</a>
|
| 3793 |
+
<a class="line-number" data-cell="setup" data-line="57" href="#cell-setup" onclick="event.preventDefault(); selectCellLine('setup', 57, true);">57</a>
|
| 3794 |
+
<a class="line-number" data-cell="setup" data-line="58" href="#cell-setup" onclick="event.preventDefault(); selectCellLine('setup', 58, true);">58</a>
|
| 3795 |
+
<a class="line-number" data-cell="setup" data-line="59" href="#cell-setup" onclick="event.preventDefault(); selectCellLine('setup', 59, true);">59</a>
|
| 3796 |
+
<a class="line-number" data-cell="setup" data-line="60" href="#cell-setup" onclick="event.preventDefault(); selectCellLine('setup', 60, true);">60</a>
|
| 3797 |
+
<a class="line-number" data-cell="setup" data-line="61" href="#cell-setup" onclick="event.preventDefault(); selectCellLine('setup', 61, true);">61</a>
|
| 3798 |
+
<a class="line-number" data-cell="setup" data-line="62" href="#cell-setup" onclick="event.preventDefault(); selectCellLine('setup', 62, true);">62</a>
|
| 3799 |
+
<a class="line-number" data-cell="setup" data-line="63" href="#cell-setup" onclick="event.preventDefault(); selectCellLine('setup', 63, true);">63</a>
|
| 3800 |
+
<a class="line-number" data-cell="setup" data-line="64" href="#cell-setup" onclick="event.preventDefault(); selectCellLine('setup', 64, true);">64</a>
|
| 3801 |
+
<a class="line-number" data-cell="setup" data-line="65" href="#cell-setup" onclick="event.preventDefault(); selectCellLine('setup', 65, true);">65</a>
|
| 3802 |
+
<a class="line-number" data-cell="setup" data-line="66" href="#cell-setup" onclick="event.preventDefault(); selectCellLine('setup', 66, true);">66</a>
|
| 3803 |
+
<a class="line-number" data-cell="setup" data-line="67" href="#cell-setup" onclick="event.preventDefault(); selectCellLine('setup', 67, true);">67</a>
|
| 3804 |
+
<a class="line-number" data-cell="setup" data-line="68" href="#cell-setup" onclick="event.preventDefault(); selectCellLine('setup', 68, true);">68</a>
|
| 3805 |
+
<a class="line-number" data-cell="setup" data-line="69" href="#cell-setup" onclick="event.preventDefault(); selectCellLine('setup', 69, true);">69</a>
|
| 3806 |
+
<a class="line-number" data-cell="setup" data-line="70" href="#cell-setup" onclick="event.preventDefault(); selectCellLine('setup', 70, true);">70</a>
|
| 3807 |
+
<a class="line-number" data-cell="setup" data-line="71" href="#cell-setup" onclick="event.preventDefault(); selectCellLine('setup', 71, true);">71</a>
|
| 3808 |
+
<a class="line-number" data-cell="setup" data-line="72" href="#cell-setup" onclick="event.preventDefault(); selectCellLine('setup', 72, true);">72</a>
|
| 3809 |
+
<a class="line-number" data-cell="setup" data-line="73" href="#cell-setup" onclick="event.preventDefault(); selectCellLine('setup', 73, true);">73</a>
|
| 3810 |
+
<a class="line-number" data-cell="setup" data-line="74" href="#cell-setup" onclick="event.preventDefault(); selectCellLine('setup', 74, true);">74</a>
|
| 3811 |
+
<a class="line-number" data-cell="setup" data-line="75" href="#cell-setup" onclick="event.preventDefault(); selectCellLine('setup', 75, true);">75</a>
|
| 3812 |
+
<a class="line-number" data-cell="setup" data-line="76" href="#cell-setup" onclick="event.preventDefault(); selectCellLine('setup', 76, true);">76</a>
|
| 3813 |
+
<a class="line-number" data-cell="setup" data-line="77" href="#cell-setup" onclick="event.preventDefault(); selectCellLine('setup', 77, true);">77</a>
|
| 3814 |
+
<a class="line-number" data-cell="setup" data-line="78" href="#cell-setup" onclick="event.preventDefault(); selectCellLine('setup', 78, true);">78</a>
|
| 3815 |
+
<a class="line-number" data-cell="setup" data-line="79" href="#cell-setup" onclick="event.preventDefault(); selectCellLine('setup', 79, true);">79</a>
|
| 3816 |
+
<a class="line-number" data-cell="setup" data-line="80" href="#cell-setup" onclick="event.preventDefault(); selectCellLine('setup', 80, true);">80</a>
|
| 3817 |
+
<a class="line-number" data-cell="setup" data-line="81" href="#cell-setup" onclick="event.preventDefault(); selectCellLine('setup', 81, true);">81</a>
|
| 3818 |
+
<a class="line-number" data-cell="setup" data-line="82" href="#cell-setup" onclick="event.preventDefault(); selectCellLine('setup', 82, true);">82</a>
|
| 3819 |
+
<a class="line-number" data-cell="setup" data-line="83" href="#cell-setup" onclick="event.preventDefault(); selectCellLine('setup', 83, true);">83</a>
|
| 3820 |
+
<a class="line-number" data-cell="setup" data-line="84" href="#cell-setup" onclick="event.preventDefault(); selectCellLine('setup', 84, true);">84</a>
|
| 3821 |
+
<a class="line-number" data-cell="setup" data-line="85" href="#cell-setup" onclick="event.preventDefault(); selectCellLine('setup', 85, true);">85</a>
|
| 3822 |
+
<a class="line-number" data-cell="setup" data-line="86" href="#cell-setup" onclick="event.preventDefault(); selectCellLine('setup', 86, true);">86</a>
|
| 3823 |
+
<a class="line-number" data-cell="setup" data-line="87" href="#cell-setup" onclick="event.preventDefault(); selectCellLine('setup', 87, true);">87</a>
|
| 3824 |
+
<a class="line-number" data-cell="setup" data-line="88" href="#cell-setup" onclick="event.preventDefault(); selectCellLine('setup', 88, true);">88</a>
|
| 3825 |
+
<a class="line-number" data-cell="setup" data-line="89" href="#cell-setup" onclick="event.preventDefault(); selectCellLine('setup', 89, true);">89</a>
|
| 3826 |
+
<a class="line-number" data-cell="setup" data-line="90" href="#cell-setup" onclick="event.preventDefault(); selectCellLine('setup', 90, true);">90</a>
|
| 3827 |
+
<a class="line-number" data-cell="setup" data-line="91" href="#cell-setup" onclick="event.preventDefault(); selectCellLine('setup', 91, true);">91</a>
|
| 3828 |
+
<a class="line-number" data-cell="setup" data-line="92" href="#cell-setup" onclick="event.preventDefault(); selectCellLine('setup', 92, true);">92</a>
|
| 3829 |
+
<a class="line-number" data-cell="setup" data-line="93" href="#cell-setup" onclick="event.preventDefault(); selectCellLine('setup', 93, true);">93</a>
|
| 3830 |
+
<a class="line-number" data-cell="setup" data-line="94" href="#cell-setup" onclick="event.preventDefault(); selectCellLine('setup', 94, true);">94</a>
|
| 3831 |
+
<a class="line-number" data-cell="setup" data-line="95" href="#cell-setup" onclick="event.preventDefault(); selectCellLine('setup', 95, true);">95</a>
|
| 3832 |
+
<a class="line-number" data-cell="setup" data-line="96" href="#cell-setup" onclick="event.preventDefault(); selectCellLine('setup', 96, true);">96</a>
|
| 3833 |
+
<a class="line-number" data-cell="setup" data-line="97" href="#cell-setup" onclick="event.preventDefault(); selectCellLine('setup', 97, true);">97</a>
|
| 3834 |
+
<a class="line-number" data-cell="setup" data-line="98" href="#cell-setup" onclick="event.preventDefault(); selectCellLine('setup', 98, true);">98</a>
|
| 3835 |
+
<a class="line-number" data-cell="setup" data-line="99" href="#cell-setup" onclick="event.preventDefault(); selectCellLine('setup', 99, true);">99</a>
|
| 3836 |
+
<a class="line-number" data-cell="setup" data-line="100" href="#cell-setup" onclick="event.preventDefault(); selectCellLine('setup', 100, true);">100</a>
|
| 3837 |
+
<a class="line-number" data-cell="setup" data-line="101" href="#cell-setup" onclick="event.preventDefault(); selectCellLine('setup', 101, true);">101</a>
|
| 3838 |
+
<a class="line-number" data-cell="setup" data-line="102" href="#cell-setup" onclick="event.preventDefault(); selectCellLine('setup', 102, true);">102</a>
|
| 3839 |
+
<a class="line-number" data-cell="setup" data-line="103" href="#cell-setup" onclick="event.preventDefault(); selectCellLine('setup', 103, true);">103</a>
|
| 3840 |
+
<a class="line-number" data-cell="setup" data-line="104" href="#cell-setup" onclick="event.preventDefault(); selectCellLine('setup', 104, true);">104</a>
|
| 3841 |
+
<a class="line-number" data-cell="setup" data-line="105" href="#cell-setup" onclick="event.preventDefault(); selectCellLine('setup', 105, true);">105</a>
|
| 3842 |
+
<a class="line-number" data-cell="setup" data-line="106" href="#cell-setup" onclick="event.preventDefault(); selectCellLine('setup', 106, true);">106</a>
|
| 3843 |
+
<a class="line-number" data-cell="setup" data-line="107" href="#cell-setup" onclick="event.preventDefault(); selectCellLine('setup', 107, true);">107</a>
|
| 3844 |
+
<a class="line-number" data-cell="setup" data-line="108" href="#cell-setup" onclick="event.preventDefault(); selectCellLine('setup', 108, true);">108</a>
|
| 3845 |
+
<a class="line-number" data-cell="setup" data-line="109" href="#cell-setup" onclick="event.preventDefault(); selectCellLine('setup', 109, true);">109</a>
|
| 3846 |
+
<a class="line-number" data-cell="setup" data-line="110" href="#cell-setup" onclick="event.preventDefault(); selectCellLine('setup', 110, true);">110</a>
|
| 3847 |
+
<a class="line-number" data-cell="setup" data-line="111" href="#cell-setup" onclick="event.preventDefault(); selectCellLine('setup', 111, true);">111</a>
|
| 3848 |
+
<a class="line-number" data-cell="setup" data-line="112" href="#cell-setup" onclick="event.preventDefault(); selectCellLine('setup', 112, true);">112</a>
|
| 3849 |
+
<a class="line-number" data-cell="setup" data-line="113" href="#cell-setup" onclick="event.preventDefault(); selectCellLine('setup', 113, true);">113</a>
|
| 3850 |
+
<a class="line-number" data-cell="setup" data-line="114" href="#cell-setup" onclick="event.preventDefault(); selectCellLine('setup', 114, true);">114</a>
|
| 3851 |
+
<a class="line-number" data-cell="setup" data-line="115" href="#cell-setup" onclick="event.preventDefault(); selectCellLine('setup', 115, true);">115</a>
|
| 3852 |
+
<a class="line-number" data-cell="setup" data-line="116" href="#cell-setup" onclick="event.preventDefault(); selectCellLine('setup', 116, true);">116</a>
|
| 3853 |
</div>
|
| 3854 |
<div class="code-wrap">
|
| 3855 |
+
<div class="highlight"><pre><span></span><span class="c1"># /// script</span>
|
| 3856 |
+
<span class="c1"># requires-python = ">=3.12"</span>
|
| 3857 |
+
<span class="c1"># dependencies = [</span>
|
| 3858 |
+
<span class="c1"># "accelerate>=1.10.1",</span>
|
| 3859 |
+
<span class="c1"># "torch>=2.7.0",</span>
|
| 3860 |
+
<span class="c1"># "kernels==0.10.0",</span>
|
| 3861 |
+
<span class="c1"># "transformers@https://github.com/huggingface/transformers.git",</span>
|
| 3862 |
+
<span class="c1"># "ipdb>=0.13.13",</span>
|
| 3863 |
+
<span class="c1"># "matplotlib>=3.7.2",</span>
|
| 3864 |
+
<span class="c1"># "numpy>=1.24.3",</span>
|
| 3865 |
+
<span class="c1"># ]</span>
|
| 3866 |
+
<span class="c1"># ///</span>
|
| 3867 |
+
|
| 3868 |
+
<span class="kn">import</span><span class="w"> </span><span class="nn">torch</span>
|
| 3869 |
+
<span class="kn">from</span><span class="w"> </span><span class="nn">transformers</span><span class="w"> </span><span class="kn">import</span> <span class="n">GptOssForCausalLM</span><span class="p">,</span> <span class="n">PreTrainedTokenizerFast</span><span class="p">,</span> <span class="n">Mxfp4Config</span>
|
| 3870 |
+
<span class="kn">import</span><span class="w"> </span><span class="nn">time</span>
|
| 3871 |
+
<span class="kn">import</span><span class="w"> </span><span class="nn">torch.nn</span><span class="w"> </span><span class="k">as</span><span class="w"> </span><span class="nn">nn</span>
|
| 3872 |
+
<span class="kn">from</span><span class="w"> </span><span class="nn">kernels</span><span class="w"> </span><span class="kn">import</span> <span class="n">register_kernel_mapping</span><span class="p">,</span> <span class="n">Mode</span><span class="p">,</span> <span class="n">LayerRepository</span>
|
| 3873 |
+
<span class="kn">import</span><span class="w"> </span><span class="nn">sys</span>
|
| 3874 |
+
<span class="kn">import</span><span class="w"> </span><span class="nn">torch.profiler</span>
|
| 3875 |
+
<span class="kn">import</span><span class="w"> </span><span class="nn">gc</span>
|
| 3876 |
+
<span class="kn">import</span><span class="w"> </span><span class="nn">logging</span>
|
| 3877 |
+
|
| 3878 |
+
<span class="c1"># set to debug logging</span>
|
| 3879 |
+
<span class="n">logging</span><span class="o">.</span><span class="n">basicConfig</span><span class="p">(</span><span class="n">level</span><span class="o">=</span><span class="n">logging</span><span class="o">.</span><span class="n">INFO</span><span class="p">)</span>
|
| 3880 |
+
|
| 3881 |
+
<span class="k">def</span><span class="w"> </span><span class="nf">reset_peak_memory_stats</span><span class="p">():</span>
|
| 3882 |
+
<span class="w"> </span><span class="sd">"""Clear CUDA cache and reset memory allocation counters."""</span>
|
| 3883 |
+
<span class="n">torch</span><span class="o">.</span><span class="n">cuda</span><span class="o">.</span><span class="n">empty_cache</span><span class="p">()</span>
|
| 3884 |
+
<span class="k">if</span> <span class="n">torch</span><span class="o">.</span><span class="n">cuda</span><span class="o">.</span><span class="n">is_available</span><span class="p">():</span>
|
| 3885 |
+
<span class="n">torch</span><span class="o">.</span><span class="n">cuda</span><span class="o">.</span><span class="n">reset_peak_memory_stats</span><span class="p">()</span>
|
| 3886 |
+
<span class="n">gc</span><span class="o">.</span><span class="n">collect</span><span class="p">()</span>
|
| 3887 |
+
|
| 3888 |
+
<span class="k">def</span><span class="w"> </span><span class="nf">get_memory_stats</span><span class="p">():</span>
|
| 3889 |
+
<span class="w"> </span><span class="sd">"""Get current and peak CUDA memory usage."""</span>
|
| 3890 |
+
<span class="k">if</span> <span class="ow">not</span> <span class="n">torch</span><span class="o">.</span><span class="n">cuda</span><span class="o">.</span><span class="n">is_available</span><span class="p">():</span>
|
| 3891 |
+
<span class="k">return</span> <span class="p">{</span><span class="s2">"allocated_gb"</span><span class="p">:</span> <span class="mi">0</span><span class="p">,</span> <span class="s2">"peak_gb"</span><span class="p">:</span> <span class="mi">0</span><span class="p">,</span> <span class="s2">"reserved_gb"</span><span class="p">:</span> <span class="mi">0</span><span class="p">}</span>
|
| 3892 |
+
<span class="k">return</span> <span class="p">{</span>
|
| 3893 |
+
<span class="s2">"allocated_gb"</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">cuda</span><span class="o">.</span><span class="n">memory_allocated</span><span class="p">()</span> <span class="o">/</span> <span class="mf">1e9</span><span class="p">,</span>
|
| 3894 |
+
<span class="s2">"peak_gb"</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">cuda</span><span class="o">.</span><span class="n">max_memory_allocated</span><span class="p">()</span> <span class="o">/</span> <span class="mf">1e9</span><span class="p">,</span>
|
| 3895 |
+
<span class="s2">"reserved_gb"</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">cuda</span><span class="o">.</span><span class="n">memory_reserved</span><span class="p">()</span> <span class="o">/</span> <span class="mf">1e9</span><span class="p">,</span>
|
| 3896 |
+
<span class="p">}</span>
|
| 3897 |
+
|
| 3898 |
+
<span class="k">def</span><span class="w"> </span><span class="nf">override_kernel_layer_name</span><span class="p">(</span><span class="n">cls_name</span><span class="p">:</span> <span class="nb">str</span><span class="p">,</span> <span class="n">value</span><span class="p">)</span> <span class="o">-></span> <span class="nb">bool</span><span class="p">:</span>
|
| 3899 |
+
<span class="w"> </span><span class="sd">"""Helper to dynamically override the kernel_layer_name in a model class."""</span>
|
| 3900 |
+
<span class="k">for</span> <span class="n">mod</span> <span class="ow">in</span> <span class="n">sys</span><span class="o">.</span><span class="n">modules</span><span class="o">.</span><span class="n">values</span><span class="p">():</span>
|
| 3901 |
+
<span class="k">if</span> <span class="n">mod</span> <span class="ow">is</span> <span class="kc">None</span><span class="p">:</span>
|
| 3902 |
+
<span class="k">continue</span>
|
| 3903 |
+
<span class="n">obj</span> <span class="o">=</span> <span class="nb">getattr</span><span class="p">(</span><span class="n">mod</span><span class="p">,</span> <span class="n">cls_name</span><span class="p">,</span> <span class="kc">None</span><span class="p">)</span>
|
| 3904 |
+
<span class="k">if</span> <span class="nb">isinstance</span><span class="p">(</span><span class="n">obj</span><span class="p">,</span> <span class="nb">type</span><span class="p">)</span> <span class="ow">and</span> <span class="nb">issubclass</span><span class="p">(</span><span class="n">obj</span><span class="p">,</span> <span class="n">nn</span><span class="o">.</span><span class="n">Module</span><span class="p">):</span>
|
| 3905 |
+
<span class="nb">setattr</span><span class="p">(</span><span class="n">obj</span><span class="p">,</span> <span class="s2">"kernel_layer_name"</span><span class="p">,</span> <span class="n">value</span><span class="p">)</span>
|
| 3906 |
+
<span class="nb">print</span><span class="p">(</span><span class="sa">f</span><span class="s2">"Overrode </span><span class="si">{</span><span class="n">cls_name</span><span class="si">}</span><span class="s2">.kernel_layer_name to </span><span class="si">{</span><span class="n">value</span><span class="si">}</span><span class="s2">"</span><span class="p">)</span>
|
| 3907 |
+
<span class="k">return</span> <span class="kc">True</span>
|
| 3908 |
+
<span class="k">return</span> <span class="kc">False</span>
|
| 3909 |
+
|
| 3910 |
+
|
| 3911 |
+
<span class="c1"># Init the model the normal way</span>
|
| 3912 |
+
<span class="n">model_id</span> <span class="o">=</span> <span class="s2">"openai/gpt-oss-20b"</span>
|
| 3913 |
+
<span class="n">tokenizer</span> <span class="o">=</span> <span class="n">PreTrainedTokenizerFast</span><span class="o">.</span><span class="n">from_pretrained</span><span class="p">(</span><span class="n">model_id</span><span class="p">)</span>
|
| 3914 |
+
<span class="n">quantization_config</span> <span class="o">=</span> <span class="n">Mxfp4Config</span><span class="p">(</span><span class="n">dequantize</span><span class="o">=</span><span class="kc">True</span><span class="p">)</span>
|
| 3915 |
+
|
| 3916 |
+
|
| 3917 |
+
<span class="kn">from</span><span class="w"> </span><span class="nn">kernels</span><span class="w"> </span><span class="kn">import</span> <span class="n">replace_kernel_forward_from_hub</span><span class="p">,</span> <span class="n">register_kernel_mapping</span><span class="p">,</span> <span class="n">LayerRepository</span><span class="p">,</span> <span class="n">Mode</span>
|
| 3918 |
+
|
| 3919 |
+
<span class="kn">from</span><span class="w"> </span><span class="nn">transformers.models.gpt_oss.modeling_gpt_oss</span><span class="w"> </span><span class="kn">import</span> <span class="n">GptOssMLP</span><span class="p">,</span> <span class="n">GptOssRMSNorm</span>
|
| 3920 |
+
|
| 3921 |
+
<span class="n">replace_kernel_forward_from_hub</span><span class="p">(</span><span class="n">GptOssMLP</span><span class="p">,</span> <span class="s2">"Yamoe"</span><span class="p">)</span>
|
| 3922 |
+
<span class="n">replace_kernel_forward_from_hub</span><span class="p">(</span><span class="n">GptOssRMSNorm</span><span class="p">,</span> <span class="kc">None</span><span class="p">)</span>
|
| 3923 |
+
<span class="n">custom_mapping</span> <span class="o">=</span> <span class="p">{</span>
|
| 3924 |
+
<span class="s2">"Yamoe"</span><span class="p">:</span> <span class="p">{</span>
|
| 3925 |
+
<span class="s2">"cuda"</span><span class="p">:</span> <span class="p">{</span>
|
| 3926 |
+
<span class="n">Mode</span><span class="o">.</span><span class="n">INFERENCE</span><span class="p">:</span> <span class="n">LayerRepository</span><span class="p">(</span>
|
| 3927 |
+
<span class="n">repo_id</span><span class="o">=</span><span class="s2">"drbh/yamoe"</span><span class="p">,</span>
|
| 3928 |
+
<span class="n">layer_name</span><span class="o">=</span><span class="s2">"Yamoe"</span><span class="p">,</span>
|
| 3929 |
+
<span class="n">revision</span><span class="o">=</span><span class="s2">"v0.3.0"</span><span class="p">,</span>
|
| 3930 |
+
<span class="p">)</span>
|
| 3931 |
+
<span class="p">}</span>
|
| 3932 |
+
<span class="p">}</span>
|
| 3933 |
+
<span class="p">}</span>
|
| 3934 |
+
<span class="n">register_kernel_mapping</span><span class="p">(</span><span class="n">custom_mapping</span><span class="p">)</span>
|
| 3935 |
+
|
| 3936 |
+
|
| 3937 |
+
<span class="n">model</span> <span class="o">=</span> <span class="n">GptOssForCausalLM</span><span class="o">.</span><span class="n">from_pretrained</span><span class="p">(</span>
|
| 3938 |
+
<span class="n">model_id</span><span class="p">,</span>
|
| 3939 |
+
<span class="n">dtype</span><span class="o">=</span><span class="s2">"bfloat16"</span><span class="p">,</span>
|
| 3940 |
+
<span class="n">device_map</span><span class="o">=</span><span class="s2">"auto"</span><span class="p">,</span>
|
| 3941 |
+
<span class="n">use_kernels</span><span class="o">=</span><span class="kc">True</span><span class="p">,</span>
|
| 3942 |
+
<span class="n">quantization_config</span><span class="o">=</span><span class="n">quantization_config</span><span class="p">,</span>
|
| 3943 |
+
<span class="p">)</span><span class="o">.</span><span class="n">eval</span><span class="p">()</span>
|
| 3944 |
+
|
| 3945 |
+
<span class="n">messages</span> <span class="o">=</span> <span class="p">[</span>
|
| 3946 |
+
<span class="p">{</span><span class="s2">"role"</span><span class="p">:</span> <span class="s2">"system"</span><span class="p">,</span> <span class="s2">"content"</span><span class="p">:</span> <span class="s2">"What is Tensor Parallelism?"</span><span class="p">},</span>
|
| 3947 |
+
<span class="p">]</span>
|
| 3948 |
+
|
| 3949 |
+
<span class="n">inputs</span> <span class="o">=</span> <span class="n">tokenizer</span><span class="o">.</span><span class="n">apply_chat_template</span><span class="p">(</span>
|
| 3950 |
+
<span class="n">messages</span><span class="p">,</span>
|
| 3951 |
+
<span class="n">add_generation_prompt</span><span class="o">=</span><span class="kc">True</span><span class="p">,</span>
|
| 3952 |
+
<span class="n">return_tensors</span><span class="o">=</span><span class="s2">"pt"</span><span class="p">,</span>
|
| 3953 |
+
<span class="n">return_dict</span><span class="o">=</span><span class="kc">True</span><span class="p">,</span>
|
| 3954 |
+
<span class="n">reasoning_effort</span><span class="o">=</span><span class="s2">"low"</span><span class="p">,</span>
|
| 3955 |
+
<span class="p">)</span><span class="o">.</span><span class="n">to</span><span class="p">(</span><span class="s2">"cuda"</span><span class="p">)</span>
|
| 3956 |
+
|
| 3957 |
+
<span class="n">max_tokens</span> <span class="o">=</span> <span class="mi">256</span>
|
| 3958 |
+
|
| 3959 |
+
<span class="k">with</span> <span class="n">torch</span><span class="o">.</span><span class="n">inference_mode</span><span class="p">():</span>
|
| 3960 |
+
<span class="n">start_time</span> <span class="o">=</span> <span class="n">time</span><span class="o">.</span><span class="n">perf_counter</span><span class="p">()</span>
|
| 3961 |
+
<span class="n">generated</span> <span class="o">=</span> <span class="n">model</span><span class="o">.</span><span class="n">generate</span><span class="p">(</span>
|
| 3962 |
+
<span class="o">**</span><span class="n">inputs</span><span class="p">,</span>
|
| 3963 |
+
<span class="n">max_new_tokens</span><span class="o">=</span><span class="n">max_tokens</span><span class="p">,</span>
|
| 3964 |
+
<span class="n">do_sample</span><span class="o">=</span><span class="kc">False</span><span class="p">,</span>
|
| 3965 |
+
<span class="n">temperature</span><span class="o">=</span><span class="kc">None</span><span class="p">,</span>
|
| 3966 |
+
<span class="p">)</span>
|
| 3967 |
+
<span class="n">end_time</span> <span class="o">=</span> <span class="n">time</span><span class="o">.</span><span class="n">perf_counter</span><span class="p">()</span>
|
| 3968 |
+
|
| 3969 |
+
<span class="nb">print</span><span class="p">(</span><span class="n">tokenizer</span><span class="o">.</span><span class="n">decode</span><span class="p">(</span><span class="n">generated</span><span class="p">[</span><span class="mi">0</span><span class="p">],</span> <span class="n">skip_special_tokens</span><span class="o">=</span><span class="kc">False</span><span class="p">))</span>
|
| 3970 |
+
<span class="nb">print</span><span class="p">(</span><span class="sa">f</span><span class="s2">"Generation took </span><span class="si">{</span><span class="n">end_time</span><span class="w"> </span><span class="o">-</span><span class="w"> </span><span class="n">start_time</span><span class="si">:</span><span class="s2">.2f</span><span class="si">}</span><span class="s2"> seconds"</span><span class="p">)</span>
|
| 3971 |
</pre></div>
|
| 3972 |
|
| 3973 |
+
<div class="code-line-highlight" id="line-highlight-setup"></div>
|
| 3974 |
</div>
|
| 3975 |
</div>
|
| 3976 |
</div>
|
| 3977 |
+
<div id="output-setup" class="cell-output">
|
| 3978 |
+
<div class="cell-stderr">Downloading cpython-3.13.7-linux-x86_64-gnu (download) (32.0MiB)
|
| 3979 |
+
Downloading cpython-3.13.7-linux-x86_64-gnu (download)
|
| 3980 |
+
Updating https://github.com/huggingface/transformers.git (HEAD)
|
| 3981 |
+
Updated https://github.com/huggingface/transformers.git (449533af73874470e914a203391635e04ac2ffc8)
|
| 3982 |
+
× No solution found when resolving script dependencies:
|
| 3983 |
+
╰─▶ Because only transformers==4.57.0.dev0 is available and
|
| 3984 |
+
transformers==4.57.0.dev0 depends on huggingface-hub==1.0.0rc1,
|
| 3985 |
+
we can conclude that all versions of transformers depend on
|
| 3986 |
+
huggingface-hub==1.0.0rc1.
|
| 3987 |
+
And because kernels==0.10.0 depends on huggingface-hub>=0.26.0,<1.0,
|
| 3988 |
+
we can conclude that kernels==0.10.0 and all versions of transformers
|
| 3989 |
+
are incompatible.
|
| 3990 |
+
And because you require kernels==0.10.0 and transformers, we can
|
| 3991 |
+
conclude that your requirements are unsatisfiable.
|
| 3992 |
</div>
|
| 3993 |
</div>
|
| 3994 |
</div>
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 3995 |
</div>
|
| 3996 |
|
| 3997 |
</body>
|
moe_benchmarks/megablocks_yamoe/torch_profile.html
CHANGED
|
The diff for this file is too large to render.
See raw diff
|
|
|