# /// script # dependencies = [ # "matplotlib", # ] # /// import json import matplotlib.pyplot as plt import numpy as np from pathlib import Path import os # List of expected result files yamoe_dir = os.environ.get('UVNOTE_INPUT_YAMOE_RUN', '.') binned_dir = os.environ.get('UVNOTE_INPUT_BINNED_RUN', '.') gptoss_dir = os.environ.get('UVNOTE_INPUT_GPTOSS_RUN', '.') gptoss_training_dir = os.environ.get('UVNOTE_INPUT_GPTOSS_TRAINING_RUN', '.') megablocks_dir = os.environ.get('UVNOTE_INPUT_MEGABLOCKS_RUN', '.') result_files = [ Path(yamoe_dir) / "yamoe_results.json", Path(binned_dir) / "binned_results.json", Path(gptoss_dir) / "gptoss_results.json", Path(gptoss_training_dir) / "gptoss_training_results.json", Path(megablocks_dir) / "megablocks_results.json" ] # Load all benchmark results results = {} for file in result_files: if Path(file).exists(): with open(file, 'r') as f: data = json.load(f) results[data['implementation']] = data print(f"Loaded {file}") else: print(f"Missing {file}") if not results: print("No benchmark results found. Run the benchmark cells first.") else: # Extract data for plotting implementations = list(results.keys()) avg_latencies = [results[impl]['stats']['avg_ms'] for impl in implementations] p95_latencies = [results[impl]['stats']['p95_ms'] for impl in implementations] throughputs = [results[impl]['stats'].get('tokens_per_s', 0) for impl in implementations] # Create figure with subplots fig, (ax1, ax2, ax3) = plt.subplots(1, 3, figsize=(18, 6)) fig.suptitle('MoE Implementation Performance Comparison', fontsize=16, fontweight='bold') # Colors for each implementation colors = ['#FF6B6B', '#4ECDC4', '#45B7D1', '#96CEB4', '#FECA57'][:len(implementations)] # 1. Average Latency Chart bars1 = ax1.bar(implementations, avg_latencies, color=colors, alpha=0.8, edgecolor='black', linewidth=1) ax1.set_title('Average Latency', fontweight='bold', fontsize=14) ax1.set_ylabel('Latency (ms)', fontweight='bold') ax1.tick_params(axis='x', rotation=45) ax1.grid(axis='y', alpha=0.3) # Add value labels on bars for bar, val in zip(bars1, avg_latencies): ax1.text(bar.get_x() + bar.get_width()/2, bar.get_height() + max(avg_latencies)*0.01, f'{val:.2f}ms', ha='center', va='bottom', fontweight='bold') # 2. P95 Latency Chart bars2 = ax2.bar(implementations, p95_latencies, color=colors, alpha=0.8, edgecolor='black', linewidth=1) ax2.set_title('95th Percentile Latency', fontweight='bold', fontsize=14) ax2.set_ylabel('Latency (ms)', fontweight='bold') ax2.tick_params(axis='x', rotation=45) ax2.grid(axis='y', alpha=0.3) # Add value labels on bars for bar, val in zip(bars2, p95_latencies): ax2.text(bar.get_x() + bar.get_width()/2, bar.get_height() + max(p95_latencies)*0.01, f'{val:.2f}ms', ha='center', va='bottom', fontweight='bold') # 3. Throughput Chart bars3 = ax3.bar(implementations, throughputs, color=colors, alpha=0.8, edgecolor='black', linewidth=1) ax3.set_title('Throughput', fontweight='bold', fontsize=14) ax3.set_ylabel('Tokens/sec', fontweight='bold') ax3.tick_params(axis='x', rotation=45) ax3.grid(axis='y', alpha=0.3) # Add value labels on bars for bar, val in zip(bars3, throughputs): if val > 0: # Only show label if throughput was calculated ax3.text(bar.get_x() + bar.get_width()/2, bar.get_height() + max(throughputs)*0.01, f'{val:.0f}', ha='center', va='bottom', fontweight='bold') plt.tight_layout() plt.savefig("moe_performance_comparison.png", dpi=300) # Print summary table print("\nPerformance Summary:") print(f"{'Implementation':<30} {'Avg (ms)':<12} {'P95 (ms)':<12} {'Tokens/sec':<12} {'Relative Speed':<15}") print("-"*80) # Sort by average latency for relative speed calculation sorted_results = sorted(results.items(), key=lambda x: x[1]['stats']['avg_ms']) fastest_latency = sorted_results[0][1]['stats']['avg_ms'] for impl, data in sorted_results: avg_ms = data['stats']['avg_ms'] p95_ms = data['stats']['p95_ms'] tokens_s = data['stats'].get('tokens_per_s', 0) relative_speed = fastest_latency / avg_ms print(f"{impl:<30} {avg_ms:>8.2f} {p95_ms:>8.2f} {tokens_s:>8.0f} {relative_speed:>6.2f}x") print(f"\nFastest: {sorted_results[0][0]} ({sorted_results[0][1]['stats']['avg_ms']:.2f}ms avg)") if len(sorted_results) > 1: print(f"Slowest: {sorted_results[-1][0]} ({sorted_results[-1][1]['stats']['avg_ms']:.2f}ms avg)") speedup = sorted_results[-1][1]['stats']['avg_ms'] / sorted_results[0][1]['stats']['avg_ms'] print(f"Max Speedup: {speedup:.1f}x")