|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
import os |
|
|
import csv |
|
|
from pathlib import Path |
|
|
import json |
|
|
|
|
|
|
|
|
cache_dirs = { |
|
|
"Flash (PyTorch SDPA)": os.environ.get('UVNOTE_FILE_FLASH_ATTENTION_BENCHMARK'), |
|
|
"MemEff (PyTorch SDPA)": os.environ.get('UVNOTE_FILE_MEM_EFFICIENT_ATTENTION_BENCHMARK'), |
|
|
"Flash Attn 2": os.environ.get('UVNOTE_FILE_FLASH_ATTN2_BENCHMARK'), |
|
|
"xFormers": os.environ.get('UVNOTE_FILE_XFORMERS_BENCHMARK'), |
|
|
"SageAttention": os.environ.get('UVNOTE_FILE_SAGE_ATTENTION_BENCHMARK'), |
|
|
"Compiled (default)": os.environ.get('UVNOTE_FILE_COMPILED_VARIANTS_BENCHMARK_DEFAULT'), |
|
|
"Compiled (max-autotune)": os.environ.get('UVNOTE_FILE_COMPILED_VARIANTS_BENCHMARK_MAX_AUTOTUNE'), |
|
|
"HF Kernels Flash Attn": os.environ.get('UVNOTE_FILE_HF_KERNELS_FLASH_ATTN_BENCHMARK'), |
|
|
"HF Kernels Flash Attn3": os.environ.get('UVNOTE_FILE_HF_KERNELS_FLASH_ATTN3_BENCHMARK'), |
|
|
} |
|
|
|
|
|
file_mapping = { |
|
|
"Flash (PyTorch SDPA)": "attn.jsonl", |
|
|
"MemEff (PyTorch SDPA)": "attn.jsonl", |
|
|
"Flash Attn 2": "attn.jsonl", |
|
|
"xFormers": "attn.jsonl", |
|
|
"SageAttention": "attn.jsonl", |
|
|
"Compiled (default)": "attn_default.jsonl", |
|
|
"Compiled (max-autotune)": "attn_max_autotune.jsonl", |
|
|
"HF Kernels Flash Attn": "attn.jsonl", |
|
|
"HF Kernels Flash Attn3": "attn.jsonl", |
|
|
} |
|
|
|
|
|
|
|
|
all_data = {} |
|
|
for name, cache_dir in cache_dirs.items(): |
|
|
if cache_dir: |
|
|
path = Path(cache_dir) / file_mapping[name] |
|
|
if path.exists() and path.stat().st_size > 0: |
|
|
with open(path, 'r') as f: |
|
|
records = [json.loads(line) for line in f] |
|
|
all_data[name] = records |
|
|
|
|
|
|
|
|
csv_path = Path("latency.csv") |
|
|
with open(csv_path, 'w', newline='') as csvfile: |
|
|
writer = csv.writer(csvfile) |
|
|
|
|
|
|
|
|
header = ["Implementation", "Sequence Length", "Latency (ms)", "Min (ms)", "Max (ms)", "Median (ms)"] |
|
|
writer.writerow(header) |
|
|
|
|
|
|
|
|
for impl_name, records in all_data.items(): |
|
|
for record in records: |
|
|
row = [ |
|
|
impl_name, |
|
|
record.get('seqlen', ''), |
|
|
record.get('latency', ''), |
|
|
record.get('min', ''), |
|
|
record.get('max', ''), |
|
|
record.get('median', ''), |
|
|
] |
|
|
writer.writerow(row) |
|
|
|
|
|
print(f"✓ CSV export complete: {csv_path}") |
|
|
print(f"Total implementations: {len(all_data)}") |
|
|
print(f"Total records: {sum(len(records) for records in all_data.values())}") |