File size: 23,323 Bytes
1a05ac7 0383930 1a05ac7 0383930 1a05ac7 0383930 1a05ac7 0383930 1a05ac7 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 363 364 365 366 367 368 369 370 371 372 373 374 375 376 377 378 379 380 381 382 383 384 385 386 387 388 389 390 391 392 393 394 395 396 397 398 399 400 401 402 403 404 405 406 407 408 409 410 411 412 413 414 415 416 417 418 419 420 421 422 423 424 425 426 427 428 429 430 431 432 433 434 435 436 437 438 439 440 441 442 443 444 445 446 447 448 449 450 451 452 453 454 455 456 457 458 459 460 461 462 463 464 465 466 467 468 469 470 471 472 473 474 475 476 477 478 479 480 481 482 483 484 485 486 487 488 489 490 491 492 493 494 495 496 497 498 499 500 501 502 503 504 505 506 507 508 509 510 511 512 513 514 515 516 517 518 519 520 521 522 523 524 525 526 527 528 529 530 531 532 533 534 535 536 537 538 539 540 541 542 543 544 545 546 547 548 549 550 551 552 553 554 555 556 557 |
# coding=utf-8
# Copyright 2025 OpenMOSS and HuggingFace Inc. teams. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
import os
import random
import uuid as uuid_module
from collections import OrderedDict, defaultdict
from pathlib import Path
from typing import List, Optional, Sequence, Tuple, Union
import numpy as np
import onnxruntime
from hyperpyyaml import load_hyperpyyaml
import torch
import torchaudio
import torchaudio.compliance.kaldi as kaldi
from safetensors.torch import load_file
from torch import nn
from transformers import PreTrainedModel, WhisperFeatureExtractor
from .configuration_moss_speech_codec import MossSpeechCodecConfig
from .modeling_whisper import WhisperVQEncoder, WhisperVQConfig
from .utils import extract_speech_token
logger = logging.getLogger(__name__)
def set_seed(seed: int) -> None:
if not isinstance(seed, int):
raise TypeError("Seed must be an integer.")
logger.info("Setting random seed to %s", seed)
random.seed(seed)
np.random.seed(seed)
if torch.cuda.is_available():
torch.cuda.manual_seed_all(seed)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False
else:
torch.manual_seed(seed)
os.environ["PYTHONHASHSEED"] = str(seed)
os.environ["TF_CUDNN_DETERMINISTIC"] = "1"
def fade_in_out(fade_in_mel, fade_out_mel, window):
device = fade_in_mel.device
fade_in_mel, fade_out_mel = fade_in_mel.cpu(), fade_out_mel.cpu()
mel_overlap_len = int(window.shape[0] / 2)
fade_in_mel[..., :mel_overlap_len] = fade_in_mel[..., :mel_overlap_len] * window[:mel_overlap_len] + \
fade_out_mel[..., -mel_overlap_len:] * window[mel_overlap_len:]
return fade_in_mel.to(device)
tts_speech_prev = None
tts_mel_prev = None
class AudioDecoder(nn.Module):
def __init__(
self,
config_path: Union[str, os.PathLike],
flow_ckpt_path: Union[str, os.PathLike],
hift_ckpt_path: Union[str, os.PathLike],
campplus_model: Union[str, os.PathLike],
device: Union[str, torch.device] = "cuda",
) -> None:
super().__init__()
self.device = torch.device(device) if isinstance(device, str) else device
with open(config_path, "r", encoding="utf-8") as config_file:
logger.info("Loading decoder configurations from %s", config_path)
self.scratch_configs = load_hyperpyyaml(config_file)
# Load models
self.flow = self.scratch_configs["flow"]
self.flow.load_state_dict(torch.load(flow_ckpt_path, map_location=self.device), strict=False)
self.hift = self.scratch_configs["hift"]
self.hift.load_state_dict(torch.load(hift_ckpt_path, map_location=self.device))
self.hift = self.hift.eval()
self.sample_rate = self.scratch_configs["sample_rate"]
self.feat_extractor = self.scratch_configs["feat_extractor"]
# Move models to the appropriate device
self.flow.to(self.device)
self.hift.to(self.device)
self.mel_overlap_dict = defaultdict(lambda: None)
self.hift_cache_dict = defaultdict(lambda: None)
self.token_min_hop_len = 2 * self.flow.input_frame_rate
self.token_max_hop_len = 4 * self.flow.input_frame_rate
self.token_overlap_len = 3.5
self.mel_overlap_len = int(self.token_overlap_len / self.flow.input_frame_rate * 24000 / (480 * 2))
self.mel_window = np.hamming(2 * self.mel_overlap_len)
# hift cache
self.mel_cache_len = 1
self.source_cache_len = int(self.mel_cache_len * 480)
# speech fade in out
session_options = onnxruntime.SessionOptions()
session_options.graph_optimization_level = onnxruntime.GraphOptimizationLevel.ORT_ENABLE_ALL
session_options.intra_op_num_threads = 1
self.campplus_session = onnxruntime.InferenceSession(
str(campplus_model),
sess_options=session_options,
providers=["CPUExecutionProvider"],
)
self.speech_window = np.hamming(2 * self.source_cache_len)
def token2wav(
self,
token: torch.Tensor,
uuid: str,
prompt_token: Optional[torch.Tensor] = None,
prompt_feat: Optional[torch.Tensor] = None,
embedding: Optional[torch.Tensor] = None,
finalize: bool = False,
) -> Tuple[torch.Tensor, torch.Tensor]:
prompt_token = prompt_token if prompt_token is not None else torch.zeros(1, 0, dtype=torch.int32)
prompt_feat = prompt_feat if prompt_feat is not None else torch.zeros(1, 0, 80)
embedding = embedding if embedding is not None else torch.zeros(1, 192)
tts_mel = self.flow.inference(
token=token.to(self.device),
token_len=torch.tensor([token.shape[1]], dtype=torch.int32, device=self.device),
prompt_token=prompt_token.to(self.device),
prompt_token_len=torch.tensor([prompt_token.shape[1]], dtype=torch.int32, device=self.device),
prompt_feat=prompt_feat.to(self.device),
prompt_feat_len=torch.tensor([prompt_feat.shape[1]], dtype=torch.int32, device=self.device),
embedding=embedding.to(self.device),
streaming=False,
finalize=finalize,
)
tts_mel = tts_mel[0]
if self.mel_overlap_dict[uuid] is not None:
tts_mel = fade_in_out(tts_mel, self.mel_overlap_dict[uuid], self.mel_window)
# append hift cache
if self.hift_cache_dict[uuid] is not None:
hift_cache_mel, hift_cache_source = (
self.hift_cache_dict[uuid]["mel"],
self.hift_cache_dict[uuid]["source"],
)
tts_mel = torch.cat([hift_cache_mel, tts_mel], dim=2)
else:
hift_cache_source = torch.zeros(1, 1, 0)
# keep overlap mel and hift cache
if not finalize:
self.mel_overlap_dict[uuid] = tts_mel[:, :, -self.mel_overlap_len:]
tts_mel = tts_mel[:, :, :-self.mel_overlap_len]
tts_speech, tts_source = self.hift.inference(speech_feat=tts_mel, cache_source=hift_cache_source)
self.hift_cache_dict[uuid] = {
"mel": tts_mel[:, :, -self.mel_cache_len:],
"source": tts_source[:, :, -self.source_cache_len:],
"speech": tts_speech[:, -self.source_cache_len:],
}
tts_speech = tts_speech[:, :-self.source_cache_len]
else:
tts_speech, tts_source = self.hift.inference(speech_feat=tts_mel, cache_source=hift_cache_source)
del self.hift_cache_dict[uuid]
del self.mel_overlap_dict[uuid]
return tts_speech, tts_mel
def offline_inference(self, token: torch.Tensor) -> torch.Tensor:
this_uuid = str(uuid_module.uuid1())
tts_speech, tts_mel = self.token2wav(token, uuid=this_uuid, finalize=True)
return tts_speech.cpu()
def stream_inference(
self,
token: torch.Tensor,
prompt_token: Optional[torch.Tensor] = None,
prompt_feat: Optional[torch.Tensor] = None,
embedding: Optional[torch.Tensor] = None,
block_size: int = 8,
) -> torch.Tensor:
token = token.to(self.device)
this_uuid = str(uuid_module.uuid1())
prompt_tensor = (
prompt_token.to(self.device)
if prompt_token is not None
else torch.zeros(1, 0, dtype=torch.int32, device=self.device)
)
prompt_speech_feat = (
prompt_feat.to(self.device)
if prompt_feat is not None
else torch.zeros(1, 0, 80, device=self.device)
)
embedding = embedding.to(self.device) if embedding is not None else torch.zeros(1, 192, device=self.device)
base_prompt_tensor = prompt_tensor
base_prompt_feat = prompt_speech_feat
tts_speechs: List[torch.Tensor] = []
tts_mels: List[torch.Tensor] = []
prev_mel: Optional[torch.Tensor] = None
for idx in range(0, token.size(1), block_size):
tts_token = token[:, idx : idx + block_size]
prompt_tensor_current = base_prompt_tensor
prompt_feat_current = base_prompt_feat
if prev_mel is not None:
prompt_feat_current = torch.cat(
[base_prompt_feat.transpose(1, 2)] + tts_mels,
dim=-1,
).transpose(1, 2)
prompt_tensor_current = torch.cat([base_prompt_tensor, token[:, :idx]], dim=-1)
is_finalize = idx + block_size >= token.size(-1)
tts_speech, tts_mel = self.token2wav(
tts_token,
uuid=this_uuid,
prompt_token=prompt_tensor_current,
prompt_feat=prompt_feat_current,
embedding=embedding,
finalize=is_finalize,
)
prev_mel = tts_mel
tts_speechs.append(tts_speech)
tts_mels.append(tts_mel)
tts_speech = torch.cat(tts_speechs, dim=-1).cpu()
return tts_speech
def streaming_inference(
self,
token: torch.Tensor,
prompt_token: Optional[torch.Tensor] = None,
prompt_feat: Optional[torch.Tensor] = None,
embedding: Optional[torch.Tensor] = None,
uuid: Optional[str] = None,
prev_mel: Optional[torch.Tensor] = None,
prev_token: Optional[torch.Tensor] = None,
is_finalize: bool = True,
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[torch.Tensor]]:
token = token.to(self.device)
this_uuid = uuid or str(uuid_module.uuid1())
prompt_speech_feat = (
prompt_feat.to(self.device)
if prompt_feat is not None
else torch.zeros(1, 0, 80, device=self.device)
)
flow_prompt_speech_token = (
prompt_token.to(self.device)
if prompt_token is not None
else torch.zeros(1, 0, dtype=torch.int32, device=self.device)
)
embedding_tensor = (
embedding.to(self.device)
if embedding is not None
else torch.zeros(1, 192, device=self.device)
)
if prev_mel is not None:
prompt_speech_feat = prev_mel
if prev_token is not None:
flow_prompt_speech_token = prev_token
tts_speech, tts_mel = self.token2wav(
token,
uuid=this_uuid,
prompt_token=flow_prompt_speech_token,
prompt_feat=prompt_speech_feat,
embedding=embedding_tensor,
finalize=is_finalize,
)
if prev_mel is not None:
prev_mel = torch.cat([prev_mel, tts_mel], dim=1)
else:
prev_mel = tts_mel
if prev_token is not None:
prev_token = torch.cat([prev_token, token], dim=-1)
else:
prev_token = token
return tts_speech.cpu(), prev_mel, prev_token
class MossSpeechCodec(PreTrainedModel):
"""MossSpeech codec model (Whisper-VQ encoder + Flow/HiFT decoder).
Notes
- API is designed to be compatible with the existing
`MossSpeechProcessor` usages, while adopting a Transformers-style layout
similar to HF codec models (`xcodec`, `encodec`).
- `encode` accepts raw audio tensors or file paths. It returns a Python
list of codec token ids per input sample for backward-compatibility.
- `decode` accepts either a 3D LongTensor `(B, 1, T)` or a nested list of
token ids, and returns a dict with a list of waveforms under
`"syn_wav_list"` (matching current processor expectations).
"""
config_class = MossSpeechCodecConfig
def __init__(
self,
encoder_weight_path: Union[str, os.PathLike],
encoder_config_path: Union[str, os.PathLike],
encoder_feature_extractor_path: Union[str, os.PathLike],
flow_path: Union[str, os.PathLike],
) -> None:
super().__init__(config=MossSpeechCodecConfig())
# Whisper-VQ encoder
self.sample_rate = 16000
config = WhisperVQConfig.from_pretrained(str(encoder_config_path))
self.whisper_vqmodel = WhisperVQEncoder(config)
state_dict = load_file(str(encoder_weight_path))
new_state_dict: OrderedDict[str, torch.Tensor] = OrderedDict()
for k, v in state_dict.items():
if k.startswith("encoder."):
new_state_dict[k[len("encoder."):]] = v
self.whisper_vqmodel.load_state_dict(new_state_dict, strict=False)
self.feature_extractor = WhisperFeatureExtractor.from_pretrained(
str(encoder_feature_extractor_path)
)
# Flow / HiFT decoder stack
self.flow_path = str(flow_path)
self.audio_decoder = AudioDecoder(
config_path=os.path.join(self.flow_path, "config.yaml"),
flow_ckpt_path=os.path.join(self.flow_path, "flow.pt"),
hift_ckpt_path=os.path.join(self.flow_path, "hift.pt"),
campplus_model=os.path.join(self.flow_path, "campplus.onnx"),
).eval()
@torch.no_grad()
def encode(
self,
inputs: Union[
Sequence[Union[str, os.PathLike, Tuple[torch.Tensor, int], torch.Tensor]],
torch.Tensor,
],
*,
sampling_rate: Optional[int] = None,
batch_size: int = 128,
) -> List[List[int]]:
"""Encode audio into codec token ids.
Accepts one of:
- a list of file paths
- a list of `(waveform, sr)` tuples
- a list of 1D/2D waveforms (sr assumed 16k)
- a batched tensor with shape `(B, C, T)` or `(B, T)`
"""
# Normalize to a list the helper can consume
if isinstance(inputs, torch.Tensor):
if inputs.dim() == 2:
inputs = inputs.unsqueeze(1) # (B, 1, T)
if inputs.dim() != 3:
raise ValueError("`inputs` must be (B, C, T) when passing a tensor.")
sr = sampling_rate or self.sample_rate
items: List[Tuple[torch.Tensor, int]] = [
(inputs[i].squeeze(0).cpu(), sr) for i in range(inputs.size(0))
]
else:
items = list(inputs) # type: ignore[assignment]
# Use the existing utility (supports file paths, tuples, tensors)
audio_tokens: List[List[int]] = extract_speech_token(
self.whisper_vqmodel, self.feature_extractor, items, batch_size=batch_size
)
return audio_tokens
def _extract_speech_feat(self, speech: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
speech_feat = self.audio_decoder.feat_extractor(speech).squeeze(dim=0).transpose(0, 1)
speech_feat = speech_feat.unsqueeze(dim=0)
speech_feat_len = torch.tensor([speech_feat.shape[1]], dtype=torch.int32)
return speech_feat, speech_feat_len
def _extract_spk_embedding(self, speech_16k: torch.Tensor) -> torch.Tensor:
feat = kaldi.fbank(speech_16k, num_mel_bins=80, dither=0, sample_frequency=16000)
feat = feat - feat.mean(dim=0, keepdim=True)
embedding = self.audio_decoder.campplus_session.run(
None,
{self.audio_decoder.campplus_session.get_inputs()[0].name: feat.unsqueeze(0).cpu().numpy()},
)[0].flatten().tolist()
return torch.tensor([embedding])
@torch.no_grad()
def decode(
self,
audio_codes: Union[Sequence[Sequence[int]], torch.LongTensor],
*,
prompt_speech: Optional[Union[str, os.PathLike]] = None,
prompt_speech_sample_rate: Optional[int] = None,
use_spk_embedding: bool = True,
use_prompt_speech: bool = True,
finalize: bool = True,
device: torch.device = torch.device("cuda"),
) -> dict:
"""Decode codec token ids back to waveform(s).
Args
- audio_codes: `(B, 1, T)` or Python nested lists per sample.
- prompt_speech: path to the enrollment audio used for conditioning.
Returns
- {"syn_wav_list": List[Tensor(T)]}
"""
if isinstance(audio_codes, torch.Tensor):
if audio_codes.dim() == 3 and audio_codes.size(1) == 1:
codes_list: List[List[int]] = [
audio_codes[i, 0].detach().cpu().tolist() for i in range(audio_codes.size(0))
]
elif audio_codes.dim() == 2:
codes_list = [row.detach().cpu().tolist() for row in audio_codes]
else:
raise ValueError("`audio_codes` must be (B, 1, T) or (B, T) when passing a tensor.")
else:
codes_list = [list(c) for c in audio_codes]
if prompt_speech is None or not os.path.exists(str(prompt_speech)):
raise ValueError("`prompt_speech` path is required for decoding and must exist.")
prompt_wav, orig_sr = torchaudio.load(str(prompt_speech))
target_sr = self.audio_decoder.sample_rate
if orig_sr != target_sr:
prompt_wav = torchaudio.transforms.Resample(orig_freq=orig_sr, new_freq=target_sr)(prompt_wav)
device = device if torch.cuda.is_available() or device.type == "cpu" else torch.device("cpu")
speech_token = torch.tensor(self.encode([str(prompt_speech)])[0], device=device).unsqueeze(0)
speech_feat, speech_feat_len = self._extract_speech_feat(prompt_wav)
if target_sr == 24000:
token_len = min(int(speech_feat.shape[1] / 4), speech_token.shape[1])
speech_feat, speech_feat_len[:] = speech_feat[:, : 4 * token_len], 4 * token_len
speech_token, _ = speech_token[:, :token_len], token_len
prompt_16k = torchaudio.transforms.Resample(orig_freq=target_sr, new_freq=16000)(prompt_wav)
embedding = self._extract_spk_embedding(prompt_16k).to(device)
speech_feat = speech_feat.to(device)
speech_feat_len = speech_feat_len.to(device)
syn_wav_list: List[torch.Tensor] = []
for codes in codes_list:
codes_t = torch.tensor(codes, device=device).unsqueeze(0)
uuid = os.urandom(16).hex()
kwargs = {"uuid": uuid, "finalize": finalize}
if use_prompt_speech:
kwargs.update({"prompt_token": speech_token, "prompt_feat": speech_feat})
if use_spk_embedding:
kwargs.update({"embedding": embedding})
tts_speech, _ = self.audio_decoder.token2wav(codes_t, **kwargs)
syn_wav_list.append(tts_speech.squeeze())
return {"syn_wav_list": syn_wav_list}
@classmethod
def from_pretrained(
cls,
pretrained_model_name_or_path: Union[str, os.PathLike],
*,
revision: Optional[str] = None,
cache_dir: Optional[Union[str, os.PathLike]] = None,
force_download: bool = False,
local_files_only: bool = False,
token: Optional[Union[str, bool]] = None,
use_auth_token: Optional[Union[str, bool]] = None, # back-compat with HF Transformers kwarg
subfolder: Optional[str] = None,
**kwargs,
):
"""Instantiate codec from a local directory or a Hugging Face Hub repo.
This mirrors the typical Hugging Face ``from_pretrained`` behavior:
- If ``pretrained_model_name_or_path`` is a local folder, files are loaded from it.
- Otherwise, it is treated as a Hub repo ID and downloaded with ``snapshot_download``.
Expected layout inside the resolved base folder:
- ``model.safetensors`` (Whisper VQ encoder weights)
- ``config.json`` (Whisper VQ config)
- ``preprocessor_config.json`` (WhisperFeatureExtractor params)
- ``flow/{config.yaml, flow.pt, hift.pt, campplus.onnx}``
"""
# Resolve local directory vs HF Hub repo.
base: Path
path_str = str(pretrained_model_name_or_path)
if os.path.isdir(path_str):
base = Path(path_str)
else:
try:
from huggingface_hub import snapshot_download # lazy import to avoid hard dependency at import time
except Exception as exc: # pragma: no cover
raise RuntimeError(
"huggingface_hub is required to load from a repo id; please `pip install huggingface_hub`."
) from exc
# HF Transformers historically supports both `token` and deprecated `use_auth_token`.
if token is None and use_auth_token is not None:
token = use_auth_token
snapshot_path = snapshot_download(
repo_id=path_str,
revision=revision,
cache_dir=str(cache_dir) if cache_dir is not None else None,
force_download=force_download,
local_files_only=local_files_only,
token=token,
)
base = Path(snapshot_path)
if subfolder:
base = base / subfolder
tokenizer_dir = base
flow_dir = base / "flow"
# Validate expected files and provide actionable error messages, similar to HF patterns.
missing: List[str] = []
if not (tokenizer_dir / "model.safetensors").exists():
missing.append(str(tokenizer_dir / "model.safetensors"))
if not (tokenizer_dir / "config.json").exists():
missing.append(str(tokenizer_dir / "config.json"))
if not (tokenizer_dir / "preprocessor_config.json").exists():
missing.append(str(tokenizer_dir / "preprocessor_config.json"))
for fname in ("config.yaml", "flow.pt", "hift.pt"):
if not (flow_dir / fname).exists():
missing.append(str(flow_dir / fname))
# `campplus.onnx` may be named differently in some drops; only warn if absent.
has_campplus = (flow_dir / "campplus.onnx").exists()
if missing:
raise FileNotFoundError(
"Missing required codec assets under resolved path. The following files were not found: "
+ ", ".join(missing)
)
if not has_campplus:
logger.warning("campplus.onnx not found under %s; decoding speaker embedding may fail.", flow_dir)
encoder_weight_path = str(tokenizer_dir / "model.safetensors")
encoder_config_path = str(tokenizer_dir / "config.json")
encoder_feature_extractor_path = str(tokenizer_dir)
flow_path = str(flow_dir)
return cls(
encoder_weight_path=encoder_weight_path,
encoder_config_path=encoder_config_path,
encoder_feature_extractor_path=encoder_feature_extractor_path,
flow_path=flow_path,
)
|