File size: 3,602 Bytes
bd710e9
 
 
 
 
 
 
 
 
c015721
bd710e9
c015721
 
 
 
bd710e9
 
 
1ccb133
b282fec
 
 
bd710e9
b282fec
1ccb133
b282fec
1ccb133
 
bd710e9
 
1ccb133
b282fec
bd710e9
1ccb133
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
bd710e9
87de695
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
bd710e9
87de695
bd710e9
87de695
bd710e9
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
#!/usr/bin/env python3
"""
Hugging Face Spaces entry point for NanoChat.
This file is automatically detected and run by HF Spaces.
"""

import os
import sys

# Set environment variables for HF Spaces BEFORE any other imports
os.environ.setdefault("NANOCHAT_BASE_DIR", "/data")
os.environ.setdefault("TORCHINDUCTOR_CACHE_DIR", "/data/.cache/torch")
os.environ.setdefault("HF_HOME", "/data/.cache/huggingface")
os.environ.setdefault("TORCH_HOME", "/data/.cache/torch")
os.environ.setdefault("XDG_CACHE_HOME", "/data/.cache")

# Download model from HF if not present
def download_model():
    """Download model weights and metadata from Hugging Face."""
    # Create all necessary directories
    os.makedirs("/data/chatsft_checkpoints/d20", exist_ok=True)
    os.makedirs("/data/.cache/huggingface", exist_ok=True)

    checkpoint_path = "/data/chatsft_checkpoints/d20/model_000650.pt"
    meta_path = "/data/chatsft_checkpoints/d20/meta_000650.json"

    if os.path.exists(checkpoint_path) and os.path.exists(meta_path):
        print(f"Model checkpoint and metadata found, skipping download")
        return

    print("Downloading model files from BrianGuo/nanochat-d20-chat...")
    from huggingface_hub import hf_hub_download

    # Download the checkpoint file
    if not os.path.exists(checkpoint_path):
        print("  - Downloading model_000650.pt...")
        hf_hub_download(
            repo_id="BrianGuo/nanochat-d20-chat",
            filename="chatsft_checkpoints/d20/model_000650.pt",
            local_dir="/data",
            local_dir_use_symlinks=False
        )

    # Download the metadata file
    if not os.path.exists(meta_path):
        print("  - Downloading meta_000650.json...")
        hf_hub_download(
            repo_id="BrianGuo/nanochat-d20-chat",
            filename="chatsft_checkpoints/d20/meta_000650.json",
            local_dir="/data",
            local_dir_use_symlinks=False
        )

    print("Model files downloaded successfully!")

def download_tokenizer():
    """Download tokenizer files from Hugging Face."""
    tokenizer_dir = "/data/tokenizer"

    if os.path.exists(tokenizer_dir) and os.listdir(tokenizer_dir):
        print(f"Tokenizer found, skipping download")
        return

    print("Downloading tokenizer from BrianGuo/nanochat-d20-chat...")
    from huggingface_hub import hf_hub_download

    os.makedirs(tokenizer_dir, exist_ok=True)

    # Download tokenizer files
    hf_hub_download(
        repo_id="BrianGuo/nanochat-d20-chat",
        filename="tokenizer/token_bytes.pt",
        local_dir="/data",
        local_dir_use_symlinks=False
    )
    hf_hub_download(
        repo_id="BrianGuo/nanochat-d20-chat",
        filename="tokenizer/tokenizer.pkl",
        local_dir="/data",
        local_dir_use_symlinks=False
    )
    print("Tokenizer downloaded successfully!")

if __name__ == "__main__":
    # Download model and tokenizer before starting
    download_model()
    download_tokenizer()

    # Override sys.argv to pass default arguments for HF Spaces
    sys.argv = [
        "app.py",
        "--port", "7860",  # HF Spaces default port
        "--host", "0.0.0.0",
        "--source", "sft",
        "--model-tag", os.environ.get("MODEL_TAG", "d20"),
        "--step", os.environ.get("MODEL_STEP", "650"),
    ]

    # Import and run the web server
    from scripts.chat_web import app
    import uvicorn

    print("Starting NanoChat on Hugging Face Spaces...")
    print(f"Model: {os.environ.get('MODEL_TAG', 'd20')} - Step: {os.environ.get('MODEL_STEP', '650')}")

    uvicorn.run(app, host="0.0.0.0", port=7860)