Spaces:
Running
Running
| ο»Ώimport torch | |
| import tempfile | |
| import logging | |
| import soundfile as sf | |
| import numpy as np | |
| from transformers import VitsModel, VitsTokenizer | |
| import asyncio | |
| from typing import Optional | |
| logger = logging.getLogger(__name__) | |
| class SimpleTTSClient: | |
| """ | |
| Simple TTS client using Facebook VITS model | |
| No speaker embeddings needed - more reliable | |
| """ | |
| def __init__(self): | |
| self.device = "cuda" if torch.cuda.is_available() else "cpu" | |
| self.model = None | |
| self.tokenizer = None | |
| self.model_loaded = False | |
| logger.info(f"Simple TTS Client initialized on device: {self.device}") | |
| async def load_model(self): | |
| """Load VITS model - simpler and more reliable""" | |
| try: | |
| logger.info("Loading Facebook VITS TTS model...") | |
| # Use a simple VITS model that doesn't require speaker embeddings | |
| model_name = "facebook/mms-tts-eng" | |
| self.tokenizer = VitsTokenizer.from_pretrained(model_name) | |
| self.model = VitsModel.from_pretrained(model_name).to(self.device) | |
| self.model_loaded = True | |
| logger.info("β VITS TTS model loaded successfully") | |
| return True | |
| except Exception as e: | |
| logger.error(f"β Failed to load VITS model: {e}") | |
| logger.info("Falling back to basic TTS approach...") | |
| return await self._load_fallback_model() | |
| async def _load_fallback_model(self): | |
| """Fallback to an even simpler TTS approach""" | |
| try: | |
| # Use a different model that's more reliable | |
| from transformers import SpeechT5Processor, SpeechT5ForTextToSpeech, SpeechT5HifiGan | |
| logger.info("Loading SpeechT5 with minimal configuration...") | |
| self.processor = SpeechT5Processor.from_pretrained("microsoft/speecht5_tts") | |
| self.model = SpeechT5ForTextToSpeech.from_pretrained("microsoft/speecht5_tts").to(self.device) | |
| self.vocoder = SpeechT5HifiGan.from_pretrained("microsoft/speecht5_hifigan").to(self.device) | |
| # Create a simple fixed speaker embedding | |
| self.speaker_embedding = torch.randn(1, 512).to(self.device) | |
| self.model_loaded = True | |
| self.use_fallback = True | |
| logger.info("β Fallback TTS model loaded successfully") | |
| return True | |
| except Exception as e: | |
| logger.error(f"β All TTS models failed to load: {e}") | |
| return False | |
| async def text_to_speech(self, text: str, voice_id: Optional[str] = None) -> str: | |
| """Convert text to speech""" | |
| if not self.model_loaded: | |
| logger.info("Model not loaded, loading now...") | |
| success = await self.load_model() | |
| if not success: | |
| raise Exception("Failed to load TTS model") | |
| try: | |
| logger.info(f"Generating speech for text: {text[:50]}...") | |
| if hasattr(self, 'use_fallback') and self.use_fallback: | |
| # Use SpeechT5 fallback | |
| inputs = self.processor(text=text, return_tensors="pt").to(self.device) | |
| with torch.no_grad(): | |
| speech = self.model.generate_speech( | |
| inputs["input_ids"], | |
| self.speaker_embedding, | |
| vocoder=self.vocoder | |
| ) | |
| else: | |
| # Use VITS model | |
| inputs = self.tokenizer(text, return_tensors="pt").to(self.device) | |
| with torch.no_grad(): | |
| output = self.model(**inputs) | |
| speech = output.waveform.squeeze() | |
| # Convert to audio file | |
| audio_data = speech.cpu().numpy() | |
| # Ensure audio data is in the right format | |
| if audio_data.ndim > 1: | |
| audio_data = audio_data.squeeze() | |
| # Save to temporary file | |
| temp_file = tempfile.NamedTemporaryFile(delete=False, suffix='.wav') | |
| sf.write(temp_file.name, audio_data, samplerate=16000) | |
| temp_file.close() | |
| logger.info(f"β Generated speech audio: {temp_file.name}") | |
| return temp_file.name | |
| except Exception as e: | |
| logger.error(f"β Error generating speech: {e}") | |
| raise Exception(f"TTS generation failed: {e}") | |