Spaces:
Running
Running
| from __future__ import annotations | |
| import torch | |
| import asyncio | |
| from queue import Queue | |
| from typing import TYPE_CHECKING, Optional | |
| from transformers.generation import BaseStreamer | |
| class AudioStreamer(BaseStreamer): | |
| """ | |
| Audio streamer that stores audio chunks in queues for each sample in the batch. | |
| This allows streaming audio generation for multiple samples simultaneously. | |
| Parameters: | |
| batch_size (`int`): | |
| The batch size for generation | |
| stop_signal (`any`, *optional*): | |
| The signal to put in the queue when generation ends. Defaults to None. | |
| timeout (`float`, *optional*): | |
| The timeout for the audio queue. If `None`, the queue will block indefinitely. | |
| """ | |
| def __init__( | |
| self, | |
| batch_size: int, | |
| stop_signal: Optional[any] = None, | |
| timeout: Optional[float] = None, | |
| ): | |
| self.batch_size = batch_size | |
| self.stop_signal = stop_signal | |
| self.timeout = timeout | |
| # Create a queue for each sample in the batch | |
| self.audio_queues = [Queue() for _ in range(batch_size)] | |
| self.finished_flags = [False for _ in range(batch_size)] | |
| self.sample_indices_map = {} # Maps from sample index to queue index | |
| def put(self, audio_chunks: torch.Tensor, sample_indices: torch.Tensor): | |
| """ | |
| Receives audio chunks and puts them in the appropriate queues. | |
| Args: | |
| audio_chunks: Tensor of shape (num_samples, ...) containing audio chunks | |
| sample_indices: Tensor indicating which samples these chunks belong to | |
| """ | |
| for i, sample_idx in enumerate(sample_indices): | |
| idx = sample_idx.item() | |
| if idx < self.batch_size and not self.finished_flags[idx]: | |
| # Convert to numpy or keep as tensor based on preference | |
| audio_chunk = audio_chunks[i].detach().cpu() | |
| self.audio_queues[idx].put(audio_chunk, timeout=self.timeout) | |
| def end(self, sample_indices: Optional[torch.Tensor] = None): | |
| """ | |
| Signals the end of generation for specified samples or all samples. | |
| Args: | |
| sample_indices: Optional tensor of sample indices to end. If None, ends all. | |
| """ | |
| if sample_indices is None: | |
| # End all samples | |
| for idx in range(self.batch_size): | |
| if not self.finished_flags[idx]: | |
| self.audio_queues[idx].put(self.stop_signal, timeout=self.timeout) | |
| self.finished_flags[idx] = True | |
| else: | |
| # End specific samples | |
| for sample_idx in sample_indices: | |
| idx = sample_idx.item() if torch.is_tensor(sample_idx) else sample_idx | |
| if idx < self.batch_size and not self.finished_flags[idx]: | |
| self.audio_queues[idx].put(self.stop_signal, timeout=self.timeout) | |
| self.finished_flags[idx] = True | |
| def __iter__(self): | |
| """Returns an iterator over the batch of audio streams.""" | |
| return AudioBatchIterator(self) | |
| def get_stream(self, sample_idx: int): | |
| """Get the audio stream for a specific sample.""" | |
| if sample_idx >= self.batch_size: | |
| raise ValueError(f"Sample index {sample_idx} exceeds batch size {self.batch_size}") | |
| return AudioSampleIterator(self, sample_idx) | |
| class AudioSampleIterator: | |
| """Iterator for a single audio stream from the batch.""" | |
| def __init__(self, streamer: AudioStreamer, sample_idx: int): | |
| self.streamer = streamer | |
| self.sample_idx = sample_idx | |
| def __iter__(self): | |
| return self | |
| def __next__(self): | |
| value = self.streamer.audio_queues[self.sample_idx].get(timeout=self.streamer.timeout) | |
| if value == self.streamer.stop_signal: | |
| raise StopIteration() | |
| return value | |
| class AudioBatchIterator: | |
| """Iterator that yields audio chunks for all samples in the batch.""" | |
| def __init__(self, streamer: AudioStreamer): | |
| self.streamer = streamer | |
| self.active_samples = set(range(streamer.batch_size)) | |
| def __iter__(self): | |
| return self | |
| def __next__(self): | |
| if not self.active_samples: | |
| raise StopIteration() | |
| batch_chunks = {} | |
| samples_to_remove = set() | |
| # Try to get chunks from all active samples | |
| for idx in self.active_samples: | |
| try: | |
| value = self.streamer.audio_queues[idx].get(block=False) | |
| if value == self.streamer.stop_signal: | |
| samples_to_remove.add(idx) | |
| else: | |
| batch_chunks[idx] = value | |
| except: | |
| # Queue is empty for this sample, skip it this iteration | |
| pass | |
| # Remove finished samples | |
| self.active_samples -= samples_to_remove | |
| if batch_chunks: | |
| return batch_chunks | |
| elif self.active_samples: | |
| # If no chunks were ready but we still have active samples, | |
| # wait a bit and try again | |
| import time | |
| time.sleep(0.01) | |
| return self.__next__() | |
| else: | |
| raise StopIteration() | |
| class AsyncAudioStreamer(AudioStreamer): | |
| """ | |
| Async version of AudioStreamer for use in async contexts. | |
| """ | |
| def __init__( | |
| self, | |
| batch_size: int, | |
| stop_signal: Optional[any] = None, | |
| timeout: Optional[float] = None, | |
| ): | |
| super().__init__(batch_size, stop_signal, timeout) | |
| # Replace regular queues with async queues | |
| self.audio_queues = [asyncio.Queue() for _ in range(batch_size)] | |
| self.loop = asyncio.get_running_loop() | |
| def put(self, audio_chunks: torch.Tensor, sample_indices: torch.Tensor): | |
| """Put audio chunks in the appropriate async queues.""" | |
| for i, sample_idx in enumerate(sample_indices): | |
| idx = sample_idx.item() | |
| if idx < self.batch_size and not self.finished_flags[idx]: | |
| audio_chunk = audio_chunks[i].detach().cpu() | |
| self.loop.call_soon_threadsafe( | |
| self.audio_queues[idx].put_nowait, audio_chunk | |
| ) | |
| def end(self, sample_indices: Optional[torch.Tensor] = None): | |
| """Signal the end of generation for specified samples.""" | |
| if sample_indices is None: | |
| indices_to_end = range(self.batch_size) | |
| else: | |
| indices_to_end = [s.item() if torch.is_tensor(s) else s for s in sample_indices] | |
| for idx in indices_to_end: | |
| if idx < self.batch_size and not self.finished_flags[idx]: | |
| self.loop.call_soon_threadsafe( | |
| self.audio_queues[idx].put_nowait, self.stop_signal | |
| ) | |
| self.finished_flags[idx] = True | |
| async def get_stream(self, sample_idx: int): | |
| """Get async iterator for a specific sample's audio stream.""" | |
| if sample_idx >= self.batch_size: | |
| raise ValueError(f"Sample index {sample_idx} exceeds batch size {self.batch_size}") | |
| while True: | |
| value = await self.audio_queues[sample_idx].get() | |
| if value == self.stop_signal: | |
| break | |
| yield value | |
| def __aiter__(self): | |
| """Returns an async iterator over all audio streams.""" | |
| return AsyncAudioBatchIterator(self) | |
| class AsyncAudioBatchIterator: | |
| """Async iterator for batch audio streaming.""" | |
| def __init__(self, streamer: AsyncAudioStreamer): | |
| self.streamer = streamer | |
| self.active_samples = set(range(streamer.batch_size)) | |
| def __aiter__(self): | |
| return self | |
| async def __anext__(self): | |
| if not self.active_samples: | |
| raise StopAsyncIteration() | |
| batch_chunks = {} | |
| samples_to_remove = set() | |
| # Create tasks for all active samples | |
| tasks = { | |
| idx: asyncio.create_task(self._get_chunk(idx)) | |
| for idx in self.active_samples | |
| } | |
| # Wait for at least one chunk to be ready | |
| done, pending = await asyncio.wait( | |
| tasks.values(), | |
| return_when=asyncio.FIRST_COMPLETED, | |
| timeout=self.streamer.timeout | |
| ) | |
| # Cancel pending tasks | |
| for task in pending: | |
| task.cancel() | |
| # Process completed tasks | |
| for idx, task in tasks.items(): | |
| if task in done: | |
| try: | |
| value = await task | |
| if value == self.streamer.stop_signal: | |
| samples_to_remove.add(idx) | |
| else: | |
| batch_chunks[idx] = value | |
| except asyncio.CancelledError: | |
| pass | |
| self.active_samples -= samples_to_remove | |
| if batch_chunks: | |
| return batch_chunks | |
| elif self.active_samples: | |
| # Try again if we still have active samples | |
| return await self.__anext__() | |
| else: | |
| raise StopAsyncIteration() | |
| async def _get_chunk(self, idx): | |
| """Helper to get a chunk from a specific queue.""" | |
| return await self.streamer.audio_queues[idx].get() |