Spaces:
Runtime error
Runtime error
| # -------------------------------------------------------- | |
| # Swin Transformer | |
| # Copyright (c) 2021 Microsoft | |
| # Licensed under The MIT License [see LICENSE for details] | |
| # Written by Ze Liu | |
| # -------------------------------------------------------- | |
| import os | |
| import torch | |
| import torch.distributed as dist | |
| from torch._six import inf | |
| def load_checkpoint(file_name, config, model, model_ema, optimizer, lr_scheduler, loss_scaler, logger): | |
| if config.model.params.deepspeed != '': | |
| file_name = file_name.split('/') | |
| ckptdir = '/'.join(file_name[:-1]) | |
| tag = file_name[-1] | |
| _, client_states = model.load_checkpoint(ckptdir, tag=tag) | |
| print(client_states) | |
| logger.info("Resume checkpoint %s" % file_name) | |
| checkpoint = torch.load( | |
| os.path.join(ckptdir, tag, "state.pth"), map_location="cpu" | |
| ) | |
| msg = model_ema.load_state_dict(checkpoint['model_ema']) | |
| logger.info(msg) | |
| start_epoch = checkpoint["epoch"] + 1 | |
| max_accuracy = 0.0 | |
| if loss_scaler and "grad_scale_manager" in checkpoint: | |
| loss_scaler.load_state_dict(checkpoint["grad_scale_manager"]) | |
| if 'max_accuracy' in checkpoint: | |
| max_accuracy = checkpoint['max_accuracy'] | |
| else: | |
| logger.info(f"==============> Resuming form {file_name}....................") | |
| checkpoint = torch.load(file_name, map_location='cpu') | |
| msg = model.load_state_dict(checkpoint['model'], strict=False) | |
| logger.info(msg) | |
| max_accuracy = 0.0 | |
| if 'optimizer' in checkpoint and 'epoch' in checkpoint: | |
| optimizer.load_state_dict(checkpoint['optimizer']) | |
| if 'lr_scheduler' in checkpoint: | |
| lr_scheduler.load_state_dict(checkpoint['lr_scheduler']) | |
| start_epoch = checkpoint['epoch'] + 1 | |
| if 'scaler' in checkpoint: | |
| loss_scaler.load_state_dict(checkpoint['scaler']) | |
| logger.info(f"=> loaded successfully '{file_name}' (epoch {checkpoint['epoch']})") | |
| if 'max_accuracy' in checkpoint: | |
| max_accuracy = checkpoint['max_accuracy'] | |
| del checkpoint | |
| torch.cuda.empty_cache() | |
| return max_accuracy, start_epoch | |
| def save_checkpoint(ckptdir, config, epoch, model, model_ema, max_accuracy, optimizer, lr_scheduler, loss_scaler, logger): | |
| if config.model.params.deepspeed != '': | |
| if dist.get_rank() == 0: | |
| os.makedirs(os.path.join(ckptdir, f'ckpt_epoch_{epoch}'), exist_ok=True) | |
| checkpoint_path = os.path.join(ckptdir, f'ckpt_epoch_{epoch}', f'state.pth') | |
| to_save = { | |
| 'epoch': epoch, | |
| 'config': config, | |
| 'max_accuracy': max_accuracy, | |
| 'model_ema': model_ema.state_dict(), | |
| } | |
| if loss_scaler is not None: | |
| to_save["grad_scale_manager"] = loss_scaler.state_dict() | |
| logger.info(f"Saving checkpoint to {checkpoint_path}") | |
| torch.save(to_save, checkpoint_path) | |
| model.save_checkpoint(save_dir=ckptdir, tag=f'ckpt_epoch_{epoch}') | |
| print(f"rank[{dist.get_rank()}]: {ckptdir}/{f'ckpt_epoch_{epoch}'} saved") | |
| dist.barrier() | |
| else: | |
| if dist.get_rank() == 0: | |
| save_state = {'model': model.state_dict(), | |
| 'optimizer': optimizer.state_dict(), | |
| # 'lr_scheduler': lr_scheduler.state_dict(), | |
| 'max_accuracy': max_accuracy, | |
| 'scaler': loss_scaler.state_dict(), | |
| 'epoch': epoch, | |
| 'config': config} | |
| save_path = os.path.join(ckptdir, f'ckpt_epoch_{epoch}.pth') | |
| logger.info(f"{save_path} saving......") | |
| torch.save(save_state, save_path) | |
| logger.info(f"{save_path} saved !!!") | |
| def get_grad_norm(parameters, norm_type=2): | |
| if isinstance(parameters, torch.Tensor): | |
| parameters = [parameters] | |
| parameters = list(filter(lambda p: p.grad is not None, parameters)) | |
| norm_type = float(norm_type) | |
| total_norm = 0 | |
| for p in parameters: | |
| param_norm = p.grad.data.norm(norm_type) | |
| total_norm += param_norm.item() ** norm_type | |
| total_norm = total_norm ** (1. / norm_type) | |
| return total_norm | |
| def auto_resume_helper(config, output_dir): | |
| if config.model.params.deepspeed != '': | |
| dirs = os.listdir(output_dir) | |
| dirs = [d for d in dirs if d.startswith('ckpt_epoch')] | |
| print(f"All checkpoints founded in {output_dir}: {dirs}") | |
| if len(dirs) > 0: | |
| dirs = max([int(d.split('_')[-1]) for d in dirs]) | |
| latest_checkpoint = os.path.join(output_dir, 'ckpt_epoch_{}'.format(dirs)) | |
| print(f"The latest checkpoint founded: {latest_checkpoint}") | |
| resume_file = latest_checkpoint | |
| else: | |
| resume_file = None | |
| else: | |
| checkpoints = os.listdir(output_dir) | |
| checkpoints = [ckpt for ckpt in checkpoints if ckpt.endswith('pth')] | |
| print(f"All checkpoints founded in {output_dir}: {checkpoints}") | |
| if len(checkpoints) > 0: | |
| latest_checkpoint = max([os.path.join(output_dir, d) for d in checkpoints], key=os.path.getmtime) | |
| print(f"The latest checkpoint founded: {latest_checkpoint}") | |
| resume_file = latest_checkpoint | |
| else: | |
| resume_file = None | |
| return resume_file | |
| def reduce_tensor(tensor): | |
| rt = tensor.clone() | |
| dist.all_reduce(rt, op=dist.ReduceOp.SUM) | |
| rt /= dist.get_world_size() | |
| return rt | |
| def ampscaler_get_grad_norm(parameters, norm_type: float = 2.0) -> torch.Tensor: | |
| if isinstance(parameters, torch.Tensor): | |
| parameters = [parameters] | |
| parameters = [p for p in parameters if p.grad is not None] | |
| norm_type = float(norm_type) | |
| if len(parameters) == 0: | |
| return torch.tensor(0.) | |
| device = parameters[0].grad.device | |
| if norm_type == inf: | |
| total_norm = max(p.grad.detach().abs().max().to(device) for p in parameters) | |
| else: | |
| total_norm = torch.norm(torch.stack([torch.norm(p.grad.detach(), | |
| norm_type).to(device) for p in parameters]), norm_type) | |
| return total_norm | |
| class NativeScalerWithGradNormCount: | |
| state_dict_key = "amp_scaler" | |
| def __init__(self): | |
| self._scaler = torch.cuda.amp.GradScaler() | |
| def __call__(self, loss, optimizer, clip_grad=None, parameters=None, create_graph=False, update_grad=True): | |
| self._scaler.scale(loss).backward(create_graph=create_graph) | |
| if update_grad: | |
| if clip_grad is not None: | |
| assert parameters is not None | |
| self._scaler.unscale_(optimizer) # unscale the gradients of optimizer's assigned params in-place | |
| norm = torch.nn.utils.clip_grad_norm_(parameters, clip_grad) | |
| else: | |
| self._scaler.unscale_(optimizer) | |
| norm = ampscaler_get_grad_norm(parameters) | |
| self._scaler.step(optimizer) | |
| self._scaler.update() | |
| else: | |
| norm = None | |
| return norm | |
| def state_dict(self): | |
| return self._scaler.state_dict() | |
| def load_state_dict(self, state_dict): | |
| self._scaler.load_state_dict(state_dict) |