Spaces:
Runtime error
Runtime error
| import time | |
| from enum import Enum | |
| import dac | |
| import numpy as np | |
| import torch | |
| import torchaudio | |
| from huggingface_hub import hf_hub_download | |
| from .audio import apply_audio_delay, build_delay_indices, build_revert_indices, decode, revert_audio_delay | |
| from .config import DiaConfig | |
| from .layers import DiaModel | |
| from .state import DecoderInferenceState, DecoderOutput, EncoderInferenceState | |
| DEFAULT_SAMPLE_RATE = 44100 | |
| def _get_default_device(): | |
| if torch.cuda.is_available(): | |
| return torch.device("cuda") | |
| elif hasattr(torch.backends, "mps") and torch.backends.mps.is_available(): | |
| return torch.device("mps") | |
| return torch.device("cpu") | |
| def _sample_next_token( | |
| logits_BCxV: torch.Tensor, | |
| temperature: float, | |
| top_p: float, | |
| cfg_filter_top_k: int | None = None, | |
| ) -> torch.Tensor: | |
| if temperature == 0.0: | |
| return torch.argmax(logits_BCxV, dim=-1) | |
| logits_BCxV = logits_BCxV / temperature | |
| if cfg_filter_top_k is not None: | |
| _, top_k_indices_BCxV = torch.topk(logits_BCxV, k=cfg_filter_top_k, dim=-1) | |
| mask = torch.ones_like(logits_BCxV, dtype=torch.bool) | |
| mask.scatter_(dim=-1, index=top_k_indices_BCxV, value=False) | |
| logits_BCxV = logits_BCxV.masked_fill(mask, -torch.inf) | |
| if top_p < 1.0: | |
| probs_BCxV = torch.softmax(logits_BCxV, dim=-1) | |
| sorted_probs_BCxV, sorted_indices_BCxV = torch.sort(probs_BCxV, dim=-1, descending=True) | |
| cumulative_probs_BCxV = torch.cumsum(sorted_probs_BCxV, dim=-1) | |
| sorted_indices_to_remove_BCxV = cumulative_probs_BCxV > top_p | |
| sorted_indices_to_remove_BCxV[..., 1:] = sorted_indices_to_remove_BCxV[..., :-1].clone() | |
| sorted_indices_to_remove_BCxV[..., 0] = 0 | |
| indices_to_remove_BCxV = torch.zeros_like(sorted_indices_to_remove_BCxV) | |
| indices_to_remove_BCxV.scatter_(dim=-1, index=sorted_indices_BCxV, src=sorted_indices_to_remove_BCxV) | |
| logits_BCxV = logits_BCxV.masked_fill(indices_to_remove_BCxV, -torch.inf) | |
| final_probs_BCxV = torch.softmax(logits_BCxV, dim=-1) | |
| sampled_indices_BC = torch.multinomial(final_probs_BCxV, num_samples=1) | |
| sampled_indices_C = sampled_indices_BC.squeeze(-1) | |
| return sampled_indices_C | |
| class ComputeDtype(str, Enum): | |
| FLOAT32 = "float32" | |
| FLOAT16 = "float16" | |
| BFLOAT16 = "bfloat16" | |
| def to_dtype(self) -> torch.dtype: | |
| if self == ComputeDtype.FLOAT32: | |
| return torch.float32 | |
| elif self == ComputeDtype.FLOAT16: | |
| return torch.float16 | |
| elif self == ComputeDtype.BFLOAT16: | |
| return torch.bfloat16 | |
| else: | |
| raise ValueError(f"Unsupported compute dtype: {self}") | |
| class Dia: | |
| def __init__( | |
| self, | |
| config: DiaConfig, | |
| compute_dtype: str | ComputeDtype = ComputeDtype.FLOAT32, | |
| device: torch.device | None = None, | |
| ): | |
| """Initializes the Dia model. | |
| Args: | |
| config: The configuration object for the model. | |
| device: The device to load the model onto. If None, will automatically select the best available device. | |
| Raises: | |
| RuntimeError: If there is an error loading the DAC model. | |
| """ | |
| super().__init__() | |
| self.config = config | |
| self.device = device if device is not None else _get_default_device() | |
| if isinstance(compute_dtype, str): | |
| compute_dtype = ComputeDtype(compute_dtype) | |
| self.compute_dtype = compute_dtype.to_dtype() | |
| self.model = DiaModel(config, self.compute_dtype) | |
| self.dac_model = None | |
| def from_local( | |
| cls, | |
| config_path: str, | |
| checkpoint_path: str, | |
| compute_dtype: str | ComputeDtype = ComputeDtype.FLOAT32, | |
| device: torch.device | None = None, | |
| ) -> "Dia": | |
| """Loads the Dia model from local configuration and checkpoint files. | |
| Args: | |
| config_path: Path to the configuration JSON file. | |
| checkpoint_path: Path to the model checkpoint (.pth) file. | |
| device: The device to load the model onto. If None, will automatically select the best available device. | |
| Returns: | |
| An instance of the Dia model loaded with weights and set to eval mode. | |
| Raises: | |
| FileNotFoundError: If the config or checkpoint file is not found. | |
| RuntimeError: If there is an error loading the checkpoint. | |
| """ | |
| config = DiaConfig.load(config_path) | |
| if config is None: | |
| raise FileNotFoundError(f"Config file not found at {config_path}") | |
| dia = cls(config, compute_dtype, device) | |
| try: | |
| state_dict = torch.load(checkpoint_path, map_location=dia.device) | |
| dia.model.load_state_dict(state_dict) | |
| except FileNotFoundError: | |
| raise FileNotFoundError(f"Checkpoint file not found at {checkpoint_path}") | |
| except Exception as e: | |
| raise RuntimeError(f"Error loading checkpoint from {checkpoint_path}") from e | |
| dia.model.to(dia.device) | |
| dia.model.eval() | |
| dia._load_dac_model() | |
| return dia | |
| def from_pretrained( | |
| cls, | |
| model_name: str = "nari-labs/Dia-1.6B", | |
| compute_dtype: str | ComputeDtype = ComputeDtype.FLOAT32, | |
| device: torch.device | None = None, | |
| ) -> "Dia": | |
| """Loads the Dia model from a Hugging Face Hub repository. | |
| Downloads the configuration and checkpoint files from the specified | |
| repository ID and then loads the model. | |
| Args: | |
| model_name: The Hugging Face Hub repository ID (e.g., "NariLabs/Dia-1.6B"). | |
| device: The device to load the model onto. If None, will automatically select the best available device. | |
| Returns: | |
| An instance of the Dia model loaded with weights and set to eval mode. | |
| Raises: | |
| FileNotFoundError: If config or checkpoint download/loading fails. | |
| RuntimeError: If there is an error loading the checkpoint. | |
| """ | |
| config_path = hf_hub_download(repo_id=model_name, filename="config.json") | |
| checkpoint_path = hf_hub_download(repo_id=model_name, filename="dia-v0_1.pth") | |
| return cls.from_local(config_path, checkpoint_path, compute_dtype, device) | |
| def _load_dac_model(self): | |
| try: | |
| dac_model_path = dac.utils.download() | |
| dac_model = dac.DAC.load(dac_model_path).to(self.device) | |
| except Exception as e: | |
| raise RuntimeError("Failed to load DAC model") from e | |
| self.dac_model = dac_model | |
| def _prepare_text_input(self, text: str) -> torch.Tensor: | |
| """Encodes text prompt, pads, and creates attention mask and positions.""" | |
| text_pad_value = self.config.data.text_pad_value | |
| max_len = self.config.data.text_length | |
| byte_text = text.encode("utf-8") | |
| replaced_bytes = byte_text.replace(b"[S1]", b"\x01").replace(b"[S2]", b"\x02") | |
| text_tokens = list(replaced_bytes) | |
| current_len = len(text_tokens) | |
| padding_needed = max_len - current_len | |
| if padding_needed <= 0: | |
| text_tokens = text_tokens[:max_len] | |
| padded_text_np = np.array(text_tokens, dtype=np.uint8) | |
| else: | |
| padded_text_np = np.pad( | |
| text_tokens, | |
| (0, padding_needed), | |
| mode="constant", | |
| constant_values=text_pad_value, | |
| ).astype(np.uint8) | |
| src_tokens = torch.from_numpy(padded_text_np).to(torch.long).to(self.device).unsqueeze(0) # [1, S] | |
| return src_tokens | |
| def _prepare_audio_prompt(self, audio_prompt: torch.Tensor | None) -> tuple[torch.Tensor, int]: | |
| num_channels = self.config.data.channels | |
| audio_bos_value = self.config.data.audio_bos_value | |
| audio_pad_value = self.config.data.audio_pad_value | |
| delay_pattern = self.config.data.delay_pattern | |
| max_delay_pattern = max(delay_pattern) | |
| prefill = torch.full( | |
| (1, num_channels), | |
| fill_value=audio_bos_value, | |
| dtype=torch.int, | |
| device=self.device, | |
| ) | |
| prefill_step = 1 | |
| if audio_prompt is not None: | |
| prefill_step += audio_prompt.shape[0] | |
| prefill = torch.cat([prefill, audio_prompt], dim=0) | |
| delay_pad_tensor = torch.full( | |
| (max_delay_pattern, num_channels), fill_value=-1, dtype=torch.int, device=self.device | |
| ) | |
| prefill = torch.cat([prefill, delay_pad_tensor], dim=0) | |
| delay_precomp = build_delay_indices( | |
| B=1, | |
| T=prefill.shape[0], | |
| C=num_channels, | |
| delay_pattern=delay_pattern, | |
| ) | |
| prefill = apply_audio_delay( | |
| audio_BxTxC=prefill.unsqueeze(0), | |
| pad_value=audio_pad_value, | |
| bos_value=audio_bos_value, | |
| precomp=delay_precomp, | |
| ).squeeze(0) | |
| return prefill, prefill_step | |
| def _prepare_generation(self, text: str, audio_prompt: str | torch.Tensor | None, verbose: bool): | |
| enc_input_cond = self._prepare_text_input(text) | |
| enc_input_uncond = torch.zeros_like(enc_input_cond) | |
| enc_input = torch.cat([enc_input_uncond, enc_input_cond], dim=0) | |
| if isinstance(audio_prompt, str): | |
| audio_prompt = self.load_audio(audio_prompt) | |
| prefill, prefill_step = self._prepare_audio_prompt(audio_prompt) | |
| if verbose: | |
| print("generate: data loaded") | |
| enc_state = EncoderInferenceState.new(self.config, enc_input_cond) | |
| encoder_out = self.model.encoder(enc_input, enc_state) | |
| dec_cross_attn_cache = self.model.decoder.precompute_cross_attn_cache(encoder_out, enc_state.positions) | |
| dec_state = DecoderInferenceState.new( | |
| self.config, enc_state, encoder_out, dec_cross_attn_cache, self.compute_dtype | |
| ) | |
| dec_output = DecoderOutput.new(self.config, self.device) | |
| dec_output.prefill(prefill, prefill_step) | |
| dec_step = prefill_step - 1 | |
| if dec_step > 0: | |
| dec_state.prepare_step(0, dec_step) | |
| tokens_BxTxC = dec_output.get_tokens_at(0, dec_step).unsqueeze(0).expand(2, -1, -1) | |
| self.model.decoder.forward(tokens_BxTxC, dec_state) | |
| return dec_state, dec_output | |
| def _decoder_step( | |
| self, | |
| tokens_Bx1xC: torch.Tensor, | |
| dec_state: DecoderInferenceState, | |
| cfg_scale: float, | |
| temperature: float, | |
| top_p: float, | |
| cfg_filter_top_k: int, | |
| ) -> torch.Tensor: | |
| audio_eos_value = self.config.data.audio_eos_value | |
| logits_Bx1xCxV = self.model.decoder.decode_step(tokens_Bx1xC, dec_state) | |
| logits_last_BxCxV = logits_Bx1xCxV[:, -1, :, :] | |
| uncond_logits_CxV = logits_last_BxCxV[0, :, :] | |
| cond_logits_CxV = logits_last_BxCxV[1, :, :] | |
| logits_CxV = cond_logits_CxV + cfg_scale * (cond_logits_CxV - uncond_logits_CxV) | |
| logits_CxV[:, audio_eos_value + 1 :] = -torch.inf | |
| logits_CxV[1:, audio_eos_value:] = -torch.inf | |
| pred_C = _sample_next_token( | |
| logits_CxV.float(), | |
| temperature=temperature, | |
| top_p=top_p, | |
| cfg_filter_top_k=cfg_filter_top_k, | |
| ) | |
| return pred_C | |
| def _generate_output(self, generated_codes: torch.Tensor) -> np.ndarray: | |
| num_channels = self.config.data.channels | |
| seq_length = generated_codes.shape[0] | |
| delay_pattern = self.config.data.delay_pattern | |
| audio_pad_value = self.config.data.audio_pad_value | |
| max_delay_pattern = max(delay_pattern) | |
| revert_precomp = build_revert_indices( | |
| B=1, | |
| T=seq_length, | |
| C=num_channels, | |
| delay_pattern=delay_pattern, | |
| ) | |
| codebook = revert_audio_delay( | |
| audio_BxTxC=generated_codes.unsqueeze(0), | |
| pad_value=audio_pad_value, | |
| precomp=revert_precomp, | |
| T=seq_length, | |
| )[:, :-max_delay_pattern, :] | |
| min_valid_index = 0 | |
| max_valid_index = 1023 | |
| invalid_mask = (codebook < min_valid_index) | (codebook > max_valid_index) | |
| codebook[invalid_mask] = 0 | |
| audio = decode(self.dac_model, codebook.transpose(1, 2)) | |
| return audio.squeeze().cpu().numpy() | |
| def load_audio(self, audio_path: str) -> torch.Tensor: | |
| audio, sr = torchaudio.load(audio_path, channels_first=True) # C, T | |
| if sr != DEFAULT_SAMPLE_RATE: | |
| audio = torchaudio.functional.resample(audio, sr, DEFAULT_SAMPLE_RATE) | |
| audio = audio.to(self.device).unsqueeze(0) # 1, C, T | |
| audio_data = self.dac_model.preprocess(audio, DEFAULT_SAMPLE_RATE) | |
| _, encoded_frame, _, _, _ = self.dac_model.encode(audio_data) # 1, C, T | |
| return encoded_frame.squeeze(0).transpose(0, 1) | |
| def save_audio(self, path: str, audio: np.ndarray): | |
| import soundfile as sf | |
| sf.write(path, audio, DEFAULT_SAMPLE_RATE) | |
| def generate( | |
| self, | |
| text: str, | |
| max_tokens: int | None = None, | |
| cfg_scale: float = 3.0, | |
| temperature: float = 1.3, | |
| top_p: float = 0.95, | |
| use_torch_compile: bool = False, | |
| cfg_filter_top_k: int = 35, | |
| audio_prompt: str | torch.Tensor | None = None, | |
| audio_prompt_path: str | None = None, | |
| use_cfg_filter: bool | None = None, | |
| verbose: bool = False, | |
| ) -> np.ndarray: | |
| audio_eos_value = self.config.data.audio_eos_value | |
| audio_pad_value = self.config.data.audio_pad_value | |
| delay_pattern = self.config.data.delay_pattern | |
| max_tokens = self.config.data.audio_length if max_tokens is None else max_tokens | |
| max_delay_pattern = max(delay_pattern) | |
| self.model.eval() | |
| if audio_prompt_path: | |
| print("Warning: audio_prompt_path is deprecated. Use audio_prompt instead.") | |
| audio_prompt = audio_prompt_path | |
| if use_cfg_filter is not None: | |
| print("Warning: use_cfg_filter is deprecated.") | |
| if verbose: | |
| total_start_time = time.time() | |
| dec_state, dec_output = self._prepare_generation(text, audio_prompt, verbose) | |
| dec_step = dec_output.prefill_step - 1 | |
| bos_countdown = max_delay_pattern | |
| eos_detected = False | |
| eos_countdown = -1 | |
| if use_torch_compile: | |
| step_fn = torch.compile(self._decoder_step, mode="default") | |
| else: | |
| step_fn = self._decoder_step | |
| if verbose: | |
| print("generate: starting generation loop") | |
| if use_torch_compile: | |
| print("generate: by using use_torch_compile=True, the first step would take long") | |
| start_time = time.time() | |
| while dec_step < max_tokens: | |
| dec_state.prepare_step(dec_step) | |
| tokens_Bx1xC = dec_output.get_tokens_at(dec_step).unsqueeze(0).expand(2, -1, -1) | |
| pred_C = step_fn( | |
| tokens_Bx1xC, | |
| dec_state, | |
| cfg_scale, | |
| temperature, | |
| top_p, | |
| cfg_filter_top_k, | |
| ) | |
| if (not eos_detected and pred_C[0] == audio_eos_value) or dec_step == max_tokens - max_delay_pattern - 1: | |
| eos_detected = True | |
| eos_countdown = max_delay_pattern | |
| if eos_countdown > 0: | |
| step_after_eos = max_delay_pattern - eos_countdown | |
| for i, d in enumerate(delay_pattern): | |
| if step_after_eos == d: | |
| pred_C[i] = audio_eos_value | |
| elif step_after_eos > d: | |
| pred_C[i] = audio_pad_value | |
| eos_countdown -= 1 | |
| bos_countdown = max(0, bos_countdown - 1) | |
| dec_output.update_one(pred_C, dec_step + 1, bos_countdown > 0) | |
| if eos_countdown == 0: | |
| break | |
| dec_step += 1 | |
| if verbose and dec_step % 86 == 0: | |
| duration = time.time() - start_time | |
| print( | |
| f"generate step {dec_step}: speed={86 / duration:.3f} tokens/s, realtime factor={1 / duration:.3f}x" | |
| ) | |
| start_time = time.time() | |
| if dec_output.prefill_step >= dec_step + 1: | |
| print("Warning: Nothing generated") | |
| return None | |
| generated_codes = dec_output.generated_tokens[dec_output.prefill_step : dec_step + 1, :] | |
| if verbose: | |
| total_step = dec_step + 1 - dec_output.prefill_step | |
| total_duration = time.time() - total_start_time | |
| print(f"generate: total step={total_step}, total duration={total_duration:.3f}s") | |
| return self._generate_output(generated_codes) | |