File size: 5,694 Bytes
4d308e1
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
"""
Common utilities for nanochat.
"""

import os
import re
import logging
import torch
import torch.distributed as dist

class ColoredFormatter(logging.Formatter):
    """Custom formatter that adds colors to log messages."""
    # ANSI color codes
    COLORS = {
        'DEBUG': '\033[36m',    # Cyan
        'INFO': '\033[32m',     # Green
        'WARNING': '\033[33m',  # Yellow
        'ERROR': '\033[31m',    # Red
        'CRITICAL': '\033[35m', # Magenta
    }
    RESET = '\033[0m'
    BOLD = '\033[1m'
    def format(self, record):
        # Add color to the level name
        levelname = record.levelname
        if levelname in self.COLORS:
            record.levelname = f"{self.COLORS[levelname]}{self.BOLD}{levelname}{self.RESET}"
        # Format the message
        message = super().format(record)
        # Add color to specific parts of the message
        if levelname == 'INFO':
            # Highlight numbers and percentages
            message = re.sub(r'(\d+\.?\d*\s*(?:GB|MB|%|docs))', rf'{self.BOLD}\1{self.RESET}', message)
            message = re.sub(r'(Shard \d+)', rf'{self.COLORS["INFO"]}{self.BOLD}\1{self.RESET}', message)
        return message

def setup_default_logging():
    handler = logging.StreamHandler()
    handler.setFormatter(ColoredFormatter('%(asctime)s - %(name)s - %(levelname)s - %(message)s'))
    logging.basicConfig(
        level=logging.INFO,
        handlers=[handler]
    )

setup_default_logging()
logger = logging.getLogger(__name__)

def get_base_dir():
    # co-locate nanochat intermediates with other cached data in ~/.cache (by default)
    if os.environ.get("NANOCHAT_BASE_DIR"):
        nanochat_dir = os.environ.get("NANOCHAT_BASE_DIR")
    else:
        home_dir = os.path.expanduser("~")
        cache_dir = os.path.join(home_dir, ".cache")
        nanochat_dir = os.path.join(cache_dir, "nanochat")
    os.makedirs(nanochat_dir, exist_ok=True)
    return nanochat_dir

def print0(s="",**kwargs):
    ddp_rank = int(os.environ.get('RANK', 0))
    if ddp_rank == 0:
        print(s, **kwargs)

