| 
							 | 
						import torch.nn as nn | 
					
					
						
						| 
							 | 
						import torch | 
					
					
						
						| 
							 | 
						import torch.nn.functional as F | 
					
					
						
						| 
							 | 
						import copy | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						from contextlib import nullcontext | 
					
					
						
						| 
							 | 
						import math | 
					
					
						
						| 
							 | 
						from typing import Optional, Tuple | 
					
					
						
						| 
							 | 
						 | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						from einops import rearrange | 
					
					
						
						| 
							 | 
						from easydict import EasyDict as adict | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						from typing import Optional, Tuple, Type | 
					
					
						
						| 
							 | 
						from functools import partial | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						class MlpProjector(nn.Module): | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						    def __init__(self, cfg): | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						        super().__init__() | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						        self.cfg = cfg | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						        if cfg.projector_type == "identity": | 
					
					
						
						| 
							 | 
						            modules = nn.Identity() | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						        elif cfg.projector_type == "linear": | 
					
					
						
						| 
							 | 
						            modules = nn.Linear(cfg.input_dim, cfg.n_embed) | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						        elif cfg.projector_type == "mlp_gelu": | 
					
					
						
						| 
							 | 
						            mlp_depth = cfg.get("depth", 1) | 
					
					
						
						| 
							 | 
						            modules = [nn.Linear(cfg.input_dim, cfg.n_embed)] | 
					
					
						
						| 
							 | 
						            for _ in range(1, mlp_depth): | 
					
					
						
						| 
							 | 
						                modules.append(nn.GELU()) | 
					
					
						
						| 
							 | 
						                modules.append(nn.Linear(cfg.n_embed, cfg.n_embed)) | 
					
					
						
						| 
							 | 
						            modules = nn.Sequential(*modules) | 
					
					
						
						| 
							 | 
						         | 
					
					
						
						| 
							 | 
						        elif cfg.projector_type == "normlayer_downsample_mlp_gelu": | 
					
					
						
						| 
							 | 
						            mlp_depth = cfg.get("depth", 1) | 
					
					
						
						| 
							 | 
						            mlp_ratio = cfg.get("mlp_ratio", 1) | 
					
					
						
						| 
							 | 
						            modules = [ | 
					
					
						
						| 
							 | 
						                nn.LayerNorm(cfg.input_dim * cfg.downsample_ratio * cfg.downsample_ratio), | 
					
					
						
						| 
							 | 
						                nn.Linear(cfg.input_dim * cfg.downsample_ratio * cfg.downsample_ratio, cfg.n_embed * mlp_ratio) | 
					
					
						
						| 
							 | 
						            ] | 
					
					
						
						| 
							 | 
						            for _ in range(1, mlp_depth - 1): | 
					
					
						
						| 
							 | 
						                modules.append(nn.GELU()) | 
					
					
						
						| 
							 | 
						                modules.append(nn.Linear(cfg.n_embed * mlp_ratio, cfg.n_embed * mlp_ratio)) | 
					
					
						
						| 
							 | 
						            modules.append(nn.GELU()) | 
					
					
						
						| 
							 | 
						            modules.append(nn.Linear(cfg.n_embed * mlp_ratio, cfg.n_embed)) | 
					
					
						
						| 
							 | 
						            modules = nn.Sequential(*modules) | 
					
					
						
						| 
							 | 
						         | 
					
					
						
						| 
							 | 
						        elif cfg.projector_type == "downsample_mlp_gelu": | 
					
					
						
						| 
							 | 
						            mlp_depth = cfg.get("depth", 1) | 
					
					
						
						| 
							 | 
						            mlp_ratio = cfg.get("mlp_ratio", 1) | 
					
					
						
						| 
							 | 
						            modules = [nn.Linear(cfg.input_dim * cfg.downsample_ratio * cfg.downsample_ratio, cfg.n_embed * mlp_ratio)] | 
					
					
						
						| 
							 | 
						            for _ in range(1, mlp_depth - 1): | 
					
					
						
						| 
							 | 
						                modules.append(nn.GELU()) | 
					
					
						
						| 
							 | 
						                modules.append(nn.Linear(cfg.n_embed * mlp_ratio, cfg.n_embed * mlp_ratio)) | 
					
					
						
						| 
							 | 
						            modules.append(nn.GELU()) | 
					
					
						
						| 
							 | 
						            modules.append(nn.Linear(cfg.n_embed * mlp_ratio, cfg.n_embed)) | 
					
					
						
						| 
							 | 
						            modules = nn.Sequential(*modules) | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						        elif cfg.projector_type == "low_high_hybrid_split_mlp_gelu": | 
					
					
						
						| 
							 | 
						            mlp_depth = cfg.get("depth", 1) | 
					
					
						
						| 
							 | 
						            self.high_up_proj = nn.Linear(cfg.input_dim, cfg.n_embed // 2) | 
					
					
						
						| 
							 | 
						            self.low_up_proj = nn.Linear(cfg.input_dim, cfg.n_embed // 2) | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						            modules = [] | 
					
					
						
						| 
							 | 
						            for _ in range(1, mlp_depth): | 
					
					
						
						| 
							 | 
						                modules.append(nn.GELU()) | 
					
					
						
						| 
							 | 
						                modules.append(nn.Linear(cfg.n_embed, cfg.n_embed)) | 
					
					
						
						| 
							 | 
						            modules = nn.Sequential(*modules) | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						        elif cfg.projector_type == "hybrid_split_feature_mlp_gelu": | 
					
					
						
						| 
							 | 
						            mlp_depth = cfg.get("depth", 1) | 
					
					
						
						| 
							 | 
						            channel_div = cfg.get("channel_div", 0.5) | 
					
					
						
						| 
							 | 
						            self.high_up_proj = nn.Linear(cfg.input_dim[0], int(cfg.n_embed * channel_div)) | 
					
					
						
						| 
							 | 
						            self.low_up_proj = nn.Linear(cfg.input_dim[1], cfg.n_embed - int(cfg.n_embed * channel_div)) | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						            modules = [] | 
					
					
						
						| 
							 | 
						            for _ in range(1, mlp_depth): | 
					
					
						
						| 
							 | 
						                modules.append(nn.GELU()) | 
					
					
						
						| 
							 | 
						                modules.append(nn.Linear(cfg.n_embed, cfg.n_embed)) | 
					
					
						
						| 
							 | 
						            modules = nn.Sequential(*modules) | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						        elif cfg.projector_type == "low_high_split_mlp_gelu": | 
					
					
						
						| 
							 | 
						            mlp_depth = cfg.get("depth", 1) | 
					
					
						
						| 
							 | 
						            modules = [] | 
					
					
						
						| 
							 | 
						            for _ in range(1, mlp_depth): | 
					
					
						
						| 
							 | 
						                modules.append(nn.GELU()) | 
					
					
						
						| 
							 | 
						                modules.append(nn.Linear(cfg.n_embed // 2, cfg.n_embed // 2)) | 
					
					
						
						| 
							 | 
						            modules = nn.Sequential(*modules) | 
					
					
						
						| 
							 | 
						            self.high_layers = nn.Sequential(*modules) | 
					
					
						
						| 
							 | 
						            self.low_layers = copy.deepcopy(modules) | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						        else: | 
					
					
						
						| 
							 | 
						            raise ValueError(f"Unknown projector type: {cfg.projector_type}") | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						        if cfg.get("token_pooling", False): | 
					
					
						
						| 
							 | 
						            self.token_pooling_layer = nn.Linear(cfg.input_dim * 4, cfg.input_dim) | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						        if cfg.get("conv_fusion_high_low_features", False): | 
					
					
						
						| 
							 | 
						            self.fusion_layer = nn.Linear(cfg.input_dim, cfg.input_dim) | 
					
					
						
						| 
							 | 
						        self.layers = modules | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						    def forward(self, x): | 
					
					
						
						| 
							 | 
						        if self.cfg.get("token_pooling", False): | 
					
					
						
						| 
							 | 
						            batch_size, wxh, channels = x.shape | 
					
					
						
						| 
							 | 
						            w = h = int(wxh**0.5) | 
					
					
						
						| 
							 | 
						            x = x.view(batch_size, w, h, channels) | 
					
					
						
						| 
							 | 
						            x = x.permute(0, 3, 1, 2) | 
					
					
						
						| 
							 | 
						             | 
					
					
						
						| 
							 | 
						            patches = x.unfold(2, 2, 2).unfold(3, 2, 2) | 
					
					
						
						| 
							 | 
						            batch_size, channels, h_patches, w_patches, _, _ = patches.size() | 
					
					
						
						| 
							 | 
						             | 
					
					
						
						| 
							 | 
						            patches = patches.contiguous().view(batch_size, channels, h_patches * w_patches, -1) | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						             | 
					
					
						
						| 
							 | 
						            patches = patches.permute(0, 2, 1, 3).contiguous() | 
					
					
						
						| 
							 | 
						            patches = patches.view(batch_size, h_patches * w_patches, channels * 4) | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						            x = self.token_pooling_layer(patches) | 
					
					
						
						| 
							 | 
						         | 
					
					
						
						| 
							 | 
						        if self.cfg.get("conv_fusion_high_low_features", False): | 
					
					
						
						| 
							 | 
						            x = self.fusion_layer(x[:, 0]) + x[:, 1] | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						        if self.cfg.projector_type == 'low_high_hybrid_split_mlp_gelu': | 
					
					
						
						| 
							 | 
						            high_x, low_x = x[0], x[1] | 
					
					
						
						| 
							 | 
						            high_x = self.high_up_proj(high_x) | 
					
					
						
						| 
							 | 
						            low_x = self.low_up_proj(low_x) | 
					
					
						
						| 
							 | 
						            x = torch.concat([high_x, low_x], dim=-1) | 
					
					
						
						| 
							 | 
						         | 
					
					
						
						| 
							 | 
						        if self.cfg.projector_type == 'hybrid_split_feature_mlp_gelu': | 
					
					
						
						| 
							 | 
						            high_x = x[...,:self.cfg.input_dim[0]] | 
					
					
						
						| 
							 | 
						            low_x = x[...,self.cfg.input_dim[0]:] | 
					
					
						
						| 
							 | 
						            high_x = self.high_up_proj(high_x) | 
					
					
						
						| 
							 | 
						            low_x = self.low_up_proj(low_x) | 
					
					
						
						| 
							 | 
						            x = torch.concat([high_x, low_x], dim=-1) | 
					
					
						
						| 
							 | 
						         | 
					
					
						
						| 
							 | 
						        if self.cfg.projector_type == 'low_high_split_mlp_gelu': | 
					
					
						
						| 
							 | 
						            high_x, low_x = x[0], x[1] | 
					
					
						
						| 
							 | 
						            high_x = self.high_layers(high_x) | 
					
					
						
						| 
							 | 
						            low_x = self.low_layers(low_x) | 
					
					
						
						| 
							 | 
						            x = torch.concat([high_x, low_x], dim=-1) | 
					
					
						
						| 
							 | 
						            return x | 
					
					
						
						| 
							 | 
						         | 
					
					
						
						| 
							 | 
						        if self.cfg.projector_type == 'downsample_mlp_gelu' or self.cfg.projector_type == 'normlayer_downsample_mlp_gelu': | 
					
					
						
						| 
							 | 
						            bs, hw, input_dim = x.shape | 
					
					
						
						| 
							 | 
						            h = w = int((hw) ** 0.5) | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						            """compute padding""" | 
					
					
						
						| 
							 | 
						            if h % self.cfg.downsample_ratio: | 
					
					
						
						| 
							 | 
						                pad = self.cfg.downsample_ratio - h % self.cfg.downsample_ratio | 
					
					
						
						| 
							 | 
						            else: | 
					
					
						
						| 
							 | 
						                pad = 0 | 
					
					
						
						| 
							 | 
						            x = x.reshape(bs, h, w, input_dim) | 
					
					
						
						| 
							 | 
						            if pad > 0: | 
					
					
						
						| 
							 | 
						                x = F.pad(x, (0, 0, 0, pad, 0, pad), "constant", 0) | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						            """4 to 1 concat""" | 
					
					
						
						| 
							 | 
						            x = x.permute(0, 3, 1, 2)   | 
					
					
						
						| 
							 | 
						            x = F.unfold(x, kernel_size=self.cfg.downsample_ratio, stride=self.cfg.downsample_ratio, padding=0)  | 
					
					
						
						| 
							 | 
						            x = x.permute(0, 2, 1) | 
					
					
						
						| 
							 | 
						             | 
					
					
						
						| 
							 | 
						        return self.layers(x) | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						    @staticmethod | 
					
					
						
						| 
							 | 
						    def get_flops_per_sample(cfg): | 
					
					
						
						| 
							 | 
						        if cfg.projector_type == "linear": | 
					
					
						
						| 
							 | 
						            fwd = 2 * cfg.input_dim * cfg.n_embed | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						        elif "mlp_gelu" in cfg.projector_type : | 
					
					
						
						| 
							 | 
						            mlp_depth = cfg.get("depth", 1) | 
					
					
						
						| 
							 | 
						            downsample_ratio = cfg.get("downsample_ratio", 1) | 
					
					
						
						| 
							 | 
						            input_dim = sum(cfg.input_dim) if isinstance(cfg.input_dim, list) else cfg.input_dim | 
					
					
						
						| 
							 | 
						            input_dim = input_dim * downsample_ratio * downsample_ratio | 
					
					
						
						| 
							 | 
						            fwd = 2 * input_dim * cfg.n_embed + (mlp_depth - 1) * 2 * cfg.n_embed * cfg.n_embed | 
					
					
						
						| 
							 | 
						        else: | 
					
					
						
						| 
							 | 
						            fwd = 0 | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						        return fwd * 3 | 
					
					
						
						| 
							 | 
						     | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						 | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						class LayerNormfp32(torch.nn.LayerNorm): | 
					
					
						
						| 
							 | 
						    """Subclass torch's LayerNorm to handle fp16.""" | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						    def forward(self, x: torch.Tensor): | 
					
					
						
						| 
							 | 
						        orig_type = x.dtype | 
					
					
						
						| 
							 | 
						        ret = super().forward(x.type(torch.float32)) | 
					
					
						
						| 
							 | 
						        return ret.type(orig_type) | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						def get_abs_pos(abs_pos, tgt_size): | 
					
					
						
						| 
							 | 
						     | 
					
					
						
						| 
							 | 
						     | 
					
					
						
						| 
							 | 
						     | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						     | 
					
					
						
						| 
							 | 
						     | 
					
					
						
						| 
							 | 
						     | 
					
					
						
						| 
							 | 
						    dim = abs_pos.size(-1) | 
					
					
						
						| 
							 | 
						     | 
					
					
						
						| 
							 | 
						    abs_pos_new = abs_pos.squeeze(0) | 
					
					
						
						| 
							 | 
						    cls_token, old_pos_embed = abs_pos_new[:1], abs_pos_new[1:] | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						    src_size = int(math.sqrt(abs_pos_new.shape[0] - 1)) | 
					
					
						
						| 
							 | 
						    tgt_size = int(math.sqrt(tgt_size)) | 
					
					
						
						| 
							 | 
						    dtype = abs_pos.dtype | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						    if src_size != tgt_size: | 
					
					
						
						| 
							 | 
						        old_pos_embed = old_pos_embed.view(1, src_size, src_size, dim).permute(0, 3, 1, | 
					
					
						
						| 
							 | 
						                                                                                    2).contiguous() | 
					
					
						
						| 
							 | 
						        old_pos_embed = old_pos_embed.to(torch.float32) | 
					
					
						
						| 
							 | 
						        new_pos_embed = F.interpolate( | 
					
					
						
						| 
							 | 
						            old_pos_embed, | 
					
					
						
						| 
							 | 
						            size=(tgt_size, tgt_size), | 
					
					
						
						| 
							 | 
						            mode='bicubic', | 
					
					
						
						| 
							 | 
						            antialias=True, | 
					
					
						
						| 
							 | 
						            align_corners=False, | 
					
					
						
						| 
							 | 
						        ).to(dtype) | 
					
					
						
						| 
							 | 
						        new_pos_embed = new_pos_embed.permute(0, 2, 3, 1) | 
					
					
						
						| 
							 | 
						        new_pos_embed = new_pos_embed.view(tgt_size * tgt_size, dim) | 
					
					
						
						| 
							 | 
						        vision_pos_embed = torch.cat([cls_token, new_pos_embed], dim=0) | 
					
					
						
						| 
							 | 
						        vision_pos_embed = vision_pos_embed.view(1, tgt_size * tgt_size + 1, dim) | 
					
					
						
						| 
							 | 
						        return vision_pos_embed | 
					
					
						
						| 
							 | 
						    else: | 
					
					
						
						| 
							 | 
						        return abs_pos | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						@torch.jit.script | 
					
					
						
						| 
							 | 
						def quick_gelu(x): | 
					
					
						
						| 
							 | 
						    return x * torch.sigmoid(1.702 * x) | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						class CLIPVisionEmbeddings(nn.Module): | 
					
					
						
						| 
							 | 
						    def __init__(self, hidden_size=1024, image_size=224, patch_size=14, num_channels=3): | 
					
					
						
						| 
							 | 
						        super().__init__() | 
					
					
						
						| 
							 | 
						        self.embed_dim = hidden_size | 
					
					
						
						| 
							 | 
						        self.image_size = image_size | 
					
					
						
						| 
							 | 
						        self.patch_size = patch_size | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						        self.class_embedding = torch.nn.Parameter(torch.randn(self.embed_dim)) | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						        self.patch_embedding = torch.nn.Conv2d( | 
					
					
						
						| 
							 | 
						            in_channels=num_channels, | 
					
					
						
						| 
							 | 
						            out_channels=self.embed_dim, | 
					
					
						
						| 
							 | 
						            kernel_size=self.patch_size, | 
					
					
						
						| 
							 | 
						            stride=self.patch_size, | 
					
					
						
						| 
							 | 
						            bias=False, | 
					
					
						
						| 
							 | 
						        ) | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						        self.num_patches = (self.image_size // self.patch_size) ** 2 | 
					
					
						
						| 
							 | 
						        self.num_positions = self.num_patches + 1 | 
					
					
						
						| 
							 | 
						        self.position_embedding = torch.nn.Embedding(self.num_positions, self.embed_dim) | 
					
					
						
						| 
							 | 
						        self.register_buffer( | 
					
					
						
						| 
							 | 
						            "position_ids", torch.arange(self.num_positions).expand((1, -1)) | 
					
					
						
						| 
							 | 
						        ) | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						    def forward(self, pixel_values, patch_embeds): | 
					
					
						
						| 
							 | 
						        batch_size = pixel_values.shape[0] | 
					
					
						
						| 
							 | 
						         | 
					
					
						
						| 
							 | 
						         | 
					
					
						
						| 
							 | 
						         | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						        if patch_embeds is not None: | 
					
					
						
						| 
							 | 
						            patch_embeds = patch_embeds | 
					
					
						
						| 
							 | 
						             | 
					
					
						
						| 
							 | 
						        else: | 
					
					
						
						| 
							 | 
						            patch_embeds = self.patch_embedding(pixel_values)   | 
					
					
						
						| 
							 | 
						             | 
					
					
						
						| 
							 | 
						         | 
					
					
						
						| 
							 | 
						         | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						        patch_embeds = patch_embeds.flatten(2).transpose(1, 2) | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						        class_embeds = self.class_embedding.expand(batch_size, 1, -1) | 
					
					
						
						| 
							 | 
						        embeddings = torch.cat([class_embeds, patch_embeds], dim=1) | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						         | 
					
					
						
						| 
							 | 
						        embeddings = embeddings + get_abs_pos(self.position_embedding(self.position_ids), embeddings.size(1)) | 
					
					
						
						| 
							 | 
						         | 
					
					
						
						| 
							 | 
						        return embeddings | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						class NoTPFeedForward(nn.Module): | 
					
					
						
						| 
							 | 
						    def __init__( | 
					
					
						
						| 
							 | 
						            self, | 
					
					
						
						| 
							 | 
						            cfg, | 
					
					
						
						| 
							 | 
						            dim: int, | 
					
					
						
						| 
							 | 
						            hidden_dim: int, | 
					
					
						
						| 
							 | 
						    ): | 
					
					
						
						| 
							 | 
						        super().__init__() | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						        self.fc1 = torch.nn.Linear(dim, hidden_dim, bias=True) | 
					
					
						
						| 
							 | 
						        self.fc2 = torch.nn.Linear(hidden_dim, dim, bias=True) | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						    def forward(self, x): | 
					
					
						
						| 
							 | 
						        output = self.fc2(quick_gelu(self.fc1(x))) | 
					
					
						
						| 
							 | 
						        return output | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						class NoTPAttention(torch.nn.Module): | 
					
					
						
						| 
							 | 
						    def __init__(self, cfg): | 
					
					
						
						| 
							 | 
						        super().__init__() | 
					
					
						
						| 
							 | 
						        self.num_heads = cfg.num_attention_heads | 
					
					
						
						| 
							 | 
						        self.n_local_heads = cfg.num_attention_heads | 
					
					
						
						| 
							 | 
						        self.head_dim = cfg.hidden_size // cfg.num_attention_heads | 
					
					
						
						| 
							 | 
						        self.max_seq_len = cfg.seq_length | 
					
					
						
						| 
							 | 
						        self.use_flash_attention = cfg.use_flash_attn | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						        self.qkv_proj = torch.nn.Linear(cfg.hidden_size, cfg.hidden_size * 3, bias=True) | 
					
					
						
						| 
							 | 
						        self.out_proj = torch.nn.Linear(cfg.hidden_size, cfg.hidden_size, bias=True) | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						         | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						        self.attn_drop = cfg.attention_dropout | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						    def forward( | 
					
					
						
						| 
							 | 
						            self, | 
					
					
						
						| 
							 | 
						            x: torch.Tensor, | 
					
					
						
						| 
							 | 
						    ): | 
					
					
						
						| 
							 | 
						        bsz, seqlen, _ = x.shape | 
					
					
						
						| 
							 | 
						        xqkv = self.qkv_proj(x) | 
					
					
						
						| 
							 | 
						        xqkv = xqkv.view(bsz, seqlen, 3, self.num_heads, self.head_dim) | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						        if self.use_flash_attention: | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						            xq, xk, xv = torch.split(xqkv, 1, dim=2) | 
					
					
						
						| 
							 | 
						            xq = xq.squeeze(2) | 
					
					
						
						| 
							 | 
						            xk = xk.squeeze(2) | 
					
					
						
						| 
							 | 
						            xv = xv.squeeze(2) | 
					
					
						
						| 
							 | 
						             | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						             | 
					
					
						
						| 
							 | 
						            xq = xq.permute(0, 2, 1, 3) | 
					
					
						
						| 
							 | 
						            xk = xk.permute(0, 2, 1, 3) | 
					
					
						
						| 
							 | 
						            xv = xv.permute(0, 2, 1, 3) | 
					
					
						
						| 
							 | 
						             | 
					
					
						
						| 
							 | 
						            output = torch.nn.functional.scaled_dot_product_attention(xq, xk, xv, attn_mask=None) | 
					
					
						
						| 
							 | 
						            output = output.permute(0, 2, 1, 3).reshape(bsz, seqlen, -1) | 
					
					
						
						| 
							 | 
						                 | 
					
					
						
						| 
							 | 
						        else: | 
					
					
						
						| 
							 | 
						             | 
					
					
						
						| 
							 | 
						            xq, xk, xv = torch.split(xqkv, 1, dim=2) | 
					
					
						
						| 
							 | 
						            xq = xq.squeeze(2) | 
					
					
						
						| 
							 | 
						            xk = xk.squeeze(2) | 
					
					
						
						| 
							 | 
						            xv = xv.squeeze(2) | 
					
					
						
						| 
							 | 
						             | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						             | 
					
					
						
						| 
							 | 
						            xq = xq.permute(0, 2, 1, 3) | 
					
					
						
						| 
							 | 
						            xk = xk.permute(0, 2, 1, 3) | 
					
					
						
						| 
							 | 
						            xv = xv.permute(0, 2, 1, 3) | 
					
					
						
						| 
							 | 
						             | 
					
					
						
						| 
							 | 
						            output = torch.nn.functional.scaled_dot_product_attention(xq, xk, xv, attn_mask=None) | 
					
					
						
						| 
							 | 
						            output = output.permute(0, 2, 1, 3).reshape(bsz, seqlen, -1) | 
					
					
						
						| 
							 | 
						             | 
					
					
						
						| 
							 | 
						        output = self.out_proj(output) | 
					
					
						
						| 
							 | 
						        return output | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						class NoTPTransformerBlock(nn.Module): | 
					
					
						
						| 
							 | 
						    def __init__(self, cfg, layer_id: int, multiple_of=256): | 
					
					
						
						| 
							 | 
						        super().__init__() | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						        self.n_heads = cfg.num_attention_heads | 
					
					
						
						| 
							 | 
						        self.dim = cfg.hidden_size | 
					
					
						
						| 
							 | 
						        self.head_dim = cfg.hidden_size // cfg.num_attention_heads | 
					
					
						
						| 
							 | 
						        self.self_attn = NoTPAttention(cfg) | 
					
					
						
						| 
							 | 
						        self.mlp = NoTPFeedForward( | 
					
					
						
						| 
							 | 
						            cfg, dim=cfg.hidden_size, hidden_dim=cfg.ffn_hidden_size | 
					
					
						
						| 
							 | 
						        ) | 
					
					
						
						| 
							 | 
						        self.layer_id = layer_id | 
					
					
						
						| 
							 | 
						        self.layer_norm1 = torch.nn.LayerNorm( | 
					
					
						
						| 
							 | 
						            cfg.hidden_size, eps=cfg.layernorm_epsilon | 
					
					
						
						| 
							 | 
						        ) | 
					
					
						
						| 
							 | 
						        self.layer_norm2 = torch.nn.LayerNorm( | 
					
					
						
						| 
							 | 
						            cfg.hidden_size, eps=cfg.layernorm_epsilon | 
					
					
						
						| 
							 | 
						        ) | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						    def forward(self, x: torch.Tensor): | 
					
					
						
						| 
							 | 
						        residual = self.self_attn.forward(self.layer_norm1(x)) | 
					
					
						
						| 
							 | 
						        h = x + residual | 
					
					
						
						| 
							 | 
						        out = h + self.mlp.forward(self.layer_norm2(h)) | 
					
					
						
						| 
							 | 
						        return out | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						class NoTPTransformer(nn.Module): | 
					
					
						
						| 
							 | 
						    def __init__(self, cfg): | 
					
					
						
						| 
							 | 
						        super().__init__() | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						        self.cfg = cfg | 
					
					
						
						| 
							 | 
						         | 
					
					
						
						| 
							 | 
						        self.num_layers = cfg.num_layers   | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						        self.layers = torch.nn.ModuleList() | 
					
					
						
						| 
							 | 
						        for layer_id in range(self.num_layers): | 
					
					
						
						| 
							 | 
						            self.layers.append( | 
					
					
						
						| 
							 | 
						                NoTPTransformerBlock( | 
					
					
						
						| 
							 | 
						                    cfg, | 
					
					
						
						| 
							 | 
						                    layer_id + 1, | 
					
					
						
						| 
							 | 
						                ) | 
					
					
						
						| 
							 | 
						            ) | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						    def forward( | 
					
					
						
						| 
							 | 
						            self, | 
					
					
						
						| 
							 | 
						            hidden_states, | 
					
					
						
						| 
							 | 
						    ): | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						        for lid, layer in enumerate(self.layers): | 
					
					
						
						| 
							 | 
						             | 
					
					
						
						| 
							 | 
						             | 
					
					
						
						| 
							 | 
						             | 
					
					
						
						| 
							 | 
						             | 
					
					
						
						| 
							 | 
						             | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						             | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						             | 
					
					
						
						| 
							 | 
						             | 
					
					
						
						| 
							 | 
						             | 
					
					
						
						| 
							 | 
						             | 
					
					
						
						| 
							 | 
						             | 
					
					
						
						| 
							 | 
						             | 
					
					
						
						| 
							 | 
						             | 
					
					
						
						| 
							 | 
						             | 
					
					
						
						| 
							 | 
						             | 
					
					
						
						| 
							 | 
						            hidden_states = layer(hidden_states) | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						        return hidden_states | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						 | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						class VitModel(nn.Module): | 
					
					
						
						| 
							 | 
						    def __init__( | 
					
					
						
						| 
							 | 
						            self, | 
					
					
						
						| 
							 | 
						            cfg, | 
					
					
						
						| 
							 | 
						            freeze_embed=False, | 
					
					
						
						| 
							 | 
						            freeze_pre_norm=False | 
					
					
						
						| 
							 | 
						    ) -> None: | 
					
					
						
						| 
							 | 
						        super().__init__() | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						        self.embeddings = CLIPVisionEmbeddings(hidden_size=cfg.hidden_size, image_size=cfg.image_size, patch_size=cfg.patch_size) | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						        if freeze_embed: | 
					
					
						
						| 
							 | 
						            for name, param in self.embeddings.named_parameters(): | 
					
					
						
						| 
							 | 
						                param.requires_grad = False | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						        self.transformer = NoTPTransformer(cfg=cfg) | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						        if cfg.get("fp32norm", False): | 
					
					
						
						| 
							 | 
						            logger.info("Load fp32 layernorm for ViT.") | 
					
					
						
						| 
							 | 
						            self.pre_layrnorm = LayerNormfp32( | 
					
					
						
						| 
							 | 
						                cfg.hidden_size, | 
					
					
						
						| 
							 | 
						                eps=cfg.get("pre_layernorm_epsilon", 1e-5), | 
					
					
						
						| 
							 | 
						            ) | 
					
					
						
						| 
							 | 
						        else: | 
					
					
						
						| 
							 | 
						            self.pre_layrnorm = torch.nn.LayerNorm( | 
					
					
						
						| 
							 | 
						                cfg.hidden_size, | 
					
					
						
						| 
							 | 
						                eps=cfg.get("pre_layernorm_epsilon", 1e-5), | 
					
					
						
						| 
							 | 
						            ) | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						         | 
					
					
						
						| 
							 | 
						         | 
					
					
						
						| 
							 | 
						         | 
					
					
						
						| 
							 | 
						         | 
					
					
						
						| 
							 | 
						         | 
					
					
						
						| 
							 | 
						         | 
					
					
						
						| 
							 | 
						         | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						        if freeze_pre_norm: | 
					
					
						
						| 
							 | 
						            for name, param in self.pre_layrnorm.named_parameters(): | 
					
					
						
						| 
							 | 
						                param.requires_grad = False | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						        for p in self.parameters(): | 
					
					
						
						| 
							 | 
						            p.micro_dp = True | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						    def set_input_tensor(self, input_tensor): | 
					
					
						
						| 
							 | 
						        if not isinstance(input_tensor, list): | 
					
					
						
						| 
							 | 
						            input_tensor = [input_tensor] | 
					
					
						
						| 
							 | 
						        self.transformer.set_input_tensor(input_tensor[0]) | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						    def __str__(self) -> str: | 
					
					
						
						| 
							 | 
						        return "open_clip" | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						    def forward( | 
					
					
						
						| 
							 | 
						            self, | 
					
					
						
						| 
							 | 
						            x, | 
					
					
						
						| 
							 | 
						            patch_embeds | 
					
					
						
						| 
							 | 
						    ): | 
					
					
						
						| 
							 | 
						        x = self.embeddings(x, patch_embeds) | 
					
					
						
						| 
							 | 
						        hidden_states = self.pre_layrnorm(x) | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						         | 
					
					
						
						| 
							 | 
						        output = self.transformer(hidden_states) | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						         | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						        return output | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						vit_model_cfg = adict( | 
					
					
						
						| 
							 | 
						    num_layers=24, | 
					
					
						
						| 
							 | 
						    hidden_size=1024, | 
					
					
						
						| 
							 | 
						    num_heads = 16, | 
					
					
						
						| 
							 | 
						    num_attention_heads=16, | 
					
					
						
						| 
							 | 
						    ffn_hidden_size=4096, | 
					
					
						
						| 
							 | 
						    seq_length=256, | 
					
					
						
						| 
							 | 
						    max_position_embeddings=256, | 
					
					
						
						| 
							 | 
						    use_flash_attn=False, | 
					
					
						
						| 
							 | 
						    understand_projector_stride=2, | 
					
					
						
						| 
							 | 
						    hidden_dropout = 0.0, | 
					
					
						
						| 
							 | 
						    attention_dropout = 0.0, | 
					
					
						
						| 
							 | 
						    no_persist_layer_norm = False, | 
					
					
						
						| 
							 | 
						    layernorm_epsilon = 1e-5, | 
					
					
						
						| 
							 | 
						    pre_layernorm_epsilon = 1e-5, | 
					
					
						
						| 
							 | 
						    image_size = 224, | 
					
					
						
						| 
							 | 
						    patch_size = 14, | 
					
					
						
						| 
							 | 
						    recompute_list = [] | 
					
					
						
						| 
							 | 
						) | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						def build_clip_l(): | 
					
					
						
						| 
							 | 
						    return VitModel( | 
					
					
						
						| 
							 | 
						        cfg=vit_model_cfg, | 
					
					
						
						| 
							 | 
						        freeze_embed=False, | 
					
					
						
						| 
							 | 
						        freeze_pre_norm=False, | 
					
					
						
						| 
							 | 
						    ) | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						 | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						def get_abs_pos_sam(abs_pos, tgt_size): | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						    dtype = abs_pos.dtype | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						    src_size = abs_pos.size(1) | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						    if src_size != tgt_size: | 
					
					
						
						| 
							 | 
						        old_pos_embed = abs_pos.permute(0, 3, 1, 2) | 
					
					
						
						| 
							 | 
						        old_pos_embed = old_pos_embed.to(torch.float32) | 
					
					
						
						| 
							 | 
						        new_pos_embed = F.interpolate( | 
					
					
						
						| 
							 | 
						            old_pos_embed, | 
					
					
						
						| 
							 | 
						            size=(tgt_size, tgt_size), | 
					
					
						
						| 
							 | 
						            mode='bicubic', | 
					
					
						
						| 
							 | 
						            antialias=True, | 
					
					
						
						| 
							 | 
						            align_corners=False, | 
					
					
						
						| 
							 | 
						        ).to(dtype) | 
					
					
						
						| 
							 | 
						        new_pos_embed = new_pos_embed.permute(0, 2, 3, 1) | 
					
					
						
						| 
							 | 
						        return new_pos_embed | 
					
					
						
						| 
							 | 
						    else: | 
					
					
						
						| 
							 | 
						        return abs_pos | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						class MLPBlock(nn.Module): | 
					
					
						
						| 
							 | 
						    def __init__( | 
					
					
						
						| 
							 | 
						        self, | 
					
					
						
						| 
							 | 
						        embedding_dim: int, | 
					
					
						
						| 
							 | 
						        mlp_dim: int, | 
					
					
						
						| 
							 | 
						        act: Type[nn.Module] = nn.GELU, | 
					
					
						
						| 
							 | 
						    ) -> None: | 
					
					
						
						| 
							 | 
						        super().__init__() | 
					
					
						
						| 
							 | 
						        self.lin1 = nn.Linear(embedding_dim, mlp_dim) | 
					
					
						
						| 
							 | 
						        self.lin2 = nn.Linear(mlp_dim, embedding_dim) | 
					
					
						
						| 
							 | 
						        self.act = act() | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						    def forward(self, x: torch.Tensor) -> torch.Tensor: | 
					
					
						
						| 
							 | 
						        return self.lin2(self.act(self.lin1(x))) | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						 | 
					
					
						
						| 
							 | 
						 | 
					
					
						
						| 
							 | 
						class LayerNorm2d(nn.Module): | 
					
					
						
						| 
							 | 
						    def __init__(self, num_channels: int, eps: float = 1e-6) -> None: | 
					
					
						
						| 
							 | 
						        super().__init__() | 
					
					
						
						| 
							 | 
						        self.weight = nn.Parameter(torch.ones(num_channels)) | 
					
					
						
						| 
							 | 
						        self.bias = nn.Parameter(torch.zeros(num_channels)) | 
					
					
						
						| 
							 | 
						        self.eps = eps | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						    def forward(self, x: torch.Tensor) -> torch.Tensor: | 
					
					
						
						| 
							 | 
						        u = x.mean(1, keepdim=True) | 
					
					
						
						| 
							 | 
						        s = (x - u).pow(2).mean(1, keepdim=True) | 
					
					
						
						| 
							 | 
						        x = (x - u) / torch.sqrt(s + self.eps) | 
					
					
						
						| 
							 | 
						        x = self.weight[:, None, None] * x + self.bias[:, None, None] | 
					
					
						
						| 
							 | 
						        return x | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						 | 
					
					
						
						| 
							 | 
						class ImageEncoderViT(nn.Module): | 
					
					
						
						| 
							 | 
						    def __init__( | 
					
					
						
						| 
							 | 
						        self, | 
					
					
						
						| 
							 | 
						        img_size: int = 1024, | 
					
					
						
						| 
							 | 
						        patch_size: int = 16, | 
					
					
						
						| 
							 | 
						        in_chans: int = 3, | 
					
					
						
						| 
							 | 
						        embed_dim: int = 768, | 
					
					
						
						| 
							 | 
						        depth: int = 12, | 
					
					
						
						| 
							 | 
						        num_heads: int = 12, | 
					
					
						
						| 
							 | 
						        mlp_ratio: float = 4.0, | 
					
					
						
						| 
							 | 
						        out_chans: int = 256, | 
					
					
						
						| 
							 | 
						        qkv_bias: bool = True, | 
					
					
						
						| 
							 | 
						        norm_layer: Type[nn.Module] = nn.LayerNorm, | 
					
					
						
						| 
							 | 
						        act_layer: Type[nn.Module] = nn.GELU, | 
					
					
						
						| 
							 | 
						        use_abs_pos: bool = True, | 
					
					
						
						| 
							 | 
						        use_rel_pos: bool = False, | 
					
					
						
						| 
							 | 
						        rel_pos_zero_init: bool = True, | 
					
					
						
						| 
							 | 
						        window_size: int = 0, | 
					
					
						
						| 
							 | 
						        global_attn_indexes: Tuple[int, ...] = (), | 
					
					
						
						| 
							 | 
						    ) -> None: | 
					
					
						
						| 
							 | 
						        """ | 
					
					
						
						| 
							 | 
						        Args: | 
					
					
						
						| 
							 | 
						            img_size (int): Input image size. | 
					
					
						
						| 
							 | 
						            patch_size (int): Patch size. | 
					
					
						
						| 
							 | 
						            in_chans (int): Number of input image channels. | 
					
					
						
						| 
							 | 
						            embed_dim (int): Patch embedding dimension. | 
					
					
						
						| 
							 | 
						            depth (int): Depth of ViT. | 
					
					
						
						| 
							 | 
						            num_heads (int): Number of attention heads in each ViT block. | 
					
					
						
						| 
							 | 
						            mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. | 
					
					
						
						| 
							 | 
						            qkv_bias (bool): If True, add a learnable bias to query, key, value. | 
					
					
						
						| 
							 | 
						            norm_layer (nn.Module): Normalization layer. | 
					
					
						
						| 
							 | 
						            act_layer (nn.Module): Activation layer. | 
					
					
						
						| 
							 | 
						            use_abs_pos (bool): If True, use absolute positional embeddings. | 
					
					
						
						| 
							 | 
						            use_rel_pos (bool): If True, add relative positional embeddings to the attention map. | 
					
					
						
						| 
							 | 
						            rel_pos_zero_init (bool): If True, zero initialize relative positional parameters. | 
					
					
						
						| 
							 | 
						            window_size (int): Window size for window attention blocks. | 
					
					
						
						| 
							 | 
						            global_attn_indexes (list): Indexes for blocks using global attention. | 
					
					
						
						| 
							 | 
						        """ | 
					
					
						
						| 
							 | 
						        super().__init__() | 
					
					
						
						| 
							 | 
						        self.img_size = img_size | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						        self.patch_embed = PatchEmbed( | 
					
					
						
						| 
							 | 
						            kernel_size=(patch_size, patch_size), | 
					
					
						
						| 
							 | 
						            stride=(patch_size, patch_size), | 
					
					
						
						| 
							 | 
						            in_chans=in_chans, | 
					
					
						
						| 
							 | 
						            embed_dim=embed_dim, | 
					
					
						
						| 
							 | 
						        ) | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						        self.pos_embed: Optional[nn.Parameter] = None | 
					
					
						
						| 
							 | 
						        if use_abs_pos: | 
					
					
						
						| 
							 | 
						             | 
					
					
						
						| 
							 | 
						            self.pos_embed = nn.Parameter( | 
					
					
						
						| 
							 | 
						                torch.zeros(1, img_size // patch_size, img_size // patch_size, embed_dim) | 
					
					
						
						| 
							 | 
						            ) | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						        self.blocks = nn.ModuleList() | 
					
					
						
						| 
							 | 
						        for i in range(depth): | 
					
					
						
						| 
							 | 
						            block = Block( | 
					
					
						
						| 
							 | 
						                dim=embed_dim, | 
					
					
						
						| 
							 | 
						                num_heads=num_heads, | 
					
					
						
						| 
							 | 
						                mlp_ratio=mlp_ratio, | 
					
					
						
						| 
							 | 
						                qkv_bias=qkv_bias, | 
					
					
						
						| 
							 | 
						                norm_layer=norm_layer, | 
					
					
						
						| 
							 | 
						                act_layer=act_layer, | 
					
					
						
						| 
							 | 
						                use_rel_pos=use_rel_pos, | 
					
					
						
						| 
							 | 
						                rel_pos_zero_init=rel_pos_zero_init, | 
					
					
						
						| 
							 | 
						                window_size=window_size if i not in global_attn_indexes else 0, | 
					
					
						
						| 
							 | 
						                input_size=(img_size // patch_size, img_size // patch_size), | 
					
					
						
						| 
							 | 
						            ) | 
					
					
						
						| 
							 | 
						            self.blocks.append(block) | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						        self.neck = nn.Sequential( | 
					
					
						
						| 
							 | 
						            nn.Conv2d( | 
					
					
						
						| 
							 | 
						                embed_dim, | 
					
					
						
						| 
							 | 
						                out_chans, | 
					
					
						
						| 
							 | 
						                kernel_size=1, | 
					
					
						
						| 
							 | 
						                bias=False, | 
					
					
						
						| 
							 | 
						            ), | 
					
					
						
						| 
							 | 
						            LayerNorm2d(out_chans), | 
					
					
						
						| 
							 | 
						            nn.Conv2d( | 
					
					
						
						| 
							 | 
						                out_chans, | 
					
					
						
						| 
							 | 
						                out_chans, | 
					
					
						
						| 
							 | 
						                kernel_size=3, | 
					
					
						
						| 
							 | 
						                padding=1, | 
					
					
						
						| 
							 | 
						                bias=False, | 
					
					
						
						| 
							 | 
						            ), | 
					
					
						
						| 
							 | 
						            LayerNorm2d(out_chans), | 
					
					
						
						| 
							 | 
						        ) | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						        self.net_2 = nn.Conv2d(256, 512, kernel_size=3, stride=2, padding=1, bias=False) | 
					
					
						
						| 
							 | 
						        self.net_3 = nn.Conv2d(512, 1024, kernel_size=3, stride=2, padding=1, bias=False) | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						    def forward(self, x: torch.Tensor) -> torch.Tensor: | 
					
					
						
						| 
							 | 
						        x = self.patch_embed(x) | 
					
					
						
						| 
							 | 
						        if self.pos_embed is not None: | 
					
					
						
						| 
							 | 
						             | 
					
					
						
						| 
							 | 
						            x = x + get_abs_pos_sam(self.pos_embed, x.size(1)) | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						        for blk in self.blocks: | 
					
					
						
						| 
							 | 
						            x = blk(x) | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						        x = self.neck(x.permute(0, 3, 1, 2)) | 
					
					
						
						| 
							 | 
						        x2 = self.net_2(x) | 
					
					
						
						| 
							 | 
						        x3 = self.net_3(x2.clone()) | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						        return x3 | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						class Block(nn.Module): | 
					
					
						
						| 
							 | 
						    """Transformer blocks with support of window attention and residual propagation blocks""" | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						    def __init__( | 
					
					
						
						| 
							 | 
						        self, | 
					
					
						
						| 
							 | 
						        dim: int, | 
					
					
						
						| 
							 | 
						        num_heads: int, | 
					
					
						
						| 
							 | 
						        mlp_ratio: float = 4.0, | 
					
					
						
						| 
							 | 
						        qkv_bias: bool = True, | 
					
					
						
						| 
							 | 
						        norm_layer: Type[nn.Module] = nn.LayerNorm, | 
					
					
						
						| 
							 | 
						        act_layer: Type[nn.Module] = nn.GELU, | 
					
					
						
						| 
							 | 
						        use_rel_pos: bool = False, | 
					
					
						
						| 
							 | 
						        rel_pos_zero_init: bool = True, | 
					
					
						
						| 
							 | 
						        window_size: int = 0, | 
					
					
						
						| 
							 | 
						        input_size: Optional[Tuple[int, int]] = None, | 
					
					
						
						| 
							 | 
						    ) -> None: | 
					
					
						
						| 
							 | 
						        """ | 
					
					
						
						| 
							 | 
						        Args: | 
					
					
						
						| 
							 | 
						            dim (int): Number of input channels. | 
					
					
						
						| 
							 | 
						            num_heads (int): Number of attention heads in each ViT block. | 
					
					
						
						| 
							 | 
						            mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. | 
					
					
						
						| 
							 | 
						            qkv_bias (bool): If True, add a learnable bias to query, key, value. | 
					
					
						
						| 
							 | 
						            norm_layer (nn.Module): Normalization layer. | 
					
					
						
						| 
							 | 
						            act_layer (nn.Module): Activation layer. | 
					
					
						
						| 
							 | 
						            use_rel_pos (bool): If True, add relative positional embeddings to the attention map. | 
					
					
						
						| 
							 | 
						            rel_pos_zero_init (bool): If True, zero initialize relative positional parameters. | 
					
					
						
						| 
							 | 
						            window_size (int): Window size for window attention blocks. If it equals 0, then | 
					
					
						
						| 
							 | 
						                use global attention. | 
					
					
						
						| 
							 | 
						            input_size (tuple(int, int) or None): Input resolution for calculating the relative | 
					
					
						
						| 
							 | 
						                positional parameter size. | 
					
					
						
						| 
							 | 
						        """ | 
					
					
						
						| 
							 | 
						        super().__init__() | 
					
					
						
						| 
							 | 
						        self.norm1 = norm_layer(dim) | 
					
					
						
						| 
							 | 
						        self.attn = Attention( | 
					
					
						
						| 
							 | 
						            dim, | 
					
					
						
						| 
							 | 
						            num_heads=num_heads, | 
					
					
						
						| 
							 | 
						            qkv_bias=qkv_bias, | 
					
					
						
						| 
							 | 
						            use_rel_pos=use_rel_pos, | 
					
					
						
						| 
							 | 
						            rel_pos_zero_init=rel_pos_zero_init, | 
					
					
						
						| 
							 | 
						            input_size=input_size if window_size == 0 else (window_size, window_size), | 
					
					
						
						| 
							 | 
						        ) | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						        self.norm2 = norm_layer(dim) | 
					
					
						
						| 
							 | 
						        self.mlp = MLPBlock(embedding_dim=dim, mlp_dim=int(dim * mlp_ratio), act=act_layer) | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						        self.window_size = window_size | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						    def forward(self, x: torch.Tensor) -> torch.Tensor: | 
					
					
						
						| 
							 | 
						        shortcut = x | 
					
					
						
						| 
							 | 
						        x = self.norm1(x) | 
					
					
						
						| 
							 | 
						         | 
					
					
						
						| 
							 | 
						        if self.window_size > 0: | 
					
					
						
						| 
							 | 
						            H, W = x.shape[1], x.shape[2] | 
					
					
						
						| 
							 | 
						            x, pad_hw = window_partition(x, self.window_size) | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						        x = self.attn(x) | 
					
					
						
						| 
							 | 
						         | 
					
					
						
						| 
							 | 
						        if self.window_size > 0: | 
					
					
						
						| 
							 | 
						            x = window_unpartition(x, self.window_size, pad_hw, (H, W)) | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						        x = shortcut + x | 
					
					
						
						| 
							 | 
						        x = x + self.mlp(self.norm2(x)) | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						        return x | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						class Attention(nn.Module): | 
					
					
						
						| 
							 | 
						    """Multi-head Attention block with relative position embeddings.""" | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						    def __init__( | 
					
					
						
						| 
							 | 
						        self, | 
					
					
						
						| 
							 | 
						        dim: int, | 
					
					
						
						| 
							 | 
						        num_heads: int = 8, | 
					
					
						
						| 
							 | 
						        qkv_bias: bool = True, | 
					
					
						
						| 
							 | 
						        use_rel_pos: bool = False, | 
					
					
						
						| 
							 | 
						        rel_pos_zero_init: bool = True, | 
					
					
						
						| 
							 | 
						        input_size: Optional[Tuple[int, int]] = None, | 
					
					
						
						| 
							 | 
						    ) -> None: | 
					
					
						
						| 
							 | 
						        """ | 
					
					
						
						| 
							 | 
						        Args: | 
					
					
						
						| 
							 | 
						            dim (int): Number of input channels. | 
					
					
						
						| 
							 | 
						            num_heads (int): Number of attention heads. | 
					
					
						
						| 
							 | 
						            qkv_bias (bool):  If True, add a learnable bias to query, key, value. | 
					
					
						
						| 
							 | 
						            rel_pos (bool): If True, add relative positional embeddings to the attention map. | 
					
					
						
						| 
							 | 
						            rel_pos_zero_init (bool): If True, zero initialize relative positional parameters. | 
					
					
						
						| 
							 | 
						            input_size (tuple(int, int) or None): Input resolution for calculating the relative | 
					
					
						
						| 
							 | 
						                positional parameter size. | 
					
					
						
						| 
							 | 
						        """ | 
					
					
						
						| 
							 | 
						        super().__init__() | 
					
					
						
						| 
							 | 
						        self.num_heads = num_heads | 
					
					
						
						| 
							 | 
						        head_dim = dim // num_heads | 
					
					
						
						| 
							 | 
						        self.scale = head_dim**-0.5 | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						        self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias) | 
					
					
						
						| 
							 | 
						        self.proj = nn.Linear(dim, dim) | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						        self.use_rel_pos = use_rel_pos | 
					
					
						
						| 
							 | 
						        if self.use_rel_pos: | 
					
					
						
						| 
							 | 
						            assert ( | 
					
					
						
						| 
							 | 
						                input_size is not None | 
					
					
						
						| 
							 | 
						            ), "Input size must be provided if using relative positional encoding." | 
					
					
						
						| 
							 | 
						             | 
					
					
						
						| 
							 | 
						            self.rel_pos_h = nn.Parameter(torch.zeros(2 * input_size[0] - 1, head_dim)) | 
					
					
						
						| 
							 | 
						            self.rel_pos_w = nn.Parameter(torch.zeros(2 * input_size[1] - 1, head_dim)) | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						    def forward(self, x: torch.Tensor) -> torch.Tensor: | 
					
					
						
						| 
							 | 
						        B, H, W, _ = x.shape | 
					
					
						
						| 
							 | 
						         | 
					
					
						
						| 
							 | 
						        qkv = self.qkv(x).reshape(B, H * W, 3, self.num_heads, -1).permute(2, 0, 3, 1, 4) | 
					
					
						
						| 
							 | 
						         | 
					
					
						
						| 
							 | 
						        q, k, v = qkv.reshape(3, B * self.num_heads, H * W, -1).unbind(0) | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						        rel_h, rel_w = None, None | 
					
					
						
						| 
							 | 
						        if self.use_rel_pos: | 
					
					
						
						| 
							 | 
						            rel_h, rel_w = add_decomposed_rel_pos(q, self.rel_pos_h, self.rel_pos_w, (H, W), (H, W)) | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						        q = q.view(B, self.num_heads, H * W, -1) | 
					
					
						
						| 
							 | 
						        k = k.view(B, self.num_heads, H * W, -1) | 
					
					
						
						| 
							 | 
						        v = v.view(B, self.num_heads, H * W, -1) | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						        if self.use_rel_pos: | 
					
					
						
						| 
							 | 
						            rel_h = rel_h.view(B, self.num_heads, rel_h.size(1), rel_h.size(2), rel_h.size(3)) | 
					
					
						
						| 
							 | 
						            rel_w = rel_w.view(B, self.num_heads, rel_w.size(1), rel_w.size(2), rel_w.size(3)) | 
					
					
						
						| 
							 | 
						            attn_bias = (rel_h + rel_w).view(B, self.num_heads, rel_h.size(2), rel_h.size(3) * rel_w.size(4)) | 
					
					
						
						| 
							 | 
						            x = torch.nn.functional.scaled_dot_product_attention(q, k, v, attn_mask=attn_bias) | 
					
					
						
						| 
							 | 
						             | 
					
					
						
						| 
							 | 
						        else: | 
					
					
						
						| 
							 | 
						            x = torch.nn.functional.scaled_dot_product_attention(q, k, v) | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						        x = x.view(B, self.num_heads, H, W, -1).permute(0, 2, 3, 1, 4).reshape(B, H, W, -1) | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						        x = self.proj(x) | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						        return x | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						def window_partition(x: torch.Tensor, window_size: int) -> Tuple[torch.Tensor, Tuple[int, int]]: | 
					
					
						
						| 
							 | 
						    """ | 
					
					
						
						| 
							 | 
						    Partition into non-overlapping windows with padding if needed. | 
					
					
						
						| 
							 | 
						    Args: | 
					
					
						
						| 
							 | 
						        x (tensor): input tokens with [B, H, W, C]. | 
					
					
						
						| 
							 | 
						        window_size (int): window size. | 
					
					
						
						| 
							 | 
						 | 
					
					
						
						| 
							 | 
						    Returns: | 
					
					
						
						| 
							 | 
						        windows: windows after partition with [B * num_windows, window_size, window_size, C]. | 
					
					
						
						| 
							 | 
						        (Hp, Wp): padded height and width before partition | 
					
					
						
						| 
							 | 
						    """ | 
					
					
						
						| 
							 | 
						    B, H, W, C = x.shape | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						    pad_h = (window_size - H % window_size) % window_size | 
					
					
						
						| 
							 | 
						    pad_w = (window_size - W % window_size) % window_size | 
					
					
						
						| 
							 | 
						    if pad_h > 0 or pad_w > 0: | 
					
					
						
						| 
							 | 
						        x = F.pad(x, (0, 0, 0, pad_w, 0, pad_h)) | 
					
					
						
						| 
							 | 
						    Hp, Wp = H + pad_h, W + pad_w | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						    x = x.view(B, Hp // window_size, window_size, Wp // window_size, window_size, C) | 
					
					
						
						| 
							 | 
						    windows = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, window_size, window_size, C) | 
					
					
						
						| 
							 | 
						    return windows, (Hp, Wp) | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						def window_unpartition( | 
					
					
						
						| 
							 | 
						    windows: torch.Tensor, window_size: int, pad_hw: Tuple[int, int], hw: Tuple[int, int] | 
					
					
						
						| 
							 | 
						) -> torch.Tensor: | 
					
					
						
						| 
							 | 
						    """ | 
					
					
						
						| 
							 | 
						    Window unpartition into original sequences and removing padding. | 
					
					
						
						| 
							 | 
						    Args: | 
					
					
						
						| 
							 | 
						        windows (tensor): input tokens with [B * num_windows, window_size, window_size, C]. | 
					
					
						
						| 
							 | 
						        window_size (int): window size. | 
					
					
						
						| 
							 | 
						        pad_hw (Tuple): padded height and width (Hp, Wp). | 
					
					
						
						| 
							 | 
						        hw (Tuple): original height and width (H, W) before padding. | 
					
					
						
						| 
							 | 
						 | 
					
					
						
						| 
							 | 
						    Returns: | 
					
					
						
						| 
							 | 
						        x: unpartitioned sequences with [B, H, W, C]. | 
					
					
						
						| 
							 | 
						    """ | 
					
					
						
						| 
							 | 
						    Hp, Wp = pad_hw | 
					
					
						
						| 
							 | 
						    H, W = hw | 
					
					
						
						| 
							 | 
						    B = windows.shape[0] // (Hp * Wp // window_size // window_size) | 
					
					
						
						| 
							 | 
						    x = windows.view(B, Hp // window_size, Wp // window_size, window_size, window_size, -1) | 
					
					
						
						| 
							 | 
						    x = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(B, Hp, Wp, -1) | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						    if Hp > H or Wp > W: | 
					
					
						
						| 
							 | 
						        x = x[:, :H, :W, :].contiguous() | 
					
					
						
						| 
							 | 
						    return x | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						def get_rel_pos(q_size: int, k_size: int, rel_pos: torch.Tensor) -> torch.Tensor: | 
					
					
						
						| 
							 | 
						    """ | 
					
					
						
						| 
							 | 
						    Get relative positional embeddings according to the relative positions of | 
					
					
						
						| 
							 | 
						        query and key sizes. | 
					
					
						
						| 
							 | 
						    Args: | 
					
					
						
						| 
							 | 
						        q_size (int): size of query q. | 
					
					
						
						| 
							 | 
						        k_size (int): size of key k. | 
					
					
						
						| 
							 | 
						        rel_pos (Tensor): relative position embeddings (L, C). | 
					
					
						
						| 
							 | 
						 | 
					
					
						
						| 
							 | 
						    Returns: | 
					
					
						
						| 
							 | 
						        Extracted positional embeddings according to relative positions. | 
					
					
						
						| 
							 | 
						    """ | 
					
					
						
						| 
							 | 
						    max_rel_dist = int(2 * max(q_size, k_size) - 1) | 
					
					
						
						| 
							 | 
						     | 
					
					
						
						| 
							 | 
						    if rel_pos.shape[0] != max_rel_dist: | 
					
					
						
						| 
							 | 
						         | 
					
					
						
						| 
							 | 
						        dtype = rel_pos.dtype | 
					
					
						
						| 
							 | 
						        rel_pos = rel_pos.to(torch.float32) | 
					
					
						
						| 
							 | 
						        rel_pos_resized = F.interpolate( | 
					
					
						
						| 
							 | 
						            rel_pos.reshape(1, rel_pos.shape[0], -1).permute(0, 2, 1), | 
					
					
						
						| 
							 | 
						            size=max_rel_dist, | 
					
					
						
						| 
							 | 
						            mode="linear", | 
					
					
						
						| 
							 | 
						        ).to(dtype) | 
					
					
						
						| 
							 | 
						        rel_pos_resized = rel_pos_resized.reshape(-1, max_rel_dist).permute(1, 0) | 
					
					
						
						| 
							 | 
						    else: | 
					
					
						
						| 
							 | 
						        rel_pos_resized = rel_pos | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						     | 
					
					
						
						| 
							 | 
						    q_coords = torch.arange(q_size, device=rel_pos.device)[:, None] * max(k_size / q_size, 1.0) | 
					
					
						
						| 
							 | 
						    k_coords = torch.arange(k_size, device=rel_pos.device)[None, :] * max(q_size / k_size, 1.0) | 
					
					
						
						| 
							 | 
						    relative_coords = (q_coords - k_coords) + (k_size - 1) * max(q_size / k_size, 1.0) | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						    return rel_pos_resized[relative_coords.long()] | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						def add_decomposed_rel_pos( | 
					
					
						
						| 
							 | 
						    q: torch.Tensor, | 
					
					
						
						| 
							 | 
						    rel_pos_h: torch.Tensor, | 
					
					
						
						| 
							 | 
						    rel_pos_w: torch.Tensor, | 
					
					
						
						| 
							 | 
						    q_size: Tuple[int, int], | 
					
					
						
						| 
							 | 
						    k_size: Tuple[int, int], | 
					
					
						
						| 
							 | 
						) -> torch.Tensor: | 
					
					
						
						| 
							 | 
						    """ | 
					
					
						
						| 
							 | 
						    Calculate decomposed Relative Positional Embeddings from :paper:`mvitv2`. | 
					
					
						
						| 
							 | 
						    https://github.com/facebookresearch/mvit/blob/19786631e330df9f3622e5402b4a419a263a2c80/mvit/models/attention.py   # noqa B950 | 
					
					
						
						| 
							 | 
						    Args: | 
					
					
						
						| 
							 | 
						        q (Tensor): query q in the attention layer with shape (B, q_h * q_w, C). | 
					
					
						
						| 
							 | 
						        rel_pos_h (Tensor): relative position embeddings (Lh, C) for height axis. | 
					
					
						
						| 
							 | 
						        rel_pos_w (Tensor): relative position embeddings (Lw, C) for width axis. | 
					
					
						
						| 
							 | 
						        q_size (Tuple): spatial sequence size of query q with (q_h, q_w). | 
					
					
						
						| 
							 | 
						        k_size (Tuple): spatial sequence size of key k with (k_h, k_w). | 
					
					
						
						| 
							 | 
						 | 
					
					
						
						| 
							 | 
						    Returns: | 
					
					
						
						| 
							 | 
						        attn (Tensor): attention map with added relative positional embeddings. | 
					
					
						
						| 
							 | 
						    """ | 
					
					
						
						| 
							 | 
						    q_h, q_w = q_size | 
					
					
						
						| 
							 | 
						    k_h, k_w = k_size | 
					
					
						
						| 
							 | 
						    Rh = get_rel_pos(q_h, k_h, rel_pos_h) | 
					
					
						
						| 
							 | 
						    Rw = get_rel_pos(q_w, k_w, rel_pos_w) | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						    B, _, dim = q.shape | 
					
					
						
						| 
							 | 
						    r_q = q.reshape(B, q_h, q_w, dim) | 
					
					
						
						| 
							 | 
						    rel_h = torch.einsum("bhwc,hkc->bhwk", r_q, Rh) | 
					
					
						
						| 
							 | 
						    rel_w = torch.einsum("bhwc,wkc->bhwk", r_q, Rw) | 
					
					
						
						| 
							 | 
						    rel_h = rel_h.unsqueeze(-1) | 
					
					
						
						| 
							 | 
						    rel_w = rel_w.unsqueeze(-2) | 
					
					
						
						| 
							 | 
						    rel_h = rel_h.reshape(B, q_h * q_w, k_h, 1) | 
					
					
						
						| 
							 | 
						    rel_w = rel_w.reshape(B, q_h * q_w, 1, k_w) | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						    return rel_h, rel_w | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						class PatchEmbed(nn.Module): | 
					
					
						
						| 
							 | 
						    """ | 
					
					
						
						| 
							 | 
						    Image to Patch Embedding. | 
					
					
						
						| 
							 | 
						    """ | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						    def __init__( | 
					
					
						
						| 
							 | 
						        self, | 
					
					
						
						| 
							 | 
						        kernel_size: Tuple[int, int] = (16, 16), | 
					
					
						
						| 
							 | 
						        stride: Tuple[int, int] = (16, 16), | 
					
					
						
						| 
							 | 
						        padding: Tuple[int, int] = (0, 0), | 
					
					
						
						| 
							 | 
						        in_chans: int = 3, | 
					
					
						
						| 
							 | 
						        embed_dim: int = 768, | 
					
					
						
						| 
							 | 
						    ) -> None: | 
					
					
						
						| 
							 | 
						        """ | 
					
					
						
						| 
							 | 
						        Args: | 
					
					
						
						| 
							 | 
						            kernel_size (Tuple): kernel size of the projection layer. | 
					
					
						
						| 
							 | 
						            stride (Tuple): stride of the projection layer. | 
					
					
						
						| 
							 | 
						            padding (Tuple): padding size of the projection layer. | 
					
					
						
						| 
							 | 
						            in_chans (int): Number of input image channels. | 
					
					
						
						| 
							 | 
						            embed_dim (int): Patch embedding dimension. | 
					
					
						
						| 
							 | 
						        """ | 
					
					
						
						| 
							 | 
						        super().__init__() | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						        self.proj = nn.Conv2d( | 
					
					
						
						| 
							 | 
						            in_chans, embed_dim, kernel_size=kernel_size, stride=stride, padding=padding | 
					
					
						
						| 
							 | 
						        ) | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						    def forward(self, x: torch.Tensor) -> torch.Tensor: | 
					
					
						
						| 
							 | 
						        x = self.proj(x) | 
					
					
						
						| 
							 | 
						         | 
					
					
						
						| 
							 | 
						        x = x.permute(0, 2, 3, 1) | 
					
					
						
						| 
							 | 
						        return x | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						def build_sam_vit_b(checkpoint=None): | 
					
					
						
						| 
							 | 
						    return _build_sam( | 
					
					
						
						| 
							 | 
						        encoder_embed_dim=768, | 
					
					
						
						| 
							 | 
						        encoder_depth=12, | 
					
					
						
						| 
							 | 
						        encoder_num_heads=12, | 
					
					
						
						| 
							 | 
						        encoder_global_attn_indexes=[2, 5, 8, 11], | 
					
					
						
						| 
							 | 
						        checkpoint=checkpoint, | 
					
					
						
						| 
							 | 
						    ) | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						def build_sam_fast_vit_b(checkpoint=None, compile_mode='max-autotune', dtype=torch.bfloat16): | 
					
					
						
						| 
							 | 
						    image_encoder = build_sam_vit_b(checkpoint).eval().to(dtype) | 
					
					
						
						| 
							 | 
						     | 
					
					
						
						| 
							 | 
						    image_encoder = torch.compile(image_encoder, mode=compile_mode) | 
					
					
						
						| 
							 | 
						    return image_encoder | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						def _build_sam( | 
					
					
						
						| 
							 | 
						    encoder_embed_dim, | 
					
					
						
						| 
							 | 
						    encoder_depth, | 
					
					
						
						| 
							 | 
						    encoder_num_heads, | 
					
					
						
						| 
							 | 
						    encoder_global_attn_indexes, | 
					
					
						
						| 
							 | 
						    checkpoint=None, | 
					
					
						
						| 
							 | 
						): | 
					
					
						
						| 
							 | 
						    prompt_embed_dim = 256 | 
					
					
						
						| 
							 | 
						    image_size = 1024 | 
					
					
						
						| 
							 | 
						    vit_patch_size = 16 | 
					
					
						
						| 
							 | 
						    image_embedding_size = image_size // vit_patch_size | 
					
					
						
						| 
							 | 
						    image_encoder=ImageEncoderViT( | 
					
					
						
						| 
							 | 
						            depth=encoder_depth, | 
					
					
						
						| 
							 | 
						            embed_dim=encoder_embed_dim, | 
					
					
						
						| 
							 | 
						            img_size=image_size, | 
					
					
						
						| 
							 | 
						            mlp_ratio=4, | 
					
					
						
						| 
							 | 
						            norm_layer=partial(torch.nn.LayerNorm, eps=1e-6), | 
					
					
						
						| 
							 | 
						            num_heads=encoder_num_heads, | 
					
					
						
						| 
							 | 
						            patch_size=vit_patch_size, | 
					
					
						
						| 
							 | 
						            qkv_bias=True, | 
					
					
						
						| 
							 | 
						            use_rel_pos=True, | 
					
					
						
						| 
							 | 
						            global_attn_indexes=encoder_global_attn_indexes, | 
					
					
						
						| 
							 | 
						            window_size=14, | 
					
					
						
						| 
							 | 
						            out_chans=prompt_embed_dim, | 
					
					
						
						| 
							 | 
						        ) | 
					
					
						
						| 
							 | 
						    image_encoder.eval() | 
					
					
						
						| 
							 | 
						    if checkpoint is not None: | 
					
					
						
						| 
							 | 
						         | 
					
					
						
						| 
							 | 
						        state_dict = torch.load(checkpoint) | 
					
					
						
						| 
							 | 
						         | 
					
					
						
						| 
							 | 
						         | 
					
					
						
						| 
							 | 
						         | 
					
					
						
						| 
							 | 
						         | 
					
					
						
						| 
							 | 
						         | 
					
					
						
						| 
							 | 
						         | 
					
					
						
						| 
							 | 
						        image_encoder.load_state_dict({k[30:]: v for k, v in state_dict.items() if 'vision_tower_high' in k}, strict=True) | 
					
					
						
						| 
							 | 
						        print(checkpoint) | 
					
					
						
						| 
							 | 
						    return image_encoder |