root commited on
Commit
6e28f61
·
1 Parent(s): 4846f0c

remove fairseq

Browse files
codeclm/tokenizer/Flow1dVAE/generate_septoken.py CHANGED
@@ -14,8 +14,8 @@ import tools.torch_tools as torch_tools
14
  from safetensors.torch import load_file
15
  from third_party.demucs.models.pretrained import get_model_from_yaml
16
  from filelock import FileLock
17
- import kaldiio
18
- # os.path.join(args.model_dir, "htdemucs.pth"), os.path.join(args.model_dir, "htdemucs.yaml")
19
  class Separator:
20
  def __init__(self, dm_model_path='demucs/ckpt/htdemucs.pth', dm_config_path='demucs/ckpt/htdemucs.yaml', gpu_id=0) -> None:
21
  if torch.cuda.is_available() and gpu_id < torch.cuda.device_count():
 
14
  from safetensors.torch import load_file
15
  from third_party.demucs.models.pretrained import get_model_from_yaml
16
  from filelock import FileLock
17
+
18
+
19
  class Separator:
20
  def __init__(self, dm_model_path='demucs/ckpt/htdemucs.pth', dm_config_path='demucs/ckpt/htdemucs.yaml', gpu_id=0) -> None:
21
  if torch.cuda.is_available() and gpu_id < torch.cuda.device_count():
codeclm/tokenizer/Flow1dVAE/model_1rvq.py CHANGED
@@ -19,12 +19,11 @@ from libs.rvq.descript_quantize3 import ResidualVectorQuantize
19
 
20
  from models_gpt.models.gpt2_rope2_time_new_correct_mask_noncasual_reflow import GPT2Model
21
  from models_gpt.models.gpt2_config import GPT2Config
 
22
 
23
  from torch.cuda.amp import autocast
24
 
25
 
26
- from our_MERT_BESTRQ.test import load_model
27
-
28
  class HubertModelWithFinalProj(HubertModel):
29
  def __init__(self, config):
30
  super().__init__(config)
@@ -272,6 +271,7 @@ class PromptCondAudioDiffusion(nn.Module):
272
  ssl_layer=None,
273
  uncondition=True,
274
  out_paint=False,
 
275
  ):
276
  super().__init__()
277
 
@@ -294,28 +294,24 @@ class PromptCondAudioDiffusion(nn.Module):
294
  self.rsq48towav2vec = torchaudio.transforms.Resample(48000, 16000)
295
  # self.wav2vec = Wav2Vec2BertModel.from_pretrained("facebook/w2v-bert-2.0", trust_remote_code=True)
296
  # self.wav2vec_processor = AutoFeatureExtractor.from_pretrained("facebook/w2v-bert-2.0", trust_remote_code=True)
297
- self.bestrq = load_model(
298
- model_dir='codeclm/tokenizer/Flow1dVAE/our_MERT_BESTRQ/mert_fairseq',
299
- checkpoint_dir='ckpt/encode-s12k.pt',
300
- )
301
  self.rsq48tobestrq = torchaudio.transforms.Resample(48000, 24000)
302
  self.rsq48tohubert = torchaudio.transforms.Resample(48000, 16000)
303
- for v in self.bestrq.parameters():v.requires_grad = False
304
  self.rvq_bestrq_emb = ResidualVectorQuantize(input_dim = 1024, n_codebooks = 1, codebook_size = 16_384, codebook_dim = 32, quantizer_dropout = 0.0, stale_tolerance=200)
305
  for v in self.rvq_bestrq_emb.parameters():v.requires_grad = False
306
  # self.hubert = HubertModelWithFinalProj.from_pretrained("ckpt/models--lengyue233--content-vec-best/snapshots/c0b9ba13db21beaa4053faae94c102ebe326fd68")
307
  # for v in self.hubert.parameters():v.requires_grad = False
308
  self.zero_cond_embedding1 = nn.Parameter(torch.randn(32*32,))
309
  # self.xvecmodel = XVECModel()
310
- config = GPT2Config(n_positions=1000,n_layer=39,n_head=30,n_embd=1200)
311
- unet = GPT2Model(config)
312
- mlp = nn.Sequential(
313
- nn.Linear(1200, 1024),
314
- nn.SiLU(),
315
- nn.Linear(1024, 1024),
316
- nn.SiLU(),
317
- nn.Linear(1024, 768)
318
- )
319
  self.set_from = "random"
