Spaces:
Running
on
Zero
Running
on
Zero
| # Copyright (c) Meta Platforms, Inc. and affiliates. | |
| # All rights reserved. | |
| # | |
| # This source code is licensed under the license found in the | |
| # LICENSE file in the root directory of this source tree. | |
| import torch | |
| import torch.nn as nn | |
| import torch.nn.functional as F | |
| from functools import partial | |
| from typing import Callable | |
| import collections | |
| from torch import Tensor | |
| from itertools import repeat | |
| # From PyTorch internals | |
| def _ntuple(n): | |
| def parse(x): | |
| if isinstance(x, collections.abc.Iterable) and not isinstance(x, str): | |
| return tuple(x) | |
| return tuple(repeat(x, n)) | |
| return parse | |
| def exists(val): | |
| return val is not None | |
| def default(val, d): | |
| return val if exists(val) else d | |
| to_2tuple = _ntuple(2) | |
| class ResidualBlock(nn.Module): | |
| """ | |
| ResidualBlock: construct a block of two conv layers with residual connections | |
| """ | |
| def __init__(self, in_planes, planes, norm_fn="group", stride=1, kernel_size=3): | |
| super(ResidualBlock, self).__init__() | |
| self.conv1 = nn.Conv2d( | |
| in_planes, | |
| planes, | |
| kernel_size=kernel_size, | |
| padding=1, | |
| stride=stride, | |
| padding_mode="zeros", | |
| ) | |
| self.conv2 = nn.Conv2d( | |
| planes, | |
| planes, | |
| kernel_size=kernel_size, | |
| padding=1, | |
| padding_mode="zeros", | |
| ) | |
| self.relu = nn.ReLU(inplace=True) | |
| num_groups = planes // 8 | |
| if norm_fn == "group": | |
| self.norm1 = nn.GroupNorm(num_groups=num_groups, num_channels=planes) | |
| self.norm2 = nn.GroupNorm(num_groups=num_groups, num_channels=planes) | |
| if not stride == 1: | |
| self.norm3 = nn.GroupNorm(num_groups=num_groups, num_channels=planes) | |
| elif norm_fn == "batch": | |
| self.norm1 = nn.BatchNorm2d(planes) | |
| self.norm2 = nn.BatchNorm2d(planes) | |
| if not stride == 1: | |
| self.norm3 = nn.BatchNorm2d(planes) | |
| elif norm_fn == "instance": | |
| self.norm1 = nn.InstanceNorm2d(planes) | |
| self.norm2 = nn.InstanceNorm2d(planes) | |
| if not stride == 1: | |
| self.norm3 = nn.InstanceNorm2d(planes) | |
| elif norm_fn == "none": | |
| self.norm1 = nn.Sequential() | |
| self.norm2 = nn.Sequential() | |
| if not stride == 1: | |
| self.norm3 = nn.Sequential() | |
| else: | |
| raise NotImplementedError | |
| if stride == 1: | |
| self.downsample = None | |
| else: | |
| self.downsample = nn.Sequential( | |
| nn.Conv2d(in_planes, planes, kernel_size=1, stride=stride), | |
| self.norm3, | |
| ) | |
| def forward(self, x): | |
| y = x | |
| y = self.relu(self.norm1(self.conv1(y))) | |
| y = self.relu(self.norm2(self.conv2(y))) | |
| if self.downsample is not None: | |
| x = self.downsample(x) | |
| return self.relu(x + y) | |
| class Mlp(nn.Module): | |
| """MLP as used in Vision Transformer, MLP-Mixer and related networks""" | |
| def __init__( | |
| self, | |
| in_features, | |
| hidden_features=None, | |
| out_features=None, | |
| act_layer=nn.GELU, | |
| norm_layer=None, | |
| bias=True, | |
| drop=0.0, | |
| use_conv=False, | |
| ): | |
| super().__init__() | |
| out_features = out_features or in_features | |
| hidden_features = hidden_features or in_features | |
| bias = to_2tuple(bias) | |
| drop_probs = to_2tuple(drop) | |
| linear_layer = partial(nn.Conv2d, kernel_size=1) if use_conv else nn.Linear | |
| self.fc1 = linear_layer(in_features, hidden_features, bias=bias[0]) | |
| self.act = act_layer() | |
| self.drop1 = nn.Dropout(drop_probs[0]) | |
| self.fc2 = linear_layer(hidden_features, out_features, bias=bias[1]) | |
| self.drop2 = nn.Dropout(drop_probs[1]) | |
| def forward(self, x): | |
| x = self.fc1(x) | |
| x = self.act(x) | |
| x = self.drop1(x) | |
| x = self.fc2(x) | |
| x = self.drop2(x) | |
| return x | |
| class AttnBlock(nn.Module): | |
| def __init__( | |
| self, | |
| hidden_size, | |
| num_heads, | |
| attn_class: Callable[..., nn.Module] = nn.MultiheadAttention, | |
| mlp_ratio=4.0, | |
| **block_kwargs | |
| ): | |
| """ | |
| Self attention block | |
| """ | |
| super().__init__() | |
| self.norm1 = nn.LayerNorm(hidden_size) | |
| self.norm2 = nn.LayerNorm(hidden_size) | |
| self.attn = attn_class(embed_dim=hidden_size, num_heads=num_heads, batch_first=True, **block_kwargs) | |
| mlp_hidden_dim = int(hidden_size * mlp_ratio) | |
| self.mlp = Mlp(in_features=hidden_size, hidden_features=mlp_hidden_dim, drop=0) | |
| def forward(self, x, mask=None): | |
| # Prepare the mask for PyTorch's attention (it expects a different format) | |
| # attn_mask = mask if mask is not None else None | |
| # Normalize before attention | |
| x = self.norm1(x) | |
| # PyTorch's MultiheadAttention returns attn_output, attn_output_weights | |
| # attn_output, _ = self.attn(x, x, x, attn_mask=attn_mask) | |
| attn_output, _ = self.attn(x, x, x) | |
| # Add & Norm | |
| x = x + attn_output | |
| x = x + self.mlp(self.norm2(x)) | |
| return x | |
| class CrossAttnBlock(nn.Module): | |
| def __init__(self, hidden_size, context_dim, num_heads=1, mlp_ratio=4.0, **block_kwargs): | |
| """ | |
| Cross attention block | |
| """ | |
| super().__init__() | |
| self.norm1 = nn.LayerNorm(hidden_size) | |
| self.norm_context = nn.LayerNorm(hidden_size) | |
| self.norm2 = nn.LayerNorm(hidden_size) | |
| self.cross_attn = nn.MultiheadAttention( | |
| embed_dim=hidden_size, num_heads=num_heads, batch_first=True, **block_kwargs | |
| ) | |
| mlp_hidden_dim = int(hidden_size * mlp_ratio) | |
| self.mlp = Mlp(in_features=hidden_size, hidden_features=mlp_hidden_dim, drop=0) | |
| def forward(self, x, context, mask=None): | |
| # Normalize inputs | |
| x = self.norm1(x) | |
| context = self.norm_context(context) | |
| # Apply cross attention | |
| # Note: nn.MultiheadAttention returns attn_output, attn_output_weights | |
| attn_output, _ = self.cross_attn(x, context, context, attn_mask=mask) | |
| # Add & Norm | |
| x = x + attn_output | |
| x = x + self.mlp(self.norm2(x)) | |
| return x | |