Spaces:
Running
Running
| import os | |
| import time | |
| import logging | |
| import torch | |
| import soundfile as sf | |
| import numpy as np | |
| from dia.model import Dia | |
| class DiaTTSWrapper: | |
| def __init__(self, model_name="nari-labs/Dia-1.6B", device="cuda", dtype="float16"): | |
| self.device = device | |
| self.sr = 44100 | |
| logging.info(f"[DiaTTS] Загрузка модели {model_name} на {device} (dtype={dtype})") | |
| self.model = Dia.from_pretrained( | |
| model_name, | |
| device=device, | |
| compute_dtype=dtype | |
| ) | |
| def generate_audio_from_text(self, text: str, paralinguistic: str = "", max_duration: float = None) -> torch.Tensor: | |
| try: | |
| if paralinguistic: | |
| clean = paralinguistic.strip("()").lower() | |
| text = f"{text} ({clean})" | |
| audio_np = self.model.generate( | |
| text, | |
| use_torch_compile=False, | |
| verbose=False | |
| ) | |
| wf = torch.from_numpy(audio_np).float().unsqueeze(0) | |
| if max_duration: | |
| max_samples = int(self.sr * max_duration) | |
| wf = wf[:, :max_samples] | |
| return wf | |
| except Exception as e: | |
| logging.error(f"[DiaTTS] Ошибка генерации аудио: {e}") | |
| return torch.zeros(1, self.sr) | |
| def generate_and_save_audio( | |
| self, | |
| text: str, | |
| paralinguistic: str = "", | |
| out_dir="tts_outputs", | |
| filename_prefix="tts", | |
| max_duration: float = None, | |
| use_timestamp=True, | |
| skip_if_exists=True, | |
| max_trim_duration: float = None | |
| ) -> torch.Tensor: | |
| os.makedirs(out_dir, exist_ok=True) | |
| if use_timestamp: | |
| timestr = time.strftime("%Y%m%d_%H%M%S") | |
| filename = f"{filename_prefix}_{timestr}.wav" | |
| else: | |
| filename = f"{filename_prefix}.wav" | |
| out_path = os.path.join(out_dir, filename) | |
| if skip_if_exists and os.path.exists(out_path): | |
| logging.info(f"[DiaTTS] ⏭️ Пропущено — уже существует: {out_path}") | |
| return None | |
| wf = self.generate_audio_from_text(text, paralinguistic, max_duration) | |
| np_wf = wf.squeeze().cpu().numpy() | |
| if max_trim_duration is not None: | |
| max_len = int(self.sr * max_trim_duration) | |
| if len(np_wf) > max_len: | |
| logging.info(f"[DiaTTS] ✂️ Обрезка аудио до {max_trim_duration} сек.") | |
| np_wf = np_wf[:max_len] | |
| sf.write(out_path, np_wf, self.sr) | |
| logging.info(f"[DiaTTS] 💾 Сохранено аудио: {out_path}") | |
| return wf | |
| def get_sample_rate(self): | |
| return self.sr | |