Thewhey-Brian commited on
Commit
b282fec
·
1 Parent(s): b6b55f4

Fix HF download: download specific checkpoint file and fix permissions

Browse files
Files changed (2) hide show
  1. Dockerfile +3 -2
  2. app.py +14 -7
Dockerfile CHANGED
@@ -25,12 +25,13 @@ RUN pip install --no-cache-dir -r requirements.txt
25
  # Build and install the Rust component
26
  RUN . $HOME/.cargo/env && maturin build --release && pip install --no-cache-dir rustbpe/target/wheels/*.whl
27
 
28
- # Create data directory for model checkpoints
29
- RUN mkdir -p /data
30
 
31
  # Set environment variables
32
  ENV NANOCHAT_BASE_DIR=/data
33
  ENV PYTHONUNBUFFERED=1
 
34
 
35
  # Expose port 7860 (HF Spaces default)
36
  EXPOSE 7860
 
25
  # Build and install the Rust component
26
  RUN . $HOME/.cargo/env && maturin build --release && pip install --no-cache-dir rustbpe/target/wheels/*.whl
27
 
28
+ # Create data directory for model checkpoints with proper permissions
29
+ RUN mkdir -p /data && chmod 777 /data
30
 
31
  # Set environment variables
32
  ENV NANOCHAT_BASE_DIR=/data
33
  ENV PYTHONUNBUFFERED=1
34
+ ENV HF_HOME=/data/.cache/huggingface
35
 
36
  # Expose port 7860 (HF Spaces default)
37
  EXPOSE 7860
app.py CHANGED
@@ -13,18 +13,25 @@ os.environ.setdefault("NANOCHAT_BASE_DIR", "/data")
13
  # Download model from HF if not present
14
  def download_model():
15
  """Download model weights from Hugging Face."""
16
- checkpoint_dir = "/data/chatsft_checkpoints"
 
 
17
 
18
- if os.path.exists(checkpoint_dir) and os.listdir(checkpoint_dir):
19
- print(f"Model checkpoints found, skipping download")
 
 
20
  return
21
 
22
- print("Downloading model from BrianGuo/nanochat-d20-chat...")
23
- from huggingface_hub import snapshot_download
24
 
25
- snapshot_download(
 
26
  repo_id="BrianGuo/nanochat-d20-chat",
27
- local_dir="/data/chatsft_checkpoints"
 
 
28
  )
29
  print("Model downloaded successfully!")
30
 
 
13
  # Download model from HF if not present
14
  def download_model():
15
  """Download model weights from Hugging Face."""
16
+ # Create all necessary directories
17
+ os.makedirs("/data/chatsft_checkpoints/d20", exist_ok=True)
18
+ os.makedirs("/data/.cache/huggingface", exist_ok=True)
19
 
20
+ checkpoint_path = "/data/chatsft_checkpoints/d20/model_000650.pt"
21
+
22
+ if os.path.exists(checkpoint_path):
23
+ print(f"Model checkpoint found at {checkpoint_path}, skipping download")
24
  return
25
 
26
+ print("Downloading model checkpoint from BrianGuo/nanochat-d20-chat...")
27
+ from huggingface_hub import hf_hub_download
28
 
29
+ # Download the specific checkpoint file
30
+ hf_hub_download(
31
  repo_id="BrianGuo/nanochat-d20-chat",
32
+ filename="chatsft_checkpoints/d20/model_000650.pt",
33
+ local_dir="/data",
34
+ local_dir_use_symlinks=False
35
  )
36
  print("Model downloaded successfully!")
37