ZhenweiWang's picture
Upload folder using huggingface_hub
0ca05b5 verified
raw
history blame
5.03 kB
# References:
# https://github.com/facebookresearch/dino/blob/master/vision_transformer.py
# https://github.com/rwightman/pytorch-image-models/tree/master/timm/layers/patch_embed.py
from typing import Callable, Optional, Tuple, Union
import torch
from torch import Tensor
import torch.nn as nn
import torch.nn.functional as F
from itertools import repeat
import collections.abc
def make_2tuple(x):
if isinstance(x, tuple):
assert len(x) == 2
return x
assert isinstance(x, int)
return (x, x)
class PatchEmbed(nn.Module):
"""
2D image to patch embedding: (B,C,H,W) -> (B,N,D)
Args:
img_size: Image size.
patch_size: Patch token size.
in_chans: Number of input image channels.
embed_dim: Number of linear projection output channels.
norm_layer: Normalization layer.
"""
def __init__(
self,
img_size: Union[int, Tuple[int, int]] = 224,
patch_size: Union[int, Tuple[int, int]] = 16,
in_chans: int = 3,
embed_dim: int = 768,
norm_layer: Optional[Callable] = None,
flatten_embedding: bool = True,
) -> None:
super().__init__()
image_HW = make_2tuple(img_size)
patch_HW = make_2tuple(patch_size)
patch_grid_size = (image_HW[0] // patch_HW[0], image_HW[1] // patch_HW[1])
self.img_size = image_HW
self.patch_size = patch_HW
self.patches_resolution = patch_grid_size
self.num_patches = patch_grid_size[0] * patch_grid_size[1]
self.in_chans = in_chans
self.embed_dim = embed_dim
self.flatten_embedding = flatten_embedding
self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_HW, stride=patch_HW)
self.norm = norm_layer(embed_dim) if norm_layer else nn.Identity()
def forward(self, x: Tensor) -> Tensor:
_, _, H, W = x.shape
patch_H, patch_W = self.patch_size
assert H % patch_H == 0, f"Input image height {H} is not a multiple of patch height {patch_H}"
assert W % patch_W == 0, f"Input image width {W} is not a multiple of patch width: {patch_W}"
x = self.proj(x) # B C H W
H, W = x.size(2), x.size(3)
x = x.flatten(2).transpose(1, 2) # B HW C
x = self.norm(x)
if not self.flatten_embedding:
x = x.reshape(-1, H, W, self.embed_dim) # B H W C
return x
class PatchEmbed_Mlp(PatchEmbed):
def __init__(self, img_size=224,
patch_size=16,
in_chans=3,
embed_dim=768,
norm_layer=None,
flatten_embedding=True):
super().__init__(img_size, patch_size, in_chans, embed_dim, norm_layer, flatten_embedding)
self.proj = nn.Sequential(
PixelUnshuffle(patch_size),
Permute((0,2,3,1)),
Mlp(in_chans * patch_size**2, 4*embed_dim, embed_dim),
Permute((0,3,1,2)),
)
class PixelUnshuffle (nn.Module):
def __init__(self, downscale_factor):
super().__init__()
self.downscale_factor = downscale_factor
def forward(self, input):
if input.numel() == 0:
# this is not in the original torch implementation
C,H,W = input.shape[-3:]
assert H and W and H % self.downscale_factor == W%self.downscale_factor == 0
return input.view(*input.shape[:-3], C*self.downscale_factor**2, H//self.downscale_factor, W//self.downscale_factor)
else:
return F.pixel_unshuffle(input, self.downscale_factor)
class Permute(torch.nn.Module):
dims: tuple[int, ...]
def __init__(self, dims: tuple[int, ...]) -> None:
super().__init__()
self.dims = tuple(dims)
def __repr__(self):
return f"Permute{self.dims}"
def forward(self, input: torch.Tensor) -> torch.Tensor:
return input.permute(*self.dims)
def _ntuple(n):
def parse(x):
if isinstance(x, collections.abc.Iterable) and not isinstance(x, str):
return x
return tuple(repeat(x, n))
return parse
to_2tuple = _ntuple(2)
class Mlp(nn.Module):
""" MLP as used in Vision Transformer, MLP-Mixer and related networks"""
def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, bias=True, drop=0.):
super().__init__()
out_features = out_features or in_features
hidden_features = hidden_features or in_features
bias = to_2tuple(bias)
drop_probs = to_2tuple(drop)
self.fc1 = nn.Linear(in_features, hidden_features, bias=bias[0])
self.act = act_layer()
self.drop1 = nn.Dropout(drop_probs[0])
self.fc2 = nn.Linear(hidden_features, out_features, bias=bias[1])
self.drop2 = nn.Dropout(drop_probs[1])
def forward(self, x):
x = self.fc1(x)
x = self.act(x)
x = self.drop1(x)
x = self.fc2(x)
x = self.drop2(x)
return x