|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
import torch |
|
|
from transformers import GptOssForCausalLM, PreTrainedTokenizerFast, Mxfp4Config |
|
|
import time |
|
|
import torch.nn as nn |
|
|
from kernels import register_kernel_mapping, Mode, LayerRepository, replace_kernel_forward_from_hub |
|
|
import sys |
|
|
import torch.profiler |
|
|
import gc |
|
|
import logging |
|
|
from transformers.models.gpt_oss.modeling_gpt_oss import GptOssRMSNorm |
|
|
|
|
|
|
|
|
replace_kernel_forward_from_hub(GptOssRMSNorm, None) |
|
|
|
|
|
|
|
|
logging.basicConfig(level=logging.INFO) |
|
|
|
|
|
def reset_peak_memory_stats(): |
|
|
"""Clear CUDA cache and reset memory allocation counters.""" |
|
|
torch.cuda.empty_cache() |
|
|
if torch.cuda.is_available(): |
|
|
torch.cuda.reset_peak_memory_stats() |
|
|
gc.collect() |
|
|
|
|
|
def get_memory_stats(): |
|
|
"""Get current and peak CUDA memory usage.""" |
|
|
if not torch.cuda.is_available(): |
|
|
return {"allocated_gb": 0, "peak_gb": 0, "reserved_gb": 0} |
|
|
return { |
|
|
"allocated_gb": torch.cuda.memory_allocated() / 1e9, |
|
|
"peak_gb": torch.cuda.max_memory_allocated() / 1e9, |
|
|
"reserved_gb": torch.cuda.memory_reserved() / 1e9, |
|
|
} |
|
|
|
|
|
def override_kernel_layer_name(cls_name: str, value) -> bool: |
|
|
"""Helper to dynamically override the kernel_layer_name in a model class.""" |
|
|
for mod in sys.modules.values(): |
|
|
if mod is None: |
|
|
continue |
|
|
obj = getattr(mod, cls_name, None) |
|
|
if isinstance(obj, type) and issubclass(obj, nn.Module): |
|
|
setattr(obj, "kernel_layer_name", value) |
|
|
print(f"Overrode {cls_name}.kernel_layer_name to {value}") |
|
|
return True |
|
|
return False |
|
|
|
|
|
|
|
|
|
|
|
model_id = "openai/gpt-oss-20b" |
|
|
tokenizer = PreTrainedTokenizerFast.from_pretrained(model_id) |
|
|
quantization_config = Mxfp4Config(dequantize=True) |
|
|
|
|
|
model = GptOssForCausalLM.from_pretrained( |
|
|
model_id, |
|
|
dtype="bfloat16", |
|
|
device_map="auto", |
|
|
use_kernels=False, |
|
|
quantization_config=quantization_config, |
|
|
).eval() |
|
|
|
|
|
messages = [ |
|
|
{"role": "system", "content": "What is Tensor Parallelism?"}, |
|
|
] |
|
|
|
|
|
inputs = tokenizer.apply_chat_template( |
|
|
messages, |
|
|
add_generation_prompt=True, |
|
|
return_tensors="pt", |
|
|
return_dict=True, |
|
|
reasoning_effort="low", |
|
|
).to("cuda") |
|
|
|
|
|
max_tokens = 128 |
|
|
|
|
|
|
|
|
reset_peak_memory_stats() |
|
|
print(f"Pre-generation memory: {get_memory_stats()}") |
|
|
|
|
|
|
|
|
with torch.autograd.set_grad_enabled(True): |
|
|
start_time = time.perf_counter() |
|
|
generated = model.generate( |
|
|
**inputs, |
|
|
max_new_tokens=max_tokens, |
|
|
do_sample=False, |
|
|
temperature=None, |
|
|
) |
|
|
end_time = time.perf_counter() |
|
|
print(tokenizer.decode(generated[0], skip_special_tokens=False)) |
|
|
print(f"Generation took {end_time - start_time:.2f} seconds") |
|
|
print(f"Post-generation memory: {get_memory_stats()}") |
|
|
|
|
|
|
|
|
if hasattr(model, 'gradient_checkpointing_enable'): |
|
|
model.gradient_checkpointing_enable() |
|
|
print("Enabled gradient checkpointing") |
|
|
|
|
|
|
|
|
max_seq_len = 512 |
|
|
if generated.size(1) > max_seq_len: |
|
|
print(f"Truncating sequence from {generated.size(1)} to {max_seq_len} tokens") |
|
|
full_sequence = generated[:, -max_seq_len:] |
|
|
else: |
|
|
full_sequence = generated |
|
|
|
|
|
|
|
|
model.train() |
|
|
|
|
|
try: |
|
|
outputs = model( |
|
|
input_ids=full_sequence, |
|
|
labels=full_sequence, |
|
|
return_dict=True |
|
|
) |
|
|
print(f"Post-forward memory: {get_memory_stats()}") |
|
|
|
|
|
|
|
|
if outputs.loss is None: |
|
|
shift_logits = outputs.logits[..., :-1, :].contiguous() |
|
|
shift_labels = full_sequence[..., 1:].contiguous() |
|
|
|
|
|
|
|
|
loss_fct = torch.nn.CrossEntropyLoss(ignore_index=tokenizer.pad_token_id if tokenizer.pad_token_id is not None else -100) |
|
|
loss = loss_fct( |
|
|
shift_logits.view(-1, shift_logits.size(-1)), |
|
|
shift_labels.view(-1) |
|
|
) |
|
|
else: |
|
|
loss = outputs.loss |
|
|
|
|
|
print(f"Loss: {loss.item():.4f}") |
|
|
|
|
|
|
|
|
del outputs |
|
|
torch.cuda.empty_cache() |
|
|
|
|
|
|
|
|
print("Running backward pass...") |
|
|
print(f"Pre-backward memory: {get_memory_stats()}") |
|
|
|
|
|
loss.backward() |
|
|
print(f"Post-backward memory: {get_memory_stats()}") |
|
|
|
|
|
except torch.cuda.OutOfMemoryError as e: |
|
|
print(f"OOM during forward/backward pass: {e}") |
|
|
print("Try reducing max_tokens or max_seq_len") |
|
|
raise |
|
|
|
|
|
|
|
|
total_norm = 0.0 |
|
|
param_count = 0 |
|
|
grad_samples = {} |
|
|
|
|
|
for name, p in model.named_parameters(): |
|
|
if p.grad is not None: |
|
|
param_count += 1 |
|
|
grad_norm = p.grad.data.norm(2).item() |
|
|
total_norm += grad_norm ** 2 |
|
|
|
|
|
|
|
|
if any(key in name for key in ['embed', 'lm_head', 'mlp.up', 'mlp.down', 'self_attn.q_proj', 'norm']): |
|
|
grad_samples[name] = { |
|
|
'norm': grad_norm, |
|
|
'mean': p.grad.data.mean().item(), |
|
|
'std': p.grad.data.std().item(), |
|
|
'max': p.grad.data.max().item(), |
|
|
'min': p.grad.data.min().item(), |
|
|
} |
|
|
|
|
|
total_norm = total_norm ** 0.5 |
|
|
|
|
|
print(f"\nGradient norm: {total_norm:.4f}") |
|
|
print(f"Parameters with gradients: {param_count}") |
|
|
|
|
|
|
|
|
print("\nSample gradient statistics:") |
|
|
for i, (name, stats) in enumerate(list(grad_samples.items())[:10]): |
|
|
print(f" {name[:60]:<60} | norm: {stats['norm']:.4e} | mean: {stats['mean']:.4e} | std: {stats['std']:.4e}") |
|
|
|
|
|
|
|
|
model.zero_grad() |
|
|
model.eval() |
|
|
|
|
|
|