Spaces:
Build error
Build error
| import torch | |
| import annotator.uniformer.mmcv as mmcv | |
| class _BatchNormXd(torch.nn.modules.batchnorm._BatchNorm): | |
| """A general BatchNorm layer without input dimension check. | |
| Reproduced from @kapily's work: | |
| (https://github.com/pytorch/pytorch/issues/41081#issuecomment-783961547) | |
| The only difference between BatchNorm1d, BatchNorm2d, BatchNorm3d, etc | |
| is `_check_input_dim` that is designed for tensor sanity checks. | |
| The check has been bypassed in this class for the convenience of converting | |
| SyncBatchNorm. | |
| """ | |
| def _check_input_dim(self, input): | |
| return | |
| def revert_sync_batchnorm(module): | |
| """Helper function to convert all `SyncBatchNorm` (SyncBN) and | |
| `mmcv.ops.sync_bn.SyncBatchNorm`(MMSyncBN) layers in the model to | |
| `BatchNormXd` layers. | |
| Adapted from @kapily's work: | |
| (https://github.com/pytorch/pytorch/issues/41081#issuecomment-783961547) | |
| Args: | |
| module (nn.Module): The module containing `SyncBatchNorm` layers. | |
| Returns: | |
| module_output: The converted module with `BatchNormXd` layers. | |
| """ | |
| module_output = module | |
| module_checklist = [torch.nn.modules.batchnorm.SyncBatchNorm] | |
| if hasattr(mmcv, 'ops'): | |
| module_checklist.append(mmcv.ops.SyncBatchNorm) | |
| if isinstance(module, tuple(module_checklist)): | |
| module_output = _BatchNormXd(module.num_features, module.eps, | |
| module.momentum, module.affine, | |
| module.track_running_stats) | |
| if module.affine: | |
| # no_grad() may not be needed here but | |
| # just to be consistent with `convert_sync_batchnorm()` | |
| with torch.no_grad(): | |
| module_output.weight = module.weight | |
| module_output.bias = module.bias | |
| module_output.running_mean = module.running_mean | |
| module_output.running_var = module.running_var | |
| module_output.num_batches_tracked = module.num_batches_tracked | |
| module_output.training = module.training | |
| # qconfig exists in quantized models | |
| if hasattr(module, 'qconfig'): | |
| module_output.qconfig = module.qconfig | |
| for name, child in module.named_children(): | |
| module_output.add_module(name, revert_sync_batchnorm(child)) | |
| del module | |
| return module_output | |