# /// script # dependencies = [ # "torch", # "numpy", # ] # /// """Reusable benchmarking utilities for performance testing.""" import time import numpy as np from contextlib import contextmanager from typing import Callable, Dict, Tuple, Any, Optional import torch def to_dtype(dtype_str: str): """Convert string to torch dtype.""" if dtype_str == "float16": return torch.float16 if dtype_str == "bfloat16": return torch.bfloat16 return torch.float32 def _sync(device: str): """Synchronize device if CUDA.""" if device == "cuda": torch.cuda.synchronize() def _compute_stats(times_s, tokens: Optional[int] = None) -> Dict[str, float]: """Compute comprehensive latency and throughput statistics.""" lat_ms = np.array([t * 1000.0 for t in times_s]) lat_ms_sorted = np.sort(lat_ms) n = len(lat_ms) stats = { "avg_ms": np.mean(lat_ms), "min_ms": np.min(lat_ms), "max_ms": np.max(lat_ms), "std_ms": np.std(lat_ms), "p50_ms": np.percentile(lat_ms, 50), "p95_ms": np.percentile(lat_ms, 95), "p99_ms": np.percentile(lat_ms, 99), "num_iters": n } if tokens is not None and n > 0: avg_s = np.mean(times_s) stats["tokens_per_s"] = tokens / avg_s if avg_s > 0 else float("inf") stats["throughput_variance"] = np.std([tokens / t for t in times_s if t > 0]) return stats def _format_timing_stats(stats: Dict[str, float], tokens: Optional[int] = None) -> str: """Format timing statistics for display.""" lines = [ "\n━━━━━━━━━━━━━━━━━━━━ Benchmark Results ━━━━━━━━━━━━━━━━━━━━", f"Iterations: {stats.get('num_iters', 0)}", "\nLatency Statistics:", f" Average: {stats['avg_ms']:.3f} ms", f" Min: {stats['min_ms']:.3f} ms", f" Max: {stats['max_ms']:.3f} ms", f" Std Dev: {stats['std_ms']:.3f} ms", "\nPercentiles:", f" P50 (median): {stats['p50_ms']:.3f} ms", f" P95: {stats['p95_ms']:.3f} ms", f" P99: {stats['p99_ms']:.3f} ms", ] if tokens is not None and 'tokens_per_s' in stats: lines.extend([ "\nThroughput:", f" Tokens/sec: {stats['tokens_per_s']:.1f}", f" Std Dev: {stats.get('throughput_variance', 0):.1f}", ]) lines.append("━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━") return "\n".join(lines) def _bench_engine( call: Callable[[], Any], *, warmup: int, iters: int, device: str, dtype, input_gen: Callable[[], Any] = None ) -> Tuple[Any, list]: """Core benchmarking engine with warmup and timing.""" use_autocast = device == "cuda" and dtype in (torch.float16, torch.bfloat16) # Warmup phase print(f"\nWarming up ({warmup} iterations)...") with torch.inference_mode(): for _ in range(max(0, warmup)): if use_autocast: with torch.autocast(device_type="cuda", dtype=dtype): if input_gen is not None: _ = call(input_gen()) else: _ = call() else: if input_gen is not None: _ = call(input_gen()) else: _ = call() _sync(device) # Measurement phase print(f"Benchmarking ({iters} iterations)...") times_s = [] last = None with torch.inference_mode(): for i in range(max(1, iters)): start = time.perf_counter() if use_autocast: with torch.autocast(device_type="cuda", dtype=dtype): if input_gen is not None: last = call(input_gen()) else: last = call() else: if input_gen is not None: last = call(input_gen()) else: last = call() _sync(device) end = time.perf_counter() times_s.append(end - start) # Progress indicator every 20% of iterations if i > 0 and i % max(1, iters // 5) == 0: pct = (i / iters) * 100 avg_so_far = np.mean(times_s[:i]) * 1000 print(f" Progress: {pct:.0f}% complete (avg: {avg_so_far:.3f} ms)") return last, times_s def tensor_stats(t: torch.Tensor) -> str: """Generate comprehensive stats string for a tensor.""" return (f"shape={tuple(t.shape)}, " f"dtype={t.dtype}, " f"device={t.device}, " f"range=[{t.min().item():.6f}, {t.max().item():.6f}], " f"mean={t.mean().item():.6f}, " f"std={t.std().item():.6f}, " f"norm={t.norm().item():.6f}") @contextmanager def bench_context( *, warmup: int = 25, iters: int = 100, device: str = "cuda", dtype=torch.float32, tokens: Optional[int] = None, verbose: bool = True, save_json: Optional[str] = None, vary_inputs: bool = True ): """Context that yields a runner: runner(fn, *args, **kwargs) -> (result, stats). If vary_inputs=True, the first argument should be a base tensor that will be varied each iteration by adding a small deterministic increment to prevent caching artifacts. """ def runner(fn: Callable[..., Any], *args, **kwargs) -> Tuple[Any, Dict[str, float]]: # Log configuration if verbose: print(f"\n┌─ Benchmark Configuration ─────────────────────────────┐") # print(f"│ Device: {device:<15} Dtype: {dtype} │") print(f"│ Warmup: {warmup:<15} Iters: {iters} │") if tokens: print(f"│ Tokens: {tokens} │") if vary_inputs: print(f"│ Input Variation: Enabled (prevents caching artifacts) │") print(f"└────────────────────────────────────────────────────────┘") # Set up input generation input_gen = None if vary_inputs and args and isinstance(args[0], torch.Tensor): base_input = args[0].clone() iteration_counter = [0] # Use list for mutable closure def generate_varied_input(): """Generate input tensor varied by iteration to prevent caching.""" # Add small deterministic increment: 0.001 * iteration_number varied_input = base_input + (iteration_counter[0] * 0.001) iteration_counter[0] += 1 return varied_input input_gen = generate_varied_input call = lambda x: fn(x, *args[1:], **kwargs) # Log base input stats if verbose: print(f"\nBase Input: {tensor_stats(base_input)}") print(f"Input Variation: +{0.001:.3f} * iteration (deterministic)") else: # Legacy mode - static inputs call = lambda: fn(*args, **kwargs) if verbose and args and isinstance(args[0], torch.Tensor): print(f"\nInput: {tensor_stats(args[0])}") result, times_s = _bench_engine(call, warmup=warmup, iters=iters, device=device, dtype=dtype, input_gen=input_gen) # Log output if it's a tensor or tuple with tensors if verbose: print("\nOutput tensors:") if isinstance(result, torch.Tensor): print(f" Primary: {tensor_stats(result)}") elif isinstance(result, tuple) and len(result) > 0 and isinstance(result[0], torch.Tensor): print(f" Primary: {tensor_stats(result[0])}") if len(result) > 1: if isinstance(result[1], torch.Tensor): print(f" Auxiliary: {tensor_stats(result[1])}") else: print(f" Auxiliary: {type(result[1]).__name__}") # Compute and display statistics stats = _compute_stats(times_s, tokens=tokens) if verbose: print(_format_timing_stats(stats, tokens)) # Save to JSON if requested if save_json: import json json_data = { "implementation": save_json.replace(".json", ""), "config": { "warmup": warmup, "iters": iters, "device": str(device), # Convert device to string "dtype": str(dtype), "tokens": tokens, "vary_inputs": vary_inputs }, "stats": stats, "output_sum": float(result[0].sum().item()) if isinstance(result, tuple) and len(result) > 0 else float(result.sum().item()) if isinstance(result, torch.Tensor) else None } with open(save_json, 'w') as f: json.dump(json_data, f, indent=2) if verbose: print(f"\nSaved benchmark results to {save_json}") return result, stats yield runner def set_seed(seed: int): """Set seeds for reproducibility.""" torch.manual_seed(seed) if torch.cuda.is_available(): torch.cuda.manual_seed(seed) torch.cuda.manual_seed_all(seed) torch.backends.cudnn.deterministic = True torch.backends.cudnn.benchmark = False