Spaces:
Running
Running
| import math | |
| import typing as tp | |
| from functools import partial | |
| from dataclasses import dataclass, field | |
| from typing import Dict, List, Optional, Tuple, Union | |
| import copy | |
| import numpy as np | |
| import torch | |
| import torch.nn as nn | |
| import torch.nn.functional as F | |
| from transformers.models.auto import AutoModel | |
| from transformers.configuration_utils import PretrainedConfig | |
| from transformers.utils import logging | |
| from transformers.modeling_utils import PreTrainedModel | |
| from transformers.activations import ACT2FN | |
| from .configuration_vibevoice import VibeVoiceAcousticTokenizerConfig, VibeVoiceSemanticTokenizerConfig | |
| logger = logging.get_logger(__name__) | |
| import os | |
| # Try to import APEX FusedRMSNorm | |
| try: | |
| from apex.normalization.fused_layer_norm import fused_rms_norm_affine | |
| APEX_AVAILABLE = True | |
| logger.info("APEX FusedRMSNorm is available and will be used for optimization") | |
| if int(os.getenv("OPTIMIZE_FOR_SPEED", "0")) == 0: | |
| APEX_AVAILABLE = False | |
| logger.warning("APEX FusedRMSNorm is disabled by environment variable OPTIMIZE_FOR_SPEED=0") | |
| except ImportError: | |
| APEX_AVAILABLE = False | |
| logger.warning("APEX FusedRMSNorm not available, using native implementation") | |
| # APEX_AVAILABLE=False | |
| # Normalization modules | |
| class ConvLayerNorm(nn.LayerNorm): | |
| """ | |
| Convolution-friendly LayerNorm that moves channels to last dimensions | |
| before running the normalization and moves them back to original position right after. | |
| """ | |
| def __init__(self, normalized_shape: tp.Union[int, tp.List[int], torch.Size], **kwargs): | |
| super().__init__(normalized_shape, **kwargs) | |
| def forward(self, x): | |
| x = x.transpose(1, 2) # b ... t -> b t ... | |
| x = nn.functional.layer_norm(x.float(), self.normalized_shape, self.weight.float(), self.bias.float(), self.eps).type_as(x) | |
| x = x.transpose(1, 2) # b t ... -> b ... t | |
| return x | |
| class RMSNorm(nn.Module): | |
| def __init__(self, dim: int, eps: float = 1e-5, elementwise_affine=True, weight_shape=None): | |
| super().__init__() | |
| self.dim = dim | |
| self.eps = eps | |
| self.elementwise_affine = elementwise_affine | |
| if self.elementwise_affine: | |
| weight_shape = (dim,) if weight_shape is None else weight_shape | |
| self.weight = nn.Parameter(torch.ones(weight_shape)) | |
| else: | |
| self.register_parameter('weight', None) | |
| def _norm(self, x): | |
| return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps) | |
| def forward(self, x): | |
| output = self._norm(x.float()).type_as(x) | |
| if self.weight is not None: | |
| output = output * self.weight | |
| return output | |
| def extra_repr(self) -> str: | |
| return f'dim={self.dim}, eps={self.eps}, elementwise_affine={self.elementwise_affine}' | |
| class ConvRMSNorm(RMSNorm): | |
| def __init__(self, dim: int, eps: float = 1e-5, elementwise_affine=True, weight_shape=None): | |
| super().__init__(dim, eps, elementwise_affine, weight_shape) | |
| def forward(self, x): | |
| x = x.transpose(1, 2) # b ... t -> b t ... | |
| if (not APEX_AVAILABLE) or (not self.elementwise_affine): | |
| # Fallback to native implementation | |
| output = self._norm(x.float()).type_as(x) | |
| if self.weight is not None: | |
| output = output * self.weight | |
| else: | |
| output = fused_rms_norm_affine(x, self.weight, self.weight.shape, self.eps) | |
| output = output.transpose(1, 2) # b t ... -> b ... t | |
| return output | |
| # Convolutional layers and utilities | |
| CONV_NORMALIZATIONS = frozenset(['none', 'weight_norm', 'spectral_norm', | |
| 'time_layer_norm', 'layer_norm', 'time_group_norm']) | |
| def apply_parametrization_norm(module: nn.Module, norm: str = 'none') -> nn.Module: | |
| assert norm in CONV_NORMALIZATIONS | |
| if norm == 'weight_norm': | |
| return nn.utils.weight_norm(module) | |
| elif norm == 'spectral_norm': | |
| return nn.utils.spectral_norm(module) | |
| else: | |
| # We already check was in CONV_NORMALIZATION, so any other choice | |
| # doesn't need reparametrization. | |
| return module | |
| def get_norm_module(module: nn.Module, causal: bool = False, norm: str = 'none', **norm_kwargs) -> nn.Module: | |
| """Return the proper normalization module. If causal is True, this will ensure the returned | |
| module is causal, or return an error if the normalization doesn't support causal evaluation. | |
| """ | |
| assert norm in CONV_NORMALIZATIONS | |
| if norm == 'layer_norm': | |
| assert isinstance(module, nn.modules.conv._ConvNd) | |
| return ConvLayerNorm(module.out_channels, **norm_kwargs) | |
| elif norm == 'time_group_norm': | |
| if causal: | |
| raise ValueError("GroupNorm doesn't support causal evaluation.") | |
| assert isinstance(module, nn.modules.conv._ConvNd) | |
| return nn.GroupNorm(1, module.out_channels, **norm_kwargs) | |
| else: | |
| return nn.Identity() | |
| def get_extra_padding_for_conv1d(x: torch.Tensor, kernel_size: int, stride: int, | |
| padding_total: int = 0) -> int: | |
| """Calculate extra padding needed for convolution to have the same output length""" | |
| length = x.shape[-1] | |
| n_frames = (length - kernel_size + padding_total) / stride + 1 | |
| ideal_length = (math.ceil(n_frames) - 1) * stride + (kernel_size - padding_total) | |
| return ideal_length - length | |
| def pad1d(x: torch.Tensor, paddings: tp.Tuple[int, int], mode: str = 'zero', value: float = 0.): | |
| """Pad 1D input with handling for small inputs in reflect mode""" | |
| length = x.shape[-1] | |
| padding_left, padding_right = paddings | |
| assert padding_left >= 0 and padding_right >= 0, (padding_left, padding_right) | |
| if mode == 'reflect': | |
| max_pad = max(padding_left, padding_right) | |
| extra_pad = 0 | |
| if length <= max_pad: | |
| extra_pad = max_pad - length + 1 | |
| x = F.pad(x, (0, extra_pad)) | |
| padded = F.pad(x, paddings, mode, value) | |
| end = padded.shape[-1] - extra_pad | |
| return padded[..., :end] | |
| else: | |
| return F.pad(x, paddings, mode, value) | |
| def unpad1d(x: torch.Tensor, paddings: tp.Tuple[int, int]): | |
| """Remove padding from x, handling properly zero padding. Only for 1d!""" | |
| padding_left, padding_right = paddings | |
| assert padding_left >= 0 and padding_right >= 0, (padding_left, padding_right) | |
| assert (padding_left + padding_right) <= x.shape[-1] | |
| end = x.shape[-1] - padding_right | |
| return x[..., padding_left: end] | |
| class NormConv1d(nn.Module): | |
| """Wrapper around Conv1d and normalization applied to this conv""" | |
| def __init__(self, *args, causal: bool = False, norm: str = 'none', | |
| norm_kwargs: tp.Dict[str, tp.Any] = {}, **kwargs): | |
| super().__init__() | |
| self.conv = apply_parametrization_norm(nn.Conv1d(*args, **kwargs), norm) | |
| self.norm = get_norm_module(self.conv, causal, norm, **norm_kwargs) | |
| self.norm_type = norm | |
| def forward(self, x): | |
| x = self.conv(x) | |
| x = self.norm(x) | |
| return x | |
| class NormConvTranspose1d(nn.Module): | |
| """Wrapper around ConvTranspose1d and normalization applied to this conv""" | |
| def __init__(self, *args, causal: bool = False, norm: str = 'none', | |
| norm_kwargs: tp.Dict[str, tp.Any] = {}, **kwargs): | |
| super().__init__() | |
| self.convtr = apply_parametrization_norm(nn.ConvTranspose1d(*args, **kwargs), norm) | |
| self.norm = get_norm_module(self.convtr, causal, norm, **norm_kwargs) | |
| self.norm_type = norm | |
| def forward(self, x): | |
| x = self.convtr(x) | |
| x = self.norm(x) | |
| return x | |
| class VibeVoiceTokenizerStreamingCache: | |
| """Cache for streaming convolution, similar to KV cache in attention""" | |
| def __init__(self): | |
| self.cache = {} # Dict mapping (layer_id, sample_idx) to state tensor | |
| def get(self, layer_id: str, sample_indices: torch.Tensor) -> Optional[torch.Tensor]: | |
| """Get cached states for given layer and sample indices""" | |
| states = [] | |
| max_length = 0 | |
| # First pass: collect states and find max length | |
| for idx in sample_indices.tolist(): | |
| key = (layer_id, idx) | |
| if key not in self.cache: | |
| return None # If any sample is missing, return None | |
| state = self.cache[key] | |
| states.append(state) | |
| max_length = max(max_length, state.shape[-1]) | |
| # Second pass: pad states to max length if needed | |
| if len(states) > 0 and states[0].dim() >= 2: | |
| padded_states = [] | |
| for state in states: | |
| if state.shape[-1] < max_length: | |
| # Pad on the time dimension (last dimension) | |
| pad_size = max_length - state.shape[-1] | |
| # Pad with zeros on the LEFT to align the most recent samples | |
| padded_state = F.pad(state, (pad_size, 0), mode='constant', value=0) | |
| padded_states.append(padded_state) | |
| else: | |
| padded_states.append(state) | |
| return torch.stack(padded_states, dim=0) | |
| else: | |
| return torch.stack(states, dim=0) | |
| def set(self, layer_id: str, sample_indices: torch.Tensor, states: torch.Tensor): | |
| """Set cached states for given layer and sample indices""" | |
| for i, idx in enumerate(sample_indices.tolist()): | |
| key = (layer_id, idx) | |
| self.cache[key] = states[i].detach() | |
| def set_to_zero(self, sample_indices: torch.Tensor): | |
| """Set all cached states to zero for given sample indices""" | |
| for key in list(self.cache.keys()): | |
| layer_id, sample_idx = key | |
| if sample_idx in sample_indices.tolist(): | |
| # Create zero tensor with same shape and dtype as cached tensor | |
| cached_tensor = self.cache[key] | |
| self.cache[key] = torch.zeros_like(cached_tensor) | |
| def clear(self, layer_id: Optional[str] = None, sample_indices: Optional[torch.Tensor] = None): | |
| """Clear cache for specific layer/samples or everything""" | |
| if layer_id is None and sample_indices is None: | |
| self.cache.clear() | |
| elif layer_id is not None and sample_indices is None: | |
| # Clear all samples for a specific layer | |
| keys_to_remove = [k for k in self.cache.keys() if k[0] == layer_id] | |
| for k in keys_to_remove: | |
| del self.cache[k] | |
| elif layer_id is not None and sample_indices is not None: | |
| # Clear specific samples for a specific layer | |
| for idx in sample_indices.tolist(): | |
| key = (layer_id, idx) | |
| self.cache.pop(key, None) | |
| class SConv1d(nn.Module): | |
| """Conv1d with built-in handling of asymmetric or causal padding and normalization.""" | |
| def __init__(self, in_channels: int, out_channels: int, | |
| kernel_size: int, stride: int = 1, dilation: int = 1, | |
| groups: int = 1, bias: bool = True, causal: bool = False, | |
| norm: str = 'none', norm_kwargs: tp.Dict[str, tp.Any] = {}, | |
| pad_mode: str = 'reflect'): | |
| super().__init__() | |
| self.conv = NormConv1d(in_channels, out_channels, kernel_size, stride, | |
| dilation=dilation, groups=groups, bias=bias, causal=causal, | |
| norm=norm, norm_kwargs=norm_kwargs) | |
| self.causal = causal | |
| self.pad_mode = pad_mode | |
| # Store configuration | |
| self.kernel_size = kernel_size | |
| self.dilation = dilation | |
| self.stride = stride | |
| self.in_channels = in_channels | |
| self.out_channels = out_channels | |
| # For causal convolution, we need to maintain kernel_size - 1 samples as context | |
| # need to check use which context_size is more suitable | |
| # self.context_size = (kernel_size - 1) * dilation | |
| self.context_size = (kernel_size - 1) * dilation - (stride - 1) | |
| # For non-streaming mode, calculate padding | |
| self.padding_total = (kernel_size - 1) * dilation - (stride - 1) | |
| # Create a unique layer ID for cache management | |
| self._layer_id = None | |
| def layer_id(self): | |
| if self._layer_id is None: | |
| self._layer_id = f"sconv1d_{id(self)}" | |
| return self._layer_id | |
| def forward(self, x: torch.Tensor, | |
| cache: Optional[VibeVoiceTokenizerStreamingCache] = None, | |
| sample_indices: Optional[torch.Tensor] = None, | |
| use_cache: bool = False, | |
| debug: bool = False) -> torch.Tensor: | |
| """ | |
| Forward pass with optional streaming support via cache. | |
| Args: | |
| x: Input tensor [batch_size, channels, time] | |
| cache: VibeVoiceTokenizerStreamingCache object for maintaining states | |
| sample_indices: Indices identifying each sample for cache management | |
| use_cache: Whether to use cached states for streaming | |
| debug: Whether to print debug information | |
| Returns: | |
| Output tensor | |
| """ | |
| B, C, T = x.shape | |
| # Non-streaming mode | |
| if not use_cache or cache is None: | |
| return self._forward_non_streaming(x, debug=debug) | |
| # Streaming mode | |
| assert self.causal, "Streaming mode is only supported for causal convolutions" | |
| assert sample_indices is not None, "sample_indices must be provided for streaming mode" | |
| assert len(sample_indices) == B, "sample_indices must match batch size" | |
| return self._forward_streaming(x, cache, sample_indices, debug) | |
| def _forward_streaming(self, x: torch.Tensor, | |
| cache: VibeVoiceTokenizerStreamingCache, | |
| sample_indices: torch.Tensor, | |
| debug: bool = False) -> torch.Tensor: | |
| """Streaming forward pass with cache operations kept separate from compiled code""" | |
| B, C, T = x.shape | |
| # Cache operations (not compiled) | |
| cached_states = cache.get(self.layer_id, sample_indices) | |
| if cached_states is None: | |
| # First chunk - initialize with zeros for context | |
| if self.context_size > 0: | |
| cached_states = torch.zeros(B, C, self.context_size, device=x.device, dtype=x.dtype) | |
| if debug: | |
| print(f"[DEBUG] Initialized cache with shape: {cached_states.shape}, context_size={self.context_size}") | |
| else: | |
| cached_states = torch.zeros(B, C, 0, device=x.device, dtype=x.dtype) | |
| if debug: | |
| print(f"[DEBUG] No context needed (kernel_size=stride)") | |
| # Concatenate cached states with input | |
| if cached_states.shape[2] > 0: | |
| input_with_context = torch.cat([cached_states, x], dim=2) | |
| else: | |
| input_with_context = x | |
| if debug: | |
| print(f"[DEBUG] Input shape: {x.shape}, Cache shape: {cached_states.shape}, Combined: {input_with_context.shape}") | |
| # Apply convolution directly - no extra padding in streaming mode | |
| # The conv layer will handle its own padding internally | |
| output = self.conv(input_with_context) | |
| if debug: | |
| print(f"[DEBUG] Output shape: {output.shape}") | |
| # Update cache for next chunk | |
| if self.context_size > 0: | |
| # Calculate how many samples to keep | |
| total_input_length = input_with_context.shape[2] | |
| # Keep the last context_size samples | |
| if total_input_length >= self.context_size: | |
| new_cache_start = total_input_length - self.context_size | |
| new_cache = input_with_context[:, :, new_cache_start:] | |
| else: | |
| # If we have less than context_size samples, keep everything | |
| new_cache = input_with_context | |
| if debug: | |
| print(f"[DEBUG] New cache shape: {new_cache.shape}") | |
| cache.set(self.layer_id, sample_indices, new_cache) | |
| return output | |
| def _forward_non_streaming(self, x: torch.Tensor, debug: bool = False) -> torch.Tensor: | |
| """Standard forward pass without streaming""" | |
| B, C, T = x.shape | |
| kernel_size = self.kernel_size | |
| stride = self.stride | |
| dilation = self.dilation | |
| padding_total = self.padding_total | |
| # Compute extra padding for stride alignment | |
| extra_padding = get_extra_padding_for_conv1d(x, kernel_size, stride, padding_total) | |
| if debug: | |
| print(f"[DEBUG NON-STREAMING] Input shape: {x.shape}, padding_total={padding_total}, extra_padding={extra_padding}") | |
| if self.causal: | |
| # Left padding for causal | |
| if self.pad_mode == 'constant': | |
| x = pad1d(x, (padding_total, extra_padding), mode=self.pad_mode, value=0) | |
| else: | |
| x = pad1d(x, (padding_total, extra_padding), mode=self.pad_mode) | |
| else: | |
| # Symmetric padding for non-causal | |
| padding_right = padding_total // 2 | |
| padding_left = padding_total - padding_right | |
| x = pad1d(x, (padding_left, padding_right + extra_padding), mode=self.pad_mode) | |
| if debug: | |
| print(f"[DEBUG NON-STREAMING] After padding: {x.shape}") | |
| output = self.conv(x) | |
| if debug: | |
| print(f"[DEBUG NON-STREAMING] Output shape: {output.shape}") | |
| return output | |
| class SConvTranspose1d(nn.Module): | |
| """ConvTranspose1d with built-in handling of asymmetric or causal padding and normalization.""" | |
| def __init__(self, in_channels: int, out_channels: int, | |
| kernel_size: int, stride: int = 1, causal: bool = False, | |
| norm: str = 'none', trim_right_ratio: float = 1., | |
| norm_kwargs: tp.Dict[str, tp.Any] = {}, bias: bool = True): | |
| super().__init__() | |
| self.convtr = NormConvTranspose1d(in_channels, out_channels, kernel_size, stride, | |
| causal=causal, norm=norm, norm_kwargs=norm_kwargs, bias=bias) | |
| self.causal = causal | |
| self.trim_right_ratio = trim_right_ratio | |
| assert self.causal or self.trim_right_ratio == 1., \ | |
| "`trim_right_ratio` != 1.0 only makes sense for causal convolutions" | |
| assert self.trim_right_ratio >= 0. and self.trim_right_ratio <= 1. | |
| # Store configuration | |
| self.kernel_size = kernel_size | |
| self.stride = stride | |
| self.in_channels = in_channels | |
| self.out_channels = out_channels | |
| # For transposed convolution, padding calculation is different | |
| self.padding_total = kernel_size - stride | |
| # For streaming, we need to keep track of input history | |
| # Transposed conv needs to see multiple input samples to produce correct output | |
| self.context_size = kernel_size - 1 | |
| # Create a unique layer ID for cache management | |
| self._layer_id = None | |
| def layer_id(self): | |
| if self._layer_id is None: | |
| self._layer_id = f"sconvtr1d_{id(self)}" | |
| return self._layer_id | |
| def forward(self, x: torch.Tensor, | |
| cache: Optional[VibeVoiceTokenizerStreamingCache] = None, | |
| sample_indices: Optional[torch.Tensor] = None, | |
| use_cache: bool = False, | |
| debug: bool = False) -> torch.Tensor: | |
| """ | |
| Forward pass with optional streaming support via cache. | |
| """ | |
| B, C, T = x.shape | |
| # Non-streaming mode | |
| if not use_cache or cache is None: | |
| return self._forward_non_streaming(x, debug=debug) | |
| # Streaming mode | |
| assert sample_indices is not None, "sample_indices must be provided for streaming mode" | |
| assert len(sample_indices) == B, "sample_indices must match batch size" | |
| return self._forward_streaming(x, cache, sample_indices, debug) | |
| def _forward_streaming(self, x: torch.Tensor, | |
| cache: VibeVoiceTokenizerStreamingCache, | |
| sample_indices: torch.Tensor, | |
| debug: bool = False) -> torch.Tensor: | |
| """Streaming forward pass with cache operations kept separate from compiled code""" | |
| B, C, T = x.shape | |
| # Cache operations (not compiled) | |
| cached_input = cache.get(self.layer_id, sample_indices) | |
| if cached_input is None: | |
| # First chunk - no history yet | |
| cached_input = torch.zeros(B, C, 0, device=x.device, dtype=x.dtype) | |
| if debug: | |
| print(f"[DEBUG] Initialized empty cache for transposed conv") | |
| # Concatenate cached input with new input | |
| full_input = torch.cat([cached_input, x], dim=2) | |
| if debug: | |
| print(f"[DEBUG] Input shape: {x.shape}, Cache shape: {cached_input.shape}, Combined: {full_input.shape}") | |
| # First chunk or debug mode - use uncompiled version | |
| full_output = self.convtr(full_input) | |
| if debug: | |
| print(f"[DEBUG] Full transposed conv output shape: {full_output.shape}") | |
| # Calculate padding to remove | |
| if self.causal: | |
| padding_right = math.ceil(self.padding_total * self.trim_right_ratio) | |
| padding_left = self.padding_total - padding_right | |
| else: | |
| padding_right = self.padding_total // 2 | |
| padding_left = self.padding_total - padding_right | |
| # Remove padding | |
| if padding_left + padding_right > 0: | |
| full_output = unpad1d(full_output, (padding_left, padding_right)) | |
| if debug: | |
| print(f"[DEBUG] After unpadding: {full_output.shape}") | |
| # Determine which part of the output corresponds to the new input | |
| if cached_input.shape[2] == 0: | |
| # First chunk - return all output | |
| output = full_output | |
| else: | |
| # Subsequent chunks - return only the new output | |
| expected_new_output = T * self.stride | |
| # Take the last expected_new_output samples | |
| if full_output.shape[2] >= expected_new_output: | |
| output = full_output[:, :, -expected_new_output:] | |
| else: | |
| output = full_output | |
| if debug: | |
| print(f"[DEBUG] Final streaming output shape: {output.shape}") | |
| # Update cache | |
| if full_input.shape[2] > self.context_size: | |
| new_cache = full_input[:, :, -self.context_size:] | |
| else: | |
| new_cache = full_input | |
| if debug: | |
| print(f"[DEBUG] New cache shape: {new_cache.shape}") | |
| cache.set(self.layer_id, sample_indices, new_cache) | |
| return output | |
| def _forward_non_streaming(self, x: torch.Tensor, debug: bool = False) -> torch.Tensor: | |
| """Standard forward pass without streaming""" | |
| if debug: | |
| print(f"[DEBUG NON-STREAMING] Input shape: {x.shape}") | |
| # Apply transposed convolution | |
| y = self.convtr(x) | |
| if debug: | |
| print(f"[DEBUG NON-STREAMING] After transposed conv: {y.shape}") | |
| # Calculate and remove padding | |
| if self.causal: | |
| padding_right = math.ceil(self.padding_total * self.trim_right_ratio) | |
| padding_left = self.padding_total - padding_right | |
| else: | |
| padding_right = self.padding_total // 2 | |
| padding_left = self.padding_total - padding_right | |
| if padding_left + padding_right > 0: | |
| y = unpad1d(y, (padding_left, padding_right)) | |
| if debug: | |
| print(f"[DEBUG NON-STREAMING] Final output shape: {y.shape}") | |
| return y | |
| # FFN | |
| class FFN(nn.Module): | |
| def __init__( | |
| self, | |
| embed_dim, | |
| ffn_dim, | |
| bias=False, | |
| ): | |
| super().__init__() | |
| self.embed_dim = embed_dim | |
| self.linear1 = nn.Linear(self.embed_dim, ffn_dim, bias=bias) | |
| self.gelu = ACT2FN["gelu"] | |
| self.linear2 = nn.Linear(ffn_dim, self.embed_dim, bias=bias) | |
| def forward(self, x): | |
| x = self.linear1(x) | |
| x = self.gelu(x) | |
| x = self.linear2(x) | |
| return x | |
| class Convlayer(nn.Module): | |
| def __init__( | |
| self, | |
| in_channels, | |
| out_channels, | |
| kernel_size, | |
| stride=1, | |
| dilation=1, | |
| groups=1, | |
| bias=True, | |
| pad_mode='zeros', | |
| norm='weight_norm', | |
| causal=True, | |
| ): | |
| super().__init__() | |
| self.conv = SConv1d(in_channels, out_channels, kernel_size, stride=stride, dilation=dilation, | |
| groups=groups, bias=bias, pad_mode=pad_mode, norm=norm, causal=causal) | |
| def forward(self, x): | |
| return self.conv(x) | |
| class Block1D(nn.Module): | |
| def __init__(self, dim, kernel_size=7, drop_path=0., mixer_layer='conv', | |
| layer_scale_init_value=1e-6, **kwargs): | |
| super().__init__() | |
| if kwargs.get('layernorm', 'LN') == 'LN': | |
| self.norm = ConvLayerNorm(dim, eps=kwargs.get('eps', 1e-6)) | |
| self.ffn_norm = ConvLayerNorm(dim, eps=kwargs.get('eps', 1e-6)) | |
| elif kwargs.get('layernorm', 'RMSNorm') == 'RMSNorm': | |
| self.norm = ConvRMSNorm(dim, eps=kwargs.get('eps', 1e-6)) | |
| self.ffn_norm = ConvRMSNorm(dim, eps=kwargs.get('eps', 1e-6)) | |
| if mixer_layer == 'conv': | |
| self.mixer = Convlayer(dim, dim, groups=kwargs.get('groups', 1), | |
| kernel_size=kernel_size, | |
| pad_mode=kwargs.get('pad_mode', 'reflect'), | |
| norm=kwargs.get('norm', 'none'), | |
| causal=kwargs.get('causal', True), | |
| bias=kwargs.get('bias', True), | |
| ) | |
| elif mixer_layer == 'depthwise_conv': | |
| self.mixer = Convlayer(dim, dim, groups=dim, | |
| kernel_size=kernel_size, | |
| pad_mode=kwargs.get('pad_mode', 'reflect'), | |
| norm=kwargs.get('norm', 'none'), | |
| causal=kwargs.get('causal', True), | |
| bias=kwargs.get('bias', True), | |
| ) | |
| else: | |
| raise ValueError(f"Unsupported mixer layer: {mixer_layer}") | |
| self.ffn = FFN( | |
| dim, | |
| kwargs.get('ffn_expansion', 4) * dim, | |
| bias=kwargs.get('bias', False), | |
| ) | |
| self.drop_path = nn.Identity() if drop_path <= 0. else nn.modules.DropPath(drop_path) | |
| if layer_scale_init_value > 0: | |
| self.gamma = nn.Parameter(layer_scale_init_value * torch.ones((dim)), requires_grad=True) | |
| self.ffn_gamma = nn.Parameter(layer_scale_init_value * torch.ones((dim)), requires_grad=True) | |
| else: | |
| self.gamma = None | |
| self.ffn_gamma = None | |
| def forward(self, x): | |
| # mixer | |
| residual = x | |
| x = self.norm(x) | |
| x = self.mixer(x) | |
| if self.gamma is not None: | |
| x = x * self.gamma.unsqueeze(-1) | |
| x = residual + self.drop_path(x) | |
| # ffn | |
| residual = x | |
| x = self.ffn_norm(x) | |
| x = x.permute(0, 2, 1) | |
| x = self.ffn(x) | |
| x = x.permute(0, 2, 1) | |
| if self.ffn_gamma is not None: | |
| x = x * self.ffn_gamma.unsqueeze(-1) | |
| x = residual + self.drop_path(x) | |
| return x | |
| class TokenizerEncoder(nn.Module): | |
| """ | |
| Encoder component for the VibeVoice tokenizer that converts audio to latent representations. | |
| Args: | |
| config: Configuration object with model parameters | |
| """ | |
| def __init__(self, config): | |
| super().__init__() | |
| # Extract parameters from config | |
| self.channels = config.channels | |
| self.dimension = config.dimension | |
| self.n_filters = config.n_filters | |
| self.ratios = list(reversed(config.ratios)) | |
| self.depths = config.depths | |
| self.n_residual_layers = getattr(config, "n_residual_layers", 1) | |
| self.hop_length = np.prod(self.ratios) | |
| self.causal = config.causal | |
| # Additional config parameters with defaults | |
| kernel_size = getattr(config, "kernel_size", 7) | |
| last_kernel_size = getattr(config, "last_kernel_size", 7) | |
| norm = getattr(config, "norm", "none") | |
| norm_params = getattr(config, "norm_params", {}) | |
| pad_mode = getattr(config, "pad_mode", "reflect") | |
| bias = getattr(config, "bias", True) | |
| layernorm = getattr(config, "layernorm", "LN") | |
| layernorm_eps = getattr(config, "layernorm_eps", 1e-6) | |
| layernorm_elementwise_affine = getattr(config, "layernorm_elementwise_affine", True) | |
| drop_path_rate = getattr(config, "drop_path_rate", 0.0) | |
| mixer_layer = getattr(config, "mixer_layer", "conv") | |
| layer_scale_init_value = getattr(config, "layer_scale_init_value", 0) | |
| disable_last_norm = getattr(config, "disable_last_norm", False) | |
| # determine the norm type based on layernorm | |
| if layernorm == 'LN': | |
| norm_type = ConvLayerNorm | |
| elif layernorm == 'RMSNorm': | |
| norm_type = partial(ConvRMSNorm, elementwise_affine=layernorm_elementwise_affine) | |
| else: | |
| raise ValueError(f"Unsupported norm type: {layernorm}") | |
| # stem and intermediate downsampling conv layers | |
| stem = nn.Sequential( | |
| SConv1d(self.channels, self.n_filters, kernel_size, norm=norm, norm_kwargs=norm_params, causal=self.causal, pad_mode=pad_mode, bias=bias), | |
| ) | |
| self.downsample_layers = nn.ModuleList() | |
| self.downsample_layers.append(stem) | |
| for i in range(len(self.ratios)): | |
| in_ch = self.n_filters * (2 ** i) | |
| out_ch = self.n_filters * (2 ** (i + 1)) | |
| downsample_layer = nn.Sequential( | |
| SConv1d(in_ch, out_ch, kernel_size=self.ratios[i] * 2, stride=self.ratios[i], causal=self.causal, pad_mode=pad_mode, norm=norm, bias=bias) | |
| ) | |
| self.downsample_layers.append(downsample_layer) | |
| # configure the transformer blocks | |
| layer_type = partial( | |
| Block1D, | |
| mixer_layer=mixer_layer, | |
| layernorm=layernorm, | |
| eps=layernorm_eps, | |
| causal=self.causal, | |
| pad_mode=pad_mode, | |
| norm=norm, | |
| bias=bias, | |
| layer_scale_init_value=layer_scale_init_value, | |
| ) | |
| self.stages = nn.ModuleList() | |
| dp_rates = [x.item() for x in torch.linspace(0, drop_path_rate, sum(self.depths))] | |
| cur = 0 | |
| for i in range(len(self.depths)): | |
| in_ch = self.n_filters * (2 ** i) | |
| stage = nn.Sequential( | |
| *[layer_type(dim=in_ch, drop_path=dp_rates[cur + j]) for j in range(self.depths[i])] | |
| ) | |
| self.stages.append(stage) | |
| cur += self.depths[i] | |
| if not disable_last_norm: | |
| self.norm = norm_type(in_ch, eps=layernorm_eps) | |
| else: | |
| self.norm = nn.Identity() | |
| self.head = SConv1d(in_ch, self.dimension, kernel_size=last_kernel_size, causal=self.causal, pad_mode=pad_mode, norm=norm, bias=bias) | |
| def forward_features(self, x, cache=None, sample_indices=None, use_cache=False, debug=False): | |
| for i in range(len(self.depths)): | |
| # Apply downsampling | |
| for layer in self.downsample_layers[i]: | |
| if isinstance(layer, SConv1d): | |
| x = layer(x, cache=cache, sample_indices=sample_indices, use_cache=use_cache, debug=debug) | |
| else: | |
| x = layer(x) | |
| # Apply stage (Block1D contains Convlayer which contains SConv1d) | |
| for block in self.stages[i]: | |
| if hasattr(block, 'mixer') and hasattr(block.mixer, 'conv') and isinstance(block.mixer.conv, SConv1d): | |
| # Block1D forward with cache support | |
| residual = x | |
| x = block.norm(x) | |
| x = block.mixer.conv(x, cache=cache, sample_indices=sample_indices, use_cache=use_cache, debug=debug) | |
| if block.gamma is not None: | |
| x = x * block.gamma.unsqueeze(-1) | |
| x = residual + x | |
| # FFN part | |
| residual = x | |
| x = block.ffn_norm(x) | |
| x = x.permute(0, 2, 1) | |
| x = block.ffn(x) | |
| x = x.permute(0, 2, 1) | |
| if block.ffn_gamma is not None: | |
| x = x * block.ffn_gamma.unsqueeze(-1) | |
| x = residual + x | |
| else: | |
| x = block(x) | |
| return self.norm(x) | |
| def forward(self, x, cache=None, sample_indices=None, use_cache=False, debug=False): | |
| x = self.forward_features(x, cache=cache, sample_indices=sample_indices, use_cache=use_cache, debug=debug) | |
| x = self.head(x, cache=cache, sample_indices=sample_indices, use_cache=use_cache, debug=debug) | |
| return x | |
| class TokenizerDecoder(nn.Module): | |
| """ | |
| Decoder component for the VibeVoice tokenizer that converts latent representations back to audio. | |
| Args: | |
| config: Configuration object with model parameters | |
| """ | |
| def __init__(self, config): | |
| super().__init__() | |
| # Extract parameters from config | |
| self.dimension = config.dimension | |
| self.channels = config.channels | |
| self.n_filters = config.n_filters | |
| self.ratios = config.ratios | |
| # IMPORTANT CHANGE: Don't reverse depths again since they're already reversed in VibeVoiceAcousticTokenizerModel | |
| self.depths = config.depths # Changed from list(reversed(config.depths)) | |
| self.n_residual_layers = getattr(config, "n_residual_layers", 1) | |
| self.hop_length = np.prod(self.ratios) | |
| self.causal = config.causal | |
| # Additional config parameters with defaults | |
| kernel_size = getattr(config, "kernel_size", 7) | |
| last_kernel_size = getattr(config, "last_kernel_size", 7) | |
| norm = getattr(config, "norm", "none") | |
| norm_params = getattr(config, "norm_params", {}) | |
| pad_mode = getattr(config, "pad_mode", "reflect") | |
| bias = getattr(config, "bias", True) | |
| layernorm = getattr(config, "layernorm", "LN") | |
| layernorm_eps = getattr(config, "layernorm_eps", 1e-6) | |
| trim_right_ratio = getattr(config, "trim_right_ratio", 1.0) | |
| layernorm_elementwise_affine = getattr(config, "layernorm_elementwise_affine", True) | |
| drop_path_rate = getattr(config, "drop_path_rate", 0.0) | |
| mixer_layer = getattr(config, "mixer_layer", "conv") | |
| layer_scale_init_value = getattr(config, "layer_scale_init_value", 0) | |
| disable_last_norm = getattr(config, "disable_last_norm", False) | |
| # determine the norm type based on layernorm | |
| if layernorm == 'LN': | |
| norm_type = ConvLayerNorm | |
| elif layernorm == 'RMSNorm': | |
| norm_type = partial(ConvRMSNorm, elementwise_affine=layernorm_elementwise_affine) | |
| else: | |
| raise ValueError(f"Unsupported norm type: {layernorm}") | |
| # stem and upsampling layers | |
| stem = nn.Sequential( | |
| SConv1d(self.dimension, self.n_filters * 2 ** (len(self.depths) - 1), kernel_size, norm=norm, | |
| norm_kwargs=norm_params, causal=self.causal, pad_mode=pad_mode, bias=bias), | |
| ) | |
| self.upsample_layers = nn.ModuleList() | |
| self.upsample_layers.append(stem) | |
| for i in range(len(self.ratios)): | |
| in_ch = self.n_filters * (2 ** (len(self.depths) - 1 - i)) | |
| out_ch = self.n_filters * (2 ** (len(self.depths) - 1 - i - 1)) | |
| upsample_layer = nn.Sequential( | |
| SConvTranspose1d(in_ch, out_ch, | |
| kernel_size=self.ratios[i] * 2, stride=self.ratios[i], | |
| norm=norm, norm_kwargs=norm_params, bias=bias, | |
| causal=self.causal, trim_right_ratio=trim_right_ratio), | |
| ) | |
| self.upsample_layers.append(upsample_layer) | |
| # configure transformer blocks | |
| layer_type = partial( | |
| Block1D, | |
| mixer_layer=mixer_layer, | |
| layernorm=layernorm, | |
| eps=layernorm_eps, | |
| causal=self.causal, | |
| pad_mode=pad_mode, | |
| norm=norm, | |
| bias=bias, | |
| layer_scale_init_value=layer_scale_init_value, | |
| ) | |
| self.stages = nn.ModuleList() | |
| dp_rates = [x.item() for x in torch.linspace(0, drop_path_rate, sum(self.depths))] | |
| cur = 0 | |
| # Create stages in the same order as the original model | |
| for i in range(len(self.depths)): | |
| in_ch = self.n_filters * (2 ** (len(self.depths) - 1 - i)) | |
| stage = nn.Sequential( | |
| *[layer_type(dim=in_ch, drop_path=dp_rates[cur + j]) for j in range(self.depths[i])] | |
| ) | |
| self.stages.append(stage) | |
| cur += self.depths[i] | |
| if not disable_last_norm: | |
| self.norm = norm_type(in_ch, eps=layernorm_eps) | |
| else: | |
| self.norm = nn.Identity() | |
| self.head = SConv1d(in_ch, self.channels, kernel_size=last_kernel_size, causal=self.causal, pad_mode=pad_mode, norm=norm, bias=bias) | |
| def forward_features(self, x, cache=None, sample_indices=None, use_cache=False, debug=False): | |
| for i in range(len(self.depths)): | |
| # Apply upsampling | |
| for layer in self.upsample_layers[i]: | |
| if isinstance(layer, (SConv1d, SConvTranspose1d)): | |
| x = layer(x, cache=cache, sample_indices=sample_indices, use_cache=use_cache, debug=debug) | |
| else: | |
| x = layer(x) | |
| # Apply stage (Block1D contains Convlayer which contains SConv1d) | |
| for block in self.stages[i]: | |
| if hasattr(block, 'mixer') and hasattr(block.mixer, 'conv') and isinstance(block.mixer.conv, SConv1d): | |
| # Block1D forward with cache support | |
| residual = x | |
| x = block.norm(x) | |
| x = block.mixer.conv(x, cache=cache, sample_indices=sample_indices, use_cache=use_cache, debug=debug) | |
| if block.gamma is not None: | |
| x = x * block.gamma.unsqueeze(-1) | |
| x = residual + x | |
| # FFN part | |
| residual = x | |
| x = block.ffn_norm(x) | |
| x = x.permute(0, 2, 1) | |
| x = block.ffn(x) | |
| x = x.permute(0, 2, 1) | |
| if block.ffn_gamma is not None: | |
| x = x * block.ffn_gamma.unsqueeze(-1) | |
| x = residual + x | |
| else: | |
| x = block(x) | |
| return self.norm(x) | |
| def forward(self, x, cache=None, sample_indices=None, use_cache=False, debug=False): | |
| x = self.forward_features(x, cache=cache, sample_indices=sample_indices, use_cache=use_cache, debug=debug) | |
| x = self.head(x, cache=cache, sample_indices=sample_indices, use_cache=use_cache, debug=debug) | |
| return x | |
| class VibeVoiceTokenizerEncoderOutput: | |
| """ | |
| Output of VibeVoice tokenizer encoder, representing a Gaussian distribution with fixed variance. | |
| Args: | |
| mean (`torch.FloatTensor`): The mean parameters of the distribution. | |
| std (`float` or `torch.FloatTensor`): Fixed standard deviation value. | |
| """ | |
| mean: torch.Tensor | |
| std: Optional[Union[float, torch.Tensor]] = None | |
| def sample(self, dist_type='fix'): | |
| """ | |
| Sample from the distribution. | |
| Args: | |
| dist_type (`str`): Sampling method, either 'fix' or 'gaussian'. | |
| Returns: | |
| `torch.FloatTensor`: Sampled values. | |
| `torch.FloatTensor` (optional): Standard deviation used (only when dist_type='gaussian'). | |
| """ | |
| if dist_type == 'fix': | |
| x = self.mean + self.std * torch.randn_like(self.mean) | |
| return x, self.std | |
| elif dist_type == 'gaussian': | |
| batch_size = self.mean.size(0) | |
| value = self.std / 0.8 | |
| std = torch.randn(batch_size, device=self.mean.device, dtype=self.mean.dtype) * value | |
| while std.dim() < self.mean.dim(): | |
| std = std.unsqueeze(-1) | |
| x = self.mean + std * torch.randn_like(self.mean) | |
| return x, std | |
| else: | |
| return self.mean, self.std | |
| def kl(self): | |
| """Compute KL divergence between this distribution and a standard normal.""" | |
| target = torch.zeros_like(self.mean) | |
| return F.mse_loss(self.mean, target, reduction='none') | |
| def mode(self): | |
| """Return the distribution mode (which is the mean for Gaussian).""" | |
| return self.mean | |
| class VibeVoiceAcousticTokenizerModel(PreTrainedModel): | |
| """VibeVoice speech tokenizer model combining encoder and decoder for acoustic tokens""" | |
| config_class = VibeVoiceAcousticTokenizerConfig | |
| base_model_prefix = "vibevoice_acoustic_tokenizer" | |
| _supports_flash_attn_2 = True | |
| _supports_sdpa = True | |
| _no_split_modules = ["TokenizerEncoder", "TokenizerDecoder"] | |
| def __init__(self, config): | |
| super().__init__(config) | |
| self.register_buffer('fix_std', torch.tensor(config.fix_std), persistent=False) | |
| self.std_dist_type = getattr(config, "std_dist_type", "fix") | |
| # Parse encoder depths | |
| if isinstance(config.encoder_depths, str): | |
| encoder_depths = [int(d) for d in config.encoder_depths.split('-')] | |
| else: | |
| encoder_depths = config.encoder_depths | |
| # Parse decoder depths if provided | |
| if config.decoder_depths is not None and isinstance(config.decoder_depths, str): | |
| decoder_depths = [int(d) for d in config.decoder_depths.split('-')] | |
| else: | |
| # Default: use reversed encoder depths if decoder_depths is None | |
| decoder_depths = list(reversed(encoder_depths)) | |
| # Create encoder config | |
| encoder_config = copy.deepcopy(config) | |
| encoder_config.dimension = config.vae_dim | |
| encoder_config.n_filters = config.encoder_n_filters | |
| encoder_config.ratios = config.encoder_ratios | |
| encoder_config.depths = encoder_depths | |
| encoder_config.norm = config.conv_norm | |
| encoder_config.pad_mode = config.pad_mode | |
| encoder_config.bias = config.conv_bias | |
| encoder_config.layernorm_eps = config.layernorm_eps | |
| encoder_config.layernorm_elementwise_affine = config.layernorm_elementwise_affine | |
| encoder_config.mixer_layer = config.mixer_layer | |
| encoder_config.layer_scale_init_value = config.layer_scale_init_value | |
| encoder_config.disable_last_norm = config.disable_last_norm | |
| # Create decoder config | |
| decoder_config = copy.deepcopy(config) | |
| decoder_config.dimension = config.vae_dim | |
| decoder_config.n_filters = config.decoder_n_filters | |
| decoder_config.ratios = config.decoder_ratios | |
| decoder_config.depths = decoder_depths | |
| decoder_config.norm = config.conv_norm | |
| decoder_config.pad_mode = config.pad_mode | |
| decoder_config.bias = config.conv_bias | |
| decoder_config.layernorm_eps = config.layernorm_eps | |
| decoder_config.layernorm_elementwise_affine = config.layernorm_elementwise_affine | |
| decoder_config.mixer_layer = config.mixer_layer | |
| decoder_config.layer_scale_init_value = config.layer_scale_init_value | |
| decoder_config.disable_last_norm = config.disable_last_norm | |
| # Initialize encoder and decoder | |
| self.encoder = TokenizerEncoder(encoder_config) | |
| self.decoder = TokenizerDecoder(decoder_config) | |
| # Initialize weights | |
| self.apply(self._init_weights) | |
| def _init_weights(self, module): | |
| """Initialize weights for the model""" | |
| if isinstance(module, nn.Linear): | |
| nn.init.normal_(module.weight, std=self.config.weight_init_value) | |
| if module.bias is not None: | |
| nn.init.zeros_(module.bias) | |
| elif isinstance(module, nn.LayerNorm): | |
| nn.init.ones_(module.weight) | |
| nn.init.zeros_(module.bias) | |
| elif isinstance(module, nn.Conv1d): | |
| nn.init.normal_(module.weight, std=self.config.weight_init_value) | |
| if module.bias is not None: | |
| nn.init.zeros_(module.bias) | |
| def encode(self, audio, cache=None, sample_indices=None, use_cache=False, debug=False): | |
| """Convert audio to latent representations""" | |
| latents = self.encoder(audio, cache=cache, sample_indices=sample_indices, use_cache=use_cache, debug=debug) | |
| return VibeVoiceTokenizerEncoderOutput(mean=latents.permute(0, 2, 1), std=self.fix_std) | |
| def sampling(self, encoder_output, dist_type=None): | |
| """Sample from the encoder output distribution""" | |
| dist_type = dist_type or self.std_dist_type | |
| if dist_type == 'fix': | |
| return encoder_output.sample(dist_type='fix') | |
| elif dist_type == 'gaussian': | |
| return encoder_output.sample(dist_type='gaussian') | |
| else: | |
| raise ValueError(f"Unsupported dist_type: {dist_type}, expected 'fix' or 'gaussian'") | |
| def decode(self, latents, cache=None, sample_indices=None, use_cache=False, debug=False): | |
| """Convert latent representations back to audio""" | |
| if latents.shape[1] == self.config.vae_dim: | |
| pass | |
| else: | |
| latents = latents.permute(0, 2, 1) | |
| audio = self.decoder(latents, cache=cache, sample_indices=sample_indices, use_cache=use_cache, debug=debug) | |
| return audio | |
| def forward(self, audio, cache=None, sample_indices=None, use_cache=False, debug=False): | |
| """Full forward pass: encode audio to latents, then decode back to audio""" | |
| encoder_output = self.encode(audio, cache=cache, sample_indices=sample_indices, use_cache=use_cache, debug=debug) | |
| sampled_latents, _ = self.sampling(encoder_output) | |
| reconstructed = self.decode(sampled_latents, cache=cache, sample_indices=sample_indices, use_cache=use_cache, debug=debug) | |
| return reconstructed, sampled_latents | |
| class VibeVoiceSemanticTokenizerModel(PreTrainedModel): | |
| """VibeVoice speech tokenizer model with only encoder for semantic tokens""" | |
| config_class = VibeVoiceSemanticTokenizerConfig | |
| base_model_prefix = "vibevoice_semantic_tokenizer" | |
| _supports_flash_attn_2 = True | |
| _supports_sdpa = True | |
| _no_split_modules = ["TokenizerEncoder"] | |
| def __init__(self, config): | |
| super().__init__(config) | |
| # Parse encoder depths | |
| if isinstance(config.encoder_depths, str): | |
| encoder_depths = [int(d) for d in config.encoder_depths.split('-')] | |
| else: | |
| encoder_depths = config.encoder_depths | |
| # Create encoder config | |
| encoder_config = copy.deepcopy(config) | |
| encoder_config.dimension = config.vae_dim | |
| encoder_config.n_filters = config.encoder_n_filters | |
| encoder_config.ratios = config.encoder_ratios | |
| encoder_config.depths = encoder_depths | |
| encoder_config.norm = config.conv_norm | |
| encoder_config.pad_mode = config.pad_mode | |
| encoder_config.bias = config.conv_bias | |
| encoder_config.layernorm_eps = config.layernorm_eps | |
| encoder_config.layernorm_elementwise_affine = config.layernorm_elementwise_affine | |
| encoder_config.mixer_layer = config.mixer_layer | |
| encoder_config.layer_scale_init_value = config.layer_scale_init_value | |
| encoder_config.disable_last_norm = config.disable_last_norm | |
| # Initialize encoder and decoder | |
| self.encoder = TokenizerEncoder(encoder_config) | |
| # Initialize weights | |
| self.apply(self._init_weights) | |
| def _init_weights(self, module): | |
| """Initialize weights for the model""" | |
| if isinstance(module, nn.Linear): | |
| nn.init.normal_(module.weight, std=self.config.weight_init_value) | |
| if module.bias is not None: | |
| nn.init.zeros_(module.bias) | |
| elif isinstance(module, nn.LayerNorm): | |
| nn.init.ones_(module.weight) | |
| nn.init.zeros_(module.bias) | |
| elif isinstance(module, nn.Conv1d): | |
| nn.init.normal_(module.weight, std=self.config.weight_init_value) | |
| if module.bias is not None: | |
| nn.init.zeros_(module.bias) | |
| def encode(self, audio, cache=None, sample_indices=None, use_cache=False, debug=False): | |
| """Convert audio to latent representations""" | |
| latents = self.encoder(audio, cache=cache, sample_indices=sample_indices, use_cache=use_cache, debug=debug) | |
| return VibeVoiceTokenizerEncoderOutput(mean=latents.permute(0, 2, 1)) | |
| def sampling(self, encoder_output, dist_type=None): | |
| """Sample from the encoder output distribution""" | |
| return encoder_output.sample(dist_type='none') | |
| def forward(self, audio, cache=None, sample_indices=None, use_cache=False, debug=False): | |
| """Full forward pass: encode audio to latents, then decode back to audio""" | |
| encoder_output = self.encode(audio, cache=cache, sample_indices=sample_indices, use_cache=use_cache, debug=debug) | |
| sampled_latents, _ = self.sampling(encoder_output, dist_type='none') | |
| return None, sampled_latents | |
| AutoModel.register(VibeVoiceAcousticTokenizerConfig, VibeVoiceAcousticTokenizerModel) | |
| AutoModel.register(VibeVoiceSemanticTokenizerConfig, VibeVoiceSemanticTokenizerModel) | |
| __all__ = [ | |
| "VibeVoiceTokenizerStreamingCache", | |
| "VibeVoiceAcousticTokenizerModel", | |
| "VibeVoiceSemanticTokenizerModel", | |
| ] |