loocorez's picture
Upload folder using huggingface_hub
4d308e1 verified
"""
Common utilities for nanochat.
"""
import os
import re
import logging
import torch
import torch.distributed as dist
class ColoredFormatter(logging.Formatter):
"""Custom formatter that adds colors to log messages."""
# ANSI color codes
COLORS = {
'DEBUG': '\033[36m', # Cyan
'INFO': '\033[32m', # Green
'WARNING': '\033[33m', # Yellow
'ERROR': '\033[31m', # Red
'CRITICAL': '\033[35m', # Magenta
}
RESET = '\033[0m'
BOLD = '\033[1m'
def format(self, record):
# Add color to the level name
levelname = record.levelname
if levelname in self.COLORS:
record.levelname = f"{self.COLORS[levelname]}{self.BOLD}{levelname}{self.RESET}"
# Format the message
message = super().format(record)
# Add color to specific parts of the message
if levelname == 'INFO':
# Highlight numbers and percentages
message = re.sub(r'(\d+\.?\d*\s*(?:GB|MB|%|docs))', rf'{self.BOLD}\1{self.RESET}', message)
message = re.sub(r'(Shard \d+)', rf'{self.COLORS["INFO"]}{self.BOLD}\1{self.RESET}', message)
return message
def setup_default_logging():
handler = logging.StreamHandler()
handler.setFormatter(ColoredFormatter('%(asctime)s - %(name)s - %(levelname)s - %(message)s'))
logging.basicConfig(
level=logging.INFO,
handlers=[handler]
)
setup_default_logging()
logger = logging.getLogger(__name__)
def get_base_dir():
# co-locate nanochat intermediates with other cached data in ~/.cache (by default)
if os.environ.get("NANOCHAT_BASE_DIR"):
nanochat_dir = os.environ.get("NANOCHAT_BASE_DIR")
else:
home_dir = os.path.expanduser("~")
cache_dir = os.path.join(home_dir, ".cache")
nanochat_dir = os.path.join(cache_dir, "nanochat")
os.makedirs(nanochat_dir, exist_ok=True)
return nanochat_dir
def print0(s="",**kwargs):
ddp_rank = int(os.environ.get('RANK', 0))
if ddp_rank == 0:
print(s, **kwargs)
def print_banner():
# Cool DOS Rebel font ASCII banner made with https://manytools.org/hacker-tools/ascii-banner/
banner = """
β–ˆβ–ˆβ–ˆβ–ˆβ–ˆ β–ˆβ–ˆβ–ˆβ–ˆβ–ˆ
β–‘β–‘β–ˆβ–ˆβ–ˆ β–‘β–‘β–ˆβ–ˆβ–ˆ
β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆ β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆ β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆ β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆ β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆ β–‘β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆ β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆ β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆ
β–‘β–‘β–ˆβ–ˆβ–ˆβ–‘β–‘β–ˆβ–ˆβ–ˆ β–‘β–‘β–‘β–‘β–‘β–ˆβ–ˆβ–ˆ β–‘β–‘β–ˆβ–ˆβ–ˆβ–‘β–‘β–ˆβ–ˆβ–ˆ β–ˆβ–ˆβ–ˆβ–‘β–‘β–ˆβ–ˆβ–ˆ β–ˆβ–ˆβ–ˆβ–‘β–‘β–ˆβ–ˆβ–ˆ β–‘β–ˆβ–ˆβ–ˆβ–‘β–‘β–ˆβ–ˆβ–ˆ β–‘β–‘β–‘β–‘β–‘β–ˆβ–ˆβ–ˆ β–‘β–‘β–‘β–ˆβ–ˆβ–ˆβ–‘
β–‘β–ˆβ–ˆβ–ˆ β–‘β–ˆβ–ˆβ–ˆ β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆ β–‘β–ˆβ–ˆβ–ˆ β–‘β–ˆβ–ˆβ–ˆ β–‘β–ˆβ–ˆβ–ˆ β–‘β–ˆβ–ˆβ–ˆβ–‘β–ˆβ–ˆβ–ˆ β–‘β–‘β–‘ β–‘β–ˆβ–ˆβ–ˆ β–‘β–ˆβ–ˆβ–ˆ β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆ β–‘β–ˆβ–ˆβ–ˆ
β–‘β–ˆβ–ˆβ–ˆ β–‘β–ˆβ–ˆβ–ˆ β–ˆβ–ˆβ–ˆβ–‘β–‘β–ˆβ–ˆβ–ˆ β–‘β–ˆβ–ˆβ–ˆ β–‘β–ˆβ–ˆβ–ˆ β–‘β–ˆβ–ˆβ–ˆ β–‘β–ˆβ–ˆβ–ˆβ–‘β–ˆβ–ˆβ–ˆ β–ˆβ–ˆβ–ˆ β–‘β–ˆβ–ˆβ–ˆ β–‘β–ˆβ–ˆβ–ˆ β–ˆβ–ˆβ–ˆβ–‘β–‘β–ˆβ–ˆβ–ˆ β–‘β–ˆβ–ˆβ–ˆ β–ˆβ–ˆβ–ˆ
β–ˆβ–ˆβ–ˆβ–ˆ β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–‘β–‘β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆ β–ˆβ–ˆβ–ˆβ–ˆ β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–‘β–‘β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆ β–‘β–‘β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆ β–ˆβ–ˆβ–ˆβ–ˆ β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–‘β–‘β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆ β–‘β–‘β–ˆβ–ˆβ–ˆβ–ˆβ–ˆ
β–‘β–‘β–‘β–‘ β–‘β–‘β–‘β–‘β–‘ β–‘β–‘β–‘β–‘β–‘β–‘β–‘β–‘ β–‘β–‘β–‘β–‘ β–‘β–‘β–‘β–‘β–‘ β–‘β–‘β–‘β–‘β–‘β–‘ β–‘β–‘β–‘β–‘β–‘β–‘ β–‘β–‘β–‘β–‘ β–‘β–‘β–‘β–‘β–‘ β–‘β–‘β–‘β–‘β–‘β–‘β–‘β–‘ β–‘β–‘β–‘β–‘β–‘
"""
print0(banner)
def is_ddp():
# TODO is there a proper way
return int(os.environ.get('RANK', -1)) != -1
def get_dist_info():
if is_ddp():
assert all(var in os.environ for var in ['RANK', 'LOCAL_RANK', 'WORLD_SIZE'])
ddp_rank = int(os.environ['RANK'])
ddp_local_rank = int(os.environ['LOCAL_RANK'])
ddp_world_size = int(os.environ['WORLD_SIZE'])
return True, ddp_rank, ddp_local_rank, ddp_world_size
else:
return False, 0, 0, 1
def compute_init():
"""Basic initialization that we keep doing over and over, so make common."""
# CUDA is currently required
assert torch.cuda.is_available(), "CUDA is needed for a distributed run atm"
# Reproducibility
torch.manual_seed(42)
torch.cuda.manual_seed(42)
# skipping full reproducibility for now, possibly investigate slowdown later
# torch.use_deterministic_algorithms(True)
# torch.backends.cudnn.deterministic = True
# torch.backends.cudnn.benchmark = False
# Precision
torch.set_float32_matmul_precision("high") # uses tf32 instead of fp32 for matmuls
# Distributed setup: Distributed Data Parallel (DDP), optional
ddp, ddp_rank, ddp_local_rank, ddp_world_size = get_dist_info()
if ddp:
device = torch.device("cuda", ddp_local_rank)
torch.cuda.set_device(device) # make "cuda" default to this device
dist.init_process_group(backend="nccl", device_id=device)
dist.barrier()
else:
device = torch.device("cuda")
if ddp_rank == 0:
logger.info(f"Distributed world size: {ddp_world_size}")
return ddp, ddp_rank, ddp_local_rank, ddp_world_size, device
def compute_cleanup():
"""Companion function to compute_init, to clean things up before script exit"""
if is_ddp():
dist.destroy_process_group()
class DummyWandb:
"""Useful if we wish to not use wandb but have all the same signatures"""
def __init__(self):
pass
def log(self, *args, **kwargs):
pass
def finish(self):
pass