Spaces:
Paused
Paused
| import gc | |
| import os | |
| import re | |
| import torch | |
| import torch.distributed as dist | |
| from packaging import version | |
| from step1x3d_geometry.utils.config import config_to_primitive | |
| from step1x3d_geometry.utils.typing import * | |
| def parse_version(ver: str): | |
| return version.parse(ver) | |
| def get_rank(): | |
| # SLURM_PROCID can be set even if SLURM is not managing the multiprocessing, | |
| # therefore LOCAL_RANK needs to be checked first | |
| rank_keys = ("RANK", "LOCAL_RANK", "SLURM_PROCID", "JSM_NAMESPACE_RANK") | |
| for key in rank_keys: | |
| rank = os.environ.get(key) | |
| if rank is not None: | |
| return int(rank) | |
| return 0 | |
| def get_world_size(): | |
| world_size_keys = ("WORLD_SIZE", "SLURM_NTASKS", "JSM_NAMESPACE_SIZE") | |
| for key in world_size_keys: | |
| world_size = os.environ.get(key) | |
| if world_size is not None: | |
| return int(world_size) | |
| return 1 | |
| def get_device(): | |
| return torch.device(f"cuda:{get_rank()}") | |
| def load_module_weights( | |
| path, module_name=None, ignore_modules=None, map_location=None | |
| ) -> Tuple[dict, int, int]: | |
| if module_name is not None and ignore_modules is not None: | |
| raise ValueError("module_name and ignore_modules cannot be both set") | |
| if map_location is None: | |
| map_location = get_device() | |
| ckpt = torch.load(path, map_location=map_location) | |
| state_dict = ckpt["state_dict"] | |
| state_dict_to_load = state_dict | |
| if ignore_modules is not None: | |
| state_dict_to_load = {} | |
| for k, v in state_dict.items(): | |
| ignore = any( | |
| [k.startswith(ignore_module + ".") for ignore_module in ignore_modules] | |
| ) | |
| if ignore: | |
| continue | |
| state_dict_to_load[k] = v | |
| if module_name is not None: | |
| state_dict_to_load = {} | |
| for k, v in state_dict.items(): | |
| m = re.match(rf"^{module_name}\.(.*)$", k) | |
| if m is None: | |
| continue | |
| state_dict_to_load[m.group(1)] = v | |
| return state_dict_to_load, ckpt["epoch"], ckpt["global_step"] | |
| def C(value: Any, epoch: int, global_step: int) -> float: | |
| if isinstance(value, int) or isinstance(value, float): | |
| pass | |
| else: | |
| value = config_to_primitive(value) | |
| if not isinstance(value, list): | |
| raise TypeError("Scalar specification only supports list, got", type(value)) | |
| if len(value) == 3: | |
| value = [0] + value | |
| assert len(value) == 4 | |
| start_step, start_value, end_value, end_step = value | |
| if isinstance(end_step, int): | |
| current_step = global_step | |
| value = start_value + (end_value - start_value) * max( | |
| min(1.0, (current_step - start_step) / (end_step - start_step)), 0.0 | |
| ) | |
| elif isinstance(end_step, float): | |
| current_step = epoch | |
| value = start_value + (end_value - start_value) * max( | |
| min(1.0, (current_step - start_step) / (end_step - start_step)), 0.0 | |
| ) | |
| return value | |
| def cleanup(): | |
| gc.collect() | |
| torch.cuda.empty_cache() | |
| tcnn.free_temporary_memory() | |
| def finish_with_cleanup(func: Callable): | |
| def wrapper(*args, **kwargs): | |
| out = func(*args, **kwargs) | |
| cleanup() | |
| return out | |
| return wrapper | |
| def _distributed_available(): | |
| return torch.distributed.is_available() and torch.distributed.is_initialized() | |
| def barrier(): | |
| if not _distributed_available(): | |
| return | |
| else: | |
| torch.distributed.barrier() | |
| def broadcast(tensor, src=0): | |
| if not _distributed_available(): | |
| return tensor | |
| else: | |
| torch.distributed.broadcast(tensor, src=src) | |
| return tensor | |
| def enable_gradient(model, enabled: bool = True) -> None: | |
| for param in model.parameters(): | |
| param.requires_grad_(enabled) | |
| def all_gather_batch(tensors): | |
| """ | |
| Performs all_gather operation on the provided tensors. | |
| """ | |
| # Queue the gathered tensors | |
| world_size = get_world_size() | |
| # There is no need for reduction in the single-proc case | |
| if world_size == 1: | |
| if isinstance(tensors, list): | |
| return tensors | |
| return tensors | |
| if not isinstance(tensors, list): | |
| is_list = False | |
| tensors = [tensors] | |
| else: | |
| is_list = True | |
| output_tensor = [] | |
| tensor_list = [] | |
| for tensor in tensors: | |
| tensor_all = [torch.ones_like(tensor) for _ in range(world_size)] | |
| dist.all_gather(tensor_all, tensor, async_op=False) # performance opt | |
| tensor_list.append(tensor_all) | |
| for tensor_all in tensor_list: | |
| output_tensor.append(torch.cat(tensor_all, dim=0)) | |
| if not is_list: | |
| return output_tensor[0] | |
| return output_tensor | |