drbh HF Staff commited on
Commit
782c694
·
verified ·
1 Parent(s): b7b6ff8

Upload folder using huggingface_hub

Browse files
site/artifacts/charts/benchmark_dashboard.png ADDED
site/artifacts/charts/latency.png ADDED
site/artifacts/charts/memory.png ADDED
site/artifacts/charts/throughput.png ADDED
site/artifacts/setup/benchmark_avg_tokens_per_sec.txt ADDED
@@ -0,0 +1 @@
 
 
1
+ 5.301658854167735
site/artifacts/setup/benchmark_dashboard.png ADDED
site/artifacts/setup/benchmark_memory.txt ADDED
@@ -0,0 +1 @@
 
 
1
+ 9.398672896,9.414898176,10.334765056
site/artifacts/setup/benchmark_times.txt ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ 12.075035744113848
2
+ 12.0710428240709
3
+ 12.070115809096023
4
+ 12.070908240042627
5
+ 12.071364195086062
site/cells/charts.py ADDED
@@ -0,0 +1,140 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # /// script
2
+ # dependencies = [
3
+ # "matplotlib",
4
+ # "numpy",
5
+ # ]
6
+ # ///
7
+
8
+ import matplotlib.pyplot as plt
9
+ import numpy as np
10
+ import os
11
+
12
+ # get the pathf rom UVNOTE_SETUP env var
13
+ setup_path = os.getenv("UVNOTE_INPUT_SETUP", ".")
14
+ print(f"Reading benchmark data from: {setup_path}")
15
+
16
+ num_runs = 5
17
+ max_tokens = 64
18
+ times = []
19
+ with open(os.path.join(setup_path, "benchmark_times.txt"), "r") as f:
20
+ for line in f:
21
+ times.append(float(line.strip()))
22
+
23
+
24
+ avg_time = 0.0
25
+ min_time = 0.0
26
+ max_time = 0.0
27
+ final_mem = {"allocated_gb": 0.0, "peak_gb": 0.0, "reserved_gb": 0.0}
28
+
29
+ avg_tokens_per_sec = 0.0
30
+ with open(os.path.join(setup_path, "benchmark_avg_tokens_per_sec.txt"), "r") as f:
31
+ avg_tokens_per_sec = float(f.read().strip())
32
+
33
+ times_file = os.path.join(setup_path, "benchmark_times.txt")
34
+ memory_file = os.path.join(setup_path, "benchmark_memory.txt")
35
+
36
+
37
+ # Minimal brutalist palette (dark theme): grayscale + 1 accent
38
+ ACCENT = '#5ec8f8' # calm cyan-blue accent
39
+ FG = '#e6e6e6' # light gray text/lines
40
+ MUTED = '#9aa0a6' # muted gray for secondary
41
+ GRID = '#333333' # grid lines
42
+
43
+ # Styling tuned for clarity, high contrast, few colors
44
+ plt.style.use('dark_background')
45
+ plt.rcParams['figure.facecolor'] = 'none'
46
+ plt.rcParams['axes.facecolor'] = 'none'
47
+ plt.rcParams['savefig.facecolor'] = 'none'
48
+ plt.rcParams['savefig.transparent'] = True
49
+ plt.rcParams['font.family'] = 'monospace'
50
+ plt.rcParams['font.weight'] = 'bold'
51
+ plt.rcParams['axes.linewidth'] = 3
52
+ plt.rcParams['grid.linewidth'] = 2
53
+ plt.rcParams['lines.linewidth'] = 3
54
+ plt.rcParams['patch.linewidth'] = 2
55
+
56
+ # Prepare data
57
+ runs = list(range(1, len(times) + 1))
58
+ tokens_per_sec_all = [max_tokens / t for t in times]
59
+
60
+ # Chart 1: Throughput Performance
61
+ fig1, ax1 = plt.subplots(1, 1, figsize=(12, 6))
62
+ fig1.patch.set_alpha(0)
63
+ ax1.patch.set_alpha(0)
64
+
65
+ ax1.plot(runs, tokens_per_sec_all, color=ACCENT, marker='o', markersize=12,
66
+ markerfacecolor=ACCENT, markeredgecolor=FG, markeredgewidth=3, linewidth=5, label='tok/s')
67
+ ax1.fill_between(runs, 0, tokens_per_sec_all, alpha=0.2, color=ACCENT)
68
+ ax1.axhline(y=avg_tokens_per_sec, color=FG, linestyle='--', linewidth=3,
69
+ label=f'AVG: {avg_tokens_per_sec:.1f}')
70
+ ax1.set_title('THROUGHPUT PERFORMANCE', color=FG, fontsize=18, pad=20, fontweight='bold')
71
+ ax1.set_xlabel('RUN NUMBER', color=FG, fontsize=14, fontweight='bold')
72
+ ax1.set_ylabel('TOKENS/SEC', color=FG, fontsize=14, fontweight='bold')
73
+ ax1.grid(True, color=GRID, alpha=0.5, linewidth=2)
74
+ ax1.tick_params(colors=FG, labelsize=12)
75
+ legend1 = ax1.legend(frameon=False, loc='lower right')
76
+ for text in legend1.get_texts():
77
+ text.set_color(FG)
78
+ text.set_fontweight('bold')
79
+ plt.tight_layout()
80
+ plt.savefig('throughput.png', dpi=150, bbox_inches='tight', transparent=True)
81
+ plt.show()
82
+
83
+ # Chart 2: Generation Latency
84
+ fig2, ax2 = plt.subplots(1, 1, figsize=(12, 6))
85
+ fig2.patch.set_alpha(0)
86
+ ax2.patch.set_alpha(0)
87
+
88
+ bar_colors = [ACCENT if i % 2 == 0 else MUTED for i in range(len(times))]
89
+ bars = ax2.bar(runs, times, color=bar_colors, edgecolor=FG, linewidth=3, width=0.6)
90
+ ax2.axhline(y=avg_time, color=FG, linestyle='--', linewidth=3,
91
+ label=f'AVG: {avg_time:.2f}s')
92
+ for i, (run, time, bar) in enumerate(zip(runs, times, bars)):
93
+ ax2.text(run, time + 0.02, f'{time:.2f}s', ha='center', va='bottom',
94
+ color=FG, fontweight='bold', fontsize=11)
95
+ ax2.set_title('GENERATION LATENCY', color=FG, fontsize=18, pad=20, fontweight='bold')
96
+ ax2.set_xlabel('RUN NUMBER', color=FG, fontsize=14, fontweight='bold')
97
+ ax2.set_ylabel('TIME (SECONDS)', color=FG, fontsize=14, fontweight='bold')
98
+ ax2.grid(True, axis='y', color=GRID, alpha=0.5, linewidth=2)
99
+ ax2.tick_params(colors=FG, labelsize=12)
100
+ ax2.set_ylim(0, max(times) * 1.15)
101
+ legend2 = ax2.legend(frameon=False, loc='upper right')
102
+ for text in legend2.get_texts():
103
+ text.set_color(FG)
104
+ text.set_fontweight('bold')
105
+ plt.tight_layout()
106
+ plt.savefig('latency.png', dpi=150, bbox_inches='tight', transparent=True)
107
+ plt.show()
108
+
109
+ # Chart 3: Memory Usage
110
+ fig3, ax3 = plt.subplots(1, 1, figsize=(12, 6))
111
+ fig3.patch.set_alpha(0)
112
+ ax3.patch.set_alpha(0)
113
+
114
+ memory_labels = ['ALLOCATED', 'PEAK', 'RESERVED']
115
+ memory_values = [final_mem['allocated_gb'], final_mem['peak_gb'], final_mem['reserved_gb']]
116
+ colors_mem = [MUTED, ACCENT, FG]
117
+ bars = ax3.barh(memory_labels, memory_values, color=colors_mem, edgecolor=FG, linewidth=3, height=0.5)
118
+ for i, (label, value, bar) in enumerate(zip(memory_labels, memory_values, bars)):
119
+ ax3.text(value + 0.5, i, f'{value:.1f} GB', va='center',
120
+ color=FG, fontweight='bold', fontsize=13)
121
+ ax3.set_title('MEMORY USAGE', color=FG, fontsize=18, pad=20, fontweight='bold')
122
+ ax3.set_xlabel('GIGABYTES', color=FG, fontsize=14, fontweight='bold')
123
+ ax3.set_xlim(0, max(memory_values) * 1.3)
124
+ ax3.grid(True, axis='x', color=GRID, alpha=0.5, linewidth=2)
125
+ ax3.tick_params(colors=FG, labelsize=12)
126
+ ax3.set_yticks(range(len(memory_labels)))
127
+ ax3.set_yticklabels(memory_labels, fontweight='bold')
128
+ plt.tight_layout()
129
+ plt.savefig('memory.png', dpi=150, bbox_inches='tight', transparent=True)
130
+ plt.show()
131
+
132
+ print(f"\n📊 Charts saved as:")
133
+ print(f" • throughput.png")
134
+ print(f" • latency.png")
135
+ print(f" • memory.png")
136
+ print(f"\nBenchmark Summary:")
137
+ print(f" avg tokens/sec: {avg_tokens_per_sec:.1f}")
138
+ print(f" min time: {min_time:.3f}s")
139
+ print(f" max time: {max_time:.3f}s")
140
+ print(f" peak memory: {final_mem['peak_gb']:.2f}GB")
site/cells/forward_and_backward.py ADDED
@@ -0,0 +1,102 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # /// script
2
+ # requires-python = ">=3.12"
3
+ # dependencies = [
4
+ # "accelerate>=1.10.1",
5
+ # "torch>=2.7.0",
6
+ # "kernels==0.10.0",
7
+ # "transformers@https://github.com/huggingface/transformers.git",
8
+ # "ipdb>=0.13.13",
9
+ # "matplotlib>=3.7.2",
10
+ # "numpy>=1.24.3",
11
+ # ]
12
+ # ///
13
+
14
+ import torch
15
+ from transformers import GptOssForCausalLM, PreTrainedTokenizerFast, Mxfp4Config
16
+ import time
17
+ import torch.nn as nn
18
+ from kernels import register_kernel_mapping, Mode, LayerRepository
19
+ import sys
20
+ import torch.profiler
21
+ import gc
22
+ import logging
23
+ from transformers.models.gpt_oss.modeling_gpt_oss import GptOssRMSNorm
24
+
25
+ # remove liger kernel for testing
26
+ replace_kernel_forward_from_hub(GptOssRMSNorm, None)
27
+
28
+ # set to debug logging
29
+ logging.basicConfig(level=logging.INFO)
30
+
31
+ def reset_peak_memory_stats():
32
+ """Clear CUDA cache and reset memory allocation counters."""
33
+ torch.cuda.empty_cache()
34
+ if torch.cuda.is_available():
35
+ torch.cuda.reset_peak_memory_stats()
36
+ gc.collect()
37
+
38
+ def get_memory_stats():
39
+ """Get current and peak CUDA memory usage."""
40
+ if not torch.cuda.is_available():
41
+ return {"allocated_gb": 0, "peak_gb": 0, "reserved_gb": 0}
42
+ return {
43
+ "allocated_gb": torch.cuda.memory_allocated() / 1e9,
44
+ "peak_gb": torch.cuda.max_memory_allocated() / 1e9,
45
+ "reserved_gb": torch.cuda.memory_reserved() / 1e9,
46
+ }
47
+
48
+ def override_kernel_layer_name(cls_name: str, value) -> bool:
49
+ """Helper to dynamically override the kernel_layer_name in a model class."""
50
+ for mod in sys.modules.values():
51
+ if mod is None:
52
+ continue
53
+ obj = getattr(mod, cls_name, None)
54
+ if isinstance(obj, type) and issubclass(obj, nn.Module):
55
+ setattr(obj, "kernel_layer_name", value)
56
+ print(f"Overrode {cls_name}.kernel_layer_name to {value}")
57
+ return True
58
+ return False
59
+
60
+
61
+ # Init the model the normal way
62
+ model_id = "openai/gpt-oss-20b"
63
+ tokenizer = PreTrainedTokenizerFast.from_pretrained(model_id)
64
+ quantization_config = Mxfp4Config(dequantize=True)
65
+
66
+ model = GptOssForCausalLM.from_pretrained(
67
+ model_id,
68
+ dtype="bfloat16",
69
+ device_map="auto",
70
+ use_kernels=True,
71
+ quantization_config=quantization_config,
72
+ training=True,
73
+ ).eval()
74
+
75
+ messages = [
76
+ {"role": "system", "content": "What is Tensor Parallelism?"},
77
+ ]
78
+
79
+ inputs = tokenizer.apply_chat_template(
80
+ messages,
81
+ add_generation_prompt=True,
82
+ return_tensors="pt",
83
+ return_dict=True,
84
+ reasoning_effort="low",
85
+ ).to("cuda")
86
+
87
+ max_tokens = 512
88
+
89
+
90
+ # forward and backward pass
91
+ with torch.autograd.set_grad_enabled(True):
92
+ start_time = time.perf_counter()
93
+ generated = model.generate(
94
+ **inputs,
95
+ max_new_tokens=max_tokens,
96
+ do_sample=False,
97
+ temperature=None,
98
+ )
99
+ end_time = time.perf_counter()
100
+ print(tokenizer.decode(generated[0], skip_special_tokens=False))
101
+ print(f"Generation took {end_time - start_time:.2f} seconds")
102
+
site/cells/forward_only.py ADDED
@@ -0,0 +1,96 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # /// script
2
+ # requires-python = ">=3.12"
3
+ # dependencies = [
4
+ # "accelerate>=1.10.1",
5
+ # "torch>=2.7.0",
6
+ # "kernels==0.10.0",
7
+ # "transformers@https://github.com/huggingface/transformers.git",
8
+ # "ipdb>=0.13.13",
9
+ # "matplotlib>=3.7.2",
10
+ # "numpy>=1.24.3",
11
+ # ]
12
+ # ///
13
+
14
+ import torch
15
+ from transformers import GptOssForCausalLM, PreTrainedTokenizerFast, Mxfp4Config
16
+ import time
17
+ import torch.nn as nn
18
+ from kernels import register_kernel_mapping, Mode, LayerRepository
19
+ import sys
20
+ import torch.profiler
21
+ import gc
22
+ import logging
23
+
24
+ # set to debug logging
25
+ logging.basicConfig(level=logging.INFO)
26
+
27
+ def reset_peak_memory_stats():
28
+ """Clear CUDA cache and reset memory allocation counters."""
29
+ torch.cuda.empty_cache()
30
+ if torch.cuda.is_available():
31
+ torch.cuda.reset_peak_memory_stats()
32
+ gc.collect()
33
+
34
+ def get_memory_stats():
35
+ """Get current and peak CUDA memory usage."""
36
+ if not torch.cuda.is_available():
37
+ return {"allocated_gb": 0, "peak_gb": 0, "reserved_gb": 0}
38
+ return {
39
+ "allocated_gb": torch.cuda.memory_allocated() / 1e9,
40
+ "peak_gb": torch.cuda.max_memory_allocated() / 1e9,
41
+ "reserved_gb": torch.cuda.memory_reserved() / 1e9,
42
+ }
43
+
44
+ def override_kernel_layer_name(cls_name: str, value) -> bool:
45
+ """Helper to dynamically override the kernel_layer_name in a model class."""
46
+ for mod in sys.modules.values():
47
+ if mod is None:
48
+ continue
49
+ obj = getattr(mod, cls_name, None)
50
+ if isinstance(obj, type) and issubclass(obj, nn.Module):
51
+ setattr(obj, "kernel_layer_name", value)
52
+ print(f"Overrode {cls_name}.kernel_layer_name to {value}")
53
+ return True
54
+ return False
55
+
56
+
57
+ # Init the model the normal way
58
+ model_id = "openai/gpt-oss-20b"
59
+ tokenizer = PreTrainedTokenizerFast.from_pretrained(model_id)
60
+ quantization_config = Mxfp4Config(dequantize=True)
61
+
62
+
63
+ model = GptOssForCausalLM.from_pretrained(
64
+ model_id,
65
+ dtype="bfloat16",
66
+ device_map="auto",
67
+ use_kernels=True,
68
+ quantization_config=quantization_config,
69
+ ).eval()
70
+
71
+ messages = [
72
+ {"role": "system", "content": "What is Tensor Parallelism?"},
73
+ ]
74
+
75
+ inputs = tokenizer.apply_chat_template(
76
+ messages,
77
+ add_generation_prompt=True,
78
+ return_tensors="pt",
79
+ return_dict=True,
80
+ reasoning_effort="low",
81
+ ).to("cuda")
82
+
83
+ max_tokens = 512
84
+
85
+ with torch.inference_mode():
86
+ start_time = time.perf_counter()
87
+ generated = model.generate(
88
+ **inputs,
89
+ max_new_tokens=max_tokens,
90
+ do_sample=False,
91
+ temperature=None,
92
+ )
93
+ end_time = time.perf_counter()
94
+
95
+ print(tokenizer.decode(generated[0], skip_special_tokens=False))
96
+ print(f"Generation took {end_time - start_time:.2f} seconds")
site/cells/setup.py ADDED
@@ -0,0 +1,116 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # /// script
2
+ # requires-python = ">=3.12"
3
+ # dependencies = [
4
+ # "accelerate>=1.10.1",
5
+ # "torch>=2.7.0",
6
+ # "kernels==0.10.0",
7
+ # "transformers@https://github.com/huggingface/transformers.git",
8
+ # "ipdb>=0.13.13",
9
+ # "matplotlib>=3.7.2",
10
+ # "numpy>=1.24.3",
11
+ # ]
12
+ # ///
13
+
14
+ import torch
15
+ from transformers import GptOssForCausalLM, PreTrainedTokenizerFast, Mxfp4Config
16
+ import time
17
+ import torch.nn as nn
18
+ from kernels import register_kernel_mapping, Mode, LayerRepository
19
+ import sys
20
+ import torch.profiler
21
+ import gc
22
+ import logging
23
+
24
+ # set to debug logging
25
+ logging.basicConfig(level=logging.INFO)
26
+
27
+ def reset_peak_memory_stats():
28
+ """Clear CUDA cache and reset memory allocation counters."""
29
+ torch.cuda.empty_cache()
30
+ if torch.cuda.is_available():
31
+ torch.cuda.reset_peak_memory_stats()
32
+ gc.collect()
33
+
34
+ def get_memory_stats():
35
+ """Get current and peak CUDA memory usage."""
36
+ if not torch.cuda.is_available():
37
+ return {"allocated_gb": 0, "peak_gb": 0, "reserved_gb": 0}
38
+ return {
39
+ "allocated_gb": torch.cuda.memory_allocated() / 1e9,
40
+ "peak_gb": torch.cuda.max_memory_allocated() / 1e9,
41
+ "reserved_gb": torch.cuda.memory_reserved() / 1e9,
42
+ }
43
+
44
+ def override_kernel_layer_name(cls_name: str, value) -> bool:
45
+ """Helper to dynamically override the kernel_layer_name in a model class."""
46
+ for mod in sys.modules.values():
47
+ if mod is None:
48
+ continue
49
+ obj = getattr(mod, cls_name, None)
50
+ if isinstance(obj, type) and issubclass(obj, nn.Module):
51
+ setattr(obj, "kernel_layer_name", value)
52
+ print(f"Overrode {cls_name}.kernel_layer_name to {value}")
53
+ return True
54
+ return False
55
+
56
+
57
+ # Init the model the normal way
58
+ model_id = "openai/gpt-oss-20b"
59
+ tokenizer = PreTrainedTokenizerFast.from_pretrained(model_id)
60
+ quantization_config = Mxfp4Config(dequantize=True)
61
+
62
+
63
+ from kernels import replace_kernel_forward_from_hub, register_kernel_mapping, LayerRepository, Mode
64
+
65
+ from transformers.models.gpt_oss.modeling_gpt_oss import GptOssMLP, GptOssRMSNorm
66
+
67
+ replace_kernel_forward_from_hub(GptOssMLP, "Yamoe") # direct, type-safe
68
+ replace_kernel_forward_from_hub(GptOssRMSNorm, None) # direct, type-safe
69
+ custom_mapping = {
70
+ "Yamoe": {
71
+ "cuda": {
72
+ Mode.INFERENCE: LayerRepository(
73
+ repo_id="drbh/yamoe",
74
+ layer_name="Yamoe",
75
+ revision="v0.3.0",
76
+ )
77
+ }
78
+ }
79
+ }
80
+ register_kernel_mapping(custom_mapping)
81
+
82
+
83
+ model = GptOssForCausalLM.from_pretrained(
84
+ model_id,
85
+ dtype="bfloat16",
86
+ device_map="auto",
87
+ use_kernels=True,
88
+ quantization_config=quantization_config,
89
+ ).eval()
90
+
91
+ messages = [
92
+ {"role": "system", "content": "What is Tensor Parallelism?"},
93
+ ]
94
+
95
+ inputs = tokenizer.apply_chat_template(
96
+ messages,
97
+ add_generation_prompt=True,
98
+ return_tensors="pt",
99
+ return_dict=True,
100
+ reasoning_effort="low",
101
+ ).to("cuda")
102
+
103
+ max_tokens = 512
104
+
105
+ with torch.inference_mode():
106
+ start_time = time.perf_counter()
107
+ generated = model.generate(
108
+ **inputs,
109
+ max_new_tokens=max_tokens,
110
+ do_sample=False,
111
+ temperature=None,
112
+ )
113
+ end_time = time.perf_counter()
114
+
115
+ print(tokenizer.decode(generated[0], skip_special_tokens=False))
116
+ print(f"Generation took {end_time - start_time:.2f} seconds")
site/cells/setup2.py ADDED
@@ -0,0 +1,115 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # /// script
2
+ # requires-python = ">=3.12"
3
+ # dependencies = [
4
+ # "accelerate>=1.10.1",
5
+ # "torch>=2.7.0",
6
+ # "kernels==0.10.0",
7
+ # "transformers@https://github.com/huggingface/transformers.git",
8
+ # "ipdb>=0.13.13",
9
+ # "matplotlib>=3.7.2",
10
+ # "numpy>=1.24.3",
11
+ # ]
12
+ # ///
13
+
14
+ import torch
15
+ from transformers import GptOssForCausalLM, PreTrainedTokenizerFast, Mxfp4Config
16
+ import time
17
+ import torch.nn as nn
18
+ from kernels import register_kernel_mapping, Mode, LayerRepository
19
+ import sys
20
+ import torch.profiler
21
+ import gc
22
+ import logging
23
+
24
+ # set to debug logging
25
+ logging.basicConfig(level=logging.INFO)
26
+
27
+ def reset_peak_memory_stats():
28
+ """Clear CUDA cache and reset memory allocation counters."""
29
+ torch.cuda.empty_cache()
30
+ if torch.cuda.is_available():
31
+ torch.cuda.reset_peak_memory_stats()
32
+ gc.collect()
33
+
34
+ def get_memory_stats():
35
+ """Get current and peak CUDA memory usage."""
36
+ if not torch.cuda.is_available():
37
+ return {"allocated_gb": 0, "peak_gb": 0, "reserved_gb": 0}
38
+ return {
39
+ "allocated_gb": torch.cuda.memory_allocated() / 1e9,
40
+ "peak_gb": torch.cuda.max_memory_allocated() / 1e9,
41
+ "reserved_gb": torch.cuda.memory_reserved() / 1e9,
42
+ }
43
+
44
+ def override_kernel_layer_name(cls_name: str, value) -> bool:
45
+ """Helper to dynamically override the kernel_layer_name in a model class."""
46
+ for mod in sys.modules.values():
47
+ if mod is None:
48
+ continue
49
+ obj = getattr(mod, cls_name, None)
50
+ if isinstance(obj, type) and issubclass(obj, nn.Module):
51
+ setattr(obj, "kernel_layer_name", value)
52
+ print(f"Overrode {cls_name}.kernel_layer_name to {value}")
53
+ return True
54
+ return False
55
+
56
+
57
+ # Init the model the normal way
58
+ model_id = "openai/gpt-oss-20b"
59
+ tokenizer = PreTrainedTokenizerFast.from_pretrained(model_id)
60
+ quantization_config = Mxfp4Config(dequantize=True)
61
+
62
+
63
+ from kernels import replace_kernel_forward_from_hub, register_kernel_mapping, LayerRepository, Mode
64
+
65
+ from transformers.models.gpt_oss.modeling_gpt_oss import GptOssMLP, GptOssRMSNorm
66
+
67
+ replace_kernel_forward_from_hub(GptOssRMSNorm, None) # direct, type-safe
68
+ custom_mapping = {
69
+ "Yamoe": {
70
+ "cuda": {
71
+ Mode.INFERENCE: LayerRepository(
72
+ repo_id="drbh/yamoe",
73
+ layer_name="Yamoe",
74
+ revision="v0.3.0",
75
+ )
76
+ }
77
+ }
78
+ }
79
+ register_kernel_mapping(custom_mapping)
80
+
81
+
82
+ model = GptOssForCausalLM.from_pretrained(
83
+ model_id,
84
+ dtype="bfloat16",
85
+ device_map="auto",
86
+ use_kernels=True,
87
+ quantization_config=quantization_config,
88
+ ).eval()
89
+
90
+ messages = [
91
+ {"role": "system", "content": "What is Tensor Parallelism?"},
92
+ ]
93
+
94
+ inputs = tokenizer.apply_chat_template(
95
+ messages,
96
+ add_generation_prompt=True,
97
+ return_tensors="pt",
98
+ return_dict=True,
99
+ reasoning_effort="low",
100
+ ).to("cuda")
101
+
102
+ max_tokens = 512
103
+
104
+ with torch.inference_mode():
105
+ start_time = time.perf_counter()
106
+ generated = model.generate(
107
+ **inputs,
108
+ max_new_tokens=max_tokens,
109
+ do_sample=False,
110
+ temperature=None,
111
+ )
112
+ end_time = time.perf_counter()
113
+
114
+ print(tokenizer.decode(generated[0], skip_special_tokens=False))
115
+ print(f"Generation took {end_time - start_time:.2f} seconds")
site/megablocks_only.html ADDED
The diff for this file is too large to render. See raw diff
 
site/note.html ADDED
The diff for this file is too large to render. See raw diff
 
site/note_test_override.html ADDED
The diff for this file is too large to render. See raw diff