|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
import logging |
|
|
import os |
|
|
import random |
|
|
import uuid as uuid_module |
|
|
from collections import OrderedDict, defaultdict |
|
|
from pathlib import Path |
|
|
from typing import List, Optional, Sequence, Tuple, Union |
|
|
|
|
|
import numpy as np |
|
|
import onnxruntime |
|
|
from hyperpyyaml import load_hyperpyyaml |
|
|
|
|
|
import torch |
|
|
import torchaudio |
|
|
import torchaudio.compliance.kaldi as kaldi |
|
|
from safetensors.torch import load_file |
|
|
from torch import nn |
|
|
from transformers import PreTrainedModel, WhisperFeatureExtractor |
|
|
|
|
|
from .configuration_moss_speech_codec import MossSpeechCodecConfig |
|
|
from .modeling_whisper import WhisperVQEncoder, WhisperVQConfig |
|
|
from .utils import extract_speech_token |
|
|
|
|
|
logger = logging.getLogger(__name__) |
|
|
|
|
|
def set_seed(seed: int) -> None: |
|
|
if not isinstance(seed, int): |
|
|
raise TypeError("Seed must be an integer.") |
|
|
|
|
|
logger.info("Setting random seed to %s", seed) |
|
|
random.seed(seed) |
|
|
np.random.seed(seed) |
|
|
if torch.cuda.is_available(): |
|
|
torch.cuda.manual_seed_all(seed) |
|
|
torch.backends.cudnn.deterministic = True |
|
|
torch.backends.cudnn.benchmark = False |
|
|
else: |
|
|
torch.manual_seed(seed) |
|
|
os.environ["PYTHONHASHSEED"] = str(seed) |
|
|
os.environ["TF_CUDNN_DETERMINISTIC"] = "1" |
|
|
|
|
|
|
|
|
def fade_in_out(fade_in_mel, fade_out_mel, window): |
|
|
device = fade_in_mel.device |
|
|
fade_in_mel, fade_out_mel = fade_in_mel.cpu(), fade_out_mel.cpu() |
|
|
mel_overlap_len = int(window.shape[0] / 2) |
|
|
fade_in_mel[..., :mel_overlap_len] = fade_in_mel[..., :mel_overlap_len] * window[:mel_overlap_len] + \ |
|
|
fade_out_mel[..., -mel_overlap_len:] * window[mel_overlap_len:] |
|
|
return fade_in_mel.to(device) |
|
|
|
|
|
|
|
|
tts_speech_prev = None |
|
|
tts_mel_prev = None |
|
|
|
|
|
|
|
|
class AudioDecoder(nn.Module): |
|
|
def __init__( |
|
|
self, |
|
|
config_path: Union[str, os.PathLike], |
|
|
flow_ckpt_path: Union[str, os.PathLike], |
|
|
hift_ckpt_path: Union[str, os.PathLike], |
|
|
campplus_model: Union[str, os.PathLike], |
|
|
device: Union[str, torch.device] = "cuda", |
|
|
) -> None: |
|
|
super().__init__() |
|
|
self.device = torch.device(device) if isinstance(device, str) else device |
|
|
|
|
|
with open(config_path, "r", encoding="utf-8") as config_file: |
|
|
logger.info("Loading decoder configurations from %s", config_path) |
|
|
self.scratch_configs = load_hyperpyyaml(config_file) |
|
|
|
|
|
|
|
|
self.flow = self.scratch_configs["flow"] |
|
|
self.flow.load_state_dict(torch.load(flow_ckpt_path, map_location=self.device), strict=False) |
|
|
self.hift = self.scratch_configs["hift"] |
|
|
self.hift.load_state_dict(torch.load(hift_ckpt_path, map_location=self.device)) |
|
|
self.hift = self.hift.eval() |
|
|
self.sample_rate = self.scratch_configs["sample_rate"] |
|
|
self.feat_extractor = self.scratch_configs["feat_extractor"] |
|
|
|
|
|
|
|
|
self.flow.to(self.device) |
|
|
self.hift.to(self.device) |
|
|
self.mel_overlap_dict = defaultdict(lambda: None) |
|
|
self.hift_cache_dict = defaultdict(lambda: None) |
|
|
self.token_min_hop_len = 2 * self.flow.input_frame_rate |
|
|
self.token_max_hop_len = 4 * self.flow.input_frame_rate |
|
|
self.token_overlap_len = 3.5 |
|
|
self.mel_overlap_len = int(self.token_overlap_len / self.flow.input_frame_rate * 24000 / (480 * 2)) |
|
|
self.mel_window = np.hamming(2 * self.mel_overlap_len) |
|
|
|
|
|
self.mel_cache_len = 1 |
|
|
self.source_cache_len = int(self.mel_cache_len * 480) |
|
|
|
|
|
session_options = onnxruntime.SessionOptions() |
|
|
session_options.graph_optimization_level = onnxruntime.GraphOptimizationLevel.ORT_ENABLE_ALL |
|
|
session_options.intra_op_num_threads = 1 |
|
|
self.campplus_session = onnxruntime.InferenceSession( |
|
|
str(campplus_model), |
|
|
sess_options=session_options, |
|
|
providers=["CPUExecutionProvider"], |
|
|
) |
|
|
self.speech_window = np.hamming(2 * self.source_cache_len) |
|
|
|
|
|
def token2wav( |
|
|
self, |
|
|
token: torch.Tensor, |
|
|
uuid: str, |
|
|
prompt_token: Optional[torch.Tensor] = None, |
|
|
prompt_feat: Optional[torch.Tensor] = None, |
|
|
embedding: Optional[torch.Tensor] = None, |
|
|
finalize: bool = False, |
|
|
) -> Tuple[torch.Tensor, torch.Tensor]: |
|
|
prompt_token = prompt_token if prompt_token is not None else torch.zeros(1, 0, dtype=torch.int32) |
|
|
prompt_feat = prompt_feat if prompt_feat is not None else torch.zeros(1, 0, 80) |
|
|
embedding = embedding if embedding is not None else torch.zeros(1, 192) |
|
|
|
|
|
tts_mel = self.flow.inference( |
|
|
token=token.to(self.device), |
|
|
token_len=torch.tensor([token.shape[1]], dtype=torch.int32, device=self.device), |
|
|
prompt_token=prompt_token.to(self.device), |
|
|
prompt_token_len=torch.tensor([prompt_token.shape[1]], dtype=torch.int32, device=self.device), |
|
|
prompt_feat=prompt_feat.to(self.device), |
|
|
prompt_feat_len=torch.tensor([prompt_feat.shape[1]], dtype=torch.int32, device=self.device), |
|
|
embedding=embedding.to(self.device), |
|
|
streaming=False, |
|
|
finalize=finalize, |
|
|
) |
|
|
|
|
|
tts_mel = tts_mel[0] |
|
|
if self.mel_overlap_dict[uuid] is not None: |
|
|
tts_mel = fade_in_out(tts_mel, self.mel_overlap_dict[uuid], self.mel_window) |
|
|
|
|
|
if self.hift_cache_dict[uuid] is not None: |
|
|
hift_cache_mel, hift_cache_source = ( |
|
|
self.hift_cache_dict[uuid]["mel"], |
|
|
self.hift_cache_dict[uuid]["source"], |
|
|
) |
|
|
tts_mel = torch.cat([hift_cache_mel, tts_mel], dim=2) |
|
|
|
|
|
else: |
|
|
hift_cache_source = torch.zeros(1, 1, 0) |
|
|
|
|
|
|
|
|
if not finalize: |
|
|
self.mel_overlap_dict[uuid] = tts_mel[:, :, -self.mel_overlap_len:] |
|
|
tts_mel = tts_mel[:, :, :-self.mel_overlap_len] |
|
|
tts_speech, tts_source = self.hift.inference(speech_feat=tts_mel, cache_source=hift_cache_source) |
|
|
|
|
|
self.hift_cache_dict[uuid] = { |
|
|
"mel": tts_mel[:, :, -self.mel_cache_len:], |
|
|
"source": tts_source[:, :, -self.source_cache_len:], |
|
|
"speech": tts_speech[:, -self.source_cache_len:], |
|
|
} |
|
|
tts_speech = tts_speech[:, :-self.source_cache_len] |
|
|
|
|
|
else: |
|
|
tts_speech, tts_source = self.hift.inference(speech_feat=tts_mel, cache_source=hift_cache_source) |
|
|
del self.hift_cache_dict[uuid] |
|
|
del self.mel_overlap_dict[uuid] |
|
|
return tts_speech, tts_mel |
|
|
|
|
|
|
|
|
def offline_inference(self, token: torch.Tensor) -> torch.Tensor: |
|
|
this_uuid = str(uuid_module.uuid1()) |
|
|
tts_speech, tts_mel = self.token2wav(token, uuid=this_uuid, finalize=True) |
|
|
return tts_speech.cpu() |
|
|
|
|
|
def stream_inference( |
|
|
self, |
|
|
token: torch.Tensor, |
|
|
prompt_token: Optional[torch.Tensor] = None, |
|
|
prompt_feat: Optional[torch.Tensor] = None, |
|
|
embedding: Optional[torch.Tensor] = None, |
|
|
block_size: int = 8, |
|
|
) -> torch.Tensor: |
|
|
token = token.to(self.device) |
|
|
this_uuid = str(uuid_module.uuid1()) |
|
|
|
|
|
prompt_tensor = ( |
|
|
prompt_token.to(self.device) |
|
|
if prompt_token is not None |
|
|
else torch.zeros(1, 0, dtype=torch.int32, device=self.device) |
|
|
) |
|
|
prompt_speech_feat = ( |
|
|
prompt_feat.to(self.device) |
|
|
if prompt_feat is not None |
|
|
else torch.zeros(1, 0, 80, device=self.device) |
|
|
) |
|
|
embedding = embedding.to(self.device) if embedding is not None else torch.zeros(1, 192, device=self.device) |
|
|
|
|
|
base_prompt_tensor = prompt_tensor |
|
|
base_prompt_feat = prompt_speech_feat |
|
|
|
|
|
tts_speechs: List[torch.Tensor] = [] |
|
|
tts_mels: List[torch.Tensor] = [] |
|
|
prev_mel: Optional[torch.Tensor] = None |
|
|
|
|
|
for idx in range(0, token.size(1), block_size): |
|
|
tts_token = token[:, idx : idx + block_size] |
|
|
|
|
|
prompt_tensor_current = base_prompt_tensor |
|
|
prompt_feat_current = base_prompt_feat |
|
|
if prev_mel is not None: |
|
|
prompt_feat_current = torch.cat( |
|
|
[base_prompt_feat.transpose(1, 2)] + tts_mels, |
|
|
dim=-1, |
|
|
).transpose(1, 2) |
|
|
prompt_tensor_current = torch.cat([base_prompt_tensor, token[:, :idx]], dim=-1) |
|
|
|
|
|
is_finalize = idx + block_size >= token.size(-1) |
|
|
|
|
|
tts_speech, tts_mel = self.token2wav( |
|
|
tts_token, |
|
|
uuid=this_uuid, |
|
|
prompt_token=prompt_tensor_current, |
|
|
prompt_feat=prompt_feat_current, |
|
|
embedding=embedding, |
|
|
finalize=is_finalize, |
|
|
) |
|
|
|
|
|
prev_mel = tts_mel |
|
|
tts_speechs.append(tts_speech) |
|
|
tts_mels.append(tts_mel) |
|
|
|
|
|
tts_speech = torch.cat(tts_speechs, dim=-1).cpu() |
|
|
|
|
|
return tts_speech |
|
|
|
|
|
def streaming_inference( |
|
|
self, |
|
|
token: torch.Tensor, |
|
|
prompt_token: Optional[torch.Tensor] = None, |
|
|
prompt_feat: Optional[torch.Tensor] = None, |
|
|
embedding: Optional[torch.Tensor] = None, |
|
|
uuid: Optional[str] = None, |
|
|
prev_mel: Optional[torch.Tensor] = None, |
|
|
prev_token: Optional[torch.Tensor] = None, |
|
|
is_finalize: bool = True, |
|
|
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[torch.Tensor]]: |
|
|
token = token.to(self.device) |
|
|
this_uuid = uuid or str(uuid_module.uuid1()) |
|
|
|
|
|
prompt_speech_feat = ( |
|
|
prompt_feat.to(self.device) |
|
|
if prompt_feat is not None |
|
|
else torch.zeros(1, 0, 80, device=self.device) |
|
|
) |
|
|
flow_prompt_speech_token = ( |
|
|
prompt_token.to(self.device) |
|
|
if prompt_token is not None |
|
|
else torch.zeros(1, 0, dtype=torch.int32, device=self.device) |
|
|
) |
|
|
embedding_tensor = ( |
|
|
embedding.to(self.device) |
|
|
if embedding is not None |
|
|
else torch.zeros(1, 192, device=self.device) |
|
|
) |
|
|
|
|
|
if prev_mel is not None: |
|
|
prompt_speech_feat = prev_mel |
|
|
if prev_token is not None: |
|
|
flow_prompt_speech_token = prev_token |
|
|
|
|
|
tts_speech, tts_mel = self.token2wav( |
|
|
token, |
|
|
uuid=this_uuid, |
|
|
prompt_token=flow_prompt_speech_token, |
|
|
prompt_feat=prompt_speech_feat, |
|
|
embedding=embedding_tensor, |
|
|
finalize=is_finalize, |
|
|
) |
|
|
|
|
|
if prev_mel is not None: |
|
|
prev_mel = torch.cat([prev_mel, tts_mel], dim=1) |
|
|
else: |
|
|
prev_mel = tts_mel |
|
|
if prev_token is not None: |
|
|
prev_token = torch.cat([prev_token, token], dim=-1) |
|
|
else: |
|
|
prev_token = token |
|
|
|
|
|
return tts_speech.cpu(), prev_mel, prev_token |
|
|
|
|
|
|
|
|
class MossSpeechCodec(PreTrainedModel): |
|
|
"""MossSpeech codec model (Whisper-VQ encoder + Flow/HiFT decoder). |
|
|
|
|
|
Notes |
|
|
- API is designed to be compatible with the existing |
|
|
`MossSpeechProcessor` usages, while adopting a Transformers-style layout |
|
|
similar to HF codec models (`xcodec`, `encodec`). |
|
|
- `encode` accepts raw audio tensors or file paths. It returns a Python |
|
|
list of codec token ids per input sample for backward-compatibility. |
|
|
- `decode` accepts either a 3D LongTensor `(B, 1, T)` or a nested list of |
|
|
token ids, and returns a dict with a list of waveforms under |
|
|
`"syn_wav_list"` (matching current processor expectations). |
|
|
""" |
|
|
|
|
|
config_class = MossSpeechCodecConfig |
|
|
|
|
|
def __init__( |
|
|
self, |
|
|
encoder_weight_path: Union[str, os.PathLike], |
|
|
encoder_config_path: Union[str, os.PathLike], |
|
|
encoder_feature_extractor_path: Union[str, os.PathLike], |
|
|
flow_path: Union[str, os.PathLike], |
|
|
) -> None: |
|
|
super().__init__(config=MossSpeechCodecConfig()) |
|
|
|
|
|
|
|
|
self.sample_rate = 16000 |
|
|
config = WhisperVQConfig.from_pretrained(str(encoder_config_path)) |
|
|
self.whisper_vqmodel = WhisperVQEncoder(config) |
|
|
|
|
|
state_dict = load_file(str(encoder_weight_path)) |
|
|
new_state_dict: OrderedDict[str, torch.Tensor] = OrderedDict() |
|
|
for k, v in state_dict.items(): |
|
|
if k.startswith("encoder."): |
|
|
new_state_dict[k[len("encoder."):]] = v |
|
|
self.whisper_vqmodel.load_state_dict(new_state_dict, strict=False) |
|
|
|
|
|
self.feature_extractor = WhisperFeatureExtractor.from_pretrained( |
|
|
str(encoder_feature_extractor_path) |
|
|
) |
|
|
|
|
|
|
|
|
self.flow_path = str(flow_path) |
|
|
self.audio_decoder = AudioDecoder( |
|
|
config_path=os.path.join(self.flow_path, "config.yaml"), |
|
|
flow_ckpt_path=os.path.join(self.flow_path, "flow.pt"), |
|
|
hift_ckpt_path=os.path.join(self.flow_path, "hift.pt"), |
|
|
campplus_model=os.path.join(self.flow_path, "campplus.onnx"), |
|
|
).eval() |
|
|
|
|
|
@torch.no_grad() |
|
|
def encode( |
|
|
self, |
|
|
inputs: Union[ |
|
|
Sequence[Union[str, os.PathLike, Tuple[torch.Tensor, int], torch.Tensor]], |
|
|
torch.Tensor, |
|
|
], |
|
|
*, |
|
|
sampling_rate: Optional[int] = None, |
|
|
batch_size: int = 128, |
|
|
) -> List[List[int]]: |
|
|
"""Encode audio into codec token ids. |
|
|
|
|
|
Accepts one of: |
|
|
- a list of file paths |
|
|
- a list of `(waveform, sr)` tuples |
|
|
- a list of 1D/2D waveforms (sr assumed 16k) |
|
|
- a batched tensor with shape `(B, C, T)` or `(B, T)` |
|
|
""" |
|
|
|
|
|
if isinstance(inputs, torch.Tensor): |
|
|
if inputs.dim() == 2: |
|
|
inputs = inputs.unsqueeze(1) |
|
|
if inputs.dim() != 3: |
|
|
raise ValueError("`inputs` must be (B, C, T) when passing a tensor.") |
|
|
sr = sampling_rate or self.sample_rate |
|
|
items: List[Tuple[torch.Tensor, int]] = [ |
|
|
(inputs[i].squeeze(0).cpu(), sr) for i in range(inputs.size(0)) |
|
|
] |
|
|
else: |
|
|
items = list(inputs) |
|
|
|
|
|
|
|
|
audio_tokens: List[List[int]] = extract_speech_token( |
|
|
self.whisper_vqmodel, self.feature_extractor, items, batch_size=batch_size |
|
|
) |
|
|
return audio_tokens |
|
|
|
|
|
def _extract_speech_feat(self, speech: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: |
|
|
speech_feat = self.audio_decoder.feat_extractor(speech).squeeze(dim=0).transpose(0, 1) |
|
|
speech_feat = speech_feat.unsqueeze(dim=0) |
|
|
speech_feat_len = torch.tensor([speech_feat.shape[1]], dtype=torch.int32) |
|
|
return speech_feat, speech_feat_len |
|
|
|
|
|
def _extract_spk_embedding(self, speech_16k: torch.Tensor) -> torch.Tensor: |
|
|
feat = kaldi.fbank(speech_16k, num_mel_bins=80, dither=0, sample_frequency=16000) |
|
|
feat = feat - feat.mean(dim=0, keepdim=True) |
|
|
embedding = self.audio_decoder.campplus_session.run( |
|
|
None, |
|
|
{self.audio_decoder.campplus_session.get_inputs()[0].name: feat.unsqueeze(0).cpu().numpy()}, |
|
|
)[0].flatten().tolist() |
|
|
return torch.tensor([embedding]) |
|
|
|
|
|
@torch.no_grad() |
|
|
def decode( |
|
|
self, |
|
|
audio_codes: Union[Sequence[Sequence[int]], torch.LongTensor], |
|
|
*, |
|
|
prompt_speech: Optional[Union[str, os.PathLike]] = None, |
|
|
prompt_speech_sample_rate: Optional[int] = None, |
|
|
use_spk_embedding: bool = True, |
|
|
use_prompt_speech: bool = True, |
|
|
finalize: bool = True, |
|
|
device: torch.device = torch.device("cuda"), |
|
|
) -> dict: |
|
|
"""Decode codec token ids back to waveform(s). |
|
|
|
|
|
Args |
|
|
- audio_codes: `(B, 1, T)` or Python nested lists per sample. |
|
|
- prompt_speech: path to the enrollment audio used for conditioning. |
|
|
Returns |
|
|
- {"syn_wav_list": List[Tensor(T)]} |
|
|
""" |
|
|
if isinstance(audio_codes, torch.Tensor): |
|
|
if audio_codes.dim() == 3 and audio_codes.size(1) == 1: |
|
|
codes_list: List[List[int]] = [ |
|
|
audio_codes[i, 0].detach().cpu().tolist() for i in range(audio_codes.size(0)) |
|
|
] |
|
|
elif audio_codes.dim() == 2: |
|
|
codes_list = [row.detach().cpu().tolist() for row in audio_codes] |
|
|
else: |
|
|
raise ValueError("`audio_codes` must be (B, 1, T) or (B, T) when passing a tensor.") |
|
|
else: |
|
|
codes_list = [list(c) for c in audio_codes] |
|
|
|
|
|
if prompt_speech is None or not os.path.exists(str(prompt_speech)): |
|
|
raise ValueError("`prompt_speech` path is required for decoding and must exist.") |
|
|
|
|
|
prompt_wav, orig_sr = torchaudio.load(str(prompt_speech)) |
|
|
target_sr = self.audio_decoder.sample_rate |
|
|
if orig_sr != target_sr: |
|
|
prompt_wav = torchaudio.transforms.Resample(orig_freq=orig_sr, new_freq=target_sr)(prompt_wav) |
|
|
|
|
|
device = device if torch.cuda.is_available() or device.type == "cpu" else torch.device("cpu") |
|
|
speech_token = torch.tensor(self.encode([str(prompt_speech)])[0], device=device).unsqueeze(0) |
|
|
speech_feat, speech_feat_len = self._extract_speech_feat(prompt_wav) |
|
|
|
|
|
if target_sr == 24000: |
|
|
token_len = min(int(speech_feat.shape[1] / 4), speech_token.shape[1]) |
|
|
speech_feat, speech_feat_len[:] = speech_feat[:, : 4 * token_len], 4 * token_len |
|
|
speech_token, _ = speech_token[:, :token_len], token_len |
|
|
|
|
|
prompt_16k = torchaudio.transforms.Resample(orig_freq=target_sr, new_freq=16000)(prompt_wav) |
|
|
embedding = self._extract_spk_embedding(prompt_16k).to(device) |
|
|
|
|
|
speech_feat = speech_feat.to(device) |
|
|
speech_feat_len = speech_feat_len.to(device) |
|
|
|
|
|
syn_wav_list: List[torch.Tensor] = [] |
|
|
for codes in codes_list: |
|
|
codes_t = torch.tensor(codes, device=device).unsqueeze(0) |
|
|
uuid = os.urandom(16).hex() |
|
|
|
|
|
kwargs = {"uuid": uuid, "finalize": finalize} |
|
|
if use_prompt_speech: |
|
|
kwargs.update({"prompt_token": speech_token, "prompt_feat": speech_feat}) |
|
|
if use_spk_embedding: |
|
|
kwargs.update({"embedding": embedding}) |
|
|
|
|
|
tts_speech, _ = self.audio_decoder.token2wav(codes_t, **kwargs) |
|
|
syn_wav_list.append(tts_speech.squeeze()) |
|
|
|
|
|
return {"syn_wav_list": syn_wav_list} |
|
|
|
|
|
@classmethod |
|
|
def from_pretrained( |
|
|
cls, |
|
|
pretrained_model_name_or_path: Union[str, os.PathLike], |
|
|
*, |
|
|
revision: Optional[str] = None, |
|
|
cache_dir: Optional[Union[str, os.PathLike]] = None, |
|
|
force_download: bool = False, |
|
|
local_files_only: bool = False, |
|
|
token: Optional[Union[str, bool]] = None, |
|
|
use_auth_token: Optional[Union[str, bool]] = None, |
|
|
subfolder: Optional[str] = None, |
|
|
**kwargs, |
|
|
): |
|
|
"""Instantiate codec from a local directory or a Hugging Face Hub repo. |
|
|
This mirrors the typical Hugging Face ``from_pretrained`` behavior: |
|
|
- If ``pretrained_model_name_or_path`` is a local folder, files are loaded from it. |
|
|
- Otherwise, it is treated as a Hub repo ID and downloaded with ``snapshot_download``. |
|
|
Expected layout inside the resolved base folder: |
|
|
- ``model.safetensors`` (Whisper VQ encoder weights) |
|
|
- ``config.json`` (Whisper VQ config) |
|
|
- ``preprocessor_config.json`` (WhisperFeatureExtractor params) |
|
|
- ``flow/{config.yaml, flow.pt, hift.pt, campplus.onnx}`` |
|
|
""" |
|
|
|
|
|
base: Path |
|
|
path_str = str(pretrained_model_name_or_path) |
|
|
if os.path.isdir(path_str): |
|
|
base = Path(path_str) |
|
|
else: |
|
|
try: |
|
|
from huggingface_hub import snapshot_download |
|
|
except Exception as exc: |
|
|
raise RuntimeError( |
|
|
"huggingface_hub is required to load from a repo id; please `pip install huggingface_hub`." |
|
|
) from exc |
|
|
|
|
|
if token is None and use_auth_token is not None: |
|
|
token = use_auth_token |
|
|
snapshot_path = snapshot_download( |
|
|
repo_id=path_str, |
|
|
revision=revision, |
|
|
cache_dir=str(cache_dir) if cache_dir is not None else None, |
|
|
force_download=force_download, |
|
|
local_files_only=local_files_only, |
|
|
token=token, |
|
|
) |
|
|
base = Path(snapshot_path) |
|
|
if subfolder: |
|
|
base = base / subfolder |
|
|
tokenizer_dir = base |
|
|
flow_dir = base / "flow" |
|
|
|
|
|
missing: List[str] = [] |
|
|
if not (tokenizer_dir / "model.safetensors").exists(): |
|
|
missing.append(str(tokenizer_dir / "model.safetensors")) |
|
|
if not (tokenizer_dir / "config.json").exists(): |
|
|
missing.append(str(tokenizer_dir / "config.json")) |
|
|
if not (tokenizer_dir / "preprocessor_config.json").exists(): |
|
|
missing.append(str(tokenizer_dir / "preprocessor_config.json")) |
|
|
for fname in ("config.yaml", "flow.pt", "hift.pt"): |
|
|
if not (flow_dir / fname).exists(): |
|
|
missing.append(str(flow_dir / fname)) |
|
|
|
|
|
has_campplus = (flow_dir / "campplus.onnx").exists() |
|
|
if missing: |
|
|
raise FileNotFoundError( |
|
|
"Missing required codec assets under resolved path. The following files were not found: " |
|
|
+ ", ".join(missing) |
|
|
) |
|
|
if not has_campplus: |
|
|
logger.warning("campplus.onnx not found under %s; decoding speaker embedding may fail.", flow_dir) |
|
|
encoder_weight_path = str(tokenizer_dir / "model.safetensors") |
|
|
encoder_config_path = str(tokenizer_dir / "config.json") |
|
|
encoder_feature_extractor_path = str(tokenizer_dir) |
|
|
flow_path = str(flow_dir) |
|
|
return cls( |
|
|
encoder_weight_path=encoder_weight_path, |
|
|
encoder_config_path=encoder_config_path, |
|
|
encoder_feature_extractor_path=encoder_feature_extractor_path, |
|
|
flow_path=flow_path, |
|
|
) |
|
|
|