DiffICM / 4_ControlModule /models_IB.py
Qiyp's picture
code of stage1 & 3, remove large files
1633fcc
import torch.nn as nn
from compressai.entropy_models import EntropyBottleneck
from timm.models.vision_transformer import Block
class IF_Module(nn.Module):
def __init__(self, embed_dim, num_heads, mlp_ratio, depth=4, norm_layer=nn.LayerNorm):
super(IF_Module, self).__init__()
self.encoder_blocks = nn.ModuleList([ # 4 layers, embed_dim=768, num_heads=12
Block(embed_dim, num_heads, mlp_ratio, qkv_bias=True, norm_layer=norm_layer)
for i in range(depth)])
self.encoder_norm = norm_layer(embed_dim)
self.decoder_blocks = nn.ModuleList([ # 4 layers, embed_dim=768, num_heads=12
Block(embed_dim, num_heads, mlp_ratio, qkv_bias=True, norm_layer=norm_layer)
for i in range(depth)])
self.decoder_norm = norm_layer(embed_dim)
self.entropy_bottleneck = EntropyBottleneck(embed_dim)
def forward(self, x, is_training=False):
# ViT analysis transform
for blk in self.encoder_blocks:
x = blk(x)
x = self.encoder_norm(x)
if is_training:
x = x.permute(0, 2, 1)
x_hat, x_likelihood = self.entropy_bottleneck(x)
x_hat = x_hat.permute(0, 2, 1)
else:
x_hat = x
x_likelihood = None
# ViT synthesis transform
for blk in self.decoder_blocks:
x_hat = blk(x_hat)
x_hat = self.decoder_norm(x_hat)
return x_hat, x_likelihood