drbh's picture
drbh HF Staff
Upload folder using huggingface_hub
b975ca1 verified
raw
history blame
4.94 kB
# /// script
# dependencies = [
# "matplotlib",
# ]
# ///
import json
import matplotlib.pyplot as plt
import numpy as np
from pathlib import Path
import os
# List of expected result files
yamoe_dir = os.environ.get('UVNOTE_INPUT_YAMOE_RUN', '.')
binned_dir = os.environ.get('UVNOTE_INPUT_BINNED_RUN', '.')
gptoss_dir = os.environ.get('UVNOTE_INPUT_GPTOSS_RUN', '.')
gptoss_training_dir = os.environ.get('UVNOTE_INPUT_GPTOSS_TRAINING_RUN', '.')
megablocks_dir = os.environ.get('UVNOTE_INPUT_MEGABLOCKS_RUN', '.')
result_files = [
Path(yamoe_dir) / "yamoe_results.json",
Path(binned_dir) / "binned_results.json",
Path(gptoss_dir) / "gptoss_results.json",
Path(gptoss_training_dir) / "gptoss_training_results.json",
Path(megablocks_dir) / "megablocks_results.json"
]
# Load all benchmark results
results = {}
for file in result_files:
if Path(file).exists():
with open(file, 'r') as f:
data = json.load(f)
results[data['implementation']] = data
print(f"Loaded {file}")
else:
print(f"Missing {file}")
if not results:
print("No benchmark results found. Run the benchmark cells first.")
else:
# Extract data for plotting
implementations = list(results.keys())
avg_latencies = [results[impl]['stats']['avg_ms'] for impl in implementations]
p95_latencies = [results[impl]['stats']['p95_ms'] for impl in implementations]
throughputs = [results[impl]['stats'].get('tokens_per_s', 0) for impl in implementations]
# Create figure with subplots
fig, (ax1, ax2, ax3) = plt.subplots(1, 3, figsize=(18, 6))
fig.suptitle('MoE Implementation Performance Comparison', fontsize=16, fontweight='bold')
# Colors for each implementation
colors = ['#FF6B6B', '#4ECDC4', '#45B7D1', '#96CEB4', '#FECA57'][:len(implementations)]
# 1. Average Latency Chart
bars1 = ax1.bar(implementations, avg_latencies, color=colors, alpha=0.8, edgecolor='black', linewidth=1)
ax1.set_title('Average Latency', fontweight='bold', fontsize=14)
ax1.set_ylabel('Latency (ms)', fontweight='bold')
ax1.tick_params(axis='x', rotation=45)
ax1.grid(axis='y', alpha=0.3)
# Add value labels on bars
for bar, val in zip(bars1, avg_latencies):
ax1.text(bar.get_x() + bar.get_width()/2, bar.get_height() + max(avg_latencies)*0.01,
f'{val:.2f}ms', ha='center', va='bottom', fontweight='bold')
# 2. P95 Latency Chart
bars2 = ax2.bar(implementations, p95_latencies, color=colors, alpha=0.8, edgecolor='black', linewidth=1)
ax2.set_title('95th Percentile Latency', fontweight='bold', fontsize=14)
ax2.set_ylabel('Latency (ms)', fontweight='bold')
ax2.tick_params(axis='x', rotation=45)
ax2.grid(axis='y', alpha=0.3)
# Add value labels on bars
for bar, val in zip(bars2, p95_latencies):
ax2.text(bar.get_x() + bar.get_width()/2, bar.get_height() + max(p95_latencies)*0.01,
f'{val:.2f}ms', ha='center', va='bottom', fontweight='bold')
# 3. Throughput Chart
bars3 = ax3.bar(implementations, throughputs, color=colors, alpha=0.8, edgecolor='black', linewidth=1)
ax3.set_title('Throughput', fontweight='bold', fontsize=14)
ax3.set_ylabel('Tokens/sec', fontweight='bold')
ax3.tick_params(axis='x', rotation=45)
ax3.grid(axis='y', alpha=0.3)
# Add value labels on bars
for bar, val in zip(bars3, throughputs):
if val > 0: # Only show label if throughput was calculated
ax3.text(bar.get_x() + bar.get_width()/2, bar.get_height() + max(throughputs)*0.01,
f'{val:.0f}', ha='center', va='bottom', fontweight='bold')
plt.tight_layout()
plt.savefig("moe_performance_comparison.png", dpi=300)
# Print summary table
print("\nPerformance Summary:")
print(f"{'Implementation':<30} {'Avg (ms)':<12} {'P95 (ms)':<12} {'Tokens/sec':<12} {'Relative Speed':<15}")
print("-"*80)
# Sort by average latency for relative speed calculation
sorted_results = sorted(results.items(), key=lambda x: x[1]['stats']['avg_ms'])
fastest_latency = sorted_results[0][1]['stats']['avg_ms']
for impl, data in sorted_results:
avg_ms = data['stats']['avg_ms']
p95_ms = data['stats']['p95_ms']
tokens_s = data['stats'].get('tokens_per_s', 0)
relative_speed = fastest_latency / avg_ms
print(f"{impl:<30} {avg_ms:>8.2f} {p95_ms:>8.2f} {tokens_s:>8.0f} {relative_speed:>6.2f}x")
print(f"\nFastest: {sorted_results[0][0]} ({sorted_results[0][1]['stats']['avg_ms']:.2f}ms avg)")
if len(sorted_results) > 1:
print(f"Slowest: {sorted_results[-1][0]} ({sorted_results[-1][1]['stats']['avg_ms']:.2f}ms avg)")
speedup = sorted_results[-1][1]['stats']['avg_ms'] / sorted_results[0][1]['stats']['avg_ms']
print(f"Max Speedup: {speedup:.1f}x")