MOSS-Speech-Codec / modeling_moss_speech_codec.py
Phospheneser's picture
Update modeling_moss_speech_codec.py
f11796a verified
# coding=utf-8
# Copyright 2025 OpenMOSS and HuggingFace Inc. teams. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
import os
import random
import uuid as uuid_module
from collections import OrderedDict, defaultdict
from pathlib import Path
from typing import List, Optional, Sequence, Tuple, Union
import numpy as np
import onnxruntime
from hyperpyyaml import load_hyperpyyaml
import torch
import torchaudio
import torchaudio.compliance.kaldi as kaldi
from safetensors.torch import load_file
from torch import nn
from transformers import PreTrainedModel, WhisperFeatureExtractor
from .configuration_moss_speech_codec import MossSpeechCodecConfig
from .modeling_whisper import WhisperVQEncoder, WhisperVQConfig
from .utils import extract_speech_token
logger = logging.getLogger(__name__)
def set_seed(seed: int) -> None:
if not isinstance(seed, int):
raise TypeError("Seed must be an integer.")
logger.info("Setting random seed to %s", seed)
random.seed(seed)
np.random.seed(seed)
if torch.cuda.is_available():
torch.cuda.manual_seed_all(seed)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False
else:
torch.manual_seed(seed)
os.environ["PYTHONHASHSEED"] = str(seed)
os.environ["TF_CUDNN_DETERMINISTIC"] = "1"
def fade_in_out(fade_in_mel, fade_out_mel, window):
device = fade_in_mel.device
fade_in_mel, fade_out_mel = fade_in_mel.cpu(), fade_out_mel.cpu()
mel_overlap_len = int(window.shape[0] / 2)
fade_in_mel[..., :mel_overlap_len] = fade_in_mel[..., :mel_overlap_len] * window[:mel_overlap_len] + \
fade_out_mel[..., -mel_overlap_len:] * window[mel_overlap_len:]
return fade_in_mel.to(device)
tts_speech_prev = None
tts_mel_prev = None
class AudioDecoder(nn.Module):
def __init__(
self,
config_path: Union[str, os.PathLike],
flow_ckpt_path: Union[str, os.PathLike],
hift_ckpt_path: Union[str, os.PathLike],
campplus_model: Union[str, os.PathLike],
device: Union[str, torch.device] = "cuda",
) -> None:
super().__init__()
self.device = torch.device(device) if isinstance(device, str) else device
with open(config_path, "r", encoding="utf-8") as config_file:
logger.info("Loading decoder configurations from %s", config_path)
self.scratch_configs = load_hyperpyyaml(config_file)
# Load models
self.flow = self.scratch_configs["flow"]
self.flow.load_state_dict(torch.load(flow_ckpt_path, map_location=self.device), strict=False)
self.hift = self.scratch_configs["hift"]
self.hift.load_state_dict(torch.load(hift_ckpt_path, map_location=self.device))
self.hift = self.hift.eval()
self.sample_rate = self.scratch_configs["sample_rate"]
self.feat_extractor = self.scratch_configs["feat_extractor"]
# Move models to the appropriate device
self.flow.to(self.device)
self.hift.to(self.device)
self.mel_overlap_dict = defaultdict(lambda: None)
self.hift_cache_dict = defaultdict(lambda: None)
self.token_min_hop_len = 2 * self.flow.input_frame_rate
self.token_max_hop_len = 4 * self.flow.input_frame_rate
self.token_overlap_len = 3.5
self.mel_overlap_len = int(self.token_overlap_len / self.flow.input_frame_rate * 24000 / (480 * 2))
self.mel_window = np.hamming(2 * self.mel_overlap_len)
# hift cache
self.mel_cache_len = 1
self.source_cache_len = int(self.mel_cache_len * 480)
# speech fade in out
session_options = onnxruntime.SessionOptions()
session_options.graph_optimization_level = onnxruntime.GraphOptimizationLevel.ORT_ENABLE_ALL
session_options.intra_op_num_threads = 1
self.campplus_session = onnxruntime.InferenceSession(
str(campplus_model),
sess_options=session_options,
providers=["CPUExecutionProvider"],
)
self.speech_window = np.hamming(2 * self.source_cache_len)
def token2wav(
self,
token: torch.Tensor,
uuid: str,
prompt_token: Optional[torch.Tensor] = None,
prompt_feat: Optional[torch.Tensor] = None,
embedding: Optional[torch.Tensor] = None,
finalize: bool = False,
) -> Tuple[torch.Tensor, torch.Tensor]:
prompt_token = prompt_token if prompt_token is not None else torch.zeros(1, 0, dtype=torch.int32)
prompt_feat = prompt_feat if prompt_feat is not None else torch.zeros(1, 0, 80)
embedding = embedding if embedding is not None else torch.zeros(1, 192)
tts_mel = self.flow.inference(
token=token.to(self.device),
token_len=torch.tensor([token.shape[1]], dtype=torch.int32, device=self.device),
prompt_token=prompt_token.to(self.device),
prompt_token_len=torch.tensor([prompt_token.shape[1]], dtype=torch.int32, device=self.device),
prompt_feat=prompt_feat.to(self.device),
prompt_feat_len=torch.tensor([prompt_feat.shape[1]], dtype=torch.int32, device=self.device),
embedding=embedding.to(self.device),
streaming=False,
finalize=finalize,
)
tts_mel = tts_mel[0]
if self.mel_overlap_dict[uuid] is not None:
tts_mel = fade_in_out(tts_mel, self.mel_overlap_dict[uuid], self.mel_window)
# append hift cache
if self.hift_cache_dict[uuid] is not None:
hift_cache_mel, hift_cache_source = (
self.hift_cache_dict[uuid]["mel"],
self.hift_cache_dict[uuid]["source"],
)
tts_mel = torch.cat([hift_cache_mel, tts_mel], dim=2)
else:
hift_cache_source = torch.zeros(1, 1, 0)
# keep overlap mel and hift cache
if not finalize:
self.mel_overlap_dict[uuid] = tts_mel[:, :, -self.mel_overlap_len:]
tts_mel = tts_mel[:, :, :-self.mel_overlap_len]
tts_speech, tts_source = self.hift.inference(speech_feat=tts_mel, cache_source=hift_cache_source)
self.hift_cache_dict[uuid] = {
"mel": tts_mel[:, :, -self.mel_cache_len:],
"source": tts_source[:, :, -self.source_cache_len:],
"speech": tts_speech[:, -self.source_cache_len:],
}
tts_speech = tts_speech[:, :-self.source_cache_len]
else:
tts_speech, tts_source = self.hift.inference(speech_feat=tts_mel, cache_source=hift_cache_source)
del self.hift_cache_dict[uuid]
del self.mel_overlap_dict[uuid]
return tts_speech, tts_mel
def offline_inference(self, token: torch.Tensor) -> torch.Tensor:
this_uuid = str(uuid_module.uuid1())
tts_speech, tts_mel = self.token2wav(token, uuid=this_uuid, finalize=True)
return tts_speech.cpu()
def stream_inference(
self,
token: torch.Tensor,
prompt_token: Optional[torch.Tensor] = None,
prompt_feat: Optional[torch.Tensor] = None,
embedding: Optional[torch.Tensor] = None,
block_size: int = 8,
) -> torch.Tensor:
token = token.to(self.device)
this_uuid = str(uuid_module.uuid1())
prompt_tensor = (
prompt_token.to(self.device)
if prompt_token is not None
else torch.zeros(1, 0, dtype=torch.int32, device=self.device)
)
prompt_speech_feat = (
prompt_feat.to(self.device)
if prompt_feat is not None
else torch.zeros(1, 0, 80, device=self.device)
)
embedding = embedding.to(self.device) if embedding is not None else torch.zeros(1, 192, device=self.device)
base_prompt_tensor = prompt_tensor
base_prompt_feat = prompt_speech_feat
tts_speechs: List[torch.Tensor] = []
tts_mels: List[torch.Tensor] = []
prev_mel: Optional[torch.Tensor] = None
for idx in range(0, token.size(1), block_size):
tts_token = token[:, idx : idx + block_size]
prompt_tensor_current = base_prompt_tensor
prompt_feat_current = base_prompt_feat
if prev_mel is not None:
prompt_feat_current = torch.cat(
[base_prompt_feat.transpose(1, 2)] + tts_mels,
dim=-1,
).transpose(1, 2)
prompt_tensor_current = torch.cat([base_prompt_tensor, token[:, :idx]], dim=-1)
is_finalize = idx + block_size >= token.size(-1)
tts_speech, tts_mel = self.token2wav(
tts_token,
uuid=this_uuid,
prompt_token=prompt_tensor_current,
prompt_feat=prompt_feat_current,
embedding=embedding,
finalize=is_finalize,
)
prev_mel = tts_mel
tts_speechs.append(tts_speech)
tts_mels.append(tts_mel)
tts_speech = torch.cat(tts_speechs, dim=-1).cpu()
return tts_speech
def streaming_inference(
self,
token: torch.Tensor,
prompt_token: Optional[torch.Tensor] = None,
prompt_feat: Optional[torch.Tensor] = None,
embedding: Optional[torch.Tensor] = None,
uuid: Optional[str] = None,
prev_mel: Optional[torch.Tensor] = None,
prev_token: Optional[torch.Tensor] = None,
is_finalize: bool = True,
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[torch.Tensor]]:
token = token.to(self.device)
this_uuid = uuid or str(uuid_module.uuid1())
prompt_speech_feat = (
prompt_feat.to(self.device)
if prompt_feat is not None
else torch.zeros(1, 0, 80, device=self.device)
)
flow_prompt_speech_token = (
prompt_token.to(self.device)
if prompt_token is not None
else torch.zeros(1, 0, dtype=torch.int32, device=self.device)
)
embedding_tensor = (
embedding.to(self.device)
if embedding is not None
else torch.zeros(1, 192, device=self.device)
)
if prev_mel is not None:
prompt_speech_feat = prev_mel
if prev_token is not None:
flow_prompt_speech_token = prev_token
tts_speech, tts_mel = self.token2wav(
token,
uuid=this_uuid,
prompt_token=flow_prompt_speech_token,
prompt_feat=prompt_speech_feat,
embedding=embedding_tensor,
finalize=is_finalize,
)
if prev_mel is not None:
prev_mel = torch.cat([prev_mel, tts_mel], dim=1)
else:
prev_mel = tts_mel
if prev_token is not None:
prev_token = torch.cat([prev_token, token], dim=-1)
else:
prev_token = token
return tts_speech.cpu(), prev_mel, prev_token
class MossSpeechCodec(PreTrainedModel):
"""MossSpeech codec model (Whisper-VQ encoder + Flow/HiFT decoder).
Notes
- API is designed to be compatible with the existing
`MossSpeechProcessor` usages, while adopting a Transformers-style layout
similar to HF codec models (`xcodec`, `encodec`).
- `encode` accepts raw audio tensors or file paths. It returns a Python
list of codec token ids per input sample for backward-compatibility.
- `decode` accepts either a 3D LongTensor `(B, 1, T)` or a nested list of
token ids, and returns a dict with a list of waveforms under
`"syn_wav_list"` (matching current processor expectations).
"""
config_class = MossSpeechCodecConfig
def __init__(
self,
encoder_weight_path: Union[str, os.PathLike],
encoder_config_path: Union[str, os.PathLike],
encoder_feature_extractor_path: Union[str, os.PathLike],
flow_path: Union[str, os.PathLike],
) -> None:
super().__init__(config=MossSpeechCodecConfig())
# Whisper-VQ encoder
self.sample_rate = 16000
config = WhisperVQConfig.from_pretrained(str(encoder_config_path))
self.whisper_vqmodel = WhisperVQEncoder(config)
state_dict = load_file(str(encoder_weight_path))
new_state_dict: OrderedDict[str, torch.Tensor] = OrderedDict()
for k, v in state_dict.items():
if k.startswith("encoder."):
new_state_dict[k[len("encoder."):]] = v
self.whisper_vqmodel.load_state_dict(new_state_dict, strict=False)
self.feature_extractor = WhisperFeatureExtractor.from_pretrained(
str(encoder_feature_extractor_path)
)
# Flow / HiFT decoder stack
self.flow_path = str(flow_path)
self.audio_decoder = AudioDecoder(
config_path=os.path.join(self.flow_path, "config.yaml"),
flow_ckpt_path=os.path.join(self.flow_path, "flow.pt"),
hift_ckpt_path=os.path.join(self.flow_path, "hift.pt"),
campplus_model=os.path.join(self.flow_path, "campplus.onnx"),
).eval()
@torch.no_grad()
def encode(
self,
inputs: Union[
Sequence[Union[str, os.PathLike, Tuple[torch.Tensor, int], torch.Tensor]],
torch.Tensor,
],
*,
sampling_rate: Optional[int] = None,
batch_size: int = 128,
) -> List[List[int]]:
"""Encode audio into codec token ids.
Accepts one of:
- a list of file paths
- a list of `(waveform, sr)` tuples
- a list of 1D/2D waveforms (sr assumed 16k)
- a batched tensor with shape `(B, C, T)` or `(B, T)`
"""
# Normalize to a list the helper can consume
if isinstance(inputs, torch.Tensor):
if inputs.dim() == 2:
inputs = inputs.unsqueeze(1) # (B, 1, T)
if inputs.dim() != 3:
raise ValueError("`inputs` must be (B, C, T) when passing a tensor.")
sr = sampling_rate or self.sample_rate
items: List[Tuple[torch.Tensor, int]] = [
(inputs[i].squeeze(0).cpu(), sr) for i in range(inputs.size(0))
]
else:
items = list(inputs) # type: ignore[assignment]
# Use the existing utility (supports file paths, tuples, tensors)
audio_tokens: List[List[int]] = extract_speech_token(
self.whisper_vqmodel, self.feature_extractor, items, batch_size=batch_size
)
return audio_tokens
def _extract_speech_feat(self, speech: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
speech_feat = self.audio_decoder.feat_extractor(speech).squeeze(dim=0).transpose(0, 1)
speech_feat = speech_feat.unsqueeze(dim=0)
speech_feat_len = torch.tensor([speech_feat.shape[1]], dtype=torch.int32)
return speech_feat, speech_feat_len
def _extract_spk_embedding(self, speech_16k: torch.Tensor) -> torch.Tensor:
feat = kaldi.fbank(speech_16k, num_mel_bins=80, dither=0, sample_frequency=16000)
feat = feat - feat.mean(dim=0, keepdim=True)
embedding = self.audio_decoder.campplus_session.run(
None,
{self.audio_decoder.campplus_session.get_inputs()[0].name: feat.unsqueeze(0).cpu().numpy()},
)[0].flatten().tolist()
return torch.tensor([embedding])
@torch.no_grad()
def decode(
self,
audio_codes: Union[Sequence[Sequence[int]], torch.LongTensor],
*,
prompt_speech: Optional[Union[str, os.PathLike]] = None,
prompt_speech_sample_rate: Optional[int] = None,
use_spk_embedding: bool = True,
use_prompt_speech: bool = True,
finalize: bool = True,
device: torch.device = torch.device("cuda"),
) -> dict:
"""Decode codec token ids back to waveform(s).
Args
- audio_codes: `(B, 1, T)` or Python nested lists per sample.
- prompt_speech: path to the enrollment audio used for conditioning.
Returns
- {"syn_wav_list": List[Tensor(T)]}
"""
if isinstance(audio_codes, torch.Tensor):
if audio_codes.dim() == 3 and audio_codes.size(1) == 1:
codes_list: List[List[int]] = [
audio_codes[i, 0].detach().cpu().tolist() for i in range(audio_codes.size(0))
]
elif audio_codes.dim() == 2:
codes_list = [row.detach().cpu().tolist() for row in audio_codes]
else:
raise ValueError("`audio_codes` must be (B, 1, T) or (B, T) when passing a tensor.")
else:
codes_list = [list(c) for c in audio_codes]
if prompt_speech is None or not os.path.exists(str(prompt_speech)):
raise ValueError("`prompt_speech` path is required for decoding and must exist.")
prompt_wav, orig_sr = torchaudio.load(str(prompt_speech))
target_sr = self.audio_decoder.sample_rate
if orig_sr != target_sr:
prompt_wav = torchaudio.transforms.Resample(orig_freq=orig_sr, new_freq=target_sr)(prompt_wav)
device = device if torch.cuda.is_available() or device.type == "cpu" else torch.device("cpu")
speech_token = torch.tensor(self.encode([str(prompt_speech)])[0], device=device).unsqueeze(0)
speech_feat, speech_feat_len = self._extract_speech_feat(prompt_wav)
if target_sr == 24000:
token_len = min(int(speech_feat.shape[1] / 4), speech_token.shape[1])
speech_feat, speech_feat_len[:] = speech_feat[:, : 4 * token_len], 4 * token_len
speech_token, _ = speech_token[:, :token_len], token_len
prompt_16k = torchaudio.transforms.Resample(orig_freq=target_sr, new_freq=16000)(prompt_wav)
embedding = self._extract_spk_embedding(prompt_16k).to(device)
speech_feat = speech_feat.to(device)
speech_feat_len = speech_feat_len.to(device)
syn_wav_list: List[torch.Tensor] = []
for codes in codes_list:
codes_t = torch.tensor(codes, device=device).unsqueeze(0)
uuid = os.urandom(16).hex()
kwargs = {"uuid": uuid, "finalize": finalize}
if use_prompt_speech:
kwargs.update({"prompt_token": speech_token, "prompt_feat": speech_feat})
if use_spk_embedding:
kwargs.update({"embedding": embedding})
tts_speech, _ = self.audio_decoder.token2wav(codes_t, **kwargs)
syn_wav_list.append(tts_speech.squeeze())
return {"syn_wav_list": syn_wav_list}
@classmethod
def from_pretrained(
cls,
pretrained_model_name_or_path: Union[str, os.PathLike],
*,
revision: Optional[str] = None,
cache_dir: Optional[Union[str, os.PathLike]] = None,
force_download: bool = False,
local_files_only: bool = False,
token: Optional[Union[str, bool]] = None,
use_auth_token: Optional[Union[str, bool]] = None, # back-compat with HF Transformers kwarg
subfolder: Optional[str] = None,
**kwargs,
):
"""Instantiate codec from a local directory or a Hugging Face Hub repo.
This mirrors the typical Hugging Face ``from_pretrained`` behavior:
- If ``pretrained_model_name_or_path`` is a local folder, files are loaded from it.
- Otherwise, it is treated as a Hub repo ID and downloaded with ``snapshot_download``.
Expected layout inside the resolved base folder:
- ``model.safetensors`` (Whisper VQ encoder weights)
- ``config.json`` (Whisper VQ config)
- ``preprocessor_config.json`` (WhisperFeatureExtractor params)
- ``flow/{config.yaml, flow.pt, hift.pt, campplus.onnx}``
"""
# Resolve local directory vs HF Hub repo.
base: Path
path_str = str(pretrained_model_name_or_path)
if os.path.isdir(path_str):
base = Path(path_str)
else:
try:
from huggingface_hub import snapshot_download # lazy import to avoid hard dependency at import time
except Exception as exc: # pragma: no cover
raise RuntimeError(
"huggingface_hub is required to load from a repo id; please `pip install huggingface_hub`."
) from exc
# HF Transformers historically supports both `token` and deprecated `use_auth_token`.
if token is None and use_auth_token is not None:
token = use_auth_token
snapshot_path = snapshot_download(
repo_id=path_str,
revision=revision,
cache_dir=str(cache_dir) if cache_dir is not None else None,
force_download=force_download,
local_files_only=local_files_only,
token=token,
)
base = Path(snapshot_path)
if subfolder:
base = base / subfolder
tokenizer_dir = base
flow_dir = base / "flow"
# Validate expected files and provide actionable error messages, similar to HF patterns.
missing: List[str] = []
if not (tokenizer_dir / "model.safetensors").exists():
missing.append(str(tokenizer_dir / "model.safetensors"))
if not (tokenizer_dir / "config.json").exists():
missing.append(str(tokenizer_dir / "config.json"))
if not (tokenizer_dir / "preprocessor_config.json").exists():
missing.append(str(tokenizer_dir / "preprocessor_config.json"))
for fname in ("config.yaml", "flow.pt", "hift.pt"):
if not (flow_dir / fname).exists():
missing.append(str(flow_dir / fname))
# `campplus.onnx` may be named differently in some drops; only warn if absent.
has_campplus = (flow_dir / "campplus.onnx").exists()
if missing:
raise FileNotFoundError(
"Missing required codec assets under resolved path. The following files were not found: "
+ ", ".join(missing)
)
if not has_campplus:
logger.warning("campplus.onnx not found under %s; decoding speaker embedding may fail.", flow_dir)
encoder_weight_path = str(tokenizer_dir / "model.safetensors")
encoder_config_path = str(tokenizer_dir / "config.json")
encoder_feature_extractor_path = str(tokenizer_dir)
flow_path = str(flow_dir)
return cls(
encoder_weight_path=encoder_weight_path,
encoder_config_path=encoder_config_path,
encoder_feature_extractor_path=encoder_feature_extractor_path,
flow_path=flow_path,
)