File size: 6,608 Bytes
73f8595 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 |
# /// script
# requires-python = ">=3.12"
# dependencies = [
# "accelerate>=1.10.1",
# "torch>=2.7.0",
# "kernels==0.10.0",
# "transformers@https://github.com/huggingface/transformers.git",
# "ipdb>=0.13.13",
# "matplotlib>=3.7.2",
# "numpy>=1.24.3",
# ]
# ///
import torch
from transformers import GptOssForCausalLM, PreTrainedTokenizerFast, Mxfp4Config
import time
import torch.nn as nn
from kernels import register_kernel_mapping, Mode, LayerRepository, replace_kernel_forward_from_hub
import sys
import torch.profiler
import gc
import logging
from transformers.models.gpt_oss.modeling_gpt_oss import GptOssRMSNorm
# remove liger kernel for testing
replace_kernel_forward_from_hub(GptOssRMSNorm, None)
# set to debug logging
logging.basicConfig(level=logging.INFO)
def reset_peak_memory_stats():
"""Clear CUDA cache and reset memory allocation counters."""
torch.cuda.empty_cache()
if torch.cuda.is_available():
torch.cuda.reset_peak_memory_stats()
gc.collect()
def get_memory_stats():
"""Get current and peak CUDA memory usage."""
if not torch.cuda.is_available():
return {"allocated_gb": 0, "peak_gb": 0, "reserved_gb": 0}
return {
"allocated_gb": torch.cuda.memory_allocated() / 1e9,
"peak_gb": torch.cuda.max_memory_allocated() / 1e9,
"reserved_gb": torch.cuda.memory_reserved() / 1e9,
}
def override_kernel_layer_name(cls_name: str, value) -> bool:
"""Helper to dynamically override the kernel_layer_name in a model class."""
for mod in sys.modules.values():
if mod is None:
continue
obj = getattr(mod, cls_name, None)
if isinstance(obj, type) and issubclass(obj, nn.Module):
setattr(obj, "kernel_layer_name", value)
print(f"Overrode {cls_name}.kernel_layer_name to {value}")
return True
return False
# Init the model the normal way
model_id = "openai/gpt-oss-20b"
tokenizer = PreTrainedTokenizerFast.from_pretrained(model_id)
quantization_config = Mxfp4Config(dequantize=True)
model = GptOssForCausalLM.from_pretrained(
model_id,
dtype="bfloat16",
device_map="auto",
use_kernels=False,
quantization_config=quantization_config,
).eval()
messages = [
{"role": "system", "content": "What is Tensor Parallelism?"},
]
inputs = tokenizer.apply_chat_template(
messages,
add_generation_prompt=True,
return_tensors="pt",
return_dict=True,
reasoning_effort="low",
).to("cuda")
max_tokens = 128 # Reduced to help with memory usage
# Clear memory before backward pass
reset_peak_memory_stats()
print(f"Pre-generation memory: {get_memory_stats()}")
# forward and backward pass
with torch.autograd.set_grad_enabled(True):
start_time = time.perf_counter()
generated = model.generate(
**inputs,
max_new_tokens=max_tokens,
do_sample=False,
temperature=None,
)
end_time = time.perf_counter()
print(tokenizer.decode(generated[0], skip_special_tokens=False))
print(f"Generation took {end_time - start_time:.2f} seconds")
print(f"Post-generation memory: {get_memory_stats()}")
# Use gradient checkpointing to reduce memory usage
if hasattr(model, 'gradient_checkpointing_enable'):
model.gradient_checkpointing_enable()
print("Enabled gradient checkpointing")
# Reduce sequence length if needed for memory
max_seq_len = 512 # Limit sequence length for backward pass
if generated.size(1) > max_seq_len:
print(f"Truncating sequence from {generated.size(1)} to {max_seq_len} tokens")
full_sequence = generated[:, -max_seq_len:]
else:
full_sequence = generated
# Get model outputs for the full sequence
model.train() # Enable dropout and other training behaviors
try:
outputs = model(
input_ids=full_sequence,
labels=full_sequence, # This will compute loss internally
return_dict=True
)
print(f"Post-forward memory: {get_memory_stats()}")
# If model doesn't compute loss, compute it manually
if outputs.loss is None:
shift_logits = outputs.logits[..., :-1, :].contiguous()
shift_labels = full_sequence[..., 1:].contiguous()
# Use CrossEntropyLoss with ignore_index for padding tokens
loss_fct = torch.nn.CrossEntropyLoss(ignore_index=tokenizer.pad_token_id if tokenizer.pad_token_id is not None else -100)
loss = loss_fct(
shift_logits.view(-1, shift_logits.size(-1)),
shift_labels.view(-1)
)
else:
loss = outputs.loss
print(f"Loss: {loss.item():.4f}")
# Clear intermediate tensors to save memory
del outputs
torch.cuda.empty_cache()
# Perform backward pass with memory management
print("Running backward pass...")
print(f"Pre-backward memory: {get_memory_stats()}")
loss.backward()
print(f"Post-backward memory: {get_memory_stats()}")
except torch.cuda.OutOfMemoryError as e:
print(f"OOM during forward/backward pass: {e}")
print("Try reducing max_tokens or max_seq_len")
raise
# Calculate gradient statistics and print sample gradients
total_norm = 0.0
param_count = 0
grad_samples = {}
for name, p in model.named_parameters():
if p.grad is not None:
param_count += 1
grad_norm = p.grad.data.norm(2).item()
total_norm += grad_norm ** 2
# Collect gradient statistics for key layers
if any(key in name for key in ['embed', 'lm_head', 'mlp.up', 'mlp.down', 'self_attn.q_proj', 'norm']):
grad_samples[name] = {
'norm': grad_norm,
'mean': p.grad.data.mean().item(),
'std': p.grad.data.std().item(),
'max': p.grad.data.max().item(),
'min': p.grad.data.min().item(),
}
total_norm = total_norm ** 0.5
print(f"\nGradient norm: {total_norm:.4f}")
print(f"Parameters with gradients: {param_count}")
# Print sample gradients from important layers
print("\nSample gradient statistics:")
for i, (name, stats) in enumerate(list(grad_samples.items())[:10]):
print(f" {name[:60]:<60} | norm: {stats['norm']:.4e} | mean: {stats['mean']:.4e} | std: {stats['std']:.4e}")
# Optional: zero gradients for next iteration
model.zero_grad()
model.eval() # Switch back to eval mode
|