Virtual-Cloths-TryOn / diffusion.py
harsh99's picture
bug fixes
16d759f
import torch
from torch import nn
from torch.nn import functional as F
from attention import SelfAttention, CrossAttention
class TimeEmbedding(nn.Module):
def __init__(self, n_embed):
super().__init__()
self.linear_1=nn.Linear(n_embed, 4*n_embed)
self.linear_2=nn.Linear(4*n_embed, 4*n_embed)
def forward(self, x):
x=self.linear_1(x)
x=F.silu(x)
x=self.linear_2(x)
return x
class UNET_ResidualBlock(nn.Module):
def __init__(self, in_channels, out_channels, n_time=1280):
super().__init__()
self.grpnorm_feature=nn.GroupNorm(32, in_channels)
self.conv_feature=nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1)
self.linear_time=nn.Linear(n_time, out_channels)
self.grpnorm_merged=nn.GroupNorm(32, out_channels)
self.conv_merged=nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1)
if in_channels==out_channels:
self.residual_layer=nn.Identity()
else:
self.residual_layer=nn.Conv2d(in_channels, out_channels, kernel_size=1, padding=0)
def forward(self, feature, time):
residue=feature
feature=self.grpnorm_feature(feature)
feature=F.silu(feature)
feature=self.conv_feature(feature)
time=F.silu(time)
time=self.linear_time(time)
merged=feature+time.unsqueeze(-1).unsqueeze(-1)
merged=self.grpnorm_merged(merged)
merged=F.silu(merged)
merged=self.conv_merged(merged)
return merged + self.residual_layer(residue)
class UNET_AttentionBlock(nn.Module):
def __init__(self, n_head, n_embed):
super().__init__()
channels=n_head*n_embed
self.grpnorm=nn.GroupNorm(32, channels, eps=1e-6)
self.conv_input=nn.Conv2d(channels, channels, kernel_size=1, padding=0)
self.layernorm_1=nn.LayerNorm(channels)
self.attention_1=SelfAttention(n_head, channels, in_proj_bias=False)
self.layernorm_2=nn.LayerNorm(channels)
# self.attention_2=CrossAttention(n_head, channels, d_context, in_proj_bias=False)
self.layernorm_3=nn.LayerNorm(channels)
self.linear_geglu_1=nn.Linear(channels, 4*channels*2)
self.linear_geglu_2=nn.Linear(4*channels, channels)
self.conv_output=nn.Conv2d(channels, channels, kernel_size=1, padding=0)
def forward(self, x):
residue_long=x
x=self.grpnorm(x)
x=self.conv_input(x)
n, c, h, w=x.shape
x=x.view((n,c,h*w))
x=x.transpose(-1, -2)
residue_short=x
x=self.layernorm_1(x)
x=self.attention_1(x)
x+=residue_short
residue_short=x
x=self.layernorm_2(x)
# x=self.attention_2(x, context)
x+=residue_short
residue_short=x
x=self.layernorm_3(x)
x, gate=self.linear_geglu_1(x).chunk(2, dim=-1)
x=x*F.gelu(gate)
x=self.linear_geglu_2(x)
x+=residue_short
x=x.transpose(-1, -2)
x=x.view((n, c, h, w))
return self.conv_output(x)+residue_long
class Upsample(nn.Module):
def __init__(self, channels):
super().__init__()
self.conv=nn.Conv2d(channels, channels, kernel_size=3, padding=1)
def forward(self, x):
x=F.interpolate(x, scale_factor=2, mode='nearest')
return self.conv(x)
# passing arguments to the parent class nn.Sequential, not to your SwitchSequential class directly — because you did not override the __init__ method in SwitchSequential
class SwitchSequential(nn.Sequential):
def forward(self, x, time):
for layer in self:
if isinstance(layer, UNET_AttentionBlock):
x=layer(x)
elif isinstance(layer, UNET_ResidualBlock):
x=layer(x, time)
else:
x=layer(x)
return x
class UNET(nn.Module):
def __init__(self, in_channels=4):
super().__init__()
self.encoders=nn.ModuleList([
# (Batch_Size, 4, Height / 8, Width / 8) -> (Batch_Size, 320, Height / 8, Width / 8)
SwitchSequential(nn.Conv2d(in_channels, 320, kernel_size=3, padding=1)),
# (Batch_Size, 320, Height / 8, Width / 8) -> # (Batch_Size, 320, Height / 8, Width / 8) -> (Batch_Size, 320, Height / 8, Width / 8)
SwitchSequential(UNET_ResidualBlock(320, 320), UNET_AttentionBlock(8, 40)),
SwitchSequential(UNET_ResidualBlock(320, 320), UNET_AttentionBlock(8, 40)),
# (Batch_Size, 320, Height / 8, Width / 8) -> (Batch_Size, 320, Height / 16, Width / 16)
SwitchSequential(nn.Conv2d(320, 320, kernel_size=3, stride=2, padding=1)),
# (Batch_Size, 320, Height / 16, Width / 16) -> (Batch_Size, 640, Height / 16, Width / 16) -> (Batch_Size, 640, Height / 16, Width / 16)
SwitchSequential(UNET_ResidualBlock(320, 640), UNET_AttentionBlock(8, 80)),
# (Batch_Size, 640, Height / 16, Width / 16) -> (Batch_Size, 640, Height / 16, Width / 16) -> (Batch_Size, 640, Height / 16, Width / 16)
SwitchSequential(UNET_ResidualBlock(640, 640), UNET_AttentionBlock(8, 80)),
# (Batch_Size, 640, Height / 16, Width / 16) -> (Batch_Size, 640, Height / 32, Width / 32)
SwitchSequential(nn.Conv2d(640, 640, kernel_size=3, stride=2, padding=1)),
# (Batch_Size, 640, Height / 32, Width / 32) -> (Batch_Size, 1280, Height / 32, Width / 32) -> (Batch_Size, 1280, Height / 32, Width / 32)
SwitchSequential(UNET_ResidualBlock(640, 1280), UNET_AttentionBlock(8, 160)),
# (Batch_Size, 1280, Height / 32, Width / 32) -> (Batch_Size, 1280, Height / 32, Width / 32) -> (Batch_Size, 1280, Height / 32, Width / 32)
SwitchSequential(UNET_ResidualBlock(1280, 1280), UNET_AttentionBlock(8, 160)),
# (Batch_Size, 1280, Height / 32, Width / 32) -> (Batch_Size, 1280, Height / 64, Width / 64)
SwitchSequential(nn.Conv2d(1280, 1280, kernel_size=3, stride=2, padding=1)),
# (Batch_Size, 1280, Height / 64, Width / 64) -> (Batch_Size, 1280, Height / 64, Width / 64)
SwitchSequential(UNET_ResidualBlock(1280, 1280)),
SwitchSequential(UNET_ResidualBlock(1280, 1280)),
])
self.bottleneck = SwitchSequential(
# (Batch_Size, 1280, Height / 64, Width / 64) -> (Batch_Size, 1280, Height / 64, Width / 64)
UNET_ResidualBlock(1280, 1280),
UNET_AttentionBlock(8, 160),
UNET_ResidualBlock(1280, 1280),
)
self.decoders = nn.ModuleList([
# (Batch_Size, 2560, Height / 64, Width / 64) -> (Batch_Size, 1280, Height / 64, Width / 64)
SwitchSequential(UNET_ResidualBlock(2560, 1280)),
SwitchSequential(UNET_ResidualBlock(2560, 1280)),
# (Batch_Size, 2560, Height / 64, Width / 64) -> (Batch_Size, 1280, Height / 64, Width / 64) -> (Batch_Size, 1280, Height / 32, Width / 32)
SwitchSequential(UNET_ResidualBlock(2560, 1280), Upsample(1280)),
# (Batch_Size, 2560, Height / 32, Width / 32) -> (Batch_Size, 1280, Height / 32, Width / 32) -> (Batch_Size, 1280, Height / 32, Width / 32)
SwitchSequential(UNET_ResidualBlock(2560, 1280), UNET_AttentionBlock(8, 160)),
SwitchSequential(UNET_ResidualBlock(2560, 1280), UNET_AttentionBlock(8, 160)),
# (Batch_Size, 1920, Height / 32, Width / 32) -> (Batch_Size, 1280, Height / 32, Width / 32) -> (Batch_Size, 1280, Height / 32, Width / 32) -> (Batch_Size, 1280, Height / 16, Width / 16)
SwitchSequential(UNET_ResidualBlock(1920, 1280), UNET_AttentionBlock(8, 160), Upsample(1280)),
# (Batch_Size, 1920, Height / 16, Width / 16) -> (Batch_Size, 640, Height / 16, Width / 16) -> (Batch_Size, 640, Height / 16, Width / 16)
SwitchSequential(UNET_ResidualBlock(1920, 640), UNET_AttentionBlock(8, 80)),
# (Batch_Size, 1280, Height / 16, Width / 16) -> (Batch_Size, 640, Height / 16, Width / 16) -> (Batch_Size, 640, Height / 16, Width / 16)
SwitchSequential(UNET_ResidualBlock(1280, 640), UNET_AttentionBlock(8, 80)),
# (Batch_Size, 960, Height / 16, Width / 16) -> (Batch_Size, 640, Height / 16, Width / 16) -> (Batch_Size, 640, Height / 16, Width / 16) -> (Batch_Size, 640, Height / 8, Width / 8)
SwitchSequential(UNET_ResidualBlock(960, 640), UNET_AttentionBlock(8, 80), Upsample(640)),
# (Batch_Size, 960, Height / 8, Width / 8) -> (Batch_Size, 320, Height / 8, Width / 8) -> (Batch_Size, 320, Height / 8, Width / 8)
SwitchSequential(UNET_ResidualBlock(960, 320), UNET_AttentionBlock(8, 40)),
# (Batch_Size, 640, Height / 8, Width / 8) -> (Batch_Size, 320, Height / 8, Width / 8) -> (Batch_Size, 320, Height / 8, Width / 8)
SwitchSequential(UNET_ResidualBlock(640, 320), UNET_AttentionBlock(8, 40)),
SwitchSequential(UNET_ResidualBlock(640, 320), UNET_AttentionBlock(8, 40)),
])
def forward(self, x, time):
# x: (Batch_Size, 4, Height / 8, Width / 8)
# context: (Batch_Size, Seq_Len, Dim)
# time: (1, 1280)
skip_connections = []
for layers in self.encoders:
x = layers(x, time)
skip_connections.append(x)
x = self.bottleneck(x, time)
for layers in self.decoders:
# Since we always concat with the skip connection of the encoder, the number of features increases before being sent to the decoder's layer
x = torch.cat((x, skip_connections.pop()), dim=1)
x = layers(x, time)
return x
class UNET_OutputLayer(nn.Module):
def __init__(self, in_channels, out_channels):
super().__init__()
self.grpnorm = nn.GroupNorm(32, in_channels)
self.conv=nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1)
def forward(self, x):
x=self.grpnorm(x)
x=F.silu(x)
x=self.conv(x)
return x
class Diffusion(nn.Module):
def __init__(self, in_channels=4, out_channels=4):
super().__init__()
self.time_embedding=TimeEmbedding(320)
self.unet=UNET(in_channels)
self.final=UNET_OutputLayer(320, out_channels)
def forward(self, latent, time):
time=self.time_embedding(time)
output=self.unet(latent, time)
output=self.final(output)
return output
import torch
from torch import nn
if __name__ == "__main__":
# Input configurations
batch_size = 1
in_channels = 9 # Must match UNET input
height = 64 # Height / 8 = 64 → original H = 512
width = 64 # Width / 8 = 64 → original W = 512
# Create dummy inputs
latent = torch.randn(batch_size, in_channels, height, width) # [1, 4, 64, 64]
time = torch.randn(batch_size, 320) # [1, 320]
# Initialize model
model = Diffusion(in_channels=in_channels, out_channels=4)
# Forward pass
with torch.no_grad():
output = model(latent, time)
# Print input and output shape
print("Input latent shape:", latent.shape)
print("Time embedding shape:", time.shape)
print("Output shape:", output.shape)