Spaces:
Running
Running
| # Copyright (c) Meta Platforms, Inc. and affiliates. | |
| # All rights reserved. | |
| # | |
| # This source code is licensed under the license found in the | |
| # LICENSE file in the root directory of this source tree. | |
| """Resampling script. | |
| """ | |
| import argparse | |
| from pathlib import Path | |
| import shutil | |
| import typing as tp | |
| import submitit | |
| import tqdm | |
| from audiocraft.data.audio import audio_read, audio_write | |
| from audiocraft.data.audio_dataset import load_audio_meta, find_audio_files | |
| from audiocraft.data.audio_utils import convert_audio | |
| from audiocraft.environment import AudioCraftEnvironment | |
| def read_txt_files(path: tp.Union[str, Path]): | |
| with open(args.files_path) as f: | |
| lines = [line.rstrip() for line in f] | |
| print(f"Read {len(lines)} in .txt") | |
| lines = [line for line in lines if Path(line).suffix not in ['.json', '.txt', '.csv']] | |
| print(f"Filtered and keep {len(lines)} from .txt") | |
| return lines | |
| def read_egs_files(path: tp.Union[str, Path]): | |
| path = Path(path) | |
| if path.is_dir(): | |
| if (path / 'data.jsonl').exists(): | |
| path = path / 'data.jsonl' | |
| elif (path / 'data.jsonl.gz').exists(): | |
| path = path / 'data.jsonl.gz' | |
| else: | |
| raise ValueError("Don't know where to read metadata from in the dir. " | |
| "Expecting either a data.jsonl or data.jsonl.gz file but none found.") | |
| meta = load_audio_meta(path) | |
| return [m.path for m in meta] | |
| def process_dataset(args, n_shards: int, node_index: int, task_index: tp.Optional[int] = None): | |
| if task_index is None: | |
| env = submitit.JobEnvironment() | |
| task_index = env.global_rank | |
| shard_index = node_index * args.tasks_per_node + task_index | |
| if args.files_path is None: | |
| lines = [m.path for m in find_audio_files(args.root_path, resolve=False, progress=True, workers=8)] | |
| else: | |
| files_path = Path(args.files_path) | |
| if files_path.suffix == '.txt': | |
| print(f"Reading file list from .txt file: {args.files_path}") | |
| lines = read_txt_files(args.files_path) | |
| else: | |
| print(f"Reading file list from egs: {args.files_path}") | |
| lines = read_egs_files(args.files_path) | |
| total_files = len(lines) | |
| print( | |
| f"Total of {total_files} processed with {n_shards} shards. " + | |
| f"Current idx = {shard_index} -> {total_files // n_shards} files to process" | |
| ) | |
| for idx, line in tqdm.tqdm(enumerate(lines)): | |
| # skip if not part of this shard | |
| if idx % n_shards != shard_index: | |
| continue | |
| path = str(AudioCraftEnvironment.apply_dataset_mappers(line)) | |
| root_path = str(args.root_path) | |
| if not root_path.endswith('/'): | |
| root_path += '/' | |
| assert path.startswith(str(root_path)), \ | |
| f"Mismatch between path and provided root: {path} VS {root_path}" | |
| try: | |
| metadata_path = Path(path).with_suffix('.json') | |
| out_path = args.out_path / path[len(root_path):] | |
| out_metadata_path = out_path.with_suffix('.json') | |
| out_done_token = out_path.with_suffix('.done') | |
| # don't reprocess existing files | |
| if out_done_token.exists(): | |
| continue | |
| print(idx, out_path, path) | |
| mix, sr = audio_read(path) | |
| mix_channels = args.channels if args.channels is not None and args.channels > 0 else mix.size(0) | |
| # enforce simple stereo | |
| out_channels = mix_channels | |
| if out_channels > 2: | |
| print(f"Mix has more than two channels: {out_channels}, enforcing 2 channels") | |
| out_channels = 2 | |
| out_sr = args.sample_rate if args.sample_rate is not None else sr | |
| out_wav = convert_audio(mix, sr, out_sr, out_channels) | |
| audio_write(out_path.with_suffix(''), out_wav, sample_rate=out_sr, | |
| format=args.format, normalize=False, strategy='clip') | |
| if metadata_path.exists(): | |
| shutil.copy(metadata_path, out_metadata_path) | |
| else: | |
| print(f"No metadata found at {str(metadata_path)}") | |
| out_done_token.touch() | |
| except Exception as e: | |
| print(f"Error processing file line: {line}, {e}") | |
| if __name__ == '__main__': | |
| parser = argparse.ArgumentParser(description="Resample dataset with SLURM.") | |
| parser.add_argument( | |
| "--log_root", | |
| type=Path, | |
| default=Path.home() / 'tmp' / 'resample_logs', | |
| ) | |
| parser.add_argument( | |
| "--files_path", | |
| type=Path, | |
| help="List of files to process, either .txt (one file per line) or a jsonl[.gz].", | |
| ) | |
| parser.add_argument( | |
| "--root_path", | |
| type=Path, | |
| required=True, | |
| help="When rewriting paths, this will be the prefix to remove.", | |
| ) | |
| parser.add_argument( | |
| "--out_path", | |
| type=Path, | |
| required=True, | |
| help="When rewriting paths, `root_path` will be replaced by this.", | |
| ) | |
| parser.add_argument("--xp_name", type=str, default="shutterstock") | |
| parser.add_argument( | |
| "--nodes", | |
| type=int, | |
| default=4, | |
| ) | |
| parser.add_argument( | |
| "--tasks_per_node", | |
| type=int, | |
| default=20, | |
| ) | |
| parser.add_argument( | |
| "--cpus_per_task", | |
| type=int, | |
| default=4, | |
| ) | |
| parser.add_argument( | |
| "--memory_gb", | |
| type=int, | |
| help="Memory in GB." | |
| ) | |
| parser.add_argument( | |
| "--format", | |
| type=str, | |
| default="wav", | |
| ) | |
| parser.add_argument( | |
| "--sample_rate", | |
| type=int, | |
| default=32000, | |
| ) | |
| parser.add_argument( | |
| "--channels", | |
| type=int, | |
| ) | |
| parser.add_argument( | |
| "--partition", | |
| default='learnfair', | |
| ) | |
| parser.add_argument("--qos") | |
| parser.add_argument("--account") | |
| parser.add_argument("--timeout", type=int, default=4320) | |
| parser.add_argument('--debug', action='store_true', help='debug mode (local run)') | |
| args = parser.parse_args() | |
| n_shards = args.tasks_per_node * args.nodes | |
| if args.files_path is None: | |
| print("Warning: --files_path not provided, not recommended when processing more than 10k files.") | |
| if args.debug: | |
| print("Debugging mode") | |
| process_dataset(args, n_shards=n_shards, node_index=0, task_index=0) | |
| else: | |
| log_folder = Path(args.log_root) / args.xp_name / '%j' | |
| print(f"Logging to: {log_folder}") | |
| log_folder.parent.mkdir(parents=True, exist_ok=True) | |
| executor = submitit.AutoExecutor(folder=str(log_folder)) | |
| if args.qos: | |
| executor.update_parameters(slurm_partition=args.partition, slurm_qos=args.qos, slurm_account=args.account) | |
| else: | |
| executor.update_parameters(slurm_partition=args.partition) | |
| executor.update_parameters( | |
| slurm_job_name=args.xp_name, timeout_min=args.timeout, | |
| cpus_per_task=args.cpus_per_task, tasks_per_node=args.tasks_per_node, nodes=1) | |
| if args.memory_gb: | |
| executor.update_parameters(mem=f'{args.memory_gb}GB') | |
| jobs = [] | |
| with executor.batch(): | |
| for node_index in range(args.nodes): | |
| job = executor.submit(process_dataset, args, n_shards=n_shards, node_index=node_index) | |
| jobs.append(job) | |
| for job in jobs: | |
| print(f"Waiting on job {job.job_id}") | |
| job.results() | |