Spaces:
Running
Running
| # Copyright (c) Meta Platforms, Inc. and affiliates. | |
| # All rights reserved. | |
| # | |
| # This source code is licensed under the license found in the | |
| # LICENSE file in the root directory of this source tree. | |
| """ | |
| Base implementation for audio generative models. This base implementation | |
| combines all the required components to run inference with pretrained audio | |
| generative models. It can be easily inherited by downstream model classes to | |
| provide easy access to the generation API. | |
| """ | |
| from abc import ABC, abstractmethod | |
| import typing as tp | |
| import omegaconf | |
| import torch | |
| from .encodec import CompressionModel | |
| from .flow import FlowModel | |
| from .lm import LMModel | |
| from .builders import get_wrapped_compression_model | |
| from ..data.audio_utils import convert_audio | |
| from ..modules.conditioners import ConditioningAttributes | |
| from ..utils.autocast import TorchAutocast | |
| class BaseGenModel(ABC): | |
| """Base generative model with convenient generation API. | |
| Args: | |
| name (str): name of the model. | |
| compression_model (CompressionModel): Compression model | |
| used to map audio to invertible discrete representations. | |
| lm (LMModel): Language model over discrete representations. | |
| max_duration (float, optional): maximum duration the model can produce, | |
| otherwise, inferred from the training params. | |
| """ | |
| def __init__(self, name: str, compression_model: CompressionModel, lm: tp.Union[LMModel, FlowModel], | |
| max_duration: tp.Optional[float] = None): | |
| self.name = name | |
| self.compression_model = compression_model | |
| self.lm = lm | |
| self.cfg: tp.Optional[omegaconf.DictConfig] = None | |
| # Just to be safe, let's put everything in eval mode. | |
| self.compression_model.eval() | |
| self.lm.eval() | |
| if hasattr(lm, 'cfg'): | |
| cfg = lm.cfg | |
| assert isinstance(cfg, omegaconf.DictConfig) | |
| self.cfg = cfg | |
| if self.cfg is not None: | |
| self.compression_model = get_wrapped_compression_model(self.compression_model, self.cfg) | |
| if max_duration is None: | |
| if self.cfg is not None: | |
| max_duration = lm.cfg.dataset.segment_duration # type: ignore | |
| else: | |
| raise ValueError("You must provide max_duration when building directly your GenModel") | |
| assert max_duration is not None | |
| self.max_duration: float = max_duration | |
| self.duration = self.max_duration | |
| # self.extend_stride is the length of audio extension when generating samples longer | |
| # than self.max_duration. NOTE: the derived class must set self.extend_stride to a | |
| # positive float value when generating with self.duration > self.max_duration. | |
| self.extend_stride: tp.Optional[float] = None | |
| self.device = next(iter(lm.parameters())).device | |
| self.generation_params: dict = {} | |
| self._progress_callback: tp.Optional[tp.Callable[[int, int], None]] = None | |
| if self.device.type == 'cpu' or self.device.type == 'mps': | |
| self.autocast = TorchAutocast(enabled=False) | |
| else: | |
| self.autocast = TorchAutocast( | |
| enabled=True, device_type=self.device.type, dtype=torch.float16) | |
| def frame_rate(self) -> float: | |
| """Roughly the number of AR steps per seconds.""" | |
| return self.compression_model.frame_rate | |
| def sample_rate(self) -> int: | |
| """Sample rate of the generated audio.""" | |
| return self.compression_model.sample_rate | |
| def audio_channels(self) -> int: | |
| """Audio channels of the generated audio.""" | |
| return self.compression_model.channels | |
| def set_custom_progress_callback(self, progress_callback: tp.Optional[tp.Callable[[int, int], None]] = None): | |
| """Override the default progress callback.""" | |
| self._progress_callback = progress_callback | |
| def set_generation_params(self, *args, **kwargs): | |
| """Set the generation parameters.""" | |
| raise NotImplementedError("No base implementation for setting generation params.") | |
| def get_pretrained(name: str, device=None): | |
| raise NotImplementedError("No base implementation for getting pretrained model") | |
| def _prepare_tokens_and_attributes( | |
| self, | |
| descriptions: tp.Sequence[tp.Optional[str]], | |
| prompt: tp.Optional[torch.Tensor], | |
| ) -> tp.Tuple[tp.List[ConditioningAttributes], tp.Optional[torch.Tensor]]: | |
| """Prepare model inputs. | |
| Args: | |
| descriptions (list of str): A list of strings used as text conditioning. | |
| prompt (torch.Tensor): A batch of waveforms used for continuation. | |
| """ | |
| attributes = [ | |
| ConditioningAttributes(text={'description': description}) | |
| for description in descriptions] | |
| if prompt is not None: | |
| if descriptions is not None: | |
| assert len(descriptions) == len(prompt), "Prompt and nb. descriptions doesn't match" | |
| prompt = prompt.to(self.device) | |
| prompt_tokens, scale = self.compression_model.encode(prompt) | |
| assert scale is None | |
| else: | |
| prompt_tokens = None | |
| return attributes, prompt_tokens | |
| def generate_unconditional(self, num_samples: int, progress: bool = False, | |
| return_tokens: bool = False) -> tp.Union[torch.Tensor, | |
| tp.Tuple[torch.Tensor, torch.Tensor]]: | |
| """Generate samples in an unconditional manner. | |
| Args: | |
| num_samples (int): Number of samples to be generated. | |
| progress (bool, optional): Flag to display progress of the generation process. Defaults to False. | |
| """ | |
| descriptions: tp.List[tp.Optional[str]] = [None] * num_samples | |
| attributes, prompt_tokens = self._prepare_tokens_and_attributes(descriptions, None) | |
| tokens = self._generate_tokens(attributes, prompt_tokens, progress) | |
| if return_tokens: | |
| return self.generate_audio(tokens), tokens | |
| return self.generate_audio(tokens) | |
| def generate(self, descriptions: tp.List[str], progress: bool = False, return_tokens: bool = False) \ | |
| -> tp.Union[torch.Tensor, tp.Tuple[torch.Tensor, torch.Tensor]]: | |
| """Generate samples conditioned on text. | |
| Args: | |
| descriptions (list of str): A list of strings used as text conditioning. | |
| progress (bool, optional): Flag to display progress of the generation process. Defaults to False. | |
| """ | |
| attributes, prompt_tokens = self._prepare_tokens_and_attributes(descriptions, None) | |
| assert prompt_tokens is None | |
| tokens = self._generate_tokens(attributes, prompt_tokens, progress) | |
| if return_tokens: | |
| return self.generate_audio(tokens), tokens | |
| return self.generate_audio(tokens) | |
| def generate_continuation(self, prompt: torch.Tensor, prompt_sample_rate: int, | |
| descriptions: tp.Optional[tp.List[tp.Optional[str]]] = None, | |
| progress: bool = False, return_tokens: bool = False) \ | |
| -> tp.Union[torch.Tensor, tp.Tuple[torch.Tensor, torch.Tensor]]: | |
| """Generate samples conditioned on audio prompts and an optional text description. | |
| Args: | |
| prompt (torch.Tensor): A batch of waveforms used for continuation. | |
| Prompt should be [B, C, T], or [C, T] if only one sample is generated. | |
| prompt_sample_rate (int): Sampling rate of the given audio waveforms. | |
| descriptions (list of str, optional): A list of strings used as text conditioning. Defaults to None. | |
| progress (bool, optional): Flag to display progress of the generation process. Defaults to False. | |
| """ | |
| if prompt.dim() == 2: | |
| prompt = prompt[None] | |
| if prompt.dim() != 3: | |
| raise ValueError("prompt should have 3 dimensions: [B, C, T] (C = 1).") | |
| prompt = convert_audio(prompt, prompt_sample_rate, self.sample_rate, self.audio_channels) | |
| if descriptions is None: | |
| descriptions = [None] * len(prompt) | |
| attributes, prompt_tokens = self._prepare_tokens_and_attributes(descriptions, prompt) | |
| assert prompt_tokens is not None | |
| tokens = self._generate_tokens(attributes, prompt_tokens, progress) | |
| if return_tokens: | |
| return self.generate_audio(tokens), tokens | |
| return self.generate_audio(tokens) | |
| def _generate_tokens(self, attributes: tp.List[ConditioningAttributes], | |
| prompt_tokens: tp.Optional[torch.Tensor], progress: bool = False) -> torch.Tensor: | |
| """Generate discrete audio tokens given audio prompt and/or conditions. | |
| Args: | |
| attributes (list of ConditioningAttributes): Conditions used for generation (here text). | |
| prompt_tokens (torch.Tensor, optional): Audio prompt used for continuation. | |
| progress (bool, optional): Flag to display progress of the generation process. Defaults to False. | |
| Returns: | |
| torch.Tensor: Generated audio, of shape [B, C, T], T is defined by the generation params. | |
| """ | |
| total_gen_len = int(self.duration * self.frame_rate) | |
| max_prompt_len = int(min(self.duration, self.max_duration) * self.frame_rate) | |
| current_gen_offset: int = 0 | |
| def _progress_callback(generated_tokens: int, tokens_to_generate: int): | |
| generated_tokens += current_gen_offset | |
| if self._progress_callback is not None: | |
| # Note that total_gen_len might be quite wrong depending on the | |
| # codebook pattern used, but with delay it is almost accurate. | |
| self._progress_callback(generated_tokens, tokens_to_generate) | |
| else: | |
| print(f'{generated_tokens: 6d} / {tokens_to_generate: 6d}', end='\r') | |
| if prompt_tokens is not None: | |
| assert max_prompt_len >= prompt_tokens.shape[-1], \ | |
| "Prompt is longer than audio to generate" | |
| callback = None | |
| if progress: | |
| callback = _progress_callback | |
| if self.duration <= self.max_duration: | |
| # generate by sampling from LM, simple case. | |
| with self.autocast: | |
| gen_tokens = self.lm.generate( | |
| prompt_tokens, attributes, | |
| callback=callback, max_gen_len=total_gen_len, **self.generation_params) | |
| else: | |
| assert self.extend_stride is not None, "Stride should be defined to generate beyond max_duration" | |
| assert self.extend_stride < self.max_duration, "Cannot stride by more than max generation duration." | |
| all_tokens = [] | |
| if prompt_tokens is None: | |
| prompt_length = 0 | |
| else: | |
| all_tokens.append(prompt_tokens) | |
| prompt_length = prompt_tokens.shape[-1] | |
| stride_tokens = int(self.frame_rate * self.extend_stride) | |
| while current_gen_offset + prompt_length < total_gen_len: | |
| time_offset = current_gen_offset / self.frame_rate | |
| chunk_duration = min(self.duration - time_offset, self.max_duration) | |
| max_gen_len = int(chunk_duration * self.frame_rate) | |
| with self.autocast: | |
| gen_tokens = self.lm.generate( | |
| prompt_tokens, attributes, | |
| callback=callback, max_gen_len=max_gen_len, **self.generation_params) | |
| if prompt_tokens is None: | |
| all_tokens.append(gen_tokens) | |
| else: | |
| all_tokens.append(gen_tokens[:, :, prompt_tokens.shape[-1]:]) | |
| prompt_tokens = gen_tokens[:, :, stride_tokens:] | |
| prompt_length = prompt_tokens.shape[-1] | |
| current_gen_offset += stride_tokens | |
| gen_tokens = torch.cat(all_tokens, dim=-1) | |
| return gen_tokens | |
| def generate_audio(self, gen_tokens: torch.Tensor) -> torch.Tensor: | |
| """Generate Audio from tokens.""" | |
| assert gen_tokens.dim() == 3 | |
| with torch.no_grad(): | |
| gen_audio = self.compression_model.decode(gen_tokens, None) | |
| return gen_audio | |