Virtual-Cloths-TryOn / load_model.py
harsh99's picture
masked free model support added.
76c374a
from clip import CLIP
from encoder import VAE_Encoder
from decoder import VAE_Decoder
from diffusion import Diffusion, UNET_AttentionBlock
from safetensors.torch import load_file
import model_converter
import torch
def load_finetuned_attention_weights(finetune_weights_path, diffusion, device):
updated_loaded_data = load_file(finetune_weights_path, device=device)
print(f"Loaded finetuned weights from {finetune_weights_path}")
unet= diffusion.unet
idx = 0
# Iterate through the attention layers in the encoders
for layers in unet.encoders:
for layer in layers:
if isinstance(layer, UNET_AttentionBlock):
# Get the parameters from the loaded data for this block
in_proj_weight_key = f"{idx}.in_proj.weight"
out_proj_weight_key = f"{idx}.out_proj.weight"
out_proj_bias_key = f"{idx}.out_proj.bias"
# Load the weights if they exist in the loaded data
if in_proj_weight_key in updated_loaded_data:
print(f"Loading {in_proj_weight_key}")
layer.attention_1.in_proj.weight.data.copy_(updated_loaded_data[in_proj_weight_key])
if out_proj_weight_key in updated_loaded_data:
print(f"Loading {out_proj_weight_key}")
layer.attention_1.out_proj.weight.data.copy_(updated_loaded_data[out_proj_weight_key])
if out_proj_bias_key in updated_loaded_data:
print(f"Loading {out_proj_bias_key}")
layer.attention_1.out_proj.bias.data.copy_(updated_loaded_data[out_proj_bias_key])
idx += 8
# Move to the next attention block index in the loaded data
# Iterate through the attention layers in the decoders
for layers in unet.decoders:
for layer in layers:
if isinstance(layer, UNET_AttentionBlock):
in_proj_weight_key = f"{idx}.in_proj.weight"
out_proj_weight_key = f"{idx}.out_proj.weight"
out_proj_bias_key = f"{idx}.out_proj.bias"
if in_proj_weight_key in updated_loaded_data:
print(f"Loading {in_proj_weight_key}")
layer.attention_1.in_proj.weight.data.copy_(updated_loaded_data[in_proj_weight_key])
if out_proj_weight_key in updated_loaded_data:
print(f"Loading {out_proj_weight_key}")
layer.attention_1.out_proj.weight.data.copy_(updated_loaded_data[out_proj_weight_key])
if out_proj_bias_key in updated_loaded_data:
print(f"Loading {out_proj_bias_key}")
layer.attention_1.out_proj.bias.data.copy_(updated_loaded_data[out_proj_bias_key])
idx += 8
# Iterate through the attention layers in the bottleneck
for layer in unet.bottleneck:
if isinstance(layer, UNET_AttentionBlock):
in_proj_weight_key = f"{idx}.in_proj.weight"
out_proj_weight_key = f"{idx}.out_proj.weight"
out_proj_bias_key = f"{idx}.out_proj.bias"
if in_proj_weight_key in updated_loaded_data:
print(f"Loading {in_proj_weight_key}")
layer.attention_1.in_proj.weight.data.copy_(updated_loaded_data[in_proj_weight_key])
if out_proj_weight_key in updated_loaded_data:
print(f"Loading {out_proj_weight_key}")
layer.attention_1.out_proj.weight.data.copy_(updated_loaded_data[out_proj_weight_key])
if out_proj_bias_key in updated_loaded_data:
print(f"Loading {out_proj_bias_key}")
layer.attention_1.out_proj.bias.data.copy_(updated_loaded_data[out_proj_bias_key])
idx += 8
print("\nAttention module weights loaded from {finetune_weights_path} successfully.")
def preload_models_from_standard_weights(ckpt_path, device, finetune_weights_path=None):
# CatVTON parameters
# in_channels: 8 for instruct-pix2pix (masked free), 9 for sd-v1-5-inpainting (masked based)
in_channels = 9
if 'maskfree' in finetune_weights_path or 'mask_free' in finetune_weights_path:
in_channels = 8
out_channels = 4
state_dict=model_converter.load_from_standard_weights(ckpt_path, device)
diffusion=Diffusion(in_channels=in_channels, out_channels=out_channels).to(device)
diffusion.load_state_dict(state_dict['diffusion'], strict=True)
if finetune_weights_path != None:
load_finetuned_attention_weights(finetune_weights_path, diffusion, device)
encoder=VAE_Encoder().to(device)
encoder.load_state_dict(state_dict['encoder'], strict=True)
decoder=VAE_Decoder().to(device)
decoder.load_state_dict(state_dict['decoder'], strict=True)
clip=CLIP().to(device)
clip.load_state_dict(state_dict['clip'], strict=True)
return {
# 'clip': clip,
'encoder': encoder,
'decoder': decoder,
'diffusion': diffusion,
}