Spaces:
Build error
Build error
| # Copyright (c) Facebook, Inc. and its affiliates. | |
| # | |
| # This source code is licensed under the MIT license found in the | |
| # LICENSE file in the root directory of this source tree. | |
| import functools | |
| import tempfile | |
| import torch | |
| def spawn_and_init(fn, world_size, args=None): | |
| if args is None: | |
| args = () | |
| with tempfile.NamedTemporaryFile(delete=False) as tmp_file: | |
| torch.multiprocessing.spawn( | |
| fn=functools.partial(init_and_run, fn, args), | |
| args=(world_size, tmp_file.name,), | |
| nprocs=world_size, | |
| join=True, | |
| ) | |
| def distributed_init(rank, world_size, tmp_file): | |
| torch.distributed.init_process_group( | |
| backend="nccl", | |
| init_method="file://{}".format(tmp_file), | |
| world_size=world_size, | |
| rank=rank, | |
| ) | |
| torch.cuda.set_device(rank) | |
| def init_and_run(fn, args, rank, world_size, tmp_file): | |
| distributed_init(rank, world_size, tmp_file) | |
| group = torch.distributed.new_group() | |
| fn(rank, group, *args) | |
| def objects_are_equal(a, b) -> bool: | |
| if type(a) is not type(b): | |
| return False | |
| if isinstance(a, dict): | |
| if set(a.keys()) != set(b.keys()): | |
| return False | |
| for k in a.keys(): | |
| if not objects_are_equal(a[k], b[k]): | |
| return False | |
| return True | |
| elif isinstance(a, (list, tuple, set)): | |
| if len(a) != len(b): | |
| return False | |
| return all(objects_are_equal(x, y) for x, y in zip(a, b)) | |
| elif torch.is_tensor(a): | |
| return ( | |
| a.size() == b.size() | |
| and a.dtype == b.dtype | |
| and a.device == b.device | |
| and torch.all(a == b) | |
| ) | |
| else: | |
| return a == b | |