Spaces:
Sleeping
Sleeping
| """ | |
| Common utilities for nanochat. | |
| """ | |
| import os | |
| import re | |
| import logging | |
| import torch | |
| import torch.distributed as dist | |
| class ColoredFormatter(logging.Formatter): | |
| """Custom formatter that adds colors to log messages.""" | |
| # ANSI color codes | |
| COLORS = { | |
| 'DEBUG': '\033[36m', # Cyan | |
| 'INFO': '\033[32m', # Green | |
| 'WARNING': '\033[33m', # Yellow | |
| 'ERROR': '\033[31m', # Red | |
| 'CRITICAL': '\033[35m', # Magenta | |
| } | |
| RESET = '\033[0m' | |
| BOLD = '\033[1m' | |
| def format(self, record): | |
| # Add color to the level name | |
| levelname = record.levelname | |
| if levelname in self.COLORS: | |
| record.levelname = f"{self.COLORS[levelname]}{self.BOLD}{levelname}{self.RESET}" | |
| # Format the message | |
| message = super().format(record) | |
| # Add color to specific parts of the message | |
| if levelname == 'INFO': | |
| # Highlight numbers and percentages | |
| message = re.sub(r'(\d+\.?\d*\s*(?:GB|MB|%|docs))', rf'{self.BOLD}\1{self.RESET}', message) | |
| message = re.sub(r'(Shard \d+)', rf'{self.COLORS["INFO"]}{self.BOLD}\1{self.RESET}', message) | |
| return message | |
| def setup_default_logging(): | |
| handler = logging.StreamHandler() | |
| handler.setFormatter(ColoredFormatter('%(asctime)s - %(name)s - %(levelname)s - %(message)s')) | |
| logging.basicConfig( | |
| level=logging.INFO, | |
| handlers=[handler] | |
| ) | |
| setup_default_logging() | |
| logger = logging.getLogger(__name__) | |
| def get_base_dir(): | |
| # co-locate nanochat intermediates with other cached data in ~/.cache (by default) | |
| if os.environ.get("NANOCHAT_BASE_DIR"): | |
| nanochat_dir = os.environ.get("NANOCHAT_BASE_DIR") | |
| else: | |
| home_dir = os.path.expanduser("~") | |
| cache_dir = os.path.join(home_dir, ".cache") | |
| nanochat_dir = os.path.join(cache_dir, "nanochat") | |
| os.makedirs(nanochat_dir, exist_ok=True) | |
| return nanochat_dir | |
| def print0(s="",**kwargs): | |
| ddp_rank = int(os.environ.get('RANK', 0)) | |
| if ddp_rank == 0: | |
| print(s, **kwargs) | |
| def print_banner(): | |
| # Cool DOS Rebel font ASCII banner made with https://manytools.org/hacker-tools/ascii-banner/ | |
| banner = """ | |
| βββββ βββββ | |
| βββββ βββββ | |
| ββββββββ ββββββ ββββββββ ββββββ ββββββ ββββββββ ββββββ βββββββ | |
| ββββββββββ ββββββββ ββββββββββ ββββββββ ββββββββ βββββββββ ββββββββ βββββββ | |
| ββββ ββββ βββββββ ββββ ββββ ββββ ββββββββ βββ ββββ ββββ βββββββ ββββ | |
| ββββ ββββ ββββββββ ββββ ββββ ββββ ββββββββ βββ ββββ ββββ ββββββββ ββββ βββ | |
| ββββ βββββββββββββββ ββββ βββββββββββββ ββββββββ ββββ βββββββββββββββ βββββββ | |
| ββββ βββββ ββββββββ ββββ βββββ ββββββ ββββββ ββββ βββββ ββββββββ βββββ | |
| """ | |
| print0(banner) | |
| def is_ddp(): | |
| # TODO is there a proper way | |
| return int(os.environ.get('RANK', -1)) != -1 | |
| def get_dist_info(): | |
| if is_ddp(): | |
| assert all(var in os.environ for var in ['RANK', 'LOCAL_RANK', 'WORLD_SIZE']) | |
| ddp_rank = int(os.environ['RANK']) | |
| ddp_local_rank = int(os.environ['LOCAL_RANK']) | |
| ddp_world_size = int(os.environ['WORLD_SIZE']) | |
| return True, ddp_rank, ddp_local_rank, ddp_world_size | |
| else: | |
| return False, 0, 0, 1 | |
| def compute_init(): | |
| """Basic initialization that we keep doing over and over, so make common.""" | |
| # CUDA is currently required | |
| assert torch.cuda.is_available(), "CUDA is needed for a distributed run atm" | |
| # Reproducibility | |
| torch.manual_seed(42) | |
| torch.cuda.manual_seed(42) | |
| # skipping full reproducibility for now, possibly investigate slowdown later | |
| # torch.use_deterministic_algorithms(True) | |
| # torch.backends.cudnn.deterministic = True | |
| # torch.backends.cudnn.benchmark = False | |
| # Precision | |
| torch.set_float32_matmul_precision("high") # uses tf32 instead of fp32 for matmuls | |
| # Distributed setup: Distributed Data Parallel (DDP), optional | |
| ddp, ddp_rank, ddp_local_rank, ddp_world_size = get_dist_info() | |
| if ddp: | |
| device = torch.device("cuda", ddp_local_rank) | |
| torch.cuda.set_device(device) # make "cuda" default to this device | |
| dist.init_process_group(backend="nccl", device_id=device) | |
| dist.barrier() | |
| else: | |
| device = torch.device("cuda") | |
| if ddp_rank == 0: | |
| logger.info(f"Distributed world size: {ddp_world_size}") | |
| return ddp, ddp_rank, ddp_local_rank, ddp_world_size, device | |
| def compute_cleanup(): | |
| """Companion function to compute_init, to clean things up before script exit""" | |
| if is_ddp(): | |
| dist.destroy_process_group() | |
| class DummyWandb: | |
| """Useful if we wish to not use wandb but have all the same signatures""" | |
| def __init__(self): | |
| pass | |
| def log(self, *args, **kwargs): | |
| pass | |
| def finish(self): | |
| pass | |