Spaces:
Runtime error
Runtime error
| """ | |
| WebSocket handler for real-time TTS streaming. | |
| Because apparently waiting 2 seconds for audio generation is too much for modern users. | |
| At least this will make it FEEL faster. | |
| """ | |
| import asyncio | |
| import json | |
| import logging | |
| import uuid | |
| import time | |
| from typing import Optional, Dict, Any | |
| from datetime import datetime | |
| from flask_socketio import SocketIO, emit, disconnect | |
| from flask import request | |
| from ttsfm import TTSClient, Voice, AudioFormat, TTSException | |
| from ttsfm.utils import split_text_by_length, estimate_audio_duration | |
| logger = logging.getLogger(__name__) | |
| class WebSocketTTSHandler: | |
| """ | |
| Handles WebSocket connections for streaming TTS generation. | |
| Because your users can't wait 2 seconds for a complete response. | |
| """ | |
| def __init__(self, socketio: SocketIO, tts_client: TTSClient): | |
| self.socketio = socketio | |
| self.tts_client = tts_client | |
| self.active_sessions: Dict[str, Dict[str, Any]] = {} | |
| # Register WebSocket events | |
| self._register_events() | |
| def _register_events(self): | |
| """Register all WebSocket event handlers.""" | |
| def handle_connect(): | |
| """Handle new WebSocket connection.""" | |
| session_id = request.sid | |
| self.active_sessions[session_id] = { | |
| 'connected_at': datetime.now(), | |
| 'request_count': 0, | |
| 'last_request': None | |
| } | |
| logger.info(f"WebSocket client connected: {session_id}") | |
| emit('connected', {'session_id': session_id, 'status': 'ready'}) | |
| def handle_disconnect(): | |
| """Handle WebSocket disconnection.""" | |
| session_id = request.sid | |
| if session_id in self.active_sessions: | |
| del self.active_sessions[session_id] | |
| logger.info(f"WebSocket client disconnected: {session_id}") | |
| def handle_generate_stream(data): | |
| """ | |
| Handle streaming TTS generation request. | |
| Expected data format: | |
| { | |
| 'text': str, | |
| 'voice': str, | |
| 'format': str, | |
| 'chunk_size': int (optional, default 1024 chars), | |
| 'instructions': str (optional, voice modulation instructions) | |
| } | |
| """ | |
| session_id = request.sid | |
| request_id = data.get('request_id', str(uuid.uuid4())) | |
| # Update session info | |
| if session_id in self.active_sessions: | |
| self.active_sessions[session_id]['request_count'] += 1 | |
| self.active_sessions[session_id]['last_request'] = datetime.now() | |
| # Emit acknowledgment | |
| emit('stream_started', { | |
| 'request_id': request_id, | |
| 'timestamp': time.time() | |
| }) | |
| # Start async generation | |
| self.socketio.start_background_task( | |
| self._generate_stream, | |
| session_id, | |
| request_id, | |
| data | |
| ) | |
| def handle_cancel_stream(data): | |
| """Handle stream cancellation request.""" | |
| request_id = data.get('request_id') | |
| session_id = request.sid | |
| # In a real implementation, you'd track and cancel the actual generation | |
| logger.info(f"Stream cancellation requested: {request_id}") | |
| emit('stream_cancelled', {'request_id': request_id}) | |
| def _generate_stream(self, session_id: str, request_id: str, data: Dict[str, Any]): | |
| """ | |
| Generate TTS audio in chunks and stream to client. | |
| This is where the magic happens. And by magic, I mean | |
| chunking text and pretending it's real-time. | |
| """ | |
| try: | |
| # Extract parameters | |
| text = data.get('text', '') | |
| voice = data.get('voice', 'alloy') | |
| format_str = data.get('format', 'mp3') | |
| chunk_size = data.get('chunk_size', 1024) | |
| instructions = data.get('instructions', None) # Voice instructions support! | |
| if not text: | |
| self._emit_error(session_id, request_id, "No text provided") | |
| return | |
| # Convert string parameters to enums | |
| try: | |
| voice_enum = Voice(voice.lower()) | |
| format_enum = AudioFormat(format_str.lower()) | |
| except ValueError as e: | |
| self._emit_error(session_id, request_id, f"Invalid parameter: {str(e)}") | |
| return | |
| # Split text into chunks for "streaming" effect | |
| chunks = split_text_by_length(text, chunk_size, preserve_words=True) | |
| total_chunks = len(chunks) | |
| logger.info(f"Starting stream generation: {request_id} with {total_chunks} chunks") | |
| # Emit initial progress | |
| self.socketio.emit('stream_progress', { | |
| 'request_id': request_id, | |
| 'progress': 0, | |
| 'total_chunks': total_chunks, | |
| 'status': 'processing' | |
| }, room=session_id) | |
| # Process each chunk | |
| for i, chunk in enumerate(chunks): | |
| # Check if client is still connected | |
| if session_id not in self.active_sessions: | |
| logger.warning(f"Client disconnected during generation: {session_id}") | |
| break | |
| try: | |
| # Generate audio for chunk | |
| start_time = time.time() | |
| response = self.tts_client.generate_speech( | |
| text=chunk, | |
| voice=voice_enum, | |
| response_format=format_enum, | |
| instructions=instructions, # Pass voice instructions! | |
| validate_length=False # We already chunked it | |
| ) | |
| generation_time = time.time() - start_time | |
| # Emit chunk data | |
| chunk_data = { | |
| 'request_id': request_id, | |
| 'chunk_index': i, | |
| 'total_chunks': total_chunks, | |
| 'audio_data': response.audio_data.hex(), # Convert bytes to hex string | |
| 'format': format_enum.value, | |
| 'duration': response.duration, | |
| 'generation_time': generation_time, | |
| 'chunk_text': chunk[:50] + '...' if len(chunk) > 50 else chunk | |
| } | |
| self.socketio.emit('audio_chunk', chunk_data, room=session_id) | |
| # Emit progress update | |
| progress = int(((i + 1) / total_chunks) * 100) | |
| self.socketio.emit('stream_progress', { | |
| 'request_id': request_id, | |
| 'progress': progress, | |
| 'total_chunks': total_chunks, | |
| 'chunks_completed': i + 1, | |
| 'status': 'processing' | |
| }, room=session_id) | |
| # Small delay to prevent overwhelming the client | |
| # (and to make it feel more "real-time") | |
| self.socketio.sleep(0.1) | |
| except Exception as e: | |
| logger.error(f"Error generating chunk {i}: {str(e)}") | |
| self._emit_error(session_id, request_id, f"Chunk {i} generation failed: {str(e)}") | |
| # Continue with next chunk instead of failing completely | |
| continue | |
| # Emit completion | |
| self.socketio.emit('stream_complete', { | |
| 'request_id': request_id, | |
| 'total_chunks': total_chunks, | |
| 'status': 'completed', | |
| 'timestamp': time.time() | |
| }, room=session_id) | |
| logger.info(f"Stream generation completed: {request_id}") | |
| except Exception as e: | |
| logger.error(f"Stream generation failed: {str(e)}") | |
| self._emit_error(session_id, request_id, str(e)) | |
| def _emit_error(self, session_id: str, request_id: str, error_message: str): | |
| """Emit error to specific session.""" | |
| self.socketio.emit('stream_error', { | |
| 'request_id': request_id, | |
| 'error': error_message, | |
| 'timestamp': time.time() | |
| }, room=session_id) | |
| def get_active_sessions_count(self) -> int: | |
| """Get count of active WebSocket sessions.""" | |
| return len(self.active_sessions) | |
| def get_session_info(self, session_id: str) -> Optional[Dict[str, Any]]: | |
| """Get information about a specific session.""" | |
| return self.active_sessions.get(session_id) |