Spaces:
Running
on
Zero
Running
on
Zero
| from pathlib import Path | |
| from typing import Callable | |
| from typing import Dict | |
| from typing import List | |
| from typing import Union | |
| import numpy as np | |
| from torch.utils.data import SequentialSampler | |
| from torch.utils.data.distributed import DistributedSampler | |
| from ..core import AudioSignal | |
| from ..core import util | |
| class AudioLoader: | |
| """Loads audio endlessly from a list of audio sources | |
| containing paths to audio files. Audio sources can be | |
| folders full of audio files (which are found via file | |
| extension) or by providing a CSV file which contains paths | |
| to audio files. | |
| Parameters | |
| ---------- | |
| sources : List[str], optional | |
| Sources containing folders, or CSVs with | |
| paths to audio files, by default None | |
| weights : List[float], optional | |
| Weights to sample audio files from each source, by default None | |
| relative_path : str, optional | |
| Path audio should be loaded relative to, by default "" | |
| transform : Callable, optional | |
| Transform to instantiate alongside audio sample, | |
| by default None | |
| ext : List[str] | |
| List of extensions to find audio within each source by. Can | |
| also be a file name (e.g. "vocals.wav"). by default | |
| ``['.wav', '.flac', '.mp3', '.mp4']``. | |
| shuffle: bool | |
| Whether to shuffle the files within the dataloader. Defaults to True. | |
| shuffle_state: int | |
| State to use to seed the shuffle of the files. | |
| """ | |
| def __init__( | |
| self, | |
| sources: List[str] = None, | |
| weights: List[float] = None, | |
| transform: Callable = None, | |
| relative_path: str = "", | |
| ext: List[str] = util.AUDIO_EXTENSIONS, | |
| shuffle: bool = True, | |
| shuffle_state: int = 0, | |
| ): | |
| self.audio_lists = util.read_sources( | |
| sources, relative_path=relative_path, ext=ext | |
| ) | |
| self.audio_indices = [ | |
| (src_idx, item_idx) | |
| for src_idx, src in enumerate(self.audio_lists) | |
| for item_idx in range(len(src)) | |
| ] | |
| if shuffle: | |
| state = util.random_state(shuffle_state) | |
| state.shuffle(self.audio_indices) | |
| self.sources = sources | |
| self.weights = weights | |
| self.transform = transform | |
| def __call__( | |
| self, | |
| state, | |
| sample_rate: int, | |
| duration: float, | |
| loudness_cutoff: float = -40, | |
| num_channels: int = 1, | |
| offset: float = None, | |
| source_idx: int = None, | |
| item_idx: int = None, | |
| global_idx: int = None, | |
| ): | |
| if source_idx is not None and item_idx is not None: | |
| try: | |
| audio_info = self.audio_lists[source_idx][item_idx] | |
| except: | |
| audio_info = {"path": "none"} | |
| elif global_idx is not None: | |
| source_idx, item_idx = self.audio_indices[ | |
| global_idx % len(self.audio_indices) | |
| ] | |
| audio_info = self.audio_lists[source_idx][item_idx] | |
| else: | |
| audio_info, source_idx, item_idx = util.choose_from_list_of_lists( | |
| state, self.audio_lists, p=self.weights | |
| ) | |
| path = audio_info["path"] | |
| signal = AudioSignal.zeros(duration, sample_rate, num_channels) | |
| if path != "none": | |
| if offset is None: | |
| signal = AudioSignal.salient_excerpt( | |
| path, | |
| duration=duration, | |
| state=state, | |
| loudness_cutoff=loudness_cutoff, | |
| ) | |
| else: | |
| signal = AudioSignal( | |
| path, | |
| offset=offset, | |
| duration=duration, | |
| ) | |
| if num_channels == 1: | |
| signal = signal.to_mono() | |
| signal = signal.resample(sample_rate) | |
| if signal.duration < duration: | |
| signal = signal.zero_pad_to(int(duration * sample_rate)) | |
| for k, v in audio_info.items(): | |
| signal.metadata[k] = v | |
| item = { | |
| "signal": signal, | |
| "source_idx": source_idx, | |
| "item_idx": item_idx, | |
| "source": str(self.sources[source_idx]), | |
| "path": str(path), | |
| } | |
| if self.transform is not None: | |
| item["transform_args"] = self.transform.instantiate(state, signal=signal) | |
| return item | |
| def default_matcher(x, y): | |
| return Path(x).parent == Path(y).parent | |
| def align_lists(lists, matcher: Callable = default_matcher): | |
| longest_list = lists[np.argmax([len(l) for l in lists])] | |
| for i, x in enumerate(longest_list): | |
| for l in lists: | |
| if i >= len(l): | |
| l.append({"path": "none"}) | |
| elif not matcher(l[i]["path"], x["path"]): | |
| l.insert(i, {"path": "none"}) | |
| return lists | |
| class AudioDataset: | |
| """Loads audio from multiple loaders (with associated transforms) | |
| for a specified number of samples. Excerpts are drawn randomly | |
| of the specified duration, above a specified loudness threshold | |
| and are resampled on the fly to the desired sample rate | |
| (if it is different from the audio source sample rate). | |
| This takes either a single AudioLoader object, | |
| a dictionary of AudioLoader objects, or a dictionary of AudioLoader | |
| objects. Each AudioLoader is called by the dataset, and the | |
| result is placed in the output dictionary. A transform can also be | |
| specified for the entire dataset, rather than for each specific | |
| loader. This transform can be applied to the output of all the | |
| loaders if desired. | |
| AudioLoader objects can be specified as aligned, which means the | |
| loaders correspond to multitrack audio (e.g. a vocals, bass, | |
| drums, and other loader for multitrack music mixtures). | |
| Parameters | |
| ---------- | |
| loaders : Union[AudioLoader, List[AudioLoader], Dict[str, AudioLoader]] | |
| AudioLoaders to sample audio from. | |
| sample_rate : int | |
| Desired sample rate. | |
| n_examples : int, optional | |
| Number of examples (length of dataset), by default 1000 | |
| duration : float, optional | |
| Duration of audio samples, by default 0.5 | |
| loudness_cutoff : float, optional | |
| Loudness cutoff threshold for audio samples, by default -40 | |
| num_channels : int, optional | |
| Number of channels in output audio, by default 1 | |
| transform : Callable, optional | |
| Transform to instantiate alongside each dataset item, by default None | |
| aligned : bool, optional | |
| Whether the loaders should be sampled in an aligned manner (e.g. same | |
| offset, duration, and matched file name), by default False | |
| shuffle_loaders : bool, optional | |
| Whether to shuffle the loaders before sampling from them, by default False | |
| matcher : Callable | |
| How to match files from adjacent audio lists (e.g. for a multitrack audio loader), | |
| by default uses the parent directory of each file. | |
| without_replacement : bool | |
| Whether to choose files with or without replacement, by default True. | |
| Examples | |
| -------- | |
| >>> from audiotools.data.datasets import AudioLoader | |
| >>> from audiotools.data.datasets import AudioDataset | |
| >>> from audiotools import transforms as tfm | |
| >>> import numpy as np | |
| >>> | |
| >>> loaders = [ | |
| >>> AudioLoader( | |
| >>> sources=[f"tests/audio/spk"], | |
| >>> transform=tfm.Equalizer(), | |
| >>> ext=["wav"], | |
| >>> ) | |
| >>> for i in range(5) | |
| >>> ] | |
| >>> | |
| >>> dataset = AudioDataset( | |
| >>> loaders = loaders, | |
| >>> sample_rate = 44100, | |
| >>> duration = 1.0, | |
| >>> transform = tfm.RescaleAudio(), | |
| >>> ) | |
| >>> | |
| >>> item = dataset[np.random.randint(len(dataset))] | |
| >>> | |
| >>> for i in range(len(loaders)): | |
| >>> item[i]["signal"] = loaders[i].transform( | |
| >>> item[i]["signal"], **item[i]["transform_args"] | |
| >>> ) | |
| >>> item[i]["signal"].widget(i) | |
| >>> | |
| >>> mix = sum([item[i]["signal"] for i in range(len(loaders))]) | |
| >>> mix = dataset.transform(mix, **item["transform_args"]) | |
| >>> mix.widget("mix") | |
| Below is an example of how one could load MUSDB multitrack data: | |
| >>> import audiotools as at | |
| >>> from pathlib import Path | |
| >>> from audiotools import transforms as tfm | |
| >>> import numpy as np | |
| >>> import torch | |
| >>> | |
| >>> def build_dataset( | |
| >>> sample_rate: int = 44100, | |
| >>> duration: float = 5.0, | |
| >>> musdb_path: str = "~/.data/musdb/", | |
| >>> ): | |
| >>> musdb_path = Path(musdb_path).expanduser() | |
| >>> loaders = { | |
| >>> src: at.datasets.AudioLoader( | |
| >>> sources=[musdb_path], | |
| >>> transform=tfm.Compose( | |
| >>> tfm.VolumeNorm(("uniform", -20, -10)), | |
| >>> tfm.Silence(prob=0.1), | |
| >>> ), | |
| >>> ext=[f"{src}.wav"], | |
| >>> ) | |
| >>> for src in ["vocals", "bass", "drums", "other"] | |
| >>> } | |
| >>> | |
| >>> dataset = at.datasets.AudioDataset( | |
| >>> loaders=loaders, | |
| >>> sample_rate=sample_rate, | |
| >>> duration=duration, | |
| >>> num_channels=1, | |
| >>> aligned=True, | |
| >>> transform=tfm.RescaleAudio(), | |
| >>> shuffle_loaders=True, | |
| >>> ) | |
| >>> return dataset, list(loaders.keys()) | |
| >>> | |
| >>> train_data, sources = build_dataset() | |
| >>> dataloader = torch.utils.data.DataLoader( | |
| >>> train_data, | |
| >>> batch_size=16, | |
| >>> num_workers=0, | |
| >>> collate_fn=train_data.collate, | |
| >>> ) | |
| >>> batch = next(iter(dataloader)) | |
| >>> | |
| >>> for k in sources: | |
| >>> src = batch[k] | |
| >>> src["transformed"] = train_data.loaders[k].transform( | |
| >>> src["signal"].clone(), **src["transform_args"] | |
| >>> ) | |
| >>> | |
| >>> mixture = sum(batch[k]["transformed"] for k in sources) | |
| >>> mixture = train_data.transform(mixture, **batch["transform_args"]) | |
| >>> | |
| >>> # Say a model takes the mix and gives back (n_batch, n_src, n_time). | |
| >>> # Construct the targets: | |
| >>> targets = at.AudioSignal.batch([batch[k]["transformed"] for k in sources], dim=1) | |
| Similarly, here's example code for loading Slakh data: | |
| >>> import audiotools as at | |
| >>> from pathlib import Path | |
| >>> from audiotools import transforms as tfm | |
| >>> import numpy as np | |
| >>> import torch | |
| >>> import glob | |
| >>> | |
| >>> def build_dataset( | |
| >>> sample_rate: int = 16000, | |
| >>> duration: float = 10.0, | |
| >>> slakh_path: str = "~/.data/slakh/", | |
| >>> ): | |
| >>> slakh_path = Path(slakh_path).expanduser() | |
| >>> | |
| >>> # Find the max number of sources in Slakh | |
| >>> src_names = [x.name for x in list(slakh_path.glob("**/*.wav")) if "S" in str(x.name)] | |
| >>> n_sources = len(list(set(src_names))) | |
| >>> | |
| >>> loaders = { | |
| >>> f"S{i:02d}": at.datasets.AudioLoader( | |
| >>> sources=[slakh_path], | |
| >>> transform=tfm.Compose( | |
| >>> tfm.VolumeNorm(("uniform", -20, -10)), | |
| >>> tfm.Silence(prob=0.1), | |
| >>> ), | |
| >>> ext=[f"S{i:02d}.wav"], | |
| >>> ) | |
| >>> for i in range(n_sources) | |
| >>> } | |
| >>> dataset = at.datasets.AudioDataset( | |
| >>> loaders=loaders, | |
| >>> sample_rate=sample_rate, | |
| >>> duration=duration, | |
| >>> num_channels=1, | |
| >>> aligned=True, | |
| >>> transform=tfm.RescaleAudio(), | |
| >>> shuffle_loaders=False, | |
| >>> ) | |
| >>> | |
| >>> return dataset, list(loaders.keys()) | |
| >>> | |
| >>> train_data, sources = build_dataset() | |
| >>> dataloader = torch.utils.data.DataLoader( | |
| >>> train_data, | |
| >>> batch_size=16, | |
| >>> num_workers=0, | |
| >>> collate_fn=train_data.collate, | |
| >>> ) | |
| >>> batch = next(iter(dataloader)) | |
| >>> | |
| >>> for k in sources: | |
| >>> src = batch[k] | |
| >>> src["transformed"] = train_data.loaders[k].transform( | |
| >>> src["signal"].clone(), **src["transform_args"] | |
| >>> ) | |
| >>> | |
| >>> mixture = sum(batch[k]["transformed"] for k in sources) | |
| >>> mixture = train_data.transform(mixture, **batch["transform_args"]) | |
| """ | |
| def __init__( | |
| self, | |
| loaders: Union[AudioLoader, List[AudioLoader], Dict[str, AudioLoader]], | |
| sample_rate: int, | |
| n_examples: int = 1000, | |
| duration: float = 0.5, | |
| offset: float = None, | |
| loudness_cutoff: float = -40, | |
| num_channels: int = 1, | |
| transform: Callable = None, | |
| aligned: bool = False, | |
| shuffle_loaders: bool = False, | |
| matcher: Callable = default_matcher, | |
| without_replacement: bool = True, | |
| ): | |
| # Internally we convert loaders to a dictionary | |
| if isinstance(loaders, list): | |
| loaders = {i: l for i, l in enumerate(loaders)} | |
| elif isinstance(loaders, AudioLoader): | |
| loaders = {0: loaders} | |
| self.loaders = loaders | |
| self.loudness_cutoff = loudness_cutoff | |
| self.num_channels = num_channels | |
| self.length = n_examples | |
| self.transform = transform | |
| self.sample_rate = sample_rate | |
| self.duration = duration | |
| self.offset = offset | |
| self.aligned = aligned | |
| self.shuffle_loaders = shuffle_loaders | |
| self.without_replacement = without_replacement | |
| if aligned: | |
| loaders_list = list(loaders.values()) | |
| for i in range(len(loaders_list[0].audio_lists)): | |
| input_lists = [l.audio_lists[i] for l in loaders_list] | |
| # Alignment happens in-place | |
| align_lists(input_lists, matcher) | |
| def __getitem__(self, idx): | |
| state = util.random_state(idx) | |
| offset = None if self.offset is None else self.offset | |
| item = {} | |
| keys = list(self.loaders.keys()) | |
| if self.shuffle_loaders: | |
| state.shuffle(keys) | |
| loader_kwargs = { | |
| "state": state, | |
| "sample_rate": self.sample_rate, | |
| "duration": self.duration, | |
| "loudness_cutoff": self.loudness_cutoff, | |
| "num_channels": self.num_channels, | |
| "global_idx": idx if self.without_replacement else None, | |
| } | |
| # Draw item from first loader | |
| loader = self.loaders[keys[0]] | |
| item[keys[0]] = loader(**loader_kwargs) | |
| for key in keys[1:]: | |
| loader = self.loaders[key] | |
| if self.aligned: | |
| # Path mapper takes the current loader + everything | |
| # returned by the first loader. | |
| offset = item[keys[0]]["signal"].metadata["offset"] | |
| loader_kwargs.update( | |
| { | |
| "offset": offset, | |
| "source_idx": item[keys[0]]["source_idx"], | |
| "item_idx": item[keys[0]]["item_idx"], | |
| } | |
| ) | |
| item[key] = loader(**loader_kwargs) | |
| # Sort dictionary back into original order | |
| keys = list(self.loaders.keys()) | |
| item = {k: item[k] for k in keys} | |
| item["idx"] = idx | |
| if self.transform is not None: | |
| item["transform_args"] = self.transform.instantiate( | |
| state=state, signal=item[keys[0]]["signal"] | |
| ) | |
| # If there's only one loader, pop it up | |
| # to the main dictionary, instead of keeping it | |
| # nested. | |
| if len(keys) == 1: | |
| item.update(item.pop(keys[0])) | |
| return item | |
| def __len__(self): | |
| return self.length | |
| def collate(list_of_dicts: Union[list, dict], n_splits: int = None): | |
| """Collates items drawn from this dataset. Uses | |
| :py:func:`audiotools.core.util.collate`. | |
| Parameters | |
| ---------- | |
| list_of_dicts : typing.Union[list, dict] | |
| Data drawn from each item. | |
| n_splits : int | |
| Number of splits to make when creating the batches (split into | |
| sub-batches). Useful for things like gradient accumulation. | |
| Returns | |
| ------- | |
| dict | |
| Dictionary of batched data. | |
| """ | |
| return util.collate(list_of_dicts, n_splits=n_splits) | |
| class ConcatDataset(AudioDataset): | |
| def __init__(self, datasets: list): | |
| self.datasets = datasets | |
| def __len__(self): | |
| return sum([len(d) for d in self.datasets]) | |
| def __getitem__(self, idx): | |
| dataset = self.datasets[idx % len(self.datasets)] | |
| return dataset[idx // len(self.datasets)] | |
| class ResumableDistributedSampler(DistributedSampler): # pragma: no cover | |
| """Distributed sampler that can be resumed from a given start index.""" | |
| def __init__(self, dataset, start_idx: int = None, **kwargs): | |
| super().__init__(dataset, **kwargs) | |
| # Start index, allows to resume an experiment at the index it was | |
| self.start_idx = start_idx // self.num_replicas if start_idx is not None else 0 | |
| def __iter__(self): | |
| for i, idx in enumerate(super().__iter__()): | |
| if i >= self.start_idx: | |
| yield idx | |
| self.start_idx = 0 # set the index back to 0 so for the next epoch | |
| class ResumableSequentialSampler(SequentialSampler): # pragma: no cover | |
| """Sequential sampler that can be resumed from a given start index.""" | |
| def __init__(self, dataset, start_idx: int = None, **kwargs): | |
| super().__init__(dataset, **kwargs) | |
| # Start index, allows to resume an experiment at the index it was | |
| self.start_idx = start_idx if start_idx is not None else 0 | |
| def __iter__(self): | |
| for i, idx in enumerate(super().__iter__()): | |
| if i >= self.start_idx: | |
| yield idx | |
| self.start_idx = 0 # set the index back to 0 so for the next epoch | |