320
  # self.cfm_wrapper = BASECFM(unet, mlp,self.ssl_layer)
321
  self.mask_emb = torch.nn.Embedding(3, 48)
@@ -538,8 +534,6 @@ class PromptCondAudioDiffusion(nn.Module):
538
  input_audio_0 = self.preprocess_audio(input_audio_0)
539
  input_audio_1 = self.preprocess_audio(input_audio_1)
540
 
541
- self.bestrq.eval()
542
-
543
  # bestrq_middle,bestrq_last = self.extract_bestrq_embeds(input_audios)
544
  # bestrq_middle = bestrq_middle.detach()
545
  # bestrq_last = bestrq_last.detach()
@@ -575,8 +569,6 @@ class PromptCondAudioDiffusion(nn.Module):
575
  input_audio_0 = self.preprocess_audio(input_audio_0)
576
  input_audio_1 = self.preprocess_audio(input_audio_1)
577
 
578
- self.bestrq.eval()
579
-
580
  # bestrq_middle,bestrq_last = self.extract_bestrq_embeds(input_audios)
581
  # bestrq_middle = bestrq_middle.detach()
582
  # bestrq_last = bestrq_last.detach()
 
19
 
20
  from models_gpt.models.gpt2_rope2_time_new_correct_mask_noncasual_reflow import GPT2Model
21
  from models_gpt.models.gpt2_config import GPT2Config
22
+ from our_MERT_BESTRQ.mert_fairseq.models.musicfm.musicfm_model import MusicFMModel, MusicFMConfig
23
 
24
  from torch.cuda.amp import autocast
25
 
26
 
 
 
27
  class HubertModelWithFinalProj(HubertModel):
28
  def __init__(self, config):
29
  super().__init__(config)
 
271
  ssl_layer=None,
272
  uncondition=True,
273
  out_paint=False,
274
+ ssl_path='ckpt/encode-s12k.pt'
275
  ):
276
  super().__init__()
277
 
 
294
  self.rsq48towav2vec = torchaudio.transforms.Resample(48000, 16000)
295
  # self.wav2vec = Wav2Vec2BertModel.from_pretrained("facebook/w2v-bert-2.0", trust_remote_code=True)
296
  # self.wav2vec_processor = AutoFeatureExtractor.from_pretrained("facebook/w2v-bert-2.0", trust_remote_code=True)
297
+ self.bestrq = MusicFMModel(MusicFMConfig())
 
 
 
298
  self.rsq48tobestrq = torchaudio.transforms.Resample(48000, 24000)
299
  self.rsq48tohubert = torchaudio.transforms.Resample(48000, 16000)
 
300
  self.rvq_bestrq_emb = ResidualVectorQuantize(input_dim = 1024, n_codebooks = 1, codebook_size = 16_384, codebook_dim = 32, quantizer_dropout = 0.0, stale_tolerance=200)
301
  for v in self.rvq_bestrq_emb.parameters():v.requires_grad = False
302
  # self.hubert = HubertModelWithFinalProj.from_pretrained("ckpt/models--lengyue233--content-vec-best/snapshots/c0b9ba13db21beaa4053faae94c102ebe326fd68")
303
  # for v in self.hubert.parameters():v.requires_grad = False
304
  self.zero_cond_embedding1 = nn.Parameter(torch.randn(32*32,))
305
  # self.xvecmodel = XVECModel()
306
+ # config = GPT2Config(n_positions=1000,n_layer=39,n_head=30,n_embd=1200)
307
+ # unet = GPT2Model(config)
308
+ # mlp = nn.Sequential(
309
+ # nn.Linear(1200, 1024),
310
+ # nn.SiLU(),
311
+ # nn.Linear(1024, 1024),
312
+ # nn.SiLU(),
313
+ # nn.Linear(1024, 768)
314
+ # )
315
  self.set_from = "random"
316
  # self.cfm_wrapper = BASECFM(unet, mlp,self.ssl_layer)
317
  self.mask_emb = torch.nn.Embedding(3, 48)
 
534
  input_audio_0 = self.preprocess_audio(input_audio_0)
535
  input_audio_1 = self.preprocess_audio(input_audio_1)
536
 
 
 
