File size: 9,787 Bytes
93c5002
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
# /// script
# dependencies = [
#     "torch",
#     "numpy",
# ]
# ///

"""Reusable benchmarking utilities for performance testing."""
import time
import numpy as np
from contextlib import contextmanager
from typing import Callable, Dict, Tuple, Any, Optional
import torch

def to_dtype(dtype_str: str):
    """Convert string to torch dtype."""
    if dtype_str == "float16":
        return torch.float16
    if dtype_str == "bfloat16":
        return torch.bfloat16
    return torch.float32

def _sync(device: str):
    """Synchronize device if CUDA."""
    if device == "cuda":
        torch.cuda.synchronize()

def _compute_stats(times_s, tokens: Optional[int] = None) -> Dict[str, float]:
    """Compute comprehensive latency and throughput statistics."""
    lat_ms = np.array([t * 1000.0 for t in times_s])
    lat_ms_sorted = np.sort(lat_ms)
    n = len(lat_ms)
    
    stats = {
        "avg_ms": np.mean(lat_ms),
        "min_ms": np.min(lat_ms),
        "max_ms": np.max(lat_ms),
        "std_ms": np.std(lat_ms),
        "p50_ms": np.percentile(lat_ms, 50),
        "p95_ms": np.percentile(lat_ms, 95),
        "p99_ms": np.percentile(lat_ms, 99),
        "num_iters": n
    }
    
    if tokens is not None and n > 0:
        avg_s = np.mean(times_s)
        stats["tokens_per_s"] = tokens / avg_s if avg_s > 0 else float("inf")
        stats["throughput_variance"] = np.std([tokens / t for t in times_s if t > 0])
    
    return stats

def _format_timing_stats(stats: Dict[str, float], tokens: Optional[int] = None) -> str:
    """Format timing statistics for display."""
    lines = [
        "\n━━━━━━━━━━━━━━━━━━━━ Benchmark Results ━━━━━━━━━━━━━━━━━━━━",
        f"Iterations: {stats.get('num_iters', 0)}",
        "\nLatency Statistics:",
        f"  Average: {stats['avg_ms']:.3f} ms",
        f"  Min:     {stats['min_ms']:.3f} ms",
        f"  Max:     {stats['max_ms']:.3f} ms", 
        f"  Std Dev: {stats['std_ms']:.3f} ms",
        "\nPercentiles:",
        f"  P50 (median): {stats['p50_ms']:.3f} ms",
        f"  P95:          {stats['p95_ms']:.3f} ms",
        f"  P99:          {stats['p99_ms']:.3f} ms",
    ]
    
    if tokens is not None and 'tokens_per_s' in stats:
        lines.extend([
            "\nThroughput:",
            f"  Tokens/sec: {stats['tokens_per_s']:.1f}",
            f"  Std Dev:    {stats.get('throughput_variance', 0):.1f}",
        ])
    
    lines.append("━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━")
    return "\n".join(lines)

