Upload 89 files
Browse filesThis view is limited to 50 files because it contains too many changes.
See raw diff
- .gitattributes +2 -0
- FUNDING.yml +3 -0
- __init__.py +0 -0
- adafactor_fused.py +106 -0
- attention.py +119 -0
- attention_processors.py +227 -0
- blip.py +245 -0
- cache_latents.py +205 -0
- cache_text_encoder_outputs.py +197 -0
- canny.py +34 -0
- cextension.py +54 -0
- check_lora_weights.py +48 -0
- clean_captions_and_tags.py +194 -0
- config_README-en.md +386 -0
- config_README-ja.md +388 -0
- config_util.py +721 -0
- control_net_lllite.py +449 -0
- control_net_lllite_for_train.py +501 -0
- convert_diffusers20_original_sd.py +163 -0
- custom_train_functions.py +559 -0
- deepspeed_utils.py +139 -0
- dependabot.yml +7 -0
- detect_face_rotate.py +253 -0
- device_utils.py +89 -0
- diffusers.py +47 -0
- dylora.py +529 -0
- extract_lora_from_dylora.py +128 -0
- extract_lora_from_models.py +360 -0
- fine_tune_README_ja.md +140 -0
- gen_img_README-ja.md +487 -0
- gradscaler.py +183 -0
- hijacks.py +367 -0
- huggingface_util.py +84 -0
- hypernetwork.py +223 -0
- hypernetwork_nai.py +96 -0
- latent_upscaler.py +354 -0
- libbitsandbytes_cpu.dll +0 -0
- libbitsandbytes_cuda116.dll +3 -0
- libbitsandbytes_cuda118.dll +3 -0
- logo_aihub.png +0 -0
- lora.py +1410 -0
- lora_diffusers.py +616 -0
- lora_fa.py +1244 -0
- lora_interrogator.py +146 -0
- lpw_stable_diffusion.py +1233 -0
- main.py +166 -0
- make_captions.py +210 -0
- make_captions_by_git.py +183 -0
- masked_loss_README-ja.md +57 -0
- masked_loss_README.md +56 -0
.gitattributes
CHANGED
|
@@ -33,3 +33,5 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
|
| 33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
| 34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
| 35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
|
|
|
|
|
|
|
|
| 33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
| 34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
| 35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
| 36 |
+
libbitsandbytes_cuda116.dll filter=lfs diff=lfs merge=lfs -text
|
| 37 |
+
libbitsandbytes_cuda118.dll filter=lfs diff=lfs merge=lfs -text
|
FUNDING.yml
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# These are supported funding model platforms
|
| 2 |
+
|
| 3 |
+
github: kohya-ss
|
__init__.py
ADDED
|
File without changes
|
adafactor_fused.py
ADDED
|
@@ -0,0 +1,106 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import math
|
| 2 |
+
import torch
|
| 3 |
+
from transformers import Adafactor
|
| 4 |
+
|
| 5 |
+
@torch.no_grad()
|
| 6 |
+
def adafactor_step_param(self, p, group):
|
| 7 |
+
if p.grad is None:
|
| 8 |
+
return
|
| 9 |
+
grad = p.grad
|
| 10 |
+
if grad.dtype in {torch.float16, torch.bfloat16}:
|
| 11 |
+
grad = grad.float()
|
| 12 |
+
if grad.is_sparse:
|
| 13 |
+
raise RuntimeError("Adafactor does not support sparse gradients.")
|
| 14 |
+
|
| 15 |
+
state = self.state[p]
|
| 16 |
+
grad_shape = grad.shape
|
| 17 |
+
|
| 18 |
+
factored, use_first_moment = Adafactor._get_options(group, grad_shape)
|
| 19 |
+
# State Initialization
|
| 20 |
+
if len(state) == 0:
|
| 21 |
+
state["step"] = 0
|
| 22 |
+
|
| 23 |
+
if use_first_moment:
|
| 24 |
+
# Exponential moving average of gradient values
|
| 25 |
+
state["exp_avg"] = torch.zeros_like(grad)
|
| 26 |
+
if factored:
|
| 27 |
+
state["exp_avg_sq_row"] = torch.zeros(grad_shape[:-1]).to(grad)
|
| 28 |
+
state["exp_avg_sq_col"] = torch.zeros(grad_shape[:-2] + grad_shape[-1:]).to(grad)
|
| 29 |
+
else:
|
| 30 |
+
state["exp_avg_sq"] = torch.zeros_like(grad)
|
| 31 |
+
|
| 32 |
+
state["RMS"] = 0
|
| 33 |
+
else:
|
| 34 |
+
if use_first_moment:
|
| 35 |
+
state["exp_avg"] = state["exp_avg"].to(grad)
|
| 36 |
+
if factored:
|
| 37 |
+
state["exp_avg_sq_row"] = state["exp_avg_sq_row"].to(grad)
|
| 38 |
+
state["exp_avg_sq_col"] = state["exp_avg_sq_col"].to(grad)
|
| 39 |
+
else:
|
| 40 |
+
state["exp_avg_sq"] = state["exp_avg_sq"].to(grad)
|
| 41 |
+
|
| 42 |
+
p_data_fp32 = p
|
| 43 |
+
if p.dtype in {torch.float16, torch.bfloat16}:
|
| 44 |
+
p_data_fp32 = p_data_fp32.float()
|
| 45 |
+
|
| 46 |
+
state["step"] += 1
|
| 47 |
+
state["RMS"] = Adafactor._rms(p_data_fp32)
|
| 48 |
+
lr = Adafactor._get_lr(group, state)
|
| 49 |
+
|
| 50 |
+
beta2t = 1.0 - math.pow(state["step"], group["decay_rate"])
|
| 51 |
+
update = (grad ** 2) + group["eps"][0]
|
| 52 |
+
if factored:
|
| 53 |
+
exp_avg_sq_row = state["exp_avg_sq_row"]
|
| 54 |
+
exp_avg_sq_col = state["exp_avg_sq_col"]
|
| 55 |
+
|
| 56 |
+
exp_avg_sq_row.mul_(beta2t).add_(update.mean(dim=-1), alpha=(1.0 - beta2t))
|
| 57 |
+
exp_avg_sq_col.mul_(beta2t).add_(update.mean(dim=-2), alpha=(1.0 - beta2t))
|
| 58 |
+
|
| 59 |
+
# Approximation of exponential moving average of square of gradient
|
| 60 |
+
update = Adafactor._approx_sq_grad(exp_avg_sq_row, exp_avg_sq_col)
|
| 61 |
+
update.mul_(grad)
|
| 62 |
+
else:
|
| 63 |
+
exp_avg_sq = state["exp_avg_sq"]
|
| 64 |
+
|
| 65 |
+
exp_avg_sq.mul_(beta2t).add_(update, alpha=(1.0 - beta2t))
|
| 66 |
+
update = exp_avg_sq.rsqrt().mul_(grad)
|
| 67 |
+
|
| 68 |
+
update.div_((Adafactor._rms(update) / group["clip_threshold"]).clamp_(min=1.0))
|
| 69 |
+
update.mul_(lr)
|
| 70 |
+
|
| 71 |
+
if use_first_moment:
|
| 72 |
+
exp_avg = state["exp_avg"]
|
| 73 |
+
exp_avg.mul_(group["beta1"]).add_(update, alpha=(1 - group["beta1"]))
|
| 74 |
+
update = exp_avg
|
| 75 |
+
|
| 76 |
+
if group["weight_decay"] != 0:
|
| 77 |
+
p_data_fp32.add_(p_data_fp32, alpha=(-group["weight_decay"] * lr))
|
| 78 |
+
|
| 79 |
+
p_data_fp32.add_(-update)
|
| 80 |
+
|
| 81 |
+
if p.dtype in {torch.float16, torch.bfloat16}:
|
| 82 |
+
p.copy_(p_data_fp32)
|
| 83 |
+
|
| 84 |
+
|
| 85 |
+
@torch.no_grad()
|
| 86 |
+
def adafactor_step(self, closure=None):
|
| 87 |
+
"""
|
| 88 |
+
Performs a single optimization step
|
| 89 |
+
|
| 90 |
+
Arguments:
|
| 91 |
+
closure (callable, optional): A closure that reevaluates the model
|
| 92 |
+
and returns the loss.
|
| 93 |
+
"""
|
| 94 |
+
loss = None
|
| 95 |
+
if closure is not None:
|
| 96 |
+
loss = closure()
|
| 97 |
+
|
| 98 |
+
for group in self.param_groups:
|
| 99 |
+
for p in group["params"]:
|
| 100 |
+
adafactor_step_param(self, p, group)
|
| 101 |
+
|
| 102 |
+
return loss
|
| 103 |
+
|
| 104 |
+
def patch_adafactor_fused(optimizer: Adafactor):
|
| 105 |
+
optimizer.step_param = adafactor_step_param.__get__(optimizer)
|
| 106 |
+
optimizer.step = adafactor_step.__get__(optimizer)
|
attention.py
ADDED
|
@@ -0,0 +1,119 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import torch
|
| 3 |
+
from functools import cache, wraps
|
| 4 |
+
|
| 5 |
+
# pylint: disable=protected-access, missing-function-docstring, line-too-long
|
| 6 |
+
|
| 7 |
+
# ARC GPUs can't allocate more than 4GB to a single block so we slice the attention layers
|
| 8 |
+
|
| 9 |
+
sdpa_slice_trigger_rate = float(os.environ.get('IPEX_SDPA_SLICE_TRIGGER_RATE', 1))
|
| 10 |
+
attention_slice_rate = float(os.environ.get('IPEX_ATTENTION_SLICE_RATE', 0.5))
|
| 11 |
+
|
| 12 |
+
# Find something divisible with the input_tokens
|
| 13 |
+
@cache
|
| 14 |
+
def find_split_size(original_size, slice_block_size, slice_rate=2):
|
| 15 |
+
split_size = original_size
|
| 16 |
+
while True:
|
| 17 |
+
if (split_size * slice_block_size) <= slice_rate and original_size % split_size == 0:
|
| 18 |
+
return split_size
|
| 19 |
+
split_size = split_size - 1
|
| 20 |
+
if split_size <= 1:
|
| 21 |
+
return 1
|
| 22 |
+
return split_size
|
| 23 |
+
|
| 24 |
+
|
| 25 |
+
# Find slice sizes for SDPA
|
| 26 |
+
@cache
|
| 27 |
+
def find_sdpa_slice_sizes(query_shape, key_shape, query_element_size, slice_rate=2, trigger_rate=3):
|
| 28 |
+
batch_size, attn_heads, query_len, _ = query_shape
|
| 29 |
+
_, _, key_len, _ = key_shape
|
| 30 |
+
|
| 31 |
+
slice_batch_size = attn_heads * (query_len * key_len) * query_element_size / 1024 / 1024 / 1024
|
| 32 |
+
|
| 33 |
+
split_batch_size = batch_size
|
| 34 |
+
split_head_size = attn_heads
|
| 35 |
+
split_query_size = query_len
|
| 36 |
+
|
| 37 |
+
do_batch_split = False
|
| 38 |
+
do_head_split = False
|
| 39 |
+
do_query_split = False
|
| 40 |
+
|
| 41 |
+
if batch_size * slice_batch_size >= trigger_rate:
|
| 42 |
+
do_batch_split = True
|
| 43 |
+
split_batch_size = find_split_size(batch_size, slice_batch_size, slice_rate=slice_rate)
|
| 44 |
+
|
| 45 |
+
if split_batch_size * slice_batch_size > slice_rate:
|
| 46 |
+
slice_head_size = split_batch_size * (query_len * key_len) * query_element_size / 1024 / 1024 / 1024
|
| 47 |
+
do_head_split = True
|
| 48 |
+
split_head_size = find_split_size(attn_heads, slice_head_size, slice_rate=slice_rate)
|
| 49 |
+
|
| 50 |
+
if split_head_size * slice_head_size > slice_rate:
|
| 51 |
+
slice_query_size = split_batch_size * split_head_size * (key_len) * query_element_size / 1024 / 1024 / 1024
|
| 52 |
+
do_query_split = True
|
| 53 |
+
split_query_size = find_split_size(query_len, slice_query_size, slice_rate=slice_rate)
|
| 54 |
+
|
| 55 |
+
return do_batch_split, do_head_split, do_query_split, split_batch_size, split_head_size, split_query_size
|
| 56 |
+
|
| 57 |
+
|
| 58 |
+
original_scaled_dot_product_attention = torch.nn.functional.scaled_dot_product_attention
|
| 59 |
+
@wraps(torch.nn.functional.scaled_dot_product_attention)
|
| 60 |
+
def dynamic_scaled_dot_product_attention(query, key, value, attn_mask=None, dropout_p=0.0, is_causal=False, **kwargs):
|
| 61 |
+
if query.device.type != "xpu":
|
| 62 |
+
return original_scaled_dot_product_attention(query, key, value, attn_mask=attn_mask, dropout_p=dropout_p, is_causal=is_causal, **kwargs)
|
| 63 |
+
is_unsqueezed = False
|
| 64 |
+
if len(query.shape) == 3:
|
| 65 |
+
query = query.unsqueeze(0)
|
| 66 |
+
is_unsqueezed = True
|
| 67 |
+
if len(key.shape) == 3:
|
| 68 |
+
key = key.unsqueeze(0)
|
| 69 |
+
if len(value.shape) == 3:
|
| 70 |
+
value = value.unsqueeze(0)
|
| 71 |
+
do_batch_split, do_head_split, do_query_split, split_batch_size, split_head_size, split_query_size = find_sdpa_slice_sizes(query.shape, key.shape, query.element_size(), slice_rate=attention_slice_rate, trigger_rate=sdpa_slice_trigger_rate)
|
| 72 |
+
|
| 73 |
+
# Slice SDPA
|
| 74 |
+
if do_batch_split:
|
| 75 |
+
batch_size, attn_heads, query_len, _ = query.shape
|
| 76 |
+
_, _, _, head_dim = value.shape
|
| 77 |
+
hidden_states = torch.zeros((batch_size, attn_heads, query_len, head_dim), device=query.device, dtype=query.dtype)
|
| 78 |
+
if attn_mask is not None:
|
| 79 |
+
attn_mask = attn_mask.expand((query.shape[0], query.shape[1], query.shape[2], key.shape[-2]))
|
| 80 |
+
for ib in range(batch_size // split_batch_size):
|
| 81 |
+
start_idx = ib * split_batch_size
|
| 82 |
+
end_idx = (ib + 1) * split_batch_size
|
| 83 |
+
if do_head_split:
|
| 84 |
+
for ih in range(attn_heads // split_head_size): # pylint: disable=invalid-name
|
| 85 |
+
start_idx_h = ih * split_head_size
|
| 86 |
+
end_idx_h = (ih + 1) * split_head_size
|
| 87 |
+
if do_query_split:
|
| 88 |
+
for iq in range(query_len // split_query_size): # pylint: disable=invalid-name
|
| 89 |
+
start_idx_q = iq * split_query_size
|
| 90 |
+
end_idx_q = (iq + 1) * split_query_size
|
| 91 |
+
hidden_states[start_idx:end_idx, start_idx_h:end_idx_h, start_idx_q:end_idx_q, :] = original_scaled_dot_product_attention(
|
| 92 |
+
query[start_idx:end_idx, start_idx_h:end_idx_h, start_idx_q:end_idx_q, :],
|
| 93 |
+
key[start_idx:end_idx, start_idx_h:end_idx_h, :, :],
|
| 94 |
+
value[start_idx:end_idx, start_idx_h:end_idx_h, :, :],
|
| 95 |
+
attn_mask=attn_mask[start_idx:end_idx, start_idx_h:end_idx_h, start_idx_q:end_idx_q, :] if attn_mask is not None else attn_mask,
|
| 96 |
+
dropout_p=dropout_p, is_causal=is_causal, **kwargs
|
| 97 |
+
)
|
| 98 |
+
else:
|
| 99 |
+
hidden_states[start_idx:end_idx, start_idx_h:end_idx_h, :, :] = original_scaled_dot_product_attention(
|
| 100 |
+
query[start_idx:end_idx, start_idx_h:end_idx_h, :, :],
|
| 101 |
+
key[start_idx:end_idx, start_idx_h:end_idx_h, :, :],
|
| 102 |
+
value[start_idx:end_idx, start_idx_h:end_idx_h, :, :],
|
| 103 |
+
attn_mask=attn_mask[start_idx:end_idx, start_idx_h:end_idx_h, :, :] if attn_mask is not None else attn_mask,
|
| 104 |
+
dropout_p=dropout_p, is_causal=is_causal, **kwargs
|
| 105 |
+
)
|
| 106 |
+
else:
|
| 107 |
+
hidden_states[start_idx:end_idx, :, :, :] = original_scaled_dot_product_attention(
|
| 108 |
+
query[start_idx:end_idx, :, :, :],
|
| 109 |
+
key[start_idx:end_idx, :, :, :],
|
| 110 |
+
value[start_idx:end_idx, :, :, :],
|
| 111 |
+
attn_mask=attn_mask[start_idx:end_idx, :, :, :] if attn_mask is not None else attn_mask,
|
| 112 |
+
dropout_p=dropout_p, is_causal=is_causal, **kwargs
|
| 113 |
+
)
|
| 114 |
+
torch.xpu.synchronize(query.device)
|
| 115 |
+
else:
|
| 116 |
+
hidden_states = original_scaled_dot_product_attention(query, key, value, attn_mask=attn_mask, dropout_p=dropout_p, is_causal=is_causal, **kwargs)
|
| 117 |
+
if is_unsqueezed:
|
| 118 |
+
hidden_states.squeeze(0)
|
| 119 |
+
return hidden_states
|
attention_processors.py
ADDED
|
@@ -0,0 +1,227 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import math
|
| 2 |
+
from typing import Any
|
| 3 |
+
from einops import rearrange
|
| 4 |
+
import torch
|
| 5 |
+
from diffusers.models.attention_processor import Attention
|
| 6 |
+
|
| 7 |
+
|
| 8 |
+
# flash attention forwards and backwards
|
| 9 |
+
|
| 10 |
+
# https://arxiv.org/abs/2205.14135
|
| 11 |
+
|
| 12 |
+
EPSILON = 1e-6
|
| 13 |
+
|
| 14 |
+
|
| 15 |
+
class FlashAttentionFunction(torch.autograd.function.Function):
|
| 16 |
+
@staticmethod
|
| 17 |
+
@torch.no_grad()
|
| 18 |
+
def forward(ctx, q, k, v, mask, causal, q_bucket_size, k_bucket_size):
|
| 19 |
+
"""Algorithm 2 in the paper"""
|
| 20 |
+
|
| 21 |
+
device = q.device
|
| 22 |
+
dtype = q.dtype
|
| 23 |
+
max_neg_value = -torch.finfo(q.dtype).max
|
| 24 |
+
qk_len_diff = max(k.shape[-2] - q.shape[-2], 0)
|
| 25 |
+
|
| 26 |
+
o = torch.zeros_like(q)
|
| 27 |
+
all_row_sums = torch.zeros((*q.shape[:-1], 1), dtype=dtype, device=device)
|
| 28 |
+
all_row_maxes = torch.full(
|
| 29 |
+
(*q.shape[:-1], 1), max_neg_value, dtype=dtype, device=device
|
| 30 |
+
)
|
| 31 |
+
|
| 32 |
+
scale = q.shape[-1] ** -0.5
|
| 33 |
+
|
| 34 |
+
if mask is None:
|
| 35 |
+
mask = (None,) * math.ceil(q.shape[-2] / q_bucket_size)
|
| 36 |
+
else:
|
| 37 |
+
mask = rearrange(mask, "b n -> b 1 1 n")
|
| 38 |
+
mask = mask.split(q_bucket_size, dim=-1)
|
| 39 |
+
|
| 40 |
+
row_splits = zip(
|
| 41 |
+
q.split(q_bucket_size, dim=-2),
|
| 42 |
+
o.split(q_bucket_size, dim=-2),
|
| 43 |
+
mask,
|
| 44 |
+
all_row_sums.split(q_bucket_size, dim=-2),
|
| 45 |
+
all_row_maxes.split(q_bucket_size, dim=-2),
|
| 46 |
+
)
|
| 47 |
+
|
| 48 |
+
for ind, (qc, oc, row_mask, row_sums, row_maxes) in enumerate(row_splits):
|
| 49 |
+
q_start_index = ind * q_bucket_size - qk_len_diff
|
| 50 |
+
|
| 51 |
+
col_splits = zip(
|
| 52 |
+
k.split(k_bucket_size, dim=-2),
|
| 53 |
+
v.split(k_bucket_size, dim=-2),
|
| 54 |
+
)
|
| 55 |
+
|
| 56 |
+
for k_ind, (kc, vc) in enumerate(col_splits):
|
| 57 |
+
k_start_index = k_ind * k_bucket_size
|
| 58 |
+
|
| 59 |
+
attn_weights = (
|
| 60 |
+
torch.einsum("... i d, ... j d -> ... i j", qc, kc) * scale
|
| 61 |
+
)
|
| 62 |
+
|
| 63 |
+
if row_mask is not None:
|
| 64 |
+
attn_weights.masked_fill_(~row_mask, max_neg_value)
|
| 65 |
+
|
| 66 |
+
if causal and q_start_index < (k_start_index + k_bucket_size - 1):
|
| 67 |
+
causal_mask = torch.ones(
|
| 68 |
+
(qc.shape[-2], kc.shape[-2]), dtype=torch.bool, device=device
|
| 69 |
+
).triu(q_start_index - k_start_index + 1)
|
| 70 |
+
attn_weights.masked_fill_(causal_mask, max_neg_value)
|
| 71 |
+
|
| 72 |
+
block_row_maxes = attn_weights.amax(dim=-1, keepdims=True)
|
| 73 |
+
attn_weights -= block_row_maxes
|
| 74 |
+
exp_weights = torch.exp(attn_weights)
|
| 75 |
+
|
| 76 |
+
if row_mask is not None:
|
| 77 |
+
exp_weights.masked_fill_(~row_mask, 0.0)
|
| 78 |
+
|
| 79 |
+
block_row_sums = exp_weights.sum(dim=-1, keepdims=True).clamp(
|
| 80 |
+
min=EPSILON
|
| 81 |
+
)
|
| 82 |
+
|
| 83 |
+
new_row_maxes = torch.maximum(block_row_maxes, row_maxes)
|
| 84 |
+
|
| 85 |
+
exp_values = torch.einsum(
|
| 86 |
+
"... i j, ... j d -> ... i d", exp_weights, vc
|
| 87 |
+
)
|
| 88 |
+
|
| 89 |
+
exp_row_max_diff = torch.exp(row_maxes - new_row_maxes)
|
| 90 |
+
exp_block_row_max_diff = torch.exp(block_row_maxes - new_row_maxes)
|
| 91 |
+
|
| 92 |
+
new_row_sums = (
|
| 93 |
+
exp_row_max_diff * row_sums
|
| 94 |
+
+ exp_block_row_max_diff * block_row_sums
|
| 95 |
+
)
|
| 96 |
+
|
| 97 |
+
oc.mul_((row_sums / new_row_sums) * exp_row_max_diff).add_(
|
| 98 |
+
(exp_block_row_max_diff / new_row_sums) * exp_values
|
| 99 |
+
)
|
| 100 |
+
|
| 101 |
+
row_maxes.copy_(new_row_maxes)
|
| 102 |
+
row_sums.copy_(new_row_sums)
|
| 103 |
+
|
| 104 |
+
ctx.args = (causal, scale, mask, q_bucket_size, k_bucket_size)
|
| 105 |
+
ctx.save_for_backward(q, k, v, o, all_row_sums, all_row_maxes)
|
| 106 |
+
|
| 107 |
+
return o
|
| 108 |
+
|
| 109 |
+
@staticmethod
|
| 110 |
+
@torch.no_grad()
|
| 111 |
+
def backward(ctx, do):
|
| 112 |
+
"""Algorithm 4 in the paper"""
|
| 113 |
+
|
| 114 |
+
causal, scale, mask, q_bucket_size, k_bucket_size = ctx.args
|
| 115 |
+
q, k, v, o, l, m = ctx.saved_tensors
|
| 116 |
+
|
| 117 |
+
device = q.device
|
| 118 |
+
|
| 119 |
+
max_neg_value = -torch.finfo(q.dtype).max
|
| 120 |
+
qk_len_diff = max(k.shape[-2] - q.shape[-2], 0)
|
| 121 |
+
|
| 122 |
+
dq = torch.zeros_like(q)
|
| 123 |
+
dk = torch.zeros_like(k)
|
| 124 |
+
dv = torch.zeros_like(v)
|
| 125 |
+
|
| 126 |
+
row_splits = zip(
|
| 127 |
+
q.split(q_bucket_size, dim=-2),
|
| 128 |
+
o.split(q_bucket_size, dim=-2),
|
| 129 |
+
do.split(q_bucket_size, dim=-2),
|
| 130 |
+
mask,
|
| 131 |
+
l.split(q_bucket_size, dim=-2),
|
| 132 |
+
m.split(q_bucket_size, dim=-2),
|
| 133 |
+
dq.split(q_bucket_size, dim=-2),
|
| 134 |
+
)
|
| 135 |
+
|
| 136 |
+
for ind, (qc, oc, doc, row_mask, lc, mc, dqc) in enumerate(row_splits):
|
| 137 |
+
q_start_index = ind * q_bucket_size - qk_len_diff
|
| 138 |
+
|
| 139 |
+
col_splits = zip(
|
| 140 |
+
k.split(k_bucket_size, dim=-2),
|
| 141 |
+
v.split(k_bucket_size, dim=-2),
|
| 142 |
+
dk.split(k_bucket_size, dim=-2),
|
| 143 |
+
dv.split(k_bucket_size, dim=-2),
|
| 144 |
+
)
|
| 145 |
+
|
| 146 |
+
for k_ind, (kc, vc, dkc, dvc) in enumerate(col_splits):
|
| 147 |
+
k_start_index = k_ind * k_bucket_size
|
| 148 |
+
|
| 149 |
+
attn_weights = (
|
| 150 |
+
torch.einsum("... i d, ... j d -> ... i j", qc, kc) * scale
|
| 151 |
+
)
|
| 152 |
+
|
| 153 |
+
if causal and q_start_index < (k_start_index + k_bucket_size - 1):
|
| 154 |
+
causal_mask = torch.ones(
|
| 155 |
+
(qc.shape[-2], kc.shape[-2]), dtype=torch.bool, device=device
|
| 156 |
+
).triu(q_start_index - k_start_index + 1)
|
| 157 |
+
attn_weights.masked_fill_(causal_mask, max_neg_value)
|
| 158 |
+
|
| 159 |
+
exp_attn_weights = torch.exp(attn_weights - mc)
|
| 160 |
+
|
| 161 |
+
if row_mask is not None:
|
| 162 |
+
exp_attn_weights.masked_fill_(~row_mask, 0.0)
|
| 163 |
+
|
| 164 |
+
p = exp_attn_weights / lc
|
| 165 |
+
|
| 166 |
+
dv_chunk = torch.einsum("... i j, ... i d -> ... j d", p, doc)
|
| 167 |
+
dp = torch.einsum("... i d, ... j d -> ... i j", doc, vc)
|
| 168 |
+
|
| 169 |
+
D = (doc * oc).sum(dim=-1, keepdims=True)
|
| 170 |
+
ds = p * scale * (dp - D)
|
| 171 |
+
|
| 172 |
+
dq_chunk = torch.einsum("... i j, ... j d -> ... i d", ds, kc)
|
| 173 |
+
dk_chunk = torch.einsum("... i j, ... i d -> ... j d", ds, qc)
|
| 174 |
+
|
| 175 |
+
dqc.add_(dq_chunk)
|
| 176 |
+
dkc.add_(dk_chunk)
|
| 177 |
+
dvc.add_(dv_chunk)
|
| 178 |
+
|
| 179 |
+
return dq, dk, dv, None, None, None, None
|
| 180 |
+
|
| 181 |
+
|
| 182 |
+
class FlashAttnProcessor:
|
| 183 |
+
def __call__(
|
| 184 |
+
self,
|
| 185 |
+
attn: Attention,
|
| 186 |
+
hidden_states,
|
| 187 |
+
encoder_hidden_states=None,
|
| 188 |
+
attention_mask=None,
|
| 189 |
+
) -> Any:
|
| 190 |
+
q_bucket_size = 512
|
| 191 |
+
k_bucket_size = 1024
|
| 192 |
+
|
| 193 |
+
h = attn.heads
|
| 194 |
+
q = attn.to_q(hidden_states)
|
| 195 |
+
|
| 196 |
+
encoder_hidden_states = (
|
| 197 |
+
encoder_hidden_states
|
| 198 |
+
if encoder_hidden_states is not None
|
| 199 |
+
else hidden_states
|
| 200 |
+
)
|
| 201 |
+
encoder_hidden_states = encoder_hidden_states.to(hidden_states.dtype)
|
| 202 |
+
|
| 203 |
+
if hasattr(attn, "hypernetwork") and attn.hypernetwork is not None:
|
| 204 |
+
context_k, context_v = attn.hypernetwork.forward(
|
| 205 |
+
hidden_states, encoder_hidden_states
|
| 206 |
+
)
|
| 207 |
+
context_k = context_k.to(hidden_states.dtype)
|
| 208 |
+
context_v = context_v.to(hidden_states.dtype)
|
| 209 |
+
else:
|
| 210 |
+
context_k = encoder_hidden_states
|
| 211 |
+
context_v = encoder_hidden_states
|
| 212 |
+
|
| 213 |
+
k = attn.to_k(context_k)
|
| 214 |
+
v = attn.to_v(context_v)
|
| 215 |
+
del encoder_hidden_states, hidden_states
|
| 216 |
+
|
| 217 |
+
q, k, v = map(lambda t: rearrange(t, "b n (h d) -> b h n d", h=h), (q, k, v))
|
| 218 |
+
|
| 219 |
+
out = FlashAttentionFunction.apply(
|
| 220 |
+
q, k, v, attention_mask, False, q_bucket_size, k_bucket_size
|
| 221 |
+
)
|
| 222 |
+
|
| 223 |
+
out = rearrange(out, "b h n d -> b n (h d)")
|
| 224 |
+
|
| 225 |
+
out = attn.to_out[0](out)
|
| 226 |
+
out = attn.to_out[1](out)
|
| 227 |
+
return out
|
blip.py
ADDED
|
@@ -0,0 +1,245 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
'''
|
| 2 |
+
* Copyright (c) 2022, salesforce.com, inc.
|
| 3 |
+
* All rights reserved.
|
| 4 |
+
* SPDX-License-Identifier: BSD-3-Clause
|
| 5 |
+
* For full license text, see LICENSE.txt file in the repo root or https://opensource.org/licenses/BSD-3-Clause
|
| 6 |
+
* By Junnan Li
|
| 7 |
+
'''
|
| 8 |
+
import warnings
|
| 9 |
+
warnings.filterwarnings("ignore")
|
| 10 |
+
|
| 11 |
+
# from models.vit import VisionTransformer, interpolate_pos_embed
|
| 12 |
+
# from models.med import BertConfig, BertModel, BertLMHeadModel
|
| 13 |
+
from blip.vit import VisionTransformer, interpolate_pos_embed
|
| 14 |
+
from blip.med import BertConfig, BertModel, BertLMHeadModel
|
| 15 |
+
from transformers import BertTokenizer
|
| 16 |
+
|
| 17 |
+
import torch
|
| 18 |
+
from torch import nn
|
| 19 |
+
import torch.nn.functional as F
|
| 20 |
+
|
| 21 |
+
import os
|
| 22 |
+
from urllib.parse import urlparse
|
| 23 |
+
from timm.models.hub import download_cached_file
|
| 24 |
+
from library.utils import setup_logging
|
| 25 |
+
setup_logging()
|
| 26 |
+
import logging
|
| 27 |
+
logger = logging.getLogger(__name__)
|
| 28 |
+
|
| 29 |
+
class BLIP_Base(nn.Module):
|
| 30 |
+
def __init__(self,
|
| 31 |
+
med_config = 'configs/med_config.json',
|
| 32 |
+
image_size = 224,
|
| 33 |
+
vit = 'base',
|
| 34 |
+
vit_grad_ckpt = False,
|
| 35 |
+
vit_ckpt_layer = 0,
|
| 36 |
+
):
|
| 37 |
+
"""
|
| 38 |
+
Args:
|
| 39 |
+
med_config (str): path for the mixture of encoder-decoder model's configuration file
|
| 40 |
+
image_size (int): input image size
|
| 41 |
+
vit (str): model size of vision transformer
|
| 42 |
+
"""
|
| 43 |
+
super().__init__()
|
| 44 |
+
|
| 45 |
+
self.visual_encoder, vision_width = create_vit(vit,image_size, vit_grad_ckpt, vit_ckpt_layer)
|
| 46 |
+
self.tokenizer = init_tokenizer()
|
| 47 |
+
med_config = BertConfig.from_json_file(med_config)
|
| 48 |
+
med_config.encoder_width = vision_width
|
| 49 |
+
self.text_encoder = BertModel(config=med_config, add_pooling_layer=False)
|
| 50 |
+
|
| 51 |
+
|
| 52 |
+
def forward(self, image, caption, mode):
|
| 53 |
+
|
| 54 |
+
assert mode in ['image', 'text', 'multimodal'], "mode parameter must be image, text, or multimodal"
|
| 55 |
+
text = self.tokenizer(caption, return_tensors="pt").to(image.device)
|
| 56 |
+
|
| 57 |
+
if mode=='image':
|
| 58 |
+
# return image features
|
| 59 |
+
image_embeds = self.visual_encoder(image)
|
| 60 |
+
return image_embeds
|
| 61 |
+
|
| 62 |
+
elif mode=='text':
|
| 63 |
+
# return text features
|
| 64 |
+
text_output = self.text_encoder(text.input_ids, attention_mask = text.attention_mask,
|
| 65 |
+
return_dict = True, mode = 'text')
|
| 66 |
+
return text_output.last_hidden_state
|
| 67 |
+
|
| 68 |
+
elif mode=='multimodal':
|
| 69 |
+
# return multimodel features
|
| 70 |
+
image_embeds = self.visual_encoder(image)
|
| 71 |
+
image_atts = torch.ones(image_embeds.size()[:-1],dtype=torch.long).to(image.device)
|
| 72 |
+
|
| 73 |
+
text.input_ids[:,0] = self.tokenizer.enc_token_id
|
| 74 |
+
output = self.text_encoder(text.input_ids,
|
| 75 |
+
attention_mask = text.attention_mask,
|
| 76 |
+
encoder_hidden_states = image_embeds,
|
| 77 |
+
encoder_attention_mask = image_atts,
|
| 78 |
+
return_dict = True,
|
| 79 |
+
)
|
| 80 |
+
return output.last_hidden_state
|
| 81 |
+
|
| 82 |
+
|
| 83 |
+
|
| 84 |
+
class BLIP_Decoder(nn.Module):
|
| 85 |
+
def __init__(self,
|
| 86 |
+
med_config = 'configs/med_config.json',
|
| 87 |
+
image_size = 384,
|
| 88 |
+
vit = 'base',
|
| 89 |
+
vit_grad_ckpt = False,
|
| 90 |
+
vit_ckpt_layer = 0,
|
| 91 |
+
prompt = 'a picture of ',
|
| 92 |
+
):
|
| 93 |
+
"""
|
| 94 |
+
Args:
|
| 95 |
+
med_config (str): path for the mixture of encoder-decoder model's configuration file
|
| 96 |
+
image_size (int): input image size
|
| 97 |
+
vit (str): model size of vision transformer
|
| 98 |
+
"""
|
| 99 |
+
super().__init__()
|
| 100 |
+
|
| 101 |
+
self.visual_encoder, vision_width = create_vit(vit,image_size, vit_grad_ckpt, vit_ckpt_layer)
|
| 102 |
+
self.tokenizer = init_tokenizer()
|
| 103 |
+
med_config = BertConfig.from_json_file(med_config)
|
| 104 |
+
med_config.encoder_width = vision_width
|
| 105 |
+
self.text_decoder = BertLMHeadModel(config=med_config)
|
| 106 |
+
|
| 107 |
+
self.prompt = prompt
|
| 108 |
+
self.prompt_length = len(self.tokenizer(self.prompt).input_ids)-1
|
| 109 |
+
|
| 110 |
+
|
| 111 |
+
def forward(self, image, caption):
|
| 112 |
+
|
| 113 |
+
image_embeds = self.visual_encoder(image)
|
| 114 |
+
image_atts = torch.ones(image_embeds.size()[:-1],dtype=torch.long).to(image.device)
|
| 115 |
+
|
| 116 |
+
text = self.tokenizer(caption, padding='longest', truncation=True, max_length=40, return_tensors="pt").to(image.device)
|
| 117 |
+
|
| 118 |
+
text.input_ids[:,0] = self.tokenizer.bos_token_id
|
| 119 |
+
|
| 120 |
+
decoder_targets = text.input_ids.masked_fill(text.input_ids == self.tokenizer.pad_token_id, -100)
|
| 121 |
+
decoder_targets[:,:self.prompt_length] = -100
|
| 122 |
+
|
| 123 |
+
decoder_output = self.text_decoder(text.input_ids,
|
| 124 |
+
attention_mask = text.attention_mask,
|
| 125 |
+
encoder_hidden_states = image_embeds,
|
| 126 |
+
encoder_attention_mask = image_atts,
|
| 127 |
+
labels = decoder_targets,
|
| 128 |
+
return_dict = True,
|
| 129 |
+
)
|
| 130 |
+
loss_lm = decoder_output.loss
|
| 131 |
+
|
| 132 |
+
return loss_lm
|
| 133 |
+
|
| 134 |
+
def generate(self, image, sample=False, num_beams=3, max_length=30, min_length=10, top_p=0.9, repetition_penalty=1.0):
|
| 135 |
+
image_embeds = self.visual_encoder(image)
|
| 136 |
+
|
| 137 |
+
# recent version of transformers seems to do repeat_interleave automatically
|
| 138 |
+
# if not sample:
|
| 139 |
+
# image_embeds = image_embeds.repeat_interleave(num_beams,dim=0)
|
| 140 |
+
|
| 141 |
+
image_atts = torch.ones(image_embeds.size()[:-1],dtype=torch.long).to(image.device)
|
| 142 |
+
model_kwargs = {"encoder_hidden_states": image_embeds, "encoder_attention_mask":image_atts}
|
| 143 |
+
|
| 144 |
+
prompt = [self.prompt] * image.size(0)
|
| 145 |
+
input_ids = self.tokenizer(prompt, return_tensors="pt").input_ids.to(image.device)
|
| 146 |
+
input_ids[:,0] = self.tokenizer.bos_token_id
|
| 147 |
+
input_ids = input_ids[:, :-1]
|
| 148 |
+
|
| 149 |
+
if sample:
|
| 150 |
+
#nucleus sampling
|
| 151 |
+
outputs = self.text_decoder.generate(input_ids=input_ids,
|
| 152 |
+
max_length=max_length,
|
| 153 |
+
min_length=min_length,
|
| 154 |
+
do_sample=True,
|
| 155 |
+
top_p=top_p,
|
| 156 |
+
num_return_sequences=1,
|
| 157 |
+
eos_token_id=self.tokenizer.sep_token_id,
|
| 158 |
+
pad_token_id=self.tokenizer.pad_token_id,
|
| 159 |
+
repetition_penalty=1.1,
|
| 160 |
+
**model_kwargs)
|
| 161 |
+
else:
|
| 162 |
+
#beam search
|
| 163 |
+
outputs = self.text_decoder.generate(input_ids=input_ids,
|
| 164 |
+
max_length=max_length,
|
| 165 |
+
min_length=min_length,
|
| 166 |
+
num_beams=num_beams,
|
| 167 |
+
eos_token_id=self.tokenizer.sep_token_id,
|
| 168 |
+
pad_token_id=self.tokenizer.pad_token_id,
|
| 169 |
+
repetition_penalty=repetition_penalty,
|
| 170 |
+
**model_kwargs)
|
| 171 |
+
|
| 172 |
+
captions = []
|
| 173 |
+
for output in outputs:
|
| 174 |
+
caption = self.tokenizer.decode(output, skip_special_tokens=True)
|
| 175 |
+
captions.append(caption[len(self.prompt):])
|
| 176 |
+
return captions
|
| 177 |
+
|
| 178 |
+
|
| 179 |
+
def blip_decoder(pretrained='',**kwargs):
|
| 180 |
+
model = BLIP_Decoder(**kwargs)
|
| 181 |
+
if pretrained:
|
| 182 |
+
model,msg = load_checkpoint(model,pretrained)
|
| 183 |
+
assert(len(msg.missing_keys)==0)
|
| 184 |
+
return model
|
| 185 |
+
|
| 186 |
+
def blip_feature_extractor(pretrained='',**kwargs):
|
| 187 |
+
model = BLIP_Base(**kwargs)
|
| 188 |
+
if pretrained:
|
| 189 |
+
model,msg = load_checkpoint(model,pretrained)
|
| 190 |
+
assert(len(msg.missing_keys)==0)
|
| 191 |
+
return model
|
| 192 |
+
|
| 193 |
+
def init_tokenizer():
|
| 194 |
+
tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
|
| 195 |
+
tokenizer.add_special_tokens({'bos_token':'[DEC]'})
|
| 196 |
+
tokenizer.add_special_tokens({'additional_special_tokens':['[ENC]']})
|
| 197 |
+
tokenizer.enc_token_id = tokenizer.additional_special_tokens_ids[0]
|
| 198 |
+
return tokenizer
|
| 199 |
+
|
| 200 |
+
|
| 201 |
+
def create_vit(vit, image_size, use_grad_checkpointing=False, ckpt_layer=0, drop_path_rate=0):
|
| 202 |
+
|
| 203 |
+
assert vit in ['base', 'large'], "vit parameter must be base or large"
|
| 204 |
+
if vit=='base':
|
| 205 |
+
vision_width = 768
|
| 206 |
+
visual_encoder = VisionTransformer(img_size=image_size, patch_size=16, embed_dim=vision_width, depth=12,
|
| 207 |
+
num_heads=12, use_grad_checkpointing=use_grad_checkpointing, ckpt_layer=ckpt_layer,
|
| 208 |
+
drop_path_rate=0 or drop_path_rate
|
| 209 |
+
)
|
| 210 |
+
elif vit=='large':
|
| 211 |
+
vision_width = 1024
|
| 212 |
+
visual_encoder = VisionTransformer(img_size=image_size, patch_size=16, embed_dim=vision_width, depth=24,
|
| 213 |
+
num_heads=16, use_grad_checkpointing=use_grad_checkpointing, ckpt_layer=ckpt_layer,
|
| 214 |
+
drop_path_rate=0.1 or drop_path_rate
|
| 215 |
+
)
|
| 216 |
+
return visual_encoder, vision_width
|
| 217 |
+
|
| 218 |
+
def is_url(url_or_filename):
|
| 219 |
+
parsed = urlparse(url_or_filename)
|
| 220 |
+
return parsed.scheme in ("http", "https")
|
| 221 |
+
|
| 222 |
+
def load_checkpoint(model,url_or_filename):
|
| 223 |
+
if is_url(url_or_filename):
|
| 224 |
+
cached_file = download_cached_file(url_or_filename, check_hash=False, progress=True)
|
| 225 |
+
checkpoint = torch.load(cached_file, map_location='cpu')
|
| 226 |
+
elif os.path.isfile(url_or_filename):
|
| 227 |
+
checkpoint = torch.load(url_or_filename, map_location='cpu')
|
| 228 |
+
else:
|
| 229 |
+
raise RuntimeError('checkpoint url or path is invalid')
|
| 230 |
+
|
| 231 |
+
state_dict = checkpoint['model']
|
| 232 |
+
|
| 233 |
+
state_dict['visual_encoder.pos_embed'] = interpolate_pos_embed(state_dict['visual_encoder.pos_embed'],model.visual_encoder)
|
| 234 |
+
if 'visual_encoder_m.pos_embed' in model.state_dict().keys():
|
| 235 |
+
state_dict['visual_encoder_m.pos_embed'] = interpolate_pos_embed(state_dict['visual_encoder_m.pos_embed'],
|
| 236 |
+
model.visual_encoder_m)
|
| 237 |
+
for key in model.state_dict().keys():
|
| 238 |
+
if key in state_dict.keys():
|
| 239 |
+
if state_dict[key].shape!=model.state_dict()[key].shape:
|
| 240 |
+
del state_dict[key]
|
| 241 |
+
|
| 242 |
+
msg = model.load_state_dict(state_dict,strict=False)
|
| 243 |
+
logger.info('load checkpoint from %s'%url_or_filename)
|
| 244 |
+
return model,msg
|
| 245 |
+
|
cache_latents.py
ADDED
|
@@ -0,0 +1,205 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# latentsのdiskへの事前キャッシュを行う / cache latents to disk
|
| 2 |
+
|
| 3 |
+
import argparse
|
| 4 |
+
import math
|
| 5 |
+
from multiprocessing import Value
|
| 6 |
+
import os
|
| 7 |
+
|
| 8 |
+
from accelerate.utils import set_seed
|
| 9 |
+
import torch
|
| 10 |
+
from tqdm import tqdm
|
| 11 |
+
|
| 12 |
+
from library import config_util
|
| 13 |
+
from library import train_util
|
| 14 |
+
from library import sdxl_train_util
|
| 15 |
+
from library.config_util import (
|
| 16 |
+
ConfigSanitizer,
|
| 17 |
+
BlueprintGenerator,
|
| 18 |
+
)
|
| 19 |
+
from library.utils import setup_logging, add_logging_arguments
|
| 20 |
+
setup_logging()
|
| 21 |
+
import logging
|
| 22 |
+
|
| 23 |
+
logger = logging.getLogger(__name__)
|
| 24 |
+
|
| 25 |
+
|
| 26 |
+
def cache_to_disk(args: argparse.Namespace) -> None:
|
| 27 |
+
setup_logging(args, reset=True)
|
| 28 |
+
train_util.prepare_dataset_args(args, True)
|
| 29 |
+
|
| 30 |
+
# check cache latents arg
|
| 31 |
+
assert args.cache_latents_to_disk, "cache_latents_to_disk must be True / cache_latents_to_diskはTrueである必要があります"
|
| 32 |
+
|
| 33 |
+
use_dreambooth_method = args.in_json is None
|
| 34 |
+
|
| 35 |
+
if args.seed is not None:
|
| 36 |
+
set_seed(args.seed) # 乱数系列を初期化する
|
| 37 |
+
|
| 38 |
+
# tokenizerを準備する:datasetを動かすために必要
|
| 39 |
+
if args.sdxl:
|
| 40 |
+
tokenizer1, tokenizer2 = sdxl_train_util.load_tokenizers(args)
|
| 41 |
+
tokenizers = [tokenizer1, tokenizer2]
|
| 42 |
+
else:
|
| 43 |
+
tokenizer = train_util.load_tokenizer(args)
|
| 44 |
+
tokenizers = [tokenizer]
|
| 45 |
+
|
| 46 |
+
# データセットを準備する
|
| 47 |
+
if args.dataset_class is None:
|
| 48 |
+
blueprint_generator = BlueprintGenerator(ConfigSanitizer(True, True, False, True))
|
| 49 |
+
if args.dataset_config is not None:
|
| 50 |
+
logger.info(f"Load dataset config from {args.dataset_config}")
|
| 51 |
+
user_config = config_util.load_user_config(args.dataset_config)
|
| 52 |
+
ignored = ["train_data_dir", "in_json"]
|
| 53 |
+
if any(getattr(args, attr) is not None for attr in ignored):
|
| 54 |
+
logger.warning(
|
| 55 |
+
"ignore following options because config file is found: {0} / 設定ファイルが利用されるため以下のオプションは無視されます: {0}".format(
|
| 56 |
+
", ".join(ignored)
|
| 57 |
+
)
|
| 58 |
+
)
|
| 59 |
+
else:
|
| 60 |
+
if use_dreambooth_method:
|
| 61 |
+
logger.info("Using DreamBooth method.")
|
| 62 |
+
user_config = {
|
| 63 |
+
"datasets": [
|
| 64 |
+
{
|
| 65 |
+
"subsets": config_util.generate_dreambooth_subsets_config_by_subdirs(
|
| 66 |
+
args.train_data_dir, args.reg_data_dir
|
| 67 |
+
)
|
| 68 |
+
}
|
| 69 |
+
]
|
| 70 |
+
}
|
| 71 |
+
else:
|
| 72 |
+
logger.info("Training with captions.")
|
| 73 |
+
user_config = {
|
| 74 |
+
"datasets": [
|
| 75 |
+
{
|
| 76 |
+
"subsets": [
|
| 77 |
+
{
|
| 78 |
+
"image_dir": args.train_data_dir,
|
| 79 |
+
"metadata_file": args.in_json,
|
| 80 |
+
}
|
| 81 |
+
]
|
| 82 |
+
}
|
| 83 |
+
]
|
| 84 |
+
}
|
| 85 |
+
|
| 86 |
+
blueprint = blueprint_generator.generate(user_config, args, tokenizer=tokenizers)
|
| 87 |
+
train_dataset_group = config_util.generate_dataset_group_by_blueprint(blueprint.dataset_group)
|
| 88 |
+
else:
|
| 89 |
+
train_dataset_group = train_util.load_arbitrary_dataset(args, tokenizers)
|
| 90 |
+
|
| 91 |
+
# datasetのcache_latentsを呼ばなければ、生の画像が返る
|
| 92 |
+
|
| 93 |
+
current_epoch = Value("i", 0)
|
| 94 |
+
current_step = Value("i", 0)
|
| 95 |
+
ds_for_collator = train_dataset_group if args.max_data_loader_n_workers == 0 else None
|
| 96 |
+
collator = train_util.collator_class(current_epoch, current_step, ds_for_collator)
|
| 97 |
+
|
| 98 |
+
# acceleratorを準備する
|
| 99 |
+
logger.info("prepare accelerator")
|
| 100 |
+
args.deepspeed = False
|
| 101 |
+
accelerator = train_util.prepare_accelerator(args)
|
| 102 |
+
|
| 103 |
+
# mixed precisionに対応した型を用意しておき適宜castする
|
| 104 |
+
weight_dtype, _ = train_util.prepare_dtype(args)
|
| 105 |
+
vae_dtype = torch.float32 if args.no_half_vae else weight_dtype
|
| 106 |
+
|
| 107 |
+
# モデルを読み込む
|
| 108 |
+
logger.info("load model")
|
| 109 |
+
if args.sdxl:
|
| 110 |
+
(_, _, _, vae, _, _, _) = sdxl_train_util.load_target_model(args, accelerator, "sdxl", weight_dtype)
|
| 111 |
+
else:
|
| 112 |
+
_, vae, _, _ = train_util.load_target_model(args, weight_dtype, accelerator)
|
| 113 |
+
|
| 114 |
+
if torch.__version__ >= "2.0.0": # PyTorch 2.0.0 以上対応のxformersなら以下が使える
|
| 115 |
+
vae.set_use_memory_efficient_attention_xformers(args.xformers)
|
| 116 |
+
vae.to(accelerator.device, dtype=vae_dtype)
|
| 117 |
+
vae.requires_grad_(False)
|
| 118 |
+
vae.eval()
|
| 119 |
+
|
| 120 |
+
# dataloaderを準備する
|
| 121 |
+
train_dataset_group.set_caching_mode("latents")
|
| 122 |
+
|
| 123 |
+
# DataLoaderのプロセス数:0 は persistent_workers が使えないので注意
|
| 124 |
+
n_workers = min(args.max_data_loader_n_workers, os.cpu_count()) # cpu_count or max_data_loader_n_workers
|
| 125 |
+
|
| 126 |
+
train_dataloader = torch.utils.data.DataLoader(
|
| 127 |
+
train_dataset_group,
|
| 128 |
+
batch_size=1,
|
| 129 |
+
shuffle=True,
|
| 130 |
+
collate_fn=collator,
|
| 131 |
+
num_workers=n_workers,
|
| 132 |
+
persistent_workers=args.persistent_data_loader_workers,
|
| 133 |
+
)
|
| 134 |
+
|
| 135 |
+
# acceleratorを使ってモデルを準備する:マルチGPUで使えるようになるはず
|
| 136 |
+
train_dataloader = accelerator.prepare(train_dataloader)
|
| 137 |
+
|
| 138 |
+
# データ取得のためのループ
|
| 139 |
+
for batch in tqdm(train_dataloader):
|
| 140 |
+
b_size = len(batch["images"])
|
| 141 |
+
vae_batch_size = b_size if args.vae_batch_size is None else args.vae_batch_size
|
| 142 |
+
flip_aug = batch["flip_aug"]
|
| 143 |
+
alpha_mask = batch["alpha_mask"]
|
| 144 |
+
random_crop = batch["random_crop"]
|
| 145 |
+
bucket_reso = batch["bucket_reso"]
|
| 146 |
+
|
| 147 |
+
# バッチを分割して処理する
|
| 148 |
+
for i in range(0, b_size, vae_batch_size):
|
| 149 |
+
images = batch["images"][i : i + vae_batch_size]
|
| 150 |
+
absolute_paths = batch["absolute_paths"][i : i + vae_batch_size]
|
| 151 |
+
resized_sizes = batch["resized_sizes"][i : i + vae_batch_size]
|
| 152 |
+
|
| 153 |
+
image_infos = []
|
| 154 |
+
for i, (image, absolute_path, resized_size) in enumerate(zip(images, absolute_paths, resized_sizes)):
|
| 155 |
+
image_info = train_util.ImageInfo(absolute_path, 1, "dummy", False, absolute_path)
|
| 156 |
+
image_info.image = image
|
| 157 |
+
image_info.bucket_reso = bucket_reso
|
| 158 |
+
image_info.resized_size = resized_size
|
| 159 |
+
image_info.latents_npz = os.path.splitext(absolute_path)[0] + ".npz"
|
| 160 |
+
|
| 161 |
+
if args.skip_existing:
|
| 162 |
+
if train_util.is_disk_cached_latents_is_expected(
|
| 163 |
+
image_info.bucket_reso, image_info.latents_npz, flip_aug, alpha_mask
|
| 164 |
+
):
|
| 165 |
+
logger.warning(f"Skipping {image_info.latents_npz} because it already exists.")
|
| 166 |
+
continue
|
| 167 |
+
|
| 168 |
+
image_infos.append(image_info)
|
| 169 |
+
|
| 170 |
+
if len(image_infos) > 0:
|
| 171 |
+
train_util.cache_batch_latents(vae, True, image_infos, flip_aug, alpha_mask, random_crop)
|
| 172 |
+
|
| 173 |
+
accelerator.wait_for_everyone()
|
| 174 |
+
accelerator.print(f"Finished caching latents for {len(train_dataset_group)} batches.")
|
| 175 |
+
|
| 176 |
+
|
| 177 |
+
def setup_parser() -> argparse.ArgumentParser:
|
| 178 |
+
parser = argparse.ArgumentParser()
|
| 179 |
+
|
| 180 |
+
add_logging_arguments(parser)
|
| 181 |
+
train_util.add_sd_models_arguments(parser)
|
| 182 |
+
train_util.add_training_arguments(parser, True)
|
| 183 |
+
train_util.add_dataset_arguments(parser, True, True, True)
|
| 184 |
+
config_util.add_config_arguments(parser)
|
| 185 |
+
parser.add_argument("--sdxl", action="store_true", help="Use SDXL model / SDXLモデルを使用する")
|
| 186 |
+
parser.add_argument(
|
| 187 |
+
"--no_half_vae",
|
| 188 |
+
action="store_true",
|
| 189 |
+
help="do not use fp16/bf16 VAE in mixed precision (use float VAE) / mixed precisionでも fp16/bf16 VAEを使わずfloat VAEを使う",
|
| 190 |
+
)
|
| 191 |
+
parser.add_argument(
|
| 192 |
+
"--skip_existing",
|
| 193 |
+
action="store_true",
|
| 194 |
+
help="skip images if npz already exists (both normal and flipped exists if flip_aug is enabled) / npzが既に存在する画像をスキップする(flip_aug有効時は通常、反転の両方が存在する画像をスキップ)",
|
| 195 |
+
)
|
| 196 |
+
return parser
|
| 197 |
+
|
| 198 |
+
|
| 199 |
+
if __name__ == "__main__":
|
| 200 |
+
parser = setup_parser()
|
| 201 |
+
|
| 202 |
+
args = parser.parse_args()
|
| 203 |
+
args = train_util.read_config_from_file(args, parser)
|
| 204 |
+
|
| 205 |
+
cache_to_disk(args)
|
cache_text_encoder_outputs.py
ADDED
|
@@ -0,0 +1,197 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# text encoder出力のdiskへの事前キャッシュを行う / cache text encoder outputs to disk in advance
|
| 2 |
+
|
| 3 |
+
import argparse
|
| 4 |
+
import math
|
| 5 |
+
from multiprocessing import Value
|
| 6 |
+
import os
|
| 7 |
+
|
| 8 |
+
from accelerate.utils import set_seed
|
| 9 |
+
import torch
|
| 10 |
+
from tqdm import tqdm
|
| 11 |
+
|
| 12 |
+
from library import config_util
|
| 13 |
+
from library import train_util
|
| 14 |
+
from library import sdxl_train_util
|
| 15 |
+
from library.config_util import (
|
| 16 |
+
ConfigSanitizer,
|
| 17 |
+
BlueprintGenerator,
|
| 18 |
+
)
|
| 19 |
+
from library.utils import setup_logging, add_logging_arguments
|
| 20 |
+
setup_logging()
|
| 21 |
+
import logging
|
| 22 |
+
logger = logging.getLogger(__name__)
|
| 23 |
+
|
| 24 |
+
def cache_to_disk(args: argparse.Namespace) -> None:
|
| 25 |
+
setup_logging(args, reset=True)
|
| 26 |
+
train_util.prepare_dataset_args(args, True)
|
| 27 |
+
|
| 28 |
+
# check cache arg
|
| 29 |
+
assert (
|
| 30 |
+
args.cache_text_encoder_outputs_to_disk
|
| 31 |
+
), "cache_text_encoder_outputs_to_disk must be True / cache_text_encoder_outputs_to_diskはTrueである必要があります"
|
| 32 |
+
|
| 33 |
+
# できるだけ準備はしておくが今のところSDXLのみしか動かない
|
| 34 |
+
assert (
|
| 35 |
+
args.sdxl
|
| 36 |
+
), "cache_text_encoder_outputs_to_disk is only available for SDXL / cache_text_encoder_outputs_to_diskはSDXLのみ利用可能です"
|
| 37 |
+
|
| 38 |
+
use_dreambooth_method = args.in_json is None
|
| 39 |
+
|
| 40 |
+
if args.seed is not None:
|
| 41 |
+
set_seed(args.seed) # 乱数系列を初期化する
|
| 42 |
+
|
| 43 |
+
# tokenizerを準備する:datasetを動かすために必要
|
| 44 |
+
if args.sdxl:
|
| 45 |
+
tokenizer1, tokenizer2 = sdxl_train_util.load_tokenizers(args)
|
| 46 |
+
tokenizers = [tokenizer1, tokenizer2]
|
| 47 |
+
else:
|
| 48 |
+
tokenizer = train_util.load_tokenizer(args)
|
| 49 |
+
tokenizers = [tokenizer]
|
| 50 |
+
|
| 51 |
+
# データセットを準備する
|
| 52 |
+
if args.dataset_class is None:
|
| 53 |
+
blueprint_generator = BlueprintGenerator(ConfigSanitizer(True, True, False, True))
|
| 54 |
+
if args.dataset_config is not None:
|
| 55 |
+
logger.info(f"Load dataset config from {args.dataset_config}")
|
| 56 |
+
user_config = config_util.load_user_config(args.dataset_config)
|
| 57 |
+
ignored = ["train_data_dir", "in_json"]
|
| 58 |
+
if any(getattr(args, attr) is not None for attr in ignored):
|
| 59 |
+
logger.warning(
|
| 60 |
+
"ignore following options because config file is found: {0} / 設定ファイルが利用されるため以下のオプションは無視されます: {0}".format(
|
| 61 |
+
", ".join(ignored)
|
| 62 |
+
)
|
| 63 |
+
)
|
| 64 |
+
else:
|
| 65 |
+
if use_dreambooth_method:
|
| 66 |
+
logger.info("Using DreamBooth method.")
|
| 67 |
+
user_config = {
|
| 68 |
+
"datasets": [
|
| 69 |
+
{
|
| 70 |
+
"subsets": config_util.generate_dreambooth_subsets_config_by_subdirs(
|
| 71 |
+
args.train_data_dir, args.reg_data_dir
|
| 72 |
+
)
|
| 73 |
+
}
|
| 74 |
+
]
|
| 75 |
+
}
|
| 76 |
+
else:
|
| 77 |
+
logger.info("Training with captions.")
|
| 78 |
+
user_config = {
|
| 79 |
+
"datasets": [
|
| 80 |
+
{
|
| 81 |
+
"subsets": [
|
| 82 |
+
{
|
| 83 |
+
"image_dir": args.train_data_dir,
|
| 84 |
+
"metadata_file": args.in_json,
|
| 85 |
+
}
|
| 86 |
+
]
|
| 87 |
+
}
|
| 88 |
+
]
|
| 89 |
+
}
|
| 90 |
+
|
| 91 |
+
blueprint = blueprint_generator.generate(user_config, args, tokenizer=tokenizers)
|
| 92 |
+
train_dataset_group = config_util.generate_dataset_group_by_blueprint(blueprint.dataset_group)
|
| 93 |
+
else:
|
| 94 |
+
train_dataset_group = train_util.load_arbitrary_dataset(args, tokenizers)
|
| 95 |
+
|
| 96 |
+
current_epoch = Value("i", 0)
|
| 97 |
+
current_step = Value("i", 0)
|
| 98 |
+
ds_for_collator = train_dataset_group if args.max_data_loader_n_workers == 0 else None
|
| 99 |
+
collator = train_util.collator_class(current_epoch, current_step, ds_for_collator)
|
| 100 |
+
|
| 101 |
+
# acceleratorを準備する
|
| 102 |
+
logger.info("prepare accelerator")
|
| 103 |
+
args.deepspeed = False
|
| 104 |
+
accelerator = train_util.prepare_accelerator(args)
|
| 105 |
+
|
| 106 |
+
# mixed precisionに対応した型を用意しておき適宜castする
|
| 107 |
+
weight_dtype, _ = train_util.prepare_dtype(args)
|
| 108 |
+
|
| 109 |
+
# モデルを読み込む
|
| 110 |
+
logger.info("load model")
|
| 111 |
+
if args.sdxl:
|
| 112 |
+
(_, text_encoder1, text_encoder2, _, _, _, _) = sdxl_train_util.load_target_model(args, accelerator, "sdxl", weight_dtype)
|
| 113 |
+
text_encoders = [text_encoder1, text_encoder2]
|
| 114 |
+
else:
|
| 115 |
+
text_encoder1, _, _, _ = train_util.load_target_model(args, weight_dtype, accelerator)
|
| 116 |
+
text_encoders = [text_encoder1]
|
| 117 |
+
|
| 118 |
+
for text_encoder in text_encoders:
|
| 119 |
+
text_encoder.to(accelerator.device, dtype=weight_dtype)
|
| 120 |
+
text_encoder.requires_grad_(False)
|
| 121 |
+
text_encoder.eval()
|
| 122 |
+
|
| 123 |
+
# dataloaderを準備する
|
| 124 |
+
train_dataset_group.set_caching_mode("text")
|
| 125 |
+
|
| 126 |
+
# DataLoaderのプロセス数:0 は persistent_workers が使えないので注意
|
| 127 |
+
n_workers = min(args.max_data_loader_n_workers, os.cpu_count()) # cpu_count or max_data_loader_n_workers
|
| 128 |
+
|
| 129 |
+
train_dataloader = torch.utils.data.DataLoader(
|
| 130 |
+
train_dataset_group,
|
| 131 |
+
batch_size=1,
|
| 132 |
+
shuffle=True,
|
| 133 |
+
collate_fn=collator,
|
| 134 |
+
num_workers=n_workers,
|
| 135 |
+
persistent_workers=args.persistent_data_loader_workers,
|
| 136 |
+
)
|
| 137 |
+
|
| 138 |
+
# acceleratorを使ってモデルを準備する:マルチGPUで使えるようになるはず
|
| 139 |
+
train_dataloader = accelerator.prepare(train_dataloader)
|
| 140 |
+
|
| 141 |
+
# データ取得のためのループ
|
| 142 |
+
for batch in tqdm(train_dataloader):
|
| 143 |
+
absolute_paths = batch["absolute_paths"]
|
| 144 |
+
input_ids1_list = batch["input_ids1_list"]
|
| 145 |
+
input_ids2_list = batch["input_ids2_list"]
|
| 146 |
+
|
| 147 |
+
image_infos = []
|
| 148 |
+
for absolute_path, input_ids1, input_ids2 in zip(absolute_paths, input_ids1_list, input_ids2_list):
|
| 149 |
+
image_info = train_util.ImageInfo(absolute_path, 1, "dummy", False, absolute_path)
|
| 150 |
+
image_info.text_encoder_outputs_npz = os.path.splitext(absolute_path)[0] + train_util.TEXT_ENCODER_OUTPUTS_CACHE_SUFFIX
|
| 151 |
+
image_info
|
| 152 |
+
|
| 153 |
+
if args.skip_existing:
|
| 154 |
+
if os.path.exists(image_info.text_encoder_outputs_npz):
|
| 155 |
+
logger.warning(f"Skipping {image_info.text_encoder_outputs_npz} because it already exists.")
|
| 156 |
+
continue
|
| 157 |
+
|
| 158 |
+
image_info.input_ids1 = input_ids1
|
| 159 |
+
image_info.input_ids2 = input_ids2
|
| 160 |
+
image_infos.append(image_info)
|
| 161 |
+
|
| 162 |
+
if len(image_infos) > 0:
|
| 163 |
+
b_input_ids1 = torch.stack([image_info.input_ids1 for image_info in image_infos])
|
| 164 |
+
b_input_ids2 = torch.stack([image_info.input_ids2 for image_info in image_infos])
|
| 165 |
+
train_util.cache_batch_text_encoder_outputs(
|
| 166 |
+
image_infos, tokenizers, text_encoders, args.max_token_length, True, b_input_ids1, b_input_ids2, weight_dtype
|
| 167 |
+
)
|
| 168 |
+
|
| 169 |
+
accelerator.wait_for_everyone()
|
| 170 |
+
accelerator.print(f"Finished caching latents for {len(train_dataset_group)} batches.")
|
| 171 |
+
|
| 172 |
+
|
| 173 |
+
def setup_parser() -> argparse.ArgumentParser:
|
| 174 |
+
parser = argparse.ArgumentParser()
|
| 175 |
+
|
| 176 |
+
add_logging_arguments(parser)
|
| 177 |
+
train_util.add_sd_models_arguments(parser)
|
| 178 |
+
train_util.add_training_arguments(parser, True)
|
| 179 |
+
train_util.add_dataset_arguments(parser, True, True, True)
|
| 180 |
+
config_util.add_config_arguments(parser)
|
| 181 |
+
sdxl_train_util.add_sdxl_training_arguments(parser)
|
| 182 |
+
parser.add_argument("--sdxl", action="store_true", help="Use SDXL model / SDXLモデルを使用する")
|
| 183 |
+
parser.add_argument(
|
| 184 |
+
"--skip_existing",
|
| 185 |
+
action="store_true",
|
| 186 |
+
help="skip images if npz already exists (both normal and flipped exists if flip_aug is enabled) / npzが既に存在する画像をスキップする(flip_aug有効時は通常、反転の両方が存在する画像をスキップ)",
|
| 187 |
+
)
|
| 188 |
+
return parser
|
| 189 |
+
|
| 190 |
+
|
| 191 |
+
if __name__ == "__main__":
|
| 192 |
+
parser = setup_parser()
|
| 193 |
+
|
| 194 |
+
args = parser.parse_args()
|
| 195 |
+
args = train_util.read_config_from_file(args, parser)
|
| 196 |
+
|
| 197 |
+
cache_to_disk(args)
|
canny.py
ADDED
|
@@ -0,0 +1,34 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import argparse
|
| 2 |
+
import cv2
|
| 3 |
+
|
| 4 |
+
import logging
|
| 5 |
+
from library.utils import setup_logging
|
| 6 |
+
setup_logging()
|
| 7 |
+
logger = logging.getLogger(__name__)
|
| 8 |
+
|
| 9 |
+
def canny(args):
|
| 10 |
+
img = cv2.imread(args.input)
|
| 11 |
+
img = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY)
|
| 12 |
+
|
| 13 |
+
canny_img = cv2.Canny(img, args.thres1, args.thres2)
|
| 14 |
+
# canny_img = 255 - canny_img
|
| 15 |
+
|
| 16 |
+
cv2.imwrite(args.output, canny_img)
|
| 17 |
+
logger.info("done!")
|
| 18 |
+
|
| 19 |
+
|
| 20 |
+
def setup_parser() -> argparse.ArgumentParser:
|
| 21 |
+
parser = argparse.ArgumentParser()
|
| 22 |
+
parser.add_argument("--input", type=str, default=None, help="input path")
|
| 23 |
+
parser.add_argument("--output", type=str, default=None, help="output path")
|
| 24 |
+
parser.add_argument("--thres1", type=int, default=32, help="thres1")
|
| 25 |
+
parser.add_argument("--thres2", type=int, default=224, help="thres2")
|
| 26 |
+
|
| 27 |
+
return parser
|
| 28 |
+
|
| 29 |
+
|
| 30 |
+
if __name__ == '__main__':
|
| 31 |
+
parser = setup_parser()
|
| 32 |
+
|
| 33 |
+
args = parser.parse_args()
|
| 34 |
+
canny(args)
|
cextension.py
ADDED
|
@@ -0,0 +1,54 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import ctypes as ct
|
| 2 |
+
from pathlib import Path
|
| 3 |
+
from warnings import warn
|
| 4 |
+
|
| 5 |
+
from .cuda_setup.main import evaluate_cuda_setup
|
| 6 |
+
|
| 7 |
+
|
| 8 |
+
class CUDALibrary_Singleton(object):
|
| 9 |
+
_instance = None
|
| 10 |
+
|
| 11 |
+
def __init__(self):
|
| 12 |
+
raise RuntimeError("Call get_instance() instead")
|
| 13 |
+
|
| 14 |
+
def initialize(self):
|
| 15 |
+
binary_name = evaluate_cuda_setup()
|
| 16 |
+
package_dir = Path(__file__).parent
|
| 17 |
+
binary_path = package_dir / binary_name
|
| 18 |
+
|
| 19 |
+
if not binary_path.exists():
|
| 20 |
+
print(f"CUDA SETUP: TODO: compile library for specific version: {binary_name}")
|
| 21 |
+
legacy_binary_name = "libbitsandbytes.so"
|
| 22 |
+
print(f"CUDA SETUP: Defaulting to {legacy_binary_name}...")
|
| 23 |
+
binary_path = package_dir / legacy_binary_name
|
| 24 |
+
if not binary_path.exists():
|
| 25 |
+
print('CUDA SETUP: CUDA detection failed. Either CUDA driver not installed, CUDA not installed, or you have multiple conflicting CUDA libraries!')
|
| 26 |
+
print('CUDA SETUP: If you compiled from source, try again with `make CUDA_VERSION=DETECTED_CUDA_VERSION` for example, `make CUDA_VERSION=113`.')
|
| 27 |
+
raise Exception('CUDA SETUP: Setup Failed!')
|
| 28 |
+
# self.lib = ct.cdll.LoadLibrary(binary_path)
|
| 29 |
+
self.lib = ct.cdll.LoadLibrary(str(binary_path)) # $$$
|
| 30 |
+
else:
|
| 31 |
+
print(f"CUDA SETUP: Loading binary {binary_path}...")
|
| 32 |
+
# self.lib = ct.cdll.LoadLibrary(binary_path)
|
| 33 |
+
self.lib = ct.cdll.LoadLibrary(str(binary_path)) # $$$
|
| 34 |
+
|
| 35 |
+
@classmethod
|
| 36 |
+
def get_instance(cls):
|
| 37 |
+
if cls._instance is None:
|
| 38 |
+
cls._instance = cls.__new__(cls)
|
| 39 |
+
cls._instance.initialize()
|
| 40 |
+
return cls._instance
|
| 41 |
+
|
| 42 |
+
|
| 43 |
+
lib = CUDALibrary_Singleton.get_instance().lib
|
| 44 |
+
try:
|
| 45 |
+
lib.cadam32bit_g32
|
| 46 |
+
lib.get_context.restype = ct.c_void_p
|
| 47 |
+
lib.get_cusparse.restype = ct.c_void_p
|
| 48 |
+
COMPILED_WITH_CUDA = True
|
| 49 |
+
except AttributeError:
|
| 50 |
+
warn(
|
| 51 |
+
"The installed version of bitsandbytes was compiled without GPU support. "
|
| 52 |
+
"8-bit optimizers and GPU quantization are unavailable."
|
| 53 |
+
)
|
| 54 |
+
COMPILED_WITH_CUDA = False
|
check_lora_weights.py
ADDED
|
@@ -0,0 +1,48 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import argparse
|
| 2 |
+
import os
|
| 3 |
+
import torch
|
| 4 |
+
from safetensors.torch import load_file
|
| 5 |
+
from library.utils import setup_logging
|
| 6 |
+
setup_logging()
|
| 7 |
+
import logging
|
| 8 |
+
logger = logging.getLogger(__name__)
|
| 9 |
+
|
| 10 |
+
def main(file):
|
| 11 |
+
logger.info(f"loading: {file}")
|
| 12 |
+
if os.path.splitext(file)[1] == ".safetensors":
|
| 13 |
+
sd = load_file(file)
|
| 14 |
+
else:
|
| 15 |
+
sd = torch.load(file, map_location="cpu")
|
| 16 |
+
|
| 17 |
+
values = []
|
| 18 |
+
|
| 19 |
+
keys = list(sd.keys())
|
| 20 |
+
for key in keys:
|
| 21 |
+
if "lora_up" in key or "lora_down" in key or "lora_A" in key or "lora_B" in key or "oft_" in key:
|
| 22 |
+
values.append((key, sd[key]))
|
| 23 |
+
print(f"number of LoRA modules: {len(values)}")
|
| 24 |
+
|
| 25 |
+
if args.show_all_keys:
|
| 26 |
+
for key in [k for k in keys if k not in values]:
|
| 27 |
+
values.append((key, sd[key]))
|
| 28 |
+
print(f"number of all modules: {len(values)}")
|
| 29 |
+
|
| 30 |
+
for key, value in values:
|
| 31 |
+
value = value.to(torch.float32)
|
| 32 |
+
print(f"{key},{str(tuple(value.size())).replace(', ', '-')},{torch.mean(torch.abs(value))},{torch.min(torch.abs(value))}")
|
| 33 |
+
|
| 34 |
+
|
| 35 |
+
def setup_parser() -> argparse.ArgumentParser:
|
| 36 |
+
parser = argparse.ArgumentParser()
|
| 37 |
+
parser.add_argument("file", type=str, help="model file to check / 重みを確認するモデルファイル")
|
| 38 |
+
parser.add_argument("-s", "--show_all_keys", action="store_true", help="show all keys / 全てのキーを表示する")
|
| 39 |
+
|
| 40 |
+
return parser
|
| 41 |
+
|
| 42 |
+
|
| 43 |
+
if __name__ == "__main__":
|
| 44 |
+
parser = setup_parser()
|
| 45 |
+
|
| 46 |
+
args = parser.parse_args()
|
| 47 |
+
|
| 48 |
+
main(args.file)
|
clean_captions_and_tags.py
ADDED
|
@@ -0,0 +1,194 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# このスクリプトのライセンスは、Apache License 2.0とします
|
| 2 |
+
# (c) 2022 Kohya S. @kohya_ss
|
| 3 |
+
|
| 4 |
+
import argparse
|
| 5 |
+
import glob
|
| 6 |
+
import os
|
| 7 |
+
import json
|
| 8 |
+
import re
|
| 9 |
+
|
| 10 |
+
from tqdm import tqdm
|
| 11 |
+
from library.utils import setup_logging
|
| 12 |
+
setup_logging()
|
| 13 |
+
import logging
|
| 14 |
+
logger = logging.getLogger(__name__)
|
| 15 |
+
|
| 16 |
+
PATTERN_HAIR_LENGTH = re.compile(r', (long|short|medium) hair, ')
|
| 17 |
+
PATTERN_HAIR_CUT = re.compile(r', (bob|hime) cut, ')
|
| 18 |
+
PATTERN_HAIR = re.compile(r', ([\w\-]+) hair, ')
|
| 19 |
+
PATTERN_WORD = re.compile(r', ([\w\-]+|hair ornament), ')
|
| 20 |
+
|
| 21 |
+
# 複数人がいるとき、複数の髪色や目の色が定義されていれば削除する
|
| 22 |
+
PATTERNS_REMOVE_IN_MULTI = [
|
| 23 |
+
PATTERN_HAIR_LENGTH,
|
| 24 |
+
PATTERN_HAIR_CUT,
|
| 25 |
+
re.compile(r', [\w\-]+ eyes, '),
|
| 26 |
+
re.compile(r', ([\w\-]+ sleeves|sleeveless), '),
|
| 27 |
+
# 複数の髪型定義がある場合は削除する
|
| 28 |
+
re.compile(
|
| 29 |
+
r', (ponytail|braid|ahoge|twintails|[\w\-]+ bun|single hair bun|single side bun|two side up|two tails|[\w\-]+ braid|sidelocks), '),
|
| 30 |
+
]
|
| 31 |
+
|
| 32 |
+
|
| 33 |
+
def clean_tags(image_key, tags):
|
| 34 |
+
# replace '_' to ' '
|
| 35 |
+
tags = tags.replace('^_^', '^@@@^')
|
| 36 |
+
tags = tags.replace('_', ' ')
|
| 37 |
+
tags = tags.replace('^@@@^', '^_^')
|
| 38 |
+
|
| 39 |
+
# remove rating: deepdanbooruのみ
|
| 40 |
+
tokens = tags.split(", rating")
|
| 41 |
+
if len(tokens) == 1:
|
| 42 |
+
# WD14 taggerのときはこちらになるのでメッセージは出さない
|
| 43 |
+
# logger.info("no rating:")
|
| 44 |
+
# logger.info(f"{image_key} {tags}")
|
| 45 |
+
pass
|
| 46 |
+
else:
|
| 47 |
+
if len(tokens) > 2:
|
| 48 |
+
logger.info("multiple ratings:")
|
| 49 |
+
logger.info(f"{image_key} {tags}")
|
| 50 |
+
tags = tokens[0]
|
| 51 |
+
|
| 52 |
+
tags = ", " + tags.replace(", ", ", , ") + ", " # カンマ付きで検索をするための身も蓋もない対策
|
| 53 |
+
|
| 54 |
+
# 複数の人物がいる場合は髪色等のタグを削除する
|
| 55 |
+
if 'girls' in tags or 'boys' in tags:
|
| 56 |
+
for pat in PATTERNS_REMOVE_IN_MULTI:
|
| 57 |
+
found = pat.findall(tags)
|
| 58 |
+
if len(found) > 1: # 二つ以上、タグがある
|
| 59 |
+
tags = pat.sub("", tags)
|
| 60 |
+
|
| 61 |
+
# 髪の特殊対応
|
| 62 |
+
srch_hair_len = PATTERN_HAIR_LENGTH.search(tags) # 髪の長さタグは例外なので避けておく(全員が同じ髪の長さの場合)
|
| 63 |
+
if srch_hair_len:
|
| 64 |
+
org = srch_hair_len.group()
|
| 65 |
+
tags = PATTERN_HAIR_LENGTH.sub(", @@@, ", tags)
|
| 66 |
+
|
| 67 |
+
found = PATTERN_HAIR.findall(tags)
|
| 68 |
+
if len(found) > 1:
|
| 69 |
+
tags = PATTERN_HAIR.sub("", tags)
|
| 70 |
+
|
| 71 |
+
if srch_hair_len:
|
| 72 |
+
tags = tags.replace(", @@@, ", org) # 戻す
|
| 73 |
+
|
| 74 |
+
# white shirtとshirtみたいな重複タグの削除
|
| 75 |
+
found = PATTERN_WORD.findall(tags)
|
| 76 |
+
for word in found:
|
| 77 |
+
if re.search(f", ((\w+) )+{word}, ", tags):
|
| 78 |
+
tags = tags.replace(f", {word}, ", "")
|
| 79 |
+
|
| 80 |
+
tags = tags.replace(", , ", ", ")
|
| 81 |
+
assert tags.startswith(", ") and tags.endswith(", ")
|
| 82 |
+
tags = tags[2:-2]
|
| 83 |
+
return tags
|
| 84 |
+
|
| 85 |
+
|
| 86 |
+
# 上から順に検索、置換される
|
| 87 |
+
# ('置換元文字列', '置換後文字列')
|
| 88 |
+
CAPTION_REPLACEMENTS = [
|
| 89 |
+
('anime anime', 'anime'),
|
| 90 |
+
('young ', ''),
|
| 91 |
+
('anime girl', 'girl'),
|
| 92 |
+
('cartoon female', 'girl'),
|
| 93 |
+
('cartoon lady', 'girl'),
|
| 94 |
+
('cartoon character', 'girl'), # a or ~s
|
| 95 |
+
('cartoon woman', 'girl'),
|
| 96 |
+
('cartoon women', 'girls'),
|
| 97 |
+
('cartoon girl', 'girl'),
|
| 98 |
+
('anime female', 'girl'),
|
| 99 |
+
('anime lady', 'girl'),
|
| 100 |
+
('anime character', 'girl'), # a or ~s
|
| 101 |
+
('anime woman', 'girl'),
|
| 102 |
+
('anime women', 'girls'),
|
| 103 |
+
('lady', 'girl'),
|
| 104 |
+
('female', 'girl'),
|
| 105 |
+
('woman', 'girl'),
|
| 106 |
+
('women', 'girls'),
|
| 107 |
+
('people', 'girls'),
|
| 108 |
+
('person', 'girl'),
|
| 109 |
+
('a cartoon figure', 'a figure'),
|
| 110 |
+
('a cartoon image', 'an image'),
|
| 111 |
+
('a cartoon picture', 'a picture'),
|
| 112 |
+
('an anime cartoon image', 'an image'),
|
| 113 |
+
('a cartoon anime drawing', 'a drawing'),
|
| 114 |
+
('a cartoon drawing', 'a drawing'),
|
| 115 |
+
('girl girl', 'girl'),
|
| 116 |
+
]
|
| 117 |
+
|
| 118 |
+
|
| 119 |
+
def clean_caption(caption):
|
| 120 |
+
for rf, rt in CAPTION_REPLACEMENTS:
|
| 121 |
+
replaced = True
|
| 122 |
+
while replaced:
|
| 123 |
+
bef = caption
|
| 124 |
+
caption = caption.replace(rf, rt)
|
| 125 |
+
replaced = bef != caption
|
| 126 |
+
return caption
|
| 127 |
+
|
| 128 |
+
|
| 129 |
+
def main(args):
|
| 130 |
+
if os.path.exists(args.in_json):
|
| 131 |
+
logger.info(f"loading existing metadata: {args.in_json}")
|
| 132 |
+
with open(args.in_json, "rt", encoding='utf-8') as f:
|
| 133 |
+
metadata = json.load(f)
|
| 134 |
+
else:
|
| 135 |
+
logger.error("no metadata / メタデータファイルがありません")
|
| 136 |
+
return
|
| 137 |
+
|
| 138 |
+
logger.info("cleaning captions and tags.")
|
| 139 |
+
image_keys = list(metadata.keys())
|
| 140 |
+
for image_key in tqdm(image_keys):
|
| 141 |
+
tags = metadata[image_key].get('tags')
|
| 142 |
+
if tags is None:
|
| 143 |
+
logger.error(f"image does not have tags / メタデータにタグがありません: {image_key}")
|
| 144 |
+
else:
|
| 145 |
+
org = tags
|
| 146 |
+
tags = clean_tags(image_key, tags)
|
| 147 |
+
metadata[image_key]['tags'] = tags
|
| 148 |
+
if args.debug and org != tags:
|
| 149 |
+
logger.info("FROM: " + org)
|
| 150 |
+
logger.info("TO: " + tags)
|
| 151 |
+
|
| 152 |
+
caption = metadata[image_key].get('caption')
|
| 153 |
+
if caption is None:
|
| 154 |
+
logger.error(f"image does not have caption / メタデータにキャプションがありません: {image_key}")
|
| 155 |
+
else:
|
| 156 |
+
org = caption
|
| 157 |
+
caption = clean_caption(caption)
|
| 158 |
+
metadata[image_key]['caption'] = caption
|
| 159 |
+
if args.debug and org != caption:
|
| 160 |
+
logger.info("FROM: " + org)
|
| 161 |
+
logger.info("TO: " + caption)
|
| 162 |
+
|
| 163 |
+
# metadataを書き出して終わり
|
| 164 |
+
logger.info(f"writing metadata: {args.out_json}")
|
| 165 |
+
with open(args.out_json, "wt", encoding='utf-8') as f:
|
| 166 |
+
json.dump(metadata, f, indent=2)
|
| 167 |
+
logger.info("done!")
|
| 168 |
+
|
| 169 |
+
|
| 170 |
+
def setup_parser() -> argparse.ArgumentParser:
|
| 171 |
+
parser = argparse.ArgumentParser()
|
| 172 |
+
# parser.add_argument("train_data_dir", type=str, help="directory for train images / 学習画像データのディレクトリ")
|
| 173 |
+
parser.add_argument("in_json", type=str, help="metadata file to input / 読み込むメタデータファイル")
|
| 174 |
+
parser.add_argument("out_json", type=str, help="metadata file to output / メタデータファイル書き出し先")
|
| 175 |
+
parser.add_argument("--debug", action="store_true", help="debug mode")
|
| 176 |
+
|
| 177 |
+
return parser
|
| 178 |
+
|
| 179 |
+
|
| 180 |
+
if __name__ == '__main__':
|
| 181 |
+
parser = setup_parser()
|
| 182 |
+
|
| 183 |
+
args, unknown = parser.parse_known_args()
|
| 184 |
+
if len(unknown) == 1:
|
| 185 |
+
logger.warning("WARNING: train_data_dir argument is removed. This script will not work with three arguments in future. Please specify two arguments: in_json and out_json.")
|
| 186 |
+
logger.warning("All captions and tags in the metadata are processed.")
|
| 187 |
+
logger.warning("警告: train_data_dir引数は不要になりました。将来的には三つの引数を指定すると動かなくなる予定です。読み込み元のメタデータと書き出し先の二つの引数だけ指定してください。")
|
| 188 |
+
logger.warning("メタデータ内のすべてのキャプションとタグが処理されます。")
|
| 189 |
+
args.in_json = args.out_json
|
| 190 |
+
args.out_json = unknown[0]
|
| 191 |
+
elif len(unknown) > 0:
|
| 192 |
+
raise ValueError(f"error: unrecognized arguments: {unknown}")
|
| 193 |
+
|
| 194 |
+
main(args)
|
config_README-en.md
ADDED
|
@@ -0,0 +1,386 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
Original Source by kohya-ss
|
| 2 |
+
|
| 3 |
+
First version:
|
| 4 |
+
A.I Translation by Model: NousResearch/Nous-Hermes-2-Mixtral-8x7B-DPO, editing by Darkstorm2150
|
| 5 |
+
|
| 6 |
+
Some parts are manually added.
|
| 7 |
+
|
| 8 |
+
# Config Readme
|
| 9 |
+
|
| 10 |
+
This README is about the configuration files that can be passed with the `--dataset_config` option.
|
| 11 |
+
|
| 12 |
+
## Overview
|
| 13 |
+
|
| 14 |
+
By passing a configuration file, users can make detailed settings.
|
| 15 |
+
|
| 16 |
+
* Multiple datasets can be configured
|
| 17 |
+
* For example, by setting `resolution` for each dataset, they can be mixed and trained.
|
| 18 |
+
* In training methods that support both the DreamBooth approach and the fine-tuning approach, datasets of the DreamBooth method and the fine-tuning method can be mixed.
|
| 19 |
+
* Settings can be changed for each subset
|
| 20 |
+
* A subset is a partition of the dataset by image directory or metadata. Several subsets make up a dataset.
|
| 21 |
+
* Options such as `keep_tokens` and `flip_aug` can be set for each subset. On the other hand, options such as `resolution` and `batch_size` can be set for each dataset, and their values are common among subsets belonging to the same dataset. More details will be provided later.
|
| 22 |
+
|
| 23 |
+
The configuration file format can be JSON or TOML. Considering the ease of writing, it is recommended to use [TOML](https://toml.io/ja/v1.0.0-rc.2). The following explanation assumes the use of TOML.
|
| 24 |
+
|
| 25 |
+
|
| 26 |
+
Here is an example of a configuration file written in TOML.
|
| 27 |
+
|
| 28 |
+
```toml
|
| 29 |
+
[general]
|
| 30 |
+
shuffle_caption = true
|
| 31 |
+
caption_extension = '.txt'
|
| 32 |
+
keep_tokens = 1
|
| 33 |
+
|
| 34 |
+
# This is a DreamBooth-style dataset
|
| 35 |
+
[[datasets]]
|
| 36 |
+
resolution = 512
|
| 37 |
+
batch_size = 4
|
| 38 |
+
keep_tokens = 2
|
| 39 |
+
|
| 40 |
+
[[datasets.subsets]]
|
| 41 |
+
image_dir = 'C:\hoge'
|
| 42 |
+
class_tokens = 'hoge girl'
|
| 43 |
+
# This subset uses keep_tokens = 2 (the value of the parent datasets)
|
| 44 |
+
|
| 45 |
+
[[datasets.subsets]]
|
| 46 |
+
image_dir = 'C:\fuga'
|
| 47 |
+
class_tokens = 'fuga boy'
|
| 48 |
+
keep_tokens = 3
|
| 49 |
+
|
| 50 |
+
[[datasets.subsets]]
|
| 51 |
+
is_reg = true
|
| 52 |
+
image_dir = 'C:\reg'
|
| 53 |
+
class_tokens = 'human'
|
| 54 |
+
keep_tokens = 1
|
| 55 |
+
|
| 56 |
+
# This is a fine-tuning dataset
|
| 57 |
+
[[datasets]]
|
| 58 |
+
resolution = [768, 768]
|
| 59 |
+
batch_size = 2
|
| 60 |
+
|
| 61 |
+
[[datasets.subsets]]
|
| 62 |
+
image_dir = 'C:\piyo'
|
| 63 |
+
metadata_file = 'C:\piyo\piyo_md.json'
|
| 64 |
+
# This subset uses keep_tokens = 1 (the value of [general])
|
| 65 |
+
```
|
| 66 |
+
|
| 67 |
+
In this example, three directories are trained as a DreamBooth-style dataset at 512x512 (batch size 4), and one directory is trained as a fine-tuning dataset at 768x768 (batch size 2).
|
| 68 |
+
|
| 69 |
+
## Settings for datasets and subsets
|
| 70 |
+
|
| 71 |
+
Settings for datasets and subsets are divided into several registration locations.
|
| 72 |
+
|
| 73 |
+
* `[general]`
|
| 74 |
+
* This is where options that apply to all datasets or all subsets are specified.
|
| 75 |
+
* If there are options with the same name in the dataset-specific or subset-specific settings, the dataset-specific or subset-specific settings take precedence.
|
| 76 |
+
* `[[datasets]]`
|
| 77 |
+
* `datasets` is where settings for datasets are registered. This is where options that apply individually to each dataset are specified.
|
| 78 |
+
* If there are subset-specific settings, the subset-specific settings take precedence.
|
| 79 |
+
* `[[datasets.subsets]]`
|
| 80 |
+
* `datasets.subsets` is where settings for subsets are registered. This is where options that apply individually to each subset are specified.
|
| 81 |
+
|
| 82 |
+
Here is an image showing the correspondence between image directories and registration locations in the previous example.
|
| 83 |
+
|
| 84 |
+
```
|
| 85 |
+
C:\
|
| 86 |
+
├─ hoge -> [[datasets.subsets]] No.1 ┐ ┐
|
| 87 |
+
├─ fuga -> [[datasets.subsets]] No.2 |-> [[datasets]] No.1 |-> [general]
|
| 88 |
+
├─ reg -> [[datasets.subsets]] No.3 ┘ |
|
| 89 |
+
└─ piyo -> [[datasets.subsets]] No.4 --> [[datasets]] No.2 ┘
|
| 90 |
+
```
|
| 91 |
+
|
| 92 |
+
The image directory corresponds to each `[[datasets.subsets]]`. Then, multiple `[[datasets.subsets]]` are combined to form one `[[datasets]]`. All `[[datasets]]` and `[[datasets.subsets]]` belong to `[general]`.
|
| 93 |
+
|
| 94 |
+
The available options for each registration location may differ, but if the same option is specified, the value in the lower registration location will take precedence. You can check how the `keep_tokens` option is handled in the previous example for better understanding.
|
| 95 |
+
|
| 96 |
+
Additionally, the available options may vary depending on the method that the learning approach supports.
|
| 97 |
+
|
| 98 |
+
* Options specific to the DreamBooth method
|
| 99 |
+
* Options specific to the fine-tuning method
|
| 100 |
+
* Options available when using the caption dropout technique
|
| 101 |
+
|
| 102 |
+
When using both the DreamBooth method and the fine-tuning method, they can be used together with a learning approach that supports both.
|
| 103 |
+
When using them together, a point to note is that the method is determined based on the dataset, so it is not possible to mix DreamBooth method subsets and fine-tuning method subsets within the same dataset.
|
| 104 |
+
In other words, if you want to use both methods together, you need to set up subsets of different methods belonging to different datasets.
|
| 105 |
+
|
| 106 |
+
In terms of program behavior, if the `metadata_file` option exists, it is determined to be a subset of fine-tuning. Therefore, for subsets belonging to the same dataset, as long as they are either "all have the `metadata_file` option" or "all have no `metadata_file` option," there is no problem.
|
| 107 |
+
|
| 108 |
+
Below, the available options will be explained. For options with the same name as the command-line argument, the explanation will be omitted in principle. Please refer to other READMEs.
|
| 109 |
+
|
| 110 |
+
### Common options for all learning methods
|
| 111 |
+
|
| 112 |
+
These are options that can be specified regardless of the learning method.
|
| 113 |
+
|
| 114 |
+
#### Data set specific options
|
| 115 |
+
|
| 116 |
+
These are options related to the configuration of the data set. They cannot be described in `datasets.subsets`.
|
| 117 |
+
|
| 118 |
+
|
| 119 |
+
| Option Name | Example Setting | `[general]` | `[[datasets]]` |
|
| 120 |
+
| ---- | ---- | ---- | ---- |
|
| 121 |
+
| `batch_size` | `1` | o | o |
|
| 122 |
+
| `bucket_no_upscale` | `true` | o | o |
|
| 123 |
+
| `bucket_reso_steps` | `64` | o | o |
|
| 124 |
+
| `enable_bucket` | `true` | o | o |
|
| 125 |
+
| `max_bucket_reso` | `1024` | o | o |
|
| 126 |
+
| `min_bucket_reso` | `128` | o | o |
|
| 127 |
+
| `resolution` | `256`, `[512, 512]` | o | o |
|
| 128 |
+
|
| 129 |
+
* `batch_size`
|
| 130 |
+
* This corresponds to the command-line argument `--train_batch_size`.
|
| 131 |
+
* `max_bucket_reso`, `min_bucket_reso`
|
| 132 |
+
* Specify the maximum and minimum resolutions of the bucket. It must be divisible by `bucket_reso_steps`.
|
| 133 |
+
|
| 134 |
+
These settings are fixed per dataset. That means that subsets belonging to the same dataset will share these settings. For example, if you want to prepare datasets with different resolutions, you can define them as separate datasets as shown in the example above, and set different resolutions for each.
|
| 135 |
+
|
| 136 |
+
#### Options for Subsets
|
| 137 |
+
|
| 138 |
+
These options are related to subset configuration.
|
| 139 |
+
|
| 140 |
+
| Option Name | Example | `[general]` | `[[datasets]]` | `[[dataset.subsets]]` |
|
| 141 |
+
| ---- | ---- | ---- | ---- | ---- |
|
| 142 |
+
| `color_aug` | `false` | o | o | o |
|
| 143 |
+
| `face_crop_aug_range` | `[1.0, 3.0]` | o | o | o |
|
| 144 |
+
| `flip_aug` | `true` | o | o | o |
|
| 145 |
+
| `keep_tokens` | `2` | o | o | o |
|
| 146 |
+
| `num_repeats` | `10` | o | o | o |
|
| 147 |
+
| `random_crop` | `false` | o | o | o |
|
| 148 |
+
| `shuffle_caption` | `true` | o | o | o |
|
| 149 |
+
| `caption_prefix` | `"masterpiece, best quality, "` | o | o | o |
|
| 150 |
+
| `caption_suffix` | `", from side"` | o | o | o |
|
| 151 |
+
| `caption_separator` | (not specified) | o | o | o |
|
| 152 |
+
| `keep_tokens_separator` | `“|||”` | o | o | o |
|
| 153 |
+
| `secondary_separator` | `“;;;”` | o | o | o |
|
| 154 |
+
| `enable_wildcard` | `true` | o | o | o |
|
| 155 |
+
|
| 156 |
+
* `num_repeats`
|
| 157 |
+
* Specifies the number of repeats for images in a subset. This is equivalent to `--dataset_repeats` in fine-tuning but can be specified for any training method.
|
| 158 |
+
* `caption_prefix`, `caption_suffix`
|
| 159 |
+
* Specifies the prefix and suffix strings to be appended to the captions. Shuffling is performed with these strings included. Be cautious when using `keep_tokens`.
|
| 160 |
+
* `caption_separator`
|
| 161 |
+
* Specifies the string to separate the tags. The default is `,`. This option is usually not necessary to set.
|
| 162 |
+
* `keep_tokens_separator`
|
| 163 |
+
* Specifies the string to separate the parts to be fixed in the caption. For example, if you specify `aaa, bbb ||| ccc, ddd, eee, fff ||| ggg, hhh`, the parts `aaa, bbb` and `ggg, hhh` will remain, and the rest will be shuffled and dropped. The comma in between is not necessary. As a result, the prompt will be `aaa, bbb, eee, ccc, fff, ggg, hhh` or `aaa, bbb, fff, ccc, eee, ggg, hhh`, etc.
|
| 164 |
+
* `secondary_separator`
|
| 165 |
+
* Specifies an additional separator. The part separated by this separator is treated as one tag and is shuffled and dropped. It is then replaced by `caption_separator`. For example, if you specify `aaa;;;bbb;;;ccc`, it will be replaced by `aaa,bbb,ccc` or dropped together.
|
| 166 |
+
* `enable_wildcard`
|
| 167 |
+
* Enables wildcard notation. This will be explained later.
|
| 168 |
+
|
| 169 |
+
### DreamBooth-specific options
|
| 170 |
+
|
| 171 |
+
DreamBooth-specific options only exist as subsets-specific options.
|
| 172 |
+
|
| 173 |
+
#### Subset-specific options
|
| 174 |
+
|
| 175 |
+
Options related to the configuration of DreamBooth subsets.
|
| 176 |
+
|
| 177 |
+
| Option Name | Example Setting | `[general]` | `[[datasets]]` | `[[dataset.subsets]]` |
|
| 178 |
+
| ---- | ---- | ---- | ---- | ---- |
|
| 179 |
+
| `image_dir` | `'C:\hoge'` | - | - | o (required) |
|
| 180 |
+
| `caption_extension` | `".txt"` | o | o | o |
|
| 181 |
+
| `class_tokens` | `"sks girl"` | - | - | o |
|
| 182 |
+
| `cache_info` | `false` | o | o | o |
|
| 183 |
+
| `is_reg` | `false` | - | - | o |
|
| 184 |
+
|
| 185 |
+
Firstly, note that for `image_dir`, the path to the image files must be specified as being directly in the directory. Unlike the previous DreamBooth method, where images had to be placed in subdirectories, this is not compatible with that specification. Also, even if you name the folder something like "5_cat", the number of repeats of the image and the class name will not be reflected. If you want to set these individually, you will need to explicitly specify them using `num_repeats` and `class_tokens`.
|
| 186 |
+
|
| 187 |
+
* `image_dir`
|
| 188 |
+
* Specifies the path to the image directory. This is a required option.
|
| 189 |
+
* Images must be placed directly under the directory.
|
| 190 |
+
* `class_tokens`
|
| 191 |
+
* Sets the class tokens.
|
| 192 |
+
* Only used during training when a corresponding caption file does not exist. The determination of whether or not to use it is made on a per-image basis. If `class_tokens` is not specified and a caption file is not found, an error will occur.
|
| 193 |
+
* `cache_info`
|
| 194 |
+
* Specifies whether to cache the image size and caption. If not specified, it is set to `false`. The cache is saved in `metadata_cache.json` in `image_dir`.
|
| 195 |
+
* Caching speeds up the loading of the dataset after the first time. It is effective when dealing with thousands of images or more.
|
| 196 |
+
* `is_reg`
|
| 197 |
+
* Specifies whether the subset images are for normalization. If not specified, it is set to `false`, meaning that the images are not for normalization.
|
| 198 |
+
|
| 199 |
+
### Fine-tuning method specific options
|
| 200 |
+
|
| 201 |
+
The options for the fine-tuning method only exist for subset-specific options.
|
| 202 |
+
|
| 203 |
+
#### Subset-specific options
|
| 204 |
+
|
| 205 |
+
These options are related to the configuration of the fine-tuning method's subsets.
|
| 206 |
+
|
| 207 |
+
| Option name | Example setting | `[general]` | `[[datasets]]` | `[[dataset.subsets]]` |
|
| 208 |
+
| ---- | ---- | ---- | ---- | ---- |
|
| 209 |
+
| `image_dir` | `'C:\hoge'` | - | - | o |
|
| 210 |
+
| `metadata_file` | `'C:\piyo\piyo_md.json'` | - | - | o (required) |
|
| 211 |
+
|
| 212 |
+
* `image_dir`
|
| 213 |
+
* Specify the path to the image directory. Unlike the DreamBooth method, specifying it is not mandatory, but it is recommended to do so.
|
| 214 |
+
* The case where it is not necessary to specify is when the `--full_path` is added to the command line when generating the metadata file.
|
| 215 |
+
* The images must be placed directly under the directory.
|
| 216 |
+
* `metadata_file`
|
| 217 |
+
* Specify the path to the metadata file used for the subset. This is a required option.
|
| 218 |
+
* It is equivalent to the command-line argument `--in_json`.
|
| 219 |
+
* Due to the specification that a metadata file must be specified for each subset, it is recommended to avoid creating a metadata file with images from different directories as a single metadata file. It is strongly recommended to prepare a separate metadata file for each image directory and register them as separate subsets.
|
| 220 |
+
|
| 221 |
+
### Options available when caption dropout method can be used
|
| 222 |
+
|
| 223 |
+
The options available when the caption dropout method can be used exist only for subsets. Regardless of whether it's the DreamBooth method or fine-tuning method, if it supports caption dropout, it can be specified.
|
| 224 |
+
|
| 225 |
+
#### Subset-specific options
|
| 226 |
+
|
| 227 |
+
Options related to the setting of subsets that caption dropout can be used for.
|
| 228 |
+
|
| 229 |
+
| Option Name | `[general]` | `[[datasets]]` | `[[dataset.subsets]]` |
|
| 230 |
+
| ---- | ---- | ---- | ---- |
|
| 231 |
+
| `caption_dropout_every_n_epochs` | o | o | o |
|
| 232 |
+
| `caption_dropout_rate` | o | o | o |
|
| 233 |
+
| `caption_tag_dropout_rate` | o | o | o |
|
| 234 |
+
|
| 235 |
+
## Behavior when there are duplicate subsets
|
| 236 |
+
|
| 237 |
+
In the case of the DreamBooth dataset, if there are multiple `image_dir` directories with the same content, they are considered to be duplicate subsets. For the fine-tuning dataset, if there are multiple `metadata_file` files with the same content, they are considered to be duplicate subsets. If duplicate subsets exist in the dataset, subsequent subsets will be ignored.
|
| 238 |
+
|
| 239 |
+
However, if they belong to different datasets, they are not considered duplicates. For example, if you have subsets with the same `image_dir` in different datasets, they will not be considered duplicates. This is useful when you want to train with the same image but with different resolutions.
|
| 240 |
+
|
| 241 |
+
```toml
|
| 242 |
+
# If data sets exist separately, they are not considered duplicates and are both used for training.
|
| 243 |
+
|
| 244 |
+
[[datasets]]
|
| 245 |
+
resolution = 512
|
| 246 |
+
|
| 247 |
+
[[datasets.subsets]]
|
| 248 |
+
image_dir = 'C:\hoge'
|
| 249 |
+
|
| 250 |
+
[[datasets]]
|
| 251 |
+
resolution = 768
|
| 252 |
+
|
| 253 |
+
[[datasets.subsets]]
|
| 254 |
+
image_dir = 'C:\hoge'
|
| 255 |
+
```
|
| 256 |
+
|
| 257 |
+
## Command Line Argument and Configuration File
|
| 258 |
+
|
| 259 |
+
There are options in the configuration file that have overlapping roles with command line argument options.
|
| 260 |
+
|
| 261 |
+
The following command line argument options are ignored if a configuration file is passed:
|
| 262 |
+
|
| 263 |
+
* `--train_data_dir`
|
| 264 |
+
* `--reg_data_dir`
|
| 265 |
+
* `--in_json`
|
| 266 |
+
|
| 267 |
+
The following command line argument options are given priority over the configuration file options if both are specified simultaneously. In most cases, they have the same names as the corresponding options in the configuration file.
|
| 268 |
+
|
| 269 |
+
| Command Line Argument Option | Prioritized Configuration File Option |
|
| 270 |
+
| ------------------------------- | ------------------------------------- |
|
| 271 |
+
| `--bucket_no_upscale` | |
|
| 272 |
+
| `--bucket_reso_steps` | |
|
| 273 |
+
| `--caption_dropout_every_n_epochs` | |
|
| 274 |
+
| `--caption_dropout_rate` | |
|
| 275 |
+
| `--caption_extension` | |
|
| 276 |
+
| `--caption_tag_dropout_rate` | |
|
| 277 |
+
| `--color_aug` | |
|
| 278 |
+
| `--dataset_repeats` | `num_repeats` |
|
| 279 |
+
| `--enable_bucket` | |
|
| 280 |
+
| `--face_crop_aug_range` | |
|
| 281 |
+
| `--flip_aug` | |
|
| 282 |
+
| `--keep_tokens` | |
|
| 283 |
+
| `--min_bucket_reso` | |
|
| 284 |
+
| `--random_crop` | |
|
| 285 |
+
| `--resolution` | |
|
| 286 |
+
| `--shuffle_caption` | |
|
| 287 |
+
| `--train_batch_size` | `batch_size` |
|
| 288 |
+
|
| 289 |
+
## Error Guide
|
| 290 |
+
|
| 291 |
+
Currently, we are using an external library to check if the configuration file is written correctly, but the development has not been completed, and there is a problem that the error message is not clear. In the future, we plan to improve this problem.
|
| 292 |
+
|
| 293 |
+
As a temporary measure, we will list common errors and their solutions. If you encounter an error even though it should be correct or if the error content is not understandable, please contact us as it may be a bug.
|
| 294 |
+
|
| 295 |
+
* `voluptuous.error.MultipleInvalid: required key not provided @ ...`: This error occurs when a required option is not provided. It is highly likely that you forgot to specify the option or misspelled the option name.
|
| 296 |
+
* The error location is indicated by `...` in the error message. For example, if you encounter an error like `voluptuous.error.MultipleInvalid: required key not provided @ data['datasets'][0]['subsets'][0]['image_dir']`, it means that the `image_dir` option does not exist in the 0th `subsets` of the 0th `datasets` setting.
|
| 297 |
+
* `voluptuous.error.MultipleInvalid: expected int for dictionary value @ ...`: This error occurs when the specified value format is incorrect. It is highly likely that the value format is incorrect. The `int` part changes depending on the target option. The example configurations in this README may be helpful.
|
| 298 |
+
* `voluptuous.error.MultipleInvalid: extra keys not allowed @ ...`: This error occurs when there is an option name that is not supported. It is highly likely that you misspelled the option name or mistakenly included it.
|
| 299 |
+
|
| 300 |
+
## Miscellaneous
|
| 301 |
+
|
| 302 |
+
### Multi-line captions
|
| 303 |
+
|
| 304 |
+
By setting `enable_wildcard = true`, multiple-line captions are also enabled. If the caption file consists of multiple lines, one line is randomly selected as the caption.
|
| 305 |
+
|
| 306 |
+
```txt
|
| 307 |
+
1girl, hatsune miku, vocaloid, upper body, looking at viewer, microphone, stage
|
| 308 |
+
a girl with a microphone standing on a stage
|
| 309 |
+
detailed digital art of a girl with a microphone on a stage
|
| 310 |
+
```
|
| 311 |
+
|
| 312 |
+
It can be combined with wildcard notation.
|
| 313 |
+
|
| 314 |
+
In metadata files, you can also specify multiple-line captions. In the `.json` metadata file, use `\n` to represent a line break. If the caption file consists of multiple lines, `merge_captions_to_metadata.py` will create a metadata file in this format.
|
| 315 |
+
|
| 316 |
+
The tags in the metadata (`tags`) are added to each line of the caption.
|
| 317 |
+
|
| 318 |
+
```json
|
| 319 |
+
{
|
| 320 |
+
"/path/to/image.png": {
|
| 321 |
+
"caption": "a cartoon of a frog with the word frog on it\ntest multiline caption1\ntest multiline caption2",
|
| 322 |
+
"tags": "open mouth, simple background, standing, no humans, animal, black background, frog, animal costume, animal focus"
|
| 323 |
+
},
|
| 324 |
+
...
|
| 325 |
+
}
|
| 326 |
+
```
|
| 327 |
+
|
| 328 |
+
In this case, the actual caption will be `a cartoon of a frog with the word frog on it, open mouth, simple background ...`, `test multiline caption1, open mouth, simple background ...`, `test multiline caption2, open mouth, simple background ...`, etc.
|
| 329 |
+
|
| 330 |
+
### Example of configuration file : `secondary_separator`, wildcard notation, `keep_tokens_separator`, etc.
|
| 331 |
+
|
| 332 |
+
```toml
|
| 333 |
+
[general]
|
| 334 |
+
flip_aug = true
|
| 335 |
+
color_aug = false
|
| 336 |
+
resolution = [1024, 1024]
|
| 337 |
+
|
| 338 |
+
[[datasets]]
|
| 339 |
+
batch_size = 6
|
| 340 |
+
enable_bucket = true
|
| 341 |
+
bucket_no_upscale = true
|
| 342 |
+
caption_extension = ".txt"
|
| 343 |
+
keep_tokens_separator= "|||"
|
| 344 |
+
shuffle_caption = true
|
| 345 |
+
caption_tag_dropout_rate = 0.1
|
| 346 |
+
secondary_separator = ";;;" # subset 側に書くこともできます / can be written in the subset side
|
| 347 |
+
enable_wildcard = true # 同上 / same as above
|
| 348 |
+
|
| 349 |
+
[[datasets.subsets]]
|
| 350 |
+
image_dir = "/path/to/image_dir"
|
| 351 |
+
num_repeats = 1
|
| 352 |
+
|
| 353 |
+
# ||| の前後はカンマは不要です(自動的に追加されます) / No comma is required before and after ||| (it is added automatically)
|
| 354 |
+
caption_prefix = "1girl, hatsune miku, vocaloid |||"
|
| 355 |
+
|
| 356 |
+
# ||| の後はシャッフル、drop されず残ります / After |||, it is not shuffled or dropped and remains
|
| 357 |
+
# 単純に文字列として連結されるので、カンマなどは自分で入れる必要があります / It is simply concatenated as a string, so you need to put commas yourself
|
| 358 |
+
caption_suffix = ", anime screencap ||| masterpiece, rating: general"
|
| 359 |
+
```
|
| 360 |
+
|
| 361 |
+
### Example of caption, secondary_separator notation: `secondary_separator = ";;;"`
|
| 362 |
+
|
| 363 |
+
```txt
|
| 364 |
+
1girl, hatsune miku, vocaloid, upper body, looking at viewer, sky;;;cloud;;;day, outdoors
|
| 365 |
+
```
|
| 366 |
+
The part `sky;;;cloud;;;day` is replaced with `sky,cloud,day` without shuffling or dropping. When shuffling and dropping are enabled, it is processed as a whole (as one tag). For example, it becomes `vocaloid, 1girl, upper body, sky,cloud,day, outdoors, hatsune miku` (shuffled) or `vocaloid, 1girl, outdoors, looking at viewer, upper body, hatsune miku` (dropped).
|
| 367 |
+
|
| 368 |
+
### Example of caption, enable_wildcard notation: `enable_wildcard = true`
|
| 369 |
+
|
| 370 |
+
```txt
|
| 371 |
+
1girl, hatsune miku, vocaloid, upper body, looking at viewer, {simple|white} background
|
| 372 |
+
```
|
| 373 |
+
`simple` or `white` is randomly selected, and it becomes `simple background` or `white background`.
|
| 374 |
+
|
| 375 |
+
```txt
|
| 376 |
+
1girl, hatsune miku, vocaloid, {{retro style}}
|
| 377 |
+
```
|
| 378 |
+
If you want to include `{` or `}` in the tag string, double them like `{{` or `}}` (in this example, the actual caption used for training is `{retro style}`).
|
| 379 |
+
|
| 380 |
+
### Example of caption, `keep_tokens_separator` notation: `keep_tokens_separator = "|||"`
|
| 381 |
+
|
| 382 |
+
```txt
|
| 383 |
+
1girl, hatsune miku, vocaloid ||| stage, microphone, white shirt, smile ||| best quality, rating: general
|
| 384 |
+
```
|
| 385 |
+
It becomes `1girl, hatsune miku, vocaloid, microphone, stage, white shirt, best quality, rating: general` or `1girl, hatsune miku, vocaloid, white shirt, smile, stage, microphone, best quality, rating: general` etc.
|
| 386 |
+
|
config_README-ja.md
ADDED
|
@@ -0,0 +1,388 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
`--dataset_config` で渡すことができる設定ファイルに関する説明です。
|
| 2 |
+
|
| 3 |
+
## 概要
|
| 4 |
+
|
| 5 |
+
設定ファイルを渡すことにより、ユーザが細かい設定を行えるようにします。
|
| 6 |
+
|
| 7 |
+
* 複数のデータセットが設定可能になります
|
| 8 |
+
* 例えば `resolution` をデータセットごとに設定して、それらを混合して学習できます。
|
| 9 |
+
* DreamBooth の手法と fine tuning の手法の両方に対応している学習方法では、DreamBooth 方式と fine tuning 方式のデータセットを混合することが可能です。
|
| 10 |
+
* サブセットごとに設定を変更することが可能になります
|
| 11 |
+
* データセットを画像ディレクトリ別またはメタデータ別に分割したものがサブセットです。いくつかのサブセットが集まってデータセットを構成します。
|
| 12 |
+
* `keep_tokens` や `flip_aug` 等のオプションはサブセットごとに設定可能です。一方、`resolution` や `batch_size` といったオプションはデータセットごとに設定可能で、同じデータセットに属するサブセットでは値が共通になります。詳しくは後述します。
|
| 13 |
+
|
| 14 |
+
設定ファイルの形式は JSON か TOML を利用できます。記述のしやすさを考えると [TOML](https://toml.io/ja/v1.0.0-rc.2) を利用するのがオススメです。以下、TOML の利用を前提に説明します。
|
| 15 |
+
|
| 16 |
+
TOML で記述した設定ファイルの例です。
|
| 17 |
+
|
| 18 |
+
```toml
|
| 19 |
+
[general]
|
| 20 |
+
shuffle_caption = true
|
| 21 |
+
caption_extension = '.txt'
|
| 22 |
+
keep_tokens = 1
|
| 23 |
+
|
| 24 |
+
# これは DreamBooth 方式のデータセット
|
| 25 |
+
[[datasets]]
|
| 26 |
+
resolution = 512
|
| 27 |
+
batch_size = 4
|
| 28 |
+
keep_tokens = 2
|
| 29 |
+
|
| 30 |
+
[[datasets.subsets]]
|
| 31 |
+
image_dir = 'C:\hoge'
|
| 32 |
+
class_tokens = 'hoge girl'
|
| 33 |
+
# このサブセットは keep_tokens = 2 (所属する datasets の値が使われる)
|
| 34 |
+
|
| 35 |
+
[[datasets.subsets]]
|
| 36 |
+
image_dir = 'C:\fuga'
|
| 37 |
+
class_tokens = 'fuga boy'
|
| 38 |
+
keep_tokens = 3
|
| 39 |
+
|
| 40 |
+
[[datasets.subsets]]
|
| 41 |
+
is_reg = true
|
| 42 |
+
image_dir = 'C:\reg'
|
| 43 |
+
class_tokens = 'human'
|
| 44 |
+
keep_tokens = 1
|
| 45 |
+
|
| 46 |
+
# これは fine tuning 方式のデータセット
|
| 47 |
+
[[datasets]]
|
| 48 |
+
resolution = [768, 768]
|
| 49 |
+
batch_size = 2
|
| 50 |
+
|
| 51 |
+
[[datasets.subsets]]
|
| 52 |
+
image_dir = 'C:\piyo'
|
| 53 |
+
metadata_file = 'C:\piyo\piyo_md.json'
|
| 54 |
+
# このサブセットは keep_tokens = 1 (general の値が使われる)
|
| 55 |
+
```
|
| 56 |
+
|
| 57 |
+
この例では、3 つのディレクトリを DreamBooth 方式のデータセットとして 512x512 (batch size 4) で学習させ、1 つのディレクトリを fine tuning 方式のデータセットとして 768x768 (batch size 2) で学習させることになります。
|
| 58 |
+
|
| 59 |
+
## データセット・サブセットに関する設定
|
| 60 |
+
|
| 61 |
+
データセット・サブセットに関する設定は、登録可能な箇所がいくつかに分かれています。
|
| 62 |
+
|
| 63 |
+
* `[general]`
|
| 64 |
+
* 全データセットまたは全サブセットに適用されるオプションを指定する箇所です。
|
| 65 |
+
* データセットごとの設定及びサブセットごとの設定に同名のオプションが存在していた場合には、データセット・サブセットごとの設定が優先されます。
|
| 66 |
+
* `[[datasets]]`
|
| 67 |
+
* `datasets` はデータセットに関する設定の登録箇所になります。各データセットに個別に適用されるオプションを指定する箇所です。
|
| 68 |
+
* サブセットごとの設定が存在していた場合には、サブセットごとの設定が優先されます。
|
| 69 |
+
* `[[datasets.subsets]]`
|
| 70 |
+
* `datasets.subsets` はサブセットに関する設定の登録箇所になります。各サブセットに個別に適用されるオプションを指定する箇所です。
|
| 71 |
+
|
| 72 |
+
先程の例における、画像ディレクトリと登録箇所の対応に関するイメージ図です。
|
| 73 |
+
|
| 74 |
+
```
|
| 75 |
+
C:\
|
| 76 |
+
├─ hoge -> [[datasets.subsets]] No.1 ┐ ┐
|
| 77 |
+
├─ fuga -> [[datasets.subsets]] No.2 |-> [[datasets]] No.1 |-> [general]
|
| 78 |
+
├─ reg -> [[datasets.subsets]] No.3 ┘ |
|
| 79 |
+
└─ piyo -> [[datasets.subsets]] No.4 --> [[datasets]] No.2 ┘
|
| 80 |
+
```
|
| 81 |
+
|
| 82 |
+
画像ディレクトリがそれぞれ1つの `[[datasets.subsets]]` に対応しています。そして `[[datasets.subsets]]` が1つ以上組み合わさって1つの `[[datasets]]` を構成します。`[general]` には全ての `[[datasets]]`, `[[datasets.subsets]]` が属します。
|
| 83 |
+
|
| 84 |
+
登録箇所ごとに指定可能なオプションは異なりますが、同名のオプションが指定された場合は下位の登録箇所にある値が優先されます。先程の例の `keep_tokens` オプションの扱われ方を確認してもらうと理解しやすいかと思います。
|
| 85 |
+
|
| 86 |
+
加えて、学習方法が対応している手法によっても指定可能なオプションが変化します。
|
| 87 |
+
|
| 88 |
+
* DreamBooth 方式専用のオプション
|
| 89 |
+
* fine tuning 方式専用のオプション
|
| 90 |
+
* caption dropout の手法が使える場合のオプション
|
| 91 |
+
|
| 92 |
+
DreamBooth の手法と fine tuning の手法の両方とも利用可能な学習方法では、両者を併用することができます。
|
| 93 |
+
併用する際の注意点として、DreamBooth 方式なのか fine tuning 方式なのかはデータセット単位で判別を行っているため、同じデータセット中に DreamBooth 方式のサブセットと fine tuning 方式のサブセットを混在させることはできません。
|
| 94 |
+
つまり、これらを併用したい場合には異なる方式のサブセットが異なるデータセットに所属するように設定する必要があります。
|
| 95 |
+
|
| 96 |
+
プログラムの挙動としては、後述する `metadata_file` オプションが存在していたら fine tuning 方式のサブセットだと判断します。
|
| 97 |
+
そのため、同一のデータセットに所属するサブセットについて言うと、「全てが `metadata_file` オプションを持つ」か「全てが `metadata_file` オプションを持たない」かのどちらかになっていれば問題ありません。
|
| 98 |
+
|
| 99 |
+
以下、利用可能なオプションを説明します。コマンドライン引数と名称が同一のオプションについては、基本的に説明を割愛します。他の README を参照してください。
|
| 100 |
+
|
| 101 |
+
### 全学習方法で共通のオプション
|
| 102 |
+
|
| 103 |
+
学習方法によらずに指定可能なオプションです。
|
| 104 |
+
|
| 105 |
+
#### データセット向けオプション
|
| 106 |
+
|
| 107 |
+
データセットの設定に関わるオプションです。`datasets.subsets` には記述できません。
|
| 108 |
+
|
| 109 |
+
| オプション名 | 設定例 | `[general]` | `[[datasets]]` |
|
| 110 |
+
| ---- | ---- | ---- | ---- |
|
| 111 |
+
| `batch_size` | `1` | o | o |
|
| 112 |
+
| `bucket_no_upscale` | `true` | o | o |
|
| 113 |
+
| `bucket_reso_steps` | `64` | o | o |
|
| 114 |
+
| `enable_bucket` | `true` | o | o |
|
| 115 |
+
| `max_bucket_reso` | `1024` | o | o |
|
| 116 |
+
| `min_bucket_reso` | `128` | o | o |
|
| 117 |
+
| `resolution` | `256`, `[512, 512]` | o | o |
|
| 118 |
+
|
| 119 |
+
* `batch_size`
|
| 120 |
+
* コマンドライン引数の `--train_batch_size` と同等です。
|
| 121 |
+
* `max_bucket_reso`, `min_bucket_reso`
|
| 122 |
+
* bucketの最大、最小解像度を指定します。`bucket_reso_steps` で割り切れる必要があります。
|
| 123 |
+
|
| 124 |
+
これらの設定はデータセットごとに固定です。
|
| 125 |
+
つまり、データセットに所属するサブセットはこれらの設定を共有することになります。
|
| 126 |
+
例えば解像度が異なるデータセットを用意したい場合は、上に挙げた例のように別々のデータセットとして定義すれば別々の解像度を設定可能です。
|
| 127 |
+
|
| 128 |
+
#### サブセット向けオプション
|
| 129 |
+
|
| 130 |
+
サブセットの設定に関わるオプションです。
|
| 131 |
+
|
| 132 |
+
| オプション名 | 設定例 | `[general]` | `[[datasets]]` | `[[dataset.subsets]]` |
|
| 133 |
+
| ---- | ---- | ---- | ---- | ---- |
|
| 134 |
+
| `color_aug` | `false` | o | o | o |
|
| 135 |
+
| `face_crop_aug_range` | `[1.0, 3.0]` | o | o | o |
|
| 136 |
+
| `flip_aug` | `true` | o | o | o |
|
| 137 |
+
| `keep_tokens` | `2` | o | o | o |
|
| 138 |
+
| `num_repeats` | `10` | o | o | o |
|
| 139 |
+
| `random_crop` | `false` | o | o | o |
|
| 140 |
+
| `shuffle_caption` | `true` | o | o | o |
|
| 141 |
+
| `caption_prefix` | `“masterpiece, best quality, ”` | o | o | o |
|
| 142 |
+
| `caption_suffix` | `“, from side”` | o | o | o |
|
| 143 |
+
| `caption_separator` | (通常は設定しません) | o | o | o |
|
| 144 |
+
| `keep_tokens_separator` | `“|||”` | o | o | o |
|
| 145 |
+
| `secondary_separator` | `“;;;”` | o | o | o |
|
| 146 |
+
| `enable_wildcard` | `true` | o | o | o |
|
| 147 |
+
|
| 148 |
+
* `num_repeats`
|
| 149 |
+
* サブセットの画像の繰り返し回数を指定します。fine tuning における `--dataset_repeats` に相当しますが、`num_repeats` はどの学習方法でも指定可能です。
|
| 150 |
+
* `caption_prefix`, `caption_suffix`
|
| 151 |
+
* キャプションの前、後に付与する文字列を指定します。シャッフルはこれらの文字列を含めた状態で行われます。`keep_tokens` を指定する場合には注意してください。
|
| 152 |
+
|
| 153 |
+
* `caption_separator`
|
| 154 |
+
* タグを区切る文字列を指定します。デフォルトは `,` です。このオプションは通常は設定する必要はありません。
|
| 155 |
+
|
| 156 |
+
* `keep_tokens_separator`
|
| 157 |
+
* キャプションで固定したい部分を区切る文字列を指定します。たとえば `aaa, bbb ||| ccc, ddd, eee, fff ||| ggg, hhh` のように指定すると、`aaa, bbb` と `ggg, hhh` の部分はシャッフル、drop されず残ります。間のカンマは不要です。結果としてプロンプトは `aaa, bbb, eee, ccc, fff, ggg, hhh` や `aaa, bbb, fff, ccc, eee, ggg, hhh` などになります。
|
| 158 |
+
|
| 159 |
+
* `secondary_separator`
|
| 160 |
+
* 追加の区切り文字を指定します。この区切り文字で区切られた部分は一つのタグとして扱われ、シャッフル、drop されます。その後、`caption_separator` に置き換えられます。たとえば `aaa;;;bbb;;;ccc` のように指定すると、`aaa,bbb,ccc` に置き換えられるか、まとめて drop されます。
|
| 161 |
+
|
| 162 |
+
* `enable_wildcard`
|
| 163 |
+
* ワイルドカード���法および複数行キャプションを有効にします。ワイルドカード記法、複数行キャプションについては後述します。
|
| 164 |
+
|
| 165 |
+
### DreamBooth 方式専用のオプション
|
| 166 |
+
|
| 167 |
+
DreamBooth 方式のオプションは、サブセット向けオプションのみ存在します。
|
| 168 |
+
|
| 169 |
+
#### サブセット向けオプション
|
| 170 |
+
|
| 171 |
+
DreamBooth 方式のサブセットの設定に関わるオプションです。
|
| 172 |
+
|
| 173 |
+
| オプション名 | 設定例 | `[general]` | `[[datasets]]` | `[[dataset.subsets]]` |
|
| 174 |
+
| ---- | ---- | ---- | ---- | ---- |
|
| 175 |
+
| `image_dir` | `‘C:\hoge’` | - | - | o(必須) |
|
| 176 |
+
| `caption_extension` | `".txt"` | o | o | o |
|
| 177 |
+
| `class_tokens` | `“sks girl”` | - | - | o |
|
| 178 |
+
| `cache_info` | `false` | o | o | o |
|
| 179 |
+
| `is_reg` | `false` | - | - | o |
|
| 180 |
+
|
| 181 |
+
まず注意点として、 `image_dir` には画像ファイルが直下に置かれているパスを指定する必要があります。従来の DreamBooth の手法ではサブディレクトリに画像を置く必要がありましたが、そちらとは仕様に互換性がありません。また、`5_cat` のようなフォルダ名にしても、画像の繰り返し回数とクラス名は反映されません。これらを個別に設定したい場合、`num_repeats` と `class_tokens` で明示的に指定する必要があることに注意してください。
|
| 182 |
+
|
| 183 |
+
* `image_dir`
|
| 184 |
+
* 画像ディレクトリのパスを指定します。指定必須オプションです。
|
| 185 |
+
* 画像はディレクトリ直下に置かれている必要があります。
|
| 186 |
+
* `class_tokens`
|
| 187 |
+
* クラストークンを設定します。
|
| 188 |
+
* 画像に対応する caption ファイルが存在しない場合にのみ学習時に利用されます。利用するかどうかの判定は画像ごとに行います。`class_tokens` を指定しなかった場合に caption ファイルも見つからなかった場合にはエラーになります。
|
| 189 |
+
* `cache_info`
|
| 190 |
+
* 画像サイズ、キャプションをキャッシュするかどうかを指定します。指定しなかった場合は `false` になります。キャッシュは `image_dir` に `metadata_cache.json` というファイル名で保存されます。
|
| 191 |
+
* キャッシュを行うと、二回目以降のデータセット読み込みが高速化されます。数千枚以上の画像を扱う場合には有効です。
|
| 192 |
+
* `is_reg`
|
| 193 |
+
* サブセットの画像が正規化用かどうかを指定します。指定しなかった場合は `false` として、つまり正規化画像ではないとして扱います。
|
| 194 |
+
|
| 195 |
+
### fine tuning 方式専用のオプション
|
| 196 |
+
|
| 197 |
+
fine tuning 方式のオプションは、サブセット向けオプションのみ存在します。
|
| 198 |
+
|
| 199 |
+
#### サブセット向けオプション
|
| 200 |
+
|
| 201 |
+
fine tuning 方式のサブセットの設定に関わるオプションです。
|
| 202 |
+
|
| 203 |
+
| オプション名 | 設定例 | `[general]` | `[[datasets]]` | `[[dataset.subsets]]` |
|
| 204 |
+
| ---- | ---- | ---- | ---- | ---- |
|
| 205 |
+
| `image_dir` | `‘C:\hoge’` | - | - | o |
|
| 206 |
+
| `metadata_file` | `'C:\piyo\piyo_md.json'` | - | - | o(必須) |
|
| 207 |
+
|
| 208 |
+
* `image_dir`
|
| 209 |
+
* 画像ディレクトリのパスを指定します。DreamBooth の手法の方とは異なり指定は必須ではありませんが、設定することを推奨します。
|
| 210 |
+
* 指定する必要がない状況としては、メタデータファイルの生成時に `--full_path` を付与して実行していた場合です。
|
| 211 |
+
* 画像はディレクトリ直下に置かれている必要があります。
|
| 212 |
+
* `metadata_file`
|
| 213 |
+
* サブセットで利用されるメタデータファイルのパスを指定します。指定必須オプションです。
|
| 214 |
+
* コマンドライン引数の `--in_json` と同等です。
|
| 215 |
+
* サブセットごとにメタデータファイルを指定する必要がある仕様上、ディレクトリを跨いだメタデータを1つのメタデータファイルとして作成することは避けた方が良いでしょう。画像ディレクトリごとにメタデータファイルを用意し、それらを別々のサブセットとして登録することを強く推奨します。
|
| 216 |
+
|
| 217 |
+
### caption dropout の手法が使える場合に指定可能なオプション
|
| 218 |
+
|
| 219 |
+
caption dropout の手法が使える場合のオプションは、サブセット向けオプションのみ存在します。
|
| 220 |
+
DreamBooth 方式か fine tuning 方式かに関わらず、caption dropout に対応している学習方法であれば指定可能です。
|
| 221 |
+
|
| 222 |
+
#### サブセット向けオプション
|
| 223 |
+
|
| 224 |
+
caption dropout が使えるサブセットの設定に関わるオプションです。
|
| 225 |
+
|
| 226 |
+
| オプション名 | `[general]` | `[[datasets]]` | `[[dataset.subsets]]` |
|
| 227 |
+
| ---- | ---- | ---- | ---- |
|
| 228 |
+
| `caption_dropout_every_n_epochs` | o | o | o |
|
| 229 |
+
| `caption_dropout_rate` | o | o | o |
|
| 230 |
+
| `caption_tag_dropout_rate` | o | o | o |
|
| 231 |
+
|
| 232 |
+
## 重複したサブセットが存在する時の挙動
|
| 233 |
+
|
| 234 |
+
DreamBooth 方式のデータセットの場合、���の中にある `image_dir` が同一のサブセットは重複していると見なされます。
|
| 235 |
+
fine tuning 方式のデータセットの場合は、その中にある `metadata_file` が同一のサブセットは重複していると見なされます。
|
| 236 |
+
データセット中に重複したサブセットが存在する場合、2個目以降は無視されます。
|
| 237 |
+
|
| 238 |
+
一方、異なるデータセットに所属している場合は、重複しているとは見なされません。
|
| 239 |
+
例えば、以下のように同一の `image_dir` を持つサブセットを別々のデータセットに入れた場合には、重複していないと見なします。
|
| 240 |
+
これは、同じ画像でも異なる解像度で学習したい場合に役立ちます。
|
| 241 |
+
|
| 242 |
+
```toml
|
| 243 |
+
# 別々のデータセットに存在している場合は重複とは見なされず、両方とも学習に使われる
|
| 244 |
+
|
| 245 |
+
[[datasets]]
|
| 246 |
+
resolution = 512
|
| 247 |
+
|
| 248 |
+
[[datasets.subsets]]
|
| 249 |
+
image_dir = 'C:\hoge'
|
| 250 |
+
|
| 251 |
+
[[datasets]]
|
| 252 |
+
resolution = 768
|
| 253 |
+
|
| 254 |
+
[[datasets.subsets]]
|
| 255 |
+
image_dir = 'C:\hoge'
|
| 256 |
+
```
|
| 257 |
+
|
| 258 |
+
## コマンドライン引数との併用
|
| 259 |
+
|
| 260 |
+
設定ファイルのオプションの中には、コマンドライン引数のオプションと役割が重複しているものがあります。
|
| 261 |
+
|
| 262 |
+
以下に挙げるコマンドライン引数のオプションは、設定ファイルを渡した場合には無視されます。
|
| 263 |
+
|
| 264 |
+
* `--train_data_dir`
|
| 265 |
+
* `--reg_data_dir`
|
| 266 |
+
* `--in_json`
|
| 267 |
+
|
| 268 |
+
以下に挙げるコマンドライン引数のオプションは、コマンドライン引数と設定ファイルで同時に指定された場合、コマンドライン引数の値よりも設定ファイルの値が優先されます。特に断りがなければ同名のオプションとなります。
|
| 269 |
+
|
| 270 |
+
| コマンドライン引数のオプション | 優先される設定ファイルのオプション |
|
| 271 |
+
| ---------------------------------- | ---------------------------------- |
|
| 272 |
+
| `--bucket_no_upscale` | |
|
| 273 |
+
| `--bucket_reso_steps` | |
|
| 274 |
+
| `--caption_dropout_every_n_epochs` | |
|
| 275 |
+
| `--caption_dropout_rate` | |
|
| 276 |
+
| `--caption_extension` | |
|
| 277 |
+
| `--caption_tag_dropout_rate` | |
|
| 278 |
+
| `--color_aug` | |
|
| 279 |
+
| `--dataset_repeats` | `num_repeats` |
|
| 280 |
+
| `--enable_bucket` | |
|
| 281 |
+
| `--face_crop_aug_range` | |
|
| 282 |
+
| `--flip_aug` | |
|
| 283 |
+
| `--keep_tokens` | |
|
| 284 |
+
| `--min_bucket_reso` | |
|
| 285 |
+
| `--random_crop` | |
|
| 286 |
+
| `--resolution` | |
|
| 287 |
+
| `--shuffle_caption` | |
|
| 288 |
+
| `--train_batch_size` | `batch_size` |
|
| 289 |
+
|
| 290 |
+
## エラーの手引き
|
| 291 |
+
|
| 292 |
+
現在、外部ライブラリを利用して設定ファイルの記述が正しいかどうかをチェックしているのですが、整備が行き届いておらずエラーメッセージがわかりづらいという問題があります。
|
| 293 |
+
将来的にはこの問題の改善に取り組む予定です。
|
| 294 |
+
|
| 295 |
+
次善策として、頻出のエラーとその対処法について載せておきます。
|
| 296 |
+
正しいはずなのにエラーが出る場合、エラー内容がどうしても分からない場合は、バグかもしれないのでご連絡ください。
|
| 297 |
+
|
| 298 |
+
* `voluptuous.error.MultipleInvalid: required key not provided @ ...`: 指定必須のオプションが指定されていないというエラーです。指定を忘れているか、オプション名を間違って記述している可能性が高いです。
|
| 299 |
+
* `...` の箇所にはエラーが発生した場所が載っています。例えば `voluptuous.error.MultipleInvalid: required key not provided @ data['datasets'][0]['subsets'][0]['image_dir']` のようなエラーが出たら、0 番目の `datasets` 中の 0 番目の `subsets` の設定に `image_dir` が存在しないということになります。
|
| 300 |
+
* `voluptuous.error.MultipleInvalid: expected int for dictionary value @ ...`: 指定する値の形式が不正というエラーです。値の形式が間違っている可能性が高いです。`int` の部分は対象となるオプションによって変わります。この README に載っているオプションの「設定例」が役立つかもしれません。
|
| 301 |
+
* `voluptuous.error.MultipleInvalid: extra keys not allowed @ ...`: 対応していないオプション名が存在している場合に発生するエラーです。オプション名を間違って��述しているか、誤って紛れ込んでいる可能性が高いです。
|
| 302 |
+
|
| 303 |
+
## その他
|
| 304 |
+
|
| 305 |
+
### 複数行キャプション
|
| 306 |
+
|
| 307 |
+
`enable_wildcard = true` を設定することで、複数行キャプションも同時に有効になります。キャプションファイルが複数の行からなる場合、ランダムに一つの行が選ばれてキャプションとして利用されます。
|
| 308 |
+
|
| 309 |
+
```txt
|
| 310 |
+
1girl, hatsune miku, vocaloid, upper body, looking at viewer, microphone, stage
|
| 311 |
+
a girl with a microphone standing on a stage
|
| 312 |
+
detailed digital art of a girl with a microphone on a stage
|
| 313 |
+
```
|
| 314 |
+
|
| 315 |
+
ワイルドカード記法と組み合わせることも可能です。
|
| 316 |
+
|
| 317 |
+
メタデータファイルでも同様に複数行キャプションを指定することができます。メタデータの .json 内には、`\n` を使って改行を表現してください。キャプションファイルが複数行からなる場合、`merge_captions_to_metadata.py` を使うと、この形式でメタデータファイルが作成されます。
|
| 318 |
+
|
| 319 |
+
メタデータのタグ (`tags`) は、キャプションの各行に追加されます。
|
| 320 |
+
|
| 321 |
+
```json
|
| 322 |
+
{
|
| 323 |
+
"/path/to/image.png": {
|
| 324 |
+
"caption": "a cartoon of a frog with the word frog on it\ntest multiline caption1\ntest multiline caption2",
|
| 325 |
+
"tags": "open mouth, simple background, standing, no humans, animal, black background, frog, animal costume, animal focus"
|
| 326 |
+
},
|
| 327 |
+
...
|
| 328 |
+
}
|
| 329 |
+
```
|
| 330 |
+
|
| 331 |
+
この場合、実際のキャプションは `a cartoon of a frog with the word frog on it, open mouth, simple background ...` または `test multiline caption1, open mouth, simple background ...`、 `test multiline caption2, open mouth, simple background ...` 等になります。
|
| 332 |
+
|
| 333 |
+
### 設定ファイルの記述例:追加の区切り文字、ワイルドカード記法、`keep_tokens_separator` 等
|
| 334 |
+
|
| 335 |
+
```toml
|
| 336 |
+
[general]
|
| 337 |
+
flip_aug = true
|
| 338 |
+
color_aug = false
|
| 339 |
+
resolution = [1024, 1024]
|
| 340 |
+
|
| 341 |
+
[[datasets]]
|
| 342 |
+
batch_size = 6
|
| 343 |
+
enable_bucket = true
|
| 344 |
+
bucket_no_upscale = true
|
| 345 |
+
caption_extension = ".txt"
|
| 346 |
+
keep_tokens_separator= "|||"
|
| 347 |
+
shuffle_caption = true
|
| 348 |
+
caption_tag_dropout_rate = 0.1
|
| 349 |
+
secondary_separator = ";;;" # subset 側に書くこともできます / can be written in the subset side
|
| 350 |
+
enable_wildcard = true # 同上 / same as above
|
| 351 |
+
|
| 352 |
+
[[datasets.subsets]]
|
| 353 |
+
image_dir = "/path/to/image_dir"
|
| 354 |
+
num_repeats = 1
|
| 355 |
+
|
| 356 |
+
# ||| の前後はカンマは不要です(自動的に追加されます) / No comma is required before and after ||| (it is added automatically)
|
| 357 |
+
caption_prefix = "1girl, hatsune miku, vocaloid |||"
|
| 358 |
+
|
| 359 |
+
# ||| の後はシャッフル、drop されず残ります / After |||, it is not shuffled or dropped and remains
|
| 360 |
+
# 単純に文字列として連結されるので、カンマなどは自分で入れる必要があります / It is simply concatenated as a string, so you need to put commas yourself
|
| 361 |
+
caption_suffix = ", anime screencap ||| masterpiece, rating: general"
|
| 362 |
+
```
|
| 363 |
+
|
| 364 |
+
### キャプション記述例、secondary_separator 記法:`secondary_separator = ";;;"` の場合
|
| 365 |
+
|
| 366 |
+
```txt
|
| 367 |
+
1girl, hatsune miku, vocaloid, upper body, looking at viewer, sky;;;cloud;;;day, outdoors
|
| 368 |
+
```
|
| 369 |
+
`sky;;;cloud;;;day` の部分はシャッフル、drop されず `sky,cloud,day` に置換されます。シャッフル、drop が有効な場合、まとめて(一つのタグとして)処理されます。つまり `vocaloid, 1girl, upper body, sky,cloud,day, outdoors, hatsune miku` (シャッフル)や `vocaloid, 1girl, outdoors, looking at viewer, upper body, hatsune miku` (drop されたケース)などになります。
|
| 370 |
+
|
| 371 |
+
### キャプション記述例、ワイルドカード記法: `enable_wildcard = true` の場合
|
| 372 |
+
|
| 373 |
+
```txt
|
| 374 |
+
1girl, hatsune miku, vocaloid, upper body, looking at viewer, {simple|white} background
|
| 375 |
+
```
|
| 376 |
+
ランダムに `simple` または `white` が選ばれ、`simple background` または `white background` になります。
|
| 377 |
+
|
| 378 |
+
```txt
|
| 379 |
+
1girl, hatsune miku, vocaloid, {{retro style}}
|
| 380 |
+
```
|
| 381 |
+
タグ文字列に `{` や `}` そのものを含めたい場合は `{{` や `}}` のように二つ重ねてください(この例では実際に学習に用いられるキャプションは `{retro style}` になります)。
|
| 382 |
+
|
| 383 |
+
### キャプション記述例、`keep_tokens_separator` 記法: `keep_tokens_separator = "|||"` の場合
|
| 384 |
+
|
| 385 |
+
```txt
|
| 386 |
+
1girl, hatsune miku, vocaloid ||| stage, microphone, white shirt, smile ||| best quality, rating: general
|
| 387 |
+
```
|
| 388 |
+
`1girl, hatsune miku, vocaloid, microphone, stage, white shirt, best quality, rating: general` や `1girl, hatsune miku, vocaloid, white shirt, smile, stage, microphone, best quality, rating: general` などになります。
|
config_util.py
ADDED
|
@@ -0,0 +1,721 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import argparse
|
| 2 |
+
from dataclasses import (
|
| 3 |
+
asdict,
|
| 4 |
+
dataclass,
|
| 5 |
+
)
|
| 6 |
+
import functools
|
| 7 |
+
import random
|
| 8 |
+
from textwrap import dedent, indent
|
| 9 |
+
import json
|
| 10 |
+
from pathlib import Path
|
| 11 |
+
|
| 12 |
+
# from toolz import curry
|
| 13 |
+
from typing import (
|
| 14 |
+
List,
|
| 15 |
+
Optional,
|
| 16 |
+
Sequence,
|
| 17 |
+
Tuple,
|
| 18 |
+
Union,
|
| 19 |
+
)
|
| 20 |
+
|
| 21 |
+
import toml
|
| 22 |
+
import voluptuous
|
| 23 |
+
from voluptuous import (
|
| 24 |
+
Any,
|
| 25 |
+
ExactSequence,
|
| 26 |
+
MultipleInvalid,
|
| 27 |
+
Object,
|
| 28 |
+
Required,
|
| 29 |
+
Schema,
|
| 30 |
+
)
|
| 31 |
+
from transformers import CLIPTokenizer
|
| 32 |
+
|
| 33 |
+
from . import train_util
|
| 34 |
+
from .train_util import (
|
| 35 |
+
DreamBoothSubset,
|
| 36 |
+
FineTuningSubset,
|
| 37 |
+
ControlNetSubset,
|
| 38 |
+
DreamBoothDataset,
|
| 39 |
+
FineTuningDataset,
|
| 40 |
+
ControlNetDataset,
|
| 41 |
+
DatasetGroup,
|
| 42 |
+
)
|
| 43 |
+
from .utils import setup_logging
|
| 44 |
+
|
| 45 |
+
setup_logging()
|
| 46 |
+
import logging
|
| 47 |
+
|
| 48 |
+
logger = logging.getLogger(__name__)
|
| 49 |
+
|
| 50 |
+
|
| 51 |
+
def add_config_arguments(parser: argparse.ArgumentParser):
|
| 52 |
+
parser.add_argument(
|
| 53 |
+
"--dataset_config", type=Path, default=None, help="config file for detail settings / 詳細な設定用の設定ファイル"
|
| 54 |
+
)
|
| 55 |
+
|
| 56 |
+
|
| 57 |
+
# TODO: inherit Params class in Subset, Dataset
|
| 58 |
+
|
| 59 |
+
|
| 60 |
+
@dataclass
|
| 61 |
+
class BaseSubsetParams:
|
| 62 |
+
image_dir: Optional[str] = None
|
| 63 |
+
num_repeats: int = 1
|
| 64 |
+
shuffle_caption: bool = False
|
| 65 |
+
caption_separator: str = (",",)
|
| 66 |
+
keep_tokens: int = 0
|
| 67 |
+
keep_tokens_separator: str = (None,)
|
| 68 |
+
secondary_separator: Optional[str] = None
|
| 69 |
+
enable_wildcard: bool = False
|
| 70 |
+
color_aug: bool = False
|
| 71 |
+
flip_aug: bool = False
|
| 72 |
+
face_crop_aug_range: Optional[Tuple[float, float]] = None
|
| 73 |
+
random_crop: bool = False
|
| 74 |
+
caption_prefix: Optional[str] = None
|
| 75 |
+
caption_suffix: Optional[str] = None
|
| 76 |
+
caption_dropout_rate: float = 0.0
|
| 77 |
+
caption_dropout_every_n_epochs: int = 0
|
| 78 |
+
caption_tag_dropout_rate: float = 0.0
|
| 79 |
+
token_warmup_min: int = 1
|
| 80 |
+
token_warmup_step: float = 0
|
| 81 |
+
|
| 82 |
+
|
| 83 |
+
@dataclass
|
| 84 |
+
class DreamBoothSubsetParams(BaseSubsetParams):
|
| 85 |
+
is_reg: bool = False
|
| 86 |
+
class_tokens: Optional[str] = None
|
| 87 |
+
caption_extension: str = ".caption"
|
| 88 |
+
cache_info: bool = False
|
| 89 |
+
alpha_mask: bool = False
|
| 90 |
+
|
| 91 |
+
|
| 92 |
+
@dataclass
|
| 93 |
+
class FineTuningSubsetParams(BaseSubsetParams):
|
| 94 |
+
metadata_file: Optional[str] = None
|
| 95 |
+
alpha_mask: bool = False
|
| 96 |
+
|
| 97 |
+
|
| 98 |
+
@dataclass
|
| 99 |
+
class ControlNetSubsetParams(BaseSubsetParams):
|
| 100 |
+
conditioning_data_dir: str = None
|
| 101 |
+
caption_extension: str = ".caption"
|
| 102 |
+
cache_info: bool = False
|
| 103 |
+
|
| 104 |
+
|
| 105 |
+
@dataclass
|
| 106 |
+
class BaseDatasetParams:
|
| 107 |
+
tokenizer: Union[CLIPTokenizer, List[CLIPTokenizer]] = None
|
| 108 |
+
max_token_length: int = None
|
| 109 |
+
resolution: Optional[Tuple[int, int]] = None
|
| 110 |
+
network_multiplier: float = 1.0
|
| 111 |
+
debug_dataset: bool = False
|
| 112 |
+
|
| 113 |
+
|
| 114 |
+
@dataclass
|
| 115 |
+
class DreamBoothDatasetParams(BaseDatasetParams):
|
| 116 |
+
batch_size: int = 1
|
| 117 |
+
enable_bucket: bool = False
|
| 118 |
+
min_bucket_reso: int = 256
|
| 119 |
+
max_bucket_reso: int = 1024
|
| 120 |
+
bucket_reso_steps: int = 64
|
| 121 |
+
bucket_no_upscale: bool = False
|
| 122 |
+
prior_loss_weight: float = 1.0
|
| 123 |
+
|
| 124 |
+
|
| 125 |
+
@dataclass
|
| 126 |
+
class FineTuningDatasetParams(BaseDatasetParams):
|
| 127 |
+
batch_size: int = 1
|
| 128 |
+
enable_bucket: bool = False
|
| 129 |
+
min_bucket_reso: int = 256
|
| 130 |
+
max_bucket_reso: int = 1024
|
| 131 |
+
bucket_reso_steps: int = 64
|
| 132 |
+
bucket_no_upscale: bool = False
|
| 133 |
+
|
| 134 |
+
|
| 135 |
+
@dataclass
|
| 136 |
+
class ControlNetDatasetParams(BaseDatasetParams):
|
| 137 |
+
batch_size: int = 1
|
| 138 |
+
enable_bucket: bool = False
|
| 139 |
+
min_bucket_reso: int = 256
|
| 140 |
+
max_bucket_reso: int = 1024
|
| 141 |
+
bucket_reso_steps: int = 64
|
| 142 |
+
bucket_no_upscale: bool = False
|
| 143 |
+
|
| 144 |
+
|
| 145 |
+
@dataclass
|
| 146 |
+
class SubsetBlueprint:
|
| 147 |
+
params: Union[DreamBoothSubsetParams, FineTuningSubsetParams]
|
| 148 |
+
|
| 149 |
+
|
| 150 |
+
@dataclass
|
| 151 |
+
class DatasetBlueprint:
|
| 152 |
+
is_dreambooth: bool
|
| 153 |
+
is_controlnet: bool
|
| 154 |
+
params: Union[DreamBoothDatasetParams, FineTuningDatasetParams]
|
| 155 |
+
subsets: Sequence[SubsetBlueprint]
|
| 156 |
+
|
| 157 |
+
|
| 158 |
+
@dataclass
|
| 159 |
+
class DatasetGroupBlueprint:
|
| 160 |
+
datasets: Sequence[DatasetBlueprint]
|
| 161 |
+
|
| 162 |
+
|
| 163 |
+
@dataclass
|
| 164 |
+
class Blueprint:
|
| 165 |
+
dataset_group: DatasetGroupBlueprint
|
| 166 |
+
|
| 167 |
+
|
| 168 |
+
class ConfigSanitizer:
|
| 169 |
+
# @curry
|
| 170 |
+
@staticmethod
|
| 171 |
+
def __validate_and_convert_twodim(klass, value: Sequence) -> Tuple:
|
| 172 |
+
Schema(ExactSequence([klass, klass]))(value)
|
| 173 |
+
return tuple(value)
|
| 174 |
+
|
| 175 |
+
# @curry
|
| 176 |
+
@staticmethod
|
| 177 |
+
def __validate_and_convert_scalar_or_twodim(klass, value: Union[float, Sequence]) -> Tuple:
|
| 178 |
+
Schema(Any(klass, ExactSequence([klass, klass])))(value)
|
| 179 |
+
try:
|
| 180 |
+
Schema(klass)(value)
|
| 181 |
+
return (value, value)
|
| 182 |
+
except:
|
| 183 |
+
return ConfigSanitizer.__validate_and_convert_twodim(klass, value)
|
| 184 |
+
|
| 185 |
+
# subset schema
|
| 186 |
+
SUBSET_ASCENDABLE_SCHEMA = {
|
| 187 |
+
"color_aug": bool,
|
| 188 |
+
"face_crop_aug_range": functools.partial(__validate_and_convert_twodim.__func__, float),
|
| 189 |
+
"flip_aug": bool,
|
| 190 |
+
"num_repeats": int,
|
| 191 |
+
"random_crop": bool,
|
| 192 |
+
"shuffle_caption": bool,
|
| 193 |
+
"keep_tokens": int,
|
| 194 |
+
"keep_tokens_separator": str,
|
| 195 |
+
"secondary_separator": str,
|
| 196 |
+
"caption_separator": str,
|
| 197 |
+
"enable_wildcard": bool,
|
| 198 |
+
"token_warmup_min": int,
|
| 199 |
+
"token_warmup_step": Any(float, int),
|
| 200 |
+
"caption_prefix": str,
|
| 201 |
+
"caption_suffix": str,
|
| 202 |
+
}
|
| 203 |
+
# DO means DropOut
|
| 204 |
+
DO_SUBSET_ASCENDABLE_SCHEMA = {
|
| 205 |
+
"caption_dropout_every_n_epochs": int,
|
| 206 |
+
"caption_dropout_rate": Any(float, int),
|
| 207 |
+
"caption_tag_dropout_rate": Any(float, int),
|
| 208 |
+
}
|
| 209 |
+
# DB means DreamBooth
|
| 210 |
+
DB_SUBSET_ASCENDABLE_SCHEMA = {
|
| 211 |
+
"caption_extension": str,
|
| 212 |
+
"class_tokens": str,
|
| 213 |
+
"cache_info": bool,
|
| 214 |
+
}
|
| 215 |
+
DB_SUBSET_DISTINCT_SCHEMA = {
|
| 216 |
+
Required("image_dir"): str,
|
| 217 |
+
"is_reg": bool,
|
| 218 |
+
"alpha_mask": bool,
|
| 219 |
+
}
|
| 220 |
+
# FT means FineTuning
|
| 221 |
+
FT_SUBSET_DISTINCT_SCHEMA = {
|
| 222 |
+
Required("metadata_file"): str,
|
| 223 |
+
"image_dir": str,
|
| 224 |
+
"alpha_mask": bool,
|
| 225 |
+
}
|
| 226 |
+
CN_SUBSET_ASCENDABLE_SCHEMA = {
|
| 227 |
+
"caption_extension": str,
|
| 228 |
+
"cache_info": bool,
|
| 229 |
+
}
|
| 230 |
+
CN_SUBSET_DISTINCT_SCHEMA = {
|
| 231 |
+
Required("image_dir"): str,
|
| 232 |
+
Required("conditioning_data_dir"): str,
|
| 233 |
+
}
|
| 234 |
+
|
| 235 |
+
# datasets schema
|
| 236 |
+
DATASET_ASCENDABLE_SCHEMA = {
|
| 237 |
+
"batch_size": int,
|
| 238 |
+
"bucket_no_upscale": bool,
|
| 239 |
+
"bucket_reso_steps": int,
|
| 240 |
+
"enable_bucket": bool,
|
| 241 |
+
"max_bucket_reso": int,
|
| 242 |
+
"min_bucket_reso": int,
|
| 243 |
+
"resolution": functools.partial(__validate_and_convert_scalar_or_twodim.__func__, int),
|
| 244 |
+
"network_multiplier": float,
|
| 245 |
+
}
|
| 246 |
+
|
| 247 |
+
# options handled by argparse but not handled by user config
|
| 248 |
+
ARGPARSE_SPECIFIC_SCHEMA = {
|
| 249 |
+
"debug_dataset": bool,
|
| 250 |
+
"max_token_length": Any(None, int),
|
| 251 |
+
"prior_loss_weight": Any(float, int),
|
| 252 |
+
}
|
| 253 |
+
# for handling default None value of argparse
|
| 254 |
+
ARGPARSE_NULLABLE_OPTNAMES = [
|
| 255 |
+
"face_crop_aug_range",
|
| 256 |
+
"resolution",
|
| 257 |
+
]
|
| 258 |
+
# prepare map because option name may differ among argparse and user config
|
| 259 |
+
ARGPARSE_OPTNAME_TO_CONFIG_OPTNAME = {
|
| 260 |
+
"train_batch_size": "batch_size",
|
| 261 |
+
"dataset_repeats": "num_repeats",
|
| 262 |
+
}
|
| 263 |
+
|
| 264 |
+
def __init__(self, support_dreambooth: bool, support_finetuning: bool, support_controlnet: bool, support_dropout: bool) -> None:
|
| 265 |
+
assert support_dreambooth or support_finetuning or support_controlnet, (
|
| 266 |
+
"Neither DreamBooth mode nor fine tuning mode nor controlnet mode specified. Please specify one mode or more."
|
| 267 |
+
+ " / DreamBooth モードか fine tuning モードか controlnet モードのどれも指定されていません。1つ以上指定してください。"
|
| 268 |
+
)
|
| 269 |
+
|
| 270 |
+
self.db_subset_schema = self.__merge_dict(
|
| 271 |
+
self.SUBSET_ASCENDABLE_SCHEMA,
|
| 272 |
+
self.DB_SUBSET_DISTINCT_SCHEMA,
|
| 273 |
+
self.DB_SUBSET_ASCENDABLE_SCHEMA,
|
| 274 |
+
self.DO_SUBSET_ASCENDABLE_SCHEMA if support_dropout else {},
|
| 275 |
+
)
|
| 276 |
+
|
| 277 |
+
self.ft_subset_schema = self.__merge_dict(
|
| 278 |
+
self.SUBSET_ASCENDABLE_SCHEMA,
|
| 279 |
+
self.FT_SUBSET_DISTINCT_SCHEMA,
|
| 280 |
+
self.DO_SUBSET_ASCENDABLE_SCHEMA if support_dropout else {},
|
| 281 |
+
)
|
| 282 |
+
|
| 283 |
+
self.cn_subset_schema = self.__merge_dict(
|
| 284 |
+
self.SUBSET_ASCENDABLE_SCHEMA,
|
| 285 |
+
self.CN_SUBSET_DISTINCT_SCHEMA,
|
| 286 |
+
self.CN_SUBSET_ASCENDABLE_SCHEMA,
|
| 287 |
+
self.DO_SUBSET_ASCENDABLE_SCHEMA if support_dropout else {},
|
| 288 |
+
)
|
| 289 |
+
|
| 290 |
+
self.db_dataset_schema = self.__merge_dict(
|
| 291 |
+
self.DATASET_ASCENDABLE_SCHEMA,
|
| 292 |
+
self.SUBSET_ASCENDABLE_SCHEMA,
|
| 293 |
+
self.DB_SUBSET_ASCENDABLE_SCHEMA,
|
| 294 |
+
self.DO_SUBSET_ASCENDABLE_SCHEMA if support_dropout else {},
|
| 295 |
+
{"subsets": [self.db_subset_schema]},
|
| 296 |
+
)
|
| 297 |
+
|
| 298 |
+
self.ft_dataset_schema = self.__merge_dict(
|
| 299 |
+
self.DATASET_ASCENDABLE_SCHEMA,
|
| 300 |
+
self.SUBSET_ASCENDABLE_SCHEMA,
|
| 301 |
+
self.DO_SUBSET_ASCENDABLE_SCHEMA if support_dropout else {},
|
| 302 |
+
{"subsets": [self.ft_subset_schema]},
|
| 303 |
+
)
|
| 304 |
+
|
| 305 |
+
self.cn_dataset_schema = self.__merge_dict(
|
| 306 |
+
self.DATASET_ASCENDABLE_SCHEMA,
|
| 307 |
+
self.SUBSET_ASCENDABLE_SCHEMA,
|
| 308 |
+
self.CN_SUBSET_ASCENDABLE_SCHEMA,
|
| 309 |
+
self.DO_SUBSET_ASCENDABLE_SCHEMA if support_dropout else {},
|
| 310 |
+
{"subsets": [self.cn_subset_schema]},
|
| 311 |
+
)
|
| 312 |
+
|
| 313 |
+
if support_dreambooth and support_finetuning:
|
| 314 |
+
|
| 315 |
+
def validate_flex_dataset(dataset_config: dict):
|
| 316 |
+
subsets_config = dataset_config.get("subsets", [])
|
| 317 |
+
|
| 318 |
+
if support_controlnet and all(["conditioning_data_dir" in subset for subset in subsets_config]):
|
| 319 |
+
return Schema(self.cn_dataset_schema)(dataset_config)
|
| 320 |
+
# check dataset meets FT style
|
| 321 |
+
# NOTE: all FT subsets should have "metadata_file"
|
| 322 |
+
elif all(["metadata_file" in subset for subset in subsets_config]):
|
| 323 |
+
return Schema(self.ft_dataset_schema)(dataset_config)
|
| 324 |
+
# check dataset meets DB style
|
| 325 |
+
# NOTE: all DB subsets should have no "metadata_file"
|
| 326 |
+
elif all(["metadata_file" not in subset for subset in subsets_config]):
|
| 327 |
+
return Schema(self.db_dataset_schema)(dataset_config)
|
| 328 |
+
else:
|
| 329 |
+
raise voluptuous.Invalid(
|
| 330 |
+
"DreamBooth subset and fine tuning subset cannot be mixed in the same dataset. Please split them into separate datasets. / DreamBoothのサブセットとfine tuninのサブセットを同一のデータセットに混在させることはできません。別々のデータセットに分割してください。"
|
| 331 |
+
)
|
| 332 |
+
|
| 333 |
+
self.dataset_schema = validate_flex_dataset
|
| 334 |
+
elif support_dreambooth:
|
| 335 |
+
if support_controlnet:
|
| 336 |
+
self.dataset_schema = self.cn_dataset_schema
|
| 337 |
+
else:
|
| 338 |
+
self.dataset_schema = self.db_dataset_schema
|
| 339 |
+
elif support_finetuning:
|
| 340 |
+
self.dataset_schema = self.ft_dataset_schema
|
| 341 |
+
elif support_controlnet:
|
| 342 |
+
self.dataset_schema = self.cn_dataset_schema
|
| 343 |
+
|
| 344 |
+
self.general_schema = self.__merge_dict(
|
| 345 |
+
self.DATASET_ASCENDABLE_SCHEMA,
|
| 346 |
+
self.SUBSET_ASCENDABLE_SCHEMA,
|
| 347 |
+
self.DB_SUBSET_ASCENDABLE_SCHEMA if support_dreambooth else {},
|
| 348 |
+
self.CN_SUBSET_ASCENDABLE_SCHEMA if support_controlnet else {},
|
| 349 |
+
self.DO_SUBSET_ASCENDABLE_SCHEMA if support_dropout else {},
|
| 350 |
+
)
|
| 351 |
+
|
| 352 |
+
self.user_config_validator = Schema(
|
| 353 |
+
{
|
| 354 |
+
"general": self.general_schema,
|
| 355 |
+
"datasets": [self.dataset_schema],
|
| 356 |
+
}
|
| 357 |
+
)
|
| 358 |
+
|
| 359 |
+
self.argparse_schema = self.__merge_dict(
|
| 360 |
+
self.general_schema,
|
| 361 |
+
self.ARGPARSE_SPECIFIC_SCHEMA,
|
| 362 |
+
{optname: Any(None, self.general_schema[optname]) for optname in self.ARGPARSE_NULLABLE_OPTNAMES},
|
| 363 |
+
{a_name: self.general_schema[c_name] for a_name, c_name in self.ARGPARSE_OPTNAME_TO_CONFIG_OPTNAME.items()},
|
| 364 |
+
)
|
| 365 |
+
|
| 366 |
+
self.argparse_config_validator = Schema(Object(self.argparse_schema), extra=voluptuous.ALLOW_EXTRA)
|
| 367 |
+
|
| 368 |
+
def sanitize_user_config(self, user_config: dict) -> dict:
|
| 369 |
+
try:
|
| 370 |
+
return self.user_config_validator(user_config)
|
| 371 |
+
except MultipleInvalid:
|
| 372 |
+
# TODO: エラー発生時のメッセージをわかりやすくする
|
| 373 |
+
logger.error("Invalid user config / ユーザ設定の形式が正しくないようです")
|
| 374 |
+
raise
|
| 375 |
+
|
| 376 |
+
# NOTE: In nature, argument parser result is not needed to be sanitize
|
| 377 |
+
# However this will help us to detect program bug
|
| 378 |
+
def sanitize_argparse_namespace(self, argparse_namespace: argparse.Namespace) -> argparse.Namespace:
|
| 379 |
+
try:
|
| 380 |
+
return self.argparse_config_validator(argparse_namespace)
|
| 381 |
+
except MultipleInvalid:
|
| 382 |
+
# XXX: this should be a bug
|
| 383 |
+
logger.error(
|
| 384 |
+
"Invalid cmdline parsed arguments. This should be a bug. / コマンドラインのパース結果が正しくないようです。プログラムのバグの可能性が高いです。"
|
| 385 |
+
)
|
| 386 |
+
raise
|
| 387 |
+
|
| 388 |
+
# NOTE: value would be overwritten by latter dict if there is already the same key
|
| 389 |
+
@staticmethod
|
| 390 |
+
def __merge_dict(*dict_list: dict) -> dict:
|
| 391 |
+
merged = {}
|
| 392 |
+
for schema in dict_list:
|
| 393 |
+
# merged |= schema
|
| 394 |
+
for k, v in schema.items():
|
| 395 |
+
merged[k] = v
|
| 396 |
+
return merged
|
| 397 |
+
|
| 398 |
+
|
| 399 |
+
class BlueprintGenerator:
|
| 400 |
+
BLUEPRINT_PARAM_NAME_TO_CONFIG_OPTNAME = {}
|
| 401 |
+
|
| 402 |
+
def __init__(self, sanitizer: ConfigSanitizer):
|
| 403 |
+
self.sanitizer = sanitizer
|
| 404 |
+
|
| 405 |
+
# runtime_params is for parameters which is only configurable on runtime, such as tokenizer
|
| 406 |
+
def generate(self, user_config: dict, argparse_namespace: argparse.Namespace, **runtime_params) -> Blueprint:
|
| 407 |
+
sanitized_user_config = self.sanitizer.sanitize_user_config(user_config)
|
| 408 |
+
sanitized_argparse_namespace = self.sanitizer.sanitize_argparse_namespace(argparse_namespace)
|
| 409 |
+
|
| 410 |
+
# convert argparse namespace to dict like config
|
| 411 |
+
# NOTE: it is ok to have extra entries in dict
|
| 412 |
+
optname_map = self.sanitizer.ARGPARSE_OPTNAME_TO_CONFIG_OPTNAME
|
| 413 |
+
argparse_config = {
|
| 414 |
+
optname_map.get(optname, optname): value for optname, value in vars(sanitized_argparse_namespace).items()
|
| 415 |
+
}
|
| 416 |
+
|
| 417 |
+
general_config = sanitized_user_config.get("general", {})
|
| 418 |
+
|
| 419 |
+
dataset_blueprints = []
|
| 420 |
+
for dataset_config in sanitized_user_config.get("datasets", []):
|
| 421 |
+
# NOTE: if subsets have no "metadata_file", these are DreamBooth datasets/subsets
|
| 422 |
+
subsets = dataset_config.get("subsets", [])
|
| 423 |
+
is_dreambooth = all(["metadata_file" not in subset for subset in subsets])
|
| 424 |
+
is_controlnet = all(["conditioning_data_dir" in subset for subset in subsets])
|
| 425 |
+
if is_controlnet:
|
| 426 |
+
subset_params_klass = ControlNetSubsetParams
|
| 427 |
+
dataset_params_klass = ControlNetDatasetParams
|
| 428 |
+
elif is_dreambooth:
|
| 429 |
+
subset_params_klass = DreamBoothSubsetParams
|
| 430 |
+
dataset_params_klass = DreamBoothDatasetParams
|
| 431 |
+
else:
|
| 432 |
+
subset_params_klass = FineTuningSubsetParams
|
| 433 |
+
dataset_params_klass = FineTuningDatasetParams
|
| 434 |
+
|
| 435 |
+
subset_blueprints = []
|
| 436 |
+
for subset_config in subsets:
|
| 437 |
+
params = self.generate_params_by_fallbacks(
|
| 438 |
+
subset_params_klass, [subset_config, dataset_config, general_config, argparse_config, runtime_params]
|
| 439 |
+
)
|
| 440 |
+
subset_blueprints.append(SubsetBlueprint(params))
|
| 441 |
+
|
| 442 |
+
params = self.generate_params_by_fallbacks(
|
| 443 |
+
dataset_params_klass, [dataset_config, general_config, argparse_config, runtime_params]
|
| 444 |
+
)
|
| 445 |
+
dataset_blueprints.append(DatasetBlueprint(is_dreambooth, is_controlnet, params, subset_blueprints))
|
| 446 |
+
|
| 447 |
+
dataset_group_blueprint = DatasetGroupBlueprint(dataset_blueprints)
|
| 448 |
+
|
| 449 |
+
return Blueprint(dataset_group_blueprint)
|
| 450 |
+
|
| 451 |
+
@staticmethod
|
| 452 |
+
def generate_params_by_fallbacks(param_klass, fallbacks: Sequence[dict]):
|
| 453 |
+
name_map = BlueprintGenerator.BLUEPRINT_PARAM_NAME_TO_CONFIG_OPTNAME
|
| 454 |
+
search_value = BlueprintGenerator.search_value
|
| 455 |
+
default_params = asdict(param_klass())
|
| 456 |
+
param_names = default_params.keys()
|
| 457 |
+
|
| 458 |
+
params = {name: search_value(name_map.get(name, name), fallbacks, default_params.get(name)) for name in param_names}
|
| 459 |
+
|
| 460 |
+
return param_klass(**params)
|
| 461 |
+
|
| 462 |
+
@staticmethod
|
| 463 |
+
def search_value(key: str, fallbacks: Sequence[dict], default_value=None):
|
| 464 |
+
for cand in fallbacks:
|
| 465 |
+
value = cand.get(key)
|
| 466 |
+
if value is not None:
|
| 467 |
+
return value
|
| 468 |
+
|
| 469 |
+
return default_value
|
| 470 |
+
|
| 471 |
+
|
| 472 |
+
def generate_dataset_group_by_blueprint(dataset_group_blueprint: DatasetGroupBlueprint):
|
| 473 |
+
datasets: List[Union[DreamBoothDataset, FineTuningDataset, ControlNetDataset]] = []
|
| 474 |
+
|
| 475 |
+
for dataset_blueprint in dataset_group_blueprint.datasets:
|
| 476 |
+
if dataset_blueprint.is_controlnet:
|
| 477 |
+
subset_klass = ControlNetSubset
|
| 478 |
+
dataset_klass = ControlNetDataset
|
| 479 |
+
elif dataset_blueprint.is_dreambooth:
|
| 480 |
+
subset_klass = DreamBoothSubset
|
| 481 |
+
dataset_klass = DreamBoothDataset
|
| 482 |
+
else:
|
| 483 |
+
subset_klass = FineTuningSubset
|
| 484 |
+
dataset_klass = FineTuningDataset
|
| 485 |
+
|
| 486 |
+
subsets = [subset_klass(**asdict(subset_blueprint.params)) for subset_blueprint in dataset_blueprint.subsets]
|
| 487 |
+
dataset = dataset_klass(subsets=subsets, **asdict(dataset_blueprint.params))
|
| 488 |
+
datasets.append(dataset)
|
| 489 |
+
|
| 490 |
+
# print info
|
| 491 |
+
info = ""
|
| 492 |
+
for i, dataset in enumerate(datasets):
|
| 493 |
+
is_dreambooth = isinstance(dataset, DreamBoothDataset)
|
| 494 |
+
is_controlnet = isinstance(dataset, ControlNetDataset)
|
| 495 |
+
info += dedent(
|
| 496 |
+
f"""\
|
| 497 |
+
[Dataset {i}]
|
| 498 |
+
batch_size: {dataset.batch_size}
|
| 499 |
+
resolution: {(dataset.width, dataset.height)}
|
| 500 |
+
enable_bucket: {dataset.enable_bucket}
|
| 501 |
+
network_multiplier: {dataset.network_multiplier}
|
| 502 |
+
"""
|
| 503 |
+
)
|
| 504 |
+
|
| 505 |
+
if dataset.enable_bucket:
|
| 506 |
+
info += indent(
|
| 507 |
+
dedent(
|
| 508 |
+
f"""\
|
| 509 |
+
min_bucket_reso: {dataset.min_bucket_reso}
|
| 510 |
+
max_bucket_reso: {dataset.max_bucket_reso}
|
| 511 |
+
bucket_reso_steps: {dataset.bucket_reso_steps}
|
| 512 |
+
bucket_no_upscale: {dataset.bucket_no_upscale}
|
| 513 |
+
\n"""
|
| 514 |
+
),
|
| 515 |
+
" ",
|
| 516 |
+
)
|
| 517 |
+
else:
|
| 518 |
+
info += "\n"
|
| 519 |
+
|
| 520 |
+
for j, subset in enumerate(dataset.subsets):
|
| 521 |
+
info += indent(
|
| 522 |
+
dedent(
|
| 523 |
+
f"""\
|
| 524 |
+
[Subset {j} of Dataset {i}]
|
| 525 |
+
image_dir: "{subset.image_dir}"
|
| 526 |
+
image_count: {subset.img_count}
|
| 527 |
+
num_repeats: {subset.num_repeats}
|
| 528 |
+
shuffle_caption: {subset.shuffle_caption}
|
| 529 |
+
keep_tokens: {subset.keep_tokens}
|
| 530 |
+
keep_tokens_separator: {subset.keep_tokens_separator}
|
| 531 |
+
caption_separator: {subset.caption_separator}
|
| 532 |
+
secondary_separator: {subset.secondary_separator}
|
| 533 |
+
enable_wildcard: {subset.enable_wildcard}
|
| 534 |
+
caption_dropout_rate: {subset.caption_dropout_rate}
|
| 535 |
+
caption_dropout_every_n_epoches: {subset.caption_dropout_every_n_epochs}
|
| 536 |
+
caption_tag_dropout_rate: {subset.caption_tag_dropout_rate}
|
| 537 |
+
caption_prefix: {subset.caption_prefix}
|
| 538 |
+
caption_suffix: {subset.caption_suffix}
|
| 539 |
+
color_aug: {subset.color_aug}
|
| 540 |
+
flip_aug: {subset.flip_aug}
|
| 541 |
+
face_crop_aug_range: {subset.face_crop_aug_range}
|
| 542 |
+
random_crop: {subset.random_crop}
|
| 543 |
+
token_warmup_min: {subset.token_warmup_min},
|
| 544 |
+
token_warmup_step: {subset.token_warmup_step},
|
| 545 |
+
alpha_mask: {subset.alpha_mask},
|
| 546 |
+
"""
|
| 547 |
+
),
|
| 548 |
+
" ",
|
| 549 |
+
)
|
| 550 |
+
|
| 551 |
+
if is_dreambooth:
|
| 552 |
+
info += indent(
|
| 553 |
+
dedent(
|
| 554 |
+
f"""\
|
| 555 |
+
is_reg: {subset.is_reg}
|
| 556 |
+
class_tokens: {subset.class_tokens}
|
| 557 |
+
caption_extension: {subset.caption_extension}
|
| 558 |
+
\n"""
|
| 559 |
+
),
|
| 560 |
+
" ",
|
| 561 |
+
)
|
| 562 |
+
elif not is_controlnet:
|
| 563 |
+
info += indent(
|
| 564 |
+
dedent(
|
| 565 |
+
f"""\
|
| 566 |
+
metadata_file: {subset.metadata_file}
|
| 567 |
+
\n"""
|
| 568 |
+
),
|
| 569 |
+
" ",
|
| 570 |
+
)
|
| 571 |
+
|
| 572 |
+
logger.info(f"{info}")
|
| 573 |
+
|
| 574 |
+
# make buckets first because it determines the length of dataset
|
| 575 |
+
# and set the same seed for all datasets
|
| 576 |
+
seed = random.randint(0, 2**31) # actual seed is seed + epoch_no
|
| 577 |
+
for i, dataset in enumerate(datasets):
|
| 578 |
+
logger.info(f"[Dataset {i}]")
|
| 579 |
+
dataset.make_buckets()
|
| 580 |
+
dataset.set_seed(seed)
|
| 581 |
+
|
| 582 |
+
return DatasetGroup(datasets)
|
| 583 |
+
|
| 584 |
+
|
| 585 |
+
def generate_dreambooth_subsets_config_by_subdirs(train_data_dir: Optional[str] = None, reg_data_dir: Optional[str] = None):
|
| 586 |
+
def extract_dreambooth_params(name: str) -> Tuple[int, str]:
|
| 587 |
+
tokens = name.split("_")
|
| 588 |
+
try:
|
| 589 |
+
n_repeats = int(tokens[0])
|
| 590 |
+
except ValueError as e:
|
| 591 |
+
logger.warning(f"ignore directory without repeats / 繰り返し回数のないディレクトリを無視します: {name}")
|
| 592 |
+
return 0, ""
|
| 593 |
+
caption_by_folder = "_".join(tokens[1:])
|
| 594 |
+
return n_repeats, caption_by_folder
|
| 595 |
+
|
| 596 |
+
def generate(base_dir: Optional[str], is_reg: bool):
|
| 597 |
+
if base_dir is None:
|
| 598 |
+
return []
|
| 599 |
+
|
| 600 |
+
base_dir: Path = Path(base_dir)
|
| 601 |
+
if not base_dir.is_dir():
|
| 602 |
+
return []
|
| 603 |
+
|
| 604 |
+
subsets_config = []
|
| 605 |
+
for subdir in base_dir.iterdir():
|
| 606 |
+
if not subdir.is_dir():
|
| 607 |
+
continue
|
| 608 |
+
|
| 609 |
+
num_repeats, class_tokens = extract_dreambooth_params(subdir.name)
|
| 610 |
+
if num_repeats < 1:
|
| 611 |
+
continue
|
| 612 |
+
|
| 613 |
+
subset_config = {"image_dir": str(subdir), "num_repeats": num_repeats, "is_reg": is_reg, "class_tokens": class_tokens}
|
| 614 |
+
subsets_config.append(subset_config)
|
| 615 |
+
|
| 616 |
+
return subsets_config
|
| 617 |
+
|
| 618 |
+
subsets_config = []
|
| 619 |
+
subsets_config += generate(train_data_dir, False)
|
| 620 |
+
subsets_config += generate(reg_data_dir, True)
|
| 621 |
+
|
| 622 |
+
return subsets_config
|
| 623 |
+
|
| 624 |
+
|
| 625 |
+
def generate_controlnet_subsets_config_by_subdirs(
|
| 626 |
+
train_data_dir: Optional[str] = None, conditioning_data_dir: Optional[str] = None, caption_extension: str = ".txt"
|
| 627 |
+
):
|
| 628 |
+
def generate(base_dir: Optional[str]):
|
| 629 |
+
if base_dir is None:
|
| 630 |
+
return []
|
| 631 |
+
|
| 632 |
+
base_dir: Path = Path(base_dir)
|
| 633 |
+
if not base_dir.is_dir():
|
| 634 |
+
return []
|
| 635 |
+
|
| 636 |
+
subsets_config = []
|
| 637 |
+
subset_config = {
|
| 638 |
+
"image_dir": train_data_dir,
|
| 639 |
+
"conditioning_data_dir": conditioning_data_dir,
|
| 640 |
+
"caption_extension": caption_extension,
|
| 641 |
+
"num_repeats": 1,
|
| 642 |
+
}
|
| 643 |
+
subsets_config.append(subset_config)
|
| 644 |
+
|
| 645 |
+
return subsets_config
|
| 646 |
+
|
| 647 |
+
subsets_config = []
|
| 648 |
+
subsets_config += generate(train_data_dir)
|
| 649 |
+
|
| 650 |
+
return subsets_config
|
| 651 |
+
|
| 652 |
+
|
| 653 |
+
def load_user_config(file: str) -> dict:
|
| 654 |
+
file: Path = Path(file)
|
| 655 |
+
if not file.is_file():
|
| 656 |
+
raise ValueError(f"file not found / ファイルが見つかりません: {file}")
|
| 657 |
+
|
| 658 |
+
if file.name.lower().endswith(".json"):
|
| 659 |
+
try:
|
| 660 |
+
with open(file, "r") as f:
|
| 661 |
+
config = json.load(f)
|
| 662 |
+
except Exception:
|
| 663 |
+
logger.error(
|
| 664 |
+
f"Error on parsing JSON config file. Please check the format. / JSON 形式の設定ファイルの読み込みに失敗しました。文法が正しいか確認してください。: {file}"
|
| 665 |
+
)
|
| 666 |
+
raise
|
| 667 |
+
elif file.name.lower().endswith(".toml"):
|
| 668 |
+
try:
|
| 669 |
+
config = toml.load(file)
|
| 670 |
+
except Exception:
|
| 671 |
+
logger.error(
|
| 672 |
+
f"Error on parsing TOML config file. Please check the format. / TOML 形式の設定ファイルの読み込みに失敗しました。文法が正しいか確認してください。: {file}"
|
| 673 |
+
)
|
| 674 |
+
raise
|
| 675 |
+
else:
|
| 676 |
+
raise ValueError(f"not supported config file format / 対応していない設定ファイルの形式です: {file}")
|
| 677 |
+
|
| 678 |
+
return config
|
| 679 |
+
|
| 680 |
+
|
| 681 |
+
# for config test
|
| 682 |
+
if __name__ == "__main__":
|
| 683 |
+
parser = argparse.ArgumentParser()
|
| 684 |
+
parser.add_argument("--support_dreambooth", action="store_true")
|
| 685 |
+
parser.add_argument("--support_finetuning", action="store_true")
|
| 686 |
+
parser.add_argument("--support_controlnet", action="store_true")
|
| 687 |
+
parser.add_argument("--support_dropout", action="store_true")
|
| 688 |
+
parser.add_argument("dataset_config")
|
| 689 |
+
config_args, remain = parser.parse_known_args()
|
| 690 |
+
|
| 691 |
+
parser = argparse.ArgumentParser()
|
| 692 |
+
train_util.add_dataset_arguments(
|
| 693 |
+
parser, config_args.support_dreambooth, config_args.support_finetuning, config_args.support_dropout
|
| 694 |
+
)
|
| 695 |
+
train_util.add_training_arguments(parser, config_args.support_dreambooth)
|
| 696 |
+
argparse_namespace = parser.parse_args(remain)
|
| 697 |
+
train_util.prepare_dataset_args(argparse_namespace, config_args.support_finetuning)
|
| 698 |
+
|
| 699 |
+
logger.info("[argparse_namespace]")
|
| 700 |
+
logger.info(f"{vars(argparse_namespace)}")
|
| 701 |
+
|
| 702 |
+
user_config = load_user_config(config_args.dataset_config)
|
| 703 |
+
|
| 704 |
+
logger.info("")
|
| 705 |
+
logger.info("[user_config]")
|
| 706 |
+
logger.info(f"{user_config}")
|
| 707 |
+
|
| 708 |
+
sanitizer = ConfigSanitizer(
|
| 709 |
+
config_args.support_dreambooth, config_args.support_finetuning, config_args.support_controlnet, config_args.support_dropout
|
| 710 |
+
)
|
| 711 |
+
sanitized_user_config = sanitizer.sanitize_user_config(user_config)
|
| 712 |
+
|
| 713 |
+
logger.info("")
|
| 714 |
+
logger.info("[sanitized_user_config]")
|
| 715 |
+
logger.info(f"{sanitized_user_config}")
|
| 716 |
+
|
| 717 |
+
blueprint = BlueprintGenerator(sanitizer).generate(user_config, argparse_namespace)
|
| 718 |
+
|
| 719 |
+
logger.info("")
|
| 720 |
+
logger.info("[blueprint]")
|
| 721 |
+
logger.info(f"{blueprint}")
|
control_net_lllite.py
ADDED
|
@@ -0,0 +1,449 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
from typing import Optional, List, Type
|
| 3 |
+
import torch
|
| 4 |
+
from library import sdxl_original_unet
|
| 5 |
+
from library.utils import setup_logging
|
| 6 |
+
setup_logging()
|
| 7 |
+
import logging
|
| 8 |
+
logger = logging.getLogger(__name__)
|
| 9 |
+
|
| 10 |
+
# input_blocksに適用するかどうか / if True, input_blocks are not applied
|
| 11 |
+
SKIP_INPUT_BLOCKS = False
|
| 12 |
+
|
| 13 |
+
# output_blocksに適用するかどうか / if True, output_blocks are not applied
|
| 14 |
+
SKIP_OUTPUT_BLOCKS = True
|
| 15 |
+
|
| 16 |
+
# conv2dに適用するかどうか / if True, conv2d are not applied
|
| 17 |
+
SKIP_CONV2D = False
|
| 18 |
+
|
| 19 |
+
# transformer_blocksのみに適用するかどうか。Trueの場合、ResBlockには適用されない
|
| 20 |
+
# if True, only transformer_blocks are applied, and ResBlocks are not applied
|
| 21 |
+
TRANSFORMER_ONLY = True # if True, SKIP_CONV2D is ignored because conv2d is not used in transformer_blocks
|
| 22 |
+
|
| 23 |
+
# Trueならattn1とattn2にのみ適用し、ffなどには適用しない / if True, apply only to attn1 and attn2, not to ff etc.
|
| 24 |
+
ATTN1_2_ONLY = True
|
| 25 |
+
|
| 26 |
+
# Trueならattn1のQKV、attn2のQにのみ適用する、ATTN1_2_ONLY指定時のみ有効 / if True, apply only to attn1 QKV and attn2 Q, only valid when ATTN1_2_ONLY is specified
|
| 27 |
+
ATTN_QKV_ONLY = True
|
| 28 |
+
|
| 29 |
+
# Trueならattn1やffなどにのみ適用し、attn2などには適用しない / if True, apply only to attn1 and ff, not to attn2
|
| 30 |
+
# ATTN1_2_ONLYと同時にTrueにできない / cannot be True at the same time as ATTN1_2_ONLY
|
| 31 |
+
ATTN1_ETC_ONLY = False # True
|
| 32 |
+
|
| 33 |
+
# transformer_blocksの最大インデックス。Noneなら全てのtransformer_blocksに適用
|
| 34 |
+
# max index of transformer_blocks. if None, apply to all transformer_blocks
|
| 35 |
+
TRANSFORMER_MAX_BLOCK_INDEX = None
|
| 36 |
+
|
| 37 |
+
|
| 38 |
+
class LLLiteModule(torch.nn.Module):
|
| 39 |
+
def __init__(self, depth, cond_emb_dim, name, org_module, mlp_dim, dropout=None, multiplier=1.0):
|
| 40 |
+
super().__init__()
|
| 41 |
+
|
| 42 |
+
self.is_conv2d = org_module.__class__.__name__ == "Conv2d"
|
| 43 |
+
self.lllite_name = name
|
| 44 |
+
self.cond_emb_dim = cond_emb_dim
|
| 45 |
+
self.org_module = [org_module]
|
| 46 |
+
self.dropout = dropout
|
| 47 |
+
self.multiplier = multiplier
|
| 48 |
+
|
| 49 |
+
if self.is_conv2d:
|
| 50 |
+
in_dim = org_module.in_channels
|
| 51 |
+
else:
|
| 52 |
+
in_dim = org_module.in_features
|
| 53 |
+
|
| 54 |
+
# conditioning1はconditioning imageを embedding する。timestepごとに呼ばれない
|
| 55 |
+
# conditioning1 embeds conditioning image. it is not called for each timestep
|
| 56 |
+
modules = []
|
| 57 |
+
modules.append(torch.nn.Conv2d(3, cond_emb_dim // 2, kernel_size=4, stride=4, padding=0)) # to latent (from VAE) size
|
| 58 |
+
if depth == 1:
|
| 59 |
+
modules.append(torch.nn.ReLU(inplace=True))
|
| 60 |
+
modules.append(torch.nn.Conv2d(cond_emb_dim // 2, cond_emb_dim, kernel_size=2, stride=2, padding=0))
|
| 61 |
+
elif depth == 2:
|
| 62 |
+
modules.append(torch.nn.ReLU(inplace=True))
|
| 63 |
+
modules.append(torch.nn.Conv2d(cond_emb_dim // 2, cond_emb_dim, kernel_size=4, stride=4, padding=0))
|
| 64 |
+
elif depth == 3:
|
| 65 |
+
# kernel size 8は大きすぎるので、4にする / kernel size 8 is too large, so set it to 4
|
| 66 |
+
modules.append(torch.nn.ReLU(inplace=True))
|
| 67 |
+
modules.append(torch.nn.Conv2d(cond_emb_dim // 2, cond_emb_dim // 2, kernel_size=4, stride=4, padding=0))
|
| 68 |
+
modules.append(torch.nn.ReLU(inplace=True))
|
| 69 |
+
modules.append(torch.nn.Conv2d(cond_emb_dim // 2, cond_emb_dim, kernel_size=2, stride=2, padding=0))
|
| 70 |
+
|
| 71 |
+
self.conditioning1 = torch.nn.Sequential(*modules)
|
| 72 |
+
|
| 73 |
+
# downで入力の次元数を削減する。LoRAにヒントを得ていることにする
|
| 74 |
+
# midでconditioning image embeddingと入力を結合する
|
| 75 |
+
# upで元の次元数に戻す
|
| 76 |
+
# これらはtimestepごとに呼ばれる
|
| 77 |
+
# reduce the number of input dimensions with down. inspired by LoRA
|
| 78 |
+
# combine conditioning image embedding and input with mid
|
| 79 |
+
# restore to the original dimension with up
|
| 80 |
+
# these are called for each timestep
|
| 81 |
+
|
| 82 |
+
if self.is_conv2d:
|
| 83 |
+
self.down = torch.nn.Sequential(
|
| 84 |
+
torch.nn.Conv2d(in_dim, mlp_dim, kernel_size=1, stride=1, padding=0),
|
| 85 |
+
torch.nn.ReLU(inplace=True),
|
| 86 |
+
)
|
| 87 |
+
self.mid = torch.nn.Sequential(
|
| 88 |
+
torch.nn.Conv2d(mlp_dim + cond_emb_dim, mlp_dim, kernel_size=1, stride=1, padding=0),
|
| 89 |
+
torch.nn.ReLU(inplace=True),
|
| 90 |
+
)
|
| 91 |
+
self.up = torch.nn.Sequential(
|
| 92 |
+
torch.nn.Conv2d(mlp_dim, in_dim, kernel_size=1, stride=1, padding=0),
|
| 93 |
+
)
|
| 94 |
+
else:
|
| 95 |
+
# midの前にconditioningをreshapeすること / reshape conditioning before mid
|
| 96 |
+
self.down = torch.nn.Sequential(
|
| 97 |
+
torch.nn.Linear(in_dim, mlp_dim),
|
| 98 |
+
torch.nn.ReLU(inplace=True),
|
| 99 |
+
)
|
| 100 |
+
self.mid = torch.nn.Sequential(
|
| 101 |
+
torch.nn.Linear(mlp_dim + cond_emb_dim, mlp_dim),
|
| 102 |
+
torch.nn.ReLU(inplace=True),
|
| 103 |
+
)
|
| 104 |
+
self.up = torch.nn.Sequential(
|
| 105 |
+
torch.nn.Linear(mlp_dim, in_dim),
|
| 106 |
+
)
|
| 107 |
+
|
| 108 |
+
# Zero-Convにする / set to Zero-Conv
|
| 109 |
+
torch.nn.init.zeros_(self.up[0].weight) # zero conv
|
| 110 |
+
|
| 111 |
+
self.depth = depth # 1~3
|
| 112 |
+
self.cond_emb = None
|
| 113 |
+
self.batch_cond_only = False # Trueなら推論時のcondにのみ適用する / if True, apply only to cond at inference
|
| 114 |
+
self.use_zeros_for_batch_uncond = False # Trueならuncondのconditioningを0にする / if True, set uncond conditioning to 0
|
| 115 |
+
|
| 116 |
+
# batch_cond_onlyとuse_zeros_for_batch_uncondはどちらも適用すると生成画像の色味がおかしくなるので実際には使えそうにない
|
| 117 |
+
# Controlの種類によっては使えるかも
|
| 118 |
+
# both batch_cond_only and use_zeros_for_batch_uncond make the color of the generated image strange, so it doesn't seem to be usable in practice
|
| 119 |
+
# it may be available depending on the type of Control
|
| 120 |
+
|
| 121 |
+
def set_cond_image(self, cond_image):
|
| 122 |
+
r"""
|
| 123 |
+
中でモデルを呼び出すので必要ならwith torch.no_grad()で囲む
|
| 124 |
+
/ call the model inside, so if necessary, surround it with torch.no_grad()
|
| 125 |
+
"""
|
| 126 |
+
if cond_image is None:
|
| 127 |
+
self.cond_emb = None
|
| 128 |
+
return
|
| 129 |
+
|
| 130 |
+
# timestepごとに呼ばれないので、あらかじめ計算しておく / it is not called for each timestep, so calculate it in advance
|
| 131 |
+
# logger.info(f"C {self.lllite_name}, cond_image.shape={cond_image.shape}")
|
| 132 |
+
cx = self.conditioning1(cond_image)
|
| 133 |
+
if not self.is_conv2d:
|
| 134 |
+
# reshape / b,c,h,w -> b,h*w,c
|
| 135 |
+
n, c, h, w = cx.shape
|
| 136 |
+
cx = cx.view(n, c, h * w).permute(0, 2, 1)
|
| 137 |
+
self.cond_emb = cx
|
| 138 |
+
|
| 139 |
+
def set_batch_cond_only(self, cond_only, zeros):
|
| 140 |
+
self.batch_cond_only = cond_only
|
| 141 |
+
self.use_zeros_for_batch_uncond = zeros
|
| 142 |
+
|
| 143 |
+
def apply_to(self):
|
| 144 |
+
self.org_forward = self.org_module[0].forward
|
| 145 |
+
self.org_module[0].forward = self.forward
|
| 146 |
+
|
| 147 |
+
def forward(self, x):
|
| 148 |
+
r"""
|
| 149 |
+
学習用の便利forward。元のモジュールのforwardを呼び出す
|
| 150 |
+
/ convenient forward for training. call the forward of the original module
|
| 151 |
+
"""
|
| 152 |
+
if self.multiplier == 0.0 or self.cond_emb is None:
|
| 153 |
+
return self.org_forward(x)
|
| 154 |
+
|
| 155 |
+
cx = self.cond_emb
|
| 156 |
+
|
| 157 |
+
if not self.batch_cond_only and x.shape[0] // 2 == cx.shape[0]: # inference only
|
| 158 |
+
cx = cx.repeat(2, 1, 1, 1) if self.is_conv2d else cx.repeat(2, 1, 1)
|
| 159 |
+
if self.use_zeros_for_batch_uncond:
|
| 160 |
+
cx[0::2] = 0.0 # uncond is zero
|
| 161 |
+
# logger.info(f"C {self.lllite_name}, x.shape={x.shape}, cx.shape={cx.shape}")
|
| 162 |
+
|
| 163 |
+
# downで入力の次元数を削減し、conditioning image embeddingと結合する
|
| 164 |
+
# 加算ではなくchannel方向に結合することで、うまいこと混ぜてくれることを期待している
|
| 165 |
+
# down reduces the number of input dimensions and combines it with conditioning image embedding
|
| 166 |
+
# we expect that it will mix well by combining in the channel direction instead of adding
|
| 167 |
+
|
| 168 |
+
cx = torch.cat([cx, self.down(x if not self.batch_cond_only else x[1::2])], dim=1 if self.is_conv2d else 2)
|
| 169 |
+
cx = self.mid(cx)
|
| 170 |
+
|
| 171 |
+
if self.dropout is not None and self.training:
|
| 172 |
+
cx = torch.nn.functional.dropout(cx, p=self.dropout)
|
| 173 |
+
|
| 174 |
+
cx = self.up(cx) * self.multiplier
|
| 175 |
+
|
| 176 |
+
# residual (x) を加算して元のforwardを呼び出す / add residual (x) and call the original forward
|
| 177 |
+
if self.batch_cond_only:
|
| 178 |
+
zx = torch.zeros_like(x)
|
| 179 |
+
zx[1::2] += cx
|
| 180 |
+
cx = zx
|
| 181 |
+
|
| 182 |
+
x = self.org_forward(x + cx) # ここで元のモジュールを呼び出す / call the original module here
|
| 183 |
+
return x
|
| 184 |
+
|
| 185 |
+
|
| 186 |
+
class ControlNetLLLite(torch.nn.Module):
|
| 187 |
+
UNET_TARGET_REPLACE_MODULE = ["Transformer2DModel"]
|
| 188 |
+
UNET_TARGET_REPLACE_MODULE_CONV2D_3X3 = ["ResnetBlock2D", "Downsample2D", "Upsample2D"]
|
| 189 |
+
|
| 190 |
+
def __init__(
|
| 191 |
+
self,
|
| 192 |
+
unet: sdxl_original_unet.SdxlUNet2DConditionModel,
|
| 193 |
+
cond_emb_dim: int = 16,
|
| 194 |
+
mlp_dim: int = 16,
|
| 195 |
+
dropout: Optional[float] = None,
|
| 196 |
+
varbose: Optional[bool] = False,
|
| 197 |
+
multiplier: Optional[float] = 1.0,
|
| 198 |
+
) -> None:
|
| 199 |
+
super().__init__()
|
| 200 |
+
# self.unets = [unet]
|
| 201 |
+
|
| 202 |
+
def create_modules(
|
| 203 |
+
root_module: torch.nn.Module,
|
| 204 |
+
target_replace_modules: List[torch.nn.Module],
|
| 205 |
+
module_class: Type[object],
|
| 206 |
+
) -> List[torch.nn.Module]:
|
| 207 |
+
prefix = "lllite_unet"
|
| 208 |
+
|
| 209 |
+
modules = []
|
| 210 |
+
for name, module in root_module.named_modules():
|
| 211 |
+
if module.__class__.__name__ in target_replace_modules:
|
| 212 |
+
for child_name, child_module in module.named_modules():
|
| 213 |
+
is_linear = child_module.__class__.__name__ == "Linear"
|
| 214 |
+
is_conv2d = child_module.__class__.__name__ == "Conv2d"
|
| 215 |
+
|
| 216 |
+
if is_linear or (is_conv2d and not SKIP_CONV2D):
|
| 217 |
+
# block indexからdepthを計算: depthはconditioningのサイズやチャネルを計算するのに使う
|
| 218 |
+
# block index to depth: depth is using to calculate conditioning size and channels
|
| 219 |
+
block_name, index1, index2 = (name + "." + child_name).split(".")[:3]
|
| 220 |
+
index1 = int(index1)
|
| 221 |
+
if block_name == "input_blocks":
|
| 222 |
+
if SKIP_INPUT_BLOCKS:
|
| 223 |
+
continue
|
| 224 |
+
depth = 1 if index1 <= 2 else (2 if index1 <= 5 else 3)
|
| 225 |
+
elif block_name == "middle_block":
|
| 226 |
+
depth = 3
|
| 227 |
+
elif block_name == "output_blocks":
|
| 228 |
+
if SKIP_OUTPUT_BLOCKS:
|
| 229 |
+
continue
|
| 230 |
+
depth = 3 if index1 <= 2 else (2 if index1 <= 5 else 1)
|
| 231 |
+
if int(index2) >= 2:
|
| 232 |
+
depth -= 1
|
| 233 |
+
else:
|
| 234 |
+
raise NotImplementedError()
|
| 235 |
+
|
| 236 |
+
lllite_name = prefix + "." + name + "." + child_name
|
| 237 |
+
lllite_name = lllite_name.replace(".", "_")
|
| 238 |
+
|
| 239 |
+
if TRANSFORMER_MAX_BLOCK_INDEX is not None:
|
| 240 |
+
p = lllite_name.find("transformer_blocks")
|
| 241 |
+
if p >= 0:
|
| 242 |
+
tf_index = int(lllite_name[p:].split("_")[2])
|
| 243 |
+
if tf_index > TRANSFORMER_MAX_BLOCK_INDEX:
|
| 244 |
+
continue
|
| 245 |
+
|
| 246 |
+
# time embは適用外とする
|
| 247 |
+
# attn2のconditioning (CLIPからの入力) はshapeが違うので適用できない
|
| 248 |
+
# time emb is not applied
|
| 249 |
+
# attn2 conditioning (input from CLIP) cannot be applied because the shape is different
|
| 250 |
+
if "emb_layers" in lllite_name or (
|
| 251 |
+
"attn2" in lllite_name and ("to_k" in lllite_name or "to_v" in lllite_name)
|
| 252 |
+
):
|
| 253 |
+
continue
|
| 254 |
+
|
| 255 |
+
if ATTN1_2_ONLY:
|
| 256 |
+
if not ("attn1" in lllite_name or "attn2" in lllite_name):
|
| 257 |
+
continue
|
| 258 |
+
if ATTN_QKV_ONLY:
|
| 259 |
+
if "to_out" in lllite_name:
|
| 260 |
+
continue
|
| 261 |
+
|
| 262 |
+
if ATTN1_ETC_ONLY:
|
| 263 |
+
if "proj_out" in lllite_name:
|
| 264 |
+
pass
|
| 265 |
+
elif "attn1" in lllite_name and (
|
| 266 |
+
"to_k" in lllite_name or "to_v" in lllite_name or "to_out" in lllite_name
|
| 267 |
+
):
|
| 268 |
+
pass
|
| 269 |
+
elif "ff_net_2" in lllite_name:
|
| 270 |
+
pass
|
| 271 |
+
else:
|
| 272 |
+
continue
|
| 273 |
+
|
| 274 |
+
module = module_class(
|
| 275 |
+
depth,
|
| 276 |
+
cond_emb_dim,
|
| 277 |
+
lllite_name,
|
| 278 |
+
child_module,
|
| 279 |
+
mlp_dim,
|
| 280 |
+
dropout=dropout,
|
| 281 |
+
multiplier=multiplier,
|
| 282 |
+
)
|
| 283 |
+
modules.append(module)
|
| 284 |
+
return modules
|
| 285 |
+
|
| 286 |
+
target_modules = ControlNetLLLite.UNET_TARGET_REPLACE_MODULE
|
| 287 |
+
if not TRANSFORMER_ONLY:
|
| 288 |
+
target_modules = target_modules + ControlNetLLLite.UNET_TARGET_REPLACE_MODULE_CONV2D_3X3
|
| 289 |
+
|
| 290 |
+
# create module instances
|
| 291 |
+
self.unet_modules: List[LLLiteModule] = create_modules(unet, target_modules, LLLiteModule)
|
| 292 |
+
logger.info(f"create ControlNet LLLite for U-Net: {len(self.unet_modules)} modules.")
|
| 293 |
+
|
| 294 |
+
def forward(self, x):
|
| 295 |
+
return x # dummy
|
| 296 |
+
|
| 297 |
+
def set_cond_image(self, cond_image):
|
| 298 |
+
r"""
|
| 299 |
+
中でモデルを呼び出すので必要ならwith torch.no_grad()で囲む
|
| 300 |
+
/ call the model inside, so if necessary, surround it with torch.no_grad()
|
| 301 |
+
"""
|
| 302 |
+
for module in self.unet_modules:
|
| 303 |
+
module.set_cond_image(cond_image)
|
| 304 |
+
|
| 305 |
+
def set_batch_cond_only(self, cond_only, zeros):
|
| 306 |
+
for module in self.unet_modules:
|
| 307 |
+
module.set_batch_cond_only(cond_only, zeros)
|
| 308 |
+
|
| 309 |
+
def set_multiplier(self, multiplier):
|
| 310 |
+
for module in self.unet_modules:
|
| 311 |
+
module.multiplier = multiplier
|
| 312 |
+
|
| 313 |
+
def load_weights(self, file):
|
| 314 |
+
if os.path.splitext(file)[1] == ".safetensors":
|
| 315 |
+
from safetensors.torch import load_file
|
| 316 |
+
|
| 317 |
+
weights_sd = load_file(file)
|
| 318 |
+
else:
|
| 319 |
+
weights_sd = torch.load(file, map_location="cpu")
|
| 320 |
+
|
| 321 |
+
info = self.load_state_dict(weights_sd, False)
|
| 322 |
+
return info
|
| 323 |
+
|
| 324 |
+
def apply_to(self):
|
| 325 |
+
logger.info("applying LLLite for U-Net...")
|
| 326 |
+
for module in self.unet_modules:
|
| 327 |
+
module.apply_to()
|
| 328 |
+
self.add_module(module.lllite_name, module)
|
| 329 |
+
|
| 330 |
+
# マージできるかどうかを返す
|
| 331 |
+
def is_mergeable(self):
|
| 332 |
+
return False
|
| 333 |
+
|
| 334 |
+
def merge_to(self, text_encoder, unet, weights_sd, dtype, device):
|
| 335 |
+
raise NotImplementedError()
|
| 336 |
+
|
| 337 |
+
def enable_gradient_checkpointing(self):
|
| 338 |
+
# not supported
|
| 339 |
+
pass
|
| 340 |
+
|
| 341 |
+
def prepare_optimizer_params(self):
|
| 342 |
+
self.requires_grad_(True)
|
| 343 |
+
return self.parameters()
|
| 344 |
+
|
| 345 |
+
def prepare_grad_etc(self):
|
| 346 |
+
self.requires_grad_(True)
|
| 347 |
+
|
| 348 |
+
def on_epoch_start(self):
|
| 349 |
+
self.train()
|
| 350 |
+
|
| 351 |
+
def get_trainable_params(self):
|
| 352 |
+
return self.parameters()
|
| 353 |
+
|
| 354 |
+
def save_weights(self, file, dtype, metadata):
|
| 355 |
+
if metadata is not None and len(metadata) == 0:
|
| 356 |
+
metadata = None
|
| 357 |
+
|
| 358 |
+
state_dict = self.state_dict()
|
| 359 |
+
|
| 360 |
+
if dtype is not None:
|
| 361 |
+
for key in list(state_dict.keys()):
|
| 362 |
+
v = state_dict[key]
|
| 363 |
+
v = v.detach().clone().to("cpu").to(dtype)
|
| 364 |
+
state_dict[key] = v
|
| 365 |
+
|
| 366 |
+
if os.path.splitext(file)[1] == ".safetensors":
|
| 367 |
+
from safetensors.torch import save_file
|
| 368 |
+
|
| 369 |
+
save_file(state_dict, file, metadata)
|
| 370 |
+
else:
|
| 371 |
+
torch.save(state_dict, file)
|
| 372 |
+
|
| 373 |
+
|
| 374 |
+
if __name__ == "__main__":
|
| 375 |
+
# デバッグ用 / for debug
|
| 376 |
+
|
| 377 |
+
# sdxl_original_unet.USE_REENTRANT = False
|
| 378 |
+
|
| 379 |
+
# test shape etc
|
| 380 |
+
logger.info("create unet")
|
| 381 |
+
unet = sdxl_original_unet.SdxlUNet2DConditionModel()
|
| 382 |
+
unet.to("cuda").to(torch.float16)
|
| 383 |
+
|
| 384 |
+
logger.info("create ControlNet-LLLite")
|
| 385 |
+
control_net = ControlNetLLLite(unet, 32, 64)
|
| 386 |
+
control_net.apply_to()
|
| 387 |
+
control_net.to("cuda")
|
| 388 |
+
|
| 389 |
+
logger.info(control_net)
|
| 390 |
+
|
| 391 |
+
# logger.info number of parameters
|
| 392 |
+
logger.info(f"number of parameters {sum(p.numel() for p in control_net.parameters() if p.requires_grad)}")
|
| 393 |
+
|
| 394 |
+
input()
|
| 395 |
+
|
| 396 |
+
unet.set_use_memory_efficient_attention(True, False)
|
| 397 |
+
unet.set_gradient_checkpointing(True)
|
| 398 |
+
unet.train() # for gradient checkpointing
|
| 399 |
+
|
| 400 |
+
control_net.train()
|
| 401 |
+
|
| 402 |
+
# # visualize
|
| 403 |
+
# import torchviz
|
| 404 |
+
# logger.info("run visualize")
|
| 405 |
+
# controlnet.set_control(conditioning_image)
|
| 406 |
+
# output = unet(x, t, ctx, y)
|
| 407 |
+
# logger.info("make_dot")
|
| 408 |
+
# image = torchviz.make_dot(output, params=dict(controlnet.named_parameters()))
|
| 409 |
+
# logger.info("render")
|
| 410 |
+
# image.format = "svg" # "png"
|
| 411 |
+
# image.render("NeuralNet") # すごく時間がかかるので注意 / be careful because it takes a long time
|
| 412 |
+
# input()
|
| 413 |
+
|
| 414 |
+
import bitsandbytes
|
| 415 |
+
|
| 416 |
+
optimizer = bitsandbytes.adam.Adam8bit(control_net.prepare_optimizer_params(), 1e-3)
|
| 417 |
+
|
| 418 |
+
scaler = torch.cuda.amp.GradScaler(enabled=True)
|
| 419 |
+
|
| 420 |
+
logger.info("start training")
|
| 421 |
+
steps = 10
|
| 422 |
+
|
| 423 |
+
sample_param = [p for p in control_net.named_parameters() if "up" in p[0]][0]
|
| 424 |
+
for step in range(steps):
|
| 425 |
+
logger.info(f"step {step}")
|
| 426 |
+
|
| 427 |
+
batch_size = 1
|
| 428 |
+
conditioning_image = torch.rand(batch_size, 3, 1024, 1024).cuda() * 2.0 - 1.0
|
| 429 |
+
x = torch.randn(batch_size, 4, 128, 128).cuda()
|
| 430 |
+
t = torch.randint(low=0, high=10, size=(batch_size,)).cuda()
|
| 431 |
+
ctx = torch.randn(batch_size, 77, 2048).cuda()
|
| 432 |
+
y = torch.randn(batch_size, sdxl_original_unet.ADM_IN_CHANNELS).cuda()
|
| 433 |
+
|
| 434 |
+
with torch.cuda.amp.autocast(enabled=True):
|
| 435 |
+
control_net.set_cond_image(conditioning_image)
|
| 436 |
+
|
| 437 |
+
output = unet(x, t, ctx, y)
|
| 438 |
+
target = torch.randn_like(output)
|
| 439 |
+
loss = torch.nn.functional.mse_loss(output, target)
|
| 440 |
+
|
| 441 |
+
scaler.scale(loss).backward()
|
| 442 |
+
scaler.step(optimizer)
|
| 443 |
+
scaler.update()
|
| 444 |
+
optimizer.zero_grad(set_to_none=True)
|
| 445 |
+
logger.info(f"{sample_param}")
|
| 446 |
+
|
| 447 |
+
# from safetensors.torch import save_file
|
| 448 |
+
|
| 449 |
+
# save_file(control_net.state_dict(), "logs/control_net.safetensors")
|
control_net_lllite_for_train.py
ADDED
|
@@ -0,0 +1,501 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# cond_imageをU-Netのforwardで渡すバージョンのControlNet-LLLite検証用実装
|
| 2 |
+
# ControlNet-LLLite implementation for verification with cond_image passed in U-Net's forward
|
| 3 |
+
|
| 4 |
+
import os
|
| 5 |
+
import re
|
| 6 |
+
from typing import Optional, List, Type
|
| 7 |
+
import torch
|
| 8 |
+
from library import sdxl_original_unet
|
| 9 |
+
from library.utils import setup_logging
|
| 10 |
+
|
| 11 |
+
setup_logging()
|
| 12 |
+
import logging
|
| 13 |
+
|
| 14 |
+
logger = logging.getLogger(__name__)
|
| 15 |
+
|
| 16 |
+
# input_blocksに適用するかどうか / if True, input_blocks are not applied
|
| 17 |
+
SKIP_INPUT_BLOCKS = False
|
| 18 |
+
|
| 19 |
+
# output_blocksに適用するかどうか / if True, output_blocks are not applied
|
| 20 |
+
SKIP_OUTPUT_BLOCKS = True
|
| 21 |
+
|
| 22 |
+
# conv2dに適用するかどうか / if True, conv2d are not applied
|
| 23 |
+
SKIP_CONV2D = False
|
| 24 |
+
|
| 25 |
+
# transformer_blocksのみに適用するかどうか。Trueの場合、ResBlockには適用されない
|
| 26 |
+
# if True, only transformer_blocks are applied, and ResBlocks are not applied
|
| 27 |
+
TRANSFORMER_ONLY = True # if True, SKIP_CONV2D is ignored because conv2d is not used in transformer_blocks
|
| 28 |
+
|
| 29 |
+
# Trueならattn1とattn2にのみ適用し、ffなどには適用しない / if True, apply only to attn1 and attn2, not to ff etc.
|
| 30 |
+
ATTN1_2_ONLY = True
|
| 31 |
+
|
| 32 |
+
# Trueならattn1のQKV、attn2のQにのみ適用する、ATTN1_2_ONLY指定時のみ有効 / if True, apply only to attn1 QKV and attn2 Q, only valid when ATTN1_2_ONLY is specified
|
| 33 |
+
ATTN_QKV_ONLY = True
|
| 34 |
+
|
| 35 |
+
# Trueならattn1やffなどにのみ適用し、attn2などには適用しない / if True, apply only to attn1 and ff, not to attn2
|
| 36 |
+
# ATTN1_2_ONLYと同時にTrueにできない / cannot be True at the same time as ATTN1_2_ONLY
|
| 37 |
+
ATTN1_ETC_ONLY = False # True
|
| 38 |
+
|
| 39 |
+
# transformer_blocksの最大インデックス。Noneなら全てのtransformer_blocksに適用
|
| 40 |
+
# max index of transformer_blocks. if None, apply to all transformer_blocks
|
| 41 |
+
TRANSFORMER_MAX_BLOCK_INDEX = None
|
| 42 |
+
|
| 43 |
+
ORIGINAL_LINEAR = torch.nn.Linear
|
| 44 |
+
ORIGINAL_CONV2D = torch.nn.Conv2d
|
| 45 |
+
|
| 46 |
+
|
| 47 |
+
def add_lllite_modules(module: torch.nn.Module, in_dim: int, depth, cond_emb_dim, mlp_dim) -> None:
|
| 48 |
+
# conditioning1はconditioning imageを embedding する。timestepごとに呼ばれない
|
| 49 |
+
# conditioning1 embeds conditioning image. it is not called for each timestep
|
| 50 |
+
modules = []
|
| 51 |
+
modules.append(ORIGINAL_CONV2D(3, cond_emb_dim // 2, kernel_size=4, stride=4, padding=0)) # to latent (from VAE) size
|
| 52 |
+
if depth == 1:
|
| 53 |
+
modules.append(torch.nn.ReLU(inplace=True))
|
| 54 |
+
modules.append(ORIGINAL_CONV2D(cond_emb_dim // 2, cond_emb_dim, kernel_size=2, stride=2, padding=0))
|
| 55 |
+
elif depth == 2:
|
| 56 |
+
modules.append(torch.nn.ReLU(inplace=True))
|
| 57 |
+
modules.append(ORIGINAL_CONV2D(cond_emb_dim // 2, cond_emb_dim, kernel_size=4, stride=4, padding=0))
|
| 58 |
+
elif depth == 3:
|
| 59 |
+
# kernel size 8は大きすぎるので、4にする / kernel size 8 is too large, so set it to 4
|
| 60 |
+
modules.append(torch.nn.ReLU(inplace=True))
|
| 61 |
+
modules.append(ORIGINAL_CONV2D(cond_emb_dim // 2, cond_emb_dim // 2, kernel_size=4, stride=4, padding=0))
|
| 62 |
+
modules.append(torch.nn.ReLU(inplace=True))
|
| 63 |
+
modules.append(ORIGINAL_CONV2D(cond_emb_dim // 2, cond_emb_dim, kernel_size=2, stride=2, padding=0))
|
| 64 |
+
|
| 65 |
+
module.lllite_conditioning1 = torch.nn.Sequential(*modules)
|
| 66 |
+
|
| 67 |
+
# downで入力の次元数を削減する。LoRAにヒントを得ていることにする
|
| 68 |
+
# midでconditioning image embeddingと入力を結合する
|
| 69 |
+
# upで元の次元数に戻す
|
| 70 |
+
# これらはtimestepごとに呼ばれる
|
| 71 |
+
# reduce the number of input dimensions with down. inspired by LoRA
|
| 72 |
+
# combine conditioning image embedding and input with mid
|
| 73 |
+
# restore to the original dimension with up
|
| 74 |
+
# these are called for each timestep
|
| 75 |
+
|
| 76 |
+
module.lllite_down = torch.nn.Sequential(
|
| 77 |
+
ORIGINAL_LINEAR(in_dim, mlp_dim),
|
| 78 |
+
torch.nn.ReLU(inplace=True),
|
| 79 |
+
)
|
| 80 |
+
module.lllite_mid = torch.nn.Sequential(
|
| 81 |
+
ORIGINAL_LINEAR(mlp_dim + cond_emb_dim, mlp_dim),
|
| 82 |
+
torch.nn.ReLU(inplace=True),
|
| 83 |
+
)
|
| 84 |
+
module.lllite_up = torch.nn.Sequential(
|
| 85 |
+
ORIGINAL_LINEAR(mlp_dim, in_dim),
|
| 86 |
+
)
|
| 87 |
+
|
| 88 |
+
# Zero-Convにする / set to Zero-Conv
|
| 89 |
+
torch.nn.init.zeros_(module.lllite_up[0].weight) # zero conv
|
| 90 |
+
|
| 91 |
+
|
| 92 |
+
class LLLiteLinear(ORIGINAL_LINEAR):
|
| 93 |
+
def __init__(self, in_features: int, out_features: int, **kwargs):
|
| 94 |
+
super().__init__(in_features, out_features, **kwargs)
|
| 95 |
+
self.enabled = False
|
| 96 |
+
|
| 97 |
+
def set_lllite(self, depth, cond_emb_dim, name, mlp_dim, dropout=None, multiplier=1.0):
|
| 98 |
+
self.enabled = True
|
| 99 |
+
self.lllite_name = name
|
| 100 |
+
self.cond_emb_dim = cond_emb_dim
|
| 101 |
+
self.dropout = dropout
|
| 102 |
+
self.multiplier = multiplier # ignored
|
| 103 |
+
|
| 104 |
+
in_dim = self.in_features
|
| 105 |
+
add_lllite_modules(self, in_dim, depth, cond_emb_dim, mlp_dim)
|
| 106 |
+
|
| 107 |
+
self.cond_image = None
|
| 108 |
+
|
| 109 |
+
def set_cond_image(self, cond_image):
|
| 110 |
+
self.cond_image = cond_image
|
| 111 |
+
|
| 112 |
+
def forward(self, x):
|
| 113 |
+
if not self.enabled:
|
| 114 |
+
return super().forward(x)
|
| 115 |
+
|
| 116 |
+
cx = self.lllite_conditioning1(self.cond_image) # make forward and backward compatible
|
| 117 |
+
|
| 118 |
+
# reshape / b,c,h,w -> b,h*w,c
|
| 119 |
+
n, c, h, w = cx.shape
|
| 120 |
+
cx = cx.view(n, c, h * w).permute(0, 2, 1)
|
| 121 |
+
|
| 122 |
+
cx = torch.cat([cx, self.lllite_down(x)], dim=2)
|
| 123 |
+
cx = self.lllite_mid(cx)
|
| 124 |
+
|
| 125 |
+
if self.dropout is not None and self.training:
|
| 126 |
+
cx = torch.nn.functional.dropout(cx, p=self.dropout)
|
| 127 |
+
|
| 128 |
+
cx = self.lllite_up(cx) * self.multiplier
|
| 129 |
+
|
| 130 |
+
x = super().forward(x + cx) # ここで元のモジュールを呼び出す / call the original module here
|
| 131 |
+
return x
|
| 132 |
+
|
| 133 |
+
|
| 134 |
+
class LLLiteConv2d(ORIGINAL_CONV2D):
|
| 135 |
+
def __init__(self, in_channels: int, out_channels: int, kernel_size, **kwargs):
|
| 136 |
+
super().__init__(in_channels, out_channels, kernel_size, **kwargs)
|
| 137 |
+
self.enabled = False
|
| 138 |
+
|
| 139 |
+
def set_lllite(self, depth, cond_emb_dim, name, mlp_dim, dropout=None, multiplier=1.0):
|
| 140 |
+
self.enabled = True
|
| 141 |
+
self.lllite_name = name
|
| 142 |
+
self.cond_emb_dim = cond_emb_dim
|
| 143 |
+
self.dropout = dropout
|
| 144 |
+
self.multiplier = multiplier # ignored
|
| 145 |
+
|
| 146 |
+
in_dim = self.in_channels
|
| 147 |
+
add_lllite_modules(self, in_dim, depth, cond_emb_dim, mlp_dim)
|
| 148 |
+
|
| 149 |
+
self.cond_image = None
|
| 150 |
+
self.cond_emb = None
|
| 151 |
+
|
| 152 |
+
def set_cond_image(self, cond_image):
|
| 153 |
+
self.cond_image = cond_image
|
| 154 |
+
self.cond_emb = None
|
| 155 |
+
|
| 156 |
+
def forward(self, x): # , cond_image=None):
|
| 157 |
+
if not self.enabled:
|
| 158 |
+
return super().forward(x)
|
| 159 |
+
|
| 160 |
+
cx = self.lllite_conditioning1(self.cond_image)
|
| 161 |
+
|
| 162 |
+
cx = torch.cat([cx, self.down(x)], dim=1)
|
| 163 |
+
cx = self.mid(cx)
|
| 164 |
+
|
| 165 |
+
if self.dropout is not None and self.training:
|
| 166 |
+
cx = torch.nn.functional.dropout(cx, p=self.dropout)
|
| 167 |
+
|
| 168 |
+
cx = self.up(cx) * self.multiplier
|
| 169 |
+
|
| 170 |
+
x = super().forward(x + cx) # ここで元のモジュールを呼び出す / call the original module here
|
| 171 |
+
return x
|
| 172 |
+
|
| 173 |
+
|
| 174 |
+
class SdxlUNet2DConditionModelControlNetLLLite(sdxl_original_unet.SdxlUNet2DConditionModel):
|
| 175 |
+
UNET_TARGET_REPLACE_MODULE = ["Transformer2DModel"]
|
| 176 |
+
UNET_TARGET_REPLACE_MODULE_CONV2D_3X3 = ["ResnetBlock2D", "Downsample2D", "Upsample2D"]
|
| 177 |
+
LLLITE_PREFIX = "lllite_unet"
|
| 178 |
+
|
| 179 |
+
def __init__(self, **kwargs):
|
| 180 |
+
super().__init__(**kwargs)
|
| 181 |
+
|
| 182 |
+
def apply_lllite(
|
| 183 |
+
self,
|
| 184 |
+
cond_emb_dim: int = 16,
|
| 185 |
+
mlp_dim: int = 16,
|
| 186 |
+
dropout: Optional[float] = None,
|
| 187 |
+
varbose: Optional[bool] = False,
|
| 188 |
+
multiplier: Optional[float] = 1.0,
|
| 189 |
+
) -> None:
|
| 190 |
+
def apply_to_modules(
|
| 191 |
+
root_module: torch.nn.Module,
|
| 192 |
+
target_replace_modules: List[torch.nn.Module],
|
| 193 |
+
) -> List[torch.nn.Module]:
|
| 194 |
+
prefix = "lllite_unet"
|
| 195 |
+
|
| 196 |
+
modules = []
|
| 197 |
+
for name, module in root_module.named_modules():
|
| 198 |
+
if module.__class__.__name__ in target_replace_modules:
|
| 199 |
+
for child_name, child_module in module.named_modules():
|
| 200 |
+
is_linear = child_module.__class__.__name__ == "LLLiteLinear"
|
| 201 |
+
is_conv2d = child_module.__class__.__name__ == "LLLiteConv2d"
|
| 202 |
+
|
| 203 |
+
if is_linear or (is_conv2d and not SKIP_CONV2D):
|
| 204 |
+
# block indexからdepthを計算: depthはconditioningのサイズやチャネルを計算するのに使う
|
| 205 |
+
# block index to depth: depth is using to calculate conditioning size and channels
|
| 206 |
+
block_name, index1, index2 = (name + "." + child_name).split(".")[:3]
|
| 207 |
+
index1 = int(index1)
|
| 208 |
+
if block_name == "input_blocks":
|
| 209 |
+
if SKIP_INPUT_BLOCKS:
|
| 210 |
+
continue
|
| 211 |
+
depth = 1 if index1 <= 2 else (2 if index1 <= 5 else 3)
|
| 212 |
+
elif block_name == "middle_block":
|
| 213 |
+
depth = 3
|
| 214 |
+
elif block_name == "output_blocks":
|
| 215 |
+
if SKIP_OUTPUT_BLOCKS:
|
| 216 |
+
continue
|
| 217 |
+
depth = 3 if index1 <= 2 else (2 if index1 <= 5 else 1)
|
| 218 |
+
if int(index2) >= 2:
|
| 219 |
+
depth -= 1
|
| 220 |
+
else:
|
| 221 |
+
raise NotImplementedError()
|
| 222 |
+
|
| 223 |
+
lllite_name = prefix + "." + name + "." + child_name
|
| 224 |
+
lllite_name = lllite_name.replace(".", "_")
|
| 225 |
+
|
| 226 |
+
if TRANSFORMER_MAX_BLOCK_INDEX is not None:
|
| 227 |
+
p = lllite_name.find("transformer_blocks")
|
| 228 |
+
if p >= 0:
|
| 229 |
+
tf_index = int(lllite_name[p:].split("_")[2])
|
| 230 |
+
if tf_index > TRANSFORMER_MAX_BLOCK_INDEX:
|
| 231 |
+
continue
|
| 232 |
+
|
| 233 |
+
# time embは適用外とする
|
| 234 |
+
# attn2のconditioning (CLIPからの入力) はshapeが違うので適用できない
|
| 235 |
+
# time emb is not applied
|
| 236 |
+
# attn2 conditioning (input from CLIP) cannot be applied because the shape is different
|
| 237 |
+
if "emb_layers" in lllite_name or (
|
| 238 |
+
"attn2" in lllite_name and ("to_k" in lllite_name or "to_v" in lllite_name)
|
| 239 |
+
):
|
| 240 |
+
continue
|
| 241 |
+
|
| 242 |
+
if ATTN1_2_ONLY:
|
| 243 |
+
if not ("attn1" in lllite_name or "attn2" in lllite_name):
|
| 244 |
+
continue
|
| 245 |
+
if ATTN_QKV_ONLY:
|
| 246 |
+
if "to_out" in lllite_name:
|
| 247 |
+
continue
|
| 248 |
+
|
| 249 |
+
if ATTN1_ETC_ONLY:
|
| 250 |
+
if "proj_out" in lllite_name:
|
| 251 |
+
pass
|
| 252 |
+
elif "attn1" in lllite_name and (
|
| 253 |
+
"to_k" in lllite_name or "to_v" in lllite_name or "to_out" in lllite_name
|
| 254 |
+
):
|
| 255 |
+
pass
|
| 256 |
+
elif "ff_net_2" in lllite_name:
|
| 257 |
+
pass
|
| 258 |
+
else:
|
| 259 |
+
continue
|
| 260 |
+
|
| 261 |
+
child_module.set_lllite(depth, cond_emb_dim, lllite_name, mlp_dim, dropout, multiplier)
|
| 262 |
+
modules.append(child_module)
|
| 263 |
+
|
| 264 |
+
return modules
|
| 265 |
+
|
| 266 |
+
target_modules = SdxlUNet2DConditionModelControlNetLLLite.UNET_TARGET_REPLACE_MODULE
|
| 267 |
+
if not TRANSFORMER_ONLY:
|
| 268 |
+
target_modules = target_modules + SdxlUNet2DConditionModelControlNetLLLite.UNET_TARGET_REPLACE_MODULE_CONV2D_3X3
|
| 269 |
+
|
| 270 |
+
# create module instances
|
| 271 |
+
self.lllite_modules = apply_to_modules(self, target_modules)
|
| 272 |
+
logger.info(f"enable ControlNet LLLite for U-Net: {len(self.lllite_modules)} modules.")
|
| 273 |
+
|
| 274 |
+
# def prepare_optimizer_params(self):
|
| 275 |
+
def prepare_params(self):
|
| 276 |
+
train_params = []
|
| 277 |
+
non_train_params = []
|
| 278 |
+
for name, p in self.named_parameters():
|
| 279 |
+
if "lllite" in name:
|
| 280 |
+
train_params.append(p)
|
| 281 |
+
else:
|
| 282 |
+
non_train_params.append(p)
|
| 283 |
+
logger.info(f"count of trainable parameters: {len(train_params)}")
|
| 284 |
+
logger.info(f"count of non-trainable parameters: {len(non_train_params)}")
|
| 285 |
+
|
| 286 |
+
for p in non_train_params:
|
| 287 |
+
p.requires_grad_(False)
|
| 288 |
+
|
| 289 |
+
# without this, an error occurs in the optimizer
|
| 290 |
+
# RuntimeError: element 0 of tensors does not require grad and does not have a grad_fn
|
| 291 |
+
non_train_params[0].requires_grad_(True)
|
| 292 |
+
|
| 293 |
+
for p in train_params:
|
| 294 |
+
p.requires_grad_(True)
|
| 295 |
+
|
| 296 |
+
return train_params
|
| 297 |
+
|
| 298 |
+
# def prepare_grad_etc(self):
|
| 299 |
+
# self.requires_grad_(True)
|
| 300 |
+
|
| 301 |
+
# def on_epoch_start(self):
|
| 302 |
+
# self.train()
|
| 303 |
+
|
| 304 |
+
def get_trainable_params(self):
|
| 305 |
+
return [p[1] for p in self.named_parameters() if "lllite" in p[0]]
|
| 306 |
+
|
| 307 |
+
def save_lllite_weights(self, file, dtype, metadata):
|
| 308 |
+
if metadata is not None and len(metadata) == 0:
|
| 309 |
+
metadata = None
|
| 310 |
+
|
| 311 |
+
org_state_dict = self.state_dict()
|
| 312 |
+
|
| 313 |
+
# copy LLLite keys from org_state_dict to state_dict with key conversion
|
| 314 |
+
state_dict = {}
|
| 315 |
+
for key in org_state_dict.keys():
|
| 316 |
+
# split with ".lllite"
|
| 317 |
+
pos = key.find(".lllite")
|
| 318 |
+
if pos < 0:
|
| 319 |
+
continue
|
| 320 |
+
lllite_key = SdxlUNet2DConditionModelControlNetLLLite.LLLITE_PREFIX + "." + key[:pos]
|
| 321 |
+
lllite_key = lllite_key.replace(".", "_") + key[pos:]
|
| 322 |
+
lllite_key = lllite_key.replace(".lllite_", ".")
|
| 323 |
+
state_dict[lllite_key] = org_state_dict[key]
|
| 324 |
+
|
| 325 |
+
if dtype is not None:
|
| 326 |
+
for key in list(state_dict.keys()):
|
| 327 |
+
v = state_dict[key]
|
| 328 |
+
v = v.detach().clone().to("cpu").to(dtype)
|
| 329 |
+
state_dict[key] = v
|
| 330 |
+
|
| 331 |
+
if os.path.splitext(file)[1] == ".safetensors":
|
| 332 |
+
from safetensors.torch import save_file
|
| 333 |
+
|
| 334 |
+
save_file(state_dict, file, metadata)
|
| 335 |
+
else:
|
| 336 |
+
torch.save(state_dict, file)
|
| 337 |
+
|
| 338 |
+
def load_lllite_weights(self, file, non_lllite_unet_sd=None):
|
| 339 |
+
r"""
|
| 340 |
+
LLLiteの重みを読み込まない(initされた値を使う)場合はfileにNoneを指定する。
|
| 341 |
+
この場合、non_lllite_unet_sdにはU-Netのstate_dictを指定する。
|
| 342 |
+
|
| 343 |
+
If you do not want to load LLLite weights (use initialized values), specify None for file.
|
| 344 |
+
In this case, specify the state_dict of U-Net for non_lllite_unet_sd.
|
| 345 |
+
"""
|
| 346 |
+
if not file:
|
| 347 |
+
state_dict = self.state_dict()
|
| 348 |
+
for key in non_lllite_unet_sd:
|
| 349 |
+
if key in state_dict:
|
| 350 |
+
state_dict[key] = non_lllite_unet_sd[key]
|
| 351 |
+
info = self.load_state_dict(state_dict, False)
|
| 352 |
+
return info
|
| 353 |
+
|
| 354 |
+
if os.path.splitext(file)[1] == ".safetensors":
|
| 355 |
+
from safetensors.torch import load_file
|
| 356 |
+
|
| 357 |
+
weights_sd = load_file(file)
|
| 358 |
+
else:
|
| 359 |
+
weights_sd = torch.load(file, map_location="cpu")
|
| 360 |
+
|
| 361 |
+
# module_name = module_name.replace("_block", "@blocks")
|
| 362 |
+
# module_name = module_name.replace("_layer", "@layer")
|
| 363 |
+
# module_name = module_name.replace("to_", "to@")
|
| 364 |
+
# module_name = module_name.replace("time_embed", "time@embed")
|
| 365 |
+
# module_name = module_name.replace("label_emb", "label@emb")
|
| 366 |
+
# module_name = module_name.replace("skip_connection", "skip@connection")
|
| 367 |
+
# module_name = module_name.replace("proj_in", "proj@in")
|
| 368 |
+
# module_name = module_name.replace("proj_out", "proj@out")
|
| 369 |
+
pattern = re.compile(r"(_block|_layer|to_|time_embed|label_emb|skip_connection|proj_in|proj_out)")
|
| 370 |
+
|
| 371 |
+
# convert to lllite with U-Net state dict
|
| 372 |
+
state_dict = non_lllite_unet_sd.copy() if non_lllite_unet_sd is not None else {}
|
| 373 |
+
for key in weights_sd.keys():
|
| 374 |
+
# split with "."
|
| 375 |
+
pos = key.find(".")
|
| 376 |
+
if pos < 0:
|
| 377 |
+
continue
|
| 378 |
+
|
| 379 |
+
module_name = key[:pos]
|
| 380 |
+
weight_name = key[pos + 1 :] # exclude "."
|
| 381 |
+
module_name = module_name.replace(SdxlUNet2DConditionModelControlNetLLLite.LLLITE_PREFIX + "_", "")
|
| 382 |
+
|
| 383 |
+
# これはうまくいかない。逆変換を考えなかった設計が悪い / this does not work well. bad design because I didn't think about inverse conversion
|
| 384 |
+
# module_name = module_name.replace("_", ".")
|
| 385 |
+
|
| 386 |
+
# ださいけどSDXLのU-Netの "_" を "@" に変換する / ugly but convert "_" of SDXL U-Net to "@"
|
| 387 |
+
matches = pattern.findall(module_name)
|
| 388 |
+
if matches is not None:
|
| 389 |
+
for m in matches:
|
| 390 |
+
logger.info(f"{module_name} {m}")
|
| 391 |
+
module_name = module_name.replace(m, m.replace("_", "@"))
|
| 392 |
+
module_name = module_name.replace("_", ".")
|
| 393 |
+
module_name = module_name.replace("@", "_")
|
| 394 |
+
|
| 395 |
+
lllite_key = module_name + ".lllite_" + weight_name
|
| 396 |
+
|
| 397 |
+
state_dict[lllite_key] = weights_sd[key]
|
| 398 |
+
|
| 399 |
+
info = self.load_state_dict(state_dict, False)
|
| 400 |
+
return info
|
| 401 |
+
|
| 402 |
+
def forward(self, x, timesteps=None, context=None, y=None, cond_image=None, **kwargs):
|
| 403 |
+
for m in self.lllite_modules:
|
| 404 |
+
m.set_cond_image(cond_image)
|
| 405 |
+
return super().forward(x, timesteps, context, y, **kwargs)
|
| 406 |
+
|
| 407 |
+
|
| 408 |
+
def replace_unet_linear_and_conv2d():
|
| 409 |
+
logger.info("replace torch.nn.Linear and torch.nn.Conv2d to LLLiteLinear and LLLiteConv2d in U-Net")
|
| 410 |
+
sdxl_original_unet.torch.nn.Linear = LLLiteLinear
|
| 411 |
+
sdxl_original_unet.torch.nn.Conv2d = LLLiteConv2d
|
| 412 |
+
|
| 413 |
+
|
| 414 |
+
if __name__ == "__main__":
|
| 415 |
+
# デバッグ用 / for debug
|
| 416 |
+
|
| 417 |
+
# sdxl_original_unet.USE_REENTRANT = False
|
| 418 |
+
replace_unet_linear_and_conv2d()
|
| 419 |
+
|
| 420 |
+
# test shape etc
|
| 421 |
+
logger.info("create unet")
|
| 422 |
+
unet = SdxlUNet2DConditionModelControlNetLLLite()
|
| 423 |
+
|
| 424 |
+
logger.info("enable ControlNet-LLLite")
|
| 425 |
+
unet.apply_lllite(32, 64, None, False, 1.0)
|
| 426 |
+
unet.to("cuda") # .to(torch.float16)
|
| 427 |
+
|
| 428 |
+
# from safetensors.torch import load_file
|
| 429 |
+
|
| 430 |
+
# model_sd = load_file(r"E:\Work\SD\Models\sdxl\sd_xl_base_1.0_0.9vae.safetensors")
|
| 431 |
+
# unet_sd = {}
|
| 432 |
+
|
| 433 |
+
# # copy U-Net keys from unet_state_dict to state_dict
|
| 434 |
+
# prefix = "model.diffusion_model."
|
| 435 |
+
# for key in model_sd.keys():
|
| 436 |
+
# if key.startswith(prefix):
|
| 437 |
+
# converted_key = key[len(prefix) :]
|
| 438 |
+
# unet_sd[converted_key] = model_sd[key]
|
| 439 |
+
|
| 440 |
+
# info = unet.load_lllite_weights("r:/lllite_from_unet.safetensors", unet_sd)
|
| 441 |
+
# logger.info(info)
|
| 442 |
+
|
| 443 |
+
# logger.info(unet)
|
| 444 |
+
|
| 445 |
+
# logger.info number of parameters
|
| 446 |
+
params = unet.prepare_params()
|
| 447 |
+
logger.info(f"number of parameters {sum(p.numel() for p in params)}")
|
| 448 |
+
# logger.info("type any key to continue")
|
| 449 |
+
# input()
|
| 450 |
+
|
| 451 |
+
unet.set_use_memory_efficient_attention(True, False)
|
| 452 |
+
unet.set_gradient_checkpointing(True)
|
| 453 |
+
unet.train() # for gradient checkpointing
|
| 454 |
+
|
| 455 |
+
# # visualize
|
| 456 |
+
# import torchviz
|
| 457 |
+
# logger.info("run visualize")
|
| 458 |
+
# controlnet.set_control(conditioning_image)
|
| 459 |
+
# output = unet(x, t, ctx, y)
|
| 460 |
+
# logger.info("make_dot")
|
| 461 |
+
# image = torchviz.make_dot(output, params=dict(controlnet.named_parameters()))
|
| 462 |
+
# logger.info("render")
|
| 463 |
+
# image.format = "svg" # "png"
|
| 464 |
+
# image.render("NeuralNet") # すごく時間がかかるので注意 / be careful because it takes a long time
|
| 465 |
+
# input()
|
| 466 |
+
|
| 467 |
+
import bitsandbytes
|
| 468 |
+
|
| 469 |
+
optimizer = bitsandbytes.adam.Adam8bit(params, 1e-3)
|
| 470 |
+
|
| 471 |
+
scaler = torch.cuda.amp.GradScaler(enabled=True)
|
| 472 |
+
|
| 473 |
+
logger.info("start training")
|
| 474 |
+
steps = 10
|
| 475 |
+
batch_size = 1
|
| 476 |
+
|
| 477 |
+
sample_param = [p for p in unet.named_parameters() if ".lllite_up." in p[0]][0]
|
| 478 |
+
for step in range(steps):
|
| 479 |
+
logger.info(f"step {step}")
|
| 480 |
+
|
| 481 |
+
conditioning_image = torch.rand(batch_size, 3, 1024, 1024).cuda() * 2.0 - 1.0
|
| 482 |
+
x = torch.randn(batch_size, 4, 128, 128).cuda()
|
| 483 |
+
t = torch.randint(low=0, high=10, size=(batch_size,)).cuda()
|
| 484 |
+
ctx = torch.randn(batch_size, 77, 2048).cuda()
|
| 485 |
+
y = torch.randn(batch_size, sdxl_original_unet.ADM_IN_CHANNELS).cuda()
|
| 486 |
+
|
| 487 |
+
with torch.cuda.amp.autocast(enabled=True, dtype=torch.bfloat16):
|
| 488 |
+
output = unet(x, t, ctx, y, conditioning_image)
|
| 489 |
+
target = torch.randn_like(output)
|
| 490 |
+
loss = torch.nn.functional.mse_loss(output, target)
|
| 491 |
+
|
| 492 |
+
scaler.scale(loss).backward()
|
| 493 |
+
scaler.step(optimizer)
|
| 494 |
+
scaler.update()
|
| 495 |
+
optimizer.zero_grad(set_to_none=True)
|
| 496 |
+
logger.info(sample_param)
|
| 497 |
+
|
| 498 |
+
# from safetensors.torch import save_file
|
| 499 |
+
|
| 500 |
+
# logger.info("save weights")
|
| 501 |
+
# unet.save_lllite_weights("r:/lllite_from_unet.safetensors", torch.float16, None)
|
convert_diffusers20_original_sd.py
ADDED
|
@@ -0,0 +1,163 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# convert Diffusers v1.x/v2.0 model to original Stable Diffusion
|
| 2 |
+
|
| 3 |
+
import argparse
|
| 4 |
+
import os
|
| 5 |
+
import torch
|
| 6 |
+
from diffusers import StableDiffusionPipeline
|
| 7 |
+
|
| 8 |
+
import library.model_util as model_util
|
| 9 |
+
from library.utils import setup_logging
|
| 10 |
+
setup_logging()
|
| 11 |
+
import logging
|
| 12 |
+
logger = logging.getLogger(__name__)
|
| 13 |
+
|
| 14 |
+
def convert(args):
|
| 15 |
+
# 引数を確認する
|
| 16 |
+
load_dtype = torch.float16 if args.fp16 else None
|
| 17 |
+
|
| 18 |
+
save_dtype = None
|
| 19 |
+
if args.fp16 or args.save_precision_as == "fp16":
|
| 20 |
+
save_dtype = torch.float16
|
| 21 |
+
elif args.bf16 or args.save_precision_as == "bf16":
|
| 22 |
+
save_dtype = torch.bfloat16
|
| 23 |
+
elif args.float or args.save_precision_as == "float":
|
| 24 |
+
save_dtype = torch.float
|
| 25 |
+
|
| 26 |
+
is_load_ckpt = os.path.isfile(args.model_to_load)
|
| 27 |
+
is_save_ckpt = len(os.path.splitext(args.model_to_save)[1]) > 0
|
| 28 |
+
|
| 29 |
+
assert not is_load_ckpt or args.v1 != args.v2, "v1 or v2 is required to load checkpoint / checkpointの読み込みにはv1/v2指定が必要です"
|
| 30 |
+
# assert (
|
| 31 |
+
# is_save_ckpt or args.reference_model is not None
|
| 32 |
+
# ), f"reference model is required to save as Diffusers / Diffusers形式での保存には参照モデルが必要です"
|
| 33 |
+
|
| 34 |
+
# モデルを読み込む
|
| 35 |
+
msg = "checkpoint" if is_load_ckpt else ("Diffusers" + (" as fp16" if args.fp16 else ""))
|
| 36 |
+
logger.info(f"loading {msg}: {args.model_to_load}")
|
| 37 |
+
|
| 38 |
+
if is_load_ckpt:
|
| 39 |
+
v2_model = args.v2
|
| 40 |
+
text_encoder, vae, unet = model_util.load_models_from_stable_diffusion_checkpoint(
|
| 41 |
+
v2_model, args.model_to_load, unet_use_linear_projection_in_v2=args.unet_use_linear_projection
|
| 42 |
+
)
|
| 43 |
+
else:
|
| 44 |
+
pipe = StableDiffusionPipeline.from_pretrained(
|
| 45 |
+
args.model_to_load, torch_dtype=load_dtype, tokenizer=None, safety_checker=None, variant=args.variant
|
| 46 |
+
)
|
| 47 |
+
text_encoder = pipe.text_encoder
|
| 48 |
+
vae = pipe.vae
|
| 49 |
+
unet = pipe.unet
|
| 50 |
+
|
| 51 |
+
if args.v1 == args.v2:
|
| 52 |
+
# 自動判定する
|
| 53 |
+
v2_model = unet.config.cross_attention_dim == 1024
|
| 54 |
+
logger.info("checking model version: model is " + ("v2" if v2_model else "v1"))
|
| 55 |
+
else:
|
| 56 |
+
v2_model = not args.v1
|
| 57 |
+
|
| 58 |
+
# 変換して保存する
|
| 59 |
+
msg = ("checkpoint" + ("" if save_dtype is None else f" in {save_dtype}")) if is_save_ckpt else "Diffusers"
|
| 60 |
+
logger.info(f"converting and saving as {msg}: {args.model_to_save}")
|
| 61 |
+
|
| 62 |
+
if is_save_ckpt:
|
| 63 |
+
original_model = args.model_to_load if is_load_ckpt else None
|
| 64 |
+
key_count = model_util.save_stable_diffusion_checkpoint(
|
| 65 |
+
v2_model,
|
| 66 |
+
args.model_to_save,
|
| 67 |
+
text_encoder,
|
| 68 |
+
unet,
|
| 69 |
+
original_model,
|
| 70 |
+
args.epoch,
|
| 71 |
+
args.global_step,
|
| 72 |
+
None if args.metadata is None else eval(args.metadata),
|
| 73 |
+
save_dtype=save_dtype,
|
| 74 |
+
vae=vae,
|
| 75 |
+
)
|
| 76 |
+
logger.info(f"model saved. total converted state_dict keys: {key_count}")
|
| 77 |
+
else:
|
| 78 |
+
logger.info(
|
| 79 |
+
f"copy scheduler/tokenizer config from: {args.reference_model if args.reference_model is not None else 'default model'}"
|
| 80 |
+
)
|
| 81 |
+
model_util.save_diffusers_checkpoint(
|
| 82 |
+
v2_model, args.model_to_save, text_encoder, unet, args.reference_model, vae, args.use_safetensors
|
| 83 |
+
)
|
| 84 |
+
logger.info("model saved.")
|
| 85 |
+
|
| 86 |
+
|
| 87 |
+
def setup_parser() -> argparse.ArgumentParser:
|
| 88 |
+
parser = argparse.ArgumentParser()
|
| 89 |
+
parser.add_argument(
|
| 90 |
+
"--v1", action="store_true", help="load v1.x model (v1 or v2 is required to load checkpoint) / 1.xのモデルを読み込む"
|
| 91 |
+
)
|
| 92 |
+
parser.add_argument(
|
| 93 |
+
"--v2", action="store_true", help="load v2.0 model (v1 or v2 is required to load checkpoint) / 2.0のモデルを読み込む"
|
| 94 |
+
)
|
| 95 |
+
parser.add_argument(
|
| 96 |
+
"--unet_use_linear_projection",
|
| 97 |
+
action="store_true",
|
| 98 |
+
help="When saving v2 model as Diffusers, set U-Net config to `use_linear_projection=true` (to match stabilityai's model) / Diffusers形式でv2モデルを保存するときにU-Netの設定を`use_linear_projection=true`にする(stabilityaiのモデルと合わせる)",
|
| 99 |
+
)
|
| 100 |
+
parser.add_argument(
|
| 101 |
+
"--fp16",
|
| 102 |
+
action="store_true",
|
| 103 |
+
help="load as fp16 (Diffusers only) and save as fp16 (checkpoint only) / fp16形式で読み込み(Diffusers形式のみ対応)、保存する(checkpointのみ対応)",
|
| 104 |
+
)
|
| 105 |
+
parser.add_argument("--bf16", action="store_true", help="save as bf16 (checkpoint only) / bf16形式で保存する(checkpointのみ対応)")
|
| 106 |
+
parser.add_argument(
|
| 107 |
+
"--float", action="store_true", help="save as float (checkpoint only) / float(float32)形式で保存する(checkpointのみ対応)"
|
| 108 |
+
)
|
| 109 |
+
parser.add_argument(
|
| 110 |
+
"--save_precision_as",
|
| 111 |
+
type=str,
|
| 112 |
+
default="no",
|
| 113 |
+
choices=["fp16", "bf16", "float"],
|
| 114 |
+
help="save precision, do not specify with --fp16/--bf16/--float / 保存する精度、--fp16/--bf16/--floatと併用しないでくださ���",
|
| 115 |
+
)
|
| 116 |
+
parser.add_argument("--epoch", type=int, default=0, help="epoch to write to checkpoint / checkpointに記録するepoch数の値")
|
| 117 |
+
parser.add_argument(
|
| 118 |
+
"--global_step", type=int, default=0, help="global_step to write to checkpoint / checkpointに記録するglobal_stepの値"
|
| 119 |
+
)
|
| 120 |
+
parser.add_argument(
|
| 121 |
+
"--metadata",
|
| 122 |
+
type=str,
|
| 123 |
+
default=None,
|
| 124 |
+
help='モデルに保存されるメタデータ、Pythonの辞書形式で指定 / metadata: metadata written in to the model in Python Dictionary. Example metadata: \'{"name": "model_name", "resolution": "512x512"}\'',
|
| 125 |
+
)
|
| 126 |
+
parser.add_argument(
|
| 127 |
+
"--variant",
|
| 128 |
+
type=str,
|
| 129 |
+
default=None,
|
| 130 |
+
help="読む込むDiffusersのvariantを指定する、例: fp16 / variant: Diffusers variant to load. Example: fp16",
|
| 131 |
+
)
|
| 132 |
+
parser.add_argument(
|
| 133 |
+
"--reference_model",
|
| 134 |
+
type=str,
|
| 135 |
+
default=None,
|
| 136 |
+
help="scheduler/tokenizerのコピー元Diffusersモデル、Diffusers形式で保存するときに使用される、省略時は`runwayml/stable-diffusion-v1-5` または `stabilityai/stable-diffusion-2-1` / reference Diffusers model to copy scheduler/tokenizer config from, used when saving as Diffusers format, default is `runwayml/stable-diffusion-v1-5` or `stabilityai/stable-diffusion-2-1`",
|
| 137 |
+
)
|
| 138 |
+
parser.add_argument(
|
| 139 |
+
"--use_safetensors",
|
| 140 |
+
action="store_true",
|
| 141 |
+
help="use safetensors format to save Diffusers model (checkpoint depends on the file extension) / Duffusersモデルをsafetensors形式で保存する(checkpointは拡張子で自動判定)",
|
| 142 |
+
)
|
| 143 |
+
|
| 144 |
+
parser.add_argument(
|
| 145 |
+
"model_to_load",
|
| 146 |
+
type=str,
|
| 147 |
+
default=None,
|
| 148 |
+
help="model to load: checkpoint file or Diffusers model's directory / 読み込むモデル、checkpointかDiffusers形式モデルのディレクトリ",
|
| 149 |
+
)
|
| 150 |
+
parser.add_argument(
|
| 151 |
+
"model_to_save",
|
| 152 |
+
type=str,
|
| 153 |
+
default=None,
|
| 154 |
+
help="model to save: checkpoint (with extension) or Diffusers model's directory (without extension) / 変換後のモデル、拡張子がある場合はcheckpoint、ない場合はDiffusesモデルとして保存",
|
| 155 |
+
)
|
| 156 |
+
return parser
|
| 157 |
+
|
| 158 |
+
|
| 159 |
+
if __name__ == "__main__":
|
| 160 |
+
parser = setup_parser()
|
| 161 |
+
|
| 162 |
+
args = parser.parse_args()
|
| 163 |
+
convert(args)
|
custom_train_functions.py
ADDED
|
@@ -0,0 +1,559 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import argparse
|
| 3 |
+
import random
|
| 4 |
+
import re
|
| 5 |
+
from typing import List, Optional, Union
|
| 6 |
+
from .utils import setup_logging
|
| 7 |
+
|
| 8 |
+
setup_logging()
|
| 9 |
+
import logging
|
| 10 |
+
|
| 11 |
+
logger = logging.getLogger(__name__)
|
| 12 |
+
|
| 13 |
+
|
| 14 |
+
def prepare_scheduler_for_custom_training(noise_scheduler, device):
|
| 15 |
+
if hasattr(noise_scheduler, "all_snr"):
|
| 16 |
+
return
|
| 17 |
+
|
| 18 |
+
alphas_cumprod = noise_scheduler.alphas_cumprod
|
| 19 |
+
sqrt_alphas_cumprod = torch.sqrt(alphas_cumprod)
|
| 20 |
+
sqrt_one_minus_alphas_cumprod = torch.sqrt(1.0 - alphas_cumprod)
|
| 21 |
+
alpha = sqrt_alphas_cumprod
|
| 22 |
+
sigma = sqrt_one_minus_alphas_cumprod
|
| 23 |
+
all_snr = (alpha / sigma) ** 2
|
| 24 |
+
|
| 25 |
+
noise_scheduler.all_snr = all_snr.to(device)
|
| 26 |
+
|
| 27 |
+
|
| 28 |
+
def fix_noise_scheduler_betas_for_zero_terminal_snr(noise_scheduler):
|
| 29 |
+
# fix beta: zero terminal SNR
|
| 30 |
+
logger.info(f"fix noise scheduler betas: https://arxiv.org/abs/2305.08891")
|
| 31 |
+
|
| 32 |
+
def enforce_zero_terminal_snr(betas):
|
| 33 |
+
# Convert betas to alphas_bar_sqrt
|
| 34 |
+
alphas = 1 - betas
|
| 35 |
+
alphas_bar = alphas.cumprod(0)
|
| 36 |
+
alphas_bar_sqrt = alphas_bar.sqrt()
|
| 37 |
+
|
| 38 |
+
# Store old values.
|
| 39 |
+
alphas_bar_sqrt_0 = alphas_bar_sqrt[0].clone()
|
| 40 |
+
alphas_bar_sqrt_T = alphas_bar_sqrt[-1].clone()
|
| 41 |
+
# Shift so last timestep is zero.
|
| 42 |
+
alphas_bar_sqrt -= alphas_bar_sqrt_T
|
| 43 |
+
# Scale so first timestep is back to old value.
|
| 44 |
+
alphas_bar_sqrt *= alphas_bar_sqrt_0 / (alphas_bar_sqrt_0 - alphas_bar_sqrt_T)
|
| 45 |
+
|
| 46 |
+
# Convert alphas_bar_sqrt to betas
|
| 47 |
+
alphas_bar = alphas_bar_sqrt**2
|
| 48 |
+
alphas = alphas_bar[1:] / alphas_bar[:-1]
|
| 49 |
+
alphas = torch.cat([alphas_bar[0:1], alphas])
|
| 50 |
+
betas = 1 - alphas
|
| 51 |
+
return betas
|
| 52 |
+
|
| 53 |
+
betas = noise_scheduler.betas
|
| 54 |
+
betas = enforce_zero_terminal_snr(betas)
|
| 55 |
+
alphas = 1.0 - betas
|
| 56 |
+
alphas_cumprod = torch.cumprod(alphas, dim=0)
|
| 57 |
+
|
| 58 |
+
# logger.info(f"original: {noise_scheduler.betas}")
|
| 59 |
+
# logger.info(f"fixed: {betas}")
|
| 60 |
+
|
| 61 |
+
noise_scheduler.betas = betas
|
| 62 |
+
noise_scheduler.alphas = alphas
|
| 63 |
+
noise_scheduler.alphas_cumprod = alphas_cumprod
|
| 64 |
+
|
| 65 |
+
|
| 66 |
+
def apply_snr_weight(loss, timesteps, noise_scheduler, gamma, v_prediction=False):
|
| 67 |
+
snr = torch.stack([noise_scheduler.all_snr[t] for t in timesteps])
|
| 68 |
+
min_snr_gamma = torch.minimum(snr, torch.full_like(snr, gamma))
|
| 69 |
+
if v_prediction:
|
| 70 |
+
snr_weight = torch.div(min_snr_gamma, snr + 1).float().to(loss.device)
|
| 71 |
+
else:
|
| 72 |
+
snr_weight = torch.div(min_snr_gamma, snr).float().to(loss.device)
|
| 73 |
+
loss = loss * snr_weight
|
| 74 |
+
return loss
|
| 75 |
+
|
| 76 |
+
|
| 77 |
+
def scale_v_prediction_loss_like_noise_prediction(loss, timesteps, noise_scheduler):
|
| 78 |
+
scale = get_snr_scale(timesteps, noise_scheduler)
|
| 79 |
+
loss = loss * scale
|
| 80 |
+
return loss
|
| 81 |
+
|
| 82 |
+
|
| 83 |
+
def get_snr_scale(timesteps, noise_scheduler):
|
| 84 |
+
snr_t = torch.stack([noise_scheduler.all_snr[t] for t in timesteps]) # batch_size
|
| 85 |
+
snr_t = torch.minimum(snr_t, torch.ones_like(snr_t) * 1000) # if timestep is 0, snr_t is inf, so limit it to 1000
|
| 86 |
+
scale = snr_t / (snr_t + 1)
|
| 87 |
+
# # show debug info
|
| 88 |
+
# logger.info(f"timesteps: {timesteps}, snr_t: {snr_t}, scale: {scale}")
|
| 89 |
+
return scale
|
| 90 |
+
|
| 91 |
+
|
| 92 |
+
def add_v_prediction_like_loss(loss, timesteps, noise_scheduler, v_pred_like_loss):
|
| 93 |
+
scale = get_snr_scale(timesteps, noise_scheduler)
|
| 94 |
+
# logger.info(f"add v-prediction like loss: {v_pred_like_loss}, scale: {scale}, loss: {loss}, time: {timesteps}")
|
| 95 |
+
loss = loss + loss / scale * v_pred_like_loss
|
| 96 |
+
return loss
|
| 97 |
+
|
| 98 |
+
|
| 99 |
+
def apply_debiased_estimation(loss, timesteps, noise_scheduler, v_prediction=False):
|
| 100 |
+
snr_t = torch.stack([noise_scheduler.all_snr[t] for t in timesteps]) # batch_size
|
| 101 |
+
snr_t = torch.minimum(snr_t, torch.ones_like(snr_t) * 1000) # if timestep is 0, snr_t is inf, so limit it to 1000
|
| 102 |
+
if v_prediction:
|
| 103 |
+
weight = 1 / (snr_t + 1)
|
| 104 |
+
else:
|
| 105 |
+
weight = 1 / torch.sqrt(snr_t)
|
| 106 |
+
loss = weight * loss
|
| 107 |
+
return loss
|
| 108 |
+
|
| 109 |
+
|
| 110 |
+
# TODO train_utilと分散しているのでどちらかに寄せる
|
| 111 |
+
|
| 112 |
+
|
| 113 |
+
def add_custom_train_arguments(parser: argparse.ArgumentParser, support_weighted_captions: bool = True):
|
| 114 |
+
parser.add_argument(
|
| 115 |
+
"--min_snr_gamma",
|
| 116 |
+
type=float,
|
| 117 |
+
default=None,
|
| 118 |
+
help="gamma for reducing the weight of high loss timesteps. Lower numbers have stronger effect. 5 is recommended by paper. / 低いタイムステップでの高いlossに対して重みを減らすためのgamma値、低いほど効果が強く、論文では5が推奨",
|
| 119 |
+
)
|
| 120 |
+
parser.add_argument(
|
| 121 |
+
"--scale_v_pred_loss_like_noise_pred",
|
| 122 |
+
action="store_true",
|
| 123 |
+
help="scale v-prediction loss like noise prediction loss / v-prediction lossをnoise prediction lossと同じようにスケーリングする",
|
| 124 |
+
)
|
| 125 |
+
parser.add_argument(
|
| 126 |
+
"--v_pred_like_loss",
|
| 127 |
+
type=float,
|
| 128 |
+
default=None,
|
| 129 |
+
help="add v-prediction like loss multiplied by this value / v-prediction lossをこの値をかけ��ものをlossに加算する",
|
| 130 |
+
)
|
| 131 |
+
parser.add_argument(
|
| 132 |
+
"--debiased_estimation_loss",
|
| 133 |
+
action="store_true",
|
| 134 |
+
help="debiased estimation loss / debiased estimation loss",
|
| 135 |
+
)
|
| 136 |
+
if support_weighted_captions:
|
| 137 |
+
parser.add_argument(
|
| 138 |
+
"--weighted_captions",
|
| 139 |
+
action="store_true",
|
| 140 |
+
default=False,
|
| 141 |
+
help="Enable weighted captions in the standard style (token:1.3). No commas inside parens, or shuffle/dropout may break the decoder. / 「[token]」、「(token)」「(token:1.3)」のような重み付きキャプションを有効にする。カンマを括弧内に入れるとシャッフルやdropoutで重みづけがおかしくなるので注意",
|
| 142 |
+
)
|
| 143 |
+
|
| 144 |
+
|
| 145 |
+
re_attention = re.compile(
|
| 146 |
+
r"""
|
| 147 |
+
\\\(|
|
| 148 |
+
\\\)|
|
| 149 |
+
\\\[|
|
| 150 |
+
\\]|
|
| 151 |
+
\\\\|
|
| 152 |
+
\\|
|
| 153 |
+
\(|
|
| 154 |
+
\[|
|
| 155 |
+
:([+-]?[.\d]+)\)|
|
| 156 |
+
\)|
|
| 157 |
+
]|
|
| 158 |
+
[^\\()\[\]:]+|
|
| 159 |
+
:
|
| 160 |
+
""",
|
| 161 |
+
re.X,
|
| 162 |
+
)
|
| 163 |
+
|
| 164 |
+
|
| 165 |
+
def parse_prompt_attention(text):
|
| 166 |
+
"""
|
| 167 |
+
Parses a string with attention tokens and returns a list of pairs: text and its associated weight.
|
| 168 |
+
Accepted tokens are:
|
| 169 |
+
(abc) - increases attention to abc by a multiplier of 1.1
|
| 170 |
+
(abc:3.12) - increases attention to abc by a multiplier of 3.12
|
| 171 |
+
[abc] - decreases attention to abc by a multiplier of 1.1
|
| 172 |
+
\( - literal character '('
|
| 173 |
+
\[ - literal character '['
|
| 174 |
+
\) - literal character ')'
|
| 175 |
+
\] - literal character ']'
|
| 176 |
+
\\ - literal character '\'
|
| 177 |
+
anything else - just text
|
| 178 |
+
>>> parse_prompt_attention('normal text')
|
| 179 |
+
[['normal text', 1.0]]
|
| 180 |
+
>>> parse_prompt_attention('an (important) word')
|
| 181 |
+
[['an ', 1.0], ['important', 1.1], [' word', 1.0]]
|
| 182 |
+
>>> parse_prompt_attention('(unbalanced')
|
| 183 |
+
[['unbalanced', 1.1]]
|
| 184 |
+
>>> parse_prompt_attention('\(literal\]')
|
| 185 |
+
[['(literal]', 1.0]]
|
| 186 |
+
>>> parse_prompt_attention('(unnecessary)(parens)')
|
| 187 |
+
[['unnecessaryparens', 1.1]]
|
| 188 |
+
>>> parse_prompt_attention('a (((house:1.3)) [on] a (hill:0.5), sun, (((sky))).')
|
| 189 |
+
[['a ', 1.0],
|
| 190 |
+
['house', 1.5730000000000004],
|
| 191 |
+
[' ', 1.1],
|
| 192 |
+
['on', 1.0],
|
| 193 |
+
[' a ', 1.1],
|
| 194 |
+
['hill', 0.55],
|
| 195 |
+
[', sun, ', 1.1],
|
| 196 |
+
['sky', 1.4641000000000006],
|
| 197 |
+
['.', 1.1]]
|
| 198 |
+
"""
|
| 199 |
+
|
| 200 |
+
res = []
|
| 201 |
+
round_brackets = []
|
| 202 |
+
square_brackets = []
|
| 203 |
+
|
| 204 |
+
round_bracket_multiplier = 1.1
|
| 205 |
+
square_bracket_multiplier = 1 / 1.1
|
| 206 |
+
|
| 207 |
+
def multiply_range(start_position, multiplier):
|
| 208 |
+
for p in range(start_position, len(res)):
|
| 209 |
+
res[p][1] *= multiplier
|
| 210 |
+
|
| 211 |
+
for m in re_attention.finditer(text):
|
| 212 |
+
text = m.group(0)
|
| 213 |
+
weight = m.group(1)
|
| 214 |
+
|
| 215 |
+
if text.startswith("\\"):
|
| 216 |
+
res.append([text[1:], 1.0])
|
| 217 |
+
elif text == "(":
|
| 218 |
+
round_brackets.append(len(res))
|
| 219 |
+
elif text == "[":
|
| 220 |
+
square_brackets.append(len(res))
|
| 221 |
+
elif weight is not None and len(round_brackets) > 0:
|
| 222 |
+
multiply_range(round_brackets.pop(), float(weight))
|
| 223 |
+
elif text == ")" and len(round_brackets) > 0:
|
| 224 |
+
multiply_range(round_brackets.pop(), round_bracket_multiplier)
|
| 225 |
+
elif text == "]" and len(square_brackets) > 0:
|
| 226 |
+
multiply_range(square_brackets.pop(), square_bracket_multiplier)
|
| 227 |
+
else:
|
| 228 |
+
res.append([text, 1.0])
|
| 229 |
+
|
| 230 |
+
for pos in round_brackets:
|
| 231 |
+
multiply_range(pos, round_bracket_multiplier)
|
| 232 |
+
|
| 233 |
+
for pos in square_brackets:
|
| 234 |
+
multiply_range(pos, square_bracket_multiplier)
|
| 235 |
+
|
| 236 |
+
if len(res) == 0:
|
| 237 |
+
res = [["", 1.0]]
|
| 238 |
+
|
| 239 |
+
# merge runs of identical weights
|
| 240 |
+
i = 0
|
| 241 |
+
while i + 1 < len(res):
|
| 242 |
+
if res[i][1] == res[i + 1][1]:
|
| 243 |
+
res[i][0] += res[i + 1][0]
|
| 244 |
+
res.pop(i + 1)
|
| 245 |
+
else:
|
| 246 |
+
i += 1
|
| 247 |
+
|
| 248 |
+
return res
|
| 249 |
+
|
| 250 |
+
|
| 251 |
+
def get_prompts_with_weights(tokenizer, prompt: List[str], max_length: int):
|
| 252 |
+
r"""
|
| 253 |
+
Tokenize a list of prompts and return its tokens with weights of each token.
|
| 254 |
+
|
| 255 |
+
No padding, starting or ending token is included.
|
| 256 |
+
"""
|
| 257 |
+
tokens = []
|
| 258 |
+
weights = []
|
| 259 |
+
truncated = False
|
| 260 |
+
for text in prompt:
|
| 261 |
+
texts_and_weights = parse_prompt_attention(text)
|
| 262 |
+
text_token = []
|
| 263 |
+
text_weight = []
|
| 264 |
+
for word, weight in texts_and_weights:
|
| 265 |
+
# tokenize and discard the starting and the ending token
|
| 266 |
+
token = tokenizer(word).input_ids[1:-1]
|
| 267 |
+
text_token += token
|
| 268 |
+
# copy the weight by length of token
|
| 269 |
+
text_weight += [weight] * len(token)
|
| 270 |
+
# stop if the text is too long (longer than truncation limit)
|
| 271 |
+
if len(text_token) > max_length:
|
| 272 |
+
truncated = True
|
| 273 |
+
break
|
| 274 |
+
# truncate
|
| 275 |
+
if len(text_token) > max_length:
|
| 276 |
+
truncated = True
|
| 277 |
+
text_token = text_token[:max_length]
|
| 278 |
+
text_weight = text_weight[:max_length]
|
| 279 |
+
tokens.append(text_token)
|
| 280 |
+
weights.append(text_weight)
|
| 281 |
+
if truncated:
|
| 282 |
+
logger.warning("Prompt was truncated. Try to shorten the prompt or increase max_embeddings_multiples")
|
| 283 |
+
return tokens, weights
|
| 284 |
+
|
| 285 |
+
|
| 286 |
+
def pad_tokens_and_weights(tokens, weights, max_length, bos, eos, no_boseos_middle=True, chunk_length=77):
|
| 287 |
+
r"""
|
| 288 |
+
Pad the tokens (with starting and ending tokens) and weights (with 1.0) to max_length.
|
| 289 |
+
"""
|
| 290 |
+
max_embeddings_multiples = (max_length - 2) // (chunk_length - 2)
|
| 291 |
+
weights_length = max_length if no_boseos_middle else max_embeddings_multiples * chunk_length
|
| 292 |
+
for i in range(len(tokens)):
|
| 293 |
+
tokens[i] = [bos] + tokens[i] + [eos] * (max_length - 1 - len(tokens[i]))
|
| 294 |
+
if no_boseos_middle:
|
| 295 |
+
weights[i] = [1.0] + weights[i] + [1.0] * (max_length - 1 - len(weights[i]))
|
| 296 |
+
else:
|
| 297 |
+
w = []
|
| 298 |
+
if len(weights[i]) == 0:
|
| 299 |
+
w = [1.0] * weights_length
|
| 300 |
+
else:
|
| 301 |
+
for j in range(max_embeddings_multiples):
|
| 302 |
+
w.append(1.0) # weight for starting token in this chunk
|
| 303 |
+
w += weights[i][j * (chunk_length - 2) : min(len(weights[i]), (j + 1) * (chunk_length - 2))]
|
| 304 |
+
w.append(1.0) # weight for ending token in this chunk
|
| 305 |
+
w += [1.0] * (weights_length - len(w))
|
| 306 |
+
weights[i] = w[:]
|
| 307 |
+
|
| 308 |
+
return tokens, weights
|
| 309 |
+
|
| 310 |
+
|
| 311 |
+
def get_unweighted_text_embeddings(
|
| 312 |
+
tokenizer,
|
| 313 |
+
text_encoder,
|
| 314 |
+
text_input: torch.Tensor,
|
| 315 |
+
chunk_length: int,
|
| 316 |
+
clip_skip: int,
|
| 317 |
+
eos: int,
|
| 318 |
+
pad: int,
|
| 319 |
+
no_boseos_middle: Optional[bool] = True,
|
| 320 |
+
):
|
| 321 |
+
"""
|
| 322 |
+
When the length of tokens is a multiple of the capacity of the text encoder,
|
| 323 |
+
it should be split into chunks and sent to the text encoder individually.
|
| 324 |
+
"""
|
| 325 |
+
max_embeddings_multiples = (text_input.shape[1] - 2) // (chunk_length - 2)
|
| 326 |
+
if max_embeddings_multiples > 1:
|
| 327 |
+
text_embeddings = []
|
| 328 |
+
for i in range(max_embeddings_multiples):
|
| 329 |
+
# extract the i-th chunk
|
| 330 |
+
text_input_chunk = text_input[:, i * (chunk_length - 2) : (i + 1) * (chunk_length - 2) + 2].clone()
|
| 331 |
+
|
| 332 |
+
# cover the head and the tail by the starting and the ending tokens
|
| 333 |
+
text_input_chunk[:, 0] = text_input[0, 0]
|
| 334 |
+
if pad == eos: # v1
|
| 335 |
+
text_input_chunk[:, -1] = text_input[0, -1]
|
| 336 |
+
else: # v2
|
| 337 |
+
for j in range(len(text_input_chunk)):
|
| 338 |
+
if text_input_chunk[j, -1] != eos and text_input_chunk[j, -1] != pad: # 最後に普通の文字がある
|
| 339 |
+
text_input_chunk[j, -1] = eos
|
| 340 |
+
if text_input_chunk[j, 1] == pad: # BOSだけであとはPAD
|
| 341 |
+
text_input_chunk[j, 1] = eos
|
| 342 |
+
|
| 343 |
+
if clip_skip is None or clip_skip == 1:
|
| 344 |
+
text_embedding = text_encoder(text_input_chunk)[0]
|
| 345 |
+
else:
|
| 346 |
+
enc_out = text_encoder(text_input_chunk, output_hidden_states=True, return_dict=True)
|
| 347 |
+
text_embedding = enc_out["hidden_states"][-clip_skip]
|
| 348 |
+
text_embedding = text_encoder.text_model.final_layer_norm(text_embedding)
|
| 349 |
+
|
| 350 |
+
if no_boseos_middle:
|
| 351 |
+
if i == 0:
|
| 352 |
+
# discard the ending token
|
| 353 |
+
text_embedding = text_embedding[:, :-1]
|
| 354 |
+
elif i == max_embeddings_multiples - 1:
|
| 355 |
+
# discard the starting token
|
| 356 |
+
text_embedding = text_embedding[:, 1:]
|
| 357 |
+
else:
|
| 358 |
+
# discard both starting and ending tokens
|
| 359 |
+
text_embedding = text_embedding[:, 1:-1]
|
| 360 |
+
|
| 361 |
+
text_embeddings.append(text_embedding)
|
| 362 |
+
text_embeddings = torch.concat(text_embeddings, axis=1)
|
| 363 |
+
else:
|
| 364 |
+
if clip_skip is None or clip_skip == 1:
|
| 365 |
+
text_embeddings = text_encoder(text_input)[0]
|
| 366 |
+
else:
|
| 367 |
+
enc_out = text_encoder(text_input, output_hidden_states=True, return_dict=True)
|
| 368 |
+
text_embeddings = enc_out["hidden_states"][-clip_skip]
|
| 369 |
+
text_embeddings = text_encoder.text_model.final_layer_norm(text_embeddings)
|
| 370 |
+
return text_embeddings
|
| 371 |
+
|
| 372 |
+
|
| 373 |
+
def get_weighted_text_embeddings(
|
| 374 |
+
tokenizer,
|
| 375 |
+
text_encoder,
|
| 376 |
+
prompt: Union[str, List[str]],
|
| 377 |
+
device,
|
| 378 |
+
max_embeddings_multiples: Optional[int] = 3,
|
| 379 |
+
no_boseos_middle: Optional[bool] = False,
|
| 380 |
+
clip_skip=None,
|
| 381 |
+
):
|
| 382 |
+
r"""
|
| 383 |
+
Prompts can be assigned with local weights using brackets. For example,
|
| 384 |
+
prompt 'A (very beautiful) masterpiece' highlights the words 'very beautiful',
|
| 385 |
+
and the embedding tokens corresponding to the words get multiplied by a constant, 1.1.
|
| 386 |
+
|
| 387 |
+
Also, to regularize of the embedding, the weighted embedding would be scaled to preserve the original mean.
|
| 388 |
+
|
| 389 |
+
Args:
|
| 390 |
+
prompt (`str` or `List[str]`):
|
| 391 |
+
The prompt or prompts to guide the image generation.
|
| 392 |
+
max_embeddings_multiples (`int`, *optional*, defaults to `3`):
|
| 393 |
+
The max multiple length of prompt embeddings compared to the max output length of text encoder.
|
| 394 |
+
no_boseos_middle (`bool`, *optional*, defaults to `False`):
|
| 395 |
+
If the length of text token is multiples of the capacity of text encoder, whether reserve the starting and
|
| 396 |
+
ending token in each of the chunk in the middle.
|
| 397 |
+
skip_parsing (`bool`, *optional*, defaults to `False`):
|
| 398 |
+
Skip the parsing of brackets.
|
| 399 |
+
skip_weighting (`bool`, *optional*, defaults to `False`):
|
| 400 |
+
Skip the weighting. When the parsing is skipped, it is forced True.
|
| 401 |
+
"""
|
| 402 |
+
max_length = (tokenizer.model_max_length - 2) * max_embeddings_multiples + 2
|
| 403 |
+
if isinstance(prompt, str):
|
| 404 |
+
prompt = [prompt]
|
| 405 |
+
|
| 406 |
+
prompt_tokens, prompt_weights = get_prompts_with_weights(tokenizer, prompt, max_length - 2)
|
| 407 |
+
|
| 408 |
+
# round up the longest length of tokens to a multiple of (model_max_length - 2)
|
| 409 |
+
max_length = max([len(token) for token in prompt_tokens])
|
| 410 |
+
|
| 411 |
+
max_embeddings_multiples = min(
|
| 412 |
+
max_embeddings_multiples,
|
| 413 |
+
(max_length - 1) // (tokenizer.model_max_length - 2) + 1,
|
| 414 |
+
)
|
| 415 |
+
max_embeddings_multiples = max(1, max_embeddings_multiples)
|
| 416 |
+
max_length = (tokenizer.model_max_length - 2) * max_embeddings_multiples + 2
|
| 417 |
+
|
| 418 |
+
# pad the length of tokens and weights
|
| 419 |
+
bos = tokenizer.bos_token_id
|
| 420 |
+
eos = tokenizer.eos_token_id
|
| 421 |
+
pad = tokenizer.pad_token_id
|
| 422 |
+
prompt_tokens, prompt_weights = pad_tokens_and_weights(
|
| 423 |
+
prompt_tokens,
|
| 424 |
+
prompt_weights,
|
| 425 |
+
max_length,
|
| 426 |
+
bos,
|
| 427 |
+
eos,
|
| 428 |
+
no_boseos_middle=no_boseos_middle,
|
| 429 |
+
chunk_length=tokenizer.model_max_length,
|
| 430 |
+
)
|
| 431 |
+
prompt_tokens = torch.tensor(prompt_tokens, dtype=torch.long, device=device)
|
| 432 |
+
|
| 433 |
+
# get the embeddings
|
| 434 |
+
text_embeddings = get_unweighted_text_embeddings(
|
| 435 |
+
tokenizer,
|
| 436 |
+
text_encoder,
|
| 437 |
+
prompt_tokens,
|
| 438 |
+
tokenizer.model_max_length,
|
| 439 |
+
clip_skip,
|
| 440 |
+
eos,
|
| 441 |
+
pad,
|
| 442 |
+
no_boseos_middle=no_boseos_middle,
|
| 443 |
+
)
|
| 444 |
+
prompt_weights = torch.tensor(prompt_weights, dtype=text_embeddings.dtype, device=device)
|
| 445 |
+
|
| 446 |
+
# assign weights to the prompts and normalize in the sense of mean
|
| 447 |
+
previous_mean = text_embeddings.float().mean(axis=[-2, -1]).to(text_embeddings.dtype)
|
| 448 |
+
text_embeddings = text_embeddings * prompt_weights.unsqueeze(-1)
|
| 449 |
+
current_mean = text_embeddings.float().mean(axis=[-2, -1]).to(text_embeddings.dtype)
|
| 450 |
+
text_embeddings = text_embeddings * (previous_mean / current_mean).unsqueeze(-1).unsqueeze(-1)
|
| 451 |
+
|
| 452 |
+
return text_embeddings
|
| 453 |
+
|
| 454 |
+
|
| 455 |
+
# https://wandb.ai/johnowhitaker/multires_noise/reports/Multi-Resolution-Noise-for-Diffusion-Model-Training--VmlldzozNjYyOTU2
|
| 456 |
+
def pyramid_noise_like(noise, device, iterations=6, discount=0.4):
|
| 457 |
+
b, c, w, h = noise.shape # EDIT: w and h get over-written, rename for a different variant!
|
| 458 |
+
u = torch.nn.Upsample(size=(w, h), mode="bilinear").to(device)
|
| 459 |
+
for i in range(iterations):
|
| 460 |
+
r = random.random() * 2 + 2 # Rather than always going 2x,
|
| 461 |
+
wn, hn = max(1, int(w / (r**i))), max(1, int(h / (r**i)))
|
| 462 |
+
noise += u(torch.randn(b, c, wn, hn).to(device)) * discount**i
|
| 463 |
+
if wn == 1 or hn == 1:
|
| 464 |
+
break # Lowest resolution is 1x1
|
| 465 |
+
return noise / noise.std() # Scaled back to roughly unit variance
|
| 466 |
+
|
| 467 |
+
|
| 468 |
+
# https://www.crosslabs.org//blog/diffusion-with-offset-noise
|
| 469 |
+
def apply_noise_offset(latents, noise, noise_offset, adaptive_noise_scale):
|
| 470 |
+
if noise_offset is None:
|
| 471 |
+
return noise
|
| 472 |
+
if adaptive_noise_scale is not None:
|
| 473 |
+
# latent shape: (batch_size, channels, height, width)
|
| 474 |
+
# abs mean value for each channel
|
| 475 |
+
latent_mean = torch.abs(latents.mean(dim=(2, 3), keepdim=True))
|
| 476 |
+
|
| 477 |
+
# multiply adaptive noise scale to the mean value and add it to the noise offset
|
| 478 |
+
noise_offset = noise_offset + adaptive_noise_scale * latent_mean
|
| 479 |
+
noise_offset = torch.clamp(noise_offset, 0.0, None) # in case of adaptive noise scale is negative
|
| 480 |
+
|
| 481 |
+
noise = noise + noise_offset * torch.randn((latents.shape[0], latents.shape[1], 1, 1), device=latents.device)
|
| 482 |
+
return noise
|
| 483 |
+
|
| 484 |
+
|
| 485 |
+
def apply_masked_loss(loss, batch):
|
| 486 |
+
if "conditioning_images" in batch:
|
| 487 |
+
# conditioning image is -1 to 1. we need to convert it to 0 to 1
|
| 488 |
+
mask_image = batch["conditioning_images"].to(dtype=loss.dtype)[:, 0].unsqueeze(1) # use R channel
|
| 489 |
+
mask_image = mask_image / 2 + 0.5
|
| 490 |
+
# print(f"conditioning_image: {mask_image.shape}")
|
| 491 |
+
elif "alpha_masks" in batch and batch["alpha_masks"] is not None:
|
| 492 |
+
# alpha mask is 0 to 1
|
| 493 |
+
mask_image = batch["alpha_masks"].to(dtype=loss.dtype).unsqueeze(1) # add channel dimension
|
| 494 |
+
# print(f"mask_image: {mask_image.shape}, {mask_image.mean()}")
|
| 495 |
+
else:
|
| 496 |
+
return loss
|
| 497 |
+
|
| 498 |
+
# resize to the same size as the loss
|
| 499 |
+
mask_image = torch.nn.functional.interpolate(mask_image, size=loss.shape[2:], mode="area")
|
| 500 |
+
loss = loss * mask_image
|
| 501 |
+
return loss
|
| 502 |
+
|
| 503 |
+
|
| 504 |
+
"""
|
| 505 |
+
##########################################
|
| 506 |
+
# Perlin Noise
|
| 507 |
+
def rand_perlin_2d(device, shape, res, fade=lambda t: 6 * t**5 - 15 * t**4 + 10 * t**3):
|
| 508 |
+
delta = (res[0] / shape[0], res[1] / shape[1])
|
| 509 |
+
d = (shape[0] // res[0], shape[1] // res[1])
|
| 510 |
+
|
| 511 |
+
grid = (
|
| 512 |
+
torch.stack(
|
| 513 |
+
torch.meshgrid(torch.arange(0, res[0], delta[0], device=device), torch.arange(0, res[1], delta[1], device=device)),
|
| 514 |
+
dim=-1,
|
| 515 |
+
)
|
| 516 |
+
% 1
|
| 517 |
+
)
|
| 518 |
+
angles = 2 * torch.pi * torch.rand(res[0] + 1, res[1] + 1, device=device)
|
| 519 |
+
gradients = torch.stack((torch.cos(angles), torch.sin(angles)), dim=-1)
|
| 520 |
+
|
| 521 |
+
tile_grads = (
|
| 522 |
+
lambda slice1, slice2: gradients[slice1[0] : slice1[1], slice2[0] : slice2[1]]
|
| 523 |
+
.repeat_interleave(d[0], 0)
|
| 524 |
+
.repeat_interleave(d[1], 1)
|
| 525 |
+
)
|
| 526 |
+
dot = lambda grad, shift: (
|
| 527 |
+
torch.stack((grid[: shape[0], : shape[1], 0] + shift[0], grid[: shape[0], : shape[1], 1] + shift[1]), dim=-1)
|
| 528 |
+
* grad[: shape[0], : shape[1]]
|
| 529 |
+
).sum(dim=-1)
|
| 530 |
+
|
| 531 |
+
n00 = dot(tile_grads([0, -1], [0, -1]), [0, 0])
|
| 532 |
+
n10 = dot(tile_grads([1, None], [0, -1]), [-1, 0])
|
| 533 |
+
n01 = dot(tile_grads([0, -1], [1, None]), [0, -1])
|
| 534 |
+
n11 = dot(tile_grads([1, None], [1, None]), [-1, -1])
|
| 535 |
+
t = fade(grid[: shape[0], : shape[1]])
|
| 536 |
+
return 1.414 * torch.lerp(torch.lerp(n00, n10, t[..., 0]), torch.lerp(n01, n11, t[..., 0]), t[..., 1])
|
| 537 |
+
|
| 538 |
+
|
| 539 |
+
def rand_perlin_2d_octaves(device, shape, res, octaves=1, persistence=0.5):
|
| 540 |
+
noise = torch.zeros(shape, device=device)
|
| 541 |
+
frequency = 1
|
| 542 |
+
amplitude = 1
|
| 543 |
+
for _ in range(octaves):
|
| 544 |
+
noise += amplitude * rand_perlin_2d(device, shape, (frequency * res[0], frequency * res[1]))
|
| 545 |
+
frequency *= 2
|
| 546 |
+
amplitude *= persistence
|
| 547 |
+
return noise
|
| 548 |
+
|
| 549 |
+
|
| 550 |
+
def perlin_noise(noise, device, octaves):
|
| 551 |
+
_, c, w, h = noise.shape
|
| 552 |
+
perlin = lambda: rand_perlin_2d_octaves(device, (w, h), (4, 4), octaves)
|
| 553 |
+
noise_perlin = []
|
| 554 |
+
for _ in range(c):
|
| 555 |
+
noise_perlin.append(perlin())
|
| 556 |
+
noise_perlin = torch.stack(noise_perlin).unsqueeze(0) # (1, c, w, h)
|
| 557 |
+
noise += noise_perlin # broadcast for each batch
|
| 558 |
+
return noise / noise.std() # Scaled back to roughly unit variance
|
| 559 |
+
"""
|
deepspeed_utils.py
ADDED
|
@@ -0,0 +1,139 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import argparse
|
| 3 |
+
import torch
|
| 4 |
+
from accelerate import DeepSpeedPlugin, Accelerator
|
| 5 |
+
|
| 6 |
+
from .utils import setup_logging
|
| 7 |
+
|
| 8 |
+
setup_logging()
|
| 9 |
+
import logging
|
| 10 |
+
|
| 11 |
+
logger = logging.getLogger(__name__)
|
| 12 |
+
|
| 13 |
+
|
| 14 |
+
def add_deepspeed_arguments(parser: argparse.ArgumentParser):
|
| 15 |
+
# DeepSpeed Arguments. https://huggingface.co/docs/accelerate/usage_guides/deepspeed
|
| 16 |
+
parser.add_argument("--deepspeed", action="store_true", help="enable deepspeed training")
|
| 17 |
+
parser.add_argument("--zero_stage", type=int, default=2, choices=[0, 1, 2, 3], help="Possible options are 0,1,2,3.")
|
| 18 |
+
parser.add_argument(
|
| 19 |
+
"--offload_optimizer_device",
|
| 20 |
+
type=str,
|
| 21 |
+
default=None,
|
| 22 |
+
choices=[None, "cpu", "nvme"],
|
| 23 |
+
help="Possible options are none|cpu|nvme. Only applicable with ZeRO Stages 2 and 3.",
|
| 24 |
+
)
|
| 25 |
+
parser.add_argument(
|
| 26 |
+
"--offload_optimizer_nvme_path",
|
| 27 |
+
type=str,
|
| 28 |
+
default=None,
|
| 29 |
+
help="Possible options are /nvme|/local_nvme. Only applicable with ZeRO Stage 3.",
|
| 30 |
+
)
|
| 31 |
+
parser.add_argument(
|
| 32 |
+
"--offload_param_device",
|
| 33 |
+
type=str,
|
| 34 |
+
default=None,
|
| 35 |
+
choices=[None, "cpu", "nvme"],
|
| 36 |
+
help="Possible options are none|cpu|nvme. Only applicable with ZeRO Stage 3.",
|
| 37 |
+
)
|
| 38 |
+
parser.add_argument(
|
| 39 |
+
"--offload_param_nvme_path",
|
| 40 |
+
type=str,
|
| 41 |
+
default=None,
|
| 42 |
+
help="Possible options are /nvme|/local_nvme. Only applicable with ZeRO Stage 3.",
|
| 43 |
+
)
|
| 44 |
+
parser.add_argument(
|
| 45 |
+
"--zero3_init_flag",
|
| 46 |
+
action="store_true",
|
| 47 |
+
help="Flag to indicate whether to enable `deepspeed.zero.Init` for constructing massive models."
|
| 48 |
+
"Only applicable with ZeRO Stage-3.",
|
| 49 |
+
)
|
| 50 |
+
parser.add_argument(
|
| 51 |
+
"--zero3_save_16bit_model",
|
| 52 |
+
action="store_true",
|
| 53 |
+
help="Flag to indicate whether to save 16-bit model. Only applicable with ZeRO Stage-3.",
|
| 54 |
+
)
|
| 55 |
+
parser.add_argument(
|
| 56 |
+
"--fp16_master_weights_and_gradients",
|
| 57 |
+
action="store_true",
|
| 58 |
+
help="fp16_master_and_gradients requires optimizer to support keeping fp16 master and gradients while keeping the optimizer states in fp32.",
|
| 59 |
+
)
|
| 60 |
+
|
| 61 |
+
|
| 62 |
+
def prepare_deepspeed_args(args: argparse.Namespace):
|
| 63 |
+
if not args.deepspeed:
|
| 64 |
+
return
|
| 65 |
+
|
| 66 |
+
# To avoid RuntimeError: DataLoader worker exited unexpectedly with exit code 1.
|
| 67 |
+
args.max_data_loader_n_workers = 1
|
| 68 |
+
|
| 69 |
+
|
| 70 |
+
def prepare_deepspeed_plugin(args: argparse.Namespace):
|
| 71 |
+
if not args.deepspeed:
|
| 72 |
+
return None
|
| 73 |
+
|
| 74 |
+
try:
|
| 75 |
+
import deepspeed
|
| 76 |
+
except ImportError as e:
|
| 77 |
+
logger.error(
|
| 78 |
+
"deepspeed is not installed. please install deepspeed in your environment with following command. DS_BUILD_OPS=0 pip install deepspeed"
|
| 79 |
+
)
|
| 80 |
+
exit(1)
|
| 81 |
+
|
| 82 |
+
deepspeed_plugin = DeepSpeedPlugin(
|
| 83 |
+
zero_stage=args.zero_stage,
|
| 84 |
+
gradient_accumulation_steps=args.gradient_accumulation_steps,
|
| 85 |
+
gradient_clipping=args.max_grad_norm,
|
| 86 |
+
offload_optimizer_device=args.offload_optimizer_device,
|
| 87 |
+
offload_optimizer_nvme_path=args.offload_optimizer_nvme_path,
|
| 88 |
+
offload_param_device=args.offload_param_device,
|
| 89 |
+
offload_param_nvme_path=args.offload_param_nvme_path,
|
| 90 |
+
zero3_init_flag=args.zero3_init_flag,
|
| 91 |
+
zero3_save_16bit_model=args.zero3_save_16bit_model,
|
| 92 |
+
)
|
| 93 |
+
deepspeed_plugin.deepspeed_config["train_micro_batch_size_per_gpu"] = args.train_batch_size
|
| 94 |
+
deepspeed_plugin.deepspeed_config["train_batch_size"] = (
|
| 95 |
+
args.train_batch_size * args.gradient_accumulation_steps * int(os.environ["WORLD_SIZE"])
|
| 96 |
+
)
|
| 97 |
+
deepspeed_plugin.set_mixed_precision(args.mixed_precision)
|
| 98 |
+
if args.mixed_precision.lower() == "fp16":
|
| 99 |
+
deepspeed_plugin.deepspeed_config["fp16"]["initial_scale_power"] = 0 # preventing overflow.
|
| 100 |
+
if args.full_fp16 or args.fp16_master_weights_and_gradients:
|
| 101 |
+
if args.offload_optimizer_device == "cpu" and args.zero_stage == 2:
|
| 102 |
+
deepspeed_plugin.deepspeed_config["fp16"]["fp16_master_weights_and_grads"] = True
|
| 103 |
+
logger.info("[DeepSpeed] full fp16 enable.")
|
| 104 |
+
else:
|
| 105 |
+
logger.info(
|
| 106 |
+
"[DeepSpeed]full fp16, fp16_master_weights_and_grads currently only supported using ZeRO-Offload with DeepSpeedCPUAdam on ZeRO-2 stage."
|
| 107 |
+
)
|
| 108 |
+
|
| 109 |
+
if args.offload_optimizer_device is not None:
|
| 110 |
+
logger.info("[DeepSpeed] start to manually build cpu_adam.")
|
| 111 |
+
deepspeed.ops.op_builder.CPUAdamBuilder().load()
|
| 112 |
+
logger.info("[DeepSpeed] building cpu_adam done.")
|
| 113 |
+
|
| 114 |
+
return deepspeed_plugin
|
| 115 |
+
|
| 116 |
+
|
| 117 |
+
# Accelerate library does not support multiple models for deepspeed. So, we need to wrap multiple models into a single model.
|
| 118 |
+
def prepare_deepspeed_model(args: argparse.Namespace, **models):
|
| 119 |
+
# remove None from models
|
| 120 |
+
models = {k: v for k, v in models.items() if v is not None}
|
| 121 |
+
|
| 122 |
+
class DeepSpeedWrapper(torch.nn.Module):
|
| 123 |
+
def __init__(self, **kw_models) -> None:
|
| 124 |
+
super().__init__()
|
| 125 |
+
self.models = torch.nn.ModuleDict()
|
| 126 |
+
|
| 127 |
+
for key, model in kw_models.items():
|
| 128 |
+
if isinstance(model, list):
|
| 129 |
+
model = torch.nn.ModuleList(model)
|
| 130 |
+
assert isinstance(
|
| 131 |
+
model, torch.nn.Module
|
| 132 |
+
), f"model must be an instance of torch.nn.Module, but got {key} is {type(model)}"
|
| 133 |
+
self.models.update(torch.nn.ModuleDict({key: model}))
|
| 134 |
+
|
| 135 |
+
def get_models(self):
|
| 136 |
+
return self.models
|
| 137 |
+
|
| 138 |
+
ds_model = DeepSpeedWrapper(**models)
|
| 139 |
+
return ds_model
|
dependabot.yml
ADDED
|
@@ -0,0 +1,7 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
---
|
| 2 |
+
version: 2
|
| 3 |
+
updates:
|
| 4 |
+
- package-ecosystem: "github-actions"
|
| 5 |
+
directory: "/"
|
| 6 |
+
schedule:
|
| 7 |
+
interval: "monthly"
|
detect_face_rotate.py
ADDED
|
@@ -0,0 +1,253 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# このスクリプトのライセンスは、train_dreambooth.pyと同じくApache License 2.0とします
|
| 2 |
+
# (c) 2022 Kohya S. @kohya_ss
|
| 3 |
+
|
| 4 |
+
# 横長の画像から顔検出して正立するように回転し、そこを中心に正方形に切り出す
|
| 5 |
+
|
| 6 |
+
# v2: extract max face if multiple faces are found
|
| 7 |
+
# v3: add crop_ratio option
|
| 8 |
+
# v4: add multiple faces extraction and min/max size
|
| 9 |
+
|
| 10 |
+
import argparse
|
| 11 |
+
import math
|
| 12 |
+
import cv2
|
| 13 |
+
import glob
|
| 14 |
+
import os
|
| 15 |
+
from anime_face_detector import create_detector
|
| 16 |
+
from tqdm import tqdm
|
| 17 |
+
import numpy as np
|
| 18 |
+
from library.utils import setup_logging, pil_resize
|
| 19 |
+
setup_logging()
|
| 20 |
+
import logging
|
| 21 |
+
logger = logging.getLogger(__name__)
|
| 22 |
+
|
| 23 |
+
KP_REYE = 11
|
| 24 |
+
KP_LEYE = 19
|
| 25 |
+
|
| 26 |
+
SCORE_THRES = 0.90
|
| 27 |
+
|
| 28 |
+
|
| 29 |
+
def detect_faces(detector, image, min_size):
|
| 30 |
+
preds = detector(image) # bgr
|
| 31 |
+
# logger.info(len(preds))
|
| 32 |
+
|
| 33 |
+
faces = []
|
| 34 |
+
for pred in preds:
|
| 35 |
+
bb = pred['bbox']
|
| 36 |
+
score = bb[-1]
|
| 37 |
+
if score < SCORE_THRES:
|
| 38 |
+
continue
|
| 39 |
+
|
| 40 |
+
left, top, right, bottom = bb[:4]
|
| 41 |
+
cx = int((left + right) / 2)
|
| 42 |
+
cy = int((top + bottom) / 2)
|
| 43 |
+
fw = int(right - left)
|
| 44 |
+
fh = int(bottom - top)
|
| 45 |
+
|
| 46 |
+
lex, ley = pred['keypoints'][KP_LEYE, 0:2]
|
| 47 |
+
rex, rey = pred['keypoints'][KP_REYE, 0:2]
|
| 48 |
+
angle = math.atan2(ley - rey, lex - rex)
|
| 49 |
+
angle = angle / math.pi * 180
|
| 50 |
+
|
| 51 |
+
faces.append((cx, cy, fw, fh, angle))
|
| 52 |
+
|
| 53 |
+
faces.sort(key=lambda x: max(x[2], x[3]), reverse=True) # 大きい順
|
| 54 |
+
return faces
|
| 55 |
+
|
| 56 |
+
|
| 57 |
+
def rotate_image(image, angle, cx, cy):
|
| 58 |
+
h, w = image.shape[0:2]
|
| 59 |
+
rot_mat = cv2.getRotationMatrix2D((cx, cy), angle, 1.0)
|
| 60 |
+
|
| 61 |
+
# # 回転する分、すこし画像サイズを大きくする→とりあえず無効化
|
| 62 |
+
# nh = max(h, int(w * math.sin(angle)))
|
| 63 |
+
# nw = max(w, int(h * math.sin(angle)))
|
| 64 |
+
# if nh > h or nw > w:
|
| 65 |
+
# pad_y = nh - h
|
| 66 |
+
# pad_t = pad_y // 2
|
| 67 |
+
# pad_x = nw - w
|
| 68 |
+
# pad_l = pad_x // 2
|
| 69 |
+
# m = np.array([[0, 0, pad_l],
|
| 70 |
+
# [0, 0, pad_t]])
|
| 71 |
+
# rot_mat = rot_mat + m
|
| 72 |
+
# h, w = nh, nw
|
| 73 |
+
# cx += pad_l
|
| 74 |
+
# cy += pad_t
|
| 75 |
+
|
| 76 |
+
result = cv2.warpAffine(image, rot_mat, (w, h), flags=cv2.INTER_LINEAR, borderMode=cv2.BORDER_REFLECT)
|
| 77 |
+
return result, cx, cy
|
| 78 |
+
|
| 79 |
+
|
| 80 |
+
def process(args):
|
| 81 |
+
assert (not args.resize_fit) or args.resize_face_size is None, f"resize_fit and resize_face_size can't be specified both / resize_fitとresize_face_sizeはどちらか片方しか指定できません"
|
| 82 |
+
assert args.crop_ratio is None or args.resize_face_size is None, f"crop_ratio指定時はresize_face_sizeは指定できません"
|
| 83 |
+
|
| 84 |
+
# アニメ顔検出モデルを読み込む
|
| 85 |
+
logger.info("loading face detector.")
|
| 86 |
+
detector = create_detector('yolov3')
|
| 87 |
+
|
| 88 |
+
# cropの引数を解析する
|
| 89 |
+
if args.crop_size is None:
|
| 90 |
+
crop_width = crop_height = None
|
| 91 |
+
else:
|
| 92 |
+
tokens = args.crop_size.split(',')
|
| 93 |
+
assert len(tokens) == 2, f"crop_size must be 'width,height' / crop_sizeは'幅,高さ'で指定してください"
|
| 94 |
+
crop_width, crop_height = [int(t) for t in tokens]
|
| 95 |
+
|
| 96 |
+
if args.crop_ratio is None:
|
| 97 |
+
crop_h_ratio = crop_v_ratio = None
|
| 98 |
+
else:
|
| 99 |
+
tokens = args.crop_ratio.split(',')
|
| 100 |
+
assert len(tokens) == 2, f"crop_ratio must be 'horizontal,vertical' / crop_ratioは'幅,高さ'の倍率で指定してください"
|
| 101 |
+
crop_h_ratio, crop_v_ratio = [float(t) for t in tokens]
|
| 102 |
+
|
| 103 |
+
# 画像を処理する
|
| 104 |
+
logger.info("processing.")
|
| 105 |
+
output_extension = ".png"
|
| 106 |
+
|
| 107 |
+
os.makedirs(args.dst_dir, exist_ok=True)
|
| 108 |
+
paths = glob.glob(os.path.join(args.src_dir, "*.png")) + glob.glob(os.path.join(args.src_dir, "*.jpg")) + \
|
| 109 |
+
glob.glob(os.path.join(args.src_dir, "*.webp"))
|
| 110 |
+
for path in tqdm(paths):
|
| 111 |
+
basename = os.path.splitext(os.path.basename(path))[0]
|
| 112 |
+
|
| 113 |
+
# image = cv2.imread(path) # 日本語ファイル名でエラーになる
|
| 114 |
+
image = cv2.imdecode(np.fromfile(path, np.uint8), cv2.IMREAD_UNCHANGED)
|
| 115 |
+
if len(image.shape) == 2:
|
| 116 |
+
image = cv2.cvtColor(image, cv2.COLOR_GRAY2BGR)
|
| 117 |
+
if image.shape[2] == 4:
|
| 118 |
+
logger.warning(f"image has alpha. ignore / 画像の透明度が設定されているため無視します: {path}")
|
| 119 |
+
image = image[:, :, :3].copy() # copyをしないと内部的に透明度情報が付いたままになるらしい
|
| 120 |
+
|
| 121 |
+
h, w = image.shape[:2]
|
| 122 |
+
|
| 123 |
+
faces = detect_faces(detector, image, args.multiple_faces)
|
| 124 |
+
for i, face in enumerate(faces):
|
| 125 |
+
cx, cy, fw, fh, angle = face
|
| 126 |
+
face_size = max(fw, fh)
|
| 127 |
+
if args.min_size is not None and face_size < args.min_size:
|
| 128 |
+
continue
|
| 129 |
+
if args.max_size is not None and face_size >= args.max_size:
|
| 130 |
+
continue
|
| 131 |
+
face_suffix = f"_{i+1:02d}" if args.multiple_faces else ""
|
| 132 |
+
|
| 133 |
+
# オプション指定があれば回転する
|
| 134 |
+
face_img = image
|
| 135 |
+
if args.rotate:
|
| 136 |
+
face_img, cx, cy = rotate_image(face_img, angle, cx, cy)
|
| 137 |
+
|
| 138 |
+
# オプション指定があれば顔を中心に切り出す
|
| 139 |
+
if crop_width is not None or crop_h_ratio is not None:
|
| 140 |
+
cur_crop_width, cur_crop_height = crop_width, crop_height
|
| 141 |
+
if crop_h_ratio is not None:
|
| 142 |
+
cur_crop_width = int(face_size * crop_h_ratio + .5)
|
| 143 |
+
cur_crop_height = int(face_size * crop_v_ratio + .5)
|
| 144 |
+
|
| 145 |
+
# リサイズを必要なら行う
|
| 146 |
+
scale = 1.0
|
| 147 |
+
if args.resize_face_size is not None:
|
| 148 |
+
# 顔サイズを基準にリサイズする
|
| 149 |
+
scale = args.resize_face_size / face_size
|
| 150 |
+
if scale < cur_crop_width / w:
|
| 151 |
+
logger.warning(
|
| 152 |
+
f"image width too small in face size based resizing / 顔を基準にリサイズすると画像の幅がcrop sizeより小さい(顔が相対的に大きすぎる)ので顔サイズが変わります: {path}")
|
| 153 |
+
scale = cur_crop_width / w
|
| 154 |
+
if scale < cur_crop_height / h:
|
| 155 |
+
logger.warning(
|
| 156 |
+
f"image height too small in face size based resizing / 顔を基準にリサイズすると画像の高さがcrop sizeより小さい(顔が相対的に大きすぎる)ので顔サイズが変わります: {path}")
|
| 157 |
+
scale = cur_crop_height / h
|
| 158 |
+
elif crop_h_ratio is not None:
|
| 159 |
+
# 倍率指定の時にはリサイズしない
|
| 160 |
+
pass
|
| 161 |
+
else:
|
| 162 |
+
# 切り出しサイズ指定あり
|
| 163 |
+
if w < cur_crop_width:
|
| 164 |
+
logger.warning(f"image width too small/ 画像の幅がcrop sizeより小さいので画質が劣化します: {path}")
|
| 165 |
+
scale = cur_crop_width / w
|
| 166 |
+
if h < cur_crop_height:
|
| 167 |
+
logger.warning(f"image height too small/ 画像の高さがcrop sizeより小さいので画質が劣化します: {path}")
|
| 168 |
+
scale = cur_crop_height / h
|
| 169 |
+
if args.resize_fit:
|
| 170 |
+
scale = max(cur_crop_width / w, cur_crop_height / h)
|
| 171 |
+
|
| 172 |
+
if scale != 1.0:
|
| 173 |
+
w = int(w * scale + .5)
|
| 174 |
+
h = int(h * scale + .5)
|
| 175 |
+
if scale < 1.0:
|
| 176 |
+
face_img = cv2.resize(face_img, (w, h), interpolation=cv2.INTER_AREA)
|
| 177 |
+
else:
|
| 178 |
+
face_img = pil_resize(face_img, (w, h))
|
| 179 |
+
cx = int(cx * scale + .5)
|
| 180 |
+
cy = int(cy * scale + .5)
|
| 181 |
+
fw = int(fw * scale + .5)
|
| 182 |
+
fh = int(fh * scale + .5)
|
| 183 |
+
|
| 184 |
+
cur_crop_width = min(cur_crop_width, face_img.shape[1])
|
| 185 |
+
cur_crop_height = min(cur_crop_height, face_img.shape[0])
|
| 186 |
+
|
| 187 |
+
x = cx - cur_crop_width // 2
|
| 188 |
+
cx = cur_crop_width // 2
|
| 189 |
+
if x < 0:
|
| 190 |
+
cx = cx + x
|
| 191 |
+
x = 0
|
| 192 |
+
elif x + cur_crop_width > w:
|
| 193 |
+
cx = cx + (x + cur_crop_width - w)
|
| 194 |
+
x = w - cur_crop_width
|
| 195 |
+
face_img = face_img[:, x:x+cur_crop_width]
|
| 196 |
+
|
| 197 |
+
y = cy - cur_crop_height // 2
|
| 198 |
+
cy = cur_crop_height // 2
|
| 199 |
+
if y < 0:
|
| 200 |
+
cy = cy + y
|
| 201 |
+
y = 0
|
| 202 |
+
elif y + cur_crop_height > h:
|
| 203 |
+
cy = cy + (y + cur_crop_height - h)
|
| 204 |
+
y = h - cur_crop_height
|
| 205 |
+
face_img = face_img[y:y + cur_crop_height]
|
| 206 |
+
|
| 207 |
+
# # debug
|
| 208 |
+
# logger.info(path, cx, cy, angle)
|
| 209 |
+
# crp = cv2.resize(image, (image.shape[1]//8, image.shape[0]//8))
|
| 210 |
+
# cv2.imshow("image", crp)
|
| 211 |
+
# if cv2.waitKey() == 27:
|
| 212 |
+
# break
|
| 213 |
+
# cv2.destroyAllWindows()
|
| 214 |
+
|
| 215 |
+
# debug
|
| 216 |
+
if args.debug:
|
| 217 |
+
cv2.rectangle(face_img, (cx-fw//2, cy-fh//2), (cx+fw//2, cy+fh//2), (255, 0, 255), fw//20)
|
| 218 |
+
|
| 219 |
+
_, buf = cv2.imencode(output_extension, face_img)
|
| 220 |
+
with open(os.path.join(args.dst_dir, f"{basename}{face_suffix}_{cx:04d}_{cy:04d}_{fw:04d}_{fh:04d}{output_extension}"), "wb") as f:
|
| 221 |
+
buf.tofile(f)
|
| 222 |
+
|
| 223 |
+
|
| 224 |
+
def setup_parser() -> argparse.ArgumentParser:
|
| 225 |
+
parser = argparse.ArgumentParser()
|
| 226 |
+
parser.add_argument("--src_dir", type=str, help="directory to load images / 画像を読み込むディレクトリ")
|
| 227 |
+
parser.add_argument("--dst_dir", type=str, help="directory to save images / 画像を保存するディレクトリ")
|
| 228 |
+
parser.add_argument("--rotate", action="store_true", help="rotate images to align faces / 顔が正立するように画像を回転する")
|
| 229 |
+
parser.add_argument("--resize_fit", action="store_true",
|
| 230 |
+
help="resize to fit smaller side after cropping / 切り出し後の画像の短辺がcrop_sizeにあうようにリサイズする")
|
| 231 |
+
parser.add_argument("--resize_face_size", type=int, default=None,
|
| 232 |
+
help="resize image before cropping by face size / 切り出し前に顔がこのサイズになるようにリサイズする")
|
| 233 |
+
parser.add_argument("--crop_size", type=str, default=None,
|
| 234 |
+
help="crop images with 'width,height' pixels, face centered / 顔を中心として'幅,高さ'のサイズで切り出す")
|
| 235 |
+
parser.add_argument("--crop_ratio", type=str, default=None,
|
| 236 |
+
help="crop images with 'horizontal,vertical' ratio to face, face centered / 顔を中心として顔サイズの'幅倍率,高さ倍率'のサイズで切り出す")
|
| 237 |
+
parser.add_argument("--min_size", type=int, default=None,
|
| 238 |
+
help="minimum face size to output (included) / 処理対象とする顔の最小サイズ(この値以上)")
|
| 239 |
+
parser.add_argument("--max_size", type=int, default=None,
|
| 240 |
+
help="maximum face size to output (excluded) / 処理対象とする顔の最大サイズ(この値未満)")
|
| 241 |
+
parser.add_argument("--multiple_faces", action="store_true",
|
| 242 |
+
help="output each faces / 複数の顔が見つかった場合、それぞれを切り出す")
|
| 243 |
+
parser.add_argument("--debug", action="store_true", help="render rect for face / 処理後画像の顔位置に矩形を描画します")
|
| 244 |
+
|
| 245 |
+
return parser
|
| 246 |
+
|
| 247 |
+
|
| 248 |
+
if __name__ == '__main__':
|
| 249 |
+
parser = setup_parser()
|
| 250 |
+
|
| 251 |
+
args = parser.parse_args()
|
| 252 |
+
|
| 253 |
+
process(args)
|
device_utils.py
ADDED
|
@@ -0,0 +1,89 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import functools
|
| 2 |
+
import gc
|
| 3 |
+
|
| 4 |
+
import torch
|
| 5 |
+
try:
|
| 6 |
+
# intel gpu support for pytorch older than 2.5
|
| 7 |
+
# ipex is not needed after pytorch 2.5
|
| 8 |
+
import intel_extension_for_pytorch as ipex # noqa
|
| 9 |
+
except Exception:
|
| 10 |
+
pass
|
| 11 |
+
|
| 12 |
+
|
| 13 |
+
try:
|
| 14 |
+
HAS_CUDA = torch.cuda.is_available()
|
| 15 |
+
except Exception:
|
| 16 |
+
HAS_CUDA = False
|
| 17 |
+
|
| 18 |
+
try:
|
| 19 |
+
HAS_MPS = torch.backends.mps.is_available()
|
| 20 |
+
except Exception:
|
| 21 |
+
HAS_MPS = False
|
| 22 |
+
|
| 23 |
+
try:
|
| 24 |
+
HAS_XPU = torch.xpu.is_available()
|
| 25 |
+
except Exception:
|
| 26 |
+
HAS_XPU = False
|
| 27 |
+
|
| 28 |
+
|
| 29 |
+
def clean_memory():
|
| 30 |
+
gc.collect()
|
| 31 |
+
if HAS_CUDA:
|
| 32 |
+
torch.cuda.empty_cache()
|
| 33 |
+
if HAS_XPU:
|
| 34 |
+
torch.xpu.empty_cache()
|
| 35 |
+
if HAS_MPS:
|
| 36 |
+
torch.mps.empty_cache()
|
| 37 |
+
|
| 38 |
+
|
| 39 |
+
def clean_memory_on_device(device: torch.device):
|
| 40 |
+
r"""
|
| 41 |
+
Clean memory on the specified device, will be called from training scripts.
|
| 42 |
+
"""
|
| 43 |
+
gc.collect()
|
| 44 |
+
|
| 45 |
+
# device may "cuda" or "cuda:0", so we need to check the type of device
|
| 46 |
+
if device.type == "cuda":
|
| 47 |
+
torch.cuda.empty_cache()
|
| 48 |
+
if device.type == "xpu":
|
| 49 |
+
torch.xpu.empty_cache()
|
| 50 |
+
if device.type == "mps":
|
| 51 |
+
torch.mps.empty_cache()
|
| 52 |
+
|
| 53 |
+
|
| 54 |
+
@functools.lru_cache(maxsize=None)
|
| 55 |
+
def get_preferred_device() -> torch.device:
|
| 56 |
+
r"""
|
| 57 |
+
Do not call this function from training scripts. Use accelerator.device instead.
|
| 58 |
+
"""
|
| 59 |
+
if HAS_CUDA:
|
| 60 |
+
device = torch.device("cuda")
|
| 61 |
+
elif HAS_XPU:
|
| 62 |
+
device = torch.device("xpu")
|
| 63 |
+
elif HAS_MPS:
|
| 64 |
+
device = torch.device("mps")
|
| 65 |
+
else:
|
| 66 |
+
device = torch.device("cpu")
|
| 67 |
+
print(f"get_preferred_device() -> {device}")
|
| 68 |
+
return device
|
| 69 |
+
|
| 70 |
+
|
| 71 |
+
def init_ipex():
|
| 72 |
+
"""
|
| 73 |
+
Apply IPEX to CUDA hijacks using `library.ipex.ipex_init`.
|
| 74 |
+
|
| 75 |
+
This function should run right after importing torch and before doing anything else.
|
| 76 |
+
|
| 77 |
+
If xpu is not available, this function does nothing.
|
| 78 |
+
"""
|
| 79 |
+
try:
|
| 80 |
+
if HAS_XPU:
|
| 81 |
+
from library.ipex import ipex_init
|
| 82 |
+
|
| 83 |
+
is_initialized, error_message = ipex_init()
|
| 84 |
+
if not is_initialized:
|
| 85 |
+
print("failed to initialize ipex:", error_message)
|
| 86 |
+
else:
|
| 87 |
+
return
|
| 88 |
+
except Exception as e:
|
| 89 |
+
print("failed to initialize ipex:", e)
|
diffusers.py
ADDED
|
@@ -0,0 +1,47 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from functools import wraps
|
| 2 |
+
import torch
|
| 3 |
+
import diffusers # pylint: disable=import-error
|
| 4 |
+
|
| 5 |
+
# pylint: disable=protected-access, missing-function-docstring, line-too-long
|
| 6 |
+
|
| 7 |
+
|
| 8 |
+
# Diffusers FreeU
|
| 9 |
+
original_fourier_filter = diffusers.utils.torch_utils.fourier_filter
|
| 10 |
+
@wraps(diffusers.utils.torch_utils.fourier_filter)
|
| 11 |
+
def fourier_filter(x_in, threshold, scale):
|
| 12 |
+
return_dtype = x_in.dtype
|
| 13 |
+
return original_fourier_filter(x_in.to(dtype=torch.float32), threshold, scale).to(dtype=return_dtype)
|
| 14 |
+
|
| 15 |
+
|
| 16 |
+
# fp64 error
|
| 17 |
+
class FluxPosEmbed(torch.nn.Module):
|
| 18 |
+
def __init__(self, theta: int, axes_dim):
|
| 19 |
+
super().__init__()
|
| 20 |
+
self.theta = theta
|
| 21 |
+
self.axes_dim = axes_dim
|
| 22 |
+
|
| 23 |
+
def forward(self, ids: torch.Tensor) -> torch.Tensor:
|
| 24 |
+
n_axes = ids.shape[-1]
|
| 25 |
+
cos_out = []
|
| 26 |
+
sin_out = []
|
| 27 |
+
pos = ids.float()
|
| 28 |
+
for i in range(n_axes):
|
| 29 |
+
cos, sin = diffusers.models.embeddings.get_1d_rotary_pos_embed(
|
| 30 |
+
self.axes_dim[i],
|
| 31 |
+
pos[:, i],
|
| 32 |
+
theta=self.theta,
|
| 33 |
+
repeat_interleave_real=True,
|
| 34 |
+
use_real=True,
|
| 35 |
+
freqs_dtype=torch.float32,
|
| 36 |
+
)
|
| 37 |
+
cos_out.append(cos)
|
| 38 |
+
sin_out.append(sin)
|
| 39 |
+
freqs_cos = torch.cat(cos_out, dim=-1).to(ids.device)
|
| 40 |
+
freqs_sin = torch.cat(sin_out, dim=-1).to(ids.device)
|
| 41 |
+
return freqs_cos, freqs_sin
|
| 42 |
+
|
| 43 |
+
|
| 44 |
+
def ipex_diffusers(device_supports_fp64=False, can_allocate_plus_4gb=False):
|
| 45 |
+
diffusers.utils.torch_utils.fourier_filter = fourier_filter
|
| 46 |
+
if not device_supports_fp64:
|
| 47 |
+
diffusers.models.embeddings.FluxPosEmbed = FluxPosEmbed
|
dylora.py
ADDED
|
@@ -0,0 +1,529 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# some codes are copied from:
|
| 2 |
+
# https://github.com/huawei-noah/KD-NLP/blob/main/DyLoRA/
|
| 3 |
+
|
| 4 |
+
# Copyright (C) 2022. Huawei Technologies Co., Ltd. All rights reserved.
|
| 5 |
+
# Changes made to the original code:
|
| 6 |
+
# 2022.08.20 - Integrate the DyLoRA layer for the LoRA Linear layer
|
| 7 |
+
# ------------------------------------------------------------------------------------------
|
| 8 |
+
# Copyright (c) Microsoft Corporation. All rights reserved.
|
| 9 |
+
# Licensed under the MIT License (MIT). See LICENSE in the repo root for license information.
|
| 10 |
+
# ------------------------------------------------------------------------------------------
|
| 11 |
+
|
| 12 |
+
import math
|
| 13 |
+
import os
|
| 14 |
+
import random
|
| 15 |
+
from typing import Dict, List, Optional, Tuple, Type, Union
|
| 16 |
+
from diffusers import AutoencoderKL
|
| 17 |
+
from transformers import CLIPTextModel
|
| 18 |
+
import torch
|
| 19 |
+
from torch import nn
|
| 20 |
+
from library.utils import setup_logging
|
| 21 |
+
|
| 22 |
+
setup_logging()
|
| 23 |
+
import logging
|
| 24 |
+
|
| 25 |
+
logger = logging.getLogger(__name__)
|
| 26 |
+
|
| 27 |
+
|
| 28 |
+
class DyLoRAModule(torch.nn.Module):
|
| 29 |
+
"""
|
| 30 |
+
replaces forward method of the original Linear, instead of replacing the original Linear module.
|
| 31 |
+
"""
|
| 32 |
+
|
| 33 |
+
# NOTE: support dropout in future
|
| 34 |
+
def __init__(self, lora_name, org_module: torch.nn.Module, multiplier=1.0, lora_dim=4, alpha=1, unit=1):
|
| 35 |
+
super().__init__()
|
| 36 |
+
self.lora_name = lora_name
|
| 37 |
+
self.lora_dim = lora_dim
|
| 38 |
+
self.unit = unit
|
| 39 |
+
assert self.lora_dim % self.unit == 0, "rank must be a multiple of unit"
|
| 40 |
+
|
| 41 |
+
if org_module.__class__.__name__ == "Conv2d":
|
| 42 |
+
in_dim = org_module.in_channels
|
| 43 |
+
out_dim = org_module.out_channels
|
| 44 |
+
else:
|
| 45 |
+
in_dim = org_module.in_features
|
| 46 |
+
out_dim = org_module.out_features
|
| 47 |
+
|
| 48 |
+
if type(alpha) == torch.Tensor:
|
| 49 |
+
alpha = alpha.detach().float().numpy() # without casting, bf16 causes error
|
| 50 |
+
alpha = self.lora_dim if alpha is None or alpha == 0 else alpha
|
| 51 |
+
self.scale = alpha / self.lora_dim
|
| 52 |
+
self.register_buffer("alpha", torch.tensor(alpha)) # 定数として扱える
|
| 53 |
+
|
| 54 |
+
self.is_conv2d = org_module.__class__.__name__ == "Conv2d"
|
| 55 |
+
self.is_conv2d_3x3 = self.is_conv2d and org_module.kernel_size == (3, 3)
|
| 56 |
+
|
| 57 |
+
if self.is_conv2d and self.is_conv2d_3x3:
|
| 58 |
+
kernel_size = org_module.kernel_size
|
| 59 |
+
self.stride = org_module.stride
|
| 60 |
+
self.padding = org_module.padding
|
| 61 |
+
self.lora_A = nn.ParameterList([org_module.weight.new_zeros((1, in_dim, *kernel_size)) for _ in range(self.lora_dim)])
|
| 62 |
+
self.lora_B = nn.ParameterList([org_module.weight.new_zeros((out_dim, 1, 1, 1)) for _ in range(self.lora_dim)])
|
| 63 |
+
else:
|
| 64 |
+
self.lora_A = nn.ParameterList([org_module.weight.new_zeros((1, in_dim)) for _ in range(self.lora_dim)])
|
| 65 |
+
self.lora_B = nn.ParameterList([org_module.weight.new_zeros((out_dim, 1)) for _ in range(self.lora_dim)])
|
| 66 |
+
|
| 67 |
+
# same as microsoft's
|
| 68 |
+
for lora in self.lora_A:
|
| 69 |
+
torch.nn.init.kaiming_uniform_(lora, a=math.sqrt(5))
|
| 70 |
+
for lora in self.lora_B:
|
| 71 |
+
torch.nn.init.zeros_(lora)
|
| 72 |
+
|
| 73 |
+
self.multiplier = multiplier
|
| 74 |
+
self.org_module = org_module # remove in applying
|
| 75 |
+
|
| 76 |
+
def apply_to(self):
|
| 77 |
+
self.org_forward = self.org_module.forward
|
| 78 |
+
self.org_module.forward = self.forward
|
| 79 |
+
del self.org_module
|
| 80 |
+
|
| 81 |
+
def forward(self, x):
|
| 82 |
+
result = self.org_forward(x)
|
| 83 |
+
|
| 84 |
+
# specify the dynamic rank
|
| 85 |
+
trainable_rank = random.randint(0, self.lora_dim - 1)
|
| 86 |
+
trainable_rank = trainable_rank - trainable_rank % self.unit # make sure the rank is a multiple of unit
|
| 87 |
+
|
| 88 |
+
# 一部のパラメータを固定して、残りのパラメータを学習する
|
| 89 |
+
for i in range(0, trainable_rank):
|
| 90 |
+
self.lora_A[i].requires_grad = False
|
| 91 |
+
self.lora_B[i].requires_grad = False
|
| 92 |
+
for i in range(trainable_rank, trainable_rank + self.unit):
|
| 93 |
+
self.lora_A[i].requires_grad = True
|
| 94 |
+
self.lora_B[i].requires_grad = True
|
| 95 |
+
for i in range(trainable_rank + self.unit, self.lora_dim):
|
| 96 |
+
self.lora_A[i].requires_grad = False
|
| 97 |
+
self.lora_B[i].requires_grad = False
|
| 98 |
+
|
| 99 |
+
lora_A = torch.cat(tuple(self.lora_A), dim=0)
|
| 100 |
+
lora_B = torch.cat(tuple(self.lora_B), dim=1)
|
| 101 |
+
|
| 102 |
+
# calculate with lora_A and lora_B
|
| 103 |
+
if self.is_conv2d_3x3:
|
| 104 |
+
ab = torch.nn.functional.conv2d(x, lora_A, stride=self.stride, padding=self.padding)
|
| 105 |
+
ab = torch.nn.functional.conv2d(ab, lora_B)
|
| 106 |
+
else:
|
| 107 |
+
ab = x
|
| 108 |
+
if self.is_conv2d:
|
| 109 |
+
ab = ab.reshape(ab.size(0), ab.size(1), -1).transpose(1, 2) # (N, C, H, W) -> (N, H*W, C)
|
| 110 |
+
|
| 111 |
+
ab = torch.nn.functional.linear(ab, lora_A)
|
| 112 |
+
ab = torch.nn.functional.linear(ab, lora_B)
|
| 113 |
+
|
| 114 |
+
if self.is_conv2d:
|
| 115 |
+
ab = ab.transpose(1, 2).reshape(ab.size(0), -1, *x.size()[2:]) # (N, H*W, C) -> (N, C, H, W)
|
| 116 |
+
|
| 117 |
+
# 最後の項は、低rankをより大きくするためのスケー���ング(じゃないかな)
|
| 118 |
+
result = result + ab * self.scale * math.sqrt(self.lora_dim / (trainable_rank + self.unit))
|
| 119 |
+
|
| 120 |
+
# NOTE weightに加算してからlinear/conv2dを呼んだほうが速いかも
|
| 121 |
+
return result
|
| 122 |
+
|
| 123 |
+
def state_dict(self, destination=None, prefix="", keep_vars=False):
|
| 124 |
+
# state dictを通常のLoRAと同じにする:
|
| 125 |
+
# nn.ParameterListは `.lora_A.0` みたいな名前になるので、forwardと同様にcatして入れ替える
|
| 126 |
+
sd = super().state_dict(destination=destination, prefix=prefix, keep_vars=keep_vars)
|
| 127 |
+
|
| 128 |
+
lora_A_weight = torch.cat(tuple(self.lora_A), dim=0)
|
| 129 |
+
if self.is_conv2d and not self.is_conv2d_3x3:
|
| 130 |
+
lora_A_weight = lora_A_weight.unsqueeze(-1).unsqueeze(-1)
|
| 131 |
+
|
| 132 |
+
lora_B_weight = torch.cat(tuple(self.lora_B), dim=1)
|
| 133 |
+
if self.is_conv2d and not self.is_conv2d_3x3:
|
| 134 |
+
lora_B_weight = lora_B_weight.unsqueeze(-1).unsqueeze(-1)
|
| 135 |
+
|
| 136 |
+
sd[self.lora_name + ".lora_down.weight"] = lora_A_weight if keep_vars else lora_A_weight.detach()
|
| 137 |
+
sd[self.lora_name + ".lora_up.weight"] = lora_B_weight if keep_vars else lora_B_weight.detach()
|
| 138 |
+
|
| 139 |
+
i = 0
|
| 140 |
+
while True:
|
| 141 |
+
key_a = f"{self.lora_name}.lora_A.{i}"
|
| 142 |
+
key_b = f"{self.lora_name}.lora_B.{i}"
|
| 143 |
+
if key_a in sd:
|
| 144 |
+
sd.pop(key_a)
|
| 145 |
+
sd.pop(key_b)
|
| 146 |
+
else:
|
| 147 |
+
break
|
| 148 |
+
i += 1
|
| 149 |
+
return sd
|
| 150 |
+
|
| 151 |
+
def _load_from_state_dict(self, state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs):
|
| 152 |
+
# 通常のLoRAと同じstate dictを読み込めるようにする:この方法はchatGPTに聞いた
|
| 153 |
+
lora_A_weight = state_dict.pop(self.lora_name + ".lora_down.weight", None)
|
| 154 |
+
lora_B_weight = state_dict.pop(self.lora_name + ".lora_up.weight", None)
|
| 155 |
+
|
| 156 |
+
if lora_A_weight is None or lora_B_weight is None:
|
| 157 |
+
if strict:
|
| 158 |
+
raise KeyError(f"{self.lora_name}.lora_down/up.weight is not found")
|
| 159 |
+
else:
|
| 160 |
+
return
|
| 161 |
+
|
| 162 |
+
if self.is_conv2d and not self.is_conv2d_3x3:
|
| 163 |
+
lora_A_weight = lora_A_weight.squeeze(-1).squeeze(-1)
|
| 164 |
+
lora_B_weight = lora_B_weight.squeeze(-1).squeeze(-1)
|
| 165 |
+
|
| 166 |
+
state_dict.update(
|
| 167 |
+
{f"{self.lora_name}.lora_A.{i}": nn.Parameter(lora_A_weight[i].unsqueeze(0)) for i in range(lora_A_weight.size(0))}
|
| 168 |
+
)
|
| 169 |
+
state_dict.update(
|
| 170 |
+
{f"{self.lora_name}.lora_B.{i}": nn.Parameter(lora_B_weight[:, i].unsqueeze(1)) for i in range(lora_B_weight.size(1))}
|
| 171 |
+
)
|
| 172 |
+
|
| 173 |
+
super()._load_from_state_dict(state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs)
|
| 174 |
+
|
| 175 |
+
|
| 176 |
+
def create_network(
|
| 177 |
+
multiplier: float,
|
| 178 |
+
network_dim: Optional[int],
|
| 179 |
+
network_alpha: Optional[float],
|
| 180 |
+
vae: AutoencoderKL,
|
| 181 |
+
text_encoder: Union[CLIPTextModel, List[CLIPTextModel]],
|
| 182 |
+
unet,
|
| 183 |
+
**kwargs,
|
| 184 |
+
):
|
| 185 |
+
if network_dim is None:
|
| 186 |
+
network_dim = 4 # default
|
| 187 |
+
if network_alpha is None:
|
| 188 |
+
network_alpha = 1.0
|
| 189 |
+
|
| 190 |
+
# extract dim/alpha for conv2d, and block dim
|
| 191 |
+
conv_dim = kwargs.get("conv_dim", None)
|
| 192 |
+
conv_alpha = kwargs.get("conv_alpha", None)
|
| 193 |
+
unit = kwargs.get("unit", None)
|
| 194 |
+
if conv_dim is not None:
|
| 195 |
+
conv_dim = int(conv_dim)
|
| 196 |
+
assert conv_dim == network_dim, "conv_dim must be same as network_dim"
|
| 197 |
+
if conv_alpha is None:
|
| 198 |
+
conv_alpha = 1.0
|
| 199 |
+
else:
|
| 200 |
+
conv_alpha = float(conv_alpha)
|
| 201 |
+
|
| 202 |
+
if unit is not None:
|
| 203 |
+
unit = int(unit)
|
| 204 |
+
else:
|
| 205 |
+
unit = 1
|
| 206 |
+
|
| 207 |
+
network = DyLoRANetwork(
|
| 208 |
+
text_encoder,
|
| 209 |
+
unet,
|
| 210 |
+
multiplier=multiplier,
|
| 211 |
+
lora_dim=network_dim,
|
| 212 |
+
alpha=network_alpha,
|
| 213 |
+
apply_to_conv=conv_dim is not None,
|
| 214 |
+
unit=unit,
|
| 215 |
+
varbose=True,
|
| 216 |
+
)
|
| 217 |
+
|
| 218 |
+
loraplus_lr_ratio = kwargs.get("loraplus_lr_ratio", None)
|
| 219 |
+
loraplus_unet_lr_ratio = kwargs.get("loraplus_unet_lr_ratio", None)
|
| 220 |
+
loraplus_text_encoder_lr_ratio = kwargs.get("loraplus_text_encoder_lr_ratio", None)
|
| 221 |
+
loraplus_lr_ratio = float(loraplus_lr_ratio) if loraplus_lr_ratio is not None else None
|
| 222 |
+
loraplus_unet_lr_ratio = float(loraplus_unet_lr_ratio) if loraplus_unet_lr_ratio is not None else None
|
| 223 |
+
loraplus_text_encoder_lr_ratio = float(loraplus_text_encoder_lr_ratio) if loraplus_text_encoder_lr_ratio is not None else None
|
| 224 |
+
if loraplus_lr_ratio is not None or loraplus_unet_lr_ratio is not None or loraplus_text_encoder_lr_ratio is not None:
|
| 225 |
+
network.set_loraplus_lr_ratio(loraplus_lr_ratio, loraplus_unet_lr_ratio, loraplus_text_encoder_lr_ratio)
|
| 226 |
+
|
| 227 |
+
return network
|
| 228 |
+
|
| 229 |
+
|
| 230 |
+
# Create network from weights for inference, weights are not loaded here (because can be merged)
|
| 231 |
+
def create_network_from_weights(multiplier, file, vae, text_encoder, unet, weights_sd=None, for_inference=False, **kwargs):
|
| 232 |
+
if weights_sd is None:
|
| 233 |
+
if os.path.splitext(file)[1] == ".safetensors":
|
| 234 |
+
from safetensors.torch import load_file, safe_open
|
| 235 |
+
|
| 236 |
+
weights_sd = load_file(file)
|
| 237 |
+
else:
|
| 238 |
+
weights_sd = torch.load(file, map_location="cpu")
|
| 239 |
+
|
| 240 |
+
# get dim/alpha mapping
|
| 241 |
+
modules_dim = {}
|
| 242 |
+
modules_alpha = {}
|
| 243 |
+
for key, value in weights_sd.items():
|
| 244 |
+
if "." not in key:
|
| 245 |
+
continue
|
| 246 |
+
|
| 247 |
+
lora_name = key.split(".")[0]
|
| 248 |
+
if "alpha" in key:
|
| 249 |
+
modules_alpha[lora_name] = value
|
| 250 |
+
elif "lora_down" in key:
|
| 251 |
+
dim = value.size()[0]
|
| 252 |
+
modules_dim[lora_name] = dim
|
| 253 |
+
# logger.info(f"{lora_name} {value.size()} {dim}")
|
| 254 |
+
|
| 255 |
+
# support old LoRA without alpha
|
| 256 |
+
for key in modules_dim.keys():
|
| 257 |
+
if key not in modules_alpha:
|
| 258 |
+
modules_alpha = modules_dim[key]
|
| 259 |
+
|
| 260 |
+
module_class = DyLoRAModule
|
| 261 |
+
|
| 262 |
+
network = DyLoRANetwork(
|
| 263 |
+
text_encoder, unet, multiplier=multiplier, modules_dim=modules_dim, modules_alpha=modules_alpha, module_class=module_class
|
| 264 |
+
)
|
| 265 |
+
return network, weights_sd
|
| 266 |
+
|
| 267 |
+
|
| 268 |
+
class DyLoRANetwork(torch.nn.Module):
|
| 269 |
+
UNET_TARGET_REPLACE_MODULE = ["Transformer2DModel"]
|
| 270 |
+
UNET_TARGET_REPLACE_MODULE_CONV2D_3X3 = ["ResnetBlock2D", "Downsample2D", "Upsample2D"]
|
| 271 |
+
TEXT_ENCODER_TARGET_REPLACE_MODULE = ["CLIPAttention", "CLIPSdpaAttention", "CLIPMLP"]
|
| 272 |
+
LORA_PREFIX_UNET = "lora_unet"
|
| 273 |
+
LORA_PREFIX_TEXT_ENCODER = "lora_te"
|
| 274 |
+
|
| 275 |
+
def __init__(
|
| 276 |
+
self,
|
| 277 |
+
text_encoder,
|
| 278 |
+
unet,
|
| 279 |
+
multiplier=1.0,
|
| 280 |
+
lora_dim=4,
|
| 281 |
+
alpha=1,
|
| 282 |
+
apply_to_conv=False,
|
| 283 |
+
modules_dim=None,
|
| 284 |
+
modules_alpha=None,
|
| 285 |
+
unit=1,
|
| 286 |
+
module_class=DyLoRAModule,
|
| 287 |
+
varbose=False,
|
| 288 |
+
) -> None:
|
| 289 |
+
super().__init__()
|
| 290 |
+
self.multiplier = multiplier
|
| 291 |
+
|
| 292 |
+
self.lora_dim = lora_dim
|
| 293 |
+
self.alpha = alpha
|
| 294 |
+
self.apply_to_conv = apply_to_conv
|
| 295 |
+
|
| 296 |
+
self.loraplus_lr_ratio = None
|
| 297 |
+
self.loraplus_unet_lr_ratio = None
|
| 298 |
+
self.loraplus_text_encoder_lr_ratio = None
|
| 299 |
+
|
| 300 |
+
if modules_dim is not None:
|
| 301 |
+
logger.info("create LoRA network from weights")
|
| 302 |
+
else:
|
| 303 |
+
logger.info(f"create LoRA network. base dim (rank): {lora_dim}, alpha: {alpha}, unit: {unit}")
|
| 304 |
+
if self.apply_to_conv:
|
| 305 |
+
logger.info("apply LoRA to Conv2d with kernel size (3,3).")
|
| 306 |
+
|
| 307 |
+
# create module instances
|
| 308 |
+
def create_modules(is_unet, root_module: torch.nn.Module, target_replace_modules) -> List[DyLoRAModule]:
|
| 309 |
+
prefix = DyLoRANetwork.LORA_PREFIX_UNET if is_unet else DyLoRANetwork.LORA_PREFIX_TEXT_ENCODER
|
| 310 |
+
loras = []
|
| 311 |
+
for name, module in root_module.named_modules():
|
| 312 |
+
if module.__class__.__name__ in target_replace_modules:
|
| 313 |
+
for child_name, child_module in module.named_modules():
|
| 314 |
+
is_linear = child_module.__class__.__name__ == "Linear"
|
| 315 |
+
is_conv2d = child_module.__class__.__name__ == "Conv2d"
|
| 316 |
+
is_conv2d_1x1 = is_conv2d and child_module.kernel_size == (1, 1)
|
| 317 |
+
|
| 318 |
+
if is_linear or is_conv2d:
|
| 319 |
+
lora_name = prefix + "." + name + "." + child_name
|
| 320 |
+
lora_name = lora_name.replace(".", "_")
|
| 321 |
+
|
| 322 |
+
dim = None
|
| 323 |
+
alpha = None
|
| 324 |
+
if modules_dim is not None:
|
| 325 |
+
if lora_name in modules_dim:
|
| 326 |
+
dim = modules_dim[lora_name]
|
| 327 |
+
alpha = modules_alpha[lora_name]
|
| 328 |
+
else:
|
| 329 |
+
if is_linear or is_conv2d_1x1 or apply_to_conv:
|
| 330 |
+
dim = self.lora_dim
|
| 331 |
+
alpha = self.alpha
|
| 332 |
+
|
| 333 |
+
if dim is None or dim == 0:
|
| 334 |
+
continue
|
| 335 |
+
|
| 336 |
+
# dropout and fan_in_fan_out is default
|
| 337 |
+
lora = module_class(lora_name, child_module, self.multiplier, dim, alpha, unit)
|
| 338 |
+
loras.append(lora)
|
| 339 |
+
return loras
|
| 340 |
+
|
| 341 |
+
text_encoders = text_encoder if type(text_encoder) == list else [text_encoder]
|
| 342 |
+
|
| 343 |
+
self.text_encoder_loras = []
|
| 344 |
+
for i, text_encoder in enumerate(text_encoders):
|
| 345 |
+
if len(text_encoders) > 1:
|
| 346 |
+
index = i + 1
|
| 347 |
+
logger.info(f"create LoRA for Text Encoder {index}")
|
| 348 |
+
else:
|
| 349 |
+
index = None
|
| 350 |
+
logger.info("create LoRA for Text Encoder")
|
| 351 |
+
|
| 352 |
+
text_encoder_loras = create_modules(False, text_encoder, DyLoRANetwork.TEXT_ENCODER_TARGET_REPLACE_MODULE)
|
| 353 |
+
self.text_encoder_loras.extend(text_encoder_loras)
|
| 354 |
+
|
| 355 |
+
# self.text_encoder_loras = create_modules(False, text_encoder, DyLoRANetwork.TEXT_ENCODER_TARGET_REPLACE_MODULE)
|
| 356 |
+
logger.info(f"create LoRA for Text Encoder: {len(self.text_encoder_loras)} modules.")
|
| 357 |
+
|
| 358 |
+
# extend U-Net target modules if conv2d 3x3 is enabled, or load from weights
|
| 359 |
+
target_modules = DyLoRANetwork.UNET_TARGET_REPLACE_MODULE
|
| 360 |
+
if modules_dim is not None or self.apply_to_conv:
|
| 361 |
+
target_modules += DyLoRANetwork.UNET_TARGET_REPLACE_MODULE_CONV2D_3X3
|
| 362 |
+
|
| 363 |
+
self.unet_loras = create_modules(True, unet, target_modules)
|
| 364 |
+
logger.info(f"create LoRA for U-Net: {len(self.unet_loras)} modules.")
|
| 365 |
+
|
| 366 |
+
def set_loraplus_lr_ratio(self, loraplus_lr_ratio, loraplus_unet_lr_ratio, loraplus_text_encoder_lr_ratio):
|
| 367 |
+
self.loraplus_lr_ratio = loraplus_lr_ratio
|
| 368 |
+
self.loraplus_unet_lr_ratio = loraplus_unet_lr_ratio
|
| 369 |
+
self.loraplus_text_encoder_lr_ratio = loraplus_text_encoder_lr_ratio
|
| 370 |
+
|
| 371 |
+
logger.info(f"LoRA+ UNet LR Ratio: {self.loraplus_unet_lr_ratio or self.loraplus_lr_ratio}")
|
| 372 |
+
logger.info(f"LoRA+ Text Encoder LR Ratio: {self.loraplus_text_encoder_lr_ratio or self.loraplus_lr_ratio}")
|
| 373 |
+
|
| 374 |
+
def set_multiplier(self, multiplier):
|
| 375 |
+
self.multiplier = multiplier
|
| 376 |
+
for lora in self.text_encoder_loras + self.unet_loras:
|
| 377 |
+
lora.multiplier = self.multiplier
|
| 378 |
+
|
| 379 |
+
def load_weights(self, file):
|
| 380 |
+
if os.path.splitext(file)[1] == ".safetensors":
|
| 381 |
+
from safetensors.torch import load_file
|
| 382 |
+
|
| 383 |
+
weights_sd = load_file(file)
|
| 384 |
+
else:
|
| 385 |
+
weights_sd = torch.load(file, map_location="cpu")
|
| 386 |
+
|
| 387 |
+
info = self.load_state_dict(weights_sd, False)
|
| 388 |
+
return info
|
| 389 |
+
|
| 390 |
+
def apply_to(self, text_encoder, unet, apply_text_encoder=True, apply_unet=True):
|
| 391 |
+
if apply_text_encoder:
|
| 392 |
+
logger.info("enable LoRA for text encoder")
|
| 393 |
+
else:
|
| 394 |
+
self.text_encoder_loras = []
|
| 395 |
+
|
| 396 |
+
if apply_unet:
|
| 397 |
+
logger.info("enable LoRA for U-Net")
|
| 398 |
+
else:
|
| 399 |
+
self.unet_loras = []
|
| 400 |
+
|
| 401 |
+
for lora in self.text_encoder_loras + self.unet_loras:
|
| 402 |
+
lora.apply_to()
|
| 403 |
+
self.add_module(lora.lora_name, lora)
|
| 404 |
+
|
| 405 |
+
"""
|
| 406 |
+
def merge_to(self, text_encoder, unet, weights_sd, dtype, device):
|
| 407 |
+
apply_text_encoder = apply_unet = False
|
| 408 |
+
for key in weights_sd.keys():
|
| 409 |
+
if key.startswith(DyLoRANetwork.LORA_PREFIX_TEXT_ENCODER):
|
| 410 |
+
apply_text_encoder = True
|
| 411 |
+
elif key.startswith(DyLoRANetwork.LORA_PREFIX_UNET):
|
| 412 |
+
apply_unet = True
|
| 413 |
+
|
| 414 |
+
if apply_text_encoder:
|
| 415 |
+
logger.info("enable LoRA for text encoder")
|
| 416 |
+
else:
|
| 417 |
+
self.text_encoder_loras = []
|
| 418 |
+
|
| 419 |
+
if apply_unet:
|
| 420 |
+
logger.info("enable LoRA for U-Net")
|
| 421 |
+
else:
|
| 422 |
+
self.unet_loras = []
|
| 423 |
+
|
| 424 |
+
for lora in self.text_encoder_loras + self.unet_loras:
|
| 425 |
+
sd_for_lora = {}
|
| 426 |
+
for key in weights_sd.keys():
|
| 427 |
+
if key.startswith(lora.lora_name):
|
| 428 |
+
sd_for_lora[key[len(lora.lora_name) + 1 :]] = weights_sd[key]
|
| 429 |
+
lora.merge_to(sd_for_lora, dtype, device)
|
| 430 |
+
|
| 431 |
+
logger.info(f"weights are merged")
|
| 432 |
+
"""
|
| 433 |
+
|
| 434 |
+
# 二つのText Encoderに別々の学習率を設定できるようにするといいかも
|
| 435 |
+
def prepare_optimizer_params(self, text_encoder_lr, unet_lr, default_lr):
|
| 436 |
+
self.requires_grad_(True)
|
| 437 |
+
all_params = []
|
| 438 |
+
|
| 439 |
+
def assemble_params(loras, lr, ratio):
|
| 440 |
+
param_groups = {"lora": {}, "plus": {}}
|
| 441 |
+
for lora in loras:
|
| 442 |
+
for name, param in lora.named_parameters():
|
| 443 |
+
if ratio is not None and "lora_B" in name:
|
| 444 |
+
param_groups["plus"][f"{lora.lora_name}.{name}"] = param
|
| 445 |
+
else:
|
| 446 |
+
param_groups["lora"][f"{lora.lora_name}.{name}"] = param
|
| 447 |
+
|
| 448 |
+
params = []
|
| 449 |
+
for key in param_groups.keys():
|
| 450 |
+
param_data = {"params": param_groups[key].values()}
|
| 451 |
+
|
| 452 |
+
if len(param_data["params"]) == 0:
|
| 453 |
+
continue
|
| 454 |
+
|
| 455 |
+
if lr is not None:
|
| 456 |
+
if key == "plus":
|
| 457 |
+
param_data["lr"] = lr * ratio
|
| 458 |
+
else:
|
| 459 |
+
param_data["lr"] = lr
|
| 460 |
+
|
| 461 |
+
if param_data.get("lr", None) == 0 or param_data.get("lr", None) is None:
|
| 462 |
+
continue
|
| 463 |
+
|
| 464 |
+
params.append(param_data)
|
| 465 |
+
|
| 466 |
+
return params
|
| 467 |
+
|
| 468 |
+
if self.text_encoder_loras:
|
| 469 |
+
params = assemble_params(
|
| 470 |
+
self.text_encoder_loras,
|
| 471 |
+
text_encoder_lr if text_encoder_lr is not None else default_lr,
|
| 472 |
+
self.loraplus_text_encoder_lr_ratio or self.loraplus_lr_ratio,
|
| 473 |
+
)
|
| 474 |
+
all_params.extend(params)
|
| 475 |
+
|
| 476 |
+
if self.unet_loras:
|
| 477 |
+
params = assemble_params(
|
| 478 |
+
self.unet_loras, default_lr if unet_lr is None else unet_lr, self.loraplus_unet_lr_ratio or self.loraplus_lr_ratio
|
| 479 |
+
)
|
| 480 |
+
all_params.extend(params)
|
| 481 |
+
|
| 482 |
+
return all_params
|
| 483 |
+
|
| 484 |
+
def enable_gradient_checkpointing(self):
|
| 485 |
+
# not supported
|
| 486 |
+
pass
|
| 487 |
+
|
| 488 |
+
def prepare_grad_etc(self, text_encoder, unet):
|
| 489 |
+
self.requires_grad_(True)
|
| 490 |
+
|
| 491 |
+
def on_epoch_start(self, text_encoder, unet):
|
| 492 |
+
self.train()
|
| 493 |
+
|
| 494 |
+
def get_trainable_params(self):
|
| 495 |
+
return self.parameters()
|
| 496 |
+
|
| 497 |
+
def save_weights(self, file, dtype, metadata):
|
| 498 |
+
if metadata is not None and len(metadata) == 0:
|
| 499 |
+
metadata = None
|
| 500 |
+
|
| 501 |
+
state_dict = self.state_dict()
|
| 502 |
+
|
| 503 |
+
if dtype is not None:
|
| 504 |
+
for key in list(state_dict.keys()):
|
| 505 |
+
v = state_dict[key]
|
| 506 |
+
v = v.detach().clone().to("cpu").to(dtype)
|
| 507 |
+
state_dict[key] = v
|
| 508 |
+
|
| 509 |
+
if os.path.splitext(file)[1] == ".safetensors":
|
| 510 |
+
from safetensors.torch import save_file
|
| 511 |
+
from library import train_util
|
| 512 |
+
|
| 513 |
+
# Precalculate model hashes to save time on indexing
|
| 514 |
+
if metadata is None:
|
| 515 |
+
metadata = {}
|
| 516 |
+
model_hash, legacy_hash = train_util.precalculate_safetensors_hashes(state_dict, metadata)
|
| 517 |
+
metadata["sshs_model_hash"] = model_hash
|
| 518 |
+
metadata["sshs_legacy_hash"] = legacy_hash
|
| 519 |
+
|
| 520 |
+
save_file(state_dict, file, metadata)
|
| 521 |
+
else:
|
| 522 |
+
torch.save(state_dict, file)
|
| 523 |
+
|
| 524 |
+
# mask is a tensor with values from 0 to 1
|
| 525 |
+
def set_region(self, sub_prompt_index, is_last_network, mask):
|
| 526 |
+
pass
|
| 527 |
+
|
| 528 |
+
def set_current_generation(self, batch_size, num_sub_prompts, width, height, shared):
|
| 529 |
+
pass
|
extract_lora_from_dylora.py
ADDED
|
@@ -0,0 +1,128 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Convert LoRA to different rank approximation (should only be used to go to lower rank)
|
| 2 |
+
# This code is based off the extract_lora_from_models.py file which is based on https://github.com/cloneofsimo/lora/blob/develop/lora_diffusion/cli_svd.py
|
| 3 |
+
# Thanks to cloneofsimo
|
| 4 |
+
|
| 5 |
+
import argparse
|
| 6 |
+
import math
|
| 7 |
+
import os
|
| 8 |
+
import torch
|
| 9 |
+
from safetensors.torch import load_file, save_file, safe_open
|
| 10 |
+
from tqdm import tqdm
|
| 11 |
+
from library import train_util, model_util
|
| 12 |
+
import numpy as np
|
| 13 |
+
from library.utils import setup_logging
|
| 14 |
+
setup_logging()
|
| 15 |
+
import logging
|
| 16 |
+
logger = logging.getLogger(__name__)
|
| 17 |
+
|
| 18 |
+
def load_state_dict(file_name):
|
| 19 |
+
if model_util.is_safetensors(file_name):
|
| 20 |
+
sd = load_file(file_name)
|
| 21 |
+
with safe_open(file_name, framework="pt") as f:
|
| 22 |
+
metadata = f.metadata()
|
| 23 |
+
else:
|
| 24 |
+
sd = torch.load(file_name, map_location="cpu")
|
| 25 |
+
metadata = None
|
| 26 |
+
|
| 27 |
+
return sd, metadata
|
| 28 |
+
|
| 29 |
+
|
| 30 |
+
def save_to_file(file_name, model, metadata):
|
| 31 |
+
if model_util.is_safetensors(file_name):
|
| 32 |
+
save_file(model, file_name, metadata)
|
| 33 |
+
else:
|
| 34 |
+
torch.save(model, file_name)
|
| 35 |
+
|
| 36 |
+
|
| 37 |
+
def split_lora_model(lora_sd, unit):
|
| 38 |
+
max_rank = 0
|
| 39 |
+
|
| 40 |
+
# Extract loaded lora dim and alpha
|
| 41 |
+
for key, value in lora_sd.items():
|
| 42 |
+
if "lora_down" in key:
|
| 43 |
+
rank = value.size()[0]
|
| 44 |
+
if rank > max_rank:
|
| 45 |
+
max_rank = rank
|
| 46 |
+
logger.info(f"Max rank: {max_rank}")
|
| 47 |
+
|
| 48 |
+
rank = unit
|
| 49 |
+
split_models = []
|
| 50 |
+
new_alpha = None
|
| 51 |
+
while rank < max_rank:
|
| 52 |
+
logger.info(f"Splitting rank {rank}")
|
| 53 |
+
new_sd = {}
|
| 54 |
+
for key, value in lora_sd.items():
|
| 55 |
+
if "lora_down" in key:
|
| 56 |
+
new_sd[key] = value[:rank].contiguous()
|
| 57 |
+
elif "lora_up" in key:
|
| 58 |
+
new_sd[key] = value[:, :rank].contiguous()
|
| 59 |
+
else:
|
| 60 |
+
# なぜかscaleするとおかしくなる……
|
| 61 |
+
# this_rank = lora_sd[key.replace("alpha", "lora_down.weight")].size()[0]
|
| 62 |
+
# scale = math.sqrt(this_rank / rank) # rank is > unit
|
| 63 |
+
# logger.info(key, value.size(), this_rank, rank, value, scale)
|
| 64 |
+
# new_alpha = value * scale # always same
|
| 65 |
+
# new_sd[key] = new_alpha
|
| 66 |
+
new_sd[key] = value
|
| 67 |
+
|
| 68 |
+
split_models.append((new_sd, rank, new_alpha))
|
| 69 |
+
rank += unit
|
| 70 |
+
|
| 71 |
+
return max_rank, split_models
|
| 72 |
+
|
| 73 |
+
|
| 74 |
+
def split(args):
|
| 75 |
+
logger.info("loading Model...")
|
| 76 |
+
lora_sd, metadata = load_state_dict(args.model)
|
| 77 |
+
|
| 78 |
+
logger.info("Splitting Model...")
|
| 79 |
+
original_rank, split_models = split_lora_model(lora_sd, args.unit)
|
| 80 |
+
|
| 81 |
+
comment = metadata.get("ss_training_comment", "")
|
| 82 |
+
for state_dict, new_rank, new_alpha in split_models:
|
| 83 |
+
# update metadata
|
| 84 |
+
if metadata is None:
|
| 85 |
+
new_metadata = {}
|
| 86 |
+
else:
|
| 87 |
+
new_metadata = metadata.copy()
|
| 88 |
+
|
| 89 |
+
new_metadata["ss_training_comment"] = f"split from DyLoRA, rank {original_rank} to {new_rank}; {comment}"
|
| 90 |
+
new_metadata["ss_network_dim"] = str(new_rank)
|
| 91 |
+
# new_metadata["ss_network_alpha"] = str(new_alpha.float().numpy())
|
| 92 |
+
|
| 93 |
+
model_hash, legacy_hash = train_util.precalculate_safetensors_hashes(state_dict, metadata)
|
| 94 |
+
metadata["sshs_model_hash"] = model_hash
|
| 95 |
+
metadata["sshs_legacy_hash"] = legacy_hash
|
| 96 |
+
|
| 97 |
+
filename, ext = os.path.splitext(args.save_to)
|
| 98 |
+
model_file_name = filename + f"-{new_rank:04d}{ext}"
|
| 99 |
+
|
| 100 |
+
logger.info(f"saving model to: {model_file_name}")
|
| 101 |
+
save_to_file(model_file_name, state_dict, new_metadata)
|
| 102 |
+
|
| 103 |
+
|
| 104 |
+
def setup_parser() -> argparse.ArgumentParser:
|
| 105 |
+
parser = argparse.ArgumentParser()
|
| 106 |
+
|
| 107 |
+
parser.add_argument("--unit", type=int, default=None, help="size of rank to split into / rankを分割するサイズ")
|
| 108 |
+
parser.add_argument(
|
| 109 |
+
"--save_to",
|
| 110 |
+
type=str,
|
| 111 |
+
default=None,
|
| 112 |
+
help="destination base file name: ckpt or safetensors file / 保存先のファイル名のbase、ckptまたはsafetensors",
|
| 113 |
+
)
|
| 114 |
+
parser.add_argument(
|
| 115 |
+
"--model",
|
| 116 |
+
type=str,
|
| 117 |
+
default=None,
|
| 118 |
+
help="DyLoRA model to resize at to new rank: ckpt or safetensors file / 読み込むDyLoRAモデル、ckptまたはsafetensors",
|
| 119 |
+
)
|
| 120 |
+
|
| 121 |
+
return parser
|
| 122 |
+
|
| 123 |
+
|
| 124 |
+
if __name__ == "__main__":
|
| 125 |
+
parser = setup_parser()
|
| 126 |
+
|
| 127 |
+
args = parser.parse_args()
|
| 128 |
+
split(args)
|
extract_lora_from_models.py
ADDED
|
@@ -0,0 +1,360 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# extract approximating LoRA by svd from two SD models
|
| 2 |
+
# The code is based on https://github.com/cloneofsimo/lora/blob/develop/lora_diffusion/cli_svd.py
|
| 3 |
+
# Thanks to cloneofsimo!
|
| 4 |
+
|
| 5 |
+
import argparse
|
| 6 |
+
import json
|
| 7 |
+
import os
|
| 8 |
+
import time
|
| 9 |
+
import torch
|
| 10 |
+
from safetensors.torch import load_file, save_file
|
| 11 |
+
from tqdm import tqdm
|
| 12 |
+
from library import sai_model_spec, model_util, sdxl_model_util
|
| 13 |
+
import lora
|
| 14 |
+
from library.utils import setup_logging
|
| 15 |
+
setup_logging()
|
| 16 |
+
import logging
|
| 17 |
+
logger = logging.getLogger(__name__)
|
| 18 |
+
|
| 19 |
+
# CLAMP_QUANTILE = 0.99
|
| 20 |
+
# MIN_DIFF = 1e-1
|
| 21 |
+
|
| 22 |
+
|
| 23 |
+
def save_to_file(file_name, model, state_dict, dtype):
|
| 24 |
+
if dtype is not None:
|
| 25 |
+
for key in list(state_dict.keys()):
|
| 26 |
+
if type(state_dict[key]) == torch.Tensor:
|
| 27 |
+
state_dict[key] = state_dict[key].to(dtype)
|
| 28 |
+
|
| 29 |
+
if os.path.splitext(file_name)[1] == ".safetensors":
|
| 30 |
+
save_file(model, file_name)
|
| 31 |
+
else:
|
| 32 |
+
torch.save(model, file_name)
|
| 33 |
+
|
| 34 |
+
|
| 35 |
+
def svd(
|
| 36 |
+
model_org=None,
|
| 37 |
+
model_tuned=None,
|
| 38 |
+
save_to=None,
|
| 39 |
+
dim=4,
|
| 40 |
+
v2=None,
|
| 41 |
+
sdxl=None,
|
| 42 |
+
conv_dim=None,
|
| 43 |
+
v_parameterization=None,
|
| 44 |
+
device=None,
|
| 45 |
+
save_precision=None,
|
| 46 |
+
clamp_quantile=0.99,
|
| 47 |
+
min_diff=0.01,
|
| 48 |
+
no_metadata=False,
|
| 49 |
+
load_precision=None,
|
| 50 |
+
load_original_model_to=None,
|
| 51 |
+
load_tuned_model_to=None,
|
| 52 |
+
):
|
| 53 |
+
def str_to_dtype(p):
|
| 54 |
+
if p == "float":
|
| 55 |
+
return torch.float
|
| 56 |
+
if p == "fp16":
|
| 57 |
+
return torch.float16
|
| 58 |
+
if p == "bf16":
|
| 59 |
+
return torch.bfloat16
|
| 60 |
+
return None
|
| 61 |
+
|
| 62 |
+
assert v2 != sdxl or (not v2 and not sdxl), "v2 and sdxl cannot be specified at the same time / v2とsdxlは同時に指定できません"
|
| 63 |
+
if v_parameterization is None:
|
| 64 |
+
v_parameterization = v2
|
| 65 |
+
|
| 66 |
+
load_dtype = str_to_dtype(load_precision) if load_precision else None
|
| 67 |
+
save_dtype = str_to_dtype(save_precision)
|
| 68 |
+
work_device = "cpu"
|
| 69 |
+
|
| 70 |
+
# load models
|
| 71 |
+
if not sdxl:
|
| 72 |
+
logger.info(f"loading original SD model : {model_org}")
|
| 73 |
+
text_encoder_o, _, unet_o = model_util.load_models_from_stable_diffusion_checkpoint(v2, model_org)
|
| 74 |
+
text_encoders_o = [text_encoder_o]
|
| 75 |
+
if load_dtype is not None:
|
| 76 |
+
text_encoder_o = text_encoder_o.to(load_dtype)
|
| 77 |
+
unet_o = unet_o.to(load_dtype)
|
| 78 |
+
|
| 79 |
+
logger.info(f"loading tuned SD model : {model_tuned}")
|
| 80 |
+
text_encoder_t, _, unet_t = model_util.load_models_from_stable_diffusion_checkpoint(v2, model_tuned)
|
| 81 |
+
text_encoders_t = [text_encoder_t]
|
| 82 |
+
if load_dtype is not None:
|
| 83 |
+
text_encoder_t = text_encoder_t.to(load_dtype)
|
| 84 |
+
unet_t = unet_t.to(load_dtype)
|
| 85 |
+
|
| 86 |
+
model_version = model_util.get_model_version_str_for_sd1_sd2(v2, v_parameterization)
|
| 87 |
+
else:
|
| 88 |
+
device_org = load_original_model_to if load_original_model_to else "cpu"
|
| 89 |
+
device_tuned = load_tuned_model_to if load_tuned_model_to else "cpu"
|
| 90 |
+
|
| 91 |
+
logger.info(f"loading original SDXL model : {model_org}")
|
| 92 |
+
text_encoder_o1, text_encoder_o2, _, unet_o, _, _ = sdxl_model_util.load_models_from_sdxl_checkpoint(
|
| 93 |
+
sdxl_model_util.MODEL_VERSION_SDXL_BASE_V1_0, model_org, device_org
|
| 94 |
+
)
|
| 95 |
+
text_encoders_o = [text_encoder_o1, text_encoder_o2]
|
| 96 |
+
if load_dtype is not None:
|
| 97 |
+
text_encoder_o1 = text_encoder_o1.to(load_dtype)
|
| 98 |
+
text_encoder_o2 = text_encoder_o2.to(load_dtype)
|
| 99 |
+
unet_o = unet_o.to(load_dtype)
|
| 100 |
+
|
| 101 |
+
logger.info(f"loading original SDXL model : {model_tuned}")
|
| 102 |
+
text_encoder_t1, text_encoder_t2, _, unet_t, _, _ = sdxl_model_util.load_models_from_sdxl_checkpoint(
|
| 103 |
+
sdxl_model_util.MODEL_VERSION_SDXL_BASE_V1_0, model_tuned, device_tuned
|
| 104 |
+
)
|
| 105 |
+
text_encoders_t = [text_encoder_t1, text_encoder_t2]
|
| 106 |
+
if load_dtype is not None:
|
| 107 |
+
text_encoder_t1 = text_encoder_t1.to(load_dtype)
|
| 108 |
+
text_encoder_t2 = text_encoder_t2.to(load_dtype)
|
| 109 |
+
unet_t = unet_t.to(load_dtype)
|
| 110 |
+
|
| 111 |
+
model_version = sdxl_model_util.MODEL_VERSION_SDXL_BASE_V1_0
|
| 112 |
+
|
| 113 |
+
# create LoRA network to extract weights: Use dim (rank) as alpha
|
| 114 |
+
if conv_dim is None:
|
| 115 |
+
kwargs = {}
|
| 116 |
+
else:
|
| 117 |
+
kwargs = {"conv_dim": conv_dim, "conv_alpha": conv_dim}
|
| 118 |
+
|
| 119 |
+
lora_network_o = lora.create_network(1.0, dim, dim, None, text_encoders_o, unet_o, **kwargs)
|
| 120 |
+
lora_network_t = lora.create_network(1.0, dim, dim, None, text_encoders_t, unet_t, **kwargs)
|
| 121 |
+
assert len(lora_network_o.text_encoder_loras) == len(
|
| 122 |
+
lora_network_t.text_encoder_loras
|
| 123 |
+
), f"model version is different (SD1.x vs SD2.x) / それぞれのモデルのバージョンが違います(SD1.xベースとSD2.xベース) "
|
| 124 |
+
|
| 125 |
+
# get diffs
|
| 126 |
+
diffs = {}
|
| 127 |
+
text_encoder_different = False
|
| 128 |
+
for i, (lora_o, lora_t) in enumerate(zip(lora_network_o.text_encoder_loras, lora_network_t.text_encoder_loras)):
|
| 129 |
+
lora_name = lora_o.lora_name
|
| 130 |
+
module_o = lora_o.org_module
|
| 131 |
+
module_t = lora_t.org_module
|
| 132 |
+
diff = module_t.weight.to(work_device) - module_o.weight.to(work_device)
|
| 133 |
+
|
| 134 |
+
# clear weight to save memory
|
| 135 |
+
module_o.weight = None
|
| 136 |
+
module_t.weight = None
|
| 137 |
+
|
| 138 |
+
# Text Encoder might be same
|
| 139 |
+
if not text_encoder_different and torch.max(torch.abs(diff)) > min_diff:
|
| 140 |
+
text_encoder_different = True
|
| 141 |
+
logger.info(f"Text encoder is different. {torch.max(torch.abs(diff))} > {min_diff}")
|
| 142 |
+
|
| 143 |
+
diffs[lora_name] = diff
|
| 144 |
+
|
| 145 |
+
# clear target Text Encoder to save memory
|
| 146 |
+
for text_encoder in text_encoders_t:
|
| 147 |
+
del text_encoder
|
| 148 |
+
|
| 149 |
+
if not text_encoder_different:
|
| 150 |
+
logger.warning("Text encoder is same. Extract U-Net only.")
|
| 151 |
+
lora_network_o.text_encoder_loras = []
|
| 152 |
+
diffs = {} # clear diffs
|
| 153 |
+
|
| 154 |
+
for i, (lora_o, lora_t) in enumerate(zip(lora_network_o.unet_loras, lora_network_t.unet_loras)):
|
| 155 |
+
lora_name = lora_o.lora_name
|
| 156 |
+
module_o = lora_o.org_module
|
| 157 |
+
module_t = lora_t.org_module
|
| 158 |
+
diff = module_t.weight.to(work_device) - module_o.weight.to(work_device)
|
| 159 |
+
|
| 160 |
+
# clear weight to save memory
|
| 161 |
+
module_o.weight = None
|
| 162 |
+
module_t.weight = None
|
| 163 |
+
|
| 164 |
+
diffs[lora_name] = diff
|
| 165 |
+
|
| 166 |
+
# clear LoRA network, target U-Net to save memory
|
| 167 |
+
del lora_network_o
|
| 168 |
+
del lora_network_t
|
| 169 |
+
del unet_t
|
| 170 |
+
|
| 171 |
+
# make LoRA with svd
|
| 172 |
+
logger.info("calculating by svd")
|
| 173 |
+
lora_weights = {}
|
| 174 |
+
with torch.no_grad():
|
| 175 |
+
for lora_name, mat in tqdm(list(diffs.items())):
|
| 176 |
+
if args.device:
|
| 177 |
+
mat = mat.to(args.device)
|
| 178 |
+
mat = mat.to(torch.float) # calc by float
|
| 179 |
+
|
| 180 |
+
# if conv_dim is None, diffs do not include LoRAs for conv2d-3x3
|
| 181 |
+
conv2d = len(mat.size()) == 4
|
| 182 |
+
kernel_size = None if not conv2d else mat.size()[2:4]
|
| 183 |
+
conv2d_3x3 = conv2d and kernel_size != (1, 1)
|
| 184 |
+
|
| 185 |
+
rank = dim if not conv2d_3x3 or conv_dim is None else conv_dim
|
| 186 |
+
out_dim, in_dim = mat.size()[0:2]
|
| 187 |
+
|
| 188 |
+
if device:
|
| 189 |
+
mat = mat.to(device)
|
| 190 |
+
|
| 191 |
+
# logger.info(lora_name, mat.size(), mat.device, rank, in_dim, out_dim)
|
| 192 |
+
rank = min(rank, in_dim, out_dim) # LoRA rank cannot exceed the original dim
|
| 193 |
+
|
| 194 |
+
if conv2d:
|
| 195 |
+
if conv2d_3x3:
|
| 196 |
+
mat = mat.flatten(start_dim=1)
|
| 197 |
+
else:
|
| 198 |
+
mat = mat.squeeze()
|
| 199 |
+
|
| 200 |
+
U, S, Vh = torch.linalg.svd(mat)
|
| 201 |
+
|
| 202 |
+
U = U[:, :rank]
|
| 203 |
+
S = S[:rank]
|
| 204 |
+
U = U @ torch.diag(S)
|
| 205 |
+
|
| 206 |
+
Vh = Vh[:rank, :]
|
| 207 |
+
|
| 208 |
+
dist = torch.cat([U.flatten(), Vh.flatten()])
|
| 209 |
+
hi_val = torch.quantile(dist, clamp_quantile)
|
| 210 |
+
low_val = -hi_val
|
| 211 |
+
|
| 212 |
+
U = U.clamp(low_val, hi_val)
|
| 213 |
+
Vh = Vh.clamp(low_val, hi_val)
|
| 214 |
+
|
| 215 |
+
if conv2d:
|
| 216 |
+
U = U.reshape(out_dim, rank, 1, 1)
|
| 217 |
+
Vh = Vh.reshape(rank, in_dim, kernel_size[0], kernel_size[1])
|
| 218 |
+
|
| 219 |
+
U = U.to(work_device, dtype=save_dtype).contiguous()
|
| 220 |
+
Vh = Vh.to(work_device, dtype=save_dtype).contiguous()
|
| 221 |
+
|
| 222 |
+
lora_weights[lora_name] = (U, Vh)
|
| 223 |
+
|
| 224 |
+
# make state dict for LoRA
|
| 225 |
+
lora_sd = {}
|
| 226 |
+
for lora_name, (up_weight, down_weight) in lora_weights.items():
|
| 227 |
+
lora_sd[lora_name + ".lora_up.weight"] = up_weight
|
| 228 |
+
lora_sd[lora_name + ".lora_down.weight"] = down_weight
|
| 229 |
+
lora_sd[lora_name + ".alpha"] = torch.tensor(down_weight.size()[0])
|
| 230 |
+
|
| 231 |
+
# load state dict to LoRA and save it
|
| 232 |
+
lora_network_save, lora_sd = lora.create_network_from_weights(1.0, None, None, text_encoders_o, unet_o, weights_sd=lora_sd)
|
| 233 |
+
lora_network_save.apply_to(text_encoders_o, unet_o) # create internal module references for state_dict
|
| 234 |
+
|
| 235 |
+
info = lora_network_save.load_state_dict(lora_sd)
|
| 236 |
+
logger.info(f"Loading extracted LoRA weights: {info}")
|
| 237 |
+
|
| 238 |
+
dir_name = os.path.dirname(save_to)
|
| 239 |
+
if dir_name and not os.path.exists(dir_name):
|
| 240 |
+
os.makedirs(dir_name, exist_ok=True)
|
| 241 |
+
|
| 242 |
+
# minimum metadata
|
| 243 |
+
net_kwargs = {}
|
| 244 |
+
if conv_dim is not None:
|
| 245 |
+
net_kwargs["conv_dim"] = str(conv_dim)
|
| 246 |
+
net_kwargs["conv_alpha"] = str(float(conv_dim))
|
| 247 |
+
|
| 248 |
+
metadata = {
|
| 249 |
+
"ss_v2": str(v2),
|
| 250 |
+
"ss_base_model_version": model_version,
|
| 251 |
+
"ss_network_module": "networks.lora",
|
| 252 |
+
"ss_network_dim": str(dim),
|
| 253 |
+
"ss_network_alpha": str(float(dim)),
|
| 254 |
+
"ss_network_args": json.dumps(net_kwargs),
|
| 255 |
+
}
|
| 256 |
+
|
| 257 |
+
if not no_metadata:
|
| 258 |
+
title = os.path.splitext(os.path.basename(save_to))[0]
|
| 259 |
+
sai_metadata = sai_model_spec.build_metadata(None, v2, v_parameterization, sdxl, True, False, time.time(), title=title)
|
| 260 |
+
metadata.update(sai_metadata)
|
| 261 |
+
|
| 262 |
+
lora_network_save.save_weights(save_to, save_dtype, metadata)
|
| 263 |
+
logger.info(f"LoRA weights are saved to: {save_to}")
|
| 264 |
+
|
| 265 |
+
|
| 266 |
+
def setup_parser() -> argparse.ArgumentParser:
|
| 267 |
+
parser = argparse.ArgumentParser()
|
| 268 |
+
parser.add_argument("--v2", action="store_true", help="load Stable Diffusion v2.x model / Stable Diffusion 2.xのモデルを読み込む")
|
| 269 |
+
parser.add_argument(
|
| 270 |
+
"--v_parameterization",
|
| 271 |
+
action="store_true",
|
| 272 |
+
default=None,
|
| 273 |
+
help="make LoRA metadata for v-parameterization (default is same to v2) / 作成するLoRAのメタデータにv-parameterization用と設定する(省略時はv2と同じ)",
|
| 274 |
+
)
|
| 275 |
+
parser.add_argument(
|
| 276 |
+
"--sdxl", action="store_true", help="load Stable Diffusion SDXL base model / Stable Diffusion SDXL baseのモデルを読み込む"
|
| 277 |
+
)
|
| 278 |
+
parser.add_argument(
|
| 279 |
+
"--load_precision",
|
| 280 |
+
type=str,
|
| 281 |
+
default=None,
|
| 282 |
+
choices=[None, "float", "fp16", "bf16"],
|
| 283 |
+
help="precision in loading, model default if omitted / 読み込み時に精度を変更して読み込む、省略時はモデルファイルによる"
|
| 284 |
+
)
|
| 285 |
+
parser.add_argument(
|
| 286 |
+
"--save_precision",
|
| 287 |
+
type=str,
|
| 288 |
+
default=None,
|
| 289 |
+
choices=[None, "float", "fp16", "bf16"],
|
| 290 |
+
help="precision in saving, same to merging if omitted / 保存時に精度を変更して保存する、省略時はfloat",
|
| 291 |
+
)
|
| 292 |
+
parser.add_argument(
|
| 293 |
+
"--model_org",
|
| 294 |
+
type=str,
|
| 295 |
+
default=None,
|
| 296 |
+
required=True,
|
| 297 |
+
help="Stable Diffusion original model: ckpt or safetensors file / 元モデル、ckptまたはsafetensors",
|
| 298 |
+
)
|
| 299 |
+
parser.add_argument(
|
| 300 |
+
"--model_tuned",
|
| 301 |
+
type=str,
|
| 302 |
+
default=None,
|
| 303 |
+
required=True,
|
| 304 |
+
help="Stable Diffusion tuned model, LoRA is difference of `original to tuned`: ckpt or safetensors file / 派生モデル(生成されるLoRAは元→派生の差分になります)、ckptまたはsafetensors",
|
| 305 |
+
)
|
| 306 |
+
parser.add_argument(
|
| 307 |
+
"--save_to",
|
| 308 |
+
type=str,
|
| 309 |
+
default=None,
|
| 310 |
+
required=True,
|
| 311 |
+
help="destination file name: ckpt or safetensors file / 保存先のファイル名、ckptまたはsafetensors",
|
| 312 |
+
)
|
| 313 |
+
parser.add_argument("--dim", type=int, default=4, help="dimension (rank) of LoRA (default 4) / LoRAの次元数(rank)(デフォルト4)")
|
| 314 |
+
parser.add_argument(
|
| 315 |
+
"--conv_dim",
|
| 316 |
+
type=int,
|
| 317 |
+
default=None,
|
| 318 |
+
help="dimension (rank) of LoRA for Conv2d-3x3 (default None, disabled) / LoRAのConv2d-3x3の次元数(rank)(デフォルトNone、適用なし)",
|
| 319 |
+
)
|
| 320 |
+
parser.add_argument("--device", type=str, default=None, help="device to use, cuda for GPU / 計算を行うデバイス、cuda でGPUを使う")
|
| 321 |
+
parser.add_argument(
|
| 322 |
+
"--clamp_quantile",
|
| 323 |
+
type=float,
|
| 324 |
+
default=0.99,
|
| 325 |
+
help="Quantile clamping value, float, (0-1). Default = 0.99 / 値をクランプするための分位点、float、(0-1)。デフォルトは0.99",
|
| 326 |
+
)
|
| 327 |
+
parser.add_argument(
|
| 328 |
+
"--min_diff",
|
| 329 |
+
type=float,
|
| 330 |
+
default=0.01,
|
| 331 |
+
help="Minimum difference between finetuned model and base to consider them different enough to extract, float, (0-1). Default = 0.01 /"
|
| 332 |
+
+ "LoRAを抽出するために元モデルと派生モデルの差分の最小値、float、(0-1)。デフォルトは0.01",
|
| 333 |
+
)
|
| 334 |
+
parser.add_argument(
|
| 335 |
+
"--no_metadata",
|
| 336 |
+
action="store_true",
|
| 337 |
+
help="do not save sai modelspec metadata (minimum ss_metadata for LoRA is saved) / "
|
| 338 |
+
+ "sai modelspecのメタデータを保存しない(LoRAの最低限のss_metadataは保存される)",
|
| 339 |
+
)
|
| 340 |
+
parser.add_argument(
|
| 341 |
+
"--load_original_model_to",
|
| 342 |
+
type=str,
|
| 343 |
+
default=None,
|
| 344 |
+
help="location to load original model, cpu or cuda, cuda:0, etc, default is cpu, only for SDXL / 元モデル読み込み先、cpuまたはcuda、cuda:0など、省略時はcpu、SDXLのみ有効",
|
| 345 |
+
)
|
| 346 |
+
parser.add_argument(
|
| 347 |
+
"--load_tuned_model_to",
|
| 348 |
+
type=str,
|
| 349 |
+
default=None,
|
| 350 |
+
help="location to load tuned model, cpu or cuda, cuda:0, etc, default is cpu, only for SDXL / 派生モデル読み込み先、cpuまたはcuda、cuda:0など、省略時はcpu、SDXLのみ有効",
|
| 351 |
+
)
|
| 352 |
+
|
| 353 |
+
return parser
|
| 354 |
+
|
| 355 |
+
|
| 356 |
+
if __name__ == "__main__":
|
| 357 |
+
parser = setup_parser()
|
| 358 |
+
|
| 359 |
+
args = parser.parse_args()
|
| 360 |
+
svd(**vars(args))
|
fine_tune_README_ja.md
ADDED
|
@@ -0,0 +1,140 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
NovelAIの提案した学習手法、自動キャプションニング、タグ付け、Windows+VRAM 12GB(SD v1.xの場合)環境等に対応したfine tuningです。ここでfine tuningとは、モデルを画像とキャプションで学習することを指します(LoRAやTextual Inversion、Hypernetworksは含みません)
|
| 2 |
+
|
| 3 |
+
[学習についての共通ドキュメント](./train_README-ja.md) もあわせてご覧ください。
|
| 4 |
+
|
| 5 |
+
# 概要
|
| 6 |
+
|
| 7 |
+
Diffusersを用いてStable DiffusionのU-Netのfine tuningを行います。NovelAIの記事にある以下の改善に対応しています(Aspect Ratio BucketingについてはNovelAIのコードを参考にしましたが、最終的なコードはすべてオリジナルです)。
|
| 8 |
+
|
| 9 |
+
* CLIP(Text Encoder)の最後の層ではなく最後から二番目の層の出力を用いる。
|
| 10 |
+
* 正方形以外の解像度での学習(Aspect Ratio Bucketing) 。
|
| 11 |
+
* トークン長を75から225に拡張する。
|
| 12 |
+
* BLIPによるキャプショニング(キャプションの自動作成)、DeepDanbooruまたはWD14Taggerによる自動タグ付けを行う。
|
| 13 |
+
* Hypernetworkの学習にも対応する。
|
| 14 |
+
* Stable Diffusion v2.0(baseおよび768/v)に対応。
|
| 15 |
+
* VAEの出力をあらかじめ取得しディスクに保存しておくことで、学習の省メモリ化、高速化を図る。
|
| 16 |
+
|
| 17 |
+
デフォルトではText Encoderの学習は行いません。モデル全体のfine tuningではU-Netだけを学習するのが一般的なようです(NovelAIもそのようです)。オプション指定でText Encoderも学習対象とできます。
|
| 18 |
+
|
| 19 |
+
# 追加機能について
|
| 20 |
+
|
| 21 |
+
## CLIPの出力の変更
|
| 22 |
+
|
| 23 |
+
プロンプトを画像に反映するため、テキストの特徴量への変換を行うのがCLIP(Text Encoder)です。Stable DiffusionではCLIPの最後の層の出力を用いていますが、それを最後から二番目の層の出力を用いるよう変更できます。NovelAIによると、これによりより正確にプロンプトが反映されるようになるとのことです。
|
| 24 |
+
元のまま、最後の層の出力を用いることも可能です。
|
| 25 |
+
|
| 26 |
+
※Stable Diffusion 2.0では最後から二番目の層をデフォルトで使います。clip_skipオプションを指定しないでください。
|
| 27 |
+
|
| 28 |
+
## 正方形以外の解像度での学習
|
| 29 |
+
|
| 30 |
+
Stable Diffusionは512\*512で学習されていますが、それに加えて256\*1024や384\*640といった解像度でも学習します。これによりトリミングされる部分が減り、より正しくプロンプトと画像の関係が学習されることが期待されます。
|
| 31 |
+
学習解像度はパラメータとして与えられた解像度の面積(=メモリ使用量)を超えない範囲で、64ピクセル単位で縦横に調整、作成されます。
|
| 32 |
+
|
| 33 |
+
機械学習では入力サイズをすべて統一するのが一般的ですが、特に制約があるわけではなく、実際は同一のバッチ内で統一されていれば大丈夫です。NovelAIの言うbucketingは、あらかじめ教師データを、アスペクト比に応じた学習解像度ごとに分類しておくことを指しているようです。そしてバッチを各bucket内の画像で作成することで、バッチの画像サイズを統一します。
|
| 34 |
+
|
| 35 |
+
## トークン長の75から225への拡張
|
| 36 |
+
|
| 37 |
+
Stable Diffusionでは最大75トークン(開始・終了を含むと77トークン)ですが、それを225トークンまで拡張します。
|
| 38 |
+
ただしCLIPが受け付ける最大長は75トークンですので、225トークンの場合、単純に三分割してCLIPを呼び出してから結果を連結しています。
|
| 39 |
+
|
| 40 |
+
※これが望ましい実装なのかどうかはいまひとつわかりません。とりあえず動いてはいるようです。特に2.0では何も参考になる実装がないので独自に実装してあります。
|
| 41 |
+
|
| 42 |
+
※Automatic1111氏のWeb UIではカンマを意識して分割、といったこともしているようですが、私の場合はそこまでしておらず単純な分割です。
|
| 43 |
+
|
| 44 |
+
# 学習の手順
|
| 45 |
+
|
| 46 |
+
あらかじめこのリポジトリのREADMEを参照し、環境整備を行ってください。
|
| 47 |
+
|
| 48 |
+
## データの準備
|
| 49 |
+
|
| 50 |
+
[学習データの準備について](./train_README-ja.md) を参照してください。fine tuningではメタデータを用いるfine tuning方式のみ対応しています。
|
| 51 |
+
|
| 52 |
+
## 学習の実行
|
| 53 |
+
たとえば以下のように実行します。以下は省メモリ化のための設定です。それぞれの行を必要に応じて書き換えてください。
|
| 54 |
+
|
| 55 |
+
```
|
| 56 |
+
accelerate launch --num_cpu_threads_per_process 1 fine_tune.py
|
| 57 |
+
--pretrained_model_name_or_path=<.ckptまたは.safetensordまたはDiffusers版モデルのディレクトリ>
|
| 58 |
+
--output_dir=<学習したモデルの出力先フォルダ>
|
| 59 |
+
--output_name=<学習したモデル出力時のファイル名>
|
| 60 |
+
--dataset_config=<データ準備で作成した.tomlファイル>
|
| 61 |
+
--save_model_as=safetensors
|
| 62 |
+
--learning_rate=5e-6 --max_train_steps=10000
|
| 63 |
+
--use_8bit_adam --xformers --gradient_checkpointing
|
| 64 |
+
--mixed_precision=fp16
|
| 65 |
+
```
|
| 66 |
+
|
| 67 |
+
`num_cpu_threads_per_process` には通常は1を指定するとよいようです。
|
| 68 |
+
|
| 69 |
+
`pretrained_model_name_or_path` に追加学習を行う元となるモデルを指定します。Stable Diffusionのcheckpointファイル(.ckptまたは.safetensors)、Diffusersのローカルディスクにあるモデルディレクトリ、DiffusersのモデルID("stabilityai/stable-diffusion-2"など)が指定できます。
|
| 70 |
+
|
| 71 |
+
`output_dir` に学習後のモデルを保存するフォルダを指定します。`output_name` にモデルのファイル名を拡張子を除いて指定します。`save_model_as` でsafetensors形式での保存を指定しています。
|
| 72 |
+
|
| 73 |
+
`dataset_config` に `.toml` ファイルを指定します。ファイル内でのバッチサイズ指定は、当初はメモリ消費を抑えるために `1` としてください。
|
| 74 |
+
|
| 75 |
+
学習させるステップ数 `max_train_steps` を10000とします。学習率 `learning_rate` はここでは5e-6を指定しています。
|
| 76 |
+
|
| 77 |
+
省メモリ化のため `mixed_precision="fp16"` を指定します(RTX30 シリーズ以降では `bf16` も指定できます。環境整備時にaccelerateに行った設定と合わせてください)。また `gradient_checkpointing` を指定します。
|
| 78 |
+
|
| 79 |
+
オプティマイザ(モデルを学習データにあうように最適化=学習させるクラス)にメモリ消費の少ない 8bit AdamW を使うため、 `optimizer_type="AdamW8bit"` を指定します。
|
| 80 |
+
|
| 81 |
+
`xformers` オプションを指定し、xformersのCrossAttentionを用います。xformersをインストールしていない場合やエラーとなる場合(環境にもよりますが `mixed_precision="no"` の場合など)、代わりに `mem_eff_attn` オプションを指定すると省メモリ版CrossAttentionを使用します(速度は遅くなります)。
|
| 82 |
+
|
| 83 |
+
ある程度メモリがある場合は、`.toml` ファイルを編集してバッチサイズをたとえば `4` くらいに増やしてください(高速化と精度向上の可能性があります)。
|
| 84 |
+
|
| 85 |
+
### よく使われるオプションについて
|
| 86 |
+
|
| 87 |
+
以下の場合にはオプションに関するドキュメントを参照してください。
|
| 88 |
+
|
| 89 |
+
- Stable Diffusion 2.xまたはそこからの派生モデルを学習する
|
| 90 |
+
- clip skipを2以上を前提としたモデルを学習する
|
| 91 |
+
- 75トークンを超えたキャプションで学習する
|
| 92 |
+
|
| 93 |
+
### バッチサイズについて
|
| 94 |
+
|
| 95 |
+
モデル全体を学習するためLoRA等の学習に比べるとメモリ消費量は多くなります(DreamBoothと同じ)。
|
| 96 |
+
|
| 97 |
+
### 学習率について
|
| 98 |
+
|
| 99 |
+
1e-6から5e-6程度が一般的なようです。他のfine tuningの例なども参照してみてください。
|
| 100 |
+
|
| 101 |
+
### 以前の形式のデータセット指定をした場合のコマンドライン
|
| 102 |
+
|
| 103 |
+
解像度やバッチサイズをオプションで指定します。コマンドラインの例は以下の通りです。
|
| 104 |
+
|
| 105 |
+
```
|
| 106 |
+
accelerate launch --num_cpu_threads_per_process 1 fine_tune.py
|
| 107 |
+
--pretrained_model_name_or_path=model.ckpt
|
| 108 |
+
--in_json meta_lat.json
|
| 109 |
+
--train_data_dir=train_data
|
| 110 |
+
--output_dir=fine_tuned
|
| 111 |
+
--shuffle_caption
|
| 112 |
+
--train_batch_size=1 --learning_rate=5e-6 --max_train_steps=10000
|
| 113 |
+
--use_8bit_adam --xformers --gradient_checkpointing
|
| 114 |
+
--mixed_precision=bf16
|
| 115 |
+
--save_every_n_epochs=4
|
| 116 |
+
```
|
| 117 |
+
|
| 118 |
+
<!--
|
| 119 |
+
### 勾配をfp16とした学習(実験的機能)
|
| 120 |
+
full_fp16オプションを指定すると勾配を通常のfloat32からfloat16(fp16)に変更して学習します(mixed precisionではなく完全なfp16学習になるようです)。これによりSD1.xの512*512サイズでは8GB未満、SD2.xの512*512サイズで12GB未満のVRAM使用量で学習できるようです。
|
| 121 |
+
|
| 122 |
+
あらかじめaccelerate configでfp16を指定し、オプションでmixed_precision="fp16"としてください(bf16では動作しません)。
|
| 123 |
+
|
| 124 |
+
メモリ使用量を最小化するためには、xformers、use_8bit_adam、gradient_checkpointingの各オプションを指定し、train_batch_sizeを1としてください。
|
| 125 |
+
(余裕があるようならtrain_batch_sizeを段階的に増やすと若干精度が上がるはずです。)
|
| 126 |
+
|
| 127 |
+
PyTorchのソースにパッチを当てて無理やり実現しています(PyTorch 1.12.1と1.13.0で確認)。精度はかなり落ちますし、途中で学習失敗する確率も高くなります。学習率やステップ数の設定もシビアなようです。それらを認識したうえで自己責任でお使いください。
|
| 128 |
+
-->
|
| 129 |
+
|
| 130 |
+
# fine tuning特有のその他の主なオプション
|
| 131 |
+
|
| 132 |
+
すべてのオプションについては別文書を参照してください。
|
| 133 |
+
|
| 134 |
+
## `train_text_encoder`
|
| 135 |
+
Text Encoderも学習対象とします。メモリ使用量が若干増加します。
|
| 136 |
+
|
| 137 |
+
通常のfine tuningではText Encoderは学習対象としませんが(恐らくText Encoderの出力に従うようにU-Netを学習するため)、学習データ数が少ない場合には、DreamBoothのようにText Encoder側に学習させるのも有効的なようです。
|
| 138 |
+
|
| 139 |
+
## `diffusers_xformers`
|
| 140 |
+
スクリプト独自のxformers置換機能ではなくDiffusersのxformers機能を利用します。Hypernetworkの学習はできなくなります。
|
gen_img_README-ja.md
ADDED
|
@@ -0,0 +1,487 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
SD 1.xおよび2.xのモデル、当リポジトリで学習したLoRA、ControlNet(v1.0のみ動作確認)などに対応した、Diffusersベースの推論(画像生成)スクリプトです。コマンドラインから用います。
|
| 2 |
+
|
| 3 |
+
# 概要
|
| 4 |
+
|
| 5 |
+
* Diffusers (v0.10.2) ベースの推論(画像生成)スクリプト。
|
| 6 |
+
* SD 1.xおよび2.x (base/v-parameterization)モデルに対応。
|
| 7 |
+
* txt2img、img2img、inpaintingに対応。
|
| 8 |
+
* 対話モード、およびファイルからのプロンプト読み込み、連続生成に対応。
|
| 9 |
+
* プロンプト1行あたりの生成枚数を指定可能。
|
| 10 |
+
* 全体の繰り返し回数を指定可能。
|
| 11 |
+
* `fp16`だけでなく`bf16`にも対応。
|
| 12 |
+
* xformersに対応し高速生成が可能。
|
| 13 |
+
* xformersにより省メモリ生成を行いますが、Automatic 1111氏のWeb UIほど最適化していないため、512*512の画像生成でおおむね6GB程度のVRAMを使用します。
|
| 14 |
+
* プロンプトの225トークンへの拡張。ネガティブプロンプト、重みづけに対応。
|
| 15 |
+
* Diffusersの各種samplerに対応(Web UIよりもsampler数は少ないです)。
|
| 16 |
+
* Text Encoderのclip skip(最後からn番目の層の出力を用いる)に対応。
|
| 17 |
+
* VAEの別途読み込み。
|
| 18 |
+
* CLIP Guided Stable Diffusion、VGG16 Guided Stable Diffusion、Highres. fix、upscale対応。
|
| 19 |
+
* Highres. fixはWeb UIの実装を全く確認していない独自実装のため、出力結果は異なるかもしれません。
|
| 20 |
+
* LoRA対応。適用率指定、複数LoRA同時利用、重みのマージに対応。
|
| 21 |
+
* Text EncoderとU-Netで別の適用率を指定することはできません。
|
| 22 |
+
* Attention Coupleに対応。
|
| 23 |
+
* ControlNet v1.0に対応。
|
| 24 |
+
* 途中でモデルを切り替えることはできませんが、バッチファイルを組むことで対応できます。
|
| 25 |
+
* 個人的に欲しくなった機能をいろいろ追加。
|
| 26 |
+
|
| 27 |
+
機能追加時にすべてのテストを行っているわけではないため、以前の機能に影響が出て一部機能が動かない可能性があります。何か問題があればお知らせください。
|
| 28 |
+
|
| 29 |
+
# 基本的な使い方
|
| 30 |
+
|
| 31 |
+
## 対話モードでの画像生成
|
| 32 |
+
|
| 33 |
+
以下のように入力してください。
|
| 34 |
+
|
| 35 |
+
```batchfile
|
| 36 |
+
python gen_img_diffusers.py --ckpt <モデル名> --outdir <画像出力先> --xformers --fp16 --interactive
|
| 37 |
+
```
|
| 38 |
+
|
| 39 |
+
`--ckpt`オプションにモデル(Stable Diffusionのcheckpointファイル、またはDiffusersのモデルフォルダ)、`--outdir`オプションに画像の出力先フォルダを指定します。
|
| 40 |
+
|
| 41 |
+
`--xformers`オプションでxformersの使用を指定します(xformersを使わない場合は外してください)。`--fp16`オプションでfp16(単精度)での推論を行います。RTX 30系のGPUでは `--bf16`オプションでbf16(bfloat16)での推論を行うこともできます。
|
| 42 |
+
|
| 43 |
+
`--interactive`オプションで対話モードを指定しています。
|
| 44 |
+
|
| 45 |
+
Stable Diffusion 2.0(またはそこからの追加学習モデル)を使う場合は`--v2`オプションを追加してください。v-parameterizationを使うモデル(`768-v-ema.ckpt`およびそこからの追加学習モデル)を使う場合はさらに`--v_parameterization`を追加してください。
|
| 46 |
+
|
| 47 |
+
`--v2`の指定有無が間違っているとモデル読み込み時にエラーになります。`--v_parameterization`の指定有無が間違っていると茶色い画像が表示されます。
|
| 48 |
+
|
| 49 |
+
`Type prompt:`と表示されたらプロンプトを入力してください。
|
| 50 |
+
|
| 51 |
+

|
| 52 |
+
|
| 53 |
+
※画像が表示されずエラーになる場合、headless(画面表示機能なし)のOpenCVがインストールされているかもしれません。`pip install opencv-python`として通常のOpenCVを入れてください。または`--no_preview`オプションで画像表示を止めてください。
|
| 54 |
+
|
| 55 |
+
画像ウィンドウを選択してから何らかのキーを押すとウィンドウが閉じ、次のプロンプトが入力できます。プロンプトでCtrl+Z、エンターの順に打鍵するとスクリプトを閉じます。
|
| 56 |
+
|
| 57 |
+
## 単一のプロンプトで画像を一括生成
|
| 58 |
+
|
| 59 |
+
以下のように入力します(実際には1行で入力します)。
|
| 60 |
+
|
| 61 |
+
```batchfile
|
| 62 |
+
python gen_img_diffusers.py --ckpt <モデル名> --outdir <画像出力先>
|
| 63 |
+
--xformers --fp16 --images_per_prompt <生成枚数> --prompt "<プロンプト>"
|
| 64 |
+
```
|
| 65 |
+
|
| 66 |
+
`--images_per_prompt`オプションで、プロンプト1件当たりの生成枚数を指定します。`--prompt`オプションでプロンプトを指定します。スペースを含む場合はダブルクォーテーションで囲んでください。
|
| 67 |
+
|
| 68 |
+
`--batch_size`オプションでバッチサイズを指定できます(後述)。
|
| 69 |
+
|
| 70 |
+
## ファイルからプロンプトを読み込み一括生成
|
| 71 |
+
|
| 72 |
+
以下のように入力します。
|
| 73 |
+
|
| 74 |
+
```batchfile
|
| 75 |
+
python gen_img_diffusers.py --ckpt <モデル名> --outdir <画像出力先>
|
| 76 |
+
--xformers --fp16 --from_file <プロンプトファイル名>
|
| 77 |
+
```
|
| 78 |
+
|
| 79 |
+
`--from_file`オプションで、プロンプトが記述されたファイルを指定します。1行1プロンプトで記述してください。`--images_per_prompt`オプションを指定して1行あたり生成枚数を指定できます。
|
| 80 |
+
|
| 81 |
+
## ネガティブプロンプト、重みづけの使用
|
| 82 |
+
|
| 83 |
+
プロンプトオプション(プロンプト内で`--x`のように指定、後述)で`--n`を書くと、以降がネガティブプロンプトとなります。
|
| 84 |
+
|
| 85 |
+
またAUTOMATIC1111氏のWeb UIと同様の `()` や` []` 、`(xxx:1.3)` などによる重みづけが可能です(実装はDiffusersの[Long Prompt Weighting Stable Diffusion](https://github.com/huggingface/diffusers/blob/main/examples/community/README.md#long-prompt-weighting-stable-diffusion)からコピーしたものです)。
|
| 86 |
+
|
| 87 |
+
コマンドラインからのプロンプト指定、ファイルからのプロンプト読み込みでも同様に指定できます。
|
| 88 |
+
|
| 89 |
+

|
| 90 |
+
|
| 91 |
+
# 主なオプション
|
| 92 |
+
|
| 93 |
+
コマンドラインから指定してください。
|
| 94 |
+
|
| 95 |
+
## モデルの指定
|
| 96 |
+
|
| 97 |
+
- `--ckpt <モデル名>`:モデル名を指定します。`--ckpt`オプションは必須です。Stable Diffusionのcheckpointファイル、またはDiffusersのモデルフォルダ、Hugging FaceのモデルIDを指定できます。
|
| 98 |
+
|
| 99 |
+
- `--v2`:Stable Diffusion 2.x系のモデルを使う場合に指定します。1.x系の場合には指定不要です。
|
| 100 |
+
|
| 101 |
+
- `--v_parameterization`:v-parameterizationを使うモデルを使う場合に指定します(`768-v-ema.ckpt`およびそこからの追加学習モデル、Waifu Diffusion v1.5など)。
|
| 102 |
+
|
| 103 |
+
`--v2`の指定有無が間違っているとモデル読み込み時にエラーになります。`--v_parameterization`の指定有無が間違っていると茶色い画像が表示されます。
|
| 104 |
+
|
| 105 |
+
- `--vae`:使用するVAEを指定します。未指定時はモデル内のVAEを使用します。
|
| 106 |
+
|
| 107 |
+
## 画像生成と出力
|
| 108 |
+
|
| 109 |
+
- `--interactive`:インタラクティブモードで動作します。プロンプトを入力すると画像が生成されます。
|
| 110 |
+
|
| 111 |
+
- `--prompt <プロンプト>`:プロンプトを指定します。スペースを含む場合はダブルクォーテーションで囲んでください。
|
| 112 |
+
|
| 113 |
+
- `--from_file <プロンプトファイル名>`:プロンプトが記述されたファイルを指定します。1行1プロンプトで記述してください。なお画像サイズやguidance scaleはプロンプトオプション(後述)で指定できます。
|
| 114 |
+
|
| 115 |
+
- `--W <画像幅>`:画像の幅を指定します。デフォルトは`512`です。
|
| 116 |
+
|
| 117 |
+
- `--H <画像高さ>`:画像の高さを指定します。デフォルトは`512`です。
|
| 118 |
+
|
| 119 |
+
- `--steps <ステップ数>`:サンプリングステップ数を指定します。デフォルトは`50`です。
|
| 120 |
+
|
| 121 |
+
- `--scale <ガイダンススケール>`:unconditionalガイダンススケールを指定します。デフォルトは`7.5`です。
|
| 122 |
+
|
| 123 |
+
- `--sampler <サンプラー名>`:サンプラーを指定します。デフォルトは`ddim`です。Diffusersで提供されているddim、pndm、dpmsolver、dpmsolver+++、lms、euler、euler_a、が指定可能です(後ろの三つはk_lms、k_euler、k_euler_aでも指定できます)。
|
| 124 |
+
|
| 125 |
+
- `--outdir <画像出力先フォルダ>`:画像の出力先を指定します。
|
| 126 |
+
|
| 127 |
+
- `--images_per_prompt <生成枚数>`:プロンプト1件当たりの生成枚数を指定します。デフォルトは`1`です。
|
| 128 |
+
|
| 129 |
+
- `--clip_skip <スキップ数>`:CLIPの後ろから何番目の層を使うかを指定します。省略時は最後の層を使います。
|
| 130 |
+
|
| 131 |
+
- `--max_embeddings_multiples <倍数>`:CLIPの入出力長をデフォルト(75)の何倍にするかを指定します。未指定時は75のままです。たとえば3を指定すると入出力長が225になります。
|
| 132 |
+
|
| 133 |
+
- `--negative_scale` : uncoditioningのguidance scaleを個別に指定します。[gcem156氏のこちらの記事](https://note.com/gcem156/n/ne9a53e4a6f43)を参考に実装したものです。
|
| 134 |
+
|
| 135 |
+
## メモリ使用量や生成速度の調整
|
| 136 |
+
|
| 137 |
+
- `--batch_size <バッチサイズ>`:バッチサイズを指定します。デフォルトは`1`です。バッチサイズが大きいとメモリを多く消費しますが、生成速度が速くなります。
|
| 138 |
+
|
| 139 |
+
- `--vae_batch_size <VAEのバッチサイズ>`:VAEのバッチサイズを指定します。デフォルトはバッチサイズと同じです。
|
| 140 |
+
VAEのほうがメモリを多く消費するため、デノイジング後(stepが100%になった後)でメモリ不足になる場合があります。このような場合にはVAEのバッチサイズを小さくしてくだ���い。
|
| 141 |
+
|
| 142 |
+
- `--xformers`:xformersを使う場合に指定します。
|
| 143 |
+
|
| 144 |
+
- `--fp16`:fp16(単精度)での推論を行います。`fp16`と`bf16`をどちらも指定しない場合はfp32(単精度)での推論を行います。
|
| 145 |
+
|
| 146 |
+
- `--bf16`:bf16(bfloat16)での推論を行います。RTX 30系のGPUでのみ指定可能です。`--bf16`オプションはRTX 30系以外のGPUではエラーになります。`fp16`よりも`bf16`のほうが推論結果がNaNになる(真っ黒の画像になる)可能性が低いようです。
|
| 147 |
+
|
| 148 |
+
## 追加ネットワーク(LoRA等)の使用
|
| 149 |
+
|
| 150 |
+
- `--network_module`:使用する追加ネットワークを指定します。LoRAの場合は`--network_module networks.lora`と指定します。複数のLoRAを使用する場合は`--network_module networks.lora networks.lora networks.lora`のように指定します。
|
| 151 |
+
|
| 152 |
+
- `--network_weights`:使用する追加ネットワークの重みファイルを指定します。`--network_weights model.safetensors`のように指定します。複数のLoRAを使用する場合は`--network_weights model1.safetensors model2.safetensors model3.safetensors`のように指定します。引数の数は`--network_module`で指定した数と同じにしてください。
|
| 153 |
+
|
| 154 |
+
- `--network_mul`:使用する追加ネットワークの重みを何倍にするかを指定します。デフォルトは`1`です。`--network_mul 0.8`のように指定します。複数のLoRAを使用する場合は`--network_mul 0.4 0.5 0.7`のように指定します。引数の数は`--network_module`で指定した数と同じにしてください。
|
| 155 |
+
|
| 156 |
+
- `--network_merge`:使用する追加ネットワークの重みを`--network_mul`に指定した重みであらかじめマージします。`--network_pre_calc` と同時に使用できません。プロンプトオプションの`--am`、およびRegional LoRAは使用できなくなりますが、LoRA未使用時と同じ程度まで生成が高速化されます。
|
| 157 |
+
|
| 158 |
+
- `--network_pre_calc`:使用する追加ネットワークの重みを生成ごとにあらかじめ計算します。プロンプトオプションの`--am`が使用できます。LoRA未使用時と同じ程度まで生成は高速化されますが、生成前に重みを計算する時間が必要で、またメモリ使用量も若干増加します。Regional LoRA使用時は無効になります 。
|
| 159 |
+
|
| 160 |
+
# 主なオプションの指定例
|
| 161 |
+
|
| 162 |
+
次は同一プロンプトで64枚をバッチサイズ4で一括生成する例です。
|
| 163 |
+
|
| 164 |
+
```batchfile
|
| 165 |
+
python gen_img_diffusers.py --ckpt model.ckpt --outdir outputs
|
| 166 |
+
--xformers --fp16 --W 512 --H 704 --scale 12.5 --sampler k_euler_a
|
| 167 |
+
--steps 32 --batch_size 4 --images_per_prompt 64
|
| 168 |
+
--prompt "beautiful flowers --n monochrome"
|
| 169 |
+
```
|
| 170 |
+
|
| 171 |
+
次はファイルに書かれたプロンプトを、それぞれ10枚ずつ、バッチサイズ4で一括生成する例です。
|
| 172 |
+
|
| 173 |
+
```batchfile
|
| 174 |
+
python gen_img_diffusers.py --ckpt model.ckpt --outdir outputs
|
| 175 |
+
--xformers --fp16 --W 512 --H 704 --scale 12.5 --sampler k_euler_a
|
| 176 |
+
--steps 32 --batch_size 4 --images_per_prompt 10
|
| 177 |
+
--from_file prompts.txt
|
| 178 |
+
```
|
| 179 |
+
|
| 180 |
+
Textual Inversion(後述)およびLoRAの使用例です。
|
| 181 |
+
|
| 182 |
+
```batchfile
|
| 183 |
+
python gen_img_diffusers.py --ckpt model.safetensors
|
| 184 |
+
--scale 8 --steps 48 --outdir txt2img --xformers
|
| 185 |
+
--W 512 --H 768 --fp16 --sampler k_euler_a
|
| 186 |
+
--textual_inversion_embeddings goodembed.safetensors negprompt.pt
|
| 187 |
+
--network_module networks.lora networks.lora
|
| 188 |
+
--network_weights model1.safetensors model2.safetensors
|
| 189 |
+
--network_mul 0.4 0.8
|
| 190 |
+
--clip_skip 2 --max_embeddings_multiples 1
|
| 191 |
+
--batch_size 8 --images_per_prompt 1 --interactive
|
| 192 |
+
```
|
| 193 |
+
|
| 194 |
+
# プロンプトオプション
|
| 195 |
+
|
| 196 |
+
プロンプト内で、`--n`のように「ハイフンふたつ+アルファベットn文字」でプロンプトから各種オプションの指定が可能です。対話モード、コマンドライン、ファイル、いずれからプロンプトを指定する場合でも有効です。
|
| 197 |
+
|
| 198 |
+
プロンプトのオプション指定`--n`の前後にはスペースを入れてください。
|
| 199 |
+
|
| 200 |
+
- `--n`:ネガティブプロンプトを指定します。
|
| 201 |
+
|
| 202 |
+
- `--w`:画像幅を指定します。コマンドラインからの指定を上書きします。
|
| 203 |
+
|
| 204 |
+
- `--h`:画像高さを指定します。コマンドラインからの指定を上書きします。
|
| 205 |
+
|
| 206 |
+
- `--s`:ステップ数を指定します。コマンドラインからの指定を上書きします。
|
| 207 |
+
|
| 208 |
+
- `--d`:この画像の乱数seedを指定します。`--images_per_prompt`を指定している場合は「--d 1,2,3,4」のようにカンマ区切りで複数指定してください。
|
| 209 |
+
※様々な理由により、Web UIとは同じ乱数seedでも生成される画像が異なる場合があります。
|
| 210 |
+
|
| 211 |
+
- `--l`:guidance scaleを指定します。コマンドラインからの指定を上書きします。
|
| 212 |
+
|
| 213 |
+
- `--t`:img2img(後述)のstrengthを指定します。コ��ンドラインからの指定を上書きします。
|
| 214 |
+
|
| 215 |
+
- `--nl`:ネガティブプロンプトのguidance scaleを指定します(後述)。コマンドラインからの指定を上書きします。
|
| 216 |
+
|
| 217 |
+
- `--am`:追加ネットワークの重みを指定します。コマンドラインからの指定を上書きします。複数の追加ネットワークを使用する場合は`--am 0.8,0.5,0.3`のように __カンマ区切りで__ 指定します。
|
| 218 |
+
|
| 219 |
+
※これらのオプションを指定すると、バッチサイズよりも小さいサイズでバッチが実行される場合があります(これらの値が異なると一括生成できないため)。(あまり気にしなくて大丈夫ですが、ファイルからプロンプトを読み込み生成する場合は、これらの値が同一のプロンプトを並べておくと効率が良くなります。)
|
| 220 |
+
|
| 221 |
+
例:
|
| 222 |
+
```
|
| 223 |
+
(masterpiece, best quality), 1girl, in shirt and plated skirt, standing at street under cherry blossoms, upper body, [from below], kind smile, looking at another, [goodembed] --n realistic, real life, (negprompt), (lowres:1.1), (worst quality:1.2), (low quality:1.1), bad anatomy, bad hands, text, error, missing fingers, extra digit, fewer digits, cropped, normal quality, jpeg artifacts, signature, watermark, username, blurry --w 960 --h 640 --s 28 --d 1
|
| 224 |
+
```
|
| 225 |
+
|
| 226 |
+

|
| 227 |
+
|
| 228 |
+
# img2img
|
| 229 |
+
|
| 230 |
+
## オプション
|
| 231 |
+
|
| 232 |
+
- `--image_path`:img2imgに利用する画像を指定します。`--image_path template.png`のように指定します。フォルダを指定すると、そのフォルダの画像を順次利用します。
|
| 233 |
+
|
| 234 |
+
- `--strength`:img2imgのstrengthを指定します。`--strength 0.8`のように指定します。デフォルトは`0.8`です。
|
| 235 |
+
|
| 236 |
+
- `--sequential_file_name`:ファイル名を連番にするかどうかを指定します。指定すると生成されるファイル名が`im_000001.png`からの連番になります。
|
| 237 |
+
|
| 238 |
+
- `--use_original_file_name`:指定すると生成ファイル名がオリジナルのファイル名と同じになります。
|
| 239 |
+
|
| 240 |
+
## コマンドラインからの実行例
|
| 241 |
+
|
| 242 |
+
```batchfile
|
| 243 |
+
python gen_img_diffusers.py --ckpt trinart_characters_it4_v1_vae_merged.ckpt
|
| 244 |
+
--outdir outputs --xformers --fp16 --scale 12.5 --sampler k_euler --steps 32
|
| 245 |
+
--image_path template.png --strength 0.8
|
| 246 |
+
--prompt "1girl, cowboy shot, brown hair, pony tail, brown eyes,
|
| 247 |
+
sailor school uniform, outdoors
|
| 248 |
+
--n lowres, bad anatomy, bad hands, error, missing fingers, cropped,
|
| 249 |
+
worst quality, low quality, normal quality, jpeg artifacts, (blurry),
|
| 250 |
+
hair ornament, glasses"
|
| 251 |
+
--batch_size 8 --images_per_prompt 32
|
| 252 |
+
```
|
| 253 |
+
|
| 254 |
+
`--image_path`オプションにフォルダを指定すると、そのフォルダの画像を順次読み込みます。生成される枚数は画像枚数ではなく、プロンプト数になりますので、`--images_per_promptPPオプションを指定してimg2imgする画像の枚数とプロンプト数を合わせてください。
|
| 255 |
+
|
| 256 |
+
ファイルはファイル名でソートして読み込みます。なおソート順は文字列順となりますので(`1.jpg→2.jpg→10.jpg`ではなく`1.jpg→10.jpg→2.jpg`の順)、頭を0埋めするなどしてご対応ください(`01.jpg→02.jpg→10.jpg`)。
|
| 257 |
+
|
| 258 |
+
## img2imgを利用したupscale
|
| 259 |
+
|
| 260 |
+
img2img時にコマンドラインオプションの`--W`と`--H`で生成画像サイズを指定すると、元画像をそのサイズにリサイズしてからimg2imgを行います。
|
| 261 |
+
|
| 262 |
+
またimg2imgの元画像がこのスクリプトで生成した画像の場合、プロンプトを省略すると、元画像のメタデータからプロンプトを取得しそのまま用います。これによりHighres. fixの2nd stageの動作だけを行うことができます。
|
| 263 |
+
|
| 264 |
+
## img2img時のinpainting
|
| 265 |
+
|
| 266 |
+
画像およびマスク画像を指定してinpaintingできます(inpaintingモデルには対応しておらず、単にマスク領域を対象にimg2imgするだけです)。
|
| 267 |
+
|
| 268 |
+
オプションは以下の通りです。
|
| 269 |
+
|
| 270 |
+
- `--mask_image`:マスク画像を指定します。`--img_path`と同様にフォルダを指定すると、そのフォルダの画像を順次利用します。
|
| 271 |
+
|
| 272 |
+
マスク画像はグレースケール画像で、白の部分がinpaintingされます。境界をグラデーションしておくとなんとなく滑らかになりますのでお勧めです。
|
| 273 |
+
|
| 274 |
+

|
| 275 |
+
|
| 276 |
+
# その他の機能
|
| 277 |
+
|
| 278 |
+
## Textual Inversion
|
| 279 |
+
|
| 280 |
+
`--textual_inversion_embeddings`オプションで使用するembeddingsを指定します(複数指定可)。拡張子を除いたファイル名をプロンプト内で使用することで、そのembeddingsを利用します(Web UIと同様の使用法です)。ネガティブプロンプト内でも使用できます。
|
| 281 |
+
|
| 282 |
+
モデルとして、当リポジトリで学習したTextual Inversionモデル、およびWeb UIで学習したTextual Inversionモデル(画像埋め込みは非対応)を利用できます
|
| 283 |
+
|
| 284 |
+
## Extended Textual Inversion
|
| 285 |
+
|
| 286 |
+
`--textual_inversion_embeddings`の代わりに`--XTI_embeddings`オプションを指定してください。使用法は`--textual_inversion_embeddings`と同じです。
|
| 287 |
+
|
| 288 |
+
## Highres. fix
|
| 289 |
+
|
| 290 |
+
AUTOMATIC1111氏のWeb UIにある機能の類似機能です(独自実装のためもしかしたらいろいろ異なるかもしれません)。最初に小さめの画像を生成し、その画像を元にimg2imgすることで、画像全体の破綻を防ぎつつ大きな解像度の画像を生成します。
|
| 291 |
+
|
| 292 |
+
2nd stageのstep数は`--steps` と`--strength`オプションの値から計算されます(`steps*strength`)。
|
| 293 |
+
|
| 294 |
+
img2imgと併用できません。
|
| 295 |
+
|
| 296 |
+
以下のオプションがあります。
|
| 297 |
+
|
| 298 |
+
- `--highres_fix_scale`:Highres. fixを有効にして、1st stageで生成する画像のサイズを、倍率で指定します。最終出力が1024x1024で、最初に512x512の画像を生成する場合は`--highres_fix_scale 0.5`のように指定します。Web UI出の指定の逆数になっていますのでご注意ください。
|
| 299 |
+
|
| 300 |
+
- `--highres_fix_steps`:1st stageの画像のステップ数を指定します。デフォルトは`28`です。
|
| 301 |
+
|
| 302 |
+
- `--highres_fix_save_1st`:1st stageの画像を保存するかどうかを指定します。
|
| 303 |
+
|
| 304 |
+
- `--highres_fix_latents_upscaling`:指定すると2nd stageの画像生成時に1st stageの画像をlatentベースでupscalingします(bilinearのみ対応)。未指定時は画像をLANCZOS4でupscalingします。
|
| 305 |
+
|
| 306 |
+
- `--highres_fix_upscaler`:2nd stageに任意のupscalerを利用します。現在は`--highres_fix_upscaler tools.latent_upscaler` のみ対応しています。
|
| 307 |
+
|
| 308 |
+
- `--highres_fix_upscaler_args`:`--highres_fix_upscaler`で指定したupscalerに渡す引数を指定します。
|
| 309 |
+
`tools.latent_upscaler`の場合は、`--highres_fix_upscaler_args "weights=D:\Work\SD\Models\others\etc\upscaler-v1-e100-220.safetensors"`のように重みファイルを指定します。
|
| 310 |
+
|
| 311 |
+
コマンドラインの例です。
|
| 312 |
+
|
| 313 |
+
```batchfile
|
| 314 |
+
python gen_img_diffusers.py --ckpt trinart_characters_it4_v1_vae_merged.ckpt
|
| 315 |
+
--n_iter 1 --scale 7.5 --W 1024 --H 1024 --batch_size 1 --outdir ../txt2img
|
| 316 |
+
--steps 48 --sampler ddim --fp16
|
| 317 |
+
--xformers
|
| 318 |
+
--images_per_prompt 1 --interactive
|
| 319 |
+
--highres_fix_scale 0.5 --highres_fix_steps 28 --strength 0.5
|
| 320 |
+
```
|
| 321 |
+
|
| 322 |
+
## ControlNet
|
| 323 |
+
|
| 324 |
+
現在はControlNet 1.0のみ動作確認しています。プリプロセスはCannyのみサポートしています。
|
| 325 |
+
|
| 326 |
+
以下のオプションがあります。
|
| 327 |
+
|
| 328 |
+
- `--control_net_models`:ControlNetのモデルファイルを指定します。
|
| 329 |
+
複数指定すると、それらをstepごとに切り替えて利用します(Web UIのControlNet拡張の実装と異なります)。diffと通常の両方をサポートします。
|
| 330 |
+
|
| 331 |
+
- `--guide_image_path`:ControlNetに使うヒント画像を指定します。`--img_path`と同様にフォルダを指定すると、そのフォルダの画像を順次利用します。Canny以外のモデルの場合には、あらかじめプリプロセスを行っておいてください。
|
| 332 |
+
|
| 333 |
+
- `--control_net_preps`:ControlNetのプリプロセスを指定します。`--control_net_models`と同様に複数指定可能です。現在はcannyのみ対応しています。対象モデルでプリプロセスを使用しない場合は `none` を指定します。
|
| 334 |
+
cannyの場合 `--control_net_preps canny_63_191`のように、閾値1と2を'_'で区切って指定できます。
|
| 335 |
+
|
| 336 |
+
- `--control_net_weights`:ControlNetの適用時の重みを指定します(`1.0`で通常、`0.5`なら半分の影響力で適用)。`--control_net_models`と同様に複数指定可能です。
|
| 337 |
+
|
| 338 |
+
- `--control_net_ratios`:ControlNetを適用するstepの範囲を指定します。`0.5`の場合は、step数の半分までControlNetを適用します。`--control_net_models`と同様に複数指定可能です。
|
| 339 |
+
|
| 340 |
+
コマンドラインの例です。
|
| 341 |
+
|
| 342 |
+
```batchfile
|
| 343 |
+
python gen_img_diffusers.py --ckpt model_ckpt --scale 8 --steps 48 --outdir txt2img --xformers
|
| 344 |
+
--W 512 --H 768 --bf16 --sampler k_euler_a
|
| 345 |
+
--control_net_models diff_control_sd15_canny.safetensors --control_net_weights 1.0
|
| 346 |
+
--guide_image_path guide.png --control_net_ratios 1.0 --interactive
|
| 347 |
+
```
|
| 348 |
+
|
| 349 |
+
## Attention Couple + Reginal LoRA
|
| 350 |
+
|
| 351 |
+
プロンプトをいくつかの部分に分割し、それぞれのプロンプトを画像内のどの領域に適用するかを指定できる機能です。個別のオプションはありませんが、`mask_path`とプロンプトで指定します。
|
| 352 |
+
|
| 353 |
+
まず、プロンプトで` AND `を利用して、複数部分を定義します。最初の3つに対して領域指定ができ、以降の部分は画像全体へ適用されます。ネガティブプロンプトは画像全体に適用されます。
|
| 354 |
+
|
| 355 |
+
以下ではANDで3つの部分を定義しています。
|
| 356 |
+
|
| 357 |
+
```
|
| 358 |
+
shs 2girls, looking at viewer, smile AND bsb 2girls, looking back AND 2girls --n bad quality, worst quality
|
| 359 |
+
```
|
| 360 |
+
|
| 361 |
+
次にマスク画像を用意します。マスク画像はカラーの画像で、RGBの各チャネルがプロンプトのANDで区切られた部分に対応します。またあるチャネルの値がすべて0の場合、画像全体に適用されます。
|
| 362 |
+
|
| 363 |
+
上記の例では、Rチャネルが`shs 2girls, looking at viewer, smile`、Gチャネルが`bsb 2girls, looking back`に、Bチャネルが`2girls`に対応します。次のようなマスク画像を使用すると、Bチャネルに指定がありませんので、`2girls`は画像全体に適用されます。
|
| 364 |
+
|
| 365 |
+

|
| 366 |
+
|
| 367 |
+
マスク画像は`--mask_path`で指定します。現在は1枚のみ対応しています。指定した画像サイズに自動的にリサイズされ適用されます。
|
| 368 |
+
|
| 369 |
+
ControlNetと組み合わせることも可能です(細かい位置指定にはControlNetとの組み合わせを推奨します)。
|
| 370 |
+
|
| 371 |
+
LoRAを指定すると、`--network_weights`で指定した複数のLoRAがそれぞれANDの各部分に対応します。現在の制約として、LoRAの数はANDの部分の数と同じである必要があります。
|
| 372 |
+
|
| 373 |
+
## CLIP Guided Stable Diffusion
|
| 374 |
+
|
| 375 |
+
DiffusersのCommunity Examplesの[こちらのcustom pipeline](https://github.com/huggingface/diffusers/blob/main/examples/community/README.md#clip-guided-stable-diffusion)からソースをコピー、変更したものです。
|
| 376 |
+
|
| 377 |
+
通常のプロンプトによる生成指定に加えて、追加でより大規模のCLIPでプロンプトのテキストの特徴量を取得し、生成中の画像の特徴量がそのテキストの特徴量に近づくよう、生成される画像をコントロールします(私のざっくりとした理解です)。大きめのCLIPを使いますのでVRAM使用量はかなり増加し(VRAM 8GBでは512*512でも厳しいかもしれません)、生成時間も掛かります。
|
| 378 |
+
|
| 379 |
+
なお選択できるサンプラーはDDIM、PNDM、LMSのみとなります。
|
| 380 |
+
|
| 381 |
+
`--clip_guidance_scale`オプションにどの程度、CLIPの特徴量を反映するかを数値で指定します。先のサンプルでは100になっていますので、そのあたりから始めて増減すると良いようです。
|
| 382 |
+
|
| 383 |
+
デフォルトではプロンプトの先頭75トークン(重みづけの特殊文字を除く)がCLIPに渡されます。プロンプトの`--c`オプションで、通常のプロンプトではなく、CLIPに渡すテキストを別に指定できます(たとえばCLIPはDreamBoothのidentifier(識別子)や「1girl」などのモデル特有の単語は認識できないと思われますので、それらを省いたテキストが良いと思われます)。
|
| 384 |
+
|
| 385 |
+
コマンドラインの例です。
|
| 386 |
+
|
| 387 |
+
```batchfile
|
| 388 |
+
python gen_img_diffusers.py --ckpt v1-5-pruned-emaonly.ckpt --n_iter 1
|
| 389 |
+
--scale 2.5 --W 512 --H 512 --batch_size 1 --outdir ../txt2img --steps 36
|
| 390 |
+
--sampler ddim --fp16 --opt_channels_last --xformers --images_per_prompt 1
|
| 391 |
+
--interactive --clip_guidance_scale 100
|
| 392 |
+
```
|
| 393 |
+
|
| 394 |
+
## CLIP Image Guided Stable Diffusion
|
| 395 |
+
|
| 396 |
+
テキストではなくCLIPに別の画像を渡し、その特徴量に近づくよう生成をコントロールする機能です。`--clip_image_guidance_scale`オプションで適用量の数値を、`--guide_image_path`オプションでguideに使用する画像(ファイルまたはフォルダ)を指定してください。
|
| 397 |
+
|
| 398 |
+
コマンドラインの例です。
|
| 399 |
+
|
| 400 |
+
```batchfile
|
| 401 |
+
python gen_img_diffusers.py --ckpt trinart_characters_it4_v1_vae_merged.ckpt
|
| 402 |
+
--n_iter 1 --scale 7.5 --W 512 --H 512 --batch_size 1 --outdir ../txt2img
|
| 403 |
+
--steps 80 --sampler ddim --fp16 --opt_channels_last --xformers
|
| 404 |
+
--images_per_prompt 1 --interactive --clip_image_guidance_scale 100
|
| 405 |
+
--guide_image_path YUKA160113420I9A4104_TP_V.jpg
|
| 406 |
+
```
|
| 407 |
+
|
| 408 |
+
### VGG16 Guided Stable Diffusion
|
| 409 |
+
|
| 410 |
+
指定した画像に近づくように画像生成する機能です。通常のプロンプトによる生成指定に加えて、追加でVGG16の特徴量を取得し、生成中の画像が指定したガイド画像に近づくよう、生成される画像をコントロールします。img2imgでの使用をお勧めします(通常の生成では画像がぼやけた感じになります)。CLIP Guided Stable Diffusionの仕組みを流用した独自の機能です。またアイデアはVGGを利用したスタイル変換から拝借しています。
|
| 411 |
+
|
| 412 |
+
なお選択できるサンプラーはDDIM、PNDM、LMSのみとなります。
|
| 413 |
+
|
| 414 |
+
`--vgg16_guidance_scale`オプションにどの程度、VGG16特徴量を反映するかを数値で指定します。試した感じでは100くらいから始めて増減すると良いようです。`--guide_image_path`オプションでguideに使用する画像(ファイルまたはフォルダ)を指定してください。
|
| 415 |
+
|
| 416 |
+
複数枚の画像を一括でimg2img変換し、元画像をガイド画像とする場合、`--guide_image_path`と`--image_path`に同じ値を指定すればOKです。
|
| 417 |
+
|
| 418 |
+
コマンドラインの例です。
|
| 419 |
+
|
| 420 |
+
```batchfile
|
| 421 |
+
python gen_img_diffusers.py --ckpt wd-v1-3-full-pruned-half.ckpt
|
| 422 |
+
--n_iter 1 --scale 5.5 --steps 60 --outdir ../txt2img
|
| 423 |
+
--xformers --sampler ddim --fp16 --W 512 --H 704
|
| 424 |
+
--batch_size 1 --images_per_prompt 1
|
| 425 |
+
--prompt "picturesque, 1girl, solo, anime face, skirt, beautiful face
|
| 426 |
+
--n lowres, bad anatomy, bad hands, error, missing fingers,
|
| 427 |
+
cropped, worst quality, low quality, normal quality,
|
| 428 |
+
jpeg artifacts, blurry, 3d, bad face, monochrome --d 1"
|
| 429 |
+
--strength 0.8 --image_path ..\src_image
|
| 430 |
+
--vgg16_guidance_scale 100 --guide_image_path ..\src_image
|
| 431 |
+
```
|
| 432 |
+
|
| 433 |
+
`--vgg16_guidance_layerPで特徴量取得に使用するVGG16のレイヤー番号を指定できます(デフォルトは20でconv4-2のReLUです)。上の層ほど画風を表現し、下の層ほどコンテンツを表現するといわれています。
|
| 434 |
+
|
| 435 |
+

|
| 436 |
+
|
| 437 |
+
# その他のオプション
|
| 438 |
+
|
| 439 |
+
- `--no_preview` : 対話モードでプレビュー画像を表示しません。OpenCVがインストールされていない場合や、出力されたファイルを直接確認する場合に指定してください。
|
| 440 |
+
|
| 441 |
+
- `--n_iter` : 生成を繰り返す回数を指定します。デフォルトは1です。プロンプトをファイルから読み込むとき、複数回の生成を行いたい場合に指定します。
|
| 442 |
+
|
| 443 |
+
- `--tokenizer_cache_dir` : トークナイザーのキャッシュディレクトリを指定します。(作業中)
|
| 444 |
+
|
| 445 |
+
- `--seed` : 乱数seedを指定します。1枚生成時はその画像のseed、複数枚生成時は各画像のseedを生成するための乱数のseedになります(`--from_file`で複数画像生成するとき、`--seed`オプションを指定すると複数回実行したときに各画像が同じseedになります)。
|
| 446 |
+
|
| 447 |
+
- `--iter_same_seed` : プロンプトに乱数seedの指定がないとき、`--n_iter`の繰り返し内ではすべて同じseedを使います。`--from_file`で指定した複数のプロンプト間でseedを統一して比較するときに使います。
|
| 448 |
+
|
| 449 |
+
- `--diffusers_xformers` : Diffuserのxformersを使用します。
|
| 450 |
+
|
| 451 |
+
- `--opt_channels_last` : 推論時にテンソルのチャンネルを最後に配置します。場合によっては高速化されることがあります。
|
| 452 |
+
|
| 453 |
+
- `--network_show_meta` : 追加ネットワークのメタデータを表示します。
|
| 454 |
+
|
| 455 |
+
|
| 456 |
+
---
|
| 457 |
+
|
| 458 |
+
# About Gradual Latent
|
| 459 |
+
|
| 460 |
+
Gradual Latent is a Hires fix that gradually increases the size of the latent. `gen_img.py`, `sdxl_gen_img.py`, and `gen_img_diffusers.py` have the following options.
|
| 461 |
+
|
| 462 |
+
- `--gradual_latent_timesteps`: Specifies the timestep to start increasing the size of the latent. The default is None, which means Gradual Latent is not used. Please try around 750 at first.
|
| 463 |
+
- `--gradual_latent_ratio`: Specifies the initial size of the latent. The default is 0.5, which means it starts with half the default latent size.
|
| 464 |
+
- `--gradual_latent_ratio_step`: Specifies the ratio to increase the size of the latent. The default is 0.125, which means the latent size is gradually increased to 0.625, 0.75, 0.875, 1.0.
|
| 465 |
+
- `--gradual_latent_ratio_every_n_steps`: Specifies the interval to increase the size of the latent. The default is 3, which means the latent size is increased every 3 steps.
|
| 466 |
+
|
| 467 |
+
Each option can also be specified with prompt options, `--glt`, `--glr`, `--gls`, `--gle`.
|
| 468 |
+
|
| 469 |
+
__Please specify `euler_a` for the sampler.__ Because the source code of the sampler is modified. It will not work with other samplers.
|
| 470 |
+
|
| 471 |
+
It is more effective with SD 1.5. It is quite subtle with SDXL.
|
| 472 |
+
|
| 473 |
+
# Gradual Latent について
|
| 474 |
+
|
| 475 |
+
latentのサイズを徐々に大きくしていくHires fixです。`gen_img.py` 、``sdxl_gen_img.py` 、`gen_img_diffusers.py` に以下のオプションが追加されています。
|
| 476 |
+
|
| 477 |
+
- `--gradual_latent_timesteps` : latentのサイズを大きくし始めるタイムステップを指定します。デフォルトは None で、Gradual Latentを使用しません。750 くらいから始めてみてください。
|
| 478 |
+
- `--gradual_latent_ratio` : latentの初期サイズを指定します。デフォルトは 0.5 で、デフォルトの latent サイズの半分のサイズから始めます。
|
| 479 |
+
- `--gradual_latent_ratio_step`: latentのサイズを大きくする割合を指定します。デフォルトは 0.125 で、latentのサイズを 0.625, 0.75, 0.875, 1.0 と徐々に大きくします。
|
| 480 |
+
- `--gradual_latent_ratio_every_n_steps`: latentのサイズを大きくする間隔を指定しま��。デフォルトは 3 で、3ステップごとに latent のサイズを大きくします。
|
| 481 |
+
|
| 482 |
+
それぞれのオプションは、プロンプトオプション、`--glt`、`--glr`、`--gls`、`--gle` でも指定できます。
|
| 483 |
+
|
| 484 |
+
サンプラーに手を加えているため、__サンプラーに `euler_a` を指定してください。__ 他のサンプラーでは動作しません。
|
| 485 |
+
|
| 486 |
+
SD 1.5 のほうが効果があります。SDXL ではかなり微妙です。
|
| 487 |
+
|
gradscaler.py
ADDED
|
@@ -0,0 +1,183 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from collections import defaultdict
|
| 2 |
+
import torch
|
| 3 |
+
import intel_extension_for_pytorch as ipex # pylint: disable=import-error, unused-import
|
| 4 |
+
import intel_extension_for_pytorch._C as core # pylint: disable=import-error, unused-import
|
| 5 |
+
|
| 6 |
+
# pylint: disable=protected-access, missing-function-docstring, line-too-long
|
| 7 |
+
|
| 8 |
+
device_supports_fp64 = torch.xpu.has_fp64_dtype() if hasattr(torch.xpu, "has_fp64_dtype") else torch.xpu.get_device_properties("xpu").has_fp64
|
| 9 |
+
OptState = ipex.cpu.autocast._grad_scaler.OptState
|
| 10 |
+
_MultiDeviceReplicator = ipex.cpu.autocast._grad_scaler._MultiDeviceReplicator
|
| 11 |
+
_refresh_per_optimizer_state = ipex.cpu.autocast._grad_scaler._refresh_per_optimizer_state
|
| 12 |
+
|
| 13 |
+
def _unscale_grads_(self, optimizer, inv_scale, found_inf, allow_fp16): # pylint: disable=unused-argument
|
| 14 |
+
per_device_inv_scale = _MultiDeviceReplicator(inv_scale)
|
| 15 |
+
per_device_found_inf = _MultiDeviceReplicator(found_inf)
|
| 16 |
+
|
| 17 |
+
# To set up _amp_foreach_non_finite_check_and_unscale_, split grads by device and dtype.
|
| 18 |
+
# There could be hundreds of grads, so we'd like to iterate through them just once.
|
| 19 |
+
# However, we don't know their devices or dtypes in advance.
|
| 20 |
+
|
| 21 |
+
# https://stackoverflow.com/questions/5029934/defaultdict-of-defaultdict
|
| 22 |
+
# Google says mypy struggles with defaultdicts type annotations.
|
| 23 |
+
per_device_and_dtype_grads = defaultdict(lambda: defaultdict(list)) # type: ignore[var-annotated]
|
| 24 |
+
# sync grad to master weight
|
| 25 |
+
if hasattr(optimizer, "sync_grad"):
|
| 26 |
+
optimizer.sync_grad()
|
| 27 |
+
with torch.no_grad():
|
| 28 |
+
for group in optimizer.param_groups:
|
| 29 |
+
for param in group["params"]:
|
| 30 |
+
if param.grad is None:
|
| 31 |
+
continue
|
| 32 |
+
if (not allow_fp16) and param.grad.dtype == torch.float16:
|
| 33 |
+
raise ValueError("Attempting to unscale FP16 gradients.")
|
| 34 |
+
if param.grad.is_sparse:
|
| 35 |
+
# is_coalesced() == False means the sparse grad has values with duplicate indices.
|
| 36 |
+
# coalesce() deduplicates indices and adds all values that have the same index.
|
| 37 |
+
# For scaled fp16 values, there's a good chance coalescing will cause overflow,
|
| 38 |
+
# so we should check the coalesced _values().
|
| 39 |
+
if param.grad.dtype is torch.float16:
|
| 40 |
+
param.grad = param.grad.coalesce()
|
| 41 |
+
to_unscale = param.grad._values()
|
| 42 |
+
else:
|
| 43 |
+
to_unscale = param.grad
|
| 44 |
+
|
| 45 |
+
# -: is there a way to split by device and dtype without appending in the inner loop?
|
| 46 |
+
to_unscale = to_unscale.to("cpu")
|
| 47 |
+
per_device_and_dtype_grads[to_unscale.device][
|
| 48 |
+
to_unscale.dtype
|
| 49 |
+
].append(to_unscale)
|
| 50 |
+
|
| 51 |
+
for _, per_dtype_grads in per_device_and_dtype_grads.items():
|
| 52 |
+
for grads in per_dtype_grads.values():
|
| 53 |
+
core._amp_foreach_non_finite_check_and_unscale_(
|
| 54 |
+
grads,
|
| 55 |
+
per_device_found_inf.get("cpu"),
|
| 56 |
+
per_device_inv_scale.get("cpu"),
|
| 57 |
+
)
|
| 58 |
+
|
| 59 |
+
return per_device_found_inf._per_device_tensors
|
| 60 |
+
|
| 61 |
+
def unscale_(self, optimizer):
|
| 62 |
+
"""
|
| 63 |
+
Divides ("unscales") the optimizer's gradient tensors by the scale factor.
|
| 64 |
+
:meth:`unscale_` is optional, serving cases where you need to
|
| 65 |
+
:ref:`modify or inspect gradients<working-with-unscaled-gradients>`
|
| 66 |
+
between the backward pass(es) and :meth:`step`.
|
| 67 |
+
If :meth:`unscale_` is not called explicitly, gradients will be unscaled automatically during :meth:`step`.
|
| 68 |
+
Simple example, using :meth:`unscale_` to enable clipping of unscaled gradients::
|
| 69 |
+
...
|
| 70 |
+
scaler.scale(loss).backward()
|
| 71 |
+
scaler.unscale_(optimizer)
|
| 72 |
+
torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm)
|
| 73 |
+
scaler.step(optimizer)
|
| 74 |
+
scaler.update()
|
| 75 |
+
Args:
|
| 76 |
+
optimizer (torch.optim.Optimizer): Optimizer that owns the gradients to be unscaled.
|
| 77 |
+
.. warning::
|
| 78 |
+
:meth:`unscale_` should only be called once per optimizer per :meth:`step` call,
|
| 79 |
+
and only after all gradients for that optimizer's assigned parameters have been accumulated.
|
| 80 |
+
Calling :meth:`unscale_` twice for a given optimizer between each :meth:`step` triggers a RuntimeError.
|
| 81 |
+
.. warning::
|
| 82 |
+
:meth:`unscale_` may unscale sparse gradients out of place, replacing the ``.grad`` attribute.
|
| 83 |
+
"""
|
| 84 |
+
if not self._enabled:
|
| 85 |
+
return
|
| 86 |
+
|
| 87 |
+
self._check_scale_growth_tracker("unscale_")
|
| 88 |
+
|
| 89 |
+
optimizer_state = self._per_optimizer_states[id(optimizer)]
|
| 90 |
+
|
| 91 |
+
if optimizer_state["stage"] is OptState.UNSCALED: # pylint: disable=no-else-raise
|
| 92 |
+
raise RuntimeError(
|
| 93 |
+
"unscale_() has already been called on this optimizer since the last update()."
|
| 94 |
+
)
|
| 95 |
+
elif optimizer_state["stage"] is OptState.STEPPED:
|
| 96 |
+
raise RuntimeError("unscale_() is being called after step().")
|
| 97 |
+
|
| 98 |
+
# FP32 division can be imprecise for certain compile options, so we carry out the reciprocal in FP64.
|
| 99 |
+
assert self._scale is not None
|
| 100 |
+
if device_supports_fp64:
|
| 101 |
+
inv_scale = self._scale.double().reciprocal().float()
|
| 102 |
+
else:
|
| 103 |
+
inv_scale = self._scale.to("cpu").double().reciprocal().float().to(self._scale.device)
|
| 104 |
+
found_inf = torch.full(
|
| 105 |
+
(1,), 0.0, dtype=torch.float32, device=self._scale.device
|
| 106 |
+
)
|
| 107 |
+
|
| 108 |
+
optimizer_state["found_inf_per_device"] = self._unscale_grads_(
|
| 109 |
+
optimizer, inv_scale, found_inf, False
|
| 110 |
+
)
|
| 111 |
+
optimizer_state["stage"] = OptState.UNSCALED
|
| 112 |
+
|
| 113 |
+
def update(self, new_scale=None):
|
| 114 |
+
"""
|
| 115 |
+
Updates the scale factor.
|
| 116 |
+
If any optimizer steps were skipped the scale is multiplied by ``backoff_factor``
|
| 117 |
+
to reduce it. If ``growth_interval`` unskipped iterations occurred consecutively,
|
| 118 |
+
the scale is multiplied by ``growth_factor`` to increase it.
|
| 119 |
+
Passing ``new_scale`` sets the new scale value manually. (``new_scale`` is not
|
| 120 |
+
used directly, it's used to fill GradScaler's internal scale tensor. So if
|
| 121 |
+
``new_scale`` was a tensor, later in-place changes to that tensor will not further
|
| 122 |
+
affect the scale GradScaler uses internally.)
|
| 123 |
+
Args:
|
| 124 |
+
new_scale (float or :class:`torch.FloatTensor`, optional, default=None): New scale factor.
|
| 125 |
+
.. warning::
|
| 126 |
+
:meth:`update` should only be called at the end of the iteration, after ``scaler.step(optimizer)`` has
|
| 127 |
+
been invoked for all optimizers used this iteration.
|
| 128 |
+
"""
|
| 129 |
+
if not self._enabled:
|
| 130 |
+
return
|
| 131 |
+
|
| 132 |
+
_scale, _growth_tracker = self._check_scale_growth_tracker("update")
|
| 133 |
+
|
| 134 |
+
if new_scale is not None:
|
| 135 |
+
# Accept a new user-defined scale.
|
| 136 |
+
if isinstance(new_scale, float):
|
| 137 |
+
self._scale.fill_(new_scale) # type: ignore[union-attr]
|
| 138 |
+
else:
|
| 139 |
+
reason = "new_scale should be a float or a 1-element torch.FloatTensor with requires_grad=False."
|
| 140 |
+
assert isinstance(new_scale, torch.FloatTensor), reason # type: ignore[attr-defined]
|
| 141 |
+
assert new_scale.numel() == 1, reason
|
| 142 |
+
assert new_scale.requires_grad is False, reason
|
| 143 |
+
self._scale.copy_(new_scale) # type: ignore[union-attr]
|
| 144 |
+
else:
|
| 145 |
+
# Consume shared inf/nan data collected from optimizers to update the scale.
|
| 146 |
+
# If all found_inf tensors are on the same device as self._scale, this operation is asynchronous.
|
| 147 |
+
found_infs = [
|
| 148 |
+
found_inf.to(device="cpu", non_blocking=True)
|
| 149 |
+
for state in self._per_optimizer_states.values()
|
| 150 |
+
for found_inf in state["found_inf_per_device"].values()
|
| 151 |
+
]
|
| 152 |
+
|
| 153 |
+
assert len(found_infs) > 0, "No inf checks were recorded prior to update."
|
| 154 |
+
|
| 155 |
+
found_inf_combined = found_infs[0]
|
| 156 |
+
if len(found_infs) > 1:
|
| 157 |
+
for i in range(1, len(found_infs)):
|
| 158 |
+
found_inf_combined += found_infs[i]
|
| 159 |
+
|
| 160 |
+
to_device = _scale.device
|
| 161 |
+
_scale = _scale.to("cpu")
|
| 162 |
+
_growth_tracker = _growth_tracker.to("cpu")
|
| 163 |
+
|
| 164 |
+
core._amp_update_scale_(
|
| 165 |
+
_scale,
|
| 166 |
+
_growth_tracker,
|
| 167 |
+
found_inf_combined,
|
| 168 |
+
self._growth_factor,
|
| 169 |
+
self._backoff_factor,
|
| 170 |
+
self._growth_interval,
|
| 171 |
+
)
|
| 172 |
+
|
| 173 |
+
_scale = _scale.to(to_device)
|
| 174 |
+
_growth_tracker = _growth_tracker.to(to_device)
|
| 175 |
+
# To prepare for next iteration, clear the data collected from optimizers this iteration.
|
| 176 |
+
self._per_optimizer_states = defaultdict(_refresh_per_optimizer_state)
|
| 177 |
+
|
| 178 |
+
def gradscaler_init():
|
| 179 |
+
torch.xpu.amp.GradScaler = ipex.cpu.autocast._grad_scaler.GradScaler
|
| 180 |
+
torch.xpu.amp.GradScaler._unscale_grads_ = _unscale_grads_
|
| 181 |
+
torch.xpu.amp.GradScaler.unscale_ = unscale_
|
| 182 |
+
torch.xpu.amp.GradScaler.update = update
|
| 183 |
+
return torch.xpu.amp.GradScaler
|
hijacks.py
ADDED
|
@@ -0,0 +1,367 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
from functools import wraps
|
| 3 |
+
from contextlib import nullcontext
|
| 4 |
+
import torch
|
| 5 |
+
import numpy as np
|
| 6 |
+
|
| 7 |
+
device_supports_fp64 = torch.xpu.has_fp64_dtype() if hasattr(torch.xpu, "has_fp64_dtype") else torch.xpu.get_device_properties("xpu").has_fp64
|
| 8 |
+
if os.environ.get('IPEX_FORCE_ATTENTION_SLICE', '0') == '0' and (torch.xpu.get_device_properties("xpu").total_memory / 1024 / 1024 / 1024) > 4.1:
|
| 9 |
+
try:
|
| 10 |
+
x = torch.ones((33000,33000), dtype=torch.float32, device="xpu")
|
| 11 |
+
del x
|
| 12 |
+
torch.xpu.empty_cache()
|
| 13 |
+
can_allocate_plus_4gb = True
|
| 14 |
+
except Exception:
|
| 15 |
+
can_allocate_plus_4gb = False
|
| 16 |
+
else:
|
| 17 |
+
can_allocate_plus_4gb = bool(os.environ.get('IPEX_FORCE_ATTENTION_SLICE', '0') == '-1')
|
| 18 |
+
|
| 19 |
+
# pylint: disable=protected-access, missing-function-docstring, line-too-long, unnecessary-lambda, no-else-return
|
| 20 |
+
|
| 21 |
+
class DummyDataParallel(torch.nn.Module): # pylint: disable=missing-class-docstring, unused-argument, too-few-public-methods
|
| 22 |
+
def __new__(cls, module, device_ids=None, output_device=None, dim=0): # pylint: disable=unused-argument
|
| 23 |
+
if isinstance(device_ids, list) and len(device_ids) > 1:
|
| 24 |
+
print("IPEX backend doesn't support DataParallel on multiple XPU devices")
|
| 25 |
+
return module.to("xpu")
|
| 26 |
+
|
| 27 |
+
def return_null_context(*args, **kwargs): # pylint: disable=unused-argument
|
| 28 |
+
return nullcontext()
|
| 29 |
+
|
| 30 |
+
@property
|
| 31 |
+
def is_cuda(self):
|
| 32 |
+
return self.device.type == 'xpu' or self.device.type == 'cuda'
|
| 33 |
+
|
| 34 |
+
def check_device(device):
|
| 35 |
+
return bool((isinstance(device, torch.device) and device.type == "cuda") or (isinstance(device, str) and "cuda" in device) or isinstance(device, int))
|
| 36 |
+
|
| 37 |
+
def return_xpu(device):
|
| 38 |
+
return f"xpu:{device.split(':')[-1]}" if isinstance(device, str) and ":" in device else f"xpu:{device}" if isinstance(device, int) else torch.device(f"xpu:{device.index}" if device.index is not None else "xpu") if isinstance(device, torch.device) else "xpu"
|
| 39 |
+
|
| 40 |
+
|
| 41 |
+
# Autocast
|
| 42 |
+
original_autocast_init = torch.amp.autocast_mode.autocast.__init__
|
| 43 |
+
@wraps(torch.amp.autocast_mode.autocast.__init__)
|
| 44 |
+
def autocast_init(self, device_type, dtype=None, enabled=True, cache_enabled=None):
|
| 45 |
+
if device_type == "cuda":
|
| 46 |
+
return original_autocast_init(self, device_type="xpu", dtype=dtype, enabled=enabled, cache_enabled=cache_enabled)
|
| 47 |
+
else:
|
| 48 |
+
return original_autocast_init(self, device_type=device_type, dtype=dtype, enabled=enabled, cache_enabled=cache_enabled)
|
| 49 |
+
|
| 50 |
+
# Latent Antialias CPU Offload:
|
| 51 |
+
original_interpolate = torch.nn.functional.interpolate
|
| 52 |
+
@wraps(torch.nn.functional.interpolate)
|
| 53 |
+
def interpolate(tensor, size=None, scale_factor=None, mode='nearest', align_corners=None, recompute_scale_factor=None, antialias=False): # pylint: disable=too-many-arguments
|
| 54 |
+
if mode in {'bicubic', 'bilinear'}:
|
| 55 |
+
return_device = tensor.device
|
| 56 |
+
return_dtype = tensor.dtype
|
| 57 |
+
return original_interpolate(tensor.to("cpu", dtype=torch.float32), size=size, scale_factor=scale_factor, mode=mode,
|
| 58 |
+
align_corners=align_corners, recompute_scale_factor=recompute_scale_factor, antialias=antialias).to(return_device, dtype=return_dtype)
|
| 59 |
+
else:
|
| 60 |
+
return original_interpolate(tensor, size=size, scale_factor=scale_factor, mode=mode,
|
| 61 |
+
align_corners=align_corners, recompute_scale_factor=recompute_scale_factor, antialias=antialias)
|
| 62 |
+
|
| 63 |
+
|
| 64 |
+
# Diffusers Float64 (Alchemist GPUs doesn't support 64 bit):
|
| 65 |
+
original_from_numpy = torch.from_numpy
|
| 66 |
+
@wraps(torch.from_numpy)
|
| 67 |
+
def from_numpy(ndarray):
|
| 68 |
+
if ndarray.dtype == float:
|
| 69 |
+
return original_from_numpy(ndarray.astype('float32'))
|
| 70 |
+
else:
|
| 71 |
+
return original_from_numpy(ndarray)
|
| 72 |
+
|
| 73 |
+
original_as_tensor = torch.as_tensor
|
| 74 |
+
@wraps(torch.as_tensor)
|
| 75 |
+
def as_tensor(data, dtype=None, device=None):
|
| 76 |
+
if check_device(device):
|
| 77 |
+
device = return_xpu(device)
|
| 78 |
+
if isinstance(data, np.ndarray) and data.dtype == float and not (
|
| 79 |
+
(isinstance(device, torch.device) and device.type == "cpu") or (isinstance(device, str) and "cpu" in device)):
|
| 80 |
+
return original_as_tensor(data, dtype=torch.float32, device=device)
|
| 81 |
+
else:
|
| 82 |
+
return original_as_tensor(data, dtype=dtype, device=device)
|
| 83 |
+
|
| 84 |
+
|
| 85 |
+
if can_allocate_plus_4gb:
|
| 86 |
+
original_scaled_dot_product_attention = torch.nn.functional.scaled_dot_product_attention
|
| 87 |
+
else:
|
| 88 |
+
# 32 bit attention workarounds for Alchemist:
|
| 89 |
+
try:
|
| 90 |
+
from .attention import dynamic_scaled_dot_product_attention as original_scaled_dot_product_attention
|
| 91 |
+
except Exception: # pylint: disable=broad-exception-caught
|
| 92 |
+
original_scaled_dot_product_attention = torch.nn.functional.scaled_dot_product_attention
|
| 93 |
+
|
| 94 |
+
@wraps(torch.nn.functional.scaled_dot_product_attention)
|
| 95 |
+
def scaled_dot_product_attention(query, key, value, attn_mask=None, dropout_p=0.0, is_causal=False, **kwargs):
|
| 96 |
+
if query.dtype != key.dtype:
|
| 97 |
+
key = key.to(dtype=query.dtype)
|
| 98 |
+
if query.dtype != value.dtype:
|
| 99 |
+
value = value.to(dtype=query.dtype)
|
| 100 |
+
if attn_mask is not None and query.dtype != attn_mask.dtype:
|
| 101 |
+
attn_mask = attn_mask.to(dtype=query.dtype)
|
| 102 |
+
return original_scaled_dot_product_attention(query, key, value, attn_mask=attn_mask, dropout_p=dropout_p, is_causal=is_causal, **kwargs)
|
| 103 |
+
|
| 104 |
+
# Data Type Errors:
|
| 105 |
+
original_torch_bmm = torch.bmm
|
| 106 |
+
@wraps(torch.bmm)
|
| 107 |
+
def torch_bmm(input, mat2, *, out=None):
|
| 108 |
+
if input.dtype != mat2.dtype:
|
| 109 |
+
mat2 = mat2.to(input.dtype)
|
| 110 |
+
return original_torch_bmm(input, mat2, out=out)
|
| 111 |
+
|
| 112 |
+
# Diffusers FreeU
|
| 113 |
+
original_fft_fftn = torch.fft.fftn
|
| 114 |
+
@wraps(torch.fft.fftn)
|
| 115 |
+
def fft_fftn(input, s=None, dim=None, norm=None, *, out=None):
|
| 116 |
+
return_dtype = input.dtype
|
| 117 |
+
return original_fft_fftn(input.to(dtype=torch.float32), s=s, dim=dim, norm=norm, out=out).to(dtype=return_dtype)
|
| 118 |
+
|
| 119 |
+
# Diffusers FreeU
|
| 120 |
+
original_fft_ifftn = torch.fft.ifftn
|
| 121 |
+
@wraps(torch.fft.ifftn)
|
| 122 |
+
def fft_ifftn(input, s=None, dim=None, norm=None, *, out=None):
|
| 123 |
+
return_dtype = input.dtype
|
| 124 |
+
return original_fft_ifftn(input.to(dtype=torch.float32), s=s, dim=dim, norm=norm, out=out).to(dtype=return_dtype)
|
| 125 |
+
|
| 126 |
+
# A1111 FP16
|
| 127 |
+
original_functional_group_norm = torch.nn.functional.group_norm
|
| 128 |
+
@wraps(torch.nn.functional.group_norm)
|
| 129 |
+
def functional_group_norm(input, num_groups, weight=None, bias=None, eps=1e-05):
|
| 130 |
+
if weight is not None and input.dtype != weight.data.dtype:
|
| 131 |
+
input = input.to(dtype=weight.data.dtype)
|
| 132 |
+
if bias is not None and weight is not None and bias.data.dtype != weight.data.dtype:
|
| 133 |
+
bias.data = bias.data.to(dtype=weight.data.dtype)
|
| 134 |
+
return original_functional_group_norm(input, num_groups, weight=weight, bias=bias, eps=eps)
|
| 135 |
+
|
| 136 |
+
# A1111 BF16
|
| 137 |
+
original_functional_layer_norm = torch.nn.functional.layer_norm
|
| 138 |
+
@wraps(torch.nn.functional.layer_norm)
|
| 139 |
+
def functional_layer_norm(input, normalized_shape, weight=None, bias=None, eps=1e-05):
|
| 140 |
+
if weight is not None and input.dtype != weight.data.dtype:
|
| 141 |
+
input = input.to(dtype=weight.data.dtype)
|
| 142 |
+
if bias is not None and weight is not None and bias.data.dtype != weight.data.dtype:
|
| 143 |
+
bias.data = bias.data.to(dtype=weight.data.dtype)
|
| 144 |
+
return original_functional_layer_norm(input, normalized_shape, weight=weight, bias=bias, eps=eps)
|
| 145 |
+
|
| 146 |
+
# Training
|
| 147 |
+
original_functional_linear = torch.nn.functional.linear
|
| 148 |
+
@wraps(torch.nn.functional.linear)
|
| 149 |
+
def functional_linear(input, weight, bias=None):
|
| 150 |
+
if input.dtype != weight.data.dtype:
|
| 151 |
+
input = input.to(dtype=weight.data.dtype)
|
| 152 |
+
if bias is not None and bias.data.dtype != weight.data.dtype:
|
| 153 |
+
bias.data = bias.data.to(dtype=weight.data.dtype)
|
| 154 |
+
return original_functional_linear(input, weight, bias=bias)
|
| 155 |
+
|
| 156 |
+
original_functional_conv1d = torch.nn.functional.conv1d
|
| 157 |
+
@wraps(torch.nn.functional.conv1d)
|
| 158 |
+
def functional_conv1d(input, weight, bias=None, stride=1, padding=0, dilation=1, groups=1):
|
| 159 |
+
if input.dtype != weight.data.dtype:
|
| 160 |
+
input = input.to(dtype=weight.data.dtype)
|
| 161 |
+
if bias is not None and bias.data.dtype != weight.data.dtype:
|
| 162 |
+
bias.data = bias.data.to(dtype=weight.data.dtype)
|
| 163 |
+
return original_functional_conv1d(input, weight, bias=bias, stride=stride, padding=padding, dilation=dilation, groups=groups)
|
| 164 |
+
|
| 165 |
+
original_functional_conv2d = torch.nn.functional.conv2d
|
| 166 |
+
@wraps(torch.nn.functional.conv2d)
|
| 167 |
+
def functional_conv2d(input, weight, bias=None, stride=1, padding=0, dilation=1, groups=1):
|
| 168 |
+
if input.dtype != weight.data.dtype:
|
| 169 |
+
input = input.to(dtype=weight.data.dtype)
|
| 170 |
+
if bias is not None and bias.data.dtype != weight.data.dtype:
|
| 171 |
+
bias.data = bias.data.to(dtype=weight.data.dtype)
|
| 172 |
+
return original_functional_conv2d(input, weight, bias=bias, stride=stride, padding=padding, dilation=dilation, groups=groups)
|
| 173 |
+
|
| 174 |
+
# LTX Video
|
| 175 |
+
original_functional_conv3d = torch.nn.functional.conv3d
|
| 176 |
+
@wraps(torch.nn.functional.conv3d)
|
| 177 |
+
def functional_conv3d(input, weight, bias=None, stride=1, padding=0, dilation=1, groups=1):
|
| 178 |
+
if input.dtype != weight.data.dtype:
|
| 179 |
+
input = input.to(dtype=weight.data.dtype)
|
| 180 |
+
if bias is not None and bias.data.dtype != weight.data.dtype:
|
| 181 |
+
bias.data = bias.data.to(dtype=weight.data.dtype)
|
| 182 |
+
return original_functional_conv3d(input, weight, bias=bias, stride=stride, padding=padding, dilation=dilation, groups=groups)
|
| 183 |
+
|
| 184 |
+
# SwinIR BF16:
|
| 185 |
+
original_functional_pad = torch.nn.functional.pad
|
| 186 |
+
@wraps(torch.nn.functional.pad)
|
| 187 |
+
def functional_pad(input, pad, mode='constant', value=None):
|
| 188 |
+
if mode == 'reflect' and input.dtype == torch.bfloat16:
|
| 189 |
+
return original_functional_pad(input.to(torch.float32), pad, mode=mode, value=value).to(dtype=torch.bfloat16)
|
| 190 |
+
else:
|
| 191 |
+
return original_functional_pad(input, pad, mode=mode, value=value)
|
| 192 |
+
|
| 193 |
+
|
| 194 |
+
original_torch_tensor = torch.tensor
|
| 195 |
+
@wraps(torch.tensor)
|
| 196 |
+
def torch_tensor(data, *args, dtype=None, device=None, **kwargs):
|
| 197 |
+
global device_supports_fp64
|
| 198 |
+
if check_device(device):
|
| 199 |
+
device = return_xpu(device)
|
| 200 |
+
if not device_supports_fp64:
|
| 201 |
+
if (isinstance(device, torch.device) and device.type == "xpu") or (isinstance(device, str) and "xpu" in device):
|
| 202 |
+
if dtype == torch.float64:
|
| 203 |
+
dtype = torch.float32
|
| 204 |
+
elif dtype is None and (hasattr(data, "dtype") and (data.dtype == torch.float64 or data.dtype == float)):
|
| 205 |
+
dtype = torch.float32
|
| 206 |
+
return original_torch_tensor(data, *args, dtype=dtype, device=device, **kwargs)
|
| 207 |
+
|
| 208 |
+
original_Tensor_to = torch.Tensor.to
|
| 209 |
+
@wraps(torch.Tensor.to)
|
| 210 |
+
def Tensor_to(self, device=None, *args, **kwargs):
|
| 211 |
+
if check_device(device):
|
| 212 |
+
return original_Tensor_to(self, return_xpu(device), *args, **kwargs)
|
| 213 |
+
else:
|
| 214 |
+
return original_Tensor_to(self, device, *args, **kwargs)
|
| 215 |
+
|
| 216 |
+
original_Tensor_cuda = torch.Tensor.cuda
|
| 217 |
+
@wraps(torch.Tensor.cuda)
|
| 218 |
+
def Tensor_cuda(self, device=None, *args, **kwargs):
|
| 219 |
+
if check_device(device):
|
| 220 |
+
return original_Tensor_cuda(self, return_xpu(device), *args, **kwargs)
|
| 221 |
+
else:
|
| 222 |
+
return original_Tensor_cuda(self, device, *args, **kwargs)
|
| 223 |
+
|
| 224 |
+
original_Tensor_pin_memory = torch.Tensor.pin_memory
|
| 225 |
+
@wraps(torch.Tensor.pin_memory)
|
| 226 |
+
def Tensor_pin_memory(self, device=None, *args, **kwargs):
|
| 227 |
+
if device is None:
|
| 228 |
+
device = "xpu"
|
| 229 |
+
if check_device(device):
|
| 230 |
+
return original_Tensor_pin_memory(self, return_xpu(device), *args, **kwargs)
|
| 231 |
+
else:
|
| 232 |
+
return original_Tensor_pin_memory(self, device, *args, **kwargs)
|
| 233 |
+
|
| 234 |
+
original_UntypedStorage_init = torch.UntypedStorage.__init__
|
| 235 |
+
@wraps(torch.UntypedStorage.__init__)
|
| 236 |
+
def UntypedStorage_init(*args, device=None, **kwargs):
|
| 237 |
+
if check_device(device):
|
| 238 |
+
return original_UntypedStorage_init(*args, device=return_xpu(device), **kwargs)
|
| 239 |
+
else:
|
| 240 |
+
return original_UntypedStorage_init(*args, device=device, **kwargs)
|
| 241 |
+
|
| 242 |
+
original_UntypedStorage_cuda = torch.UntypedStorage.cuda
|
| 243 |
+
@wraps(torch.UntypedStorage.cuda)
|
| 244 |
+
def UntypedStorage_cuda(self, device=None, *args, **kwargs):
|
| 245 |
+
if check_device(device):
|
| 246 |
+
return original_UntypedStorage_cuda(self, return_xpu(device), *args, **kwargs)
|
| 247 |
+
else:
|
| 248 |
+
return original_UntypedStorage_cuda(self, device, *args, **kwargs)
|
| 249 |
+
|
| 250 |
+
original_torch_empty = torch.empty
|
| 251 |
+
@wraps(torch.empty)
|
| 252 |
+
def torch_empty(*args, device=None, **kwargs):
|
| 253 |
+
if check_device(device):
|
| 254 |
+
return original_torch_empty(*args, device=return_xpu(device), **kwargs)
|
| 255 |
+
else:
|
| 256 |
+
return original_torch_empty(*args, device=device, **kwargs)
|
| 257 |
+
|
| 258 |
+
original_torch_randn = torch.randn
|
| 259 |
+
@wraps(torch.randn)
|
| 260 |
+
def torch_randn(*args, device=None, dtype=None, **kwargs):
|
| 261 |
+
if dtype is bytes:
|
| 262 |
+
dtype = None
|
| 263 |
+
if check_device(device):
|
| 264 |
+
return original_torch_randn(*args, device=return_xpu(device), **kwargs)
|
| 265 |
+
else:
|
| 266 |
+
return original_torch_randn(*args, device=device, **kwargs)
|
| 267 |
+
|
| 268 |
+
original_torch_ones = torch.ones
|
| 269 |
+
@wraps(torch.ones)
|
| 270 |
+
def torch_ones(*args, device=None, **kwargs):
|
| 271 |
+
if check_device(device):
|
| 272 |
+
return original_torch_ones(*args, device=return_xpu(device), **kwargs)
|
| 273 |
+
else:
|
| 274 |
+
return original_torch_ones(*args, device=device, **kwargs)
|
| 275 |
+
|
| 276 |
+
original_torch_zeros = torch.zeros
|
| 277 |
+
@wraps(torch.zeros)
|
| 278 |
+
def torch_zeros(*args, device=None, **kwargs):
|
| 279 |
+
if check_device(device):
|
| 280 |
+
return original_torch_zeros(*args, device=return_xpu(device), **kwargs)
|
| 281 |
+
else:
|
| 282 |
+
return original_torch_zeros(*args, device=device, **kwargs)
|
| 283 |
+
|
| 284 |
+
original_torch_full = torch.full
|
| 285 |
+
@wraps(torch.full)
|
| 286 |
+
def torch_full(*args, device=None, **kwargs):
|
| 287 |
+
if check_device(device):
|
| 288 |
+
return original_torch_full(*args, device=return_xpu(device), **kwargs)
|
| 289 |
+
else:
|
| 290 |
+
return original_torch_full(*args, device=device, **kwargs)
|
| 291 |
+
|
| 292 |
+
original_torch_linspace = torch.linspace
|
| 293 |
+
@wraps(torch.linspace)
|
| 294 |
+
def torch_linspace(*args, device=None, **kwargs):
|
| 295 |
+
if check_device(device):
|
| 296 |
+
return original_torch_linspace(*args, device=return_xpu(device), **kwargs)
|
| 297 |
+
else:
|
| 298 |
+
return original_torch_linspace(*args, device=device, **kwargs)
|
| 299 |
+
|
| 300 |
+
original_torch_load = torch.load
|
| 301 |
+
@wraps(torch.load)
|
| 302 |
+
def torch_load(f, map_location=None, *args, **kwargs):
|
| 303 |
+
if map_location is None:
|
| 304 |
+
map_location = "xpu"
|
| 305 |
+
if check_device(map_location):
|
| 306 |
+
return original_torch_load(f, *args, map_location=return_xpu(map_location), **kwargs)
|
| 307 |
+
else:
|
| 308 |
+
return original_torch_load(f, *args, map_location=map_location, **kwargs)
|
| 309 |
+
|
| 310 |
+
original_torch_Generator = torch.Generator
|
| 311 |
+
@wraps(torch.Generator)
|
| 312 |
+
def torch_Generator(device=None):
|
| 313 |
+
if check_device(device):
|
| 314 |
+
return original_torch_Generator(return_xpu(device))
|
| 315 |
+
else:
|
| 316 |
+
return original_torch_Generator(device)
|
| 317 |
+
|
| 318 |
+
@wraps(torch.cuda.synchronize)
|
| 319 |
+
def torch_cuda_synchronize(device=None):
|
| 320 |
+
if check_device(device):
|
| 321 |
+
return torch.xpu.synchronize(return_xpu(device))
|
| 322 |
+
else:
|
| 323 |
+
return torch.xpu.synchronize(device)
|
| 324 |
+
|
| 325 |
+
|
| 326 |
+
# Hijack Functions:
|
| 327 |
+
def ipex_hijacks(legacy=True):
|
| 328 |
+
global device_supports_fp64, can_allocate_plus_4gb
|
| 329 |
+
if legacy and float(torch.__version__[:3]) < 2.5:
|
| 330 |
+
torch.nn.functional.interpolate = interpolate
|
| 331 |
+
torch.tensor = torch_tensor
|
| 332 |
+
torch.Tensor.to = Tensor_to
|
| 333 |
+
torch.Tensor.cuda = Tensor_cuda
|
| 334 |
+
torch.Tensor.pin_memory = Tensor_pin_memory
|
| 335 |
+
torch.UntypedStorage.__init__ = UntypedStorage_init
|
| 336 |
+
torch.UntypedStorage.cuda = UntypedStorage_cuda
|
| 337 |
+
torch.empty = torch_empty
|
| 338 |
+
torch.randn = torch_randn
|
| 339 |
+
torch.ones = torch_ones
|
| 340 |
+
torch.zeros = torch_zeros
|
| 341 |
+
torch.full = torch_full
|
| 342 |
+
torch.linspace = torch_linspace
|
| 343 |
+
torch.load = torch_load
|
| 344 |
+
torch.Generator = torch_Generator
|
| 345 |
+
torch.cuda.synchronize = torch_cuda_synchronize
|
| 346 |
+
|
| 347 |
+
torch.backends.cuda.sdp_kernel = return_null_context
|
| 348 |
+
torch.nn.DataParallel = DummyDataParallel
|
| 349 |
+
torch.UntypedStorage.is_cuda = is_cuda
|
| 350 |
+
torch.amp.autocast_mode.autocast.__init__ = autocast_init
|
| 351 |
+
|
| 352 |
+
torch.nn.functional.scaled_dot_product_attention = scaled_dot_product_attention
|
| 353 |
+
torch.nn.functional.group_norm = functional_group_norm
|
| 354 |
+
torch.nn.functional.layer_norm = functional_layer_norm
|
| 355 |
+
torch.nn.functional.linear = functional_linear
|
| 356 |
+
torch.nn.functional.conv1d = functional_conv1d
|
| 357 |
+
torch.nn.functional.conv2d = functional_conv2d
|
| 358 |
+
torch.nn.functional.conv3d = functional_conv3d
|
| 359 |
+
torch.nn.functional.pad = functional_pad
|
| 360 |
+
|
| 361 |
+
torch.bmm = torch_bmm
|
| 362 |
+
torch.fft.fftn = fft_fftn
|
| 363 |
+
torch.fft.ifftn = fft_ifftn
|
| 364 |
+
if not device_supports_fp64:
|
| 365 |
+
torch.from_numpy = from_numpy
|
| 366 |
+
torch.as_tensor = as_tensor
|
| 367 |
+
return device_supports_fp64, can_allocate_plus_4gb
|
huggingface_util.py
ADDED
|
@@ -0,0 +1,84 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from typing import Union, BinaryIO
|
| 2 |
+
from huggingface_hub import HfApi
|
| 3 |
+
from pathlib import Path
|
| 4 |
+
import argparse
|
| 5 |
+
import os
|
| 6 |
+
from library.utils import fire_in_thread
|
| 7 |
+
from library.utils import setup_logging
|
| 8 |
+
setup_logging()
|
| 9 |
+
import logging
|
| 10 |
+
logger = logging.getLogger(__name__)
|
| 11 |
+
|
| 12 |
+
def exists_repo(repo_id: str, repo_type: str, revision: str = "main", token: str = None):
|
| 13 |
+
api = HfApi(
|
| 14 |
+
token=token,
|
| 15 |
+
)
|
| 16 |
+
try:
|
| 17 |
+
api.repo_info(repo_id=repo_id, revision=revision, repo_type=repo_type)
|
| 18 |
+
return True
|
| 19 |
+
except:
|
| 20 |
+
return False
|
| 21 |
+
|
| 22 |
+
|
| 23 |
+
def upload(
|
| 24 |
+
args: argparse.Namespace,
|
| 25 |
+
src: Union[str, Path, bytes, BinaryIO],
|
| 26 |
+
dest_suffix: str = "",
|
| 27 |
+
force_sync_upload: bool = False,
|
| 28 |
+
):
|
| 29 |
+
repo_id = args.huggingface_repo_id
|
| 30 |
+
repo_type = args.huggingface_repo_type
|
| 31 |
+
token = args.huggingface_token
|
| 32 |
+
path_in_repo = args.huggingface_path_in_repo + dest_suffix if args.huggingface_path_in_repo is not None else None
|
| 33 |
+
private = args.huggingface_repo_visibility is None or args.huggingface_repo_visibility != "public"
|
| 34 |
+
api = HfApi(token=token)
|
| 35 |
+
if not exists_repo(repo_id=repo_id, repo_type=repo_type, token=token):
|
| 36 |
+
try:
|
| 37 |
+
api.create_repo(repo_id=repo_id, repo_type=repo_type, private=private)
|
| 38 |
+
except Exception as e: # とりあえずRepositoryNotFoundErrorは確認したが他にあると困るので
|
| 39 |
+
logger.error("===========================================")
|
| 40 |
+
logger.error(f"failed to create HuggingFace repo / HuggingFaceのリポジトリの作成に失敗しました : {e}")
|
| 41 |
+
logger.error("===========================================")
|
| 42 |
+
|
| 43 |
+
is_folder = (type(src) == str and os.path.isdir(src)) or (isinstance(src, Path) and src.is_dir())
|
| 44 |
+
|
| 45 |
+
def uploader():
|
| 46 |
+
try:
|
| 47 |
+
if is_folder:
|
| 48 |
+
api.upload_folder(
|
| 49 |
+
repo_id=repo_id,
|
| 50 |
+
repo_type=repo_type,
|
| 51 |
+
folder_path=src,
|
| 52 |
+
path_in_repo=path_in_repo,
|
| 53 |
+
)
|
| 54 |
+
else:
|
| 55 |
+
api.upload_file(
|
| 56 |
+
repo_id=repo_id,
|
| 57 |
+
repo_type=repo_type,
|
| 58 |
+
path_or_fileobj=src,
|
| 59 |
+
path_in_repo=path_in_repo,
|
| 60 |
+
)
|
| 61 |
+
except Exception as e: # RuntimeErrorを確認済みだが他にあると困るので
|
| 62 |
+
logger.error("===========================================")
|
| 63 |
+
logger.error(f"failed to upload to HuggingFace / HuggingFaceへのアップロードに失敗しました : {e}")
|
| 64 |
+
logger.error("===========================================")
|
| 65 |
+
|
| 66 |
+
if args.async_upload and not force_sync_upload:
|
| 67 |
+
fire_in_thread(uploader)
|
| 68 |
+
else:
|
| 69 |
+
uploader()
|
| 70 |
+
|
| 71 |
+
|
| 72 |
+
def list_dir(
|
| 73 |
+
repo_id: str,
|
| 74 |
+
subfolder: str,
|
| 75 |
+
repo_type: str,
|
| 76 |
+
revision: str = "main",
|
| 77 |
+
token: str = None,
|
| 78 |
+
):
|
| 79 |
+
api = HfApi(
|
| 80 |
+
token=token,
|
| 81 |
+
)
|
| 82 |
+
repo_info = api.repo_info(repo_id=repo_id, revision=revision, repo_type=repo_type)
|
| 83 |
+
file_list = [file for file in repo_info.siblings if file.rfilename.startswith(subfolder)]
|
| 84 |
+
return file_list
|
hypernetwork.py
ADDED
|
@@ -0,0 +1,223 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import torch.nn.functional as F
|
| 3 |
+
from diffusers.models.attention_processor import (
|
| 4 |
+
Attention,
|
| 5 |
+
AttnProcessor2_0,
|
| 6 |
+
SlicedAttnProcessor,
|
| 7 |
+
XFormersAttnProcessor
|
| 8 |
+
)
|
| 9 |
+
|
| 10 |
+
try:
|
| 11 |
+
import xformers.ops
|
| 12 |
+
except:
|
| 13 |
+
xformers = None
|
| 14 |
+
|
| 15 |
+
|
| 16 |
+
loaded_networks = []
|
| 17 |
+
|
| 18 |
+
|
| 19 |
+
def apply_single_hypernetwork(
|
| 20 |
+
hypernetwork, hidden_states, encoder_hidden_states
|
| 21 |
+
):
|
| 22 |
+
context_k, context_v = hypernetwork.forward(hidden_states, encoder_hidden_states)
|
| 23 |
+
return context_k, context_v
|
| 24 |
+
|
| 25 |
+
|
| 26 |
+
def apply_hypernetworks(context_k, context_v, layer=None):
|
| 27 |
+
if len(loaded_networks) == 0:
|
| 28 |
+
return context_v, context_v
|
| 29 |
+
for hypernetwork in loaded_networks:
|
| 30 |
+
context_k, context_v = hypernetwork.forward(context_k, context_v)
|
| 31 |
+
|
| 32 |
+
context_k = context_k.to(dtype=context_k.dtype)
|
| 33 |
+
context_v = context_v.to(dtype=context_k.dtype)
|
| 34 |
+
|
| 35 |
+
return context_k, context_v
|
| 36 |
+
|
| 37 |
+
|
| 38 |
+
|
| 39 |
+
def xformers_forward(
|
| 40 |
+
self: XFormersAttnProcessor,
|
| 41 |
+
attn: Attention,
|
| 42 |
+
hidden_states: torch.Tensor,
|
| 43 |
+
encoder_hidden_states: torch.Tensor = None,
|
| 44 |
+
attention_mask: torch.Tensor = None,
|
| 45 |
+
):
|
| 46 |
+
batch_size, sequence_length, _ = (
|
| 47 |
+
hidden_states.shape
|
| 48 |
+
if encoder_hidden_states is None
|
| 49 |
+
else encoder_hidden_states.shape
|
| 50 |
+
)
|
| 51 |
+
|
| 52 |
+
attention_mask = attn.prepare_attention_mask(
|
| 53 |
+
attention_mask, sequence_length, batch_size
|
| 54 |
+
)
|
| 55 |
+
|
| 56 |
+
query = attn.to_q(hidden_states)
|
| 57 |
+
|
| 58 |
+
if encoder_hidden_states is None:
|
| 59 |
+
encoder_hidden_states = hidden_states
|
| 60 |
+
elif attn.norm_cross:
|
| 61 |
+
encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
|
| 62 |
+
|
| 63 |
+
context_k, context_v = apply_hypernetworks(hidden_states, encoder_hidden_states)
|
| 64 |
+
|
| 65 |
+
key = attn.to_k(context_k)
|
| 66 |
+
value = attn.to_v(context_v)
|
| 67 |
+
|
| 68 |
+
query = attn.head_to_batch_dim(query).contiguous()
|
| 69 |
+
key = attn.head_to_batch_dim(key).contiguous()
|
| 70 |
+
value = attn.head_to_batch_dim(value).contiguous()
|
| 71 |
+
|
| 72 |
+
hidden_states = xformers.ops.memory_efficient_attention(
|
| 73 |
+
query,
|
| 74 |
+
key,
|
| 75 |
+
value,
|
| 76 |
+
attn_bias=attention_mask,
|
| 77 |
+
op=self.attention_op,
|
| 78 |
+
scale=attn.scale,
|
| 79 |
+
)
|
| 80 |
+
hidden_states = hidden_states.to(query.dtype)
|
| 81 |
+
hidden_states = attn.batch_to_head_dim(hidden_states)
|
| 82 |
+
|
| 83 |
+
# linear proj
|
| 84 |
+
hidden_states = attn.to_out[0](hidden_states)
|
| 85 |
+
# dropout
|
| 86 |
+
hidden_states = attn.to_out[1](hidden_states)
|
| 87 |
+
return hidden_states
|
| 88 |
+
|
| 89 |
+
|
| 90 |
+
def sliced_attn_forward(
|
| 91 |
+
self: SlicedAttnProcessor,
|
| 92 |
+
attn: Attention,
|
| 93 |
+
hidden_states: torch.Tensor,
|
| 94 |
+
encoder_hidden_states: torch.Tensor = None,
|
| 95 |
+
attention_mask: torch.Tensor = None,
|
| 96 |
+
):
|
| 97 |
+
batch_size, sequence_length, _ = (
|
| 98 |
+
hidden_states.shape
|
| 99 |
+
if encoder_hidden_states is None
|
| 100 |
+
else encoder_hidden_states.shape
|
| 101 |
+
)
|
| 102 |
+
attention_mask = attn.prepare_attention_mask(
|
| 103 |
+
attention_mask, sequence_length, batch_size
|
| 104 |
+
)
|
| 105 |
+
|
| 106 |
+
query = attn.to_q(hidden_states)
|
| 107 |
+
dim = query.shape[-1]
|
| 108 |
+
query = attn.head_to_batch_dim(query)
|
| 109 |
+
|
| 110 |
+
if encoder_hidden_states is None:
|
| 111 |
+
encoder_hidden_states = hidden_states
|
| 112 |
+
elif attn.norm_cross:
|
| 113 |
+
encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
|
| 114 |
+
|
| 115 |
+
context_k, context_v = apply_hypernetworks(hidden_states, encoder_hidden_states)
|
| 116 |
+
|
| 117 |
+
key = attn.to_k(context_k)
|
| 118 |
+
value = attn.to_v(context_v)
|
| 119 |
+
key = attn.head_to_batch_dim(key)
|
| 120 |
+
value = attn.head_to_batch_dim(value)
|
| 121 |
+
|
| 122 |
+
batch_size_attention, query_tokens, _ = query.shape
|
| 123 |
+
hidden_states = torch.zeros(
|
| 124 |
+
(batch_size_attention, query_tokens, dim // attn.heads),
|
| 125 |
+
device=query.device,
|
| 126 |
+
dtype=query.dtype,
|
| 127 |
+
)
|
| 128 |
+
|
| 129 |
+
for i in range(batch_size_attention // self.slice_size):
|
| 130 |
+
start_idx = i * self.slice_size
|
| 131 |
+
end_idx = (i + 1) * self.slice_size
|
| 132 |
+
|
| 133 |
+
query_slice = query[start_idx:end_idx]
|
| 134 |
+
key_slice = key[start_idx:end_idx]
|
| 135 |
+
attn_mask_slice = (
|
| 136 |
+
attention_mask[start_idx:end_idx] if attention_mask is not None else None
|
| 137 |
+
)
|
| 138 |
+
|
| 139 |
+
attn_slice = attn.get_attention_scores(query_slice, key_slice, attn_mask_slice)
|
| 140 |
+
|
| 141 |
+
attn_slice = torch.bmm(attn_slice, value[start_idx:end_idx])
|
| 142 |
+
|
| 143 |
+
hidden_states[start_idx:end_idx] = attn_slice
|
| 144 |
+
|
| 145 |
+
hidden_states = attn.batch_to_head_dim(hidden_states)
|
| 146 |
+
|
| 147 |
+
# linear proj
|
| 148 |
+
hidden_states = attn.to_out[0](hidden_states)
|
| 149 |
+
# dropout
|
| 150 |
+
hidden_states = attn.to_out[1](hidden_states)
|
| 151 |
+
|
| 152 |
+
return hidden_states
|
| 153 |
+
|
| 154 |
+
|
| 155 |
+
def v2_0_forward(
|
| 156 |
+
self: AttnProcessor2_0,
|
| 157 |
+
attn: Attention,
|
| 158 |
+
hidden_states,
|
| 159 |
+
encoder_hidden_states=None,
|
| 160 |
+
attention_mask=None,
|
| 161 |
+
):
|
| 162 |
+
batch_size, sequence_length, _ = (
|
| 163 |
+
hidden_states.shape
|
| 164 |
+
if encoder_hidden_states is None
|
| 165 |
+
else encoder_hidden_states.shape
|
| 166 |
+
)
|
| 167 |
+
inner_dim = hidden_states.shape[-1]
|
| 168 |
+
|
| 169 |
+
if attention_mask is not None:
|
| 170 |
+
attention_mask = attn.prepare_attention_mask(
|
| 171 |
+
attention_mask, sequence_length, batch_size
|
| 172 |
+
)
|
| 173 |
+
# scaled_dot_product_attention expects attention_mask shape to be
|
| 174 |
+
# (batch, heads, source_length, target_length)
|
| 175 |
+
attention_mask = attention_mask.view(
|
| 176 |
+
batch_size, attn.heads, -1, attention_mask.shape[-1]
|
| 177 |
+
)
|
| 178 |
+
|
| 179 |
+
query = attn.to_q(hidden_states)
|
| 180 |
+
|
| 181 |
+
if encoder_hidden_states is None:
|
| 182 |
+
encoder_hidden_states = hidden_states
|
| 183 |
+
elif attn.norm_cross:
|
| 184 |
+
encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
|
| 185 |
+
|
| 186 |
+
context_k, context_v = apply_hypernetworks(hidden_states, encoder_hidden_states)
|
| 187 |
+
|
| 188 |
+
key = attn.to_k(context_k)
|
| 189 |
+
value = attn.to_v(context_v)
|
| 190 |
+
|
| 191 |
+
head_dim = inner_dim // attn.heads
|
| 192 |
+
query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
|
| 193 |
+
key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
|
| 194 |
+
value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
|
| 195 |
+
|
| 196 |
+
# the output of sdp = (batch, num_heads, seq_len, head_dim)
|
| 197 |
+
# TODO: add support for attn.scale when we move to Torch 2.1
|
| 198 |
+
hidden_states = F.scaled_dot_product_attention(
|
| 199 |
+
query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False
|
| 200 |
+
)
|
| 201 |
+
|
| 202 |
+
hidden_states = hidden_states.transpose(1, 2).reshape(
|
| 203 |
+
batch_size, -1, attn.heads * head_dim
|
| 204 |
+
)
|
| 205 |
+
hidden_states = hidden_states.to(query.dtype)
|
| 206 |
+
|
| 207 |
+
# linear proj
|
| 208 |
+
hidden_states = attn.to_out[0](hidden_states)
|
| 209 |
+
# dropout
|
| 210 |
+
hidden_states = attn.to_out[1](hidden_states)
|
| 211 |
+
return hidden_states
|
| 212 |
+
|
| 213 |
+
|
| 214 |
+
def replace_attentions_for_hypernetwork():
|
| 215 |
+
import diffusers.models.attention_processor
|
| 216 |
+
|
| 217 |
+
diffusers.models.attention_processor.XFormersAttnProcessor.__call__ = (
|
| 218 |
+
xformers_forward
|
| 219 |
+
)
|
| 220 |
+
diffusers.models.attention_processor.SlicedAttnProcessor.__call__ = (
|
| 221 |
+
sliced_attn_forward
|
| 222 |
+
)
|
| 223 |
+
diffusers.models.attention_processor.AttnProcessor2_0.__call__ = v2_0_forward
|
hypernetwork_nai.py
ADDED
|
@@ -0,0 +1,96 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# NAI compatible
|
| 2 |
+
|
| 3 |
+
import torch
|
| 4 |
+
|
| 5 |
+
|
| 6 |
+
class HypernetworkModule(torch.nn.Module):
|
| 7 |
+
def __init__(self, dim, multiplier=1.0):
|
| 8 |
+
super().__init__()
|
| 9 |
+
|
| 10 |
+
linear1 = torch.nn.Linear(dim, dim * 2)
|
| 11 |
+
linear2 = torch.nn.Linear(dim * 2, dim)
|
| 12 |
+
linear1.weight.data.normal_(mean=0.0, std=0.01)
|
| 13 |
+
linear1.bias.data.zero_()
|
| 14 |
+
linear2.weight.data.normal_(mean=0.0, std=0.01)
|
| 15 |
+
linear2.bias.data.zero_()
|
| 16 |
+
linears = [linear1, linear2]
|
| 17 |
+
|
| 18 |
+
self.linear = torch.nn.Sequential(*linears)
|
| 19 |
+
self.multiplier = multiplier
|
| 20 |
+
|
| 21 |
+
def forward(self, x):
|
| 22 |
+
return x + self.linear(x) * self.multiplier
|
| 23 |
+
|
| 24 |
+
|
| 25 |
+
class Hypernetwork(torch.nn.Module):
|
| 26 |
+
enable_sizes = [320, 640, 768, 1280]
|
| 27 |
+
# return self.modules[Hypernetwork.enable_sizes.index(size)]
|
| 28 |
+
|
| 29 |
+
def __init__(self, multiplier=1.0) -> None:
|
| 30 |
+
super().__init__()
|
| 31 |
+
self.modules = []
|
| 32 |
+
for size in Hypernetwork.enable_sizes:
|
| 33 |
+
self.modules.append((HypernetworkModule(size, multiplier), HypernetworkModule(size, multiplier)))
|
| 34 |
+
self.register_module(f"{size}_0", self.modules[-1][0])
|
| 35 |
+
self.register_module(f"{size}_1", self.modules[-1][1])
|
| 36 |
+
|
| 37 |
+
def apply_to_stable_diffusion(self, text_encoder, vae, unet):
|
| 38 |
+
blocks = unet.input_blocks + [unet.middle_block] + unet.output_blocks
|
| 39 |
+
for block in blocks:
|
| 40 |
+
for subblk in block:
|
| 41 |
+
if 'SpatialTransformer' in str(type(subblk)):
|
| 42 |
+
for tf_block in subblk.transformer_blocks:
|
| 43 |
+
for attn in [tf_block.attn1, tf_block.attn2]:
|
| 44 |
+
size = attn.context_dim
|
| 45 |
+
if size in Hypernetwork.enable_sizes:
|
| 46 |
+
attn.hypernetwork = self
|
| 47 |
+
else:
|
| 48 |
+
attn.hypernetwork = None
|
| 49 |
+
|
| 50 |
+
def apply_to_diffusers(self, text_encoder, vae, unet):
|
| 51 |
+
blocks = unet.down_blocks + [unet.mid_block] + unet.up_blocks
|
| 52 |
+
for block in blocks:
|
| 53 |
+
if hasattr(block, 'attentions'):
|
| 54 |
+
for subblk in block.attentions:
|
| 55 |
+
if 'SpatialTransformer' in str(type(subblk)) or 'Transformer2DModel' in str(type(subblk)): # 0.6.0 and 0.7~
|
| 56 |
+
for tf_block in subblk.transformer_blocks:
|
| 57 |
+
for attn in [tf_block.attn1, tf_block.attn2]:
|
| 58 |
+
size = attn.to_k.in_features
|
| 59 |
+
if size in Hypernetwork.enable_sizes:
|
| 60 |
+
attn.hypernetwork = self
|
| 61 |
+
else:
|
| 62 |
+
attn.hypernetwork = None
|
| 63 |
+
return True # TODO error checking
|
| 64 |
+
|
| 65 |
+
def forward(self, x, context):
|
| 66 |
+
size = context.shape[-1]
|
| 67 |
+
assert size in Hypernetwork.enable_sizes
|
| 68 |
+
module = self.modules[Hypernetwork.enable_sizes.index(size)]
|
| 69 |
+
return module[0].forward(context), module[1].forward(context)
|
| 70 |
+
|
| 71 |
+
def load_from_state_dict(self, state_dict):
|
| 72 |
+
# old ver to new ver
|
| 73 |
+
changes = {
|
| 74 |
+
'linear1.bias': 'linear.0.bias',
|
| 75 |
+
'linear1.weight': 'linear.0.weight',
|
| 76 |
+
'linear2.bias': 'linear.1.bias',
|
| 77 |
+
'linear2.weight': 'linear.1.weight',
|
| 78 |
+
}
|
| 79 |
+
for key_from, key_to in changes.items():
|
| 80 |
+
if key_from in state_dict:
|
| 81 |
+
state_dict[key_to] = state_dict[key_from]
|
| 82 |
+
del state_dict[key_from]
|
| 83 |
+
|
| 84 |
+
for size, sd in state_dict.items():
|
| 85 |
+
if type(size) == int:
|
| 86 |
+
self.modules[Hypernetwork.enable_sizes.index(size)][0].load_state_dict(sd[0], strict=True)
|
| 87 |
+
self.modules[Hypernetwork.enable_sizes.index(size)][1].load_state_dict(sd[1], strict=True)
|
| 88 |
+
return True
|
| 89 |
+
|
| 90 |
+
def get_state_dict(self):
|
| 91 |
+
state_dict = {}
|
| 92 |
+
for i, size in enumerate(Hypernetwork.enable_sizes):
|
| 93 |
+
sd0 = self.modules[i][0].state_dict()
|
| 94 |
+
sd1 = self.modules[i][1].state_dict()
|
| 95 |
+
state_dict[size] = [sd0, sd1]
|
| 96 |
+
return state_dict
|
latent_upscaler.py
ADDED
|
@@ -0,0 +1,354 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# 外部から簡単にupscalerを呼ぶためのスクリプト
|
| 2 |
+
# 単体で動くようにモデル定義も含めている
|
| 3 |
+
|
| 4 |
+
import argparse
|
| 5 |
+
import glob
|
| 6 |
+
import os
|
| 7 |
+
import cv2
|
| 8 |
+
from diffusers import AutoencoderKL
|
| 9 |
+
|
| 10 |
+
from typing import Dict, List
|
| 11 |
+
import numpy as np
|
| 12 |
+
|
| 13 |
+
import torch
|
| 14 |
+
from library.device_utils import init_ipex, get_preferred_device
|
| 15 |
+
init_ipex()
|
| 16 |
+
|
| 17 |
+
from torch import nn
|
| 18 |
+
from tqdm import tqdm
|
| 19 |
+
from PIL import Image
|
| 20 |
+
from library.utils import setup_logging
|
| 21 |
+
setup_logging()
|
| 22 |
+
import logging
|
| 23 |
+
logger = logging.getLogger(__name__)
|
| 24 |
+
|
| 25 |
+
class ResidualBlock(nn.Module):
|
| 26 |
+
def __init__(self, in_channels, out_channels=None, kernel_size=3, stride=1, padding=1):
|
| 27 |
+
super(ResidualBlock, self).__init__()
|
| 28 |
+
|
| 29 |
+
if out_channels is None:
|
| 30 |
+
out_channels = in_channels
|
| 31 |
+
|
| 32 |
+
self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size, stride, padding, bias=False)
|
| 33 |
+
self.bn1 = nn.BatchNorm2d(out_channels)
|
| 34 |
+
self.relu1 = nn.ReLU(inplace=True)
|
| 35 |
+
|
| 36 |
+
self.conv2 = nn.Conv2d(out_channels, out_channels, kernel_size, stride, padding, bias=False)
|
| 37 |
+
self.bn2 = nn.BatchNorm2d(out_channels)
|
| 38 |
+
|
| 39 |
+
self.relu2 = nn.ReLU(inplace=True) # このReLUはresidualに足す前にかけるほうがいいかも
|
| 40 |
+
|
| 41 |
+
# initialize weights
|
| 42 |
+
self._initialize_weights()
|
| 43 |
+
|
| 44 |
+
def _initialize_weights(self):
|
| 45 |
+
for m in self.modules():
|
| 46 |
+
if isinstance(m, nn.Conv2d):
|
| 47 |
+
nn.init.kaiming_normal_(m.weight, mode="fan_out", nonlinearity="relu")
|
| 48 |
+
if m.bias is not None:
|
| 49 |
+
nn.init.constant_(m.bias, 0)
|
| 50 |
+
elif isinstance(m, nn.BatchNorm2d):
|
| 51 |
+
nn.init.constant_(m.weight, 1)
|
| 52 |
+
nn.init.constant_(m.bias, 0)
|
| 53 |
+
elif isinstance(m, nn.Linear):
|
| 54 |
+
nn.init.normal_(m.weight, 0, 0.01)
|
| 55 |
+
nn.init.constant_(m.bias, 0)
|
| 56 |
+
|
| 57 |
+
def forward(self, x):
|
| 58 |
+
residual = x
|
| 59 |
+
|
| 60 |
+
out = self.conv1(x)
|
| 61 |
+
out = self.bn1(out)
|
| 62 |
+
out = self.relu1(out)
|
| 63 |
+
|
| 64 |
+
out = self.conv2(out)
|
| 65 |
+
out = self.bn2(out)
|
| 66 |
+
|
| 67 |
+
out += residual
|
| 68 |
+
|
| 69 |
+
out = self.relu2(out)
|
| 70 |
+
|
| 71 |
+
return out
|
| 72 |
+
|
| 73 |
+
|
| 74 |
+
class Upscaler(nn.Module):
|
| 75 |
+
def __init__(self):
|
| 76 |
+
super(Upscaler, self).__init__()
|
| 77 |
+
|
| 78 |
+
# define layers
|
| 79 |
+
# latent has 4 channels
|
| 80 |
+
|
| 81 |
+
self.conv1 = nn.Conv2d(4, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
|
| 82 |
+
self.bn1 = nn.BatchNorm2d(128)
|
| 83 |
+
self.relu1 = nn.ReLU(inplace=True)
|
| 84 |
+
|
| 85 |
+
# resblocks
|
| 86 |
+
# 数の暴力で20個:次元数を増やすよりもブロックを増やしたほうがreceptive fieldが広がるはずだぞ
|
| 87 |
+
self.resblock1 = ResidualBlock(128)
|
| 88 |
+
self.resblock2 = ResidualBlock(128)
|
| 89 |
+
self.resblock3 = ResidualBlock(128)
|
| 90 |
+
self.resblock4 = ResidualBlock(128)
|
| 91 |
+
self.resblock5 = ResidualBlock(128)
|
| 92 |
+
self.resblock6 = ResidualBlock(128)
|
| 93 |
+
self.resblock7 = ResidualBlock(128)
|
| 94 |
+
self.resblock8 = ResidualBlock(128)
|
| 95 |
+
self.resblock9 = ResidualBlock(128)
|
| 96 |
+
self.resblock10 = ResidualBlock(128)
|
| 97 |
+
self.resblock11 = ResidualBlock(128)
|
| 98 |
+
self.resblock12 = ResidualBlock(128)
|
| 99 |
+
self.resblock13 = ResidualBlock(128)
|
| 100 |
+
self.resblock14 = ResidualBlock(128)
|
| 101 |
+
self.resblock15 = ResidualBlock(128)
|
| 102 |
+
self.resblock16 = ResidualBlock(128)
|
| 103 |
+
self.resblock17 = ResidualBlock(128)
|
| 104 |
+
self.resblock18 = ResidualBlock(128)
|
| 105 |
+
self.resblock19 = ResidualBlock(128)
|
| 106 |
+
self.resblock20 = ResidualBlock(128)
|
| 107 |
+
|
| 108 |
+
# last convs
|
| 109 |
+
self.conv2 = nn.Conv2d(128, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
|
| 110 |
+
self.bn2 = nn.BatchNorm2d(64)
|
| 111 |
+
self.relu2 = nn.ReLU(inplace=True)
|
| 112 |
+
|
| 113 |
+
self.conv3 = nn.Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
|
| 114 |
+
self.bn3 = nn.BatchNorm2d(64)
|
| 115 |
+
self.relu3 = nn.ReLU(inplace=True)
|
| 116 |
+
|
| 117 |
+
# final conv: output 4 channels
|
| 118 |
+
self.conv_final = nn.Conv2d(64, 4, kernel_size=(1, 1), stride=(1, 1), padding=(0, 0))
|
| 119 |
+
|
| 120 |
+
# initialize weights
|
| 121 |
+
self._initialize_weights()
|
| 122 |
+
|
| 123 |
+
def _initialize_weights(self):
|
| 124 |
+
for m in self.modules():
|
| 125 |
+
if isinstance(m, nn.Conv2d):
|
| 126 |
+
nn.init.kaiming_normal_(m.weight, mode="fan_out", nonlinearity="relu")
|
| 127 |
+
if m.bias is not None:
|
| 128 |
+
nn.init.constant_(m.bias, 0)
|
| 129 |
+
elif isinstance(m, nn.BatchNorm2d):
|
| 130 |
+
nn.init.constant_(m.weight, 1)
|
| 131 |
+
nn.init.constant_(m.bias, 0)
|
| 132 |
+
elif isinstance(m, nn.Linear):
|
| 133 |
+
nn.init.normal_(m.weight, 0, 0.01)
|
| 134 |
+
nn.init.constant_(m.bias, 0)
|
| 135 |
+
|
| 136 |
+
# initialize final conv weights to 0: 流行りのzero conv
|
| 137 |
+
nn.init.constant_(self.conv_final.weight, 0)
|
| 138 |
+
|
| 139 |
+
def forward(self, x):
|
| 140 |
+
inp = x
|
| 141 |
+
|
| 142 |
+
x = self.conv1(x)
|
| 143 |
+
x = self.bn1(x)
|
| 144 |
+
x = self.relu1(x)
|
| 145 |
+
|
| 146 |
+
# いくつかのresblockを通した��に、residualを足すことで精度向上と学習速度向上が見込めるはず
|
| 147 |
+
residual = x
|
| 148 |
+
x = self.resblock1(x)
|
| 149 |
+
x = self.resblock2(x)
|
| 150 |
+
x = self.resblock3(x)
|
| 151 |
+
x = self.resblock4(x)
|
| 152 |
+
x = x + residual
|
| 153 |
+
residual = x
|
| 154 |
+
x = self.resblock5(x)
|
| 155 |
+
x = self.resblock6(x)
|
| 156 |
+
x = self.resblock7(x)
|
| 157 |
+
x = self.resblock8(x)
|
| 158 |
+
x = x + residual
|
| 159 |
+
residual = x
|
| 160 |
+
x = self.resblock9(x)
|
| 161 |
+
x = self.resblock10(x)
|
| 162 |
+
x = self.resblock11(x)
|
| 163 |
+
x = self.resblock12(x)
|
| 164 |
+
x = x + residual
|
| 165 |
+
residual = x
|
| 166 |
+
x = self.resblock13(x)
|
| 167 |
+
x = self.resblock14(x)
|
| 168 |
+
x = self.resblock15(x)
|
| 169 |
+
x = self.resblock16(x)
|
| 170 |
+
x = x + residual
|
| 171 |
+
residual = x
|
| 172 |
+
x = self.resblock17(x)
|
| 173 |
+
x = self.resblock18(x)
|
| 174 |
+
x = self.resblock19(x)
|
| 175 |
+
x = self.resblock20(x)
|
| 176 |
+
x = x + residual
|
| 177 |
+
|
| 178 |
+
x = self.conv2(x)
|
| 179 |
+
x = self.bn2(x)
|
| 180 |
+
x = self.relu2(x)
|
| 181 |
+
x = self.conv3(x)
|
| 182 |
+
x = self.bn3(x)
|
| 183 |
+
|
| 184 |
+
# ここにreluを入れないほうがいい気がする
|
| 185 |
+
|
| 186 |
+
x = self.conv_final(x)
|
| 187 |
+
|
| 188 |
+
# network estimates the difference between the input and the output
|
| 189 |
+
x = x + inp
|
| 190 |
+
|
| 191 |
+
return x
|
| 192 |
+
|
| 193 |
+
def support_latents(self) -> bool:
|
| 194 |
+
return False
|
| 195 |
+
|
| 196 |
+
def upscale(
|
| 197 |
+
self,
|
| 198 |
+
vae: AutoencoderKL,
|
| 199 |
+
lowreso_images: List[Image.Image],
|
| 200 |
+
lowreso_latents: torch.Tensor,
|
| 201 |
+
dtype: torch.dtype,
|
| 202 |
+
width: int,
|
| 203 |
+
height: int,
|
| 204 |
+
batch_size: int = 1,
|
| 205 |
+
vae_batch_size: int = 1,
|
| 206 |
+
):
|
| 207 |
+
# assertion
|
| 208 |
+
assert lowreso_images is not None, "Upscaler requires lowreso image"
|
| 209 |
+
|
| 210 |
+
# make upsampled image with lanczos4
|
| 211 |
+
upsampled_images = []
|
| 212 |
+
for lowreso_image in lowreso_images:
|
| 213 |
+
upsampled_image = np.array(lowreso_image.resize((width, height), Image.LANCZOS))
|
| 214 |
+
upsampled_images.append(upsampled_image)
|
| 215 |
+
|
| 216 |
+
# convert to tensor: this tensor is too large to be converted to cuda
|
| 217 |
+
upsampled_images = [torch.from_numpy(upsampled_image).permute(2, 0, 1).float() for upsampled_image in upsampled_images]
|
| 218 |
+
upsampled_images = torch.stack(upsampled_images, dim=0)
|
| 219 |
+
upsampled_images = upsampled_images.to(dtype)
|
| 220 |
+
|
| 221 |
+
# normalize to [-1, 1]
|
| 222 |
+
upsampled_images = upsampled_images / 127.5 - 1.0
|
| 223 |
+
|
| 224 |
+
# convert upsample images to latents with batch size
|
| 225 |
+
# logger.info("Encoding upsampled (LANCZOS4) images...")
|
| 226 |
+
upsampled_latents = []
|
| 227 |
+
for i in tqdm(range(0, upsampled_images.shape[0], vae_batch_size)):
|
| 228 |
+
batch = upsampled_images[i : i + vae_batch_size].to(vae.device)
|
| 229 |
+
with torch.no_grad():
|
| 230 |
+
batch = vae.encode(batch).latent_dist.sample()
|
| 231 |
+
upsampled_latents.append(batch)
|
| 232 |
+
|
| 233 |
+
upsampled_latents = torch.cat(upsampled_latents, dim=0)
|
| 234 |
+
|
| 235 |
+
# upscale (refine) latents with this model with batch size
|
| 236 |
+
logger.info("Upscaling latents...")
|
| 237 |
+
upscaled_latents = []
|
| 238 |
+
for i in range(0, upsampled_latents.shape[0], batch_size):
|
| 239 |
+
with torch.no_grad():
|
| 240 |
+
upscaled_latents.append(self.forward(upsampled_latents[i : i + batch_size]))
|
| 241 |
+
upscaled_latents = torch.cat(upscaled_latents, dim=0)
|
| 242 |
+
|
| 243 |
+
return upscaled_latents * 0.18215
|
| 244 |
+
|
| 245 |
+
|
| 246 |
+
# external interface: returns a model
|
| 247 |
+
def create_upscaler(**kwargs):
|
| 248 |
+
weights = kwargs["weights"]
|
| 249 |
+
model = Upscaler()
|
| 250 |
+
|
| 251 |
+
logger.info(f"Loading weights from {weights}...")
|
| 252 |
+
if os.path.splitext(weights)[1] == ".safetensors":
|
| 253 |
+
from safetensors.torch import load_file
|
| 254 |
+
|
| 255 |
+
sd = load_file(weights)
|
| 256 |
+
else:
|
| 257 |
+
sd = torch.load(weights, map_location=torch.device("cpu"))
|
| 258 |
+
model.load_state_dict(sd)
|
| 259 |
+
return model
|
| 260 |
+
|
| 261 |
+
|
| 262 |
+
# another interface: upscale images with a model for given images from command line
|
| 263 |
+
def upscale_images(args: argparse.Namespace):
|
| 264 |
+
DEVICE = get_preferred_device()
|
| 265 |
+
us_dtype = torch.float16 # TODO: support fp32/bf16
|
| 266 |
+
os.makedirs(args.output_dir, exist_ok=True)
|
| 267 |
+
|
| 268 |
+
# load VAE with Diffusers
|
| 269 |
+
assert args.vae_path is not None, "VAE path is required"
|
| 270 |
+
logger.info(f"Loading VAE from {args.vae_path}...")
|
| 271 |
+
vae = AutoencoderKL.from_pretrained(args.vae_path, subfolder="vae")
|
| 272 |
+
vae.to(DEVICE, dtype=us_dtype)
|
| 273 |
+
|
| 274 |
+
# prepare model
|
| 275 |
+
logger.info("Preparing model...")
|
| 276 |
+
upscaler: Upscaler = create_upscaler(weights=args.weights)
|
| 277 |
+
# logger.info("Loading weights from", args.weights)
|
| 278 |
+
# upscaler.load_state_dict(torch.load(args.weights))
|
| 279 |
+
upscaler.eval()
|
| 280 |
+
upscaler.to(DEVICE, dtype=us_dtype)
|
| 281 |
+
|
| 282 |
+
# load images
|
| 283 |
+
image_paths = glob.glob(args.image_pattern)
|
| 284 |
+
images = []
|
| 285 |
+
for image_path in image_paths:
|
| 286 |
+
image = Image.open(image_path)
|
| 287 |
+
image = image.convert("RGB")
|
| 288 |
+
|
| 289 |
+
# make divisible by 8
|
| 290 |
+
width = image.width
|
| 291 |
+
height = image.height
|
| 292 |
+
if width % 8 != 0:
|
| 293 |
+
width = width - (width % 8)
|
| 294 |
+
if height % 8 != 0:
|
| 295 |
+
height = height - (height % 8)
|
| 296 |
+
if width != image.width or height != image.height:
|
| 297 |
+
image = image.crop((0, 0, width, height))
|
| 298 |
+
|
| 299 |
+
images.append(image)
|
| 300 |
+
|
| 301 |
+
# debug output
|
| 302 |
+
if args.debug:
|
| 303 |
+
for image, image_path in zip(images, image_paths):
|
| 304 |
+
image_debug = image.resize((image.width * 2, image.height * 2), Image.LANCZOS)
|
| 305 |
+
|
| 306 |
+
basename = os.path.basename(image_path)
|
| 307 |
+
basename_wo_ext, ext = os.path.splitext(basename)
|
| 308 |
+
dest_file_name = os.path.join(args.output_dir, f"{basename_wo_ext}_lanczos4{ext}")
|
| 309 |
+
image_debug.save(dest_file_name)
|
| 310 |
+
|
| 311 |
+
# upscale
|
| 312 |
+
logger.info("Upscaling...")
|
| 313 |
+
upscaled_latents = upscaler.upscale(
|
| 314 |
+
vae, images, None, us_dtype, width * 2, height * 2, batch_size=args.batch_size, vae_batch_size=args.vae_batch_size
|
| 315 |
+
)
|
| 316 |
+
upscaled_latents /= 0.18215
|
| 317 |
+
|
| 318 |
+
# decode with batch
|
| 319 |
+
logger.info("Decoding...")
|
| 320 |
+
upscaled_images = []
|
| 321 |
+
for i in tqdm(range(0, upscaled_latents.shape[0], args.vae_batch_size)):
|
| 322 |
+
with torch.no_grad():
|
| 323 |
+
batch = vae.decode(upscaled_latents[i : i + args.vae_batch_size]).sample
|
| 324 |
+
batch = batch.to("cpu")
|
| 325 |
+
upscaled_images.append(batch)
|
| 326 |
+
upscaled_images = torch.cat(upscaled_images, dim=0)
|
| 327 |
+
|
| 328 |
+
# tensor to numpy
|
| 329 |
+
upscaled_images = upscaled_images.permute(0, 2, 3, 1).numpy()
|
| 330 |
+
upscaled_images = (upscaled_images + 1.0) * 127.5
|
| 331 |
+
upscaled_images = upscaled_images.clip(0, 255).astype(np.uint8)
|
| 332 |
+
|
| 333 |
+
upscaled_images = upscaled_images[..., ::-1]
|
| 334 |
+
|
| 335 |
+
# save images
|
| 336 |
+
for i, image in enumerate(upscaled_images):
|
| 337 |
+
basename = os.path.basename(image_paths[i])
|
| 338 |
+
basename_wo_ext, ext = os.path.splitext(basename)
|
| 339 |
+
dest_file_name = os.path.join(args.output_dir, f"{basename_wo_ext}_upscaled{ext}")
|
| 340 |
+
cv2.imwrite(dest_file_name, image)
|
| 341 |
+
|
| 342 |
+
|
| 343 |
+
if __name__ == "__main__":
|
| 344 |
+
parser = argparse.ArgumentParser()
|
| 345 |
+
parser.add_argument("--vae_path", type=str, default=None, help="VAE path")
|
| 346 |
+
parser.add_argument("--weights", type=str, default=None, help="Weights path")
|
| 347 |
+
parser.add_argument("--image_pattern", type=str, default=None, help="Image pattern")
|
| 348 |
+
parser.add_argument("--output_dir", type=str, default=".", help="Output directory")
|
| 349 |
+
parser.add_argument("--batch_size", type=int, default=4, help="Batch size")
|
| 350 |
+
parser.add_argument("--vae_batch_size", type=int, default=1, help="VAE batch size")
|
| 351 |
+
parser.add_argument("--debug", action="store_true", help="Debug mode")
|
| 352 |
+
|
| 353 |
+
args = parser.parse_args()
|
| 354 |
+
upscale_images(args)
|
libbitsandbytes_cpu.dll
ADDED
|
Binary file (76.3 kB). View file
|
|
|
libbitsandbytes_cuda116.dll
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:88f7bd2916ca3effc43f88492f1e1b9088d13cb5be3b4a3a4aede6aa3bf8d412
|
| 3 |
+
size 4724224
|
libbitsandbytes_cuda118.dll
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:4dc34709b8dcb078cbcdd65e5684f116cb395644d12b9c9fb144af5455bb1c18
|
| 3 |
+
size 14026752
|
logo_aihub.png
ADDED
|
lora.py
ADDED
|
@@ -0,0 +1,1410 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# LoRA network module
|
| 2 |
+
# reference:
|
| 3 |
+
# https://github.com/microsoft/LoRA/blob/main/loralib/layers.py
|
| 4 |
+
# https://github.com/cloneofsimo/lora/blob/master/lora_diffusion/lora.py
|
| 5 |
+
|
| 6 |
+
import math
|
| 7 |
+
import os
|
| 8 |
+
from typing import Dict, List, Optional, Tuple, Type, Union
|
| 9 |
+
from diffusers import AutoencoderKL
|
| 10 |
+
from transformers import CLIPTextModel
|
| 11 |
+
import numpy as np
|
| 12 |
+
import torch
|
| 13 |
+
import re
|
| 14 |
+
from library.utils import setup_logging
|
| 15 |
+
from library.sdxl_original_unet import SdxlUNet2DConditionModel
|
| 16 |
+
|
| 17 |
+
setup_logging()
|
| 18 |
+
import logging
|
| 19 |
+
|
| 20 |
+
logger = logging.getLogger(__name__)
|
| 21 |
+
|
| 22 |
+
RE_UPDOWN = re.compile(r"(up|down)_blocks_(\d+)_(resnets|upsamplers|downsamplers|attentions)_(\d+)_")
|
| 23 |
+
|
| 24 |
+
|
| 25 |
+
class LoRAModule(torch.nn.Module):
|
| 26 |
+
"""
|
| 27 |
+
replaces forward method of the original Linear, instead of replacing the original Linear module.
|
| 28 |
+
"""
|
| 29 |
+
|
| 30 |
+
def __init__(
|
| 31 |
+
self,
|
| 32 |
+
lora_name,
|
| 33 |
+
org_module: torch.nn.Module,
|
| 34 |
+
multiplier=1.0,
|
| 35 |
+
lora_dim=4,
|
| 36 |
+
alpha=1,
|
| 37 |
+
dropout=None,
|
| 38 |
+
rank_dropout=None,
|
| 39 |
+
module_dropout=None,
|
| 40 |
+
):
|
| 41 |
+
"""if alpha == 0 or None, alpha is rank (no scaling)."""
|
| 42 |
+
super().__init__()
|
| 43 |
+
self.lora_name = lora_name
|
| 44 |
+
|
| 45 |
+
if org_module.__class__.__name__ == "Conv2d":
|
| 46 |
+
in_dim = org_module.in_channels
|
| 47 |
+
out_dim = org_module.out_channels
|
| 48 |
+
else:
|
| 49 |
+
in_dim = org_module.in_features
|
| 50 |
+
out_dim = org_module.out_features
|
| 51 |
+
|
| 52 |
+
# if limit_rank:
|
| 53 |
+
# self.lora_dim = min(lora_dim, in_dim, out_dim)
|
| 54 |
+
# if self.lora_dim != lora_dim:
|
| 55 |
+
# logger.info(f"{lora_name} dim (rank) is changed to: {self.lora_dim}")
|
| 56 |
+
# else:
|
| 57 |
+
self.lora_dim = lora_dim
|
| 58 |
+
|
| 59 |
+
if org_module.__class__.__name__ == "Conv2d":
|
| 60 |
+
kernel_size = org_module.kernel_size
|
| 61 |
+
stride = org_module.stride
|
| 62 |
+
padding = org_module.padding
|
| 63 |
+
self.lora_down = torch.nn.Conv2d(in_dim, self.lora_dim, kernel_size, stride, padding, bias=False)
|
| 64 |
+
self.lora_up = torch.nn.Conv2d(self.lora_dim, out_dim, (1, 1), (1, 1), bias=False)
|
| 65 |
+
else:
|
| 66 |
+
self.lora_down = torch.nn.Linear(in_dim, self.lora_dim, bias=False)
|
| 67 |
+
self.lora_up = torch.nn.Linear(self.lora_dim, out_dim, bias=False)
|
| 68 |
+
|
| 69 |
+
if type(alpha) == torch.Tensor:
|
| 70 |
+
alpha = alpha.detach().float().numpy() # without casting, bf16 causes error
|
| 71 |
+
alpha = self.lora_dim if alpha is None or alpha == 0 else alpha
|
| 72 |
+
self.scale = alpha / self.lora_dim
|
| 73 |
+
self.register_buffer("alpha", torch.tensor(alpha)) # 定数として扱える
|
| 74 |
+
|
| 75 |
+
# same as microsoft's
|
| 76 |
+
torch.nn.init.kaiming_uniform_(self.lora_down.weight, a=math.sqrt(5))
|
| 77 |
+
torch.nn.init.zeros_(self.lora_up.weight)
|
| 78 |
+
|
| 79 |
+
self.multiplier = multiplier
|
| 80 |
+
self.org_module = org_module # remove in applying
|
| 81 |
+
self.dropout = dropout
|
| 82 |
+
self.rank_dropout = rank_dropout
|
| 83 |
+
self.module_dropout = module_dropout
|
| 84 |
+
|
| 85 |
+
def apply_to(self):
|
| 86 |
+
self.org_forward = self.org_module.forward
|
| 87 |
+
self.org_module.forward = self.forward
|
| 88 |
+
del self.org_module
|
| 89 |
+
|
| 90 |
+
def forward(self, x):
|
| 91 |
+
org_forwarded = self.org_forward(x)
|
| 92 |
+
|
| 93 |
+
# module dropout
|
| 94 |
+
if self.module_dropout is not None and self.training:
|
| 95 |
+
if torch.rand(1) < self.module_dropout:
|
| 96 |
+
return org_forwarded
|
| 97 |
+
|
| 98 |
+
lx = self.lora_down(x)
|
| 99 |
+
|
| 100 |
+
# normal dropout
|
| 101 |
+
if self.dropout is not None and self.training:
|
| 102 |
+
lx = torch.nn.functional.dropout(lx, p=self.dropout)
|
| 103 |
+
|
| 104 |
+
# rank dropout
|
| 105 |
+
if self.rank_dropout is not None and self.training:
|
| 106 |
+
mask = torch.rand((lx.size(0), self.lora_dim), device=lx.device) > self.rank_dropout
|
| 107 |
+
if len(lx.size()) == 3:
|
| 108 |
+
mask = mask.unsqueeze(1) # for Text Encoder
|
| 109 |
+
elif len(lx.size()) == 4:
|
| 110 |
+
mask = mask.unsqueeze(-1).unsqueeze(-1) # for Conv2d
|
| 111 |
+
lx = lx * mask
|
| 112 |
+
|
| 113 |
+
# scaling for rank dropout: treat as if the rank is changed
|
| 114 |
+
# maskから計算することも考えられるが、augmentation的な効果を期待してrank_dropoutを用いる
|
| 115 |
+
scale = self.scale * (1.0 / (1.0 - self.rank_dropout)) # redundant for readability
|
| 116 |
+
else:
|
| 117 |
+
scale = self.scale
|
| 118 |
+
|
| 119 |
+
lx = self.lora_up(lx)
|
| 120 |
+
|
| 121 |
+
return org_forwarded + lx * self.multiplier * scale
|
| 122 |
+
|
| 123 |
+
|
| 124 |
+
class LoRAInfModule(LoRAModule):
|
| 125 |
+
def __init__(
|
| 126 |
+
self,
|
| 127 |
+
lora_name,
|
| 128 |
+
org_module: torch.nn.Module,
|
| 129 |
+
multiplier=1.0,
|
| 130 |
+
lora_dim=4,
|
| 131 |
+
alpha=1,
|
| 132 |
+
**kwargs,
|
| 133 |
+
):
|
| 134 |
+
# no dropout for inference
|
| 135 |
+
super().__init__(lora_name, org_module, multiplier, lora_dim, alpha)
|
| 136 |
+
|
| 137 |
+
self.org_module_ref = [org_module] # 後から参照できるように
|
| 138 |
+
self.enabled = True
|
| 139 |
+
|
| 140 |
+
# check regional or not by lora_name
|
| 141 |
+
self.text_encoder = False
|
| 142 |
+
if lora_name.startswith("lora_te_"):
|
| 143 |
+
self.regional = False
|
| 144 |
+
self.use_sub_prompt = True
|
| 145 |
+
self.text_encoder = True
|
| 146 |
+
elif "attn2_to_k" in lora_name or "attn2_to_v" in lora_name:
|
| 147 |
+
self.regional = False
|
| 148 |
+
self.use_sub_prompt = True
|
| 149 |
+
elif "time_emb" in lora_name:
|
| 150 |
+
self.regional = False
|
| 151 |
+
self.use_sub_prompt = False
|
| 152 |
+
else:
|
| 153 |
+
self.regional = True
|
| 154 |
+
self.use_sub_prompt = False
|
| 155 |
+
|
| 156 |
+
self.network: LoRANetwork = None
|
| 157 |
+
|
| 158 |
+
def set_network(self, network):
|
| 159 |
+
self.network = network
|
| 160 |
+
|
| 161 |
+
# freezeしてマージする
|
| 162 |
+
def merge_to(self, sd, dtype, device):
|
| 163 |
+
# get up/down weight
|
| 164 |
+
up_weight = sd["lora_up.weight"].to(torch.float).to(device)
|
| 165 |
+
down_weight = sd["lora_down.weight"].to(torch.float).to(device)
|
| 166 |
+
|
| 167 |
+
# extract weight from org_module
|
| 168 |
+
org_sd = self.org_module.state_dict()
|
| 169 |
+
weight = org_sd["weight"].to(torch.float)
|
| 170 |
+
|
| 171 |
+
# merge weight
|
| 172 |
+
if len(weight.size()) == 2:
|
| 173 |
+
# linear
|
| 174 |
+
weight = weight + self.multiplier * (up_weight @ down_weight) * self.scale
|
| 175 |
+
elif down_weight.size()[2:4] == (1, 1):
|
| 176 |
+
# conv2d 1x1
|
| 177 |
+
weight = (
|
| 178 |
+
weight
|
| 179 |
+
+ self.multiplier
|
| 180 |
+
* (up_weight.squeeze(3).squeeze(2) @ down_weight.squeeze(3).squeeze(2)).unsqueeze(2).unsqueeze(3)
|
| 181 |
+
* self.scale
|
| 182 |
+
)
|
| 183 |
+
else:
|
| 184 |
+
# conv2d 3x3
|
| 185 |
+
conved = torch.nn.functional.conv2d(down_weight.permute(1, 0, 2, 3), up_weight).permute(1, 0, 2, 3)
|
| 186 |
+
# logger.info(conved.size(), weight.size(), module.stride, module.padding)
|
| 187 |
+
weight = weight + self.multiplier * conved * self.scale
|
| 188 |
+
|
| 189 |
+
# set weight to org_module
|
| 190 |
+
org_sd["weight"] = weight.to(dtype)
|
| 191 |
+
self.org_module.load_state_dict(org_sd)
|
| 192 |
+
|
| 193 |
+
# 復元できるマージのため、このモジュールのweightを返す
|
| 194 |
+
def get_weight(self, multiplier=None):
|
| 195 |
+
if multiplier is None:
|
| 196 |
+
multiplier = self.multiplier
|
| 197 |
+
|
| 198 |
+
# get up/down weight from module
|
| 199 |
+
up_weight = self.lora_up.weight.to(torch.float)
|
| 200 |
+
down_weight = self.lora_down.weight.to(torch.float)
|
| 201 |
+
|
| 202 |
+
# pre-calculated weight
|
| 203 |
+
if len(down_weight.size()) == 2:
|
| 204 |
+
# linear
|
| 205 |
+
weight = self.multiplier * (up_weight @ down_weight) * self.scale
|
| 206 |
+
elif down_weight.size()[2:4] == (1, 1):
|
| 207 |
+
# conv2d 1x1
|
| 208 |
+
weight = (
|
| 209 |
+
self.multiplier
|
| 210 |
+
* (up_weight.squeeze(3).squeeze(2) @ down_weight.squeeze(3).squeeze(2)).unsqueeze(2).unsqueeze(3)
|
| 211 |
+
* self.scale
|
| 212 |
+
)
|
| 213 |
+
else:
|
| 214 |
+
# conv2d 3x3
|
| 215 |
+
conved = torch.nn.functional.conv2d(down_weight.permute(1, 0, 2, 3), up_weight).permute(1, 0, 2, 3)
|
| 216 |
+
weight = self.multiplier * conved * self.scale
|
| 217 |
+
|
| 218 |
+
return weight
|
| 219 |
+
|
| 220 |
+
def set_region(self, region):
|
| 221 |
+
self.region = region
|
| 222 |
+
self.region_mask = None
|
| 223 |
+
|
| 224 |
+
def default_forward(self, x):
|
| 225 |
+
# logger.info(f"default_forward {self.lora_name} {x.size()}")
|
| 226 |
+
return self.org_forward(x) + self.lora_up(self.lora_down(x)) * self.multiplier * self.scale
|
| 227 |
+
|
| 228 |
+
def forward(self, x):
|
| 229 |
+
if not self.enabled:
|
| 230 |
+
return self.org_forward(x)
|
| 231 |
+
|
| 232 |
+
if self.network is None or self.network.sub_prompt_index is None:
|
| 233 |
+
return self.default_forward(x)
|
| 234 |
+
if not self.regional and not self.use_sub_prompt:
|
| 235 |
+
return self.default_forward(x)
|
| 236 |
+
|
| 237 |
+
if self.regional:
|
| 238 |
+
return self.regional_forward(x)
|
| 239 |
+
else:
|
| 240 |
+
return self.sub_prompt_forward(x)
|
| 241 |
+
|
| 242 |
+
def get_mask_for_x(self, x):
|
| 243 |
+
# calculate size from shape of x
|
| 244 |
+
if len(x.size()) == 4:
|
| 245 |
+
h, w = x.size()[2:4]
|
| 246 |
+
area = h * w
|
| 247 |
+
else:
|
| 248 |
+
area = x.size()[1]
|
| 249 |
+
|
| 250 |
+
mask = self.network.mask_dic.get(area, None)
|
| 251 |
+
if mask is None or len(x.size()) == 2:
|
| 252 |
+
# emb_layers in SDXL doesn't have mask
|
| 253 |
+
# if "emb" not in self.lora_name:
|
| 254 |
+
# print(f"mask is None for resolution {self.lora_name}, {area}, {x.size()}")
|
| 255 |
+
mask_size = (1, x.size()[1]) if len(x.size()) == 2 else (1, *x.size()[1:-1], 1)
|
| 256 |
+
return torch.ones(mask_size, dtype=x.dtype, device=x.device) / self.network.num_sub_prompts
|
| 257 |
+
if len(x.size()) == 3:
|
| 258 |
+
mask = torch.reshape(mask, (1, -1, 1))
|
| 259 |
+
return mask
|
| 260 |
+
|
| 261 |
+
def regional_forward(self, x):
|
| 262 |
+
if "attn2_to_out" in self.lora_name:
|
| 263 |
+
return self.to_out_forward(x)
|
| 264 |
+
|
| 265 |
+
if self.network.mask_dic is None: # sub_prompt_index >= 3
|
| 266 |
+
return self.default_forward(x)
|
| 267 |
+
|
| 268 |
+
# apply mask for LoRA result
|
| 269 |
+
lx = self.lora_up(self.lora_down(x)) * self.multiplier * self.scale
|
| 270 |
+
mask = self.get_mask_for_x(lx)
|
| 271 |
+
# print("regional", self.lora_name, self.network.sub_prompt_index, lx.size(), mask.size())
|
| 272 |
+
# if mask.ndim > lx.ndim: # in some resolution, lx is 2d and mask is 3d (the reason is not checked)
|
| 273 |
+
# mask = mask.squeeze(-1)
|
| 274 |
+
lx = lx * mask
|
| 275 |
+
|
| 276 |
+
x = self.org_forward(x)
|
| 277 |
+
x = x + lx
|
| 278 |
+
|
| 279 |
+
if "attn2_to_q" in self.lora_name and self.network.is_last_network:
|
| 280 |
+
x = self.postp_to_q(x)
|
| 281 |
+
|
| 282 |
+
return x
|
| 283 |
+
|
| 284 |
+
def postp_to_q(self, x):
|
| 285 |
+
# repeat x to num_sub_prompts
|
| 286 |
+
has_real_uncond = x.size()[0] // self.network.batch_size == 3
|
| 287 |
+
qc = self.network.batch_size # uncond
|
| 288 |
+
qc += self.network.batch_size * self.network.num_sub_prompts # cond
|
| 289 |
+
if has_real_uncond:
|
| 290 |
+
qc += self.network.batch_size # real_uncond
|
| 291 |
+
|
| 292 |
+
query = torch.zeros((qc, x.size()[1], x.size()[2]), device=x.device, dtype=x.dtype)
|
| 293 |
+
query[: self.network.batch_size] = x[: self.network.batch_size]
|
| 294 |
+
|
| 295 |
+
for i in range(self.network.batch_size):
|
| 296 |
+
qi = self.network.batch_size + i * self.network.num_sub_prompts
|
| 297 |
+
query[qi : qi + self.network.num_sub_prompts] = x[self.network.batch_size + i]
|
| 298 |
+
|
| 299 |
+
if has_real_uncond:
|
| 300 |
+
query[-self.network.batch_size :] = x[-self.network.batch_size :]
|
| 301 |
+
|
| 302 |
+
# logger.info(f"postp_to_q {self.lora_name} {x.size()} {query.size()} {self.network.num_sub_prompts}")
|
| 303 |
+
return query
|
| 304 |
+
|
| 305 |
+
def sub_prompt_forward(self, x):
|
| 306 |
+
if x.size()[0] == self.network.batch_size: # if uncond in text_encoder, do not apply LoRA
|
| 307 |
+
return self.org_forward(x)
|
| 308 |
+
|
| 309 |
+
emb_idx = self.network.sub_prompt_index
|
| 310 |
+
if not self.text_encoder:
|
| 311 |
+
emb_idx += self.network.batch_size
|
| 312 |
+
|
| 313 |
+
# apply sub prompt of X
|
| 314 |
+
lx = x[emb_idx :: self.network.num_sub_prompts]
|
| 315 |
+
lx = self.lora_up(self.lora_down(lx)) * self.multiplier * self.scale
|
| 316 |
+
|
| 317 |
+
# logger.info(f"sub_prompt_forward {self.lora_name} {x.size()} {lx.size()} {emb_idx}")
|
| 318 |
+
|
| 319 |
+
x = self.org_forward(x)
|
| 320 |
+
x[emb_idx :: self.network.num_sub_prompts] += lx
|
| 321 |
+
|
| 322 |
+
return x
|
| 323 |
+
|
| 324 |
+
def to_out_forward(self, x):
|
| 325 |
+
# logger.info(f"to_out_forward {self.lora_name} {x.size()} {self.network.is_last_network}")
|
| 326 |
+
|
| 327 |
+
if self.network.is_last_network:
|
| 328 |
+
masks = [None] * self.network.num_sub_prompts
|
| 329 |
+
self.network.shared[self.lora_name] = (None, masks)
|
| 330 |
+
else:
|
| 331 |
+
lx, masks = self.network.shared[self.lora_name]
|
| 332 |
+
|
| 333 |
+
# call own LoRA
|
| 334 |
+
x1 = x[self.network.batch_size + self.network.sub_prompt_index :: self.network.num_sub_prompts]
|
| 335 |
+
lx1 = self.lora_up(self.lora_down(x1)) * self.multiplier * self.scale
|
| 336 |
+
|
| 337 |
+
if self.network.is_last_network:
|
| 338 |
+
lx = torch.zeros(
|
| 339 |
+
(self.network.num_sub_prompts * self.network.batch_size, *lx1.size()[1:]), device=lx1.device, dtype=lx1.dtype
|
| 340 |
+
)
|
| 341 |
+
self.network.shared[self.lora_name] = (lx, masks)
|
| 342 |
+
|
| 343 |
+
# logger.info(f"to_out_forward {lx.size()} {lx1.size()} {self.network.sub_prompt_index} {self.network.num_sub_prompts}")
|
| 344 |
+
lx[self.network.sub_prompt_index :: self.network.num_sub_prompts] += lx1
|
| 345 |
+
masks[self.network.sub_prompt_index] = self.get_mask_for_x(lx1)
|
| 346 |
+
|
| 347 |
+
# if not last network, return x and masks
|
| 348 |
+
x = self.org_forward(x)
|
| 349 |
+
if not self.network.is_last_network:
|
| 350 |
+
return x
|
| 351 |
+
|
| 352 |
+
lx, masks = self.network.shared.pop(self.lora_name)
|
| 353 |
+
|
| 354 |
+
# if last network, combine separated x with mask weighted sum
|
| 355 |
+
has_real_uncond = x.size()[0] // self.network.batch_size == self.network.num_sub_prompts + 2
|
| 356 |
+
|
| 357 |
+
out = torch.zeros((self.network.batch_size * (3 if has_real_uncond else 2), *x.size()[1:]), device=x.device, dtype=x.dtype)
|
| 358 |
+
out[: self.network.batch_size] = x[: self.network.batch_size] # uncond
|
| 359 |
+
if has_real_uncond:
|
| 360 |
+
out[-self.network.batch_size :] = x[-self.network.batch_size :] # real_uncond
|
| 361 |
+
|
| 362 |
+
# logger.info(f"to_out_forward {self.lora_name} {self.network.sub_prompt_index} {self.network.num_sub_prompts}")
|
| 363 |
+
# if num_sub_prompts > num of LoRAs, fill with zero
|
| 364 |
+
for i in range(len(masks)):
|
| 365 |
+
if masks[i] is None:
|
| 366 |
+
masks[i] = torch.zeros_like(masks[0])
|
| 367 |
+
|
| 368 |
+
mask = torch.cat(masks)
|
| 369 |
+
mask_sum = torch.sum(mask, dim=0) + 1e-4
|
| 370 |
+
for i in range(self.network.batch_size):
|
| 371 |
+
# 1枚の画像ごとに処理する
|
| 372 |
+
lx1 = lx[i * self.network.num_sub_prompts : (i + 1) * self.network.num_sub_prompts]
|
| 373 |
+
lx1 = lx1 * mask
|
| 374 |
+
lx1 = torch.sum(lx1, dim=0)
|
| 375 |
+
|
| 376 |
+
xi = self.network.batch_size + i * self.network.num_sub_prompts
|
| 377 |
+
x1 = x[xi : xi + self.network.num_sub_prompts]
|
| 378 |
+
x1 = x1 * mask
|
| 379 |
+
x1 = torch.sum(x1, dim=0)
|
| 380 |
+
x1 = x1 / mask_sum
|
| 381 |
+
|
| 382 |
+
x1 = x1 + lx1
|
| 383 |
+
out[self.network.batch_size + i] = x1
|
| 384 |
+
|
| 385 |
+
# logger.info(f"to_out_forward {x.size()} {out.size()} {has_real_uncond}")
|
| 386 |
+
return out
|
| 387 |
+
|
| 388 |
+
|
| 389 |
+
def parse_block_lr_kwargs(is_sdxl: bool, nw_kwargs: Dict) -> Optional[List[float]]:
|
| 390 |
+
down_lr_weight = nw_kwargs.get("down_lr_weight", None)
|
| 391 |
+
mid_lr_weight = nw_kwargs.get("mid_lr_weight", None)
|
| 392 |
+
up_lr_weight = nw_kwargs.get("up_lr_weight", None)
|
| 393 |
+
|
| 394 |
+
# 以上のいずれにも設定がない場合は無効としてNoneを返す
|
| 395 |
+
if down_lr_weight is None and mid_lr_weight is None and up_lr_weight is None:
|
| 396 |
+
return None
|
| 397 |
+
|
| 398 |
+
# extract learning rate weight for each block
|
| 399 |
+
if down_lr_weight is not None:
|
| 400 |
+
# if some parameters are not set, use zero
|
| 401 |
+
if "," in down_lr_weight:
|
| 402 |
+
down_lr_weight = [(float(s) if s else 0.0) for s in down_lr_weight.split(",")]
|
| 403 |
+
|
| 404 |
+
if mid_lr_weight is not None:
|
| 405 |
+
mid_lr_weight = [(float(s) if s else 0.0) for s in mid_lr_weight.split(",")]
|
| 406 |
+
|
| 407 |
+
if up_lr_weight is not None:
|
| 408 |
+
if "," in up_lr_weight:
|
| 409 |
+
up_lr_weight = [(float(s) if s else 0.0) for s in up_lr_weight.split(",")]
|
| 410 |
+
|
| 411 |
+
return get_block_lr_weight(
|
| 412 |
+
is_sdxl, down_lr_weight, mid_lr_weight, up_lr_weight, float(nw_kwargs.get("block_lr_zero_threshold", 0.0))
|
| 413 |
+
)
|
| 414 |
+
|
| 415 |
+
|
| 416 |
+
def create_network(
|
| 417 |
+
multiplier: float,
|
| 418 |
+
network_dim: Optional[int],
|
| 419 |
+
network_alpha: Optional[float],
|
| 420 |
+
vae: AutoencoderKL,
|
| 421 |
+
text_encoder: Union[CLIPTextModel, List[CLIPTextModel]],
|
| 422 |
+
unet,
|
| 423 |
+
neuron_dropout: Optional[float] = None,
|
| 424 |
+
**kwargs,
|
| 425 |
+
):
|
| 426 |
+
# if unet is an instance of SdxlUNet2DConditionModel or subclass, set is_sdxl to True
|
| 427 |
+
is_sdxl = unet is not None and issubclass(unet.__class__, SdxlUNet2DConditionModel)
|
| 428 |
+
|
| 429 |
+
if network_dim is None:
|
| 430 |
+
network_dim = 4 # default
|
| 431 |
+
if network_alpha is None:
|
| 432 |
+
network_alpha = 1.0
|
| 433 |
+
|
| 434 |
+
# extract dim/alpha for conv2d, and block dim
|
| 435 |
+
conv_dim = kwargs.get("conv_dim", None)
|
| 436 |
+
conv_alpha = kwargs.get("conv_alpha", None)
|
| 437 |
+
if conv_dim is not None:
|
| 438 |
+
conv_dim = int(conv_dim)
|
| 439 |
+
if conv_alpha is None:
|
| 440 |
+
conv_alpha = 1.0
|
| 441 |
+
else:
|
| 442 |
+
conv_alpha = float(conv_alpha)
|
| 443 |
+
|
| 444 |
+
# block dim/alpha/lr
|
| 445 |
+
block_dims = kwargs.get("block_dims", None)
|
| 446 |
+
block_lr_weight = parse_block_lr_kwargs(is_sdxl, kwargs)
|
| 447 |
+
|
| 448 |
+
# 以上のいずれかに指定があればblockごとのdim(rank)を有効にする
|
| 449 |
+
if block_dims is not None or block_lr_weight is not None:
|
| 450 |
+
block_alphas = kwargs.get("block_alphas", None)
|
| 451 |
+
conv_block_dims = kwargs.get("conv_block_dims", None)
|
| 452 |
+
conv_block_alphas = kwargs.get("conv_block_alphas", None)
|
| 453 |
+
|
| 454 |
+
block_dims, block_alphas, conv_block_dims, conv_block_alphas = get_block_dims_and_alphas(
|
| 455 |
+
is_sdxl, block_dims, block_alphas, network_dim, network_alpha, conv_block_dims, conv_block_alphas, conv_dim, conv_alpha
|
| 456 |
+
)
|
| 457 |
+
|
| 458 |
+
# remove block dim/alpha without learning rate
|
| 459 |
+
block_dims, block_alphas, conv_block_dims, conv_block_alphas = remove_block_dims_and_alphas(
|
| 460 |
+
is_sdxl, block_dims, block_alphas, conv_block_dims, conv_block_alphas, block_lr_weight
|
| 461 |
+
)
|
| 462 |
+
|
| 463 |
+
else:
|
| 464 |
+
block_alphas = None
|
| 465 |
+
conv_block_dims = None
|
| 466 |
+
conv_block_alphas = None
|
| 467 |
+
|
| 468 |
+
# rank/module dropout
|
| 469 |
+
rank_dropout = kwargs.get("rank_dropout", None)
|
| 470 |
+
if rank_dropout is not None:
|
| 471 |
+
rank_dropout = float(rank_dropout)
|
| 472 |
+
module_dropout = kwargs.get("module_dropout", None)
|
| 473 |
+
if module_dropout is not None:
|
| 474 |
+
module_dropout = float(module_dropout)
|
| 475 |
+
|
| 476 |
+
# すごく引数が多いな ( ^ω^)・・・
|
| 477 |
+
network = LoRANetwork(
|
| 478 |
+
text_encoder,
|
| 479 |
+
unet,
|
| 480 |
+
multiplier=multiplier,
|
| 481 |
+
lora_dim=network_dim,
|
| 482 |
+
alpha=network_alpha,
|
| 483 |
+
dropout=neuron_dropout,
|
| 484 |
+
rank_dropout=rank_dropout,
|
| 485 |
+
module_dropout=module_dropout,
|
| 486 |
+
conv_lora_dim=conv_dim,
|
| 487 |
+
conv_alpha=conv_alpha,
|
| 488 |
+
block_dims=block_dims,
|
| 489 |
+
block_alphas=block_alphas,
|
| 490 |
+
conv_block_dims=conv_block_dims,
|
| 491 |
+
conv_block_alphas=conv_block_alphas,
|
| 492 |
+
varbose=True,
|
| 493 |
+
is_sdxl=is_sdxl,
|
| 494 |
+
)
|
| 495 |
+
|
| 496 |
+
loraplus_lr_ratio = kwargs.get("loraplus_lr_ratio", None)
|
| 497 |
+
loraplus_unet_lr_ratio = kwargs.get("loraplus_unet_lr_ratio", None)
|
| 498 |
+
loraplus_text_encoder_lr_ratio = kwargs.get("loraplus_text_encoder_lr_ratio", None)
|
| 499 |
+
loraplus_lr_ratio = float(loraplus_lr_ratio) if loraplus_lr_ratio is not None else None
|
| 500 |
+
loraplus_unet_lr_ratio = float(loraplus_unet_lr_ratio) if loraplus_unet_lr_ratio is not None else None
|
| 501 |
+
loraplus_text_encoder_lr_ratio = float(loraplus_text_encoder_lr_ratio) if loraplus_text_encoder_lr_ratio is not None else None
|
| 502 |
+
if loraplus_lr_ratio is not None or loraplus_unet_lr_ratio is not None or loraplus_text_encoder_lr_ratio is not None:
|
| 503 |
+
network.set_loraplus_lr_ratio(loraplus_lr_ratio, loraplus_unet_lr_ratio, loraplus_text_encoder_lr_ratio)
|
| 504 |
+
|
| 505 |
+
if block_lr_weight is not None:
|
| 506 |
+
network.set_block_lr_weight(block_lr_weight)
|
| 507 |
+
|
| 508 |
+
return network
|
| 509 |
+
|
| 510 |
+
|
| 511 |
+
# このメソッドは外部から呼び出される可能性を考慮しておく
|
| 512 |
+
# network_dim, network_alpha にはデフォルト値が入っている。
|
| 513 |
+
# block_dims, block_alphas は両方ともNoneまたは両方とも値が入っている
|
| 514 |
+
# conv_dim, conv_alpha は両方ともNoneまたは両��とも値が入っている
|
| 515 |
+
def get_block_dims_and_alphas(
|
| 516 |
+
is_sdxl, block_dims, block_alphas, network_dim, network_alpha, conv_block_dims, conv_block_alphas, conv_dim, conv_alpha
|
| 517 |
+
):
|
| 518 |
+
if not is_sdxl:
|
| 519 |
+
num_total_blocks = LoRANetwork.NUM_OF_BLOCKS * 2 + LoRANetwork.NUM_OF_MID_BLOCKS
|
| 520 |
+
else:
|
| 521 |
+
# 1+9+3+9+1=23, no LoRA for emb_layers (0)
|
| 522 |
+
num_total_blocks = 1 + LoRANetwork.SDXL_NUM_OF_BLOCKS * 2 + LoRANetwork.SDXL_NUM_OF_MID_BLOCKS + 1
|
| 523 |
+
|
| 524 |
+
def parse_ints(s):
|
| 525 |
+
return [int(i) for i in s.split(",")]
|
| 526 |
+
|
| 527 |
+
def parse_floats(s):
|
| 528 |
+
return [float(i) for i in s.split(",")]
|
| 529 |
+
|
| 530 |
+
# block_dimsとblock_alphasをパースする。必ず値が入る
|
| 531 |
+
if block_dims is not None:
|
| 532 |
+
block_dims = parse_ints(block_dims)
|
| 533 |
+
assert len(block_dims) == num_total_blocks, (
|
| 534 |
+
f"block_dims must have {num_total_blocks} elements but {len(block_dims)} elements are given"
|
| 535 |
+
+ f" / block_dimsは{num_total_blocks}個指定してください(指定された個数: {len(block_dims)})"
|
| 536 |
+
)
|
| 537 |
+
else:
|
| 538 |
+
logger.warning(
|
| 539 |
+
f"block_dims is not specified. all dims are set to {network_dim} / block_dimsが指定されていません。すべてのdimは{network_dim}になります"
|
| 540 |
+
)
|
| 541 |
+
block_dims = [network_dim] * num_total_blocks
|
| 542 |
+
|
| 543 |
+
if block_alphas is not None:
|
| 544 |
+
block_alphas = parse_floats(block_alphas)
|
| 545 |
+
assert (
|
| 546 |
+
len(block_alphas) == num_total_blocks
|
| 547 |
+
), f"block_alphas must have {num_total_blocks} elements / block_alphasは{num_total_blocks}個指定してください"
|
| 548 |
+
else:
|
| 549 |
+
logger.warning(
|
| 550 |
+
f"block_alphas is not specified. all alphas are set to {network_alpha} / block_alphasが指定されていません。すべてのalphaは{network_alpha}になります"
|
| 551 |
+
)
|
| 552 |
+
block_alphas = [network_alpha] * num_total_blocks
|
| 553 |
+
|
| 554 |
+
# conv_block_dimsとconv_block_alphasを、指定がある場合のみパースする。指定がなければconv_dimとconv_alphaを使う
|
| 555 |
+
if conv_block_dims is not None:
|
| 556 |
+
conv_block_dims = parse_ints(conv_block_dims)
|
| 557 |
+
assert (
|
| 558 |
+
len(conv_block_dims) == num_total_blocks
|
| 559 |
+
), f"conv_block_dims must have {num_total_blocks} elements / conv_block_dimsは{num_total_blocks}個指定してください"
|
| 560 |
+
|
| 561 |
+
if conv_block_alphas is not None:
|
| 562 |
+
conv_block_alphas = parse_floats(conv_block_alphas)
|
| 563 |
+
assert (
|
| 564 |
+
len(conv_block_alphas) == num_total_blocks
|
| 565 |
+
), f"conv_block_alphas must have {num_total_blocks} elements / conv_block_alphasは{num_total_blocks}個指定してください"
|
| 566 |
+
else:
|
| 567 |
+
if conv_alpha is None:
|
| 568 |
+
conv_alpha = 1.0
|
| 569 |
+
logger.warning(
|
| 570 |
+
f"conv_block_alphas is not specified. all alphas are set to {conv_alpha} / conv_block_alphasが指定されていません。すべてのalphaは{conv_alpha}になります"
|
| 571 |
+
)
|
| 572 |
+
conv_block_alphas = [conv_alpha] * num_total_blocks
|
| 573 |
+
else:
|
| 574 |
+
if conv_dim is not None:
|
| 575 |
+
logger.warning(
|
| 576 |
+
f"conv_dim/alpha for all blocks are set to {conv_dim} and {conv_alpha} / すべてのブロックのconv_dimとalphaは{conv_dim}および{conv_alpha}になります"
|
| 577 |
+
)
|
| 578 |
+
conv_block_dims = [conv_dim] * num_total_blocks
|
| 579 |
+
conv_block_alphas = [conv_alpha] * num_total_blocks
|
| 580 |
+
else:
|
| 581 |
+
conv_block_dims = None
|
| 582 |
+
conv_block_alphas = None
|
| 583 |
+
|
| 584 |
+
return block_dims, block_alphas, conv_block_dims, conv_block_alphas
|
| 585 |
+
|
| 586 |
+
|
| 587 |
+
# 層別学習率用に層ごとの学習率に対する倍率を定義する、外部から呼び出せるようにclass外に出しておく
|
| 588 |
+
# 戻り値は block ごとの倍率のリスト
|
| 589 |
+
def get_block_lr_weight(
|
| 590 |
+
is_sdxl,
|
| 591 |
+
down_lr_weight: Union[str, List[float]],
|
| 592 |
+
mid_lr_weight: List[float],
|
| 593 |
+
up_lr_weight: Union[str, List[float]],
|
| 594 |
+
zero_threshold: float,
|
| 595 |
+
) -> Optional[List[float]]:
|
| 596 |
+
# パラメータ未指定時は何もせず、今までと同じ動作とする
|
| 597 |
+
if up_lr_weight is None and mid_lr_weight is None and down_lr_weight is None:
|
| 598 |
+
return None
|
| 599 |
+
|
| 600 |
+
if not is_sdxl:
|
| 601 |
+
max_len_for_down_or_up = LoRANetwork.NUM_OF_BLOCKS
|
| 602 |
+
max_len_for_mid = LoRANetwork.NUM_OF_MID_BLOCKS
|
| 603 |
+
else:
|
| 604 |
+
max_len_for_down_or_up = LoRANetwork.SDXL_NUM_OF_BLOCKS
|
| 605 |
+
max_len_for_mid = LoRANetwork.SDXL_NUM_OF_MID_BLOCKS
|
| 606 |
+
|
| 607 |
+
def get_list(name_with_suffix) -> List[float]:
|
| 608 |
+
import math
|
| 609 |
+
|
| 610 |
+
tokens = name_with_suffix.split("+")
|
| 611 |
+
name = tokens[0]
|
| 612 |
+
base_lr = float(tokens[1]) if len(tokens) > 1 else 0.0
|
| 613 |
+
|
| 614 |
+
if name == "cosine":
|
| 615 |
+
return [
|
| 616 |
+
math.sin(math.pi * (i / (max_len_for_down_or_up - 1)) / 2) + base_lr
|
| 617 |
+
for i in reversed(range(max_len_for_down_or_up))
|
| 618 |
+
]
|
| 619 |
+
elif name == "sine":
|
| 620 |
+
return [math.sin(math.pi * (i / (max_len_for_down_or_up - 1)) / 2) + base_lr for i in range(max_len_for_down_or_up)]
|
| 621 |
+
elif name == "linear":
|
| 622 |
+
return [i / (max_len_for_down_or_up - 1) + base_lr for i in range(max_len_for_down_or_up)]
|
| 623 |
+
elif name == "reverse_linear":
|
| 624 |
+
return [i / (max_len_for_down_or_up - 1) + base_lr for i in reversed(range(max_len_for_down_or_up))]
|
| 625 |
+
elif name == "zeros":
|
| 626 |
+
return [0.0 + base_lr] * max_len_for_down_or_up
|
| 627 |
+
else:
|
| 628 |
+
logger.error(
|
| 629 |
+
"Unknown lr_weight argument %s is used. Valid arguments: / 不明なlr_weightの引数 %s が使われました。有効な引数:\n\tcosine, sine, linear, reverse_linear, zeros"
|
| 630 |
+
% (name)
|
| 631 |
+
)
|
| 632 |
+
return None
|
| 633 |
+
|
| 634 |
+
if type(down_lr_weight) == str:
|
| 635 |
+
down_lr_weight = get_list(down_lr_weight)
|
| 636 |
+
if type(up_lr_weight) == str:
|
| 637 |
+
up_lr_weight = get_list(up_lr_weight)
|
| 638 |
+
|
| 639 |
+
if (up_lr_weight != None and len(up_lr_weight) > max_len_for_down_or_up) or (
|
| 640 |
+
down_lr_weight != None and len(down_lr_weight) > max_len_for_down_or_up
|
| 641 |
+
):
|
| 642 |
+
logger.warning("down_weight or up_weight is too long. Parameters after %d-th are ignored." % max_len_for_down_or_up)
|
| 643 |
+
logger.warning("down_weightもしくはup_weightが長すぎます。%d個目以降のパラメータは無視されます。" % max_len_for_down_or_up)
|
| 644 |
+
up_lr_weight = up_lr_weight[:max_len_for_down_or_up]
|
| 645 |
+
down_lr_weight = down_lr_weight[:max_len_for_down_or_up]
|
| 646 |
+
|
| 647 |
+
if mid_lr_weight != None and len(mid_lr_weight) > max_len_for_mid:
|
| 648 |
+
logger.warning("mid_weight is too long. Parameters after %d-th are ignored." % max_len_for_mid)
|
| 649 |
+
logger.warning("mid_weightが長すぎます。%d個目以降のパラメータは無視されます。" % max_len_for_mid)
|
| 650 |
+
mid_lr_weight = mid_lr_weight[:max_len_for_mid]
|
| 651 |
+
|
| 652 |
+
if (up_lr_weight != None and len(up_lr_weight) < max_len_for_down_or_up) or (
|
| 653 |
+
down_lr_weight != None and len(down_lr_weight) < max_len_for_down_or_up
|
| 654 |
+
):
|
| 655 |
+
logger.warning("down_weight or up_weight is too short. Parameters after %d-th are filled with 1." % max_len_for_down_or_up)
|
| 656 |
+
logger.warning(
|
| 657 |
+
"down_weightもしくはup_weightが短すぎます。%d個目までの不足したパラメータは1で補われます。" % max_len_for_down_or_up
|
| 658 |
+
)
|
| 659 |
+
|
| 660 |
+
if down_lr_weight != None and len(down_lr_weight) < max_len_for_down_or_up:
|
| 661 |
+
down_lr_weight = down_lr_weight + [1.0] * (max_len_for_down_or_up - len(down_lr_weight))
|
| 662 |
+
if up_lr_weight != None and len(up_lr_weight) < max_len_for_down_or_up:
|
| 663 |
+
up_lr_weight = up_lr_weight + [1.0] * (max_len_for_down_or_up - len(up_lr_weight))
|
| 664 |
+
|
| 665 |
+
if mid_lr_weight != None and len(mid_lr_weight) < max_len_for_mid:
|
| 666 |
+
logger.warning("mid_weight is too short. Parameters after %d-th are filled with 1." % max_len_for_mid)
|
| 667 |
+
logger.warning("mid_weightが短すぎます。%d個目までの不足したパラメータは1で補われます。" % max_len_for_mid)
|
| 668 |
+
mid_lr_weight = mid_lr_weight + [1.0] * (max_len_for_mid - len(mid_lr_weight))
|
| 669 |
+
|
| 670 |
+
if (up_lr_weight != None) or (mid_lr_weight != None) or (down_lr_weight != None):
|
| 671 |
+
logger.info("apply block learning rate / 階層別学習率を適用します。")
|
| 672 |
+
if down_lr_weight != None:
|
| 673 |
+
down_lr_weight = [w if w > zero_threshold else 0 for w in down_lr_weight]
|
| 674 |
+
logger.info(f"down_lr_weight (shallower -> deeper, 浅い層->深い層): {down_lr_weight}")
|
| 675 |
+
else:
|
| 676 |
+
down_lr_weight = [1.0] * max_len_for_down_or_up
|
| 677 |
+
logger.info("down_lr_weight: all 1.0, すべて1.0")
|
| 678 |
+
|
| 679 |
+
if mid_lr_weight != None:
|
| 680 |
+
mid_lr_weight = [w if w > zero_threshold else 0 for w in mid_lr_weight]
|
| 681 |
+
logger.info(f"mid_lr_weight: {mid_lr_weight}")
|
| 682 |
+
else:
|
| 683 |
+
mid_lr_weight = [1.0] * max_len_for_mid
|
| 684 |
+
logger.info("mid_lr_weight: all 1.0, すべて1.0")
|
| 685 |
+
|
| 686 |
+
if up_lr_weight != None:
|
| 687 |
+
up_lr_weight = [w if w > zero_threshold else 0 for w in up_lr_weight]
|
| 688 |
+
logger.info(f"up_lr_weight (deeper -> shallower, 深い層->浅い層): {up_lr_weight}")
|
| 689 |
+
else:
|
| 690 |
+
up_lr_weight = [1.0] * max_len_for_down_or_up
|
| 691 |
+
logger.info("up_lr_weight: all 1.0, すべて1.0")
|
| 692 |
+
|
| 693 |
+
lr_weight = down_lr_weight + mid_lr_weight + up_lr_weight
|
| 694 |
+
|
| 695 |
+
if is_sdxl:
|
| 696 |
+
lr_weight = [1.0] + lr_weight + [1.0] # add 1.0 for emb_layers and out
|
| 697 |
+
|
| 698 |
+
assert (not is_sdxl and len(lr_weight) == LoRANetwork.NUM_OF_BLOCKS * 2 + LoRANetwork.NUM_OF_MID_BLOCKS) or (
|
| 699 |
+
is_sdxl and len(lr_weight) == 1 + LoRANetwork.SDXL_NUM_OF_BLOCKS * 2 + LoRANetwork.SDXL_NUM_OF_MID_BLOCKS + 1
|
| 700 |
+
), f"lr_weight length is invalid: {len(lr_weight)}"
|
| 701 |
+
|
| 702 |
+
return lr_weight
|
| 703 |
+
|
| 704 |
+
|
| 705 |
+
# lr_weightが0のblockをblock_dimsから除外する、外部から呼び出す可能性を考慮しておく
|
| 706 |
+
def remove_block_dims_and_alphas(
|
| 707 |
+
is_sdxl, block_dims, block_alphas, conv_block_dims, conv_block_alphas, block_lr_weight: Optional[List[float]]
|
| 708 |
+
):
|
| 709 |
+
if block_lr_weight is not None:
|
| 710 |
+
for i, lr in enumerate(block_lr_weight):
|
| 711 |
+
if lr == 0:
|
| 712 |
+
block_dims[i] = 0
|
| 713 |
+
if conv_block_dims is not None:
|
| 714 |
+
conv_block_dims[i] = 0
|
| 715 |
+
return block_dims, block_alphas, conv_block_dims, conv_block_alphas
|
| 716 |
+
|
| 717 |
+
|
| 718 |
+
# 外部から呼び出す可能性を考慮しておく
|
| 719 |
+
def get_block_index(lora_name: str, is_sdxl: bool = False) -> int:
|
| 720 |
+
block_idx = -1 # invalid lora name
|
| 721 |
+
if not is_sdxl:
|
| 722 |
+
m = RE_UPDOWN.search(lora_name)
|
| 723 |
+
if m:
|
| 724 |
+
g = m.groups()
|
| 725 |
+
i = int(g[1])
|
| 726 |
+
j = int(g[3])
|
| 727 |
+
if g[2] == "resnets":
|
| 728 |
+
idx = 3 * i + j
|
| 729 |
+
elif g[2] == "attentions":
|
| 730 |
+
idx = 3 * i + j
|
| 731 |
+
elif g[2] == "upsamplers" or g[2] == "downsamplers":
|
| 732 |
+
idx = 3 * i + 2
|
| 733 |
+
|
| 734 |
+
if g[0] == "down":
|
| 735 |
+
block_idx = 1 + idx # 0に該当するLoRAは存在しない
|
| 736 |
+
elif g[0] == "up":
|
| 737 |
+
block_idx = LoRANetwork.NUM_OF_BLOCKS + 1 + idx
|
| 738 |
+
elif "mid_block_" in lora_name:
|
| 739 |
+
block_idx = LoRANetwork.NUM_OF_BLOCKS # idx=12
|
| 740 |
+
else:
|
| 741 |
+
# copy from sdxl_train
|
| 742 |
+
if lora_name.startswith("lora_unet_"):
|
| 743 |
+
name = lora_name[len("lora_unet_") :]
|
| 744 |
+
if name.startswith("time_embed_") or name.startswith("label_emb_"): # No LoRA
|
| 745 |
+
block_idx = 0 # 0
|
| 746 |
+
elif name.startswith("input_blocks_"): # 1-9
|
| 747 |
+
block_idx = 1 + int(name.split("_")[2])
|
| 748 |
+
elif name.startswith("middle_block_"): # 10-12
|
| 749 |
+
block_idx = 10 + int(name.split("_")[2])
|
| 750 |
+
elif name.startswith("output_blocks_"): # 13-21
|
| 751 |
+
block_idx = 13 + int(name.split("_")[2])
|
| 752 |
+
elif name.startswith("out_"): # 22, out, no LoRA
|
| 753 |
+
block_idx = 22
|
| 754 |
+
|
| 755 |
+
return block_idx
|
| 756 |
+
|
| 757 |
+
|
| 758 |
+
def convert_diffusers_to_sai_if_needed(weights_sd):
|
| 759 |
+
# only supports U-Net LoRA modules
|
| 760 |
+
|
| 761 |
+
found_up_down_blocks = False
|
| 762 |
+
for k in list(weights_sd.keys()):
|
| 763 |
+
if "down_blocks" in k:
|
| 764 |
+
found_up_down_blocks = True
|
| 765 |
+
break
|
| 766 |
+
if "up_blocks" in k:
|
| 767 |
+
found_up_down_blocks = True
|
| 768 |
+
break
|
| 769 |
+
if not found_up_down_blocks:
|
| 770 |
+
return
|
| 771 |
+
|
| 772 |
+
from library.sdxl_model_util import make_unet_conversion_map
|
| 773 |
+
|
| 774 |
+
unet_conversion_map = make_unet_conversion_map()
|
| 775 |
+
unet_conversion_map = {hf.replace(".", "_")[:-1]: sd.replace(".", "_")[:-1] for sd, hf in unet_conversion_map}
|
| 776 |
+
|
| 777 |
+
# # add extra conversion
|
| 778 |
+
# unet_conversion_map["up_blocks_1_upsamplers_0"] = "lora_unet_output_blocks_2_2_conv"
|
| 779 |
+
|
| 780 |
+
logger.info(f"Converting LoRA keys from Diffusers to SAI")
|
| 781 |
+
lora_unet_prefix = "lora_unet_"
|
| 782 |
+
for k in list(weights_sd.keys()):
|
| 783 |
+
if not k.startswith(lora_unet_prefix):
|
| 784 |
+
continue
|
| 785 |
+
|
| 786 |
+
unet_module_name = k[len(lora_unet_prefix) :].split(".")[0]
|
| 787 |
+
|
| 788 |
+
# search for conversion: this is slow because the algorithm is O(n^2), but the number of keys is small
|
| 789 |
+
for hf_module_name, sd_module_name in unet_conversion_map.items():
|
| 790 |
+
if hf_module_name in unet_module_name:
|
| 791 |
+
new_key = (
|
| 792 |
+
lora_unet_prefix
|
| 793 |
+
+ unet_module_name.replace(hf_module_name, sd_module_name)
|
| 794 |
+
+ k[len(lora_unet_prefix) + len(unet_module_name) :]
|
| 795 |
+
)
|
| 796 |
+
weights_sd[new_key] = weights_sd.pop(k)
|
| 797 |
+
found = True
|
| 798 |
+
break
|
| 799 |
+
|
| 800 |
+
if not found:
|
| 801 |
+
logger.warning(f"Key {k} is not found in unet_conversion_map")
|
| 802 |
+
|
| 803 |
+
|
| 804 |
+
# Create network from weights for inference, weights are not loaded here (because can be merged)
|
| 805 |
+
def create_network_from_weights(multiplier, file, vae, text_encoder, unet, weights_sd=None, for_inference=False, **kwargs):
|
| 806 |
+
# if unet is an instance of SdxlUNet2DConditionModel or subclass, set is_sdxl to True
|
| 807 |
+
is_sdxl = unet is not None and issubclass(unet.__class__, SdxlUNet2DConditionModel)
|
| 808 |
+
|
| 809 |
+
if weights_sd is None:
|
| 810 |
+
if os.path.splitext(file)[1] == ".safetensors":
|
| 811 |
+
from safetensors.torch import load_file, safe_open
|
| 812 |
+
|
| 813 |
+
weights_sd = load_file(file)
|
| 814 |
+
else:
|
| 815 |
+
weights_sd = torch.load(file, map_location="cpu")
|
| 816 |
+
|
| 817 |
+
# if keys are Diffusers based, convert to SAI based
|
| 818 |
+
if is_sdxl:
|
| 819 |
+
convert_diffusers_to_sai_if_needed(weights_sd)
|
| 820 |
+
|
| 821 |
+
# get dim/alpha mapping
|
| 822 |
+
modules_dim = {}
|
| 823 |
+
modules_alpha = {}
|
| 824 |
+
for key, value in weights_sd.items():
|
| 825 |
+
if "." not in key:
|
| 826 |
+
continue
|
| 827 |
+
|
| 828 |
+
lora_name = key.split(".")[0]
|
| 829 |
+
if "alpha" in key:
|
| 830 |
+
modules_alpha[lora_name] = value
|
| 831 |
+
elif "lora_down" in key:
|
| 832 |
+
dim = value.size()[0]
|
| 833 |
+
modules_dim[lora_name] = dim
|
| 834 |
+
# logger.info(lora_name, value.size(), dim)
|
| 835 |
+
|
| 836 |
+
# support old LoRA without alpha
|
| 837 |
+
for key in modules_dim.keys():
|
| 838 |
+
if key not in modules_alpha:
|
| 839 |
+
modules_alpha[key] = modules_dim[key]
|
| 840 |
+
|
| 841 |
+
module_class = LoRAInfModule if for_inference else LoRAModule
|
| 842 |
+
|
| 843 |
+
network = LoRANetwork(
|
| 844 |
+
text_encoder,
|
| 845 |
+
unet,
|
| 846 |
+
multiplier=multiplier,
|
| 847 |
+
modules_dim=modules_dim,
|
| 848 |
+
modules_alpha=modules_alpha,
|
| 849 |
+
module_class=module_class,
|
| 850 |
+
is_sdxl=is_sdxl,
|
| 851 |
+
)
|
| 852 |
+
|
| 853 |
+
# block lr
|
| 854 |
+
block_lr_weight = parse_block_lr_kwargs(is_sdxl, kwargs)
|
| 855 |
+
if block_lr_weight is not None:
|
| 856 |
+
network.set_block_lr_weight(block_lr_weight)
|
| 857 |
+
|
| 858 |
+
return network, weights_sd
|
| 859 |
+
|
| 860 |
+
|
| 861 |
+
class LoRANetwork(torch.nn.Module):
|
| 862 |
+
NUM_OF_BLOCKS = 12 # フルモデル相当でのup,downの層の数
|
| 863 |
+
NUM_OF_MID_BLOCKS = 1
|
| 864 |
+
SDXL_NUM_OF_BLOCKS = 9 # SDXLのモデルでのinput/outputの層の数 total=1(base) 9(input) + 3(mid) + 9(output) + 1(out) = 23
|
| 865 |
+
SDXL_NUM_OF_MID_BLOCKS = 3
|
| 866 |
+
|
| 867 |
+
UNET_TARGET_REPLACE_MODULE = ["Transformer2DModel"]
|
| 868 |
+
UNET_TARGET_REPLACE_MODULE_CONV2D_3X3 = ["ResnetBlock2D", "Downsample2D", "Upsample2D"]
|
| 869 |
+
TEXT_ENCODER_TARGET_REPLACE_MODULE = ["CLIPAttention", "CLIPSdpaAttention", "CLIPMLP"]
|
| 870 |
+
LORA_PREFIX_UNET = "lora_unet"
|
| 871 |
+
LORA_PREFIX_TEXT_ENCODER = "lora_te"
|
| 872 |
+
|
| 873 |
+
# SDXL: must starts with LORA_PREFIX_TEXT_ENCODER
|
| 874 |
+
LORA_PREFIX_TEXT_ENCODER1 = "lora_te1"
|
| 875 |
+
LORA_PREFIX_TEXT_ENCODER2 = "lora_te2"
|
| 876 |
+
|
| 877 |
+
def __init__(
|
| 878 |
+
self,
|
| 879 |
+
text_encoder: Union[List[CLIPTextModel], CLIPTextModel],
|
| 880 |
+
unet,
|
| 881 |
+
multiplier: float = 1.0,
|
| 882 |
+
lora_dim: int = 4,
|
| 883 |
+
alpha: float = 1,
|
| 884 |
+
dropout: Optional[float] = None,
|
| 885 |
+
rank_dropout: Optional[float] = None,
|
| 886 |
+
module_dropout: Optional[float] = None,
|
| 887 |
+
conv_lora_dim: Optional[int] = None,
|
| 888 |
+
conv_alpha: Optional[float] = None,
|
| 889 |
+
block_dims: Optional[List[int]] = None,
|
| 890 |
+
block_alphas: Optional[List[float]] = None,
|
| 891 |
+
conv_block_dims: Optional[List[int]] = None,
|
| 892 |
+
conv_block_alphas: Optional[List[float]] = None,
|
| 893 |
+
modules_dim: Optional[Dict[str, int]] = None,
|
| 894 |
+
modules_alpha: Optional[Dict[str, int]] = None,
|
| 895 |
+
module_class: Type[object] = LoRAModule,
|
| 896 |
+
varbose: Optional[bool] = False,
|
| 897 |
+
is_sdxl: Optional[bool] = False,
|
| 898 |
+
) -> None:
|
| 899 |
+
"""
|
| 900 |
+
LoRA network: すごく引数が多いが、パターンは以下の通り
|
| 901 |
+
1. lora_dimとalphaを指定
|
| 902 |
+
2. lora_dim、alpha、conv_lora_dim、conv_alphaを指定
|
| 903 |
+
3. block_dimsとblock_alphasを指定 : Conv2d3x3には適用しない
|
| 904 |
+
4. block_dims、block_alphas、conv_block_dims、conv_block_alphasを指定 : Conv2d3x3にも適用する
|
| 905 |
+
5. modules_dimとmodules_alphaを指定 (推論用)
|
| 906 |
+
"""
|
| 907 |
+
super().__init__()
|
| 908 |
+
self.multiplier = multiplier
|
| 909 |
+
|
| 910 |
+
self.lora_dim = lora_dim
|
| 911 |
+
self.alpha = alpha
|
| 912 |
+
self.conv_lora_dim = conv_lora_dim
|
| 913 |
+
self.conv_alpha = conv_alpha
|
| 914 |
+
self.dropout = dropout
|
| 915 |
+
self.rank_dropout = rank_dropout
|
| 916 |
+
self.module_dropout = module_dropout
|
| 917 |
+
|
| 918 |
+
self.loraplus_lr_ratio = None
|
| 919 |
+
self.loraplus_unet_lr_ratio = None
|
| 920 |
+
self.loraplus_text_encoder_lr_ratio = None
|
| 921 |
+
|
| 922 |
+
if modules_dim is not None:
|
| 923 |
+
logger.info(f"create LoRA network from weights")
|
| 924 |
+
elif block_dims is not None:
|
| 925 |
+
logger.info(f"create LoRA network from block_dims")
|
| 926 |
+
logger.info(
|
| 927 |
+
f"neuron dropout: p={self.dropout}, rank dropout: p={self.rank_dropout}, module dropout: p={self.module_dropout}"
|
| 928 |
+
)
|
| 929 |
+
logger.info(f"block_dims: {block_dims}")
|
| 930 |
+
logger.info(f"block_alphas: {block_alphas}")
|
| 931 |
+
if conv_block_dims is not None:
|
| 932 |
+
logger.info(f"conv_block_dims: {conv_block_dims}")
|
| 933 |
+
logger.info(f"conv_block_alphas: {conv_block_alphas}")
|
| 934 |
+
else:
|
| 935 |
+
logger.info(f"create LoRA network. base dim (rank): {lora_dim}, alpha: {alpha}")
|
| 936 |
+
logger.info(
|
| 937 |
+
f"neuron dropout: p={self.dropout}, rank dropout: p={self.rank_dropout}, module dropout: p={self.module_dropout}"
|
| 938 |
+
)
|
| 939 |
+
if self.conv_lora_dim is not None:
|
| 940 |
+
logger.info(
|
| 941 |
+
f"apply LoRA to Conv2d with kernel size (3,3). dim (rank): {self.conv_lora_dim}, alpha: {self.conv_alpha}"
|
| 942 |
+
)
|
| 943 |
+
|
| 944 |
+
# create module instances
|
| 945 |
+
def create_modules(
|
| 946 |
+
is_unet: bool,
|
| 947 |
+
text_encoder_idx: Optional[int], # None, 1, 2
|
| 948 |
+
root_module: torch.nn.Module,
|
| 949 |
+
target_replace_modules: List[torch.nn.Module],
|
| 950 |
+
) -> List[LoRAModule]:
|
| 951 |
+
prefix = (
|
| 952 |
+
self.LORA_PREFIX_UNET
|
| 953 |
+
if is_unet
|
| 954 |
+
else (
|
| 955 |
+
self.LORA_PREFIX_TEXT_ENCODER
|
| 956 |
+
if text_encoder_idx is None
|
| 957 |
+
else (self.LORA_PREFIX_TEXT_ENCODER1 if text_encoder_idx == 1 else self.LORA_PREFIX_TEXT_ENCODER2)
|
| 958 |
+
)
|
| 959 |
+
)
|
| 960 |
+
loras = []
|
| 961 |
+
skipped = []
|
| 962 |
+
for name, module in root_module.named_modules():
|
| 963 |
+
if module.__class__.__name__ in target_replace_modules:
|
| 964 |
+
for child_name, child_module in module.named_modules():
|
| 965 |
+
is_linear = child_module.__class__.__name__ == "Linear"
|
| 966 |
+
is_conv2d = child_module.__class__.__name__ == "Conv2d"
|
| 967 |
+
is_conv2d_1x1 = is_conv2d and child_module.kernel_size == (1, 1)
|
| 968 |
+
|
| 969 |
+
if is_linear or is_conv2d:
|
| 970 |
+
lora_name = prefix + "." + name + "." + child_name
|
| 971 |
+
lora_name = lora_name.replace(".", "_")
|
| 972 |
+
|
| 973 |
+
dim = None
|
| 974 |
+
alpha = None
|
| 975 |
+
|
| 976 |
+
if modules_dim is not None:
|
| 977 |
+
# モジュール指定あり
|
| 978 |
+
if lora_name in modules_dim:
|
| 979 |
+
dim = modules_dim[lora_name]
|
| 980 |
+
alpha = modules_alpha[lora_name]
|
| 981 |
+
elif is_unet and block_dims is not None:
|
| 982 |
+
# U-Netでblock_dims指定あり
|
| 983 |
+
block_idx = get_block_index(lora_name, is_sdxl)
|
| 984 |
+
if is_linear or is_conv2d_1x1:
|
| 985 |
+
dim = block_dims[block_idx]
|
| 986 |
+
alpha = block_alphas[block_idx]
|
| 987 |
+
elif conv_block_dims is not None:
|
| 988 |
+
dim = conv_block_dims[block_idx]
|
| 989 |
+
alpha = conv_block_alphas[block_idx]
|
| 990 |
+
else:
|
| 991 |
+
# 通常、すべて対象とする
|
| 992 |
+
if is_linear or is_conv2d_1x1:
|
| 993 |
+
dim = self.lora_dim
|
| 994 |
+
alpha = self.alpha
|
| 995 |
+
elif self.conv_lora_dim is not None:
|
| 996 |
+
dim = self.conv_lora_dim
|
| 997 |
+
alpha = self.conv_alpha
|
| 998 |
+
|
| 999 |
+
if dim is None or dim == 0:
|
| 1000 |
+
# skipした情報を出力
|
| 1001 |
+
if is_linear or is_conv2d_1x1 or (self.conv_lora_dim is not None or conv_block_dims is not None):
|
| 1002 |
+
skipped.append(lora_name)
|
| 1003 |
+
continue
|
| 1004 |
+
|
| 1005 |
+
lora = module_class(
|
| 1006 |
+
lora_name,
|
| 1007 |
+
child_module,
|
| 1008 |
+
self.multiplier,
|
| 1009 |
+
dim,
|
| 1010 |
+
alpha,
|
| 1011 |
+
dropout=dropout,
|
| 1012 |
+
rank_dropout=rank_dropout,
|
| 1013 |
+
module_dropout=module_dropout,
|
| 1014 |
+
)
|
| 1015 |
+
loras.append(lora)
|
| 1016 |
+
return loras, skipped
|
| 1017 |
+
|
| 1018 |
+
text_encoders = text_encoder if type(text_encoder) == list else [text_encoder]
|
| 1019 |
+
|
| 1020 |
+
# create LoRA for text encoder
|
| 1021 |
+
# 毎回すべてのモジュールを作るのは無駄なので要検討
|
| 1022 |
+
self.text_encoder_loras = []
|
| 1023 |
+
skipped_te = []
|
| 1024 |
+
for i, text_encoder in enumerate(text_encoders):
|
| 1025 |
+
if len(text_encoders) > 1:
|
| 1026 |
+
index = i + 1
|
| 1027 |
+
logger.info(f"create LoRA for Text Encoder {index}:")
|
| 1028 |
+
else:
|
| 1029 |
+
index = None
|
| 1030 |
+
logger.info(f"create LoRA for Text Encoder:")
|
| 1031 |
+
|
| 1032 |
+
text_encoder_loras, skipped = create_modules(False, index, text_encoder, LoRANetwork.TEXT_ENCODER_TARGET_REPLACE_MODULE)
|
| 1033 |
+
self.text_encoder_loras.extend(text_encoder_loras)
|
| 1034 |
+
skipped_te += skipped
|
| 1035 |
+
logger.info(f"create LoRA for Text Encoder: {len(self.text_encoder_loras)} modules.")
|
| 1036 |
+
|
| 1037 |
+
# extend U-Net target modules if conv2d 3x3 is enabled, or load from weights
|
| 1038 |
+
target_modules = LoRANetwork.UNET_TARGET_REPLACE_MODULE
|
| 1039 |
+
if modules_dim is not None or self.conv_lora_dim is not None or conv_block_dims is not None:
|
| 1040 |
+
target_modules += LoRANetwork.UNET_TARGET_REPLACE_MODULE_CONV2D_3X3
|
| 1041 |
+
|
| 1042 |
+
self.unet_loras, skipped_un = create_modules(True, None, unet, target_modules)
|
| 1043 |
+
logger.info(f"create LoRA for U-Net: {len(self.unet_loras)} modules.")
|
| 1044 |
+
|
| 1045 |
+
skipped = skipped_te + skipped_un
|
| 1046 |
+
if varbose and len(skipped) > 0:
|
| 1047 |
+
logger.warning(
|
| 1048 |
+
f"because block_lr_weight is 0 or dim (rank) is 0, {len(skipped)} LoRA modules are skipped / block_lr_weightまたはdim (rank)が0の為、次の{len(skipped)}個のLoRAモジュールはスキップされます:"
|
| 1049 |
+
)
|
| 1050 |
+
for name in skipped:
|
| 1051 |
+
logger.info(f"\t{name}")
|
| 1052 |
+
|
| 1053 |
+
self.block_lr_weight = None
|
| 1054 |
+
self.block_lr = False
|
| 1055 |
+
|
| 1056 |
+
# assertion
|
| 1057 |
+
names = set()
|
| 1058 |
+
for lora in self.text_encoder_loras + self.unet_loras:
|
| 1059 |
+
assert lora.lora_name not in names, f"duplicated lora name: {lora.lora_name}"
|
| 1060 |
+
names.add(lora.lora_name)
|
| 1061 |
+
|
| 1062 |
+
def set_multiplier(self, multiplier):
|
| 1063 |
+
self.multiplier = multiplier
|
| 1064 |
+
for lora in self.text_encoder_loras + self.unet_loras:
|
| 1065 |
+
lora.multiplier = self.multiplier
|
| 1066 |
+
|
| 1067 |
+
def set_enabled(self, is_enabled):
|
| 1068 |
+
for lora in self.text_encoder_loras + self.unet_loras:
|
| 1069 |
+
lora.enabled = is_enabled
|
| 1070 |
+
|
| 1071 |
+
def load_weights(self, file):
|
| 1072 |
+
if os.path.splitext(file)[1] == ".safetensors":
|
| 1073 |
+
from safetensors.torch import load_file
|
| 1074 |
+
|
| 1075 |
+
weights_sd = load_file(file)
|
| 1076 |
+
else:
|
| 1077 |
+
weights_sd = torch.load(file, map_location="cpu")
|
| 1078 |
+
|
| 1079 |
+
info = self.load_state_dict(weights_sd, False)
|
| 1080 |
+
return info
|
| 1081 |
+
|
| 1082 |
+
def apply_to(self, text_encoder, unet, apply_text_encoder=True, apply_unet=True):
|
| 1083 |
+
if apply_text_encoder:
|
| 1084 |
+
logger.info(f"enable LoRA for text encoder: {len(self.text_encoder_loras)} modules")
|
| 1085 |
+
else:
|
| 1086 |
+
self.text_encoder_loras = []
|
| 1087 |
+
|
| 1088 |
+
if apply_unet:
|
| 1089 |
+
logger.info(f"enable LoRA for U-Net: {len(self.unet_loras)} modules")
|
| 1090 |
+
else:
|
| 1091 |
+
self.unet_loras = []
|
| 1092 |
+
|
| 1093 |
+
for lora in self.text_encoder_loras + self.unet_loras:
|
| 1094 |
+
lora.apply_to()
|
| 1095 |
+
self.add_module(lora.lora_name, lora)
|
| 1096 |
+
|
| 1097 |
+
# マージできるかどうかを返す
|
| 1098 |
+
def is_mergeable(self):
|
| 1099 |
+
return True
|
| 1100 |
+
|
| 1101 |
+
# TODO refactor to common function with apply_to
|
| 1102 |
+
def merge_to(self, text_encoder, unet, weights_sd, dtype, device):
|
| 1103 |
+
apply_text_encoder = apply_unet = False
|
| 1104 |
+
for key in weights_sd.keys():
|
| 1105 |
+
if key.startswith(LoRANetwork.LORA_PREFIX_TEXT_ENCODER):
|
| 1106 |
+
apply_text_encoder = True
|
| 1107 |
+
elif key.startswith(LoRANetwork.LORA_PREFIX_UNET):
|
| 1108 |
+
apply_unet = True
|
| 1109 |
+
|
| 1110 |
+
if apply_text_encoder:
|
| 1111 |
+
logger.info("enable LoRA for text encoder")
|
| 1112 |
+
else:
|
| 1113 |
+
self.text_encoder_loras = []
|
| 1114 |
+
|
| 1115 |
+
if apply_unet:
|
| 1116 |
+
logger.info("enable LoRA for U-Net")
|
| 1117 |
+
else:
|
| 1118 |
+
self.unet_loras = []
|
| 1119 |
+
|
| 1120 |
+
for lora in self.text_encoder_loras + self.unet_loras:
|
| 1121 |
+
sd_for_lora = {}
|
| 1122 |
+
for key in weights_sd.keys():
|
| 1123 |
+
if key.startswith(lora.lora_name):
|
| 1124 |
+
sd_for_lora[key[len(lora.lora_name) + 1 :]] = weights_sd[key]
|
| 1125 |
+
lora.merge_to(sd_for_lora, dtype, device)
|
| 1126 |
+
|
| 1127 |
+
logger.info(f"weights are merged")
|
| 1128 |
+
|
| 1129 |
+
# 層別学習率用に層ごとの学習率に対する倍率を定義する 引数の順番が逆だがとりあえず気にしない
|
| 1130 |
+
def set_block_lr_weight(self, block_lr_weight: Optional[List[float]]):
|
| 1131 |
+
self.block_lr = True
|
| 1132 |
+
self.block_lr_weight = block_lr_weight
|
| 1133 |
+
|
| 1134 |
+
def get_lr_weight(self, block_idx: int) -> float:
|
| 1135 |
+
if not self.block_lr or self.block_lr_weight is None:
|
| 1136 |
+
return 1.0
|
| 1137 |
+
return self.block_lr_weight[block_idx]
|
| 1138 |
+
|
| 1139 |
+
def set_loraplus_lr_ratio(self, loraplus_lr_ratio, loraplus_unet_lr_ratio, loraplus_text_encoder_lr_ratio):
|
| 1140 |
+
self.loraplus_lr_ratio = loraplus_lr_ratio
|
| 1141 |
+
self.loraplus_unet_lr_ratio = loraplus_unet_lr_ratio
|
| 1142 |
+
self.loraplus_text_encoder_lr_ratio = loraplus_text_encoder_lr_ratio
|
| 1143 |
+
|
| 1144 |
+
logger.info(f"LoRA+ UNet LR Ratio: {self.loraplus_unet_lr_ratio or self.loraplus_lr_ratio}")
|
| 1145 |
+
logger.info(f"LoRA+ Text Encoder LR Ratio: {self.loraplus_text_encoder_lr_ratio or self.loraplus_lr_ratio}")
|
| 1146 |
+
|
| 1147 |
+
# 二つのText Encoderに別々の学習率を設定できるようにするといいかも
|
| 1148 |
+
def prepare_optimizer_params(self, text_encoder_lr, unet_lr, default_lr):
|
| 1149 |
+
# TODO warn if optimizer is not compatible with LoRA+ (but it will cause error so we don't need to check it here?)
|
| 1150 |
+
# if (
|
| 1151 |
+
# self.loraplus_lr_ratio is not None
|
| 1152 |
+
# or self.loraplus_text_encoder_lr_ratio is not None
|
| 1153 |
+
# or self.loraplus_unet_lr_ratio is not None
|
| 1154 |
+
# ):
|
| 1155 |
+
# assert (
|
| 1156 |
+
# optimizer_type.lower() != "prodigy" and "dadapt" not in optimizer_type.lower()
|
| 1157 |
+
# ), "LoRA+ and Prodigy/DAdaptation is not supported / LoRA+とProdigy/DAdaptationの組み合わせはサポートされていません"
|
| 1158 |
+
|
| 1159 |
+
self.requires_grad_(True)
|
| 1160 |
+
|
| 1161 |
+
all_params = []
|
| 1162 |
+
lr_descriptions = []
|
| 1163 |
+
|
| 1164 |
+
def assemble_params(loras, lr, ratio):
|
| 1165 |
+
param_groups = {"lora": {}, "plus": {}}
|
| 1166 |
+
for lora in loras:
|
| 1167 |
+
for name, param in lora.named_parameters():
|
| 1168 |
+
if ratio is not None and "lora_up" in name:
|
| 1169 |
+
param_groups["plus"][f"{lora.lora_name}.{name}"] = param
|
| 1170 |
+
else:
|
| 1171 |
+
param_groups["lora"][f"{lora.lora_name}.{name}"] = param
|
| 1172 |
+
|
| 1173 |
+
params = []
|
| 1174 |
+
descriptions = []
|
| 1175 |
+
for key in param_groups.keys():
|
| 1176 |
+
param_data = {"params": param_groups[key].values()}
|
| 1177 |
+
|
| 1178 |
+
if len(param_data["params"]) == 0:
|
| 1179 |
+
continue
|
| 1180 |
+
|
| 1181 |
+
if lr is not None:
|
| 1182 |
+
if key == "plus":
|
| 1183 |
+
param_data["lr"] = lr * ratio
|
| 1184 |
+
else:
|
| 1185 |
+
param_data["lr"] = lr
|
| 1186 |
+
|
| 1187 |
+
if param_data.get("lr", None) == 0 or param_data.get("lr", None) is None:
|
| 1188 |
+
logger.info("NO LR skipping!")
|
| 1189 |
+
continue
|
| 1190 |
+
|
| 1191 |
+
params.append(param_data)
|
| 1192 |
+
descriptions.append("plus" if key == "plus" else "")
|
| 1193 |
+
|
| 1194 |
+
return params, descriptions
|
| 1195 |
+
|
| 1196 |
+
if self.text_encoder_loras:
|
| 1197 |
+
params, descriptions = assemble_params(
|
| 1198 |
+
self.text_encoder_loras,
|
| 1199 |
+
text_encoder_lr if text_encoder_lr is not None else default_lr,
|
| 1200 |
+
self.loraplus_text_encoder_lr_ratio or self.loraplus_lr_ratio,
|
| 1201 |
+
)
|
| 1202 |
+
all_params.extend(params)
|
| 1203 |
+
lr_descriptions.extend(["textencoder" + (" " + d if d else "") for d in descriptions])
|
| 1204 |
+
|
| 1205 |
+
if self.unet_loras:
|
| 1206 |
+
if self.block_lr:
|
| 1207 |
+
is_sdxl = False
|
| 1208 |
+
for lora in self.unet_loras:
|
| 1209 |
+
if "input_blocks" in lora.lora_name or "output_blocks" in lora.lora_name:
|
| 1210 |
+
is_sdxl = True
|
| 1211 |
+
break
|
| 1212 |
+
|
| 1213 |
+
# 学習率のグラフをblockごとにしたいので、blockごとにloraを分類
|
| 1214 |
+
block_idx_to_lora = {}
|
| 1215 |
+
for lora in self.unet_loras:
|
| 1216 |
+
idx = get_block_index(lora.lora_name, is_sdxl)
|
| 1217 |
+
if idx not in block_idx_to_lora:
|
| 1218 |
+
block_idx_to_lora[idx] = []
|
| 1219 |
+
block_idx_to_lora[idx].append(lora)
|
| 1220 |
+
|
| 1221 |
+
# blockごとにパラメータを設定する
|
| 1222 |
+
for idx, block_loras in block_idx_to_lora.items():
|
| 1223 |
+
params, descriptions = assemble_params(
|
| 1224 |
+
block_loras,
|
| 1225 |
+
(unet_lr if unet_lr is not None else default_lr) * self.get_lr_weight(idx),
|
| 1226 |
+
self.loraplus_unet_lr_ratio or self.loraplus_lr_ratio,
|
| 1227 |
+
)
|
| 1228 |
+
all_params.extend(params)
|
| 1229 |
+
lr_descriptions.extend([f"unet_block{idx}" + (" " + d if d else "") for d in descriptions])
|
| 1230 |
+
|
| 1231 |
+
else:
|
| 1232 |
+
params, descriptions = assemble_params(
|
| 1233 |
+
self.unet_loras,
|
| 1234 |
+
unet_lr if unet_lr is not None else default_lr,
|
| 1235 |
+
self.loraplus_unet_lr_ratio or self.loraplus_lr_ratio,
|
| 1236 |
+
)
|
| 1237 |
+
all_params.extend(params)
|
| 1238 |
+
lr_descriptions.extend(["unet" + (" " + d if d else "") for d in descriptions])
|
| 1239 |
+
|
| 1240 |
+
return all_params, lr_descriptions
|
| 1241 |
+
|
| 1242 |
+
def enable_gradient_checkpointing(self):
|
| 1243 |
+
# not supported
|
| 1244 |
+
pass
|
| 1245 |
+
|
| 1246 |
+
def prepare_grad_etc(self, text_encoder, unet):
|
| 1247 |
+
self.requires_grad_(True)
|
| 1248 |
+
|
| 1249 |
+
def on_epoch_start(self, text_encoder, unet):
|
| 1250 |
+
self.train()
|
| 1251 |
+
|
| 1252 |
+
def get_trainable_params(self):
|
| 1253 |
+
return self.parameters()
|
| 1254 |
+
|
| 1255 |
+
def save_weights(self, file, dtype, metadata):
|
| 1256 |
+
if metadata is not None and len(metadata) == 0:
|
| 1257 |
+
metadata = None
|
| 1258 |
+
|
| 1259 |
+
state_dict = self.state_dict()
|
| 1260 |
+
|
| 1261 |
+
if dtype is not None:
|
| 1262 |
+
for key in list(state_dict.keys()):
|
| 1263 |
+
v = state_dict[key]
|
| 1264 |
+
v = v.detach().clone().to("cpu").to(dtype)
|
| 1265 |
+
state_dict[key] = v
|
| 1266 |
+
|
| 1267 |
+
if os.path.splitext(file)[1] == ".safetensors":
|
| 1268 |
+
from safetensors.torch import save_file
|
| 1269 |
+
from library import train_util
|
| 1270 |
+
|
| 1271 |
+
# Precalculate model hashes to save time on indexing
|
| 1272 |
+
if metadata is None:
|
| 1273 |
+
metadata = {}
|
| 1274 |
+
model_hash, legacy_hash = train_util.precalculate_safetensors_hashes(state_dict, metadata)
|
| 1275 |
+
metadata["sshs_model_hash"] = model_hash
|
| 1276 |
+
metadata["sshs_legacy_hash"] = legacy_hash
|
| 1277 |
+
|
| 1278 |
+
save_file(state_dict, file, metadata)
|
| 1279 |
+
else:
|
| 1280 |
+
torch.save(state_dict, file)
|
| 1281 |
+
|
| 1282 |
+
# mask is a tensor with values from 0 to 1
|
| 1283 |
+
def set_region(self, sub_prompt_index, is_last_network, mask):
|
| 1284 |
+
if mask.max() == 0:
|
| 1285 |
+
mask = torch.ones_like(mask)
|
| 1286 |
+
|
| 1287 |
+
self.mask = mask
|
| 1288 |
+
self.sub_prompt_index = sub_prompt_index
|
| 1289 |
+
self.is_last_network = is_last_network
|
| 1290 |
+
|
| 1291 |
+
for lora in self.text_encoder_loras + self.unet_loras:
|
| 1292 |
+
lora.set_network(self)
|
| 1293 |
+
|
| 1294 |
+
def set_current_generation(self, batch_size, num_sub_prompts, width, height, shared, ds_ratio=None):
|
| 1295 |
+
self.batch_size = batch_size
|
| 1296 |
+
self.num_sub_prompts = num_sub_prompts
|
| 1297 |
+
self.current_size = (height, width)
|
| 1298 |
+
self.shared = shared
|
| 1299 |
+
|
| 1300 |
+
# create masks
|
| 1301 |
+
mask = self.mask
|
| 1302 |
+
mask_dic = {}
|
| 1303 |
+
mask = mask.unsqueeze(0).unsqueeze(1) # b(1),c(1),h,w
|
| 1304 |
+
ref_weight = self.text_encoder_loras[0].lora_down.weight if self.text_encoder_loras else self.unet_loras[0].lora_down.weight
|
| 1305 |
+
dtype = ref_weight.dtype
|
| 1306 |
+
device = ref_weight.device
|
| 1307 |
+
|
| 1308 |
+
def resize_add(mh, mw):
|
| 1309 |
+
# logger.info(mh, mw, mh * mw)
|
| 1310 |
+
m = torch.nn.functional.interpolate(mask, (mh, mw), mode="bilinear") # doesn't work in bf16
|
| 1311 |
+
m = m.to(device, dtype=dtype)
|
| 1312 |
+
mask_dic[mh * mw] = m
|
| 1313 |
+
|
| 1314 |
+
h = height // 8
|
| 1315 |
+
w = width // 8
|
| 1316 |
+
for _ in range(4):
|
| 1317 |
+
resize_add(h, w)
|
| 1318 |
+
if h % 2 == 1 or w % 2 == 1: # add extra shape if h/w is not divisible by 2
|
| 1319 |
+
resize_add(h + h % 2, w + w % 2)
|
| 1320 |
+
|
| 1321 |
+
# deep shrink
|
| 1322 |
+
if ds_ratio is not None:
|
| 1323 |
+
hd = int(h * ds_ratio)
|
| 1324 |
+
wd = int(w * ds_ratio)
|
| 1325 |
+
resize_add(hd, wd)
|
| 1326 |
+
|
| 1327 |
+
h = (h + 1) // 2
|
| 1328 |
+
w = (w + 1) // 2
|
| 1329 |
+
|
| 1330 |
+
self.mask_dic = mask_dic
|
| 1331 |
+
|
| 1332 |
+
def backup_weights(self):
|
| 1333 |
+
# 重みのバックアップを行う
|
| 1334 |
+
loras: List[LoRAInfModule] = self.text_encoder_loras + self.unet_loras
|
| 1335 |
+
for lora in loras:
|
| 1336 |
+
org_module = lora.org_module_ref[0]
|
| 1337 |
+
if not hasattr(org_module, "_lora_org_weight"):
|
| 1338 |
+
sd = org_module.state_dict()
|
| 1339 |
+
org_module._lora_org_weight = sd["weight"].detach().clone()
|
| 1340 |
+
org_module._lora_restored = True
|
| 1341 |
+
|
| 1342 |
+
def restore_weights(self):
|
| 1343 |
+
# 重みのリストアを行う
|
| 1344 |
+
loras: List[LoRAInfModule] = self.text_encoder_loras + self.unet_loras
|
| 1345 |
+
for lora in loras:
|
| 1346 |
+
org_module = lora.org_module_ref[0]
|
| 1347 |
+
if not org_module._lora_restored:
|
| 1348 |
+
sd = org_module.state_dict()
|
| 1349 |
+
sd["weight"] = org_module._lora_org_weight
|
| 1350 |
+
org_module.load_state_dict(sd)
|
| 1351 |
+
org_module._lora_restored = True
|
| 1352 |
+
|
| 1353 |
+
def pre_calculation(self):
|
| 1354 |
+
# 事前計算を行う
|
| 1355 |
+
loras: List[LoRAInfModule] = self.text_encoder_loras + self.unet_loras
|
| 1356 |
+
for lora in loras:
|
| 1357 |
+
org_module = lora.org_module_ref[0]
|
| 1358 |
+
sd = org_module.state_dict()
|
| 1359 |
+
|
| 1360 |
+
org_weight = sd["weight"]
|
| 1361 |
+
lora_weight = lora.get_weight().to(org_weight.device, dtype=org_weight.dtype)
|
| 1362 |
+
sd["weight"] = org_weight + lora_weight
|
| 1363 |
+
assert sd["weight"].shape == org_weight.shape
|
| 1364 |
+
org_module.load_state_dict(sd)
|
| 1365 |
+
|
| 1366 |
+
org_module._lora_restored = False
|
| 1367 |
+
lora.enabled = False
|
| 1368 |
+
|
| 1369 |
+
def apply_max_norm_regularization(self, max_norm_value, device):
|
| 1370 |
+
downkeys = []
|
| 1371 |
+
upkeys = []
|
| 1372 |
+
alphakeys = []
|
| 1373 |
+
norms = []
|
| 1374 |
+
keys_scaled = 0
|
| 1375 |
+
|
| 1376 |
+
state_dict = self.state_dict()
|
| 1377 |
+
for key in state_dict.keys():
|
| 1378 |
+
if "lora_down" in key and "weight" in key:
|
| 1379 |
+
downkeys.append(key)
|
| 1380 |
+
upkeys.append(key.replace("lora_down", "lora_up"))
|
| 1381 |
+
alphakeys.append(key.replace("lora_down.weight", "alpha"))
|
| 1382 |
+
|
| 1383 |
+
for i in range(len(downkeys)):
|
| 1384 |
+
down = state_dict[downkeys[i]].to(device)
|
| 1385 |
+
up = state_dict[upkeys[i]].to(device)
|
| 1386 |
+
alpha = state_dict[alphakeys[i]].to(device)
|
| 1387 |
+
dim = down.shape[0]
|
| 1388 |
+
scale = alpha / dim
|
| 1389 |
+
|
| 1390 |
+
if up.shape[2:] == (1, 1) and down.shape[2:] == (1, 1):
|
| 1391 |
+
updown = (up.squeeze(2).squeeze(2) @ down.squeeze(2).squeeze(2)).unsqueeze(2).unsqueeze(3)
|
| 1392 |
+
elif up.shape[2:] == (3, 3) or down.shape[2:] == (3, 3):
|
| 1393 |
+
updown = torch.nn.functional.conv2d(down.permute(1, 0, 2, 3), up).permute(1, 0, 2, 3)
|
| 1394 |
+
else:
|
| 1395 |
+
updown = up @ down
|
| 1396 |
+
|
| 1397 |
+
updown *= scale
|
| 1398 |
+
|
| 1399 |
+
norm = updown.norm().clamp(min=max_norm_value / 2)
|
| 1400 |
+
desired = torch.clamp(norm, max=max_norm_value)
|
| 1401 |
+
ratio = desired.cpu() / norm.cpu()
|
| 1402 |
+
sqrt_ratio = ratio**0.5
|
| 1403 |
+
if ratio != 1:
|
| 1404 |
+
keys_scaled += 1
|
| 1405 |
+
state_dict[upkeys[i]] *= sqrt_ratio
|
| 1406 |
+
state_dict[downkeys[i]] *= sqrt_ratio
|
| 1407 |
+
scalednorm = updown.norm() * ratio
|
| 1408 |
+
norms.append(scalednorm.item())
|
| 1409 |
+
|
| 1410 |
+
return keys_scaled, sum(norms) / len(norms), max(norms)
|
lora_diffusers.py
ADDED
|
@@ -0,0 +1,616 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Diffusersで動くLoRA。このファイル単独で完結する。
|
| 2 |
+
# LoRA module for Diffusers. This file works independently.
|
| 3 |
+
|
| 4 |
+
import bisect
|
| 5 |
+
import math
|
| 6 |
+
import random
|
| 7 |
+
from typing import Any, Dict, List, Mapping, Optional, Union
|
| 8 |
+
from diffusers import UNet2DConditionModel
|
| 9 |
+
import numpy as np
|
| 10 |
+
from tqdm import tqdm
|
| 11 |
+
from transformers import CLIPTextModel
|
| 12 |
+
|
| 13 |
+
import torch
|
| 14 |
+
from library.device_utils import init_ipex, get_preferred_device
|
| 15 |
+
init_ipex()
|
| 16 |
+
|
| 17 |
+
from library.utils import setup_logging
|
| 18 |
+
setup_logging()
|
| 19 |
+
import logging
|
| 20 |
+
logger = logging.getLogger(__name__)
|
| 21 |
+
|
| 22 |
+
def make_unet_conversion_map() -> Dict[str, str]:
|
| 23 |
+
unet_conversion_map_layer = []
|
| 24 |
+
|
| 25 |
+
for i in range(3): # num_blocks is 3 in sdxl
|
| 26 |
+
# loop over downblocks/upblocks
|
| 27 |
+
for j in range(2):
|
| 28 |
+
# loop over resnets/attentions for downblocks
|
| 29 |
+
hf_down_res_prefix = f"down_blocks.{i}.resnets.{j}."
|
| 30 |
+
sd_down_res_prefix = f"input_blocks.{3*i + j + 1}.0."
|
| 31 |
+
unet_conversion_map_layer.append((sd_down_res_prefix, hf_down_res_prefix))
|
| 32 |
+
|
| 33 |
+
if i < 3:
|
| 34 |
+
# no attention layers in down_blocks.3
|
| 35 |
+
hf_down_atn_prefix = f"down_blocks.{i}.attentions.{j}."
|
| 36 |
+
sd_down_atn_prefix = f"input_blocks.{3*i + j + 1}.1."
|
| 37 |
+
unet_conversion_map_layer.append((sd_down_atn_prefix, hf_down_atn_prefix))
|
| 38 |
+
|
| 39 |
+
for j in range(3):
|
| 40 |
+
# loop over resnets/attentions for upblocks
|
| 41 |
+
hf_up_res_prefix = f"up_blocks.{i}.resnets.{j}."
|
| 42 |
+
sd_up_res_prefix = f"output_blocks.{3*i + j}.0."
|
| 43 |
+
unet_conversion_map_layer.append((sd_up_res_prefix, hf_up_res_prefix))
|
| 44 |
+
|
| 45 |
+
# if i > 0: commentout for sdxl
|
| 46 |
+
# no attention layers in up_blocks.0
|
| 47 |
+
hf_up_atn_prefix = f"up_blocks.{i}.attentions.{j}."
|
| 48 |
+
sd_up_atn_prefix = f"output_blocks.{3*i + j}.1."
|
| 49 |
+
unet_conversion_map_layer.append((sd_up_atn_prefix, hf_up_atn_prefix))
|
| 50 |
+
|
| 51 |
+
if i < 3:
|
| 52 |
+
# no downsample in down_blocks.3
|
| 53 |
+
hf_downsample_prefix = f"down_blocks.{i}.downsamplers.0.conv."
|
| 54 |
+
sd_downsample_prefix = f"input_blocks.{3*(i+1)}.0.op."
|
| 55 |
+
unet_conversion_map_layer.append((sd_downsample_prefix, hf_downsample_prefix))
|
| 56 |
+
|
| 57 |
+
# no upsample in up_blocks.3
|
| 58 |
+
hf_upsample_prefix = f"up_blocks.{i}.upsamplers.0."
|
| 59 |
+
sd_upsample_prefix = f"output_blocks.{3*i + 2}.{2}." # change for sdxl
|
| 60 |
+
unet_conversion_map_layer.append((sd_upsample_prefix, hf_upsample_prefix))
|
| 61 |
+
|
| 62 |
+
hf_mid_atn_prefix = "mid_block.attentions.0."
|
| 63 |
+
sd_mid_atn_prefix = "middle_block.1."
|
| 64 |
+
unet_conversion_map_layer.append((sd_mid_atn_prefix, hf_mid_atn_prefix))
|
| 65 |
+
|
| 66 |
+
for j in range(2):
|
| 67 |
+
hf_mid_res_prefix = f"mid_block.resnets.{j}."
|
| 68 |
+
sd_mid_res_prefix = f"middle_block.{2*j}."
|
| 69 |
+
unet_conversion_map_layer.append((sd_mid_res_prefix, hf_mid_res_prefix))
|
| 70 |
+
|
| 71 |
+
unet_conversion_map_resnet = [
|
| 72 |
+
# (stable-diffusion, HF Diffusers)
|
| 73 |
+
("in_layers.0.", "norm1."),
|
| 74 |
+
("in_layers.2.", "conv1."),
|
| 75 |
+
("out_layers.0.", "norm2."),
|
| 76 |
+
("out_layers.3.", "conv2."),
|
| 77 |
+
("emb_layers.1.", "time_emb_proj."),
|
| 78 |
+
("skip_connection.", "conv_shortcut."),
|
| 79 |
+
]
|
| 80 |
+
|
| 81 |
+
unet_conversion_map = []
|
| 82 |
+
for sd, hf in unet_conversion_map_layer:
|
| 83 |
+
if "resnets" in hf:
|
| 84 |
+
for sd_res, hf_res in unet_conversion_map_resnet:
|
| 85 |
+
unet_conversion_map.append((sd + sd_res, hf + hf_res))
|
| 86 |
+
else:
|
| 87 |
+
unet_conversion_map.append((sd, hf))
|
| 88 |
+
|
| 89 |
+
for j in range(2):
|
| 90 |
+
hf_time_embed_prefix = f"time_embedding.linear_{j+1}."
|
| 91 |
+
sd_time_embed_prefix = f"time_embed.{j*2}."
|
| 92 |
+
unet_conversion_map.append((sd_time_embed_prefix, hf_time_embed_prefix))
|
| 93 |
+
|
| 94 |
+
for j in range(2):
|
| 95 |
+
hf_label_embed_prefix = f"add_embedding.linear_{j+1}."
|
| 96 |
+
sd_label_embed_prefix = f"label_emb.0.{j*2}."
|
| 97 |
+
unet_conversion_map.append((sd_label_embed_prefix, hf_label_embed_prefix))
|
| 98 |
+
|
| 99 |
+
unet_conversion_map.append(("input_blocks.0.0.", "conv_in."))
|
| 100 |
+
unet_conversion_map.append(("out.0.", "conv_norm_out."))
|
| 101 |
+
unet_conversion_map.append(("out.2.", "conv_out."))
|
| 102 |
+
|
| 103 |
+
sd_hf_conversion_map = {sd.replace(".", "_")[:-1]: hf.replace(".", "_")[:-1] for sd, hf in unet_conversion_map}
|
| 104 |
+
return sd_hf_conversion_map
|
| 105 |
+
|
| 106 |
+
|
| 107 |
+
UNET_CONVERSION_MAP = make_unet_conversion_map()
|
| 108 |
+
|
| 109 |
+
|
| 110 |
+
class LoRAModule(torch.nn.Module):
|
| 111 |
+
"""
|
| 112 |
+
replaces forward method of the original Linear, instead of replacing the original Linear module.
|
| 113 |
+
"""
|
| 114 |
+
|
| 115 |
+
def __init__(
|
| 116 |
+
self,
|
| 117 |
+
lora_name,
|
| 118 |
+
org_module: torch.nn.Module,
|
| 119 |
+
multiplier=1.0,
|
| 120 |
+
lora_dim=4,
|
| 121 |
+
alpha=1,
|
| 122 |
+
):
|
| 123 |
+
"""if alpha == 0 or None, alpha is rank (no scaling)."""
|
| 124 |
+
super().__init__()
|
| 125 |
+
self.lora_name = lora_name
|
| 126 |
+
|
| 127 |
+
if org_module.__class__.__name__ == "Conv2d" or org_module.__class__.__name__ == "LoRACompatibleConv":
|
| 128 |
+
in_dim = org_module.in_channels
|
| 129 |
+
out_dim = org_module.out_channels
|
| 130 |
+
else:
|
| 131 |
+
in_dim = org_module.in_features
|
| 132 |
+
out_dim = org_module.out_features
|
| 133 |
+
|
| 134 |
+
self.lora_dim = lora_dim
|
| 135 |
+
|
| 136 |
+
if org_module.__class__.__name__ == "Conv2d" or org_module.__class__.__name__ == "LoRACompatibleConv":
|
| 137 |
+
kernel_size = org_module.kernel_size
|
| 138 |
+
stride = org_module.stride
|
| 139 |
+
padding = org_module.padding
|
| 140 |
+
self.lora_down = torch.nn.Conv2d(in_dim, self.lora_dim, kernel_size, stride, padding, bias=False)
|
| 141 |
+
self.lora_up = torch.nn.Conv2d(self.lora_dim, out_dim, (1, 1), (1, 1), bias=False)
|
| 142 |
+
else:
|
| 143 |
+
self.lora_down = torch.nn.Linear(in_dim, self.lora_dim, bias=False)
|
| 144 |
+
self.lora_up = torch.nn.Linear(self.lora_dim, out_dim, bias=False)
|
| 145 |
+
|
| 146 |
+
if type(alpha) == torch.Tensor:
|
| 147 |
+
alpha = alpha.detach().float().numpy() # without casting, bf16 causes error
|
| 148 |
+
alpha = self.lora_dim if alpha is None or alpha == 0 else alpha
|
| 149 |
+
self.scale = alpha / self.lora_dim
|
| 150 |
+
self.register_buffer("alpha", torch.tensor(alpha)) # 勾配計算に含めない / not included in gradient calculation
|
| 151 |
+
|
| 152 |
+
# same as microsoft's
|
| 153 |
+
torch.nn.init.kaiming_uniform_(self.lora_down.weight, a=math.sqrt(5))
|
| 154 |
+
torch.nn.init.zeros_(self.lora_up.weight)
|
| 155 |
+
|
| 156 |
+
self.multiplier = multiplier
|
| 157 |
+
self.org_module = [org_module]
|
| 158 |
+
self.enabled = True
|
| 159 |
+
self.network: LoRANetwork = None
|
| 160 |
+
self.org_forward = None
|
| 161 |
+
|
| 162 |
+
# override org_module's forward method
|
| 163 |
+
def apply_to(self, multiplier=None):
|
| 164 |
+
if multiplier is not None:
|
| 165 |
+
self.multiplier = multiplier
|
| 166 |
+
if self.org_forward is None:
|
| 167 |
+
self.org_forward = self.org_module[0].forward
|
| 168 |
+
self.org_module[0].forward = self.forward
|
| 169 |
+
|
| 170 |
+
# restore org_module's forward method
|
| 171 |
+
def unapply_to(self):
|
| 172 |
+
if self.org_forward is not None:
|
| 173 |
+
self.org_module[0].forward = self.org_forward
|
| 174 |
+
|
| 175 |
+
# forward with lora
|
| 176 |
+
# scale is used LoRACompatibleConv, but we ignore it because we have multiplier
|
| 177 |
+
def forward(self, x, scale=1.0):
|
| 178 |
+
if not self.enabled:
|
| 179 |
+
return self.org_forward(x)
|
| 180 |
+
return self.org_forward(x) + self.lora_up(self.lora_down(x)) * self.multiplier * self.scale
|
| 181 |
+
|
| 182 |
+
def set_network(self, network):
|
| 183 |
+
self.network = network
|
| 184 |
+
|
| 185 |
+
# merge lora weight to org weight
|
| 186 |
+
def merge_to(self, multiplier=1.0):
|
| 187 |
+
# get lora weight
|
| 188 |
+
lora_weight = self.get_weight(multiplier)
|
| 189 |
+
|
| 190 |
+
# get org weight
|
| 191 |
+
org_sd = self.org_module[0].state_dict()
|
| 192 |
+
org_weight = org_sd["weight"]
|
| 193 |
+
weight = org_weight + lora_weight.to(org_weight.device, dtype=org_weight.dtype)
|
| 194 |
+
|
| 195 |
+
# set weight to org_module
|
| 196 |
+
org_sd["weight"] = weight
|
| 197 |
+
self.org_module[0].load_state_dict(org_sd)
|
| 198 |
+
|
| 199 |
+
# restore org weight from lora weight
|
| 200 |
+
def restore_from(self, multiplier=1.0):
|
| 201 |
+
# get lora weight
|
| 202 |
+
lora_weight = self.get_weight(multiplier)
|
| 203 |
+
|
| 204 |
+
# get org weight
|
| 205 |
+
org_sd = self.org_module[0].state_dict()
|
| 206 |
+
org_weight = org_sd["weight"]
|
| 207 |
+
weight = org_weight - lora_weight.to(org_weight.device, dtype=org_weight.dtype)
|
| 208 |
+
|
| 209 |
+
# set weight to org_module
|
| 210 |
+
org_sd["weight"] = weight
|
| 211 |
+
self.org_module[0].load_state_dict(org_sd)
|
| 212 |
+
|
| 213 |
+
# return lora weight
|
| 214 |
+
def get_weight(self, multiplier=None):
|
| 215 |
+
if multiplier is None:
|
| 216 |
+
multiplier = self.multiplier
|
| 217 |
+
|
| 218 |
+
# get up/down weight from module
|
| 219 |
+
up_weight = self.lora_up.weight.to(torch.float)
|
| 220 |
+
down_weight = self.lora_down.weight.to(torch.float)
|
| 221 |
+
|
| 222 |
+
# pre-calculated weight
|
| 223 |
+
if len(down_weight.size()) == 2:
|
| 224 |
+
# linear
|
| 225 |
+
weight = self.multiplier * (up_weight @ down_weight) * self.scale
|
| 226 |
+
elif down_weight.size()[2:4] == (1, 1):
|
| 227 |
+
# conv2d 1x1
|
| 228 |
+
weight = (
|
| 229 |
+
self.multiplier
|
| 230 |
+
* (up_weight.squeeze(3).squeeze(2) @ down_weight.squeeze(3).squeeze(2)).unsqueeze(2).unsqueeze(3)
|
| 231 |
+
* self.scale
|
| 232 |
+
)
|
| 233 |
+
else:
|
| 234 |
+
# conv2d 3x3
|
| 235 |
+
conved = torch.nn.functional.conv2d(down_weight.permute(1, 0, 2, 3), up_weight).permute(1, 0, 2, 3)
|
| 236 |
+
weight = self.multiplier * conved * self.scale
|
| 237 |
+
|
| 238 |
+
return weight
|
| 239 |
+
|
| 240 |
+
|
| 241 |
+
# Create network from weights for inference, weights are not loaded here
|
| 242 |
+
def create_network_from_weights(
|
| 243 |
+
text_encoder: Union[CLIPTextModel, List[CLIPTextModel]], unet: UNet2DConditionModel, weights_sd: Dict, multiplier: float = 1.0
|
| 244 |
+
):
|
| 245 |
+
# get dim/alpha mapping
|
| 246 |
+
modules_dim = {}
|
| 247 |
+
modules_alpha = {}
|
| 248 |
+
for key, value in weights_sd.items():
|
| 249 |
+
if "." not in key:
|
| 250 |
+
continue
|
| 251 |
+
|
| 252 |
+
lora_name = key.split(".")[0]
|
| 253 |
+
if "alpha" in key:
|
| 254 |
+
modules_alpha[lora_name] = value
|
| 255 |
+
elif "lora_down" in key:
|
| 256 |
+
dim = value.size()[0]
|
| 257 |
+
modules_dim[lora_name] = dim
|
| 258 |
+
# logger.info(f"{lora_name} {value.size()} {dim}")
|
| 259 |
+
|
| 260 |
+
# support old LoRA without alpha
|
| 261 |
+
for key in modules_dim.keys():
|
| 262 |
+
if key not in modules_alpha:
|
| 263 |
+
modules_alpha[key] = modules_dim[key]
|
| 264 |
+
|
| 265 |
+
return LoRANetwork(text_encoder, unet, multiplier=multiplier, modules_dim=modules_dim, modules_alpha=modules_alpha)
|
| 266 |
+
|
| 267 |
+
|
| 268 |
+
def merge_lora_weights(pipe, weights_sd: Dict, multiplier: float = 1.0):
|
| 269 |
+
text_encoders = [pipe.text_encoder, pipe.text_encoder_2] if hasattr(pipe, "text_encoder_2") else [pipe.text_encoder]
|
| 270 |
+
unet = pipe.unet
|
| 271 |
+
|
| 272 |
+
lora_network = create_network_from_weights(text_encoders, unet, weights_sd, multiplier=multiplier)
|
| 273 |
+
lora_network.load_state_dict(weights_sd)
|
| 274 |
+
lora_network.merge_to(multiplier=multiplier)
|
| 275 |
+
|
| 276 |
+
|
| 277 |
+
# block weightや学習に対応しない簡易版 / simple version without block weight and training
|
| 278 |
+
class LoRANetwork(torch.nn.Module):
|
| 279 |
+
UNET_TARGET_REPLACE_MODULE = ["Transformer2DModel"]
|
| 280 |
+
UNET_TARGET_REPLACE_MODULE_CONV2D_3X3 = ["ResnetBlock2D", "Downsample2D", "Upsample2D"]
|
| 281 |
+
TEXT_ENCODER_TARGET_REPLACE_MODULE = ["CLIPAttention", "CLIPSdpaAttention", "CLIPMLP"]
|
| 282 |
+
LORA_PREFIX_UNET = "lora_unet"
|
| 283 |
+
LORA_PREFIX_TEXT_ENCODER = "lora_te"
|
| 284 |
+
|
| 285 |
+
# SDXL: must starts with LORA_PREFIX_TEXT_ENCODER
|
| 286 |
+
LORA_PREFIX_TEXT_ENCODER1 = "lora_te1"
|
| 287 |
+
LORA_PREFIX_TEXT_ENCODER2 = "lora_te2"
|
| 288 |
+
|
| 289 |
+
def __init__(
|
| 290 |
+
self,
|
| 291 |
+
text_encoder: Union[List[CLIPTextModel], CLIPTextModel],
|
| 292 |
+
unet: UNet2DConditionModel,
|
| 293 |
+
multiplier: float = 1.0,
|
| 294 |
+
modules_dim: Optional[Dict[str, int]] = None,
|
| 295 |
+
modules_alpha: Optional[Dict[str, int]] = None,
|
| 296 |
+
varbose: Optional[bool] = False,
|
| 297 |
+
) -> None:
|
| 298 |
+
super().__init__()
|
| 299 |
+
self.multiplier = multiplier
|
| 300 |
+
|
| 301 |
+
logger.info("create LoRA network from weights")
|
| 302 |
+
|
| 303 |
+
# convert SDXL Stability AI's U-Net modules to Diffusers
|
| 304 |
+
converted = self.convert_unet_modules(modules_dim, modules_alpha)
|
| 305 |
+
if converted:
|
| 306 |
+
logger.info(f"converted {converted} Stability AI's U-Net LoRA modules to Diffusers (SDXL)")
|
| 307 |
+
|
| 308 |
+
# create module instances
|
| 309 |
+
def create_modules(
|
| 310 |
+
is_unet: bool,
|
| 311 |
+
text_encoder_idx: Optional[int], # None, 1, 2
|
| 312 |
+
root_module: torch.nn.Module,
|
| 313 |
+
target_replace_modules: List[torch.nn.Module],
|
| 314 |
+
) -> List[LoRAModule]:
|
| 315 |
+
prefix = (
|
| 316 |
+
self.LORA_PREFIX_UNET
|
| 317 |
+
if is_unet
|
| 318 |
+
else (
|
| 319 |
+
self.LORA_PREFIX_TEXT_ENCODER
|
| 320 |
+
if text_encoder_idx is None
|
| 321 |
+
else (self.LORA_PREFIX_TEXT_ENCODER1 if text_encoder_idx == 1 else self.LORA_PREFIX_TEXT_ENCODER2)
|
| 322 |
+
)
|
| 323 |
+
)
|
| 324 |
+
loras = []
|
| 325 |
+
skipped = []
|
| 326 |
+
for name, module in root_module.named_modules():
|
| 327 |
+
if module.__class__.__name__ in target_replace_modules:
|
| 328 |
+
for child_name, child_module in module.named_modules():
|
| 329 |
+
is_linear = (
|
| 330 |
+
child_module.__class__.__name__ == "Linear" or child_module.__class__.__name__ == "LoRACompatibleLinear"
|
| 331 |
+
)
|
| 332 |
+
is_conv2d = (
|
| 333 |
+
child_module.__class__.__name__ == "Conv2d" or child_module.__class__.__name__ == "LoRACompatibleConv"
|
| 334 |
+
)
|
| 335 |
+
|
| 336 |
+
if is_linear or is_conv2d:
|
| 337 |
+
lora_name = prefix + "." + name + "." + child_name
|
| 338 |
+
lora_name = lora_name.replace(".", "_")
|
| 339 |
+
|
| 340 |
+
if lora_name not in modules_dim:
|
| 341 |
+
# logger.info(f"skipped {lora_name} (not found in modules_dim)")
|
| 342 |
+
skipped.append(lora_name)
|
| 343 |
+
continue
|
| 344 |
+
|
| 345 |
+
dim = modules_dim[lora_name]
|
| 346 |
+
alpha = modules_alpha[lora_name]
|
| 347 |
+
lora = LoRAModule(
|
| 348 |
+
lora_name,
|
| 349 |
+
child_module,
|
| 350 |
+
self.multiplier,
|
| 351 |
+
dim,
|
| 352 |
+
alpha,
|
| 353 |
+
)
|
| 354 |
+
loras.append(lora)
|
| 355 |
+
return loras, skipped
|
| 356 |
+
|
| 357 |
+
text_encoders = text_encoder if type(text_encoder) == list else [text_encoder]
|
| 358 |
+
|
| 359 |
+
# create LoRA for text encoder
|
| 360 |
+
# 毎回すべてのモジュールを作るのは無駄なので要検討 / it is wasteful to create all modules every time, need to consider
|
| 361 |
+
self.text_encoder_loras: List[LoRAModule] = []
|
| 362 |
+
skipped_te = []
|
| 363 |
+
for i, text_encoder in enumerate(text_encoders):
|
| 364 |
+
if len(text_encoders) > 1:
|
| 365 |
+
index = i + 1
|
| 366 |
+
else:
|
| 367 |
+
index = None
|
| 368 |
+
|
| 369 |
+
text_encoder_loras, skipped = create_modules(False, index, text_encoder, LoRANetwork.TEXT_ENCODER_TARGET_REPLACE_MODULE)
|
| 370 |
+
self.text_encoder_loras.extend(text_encoder_loras)
|
| 371 |
+
skipped_te += skipped
|
| 372 |
+
logger.info(f"create LoRA for Text Encoder: {len(self.text_encoder_loras)} modules.")
|
| 373 |
+
if len(skipped_te) > 0:
|
| 374 |
+
logger.warning(f"skipped {len(skipped_te)} modules because of missing weight for text encoder.")
|
| 375 |
+
|
| 376 |
+
# extend U-Net target modules to include Conv2d 3x3
|
| 377 |
+
target_modules = LoRANetwork.UNET_TARGET_REPLACE_MODULE + LoRANetwork.UNET_TARGET_REPLACE_MODULE_CONV2D_3X3
|
| 378 |
+
|
| 379 |
+
self.unet_loras: List[LoRAModule]
|
| 380 |
+
self.unet_loras, skipped_un = create_modules(True, None, unet, target_modules)
|
| 381 |
+
logger.info(f"create LoRA for U-Net: {len(self.unet_loras)} modules.")
|
| 382 |
+
if len(skipped_un) > 0:
|
| 383 |
+
logger.warning(f"skipped {len(skipped_un)} modules because of missing weight for U-Net.")
|
| 384 |
+
|
| 385 |
+
# assertion
|
| 386 |
+
names = set()
|
| 387 |
+
for lora in self.text_encoder_loras + self.unet_loras:
|
| 388 |
+
names.add(lora.lora_name)
|
| 389 |
+
for lora_name in modules_dim.keys():
|
| 390 |
+
assert lora_name in names, f"{lora_name} is not found in created LoRA modules."
|
| 391 |
+
|
| 392 |
+
# make to work load_state_dict
|
| 393 |
+
for lora in self.text_encoder_loras + self.unet_loras:
|
| 394 |
+
self.add_module(lora.lora_name, lora)
|
| 395 |
+
|
| 396 |
+
# SDXL: convert SDXL Stability AI's U-Net modules to Diffusers
|
| 397 |
+
def convert_unet_modules(self, modules_dim, modules_alpha):
|
| 398 |
+
converted_count = 0
|
| 399 |
+
not_converted_count = 0
|
| 400 |
+
|
| 401 |
+
map_keys = list(UNET_CONVERSION_MAP.keys())
|
| 402 |
+
map_keys.sort()
|
| 403 |
+
|
| 404 |
+
for key in list(modules_dim.keys()):
|
| 405 |
+
if key.startswith(LoRANetwork.LORA_PREFIX_UNET + "_"):
|
| 406 |
+
search_key = key.replace(LoRANetwork.LORA_PREFIX_UNET + "_", "")
|
| 407 |
+
position = bisect.bisect_right(map_keys, search_key)
|
| 408 |
+
map_key = map_keys[position - 1]
|
| 409 |
+
if search_key.startswith(map_key):
|
| 410 |
+
new_key = key.replace(map_key, UNET_CONVERSION_MAP[map_key])
|
| 411 |
+
modules_dim[new_key] = modules_dim[key]
|
| 412 |
+
modules_alpha[new_key] = modules_alpha[key]
|
| 413 |
+
del modules_dim[key]
|
| 414 |
+
del modules_alpha[key]
|
| 415 |
+
converted_count += 1
|
| 416 |
+
else:
|
| 417 |
+
not_converted_count += 1
|
| 418 |
+
assert (
|
| 419 |
+
converted_count == 0 or not_converted_count == 0
|
| 420 |
+
), f"some modules are not converted: {converted_count} converted, {not_converted_count} not converted"
|
| 421 |
+
return converted_count
|
| 422 |
+
|
| 423 |
+
def set_multiplier(self, multiplier):
|
| 424 |
+
self.multiplier = multiplier
|
| 425 |
+
for lora in self.text_encoder_loras + self.unet_loras:
|
| 426 |
+
lora.multiplier = self.multiplier
|
| 427 |
+
|
| 428 |
+
def apply_to(self, multiplier=1.0, apply_text_encoder=True, apply_unet=True):
|
| 429 |
+
if apply_text_encoder:
|
| 430 |
+
logger.info("enable LoRA for text encoder")
|
| 431 |
+
for lora in self.text_encoder_loras:
|
| 432 |
+
lora.apply_to(multiplier)
|
| 433 |
+
if apply_unet:
|
| 434 |
+
logger.info("enable LoRA for U-Net")
|
| 435 |
+
for lora in self.unet_loras:
|
| 436 |
+
lora.apply_to(multiplier)
|
| 437 |
+
|
| 438 |
+
def unapply_to(self):
|
| 439 |
+
for lora in self.text_encoder_loras + self.unet_loras:
|
| 440 |
+
lora.unapply_to()
|
| 441 |
+
|
| 442 |
+
def merge_to(self, multiplier=1.0):
|
| 443 |
+
logger.info("merge LoRA weights to original weights")
|
| 444 |
+
for lora in tqdm(self.text_encoder_loras + self.unet_loras):
|
| 445 |
+
lora.merge_to(multiplier)
|
| 446 |
+
logger.info(f"weights are merged")
|
| 447 |
+
|
| 448 |
+
def restore_from(self, multiplier=1.0):
|
| 449 |
+
logger.info("restore LoRA weights from original weights")
|
| 450 |
+
for lora in tqdm(self.text_encoder_loras + self.unet_loras):
|
| 451 |
+
lora.restore_from(multiplier)
|
| 452 |
+
logger.info(f"weights are restored")
|
| 453 |
+
|
| 454 |
+
def load_state_dict(self, state_dict: Mapping[str, Any], strict: bool = True):
|
| 455 |
+
# convert SDXL Stability AI's state dict to Diffusers' based state dict
|
| 456 |
+
map_keys = list(UNET_CONVERSION_MAP.keys()) # prefix of U-Net modules
|
| 457 |
+
map_keys.sort()
|
| 458 |
+
for key in list(state_dict.keys()):
|
| 459 |
+
if key.startswith(LoRANetwork.LORA_PREFIX_UNET + "_"):
|
| 460 |
+
search_key = key.replace(LoRANetwork.LORA_PREFIX_UNET + "_", "")
|
| 461 |
+
position = bisect.bisect_right(map_keys, search_key)
|
| 462 |
+
map_key = map_keys[position - 1]
|
| 463 |
+
if search_key.startswith(map_key):
|
| 464 |
+
new_key = key.replace(map_key, UNET_CONVERSION_MAP[map_key])
|
| 465 |
+
state_dict[new_key] = state_dict[key]
|
| 466 |
+
del state_dict[key]
|
| 467 |
+
|
| 468 |
+
# in case of V2, some weights have different shape, so we need to convert them
|
| 469 |
+
# because V2 LoRA is based on U-Net created by use_linear_projection=False
|
| 470 |
+
my_state_dict = self.state_dict()
|
| 471 |
+
for key in state_dict.keys():
|
| 472 |
+
if state_dict[key].size() != my_state_dict[key].size():
|
| 473 |
+
# logger.info(f"convert {key} from {state_dict[key].size()} to {my_state_dict[key].size()}")
|
| 474 |
+
state_dict[key] = state_dict[key].view(my_state_dict[key].size())
|
| 475 |
+
|
| 476 |
+
return super().load_state_dict(state_dict, strict)
|
| 477 |
+
|
| 478 |
+
|
| 479 |
+
if __name__ == "__main__":
|
| 480 |
+
# sample code to use LoRANetwork
|
| 481 |
+
import os
|
| 482 |
+
import argparse
|
| 483 |
+
from diffusers import StableDiffusionPipeline, StableDiffusionXLPipeline
|
| 484 |
+
import torch
|
| 485 |
+
|
| 486 |
+
device = get_preferred_device()
|
| 487 |
+
|
| 488 |
+
parser = argparse.ArgumentParser()
|
| 489 |
+
parser.add_argument("--model_id", type=str, default=None, help="model id for huggingface")
|
| 490 |
+
parser.add_argument("--lora_weights", type=str, default=None, help="path to LoRA weights")
|
| 491 |
+
parser.add_argument("--sdxl", action="store_true", help="use SDXL model")
|
| 492 |
+
parser.add_argument("--prompt", type=str, default="A photo of cat", help="prompt text")
|
| 493 |
+
parser.add_argument("--negative_prompt", type=str, default="", help="negative prompt text")
|
| 494 |
+
parser.add_argument("--seed", type=int, default=0, help="random seed")
|
| 495 |
+
args = parser.parse_args()
|
| 496 |
+
|
| 497 |
+
image_prefix = args.model_id.replace("/", "_") + "_"
|
| 498 |
+
|
| 499 |
+
# load Diffusers model
|
| 500 |
+
logger.info(f"load model from {args.model_id}")
|
| 501 |
+
pipe: Union[StableDiffusionPipeline, StableDiffusionXLPipeline]
|
| 502 |
+
if args.sdxl:
|
| 503 |
+
# use_safetensors=True does not work with 0.18.2
|
| 504 |
+
pipe = StableDiffusionXLPipeline.from_pretrained(args.model_id, variant="fp16", torch_dtype=torch.float16)
|
| 505 |
+
else:
|
| 506 |
+
pipe = StableDiffusionPipeline.from_pretrained(args.model_id, variant="fp16", torch_dtype=torch.float16)
|
| 507 |
+
pipe.to(device)
|
| 508 |
+
pipe.set_use_memory_efficient_attention_xformers(True)
|
| 509 |
+
|
| 510 |
+
text_encoders = [pipe.text_encoder, pipe.text_encoder_2] if args.sdxl else [pipe.text_encoder]
|
| 511 |
+
|
| 512 |
+
# load LoRA weights
|
| 513 |
+
logger.info(f"load LoRA weights from {args.lora_weights}")
|
| 514 |
+
if os.path.splitext(args.lora_weights)[1] == ".safetensors":
|
| 515 |
+
from safetensors.torch import load_file
|
| 516 |
+
|
| 517 |
+
lora_sd = load_file(args.lora_weights)
|
| 518 |
+
else:
|
| 519 |
+
lora_sd = torch.load(args.lora_weights)
|
| 520 |
+
|
| 521 |
+
# create by LoRA weights and load weights
|
| 522 |
+
logger.info(f"create LoRA network")
|
| 523 |
+
lora_network: LoRANetwork = create_network_from_weights(text_encoders, pipe.unet, lora_sd, multiplier=1.0)
|
| 524 |
+
|
| 525 |
+
logger.info(f"load LoRA network weights")
|
| 526 |
+
lora_network.load_state_dict(lora_sd)
|
| 527 |
+
|
| 528 |
+
lora_network.to(device, dtype=pipe.unet.dtype) # required to apply_to. merge_to works without this
|
| 529 |
+
|
| 530 |
+
# 必要があれば、元のモデルの重みをバックアップしておく
|
| 531 |
+
# back-up unet/text encoder weights if necessary
|
| 532 |
+
def detach_and_move_to_cpu(state_dict):
|
| 533 |
+
for k, v in state_dict.items():
|
| 534 |
+
state_dict[k] = v.detach().cpu()
|
| 535 |
+
return state_dict
|
| 536 |
+
|
| 537 |
+
org_unet_sd = pipe.unet.state_dict()
|
| 538 |
+
detach_and_move_to_cpu(org_unet_sd)
|
| 539 |
+
|
| 540 |
+
org_text_encoder_sd = pipe.text_encoder.state_dict()
|
| 541 |
+
detach_and_move_to_cpu(org_text_encoder_sd)
|
| 542 |
+
|
| 543 |
+
if args.sdxl:
|
| 544 |
+
org_text_encoder_2_sd = pipe.text_encoder_2.state_dict()
|
| 545 |
+
detach_and_move_to_cpu(org_text_encoder_2_sd)
|
| 546 |
+
|
| 547 |
+
def seed_everything(seed):
|
| 548 |
+
torch.manual_seed(seed)
|
| 549 |
+
torch.cuda.manual_seed_all(seed)
|
| 550 |
+
np.random.seed(seed)
|
| 551 |
+
random.seed(seed)
|
| 552 |
+
|
| 553 |
+
# create image with original weights
|
| 554 |
+
logger.info(f"create image with original weights")
|
| 555 |
+
seed_everything(args.seed)
|
| 556 |
+
image = pipe(args.prompt, negative_prompt=args.negative_prompt).images[0]
|
| 557 |
+
image.save(image_prefix + "original.png")
|
| 558 |
+
|
| 559 |
+
# apply LoRA network to the model: slower than merge_to, but can be reverted easily
|
| 560 |
+
logger.info(f"apply LoRA network to the model")
|
| 561 |
+
lora_network.apply_to(multiplier=1.0)
|
| 562 |
+
|
| 563 |
+
logger.info(f"create image with applied LoRA")
|
| 564 |
+
seed_everything(args.seed)
|
| 565 |
+
image = pipe(args.prompt, negative_prompt=args.negative_prompt).images[0]
|
| 566 |
+
image.save(image_prefix + "applied_lora.png")
|
| 567 |
+
|
| 568 |
+
# unapply LoRA network to the model
|
| 569 |
+
logger.info(f"unapply LoRA network to the model")
|
| 570 |
+
lora_network.unapply_to()
|
| 571 |
+
|
| 572 |
+
logger.info(f"create image with unapplied LoRA")
|
| 573 |
+
seed_everything(args.seed)
|
| 574 |
+
image = pipe(args.prompt, negative_prompt=args.negative_prompt).images[0]
|
| 575 |
+
image.save(image_prefix + "unapplied_lora.png")
|
| 576 |
+
|
| 577 |
+
# merge LoRA network to the model: faster than apply_to, but requires back-up of original weights (or unmerge_to)
|
| 578 |
+
logger.info(f"merge LoRA network to the model")
|
| 579 |
+
lora_network.merge_to(multiplier=1.0)
|
| 580 |
+
|
| 581 |
+
logger.info(f"create image with LoRA")
|
| 582 |
+
seed_everything(args.seed)
|
| 583 |
+
image = pipe(args.prompt, negative_prompt=args.negative_prompt).images[0]
|
| 584 |
+
image.save(image_prefix + "merged_lora.png")
|
| 585 |
+
|
| 586 |
+
# restore (unmerge) LoRA weights: numerically unstable
|
| 587 |
+
# マージされた重みを元に戻す。計算誤差のため、元の重みと完全に一致しないことがあるかもしれない
|
| 588 |
+
# 保存したstate_dictから元の重みを復元するのが確実
|
| 589 |
+
logger.info(f"restore (unmerge) LoRA weights")
|
| 590 |
+
lora_network.restore_from(multiplier=1.0)
|
| 591 |
+
|
| 592 |
+
logger.info(f"create image without LoRA")
|
| 593 |
+
seed_everything(args.seed)
|
| 594 |
+
image = pipe(args.prompt, negative_prompt=args.negative_prompt).images[0]
|
| 595 |
+
image.save(image_prefix + "unmerged_lora.png")
|
| 596 |
+
|
| 597 |
+
# restore original weights
|
| 598 |
+
logger.info(f"restore original weights")
|
| 599 |
+
pipe.unet.load_state_dict(org_unet_sd)
|
| 600 |
+
pipe.text_encoder.load_state_dict(org_text_encoder_sd)
|
| 601 |
+
if args.sdxl:
|
| 602 |
+
pipe.text_encoder_2.load_state_dict(org_text_encoder_2_sd)
|
| 603 |
+
|
| 604 |
+
logger.info(f"create image with restored original weights")
|
| 605 |
+
seed_everything(args.seed)
|
| 606 |
+
image = pipe(args.prompt, negative_prompt=args.negative_prompt).images[0]
|
| 607 |
+
image.save(image_prefix + "restore_original.png")
|
| 608 |
+
|
| 609 |
+
# use convenience function to merge LoRA weights
|
| 610 |
+
logger.info(f"merge LoRA weights with convenience function")
|
| 611 |
+
merge_lora_weights(pipe, lora_sd, multiplier=1.0)
|
| 612 |
+
|
| 613 |
+
logger.info(f"create image with merged LoRA weights")
|
| 614 |
+
seed_everything(args.seed)
|
| 615 |
+
image = pipe(args.prompt, negative_prompt=args.negative_prompt).images[0]
|
| 616 |
+
image.save(image_prefix + "convenience_merged_lora.png")
|
lora_fa.py
ADDED
|
@@ -0,0 +1,1244 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# LoRA network module
|
| 2 |
+
# reference:
|
| 3 |
+
# https://github.com/microsoft/LoRA/blob/main/loralib/layers.py
|
| 4 |
+
# https://github.com/cloneofsimo/lora/blob/master/lora_diffusion/lora.py
|
| 5 |
+
|
| 6 |
+
# temporary implementation of LoRA-FA: https://arxiv.org/abs/2308.03303
|
| 7 |
+
# need to be refactored and merged to lora.py
|
| 8 |
+
|
| 9 |
+
import math
|
| 10 |
+
import os
|
| 11 |
+
from typing import Dict, List, Optional, Tuple, Type, Union
|
| 12 |
+
from diffusers import AutoencoderKL
|
| 13 |
+
from transformers import CLIPTextModel
|
| 14 |
+
import numpy as np
|
| 15 |
+
import torch
|
| 16 |
+
import re
|
| 17 |
+
from library.utils import setup_logging
|
| 18 |
+
setup_logging()
|
| 19 |
+
import logging
|
| 20 |
+
logger = logging.getLogger(__name__)
|
| 21 |
+
|
| 22 |
+
RE_UPDOWN = re.compile(r"(up|down)_blocks_(\d+)_(resnets|upsamplers|downsamplers|attentions)_(\d+)_")
|
| 23 |
+
|
| 24 |
+
|
| 25 |
+
class LoRAModule(torch.nn.Module):
|
| 26 |
+
"""
|
| 27 |
+
replaces forward method of the original Linear, instead of replacing the original Linear module.
|
| 28 |
+
"""
|
| 29 |
+
|
| 30 |
+
def __init__(
|
| 31 |
+
self,
|
| 32 |
+
lora_name,
|
| 33 |
+
org_module: torch.nn.Module,
|
| 34 |
+
multiplier=1.0,
|
| 35 |
+
lora_dim=4,
|
| 36 |
+
alpha=1,
|
| 37 |
+
dropout=None,
|
| 38 |
+
rank_dropout=None,
|
| 39 |
+
module_dropout=None,
|
| 40 |
+
):
|
| 41 |
+
"""if alpha == 0 or None, alpha is rank (no scaling)."""
|
| 42 |
+
super().__init__()
|
| 43 |
+
self.lora_name = lora_name
|
| 44 |
+
|
| 45 |
+
if org_module.__class__.__name__ == "Conv2d":
|
| 46 |
+
in_dim = org_module.in_channels
|
| 47 |
+
out_dim = org_module.out_channels
|
| 48 |
+
else:
|
| 49 |
+
in_dim = org_module.in_features
|
| 50 |
+
out_dim = org_module.out_features
|
| 51 |
+
|
| 52 |
+
# if limit_rank:
|
| 53 |
+
# self.lora_dim = min(lora_dim, in_dim, out_dim)
|
| 54 |
+
# if self.lora_dim != lora_dim:
|
| 55 |
+
# logger.info(f"{lora_name} dim (rank) is changed to: {self.lora_dim}")
|
| 56 |
+
# else:
|
| 57 |
+
self.lora_dim = lora_dim
|
| 58 |
+
|
| 59 |
+
if org_module.__class__.__name__ == "Conv2d":
|
| 60 |
+
kernel_size = org_module.kernel_size
|
| 61 |
+
stride = org_module.stride
|
| 62 |
+
padding = org_module.padding
|
| 63 |
+
self.lora_down = torch.nn.Conv2d(in_dim, self.lora_dim, kernel_size, stride, padding, bias=False)
|
| 64 |
+
self.lora_up = torch.nn.Conv2d(self.lora_dim, out_dim, (1, 1), (1, 1), bias=False)
|
| 65 |
+
else:
|
| 66 |
+
self.lora_down = torch.nn.Linear(in_dim, self.lora_dim, bias=False)
|
| 67 |
+
self.lora_up = torch.nn.Linear(self.lora_dim, out_dim, bias=False)
|
| 68 |
+
|
| 69 |
+
if type(alpha) == torch.Tensor:
|
| 70 |
+
alpha = alpha.detach().float().numpy() # without casting, bf16 causes error
|
| 71 |
+
alpha = self.lora_dim if alpha is None or alpha == 0 else alpha
|
| 72 |
+
self.scale = alpha / self.lora_dim
|
| 73 |
+
self.register_buffer("alpha", torch.tensor(alpha)) # 定数として扱える
|
| 74 |
+
|
| 75 |
+
# # same as microsoft's
|
| 76 |
+
# torch.nn.init.kaiming_uniform_(self.lora_down.weight, a=math.sqrt(5))
|
| 77 |
+
|
| 78 |
+
# according to the paper, initialize LoRA-A (down) as normal distribution
|
| 79 |
+
torch.nn.init.normal_(self.lora_down.weight, std=math.sqrt(2.0 / (in_dim + self.lora_dim)))
|
| 80 |
+
|
| 81 |
+
torch.nn.init.zeros_(self.lora_up.weight)
|
| 82 |
+
|
| 83 |
+
self.multiplier = multiplier
|
| 84 |
+
self.org_module = org_module # remove in applying
|
| 85 |
+
self.dropout = dropout
|
| 86 |
+
self.rank_dropout = rank_dropout
|
| 87 |
+
self.module_dropout = module_dropout
|
| 88 |
+
|
| 89 |
+
def get_trainable_params(self):
|
| 90 |
+
params = self.named_parameters()
|
| 91 |
+
trainable_params = []
|
| 92 |
+
for param in params:
|
| 93 |
+
if param[0] == "lora_up.weight": # up only
|
| 94 |
+
trainable_params.append(param[1])
|
| 95 |
+
return trainable_params
|
| 96 |
+
|
| 97 |
+
def requires_grad_(self, requires_grad: bool = True):
|
| 98 |
+
self.lora_up.requires_grad_(requires_grad)
|
| 99 |
+
self.lora_down.requires_grad_(False)
|
| 100 |
+
return self
|
| 101 |
+
|
| 102 |
+
def apply_to(self):
|
| 103 |
+
self.org_forward = self.org_module.forward
|
| 104 |
+
self.org_module.forward = self.forward
|
| 105 |
+
del self.org_module
|
| 106 |
+
|
| 107 |
+
def forward(self, x):
|
| 108 |
+
org_forwarded = self.org_forward(x)
|
| 109 |
+
|
| 110 |
+
# module dropout
|
| 111 |
+
if self.module_dropout is not None and self.training:
|
| 112 |
+
if torch.rand(1) < self.module_dropout:
|
| 113 |
+
return org_forwarded
|
| 114 |
+
|
| 115 |
+
lx = self.lora_down(x)
|
| 116 |
+
|
| 117 |
+
# normal dropout
|
| 118 |
+
if self.dropout is not None and self.training:
|
| 119 |
+
lx = torch.nn.functional.dropout(lx, p=self.dropout)
|
| 120 |
+
|
| 121 |
+
# rank dropout
|
| 122 |
+
if self.rank_dropout is not None and self.training:
|
| 123 |
+
mask = torch.rand((lx.size(0), self.lora_dim), device=lx.device) > self.rank_dropout
|
| 124 |
+
if len(lx.size()) == 3:
|
| 125 |
+
mask = mask.unsqueeze(1) # for Text Encoder
|
| 126 |
+
elif len(lx.size()) == 4:
|
| 127 |
+
mask = mask.unsqueeze(-1).unsqueeze(-1) # for Conv2d
|
| 128 |
+
lx = lx * mask
|
| 129 |
+
|
| 130 |
+
# scaling for rank dropout: treat as if the rank is changed
|
| 131 |
+
# maskから計算することも考えられるが、augmentation的な効果を期待してrank_dropoutを用いる
|
| 132 |
+
scale = self.scale * (1.0 / (1.0 - self.rank_dropout)) # redundant for readability
|
| 133 |
+
else:
|
| 134 |
+
scale = self.scale
|
| 135 |
+
|
| 136 |
+
lx = self.lora_up(lx)
|
| 137 |
+
|
| 138 |
+
return org_forwarded + lx * self.multiplier * scale
|
| 139 |
+
|
| 140 |
+
|
| 141 |
+
class LoRAInfModule(LoRAModule):
|
| 142 |
+
def __init__(
|
| 143 |
+
self,
|
| 144 |
+
lora_name,
|
| 145 |
+
org_module: torch.nn.Module,
|
| 146 |
+
multiplier=1.0,
|
| 147 |
+
lora_dim=4,
|
| 148 |
+
alpha=1,
|
| 149 |
+
**kwargs,
|
| 150 |
+
):
|
| 151 |
+
# no dropout for inference
|
| 152 |
+
super().__init__(lora_name, org_module, multiplier, lora_dim, alpha)
|
| 153 |
+
|
| 154 |
+
self.org_module_ref = [org_module] # 後から参照できるように
|
| 155 |
+
self.enabled = True
|
| 156 |
+
|
| 157 |
+
# check regional or not by lora_name
|
| 158 |
+
self.text_encoder = False
|
| 159 |
+
if lora_name.startswith("lora_te_"):
|
| 160 |
+
self.regional = False
|
| 161 |
+
self.use_sub_prompt = True
|
| 162 |
+
self.text_encoder = True
|
| 163 |
+
elif "attn2_to_k" in lora_name or "attn2_to_v" in lora_name:
|
| 164 |
+
self.regional = False
|
| 165 |
+
self.use_sub_prompt = True
|
| 166 |
+
elif "time_emb" in lora_name:
|
| 167 |
+
self.regional = False
|
| 168 |
+
self.use_sub_prompt = False
|
| 169 |
+
else:
|
| 170 |
+
self.regional = True
|
| 171 |
+
self.use_sub_prompt = False
|
| 172 |
+
|
| 173 |
+
self.network: LoRANetwork = None
|
| 174 |
+
|
| 175 |
+
def set_network(self, network):
|
| 176 |
+
self.network = network
|
| 177 |
+
|
| 178 |
+
# freezeしてマージする
|
| 179 |
+
def merge_to(self, sd, dtype, device):
|
| 180 |
+
# get up/down weight
|
| 181 |
+
up_weight = sd["lora_up.weight"].to(torch.float).to(device)
|
| 182 |
+
down_weight = sd["lora_down.weight"].to(torch.float).to(device)
|
| 183 |
+
|
| 184 |
+
# extract weight from org_module
|
| 185 |
+
org_sd = self.org_module.state_dict()
|
| 186 |
+
weight = org_sd["weight"].to(torch.float)
|
| 187 |
+
|
| 188 |
+
# merge weight
|
| 189 |
+
if len(weight.size()) == 2:
|
| 190 |
+
# linear
|
| 191 |
+
weight = weight + self.multiplier * (up_weight @ down_weight) * self.scale
|
| 192 |
+
elif down_weight.size()[2:4] == (1, 1):
|
| 193 |
+
# conv2d 1x1
|
| 194 |
+
weight = (
|
| 195 |
+
weight
|
| 196 |
+
+ self.multiplier
|
| 197 |
+
* (up_weight.squeeze(3).squeeze(2) @ down_weight.squeeze(3).squeeze(2)).unsqueeze(2).unsqueeze(3)
|
| 198 |
+
* self.scale
|
| 199 |
+
)
|
| 200 |
+
else:
|
| 201 |
+
# conv2d 3x3
|
| 202 |
+
conved = torch.nn.functional.conv2d(down_weight.permute(1, 0, 2, 3), up_weight).permute(1, 0, 2, 3)
|
| 203 |
+
# logger.info(conved.size(), weight.size(), module.stride, module.padding)
|
| 204 |
+
weight = weight + self.multiplier * conved * self.scale
|
| 205 |
+
|
| 206 |
+
# set weight to org_module
|
| 207 |
+
org_sd["weight"] = weight.to(dtype)
|
| 208 |
+
self.org_module.load_state_dict(org_sd)
|
| 209 |
+
|
| 210 |
+
# 復元できるマージのため、このモジュールのweightを返す
|
| 211 |
+
def get_weight(self, multiplier=None):
|
| 212 |
+
if multiplier is None:
|
| 213 |
+
multiplier = self.multiplier
|
| 214 |
+
|
| 215 |
+
# get up/down weight from module
|
| 216 |
+
up_weight = self.lora_up.weight.to(torch.float)
|
| 217 |
+
down_weight = self.lora_down.weight.to(torch.float)
|
| 218 |
+
|
| 219 |
+
# pre-calculated weight
|
| 220 |
+
if len(down_weight.size()) == 2:
|
| 221 |
+
# linear
|
| 222 |
+
weight = self.multiplier * (up_weight @ down_weight) * self.scale
|
| 223 |
+
elif down_weight.size()[2:4] == (1, 1):
|
| 224 |
+
# conv2d 1x1
|
| 225 |
+
weight = (
|
| 226 |
+
self.multiplier
|
| 227 |
+
* (up_weight.squeeze(3).squeeze(2) @ down_weight.squeeze(3).squeeze(2)).unsqueeze(2).unsqueeze(3)
|
| 228 |
+
* self.scale
|
| 229 |
+
)
|
| 230 |
+
else:
|
| 231 |
+
# conv2d 3x3
|
| 232 |
+
conved = torch.nn.functional.conv2d(down_weight.permute(1, 0, 2, 3), up_weight).permute(1, 0, 2, 3)
|
| 233 |
+
weight = self.multiplier * conved * self.scale
|
| 234 |
+
|
| 235 |
+
return weight
|
| 236 |
+
|
| 237 |
+
def set_region(self, region):
|
| 238 |
+
self.region = region
|
| 239 |
+
self.region_mask = None
|
| 240 |
+
|
| 241 |
+
def default_forward(self, x):
|
| 242 |
+
# logger.info("default_forward", self.lora_name, x.size())
|
| 243 |
+
return self.org_forward(x) + self.lora_up(self.lora_down(x)) * self.multiplier * self.scale
|
| 244 |
+
|
| 245 |
+
def forward(self, x):
|
| 246 |
+
if not self.enabled:
|
| 247 |
+
return self.org_forward(x)
|
| 248 |
+
|
| 249 |
+
if self.network is None or self.network.sub_prompt_index is None:
|
| 250 |
+
return self.default_forward(x)
|
| 251 |
+
if not self.regional and not self.use_sub_prompt:
|
| 252 |
+
return self.default_forward(x)
|
| 253 |
+
|
| 254 |
+
if self.regional:
|
| 255 |
+
return self.regional_forward(x)
|
| 256 |
+
else:
|
| 257 |
+
return self.sub_prompt_forward(x)
|
| 258 |
+
|
| 259 |
+
def get_mask_for_x(self, x):
|
| 260 |
+
# calculate size from shape of x
|
| 261 |
+
if len(x.size()) == 4:
|
| 262 |
+
h, w = x.size()[2:4]
|
| 263 |
+
area = h * w
|
| 264 |
+
else:
|
| 265 |
+
area = x.size()[1]
|
| 266 |
+
|
| 267 |
+
mask = self.network.mask_dic[area]
|
| 268 |
+
if mask is None:
|
| 269 |
+
raise ValueError(f"mask is None for resolution {area}")
|
| 270 |
+
if len(x.size()) != 4:
|
| 271 |
+
mask = torch.reshape(mask, (1, -1, 1))
|
| 272 |
+
return mask
|
| 273 |
+
|
| 274 |
+
def regional_forward(self, x):
|
| 275 |
+
if "attn2_to_out" in self.lora_name:
|
| 276 |
+
return self.to_out_forward(x)
|
| 277 |
+
|
| 278 |
+
if self.network.mask_dic is None: # sub_prompt_index >= 3
|
| 279 |
+
return self.default_forward(x)
|
| 280 |
+
|
| 281 |
+
# apply mask for LoRA result
|
| 282 |
+
lx = self.lora_up(self.lora_down(x)) * self.multiplier * self.scale
|
| 283 |
+
mask = self.get_mask_for_x(lx)
|
| 284 |
+
# logger.info("regional", self.lora_name, self.network.sub_prompt_index, lx.size(), mask.size())
|
| 285 |
+
lx = lx * mask
|
| 286 |
+
|
| 287 |
+
x = self.org_forward(x)
|
| 288 |
+
x = x + lx
|
| 289 |
+
|
| 290 |
+
if "attn2_to_q" in self.lora_name and self.network.is_last_network:
|
| 291 |
+
x = self.postp_to_q(x)
|
| 292 |
+
|
| 293 |
+
return x
|
| 294 |
+
|
| 295 |
+
def postp_to_q(self, x):
|
| 296 |
+
# repeat x to num_sub_prompts
|
| 297 |
+
has_real_uncond = x.size()[0] // self.network.batch_size == 3
|
| 298 |
+
qc = self.network.batch_size # uncond
|
| 299 |
+
qc += self.network.batch_size * self.network.num_sub_prompts # cond
|
| 300 |
+
if has_real_uncond:
|
| 301 |
+
qc += self.network.batch_size # real_uncond
|
| 302 |
+
|
| 303 |
+
query = torch.zeros((qc, x.size()[1], x.size()[2]), device=x.device, dtype=x.dtype)
|
| 304 |
+
query[: self.network.batch_size] = x[: self.network.batch_size]
|
| 305 |
+
|
| 306 |
+
for i in range(self.network.batch_size):
|
| 307 |
+
qi = self.network.batch_size + i * self.network.num_sub_prompts
|
| 308 |
+
query[qi : qi + self.network.num_sub_prompts] = x[self.network.batch_size + i]
|
| 309 |
+
|
| 310 |
+
if has_real_uncond:
|
| 311 |
+
query[-self.network.batch_size :] = x[-self.network.batch_size :]
|
| 312 |
+
|
| 313 |
+
# logger.info("postp_to_q", self.lora_name, x.size(), query.size(), self.network.num_sub_prompts)
|
| 314 |
+
return query
|
| 315 |
+
|
| 316 |
+
def sub_prompt_forward(self, x):
|
| 317 |
+
if x.size()[0] == self.network.batch_size: # if uncond in text_encoder, do not apply LoRA
|
| 318 |
+
return self.org_forward(x)
|
| 319 |
+
|
| 320 |
+
emb_idx = self.network.sub_prompt_index
|
| 321 |
+
if not self.text_encoder:
|
| 322 |
+
emb_idx += self.network.batch_size
|
| 323 |
+
|
| 324 |
+
# apply sub prompt of X
|
| 325 |
+
lx = x[emb_idx :: self.network.num_sub_prompts]
|
| 326 |
+
lx = self.lora_up(self.lora_down(lx)) * self.multiplier * self.scale
|
| 327 |
+
|
| 328 |
+
# logger.info("sub_prompt_forward", self.lora_name, x.size(), lx.size(), emb_idx)
|
| 329 |
+
|
| 330 |
+
x = self.org_forward(x)
|
| 331 |
+
x[emb_idx :: self.network.num_sub_prompts] += lx
|
| 332 |
+
|
| 333 |
+
return x
|
| 334 |
+
|
| 335 |
+
def to_out_forward(self, x):
|
| 336 |
+
# logger.info("to_out_forward", self.lora_name, x.size(), self.network.is_last_network)
|
| 337 |
+
|
| 338 |
+
if self.network.is_last_network:
|
| 339 |
+
masks = [None] * self.network.num_sub_prompts
|
| 340 |
+
self.network.shared[self.lora_name] = (None, masks)
|
| 341 |
+
else:
|
| 342 |
+
lx, masks = self.network.shared[self.lora_name]
|
| 343 |
+
|
| 344 |
+
# call own LoRA
|
| 345 |
+
x1 = x[self.network.batch_size + self.network.sub_prompt_index :: self.network.num_sub_prompts]
|
| 346 |
+
lx1 = self.lora_up(self.lora_down(x1)) * self.multiplier * self.scale
|
| 347 |
+
|
| 348 |
+
if self.network.is_last_network:
|
| 349 |
+
lx = torch.zeros(
|
| 350 |
+
(self.network.num_sub_prompts * self.network.batch_size, *lx1.size()[1:]), device=lx1.device, dtype=lx1.dtype
|
| 351 |
+
)
|
| 352 |
+
self.network.shared[self.lora_name] = (lx, masks)
|
| 353 |
+
|
| 354 |
+
# logger.info("to_out_forward", lx.size(), lx1.size(), self.network.sub_prompt_index, self.network.num_sub_prompts)
|
| 355 |
+
lx[self.network.sub_prompt_index :: self.network.num_sub_prompts] += lx1
|
| 356 |
+
masks[self.network.sub_prompt_index] = self.get_mask_for_x(lx1)
|
| 357 |
+
|
| 358 |
+
# if not last network, return x and masks
|
| 359 |
+
x = self.org_forward(x)
|
| 360 |
+
if not self.network.is_last_network:
|
| 361 |
+
return x
|
| 362 |
+
|
| 363 |
+
lx, masks = self.network.shared.pop(self.lora_name)
|
| 364 |
+
|
| 365 |
+
# if last network, combine separated x with mask weighted sum
|
| 366 |
+
has_real_uncond = x.size()[0] // self.network.batch_size == self.network.num_sub_prompts + 2
|
| 367 |
+
|
| 368 |
+
out = torch.zeros((self.network.batch_size * (3 if has_real_uncond else 2), *x.size()[1:]), device=x.device, dtype=x.dtype)
|
| 369 |
+
out[: self.network.batch_size] = x[: self.network.batch_size] # uncond
|
| 370 |
+
if has_real_uncond:
|
| 371 |
+
out[-self.network.batch_size :] = x[-self.network.batch_size :] # real_uncond
|
| 372 |
+
|
| 373 |
+
# logger.info("to_out_forward", self.lora_name, self.network.sub_prompt_index, self.network.num_sub_prompts)
|
| 374 |
+
# for i in range(len(masks)):
|
| 375 |
+
# if masks[i] is None:
|
| 376 |
+
# masks[i] = torch.zeros_like(masks[-1])
|
| 377 |
+
|
| 378 |
+
mask = torch.cat(masks)
|
| 379 |
+
mask_sum = torch.sum(mask, dim=0) + 1e-4
|
| 380 |
+
for i in range(self.network.batch_size):
|
| 381 |
+
# 1枚の画像ごとに処理する
|
| 382 |
+
lx1 = lx[i * self.network.num_sub_prompts : (i + 1) * self.network.num_sub_prompts]
|
| 383 |
+
lx1 = lx1 * mask
|
| 384 |
+
lx1 = torch.sum(lx1, dim=0)
|
| 385 |
+
|
| 386 |
+
xi = self.network.batch_size + i * self.network.num_sub_prompts
|
| 387 |
+
x1 = x[xi : xi + self.network.num_sub_prompts]
|
| 388 |
+
x1 = x1 * mask
|
| 389 |
+
x1 = torch.sum(x1, dim=0)
|
| 390 |
+
x1 = x1 / mask_sum
|
| 391 |
+
|
| 392 |
+
x1 = x1 + lx1
|
| 393 |
+
out[self.network.batch_size + i] = x1
|
| 394 |
+
|
| 395 |
+
# logger.info("to_out_forward", x.size(), out.size(), has_real_uncond)
|
| 396 |
+
return out
|
| 397 |
+
|
| 398 |
+
|
| 399 |
+
def parse_block_lr_kwargs(nw_kwargs):
|
| 400 |
+
down_lr_weight = nw_kwargs.get("down_lr_weight", None)
|
| 401 |
+
mid_lr_weight = nw_kwargs.get("mid_lr_weight", None)
|
| 402 |
+
up_lr_weight = nw_kwargs.get("up_lr_weight", None)
|
| 403 |
+
|
| 404 |
+
# 以上のいずれにも設定がない場合は無効としてNoneを返す
|
| 405 |
+
if down_lr_weight is None and mid_lr_weight is None and up_lr_weight is None:
|
| 406 |
+
return None, None, None
|
| 407 |
+
|
| 408 |
+
# extract learning rate weight for each block
|
| 409 |
+
if down_lr_weight is not None:
|
| 410 |
+
# if some parameters are not set, use zero
|
| 411 |
+
if "," in down_lr_weight:
|
| 412 |
+
down_lr_weight = [(float(s) if s else 0.0) for s in down_lr_weight.split(",")]
|
| 413 |
+
|
| 414 |
+
if mid_lr_weight is not None:
|
| 415 |
+
mid_lr_weight = float(mid_lr_weight)
|
| 416 |
+
|
| 417 |
+
if up_lr_weight is not None:
|
| 418 |
+
if "," in up_lr_weight:
|
| 419 |
+
up_lr_weight = [(float(s) if s else 0.0) for s in up_lr_weight.split(",")]
|
| 420 |
+
|
| 421 |
+
down_lr_weight, mid_lr_weight, up_lr_weight = get_block_lr_weight(
|
| 422 |
+
down_lr_weight, mid_lr_weight, up_lr_weight, float(nw_kwargs.get("block_lr_zero_threshold", 0.0))
|
| 423 |
+
)
|
| 424 |
+
|
| 425 |
+
return down_lr_weight, mid_lr_weight, up_lr_weight
|
| 426 |
+
|
| 427 |
+
|
| 428 |
+
def create_network(
|
| 429 |
+
multiplier: float,
|
| 430 |
+
network_dim: Optional[int],
|
| 431 |
+
network_alpha: Optional[float],
|
| 432 |
+
vae: AutoencoderKL,
|
| 433 |
+
text_encoder: Union[CLIPTextModel, List[CLIPTextModel]],
|
| 434 |
+
unet,
|
| 435 |
+
neuron_dropout: Optional[float] = None,
|
| 436 |
+
**kwargs,
|
| 437 |
+
):
|
| 438 |
+
if network_dim is None:
|
| 439 |
+
network_dim = 4 # default
|
| 440 |
+
if network_alpha is None:
|
| 441 |
+
network_alpha = 1.0
|
| 442 |
+
|
| 443 |
+
# extract dim/alpha for conv2d, and block dim
|
| 444 |
+
conv_dim = kwargs.get("conv_dim", None)
|
| 445 |
+
conv_alpha = kwargs.get("conv_alpha", None)
|
| 446 |
+
if conv_dim is not None:
|
| 447 |
+
conv_dim = int(conv_dim)
|
| 448 |
+
if conv_alpha is None:
|
| 449 |
+
conv_alpha = 1.0
|
| 450 |
+
else:
|
| 451 |
+
conv_alpha = float(conv_alpha)
|
| 452 |
+
|
| 453 |
+
# block dim/alpha/lr
|
| 454 |
+
block_dims = kwargs.get("block_dims", None)
|
| 455 |
+
down_lr_weight, mid_lr_weight, up_lr_weight = parse_block_lr_kwargs(kwargs)
|
| 456 |
+
|
| 457 |
+
# 以上のいずれかに指定があればblockごとのdim(rank)を有効にする
|
| 458 |
+
if block_dims is not None or down_lr_weight is not None or mid_lr_weight is not None or up_lr_weight is not None:
|
| 459 |
+
block_alphas = kwargs.get("block_alphas", None)
|
| 460 |
+
conv_block_dims = kwargs.get("conv_block_dims", None)
|
| 461 |
+
conv_block_alphas = kwargs.get("conv_block_alphas", None)
|
| 462 |
+
|
| 463 |
+
block_dims, block_alphas, conv_block_dims, conv_block_alphas = get_block_dims_and_alphas(
|
| 464 |
+
block_dims, block_alphas, network_dim, network_alpha, conv_block_dims, conv_block_alphas, conv_dim, conv_alpha
|
| 465 |
+
)
|
| 466 |
+
|
| 467 |
+
# remove block dim/alpha without learning rate
|
| 468 |
+
block_dims, block_alphas, conv_block_dims, conv_block_alphas = remove_block_dims_and_alphas(
|
| 469 |
+
block_dims, block_alphas, conv_block_dims, conv_block_alphas, down_lr_weight, mid_lr_weight, up_lr_weight
|
| 470 |
+
)
|
| 471 |
+
|
| 472 |
+
else:
|
| 473 |
+
block_alphas = None
|
| 474 |
+
conv_block_dims = None
|
| 475 |
+
conv_block_alphas = None
|
| 476 |
+
|
| 477 |
+
# rank/module dropout
|
| 478 |
+
rank_dropout = kwargs.get("rank_dropout", None)
|
| 479 |
+
if rank_dropout is not None:
|
| 480 |
+
rank_dropout = float(rank_dropout)
|
| 481 |
+
module_dropout = kwargs.get("module_dropout", None)
|
| 482 |
+
if module_dropout is not None:
|
| 483 |
+
module_dropout = float(module_dropout)
|
| 484 |
+
|
| 485 |
+
# すごく引数が多いな ( ^ω^)・・・
|
| 486 |
+
network = LoRANetwork(
|
| 487 |
+
text_encoder,
|
| 488 |
+
unet,
|
| 489 |
+
multiplier=multiplier,
|
| 490 |
+
lora_dim=network_dim,
|
| 491 |
+
alpha=network_alpha,
|
| 492 |
+
dropout=neuron_dropout,
|
| 493 |
+
rank_dropout=rank_dropout,
|
| 494 |
+
module_dropout=module_dropout,
|
| 495 |
+
conv_lora_dim=conv_dim,
|
| 496 |
+
conv_alpha=conv_alpha,
|
| 497 |
+
block_dims=block_dims,
|
| 498 |
+
block_alphas=block_alphas,
|
| 499 |
+
conv_block_dims=conv_block_dims,
|
| 500 |
+
conv_block_alphas=conv_block_alphas,
|
| 501 |
+
varbose=True,
|
| 502 |
+
)
|
| 503 |
+
|
| 504 |
+
if up_lr_weight is not None or mid_lr_weight is not None or down_lr_weight is not None:
|
| 505 |
+
network.set_block_lr_weight(up_lr_weight, mid_lr_weight, down_lr_weight)
|
| 506 |
+
|
| 507 |
+
return network
|
| 508 |
+
|
| 509 |
+
|
| 510 |
+
# このメソッドは外部から呼び出される可能性を考慮しておく
|
| 511 |
+
# network_dim, network_alpha にはデフォルト値が入っている。
|
| 512 |
+
# block_dims, block_alphas は両方ともNoneまたは両方とも値が入っている
|
| 513 |
+
# conv_dim, conv_alpha は両方ともNoneまたは両方とも値が入っている
|
| 514 |
+
def get_block_dims_and_alphas(
|
| 515 |
+
block_dims, block_alphas, network_dim, network_alpha, conv_block_dims, conv_block_alphas, conv_dim, conv_alpha
|
| 516 |
+
):
|
| 517 |
+
num_total_blocks = LoRANetwork.NUM_OF_BLOCKS * 2 + 1
|
| 518 |
+
|
| 519 |
+
def parse_ints(s):
|
| 520 |
+
return [int(i) for i in s.split(",")]
|
| 521 |
+
|
| 522 |
+
def parse_floats(s):
|
| 523 |
+
return [float(i) for i in s.split(",")]
|
| 524 |
+
|
| 525 |
+
# block_dimsとblock_alphasをパースする。必ず値が入る
|
| 526 |
+
if block_dims is not None:
|
| 527 |
+
block_dims = parse_ints(block_dims)
|
| 528 |
+
assert (
|
| 529 |
+
len(block_dims) == num_total_blocks
|
| 530 |
+
), f"block_dims must have {num_total_blocks} elements / block_dimsは{num_total_blocks}個指定してください"
|
| 531 |
+
else:
|
| 532 |
+
logger.warning(f"block_dims is not specified. all dims are set to {network_dim} / block_dimsが指定されていません。すべてのdimは{network_dim}になります")
|
| 533 |
+
block_dims = [network_dim] * num_total_blocks
|
| 534 |
+
|
| 535 |
+
if block_alphas is not None:
|
| 536 |
+
block_alphas = parse_floats(block_alphas)
|
| 537 |
+
assert (
|
| 538 |
+
len(block_alphas) == num_total_blocks
|
| 539 |
+
), f"block_alphas must have {num_total_blocks} elements / block_alphasは{num_total_blocks}個指定してください"
|
| 540 |
+
else:
|
| 541 |
+
logger.warning(
|
| 542 |
+
f"block_alphas is not specified. all alphas are set to {network_alpha} / block_alphasが指定されていません。すべてのalphaは{network_alpha}になります"
|
| 543 |
+
)
|
| 544 |
+
block_alphas = [network_alpha] * num_total_blocks
|
| 545 |
+
|
| 546 |
+
# conv_block_dimsとconv_block_alphasを、指定がある場合のみパースする。指定がなければconv_dimとconv_alphaを使う
|
| 547 |
+
if conv_block_dims is not None:
|
| 548 |
+
conv_block_dims = parse_ints(conv_block_dims)
|
| 549 |
+
assert (
|
| 550 |
+
len(conv_block_dims) == num_total_blocks
|
| 551 |
+
), f"conv_block_dims must have {num_total_blocks} elements / conv_block_dimsは{num_total_blocks}個指定してください"
|
| 552 |
+
|
| 553 |
+
if conv_block_alphas is not None:
|
| 554 |
+
conv_block_alphas = parse_floats(conv_block_alphas)
|
| 555 |
+
assert (
|
| 556 |
+
len(conv_block_alphas) == num_total_blocks
|
| 557 |
+
), f"conv_block_alphas must have {num_total_blocks} elements / conv_block_alphasは{num_total_blocks}個指定してください"
|
| 558 |
+
else:
|
| 559 |
+
if conv_alpha is None:
|
| 560 |
+
conv_alpha = 1.0
|
| 561 |
+
logger.warning(
|
| 562 |
+
f"conv_block_alphas is not specified. all alphas are set to {conv_alpha} / conv_block_alphasが指定されていません。すべてのalphaは{conv_alpha}になります"
|
| 563 |
+
)
|
| 564 |
+
conv_block_alphas = [conv_alpha] * num_total_blocks
|
| 565 |
+
else:
|
| 566 |
+
if conv_dim is not None:
|
| 567 |
+
logger.warning(
|
| 568 |
+
f"conv_dim/alpha for all blocks are set to {conv_dim} and {conv_alpha} / すべてのブロックのconv_dimとalphaは{conv_dim}および{conv_alpha}になります"
|
| 569 |
+
)
|
| 570 |
+
conv_block_dims = [conv_dim] * num_total_blocks
|
| 571 |
+
conv_block_alphas = [conv_alpha] * num_total_blocks
|
| 572 |
+
else:
|
| 573 |
+
conv_block_dims = None
|
| 574 |
+
conv_block_alphas = None
|
| 575 |
+
|
| 576 |
+
return block_dims, block_alphas, conv_block_dims, conv_block_alphas
|
| 577 |
+
|
| 578 |
+
|
| 579 |
+
# 層別学習率用に層ごとの学習率に対する倍率を定義する、外部から呼び出される可能性を考慮しておく
|
| 580 |
+
def get_block_lr_weight(
|
| 581 |
+
down_lr_weight, mid_lr_weight, up_lr_weight, zero_threshold
|
| 582 |
+
) -> Tuple[List[float], List[float], List[float]]:
|
| 583 |
+
# パラメータ未指定時は何もせず、今までと同じ動作とする
|
| 584 |
+
if up_lr_weight is None and mid_lr_weight is None and down_lr_weight is None:
|
| 585 |
+
return None, None, None
|
| 586 |
+
|
| 587 |
+
max_len = LoRANetwork.NUM_OF_BLOCKS # フルモデル相当でのup,downの層の数
|
| 588 |
+
|
| 589 |
+
def get_list(name_with_suffix) -> List[float]:
|
| 590 |
+
import math
|
| 591 |
+
|
| 592 |
+
tokens = name_with_suffix.split("+")
|
| 593 |
+
name = tokens[0]
|
| 594 |
+
base_lr = float(tokens[1]) if len(tokens) > 1 else 0.0
|
| 595 |
+
|
| 596 |
+
if name == "cosine":
|
| 597 |
+
return [math.sin(math.pi * (i / (max_len - 1)) / 2) + base_lr for i in reversed(range(max_len))]
|
| 598 |
+
elif name == "sine":
|
| 599 |
+
return [math.sin(math.pi * (i / (max_len - 1)) / 2) + base_lr for i in range(max_len)]
|
| 600 |
+
elif name == "linear":
|
| 601 |
+
return [i / (max_len - 1) + base_lr for i in range(max_len)]
|
| 602 |
+
elif name == "reverse_linear":
|
| 603 |
+
return [i / (max_len - 1) + base_lr for i in reversed(range(max_len))]
|
| 604 |
+
elif name == "zeros":
|
| 605 |
+
return [0.0 + base_lr] * max_len
|
| 606 |
+
else:
|
| 607 |
+
logger.error(
|
| 608 |
+
"Unknown lr_weight argument %s is used. Valid arguments: / 不明なlr_weightの引数 %s が使われました。有効な引数:\n\tcosine, sine, linear, reverse_linear, zeros"
|
| 609 |
+
% (name)
|
| 610 |
+
)
|
| 611 |
+
return None
|
| 612 |
+
|
| 613 |
+
if type(down_lr_weight) == str:
|
| 614 |
+
down_lr_weight = get_list(down_lr_weight)
|
| 615 |
+
if type(up_lr_weight) == str:
|
| 616 |
+
up_lr_weight = get_list(up_lr_weight)
|
| 617 |
+
|
| 618 |
+
if (up_lr_weight != None and len(up_lr_weight) > max_len) or (down_lr_weight != None and len(down_lr_weight) > max_len):
|
| 619 |
+
logger.warning("down_weight or up_weight is too long. Parameters after %d-th are ignored." % max_len)
|
| 620 |
+
logger.warning("down_weightもしくはup_weightが長すぎます。%d個目以降のパラメータは無視されます。" % max_len)
|
| 621 |
+
up_lr_weight = up_lr_weight[:max_len]
|
| 622 |
+
down_lr_weight = down_lr_weight[:max_len]
|
| 623 |
+
|
| 624 |
+
if (up_lr_weight != None and len(up_lr_weight) < max_len) or (down_lr_weight != None and len(down_lr_weight) < max_len):
|
| 625 |
+
logger.warning("down_weight or up_weight is too short. Parameters after %d-th are filled with 1." % max_len)
|
| 626 |
+
logger.warning("down_weightもしくはup_weightが短すぎます。%d個目までの不足したパラメータは1で補われます。" % max_len)
|
| 627 |
+
|
| 628 |
+
if down_lr_weight != None and len(down_lr_weight) < max_len:
|
| 629 |
+
down_lr_weight = down_lr_weight + [1.0] * (max_len - len(down_lr_weight))
|
| 630 |
+
if up_lr_weight != None and len(up_lr_weight) < max_len:
|
| 631 |
+
up_lr_weight = up_lr_weight + [1.0] * (max_len - len(up_lr_weight))
|
| 632 |
+
|
| 633 |
+
if (up_lr_weight != None) or (mid_lr_weight != None) or (down_lr_weight != None):
|
| 634 |
+
logger.info("apply block learning rate / 階層別学習率を適用します。")
|
| 635 |
+
if down_lr_weight != None:
|
| 636 |
+
down_lr_weight = [w if w > zero_threshold else 0 for w in down_lr_weight]
|
| 637 |
+
logger.info(f"down_lr_weight (shallower -> deeper, 浅い層->深い層): {down_lr_weight}")
|
| 638 |
+
else:
|
| 639 |
+
logger.info("down_lr_weight: all 1.0, すべて1.0")
|
| 640 |
+
|
| 641 |
+
if mid_lr_weight != None:
|
| 642 |
+
mid_lr_weight = mid_lr_weight if mid_lr_weight > zero_threshold else 0
|
| 643 |
+
logger.info(f"mid_lr_weight: {mid_lr_weight}")
|
| 644 |
+
else:
|
| 645 |
+
logger.info("mid_lr_weight: 1.0")
|
| 646 |
+
|
| 647 |
+
if up_lr_weight != None:
|
| 648 |
+
up_lr_weight = [w if w > zero_threshold else 0 for w in up_lr_weight]
|
| 649 |
+
logger.info(f"up_lr_weight (deeper -> shallower, 深い層->浅い層): {up_lr_weight}")
|
| 650 |
+
else:
|
| 651 |
+
logger.info("up_lr_weight: all 1.0, すべて1.0")
|
| 652 |
+
|
| 653 |
+
return down_lr_weight, mid_lr_weight, up_lr_weight
|
| 654 |
+
|
| 655 |
+
|
| 656 |
+
# lr_weightが0のblockをblock_dimsから除外する、外部から呼び出す可能性を考慮しておく
|
| 657 |
+
def remove_block_dims_and_alphas(
|
| 658 |
+
block_dims, block_alphas, conv_block_dims, conv_block_alphas, down_lr_weight, mid_lr_weight, up_lr_weight
|
| 659 |
+
):
|
| 660 |
+
# set 0 to block dim without learning rate to remove the block
|
| 661 |
+
if down_lr_weight != None:
|
| 662 |
+
for i, lr in enumerate(down_lr_weight):
|
| 663 |
+
if lr == 0:
|
| 664 |
+
block_dims[i] = 0
|
| 665 |
+
if conv_block_dims is not None:
|
| 666 |
+
conv_block_dims[i] = 0
|
| 667 |
+
if mid_lr_weight != None:
|
| 668 |
+
if mid_lr_weight == 0:
|
| 669 |
+
block_dims[LoRANetwork.NUM_OF_BLOCKS] = 0
|
| 670 |
+
if conv_block_dims is not None:
|
| 671 |
+
conv_block_dims[LoRANetwork.NUM_OF_BLOCKS] = 0
|
| 672 |
+
if up_lr_weight != None:
|
| 673 |
+
for i, lr in enumerate(up_lr_weight):
|
| 674 |
+
if lr == 0:
|
| 675 |
+
block_dims[LoRANetwork.NUM_OF_BLOCKS + 1 + i] = 0
|
| 676 |
+
if conv_block_dims is not None:
|
| 677 |
+
conv_block_dims[LoRANetwork.NUM_OF_BLOCKS + 1 + i] = 0
|
| 678 |
+
|
| 679 |
+
return block_dims, block_alphas, conv_block_dims, conv_block_alphas
|
| 680 |
+
|
| 681 |
+
|
| 682 |
+
# 外部から呼び出す可能性を考慮しておく
|
| 683 |
+
def get_block_index(lora_name: str) -> int:
|
| 684 |
+
block_idx = -1 # invalid lora name
|
| 685 |
+
|
| 686 |
+
m = RE_UPDOWN.search(lora_name)
|
| 687 |
+
if m:
|
| 688 |
+
g = m.groups()
|
| 689 |
+
i = int(g[1])
|
| 690 |
+
j = int(g[3])
|
| 691 |
+
if g[2] == "resnets":
|
| 692 |
+
idx = 3 * i + j
|
| 693 |
+
elif g[2] == "attentions":
|
| 694 |
+
idx = 3 * i + j
|
| 695 |
+
elif g[2] == "upsamplers" or g[2] == "downsamplers":
|
| 696 |
+
idx = 3 * i + 2
|
| 697 |
+
|
| 698 |
+
if g[0] == "down":
|
| 699 |
+
block_idx = 1 + idx # 0に該当するLoRAは存在しない
|
| 700 |
+
elif g[0] == "up":
|
| 701 |
+
block_idx = LoRANetwork.NUM_OF_BLOCKS + 1 + idx
|
| 702 |
+
|
| 703 |
+
elif "mid_block_" in lora_name:
|
| 704 |
+
block_idx = LoRANetwork.NUM_OF_BLOCKS # idx=12
|
| 705 |
+
|
| 706 |
+
return block_idx
|
| 707 |
+
|
| 708 |
+
|
| 709 |
+
# Create network from weights for inference, weights are not loaded here (because can be merged)
|
| 710 |
+
def create_network_from_weights(multiplier, file, vae, text_encoder, unet, weights_sd=None, for_inference=False, **kwargs):
|
| 711 |
+
if weights_sd is None:
|
| 712 |
+
if os.path.splitext(file)[1] == ".safetensors":
|
| 713 |
+
from safetensors.torch import load_file, safe_open
|
| 714 |
+
|
| 715 |
+
weights_sd = load_file(file)
|
| 716 |
+
else:
|
| 717 |
+
weights_sd = torch.load(file, map_location="cpu")
|
| 718 |
+
|
| 719 |
+
# get dim/alpha mapping
|
| 720 |
+
modules_dim = {}
|
| 721 |
+
modules_alpha = {}
|
| 722 |
+
for key, value in weights_sd.items():
|
| 723 |
+
if "." not in key:
|
| 724 |
+
continue
|
| 725 |
+
|
| 726 |
+
lora_name = key.split(".")[0]
|
| 727 |
+
if "alpha" in key:
|
| 728 |
+
modules_alpha[lora_name] = value
|
| 729 |
+
elif "lora_down" in key:
|
| 730 |
+
dim = value.size()[0]
|
| 731 |
+
modules_dim[lora_name] = dim
|
| 732 |
+
# logger.info(lora_name, value.size(), dim)
|
| 733 |
+
|
| 734 |
+
# support old LoRA without alpha
|
| 735 |
+
for key in modules_dim.keys():
|
| 736 |
+
if key not in modules_alpha:
|
| 737 |
+
modules_alpha[key] = modules_dim[key]
|
| 738 |
+
|
| 739 |
+
module_class = LoRAInfModule if for_inference else LoRAModule
|
| 740 |
+
|
| 741 |
+
network = LoRANetwork(
|
| 742 |
+
text_encoder, unet, multiplier=multiplier, modules_dim=modules_dim, modules_alpha=modules_alpha, module_class=module_class
|
| 743 |
+
)
|
| 744 |
+
|
| 745 |
+
# block lr
|
| 746 |
+
down_lr_weight, mid_lr_weight, up_lr_weight = parse_block_lr_kwargs(kwargs)
|
| 747 |
+
if up_lr_weight is not None or mid_lr_weight is not None or down_lr_weight is not None:
|
| 748 |
+
network.set_block_lr_weight(up_lr_weight, mid_lr_weight, down_lr_weight)
|
| 749 |
+
|
| 750 |
+
return network, weights_sd
|
| 751 |
+
|
| 752 |
+
|
| 753 |
+
class LoRANetwork(torch.nn.Module):
|
| 754 |
+
NUM_OF_BLOCKS = 12 # フルモデル相当でのup,downの層の数
|
| 755 |
+
|
| 756 |
+
UNET_TARGET_REPLACE_MODULE = ["Transformer2DModel"]
|
| 757 |
+
UNET_TARGET_REPLACE_MODULE_CONV2D_3X3 = ["ResnetBlock2D", "Downsample2D", "Upsample2D"]
|
| 758 |
+
TEXT_ENCODER_TARGET_REPLACE_MODULE = ["CLIPAttention", "CLIPSdpaAttention", "CLIPMLP"]
|
| 759 |
+
LORA_PREFIX_UNET = "lora_unet"
|
| 760 |
+
LORA_PREFIX_TEXT_ENCODER = "lora_te"
|
| 761 |
+
|
| 762 |
+
# SDXL: must starts with LORA_PREFIX_TEXT_ENCODER
|
| 763 |
+
LORA_PREFIX_TEXT_ENCODER1 = "lora_te1"
|
| 764 |
+
LORA_PREFIX_TEXT_ENCODER2 = "lora_te2"
|
| 765 |
+
|
| 766 |
+
def __init__(
|
| 767 |
+
self,
|
| 768 |
+
text_encoder: Union[List[CLIPTextModel], CLIPTextModel],
|
| 769 |
+
unet,
|
| 770 |
+
multiplier: float = 1.0,
|
| 771 |
+
lora_dim: int = 4,
|
| 772 |
+
alpha: float = 1,
|
| 773 |
+
dropout: Optional[float] = None,
|
| 774 |
+
rank_dropout: Optional[float] = None,
|
| 775 |
+
module_dropout: Optional[float] = None,
|
| 776 |
+
conv_lora_dim: Optional[int] = None,
|
| 777 |
+
conv_alpha: Optional[float] = None,
|
| 778 |
+
block_dims: Optional[List[int]] = None,
|
| 779 |
+
block_alphas: Optional[List[float]] = None,
|
| 780 |
+
conv_block_dims: Optional[List[int]] = None,
|
| 781 |
+
conv_block_alphas: Optional[List[float]] = None,
|
| 782 |
+
modules_dim: Optional[Dict[str, int]] = None,
|
| 783 |
+
modules_alpha: Optional[Dict[str, int]] = None,
|
| 784 |
+
module_class: Type[object] = LoRAModule,
|
| 785 |
+
varbose: Optional[bool] = False,
|
| 786 |
+
) -> None:
|
| 787 |
+
"""
|
| 788 |
+
LoRA network: すごく引数が多いが、パターンは以下の通り
|
| 789 |
+
1. lora_dimとalphaを指定
|
| 790 |
+
2. lora_dim、alpha、conv_lora_dim、conv_alphaを指定
|
| 791 |
+
3. block_dimsとblock_alphasを指定 : Conv2d3x3には適用しない
|
| 792 |
+
4. block_dims、block_alphas、conv_block_dims、conv_block_alphasを指定 : Conv2d3x3にも適用する
|
| 793 |
+
5. modules_dimとmodules_alphaを指定 (推論用)
|
| 794 |
+
"""
|
| 795 |
+
super().__init__()
|
| 796 |
+
self.multiplier = multiplier
|
| 797 |
+
|
| 798 |
+
self.lora_dim = lora_dim
|
| 799 |
+
self.alpha = alpha
|
| 800 |
+
self.conv_lora_dim = conv_lora_dim
|
| 801 |
+
self.conv_alpha = conv_alpha
|
| 802 |
+
self.dropout = dropout
|
| 803 |
+
self.rank_dropout = rank_dropout
|
| 804 |
+
self.module_dropout = module_dropout
|
| 805 |
+
|
| 806 |
+
if modules_dim is not None:
|
| 807 |
+
logger.info(f"create LoRA network from weights")
|
| 808 |
+
elif block_dims is not None:
|
| 809 |
+
logger.info(f"create LoRA network from block_dims")
|
| 810 |
+
logger.info(f"neuron dropout: p={self.dropout}, rank dropout: p={self.rank_dropout}, module dropout: p={self.module_dropout}")
|
| 811 |
+
logger.info(f"block_dims: {block_dims}")
|
| 812 |
+
logger.info(f"block_alphas: {block_alphas}")
|
| 813 |
+
if conv_block_dims is not None:
|
| 814 |
+
logger.info(f"conv_block_dims: {conv_block_dims}")
|
| 815 |
+
logger.info(f"conv_block_alphas: {conv_block_alphas}")
|
| 816 |
+
else:
|
| 817 |
+
logger.info(f"create LoRA network. base dim (rank): {lora_dim}, alpha: {alpha}")
|
| 818 |
+
logger.info(f"neuron dropout: p={self.dropout}, rank dropout: p={self.rank_dropout}, module dropout: p={self.module_dropout}")
|
| 819 |
+
if self.conv_lora_dim is not None:
|
| 820 |
+
logger.info(f"apply LoRA to Conv2d with kernel size (3,3). dim (rank): {self.conv_lora_dim}, alpha: {self.conv_alpha}")
|
| 821 |
+
|
| 822 |
+
# create module instances
|
| 823 |
+
def create_modules(
|
| 824 |
+
is_unet: bool,
|
| 825 |
+
text_encoder_idx: Optional[int], # None, 1, 2
|
| 826 |
+
root_module: torch.nn.Module,
|
| 827 |
+
target_replace_modules: List[torch.nn.Module],
|
| 828 |
+
) -> List[LoRAModule]:
|
| 829 |
+
prefix = (
|
| 830 |
+
self.LORA_PREFIX_UNET
|
| 831 |
+
if is_unet
|
| 832 |
+
else (
|
| 833 |
+
self.LORA_PREFIX_TEXT_ENCODER
|
| 834 |
+
if text_encoder_idx is None
|
| 835 |
+
else (self.LORA_PREFIX_TEXT_ENCODER1 if text_encoder_idx == 1 else self.LORA_PREFIX_TEXT_ENCODER2)
|
| 836 |
+
)
|
| 837 |
+
)
|
| 838 |
+
loras = []
|
| 839 |
+
skipped = []
|
| 840 |
+
for name, module in root_module.named_modules():
|
| 841 |
+
if module.__class__.__name__ in target_replace_modules:
|
| 842 |
+
for child_name, child_module in module.named_modules():
|
| 843 |
+
is_linear = child_module.__class__.__name__ == "Linear"
|
| 844 |
+
is_conv2d = child_module.__class__.__name__ == "Conv2d"
|
| 845 |
+
is_conv2d_1x1 = is_conv2d and child_module.kernel_size == (1, 1)
|
| 846 |
+
|
| 847 |
+
if is_linear or is_conv2d:
|
| 848 |
+
lora_name = prefix + "." + name + "." + child_name
|
| 849 |
+
lora_name = lora_name.replace(".", "_")
|
| 850 |
+
|
| 851 |
+
dim = None
|
| 852 |
+
alpha = None
|
| 853 |
+
|
| 854 |
+
if modules_dim is not None:
|
| 855 |
+
# モジュール指定あり
|
| 856 |
+
if lora_name in modules_dim:
|
| 857 |
+
dim = modules_dim[lora_name]
|
| 858 |
+
alpha = modules_alpha[lora_name]
|
| 859 |
+
elif is_unet and block_dims is not None:
|
| 860 |
+
# U-Netでblock_dims指定あり
|
| 861 |
+
block_idx = get_block_index(lora_name)
|
| 862 |
+
if is_linear or is_conv2d_1x1:
|
| 863 |
+
dim = block_dims[block_idx]
|
| 864 |
+
alpha = block_alphas[block_idx]
|
| 865 |
+
elif conv_block_dims is not None:
|
| 866 |
+
dim = conv_block_dims[block_idx]
|
| 867 |
+
alpha = conv_block_alphas[block_idx]
|
| 868 |
+
else:
|
| 869 |
+
# 通常、すべて対象とする
|
| 870 |
+
if is_linear or is_conv2d_1x1:
|
| 871 |
+
dim = self.lora_dim
|
| 872 |
+
alpha = self.alpha
|
| 873 |
+
elif self.conv_lora_dim is not None:
|
| 874 |
+
dim = self.conv_lora_dim
|
| 875 |
+
alpha = self.conv_alpha
|
| 876 |
+
|
| 877 |
+
if dim is None or dim == 0:
|
| 878 |
+
# skipした情報を出力
|
| 879 |
+
if is_linear or is_conv2d_1x1 or (self.conv_lora_dim is not None or conv_block_dims is not None):
|
| 880 |
+
skipped.append(lora_name)
|
| 881 |
+
continue
|
| 882 |
+
|
| 883 |
+
lora = module_class(
|
| 884 |
+
lora_name,
|
| 885 |
+
child_module,
|
| 886 |
+
self.multiplier,
|
| 887 |
+
dim,
|
| 888 |
+
alpha,
|
| 889 |
+
dropout=dropout,
|
| 890 |
+
rank_dropout=rank_dropout,
|
| 891 |
+
module_dropout=module_dropout,
|
| 892 |
+
)
|
| 893 |
+
loras.append(lora)
|
| 894 |
+
return loras, skipped
|
| 895 |
+
|
| 896 |
+
text_encoders = text_encoder if type(text_encoder) == list else [text_encoder]
|
| 897 |
+
|
| 898 |
+
# create LoRA for text encoder
|
| 899 |
+
# 毎回すべてのモジュールを作るのは無駄なので要検討
|
| 900 |
+
self.text_encoder_loras = []
|
| 901 |
+
skipped_te = []
|
| 902 |
+
for i, text_encoder in enumerate(text_encoders):
|
| 903 |
+
if len(text_encoders) > 1:
|
| 904 |
+
index = i + 1
|
| 905 |
+
logger.info(f"create LoRA for Text Encoder {index}:")
|
| 906 |
+
else:
|
| 907 |
+
index = None
|
| 908 |
+
logger.info(f"create LoRA for Text Encoder:")
|
| 909 |
+
|
| 910 |
+
text_encoder_loras, skipped = create_modules(False, index, text_encoder, LoRANetwork.TEXT_ENCODER_TARGET_REPLACE_MODULE)
|
| 911 |
+
self.text_encoder_loras.extend(text_encoder_loras)
|
| 912 |
+
skipped_te += skipped
|
| 913 |
+
logger.info(f"create LoRA for Text Encoder: {len(self.text_encoder_loras)} modules.")
|
| 914 |
+
|
| 915 |
+
# extend U-Net target modules if conv2d 3x3 is enabled, or load from weights
|
| 916 |
+
target_modules = LoRANetwork.UNET_TARGET_REPLACE_MODULE
|
| 917 |
+
if modules_dim is not None or self.conv_lora_dim is not None or conv_block_dims is not None:
|
| 918 |
+
target_modules += LoRANetwork.UNET_TARGET_REPLACE_MODULE_CONV2D_3X3
|
| 919 |
+
|
| 920 |
+
self.unet_loras, skipped_un = create_modules(True, None, unet, target_modules)
|
| 921 |
+
logger.info(f"create LoRA for U-Net: {len(self.unet_loras)} modules.")
|
| 922 |
+
|
| 923 |
+
skipped = skipped_te + skipped_un
|
| 924 |
+
if varbose and len(skipped) > 0:
|
| 925 |
+
logger.warning(
|
| 926 |
+
f"because block_lr_weight is 0 or dim (rank) is 0, {len(skipped)} LoRA modules are skipped / block_lr_weightまたはdim (rank)が0の為、次の{len(skipped)}個のLoRAモジュールはスキップされます:"
|
| 927 |
+
)
|
| 928 |
+
for name in skipped:
|
| 929 |
+
logger.info(f"\t{name}")
|
| 930 |
+
|
| 931 |
+
self.up_lr_weight: List[float] = None
|
| 932 |
+
self.down_lr_weight: List[float] = None
|
| 933 |
+
self.mid_lr_weight: float = None
|
| 934 |
+
self.block_lr = False
|
| 935 |
+
|
| 936 |
+
# assertion
|
| 937 |
+
names = set()
|
| 938 |
+
for lora in self.text_encoder_loras + self.unet_loras:
|
| 939 |
+
assert lora.lora_name not in names, f"duplicated lora name: {lora.lora_name}"
|
| 940 |
+
names.add(lora.lora_name)
|
| 941 |
+
|
| 942 |
+
def set_multiplier(self, multiplier):
|
| 943 |
+
self.multiplier = multiplier
|
| 944 |
+
for lora in self.text_encoder_loras + self.unet_loras:
|
| 945 |
+
lora.multiplier = self.multiplier
|
| 946 |
+
|
| 947 |
+
def load_weights(self, file):
|
| 948 |
+
if os.path.splitext(file)[1] == ".safetensors":
|
| 949 |
+
from safetensors.torch import load_file
|
| 950 |
+
|
| 951 |
+
weights_sd = load_file(file)
|
| 952 |
+
else:
|
| 953 |
+
weights_sd = torch.load(file, map_location="cpu")
|
| 954 |
+
|
| 955 |
+
info = self.load_state_dict(weights_sd, False)
|
| 956 |
+
return info
|
| 957 |
+
|
| 958 |
+
def apply_to(self, text_encoder, unet, apply_text_encoder=True, apply_unet=True):
|
| 959 |
+
if apply_text_encoder:
|
| 960 |
+
logger.info("enable LoRA for text encoder")
|
| 961 |
+
else:
|
| 962 |
+
self.text_encoder_loras = []
|
| 963 |
+
|
| 964 |
+
if apply_unet:
|
| 965 |
+
logger.info("enable LoRA for U-Net")
|
| 966 |
+
else:
|
| 967 |
+
self.unet_loras = []
|
| 968 |
+
|
| 969 |
+
for lora in self.text_encoder_loras + self.unet_loras:
|
| 970 |
+
lora.apply_to()
|
| 971 |
+
self.add_module(lora.lora_name, lora)
|
| 972 |
+
|
| 973 |
+
# マージできるかどうかを返す
|
| 974 |
+
def is_mergeable(self):
|
| 975 |
+
return True
|
| 976 |
+
|
| 977 |
+
# TODO refactor to common function with apply_to
|
| 978 |
+
def merge_to(self, text_encoder, unet, weights_sd, dtype, device):
|
| 979 |
+
apply_text_encoder = apply_unet = False
|
| 980 |
+
for key in weights_sd.keys():
|
| 981 |
+
if key.startswith(LoRANetwork.LORA_PREFIX_TEXT_ENCODER):
|
| 982 |
+
apply_text_encoder = True
|
| 983 |
+
elif key.startswith(LoRANetwork.LORA_PREFIX_UNET):
|
| 984 |
+
apply_unet = True
|
| 985 |
+
|
| 986 |
+
if apply_text_encoder:
|
| 987 |
+
logger.info("enable LoRA for text encoder")
|
| 988 |
+
else:
|
| 989 |
+
self.text_encoder_loras = []
|
| 990 |
+
|
| 991 |
+
if apply_unet:
|
| 992 |
+
logger.info("enable LoRA for U-Net")
|
| 993 |
+
else:
|
| 994 |
+
self.unet_loras = []
|
| 995 |
+
|
| 996 |
+
for lora in self.text_encoder_loras + self.unet_loras:
|
| 997 |
+
sd_for_lora = {}
|
| 998 |
+
for key in weights_sd.keys():
|
| 999 |
+
if key.startswith(lora.lora_name):
|
| 1000 |
+
sd_for_lora[key[len(lora.lora_name) + 1 :]] = weights_sd[key]
|
| 1001 |
+
lora.merge_to(sd_for_lora, dtype, device)
|
| 1002 |
+
|
| 1003 |
+
logger.info(f"weights are merged")
|
| 1004 |
+
|
| 1005 |
+
# 層別学習率用に層ごとの学習率に対する倍率を定義する 引数の順番が逆だがとりあえず気にしない
|
| 1006 |
+
def set_block_lr_weight(
|
| 1007 |
+
self,
|
| 1008 |
+
up_lr_weight: List[float] = None,
|
| 1009 |
+
mid_lr_weight: float = None,
|
| 1010 |
+
down_lr_weight: List[float] = None,
|
| 1011 |
+
):
|
| 1012 |
+
self.block_lr = True
|
| 1013 |
+
self.down_lr_weight = down_lr_weight
|
| 1014 |
+
self.mid_lr_weight = mid_lr_weight
|
| 1015 |
+
self.up_lr_weight = up_lr_weight
|
| 1016 |
+
|
| 1017 |
+
def get_lr_weight(self, lora: LoRAModule) -> float:
|
| 1018 |
+
lr_weight = 1.0
|
| 1019 |
+
block_idx = get_block_index(lora.lora_name)
|
| 1020 |
+
if block_idx < 0:
|
| 1021 |
+
return lr_weight
|
| 1022 |
+
|
| 1023 |
+
if block_idx < LoRANetwork.NUM_OF_BLOCKS:
|
| 1024 |
+
if self.down_lr_weight != None:
|
| 1025 |
+
lr_weight = self.down_lr_weight[block_idx]
|
| 1026 |
+
elif block_idx == LoRANetwork.NUM_OF_BLOCKS:
|
| 1027 |
+
if self.mid_lr_weight != None:
|
| 1028 |
+
lr_weight = self.mid_lr_weight
|
| 1029 |
+
elif block_idx > LoRANetwork.NUM_OF_BLOCKS:
|
| 1030 |
+
if self.up_lr_weight != None:
|
| 1031 |
+
lr_weight = self.up_lr_weight[block_idx - LoRANetwork.NUM_OF_BLOCKS - 1]
|
| 1032 |
+
|
| 1033 |
+
return lr_weight
|
| 1034 |
+
|
| 1035 |
+
# 二つのText Encoderに別々の学習率を設定できるようにするといいかも
|
| 1036 |
+
def prepare_optimizer_params(self, text_encoder_lr, unet_lr, default_lr):
|
| 1037 |
+
self.requires_grad_(True)
|
| 1038 |
+
all_params = []
|
| 1039 |
+
|
| 1040 |
+
def enumerate_params(loras: List[LoRAModule]):
|
| 1041 |
+
params = []
|
| 1042 |
+
for lora in loras:
|
| 1043 |
+
# params.extend(lora.parameters())
|
| 1044 |
+
params.extend(lora.get_trainable_params())
|
| 1045 |
+
return params
|
| 1046 |
+
|
| 1047 |
+
if self.text_encoder_loras:
|
| 1048 |
+
param_data = {"params": enumerate_params(self.text_encoder_loras)}
|
| 1049 |
+
if text_encoder_lr is not None:
|
| 1050 |
+
param_data["lr"] = text_encoder_lr
|
| 1051 |
+
all_params.append(param_data)
|
| 1052 |
+
|
| 1053 |
+
if self.unet_loras:
|
| 1054 |
+
if self.block_lr:
|
| 1055 |
+
# 学習率のグラフをblockごとにしたいので、blockごとにloraを分類
|
| 1056 |
+
block_idx_to_lora = {}
|
| 1057 |
+
for lora in self.unet_loras:
|
| 1058 |
+
idx = get_block_index(lora.lora_name)
|
| 1059 |
+
if idx not in block_idx_to_lora:
|
| 1060 |
+
block_idx_to_lora[idx] = []
|
| 1061 |
+
block_idx_to_lora[idx].append(lora)
|
| 1062 |
+
|
| 1063 |
+
# blockごとにパラメータを設定する
|
| 1064 |
+
for idx, block_loras in block_idx_to_lora.items():
|
| 1065 |
+
param_data = {"params": enumerate_params(block_loras)}
|
| 1066 |
+
|
| 1067 |
+
if unet_lr is not None:
|
| 1068 |
+
param_data["lr"] = unet_lr * self.get_lr_weight(block_loras[0])
|
| 1069 |
+
elif default_lr is not None:
|
| 1070 |
+
param_data["lr"] = default_lr * self.get_lr_weight(block_loras[0])
|
| 1071 |
+
if ("lr" in param_data) and (param_data["lr"] == 0):
|
| 1072 |
+
continue
|
| 1073 |
+
all_params.append(param_data)
|
| 1074 |
+
|
| 1075 |
+
else:
|
| 1076 |
+
param_data = {"params": enumerate_params(self.unet_loras)}
|
| 1077 |
+
if unet_lr is not None:
|
| 1078 |
+
param_data["lr"] = unet_lr
|
| 1079 |
+
all_params.append(param_data)
|
| 1080 |
+
|
| 1081 |
+
return all_params
|
| 1082 |
+
|
| 1083 |
+
def enable_gradient_checkpointing(self):
|
| 1084 |
+
# not supported
|
| 1085 |
+
pass
|
| 1086 |
+
|
| 1087 |
+
def prepare_grad_etc(self, text_encoder, unet):
|
| 1088 |
+
self.requires_grad_(True)
|
| 1089 |
+
|
| 1090 |
+
def on_epoch_start(self, text_encoder, unet):
|
| 1091 |
+
self.train()
|
| 1092 |
+
|
| 1093 |
+
def get_trainable_params(self):
|
| 1094 |
+
return self.parameters()
|
| 1095 |
+
|
| 1096 |
+
def save_weights(self, file, dtype, metadata):
|
| 1097 |
+
if metadata is not None and len(metadata) == 0:
|
| 1098 |
+
metadata = None
|
| 1099 |
+
|
| 1100 |
+
state_dict = self.state_dict()
|
| 1101 |
+
|
| 1102 |
+
if dtype is not None:
|
| 1103 |
+
for key in list(state_dict.keys()):
|
| 1104 |
+
v = state_dict[key]
|
| 1105 |
+
v = v.detach().clone().to("cpu").to(dtype)
|
| 1106 |
+
state_dict[key] = v
|
| 1107 |
+
|
| 1108 |
+
if os.path.splitext(file)[1] == ".safetensors":
|
| 1109 |
+
from safetensors.torch import save_file
|
| 1110 |
+
from library import train_util
|
| 1111 |
+
|
| 1112 |
+
# Precalculate model hashes to save time on indexing
|
| 1113 |
+
if metadata is None:
|
| 1114 |
+
metadata = {}
|
| 1115 |
+
model_hash, legacy_hash = train_util.precalculate_safetensors_hashes(state_dict, metadata)
|
| 1116 |
+
metadata["sshs_model_hash"] = model_hash
|
| 1117 |
+
metadata["sshs_legacy_hash"] = legacy_hash
|
| 1118 |
+
|
| 1119 |
+
save_file(state_dict, file, metadata)
|
| 1120 |
+
else:
|
| 1121 |
+
torch.save(state_dict, file)
|
| 1122 |
+
|
| 1123 |
+
# mask is a tensor with values from 0 to 1
|
| 1124 |
+
def set_region(self, sub_prompt_index, is_last_network, mask):
|
| 1125 |
+
if mask.max() == 0:
|
| 1126 |
+
mask = torch.ones_like(mask)
|
| 1127 |
+
|
| 1128 |
+
self.mask = mask
|
| 1129 |
+
self.sub_prompt_index = sub_prompt_index
|
| 1130 |
+
self.is_last_network = is_last_network
|
| 1131 |
+
|
| 1132 |
+
for lora in self.text_encoder_loras + self.unet_loras:
|
| 1133 |
+
lora.set_network(self)
|
| 1134 |
+
|
| 1135 |
+
def set_current_generation(self, batch_size, num_sub_prompts, width, height, shared):
|
| 1136 |
+
self.batch_size = batch_size
|
| 1137 |
+
self.num_sub_prompts = num_sub_prompts
|
| 1138 |
+
self.current_size = (height, width)
|
| 1139 |
+
self.shared = shared
|
| 1140 |
+
|
| 1141 |
+
# create masks
|
| 1142 |
+
mask = self.mask
|
| 1143 |
+
mask_dic = {}
|
| 1144 |
+
mask = mask.unsqueeze(0).unsqueeze(1) # b(1),c(1),h,w
|
| 1145 |
+
ref_weight = self.text_encoder_loras[0].lora_down.weight if self.text_encoder_loras else self.unet_loras[0].lora_down.weight
|
| 1146 |
+
dtype = ref_weight.dtype
|
| 1147 |
+
device = ref_weight.device
|
| 1148 |
+
|
| 1149 |
+
def resize_add(mh, mw):
|
| 1150 |
+
# logger.info(mh, mw, mh * mw)
|
| 1151 |
+
m = torch.nn.functional.interpolate(mask, (mh, mw), mode="bilinear") # doesn't work in bf16
|
| 1152 |
+
m = m.to(device, dtype=dtype)
|
| 1153 |
+
mask_dic[mh * mw] = m
|
| 1154 |
+
|
| 1155 |
+
h = height // 8
|
| 1156 |
+
w = width // 8
|
| 1157 |
+
for _ in range(4):
|
| 1158 |
+
resize_add(h, w)
|
| 1159 |
+
if h % 2 == 1 or w % 2 == 1: # add extra shape if h/w is not divisible by 2
|
| 1160 |
+
resize_add(h + h % 2, w + w % 2)
|
| 1161 |
+
h = (h + 1) // 2
|
| 1162 |
+
w = (w + 1) // 2
|
| 1163 |
+
|
| 1164 |
+
self.mask_dic = mask_dic
|
| 1165 |
+
|
| 1166 |
+
def backup_weights(self):
|
| 1167 |
+
# 重みのバックアップを行う
|
| 1168 |
+
loras: List[LoRAInfModule] = self.text_encoder_loras + self.unet_loras
|
| 1169 |
+
for lora in loras:
|
| 1170 |
+
org_module = lora.org_module_ref[0]
|
| 1171 |
+
if not hasattr(org_module, "_lora_org_weight"):
|
| 1172 |
+
sd = org_module.state_dict()
|
| 1173 |
+
org_module._lora_org_weight = sd["weight"].detach().clone()
|
| 1174 |
+
org_module._lora_restored = True
|
| 1175 |
+
|
| 1176 |
+
def restore_weights(self):
|
| 1177 |
+
# 重みのリストアを行う
|
| 1178 |
+
loras: List[LoRAInfModule] = self.text_encoder_loras + self.unet_loras
|
| 1179 |
+
for lora in loras:
|
| 1180 |
+
org_module = lora.org_module_ref[0]
|
| 1181 |
+
if not org_module._lora_restored:
|
| 1182 |
+
sd = org_module.state_dict()
|
| 1183 |
+
sd["weight"] = org_module._lora_org_weight
|
| 1184 |
+
org_module.load_state_dict(sd)
|
| 1185 |
+
org_module._lora_restored = True
|
| 1186 |
+
|
| 1187 |
+
def pre_calculation(self):
|
| 1188 |
+
# 事前計算を行う
|
| 1189 |
+
loras: List[LoRAInfModule] = self.text_encoder_loras + self.unet_loras
|
| 1190 |
+
for lora in loras:
|
| 1191 |
+
org_module = lora.org_module_ref[0]
|
| 1192 |
+
sd = org_module.state_dict()
|
| 1193 |
+
|
| 1194 |
+
org_weight = sd["weight"]
|
| 1195 |
+
lora_weight = lora.get_weight().to(org_weight.device, dtype=org_weight.dtype)
|
| 1196 |
+
sd["weight"] = org_weight + lora_weight
|
| 1197 |
+
assert sd["weight"].shape == org_weight.shape
|
| 1198 |
+
org_module.load_state_dict(sd)
|
| 1199 |
+
|
| 1200 |
+
org_module._lora_restored = False
|
| 1201 |
+
lora.enabled = False
|
| 1202 |
+
|
| 1203 |
+
def apply_max_norm_regularization(self, max_norm_value, device):
|
| 1204 |
+
downkeys = []
|
| 1205 |
+
upkeys = []
|
| 1206 |
+
alphakeys = []
|
| 1207 |
+
norms = []
|
| 1208 |
+
keys_scaled = 0
|
| 1209 |
+
|
| 1210 |
+
state_dict = self.state_dict()
|
| 1211 |
+
for key in state_dict.keys():
|
| 1212 |
+
if "lora_down" in key and "weight" in key:
|
| 1213 |
+
downkeys.append(key)
|
| 1214 |
+
upkeys.append(key.replace("lora_down", "lora_up"))
|
| 1215 |
+
alphakeys.append(key.replace("lora_down.weight", "alpha"))
|
| 1216 |
+
|
| 1217 |
+
for i in range(len(downkeys)):
|
| 1218 |
+
down = state_dict[downkeys[i]].to(device)
|
| 1219 |
+
up = state_dict[upkeys[i]].to(device)
|
| 1220 |
+
alpha = state_dict[alphakeys[i]].to(device)
|
| 1221 |
+
dim = down.shape[0]
|
| 1222 |
+
scale = alpha / dim
|
| 1223 |
+
|
| 1224 |
+
if up.shape[2:] == (1, 1) and down.shape[2:] == (1, 1):
|
| 1225 |
+
updown = (up.squeeze(2).squeeze(2) @ down.squeeze(2).squeeze(2)).unsqueeze(2).unsqueeze(3)
|
| 1226 |
+
elif up.shape[2:] == (3, 3) or down.shape[2:] == (3, 3):
|
| 1227 |
+
updown = torch.nn.functional.conv2d(down.permute(1, 0, 2, 3), up).permute(1, 0, 2, 3)
|
| 1228 |
+
else:
|
| 1229 |
+
updown = up @ down
|
| 1230 |
+
|
| 1231 |
+
updown *= scale
|
| 1232 |
+
|
| 1233 |
+
norm = updown.norm().clamp(min=max_norm_value / 2)
|
| 1234 |
+
desired = torch.clamp(norm, max=max_norm_value)
|
| 1235 |
+
ratio = desired.cpu() / norm.cpu()
|
| 1236 |
+
sqrt_ratio = ratio**0.5
|
| 1237 |
+
if ratio != 1:
|
| 1238 |
+
keys_scaled += 1
|
| 1239 |
+
state_dict[upkeys[i]] *= sqrt_ratio
|
| 1240 |
+
state_dict[downkeys[i]] *= sqrt_ratio
|
| 1241 |
+
scalednorm = updown.norm() * ratio
|
| 1242 |
+
norms.append(scalednorm.item())
|
| 1243 |
+
|
| 1244 |
+
return keys_scaled, sum(norms) / len(norms), max(norms)
|
lora_interrogator.py
ADDED
|
@@ -0,0 +1,146 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
|
| 2 |
+
|
| 3 |
+
from tqdm import tqdm
|
| 4 |
+
from library import model_util
|
| 5 |
+
import library.train_util as train_util
|
| 6 |
+
import argparse
|
| 7 |
+
from transformers import CLIPTokenizer
|
| 8 |
+
|
| 9 |
+
import torch
|
| 10 |
+
from library.device_utils import init_ipex, get_preferred_device
|
| 11 |
+
init_ipex()
|
| 12 |
+
|
| 13 |
+
import library.model_util as model_util
|
| 14 |
+
import lora
|
| 15 |
+
from library.utils import setup_logging
|
| 16 |
+
setup_logging()
|
| 17 |
+
import logging
|
| 18 |
+
logger = logging.getLogger(__name__)
|
| 19 |
+
|
| 20 |
+
TOKENIZER_PATH = "openai/clip-vit-large-patch14"
|
| 21 |
+
V2_STABLE_DIFFUSION_PATH = "stabilityai/stable-diffusion-2" # ここからtokenizerだけ使う
|
| 22 |
+
|
| 23 |
+
DEVICE = get_preferred_device()
|
| 24 |
+
|
| 25 |
+
|
| 26 |
+
def interrogate(args):
|
| 27 |
+
weights_dtype = torch.float16
|
| 28 |
+
|
| 29 |
+
# いろいろ準備する
|
| 30 |
+
logger.info(f"loading SD model: {args.sd_model}")
|
| 31 |
+
args.pretrained_model_name_or_path = args.sd_model
|
| 32 |
+
args.vae = None
|
| 33 |
+
text_encoder, vae, unet, _ = train_util._load_target_model(args,weights_dtype, DEVICE)
|
| 34 |
+
|
| 35 |
+
logger.info(f"loading LoRA: {args.model}")
|
| 36 |
+
network, weights_sd = lora.create_network_from_weights(1.0, args.model, vae, text_encoder, unet)
|
| 37 |
+
|
| 38 |
+
# text encoder向けの重みがあるかチェックする:本当はlora側でやるのがいい
|
| 39 |
+
has_te_weight = False
|
| 40 |
+
for key in weights_sd.keys():
|
| 41 |
+
if 'lora_te' in key:
|
| 42 |
+
has_te_weight = True
|
| 43 |
+
break
|
| 44 |
+
if not has_te_weight:
|
| 45 |
+
logger.error("This LoRA does not have modules for Text Encoder, cannot interrogate / このLoRAはText Encoder向けのモジュールがないため調査できません")
|
| 46 |
+
return
|
| 47 |
+
del vae
|
| 48 |
+
|
| 49 |
+
logger.info("loading tokenizer")
|
| 50 |
+
if args.v2:
|
| 51 |
+
tokenizer: CLIPTokenizer = CLIPTokenizer.from_pretrained(V2_STABLE_DIFFUSION_PATH, subfolder="tokenizer")
|
| 52 |
+
else:
|
| 53 |
+
tokenizer: CLIPTokenizer = CLIPTokenizer.from_pretrained(TOKENIZER_PATH) # , model_max_length=max_token_length + 2)
|
| 54 |
+
|
| 55 |
+
text_encoder.to(DEVICE, dtype=weights_dtype)
|
| 56 |
+
text_encoder.eval()
|
| 57 |
+
unet.to(DEVICE, dtype=weights_dtype)
|
| 58 |
+
unet.eval() # U-Netは呼び出さないので不要だけど
|
| 59 |
+
|
| 60 |
+
# トークンをひとつひとつ当たっていく
|
| 61 |
+
token_id_start = 0
|
| 62 |
+
token_id_end = max(tokenizer.all_special_ids)
|
| 63 |
+
logger.info(f"interrogate tokens are: {token_id_start} to {token_id_end}")
|
| 64 |
+
|
| 65 |
+
def get_all_embeddings(text_encoder):
|
| 66 |
+
embs = []
|
| 67 |
+
with torch.no_grad():
|
| 68 |
+
for token_id in tqdm(range(token_id_start, token_id_end + 1, args.batch_size)):
|
| 69 |
+
batch = []
|
| 70 |
+
for tid in range(token_id, min(token_id_end + 1, token_id + args.batch_size)):
|
| 71 |
+
tokens = [tokenizer.bos_token_id, tid, tokenizer.eos_token_id]
|
| 72 |
+
# tokens = [tid] # こちらは結果がいまひとつ
|
| 73 |
+
batch.append(tokens)
|
| 74 |
+
|
| 75 |
+
# batch_embs = text_encoder(torch.tensor(batch).to(DEVICE))[0].to("cpu") # bos/eosも含めたほうが差が出るようだ [:, 1]
|
| 76 |
+
# clip skip対応
|
| 77 |
+
batch = torch.tensor(batch).to(DEVICE)
|
| 78 |
+
if args.clip_skip is None:
|
| 79 |
+
encoder_hidden_states = text_encoder(batch)[0]
|
| 80 |
+
else:
|
| 81 |
+
enc_out = text_encoder(batch, output_hidden_states=True, return_dict=True)
|
| 82 |
+
encoder_hidden_states = enc_out['hidden_states'][-args.clip_skip]
|
| 83 |
+
encoder_hidden_states = text_encoder.text_model.final_layer_norm(encoder_hidden_states)
|
| 84 |
+
encoder_hidden_states = encoder_hidden_states.to("cpu")
|
| 85 |
+
|
| 86 |
+
embs.extend(encoder_hidden_states)
|
| 87 |
+
return torch.stack(embs)
|
| 88 |
+
|
| 89 |
+
logger.info("get original text encoder embeddings.")
|
| 90 |
+
orig_embs = get_all_embeddings(text_encoder)
|
| 91 |
+
|
| 92 |
+
network.apply_to(text_encoder, unet, True, len(network.unet_loras) > 0)
|
| 93 |
+
info = network.load_state_dict(weights_sd, strict=False)
|
| 94 |
+
logger.info(f"Loading LoRA weights: {info}")
|
| 95 |
+
|
| 96 |
+
network.to(DEVICE, dtype=weights_dtype)
|
| 97 |
+
network.eval()
|
| 98 |
+
|
| 99 |
+
del unet
|
| 100 |
+
|
| 101 |
+
logger.info("You can ignore warning messages start with '_IncompatibleKeys' (LoRA model does not have alpha because trained by older script) / '_IncompatibleKeys'の警告は無視して構いません(以前のスクリプトで学習されたLoRAモデルのためalphaの定義がありません)")
|
| 102 |
+
logger.info("get text encoder embeddings with lora.")
|
| 103 |
+
lora_embs = get_all_embeddings(text_encoder)
|
| 104 |
+
|
| 105 |
+
# 比べる:とりあえず単純に差分の絶対値で
|
| 106 |
+
logger.info("comparing...")
|
| 107 |
+
diffs = {}
|
| 108 |
+
for i, (orig_emb, lora_emb) in enumerate(zip(orig_embs, tqdm(lora_embs))):
|
| 109 |
+
diff = torch.mean(torch.abs(orig_emb - lora_emb))
|
| 110 |
+
# diff = torch.mean(torch.cosine_similarity(orig_emb, lora_emb, dim=1)) # うまく検出できない
|
| 111 |
+
diff = float(diff.detach().to('cpu').numpy())
|
| 112 |
+
diffs[token_id_start + i] = diff
|
| 113 |
+
|
| 114 |
+
diffs_sorted = sorted(diffs.items(), key=lambda x: -x[1])
|
| 115 |
+
|
| 116 |
+
# 結果を表示する
|
| 117 |
+
print("top 100:")
|
| 118 |
+
for i, (token, diff) in enumerate(diffs_sorted[:100]):
|
| 119 |
+
# if diff < 1e-6:
|
| 120 |
+
# break
|
| 121 |
+
string = tokenizer.convert_tokens_to_string(tokenizer.convert_ids_to_tokens([token]))
|
| 122 |
+
print(f"[{i:3d}]: {token:5d} {string:<20s}: {diff:.5f}")
|
| 123 |
+
|
| 124 |
+
|
| 125 |
+
def setup_parser() -> argparse.ArgumentParser:
|
| 126 |
+
parser = argparse.ArgumentParser()
|
| 127 |
+
|
| 128 |
+
parser.add_argument("--v2", action='store_true',
|
| 129 |
+
help='load Stable Diffusion v2.x model / Stable Diffusion 2.xのモデルを読み込む')
|
| 130 |
+
parser.add_argument("--sd_model", type=str, default=None,
|
| 131 |
+
help="Stable Diffusion model to load: ckpt or safetensors file / 読み込むSDのモデル、ckptまたはsafetensors")
|
| 132 |
+
parser.add_argument("--model", type=str, default=None,
|
| 133 |
+
help="LoRA model to interrogate: ckpt or safetensors file / 調査するLoRAモデル、ckptまたはsafetensors")
|
| 134 |
+
parser.add_argument("--batch_size", type=int, default=16,
|
| 135 |
+
help="batch size for processing with Text Encoder / Text Encoderで処理するときのバッチサイズ")
|
| 136 |
+
parser.add_argument("--clip_skip", type=int, default=None,
|
| 137 |
+
help="use output of nth layer from back of text encoder (n>=1) / text encoderの後ろからn番目の層の出力を用いる(nは1以上)")
|
| 138 |
+
|
| 139 |
+
return parser
|
| 140 |
+
|
| 141 |
+
|
| 142 |
+
if __name__ == '__main__':
|
| 143 |
+
parser = setup_parser()
|
| 144 |
+
|
| 145 |
+
args = parser.parse_args()
|
| 146 |
+
interrogate(args)
|
lpw_stable_diffusion.py
ADDED
|
@@ -0,0 +1,1233 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# copy from https://github.com/huggingface/diffusers/blob/main/examples/community/lpw_stable_diffusion.py
|
| 2 |
+
# and modify to support SD2.x
|
| 3 |
+
|
| 4 |
+
import inspect
|
| 5 |
+
import re
|
| 6 |
+
from typing import Callable, List, Optional, Union
|
| 7 |
+
|
| 8 |
+
import numpy as np
|
| 9 |
+
import PIL.Image
|
| 10 |
+
import torch
|
| 11 |
+
from packaging import version
|
| 12 |
+
from transformers import CLIPFeatureExtractor, CLIPTextModel, CLIPTokenizer, CLIPVisionModelWithProjection
|
| 13 |
+
|
| 14 |
+
import diffusers
|
| 15 |
+
from diffusers import SchedulerMixin, StableDiffusionPipeline
|
| 16 |
+
from diffusers.models import AutoencoderKL, UNet2DConditionModel
|
| 17 |
+
from diffusers.pipelines.stable_diffusion import StableDiffusionPipelineOutput, StableDiffusionSafetyChecker
|
| 18 |
+
from diffusers.utils import logging
|
| 19 |
+
|
| 20 |
+
try:
|
| 21 |
+
from diffusers.utils import PIL_INTERPOLATION
|
| 22 |
+
except ImportError:
|
| 23 |
+
if version.parse(version.parse(PIL.__version__).base_version) >= version.parse("9.1.0"):
|
| 24 |
+
PIL_INTERPOLATION = {
|
| 25 |
+
"linear": PIL.Image.Resampling.BILINEAR,
|
| 26 |
+
"bilinear": PIL.Image.Resampling.BILINEAR,
|
| 27 |
+
"bicubic": PIL.Image.Resampling.BICUBIC,
|
| 28 |
+
"lanczos": PIL.Image.Resampling.LANCZOS,
|
| 29 |
+
"nearest": PIL.Image.Resampling.NEAREST,
|
| 30 |
+
}
|
| 31 |
+
else:
|
| 32 |
+
PIL_INTERPOLATION = {
|
| 33 |
+
"linear": PIL.Image.LINEAR,
|
| 34 |
+
"bilinear": PIL.Image.BILINEAR,
|
| 35 |
+
"bicubic": PIL.Image.BICUBIC,
|
| 36 |
+
"lanczos": PIL.Image.LANCZOS,
|
| 37 |
+
"nearest": PIL.Image.NEAREST,
|
| 38 |
+
}
|
| 39 |
+
# ------------------------------------------------------------------------------
|
| 40 |
+
|
| 41 |
+
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
| 42 |
+
|
| 43 |
+
re_attention = re.compile(
|
| 44 |
+
r"""
|
| 45 |
+
\\\(|
|
| 46 |
+
\\\)|
|
| 47 |
+
\\\[|
|
| 48 |
+
\\]|
|
| 49 |
+
\\\\|
|
| 50 |
+
\\|
|
| 51 |
+
\(|
|
| 52 |
+
\[|
|
| 53 |
+
:([+-]?[.\d]+)\)|
|
| 54 |
+
\)|
|
| 55 |
+
]|
|
| 56 |
+
[^\\()\[\]:]+|
|
| 57 |
+
:
|
| 58 |
+
""",
|
| 59 |
+
re.X,
|
| 60 |
+
)
|
| 61 |
+
|
| 62 |
+
|
| 63 |
+
def parse_prompt_attention(text):
|
| 64 |
+
"""
|
| 65 |
+
Parses a string with attention tokens and returns a list of pairs: text and its associated weight.
|
| 66 |
+
Accepted tokens are:
|
| 67 |
+
(abc) - increases attention to abc by a multiplier of 1.1
|
| 68 |
+
(abc:3.12) - increases attention to abc by a multiplier of 3.12
|
| 69 |
+
[abc] - decreases attention to abc by a multiplier of 1.1
|
| 70 |
+
\( - literal character '('
|
| 71 |
+
\[ - literal character '['
|
| 72 |
+
\) - literal character ')'
|
| 73 |
+
\] - literal character ']'
|
| 74 |
+
\\ - literal character '\'
|
| 75 |
+
anything else - just text
|
| 76 |
+
>>> parse_prompt_attention('normal text')
|
| 77 |
+
[['normal text', 1.0]]
|
| 78 |
+
>>> parse_prompt_attention('an (important) word')
|
| 79 |
+
[['an ', 1.0], ['important', 1.1], [' word', 1.0]]
|
| 80 |
+
>>> parse_prompt_attention('(unbalanced')
|
| 81 |
+
[['unbalanced', 1.1]]
|
| 82 |
+
>>> parse_prompt_attention('\(literal\]')
|
| 83 |
+
[['(literal]', 1.0]]
|
| 84 |
+
>>> parse_prompt_attention('(unnecessary)(parens)')
|
| 85 |
+
[['unnecessaryparens', 1.1]]
|
| 86 |
+
>>> parse_prompt_attention('a (((house:1.3)) [on] a (hill:0.5), sun, (((sky))).')
|
| 87 |
+
[['a ', 1.0],
|
| 88 |
+
['house', 1.5730000000000004],
|
| 89 |
+
[' ', 1.1],
|
| 90 |
+
['on', 1.0],
|
| 91 |
+
[' a ', 1.1],
|
| 92 |
+
['hill', 0.55],
|
| 93 |
+
[', sun, ', 1.1],
|
| 94 |
+
['sky', 1.4641000000000006],
|
| 95 |
+
['.', 1.1]]
|
| 96 |
+
"""
|
| 97 |
+
|
| 98 |
+
res = []
|
| 99 |
+
round_brackets = []
|
| 100 |
+
square_brackets = []
|
| 101 |
+
|
| 102 |
+
round_bracket_multiplier = 1.1
|
| 103 |
+
square_bracket_multiplier = 1 / 1.1
|
| 104 |
+
|
| 105 |
+
def multiply_range(start_position, multiplier):
|
| 106 |
+
for p in range(start_position, len(res)):
|
| 107 |
+
res[p][1] *= multiplier
|
| 108 |
+
|
| 109 |
+
for m in re_attention.finditer(text):
|
| 110 |
+
text = m.group(0)
|
| 111 |
+
weight = m.group(1)
|
| 112 |
+
|
| 113 |
+
if text.startswith("\\"):
|
| 114 |
+
res.append([text[1:], 1.0])
|
| 115 |
+
elif text == "(":
|
| 116 |
+
round_brackets.append(len(res))
|
| 117 |
+
elif text == "[":
|
| 118 |
+
square_brackets.append(len(res))
|
| 119 |
+
elif weight is not None and len(round_brackets) > 0:
|
| 120 |
+
multiply_range(round_brackets.pop(), float(weight))
|
| 121 |
+
elif text == ")" and len(round_brackets) > 0:
|
| 122 |
+
multiply_range(round_brackets.pop(), round_bracket_multiplier)
|
| 123 |
+
elif text == "]" and len(square_brackets) > 0:
|
| 124 |
+
multiply_range(square_brackets.pop(), square_bracket_multiplier)
|
| 125 |
+
else:
|
| 126 |
+
res.append([text, 1.0])
|
| 127 |
+
|
| 128 |
+
for pos in round_brackets:
|
| 129 |
+
multiply_range(pos, round_bracket_multiplier)
|
| 130 |
+
|
| 131 |
+
for pos in square_brackets:
|
| 132 |
+
multiply_range(pos, square_bracket_multiplier)
|
| 133 |
+
|
| 134 |
+
if len(res) == 0:
|
| 135 |
+
res = [["", 1.0]]
|
| 136 |
+
|
| 137 |
+
# merge runs of identical weights
|
| 138 |
+
i = 0
|
| 139 |
+
while i + 1 < len(res):
|
| 140 |
+
if res[i][1] == res[i + 1][1]:
|
| 141 |
+
res[i][0] += res[i + 1][0]
|
| 142 |
+
res.pop(i + 1)
|
| 143 |
+
else:
|
| 144 |
+
i += 1
|
| 145 |
+
|
| 146 |
+
return res
|
| 147 |
+
|
| 148 |
+
|
| 149 |
+
def get_prompts_with_weights(pipe: StableDiffusionPipeline, prompt: List[str], max_length: int):
|
| 150 |
+
r"""
|
| 151 |
+
Tokenize a list of prompts and return its tokens with weights of each token.
|
| 152 |
+
|
| 153 |
+
No padding, starting or ending token is included.
|
| 154 |
+
"""
|
| 155 |
+
tokens = []
|
| 156 |
+
weights = []
|
| 157 |
+
truncated = False
|
| 158 |
+
for text in prompt:
|
| 159 |
+
texts_and_weights = parse_prompt_attention(text)
|
| 160 |
+
text_token = []
|
| 161 |
+
text_weight = []
|
| 162 |
+
for word, weight in texts_and_weights:
|
| 163 |
+
# tokenize and discard the starting and the ending token
|
| 164 |
+
token = pipe.tokenizer(word).input_ids[1:-1]
|
| 165 |
+
text_token += token
|
| 166 |
+
# copy the weight by length of token
|
| 167 |
+
text_weight += [weight] * len(token)
|
| 168 |
+
# stop if the text is too long (longer than truncation limit)
|
| 169 |
+
if len(text_token) > max_length:
|
| 170 |
+
truncated = True
|
| 171 |
+
break
|
| 172 |
+
# truncate
|
| 173 |
+
if len(text_token) > max_length:
|
| 174 |
+
truncated = True
|
| 175 |
+
text_token = text_token[:max_length]
|
| 176 |
+
text_weight = text_weight[:max_length]
|
| 177 |
+
tokens.append(text_token)
|
| 178 |
+
weights.append(text_weight)
|
| 179 |
+
if truncated:
|
| 180 |
+
logger.warning("Prompt was truncated. Try to shorten the prompt or increase max_embeddings_multiples")
|
| 181 |
+
return tokens, weights
|
| 182 |
+
|
| 183 |
+
|
| 184 |
+
def pad_tokens_and_weights(tokens, weights, max_length, bos, eos, no_boseos_middle=True, chunk_length=77):
|
| 185 |
+
r"""
|
| 186 |
+
Pad the tokens (with starting and ending tokens) and weights (with 1.0) to max_length.
|
| 187 |
+
"""
|
| 188 |
+
max_embeddings_multiples = (max_length - 2) // (chunk_length - 2)
|
| 189 |
+
weights_length = max_length if no_boseos_middle else max_embeddings_multiples * chunk_length
|
| 190 |
+
for i in range(len(tokens)):
|
| 191 |
+
tokens[i] = [bos] + tokens[i] + [eos] * (max_length - 1 - len(tokens[i]))
|
| 192 |
+
if no_boseos_middle:
|
| 193 |
+
weights[i] = [1.0] + weights[i] + [1.0] * (max_length - 1 - len(weights[i]))
|
| 194 |
+
else:
|
| 195 |
+
w = []
|
| 196 |
+
if len(weights[i]) == 0:
|
| 197 |
+
w = [1.0] * weights_length
|
| 198 |
+
else:
|
| 199 |
+
for j in range(max_embeddings_multiples):
|
| 200 |
+
w.append(1.0) # weight for starting token in this chunk
|
| 201 |
+
w += weights[i][j * (chunk_length - 2) : min(len(weights[i]), (j + 1) * (chunk_length - 2))]
|
| 202 |
+
w.append(1.0) # weight for ending token in this chunk
|
| 203 |
+
w += [1.0] * (weights_length - len(w))
|
| 204 |
+
weights[i] = w[:]
|
| 205 |
+
|
| 206 |
+
return tokens, weights
|
| 207 |
+
|
| 208 |
+
|
| 209 |
+
def get_unweighted_text_embeddings(
|
| 210 |
+
pipe: StableDiffusionPipeline,
|
| 211 |
+
text_input: torch.Tensor,
|
| 212 |
+
chunk_length: int,
|
| 213 |
+
clip_skip: int,
|
| 214 |
+
eos: int,
|
| 215 |
+
pad: int,
|
| 216 |
+
no_boseos_middle: Optional[bool] = True,
|
| 217 |
+
):
|
| 218 |
+
"""
|
| 219 |
+
When the length of tokens is a multiple of the capacity of the text encoder,
|
| 220 |
+
it should be split into chunks and sent to the text encoder individually.
|
| 221 |
+
"""
|
| 222 |
+
max_embeddings_multiples = (text_input.shape[1] - 2) // (chunk_length - 2)
|
| 223 |
+
if max_embeddings_multiples > 1:
|
| 224 |
+
text_embeddings = []
|
| 225 |
+
for i in range(max_embeddings_multiples):
|
| 226 |
+
# extract the i-th chunk
|
| 227 |
+
text_input_chunk = text_input[:, i * (chunk_length - 2) : (i + 1) * (chunk_length - 2) + 2].clone()
|
| 228 |
+
|
| 229 |
+
# cover the head and the tail by the starting and the ending tokens
|
| 230 |
+
text_input_chunk[:, 0] = text_input[0, 0]
|
| 231 |
+
if pad == eos: # v1
|
| 232 |
+
text_input_chunk[:, -1] = text_input[0, -1]
|
| 233 |
+
else: # v2
|
| 234 |
+
for j in range(len(text_input_chunk)):
|
| 235 |
+
if text_input_chunk[j, -1] != eos and text_input_chunk[j, -1] != pad: # 最後に普通の文字がある
|
| 236 |
+
text_input_chunk[j, -1] = eos
|
| 237 |
+
if text_input_chunk[j, 1] == pad: # BOSだけであとはPAD
|
| 238 |
+
text_input_chunk[j, 1] = eos
|
| 239 |
+
|
| 240 |
+
if clip_skip is None or clip_skip == 1:
|
| 241 |
+
text_embedding = pipe.text_encoder(text_input_chunk)[0]
|
| 242 |
+
else:
|
| 243 |
+
enc_out = pipe.text_encoder(text_input_chunk, output_hidden_states=True, return_dict=True)
|
| 244 |
+
text_embedding = enc_out["hidden_states"][-clip_skip]
|
| 245 |
+
text_embedding = pipe.text_encoder.text_model.final_layer_norm(text_embedding)
|
| 246 |
+
|
| 247 |
+
if no_boseos_middle:
|
| 248 |
+
if i == 0:
|
| 249 |
+
# discard the ending token
|
| 250 |
+
text_embedding = text_embedding[:, :-1]
|
| 251 |
+
elif i == max_embeddings_multiples - 1:
|
| 252 |
+
# discard the starting token
|
| 253 |
+
text_embedding = text_embedding[:, 1:]
|
| 254 |
+
else:
|
| 255 |
+
# discard both starting and ending tokens
|
| 256 |
+
text_embedding = text_embedding[:, 1:-1]
|
| 257 |
+
|
| 258 |
+
text_embeddings.append(text_embedding)
|
| 259 |
+
text_embeddings = torch.concat(text_embeddings, axis=1)
|
| 260 |
+
else:
|
| 261 |
+
if clip_skip is None or clip_skip == 1:
|
| 262 |
+
text_embeddings = pipe.text_encoder(text_input)[0]
|
| 263 |
+
else:
|
| 264 |
+
enc_out = pipe.text_encoder(text_input, output_hidden_states=True, return_dict=True)
|
| 265 |
+
text_embeddings = enc_out["hidden_states"][-clip_skip]
|
| 266 |
+
text_embeddings = pipe.text_encoder.text_model.final_layer_norm(text_embeddings)
|
| 267 |
+
return text_embeddings
|
| 268 |
+
|
| 269 |
+
|
| 270 |
+
def get_weighted_text_embeddings(
|
| 271 |
+
pipe: StableDiffusionPipeline,
|
| 272 |
+
prompt: Union[str, List[str]],
|
| 273 |
+
uncond_prompt: Optional[Union[str, List[str]]] = None,
|
| 274 |
+
max_embeddings_multiples: Optional[int] = 3,
|
| 275 |
+
no_boseos_middle: Optional[bool] = False,
|
| 276 |
+
skip_parsing: Optional[bool] = False,
|
| 277 |
+
skip_weighting: Optional[bool] = False,
|
| 278 |
+
clip_skip=None,
|
| 279 |
+
):
|
| 280 |
+
r"""
|
| 281 |
+
Prompts can be assigned with local weights using brackets. For example,
|
| 282 |
+
prompt 'A (very beautiful) masterpiece' highlights the words 'very beautiful',
|
| 283 |
+
and the embedding tokens corresponding to the words get multiplied by a constant, 1.1.
|
| 284 |
+
|
| 285 |
+
Also, to regularize of the embedding, the weighted embedding would be scaled to preserve the original mean.
|
| 286 |
+
|
| 287 |
+
Args:
|
| 288 |
+
pipe (`StableDiffusionPipeline`):
|
| 289 |
+
Pipe to provide access to the tokenizer and the text encoder.
|
| 290 |
+
prompt (`str` or `List[str]`):
|
| 291 |
+
The prompt or prompts to guide the image generation.
|
| 292 |
+
uncond_prompt (`str` or `List[str]`):
|
| 293 |
+
The unconditional prompt or prompts for guide the image generation. If unconditional prompt
|
| 294 |
+
is provided, the embeddings of prompt and uncond_prompt are concatenated.
|
| 295 |
+
max_embeddings_multiples (`int`, *optional*, defaults to `3`):
|
| 296 |
+
The max multiple length of prompt embeddings compared to the max output length of text encoder.
|
| 297 |
+
no_boseos_middle (`bool`, *optional*, defaults to `False`):
|
| 298 |
+
If the length of text token is multiples of the capacity of text encoder, whether reserve the starting and
|
| 299 |
+
ending token in each of the chunk in the middle.
|
| 300 |
+
skip_parsing (`bool`, *optional*, defaults to `False`):
|
| 301 |
+
Skip the parsing of brackets.
|
| 302 |
+
skip_weighting (`bool`, *optional*, defaults to `False`):
|
| 303 |
+
Skip the weighting. When the parsing is skipped, it is forced True.
|
| 304 |
+
"""
|
| 305 |
+
max_length = (pipe.tokenizer.model_max_length - 2) * max_embeddings_multiples + 2
|
| 306 |
+
if isinstance(prompt, str):
|
| 307 |
+
prompt = [prompt]
|
| 308 |
+
|
| 309 |
+
if not skip_parsing:
|
| 310 |
+
prompt_tokens, prompt_weights = get_prompts_with_weights(pipe, prompt, max_length - 2)
|
| 311 |
+
if uncond_prompt is not None:
|
| 312 |
+
if isinstance(uncond_prompt, str):
|
| 313 |
+
uncond_prompt = [uncond_prompt]
|
| 314 |
+
uncond_tokens, uncond_weights = get_prompts_with_weights(pipe, uncond_prompt, max_length - 2)
|
| 315 |
+
else:
|
| 316 |
+
prompt_tokens = [token[1:-1] for token in pipe.tokenizer(prompt, max_length=max_length, truncation=True).input_ids]
|
| 317 |
+
prompt_weights = [[1.0] * len(token) for token in prompt_tokens]
|
| 318 |
+
if uncond_prompt is not None:
|
| 319 |
+
if isinstance(uncond_prompt, str):
|
| 320 |
+
uncond_prompt = [uncond_prompt]
|
| 321 |
+
uncond_tokens = [
|
| 322 |
+
token[1:-1] for token in pipe.tokenizer(uncond_prompt, max_length=max_length, truncation=True).input_ids
|
| 323 |
+
]
|
| 324 |
+
uncond_weights = [[1.0] * len(token) for token in uncond_tokens]
|
| 325 |
+
|
| 326 |
+
# round up the longest length of tokens to a multiple of (model_max_length - 2)
|
| 327 |
+
max_length = max([len(token) for token in prompt_tokens])
|
| 328 |
+
if uncond_prompt is not None:
|
| 329 |
+
max_length = max(max_length, max([len(token) for token in uncond_tokens]))
|
| 330 |
+
|
| 331 |
+
max_embeddings_multiples = min(
|
| 332 |
+
max_embeddings_multiples,
|
| 333 |
+
(max_length - 1) // (pipe.tokenizer.model_max_length - 2) + 1,
|
| 334 |
+
)
|
| 335 |
+
max_embeddings_multiples = max(1, max_embeddings_multiples)
|
| 336 |
+
max_length = (pipe.tokenizer.model_max_length - 2) * max_embeddings_multiples + 2
|
| 337 |
+
|
| 338 |
+
# pad the length of tokens and weights
|
| 339 |
+
bos = pipe.tokenizer.bos_token_id
|
| 340 |
+
eos = pipe.tokenizer.eos_token_id
|
| 341 |
+
pad = pipe.tokenizer.pad_token_id
|
| 342 |
+
prompt_tokens, prompt_weights = pad_tokens_and_weights(
|
| 343 |
+
prompt_tokens,
|
| 344 |
+
prompt_weights,
|
| 345 |
+
max_length,
|
| 346 |
+
bos,
|
| 347 |
+
eos,
|
| 348 |
+
no_boseos_middle=no_boseos_middle,
|
| 349 |
+
chunk_length=pipe.tokenizer.model_max_length,
|
| 350 |
+
)
|
| 351 |
+
prompt_tokens = torch.tensor(prompt_tokens, dtype=torch.long, device=pipe.device)
|
| 352 |
+
if uncond_prompt is not None:
|
| 353 |
+
uncond_tokens, uncond_weights = pad_tokens_and_weights(
|
| 354 |
+
uncond_tokens,
|
| 355 |
+
uncond_weights,
|
| 356 |
+
max_length,
|
| 357 |
+
bos,
|
| 358 |
+
eos,
|
| 359 |
+
no_boseos_middle=no_boseos_middle,
|
| 360 |
+
chunk_length=pipe.tokenizer.model_max_length,
|
| 361 |
+
)
|
| 362 |
+
uncond_tokens = torch.tensor(uncond_tokens, dtype=torch.long, device=pipe.device)
|
| 363 |
+
|
| 364 |
+
# get the embeddings
|
| 365 |
+
text_embeddings = get_unweighted_text_embeddings(
|
| 366 |
+
pipe,
|
| 367 |
+
prompt_tokens,
|
| 368 |
+
pipe.tokenizer.model_max_length,
|
| 369 |
+
clip_skip,
|
| 370 |
+
eos,
|
| 371 |
+
pad,
|
| 372 |
+
no_boseos_middle=no_boseos_middle,
|
| 373 |
+
)
|
| 374 |
+
prompt_weights = torch.tensor(prompt_weights, dtype=text_embeddings.dtype, device=pipe.device)
|
| 375 |
+
if uncond_prompt is not None:
|
| 376 |
+
uncond_embeddings = get_unweighted_text_embeddings(
|
| 377 |
+
pipe,
|
| 378 |
+
uncond_tokens,
|
| 379 |
+
pipe.tokenizer.model_max_length,
|
| 380 |
+
clip_skip,
|
| 381 |
+
eos,
|
| 382 |
+
pad,
|
| 383 |
+
no_boseos_middle=no_boseos_middle,
|
| 384 |
+
)
|
| 385 |
+
uncond_weights = torch.tensor(uncond_weights, dtype=uncond_embeddings.dtype, device=pipe.device)
|
| 386 |
+
|
| 387 |
+
# assign weights to the prompts and normalize in the sense of mean
|
| 388 |
+
# TODO: should we normalize by chunk or in a whole (current implementation)?
|
| 389 |
+
if (not skip_parsing) and (not skip_weighting):
|
| 390 |
+
previous_mean = text_embeddings.float().mean(axis=[-2, -1]).to(text_embeddings.dtype)
|
| 391 |
+
text_embeddings *= prompt_weights.unsqueeze(-1)
|
| 392 |
+
current_mean = text_embeddings.float().mean(axis=[-2, -1]).to(text_embeddings.dtype)
|
| 393 |
+
text_embeddings *= (previous_mean / current_mean).unsqueeze(-1).unsqueeze(-1)
|
| 394 |
+
if uncond_prompt is not None:
|
| 395 |
+
previous_mean = uncond_embeddings.float().mean(axis=[-2, -1]).to(uncond_embeddings.dtype)
|
| 396 |
+
uncond_embeddings *= uncond_weights.unsqueeze(-1)
|
| 397 |
+
current_mean = uncond_embeddings.float().mean(axis=[-2, -1]).to(uncond_embeddings.dtype)
|
| 398 |
+
uncond_embeddings *= (previous_mean / current_mean).unsqueeze(-1).unsqueeze(-1)
|
| 399 |
+
|
| 400 |
+
if uncond_prompt is not None:
|
| 401 |
+
return text_embeddings, uncond_embeddings
|
| 402 |
+
return text_embeddings, None
|
| 403 |
+
|
| 404 |
+
|
| 405 |
+
def preprocess_image(image):
|
| 406 |
+
w, h = image.size
|
| 407 |
+
w, h = map(lambda x: x - x % 32, (w, h)) # resize to integer multiple of 32
|
| 408 |
+
image = image.resize((w, h), resample=PIL_INTERPOLATION["lanczos"])
|
| 409 |
+
image = np.array(image).astype(np.float32) / 255.0
|
| 410 |
+
image = image[None].transpose(0, 3, 1, 2)
|
| 411 |
+
image = torch.from_numpy(image)
|
| 412 |
+
return 2.0 * image - 1.0
|
| 413 |
+
|
| 414 |
+
|
| 415 |
+
def preprocess_mask(mask, scale_factor=8):
|
| 416 |
+
mask = mask.convert("L")
|
| 417 |
+
w, h = mask.size
|
| 418 |
+
w, h = map(lambda x: x - x % 32, (w, h)) # resize to integer multiple of 32
|
| 419 |
+
mask = mask.resize((w // scale_factor, h // scale_factor), resample=PIL_INTERPOLATION["nearest"])
|
| 420 |
+
mask = np.array(mask).astype(np.float32) / 255.0
|
| 421 |
+
mask = np.tile(mask, (4, 1, 1))
|
| 422 |
+
mask = mask[None].transpose(0, 1, 2, 3) # what does this step do?
|
| 423 |
+
mask = 1 - mask # repaint white, keep black
|
| 424 |
+
mask = torch.from_numpy(mask)
|
| 425 |
+
return mask
|
| 426 |
+
|
| 427 |
+
|
| 428 |
+
def prepare_controlnet_image(
|
| 429 |
+
image: PIL.Image.Image,
|
| 430 |
+
width: int,
|
| 431 |
+
height: int,
|
| 432 |
+
batch_size: int,
|
| 433 |
+
num_images_per_prompt: int,
|
| 434 |
+
device: torch.device,
|
| 435 |
+
dtype: torch.dtype,
|
| 436 |
+
do_classifier_free_guidance: bool = False,
|
| 437 |
+
guess_mode: bool = False,
|
| 438 |
+
):
|
| 439 |
+
if not isinstance(image, torch.Tensor):
|
| 440 |
+
if isinstance(image, PIL.Image.Image):
|
| 441 |
+
image = [image]
|
| 442 |
+
|
| 443 |
+
if isinstance(image[0], PIL.Image.Image):
|
| 444 |
+
images = []
|
| 445 |
+
|
| 446 |
+
for image_ in image:
|
| 447 |
+
image_ = image_.convert("RGB")
|
| 448 |
+
image_ = image_.resize((width, height), resample=PIL_INTERPOLATION["lanczos"])
|
| 449 |
+
image_ = np.array(image_)
|
| 450 |
+
image_ = image_[None, :]
|
| 451 |
+
images.append(image_)
|
| 452 |
+
|
| 453 |
+
image = images
|
| 454 |
+
|
| 455 |
+
image = np.concatenate(image, axis=0)
|
| 456 |
+
image = np.array(image).astype(np.float32) / 255.0
|
| 457 |
+
image = image.transpose(0, 3, 1, 2)
|
| 458 |
+
image = torch.from_numpy(image)
|
| 459 |
+
elif isinstance(image[0], torch.Tensor):
|
| 460 |
+
image = torch.cat(image, dim=0)
|
| 461 |
+
|
| 462 |
+
image_batch_size = image.shape[0]
|
| 463 |
+
|
| 464 |
+
if image_batch_size == 1:
|
| 465 |
+
repeat_by = batch_size
|
| 466 |
+
else:
|
| 467 |
+
# image batch size is the same as prompt batch size
|
| 468 |
+
repeat_by = num_images_per_prompt
|
| 469 |
+
|
| 470 |
+
image = image.repeat_interleave(repeat_by, dim=0)
|
| 471 |
+
|
| 472 |
+
image = image.to(device=device, dtype=dtype)
|
| 473 |
+
|
| 474 |
+
if do_classifier_free_guidance and not guess_mode:
|
| 475 |
+
image = torch.cat([image] * 2)
|
| 476 |
+
|
| 477 |
+
return image
|
| 478 |
+
|
| 479 |
+
|
| 480 |
+
class StableDiffusionLongPromptWeightingPipeline(StableDiffusionPipeline):
|
| 481 |
+
r"""
|
| 482 |
+
Pipeline for text-to-image generation using Stable Diffusion without tokens length limit, and support parsing
|
| 483 |
+
weighting in prompt.
|
| 484 |
+
|
| 485 |
+
This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the
|
| 486 |
+
library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.)
|
| 487 |
+
|
| 488 |
+
Args:
|
| 489 |
+
vae ([`AutoencoderKL`]):
|
| 490 |
+
Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations.
|
| 491 |
+
text_encoder ([`CLIPTextModel`]):
|
| 492 |
+
Frozen text-encoder. Stable Diffusion uses the text portion of
|
| 493 |
+
[CLIP](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPTextModel), specifically
|
| 494 |
+
the [clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14) variant.
|
| 495 |
+
tokenizer (`CLIPTokenizer`):
|
| 496 |
+
Tokenizer of class
|
| 497 |
+
[CLIPTokenizer](https://huggingface.co/docs/transformers/v4.21.0/en/model_doc/clip#transformers.CLIPTokenizer).
|
| 498 |
+
unet ([`UNet2DConditionModel`]): Conditional U-Net architecture to denoise the encoded image latents.
|
| 499 |
+
scheduler ([`SchedulerMixin`]):
|
| 500 |
+
A scheduler to be used in combination with `unet` to denoise the encoded image latents. Can be one of
|
| 501 |
+
[`DDIMScheduler`], [`LMSDiscreteScheduler`], or [`PNDMScheduler`].
|
| 502 |
+
safety_checker ([`StableDiffusionSafetyChecker`]):
|
| 503 |
+
Classification module that estimates whether generated images could be considered offensive or harmful.
|
| 504 |
+
Please, refer to the [model card](https://huggingface.co/CompVis/stable-diffusion-v1-4) for details.
|
| 505 |
+
feature_extractor ([`CLIPFeatureExtractor`]):
|
| 506 |
+
Model that extracts features from generated images to be used as inputs for the `safety_checker`.
|
| 507 |
+
"""
|
| 508 |
+
|
| 509 |
+
# if version.parse(version.parse(diffusers.__version__).base_version) >= version.parse("0.9.0"):
|
| 510 |
+
|
| 511 |
+
def __init__(
|
| 512 |
+
self,
|
| 513 |
+
vae: AutoencoderKL,
|
| 514 |
+
text_encoder: CLIPTextModel,
|
| 515 |
+
tokenizer: CLIPTokenizer,
|
| 516 |
+
unet: UNet2DConditionModel,
|
| 517 |
+
scheduler: SchedulerMixin,
|
| 518 |
+
# clip_skip: int,
|
| 519 |
+
safety_checker: StableDiffusionSafetyChecker,
|
| 520 |
+
feature_extractor: CLIPFeatureExtractor,
|
| 521 |
+
requires_safety_checker: bool = True,
|
| 522 |
+
image_encoder: CLIPVisionModelWithProjection = None,
|
| 523 |
+
clip_skip: int = 1,
|
| 524 |
+
):
|
| 525 |
+
super().__init__(
|
| 526 |
+
vae=vae,
|
| 527 |
+
text_encoder=text_encoder,
|
| 528 |
+
tokenizer=tokenizer,
|
| 529 |
+
unet=unet,
|
| 530 |
+
scheduler=scheduler,
|
| 531 |
+
safety_checker=safety_checker,
|
| 532 |
+
feature_extractor=feature_extractor,
|
| 533 |
+
requires_safety_checker=requires_safety_checker,
|
| 534 |
+
image_encoder=image_encoder,
|
| 535 |
+
)
|
| 536 |
+
self.custom_clip_skip = clip_skip
|
| 537 |
+
self.__init__additional__()
|
| 538 |
+
|
| 539 |
+
def __init__additional__(self):
|
| 540 |
+
if not hasattr(self, "vae_scale_factor"):
|
| 541 |
+
setattr(self, "vae_scale_factor", 2 ** (len(self.vae.config.block_out_channels) - 1))
|
| 542 |
+
|
| 543 |
+
@property
|
| 544 |
+
def _execution_device(self):
|
| 545 |
+
r"""
|
| 546 |
+
Returns the device on which the pipeline's models will be executed. After calling
|
| 547 |
+
`pipeline.enable_sequential_cpu_offload()` the execution device can only be inferred from Accelerate's module
|
| 548 |
+
hooks.
|
| 549 |
+
"""
|
| 550 |
+
if self.device != torch.device("meta") or not hasattr(self.unet, "_hf_hook"):
|
| 551 |
+
return self.device
|
| 552 |
+
for module in self.unet.modules():
|
| 553 |
+
if (
|
| 554 |
+
hasattr(module, "_hf_hook")
|
| 555 |
+
and hasattr(module._hf_hook, "execution_device")
|
| 556 |
+
and module._hf_hook.execution_device is not None
|
| 557 |
+
):
|
| 558 |
+
return torch.device(module._hf_hook.execution_device)
|
| 559 |
+
return self.device
|
| 560 |
+
|
| 561 |
+
def _encode_prompt(
|
| 562 |
+
self,
|
| 563 |
+
prompt,
|
| 564 |
+
device,
|
| 565 |
+
num_images_per_prompt,
|
| 566 |
+
do_classifier_free_guidance,
|
| 567 |
+
negative_prompt,
|
| 568 |
+
max_embeddings_multiples,
|
| 569 |
+
):
|
| 570 |
+
r"""
|
| 571 |
+
Encodes the prompt into text encoder hidden states.
|
| 572 |
+
|
| 573 |
+
Args:
|
| 574 |
+
prompt (`str` or `list(int)`):
|
| 575 |
+
prompt to be encoded
|
| 576 |
+
device: (`torch.device`):
|
| 577 |
+
torch device
|
| 578 |
+
num_images_per_prompt (`int`):
|
| 579 |
+
number of images that should be generated per prompt
|
| 580 |
+
do_classifier_free_guidance (`bool`):
|
| 581 |
+
whether to use classifier free guidance or not
|
| 582 |
+
negative_prompt (`str` or `List[str]`):
|
| 583 |
+
The prompt or prompts not to guide the image generation. Ignored when not using guidance (i.e., ignored
|
| 584 |
+
if `guidance_scale` is less than `1`).
|
| 585 |
+
max_embeddings_multiples (`int`, *optional*, defaults to `3`):
|
| 586 |
+
The max multiple length of prompt embeddings compared to the max output length of text encoder.
|
| 587 |
+
"""
|
| 588 |
+
batch_size = len(prompt) if isinstance(prompt, list) else 1
|
| 589 |
+
|
| 590 |
+
if negative_prompt is None:
|
| 591 |
+
negative_prompt = [""] * batch_size
|
| 592 |
+
elif isinstance(negative_prompt, str):
|
| 593 |
+
negative_prompt = [negative_prompt] * batch_size
|
| 594 |
+
if batch_size != len(negative_prompt):
|
| 595 |
+
raise ValueError(
|
| 596 |
+
f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:"
|
| 597 |
+
f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches"
|
| 598 |
+
" the batch size of `prompt`."
|
| 599 |
+
)
|
| 600 |
+
|
| 601 |
+
text_embeddings, uncond_embeddings = get_weighted_text_embeddings(
|
| 602 |
+
pipe=self,
|
| 603 |
+
prompt=prompt,
|
| 604 |
+
uncond_prompt=negative_prompt if do_classifier_free_guidance else None,
|
| 605 |
+
max_embeddings_multiples=max_embeddings_multiples,
|
| 606 |
+
clip_skip=self.custom_clip_skip,
|
| 607 |
+
)
|
| 608 |
+
bs_embed, seq_len, _ = text_embeddings.shape
|
| 609 |
+
text_embeddings = text_embeddings.repeat(1, num_images_per_prompt, 1)
|
| 610 |
+
text_embeddings = text_embeddings.view(bs_embed * num_images_per_prompt, seq_len, -1)
|
| 611 |
+
|
| 612 |
+
if do_classifier_free_guidance:
|
| 613 |
+
bs_embed, seq_len, _ = uncond_embeddings.shape
|
| 614 |
+
uncond_embeddings = uncond_embeddings.repeat(1, num_images_per_prompt, 1)
|
| 615 |
+
uncond_embeddings = uncond_embeddings.view(bs_embed * num_images_per_prompt, seq_len, -1)
|
| 616 |
+
text_embeddings = torch.cat([uncond_embeddings, text_embeddings])
|
| 617 |
+
|
| 618 |
+
return text_embeddings
|
| 619 |
+
|
| 620 |
+
def check_inputs(self, prompt, height, width, strength, callback_steps):
|
| 621 |
+
if not isinstance(prompt, str) and not isinstance(prompt, list):
|
| 622 |
+
raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
|
| 623 |
+
|
| 624 |
+
if strength < 0 or strength > 1:
|
| 625 |
+
raise ValueError(f"The value of strength should in [0.0, 1.0] but is {strength}")
|
| 626 |
+
|
| 627 |
+
if height % 8 != 0 or width % 8 != 0:
|
| 628 |
+
logger.info(f'{height} {width}')
|
| 629 |
+
raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.")
|
| 630 |
+
|
| 631 |
+
if (callback_steps is None) or (
|
| 632 |
+
callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0)
|
| 633 |
+
):
|
| 634 |
+
raise ValueError(
|
| 635 |
+
f"`callback_steps` has to be a positive integer but is {callback_steps} of type" f" {type(callback_steps)}."
|
| 636 |
+
)
|
| 637 |
+
|
| 638 |
+
def get_timesteps(self, num_inference_steps, strength, device, is_text2img):
|
| 639 |
+
if is_text2img:
|
| 640 |
+
return self.scheduler.timesteps.to(device), num_inference_steps
|
| 641 |
+
else:
|
| 642 |
+
# get the original timestep using init_timestep
|
| 643 |
+
offset = self.scheduler.config.get("steps_offset", 0)
|
| 644 |
+
init_timestep = int(num_inference_steps * strength) + offset
|
| 645 |
+
init_timestep = min(init_timestep, num_inference_steps)
|
| 646 |
+
|
| 647 |
+
t_start = max(num_inference_steps - init_timestep + offset, 0)
|
| 648 |
+
timesteps = self.scheduler.timesteps[t_start:].to(device)
|
| 649 |
+
return timesteps, num_inference_steps - t_start
|
| 650 |
+
|
| 651 |
+
def run_safety_checker(self, image, device, dtype):
|
| 652 |
+
if self.safety_checker is not None:
|
| 653 |
+
safety_checker_input = self.feature_extractor(self.numpy_to_pil(image), return_tensors="pt").to(device)
|
| 654 |
+
image, has_nsfw_concept = self.safety_checker(images=image, clip_input=safety_checker_input.pixel_values.to(dtype))
|
| 655 |
+
else:
|
| 656 |
+
has_nsfw_concept = None
|
| 657 |
+
return image, has_nsfw_concept
|
| 658 |
+
|
| 659 |
+
def decode_latents(self, latents):
|
| 660 |
+
latents = 1 / 0.18215 * latents
|
| 661 |
+
image = self.vae.decode(latents).sample
|
| 662 |
+
image = (image / 2 + 0.5).clamp(0, 1)
|
| 663 |
+
# we always cast to float32 as this does not cause significant overhead and is compatible with bfloat16
|
| 664 |
+
image = image.cpu().permute(0, 2, 3, 1).float().numpy()
|
| 665 |
+
return image
|
| 666 |
+
|
| 667 |
+
def prepare_extra_step_kwargs(self, generator, eta):
|
| 668 |
+
# prepare extra kwargs for the scheduler step, since not all schedulers have the same signature
|
| 669 |
+
# eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers.
|
| 670 |
+
# eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502
|
| 671 |
+
# and should be between [0, 1]
|
| 672 |
+
|
| 673 |
+
accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys())
|
| 674 |
+
extra_step_kwargs = {}
|
| 675 |
+
if accepts_eta:
|
| 676 |
+
extra_step_kwargs["eta"] = eta
|
| 677 |
+
|
| 678 |
+
# check if the scheduler accepts generator
|
| 679 |
+
accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys())
|
| 680 |
+
if accepts_generator:
|
| 681 |
+
extra_step_kwargs["generator"] = generator
|
| 682 |
+
return extra_step_kwargs
|
| 683 |
+
|
| 684 |
+
def prepare_latents(self, image, timestep, batch_size, height, width, dtype, device, generator, latents=None):
|
| 685 |
+
if image is None:
|
| 686 |
+
shape = (
|
| 687 |
+
batch_size,
|
| 688 |
+
self.unet.in_channels,
|
| 689 |
+
height // self.vae_scale_factor,
|
| 690 |
+
width // self.vae_scale_factor,
|
| 691 |
+
)
|
| 692 |
+
|
| 693 |
+
if latents is None:
|
| 694 |
+
if device.type == "mps":
|
| 695 |
+
# randn does not work reproducibly on mps
|
| 696 |
+
latents = torch.randn(shape, generator=generator, device="cpu", dtype=dtype).to(device)
|
| 697 |
+
else:
|
| 698 |
+
latents = torch.randn(shape, generator=generator, device=device, dtype=dtype)
|
| 699 |
+
else:
|
| 700 |
+
if latents.shape != shape:
|
| 701 |
+
raise ValueError(f"Unexpected latents shape, got {latents.shape}, expected {shape}")
|
| 702 |
+
latents = latents.to(device)
|
| 703 |
+
|
| 704 |
+
# scale the initial noise by the standard deviation required by the scheduler
|
| 705 |
+
latents = latents * self.scheduler.init_noise_sigma
|
| 706 |
+
return latents, None, None
|
| 707 |
+
else:
|
| 708 |
+
init_latent_dist = self.vae.encode(image).latent_dist
|
| 709 |
+
init_latents = init_latent_dist.sample(generator=generator)
|
| 710 |
+
init_latents = 0.18215 * init_latents
|
| 711 |
+
init_latents = torch.cat([init_latents] * batch_size, dim=0)
|
| 712 |
+
init_latents_orig = init_latents
|
| 713 |
+
shape = init_latents.shape
|
| 714 |
+
|
| 715 |
+
# add noise to latents using the timesteps
|
| 716 |
+
if device.type == "mps":
|
| 717 |
+
noise = torch.randn(shape, generator=generator, device="cpu", dtype=dtype).to(device)
|
| 718 |
+
else:
|
| 719 |
+
noise = torch.randn(shape, generator=generator, device=device, dtype=dtype)
|
| 720 |
+
latents = self.scheduler.add_noise(init_latents, noise, timestep)
|
| 721 |
+
return latents, init_latents_orig, noise
|
| 722 |
+
|
| 723 |
+
@torch.no_grad()
|
| 724 |
+
def __call__(
|
| 725 |
+
self,
|
| 726 |
+
prompt: Union[str, List[str]],
|
| 727 |
+
negative_prompt: Optional[Union[str, List[str]]] = None,
|
| 728 |
+
image: Union[torch.FloatTensor, PIL.Image.Image] = None,
|
| 729 |
+
mask_image: Union[torch.FloatTensor, PIL.Image.Image] = None,
|
| 730 |
+
height: int = 512,
|
| 731 |
+
width: int = 512,
|
| 732 |
+
num_inference_steps: int = 50,
|
| 733 |
+
guidance_scale: float = 7.5,
|
| 734 |
+
strength: float = 0.8,
|
| 735 |
+
num_images_per_prompt: Optional[int] = 1,
|
| 736 |
+
eta: float = 0.0,
|
| 737 |
+
generator: Optional[torch.Generator] = None,
|
| 738 |
+
latents: Optional[torch.FloatTensor] = None,
|
| 739 |
+
max_embeddings_multiples: Optional[int] = 3,
|
| 740 |
+
output_type: Optional[str] = "pil",
|
| 741 |
+
return_dict: bool = True,
|
| 742 |
+
controlnet=None,
|
| 743 |
+
controlnet_image=None,
|
| 744 |
+
callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None,
|
| 745 |
+
is_cancelled_callback: Optional[Callable[[], bool]] = None,
|
| 746 |
+
callback_steps: int = 1,
|
| 747 |
+
):
|
| 748 |
+
r"""
|
| 749 |
+
Function invoked when calling the pipeline for generation.
|
| 750 |
+
|
| 751 |
+
Args:
|
| 752 |
+
prompt (`str` or `List[str]`):
|
| 753 |
+
The prompt or prompts to guide the image generation.
|
| 754 |
+
negative_prompt (`str` or `List[str]`, *optional*):
|
| 755 |
+
The prompt or prompts not to guide the image generation. Ignored when not using guidance (i.e., ignored
|
| 756 |
+
if `guidance_scale` is less than `1`).
|
| 757 |
+
image (`torch.FloatTensor` or `PIL.Image.Image`):
|
| 758 |
+
`Image`, or tensor representing an image batch, that will be used as the starting point for the
|
| 759 |
+
process.
|
| 760 |
+
mask_image (`torch.FloatTensor` or `PIL.Image.Image`):
|
| 761 |
+
`Image`, or tensor representing an image batch, to mask `image`. White pixels in the mask will be
|
| 762 |
+
replaced by noise and therefore repainted, while black pixels will be preserved. If `mask_image` is a
|
| 763 |
+
PIL image, it will be converted to a single channel (luminance) before use. If it's a tensor, it should
|
| 764 |
+
contain one color channel (L) instead of 3, so the expected shape would be `(B, H, W, 1)`.
|
| 765 |
+
height (`int`, *optional*, defaults to 512):
|
| 766 |
+
The height in pixels of the generated image.
|
| 767 |
+
width (`int`, *optional*, defaults to 512):
|
| 768 |
+
The width in pixels of the generated image.
|
| 769 |
+
num_inference_steps (`int`, *optional*, defaults to 50):
|
| 770 |
+
The number of denoising steps. More denoising steps usually lead to a higher quality image at the
|
| 771 |
+
expense of slower inference.
|
| 772 |
+
guidance_scale (`float`, *optional*, defaults to 7.5):
|
| 773 |
+
Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
|
| 774 |
+
`guidance_scale` is defined as `w` of equation 2. of [Imagen
|
| 775 |
+
Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale >
|
| 776 |
+
1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,
|
| 777 |
+
usually at the expense of lower image quality.
|
| 778 |
+
strength (`float`, *optional*, defaults to 0.8):
|
| 779 |
+
Conceptually, indicates how much to transform the reference `image`. Must be between 0 and 1.
|
| 780 |
+
`image` will be used as a starting point, adding more noise to it the larger the `strength`. The
|
| 781 |
+
number of denoising steps depends on the amount of noise initially added. When `strength` is 1, added
|
| 782 |
+
noise will be maximum and the denoising process will run for the full number of iterations specified in
|
| 783 |
+
`num_inference_steps`. A value of 1, therefore, essentially ignores `image`.
|
| 784 |
+
num_images_per_prompt (`int`, *optional*, defaults to 1):
|
| 785 |
+
The number of images to generate per prompt.
|
| 786 |
+
eta (`float`, *optional*, defaults to 0.0):
|
| 787 |
+
Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to
|
| 788 |
+
[`schedulers.DDIMScheduler`], will be ignored for others.
|
| 789 |
+
generator (`torch.Generator`, *optional*):
|
| 790 |
+
A [torch generator](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make generation
|
| 791 |
+
deterministic.
|
| 792 |
+
latents (`torch.FloatTensor`, *optional*):
|
| 793 |
+
Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image
|
| 794 |
+
generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
|
| 795 |
+
tensor will ge generated by sampling using the supplied random `generator`.
|
| 796 |
+
max_embeddings_multiples (`int`, *optional*, defaults to `3`):
|
| 797 |
+
The max multiple length of prompt embeddings compared to the max output length of text encoder.
|
| 798 |
+
output_type (`str`, *optional*, defaults to `"pil"`):
|
| 799 |
+
The output format of the generate image. Choose between
|
| 800 |
+
[PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
|
| 801 |
+
return_dict (`bool`, *optional*, defaults to `True`):
|
| 802 |
+
Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a
|
| 803 |
+
plain tuple.
|
| 804 |
+
controlnet (`diffusers.ControlNetModel`, *optional*):
|
| 805 |
+
A controlnet model to be used for the inference. If not provided, controlnet will be disabled.
|
| 806 |
+
controlnet_image (`torch.FloatTensor` or `PIL.Image.Image`, *optional*):
|
| 807 |
+
`Image`, or tensor representing an image batch, to be used as the starting point for the controlnet
|
| 808 |
+
inference.
|
| 809 |
+
callback (`Callable`, *optional*):
|
| 810 |
+
A function that will be called every `callback_steps` steps during inference. The function will be
|
| 811 |
+
called with the following arguments: `callback(step: int, timestep: int, latents: torch.FloatTensor)`.
|
| 812 |
+
is_cancelled_callback (`Callable`, *optional*):
|
| 813 |
+
A function that will be called every `callback_steps` steps during inference. If the function returns
|
| 814 |
+
`True`, the inference will be cancelled.
|
| 815 |
+
callback_steps (`int`, *optional*, defaults to 1):
|
| 816 |
+
The frequency at which the `callback` function will be called. If not specified, the callback will be
|
| 817 |
+
called at every step.
|
| 818 |
+
|
| 819 |
+
Returns:
|
| 820 |
+
`None` if cancelled by `is_cancelled_callback`,
|
| 821 |
+
[`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] or `tuple`:
|
| 822 |
+
[`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] if `return_dict` is True, otherwise a `tuple.
|
| 823 |
+
When returning a tuple, the first element is a list with the generated images, and the second element is a
|
| 824 |
+
list of `bool`s denoting whether the corresponding generated image likely represents "not-safe-for-work"
|
| 825 |
+
(nsfw) content, according to the `safety_checker`.
|
| 826 |
+
"""
|
| 827 |
+
if controlnet is not None and controlnet_image is None:
|
| 828 |
+
raise ValueError("controlnet_image must be provided if controlnet is not None.")
|
| 829 |
+
|
| 830 |
+
# 0. Default height and width to unet
|
| 831 |
+
height = height or self.unet.config.sample_size * self.vae_scale_factor
|
| 832 |
+
width = width or self.unet.config.sample_size * self.vae_scale_factor
|
| 833 |
+
|
| 834 |
+
# 1. Check inputs. Raise error if not correct
|
| 835 |
+
self.check_inputs(prompt, height, width, strength, callback_steps)
|
| 836 |
+
|
| 837 |
+
# 2. Define call parameters
|
| 838 |
+
batch_size = 1 if isinstance(prompt, str) else len(prompt)
|
| 839 |
+
device = self._execution_device
|
| 840 |
+
# here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
|
| 841 |
+
# of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
|
| 842 |
+
# corresponds to doing no classifier free guidance.
|
| 843 |
+
do_classifier_free_guidance = guidance_scale > 1.0
|
| 844 |
+
|
| 845 |
+
# 3. Encode input prompt
|
| 846 |
+
text_embeddings = self._encode_prompt(
|
| 847 |
+
prompt,
|
| 848 |
+
device,
|
| 849 |
+
num_images_per_prompt,
|
| 850 |
+
do_classifier_free_guidance,
|
| 851 |
+
negative_prompt,
|
| 852 |
+
max_embeddings_multiples,
|
| 853 |
+
)
|
| 854 |
+
dtype = text_embeddings.dtype
|
| 855 |
+
|
| 856 |
+
# 4. Preprocess image and mask
|
| 857 |
+
if isinstance(image, PIL.Image.Image):
|
| 858 |
+
image = preprocess_image(image)
|
| 859 |
+
if image is not None:
|
| 860 |
+
image = image.to(device=self.device, dtype=dtype)
|
| 861 |
+
if isinstance(mask_image, PIL.Image.Image):
|
| 862 |
+
mask_image = preprocess_mask(mask_image, self.vae_scale_factor)
|
| 863 |
+
if mask_image is not None:
|
| 864 |
+
mask = mask_image.to(device=self.device, dtype=dtype)
|
| 865 |
+
mask = torch.cat([mask] * batch_size * num_images_per_prompt)
|
| 866 |
+
else:
|
| 867 |
+
mask = None
|
| 868 |
+
|
| 869 |
+
if controlnet_image is not None:
|
| 870 |
+
controlnet_image = prepare_controlnet_image(
|
| 871 |
+
controlnet_image, width, height, batch_size, 1, self.device, controlnet.dtype, do_classifier_free_guidance, False
|
| 872 |
+
)
|
| 873 |
+
|
| 874 |
+
# 5. set timesteps
|
| 875 |
+
self.scheduler.set_timesteps(num_inference_steps, device=device)
|
| 876 |
+
timesteps, num_inference_steps = self.get_timesteps(num_inference_steps, strength, device, image is None)
|
| 877 |
+
latent_timestep = timesteps[:1].repeat(batch_size * num_images_per_prompt)
|
| 878 |
+
|
| 879 |
+
# 6. Prepare latent variables
|
| 880 |
+
latents, init_latents_orig, noise = self.prepare_latents(
|
| 881 |
+
image,
|
| 882 |
+
latent_timestep,
|
| 883 |
+
batch_size * num_images_per_prompt,
|
| 884 |
+
height,
|
| 885 |
+
width,
|
| 886 |
+
dtype,
|
| 887 |
+
device,
|
| 888 |
+
generator,
|
| 889 |
+
latents,
|
| 890 |
+
)
|
| 891 |
+
|
| 892 |
+
# 7. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline
|
| 893 |
+
extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)
|
| 894 |
+
|
| 895 |
+
# 8. Denoising loop
|
| 896 |
+
for i, t in enumerate(self.progress_bar(timesteps)):
|
| 897 |
+
# expand the latents if we are doing classifier free guidance
|
| 898 |
+
latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents
|
| 899 |
+
latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
|
| 900 |
+
|
| 901 |
+
unet_additional_args = {}
|
| 902 |
+
if controlnet is not None:
|
| 903 |
+
down_block_res_samples, mid_block_res_sample = controlnet(
|
| 904 |
+
latent_model_input,
|
| 905 |
+
t,
|
| 906 |
+
encoder_hidden_states=text_embeddings,
|
| 907 |
+
controlnet_cond=controlnet_image,
|
| 908 |
+
conditioning_scale=1.0,
|
| 909 |
+
guess_mode=False,
|
| 910 |
+
return_dict=False,
|
| 911 |
+
)
|
| 912 |
+
unet_additional_args["down_block_additional_residuals"] = down_block_res_samples
|
| 913 |
+
unet_additional_args["mid_block_additional_residual"] = mid_block_res_sample
|
| 914 |
+
|
| 915 |
+
# predict the noise residual
|
| 916 |
+
noise_pred = self.unet(latent_model_input, t, encoder_hidden_states=text_embeddings, **unet_additional_args).sample
|
| 917 |
+
|
| 918 |
+
# perform guidance
|
| 919 |
+
if do_classifier_free_guidance:
|
| 920 |
+
noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
|
| 921 |
+
noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
|
| 922 |
+
|
| 923 |
+
# compute the previous noisy sample x_t -> x_t-1
|
| 924 |
+
latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs).prev_sample
|
| 925 |
+
|
| 926 |
+
if mask is not None:
|
| 927 |
+
# masking
|
| 928 |
+
init_latents_proper = self.scheduler.add_noise(init_latents_orig, noise, torch.tensor([t]))
|
| 929 |
+
latents = (init_latents_proper * mask) + (latents * (1 - mask))
|
| 930 |
+
|
| 931 |
+
# call the callback, if provided
|
| 932 |
+
if i % callback_steps == 0:
|
| 933 |
+
if callback is not None:
|
| 934 |
+
callback(i, t, latents)
|
| 935 |
+
if is_cancelled_callback is not None and is_cancelled_callback():
|
| 936 |
+
return None
|
| 937 |
+
|
| 938 |
+
return latents
|
| 939 |
+
|
| 940 |
+
def latents_to_image(self, latents):
|
| 941 |
+
# 9. Post-processing
|
| 942 |
+
image = self.decode_latents(latents.to(self.vae.dtype))
|
| 943 |
+
image = self.numpy_to_pil(image)
|
| 944 |
+
return image
|
| 945 |
+
|
| 946 |
+
def text2img(
|
| 947 |
+
self,
|
| 948 |
+
prompt: Union[str, List[str]],
|
| 949 |
+
negative_prompt: Optional[Union[str, List[str]]] = None,
|
| 950 |
+
height: int = 512,
|
| 951 |
+
width: int = 512,
|
| 952 |
+
num_inference_steps: int = 50,
|
| 953 |
+
guidance_scale: float = 7.5,
|
| 954 |
+
num_images_per_prompt: Optional[int] = 1,
|
| 955 |
+
eta: float = 0.0,
|
| 956 |
+
generator: Optional[torch.Generator] = None,
|
| 957 |
+
latents: Optional[torch.FloatTensor] = None,
|
| 958 |
+
max_embeddings_multiples: Optional[int] = 3,
|
| 959 |
+
output_type: Optional[str] = "pil",
|
| 960 |
+
return_dict: bool = True,
|
| 961 |
+
callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None,
|
| 962 |
+
is_cancelled_callback: Optional[Callable[[], bool]] = None,
|
| 963 |
+
callback_steps: int = 1,
|
| 964 |
+
):
|
| 965 |
+
r"""
|
| 966 |
+
Function for text-to-image generation.
|
| 967 |
+
Args:
|
| 968 |
+
prompt (`str` or `List[str]`):
|
| 969 |
+
The prompt or prompts to guide the image generation.
|
| 970 |
+
negative_prompt (`str` or `List[str]`, *optional*):
|
| 971 |
+
The prompt or prompts not to guide the image generation. Ignored when not using guidance (i.e., ignored
|
| 972 |
+
if `guidance_scale` is less than `1`).
|
| 973 |
+
height (`int`, *optional*, defaults to 512):
|
| 974 |
+
The height in pixels of the generated image.
|
| 975 |
+
width (`int`, *optional*, defaults to 512):
|
| 976 |
+
The width in pixels of the generated image.
|
| 977 |
+
num_inference_steps (`int`, *optional*, defaults to 50):
|
| 978 |
+
The number of denoising steps. More denoising steps usually lead to a higher quality image at the
|
| 979 |
+
expense of slower inference.
|
| 980 |
+
guidance_scale (`float`, *optional*, defaults to 7.5):
|
| 981 |
+
Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
|
| 982 |
+
`guidance_scale` is defined as `w` of equation 2. of [Imagen
|
| 983 |
+
Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale >
|
| 984 |
+
1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,
|
| 985 |
+
usually at the expense of lower image quality.
|
| 986 |
+
num_images_per_prompt (`int`, *optional*, defaults to 1):
|
| 987 |
+
The number of images to generate per prompt.
|
| 988 |
+
eta (`float`, *optional*, defaults to 0.0):
|
| 989 |
+
Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to
|
| 990 |
+
[`schedulers.DDIMScheduler`], will be ignored for others.
|
| 991 |
+
generator (`torch.Generator`, *optional*):
|
| 992 |
+
A [torch generator](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make generation
|
| 993 |
+
deterministic.
|
| 994 |
+
latents (`torch.FloatTensor`, *optional*):
|
| 995 |
+
Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image
|
| 996 |
+
generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
|
| 997 |
+
tensor will ge generated by sampling using the supplied random `generator`.
|
| 998 |
+
max_embeddings_multiples (`int`, *optional*, defaults to `3`):
|
| 999 |
+
The max multiple length of prompt embeddings compared to the max output length of text encoder.
|
| 1000 |
+
output_type (`str`, *optional*, defaults to `"pil"`):
|
| 1001 |
+
The output format of the generate image. Choose between
|
| 1002 |
+
[PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
|
| 1003 |
+
return_dict (`bool`, *optional*, defaults to `True`):
|
| 1004 |
+
Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a
|
| 1005 |
+
plain tuple.
|
| 1006 |
+
callback (`Callable`, *optional*):
|
| 1007 |
+
A function that will be called every `callback_steps` steps during inference. The function will be
|
| 1008 |
+
called with the following arguments: `callback(step: int, timestep: int, latents: torch.FloatTensor)`.
|
| 1009 |
+
is_cancelled_callback (`Callable`, *optional*):
|
| 1010 |
+
A function that will be called every `callback_steps` steps during inference. If the function returns
|
| 1011 |
+
`True`, the inference will be cancelled.
|
| 1012 |
+
callback_steps (`int`, *optional*, defaults to 1):
|
| 1013 |
+
The frequency at which the `callback` function will be called. If not specified, the callback will be
|
| 1014 |
+
called at every step.
|
| 1015 |
+
Returns:
|
| 1016 |
+
[`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] or `tuple`:
|
| 1017 |
+
[`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] if `return_dict` is True, otherwise a `tuple.
|
| 1018 |
+
When returning a tuple, the first element is a list with the generated images, and the second element is a
|
| 1019 |
+
list of `bool`s denoting whether the corresponding generated image likely represents "not-safe-for-work"
|
| 1020 |
+
(nsfw) content, according to the `safety_checker`.
|
| 1021 |
+
"""
|
| 1022 |
+
return self.__call__(
|
| 1023 |
+
prompt=prompt,
|
| 1024 |
+
negative_prompt=negative_prompt,
|
| 1025 |
+
height=height,
|
| 1026 |
+
width=width,
|
| 1027 |
+
num_inference_steps=num_inference_steps,
|
| 1028 |
+
guidance_scale=guidance_scale,
|
| 1029 |
+
num_images_per_prompt=num_images_per_prompt,
|
| 1030 |
+
eta=eta,
|
| 1031 |
+
generator=generator,
|
| 1032 |
+
latents=latents,
|
| 1033 |
+
max_embeddings_multiples=max_embeddings_multiples,
|
| 1034 |
+
output_type=output_type,
|
| 1035 |
+
return_dict=return_dict,
|
| 1036 |
+
callback=callback,
|
| 1037 |
+
is_cancelled_callback=is_cancelled_callback,
|
| 1038 |
+
callback_steps=callback_steps,
|
| 1039 |
+
)
|
| 1040 |
+
|
| 1041 |
+
def img2img(
|
| 1042 |
+
self,
|
| 1043 |
+
image: Union[torch.FloatTensor, PIL.Image.Image],
|
| 1044 |
+
prompt: Union[str, List[str]],
|
| 1045 |
+
negative_prompt: Optional[Union[str, List[str]]] = None,
|
| 1046 |
+
strength: float = 0.8,
|
| 1047 |
+
num_inference_steps: Optional[int] = 50,
|
| 1048 |
+
guidance_scale: Optional[float] = 7.5,
|
| 1049 |
+
num_images_per_prompt: Optional[int] = 1,
|
| 1050 |
+
eta: Optional[float] = 0.0,
|
| 1051 |
+
generator: Optional[torch.Generator] = None,
|
| 1052 |
+
max_embeddings_multiples: Optional[int] = 3,
|
| 1053 |
+
output_type: Optional[str] = "pil",
|
| 1054 |
+
return_dict: bool = True,
|
| 1055 |
+
callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None,
|
| 1056 |
+
is_cancelled_callback: Optional[Callable[[], bool]] = None,
|
| 1057 |
+
callback_steps: int = 1,
|
| 1058 |
+
):
|
| 1059 |
+
r"""
|
| 1060 |
+
Function for image-to-image generation.
|
| 1061 |
+
Args:
|
| 1062 |
+
image (`torch.FloatTensor` or `PIL.Image.Image`):
|
| 1063 |
+
`Image`, or tensor representing an image batch, that will be used as the starting point for the
|
| 1064 |
+
process.
|
| 1065 |
+
prompt (`str` or `List[str]`):
|
| 1066 |
+
The prompt or prompts to guide the image generation.
|
| 1067 |
+
negative_prompt (`str` or `List[str]`, *optional*):
|
| 1068 |
+
The prompt or prompts not to guide the image generation. Ignored when not using guidance (i.e., ignored
|
| 1069 |
+
if `guidance_scale` is less than `1`).
|
| 1070 |
+
strength (`float`, *optional*, defaults to 0.8):
|
| 1071 |
+
Conceptually, indicates how much to transform the reference `image`. Must be between 0 and 1.
|
| 1072 |
+
`image` will be used as a starting point, adding more noise to it the larger the `strength`. The
|
| 1073 |
+
number of denoising steps depends on the amount of noise initially added. When `strength` is 1, added
|
| 1074 |
+
noise will be maximum and the denoising process will run for the full number of iterations specified in
|
| 1075 |
+
`num_inference_steps`. A value of 1, therefore, essentially ignores `image`.
|
| 1076 |
+
num_inference_steps (`int`, *optional*, defaults to 50):
|
| 1077 |
+
The number of denoising steps. More denoising steps usually lead to a higher quality image at the
|
| 1078 |
+
expense of slower inference. This parameter will be modulated by `strength`.
|
| 1079 |
+
guidance_scale (`float`, *optional*, defaults to 7.5):
|
| 1080 |
+
Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
|
| 1081 |
+
`guidance_scale` is defined as `w` of equation 2. of [Imagen
|
| 1082 |
+
Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale >
|
| 1083 |
+
1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,
|
| 1084 |
+
usually at the expense of lower image quality.
|
| 1085 |
+
num_images_per_prompt (`int`, *optional*, defaults to 1):
|
| 1086 |
+
The number of images to generate per prompt.
|
| 1087 |
+
eta (`float`, *optional*, defaults to 0.0):
|
| 1088 |
+
Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to
|
| 1089 |
+
[`schedulers.DDIMScheduler`], will be ignored for others.
|
| 1090 |
+
generator (`torch.Generator`, *optional*):
|
| 1091 |
+
A [torch generator](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make generation
|
| 1092 |
+
deterministic.
|
| 1093 |
+
max_embeddings_multiples (`int`, *optional*, defaults to `3`):
|
| 1094 |
+
The max multiple length of prompt embeddings compared to the max output length of text encoder.
|
| 1095 |
+
output_type (`str`, *optional*, defaults to `"pil"`):
|
| 1096 |
+
The output format of the generate image. Choose between
|
| 1097 |
+
[PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
|
| 1098 |
+
return_dict (`bool`, *optional*, defaults to `True`):
|
| 1099 |
+
Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a
|
| 1100 |
+
plain tuple.
|
| 1101 |
+
callback (`Callable`, *optional*):
|
| 1102 |
+
A function that will be called every `callback_steps` steps during inference. The function will be
|
| 1103 |
+
called with the following arguments: `callback(step: int, timestep: int, latents: torch.FloatTensor)`.
|
| 1104 |
+
is_cancelled_callback (`Callable`, *optional*):
|
| 1105 |
+
A function that will be called every `callback_steps` steps during inference. If the function returns
|
| 1106 |
+
`True`, the inference will be cancelled.
|
| 1107 |
+
callback_steps (`int`, *optional*, defaults to 1):
|
| 1108 |
+
The frequency at which the `callback` function will be called. If not specified, the callback will be
|
| 1109 |
+
called at every step.
|
| 1110 |
+
Returns:
|
| 1111 |
+
[`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] or `tuple`:
|
| 1112 |
+
[`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] if `return_dict` is True, otherwise a `tuple.
|
| 1113 |
+
When returning a tuple, the first element is a list with the generated images, and the second element is a
|
| 1114 |
+
list of `bool`s denoting whether the corresponding generated image likely represents "not-safe-for-work"
|
| 1115 |
+
(nsfw) content, according to the `safety_checker`.
|
| 1116 |
+
"""
|
| 1117 |
+
return self.__call__(
|
| 1118 |
+
prompt=prompt,
|
| 1119 |
+
negative_prompt=negative_prompt,
|
| 1120 |
+
image=image,
|
| 1121 |
+
num_inference_steps=num_inference_steps,
|
| 1122 |
+
guidance_scale=guidance_scale,
|
| 1123 |
+
strength=strength,
|
| 1124 |
+
num_images_per_prompt=num_images_per_prompt,
|
| 1125 |
+
eta=eta,
|
| 1126 |
+
generator=generator,
|
| 1127 |
+
max_embeddings_multiples=max_embeddings_multiples,
|
| 1128 |
+
output_type=output_type,
|
| 1129 |
+
return_dict=return_dict,
|
| 1130 |
+
callback=callback,
|
| 1131 |
+
is_cancelled_callback=is_cancelled_callback,
|
| 1132 |
+
callback_steps=callback_steps,
|
| 1133 |
+
)
|
| 1134 |
+
|
| 1135 |
+
def inpaint(
|
| 1136 |
+
self,
|
| 1137 |
+
image: Union[torch.FloatTensor, PIL.Image.Image],
|
| 1138 |
+
mask_image: Union[torch.FloatTensor, PIL.Image.Image],
|
| 1139 |
+
prompt: Union[str, List[str]],
|
| 1140 |
+
negative_prompt: Optional[Union[str, List[str]]] = None,
|
| 1141 |
+
strength: float = 0.8,
|
| 1142 |
+
num_inference_steps: Optional[int] = 50,
|
| 1143 |
+
guidance_scale: Optional[float] = 7.5,
|
| 1144 |
+
num_images_per_prompt: Optional[int] = 1,
|
| 1145 |
+
eta: Optional[float] = 0.0,
|
| 1146 |
+
generator: Optional[torch.Generator] = None,
|
| 1147 |
+
max_embeddings_multiples: Optional[int] = 3,
|
| 1148 |
+
output_type: Optional[str] = "pil",
|
| 1149 |
+
return_dict: bool = True,
|
| 1150 |
+
callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None,
|
| 1151 |
+
is_cancelled_callback: Optional[Callable[[], bool]] = None,
|
| 1152 |
+
callback_steps: int = 1,
|
| 1153 |
+
):
|
| 1154 |
+
r"""
|
| 1155 |
+
Function for inpaint.
|
| 1156 |
+
Args:
|
| 1157 |
+
image (`torch.FloatTensor` or `PIL.Image.Image`):
|
| 1158 |
+
`Image`, or tensor representing an image batch, that will be used as the starting point for the
|
| 1159 |
+
process. This is the image whose masked region will be inpainted.
|
| 1160 |
+
mask_image (`torch.FloatTensor` or `PIL.Image.Image`):
|
| 1161 |
+
`Image`, or tensor representing an image batch, to mask `image`. White pixels in the mask will be
|
| 1162 |
+
replaced by noise and therefore repainted, while black pixels will be preserved. If `mask_image` is a
|
| 1163 |
+
PIL image, it will be converted to a single channel (luminance) before use. If it's a tensor, it should
|
| 1164 |
+
contain one color channel (L) instead of 3, so the expected shape would be `(B, H, W, 1)`.
|
| 1165 |
+
prompt (`str` or `List[str]`):
|
| 1166 |
+
The prompt or prompts to guide the image generation.
|
| 1167 |
+
negative_prompt (`str` or `List[str]`, *optional*):
|
| 1168 |
+
The prompt or prompts not to guide the image generation. Ignored when not using guidance (i.e., ignored
|
| 1169 |
+
if `guidance_scale` is less than `1`).
|
| 1170 |
+
strength (`float`, *optional*, defaults to 0.8):
|
| 1171 |
+
Conceptually, indicates how much to inpaint the masked area. Must be between 0 and 1. When `strength`
|
| 1172 |
+
is 1, the denoising process will be run on the masked area for the full number of iterations specified
|
| 1173 |
+
in `num_inference_steps`. `image` will be used as a reference for the masked area, adding more
|
| 1174 |
+
noise to that region the larger the `strength`. If `strength` is 0, no inpainting will occur.
|
| 1175 |
+
num_inference_steps (`int`, *optional*, defaults to 50):
|
| 1176 |
+
The reference number of denoising steps. More denoising steps usually lead to a higher quality image at
|
| 1177 |
+
the expense of slower inference. This parameter will be modulated by `strength`, as explained above.
|
| 1178 |
+
guidance_scale (`float`, *optional*, defaults to 7.5):
|
| 1179 |
+
Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
|
| 1180 |
+
`guidance_scale` is defined as `w` of equation 2. of [Imagen
|
| 1181 |
+
Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale >
|
| 1182 |
+
1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,
|
| 1183 |
+
usually at the expense of lower image quality.
|
| 1184 |
+
num_images_per_prompt (`int`, *optional*, defaults to 1):
|
| 1185 |
+
The number of images to generate per prompt.
|
| 1186 |
+
eta (`float`, *optional*, defaults to 0.0):
|
| 1187 |
+
Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to
|
| 1188 |
+
[`schedulers.DDIMScheduler`], will be ignored for others.
|
| 1189 |
+
generator (`torch.Generator`, *optional*):
|
| 1190 |
+
A [torch generator](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make generation
|
| 1191 |
+
deterministic.
|
| 1192 |
+
max_embeddings_multiples (`int`, *optional*, defaults to `3`):
|
| 1193 |
+
The max multiple length of prompt embeddings compared to the max output length of text encoder.
|
| 1194 |
+
output_type (`str`, *optional*, defaults to `"pil"`):
|
| 1195 |
+
The output format of the generate image. Choose between
|
| 1196 |
+
[PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
|
| 1197 |
+
return_dict (`bool`, *optional*, defaults to `True`):
|
| 1198 |
+
Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a
|
| 1199 |
+
plain tuple.
|
| 1200 |
+
callback (`Callable`, *optional*):
|
| 1201 |
+
A function that will be called every `callback_steps` steps during inference. The function will be
|
| 1202 |
+
called with the following arguments: `callback(step: int, timestep: int, latents: torch.FloatTensor)`.
|
| 1203 |
+
is_cancelled_callback (`Callable`, *optional*):
|
| 1204 |
+
A function that will be called every `callback_steps` steps during inference. If the function returns
|
| 1205 |
+
`True`, the inference will be cancelled.
|
| 1206 |
+
callback_steps (`int`, *optional*, defaults to 1):
|
| 1207 |
+
The frequency at which the `callback` function will be called. If not specified, the callback will be
|
| 1208 |
+
called at every step.
|
| 1209 |
+
Returns:
|
| 1210 |
+
[`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] or `tuple`:
|
| 1211 |
+
[`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] if `return_dict` is True, otherwise a `tuple.
|
| 1212 |
+
When returning a tuple, the first element is a list with the generated images, and the second element is a
|
| 1213 |
+
list of `bool`s denoting whether the corresponding generated image likely represents "not-safe-for-work"
|
| 1214 |
+
(nsfw) content, according to the `safety_checker`.
|
| 1215 |
+
"""
|
| 1216 |
+
return self.__call__(
|
| 1217 |
+
prompt=prompt,
|
| 1218 |
+
negative_prompt=negative_prompt,
|
| 1219 |
+
image=image,
|
| 1220 |
+
mask_image=mask_image,
|
| 1221 |
+
num_inference_steps=num_inference_steps,
|
| 1222 |
+
guidance_scale=guidance_scale,
|
| 1223 |
+
strength=strength,
|
| 1224 |
+
num_images_per_prompt=num_images_per_prompt,
|
| 1225 |
+
eta=eta,
|
| 1226 |
+
generator=generator,
|
| 1227 |
+
max_embeddings_multiples=max_embeddings_multiples,
|
| 1228 |
+
output_type=output_type,
|
| 1229 |
+
return_dict=return_dict,
|
| 1230 |
+
callback=callback,
|
| 1231 |
+
is_cancelled_callback=is_cancelled_callback,
|
| 1232 |
+
callback_steps=callback_steps,
|
| 1233 |
+
)
|
main.py
ADDED
|
@@ -0,0 +1,166 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
extract factors the build is dependent on:
|
| 3 |
+
[X] compute capability
|
| 4 |
+
[ ] TODO: Q - What if we have multiple GPUs of different makes?
|
| 5 |
+
- CUDA version
|
| 6 |
+
- Software:
|
| 7 |
+
- CPU-only: only CPU quantization functions (no optimizer, no matrix multiple)
|
| 8 |
+
- CuBLAS-LT: full-build 8-bit optimizer
|
| 9 |
+
- no CuBLAS-LT: no 8-bit matrix multiplication (`nomatmul`)
|
| 10 |
+
|
| 11 |
+
evaluation:
|
| 12 |
+
- if paths faulty, return meaningful error
|
| 13 |
+
- else:
|
| 14 |
+
- determine CUDA version
|
| 15 |
+
- determine capabilities
|
| 16 |
+
- based on that set the default path
|
| 17 |
+
"""
|
| 18 |
+
|
| 19 |
+
import ctypes
|
| 20 |
+
|
| 21 |
+
from .paths import determine_cuda_runtime_lib_path
|
| 22 |
+
|
| 23 |
+
|
| 24 |
+
def check_cuda_result(cuda, result_val):
|
| 25 |
+
# 3. Check for CUDA errors
|
| 26 |
+
if result_val != 0:
|
| 27 |
+
error_str = ctypes.c_char_p()
|
| 28 |
+
cuda.cuGetErrorString(result_val, ctypes.byref(error_str))
|
| 29 |
+
print(f"CUDA exception! Error code: {error_str.value.decode()}")
|
| 30 |
+
|
| 31 |
+
def get_cuda_version(cuda, cudart_path):
|
| 32 |
+
# https://docs.nvidia.com/cuda/cuda-runtime-api/group__CUDART____VERSION.html#group__CUDART____VERSION
|
| 33 |
+
try:
|
| 34 |
+
cudart = ctypes.CDLL(cudart_path)
|
| 35 |
+
except OSError:
|
| 36 |
+
# TODO: shouldn't we error or at least warn here?
|
| 37 |
+
print(f'ERROR: libcudart.so could not be read from path: {cudart_path}!')
|
| 38 |
+
return None
|
| 39 |
+
|
| 40 |
+
version = ctypes.c_int()
|
| 41 |
+
check_cuda_result(cuda, cudart.cudaRuntimeGetVersion(ctypes.byref(version)))
|
| 42 |
+
version = int(version.value)
|
| 43 |
+
major = version//1000
|
| 44 |
+
minor = (version-(major*1000))//10
|
| 45 |
+
|
| 46 |
+
if major < 11:
|
| 47 |
+
print('CUDA SETUP: CUDA version lower than 11 are currently not supported for LLM.int8(). You will be only to use 8-bit optimizers and quantization routines!!')
|
| 48 |
+
|
| 49 |
+
return f'{major}{minor}'
|
| 50 |
+
|
| 51 |
+
|
| 52 |
+
def get_cuda_lib_handle():
|
| 53 |
+
# 1. find libcuda.so library (GPU driver) (/usr/lib)
|
| 54 |
+
try:
|
| 55 |
+
cuda = ctypes.CDLL("libcuda.so")
|
| 56 |
+
except OSError:
|
| 57 |
+
# TODO: shouldn't we error or at least warn here?
|
| 58 |
+
print('CUDA SETUP: WARNING! libcuda.so not found! Do you have a CUDA driver installed? If you are on a cluster, make sure you are on a CUDA machine!')
|
| 59 |
+
return None
|
| 60 |
+
check_cuda_result(cuda, cuda.cuInit(0))
|
| 61 |
+
|
| 62 |
+
return cuda
|
| 63 |
+
|
| 64 |
+
|
| 65 |
+
def get_compute_capabilities(cuda):
|
| 66 |
+
"""
|
| 67 |
+
1. find libcuda.so library (GPU driver) (/usr/lib)
|
| 68 |
+
init_device -> init variables -> call function by reference
|
| 69 |
+
2. call extern C function to determine CC
|
| 70 |
+
(https://docs.nvidia.com/cuda/cuda-driver-api/group__CUDA__DEVICE__DEPRECATED.html)
|
| 71 |
+
3. Check for CUDA errors
|
| 72 |
+
https://stackoverflow.com/questions/14038589/what-is-the-canonical-way-to-check-for-errors-using-the-cuda-runtime-api
|
| 73 |
+
# bits taken from https://gist.github.com/f0k/63a664160d016a491b2cbea15913d549
|
| 74 |
+
"""
|
| 75 |
+
|
| 76 |
+
|
| 77 |
+
nGpus = ctypes.c_int()
|
| 78 |
+
cc_major = ctypes.c_int()
|
| 79 |
+
cc_minor = ctypes.c_int()
|
| 80 |
+
|
| 81 |
+
device = ctypes.c_int()
|
| 82 |
+
|
| 83 |
+
check_cuda_result(cuda, cuda.cuDeviceGetCount(ctypes.byref(nGpus)))
|
| 84 |
+
ccs = []
|
| 85 |
+
for i in range(nGpus.value):
|
| 86 |
+
check_cuda_result(cuda, cuda.cuDeviceGet(ctypes.byref(device), i))
|
| 87 |
+
ref_major = ctypes.byref(cc_major)
|
| 88 |
+
ref_minor = ctypes.byref(cc_minor)
|
| 89 |
+
# 2. call extern C function to determine CC
|
| 90 |
+
check_cuda_result(
|
| 91 |
+
cuda, cuda.cuDeviceComputeCapability(ref_major, ref_minor, device)
|
| 92 |
+
)
|
| 93 |
+
ccs.append(f"{cc_major.value}.{cc_minor.value}")
|
| 94 |
+
|
| 95 |
+
return ccs
|
| 96 |
+
|
| 97 |
+
|
| 98 |
+
# def get_compute_capability()-> Union[List[str, ...], None]: # FIXME: error
|
| 99 |
+
def get_compute_capability(cuda):
|
| 100 |
+
"""
|
| 101 |
+
Extracts the highest compute capbility from all available GPUs, as compute
|
| 102 |
+
capabilities are downwards compatible. If no GPUs are detected, it returns
|
| 103 |
+
None.
|
| 104 |
+
"""
|
| 105 |
+
ccs = get_compute_capabilities(cuda)
|
| 106 |
+
if ccs is not None:
|
| 107 |
+
# TODO: handle different compute capabilities; for now, take the max
|
| 108 |
+
return ccs[-1]
|
| 109 |
+
return None
|
| 110 |
+
|
| 111 |
+
|
| 112 |
+
def evaluate_cuda_setup():
|
| 113 |
+
print('')
|
| 114 |
+
print('='*35 + 'BUG REPORT' + '='*35)
|
| 115 |
+
print('Welcome to bitsandbytes. For bug reports, please submit your error trace to: https://github.com/TimDettmers/bitsandbytes/issues')
|
| 116 |
+
print('For effortless bug reporting copy-paste your error into this form: https://docs.google.com/forms/d/e/1FAIpQLScPB8emS3Thkp66nvqwmjTEgxp8Y9ufuWTzFyr9kJ5AoI47dQ/viewform?usp=sf_link')
|
| 117 |
+
print('='*80)
|
| 118 |
+
return "libbitsandbytes_cuda116.dll" # $$$
|
| 119 |
+
|
| 120 |
+
binary_name = "libbitsandbytes_cpu.so"
|
| 121 |
+
#if not torch.cuda.is_available():
|
| 122 |
+
#print('No GPU detected. Loading CPU library...')
|
| 123 |
+
#return binary_name
|
| 124 |
+
|
| 125 |
+
cudart_path = determine_cuda_runtime_lib_path()
|
| 126 |
+
if cudart_path is None:
|
| 127 |
+
print(
|
| 128 |
+
"WARNING: No libcudart.so found! Install CUDA or the cudatoolkit package (anaconda)!"
|
| 129 |
+
)
|
| 130 |
+
return binary_name
|
| 131 |
+
|
| 132 |
+
print(f"CUDA SETUP: CUDA runtime path found: {cudart_path}")
|
| 133 |
+
cuda = get_cuda_lib_handle()
|
| 134 |
+
cc = get_compute_capability(cuda)
|
| 135 |
+
print(f"CUDA SETUP: Highest compute capability among GPUs detected: {cc}")
|
| 136 |
+
cuda_version_string = get_cuda_version(cuda, cudart_path)
|
| 137 |
+
|
| 138 |
+
|
| 139 |
+
if cc == '':
|
| 140 |
+
print(
|
| 141 |
+
"WARNING: No GPU detected! Check your CUDA paths. Processing to load CPU-only library..."
|
| 142 |
+
)
|
| 143 |
+
return binary_name
|
| 144 |
+
|
| 145 |
+
# 7.5 is the minimum CC vor cublaslt
|
| 146 |
+
has_cublaslt = cc in ["7.5", "8.0", "8.6"]
|
| 147 |
+
|
| 148 |
+
# TODO:
|
| 149 |
+
# (1) CUDA missing cases (no CUDA installed by CUDA driver (nvidia-smi accessible)
|
| 150 |
+
# (2) Multiple CUDA versions installed
|
| 151 |
+
|
| 152 |
+
# we use ls -l instead of nvcc to determine the cuda version
|
| 153 |
+
# since most installations will have the libcudart.so installed, but not the compiler
|
| 154 |
+
print(f'CUDA SETUP: Detected CUDA version {cuda_version_string}')
|
| 155 |
+
|
| 156 |
+
def get_binary_name():
|
| 157 |
+
"if not has_cublaslt (CC < 7.5), then we have to choose _nocublaslt.so"
|
| 158 |
+
bin_base_name = "libbitsandbytes_cuda"
|
| 159 |
+
if has_cublaslt:
|
| 160 |
+
return f"{bin_base_name}{cuda_version_string}.so"
|
| 161 |
+
else:
|
| 162 |
+
return f"{bin_base_name}{cuda_version_string}_nocublaslt.so"
|
| 163 |
+
|
| 164 |
+
binary_name = get_binary_name()
|
| 165 |
+
|
| 166 |
+
return binary_name
|
make_captions.py
ADDED
|
@@ -0,0 +1,210 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import argparse
|
| 2 |
+
import glob
|
| 3 |
+
import os
|
| 4 |
+
import json
|
| 5 |
+
import random
|
| 6 |
+
import sys
|
| 7 |
+
|
| 8 |
+
from pathlib import Path
|
| 9 |
+
from PIL import Image
|
| 10 |
+
from tqdm import tqdm
|
| 11 |
+
import numpy as np
|
| 12 |
+
|
| 13 |
+
import torch
|
| 14 |
+
from library.device_utils import init_ipex, get_preferred_device
|
| 15 |
+
init_ipex()
|
| 16 |
+
|
| 17 |
+
from torchvision import transforms
|
| 18 |
+
from torchvision.transforms.functional import InterpolationMode
|
| 19 |
+
sys.path.append(os.path.dirname(__file__))
|
| 20 |
+
from blip.blip import blip_decoder, is_url
|
| 21 |
+
import library.train_util as train_util
|
| 22 |
+
from library.utils import setup_logging
|
| 23 |
+
setup_logging()
|
| 24 |
+
import logging
|
| 25 |
+
logger = logging.getLogger(__name__)
|
| 26 |
+
|
| 27 |
+
DEVICE = get_preferred_device()
|
| 28 |
+
|
| 29 |
+
|
| 30 |
+
IMAGE_SIZE = 384
|
| 31 |
+
|
| 32 |
+
# 正方形でいいのか? という気がするがソースがそうなので
|
| 33 |
+
IMAGE_TRANSFORM = transforms.Compose(
|
| 34 |
+
[
|
| 35 |
+
transforms.Resize((IMAGE_SIZE, IMAGE_SIZE), interpolation=InterpolationMode.BICUBIC),
|
| 36 |
+
transforms.ToTensor(),
|
| 37 |
+
transforms.Normalize((0.48145466, 0.4578275, 0.40821073), (0.26862954, 0.26130258, 0.27577711)),
|
| 38 |
+
]
|
| 39 |
+
)
|
| 40 |
+
|
| 41 |
+
|
| 42 |
+
# 共通化したいが微妙に処理が異なる……
|
| 43 |
+
class ImageLoadingTransformDataset(torch.utils.data.Dataset):
|
| 44 |
+
def __init__(self, image_paths):
|
| 45 |
+
self.images = image_paths
|
| 46 |
+
|
| 47 |
+
def __len__(self):
|
| 48 |
+
return len(self.images)
|
| 49 |
+
|
| 50 |
+
def __getitem__(self, idx):
|
| 51 |
+
img_path = self.images[idx]
|
| 52 |
+
|
| 53 |
+
try:
|
| 54 |
+
image = Image.open(img_path).convert("RGB")
|
| 55 |
+
# convert to tensor temporarily so dataloader will accept it
|
| 56 |
+
tensor = IMAGE_TRANSFORM(image)
|
| 57 |
+
except Exception as e:
|
| 58 |
+
logger.error(f"Could not load image path / 画像を読み込めません: {img_path}, error: {e}")
|
| 59 |
+
return None
|
| 60 |
+
|
| 61 |
+
return (tensor, img_path)
|
| 62 |
+
|
| 63 |
+
|
| 64 |
+
def collate_fn_remove_corrupted(batch):
|
| 65 |
+
"""Collate function that allows to remove corrupted examples in the
|
| 66 |
+
dataloader. It expects that the dataloader returns 'None' when that occurs.
|
| 67 |
+
The 'None's in the batch are removed.
|
| 68 |
+
"""
|
| 69 |
+
# Filter out all the Nones (corrupted examples)
|
| 70 |
+
batch = list(filter(lambda x: x is not None, batch))
|
| 71 |
+
return batch
|
| 72 |
+
|
| 73 |
+
|
| 74 |
+
def main(args):
|
| 75 |
+
# fix the seed for reproducibility
|
| 76 |
+
seed = args.seed # + utils.get_rank()
|
| 77 |
+
torch.manual_seed(seed)
|
| 78 |
+
np.random.seed(seed)
|
| 79 |
+
random.seed(seed)
|
| 80 |
+
|
| 81 |
+
if not os.path.exists("blip"):
|
| 82 |
+
args.train_data_dir = os.path.abspath(args.train_data_dir) # convert to absolute path
|
| 83 |
+
|
| 84 |
+
cwd = os.getcwd()
|
| 85 |
+
logger.info(f"Current Working Directory is: {cwd}")
|
| 86 |
+
os.chdir("finetune")
|
| 87 |
+
if not is_url(args.caption_weights) and not os.path.isfile(args.caption_weights):
|
| 88 |
+
args.caption_weights = os.path.join("..", args.caption_weights)
|
| 89 |
+
|
| 90 |
+
logger.info(f"load images from {args.train_data_dir}")
|
| 91 |
+
train_data_dir_path = Path(args.train_data_dir)
|
| 92 |
+
image_paths = train_util.glob_images_pathlib(train_data_dir_path, args.recursive)
|
| 93 |
+
logger.info(f"found {len(image_paths)} images.")
|
| 94 |
+
|
| 95 |
+
logger.info(f"loading BLIP caption: {args.caption_weights}")
|
| 96 |
+
model = blip_decoder(pretrained=args.caption_weights, image_size=IMAGE_SIZE, vit="large", med_config="./blip/med_config.json")
|
| 97 |
+
model.eval()
|
| 98 |
+
model = model.to(DEVICE)
|
| 99 |
+
logger.info("BLIP loaded")
|
| 100 |
+
|
| 101 |
+
# captioningする
|
| 102 |
+
def run_batch(path_imgs):
|
| 103 |
+
imgs = torch.stack([im for _, im in path_imgs]).to(DEVICE)
|
| 104 |
+
|
| 105 |
+
with torch.no_grad():
|
| 106 |
+
if args.beam_search:
|
| 107 |
+
captions = model.generate(
|
| 108 |
+
imgs, sample=False, num_beams=args.num_beams, max_length=args.max_length, min_length=args.min_length
|
| 109 |
+
)
|
| 110 |
+
else:
|
| 111 |
+
captions = model.generate(
|
| 112 |
+
imgs, sample=True, top_p=args.top_p, max_length=args.max_length, min_length=args.min_length
|
| 113 |
+
)
|
| 114 |
+
|
| 115 |
+
for (image_path, _), caption in zip(path_imgs, captions):
|
| 116 |
+
with open(os.path.splitext(image_path)[0] + args.caption_extension, "wt", encoding="utf-8") as f:
|
| 117 |
+
f.write(caption + "\n")
|
| 118 |
+
if args.debug:
|
| 119 |
+
logger.info(f'{image_path} {caption}')
|
| 120 |
+
|
| 121 |
+
# 読み込みの高速化のためにDataLoaderを使うオプション
|
| 122 |
+
if args.max_data_loader_n_workers is not None:
|
| 123 |
+
dataset = ImageLoadingTransformDataset(image_paths)
|
| 124 |
+
data = torch.utils.data.DataLoader(
|
| 125 |
+
dataset,
|
| 126 |
+
batch_size=args.batch_size,
|
| 127 |
+
shuffle=False,
|
| 128 |
+
num_workers=args.max_data_loader_n_workers,
|
| 129 |
+
collate_fn=collate_fn_remove_corrupted,
|
| 130 |
+
drop_last=False,
|
| 131 |
+
)
|
| 132 |
+
else:
|
| 133 |
+
data = [[(None, ip)] for ip in image_paths]
|
| 134 |
+
|
| 135 |
+
b_imgs = []
|
| 136 |
+
for data_entry in tqdm(data, smoothing=0.0):
|
| 137 |
+
for data in data_entry:
|
| 138 |
+
if data is None:
|
| 139 |
+
continue
|
| 140 |
+
|
| 141 |
+
img_tensor, image_path = data
|
| 142 |
+
if img_tensor is None:
|
| 143 |
+
try:
|
| 144 |
+
raw_image = Image.open(image_path)
|
| 145 |
+
if raw_image.mode != "RGB":
|
| 146 |
+
raw_image = raw_image.convert("RGB")
|
| 147 |
+
img_tensor = IMAGE_TRANSFORM(raw_image)
|
| 148 |
+
except Exception as e:
|
| 149 |
+
logger.error(f"Could not load image path / 画像を読み込めません: {image_path}, error: {e}")
|
| 150 |
+
continue
|
| 151 |
+
|
| 152 |
+
b_imgs.append((image_path, img_tensor))
|
| 153 |
+
if len(b_imgs) >= args.batch_size:
|
| 154 |
+
run_batch(b_imgs)
|
| 155 |
+
b_imgs.clear()
|
| 156 |
+
if len(b_imgs) > 0:
|
| 157 |
+
run_batch(b_imgs)
|
| 158 |
+
|
| 159 |
+
logger.info("done!")
|
| 160 |
+
|
| 161 |
+
|
| 162 |
+
def setup_parser() -> argparse.ArgumentParser:
|
| 163 |
+
parser = argparse.ArgumentParser()
|
| 164 |
+
parser.add_argument("train_data_dir", type=str, help="directory for train images / 学習画像データのディレクトリ")
|
| 165 |
+
parser.add_argument(
|
| 166 |
+
"--caption_weights",
|
| 167 |
+
type=str,
|
| 168 |
+
default="https://storage.googleapis.com/sfr-vision-language-research/BLIP/models/model_large_caption.pth",
|
| 169 |
+
help="BLIP caption weights (model_large_caption.pth) / BLIP captionの重みファイル(model_large_caption.pth)",
|
| 170 |
+
)
|
| 171 |
+
parser.add_argument(
|
| 172 |
+
"--caption_extention",
|
| 173 |
+
type=str,
|
| 174 |
+
default=None,
|
| 175 |
+
help="extension of caption file (for backward compatibility) / 出力されるキャプションファイルの拡張子(スペルミスしていたのを残してあります)",
|
| 176 |
+
)
|
| 177 |
+
parser.add_argument("--caption_extension", type=str, default=".caption", help="extension of caption file / 出力されるキャプションファイルの拡張子")
|
| 178 |
+
parser.add_argument(
|
| 179 |
+
"--beam_search",
|
| 180 |
+
action="store_true",
|
| 181 |
+
help="use beam search (default Nucleus sampling) / beam searchを使う(このオプション未指定時はNucleus sampling)",
|
| 182 |
+
)
|
| 183 |
+
parser.add_argument("--batch_size", type=int, default=1, help="batch size in inference / 推論時のバッチサイズ")
|
| 184 |
+
parser.add_argument(
|
| 185 |
+
"--max_data_loader_n_workers",
|
| 186 |
+
type=int,
|
| 187 |
+
default=None,
|
| 188 |
+
help="enable image reading by DataLoader with this number of workers (faster) / DataLoaderによる画像読み込みを有効にしてこのワーカー数を適用する(読み込みを高速化)",
|
| 189 |
+
)
|
| 190 |
+
parser.add_argument("--num_beams", type=int, default=1, help="num of beams in beam search /beam search時のビーム数(多いと精度が上がるが時間がかかる)")
|
| 191 |
+
parser.add_argument("--top_p", type=float, default=0.9, help="top_p in Nucleus sampling / Nucleus sampling時のtop_p")
|
| 192 |
+
parser.add_argument("--max_length", type=int, default=75, help="max length of caption / captionの最大長")
|
| 193 |
+
parser.add_argument("--min_length", type=int, default=5, help="min length of caption / captionの最小長")
|
| 194 |
+
parser.add_argument("--seed", default=42, type=int, help="seed for reproducibility / 再現性を確保するための乱数seed")
|
| 195 |
+
parser.add_argument("--debug", action="store_true", help="debug mode")
|
| 196 |
+
parser.add_argument("--recursive", action="store_true", help="search for images in subfolders recursively / サブフォルダを再帰的に検索する")
|
| 197 |
+
|
| 198 |
+
return parser
|
| 199 |
+
|
| 200 |
+
|
| 201 |
+
if __name__ == "__main__":
|
| 202 |
+
parser = setup_parser()
|
| 203 |
+
|
| 204 |
+
args = parser.parse_args()
|
| 205 |
+
|
| 206 |
+
# スペルミスしていたオプションを復元する
|
| 207 |
+
if args.caption_extention is not None:
|
| 208 |
+
args.caption_extension = args.caption_extention
|
| 209 |
+
|
| 210 |
+
main(args)
|
make_captions_by_git.py
ADDED
|
@@ -0,0 +1,183 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import argparse
|
| 2 |
+
import os
|
| 3 |
+
import re
|
| 4 |
+
|
| 5 |
+
from pathlib import Path
|
| 6 |
+
from PIL import Image
|
| 7 |
+
from tqdm import tqdm
|
| 8 |
+
|
| 9 |
+
import torch
|
| 10 |
+
from library.device_utils import init_ipex, get_preferred_device
|
| 11 |
+
init_ipex()
|
| 12 |
+
|
| 13 |
+
from transformers import AutoProcessor, AutoModelForCausalLM
|
| 14 |
+
from transformers.generation.utils import GenerationMixin
|
| 15 |
+
|
| 16 |
+
import library.train_util as train_util
|
| 17 |
+
from library.utils import setup_logging
|
| 18 |
+
setup_logging()
|
| 19 |
+
import logging
|
| 20 |
+
logger = logging.getLogger(__name__)
|
| 21 |
+
|
| 22 |
+
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
| 23 |
+
|
| 24 |
+
PATTERN_REPLACE = [
|
| 25 |
+
re.compile(r'(has|with|and) the (words?|letters?|name) (" ?[^"]*"|\w+)( ?(is )?(on|in) (the |her |their |him )?\w+)?'),
|
| 26 |
+
re.compile(r'(with a sign )?that says ?(" ?[^"]*"|\w+)( ?on it)?'),
|
| 27 |
+
re.compile(r"(with a sign )?that says ?(' ?(i'm)?[^']*'|\w+)( ?on it)?"),
|
| 28 |
+
re.compile(r"with the number \d+ on (it|\w+ \w+)"),
|
| 29 |
+
re.compile(r'with the words "'),
|
| 30 |
+
re.compile(r"word \w+ on it"),
|
| 31 |
+
re.compile(r"that says the word \w+ on it"),
|
| 32 |
+
re.compile("that says'the word \"( on it)?"),
|
| 33 |
+
]
|
| 34 |
+
|
| 35 |
+
# 誤検知しまくりの with the word xxxx を消す
|
| 36 |
+
|
| 37 |
+
|
| 38 |
+
def remove_words(captions, debug):
|
| 39 |
+
removed_caps = []
|
| 40 |
+
for caption in captions:
|
| 41 |
+
cap = caption
|
| 42 |
+
for pat in PATTERN_REPLACE:
|
| 43 |
+
cap = pat.sub("", cap)
|
| 44 |
+
if debug and cap != caption:
|
| 45 |
+
logger.info(caption)
|
| 46 |
+
logger.info(cap)
|
| 47 |
+
removed_caps.append(cap)
|
| 48 |
+
return removed_caps
|
| 49 |
+
|
| 50 |
+
|
| 51 |
+
def collate_fn_remove_corrupted(batch):
|
| 52 |
+
"""Collate function that allows to remove corrupted examples in the
|
| 53 |
+
dataloader. It expects that the dataloader returns 'None' when that occurs.
|
| 54 |
+
The 'None's in the batch are removed.
|
| 55 |
+
"""
|
| 56 |
+
# Filter out all the Nones (corrupted examples)
|
| 57 |
+
batch = list(filter(lambda x: x is not None, batch))
|
| 58 |
+
return batch
|
| 59 |
+
|
| 60 |
+
|
| 61 |
+
def main(args):
|
| 62 |
+
r"""
|
| 63 |
+
transformers 4.30.2で、バッチサイズ>1でも動くようになったので、以下コメントアウト
|
| 64 |
+
|
| 65 |
+
# GITにバッチサイズが1より大きくても動くようにパッチを当てる: transformers 4.26.0用
|
| 66 |
+
org_prepare_input_ids_for_generation = GenerationMixin._prepare_input_ids_for_generation
|
| 67 |
+
curr_batch_size = [args.batch_size] # ループの最後で件数がbatch_size未満になるので入れ替えられるように
|
| 68 |
+
|
| 69 |
+
# input_idsがバッチサイズと同じ件数である必要がある:バッチサイズはこの関数から参照できないので外から渡す
|
| 70 |
+
# ここより上で置き換えようとするとすごく大変
|
| 71 |
+
def _prepare_input_ids_for_generation_patch(self, bos_token_id, encoder_outputs):
|
| 72 |
+
input_ids = org_prepare_input_ids_for_generation(self, bos_token_id, encoder_outputs)
|
| 73 |
+
if input_ids.size()[0] != curr_batch_size[0]:
|
| 74 |
+
input_ids = input_ids.repeat(curr_batch_size[0], 1)
|
| 75 |
+
return input_ids
|
| 76 |
+
|
| 77 |
+
GenerationMixin._prepare_input_ids_for_generation = _prepare_input_ids_for_generation_patch
|
| 78 |
+
"""
|
| 79 |
+
|
| 80 |
+
logger.info(f"load images from {args.train_data_dir}")
|
| 81 |
+
train_data_dir_path = Path(args.train_data_dir)
|
| 82 |
+
image_paths = train_util.glob_images_pathlib(train_data_dir_path, args.recursive)
|
| 83 |
+
logger.info(f"found {len(image_paths)} images.")
|
| 84 |
+
|
| 85 |
+
# できればcacheに依存せず明示的にダウンロードしたい
|
| 86 |
+
logger.info(f"loading GIT: {args.model_id}")
|
| 87 |
+
git_processor = AutoProcessor.from_pretrained(args.model_id)
|
| 88 |
+
git_model = AutoModelForCausalLM.from_pretrained(args.model_id).to(DEVICE)
|
| 89 |
+
logger.info("GIT loaded")
|
| 90 |
+
|
| 91 |
+
# captioningする
|
| 92 |
+
def run_batch(path_imgs):
|
| 93 |
+
imgs = [im for _, im in path_imgs]
|
| 94 |
+
|
| 95 |
+
# curr_batch_size[0] = len(path_imgs)
|
| 96 |
+
inputs = git_processor(images=imgs, return_tensors="pt").to(DEVICE) # 画像はpil形式
|
| 97 |
+
generated_ids = git_model.generate(pixel_values=inputs.pixel_values, max_length=args.max_length)
|
| 98 |
+
captions = git_processor.batch_decode(generated_ids, skip_special_tokens=True)
|
| 99 |
+
|
| 100 |
+
if args.remove_words:
|
| 101 |
+
captions = remove_words(captions, args.debug)
|
| 102 |
+
|
| 103 |
+
for (image_path, _), caption in zip(path_imgs, captions):
|
| 104 |
+
with open(os.path.splitext(image_path)[0] + args.caption_extension, "wt", encoding="utf-8") as f:
|
| 105 |
+
f.write(caption + "\n")
|
| 106 |
+
if args.debug:
|
| 107 |
+
logger.info(f"{image_path} {caption}")
|
| 108 |
+
|
| 109 |
+
# 読み込みの高速化のためにDataLoaderを使うオプション
|
| 110 |
+
if args.max_data_loader_n_workers is not None:
|
| 111 |
+
dataset = train_util.ImageLoadingDataset(image_paths)
|
| 112 |
+
data = torch.utils.data.DataLoader(
|
| 113 |
+
dataset,
|
| 114 |
+
batch_size=args.batch_size,
|
| 115 |
+
shuffle=False,
|
| 116 |
+
num_workers=args.max_data_loader_n_workers,
|
| 117 |
+
collate_fn=collate_fn_remove_corrupted,
|
| 118 |
+
drop_last=False,
|
| 119 |
+
)
|
| 120 |
+
else:
|
| 121 |
+
data = [[(None, ip)] for ip in image_paths]
|
| 122 |
+
|
| 123 |
+
b_imgs = []
|
| 124 |
+
for data_entry in tqdm(data, smoothing=0.0):
|
| 125 |
+
for data in data_entry:
|
| 126 |
+
if data is None:
|
| 127 |
+
continue
|
| 128 |
+
|
| 129 |
+
image, image_path = data
|
| 130 |
+
if image is None:
|
| 131 |
+
try:
|
| 132 |
+
image = Image.open(image_path)
|
| 133 |
+
if image.mode != "RGB":
|
| 134 |
+
image = image.convert("RGB")
|
| 135 |
+
except Exception as e:
|
| 136 |
+
logger.error(f"Could not load image path / 画像を読み込めません: {image_path}, error: {e}")
|
| 137 |
+
continue
|
| 138 |
+
|
| 139 |
+
b_imgs.append((image_path, image))
|
| 140 |
+
if len(b_imgs) >= args.batch_size:
|
| 141 |
+
run_batch(b_imgs)
|
| 142 |
+
b_imgs.clear()
|
| 143 |
+
|
| 144 |
+
if len(b_imgs) > 0:
|
| 145 |
+
run_batch(b_imgs)
|
| 146 |
+
|
| 147 |
+
logger.info("done!")
|
| 148 |
+
|
| 149 |
+
|
| 150 |
+
def setup_parser() -> argparse.ArgumentParser:
|
| 151 |
+
parser = argparse.ArgumentParser()
|
| 152 |
+
parser.add_argument("train_data_dir", type=str, help="directory for train images / 学習画像データのディレクトリ")
|
| 153 |
+
parser.add_argument("--caption_extension", type=str, default=".caption", help="extension of caption file / 出力されるキャプションファイルの拡張子")
|
| 154 |
+
parser.add_argument(
|
| 155 |
+
"--model_id",
|
| 156 |
+
type=str,
|
| 157 |
+
default="microsoft/git-large-textcaps",
|
| 158 |
+
help="model id for GIT in Hugging Face / 使用するGITのHugging FaceのモデルID",
|
| 159 |
+
)
|
| 160 |
+
parser.add_argument("--batch_size", type=int, default=1, help="batch size in inference / 推論時のバッチサイズ")
|
| 161 |
+
parser.add_argument(
|
| 162 |
+
"--max_data_loader_n_workers",
|
| 163 |
+
type=int,
|
| 164 |
+
default=None,
|
| 165 |
+
help="enable image reading by DataLoader with this number of workers (faster) / DataLoaderによる画像読み込みを有効にしてこのワーカー数を適用する(読み込みを高速化)",
|
| 166 |
+
)
|
| 167 |
+
parser.add_argument("--max_length", type=int, default=50, help="max length of caption / captionの最大長")
|
| 168 |
+
parser.add_argument(
|
| 169 |
+
"--remove_words",
|
| 170 |
+
action="store_true",
|
| 171 |
+
help="remove like `with the words xxx` from caption / `with the words xxx`のような部分をキャプションから削除する",
|
| 172 |
+
)
|
| 173 |
+
parser.add_argument("--debug", action="store_true", help="debug mode")
|
| 174 |
+
parser.add_argument("--recursive", action="store_true", help="search for images in subfolders recursively / サブフォルダを再帰的に検索する")
|
| 175 |
+
|
| 176 |
+
return parser
|
| 177 |
+
|
| 178 |
+
|
| 179 |
+
if __name__ == "__main__":
|
| 180 |
+
parser = setup_parser()
|
| 181 |
+
|
| 182 |
+
args = parser.parse_args()
|
| 183 |
+
main(args)
|
masked_loss_README-ja.md
ADDED
|
@@ -0,0 +1,57 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
## マスクロスについて
|
| 2 |
+
|
| 3 |
+
マスクロスは、入力画像のマスクで指定された部分だけ損失計算することで、画像の一部分だけを学習することができる機能です。
|
| 4 |
+
たとえばキャラクタを学習したい場合、キャラクタ部分だけをマスクして学習することで、背景を無視して学習することができます。
|
| 5 |
+
|
| 6 |
+
マスクロスのマスクには、二種類の指定方法があります。
|
| 7 |
+
|
| 8 |
+
- マスク画像を用いる方法
|
| 9 |
+
- 透明度(アルファチャネル)を使用する方法
|
| 10 |
+
|
| 11 |
+
なお、サンプルは [ずんずんPJイラスト/3Dデータ](https://zunko.jp/con_illust.html) の「AI画像モデル用学習データ」を使用しています。
|
| 12 |
+
|
| 13 |
+
### マスク画像を用いる方法
|
| 14 |
+
|
| 15 |
+
学習画像それぞれに対応するマスク画像を用意する方法です。学習画像と同じファイル名のマスク画像を用意し、それを学習画像と別のディレクトリに保存します。
|
| 16 |
+
|
| 17 |
+
- 学習画像
|
| 18 |
+

|
| 19 |
+
- マスク画像
|
| 20 |
+

|
| 21 |
+
|
| 22 |
+
```.toml
|
| 23 |
+
[[datasets.subsets]]
|
| 24 |
+
image_dir = "/path/to/a_zundamon"
|
| 25 |
+
caption_extension = ".txt"
|
| 26 |
+
conditioning_data_dir = "/path/to/a_zundamon_mask"
|
| 27 |
+
num_repeats = 8
|
| 28 |
+
```
|
| 29 |
+
|
| 30 |
+
マスク画像は、学習画像と同じサイズで、学習する部分を白、無視する部分を黒で描画します。グレースケールにも対応しています(127 ならロス重みが 0.5 になります)。なお、正確にはマスク画像の R チャネルが用いられます。
|
| 31 |
+
|
| 32 |
+
DreamBooth 方式の dataset で、`conditioning_data_dir` で指定したディレクトリにマスク画像を保存してください。ControlNet のデータセットと同じですので、詳細は [ControlNet-LLLite](train_lllite_README-ja.md#データセットの準備) を参照してください。
|
| 33 |
+
|
| 34 |
+
### 透明度(アルファチャネル)を使用する方法
|
| 35 |
+
|
| 36 |
+
学習画像の透明度(アルファチャネル)がマスクとして使用されます。透明度が 0 の部分は無視され、255 の部分は学習されます。半透明の場合は、その透明度に応じてロス重みが変化します(127 ならおおむね 0.5)。
|
| 37 |
+
|
| 38 |
+

|
| 39 |
+
|
| 40 |
+
※それぞれの画像は透過PNG
|
| 41 |
+
|
| 42 |
+
学習時のスクリプトのオプションに `--alpha_mask` を指定するか、dataset の設定ファイルの subset で、`alpha_mask` を指定してください。たとえば、以下のようになります。
|
| 43 |
+
|
| 44 |
+
```toml
|
| 45 |
+
[[datasets.subsets]]
|
| 46 |
+
image_dir = "/path/to/image/dir"
|
| 47 |
+
caption_extension = ".txt"
|
| 48 |
+
num_repeats = 8
|
| 49 |
+
alpha_mask = true
|
| 50 |
+
```
|
| 51 |
+
|
| 52 |
+
## 学習時の注意事項
|
| 53 |
+
|
| 54 |
+
- 現時点では DreamBooth 方式の dataset のみ対応しています。
|
| 55 |
+
- マスクは latents のサイズ、つまり 1/8 に縮小されてから適用されます。そのため、細かい部分(たとえばアホ毛やイヤリングなど)はうまく学習できない可能性があります。マスクをわずかに拡張するなどの工夫が必要かもしれません。
|
| 56 |
+
- マスクロスを用いる場合、学習対象外の部分をキャプションに含める必要はないかもしれません。(要検証)
|
| 57 |
+
- `alpha_mask` の場合、マスクの有無を切り替えると latents キャッシュが自動的に再生成されます。
|
masked_loss_README.md
ADDED
|
@@ -0,0 +1,56 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
## Masked Loss
|
| 2 |
+
|
| 3 |
+
Masked loss is a feature that allows you to train only part of an image by calculating the loss only for the part specified by the mask of the input image. For example, if you want to train a character, you can train only the character part by masking it, ignoring the background.
|
| 4 |
+
|
| 5 |
+
There are two ways to specify the mask for masked loss.
|
| 6 |
+
|
| 7 |
+
- Using a mask image
|
| 8 |
+
- Using transparency (alpha channel) of the image
|
| 9 |
+
|
| 10 |
+
The sample uses the "AI image model training data" from [ZunZunPJ Illustration/3D Data](https://zunko.jp/con_illust.html).
|
| 11 |
+
|
| 12 |
+
### Using a mask image
|
| 13 |
+
|
| 14 |
+
This is a method of preparing a mask image corresponding to each training image. Prepare a mask image with the same file name as the training image and save it in a different directory from the training image.
|
| 15 |
+
|
| 16 |
+
- Training image
|
| 17 |
+

|
| 18 |
+
- Mask image
|
| 19 |
+

|
| 20 |
+
|
| 21 |
+
```.toml
|
| 22 |
+
[[datasets.subsets]]
|
| 23 |
+
image_dir = "/path/to/a_zundamon"
|
| 24 |
+
caption_extension = ".txt"
|
| 25 |
+
conditioning_data_dir = "/path/to/a_zundamon_mask"
|
| 26 |
+
num_repeats = 8
|
| 27 |
+
```
|
| 28 |
+
|
| 29 |
+
The mask image is the same size as the training image, with the part to be trained drawn in white and the part to be ignored in black. It also supports grayscale (127 gives a loss weight of 0.5). The R channel of the mask image is used currently.
|
| 30 |
+
|
| 31 |
+
Use the dataset in the DreamBooth method, and save the mask image in the directory specified by `conditioning_data_dir`. It is the same as the ControlNet dataset, so please refer to [ControlNet-LLLite](train_lllite_README.md#Preparing-the-dataset) for details.
|
| 32 |
+
|
| 33 |
+
### Using transparency (alpha channel) of the image
|
| 34 |
+
|
| 35 |
+
The transparency (alpha channel) of the training image is used as a mask. The part with transparency 0 is ignored, the part with transparency 255 is trained. For semi-transparent parts, the loss weight changes according to the transparency (127 gives a weight of about 0.5).
|
| 36 |
+
|
| 37 |
+

|
| 38 |
+
|
| 39 |
+
※Each image is a transparent PNG
|
| 40 |
+
|
| 41 |
+
Specify `--alpha_mask` in the training script options or specify `alpha_mask` in the subset of the dataset configuration file. For example, it will look like this.
|
| 42 |
+
|
| 43 |
+
```toml
|
| 44 |
+
[[datasets.subsets]]
|
| 45 |
+
image_dir = "/path/to/image/dir"
|
| 46 |
+
caption_extension = ".txt"
|
| 47 |
+
num_repeats = 8
|
| 48 |
+
alpha_mask = true
|
| 49 |
+
```
|
| 50 |
+
|
| 51 |
+
## Notes on training
|
| 52 |
+
|
| 53 |
+
- At the moment, only the dataset in the DreamBooth method is supported.
|
| 54 |
+
- The mask is applied after the size is reduced to 1/8, which is the size of the latents. Therefore, fine details (such as ahoge or earrings) may not be learned well. Some dilations of the mask may be necessary.
|
| 55 |
+
- If using masked loss, it may not be necessary to include parts that are not to be trained in the caption. (To be verified)
|
| 56 |
+
- In the case of `alpha_mask`, the latents cache is automatically regenerated when the enable/disable state of the mask is switched.
|