Spaces:
Running
Running
| ο»Ώimport os | |
| import torch | |
| import tempfile | |
| import logging | |
| import soundfile as sf | |
| import numpy as np | |
| import asyncio | |
| from typing import Optional | |
| # Set HuggingFace cache directories before importing transformers | |
| os.environ.setdefault('HF_HOME', '/tmp/huggingface') | |
| os.environ.setdefault('TRANSFORMERS_CACHE', '/tmp/huggingface/transformers') | |
| os.environ.setdefault('HF_DATASETS_CACHE', '/tmp/huggingface/datasets') | |
| os.environ.setdefault('HUGGINGFACE_HUB_CACHE', '/tmp/huggingface/hub') | |
| # Create cache directories | |
| for cache_dir in ['/tmp/huggingface', '/tmp/huggingface/transformers', '/tmp/huggingface/datasets', '/tmp/huggingface/hub']: | |
| os.makedirs(cache_dir, exist_ok=True) | |
| # Try to import transformers components | |
| try: | |
| from transformers import ( | |
| VitsModel, | |
| VitsTokenizer, | |
| SpeechT5Processor, | |
| SpeechT5ForTextToSpeech, | |
| SpeechT5HifiGan | |
| ) | |
| from datasets import load_dataset | |
| TRANSFORMERS_AVAILABLE = True | |
| print("β Transformers and datasets available") | |
| except ImportError as e: | |
| TRANSFORMERS_AVAILABLE = False | |
| print(f"β οΈ Advanced TTS models not available: {e}") | |
| print("π‘ Install with: pip install transformers datasets") | |
| logger = logging.getLogger(__name__) | |
| class AdvancedTTSClient: | |
| """ | |
| Advanced TTS client using Facebook VITS and SpeechT5 models | |
| Falls back gracefully if models are not available | |
| """ | |
| def __init__(self): | |
| self.device = "cuda" if torch.cuda.is_available() else "cpu" | |
| self.models_loaded = False | |
| self.transformers_available = TRANSFORMERS_AVAILABLE | |
| # Model instances - will be loaded on demand | |
| self.vits_model = None | |
| self.vits_tokenizer = None | |
| self.speecht5_processor = None | |
| self.speecht5_model = None | |
| self.speecht5_vocoder = None | |
| self.speaker_embeddings = None | |
| logger.info(f"Advanced TTS Client initialized on device: {self.device}") | |
| logger.info(f"Transformers available: {self.transformers_available}") | |
| async def load_models(self): | |
| """Load TTS models asynchronously""" | |
| if not self.transformers_available: | |
| logger.warning("β Transformers not available - cannot load advanced TTS models") | |
| return False | |
| try: | |
| logger.info("Loading Facebook VITS and SpeechT5 models...") | |
| # Load SpeechT5 model (Microsoft) - usually more reliable | |
| try: | |
| logger.info("Loading Microsoft SpeechT5 model...") | |
| logger.info(f"Using cache directory: {os.environ.get('TRANSFORMERS_CACHE', 'default')}") | |
| # Add cache_dir parameter and retry logic | |
| cache_dir = os.environ.get('TRANSFORMERS_CACHE', '/tmp/huggingface/transformers') | |
| # Try with timeout and better error handling | |
| import asyncio | |
| async def load_model_with_timeout(): | |
| loop = asyncio.get_event_loop() | |
| # Load processor | |
| processor_task = loop.run_in_executor( | |
| None, | |
| lambda: SpeechT5Processor.from_pretrained( | |
| "microsoft/speecht5_tts", | |
| cache_dir=cache_dir | |
| ) | |
| ) | |
| # Load model | |
| model_task = loop.run_in_executor( | |
| None, | |
| lambda: SpeechT5ForTextToSpeech.from_pretrained( | |
| "microsoft/speecht5_tts", | |
| cache_dir=cache_dir | |
| ).to(self.device) | |
| ) | |
| # Load vocoder | |
| vocoder_task = loop.run_in_executor( | |
| None, | |
| lambda: SpeechT5HifiGan.from_pretrained( | |
| "microsoft/speecht5_hifigan", | |
| cache_dir=cache_dir | |
| ).to(self.device) | |
| ) | |
| # Wait for all with timeout | |
| self.speecht5_processor, self.speecht5_model, self.speecht5_vocoder = await asyncio.wait_for( | |
| asyncio.gather(processor_task, model_task, vocoder_task), | |
| timeout=300 # 5 minutes timeout | |
| ) | |
| await load_model_with_timeout() | |
| # Load speaker embeddings for SpeechT5 | |
| logger.info("Loading speaker embeddings...") | |
| try: | |
| embeddings_dataset = load_dataset("Matthijs/cmu-arctic-xvectors", split="validation") | |
| self.speaker_embeddings = torch.tensor(embeddings_dataset[0]["xvector"]).unsqueeze(0).to(self.device) | |
| logger.info("β Speaker embeddings loaded from dataset") | |
| except Exception as embed_error: | |
| logger.warning(f"Failed to load speaker embeddings from dataset: {embed_error}") | |
| # Create default embedding | |
| self.speaker_embeddings = torch.randn(1, 512).to(self.device) | |
| logger.info("β Using generated speaker embeddings") | |
| logger.info("β SpeechT5 model loaded successfully") | |
| except asyncio.TimeoutError: | |
| logger.error("β SpeechT5 loading timed out after 5 minutes") | |
| except PermissionError as perm_error: | |
| logger.error(f"β SpeechT5 loading failed due to cache permission error: {perm_error}") | |
| logger.error("π‘ Try clearing cache directory or using different cache location") | |
| except Exception as speecht5_error: | |
| logger.warning(f"SpeechT5 loading failed: {speecht5_error}") | |
| # Try to load VITS model (Facebook MMS) as secondary option | |
| try: | |
| logger.info("Loading Facebook VITS (MMS) model...") | |
| cache_dir = os.environ.get('TRANSFORMERS_CACHE', '/tmp/huggingface/transformers') | |
| async def load_vits_with_timeout(): | |
| loop = asyncio.get_event_loop() | |
| model_task = loop.run_in_executor( | |
| None, | |
| lambda: VitsModel.from_pretrained( | |
| "facebook/mms-tts-eng", | |
| cache_dir=cache_dir | |
| ).to(self.device) | |
| ) | |
| tokenizer_task = loop.run_in_executor( | |
| None, | |
| lambda: VitsTokenizer.from_pretrained( | |
| "facebook/mms-tts-eng", | |
| cache_dir=cache_dir | |
| ) | |
| ) | |
| self.vits_model, self.vits_tokenizer = await asyncio.wait_for( | |
| asyncio.gather(model_task, tokenizer_task), | |
| timeout=300 # 5 minutes timeout | |
| ) | |
| await load_vits_with_timeout() | |
| logger.info("β VITS model loaded successfully") | |
| except asyncio.TimeoutError: | |
| logger.error("β VITS loading timed out after 5 minutes") | |
| except PermissionError as perm_error: | |
| logger.error(f"β VITS loading failed due to cache permission error: {perm_error}") | |
| logger.error("π‘ Try clearing cache directory or using different cache location") | |
| except Exception as vits_error: | |
| logger.warning(f"VITS loading failed: {vits_error}") | |
| # Check if at least one model loaded | |
| if self.speecht5_model is not None or self.vits_model is not None: | |
| self.models_loaded = True | |
| logger.info("β Advanced TTS models loaded successfully!") | |
| return True | |
| else: | |
| logger.error("β No TTS models could be loaded") | |
| return False | |
| except Exception as e: | |
| logger.error(f"β Error loading TTS models: {e}") | |
| return False | |
| def get_voice_embedding(self, voice_id: Optional[str] = None): | |
| """Get speaker embedding for different voices""" | |
| if self.speaker_embeddings is None: | |
| # Create default if not available | |
| self.speaker_embeddings = torch.randn(1, 512).to(self.device) | |
| if voice_id is None: | |
| return self.speaker_embeddings | |
| # Voice mapping for different voice IDs with different characteristics | |
| voice_seed = abs(hash(voice_id)) % 1000 | |
| torch.manual_seed(voice_seed) | |
| voice_variations = { | |
| "21m00Tcm4TlvDq8ikWAM": torch.randn(1, 512) * 0.8, # Female-ish | |
| "pNInz6obpgDQGcFmaJgB": torch.randn(1, 512) * 1.2, # Male-ish | |
| "EXAVITQu4vr4xnSDxMaL": torch.randn(1, 512) * 0.6, # Sweet | |
| "ErXwobaYiN019PkySvjV": torch.randn(1, 512) * 1.0, # Professional | |
| "TxGEqnHWrfGW9XjX": torch.randn(1, 512) * 1.4, # Deep | |
| "yoZ06aMxZJJ28mfd3POQ": torch.randn(1, 512) * 0.9, # Friendly | |
| "AZnzlk1XvdvUeBnXmlld": torch.randn(1, 512) * 1.1, # Strong | |
| } | |
| if voice_id in voice_variations: | |
| embedding = voice_variations[voice_id].to(self.device) | |
| logger.info(f"Using voice variation for: {voice_id}") | |
| return embedding | |
| else: | |
| # Use original embeddings for unknown voice IDs | |
| return self.speaker_embeddings | |
| async def generate_with_vits(self, text: str, voice_id: Optional[str] = None) -> tuple: | |
| """Generate speech using Facebook VITS model""" | |
| try: | |
| if not self.vits_model or not self.vits_tokenizer: | |
| raise Exception("VITS model not loaded") | |
| logger.info(f"Generating speech with VITS: {text[:50]}...") | |
| # Tokenize text | |
| inputs = self.vits_tokenizer(text, return_tensors="pt").to(self.device) | |
| # Generate speech | |
| with torch.no_grad(): | |
| output = self.vits_model(**inputs).waveform | |
| # Convert to numpy | |
| audio_data = output.squeeze().cpu().numpy() | |
| sample_rate = self.vits_model.config.sampling_rate | |
| logger.info(f"β VITS generation successful: {len(audio_data)/sample_rate:.1f}s") | |
| return audio_data, sample_rate | |
| except Exception as e: | |
| logger.error(f"VITS generation failed: {e}") | |
| raise | |
| async def generate_with_speecht5(self, text: str, voice_id: Optional[str] = None) -> tuple: | |
| """Generate speech using Microsoft SpeechT5 model""" | |
| try: | |
| if not self.speecht5_model or not self.speecht5_processor: | |
| raise Exception("SpeechT5 model not loaded") | |
| logger.info(f"Generating speech with SpeechT5: {text[:50]}...") | |
| # Process text | |
| inputs = self.speecht5_processor(text=text, return_tensors="pt").to(self.device) | |
| # Get speaker embedding | |
| speaker_embedding = self.get_voice_embedding(voice_id) | |
| # Generate speech | |
| with torch.no_grad(): | |
| speech = self.speecht5_model.generate_speech( | |
| inputs["input_ids"], | |
| speaker_embedding, | |
| vocoder=self.speecht5_vocoder | |
| ) | |
| # Convert to numpy | |
| audio_data = speech.cpu().numpy() | |
| sample_rate = 16000 # SpeechT5 default sample rate | |
| logger.info(f"β SpeechT5 generation successful: {len(audio_data)/sample_rate:.1f}s") | |
| return audio_data, sample_rate | |
| except Exception as e: | |
| logger.error(f"SpeechT5 generation failed: {e}") | |
| raise | |
| async def text_to_speech(self, text: str, voice_id: Optional[str] = None) -> str: | |
| """ | |
| Convert text to speech using Facebook VITS or SpeechT5 | |
| """ | |
| if not self.transformers_available: | |
| logger.error("β Transformers not available - cannot use advanced TTS") | |
| raise Exception("Advanced TTS models not available. Install: pip install transformers datasets") | |
| if not self.models_loaded: | |
| logger.info("TTS models not loaded, loading now...") | |
| success = await self.load_models() | |
| if not success: | |
| logger.error("TTS model loading failed") | |
| raise Exception("TTS models failed to load") | |
| try: | |
| logger.info(f"Generating speech for text: {text[:50]}...") | |
| logger.info(f"Using voice profile: {voice_id or 'default'}") | |
| # Try SpeechT5 first (usually better quality and more reliable) | |
| try: | |
| audio_data, sample_rate = await self.generate_with_speecht5(text, voice_id) | |
| method = "SpeechT5" | |
| except Exception as speecht5_error: | |
| logger.warning(f"SpeechT5 failed: {speecht5_error}") | |
| # Fall back to VITS | |
| try: | |
| audio_data, sample_rate = await self.generate_with_vits(text, voice_id) | |
| method = "VITS" | |
| except Exception as vits_error: | |
| logger.error(f"Both SpeechT5 and VITS failed") | |
| logger.error(f"SpeechT5 error: {speecht5_error}") | |
| logger.error(f"VITS error: {vits_error}") | |
| raise Exception(f"All advanced TTS methods failed: SpeechT5({speecht5_error}), VITS({vits_error})") | |
| # Normalize audio | |
| if np.max(np.abs(audio_data)) > 0: | |
| audio_data = audio_data / np.max(np.abs(audio_data)) * 0.8 | |
| # Save to temporary file | |
| temp_file = tempfile.NamedTemporaryFile(delete=False, suffix='.wav') | |
| sf.write(temp_file.name, audio_data, samplerate=sample_rate) | |
| temp_file.close() | |
| logger.info(f"β Generated audio file: {temp_file.name}") | |
| logger.info(f"π Audio details: {len(audio_data)/sample_rate:.1f}s, {sample_rate}Hz, method: {method}") | |
| logger.info("ποΈ Using advanced open-source TTS models") | |
| return temp_file.name | |
| except Exception as e: | |
| logger.error(f"β Critical error in advanced TTS generation: {str(e)}") | |
| logger.error(f"Exception type: {type(e).__name__}") | |
| raise Exception(f"Advanced TTS generation failed: {e}") | |
| async def get_available_voices(self): | |
| """Get list of available voice configurations""" | |
| return { | |
| "21m00Tcm4TlvDq8ikWAM": "Female (Neutral)", | |
| "pNInz6obpgDQGcFmaJgB": "Male (Professional)", | |
| "EXAVITQu4vr4xnSDxMaL": "Female (Sweet)", | |
| "ErXwobaYiN019PkySvjV": "Male (Professional)", | |
| "TxGEqnHWrfGW9XjX": "Male (Deep)", | |
| "yoZ06aMxZJJ28mfd3POQ": "Unisex (Friendly)", | |
| "AZnzlk1XvdvUeBnXmlld": "Female (Strong)" | |
| } | |
| def get_model_info(self): | |
| """Get information about loaded models""" | |
| return { | |
| "models_loaded": self.models_loaded, | |
| "transformers_available": self.transformers_available, | |
| "device": str(self.device), | |
| "vits_available": self.vits_model is not None, | |
| "speecht5_available": self.speecht5_model is not None, | |
| "primary_method": "SpeechT5" if self.speecht5_model else "VITS" if self.vits_model else "None", | |
| "fallback_method": "VITS" if self.speecht5_model and self.vits_model else "None", | |
| "cache_directory": os.environ.get('TRANSFORMERS_CACHE', 'default') | |
| } | |