# /// script # requires-python = ">=3.10" # dependencies = [ # "numpy", # "torch", # "kernels-benchmark-tools", # ] # # [tool.uv.sources] # kernels-benchmark-tools = { git = "https://github.com/drbh/kernels-benchmark-tools.git", branch = "main" } # /// import os import csv from pathlib import Path import json # --- Locate benchmark artifacts -------------------------------------------------- cache_dirs = { "Flash (PyTorch SDPA)": os.environ.get('UVNOTE_FILE_FLASH_ATTENTION_BENCHMARK'), "MemEff (PyTorch SDPA)": os.environ.get('UVNOTE_FILE_MEM_EFFICIENT_ATTENTION_BENCHMARK'), "Flash Attn 2": os.environ.get('UVNOTE_FILE_FLASH_ATTN2_BENCHMARK'), "xFormers": os.environ.get('UVNOTE_FILE_XFORMERS_BENCHMARK'), "SageAttention": os.environ.get('UVNOTE_FILE_SAGE_ATTENTION_BENCHMARK'), "Compiled (default)": os.environ.get('UVNOTE_FILE_COMPILED_VARIANTS_BENCHMARK_DEFAULT'), "Compiled (max-autotune)": os.environ.get('UVNOTE_FILE_COMPILED_VARIANTS_BENCHMARK_MAX_AUTOTUNE'), "HF Kernels Flash Attn": os.environ.get('UVNOTE_FILE_HF_KERNELS_FLASH_ATTN_BENCHMARK'), "HF Kernels Flash Attn3": os.environ.get('UVNOTE_FILE_HF_KERNELS_FLASH_ATTN3_BENCHMARK'), } file_mapping = { "Flash (PyTorch SDPA)": "attn.jsonl", "MemEff (PyTorch SDPA)": "attn.jsonl", "Flash Attn 2": "attn.jsonl", "xFormers": "attn.jsonl", "SageAttention": "attn.jsonl", "Compiled (default)": "attn_default.jsonl", "Compiled (max-autotune)": "attn_max_autotune.jsonl", "HF Kernels Flash Attn": "attn.jsonl", "HF Kernels Flash Attn3": "attn.jsonl", } # Collect all benchmark data all_data = {} for name, cache_dir in cache_dirs.items(): if cache_dir: path = Path(cache_dir) / file_mapping[name] if path.exists() and path.stat().st_size > 0: with open(path, 'r') as f: records = [json.loads(line) for line in f] all_data[name] = records # Export to CSV csv_path = Path("latency.csv") with open(csv_path, 'w', newline='') as csvfile: writer = csv.writer(csvfile) # Write header header = ["Implementation", "Sequence Length", "Latency (ms)", "Min (ms)", "Max (ms)", "Median (ms)"] writer.writerow(header) # Write data rows for impl_name, records in all_data.items(): for record in records: row = [ impl_name, record.get('seqlen', ''), record.get('latency', ''), record.get('min', ''), record.get('max', ''), record.get('median', ''), ] writer.writerow(row) print(f"✓ CSV export complete: {csv_path}") print(f"Total implementations: {len(all_data)}") print(f"Total records: {sum(len(records) for records in all_data.values())}")