Spaces:
Running
Running
| import torch | |
| import tempfile | |
| import logging | |
| import soundfile as sf | |
| import numpy as np | |
| from transformers import SpeechT5Processor, SpeechT5ForTextToSpeech, SpeechT5HifiGan | |
| from datasets import load_dataset | |
| import asyncio | |
| from typing import Optional | |
| logger = logging.getLogger(__name__) | |
| class HuggingFaceTTSClient: | |
| """ | |
| Hugging Face TTS client using Microsoft SpeechT5 | |
| Replaces ElevenLabs with free, open-source TTS | |
| """ | |
| def __init__(self): | |
| self.device = "cuda" if torch.cuda.is_available() else "cpu" | |
| self.processor = None | |
| self.model = None | |
| self.vocoder = None | |
| self.speaker_embeddings = None | |
| self.model_loaded = False | |
| logger.info(f"HF TTS Client initialized on device: {self.device}") | |
| async def load_model(self): | |
| """Load SpeechT5 model and vocoder""" | |
| try: | |
| logger.info("Loading SpeechT5 TTS model...") | |
| # Load processor, model and vocoder | |
| self.processor = SpeechT5Processor.from_pretrained("microsoft/speecht5_tts") | |
| self.model = SpeechT5ForTextToSpeech.from_pretrained("microsoft/speecht5_tts").to(self.device) | |
| self.vocoder = SpeechT5HifiGan.from_pretrained("microsoft/speecht5_hifigan").to(self.device) | |
| # Load speaker embeddings dataset | |
| embeddings_dataset = load_dataset("Matthijs/cmu-arctic-xvectors", split="validation") | |
| self.speaker_embeddings = torch.tensor(embeddings_dataset[7306]["xvector"]).unsqueeze(0).to(self.device) | |
| self.model_loaded = True | |
| logger.info("✅ SpeechT5 TTS model loaded successfully") | |
| return True | |
| except Exception as e: | |
| logger.error(f"❌ Failed to load TTS model: {e}") | |
| return False | |
| async def text_to_speech(self, text: str, voice_id: Optional[str] = None) -> str: | |
| """ | |
| Convert text to speech using SpeechT5 | |
| Args: | |
| text: Text to convert to speech | |
| voice_id: Voice identifier (for compatibility, maps to speaker embeddings) | |
| Returns: | |
| Path to generated audio file | |
| """ | |
| if not self.model_loaded: | |
| logger.info("Model not loaded, loading now...") | |
| success = await self.load_model() | |
| if not success: | |
| raise Exception("Failed to load TTS model") | |
| try: | |
| logger.info(f"Generating speech for text: {text[:50]}...") | |
| # Choose speaker embedding based on voice_id (for variety) | |
| speaker_idx = self._get_speaker_index(voice_id) | |
| embeddings_dataset = load_dataset("Matthijs/cmu-arctic-xvectors", split="validation") | |
| speaker_embeddings = torch.tensor(embeddings_dataset[speaker_idx]["xvector"]).unsqueeze(0).to(self.device) | |
| # Process text | |
| inputs = self.processor(text=text, return_tensors="pt").to(self.device) | |
| # Generate speech | |
| with torch.no_grad(): | |
| speech = self.model.generate_speech( | |
| inputs["input_ids"], | |
| speaker_embeddings, | |
| vocoder=self.vocoder | |
| ) | |
| # Convert to audio file | |
| audio_data = speech.cpu().numpy() | |
| # Save to temporary file | |
| temp_file = tempfile.NamedTemporaryFile(delete=False, suffix='.wav') | |
| sf.write(temp_file.name, audio_data, samplerate=16000) | |
| temp_file.close() | |
| logger.info(f"✅ Generated speech audio: {temp_file.name}") | |
| return temp_file.name | |
| except Exception as e: | |
| logger.error(f"❌ Error generating speech: {e}") | |
| raise Exception(f"TTS generation failed: {e}") | |
| def _get_speaker_index(self, voice_id: Optional[str]) -> int: | |
| """Map voice_id to speaker embedding index for voice variety""" | |
| voice_mapping = { | |
| # Map ElevenLabs voice IDs to speaker indices for compatibility | |
| "21m00Tcm4TlvDq8ikWAM": 7306, # Female voice (default) | |
| "pNInz6obpgDQGcFmaJgB": 4077, # Male voice | |
| "EXAVITQu4vr4xnSDxMaL": 1995, # Female voice (sweet) | |
| "ErXwobaYiN019PkySvjV": 8051, # Male voice (professional) | |
| "TxGEqnHWrfWFTfGW9XjX": 5688, # Deep male voice | |
| "yoZ06aMxZJJ28mfd3POQ": 3570, # Friendly voice | |
| "AZnzlk1XvdvUeBnXmlld": 2967, # Strong female | |
| } | |
| return voice_mapping.get(voice_id, 7306) # Default to female voice | |