def print_banner():
    # Cool DOS Rebel font ASCII banner made with https://manytools.org/hacker-tools/ascii-banner/
    banner = """
                                                   β–ˆβ–ˆβ–ˆβ–ˆβ–ˆ                 β–ˆβ–ˆβ–ˆβ–ˆβ–ˆ
                                                  β–‘β–‘β–ˆβ–ˆβ–ˆ                 β–‘β–‘β–ˆβ–ˆβ–ˆ
 β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆ    β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆ   β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆ    β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆ   β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆ  β–‘β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆ    β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆ   β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆ
β–‘β–‘β–ˆβ–ˆβ–ˆβ–‘β–‘β–ˆβ–ˆβ–ˆ  β–‘β–‘β–‘β–‘β–‘β–ˆβ–ˆβ–ˆ β–‘β–‘β–ˆβ–ˆβ–ˆβ–‘β–‘β–ˆβ–ˆβ–ˆ  β–ˆβ–ˆβ–ˆβ–‘β–‘β–ˆβ–ˆβ–ˆ β–ˆβ–ˆβ–ˆβ–‘β–‘β–ˆβ–ˆβ–ˆ β–‘β–ˆβ–ˆβ–ˆβ–‘β–‘β–ˆβ–ˆβ–ˆ  β–‘β–‘β–‘β–‘β–‘β–ˆβ–ˆβ–ˆ β–‘β–‘β–‘β–ˆβ–ˆβ–ˆβ–‘
 β–‘β–ˆβ–ˆβ–ˆ β–‘β–ˆβ–ˆβ–ˆ   β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆ  β–‘β–ˆβ–ˆβ–ˆ β–‘β–ˆβ–ˆβ–ˆ β–‘β–ˆβ–ˆβ–ˆ β–‘β–ˆβ–ˆβ–ˆβ–‘β–ˆβ–ˆβ–ˆ β–‘β–‘β–‘  β–‘β–ˆβ–ˆβ–ˆ β–‘β–ˆβ–ˆβ–ˆ   β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆ   β–‘β–ˆβ–ˆβ–ˆ
 β–‘β–ˆβ–ˆβ–ˆ β–‘β–ˆβ–ˆβ–ˆ  β–ˆβ–ˆβ–ˆβ–‘β–‘β–ˆβ–ˆβ–ˆ  β–‘β–ˆβ–ˆβ–ˆ β–‘β–ˆβ–ˆβ–ˆ β–‘β–ˆβ–ˆβ–ˆ β–‘β–ˆβ–ˆβ–ˆβ–‘β–ˆβ–ˆβ–ˆ  β–ˆβ–ˆβ–ˆ β–‘β–ˆβ–ˆβ–ˆ β–‘β–ˆβ–ˆβ–ˆ  β–ˆβ–ˆβ–ˆβ–‘β–‘β–ˆβ–ˆβ–ˆ   β–‘β–ˆβ–ˆβ–ˆ β–ˆβ–ˆβ–ˆ
 β–ˆβ–ˆβ–ˆβ–ˆ β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–‘β–‘β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆ β–ˆβ–ˆβ–ˆβ–ˆ β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–‘β–‘β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆ β–‘β–‘β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆ  β–ˆβ–ˆβ–ˆβ–ˆ β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–‘β–‘β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆ  β–‘β–‘β–ˆβ–ˆβ–ˆβ–ˆβ–ˆ
β–‘β–‘β–‘β–‘ β–‘β–‘β–‘β–‘β–‘  β–‘β–‘β–‘β–‘β–‘β–‘β–‘β–‘ β–‘β–‘β–‘β–‘ β–‘β–‘β–‘β–‘β–‘  β–‘β–‘β–‘β–‘β–‘β–‘   β–‘β–‘β–‘β–‘β–‘β–‘  β–‘β–‘β–‘β–‘ β–‘β–‘β–‘β–‘β–‘  β–‘β–‘β–‘β–‘β–‘β–‘β–‘β–‘    β–‘β–‘β–‘β–‘β–‘
"""
    print0(banner)

def is_ddp():
    # TODO is there a proper way
    return int(os.environ.get('RANK', -1)) != -1

def get_dist_info():
    if is_ddp():
        assert all(var in os.environ for var in ['RANK', 'LOCAL_RANK', 'WORLD_SIZE'])
        ddp_rank = int(os.environ['RANK'])
        ddp_local_rank = int(os.environ['LOCAL_RANK'])
        ddp_world_size = int(os.environ['WORLD_SIZE'])
        return True, ddp_rank, ddp_local_rank, ddp_world_size
    else:
        return False, 0, 0, 1

def compute_init():
    """Basic initialization that we keep doing over and over, so make common."""

    # CUDA is currently required
    assert torch.cuda.is_available(), "CUDA is needed for a distributed run atm"

    # Reproducibility
    torch.manual_seed(42)
    torch.cuda.manual_seed(42)
    # skipping full reproducibility for now, possibly investigate slowdown later
    # torch.use_deterministic_algorithms(True)
    # torch.backends.cudnn.deterministic = True
    # torch.backends.cudnn.benchmark = False

    # Precision
    torch.set_float32_matmul_precision("high") # uses tf32 instead of fp32 for matmuls

    # Distributed setup: Distributed Data Parallel (DDP), optional
    ddp, ddp_rank, ddp_local_rank, ddp_world_size = get_dist_info()
    if ddp:
        device = torch.device("cuda", ddp_local_rank)
        torch.cuda.set_device(device) # make "cuda" default to this device
        dist.init_process_group(backend="nccl", device_id=device)
        dist.barrier()
    else:
        device = torch.device("cuda")

    if ddp_rank == 0:
        logger.info(f"Distributed world size: {ddp_world_size}")

    return ddp, ddp_rank, ddp_local_rank, ddp_world_size, device

def compute_cleanup():
    """Companion function to compute_init, to clean things up before script exit"""
    if is_ddp():
        dist.destroy_process_group()

class DummyWandb:
    """Useful if we wish to not use wandb but have all the same signatures"""
    def __init__(self):
        pass
    def log(self, *args, **kwargs):
        pass
    def finish(self):
        pass