|
|
import torch.nn as nn |
|
|
from compressai.entropy_models import EntropyBottleneck |
|
|
from timm.models.vision_transformer import Block |
|
|
|
|
|
class IF_Module(nn.Module): |
|
|
def __init__(self, embed_dim, num_heads, mlp_ratio, depth=4, norm_layer=nn.LayerNorm): |
|
|
super(IF_Module, self).__init__() |
|
|
|
|
|
self.encoder_blocks = nn.ModuleList([ |
|
|
Block(embed_dim, num_heads, mlp_ratio, qkv_bias=True, norm_layer=norm_layer) |
|
|
for i in range(depth)]) |
|
|
self.encoder_norm = norm_layer(embed_dim) |
|
|
|
|
|
self.decoder_blocks = nn.ModuleList([ |
|
|
Block(embed_dim, num_heads, mlp_ratio, qkv_bias=True, norm_layer=norm_layer) |
|
|
for i in range(depth)]) |
|
|
|
|
|
self.decoder_norm = norm_layer(embed_dim) |
|
|
self.entropy_bottleneck = EntropyBottleneck(embed_dim) |
|
|
|
|
|
def forward(self, x, is_training=False): |
|
|
|
|
|
for blk in self.encoder_blocks: |
|
|
x = blk(x) |
|
|
x = self.encoder_norm(x) |
|
|
|
|
|
if is_training: |
|
|
x = x.permute(0, 2, 1) |
|
|
x_hat, x_likelihood = self.entropy_bottleneck(x) |
|
|
x_hat = x_hat.permute(0, 2, 1) |
|
|
else: |
|
|
x_hat = x |
|
|
x_likelihood = None |
|
|
|
|
|
|
|
|
for blk in self.decoder_blocks: |
|
|
x_hat = blk(x_hat) |
|
|
x_hat = self.decoder_norm(x_hat) |
|
|
|
|
|
return x_hat, x_likelihood |
|
|
|