Spaces:
Running
Running
File size: 5,071 Bytes
b993f12 02ca424 b993f12 870796d b993f12 02ca424 870796d 16d759f 76c374a 16d759f 76c374a 16d759f 870796d 02ca424 870796d 02ca424 b993f12 76ee43d b993f12 870796d |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 |
from clip import CLIP
from encoder import VAE_Encoder
from decoder import VAE_Decoder
from diffusion import Diffusion, UNET_AttentionBlock
from safetensors.torch import load_file
import model_converter
import torch
def load_finetuned_attention_weights(finetune_weights_path, diffusion, device):
updated_loaded_data = load_file(finetune_weights_path, device=device)
print(f"Loaded finetuned weights from {finetune_weights_path}")
unet= diffusion.unet
idx = 0
# Iterate through the attention layers in the encoders
for layers in unet.encoders:
for layer in layers:
if isinstance(layer, UNET_AttentionBlock):
# Get the parameters from the loaded data for this block
in_proj_weight_key = f"{idx}.in_proj.weight"
out_proj_weight_key = f"{idx}.out_proj.weight"
out_proj_bias_key = f"{idx}.out_proj.bias"
# Load the weights if they exist in the loaded data
if in_proj_weight_key in updated_loaded_data:
print(f"Loading {in_proj_weight_key}")
layer.attention_1.in_proj.weight.data.copy_(updated_loaded_data[in_proj_weight_key])
if out_proj_weight_key in updated_loaded_data:
print(f"Loading {out_proj_weight_key}")
layer.attention_1.out_proj.weight.data.copy_(updated_loaded_data[out_proj_weight_key])
if out_proj_bias_key in updated_loaded_data:
print(f"Loading {out_proj_bias_key}")
layer.attention_1.out_proj.bias.data.copy_(updated_loaded_data[out_proj_bias_key])
idx += 8
# Move to the next attention block index in the loaded data
# Iterate through the attention layers in the decoders
for layers in unet.decoders:
for layer in layers:
if isinstance(layer, UNET_AttentionBlock):
in_proj_weight_key = f"{idx}.in_proj.weight"
out_proj_weight_key = f"{idx}.out_proj.weight"
out_proj_bias_key = f"{idx}.out_proj.bias"
if in_proj_weight_key in updated_loaded_data:
print(f"Loading {in_proj_weight_key}")
layer.attention_1.in_proj.weight.data.copy_(updated_loaded_data[in_proj_weight_key])
if out_proj_weight_key in updated_loaded_data:
print(f"Loading {out_proj_weight_key}")
layer.attention_1.out_proj.weight.data.copy_(updated_loaded_data[out_proj_weight_key])
if out_proj_bias_key in updated_loaded_data:
print(f"Loading {out_proj_bias_key}")
layer.attention_1.out_proj.bias.data.copy_(updated_loaded_data[out_proj_bias_key])
idx += 8
# Iterate through the attention layers in the bottleneck
for layer in unet.bottleneck:
if isinstance(layer, UNET_AttentionBlock):
in_proj_weight_key = f"{idx}.in_proj.weight"
out_proj_weight_key = f"{idx}.out_proj.weight"
out_proj_bias_key = f"{idx}.out_proj.bias"
if in_proj_weight_key in updated_loaded_data:
print(f"Loading {in_proj_weight_key}")
layer.attention_1.in_proj.weight.data.copy_(updated_loaded_data[in_proj_weight_key])
if out_proj_weight_key in updated_loaded_data:
print(f"Loading {out_proj_weight_key}")
layer.attention_1.out_proj.weight.data.copy_(updated_loaded_data[out_proj_weight_key])
if out_proj_bias_key in updated_loaded_data:
print(f"Loading {out_proj_bias_key}")
layer.attention_1.out_proj.bias.data.copy_(updated_loaded_data[out_proj_bias_key])
idx += 8
print("\nAttention module weights loaded from {finetune_weights_path} successfully.")
def preload_models_from_standard_weights(ckpt_path, device, finetune_weights_path=None):
# CatVTON parameters
# in_channels: 8 for instruct-pix2pix (masked free), 9 for sd-v1-5-inpainting (masked based)
in_channels = 9
if 'maskfree' in finetune_weights_path or 'mask_free' in finetune_weights_path:
in_channels = 8
out_channels = 4
state_dict=model_converter.load_from_standard_weights(ckpt_path, device)
diffusion=Diffusion(in_channels=in_channels, out_channels=out_channels).to(device)
diffusion.load_state_dict(state_dict['diffusion'], strict=True)
if finetune_weights_path != None:
load_finetuned_attention_weights(finetune_weights_path, diffusion, device)
encoder=VAE_Encoder().to(device)
encoder.load_state_dict(state_dict['encoder'], strict=True)
decoder=VAE_Decoder().to(device)
decoder.load_state_dict(state_dict['decoder'], strict=True)
clip=CLIP().to(device)
clip.load_state_dict(state_dict['clip'], strict=True)
return {
# 'clip': clip,
'encoder': encoder,
'decoder': decoder,
'diffusion': diffusion,
} |