File size: 5,071 Bytes
b993f12
 
 
02ca424
 
b993f12
870796d
b993f12
02ca424
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
870796d
16d759f
76c374a
16d759f
76c374a
 
 
 
16d759f
 
870796d
 
 
02ca424
 
870796d
02ca424
b993f12
 
 
 
 
 
 
 
 
 
 
76ee43d
b993f12
 
 
870796d
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
from clip import CLIP
from encoder import VAE_Encoder
from decoder import VAE_Decoder
from diffusion import Diffusion, UNET_AttentionBlock
from safetensors.torch import load_file
import model_converter
import torch

def load_finetuned_attention_weights(finetune_weights_path, diffusion, device):
    updated_loaded_data = load_file(finetune_weights_path, device=device)
    print(f"Loaded finetuned weights from {finetune_weights_path}")
    
    unet= diffusion.unet
    idx = 0  
    # Iterate through the attention layers in the encoders
    for layers in unet.encoders:
        for layer in layers:
            if isinstance(layer, UNET_AttentionBlock):
                # Get the parameters from the loaded data for this block
                in_proj_weight_key = f"{idx}.in_proj.weight"
                out_proj_weight_key = f"{idx}.out_proj.weight"
                out_proj_bias_key = f"{idx}.out_proj.bias" 

                # Load the weights if they exist in the loaded data
                if in_proj_weight_key in updated_loaded_data:
                    print(f"Loading {in_proj_weight_key}")
                    layer.attention_1.in_proj.weight.data.copy_(updated_loaded_data[in_proj_weight_key])
                if out_proj_weight_key in updated_loaded_data:
                    print(f"Loading {out_proj_weight_key}")
                    layer.attention_1.out_proj.weight.data.copy_(updated_loaded_data[out_proj_weight_key])
                if out_proj_bias_key in updated_loaded_data:
                    print(f"Loading {out_proj_bias_key}")
                    layer.attention_1.out_proj.bias.data.copy_(updated_loaded_data[out_proj_bias_key])
                    idx += 8

                # Move to the next attention block index in the loaded data


    # Iterate through the attention layers in the decoders
    for layers in unet.decoders:
        for layer in layers:
            if isinstance(layer, UNET_AttentionBlock):
                in_proj_weight_key = f"{idx}.in_proj.weight"
                out_proj_weight_key = f"{idx}.out_proj.weight"
                out_proj_bias_key = f"{idx}.out_proj.bias"

                if in_proj_weight_key in updated_loaded_data:
                    print(f"Loading {in_proj_weight_key}")
                    layer.attention_1.in_proj.weight.data.copy_(updated_loaded_data[in_proj_weight_key])
                if out_proj_weight_key in updated_loaded_data:
                    print(f"Loading {out_proj_weight_key}")
                    layer.attention_1.out_proj.weight.data.copy_(updated_loaded_data[out_proj_weight_key])
                if out_proj_bias_key in updated_loaded_data:
                    print(f"Loading {out_proj_bias_key}")
                    layer.attention_1.out_proj.bias.data.copy_(updated_loaded_data[out_proj_bias_key])
                    idx += 8


    # Iterate through the attention layers in the bottleneck
    for layer in unet.bottleneck:
        if isinstance(layer, UNET_AttentionBlock):
            in_proj_weight_key = f"{idx}.in_proj.weight"
            out_proj_weight_key = f"{idx}.out_proj.weight"
            out_proj_bias_key = f"{idx}.out_proj.bias"

            if in_proj_weight_key in updated_loaded_data:
                print(f"Loading {in_proj_weight_key}")
                layer.attention_1.in_proj.weight.data.copy_(updated_loaded_data[in_proj_weight_key])
            if out_proj_weight_key in updated_loaded_data:
                print(f"Loading {out_proj_weight_key}")
                layer.attention_1.out_proj.weight.data.copy_(updated_loaded_data[out_proj_weight_key])
            if out_proj_bias_key in updated_loaded_data:
                print(f"Loading {out_proj_bias_key}")
                layer.attention_1.out_proj.bias.data.copy_(updated_loaded_data[out_proj_bias_key])
                idx += 8

    print("\nAttention module weights loaded from {finetune_weights_path} successfully.")

def preload_models_from_standard_weights(ckpt_path, device, finetune_weights_path=None):
    # CatVTON parameters
    # in_channels: 8 for instruct-pix2pix (masked free), 9 for sd-v1-5-inpainting (masked based)
    in_channels = 9
    
    if 'maskfree' in finetune_weights_path or 'mask_free' in finetune_weights_path:
        in_channels = 8
        
    out_channels = 4

    state_dict=model_converter.load_from_standard_weights(ckpt_path, device)

    diffusion=Diffusion(in_channels=in_channels, out_channels=out_channels).to(device)
    diffusion.load_state_dict(state_dict['diffusion'], strict=True)
    
    if finetune_weights_path != None:
        load_finetuned_attention_weights(finetune_weights_path, diffusion, device)

    encoder=VAE_Encoder().to(device)
    encoder.load_state_dict(state_dict['encoder'], strict=True)

    decoder=VAE_Decoder().to(device)
    decoder.load_state_dict(state_dict['decoder'], strict=True)

    clip=CLIP().to(device)
    clip.load_state_dict(state_dict['clip'], strict=True)

    return {
        # 'clip': clip,
        'encoder': encoder,
        'decoder': decoder,
        'diffusion': diffusion,
    }