def _bench_engine(
    call: Callable[[], Any], *, warmup: int, iters: int, device: str, dtype, input_gen: Callable[[], Any] = None
) -> Tuple[Any, list]:
    """Core benchmarking engine with warmup and timing."""
    use_autocast = device == "cuda" and dtype in (torch.float16, torch.bfloat16)

    # Warmup phase
    print(f"\nWarming up ({warmup} iterations)...")
    with torch.inference_mode():
        for _ in range(max(0, warmup)):
            if use_autocast:
                with torch.autocast(device_type="cuda", dtype=dtype):
                    if input_gen is not None:
                        _ = call(input_gen())
                    else:
                        _ = call()
            else:
                if input_gen is not None:
                    _ = call(input_gen())
                else:
                    _ = call()
        _sync(device)

    # Measurement phase
    print(f"Benchmarking ({iters} iterations)...")
    times_s = []
    last = None
    with torch.inference_mode():
        for i in range(max(1, iters)):
            start = time.perf_counter()
            if use_autocast:
                with torch.autocast(device_type="cuda", dtype=dtype):
                    if input_gen is not None:
                        last = call(input_gen())
                    else:
                        last = call()
            else:
                if input_gen is not None:
                    last = call(input_gen())
                else:
                    last = call()
            _sync(device)
            end = time.perf_counter()
            times_s.append(end - start)

            # Progress indicator every 20% of iterations
            if i > 0 and i % max(1, iters // 5) == 0:
                pct = (i / iters) * 100
                avg_so_far = np.mean(times_s[:i]) * 1000
                print(f"  Progress: {pct:.0f}% complete (avg: {avg_so_far:.3f} ms)")

    return last, times_s

def tensor_stats(t: torch.Tensor) -> str:
    """Generate comprehensive stats string for a tensor."""
    return (f"shape={tuple(t.shape)}, "
            f"dtype={t.dtype}, "
            f"device={t.device}, "
            f"range=[{t.min().item():.6f}, {t.max().item():.6f}], "
            f"mean={t.mean().item():.6f}, "
            f"std={t.std().item():.6f}, "
            f"norm={t.norm().item():.6f}")

@contextmanager
def bench_context(
    *, warmup: int = 25, iters: int = 100, device: str = "cuda", dtype=torch.float32, tokens: Optional[int] = None, verbose: bool = True, save_json: Optional[str] = None, vary_inputs: bool = True
):
    """Context that yields a runner: runner(fn, *args, **kwargs) -> (result, stats).

    If vary_inputs=True, the first argument should be a base tensor that will be varied each iteration
    by adding a small deterministic increment to prevent caching artifacts.
    """

    def runner(fn: Callable[..., Any], *args, **kwargs) -> Tuple[Any, Dict[str, float]]:
        # Log configuration
        if verbose:
            print(f"\n┌─ Benchmark Configuration ─────────────────────────────┐")
            # print(f"│ Device: {device:<15} Dtype: {dtype}              │")
            print(f"│ Warmup: {warmup:<15} Iters: {iters}              │")
            if tokens:
                print(f"│ Tokens: {tokens}                                        │")
            if vary_inputs:
                print(f"│ Input Variation: Enabled (prevents caching artifacts)  │")
            print(f"└────────────────────────────────────────────────────────┘")

        # Set up input generation
        input_gen = None
        if vary_inputs and args and isinstance(args[0], torch.Tensor):
            base_input = args[0].clone()
            iteration_counter = [0]  # Use list for mutable closure

            def generate_varied_input():
                """Generate input tensor varied by iteration to prevent caching."""
                # Add small deterministic increment: 0.001 * iteration_number
                varied_input = base_input + (iteration_counter[0] * 0.001)
                iteration_counter[0] += 1
                return varied_input

            input_gen = generate_varied_input
            call = lambda x: fn(x, *args[1:], **kwargs)

            # Log base input stats
            if verbose:
                print(f"\nBase Input: {tensor_stats(base_input)}")
                print(f"Input Variation: +{0.001:.3f} * iteration (deterministic)")
        else:
            # Legacy mode - static inputs
            call = lambda: fn(*args, **kwargs)
            if verbose and args and isinstance(args[0], torch.Tensor):
                print(f"\nInput: {tensor_stats(args[0])}")

        result, times_s = _bench_engine(call, warmup=warmup, iters=iters, device=device, dtype=dtype, input_gen=input_gen)

        # Log output if it's a tensor or tuple with tensors
        if verbose:
            print("\nOutput tensors:")
            if isinstance(result, torch.Tensor):
                print(f"  Primary: {tensor_stats(result)}")
            elif isinstance(result, tuple) and len(result) > 0 and isinstance(result[0], torch.Tensor):
                print(f"  Primary: {tensor_stats(result[0])}")
                if len(result) > 1:
                    if isinstance(result[1], torch.Tensor):
                        print(f"  Auxiliary: {tensor_stats(result[1])}")
                    else:
                        print(f"  Auxiliary: {type(result[1]).__name__}")

        # Compute and display statistics
        stats = _compute_stats(times_s, tokens=tokens)
        if verbose:
            print(_format_timing_stats(stats, tokens))

        # Save to JSON if requested
        if save_json:
            import json
            json_data = {
                "implementation": save_json.replace(".json", ""),
                "config": {
                    "warmup": warmup,
                    "iters": iters,
                    "device": str(device),  # Convert device to string
                    "dtype": str(dtype),
                    "tokens": tokens,
                    "vary_inputs": vary_inputs
                },
                "stats": stats,
                "output_sum": float(result[0].sum().item()) if isinstance(result, tuple) and len(result) > 0 else float(result.sum().item()) if isinstance(result, torch.Tensor) else None
            }
            with open(save_json, 'w') as f:
                json.dump(json_data, f, indent=2)
            if verbose:
                print(f"\nSaved benchmark results to {save_json}")

        return result, stats

    yield runner

def set_seed(seed: int):
    """Set seeds for reproducibility."""
    torch.manual_seed(seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed(seed)
        torch.cuda.manual_seed_all(seed)
        torch.backends.cudnn.deterministic = True
        torch.backends.cudnn.benchmark = False