Spaces:
Runtime error
Runtime error
| # Copyright (c) Meta Platforms, Inc. and affiliates. | |
| # All rights reserved. | |
| # | |
| # This source code is licensed under the license found in the | |
| # LICENSE file in the root directory of this source tree. | |
| """Base classes for the datasets that also provide non-audio metadata, | |
| e.g. description, text transcription etc. | |
| """ | |
| from dataclasses import dataclass | |
| import logging | |
| import math | |
| import re | |
| import typing as tp | |
| import torch | |
| from .audio_dataset import AudioDataset, AudioMeta | |
| from ..environment import AudioCraftEnvironment | |
| from ..modules.conditioners import SegmentWithAttributes, ConditioningAttributes | |
| logger = logging.getLogger(__name__) | |
| def _clusterify_meta(meta: AudioMeta) -> AudioMeta: | |
| """Monkey-patch meta to match cluster specificities.""" | |
| meta.path = AudioCraftEnvironment.apply_dataset_mappers(meta.path) | |
| if meta.info_path is not None: | |
| meta.info_path.zip_path = AudioCraftEnvironment.apply_dataset_mappers(meta.info_path.zip_path) | |
| return meta | |
| def clusterify_all_meta(meta: tp.List[AudioMeta]) -> tp.List[AudioMeta]: | |
| """Monkey-patch all meta to match cluster specificities.""" | |
| return [_clusterify_meta(m) for m in meta] | |
| class AudioInfo(SegmentWithAttributes): | |
| """Dummy SegmentInfo with empty attributes. | |
| The InfoAudioDataset is expected to return metadata that inherits | |
| from SegmentWithAttributes class and can return conditioning attributes. | |
| This basically guarantees all datasets will be compatible with current | |
| solver that contain conditioners requiring this. | |
| """ | |
| audio_tokens: tp.Optional[torch.Tensor] = None # populated when using cached batch for training a LM. | |
| def to_condition_attributes(self) -> ConditioningAttributes: | |
| return ConditioningAttributes() | |
| class InfoAudioDataset(AudioDataset): | |
| """AudioDataset that always returns metadata as SegmentWithAttributes along with the audio waveform. | |
| See `audiocraft.data.audio_dataset.AudioDataset` for initialization arguments. | |
| """ | |
| def __init__(self, meta: tp.List[AudioMeta], **kwargs): | |
| super().__init__(clusterify_all_meta(meta), **kwargs) | |
| def __getitem__(self, index: int) -> tp.Union[torch.Tensor, tp.Tuple[torch.Tensor, SegmentWithAttributes]]: | |
| if not self.return_info: | |
| wav = super().__getitem__(index) | |
| assert isinstance(wav, torch.Tensor) | |
| return wav | |
| wav, meta = super().__getitem__(index) | |
| return wav, AudioInfo(**meta.to_dict()) | |
| def get_keyword_or_keyword_list(value: tp.Optional[str]) -> tp.Union[tp.Optional[str], tp.Optional[tp.List[str]]]: | |
| """Preprocess a single keyword or possible a list of keywords.""" | |
| if isinstance(value, list): | |
| return get_keyword_list(value) | |
| else: | |
| return get_keyword(value) | |
| def get_string(value: tp.Optional[str]) -> tp.Optional[str]: | |
| """Preprocess a single keyword.""" | |
| if value is None or (not isinstance(value, str)) or len(value) == 0 or value == 'None': | |
| return None | |
| else: | |
| return value.strip() | |
| def get_keyword(value: tp.Optional[str]) -> tp.Optional[str]: | |
| """Preprocess a single keyword.""" | |
| if value is None or (not isinstance(value, str)) or len(value) == 0 or value == 'None': | |
| return None | |
| else: | |
| return value.strip().lower() | |
| def get_keyword_list(values: tp.Union[str, tp.List[str]]) -> tp.Optional[tp.List[str]]: | |
| """Preprocess a list of keywords.""" | |
| if isinstance(values, str): | |
| values = [v.strip() for v in re.split(r'[,\s]', values)] | |
| elif isinstance(values, float) and math.isnan(values): | |
| values = [] | |
| if not isinstance(values, list): | |
| logger.debug(f"Unexpected keyword list {values}") | |
| values = [str(values)] | |
| kws = [get_keyword(v) for v in values] | |
| kw_list = [k for k in kws if k is not None] | |
| if len(kw_list) == 0: | |
| return None | |
| else: | |
| return kw_list | |