AI_Avatar_Chat / hf_tts_client.py
bravedims
Replace ElevenLabs with HuggingFace TTS (SpeechT5)
8be8b4b
raw
history blame
4.7 kB
import torch
import tempfile
import logging
import soundfile as sf
import numpy as np
from transformers import SpeechT5Processor, SpeechT5ForTextToSpeech, SpeechT5HifiGan
from datasets import load_dataset
import asyncio
from typing import Optional
logger = logging.getLogger(__name__)
class HuggingFaceTTSClient:
"""
Hugging Face TTS client using Microsoft SpeechT5
Replaces ElevenLabs with free, open-source TTS
"""
def __init__(self):
self.device = "cuda" if torch.cuda.is_available() else "cpu"
self.processor = None
self.model = None
self.vocoder = None
self.speaker_embeddings = None
self.model_loaded = False
logger.info(f"HF TTS Client initialized on device: {self.device}")
async def load_model(self):
"""Load SpeechT5 model and vocoder"""
try:
logger.info("Loading SpeechT5 TTS model...")
# Load processor, model and vocoder
self.processor = SpeechT5Processor.from_pretrained("microsoft/speecht5_tts")
self.model = SpeechT5ForTextToSpeech.from_pretrained("microsoft/speecht5_tts").to(self.device)
self.vocoder = SpeechT5HifiGan.from_pretrained("microsoft/speecht5_hifigan").to(self.device)
# Load speaker embeddings dataset
embeddings_dataset = load_dataset("Matthijs/cmu-arctic-xvectors", split="validation")
self.speaker_embeddings = torch.tensor(embeddings_dataset[7306]["xvector"]).unsqueeze(0).to(self.device)
self.model_loaded = True
logger.info("✅ SpeechT5 TTS model loaded successfully")
return True
except Exception as e:
logger.error(f"❌ Failed to load TTS model: {e}")
return False
async def text_to_speech(self, text: str, voice_id: Optional[str] = None) -> str:
"""
Convert text to speech using SpeechT5
Args:
text: Text to convert to speech
voice_id: Voice identifier (for compatibility, maps to speaker embeddings)
Returns:
Path to generated audio file
"""
if not self.model_loaded:
logger.info("Model not loaded, loading now...")
success = await self.load_model()
if not success:
raise Exception("Failed to load TTS model")
try:
logger.info(f"Generating speech for text: {text[:50]}...")
# Choose speaker embedding based on voice_id (for variety)
speaker_idx = self._get_speaker_index(voice_id)
embeddings_dataset = load_dataset("Matthijs/cmu-arctic-xvectors", split="validation")
speaker_embeddings = torch.tensor(embeddings_dataset[speaker_idx]["xvector"]).unsqueeze(0).to(self.device)
# Process text
inputs = self.processor(text=text, return_tensors="pt").to(self.device)
# Generate speech
with torch.no_grad():
speech = self.model.generate_speech(
inputs["input_ids"],
speaker_embeddings,
vocoder=self.vocoder
)
# Convert to audio file
audio_data = speech.cpu().numpy()
# Save to temporary file
temp_file = tempfile.NamedTemporaryFile(delete=False, suffix='.wav')
sf.write(temp_file.name, audio_data, samplerate=16000)
temp_file.close()
logger.info(f"✅ Generated speech audio: {temp_file.name}")
return temp_file.name
except Exception as e:
logger.error(f"❌ Error generating speech: {e}")
raise Exception(f"TTS generation failed: {e}")
def _get_speaker_index(self, voice_id: Optional[str]) -> int:
"""Map voice_id to speaker embedding index for voice variety"""
voice_mapping = {
# Map ElevenLabs voice IDs to speaker indices for compatibility
"21m00Tcm4TlvDq8ikWAM": 7306, # Female voice (default)
"pNInz6obpgDQGcFmaJgB": 4077, # Male voice
"EXAVITQu4vr4xnSDxMaL": 1995, # Female voice (sweet)
"ErXwobaYiN019PkySvjV": 8051, # Male voice (professional)
"TxGEqnHWrfWFTfGW9XjX": 5688, # Deep male voice
"yoZ06aMxZJJ28mfd3POQ": 3570, # Friendly voice
"AZnzlk1XvdvUeBnXmlld": 2967, # Strong female
}
return voice_mapping.get(voice_id, 7306) # Default to female voice