537
  # bestrq_middle,bestrq_last = self.extract_bestrq_embeds(input_audios)
538
  # bestrq_middle = bestrq_middle.detach()
539
  # bestrq_last = bestrq_last.detach()
 
569
  input_audio_0 = self.preprocess_audio(input_audio_0)
570
  input_audio_1 = self.preprocess_audio(input_audio_1)
571
 
 
 
572
  # bestrq_middle,bestrq_last = self.extract_bestrq_embeds(input_audios)
573
  # bestrq_middle = bestrq_middle.detach()
574
  # bestrq_last = bestrq_last.detach()
codeclm/tokenizer/Flow1dVAE/model_septoken.py CHANGED
@@ -20,9 +20,9 @@ from libs.rvq.descript_quantize3 import ResidualVectorQuantize
20
 
21
  from models_gpt.models.gpt2_rope2_time_new_correct_mask_noncasual_reflow import GPT2Model
22
  from models_gpt.models.gpt2_config import GPT2Config
 
23
 
24
  from torch.cuda.amp import autocast
25
- from our_MERT_BESTRQ.test import load_model
26
 
27
  class HubertModelWithFinalProj(HubertModel):
28
  def __init__(self, config):
@@ -253,6 +253,7 @@ class PromptCondAudioDiffusion(nn.Module):
253
  snr_gamma=None,
254
  uncondition=True,
255
  out_paint=False,
 
256
  ):
257
  super().__init__()
258
 
@@ -273,13 +274,9 @@ class PromptCondAudioDiffusion(nn.Module):
273
  self.rsq48towav2vec = torchaudio.transforms.Resample(48000, 16000)
274
  # self.wav2vec = Wav2Vec2BertModel.from_pretrained("facebook/w2v-bert-2.0", trust_remote_code=True)
275
  # self.wav2vec_processor = AutoFeatureExtractor.from_pretrained("facebook/w2v-bert-2.0", trust_remote_code=True)
276
- self.bestrq = load_model(
277
- model_dir='codeclm/tokenizer/Flow1dVAE/our_MERT_BESTRQ/mert_fairseq',
278
- checkpoint_dir='ckpt/encode-s12k.pt',
279
- )
280
  self.rsq48tobestrq = torchaudio.transforms.Resample(48000, 24000)
281
  self.rsq48tohubert = torchaudio.transforms.Resample(48000, 16000)
282
- for v in self.bestrq.parameters():v.requires_grad = False
283
  self.rvq_bestrq_emb = ResidualVectorQuantize(input_dim = 1024, n_codebooks = 1, codebook_size = 16_384, codebook_dim = 32, quantizer_dropout = 0.0, stale_tolerance=200)
284
  self.rvq_bestrq_bgm_emb = ResidualVectorQuantize(input_dim = 1024, n_codebooks = 1, codebook_size = 16_384, codebook_dim = 32, quantizer_dropout = 0.0, stale_tolerance=200)
285
  # self.hubert = HubertModelWithFinalProj.from_pretrained("ckpt/models--lengyue233--content-vec-best/snapshots/c0b9ba13db21beaa4053faae94c102ebe326fd68")
 
20
 
21
  from models_gpt.models.gpt2_rope2_time_new_correct_mask_noncasual_reflow import GPT2Model
22
  from models_gpt.models.gpt2_config import GPT2Config
23
+ from our_MERT_BESTRQ.mert_fairseq.models.musicfm.musicfm_model import MusicFMModel, MusicFMConfig
24
 
25
  from torch.cuda.amp import autocast
 
26
 
27
  class HubertModelWithFinalProj(HubertModel):
28
  def __init__(self, config):
 
253
  snr_gamma=None,
254
  uncondition=True,
255
  out_paint=False,
256
+ ssl_path='ckpt/encode-s12k.pt'
257
  ):
258
  super().__init__()
259
 
 
274
  self.rsq48towav2vec = torchaudio.transforms.Resample(48000, 16000)
275
  # self.wav2vec = Wav2Vec2BertModel.from_pretrained("facebook/w2v-bert-2.0", trust_remote_code=True)
276
  # self.wav2vec_processor = AutoFeatureExtractor.from_pretrained("facebook/w2v-bert-2.0", trust_remote_code=True)
277
+ self.bestrq = MusicFMModel(MusicFMConfig())
 
 
 
