File size: 4,938 Bytes
b975ca1
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
# /// script
# dependencies = [
#     "matplotlib",
# ]
# ///

import json
import matplotlib.pyplot as plt
import numpy as np
from pathlib import Path
import os

# List of expected result files
yamoe_dir = os.environ.get('UVNOTE_INPUT_YAMOE_RUN', '.')
binned_dir = os.environ.get('UVNOTE_INPUT_BINNED_RUN', '.')
gptoss_dir = os.environ.get('UVNOTE_INPUT_GPTOSS_RUN', '.')
gptoss_training_dir = os.environ.get('UVNOTE_INPUT_GPTOSS_TRAINING_RUN', '.')
megablocks_dir = os.environ.get('UVNOTE_INPUT_MEGABLOCKS_RUN', '.')

result_files = [
    Path(yamoe_dir) / "yamoe_results.json",
    Path(binned_dir) / "binned_results.json", 
    Path(gptoss_dir) / "gptoss_results.json",
    Path(gptoss_training_dir) / "gptoss_training_results.json",
    Path(megablocks_dir) / "megablocks_results.json"
]

# Load all benchmark results
results = {}
for file in result_files:
    if Path(file).exists():
        with open(file, 'r') as f:
            data = json.load(f)
            results[data['implementation']] = data
        print(f"Loaded {file}")
    else:
        print(f"Missing {file}")

if not results:
    print("No benchmark results found. Run the benchmark cells first.")
else:
    # Extract data for plotting
    implementations = list(results.keys())
    avg_latencies = [results[impl]['stats']['avg_ms'] for impl in implementations]
    p95_latencies = [results[impl]['stats']['p95_ms'] for impl in implementations]
    throughputs = [results[impl]['stats'].get('tokens_per_s', 0) for impl in implementations]
    
    # Create figure with subplots
    fig, (ax1, ax2, ax3) = plt.subplots(1, 3, figsize=(18, 6))
    fig.suptitle('MoE Implementation Performance Comparison', fontsize=16, fontweight='bold')
    
    # Colors for each implementation
    colors = ['#FF6B6B', '#4ECDC4', '#45B7D1', '#96CEB4', '#FECA57'][:len(implementations)]
    
    # 1. Average Latency Chart
    bars1 = ax1.bar(implementations, avg_latencies, color=colors, alpha=0.8, edgecolor='black', linewidth=1)
    ax1.set_title('Average Latency', fontweight='bold', fontsize=14)
    ax1.set_ylabel('Latency (ms)', fontweight='bold')
    ax1.tick_params(axis='x', rotation=45)
    ax1.grid(axis='y', alpha=0.3)
    
    # Add value labels on bars
    for bar, val in zip(bars1, avg_latencies):
        ax1.text(bar.get_x() + bar.get_width()/2, bar.get_height() + max(avg_latencies)*0.01,
                f'{val:.2f}ms', ha='center', va='bottom', fontweight='bold')
    
    # 2. P95 Latency Chart
    bars2 = ax2.bar(implementations, p95_latencies, color=colors, alpha=0.8, edgecolor='black', linewidth=1)
    ax2.set_title('95th Percentile Latency', fontweight='bold', fontsize=14)
    ax2.set_ylabel('Latency (ms)', fontweight='bold')
    ax2.tick_params(axis='x', rotation=45)
    ax2.grid(axis='y', alpha=0.3)
    
    # Add value labels on bars
    for bar, val in zip(bars2, p95_latencies):
        ax2.text(bar.get_x() + bar.get_width()/2, bar.get_height() + max(p95_latencies)*0.01,
                f'{val:.2f}ms', ha='center', va='bottom', fontweight='bold')
    
    # 3. Throughput Chart
    bars3 = ax3.bar(implementations, throughputs, color=colors, alpha=0.8, edgecolor='black', linewidth=1)
    ax3.set_title('Throughput', fontweight='bold', fontsize=14)
    ax3.set_ylabel('Tokens/sec', fontweight='bold')
    ax3.tick_params(axis='x', rotation=45)
    ax3.grid(axis='y', alpha=0.3)
    
    # Add value labels on bars
    for bar, val in zip(bars3, throughputs):
        if val > 0:  # Only show label if throughput was calculated
            ax3.text(bar.get_x() + bar.get_width()/2, bar.get_height() + max(throughputs)*0.01,
                    f'{val:.0f}', ha='center', va='bottom', fontweight='bold')
    
    plt.tight_layout()
    plt.savefig("moe_performance_comparison.png", dpi=300)
    
    # Print summary table
    print("\nPerformance Summary:")
    print(f"{'Implementation':<30} {'Avg (ms)':<12} {'P95 (ms)':<12} {'Tokens/sec':<12} {'Relative Speed':<15}")
    print("-"*80)
    
    # Sort by average latency for relative speed calculation
    sorted_results = sorted(results.items(), key=lambda x: x[1]['stats']['avg_ms'])
    fastest_latency = sorted_results[0][1]['stats']['avg_ms']
    
    for impl, data in sorted_results:
        avg_ms = data['stats']['avg_ms']
        p95_ms = data['stats']['p95_ms']
        tokens_s = data['stats'].get('tokens_per_s', 0)
        relative_speed = fastest_latency / avg_ms
        
        print(f"{impl:<30} {avg_ms:>8.2f}    {p95_ms:>8.2f}    {tokens_s:>8.0f}      {relative_speed:>6.2f}x")
    
    print(f"\nFastest: {sorted_results[0][0]} ({sorted_results[0][1]['stats']['avg_ms']:.2f}ms avg)")
    if len(sorted_results) > 1:
        print(f"Slowest: {sorted_results[-1][0]} ({sorted_results[-1][1]['stats']['avg_ms']:.2f}ms avg)")
        speedup = sorted_results[-1][1]['stats']['avg_ms'] / sorted_results[0][1]['stats']['avg_ms']
        print(f"Max Speedup: {speedup:.1f}x")