| """ Global Response Normalization Module | |
| Based on the GRN layer presented in | |
| `ConvNeXt-V2 - Co-designing and Scaling ConvNets with Masked Autoencoders` - https://arxiv.org/abs/2301.00808 | |
| This implementation | |
| * works for both NCHW and NHWC tensor layouts | |
| * uses affine param names matching existing torch norm layers | |
| * slightly improves eager mode performance via fused addcmul | |
| Hacked together by / Copyright 2023 Ross Wightman | |
| """ | |
| import torch | |
| from torch import nn as nn | |
| class GlobalResponseNorm(nn.Module): | |
| """ Global Response Normalization layer | |
| """ | |
| def __init__(self, dim, eps=1e-6, channels_last=True): | |
| super().__init__() | |
| self.eps = eps | |
| if channels_last: | |
| self.spatial_dim = (1, 2) | |
| self.channel_dim = -1 | |
| self.wb_shape = (1, 1, 1, -1) | |
| else: | |
| self.spatial_dim = (2, 3) | |
| self.channel_dim = 1 | |
| self.wb_shape = (1, -1, 1, 1) | |
| self.weight = nn.Parameter(torch.zeros(dim)) | |
| self.bias = nn.Parameter(torch.zeros(dim)) | |
| def forward(self, x): | |
| x_g = x.norm(p=2, dim=self.spatial_dim, keepdim=True) | |
| x_n = x_g / (x_g.mean(dim=self.channel_dim, keepdim=True) + self.eps) | |
| return x + torch.addcmul(self.bias.view(self.wb_shape), self.weight.view(self.wb_shape), x * x_n) | |