import os import torch import tempfile import logging import soundfile as sf import numpy as np import asyncio from typing import Optional # Set HuggingFace cache directories before importing transformers os.environ.setdefault('HF_HOME', '/tmp/huggingface') os.environ.setdefault('TRANSFORMERS_CACHE', '/tmp/huggingface/transformers') os.environ.setdefault('HF_DATASETS_CACHE', '/tmp/huggingface/datasets') os.environ.setdefault('HUGGINGFACE_HUB_CACHE', '/tmp/huggingface/hub') # Create cache directories for cache_dir in ['/tmp/huggingface', '/tmp/huggingface/transformers', '/tmp/huggingface/datasets', '/tmp/huggingface/hub']: os.makedirs(cache_dir, exist_ok=True) # Try to import transformers components try: from transformers import ( VitsModel, VitsTokenizer, SpeechT5Processor, SpeechT5ForTextToSpeech, SpeechT5HifiGan ) from datasets import load_dataset TRANSFORMERS_AVAILABLE = True print("✅ Transformers and datasets available") except ImportError as e: TRANSFORMERS_AVAILABLE = False print(f"⚠️ Advanced TTS models not available: {e}") print("💡 Install with: pip install transformers datasets") logger = logging.getLogger(__name__) class AdvancedTTSClient: """ Advanced TTS client using Facebook VITS and SpeechT5 models Falls back gracefully if models are not available """ def __init__(self): self.device = "cuda" if torch.cuda.is_available() else "cpu" self.models_loaded = False self.transformers_available = TRANSFORMERS_AVAILABLE # Model instances - will be loaded on demand self.vits_model = None self.vits_tokenizer = None self.speecht5_processor = None self.speecht5_model = None self.speecht5_vocoder = None self.speaker_embeddings = None logger.info(f"Advanced TTS Client initialized on device: {self.device}") logger.info(f"Transformers available: {self.transformers_available}") async def load_models(self): """Load TTS models asynchronously""" if not self.transformers_available: logger.warning("❌ Transformers not available - cannot load advanced TTS models") return False try: logger.info("Loading Facebook VITS and SpeechT5 models...") # Load SpeechT5 model (Microsoft) - usually more reliable try: logger.info("Loading Microsoft SpeechT5 model...") logger.info(f"Using cache directory: {os.environ.get('TRANSFORMERS_CACHE', 'default')}") # Add cache_dir parameter and retry logic cache_dir = os.environ.get('TRANSFORMERS_CACHE', '/tmp/huggingface/transformers') # Try with timeout and better error handling import asyncio async def load_model_with_timeout(): loop = asyncio.get_event_loop() # Load processor processor_task = loop.run_in_executor( None, lambda: SpeechT5Processor.from_pretrained( "microsoft/speecht5_tts", cache_dir=cache_dir ) ) # Load model model_task = loop.run_in_executor( None, lambda: SpeechT5ForTextToSpeech.from_pretrained( "microsoft/speecht5_tts", cache_dir=cache_dir ).to(self.device) ) # Load vocoder vocoder_task = loop.run_in_executor( None, lambda: SpeechT5HifiGan.from_pretrained( "microsoft/speecht5_hifigan", cache_dir=cache_dir ).to(self.device) ) # Wait for all with timeout self.speecht5_processor, self.speecht5_model, self.speecht5_vocoder = await asyncio.wait_for( asyncio.gather(processor_task, model_task, vocoder_task), timeout=300 # 5 minutes timeout ) await load_model_with_timeout() # Load speaker embeddings for SpeechT5 logger.info("Loading speaker embeddings...") try: embeddings_dataset = load_dataset("Matthijs/cmu-arctic-xvectors", split="validation") self.speaker_embeddings = torch.tensor(embeddings_dataset[0]["xvector"]).unsqueeze(0).to(self.device) logger.info("✅ Speaker embeddings loaded from dataset") except Exception as embed_error: logger.warning(f"Failed to load speaker embeddings from dataset: {embed_error}") # Create default embedding self.speaker_embeddings = torch.randn(1, 512).to(self.device) logger.info("✅ Using generated speaker embeddings") logger.info("✅ SpeechT5 model loaded successfully") except asyncio.TimeoutError: logger.error("❌ SpeechT5 loading timed out after 5 minutes") except PermissionError as perm_error: logger.error(f"❌ SpeechT5 loading failed due to cache permission error: {perm_error}") logger.error("💡 Try clearing cache directory or using different cache location") except Exception as speecht5_error: logger.warning(f"SpeechT5 loading failed: {speecht5_error}") # Try to load VITS model (Facebook MMS) as secondary option try: logger.info("Loading Facebook VITS (MMS) model...") cache_dir = os.environ.get('TRANSFORMERS_CACHE', '/tmp/huggingface/transformers') async def load_vits_with_timeout(): loop = asyncio.get_event_loop() model_task = loop.run_in_executor( None, lambda: VitsModel.from_pretrained( "facebook/mms-tts-eng", cache_dir=cache_dir ).to(self.device) ) tokenizer_task = loop.run_in_executor( None, lambda: VitsTokenizer.from_pretrained( "facebook/mms-tts-eng", cache_dir=cache_dir ) ) self.vits_model, self.vits_tokenizer = await asyncio.wait_for( asyncio.gather(model_task, tokenizer_task), timeout=300 # 5 minutes timeout ) await load_vits_with_timeout() logger.info("✅ VITS model loaded successfully") except asyncio.TimeoutError: logger.error("❌ VITS loading timed out after 5 minutes") except PermissionError as perm_error: logger.error(f"❌ VITS loading failed due to cache permission error: {perm_error}") logger.error("💡 Try clearing cache directory or using different cache location") except Exception as vits_error: logger.warning(f"VITS loading failed: {vits_error}") # Check if at least one model loaded if self.speecht5_model is not None or self.vits_model is not None: self.models_loaded = True logger.info("✅ Advanced TTS models loaded successfully!") return True else: logger.error("❌ No TTS models could be loaded") return False except Exception as e: logger.error(f"❌ Error loading TTS models: {e}") return False def get_voice_embedding(self, voice_id: Optional[str] = None): """Get speaker embedding for different voices""" if self.speaker_embeddings is None: # Create default if not available self.speaker_embeddings = torch.randn(1, 512).to(self.device) if voice_id is None: return self.speaker_embeddings # Voice mapping for different voice IDs with different characteristics voice_seed = abs(hash(voice_id)) % 1000 torch.manual_seed(voice_seed) voice_variations = { "21m00Tcm4TlvDq8ikWAM": torch.randn(1, 512) * 0.8, # Female-ish "pNInz6obpgDQGcFmaJgB": torch.randn(1, 512) * 1.2, # Male-ish "EXAVITQu4vr4xnSDxMaL": torch.randn(1, 512) * 0.6, # Sweet "ErXwobaYiN019PkySvjV": torch.randn(1, 512) * 1.0, # Professional "TxGEqnHWrfGW9XjX": torch.randn(1, 512) * 1.4, # Deep "yoZ06aMxZJJ28mfd3POQ": torch.randn(1, 512) * 0.9, # Friendly "AZnzlk1XvdvUeBnXmlld": torch.randn(1, 512) * 1.1, # Strong } if voice_id in voice_variations: embedding = voice_variations[voice_id].to(self.device) logger.info(f"Using voice variation for: {voice_id}") return embedding else: # Use original embeddings for unknown voice IDs return self.speaker_embeddings async def generate_with_vits(self, text: str, voice_id: Optional[str] = None) -> tuple: """Generate speech using Facebook VITS model""" try: if not self.vits_model or not self.vits_tokenizer: raise Exception("VITS model not loaded") logger.info(f"Generating speech with VITS: {text[:50]}...") # Tokenize text inputs = self.vits_tokenizer(text, return_tensors="pt").to(self.device) # Generate speech with torch.no_grad(): output = self.vits_model(**inputs).waveform # Convert to numpy audio_data = output.squeeze().cpu().numpy() sample_rate = self.vits_model.config.sampling_rate logger.info(f"✅ VITS generation successful: {len(audio_data)/sample_rate:.1f}s") return audio_data, sample_rate except Exception as e: logger.error(f"VITS generation failed: {e}") raise async def generate_with_speecht5(self, text: str, voice_id: Optional[str] = None) -> tuple: """Generate speech using Microsoft SpeechT5 model""" try: if not self.speecht5_model or not self.speecht5_processor: raise Exception("SpeechT5 model not loaded") logger.info(f"Generating speech with SpeechT5: {text[:50]}...") # Process text inputs = self.speecht5_processor(text=text, return_tensors="pt").to(self.device) # Get speaker embedding speaker_embedding = self.get_voice_embedding(voice_id) # Generate speech with torch.no_grad(): speech = self.speecht5_model.generate_speech( inputs["input_ids"], speaker_embedding, vocoder=self.speecht5_vocoder ) # Convert to numpy audio_data = speech.cpu().numpy() sample_rate = 16000 # SpeechT5 default sample rate logger.info(f"✅ SpeechT5 generation successful: {len(audio_data)/sample_rate:.1f}s") return audio_data, sample_rate except Exception as e: logger.error(f"SpeechT5 generation failed: {e}") raise async def text_to_speech(self, text: str, voice_id: Optional[str] = None) -> str: """ Convert text to speech using Facebook VITS or SpeechT5 """ if not self.transformers_available: logger.error("❌ Transformers not available - cannot use advanced TTS") raise Exception("Advanced TTS models not available. Install: pip install transformers datasets") if not self.models_loaded: logger.info("TTS models not loaded, loading now...") success = await self.load_models() if not success: logger.error("TTS model loading failed") raise Exception("TTS models failed to load") try: logger.info(f"Generating speech for text: {text[:50]}...") logger.info(f"Using voice profile: {voice_id or 'default'}") # Try SpeechT5 first (usually better quality and more reliable) try: audio_data, sample_rate = await self.generate_with_speecht5(text, voice_id) method = "SpeechT5" except Exception as speecht5_error: logger.warning(f"SpeechT5 failed: {speecht5_error}") # Fall back to VITS try: audio_data, sample_rate = await self.generate_with_vits(text, voice_id) method = "VITS" except Exception as vits_error: logger.error(f"Both SpeechT5 and VITS failed") logger.error(f"SpeechT5 error: {speecht5_error}") logger.error(f"VITS error: {vits_error}") raise Exception(f"All advanced TTS methods failed: SpeechT5({speecht5_error}), VITS({vits_error})") # Normalize audio if np.max(np.abs(audio_data)) > 0: audio_data = audio_data / np.max(np.abs(audio_data)) * 0.8 # Save to temporary file temp_file = tempfile.NamedTemporaryFile(delete=False, suffix='.wav') sf.write(temp_file.name, audio_data, samplerate=sample_rate) temp_file.close() logger.info(f"✅ Generated audio file: {temp_file.name}") logger.info(f"📊 Audio details: {len(audio_data)/sample_rate:.1f}s, {sample_rate}Hz, method: {method}") logger.info("🎙️ Using advanced open-source TTS models") return temp_file.name except Exception as e: logger.error(f"❌ Critical error in advanced TTS generation: {str(e)}") logger.error(f"Exception type: {type(e).__name__}") raise Exception(f"Advanced TTS generation failed: {e}") async def get_available_voices(self): """Get list of available voice configurations""" return { "21m00Tcm4TlvDq8ikWAM": "Female (Neutral)", "pNInz6obpgDQGcFmaJgB": "Male (Professional)", "EXAVITQu4vr4xnSDxMaL": "Female (Sweet)", "ErXwobaYiN019PkySvjV": "Male (Professional)", "TxGEqnHWrfGW9XjX": "Male (Deep)", "yoZ06aMxZJJ28mfd3POQ": "Unisex (Friendly)", "AZnzlk1XvdvUeBnXmlld": "Female (Strong)" } def get_model_info(self): """Get information about loaded models""" return { "models_loaded": self.models_loaded, "transformers_available": self.transformers_available, "device": str(self.device), "vits_available": self.vits_model is not None, "speecht5_available": self.speecht5_model is not None, "primary_method": "SpeechT5" if self.speecht5_model else "VITS" if self.vits_model else "None", "fallback_method": "VITS" if self.speecht5_model and self.vits_model else "None", "cache_directory": os.environ.get('TRANSFORMERS_CACHE', 'default') }