Spaces:
Running
on
L40S
Running
on
L40S
root
commited on
Commit
·
6e28f61
1
Parent(s):
4846f0c
remove fairseq
Browse files
codeclm/tokenizer/Flow1dVAE/generate_septoken.py
CHANGED
|
@@ -14,8 +14,8 @@ import tools.torch_tools as torch_tools
|
|
| 14 |
from safetensors.torch import load_file
|
| 15 |
from third_party.demucs.models.pretrained import get_model_from_yaml
|
| 16 |
from filelock import FileLock
|
| 17 |
-
|
| 18 |
-
|
| 19 |
class Separator:
|
| 20 |
def __init__(self, dm_model_path='demucs/ckpt/htdemucs.pth', dm_config_path='demucs/ckpt/htdemucs.yaml', gpu_id=0) -> None:
|
| 21 |
if torch.cuda.is_available() and gpu_id < torch.cuda.device_count():
|
|
|
|
| 14 |
from safetensors.torch import load_file
|
| 15 |
from third_party.demucs.models.pretrained import get_model_from_yaml
|
| 16 |
from filelock import FileLock
|
| 17 |
+
|
| 18 |
+
|
| 19 |
class Separator:
|
| 20 |
def __init__(self, dm_model_path='demucs/ckpt/htdemucs.pth', dm_config_path='demucs/ckpt/htdemucs.yaml', gpu_id=0) -> None:
|
| 21 |
if torch.cuda.is_available() and gpu_id < torch.cuda.device_count():
|
codeclm/tokenizer/Flow1dVAE/model_1rvq.py
CHANGED
|
@@ -19,12 +19,11 @@ from libs.rvq.descript_quantize3 import ResidualVectorQuantize
|
|
| 19 |
|
| 20 |
from models_gpt.models.gpt2_rope2_time_new_correct_mask_noncasual_reflow import GPT2Model
|
| 21 |
from models_gpt.models.gpt2_config import GPT2Config
|
|
|
|
| 22 |
|
| 23 |
from torch.cuda.amp import autocast
|
| 24 |
|
| 25 |
|
| 26 |
-
from our_MERT_BESTRQ.test import load_model
|
| 27 |
-
|
| 28 |
class HubertModelWithFinalProj(HubertModel):
|
| 29 |
def __init__(self, config):
|
| 30 |
super().__init__(config)
|
|
@@ -272,6 +271,7 @@ class PromptCondAudioDiffusion(nn.Module):
|
|
| 272 |
ssl_layer=None,
|
| 273 |
uncondition=True,
|
| 274 |
out_paint=False,
|
|
|
|
| 275 |
):
|
| 276 |
super().__init__()
|
| 277 |
|
|
@@ -294,28 +294,24 @@ class PromptCondAudioDiffusion(nn.Module):
|
|
| 294 |
self.rsq48towav2vec = torchaudio.transforms.Resample(48000, 16000)
|
| 295 |
# self.wav2vec = Wav2Vec2BertModel.from_pretrained("facebook/w2v-bert-2.0", trust_remote_code=True)
|
| 296 |
# self.wav2vec_processor = AutoFeatureExtractor.from_pretrained("facebook/w2v-bert-2.0", trust_remote_code=True)
|
| 297 |
-
self.bestrq =
|
| 298 |
-
model_dir='codeclm/tokenizer/Flow1dVAE/our_MERT_BESTRQ/mert_fairseq',
|
| 299 |
-
checkpoint_dir='ckpt/encode-s12k.pt',
|
| 300 |
-
)
|
| 301 |
self.rsq48tobestrq = torchaudio.transforms.Resample(48000, 24000)
|
| 302 |
self.rsq48tohubert = torchaudio.transforms.Resample(48000, 16000)
|
| 303 |
-
for v in self.bestrq.parameters():v.requires_grad = False
|
| 304 |
self.rvq_bestrq_emb = ResidualVectorQuantize(input_dim = 1024, n_codebooks = 1, codebook_size = 16_384, codebook_dim = 32, quantizer_dropout = 0.0, stale_tolerance=200)
|
| 305 |
for v in self.rvq_bestrq_emb.parameters():v.requires_grad = False
|
| 306 |
# self.hubert = HubertModelWithFinalProj.from_pretrained("ckpt/models--lengyue233--content-vec-best/snapshots/c0b9ba13db21beaa4053faae94c102ebe326fd68")
|
| 307 |
# for v in self.hubert.parameters():v.requires_grad = False
|
| 308 |
self.zero_cond_embedding1 = nn.Parameter(torch.randn(32*32,))
|
| 309 |
# self.xvecmodel = XVECModel()
|
| 310 |
-
config = GPT2Config(n_positions=1000,n_layer=39,n_head=30,n_embd=1200)
|
| 311 |
-
unet = GPT2Model(config)
|
| 312 |
-
mlp = nn.Sequential(
|
| 313 |
-
|
| 314 |
-
|
| 315 |
-
|
| 316 |
-
|
| 317 |
-
|
| 318 |
-
)
|
| 319 |
self.set_from = "random"
|
| 320 |
# self.cfm_wrapper = BASECFM(unet, mlp,self.ssl_layer)
|
| 321 |
self.mask_emb = torch.nn.Embedding(3, 48)
|
|
@@ -538,8 +534,6 @@ class PromptCondAudioDiffusion(nn.Module):
|
|
| 538 |
input_audio_0 = self.preprocess_audio(input_audio_0)
|
| 539 |
input_audio_1 = self.preprocess_audio(input_audio_1)
|
| 540 |
|
| 541 |
-
self.bestrq.eval()
|
| 542 |
-
|
| 543 |
# bestrq_middle,bestrq_last = self.extract_bestrq_embeds(input_audios)
|
| 544 |
# bestrq_middle = bestrq_middle.detach()
|
| 545 |
# bestrq_last = bestrq_last.detach()
|
|
@@ -575,8 +569,6 @@ class PromptCondAudioDiffusion(nn.Module):
|
|
| 575 |
input_audio_0 = self.preprocess_audio(input_audio_0)
|
| 576 |
input_audio_1 = self.preprocess_audio(input_audio_1)
|
| 577 |
|
| 578 |
-
self.bestrq.eval()
|
| 579 |
-
|
| 580 |
# bestrq_middle,bestrq_last = self.extract_bestrq_embeds(input_audios)
|
| 581 |
# bestrq_middle = bestrq_middle.detach()
|
| 582 |
# bestrq_last = bestrq_last.detach()
|
|
|
|
| 19 |
|
| 20 |
from models_gpt.models.gpt2_rope2_time_new_correct_mask_noncasual_reflow import GPT2Model
|
| 21 |
from models_gpt.models.gpt2_config import GPT2Config
|
| 22 |
+
from our_MERT_BESTRQ.mert_fairseq.models.musicfm.musicfm_model import MusicFMModel, MusicFMConfig
|
| 23 |
|
| 24 |
from torch.cuda.amp import autocast
|
| 25 |
|
| 26 |
|
|
|
|
|
|
|
| 27 |
class HubertModelWithFinalProj(HubertModel):
|
| 28 |
def __init__(self, config):
|
| 29 |
super().__init__(config)
|
|
|
|
| 271 |
ssl_layer=None,
|
| 272 |
uncondition=True,
|
| 273 |
out_paint=False,
|
| 274 |
+
ssl_path='ckpt/encode-s12k.pt'
|
| 275 |
):
|
| 276 |
super().__init__()
|
| 277 |
|
|
|
|
| 294 |
self.rsq48towav2vec = torchaudio.transforms.Resample(48000, 16000)
|
| 295 |
# self.wav2vec = Wav2Vec2BertModel.from_pretrained("facebook/w2v-bert-2.0", trust_remote_code=True)
|
| 296 |
# self.wav2vec_processor = AutoFeatureExtractor.from_pretrained("facebook/w2v-bert-2.0", trust_remote_code=True)
|
| 297 |
+
self.bestrq = MusicFMModel(MusicFMConfig())
|
|
|
|
|
|
|
|
|
|
| 298 |
self.rsq48tobestrq = torchaudio.transforms.Resample(48000, 24000)
|
| 299 |
self.rsq48tohubert = torchaudio.transforms.Resample(48000, 16000)
|
|
|
|
| 300 |
self.rvq_bestrq_emb = ResidualVectorQuantize(input_dim = 1024, n_codebooks = 1, codebook_size = 16_384, codebook_dim = 32, quantizer_dropout = 0.0, stale_tolerance=200)
|
| 301 |
for v in self.rvq_bestrq_emb.parameters():v.requires_grad = False
|
| 302 |
# self.hubert = HubertModelWithFinalProj.from_pretrained("ckpt/models--lengyue233--content-vec-best/snapshots/c0b9ba13db21beaa4053faae94c102ebe326fd68")
|
| 303 |
# for v in self.hubert.parameters():v.requires_grad = False
|
| 304 |
self.zero_cond_embedding1 = nn.Parameter(torch.randn(32*32,))
|
| 305 |
# self.xvecmodel = XVECModel()
|
| 306 |
+
# config = GPT2Config(n_positions=1000,n_layer=39,n_head=30,n_embd=1200)
|
| 307 |
+
# unet = GPT2Model(config)
|
| 308 |
+
# mlp = nn.Sequential(
|
| 309 |
+
# nn.Linear(1200, 1024),
|
| 310 |
+
# nn.SiLU(),
|
| 311 |
+
# nn.Linear(1024, 1024),
|
| 312 |
+
# nn.SiLU(),
|
| 313 |
+
# nn.Linear(1024, 768)
|
| 314 |
+
# )
|
| 315 |
self.set_from = "random"
|
| 316 |
# self.cfm_wrapper = BASECFM(unet, mlp,self.ssl_layer)
|
| 317 |
self.mask_emb = torch.nn.Embedding(3, 48)
|
|
|
|
| 534 |
input_audio_0 = self.preprocess_audio(input_audio_0)
|
| 535 |
input_audio_1 = self.preprocess_audio(input_audio_1)
|
| 536 |
|
|
|
|
|
|
|
| 537 |
# bestrq_middle,bestrq_last = self.extract_bestrq_embeds(input_audios)
|
| 538 |
# bestrq_middle = bestrq_middle.detach()
|
| 539 |
# bestrq_last = bestrq_last.detach()
|
|
|
|
| 569 |
input_audio_0 = self.preprocess_audio(input_audio_0)
|
| 570 |
input_audio_1 = self.preprocess_audio(input_audio_1)
|
| 571 |
|
|
|
|
|
|
|
| 572 |
# bestrq_middle,bestrq_last = self.extract_bestrq_embeds(input_audios)
|
| 573 |
# bestrq_middle = bestrq_middle.detach()
|
| 574 |
# bestrq_last = bestrq_last.detach()
|
codeclm/tokenizer/Flow1dVAE/model_septoken.py
CHANGED
|
@@ -20,9 +20,9 @@ from libs.rvq.descript_quantize3 import ResidualVectorQuantize
|
|
| 20 |
|
| 21 |
from models_gpt.models.gpt2_rope2_time_new_correct_mask_noncasual_reflow import GPT2Model
|
| 22 |
from models_gpt.models.gpt2_config import GPT2Config
|
|
|
|
| 23 |
|
| 24 |
from torch.cuda.amp import autocast
|
| 25 |
-
from our_MERT_BESTRQ.test import load_model
|
| 26 |
|
| 27 |
class HubertModelWithFinalProj(HubertModel):
|
| 28 |
def __init__(self, config):
|
|
@@ -253,6 +253,7 @@ class PromptCondAudioDiffusion(nn.Module):
|
|
| 253 |
snr_gamma=None,
|
| 254 |
uncondition=True,
|
| 255 |
out_paint=False,
|
|
|
|
| 256 |
):
|
| 257 |
super().__init__()
|
| 258 |
|
|
@@ -273,13 +274,9 @@ class PromptCondAudioDiffusion(nn.Module):
|
|
| 273 |
self.rsq48towav2vec = torchaudio.transforms.Resample(48000, 16000)
|
| 274 |
# self.wav2vec = Wav2Vec2BertModel.from_pretrained("facebook/w2v-bert-2.0", trust_remote_code=True)
|
| 275 |
# self.wav2vec_processor = AutoFeatureExtractor.from_pretrained("facebook/w2v-bert-2.0", trust_remote_code=True)
|
| 276 |
-
self.bestrq =
|
| 277 |
-
model_dir='codeclm/tokenizer/Flow1dVAE/our_MERT_BESTRQ/mert_fairseq',
|
| 278 |
-
checkpoint_dir='ckpt/encode-s12k.pt',
|
| 279 |
-
)
|
| 280 |
self.rsq48tobestrq = torchaudio.transforms.Resample(48000, 24000)
|
| 281 |
self.rsq48tohubert = torchaudio.transforms.Resample(48000, 16000)
|
| 282 |
-
for v in self.bestrq.parameters():v.requires_grad = False
|
| 283 |
self.rvq_bestrq_emb = ResidualVectorQuantize(input_dim = 1024, n_codebooks = 1, codebook_size = 16_384, codebook_dim = 32, quantizer_dropout = 0.0, stale_tolerance=200)
|
| 284 |
self.rvq_bestrq_bgm_emb = ResidualVectorQuantize(input_dim = 1024, n_codebooks = 1, codebook_size = 16_384, codebook_dim = 32, quantizer_dropout = 0.0, stale_tolerance=200)
|
| 285 |
# self.hubert = HubertModelWithFinalProj.from_pretrained("ckpt/models--lengyue233--content-vec-best/snapshots/c0b9ba13db21beaa4053faae94c102ebe326fd68")
|
|
|
|
| 20 |
|
| 21 |
from models_gpt.models.gpt2_rope2_time_new_correct_mask_noncasual_reflow import GPT2Model
|
| 22 |
from models_gpt.models.gpt2_config import GPT2Config
|
| 23 |
+
from our_MERT_BESTRQ.mert_fairseq.models.musicfm.musicfm_model import MusicFMModel, MusicFMConfig
|
| 24 |
|
| 25 |
from torch.cuda.amp import autocast
|
|
|
|
| 26 |
|
| 27 |
class HubertModelWithFinalProj(HubertModel):
|
| 28 |
def __init__(self, config):
|
|
|
|
| 253 |
snr_gamma=None,
|
| 254 |
uncondition=True,
|
| 255 |
out_paint=False,
|
| 256 |
+
ssl_path='ckpt/encode-s12k.pt'
|
| 257 |
):
|
| 258 |
super().__init__()
|
| 259 |
|
|
|
|
| 274 |
self.rsq48towav2vec = torchaudio.transforms.Resample(48000, 16000)
|
| 275 |
# self.wav2vec = Wav2Vec2BertModel.from_pretrained("facebook/w2v-bert-2.0", trust_remote_code=True)
|
| 276 |
# self.wav2vec_processor = AutoFeatureExtractor.from_pretrained("facebook/w2v-bert-2.0", trust_remote_code=True)
|
| 277 |
+
self.bestrq = MusicFMModel(MusicFMConfig())
|
|
|
|
|
|
|
|
|
|
| 278 |
self.rsq48tobestrq = torchaudio.transforms.Resample(48000, 24000)
|
| 279 |
self.rsq48tohubert = torchaudio.transforms.Resample(48000, 16000)
|
|
|
|
| 280 |
self.rvq_bestrq_emb = ResidualVectorQuantize(input_dim = 1024, n_codebooks = 1, codebook_size = 16_384, codebook_dim = 32, quantizer_dropout = 0.0, stale_tolerance=200)
|
| 281 |
self.rvq_bestrq_bgm_emb = ResidualVectorQuantize(input_dim = 1024, n_codebooks = 1, codebook_size = 16_384, codebook_dim = 32, quantizer_dropout = 0.0, stale_tolerance=200)
|
| 282 |
# self.hubert = HubertModelWithFinalProj.from_pretrained("ckpt/models--lengyue233--content-vec-best/snapshots/c0b9ba13db21beaa4053faae94c102ebe326fd68")
|
codeclm/tokenizer/Flow1dVAE/our_MERT_BESTRQ/mert_fairseq/models/musicfm/musicfm_model.py
CHANGED
|
@@ -4,14 +4,6 @@ except:
|
|
| 4 |
import sys, os
|
| 5 |
sys.path.append(os.path.dirname(os.path.abspath(__file__)))
|
| 6 |
from model.musicfm_25hz import MusicFM25Hz
|
| 7 |
-
try:
|
| 8 |
-
from fairseq.fairseq.dataclass import FairseqDataclass
|
| 9 |
-
from fairseq.fairseq.models import BaseFairseqModel, register_model
|
| 10 |
-
from fairseq.fairseq.tasks.fairseq_task import FairseqTask
|
| 11 |
-
except:
|
| 12 |
-
from fairseq.dataclass import FairseqDataclass
|
| 13 |
-
from fairseq.models import BaseFairseqModel, register_model
|
| 14 |
-
from fairseq.tasks.fairseq_task import FairseqTask
|
| 15 |
|
| 16 |
from dataclasses import dataclass, field
|
| 17 |
from typing import List, Tuple, Optional
|
|
@@ -22,7 +14,7 @@ from logging import getLogger
|
|
| 22 |
logger = getLogger(__name__)
|
| 23 |
|
| 24 |
@dataclass
|
| 25 |
-
class MusicFMConfig
|
| 26 |
label_rate:int = field(default=25)
|
| 27 |
num_codebooks:int = field(default=1)
|
| 28 |
codebook_dim:int = field(default=16)
|
|
@@ -45,9 +37,8 @@ class MusicFMConfig(FairseqDataclass):
|
|
| 45 |
|
| 46 |
SAMPLE_RATE = 24_000
|
| 47 |
|
| 48 |
-
|
| 49 |
-
|
| 50 |
-
def __init__(self, cfg: MusicFMConfig, task_cfg: FairseqTask):
|
| 51 |
super().__init__()
|
| 52 |
self.cfg = cfg
|
| 53 |
self.model = MusicFM25Hz(
|
|
@@ -91,19 +82,3 @@ class MusicFMModel(BaseFairseqModel):
|
|
| 91 |
result["logits"] = logits
|
| 92 |
result["hidden_emb"] = hidden_emb
|
| 93 |
return result
|
| 94 |
-
|
| 95 |
-
@classmethod
|
| 96 |
-
def build_model(cls, cfg: MusicFMConfig, task: FairseqTask):
|
| 97 |
-
"""Build a new model instance."""
|
| 98 |
-
|
| 99 |
-
model = MusicFMModel(cfg, task.cfg)
|
| 100 |
-
import numpy as np
|
| 101 |
-
s = 0
|
| 102 |
-
for param in model.parameters():
|
| 103 |
-
s += np.product(param.size())
|
| 104 |
-
print('# of parameters: '+str(s/1024.0/1024.0))
|
| 105 |
-
return model
|
| 106 |
-
|
| 107 |
-
def get_losses(self, result, batch):
|
| 108 |
-
return result['losses']
|
| 109 |
-
|
|
|
|
| 4 |
import sys, os
|
| 5 |
sys.path.append(os.path.dirname(os.path.abspath(__file__)))
|
| 6 |
from model.musicfm_25hz import MusicFM25Hz
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 7 |
|
| 8 |
from dataclasses import dataclass, field
|
| 9 |
from typing import List, Tuple, Optional
|
|
|
|
| 14 |
logger = getLogger(__name__)
|
| 15 |
|
| 16 |
@dataclass
|
| 17 |
+
class MusicFMConfig:
|
| 18 |
label_rate:int = field(default=25)
|
| 19 |
num_codebooks:int = field(default=1)
|
| 20 |
codebook_dim:int = field(default=16)
|
|
|
|
| 37 |
|
| 38 |
SAMPLE_RATE = 24_000
|
| 39 |
|
| 40 |
+
class MusicFMModel(torch.nn.Module):
|
| 41 |
+
def __init__(self, cfg: MusicFMConfig):
|
|
|
|
| 42 |
super().__init__()
|
| 43 |
self.cfg = cfg
|
| 44 |
self.model = MusicFM25Hz(
|
|
|
|
| 82 |
result["logits"] = logits
|
| 83 |
result["hidden_emb"] = hidden_emb
|
| 84 |
return result
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|