Spaces:
Running
Running
| """slakh Dataset Loader | |
| .. admonition:: Dataset Info | |
| :class: dropdown | |
| • This code is modified to use the Slakh2100 dataset converted into 16k. | |
| • Unlike slakh, this version treats drum tracks as pitched instruments (80 notes appears). | |
| See Line 243, 356. | |
| The Synthesized Lakh (Slakh) Dataset is a dataset of multi-track audio and aligned | |
| MIDI for music source separation and multi-instrument automatic transcription. | |
| Individual MIDI tracks are synthesized from the Lakh MIDI Dataset v0.1 using | |
| professional-grade sample-based virtual instruments, and the resulting audio is | |
| mixed together to make musical mixtures. | |
| The original release of Slakh, called Slakh2100, | |
| contains 2100 automatically mixed tracks and accompanying, aligned MIDI files, | |
| synthesized from 187 instrument patches categorized into 34 classes, totaling | |
| 145 hours of mixture data. | |
| This loader supports two versions of Slakh: | |
| - Slakh2100-redux: a deduplicated version of slakh2100 containing 1710 multitracks | |
| - baby-slakh: a mini version with 16k wav audio and only the first 20 tracks | |
| This dataset was created at Mitsubishi Electric Research Labl (MERL) and | |
| Interactive Audio Lab at Northwestern University by Ethan Manilow, | |
| Gordon Wichern, Prem Seetharaman, and Jonathan Le Roux. | |
| For more information see http://www.slakh.com/ | |
| """ | |
| import os | |
| from typing import BinaryIO, Optional, Tuple | |
| from deprecated.sphinx import deprecated | |
| import librosa | |
| import numpy as np | |
| import pretty_midi | |
| from smart_open import open | |
| import yaml | |
| from mirdata import io, download_utils, jams_utils, core, annotations | |
| BIBTEX = """ | |
| @inproceedings{manilow2019cutting, | |
| title={Cutting Music Source Separation Some {Slakh}: A Dataset to Study the Impact of Training Data Quality and Quantity}, | |
| author={Manilow, Ethan and Wichern, Gordon and Seetharaman, Prem and Le Roux, Jonathan}, | |
| booktitle={Proc. IEEE Workshop on Applications of Signal Processing to Audio and Acoustics (WASPAA)}, | |
| year={2019}, | |
| organization={IEEE} | |
| } | |
| """ | |
| INDEXES = { | |
| "default": | |
| "2100-yourmt3-16k", | |
| "test": | |
| "baby", | |
| "2100-yourmt3-16k": | |
| core.Index( | |
| filename="slakh_index_2100-yourmt3-16k.json", | |
| url="https://zenodo.org/record/7717249/files/slakh_index_2100-yourmt3-16k.json?download=1", | |
| checksum="fab898bd82827ddc4c3e4dbd7b7fcbd9", | |
| partial_download=["2100-yourmt3-16k"]), | |
| "2100-redux": | |
| core.Index(filename="slakh_index_2100-redux.json", partial_download=["2100-redux"]), | |
| "baby": | |
| core.Index(filename="slakh_index_baby.json", partial_download=["baby"]), | |
| } | |
| REMOTES = { | |
| "2100-yourmt3-16k": | |
| download_utils.RemoteFileMetadata( | |
| filename="slakh2100_yourmt3_16k.tar.gz", | |
| url="https://zenodo.org/record/7717249/files/slakh2100_yourmt3_16k.tar.gz?download=1", | |
| checksum="c44f9bcba07b3c6ddeaf604f45dc61c5", | |
| ), | |
| "2100-redux": | |
| download_utils.RemoteFileMetadata( | |
| filename="slakh2100_flac_redux.tar.gz", | |
| url="https://zenodo.org/record/4599666/files/slakh2100_flac_redux.tar.gz?download=1", | |
| checksum="f4b71b6c45ac9b506f59788456b3f0c4", | |
| ), | |
| "baby": | |
| download_utils.RemoteFileMetadata( | |
| filename="babyslakh_16k.tar.gz", | |
| url="https://zenodo.org/record/4603870/files/babyslakh_16k.tar.gz?download=1", | |
| checksum="311096dc2bde7d61c97e930edbfc7f78", | |
| ), | |
| } | |
| LICENSE_INFO = """ | |
| Creative Commons Attribution 4.0 International | |
| """ | |
| SPLITS = ["train", "validation", "test", "omitted"] | |
| SPLITS_16K = ["train", "validation", "test"] | |
| #: Mixing group to program number mapping | |
| MIXING_GROUPS = { | |
| "piano": [0, 1, 2, 3, 4, 5, 6, 7], | |
| "guitar": [24, 25, 26, 27, 28, 29, 30, 31], | |
| "bass": [32, 33, 34, 35, 36, 37, 38, 39], | |
| "drums": [128], | |
| } | |
| class Track(core.Track): | |
| """slakh Track class, for individual stems | |
| Attributes: | |
| audio_path (str or None): path to the track's audio file. For some unusual tracks, | |
| such as sound effects, there is no audio and this attribute is None. | |
| split (str or None): one of 'train', 'validation', 'test', or 'omitted'. | |
| 'omitted' tracks are part of slakh2100-redux which were found to be | |
| duplicates in the original slakh2011. | |
| In baby slakh there are no splits, so this attribute is None. | |
| data_split (str or None): equivalent to split (deprecated in 0.3.6) | |
| metadata_path (str): path to the multitrack's metadata file | |
| midi_path (str or None): path to the track's midi file. For some unusual tracks, | |
| such as sound effects, there is no midi and this attribute is None. | |
| mtrack_id (str): the track's multitrack id | |
| track_id (str): track id | |
| instrument (str): MIDI instrument class, see link for details: | |
| https://en.wikipedia.org/wiki/General_MIDI#Program_change_events | |
| integrated_loudness (float): integrated loudness (dB) of this track | |
| as calculated by the ITU-R BS.1770-4 spec | |
| is_drum (bool): whether the "drum" flag is true for this MIDI track | |
| midi_program_name (str): MIDI instrument program name | |
| plugin_name (str): patch/plugin name that rendered the audio file | |
| mixing_group (str): which mixing group the track belongs to. | |
| One of MIXING_GROUPS. | |
| program_number (int): MIDI instrument program number | |
| Cached Properties: | |
| midi (PrettyMIDI): midi data used to generate the audio | |
| notes (NoteData or None): note representation of the midi data. | |
| If there are no notes in the midi file, returns None. | |
| multif0 (MultiF0Data or None): multif0 representaation of the midi data. | |
| If there are no notes in the midi file, returns None. | |
| """ | |
| def __init__(self, track_id, data_home, dataset_name, index, metadata): | |
| super().__init__( | |
| track_id, | |
| data_home, | |
| dataset_name=dataset_name, | |
| index=index, | |
| metadata=metadata, | |
| ) | |
| self.mtrack_id = self.track_id.split("-")[0] | |
| self.audio_path = self.get_path("audio") | |
| self.midi_path = self.get_path("midi") | |
| self.metadata_path = self.get_path("metadata") | |
| # split (train/validation/test/omitted) is part of the relative filepath in the index | |
| self.split = None # for baby_slakh, there are no data splits - set to None | |
| # if index["version"] == "2100-redux": | |
| if "2100-redux" in index["version"]: | |
| self.split = self._track_paths["metadata"][0].split(os.sep)[1] | |
| assert (self.split in SPLITS), "{} not a valid split - should be one of {}.".format( | |
| self.split, SPLITS) | |
| elif "2100-yourmt3" in index["version"]: | |
| self.split = self._track_paths["metadata"][0].split(os.sep)[1] | |
| assert (self.split in SPLITS_16K), "{} not a valid split - should be one of {}.".format( | |
| self.split, SPLITS_16K) | |
| self.data_split = self.split # deprecated in 0.3.6 | |
| def _track_metadata(self) -> dict: | |
| try: | |
| with open(self.metadata_path, "r") as fhandle: | |
| metadata = yaml.safe_load(fhandle) | |
| except FileNotFoundError: | |
| raise FileNotFoundError( | |
| f"track metadata for {self.track_id} not found. Did you run .download()?") | |
| return metadata["stems"][self.track_id.split("-")[1]] | |
| def instrument(self) -> Optional[str]: | |
| return self._track_metadata.get("inst_class") | |
| def integrated_loudness(self) -> Optional[float]: | |
| return self._track_metadata.get("integrated_loudness") | |
| def is_drum(self) -> Optional[bool]: | |
| return self._track_metadata.get("is_drum") | |
| def midi_program_name(self) -> Optional[str]: | |
| return self._track_metadata.get("midi_program_name") | |
| def plugin_name(self) -> Optional[str]: | |
| return self._track_metadata.get("plugin_name") | |
| def program_number(self) -> Optional[int]: | |
| return self._track_metadata.get("program_num") | |
| def mixing_group(self) -> Optional[str]: | |
| group = [k for k, v in MIXING_GROUPS.items() if self.program_number in v] | |
| if len(group) == 0: | |
| return None | |
| return group[0] | |
| def midi(self) -> Optional[pretty_midi.PrettyMIDI]: | |
| return io.load_midi(self.midi_path) | |
| def notes(self) -> Optional[annotations.NoteData]: | |
| return io.load_notes_from_midi(self.midi_path, self.midi, skip_drums=False) | |
| def multif0(self) -> Optional[annotations.MultiF0Data]: | |
| return io.load_multif0_from_midi( | |
| self.midi_path, self.midi, skip_drums=True, pitch_bend=False) | |
| def audio(self) -> Optional[Tuple[np.ndarray, float]]: | |
| """The track's audio | |
| Returns: | |
| * np.ndarray - audio signal | |
| * float - sample rate | |
| """ | |
| return load_audio(self.audio_path) | |
| def to_jams(self): | |
| """Jams: the track's data in jams format""" | |
| return jams_utils.jams_converter( | |
| audio_path=self.audio_path, | |
| note_data=[(self.notes, "Notes")], | |
| ) | |
| class MultiTrack(core.MultiTrack): | |
| """slakh multitrack class, containing information about the mix and | |
| the set of associated stems | |
| Attributes: | |
| mtrack_id (str): track id | |
| tracks (dict): {track_id: Track} | |
| track_audio_property (str): the name of the attribute of Track which | |
| returns the audio to be mixed | |
| mix_path (str): path to the multitrack mix audio | |
| midi_path (str): path to the full midi data used to generate the mixture | |
| metadata_path (str): path to the multitrack metadata file | |
| split (str or None): one of 'train', 'validation', 'test', or 'omitted'. | |
| 'omitted' tracks are part of slakh2100-redux which were found to be | |
| duplicates in the original slakh2011. | |
| data_split (str or None): equivalent to split (deprecated in 0.3.6) | |
| uuid (str): File name of the original MIDI file from Lakh, sans extension | |
| lakh_midi_dir (str): Path to the original MIDI file from a fresh download of Lakh | |
| normalized (bool): whether the mix and stems were normalized according to the ITU-R BS.1770-4 spec | |
| overall_gain (float): gain applied to every stem to make sure mixture does not clip when stems are summed | |
| Cached Properties: | |
| midi (PrettyMIDI): midi data used to generate the mixture audio | |
| notes (NoteData): note representation of the midi data | |
| multif0 (MultiF0Data): multif0 representation of the midi data | |
| """ | |
| def __init__(self, mtrack_id, data_home, dataset_name, index, track_class, metadata): | |
| super().__init__( | |
| mtrack_id=mtrack_id, | |
| data_home=data_home, | |
| dataset_name=dataset_name, | |
| index=index, | |
| track_class=track_class, | |
| metadata=metadata, | |
| ) | |
| self.mix_path = self.get_path("mix") | |
| self.midi_path = self.get_path("midi") | |
| self.metadata_path = self.get_path("metadata") | |
| # split (train/validation/test) is determined by the relative filepath in the index | |
| self.split = None # for baby_slakh, there are no data splits - set to None | |
| # if index["version"] == "2100-redux": | |
| if "2100-redux" in index["version"]: | |
| self.split = self._multitrack_paths["mix"][0].split(os.sep)[1] | |
| assert self.split in SPLITS, "{} not in SPLITS".format(self.split) | |
| elif "2100-yourmt3" in index["version"]: | |
| self.split = self._multitrack_paths["mix"][0].split(os.sep)[1] | |
| assert self.split in SPLITS_16K, "{} not in SPLITS".format(self.split) | |
| self.data_split = self.split # deprecated in 0.3.6 | |
| def track_audio_property(self) -> str: | |
| return "audio" | |
| def _multitrack_metadata(self) -> dict: | |
| try: | |
| with open(self.metadata_path, "r") as fhandle: | |
| metadata = yaml.safe_load(fhandle) | |
| except FileNotFoundError: | |
| raise FileNotFoundError("Metadata not found. Did you run .download()?") | |
| return metadata | |
| def uuid(self) -> Optional[str]: | |
| return self._multitrack_metadata.get("UUID") | |
| def lakh_midi_dir(self) -> Optional[str]: | |
| return self._multitrack_metadata.get("lmd_midi_dir") | |
| def normalized(self) -> Optional[bool]: | |
| return self._multitrack_metadata.get("normalized") | |
| def overall_gain(self) -> Optional[float]: | |
| return self._multitrack_metadata.get("overall_gain") | |
| def midi(self) -> Optional[pretty_midi.PrettyMIDI]: | |
| return io.load_midi(self.midi_path) | |
| def notes(self) -> Optional[annotations.NoteData]: | |
| return io.load_notes_from_midi(self.midi_path, self.midi, skip_drums=False) | |
| def multif0(self) -> Optional[annotations.MultiF0Data]: | |
| # TODO: setting pitch_bend to False by default, but there are some | |
| # patches that render pitch bend in the audio. | |
| return io.load_multif0_from_midi( | |
| self.midi_path, self.midi, skip_drums=False, pitch_bend=False) | |
| def audio(self) -> Optional[Tuple[np.ndarray, float]]: | |
| """The track's audio | |
| Returns: | |
| * np.ndarray - audio signal | |
| * float - sample rate | |
| """ | |
| return load_audio(self.mix_path) | |
| def to_jams(self): | |
| """Jams: the track's data in jams format""" | |
| return jams_utils.jams_converter( | |
| audio_path=self.mix_path, | |
| note_data=[(self.notes, "Notes")], | |
| ) | |
| def get_submix_by_group(self, target_groups): | |
| """Create submixes grouped by instrument type. Creates one submix | |
| per target group, plus one additional "other" group for any remaining sources. | |
| Only tracks with available audio are mixed. | |
| Args: | |
| target_groups (list): List of target groups. Elements should be one of | |
| MIXING_GROUPS, e.g. ["bass", "guitar"] | |
| Returns: | |
| * submixes (dict): {group: audio_signal} of submixes | |
| * groups (dict): {group: list of track ids} of submixes | |
| """ | |
| groups = {} | |
| submixes = {} | |
| tracks_with_audio = [track for track in self.tracks.values() if track.audio_path] | |
| in_group = [] | |
| for group in target_groups: | |
| groups[group] = [ | |
| track.track_id for track in tracks_with_audio if track.mixing_group == group | |
| ] | |
| in_group.extend(groups[group]) | |
| submixes[group] = (None if len(groups[group]) == 0 else self.get_target(groups[group])) | |
| groups["other"] = [ | |
| track.track_id for track in tracks_with_audio if track.track_id not in in_group | |
| ] | |
| submixes["other"] = (None | |
| if len(groups["other"]) == 0 else self.get_target(groups["other"])) | |
| return submixes, groups | |
| def load_audio(fhandle: BinaryIO) -> Tuple[np.ndarray, float]: | |
| """Load a slakh audio file. | |
| Args: | |
| fhandle (str or file-like): path or file-like object pointing to an audio file | |
| Returns: | |
| * np.ndarray - the audio signal | |
| * float - The sample rate of the audio file | |
| """ | |
| return librosa.load(fhandle, sr=None, mono=False) | |
| class Dataset(core.Dataset): | |
| """ | |
| The slakh dataset | |
| """ | |
| def __init__(self, data_home=None, version="default"): | |
| super().__init__( | |
| data_home, | |
| version, | |
| name="slakh", | |
| track_class=Track, | |
| multitrack_class=MultiTrack, | |
| bibtex=BIBTEX, | |
| indexes=INDEXES, | |
| remotes=REMOTES, | |
| license_info=LICENSE_INFO, | |
| ) | |
| def load_audio(self, *args, **kwargs): | |
| return load_audio(*args, **kwargs) | |
| def load_midi(self, *args, **kwargs): | |
| return io.load_midi(*args, **kwargs) | |
| def load_notes_from_midi(self, *args, **kwargs): | |
| return io.load_notes_from_midi(*args, **kwargs) | |
| def load_multif0_from_midi(self, *args, **kwargs): | |
| return io.load_multif0_from_midi(*args, **kwargs) |