import torch.nn as nn from compressai.entropy_models import EntropyBottleneck from timm.models.vision_transformer import Block class IF_Module(nn.Module): def __init__(self, embed_dim, num_heads, mlp_ratio, depth=4, norm_layer=nn.LayerNorm): super(IF_Module, self).__init__() self.encoder_blocks = nn.ModuleList([ # 4 layers, embed_dim=768, num_heads=12 Block(embed_dim, num_heads, mlp_ratio, qkv_bias=True, norm_layer=norm_layer) for i in range(depth)]) self.encoder_norm = norm_layer(embed_dim) self.decoder_blocks = nn.ModuleList([ # 4 layers, embed_dim=768, num_heads=12 Block(embed_dim, num_heads, mlp_ratio, qkv_bias=True, norm_layer=norm_layer) for i in range(depth)]) self.decoder_norm = norm_layer(embed_dim) self.entropy_bottleneck = EntropyBottleneck(embed_dim) def forward(self, x, is_training=False): # ViT analysis transform for blk in self.encoder_blocks: x = blk(x) x = self.encoder_norm(x) if is_training: x = x.permute(0, 2, 1) x_hat, x_likelihood = self.entropy_bottleneck(x) x_hat = x_hat.permute(0, 2, 1) else: x_hat = x x_likelihood = None # ViT synthesis transform for blk in self.decoder_blocks: x_hat = blk(x_hat) x_hat = self.decoder_norm(x_hat) return x_hat, x_likelihood