Spaces:
Running
on
Zero
Running
on
Zero
| # References: | |
| # https://github.com/facebookresearch/dino/blob/master/vision_transformer.py | |
| # https://github.com/rwightman/pytorch-image-models/tree/master/timm/layers/patch_embed.py | |
| from typing import Callable, Optional, Tuple, Union | |
| import torch | |
| from torch import Tensor | |
| import torch.nn as nn | |
| import torch.nn.functional as F | |
| from itertools import repeat | |
| import collections.abc | |
| def make_2tuple(x): | |
| if isinstance(x, tuple): | |
| assert len(x) == 2 | |
| return x | |
| assert isinstance(x, int) | |
| return (x, x) | |
| class PatchEmbed(nn.Module): | |
| """ | |
| 2D image to patch embedding: (B,C,H,W) -> (B,N,D) | |
| Args: | |
| img_size: Image size. | |
| patch_size: Patch token size. | |
| in_chans: Number of input image channels. | |
| embed_dim: Number of linear projection output channels. | |
| norm_layer: Normalization layer. | |
| """ | |
| def __init__( | |
| self, | |
| img_size: Union[int, Tuple[int, int]] = 224, | |
| patch_size: Union[int, Tuple[int, int]] = 16, | |
| in_chans: int = 3, | |
| embed_dim: int = 768, | |
| norm_layer: Optional[Callable] = None, | |
| flatten_embedding: bool = True, | |
| ) -> None: | |
| super().__init__() | |
| image_HW = make_2tuple(img_size) | |
| patch_HW = make_2tuple(patch_size) | |
| patch_grid_size = (image_HW[0] // patch_HW[0], image_HW[1] // patch_HW[1]) | |
| self.img_size = image_HW | |
| self.patch_size = patch_HW | |
| self.patches_resolution = patch_grid_size | |
| self.num_patches = patch_grid_size[0] * patch_grid_size[1] | |
| self.in_chans = in_chans | |
| self.embed_dim = embed_dim | |
| self.flatten_embedding = flatten_embedding | |
| self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_HW, stride=patch_HW) | |
| self.norm = norm_layer(embed_dim) if norm_layer else nn.Identity() | |
| def forward(self, x: Tensor) -> Tensor: | |
| _, _, H, W = x.shape | |
| patch_H, patch_W = self.patch_size | |
| assert H % patch_H == 0, f"Input image height {H} is not a multiple of patch height {patch_H}" | |
| assert W % patch_W == 0, f"Input image width {W} is not a multiple of patch width: {patch_W}" | |
| x = self.proj(x) # B C H W | |
| H, W = x.size(2), x.size(3) | |
| x = x.flatten(2).transpose(1, 2) # B HW C | |
| x = self.norm(x) | |
| if not self.flatten_embedding: | |
| x = x.reshape(-1, H, W, self.embed_dim) # B H W C | |
| return x | |
| class PatchEmbed_Mlp(PatchEmbed): | |
| def __init__(self, img_size=224, | |
| patch_size=16, | |
| in_chans=3, | |
| embed_dim=768, | |
| norm_layer=None, | |
| flatten_embedding=True): | |
| super().__init__(img_size, patch_size, in_chans, embed_dim, norm_layer, flatten_embedding) | |
| self.proj = nn.Sequential( | |
| PixelUnshuffle(patch_size), | |
| Permute((0,2,3,1)), | |
| Mlp(in_chans * patch_size**2, 4*embed_dim, embed_dim), | |
| Permute((0,3,1,2)), | |
| ) | |
| class PixelUnshuffle (nn.Module): | |
| def __init__(self, downscale_factor): | |
| super().__init__() | |
| self.downscale_factor = downscale_factor | |
| def forward(self, input): | |
| if input.numel() == 0: | |
| # this is not in the original torch implementation | |
| C,H,W = input.shape[-3:] | |
| assert H and W and H % self.downscale_factor == W%self.downscale_factor == 0 | |
| return input.view(*input.shape[:-3], C*self.downscale_factor**2, H//self.downscale_factor, W//self.downscale_factor) | |
| else: | |
| return F.pixel_unshuffle(input, self.downscale_factor) | |
| class Permute(torch.nn.Module): | |
| dims: tuple[int, ...] | |
| def __init__(self, dims: tuple[int, ...]) -> None: | |
| super().__init__() | |
| self.dims = tuple(dims) | |
| def __repr__(self): | |
| return f"Permute{self.dims}" | |
| def forward(self, input: torch.Tensor) -> torch.Tensor: | |
| return input.permute(*self.dims) | |
| def _ntuple(n): | |
| def parse(x): | |
| if isinstance(x, collections.abc.Iterable) and not isinstance(x, str): | |
| return x | |
| return tuple(repeat(x, n)) | |
| return parse | |
| to_2tuple = _ntuple(2) | |
| class Mlp(nn.Module): | |
| """ MLP as used in Vision Transformer, MLP-Mixer and related networks""" | |
| def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, bias=True, drop=0.): | |
| super().__init__() | |
| out_features = out_features or in_features | |
| hidden_features = hidden_features or in_features | |
| bias = to_2tuple(bias) | |
| drop_probs = to_2tuple(drop) | |
| self.fc1 = nn.Linear(in_features, hidden_features, bias=bias[0]) | |
| self.act = act_layer() | |
| self.drop1 = nn.Dropout(drop_probs[0]) | |
| self.fc2 = nn.Linear(hidden_features, out_features, bias=bias[1]) | |
| self.drop2 = nn.Dropout(drop_probs[1]) | |
| def forward(self, x): | |
| x = self.fc1(x) | |
| x = self.act(x) | |
| x = self.drop1(x) | |
| x = self.fc2(x) | |
| x = self.drop2(x) | |
| return x | |