Spaces:
Sleeping
Sleeping
| """ | |
| Utilities for generating training report cards. More messy code than usual, will fix. | |
| """ | |
| import os | |
| import re | |
| import shutil | |
| import subprocess | |
| import socket | |
| import datetime | |
| import platform | |
| import psutil | |
| import torch | |
| def run_command(cmd): | |
| """Run a shell command and return output, or None if it fails.""" | |
| try: | |
| result = subprocess.run(cmd, shell=True, capture_output=True, text=True, timeout=5) | |
| if result.returncode == 0: | |
| return result.stdout.strip() | |
| return None | |
| except: | |
| return None | |
| def get_git_info(): | |
| """Get current git commit, branch, and dirty status.""" | |
| info = {} | |
| info['commit'] = run_command("git rev-parse --short HEAD") or "unknown" | |
| info['branch'] = run_command("git rev-parse --abbrev-ref HEAD") or "unknown" | |
| # Check if repo is dirty (has uncommitted changes) | |
| status = run_command("git status --porcelain") | |
| info['dirty'] = bool(status) if status is not None else False | |
| # Get commit message | |
| info['message'] = run_command("git log -1 --pretty=%B") or "" | |
| info['message'] = info['message'].split('\n')[0][:80] # First line, truncated | |
| return info | |
| def get_gpu_info(): | |
| """Get GPU information.""" | |
| if not torch.cuda.is_available(): | |
| return {"available": False} | |
| num_devices = torch.cuda.device_count() | |
| info = { | |
| "available": True, | |
| "count": num_devices, | |
| "names": [], | |
| "memory_gb": [] | |
| } | |
| for i in range(num_devices): | |
| props = torch.cuda.get_device_properties(i) | |
| info["names"].append(props.name) | |
| info["memory_gb"].append(props.total_memory / (1024**3)) | |
| # Get CUDA version | |
| info["cuda_version"] = torch.version.cuda or "unknown" | |
| return info | |
| def get_system_info(): | |
| """Get system information.""" | |
| info = {} | |
| # Basic system info | |
| info['hostname'] = socket.gethostname() | |
| info['platform'] = platform.system() | |
| info['python_version'] = platform.python_version() | |
| info['torch_version'] = torch.__version__ | |
| # CPU and memory | |
| info['cpu_count'] = psutil.cpu_count(logical=False) | |
| info['cpu_count_logical'] = psutil.cpu_count(logical=True) | |
| info['memory_gb'] = psutil.virtual_memory().total / (1024**3) | |
| # User and environment | |
| info['user'] = os.environ.get('USER', 'unknown') | |
| info['nanochat_base_dir'] = os.environ.get('NANOCHAT_BASE_DIR', 'out') | |
| info['working_dir'] = os.getcwd() | |
| return info | |
| def estimate_cost(gpu_info, runtime_hours=None): | |
| """Estimate training cost based on GPU type and runtime.""" | |
| # Rough pricing, from Lambda Cloud | |
| default_rate = 2.0 | |
| gpu_hourly_rates = { | |
| "H100": 3.00, | |
| "A100": 1.79, | |
| "V100": 0.55, | |
| } | |
| if not gpu_info.get("available"): | |
| return None | |
| # Try to identify GPU type from name | |
| hourly_rate = None | |
| gpu_name = gpu_info["names"][0] if gpu_info["names"] else "unknown" | |
| for gpu_type, rate in gpu_hourly_rates.items(): | |
| if gpu_type in gpu_name: | |
| hourly_rate = rate * gpu_info["count"] | |
| break | |
| if hourly_rate is None: | |
| hourly_rate = default_rate * gpu_info["count"] # Default estimate | |
| return { | |
| "hourly_rate": hourly_rate, | |
| "gpu_type": gpu_name, | |
| "estimated_total": hourly_rate * runtime_hours if runtime_hours else None | |
| } | |
| def generate_header(): | |
| """Generate the header for a training report.""" | |
| timestamp = datetime.datetime.now().strftime("%Y-%m-%d %H:%M:%S") | |
| git_info = get_git_info() | |
| gpu_info = get_gpu_info() | |
| sys_info = get_system_info() | |
| cost_info = estimate_cost(gpu_info) | |
| header = f"""# nanochat training report | |
| Generated: {timestamp} | |
| ## Environment | |
| ### Git Information | |
| - Branch: {git_info['branch']} | |
| - Commit: {git_info['commit']} {"(dirty)" if git_info['dirty'] else "(clean)"} | |
| - Message: {git_info['message']} | |
| ### Hardware | |
| - Platform: {sys_info['platform']} | |
| - CPUs: {sys_info['cpu_count']} cores ({sys_info['cpu_count_logical']} logical) | |
| - Memory: {sys_info['memory_gb']:.1f} GB | |
| """ | |
| if gpu_info.get("available"): | |
| gpu_names = ", ".join(set(gpu_info["names"])) | |
| total_vram = sum(gpu_info["memory_gb"]) | |
| header += f"""- GPUs: {gpu_info['count']}x {gpu_names} | |
| - GPU Memory: {total_vram:.1f} GB total | |
| - CUDA Version: {gpu_info['cuda_version']} | |
| """ | |
| else: | |
| header += "- GPUs: None available\n" | |
| if cost_info and cost_info["hourly_rate"] > 0: | |
| header += f"""- Hourly Rate: ${cost_info['hourly_rate']:.2f}/hour\n""" | |
| header += f""" | |
| ### Software | |
| - Python: {sys_info['python_version']} | |
| - PyTorch: {sys_info['torch_version']} | |
| """ | |
| # bloat metrics: package all of the source code and assess its weight | |
| packaged = run_command('files-to-prompt . -e py -e md -e rs -e html -e toml -e sh --ignore "*target*" --cxml') | |
| num_chars = len(packaged) | |
| num_lines = len(packaged.split('\n')) | |
| num_files = len([x for x in packaged.split('\n') if x.startswith('<source>')]) | |
| num_tokens = num_chars // 4 # assume approximately 4 chars per token | |
| # count dependencies via uv.lock | |
| uv_lock_lines = 0 | |
| if os.path.exists('uv.lock'): | |
| with open('uv.lock', 'r') as f: | |
| uv_lock_lines = len(f.readlines()) | |
| header += f""" | |
| ### Bloat | |
| - Characters: {num_chars:,} | |
| - Lines: {num_lines:,} | |
| - Files: {num_files:,} | |
| - Tokens (approx): {num_tokens:,} | |
| - Dependencies (uv.lock lines): {uv_lock_lines:,} | |
| """ | |
| return header | |
| # ----------------------------------------------------------------------------- | |
| def slugify(text): | |
| """Slugify a text string.""" | |
| return text.lower().replace(" ", "-") | |
| # the expected files and their order | |
| EXPECTED_FILES = [ | |
| "tokenizer-training.md", | |
| "tokenizer-evaluation.md", | |
| "base-model-training.md", | |
| "base-model-loss.md", | |
| "base-model-evaluation.md", | |
| "midtraining.md", | |
| "chat-evaluation-mid.md", | |
| "chat-sft.md", | |
| "chat-evaluation-sft.md", | |
| "chat-rl.md", | |
| "chat-evaluation-rl.md", | |
| ] | |
| # the metrics we're currently interested in | |
| chat_metrics = ["ARC-Easy", "ARC-Challenge", "MMLU", "GSM8K", "HumanEval", "ChatCORE"] | |
| def extract(section, keys): | |
| """simple def to extract a single key from a section""" | |
| if not isinstance(keys, list): | |
| keys = [keys] # convenience | |
| out = {} | |
| for line in section.split("\n"): | |
| for key in keys: | |
| if key in line: | |
| out[key] = line.split(":")[1].strip() | |
| return out | |
| def extract_timestamp(content, prefix): | |
| """Extract timestamp from content with given prefix.""" | |
| for line in content.split('\n'): | |
| if line.startswith(prefix): | |
| time_str = line.split(":", 1)[1].strip() | |
| try: | |
| return datetime.datetime.strptime(time_str, "%Y-%m-%d %H:%M:%S") | |
| except: | |
| pass | |
| return None | |
| class Report: | |
| """Maintains a bunch of logs, generates a final markdown report.""" | |
| def __init__(self, report_dir): | |
| os.makedirs(report_dir, exist_ok=True) | |
| self.report_dir = report_dir | |
| def log(self, section, data): | |
| """Log a section of data to the report.""" | |
| slug = slugify(section) | |
| file_name = f"{slug}.md" | |
| file_path = os.path.join(self.report_dir, file_name) | |
| with open(file_path, "w") as f: | |
| f.write(f"## {section}\n") | |
| f.write(f"timestamp: {datetime.datetime.now().strftime('%Y-%m-%d %H:%M:%S')}\n\n") | |
| for item in data: | |
| if not item: | |
| # skip falsy values like None or empty dict etc. | |
| continue | |
| if isinstance(item, str): | |
| # directly write the string | |
| f.write(item) | |
| else: | |
| # render a dict | |
| for k, v in item.items(): | |
| if isinstance(v, float): | |
| vstr = f"{v:.4f}" | |
| elif isinstance(v, int) and v >= 10000: | |
| vstr = f"{v:,.0f}" | |
| else: | |
| vstr = str(v) | |
| f.write(f"- {k}: {vstr}\n") | |
| f.write("\n") | |
| return file_path | |
| def generate(self): | |
| """Generate the final report.""" | |
| report_dir = self.report_dir | |
| report_file = os.path.join(report_dir, "report.md") | |
| print(f"Generating report to {report_file}") | |
| final_metrics = {} # the most important final metrics we'll add as table at the end | |
| start_time = None | |
| end_time = None | |
| with open(report_file, "w") as out_file: | |
| # write the header first | |
| header_file = os.path.join(report_dir, "header.md") | |
| if os.path.exists(header_file): | |
| with open(header_file, "r") as f: | |
| header_content = f.read() | |
| out_file.write(header_content) | |
| start_time = extract_timestamp(header_content, "Run started:") | |
| # capture bloat data for summary later (the stuff after Bloat header and until \n\n) | |
| bloat_data = re.search(r"### Bloat\n(.*?)\n\n", header_content, re.DOTALL) | |
| bloat_data = bloat_data.group(1) if bloat_data else "" | |
| # process all the individual sections | |
| for file_name in EXPECTED_FILES: | |
| section_file = os.path.join(report_dir, file_name) | |
| if not os.path.exists(section_file): | |
| print(f"Warning: {section_file} does not exist, skipping") | |
| continue | |
| with open(section_file, "r") as in_file: | |
| section = in_file.read() | |
| # Extract timestamp from this section (the last section's timestamp will "stick" as end_time) | |
| if "rl" not in file_name: | |
| # Skip RL sections for end_time calculation because RL is experimental | |
| end_time = extract_timestamp(section, "timestamp:") | |
| # extract the most important metrics from the sections | |
| if file_name == "base-model-evaluation.md": | |
| final_metrics["base"] = extract(section, "CORE") | |
| if file_name == "chat-evaluation-mid.md": | |
| final_metrics["mid"] = extract(section, chat_metrics) | |
| if file_name == "chat-evaluation-sft.md": | |
| final_metrics["sft"] = extract(section, chat_metrics) | |
| if file_name == "chat-evaluation-rl.md": | |
| final_metrics["rl"] = extract(section, "GSM8K") # RL only evals GSM8K | |
| # append this section of the report | |
| out_file.write(section) | |
| out_file.write("\n") | |
| # add the final metrics table | |
| out_file.write("## Summary\n\n") | |
| # Copy over the bloat metrics from the header | |
| out_file.write(bloat_data) | |
| out_file.write("\n\n") | |
| # Collect all unique metric names | |
| all_metrics = set() | |
| for stage_metrics in final_metrics.values(): | |
| all_metrics.update(stage_metrics.keys()) | |
| # Custom ordering: CORE first, ChatCORE last, rest in middle | |
| all_metrics = sorted(all_metrics, key=lambda x: (x != "CORE", x == "ChatCORE", x)) | |
| # Fixed column widths | |
| stages = ["base", "mid", "sft", "rl"] | |
| metric_width = 15 | |
| value_width = 8 | |
| # Write table header | |
| header = f"| {'Metric'.ljust(metric_width)} |" | |
| for stage in stages: | |
| header += f" {stage.upper().ljust(value_width)} |" | |
| out_file.write(header + "\n") | |
| # Write separator | |
| separator = f"|{'-' * (metric_width + 2)}|" | |
| for stage in stages: | |
| separator += f"{'-' * (value_width + 2)}|" | |
| out_file.write(separator + "\n") | |
| # Write table rows | |
| for metric in all_metrics: | |
| row = f"| {metric.ljust(metric_width)} |" | |
| for stage in stages: | |
| value = final_metrics.get(stage, {}).get(metric, "-") | |
| row += f" {str(value).ljust(value_width)} |" | |
| out_file.write(row + "\n") | |
| out_file.write("\n") | |
| # Calculate and write total wall clock time | |
| if start_time and end_time: | |
| duration = end_time - start_time | |
| total_seconds = int(duration.total_seconds()) | |
| hours = total_seconds // 3600 | |
| minutes = (total_seconds % 3600) // 60 | |
| out_file.write(f"Total wall clock time: {hours}h{minutes}m\n") | |
| else: | |
| out_file.write("Total wall clock time: unknown\n") | |
| # also cp the report.md file to current directory | |
| print(f"Copying report.md to current directory for convenience") | |
| shutil.copy(report_file, "report.md") | |
| return report_file | |
| def reset(self): | |
| """Reset the report.""" | |
| # Remove section files | |
| for file_name in EXPECTED_FILES: | |
| file_path = os.path.join(self.report_dir, file_name) | |
| if os.path.exists(file_path): | |
| os.remove(file_path) | |
| # Remove report.md if it exists | |
| report_file = os.path.join(self.report_dir, "report.md") | |
| if os.path.exists(report_file): | |
| os.remove(report_file) | |
| # Generate and write the header section with start timestamp | |
| header_file = os.path.join(self.report_dir, "header.md") | |
| header = generate_header() | |
| start_time = datetime.datetime.now().strftime("%Y-%m-%d %H:%M:%S") | |
| with open(header_file, "w") as f: | |
| f.write(header) | |
| f.write(f"Run started: {start_time}\n\n---\n\n") | |
| print(f"Reset report and wrote header to {header_file}") | |
| # ----------------------------------------------------------------------------- | |
| # nanochat-specific convenience functions | |
| class DummyReport: | |
| def log(self, *args, **kwargs): | |
| pass | |
| def reset(self, *args, **kwargs): | |
| pass | |
| def get_report(): | |
| # just for convenience, only rank 0 logs to report | |
| from nanochat.common import get_base_dir, get_dist_info | |
| ddp, ddp_rank, ddp_local_rank, ddp_world_size = get_dist_info() | |
| if ddp_rank == 0: | |
| report_dir = os.path.join(get_base_dir(), "report") | |
| return Report(report_dir) | |
| else: | |
| return DummyReport() | |
| if __name__ == "__main__": | |
| import argparse | |
| parser = argparse.ArgumentParser(description="Generate or reset nanochat training reports.") | |
| parser.add_argument("command", nargs="?", default="generate", choices=["generate", "reset"], help="Operation to perform (default: generate)") | |
| args = parser.parse_args() | |
| if args.command == "generate": | |
| get_report().generate() | |
| elif args.command == "reset": | |
| get_report().reset() | |