278
  self.rsq48tobestrq = torchaudio.transforms.Resample(48000, 24000)
279
  self.rsq48tohubert = torchaudio.transforms.Resample(48000, 16000)
 
280
  self.rvq_bestrq_emb = ResidualVectorQuantize(input_dim = 1024, n_codebooks = 1, codebook_size = 16_384, codebook_dim = 32, quantizer_dropout = 0.0, stale_tolerance=200)
281
  self.rvq_bestrq_bgm_emb = ResidualVectorQuantize(input_dim = 1024, n_codebooks = 1, codebook_size = 16_384, codebook_dim = 32, quantizer_dropout = 0.0, stale_tolerance=200)
282
  # self.hubert = HubertModelWithFinalProj.from_pretrained("ckpt/models--lengyue233--content-vec-best/snapshots/c0b9ba13db21beaa4053faae94c102ebe326fd68")
codeclm/tokenizer/Flow1dVAE/our_MERT_BESTRQ/mert_fairseq/models/musicfm/musicfm_model.py CHANGED
@@ -4,14 +4,6 @@ except:
4
  import sys, os
5
  sys.path.append(os.path.dirname(os.path.abspath(__file__)))
6
  from model.musicfm_25hz import MusicFM25Hz
7
- try:
8
- from fairseq.fairseq.dataclass import FairseqDataclass
9
- from fairseq.fairseq.models import BaseFairseqModel, register_model
10
- from fairseq.fairseq.tasks.fairseq_task import FairseqTask
11
- except:
12
- from fairseq.dataclass import FairseqDataclass
13
- from fairseq.models import BaseFairseqModel, register_model
14
- from fairseq.tasks.fairseq_task import FairseqTask
15
 
16
  from dataclasses import dataclass, field
17
  from typing import List, Tuple, Optional
@@ -22,7 +14,7 @@ from logging import getLogger
22
  logger = getLogger(__name__)
23
 
24
  @dataclass
25
- class MusicFMConfig(FairseqDataclass):
26
  label_rate:int = field(default=25)
27
  num_codebooks:int = field(default=1)
28
  codebook_dim:int = field(default=16)
@@ -45,9 +37,8 @@ class MusicFMConfig(FairseqDataclass):
45
 
46
  SAMPLE_RATE = 24_000
47
 
48
- @register_model("musicfm", dataclass=MusicFMConfig)
49
- class MusicFMModel(BaseFairseqModel):
50
- def __init__(self, cfg: MusicFMConfig, task_cfg: FairseqTask):
51
  super().__init__()
52
  self.cfg = cfg
53
  self.model = MusicFM25Hz(
@@ -91,19 +82,3 @@ class MusicFMModel(BaseFairseqModel):
91
  result["logits"] = logits
92
  result["hidden_emb"] = hidden_emb
93
  return result
94
-
95
- @classmethod
96
- def build_model(cls, cfg: MusicFMConfig, task: FairseqTask):
97
- """Build a new model instance."""
98
-
99
- model = MusicFMModel(cfg, task.cfg)
100
- import numpy as np
101
- s = 0
102
- for param in model.parameters():
103
- s += np.product(param.size())
104
- print('# of parameters: '+str(s/1024.0/1024.0))
105
- return model
106
-
107
- def get_losses(self, result, batch):
108
- return result['losses']
109
-
 
4
  import sys, os
5
  sys.path.append(os.path.dirname(os.path.abspath(__file__)))
6
  from model.musicfm_25hz import MusicFM25Hz
 
 
 
 
 
 
 
 
7
 
8
  from dataclasses import dataclass, field
9
  from typing import List, Tuple, Optional
 
14
  logger = getLogger(__name__)
15
 
16
  @dataclass
17
+ class MusicFMConfig:
18
  label_rate:int = field(default=25)
19
  num_codebooks:int = field(default=1)
20
  codebook_dim:int = field(default=16)
 
37
 
38
  SAMPLE_RATE = 24_000
39
 
40
+ class MusicFMModel(torch.nn.Module):
41
+ def __init__(self, cfg: MusicFMConfig):
 
42
  super().__init__()
43
  self.cfg = cfg
44
  self.model = MusicFM25Hz(
 
82
  result["logits"] = logits
83
  result["hidden_emb"] = hidden_emb
84
  return result