root commited on
Commit
410c1c2
·
1 Parent(s): 6e28f61

update ckpt version

Browse files
Files changed (2) hide show
  1. download.py +2 -3
  2. z_script.py +44 -0
download.py CHANGED
@@ -4,14 +4,13 @@ import os
4
  os.environ["HF_HUB_DOWNLOAD_TIMEOUT"] = "1200"
5
  # os.environ["HF_ENDPOINT"] = "https://hf-mirror.com"
6
 
7
- def download_model(local_dir, repo_id="tencent/SongGeneration", revision="aa9d1b3"):
8
  downloaded_path = snapshot_download(
9
  repo_id=repo_id,
10
  local_dir=local_dir,
11
  revision=revision,
12
  token=os.environ.get("HF_TOKEN"),
13
- ignore_patterns=['.git*'],
14
- endpoint="https://hf-mirror.com"
15
  )
16
  print(f"File downloaded to:{downloaded_path}")
17
 
 
4
  os.environ["HF_HUB_DOWNLOAD_TIMEOUT"] = "1200"
5
  # os.environ["HF_ENDPOINT"] = "https://hf-mirror.com"
6
 
7
+ def download_model(local_dir, repo_id="tencent/SongGeneration", revision="87bccf1"):
8
  downloaded_path = snapshot_download(
9
  repo_id=repo_id,
10
  local_dir=local_dir,
11
  revision=revision,
12
  token=os.environ.get("HF_TOKEN"),
13
+ ignore_patterns=['.git*']
 
14
  )
15
  print(f"File downloaded to:{downloaded_path}")
16
 
z_script.py ADDED
@@ -0,0 +1,44 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from hmac import new
2
+ import sys
3
+ import os
4
+ import argparse
5
+ from safetensors.torch import save_file
6
+
7
+ import time
8
+ import json
9
+ import torch
10
+ import torchaudio
11
+ import numpy as np
12
+ from omegaconf import OmegaConf
13
+ from codeclm.models import builders
14
+ import gc
15
+ from codeclm.trainer.codec_song_pl import CodecLM_PL
16
+ from codeclm.models import CodecLM
17
+ from third_party.demucs.models.pretrained import get_model_from_yaml
18
+
19
+ cfg_path = "/apdcephfs_cq11/share_300883980/tanwei/SongGeneration-LeVo/ckpt/songgeneration_base/config.yaml"
20
+ cfg = OmegaConf.load(cfg_path)
21
+ cfg.mode = 'inference'
22
+ # audio_tokenizer = builders.get_audio_tokenizer_model(cfg.audio_tokenizer_checkpoint, cfg)
23
+ # model = audio_tokenizer.model.model
24
+ # weights = {k: v.half() for k, v in model.state_dict().items() if isinstance(v, torch.Tensor) and v.numel() > 0}
25
+ # save_file(weights, '/apdcephfs_cq11/share_300883980/tanwei/SongGeneration-LeVo/ckpt/encoder_fp16.safetensors')
26
+ # print(weights)
27
+
28
+ # seperate_tokenizer = builders.get_audio_tokenizer_model(cfg.audio_tokenizer_checkpoint_sep, cfg)
29
+ # model = seperate_tokenizer.model.model
30
+ # weights = {}
31
+ # for k, v in model.state_dict().items():
32
+ # if k.startswith("rvq_bestrq_bgm_emb") or k.startswith("rvq_bestrq_emb") or k.startswith("bestrq"):
33
+ # weights[k] = v.half()
34
+ # else:
35
+ # weights[k] = v
36
+ # # weights = {k: v.half() for k, v in model.state_dict().items() if isinstance(v, torch.Tensor) and v.numel() > 0}
37
+ # save_file(weights, '/apdcephfs_cq11/share_300883980/tanwei/SongGeneration-LeVo/ckpt/encoder_fp16.safetensors')
38
+ # print(weights.keys())
39
+
40
+ ckpt_path = "/apdcephfs_cq11/share_300883980/tanwei/SongGeneration-WX/ckpt/songgeneration_new_small/model_32.pt"
41
+ # audiolm = builders.get_lm_model(cfg)
42
+ checkpoint = torch.load(ckpt_path, map_location='cpu')
43
+ audiolm_state_dict = {k: v.half() for k, v in checkpoint.items()}
44
+ torch.save(audiolm_state_dict, "/apdcephfs_cq11/share_300883980/tanwei/SongGeneration-WX/ckpt/songgeneration_new_small/model.pt")