File size: 9,450 Bytes
bf90fc9
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
"""

WebSocket handler for real-time TTS streaming.



Because apparently waiting 2 seconds for audio generation is too much for modern users.

At least this will make it FEEL faster.

"""

import asyncio
import json
import logging
import uuid
import time
from typing import Optional, Dict, Any
from datetime import datetime

from flask_socketio import SocketIO, emit, disconnect
from flask import request

from ttsfm import TTSClient, Voice, AudioFormat, TTSException
from ttsfm.utils import split_text_by_length, estimate_audio_duration

logger = logging.getLogger(__name__)


class WebSocketTTSHandler:
    """

    Handles WebSocket connections for streaming TTS generation.

    

    Because your users can't wait 2 seconds for a complete response.

    """
    
    def __init__(self, socketio: SocketIO, tts_client: TTSClient):
        self.socketio = socketio
        self.tts_client = tts_client
        self.active_sessions: Dict[str, Dict[str, Any]] = {}
        
        # Register WebSocket events
        self._register_events()
        
    def _register_events(self):
        """Register all WebSocket event handlers."""
        
        @self.socketio.on('connect')
        def handle_connect():
            """Handle new WebSocket connection."""
            session_id = request.sid
            self.active_sessions[session_id] = {
                'connected_at': datetime.now(),
                'request_count': 0,
                'last_request': None
            }
            logger.info(f"WebSocket client connected: {session_id}")
            emit('connected', {'session_id': session_id, 'status': 'ready'})
            
        @self.socketio.on('disconnect')
        def handle_disconnect():
            """Handle WebSocket disconnection."""
            session_id = request.sid
            if session_id in self.active_sessions:
                del self.active_sessions[session_id]
            logger.info(f"WebSocket client disconnected: {session_id}")
            
        @self.socketio.on('generate_stream')
        def handle_generate_stream(data):
            """

            Handle streaming TTS generation request.

            

            Expected data format:

            {

                'text': str,

                'voice': str,

                'format': str,

                'chunk_size': int (optional, default 1024 chars),

                'instructions': str (optional, voice modulation instructions)

            }

            """
            session_id = request.sid
            request_id = data.get('request_id', str(uuid.uuid4()))
            
            # Update session info
            if session_id in self.active_sessions:
                self.active_sessions[session_id]['request_count'] += 1
                self.active_sessions[session_id]['last_request'] = datetime.now()
            
            # Emit acknowledgment
            emit('stream_started', {
                'request_id': request_id,
                'timestamp': time.time()
            })
            
            # Start async generation
            self.socketio.start_background_task(
                self._generate_stream,
                session_id,
                request_id,
                data
            )
            
        @self.socketio.on('cancel_stream')
        def handle_cancel_stream(data):
            """Handle stream cancellation request."""
            request_id = data.get('request_id')
            session_id = request.sid
            
            # In a real implementation, you'd track and cancel the actual generation
            logger.info(f"Stream cancellation requested: {request_id}")
            emit('stream_cancelled', {'request_id': request_id})
    
    def _generate_stream(self, session_id: str, request_id: str, data: Dict[str, Any]):
        """

        Generate TTS audio in chunks and stream to client.

        

        This is where the magic happens. And by magic, I mean

        chunking text and pretending it's real-time.

        """
        try:
            # Extract parameters
            text = data.get('text', '')
            voice = data.get('voice', 'alloy')
            format_str = data.get('format', 'mp3')
            chunk_size = data.get('chunk_size', 1024)
            instructions = data.get('instructions', None)  # Voice instructions support!
            
            if not text:
                self._emit_error(session_id, request_id, "No text provided")
                return
                
            # Convert string parameters to enums
            try:
                voice_enum = Voice(voice.lower())
                format_enum = AudioFormat(format_str.lower())
            except ValueError as e:
                self._emit_error(session_id, request_id, f"Invalid parameter: {str(e)}")
                return
            
            # Split text into chunks for "streaming" effect
            chunks = split_text_by_length(text, chunk_size, preserve_words=True)
            total_chunks = len(chunks)
            
            logger.info(f"Starting stream generation: {request_id} with {total_chunks} chunks")
            
            # Emit initial progress
            self.socketio.emit('stream_progress', {
                'request_id': request_id,
                'progress': 0,
                'total_chunks': total_chunks,
                'status': 'processing'
            }, room=session_id)
            
            # Process each chunk
            for i, chunk in enumerate(chunks):
                # Check if client is still connected
                if session_id not in self.active_sessions:
                    logger.warning(f"Client disconnected during generation: {session_id}")
                    break
                
                try:
                    # Generate audio for chunk
                    start_time = time.time()
                    response = self.tts_client.generate_speech(
                        text=chunk,
                        voice=voice_enum,
                        response_format=format_enum,
                        instructions=instructions,  # Pass voice instructions!
                        validate_length=False  # We already chunked it
                    )
                    generation_time = time.time() - start_time
                    
                    # Emit chunk data
                    chunk_data = {
                        'request_id': request_id,
                        'chunk_index': i,
                        'total_chunks': total_chunks,
                        'audio_data': response.audio_data.hex(),  # Convert bytes to hex string
                        'format': format_enum.value,
                        'duration': response.duration,
                        'generation_time': generation_time,
                        'chunk_text': chunk[:50] + '...' if len(chunk) > 50 else chunk
                    }
                    
                    self.socketio.emit('audio_chunk', chunk_data, room=session_id)
                    
                    # Emit progress update
                    progress = int(((i + 1) / total_chunks) * 100)
                    self.socketio.emit('stream_progress', {
                        'request_id': request_id,
                        'progress': progress,
                        'total_chunks': total_chunks,
                        'chunks_completed': i + 1,
                        'status': 'processing'
                    }, room=session_id)
                    
                    # Small delay to prevent overwhelming the client
                    # (and to make it feel more "real-time")
                    self.socketio.sleep(0.1)
                    
                except Exception as e:
                    logger.error(f"Error generating chunk {i}: {str(e)}")
                    self._emit_error(session_id, request_id, f"Chunk {i} generation failed: {str(e)}")
                    # Continue with next chunk instead of failing completely
                    continue
            
            # Emit completion
            self.socketio.emit('stream_complete', {
                'request_id': request_id,
                'total_chunks': total_chunks,
                'status': 'completed',
                'timestamp': time.time()
            }, room=session_id)
            
            logger.info(f"Stream generation completed: {request_id}")
            
        except Exception as e:
            logger.error(f"Stream generation failed: {str(e)}")
            self._emit_error(session_id, request_id, str(e))
    
    def _emit_error(self, session_id: str, request_id: str, error_message: str):
        """Emit error to specific session."""
        self.socketio.emit('stream_error', {
            'request_id': request_id,
            'error': error_message,
            'timestamp': time.time()
        }, room=session_id)
    
    def get_active_sessions_count(self) -> int:
        """Get count of active WebSocket sessions."""
        return len(self.active_sessions)
    
    def get_session_info(self, session_id: str) -> Optional[Dict[str, Any]]:
        """Get information about a specific session."""
        return self.active_sessions.get(session_id)