Spaces:
Runtime error
Runtime error
| #!/usr/bin/env python3 | |
| # -*- coding:utf-8 -*- | |
| # Copyright (c) Megvii Inc. All rights reserved. | |
| import math | |
| from copy import deepcopy | |
| import torch | |
| import torch.nn as nn | |
| __all__ = ["ModelEMA", "is_parallel"] | |
| def is_parallel(model): | |
| """check if model is in parallel mode.""" | |
| parallel_type = ( | |
| nn.parallel.DataParallel, | |
| nn.parallel.DistributedDataParallel, | |
| ) | |
| return isinstance(model, parallel_type) | |
| class ModelEMA: | |
| """ | |
| Model Exponential Moving Average from https://github.com/rwightman/pytorch-image-models | |
| Keep a moving average of everything in the model state_dict (parameters and buffers). | |
| This is intended to allow functionality like | |
| https://www.tensorflow.org/api_docs/python/tf/train/ExponentialMovingAverage | |
| A smoothed version of the weights is necessary for some training schemes to perform well. | |
| This class is sensitive where it is initialized in the sequence of model init, | |
| GPU assignment and distributed training wrappers. | |
| """ | |
| def __init__(self, model, decay=0.9999, updates=0): | |
| """ | |
| Args: | |
| model (nn.Module): model to apply EMA. | |
| decay (float): ema decay reate. | |
| updates (int): counter of EMA updates. | |
| """ | |
| # Create EMA(FP32) | |
| self.ema = deepcopy(model.module if is_parallel(model) else model).eval() | |
| self.updates = updates | |
| # decay exponential ramp (to help early epochs) | |
| self.decay = lambda x: decay * (1 - math.exp(-x / 2000)) | |
| for p in self.ema.parameters(): | |
| p.requires_grad_(False) | |
| def update(self, model): | |
| # Update EMA parameters | |
| with torch.no_grad(): | |
| self.updates += 1 | |
| d = self.decay(self.updates) | |
| msd = ( | |
| model.module.state_dict() if is_parallel(model) else model.state_dict() | |
| ) # model state_dict | |
| for k, v in self.ema.state_dict().items(): | |
| if v.dtype.is_floating_point: | |
| v *= d | |
| v += (1.0 - d) * msd[k].detach() | |