"""Reusable benchmarking utilities for performance testing."""
+import time
+import numpy as np
+from contextlib import contextmanager
+from typing import Callable, Dict, Tuple, Any, Optional
+import torch
+
+def to_dtype(dtype_str: str):
+ """Convert string to torch dtype."""
+ if dtype_str == "float16":
+ return torch.float16
+ if dtype_str == "bfloat16":
+ return torch.bfloat16
+ return torch.float32
+
+def _sync(device: str):
+ """Synchronize device if CUDA."""
+ if device == "cuda":
+ torch.cuda.synchronize()
+
+def _compute_stats(times_s, tokens: Optional[int] = None) -> Dict[str, float]:
+ """Compute comprehensive latency and throughput statistics."""
+ lat_ms = np.array([t * 1000.0 for t in times_s])
+ lat_ms_sorted = np.sort(lat_ms)
+ n = len(lat_ms)
+
+ stats = {
+ "avg_ms": np.mean(lat_ms),
+ "min_ms": np.min(lat_ms),
+ "max_ms": np.max(lat_ms),
+ "std_ms": np.std(lat_ms),
+ "p50_ms": np.percentile(lat_ms, 50),
+ "p95_ms": np.percentile(lat_ms, 95),
+ "p99_ms": np.percentile(lat_ms, 99),
+ "num_iters": n
+ }
+
+ if tokens is not None and n > 0:
+ avg_s = np.mean(times_s)
+ stats["tokens_per_s"] = tokens / avg_s if avg_s > 0 else float("inf")
+ stats["throughput_variance"] = np.std([tokens / t for t in times_s if t > 0])
+
+ return stats
+
+def _format_timing_stats(stats: Dict[str, float], tokens: Optional[int] = None) -> str:
+ """Format timing statistics for display."""
+ lines = [
+ "\n━━━━━━━━━━━━━━━━━━━━ Benchmark Results ━━━━━━━━━━━━━━━━━━━━",
+ f"Iterations: {stats.get('num_iters', 0)}",
+ "\nLatency Statistics:",
+ f" Average: {stats['avg_ms']:.3f} ms",
+ f" Min: {stats['min_ms']:.3f} ms",
+ f" Max: {stats['max_ms']:.3f} ms",
+ f" Std Dev: {stats['std_ms']:.3f} ms",
+ "\nPercentiles:",
+ f" P50 (median): {stats['p50_ms']:.3f} ms",
+ f" P95: {stats['p95_ms']:.3f} ms",
+ f" P99: {stats['p99_ms']:.3f} ms",
+ ]
+
+ if tokens is not None and 'tokens_per_s' in stats:
+ lines.extend([
+ "\nThroughput:",
+ f" Tokens/sec: {stats['tokens_per_s']:.1f}",
+ f" Std Dev: {stats.get('throughput_variance', 0):.1f}",
+ ])
+
+ lines.append("━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━")
+ return "\n".join(lines)
+
+def _bench_engine(
+ call: Callable[[], Any], *, warmup: int, iters: int, device: str, dtype, input_gen: Callable[[], Any] = None
+) -> Tuple[Any, list]:
+ """Core benchmarking engine with warmup and timing."""
+ use_autocast = device == "cuda" and dtype in (torch.float16, torch.bfloat16)
+
+ # Warmup phase
+ print(f"\nWarming up ({warmup} iterations)...")
+ with torch.inference_mode():
+ for _ in range(max(0, warmup)):
+ if use_autocast:
+ with torch.autocast(device_type="cuda", dtype=dtype):
+ if input_gen is not None:
+ _ = call(input_gen())
+ else:
+ _ = call()
+ else:
+ if input_gen is not None:
+ _ = call(input_gen())
+ else:
+ _ = call()
+ _sync(device)
+
+ # Measurement phase
+ print(f"Benchmarking ({iters} iterations)...")
+ times_s = []
+ last = None
+ with torch.inference_mode():
+ for i in range(max(1, iters)):
+ start = time.perf_counter()
+ if use_autocast:
+ with torch.autocast(device_type="cuda", dtype=dtype):
+ if input_gen is not None:
+ last = call(input_gen())
+ else:
+ last = call()
+ else:
+ if input_gen is not None:
+ last = call(input_gen())
+ else:
+ last = call()
+ _sync(device)
+ end = time.perf_counter()
+ times_s.append(end - start)
+
+ # Progress indicator every 20% of iterations
+ if i > 0 and i % max(1, iters // 5) == 0:
+ pct = (i / iters) * 100
+ avg_so_far = np.mean(times_s[:i]) * 1000
+ print(f" Progress: {pct:.0f}% complete (avg: {avg_so_far:.3f} ms)")
+
+ return last, times_s
+
+def tensor_stats(t: torch.Tensor) -> str:
+ """Generate comprehensive stats string for a tensor."""
+ return (f"shape={tuple(t.shape)}, "
+ f"dtype={t.dtype}, "
+ f"device={t.device}, "
+ f"range=[{t.min().item():.6f}, {t.max().item():.6f}], "
+ f"mean={t.mean().item():.6f}, "
+ f"std={t.std().item():.6f}, "
+ f"norm={t.norm().item():.6f}")
+
+@contextmanager
+def bench_context(
+ *, warmup: int = 25, iters: int = 100, device: str = "cuda", dtype=torch.float32, tokens: Optional[int] = None, verbose: bool = True, save_json: Optional[str] = None, vary_inputs: bool = True
+):
+ """Context that yields a runner: runner(fn, *args, **kwargs) -> (result, stats).
+
+ If vary_inputs=True, the first argument should be a base tensor that will be varied each iteration
+ by adding a small deterministic increment to prevent caching artifacts.
+ """
+
+ def runner(fn: Callable[..., Any], *args, **kwargs) -> Tuple[Any, Dict[str, float]]:
+ # Log configuration
+ if verbose:
+ print(f"\n┌─ Benchmark Configuration ─────────────────────────────┐")
+ # print(f"│ Device: {device:<15} Dtype: {dtype} │")
+ print(f"│ Warmup: {warmup:<15} Iters: {iters} │")
+ if tokens:
+ print(f"│ Tokens: {tokens} │")
+ if vary_inputs:
+ print(f"│ Input Variation: Enabled (prevents caching artifacts) │")
+ print(f"└────────────────────────────────────────────────────────┘")
+
+ # Set up input generation
+ input_gen = None
+ if vary_inputs and args and isinstance(args[0], torch.Tensor):
+ base_input = args[0].clone()
+ iteration_counter = [0] # Use list for mutable closure
+
+ def generate_varied_input():
+ """Generate input tensor varied by iteration to prevent caching."""
+ # Add small deterministic increment: 0.001 * iteration_number
+ varied_input = base_input + (iteration_counter[0] * 0.001)
+ iteration_counter[0] += 1
+ return varied_input
+
+ input_gen = generate_varied_input
+ call = lambda x: fn(x, *args[1:], **kwargs)
+
+ # Log base input stats
+ if verbose:
+ print(f"\nBase Input: {tensor_stats(base_input)}")
+ print(f"Input Variation: +{0.001:.3f} * iteration (deterministic)")
+ else:
+ # Legacy mode - static inputs
+ call = lambda: fn(*args, **kwargs)
+ if verbose and args and isinstance(args[0], torch.Tensor):
+ print(f"\nInput: {tensor_stats(args[0])}")
+
+ result, times_s = _bench_engine(call, warmup=warmup, iters=iters, device=device, dtype=dtype, input_gen=input_gen)
+
+ # Log output if it's a tensor or tuple with tensors
+ if verbose:
+ print("\nOutput tensors:")
+ if isinstance(result, torch.Tensor):
+ print(f" Primary: {tensor_stats(result)}")
+ elif isinstance(result, tuple) and len(result) > 0 and isinstance(result[0], torch.Tensor):
+ print(f" Primary: {tensor_stats(result[0])}")
+ if len(result) > 1:
+ if isinstance(result[1], torch.Tensor):
+ print(f" Auxiliary: {tensor_stats(result[1])}")
+ else:
+ print(f" Auxiliary: {type(result[1]).__name__}")
+
+ # Compute and display statistics
+ stats = _compute_stats(times_s, tokens=tokens)
+ if verbose:
+ print(_format_timing_stats(stats, tokens))
+
+ # Save to JSON if requested
+ if save_json:
+ import json
+ json_data = {
+ "implementation": save_json.replace(".json", ""),
+ "config": {
+ "warmup": warmup,
+ "iters": iters,
+ "device": str(device), # Convert device to string
+ "dtype": str(dtype),
+ "tokens": tokens,
+ "vary_inputs": vary_inputs
+ },
+ "stats": stats,
+ "output_sum": float(result[0].sum().item()) if isinstance(result, tuple) and len(result) > 0 else float(result.sum().item()) if isinstance(result, torch.Tensor) else None
+ }
+ with open(save_json, 'w') as f:
+ json.dump(json_data, f, indent=2)
+ if verbose:
+ print(f"\nSaved benchmark results to {save_json}")
+
+ return result, stats
+
+ yield runner
+
+def set_seed(seed: int):
+ """Set seeds for reproducibility."""
+ torch.manual_seed(seed)
+ if torch.cuda.is_available():
+ torch.cuda.manual_seed(seed)
+ torch.cuda.manual_seed_all(seed)
+ torch.backends.cudnn.deterministic = True
+ torch.backends.cudnn.benchmark = False
+