Allex21 commited on
Commit
c11c83e
·
verified ·
1 Parent(s): eac965b

Upload 89 files

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
.gitattributes CHANGED
@@ -33,3 +33,5 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
 
 
 
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
36
+ libbitsandbytes_cuda116.dll filter=lfs diff=lfs merge=lfs -text
37
+ libbitsandbytes_cuda118.dll filter=lfs diff=lfs merge=lfs -text
FUNDING.yml ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ # These are supported funding model platforms
2
+
3
+ github: kohya-ss
__init__.py ADDED
File without changes
adafactor_fused.py ADDED
@@ -0,0 +1,106 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ import torch
3
+ from transformers import Adafactor
4
+
5
+ @torch.no_grad()
6
+ def adafactor_step_param(self, p, group):
7
+ if p.grad is None:
8
+ return
9
+ grad = p.grad
10
+ if grad.dtype in {torch.float16, torch.bfloat16}:
11
+ grad = grad.float()
12
+ if grad.is_sparse:
13
+ raise RuntimeError("Adafactor does not support sparse gradients.")
14
+
15
+ state = self.state[p]
16
+ grad_shape = grad.shape
17
+
18
+ factored, use_first_moment = Adafactor._get_options(group, grad_shape)
19
+ # State Initialization
20
+ if len(state) == 0:
21
+ state["step"] = 0
22
+
23
+ if use_first_moment:
24
+ # Exponential moving average of gradient values
25
+ state["exp_avg"] = torch.zeros_like(grad)
26
+ if factored:
27
+ state["exp_avg_sq_row"] = torch.zeros(grad_shape[:-1]).to(grad)
28
+ state["exp_avg_sq_col"] = torch.zeros(grad_shape[:-2] + grad_shape[-1:]).to(grad)
29
+ else:
30
+ state["exp_avg_sq"] = torch.zeros_like(grad)
31
+
32
+ state["RMS"] = 0
33
+ else:
34
+ if use_first_moment:
35
+ state["exp_avg"] = state["exp_avg"].to(grad)
36
+ if factored:
37
+ state["exp_avg_sq_row"] = state["exp_avg_sq_row"].to(grad)
38
+ state["exp_avg_sq_col"] = state["exp_avg_sq_col"].to(grad)
39
+ else:
40
+ state["exp_avg_sq"] = state["exp_avg_sq"].to(grad)
41
+
42
+ p_data_fp32 = p
43
+ if p.dtype in {torch.float16, torch.bfloat16}:
44
+ p_data_fp32 = p_data_fp32.float()
45
+
46
+ state["step"] += 1
47
+ state["RMS"] = Adafactor._rms(p_data_fp32)
48
+ lr = Adafactor._get_lr(group, state)
49
+
50
+ beta2t = 1.0 - math.pow(state["step"], group["decay_rate"])
51
+ update = (grad ** 2) + group["eps"][0]
52
+ if factored:
53
+ exp_avg_sq_row = state["exp_avg_sq_row"]
54
+ exp_avg_sq_col = state["exp_avg_sq_col"]
55
+
56
+ exp_avg_sq_row.mul_(beta2t).add_(update.mean(dim=-1), alpha=(1.0 - beta2t))
57
+ exp_avg_sq_col.mul_(beta2t).add_(update.mean(dim=-2), alpha=(1.0 - beta2t))
58
+
59
+ # Approximation of exponential moving average of square of gradient
60
+ update = Adafactor._approx_sq_grad(exp_avg_sq_row, exp_avg_sq_col)
61
+ update.mul_(grad)
62
+ else:
63
+ exp_avg_sq = state["exp_avg_sq"]
64
+
65
+ exp_avg_sq.mul_(beta2t).add_(update, alpha=(1.0 - beta2t))
66
+ update = exp_avg_sq.rsqrt().mul_(grad)
67
+
68
+ update.div_((Adafactor._rms(update) / group["clip_threshold"]).clamp_(min=1.0))
69
+ update.mul_(lr)
70
+
71
+ if use_first_moment:
72
+ exp_avg = state["exp_avg"]
73
+ exp_avg.mul_(group["beta1"]).add_(update, alpha=(1 - group["beta1"]))
74
+ update = exp_avg
75
+
76
+ if group["weight_decay"] != 0:
77
+ p_data_fp32.add_(p_data_fp32, alpha=(-group["weight_decay"] * lr))
78
+
79
+ p_data_fp32.add_(-update)
80
+
81
+ if p.dtype in {torch.float16, torch.bfloat16}:
82
+ p.copy_(p_data_fp32)
83
+
84
+
85
+ @torch.no_grad()
86
+ def adafactor_step(self, closure=None):
87
+ """
88
+ Performs a single optimization step
89
+
90
+ Arguments:
91
+ closure (callable, optional): A closure that reevaluates the model
92
+ and returns the loss.
93
+ """
94
+ loss = None
95
+ if closure is not None:
96
+ loss = closure()
97
+
98
+ for group in self.param_groups:
99
+ for p in group["params"]:
100
+ adafactor_step_param(self, p, group)
101
+
102
+ return loss
103
+
104
+ def patch_adafactor_fused(optimizer: Adafactor):
105
+ optimizer.step_param = adafactor_step_param.__get__(optimizer)
106
+ optimizer.step = adafactor_step.__get__(optimizer)
attention.py ADDED
@@ -0,0 +1,119 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import torch
3
+ from functools import cache, wraps
4
+
5
+ # pylint: disable=protected-access, missing-function-docstring, line-too-long
6
+
7
+ # ARC GPUs can't allocate more than 4GB to a single block so we slice the attention layers
8
+
9
+ sdpa_slice_trigger_rate = float(os.environ.get('IPEX_SDPA_SLICE_TRIGGER_RATE', 1))
10
+ attention_slice_rate = float(os.environ.get('IPEX_ATTENTION_SLICE_RATE', 0.5))
11
+
12
+ # Find something divisible with the input_tokens
13
+ @cache
14
+ def find_split_size(original_size, slice_block_size, slice_rate=2):
15
+ split_size = original_size
16
+ while True:
17
+ if (split_size * slice_block_size) <= slice_rate and original_size % split_size == 0:
18
+ return split_size
19
+ split_size = split_size - 1
20
+ if split_size <= 1:
21
+ return 1
22
+ return split_size
23
+
24
+
25
+ # Find slice sizes for SDPA
26
+ @cache
27
+ def find_sdpa_slice_sizes(query_shape, key_shape, query_element_size, slice_rate=2, trigger_rate=3):
28
+ batch_size, attn_heads, query_len, _ = query_shape
29
+ _, _, key_len, _ = key_shape
30
+
31
+ slice_batch_size = attn_heads * (query_len * key_len) * query_element_size / 1024 / 1024 / 1024
32
+
33
+ split_batch_size = batch_size
34
+ split_head_size = attn_heads
35
+ split_query_size = query_len
36
+
37
+ do_batch_split = False
38
+ do_head_split = False
39
+ do_query_split = False
40
+
41
+ if batch_size * slice_batch_size >= trigger_rate:
42
+ do_batch_split = True
43
+ split_batch_size = find_split_size(batch_size, slice_batch_size, slice_rate=slice_rate)
44
+
45
+ if split_batch_size * slice_batch_size > slice_rate:
46
+ slice_head_size = split_batch_size * (query_len * key_len) * query_element_size / 1024 / 1024 / 1024
47
+ do_head_split = True
48
+ split_head_size = find_split_size(attn_heads, slice_head_size, slice_rate=slice_rate)
49
+
50
+ if split_head_size * slice_head_size > slice_rate:
51
+ slice_query_size = split_batch_size * split_head_size * (key_len) * query_element_size / 1024 / 1024 / 1024
52
+ do_query_split = True
53
+ split_query_size = find_split_size(query_len, slice_query_size, slice_rate=slice_rate)
54
+
55
+ return do_batch_split, do_head_split, do_query_split, split_batch_size, split_head_size, split_query_size
56
+
57
+
58
+ original_scaled_dot_product_attention = torch.nn.functional.scaled_dot_product_attention
59
+ @wraps(torch.nn.functional.scaled_dot_product_attention)
60
+ def dynamic_scaled_dot_product_attention(query, key, value, attn_mask=None, dropout_p=0.0, is_causal=False, **kwargs):
61
+ if query.device.type != "xpu":
62
+ return original_scaled_dot_product_attention(query, key, value, attn_mask=attn_mask, dropout_p=dropout_p, is_causal=is_causal, **kwargs)
63
+ is_unsqueezed = False
64
+ if len(query.shape) == 3:
65
+ query = query.unsqueeze(0)
66
+ is_unsqueezed = True
67
+ if len(key.shape) == 3:
68
+ key = key.unsqueeze(0)
69
+ if len(value.shape) == 3:
70
+ value = value.unsqueeze(0)
71
+ do_batch_split, do_head_split, do_query_split, split_batch_size, split_head_size, split_query_size = find_sdpa_slice_sizes(query.shape, key.shape, query.element_size(), slice_rate=attention_slice_rate, trigger_rate=sdpa_slice_trigger_rate)
72
+
73
+ # Slice SDPA
74
+ if do_batch_split:
75
+ batch_size, attn_heads, query_len, _ = query.shape
76
+ _, _, _, head_dim = value.shape
77
+ hidden_states = torch.zeros((batch_size, attn_heads, query_len, head_dim), device=query.device, dtype=query.dtype)
78
+ if attn_mask is not None:
79
+ attn_mask = attn_mask.expand((query.shape[0], query.shape[1], query.shape[2], key.shape[-2]))
80
+ for ib in range(batch_size // split_batch_size):
81
+ start_idx = ib * split_batch_size
82
+ end_idx = (ib + 1) * split_batch_size
83
+ if do_head_split:
84
+ for ih in range(attn_heads // split_head_size): # pylint: disable=invalid-name
85
+ start_idx_h = ih * split_head_size
86
+ end_idx_h = (ih + 1) * split_head_size
87
+ if do_query_split:
88
+ for iq in range(query_len // split_query_size): # pylint: disable=invalid-name
89
+ start_idx_q = iq * split_query_size
90
+ end_idx_q = (iq + 1) * split_query_size
91
+ hidden_states[start_idx:end_idx, start_idx_h:end_idx_h, start_idx_q:end_idx_q, :] = original_scaled_dot_product_attention(
92
+ query[start_idx:end_idx, start_idx_h:end_idx_h, start_idx_q:end_idx_q, :],
93
+ key[start_idx:end_idx, start_idx_h:end_idx_h, :, :],
94
+ value[start_idx:end_idx, start_idx_h:end_idx_h, :, :],
95
+ attn_mask=attn_mask[start_idx:end_idx, start_idx_h:end_idx_h, start_idx_q:end_idx_q, :] if attn_mask is not None else attn_mask,
96
+ dropout_p=dropout_p, is_causal=is_causal, **kwargs
97
+ )
98
+ else:
99
+ hidden_states[start_idx:end_idx, start_idx_h:end_idx_h, :, :] = original_scaled_dot_product_attention(
100
+ query[start_idx:end_idx, start_idx_h:end_idx_h, :, :],
101
+ key[start_idx:end_idx, start_idx_h:end_idx_h, :, :],
102
+ value[start_idx:end_idx, start_idx_h:end_idx_h, :, :],
103
+ attn_mask=attn_mask[start_idx:end_idx, start_idx_h:end_idx_h, :, :] if attn_mask is not None else attn_mask,
104
+ dropout_p=dropout_p, is_causal=is_causal, **kwargs
105
+ )
106
+ else:
107
+ hidden_states[start_idx:end_idx, :, :, :] = original_scaled_dot_product_attention(
108
+ query[start_idx:end_idx, :, :, :],
109
+ key[start_idx:end_idx, :, :, :],
110
+ value[start_idx:end_idx, :, :, :],
111
+ attn_mask=attn_mask[start_idx:end_idx, :, :, :] if attn_mask is not None else attn_mask,
112
+ dropout_p=dropout_p, is_causal=is_causal, **kwargs
113
+ )
114
+ torch.xpu.synchronize(query.device)
115
+ else:
116
+ hidden_states = original_scaled_dot_product_attention(query, key, value, attn_mask=attn_mask, dropout_p=dropout_p, is_causal=is_causal, **kwargs)
117
+ if is_unsqueezed:
118
+ hidden_states.squeeze(0)
119
+ return hidden_states
attention_processors.py ADDED
@@ -0,0 +1,227 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ from typing import Any
3
+ from einops import rearrange
4
+ import torch
5
+ from diffusers.models.attention_processor import Attention
6
+
7
+
8
+ # flash attention forwards and backwards
9
+
10
+ # https://arxiv.org/abs/2205.14135
11
+
12
+ EPSILON = 1e-6
13
+
14
+
15
+ class FlashAttentionFunction(torch.autograd.function.Function):
16
+ @staticmethod
17
+ @torch.no_grad()
18
+ def forward(ctx, q, k, v, mask, causal, q_bucket_size, k_bucket_size):
19
+ """Algorithm 2 in the paper"""
20
+
21
+ device = q.device
22
+ dtype = q.dtype
23
+ max_neg_value = -torch.finfo(q.dtype).max
24
+ qk_len_diff = max(k.shape[-2] - q.shape[-2], 0)
25
+
26
+ o = torch.zeros_like(q)
27
+ all_row_sums = torch.zeros((*q.shape[:-1], 1), dtype=dtype, device=device)
28
+ all_row_maxes = torch.full(
29
+ (*q.shape[:-1], 1), max_neg_value, dtype=dtype, device=device
30
+ )
31
+
32
+ scale = q.shape[-1] ** -0.5
33
+
34
+ if mask is None:
35
+ mask = (None,) * math.ceil(q.shape[-2] / q_bucket_size)
36
+ else:
37
+ mask = rearrange(mask, "b n -> b 1 1 n")
38
+ mask = mask.split(q_bucket_size, dim=-1)
39
+
40
+ row_splits = zip(
41
+ q.split(q_bucket_size, dim=-2),
42
+ o.split(q_bucket_size, dim=-2),
43
+ mask,
44
+ all_row_sums.split(q_bucket_size, dim=-2),
45
+ all_row_maxes.split(q_bucket_size, dim=-2),
46
+ )
47
+
48
+ for ind, (qc, oc, row_mask, row_sums, row_maxes) in enumerate(row_splits):
49
+ q_start_index = ind * q_bucket_size - qk_len_diff
50
+
51
+ col_splits = zip(
52
+ k.split(k_bucket_size, dim=-2),
53
+ v.split(k_bucket_size, dim=-2),
54
+ )
55
+
56
+ for k_ind, (kc, vc) in enumerate(col_splits):
57
+ k_start_index = k_ind * k_bucket_size
58
+
59
+ attn_weights = (
60
+ torch.einsum("... i d, ... j d -> ... i j", qc, kc) * scale
61
+ )
62
+
63
+ if row_mask is not None:
64
+ attn_weights.masked_fill_(~row_mask, max_neg_value)
65
+
66
+ if causal and q_start_index < (k_start_index + k_bucket_size - 1):
67
+ causal_mask = torch.ones(
68
+ (qc.shape[-2], kc.shape[-2]), dtype=torch.bool, device=device
69
+ ).triu(q_start_index - k_start_index + 1)
70
+ attn_weights.masked_fill_(causal_mask, max_neg_value)
71
+
72
+ block_row_maxes = attn_weights.amax(dim=-1, keepdims=True)
73
+ attn_weights -= block_row_maxes
74
+ exp_weights = torch.exp(attn_weights)
75
+
76
+ if row_mask is not None:
77
+ exp_weights.masked_fill_(~row_mask, 0.0)
78
+
79
+ block_row_sums = exp_weights.sum(dim=-1, keepdims=True).clamp(
80
+ min=EPSILON
81
+ )
82
+
83
+ new_row_maxes = torch.maximum(block_row_maxes, row_maxes)
84
+
85
+ exp_values = torch.einsum(
86
+ "... i j, ... j d -> ... i d", exp_weights, vc
87
+ )
88
+
89
+ exp_row_max_diff = torch.exp(row_maxes - new_row_maxes)
90
+ exp_block_row_max_diff = torch.exp(block_row_maxes - new_row_maxes)
91
+
92
+ new_row_sums = (
93
+ exp_row_max_diff * row_sums
94
+ + exp_block_row_max_diff * block_row_sums
95
+ )
96
+
97
+ oc.mul_((row_sums / new_row_sums) * exp_row_max_diff).add_(
98
+ (exp_block_row_max_diff / new_row_sums) * exp_values
99
+ )
100
+
101
+ row_maxes.copy_(new_row_maxes)
102
+ row_sums.copy_(new_row_sums)
103
+
104
+ ctx.args = (causal, scale, mask, q_bucket_size, k_bucket_size)
105
+ ctx.save_for_backward(q, k, v, o, all_row_sums, all_row_maxes)
106
+
107
+ return o
108
+
109
+ @staticmethod
110
+ @torch.no_grad()
111
+ def backward(ctx, do):
112
+ """Algorithm 4 in the paper"""
113
+
114
+ causal, scale, mask, q_bucket_size, k_bucket_size = ctx.args
115
+ q, k, v, o, l, m = ctx.saved_tensors
116
+
117
+ device = q.device
118
+
119
+ max_neg_value = -torch.finfo(q.dtype).max
120
+ qk_len_diff = max(k.shape[-2] - q.shape[-2], 0)
121
+
122
+ dq = torch.zeros_like(q)
123
+ dk = torch.zeros_like(k)
124
+ dv = torch.zeros_like(v)
125
+
126
+ row_splits = zip(
127
+ q.split(q_bucket_size, dim=-2),
128
+ o.split(q_bucket_size, dim=-2),
129
+ do.split(q_bucket_size, dim=-2),
130
+ mask,
131
+ l.split(q_bucket_size, dim=-2),
132
+ m.split(q_bucket_size, dim=-2),
133
+ dq.split(q_bucket_size, dim=-2),
134
+ )
135
+
136
+ for ind, (qc, oc, doc, row_mask, lc, mc, dqc) in enumerate(row_splits):
137
+ q_start_index = ind * q_bucket_size - qk_len_diff
138
+
139
+ col_splits = zip(
140
+ k.split(k_bucket_size, dim=-2),
141
+ v.split(k_bucket_size, dim=-2),
142
+ dk.split(k_bucket_size, dim=-2),
143
+ dv.split(k_bucket_size, dim=-2),
144
+ )
145
+
146
+ for k_ind, (kc, vc, dkc, dvc) in enumerate(col_splits):
147
+ k_start_index = k_ind * k_bucket_size
148
+
149
+ attn_weights = (
150
+ torch.einsum("... i d, ... j d -> ... i j", qc, kc) * scale
151
+ )
152
+
153
+ if causal and q_start_index < (k_start_index + k_bucket_size - 1):
154
+ causal_mask = torch.ones(
155
+ (qc.shape[-2], kc.shape[-2]), dtype=torch.bool, device=device
156
+ ).triu(q_start_index - k_start_index + 1)
157
+ attn_weights.masked_fill_(causal_mask, max_neg_value)
158
+
159
+ exp_attn_weights = torch.exp(attn_weights - mc)
160
+
161
+ if row_mask is not None:
162
+ exp_attn_weights.masked_fill_(~row_mask, 0.0)
163
+
164
+ p = exp_attn_weights / lc
165
+
166
+ dv_chunk = torch.einsum("... i j, ... i d -> ... j d", p, doc)
167
+ dp = torch.einsum("... i d, ... j d -> ... i j", doc, vc)
168
+
169
+ D = (doc * oc).sum(dim=-1, keepdims=True)
170
+ ds = p * scale * (dp - D)
171
+
172
+ dq_chunk = torch.einsum("... i j, ... j d -> ... i d", ds, kc)
173
+ dk_chunk = torch.einsum("... i j, ... i d -> ... j d", ds, qc)
174
+
175
+ dqc.add_(dq_chunk)
176
+ dkc.add_(dk_chunk)
177
+ dvc.add_(dv_chunk)
178
+
179
+ return dq, dk, dv, None, None, None, None
180
+
181
+
182
+ class FlashAttnProcessor:
183
+ def __call__(
184
+ self,
185
+ attn: Attention,
186
+ hidden_states,
187
+ encoder_hidden_states=None,
188
+ attention_mask=None,
189
+ ) -> Any:
190
+ q_bucket_size = 512
191
+ k_bucket_size = 1024
192
+
193
+ h = attn.heads
194
+ q = attn.to_q(hidden_states)
195
+
196
+ encoder_hidden_states = (
197
+ encoder_hidden_states
198
+ if encoder_hidden_states is not None
199
+ else hidden_states
200
+ )
201
+ encoder_hidden_states = encoder_hidden_states.to(hidden_states.dtype)
202
+
203
+ if hasattr(attn, "hypernetwork") and attn.hypernetwork is not None:
204
+ context_k, context_v = attn.hypernetwork.forward(
205
+ hidden_states, encoder_hidden_states
206
+ )
207
+ context_k = context_k.to(hidden_states.dtype)
208
+ context_v = context_v.to(hidden_states.dtype)
209
+ else:
210
+ context_k = encoder_hidden_states
211
+ context_v = encoder_hidden_states
212
+
213
+ k = attn.to_k(context_k)
214
+ v = attn.to_v(context_v)
215
+ del encoder_hidden_states, hidden_states
216
+
217
+ q, k, v = map(lambda t: rearrange(t, "b n (h d) -> b h n d", h=h), (q, k, v))
218
+
219
+ out = FlashAttentionFunction.apply(
220
+ q, k, v, attention_mask, False, q_bucket_size, k_bucket_size
221
+ )
222
+
223
+ out = rearrange(out, "b h n d -> b n (h d)")
224
+
225
+ out = attn.to_out[0](out)
226
+ out = attn.to_out[1](out)
227
+ return out
blip.py ADDED
@@ -0,0 +1,245 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ '''
2
+ * Copyright (c) 2022, salesforce.com, inc.
3
+ * All rights reserved.
4
+ * SPDX-License-Identifier: BSD-3-Clause
5
+ * For full license text, see LICENSE.txt file in the repo root or https://opensource.org/licenses/BSD-3-Clause
6
+ * By Junnan Li
7
+ '''
8
+ import warnings
9
+ warnings.filterwarnings("ignore")
10
+
11
+ # from models.vit import VisionTransformer, interpolate_pos_embed
12
+ # from models.med import BertConfig, BertModel, BertLMHeadModel
13
+ from blip.vit import VisionTransformer, interpolate_pos_embed
14
+ from blip.med import BertConfig, BertModel, BertLMHeadModel
15
+ from transformers import BertTokenizer
16
+
17
+ import torch
18
+ from torch import nn
19
+ import torch.nn.functional as F
20
+
21
+ import os
22
+ from urllib.parse import urlparse
23
+ from timm.models.hub import download_cached_file
24
+ from library.utils import setup_logging
25
+ setup_logging()
26
+ import logging
27
+ logger = logging.getLogger(__name__)
28
+
29
+ class BLIP_Base(nn.Module):
30
+ def __init__(self,
31
+ med_config = 'configs/med_config.json',
32
+ image_size = 224,
33
+ vit = 'base',
34
+ vit_grad_ckpt = False,
35
+ vit_ckpt_layer = 0,
36
+ ):
37
+ """
38
+ Args:
39
+ med_config (str): path for the mixture of encoder-decoder model's configuration file
40
+ image_size (int): input image size
41
+ vit (str): model size of vision transformer
42
+ """
43
+ super().__init__()
44
+
45
+ self.visual_encoder, vision_width = create_vit(vit,image_size, vit_grad_ckpt, vit_ckpt_layer)
46
+ self.tokenizer = init_tokenizer()
47
+ med_config = BertConfig.from_json_file(med_config)
48
+ med_config.encoder_width = vision_width
49
+ self.text_encoder = BertModel(config=med_config, add_pooling_layer=False)
50
+
51
+
52
+ def forward(self, image, caption, mode):
53
+
54
+ assert mode in ['image', 'text', 'multimodal'], "mode parameter must be image, text, or multimodal"
55
+ text = self.tokenizer(caption, return_tensors="pt").to(image.device)
56
+
57
+ if mode=='image':
58
+ # return image features
59
+ image_embeds = self.visual_encoder(image)
60
+ return image_embeds
61
+
62
+ elif mode=='text':
63
+ # return text features
64
+ text_output = self.text_encoder(text.input_ids, attention_mask = text.attention_mask,
65
+ return_dict = True, mode = 'text')
66
+ return text_output.last_hidden_state
67
+
68
+ elif mode=='multimodal':
69
+ # return multimodel features
70
+ image_embeds = self.visual_encoder(image)
71
+ image_atts = torch.ones(image_embeds.size()[:-1],dtype=torch.long).to(image.device)
72
+
73
+ text.input_ids[:,0] = self.tokenizer.enc_token_id
74
+ output = self.text_encoder(text.input_ids,
75
+ attention_mask = text.attention_mask,
76
+ encoder_hidden_states = image_embeds,
77
+ encoder_attention_mask = image_atts,
78
+ return_dict = True,
79
+ )
80
+ return output.last_hidden_state
81
+
82
+
83
+
84
+ class BLIP_Decoder(nn.Module):
85
+ def __init__(self,
86
+ med_config = 'configs/med_config.json',
87
+ image_size = 384,
88
+ vit = 'base',
89
+ vit_grad_ckpt = False,
90
+ vit_ckpt_layer = 0,
91
+ prompt = 'a picture of ',
92
+ ):
93
+ """
94
+ Args:
95
+ med_config (str): path for the mixture of encoder-decoder model's configuration file
96
+ image_size (int): input image size
97
+ vit (str): model size of vision transformer
98
+ """
99
+ super().__init__()
100
+
101
+ self.visual_encoder, vision_width = create_vit(vit,image_size, vit_grad_ckpt, vit_ckpt_layer)
102
+ self.tokenizer = init_tokenizer()
103
+ med_config = BertConfig.from_json_file(med_config)
104
+ med_config.encoder_width = vision_width
105
+ self.text_decoder = BertLMHeadModel(config=med_config)
106
+
107
+ self.prompt = prompt
108
+ self.prompt_length = len(self.tokenizer(self.prompt).input_ids)-1
109
+
110
+
111
+ def forward(self, image, caption):
112
+
113
+ image_embeds = self.visual_encoder(image)
114
+ image_atts = torch.ones(image_embeds.size()[:-1],dtype=torch.long).to(image.device)
115
+
116
+ text = self.tokenizer(caption, padding='longest', truncation=True, max_length=40, return_tensors="pt").to(image.device)
117
+
118
+ text.input_ids[:,0] = self.tokenizer.bos_token_id
119
+
120
+ decoder_targets = text.input_ids.masked_fill(text.input_ids == self.tokenizer.pad_token_id, -100)
121
+ decoder_targets[:,:self.prompt_length] = -100
122
+
123
+ decoder_output = self.text_decoder(text.input_ids,
124
+ attention_mask = text.attention_mask,
125
+ encoder_hidden_states = image_embeds,
126
+ encoder_attention_mask = image_atts,
127
+ labels = decoder_targets,
128
+ return_dict = True,
129
+ )
130
+ loss_lm = decoder_output.loss
131
+
132
+ return loss_lm
133
+
134
+ def generate(self, image, sample=False, num_beams=3, max_length=30, min_length=10, top_p=0.9, repetition_penalty=1.0):
135
+ image_embeds = self.visual_encoder(image)
136
+
137
+ # recent version of transformers seems to do repeat_interleave automatically
138
+ # if not sample:
139
+ # image_embeds = image_embeds.repeat_interleave(num_beams,dim=0)
140
+
141
+ image_atts = torch.ones(image_embeds.size()[:-1],dtype=torch.long).to(image.device)
142
+ model_kwargs = {"encoder_hidden_states": image_embeds, "encoder_attention_mask":image_atts}
143
+
144
+ prompt = [self.prompt] * image.size(0)
145
+ input_ids = self.tokenizer(prompt, return_tensors="pt").input_ids.to(image.device)
146
+ input_ids[:,0] = self.tokenizer.bos_token_id
147
+ input_ids = input_ids[:, :-1]
148
+
149
+ if sample:
150
+ #nucleus sampling
151
+ outputs = self.text_decoder.generate(input_ids=input_ids,
152
+ max_length=max_length,
153
+ min_length=min_length,
154
+ do_sample=True,
155
+ top_p=top_p,
156
+ num_return_sequences=1,
157
+ eos_token_id=self.tokenizer.sep_token_id,
158
+ pad_token_id=self.tokenizer.pad_token_id,
159
+ repetition_penalty=1.1,
160
+ **model_kwargs)
161
+ else:
162
+ #beam search
163
+ outputs = self.text_decoder.generate(input_ids=input_ids,
164
+ max_length=max_length,
165
+ min_length=min_length,
166
+ num_beams=num_beams,
167
+ eos_token_id=self.tokenizer.sep_token_id,
168
+ pad_token_id=self.tokenizer.pad_token_id,
169
+ repetition_penalty=repetition_penalty,
170
+ **model_kwargs)
171
+
172
+ captions = []
173
+ for output in outputs:
174
+ caption = self.tokenizer.decode(output, skip_special_tokens=True)
175
+ captions.append(caption[len(self.prompt):])
176
+ return captions
177
+
178
+
179
+ def blip_decoder(pretrained='',**kwargs):
180
+ model = BLIP_Decoder(**kwargs)
181
+ if pretrained:
182
+ model,msg = load_checkpoint(model,pretrained)
183
+ assert(len(msg.missing_keys)==0)
184
+ return model
185
+
186
+ def blip_feature_extractor(pretrained='',**kwargs):
187
+ model = BLIP_Base(**kwargs)
188
+ if pretrained:
189
+ model,msg = load_checkpoint(model,pretrained)
190
+ assert(len(msg.missing_keys)==0)
191
+ return model
192
+
193
+ def init_tokenizer():
194
+ tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
195
+ tokenizer.add_special_tokens({'bos_token':'[DEC]'})
196
+ tokenizer.add_special_tokens({'additional_special_tokens':['[ENC]']})
197
+ tokenizer.enc_token_id = tokenizer.additional_special_tokens_ids[0]
198
+ return tokenizer
199
+
200
+
201
+ def create_vit(vit, image_size, use_grad_checkpointing=False, ckpt_layer=0, drop_path_rate=0):
202
+
203
+ assert vit in ['base', 'large'], "vit parameter must be base or large"
204
+ if vit=='base':
205
+ vision_width = 768
206
+ visual_encoder = VisionTransformer(img_size=image_size, patch_size=16, embed_dim=vision_width, depth=12,
207
+ num_heads=12, use_grad_checkpointing=use_grad_checkpointing, ckpt_layer=ckpt_layer,
208
+ drop_path_rate=0 or drop_path_rate
209
+ )
210
+ elif vit=='large':
211
+ vision_width = 1024
212
+ visual_encoder = VisionTransformer(img_size=image_size, patch_size=16, embed_dim=vision_width, depth=24,
213
+ num_heads=16, use_grad_checkpointing=use_grad_checkpointing, ckpt_layer=ckpt_layer,
214
+ drop_path_rate=0.1 or drop_path_rate
215
+ )
216
+ return visual_encoder, vision_width
217
+
218
+ def is_url(url_or_filename):
219
+ parsed = urlparse(url_or_filename)
220
+ return parsed.scheme in ("http", "https")
221
+
222
+ def load_checkpoint(model,url_or_filename):
223
+ if is_url(url_or_filename):
224
+ cached_file = download_cached_file(url_or_filename, check_hash=False, progress=True)
225
+ checkpoint = torch.load(cached_file, map_location='cpu')
226
+ elif os.path.isfile(url_or_filename):
227
+ checkpoint = torch.load(url_or_filename, map_location='cpu')
228
+ else:
229
+ raise RuntimeError('checkpoint url or path is invalid')
230
+
231
+ state_dict = checkpoint['model']
232
+
233
+ state_dict['visual_encoder.pos_embed'] = interpolate_pos_embed(state_dict['visual_encoder.pos_embed'],model.visual_encoder)
234
+ if 'visual_encoder_m.pos_embed' in model.state_dict().keys():
235
+ state_dict['visual_encoder_m.pos_embed'] = interpolate_pos_embed(state_dict['visual_encoder_m.pos_embed'],
236
+ model.visual_encoder_m)
237
+ for key in model.state_dict().keys():
238
+ if key in state_dict.keys():
239
+ if state_dict[key].shape!=model.state_dict()[key].shape:
240
+ del state_dict[key]
241
+
242
+ msg = model.load_state_dict(state_dict,strict=False)
243
+ logger.info('load checkpoint from %s'%url_or_filename)
244
+ return model,msg
245
+
cache_latents.py ADDED
@@ -0,0 +1,205 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # latentsのdiskへの事前キャッシュを行う / cache latents to disk
2
+
3
+ import argparse
4
+ import math
5
+ from multiprocessing import Value
6
+ import os
7
+
8
+ from accelerate.utils import set_seed
9
+ import torch
10
+ from tqdm import tqdm
11
+
12
+ from library import config_util
13
+ from library import train_util
14
+ from library import sdxl_train_util
15
+ from library.config_util import (
16
+ ConfigSanitizer,
17
+ BlueprintGenerator,
18
+ )
19
+ from library.utils import setup_logging, add_logging_arguments
20
+ setup_logging()
21
+ import logging
22
+
23
+ logger = logging.getLogger(__name__)
24
+
25
+
26
+ def cache_to_disk(args: argparse.Namespace) -> None:
27
+ setup_logging(args, reset=True)
28
+ train_util.prepare_dataset_args(args, True)
29
+
30
+ # check cache latents arg
31
+ assert args.cache_latents_to_disk, "cache_latents_to_disk must be True / cache_latents_to_diskはTrueである必要があります"
32
+
33
+ use_dreambooth_method = args.in_json is None
34
+
35
+ if args.seed is not None:
36
+ set_seed(args.seed) # 乱数系列を初期化する
37
+
38
+ # tokenizerを準備する:datasetを動かすために必要
39
+ if args.sdxl:
40
+ tokenizer1, tokenizer2 = sdxl_train_util.load_tokenizers(args)
41
+ tokenizers = [tokenizer1, tokenizer2]
42
+ else:
43
+ tokenizer = train_util.load_tokenizer(args)
44
+ tokenizers = [tokenizer]
45
+
46
+ # データセットを準備する
47
+ if args.dataset_class is None:
48
+ blueprint_generator = BlueprintGenerator(ConfigSanitizer(True, True, False, True))
49
+ if args.dataset_config is not None:
50
+ logger.info(f"Load dataset config from {args.dataset_config}")
51
+ user_config = config_util.load_user_config(args.dataset_config)
52
+ ignored = ["train_data_dir", "in_json"]
53
+ if any(getattr(args, attr) is not None for attr in ignored):
54
+ logger.warning(
55
+ "ignore following options because config file is found: {0} / 設定ファイルが利用されるため以下のオプションは無視されます: {0}".format(
56
+ ", ".join(ignored)
57
+ )
58
+ )
59
+ else:
60
+ if use_dreambooth_method:
61
+ logger.info("Using DreamBooth method.")
62
+ user_config = {
63
+ "datasets": [
64
+ {
65
+ "subsets": config_util.generate_dreambooth_subsets_config_by_subdirs(
66
+ args.train_data_dir, args.reg_data_dir
67
+ )
68
+ }
69
+ ]
70
+ }
71
+ else:
72
+ logger.info("Training with captions.")
73
+ user_config = {
74
+ "datasets": [
75
+ {
76
+ "subsets": [
77
+ {
78
+ "image_dir": args.train_data_dir,
79
+ "metadata_file": args.in_json,
80
+ }
81
+ ]
82
+ }
83
+ ]
84
+ }
85
+
86
+ blueprint = blueprint_generator.generate(user_config, args, tokenizer=tokenizers)
87
+ train_dataset_group = config_util.generate_dataset_group_by_blueprint(blueprint.dataset_group)
88
+ else:
89
+ train_dataset_group = train_util.load_arbitrary_dataset(args, tokenizers)
90
+
91
+ # datasetのcache_latentsを呼ばなければ、生の画像が返る
92
+
93
+ current_epoch = Value("i", 0)
94
+ current_step = Value("i", 0)
95
+ ds_for_collator = train_dataset_group if args.max_data_loader_n_workers == 0 else None
96
+ collator = train_util.collator_class(current_epoch, current_step, ds_for_collator)
97
+
98
+ # acceleratorを準備する
99
+ logger.info("prepare accelerator")
100
+ args.deepspeed = False
101
+ accelerator = train_util.prepare_accelerator(args)
102
+
103
+ # mixed precisionに対応した型を用意しておき適宜castする
104
+ weight_dtype, _ = train_util.prepare_dtype(args)
105
+ vae_dtype = torch.float32 if args.no_half_vae else weight_dtype
106
+
107
+ # モデルを読み込む
108
+ logger.info("load model")
109
+ if args.sdxl:
110
+ (_, _, _, vae, _, _, _) = sdxl_train_util.load_target_model(args, accelerator, "sdxl", weight_dtype)
111
+ else:
112
+ _, vae, _, _ = train_util.load_target_model(args, weight_dtype, accelerator)
113
+
114
+ if torch.__version__ >= "2.0.0": # PyTorch 2.0.0 以上対応のxformersなら以下が使える
115
+ vae.set_use_memory_efficient_attention_xformers(args.xformers)
116
+ vae.to(accelerator.device, dtype=vae_dtype)
117
+ vae.requires_grad_(False)
118
+ vae.eval()
119
+
120
+ # dataloaderを準備する
121
+ train_dataset_group.set_caching_mode("latents")
122
+
123
+ # DataLoaderのプロセス数:0 は persistent_workers が使えないので注意
124
+ n_workers = min(args.max_data_loader_n_workers, os.cpu_count()) # cpu_count or max_data_loader_n_workers
125
+
126
+ train_dataloader = torch.utils.data.DataLoader(
127
+ train_dataset_group,
128
+ batch_size=1,
129
+ shuffle=True,
130
+ collate_fn=collator,
131
+ num_workers=n_workers,
132
+ persistent_workers=args.persistent_data_loader_workers,
133
+ )
134
+
135
+ # acceleratorを使ってモデルを準備する:マルチGPUで使えるようになるはず
136
+ train_dataloader = accelerator.prepare(train_dataloader)
137
+
138
+ # データ取得のためのループ
139
+ for batch in tqdm(train_dataloader):
140
+ b_size = len(batch["images"])
141
+ vae_batch_size = b_size if args.vae_batch_size is None else args.vae_batch_size
142
+ flip_aug = batch["flip_aug"]
143
+ alpha_mask = batch["alpha_mask"]
144
+ random_crop = batch["random_crop"]
145
+ bucket_reso = batch["bucket_reso"]
146
+
147
+ # バッチを分割して処理する
148
+ for i in range(0, b_size, vae_batch_size):
149
+ images = batch["images"][i : i + vae_batch_size]
150
+ absolute_paths = batch["absolute_paths"][i : i + vae_batch_size]
151
+ resized_sizes = batch["resized_sizes"][i : i + vae_batch_size]
152
+
153
+ image_infos = []
154
+ for i, (image, absolute_path, resized_size) in enumerate(zip(images, absolute_paths, resized_sizes)):
155
+ image_info = train_util.ImageInfo(absolute_path, 1, "dummy", False, absolute_path)
156
+ image_info.image = image
157
+ image_info.bucket_reso = bucket_reso
158
+ image_info.resized_size = resized_size
159
+ image_info.latents_npz = os.path.splitext(absolute_path)[0] + ".npz"
160
+
161
+ if args.skip_existing:
162
+ if train_util.is_disk_cached_latents_is_expected(
163
+ image_info.bucket_reso, image_info.latents_npz, flip_aug, alpha_mask
164
+ ):
165
+ logger.warning(f"Skipping {image_info.latents_npz} because it already exists.")
166
+ continue
167
+
168
+ image_infos.append(image_info)
169
+
170
+ if len(image_infos) > 0:
171
+ train_util.cache_batch_latents(vae, True, image_infos, flip_aug, alpha_mask, random_crop)
172
+
173
+ accelerator.wait_for_everyone()
174
+ accelerator.print(f"Finished caching latents for {len(train_dataset_group)} batches.")
175
+
176
+
177
+ def setup_parser() -> argparse.ArgumentParser:
178
+ parser = argparse.ArgumentParser()
179
+
180
+ add_logging_arguments(parser)
181
+ train_util.add_sd_models_arguments(parser)
182
+ train_util.add_training_arguments(parser, True)
183
+ train_util.add_dataset_arguments(parser, True, True, True)
184
+ config_util.add_config_arguments(parser)
185
+ parser.add_argument("--sdxl", action="store_true", help="Use SDXL model / SDXLモデルを使用する")
186
+ parser.add_argument(
187
+ "--no_half_vae",
188
+ action="store_true",
189
+ help="do not use fp16/bf16 VAE in mixed precision (use float VAE) / mixed precisionでも fp16/bf16 VAEを使わずfloat VAEを使う",
190
+ )
191
+ parser.add_argument(
192
+ "--skip_existing",
193
+ action="store_true",
194
+ help="skip images if npz already exists (both normal and flipped exists if flip_aug is enabled) / npzが既に存在する画像をスキップする(flip_aug有効時は通常、反転の両方が存在する画像をスキップ)",
195
+ )
196
+ return parser
197
+
198
+
199
+ if __name__ == "__main__":
200
+ parser = setup_parser()
201
+
202
+ args = parser.parse_args()
203
+ args = train_util.read_config_from_file(args, parser)
204
+
205
+ cache_to_disk(args)
cache_text_encoder_outputs.py ADDED
@@ -0,0 +1,197 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # text encoder出力のdiskへの事前キャッシュを行う / cache text encoder outputs to disk in advance
2
+
3
+ import argparse
4
+ import math
5
+ from multiprocessing import Value
6
+ import os
7
+
8
+ from accelerate.utils import set_seed
9
+ import torch
10
+ from tqdm import tqdm
11
+
12
+ from library import config_util
13
+ from library import train_util
14
+ from library import sdxl_train_util
15
+ from library.config_util import (
16
+ ConfigSanitizer,
17
+ BlueprintGenerator,
18
+ )
19
+ from library.utils import setup_logging, add_logging_arguments
20
+ setup_logging()
21
+ import logging
22
+ logger = logging.getLogger(__name__)
23
+
24
+ def cache_to_disk(args: argparse.Namespace) -> None:
25
+ setup_logging(args, reset=True)
26
+ train_util.prepare_dataset_args(args, True)
27
+
28
+ # check cache arg
29
+ assert (
30
+ args.cache_text_encoder_outputs_to_disk
31
+ ), "cache_text_encoder_outputs_to_disk must be True / cache_text_encoder_outputs_to_diskはTrueである必要があります"
32
+
33
+ # できるだけ準備はしておくが今のところSDXLのみしか動かない
34
+ assert (
35
+ args.sdxl
36
+ ), "cache_text_encoder_outputs_to_disk is only available for SDXL / cache_text_encoder_outputs_to_diskはSDXLのみ利用可能です"
37
+
38
+ use_dreambooth_method = args.in_json is None
39
+
40
+ if args.seed is not None:
41
+ set_seed(args.seed) # 乱数系列を初期化する
42
+
43
+ # tokenizerを準備する:datasetを動かすために必要
44
+ if args.sdxl:
45
+ tokenizer1, tokenizer2 = sdxl_train_util.load_tokenizers(args)
46
+ tokenizers = [tokenizer1, tokenizer2]
47
+ else:
48
+ tokenizer = train_util.load_tokenizer(args)
49
+ tokenizers = [tokenizer]
50
+
51
+ # データセットを準備する
52
+ if args.dataset_class is None:
53
+ blueprint_generator = BlueprintGenerator(ConfigSanitizer(True, True, False, True))
54
+ if args.dataset_config is not None:
55
+ logger.info(f"Load dataset config from {args.dataset_config}")
56
+ user_config = config_util.load_user_config(args.dataset_config)
57
+ ignored = ["train_data_dir", "in_json"]
58
+ if any(getattr(args, attr) is not None for attr in ignored):
59
+ logger.warning(
60
+ "ignore following options because config file is found: {0} / 設定ファイルが利用されるため以下のオプションは無視されます: {0}".format(
61
+ ", ".join(ignored)
62
+ )
63
+ )
64
+ else:
65
+ if use_dreambooth_method:
66
+ logger.info("Using DreamBooth method.")
67
+ user_config = {
68
+ "datasets": [
69
+ {
70
+ "subsets": config_util.generate_dreambooth_subsets_config_by_subdirs(
71
+ args.train_data_dir, args.reg_data_dir
72
+ )
73
+ }
74
+ ]
75
+ }
76
+ else:
77
+ logger.info("Training with captions.")
78
+ user_config = {
79
+ "datasets": [
80
+ {
81
+ "subsets": [
82
+ {
83
+ "image_dir": args.train_data_dir,
84
+ "metadata_file": args.in_json,
85
+ }
86
+ ]
87
+ }
88
+ ]
89
+ }
90
+
91
+ blueprint = blueprint_generator.generate(user_config, args, tokenizer=tokenizers)
92
+ train_dataset_group = config_util.generate_dataset_group_by_blueprint(blueprint.dataset_group)
93
+ else:
94
+ train_dataset_group = train_util.load_arbitrary_dataset(args, tokenizers)
95
+
96
+ current_epoch = Value("i", 0)
97
+ current_step = Value("i", 0)
98
+ ds_for_collator = train_dataset_group if args.max_data_loader_n_workers == 0 else None
99
+ collator = train_util.collator_class(current_epoch, current_step, ds_for_collator)
100
+
101
+ # acceleratorを準備する
102
+ logger.info("prepare accelerator")
103
+ args.deepspeed = False
104
+ accelerator = train_util.prepare_accelerator(args)
105
+
106
+ # mixed precisionに対応した型を用意しておき適宜castする
107
+ weight_dtype, _ = train_util.prepare_dtype(args)
108
+
109
+ # モデルを読み込む
110
+ logger.info("load model")
111
+ if args.sdxl:
112
+ (_, text_encoder1, text_encoder2, _, _, _, _) = sdxl_train_util.load_target_model(args, accelerator, "sdxl", weight_dtype)
113
+ text_encoders = [text_encoder1, text_encoder2]
114
+ else:
115
+ text_encoder1, _, _, _ = train_util.load_target_model(args, weight_dtype, accelerator)
116
+ text_encoders = [text_encoder1]
117
+
118
+ for text_encoder in text_encoders:
119
+ text_encoder.to(accelerator.device, dtype=weight_dtype)
120
+ text_encoder.requires_grad_(False)
121
+ text_encoder.eval()
122
+
123
+ # dataloaderを準備する
124
+ train_dataset_group.set_caching_mode("text")
125
+
126
+ # DataLoaderのプロセス数:0 は persistent_workers が使えないので注意
127
+ n_workers = min(args.max_data_loader_n_workers, os.cpu_count()) # cpu_count or max_data_loader_n_workers
128
+
129
+ train_dataloader = torch.utils.data.DataLoader(
130
+ train_dataset_group,
131
+ batch_size=1,
132
+ shuffle=True,
133
+ collate_fn=collator,
134
+ num_workers=n_workers,
135
+ persistent_workers=args.persistent_data_loader_workers,
136
+ )
137
+
138
+ # acceleratorを使ってモデルを準備する:マルチGPUで使えるようになるはず
139
+ train_dataloader = accelerator.prepare(train_dataloader)
140
+
141
+ # データ取得のためのループ
142
+ for batch in tqdm(train_dataloader):
143
+ absolute_paths = batch["absolute_paths"]
144
+ input_ids1_list = batch["input_ids1_list"]
145
+ input_ids2_list = batch["input_ids2_list"]
146
+
147
+ image_infos = []
148
+ for absolute_path, input_ids1, input_ids2 in zip(absolute_paths, input_ids1_list, input_ids2_list):
149
+ image_info = train_util.ImageInfo(absolute_path, 1, "dummy", False, absolute_path)
150
+ image_info.text_encoder_outputs_npz = os.path.splitext(absolute_path)[0] + train_util.TEXT_ENCODER_OUTPUTS_CACHE_SUFFIX
151
+ image_info
152
+
153
+ if args.skip_existing:
154
+ if os.path.exists(image_info.text_encoder_outputs_npz):
155
+ logger.warning(f"Skipping {image_info.text_encoder_outputs_npz} because it already exists.")
156
+ continue
157
+
158
+ image_info.input_ids1 = input_ids1
159
+ image_info.input_ids2 = input_ids2
160
+ image_infos.append(image_info)
161
+
162
+ if len(image_infos) > 0:
163
+ b_input_ids1 = torch.stack([image_info.input_ids1 for image_info in image_infos])
164
+ b_input_ids2 = torch.stack([image_info.input_ids2 for image_info in image_infos])
165
+ train_util.cache_batch_text_encoder_outputs(
166
+ image_infos, tokenizers, text_encoders, args.max_token_length, True, b_input_ids1, b_input_ids2, weight_dtype
167
+ )
168
+
169
+ accelerator.wait_for_everyone()
170
+ accelerator.print(f"Finished caching latents for {len(train_dataset_group)} batches.")
171
+
172
+
173
+ def setup_parser() -> argparse.ArgumentParser:
174
+ parser = argparse.ArgumentParser()
175
+
176
+ add_logging_arguments(parser)
177
+ train_util.add_sd_models_arguments(parser)
178
+ train_util.add_training_arguments(parser, True)
179
+ train_util.add_dataset_arguments(parser, True, True, True)
180
+ config_util.add_config_arguments(parser)
181
+ sdxl_train_util.add_sdxl_training_arguments(parser)
182
+ parser.add_argument("--sdxl", action="store_true", help="Use SDXL model / SDXLモデルを使用する")
183
+ parser.add_argument(
184
+ "--skip_existing",
185
+ action="store_true",
186
+ help="skip images if npz already exists (both normal and flipped exists if flip_aug is enabled) / npzが既に存在する画像をスキップする(flip_aug有効時は通常、反転の両方が存在する画像をスキップ)",
187
+ )
188
+ return parser
189
+
190
+
191
+ if __name__ == "__main__":
192
+ parser = setup_parser()
193
+
194
+ args = parser.parse_args()
195
+ args = train_util.read_config_from_file(args, parser)
196
+
197
+ cache_to_disk(args)
canny.py ADDED
@@ -0,0 +1,34 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import cv2
3
+
4
+ import logging
5
+ from library.utils import setup_logging
6
+ setup_logging()
7
+ logger = logging.getLogger(__name__)
8
+
9
+ def canny(args):
10
+ img = cv2.imread(args.input)
11
+ img = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY)
12
+
13
+ canny_img = cv2.Canny(img, args.thres1, args.thres2)
14
+ # canny_img = 255 - canny_img
15
+
16
+ cv2.imwrite(args.output, canny_img)
17
+ logger.info("done!")
18
+
19
+
20
+ def setup_parser() -> argparse.ArgumentParser:
21
+ parser = argparse.ArgumentParser()
22
+ parser.add_argument("--input", type=str, default=None, help="input path")
23
+ parser.add_argument("--output", type=str, default=None, help="output path")
24
+ parser.add_argument("--thres1", type=int, default=32, help="thres1")
25
+ parser.add_argument("--thres2", type=int, default=224, help="thres2")
26
+
27
+ return parser
28
+
29
+
30
+ if __name__ == '__main__':
31
+ parser = setup_parser()
32
+
33
+ args = parser.parse_args()
34
+ canny(args)
cextension.py ADDED
@@ -0,0 +1,54 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import ctypes as ct
2
+ from pathlib import Path
3
+ from warnings import warn
4
+
5
+ from .cuda_setup.main import evaluate_cuda_setup
6
+
7
+
8
+ class CUDALibrary_Singleton(object):
9
+ _instance = None
10
+
11
+ def __init__(self):
12
+ raise RuntimeError("Call get_instance() instead")
13
+
14
+ def initialize(self):
15
+ binary_name = evaluate_cuda_setup()
16
+ package_dir = Path(__file__).parent
17
+ binary_path = package_dir / binary_name
18
+
19
+ if not binary_path.exists():
20
+ print(f"CUDA SETUP: TODO: compile library for specific version: {binary_name}")
21
+ legacy_binary_name = "libbitsandbytes.so"
22
+ print(f"CUDA SETUP: Defaulting to {legacy_binary_name}...")
23
+ binary_path = package_dir / legacy_binary_name
24
+ if not binary_path.exists():
25
+ print('CUDA SETUP: CUDA detection failed. Either CUDA driver not installed, CUDA not installed, or you have multiple conflicting CUDA libraries!')
26
+ print('CUDA SETUP: If you compiled from source, try again with `make CUDA_VERSION=DETECTED_CUDA_VERSION` for example, `make CUDA_VERSION=113`.')
27
+ raise Exception('CUDA SETUP: Setup Failed!')
28
+ # self.lib = ct.cdll.LoadLibrary(binary_path)
29
+ self.lib = ct.cdll.LoadLibrary(str(binary_path)) # $$$
30
+ else:
31
+ print(f"CUDA SETUP: Loading binary {binary_path}...")
32
+ # self.lib = ct.cdll.LoadLibrary(binary_path)
33
+ self.lib = ct.cdll.LoadLibrary(str(binary_path)) # $$$
34
+
35
+ @classmethod
36
+ def get_instance(cls):
37
+ if cls._instance is None:
38
+ cls._instance = cls.__new__(cls)
39
+ cls._instance.initialize()
40
+ return cls._instance
41
+
42
+
43
+ lib = CUDALibrary_Singleton.get_instance().lib
44
+ try:
45
+ lib.cadam32bit_g32
46
+ lib.get_context.restype = ct.c_void_p
47
+ lib.get_cusparse.restype = ct.c_void_p
48
+ COMPILED_WITH_CUDA = True
49
+ except AttributeError:
50
+ warn(
51
+ "The installed version of bitsandbytes was compiled without GPU support. "
52
+ "8-bit optimizers and GPU quantization are unavailable."
53
+ )
54
+ COMPILED_WITH_CUDA = False
check_lora_weights.py ADDED
@@ -0,0 +1,48 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import os
3
+ import torch
4
+ from safetensors.torch import load_file
5
+ from library.utils import setup_logging
6
+ setup_logging()
7
+ import logging
8
+ logger = logging.getLogger(__name__)
9
+
10
+ def main(file):
11
+ logger.info(f"loading: {file}")
12
+ if os.path.splitext(file)[1] == ".safetensors":
13
+ sd = load_file(file)
14
+ else:
15
+ sd = torch.load(file, map_location="cpu")
16
+
17
+ values = []
18
+
19
+ keys = list(sd.keys())
20
+ for key in keys:
21
+ if "lora_up" in key or "lora_down" in key or "lora_A" in key or "lora_B" in key or "oft_" in key:
22
+ values.append((key, sd[key]))
23
+ print(f"number of LoRA modules: {len(values)}")
24
+
25
+ if args.show_all_keys:
26
+ for key in [k for k in keys if k not in values]:
27
+ values.append((key, sd[key]))
28
+ print(f"number of all modules: {len(values)}")
29
+
30
+ for key, value in values:
31
+ value = value.to(torch.float32)
32
+ print(f"{key},{str(tuple(value.size())).replace(', ', '-')},{torch.mean(torch.abs(value))},{torch.min(torch.abs(value))}")
33
+
34
+
35
+ def setup_parser() -> argparse.ArgumentParser:
36
+ parser = argparse.ArgumentParser()
37
+ parser.add_argument("file", type=str, help="model file to check / 重みを確認するモデルファイル")
38
+ parser.add_argument("-s", "--show_all_keys", action="store_true", help="show all keys / 全てのキーを表示する")
39
+
40
+ return parser
41
+
42
+
43
+ if __name__ == "__main__":
44
+ parser = setup_parser()
45
+
46
+ args = parser.parse_args()
47
+
48
+ main(args.file)
clean_captions_and_tags.py ADDED
@@ -0,0 +1,194 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # このスクリプトのライセンスは、Apache License 2.0とします
2
+ # (c) 2022 Kohya S. @kohya_ss
3
+
4
+ import argparse
5
+ import glob
6
+ import os
7
+ import json
8
+ import re
9
+
10
+ from tqdm import tqdm
11
+ from library.utils import setup_logging
12
+ setup_logging()
13
+ import logging
14
+ logger = logging.getLogger(__name__)
15
+
16
+ PATTERN_HAIR_LENGTH = re.compile(r', (long|short|medium) hair, ')
17
+ PATTERN_HAIR_CUT = re.compile(r', (bob|hime) cut, ')
18
+ PATTERN_HAIR = re.compile(r', ([\w\-]+) hair, ')
19
+ PATTERN_WORD = re.compile(r', ([\w\-]+|hair ornament), ')
20
+
21
+ # 複数人がいるとき、複数の髪色や目の色が定義されていれば削除する
22
+ PATTERNS_REMOVE_IN_MULTI = [
23
+ PATTERN_HAIR_LENGTH,
24
+ PATTERN_HAIR_CUT,
25
+ re.compile(r', [\w\-]+ eyes, '),
26
+ re.compile(r', ([\w\-]+ sleeves|sleeveless), '),
27
+ # 複数の髪型定義がある場合は削除する
28
+ re.compile(
29
+ r', (ponytail|braid|ahoge|twintails|[\w\-]+ bun|single hair bun|single side bun|two side up|two tails|[\w\-]+ braid|sidelocks), '),
30
+ ]
31
+
32
+
33
+ def clean_tags(image_key, tags):
34
+ # replace '_' to ' '
35
+ tags = tags.replace('^_^', '^@@@^')
36
+ tags = tags.replace('_', ' ')
37
+ tags = tags.replace('^@@@^', '^_^')
38
+
39
+ # remove rating: deepdanbooruのみ
40
+ tokens = tags.split(", rating")
41
+ if len(tokens) == 1:
42
+ # WD14 taggerのときはこちらになるのでメッセージは出さない
43
+ # logger.info("no rating:")
44
+ # logger.info(f"{image_key} {tags}")
45
+ pass
46
+ else:
47
+ if len(tokens) > 2:
48
+ logger.info("multiple ratings:")
49
+ logger.info(f"{image_key} {tags}")
50
+ tags = tokens[0]
51
+
52
+ tags = ", " + tags.replace(", ", ", , ") + ", " # カンマ付きで検索をするための身も蓋もない対策
53
+
54
+ # 複数の人物がいる場合は髪色等のタグを削除する
55
+ if 'girls' in tags or 'boys' in tags:
56
+ for pat in PATTERNS_REMOVE_IN_MULTI:
57
+ found = pat.findall(tags)
58
+ if len(found) > 1: # 二つ以上、タグがある
59
+ tags = pat.sub("", tags)
60
+
61
+ # 髪の特殊対応
62
+ srch_hair_len = PATTERN_HAIR_LENGTH.search(tags) # 髪の長さタグは例外なので避けておく(全員が同じ髪の長さの場合)
63
+ if srch_hair_len:
64
+ org = srch_hair_len.group()
65
+ tags = PATTERN_HAIR_LENGTH.sub(", @@@, ", tags)
66
+
67
+ found = PATTERN_HAIR.findall(tags)
68
+ if len(found) > 1:
69
+ tags = PATTERN_HAIR.sub("", tags)
70
+
71
+ if srch_hair_len:
72
+ tags = tags.replace(", @@@, ", org) # 戻す
73
+
74
+ # white shirtとshirtみたいな重複タグの削除
75
+ found = PATTERN_WORD.findall(tags)
76
+ for word in found:
77
+ if re.search(f", ((\w+) )+{word}, ", tags):
78
+ tags = tags.replace(f", {word}, ", "")
79
+
80
+ tags = tags.replace(", , ", ", ")
81
+ assert tags.startswith(", ") and tags.endswith(", ")
82
+ tags = tags[2:-2]
83
+ return tags
84
+
85
+
86
+ # 上から順に検索、置換される
87
+ # ('置換元文字列', '置換後文字列')
88
+ CAPTION_REPLACEMENTS = [
89
+ ('anime anime', 'anime'),
90
+ ('young ', ''),
91
+ ('anime girl', 'girl'),
92
+ ('cartoon female', 'girl'),
93
+ ('cartoon lady', 'girl'),
94
+ ('cartoon character', 'girl'), # a or ~s
95
+ ('cartoon woman', 'girl'),
96
+ ('cartoon women', 'girls'),
97
+ ('cartoon girl', 'girl'),
98
+ ('anime female', 'girl'),
99
+ ('anime lady', 'girl'),
100
+ ('anime character', 'girl'), # a or ~s
101
+ ('anime woman', 'girl'),
102
+ ('anime women', 'girls'),
103
+ ('lady', 'girl'),
104
+ ('female', 'girl'),
105
+ ('woman', 'girl'),
106
+ ('women', 'girls'),
107
+ ('people', 'girls'),
108
+ ('person', 'girl'),
109
+ ('a cartoon figure', 'a figure'),
110
+ ('a cartoon image', 'an image'),
111
+ ('a cartoon picture', 'a picture'),
112
+ ('an anime cartoon image', 'an image'),
113
+ ('a cartoon anime drawing', 'a drawing'),
114
+ ('a cartoon drawing', 'a drawing'),
115
+ ('girl girl', 'girl'),
116
+ ]
117
+
118
+
119
+ def clean_caption(caption):
120
+ for rf, rt in CAPTION_REPLACEMENTS:
121
+ replaced = True
122
+ while replaced:
123
+ bef = caption
124
+ caption = caption.replace(rf, rt)
125
+ replaced = bef != caption
126
+ return caption
127
+
128
+
129
+ def main(args):
130
+ if os.path.exists(args.in_json):
131
+ logger.info(f"loading existing metadata: {args.in_json}")
132
+ with open(args.in_json, "rt", encoding='utf-8') as f:
133
+ metadata = json.load(f)
134
+ else:
135
+ logger.error("no metadata / メタデータファイルがありません")
136
+ return
137
+
138
+ logger.info("cleaning captions and tags.")
139
+ image_keys = list(metadata.keys())
140
+ for image_key in tqdm(image_keys):
141
+ tags = metadata[image_key].get('tags')
142
+ if tags is None:
143
+ logger.error(f"image does not have tags / メタデータにタグがありません: {image_key}")
144
+ else:
145
+ org = tags
146
+ tags = clean_tags(image_key, tags)
147
+ metadata[image_key]['tags'] = tags
148
+ if args.debug and org != tags:
149
+ logger.info("FROM: " + org)
150
+ logger.info("TO: " + tags)
151
+
152
+ caption = metadata[image_key].get('caption')
153
+ if caption is None:
154
+ logger.error(f"image does not have caption / メタデータにキャプションがありません: {image_key}")
155
+ else:
156
+ org = caption
157
+ caption = clean_caption(caption)
158
+ metadata[image_key]['caption'] = caption
159
+ if args.debug and org != caption:
160
+ logger.info("FROM: " + org)
161
+ logger.info("TO: " + caption)
162
+
163
+ # metadataを書き出して終わり
164
+ logger.info(f"writing metadata: {args.out_json}")
165
+ with open(args.out_json, "wt", encoding='utf-8') as f:
166
+ json.dump(metadata, f, indent=2)
167
+ logger.info("done!")
168
+
169
+
170
+ def setup_parser() -> argparse.ArgumentParser:
171
+ parser = argparse.ArgumentParser()
172
+ # parser.add_argument("train_data_dir", type=str, help="directory for train images / 学習画像データのディレクトリ")
173
+ parser.add_argument("in_json", type=str, help="metadata file to input / 読み込むメタデータファイル")
174
+ parser.add_argument("out_json", type=str, help="metadata file to output / メタデータファイル書き出し先")
175
+ parser.add_argument("--debug", action="store_true", help="debug mode")
176
+
177
+ return parser
178
+
179
+
180
+ if __name__ == '__main__':
181
+ parser = setup_parser()
182
+
183
+ args, unknown = parser.parse_known_args()
184
+ if len(unknown) == 1:
185
+ logger.warning("WARNING: train_data_dir argument is removed. This script will not work with three arguments in future. Please specify two arguments: in_json and out_json.")
186
+ logger.warning("All captions and tags in the metadata are processed.")
187
+ logger.warning("警告: train_data_dir引数は不要になりました。将来的には三つの引数を指定すると動かなくなる予定です。読み込み元のメタデータと書き出し先の二つの引数だけ指定してください。")
188
+ logger.warning("メタデータ内のすべてのキャプションとタグが処理されます。")
189
+ args.in_json = args.out_json
190
+ args.out_json = unknown[0]
191
+ elif len(unknown) > 0:
192
+ raise ValueError(f"error: unrecognized arguments: {unknown}")
193
+
194
+ main(args)
config_README-en.md ADDED
@@ -0,0 +1,386 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ Original Source by kohya-ss
2
+
3
+ First version:
4
+ A.I Translation by Model: NousResearch/Nous-Hermes-2-Mixtral-8x7B-DPO, editing by Darkstorm2150
5
+
6
+ Some parts are manually added.
7
+
8
+ # Config Readme
9
+
10
+ This README is about the configuration files that can be passed with the `--dataset_config` option.
11
+
12
+ ## Overview
13
+
14
+ By passing a configuration file, users can make detailed settings.
15
+
16
+ * Multiple datasets can be configured
17
+ * For example, by setting `resolution` for each dataset, they can be mixed and trained.
18
+ * In training methods that support both the DreamBooth approach and the fine-tuning approach, datasets of the DreamBooth method and the fine-tuning method can be mixed.
19
+ * Settings can be changed for each subset
20
+ * A subset is a partition of the dataset by image directory or metadata. Several subsets make up a dataset.
21
+ * Options such as `keep_tokens` and `flip_aug` can be set for each subset. On the other hand, options such as `resolution` and `batch_size` can be set for each dataset, and their values are common among subsets belonging to the same dataset. More details will be provided later.
22
+
23
+ The configuration file format can be JSON or TOML. Considering the ease of writing, it is recommended to use [TOML](https://toml.io/ja/v1.0.0-rc.2). The following explanation assumes the use of TOML.
24
+
25
+
26
+ Here is an example of a configuration file written in TOML.
27
+
28
+ ```toml
29
+ [general]
30
+ shuffle_caption = true
31
+ caption_extension = '.txt'
32
+ keep_tokens = 1
33
+
34
+ # This is a DreamBooth-style dataset
35
+ [[datasets]]
36
+ resolution = 512
37
+ batch_size = 4
38
+ keep_tokens = 2
39
+
40
+ [[datasets.subsets]]
41
+ image_dir = 'C:\hoge'
42
+ class_tokens = 'hoge girl'
43
+ # This subset uses keep_tokens = 2 (the value of the parent datasets)
44
+
45
+ [[datasets.subsets]]
46
+ image_dir = 'C:\fuga'
47
+ class_tokens = 'fuga boy'
48
+ keep_tokens = 3
49
+
50
+ [[datasets.subsets]]
51
+ is_reg = true
52
+ image_dir = 'C:\reg'
53
+ class_tokens = 'human'
54
+ keep_tokens = 1
55
+
56
+ # This is a fine-tuning dataset
57
+ [[datasets]]
58
+ resolution = [768, 768]
59
+ batch_size = 2
60
+
61
+ [[datasets.subsets]]
62
+ image_dir = 'C:\piyo'
63
+ metadata_file = 'C:\piyo\piyo_md.json'
64
+ # This subset uses keep_tokens = 1 (the value of [general])
65
+ ```
66
+
67
+ In this example, three directories are trained as a DreamBooth-style dataset at 512x512 (batch size 4), and one directory is trained as a fine-tuning dataset at 768x768 (batch size 2).
68
+
69
+ ## Settings for datasets and subsets
70
+
71
+ Settings for datasets and subsets are divided into several registration locations.
72
+
73
+ * `[general]`
74
+ * This is where options that apply to all datasets or all subsets are specified.
75
+ * If there are options with the same name in the dataset-specific or subset-specific settings, the dataset-specific or subset-specific settings take precedence.
76
+ * `[[datasets]]`
77
+ * `datasets` is where settings for datasets are registered. This is where options that apply individually to each dataset are specified.
78
+ * If there are subset-specific settings, the subset-specific settings take precedence.
79
+ * `[[datasets.subsets]]`
80
+ * `datasets.subsets` is where settings for subsets are registered. This is where options that apply individually to each subset are specified.
81
+
82
+ Here is an image showing the correspondence between image directories and registration locations in the previous example.
83
+
84
+ ```
85
+ C:\
86
+ ├─ hoge -> [[datasets.subsets]] No.1 ┐ ┐
87
+ ├─ fuga -> [[datasets.subsets]] No.2 |-> [[datasets]] No.1 |-> [general]
88
+ ├─ reg -> [[datasets.subsets]] No.3 ┘ |
89
+ └─ piyo -> [[datasets.subsets]] No.4 --> [[datasets]] No.2 ┘
90
+ ```
91
+
92
+ The image directory corresponds to each `[[datasets.subsets]]`. Then, multiple `[[datasets.subsets]]` are combined to form one `[[datasets]]`. All `[[datasets]]` and `[[datasets.subsets]]` belong to `[general]`.
93
+
94
+ The available options for each registration location may differ, but if the same option is specified, the value in the lower registration location will take precedence. You can check how the `keep_tokens` option is handled in the previous example for better understanding.
95
+
96
+ Additionally, the available options may vary depending on the method that the learning approach supports.
97
+
98
+ * Options specific to the DreamBooth method
99
+ * Options specific to the fine-tuning method
100
+ * Options available when using the caption dropout technique
101
+
102
+ When using both the DreamBooth method and the fine-tuning method, they can be used together with a learning approach that supports both.
103
+ When using them together, a point to note is that the method is determined based on the dataset, so it is not possible to mix DreamBooth method subsets and fine-tuning method subsets within the same dataset.
104
+ In other words, if you want to use both methods together, you need to set up subsets of different methods belonging to different datasets.
105
+
106
+ In terms of program behavior, if the `metadata_file` option exists, it is determined to be a subset of fine-tuning. Therefore, for subsets belonging to the same dataset, as long as they are either "all have the `metadata_file` option" or "all have no `metadata_file` option," there is no problem.
107
+
108
+ Below, the available options will be explained. For options with the same name as the command-line argument, the explanation will be omitted in principle. Please refer to other READMEs.
109
+
110
+ ### Common options for all learning methods
111
+
112
+ These are options that can be specified regardless of the learning method.
113
+
114
+ #### Data set specific options
115
+
116
+ These are options related to the configuration of the data set. They cannot be described in `datasets.subsets`.
117
+
118
+
119
+ | Option Name | Example Setting | `[general]` | `[[datasets]]` |
120
+ | ---- | ---- | ---- | ---- |
121
+ | `batch_size` | `1` | o | o |
122
+ | `bucket_no_upscale` | `true` | o | o |
123
+ | `bucket_reso_steps` | `64` | o | o |
124
+ | `enable_bucket` | `true` | o | o |
125
+ | `max_bucket_reso` | `1024` | o | o |
126
+ | `min_bucket_reso` | `128` | o | o |
127
+ | `resolution` | `256`, `[512, 512]` | o | o |
128
+
129
+ * `batch_size`
130
+ * This corresponds to the command-line argument `--train_batch_size`.
131
+ * `max_bucket_reso`, `min_bucket_reso`
132
+ * Specify the maximum and minimum resolutions of the bucket. It must be divisible by `bucket_reso_steps`.
133
+
134
+ These settings are fixed per dataset. That means that subsets belonging to the same dataset will share these settings. For example, if you want to prepare datasets with different resolutions, you can define them as separate datasets as shown in the example above, and set different resolutions for each.
135
+
136
+ #### Options for Subsets
137
+
138
+ These options are related to subset configuration.
139
+
140
+ | Option Name | Example | `[general]` | `[[datasets]]` | `[[dataset.subsets]]` |
141
+ | ---- | ---- | ---- | ---- | ---- |
142
+ | `color_aug` | `false` | o | o | o |
143
+ | `face_crop_aug_range` | `[1.0, 3.0]` | o | o | o |
144
+ | `flip_aug` | `true` | o | o | o |
145
+ | `keep_tokens` | `2` | o | o | o |
146
+ | `num_repeats` | `10` | o | o | o |
147
+ | `random_crop` | `false` | o | o | o |
148
+ | `shuffle_caption` | `true` | o | o | o |
149
+ | `caption_prefix` | `"masterpiece, best quality, "` | o | o | o |
150
+ | `caption_suffix` | `", from side"` | o | o | o |
151
+ | `caption_separator` | (not specified) | o | o | o |
152
+ | `keep_tokens_separator` | `“|||”` | o | o | o |
153
+ | `secondary_separator` | `“;;;”` | o | o | o |
154
+ | `enable_wildcard` | `true` | o | o | o |
155
+
156
+ * `num_repeats`
157
+ * Specifies the number of repeats for images in a subset. This is equivalent to `--dataset_repeats` in fine-tuning but can be specified for any training method.
158
+ * `caption_prefix`, `caption_suffix`
159
+ * Specifies the prefix and suffix strings to be appended to the captions. Shuffling is performed with these strings included. Be cautious when using `keep_tokens`.
160
+ * `caption_separator`
161
+ * Specifies the string to separate the tags. The default is `,`. This option is usually not necessary to set.
162
+ * `keep_tokens_separator`
163
+ * Specifies the string to separate the parts to be fixed in the caption. For example, if you specify `aaa, bbb ||| ccc, ddd, eee, fff ||| ggg, hhh`, the parts `aaa, bbb` and `ggg, hhh` will remain, and the rest will be shuffled and dropped. The comma in between is not necessary. As a result, the prompt will be `aaa, bbb, eee, ccc, fff, ggg, hhh` or `aaa, bbb, fff, ccc, eee, ggg, hhh`, etc.
164
+ * `secondary_separator`
165
+ * Specifies an additional separator. The part separated by this separator is treated as one tag and is shuffled and dropped. It is then replaced by `caption_separator`. For example, if you specify `aaa;;;bbb;;;ccc`, it will be replaced by `aaa,bbb,ccc` or dropped together.
166
+ * `enable_wildcard`
167
+ * Enables wildcard notation. This will be explained later.
168
+
169
+ ### DreamBooth-specific options
170
+
171
+ DreamBooth-specific options only exist as subsets-specific options.
172
+
173
+ #### Subset-specific options
174
+
175
+ Options related to the configuration of DreamBooth subsets.
176
+
177
+ | Option Name | Example Setting | `[general]` | `[[datasets]]` | `[[dataset.subsets]]` |
178
+ | ---- | ---- | ---- | ---- | ---- |
179
+ | `image_dir` | `'C:\hoge'` | - | - | o (required) |
180
+ | `caption_extension` | `".txt"` | o | o | o |
181
+ | `class_tokens` | `"sks girl"` | - | - | o |
182
+ | `cache_info` | `false` | o | o | o |
183
+ | `is_reg` | `false` | - | - | o |
184
+
185
+ Firstly, note that for `image_dir`, the path to the image files must be specified as being directly in the directory. Unlike the previous DreamBooth method, where images had to be placed in subdirectories, this is not compatible with that specification. Also, even if you name the folder something like "5_cat", the number of repeats of the image and the class name will not be reflected. If you want to set these individually, you will need to explicitly specify them using `num_repeats` and `class_tokens`.
186
+
187
+ * `image_dir`
188
+ * Specifies the path to the image directory. This is a required option.
189
+ * Images must be placed directly under the directory.
190
+ * `class_tokens`
191
+ * Sets the class tokens.
192
+ * Only used during training when a corresponding caption file does not exist. The determination of whether or not to use it is made on a per-image basis. If `class_tokens` is not specified and a caption file is not found, an error will occur.
193
+ * `cache_info`
194
+ * Specifies whether to cache the image size and caption. If not specified, it is set to `false`. The cache is saved in `metadata_cache.json` in `image_dir`.
195
+ * Caching speeds up the loading of the dataset after the first time. It is effective when dealing with thousands of images or more.
196
+ * `is_reg`
197
+ * Specifies whether the subset images are for normalization. If not specified, it is set to `false`, meaning that the images are not for normalization.
198
+
199
+ ### Fine-tuning method specific options
200
+
201
+ The options for the fine-tuning method only exist for subset-specific options.
202
+
203
+ #### Subset-specific options
204
+
205
+ These options are related to the configuration of the fine-tuning method's subsets.
206
+
207
+ | Option name | Example setting | `[general]` | `[[datasets]]` | `[[dataset.subsets]]` |
208
+ | ---- | ---- | ---- | ---- | ---- |
209
+ | `image_dir` | `'C:\hoge'` | - | - | o |
210
+ | `metadata_file` | `'C:\piyo\piyo_md.json'` | - | - | o (required) |
211
+
212
+ * `image_dir`
213
+ * Specify the path to the image directory. Unlike the DreamBooth method, specifying it is not mandatory, but it is recommended to do so.
214
+ * The case where it is not necessary to specify is when the `--full_path` is added to the command line when generating the metadata file.
215
+ * The images must be placed directly under the directory.
216
+ * `metadata_file`
217
+ * Specify the path to the metadata file used for the subset. This is a required option.
218
+ * It is equivalent to the command-line argument `--in_json`.
219
+ * Due to the specification that a metadata file must be specified for each subset, it is recommended to avoid creating a metadata file with images from different directories as a single metadata file. It is strongly recommended to prepare a separate metadata file for each image directory and register them as separate subsets.
220
+
221
+ ### Options available when caption dropout method can be used
222
+
223
+ The options available when the caption dropout method can be used exist only for subsets. Regardless of whether it's the DreamBooth method or fine-tuning method, if it supports caption dropout, it can be specified.
224
+
225
+ #### Subset-specific options
226
+
227
+ Options related to the setting of subsets that caption dropout can be used for.
228
+
229
+ | Option Name | `[general]` | `[[datasets]]` | `[[dataset.subsets]]` |
230
+ | ---- | ---- | ---- | ---- |
231
+ | `caption_dropout_every_n_epochs` | o | o | o |
232
+ | `caption_dropout_rate` | o | o | o |
233
+ | `caption_tag_dropout_rate` | o | o | o |
234
+
235
+ ## Behavior when there are duplicate subsets
236
+
237
+ In the case of the DreamBooth dataset, if there are multiple `image_dir` directories with the same content, they are considered to be duplicate subsets. For the fine-tuning dataset, if there are multiple `metadata_file` files with the same content, they are considered to be duplicate subsets. If duplicate subsets exist in the dataset, subsequent subsets will be ignored.
238
+
239
+ However, if they belong to different datasets, they are not considered duplicates. For example, if you have subsets with the same `image_dir` in different datasets, they will not be considered duplicates. This is useful when you want to train with the same image but with different resolutions.
240
+
241
+ ```toml
242
+ # If data sets exist separately, they are not considered duplicates and are both used for training.
243
+
244
+ [[datasets]]
245
+ resolution = 512
246
+
247
+ [[datasets.subsets]]
248
+ image_dir = 'C:\hoge'
249
+
250
+ [[datasets]]
251
+ resolution = 768
252
+
253
+ [[datasets.subsets]]
254
+ image_dir = 'C:\hoge'
255
+ ```
256
+
257
+ ## Command Line Argument and Configuration File
258
+
259
+ There are options in the configuration file that have overlapping roles with command line argument options.
260
+
261
+ The following command line argument options are ignored if a configuration file is passed:
262
+
263
+ * `--train_data_dir`
264
+ * `--reg_data_dir`
265
+ * `--in_json`
266
+
267
+ The following command line argument options are given priority over the configuration file options if both are specified simultaneously. In most cases, they have the same names as the corresponding options in the configuration file.
268
+
269
+ | Command Line Argument Option | Prioritized Configuration File Option |
270
+ | ------------------------------- | ------------------------------------- |
271
+ | `--bucket_no_upscale` | |
272
+ | `--bucket_reso_steps` | |
273
+ | `--caption_dropout_every_n_epochs` | |
274
+ | `--caption_dropout_rate` | |
275
+ | `--caption_extension` | |
276
+ | `--caption_tag_dropout_rate` | |
277
+ | `--color_aug` | |
278
+ | `--dataset_repeats` | `num_repeats` |
279
+ | `--enable_bucket` | |
280
+ | `--face_crop_aug_range` | |
281
+ | `--flip_aug` | |
282
+ | `--keep_tokens` | |
283
+ | `--min_bucket_reso` | |
284
+ | `--random_crop` | |
285
+ | `--resolution` | |
286
+ | `--shuffle_caption` | |
287
+ | `--train_batch_size` | `batch_size` |
288
+
289
+ ## Error Guide
290
+
291
+ Currently, we are using an external library to check if the configuration file is written correctly, but the development has not been completed, and there is a problem that the error message is not clear. In the future, we plan to improve this problem.
292
+
293
+ As a temporary measure, we will list common errors and their solutions. If you encounter an error even though it should be correct or if the error content is not understandable, please contact us as it may be a bug.
294
+
295
+ * `voluptuous.error.MultipleInvalid: required key not provided @ ...`: This error occurs when a required option is not provided. It is highly likely that you forgot to specify the option or misspelled the option name.
296
+ * The error location is indicated by `...` in the error message. For example, if you encounter an error like `voluptuous.error.MultipleInvalid: required key not provided @ data['datasets'][0]['subsets'][0]['image_dir']`, it means that the `image_dir` option does not exist in the 0th `subsets` of the 0th `datasets` setting.
297
+ * `voluptuous.error.MultipleInvalid: expected int for dictionary value @ ...`: This error occurs when the specified value format is incorrect. It is highly likely that the value format is incorrect. The `int` part changes depending on the target option. The example configurations in this README may be helpful.
298
+ * `voluptuous.error.MultipleInvalid: extra keys not allowed @ ...`: This error occurs when there is an option name that is not supported. It is highly likely that you misspelled the option name or mistakenly included it.
299
+
300
+ ## Miscellaneous
301
+
302
+ ### Multi-line captions
303
+
304
+ By setting `enable_wildcard = true`, multiple-line captions are also enabled. If the caption file consists of multiple lines, one line is randomly selected as the caption.
305
+
306
+ ```txt
307
+ 1girl, hatsune miku, vocaloid, upper body, looking at viewer, microphone, stage
308
+ a girl with a microphone standing on a stage
309
+ detailed digital art of a girl with a microphone on a stage
310
+ ```
311
+
312
+ It can be combined with wildcard notation.
313
+
314
+ In metadata files, you can also specify multiple-line captions. In the `.json` metadata file, use `\n` to represent a line break. If the caption file consists of multiple lines, `merge_captions_to_metadata.py` will create a metadata file in this format.
315
+
316
+ The tags in the metadata (`tags`) are added to each line of the caption.
317
+
318
+ ```json
319
+ {
320
+ "/path/to/image.png": {
321
+ "caption": "a cartoon of a frog with the word frog on it\ntest multiline caption1\ntest multiline caption2",
322
+ "tags": "open mouth, simple background, standing, no humans, animal, black background, frog, animal costume, animal focus"
323
+ },
324
+ ...
325
+ }
326
+ ```
327
+
328
+ In this case, the actual caption will be `a cartoon of a frog with the word frog on it, open mouth, simple background ...`, `test multiline caption1, open mouth, simple background ...`, `test multiline caption2, open mouth, simple background ...`, etc.
329
+
330
+ ### Example of configuration file : `secondary_separator`, wildcard notation, `keep_tokens_separator`, etc.
331
+
332
+ ```toml
333
+ [general]
334
+ flip_aug = true
335
+ color_aug = false
336
+ resolution = [1024, 1024]
337
+
338
+ [[datasets]]
339
+ batch_size = 6
340
+ enable_bucket = true
341
+ bucket_no_upscale = true
342
+ caption_extension = ".txt"
343
+ keep_tokens_separator= "|||"
344
+ shuffle_caption = true
345
+ caption_tag_dropout_rate = 0.1
346
+ secondary_separator = ";;;" # subset 側に書くこともできます / can be written in the subset side
347
+ enable_wildcard = true # 同上 / same as above
348
+
349
+ [[datasets.subsets]]
350
+ image_dir = "/path/to/image_dir"
351
+ num_repeats = 1
352
+
353
+ # ||| の前後はカンマは不要です(自動的に追加されます) / No comma is required before and after ||| (it is added automatically)
354
+ caption_prefix = "1girl, hatsune miku, vocaloid |||"
355
+
356
+ # ||| の後はシャッフル、drop されず残ります / After |||, it is not shuffled or dropped and remains
357
+ # 単純に文字列として連結されるので、カンマなどは自分で入れる必要があります / It is simply concatenated as a string, so you need to put commas yourself
358
+ caption_suffix = ", anime screencap ||| masterpiece, rating: general"
359
+ ```
360
+
361
+ ### Example of caption, secondary_separator notation: `secondary_separator = ";;;"`
362
+
363
+ ```txt
364
+ 1girl, hatsune miku, vocaloid, upper body, looking at viewer, sky;;;cloud;;;day, outdoors
365
+ ```
366
+ The part `sky;;;cloud;;;day` is replaced with `sky,cloud,day` without shuffling or dropping. When shuffling and dropping are enabled, it is processed as a whole (as one tag). For example, it becomes `vocaloid, 1girl, upper body, sky,cloud,day, outdoors, hatsune miku` (shuffled) or `vocaloid, 1girl, outdoors, looking at viewer, upper body, hatsune miku` (dropped).
367
+
368
+ ### Example of caption, enable_wildcard notation: `enable_wildcard = true`
369
+
370
+ ```txt
371
+ 1girl, hatsune miku, vocaloid, upper body, looking at viewer, {simple|white} background
372
+ ```
373
+ `simple` or `white` is randomly selected, and it becomes `simple background` or `white background`.
374
+
375
+ ```txt
376
+ 1girl, hatsune miku, vocaloid, {{retro style}}
377
+ ```
378
+ If you want to include `{` or `}` in the tag string, double them like `{{` or `}}` (in this example, the actual caption used for training is `{retro style}`).
379
+
380
+ ### Example of caption, `keep_tokens_separator` notation: `keep_tokens_separator = "|||"`
381
+
382
+ ```txt
383
+ 1girl, hatsune miku, vocaloid ||| stage, microphone, white shirt, smile ||| best quality, rating: general
384
+ ```
385
+ It becomes `1girl, hatsune miku, vocaloid, microphone, stage, white shirt, best quality, rating: general` or `1girl, hatsune miku, vocaloid, white shirt, smile, stage, microphone, best quality, rating: general` etc.
386
+
config_README-ja.md ADDED
@@ -0,0 +1,388 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ `--dataset_config` で渡すことができる設定ファイルに関する説明です。
2
+
3
+ ## 概要
4
+
5
+ 設定ファイルを渡すことにより、ユーザが細かい設定を行えるようにします。
6
+
7
+ * 複数のデータセットが設定可能になります
8
+ * 例えば `resolution` をデータセットごとに設定して、それらを混合して学習できます。
9
+ * DreamBooth の手法と fine tuning の手法の両方に対応している学習方法では、DreamBooth 方式と fine tuning 方式のデータセットを混合することが可能です。
10
+ * サブセットごとに設定を変更することが可能になります
11
+ * データセットを画像ディレクトリ別またはメタデータ別に分割したものがサブセットです。いくつかのサブセットが集まってデータセットを構成します。
12
+ * `keep_tokens` や `flip_aug` 等のオプションはサブセットごとに設定可能です。一方、`resolution` や `batch_size` といったオプションはデータセットごとに設定可能で、同じデータセットに属するサブセットでは値が共通になります。詳しくは後述します。
13
+
14
+ 設定ファイルの形式は JSON か TOML を利用できます。記述のしやすさを考えると [TOML](https://toml.io/ja/v1.0.0-rc.2) を利用するのがオススメです。以下、TOML の利用を前提に説明します。
15
+
16
+ TOML で記述した設定ファイルの例です。
17
+
18
+ ```toml
19
+ [general]
20
+ shuffle_caption = true
21
+ caption_extension = '.txt'
22
+ keep_tokens = 1
23
+
24
+ # これは DreamBooth 方式のデータセット
25
+ [[datasets]]
26
+ resolution = 512
27
+ batch_size = 4
28
+ keep_tokens = 2
29
+
30
+ [[datasets.subsets]]
31
+ image_dir = 'C:\hoge'
32
+ class_tokens = 'hoge girl'
33
+ # このサブセットは keep_tokens = 2 (所属する datasets の値が使われる)
34
+
35
+ [[datasets.subsets]]
36
+ image_dir = 'C:\fuga'
37
+ class_tokens = 'fuga boy'
38
+ keep_tokens = 3
39
+
40
+ [[datasets.subsets]]
41
+ is_reg = true
42
+ image_dir = 'C:\reg'
43
+ class_tokens = 'human'
44
+ keep_tokens = 1
45
+
46
+ # これは fine tuning 方式のデータセット
47
+ [[datasets]]
48
+ resolution = [768, 768]
49
+ batch_size = 2
50
+
51
+ [[datasets.subsets]]
52
+ image_dir = 'C:\piyo'
53
+ metadata_file = 'C:\piyo\piyo_md.json'
54
+ # このサブセットは keep_tokens = 1 (general の値が使われる)
55
+ ```
56
+
57
+ この例では、3 つのディレクトリを DreamBooth 方式のデータセットとして 512x512 (batch size 4) で学習させ、1 つのディレクトリを fine tuning 方式のデータセットとして 768x768 (batch size 2) で学習させることになります。
58
+
59
+ ## データセット・サブセットに関する設定
60
+
61
+ データセット・サブセットに関する設定は、登録可能な箇所がいくつかに分かれています。
62
+
63
+ * `[general]`
64
+ * 全データセットまたは全サブセットに適用されるオプションを指定する箇所です。
65
+ * データセットごとの設定及びサブセットごとの設定に同名のオプションが存在していた場合には、データセット・サブセットごとの設定が優先されます。
66
+ * `[[datasets]]`
67
+ * `datasets` はデータセットに関する設定の登録箇所になります。各データセットに個別に適用されるオプションを指定する箇所です。
68
+ * サブセットごとの設定が存在していた場合には、サブセットごとの設定が優先されます。
69
+ * `[[datasets.subsets]]`
70
+ * `datasets.subsets` はサブセットに関する設定の登録箇所になります。各サブセットに個別に適用されるオプションを指定する箇所です。
71
+
72
+ 先程の例における、画像ディレクトリと登録箇所の対応に関するイメージ図です。
73
+
74
+ ```
75
+ C:\
76
+ ├─ hoge -> [[datasets.subsets]] No.1 ┐ ┐
77
+ ├─ fuga -> [[datasets.subsets]] No.2 |-> [[datasets]] No.1 |-> [general]
78
+ ├─ reg -> [[datasets.subsets]] No.3 ┘ |
79
+ └─ piyo -> [[datasets.subsets]] No.4 --> [[datasets]] No.2 ┘
80
+ ```
81
+
82
+ 画像ディレクトリがそれぞれ1つの `[[datasets.subsets]]` に対応しています。そして `[[datasets.subsets]]` が1つ以上組み合わさって1つの `[[datasets]]` を構成します。`[general]` には全ての `[[datasets]]`, `[[datasets.subsets]]` が属します。
83
+
84
+ 登録箇所ごとに指定可能なオプションは異なりますが、同名のオプションが指定された場合は下位の登録箇所にある値が優先されます。先程の例の `keep_tokens` オプションの扱われ方を確認してもらうと理解しやすいかと思います。
85
+
86
+ 加えて、学習方法が対応している手法によっても指定可能なオプションが変化します。
87
+
88
+ * DreamBooth 方式専用のオプション
89
+ * fine tuning 方式専用のオプション
90
+ * caption dropout の手法が使える場合のオプション
91
+
92
+ DreamBooth の手法と fine tuning の手法の両方とも利用可能な学習方法では、両者を併用することができます。
93
+ 併用する際の注意点として、DreamBooth 方式なのか fine tuning 方式なのかはデータセット単位で判別を行っているため、同じデータセット中に DreamBooth 方式のサブセットと fine tuning 方式のサブセットを混在させることはできません。
94
+ つまり、これらを併用したい場合には異なる方式のサブセットが異なるデータセットに所属するように設定する必要があります。
95
+
96
+ プログラムの挙動としては、後述する `metadata_file` オプションが存在していたら fine tuning 方式のサブセットだと判断します。
97
+ そのため、同一のデータセットに所属するサブセットについて言うと、「全てが `metadata_file` オプションを持つ」か「全てが `metadata_file` オプションを持たない」かのどちらかになっていれば問題ありません。
98
+
99
+ 以下、利用可能なオプションを説明します。コマンドライン引数と名称が同一のオプションについては、基本的に説明を割愛します。他の README を参照してください。
100
+
101
+ ### 全学習方法で共通のオプション
102
+
103
+ 学習方法によらずに指定可能なオプションです。
104
+
105
+ #### データセット向けオプション
106
+
107
+ データセットの設定に関わるオプションです。`datasets.subsets` には記述できません。
108
+
109
+ | オプション名 | 設定例 | `[general]` | `[[datasets]]` |
110
+ | ---- | ---- | ---- | ---- |
111
+ | `batch_size` | `1` | o | o |
112
+ | `bucket_no_upscale` | `true` | o | o |
113
+ | `bucket_reso_steps` | `64` | o | o |
114
+ | `enable_bucket` | `true` | o | o |
115
+ | `max_bucket_reso` | `1024` | o | o |
116
+ | `min_bucket_reso` | `128` | o | o |
117
+ | `resolution` | `256`, `[512, 512]` | o | o |
118
+
119
+ * `batch_size`
120
+ * コマンドライン引数の `--train_batch_size` と同等です。
121
+ * `max_bucket_reso`, `min_bucket_reso`
122
+ * bucketの最大、最小解像度を指定します。`bucket_reso_steps` で割り切れる必要があります。
123
+
124
+ これらの設定はデータセットごとに固定です。
125
+ つまり、データセットに所属するサブセットはこれらの設定を共有することになります。
126
+ 例えば解像度が異なるデータセットを用意したい場合は、上に挙げた例のように別々のデータセットとして定義すれば別々の解像度を設定可能です。
127
+
128
+ #### サブセット向けオプション
129
+
130
+ サブセットの設定に関わるオプションです。
131
+
132
+ | オプション名 | 設定例 | `[general]` | `[[datasets]]` | `[[dataset.subsets]]` |
133
+ | ---- | ---- | ---- | ---- | ---- |
134
+ | `color_aug` | `false` | o | o | o |
135
+ | `face_crop_aug_range` | `[1.0, 3.0]` | o | o | o |
136
+ | `flip_aug` | `true` | o | o | o |
137
+ | `keep_tokens` | `2` | o | o | o |
138
+ | `num_repeats` | `10` | o | o | o |
139
+ | `random_crop` | `false` | o | o | o |
140
+ | `shuffle_caption` | `true` | o | o | o |
141
+ | `caption_prefix` | `“masterpiece, best quality, ”` | o | o | o |
142
+ | `caption_suffix` | `“, from side”` | o | o | o |
143
+ | `caption_separator` | (通常は設定しません) | o | o | o |
144
+ | `keep_tokens_separator` | `“|||”` | o | o | o |
145
+ | `secondary_separator` | `“;;;”` | o | o | o |
146
+ | `enable_wildcard` | `true` | o | o | o |
147
+
148
+ * `num_repeats`
149
+ * サブセットの画像の繰り返し回数を指定します。fine tuning における `--dataset_repeats` に相当しますが、`num_repeats` はどの学習方法でも指定可能です。
150
+ * `caption_prefix`, `caption_suffix`
151
+ * キャプションの前、後に付与する文字列を指定します。シャッフルはこれらの文字列を含めた状態で行われます。`keep_tokens` を指定する場合には注意してください。
152
+
153
+ * `caption_separator`
154
+ * タグを区切る文字列を指定します。デフォルトは `,` です。このオプションは通常は設定する必要はありません。
155
+
156
+ * `keep_tokens_separator`
157
+ * キャプションで固定したい部分を区切る文字列を指定します。たとえば `aaa, bbb ||| ccc, ddd, eee, fff ||| ggg, hhh` のように指定すると、`aaa, bbb` と `ggg, hhh` の部分はシャッフル、drop されず残ります。間のカンマは不要です。結果としてプロンプトは `aaa, bbb, eee, ccc, fff, ggg, hhh` や `aaa, bbb, fff, ccc, eee, ggg, hhh` などになります。
158
+
159
+ * `secondary_separator`
160
+ * 追加の区切り文字を指定します。この区切り文字で区切られた部分は一つのタグとして扱われ、シャッフル、drop されます。その後、`caption_separator` に置き換えられます。たとえば `aaa;;;bbb;;;ccc` のように指定すると、`aaa,bbb,ccc` に置き換えられるか、まとめて drop されます。
161
+
162
+ * `enable_wildcard`
163
+ * ワイルドカード���法および複数行キャプションを有効にします。ワイルドカード記法、複数行キャプションについては後述します。
164
+
165
+ ### DreamBooth 方式専用のオプション
166
+
167
+ DreamBooth 方式のオプションは、サブセット向けオプションのみ存在します。
168
+
169
+ #### サブセット向けオプション
170
+
171
+ DreamBooth 方式のサブセットの設定に関わるオプションです。
172
+
173
+ | オプション名 | 設定例 | `[general]` | `[[datasets]]` | `[[dataset.subsets]]` |
174
+ | ---- | ---- | ---- | ---- | ---- |
175
+ | `image_dir` | `‘C:\hoge’` | - | - | o(必須) |
176
+ | `caption_extension` | `".txt"` | o | o | o |
177
+ | `class_tokens` | `“sks girl”` | - | - | o |
178
+ | `cache_info` | `false` | o | o | o |
179
+ | `is_reg` | `false` | - | - | o |
180
+
181
+ まず注意点として、 `image_dir` には画像ファイルが直下に置かれているパスを指定する必要があります。従来の DreamBooth の手法ではサブディレクトリに画像を置く必要がありましたが、そちらとは仕様に互換性がありません。また、`5_cat` のようなフォルダ名にしても、画像の繰り返し回数とクラス名は反映されません。これらを個別に設定したい場合、`num_repeats` と `class_tokens` で明示的に指定する必要があることに注意してください。
182
+
183
+ * `image_dir`
184
+ * 画像ディレクトリのパスを指定します。指定必須オプションです。
185
+ * 画像はディレクトリ直下に置かれている必要があります。
186
+ * `class_tokens`
187
+ * クラストークンを設定します。
188
+ * 画像に対応する caption ファイルが存在しない場合にのみ学習時に利用されます。利用するかどうかの判定は画像ごとに行います。`class_tokens` を指定しなかった場合に caption ファイルも見つからなかった場合にはエラーになります。
189
+ * `cache_info`
190
+ * 画像サイズ、キャプションをキャッシュするかどうかを指定します。指定しなかった場合は `false` になります。キャッシュは `image_dir` に `metadata_cache.json` というファイル名で保存されます。
191
+ * キャッシュを行うと、二回目以降のデータセット読み込みが高速化されます。数千枚以上の画像を扱う場合には有効です。
192
+ * `is_reg`
193
+ * サブセットの画像が正規化用かどうかを指定します。指定しなかった場合は `false` として、つまり正規化画像ではないとして扱います。
194
+
195
+ ### fine tuning 方式専用のオプション
196
+
197
+ fine tuning 方式のオプションは、サブセット向けオプションのみ存在します。
198
+
199
+ #### サブセット向けオプション
200
+
201
+ fine tuning 方式のサブセットの設定に関わるオプションです。
202
+
203
+ | オプション名 | 設定例 | `[general]` | `[[datasets]]` | `[[dataset.subsets]]` |
204
+ | ---- | ---- | ---- | ---- | ---- |
205
+ | `image_dir` | `‘C:\hoge’` | - | - | o |
206
+ | `metadata_file` | `'C:\piyo\piyo_md.json'` | - | - | o(必須) |
207
+
208
+ * `image_dir`
209
+ * 画像ディレクトリのパスを指定します。DreamBooth の手法の方とは異なり指定は必須ではありませんが、設定することを推奨します。
210
+ * 指定する必要がない状況としては、メタデータファイルの生成時に `--full_path` を付与して実行していた場合です。
211
+ * 画像はディレクトリ直下に置かれている必要があります。
212
+ * `metadata_file`
213
+ * サブセットで利用されるメタデータファイルのパスを指定します。指定必須オプションです。
214
+ * コマンドライン引数の `--in_json` と同等です。
215
+ * サブセットごとにメタデータファイルを指定する必要がある仕様上、ディレクトリを跨いだメタデータを1つのメタデータファイルとして作成することは避けた方が良いでしょう。画像ディレクトリごとにメタデータファイルを用意し、それらを別々のサブセットとして登録することを強く推奨します。
216
+
217
+ ### caption dropout の手法が使える場合に指定可能なオプション
218
+
219
+ caption dropout の手法が使える場合のオプションは、サブセット向けオプションのみ存在します。
220
+ DreamBooth 方式か fine tuning 方式かに関わらず、caption dropout に対応している学習方法であれば指定可能です。
221
+
222
+ #### サブセット向けオプション
223
+
224
+ caption dropout が使えるサブセットの設定に関わるオプションです。
225
+
226
+ | オプション名 | `[general]` | `[[datasets]]` | `[[dataset.subsets]]` |
227
+ | ---- | ---- | ---- | ---- |
228
+ | `caption_dropout_every_n_epochs` | o | o | o |
229
+ | `caption_dropout_rate` | o | o | o |
230
+ | `caption_tag_dropout_rate` | o | o | o |
231
+
232
+ ## 重複したサブセットが存在する時の挙動
233
+
234
+ DreamBooth 方式のデータセットの場合、���の中にある `image_dir` が同一のサブセットは重複していると見なされます。
235
+ fine tuning 方式のデータセットの場合は、その中にある `metadata_file` が同一のサブセットは重複していると見なされます。
236
+ データセット中に重複したサブセットが存在する場合、2個目以降は無視されます。
237
+
238
+ 一方、異なるデータセットに所属している場合は、重複しているとは見なされません。
239
+ 例えば、以下のように同一の `image_dir` を持つサブセットを別々のデータセットに入れた場合には、重複していないと見なします。
240
+ これは、同じ画像でも異なる解像度で学習したい場合に役立ちます。
241
+
242
+ ```toml
243
+ # 別々のデータセットに存在している場合は重複とは見なされず、両方とも学習に使われる
244
+
245
+ [[datasets]]
246
+ resolution = 512
247
+
248
+ [[datasets.subsets]]
249
+ image_dir = 'C:\hoge'
250
+
251
+ [[datasets]]
252
+ resolution = 768
253
+
254
+ [[datasets.subsets]]
255
+ image_dir = 'C:\hoge'
256
+ ```
257
+
258
+ ## コマンドライン引数との併用
259
+
260
+ 設定ファイルのオプションの中には、コマンドライン引数のオプションと役割が重複しているものがあります。
261
+
262
+ 以下に挙げるコマンドライン引数のオプションは、設定ファイルを渡した場合には無視されます。
263
+
264
+ * `--train_data_dir`
265
+ * `--reg_data_dir`
266
+ * `--in_json`
267
+
268
+ 以下に挙げるコマンドライン引数のオプションは、コマンドライン引数と設定ファイルで同時に指定された場合、コマンドライン引数の値よりも設定ファイルの値が優先されます。特に断りがなければ同名のオプションとなります。
269
+
270
+ | コマンドライン引数のオプション | 優先される設定ファイルのオプション |
271
+ | ---------------------------------- | ---------------------------------- |
272
+ | `--bucket_no_upscale` | |
273
+ | `--bucket_reso_steps` | |
274
+ | `--caption_dropout_every_n_epochs` | |
275
+ | `--caption_dropout_rate` | |
276
+ | `--caption_extension` | |
277
+ | `--caption_tag_dropout_rate` | |
278
+ | `--color_aug` | |
279
+ | `--dataset_repeats` | `num_repeats` |
280
+ | `--enable_bucket` | |
281
+ | `--face_crop_aug_range` | |
282
+ | `--flip_aug` | |
283
+ | `--keep_tokens` | |
284
+ | `--min_bucket_reso` | |
285
+ | `--random_crop` | |
286
+ | `--resolution` | |
287
+ | `--shuffle_caption` | |
288
+ | `--train_batch_size` | `batch_size` |
289
+
290
+ ## エラーの手引き
291
+
292
+ 現在、外部ライブラリを利用して設定ファイルの記述が正しいかどうかをチェックしているのですが、整備が行き届いておらずエラーメッセージがわかりづらいという問題があります。
293
+ 将来的にはこの問題の改善に取り組む予定です。
294
+
295
+ 次善策として、頻出のエラーとその対処法について載せておきます。
296
+ 正しいはずなのにエラーが出る場合、エラー内容がどうしても分からない場合は、バグかもしれないのでご連絡ください。
297
+
298
+ * `voluptuous.error.MultipleInvalid: required key not provided @ ...`: 指定必須のオプションが指定されていないというエラーです。指定を忘れているか、オプション名を間違って記述している可能性が高いです。
299
+ * `...` の箇所にはエラーが発生した場所が載っています。例えば `voluptuous.error.MultipleInvalid: required key not provided @ data['datasets'][0]['subsets'][0]['image_dir']` のようなエラーが出たら、0 番目の `datasets` 中の 0 番目の `subsets` の設定に `image_dir` が存在しないということになります。
300
+ * `voluptuous.error.MultipleInvalid: expected int for dictionary value @ ...`: 指定する値の形式が不正というエラーです。値の形式が間違っている可能性が高いです。`int` の部分は対象となるオプションによって変わります。この README に載っているオプションの「設定例」が役立つかもしれません。
301
+ * `voluptuous.error.MultipleInvalid: extra keys not allowed @ ...`: 対応していないオプション名が存在している場合に発生するエラーです。オプション名を間違って��述しているか、誤って紛れ込んでいる可能性が高いです。
302
+
303
+ ## その他
304
+
305
+ ### 複数行キャプション
306
+
307
+ `enable_wildcard = true` を設定することで、複数行キャプションも同時に有効になります。キャプションファイルが複数の行からなる場合、ランダムに一つの行が選ばれてキャプションとして利用されます。
308
+
309
+ ```txt
310
+ 1girl, hatsune miku, vocaloid, upper body, looking at viewer, microphone, stage
311
+ a girl with a microphone standing on a stage
312
+ detailed digital art of a girl with a microphone on a stage
313
+ ```
314
+
315
+ ワイルドカード記法と組み合わせることも可能です。
316
+
317
+ メタデータファイルでも同様に複数行キャプションを指定することができます。メタデータの .json 内には、`\n` を使って改行を表現してください。キャプションファイルが複数行からなる場合、`merge_captions_to_metadata.py` を使うと、この形式でメタデータファイルが作成されます。
318
+
319
+ メタデータのタグ (`tags`) は、キャプションの各行に追加されます。
320
+
321
+ ```json
322
+ {
323
+ "/path/to/image.png": {
324
+ "caption": "a cartoon of a frog with the word frog on it\ntest multiline caption1\ntest multiline caption2",
325
+ "tags": "open mouth, simple background, standing, no humans, animal, black background, frog, animal costume, animal focus"
326
+ },
327
+ ...
328
+ }
329
+ ```
330
+
331
+ この場合、実際のキャプションは `a cartoon of a frog with the word frog on it, open mouth, simple background ...` または `test multiline caption1, open mouth, simple background ...`、 `test multiline caption2, open mouth, simple background ...` 等になります。
332
+
333
+ ### 設定ファイルの記述例:追加の区切り文字、ワイルドカード記法、`keep_tokens_separator` 等
334
+
335
+ ```toml
336
+ [general]
337
+ flip_aug = true
338
+ color_aug = false
339
+ resolution = [1024, 1024]
340
+
341
+ [[datasets]]
342
+ batch_size = 6
343
+ enable_bucket = true
344
+ bucket_no_upscale = true
345
+ caption_extension = ".txt"
346
+ keep_tokens_separator= "|||"
347
+ shuffle_caption = true
348
+ caption_tag_dropout_rate = 0.1
349
+ secondary_separator = ";;;" # subset 側に書くこともできます / can be written in the subset side
350
+ enable_wildcard = true # 同上 / same as above
351
+
352
+ [[datasets.subsets]]
353
+ image_dir = "/path/to/image_dir"
354
+ num_repeats = 1
355
+
356
+ # ||| の前後はカンマは不要です(自動的に追加されます) / No comma is required before and after ||| (it is added automatically)
357
+ caption_prefix = "1girl, hatsune miku, vocaloid |||"
358
+
359
+ # ||| の後はシャッフル、drop されず残ります / After |||, it is not shuffled or dropped and remains
360
+ # 単純に文字列として連結されるので、カンマなどは自分で入れる必要があります / It is simply concatenated as a string, so you need to put commas yourself
361
+ caption_suffix = ", anime screencap ||| masterpiece, rating: general"
362
+ ```
363
+
364
+ ### キャプション記述例、secondary_separator 記法:`secondary_separator = ";;;"` の場合
365
+
366
+ ```txt
367
+ 1girl, hatsune miku, vocaloid, upper body, looking at viewer, sky;;;cloud;;;day, outdoors
368
+ ```
369
+ `sky;;;cloud;;;day` の部分はシャッフル、drop されず `sky,cloud,day` に置換されます。シャッフル、drop が有効な場合、まとめて(一つのタグとして)処理されます。つまり `vocaloid, 1girl, upper body, sky,cloud,day, outdoors, hatsune miku` (シャッフル)や `vocaloid, 1girl, outdoors, looking at viewer, upper body, hatsune miku` (drop されたケース)などになります。
370
+
371
+ ### キャプション記述例、ワイルドカード記法: `enable_wildcard = true` の場合
372
+
373
+ ```txt
374
+ 1girl, hatsune miku, vocaloid, upper body, looking at viewer, {simple|white} background
375
+ ```
376
+ ランダムに `simple` または `white` が選ばれ、`simple background` または `white background` になります。
377
+
378
+ ```txt
379
+ 1girl, hatsune miku, vocaloid, {{retro style}}
380
+ ```
381
+ タグ文字列に `{` や `}` そのものを含めたい場合は `{{` や `}}` のように二つ重ねてください(この例では実際に学習に用いられるキャプションは `{retro style}` になります)。
382
+
383
+ ### キャプション記述例、`keep_tokens_separator` 記法: `keep_tokens_separator = "|||"` の場合
384
+
385
+ ```txt
386
+ 1girl, hatsune miku, vocaloid ||| stage, microphone, white shirt, smile ||| best quality, rating: general
387
+ ```
388
+ `1girl, hatsune miku, vocaloid, microphone, stage, white shirt, best quality, rating: general` や `1girl, hatsune miku, vocaloid, white shirt, smile, stage, microphone, best quality, rating: general` などになります。
config_util.py ADDED
@@ -0,0 +1,721 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ from dataclasses import (
3
+ asdict,
4
+ dataclass,
5
+ )
6
+ import functools
7
+ import random
8
+ from textwrap import dedent, indent
9
+ import json
10
+ from pathlib import Path
11
+
12
+ # from toolz import curry
13
+ from typing import (
14
+ List,
15
+ Optional,
16
+ Sequence,
17
+ Tuple,
18
+ Union,
19
+ )
20
+
21
+ import toml
22
+ import voluptuous
23
+ from voluptuous import (
24
+ Any,
25
+ ExactSequence,
26
+ MultipleInvalid,
27
+ Object,
28
+ Required,
29
+ Schema,
30
+ )
31
+ from transformers import CLIPTokenizer
32
+
33
+ from . import train_util
34
+ from .train_util import (
35
+ DreamBoothSubset,
36
+ FineTuningSubset,
37
+ ControlNetSubset,
38
+ DreamBoothDataset,
39
+ FineTuningDataset,
40
+ ControlNetDataset,
41
+ DatasetGroup,
42
+ )
43
+ from .utils import setup_logging
44
+
45
+ setup_logging()
46
+ import logging
47
+
48
+ logger = logging.getLogger(__name__)
49
+
50
+
51
+ def add_config_arguments(parser: argparse.ArgumentParser):
52
+ parser.add_argument(
53
+ "--dataset_config", type=Path, default=None, help="config file for detail settings / 詳細な設定用の設定ファイル"
54
+ )
55
+
56
+
57
+ # TODO: inherit Params class in Subset, Dataset
58
+
59
+
60
+ @dataclass
61
+ class BaseSubsetParams:
62
+ image_dir: Optional[str] = None
63
+ num_repeats: int = 1
64
+ shuffle_caption: bool = False
65
+ caption_separator: str = (",",)
66
+ keep_tokens: int = 0
67
+ keep_tokens_separator: str = (None,)
68
+ secondary_separator: Optional[str] = None
69
+ enable_wildcard: bool = False
70
+ color_aug: bool = False
71
+ flip_aug: bool = False
72
+ face_crop_aug_range: Optional[Tuple[float, float]] = None
73
+ random_crop: bool = False
74
+ caption_prefix: Optional[str] = None
75
+ caption_suffix: Optional[str] = None
76
+ caption_dropout_rate: float = 0.0
77
+ caption_dropout_every_n_epochs: int = 0
78
+ caption_tag_dropout_rate: float = 0.0
79
+ token_warmup_min: int = 1
80
+ token_warmup_step: float = 0
81
+
82
+
83
+ @dataclass
84
+ class DreamBoothSubsetParams(BaseSubsetParams):
85
+ is_reg: bool = False
86
+ class_tokens: Optional[str] = None
87
+ caption_extension: str = ".caption"
88
+ cache_info: bool = False
89
+ alpha_mask: bool = False
90
+
91
+
92
+ @dataclass
93
+ class FineTuningSubsetParams(BaseSubsetParams):
94
+ metadata_file: Optional[str] = None
95
+ alpha_mask: bool = False
96
+
97
+
98
+ @dataclass
99
+ class ControlNetSubsetParams(BaseSubsetParams):
100
+ conditioning_data_dir: str = None
101
+ caption_extension: str = ".caption"
102
+ cache_info: bool = False
103
+
104
+
105
+ @dataclass
106
+ class BaseDatasetParams:
107
+ tokenizer: Union[CLIPTokenizer, List[CLIPTokenizer]] = None
108
+ max_token_length: int = None
109
+ resolution: Optional[Tuple[int, int]] = None
110
+ network_multiplier: float = 1.0
111
+ debug_dataset: bool = False
112
+
113
+
114
+ @dataclass
115
+ class DreamBoothDatasetParams(BaseDatasetParams):
116
+ batch_size: int = 1
117
+ enable_bucket: bool = False
118
+ min_bucket_reso: int = 256
119
+ max_bucket_reso: int = 1024
120
+ bucket_reso_steps: int = 64
121
+ bucket_no_upscale: bool = False
122
+ prior_loss_weight: float = 1.0
123
+
124
+
125
+ @dataclass
126
+ class FineTuningDatasetParams(BaseDatasetParams):
127
+ batch_size: int = 1
128
+ enable_bucket: bool = False
129
+ min_bucket_reso: int = 256
130
+ max_bucket_reso: int = 1024
131
+ bucket_reso_steps: int = 64
132
+ bucket_no_upscale: bool = False
133
+
134
+
135
+ @dataclass
136
+ class ControlNetDatasetParams(BaseDatasetParams):
137
+ batch_size: int = 1
138
+ enable_bucket: bool = False
139
+ min_bucket_reso: int = 256
140
+ max_bucket_reso: int = 1024
141
+ bucket_reso_steps: int = 64
142
+ bucket_no_upscale: bool = False
143
+
144
+
145
+ @dataclass
146
+ class SubsetBlueprint:
147
+ params: Union[DreamBoothSubsetParams, FineTuningSubsetParams]
148
+
149
+
150
+ @dataclass
151
+ class DatasetBlueprint:
152
+ is_dreambooth: bool
153
+ is_controlnet: bool
154
+ params: Union[DreamBoothDatasetParams, FineTuningDatasetParams]
155
+ subsets: Sequence[SubsetBlueprint]
156
+
157
+
158
+ @dataclass
159
+ class DatasetGroupBlueprint:
160
+ datasets: Sequence[DatasetBlueprint]
161
+
162
+
163
+ @dataclass
164
+ class Blueprint:
165
+ dataset_group: DatasetGroupBlueprint
166
+
167
+
168
+ class ConfigSanitizer:
169
+ # @curry
170
+ @staticmethod
171
+ def __validate_and_convert_twodim(klass, value: Sequence) -> Tuple:
172
+ Schema(ExactSequence([klass, klass]))(value)
173
+ return tuple(value)
174
+
175
+ # @curry
176
+ @staticmethod
177
+ def __validate_and_convert_scalar_or_twodim(klass, value: Union[float, Sequence]) -> Tuple:
178
+ Schema(Any(klass, ExactSequence([klass, klass])))(value)
179
+ try:
180
+ Schema(klass)(value)
181
+ return (value, value)
182
+ except:
183
+ return ConfigSanitizer.__validate_and_convert_twodim(klass, value)
184
+
185
+ # subset schema
186
+ SUBSET_ASCENDABLE_SCHEMA = {
187
+ "color_aug": bool,
188
+ "face_crop_aug_range": functools.partial(__validate_and_convert_twodim.__func__, float),
189
+ "flip_aug": bool,
190
+ "num_repeats": int,
191
+ "random_crop": bool,
192
+ "shuffle_caption": bool,
193
+ "keep_tokens": int,
194
+ "keep_tokens_separator": str,
195
+ "secondary_separator": str,
196
+ "caption_separator": str,
197
+ "enable_wildcard": bool,
198
+ "token_warmup_min": int,
199
+ "token_warmup_step": Any(float, int),
200
+ "caption_prefix": str,
201
+ "caption_suffix": str,
202
+ }
203
+ # DO means DropOut
204
+ DO_SUBSET_ASCENDABLE_SCHEMA = {
205
+ "caption_dropout_every_n_epochs": int,
206
+ "caption_dropout_rate": Any(float, int),
207
+ "caption_tag_dropout_rate": Any(float, int),
208
+ }
209
+ # DB means DreamBooth
210
+ DB_SUBSET_ASCENDABLE_SCHEMA = {
211
+ "caption_extension": str,
212
+ "class_tokens": str,
213
+ "cache_info": bool,
214
+ }
215
+ DB_SUBSET_DISTINCT_SCHEMA = {
216
+ Required("image_dir"): str,
217
+ "is_reg": bool,
218
+ "alpha_mask": bool,
219
+ }
220
+ # FT means FineTuning
221
+ FT_SUBSET_DISTINCT_SCHEMA = {
222
+ Required("metadata_file"): str,
223
+ "image_dir": str,
224
+ "alpha_mask": bool,
225
+ }
226
+ CN_SUBSET_ASCENDABLE_SCHEMA = {
227
+ "caption_extension": str,
228
+ "cache_info": bool,
229
+ }
230
+ CN_SUBSET_DISTINCT_SCHEMA = {
231
+ Required("image_dir"): str,
232
+ Required("conditioning_data_dir"): str,
233
+ }
234
+
235
+ # datasets schema
236
+ DATASET_ASCENDABLE_SCHEMA = {
237
+ "batch_size": int,
238
+ "bucket_no_upscale": bool,
239
+ "bucket_reso_steps": int,
240
+ "enable_bucket": bool,
241
+ "max_bucket_reso": int,
242
+ "min_bucket_reso": int,
243
+ "resolution": functools.partial(__validate_and_convert_scalar_or_twodim.__func__, int),
244
+ "network_multiplier": float,
245
+ }
246
+
247
+ # options handled by argparse but not handled by user config
248
+ ARGPARSE_SPECIFIC_SCHEMA = {
249
+ "debug_dataset": bool,
250
+ "max_token_length": Any(None, int),
251
+ "prior_loss_weight": Any(float, int),
252
+ }
253
+ # for handling default None value of argparse
254
+ ARGPARSE_NULLABLE_OPTNAMES = [
255
+ "face_crop_aug_range",
256
+ "resolution",
257
+ ]
258
+ # prepare map because option name may differ among argparse and user config
259
+ ARGPARSE_OPTNAME_TO_CONFIG_OPTNAME = {
260
+ "train_batch_size": "batch_size",
261
+ "dataset_repeats": "num_repeats",
262
+ }
263
+
264
+ def __init__(self, support_dreambooth: bool, support_finetuning: bool, support_controlnet: bool, support_dropout: bool) -> None:
265
+ assert support_dreambooth or support_finetuning or support_controlnet, (
266
+ "Neither DreamBooth mode nor fine tuning mode nor controlnet mode specified. Please specify one mode or more."
267
+ + " / DreamBooth モードか fine tuning モードか controlnet モードのどれも指定されていません。1つ以上指定してください。"
268
+ )
269
+
270
+ self.db_subset_schema = self.__merge_dict(
271
+ self.SUBSET_ASCENDABLE_SCHEMA,
272
+ self.DB_SUBSET_DISTINCT_SCHEMA,
273
+ self.DB_SUBSET_ASCENDABLE_SCHEMA,
274
+ self.DO_SUBSET_ASCENDABLE_SCHEMA if support_dropout else {},
275
+ )
276
+
277
+ self.ft_subset_schema = self.__merge_dict(
278
+ self.SUBSET_ASCENDABLE_SCHEMA,
279
+ self.FT_SUBSET_DISTINCT_SCHEMA,
280
+ self.DO_SUBSET_ASCENDABLE_SCHEMA if support_dropout else {},
281
+ )
282
+
283
+ self.cn_subset_schema = self.__merge_dict(
284
+ self.SUBSET_ASCENDABLE_SCHEMA,
285
+ self.CN_SUBSET_DISTINCT_SCHEMA,
286
+ self.CN_SUBSET_ASCENDABLE_SCHEMA,
287
+ self.DO_SUBSET_ASCENDABLE_SCHEMA if support_dropout else {},
288
+ )
289
+
290
+ self.db_dataset_schema = self.__merge_dict(
291
+ self.DATASET_ASCENDABLE_SCHEMA,
292
+ self.SUBSET_ASCENDABLE_SCHEMA,
293
+ self.DB_SUBSET_ASCENDABLE_SCHEMA,
294
+ self.DO_SUBSET_ASCENDABLE_SCHEMA if support_dropout else {},
295
+ {"subsets": [self.db_subset_schema]},
296
+ )
297
+
298
+ self.ft_dataset_schema = self.__merge_dict(
299
+ self.DATASET_ASCENDABLE_SCHEMA,
300
+ self.SUBSET_ASCENDABLE_SCHEMA,
301
+ self.DO_SUBSET_ASCENDABLE_SCHEMA if support_dropout else {},
302
+ {"subsets": [self.ft_subset_schema]},
303
+ )
304
+
305
+ self.cn_dataset_schema = self.__merge_dict(
306
+ self.DATASET_ASCENDABLE_SCHEMA,
307
+ self.SUBSET_ASCENDABLE_SCHEMA,
308
+ self.CN_SUBSET_ASCENDABLE_SCHEMA,
309
+ self.DO_SUBSET_ASCENDABLE_SCHEMA if support_dropout else {},
310
+ {"subsets": [self.cn_subset_schema]},
311
+ )
312
+
313
+ if support_dreambooth and support_finetuning:
314
+
315
+ def validate_flex_dataset(dataset_config: dict):
316
+ subsets_config = dataset_config.get("subsets", [])
317
+
318
+ if support_controlnet and all(["conditioning_data_dir" in subset for subset in subsets_config]):
319
+ return Schema(self.cn_dataset_schema)(dataset_config)
320
+ # check dataset meets FT style
321
+ # NOTE: all FT subsets should have "metadata_file"
322
+ elif all(["metadata_file" in subset for subset in subsets_config]):
323
+ return Schema(self.ft_dataset_schema)(dataset_config)
324
+ # check dataset meets DB style
325
+ # NOTE: all DB subsets should have no "metadata_file"
326
+ elif all(["metadata_file" not in subset for subset in subsets_config]):
327
+ return Schema(self.db_dataset_schema)(dataset_config)
328
+ else:
329
+ raise voluptuous.Invalid(
330
+ "DreamBooth subset and fine tuning subset cannot be mixed in the same dataset. Please split them into separate datasets. / DreamBoothのサブセットとfine tuninのサブセットを同一のデータセットに混在させることはできません。別々のデータセットに分割してください。"
331
+ )
332
+
333
+ self.dataset_schema = validate_flex_dataset
334
+ elif support_dreambooth:
335
+ if support_controlnet:
336
+ self.dataset_schema = self.cn_dataset_schema
337
+ else:
338
+ self.dataset_schema = self.db_dataset_schema
339
+ elif support_finetuning:
340
+ self.dataset_schema = self.ft_dataset_schema
341
+ elif support_controlnet:
342
+ self.dataset_schema = self.cn_dataset_schema
343
+
344
+ self.general_schema = self.__merge_dict(
345
+ self.DATASET_ASCENDABLE_SCHEMA,
346
+ self.SUBSET_ASCENDABLE_SCHEMA,
347
+ self.DB_SUBSET_ASCENDABLE_SCHEMA if support_dreambooth else {},
348
+ self.CN_SUBSET_ASCENDABLE_SCHEMA if support_controlnet else {},
349
+ self.DO_SUBSET_ASCENDABLE_SCHEMA if support_dropout else {},
350
+ )
351
+
352
+ self.user_config_validator = Schema(
353
+ {
354
+ "general": self.general_schema,
355
+ "datasets": [self.dataset_schema],
356
+ }
357
+ )
358
+
359
+ self.argparse_schema = self.__merge_dict(
360
+ self.general_schema,
361
+ self.ARGPARSE_SPECIFIC_SCHEMA,
362
+ {optname: Any(None, self.general_schema[optname]) for optname in self.ARGPARSE_NULLABLE_OPTNAMES},
363
+ {a_name: self.general_schema[c_name] for a_name, c_name in self.ARGPARSE_OPTNAME_TO_CONFIG_OPTNAME.items()},
364
+ )
365
+
366
+ self.argparse_config_validator = Schema(Object(self.argparse_schema), extra=voluptuous.ALLOW_EXTRA)
367
+
368
+ def sanitize_user_config(self, user_config: dict) -> dict:
369
+ try:
370
+ return self.user_config_validator(user_config)
371
+ except MultipleInvalid:
372
+ # TODO: エラー発生時のメッセージをわかりやすくする
373
+ logger.error("Invalid user config / ユーザ設定の形式が正しくないようです")
374
+ raise
375
+
376
+ # NOTE: In nature, argument parser result is not needed to be sanitize
377
+ # However this will help us to detect program bug
378
+ def sanitize_argparse_namespace(self, argparse_namespace: argparse.Namespace) -> argparse.Namespace:
379
+ try:
380
+ return self.argparse_config_validator(argparse_namespace)
381
+ except MultipleInvalid:
382
+ # XXX: this should be a bug
383
+ logger.error(
384
+ "Invalid cmdline parsed arguments. This should be a bug. / コマンドラインのパース結果が正しくないようです。プログラムのバグの可能性が高いです。"
385
+ )
386
+ raise
387
+
388
+ # NOTE: value would be overwritten by latter dict if there is already the same key
389
+ @staticmethod
390
+ def __merge_dict(*dict_list: dict) -> dict:
391
+ merged = {}
392
+ for schema in dict_list:
393
+ # merged |= schema
394
+ for k, v in schema.items():
395
+ merged[k] = v
396
+ return merged
397
+
398
+
399
+ class BlueprintGenerator:
400
+ BLUEPRINT_PARAM_NAME_TO_CONFIG_OPTNAME = {}
401
+
402
+ def __init__(self, sanitizer: ConfigSanitizer):
403
+ self.sanitizer = sanitizer
404
+
405
+ # runtime_params is for parameters which is only configurable on runtime, such as tokenizer
406
+ def generate(self, user_config: dict, argparse_namespace: argparse.Namespace, **runtime_params) -> Blueprint:
407
+ sanitized_user_config = self.sanitizer.sanitize_user_config(user_config)
408
+ sanitized_argparse_namespace = self.sanitizer.sanitize_argparse_namespace(argparse_namespace)
409
+
410
+ # convert argparse namespace to dict like config
411
+ # NOTE: it is ok to have extra entries in dict
412
+ optname_map = self.sanitizer.ARGPARSE_OPTNAME_TO_CONFIG_OPTNAME
413
+ argparse_config = {
414
+ optname_map.get(optname, optname): value for optname, value in vars(sanitized_argparse_namespace).items()
415
+ }
416
+
417
+ general_config = sanitized_user_config.get("general", {})
418
+
419
+ dataset_blueprints = []
420
+ for dataset_config in sanitized_user_config.get("datasets", []):
421
+ # NOTE: if subsets have no "metadata_file", these are DreamBooth datasets/subsets
422
+ subsets = dataset_config.get("subsets", [])
423
+ is_dreambooth = all(["metadata_file" not in subset for subset in subsets])
424
+ is_controlnet = all(["conditioning_data_dir" in subset for subset in subsets])
425
+ if is_controlnet:
426
+ subset_params_klass = ControlNetSubsetParams
427
+ dataset_params_klass = ControlNetDatasetParams
428
+ elif is_dreambooth:
429
+ subset_params_klass = DreamBoothSubsetParams
430
+ dataset_params_klass = DreamBoothDatasetParams
431
+ else:
432
+ subset_params_klass = FineTuningSubsetParams
433
+ dataset_params_klass = FineTuningDatasetParams
434
+
435
+ subset_blueprints = []
436
+ for subset_config in subsets:
437
+ params = self.generate_params_by_fallbacks(
438
+ subset_params_klass, [subset_config, dataset_config, general_config, argparse_config, runtime_params]
439
+ )
440
+ subset_blueprints.append(SubsetBlueprint(params))
441
+
442
+ params = self.generate_params_by_fallbacks(
443
+ dataset_params_klass, [dataset_config, general_config, argparse_config, runtime_params]
444
+ )
445
+ dataset_blueprints.append(DatasetBlueprint(is_dreambooth, is_controlnet, params, subset_blueprints))
446
+
447
+ dataset_group_blueprint = DatasetGroupBlueprint(dataset_blueprints)
448
+
449
+ return Blueprint(dataset_group_blueprint)
450
+
451
+ @staticmethod
452
+ def generate_params_by_fallbacks(param_klass, fallbacks: Sequence[dict]):
453
+ name_map = BlueprintGenerator.BLUEPRINT_PARAM_NAME_TO_CONFIG_OPTNAME
454
+ search_value = BlueprintGenerator.search_value
455
+ default_params = asdict(param_klass())
456
+ param_names = default_params.keys()
457
+
458
+ params = {name: search_value(name_map.get(name, name), fallbacks, default_params.get(name)) for name in param_names}
459
+
460
+ return param_klass(**params)
461
+
462
+ @staticmethod
463
+ def search_value(key: str, fallbacks: Sequence[dict], default_value=None):
464
+ for cand in fallbacks:
465
+ value = cand.get(key)
466
+ if value is not None:
467
+ return value
468
+
469
+ return default_value
470
+
471
+
472
+ def generate_dataset_group_by_blueprint(dataset_group_blueprint: DatasetGroupBlueprint):
473
+ datasets: List[Union[DreamBoothDataset, FineTuningDataset, ControlNetDataset]] = []
474
+
475
+ for dataset_blueprint in dataset_group_blueprint.datasets:
476
+ if dataset_blueprint.is_controlnet:
477
+ subset_klass = ControlNetSubset
478
+ dataset_klass = ControlNetDataset
479
+ elif dataset_blueprint.is_dreambooth:
480
+ subset_klass = DreamBoothSubset
481
+ dataset_klass = DreamBoothDataset
482
+ else:
483
+ subset_klass = FineTuningSubset
484
+ dataset_klass = FineTuningDataset
485
+
486
+ subsets = [subset_klass(**asdict(subset_blueprint.params)) for subset_blueprint in dataset_blueprint.subsets]
487
+ dataset = dataset_klass(subsets=subsets, **asdict(dataset_blueprint.params))
488
+ datasets.append(dataset)
489
+
490
+ # print info
491
+ info = ""
492
+ for i, dataset in enumerate(datasets):
493
+ is_dreambooth = isinstance(dataset, DreamBoothDataset)
494
+ is_controlnet = isinstance(dataset, ControlNetDataset)
495
+ info += dedent(
496
+ f"""\
497
+ [Dataset {i}]
498
+ batch_size: {dataset.batch_size}
499
+ resolution: {(dataset.width, dataset.height)}
500
+ enable_bucket: {dataset.enable_bucket}
501
+ network_multiplier: {dataset.network_multiplier}
502
+ """
503
+ )
504
+
505
+ if dataset.enable_bucket:
506
+ info += indent(
507
+ dedent(
508
+ f"""\
509
+ min_bucket_reso: {dataset.min_bucket_reso}
510
+ max_bucket_reso: {dataset.max_bucket_reso}
511
+ bucket_reso_steps: {dataset.bucket_reso_steps}
512
+ bucket_no_upscale: {dataset.bucket_no_upscale}
513
+ \n"""
514
+ ),
515
+ " ",
516
+ )
517
+ else:
518
+ info += "\n"
519
+
520
+ for j, subset in enumerate(dataset.subsets):
521
+ info += indent(
522
+ dedent(
523
+ f"""\
524
+ [Subset {j} of Dataset {i}]
525
+ image_dir: "{subset.image_dir}"
526
+ image_count: {subset.img_count}
527
+ num_repeats: {subset.num_repeats}
528
+ shuffle_caption: {subset.shuffle_caption}
529
+ keep_tokens: {subset.keep_tokens}
530
+ keep_tokens_separator: {subset.keep_tokens_separator}
531
+ caption_separator: {subset.caption_separator}
532
+ secondary_separator: {subset.secondary_separator}
533
+ enable_wildcard: {subset.enable_wildcard}
534
+ caption_dropout_rate: {subset.caption_dropout_rate}
535
+ caption_dropout_every_n_epoches: {subset.caption_dropout_every_n_epochs}
536
+ caption_tag_dropout_rate: {subset.caption_tag_dropout_rate}
537
+ caption_prefix: {subset.caption_prefix}
538
+ caption_suffix: {subset.caption_suffix}
539
+ color_aug: {subset.color_aug}
540
+ flip_aug: {subset.flip_aug}
541
+ face_crop_aug_range: {subset.face_crop_aug_range}
542
+ random_crop: {subset.random_crop}
543
+ token_warmup_min: {subset.token_warmup_min},
544
+ token_warmup_step: {subset.token_warmup_step},
545
+ alpha_mask: {subset.alpha_mask},
546
+ """
547
+ ),
548
+ " ",
549
+ )
550
+
551
+ if is_dreambooth:
552
+ info += indent(
553
+ dedent(
554
+ f"""\
555
+ is_reg: {subset.is_reg}
556
+ class_tokens: {subset.class_tokens}
557
+ caption_extension: {subset.caption_extension}
558
+ \n"""
559
+ ),
560
+ " ",
561
+ )
562
+ elif not is_controlnet:
563
+ info += indent(
564
+ dedent(
565
+ f"""\
566
+ metadata_file: {subset.metadata_file}
567
+ \n"""
568
+ ),
569
+ " ",
570
+ )
571
+
572
+ logger.info(f"{info}")
573
+
574
+ # make buckets first because it determines the length of dataset
575
+ # and set the same seed for all datasets
576
+ seed = random.randint(0, 2**31) # actual seed is seed + epoch_no
577
+ for i, dataset in enumerate(datasets):
578
+ logger.info(f"[Dataset {i}]")
579
+ dataset.make_buckets()
580
+ dataset.set_seed(seed)
581
+
582
+ return DatasetGroup(datasets)
583
+
584
+
585
+ def generate_dreambooth_subsets_config_by_subdirs(train_data_dir: Optional[str] = None, reg_data_dir: Optional[str] = None):
586
+ def extract_dreambooth_params(name: str) -> Tuple[int, str]:
587
+ tokens = name.split("_")
588
+ try:
589
+ n_repeats = int(tokens[0])
590
+ except ValueError as e:
591
+ logger.warning(f"ignore directory without repeats / 繰り返し回数のないディレクトリを無視します: {name}")
592
+ return 0, ""
593
+ caption_by_folder = "_".join(tokens[1:])
594
+ return n_repeats, caption_by_folder
595
+
596
+ def generate(base_dir: Optional[str], is_reg: bool):
597
+ if base_dir is None:
598
+ return []
599
+
600
+ base_dir: Path = Path(base_dir)
601
+ if not base_dir.is_dir():
602
+ return []
603
+
604
+ subsets_config = []
605
+ for subdir in base_dir.iterdir():
606
+ if not subdir.is_dir():
607
+ continue
608
+
609
+ num_repeats, class_tokens = extract_dreambooth_params(subdir.name)
610
+ if num_repeats < 1:
611
+ continue
612
+
613
+ subset_config = {"image_dir": str(subdir), "num_repeats": num_repeats, "is_reg": is_reg, "class_tokens": class_tokens}
614
+ subsets_config.append(subset_config)
615
+
616
+ return subsets_config
617
+
618
+ subsets_config = []
619
+ subsets_config += generate(train_data_dir, False)
620
+ subsets_config += generate(reg_data_dir, True)
621
+
622
+ return subsets_config
623
+
624
+
625
+ def generate_controlnet_subsets_config_by_subdirs(
626
+ train_data_dir: Optional[str] = None, conditioning_data_dir: Optional[str] = None, caption_extension: str = ".txt"
627
+ ):
628
+ def generate(base_dir: Optional[str]):
629
+ if base_dir is None:
630
+ return []
631
+
632
+ base_dir: Path = Path(base_dir)
633
+ if not base_dir.is_dir():
634
+ return []
635
+
636
+ subsets_config = []
637
+ subset_config = {
638
+ "image_dir": train_data_dir,
639
+ "conditioning_data_dir": conditioning_data_dir,
640
+ "caption_extension": caption_extension,
641
+ "num_repeats": 1,
642
+ }
643
+ subsets_config.append(subset_config)
644
+
645
+ return subsets_config
646
+
647
+ subsets_config = []
648
+ subsets_config += generate(train_data_dir)
649
+
650
+ return subsets_config
651
+
652
+
653
+ def load_user_config(file: str) -> dict:
654
+ file: Path = Path(file)
655
+ if not file.is_file():
656
+ raise ValueError(f"file not found / ファイルが見つかりません: {file}")
657
+
658
+ if file.name.lower().endswith(".json"):
659
+ try:
660
+ with open(file, "r") as f:
661
+ config = json.load(f)
662
+ except Exception:
663
+ logger.error(
664
+ f"Error on parsing JSON config file. Please check the format. / JSON 形式の設定ファイルの読み込みに失敗しました。文法が正しいか確認してください。: {file}"
665
+ )
666
+ raise
667
+ elif file.name.lower().endswith(".toml"):
668
+ try:
669
+ config = toml.load(file)
670
+ except Exception:
671
+ logger.error(
672
+ f"Error on parsing TOML config file. Please check the format. / TOML 形式の設定ファイルの読み込みに失敗しました。文法が正しいか確認してください。: {file}"
673
+ )
674
+ raise
675
+ else:
676
+ raise ValueError(f"not supported config file format / 対応していない設定ファイルの形式です: {file}")
677
+
678
+ return config
679
+
680
+
681
+ # for config test
682
+ if __name__ == "__main__":
683
+ parser = argparse.ArgumentParser()
684
+ parser.add_argument("--support_dreambooth", action="store_true")
685
+ parser.add_argument("--support_finetuning", action="store_true")
686
+ parser.add_argument("--support_controlnet", action="store_true")
687
+ parser.add_argument("--support_dropout", action="store_true")
688
+ parser.add_argument("dataset_config")
689
+ config_args, remain = parser.parse_known_args()
690
+
691
+ parser = argparse.ArgumentParser()
692
+ train_util.add_dataset_arguments(
693
+ parser, config_args.support_dreambooth, config_args.support_finetuning, config_args.support_dropout
694
+ )
695
+ train_util.add_training_arguments(parser, config_args.support_dreambooth)
696
+ argparse_namespace = parser.parse_args(remain)
697
+ train_util.prepare_dataset_args(argparse_namespace, config_args.support_finetuning)
698
+
699
+ logger.info("[argparse_namespace]")
700
+ logger.info(f"{vars(argparse_namespace)}")
701
+
702
+ user_config = load_user_config(config_args.dataset_config)
703
+
704
+ logger.info("")
705
+ logger.info("[user_config]")
706
+ logger.info(f"{user_config}")
707
+
708
+ sanitizer = ConfigSanitizer(
709
+ config_args.support_dreambooth, config_args.support_finetuning, config_args.support_controlnet, config_args.support_dropout
710
+ )
711
+ sanitized_user_config = sanitizer.sanitize_user_config(user_config)
712
+
713
+ logger.info("")
714
+ logger.info("[sanitized_user_config]")
715
+ logger.info(f"{sanitized_user_config}")
716
+
717
+ blueprint = BlueprintGenerator(sanitizer).generate(user_config, argparse_namespace)
718
+
719
+ logger.info("")
720
+ logger.info("[blueprint]")
721
+ logger.info(f"{blueprint}")
control_net_lllite.py ADDED
@@ -0,0 +1,449 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ from typing import Optional, List, Type
3
+ import torch
4
+ from library import sdxl_original_unet
5
+ from library.utils import setup_logging
6
+ setup_logging()
7
+ import logging
8
+ logger = logging.getLogger(__name__)
9
+
10
+ # input_blocksに適用するかどうか / if True, input_blocks are not applied
11
+ SKIP_INPUT_BLOCKS = False
12
+
13
+ # output_blocksに適用するかどうか / if True, output_blocks are not applied
14
+ SKIP_OUTPUT_BLOCKS = True
15
+
16
+ # conv2dに適用するかどうか / if True, conv2d are not applied
17
+ SKIP_CONV2D = False
18
+
19
+ # transformer_blocksのみに適用するかどうか。Trueの場合、ResBlockには適用されない
20
+ # if True, only transformer_blocks are applied, and ResBlocks are not applied
21
+ TRANSFORMER_ONLY = True # if True, SKIP_CONV2D is ignored because conv2d is not used in transformer_blocks
22
+
23
+ # Trueならattn1とattn2にのみ適用し、ffなどには適用しない / if True, apply only to attn1 and attn2, not to ff etc.
24
+ ATTN1_2_ONLY = True
25
+
26
+ # Trueならattn1のQKV、attn2のQにのみ適用する、ATTN1_2_ONLY指定時のみ有効 / if True, apply only to attn1 QKV and attn2 Q, only valid when ATTN1_2_ONLY is specified
27
+ ATTN_QKV_ONLY = True
28
+
29
+ # Trueならattn1やffなどにのみ適用し、attn2などには適用しない / if True, apply only to attn1 and ff, not to attn2
30
+ # ATTN1_2_ONLYと同時にTrueにできない / cannot be True at the same time as ATTN1_2_ONLY
31
+ ATTN1_ETC_ONLY = False # True
32
+
33
+ # transformer_blocksの最大インデックス。Noneなら全てのtransformer_blocksに適用
34
+ # max index of transformer_blocks. if None, apply to all transformer_blocks
35
+ TRANSFORMER_MAX_BLOCK_INDEX = None
36
+
37
+
38
+ class LLLiteModule(torch.nn.Module):
39
+ def __init__(self, depth, cond_emb_dim, name, org_module, mlp_dim, dropout=None, multiplier=1.0):
40
+ super().__init__()
41
+
42
+ self.is_conv2d = org_module.__class__.__name__ == "Conv2d"
43
+ self.lllite_name = name
44
+ self.cond_emb_dim = cond_emb_dim
45
+ self.org_module = [org_module]
46
+ self.dropout = dropout
47
+ self.multiplier = multiplier
48
+
49
+ if self.is_conv2d:
50
+ in_dim = org_module.in_channels
51
+ else:
52
+ in_dim = org_module.in_features
53
+
54
+ # conditioning1はconditioning imageを embedding する。timestepごとに呼ばれない
55
+ # conditioning1 embeds conditioning image. it is not called for each timestep
56
+ modules = []
57
+ modules.append(torch.nn.Conv2d(3, cond_emb_dim // 2, kernel_size=4, stride=4, padding=0)) # to latent (from VAE) size
58
+ if depth == 1:
59
+ modules.append(torch.nn.ReLU(inplace=True))
60
+ modules.append(torch.nn.Conv2d(cond_emb_dim // 2, cond_emb_dim, kernel_size=2, stride=2, padding=0))
61
+ elif depth == 2:
62
+ modules.append(torch.nn.ReLU(inplace=True))
63
+ modules.append(torch.nn.Conv2d(cond_emb_dim // 2, cond_emb_dim, kernel_size=4, stride=4, padding=0))
64
+ elif depth == 3:
65
+ # kernel size 8は大きすぎるので、4にする / kernel size 8 is too large, so set it to 4
66
+ modules.append(torch.nn.ReLU(inplace=True))
67
+ modules.append(torch.nn.Conv2d(cond_emb_dim // 2, cond_emb_dim // 2, kernel_size=4, stride=4, padding=0))
68
+ modules.append(torch.nn.ReLU(inplace=True))
69
+ modules.append(torch.nn.Conv2d(cond_emb_dim // 2, cond_emb_dim, kernel_size=2, stride=2, padding=0))
70
+
71
+ self.conditioning1 = torch.nn.Sequential(*modules)
72
+
73
+ # downで入力の次元数を削減する。LoRAにヒントを得ていることにする
74
+ # midでconditioning image embeddingと入力を結合する
75
+ # upで元の次元数に戻す
76
+ # これらはtimestepごとに呼ばれる
77
+ # reduce the number of input dimensions with down. inspired by LoRA
78
+ # combine conditioning image embedding and input with mid
79
+ # restore to the original dimension with up
80
+ # these are called for each timestep
81
+
82
+ if self.is_conv2d:
83
+ self.down = torch.nn.Sequential(
84
+ torch.nn.Conv2d(in_dim, mlp_dim, kernel_size=1, stride=1, padding=0),
85
+ torch.nn.ReLU(inplace=True),
86
+ )
87
+ self.mid = torch.nn.Sequential(
88
+ torch.nn.Conv2d(mlp_dim + cond_emb_dim, mlp_dim, kernel_size=1, stride=1, padding=0),
89
+ torch.nn.ReLU(inplace=True),
90
+ )
91
+ self.up = torch.nn.Sequential(
92
+ torch.nn.Conv2d(mlp_dim, in_dim, kernel_size=1, stride=1, padding=0),
93
+ )
94
+ else:
95
+ # midの前にconditioningをreshapeすること / reshape conditioning before mid
96
+ self.down = torch.nn.Sequential(
97
+ torch.nn.Linear(in_dim, mlp_dim),
98
+ torch.nn.ReLU(inplace=True),
99
+ )
100
+ self.mid = torch.nn.Sequential(
101
+ torch.nn.Linear(mlp_dim + cond_emb_dim, mlp_dim),
102
+ torch.nn.ReLU(inplace=True),
103
+ )
104
+ self.up = torch.nn.Sequential(
105
+ torch.nn.Linear(mlp_dim, in_dim),
106
+ )
107
+
108
+ # Zero-Convにする / set to Zero-Conv
109
+ torch.nn.init.zeros_(self.up[0].weight) # zero conv
110
+
111
+ self.depth = depth # 1~3
112
+ self.cond_emb = None
113
+ self.batch_cond_only = False # Trueなら推論時のcondにのみ適用する / if True, apply only to cond at inference
114
+ self.use_zeros_for_batch_uncond = False # Trueならuncondのconditioningを0にする / if True, set uncond conditioning to 0
115
+
116
+ # batch_cond_onlyとuse_zeros_for_batch_uncondはどちらも適用すると生成画像の色味がおかしくなるので実際には使えそうにない
117
+ # Controlの種類によっては使えるかも
118
+ # both batch_cond_only and use_zeros_for_batch_uncond make the color of the generated image strange, so it doesn't seem to be usable in practice
119
+ # it may be available depending on the type of Control
120
+
121
+ def set_cond_image(self, cond_image):
122
+ r"""
123
+ 中でモデルを呼び出すので必要ならwith torch.no_grad()で囲む
124
+ / call the model inside, so if necessary, surround it with torch.no_grad()
125
+ """
126
+ if cond_image is None:
127
+ self.cond_emb = None
128
+ return
129
+
130
+ # timestepごとに呼ばれないので、あらかじめ計算しておく / it is not called for each timestep, so calculate it in advance
131
+ # logger.info(f"C {self.lllite_name}, cond_image.shape={cond_image.shape}")
132
+ cx = self.conditioning1(cond_image)
133
+ if not self.is_conv2d:
134
+ # reshape / b,c,h,w -> b,h*w,c
135
+ n, c, h, w = cx.shape
136
+ cx = cx.view(n, c, h * w).permute(0, 2, 1)
137
+ self.cond_emb = cx
138
+
139
+ def set_batch_cond_only(self, cond_only, zeros):
140
+ self.batch_cond_only = cond_only
141
+ self.use_zeros_for_batch_uncond = zeros
142
+
143
+ def apply_to(self):
144
+ self.org_forward = self.org_module[0].forward
145
+ self.org_module[0].forward = self.forward
146
+
147
+ def forward(self, x):
148
+ r"""
149
+ 学習用の便利forward。元のモジュールのforwardを呼び出す
150
+ / convenient forward for training. call the forward of the original module
151
+ """
152
+ if self.multiplier == 0.0 or self.cond_emb is None:
153
+ return self.org_forward(x)
154
+
155
+ cx = self.cond_emb
156
+
157
+ if not self.batch_cond_only and x.shape[0] // 2 == cx.shape[0]: # inference only
158
+ cx = cx.repeat(2, 1, 1, 1) if self.is_conv2d else cx.repeat(2, 1, 1)
159
+ if self.use_zeros_for_batch_uncond:
160
+ cx[0::2] = 0.0 # uncond is zero
161
+ # logger.info(f"C {self.lllite_name}, x.shape={x.shape}, cx.shape={cx.shape}")
162
+
163
+ # downで入力の次元数を削減し、conditioning image embeddingと結合する
164
+ # 加算ではなくchannel方向に結合することで、うまいこと混ぜてくれることを期待している
165
+ # down reduces the number of input dimensions and combines it with conditioning image embedding
166
+ # we expect that it will mix well by combining in the channel direction instead of adding
167
+
168
+ cx = torch.cat([cx, self.down(x if not self.batch_cond_only else x[1::2])], dim=1 if self.is_conv2d else 2)
169
+ cx = self.mid(cx)
170
+
171
+ if self.dropout is not None and self.training:
172
+ cx = torch.nn.functional.dropout(cx, p=self.dropout)
173
+
174
+ cx = self.up(cx) * self.multiplier
175
+
176
+ # residual (x) を加算して元のforwardを呼び出す / add residual (x) and call the original forward
177
+ if self.batch_cond_only:
178
+ zx = torch.zeros_like(x)
179
+ zx[1::2] += cx
180
+ cx = zx
181
+
182
+ x = self.org_forward(x + cx) # ここで元のモジュールを呼び出す / call the original module here
183
+ return x
184
+
185
+
186
+ class ControlNetLLLite(torch.nn.Module):
187
+ UNET_TARGET_REPLACE_MODULE = ["Transformer2DModel"]
188
+ UNET_TARGET_REPLACE_MODULE_CONV2D_3X3 = ["ResnetBlock2D", "Downsample2D", "Upsample2D"]
189
+
190
+ def __init__(
191
+ self,
192
+ unet: sdxl_original_unet.SdxlUNet2DConditionModel,
193
+ cond_emb_dim: int = 16,
194
+ mlp_dim: int = 16,
195
+ dropout: Optional[float] = None,
196
+ varbose: Optional[bool] = False,
197
+ multiplier: Optional[float] = 1.0,
198
+ ) -> None:
199
+ super().__init__()
200
+ # self.unets = [unet]
201
+
202
+ def create_modules(
203
+ root_module: torch.nn.Module,
204
+ target_replace_modules: List[torch.nn.Module],
205
+ module_class: Type[object],
206
+ ) -> List[torch.nn.Module]:
207
+ prefix = "lllite_unet"
208
+
209
+ modules = []
210
+ for name, module in root_module.named_modules():
211
+ if module.__class__.__name__ in target_replace_modules:
212
+ for child_name, child_module in module.named_modules():
213
+ is_linear = child_module.__class__.__name__ == "Linear"
214
+ is_conv2d = child_module.__class__.__name__ == "Conv2d"
215
+
216
+ if is_linear or (is_conv2d and not SKIP_CONV2D):
217
+ # block indexからdepthを計算: depthはconditioningのサイズやチャネルを計算するのに使う
218
+ # block index to depth: depth is using to calculate conditioning size and channels
219
+ block_name, index1, index2 = (name + "." + child_name).split(".")[:3]
220
+ index1 = int(index1)
221
+ if block_name == "input_blocks":
222
+ if SKIP_INPUT_BLOCKS:
223
+ continue
224
+ depth = 1 if index1 <= 2 else (2 if index1 <= 5 else 3)
225
+ elif block_name == "middle_block":
226
+ depth = 3
227
+ elif block_name == "output_blocks":
228
+ if SKIP_OUTPUT_BLOCKS:
229
+ continue
230
+ depth = 3 if index1 <= 2 else (2 if index1 <= 5 else 1)
231
+ if int(index2) >= 2:
232
+ depth -= 1
233
+ else:
234
+ raise NotImplementedError()
235
+
236
+ lllite_name = prefix + "." + name + "." + child_name
237
+ lllite_name = lllite_name.replace(".", "_")
238
+
239
+ if TRANSFORMER_MAX_BLOCK_INDEX is not None:
240
+ p = lllite_name.find("transformer_blocks")
241
+ if p >= 0:
242
+ tf_index = int(lllite_name[p:].split("_")[2])
243
+ if tf_index > TRANSFORMER_MAX_BLOCK_INDEX:
244
+ continue
245
+
246
+ # time embは適用外とする
247
+ # attn2のconditioning (CLIPからの入力) はshapeが違うので適用できない
248
+ # time emb is not applied
249
+ # attn2 conditioning (input from CLIP) cannot be applied because the shape is different
250
+ if "emb_layers" in lllite_name or (
251
+ "attn2" in lllite_name and ("to_k" in lllite_name or "to_v" in lllite_name)
252
+ ):
253
+ continue
254
+
255
+ if ATTN1_2_ONLY:
256
+ if not ("attn1" in lllite_name or "attn2" in lllite_name):
257
+ continue
258
+ if ATTN_QKV_ONLY:
259
+ if "to_out" in lllite_name:
260
+ continue
261
+
262
+ if ATTN1_ETC_ONLY:
263
+ if "proj_out" in lllite_name:
264
+ pass
265
+ elif "attn1" in lllite_name and (
266
+ "to_k" in lllite_name or "to_v" in lllite_name or "to_out" in lllite_name
267
+ ):
268
+ pass
269
+ elif "ff_net_2" in lllite_name:
270
+ pass
271
+ else:
272
+ continue
273
+
274
+ module = module_class(
275
+ depth,
276
+ cond_emb_dim,
277
+ lllite_name,
278
+ child_module,
279
+ mlp_dim,
280
+ dropout=dropout,
281
+ multiplier=multiplier,
282
+ )
283
+ modules.append(module)
284
+ return modules
285
+
286
+ target_modules = ControlNetLLLite.UNET_TARGET_REPLACE_MODULE
287
+ if not TRANSFORMER_ONLY:
288
+ target_modules = target_modules + ControlNetLLLite.UNET_TARGET_REPLACE_MODULE_CONV2D_3X3
289
+
290
+ # create module instances
291
+ self.unet_modules: List[LLLiteModule] = create_modules(unet, target_modules, LLLiteModule)
292
+ logger.info(f"create ControlNet LLLite for U-Net: {len(self.unet_modules)} modules.")
293
+
294
+ def forward(self, x):
295
+ return x # dummy
296
+
297
+ def set_cond_image(self, cond_image):
298
+ r"""
299
+ 中でモデルを呼び出すので必要ならwith torch.no_grad()で囲む
300
+ / call the model inside, so if necessary, surround it with torch.no_grad()
301
+ """
302
+ for module in self.unet_modules:
303
+ module.set_cond_image(cond_image)
304
+
305
+ def set_batch_cond_only(self, cond_only, zeros):
306
+ for module in self.unet_modules:
307
+ module.set_batch_cond_only(cond_only, zeros)
308
+
309
+ def set_multiplier(self, multiplier):
310
+ for module in self.unet_modules:
311
+ module.multiplier = multiplier
312
+
313
+ def load_weights(self, file):
314
+ if os.path.splitext(file)[1] == ".safetensors":
315
+ from safetensors.torch import load_file
316
+
317
+ weights_sd = load_file(file)
318
+ else:
319
+ weights_sd = torch.load(file, map_location="cpu")
320
+
321
+ info = self.load_state_dict(weights_sd, False)
322
+ return info
323
+
324
+ def apply_to(self):
325
+ logger.info("applying LLLite for U-Net...")
326
+ for module in self.unet_modules:
327
+ module.apply_to()
328
+ self.add_module(module.lllite_name, module)
329
+
330
+ # マージできるかどうかを返す
331
+ def is_mergeable(self):
332
+ return False
333
+
334
+ def merge_to(self, text_encoder, unet, weights_sd, dtype, device):
335
+ raise NotImplementedError()
336
+
337
+ def enable_gradient_checkpointing(self):
338
+ # not supported
339
+ pass
340
+
341
+ def prepare_optimizer_params(self):
342
+ self.requires_grad_(True)
343
+ return self.parameters()
344
+
345
+ def prepare_grad_etc(self):
346
+ self.requires_grad_(True)
347
+
348
+ def on_epoch_start(self):
349
+ self.train()
350
+
351
+ def get_trainable_params(self):
352
+ return self.parameters()
353
+
354
+ def save_weights(self, file, dtype, metadata):
355
+ if metadata is not None and len(metadata) == 0:
356
+ metadata = None
357
+
358
+ state_dict = self.state_dict()
359
+
360
+ if dtype is not None:
361
+ for key in list(state_dict.keys()):
362
+ v = state_dict[key]
363
+ v = v.detach().clone().to("cpu").to(dtype)
364
+ state_dict[key] = v
365
+
366
+ if os.path.splitext(file)[1] == ".safetensors":
367
+ from safetensors.torch import save_file
368
+
369
+ save_file(state_dict, file, metadata)
370
+ else:
371
+ torch.save(state_dict, file)
372
+
373
+
374
+ if __name__ == "__main__":
375
+ # デバッグ用 / for debug
376
+
377
+ # sdxl_original_unet.USE_REENTRANT = False
378
+
379
+ # test shape etc
380
+ logger.info("create unet")
381
+ unet = sdxl_original_unet.SdxlUNet2DConditionModel()
382
+ unet.to("cuda").to(torch.float16)
383
+
384
+ logger.info("create ControlNet-LLLite")
385
+ control_net = ControlNetLLLite(unet, 32, 64)
386
+ control_net.apply_to()
387
+ control_net.to("cuda")
388
+
389
+ logger.info(control_net)
390
+
391
+ # logger.info number of parameters
392
+ logger.info(f"number of parameters {sum(p.numel() for p in control_net.parameters() if p.requires_grad)}")
393
+
394
+ input()
395
+
396
+ unet.set_use_memory_efficient_attention(True, False)
397
+ unet.set_gradient_checkpointing(True)
398
+ unet.train() # for gradient checkpointing
399
+
400
+ control_net.train()
401
+
402
+ # # visualize
403
+ # import torchviz
404
+ # logger.info("run visualize")
405
+ # controlnet.set_control(conditioning_image)
406
+ # output = unet(x, t, ctx, y)
407
+ # logger.info("make_dot")
408
+ # image = torchviz.make_dot(output, params=dict(controlnet.named_parameters()))
409
+ # logger.info("render")
410
+ # image.format = "svg" # "png"
411
+ # image.render("NeuralNet") # すごく時間がかかるので注意 / be careful because it takes a long time
412
+ # input()
413
+
414
+ import bitsandbytes
415
+
416
+ optimizer = bitsandbytes.adam.Adam8bit(control_net.prepare_optimizer_params(), 1e-3)
417
+
418
+ scaler = torch.cuda.amp.GradScaler(enabled=True)
419
+
420
+ logger.info("start training")
421
+ steps = 10
422
+
423
+ sample_param = [p for p in control_net.named_parameters() if "up" in p[0]][0]
424
+ for step in range(steps):
425
+ logger.info(f"step {step}")
426
+
427
+ batch_size = 1
428
+ conditioning_image = torch.rand(batch_size, 3, 1024, 1024).cuda() * 2.0 - 1.0
429
+ x = torch.randn(batch_size, 4, 128, 128).cuda()
430
+ t = torch.randint(low=0, high=10, size=(batch_size,)).cuda()
431
+ ctx = torch.randn(batch_size, 77, 2048).cuda()
432
+ y = torch.randn(batch_size, sdxl_original_unet.ADM_IN_CHANNELS).cuda()
433
+
434
+ with torch.cuda.amp.autocast(enabled=True):
435
+ control_net.set_cond_image(conditioning_image)
436
+
437
+ output = unet(x, t, ctx, y)
438
+ target = torch.randn_like(output)
439
+ loss = torch.nn.functional.mse_loss(output, target)
440
+
441
+ scaler.scale(loss).backward()
442
+ scaler.step(optimizer)
443
+ scaler.update()
444
+ optimizer.zero_grad(set_to_none=True)
445
+ logger.info(f"{sample_param}")
446
+
447
+ # from safetensors.torch import save_file
448
+
449
+ # save_file(control_net.state_dict(), "logs/control_net.safetensors")
control_net_lllite_for_train.py ADDED
@@ -0,0 +1,501 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # cond_imageをU-Netのforwardで渡すバージョンのControlNet-LLLite検証用実装
2
+ # ControlNet-LLLite implementation for verification with cond_image passed in U-Net's forward
3
+
4
+ import os
5
+ import re
6
+ from typing import Optional, List, Type
7
+ import torch
8
+ from library import sdxl_original_unet
9
+ from library.utils import setup_logging
10
+
11
+ setup_logging()
12
+ import logging
13
+
14
+ logger = logging.getLogger(__name__)
15
+
16
+ # input_blocksに適用するかどうか / if True, input_blocks are not applied
17
+ SKIP_INPUT_BLOCKS = False
18
+
19
+ # output_blocksに適用するかどうか / if True, output_blocks are not applied
20
+ SKIP_OUTPUT_BLOCKS = True
21
+
22
+ # conv2dに適用するかどうか / if True, conv2d are not applied
23
+ SKIP_CONV2D = False
24
+
25
+ # transformer_blocksのみに適用するかどうか。Trueの場合、ResBlockには適用されない
26
+ # if True, only transformer_blocks are applied, and ResBlocks are not applied
27
+ TRANSFORMER_ONLY = True # if True, SKIP_CONV2D is ignored because conv2d is not used in transformer_blocks
28
+
29
+ # Trueならattn1とattn2にのみ適用し、ffなどには適用しない / if True, apply only to attn1 and attn2, not to ff etc.
30
+ ATTN1_2_ONLY = True
31
+
32
+ # Trueならattn1のQKV、attn2のQにのみ適用する、ATTN1_2_ONLY指定時のみ有効 / if True, apply only to attn1 QKV and attn2 Q, only valid when ATTN1_2_ONLY is specified
33
+ ATTN_QKV_ONLY = True
34
+
35
+ # Trueならattn1やffなどにのみ適用し、attn2などには適用しない / if True, apply only to attn1 and ff, not to attn2
36
+ # ATTN1_2_ONLYと同時にTrueにできない / cannot be True at the same time as ATTN1_2_ONLY
37
+ ATTN1_ETC_ONLY = False # True
38
+
39
+ # transformer_blocksの最大インデックス。Noneなら全てのtransformer_blocksに適用
40
+ # max index of transformer_blocks. if None, apply to all transformer_blocks
41
+ TRANSFORMER_MAX_BLOCK_INDEX = None
42
+
43
+ ORIGINAL_LINEAR = torch.nn.Linear
44
+ ORIGINAL_CONV2D = torch.nn.Conv2d
45
+
46
+
47
+ def add_lllite_modules(module: torch.nn.Module, in_dim: int, depth, cond_emb_dim, mlp_dim) -> None:
48
+ # conditioning1はconditioning imageを embedding する。timestepごとに呼ばれない
49
+ # conditioning1 embeds conditioning image. it is not called for each timestep
50
+ modules = []
51
+ modules.append(ORIGINAL_CONV2D(3, cond_emb_dim // 2, kernel_size=4, stride=4, padding=0)) # to latent (from VAE) size
52
+ if depth == 1:
53
+ modules.append(torch.nn.ReLU(inplace=True))
54
+ modules.append(ORIGINAL_CONV2D(cond_emb_dim // 2, cond_emb_dim, kernel_size=2, stride=2, padding=0))
55
+ elif depth == 2:
56
+ modules.append(torch.nn.ReLU(inplace=True))
57
+ modules.append(ORIGINAL_CONV2D(cond_emb_dim // 2, cond_emb_dim, kernel_size=4, stride=4, padding=0))
58
+ elif depth == 3:
59
+ # kernel size 8は大きすぎるので、4にする / kernel size 8 is too large, so set it to 4
60
+ modules.append(torch.nn.ReLU(inplace=True))
61
+ modules.append(ORIGINAL_CONV2D(cond_emb_dim // 2, cond_emb_dim // 2, kernel_size=4, stride=4, padding=0))
62
+ modules.append(torch.nn.ReLU(inplace=True))
63
+ modules.append(ORIGINAL_CONV2D(cond_emb_dim // 2, cond_emb_dim, kernel_size=2, stride=2, padding=0))
64
+
65
+ module.lllite_conditioning1 = torch.nn.Sequential(*modules)
66
+
67
+ # downで入力の次元数を削減する。LoRAにヒントを得ていることにする
68
+ # midでconditioning image embeddingと入力を結合する
69
+ # upで元の次元数に戻す
70
+ # これらはtimestepごとに呼ばれる
71
+ # reduce the number of input dimensions with down. inspired by LoRA
72
+ # combine conditioning image embedding and input with mid
73
+ # restore to the original dimension with up
74
+ # these are called for each timestep
75
+
76
+ module.lllite_down = torch.nn.Sequential(
77
+ ORIGINAL_LINEAR(in_dim, mlp_dim),
78
+ torch.nn.ReLU(inplace=True),
79
+ )
80
+ module.lllite_mid = torch.nn.Sequential(
81
+ ORIGINAL_LINEAR(mlp_dim + cond_emb_dim, mlp_dim),
82
+ torch.nn.ReLU(inplace=True),
83
+ )
84
+ module.lllite_up = torch.nn.Sequential(
85
+ ORIGINAL_LINEAR(mlp_dim, in_dim),
86
+ )
87
+
88
+ # Zero-Convにする / set to Zero-Conv
89
+ torch.nn.init.zeros_(module.lllite_up[0].weight) # zero conv
90
+
91
+
92
+ class LLLiteLinear(ORIGINAL_LINEAR):
93
+ def __init__(self, in_features: int, out_features: int, **kwargs):
94
+ super().__init__(in_features, out_features, **kwargs)
95
+ self.enabled = False
96
+
97
+ def set_lllite(self, depth, cond_emb_dim, name, mlp_dim, dropout=None, multiplier=1.0):
98
+ self.enabled = True
99
+ self.lllite_name = name
100
+ self.cond_emb_dim = cond_emb_dim
101
+ self.dropout = dropout
102
+ self.multiplier = multiplier # ignored
103
+
104
+ in_dim = self.in_features
105
+ add_lllite_modules(self, in_dim, depth, cond_emb_dim, mlp_dim)
106
+
107
+ self.cond_image = None
108
+
109
+ def set_cond_image(self, cond_image):
110
+ self.cond_image = cond_image
111
+
112
+ def forward(self, x):
113
+ if not self.enabled:
114
+ return super().forward(x)
115
+
116
+ cx = self.lllite_conditioning1(self.cond_image) # make forward and backward compatible
117
+
118
+ # reshape / b,c,h,w -> b,h*w,c
119
+ n, c, h, w = cx.shape
120
+ cx = cx.view(n, c, h * w).permute(0, 2, 1)
121
+
122
+ cx = torch.cat([cx, self.lllite_down(x)], dim=2)
123
+ cx = self.lllite_mid(cx)
124
+
125
+ if self.dropout is not None and self.training:
126
+ cx = torch.nn.functional.dropout(cx, p=self.dropout)
127
+
128
+ cx = self.lllite_up(cx) * self.multiplier
129
+
130
+ x = super().forward(x + cx) # ここで元のモジュールを呼び出す / call the original module here
131
+ return x
132
+
133
+
134
+ class LLLiteConv2d(ORIGINAL_CONV2D):
135
+ def __init__(self, in_channels: int, out_channels: int, kernel_size, **kwargs):
136
+ super().__init__(in_channels, out_channels, kernel_size, **kwargs)
137
+ self.enabled = False
138
+
139
+ def set_lllite(self, depth, cond_emb_dim, name, mlp_dim, dropout=None, multiplier=1.0):
140
+ self.enabled = True
141
+ self.lllite_name = name
142
+ self.cond_emb_dim = cond_emb_dim
143
+ self.dropout = dropout
144
+ self.multiplier = multiplier # ignored
145
+
146
+ in_dim = self.in_channels
147
+ add_lllite_modules(self, in_dim, depth, cond_emb_dim, mlp_dim)
148
+
149
+ self.cond_image = None
150
+ self.cond_emb = None
151
+
152
+ def set_cond_image(self, cond_image):
153
+ self.cond_image = cond_image
154
+ self.cond_emb = None
155
+
156
+ def forward(self, x): # , cond_image=None):
157
+ if not self.enabled:
158
+ return super().forward(x)
159
+
160
+ cx = self.lllite_conditioning1(self.cond_image)
161
+
162
+ cx = torch.cat([cx, self.down(x)], dim=1)
163
+ cx = self.mid(cx)
164
+
165
+ if self.dropout is not None and self.training:
166
+ cx = torch.nn.functional.dropout(cx, p=self.dropout)
167
+
168
+ cx = self.up(cx) * self.multiplier
169
+
170
+ x = super().forward(x + cx) # ここで元のモジュールを呼び出す / call the original module here
171
+ return x
172
+
173
+
174
+ class SdxlUNet2DConditionModelControlNetLLLite(sdxl_original_unet.SdxlUNet2DConditionModel):
175
+ UNET_TARGET_REPLACE_MODULE = ["Transformer2DModel"]
176
+ UNET_TARGET_REPLACE_MODULE_CONV2D_3X3 = ["ResnetBlock2D", "Downsample2D", "Upsample2D"]
177
+ LLLITE_PREFIX = "lllite_unet"
178
+
179
+ def __init__(self, **kwargs):
180
+ super().__init__(**kwargs)
181
+
182
+ def apply_lllite(
183
+ self,
184
+ cond_emb_dim: int = 16,
185
+ mlp_dim: int = 16,
186
+ dropout: Optional[float] = None,
187
+ varbose: Optional[bool] = False,
188
+ multiplier: Optional[float] = 1.0,
189
+ ) -> None:
190
+ def apply_to_modules(
191
+ root_module: torch.nn.Module,
192
+ target_replace_modules: List[torch.nn.Module],
193
+ ) -> List[torch.nn.Module]:
194
+ prefix = "lllite_unet"
195
+
196
+ modules = []
197
+ for name, module in root_module.named_modules():
198
+ if module.__class__.__name__ in target_replace_modules:
199
+ for child_name, child_module in module.named_modules():
200
+ is_linear = child_module.__class__.__name__ == "LLLiteLinear"
201
+ is_conv2d = child_module.__class__.__name__ == "LLLiteConv2d"
202
+
203
+ if is_linear or (is_conv2d and not SKIP_CONV2D):
204
+ # block indexからdepthを計算: depthはconditioningのサイズやチャネルを計算するのに使う
205
+ # block index to depth: depth is using to calculate conditioning size and channels
206
+ block_name, index1, index2 = (name + "." + child_name).split(".")[:3]
207
+ index1 = int(index1)
208
+ if block_name == "input_blocks":
209
+ if SKIP_INPUT_BLOCKS:
210
+ continue
211
+ depth = 1 if index1 <= 2 else (2 if index1 <= 5 else 3)
212
+ elif block_name == "middle_block":
213
+ depth = 3
214
+ elif block_name == "output_blocks":
215
+ if SKIP_OUTPUT_BLOCKS:
216
+ continue
217
+ depth = 3 if index1 <= 2 else (2 if index1 <= 5 else 1)
218
+ if int(index2) >= 2:
219
+ depth -= 1
220
+ else:
221
+ raise NotImplementedError()
222
+
223
+ lllite_name = prefix + "." + name + "." + child_name
224
+ lllite_name = lllite_name.replace(".", "_")
225
+
226
+ if TRANSFORMER_MAX_BLOCK_INDEX is not None:
227
+ p = lllite_name.find("transformer_blocks")
228
+ if p >= 0:
229
+ tf_index = int(lllite_name[p:].split("_")[2])
230
+ if tf_index > TRANSFORMER_MAX_BLOCK_INDEX:
231
+ continue
232
+
233
+ # time embは適用外とする
234
+ # attn2のconditioning (CLIPからの入力) はshapeが違うので適用できない
235
+ # time emb is not applied
236
+ # attn2 conditioning (input from CLIP) cannot be applied because the shape is different
237
+ if "emb_layers" in lllite_name or (
238
+ "attn2" in lllite_name and ("to_k" in lllite_name or "to_v" in lllite_name)
239
+ ):
240
+ continue
241
+
242
+ if ATTN1_2_ONLY:
243
+ if not ("attn1" in lllite_name or "attn2" in lllite_name):
244
+ continue
245
+ if ATTN_QKV_ONLY:
246
+ if "to_out" in lllite_name:
247
+ continue
248
+
249
+ if ATTN1_ETC_ONLY:
250
+ if "proj_out" in lllite_name:
251
+ pass
252
+ elif "attn1" in lllite_name and (
253
+ "to_k" in lllite_name or "to_v" in lllite_name or "to_out" in lllite_name
254
+ ):
255
+ pass
256
+ elif "ff_net_2" in lllite_name:
257
+ pass
258
+ else:
259
+ continue
260
+
261
+ child_module.set_lllite(depth, cond_emb_dim, lllite_name, mlp_dim, dropout, multiplier)
262
+ modules.append(child_module)
263
+
264
+ return modules
265
+
266
+ target_modules = SdxlUNet2DConditionModelControlNetLLLite.UNET_TARGET_REPLACE_MODULE
267
+ if not TRANSFORMER_ONLY:
268
+ target_modules = target_modules + SdxlUNet2DConditionModelControlNetLLLite.UNET_TARGET_REPLACE_MODULE_CONV2D_3X3
269
+
270
+ # create module instances
271
+ self.lllite_modules = apply_to_modules(self, target_modules)
272
+ logger.info(f"enable ControlNet LLLite for U-Net: {len(self.lllite_modules)} modules.")
273
+
274
+ # def prepare_optimizer_params(self):
275
+ def prepare_params(self):
276
+ train_params = []
277
+ non_train_params = []
278
+ for name, p in self.named_parameters():
279
+ if "lllite" in name:
280
+ train_params.append(p)
281
+ else:
282
+ non_train_params.append(p)
283
+ logger.info(f"count of trainable parameters: {len(train_params)}")
284
+ logger.info(f"count of non-trainable parameters: {len(non_train_params)}")
285
+
286
+ for p in non_train_params:
287
+ p.requires_grad_(False)
288
+
289
+ # without this, an error occurs in the optimizer
290
+ # RuntimeError: element 0 of tensors does not require grad and does not have a grad_fn
291
+ non_train_params[0].requires_grad_(True)
292
+
293
+ for p in train_params:
294
+ p.requires_grad_(True)
295
+
296
+ return train_params
297
+
298
+ # def prepare_grad_etc(self):
299
+ # self.requires_grad_(True)
300
+
301
+ # def on_epoch_start(self):
302
+ # self.train()
303
+
304
+ def get_trainable_params(self):
305
+ return [p[1] for p in self.named_parameters() if "lllite" in p[0]]
306
+
307
+ def save_lllite_weights(self, file, dtype, metadata):
308
+ if metadata is not None and len(metadata) == 0:
309
+ metadata = None
310
+
311
+ org_state_dict = self.state_dict()
312
+
313
+ # copy LLLite keys from org_state_dict to state_dict with key conversion
314
+ state_dict = {}
315
+ for key in org_state_dict.keys():
316
+ # split with ".lllite"
317
+ pos = key.find(".lllite")
318
+ if pos < 0:
319
+ continue
320
+ lllite_key = SdxlUNet2DConditionModelControlNetLLLite.LLLITE_PREFIX + "." + key[:pos]
321
+ lllite_key = lllite_key.replace(".", "_") + key[pos:]
322
+ lllite_key = lllite_key.replace(".lllite_", ".")
323
+ state_dict[lllite_key] = org_state_dict[key]
324
+
325
+ if dtype is not None:
326
+ for key in list(state_dict.keys()):
327
+ v = state_dict[key]
328
+ v = v.detach().clone().to("cpu").to(dtype)
329
+ state_dict[key] = v
330
+
331
+ if os.path.splitext(file)[1] == ".safetensors":
332
+ from safetensors.torch import save_file
333
+
334
+ save_file(state_dict, file, metadata)
335
+ else:
336
+ torch.save(state_dict, file)
337
+
338
+ def load_lllite_weights(self, file, non_lllite_unet_sd=None):
339
+ r"""
340
+ LLLiteの重みを読み込まない(initされた値を使う)場合はfileにNoneを指定する。
341
+ この場合、non_lllite_unet_sdにはU-Netのstate_dictを指定する。
342
+
343
+ If you do not want to load LLLite weights (use initialized values), specify None for file.
344
+ In this case, specify the state_dict of U-Net for non_lllite_unet_sd.
345
+ """
346
+ if not file:
347
+ state_dict = self.state_dict()
348
+ for key in non_lllite_unet_sd:
349
+ if key in state_dict:
350
+ state_dict[key] = non_lllite_unet_sd[key]
351
+ info = self.load_state_dict(state_dict, False)
352
+ return info
353
+
354
+ if os.path.splitext(file)[1] == ".safetensors":
355
+ from safetensors.torch import load_file
356
+
357
+ weights_sd = load_file(file)
358
+ else:
359
+ weights_sd = torch.load(file, map_location="cpu")
360
+
361
+ # module_name = module_name.replace("_block", "@blocks")
362
+ # module_name = module_name.replace("_layer", "@layer")
363
+ # module_name = module_name.replace("to_", "to@")
364
+ # module_name = module_name.replace("time_embed", "time@embed")
365
+ # module_name = module_name.replace("label_emb", "label@emb")
366
+ # module_name = module_name.replace("skip_connection", "skip@connection")
367
+ # module_name = module_name.replace("proj_in", "proj@in")
368
+ # module_name = module_name.replace("proj_out", "proj@out")
369
+ pattern = re.compile(r"(_block|_layer|to_|time_embed|label_emb|skip_connection|proj_in|proj_out)")
370
+
371
+ # convert to lllite with U-Net state dict
372
+ state_dict = non_lllite_unet_sd.copy() if non_lllite_unet_sd is not None else {}
373
+ for key in weights_sd.keys():
374
+ # split with "."
375
+ pos = key.find(".")
376
+ if pos < 0:
377
+ continue
378
+
379
+ module_name = key[:pos]
380
+ weight_name = key[pos + 1 :] # exclude "."
381
+ module_name = module_name.replace(SdxlUNet2DConditionModelControlNetLLLite.LLLITE_PREFIX + "_", "")
382
+
383
+ # これはうまくいかない。逆変換を考えなかった設計が悪い / this does not work well. bad design because I didn't think about inverse conversion
384
+ # module_name = module_name.replace("_", ".")
385
+
386
+ # ださいけどSDXLのU-Netの "_" を "@" に変換する / ugly but convert "_" of SDXL U-Net to "@"
387
+ matches = pattern.findall(module_name)
388
+ if matches is not None:
389
+ for m in matches:
390
+ logger.info(f"{module_name} {m}")
391
+ module_name = module_name.replace(m, m.replace("_", "@"))
392
+ module_name = module_name.replace("_", ".")
393
+ module_name = module_name.replace("@", "_")
394
+
395
+ lllite_key = module_name + ".lllite_" + weight_name
396
+
397
+ state_dict[lllite_key] = weights_sd[key]
398
+
399
+ info = self.load_state_dict(state_dict, False)
400
+ return info
401
+
402
+ def forward(self, x, timesteps=None, context=None, y=None, cond_image=None, **kwargs):
403
+ for m in self.lllite_modules:
404
+ m.set_cond_image(cond_image)
405
+ return super().forward(x, timesteps, context, y, **kwargs)
406
+
407
+
408
+ def replace_unet_linear_and_conv2d():
409
+ logger.info("replace torch.nn.Linear and torch.nn.Conv2d to LLLiteLinear and LLLiteConv2d in U-Net")
410
+ sdxl_original_unet.torch.nn.Linear = LLLiteLinear
411
+ sdxl_original_unet.torch.nn.Conv2d = LLLiteConv2d
412
+
413
+
414
+ if __name__ == "__main__":
415
+ # デバッグ用 / for debug
416
+
417
+ # sdxl_original_unet.USE_REENTRANT = False
418
+ replace_unet_linear_and_conv2d()
419
+
420
+ # test shape etc
421
+ logger.info("create unet")
422
+ unet = SdxlUNet2DConditionModelControlNetLLLite()
423
+
424
+ logger.info("enable ControlNet-LLLite")
425
+ unet.apply_lllite(32, 64, None, False, 1.0)
426
+ unet.to("cuda") # .to(torch.float16)
427
+
428
+ # from safetensors.torch import load_file
429
+
430
+ # model_sd = load_file(r"E:\Work\SD\Models\sdxl\sd_xl_base_1.0_0.9vae.safetensors")
431
+ # unet_sd = {}
432
+
433
+ # # copy U-Net keys from unet_state_dict to state_dict
434
+ # prefix = "model.diffusion_model."
435
+ # for key in model_sd.keys():
436
+ # if key.startswith(prefix):
437
+ # converted_key = key[len(prefix) :]
438
+ # unet_sd[converted_key] = model_sd[key]
439
+
440
+ # info = unet.load_lllite_weights("r:/lllite_from_unet.safetensors", unet_sd)
441
+ # logger.info(info)
442
+
443
+ # logger.info(unet)
444
+
445
+ # logger.info number of parameters
446
+ params = unet.prepare_params()
447
+ logger.info(f"number of parameters {sum(p.numel() for p in params)}")
448
+ # logger.info("type any key to continue")
449
+ # input()
450
+
451
+ unet.set_use_memory_efficient_attention(True, False)
452
+ unet.set_gradient_checkpointing(True)
453
+ unet.train() # for gradient checkpointing
454
+
455
+ # # visualize
456
+ # import torchviz
457
+ # logger.info("run visualize")
458
+ # controlnet.set_control(conditioning_image)
459
+ # output = unet(x, t, ctx, y)
460
+ # logger.info("make_dot")
461
+ # image = torchviz.make_dot(output, params=dict(controlnet.named_parameters()))
462
+ # logger.info("render")
463
+ # image.format = "svg" # "png"
464
+ # image.render("NeuralNet") # すごく時間がかかるので注意 / be careful because it takes a long time
465
+ # input()
466
+
467
+ import bitsandbytes
468
+
469
+ optimizer = bitsandbytes.adam.Adam8bit(params, 1e-3)
470
+
471
+ scaler = torch.cuda.amp.GradScaler(enabled=True)
472
+
473
+ logger.info("start training")
474
+ steps = 10
475
+ batch_size = 1
476
+
477
+ sample_param = [p for p in unet.named_parameters() if ".lllite_up." in p[0]][0]
478
+ for step in range(steps):
479
+ logger.info(f"step {step}")
480
+
481
+ conditioning_image = torch.rand(batch_size, 3, 1024, 1024).cuda() * 2.0 - 1.0
482
+ x = torch.randn(batch_size, 4, 128, 128).cuda()
483
+ t = torch.randint(low=0, high=10, size=(batch_size,)).cuda()
484
+ ctx = torch.randn(batch_size, 77, 2048).cuda()
485
+ y = torch.randn(batch_size, sdxl_original_unet.ADM_IN_CHANNELS).cuda()
486
+
487
+ with torch.cuda.amp.autocast(enabled=True, dtype=torch.bfloat16):
488
+ output = unet(x, t, ctx, y, conditioning_image)
489
+ target = torch.randn_like(output)
490
+ loss = torch.nn.functional.mse_loss(output, target)
491
+
492
+ scaler.scale(loss).backward()
493
+ scaler.step(optimizer)
494
+ scaler.update()
495
+ optimizer.zero_grad(set_to_none=True)
496
+ logger.info(sample_param)
497
+
498
+ # from safetensors.torch import save_file
499
+
500
+ # logger.info("save weights")
501
+ # unet.save_lllite_weights("r:/lllite_from_unet.safetensors", torch.float16, None)
convert_diffusers20_original_sd.py ADDED
@@ -0,0 +1,163 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # convert Diffusers v1.x/v2.0 model to original Stable Diffusion
2
+
3
+ import argparse
4
+ import os
5
+ import torch
6
+ from diffusers import StableDiffusionPipeline
7
+
8
+ import library.model_util as model_util
9
+ from library.utils import setup_logging
10
+ setup_logging()
11
+ import logging
12
+ logger = logging.getLogger(__name__)
13
+
14
+ def convert(args):
15
+ # 引数を確認する
16
+ load_dtype = torch.float16 if args.fp16 else None
17
+
18
+ save_dtype = None
19
+ if args.fp16 or args.save_precision_as == "fp16":
20
+ save_dtype = torch.float16
21
+ elif args.bf16 or args.save_precision_as == "bf16":
22
+ save_dtype = torch.bfloat16
23
+ elif args.float or args.save_precision_as == "float":
24
+ save_dtype = torch.float
25
+
26
+ is_load_ckpt = os.path.isfile(args.model_to_load)
27
+ is_save_ckpt = len(os.path.splitext(args.model_to_save)[1]) > 0
28
+
29
+ assert not is_load_ckpt or args.v1 != args.v2, "v1 or v2 is required to load checkpoint / checkpointの読み込みにはv1/v2指定が必要です"
30
+ # assert (
31
+ # is_save_ckpt or args.reference_model is not None
32
+ # ), f"reference model is required to save as Diffusers / Diffusers形式での保存には参照モデルが必要です"
33
+
34
+ # モデルを読み込む
35
+ msg = "checkpoint" if is_load_ckpt else ("Diffusers" + (" as fp16" if args.fp16 else ""))
36
+ logger.info(f"loading {msg}: {args.model_to_load}")
37
+
38
+ if is_load_ckpt:
39
+ v2_model = args.v2
40
+ text_encoder, vae, unet = model_util.load_models_from_stable_diffusion_checkpoint(
41
+ v2_model, args.model_to_load, unet_use_linear_projection_in_v2=args.unet_use_linear_projection
42
+ )
43
+ else:
44
+ pipe = StableDiffusionPipeline.from_pretrained(
45
+ args.model_to_load, torch_dtype=load_dtype, tokenizer=None, safety_checker=None, variant=args.variant
46
+ )
47
+ text_encoder = pipe.text_encoder
48
+ vae = pipe.vae
49
+ unet = pipe.unet
50
+
51
+ if args.v1 == args.v2:
52
+ # 自動判定する
53
+ v2_model = unet.config.cross_attention_dim == 1024
54
+ logger.info("checking model version: model is " + ("v2" if v2_model else "v1"))
55
+ else:
56
+ v2_model = not args.v1
57
+
58
+ # 変換して保存する
59
+ msg = ("checkpoint" + ("" if save_dtype is None else f" in {save_dtype}")) if is_save_ckpt else "Diffusers"
60
+ logger.info(f"converting and saving as {msg}: {args.model_to_save}")
61
+
62
+ if is_save_ckpt:
63
+ original_model = args.model_to_load if is_load_ckpt else None
64
+ key_count = model_util.save_stable_diffusion_checkpoint(
65
+ v2_model,
66
+ args.model_to_save,
67
+ text_encoder,
68
+ unet,
69
+ original_model,
70
+ args.epoch,
71
+ args.global_step,
72
+ None if args.metadata is None else eval(args.metadata),
73
+ save_dtype=save_dtype,
74
+ vae=vae,
75
+ )
76
+ logger.info(f"model saved. total converted state_dict keys: {key_count}")
77
+ else:
78
+ logger.info(
79
+ f"copy scheduler/tokenizer config from: {args.reference_model if args.reference_model is not None else 'default model'}"
80
+ )
81
+ model_util.save_diffusers_checkpoint(
82
+ v2_model, args.model_to_save, text_encoder, unet, args.reference_model, vae, args.use_safetensors
83
+ )
84
+ logger.info("model saved.")
85
+
86
+
87
+ def setup_parser() -> argparse.ArgumentParser:
88
+ parser = argparse.ArgumentParser()
89
+ parser.add_argument(
90
+ "--v1", action="store_true", help="load v1.x model (v1 or v2 is required to load checkpoint) / 1.xのモデルを読み込む"
91
+ )
92
+ parser.add_argument(
93
+ "--v2", action="store_true", help="load v2.0 model (v1 or v2 is required to load checkpoint) / 2.0のモデルを読み込む"
94
+ )
95
+ parser.add_argument(
96
+ "--unet_use_linear_projection",
97
+ action="store_true",
98
+ help="When saving v2 model as Diffusers, set U-Net config to `use_linear_projection=true` (to match stabilityai's model) / Diffusers形式でv2モデルを保存するときにU-Netの設定を`use_linear_projection=true`にする(stabilityaiのモデルと合わせる)",
99
+ )
100
+ parser.add_argument(
101
+ "--fp16",
102
+ action="store_true",
103
+ help="load as fp16 (Diffusers only) and save as fp16 (checkpoint only) / fp16形式で読み込み(Diffusers形式のみ対応)、保存する(checkpointのみ対応)",
104
+ )
105
+ parser.add_argument("--bf16", action="store_true", help="save as bf16 (checkpoint only) / bf16形式で保存する(checkpointのみ対応)")
106
+ parser.add_argument(
107
+ "--float", action="store_true", help="save as float (checkpoint only) / float(float32)形式で保存する(checkpointのみ対応)"
108
+ )
109
+ parser.add_argument(
110
+ "--save_precision_as",
111
+ type=str,
112
+ default="no",
113
+ choices=["fp16", "bf16", "float"],
114
+ help="save precision, do not specify with --fp16/--bf16/--float / 保存する精度、--fp16/--bf16/--floatと併用しないでくださ���",
115
+ )
116
+ parser.add_argument("--epoch", type=int, default=0, help="epoch to write to checkpoint / checkpointに記録するepoch数の値")
117
+ parser.add_argument(
118
+ "--global_step", type=int, default=0, help="global_step to write to checkpoint / checkpointに記録するglobal_stepの値"
119
+ )
120
+ parser.add_argument(
121
+ "--metadata",
122
+ type=str,
123
+ default=None,
124
+ help='モデルに保存されるメタデータ、Pythonの辞書形式で指定 / metadata: metadata written in to the model in Python Dictionary. Example metadata: \'{"name": "model_name", "resolution": "512x512"}\'',
125
+ )
126
+ parser.add_argument(
127
+ "--variant",
128
+ type=str,
129
+ default=None,
130
+ help="読む込むDiffusersのvariantを指定する、例: fp16 / variant: Diffusers variant to load. Example: fp16",
131
+ )
132
+ parser.add_argument(
133
+ "--reference_model",
134
+ type=str,
135
+ default=None,
136
+ help="scheduler/tokenizerのコピー元Diffusersモデル、Diffusers形式で保存するときに使用される、省略時は`runwayml/stable-diffusion-v1-5` または `stabilityai/stable-diffusion-2-1` / reference Diffusers model to copy scheduler/tokenizer config from, used when saving as Diffusers format, default is `runwayml/stable-diffusion-v1-5` or `stabilityai/stable-diffusion-2-1`",
137
+ )
138
+ parser.add_argument(
139
+ "--use_safetensors",
140
+ action="store_true",
141
+ help="use safetensors format to save Diffusers model (checkpoint depends on the file extension) / Duffusersモデルをsafetensors形式で保存する(checkpointは拡張子で自動判定)",
142
+ )
143
+
144
+ parser.add_argument(
145
+ "model_to_load",
146
+ type=str,
147
+ default=None,
148
+ help="model to load: checkpoint file or Diffusers model's directory / 読み込むモデル、checkpointかDiffusers形式モデルのディレクトリ",
149
+ )
150
+ parser.add_argument(
151
+ "model_to_save",
152
+ type=str,
153
+ default=None,
154
+ help="model to save: checkpoint (with extension) or Diffusers model's directory (without extension) / 変換後のモデル、拡張子がある場合はcheckpoint、ない場合はDiffusesモデルとして保存",
155
+ )
156
+ return parser
157
+
158
+
159
+ if __name__ == "__main__":
160
+ parser = setup_parser()
161
+
162
+ args = parser.parse_args()
163
+ convert(args)
custom_train_functions.py ADDED
@@ -0,0 +1,559 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import argparse
3
+ import random
4
+ import re
5
+ from typing import List, Optional, Union
6
+ from .utils import setup_logging
7
+
8
+ setup_logging()
9
+ import logging
10
+
11
+ logger = logging.getLogger(__name__)
12
+
13
+
14
+ def prepare_scheduler_for_custom_training(noise_scheduler, device):
15
+ if hasattr(noise_scheduler, "all_snr"):
16
+ return
17
+
18
+ alphas_cumprod = noise_scheduler.alphas_cumprod
19
+ sqrt_alphas_cumprod = torch.sqrt(alphas_cumprod)
20
+ sqrt_one_minus_alphas_cumprod = torch.sqrt(1.0 - alphas_cumprod)
21
+ alpha = sqrt_alphas_cumprod
22
+ sigma = sqrt_one_minus_alphas_cumprod
23
+ all_snr = (alpha / sigma) ** 2
24
+
25
+ noise_scheduler.all_snr = all_snr.to(device)
26
+
27
+
28
+ def fix_noise_scheduler_betas_for_zero_terminal_snr(noise_scheduler):
29
+ # fix beta: zero terminal SNR
30
+ logger.info(f"fix noise scheduler betas: https://arxiv.org/abs/2305.08891")
31
+
32
+ def enforce_zero_terminal_snr(betas):
33
+ # Convert betas to alphas_bar_sqrt
34
+ alphas = 1 - betas
35
+ alphas_bar = alphas.cumprod(0)
36
+ alphas_bar_sqrt = alphas_bar.sqrt()
37
+
38
+ # Store old values.
39
+ alphas_bar_sqrt_0 = alphas_bar_sqrt[0].clone()
40
+ alphas_bar_sqrt_T = alphas_bar_sqrt[-1].clone()
41
+ # Shift so last timestep is zero.
42
+ alphas_bar_sqrt -= alphas_bar_sqrt_T
43
+ # Scale so first timestep is back to old value.
44
+ alphas_bar_sqrt *= alphas_bar_sqrt_0 / (alphas_bar_sqrt_0 - alphas_bar_sqrt_T)
45
+
46
+ # Convert alphas_bar_sqrt to betas
47
+ alphas_bar = alphas_bar_sqrt**2
48
+ alphas = alphas_bar[1:] / alphas_bar[:-1]
49
+ alphas = torch.cat([alphas_bar[0:1], alphas])
50
+ betas = 1 - alphas
51
+ return betas
52
+
53
+ betas = noise_scheduler.betas
54
+ betas = enforce_zero_terminal_snr(betas)
55
+ alphas = 1.0 - betas
56
+ alphas_cumprod = torch.cumprod(alphas, dim=0)
57
+
58
+ # logger.info(f"original: {noise_scheduler.betas}")
59
+ # logger.info(f"fixed: {betas}")
60
+
61
+ noise_scheduler.betas = betas
62
+ noise_scheduler.alphas = alphas
63
+ noise_scheduler.alphas_cumprod = alphas_cumprod
64
+
65
+
66
+ def apply_snr_weight(loss, timesteps, noise_scheduler, gamma, v_prediction=False):
67
+ snr = torch.stack([noise_scheduler.all_snr[t] for t in timesteps])
68
+ min_snr_gamma = torch.minimum(snr, torch.full_like(snr, gamma))
69
+ if v_prediction:
70
+ snr_weight = torch.div(min_snr_gamma, snr + 1).float().to(loss.device)
71
+ else:
72
+ snr_weight = torch.div(min_snr_gamma, snr).float().to(loss.device)
73
+ loss = loss * snr_weight
74
+ return loss
75
+
76
+
77
+ def scale_v_prediction_loss_like_noise_prediction(loss, timesteps, noise_scheduler):
78
+ scale = get_snr_scale(timesteps, noise_scheduler)
79
+ loss = loss * scale
80
+ return loss
81
+
82
+
83
+ def get_snr_scale(timesteps, noise_scheduler):
84
+ snr_t = torch.stack([noise_scheduler.all_snr[t] for t in timesteps]) # batch_size
85
+ snr_t = torch.minimum(snr_t, torch.ones_like(snr_t) * 1000) # if timestep is 0, snr_t is inf, so limit it to 1000
86
+ scale = snr_t / (snr_t + 1)
87
+ # # show debug info
88
+ # logger.info(f"timesteps: {timesteps}, snr_t: {snr_t}, scale: {scale}")
89
+ return scale
90
+
91
+
92
+ def add_v_prediction_like_loss(loss, timesteps, noise_scheduler, v_pred_like_loss):
93
+ scale = get_snr_scale(timesteps, noise_scheduler)
94
+ # logger.info(f"add v-prediction like loss: {v_pred_like_loss}, scale: {scale}, loss: {loss}, time: {timesteps}")
95
+ loss = loss + loss / scale * v_pred_like_loss
96
+ return loss
97
+
98
+
99
+ def apply_debiased_estimation(loss, timesteps, noise_scheduler, v_prediction=False):
100
+ snr_t = torch.stack([noise_scheduler.all_snr[t] for t in timesteps]) # batch_size
101
+ snr_t = torch.minimum(snr_t, torch.ones_like(snr_t) * 1000) # if timestep is 0, snr_t is inf, so limit it to 1000
102
+ if v_prediction:
103
+ weight = 1 / (snr_t + 1)
104
+ else:
105
+ weight = 1 / torch.sqrt(snr_t)
106
+ loss = weight * loss
107
+ return loss
108
+
109
+
110
+ # TODO train_utilと分散しているのでどちらかに寄せる
111
+
112
+
113
+ def add_custom_train_arguments(parser: argparse.ArgumentParser, support_weighted_captions: bool = True):
114
+ parser.add_argument(
115
+ "--min_snr_gamma",
116
+ type=float,
117
+ default=None,
118
+ help="gamma for reducing the weight of high loss timesteps. Lower numbers have stronger effect. 5 is recommended by paper. / 低いタイムステップでの高いlossに対して重みを減らすためのgamma値、低いほど効果が強く、論文では5が推奨",
119
+ )
120
+ parser.add_argument(
121
+ "--scale_v_pred_loss_like_noise_pred",
122
+ action="store_true",
123
+ help="scale v-prediction loss like noise prediction loss / v-prediction lossをnoise prediction lossと同じようにスケーリングする",
124
+ )
125
+ parser.add_argument(
126
+ "--v_pred_like_loss",
127
+ type=float,
128
+ default=None,
129
+ help="add v-prediction like loss multiplied by this value / v-prediction lossをこの値をかけ��ものをlossに加算する",
130
+ )
131
+ parser.add_argument(
132
+ "--debiased_estimation_loss",
133
+ action="store_true",
134
+ help="debiased estimation loss / debiased estimation loss",
135
+ )
136
+ if support_weighted_captions:
137
+ parser.add_argument(
138
+ "--weighted_captions",
139
+ action="store_true",
140
+ default=False,
141
+ help="Enable weighted captions in the standard style (token:1.3). No commas inside parens, or shuffle/dropout may break the decoder. / 「[token]」、「(token)」「(token:1.3)」のような重み付きキャプションを有効にする。カンマを括弧内に入れるとシャッフルやdropoutで重みづけがおかしくなるので注意",
142
+ )
143
+
144
+
145
+ re_attention = re.compile(
146
+ r"""
147
+ \\\(|
148
+ \\\)|
149
+ \\\[|
150
+ \\]|
151
+ \\\\|
152
+ \\|
153
+ \(|
154
+ \[|
155
+ :([+-]?[.\d]+)\)|
156
+ \)|
157
+ ]|
158
+ [^\\()\[\]:]+|
159
+ :
160
+ """,
161
+ re.X,
162
+ )
163
+
164
+
165
+ def parse_prompt_attention(text):
166
+ """
167
+ Parses a string with attention tokens and returns a list of pairs: text and its associated weight.
168
+ Accepted tokens are:
169
+ (abc) - increases attention to abc by a multiplier of 1.1
170
+ (abc:3.12) - increases attention to abc by a multiplier of 3.12
171
+ [abc] - decreases attention to abc by a multiplier of 1.1
172
+ \( - literal character '('
173
+ \[ - literal character '['
174
+ \) - literal character ')'
175
+ \] - literal character ']'
176
+ \\ - literal character '\'
177
+ anything else - just text
178
+ >>> parse_prompt_attention('normal text')
179
+ [['normal text', 1.0]]
180
+ >>> parse_prompt_attention('an (important) word')
181
+ [['an ', 1.0], ['important', 1.1], [' word', 1.0]]
182
+ >>> parse_prompt_attention('(unbalanced')
183
+ [['unbalanced', 1.1]]
184
+ >>> parse_prompt_attention('\(literal\]')
185
+ [['(literal]', 1.0]]
186
+ >>> parse_prompt_attention('(unnecessary)(parens)')
187
+ [['unnecessaryparens', 1.1]]
188
+ >>> parse_prompt_attention('a (((house:1.3)) [on] a (hill:0.5), sun, (((sky))).')
189
+ [['a ', 1.0],
190
+ ['house', 1.5730000000000004],
191
+ [' ', 1.1],
192
+ ['on', 1.0],
193
+ [' a ', 1.1],
194
+ ['hill', 0.55],
195
+ [', sun, ', 1.1],
196
+ ['sky', 1.4641000000000006],
197
+ ['.', 1.1]]
198
+ """
199
+
200
+ res = []
201
+ round_brackets = []
202
+ square_brackets = []
203
+
204
+ round_bracket_multiplier = 1.1
205
+ square_bracket_multiplier = 1 / 1.1
206
+
207
+ def multiply_range(start_position, multiplier):
208
+ for p in range(start_position, len(res)):
209
+ res[p][1] *= multiplier
210
+
211
+ for m in re_attention.finditer(text):
212
+ text = m.group(0)
213
+ weight = m.group(1)
214
+
215
+ if text.startswith("\\"):
216
+ res.append([text[1:], 1.0])
217
+ elif text == "(":
218
+ round_brackets.append(len(res))
219
+ elif text == "[":
220
+ square_brackets.append(len(res))
221
+ elif weight is not None and len(round_brackets) > 0:
222
+ multiply_range(round_brackets.pop(), float(weight))
223
+ elif text == ")" and len(round_brackets) > 0:
224
+ multiply_range(round_brackets.pop(), round_bracket_multiplier)
225
+ elif text == "]" and len(square_brackets) > 0:
226
+ multiply_range(square_brackets.pop(), square_bracket_multiplier)
227
+ else:
228
+ res.append([text, 1.0])
229
+
230
+ for pos in round_brackets:
231
+ multiply_range(pos, round_bracket_multiplier)
232
+
233
+ for pos in square_brackets:
234
+ multiply_range(pos, square_bracket_multiplier)
235
+
236
+ if len(res) == 0:
237
+ res = [["", 1.0]]
238
+
239
+ # merge runs of identical weights
240
+ i = 0
241
+ while i + 1 < len(res):
242
+ if res[i][1] == res[i + 1][1]:
243
+ res[i][0] += res[i + 1][0]
244
+ res.pop(i + 1)
245
+ else:
246
+ i += 1
247
+
248
+ return res
249
+
250
+
251
+ def get_prompts_with_weights(tokenizer, prompt: List[str], max_length: int):
252
+ r"""
253
+ Tokenize a list of prompts and return its tokens with weights of each token.
254
+
255
+ No padding, starting or ending token is included.
256
+ """
257
+ tokens = []
258
+ weights = []
259
+ truncated = False
260
+ for text in prompt:
261
+ texts_and_weights = parse_prompt_attention(text)
262
+ text_token = []
263
+ text_weight = []
264
+ for word, weight in texts_and_weights:
265
+ # tokenize and discard the starting and the ending token
266
+ token = tokenizer(word).input_ids[1:-1]
267
+ text_token += token
268
+ # copy the weight by length of token
269
+ text_weight += [weight] * len(token)
270
+ # stop if the text is too long (longer than truncation limit)
271
+ if len(text_token) > max_length:
272
+ truncated = True
273
+ break
274
+ # truncate
275
+ if len(text_token) > max_length:
276
+ truncated = True
277
+ text_token = text_token[:max_length]
278
+ text_weight = text_weight[:max_length]
279
+ tokens.append(text_token)
280
+ weights.append(text_weight)
281
+ if truncated:
282
+ logger.warning("Prompt was truncated. Try to shorten the prompt or increase max_embeddings_multiples")
283
+ return tokens, weights
284
+
285
+
286
+ def pad_tokens_and_weights(tokens, weights, max_length, bos, eos, no_boseos_middle=True, chunk_length=77):
287
+ r"""
288
+ Pad the tokens (with starting and ending tokens) and weights (with 1.0) to max_length.
289
+ """
290
+ max_embeddings_multiples = (max_length - 2) // (chunk_length - 2)
291
+ weights_length = max_length if no_boseos_middle else max_embeddings_multiples * chunk_length
292
+ for i in range(len(tokens)):
293
+ tokens[i] = [bos] + tokens[i] + [eos] * (max_length - 1 - len(tokens[i]))
294
+ if no_boseos_middle:
295
+ weights[i] = [1.0] + weights[i] + [1.0] * (max_length - 1 - len(weights[i]))
296
+ else:
297
+ w = []
298
+ if len(weights[i]) == 0:
299
+ w = [1.0] * weights_length
300
+ else:
301
+ for j in range(max_embeddings_multiples):
302
+ w.append(1.0) # weight for starting token in this chunk
303
+ w += weights[i][j * (chunk_length - 2) : min(len(weights[i]), (j + 1) * (chunk_length - 2))]
304
+ w.append(1.0) # weight for ending token in this chunk
305
+ w += [1.0] * (weights_length - len(w))
306
+ weights[i] = w[:]
307
+
308
+ return tokens, weights
309
+
310
+
311
+ def get_unweighted_text_embeddings(
312
+ tokenizer,
313
+ text_encoder,
314
+ text_input: torch.Tensor,
315
+ chunk_length: int,
316
+ clip_skip: int,
317
+ eos: int,
318
+ pad: int,
319
+ no_boseos_middle: Optional[bool] = True,
320
+ ):
321
+ """
322
+ When the length of tokens is a multiple of the capacity of the text encoder,
323
+ it should be split into chunks and sent to the text encoder individually.
324
+ """
325
+ max_embeddings_multiples = (text_input.shape[1] - 2) // (chunk_length - 2)
326
+ if max_embeddings_multiples > 1:
327
+ text_embeddings = []
328
+ for i in range(max_embeddings_multiples):
329
+ # extract the i-th chunk
330
+ text_input_chunk = text_input[:, i * (chunk_length - 2) : (i + 1) * (chunk_length - 2) + 2].clone()
331
+
332
+ # cover the head and the tail by the starting and the ending tokens
333
+ text_input_chunk[:, 0] = text_input[0, 0]
334
+ if pad == eos: # v1
335
+ text_input_chunk[:, -1] = text_input[0, -1]
336
+ else: # v2
337
+ for j in range(len(text_input_chunk)):
338
+ if text_input_chunk[j, -1] != eos and text_input_chunk[j, -1] != pad: # 最後に普通の文字がある
339
+ text_input_chunk[j, -1] = eos
340
+ if text_input_chunk[j, 1] == pad: # BOSだけであとはPAD
341
+ text_input_chunk[j, 1] = eos
342
+
343
+ if clip_skip is None or clip_skip == 1:
344
+ text_embedding = text_encoder(text_input_chunk)[0]
345
+ else:
346
+ enc_out = text_encoder(text_input_chunk, output_hidden_states=True, return_dict=True)
347
+ text_embedding = enc_out["hidden_states"][-clip_skip]
348
+ text_embedding = text_encoder.text_model.final_layer_norm(text_embedding)
349
+
350
+ if no_boseos_middle:
351
+ if i == 0:
352
+ # discard the ending token
353
+ text_embedding = text_embedding[:, :-1]
354
+ elif i == max_embeddings_multiples - 1:
355
+ # discard the starting token
356
+ text_embedding = text_embedding[:, 1:]
357
+ else:
358
+ # discard both starting and ending tokens
359
+ text_embedding = text_embedding[:, 1:-1]
360
+
361
+ text_embeddings.append(text_embedding)
362
+ text_embeddings = torch.concat(text_embeddings, axis=1)
363
+ else:
364
+ if clip_skip is None or clip_skip == 1:
365
+ text_embeddings = text_encoder(text_input)[0]
366
+ else:
367
+ enc_out = text_encoder(text_input, output_hidden_states=True, return_dict=True)
368
+ text_embeddings = enc_out["hidden_states"][-clip_skip]
369
+ text_embeddings = text_encoder.text_model.final_layer_norm(text_embeddings)
370
+ return text_embeddings
371
+
372
+
373
+ def get_weighted_text_embeddings(
374
+ tokenizer,
375
+ text_encoder,
376
+ prompt: Union[str, List[str]],
377
+ device,
378
+ max_embeddings_multiples: Optional[int] = 3,
379
+ no_boseos_middle: Optional[bool] = False,
380
+ clip_skip=None,
381
+ ):
382
+ r"""
383
+ Prompts can be assigned with local weights using brackets. For example,
384
+ prompt 'A (very beautiful) masterpiece' highlights the words 'very beautiful',
385
+ and the embedding tokens corresponding to the words get multiplied by a constant, 1.1.
386
+
387
+ Also, to regularize of the embedding, the weighted embedding would be scaled to preserve the original mean.
388
+
389
+ Args:
390
+ prompt (`str` or `List[str]`):
391
+ The prompt or prompts to guide the image generation.
392
+ max_embeddings_multiples (`int`, *optional*, defaults to `3`):
393
+ The max multiple length of prompt embeddings compared to the max output length of text encoder.
394
+ no_boseos_middle (`bool`, *optional*, defaults to `False`):
395
+ If the length of text token is multiples of the capacity of text encoder, whether reserve the starting and
396
+ ending token in each of the chunk in the middle.
397
+ skip_parsing (`bool`, *optional*, defaults to `False`):
398
+ Skip the parsing of brackets.
399
+ skip_weighting (`bool`, *optional*, defaults to `False`):
400
+ Skip the weighting. When the parsing is skipped, it is forced True.
401
+ """
402
+ max_length = (tokenizer.model_max_length - 2) * max_embeddings_multiples + 2
403
+ if isinstance(prompt, str):
404
+ prompt = [prompt]
405
+
406
+ prompt_tokens, prompt_weights = get_prompts_with_weights(tokenizer, prompt, max_length - 2)
407
+
408
+ # round up the longest length of tokens to a multiple of (model_max_length - 2)
409
+ max_length = max([len(token) for token in prompt_tokens])
410
+
411
+ max_embeddings_multiples = min(
412
+ max_embeddings_multiples,
413
+ (max_length - 1) // (tokenizer.model_max_length - 2) + 1,
414
+ )
415
+ max_embeddings_multiples = max(1, max_embeddings_multiples)
416
+ max_length = (tokenizer.model_max_length - 2) * max_embeddings_multiples + 2
417
+
418
+ # pad the length of tokens and weights
419
+ bos = tokenizer.bos_token_id
420
+ eos = tokenizer.eos_token_id
421
+ pad = tokenizer.pad_token_id
422
+ prompt_tokens, prompt_weights = pad_tokens_and_weights(
423
+ prompt_tokens,
424
+ prompt_weights,
425
+ max_length,
426
+ bos,
427
+ eos,
428
+ no_boseos_middle=no_boseos_middle,
429
+ chunk_length=tokenizer.model_max_length,
430
+ )
431
+ prompt_tokens = torch.tensor(prompt_tokens, dtype=torch.long, device=device)
432
+
433
+ # get the embeddings
434
+ text_embeddings = get_unweighted_text_embeddings(
435
+ tokenizer,
436
+ text_encoder,
437
+ prompt_tokens,
438
+ tokenizer.model_max_length,
439
+ clip_skip,
440
+ eos,
441
+ pad,
442
+ no_boseos_middle=no_boseos_middle,
443
+ )
444
+ prompt_weights = torch.tensor(prompt_weights, dtype=text_embeddings.dtype, device=device)
445
+
446
+ # assign weights to the prompts and normalize in the sense of mean
447
+ previous_mean = text_embeddings.float().mean(axis=[-2, -1]).to(text_embeddings.dtype)
448
+ text_embeddings = text_embeddings * prompt_weights.unsqueeze(-1)
449
+ current_mean = text_embeddings.float().mean(axis=[-2, -1]).to(text_embeddings.dtype)
450
+ text_embeddings = text_embeddings * (previous_mean / current_mean).unsqueeze(-1).unsqueeze(-1)
451
+
452
+ return text_embeddings
453
+
454
+
455
+ # https://wandb.ai/johnowhitaker/multires_noise/reports/Multi-Resolution-Noise-for-Diffusion-Model-Training--VmlldzozNjYyOTU2
456
+ def pyramid_noise_like(noise, device, iterations=6, discount=0.4):
457
+ b, c, w, h = noise.shape # EDIT: w and h get over-written, rename for a different variant!
458
+ u = torch.nn.Upsample(size=(w, h), mode="bilinear").to(device)
459
+ for i in range(iterations):
460
+ r = random.random() * 2 + 2 # Rather than always going 2x,
461
+ wn, hn = max(1, int(w / (r**i))), max(1, int(h / (r**i)))
462
+ noise += u(torch.randn(b, c, wn, hn).to(device)) * discount**i
463
+ if wn == 1 or hn == 1:
464
+ break # Lowest resolution is 1x1
465
+ return noise / noise.std() # Scaled back to roughly unit variance
466
+
467
+
468
+ # https://www.crosslabs.org//blog/diffusion-with-offset-noise
469
+ def apply_noise_offset(latents, noise, noise_offset, adaptive_noise_scale):
470
+ if noise_offset is None:
471
+ return noise
472
+ if adaptive_noise_scale is not None:
473
+ # latent shape: (batch_size, channels, height, width)
474
+ # abs mean value for each channel
475
+ latent_mean = torch.abs(latents.mean(dim=(2, 3), keepdim=True))
476
+
477
+ # multiply adaptive noise scale to the mean value and add it to the noise offset
478
+ noise_offset = noise_offset + adaptive_noise_scale * latent_mean
479
+ noise_offset = torch.clamp(noise_offset, 0.0, None) # in case of adaptive noise scale is negative
480
+
481
+ noise = noise + noise_offset * torch.randn((latents.shape[0], latents.shape[1], 1, 1), device=latents.device)
482
+ return noise
483
+
484
+
485
+ def apply_masked_loss(loss, batch):
486
+ if "conditioning_images" in batch:
487
+ # conditioning image is -1 to 1. we need to convert it to 0 to 1
488
+ mask_image = batch["conditioning_images"].to(dtype=loss.dtype)[:, 0].unsqueeze(1) # use R channel
489
+ mask_image = mask_image / 2 + 0.5
490
+ # print(f"conditioning_image: {mask_image.shape}")
491
+ elif "alpha_masks" in batch and batch["alpha_masks"] is not None:
492
+ # alpha mask is 0 to 1
493
+ mask_image = batch["alpha_masks"].to(dtype=loss.dtype).unsqueeze(1) # add channel dimension
494
+ # print(f"mask_image: {mask_image.shape}, {mask_image.mean()}")
495
+ else:
496
+ return loss
497
+
498
+ # resize to the same size as the loss
499
+ mask_image = torch.nn.functional.interpolate(mask_image, size=loss.shape[2:], mode="area")
500
+ loss = loss * mask_image
501
+ return loss
502
+
503
+
504
+ """
505
+ ##########################################
506
+ # Perlin Noise
507
+ def rand_perlin_2d(device, shape, res, fade=lambda t: 6 * t**5 - 15 * t**4 + 10 * t**3):
508
+ delta = (res[0] / shape[0], res[1] / shape[1])
509
+ d = (shape[0] // res[0], shape[1] // res[1])
510
+
511
+ grid = (
512
+ torch.stack(
513
+ torch.meshgrid(torch.arange(0, res[0], delta[0], device=device), torch.arange(0, res[1], delta[1], device=device)),
514
+ dim=-1,
515
+ )
516
+ % 1
517
+ )
518
+ angles = 2 * torch.pi * torch.rand(res[0] + 1, res[1] + 1, device=device)
519
+ gradients = torch.stack((torch.cos(angles), torch.sin(angles)), dim=-1)
520
+
521
+ tile_grads = (
522
+ lambda slice1, slice2: gradients[slice1[0] : slice1[1], slice2[0] : slice2[1]]
523
+ .repeat_interleave(d[0], 0)
524
+ .repeat_interleave(d[1], 1)
525
+ )
526
+ dot = lambda grad, shift: (
527
+ torch.stack((grid[: shape[0], : shape[1], 0] + shift[0], grid[: shape[0], : shape[1], 1] + shift[1]), dim=-1)
528
+ * grad[: shape[0], : shape[1]]
529
+ ).sum(dim=-1)
530
+
531
+ n00 = dot(tile_grads([0, -1], [0, -1]), [0, 0])
532
+ n10 = dot(tile_grads([1, None], [0, -1]), [-1, 0])
533
+ n01 = dot(tile_grads([0, -1], [1, None]), [0, -1])
534
+ n11 = dot(tile_grads([1, None], [1, None]), [-1, -1])
535
+ t = fade(grid[: shape[0], : shape[1]])
536
+ return 1.414 * torch.lerp(torch.lerp(n00, n10, t[..., 0]), torch.lerp(n01, n11, t[..., 0]), t[..., 1])
537
+
538
+
539
+ def rand_perlin_2d_octaves(device, shape, res, octaves=1, persistence=0.5):
540
+ noise = torch.zeros(shape, device=device)
541
+ frequency = 1
542
+ amplitude = 1
543
+ for _ in range(octaves):
544
+ noise += amplitude * rand_perlin_2d(device, shape, (frequency * res[0], frequency * res[1]))
545
+ frequency *= 2
546
+ amplitude *= persistence
547
+ return noise
548
+
549
+
550
+ def perlin_noise(noise, device, octaves):
551
+ _, c, w, h = noise.shape
552
+ perlin = lambda: rand_perlin_2d_octaves(device, (w, h), (4, 4), octaves)
553
+ noise_perlin = []
554
+ for _ in range(c):
555
+ noise_perlin.append(perlin())
556
+ noise_perlin = torch.stack(noise_perlin).unsqueeze(0) # (1, c, w, h)
557
+ noise += noise_perlin # broadcast for each batch
558
+ return noise / noise.std() # Scaled back to roughly unit variance
559
+ """
deepspeed_utils.py ADDED
@@ -0,0 +1,139 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import argparse
3
+ import torch
4
+ from accelerate import DeepSpeedPlugin, Accelerator
5
+
6
+ from .utils import setup_logging
7
+
8
+ setup_logging()
9
+ import logging
10
+
11
+ logger = logging.getLogger(__name__)
12
+
13
+
14
+ def add_deepspeed_arguments(parser: argparse.ArgumentParser):
15
+ # DeepSpeed Arguments. https://huggingface.co/docs/accelerate/usage_guides/deepspeed
16
+ parser.add_argument("--deepspeed", action="store_true", help="enable deepspeed training")
17
+ parser.add_argument("--zero_stage", type=int, default=2, choices=[0, 1, 2, 3], help="Possible options are 0,1,2,3.")
18
+ parser.add_argument(
19
+ "--offload_optimizer_device",
20
+ type=str,
21
+ default=None,
22
+ choices=[None, "cpu", "nvme"],
23
+ help="Possible options are none|cpu|nvme. Only applicable with ZeRO Stages 2 and 3.",
24
+ )
25
+ parser.add_argument(
26
+ "--offload_optimizer_nvme_path",
27
+ type=str,
28
+ default=None,
29
+ help="Possible options are /nvme|/local_nvme. Only applicable with ZeRO Stage 3.",
30
+ )
31
+ parser.add_argument(
32
+ "--offload_param_device",
33
+ type=str,
34
+ default=None,
35
+ choices=[None, "cpu", "nvme"],
36
+ help="Possible options are none|cpu|nvme. Only applicable with ZeRO Stage 3.",
37
+ )
38
+ parser.add_argument(
39
+ "--offload_param_nvme_path",
40
+ type=str,
41
+ default=None,
42
+ help="Possible options are /nvme|/local_nvme. Only applicable with ZeRO Stage 3.",
43
+ )
44
+ parser.add_argument(
45
+ "--zero3_init_flag",
46
+ action="store_true",
47
+ help="Flag to indicate whether to enable `deepspeed.zero.Init` for constructing massive models."
48
+ "Only applicable with ZeRO Stage-3.",
49
+ )
50
+ parser.add_argument(
51
+ "--zero3_save_16bit_model",
52
+ action="store_true",
53
+ help="Flag to indicate whether to save 16-bit model. Only applicable with ZeRO Stage-3.",
54
+ )
55
+ parser.add_argument(
56
+ "--fp16_master_weights_and_gradients",
57
+ action="store_true",
58
+ help="fp16_master_and_gradients requires optimizer to support keeping fp16 master and gradients while keeping the optimizer states in fp32.",
59
+ )
60
+
61
+
62
+ def prepare_deepspeed_args(args: argparse.Namespace):
63
+ if not args.deepspeed:
64
+ return
65
+
66
+ # To avoid RuntimeError: DataLoader worker exited unexpectedly with exit code 1.
67
+ args.max_data_loader_n_workers = 1
68
+
69
+
70
+ def prepare_deepspeed_plugin(args: argparse.Namespace):
71
+ if not args.deepspeed:
72
+ return None
73
+
74
+ try:
75
+ import deepspeed
76
+ except ImportError as e:
77
+ logger.error(
78
+ "deepspeed is not installed. please install deepspeed in your environment with following command. DS_BUILD_OPS=0 pip install deepspeed"
79
+ )
80
+ exit(1)
81
+
82
+ deepspeed_plugin = DeepSpeedPlugin(
83
+ zero_stage=args.zero_stage,
84
+ gradient_accumulation_steps=args.gradient_accumulation_steps,
85
+ gradient_clipping=args.max_grad_norm,
86
+ offload_optimizer_device=args.offload_optimizer_device,
87
+ offload_optimizer_nvme_path=args.offload_optimizer_nvme_path,
88
+ offload_param_device=args.offload_param_device,
89
+ offload_param_nvme_path=args.offload_param_nvme_path,
90
+ zero3_init_flag=args.zero3_init_flag,
91
+ zero3_save_16bit_model=args.zero3_save_16bit_model,
92
+ )
93
+ deepspeed_plugin.deepspeed_config["train_micro_batch_size_per_gpu"] = args.train_batch_size
94
+ deepspeed_plugin.deepspeed_config["train_batch_size"] = (
95
+ args.train_batch_size * args.gradient_accumulation_steps * int(os.environ["WORLD_SIZE"])
96
+ )
97
+ deepspeed_plugin.set_mixed_precision(args.mixed_precision)
98
+ if args.mixed_precision.lower() == "fp16":
99
+ deepspeed_plugin.deepspeed_config["fp16"]["initial_scale_power"] = 0 # preventing overflow.
100
+ if args.full_fp16 or args.fp16_master_weights_and_gradients:
101
+ if args.offload_optimizer_device == "cpu" and args.zero_stage == 2:
102
+ deepspeed_plugin.deepspeed_config["fp16"]["fp16_master_weights_and_grads"] = True
103
+ logger.info("[DeepSpeed] full fp16 enable.")
104
+ else:
105
+ logger.info(
106
+ "[DeepSpeed]full fp16, fp16_master_weights_and_grads currently only supported using ZeRO-Offload with DeepSpeedCPUAdam on ZeRO-2 stage."
107
+ )
108
+
109
+ if args.offload_optimizer_device is not None:
110
+ logger.info("[DeepSpeed] start to manually build cpu_adam.")
111
+ deepspeed.ops.op_builder.CPUAdamBuilder().load()
112
+ logger.info("[DeepSpeed] building cpu_adam done.")
113
+
114
+ return deepspeed_plugin
115
+
116
+
117
+ # Accelerate library does not support multiple models for deepspeed. So, we need to wrap multiple models into a single model.
118
+ def prepare_deepspeed_model(args: argparse.Namespace, **models):
119
+ # remove None from models
120
+ models = {k: v for k, v in models.items() if v is not None}
121
+
122
+ class DeepSpeedWrapper(torch.nn.Module):
123
+ def __init__(self, **kw_models) -> None:
124
+ super().__init__()
125
+ self.models = torch.nn.ModuleDict()
126
+
127
+ for key, model in kw_models.items():
128
+ if isinstance(model, list):
129
+ model = torch.nn.ModuleList(model)
130
+ assert isinstance(
131
+ model, torch.nn.Module
132
+ ), f"model must be an instance of torch.nn.Module, but got {key} is {type(model)}"
133
+ self.models.update(torch.nn.ModuleDict({key: model}))
134
+
135
+ def get_models(self):
136
+ return self.models
137
+
138
+ ds_model = DeepSpeedWrapper(**models)
139
+ return ds_model
dependabot.yml ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ ---
2
+ version: 2
3
+ updates:
4
+ - package-ecosystem: "github-actions"
5
+ directory: "/"
6
+ schedule:
7
+ interval: "monthly"
detect_face_rotate.py ADDED
@@ -0,0 +1,253 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # このスクリプトのライセンスは、train_dreambooth.pyと同じくApache License 2.0とします
2
+ # (c) 2022 Kohya S. @kohya_ss
3
+
4
+ # 横長の画像から顔検出して正立するように回転し、そこを中心に正方形に切り出す
5
+
6
+ # v2: extract max face if multiple faces are found
7
+ # v3: add crop_ratio option
8
+ # v4: add multiple faces extraction and min/max size
9
+
10
+ import argparse
11
+ import math
12
+ import cv2
13
+ import glob
14
+ import os
15
+ from anime_face_detector import create_detector
16
+ from tqdm import tqdm
17
+ import numpy as np
18
+ from library.utils import setup_logging, pil_resize
19
+ setup_logging()
20
+ import logging
21
+ logger = logging.getLogger(__name__)
22
+
23
+ KP_REYE = 11
24
+ KP_LEYE = 19
25
+
26
+ SCORE_THRES = 0.90
27
+
28
+
29
+ def detect_faces(detector, image, min_size):
30
+ preds = detector(image) # bgr
31
+ # logger.info(len(preds))
32
+
33
+ faces = []
34
+ for pred in preds:
35
+ bb = pred['bbox']
36
+ score = bb[-1]
37
+ if score < SCORE_THRES:
38
+ continue
39
+
40
+ left, top, right, bottom = bb[:4]
41
+ cx = int((left + right) / 2)
42
+ cy = int((top + bottom) / 2)
43
+ fw = int(right - left)
44
+ fh = int(bottom - top)
45
+
46
+ lex, ley = pred['keypoints'][KP_LEYE, 0:2]
47
+ rex, rey = pred['keypoints'][KP_REYE, 0:2]
48
+ angle = math.atan2(ley - rey, lex - rex)
49
+ angle = angle / math.pi * 180
50
+
51
+ faces.append((cx, cy, fw, fh, angle))
52
+
53
+ faces.sort(key=lambda x: max(x[2], x[3]), reverse=True) # 大きい順
54
+ return faces
55
+
56
+
57
+ def rotate_image(image, angle, cx, cy):
58
+ h, w = image.shape[0:2]
59
+ rot_mat = cv2.getRotationMatrix2D((cx, cy), angle, 1.0)
60
+
61
+ # # 回転する分、すこし画像サイズを大きくする→とりあえず無効化
62
+ # nh = max(h, int(w * math.sin(angle)))
63
+ # nw = max(w, int(h * math.sin(angle)))
64
+ # if nh > h or nw > w:
65
+ # pad_y = nh - h
66
+ # pad_t = pad_y // 2
67
+ # pad_x = nw - w
68
+ # pad_l = pad_x // 2
69
+ # m = np.array([[0, 0, pad_l],
70
+ # [0, 0, pad_t]])
71
+ # rot_mat = rot_mat + m
72
+ # h, w = nh, nw
73
+ # cx += pad_l
74
+ # cy += pad_t
75
+
76
+ result = cv2.warpAffine(image, rot_mat, (w, h), flags=cv2.INTER_LINEAR, borderMode=cv2.BORDER_REFLECT)
77
+ return result, cx, cy
78
+
79
+
80
+ def process(args):
81
+ assert (not args.resize_fit) or args.resize_face_size is None, f"resize_fit and resize_face_size can't be specified both / resize_fitとresize_face_sizeはどちらか片方しか指定できません"
82
+ assert args.crop_ratio is None or args.resize_face_size is None, f"crop_ratio指定時はresize_face_sizeは指定できません"
83
+
84
+ # アニメ顔検出モデルを読み込む
85
+ logger.info("loading face detector.")
86
+ detector = create_detector('yolov3')
87
+
88
+ # cropの引数を解析する
89
+ if args.crop_size is None:
90
+ crop_width = crop_height = None
91
+ else:
92
+ tokens = args.crop_size.split(',')
93
+ assert len(tokens) == 2, f"crop_size must be 'width,height' / crop_sizeは'幅,高さ'で指定してください"
94
+ crop_width, crop_height = [int(t) for t in tokens]
95
+
96
+ if args.crop_ratio is None:
97
+ crop_h_ratio = crop_v_ratio = None
98
+ else:
99
+ tokens = args.crop_ratio.split(',')
100
+ assert len(tokens) == 2, f"crop_ratio must be 'horizontal,vertical' / crop_ratioは'幅,高さ'の倍率で指定してください"
101
+ crop_h_ratio, crop_v_ratio = [float(t) for t in tokens]
102
+
103
+ # 画像を処理する
104
+ logger.info("processing.")
105
+ output_extension = ".png"
106
+
107
+ os.makedirs(args.dst_dir, exist_ok=True)
108
+ paths = glob.glob(os.path.join(args.src_dir, "*.png")) + glob.glob(os.path.join(args.src_dir, "*.jpg")) + \
109
+ glob.glob(os.path.join(args.src_dir, "*.webp"))
110
+ for path in tqdm(paths):
111
+ basename = os.path.splitext(os.path.basename(path))[0]
112
+
113
+ # image = cv2.imread(path) # 日本語ファイル名でエラーになる
114
+ image = cv2.imdecode(np.fromfile(path, np.uint8), cv2.IMREAD_UNCHANGED)
115
+ if len(image.shape) == 2:
116
+ image = cv2.cvtColor(image, cv2.COLOR_GRAY2BGR)
117
+ if image.shape[2] == 4:
118
+ logger.warning(f"image has alpha. ignore / 画像の透明度が設定されているため無視します: {path}")
119
+ image = image[:, :, :3].copy() # copyをしないと内部的に透明度情報が付いたままになるらしい
120
+
121
+ h, w = image.shape[:2]
122
+
123
+ faces = detect_faces(detector, image, args.multiple_faces)
124
+ for i, face in enumerate(faces):
125
+ cx, cy, fw, fh, angle = face
126
+ face_size = max(fw, fh)
127
+ if args.min_size is not None and face_size < args.min_size:
128
+ continue
129
+ if args.max_size is not None and face_size >= args.max_size:
130
+ continue
131
+ face_suffix = f"_{i+1:02d}" if args.multiple_faces else ""
132
+
133
+ # オプション指定があれば回転する
134
+ face_img = image
135
+ if args.rotate:
136
+ face_img, cx, cy = rotate_image(face_img, angle, cx, cy)
137
+
138
+ # オプション指定があれば顔を中心に切り出す
139
+ if crop_width is not None or crop_h_ratio is not None:
140
+ cur_crop_width, cur_crop_height = crop_width, crop_height
141
+ if crop_h_ratio is not None:
142
+ cur_crop_width = int(face_size * crop_h_ratio + .5)
143
+ cur_crop_height = int(face_size * crop_v_ratio + .5)
144
+
145
+ # リサイズを必要なら行う
146
+ scale = 1.0
147
+ if args.resize_face_size is not None:
148
+ # 顔サイズを基準にリサイズする
149
+ scale = args.resize_face_size / face_size
150
+ if scale < cur_crop_width / w:
151
+ logger.warning(
152
+ f"image width too small in face size based resizing / 顔を基準にリサイズすると画像の幅がcrop sizeより小さい(顔が相対的に大きすぎる)ので顔サイズが変わります: {path}")
153
+ scale = cur_crop_width / w
154
+ if scale < cur_crop_height / h:
155
+ logger.warning(
156
+ f"image height too small in face size based resizing / 顔を基準にリサイズすると画像の高さがcrop sizeより小さい(顔が相対的に大きすぎる)ので顔サイズが変わります: {path}")
157
+ scale = cur_crop_height / h
158
+ elif crop_h_ratio is not None:
159
+ # 倍率指定の時にはリサイズしない
160
+ pass
161
+ else:
162
+ # 切り出しサイズ指定あり
163
+ if w < cur_crop_width:
164
+ logger.warning(f"image width too small/ 画像の幅がcrop sizeより小さいので画質が劣化します: {path}")
165
+ scale = cur_crop_width / w
166
+ if h < cur_crop_height:
167
+ logger.warning(f"image height too small/ 画像の高さがcrop sizeより小さいので画質が劣化します: {path}")
168
+ scale = cur_crop_height / h
169
+ if args.resize_fit:
170
+ scale = max(cur_crop_width / w, cur_crop_height / h)
171
+
172
+ if scale != 1.0:
173
+ w = int(w * scale + .5)
174
+ h = int(h * scale + .5)
175
+ if scale < 1.0:
176
+ face_img = cv2.resize(face_img, (w, h), interpolation=cv2.INTER_AREA)
177
+ else:
178
+ face_img = pil_resize(face_img, (w, h))
179
+ cx = int(cx * scale + .5)
180
+ cy = int(cy * scale + .5)
181
+ fw = int(fw * scale + .5)
182
+ fh = int(fh * scale + .5)
183
+
184
+ cur_crop_width = min(cur_crop_width, face_img.shape[1])
185
+ cur_crop_height = min(cur_crop_height, face_img.shape[0])
186
+
187
+ x = cx - cur_crop_width // 2
188
+ cx = cur_crop_width // 2
189
+ if x < 0:
190
+ cx = cx + x
191
+ x = 0
192
+ elif x + cur_crop_width > w:
193
+ cx = cx + (x + cur_crop_width - w)
194
+ x = w - cur_crop_width
195
+ face_img = face_img[:, x:x+cur_crop_width]
196
+
197
+ y = cy - cur_crop_height // 2
198
+ cy = cur_crop_height // 2
199
+ if y < 0:
200
+ cy = cy + y
201
+ y = 0
202
+ elif y + cur_crop_height > h:
203
+ cy = cy + (y + cur_crop_height - h)
204
+ y = h - cur_crop_height
205
+ face_img = face_img[y:y + cur_crop_height]
206
+
207
+ # # debug
208
+ # logger.info(path, cx, cy, angle)
209
+ # crp = cv2.resize(image, (image.shape[1]//8, image.shape[0]//8))
210
+ # cv2.imshow("image", crp)
211
+ # if cv2.waitKey() == 27:
212
+ # break
213
+ # cv2.destroyAllWindows()
214
+
215
+ # debug
216
+ if args.debug:
217
+ cv2.rectangle(face_img, (cx-fw//2, cy-fh//2), (cx+fw//2, cy+fh//2), (255, 0, 255), fw//20)
218
+
219
+ _, buf = cv2.imencode(output_extension, face_img)
220
+ with open(os.path.join(args.dst_dir, f"{basename}{face_suffix}_{cx:04d}_{cy:04d}_{fw:04d}_{fh:04d}{output_extension}"), "wb") as f:
221
+ buf.tofile(f)
222
+
223
+
224
+ def setup_parser() -> argparse.ArgumentParser:
225
+ parser = argparse.ArgumentParser()
226
+ parser.add_argument("--src_dir", type=str, help="directory to load images / 画像を読み込むディレクトリ")
227
+ parser.add_argument("--dst_dir", type=str, help="directory to save images / 画像を保存するディレクトリ")
228
+ parser.add_argument("--rotate", action="store_true", help="rotate images to align faces / 顔が正立するように画像を回転する")
229
+ parser.add_argument("--resize_fit", action="store_true",
230
+ help="resize to fit smaller side after cropping / 切り出し後の画像の短辺がcrop_sizeにあうようにリサイズする")
231
+ parser.add_argument("--resize_face_size", type=int, default=None,
232
+ help="resize image before cropping by face size / 切り出し前に顔がこのサイズになるようにリサイズする")
233
+ parser.add_argument("--crop_size", type=str, default=None,
234
+ help="crop images with 'width,height' pixels, face centered / 顔を中心として'幅,高さ'のサイズで切り出す")
235
+ parser.add_argument("--crop_ratio", type=str, default=None,
236
+ help="crop images with 'horizontal,vertical' ratio to face, face centered / 顔を中心として顔サイズの'幅倍率,高さ倍率'のサイズで切り出す")
237
+ parser.add_argument("--min_size", type=int, default=None,
238
+ help="minimum face size to output (included) / 処理対象とする顔の最小サイズ(この値以上)")
239
+ parser.add_argument("--max_size", type=int, default=None,
240
+ help="maximum face size to output (excluded) / 処理対象とする顔の最大サイズ(この値未満)")
241
+ parser.add_argument("--multiple_faces", action="store_true",
242
+ help="output each faces / 複数の顔が見つかった場合、それぞれを切り出す")
243
+ parser.add_argument("--debug", action="store_true", help="render rect for face / 処理後画像の顔位置に矩形を描画します")
244
+
245
+ return parser
246
+
247
+
248
+ if __name__ == '__main__':
249
+ parser = setup_parser()
250
+
251
+ args = parser.parse_args()
252
+
253
+ process(args)
device_utils.py ADDED
@@ -0,0 +1,89 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import functools
2
+ import gc
3
+
4
+ import torch
5
+ try:
6
+ # intel gpu support for pytorch older than 2.5
7
+ # ipex is not needed after pytorch 2.5
8
+ import intel_extension_for_pytorch as ipex # noqa
9
+ except Exception:
10
+ pass
11
+
12
+
13
+ try:
14
+ HAS_CUDA = torch.cuda.is_available()
15
+ except Exception:
16
+ HAS_CUDA = False
17
+
18
+ try:
19
+ HAS_MPS = torch.backends.mps.is_available()
20
+ except Exception:
21
+ HAS_MPS = False
22
+
23
+ try:
24
+ HAS_XPU = torch.xpu.is_available()
25
+ except Exception:
26
+ HAS_XPU = False
27
+
28
+
29
+ def clean_memory():
30
+ gc.collect()
31
+ if HAS_CUDA:
32
+ torch.cuda.empty_cache()
33
+ if HAS_XPU:
34
+ torch.xpu.empty_cache()
35
+ if HAS_MPS:
36
+ torch.mps.empty_cache()
37
+
38
+
39
+ def clean_memory_on_device(device: torch.device):
40
+ r"""
41
+ Clean memory on the specified device, will be called from training scripts.
42
+ """
43
+ gc.collect()
44
+
45
+ # device may "cuda" or "cuda:0", so we need to check the type of device
46
+ if device.type == "cuda":
47
+ torch.cuda.empty_cache()
48
+ if device.type == "xpu":
49
+ torch.xpu.empty_cache()
50
+ if device.type == "mps":
51
+ torch.mps.empty_cache()
52
+
53
+
54
+ @functools.lru_cache(maxsize=None)
55
+ def get_preferred_device() -> torch.device:
56
+ r"""
57
+ Do not call this function from training scripts. Use accelerator.device instead.
58
+ """
59
+ if HAS_CUDA:
60
+ device = torch.device("cuda")
61
+ elif HAS_XPU:
62
+ device = torch.device("xpu")
63
+ elif HAS_MPS:
64
+ device = torch.device("mps")
65
+ else:
66
+ device = torch.device("cpu")
67
+ print(f"get_preferred_device() -> {device}")
68
+ return device
69
+
70
+
71
+ def init_ipex():
72
+ """
73
+ Apply IPEX to CUDA hijacks using `library.ipex.ipex_init`.
74
+
75
+ This function should run right after importing torch and before doing anything else.
76
+
77
+ If xpu is not available, this function does nothing.
78
+ """
79
+ try:
80
+ if HAS_XPU:
81
+ from library.ipex import ipex_init
82
+
83
+ is_initialized, error_message = ipex_init()
84
+ if not is_initialized:
85
+ print("failed to initialize ipex:", error_message)
86
+ else:
87
+ return
88
+ except Exception as e:
89
+ print("failed to initialize ipex:", e)
diffusers.py ADDED
@@ -0,0 +1,47 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from functools import wraps
2
+ import torch
3
+ import diffusers # pylint: disable=import-error
4
+
5
+ # pylint: disable=protected-access, missing-function-docstring, line-too-long
6
+
7
+
8
+ # Diffusers FreeU
9
+ original_fourier_filter = diffusers.utils.torch_utils.fourier_filter
10
+ @wraps(diffusers.utils.torch_utils.fourier_filter)
11
+ def fourier_filter(x_in, threshold, scale):
12
+ return_dtype = x_in.dtype
13
+ return original_fourier_filter(x_in.to(dtype=torch.float32), threshold, scale).to(dtype=return_dtype)
14
+
15
+
16
+ # fp64 error
17
+ class FluxPosEmbed(torch.nn.Module):
18
+ def __init__(self, theta: int, axes_dim):
19
+ super().__init__()
20
+ self.theta = theta
21
+ self.axes_dim = axes_dim
22
+
23
+ def forward(self, ids: torch.Tensor) -> torch.Tensor:
24
+ n_axes = ids.shape[-1]
25
+ cos_out = []
26
+ sin_out = []
27
+ pos = ids.float()
28
+ for i in range(n_axes):
29
+ cos, sin = diffusers.models.embeddings.get_1d_rotary_pos_embed(
30
+ self.axes_dim[i],
31
+ pos[:, i],
32
+ theta=self.theta,
33
+ repeat_interleave_real=True,
34
+ use_real=True,
35
+ freqs_dtype=torch.float32,
36
+ )
37
+ cos_out.append(cos)
38
+ sin_out.append(sin)
39
+ freqs_cos = torch.cat(cos_out, dim=-1).to(ids.device)
40
+ freqs_sin = torch.cat(sin_out, dim=-1).to(ids.device)
41
+ return freqs_cos, freqs_sin
42
+
43
+
44
+ def ipex_diffusers(device_supports_fp64=False, can_allocate_plus_4gb=False):
45
+ diffusers.utils.torch_utils.fourier_filter = fourier_filter
46
+ if not device_supports_fp64:
47
+ diffusers.models.embeddings.FluxPosEmbed = FluxPosEmbed
dylora.py ADDED
@@ -0,0 +1,529 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # some codes are copied from:
2
+ # https://github.com/huawei-noah/KD-NLP/blob/main/DyLoRA/
3
+
4
+ # Copyright (C) 2022. Huawei Technologies Co., Ltd. All rights reserved.
5
+ # Changes made to the original code:
6
+ # 2022.08.20 - Integrate the DyLoRA layer for the LoRA Linear layer
7
+ # ------------------------------------------------------------------------------------------
8
+ # Copyright (c) Microsoft Corporation. All rights reserved.
9
+ # Licensed under the MIT License (MIT). See LICENSE in the repo root for license information.
10
+ # ------------------------------------------------------------------------------------------
11
+
12
+ import math
13
+ import os
14
+ import random
15
+ from typing import Dict, List, Optional, Tuple, Type, Union
16
+ from diffusers import AutoencoderKL
17
+ from transformers import CLIPTextModel
18
+ import torch
19
+ from torch import nn
20
+ from library.utils import setup_logging
21
+
22
+ setup_logging()
23
+ import logging
24
+
25
+ logger = logging.getLogger(__name__)
26
+
27
+
28
+ class DyLoRAModule(torch.nn.Module):
29
+ """
30
+ replaces forward method of the original Linear, instead of replacing the original Linear module.
31
+ """
32
+
33
+ # NOTE: support dropout in future
34
+ def __init__(self, lora_name, org_module: torch.nn.Module, multiplier=1.0, lora_dim=4, alpha=1, unit=1):
35
+ super().__init__()
36
+ self.lora_name = lora_name
37
+ self.lora_dim = lora_dim
38
+ self.unit = unit
39
+ assert self.lora_dim % self.unit == 0, "rank must be a multiple of unit"
40
+
41
+ if org_module.__class__.__name__ == "Conv2d":
42
+ in_dim = org_module.in_channels
43
+ out_dim = org_module.out_channels
44
+ else:
45
+ in_dim = org_module.in_features
46
+ out_dim = org_module.out_features
47
+
48
+ if type(alpha) == torch.Tensor:
49
+ alpha = alpha.detach().float().numpy() # without casting, bf16 causes error
50
+ alpha = self.lora_dim if alpha is None or alpha == 0 else alpha
51
+ self.scale = alpha / self.lora_dim
52
+ self.register_buffer("alpha", torch.tensor(alpha)) # 定数として扱える
53
+
54
+ self.is_conv2d = org_module.__class__.__name__ == "Conv2d"
55
+ self.is_conv2d_3x3 = self.is_conv2d and org_module.kernel_size == (3, 3)
56
+
57
+ if self.is_conv2d and self.is_conv2d_3x3:
58
+ kernel_size = org_module.kernel_size
59
+ self.stride = org_module.stride
60
+ self.padding = org_module.padding
61
+ self.lora_A = nn.ParameterList([org_module.weight.new_zeros((1, in_dim, *kernel_size)) for _ in range(self.lora_dim)])
62
+ self.lora_B = nn.ParameterList([org_module.weight.new_zeros((out_dim, 1, 1, 1)) for _ in range(self.lora_dim)])
63
+ else:
64
+ self.lora_A = nn.ParameterList([org_module.weight.new_zeros((1, in_dim)) for _ in range(self.lora_dim)])
65
+ self.lora_B = nn.ParameterList([org_module.weight.new_zeros((out_dim, 1)) for _ in range(self.lora_dim)])
66
+
67
+ # same as microsoft's
68
+ for lora in self.lora_A:
69
+ torch.nn.init.kaiming_uniform_(lora, a=math.sqrt(5))
70
+ for lora in self.lora_B:
71
+ torch.nn.init.zeros_(lora)
72
+
73
+ self.multiplier = multiplier
74
+ self.org_module = org_module # remove in applying
75
+
76
+ def apply_to(self):
77
+ self.org_forward = self.org_module.forward
78
+ self.org_module.forward = self.forward
79
+ del self.org_module
80
+
81
+ def forward(self, x):
82
+ result = self.org_forward(x)
83
+
84
+ # specify the dynamic rank
85
+ trainable_rank = random.randint(0, self.lora_dim - 1)
86
+ trainable_rank = trainable_rank - trainable_rank % self.unit # make sure the rank is a multiple of unit
87
+
88
+ # 一部のパラメータを固定して、残りのパラメータを学習する
89
+ for i in range(0, trainable_rank):
90
+ self.lora_A[i].requires_grad = False
91
+ self.lora_B[i].requires_grad = False
92
+ for i in range(trainable_rank, trainable_rank + self.unit):
93
+ self.lora_A[i].requires_grad = True
94
+ self.lora_B[i].requires_grad = True
95
+ for i in range(trainable_rank + self.unit, self.lora_dim):
96
+ self.lora_A[i].requires_grad = False
97
+ self.lora_B[i].requires_grad = False
98
+
99
+ lora_A = torch.cat(tuple(self.lora_A), dim=0)
100
+ lora_B = torch.cat(tuple(self.lora_B), dim=1)
101
+
102
+ # calculate with lora_A and lora_B
103
+ if self.is_conv2d_3x3:
104
+ ab = torch.nn.functional.conv2d(x, lora_A, stride=self.stride, padding=self.padding)
105
+ ab = torch.nn.functional.conv2d(ab, lora_B)
106
+ else:
107
+ ab = x
108
+ if self.is_conv2d:
109
+ ab = ab.reshape(ab.size(0), ab.size(1), -1).transpose(1, 2) # (N, C, H, W) -> (N, H*W, C)
110
+
111
+ ab = torch.nn.functional.linear(ab, lora_A)
112
+ ab = torch.nn.functional.linear(ab, lora_B)
113
+
114
+ if self.is_conv2d:
115
+ ab = ab.transpose(1, 2).reshape(ab.size(0), -1, *x.size()[2:]) # (N, H*W, C) -> (N, C, H, W)
116
+
117
+ # 最後の項は、低rankをより大きくするためのスケー���ング(じゃないかな)
118
+ result = result + ab * self.scale * math.sqrt(self.lora_dim / (trainable_rank + self.unit))
119
+
120
+ # NOTE weightに加算してからlinear/conv2dを呼んだほうが速いかも
121
+ return result
122
+
123
+ def state_dict(self, destination=None, prefix="", keep_vars=False):
124
+ # state dictを通常のLoRAと同じにする:
125
+ # nn.ParameterListは `.lora_A.0` みたいな名前になるので、forwardと同様にcatして入れ替える
126
+ sd = super().state_dict(destination=destination, prefix=prefix, keep_vars=keep_vars)
127
+
128
+ lora_A_weight = torch.cat(tuple(self.lora_A), dim=0)
129
+ if self.is_conv2d and not self.is_conv2d_3x3:
130
+ lora_A_weight = lora_A_weight.unsqueeze(-1).unsqueeze(-1)
131
+
132
+ lora_B_weight = torch.cat(tuple(self.lora_B), dim=1)
133
+ if self.is_conv2d and not self.is_conv2d_3x3:
134
+ lora_B_weight = lora_B_weight.unsqueeze(-1).unsqueeze(-1)
135
+
136
+ sd[self.lora_name + ".lora_down.weight"] = lora_A_weight if keep_vars else lora_A_weight.detach()
137
+ sd[self.lora_name + ".lora_up.weight"] = lora_B_weight if keep_vars else lora_B_weight.detach()
138
+
139
+ i = 0
140
+ while True:
141
+ key_a = f"{self.lora_name}.lora_A.{i}"
142
+ key_b = f"{self.lora_name}.lora_B.{i}"
143
+ if key_a in sd:
144
+ sd.pop(key_a)
145
+ sd.pop(key_b)
146
+ else:
147
+ break
148
+ i += 1
149
+ return sd
150
+
151
+ def _load_from_state_dict(self, state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs):
152
+ # 通常のLoRAと同じstate dictを読み込めるようにする:この方法はchatGPTに聞いた
153
+ lora_A_weight = state_dict.pop(self.lora_name + ".lora_down.weight", None)
154
+ lora_B_weight = state_dict.pop(self.lora_name + ".lora_up.weight", None)
155
+
156
+ if lora_A_weight is None or lora_B_weight is None:
157
+ if strict:
158
+ raise KeyError(f"{self.lora_name}.lora_down/up.weight is not found")
159
+ else:
160
+ return
161
+
162
+ if self.is_conv2d and not self.is_conv2d_3x3:
163
+ lora_A_weight = lora_A_weight.squeeze(-1).squeeze(-1)
164
+ lora_B_weight = lora_B_weight.squeeze(-1).squeeze(-1)
165
+
166
+ state_dict.update(
167
+ {f"{self.lora_name}.lora_A.{i}": nn.Parameter(lora_A_weight[i].unsqueeze(0)) for i in range(lora_A_weight.size(0))}
168
+ )
169
+ state_dict.update(
170
+ {f"{self.lora_name}.lora_B.{i}": nn.Parameter(lora_B_weight[:, i].unsqueeze(1)) for i in range(lora_B_weight.size(1))}
171
+ )
172
+
173
+ super()._load_from_state_dict(state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs)
174
+
175
+
176
+ def create_network(
177
+ multiplier: float,
178
+ network_dim: Optional[int],
179
+ network_alpha: Optional[float],
180
+ vae: AutoencoderKL,
181
+ text_encoder: Union[CLIPTextModel, List[CLIPTextModel]],
182
+ unet,
183
+ **kwargs,
184
+ ):
185
+ if network_dim is None:
186
+ network_dim = 4 # default
187
+ if network_alpha is None:
188
+ network_alpha = 1.0
189
+
190
+ # extract dim/alpha for conv2d, and block dim
191
+ conv_dim = kwargs.get("conv_dim", None)
192
+ conv_alpha = kwargs.get("conv_alpha", None)
193
+ unit = kwargs.get("unit", None)
194
+ if conv_dim is not None:
195
+ conv_dim = int(conv_dim)
196
+ assert conv_dim == network_dim, "conv_dim must be same as network_dim"
197
+ if conv_alpha is None:
198
+ conv_alpha = 1.0
199
+ else:
200
+ conv_alpha = float(conv_alpha)
201
+
202
+ if unit is not None:
203
+ unit = int(unit)
204
+ else:
205
+ unit = 1
206
+
207
+ network = DyLoRANetwork(
208
+ text_encoder,
209
+ unet,
210
+ multiplier=multiplier,
211
+ lora_dim=network_dim,
212
+ alpha=network_alpha,
213
+ apply_to_conv=conv_dim is not None,
214
+ unit=unit,
215
+ varbose=True,
216
+ )
217
+
218
+ loraplus_lr_ratio = kwargs.get("loraplus_lr_ratio", None)
219
+ loraplus_unet_lr_ratio = kwargs.get("loraplus_unet_lr_ratio", None)
220
+ loraplus_text_encoder_lr_ratio = kwargs.get("loraplus_text_encoder_lr_ratio", None)
221
+ loraplus_lr_ratio = float(loraplus_lr_ratio) if loraplus_lr_ratio is not None else None
222
+ loraplus_unet_lr_ratio = float(loraplus_unet_lr_ratio) if loraplus_unet_lr_ratio is not None else None
223
+ loraplus_text_encoder_lr_ratio = float(loraplus_text_encoder_lr_ratio) if loraplus_text_encoder_lr_ratio is not None else None
224
+ if loraplus_lr_ratio is not None or loraplus_unet_lr_ratio is not None or loraplus_text_encoder_lr_ratio is not None:
225
+ network.set_loraplus_lr_ratio(loraplus_lr_ratio, loraplus_unet_lr_ratio, loraplus_text_encoder_lr_ratio)
226
+
227
+ return network
228
+
229
+
230
+ # Create network from weights for inference, weights are not loaded here (because can be merged)
231
+ def create_network_from_weights(multiplier, file, vae, text_encoder, unet, weights_sd=None, for_inference=False, **kwargs):
232
+ if weights_sd is None:
233
+ if os.path.splitext(file)[1] == ".safetensors":
234
+ from safetensors.torch import load_file, safe_open
235
+
236
+ weights_sd = load_file(file)
237
+ else:
238
+ weights_sd = torch.load(file, map_location="cpu")
239
+
240
+ # get dim/alpha mapping
241
+ modules_dim = {}
242
+ modules_alpha = {}
243
+ for key, value in weights_sd.items():
244
+ if "." not in key:
245
+ continue
246
+
247
+ lora_name = key.split(".")[0]
248
+ if "alpha" in key:
249
+ modules_alpha[lora_name] = value
250
+ elif "lora_down" in key:
251
+ dim = value.size()[0]
252
+ modules_dim[lora_name] = dim
253
+ # logger.info(f"{lora_name} {value.size()} {dim}")
254
+
255
+ # support old LoRA without alpha
256
+ for key in modules_dim.keys():
257
+ if key not in modules_alpha:
258
+ modules_alpha = modules_dim[key]
259
+
260
+ module_class = DyLoRAModule
261
+
262
+ network = DyLoRANetwork(
263
+ text_encoder, unet, multiplier=multiplier, modules_dim=modules_dim, modules_alpha=modules_alpha, module_class=module_class
264
+ )
265
+ return network, weights_sd
266
+
267
+
268
+ class DyLoRANetwork(torch.nn.Module):
269
+ UNET_TARGET_REPLACE_MODULE = ["Transformer2DModel"]
270
+ UNET_TARGET_REPLACE_MODULE_CONV2D_3X3 = ["ResnetBlock2D", "Downsample2D", "Upsample2D"]
271
+ TEXT_ENCODER_TARGET_REPLACE_MODULE = ["CLIPAttention", "CLIPSdpaAttention", "CLIPMLP"]
272
+ LORA_PREFIX_UNET = "lora_unet"
273
+ LORA_PREFIX_TEXT_ENCODER = "lora_te"
274
+
275
+ def __init__(
276
+ self,
277
+ text_encoder,
278
+ unet,
279
+ multiplier=1.0,
280
+ lora_dim=4,
281
+ alpha=1,
282
+ apply_to_conv=False,
283
+ modules_dim=None,
284
+ modules_alpha=None,
285
+ unit=1,
286
+ module_class=DyLoRAModule,
287
+ varbose=False,
288
+ ) -> None:
289
+ super().__init__()
290
+ self.multiplier = multiplier
291
+
292
+ self.lora_dim = lora_dim
293
+ self.alpha = alpha
294
+ self.apply_to_conv = apply_to_conv
295
+
296
+ self.loraplus_lr_ratio = None
297
+ self.loraplus_unet_lr_ratio = None
298
+ self.loraplus_text_encoder_lr_ratio = None
299
+
300
+ if modules_dim is not None:
301
+ logger.info("create LoRA network from weights")
302
+ else:
303
+ logger.info(f"create LoRA network. base dim (rank): {lora_dim}, alpha: {alpha}, unit: {unit}")
304
+ if self.apply_to_conv:
305
+ logger.info("apply LoRA to Conv2d with kernel size (3,3).")
306
+
307
+ # create module instances
308
+ def create_modules(is_unet, root_module: torch.nn.Module, target_replace_modules) -> List[DyLoRAModule]:
309
+ prefix = DyLoRANetwork.LORA_PREFIX_UNET if is_unet else DyLoRANetwork.LORA_PREFIX_TEXT_ENCODER
310
+ loras = []
311
+ for name, module in root_module.named_modules():
312
+ if module.__class__.__name__ in target_replace_modules:
313
+ for child_name, child_module in module.named_modules():
314
+ is_linear = child_module.__class__.__name__ == "Linear"
315
+ is_conv2d = child_module.__class__.__name__ == "Conv2d"
316
+ is_conv2d_1x1 = is_conv2d and child_module.kernel_size == (1, 1)
317
+
318
+ if is_linear or is_conv2d:
319
+ lora_name = prefix + "." + name + "." + child_name
320
+ lora_name = lora_name.replace(".", "_")
321
+
322
+ dim = None
323
+ alpha = None
324
+ if modules_dim is not None:
325
+ if lora_name in modules_dim:
326
+ dim = modules_dim[lora_name]
327
+ alpha = modules_alpha[lora_name]
328
+ else:
329
+ if is_linear or is_conv2d_1x1 or apply_to_conv:
330
+ dim = self.lora_dim
331
+ alpha = self.alpha
332
+
333
+ if dim is None or dim == 0:
334
+ continue
335
+
336
+ # dropout and fan_in_fan_out is default
337
+ lora = module_class(lora_name, child_module, self.multiplier, dim, alpha, unit)
338
+ loras.append(lora)
339
+ return loras
340
+
341
+ text_encoders = text_encoder if type(text_encoder) == list else [text_encoder]
342
+
343
+ self.text_encoder_loras = []
344
+ for i, text_encoder in enumerate(text_encoders):
345
+ if len(text_encoders) > 1:
346
+ index = i + 1
347
+ logger.info(f"create LoRA for Text Encoder {index}")
348
+ else:
349
+ index = None
350
+ logger.info("create LoRA for Text Encoder")
351
+
352
+ text_encoder_loras = create_modules(False, text_encoder, DyLoRANetwork.TEXT_ENCODER_TARGET_REPLACE_MODULE)
353
+ self.text_encoder_loras.extend(text_encoder_loras)
354
+
355
+ # self.text_encoder_loras = create_modules(False, text_encoder, DyLoRANetwork.TEXT_ENCODER_TARGET_REPLACE_MODULE)
356
+ logger.info(f"create LoRA for Text Encoder: {len(self.text_encoder_loras)} modules.")
357
+
358
+ # extend U-Net target modules if conv2d 3x3 is enabled, or load from weights
359
+ target_modules = DyLoRANetwork.UNET_TARGET_REPLACE_MODULE
360
+ if modules_dim is not None or self.apply_to_conv:
361
+ target_modules += DyLoRANetwork.UNET_TARGET_REPLACE_MODULE_CONV2D_3X3
362
+
363
+ self.unet_loras = create_modules(True, unet, target_modules)
364
+ logger.info(f"create LoRA for U-Net: {len(self.unet_loras)} modules.")
365
+
366
+ def set_loraplus_lr_ratio(self, loraplus_lr_ratio, loraplus_unet_lr_ratio, loraplus_text_encoder_lr_ratio):
367
+ self.loraplus_lr_ratio = loraplus_lr_ratio
368
+ self.loraplus_unet_lr_ratio = loraplus_unet_lr_ratio
369
+ self.loraplus_text_encoder_lr_ratio = loraplus_text_encoder_lr_ratio
370
+
371
+ logger.info(f"LoRA+ UNet LR Ratio: {self.loraplus_unet_lr_ratio or self.loraplus_lr_ratio}")
372
+ logger.info(f"LoRA+ Text Encoder LR Ratio: {self.loraplus_text_encoder_lr_ratio or self.loraplus_lr_ratio}")
373
+
374
+ def set_multiplier(self, multiplier):
375
+ self.multiplier = multiplier
376
+ for lora in self.text_encoder_loras + self.unet_loras:
377
+ lora.multiplier = self.multiplier
378
+
379
+ def load_weights(self, file):
380
+ if os.path.splitext(file)[1] == ".safetensors":
381
+ from safetensors.torch import load_file
382
+
383
+ weights_sd = load_file(file)
384
+ else:
385
+ weights_sd = torch.load(file, map_location="cpu")
386
+
387
+ info = self.load_state_dict(weights_sd, False)
388
+ return info
389
+
390
+ def apply_to(self, text_encoder, unet, apply_text_encoder=True, apply_unet=True):
391
+ if apply_text_encoder:
392
+ logger.info("enable LoRA for text encoder")
393
+ else:
394
+ self.text_encoder_loras = []
395
+
396
+ if apply_unet:
397
+ logger.info("enable LoRA for U-Net")
398
+ else:
399
+ self.unet_loras = []
400
+
401
+ for lora in self.text_encoder_loras + self.unet_loras:
402
+ lora.apply_to()
403
+ self.add_module(lora.lora_name, lora)
404
+
405
+ """
406
+ def merge_to(self, text_encoder, unet, weights_sd, dtype, device):
407
+ apply_text_encoder = apply_unet = False
408
+ for key in weights_sd.keys():
409
+ if key.startswith(DyLoRANetwork.LORA_PREFIX_TEXT_ENCODER):
410
+ apply_text_encoder = True
411
+ elif key.startswith(DyLoRANetwork.LORA_PREFIX_UNET):
412
+ apply_unet = True
413
+
414
+ if apply_text_encoder:
415
+ logger.info("enable LoRA for text encoder")
416
+ else:
417
+ self.text_encoder_loras = []
418
+
419
+ if apply_unet:
420
+ logger.info("enable LoRA for U-Net")
421
+ else:
422
+ self.unet_loras = []
423
+
424
+ for lora in self.text_encoder_loras + self.unet_loras:
425
+ sd_for_lora = {}
426
+ for key in weights_sd.keys():
427
+ if key.startswith(lora.lora_name):
428
+ sd_for_lora[key[len(lora.lora_name) + 1 :]] = weights_sd[key]
429
+ lora.merge_to(sd_for_lora, dtype, device)
430
+
431
+ logger.info(f"weights are merged")
432
+ """
433
+
434
+ # 二つのText Encoderに別々の学習率を設定できるようにするといいかも
435
+ def prepare_optimizer_params(self, text_encoder_lr, unet_lr, default_lr):
436
+ self.requires_grad_(True)
437
+ all_params = []
438
+
439
+ def assemble_params(loras, lr, ratio):
440
+ param_groups = {"lora": {}, "plus": {}}
441
+ for lora in loras:
442
+ for name, param in lora.named_parameters():
443
+ if ratio is not None and "lora_B" in name:
444
+ param_groups["plus"][f"{lora.lora_name}.{name}"] = param
445
+ else:
446
+ param_groups["lora"][f"{lora.lora_name}.{name}"] = param
447
+
448
+ params = []
449
+ for key in param_groups.keys():
450
+ param_data = {"params": param_groups[key].values()}
451
+
452
+ if len(param_data["params"]) == 0:
453
+ continue
454
+
455
+ if lr is not None:
456
+ if key == "plus":
457
+ param_data["lr"] = lr * ratio
458
+ else:
459
+ param_data["lr"] = lr
460
+
461
+ if param_data.get("lr", None) == 0 or param_data.get("lr", None) is None:
462
+ continue
463
+
464
+ params.append(param_data)
465
+
466
+ return params
467
+
468
+ if self.text_encoder_loras:
469
+ params = assemble_params(
470
+ self.text_encoder_loras,
471
+ text_encoder_lr if text_encoder_lr is not None else default_lr,
472
+ self.loraplus_text_encoder_lr_ratio or self.loraplus_lr_ratio,
473
+ )
474
+ all_params.extend(params)
475
+
476
+ if self.unet_loras:
477
+ params = assemble_params(
478
+ self.unet_loras, default_lr if unet_lr is None else unet_lr, self.loraplus_unet_lr_ratio or self.loraplus_lr_ratio
479
+ )
480
+ all_params.extend(params)
481
+
482
+ return all_params
483
+
484
+ def enable_gradient_checkpointing(self):
485
+ # not supported
486
+ pass
487
+
488
+ def prepare_grad_etc(self, text_encoder, unet):
489
+ self.requires_grad_(True)
490
+
491
+ def on_epoch_start(self, text_encoder, unet):
492
+ self.train()
493
+
494
+ def get_trainable_params(self):
495
+ return self.parameters()
496
+
497
+ def save_weights(self, file, dtype, metadata):
498
+ if metadata is not None and len(metadata) == 0:
499
+ metadata = None
500
+
501
+ state_dict = self.state_dict()
502
+
503
+ if dtype is not None:
504
+ for key in list(state_dict.keys()):
505
+ v = state_dict[key]
506
+ v = v.detach().clone().to("cpu").to(dtype)
507
+ state_dict[key] = v
508
+
509
+ if os.path.splitext(file)[1] == ".safetensors":
510
+ from safetensors.torch import save_file
511
+ from library import train_util
512
+
513
+ # Precalculate model hashes to save time on indexing
514
+ if metadata is None:
515
+ metadata = {}
516
+ model_hash, legacy_hash = train_util.precalculate_safetensors_hashes(state_dict, metadata)
517
+ metadata["sshs_model_hash"] = model_hash
518
+ metadata["sshs_legacy_hash"] = legacy_hash
519
+
520
+ save_file(state_dict, file, metadata)
521
+ else:
522
+ torch.save(state_dict, file)
523
+
524
+ # mask is a tensor with values from 0 to 1
525
+ def set_region(self, sub_prompt_index, is_last_network, mask):
526
+ pass
527
+
528
+ def set_current_generation(self, batch_size, num_sub_prompts, width, height, shared):
529
+ pass
extract_lora_from_dylora.py ADDED
@@ -0,0 +1,128 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Convert LoRA to different rank approximation (should only be used to go to lower rank)
2
+ # This code is based off the extract_lora_from_models.py file which is based on https://github.com/cloneofsimo/lora/blob/develop/lora_diffusion/cli_svd.py
3
+ # Thanks to cloneofsimo
4
+
5
+ import argparse
6
+ import math
7
+ import os
8
+ import torch
9
+ from safetensors.torch import load_file, save_file, safe_open
10
+ from tqdm import tqdm
11
+ from library import train_util, model_util
12
+ import numpy as np
13
+ from library.utils import setup_logging
14
+ setup_logging()
15
+ import logging
16
+ logger = logging.getLogger(__name__)
17
+
18
+ def load_state_dict(file_name):
19
+ if model_util.is_safetensors(file_name):
20
+ sd = load_file(file_name)
21
+ with safe_open(file_name, framework="pt") as f:
22
+ metadata = f.metadata()
23
+ else:
24
+ sd = torch.load(file_name, map_location="cpu")
25
+ metadata = None
26
+
27
+ return sd, metadata
28
+
29
+
30
+ def save_to_file(file_name, model, metadata):
31
+ if model_util.is_safetensors(file_name):
32
+ save_file(model, file_name, metadata)
33
+ else:
34
+ torch.save(model, file_name)
35
+
36
+
37
+ def split_lora_model(lora_sd, unit):
38
+ max_rank = 0
39
+
40
+ # Extract loaded lora dim and alpha
41
+ for key, value in lora_sd.items():
42
+ if "lora_down" in key:
43
+ rank = value.size()[0]
44
+ if rank > max_rank:
45
+ max_rank = rank
46
+ logger.info(f"Max rank: {max_rank}")
47
+
48
+ rank = unit
49
+ split_models = []
50
+ new_alpha = None
51
+ while rank < max_rank:
52
+ logger.info(f"Splitting rank {rank}")
53
+ new_sd = {}
54
+ for key, value in lora_sd.items():
55
+ if "lora_down" in key:
56
+ new_sd[key] = value[:rank].contiguous()
57
+ elif "lora_up" in key:
58
+ new_sd[key] = value[:, :rank].contiguous()
59
+ else:
60
+ # なぜかscaleするとおかしくなる……
61
+ # this_rank = lora_sd[key.replace("alpha", "lora_down.weight")].size()[0]
62
+ # scale = math.sqrt(this_rank / rank) # rank is > unit
63
+ # logger.info(key, value.size(), this_rank, rank, value, scale)
64
+ # new_alpha = value * scale # always same
65
+ # new_sd[key] = new_alpha
66
+ new_sd[key] = value
67
+
68
+ split_models.append((new_sd, rank, new_alpha))
69
+ rank += unit
70
+
71
+ return max_rank, split_models
72
+
73
+
74
+ def split(args):
75
+ logger.info("loading Model...")
76
+ lora_sd, metadata = load_state_dict(args.model)
77
+
78
+ logger.info("Splitting Model...")
79
+ original_rank, split_models = split_lora_model(lora_sd, args.unit)
80
+
81
+ comment = metadata.get("ss_training_comment", "")
82
+ for state_dict, new_rank, new_alpha in split_models:
83
+ # update metadata
84
+ if metadata is None:
85
+ new_metadata = {}
86
+ else:
87
+ new_metadata = metadata.copy()
88
+
89
+ new_metadata["ss_training_comment"] = f"split from DyLoRA, rank {original_rank} to {new_rank}; {comment}"
90
+ new_metadata["ss_network_dim"] = str(new_rank)
91
+ # new_metadata["ss_network_alpha"] = str(new_alpha.float().numpy())
92
+
93
+ model_hash, legacy_hash = train_util.precalculate_safetensors_hashes(state_dict, metadata)
94
+ metadata["sshs_model_hash"] = model_hash
95
+ metadata["sshs_legacy_hash"] = legacy_hash
96
+
97
+ filename, ext = os.path.splitext(args.save_to)
98
+ model_file_name = filename + f"-{new_rank:04d}{ext}"
99
+
100
+ logger.info(f"saving model to: {model_file_name}")
101
+ save_to_file(model_file_name, state_dict, new_metadata)
102
+
103
+
104
+ def setup_parser() -> argparse.ArgumentParser:
105
+ parser = argparse.ArgumentParser()
106
+
107
+ parser.add_argument("--unit", type=int, default=None, help="size of rank to split into / rankを分割するサイズ")
108
+ parser.add_argument(
109
+ "--save_to",
110
+ type=str,
111
+ default=None,
112
+ help="destination base file name: ckpt or safetensors file / 保存先のファイル名のbase、ckptまたはsafetensors",
113
+ )
114
+ parser.add_argument(
115
+ "--model",
116
+ type=str,
117
+ default=None,
118
+ help="DyLoRA model to resize at to new rank: ckpt or safetensors file / 読み込むDyLoRAモデル、ckptまたはsafetensors",
119
+ )
120
+
121
+ return parser
122
+
123
+
124
+ if __name__ == "__main__":
125
+ parser = setup_parser()
126
+
127
+ args = parser.parse_args()
128
+ split(args)
extract_lora_from_models.py ADDED
@@ -0,0 +1,360 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # extract approximating LoRA by svd from two SD models
2
+ # The code is based on https://github.com/cloneofsimo/lora/blob/develop/lora_diffusion/cli_svd.py
3
+ # Thanks to cloneofsimo!
4
+
5
+ import argparse
6
+ import json
7
+ import os
8
+ import time
9
+ import torch
10
+ from safetensors.torch import load_file, save_file
11
+ from tqdm import tqdm
12
+ from library import sai_model_spec, model_util, sdxl_model_util
13
+ import lora
14
+ from library.utils import setup_logging
15
+ setup_logging()
16
+ import logging
17
+ logger = logging.getLogger(__name__)
18
+
19
+ # CLAMP_QUANTILE = 0.99
20
+ # MIN_DIFF = 1e-1
21
+
22
+
23
+ def save_to_file(file_name, model, state_dict, dtype):
24
+ if dtype is not None:
25
+ for key in list(state_dict.keys()):
26
+ if type(state_dict[key]) == torch.Tensor:
27
+ state_dict[key] = state_dict[key].to(dtype)
28
+
29
+ if os.path.splitext(file_name)[1] == ".safetensors":
30
+ save_file(model, file_name)
31
+ else:
32
+ torch.save(model, file_name)
33
+
34
+
35
+ def svd(
36
+ model_org=None,
37
+ model_tuned=None,
38
+ save_to=None,
39
+ dim=4,
40
+ v2=None,
41
+ sdxl=None,
42
+ conv_dim=None,
43
+ v_parameterization=None,
44
+ device=None,
45
+ save_precision=None,
46
+ clamp_quantile=0.99,
47
+ min_diff=0.01,
48
+ no_metadata=False,
49
+ load_precision=None,
50
+ load_original_model_to=None,
51
+ load_tuned_model_to=None,
52
+ ):
53
+ def str_to_dtype(p):
54
+ if p == "float":
55
+ return torch.float
56
+ if p == "fp16":
57
+ return torch.float16
58
+ if p == "bf16":
59
+ return torch.bfloat16
60
+ return None
61
+
62
+ assert v2 != sdxl or (not v2 and not sdxl), "v2 and sdxl cannot be specified at the same time / v2とsdxlは同時に指定できません"
63
+ if v_parameterization is None:
64
+ v_parameterization = v2
65
+
66
+ load_dtype = str_to_dtype(load_precision) if load_precision else None
67
+ save_dtype = str_to_dtype(save_precision)
68
+ work_device = "cpu"
69
+
70
+ # load models
71
+ if not sdxl:
72
+ logger.info(f"loading original SD model : {model_org}")
73
+ text_encoder_o, _, unet_o = model_util.load_models_from_stable_diffusion_checkpoint(v2, model_org)
74
+ text_encoders_o = [text_encoder_o]
75
+ if load_dtype is not None:
76
+ text_encoder_o = text_encoder_o.to(load_dtype)
77
+ unet_o = unet_o.to(load_dtype)
78
+
79
+ logger.info(f"loading tuned SD model : {model_tuned}")
80
+ text_encoder_t, _, unet_t = model_util.load_models_from_stable_diffusion_checkpoint(v2, model_tuned)
81
+ text_encoders_t = [text_encoder_t]
82
+ if load_dtype is not None:
83
+ text_encoder_t = text_encoder_t.to(load_dtype)
84
+ unet_t = unet_t.to(load_dtype)
85
+
86
+ model_version = model_util.get_model_version_str_for_sd1_sd2(v2, v_parameterization)
87
+ else:
88
+ device_org = load_original_model_to if load_original_model_to else "cpu"
89
+ device_tuned = load_tuned_model_to if load_tuned_model_to else "cpu"
90
+
91
+ logger.info(f"loading original SDXL model : {model_org}")
92
+ text_encoder_o1, text_encoder_o2, _, unet_o, _, _ = sdxl_model_util.load_models_from_sdxl_checkpoint(
93
+ sdxl_model_util.MODEL_VERSION_SDXL_BASE_V1_0, model_org, device_org
94
+ )
95
+ text_encoders_o = [text_encoder_o1, text_encoder_o2]
96
+ if load_dtype is not None:
97
+ text_encoder_o1 = text_encoder_o1.to(load_dtype)
98
+ text_encoder_o2 = text_encoder_o2.to(load_dtype)
99
+ unet_o = unet_o.to(load_dtype)
100
+
101
+ logger.info(f"loading original SDXL model : {model_tuned}")
102
+ text_encoder_t1, text_encoder_t2, _, unet_t, _, _ = sdxl_model_util.load_models_from_sdxl_checkpoint(
103
+ sdxl_model_util.MODEL_VERSION_SDXL_BASE_V1_0, model_tuned, device_tuned
104
+ )
105
+ text_encoders_t = [text_encoder_t1, text_encoder_t2]
106
+ if load_dtype is not None:
107
+ text_encoder_t1 = text_encoder_t1.to(load_dtype)
108
+ text_encoder_t2 = text_encoder_t2.to(load_dtype)
109
+ unet_t = unet_t.to(load_dtype)
110
+
111
+ model_version = sdxl_model_util.MODEL_VERSION_SDXL_BASE_V1_0
112
+
113
+ # create LoRA network to extract weights: Use dim (rank) as alpha
114
+ if conv_dim is None:
115
+ kwargs = {}
116
+ else:
117
+ kwargs = {"conv_dim": conv_dim, "conv_alpha": conv_dim}
118
+
119
+ lora_network_o = lora.create_network(1.0, dim, dim, None, text_encoders_o, unet_o, **kwargs)
120
+ lora_network_t = lora.create_network(1.0, dim, dim, None, text_encoders_t, unet_t, **kwargs)
121
+ assert len(lora_network_o.text_encoder_loras) == len(
122
+ lora_network_t.text_encoder_loras
123
+ ), f"model version is different (SD1.x vs SD2.x) / それぞれのモデルのバージョンが違います(SD1.xベースとSD2.xベース) "
124
+
125
+ # get diffs
126
+ diffs = {}
127
+ text_encoder_different = False
128
+ for i, (lora_o, lora_t) in enumerate(zip(lora_network_o.text_encoder_loras, lora_network_t.text_encoder_loras)):
129
+ lora_name = lora_o.lora_name
130
+ module_o = lora_o.org_module
131
+ module_t = lora_t.org_module
132
+ diff = module_t.weight.to(work_device) - module_o.weight.to(work_device)
133
+
134
+ # clear weight to save memory
135
+ module_o.weight = None
136
+ module_t.weight = None
137
+
138
+ # Text Encoder might be same
139
+ if not text_encoder_different and torch.max(torch.abs(diff)) > min_diff:
140
+ text_encoder_different = True
141
+ logger.info(f"Text encoder is different. {torch.max(torch.abs(diff))} > {min_diff}")
142
+
143
+ diffs[lora_name] = diff
144
+
145
+ # clear target Text Encoder to save memory
146
+ for text_encoder in text_encoders_t:
147
+ del text_encoder
148
+
149
+ if not text_encoder_different:
150
+ logger.warning("Text encoder is same. Extract U-Net only.")
151
+ lora_network_o.text_encoder_loras = []
152
+ diffs = {} # clear diffs
153
+
154
+ for i, (lora_o, lora_t) in enumerate(zip(lora_network_o.unet_loras, lora_network_t.unet_loras)):
155
+ lora_name = lora_o.lora_name
156
+ module_o = lora_o.org_module
157
+ module_t = lora_t.org_module
158
+ diff = module_t.weight.to(work_device) - module_o.weight.to(work_device)
159
+
160
+ # clear weight to save memory
161
+ module_o.weight = None
162
+ module_t.weight = None
163
+
164
+ diffs[lora_name] = diff
165
+
166
+ # clear LoRA network, target U-Net to save memory
167
+ del lora_network_o
168
+ del lora_network_t
169
+ del unet_t
170
+
171
+ # make LoRA with svd
172
+ logger.info("calculating by svd")
173
+ lora_weights = {}
174
+ with torch.no_grad():
175
+ for lora_name, mat in tqdm(list(diffs.items())):
176
+ if args.device:
177
+ mat = mat.to(args.device)
178
+ mat = mat.to(torch.float) # calc by float
179
+
180
+ # if conv_dim is None, diffs do not include LoRAs for conv2d-3x3
181
+ conv2d = len(mat.size()) == 4
182
+ kernel_size = None if not conv2d else mat.size()[2:4]
183
+ conv2d_3x3 = conv2d and kernel_size != (1, 1)
184
+
185
+ rank = dim if not conv2d_3x3 or conv_dim is None else conv_dim
186
+ out_dim, in_dim = mat.size()[0:2]
187
+
188
+ if device:
189
+ mat = mat.to(device)
190
+
191
+ # logger.info(lora_name, mat.size(), mat.device, rank, in_dim, out_dim)
192
+ rank = min(rank, in_dim, out_dim) # LoRA rank cannot exceed the original dim
193
+
194
+ if conv2d:
195
+ if conv2d_3x3:
196
+ mat = mat.flatten(start_dim=1)
197
+ else:
198
+ mat = mat.squeeze()
199
+
200
+ U, S, Vh = torch.linalg.svd(mat)
201
+
202
+ U = U[:, :rank]
203
+ S = S[:rank]
204
+ U = U @ torch.diag(S)
205
+
206
+ Vh = Vh[:rank, :]
207
+
208
+ dist = torch.cat([U.flatten(), Vh.flatten()])
209
+ hi_val = torch.quantile(dist, clamp_quantile)
210
+ low_val = -hi_val
211
+
212
+ U = U.clamp(low_val, hi_val)
213
+ Vh = Vh.clamp(low_val, hi_val)
214
+
215
+ if conv2d:
216
+ U = U.reshape(out_dim, rank, 1, 1)
217
+ Vh = Vh.reshape(rank, in_dim, kernel_size[0], kernel_size[1])
218
+
219
+ U = U.to(work_device, dtype=save_dtype).contiguous()
220
+ Vh = Vh.to(work_device, dtype=save_dtype).contiguous()
221
+
222
+ lora_weights[lora_name] = (U, Vh)
223
+
224
+ # make state dict for LoRA
225
+ lora_sd = {}
226
+ for lora_name, (up_weight, down_weight) in lora_weights.items():
227
+ lora_sd[lora_name + ".lora_up.weight"] = up_weight
228
+ lora_sd[lora_name + ".lora_down.weight"] = down_weight
229
+ lora_sd[lora_name + ".alpha"] = torch.tensor(down_weight.size()[0])
230
+
231
+ # load state dict to LoRA and save it
232
+ lora_network_save, lora_sd = lora.create_network_from_weights(1.0, None, None, text_encoders_o, unet_o, weights_sd=lora_sd)
233
+ lora_network_save.apply_to(text_encoders_o, unet_o) # create internal module references for state_dict
234
+
235
+ info = lora_network_save.load_state_dict(lora_sd)
236
+ logger.info(f"Loading extracted LoRA weights: {info}")
237
+
238
+ dir_name = os.path.dirname(save_to)
239
+ if dir_name and not os.path.exists(dir_name):
240
+ os.makedirs(dir_name, exist_ok=True)
241
+
242
+ # minimum metadata
243
+ net_kwargs = {}
244
+ if conv_dim is not None:
245
+ net_kwargs["conv_dim"] = str(conv_dim)
246
+ net_kwargs["conv_alpha"] = str(float(conv_dim))
247
+
248
+ metadata = {
249
+ "ss_v2": str(v2),
250
+ "ss_base_model_version": model_version,
251
+ "ss_network_module": "networks.lora",
252
+ "ss_network_dim": str(dim),
253
+ "ss_network_alpha": str(float(dim)),
254
+ "ss_network_args": json.dumps(net_kwargs),
255
+ }
256
+
257
+ if not no_metadata:
258
+ title = os.path.splitext(os.path.basename(save_to))[0]
259
+ sai_metadata = sai_model_spec.build_metadata(None, v2, v_parameterization, sdxl, True, False, time.time(), title=title)
260
+ metadata.update(sai_metadata)
261
+
262
+ lora_network_save.save_weights(save_to, save_dtype, metadata)
263
+ logger.info(f"LoRA weights are saved to: {save_to}")
264
+
265
+
266
+ def setup_parser() -> argparse.ArgumentParser:
267
+ parser = argparse.ArgumentParser()
268
+ parser.add_argument("--v2", action="store_true", help="load Stable Diffusion v2.x model / Stable Diffusion 2.xのモデルを読み込む")
269
+ parser.add_argument(
270
+ "--v_parameterization",
271
+ action="store_true",
272
+ default=None,
273
+ help="make LoRA metadata for v-parameterization (default is same to v2) / 作成するLoRAのメタデータにv-parameterization用と設定する(省略時はv2と同じ)",
274
+ )
275
+ parser.add_argument(
276
+ "--sdxl", action="store_true", help="load Stable Diffusion SDXL base model / Stable Diffusion SDXL baseのモデルを読み込む"
277
+ )
278
+ parser.add_argument(
279
+ "--load_precision",
280
+ type=str,
281
+ default=None,
282
+ choices=[None, "float", "fp16", "bf16"],
283
+ help="precision in loading, model default if omitted / 読み込み時に精度を変更して読み込む、省略時はモデルファイルによる"
284
+ )
285
+ parser.add_argument(
286
+ "--save_precision",
287
+ type=str,
288
+ default=None,
289
+ choices=[None, "float", "fp16", "bf16"],
290
+ help="precision in saving, same to merging if omitted / 保存時に精度を変更して保存する、省略時はfloat",
291
+ )
292
+ parser.add_argument(
293
+ "--model_org",
294
+ type=str,
295
+ default=None,
296
+ required=True,
297
+ help="Stable Diffusion original model: ckpt or safetensors file / 元モデル、ckptまたはsafetensors",
298
+ )
299
+ parser.add_argument(
300
+ "--model_tuned",
301
+ type=str,
302
+ default=None,
303
+ required=True,
304
+ help="Stable Diffusion tuned model, LoRA is difference of `original to tuned`: ckpt or safetensors file / 派生モデル(生成されるLoRAは元→派生の差分になります)、ckptまたはsafetensors",
305
+ )
306
+ parser.add_argument(
307
+ "--save_to",
308
+ type=str,
309
+ default=None,
310
+ required=True,
311
+ help="destination file name: ckpt or safetensors file / 保存先のファイル名、ckptまたはsafetensors",
312
+ )
313
+ parser.add_argument("--dim", type=int, default=4, help="dimension (rank) of LoRA (default 4) / LoRAの次元数(rank)(デフォルト4)")
314
+ parser.add_argument(
315
+ "--conv_dim",
316
+ type=int,
317
+ default=None,
318
+ help="dimension (rank) of LoRA for Conv2d-3x3 (default None, disabled) / LoRAのConv2d-3x3の次元数(rank)(デフォルトNone、適用なし)",
319
+ )
320
+ parser.add_argument("--device", type=str, default=None, help="device to use, cuda for GPU / 計算を行うデバイス、cuda でGPUを使う")
321
+ parser.add_argument(
322
+ "--clamp_quantile",
323
+ type=float,
324
+ default=0.99,
325
+ help="Quantile clamping value, float, (0-1). Default = 0.99 / 値をクランプするための分位点、float、(0-1)。デフォルトは0.99",
326
+ )
327
+ parser.add_argument(
328
+ "--min_diff",
329
+ type=float,
330
+ default=0.01,
331
+ help="Minimum difference between finetuned model and base to consider them different enough to extract, float, (0-1). Default = 0.01 /"
332
+ + "LoRAを抽出するために元モデルと派生モデルの差分の最小値、float、(0-1)。デフォルトは0.01",
333
+ )
334
+ parser.add_argument(
335
+ "--no_metadata",
336
+ action="store_true",
337
+ help="do not save sai modelspec metadata (minimum ss_metadata for LoRA is saved) / "
338
+ + "sai modelspecのメタデータを保存しない(LoRAの最低限のss_metadataは保存される)",
339
+ )
340
+ parser.add_argument(
341
+ "--load_original_model_to",
342
+ type=str,
343
+ default=None,
344
+ help="location to load original model, cpu or cuda, cuda:0, etc, default is cpu, only for SDXL / 元モデル読み込み先、cpuまたはcuda、cuda:0など、省略時はcpu、SDXLのみ有効",
345
+ )
346
+ parser.add_argument(
347
+ "--load_tuned_model_to",
348
+ type=str,
349
+ default=None,
350
+ help="location to load tuned model, cpu or cuda, cuda:0, etc, default is cpu, only for SDXL / 派生モデル読み込み先、cpuまたはcuda、cuda:0など、省略時はcpu、SDXLのみ有効",
351
+ )
352
+
353
+ return parser
354
+
355
+
356
+ if __name__ == "__main__":
357
+ parser = setup_parser()
358
+
359
+ args = parser.parse_args()
360
+ svd(**vars(args))
fine_tune_README_ja.md ADDED
@@ -0,0 +1,140 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ NovelAIの提案した学習手法、自動キャプションニング、タグ付け、Windows+VRAM 12GB(SD v1.xの場合)環境等に対応したfine tuningです。ここでfine tuningとは、モデルを画像とキャプションで学習することを指します(LoRAやTextual Inversion、Hypernetworksは含みません)
2
+
3
+ [学習についての共通ドキュメント](./train_README-ja.md) もあわせてご覧ください。
4
+
5
+ # 概要
6
+
7
+ Diffusersを用いてStable DiffusionのU-Netのfine tuningを行います。NovelAIの記事にある以下の改善に対応しています(Aspect Ratio BucketingについてはNovelAIのコードを参考にしましたが、最終的なコードはすべてオリジナルです)。
8
+
9
+ * CLIP(Text Encoder)の最後の層ではなく最後から二番目の層の出力を用いる。
10
+ * 正方形以外の解像度での学習(Aspect Ratio Bucketing) 。
11
+ * トークン長を75から225に拡張する。
12
+ * BLIPによるキャプショニング(キャプションの自動作成)、DeepDanbooruまたはWD14Taggerによる自動タグ付けを行う。
13
+ * Hypernetworkの学習にも対応する。
14
+ * Stable Diffusion v2.0(baseおよび768/v)に対応。
15
+ * VAEの出力をあらかじめ取得しディスクに保存しておくことで、学習の省メモリ化、高速化を図る。
16
+
17
+ デフォルトではText Encoderの学習は行いません。モデル全体のfine tuningではU-Netだけを学習するのが一般的なようです(NovelAIもそのようです)。オプション指定でText Encoderも学習対象とできます。
18
+
19
+ # 追加機能について
20
+
21
+ ## CLIPの出力の変更
22
+
23
+ プロンプトを画像に反映するため、テキストの特徴量への変換を行うのがCLIP(Text Encoder)です。Stable DiffusionではCLIPの最後の層の出力を用いていますが、それを最後から二番目の層の出力を用いるよう変更できます。NovelAIによると、これによりより正確にプロンプトが反映されるようになるとのことです。
24
+ 元のまま、最後の層の出力を用いることも可能です。
25
+
26
+ ※Stable Diffusion 2.0では最後から二番目の層をデフォルトで使います。clip_skipオプションを指定しないでください。
27
+
28
+ ## 正方形以外の解像度での学習
29
+
30
+ Stable Diffusionは512\*512で学習されていますが、それに加えて256\*1024や384\*640といった解像度でも学習します。これによりトリミングされる部分が減り、より正しくプロンプトと画像の関係が学習されることが期待されます。
31
+ 学習解像度はパラメータとして与えられた解像度の面積(=メモリ使用量)を超えない範囲で、64ピクセル単位で縦横に調整、作成されます。
32
+
33
+ 機械学習では入力サイズをすべて統一するのが一般的ですが、特に制約があるわけではなく、実際は同一のバッチ内で統一されていれば大丈夫です。NovelAIの言うbucketingは、あらかじめ教師データを、アスペクト比に応じた学習解像度ごとに分類しておくことを指しているようです。そしてバッチを各bucket内の画像で作成することで、バッチの画像サイズを統一します。
34
+
35
+ ## トークン長の75から225への拡張
36
+
37
+ Stable Diffusionでは最大75トークン(開始・終了を含むと77トークン)ですが、それを225トークンまで拡張します。
38
+ ただしCLIPが受け付ける最大長は75トークンですので、225トークンの場合、単純に三分割してCLIPを呼び出してから結果を連結しています。
39
+
40
+ ※これが望ましい実装なのかどうかはいまひとつわかりません。とりあえず動いてはいるようです。特に2.0では何も参考になる実装がないので独自に実装してあります。
41
+
42
+ ※Automatic1111氏のWeb UIではカンマを意識して分割、といったこともしているようですが、私の場合はそこまでしておらず単純な分割です。
43
+
44
+ # 学習の手順
45
+
46
+ あらかじめこのリポジトリのREADMEを参照し、環境整備を行ってください。
47
+
48
+ ## データの準備
49
+
50
+ [学習データの準備について](./train_README-ja.md) を参照してください。fine tuningではメタデータを用いるfine tuning方式のみ対応しています。
51
+
52
+ ## 学習の実行
53
+ たとえば以下のように実行します。以下は省メモリ化のための設定です。それぞれの行を必要に応じて書き換えてください。
54
+
55
+ ```
56
+ accelerate launch --num_cpu_threads_per_process 1 fine_tune.py
57
+ --pretrained_model_name_or_path=<.ckptまたは.safetensordまたはDiffusers版モデルのディレクトリ>
58
+ --output_dir=<学習したモデルの出力先フォルダ>
59
+ --output_name=<学習したモデル出力時のファイル名>
60
+ --dataset_config=<データ準備で作成した.tomlファイル>
61
+ --save_model_as=safetensors
62
+ --learning_rate=5e-6 --max_train_steps=10000
63
+ --use_8bit_adam --xformers --gradient_checkpointing
64
+ --mixed_precision=fp16
65
+ ```
66
+
67
+ `num_cpu_threads_per_process` には通常は1を指定するとよいようです。
68
+
69
+ `pretrained_model_name_or_path` に追加学習を行う元となるモデルを指定します。Stable Diffusionのcheckpointファイル(.ckptまたは.safetensors)、Diffusersのローカルディスクにあるモデルディレクトリ、DiffusersのモデルID("stabilityai/stable-diffusion-2"など)が指定できます。
70
+
71
+ `output_dir` に学習後のモデルを保存するフォルダを指定します。`output_name` にモデルのファイル名を拡張子を除いて指定します。`save_model_as` でsafetensors形式での保存を指定しています。
72
+
73
+ `dataset_config` に `.toml` ファイルを指定します。ファイル内でのバッチサイズ指定は、当初はメモリ消費を抑えるために `1` としてください。
74
+
75
+ 学習させるステップ数 `max_train_steps` を10000とします。学習率 `learning_rate` はここでは5e-6を指定しています。
76
+
77
+ 省メモリ化のため `mixed_precision="fp16"` を指定します(RTX30 シリーズ以降では `bf16` も指定できます。環境整備時にaccelerateに行った設定と合わせてください)。また `gradient_checkpointing` を指定します。
78
+
79
+ オプティマイザ(モデルを学習データにあうように最適化=学習させるクラス)にメモリ消費の少ない 8bit AdamW を使うため、 `optimizer_type="AdamW8bit"` を指定します。
80
+
81
+ `xformers` オプションを指定し、xformersのCrossAttentionを用います。xformersをインストールしていない場合やエラーとなる場合(環境にもよりますが `mixed_precision="no"` の場合など)、代わりに `mem_eff_attn` オプションを指定すると省メモリ版CrossAttentionを使用します(速度は遅くなります)。
82
+
83
+ ある程度メモリがある場合は、`.toml` ファイルを編集してバッチサイズをたとえば `4` くらいに増やしてください(高速化と精度向上の可能性があります)。
84
+
85
+ ### よく使われるオプションについて
86
+
87
+ 以下の場合にはオプションに関するドキュメントを参照してください。
88
+
89
+ - Stable Diffusion 2.xまたはそこからの派生モデルを学習する
90
+ - clip skipを2以上を前提としたモデルを学習する
91
+ - 75トークンを超えたキャプションで学習する
92
+
93
+ ### バッチサイズについて
94
+
95
+ モデル全体を学習するためLoRA等の学習に比べるとメモリ消費量は多くなります(DreamBoothと同じ)。
96
+
97
+ ### 学習率について
98
+
99
+ 1e-6から5e-6程度が一般的なようです。他のfine tuningの例なども参照してみてください。
100
+
101
+ ### 以前の形式のデータセット指定をした場合のコマンドライン
102
+
103
+ 解像度やバッチサイズをオプションで指定します。コマンドラインの例は以下の通りです。
104
+
105
+ ```
106
+ accelerate launch --num_cpu_threads_per_process 1 fine_tune.py
107
+ --pretrained_model_name_or_path=model.ckpt
108
+ --in_json meta_lat.json
109
+ --train_data_dir=train_data
110
+ --output_dir=fine_tuned
111
+ --shuffle_caption
112
+ --train_batch_size=1 --learning_rate=5e-6 --max_train_steps=10000
113
+ --use_8bit_adam --xformers --gradient_checkpointing
114
+ --mixed_precision=bf16
115
+ --save_every_n_epochs=4
116
+ ```
117
+
118
+ <!--
119
+ ### 勾配をfp16とした学習(実験的機能)
120
+ full_fp16オプションを指定すると勾配を通常のfloat32からfloat16(fp16)に変更して学習します(mixed precisionではなく完全なfp16学習になるようです)。これによりSD1.xの512*512サイズでは8GB未満、SD2.xの512*512サイズで12GB未満のVRAM使用量で学習できるようです。
121
+
122
+ あらかじめaccelerate configでfp16を指定し、オプションでmixed_precision="fp16"としてください(bf16では動作しません)。
123
+
124
+ メモリ使用量を最小化するためには、xformers、use_8bit_adam、gradient_checkpointingの各オプションを指定し、train_batch_sizeを1としてください。
125
+ (余裕があるようならtrain_batch_sizeを段階的に増やすと若干精度が上がるはずです。)
126
+
127
+ PyTorchのソースにパッチを当てて無理やり実現しています(PyTorch 1.12.1と1.13.0で確認)。精度はかなり落ちますし、途中で学習失敗する確率も高くなります。学習率やステップ数の設定もシビアなようです。それらを認識したうえで自己責任でお使いください。
128
+ -->
129
+
130
+ # fine tuning特有のその他の主なオプション
131
+
132
+ すべてのオプションについては別文書を参照してください。
133
+
134
+ ## `train_text_encoder`
135
+ Text Encoderも学習対象とします。メモリ使用量が若干増加します。
136
+
137
+ 通常のfine tuningではText Encoderは学習対象としませんが(恐らくText Encoderの出力に従うようにU-Netを学習するため)、学習データ数が少ない場合には、DreamBoothのようにText Encoder側に学習させるのも有効的なようです。
138
+
139
+ ## `diffusers_xformers`
140
+ スクリプト独自のxformers置換機能ではなくDiffusersのxformers機能を利用します。Hypernetworkの学習はできなくなります。
gen_img_README-ja.md ADDED
@@ -0,0 +1,487 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ SD 1.xおよび2.xのモデル、当リポジトリで学習したLoRA、ControlNet(v1.0のみ動作確認)などに対応した、Diffusersベースの推論(画像生成)スクリプトです。コマンドラインから用います。
2
+
3
+ # 概要
4
+
5
+ * Diffusers (v0.10.2) ベースの推論(画像生成)スクリプト。
6
+ * SD 1.xおよび2.x (base/v-parameterization)モデルに対応。
7
+ * txt2img、img2img、inpaintingに対応。
8
+ * 対話モード、およびファイルからのプロンプト読み込み、連続生成に対応。
9
+ * プロンプト1行あたりの生成枚数を指定可能。
10
+ * 全体の繰り返し回数を指定可能。
11
+ * `fp16`だけでなく`bf16`にも対応。
12
+ * xformersに対応し高速生成が可能。
13
+ * xformersにより省メモリ生成を行いますが、Automatic 1111氏のWeb UIほど最適化していないため、512*512の画像生成でおおむね6GB程度のVRAMを使用します。
14
+ * プロンプトの225トークンへの拡張。ネガティブプロンプト、重みづけに対応。
15
+ * Diffusersの各種samplerに対応(Web UIよりもsampler数は少ないです)。
16
+ * Text Encoderのclip skip(最後からn番目の層の出力を用いる)に対応。
17
+ * VAEの別途読み込み。
18
+ * CLIP Guided Stable Diffusion、VGG16 Guided Stable Diffusion、Highres. fix、upscale対応。
19
+ * Highres. fixはWeb UIの実装を全く確認していない独自実装のため、出力結果は異なるかもしれません。
20
+ * LoRA対応。適用率指定、複数LoRA同時利用、重みのマージに対応。
21
+ * Text EncoderとU-Netで別の適用率を指定することはできません。
22
+ * Attention Coupleに対応。
23
+ * ControlNet v1.0に対応。
24
+ * 途中でモデルを切り替えることはできませんが、バッチファイルを組むことで対応できます。
25
+ * 個人的に欲しくなった機能をいろいろ追加。
26
+
27
+ 機能追加時にすべてのテストを行っているわけではないため、以前の機能に影響が出て一部機能が動かない可能性があります。何か問題があればお知らせください。
28
+
29
+ # 基本的な使い方
30
+
31
+ ## 対話モードでの画像生成
32
+
33
+ 以下のように入力してください。
34
+
35
+ ```batchfile
36
+ python gen_img_diffusers.py --ckpt <モデル名> --outdir <画像出力先> --xformers --fp16 --interactive
37
+ ```
38
+
39
+ `--ckpt`オプションにモデル(Stable Diffusionのcheckpointファイル、またはDiffusersのモデルフォルダ)、`--outdir`オプションに画像の出力先フォルダを指定します。
40
+
41
+ `--xformers`オプションでxformersの使用を指定します(xformersを使わない場合は外してください)。`--fp16`オプションでfp16(単精度)での推論を行います。RTX 30系のGPUでは `--bf16`オプションでbf16(bfloat16)での推論を行うこともできます。
42
+
43
+ `--interactive`オプションで対話モードを指定しています。
44
+
45
+ Stable Diffusion 2.0(またはそこからの追加学習モデル)を使う場合は`--v2`オプションを追加してください。v-parameterizationを使うモデル(`768-v-ema.ckpt`およびそこからの追加学習モデル)を使う場合はさらに`--v_parameterization`を追加してください。
46
+
47
+ `--v2`の指定有無が間違っているとモデル読み込み時にエラーになります。`--v_parameterization`の指定有無が間違っていると茶色い画像が表示されます。
48
+
49
+ `Type prompt:`と表示されたらプロンプトを入力してください。
50
+
51
+ ![image](https://user-images.githubusercontent.com/52813779/235343115-f3b8ac82-456d-4aab-9724-0cc73c4534aa.png)
52
+
53
+ ※画像が表示されずエラーになる場合、headless(画面表示機能なし)のOpenCVがインストールされているかもしれません。`pip install opencv-python`として通常のOpenCVを入れてください。または`--no_preview`オプションで画像表示を止めてください。
54
+
55
+ 画像ウィンドウを選択してから何らかのキーを押すとウィンドウが閉じ、次のプロンプトが入力できます。プロンプトでCtrl+Z、エンターの順に打鍵するとスクリプトを閉じます。
56
+
57
+ ## 単一のプロンプトで画像を一括生成
58
+
59
+ 以下のように入力します(実際には1行で入力します)。
60
+
61
+ ```batchfile
62
+ python gen_img_diffusers.py --ckpt <モデル名> --outdir <画像出力先>
63
+ --xformers --fp16 --images_per_prompt <生成枚数> --prompt "<プロンプト>"
64
+ ```
65
+
66
+ `--images_per_prompt`オプションで、プロンプト1件当たりの生成枚数を指定します。`--prompt`オプションでプロンプトを指定します。スペースを含む場合はダブルクォーテーションで囲んでください。
67
+
68
+ `--batch_size`オプションでバッチサイズを指定できます(後述)。
69
+
70
+ ## ファイルからプロンプトを読み込み一括生成
71
+
72
+ 以下のように入力します。
73
+
74
+ ```batchfile
75
+ python gen_img_diffusers.py --ckpt <モデル名> --outdir <画像出力先>
76
+ --xformers --fp16 --from_file <プロンプトファイル名>
77
+ ```
78
+
79
+ `--from_file`オプションで、プロンプトが記述されたファイルを指定します。1行1プロンプトで記述してください。`--images_per_prompt`オプションを指定して1行あたり生成枚数を指定できます。
80
+
81
+ ## ネガティブプロンプト、重みづけの使用
82
+
83
+ プロンプトオプション(プロンプト内で`--x`のように指定、後述)で`--n`を書くと、以降がネガティブプロンプトとなります。
84
+
85
+ またAUTOMATIC1111氏のWeb UIと同様の `()` や` []` 、`(xxx:1.3)` などによる重みづけが可能です(実装はDiffusersの[Long Prompt Weighting Stable Diffusion](https://github.com/huggingface/diffusers/blob/main/examples/community/README.md#long-prompt-weighting-stable-diffusion)からコピーしたものです)。
86
+
87
+ コマンドラインからのプロンプト指定、ファイルからのプロンプト読み込みでも同様に指定できます。
88
+
89
+ ![image](https://user-images.githubusercontent.com/52813779/235343128-e79cd768-ec59-46f5-8395-fce9bdc46208.png)
90
+
91
+ # 主なオプション
92
+
93
+ コマンドラインから指定してください。
94
+
95
+ ## モデルの指定
96
+
97
+ - `--ckpt <モデル名>`:モデル名を指定します。`--ckpt`オプションは必須です。Stable Diffusionのcheckpointファイル、またはDiffusersのモデルフォルダ、Hugging FaceのモデルIDを指定できます。
98
+
99
+ - `--v2`:Stable Diffusion 2.x系のモデルを使う場合に指定します。1.x系の場合には指定不要です。
100
+
101
+ - `--v_parameterization`:v-parameterizationを使うモデルを使う場合に指定します(`768-v-ema.ckpt`およびそこからの追加学習モデル、Waifu Diffusion v1.5など)。
102
+
103
+ `--v2`の指定有無が間違っているとモデル読み込み時にエラーになります。`--v_parameterization`の指定有無が間違っていると茶色い画像が表示されます。
104
+
105
+ - `--vae`:使用するVAEを指定します。未指定時はモデル内のVAEを使用します。
106
+
107
+ ## 画像生成と出力
108
+
109
+ - `--interactive`:インタラクティブモードで動作します。プロンプトを入力すると画像が生成されます。
110
+
111
+ - `--prompt <プロンプト>`:プロンプトを指定します。スペースを含む場合はダブルクォーテーションで囲んでください。
112
+
113
+ - `--from_file <プロンプトファイル名>`:プロンプトが記述されたファイルを指定します。1行1プロンプトで記述してください。なお画像サイズやguidance scaleはプロンプトオプション(後述)で指定できます。
114
+
115
+ - `--W <画像幅>`:画像の幅を指定します。デフォルトは`512`です。
116
+
117
+ - `--H <画像高さ>`:画像の高さを指定します。デフォルトは`512`です。
118
+
119
+ - `--steps <ステップ数>`:サンプリングステップ数を指定します。デフォルトは`50`です。
120
+
121
+ - `--scale <ガイダンススケール>`:unconditionalガイダンススケールを指定します。デフォルトは`7.5`です。
122
+
123
+ - `--sampler <サンプラー名>`:サンプラーを指定します。デフォルトは`ddim`です。Diffusersで提供されているddim、pndm、dpmsolver、dpmsolver+++、lms、euler、euler_a、が指定可能です(後ろの三つはk_lms、k_euler、k_euler_aでも指定できます)。
124
+
125
+ - `--outdir <画像出力先フォルダ>`:画像の出力先を指定します。
126
+
127
+ - `--images_per_prompt <生成枚数>`:プロンプト1件当たりの生成枚数を指定します。デフォルトは`1`です。
128
+
129
+ - `--clip_skip <スキップ数>`:CLIPの後ろから何番目の層を使うかを指定します。省略時は最後の層を使います。
130
+
131
+ - `--max_embeddings_multiples <倍数>`:CLIPの入出力長をデフォルト(75)の何倍にするかを指定します。未指定時は75のままです。たとえば3を指定すると入出力長が225になります。
132
+
133
+ - `--negative_scale` : uncoditioningのguidance scaleを個別に指定します。[gcem156氏のこちらの記事](https://note.com/gcem156/n/ne9a53e4a6f43)を参考に実装したものです。
134
+
135
+ ## メモリ使用量や生成速度の調整
136
+
137
+ - `--batch_size <バッチサイズ>`:バッチサイズを指定します。デフォルトは`1`です。バッチサイズが大きいとメモリを多く消費しますが、生成速度が速くなります。
138
+
139
+ - `--vae_batch_size <VAEのバッチサイズ>`:VAEのバッチサイズを指定します。デフォルトはバッチサイズと同じです。
140
+ VAEのほうがメモリを多く消費するため、デノイジング後(stepが100%になった後)でメモリ不足になる場合があります。このような場合にはVAEのバッチサイズを小さくしてくだ���い。
141
+
142
+ - `--xformers`:xformersを使う場合に指定します。
143
+
144
+ - `--fp16`:fp16(単精度)での推論を行います。`fp16`と`bf16`をどちらも指定しない場合はfp32(単精度)での推論を行います。
145
+
146
+ - `--bf16`:bf16(bfloat16)での推論を行います。RTX 30系のGPUでのみ指定可能です。`--bf16`オプションはRTX 30系以外のGPUではエラーになります。`fp16`よりも`bf16`のほうが推論結果がNaNになる(真っ黒の画像になる)可能性が低いようです。
147
+
148
+ ## 追加ネットワーク(LoRA等)の使用
149
+
150
+ - `--network_module`:使用する追加ネットワークを指定します。LoRAの場合は`--network_module networks.lora`と指定します。複数のLoRAを使用する場合は`--network_module networks.lora networks.lora networks.lora`のように指定します。
151
+
152
+ - `--network_weights`:使用する追加ネットワークの重みファイルを指定します。`--network_weights model.safetensors`のように指定します。複数のLoRAを使用する場合は`--network_weights model1.safetensors model2.safetensors model3.safetensors`のように指定します。引数の数は`--network_module`で指定した数と同じにしてください。
153
+
154
+ - `--network_mul`:使用する追加ネットワークの重みを何倍にするかを指定します。デフォルトは`1`です。`--network_mul 0.8`のように指定します。複数のLoRAを使用する場合は`--network_mul 0.4 0.5 0.7`のように指定します。引数の数は`--network_module`で指定した数と同じにしてください。
155
+
156
+ - `--network_merge`:使用する追加ネットワークの重みを`--network_mul`に指定した重みであらかじめマージします。`--network_pre_calc` と同時に使用できません。プロンプトオプションの`--am`、およびRegional LoRAは使用できなくなりますが、LoRA未使用時と同じ程度まで生成が高速化されます。
157
+
158
+ - `--network_pre_calc`:使用する追加ネットワークの重みを生成ごとにあらかじめ計算します。プロンプトオプションの`--am`が使用できます。LoRA未使用時と同じ程度まで生成は高速化されますが、生成前に重みを計算する時間が必要で、またメモリ使用量も若干増加します。Regional LoRA使用時は無効になります 。
159
+
160
+ # 主なオプションの指定例
161
+
162
+ 次は同一プロンプトで64枚をバッチサイズ4で一括生成する例です。
163
+
164
+ ```batchfile
165
+ python gen_img_diffusers.py --ckpt model.ckpt --outdir outputs
166
+ --xformers --fp16 --W 512 --H 704 --scale 12.5 --sampler k_euler_a
167
+ --steps 32 --batch_size 4 --images_per_prompt 64
168
+ --prompt "beautiful flowers --n monochrome"
169
+ ```
170
+
171
+ 次はファイルに書かれたプロンプトを、それぞれ10枚ずつ、バッチサイズ4で一括生成する例です。
172
+
173
+ ```batchfile
174
+ python gen_img_diffusers.py --ckpt model.ckpt --outdir outputs
175
+ --xformers --fp16 --W 512 --H 704 --scale 12.5 --sampler k_euler_a
176
+ --steps 32 --batch_size 4 --images_per_prompt 10
177
+ --from_file prompts.txt
178
+ ```
179
+
180
+ Textual Inversion(後述)およびLoRAの使用例です。
181
+
182
+ ```batchfile
183
+ python gen_img_diffusers.py --ckpt model.safetensors
184
+ --scale 8 --steps 48 --outdir txt2img --xformers
185
+ --W 512 --H 768 --fp16 --sampler k_euler_a
186
+ --textual_inversion_embeddings goodembed.safetensors negprompt.pt
187
+ --network_module networks.lora networks.lora
188
+ --network_weights model1.safetensors model2.safetensors
189
+ --network_mul 0.4 0.8
190
+ --clip_skip 2 --max_embeddings_multiples 1
191
+ --batch_size 8 --images_per_prompt 1 --interactive
192
+ ```
193
+
194
+ # プロンプトオプション
195
+
196
+ プロンプト内で、`--n`のように「ハイフンふたつ+アルファベットn文字」でプロンプトから各種オプションの指定が可能です。対話モード、コマンドライン、ファイル、いずれからプロンプトを指定する場合でも有効です。
197
+
198
+ プロンプトのオプション指定`--n`の前後にはスペースを入れてください。
199
+
200
+ - `--n`:ネガティブプロンプトを指定します。
201
+
202
+ - `--w`:画像幅を指定します。コマンドラインからの指定を上書きします。
203
+
204
+ - `--h`:画像高さを指定します。コマンドラインからの指定を上書きします。
205
+
206
+ - `--s`:ステップ数を指定します。コマンドラインからの指定を上書きします。
207
+
208
+ - `--d`:この画像の乱数seedを指定します。`--images_per_prompt`を指定している場合は「--d 1,2,3,4」のようにカンマ区切りで複数指定してください。
209
+ ※様々な理由により、Web UIとは同じ乱数seedでも生成される画像が異なる場合があります。
210
+
211
+ - `--l`:guidance scaleを指定します。コマンドラインからの指定を上書きします。
212
+
213
+ - `--t`:img2img(後述)のstrengthを指定します。コ��ンドラインからの指定を上書きします。
214
+
215
+ - `--nl`:ネガティブプロンプトのguidance scaleを指定します(後述)。コマンドラインからの指定を上書きします。
216
+
217
+ - `--am`:追加ネットワークの重みを指定します。コマンドラインからの指定を上書きします。複数の追加ネットワークを使用する場合は`--am 0.8,0.5,0.3`のように __カンマ区切りで__ 指定します。
218
+
219
+ ※これらのオプションを指定すると、バッチサイズよりも小さいサイズでバッチが実行される場合があります(これらの値が異なると一括生成できないため)。(あまり気にしなくて大丈夫ですが、ファイルからプロンプトを読み込み生成する場合は、これらの値が同一のプロンプトを並べておくと効率が良くなります。)
220
+
221
+ 例:
222
+ ```
223
+ (masterpiece, best quality), 1girl, in shirt and plated skirt, standing at street under cherry blossoms, upper body, [from below], kind smile, looking at another, [goodembed] --n realistic, real life, (negprompt), (lowres:1.1), (worst quality:1.2), (low quality:1.1), bad anatomy, bad hands, text, error, missing fingers, extra digit, fewer digits, cropped, normal quality, jpeg artifacts, signature, watermark, username, blurry --w 960 --h 640 --s 28 --d 1
224
+ ```
225
+
226
+ ![image](https://user-images.githubusercontent.com/52813779/235343446-25654172-fff4-4aaf-977a-20d262b51676.png)
227
+
228
+ # img2img
229
+
230
+ ## オプション
231
+
232
+ - `--image_path`:img2imgに利用する画像を指定します。`--image_path template.png`のように指定します。フォルダを指定すると、そのフォルダの画像を順次利用します。
233
+
234
+ - `--strength`:img2imgのstrengthを指定します。`--strength 0.8`のように指定します。デフォルトは`0.8`です。
235
+
236
+ - `--sequential_file_name`:ファイル名を連番にするかどうかを指定します。指定すると生成されるファイル名が`im_000001.png`からの連番になります。
237
+
238
+ - `--use_original_file_name`:指定すると生成ファイル名がオリジナルのファイル名と同じになります。
239
+
240
+ ## コマンドラインからの実行例
241
+
242
+ ```batchfile
243
+ python gen_img_diffusers.py --ckpt trinart_characters_it4_v1_vae_merged.ckpt
244
+ --outdir outputs --xformers --fp16 --scale 12.5 --sampler k_euler --steps 32
245
+ --image_path template.png --strength 0.8
246
+ --prompt "1girl, cowboy shot, brown hair, pony tail, brown eyes,
247
+ sailor school uniform, outdoors
248
+ --n lowres, bad anatomy, bad hands, error, missing fingers, cropped,
249
+ worst quality, low quality, normal quality, jpeg artifacts, (blurry),
250
+ hair ornament, glasses"
251
+ --batch_size 8 --images_per_prompt 32
252
+ ```
253
+
254
+ `--image_path`オプションにフォルダを指定すると、そのフォルダの画像を順次読み込みます。生成される枚数は画像枚数ではなく、プロンプト数になりますので、`--images_per_promptPPオプションを指定してimg2imgする画像の枚数とプロンプト数を合わせてください。
255
+
256
+ ファイルはファイル名でソートして読み込みます。なおソート順は文字列順となりますので(`1.jpg→2.jpg→10.jpg`ではなく`1.jpg→10.jpg→2.jpg`の順)、頭を0埋めするなどしてご対応ください(`01.jpg→02.jpg→10.jpg`)。
257
+
258
+ ## img2imgを利用したupscale
259
+
260
+ img2img時にコマンドラインオプションの`--W`と`--H`で生成画像サイズを指定すると、元画像をそのサイズにリサイズしてからimg2imgを行います。
261
+
262
+ またimg2imgの元画像がこのスクリプトで生成した画像の場合、プロンプトを省略すると、元画像のメタデータからプロンプトを取得しそのまま用います。これによりHighres. fixの2nd stageの動作だけを行うことができます。
263
+
264
+ ## img2img時のinpainting
265
+
266
+ 画像およびマスク画像を指定してinpaintingできます(inpaintingモデルには対応しておらず、単にマスク領域を対象にimg2imgするだけです)。
267
+
268
+ オプションは以下の通りです。
269
+
270
+ - `--mask_image`:マスク画像を指定します。`--img_path`と同様にフォルダを指定すると、そのフォルダの画像を順次利用します。
271
+
272
+ マスク画像はグレースケール画像で、白の部分がinpaintingされます。境界をグラデーションしておくとなんとなく滑らかになりますのでお勧めです。
273
+
274
+ ![image](https://user-images.githubusercontent.com/52813779/235343795-9eaa6d98-02ff-4f32-b089-80d1fc482453.png)
275
+
276
+ # その他の機能
277
+
278
+ ## Textual Inversion
279
+
280
+ `--textual_inversion_embeddings`オプションで使用するembeddingsを指定します(複数指定可)。拡張子を除いたファイル名をプロンプト内で使用することで、そのembeddingsを利用します(Web UIと同様の使用法です)。ネガティブプロンプト内でも使用できます。
281
+
282
+ モデルとして、当リポジトリで学習したTextual Inversionモデル、およびWeb UIで学習したTextual Inversionモデル(画像埋め込みは非対応)を利用できます
283
+
284
+ ## Extended Textual Inversion
285
+
286
+ `--textual_inversion_embeddings`の代わりに`--XTI_embeddings`オプションを指定してください。使用法は`--textual_inversion_embeddings`と同じです。
287
+
288
+ ## Highres. fix
289
+
290
+ AUTOMATIC1111氏のWeb UIにある機能の類似機能です(独自実装のためもしかしたらいろいろ異なるかもしれません)。最初に小さめの画像を生成し、その画像を元にimg2imgすることで、画像全体の破綻を防ぎつつ大きな解像度の画像を生成します。
291
+
292
+ 2nd stageのstep数は`--steps` と`--strength`オプションの値から計算されます(`steps*strength`)。
293
+
294
+ img2imgと併用できません。
295
+
296
+ 以下のオプションがあります。
297
+
298
+ - `--highres_fix_scale`:Highres. fixを有効にして、1st stageで生成する画像のサイズを、倍率で指定します。最終出力が1024x1024で、最初に512x512の画像を生成する場合は`--highres_fix_scale 0.5`のように指定します。Web UI出の指定の逆数になっていますのでご注意ください。
299
+
300
+ - `--highres_fix_steps`:1st stageの画像のステップ数を指定します。デフォルトは`28`です。
301
+
302
+ - `--highres_fix_save_1st`:1st stageの画像を保存するかどうかを指定します。
303
+
304
+ - `--highres_fix_latents_upscaling`:指定すると2nd stageの画像生成時に1st stageの画像をlatentベースでupscalingします(bilinearのみ対応)。未指定時は画像をLANCZOS4でupscalingします。
305
+
306
+ - `--highres_fix_upscaler`:2nd stageに任意のupscalerを利用します。現在は`--highres_fix_upscaler tools.latent_upscaler` のみ対応しています。
307
+
308
+ - `--highres_fix_upscaler_args`:`--highres_fix_upscaler`で指定したupscalerに渡す引数を指定します。
309
+ `tools.latent_upscaler`の場合は、`--highres_fix_upscaler_args "weights=D:\Work\SD\Models\others\etc\upscaler-v1-e100-220.safetensors"`のように重みファイルを指定します。
310
+
311
+ コマンドラインの例です。
312
+
313
+ ```batchfile
314
+ python gen_img_diffusers.py --ckpt trinart_characters_it4_v1_vae_merged.ckpt
315
+ --n_iter 1 --scale 7.5 --W 1024 --H 1024 --batch_size 1 --outdir ../txt2img
316
+ --steps 48 --sampler ddim --fp16
317
+ --xformers
318
+ --images_per_prompt 1 --interactive
319
+ --highres_fix_scale 0.5 --highres_fix_steps 28 --strength 0.5
320
+ ```
321
+
322
+ ## ControlNet
323
+
324
+ 現在はControlNet 1.0のみ動作確認しています。プリプロセスはCannyのみサポートしています。
325
+
326
+ 以下のオプションがあります。
327
+
328
+ - `--control_net_models`:ControlNetのモデルファイルを指定します。
329
+ 複数指定すると、それらをstepごとに切り替えて利用します(Web UIのControlNet拡張の実装と異なります)。diffと通常の両方をサポートします。
330
+
331
+ - `--guide_image_path`:ControlNetに使うヒント画像を指定します。`--img_path`と同様にフォルダを指定すると、そのフォルダの画像を順次利用します。Canny以外のモデルの場合には、あらかじめプリプロセスを行っておいてください。
332
+
333
+ - `--control_net_preps`:ControlNetのプリプロセスを指定します。`--control_net_models`と同様に複数指定可能です。現在はcannyのみ対応しています。対象モデルでプリプロセスを使用しない場合は `none` を指定します。
334
+ cannyの場合 `--control_net_preps canny_63_191`のように、閾値1と2を'_'で区切って指定できます。
335
+
336
+ - `--control_net_weights`:ControlNetの適用時の重みを指定します(`1.0`で通常、`0.5`なら半分の影響力で適用)。`--control_net_models`と同様に複数指定可能です。
337
+
338
+ - `--control_net_ratios`:ControlNetを適用するstepの範囲を指定します。`0.5`の場合は、step数の半分までControlNetを適用します。`--control_net_models`と同様に複数指定可能です。
339
+
340
+ コマンドラインの例です。
341
+
342
+ ```batchfile
343
+ python gen_img_diffusers.py --ckpt model_ckpt --scale 8 --steps 48 --outdir txt2img --xformers
344
+ --W 512 --H 768 --bf16 --sampler k_euler_a
345
+ --control_net_models diff_control_sd15_canny.safetensors --control_net_weights 1.0
346
+ --guide_image_path guide.png --control_net_ratios 1.0 --interactive
347
+ ```
348
+
349
+ ## Attention Couple + Reginal LoRA
350
+
351
+ プロンプトをいくつかの部分に分割し、それぞれのプロンプトを画像内のどの領域に適用するかを指定できる機能です。個別のオプションはありませんが、`mask_path`とプロンプトで指定します。
352
+
353
+ まず、プロンプトで` AND `を利用して、複数部分を定義します。最初の3つに対して領域指定ができ、以降の部分は画像全体へ適用されます。ネガティブプロンプトは画像全体に適用されます。
354
+
355
+ 以下ではANDで3つの部分を定義しています。
356
+
357
+ ```
358
+ shs 2girls, looking at viewer, smile AND bsb 2girls, looking back AND 2girls --n bad quality, worst quality
359
+ ```
360
+
361
+ 次にマスク画像を用意します。マスク画像はカラーの画像で、RGBの各チャネルがプロンプトのANDで区切られた部分に対応します。またあるチャネルの値がすべて0の場合、画像全体に適用されます。
362
+
363
+ 上記の例では、Rチャネルが`shs 2girls, looking at viewer, smile`、Gチャネルが`bsb 2girls, looking back`に、Bチャネルが`2girls`に対応します。次のようなマスク画像を使用すると、Bチャネルに指定がありませんので、`2girls`は画像全体に適用されます。
364
+
365
+ ![image](https://user-images.githubusercontent.com/52813779/235343061-b4dc9392-3dae-4831-8347-1e9ae5054251.png)
366
+
367
+ マスク画像は`--mask_path`で指定します。現在は1枚のみ対応しています。指定した画像サイズに自動的にリサイズされ適用されます。
368
+
369
+ ControlNetと組み合わせることも可能です(細かい位置指定にはControlNetとの組み合わせを推奨します)。
370
+
371
+ LoRAを指定すると、`--network_weights`で指定した複数のLoRAがそれぞれANDの各部分に対応します。現在の制約として、LoRAの数はANDの部分の数と同じである必要があります。
372
+
373
+ ## CLIP Guided Stable Diffusion
374
+
375
+ DiffusersのCommunity Examplesの[こちらのcustom pipeline](https://github.com/huggingface/diffusers/blob/main/examples/community/README.md#clip-guided-stable-diffusion)からソースをコピー、変更したものです。
376
+
377
+ 通常のプロンプトによる生成指定に加えて、追加でより大規模のCLIPでプロンプトのテキストの特徴量を取得し、生成中の画像の特徴量がそのテキストの特徴量に近づくよう、生成される画像をコントロールします(私のざっくりとした理解です)。大きめのCLIPを使いますのでVRAM使用量はかなり増加し(VRAM 8GBでは512*512でも厳しいかもしれません)、生成時間も掛かります。
378
+
379
+ なお選択できるサンプラーはDDIM、PNDM、LMSのみとなります。
380
+
381
+ `--clip_guidance_scale`オプションにどの程度、CLIPの特徴量を反映するかを数値で指定します。先のサンプルでは100になっていますので、そのあたりから始めて増減すると良いようです。
382
+
383
+ デフォルトではプロンプトの先頭75トークン(重みづけの特殊文字を除く)がCLIPに渡されます。プロンプトの`--c`オプションで、通常のプロンプトではなく、CLIPに渡すテキストを別に指定できます(たとえばCLIPはDreamBoothのidentifier(識別子)や「1girl」などのモデル特有の単語は認識できないと思われますので、それらを省いたテキストが良いと思われます)。
384
+
385
+ コマンドラインの例です。
386
+
387
+ ```batchfile
388
+ python gen_img_diffusers.py --ckpt v1-5-pruned-emaonly.ckpt --n_iter 1
389
+ --scale 2.5 --W 512 --H 512 --batch_size 1 --outdir ../txt2img --steps 36
390
+ --sampler ddim --fp16 --opt_channels_last --xformers --images_per_prompt 1
391
+ --interactive --clip_guidance_scale 100
392
+ ```
393
+
394
+ ## CLIP Image Guided Stable Diffusion
395
+
396
+ テキストではなくCLIPに別の画像を渡し、その特徴量に近づくよう生成をコントロールする機能です。`--clip_image_guidance_scale`オプションで適用量の数値を、`--guide_image_path`オプションでguideに使用する画像(ファイルまたはフォルダ)を指定してください。
397
+
398
+ コマンドラインの例です。
399
+
400
+ ```batchfile
401
+ python gen_img_diffusers.py --ckpt trinart_characters_it4_v1_vae_merged.ckpt
402
+ --n_iter 1 --scale 7.5 --W 512 --H 512 --batch_size 1 --outdir ../txt2img
403
+ --steps 80 --sampler ddim --fp16 --opt_channels_last --xformers
404
+ --images_per_prompt 1 --interactive --clip_image_guidance_scale 100
405
+ --guide_image_path YUKA160113420I9A4104_TP_V.jpg
406
+ ```
407
+
408
+ ### VGG16 Guided Stable Diffusion
409
+
410
+ 指定した画像に近づくように画像生成する機能です。通常のプロンプトによる生成指定に加えて、追加でVGG16の特徴量を取得し、生成中の画像が指定したガイド画像に近づくよう、生成される画像をコントロールします。img2imgでの使用をお勧めします(通常の生成では画像がぼやけた感じになります)。CLIP Guided Stable Diffusionの仕組みを流用した独自の機能です。またアイデアはVGGを利用したスタイル変換から拝借しています。
411
+
412
+ なお選択できるサンプラーはDDIM、PNDM、LMSのみとなります。
413
+
414
+ `--vgg16_guidance_scale`オプションにどの程度、VGG16特徴量を反映するかを数値で指定します。試した感じでは100くらいから始めて増減すると良いようです。`--guide_image_path`オプションでguideに使用する画像(ファイルまたはフォルダ)を指定してください。
415
+
416
+ 複数枚の画像を一括でimg2img変換し、元画像をガイド画像とする場合、`--guide_image_path`と`--image_path`に同じ値を指定すればOKです。
417
+
418
+ コマンドラインの例です。
419
+
420
+ ```batchfile
421
+ python gen_img_diffusers.py --ckpt wd-v1-3-full-pruned-half.ckpt
422
+ --n_iter 1 --scale 5.5 --steps 60 --outdir ../txt2img
423
+ --xformers --sampler ddim --fp16 --W 512 --H 704
424
+ --batch_size 1 --images_per_prompt 1
425
+ --prompt "picturesque, 1girl, solo, anime face, skirt, beautiful face
426
+ --n lowres, bad anatomy, bad hands, error, missing fingers,
427
+ cropped, worst quality, low quality, normal quality,
428
+ jpeg artifacts, blurry, 3d, bad face, monochrome --d 1"
429
+ --strength 0.8 --image_path ..\src_image
430
+ --vgg16_guidance_scale 100 --guide_image_path ..\src_image
431
+ ```
432
+
433
+ `--vgg16_guidance_layerPで特徴量取得に使用するVGG16のレイヤー番号を指定できます(デフォルトは20でconv4-2のReLUです)。上の層ほど画風を表現し、下の層ほどコンテンツを表現するといわれています。
434
+
435
+ ![image](https://user-images.githubusercontent.com/52813779/235343813-3c1f0d7a-4fb3-4274-98e4-b92d76b551df.png)
436
+
437
+ # その他のオプション
438
+
439
+ - `--no_preview` : 対話モードでプレビュー画像を表示しません。OpenCVがインストールされていない場合や、出力されたファイルを直接確認する場合に指定してください。
440
+
441
+ - `--n_iter` : 生成を繰り返す回数を指定します。デフォルトは1です。プロンプトをファイルから読み込むとき、複数回の生成を行いたい場合に指定します。
442
+
443
+ - `--tokenizer_cache_dir` : トークナイザーのキャッシュディレクトリを指定します。(作業中)
444
+
445
+ - `--seed` : 乱数seedを指定します。1枚生成時はその画像のseed、複数枚生成時は各画像のseedを生成するための乱数のseedになります(`--from_file`で複数画像生成するとき、`--seed`オプションを指定すると複数回実行したときに各画像が同じseedになります)。
446
+
447
+ - `--iter_same_seed` : プロンプトに乱数seedの指定がないとき、`--n_iter`の繰り返し内ではすべて同じseedを使います。`--from_file`で指定した複数のプロンプト間でseedを統一して比較するときに使います。
448
+
449
+ - `--diffusers_xformers` : Diffuserのxformersを使用します。
450
+
451
+ - `--opt_channels_last` : 推論時にテンソルのチャンネルを最後に配置します。場合によっては高速化されることがあります。
452
+
453
+ - `--network_show_meta` : 追加ネットワークのメタデータを表示します。
454
+
455
+
456
+ ---
457
+
458
+ # About Gradual Latent
459
+
460
+ Gradual Latent is a Hires fix that gradually increases the size of the latent. `gen_img.py`, `sdxl_gen_img.py`, and `gen_img_diffusers.py` have the following options.
461
+
462
+ - `--gradual_latent_timesteps`: Specifies the timestep to start increasing the size of the latent. The default is None, which means Gradual Latent is not used. Please try around 750 at first.
463
+ - `--gradual_latent_ratio`: Specifies the initial size of the latent. The default is 0.5, which means it starts with half the default latent size.
464
+ - `--gradual_latent_ratio_step`: Specifies the ratio to increase the size of the latent. The default is 0.125, which means the latent size is gradually increased to 0.625, 0.75, 0.875, 1.0.
465
+ - `--gradual_latent_ratio_every_n_steps`: Specifies the interval to increase the size of the latent. The default is 3, which means the latent size is increased every 3 steps.
466
+
467
+ Each option can also be specified with prompt options, `--glt`, `--glr`, `--gls`, `--gle`.
468
+
469
+ __Please specify `euler_a` for the sampler.__ Because the source code of the sampler is modified. It will not work with other samplers.
470
+
471
+ It is more effective with SD 1.5. It is quite subtle with SDXL.
472
+
473
+ # Gradual Latent について
474
+
475
+ latentのサイズを徐々に大きくしていくHires fixです。`gen_img.py` 、``sdxl_gen_img.py` 、`gen_img_diffusers.py` に以下のオプションが追加されています。
476
+
477
+ - `--gradual_latent_timesteps` : latentのサイズを大きくし始めるタイムステップを指定します。デフォルトは None で、Gradual Latentを使用しません。750 くらいから始めてみてください。
478
+ - `--gradual_latent_ratio` : latentの初期サイズを指定します。デフォルトは 0.5 で、デフォルトの latent サイズの半分のサイズから始めます。
479
+ - `--gradual_latent_ratio_step`: latentのサイズを大きくする割合を指定します。デフォルトは 0.125 で、latentのサイズを 0.625, 0.75, 0.875, 1.0 と徐々に大きくします。
480
+ - `--gradual_latent_ratio_every_n_steps`: latentのサイズを大きくする間隔を指定しま��。デフォルトは 3 で、3ステップごとに latent のサイズを大きくします。
481
+
482
+ それぞれのオプションは、プロンプトオプション、`--glt`、`--glr`、`--gls`、`--gle` でも指定できます。
483
+
484
+ サンプラーに手を加えているため、__サンプラーに `euler_a` を指定してください。__ 他のサンプラーでは動作しません。
485
+
486
+ SD 1.5 のほうが効果があります。SDXL ではかなり微妙です。
487
+
gradscaler.py ADDED
@@ -0,0 +1,183 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from collections import defaultdict
2
+ import torch
3
+ import intel_extension_for_pytorch as ipex # pylint: disable=import-error, unused-import
4
+ import intel_extension_for_pytorch._C as core # pylint: disable=import-error, unused-import
5
+
6
+ # pylint: disable=protected-access, missing-function-docstring, line-too-long
7
+
8
+ device_supports_fp64 = torch.xpu.has_fp64_dtype() if hasattr(torch.xpu, "has_fp64_dtype") else torch.xpu.get_device_properties("xpu").has_fp64
9
+ OptState = ipex.cpu.autocast._grad_scaler.OptState
10
+ _MultiDeviceReplicator = ipex.cpu.autocast._grad_scaler._MultiDeviceReplicator
11
+ _refresh_per_optimizer_state = ipex.cpu.autocast._grad_scaler._refresh_per_optimizer_state
12
+
13
+ def _unscale_grads_(self, optimizer, inv_scale, found_inf, allow_fp16): # pylint: disable=unused-argument
14
+ per_device_inv_scale = _MultiDeviceReplicator(inv_scale)
15
+ per_device_found_inf = _MultiDeviceReplicator(found_inf)
16
+
17
+ # To set up _amp_foreach_non_finite_check_and_unscale_, split grads by device and dtype.
18
+ # There could be hundreds of grads, so we'd like to iterate through them just once.
19
+ # However, we don't know their devices or dtypes in advance.
20
+
21
+ # https://stackoverflow.com/questions/5029934/defaultdict-of-defaultdict
22
+ # Google says mypy struggles with defaultdicts type annotations.
23
+ per_device_and_dtype_grads = defaultdict(lambda: defaultdict(list)) # type: ignore[var-annotated]
24
+ # sync grad to master weight
25
+ if hasattr(optimizer, "sync_grad"):
26
+ optimizer.sync_grad()
27
+ with torch.no_grad():
28
+ for group in optimizer.param_groups:
29
+ for param in group["params"]:
30
+ if param.grad is None:
31
+ continue
32
+ if (not allow_fp16) and param.grad.dtype == torch.float16:
33
+ raise ValueError("Attempting to unscale FP16 gradients.")
34
+ if param.grad.is_sparse:
35
+ # is_coalesced() == False means the sparse grad has values with duplicate indices.
36
+ # coalesce() deduplicates indices and adds all values that have the same index.
37
+ # For scaled fp16 values, there's a good chance coalescing will cause overflow,
38
+ # so we should check the coalesced _values().
39
+ if param.grad.dtype is torch.float16:
40
+ param.grad = param.grad.coalesce()
41
+ to_unscale = param.grad._values()
42
+ else:
43
+ to_unscale = param.grad
44
+
45
+ # -: is there a way to split by device and dtype without appending in the inner loop?
46
+ to_unscale = to_unscale.to("cpu")
47
+ per_device_and_dtype_grads[to_unscale.device][
48
+ to_unscale.dtype
49
+ ].append(to_unscale)
50
+
51
+ for _, per_dtype_grads in per_device_and_dtype_grads.items():
52
+ for grads in per_dtype_grads.values():
53
+ core._amp_foreach_non_finite_check_and_unscale_(
54
+ grads,
55
+ per_device_found_inf.get("cpu"),
56
+ per_device_inv_scale.get("cpu"),
57
+ )
58
+
59
+ return per_device_found_inf._per_device_tensors
60
+
61
+ def unscale_(self, optimizer):
62
+ """
63
+ Divides ("unscales") the optimizer's gradient tensors by the scale factor.
64
+ :meth:`unscale_` is optional, serving cases where you need to
65
+ :ref:`modify or inspect gradients<working-with-unscaled-gradients>`
66
+ between the backward pass(es) and :meth:`step`.
67
+ If :meth:`unscale_` is not called explicitly, gradients will be unscaled automatically during :meth:`step`.
68
+ Simple example, using :meth:`unscale_` to enable clipping of unscaled gradients::
69
+ ...
70
+ scaler.scale(loss).backward()
71
+ scaler.unscale_(optimizer)
72
+ torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm)
73
+ scaler.step(optimizer)
74
+ scaler.update()
75
+ Args:
76
+ optimizer (torch.optim.Optimizer): Optimizer that owns the gradients to be unscaled.
77
+ .. warning::
78
+ :meth:`unscale_` should only be called once per optimizer per :meth:`step` call,
79
+ and only after all gradients for that optimizer's assigned parameters have been accumulated.
80
+ Calling :meth:`unscale_` twice for a given optimizer between each :meth:`step` triggers a RuntimeError.
81
+ .. warning::
82
+ :meth:`unscale_` may unscale sparse gradients out of place, replacing the ``.grad`` attribute.
83
+ """
84
+ if not self._enabled:
85
+ return
86
+
87
+ self._check_scale_growth_tracker("unscale_")
88
+
89
+ optimizer_state = self._per_optimizer_states[id(optimizer)]
90
+
91
+ if optimizer_state["stage"] is OptState.UNSCALED: # pylint: disable=no-else-raise
92
+ raise RuntimeError(
93
+ "unscale_() has already been called on this optimizer since the last update()."
94
+ )
95
+ elif optimizer_state["stage"] is OptState.STEPPED:
96
+ raise RuntimeError("unscale_() is being called after step().")
97
+
98
+ # FP32 division can be imprecise for certain compile options, so we carry out the reciprocal in FP64.
99
+ assert self._scale is not None
100
+ if device_supports_fp64:
101
+ inv_scale = self._scale.double().reciprocal().float()
102
+ else:
103
+ inv_scale = self._scale.to("cpu").double().reciprocal().float().to(self._scale.device)
104
+ found_inf = torch.full(
105
+ (1,), 0.0, dtype=torch.float32, device=self._scale.device
106
+ )
107
+
108
+ optimizer_state["found_inf_per_device"] = self._unscale_grads_(
109
+ optimizer, inv_scale, found_inf, False
110
+ )
111
+ optimizer_state["stage"] = OptState.UNSCALED
112
+
113
+ def update(self, new_scale=None):
114
+ """
115
+ Updates the scale factor.
116
+ If any optimizer steps were skipped the scale is multiplied by ``backoff_factor``
117
+ to reduce it. If ``growth_interval`` unskipped iterations occurred consecutively,
118
+ the scale is multiplied by ``growth_factor`` to increase it.
119
+ Passing ``new_scale`` sets the new scale value manually. (``new_scale`` is not
120
+ used directly, it's used to fill GradScaler's internal scale tensor. So if
121
+ ``new_scale`` was a tensor, later in-place changes to that tensor will not further
122
+ affect the scale GradScaler uses internally.)
123
+ Args:
124
+ new_scale (float or :class:`torch.FloatTensor`, optional, default=None): New scale factor.
125
+ .. warning::
126
+ :meth:`update` should only be called at the end of the iteration, after ``scaler.step(optimizer)`` has
127
+ been invoked for all optimizers used this iteration.
128
+ """
129
+ if not self._enabled:
130
+ return
131
+
132
+ _scale, _growth_tracker = self._check_scale_growth_tracker("update")
133
+
134
+ if new_scale is not None:
135
+ # Accept a new user-defined scale.
136
+ if isinstance(new_scale, float):
137
+ self._scale.fill_(new_scale) # type: ignore[union-attr]
138
+ else:
139
+ reason = "new_scale should be a float or a 1-element torch.FloatTensor with requires_grad=False."
140
+ assert isinstance(new_scale, torch.FloatTensor), reason # type: ignore[attr-defined]
141
+ assert new_scale.numel() == 1, reason
142
+ assert new_scale.requires_grad is False, reason
143
+ self._scale.copy_(new_scale) # type: ignore[union-attr]
144
+ else:
145
+ # Consume shared inf/nan data collected from optimizers to update the scale.
146
+ # If all found_inf tensors are on the same device as self._scale, this operation is asynchronous.
147
+ found_infs = [
148
+ found_inf.to(device="cpu", non_blocking=True)
149
+ for state in self._per_optimizer_states.values()
150
+ for found_inf in state["found_inf_per_device"].values()
151
+ ]
152
+
153
+ assert len(found_infs) > 0, "No inf checks were recorded prior to update."
154
+
155
+ found_inf_combined = found_infs[0]
156
+ if len(found_infs) > 1:
157
+ for i in range(1, len(found_infs)):
158
+ found_inf_combined += found_infs[i]
159
+
160
+ to_device = _scale.device
161
+ _scale = _scale.to("cpu")
162
+ _growth_tracker = _growth_tracker.to("cpu")
163
+
164
+ core._amp_update_scale_(
165
+ _scale,
166
+ _growth_tracker,
167
+ found_inf_combined,
168
+ self._growth_factor,
169
+ self._backoff_factor,
170
+ self._growth_interval,
171
+ )
172
+
173
+ _scale = _scale.to(to_device)
174
+ _growth_tracker = _growth_tracker.to(to_device)
175
+ # To prepare for next iteration, clear the data collected from optimizers this iteration.
176
+ self._per_optimizer_states = defaultdict(_refresh_per_optimizer_state)
177
+
178
+ def gradscaler_init():
179
+ torch.xpu.amp.GradScaler = ipex.cpu.autocast._grad_scaler.GradScaler
180
+ torch.xpu.amp.GradScaler._unscale_grads_ = _unscale_grads_
181
+ torch.xpu.amp.GradScaler.unscale_ = unscale_
182
+ torch.xpu.amp.GradScaler.update = update
183
+ return torch.xpu.amp.GradScaler
hijacks.py ADDED
@@ -0,0 +1,367 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ from functools import wraps
3
+ from contextlib import nullcontext
4
+ import torch
5
+ import numpy as np
6
+
7
+ device_supports_fp64 = torch.xpu.has_fp64_dtype() if hasattr(torch.xpu, "has_fp64_dtype") else torch.xpu.get_device_properties("xpu").has_fp64
8
+ if os.environ.get('IPEX_FORCE_ATTENTION_SLICE', '0') == '0' and (torch.xpu.get_device_properties("xpu").total_memory / 1024 / 1024 / 1024) > 4.1:
9
+ try:
10
+ x = torch.ones((33000,33000), dtype=torch.float32, device="xpu")
11
+ del x
12
+ torch.xpu.empty_cache()
13
+ can_allocate_plus_4gb = True
14
+ except Exception:
15
+ can_allocate_plus_4gb = False
16
+ else:
17
+ can_allocate_plus_4gb = bool(os.environ.get('IPEX_FORCE_ATTENTION_SLICE', '0') == '-1')
18
+
19
+ # pylint: disable=protected-access, missing-function-docstring, line-too-long, unnecessary-lambda, no-else-return
20
+
21
+ class DummyDataParallel(torch.nn.Module): # pylint: disable=missing-class-docstring, unused-argument, too-few-public-methods
22
+ def __new__(cls, module, device_ids=None, output_device=None, dim=0): # pylint: disable=unused-argument
23
+ if isinstance(device_ids, list) and len(device_ids) > 1:
24
+ print("IPEX backend doesn't support DataParallel on multiple XPU devices")
25
+ return module.to("xpu")
26
+
27
+ def return_null_context(*args, **kwargs): # pylint: disable=unused-argument
28
+ return nullcontext()
29
+
30
+ @property
31
+ def is_cuda(self):
32
+ return self.device.type == 'xpu' or self.device.type == 'cuda'
33
+
34
+ def check_device(device):
35
+ return bool((isinstance(device, torch.device) and device.type == "cuda") or (isinstance(device, str) and "cuda" in device) or isinstance(device, int))
36
+
37
+ def return_xpu(device):
38
+ return f"xpu:{device.split(':')[-1]}" if isinstance(device, str) and ":" in device else f"xpu:{device}" if isinstance(device, int) else torch.device(f"xpu:{device.index}" if device.index is not None else "xpu") if isinstance(device, torch.device) else "xpu"
39
+
40
+
41
+ # Autocast
42
+ original_autocast_init = torch.amp.autocast_mode.autocast.__init__
43
+ @wraps(torch.amp.autocast_mode.autocast.__init__)
44
+ def autocast_init(self, device_type, dtype=None, enabled=True, cache_enabled=None):
45
+ if device_type == "cuda":
46
+ return original_autocast_init(self, device_type="xpu", dtype=dtype, enabled=enabled, cache_enabled=cache_enabled)
47
+ else:
48
+ return original_autocast_init(self, device_type=device_type, dtype=dtype, enabled=enabled, cache_enabled=cache_enabled)
49
+
50
+ # Latent Antialias CPU Offload:
51
+ original_interpolate = torch.nn.functional.interpolate
52
+ @wraps(torch.nn.functional.interpolate)
53
+ def interpolate(tensor, size=None, scale_factor=None, mode='nearest', align_corners=None, recompute_scale_factor=None, antialias=False): # pylint: disable=too-many-arguments
54
+ if mode in {'bicubic', 'bilinear'}:
55
+ return_device = tensor.device
56
+ return_dtype = tensor.dtype
57
+ return original_interpolate(tensor.to("cpu", dtype=torch.float32), size=size, scale_factor=scale_factor, mode=mode,
58
+ align_corners=align_corners, recompute_scale_factor=recompute_scale_factor, antialias=antialias).to(return_device, dtype=return_dtype)
59
+ else:
60
+ return original_interpolate(tensor, size=size, scale_factor=scale_factor, mode=mode,
61
+ align_corners=align_corners, recompute_scale_factor=recompute_scale_factor, antialias=antialias)
62
+
63
+
64
+ # Diffusers Float64 (Alchemist GPUs doesn't support 64 bit):
65
+ original_from_numpy = torch.from_numpy
66
+ @wraps(torch.from_numpy)
67
+ def from_numpy(ndarray):
68
+ if ndarray.dtype == float:
69
+ return original_from_numpy(ndarray.astype('float32'))
70
+ else:
71
+ return original_from_numpy(ndarray)
72
+
73
+ original_as_tensor = torch.as_tensor
74
+ @wraps(torch.as_tensor)
75
+ def as_tensor(data, dtype=None, device=None):
76
+ if check_device(device):
77
+ device = return_xpu(device)
78
+ if isinstance(data, np.ndarray) and data.dtype == float and not (
79
+ (isinstance(device, torch.device) and device.type == "cpu") or (isinstance(device, str) and "cpu" in device)):
80
+ return original_as_tensor(data, dtype=torch.float32, device=device)
81
+ else:
82
+ return original_as_tensor(data, dtype=dtype, device=device)
83
+
84
+
85
+ if can_allocate_plus_4gb:
86
+ original_scaled_dot_product_attention = torch.nn.functional.scaled_dot_product_attention
87
+ else:
88
+ # 32 bit attention workarounds for Alchemist:
89
+ try:
90
+ from .attention import dynamic_scaled_dot_product_attention as original_scaled_dot_product_attention
91
+ except Exception: # pylint: disable=broad-exception-caught
92
+ original_scaled_dot_product_attention = torch.nn.functional.scaled_dot_product_attention
93
+
94
+ @wraps(torch.nn.functional.scaled_dot_product_attention)
95
+ def scaled_dot_product_attention(query, key, value, attn_mask=None, dropout_p=0.0, is_causal=False, **kwargs):
96
+ if query.dtype != key.dtype:
97
+ key = key.to(dtype=query.dtype)
98
+ if query.dtype != value.dtype:
99
+ value = value.to(dtype=query.dtype)
100
+ if attn_mask is not None and query.dtype != attn_mask.dtype:
101
+ attn_mask = attn_mask.to(dtype=query.dtype)
102
+ return original_scaled_dot_product_attention(query, key, value, attn_mask=attn_mask, dropout_p=dropout_p, is_causal=is_causal, **kwargs)
103
+
104
+ # Data Type Errors:
105
+ original_torch_bmm = torch.bmm
106
+ @wraps(torch.bmm)
107
+ def torch_bmm(input, mat2, *, out=None):
108
+ if input.dtype != mat2.dtype:
109
+ mat2 = mat2.to(input.dtype)
110
+ return original_torch_bmm(input, mat2, out=out)
111
+
112
+ # Diffusers FreeU
113
+ original_fft_fftn = torch.fft.fftn
114
+ @wraps(torch.fft.fftn)
115
+ def fft_fftn(input, s=None, dim=None, norm=None, *, out=None):
116
+ return_dtype = input.dtype
117
+ return original_fft_fftn(input.to(dtype=torch.float32), s=s, dim=dim, norm=norm, out=out).to(dtype=return_dtype)
118
+
119
+ # Diffusers FreeU
120
+ original_fft_ifftn = torch.fft.ifftn
121
+ @wraps(torch.fft.ifftn)
122
+ def fft_ifftn(input, s=None, dim=None, norm=None, *, out=None):
123
+ return_dtype = input.dtype
124
+ return original_fft_ifftn(input.to(dtype=torch.float32), s=s, dim=dim, norm=norm, out=out).to(dtype=return_dtype)
125
+
126
+ # A1111 FP16
127
+ original_functional_group_norm = torch.nn.functional.group_norm
128
+ @wraps(torch.nn.functional.group_norm)
129
+ def functional_group_norm(input, num_groups, weight=None, bias=None, eps=1e-05):
130
+ if weight is not None and input.dtype != weight.data.dtype:
131
+ input = input.to(dtype=weight.data.dtype)
132
+ if bias is not None and weight is not None and bias.data.dtype != weight.data.dtype:
133
+ bias.data = bias.data.to(dtype=weight.data.dtype)
134
+ return original_functional_group_norm(input, num_groups, weight=weight, bias=bias, eps=eps)
135
+
136
+ # A1111 BF16
137
+ original_functional_layer_norm = torch.nn.functional.layer_norm
138
+ @wraps(torch.nn.functional.layer_norm)
139
+ def functional_layer_norm(input, normalized_shape, weight=None, bias=None, eps=1e-05):
140
+ if weight is not None and input.dtype != weight.data.dtype:
141
+ input = input.to(dtype=weight.data.dtype)
142
+ if bias is not None and weight is not None and bias.data.dtype != weight.data.dtype:
143
+ bias.data = bias.data.to(dtype=weight.data.dtype)
144
+ return original_functional_layer_norm(input, normalized_shape, weight=weight, bias=bias, eps=eps)
145
+
146
+ # Training
147
+ original_functional_linear = torch.nn.functional.linear
148
+ @wraps(torch.nn.functional.linear)
149
+ def functional_linear(input, weight, bias=None):
150
+ if input.dtype != weight.data.dtype:
151
+ input = input.to(dtype=weight.data.dtype)
152
+ if bias is not None and bias.data.dtype != weight.data.dtype:
153
+ bias.data = bias.data.to(dtype=weight.data.dtype)
154
+ return original_functional_linear(input, weight, bias=bias)
155
+
156
+ original_functional_conv1d = torch.nn.functional.conv1d
157
+ @wraps(torch.nn.functional.conv1d)
158
+ def functional_conv1d(input, weight, bias=None, stride=1, padding=0, dilation=1, groups=1):
159
+ if input.dtype != weight.data.dtype:
160
+ input = input.to(dtype=weight.data.dtype)
161
+ if bias is not None and bias.data.dtype != weight.data.dtype:
162
+ bias.data = bias.data.to(dtype=weight.data.dtype)
163
+ return original_functional_conv1d(input, weight, bias=bias, stride=stride, padding=padding, dilation=dilation, groups=groups)
164
+
165
+ original_functional_conv2d = torch.nn.functional.conv2d
166
+ @wraps(torch.nn.functional.conv2d)
167
+ def functional_conv2d(input, weight, bias=None, stride=1, padding=0, dilation=1, groups=1):
168
+ if input.dtype != weight.data.dtype:
169
+ input = input.to(dtype=weight.data.dtype)
170
+ if bias is not None and bias.data.dtype != weight.data.dtype:
171
+ bias.data = bias.data.to(dtype=weight.data.dtype)
172
+ return original_functional_conv2d(input, weight, bias=bias, stride=stride, padding=padding, dilation=dilation, groups=groups)
173
+
174
+ # LTX Video
175
+ original_functional_conv3d = torch.nn.functional.conv3d
176
+ @wraps(torch.nn.functional.conv3d)
177
+ def functional_conv3d(input, weight, bias=None, stride=1, padding=0, dilation=1, groups=1):
178
+ if input.dtype != weight.data.dtype:
179
+ input = input.to(dtype=weight.data.dtype)
180
+ if bias is not None and bias.data.dtype != weight.data.dtype:
181
+ bias.data = bias.data.to(dtype=weight.data.dtype)
182
+ return original_functional_conv3d(input, weight, bias=bias, stride=stride, padding=padding, dilation=dilation, groups=groups)
183
+
184
+ # SwinIR BF16:
185
+ original_functional_pad = torch.nn.functional.pad
186
+ @wraps(torch.nn.functional.pad)
187
+ def functional_pad(input, pad, mode='constant', value=None):
188
+ if mode == 'reflect' and input.dtype == torch.bfloat16:
189
+ return original_functional_pad(input.to(torch.float32), pad, mode=mode, value=value).to(dtype=torch.bfloat16)
190
+ else:
191
+ return original_functional_pad(input, pad, mode=mode, value=value)
192
+
193
+
194
+ original_torch_tensor = torch.tensor
195
+ @wraps(torch.tensor)
196
+ def torch_tensor(data, *args, dtype=None, device=None, **kwargs):
197
+ global device_supports_fp64
198
+ if check_device(device):
199
+ device = return_xpu(device)
200
+ if not device_supports_fp64:
201
+ if (isinstance(device, torch.device) and device.type == "xpu") or (isinstance(device, str) and "xpu" in device):
202
+ if dtype == torch.float64:
203
+ dtype = torch.float32
204
+ elif dtype is None and (hasattr(data, "dtype") and (data.dtype == torch.float64 or data.dtype == float)):
205
+ dtype = torch.float32
206
+ return original_torch_tensor(data, *args, dtype=dtype, device=device, **kwargs)
207
+
208
+ original_Tensor_to = torch.Tensor.to
209
+ @wraps(torch.Tensor.to)
210
+ def Tensor_to(self, device=None, *args, **kwargs):
211
+ if check_device(device):
212
+ return original_Tensor_to(self, return_xpu(device), *args, **kwargs)
213
+ else:
214
+ return original_Tensor_to(self, device, *args, **kwargs)
215
+
216
+ original_Tensor_cuda = torch.Tensor.cuda
217
+ @wraps(torch.Tensor.cuda)
218
+ def Tensor_cuda(self, device=None, *args, **kwargs):
219
+ if check_device(device):
220
+ return original_Tensor_cuda(self, return_xpu(device), *args, **kwargs)
221
+ else:
222
+ return original_Tensor_cuda(self, device, *args, **kwargs)
223
+
224
+ original_Tensor_pin_memory = torch.Tensor.pin_memory
225
+ @wraps(torch.Tensor.pin_memory)
226
+ def Tensor_pin_memory(self, device=None, *args, **kwargs):
227
+ if device is None:
228
+ device = "xpu"
229
+ if check_device(device):
230
+ return original_Tensor_pin_memory(self, return_xpu(device), *args, **kwargs)
231
+ else:
232
+ return original_Tensor_pin_memory(self, device, *args, **kwargs)
233
+
234
+ original_UntypedStorage_init = torch.UntypedStorage.__init__
235
+ @wraps(torch.UntypedStorage.__init__)
236
+ def UntypedStorage_init(*args, device=None, **kwargs):
237
+ if check_device(device):
238
+ return original_UntypedStorage_init(*args, device=return_xpu(device), **kwargs)
239
+ else:
240
+ return original_UntypedStorage_init(*args, device=device, **kwargs)
241
+
242
+ original_UntypedStorage_cuda = torch.UntypedStorage.cuda
243
+ @wraps(torch.UntypedStorage.cuda)
244
+ def UntypedStorage_cuda(self, device=None, *args, **kwargs):
245
+ if check_device(device):
246
+ return original_UntypedStorage_cuda(self, return_xpu(device), *args, **kwargs)
247
+ else:
248
+ return original_UntypedStorage_cuda(self, device, *args, **kwargs)
249
+
250
+ original_torch_empty = torch.empty
251
+ @wraps(torch.empty)
252
+ def torch_empty(*args, device=None, **kwargs):
253
+ if check_device(device):
254
+ return original_torch_empty(*args, device=return_xpu(device), **kwargs)
255
+ else:
256
+ return original_torch_empty(*args, device=device, **kwargs)
257
+
258
+ original_torch_randn = torch.randn
259
+ @wraps(torch.randn)
260
+ def torch_randn(*args, device=None, dtype=None, **kwargs):
261
+ if dtype is bytes:
262
+ dtype = None
263
+ if check_device(device):
264
+ return original_torch_randn(*args, device=return_xpu(device), **kwargs)
265
+ else:
266
+ return original_torch_randn(*args, device=device, **kwargs)
267
+
268
+ original_torch_ones = torch.ones
269
+ @wraps(torch.ones)
270
+ def torch_ones(*args, device=None, **kwargs):
271
+ if check_device(device):
272
+ return original_torch_ones(*args, device=return_xpu(device), **kwargs)
273
+ else:
274
+ return original_torch_ones(*args, device=device, **kwargs)
275
+
276
+ original_torch_zeros = torch.zeros
277
+ @wraps(torch.zeros)
278
+ def torch_zeros(*args, device=None, **kwargs):
279
+ if check_device(device):
280
+ return original_torch_zeros(*args, device=return_xpu(device), **kwargs)
281
+ else:
282
+ return original_torch_zeros(*args, device=device, **kwargs)
283
+
284
+ original_torch_full = torch.full
285
+ @wraps(torch.full)
286
+ def torch_full(*args, device=None, **kwargs):
287
+ if check_device(device):
288
+ return original_torch_full(*args, device=return_xpu(device), **kwargs)
289
+ else:
290
+ return original_torch_full(*args, device=device, **kwargs)
291
+
292
+ original_torch_linspace = torch.linspace
293
+ @wraps(torch.linspace)
294
+ def torch_linspace(*args, device=None, **kwargs):
295
+ if check_device(device):
296
+ return original_torch_linspace(*args, device=return_xpu(device), **kwargs)
297
+ else:
298
+ return original_torch_linspace(*args, device=device, **kwargs)
299
+
300
+ original_torch_load = torch.load
301
+ @wraps(torch.load)
302
+ def torch_load(f, map_location=None, *args, **kwargs):
303
+ if map_location is None:
304
+ map_location = "xpu"
305
+ if check_device(map_location):
306
+ return original_torch_load(f, *args, map_location=return_xpu(map_location), **kwargs)
307
+ else:
308
+ return original_torch_load(f, *args, map_location=map_location, **kwargs)
309
+
310
+ original_torch_Generator = torch.Generator
311
+ @wraps(torch.Generator)
312
+ def torch_Generator(device=None):
313
+ if check_device(device):
314
+ return original_torch_Generator(return_xpu(device))
315
+ else:
316
+ return original_torch_Generator(device)
317
+
318
+ @wraps(torch.cuda.synchronize)
319
+ def torch_cuda_synchronize(device=None):
320
+ if check_device(device):
321
+ return torch.xpu.synchronize(return_xpu(device))
322
+ else:
323
+ return torch.xpu.synchronize(device)
324
+
325
+
326
+ # Hijack Functions:
327
+ def ipex_hijacks(legacy=True):
328
+ global device_supports_fp64, can_allocate_plus_4gb
329
+ if legacy and float(torch.__version__[:3]) < 2.5:
330
+ torch.nn.functional.interpolate = interpolate
331
+ torch.tensor = torch_tensor
332
+ torch.Tensor.to = Tensor_to
333
+ torch.Tensor.cuda = Tensor_cuda
334
+ torch.Tensor.pin_memory = Tensor_pin_memory
335
+ torch.UntypedStorage.__init__ = UntypedStorage_init
336
+ torch.UntypedStorage.cuda = UntypedStorage_cuda
337
+ torch.empty = torch_empty
338
+ torch.randn = torch_randn
339
+ torch.ones = torch_ones
340
+ torch.zeros = torch_zeros
341
+ torch.full = torch_full
342
+ torch.linspace = torch_linspace
343
+ torch.load = torch_load
344
+ torch.Generator = torch_Generator
345
+ torch.cuda.synchronize = torch_cuda_synchronize
346
+
347
+ torch.backends.cuda.sdp_kernel = return_null_context
348
+ torch.nn.DataParallel = DummyDataParallel
349
+ torch.UntypedStorage.is_cuda = is_cuda
350
+ torch.amp.autocast_mode.autocast.__init__ = autocast_init
351
+
352
+ torch.nn.functional.scaled_dot_product_attention = scaled_dot_product_attention
353
+ torch.nn.functional.group_norm = functional_group_norm
354
+ torch.nn.functional.layer_norm = functional_layer_norm
355
+ torch.nn.functional.linear = functional_linear
356
+ torch.nn.functional.conv1d = functional_conv1d
357
+ torch.nn.functional.conv2d = functional_conv2d
358
+ torch.nn.functional.conv3d = functional_conv3d
359
+ torch.nn.functional.pad = functional_pad
360
+
361
+ torch.bmm = torch_bmm
362
+ torch.fft.fftn = fft_fftn
363
+ torch.fft.ifftn = fft_ifftn
364
+ if not device_supports_fp64:
365
+ torch.from_numpy = from_numpy
366
+ torch.as_tensor = as_tensor
367
+ return device_supports_fp64, can_allocate_plus_4gb
huggingface_util.py ADDED
@@ -0,0 +1,84 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Union, BinaryIO
2
+ from huggingface_hub import HfApi
3
+ from pathlib import Path
4
+ import argparse
5
+ import os
6
+ from library.utils import fire_in_thread
7
+ from library.utils import setup_logging
8
+ setup_logging()
9
+ import logging
10
+ logger = logging.getLogger(__name__)
11
+
12
+ def exists_repo(repo_id: str, repo_type: str, revision: str = "main", token: str = None):
13
+ api = HfApi(
14
+ token=token,
15
+ )
16
+ try:
17
+ api.repo_info(repo_id=repo_id, revision=revision, repo_type=repo_type)
18
+ return True
19
+ except:
20
+ return False
21
+
22
+
23
+ def upload(
24
+ args: argparse.Namespace,
25
+ src: Union[str, Path, bytes, BinaryIO],
26
+ dest_suffix: str = "",
27
+ force_sync_upload: bool = False,
28
+ ):
29
+ repo_id = args.huggingface_repo_id
30
+ repo_type = args.huggingface_repo_type
31
+ token = args.huggingface_token
32
+ path_in_repo = args.huggingface_path_in_repo + dest_suffix if args.huggingface_path_in_repo is not None else None
33
+ private = args.huggingface_repo_visibility is None or args.huggingface_repo_visibility != "public"
34
+ api = HfApi(token=token)
35
+ if not exists_repo(repo_id=repo_id, repo_type=repo_type, token=token):
36
+ try:
37
+ api.create_repo(repo_id=repo_id, repo_type=repo_type, private=private)
38
+ except Exception as e: # とりあえずRepositoryNotFoundErrorは確認したが他にあると困るので
39
+ logger.error("===========================================")
40
+ logger.error(f"failed to create HuggingFace repo / HuggingFaceのリポジトリの作成に失敗しました : {e}")
41
+ logger.error("===========================================")
42
+
43
+ is_folder = (type(src) == str and os.path.isdir(src)) or (isinstance(src, Path) and src.is_dir())
44
+
45
+ def uploader():
46
+ try:
47
+ if is_folder:
48
+ api.upload_folder(
49
+ repo_id=repo_id,
50
+ repo_type=repo_type,
51
+ folder_path=src,
52
+ path_in_repo=path_in_repo,
53
+ )
54
+ else:
55
+ api.upload_file(
56
+ repo_id=repo_id,
57
+ repo_type=repo_type,
58
+ path_or_fileobj=src,
59
+ path_in_repo=path_in_repo,
60
+ )
61
+ except Exception as e: # RuntimeErrorを確認済みだが他にあると困るので
62
+ logger.error("===========================================")
63
+ logger.error(f"failed to upload to HuggingFace / HuggingFaceへのアップロードに失敗しました : {e}")
64
+ logger.error("===========================================")
65
+
66
+ if args.async_upload and not force_sync_upload:
67
+ fire_in_thread(uploader)
68
+ else:
69
+ uploader()
70
+
71
+
72
+ def list_dir(
73
+ repo_id: str,
74
+ subfolder: str,
75
+ repo_type: str,
76
+ revision: str = "main",
77
+ token: str = None,
78
+ ):
79
+ api = HfApi(
80
+ token=token,
81
+ )
82
+ repo_info = api.repo_info(repo_id=repo_id, revision=revision, repo_type=repo_type)
83
+ file_list = [file for file in repo_info.siblings if file.rfilename.startswith(subfolder)]
84
+ return file_list
hypernetwork.py ADDED
@@ -0,0 +1,223 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn.functional as F
3
+ from diffusers.models.attention_processor import (
4
+ Attention,
5
+ AttnProcessor2_0,
6
+ SlicedAttnProcessor,
7
+ XFormersAttnProcessor
8
+ )
9
+
10
+ try:
11
+ import xformers.ops
12
+ except:
13
+ xformers = None
14
+
15
+
16
+ loaded_networks = []
17
+
18
+
19
+ def apply_single_hypernetwork(
20
+ hypernetwork, hidden_states, encoder_hidden_states
21
+ ):
22
+ context_k, context_v = hypernetwork.forward(hidden_states, encoder_hidden_states)
23
+ return context_k, context_v
24
+
25
+
26
+ def apply_hypernetworks(context_k, context_v, layer=None):
27
+ if len(loaded_networks) == 0:
28
+ return context_v, context_v
29
+ for hypernetwork in loaded_networks:
30
+ context_k, context_v = hypernetwork.forward(context_k, context_v)
31
+
32
+ context_k = context_k.to(dtype=context_k.dtype)
33
+ context_v = context_v.to(dtype=context_k.dtype)
34
+
35
+ return context_k, context_v
36
+
37
+
38
+
39
+ def xformers_forward(
40
+ self: XFormersAttnProcessor,
41
+ attn: Attention,
42
+ hidden_states: torch.Tensor,
43
+ encoder_hidden_states: torch.Tensor = None,
44
+ attention_mask: torch.Tensor = None,
45
+ ):
46
+ batch_size, sequence_length, _ = (
47
+ hidden_states.shape
48
+ if encoder_hidden_states is None
49
+ else encoder_hidden_states.shape
50
+ )
51
+
52
+ attention_mask = attn.prepare_attention_mask(
53
+ attention_mask, sequence_length, batch_size
54
+ )
55
+
56
+ query = attn.to_q(hidden_states)
57
+
58
+ if encoder_hidden_states is None:
59
+ encoder_hidden_states = hidden_states
60
+ elif attn.norm_cross:
61
+ encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
62
+
63
+ context_k, context_v = apply_hypernetworks(hidden_states, encoder_hidden_states)
64
+
65
+ key = attn.to_k(context_k)
66
+ value = attn.to_v(context_v)
67
+
68
+ query = attn.head_to_batch_dim(query).contiguous()
69
+ key = attn.head_to_batch_dim(key).contiguous()
70
+ value = attn.head_to_batch_dim(value).contiguous()
71
+
72
+ hidden_states = xformers.ops.memory_efficient_attention(
73
+ query,
74
+ key,
75
+ value,
76
+ attn_bias=attention_mask,
77
+ op=self.attention_op,
78
+ scale=attn.scale,
79
+ )
80
+ hidden_states = hidden_states.to(query.dtype)
81
+ hidden_states = attn.batch_to_head_dim(hidden_states)
82
+
83
+ # linear proj
84
+ hidden_states = attn.to_out[0](hidden_states)
85
+ # dropout
86
+ hidden_states = attn.to_out[1](hidden_states)
87
+ return hidden_states
88
+
89
+
90
+ def sliced_attn_forward(
91
+ self: SlicedAttnProcessor,
92
+ attn: Attention,
93
+ hidden_states: torch.Tensor,
94
+ encoder_hidden_states: torch.Tensor = None,
95
+ attention_mask: torch.Tensor = None,
96
+ ):
97
+ batch_size, sequence_length, _ = (
98
+ hidden_states.shape
99
+ if encoder_hidden_states is None
100
+ else encoder_hidden_states.shape
101
+ )
102
+ attention_mask = attn.prepare_attention_mask(
103
+ attention_mask, sequence_length, batch_size
104
+ )
105
+
106
+ query = attn.to_q(hidden_states)
107
+ dim = query.shape[-1]
108
+ query = attn.head_to_batch_dim(query)
109
+
110
+ if encoder_hidden_states is None:
111
+ encoder_hidden_states = hidden_states
112
+ elif attn.norm_cross:
113
+ encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
114
+
115
+ context_k, context_v = apply_hypernetworks(hidden_states, encoder_hidden_states)
116
+
117
+ key = attn.to_k(context_k)
118
+ value = attn.to_v(context_v)
119
+ key = attn.head_to_batch_dim(key)
120
+ value = attn.head_to_batch_dim(value)
121
+
122
+ batch_size_attention, query_tokens, _ = query.shape
123
+ hidden_states = torch.zeros(
124
+ (batch_size_attention, query_tokens, dim // attn.heads),
125
+ device=query.device,
126
+ dtype=query.dtype,
127
+ )
128
+
129
+ for i in range(batch_size_attention // self.slice_size):
130
+ start_idx = i * self.slice_size
131
+ end_idx = (i + 1) * self.slice_size
132
+
133
+ query_slice = query[start_idx:end_idx]
134
+ key_slice = key[start_idx:end_idx]
135
+ attn_mask_slice = (
136
+ attention_mask[start_idx:end_idx] if attention_mask is not None else None
137
+ )
138
+
139
+ attn_slice = attn.get_attention_scores(query_slice, key_slice, attn_mask_slice)
140
+
141
+ attn_slice = torch.bmm(attn_slice, value[start_idx:end_idx])
142
+
143
+ hidden_states[start_idx:end_idx] = attn_slice
144
+
145
+ hidden_states = attn.batch_to_head_dim(hidden_states)
146
+
147
+ # linear proj
148
+ hidden_states = attn.to_out[0](hidden_states)
149
+ # dropout
150
+ hidden_states = attn.to_out[1](hidden_states)
151
+
152
+ return hidden_states
153
+
154
+
155
+ def v2_0_forward(
156
+ self: AttnProcessor2_0,
157
+ attn: Attention,
158
+ hidden_states,
159
+ encoder_hidden_states=None,
160
+ attention_mask=None,
161
+ ):
162
+ batch_size, sequence_length, _ = (
163
+ hidden_states.shape
164
+ if encoder_hidden_states is None
165
+ else encoder_hidden_states.shape
166
+ )
167
+ inner_dim = hidden_states.shape[-1]
168
+
169
+ if attention_mask is not None:
170
+ attention_mask = attn.prepare_attention_mask(
171
+ attention_mask, sequence_length, batch_size
172
+ )
173
+ # scaled_dot_product_attention expects attention_mask shape to be
174
+ # (batch, heads, source_length, target_length)
175
+ attention_mask = attention_mask.view(
176
+ batch_size, attn.heads, -1, attention_mask.shape[-1]
177
+ )
178
+
179
+ query = attn.to_q(hidden_states)
180
+
181
+ if encoder_hidden_states is None:
182
+ encoder_hidden_states = hidden_states
183
+ elif attn.norm_cross:
184
+ encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
185
+
186
+ context_k, context_v = apply_hypernetworks(hidden_states, encoder_hidden_states)
187
+
188
+ key = attn.to_k(context_k)
189
+ value = attn.to_v(context_v)
190
+
191
+ head_dim = inner_dim // attn.heads
192
+ query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
193
+ key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
194
+ value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
195
+
196
+ # the output of sdp = (batch, num_heads, seq_len, head_dim)
197
+ # TODO: add support for attn.scale when we move to Torch 2.1
198
+ hidden_states = F.scaled_dot_product_attention(
199
+ query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False
200
+ )
201
+
202
+ hidden_states = hidden_states.transpose(1, 2).reshape(
203
+ batch_size, -1, attn.heads * head_dim
204
+ )
205
+ hidden_states = hidden_states.to(query.dtype)
206
+
207
+ # linear proj
208
+ hidden_states = attn.to_out[0](hidden_states)
209
+ # dropout
210
+ hidden_states = attn.to_out[1](hidden_states)
211
+ return hidden_states
212
+
213
+
214
+ def replace_attentions_for_hypernetwork():
215
+ import diffusers.models.attention_processor
216
+
217
+ diffusers.models.attention_processor.XFormersAttnProcessor.__call__ = (
218
+ xformers_forward
219
+ )
220
+ diffusers.models.attention_processor.SlicedAttnProcessor.__call__ = (
221
+ sliced_attn_forward
222
+ )
223
+ diffusers.models.attention_processor.AttnProcessor2_0.__call__ = v2_0_forward
hypernetwork_nai.py ADDED
@@ -0,0 +1,96 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # NAI compatible
2
+
3
+ import torch
4
+
5
+
6
+ class HypernetworkModule(torch.nn.Module):
7
+ def __init__(self, dim, multiplier=1.0):
8
+ super().__init__()
9
+
10
+ linear1 = torch.nn.Linear(dim, dim * 2)
11
+ linear2 = torch.nn.Linear(dim * 2, dim)
12
+ linear1.weight.data.normal_(mean=0.0, std=0.01)
13
+ linear1.bias.data.zero_()
14
+ linear2.weight.data.normal_(mean=0.0, std=0.01)
15
+ linear2.bias.data.zero_()
16
+ linears = [linear1, linear2]
17
+
18
+ self.linear = torch.nn.Sequential(*linears)
19
+ self.multiplier = multiplier
20
+
21
+ def forward(self, x):
22
+ return x + self.linear(x) * self.multiplier
23
+
24
+
25
+ class Hypernetwork(torch.nn.Module):
26
+ enable_sizes = [320, 640, 768, 1280]
27
+ # return self.modules[Hypernetwork.enable_sizes.index(size)]
28
+
29
+ def __init__(self, multiplier=1.0) -> None:
30
+ super().__init__()
31
+ self.modules = []
32
+ for size in Hypernetwork.enable_sizes:
33
+ self.modules.append((HypernetworkModule(size, multiplier), HypernetworkModule(size, multiplier)))
34
+ self.register_module(f"{size}_0", self.modules[-1][0])
35
+ self.register_module(f"{size}_1", self.modules[-1][1])
36
+
37
+ def apply_to_stable_diffusion(self, text_encoder, vae, unet):
38
+ blocks = unet.input_blocks + [unet.middle_block] + unet.output_blocks
39
+ for block in blocks:
40
+ for subblk in block:
41
+ if 'SpatialTransformer' in str(type(subblk)):
42
+ for tf_block in subblk.transformer_blocks:
43
+ for attn in [tf_block.attn1, tf_block.attn2]:
44
+ size = attn.context_dim
45
+ if size in Hypernetwork.enable_sizes:
46
+ attn.hypernetwork = self
47
+ else:
48
+ attn.hypernetwork = None
49
+
50
+ def apply_to_diffusers(self, text_encoder, vae, unet):
51
+ blocks = unet.down_blocks + [unet.mid_block] + unet.up_blocks
52
+ for block in blocks:
53
+ if hasattr(block, 'attentions'):
54
+ for subblk in block.attentions:
55
+ if 'SpatialTransformer' in str(type(subblk)) or 'Transformer2DModel' in str(type(subblk)): # 0.6.0 and 0.7~
56
+ for tf_block in subblk.transformer_blocks:
57
+ for attn in [tf_block.attn1, tf_block.attn2]:
58
+ size = attn.to_k.in_features
59
+ if size in Hypernetwork.enable_sizes:
60
+ attn.hypernetwork = self
61
+ else:
62
+ attn.hypernetwork = None
63
+ return True # TODO error checking
64
+
65
+ def forward(self, x, context):
66
+ size = context.shape[-1]
67
+ assert size in Hypernetwork.enable_sizes
68
+ module = self.modules[Hypernetwork.enable_sizes.index(size)]
69
+ return module[0].forward(context), module[1].forward(context)
70
+
71
+ def load_from_state_dict(self, state_dict):
72
+ # old ver to new ver
73
+ changes = {
74
+ 'linear1.bias': 'linear.0.bias',
75
+ 'linear1.weight': 'linear.0.weight',
76
+ 'linear2.bias': 'linear.1.bias',
77
+ 'linear2.weight': 'linear.1.weight',
78
+ }
79
+ for key_from, key_to in changes.items():
80
+ if key_from in state_dict:
81
+ state_dict[key_to] = state_dict[key_from]
82
+ del state_dict[key_from]
83
+
84
+ for size, sd in state_dict.items():
85
+ if type(size) == int:
86
+ self.modules[Hypernetwork.enable_sizes.index(size)][0].load_state_dict(sd[0], strict=True)
87
+ self.modules[Hypernetwork.enable_sizes.index(size)][1].load_state_dict(sd[1], strict=True)
88
+ return True
89
+
90
+ def get_state_dict(self):
91
+ state_dict = {}
92
+ for i, size in enumerate(Hypernetwork.enable_sizes):
93
+ sd0 = self.modules[i][0].state_dict()
94
+ sd1 = self.modules[i][1].state_dict()
95
+ state_dict[size] = [sd0, sd1]
96
+ return state_dict
latent_upscaler.py ADDED
@@ -0,0 +1,354 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # 外部から簡単にupscalerを呼ぶためのスクリプト
2
+ # 単体で動くようにモデル定義も含めている
3
+
4
+ import argparse
5
+ import glob
6
+ import os
7
+ import cv2
8
+ from diffusers import AutoencoderKL
9
+
10
+ from typing import Dict, List
11
+ import numpy as np
12
+
13
+ import torch
14
+ from library.device_utils import init_ipex, get_preferred_device
15
+ init_ipex()
16
+
17
+ from torch import nn
18
+ from tqdm import tqdm
19
+ from PIL import Image
20
+ from library.utils import setup_logging
21
+ setup_logging()
22
+ import logging
23
+ logger = logging.getLogger(__name__)
24
+
25
+ class ResidualBlock(nn.Module):
26
+ def __init__(self, in_channels, out_channels=None, kernel_size=3, stride=1, padding=1):
27
+ super(ResidualBlock, self).__init__()
28
+
29
+ if out_channels is None:
30
+ out_channels = in_channels
31
+
32
+ self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size, stride, padding, bias=False)
33
+ self.bn1 = nn.BatchNorm2d(out_channels)
34
+ self.relu1 = nn.ReLU(inplace=True)
35
+
36
+ self.conv2 = nn.Conv2d(out_channels, out_channels, kernel_size, stride, padding, bias=False)
37
+ self.bn2 = nn.BatchNorm2d(out_channels)
38
+
39
+ self.relu2 = nn.ReLU(inplace=True) # このReLUはresidualに足す前にかけるほうがいいかも
40
+
41
+ # initialize weights
42
+ self._initialize_weights()
43
+
44
+ def _initialize_weights(self):
45
+ for m in self.modules():
46
+ if isinstance(m, nn.Conv2d):
47
+ nn.init.kaiming_normal_(m.weight, mode="fan_out", nonlinearity="relu")
48
+ if m.bias is not None:
49
+ nn.init.constant_(m.bias, 0)
50
+ elif isinstance(m, nn.BatchNorm2d):
51
+ nn.init.constant_(m.weight, 1)
52
+ nn.init.constant_(m.bias, 0)
53
+ elif isinstance(m, nn.Linear):
54
+ nn.init.normal_(m.weight, 0, 0.01)
55
+ nn.init.constant_(m.bias, 0)
56
+
57
+ def forward(self, x):
58
+ residual = x
59
+
60
+ out = self.conv1(x)
61
+ out = self.bn1(out)
62
+ out = self.relu1(out)
63
+
64
+ out = self.conv2(out)
65
+ out = self.bn2(out)
66
+
67
+ out += residual
68
+
69
+ out = self.relu2(out)
70
+
71
+ return out
72
+
73
+
74
+ class Upscaler(nn.Module):
75
+ def __init__(self):
76
+ super(Upscaler, self).__init__()
77
+
78
+ # define layers
79
+ # latent has 4 channels
80
+
81
+ self.conv1 = nn.Conv2d(4, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
82
+ self.bn1 = nn.BatchNorm2d(128)
83
+ self.relu1 = nn.ReLU(inplace=True)
84
+
85
+ # resblocks
86
+ # 数の暴力で20個:次元数を増やすよりもブロックを増やしたほうがreceptive fieldが広がるはずだぞ
87
+ self.resblock1 = ResidualBlock(128)
88
+ self.resblock2 = ResidualBlock(128)
89
+ self.resblock3 = ResidualBlock(128)
90
+ self.resblock4 = ResidualBlock(128)
91
+ self.resblock5 = ResidualBlock(128)
92
+ self.resblock6 = ResidualBlock(128)
93
+ self.resblock7 = ResidualBlock(128)
94
+ self.resblock8 = ResidualBlock(128)
95
+ self.resblock9 = ResidualBlock(128)
96
+ self.resblock10 = ResidualBlock(128)
97
+ self.resblock11 = ResidualBlock(128)
98
+ self.resblock12 = ResidualBlock(128)
99
+ self.resblock13 = ResidualBlock(128)
100
+ self.resblock14 = ResidualBlock(128)
101
+ self.resblock15 = ResidualBlock(128)
102
+ self.resblock16 = ResidualBlock(128)
103
+ self.resblock17 = ResidualBlock(128)
104
+ self.resblock18 = ResidualBlock(128)
105
+ self.resblock19 = ResidualBlock(128)
106
+ self.resblock20 = ResidualBlock(128)
107
+
108
+ # last convs
109
+ self.conv2 = nn.Conv2d(128, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
110
+ self.bn2 = nn.BatchNorm2d(64)
111
+ self.relu2 = nn.ReLU(inplace=True)
112
+
113
+ self.conv3 = nn.Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
114
+ self.bn3 = nn.BatchNorm2d(64)
115
+ self.relu3 = nn.ReLU(inplace=True)
116
+
117
+ # final conv: output 4 channels
118
+ self.conv_final = nn.Conv2d(64, 4, kernel_size=(1, 1), stride=(1, 1), padding=(0, 0))
119
+
120
+ # initialize weights
121
+ self._initialize_weights()
122
+
123
+ def _initialize_weights(self):
124
+ for m in self.modules():
125
+ if isinstance(m, nn.Conv2d):
126
+ nn.init.kaiming_normal_(m.weight, mode="fan_out", nonlinearity="relu")
127
+ if m.bias is not None:
128
+ nn.init.constant_(m.bias, 0)
129
+ elif isinstance(m, nn.BatchNorm2d):
130
+ nn.init.constant_(m.weight, 1)
131
+ nn.init.constant_(m.bias, 0)
132
+ elif isinstance(m, nn.Linear):
133
+ nn.init.normal_(m.weight, 0, 0.01)
134
+ nn.init.constant_(m.bias, 0)
135
+
136
+ # initialize final conv weights to 0: 流行りのzero conv
137
+ nn.init.constant_(self.conv_final.weight, 0)
138
+
139
+ def forward(self, x):
140
+ inp = x
141
+
142
+ x = self.conv1(x)
143
+ x = self.bn1(x)
144
+ x = self.relu1(x)
145
+
146
+ # いくつかのresblockを通した��に、residualを足すことで精度向上と学習速度向上が見込めるはず
147
+ residual = x
148
+ x = self.resblock1(x)
149
+ x = self.resblock2(x)
150
+ x = self.resblock3(x)
151
+ x = self.resblock4(x)
152
+ x = x + residual
153
+ residual = x
154
+ x = self.resblock5(x)
155
+ x = self.resblock6(x)
156
+ x = self.resblock7(x)
157
+ x = self.resblock8(x)
158
+ x = x + residual
159
+ residual = x
160
+ x = self.resblock9(x)
161
+ x = self.resblock10(x)
162
+ x = self.resblock11(x)
163
+ x = self.resblock12(x)
164
+ x = x + residual
165
+ residual = x
166
+ x = self.resblock13(x)
167
+ x = self.resblock14(x)
168
+ x = self.resblock15(x)
169
+ x = self.resblock16(x)
170
+ x = x + residual
171
+ residual = x
172
+ x = self.resblock17(x)
173
+ x = self.resblock18(x)
174
+ x = self.resblock19(x)
175
+ x = self.resblock20(x)
176
+ x = x + residual
177
+
178
+ x = self.conv2(x)
179
+ x = self.bn2(x)
180
+ x = self.relu2(x)
181
+ x = self.conv3(x)
182
+ x = self.bn3(x)
183
+
184
+ # ここにreluを入れないほうがいい気がする
185
+
186
+ x = self.conv_final(x)
187
+
188
+ # network estimates the difference between the input and the output
189
+ x = x + inp
190
+
191
+ return x
192
+
193
+ def support_latents(self) -> bool:
194
+ return False
195
+
196
+ def upscale(
197
+ self,
198
+ vae: AutoencoderKL,
199
+ lowreso_images: List[Image.Image],
200
+ lowreso_latents: torch.Tensor,
201
+ dtype: torch.dtype,
202
+ width: int,
203
+ height: int,
204
+ batch_size: int = 1,
205
+ vae_batch_size: int = 1,
206
+ ):
207
+ # assertion
208
+ assert lowreso_images is not None, "Upscaler requires lowreso image"
209
+
210
+ # make upsampled image with lanczos4
211
+ upsampled_images = []
212
+ for lowreso_image in lowreso_images:
213
+ upsampled_image = np.array(lowreso_image.resize((width, height), Image.LANCZOS))
214
+ upsampled_images.append(upsampled_image)
215
+
216
+ # convert to tensor: this tensor is too large to be converted to cuda
217
+ upsampled_images = [torch.from_numpy(upsampled_image).permute(2, 0, 1).float() for upsampled_image in upsampled_images]
218
+ upsampled_images = torch.stack(upsampled_images, dim=0)
219
+ upsampled_images = upsampled_images.to(dtype)
220
+
221
+ # normalize to [-1, 1]
222
+ upsampled_images = upsampled_images / 127.5 - 1.0
223
+
224
+ # convert upsample images to latents with batch size
225
+ # logger.info("Encoding upsampled (LANCZOS4) images...")
226
+ upsampled_latents = []
227
+ for i in tqdm(range(0, upsampled_images.shape[0], vae_batch_size)):
228
+ batch = upsampled_images[i : i + vae_batch_size].to(vae.device)
229
+ with torch.no_grad():
230
+ batch = vae.encode(batch).latent_dist.sample()
231
+ upsampled_latents.append(batch)
232
+
233
+ upsampled_latents = torch.cat(upsampled_latents, dim=0)
234
+
235
+ # upscale (refine) latents with this model with batch size
236
+ logger.info("Upscaling latents...")
237
+ upscaled_latents = []
238
+ for i in range(0, upsampled_latents.shape[0], batch_size):
239
+ with torch.no_grad():
240
+ upscaled_latents.append(self.forward(upsampled_latents[i : i + batch_size]))
241
+ upscaled_latents = torch.cat(upscaled_latents, dim=0)
242
+
243
+ return upscaled_latents * 0.18215
244
+
245
+
246
+ # external interface: returns a model
247
+ def create_upscaler(**kwargs):
248
+ weights = kwargs["weights"]
249
+ model = Upscaler()
250
+
251
+ logger.info(f"Loading weights from {weights}...")
252
+ if os.path.splitext(weights)[1] == ".safetensors":
253
+ from safetensors.torch import load_file
254
+
255
+ sd = load_file(weights)
256
+ else:
257
+ sd = torch.load(weights, map_location=torch.device("cpu"))
258
+ model.load_state_dict(sd)
259
+ return model
260
+
261
+
262
+ # another interface: upscale images with a model for given images from command line
263
+ def upscale_images(args: argparse.Namespace):
264
+ DEVICE = get_preferred_device()
265
+ us_dtype = torch.float16 # TODO: support fp32/bf16
266
+ os.makedirs(args.output_dir, exist_ok=True)
267
+
268
+ # load VAE with Diffusers
269
+ assert args.vae_path is not None, "VAE path is required"
270
+ logger.info(f"Loading VAE from {args.vae_path}...")
271
+ vae = AutoencoderKL.from_pretrained(args.vae_path, subfolder="vae")
272
+ vae.to(DEVICE, dtype=us_dtype)
273
+
274
+ # prepare model
275
+ logger.info("Preparing model...")
276
+ upscaler: Upscaler = create_upscaler(weights=args.weights)
277
+ # logger.info("Loading weights from", args.weights)
278
+ # upscaler.load_state_dict(torch.load(args.weights))
279
+ upscaler.eval()
280
+ upscaler.to(DEVICE, dtype=us_dtype)
281
+
282
+ # load images
283
+ image_paths = glob.glob(args.image_pattern)
284
+ images = []
285
+ for image_path in image_paths:
286
+ image = Image.open(image_path)
287
+ image = image.convert("RGB")
288
+
289
+ # make divisible by 8
290
+ width = image.width
291
+ height = image.height
292
+ if width % 8 != 0:
293
+ width = width - (width % 8)
294
+ if height % 8 != 0:
295
+ height = height - (height % 8)
296
+ if width != image.width or height != image.height:
297
+ image = image.crop((0, 0, width, height))
298
+
299
+ images.append(image)
300
+
301
+ # debug output
302
+ if args.debug:
303
+ for image, image_path in zip(images, image_paths):
304
+ image_debug = image.resize((image.width * 2, image.height * 2), Image.LANCZOS)
305
+
306
+ basename = os.path.basename(image_path)
307
+ basename_wo_ext, ext = os.path.splitext(basename)
308
+ dest_file_name = os.path.join(args.output_dir, f"{basename_wo_ext}_lanczos4{ext}")
309
+ image_debug.save(dest_file_name)
310
+
311
+ # upscale
312
+ logger.info("Upscaling...")
313
+ upscaled_latents = upscaler.upscale(
314
+ vae, images, None, us_dtype, width * 2, height * 2, batch_size=args.batch_size, vae_batch_size=args.vae_batch_size
315
+ )
316
+ upscaled_latents /= 0.18215
317
+
318
+ # decode with batch
319
+ logger.info("Decoding...")
320
+ upscaled_images = []
321
+ for i in tqdm(range(0, upscaled_latents.shape[0], args.vae_batch_size)):
322
+ with torch.no_grad():
323
+ batch = vae.decode(upscaled_latents[i : i + args.vae_batch_size]).sample
324
+ batch = batch.to("cpu")
325
+ upscaled_images.append(batch)
326
+ upscaled_images = torch.cat(upscaled_images, dim=0)
327
+
328
+ # tensor to numpy
329
+ upscaled_images = upscaled_images.permute(0, 2, 3, 1).numpy()
330
+ upscaled_images = (upscaled_images + 1.0) * 127.5
331
+ upscaled_images = upscaled_images.clip(0, 255).astype(np.uint8)
332
+
333
+ upscaled_images = upscaled_images[..., ::-1]
334
+
335
+ # save images
336
+ for i, image in enumerate(upscaled_images):
337
+ basename = os.path.basename(image_paths[i])
338
+ basename_wo_ext, ext = os.path.splitext(basename)
339
+ dest_file_name = os.path.join(args.output_dir, f"{basename_wo_ext}_upscaled{ext}")
340
+ cv2.imwrite(dest_file_name, image)
341
+
342
+
343
+ if __name__ == "__main__":
344
+ parser = argparse.ArgumentParser()
345
+ parser.add_argument("--vae_path", type=str, default=None, help="VAE path")
346
+ parser.add_argument("--weights", type=str, default=None, help="Weights path")
347
+ parser.add_argument("--image_pattern", type=str, default=None, help="Image pattern")
348
+ parser.add_argument("--output_dir", type=str, default=".", help="Output directory")
349
+ parser.add_argument("--batch_size", type=int, default=4, help="Batch size")
350
+ parser.add_argument("--vae_batch_size", type=int, default=1, help="VAE batch size")
351
+ parser.add_argument("--debug", action="store_true", help="Debug mode")
352
+
353
+ args = parser.parse_args()
354
+ upscale_images(args)
libbitsandbytes_cpu.dll ADDED
Binary file (76.3 kB). View file
 
libbitsandbytes_cuda116.dll ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:88f7bd2916ca3effc43f88492f1e1b9088d13cb5be3b4a3a4aede6aa3bf8d412
3
+ size 4724224
libbitsandbytes_cuda118.dll ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:4dc34709b8dcb078cbcdd65e5684f116cb395644d12b9c9fb144af5455bb1c18
3
+ size 14026752
logo_aihub.png ADDED
lora.py ADDED
@@ -0,0 +1,1410 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # LoRA network module
2
+ # reference:
3
+ # https://github.com/microsoft/LoRA/blob/main/loralib/layers.py
4
+ # https://github.com/cloneofsimo/lora/blob/master/lora_diffusion/lora.py
5
+
6
+ import math
7
+ import os
8
+ from typing import Dict, List, Optional, Tuple, Type, Union
9
+ from diffusers import AutoencoderKL
10
+ from transformers import CLIPTextModel
11
+ import numpy as np
12
+ import torch
13
+ import re
14
+ from library.utils import setup_logging
15
+ from library.sdxl_original_unet import SdxlUNet2DConditionModel
16
+
17
+ setup_logging()
18
+ import logging
19
+
20
+ logger = logging.getLogger(__name__)
21
+
22
+ RE_UPDOWN = re.compile(r"(up|down)_blocks_(\d+)_(resnets|upsamplers|downsamplers|attentions)_(\d+)_")
23
+
24
+
25
+ class LoRAModule(torch.nn.Module):
26
+ """
27
+ replaces forward method of the original Linear, instead of replacing the original Linear module.
28
+ """
29
+
30
+ def __init__(
31
+ self,
32
+ lora_name,
33
+ org_module: torch.nn.Module,
34
+ multiplier=1.0,
35
+ lora_dim=4,
36
+ alpha=1,
37
+ dropout=None,
38
+ rank_dropout=None,
39
+ module_dropout=None,
40
+ ):
41
+ """if alpha == 0 or None, alpha is rank (no scaling)."""
42
+ super().__init__()
43
+ self.lora_name = lora_name
44
+
45
+ if org_module.__class__.__name__ == "Conv2d":
46
+ in_dim = org_module.in_channels
47
+ out_dim = org_module.out_channels
48
+ else:
49
+ in_dim = org_module.in_features
50
+ out_dim = org_module.out_features
51
+
52
+ # if limit_rank:
53
+ # self.lora_dim = min(lora_dim, in_dim, out_dim)
54
+ # if self.lora_dim != lora_dim:
55
+ # logger.info(f"{lora_name} dim (rank) is changed to: {self.lora_dim}")
56
+ # else:
57
+ self.lora_dim = lora_dim
58
+
59
+ if org_module.__class__.__name__ == "Conv2d":
60
+ kernel_size = org_module.kernel_size
61
+ stride = org_module.stride
62
+ padding = org_module.padding
63
+ self.lora_down = torch.nn.Conv2d(in_dim, self.lora_dim, kernel_size, stride, padding, bias=False)
64
+ self.lora_up = torch.nn.Conv2d(self.lora_dim, out_dim, (1, 1), (1, 1), bias=False)
65
+ else:
66
+ self.lora_down = torch.nn.Linear(in_dim, self.lora_dim, bias=False)
67
+ self.lora_up = torch.nn.Linear(self.lora_dim, out_dim, bias=False)
68
+
69
+ if type(alpha) == torch.Tensor:
70
+ alpha = alpha.detach().float().numpy() # without casting, bf16 causes error
71
+ alpha = self.lora_dim if alpha is None or alpha == 0 else alpha
72
+ self.scale = alpha / self.lora_dim
73
+ self.register_buffer("alpha", torch.tensor(alpha)) # 定数として扱える
74
+
75
+ # same as microsoft's
76
+ torch.nn.init.kaiming_uniform_(self.lora_down.weight, a=math.sqrt(5))
77
+ torch.nn.init.zeros_(self.lora_up.weight)
78
+
79
+ self.multiplier = multiplier
80
+ self.org_module = org_module # remove in applying
81
+ self.dropout = dropout
82
+ self.rank_dropout = rank_dropout
83
+ self.module_dropout = module_dropout
84
+
85
+ def apply_to(self):
86
+ self.org_forward = self.org_module.forward
87
+ self.org_module.forward = self.forward
88
+ del self.org_module
89
+
90
+ def forward(self, x):
91
+ org_forwarded = self.org_forward(x)
92
+
93
+ # module dropout
94
+ if self.module_dropout is not None and self.training:
95
+ if torch.rand(1) < self.module_dropout:
96
+ return org_forwarded
97
+
98
+ lx = self.lora_down(x)
99
+
100
+ # normal dropout
101
+ if self.dropout is not None and self.training:
102
+ lx = torch.nn.functional.dropout(lx, p=self.dropout)
103
+
104
+ # rank dropout
105
+ if self.rank_dropout is not None and self.training:
106
+ mask = torch.rand((lx.size(0), self.lora_dim), device=lx.device) > self.rank_dropout
107
+ if len(lx.size()) == 3:
108
+ mask = mask.unsqueeze(1) # for Text Encoder
109
+ elif len(lx.size()) == 4:
110
+ mask = mask.unsqueeze(-1).unsqueeze(-1) # for Conv2d
111
+ lx = lx * mask
112
+
113
+ # scaling for rank dropout: treat as if the rank is changed
114
+ # maskから計算することも考えられるが、augmentation的な効果を期待してrank_dropoutを用いる
115
+ scale = self.scale * (1.0 / (1.0 - self.rank_dropout)) # redundant for readability
116
+ else:
117
+ scale = self.scale
118
+
119
+ lx = self.lora_up(lx)
120
+
121
+ return org_forwarded + lx * self.multiplier * scale
122
+
123
+
124
+ class LoRAInfModule(LoRAModule):
125
+ def __init__(
126
+ self,
127
+ lora_name,
128
+ org_module: torch.nn.Module,
129
+ multiplier=1.0,
130
+ lora_dim=4,
131
+ alpha=1,
132
+ **kwargs,
133
+ ):
134
+ # no dropout for inference
135
+ super().__init__(lora_name, org_module, multiplier, lora_dim, alpha)
136
+
137
+ self.org_module_ref = [org_module] # 後から参照できるように
138
+ self.enabled = True
139
+
140
+ # check regional or not by lora_name
141
+ self.text_encoder = False
142
+ if lora_name.startswith("lora_te_"):
143
+ self.regional = False
144
+ self.use_sub_prompt = True
145
+ self.text_encoder = True
146
+ elif "attn2_to_k" in lora_name or "attn2_to_v" in lora_name:
147
+ self.regional = False
148
+ self.use_sub_prompt = True
149
+ elif "time_emb" in lora_name:
150
+ self.regional = False
151
+ self.use_sub_prompt = False
152
+ else:
153
+ self.regional = True
154
+ self.use_sub_prompt = False
155
+
156
+ self.network: LoRANetwork = None
157
+
158
+ def set_network(self, network):
159
+ self.network = network
160
+
161
+ # freezeしてマージする
162
+ def merge_to(self, sd, dtype, device):
163
+ # get up/down weight
164
+ up_weight = sd["lora_up.weight"].to(torch.float).to(device)
165
+ down_weight = sd["lora_down.weight"].to(torch.float).to(device)
166
+
167
+ # extract weight from org_module
168
+ org_sd = self.org_module.state_dict()
169
+ weight = org_sd["weight"].to(torch.float)
170
+
171
+ # merge weight
172
+ if len(weight.size()) == 2:
173
+ # linear
174
+ weight = weight + self.multiplier * (up_weight @ down_weight) * self.scale
175
+ elif down_weight.size()[2:4] == (1, 1):
176
+ # conv2d 1x1
177
+ weight = (
178
+ weight
179
+ + self.multiplier
180
+ * (up_weight.squeeze(3).squeeze(2) @ down_weight.squeeze(3).squeeze(2)).unsqueeze(2).unsqueeze(3)
181
+ * self.scale
182
+ )
183
+ else:
184
+ # conv2d 3x3
185
+ conved = torch.nn.functional.conv2d(down_weight.permute(1, 0, 2, 3), up_weight).permute(1, 0, 2, 3)
186
+ # logger.info(conved.size(), weight.size(), module.stride, module.padding)
187
+ weight = weight + self.multiplier * conved * self.scale
188
+
189
+ # set weight to org_module
190
+ org_sd["weight"] = weight.to(dtype)
191
+ self.org_module.load_state_dict(org_sd)
192
+
193
+ # 復元できるマージのため、このモジュールのweightを返す
194
+ def get_weight(self, multiplier=None):
195
+ if multiplier is None:
196
+ multiplier = self.multiplier
197
+
198
+ # get up/down weight from module
199
+ up_weight = self.lora_up.weight.to(torch.float)
200
+ down_weight = self.lora_down.weight.to(torch.float)
201
+
202
+ # pre-calculated weight
203
+ if len(down_weight.size()) == 2:
204
+ # linear
205
+ weight = self.multiplier * (up_weight @ down_weight) * self.scale
206
+ elif down_weight.size()[2:4] == (1, 1):
207
+ # conv2d 1x1
208
+ weight = (
209
+ self.multiplier
210
+ * (up_weight.squeeze(3).squeeze(2) @ down_weight.squeeze(3).squeeze(2)).unsqueeze(2).unsqueeze(3)
211
+ * self.scale
212
+ )
213
+ else:
214
+ # conv2d 3x3
215
+ conved = torch.nn.functional.conv2d(down_weight.permute(1, 0, 2, 3), up_weight).permute(1, 0, 2, 3)
216
+ weight = self.multiplier * conved * self.scale
217
+
218
+ return weight
219
+
220
+ def set_region(self, region):
221
+ self.region = region
222
+ self.region_mask = None
223
+
224
+ def default_forward(self, x):
225
+ # logger.info(f"default_forward {self.lora_name} {x.size()}")
226
+ return self.org_forward(x) + self.lora_up(self.lora_down(x)) * self.multiplier * self.scale
227
+
228
+ def forward(self, x):
229
+ if not self.enabled:
230
+ return self.org_forward(x)
231
+
232
+ if self.network is None or self.network.sub_prompt_index is None:
233
+ return self.default_forward(x)
234
+ if not self.regional and not self.use_sub_prompt:
235
+ return self.default_forward(x)
236
+
237
+ if self.regional:
238
+ return self.regional_forward(x)
239
+ else:
240
+ return self.sub_prompt_forward(x)
241
+
242
+ def get_mask_for_x(self, x):
243
+ # calculate size from shape of x
244
+ if len(x.size()) == 4:
245
+ h, w = x.size()[2:4]
246
+ area = h * w
247
+ else:
248
+ area = x.size()[1]
249
+
250
+ mask = self.network.mask_dic.get(area, None)
251
+ if mask is None or len(x.size()) == 2:
252
+ # emb_layers in SDXL doesn't have mask
253
+ # if "emb" not in self.lora_name:
254
+ # print(f"mask is None for resolution {self.lora_name}, {area}, {x.size()}")
255
+ mask_size = (1, x.size()[1]) if len(x.size()) == 2 else (1, *x.size()[1:-1], 1)
256
+ return torch.ones(mask_size, dtype=x.dtype, device=x.device) / self.network.num_sub_prompts
257
+ if len(x.size()) == 3:
258
+ mask = torch.reshape(mask, (1, -1, 1))
259
+ return mask
260
+
261
+ def regional_forward(self, x):
262
+ if "attn2_to_out" in self.lora_name:
263
+ return self.to_out_forward(x)
264
+
265
+ if self.network.mask_dic is None: # sub_prompt_index >= 3
266
+ return self.default_forward(x)
267
+
268
+ # apply mask for LoRA result
269
+ lx = self.lora_up(self.lora_down(x)) * self.multiplier * self.scale
270
+ mask = self.get_mask_for_x(lx)
271
+ # print("regional", self.lora_name, self.network.sub_prompt_index, lx.size(), mask.size())
272
+ # if mask.ndim > lx.ndim: # in some resolution, lx is 2d and mask is 3d (the reason is not checked)
273
+ # mask = mask.squeeze(-1)
274
+ lx = lx * mask
275
+
276
+ x = self.org_forward(x)
277
+ x = x + lx
278
+
279
+ if "attn2_to_q" in self.lora_name and self.network.is_last_network:
280
+ x = self.postp_to_q(x)
281
+
282
+ return x
283
+
284
+ def postp_to_q(self, x):
285
+ # repeat x to num_sub_prompts
286
+ has_real_uncond = x.size()[0] // self.network.batch_size == 3
287
+ qc = self.network.batch_size # uncond
288
+ qc += self.network.batch_size * self.network.num_sub_prompts # cond
289
+ if has_real_uncond:
290
+ qc += self.network.batch_size # real_uncond
291
+
292
+ query = torch.zeros((qc, x.size()[1], x.size()[2]), device=x.device, dtype=x.dtype)
293
+ query[: self.network.batch_size] = x[: self.network.batch_size]
294
+
295
+ for i in range(self.network.batch_size):
296
+ qi = self.network.batch_size + i * self.network.num_sub_prompts
297
+ query[qi : qi + self.network.num_sub_prompts] = x[self.network.batch_size + i]
298
+
299
+ if has_real_uncond:
300
+ query[-self.network.batch_size :] = x[-self.network.batch_size :]
301
+
302
+ # logger.info(f"postp_to_q {self.lora_name} {x.size()} {query.size()} {self.network.num_sub_prompts}")
303
+ return query
304
+
305
+ def sub_prompt_forward(self, x):
306
+ if x.size()[0] == self.network.batch_size: # if uncond in text_encoder, do not apply LoRA
307
+ return self.org_forward(x)
308
+
309
+ emb_idx = self.network.sub_prompt_index
310
+ if not self.text_encoder:
311
+ emb_idx += self.network.batch_size
312
+
313
+ # apply sub prompt of X
314
+ lx = x[emb_idx :: self.network.num_sub_prompts]
315
+ lx = self.lora_up(self.lora_down(lx)) * self.multiplier * self.scale
316
+
317
+ # logger.info(f"sub_prompt_forward {self.lora_name} {x.size()} {lx.size()} {emb_idx}")
318
+
319
+ x = self.org_forward(x)
320
+ x[emb_idx :: self.network.num_sub_prompts] += lx
321
+
322
+ return x
323
+
324
+ def to_out_forward(self, x):
325
+ # logger.info(f"to_out_forward {self.lora_name} {x.size()} {self.network.is_last_network}")
326
+
327
+ if self.network.is_last_network:
328
+ masks = [None] * self.network.num_sub_prompts
329
+ self.network.shared[self.lora_name] = (None, masks)
330
+ else:
331
+ lx, masks = self.network.shared[self.lora_name]
332
+
333
+ # call own LoRA
334
+ x1 = x[self.network.batch_size + self.network.sub_prompt_index :: self.network.num_sub_prompts]
335
+ lx1 = self.lora_up(self.lora_down(x1)) * self.multiplier * self.scale
336
+
337
+ if self.network.is_last_network:
338
+ lx = torch.zeros(
339
+ (self.network.num_sub_prompts * self.network.batch_size, *lx1.size()[1:]), device=lx1.device, dtype=lx1.dtype
340
+ )
341
+ self.network.shared[self.lora_name] = (lx, masks)
342
+
343
+ # logger.info(f"to_out_forward {lx.size()} {lx1.size()} {self.network.sub_prompt_index} {self.network.num_sub_prompts}")
344
+ lx[self.network.sub_prompt_index :: self.network.num_sub_prompts] += lx1
345
+ masks[self.network.sub_prompt_index] = self.get_mask_for_x(lx1)
346
+
347
+ # if not last network, return x and masks
348
+ x = self.org_forward(x)
349
+ if not self.network.is_last_network:
350
+ return x
351
+
352
+ lx, masks = self.network.shared.pop(self.lora_name)
353
+
354
+ # if last network, combine separated x with mask weighted sum
355
+ has_real_uncond = x.size()[0] // self.network.batch_size == self.network.num_sub_prompts + 2
356
+
357
+ out = torch.zeros((self.network.batch_size * (3 if has_real_uncond else 2), *x.size()[1:]), device=x.device, dtype=x.dtype)
358
+ out[: self.network.batch_size] = x[: self.network.batch_size] # uncond
359
+ if has_real_uncond:
360
+ out[-self.network.batch_size :] = x[-self.network.batch_size :] # real_uncond
361
+
362
+ # logger.info(f"to_out_forward {self.lora_name} {self.network.sub_prompt_index} {self.network.num_sub_prompts}")
363
+ # if num_sub_prompts > num of LoRAs, fill with zero
364
+ for i in range(len(masks)):
365
+ if masks[i] is None:
366
+ masks[i] = torch.zeros_like(masks[0])
367
+
368
+ mask = torch.cat(masks)
369
+ mask_sum = torch.sum(mask, dim=0) + 1e-4
370
+ for i in range(self.network.batch_size):
371
+ # 1枚の画像ごとに処理する
372
+ lx1 = lx[i * self.network.num_sub_prompts : (i + 1) * self.network.num_sub_prompts]
373
+ lx1 = lx1 * mask
374
+ lx1 = torch.sum(lx1, dim=0)
375
+
376
+ xi = self.network.batch_size + i * self.network.num_sub_prompts
377
+ x1 = x[xi : xi + self.network.num_sub_prompts]
378
+ x1 = x1 * mask
379
+ x1 = torch.sum(x1, dim=0)
380
+ x1 = x1 / mask_sum
381
+
382
+ x1 = x1 + lx1
383
+ out[self.network.batch_size + i] = x1
384
+
385
+ # logger.info(f"to_out_forward {x.size()} {out.size()} {has_real_uncond}")
386
+ return out
387
+
388
+
389
+ def parse_block_lr_kwargs(is_sdxl: bool, nw_kwargs: Dict) -> Optional[List[float]]:
390
+ down_lr_weight = nw_kwargs.get("down_lr_weight", None)
391
+ mid_lr_weight = nw_kwargs.get("mid_lr_weight", None)
392
+ up_lr_weight = nw_kwargs.get("up_lr_weight", None)
393
+
394
+ # 以上のいずれにも設定がない場合は無効としてNoneを返す
395
+ if down_lr_weight is None and mid_lr_weight is None and up_lr_weight is None:
396
+ return None
397
+
398
+ # extract learning rate weight for each block
399
+ if down_lr_weight is not None:
400
+ # if some parameters are not set, use zero
401
+ if "," in down_lr_weight:
402
+ down_lr_weight = [(float(s) if s else 0.0) for s in down_lr_weight.split(",")]
403
+
404
+ if mid_lr_weight is not None:
405
+ mid_lr_weight = [(float(s) if s else 0.0) for s in mid_lr_weight.split(",")]
406
+
407
+ if up_lr_weight is not None:
408
+ if "," in up_lr_weight:
409
+ up_lr_weight = [(float(s) if s else 0.0) for s in up_lr_weight.split(",")]
410
+
411
+ return get_block_lr_weight(
412
+ is_sdxl, down_lr_weight, mid_lr_weight, up_lr_weight, float(nw_kwargs.get("block_lr_zero_threshold", 0.0))
413
+ )
414
+
415
+
416
+ def create_network(
417
+ multiplier: float,
418
+ network_dim: Optional[int],
419
+ network_alpha: Optional[float],
420
+ vae: AutoencoderKL,
421
+ text_encoder: Union[CLIPTextModel, List[CLIPTextModel]],
422
+ unet,
423
+ neuron_dropout: Optional[float] = None,
424
+ **kwargs,
425
+ ):
426
+ # if unet is an instance of SdxlUNet2DConditionModel or subclass, set is_sdxl to True
427
+ is_sdxl = unet is not None and issubclass(unet.__class__, SdxlUNet2DConditionModel)
428
+
429
+ if network_dim is None:
430
+ network_dim = 4 # default
431
+ if network_alpha is None:
432
+ network_alpha = 1.0
433
+
434
+ # extract dim/alpha for conv2d, and block dim
435
+ conv_dim = kwargs.get("conv_dim", None)
436
+ conv_alpha = kwargs.get("conv_alpha", None)
437
+ if conv_dim is not None:
438
+ conv_dim = int(conv_dim)
439
+ if conv_alpha is None:
440
+ conv_alpha = 1.0
441
+ else:
442
+ conv_alpha = float(conv_alpha)
443
+
444
+ # block dim/alpha/lr
445
+ block_dims = kwargs.get("block_dims", None)
446
+ block_lr_weight = parse_block_lr_kwargs(is_sdxl, kwargs)
447
+
448
+ # 以上のいずれかに指定があればblockごとのdim(rank)を有効にする
449
+ if block_dims is not None or block_lr_weight is not None:
450
+ block_alphas = kwargs.get("block_alphas", None)
451
+ conv_block_dims = kwargs.get("conv_block_dims", None)
452
+ conv_block_alphas = kwargs.get("conv_block_alphas", None)
453
+
454
+ block_dims, block_alphas, conv_block_dims, conv_block_alphas = get_block_dims_and_alphas(
455
+ is_sdxl, block_dims, block_alphas, network_dim, network_alpha, conv_block_dims, conv_block_alphas, conv_dim, conv_alpha
456
+ )
457
+
458
+ # remove block dim/alpha without learning rate
459
+ block_dims, block_alphas, conv_block_dims, conv_block_alphas = remove_block_dims_and_alphas(
460
+ is_sdxl, block_dims, block_alphas, conv_block_dims, conv_block_alphas, block_lr_weight
461
+ )
462
+
463
+ else:
464
+ block_alphas = None
465
+ conv_block_dims = None
466
+ conv_block_alphas = None
467
+
468
+ # rank/module dropout
469
+ rank_dropout = kwargs.get("rank_dropout", None)
470
+ if rank_dropout is not None:
471
+ rank_dropout = float(rank_dropout)
472
+ module_dropout = kwargs.get("module_dropout", None)
473
+ if module_dropout is not None:
474
+ module_dropout = float(module_dropout)
475
+
476
+ # すごく引数が多いな ( ^ω^)・・・
477
+ network = LoRANetwork(
478
+ text_encoder,
479
+ unet,
480
+ multiplier=multiplier,
481
+ lora_dim=network_dim,
482
+ alpha=network_alpha,
483
+ dropout=neuron_dropout,
484
+ rank_dropout=rank_dropout,
485
+ module_dropout=module_dropout,
486
+ conv_lora_dim=conv_dim,
487
+ conv_alpha=conv_alpha,
488
+ block_dims=block_dims,
489
+ block_alphas=block_alphas,
490
+ conv_block_dims=conv_block_dims,
491
+ conv_block_alphas=conv_block_alphas,
492
+ varbose=True,
493
+ is_sdxl=is_sdxl,
494
+ )
495
+
496
+ loraplus_lr_ratio = kwargs.get("loraplus_lr_ratio", None)
497
+ loraplus_unet_lr_ratio = kwargs.get("loraplus_unet_lr_ratio", None)
498
+ loraplus_text_encoder_lr_ratio = kwargs.get("loraplus_text_encoder_lr_ratio", None)
499
+ loraplus_lr_ratio = float(loraplus_lr_ratio) if loraplus_lr_ratio is not None else None
500
+ loraplus_unet_lr_ratio = float(loraplus_unet_lr_ratio) if loraplus_unet_lr_ratio is not None else None
501
+ loraplus_text_encoder_lr_ratio = float(loraplus_text_encoder_lr_ratio) if loraplus_text_encoder_lr_ratio is not None else None
502
+ if loraplus_lr_ratio is not None or loraplus_unet_lr_ratio is not None or loraplus_text_encoder_lr_ratio is not None:
503
+ network.set_loraplus_lr_ratio(loraplus_lr_ratio, loraplus_unet_lr_ratio, loraplus_text_encoder_lr_ratio)
504
+
505
+ if block_lr_weight is not None:
506
+ network.set_block_lr_weight(block_lr_weight)
507
+
508
+ return network
509
+
510
+
511
+ # このメソッドは外部から呼び出される可能性を考慮しておく
512
+ # network_dim, network_alpha にはデフォルト値が入っている。
513
+ # block_dims, block_alphas は両方ともNoneまたは両方とも値が入っている
514
+ # conv_dim, conv_alpha は両方ともNoneまたは両��とも値が入っている
515
+ def get_block_dims_and_alphas(
516
+ is_sdxl, block_dims, block_alphas, network_dim, network_alpha, conv_block_dims, conv_block_alphas, conv_dim, conv_alpha
517
+ ):
518
+ if not is_sdxl:
519
+ num_total_blocks = LoRANetwork.NUM_OF_BLOCKS * 2 + LoRANetwork.NUM_OF_MID_BLOCKS
520
+ else:
521
+ # 1+9+3+9+1=23, no LoRA for emb_layers (0)
522
+ num_total_blocks = 1 + LoRANetwork.SDXL_NUM_OF_BLOCKS * 2 + LoRANetwork.SDXL_NUM_OF_MID_BLOCKS + 1
523
+
524
+ def parse_ints(s):
525
+ return [int(i) for i in s.split(",")]
526
+
527
+ def parse_floats(s):
528
+ return [float(i) for i in s.split(",")]
529
+
530
+ # block_dimsとblock_alphasをパースする。必ず値が入る
531
+ if block_dims is not None:
532
+ block_dims = parse_ints(block_dims)
533
+ assert len(block_dims) == num_total_blocks, (
534
+ f"block_dims must have {num_total_blocks} elements but {len(block_dims)} elements are given"
535
+ + f" / block_dimsは{num_total_blocks}個指定してください(指定された個数: {len(block_dims)})"
536
+ )
537
+ else:
538
+ logger.warning(
539
+ f"block_dims is not specified. all dims are set to {network_dim} / block_dimsが指定されていません。すべてのdimは{network_dim}になります"
540
+ )
541
+ block_dims = [network_dim] * num_total_blocks
542
+
543
+ if block_alphas is not None:
544
+ block_alphas = parse_floats(block_alphas)
545
+ assert (
546
+ len(block_alphas) == num_total_blocks
547
+ ), f"block_alphas must have {num_total_blocks} elements / block_alphasは{num_total_blocks}個指定してください"
548
+ else:
549
+ logger.warning(
550
+ f"block_alphas is not specified. all alphas are set to {network_alpha} / block_alphasが指定されていません。すべてのalphaは{network_alpha}になります"
551
+ )
552
+ block_alphas = [network_alpha] * num_total_blocks
553
+
554
+ # conv_block_dimsとconv_block_alphasを、指定がある場合のみパースする。指定がなければconv_dimとconv_alphaを使う
555
+ if conv_block_dims is not None:
556
+ conv_block_dims = parse_ints(conv_block_dims)
557
+ assert (
558
+ len(conv_block_dims) == num_total_blocks
559
+ ), f"conv_block_dims must have {num_total_blocks} elements / conv_block_dimsは{num_total_blocks}個指定してください"
560
+
561
+ if conv_block_alphas is not None:
562
+ conv_block_alphas = parse_floats(conv_block_alphas)
563
+ assert (
564
+ len(conv_block_alphas) == num_total_blocks
565
+ ), f"conv_block_alphas must have {num_total_blocks} elements / conv_block_alphasは{num_total_blocks}個指定してください"
566
+ else:
567
+ if conv_alpha is None:
568
+ conv_alpha = 1.0
569
+ logger.warning(
570
+ f"conv_block_alphas is not specified. all alphas are set to {conv_alpha} / conv_block_alphasが指定されていません。すべてのalphaは{conv_alpha}になります"
571
+ )
572
+ conv_block_alphas = [conv_alpha] * num_total_blocks
573
+ else:
574
+ if conv_dim is not None:
575
+ logger.warning(
576
+ f"conv_dim/alpha for all blocks are set to {conv_dim} and {conv_alpha} / すべてのブロックのconv_dimとalphaは{conv_dim}および{conv_alpha}になります"
577
+ )
578
+ conv_block_dims = [conv_dim] * num_total_blocks
579
+ conv_block_alphas = [conv_alpha] * num_total_blocks
580
+ else:
581
+ conv_block_dims = None
582
+ conv_block_alphas = None
583
+
584
+ return block_dims, block_alphas, conv_block_dims, conv_block_alphas
585
+
586
+
587
+ # 層別学習率用に層ごとの学習率に対する倍率を定義する、外部から呼び出せるようにclass外に出しておく
588
+ # 戻り値は block ごとの倍率のリスト
589
+ def get_block_lr_weight(
590
+ is_sdxl,
591
+ down_lr_weight: Union[str, List[float]],
592
+ mid_lr_weight: List[float],
593
+ up_lr_weight: Union[str, List[float]],
594
+ zero_threshold: float,
595
+ ) -> Optional[List[float]]:
596
+ # パラメータ未指定時は何もせず、今までと同じ動作とする
597
+ if up_lr_weight is None and mid_lr_weight is None and down_lr_weight is None:
598
+ return None
599
+
600
+ if not is_sdxl:
601
+ max_len_for_down_or_up = LoRANetwork.NUM_OF_BLOCKS
602
+ max_len_for_mid = LoRANetwork.NUM_OF_MID_BLOCKS
603
+ else:
604
+ max_len_for_down_or_up = LoRANetwork.SDXL_NUM_OF_BLOCKS
605
+ max_len_for_mid = LoRANetwork.SDXL_NUM_OF_MID_BLOCKS
606
+
607
+ def get_list(name_with_suffix) -> List[float]:
608
+ import math
609
+
610
+ tokens = name_with_suffix.split("+")
611
+ name = tokens[0]
612
+ base_lr = float(tokens[1]) if len(tokens) > 1 else 0.0
613
+
614
+ if name == "cosine":
615
+ return [
616
+ math.sin(math.pi * (i / (max_len_for_down_or_up - 1)) / 2) + base_lr
617
+ for i in reversed(range(max_len_for_down_or_up))
618
+ ]
619
+ elif name == "sine":
620
+ return [math.sin(math.pi * (i / (max_len_for_down_or_up - 1)) / 2) + base_lr for i in range(max_len_for_down_or_up)]
621
+ elif name == "linear":
622
+ return [i / (max_len_for_down_or_up - 1) + base_lr for i in range(max_len_for_down_or_up)]
623
+ elif name == "reverse_linear":
624
+ return [i / (max_len_for_down_or_up - 1) + base_lr for i in reversed(range(max_len_for_down_or_up))]
625
+ elif name == "zeros":
626
+ return [0.0 + base_lr] * max_len_for_down_or_up
627
+ else:
628
+ logger.error(
629
+ "Unknown lr_weight argument %s is used. Valid arguments: / 不明なlr_weightの引数 %s が使われました。有効な引数:\n\tcosine, sine, linear, reverse_linear, zeros"
630
+ % (name)
631
+ )
632
+ return None
633
+
634
+ if type(down_lr_weight) == str:
635
+ down_lr_weight = get_list(down_lr_weight)
636
+ if type(up_lr_weight) == str:
637
+ up_lr_weight = get_list(up_lr_weight)
638
+
639
+ if (up_lr_weight != None and len(up_lr_weight) > max_len_for_down_or_up) or (
640
+ down_lr_weight != None and len(down_lr_weight) > max_len_for_down_or_up
641
+ ):
642
+ logger.warning("down_weight or up_weight is too long. Parameters after %d-th are ignored." % max_len_for_down_or_up)
643
+ logger.warning("down_weightもしくはup_weightが長すぎます。%d個目以降のパラメータは無視されます。" % max_len_for_down_or_up)
644
+ up_lr_weight = up_lr_weight[:max_len_for_down_or_up]
645
+ down_lr_weight = down_lr_weight[:max_len_for_down_or_up]
646
+
647
+ if mid_lr_weight != None and len(mid_lr_weight) > max_len_for_mid:
648
+ logger.warning("mid_weight is too long. Parameters after %d-th are ignored." % max_len_for_mid)
649
+ logger.warning("mid_weightが長すぎます。%d個目以降のパラメータは無視されます。" % max_len_for_mid)
650
+ mid_lr_weight = mid_lr_weight[:max_len_for_mid]
651
+
652
+ if (up_lr_weight != None and len(up_lr_weight) < max_len_for_down_or_up) or (
653
+ down_lr_weight != None and len(down_lr_weight) < max_len_for_down_or_up
654
+ ):
655
+ logger.warning("down_weight or up_weight is too short. Parameters after %d-th are filled with 1." % max_len_for_down_or_up)
656
+ logger.warning(
657
+ "down_weightもしくはup_weightが短すぎます。%d個目までの不足したパラメータは1で補われます。" % max_len_for_down_or_up
658
+ )
659
+
660
+ if down_lr_weight != None and len(down_lr_weight) < max_len_for_down_or_up:
661
+ down_lr_weight = down_lr_weight + [1.0] * (max_len_for_down_or_up - len(down_lr_weight))
662
+ if up_lr_weight != None and len(up_lr_weight) < max_len_for_down_or_up:
663
+ up_lr_weight = up_lr_weight + [1.0] * (max_len_for_down_or_up - len(up_lr_weight))
664
+
665
+ if mid_lr_weight != None and len(mid_lr_weight) < max_len_for_mid:
666
+ logger.warning("mid_weight is too short. Parameters after %d-th are filled with 1." % max_len_for_mid)
667
+ logger.warning("mid_weightが短すぎます。%d個目までの不足したパラメータは1で補われます。" % max_len_for_mid)
668
+ mid_lr_weight = mid_lr_weight + [1.0] * (max_len_for_mid - len(mid_lr_weight))
669
+
670
+ if (up_lr_weight != None) or (mid_lr_weight != None) or (down_lr_weight != None):
671
+ logger.info("apply block learning rate / 階層別学習率を適用します。")
672
+ if down_lr_weight != None:
673
+ down_lr_weight = [w if w > zero_threshold else 0 for w in down_lr_weight]
674
+ logger.info(f"down_lr_weight (shallower -> deeper, 浅い層->深い層): {down_lr_weight}")
675
+ else:
676
+ down_lr_weight = [1.0] * max_len_for_down_or_up
677
+ logger.info("down_lr_weight: all 1.0, すべて1.0")
678
+
679
+ if mid_lr_weight != None:
680
+ mid_lr_weight = [w if w > zero_threshold else 0 for w in mid_lr_weight]
681
+ logger.info(f"mid_lr_weight: {mid_lr_weight}")
682
+ else:
683
+ mid_lr_weight = [1.0] * max_len_for_mid
684
+ logger.info("mid_lr_weight: all 1.0, すべて1.0")
685
+
686
+ if up_lr_weight != None:
687
+ up_lr_weight = [w if w > zero_threshold else 0 for w in up_lr_weight]
688
+ logger.info(f"up_lr_weight (deeper -> shallower, 深い層->浅い層): {up_lr_weight}")
689
+ else:
690
+ up_lr_weight = [1.0] * max_len_for_down_or_up
691
+ logger.info("up_lr_weight: all 1.0, すべて1.0")
692
+
693
+ lr_weight = down_lr_weight + mid_lr_weight + up_lr_weight
694
+
695
+ if is_sdxl:
696
+ lr_weight = [1.0] + lr_weight + [1.0] # add 1.0 for emb_layers and out
697
+
698
+ assert (not is_sdxl and len(lr_weight) == LoRANetwork.NUM_OF_BLOCKS * 2 + LoRANetwork.NUM_OF_MID_BLOCKS) or (
699
+ is_sdxl and len(lr_weight) == 1 + LoRANetwork.SDXL_NUM_OF_BLOCKS * 2 + LoRANetwork.SDXL_NUM_OF_MID_BLOCKS + 1
700
+ ), f"lr_weight length is invalid: {len(lr_weight)}"
701
+
702
+ return lr_weight
703
+
704
+
705
+ # lr_weightが0のblockをblock_dimsから除外する、外部から呼び出す可能性を考慮しておく
706
+ def remove_block_dims_and_alphas(
707
+ is_sdxl, block_dims, block_alphas, conv_block_dims, conv_block_alphas, block_lr_weight: Optional[List[float]]
708
+ ):
709
+ if block_lr_weight is not None:
710
+ for i, lr in enumerate(block_lr_weight):
711
+ if lr == 0:
712
+ block_dims[i] = 0
713
+ if conv_block_dims is not None:
714
+ conv_block_dims[i] = 0
715
+ return block_dims, block_alphas, conv_block_dims, conv_block_alphas
716
+
717
+
718
+ # 外部から呼び出す可能性を考慮しておく
719
+ def get_block_index(lora_name: str, is_sdxl: bool = False) -> int:
720
+ block_idx = -1 # invalid lora name
721
+ if not is_sdxl:
722
+ m = RE_UPDOWN.search(lora_name)
723
+ if m:
724
+ g = m.groups()
725
+ i = int(g[1])
726
+ j = int(g[3])
727
+ if g[2] == "resnets":
728
+ idx = 3 * i + j
729
+ elif g[2] == "attentions":
730
+ idx = 3 * i + j
731
+ elif g[2] == "upsamplers" or g[2] == "downsamplers":
732
+ idx = 3 * i + 2
733
+
734
+ if g[0] == "down":
735
+ block_idx = 1 + idx # 0に該当するLoRAは存在しない
736
+ elif g[0] == "up":
737
+ block_idx = LoRANetwork.NUM_OF_BLOCKS + 1 + idx
738
+ elif "mid_block_" in lora_name:
739
+ block_idx = LoRANetwork.NUM_OF_BLOCKS # idx=12
740
+ else:
741
+ # copy from sdxl_train
742
+ if lora_name.startswith("lora_unet_"):
743
+ name = lora_name[len("lora_unet_") :]
744
+ if name.startswith("time_embed_") or name.startswith("label_emb_"): # No LoRA
745
+ block_idx = 0 # 0
746
+ elif name.startswith("input_blocks_"): # 1-9
747
+ block_idx = 1 + int(name.split("_")[2])
748
+ elif name.startswith("middle_block_"): # 10-12
749
+ block_idx = 10 + int(name.split("_")[2])
750
+ elif name.startswith("output_blocks_"): # 13-21
751
+ block_idx = 13 + int(name.split("_")[2])
752
+ elif name.startswith("out_"): # 22, out, no LoRA
753
+ block_idx = 22
754
+
755
+ return block_idx
756
+
757
+
758
+ def convert_diffusers_to_sai_if_needed(weights_sd):
759
+ # only supports U-Net LoRA modules
760
+
761
+ found_up_down_blocks = False
762
+ for k in list(weights_sd.keys()):
763
+ if "down_blocks" in k:
764
+ found_up_down_blocks = True
765
+ break
766
+ if "up_blocks" in k:
767
+ found_up_down_blocks = True
768
+ break
769
+ if not found_up_down_blocks:
770
+ return
771
+
772
+ from library.sdxl_model_util import make_unet_conversion_map
773
+
774
+ unet_conversion_map = make_unet_conversion_map()
775
+ unet_conversion_map = {hf.replace(".", "_")[:-1]: sd.replace(".", "_")[:-1] for sd, hf in unet_conversion_map}
776
+
777
+ # # add extra conversion
778
+ # unet_conversion_map["up_blocks_1_upsamplers_0"] = "lora_unet_output_blocks_2_2_conv"
779
+
780
+ logger.info(f"Converting LoRA keys from Diffusers to SAI")
781
+ lora_unet_prefix = "lora_unet_"
782
+ for k in list(weights_sd.keys()):
783
+ if not k.startswith(lora_unet_prefix):
784
+ continue
785
+
786
+ unet_module_name = k[len(lora_unet_prefix) :].split(".")[0]
787
+
788
+ # search for conversion: this is slow because the algorithm is O(n^2), but the number of keys is small
789
+ for hf_module_name, sd_module_name in unet_conversion_map.items():
790
+ if hf_module_name in unet_module_name:
791
+ new_key = (
792
+ lora_unet_prefix
793
+ + unet_module_name.replace(hf_module_name, sd_module_name)
794
+ + k[len(lora_unet_prefix) + len(unet_module_name) :]
795
+ )
796
+ weights_sd[new_key] = weights_sd.pop(k)
797
+ found = True
798
+ break
799
+
800
+ if not found:
801
+ logger.warning(f"Key {k} is not found in unet_conversion_map")
802
+
803
+
804
+ # Create network from weights for inference, weights are not loaded here (because can be merged)
805
+ def create_network_from_weights(multiplier, file, vae, text_encoder, unet, weights_sd=None, for_inference=False, **kwargs):
806
+ # if unet is an instance of SdxlUNet2DConditionModel or subclass, set is_sdxl to True
807
+ is_sdxl = unet is not None and issubclass(unet.__class__, SdxlUNet2DConditionModel)
808
+
809
+ if weights_sd is None:
810
+ if os.path.splitext(file)[1] == ".safetensors":
811
+ from safetensors.torch import load_file, safe_open
812
+
813
+ weights_sd = load_file(file)
814
+ else:
815
+ weights_sd = torch.load(file, map_location="cpu")
816
+
817
+ # if keys are Diffusers based, convert to SAI based
818
+ if is_sdxl:
819
+ convert_diffusers_to_sai_if_needed(weights_sd)
820
+
821
+ # get dim/alpha mapping
822
+ modules_dim = {}
823
+ modules_alpha = {}
824
+ for key, value in weights_sd.items():
825
+ if "." not in key:
826
+ continue
827
+
828
+ lora_name = key.split(".")[0]
829
+ if "alpha" in key:
830
+ modules_alpha[lora_name] = value
831
+ elif "lora_down" in key:
832
+ dim = value.size()[0]
833
+ modules_dim[lora_name] = dim
834
+ # logger.info(lora_name, value.size(), dim)
835
+
836
+ # support old LoRA without alpha
837
+ for key in modules_dim.keys():
838
+ if key not in modules_alpha:
839
+ modules_alpha[key] = modules_dim[key]
840
+
841
+ module_class = LoRAInfModule if for_inference else LoRAModule
842
+
843
+ network = LoRANetwork(
844
+ text_encoder,
845
+ unet,
846
+ multiplier=multiplier,
847
+ modules_dim=modules_dim,
848
+ modules_alpha=modules_alpha,
849
+ module_class=module_class,
850
+ is_sdxl=is_sdxl,
851
+ )
852
+
853
+ # block lr
854
+ block_lr_weight = parse_block_lr_kwargs(is_sdxl, kwargs)
855
+ if block_lr_weight is not None:
856
+ network.set_block_lr_weight(block_lr_weight)
857
+
858
+ return network, weights_sd
859
+
860
+
861
+ class LoRANetwork(torch.nn.Module):
862
+ NUM_OF_BLOCKS = 12 # フルモデル相当でのup,downの層の数
863
+ NUM_OF_MID_BLOCKS = 1
864
+ SDXL_NUM_OF_BLOCKS = 9 # SDXLのモデルでのinput/outputの層の数 total=1(base) 9(input) + 3(mid) + 9(output) + 1(out) = 23
865
+ SDXL_NUM_OF_MID_BLOCKS = 3
866
+
867
+ UNET_TARGET_REPLACE_MODULE = ["Transformer2DModel"]
868
+ UNET_TARGET_REPLACE_MODULE_CONV2D_3X3 = ["ResnetBlock2D", "Downsample2D", "Upsample2D"]
869
+ TEXT_ENCODER_TARGET_REPLACE_MODULE = ["CLIPAttention", "CLIPSdpaAttention", "CLIPMLP"]
870
+ LORA_PREFIX_UNET = "lora_unet"
871
+ LORA_PREFIX_TEXT_ENCODER = "lora_te"
872
+
873
+ # SDXL: must starts with LORA_PREFIX_TEXT_ENCODER
874
+ LORA_PREFIX_TEXT_ENCODER1 = "lora_te1"
875
+ LORA_PREFIX_TEXT_ENCODER2 = "lora_te2"
876
+
877
+ def __init__(
878
+ self,
879
+ text_encoder: Union[List[CLIPTextModel], CLIPTextModel],
880
+ unet,
881
+ multiplier: float = 1.0,
882
+ lora_dim: int = 4,
883
+ alpha: float = 1,
884
+ dropout: Optional[float] = None,
885
+ rank_dropout: Optional[float] = None,
886
+ module_dropout: Optional[float] = None,
887
+ conv_lora_dim: Optional[int] = None,
888
+ conv_alpha: Optional[float] = None,
889
+ block_dims: Optional[List[int]] = None,
890
+ block_alphas: Optional[List[float]] = None,
891
+ conv_block_dims: Optional[List[int]] = None,
892
+ conv_block_alphas: Optional[List[float]] = None,
893
+ modules_dim: Optional[Dict[str, int]] = None,
894
+ modules_alpha: Optional[Dict[str, int]] = None,
895
+ module_class: Type[object] = LoRAModule,
896
+ varbose: Optional[bool] = False,
897
+ is_sdxl: Optional[bool] = False,
898
+ ) -> None:
899
+ """
900
+ LoRA network: すごく引数が多いが、パターンは以下の通り
901
+ 1. lora_dimとalphaを指定
902
+ 2. lora_dim、alpha、conv_lora_dim、conv_alphaを指定
903
+ 3. block_dimsとblock_alphasを指定 : Conv2d3x3には適用しない
904
+ 4. block_dims、block_alphas、conv_block_dims、conv_block_alphasを指定 : Conv2d3x3にも適用する
905
+ 5. modules_dimとmodules_alphaを指定 (推論用)
906
+ """
907
+ super().__init__()
908
+ self.multiplier = multiplier
909
+
910
+ self.lora_dim = lora_dim
911
+ self.alpha = alpha
912
+ self.conv_lora_dim = conv_lora_dim
913
+ self.conv_alpha = conv_alpha
914
+ self.dropout = dropout
915
+ self.rank_dropout = rank_dropout
916
+ self.module_dropout = module_dropout
917
+
918
+ self.loraplus_lr_ratio = None
919
+ self.loraplus_unet_lr_ratio = None
920
+ self.loraplus_text_encoder_lr_ratio = None
921
+
922
+ if modules_dim is not None:
923
+ logger.info(f"create LoRA network from weights")
924
+ elif block_dims is not None:
925
+ logger.info(f"create LoRA network from block_dims")
926
+ logger.info(
927
+ f"neuron dropout: p={self.dropout}, rank dropout: p={self.rank_dropout}, module dropout: p={self.module_dropout}"
928
+ )
929
+ logger.info(f"block_dims: {block_dims}")
930
+ logger.info(f"block_alphas: {block_alphas}")
931
+ if conv_block_dims is not None:
932
+ logger.info(f"conv_block_dims: {conv_block_dims}")
933
+ logger.info(f"conv_block_alphas: {conv_block_alphas}")
934
+ else:
935
+ logger.info(f"create LoRA network. base dim (rank): {lora_dim}, alpha: {alpha}")
936
+ logger.info(
937
+ f"neuron dropout: p={self.dropout}, rank dropout: p={self.rank_dropout}, module dropout: p={self.module_dropout}"
938
+ )
939
+ if self.conv_lora_dim is not None:
940
+ logger.info(
941
+ f"apply LoRA to Conv2d with kernel size (3,3). dim (rank): {self.conv_lora_dim}, alpha: {self.conv_alpha}"
942
+ )
943
+
944
+ # create module instances
945
+ def create_modules(
946
+ is_unet: bool,
947
+ text_encoder_idx: Optional[int], # None, 1, 2
948
+ root_module: torch.nn.Module,
949
+ target_replace_modules: List[torch.nn.Module],
950
+ ) -> List[LoRAModule]:
951
+ prefix = (
952
+ self.LORA_PREFIX_UNET
953
+ if is_unet
954
+ else (
955
+ self.LORA_PREFIX_TEXT_ENCODER
956
+ if text_encoder_idx is None
957
+ else (self.LORA_PREFIX_TEXT_ENCODER1 if text_encoder_idx == 1 else self.LORA_PREFIX_TEXT_ENCODER2)
958
+ )
959
+ )
960
+ loras = []
961
+ skipped = []
962
+ for name, module in root_module.named_modules():
963
+ if module.__class__.__name__ in target_replace_modules:
964
+ for child_name, child_module in module.named_modules():
965
+ is_linear = child_module.__class__.__name__ == "Linear"
966
+ is_conv2d = child_module.__class__.__name__ == "Conv2d"
967
+ is_conv2d_1x1 = is_conv2d and child_module.kernel_size == (1, 1)
968
+
969
+ if is_linear or is_conv2d:
970
+ lora_name = prefix + "." + name + "." + child_name
971
+ lora_name = lora_name.replace(".", "_")
972
+
973
+ dim = None
974
+ alpha = None
975
+
976
+ if modules_dim is not None:
977
+ # モジュール指定あり
978
+ if lora_name in modules_dim:
979
+ dim = modules_dim[lora_name]
980
+ alpha = modules_alpha[lora_name]
981
+ elif is_unet and block_dims is not None:
982
+ # U-Netでblock_dims指定あり
983
+ block_idx = get_block_index(lora_name, is_sdxl)
984
+ if is_linear or is_conv2d_1x1:
985
+ dim = block_dims[block_idx]
986
+ alpha = block_alphas[block_idx]
987
+ elif conv_block_dims is not None:
988
+ dim = conv_block_dims[block_idx]
989
+ alpha = conv_block_alphas[block_idx]
990
+ else:
991
+ # 通常、すべて対象とする
992
+ if is_linear or is_conv2d_1x1:
993
+ dim = self.lora_dim
994
+ alpha = self.alpha
995
+ elif self.conv_lora_dim is not None:
996
+ dim = self.conv_lora_dim
997
+ alpha = self.conv_alpha
998
+
999
+ if dim is None or dim == 0:
1000
+ # skipした情報を出力
1001
+ if is_linear or is_conv2d_1x1 or (self.conv_lora_dim is not None or conv_block_dims is not None):
1002
+ skipped.append(lora_name)
1003
+ continue
1004
+
1005
+ lora = module_class(
1006
+ lora_name,
1007
+ child_module,
1008
+ self.multiplier,
1009
+ dim,
1010
+ alpha,
1011
+ dropout=dropout,
1012
+ rank_dropout=rank_dropout,
1013
+ module_dropout=module_dropout,
1014
+ )
1015
+ loras.append(lora)
1016
+ return loras, skipped
1017
+
1018
+ text_encoders = text_encoder if type(text_encoder) == list else [text_encoder]
1019
+
1020
+ # create LoRA for text encoder
1021
+ # 毎回すべてのモジュールを作るのは無駄なので要検討
1022
+ self.text_encoder_loras = []
1023
+ skipped_te = []
1024
+ for i, text_encoder in enumerate(text_encoders):
1025
+ if len(text_encoders) > 1:
1026
+ index = i + 1
1027
+ logger.info(f"create LoRA for Text Encoder {index}:")
1028
+ else:
1029
+ index = None
1030
+ logger.info(f"create LoRA for Text Encoder:")
1031
+
1032
+ text_encoder_loras, skipped = create_modules(False, index, text_encoder, LoRANetwork.TEXT_ENCODER_TARGET_REPLACE_MODULE)
1033
+ self.text_encoder_loras.extend(text_encoder_loras)
1034
+ skipped_te += skipped
1035
+ logger.info(f"create LoRA for Text Encoder: {len(self.text_encoder_loras)} modules.")
1036
+
1037
+ # extend U-Net target modules if conv2d 3x3 is enabled, or load from weights
1038
+ target_modules = LoRANetwork.UNET_TARGET_REPLACE_MODULE
1039
+ if modules_dim is not None or self.conv_lora_dim is not None or conv_block_dims is not None:
1040
+ target_modules += LoRANetwork.UNET_TARGET_REPLACE_MODULE_CONV2D_3X3
1041
+
1042
+ self.unet_loras, skipped_un = create_modules(True, None, unet, target_modules)
1043
+ logger.info(f"create LoRA for U-Net: {len(self.unet_loras)} modules.")
1044
+
1045
+ skipped = skipped_te + skipped_un
1046
+ if varbose and len(skipped) > 0:
1047
+ logger.warning(
1048
+ f"because block_lr_weight is 0 or dim (rank) is 0, {len(skipped)} LoRA modules are skipped / block_lr_weightまたはdim (rank)が0の為、次の{len(skipped)}個のLoRAモジュールはスキップされます:"
1049
+ )
1050
+ for name in skipped:
1051
+ logger.info(f"\t{name}")
1052
+
1053
+ self.block_lr_weight = None
1054
+ self.block_lr = False
1055
+
1056
+ # assertion
1057
+ names = set()
1058
+ for lora in self.text_encoder_loras + self.unet_loras:
1059
+ assert lora.lora_name not in names, f"duplicated lora name: {lora.lora_name}"
1060
+ names.add(lora.lora_name)
1061
+
1062
+ def set_multiplier(self, multiplier):
1063
+ self.multiplier = multiplier
1064
+ for lora in self.text_encoder_loras + self.unet_loras:
1065
+ lora.multiplier = self.multiplier
1066
+
1067
+ def set_enabled(self, is_enabled):
1068
+ for lora in self.text_encoder_loras + self.unet_loras:
1069
+ lora.enabled = is_enabled
1070
+
1071
+ def load_weights(self, file):
1072
+ if os.path.splitext(file)[1] == ".safetensors":
1073
+ from safetensors.torch import load_file
1074
+
1075
+ weights_sd = load_file(file)
1076
+ else:
1077
+ weights_sd = torch.load(file, map_location="cpu")
1078
+
1079
+ info = self.load_state_dict(weights_sd, False)
1080
+ return info
1081
+
1082
+ def apply_to(self, text_encoder, unet, apply_text_encoder=True, apply_unet=True):
1083
+ if apply_text_encoder:
1084
+ logger.info(f"enable LoRA for text encoder: {len(self.text_encoder_loras)} modules")
1085
+ else:
1086
+ self.text_encoder_loras = []
1087
+
1088
+ if apply_unet:
1089
+ logger.info(f"enable LoRA for U-Net: {len(self.unet_loras)} modules")
1090
+ else:
1091
+ self.unet_loras = []
1092
+
1093
+ for lora in self.text_encoder_loras + self.unet_loras:
1094
+ lora.apply_to()
1095
+ self.add_module(lora.lora_name, lora)
1096
+
1097
+ # マージできるかどうかを返す
1098
+ def is_mergeable(self):
1099
+ return True
1100
+
1101
+ # TODO refactor to common function with apply_to
1102
+ def merge_to(self, text_encoder, unet, weights_sd, dtype, device):
1103
+ apply_text_encoder = apply_unet = False
1104
+ for key in weights_sd.keys():
1105
+ if key.startswith(LoRANetwork.LORA_PREFIX_TEXT_ENCODER):
1106
+ apply_text_encoder = True
1107
+ elif key.startswith(LoRANetwork.LORA_PREFIX_UNET):
1108
+ apply_unet = True
1109
+
1110
+ if apply_text_encoder:
1111
+ logger.info("enable LoRA for text encoder")
1112
+ else:
1113
+ self.text_encoder_loras = []
1114
+
1115
+ if apply_unet:
1116
+ logger.info("enable LoRA for U-Net")
1117
+ else:
1118
+ self.unet_loras = []
1119
+
1120
+ for lora in self.text_encoder_loras + self.unet_loras:
1121
+ sd_for_lora = {}
1122
+ for key in weights_sd.keys():
1123
+ if key.startswith(lora.lora_name):
1124
+ sd_for_lora[key[len(lora.lora_name) + 1 :]] = weights_sd[key]
1125
+ lora.merge_to(sd_for_lora, dtype, device)
1126
+
1127
+ logger.info(f"weights are merged")
1128
+
1129
+ # 層別学習率用に層ごとの学習率に対する倍率を定義する 引数の順番が逆だがとりあえず気にしない
1130
+ def set_block_lr_weight(self, block_lr_weight: Optional[List[float]]):
1131
+ self.block_lr = True
1132
+ self.block_lr_weight = block_lr_weight
1133
+
1134
+ def get_lr_weight(self, block_idx: int) -> float:
1135
+ if not self.block_lr or self.block_lr_weight is None:
1136
+ return 1.0
1137
+ return self.block_lr_weight[block_idx]
1138
+
1139
+ def set_loraplus_lr_ratio(self, loraplus_lr_ratio, loraplus_unet_lr_ratio, loraplus_text_encoder_lr_ratio):
1140
+ self.loraplus_lr_ratio = loraplus_lr_ratio
1141
+ self.loraplus_unet_lr_ratio = loraplus_unet_lr_ratio
1142
+ self.loraplus_text_encoder_lr_ratio = loraplus_text_encoder_lr_ratio
1143
+
1144
+ logger.info(f"LoRA+ UNet LR Ratio: {self.loraplus_unet_lr_ratio or self.loraplus_lr_ratio}")
1145
+ logger.info(f"LoRA+ Text Encoder LR Ratio: {self.loraplus_text_encoder_lr_ratio or self.loraplus_lr_ratio}")
1146
+
1147
+ # 二つのText Encoderに別々の学習率を設定できるようにするといいかも
1148
+ def prepare_optimizer_params(self, text_encoder_lr, unet_lr, default_lr):
1149
+ # TODO warn if optimizer is not compatible with LoRA+ (but it will cause error so we don't need to check it here?)
1150
+ # if (
1151
+ # self.loraplus_lr_ratio is not None
1152
+ # or self.loraplus_text_encoder_lr_ratio is not None
1153
+ # or self.loraplus_unet_lr_ratio is not None
1154
+ # ):
1155
+ # assert (
1156
+ # optimizer_type.lower() != "prodigy" and "dadapt" not in optimizer_type.lower()
1157
+ # ), "LoRA+ and Prodigy/DAdaptation is not supported / LoRA+とProdigy/DAdaptationの組み合わせはサポートされていません"
1158
+
1159
+ self.requires_grad_(True)
1160
+
1161
+ all_params = []
1162
+ lr_descriptions = []
1163
+
1164
+ def assemble_params(loras, lr, ratio):
1165
+ param_groups = {"lora": {}, "plus": {}}
1166
+ for lora in loras:
1167
+ for name, param in lora.named_parameters():
1168
+ if ratio is not None and "lora_up" in name:
1169
+ param_groups["plus"][f"{lora.lora_name}.{name}"] = param
1170
+ else:
1171
+ param_groups["lora"][f"{lora.lora_name}.{name}"] = param
1172
+
1173
+ params = []
1174
+ descriptions = []
1175
+ for key in param_groups.keys():
1176
+ param_data = {"params": param_groups[key].values()}
1177
+
1178
+ if len(param_data["params"]) == 0:
1179
+ continue
1180
+
1181
+ if lr is not None:
1182
+ if key == "plus":
1183
+ param_data["lr"] = lr * ratio
1184
+ else:
1185
+ param_data["lr"] = lr
1186
+
1187
+ if param_data.get("lr", None) == 0 or param_data.get("lr", None) is None:
1188
+ logger.info("NO LR skipping!")
1189
+ continue
1190
+
1191
+ params.append(param_data)
1192
+ descriptions.append("plus" if key == "plus" else "")
1193
+
1194
+ return params, descriptions
1195
+
1196
+ if self.text_encoder_loras:
1197
+ params, descriptions = assemble_params(
1198
+ self.text_encoder_loras,
1199
+ text_encoder_lr if text_encoder_lr is not None else default_lr,
1200
+ self.loraplus_text_encoder_lr_ratio or self.loraplus_lr_ratio,
1201
+ )
1202
+ all_params.extend(params)
1203
+ lr_descriptions.extend(["textencoder" + (" " + d if d else "") for d in descriptions])
1204
+
1205
+ if self.unet_loras:
1206
+ if self.block_lr:
1207
+ is_sdxl = False
1208
+ for lora in self.unet_loras:
1209
+ if "input_blocks" in lora.lora_name or "output_blocks" in lora.lora_name:
1210
+ is_sdxl = True
1211
+ break
1212
+
1213
+ # 学習率のグラフをblockごとにしたいので、blockごとにloraを分類
1214
+ block_idx_to_lora = {}
1215
+ for lora in self.unet_loras:
1216
+ idx = get_block_index(lora.lora_name, is_sdxl)
1217
+ if idx not in block_idx_to_lora:
1218
+ block_idx_to_lora[idx] = []
1219
+ block_idx_to_lora[idx].append(lora)
1220
+
1221
+ # blockごとにパラメータを設定する
1222
+ for idx, block_loras in block_idx_to_lora.items():
1223
+ params, descriptions = assemble_params(
1224
+ block_loras,
1225
+ (unet_lr if unet_lr is not None else default_lr) * self.get_lr_weight(idx),
1226
+ self.loraplus_unet_lr_ratio or self.loraplus_lr_ratio,
1227
+ )
1228
+ all_params.extend(params)
1229
+ lr_descriptions.extend([f"unet_block{idx}" + (" " + d if d else "") for d in descriptions])
1230
+
1231
+ else:
1232
+ params, descriptions = assemble_params(
1233
+ self.unet_loras,
1234
+ unet_lr if unet_lr is not None else default_lr,
1235
+ self.loraplus_unet_lr_ratio or self.loraplus_lr_ratio,
1236
+ )
1237
+ all_params.extend(params)
1238
+ lr_descriptions.extend(["unet" + (" " + d if d else "") for d in descriptions])
1239
+
1240
+ return all_params, lr_descriptions
1241
+
1242
+ def enable_gradient_checkpointing(self):
1243
+ # not supported
1244
+ pass
1245
+
1246
+ def prepare_grad_etc(self, text_encoder, unet):
1247
+ self.requires_grad_(True)
1248
+
1249
+ def on_epoch_start(self, text_encoder, unet):
1250
+ self.train()
1251
+
1252
+ def get_trainable_params(self):
1253
+ return self.parameters()
1254
+
1255
+ def save_weights(self, file, dtype, metadata):
1256
+ if metadata is not None and len(metadata) == 0:
1257
+ metadata = None
1258
+
1259
+ state_dict = self.state_dict()
1260
+
1261
+ if dtype is not None:
1262
+ for key in list(state_dict.keys()):
1263
+ v = state_dict[key]
1264
+ v = v.detach().clone().to("cpu").to(dtype)
1265
+ state_dict[key] = v
1266
+
1267
+ if os.path.splitext(file)[1] == ".safetensors":
1268
+ from safetensors.torch import save_file
1269
+ from library import train_util
1270
+
1271
+ # Precalculate model hashes to save time on indexing
1272
+ if metadata is None:
1273
+ metadata = {}
1274
+ model_hash, legacy_hash = train_util.precalculate_safetensors_hashes(state_dict, metadata)
1275
+ metadata["sshs_model_hash"] = model_hash
1276
+ metadata["sshs_legacy_hash"] = legacy_hash
1277
+
1278
+ save_file(state_dict, file, metadata)
1279
+ else:
1280
+ torch.save(state_dict, file)
1281
+
1282
+ # mask is a tensor with values from 0 to 1
1283
+ def set_region(self, sub_prompt_index, is_last_network, mask):
1284
+ if mask.max() == 0:
1285
+ mask = torch.ones_like(mask)
1286
+
1287
+ self.mask = mask
1288
+ self.sub_prompt_index = sub_prompt_index
1289
+ self.is_last_network = is_last_network
1290
+
1291
+ for lora in self.text_encoder_loras + self.unet_loras:
1292
+ lora.set_network(self)
1293
+
1294
+ def set_current_generation(self, batch_size, num_sub_prompts, width, height, shared, ds_ratio=None):
1295
+ self.batch_size = batch_size
1296
+ self.num_sub_prompts = num_sub_prompts
1297
+ self.current_size = (height, width)
1298
+ self.shared = shared
1299
+
1300
+ # create masks
1301
+ mask = self.mask
1302
+ mask_dic = {}
1303
+ mask = mask.unsqueeze(0).unsqueeze(1) # b(1),c(1),h,w
1304
+ ref_weight = self.text_encoder_loras[0].lora_down.weight if self.text_encoder_loras else self.unet_loras[0].lora_down.weight
1305
+ dtype = ref_weight.dtype
1306
+ device = ref_weight.device
1307
+
1308
+ def resize_add(mh, mw):
1309
+ # logger.info(mh, mw, mh * mw)
1310
+ m = torch.nn.functional.interpolate(mask, (mh, mw), mode="bilinear") # doesn't work in bf16
1311
+ m = m.to(device, dtype=dtype)
1312
+ mask_dic[mh * mw] = m
1313
+
1314
+ h = height // 8
1315
+ w = width // 8
1316
+ for _ in range(4):
1317
+ resize_add(h, w)
1318
+ if h % 2 == 1 or w % 2 == 1: # add extra shape if h/w is not divisible by 2
1319
+ resize_add(h + h % 2, w + w % 2)
1320
+
1321
+ # deep shrink
1322
+ if ds_ratio is not None:
1323
+ hd = int(h * ds_ratio)
1324
+ wd = int(w * ds_ratio)
1325
+ resize_add(hd, wd)
1326
+
1327
+ h = (h + 1) // 2
1328
+ w = (w + 1) // 2
1329
+
1330
+ self.mask_dic = mask_dic
1331
+
1332
+ def backup_weights(self):
1333
+ # 重みのバックアップを行う
1334
+ loras: List[LoRAInfModule] = self.text_encoder_loras + self.unet_loras
1335
+ for lora in loras:
1336
+ org_module = lora.org_module_ref[0]
1337
+ if not hasattr(org_module, "_lora_org_weight"):
1338
+ sd = org_module.state_dict()
1339
+ org_module._lora_org_weight = sd["weight"].detach().clone()
1340
+ org_module._lora_restored = True
1341
+
1342
+ def restore_weights(self):
1343
+ # 重みのリストアを行う
1344
+ loras: List[LoRAInfModule] = self.text_encoder_loras + self.unet_loras
1345
+ for lora in loras:
1346
+ org_module = lora.org_module_ref[0]
1347
+ if not org_module._lora_restored:
1348
+ sd = org_module.state_dict()
1349
+ sd["weight"] = org_module._lora_org_weight
1350
+ org_module.load_state_dict(sd)
1351
+ org_module._lora_restored = True
1352
+
1353
+ def pre_calculation(self):
1354
+ # 事前計算を行う
1355
+ loras: List[LoRAInfModule] = self.text_encoder_loras + self.unet_loras
1356
+ for lora in loras:
1357
+ org_module = lora.org_module_ref[0]
1358
+ sd = org_module.state_dict()
1359
+
1360
+ org_weight = sd["weight"]
1361
+ lora_weight = lora.get_weight().to(org_weight.device, dtype=org_weight.dtype)
1362
+ sd["weight"] = org_weight + lora_weight
1363
+ assert sd["weight"].shape == org_weight.shape
1364
+ org_module.load_state_dict(sd)
1365
+
1366
+ org_module._lora_restored = False
1367
+ lora.enabled = False
1368
+
1369
+ def apply_max_norm_regularization(self, max_norm_value, device):
1370
+ downkeys = []
1371
+ upkeys = []
1372
+ alphakeys = []
1373
+ norms = []
1374
+ keys_scaled = 0
1375
+
1376
+ state_dict = self.state_dict()
1377
+ for key in state_dict.keys():
1378
+ if "lora_down" in key and "weight" in key:
1379
+ downkeys.append(key)
1380
+ upkeys.append(key.replace("lora_down", "lora_up"))
1381
+ alphakeys.append(key.replace("lora_down.weight", "alpha"))
1382
+
1383
+ for i in range(len(downkeys)):
1384
+ down = state_dict[downkeys[i]].to(device)
1385
+ up = state_dict[upkeys[i]].to(device)
1386
+ alpha = state_dict[alphakeys[i]].to(device)
1387
+ dim = down.shape[0]
1388
+ scale = alpha / dim
1389
+
1390
+ if up.shape[2:] == (1, 1) and down.shape[2:] == (1, 1):
1391
+ updown = (up.squeeze(2).squeeze(2) @ down.squeeze(2).squeeze(2)).unsqueeze(2).unsqueeze(3)
1392
+ elif up.shape[2:] == (3, 3) or down.shape[2:] == (3, 3):
1393
+ updown = torch.nn.functional.conv2d(down.permute(1, 0, 2, 3), up).permute(1, 0, 2, 3)
1394
+ else:
1395
+ updown = up @ down
1396
+
1397
+ updown *= scale
1398
+
1399
+ norm = updown.norm().clamp(min=max_norm_value / 2)
1400
+ desired = torch.clamp(norm, max=max_norm_value)
1401
+ ratio = desired.cpu() / norm.cpu()
1402
+ sqrt_ratio = ratio**0.5
1403
+ if ratio != 1:
1404
+ keys_scaled += 1
1405
+ state_dict[upkeys[i]] *= sqrt_ratio
1406
+ state_dict[downkeys[i]] *= sqrt_ratio
1407
+ scalednorm = updown.norm() * ratio
1408
+ norms.append(scalednorm.item())
1409
+
1410
+ return keys_scaled, sum(norms) / len(norms), max(norms)
lora_diffusers.py ADDED
@@ -0,0 +1,616 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Diffusersで動くLoRA。このファイル単独で完結する。
2
+ # LoRA module for Diffusers. This file works independently.
3
+
4
+ import bisect
5
+ import math
6
+ import random
7
+ from typing import Any, Dict, List, Mapping, Optional, Union
8
+ from diffusers import UNet2DConditionModel
9
+ import numpy as np
10
+ from tqdm import tqdm
11
+ from transformers import CLIPTextModel
12
+
13
+ import torch
14
+ from library.device_utils import init_ipex, get_preferred_device
15
+ init_ipex()
16
+
17
+ from library.utils import setup_logging
18
+ setup_logging()
19
+ import logging
20
+ logger = logging.getLogger(__name__)
21
+
22
+ def make_unet_conversion_map() -> Dict[str, str]:
23
+ unet_conversion_map_layer = []
24
+
25
+ for i in range(3): # num_blocks is 3 in sdxl
26
+ # loop over downblocks/upblocks
27
+ for j in range(2):
28
+ # loop over resnets/attentions for downblocks
29
+ hf_down_res_prefix = f"down_blocks.{i}.resnets.{j}."
30
+ sd_down_res_prefix = f"input_blocks.{3*i + j + 1}.0."
31
+ unet_conversion_map_layer.append((sd_down_res_prefix, hf_down_res_prefix))
32
+
33
+ if i < 3:
34
+ # no attention layers in down_blocks.3
35
+ hf_down_atn_prefix = f"down_blocks.{i}.attentions.{j}."
36
+ sd_down_atn_prefix = f"input_blocks.{3*i + j + 1}.1."
37
+ unet_conversion_map_layer.append((sd_down_atn_prefix, hf_down_atn_prefix))
38
+
39
+ for j in range(3):
40
+ # loop over resnets/attentions for upblocks
41
+ hf_up_res_prefix = f"up_blocks.{i}.resnets.{j}."
42
+ sd_up_res_prefix = f"output_blocks.{3*i + j}.0."
43
+ unet_conversion_map_layer.append((sd_up_res_prefix, hf_up_res_prefix))
44
+
45
+ # if i > 0: commentout for sdxl
46
+ # no attention layers in up_blocks.0
47
+ hf_up_atn_prefix = f"up_blocks.{i}.attentions.{j}."
48
+ sd_up_atn_prefix = f"output_blocks.{3*i + j}.1."
49
+ unet_conversion_map_layer.append((sd_up_atn_prefix, hf_up_atn_prefix))
50
+
51
+ if i < 3:
52
+ # no downsample in down_blocks.3
53
+ hf_downsample_prefix = f"down_blocks.{i}.downsamplers.0.conv."
54
+ sd_downsample_prefix = f"input_blocks.{3*(i+1)}.0.op."
55
+ unet_conversion_map_layer.append((sd_downsample_prefix, hf_downsample_prefix))
56
+
57
+ # no upsample in up_blocks.3
58
+ hf_upsample_prefix = f"up_blocks.{i}.upsamplers.0."
59
+ sd_upsample_prefix = f"output_blocks.{3*i + 2}.{2}." # change for sdxl
60
+ unet_conversion_map_layer.append((sd_upsample_prefix, hf_upsample_prefix))
61
+
62
+ hf_mid_atn_prefix = "mid_block.attentions.0."
63
+ sd_mid_atn_prefix = "middle_block.1."
64
+ unet_conversion_map_layer.append((sd_mid_atn_prefix, hf_mid_atn_prefix))
65
+
66
+ for j in range(2):
67
+ hf_mid_res_prefix = f"mid_block.resnets.{j}."
68
+ sd_mid_res_prefix = f"middle_block.{2*j}."
69
+ unet_conversion_map_layer.append((sd_mid_res_prefix, hf_mid_res_prefix))
70
+
71
+ unet_conversion_map_resnet = [
72
+ # (stable-diffusion, HF Diffusers)
73
+ ("in_layers.0.", "norm1."),
74
+ ("in_layers.2.", "conv1."),
75
+ ("out_layers.0.", "norm2."),
76
+ ("out_layers.3.", "conv2."),
77
+ ("emb_layers.1.", "time_emb_proj."),
78
+ ("skip_connection.", "conv_shortcut."),
79
+ ]
80
+
81
+ unet_conversion_map = []
82
+ for sd, hf in unet_conversion_map_layer:
83
+ if "resnets" in hf:
84
+ for sd_res, hf_res in unet_conversion_map_resnet:
85
+ unet_conversion_map.append((sd + sd_res, hf + hf_res))
86
+ else:
87
+ unet_conversion_map.append((sd, hf))
88
+
89
+ for j in range(2):
90
+ hf_time_embed_prefix = f"time_embedding.linear_{j+1}."
91
+ sd_time_embed_prefix = f"time_embed.{j*2}."
92
+ unet_conversion_map.append((sd_time_embed_prefix, hf_time_embed_prefix))
93
+
94
+ for j in range(2):
95
+ hf_label_embed_prefix = f"add_embedding.linear_{j+1}."
96
+ sd_label_embed_prefix = f"label_emb.0.{j*2}."
97
+ unet_conversion_map.append((sd_label_embed_prefix, hf_label_embed_prefix))
98
+
99
+ unet_conversion_map.append(("input_blocks.0.0.", "conv_in."))
100
+ unet_conversion_map.append(("out.0.", "conv_norm_out."))
101
+ unet_conversion_map.append(("out.2.", "conv_out."))
102
+
103
+ sd_hf_conversion_map = {sd.replace(".", "_")[:-1]: hf.replace(".", "_")[:-1] for sd, hf in unet_conversion_map}
104
+ return sd_hf_conversion_map
105
+
106
+
107
+ UNET_CONVERSION_MAP = make_unet_conversion_map()
108
+
109
+
110
+ class LoRAModule(torch.nn.Module):
111
+ """
112
+ replaces forward method of the original Linear, instead of replacing the original Linear module.
113
+ """
114
+
115
+ def __init__(
116
+ self,
117
+ lora_name,
118
+ org_module: torch.nn.Module,
119
+ multiplier=1.0,
120
+ lora_dim=4,
121
+ alpha=1,
122
+ ):
123
+ """if alpha == 0 or None, alpha is rank (no scaling)."""
124
+ super().__init__()
125
+ self.lora_name = lora_name
126
+
127
+ if org_module.__class__.__name__ == "Conv2d" or org_module.__class__.__name__ == "LoRACompatibleConv":
128
+ in_dim = org_module.in_channels
129
+ out_dim = org_module.out_channels
130
+ else:
131
+ in_dim = org_module.in_features
132
+ out_dim = org_module.out_features
133
+
134
+ self.lora_dim = lora_dim
135
+
136
+ if org_module.__class__.__name__ == "Conv2d" or org_module.__class__.__name__ == "LoRACompatibleConv":
137
+ kernel_size = org_module.kernel_size
138
+ stride = org_module.stride
139
+ padding = org_module.padding
140
+ self.lora_down = torch.nn.Conv2d(in_dim, self.lora_dim, kernel_size, stride, padding, bias=False)
141
+ self.lora_up = torch.nn.Conv2d(self.lora_dim, out_dim, (1, 1), (1, 1), bias=False)
142
+ else:
143
+ self.lora_down = torch.nn.Linear(in_dim, self.lora_dim, bias=False)
144
+ self.lora_up = torch.nn.Linear(self.lora_dim, out_dim, bias=False)
145
+
146
+ if type(alpha) == torch.Tensor:
147
+ alpha = alpha.detach().float().numpy() # without casting, bf16 causes error
148
+ alpha = self.lora_dim if alpha is None or alpha == 0 else alpha
149
+ self.scale = alpha / self.lora_dim
150
+ self.register_buffer("alpha", torch.tensor(alpha)) # 勾配計算に含めない / not included in gradient calculation
151
+
152
+ # same as microsoft's
153
+ torch.nn.init.kaiming_uniform_(self.lora_down.weight, a=math.sqrt(5))
154
+ torch.nn.init.zeros_(self.lora_up.weight)
155
+
156
+ self.multiplier = multiplier
157
+ self.org_module = [org_module]
158
+ self.enabled = True
159
+ self.network: LoRANetwork = None
160
+ self.org_forward = None
161
+
162
+ # override org_module's forward method
163
+ def apply_to(self, multiplier=None):
164
+ if multiplier is not None:
165
+ self.multiplier = multiplier
166
+ if self.org_forward is None:
167
+ self.org_forward = self.org_module[0].forward
168
+ self.org_module[0].forward = self.forward
169
+
170
+ # restore org_module's forward method
171
+ def unapply_to(self):
172
+ if self.org_forward is not None:
173
+ self.org_module[0].forward = self.org_forward
174
+
175
+ # forward with lora
176
+ # scale is used LoRACompatibleConv, but we ignore it because we have multiplier
177
+ def forward(self, x, scale=1.0):
178
+ if not self.enabled:
179
+ return self.org_forward(x)
180
+ return self.org_forward(x) + self.lora_up(self.lora_down(x)) * self.multiplier * self.scale
181
+
182
+ def set_network(self, network):
183
+ self.network = network
184
+
185
+ # merge lora weight to org weight
186
+ def merge_to(self, multiplier=1.0):
187
+ # get lora weight
188
+ lora_weight = self.get_weight(multiplier)
189
+
190
+ # get org weight
191
+ org_sd = self.org_module[0].state_dict()
192
+ org_weight = org_sd["weight"]
193
+ weight = org_weight + lora_weight.to(org_weight.device, dtype=org_weight.dtype)
194
+
195
+ # set weight to org_module
196
+ org_sd["weight"] = weight
197
+ self.org_module[0].load_state_dict(org_sd)
198
+
199
+ # restore org weight from lora weight
200
+ def restore_from(self, multiplier=1.0):
201
+ # get lora weight
202
+ lora_weight = self.get_weight(multiplier)
203
+
204
+ # get org weight
205
+ org_sd = self.org_module[0].state_dict()
206
+ org_weight = org_sd["weight"]
207
+ weight = org_weight - lora_weight.to(org_weight.device, dtype=org_weight.dtype)
208
+
209
+ # set weight to org_module
210
+ org_sd["weight"] = weight
211
+ self.org_module[0].load_state_dict(org_sd)
212
+
213
+ # return lora weight
214
+ def get_weight(self, multiplier=None):
215
+ if multiplier is None:
216
+ multiplier = self.multiplier
217
+
218
+ # get up/down weight from module
219
+ up_weight = self.lora_up.weight.to(torch.float)
220
+ down_weight = self.lora_down.weight.to(torch.float)
221
+
222
+ # pre-calculated weight
223
+ if len(down_weight.size()) == 2:
224
+ # linear
225
+ weight = self.multiplier * (up_weight @ down_weight) * self.scale
226
+ elif down_weight.size()[2:4] == (1, 1):
227
+ # conv2d 1x1
228
+ weight = (
229
+ self.multiplier
230
+ * (up_weight.squeeze(3).squeeze(2) @ down_weight.squeeze(3).squeeze(2)).unsqueeze(2).unsqueeze(3)
231
+ * self.scale
232
+ )
233
+ else:
234
+ # conv2d 3x3
235
+ conved = torch.nn.functional.conv2d(down_weight.permute(1, 0, 2, 3), up_weight).permute(1, 0, 2, 3)
236
+ weight = self.multiplier * conved * self.scale
237
+
238
+ return weight
239
+
240
+
241
+ # Create network from weights for inference, weights are not loaded here
242
+ def create_network_from_weights(
243
+ text_encoder: Union[CLIPTextModel, List[CLIPTextModel]], unet: UNet2DConditionModel, weights_sd: Dict, multiplier: float = 1.0
244
+ ):
245
+ # get dim/alpha mapping
246
+ modules_dim = {}
247
+ modules_alpha = {}
248
+ for key, value in weights_sd.items():
249
+ if "." not in key:
250
+ continue
251
+
252
+ lora_name = key.split(".")[0]
253
+ if "alpha" in key:
254
+ modules_alpha[lora_name] = value
255
+ elif "lora_down" in key:
256
+ dim = value.size()[0]
257
+ modules_dim[lora_name] = dim
258
+ # logger.info(f"{lora_name} {value.size()} {dim}")
259
+
260
+ # support old LoRA without alpha
261
+ for key in modules_dim.keys():
262
+ if key not in modules_alpha:
263
+ modules_alpha[key] = modules_dim[key]
264
+
265
+ return LoRANetwork(text_encoder, unet, multiplier=multiplier, modules_dim=modules_dim, modules_alpha=modules_alpha)
266
+
267
+
268
+ def merge_lora_weights(pipe, weights_sd: Dict, multiplier: float = 1.0):
269
+ text_encoders = [pipe.text_encoder, pipe.text_encoder_2] if hasattr(pipe, "text_encoder_2") else [pipe.text_encoder]
270
+ unet = pipe.unet
271
+
272
+ lora_network = create_network_from_weights(text_encoders, unet, weights_sd, multiplier=multiplier)
273
+ lora_network.load_state_dict(weights_sd)
274
+ lora_network.merge_to(multiplier=multiplier)
275
+
276
+
277
+ # block weightや学習に対応しない簡易版 / simple version without block weight and training
278
+ class LoRANetwork(torch.nn.Module):
279
+ UNET_TARGET_REPLACE_MODULE = ["Transformer2DModel"]
280
+ UNET_TARGET_REPLACE_MODULE_CONV2D_3X3 = ["ResnetBlock2D", "Downsample2D", "Upsample2D"]
281
+ TEXT_ENCODER_TARGET_REPLACE_MODULE = ["CLIPAttention", "CLIPSdpaAttention", "CLIPMLP"]
282
+ LORA_PREFIX_UNET = "lora_unet"
283
+ LORA_PREFIX_TEXT_ENCODER = "lora_te"
284
+
285
+ # SDXL: must starts with LORA_PREFIX_TEXT_ENCODER
286
+ LORA_PREFIX_TEXT_ENCODER1 = "lora_te1"
287
+ LORA_PREFIX_TEXT_ENCODER2 = "lora_te2"
288
+
289
+ def __init__(
290
+ self,
291
+ text_encoder: Union[List[CLIPTextModel], CLIPTextModel],
292
+ unet: UNet2DConditionModel,
293
+ multiplier: float = 1.0,
294
+ modules_dim: Optional[Dict[str, int]] = None,
295
+ modules_alpha: Optional[Dict[str, int]] = None,
296
+ varbose: Optional[bool] = False,
297
+ ) -> None:
298
+ super().__init__()
299
+ self.multiplier = multiplier
300
+
301
+ logger.info("create LoRA network from weights")
302
+
303
+ # convert SDXL Stability AI's U-Net modules to Diffusers
304
+ converted = self.convert_unet_modules(modules_dim, modules_alpha)
305
+ if converted:
306
+ logger.info(f"converted {converted} Stability AI's U-Net LoRA modules to Diffusers (SDXL)")
307
+
308
+ # create module instances
309
+ def create_modules(
310
+ is_unet: bool,
311
+ text_encoder_idx: Optional[int], # None, 1, 2
312
+ root_module: torch.nn.Module,
313
+ target_replace_modules: List[torch.nn.Module],
314
+ ) -> List[LoRAModule]:
315
+ prefix = (
316
+ self.LORA_PREFIX_UNET
317
+ if is_unet
318
+ else (
319
+ self.LORA_PREFIX_TEXT_ENCODER
320
+ if text_encoder_idx is None
321
+ else (self.LORA_PREFIX_TEXT_ENCODER1 if text_encoder_idx == 1 else self.LORA_PREFIX_TEXT_ENCODER2)
322
+ )
323
+ )
324
+ loras = []
325
+ skipped = []
326
+ for name, module in root_module.named_modules():
327
+ if module.__class__.__name__ in target_replace_modules:
328
+ for child_name, child_module in module.named_modules():
329
+ is_linear = (
330
+ child_module.__class__.__name__ == "Linear" or child_module.__class__.__name__ == "LoRACompatibleLinear"
331
+ )
332
+ is_conv2d = (
333
+ child_module.__class__.__name__ == "Conv2d" or child_module.__class__.__name__ == "LoRACompatibleConv"
334
+ )
335
+
336
+ if is_linear or is_conv2d:
337
+ lora_name = prefix + "." + name + "." + child_name
338
+ lora_name = lora_name.replace(".", "_")
339
+
340
+ if lora_name not in modules_dim:
341
+ # logger.info(f"skipped {lora_name} (not found in modules_dim)")
342
+ skipped.append(lora_name)
343
+ continue
344
+
345
+ dim = modules_dim[lora_name]
346
+ alpha = modules_alpha[lora_name]
347
+ lora = LoRAModule(
348
+ lora_name,
349
+ child_module,
350
+ self.multiplier,
351
+ dim,
352
+ alpha,
353
+ )
354
+ loras.append(lora)
355
+ return loras, skipped
356
+
357
+ text_encoders = text_encoder if type(text_encoder) == list else [text_encoder]
358
+
359
+ # create LoRA for text encoder
360
+ # 毎回すべてのモジュールを作るのは無駄なので要検討 / it is wasteful to create all modules every time, need to consider
361
+ self.text_encoder_loras: List[LoRAModule] = []
362
+ skipped_te = []
363
+ for i, text_encoder in enumerate(text_encoders):
364
+ if len(text_encoders) > 1:
365
+ index = i + 1
366
+ else:
367
+ index = None
368
+
369
+ text_encoder_loras, skipped = create_modules(False, index, text_encoder, LoRANetwork.TEXT_ENCODER_TARGET_REPLACE_MODULE)
370
+ self.text_encoder_loras.extend(text_encoder_loras)
371
+ skipped_te += skipped
372
+ logger.info(f"create LoRA for Text Encoder: {len(self.text_encoder_loras)} modules.")
373
+ if len(skipped_te) > 0:
374
+ logger.warning(f"skipped {len(skipped_te)} modules because of missing weight for text encoder.")
375
+
376
+ # extend U-Net target modules to include Conv2d 3x3
377
+ target_modules = LoRANetwork.UNET_TARGET_REPLACE_MODULE + LoRANetwork.UNET_TARGET_REPLACE_MODULE_CONV2D_3X3
378
+
379
+ self.unet_loras: List[LoRAModule]
380
+ self.unet_loras, skipped_un = create_modules(True, None, unet, target_modules)
381
+ logger.info(f"create LoRA for U-Net: {len(self.unet_loras)} modules.")
382
+ if len(skipped_un) > 0:
383
+ logger.warning(f"skipped {len(skipped_un)} modules because of missing weight for U-Net.")
384
+
385
+ # assertion
386
+ names = set()
387
+ for lora in self.text_encoder_loras + self.unet_loras:
388
+ names.add(lora.lora_name)
389
+ for lora_name in modules_dim.keys():
390
+ assert lora_name in names, f"{lora_name} is not found in created LoRA modules."
391
+
392
+ # make to work load_state_dict
393
+ for lora in self.text_encoder_loras + self.unet_loras:
394
+ self.add_module(lora.lora_name, lora)
395
+
396
+ # SDXL: convert SDXL Stability AI's U-Net modules to Diffusers
397
+ def convert_unet_modules(self, modules_dim, modules_alpha):
398
+ converted_count = 0
399
+ not_converted_count = 0
400
+
401
+ map_keys = list(UNET_CONVERSION_MAP.keys())
402
+ map_keys.sort()
403
+
404
+ for key in list(modules_dim.keys()):
405
+ if key.startswith(LoRANetwork.LORA_PREFIX_UNET + "_"):
406
+ search_key = key.replace(LoRANetwork.LORA_PREFIX_UNET + "_", "")
407
+ position = bisect.bisect_right(map_keys, search_key)
408
+ map_key = map_keys[position - 1]
409
+ if search_key.startswith(map_key):
410
+ new_key = key.replace(map_key, UNET_CONVERSION_MAP[map_key])
411
+ modules_dim[new_key] = modules_dim[key]
412
+ modules_alpha[new_key] = modules_alpha[key]
413
+ del modules_dim[key]
414
+ del modules_alpha[key]
415
+ converted_count += 1
416
+ else:
417
+ not_converted_count += 1
418
+ assert (
419
+ converted_count == 0 or not_converted_count == 0
420
+ ), f"some modules are not converted: {converted_count} converted, {not_converted_count} not converted"
421
+ return converted_count
422
+
423
+ def set_multiplier(self, multiplier):
424
+ self.multiplier = multiplier
425
+ for lora in self.text_encoder_loras + self.unet_loras:
426
+ lora.multiplier = self.multiplier
427
+
428
+ def apply_to(self, multiplier=1.0, apply_text_encoder=True, apply_unet=True):
429
+ if apply_text_encoder:
430
+ logger.info("enable LoRA for text encoder")
431
+ for lora in self.text_encoder_loras:
432
+ lora.apply_to(multiplier)
433
+ if apply_unet:
434
+ logger.info("enable LoRA for U-Net")
435
+ for lora in self.unet_loras:
436
+ lora.apply_to(multiplier)
437
+
438
+ def unapply_to(self):
439
+ for lora in self.text_encoder_loras + self.unet_loras:
440
+ lora.unapply_to()
441
+
442
+ def merge_to(self, multiplier=1.0):
443
+ logger.info("merge LoRA weights to original weights")
444
+ for lora in tqdm(self.text_encoder_loras + self.unet_loras):
445
+ lora.merge_to(multiplier)
446
+ logger.info(f"weights are merged")
447
+
448
+ def restore_from(self, multiplier=1.0):
449
+ logger.info("restore LoRA weights from original weights")
450
+ for lora in tqdm(self.text_encoder_loras + self.unet_loras):
451
+ lora.restore_from(multiplier)
452
+ logger.info(f"weights are restored")
453
+
454
+ def load_state_dict(self, state_dict: Mapping[str, Any], strict: bool = True):
455
+ # convert SDXL Stability AI's state dict to Diffusers' based state dict
456
+ map_keys = list(UNET_CONVERSION_MAP.keys()) # prefix of U-Net modules
457
+ map_keys.sort()
458
+ for key in list(state_dict.keys()):
459
+ if key.startswith(LoRANetwork.LORA_PREFIX_UNET + "_"):
460
+ search_key = key.replace(LoRANetwork.LORA_PREFIX_UNET + "_", "")
461
+ position = bisect.bisect_right(map_keys, search_key)
462
+ map_key = map_keys[position - 1]
463
+ if search_key.startswith(map_key):
464
+ new_key = key.replace(map_key, UNET_CONVERSION_MAP[map_key])
465
+ state_dict[new_key] = state_dict[key]
466
+ del state_dict[key]
467
+
468
+ # in case of V2, some weights have different shape, so we need to convert them
469
+ # because V2 LoRA is based on U-Net created by use_linear_projection=False
470
+ my_state_dict = self.state_dict()
471
+ for key in state_dict.keys():
472
+ if state_dict[key].size() != my_state_dict[key].size():
473
+ # logger.info(f"convert {key} from {state_dict[key].size()} to {my_state_dict[key].size()}")
474
+ state_dict[key] = state_dict[key].view(my_state_dict[key].size())
475
+
476
+ return super().load_state_dict(state_dict, strict)
477
+
478
+
479
+ if __name__ == "__main__":
480
+ # sample code to use LoRANetwork
481
+ import os
482
+ import argparse
483
+ from diffusers import StableDiffusionPipeline, StableDiffusionXLPipeline
484
+ import torch
485
+
486
+ device = get_preferred_device()
487
+
488
+ parser = argparse.ArgumentParser()
489
+ parser.add_argument("--model_id", type=str, default=None, help="model id for huggingface")
490
+ parser.add_argument("--lora_weights", type=str, default=None, help="path to LoRA weights")
491
+ parser.add_argument("--sdxl", action="store_true", help="use SDXL model")
492
+ parser.add_argument("--prompt", type=str, default="A photo of cat", help="prompt text")
493
+ parser.add_argument("--negative_prompt", type=str, default="", help="negative prompt text")
494
+ parser.add_argument("--seed", type=int, default=0, help="random seed")
495
+ args = parser.parse_args()
496
+
497
+ image_prefix = args.model_id.replace("/", "_") + "_"
498
+
499
+ # load Diffusers model
500
+ logger.info(f"load model from {args.model_id}")
501
+ pipe: Union[StableDiffusionPipeline, StableDiffusionXLPipeline]
502
+ if args.sdxl:
503
+ # use_safetensors=True does not work with 0.18.2
504
+ pipe = StableDiffusionXLPipeline.from_pretrained(args.model_id, variant="fp16", torch_dtype=torch.float16)
505
+ else:
506
+ pipe = StableDiffusionPipeline.from_pretrained(args.model_id, variant="fp16", torch_dtype=torch.float16)
507
+ pipe.to(device)
508
+ pipe.set_use_memory_efficient_attention_xformers(True)
509
+
510
+ text_encoders = [pipe.text_encoder, pipe.text_encoder_2] if args.sdxl else [pipe.text_encoder]
511
+
512
+ # load LoRA weights
513
+ logger.info(f"load LoRA weights from {args.lora_weights}")
514
+ if os.path.splitext(args.lora_weights)[1] == ".safetensors":
515
+ from safetensors.torch import load_file
516
+
517
+ lora_sd = load_file(args.lora_weights)
518
+ else:
519
+ lora_sd = torch.load(args.lora_weights)
520
+
521
+ # create by LoRA weights and load weights
522
+ logger.info(f"create LoRA network")
523
+ lora_network: LoRANetwork = create_network_from_weights(text_encoders, pipe.unet, lora_sd, multiplier=1.0)
524
+
525
+ logger.info(f"load LoRA network weights")
526
+ lora_network.load_state_dict(lora_sd)
527
+
528
+ lora_network.to(device, dtype=pipe.unet.dtype) # required to apply_to. merge_to works without this
529
+
530
+ # 必要があれば、元のモデルの重みをバックアップしておく
531
+ # back-up unet/text encoder weights if necessary
532
+ def detach_and_move_to_cpu(state_dict):
533
+ for k, v in state_dict.items():
534
+ state_dict[k] = v.detach().cpu()
535
+ return state_dict
536
+
537
+ org_unet_sd = pipe.unet.state_dict()
538
+ detach_and_move_to_cpu(org_unet_sd)
539
+
540
+ org_text_encoder_sd = pipe.text_encoder.state_dict()
541
+ detach_and_move_to_cpu(org_text_encoder_sd)
542
+
543
+ if args.sdxl:
544
+ org_text_encoder_2_sd = pipe.text_encoder_2.state_dict()
545
+ detach_and_move_to_cpu(org_text_encoder_2_sd)
546
+
547
+ def seed_everything(seed):
548
+ torch.manual_seed(seed)
549
+ torch.cuda.manual_seed_all(seed)
550
+ np.random.seed(seed)
551
+ random.seed(seed)
552
+
553
+ # create image with original weights
554
+ logger.info(f"create image with original weights")
555
+ seed_everything(args.seed)
556
+ image = pipe(args.prompt, negative_prompt=args.negative_prompt).images[0]
557
+ image.save(image_prefix + "original.png")
558
+
559
+ # apply LoRA network to the model: slower than merge_to, but can be reverted easily
560
+ logger.info(f"apply LoRA network to the model")
561
+ lora_network.apply_to(multiplier=1.0)
562
+
563
+ logger.info(f"create image with applied LoRA")
564
+ seed_everything(args.seed)
565
+ image = pipe(args.prompt, negative_prompt=args.negative_prompt).images[0]
566
+ image.save(image_prefix + "applied_lora.png")
567
+
568
+ # unapply LoRA network to the model
569
+ logger.info(f"unapply LoRA network to the model")
570
+ lora_network.unapply_to()
571
+
572
+ logger.info(f"create image with unapplied LoRA")
573
+ seed_everything(args.seed)
574
+ image = pipe(args.prompt, negative_prompt=args.negative_prompt).images[0]
575
+ image.save(image_prefix + "unapplied_lora.png")
576
+
577
+ # merge LoRA network to the model: faster than apply_to, but requires back-up of original weights (or unmerge_to)
578
+ logger.info(f"merge LoRA network to the model")
579
+ lora_network.merge_to(multiplier=1.0)
580
+
581
+ logger.info(f"create image with LoRA")
582
+ seed_everything(args.seed)
583
+ image = pipe(args.prompt, negative_prompt=args.negative_prompt).images[0]
584
+ image.save(image_prefix + "merged_lora.png")
585
+
586
+ # restore (unmerge) LoRA weights: numerically unstable
587
+ # マージされた重みを元に戻す。計算誤差のため、元の重みと完全に一致しないことがあるかもしれない
588
+ # 保存したstate_dictから元の重みを復元するのが確実
589
+ logger.info(f"restore (unmerge) LoRA weights")
590
+ lora_network.restore_from(multiplier=1.0)
591
+
592
+ logger.info(f"create image without LoRA")
593
+ seed_everything(args.seed)
594
+ image = pipe(args.prompt, negative_prompt=args.negative_prompt).images[0]
595
+ image.save(image_prefix + "unmerged_lora.png")
596
+
597
+ # restore original weights
598
+ logger.info(f"restore original weights")
599
+ pipe.unet.load_state_dict(org_unet_sd)
600
+ pipe.text_encoder.load_state_dict(org_text_encoder_sd)
601
+ if args.sdxl:
602
+ pipe.text_encoder_2.load_state_dict(org_text_encoder_2_sd)
603
+
604
+ logger.info(f"create image with restored original weights")
605
+ seed_everything(args.seed)
606
+ image = pipe(args.prompt, negative_prompt=args.negative_prompt).images[0]
607
+ image.save(image_prefix + "restore_original.png")
608
+
609
+ # use convenience function to merge LoRA weights
610
+ logger.info(f"merge LoRA weights with convenience function")
611
+ merge_lora_weights(pipe, lora_sd, multiplier=1.0)
612
+
613
+ logger.info(f"create image with merged LoRA weights")
614
+ seed_everything(args.seed)
615
+ image = pipe(args.prompt, negative_prompt=args.negative_prompt).images[0]
616
+ image.save(image_prefix + "convenience_merged_lora.png")
lora_fa.py ADDED
@@ -0,0 +1,1244 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # LoRA network module
2
+ # reference:
3
+ # https://github.com/microsoft/LoRA/blob/main/loralib/layers.py
4
+ # https://github.com/cloneofsimo/lora/blob/master/lora_diffusion/lora.py
5
+
6
+ # temporary implementation of LoRA-FA: https://arxiv.org/abs/2308.03303
7
+ # need to be refactored and merged to lora.py
8
+
9
+ import math
10
+ import os
11
+ from typing import Dict, List, Optional, Tuple, Type, Union
12
+ from diffusers import AutoencoderKL
13
+ from transformers import CLIPTextModel
14
+ import numpy as np
15
+ import torch
16
+ import re
17
+ from library.utils import setup_logging
18
+ setup_logging()
19
+ import logging
20
+ logger = logging.getLogger(__name__)
21
+
22
+ RE_UPDOWN = re.compile(r"(up|down)_blocks_(\d+)_(resnets|upsamplers|downsamplers|attentions)_(\d+)_")
23
+
24
+
25
+ class LoRAModule(torch.nn.Module):
26
+ """
27
+ replaces forward method of the original Linear, instead of replacing the original Linear module.
28
+ """
29
+
30
+ def __init__(
31
+ self,
32
+ lora_name,
33
+ org_module: torch.nn.Module,
34
+ multiplier=1.0,
35
+ lora_dim=4,
36
+ alpha=1,
37
+ dropout=None,
38
+ rank_dropout=None,
39
+ module_dropout=None,
40
+ ):
41
+ """if alpha == 0 or None, alpha is rank (no scaling)."""
42
+ super().__init__()
43
+ self.lora_name = lora_name
44
+
45
+ if org_module.__class__.__name__ == "Conv2d":
46
+ in_dim = org_module.in_channels
47
+ out_dim = org_module.out_channels
48
+ else:
49
+ in_dim = org_module.in_features
50
+ out_dim = org_module.out_features
51
+
52
+ # if limit_rank:
53
+ # self.lora_dim = min(lora_dim, in_dim, out_dim)
54
+ # if self.lora_dim != lora_dim:
55
+ # logger.info(f"{lora_name} dim (rank) is changed to: {self.lora_dim}")
56
+ # else:
57
+ self.lora_dim = lora_dim
58
+
59
+ if org_module.__class__.__name__ == "Conv2d":
60
+ kernel_size = org_module.kernel_size
61
+ stride = org_module.stride
62
+ padding = org_module.padding
63
+ self.lora_down = torch.nn.Conv2d(in_dim, self.lora_dim, kernel_size, stride, padding, bias=False)
64
+ self.lora_up = torch.nn.Conv2d(self.lora_dim, out_dim, (1, 1), (1, 1), bias=False)
65
+ else:
66
+ self.lora_down = torch.nn.Linear(in_dim, self.lora_dim, bias=False)
67
+ self.lora_up = torch.nn.Linear(self.lora_dim, out_dim, bias=False)
68
+
69
+ if type(alpha) == torch.Tensor:
70
+ alpha = alpha.detach().float().numpy() # without casting, bf16 causes error
71
+ alpha = self.lora_dim if alpha is None or alpha == 0 else alpha
72
+ self.scale = alpha / self.lora_dim
73
+ self.register_buffer("alpha", torch.tensor(alpha)) # 定数として扱える
74
+
75
+ # # same as microsoft's
76
+ # torch.nn.init.kaiming_uniform_(self.lora_down.weight, a=math.sqrt(5))
77
+
78
+ # according to the paper, initialize LoRA-A (down) as normal distribution
79
+ torch.nn.init.normal_(self.lora_down.weight, std=math.sqrt(2.0 / (in_dim + self.lora_dim)))
80
+
81
+ torch.nn.init.zeros_(self.lora_up.weight)
82
+
83
+ self.multiplier = multiplier
84
+ self.org_module = org_module # remove in applying
85
+ self.dropout = dropout
86
+ self.rank_dropout = rank_dropout
87
+ self.module_dropout = module_dropout
88
+
89
+ def get_trainable_params(self):
90
+ params = self.named_parameters()
91
+ trainable_params = []
92
+ for param in params:
93
+ if param[0] == "lora_up.weight": # up only
94
+ trainable_params.append(param[1])
95
+ return trainable_params
96
+
97
+ def requires_grad_(self, requires_grad: bool = True):
98
+ self.lora_up.requires_grad_(requires_grad)
99
+ self.lora_down.requires_grad_(False)
100
+ return self
101
+
102
+ def apply_to(self):
103
+ self.org_forward = self.org_module.forward
104
+ self.org_module.forward = self.forward
105
+ del self.org_module
106
+
107
+ def forward(self, x):
108
+ org_forwarded = self.org_forward(x)
109
+
110
+ # module dropout
111
+ if self.module_dropout is not None and self.training:
112
+ if torch.rand(1) < self.module_dropout:
113
+ return org_forwarded
114
+
115
+ lx = self.lora_down(x)
116
+
117
+ # normal dropout
118
+ if self.dropout is not None and self.training:
119
+ lx = torch.nn.functional.dropout(lx, p=self.dropout)
120
+
121
+ # rank dropout
122
+ if self.rank_dropout is not None and self.training:
123
+ mask = torch.rand((lx.size(0), self.lora_dim), device=lx.device) > self.rank_dropout
124
+ if len(lx.size()) == 3:
125
+ mask = mask.unsqueeze(1) # for Text Encoder
126
+ elif len(lx.size()) == 4:
127
+ mask = mask.unsqueeze(-1).unsqueeze(-1) # for Conv2d
128
+ lx = lx * mask
129
+
130
+ # scaling for rank dropout: treat as if the rank is changed
131
+ # maskから計算することも考えられるが、augmentation的な効果を期待してrank_dropoutを用いる
132
+ scale = self.scale * (1.0 / (1.0 - self.rank_dropout)) # redundant for readability
133
+ else:
134
+ scale = self.scale
135
+
136
+ lx = self.lora_up(lx)
137
+
138
+ return org_forwarded + lx * self.multiplier * scale
139
+
140
+
141
+ class LoRAInfModule(LoRAModule):
142
+ def __init__(
143
+ self,
144
+ lora_name,
145
+ org_module: torch.nn.Module,
146
+ multiplier=1.0,
147
+ lora_dim=4,
148
+ alpha=1,
149
+ **kwargs,
150
+ ):
151
+ # no dropout for inference
152
+ super().__init__(lora_name, org_module, multiplier, lora_dim, alpha)
153
+
154
+ self.org_module_ref = [org_module] # 後から参照できるように
155
+ self.enabled = True
156
+
157
+ # check regional or not by lora_name
158
+ self.text_encoder = False
159
+ if lora_name.startswith("lora_te_"):
160
+ self.regional = False
161
+ self.use_sub_prompt = True
162
+ self.text_encoder = True
163
+ elif "attn2_to_k" in lora_name or "attn2_to_v" in lora_name:
164
+ self.regional = False
165
+ self.use_sub_prompt = True
166
+ elif "time_emb" in lora_name:
167
+ self.regional = False
168
+ self.use_sub_prompt = False
169
+ else:
170
+ self.regional = True
171
+ self.use_sub_prompt = False
172
+
173
+ self.network: LoRANetwork = None
174
+
175
+ def set_network(self, network):
176
+ self.network = network
177
+
178
+ # freezeしてマージする
179
+ def merge_to(self, sd, dtype, device):
180
+ # get up/down weight
181
+ up_weight = sd["lora_up.weight"].to(torch.float).to(device)
182
+ down_weight = sd["lora_down.weight"].to(torch.float).to(device)
183
+
184
+ # extract weight from org_module
185
+ org_sd = self.org_module.state_dict()
186
+ weight = org_sd["weight"].to(torch.float)
187
+
188
+ # merge weight
189
+ if len(weight.size()) == 2:
190
+ # linear
191
+ weight = weight + self.multiplier * (up_weight @ down_weight) * self.scale
192
+ elif down_weight.size()[2:4] == (1, 1):
193
+ # conv2d 1x1
194
+ weight = (
195
+ weight
196
+ + self.multiplier
197
+ * (up_weight.squeeze(3).squeeze(2) @ down_weight.squeeze(3).squeeze(2)).unsqueeze(2).unsqueeze(3)
198
+ * self.scale
199
+ )
200
+ else:
201
+ # conv2d 3x3
202
+ conved = torch.nn.functional.conv2d(down_weight.permute(1, 0, 2, 3), up_weight).permute(1, 0, 2, 3)
203
+ # logger.info(conved.size(), weight.size(), module.stride, module.padding)
204
+ weight = weight + self.multiplier * conved * self.scale
205
+
206
+ # set weight to org_module
207
+ org_sd["weight"] = weight.to(dtype)
208
+ self.org_module.load_state_dict(org_sd)
209
+
210
+ # 復元できるマージのため、このモジュールのweightを返す
211
+ def get_weight(self, multiplier=None):
212
+ if multiplier is None:
213
+ multiplier = self.multiplier
214
+
215
+ # get up/down weight from module
216
+ up_weight = self.lora_up.weight.to(torch.float)
217
+ down_weight = self.lora_down.weight.to(torch.float)
218
+
219
+ # pre-calculated weight
220
+ if len(down_weight.size()) == 2:
221
+ # linear
222
+ weight = self.multiplier * (up_weight @ down_weight) * self.scale
223
+ elif down_weight.size()[2:4] == (1, 1):
224
+ # conv2d 1x1
225
+ weight = (
226
+ self.multiplier
227
+ * (up_weight.squeeze(3).squeeze(2) @ down_weight.squeeze(3).squeeze(2)).unsqueeze(2).unsqueeze(3)
228
+ * self.scale
229
+ )
230
+ else:
231
+ # conv2d 3x3
232
+ conved = torch.nn.functional.conv2d(down_weight.permute(1, 0, 2, 3), up_weight).permute(1, 0, 2, 3)
233
+ weight = self.multiplier * conved * self.scale
234
+
235
+ return weight
236
+
237
+ def set_region(self, region):
238
+ self.region = region
239
+ self.region_mask = None
240
+
241
+ def default_forward(self, x):
242
+ # logger.info("default_forward", self.lora_name, x.size())
243
+ return self.org_forward(x) + self.lora_up(self.lora_down(x)) * self.multiplier * self.scale
244
+
245
+ def forward(self, x):
246
+ if not self.enabled:
247
+ return self.org_forward(x)
248
+
249
+ if self.network is None or self.network.sub_prompt_index is None:
250
+ return self.default_forward(x)
251
+ if not self.regional and not self.use_sub_prompt:
252
+ return self.default_forward(x)
253
+
254
+ if self.regional:
255
+ return self.regional_forward(x)
256
+ else:
257
+ return self.sub_prompt_forward(x)
258
+
259
+ def get_mask_for_x(self, x):
260
+ # calculate size from shape of x
261
+ if len(x.size()) == 4:
262
+ h, w = x.size()[2:4]
263
+ area = h * w
264
+ else:
265
+ area = x.size()[1]
266
+
267
+ mask = self.network.mask_dic[area]
268
+ if mask is None:
269
+ raise ValueError(f"mask is None for resolution {area}")
270
+ if len(x.size()) != 4:
271
+ mask = torch.reshape(mask, (1, -1, 1))
272
+ return mask
273
+
274
+ def regional_forward(self, x):
275
+ if "attn2_to_out" in self.lora_name:
276
+ return self.to_out_forward(x)
277
+
278
+ if self.network.mask_dic is None: # sub_prompt_index >= 3
279
+ return self.default_forward(x)
280
+
281
+ # apply mask for LoRA result
282
+ lx = self.lora_up(self.lora_down(x)) * self.multiplier * self.scale
283
+ mask = self.get_mask_for_x(lx)
284
+ # logger.info("regional", self.lora_name, self.network.sub_prompt_index, lx.size(), mask.size())
285
+ lx = lx * mask
286
+
287
+ x = self.org_forward(x)
288
+ x = x + lx
289
+
290
+ if "attn2_to_q" in self.lora_name and self.network.is_last_network:
291
+ x = self.postp_to_q(x)
292
+
293
+ return x
294
+
295
+ def postp_to_q(self, x):
296
+ # repeat x to num_sub_prompts
297
+ has_real_uncond = x.size()[0] // self.network.batch_size == 3
298
+ qc = self.network.batch_size # uncond
299
+ qc += self.network.batch_size * self.network.num_sub_prompts # cond
300
+ if has_real_uncond:
301
+ qc += self.network.batch_size # real_uncond
302
+
303
+ query = torch.zeros((qc, x.size()[1], x.size()[2]), device=x.device, dtype=x.dtype)
304
+ query[: self.network.batch_size] = x[: self.network.batch_size]
305
+
306
+ for i in range(self.network.batch_size):
307
+ qi = self.network.batch_size + i * self.network.num_sub_prompts
308
+ query[qi : qi + self.network.num_sub_prompts] = x[self.network.batch_size + i]
309
+
310
+ if has_real_uncond:
311
+ query[-self.network.batch_size :] = x[-self.network.batch_size :]
312
+
313
+ # logger.info("postp_to_q", self.lora_name, x.size(), query.size(), self.network.num_sub_prompts)
314
+ return query
315
+
316
+ def sub_prompt_forward(self, x):
317
+ if x.size()[0] == self.network.batch_size: # if uncond in text_encoder, do not apply LoRA
318
+ return self.org_forward(x)
319
+
320
+ emb_idx = self.network.sub_prompt_index
321
+ if not self.text_encoder:
322
+ emb_idx += self.network.batch_size
323
+
324
+ # apply sub prompt of X
325
+ lx = x[emb_idx :: self.network.num_sub_prompts]
326
+ lx = self.lora_up(self.lora_down(lx)) * self.multiplier * self.scale
327
+
328
+ # logger.info("sub_prompt_forward", self.lora_name, x.size(), lx.size(), emb_idx)
329
+
330
+ x = self.org_forward(x)
331
+ x[emb_idx :: self.network.num_sub_prompts] += lx
332
+
333
+ return x
334
+
335
+ def to_out_forward(self, x):
336
+ # logger.info("to_out_forward", self.lora_name, x.size(), self.network.is_last_network)
337
+
338
+ if self.network.is_last_network:
339
+ masks = [None] * self.network.num_sub_prompts
340
+ self.network.shared[self.lora_name] = (None, masks)
341
+ else:
342
+ lx, masks = self.network.shared[self.lora_name]
343
+
344
+ # call own LoRA
345
+ x1 = x[self.network.batch_size + self.network.sub_prompt_index :: self.network.num_sub_prompts]
346
+ lx1 = self.lora_up(self.lora_down(x1)) * self.multiplier * self.scale
347
+
348
+ if self.network.is_last_network:
349
+ lx = torch.zeros(
350
+ (self.network.num_sub_prompts * self.network.batch_size, *lx1.size()[1:]), device=lx1.device, dtype=lx1.dtype
351
+ )
352
+ self.network.shared[self.lora_name] = (lx, masks)
353
+
354
+ # logger.info("to_out_forward", lx.size(), lx1.size(), self.network.sub_prompt_index, self.network.num_sub_prompts)
355
+ lx[self.network.sub_prompt_index :: self.network.num_sub_prompts] += lx1
356
+ masks[self.network.sub_prompt_index] = self.get_mask_for_x(lx1)
357
+
358
+ # if not last network, return x and masks
359
+ x = self.org_forward(x)
360
+ if not self.network.is_last_network:
361
+ return x
362
+
363
+ lx, masks = self.network.shared.pop(self.lora_name)
364
+
365
+ # if last network, combine separated x with mask weighted sum
366
+ has_real_uncond = x.size()[0] // self.network.batch_size == self.network.num_sub_prompts + 2
367
+
368
+ out = torch.zeros((self.network.batch_size * (3 if has_real_uncond else 2), *x.size()[1:]), device=x.device, dtype=x.dtype)
369
+ out[: self.network.batch_size] = x[: self.network.batch_size] # uncond
370
+ if has_real_uncond:
371
+ out[-self.network.batch_size :] = x[-self.network.batch_size :] # real_uncond
372
+
373
+ # logger.info("to_out_forward", self.lora_name, self.network.sub_prompt_index, self.network.num_sub_prompts)
374
+ # for i in range(len(masks)):
375
+ # if masks[i] is None:
376
+ # masks[i] = torch.zeros_like(masks[-1])
377
+
378
+ mask = torch.cat(masks)
379
+ mask_sum = torch.sum(mask, dim=0) + 1e-4
380
+ for i in range(self.network.batch_size):
381
+ # 1枚の画像ごとに処理する
382
+ lx1 = lx[i * self.network.num_sub_prompts : (i + 1) * self.network.num_sub_prompts]
383
+ lx1 = lx1 * mask
384
+ lx1 = torch.sum(lx1, dim=0)
385
+
386
+ xi = self.network.batch_size + i * self.network.num_sub_prompts
387
+ x1 = x[xi : xi + self.network.num_sub_prompts]
388
+ x1 = x1 * mask
389
+ x1 = torch.sum(x1, dim=0)
390
+ x1 = x1 / mask_sum
391
+
392
+ x1 = x1 + lx1
393
+ out[self.network.batch_size + i] = x1
394
+
395
+ # logger.info("to_out_forward", x.size(), out.size(), has_real_uncond)
396
+ return out
397
+
398
+
399
+ def parse_block_lr_kwargs(nw_kwargs):
400
+ down_lr_weight = nw_kwargs.get("down_lr_weight", None)
401
+ mid_lr_weight = nw_kwargs.get("mid_lr_weight", None)
402
+ up_lr_weight = nw_kwargs.get("up_lr_weight", None)
403
+
404
+ # 以上のいずれにも設定がない場合は無効としてNoneを返す
405
+ if down_lr_weight is None and mid_lr_weight is None and up_lr_weight is None:
406
+ return None, None, None
407
+
408
+ # extract learning rate weight for each block
409
+ if down_lr_weight is not None:
410
+ # if some parameters are not set, use zero
411
+ if "," in down_lr_weight:
412
+ down_lr_weight = [(float(s) if s else 0.0) for s in down_lr_weight.split(",")]
413
+
414
+ if mid_lr_weight is not None:
415
+ mid_lr_weight = float(mid_lr_weight)
416
+
417
+ if up_lr_weight is not None:
418
+ if "," in up_lr_weight:
419
+ up_lr_weight = [(float(s) if s else 0.0) for s in up_lr_weight.split(",")]
420
+
421
+ down_lr_weight, mid_lr_weight, up_lr_weight = get_block_lr_weight(
422
+ down_lr_weight, mid_lr_weight, up_lr_weight, float(nw_kwargs.get("block_lr_zero_threshold", 0.0))
423
+ )
424
+
425
+ return down_lr_weight, mid_lr_weight, up_lr_weight
426
+
427
+
428
+ def create_network(
429
+ multiplier: float,
430
+ network_dim: Optional[int],
431
+ network_alpha: Optional[float],
432
+ vae: AutoencoderKL,
433
+ text_encoder: Union[CLIPTextModel, List[CLIPTextModel]],
434
+ unet,
435
+ neuron_dropout: Optional[float] = None,
436
+ **kwargs,
437
+ ):
438
+ if network_dim is None:
439
+ network_dim = 4 # default
440
+ if network_alpha is None:
441
+ network_alpha = 1.0
442
+
443
+ # extract dim/alpha for conv2d, and block dim
444
+ conv_dim = kwargs.get("conv_dim", None)
445
+ conv_alpha = kwargs.get("conv_alpha", None)
446
+ if conv_dim is not None:
447
+ conv_dim = int(conv_dim)
448
+ if conv_alpha is None:
449
+ conv_alpha = 1.0
450
+ else:
451
+ conv_alpha = float(conv_alpha)
452
+
453
+ # block dim/alpha/lr
454
+ block_dims = kwargs.get("block_dims", None)
455
+ down_lr_weight, mid_lr_weight, up_lr_weight = parse_block_lr_kwargs(kwargs)
456
+
457
+ # 以上のいずれかに指定があればblockごとのdim(rank)を有効にする
458
+ if block_dims is not None or down_lr_weight is not None or mid_lr_weight is not None or up_lr_weight is not None:
459
+ block_alphas = kwargs.get("block_alphas", None)
460
+ conv_block_dims = kwargs.get("conv_block_dims", None)
461
+ conv_block_alphas = kwargs.get("conv_block_alphas", None)
462
+
463
+ block_dims, block_alphas, conv_block_dims, conv_block_alphas = get_block_dims_and_alphas(
464
+ block_dims, block_alphas, network_dim, network_alpha, conv_block_dims, conv_block_alphas, conv_dim, conv_alpha
465
+ )
466
+
467
+ # remove block dim/alpha without learning rate
468
+ block_dims, block_alphas, conv_block_dims, conv_block_alphas = remove_block_dims_and_alphas(
469
+ block_dims, block_alphas, conv_block_dims, conv_block_alphas, down_lr_weight, mid_lr_weight, up_lr_weight
470
+ )
471
+
472
+ else:
473
+ block_alphas = None
474
+ conv_block_dims = None
475
+ conv_block_alphas = None
476
+
477
+ # rank/module dropout
478
+ rank_dropout = kwargs.get("rank_dropout", None)
479
+ if rank_dropout is not None:
480
+ rank_dropout = float(rank_dropout)
481
+ module_dropout = kwargs.get("module_dropout", None)
482
+ if module_dropout is not None:
483
+ module_dropout = float(module_dropout)
484
+
485
+ # すごく引数が多いな ( ^ω^)・・・
486
+ network = LoRANetwork(
487
+ text_encoder,
488
+ unet,
489
+ multiplier=multiplier,
490
+ lora_dim=network_dim,
491
+ alpha=network_alpha,
492
+ dropout=neuron_dropout,
493
+ rank_dropout=rank_dropout,
494
+ module_dropout=module_dropout,
495
+ conv_lora_dim=conv_dim,
496
+ conv_alpha=conv_alpha,
497
+ block_dims=block_dims,
498
+ block_alphas=block_alphas,
499
+ conv_block_dims=conv_block_dims,
500
+ conv_block_alphas=conv_block_alphas,
501
+ varbose=True,
502
+ )
503
+
504
+ if up_lr_weight is not None or mid_lr_weight is not None or down_lr_weight is not None:
505
+ network.set_block_lr_weight(up_lr_weight, mid_lr_weight, down_lr_weight)
506
+
507
+ return network
508
+
509
+
510
+ # このメソッドは外部から呼び出される可能性を考慮しておく
511
+ # network_dim, network_alpha にはデフォルト値が入っている。
512
+ # block_dims, block_alphas は両方ともNoneまたは両方とも値が入っている
513
+ # conv_dim, conv_alpha は両方ともNoneまたは両方とも値が入っている
514
+ def get_block_dims_and_alphas(
515
+ block_dims, block_alphas, network_dim, network_alpha, conv_block_dims, conv_block_alphas, conv_dim, conv_alpha
516
+ ):
517
+ num_total_blocks = LoRANetwork.NUM_OF_BLOCKS * 2 + 1
518
+
519
+ def parse_ints(s):
520
+ return [int(i) for i in s.split(",")]
521
+
522
+ def parse_floats(s):
523
+ return [float(i) for i in s.split(",")]
524
+
525
+ # block_dimsとblock_alphasをパースする。必ず値が入る
526
+ if block_dims is not None:
527
+ block_dims = parse_ints(block_dims)
528
+ assert (
529
+ len(block_dims) == num_total_blocks
530
+ ), f"block_dims must have {num_total_blocks} elements / block_dimsは{num_total_blocks}個指定してください"
531
+ else:
532
+ logger.warning(f"block_dims is not specified. all dims are set to {network_dim} / block_dimsが指定されていません。すべてのdimは{network_dim}になります")
533
+ block_dims = [network_dim] * num_total_blocks
534
+
535
+ if block_alphas is not None:
536
+ block_alphas = parse_floats(block_alphas)
537
+ assert (
538
+ len(block_alphas) == num_total_blocks
539
+ ), f"block_alphas must have {num_total_blocks} elements / block_alphasは{num_total_blocks}個指定してください"
540
+ else:
541
+ logger.warning(
542
+ f"block_alphas is not specified. all alphas are set to {network_alpha} / block_alphasが指定されていません。すべてのalphaは{network_alpha}になります"
543
+ )
544
+ block_alphas = [network_alpha] * num_total_blocks
545
+
546
+ # conv_block_dimsとconv_block_alphasを、指定がある場合のみパースする。指定がなければconv_dimとconv_alphaを使う
547
+ if conv_block_dims is not None:
548
+ conv_block_dims = parse_ints(conv_block_dims)
549
+ assert (
550
+ len(conv_block_dims) == num_total_blocks
551
+ ), f"conv_block_dims must have {num_total_blocks} elements / conv_block_dimsは{num_total_blocks}個指定してください"
552
+
553
+ if conv_block_alphas is not None:
554
+ conv_block_alphas = parse_floats(conv_block_alphas)
555
+ assert (
556
+ len(conv_block_alphas) == num_total_blocks
557
+ ), f"conv_block_alphas must have {num_total_blocks} elements / conv_block_alphasは{num_total_blocks}個指定してください"
558
+ else:
559
+ if conv_alpha is None:
560
+ conv_alpha = 1.0
561
+ logger.warning(
562
+ f"conv_block_alphas is not specified. all alphas are set to {conv_alpha} / conv_block_alphasが指定されていません。すべてのalphaは{conv_alpha}になります"
563
+ )
564
+ conv_block_alphas = [conv_alpha] * num_total_blocks
565
+ else:
566
+ if conv_dim is not None:
567
+ logger.warning(
568
+ f"conv_dim/alpha for all blocks are set to {conv_dim} and {conv_alpha} / すべてのブロックのconv_dimとalphaは{conv_dim}および{conv_alpha}になります"
569
+ )
570
+ conv_block_dims = [conv_dim] * num_total_blocks
571
+ conv_block_alphas = [conv_alpha] * num_total_blocks
572
+ else:
573
+ conv_block_dims = None
574
+ conv_block_alphas = None
575
+
576
+ return block_dims, block_alphas, conv_block_dims, conv_block_alphas
577
+
578
+
579
+ # 層別学習率用に層ごとの学習率に対する倍率を定義する、外部から呼び出される可能性を考慮しておく
580
+ def get_block_lr_weight(
581
+ down_lr_weight, mid_lr_weight, up_lr_weight, zero_threshold
582
+ ) -> Tuple[List[float], List[float], List[float]]:
583
+ # パラメータ未指定時は何もせず、今までと同じ動作とする
584
+ if up_lr_weight is None and mid_lr_weight is None and down_lr_weight is None:
585
+ return None, None, None
586
+
587
+ max_len = LoRANetwork.NUM_OF_BLOCKS # フルモデル相当でのup,downの層の数
588
+
589
+ def get_list(name_with_suffix) -> List[float]:
590
+ import math
591
+
592
+ tokens = name_with_suffix.split("+")
593
+ name = tokens[0]
594
+ base_lr = float(tokens[1]) if len(tokens) > 1 else 0.0
595
+
596
+ if name == "cosine":
597
+ return [math.sin(math.pi * (i / (max_len - 1)) / 2) + base_lr for i in reversed(range(max_len))]
598
+ elif name == "sine":
599
+ return [math.sin(math.pi * (i / (max_len - 1)) / 2) + base_lr for i in range(max_len)]
600
+ elif name == "linear":
601
+ return [i / (max_len - 1) + base_lr for i in range(max_len)]
602
+ elif name == "reverse_linear":
603
+ return [i / (max_len - 1) + base_lr for i in reversed(range(max_len))]
604
+ elif name == "zeros":
605
+ return [0.0 + base_lr] * max_len
606
+ else:
607
+ logger.error(
608
+ "Unknown lr_weight argument %s is used. Valid arguments: / 不明なlr_weightの引数 %s が使われました。有効な引数:\n\tcosine, sine, linear, reverse_linear, zeros"
609
+ % (name)
610
+ )
611
+ return None
612
+
613
+ if type(down_lr_weight) == str:
614
+ down_lr_weight = get_list(down_lr_weight)
615
+ if type(up_lr_weight) == str:
616
+ up_lr_weight = get_list(up_lr_weight)
617
+
618
+ if (up_lr_weight != None and len(up_lr_weight) > max_len) or (down_lr_weight != None and len(down_lr_weight) > max_len):
619
+ logger.warning("down_weight or up_weight is too long. Parameters after %d-th are ignored." % max_len)
620
+ logger.warning("down_weightもしくはup_weightが長すぎます。%d個目以降のパラメータは無視されます。" % max_len)
621
+ up_lr_weight = up_lr_weight[:max_len]
622
+ down_lr_weight = down_lr_weight[:max_len]
623
+
624
+ if (up_lr_weight != None and len(up_lr_weight) < max_len) or (down_lr_weight != None and len(down_lr_weight) < max_len):
625
+ logger.warning("down_weight or up_weight is too short. Parameters after %d-th are filled with 1." % max_len)
626
+ logger.warning("down_weightもしくはup_weightが短すぎます。%d個目までの不足したパラメータは1で補われます。" % max_len)
627
+
628
+ if down_lr_weight != None and len(down_lr_weight) < max_len:
629
+ down_lr_weight = down_lr_weight + [1.0] * (max_len - len(down_lr_weight))
630
+ if up_lr_weight != None and len(up_lr_weight) < max_len:
631
+ up_lr_weight = up_lr_weight + [1.0] * (max_len - len(up_lr_weight))
632
+
633
+ if (up_lr_weight != None) or (mid_lr_weight != None) or (down_lr_weight != None):
634
+ logger.info("apply block learning rate / 階層別学習率を適用します。")
635
+ if down_lr_weight != None:
636
+ down_lr_weight = [w if w > zero_threshold else 0 for w in down_lr_weight]
637
+ logger.info(f"down_lr_weight (shallower -> deeper, 浅い層->深い層): {down_lr_weight}")
638
+ else:
639
+ logger.info("down_lr_weight: all 1.0, すべて1.0")
640
+
641
+ if mid_lr_weight != None:
642
+ mid_lr_weight = mid_lr_weight if mid_lr_weight > zero_threshold else 0
643
+ logger.info(f"mid_lr_weight: {mid_lr_weight}")
644
+ else:
645
+ logger.info("mid_lr_weight: 1.0")
646
+
647
+ if up_lr_weight != None:
648
+ up_lr_weight = [w if w > zero_threshold else 0 for w in up_lr_weight]
649
+ logger.info(f"up_lr_weight (deeper -> shallower, 深い層->浅い層): {up_lr_weight}")
650
+ else:
651
+ logger.info("up_lr_weight: all 1.0, すべて1.0")
652
+
653
+ return down_lr_weight, mid_lr_weight, up_lr_weight
654
+
655
+
656
+ # lr_weightが0のblockをblock_dimsから除外する、外部から呼び出す可能性を考慮しておく
657
+ def remove_block_dims_and_alphas(
658
+ block_dims, block_alphas, conv_block_dims, conv_block_alphas, down_lr_weight, mid_lr_weight, up_lr_weight
659
+ ):
660
+ # set 0 to block dim without learning rate to remove the block
661
+ if down_lr_weight != None:
662
+ for i, lr in enumerate(down_lr_weight):
663
+ if lr == 0:
664
+ block_dims[i] = 0
665
+ if conv_block_dims is not None:
666
+ conv_block_dims[i] = 0
667
+ if mid_lr_weight != None:
668
+ if mid_lr_weight == 0:
669
+ block_dims[LoRANetwork.NUM_OF_BLOCKS] = 0
670
+ if conv_block_dims is not None:
671
+ conv_block_dims[LoRANetwork.NUM_OF_BLOCKS] = 0
672
+ if up_lr_weight != None:
673
+ for i, lr in enumerate(up_lr_weight):
674
+ if lr == 0:
675
+ block_dims[LoRANetwork.NUM_OF_BLOCKS + 1 + i] = 0
676
+ if conv_block_dims is not None:
677
+ conv_block_dims[LoRANetwork.NUM_OF_BLOCKS + 1 + i] = 0
678
+
679
+ return block_dims, block_alphas, conv_block_dims, conv_block_alphas
680
+
681
+
682
+ # 外部から呼び出す可能性を考慮しておく
683
+ def get_block_index(lora_name: str) -> int:
684
+ block_idx = -1 # invalid lora name
685
+
686
+ m = RE_UPDOWN.search(lora_name)
687
+ if m:
688
+ g = m.groups()
689
+ i = int(g[1])
690
+ j = int(g[3])
691
+ if g[2] == "resnets":
692
+ idx = 3 * i + j
693
+ elif g[2] == "attentions":
694
+ idx = 3 * i + j
695
+ elif g[2] == "upsamplers" or g[2] == "downsamplers":
696
+ idx = 3 * i + 2
697
+
698
+ if g[0] == "down":
699
+ block_idx = 1 + idx # 0に該当するLoRAは存在しない
700
+ elif g[0] == "up":
701
+ block_idx = LoRANetwork.NUM_OF_BLOCKS + 1 + idx
702
+
703
+ elif "mid_block_" in lora_name:
704
+ block_idx = LoRANetwork.NUM_OF_BLOCKS # idx=12
705
+
706
+ return block_idx
707
+
708
+
709
+ # Create network from weights for inference, weights are not loaded here (because can be merged)
710
+ def create_network_from_weights(multiplier, file, vae, text_encoder, unet, weights_sd=None, for_inference=False, **kwargs):
711
+ if weights_sd is None:
712
+ if os.path.splitext(file)[1] == ".safetensors":
713
+ from safetensors.torch import load_file, safe_open
714
+
715
+ weights_sd = load_file(file)
716
+ else:
717
+ weights_sd = torch.load(file, map_location="cpu")
718
+
719
+ # get dim/alpha mapping
720
+ modules_dim = {}
721
+ modules_alpha = {}
722
+ for key, value in weights_sd.items():
723
+ if "." not in key:
724
+ continue
725
+
726
+ lora_name = key.split(".")[0]
727
+ if "alpha" in key:
728
+ modules_alpha[lora_name] = value
729
+ elif "lora_down" in key:
730
+ dim = value.size()[0]
731
+ modules_dim[lora_name] = dim
732
+ # logger.info(lora_name, value.size(), dim)
733
+
734
+ # support old LoRA without alpha
735
+ for key in modules_dim.keys():
736
+ if key not in modules_alpha:
737
+ modules_alpha[key] = modules_dim[key]
738
+
739
+ module_class = LoRAInfModule if for_inference else LoRAModule
740
+
741
+ network = LoRANetwork(
742
+ text_encoder, unet, multiplier=multiplier, modules_dim=modules_dim, modules_alpha=modules_alpha, module_class=module_class
743
+ )
744
+
745
+ # block lr
746
+ down_lr_weight, mid_lr_weight, up_lr_weight = parse_block_lr_kwargs(kwargs)
747
+ if up_lr_weight is not None or mid_lr_weight is not None or down_lr_weight is not None:
748
+ network.set_block_lr_weight(up_lr_weight, mid_lr_weight, down_lr_weight)
749
+
750
+ return network, weights_sd
751
+
752
+
753
+ class LoRANetwork(torch.nn.Module):
754
+ NUM_OF_BLOCKS = 12 # フルモデル相当でのup,downの層の数
755
+
756
+ UNET_TARGET_REPLACE_MODULE = ["Transformer2DModel"]
757
+ UNET_TARGET_REPLACE_MODULE_CONV2D_3X3 = ["ResnetBlock2D", "Downsample2D", "Upsample2D"]
758
+ TEXT_ENCODER_TARGET_REPLACE_MODULE = ["CLIPAttention", "CLIPSdpaAttention", "CLIPMLP"]
759
+ LORA_PREFIX_UNET = "lora_unet"
760
+ LORA_PREFIX_TEXT_ENCODER = "lora_te"
761
+
762
+ # SDXL: must starts with LORA_PREFIX_TEXT_ENCODER
763
+ LORA_PREFIX_TEXT_ENCODER1 = "lora_te1"
764
+ LORA_PREFIX_TEXT_ENCODER2 = "lora_te2"
765
+
766
+ def __init__(
767
+ self,
768
+ text_encoder: Union[List[CLIPTextModel], CLIPTextModel],
769
+ unet,
770
+ multiplier: float = 1.0,
771
+ lora_dim: int = 4,
772
+ alpha: float = 1,
773
+ dropout: Optional[float] = None,
774
+ rank_dropout: Optional[float] = None,
775
+ module_dropout: Optional[float] = None,
776
+ conv_lora_dim: Optional[int] = None,
777
+ conv_alpha: Optional[float] = None,
778
+ block_dims: Optional[List[int]] = None,
779
+ block_alphas: Optional[List[float]] = None,
780
+ conv_block_dims: Optional[List[int]] = None,
781
+ conv_block_alphas: Optional[List[float]] = None,
782
+ modules_dim: Optional[Dict[str, int]] = None,
783
+ modules_alpha: Optional[Dict[str, int]] = None,
784
+ module_class: Type[object] = LoRAModule,
785
+ varbose: Optional[bool] = False,
786
+ ) -> None:
787
+ """
788
+ LoRA network: すごく引数が多いが、パターンは以下の通り
789
+ 1. lora_dimとalphaを指定
790
+ 2. lora_dim、alpha、conv_lora_dim、conv_alphaを指定
791
+ 3. block_dimsとblock_alphasを指定 : Conv2d3x3には適用しない
792
+ 4. block_dims、block_alphas、conv_block_dims、conv_block_alphasを指定 : Conv2d3x3にも適用する
793
+ 5. modules_dimとmodules_alphaを指定 (推論用)
794
+ """
795
+ super().__init__()
796
+ self.multiplier = multiplier
797
+
798
+ self.lora_dim = lora_dim
799
+ self.alpha = alpha
800
+ self.conv_lora_dim = conv_lora_dim
801
+ self.conv_alpha = conv_alpha
802
+ self.dropout = dropout
803
+ self.rank_dropout = rank_dropout
804
+ self.module_dropout = module_dropout
805
+
806
+ if modules_dim is not None:
807
+ logger.info(f"create LoRA network from weights")
808
+ elif block_dims is not None:
809
+ logger.info(f"create LoRA network from block_dims")
810
+ logger.info(f"neuron dropout: p={self.dropout}, rank dropout: p={self.rank_dropout}, module dropout: p={self.module_dropout}")
811
+ logger.info(f"block_dims: {block_dims}")
812
+ logger.info(f"block_alphas: {block_alphas}")
813
+ if conv_block_dims is not None:
814
+ logger.info(f"conv_block_dims: {conv_block_dims}")
815
+ logger.info(f"conv_block_alphas: {conv_block_alphas}")
816
+ else:
817
+ logger.info(f"create LoRA network. base dim (rank): {lora_dim}, alpha: {alpha}")
818
+ logger.info(f"neuron dropout: p={self.dropout}, rank dropout: p={self.rank_dropout}, module dropout: p={self.module_dropout}")
819
+ if self.conv_lora_dim is not None:
820
+ logger.info(f"apply LoRA to Conv2d with kernel size (3,3). dim (rank): {self.conv_lora_dim}, alpha: {self.conv_alpha}")
821
+
822
+ # create module instances
823
+ def create_modules(
824
+ is_unet: bool,
825
+ text_encoder_idx: Optional[int], # None, 1, 2
826
+ root_module: torch.nn.Module,
827
+ target_replace_modules: List[torch.nn.Module],
828
+ ) -> List[LoRAModule]:
829
+ prefix = (
830
+ self.LORA_PREFIX_UNET
831
+ if is_unet
832
+ else (
833
+ self.LORA_PREFIX_TEXT_ENCODER
834
+ if text_encoder_idx is None
835
+ else (self.LORA_PREFIX_TEXT_ENCODER1 if text_encoder_idx == 1 else self.LORA_PREFIX_TEXT_ENCODER2)
836
+ )
837
+ )
838
+ loras = []
839
+ skipped = []
840
+ for name, module in root_module.named_modules():
841
+ if module.__class__.__name__ in target_replace_modules:
842
+ for child_name, child_module in module.named_modules():
843
+ is_linear = child_module.__class__.__name__ == "Linear"
844
+ is_conv2d = child_module.__class__.__name__ == "Conv2d"
845
+ is_conv2d_1x1 = is_conv2d and child_module.kernel_size == (1, 1)
846
+
847
+ if is_linear or is_conv2d:
848
+ lora_name = prefix + "." + name + "." + child_name
849
+ lora_name = lora_name.replace(".", "_")
850
+
851
+ dim = None
852
+ alpha = None
853
+
854
+ if modules_dim is not None:
855
+ # モジュール指定あり
856
+ if lora_name in modules_dim:
857
+ dim = modules_dim[lora_name]
858
+ alpha = modules_alpha[lora_name]
859
+ elif is_unet and block_dims is not None:
860
+ # U-Netでblock_dims指定あり
861
+ block_idx = get_block_index(lora_name)
862
+ if is_linear or is_conv2d_1x1:
863
+ dim = block_dims[block_idx]
864
+ alpha = block_alphas[block_idx]
865
+ elif conv_block_dims is not None:
866
+ dim = conv_block_dims[block_idx]
867
+ alpha = conv_block_alphas[block_idx]
868
+ else:
869
+ # 通常、すべて対象とする
870
+ if is_linear or is_conv2d_1x1:
871
+ dim = self.lora_dim
872
+ alpha = self.alpha
873
+ elif self.conv_lora_dim is not None:
874
+ dim = self.conv_lora_dim
875
+ alpha = self.conv_alpha
876
+
877
+ if dim is None or dim == 0:
878
+ # skipした情報を出力
879
+ if is_linear or is_conv2d_1x1 or (self.conv_lora_dim is not None or conv_block_dims is not None):
880
+ skipped.append(lora_name)
881
+ continue
882
+
883
+ lora = module_class(
884
+ lora_name,
885
+ child_module,
886
+ self.multiplier,
887
+ dim,
888
+ alpha,
889
+ dropout=dropout,
890
+ rank_dropout=rank_dropout,
891
+ module_dropout=module_dropout,
892
+ )
893
+ loras.append(lora)
894
+ return loras, skipped
895
+
896
+ text_encoders = text_encoder if type(text_encoder) == list else [text_encoder]
897
+
898
+ # create LoRA for text encoder
899
+ # 毎回すべてのモジュールを作るのは無駄なので要検討
900
+ self.text_encoder_loras = []
901
+ skipped_te = []
902
+ for i, text_encoder in enumerate(text_encoders):
903
+ if len(text_encoders) > 1:
904
+ index = i + 1
905
+ logger.info(f"create LoRA for Text Encoder {index}:")
906
+ else:
907
+ index = None
908
+ logger.info(f"create LoRA for Text Encoder:")
909
+
910
+ text_encoder_loras, skipped = create_modules(False, index, text_encoder, LoRANetwork.TEXT_ENCODER_TARGET_REPLACE_MODULE)
911
+ self.text_encoder_loras.extend(text_encoder_loras)
912
+ skipped_te += skipped
913
+ logger.info(f"create LoRA for Text Encoder: {len(self.text_encoder_loras)} modules.")
914
+
915
+ # extend U-Net target modules if conv2d 3x3 is enabled, or load from weights
916
+ target_modules = LoRANetwork.UNET_TARGET_REPLACE_MODULE
917
+ if modules_dim is not None or self.conv_lora_dim is not None or conv_block_dims is not None:
918
+ target_modules += LoRANetwork.UNET_TARGET_REPLACE_MODULE_CONV2D_3X3
919
+
920
+ self.unet_loras, skipped_un = create_modules(True, None, unet, target_modules)
921
+ logger.info(f"create LoRA for U-Net: {len(self.unet_loras)} modules.")
922
+
923
+ skipped = skipped_te + skipped_un
924
+ if varbose and len(skipped) > 0:
925
+ logger.warning(
926
+ f"because block_lr_weight is 0 or dim (rank) is 0, {len(skipped)} LoRA modules are skipped / block_lr_weightまたはdim (rank)が0の為、次の{len(skipped)}個のLoRAモジュールはスキップされます:"
927
+ )
928
+ for name in skipped:
929
+ logger.info(f"\t{name}")
930
+
931
+ self.up_lr_weight: List[float] = None
932
+ self.down_lr_weight: List[float] = None
933
+ self.mid_lr_weight: float = None
934
+ self.block_lr = False
935
+
936
+ # assertion
937
+ names = set()
938
+ for lora in self.text_encoder_loras + self.unet_loras:
939
+ assert lora.lora_name not in names, f"duplicated lora name: {lora.lora_name}"
940
+ names.add(lora.lora_name)
941
+
942
+ def set_multiplier(self, multiplier):
943
+ self.multiplier = multiplier
944
+ for lora in self.text_encoder_loras + self.unet_loras:
945
+ lora.multiplier = self.multiplier
946
+
947
+ def load_weights(self, file):
948
+ if os.path.splitext(file)[1] == ".safetensors":
949
+ from safetensors.torch import load_file
950
+
951
+ weights_sd = load_file(file)
952
+ else:
953
+ weights_sd = torch.load(file, map_location="cpu")
954
+
955
+ info = self.load_state_dict(weights_sd, False)
956
+ return info
957
+
958
+ def apply_to(self, text_encoder, unet, apply_text_encoder=True, apply_unet=True):
959
+ if apply_text_encoder:
960
+ logger.info("enable LoRA for text encoder")
961
+ else:
962
+ self.text_encoder_loras = []
963
+
964
+ if apply_unet:
965
+ logger.info("enable LoRA for U-Net")
966
+ else:
967
+ self.unet_loras = []
968
+
969
+ for lora in self.text_encoder_loras + self.unet_loras:
970
+ lora.apply_to()
971
+ self.add_module(lora.lora_name, lora)
972
+
973
+ # マージできるかどうかを返す
974
+ def is_mergeable(self):
975
+ return True
976
+
977
+ # TODO refactor to common function with apply_to
978
+ def merge_to(self, text_encoder, unet, weights_sd, dtype, device):
979
+ apply_text_encoder = apply_unet = False
980
+ for key in weights_sd.keys():
981
+ if key.startswith(LoRANetwork.LORA_PREFIX_TEXT_ENCODER):
982
+ apply_text_encoder = True
983
+ elif key.startswith(LoRANetwork.LORA_PREFIX_UNET):
984
+ apply_unet = True
985
+
986
+ if apply_text_encoder:
987
+ logger.info("enable LoRA for text encoder")
988
+ else:
989
+ self.text_encoder_loras = []
990
+
991
+ if apply_unet:
992
+ logger.info("enable LoRA for U-Net")
993
+ else:
994
+ self.unet_loras = []
995
+
996
+ for lora in self.text_encoder_loras + self.unet_loras:
997
+ sd_for_lora = {}
998
+ for key in weights_sd.keys():
999
+ if key.startswith(lora.lora_name):
1000
+ sd_for_lora[key[len(lora.lora_name) + 1 :]] = weights_sd[key]
1001
+ lora.merge_to(sd_for_lora, dtype, device)
1002
+
1003
+ logger.info(f"weights are merged")
1004
+
1005
+ # 層別学習率用に層ごとの学習率に対する倍率を定義する 引数の順番が逆だがとりあえず気にしない
1006
+ def set_block_lr_weight(
1007
+ self,
1008
+ up_lr_weight: List[float] = None,
1009
+ mid_lr_weight: float = None,
1010
+ down_lr_weight: List[float] = None,
1011
+ ):
1012
+ self.block_lr = True
1013
+ self.down_lr_weight = down_lr_weight
1014
+ self.mid_lr_weight = mid_lr_weight
1015
+ self.up_lr_weight = up_lr_weight
1016
+
1017
+ def get_lr_weight(self, lora: LoRAModule) -> float:
1018
+ lr_weight = 1.0
1019
+ block_idx = get_block_index(lora.lora_name)
1020
+ if block_idx < 0:
1021
+ return lr_weight
1022
+
1023
+ if block_idx < LoRANetwork.NUM_OF_BLOCKS:
1024
+ if self.down_lr_weight != None:
1025
+ lr_weight = self.down_lr_weight[block_idx]
1026
+ elif block_idx == LoRANetwork.NUM_OF_BLOCKS:
1027
+ if self.mid_lr_weight != None:
1028
+ lr_weight = self.mid_lr_weight
1029
+ elif block_idx > LoRANetwork.NUM_OF_BLOCKS:
1030
+ if self.up_lr_weight != None:
1031
+ lr_weight = self.up_lr_weight[block_idx - LoRANetwork.NUM_OF_BLOCKS - 1]
1032
+
1033
+ return lr_weight
1034
+
1035
+ # 二つのText Encoderに別々の学習率を設定できるようにするといいかも
1036
+ def prepare_optimizer_params(self, text_encoder_lr, unet_lr, default_lr):
1037
+ self.requires_grad_(True)
1038
+ all_params = []
1039
+
1040
+ def enumerate_params(loras: List[LoRAModule]):
1041
+ params = []
1042
+ for lora in loras:
1043
+ # params.extend(lora.parameters())
1044
+ params.extend(lora.get_trainable_params())
1045
+ return params
1046
+
1047
+ if self.text_encoder_loras:
1048
+ param_data = {"params": enumerate_params(self.text_encoder_loras)}
1049
+ if text_encoder_lr is not None:
1050
+ param_data["lr"] = text_encoder_lr
1051
+ all_params.append(param_data)
1052
+
1053
+ if self.unet_loras:
1054
+ if self.block_lr:
1055
+ # 学習率のグラフをblockごとにしたいので、blockごとにloraを分類
1056
+ block_idx_to_lora = {}
1057
+ for lora in self.unet_loras:
1058
+ idx = get_block_index(lora.lora_name)
1059
+ if idx not in block_idx_to_lora:
1060
+ block_idx_to_lora[idx] = []
1061
+ block_idx_to_lora[idx].append(lora)
1062
+
1063
+ # blockごとにパラメータを設定する
1064
+ for idx, block_loras in block_idx_to_lora.items():
1065
+ param_data = {"params": enumerate_params(block_loras)}
1066
+
1067
+ if unet_lr is not None:
1068
+ param_data["lr"] = unet_lr * self.get_lr_weight(block_loras[0])
1069
+ elif default_lr is not None:
1070
+ param_data["lr"] = default_lr * self.get_lr_weight(block_loras[0])
1071
+ if ("lr" in param_data) and (param_data["lr"] == 0):
1072
+ continue
1073
+ all_params.append(param_data)
1074
+
1075
+ else:
1076
+ param_data = {"params": enumerate_params(self.unet_loras)}
1077
+ if unet_lr is not None:
1078
+ param_data["lr"] = unet_lr
1079
+ all_params.append(param_data)
1080
+
1081
+ return all_params
1082
+
1083
+ def enable_gradient_checkpointing(self):
1084
+ # not supported
1085
+ pass
1086
+
1087
+ def prepare_grad_etc(self, text_encoder, unet):
1088
+ self.requires_grad_(True)
1089
+
1090
+ def on_epoch_start(self, text_encoder, unet):
1091
+ self.train()
1092
+
1093
+ def get_trainable_params(self):
1094
+ return self.parameters()
1095
+
1096
+ def save_weights(self, file, dtype, metadata):
1097
+ if metadata is not None and len(metadata) == 0:
1098
+ metadata = None
1099
+
1100
+ state_dict = self.state_dict()
1101
+
1102
+ if dtype is not None:
1103
+ for key in list(state_dict.keys()):
1104
+ v = state_dict[key]
1105
+ v = v.detach().clone().to("cpu").to(dtype)
1106
+ state_dict[key] = v
1107
+
1108
+ if os.path.splitext(file)[1] == ".safetensors":
1109
+ from safetensors.torch import save_file
1110
+ from library import train_util
1111
+
1112
+ # Precalculate model hashes to save time on indexing
1113
+ if metadata is None:
1114
+ metadata = {}
1115
+ model_hash, legacy_hash = train_util.precalculate_safetensors_hashes(state_dict, metadata)
1116
+ metadata["sshs_model_hash"] = model_hash
1117
+ metadata["sshs_legacy_hash"] = legacy_hash
1118
+
1119
+ save_file(state_dict, file, metadata)
1120
+ else:
1121
+ torch.save(state_dict, file)
1122
+
1123
+ # mask is a tensor with values from 0 to 1
1124
+ def set_region(self, sub_prompt_index, is_last_network, mask):
1125
+ if mask.max() == 0:
1126
+ mask = torch.ones_like(mask)
1127
+
1128
+ self.mask = mask
1129
+ self.sub_prompt_index = sub_prompt_index
1130
+ self.is_last_network = is_last_network
1131
+
1132
+ for lora in self.text_encoder_loras + self.unet_loras:
1133
+ lora.set_network(self)
1134
+
1135
+ def set_current_generation(self, batch_size, num_sub_prompts, width, height, shared):
1136
+ self.batch_size = batch_size
1137
+ self.num_sub_prompts = num_sub_prompts
1138
+ self.current_size = (height, width)
1139
+ self.shared = shared
1140
+
1141
+ # create masks
1142
+ mask = self.mask
1143
+ mask_dic = {}
1144
+ mask = mask.unsqueeze(0).unsqueeze(1) # b(1),c(1),h,w
1145
+ ref_weight = self.text_encoder_loras[0].lora_down.weight if self.text_encoder_loras else self.unet_loras[0].lora_down.weight
1146
+ dtype = ref_weight.dtype
1147
+ device = ref_weight.device
1148
+
1149
+ def resize_add(mh, mw):
1150
+ # logger.info(mh, mw, mh * mw)
1151
+ m = torch.nn.functional.interpolate(mask, (mh, mw), mode="bilinear") # doesn't work in bf16
1152
+ m = m.to(device, dtype=dtype)
1153
+ mask_dic[mh * mw] = m
1154
+
1155
+ h = height // 8
1156
+ w = width // 8
1157
+ for _ in range(4):
1158
+ resize_add(h, w)
1159
+ if h % 2 == 1 or w % 2 == 1: # add extra shape if h/w is not divisible by 2
1160
+ resize_add(h + h % 2, w + w % 2)
1161
+ h = (h + 1) // 2
1162
+ w = (w + 1) // 2
1163
+
1164
+ self.mask_dic = mask_dic
1165
+
1166
+ def backup_weights(self):
1167
+ # 重みのバックアップを行う
1168
+ loras: List[LoRAInfModule] = self.text_encoder_loras + self.unet_loras
1169
+ for lora in loras:
1170
+ org_module = lora.org_module_ref[0]
1171
+ if not hasattr(org_module, "_lora_org_weight"):
1172
+ sd = org_module.state_dict()
1173
+ org_module._lora_org_weight = sd["weight"].detach().clone()
1174
+ org_module._lora_restored = True
1175
+
1176
+ def restore_weights(self):
1177
+ # 重みのリストアを行う
1178
+ loras: List[LoRAInfModule] = self.text_encoder_loras + self.unet_loras
1179
+ for lora in loras:
1180
+ org_module = lora.org_module_ref[0]
1181
+ if not org_module._lora_restored:
1182
+ sd = org_module.state_dict()
1183
+ sd["weight"] = org_module._lora_org_weight
1184
+ org_module.load_state_dict(sd)
1185
+ org_module._lora_restored = True
1186
+
1187
+ def pre_calculation(self):
1188
+ # 事前計算を行う
1189
+ loras: List[LoRAInfModule] = self.text_encoder_loras + self.unet_loras
1190
+ for lora in loras:
1191
+ org_module = lora.org_module_ref[0]
1192
+ sd = org_module.state_dict()
1193
+
1194
+ org_weight = sd["weight"]
1195
+ lora_weight = lora.get_weight().to(org_weight.device, dtype=org_weight.dtype)
1196
+ sd["weight"] = org_weight + lora_weight
1197
+ assert sd["weight"].shape == org_weight.shape
1198
+ org_module.load_state_dict(sd)
1199
+
1200
+ org_module._lora_restored = False
1201
+ lora.enabled = False
1202
+
1203
+ def apply_max_norm_regularization(self, max_norm_value, device):
1204
+ downkeys = []
1205
+ upkeys = []
1206
+ alphakeys = []
1207
+ norms = []
1208
+ keys_scaled = 0
1209
+
1210
+ state_dict = self.state_dict()
1211
+ for key in state_dict.keys():
1212
+ if "lora_down" in key and "weight" in key:
1213
+ downkeys.append(key)
1214
+ upkeys.append(key.replace("lora_down", "lora_up"))
1215
+ alphakeys.append(key.replace("lora_down.weight", "alpha"))
1216
+
1217
+ for i in range(len(downkeys)):
1218
+ down = state_dict[downkeys[i]].to(device)
1219
+ up = state_dict[upkeys[i]].to(device)
1220
+ alpha = state_dict[alphakeys[i]].to(device)
1221
+ dim = down.shape[0]
1222
+ scale = alpha / dim
1223
+
1224
+ if up.shape[2:] == (1, 1) and down.shape[2:] == (1, 1):
1225
+ updown = (up.squeeze(2).squeeze(2) @ down.squeeze(2).squeeze(2)).unsqueeze(2).unsqueeze(3)
1226
+ elif up.shape[2:] == (3, 3) or down.shape[2:] == (3, 3):
1227
+ updown = torch.nn.functional.conv2d(down.permute(1, 0, 2, 3), up).permute(1, 0, 2, 3)
1228
+ else:
1229
+ updown = up @ down
1230
+
1231
+ updown *= scale
1232
+
1233
+ norm = updown.norm().clamp(min=max_norm_value / 2)
1234
+ desired = torch.clamp(norm, max=max_norm_value)
1235
+ ratio = desired.cpu() / norm.cpu()
1236
+ sqrt_ratio = ratio**0.5
1237
+ if ratio != 1:
1238
+ keys_scaled += 1
1239
+ state_dict[upkeys[i]] *= sqrt_ratio
1240
+ state_dict[downkeys[i]] *= sqrt_ratio
1241
+ scalednorm = updown.norm() * ratio
1242
+ norms.append(scalednorm.item())
1243
+
1244
+ return keys_scaled, sum(norms) / len(norms), max(norms)
lora_interrogator.py ADDED
@@ -0,0 +1,146 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+
3
+ from tqdm import tqdm
4
+ from library import model_util
5
+ import library.train_util as train_util
6
+ import argparse
7
+ from transformers import CLIPTokenizer
8
+
9
+ import torch
10
+ from library.device_utils import init_ipex, get_preferred_device
11
+ init_ipex()
12
+
13
+ import library.model_util as model_util
14
+ import lora
15
+ from library.utils import setup_logging
16
+ setup_logging()
17
+ import logging
18
+ logger = logging.getLogger(__name__)
19
+
20
+ TOKENIZER_PATH = "openai/clip-vit-large-patch14"
21
+ V2_STABLE_DIFFUSION_PATH = "stabilityai/stable-diffusion-2" # ここからtokenizerだけ使う
22
+
23
+ DEVICE = get_preferred_device()
24
+
25
+
26
+ def interrogate(args):
27
+ weights_dtype = torch.float16
28
+
29
+ # いろいろ準備する
30
+ logger.info(f"loading SD model: {args.sd_model}")
31
+ args.pretrained_model_name_or_path = args.sd_model
32
+ args.vae = None
33
+ text_encoder, vae, unet, _ = train_util._load_target_model(args,weights_dtype, DEVICE)
34
+
35
+ logger.info(f"loading LoRA: {args.model}")
36
+ network, weights_sd = lora.create_network_from_weights(1.0, args.model, vae, text_encoder, unet)
37
+
38
+ # text encoder向けの重みがあるかチェックする:本当はlora側でやるのがいい
39
+ has_te_weight = False
40
+ for key in weights_sd.keys():
41
+ if 'lora_te' in key:
42
+ has_te_weight = True
43
+ break
44
+ if not has_te_weight:
45
+ logger.error("This LoRA does not have modules for Text Encoder, cannot interrogate / このLoRAはText Encoder向けのモジュールがないため調査できません")
46
+ return
47
+ del vae
48
+
49
+ logger.info("loading tokenizer")
50
+ if args.v2:
51
+ tokenizer: CLIPTokenizer = CLIPTokenizer.from_pretrained(V2_STABLE_DIFFUSION_PATH, subfolder="tokenizer")
52
+ else:
53
+ tokenizer: CLIPTokenizer = CLIPTokenizer.from_pretrained(TOKENIZER_PATH) # , model_max_length=max_token_length + 2)
54
+
55
+ text_encoder.to(DEVICE, dtype=weights_dtype)
56
+ text_encoder.eval()
57
+ unet.to(DEVICE, dtype=weights_dtype)
58
+ unet.eval() # U-Netは呼び出さないので不要だけど
59
+
60
+ # トークンをひとつひとつ当たっていく
61
+ token_id_start = 0
62
+ token_id_end = max(tokenizer.all_special_ids)
63
+ logger.info(f"interrogate tokens are: {token_id_start} to {token_id_end}")
64
+
65
+ def get_all_embeddings(text_encoder):
66
+ embs = []
67
+ with torch.no_grad():
68
+ for token_id in tqdm(range(token_id_start, token_id_end + 1, args.batch_size)):
69
+ batch = []
70
+ for tid in range(token_id, min(token_id_end + 1, token_id + args.batch_size)):
71
+ tokens = [tokenizer.bos_token_id, tid, tokenizer.eos_token_id]
72
+ # tokens = [tid] # こちらは結果がいまひとつ
73
+ batch.append(tokens)
74
+
75
+ # batch_embs = text_encoder(torch.tensor(batch).to(DEVICE))[0].to("cpu") # bos/eosも含めたほうが差が出るようだ [:, 1]
76
+ # clip skip対応
77
+ batch = torch.tensor(batch).to(DEVICE)
78
+ if args.clip_skip is None:
79
+ encoder_hidden_states = text_encoder(batch)[0]
80
+ else:
81
+ enc_out = text_encoder(batch, output_hidden_states=True, return_dict=True)
82
+ encoder_hidden_states = enc_out['hidden_states'][-args.clip_skip]
83
+ encoder_hidden_states = text_encoder.text_model.final_layer_norm(encoder_hidden_states)
84
+ encoder_hidden_states = encoder_hidden_states.to("cpu")
85
+
86
+ embs.extend(encoder_hidden_states)
87
+ return torch.stack(embs)
88
+
89
+ logger.info("get original text encoder embeddings.")
90
+ orig_embs = get_all_embeddings(text_encoder)
91
+
92
+ network.apply_to(text_encoder, unet, True, len(network.unet_loras) > 0)
93
+ info = network.load_state_dict(weights_sd, strict=False)
94
+ logger.info(f"Loading LoRA weights: {info}")
95
+
96
+ network.to(DEVICE, dtype=weights_dtype)
97
+ network.eval()
98
+
99
+ del unet
100
+
101
+ logger.info("You can ignore warning messages start with '_IncompatibleKeys' (LoRA model does not have alpha because trained by older script) / '_IncompatibleKeys'の警告は無視して構いません(以前のスクリプトで学習されたLoRAモデルのためalphaの定義がありません)")
102
+ logger.info("get text encoder embeddings with lora.")
103
+ lora_embs = get_all_embeddings(text_encoder)
104
+
105
+ # 比べる:とりあえず単純に差分の絶対値で
106
+ logger.info("comparing...")
107
+ diffs = {}
108
+ for i, (orig_emb, lora_emb) in enumerate(zip(orig_embs, tqdm(lora_embs))):
109
+ diff = torch.mean(torch.abs(orig_emb - lora_emb))
110
+ # diff = torch.mean(torch.cosine_similarity(orig_emb, lora_emb, dim=1)) # うまく検出できない
111
+ diff = float(diff.detach().to('cpu').numpy())
112
+ diffs[token_id_start + i] = diff
113
+
114
+ diffs_sorted = sorted(diffs.items(), key=lambda x: -x[1])
115
+
116
+ # 結果を表示する
117
+ print("top 100:")
118
+ for i, (token, diff) in enumerate(diffs_sorted[:100]):
119
+ # if diff < 1e-6:
120
+ # break
121
+ string = tokenizer.convert_tokens_to_string(tokenizer.convert_ids_to_tokens([token]))
122
+ print(f"[{i:3d}]: {token:5d} {string:<20s}: {diff:.5f}")
123
+
124
+
125
+ def setup_parser() -> argparse.ArgumentParser:
126
+ parser = argparse.ArgumentParser()
127
+
128
+ parser.add_argument("--v2", action='store_true',
129
+ help='load Stable Diffusion v2.x model / Stable Diffusion 2.xのモデルを読み込む')
130
+ parser.add_argument("--sd_model", type=str, default=None,
131
+ help="Stable Diffusion model to load: ckpt or safetensors file / 読み込むSDのモデル、ckptまたはsafetensors")
132
+ parser.add_argument("--model", type=str, default=None,
133
+ help="LoRA model to interrogate: ckpt or safetensors file / 調査するLoRAモデル、ckptまたはsafetensors")
134
+ parser.add_argument("--batch_size", type=int, default=16,
135
+ help="batch size for processing with Text Encoder / Text Encoderで処理するときのバッチサイズ")
136
+ parser.add_argument("--clip_skip", type=int, default=None,
137
+ help="use output of nth layer from back of text encoder (n>=1) / text encoderの後ろからn番目の層の出力を用いる(nは1以上)")
138
+
139
+ return parser
140
+
141
+
142
+ if __name__ == '__main__':
143
+ parser = setup_parser()
144
+
145
+ args = parser.parse_args()
146
+ interrogate(args)
lpw_stable_diffusion.py ADDED
@@ -0,0 +1,1233 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # copy from https://github.com/huggingface/diffusers/blob/main/examples/community/lpw_stable_diffusion.py
2
+ # and modify to support SD2.x
3
+
4
+ import inspect
5
+ import re
6
+ from typing import Callable, List, Optional, Union
7
+
8
+ import numpy as np
9
+ import PIL.Image
10
+ import torch
11
+ from packaging import version
12
+ from transformers import CLIPFeatureExtractor, CLIPTextModel, CLIPTokenizer, CLIPVisionModelWithProjection
13
+
14
+ import diffusers
15
+ from diffusers import SchedulerMixin, StableDiffusionPipeline
16
+ from diffusers.models import AutoencoderKL, UNet2DConditionModel
17
+ from diffusers.pipelines.stable_diffusion import StableDiffusionPipelineOutput, StableDiffusionSafetyChecker
18
+ from diffusers.utils import logging
19
+
20
+ try:
21
+ from diffusers.utils import PIL_INTERPOLATION
22
+ except ImportError:
23
+ if version.parse(version.parse(PIL.__version__).base_version) >= version.parse("9.1.0"):
24
+ PIL_INTERPOLATION = {
25
+ "linear": PIL.Image.Resampling.BILINEAR,
26
+ "bilinear": PIL.Image.Resampling.BILINEAR,
27
+ "bicubic": PIL.Image.Resampling.BICUBIC,
28
+ "lanczos": PIL.Image.Resampling.LANCZOS,
29
+ "nearest": PIL.Image.Resampling.NEAREST,
30
+ }
31
+ else:
32
+ PIL_INTERPOLATION = {
33
+ "linear": PIL.Image.LINEAR,
34
+ "bilinear": PIL.Image.BILINEAR,
35
+ "bicubic": PIL.Image.BICUBIC,
36
+ "lanczos": PIL.Image.LANCZOS,
37
+ "nearest": PIL.Image.NEAREST,
38
+ }
39
+ # ------------------------------------------------------------------------------
40
+
41
+ logger = logging.get_logger(__name__) # pylint: disable=invalid-name
42
+
43
+ re_attention = re.compile(
44
+ r"""
45
+ \\\(|
46
+ \\\)|
47
+ \\\[|
48
+ \\]|
49
+ \\\\|
50
+ \\|
51
+ \(|
52
+ \[|
53
+ :([+-]?[.\d]+)\)|
54
+ \)|
55
+ ]|
56
+ [^\\()\[\]:]+|
57
+ :
58
+ """,
59
+ re.X,
60
+ )
61
+
62
+
63
+ def parse_prompt_attention(text):
64
+ """
65
+ Parses a string with attention tokens and returns a list of pairs: text and its associated weight.
66
+ Accepted tokens are:
67
+ (abc) - increases attention to abc by a multiplier of 1.1
68
+ (abc:3.12) - increases attention to abc by a multiplier of 3.12
69
+ [abc] - decreases attention to abc by a multiplier of 1.1
70
+ \( - literal character '('
71
+ \[ - literal character '['
72
+ \) - literal character ')'
73
+ \] - literal character ']'
74
+ \\ - literal character '\'
75
+ anything else - just text
76
+ >>> parse_prompt_attention('normal text')
77
+ [['normal text', 1.0]]
78
+ >>> parse_prompt_attention('an (important) word')
79
+ [['an ', 1.0], ['important', 1.1], [' word', 1.0]]
80
+ >>> parse_prompt_attention('(unbalanced')
81
+ [['unbalanced', 1.1]]
82
+ >>> parse_prompt_attention('\(literal\]')
83
+ [['(literal]', 1.0]]
84
+ >>> parse_prompt_attention('(unnecessary)(parens)')
85
+ [['unnecessaryparens', 1.1]]
86
+ >>> parse_prompt_attention('a (((house:1.3)) [on] a (hill:0.5), sun, (((sky))).')
87
+ [['a ', 1.0],
88
+ ['house', 1.5730000000000004],
89
+ [' ', 1.1],
90
+ ['on', 1.0],
91
+ [' a ', 1.1],
92
+ ['hill', 0.55],
93
+ [', sun, ', 1.1],
94
+ ['sky', 1.4641000000000006],
95
+ ['.', 1.1]]
96
+ """
97
+
98
+ res = []
99
+ round_brackets = []
100
+ square_brackets = []
101
+
102
+ round_bracket_multiplier = 1.1
103
+ square_bracket_multiplier = 1 / 1.1
104
+
105
+ def multiply_range(start_position, multiplier):
106
+ for p in range(start_position, len(res)):
107
+ res[p][1] *= multiplier
108
+
109
+ for m in re_attention.finditer(text):
110
+ text = m.group(0)
111
+ weight = m.group(1)
112
+
113
+ if text.startswith("\\"):
114
+ res.append([text[1:], 1.0])
115
+ elif text == "(":
116
+ round_brackets.append(len(res))
117
+ elif text == "[":
118
+ square_brackets.append(len(res))
119
+ elif weight is not None and len(round_brackets) > 0:
120
+ multiply_range(round_brackets.pop(), float(weight))
121
+ elif text == ")" and len(round_brackets) > 0:
122
+ multiply_range(round_brackets.pop(), round_bracket_multiplier)
123
+ elif text == "]" and len(square_brackets) > 0:
124
+ multiply_range(square_brackets.pop(), square_bracket_multiplier)
125
+ else:
126
+ res.append([text, 1.0])
127
+
128
+ for pos in round_brackets:
129
+ multiply_range(pos, round_bracket_multiplier)
130
+
131
+ for pos in square_brackets:
132
+ multiply_range(pos, square_bracket_multiplier)
133
+
134
+ if len(res) == 0:
135
+ res = [["", 1.0]]
136
+
137
+ # merge runs of identical weights
138
+ i = 0
139
+ while i + 1 < len(res):
140
+ if res[i][1] == res[i + 1][1]:
141
+ res[i][0] += res[i + 1][0]
142
+ res.pop(i + 1)
143
+ else:
144
+ i += 1
145
+
146
+ return res
147
+
148
+
149
+ def get_prompts_with_weights(pipe: StableDiffusionPipeline, prompt: List[str], max_length: int):
150
+ r"""
151
+ Tokenize a list of prompts and return its tokens with weights of each token.
152
+
153
+ No padding, starting or ending token is included.
154
+ """
155
+ tokens = []
156
+ weights = []
157
+ truncated = False
158
+ for text in prompt:
159
+ texts_and_weights = parse_prompt_attention(text)
160
+ text_token = []
161
+ text_weight = []
162
+ for word, weight in texts_and_weights:
163
+ # tokenize and discard the starting and the ending token
164
+ token = pipe.tokenizer(word).input_ids[1:-1]
165
+ text_token += token
166
+ # copy the weight by length of token
167
+ text_weight += [weight] * len(token)
168
+ # stop if the text is too long (longer than truncation limit)
169
+ if len(text_token) > max_length:
170
+ truncated = True
171
+ break
172
+ # truncate
173
+ if len(text_token) > max_length:
174
+ truncated = True
175
+ text_token = text_token[:max_length]
176
+ text_weight = text_weight[:max_length]
177
+ tokens.append(text_token)
178
+ weights.append(text_weight)
179
+ if truncated:
180
+ logger.warning("Prompt was truncated. Try to shorten the prompt or increase max_embeddings_multiples")
181
+ return tokens, weights
182
+
183
+
184
+ def pad_tokens_and_weights(tokens, weights, max_length, bos, eos, no_boseos_middle=True, chunk_length=77):
185
+ r"""
186
+ Pad the tokens (with starting and ending tokens) and weights (with 1.0) to max_length.
187
+ """
188
+ max_embeddings_multiples = (max_length - 2) // (chunk_length - 2)
189
+ weights_length = max_length if no_boseos_middle else max_embeddings_multiples * chunk_length
190
+ for i in range(len(tokens)):
191
+ tokens[i] = [bos] + tokens[i] + [eos] * (max_length - 1 - len(tokens[i]))
192
+ if no_boseos_middle:
193
+ weights[i] = [1.0] + weights[i] + [1.0] * (max_length - 1 - len(weights[i]))
194
+ else:
195
+ w = []
196
+ if len(weights[i]) == 0:
197
+ w = [1.0] * weights_length
198
+ else:
199
+ for j in range(max_embeddings_multiples):
200
+ w.append(1.0) # weight for starting token in this chunk
201
+ w += weights[i][j * (chunk_length - 2) : min(len(weights[i]), (j + 1) * (chunk_length - 2))]
202
+ w.append(1.0) # weight for ending token in this chunk
203
+ w += [1.0] * (weights_length - len(w))
204
+ weights[i] = w[:]
205
+
206
+ return tokens, weights
207
+
208
+
209
+ def get_unweighted_text_embeddings(
210
+ pipe: StableDiffusionPipeline,
211
+ text_input: torch.Tensor,
212
+ chunk_length: int,
213
+ clip_skip: int,
214
+ eos: int,
215
+ pad: int,
216
+ no_boseos_middle: Optional[bool] = True,
217
+ ):
218
+ """
219
+ When the length of tokens is a multiple of the capacity of the text encoder,
220
+ it should be split into chunks and sent to the text encoder individually.
221
+ """
222
+ max_embeddings_multiples = (text_input.shape[1] - 2) // (chunk_length - 2)
223
+ if max_embeddings_multiples > 1:
224
+ text_embeddings = []
225
+ for i in range(max_embeddings_multiples):
226
+ # extract the i-th chunk
227
+ text_input_chunk = text_input[:, i * (chunk_length - 2) : (i + 1) * (chunk_length - 2) + 2].clone()
228
+
229
+ # cover the head and the tail by the starting and the ending tokens
230
+ text_input_chunk[:, 0] = text_input[0, 0]
231
+ if pad == eos: # v1
232
+ text_input_chunk[:, -1] = text_input[0, -1]
233
+ else: # v2
234
+ for j in range(len(text_input_chunk)):
235
+ if text_input_chunk[j, -1] != eos and text_input_chunk[j, -1] != pad: # 最後に普通の文字がある
236
+ text_input_chunk[j, -1] = eos
237
+ if text_input_chunk[j, 1] == pad: # BOSだけであとはPAD
238
+ text_input_chunk[j, 1] = eos
239
+
240
+ if clip_skip is None or clip_skip == 1:
241
+ text_embedding = pipe.text_encoder(text_input_chunk)[0]
242
+ else:
243
+ enc_out = pipe.text_encoder(text_input_chunk, output_hidden_states=True, return_dict=True)
244
+ text_embedding = enc_out["hidden_states"][-clip_skip]
245
+ text_embedding = pipe.text_encoder.text_model.final_layer_norm(text_embedding)
246
+
247
+ if no_boseos_middle:
248
+ if i == 0:
249
+ # discard the ending token
250
+ text_embedding = text_embedding[:, :-1]
251
+ elif i == max_embeddings_multiples - 1:
252
+ # discard the starting token
253
+ text_embedding = text_embedding[:, 1:]
254
+ else:
255
+ # discard both starting and ending tokens
256
+ text_embedding = text_embedding[:, 1:-1]
257
+
258
+ text_embeddings.append(text_embedding)
259
+ text_embeddings = torch.concat(text_embeddings, axis=1)
260
+ else:
261
+ if clip_skip is None or clip_skip == 1:
262
+ text_embeddings = pipe.text_encoder(text_input)[0]
263
+ else:
264
+ enc_out = pipe.text_encoder(text_input, output_hidden_states=True, return_dict=True)
265
+ text_embeddings = enc_out["hidden_states"][-clip_skip]
266
+ text_embeddings = pipe.text_encoder.text_model.final_layer_norm(text_embeddings)
267
+ return text_embeddings
268
+
269
+
270
+ def get_weighted_text_embeddings(
271
+ pipe: StableDiffusionPipeline,
272
+ prompt: Union[str, List[str]],
273
+ uncond_prompt: Optional[Union[str, List[str]]] = None,
274
+ max_embeddings_multiples: Optional[int] = 3,
275
+ no_boseos_middle: Optional[bool] = False,
276
+ skip_parsing: Optional[bool] = False,
277
+ skip_weighting: Optional[bool] = False,
278
+ clip_skip=None,
279
+ ):
280
+ r"""
281
+ Prompts can be assigned with local weights using brackets. For example,
282
+ prompt 'A (very beautiful) masterpiece' highlights the words 'very beautiful',
283
+ and the embedding tokens corresponding to the words get multiplied by a constant, 1.1.
284
+
285
+ Also, to regularize of the embedding, the weighted embedding would be scaled to preserve the original mean.
286
+
287
+ Args:
288
+ pipe (`StableDiffusionPipeline`):
289
+ Pipe to provide access to the tokenizer and the text encoder.
290
+ prompt (`str` or `List[str]`):
291
+ The prompt or prompts to guide the image generation.
292
+ uncond_prompt (`str` or `List[str]`):
293
+ The unconditional prompt or prompts for guide the image generation. If unconditional prompt
294
+ is provided, the embeddings of prompt and uncond_prompt are concatenated.
295
+ max_embeddings_multiples (`int`, *optional*, defaults to `3`):
296
+ The max multiple length of prompt embeddings compared to the max output length of text encoder.
297
+ no_boseos_middle (`bool`, *optional*, defaults to `False`):
298
+ If the length of text token is multiples of the capacity of text encoder, whether reserve the starting and
299
+ ending token in each of the chunk in the middle.
300
+ skip_parsing (`bool`, *optional*, defaults to `False`):
301
+ Skip the parsing of brackets.
302
+ skip_weighting (`bool`, *optional*, defaults to `False`):
303
+ Skip the weighting. When the parsing is skipped, it is forced True.
304
+ """
305
+ max_length = (pipe.tokenizer.model_max_length - 2) * max_embeddings_multiples + 2
306
+ if isinstance(prompt, str):
307
+ prompt = [prompt]
308
+
309
+ if not skip_parsing:
310
+ prompt_tokens, prompt_weights = get_prompts_with_weights(pipe, prompt, max_length - 2)
311
+ if uncond_prompt is not None:
312
+ if isinstance(uncond_prompt, str):
313
+ uncond_prompt = [uncond_prompt]
314
+ uncond_tokens, uncond_weights = get_prompts_with_weights(pipe, uncond_prompt, max_length - 2)
315
+ else:
316
+ prompt_tokens = [token[1:-1] for token in pipe.tokenizer(prompt, max_length=max_length, truncation=True).input_ids]
317
+ prompt_weights = [[1.0] * len(token) for token in prompt_tokens]
318
+ if uncond_prompt is not None:
319
+ if isinstance(uncond_prompt, str):
320
+ uncond_prompt = [uncond_prompt]
321
+ uncond_tokens = [
322
+ token[1:-1] for token in pipe.tokenizer(uncond_prompt, max_length=max_length, truncation=True).input_ids
323
+ ]
324
+ uncond_weights = [[1.0] * len(token) for token in uncond_tokens]
325
+
326
+ # round up the longest length of tokens to a multiple of (model_max_length - 2)
327
+ max_length = max([len(token) for token in prompt_tokens])
328
+ if uncond_prompt is not None:
329
+ max_length = max(max_length, max([len(token) for token in uncond_tokens]))
330
+
331
+ max_embeddings_multiples = min(
332
+ max_embeddings_multiples,
333
+ (max_length - 1) // (pipe.tokenizer.model_max_length - 2) + 1,
334
+ )
335
+ max_embeddings_multiples = max(1, max_embeddings_multiples)
336
+ max_length = (pipe.tokenizer.model_max_length - 2) * max_embeddings_multiples + 2
337
+
338
+ # pad the length of tokens and weights
339
+ bos = pipe.tokenizer.bos_token_id
340
+ eos = pipe.tokenizer.eos_token_id
341
+ pad = pipe.tokenizer.pad_token_id
342
+ prompt_tokens, prompt_weights = pad_tokens_and_weights(
343
+ prompt_tokens,
344
+ prompt_weights,
345
+ max_length,
346
+ bos,
347
+ eos,
348
+ no_boseos_middle=no_boseos_middle,
349
+ chunk_length=pipe.tokenizer.model_max_length,
350
+ )
351
+ prompt_tokens = torch.tensor(prompt_tokens, dtype=torch.long, device=pipe.device)
352
+ if uncond_prompt is not None:
353
+ uncond_tokens, uncond_weights = pad_tokens_and_weights(
354
+ uncond_tokens,
355
+ uncond_weights,
356
+ max_length,
357
+ bos,
358
+ eos,
359
+ no_boseos_middle=no_boseos_middle,
360
+ chunk_length=pipe.tokenizer.model_max_length,
361
+ )
362
+ uncond_tokens = torch.tensor(uncond_tokens, dtype=torch.long, device=pipe.device)
363
+
364
+ # get the embeddings
365
+ text_embeddings = get_unweighted_text_embeddings(
366
+ pipe,
367
+ prompt_tokens,
368
+ pipe.tokenizer.model_max_length,
369
+ clip_skip,
370
+ eos,
371
+ pad,
372
+ no_boseos_middle=no_boseos_middle,
373
+ )
374
+ prompt_weights = torch.tensor(prompt_weights, dtype=text_embeddings.dtype, device=pipe.device)
375
+ if uncond_prompt is not None:
376
+ uncond_embeddings = get_unweighted_text_embeddings(
377
+ pipe,
378
+ uncond_tokens,
379
+ pipe.tokenizer.model_max_length,
380
+ clip_skip,
381
+ eos,
382
+ pad,
383
+ no_boseos_middle=no_boseos_middle,
384
+ )
385
+ uncond_weights = torch.tensor(uncond_weights, dtype=uncond_embeddings.dtype, device=pipe.device)
386
+
387
+ # assign weights to the prompts and normalize in the sense of mean
388
+ # TODO: should we normalize by chunk or in a whole (current implementation)?
389
+ if (not skip_parsing) and (not skip_weighting):
390
+ previous_mean = text_embeddings.float().mean(axis=[-2, -1]).to(text_embeddings.dtype)
391
+ text_embeddings *= prompt_weights.unsqueeze(-1)
392
+ current_mean = text_embeddings.float().mean(axis=[-2, -1]).to(text_embeddings.dtype)
393
+ text_embeddings *= (previous_mean / current_mean).unsqueeze(-1).unsqueeze(-1)
394
+ if uncond_prompt is not None:
395
+ previous_mean = uncond_embeddings.float().mean(axis=[-2, -1]).to(uncond_embeddings.dtype)
396
+ uncond_embeddings *= uncond_weights.unsqueeze(-1)
397
+ current_mean = uncond_embeddings.float().mean(axis=[-2, -1]).to(uncond_embeddings.dtype)
398
+ uncond_embeddings *= (previous_mean / current_mean).unsqueeze(-1).unsqueeze(-1)
399
+
400
+ if uncond_prompt is not None:
401
+ return text_embeddings, uncond_embeddings
402
+ return text_embeddings, None
403
+
404
+
405
+ def preprocess_image(image):
406
+ w, h = image.size
407
+ w, h = map(lambda x: x - x % 32, (w, h)) # resize to integer multiple of 32
408
+ image = image.resize((w, h), resample=PIL_INTERPOLATION["lanczos"])
409
+ image = np.array(image).astype(np.float32) / 255.0
410
+ image = image[None].transpose(0, 3, 1, 2)
411
+ image = torch.from_numpy(image)
412
+ return 2.0 * image - 1.0
413
+
414
+
415
+ def preprocess_mask(mask, scale_factor=8):
416
+ mask = mask.convert("L")
417
+ w, h = mask.size
418
+ w, h = map(lambda x: x - x % 32, (w, h)) # resize to integer multiple of 32
419
+ mask = mask.resize((w // scale_factor, h // scale_factor), resample=PIL_INTERPOLATION["nearest"])
420
+ mask = np.array(mask).astype(np.float32) / 255.0
421
+ mask = np.tile(mask, (4, 1, 1))
422
+ mask = mask[None].transpose(0, 1, 2, 3) # what does this step do?
423
+ mask = 1 - mask # repaint white, keep black
424
+ mask = torch.from_numpy(mask)
425
+ return mask
426
+
427
+
428
+ def prepare_controlnet_image(
429
+ image: PIL.Image.Image,
430
+ width: int,
431
+ height: int,
432
+ batch_size: int,
433
+ num_images_per_prompt: int,
434
+ device: torch.device,
435
+ dtype: torch.dtype,
436
+ do_classifier_free_guidance: bool = False,
437
+ guess_mode: bool = False,
438
+ ):
439
+ if not isinstance(image, torch.Tensor):
440
+ if isinstance(image, PIL.Image.Image):
441
+ image = [image]
442
+
443
+ if isinstance(image[0], PIL.Image.Image):
444
+ images = []
445
+
446
+ for image_ in image:
447
+ image_ = image_.convert("RGB")
448
+ image_ = image_.resize((width, height), resample=PIL_INTERPOLATION["lanczos"])
449
+ image_ = np.array(image_)
450
+ image_ = image_[None, :]
451
+ images.append(image_)
452
+
453
+ image = images
454
+
455
+ image = np.concatenate(image, axis=0)
456
+ image = np.array(image).astype(np.float32) / 255.0
457
+ image = image.transpose(0, 3, 1, 2)
458
+ image = torch.from_numpy(image)
459
+ elif isinstance(image[0], torch.Tensor):
460
+ image = torch.cat(image, dim=0)
461
+
462
+ image_batch_size = image.shape[0]
463
+
464
+ if image_batch_size == 1:
465
+ repeat_by = batch_size
466
+ else:
467
+ # image batch size is the same as prompt batch size
468
+ repeat_by = num_images_per_prompt
469
+
470
+ image = image.repeat_interleave(repeat_by, dim=0)
471
+
472
+ image = image.to(device=device, dtype=dtype)
473
+
474
+ if do_classifier_free_guidance and not guess_mode:
475
+ image = torch.cat([image] * 2)
476
+
477
+ return image
478
+
479
+
480
+ class StableDiffusionLongPromptWeightingPipeline(StableDiffusionPipeline):
481
+ r"""
482
+ Pipeline for text-to-image generation using Stable Diffusion without tokens length limit, and support parsing
483
+ weighting in prompt.
484
+
485
+ This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the
486
+ library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.)
487
+
488
+ Args:
489
+ vae ([`AutoencoderKL`]):
490
+ Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations.
491
+ text_encoder ([`CLIPTextModel`]):
492
+ Frozen text-encoder. Stable Diffusion uses the text portion of
493
+ [CLIP](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPTextModel), specifically
494
+ the [clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14) variant.
495
+ tokenizer (`CLIPTokenizer`):
496
+ Tokenizer of class
497
+ [CLIPTokenizer](https://huggingface.co/docs/transformers/v4.21.0/en/model_doc/clip#transformers.CLIPTokenizer).
498
+ unet ([`UNet2DConditionModel`]): Conditional U-Net architecture to denoise the encoded image latents.
499
+ scheduler ([`SchedulerMixin`]):
500
+ A scheduler to be used in combination with `unet` to denoise the encoded image latents. Can be one of
501
+ [`DDIMScheduler`], [`LMSDiscreteScheduler`], or [`PNDMScheduler`].
502
+ safety_checker ([`StableDiffusionSafetyChecker`]):
503
+ Classification module that estimates whether generated images could be considered offensive or harmful.
504
+ Please, refer to the [model card](https://huggingface.co/CompVis/stable-diffusion-v1-4) for details.
505
+ feature_extractor ([`CLIPFeatureExtractor`]):
506
+ Model that extracts features from generated images to be used as inputs for the `safety_checker`.
507
+ """
508
+
509
+ # if version.parse(version.parse(diffusers.__version__).base_version) >= version.parse("0.9.0"):
510
+
511
+ def __init__(
512
+ self,
513
+ vae: AutoencoderKL,
514
+ text_encoder: CLIPTextModel,
515
+ tokenizer: CLIPTokenizer,
516
+ unet: UNet2DConditionModel,
517
+ scheduler: SchedulerMixin,
518
+ # clip_skip: int,
519
+ safety_checker: StableDiffusionSafetyChecker,
520
+ feature_extractor: CLIPFeatureExtractor,
521
+ requires_safety_checker: bool = True,
522
+ image_encoder: CLIPVisionModelWithProjection = None,
523
+ clip_skip: int = 1,
524
+ ):
525
+ super().__init__(
526
+ vae=vae,
527
+ text_encoder=text_encoder,
528
+ tokenizer=tokenizer,
529
+ unet=unet,
530
+ scheduler=scheduler,
531
+ safety_checker=safety_checker,
532
+ feature_extractor=feature_extractor,
533
+ requires_safety_checker=requires_safety_checker,
534
+ image_encoder=image_encoder,
535
+ )
536
+ self.custom_clip_skip = clip_skip
537
+ self.__init__additional__()
538
+
539
+ def __init__additional__(self):
540
+ if not hasattr(self, "vae_scale_factor"):
541
+ setattr(self, "vae_scale_factor", 2 ** (len(self.vae.config.block_out_channels) - 1))
542
+
543
+ @property
544
+ def _execution_device(self):
545
+ r"""
546
+ Returns the device on which the pipeline's models will be executed. After calling
547
+ `pipeline.enable_sequential_cpu_offload()` the execution device can only be inferred from Accelerate's module
548
+ hooks.
549
+ """
550
+ if self.device != torch.device("meta") or not hasattr(self.unet, "_hf_hook"):
551
+ return self.device
552
+ for module in self.unet.modules():
553
+ if (
554
+ hasattr(module, "_hf_hook")
555
+ and hasattr(module._hf_hook, "execution_device")
556
+ and module._hf_hook.execution_device is not None
557
+ ):
558
+ return torch.device(module._hf_hook.execution_device)
559
+ return self.device
560
+
561
+ def _encode_prompt(
562
+ self,
563
+ prompt,
564
+ device,
565
+ num_images_per_prompt,
566
+ do_classifier_free_guidance,
567
+ negative_prompt,
568
+ max_embeddings_multiples,
569
+ ):
570
+ r"""
571
+ Encodes the prompt into text encoder hidden states.
572
+
573
+ Args:
574
+ prompt (`str` or `list(int)`):
575
+ prompt to be encoded
576
+ device: (`torch.device`):
577
+ torch device
578
+ num_images_per_prompt (`int`):
579
+ number of images that should be generated per prompt
580
+ do_classifier_free_guidance (`bool`):
581
+ whether to use classifier free guidance or not
582
+ negative_prompt (`str` or `List[str]`):
583
+ The prompt or prompts not to guide the image generation. Ignored when not using guidance (i.e., ignored
584
+ if `guidance_scale` is less than `1`).
585
+ max_embeddings_multiples (`int`, *optional*, defaults to `3`):
586
+ The max multiple length of prompt embeddings compared to the max output length of text encoder.
587
+ """
588
+ batch_size = len(prompt) if isinstance(prompt, list) else 1
589
+
590
+ if negative_prompt is None:
591
+ negative_prompt = [""] * batch_size
592
+ elif isinstance(negative_prompt, str):
593
+ negative_prompt = [negative_prompt] * batch_size
594
+ if batch_size != len(negative_prompt):
595
+ raise ValueError(
596
+ f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:"
597
+ f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches"
598
+ " the batch size of `prompt`."
599
+ )
600
+
601
+ text_embeddings, uncond_embeddings = get_weighted_text_embeddings(
602
+ pipe=self,
603
+ prompt=prompt,
604
+ uncond_prompt=negative_prompt if do_classifier_free_guidance else None,
605
+ max_embeddings_multiples=max_embeddings_multiples,
606
+ clip_skip=self.custom_clip_skip,
607
+ )
608
+ bs_embed, seq_len, _ = text_embeddings.shape
609
+ text_embeddings = text_embeddings.repeat(1, num_images_per_prompt, 1)
610
+ text_embeddings = text_embeddings.view(bs_embed * num_images_per_prompt, seq_len, -1)
611
+
612
+ if do_classifier_free_guidance:
613
+ bs_embed, seq_len, _ = uncond_embeddings.shape
614
+ uncond_embeddings = uncond_embeddings.repeat(1, num_images_per_prompt, 1)
615
+ uncond_embeddings = uncond_embeddings.view(bs_embed * num_images_per_prompt, seq_len, -1)
616
+ text_embeddings = torch.cat([uncond_embeddings, text_embeddings])
617
+
618
+ return text_embeddings
619
+
620
+ def check_inputs(self, prompt, height, width, strength, callback_steps):
621
+ if not isinstance(prompt, str) and not isinstance(prompt, list):
622
+ raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
623
+
624
+ if strength < 0 or strength > 1:
625
+ raise ValueError(f"The value of strength should in [0.0, 1.0] but is {strength}")
626
+
627
+ if height % 8 != 0 or width % 8 != 0:
628
+ logger.info(f'{height} {width}')
629
+ raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.")
630
+
631
+ if (callback_steps is None) or (
632
+ callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0)
633
+ ):
634
+ raise ValueError(
635
+ f"`callback_steps` has to be a positive integer but is {callback_steps} of type" f" {type(callback_steps)}."
636
+ )
637
+
638
+ def get_timesteps(self, num_inference_steps, strength, device, is_text2img):
639
+ if is_text2img:
640
+ return self.scheduler.timesteps.to(device), num_inference_steps
641
+ else:
642
+ # get the original timestep using init_timestep
643
+ offset = self.scheduler.config.get("steps_offset", 0)
644
+ init_timestep = int(num_inference_steps * strength) + offset
645
+ init_timestep = min(init_timestep, num_inference_steps)
646
+
647
+ t_start = max(num_inference_steps - init_timestep + offset, 0)
648
+ timesteps = self.scheduler.timesteps[t_start:].to(device)
649
+ return timesteps, num_inference_steps - t_start
650
+
651
+ def run_safety_checker(self, image, device, dtype):
652
+ if self.safety_checker is not None:
653
+ safety_checker_input = self.feature_extractor(self.numpy_to_pil(image), return_tensors="pt").to(device)
654
+ image, has_nsfw_concept = self.safety_checker(images=image, clip_input=safety_checker_input.pixel_values.to(dtype))
655
+ else:
656
+ has_nsfw_concept = None
657
+ return image, has_nsfw_concept
658
+
659
+ def decode_latents(self, latents):
660
+ latents = 1 / 0.18215 * latents
661
+ image = self.vae.decode(latents).sample
662
+ image = (image / 2 + 0.5).clamp(0, 1)
663
+ # we always cast to float32 as this does not cause significant overhead and is compatible with bfloat16
664
+ image = image.cpu().permute(0, 2, 3, 1).float().numpy()
665
+ return image
666
+
667
+ def prepare_extra_step_kwargs(self, generator, eta):
668
+ # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature
669
+ # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers.
670
+ # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502
671
+ # and should be between [0, 1]
672
+
673
+ accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys())
674
+ extra_step_kwargs = {}
675
+ if accepts_eta:
676
+ extra_step_kwargs["eta"] = eta
677
+
678
+ # check if the scheduler accepts generator
679
+ accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys())
680
+ if accepts_generator:
681
+ extra_step_kwargs["generator"] = generator
682
+ return extra_step_kwargs
683
+
684
+ def prepare_latents(self, image, timestep, batch_size, height, width, dtype, device, generator, latents=None):
685
+ if image is None:
686
+ shape = (
687
+ batch_size,
688
+ self.unet.in_channels,
689
+ height // self.vae_scale_factor,
690
+ width // self.vae_scale_factor,
691
+ )
692
+
693
+ if latents is None:
694
+ if device.type == "mps":
695
+ # randn does not work reproducibly on mps
696
+ latents = torch.randn(shape, generator=generator, device="cpu", dtype=dtype).to(device)
697
+ else:
698
+ latents = torch.randn(shape, generator=generator, device=device, dtype=dtype)
699
+ else:
700
+ if latents.shape != shape:
701
+ raise ValueError(f"Unexpected latents shape, got {latents.shape}, expected {shape}")
702
+ latents = latents.to(device)
703
+
704
+ # scale the initial noise by the standard deviation required by the scheduler
705
+ latents = latents * self.scheduler.init_noise_sigma
706
+ return latents, None, None
707
+ else:
708
+ init_latent_dist = self.vae.encode(image).latent_dist
709
+ init_latents = init_latent_dist.sample(generator=generator)
710
+ init_latents = 0.18215 * init_latents
711
+ init_latents = torch.cat([init_latents] * batch_size, dim=0)
712
+ init_latents_orig = init_latents
713
+ shape = init_latents.shape
714
+
715
+ # add noise to latents using the timesteps
716
+ if device.type == "mps":
717
+ noise = torch.randn(shape, generator=generator, device="cpu", dtype=dtype).to(device)
718
+ else:
719
+ noise = torch.randn(shape, generator=generator, device=device, dtype=dtype)
720
+ latents = self.scheduler.add_noise(init_latents, noise, timestep)
721
+ return latents, init_latents_orig, noise
722
+
723
+ @torch.no_grad()
724
+ def __call__(
725
+ self,
726
+ prompt: Union[str, List[str]],
727
+ negative_prompt: Optional[Union[str, List[str]]] = None,
728
+ image: Union[torch.FloatTensor, PIL.Image.Image] = None,
729
+ mask_image: Union[torch.FloatTensor, PIL.Image.Image] = None,
730
+ height: int = 512,
731
+ width: int = 512,
732
+ num_inference_steps: int = 50,
733
+ guidance_scale: float = 7.5,
734
+ strength: float = 0.8,
735
+ num_images_per_prompt: Optional[int] = 1,
736
+ eta: float = 0.0,
737
+ generator: Optional[torch.Generator] = None,
738
+ latents: Optional[torch.FloatTensor] = None,
739
+ max_embeddings_multiples: Optional[int] = 3,
740
+ output_type: Optional[str] = "pil",
741
+ return_dict: bool = True,
742
+ controlnet=None,
743
+ controlnet_image=None,
744
+ callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None,
745
+ is_cancelled_callback: Optional[Callable[[], bool]] = None,
746
+ callback_steps: int = 1,
747
+ ):
748
+ r"""
749
+ Function invoked when calling the pipeline for generation.
750
+
751
+ Args:
752
+ prompt (`str` or `List[str]`):
753
+ The prompt or prompts to guide the image generation.
754
+ negative_prompt (`str` or `List[str]`, *optional*):
755
+ The prompt or prompts not to guide the image generation. Ignored when not using guidance (i.e., ignored
756
+ if `guidance_scale` is less than `1`).
757
+ image (`torch.FloatTensor` or `PIL.Image.Image`):
758
+ `Image`, or tensor representing an image batch, that will be used as the starting point for the
759
+ process.
760
+ mask_image (`torch.FloatTensor` or `PIL.Image.Image`):
761
+ `Image`, or tensor representing an image batch, to mask `image`. White pixels in the mask will be
762
+ replaced by noise and therefore repainted, while black pixels will be preserved. If `mask_image` is a
763
+ PIL image, it will be converted to a single channel (luminance) before use. If it's a tensor, it should
764
+ contain one color channel (L) instead of 3, so the expected shape would be `(B, H, W, 1)`.
765
+ height (`int`, *optional*, defaults to 512):
766
+ The height in pixels of the generated image.
767
+ width (`int`, *optional*, defaults to 512):
768
+ The width in pixels of the generated image.
769
+ num_inference_steps (`int`, *optional*, defaults to 50):
770
+ The number of denoising steps. More denoising steps usually lead to a higher quality image at the
771
+ expense of slower inference.
772
+ guidance_scale (`float`, *optional*, defaults to 7.5):
773
+ Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
774
+ `guidance_scale` is defined as `w` of equation 2. of [Imagen
775
+ Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale >
776
+ 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,
777
+ usually at the expense of lower image quality.
778
+ strength (`float`, *optional*, defaults to 0.8):
779
+ Conceptually, indicates how much to transform the reference `image`. Must be between 0 and 1.
780
+ `image` will be used as a starting point, adding more noise to it the larger the `strength`. The
781
+ number of denoising steps depends on the amount of noise initially added. When `strength` is 1, added
782
+ noise will be maximum and the denoising process will run for the full number of iterations specified in
783
+ `num_inference_steps`. A value of 1, therefore, essentially ignores `image`.
784
+ num_images_per_prompt (`int`, *optional*, defaults to 1):
785
+ The number of images to generate per prompt.
786
+ eta (`float`, *optional*, defaults to 0.0):
787
+ Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to
788
+ [`schedulers.DDIMScheduler`], will be ignored for others.
789
+ generator (`torch.Generator`, *optional*):
790
+ A [torch generator](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make generation
791
+ deterministic.
792
+ latents (`torch.FloatTensor`, *optional*):
793
+ Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image
794
+ generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
795
+ tensor will ge generated by sampling using the supplied random `generator`.
796
+ max_embeddings_multiples (`int`, *optional*, defaults to `3`):
797
+ The max multiple length of prompt embeddings compared to the max output length of text encoder.
798
+ output_type (`str`, *optional*, defaults to `"pil"`):
799
+ The output format of the generate image. Choose between
800
+ [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
801
+ return_dict (`bool`, *optional*, defaults to `True`):
802
+ Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a
803
+ plain tuple.
804
+ controlnet (`diffusers.ControlNetModel`, *optional*):
805
+ A controlnet model to be used for the inference. If not provided, controlnet will be disabled.
806
+ controlnet_image (`torch.FloatTensor` or `PIL.Image.Image`, *optional*):
807
+ `Image`, or tensor representing an image batch, to be used as the starting point for the controlnet
808
+ inference.
809
+ callback (`Callable`, *optional*):
810
+ A function that will be called every `callback_steps` steps during inference. The function will be
811
+ called with the following arguments: `callback(step: int, timestep: int, latents: torch.FloatTensor)`.
812
+ is_cancelled_callback (`Callable`, *optional*):
813
+ A function that will be called every `callback_steps` steps during inference. If the function returns
814
+ `True`, the inference will be cancelled.
815
+ callback_steps (`int`, *optional*, defaults to 1):
816
+ The frequency at which the `callback` function will be called. If not specified, the callback will be
817
+ called at every step.
818
+
819
+ Returns:
820
+ `None` if cancelled by `is_cancelled_callback`,
821
+ [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] or `tuple`:
822
+ [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] if `return_dict` is True, otherwise a `tuple.
823
+ When returning a tuple, the first element is a list with the generated images, and the second element is a
824
+ list of `bool`s denoting whether the corresponding generated image likely represents "not-safe-for-work"
825
+ (nsfw) content, according to the `safety_checker`.
826
+ """
827
+ if controlnet is not None and controlnet_image is None:
828
+ raise ValueError("controlnet_image must be provided if controlnet is not None.")
829
+
830
+ # 0. Default height and width to unet
831
+ height = height or self.unet.config.sample_size * self.vae_scale_factor
832
+ width = width or self.unet.config.sample_size * self.vae_scale_factor
833
+
834
+ # 1. Check inputs. Raise error if not correct
835
+ self.check_inputs(prompt, height, width, strength, callback_steps)
836
+
837
+ # 2. Define call parameters
838
+ batch_size = 1 if isinstance(prompt, str) else len(prompt)
839
+ device = self._execution_device
840
+ # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
841
+ # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
842
+ # corresponds to doing no classifier free guidance.
843
+ do_classifier_free_guidance = guidance_scale > 1.0
844
+
845
+ # 3. Encode input prompt
846
+ text_embeddings = self._encode_prompt(
847
+ prompt,
848
+ device,
849
+ num_images_per_prompt,
850
+ do_classifier_free_guidance,
851
+ negative_prompt,
852
+ max_embeddings_multiples,
853
+ )
854
+ dtype = text_embeddings.dtype
855
+
856
+ # 4. Preprocess image and mask
857
+ if isinstance(image, PIL.Image.Image):
858
+ image = preprocess_image(image)
859
+ if image is not None:
860
+ image = image.to(device=self.device, dtype=dtype)
861
+ if isinstance(mask_image, PIL.Image.Image):
862
+ mask_image = preprocess_mask(mask_image, self.vae_scale_factor)
863
+ if mask_image is not None:
864
+ mask = mask_image.to(device=self.device, dtype=dtype)
865
+ mask = torch.cat([mask] * batch_size * num_images_per_prompt)
866
+ else:
867
+ mask = None
868
+
869
+ if controlnet_image is not None:
870
+ controlnet_image = prepare_controlnet_image(
871
+ controlnet_image, width, height, batch_size, 1, self.device, controlnet.dtype, do_classifier_free_guidance, False
872
+ )
873
+
874
+ # 5. set timesteps
875
+ self.scheduler.set_timesteps(num_inference_steps, device=device)
876
+ timesteps, num_inference_steps = self.get_timesteps(num_inference_steps, strength, device, image is None)
877
+ latent_timestep = timesteps[:1].repeat(batch_size * num_images_per_prompt)
878
+
879
+ # 6. Prepare latent variables
880
+ latents, init_latents_orig, noise = self.prepare_latents(
881
+ image,
882
+ latent_timestep,
883
+ batch_size * num_images_per_prompt,
884
+ height,
885
+ width,
886
+ dtype,
887
+ device,
888
+ generator,
889
+ latents,
890
+ )
891
+
892
+ # 7. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline
893
+ extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)
894
+
895
+ # 8. Denoising loop
896
+ for i, t in enumerate(self.progress_bar(timesteps)):
897
+ # expand the latents if we are doing classifier free guidance
898
+ latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents
899
+ latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
900
+
901
+ unet_additional_args = {}
902
+ if controlnet is not None:
903
+ down_block_res_samples, mid_block_res_sample = controlnet(
904
+ latent_model_input,
905
+ t,
906
+ encoder_hidden_states=text_embeddings,
907
+ controlnet_cond=controlnet_image,
908
+ conditioning_scale=1.0,
909
+ guess_mode=False,
910
+ return_dict=False,
911
+ )
912
+ unet_additional_args["down_block_additional_residuals"] = down_block_res_samples
913
+ unet_additional_args["mid_block_additional_residual"] = mid_block_res_sample
914
+
915
+ # predict the noise residual
916
+ noise_pred = self.unet(latent_model_input, t, encoder_hidden_states=text_embeddings, **unet_additional_args).sample
917
+
918
+ # perform guidance
919
+ if do_classifier_free_guidance:
920
+ noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
921
+ noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
922
+
923
+ # compute the previous noisy sample x_t -> x_t-1
924
+ latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs).prev_sample
925
+
926
+ if mask is not None:
927
+ # masking
928
+ init_latents_proper = self.scheduler.add_noise(init_latents_orig, noise, torch.tensor([t]))
929
+ latents = (init_latents_proper * mask) + (latents * (1 - mask))
930
+
931
+ # call the callback, if provided
932
+ if i % callback_steps == 0:
933
+ if callback is not None:
934
+ callback(i, t, latents)
935
+ if is_cancelled_callback is not None and is_cancelled_callback():
936
+ return None
937
+
938
+ return latents
939
+
940
+ def latents_to_image(self, latents):
941
+ # 9. Post-processing
942
+ image = self.decode_latents(latents.to(self.vae.dtype))
943
+ image = self.numpy_to_pil(image)
944
+ return image
945
+
946
+ def text2img(
947
+ self,
948
+ prompt: Union[str, List[str]],
949
+ negative_prompt: Optional[Union[str, List[str]]] = None,
950
+ height: int = 512,
951
+ width: int = 512,
952
+ num_inference_steps: int = 50,
953
+ guidance_scale: float = 7.5,
954
+ num_images_per_prompt: Optional[int] = 1,
955
+ eta: float = 0.0,
956
+ generator: Optional[torch.Generator] = None,
957
+ latents: Optional[torch.FloatTensor] = None,
958
+ max_embeddings_multiples: Optional[int] = 3,
959
+ output_type: Optional[str] = "pil",
960
+ return_dict: bool = True,
961
+ callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None,
962
+ is_cancelled_callback: Optional[Callable[[], bool]] = None,
963
+ callback_steps: int = 1,
964
+ ):
965
+ r"""
966
+ Function for text-to-image generation.
967
+ Args:
968
+ prompt (`str` or `List[str]`):
969
+ The prompt or prompts to guide the image generation.
970
+ negative_prompt (`str` or `List[str]`, *optional*):
971
+ The prompt or prompts not to guide the image generation. Ignored when not using guidance (i.e., ignored
972
+ if `guidance_scale` is less than `1`).
973
+ height (`int`, *optional*, defaults to 512):
974
+ The height in pixels of the generated image.
975
+ width (`int`, *optional*, defaults to 512):
976
+ The width in pixels of the generated image.
977
+ num_inference_steps (`int`, *optional*, defaults to 50):
978
+ The number of denoising steps. More denoising steps usually lead to a higher quality image at the
979
+ expense of slower inference.
980
+ guidance_scale (`float`, *optional*, defaults to 7.5):
981
+ Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
982
+ `guidance_scale` is defined as `w` of equation 2. of [Imagen
983
+ Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale >
984
+ 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,
985
+ usually at the expense of lower image quality.
986
+ num_images_per_prompt (`int`, *optional*, defaults to 1):
987
+ The number of images to generate per prompt.
988
+ eta (`float`, *optional*, defaults to 0.0):
989
+ Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to
990
+ [`schedulers.DDIMScheduler`], will be ignored for others.
991
+ generator (`torch.Generator`, *optional*):
992
+ A [torch generator](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make generation
993
+ deterministic.
994
+ latents (`torch.FloatTensor`, *optional*):
995
+ Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image
996
+ generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
997
+ tensor will ge generated by sampling using the supplied random `generator`.
998
+ max_embeddings_multiples (`int`, *optional*, defaults to `3`):
999
+ The max multiple length of prompt embeddings compared to the max output length of text encoder.
1000
+ output_type (`str`, *optional*, defaults to `"pil"`):
1001
+ The output format of the generate image. Choose between
1002
+ [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
1003
+ return_dict (`bool`, *optional*, defaults to `True`):
1004
+ Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a
1005
+ plain tuple.
1006
+ callback (`Callable`, *optional*):
1007
+ A function that will be called every `callback_steps` steps during inference. The function will be
1008
+ called with the following arguments: `callback(step: int, timestep: int, latents: torch.FloatTensor)`.
1009
+ is_cancelled_callback (`Callable`, *optional*):
1010
+ A function that will be called every `callback_steps` steps during inference. If the function returns
1011
+ `True`, the inference will be cancelled.
1012
+ callback_steps (`int`, *optional*, defaults to 1):
1013
+ The frequency at which the `callback` function will be called. If not specified, the callback will be
1014
+ called at every step.
1015
+ Returns:
1016
+ [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] or `tuple`:
1017
+ [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] if `return_dict` is True, otherwise a `tuple.
1018
+ When returning a tuple, the first element is a list with the generated images, and the second element is a
1019
+ list of `bool`s denoting whether the corresponding generated image likely represents "not-safe-for-work"
1020
+ (nsfw) content, according to the `safety_checker`.
1021
+ """
1022
+ return self.__call__(
1023
+ prompt=prompt,
1024
+ negative_prompt=negative_prompt,
1025
+ height=height,
1026
+ width=width,
1027
+ num_inference_steps=num_inference_steps,
1028
+ guidance_scale=guidance_scale,
1029
+ num_images_per_prompt=num_images_per_prompt,
1030
+ eta=eta,
1031
+ generator=generator,
1032
+ latents=latents,
1033
+ max_embeddings_multiples=max_embeddings_multiples,
1034
+ output_type=output_type,
1035
+ return_dict=return_dict,
1036
+ callback=callback,
1037
+ is_cancelled_callback=is_cancelled_callback,
1038
+ callback_steps=callback_steps,
1039
+ )
1040
+
1041
+ def img2img(
1042
+ self,
1043
+ image: Union[torch.FloatTensor, PIL.Image.Image],
1044
+ prompt: Union[str, List[str]],
1045
+ negative_prompt: Optional[Union[str, List[str]]] = None,
1046
+ strength: float = 0.8,
1047
+ num_inference_steps: Optional[int] = 50,
1048
+ guidance_scale: Optional[float] = 7.5,
1049
+ num_images_per_prompt: Optional[int] = 1,
1050
+ eta: Optional[float] = 0.0,
1051
+ generator: Optional[torch.Generator] = None,
1052
+ max_embeddings_multiples: Optional[int] = 3,
1053
+ output_type: Optional[str] = "pil",
1054
+ return_dict: bool = True,
1055
+ callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None,
1056
+ is_cancelled_callback: Optional[Callable[[], bool]] = None,
1057
+ callback_steps: int = 1,
1058
+ ):
1059
+ r"""
1060
+ Function for image-to-image generation.
1061
+ Args:
1062
+ image (`torch.FloatTensor` or `PIL.Image.Image`):
1063
+ `Image`, or tensor representing an image batch, that will be used as the starting point for the
1064
+ process.
1065
+ prompt (`str` or `List[str]`):
1066
+ The prompt or prompts to guide the image generation.
1067
+ negative_prompt (`str` or `List[str]`, *optional*):
1068
+ The prompt or prompts not to guide the image generation. Ignored when not using guidance (i.e., ignored
1069
+ if `guidance_scale` is less than `1`).
1070
+ strength (`float`, *optional*, defaults to 0.8):
1071
+ Conceptually, indicates how much to transform the reference `image`. Must be between 0 and 1.
1072
+ `image` will be used as a starting point, adding more noise to it the larger the `strength`. The
1073
+ number of denoising steps depends on the amount of noise initially added. When `strength` is 1, added
1074
+ noise will be maximum and the denoising process will run for the full number of iterations specified in
1075
+ `num_inference_steps`. A value of 1, therefore, essentially ignores `image`.
1076
+ num_inference_steps (`int`, *optional*, defaults to 50):
1077
+ The number of denoising steps. More denoising steps usually lead to a higher quality image at the
1078
+ expense of slower inference. This parameter will be modulated by `strength`.
1079
+ guidance_scale (`float`, *optional*, defaults to 7.5):
1080
+ Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
1081
+ `guidance_scale` is defined as `w` of equation 2. of [Imagen
1082
+ Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale >
1083
+ 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,
1084
+ usually at the expense of lower image quality.
1085
+ num_images_per_prompt (`int`, *optional*, defaults to 1):
1086
+ The number of images to generate per prompt.
1087
+ eta (`float`, *optional*, defaults to 0.0):
1088
+ Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to
1089
+ [`schedulers.DDIMScheduler`], will be ignored for others.
1090
+ generator (`torch.Generator`, *optional*):
1091
+ A [torch generator](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make generation
1092
+ deterministic.
1093
+ max_embeddings_multiples (`int`, *optional*, defaults to `3`):
1094
+ The max multiple length of prompt embeddings compared to the max output length of text encoder.
1095
+ output_type (`str`, *optional*, defaults to `"pil"`):
1096
+ The output format of the generate image. Choose between
1097
+ [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
1098
+ return_dict (`bool`, *optional*, defaults to `True`):
1099
+ Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a
1100
+ plain tuple.
1101
+ callback (`Callable`, *optional*):
1102
+ A function that will be called every `callback_steps` steps during inference. The function will be
1103
+ called with the following arguments: `callback(step: int, timestep: int, latents: torch.FloatTensor)`.
1104
+ is_cancelled_callback (`Callable`, *optional*):
1105
+ A function that will be called every `callback_steps` steps during inference. If the function returns
1106
+ `True`, the inference will be cancelled.
1107
+ callback_steps (`int`, *optional*, defaults to 1):
1108
+ The frequency at which the `callback` function will be called. If not specified, the callback will be
1109
+ called at every step.
1110
+ Returns:
1111
+ [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] or `tuple`:
1112
+ [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] if `return_dict` is True, otherwise a `tuple.
1113
+ When returning a tuple, the first element is a list with the generated images, and the second element is a
1114
+ list of `bool`s denoting whether the corresponding generated image likely represents "not-safe-for-work"
1115
+ (nsfw) content, according to the `safety_checker`.
1116
+ """
1117
+ return self.__call__(
1118
+ prompt=prompt,
1119
+ negative_prompt=negative_prompt,
1120
+ image=image,
1121
+ num_inference_steps=num_inference_steps,
1122
+ guidance_scale=guidance_scale,
1123
+ strength=strength,
1124
+ num_images_per_prompt=num_images_per_prompt,
1125
+ eta=eta,
1126
+ generator=generator,
1127
+ max_embeddings_multiples=max_embeddings_multiples,
1128
+ output_type=output_type,
1129
+ return_dict=return_dict,
1130
+ callback=callback,
1131
+ is_cancelled_callback=is_cancelled_callback,
1132
+ callback_steps=callback_steps,
1133
+ )
1134
+
1135
+ def inpaint(
1136
+ self,
1137
+ image: Union[torch.FloatTensor, PIL.Image.Image],
1138
+ mask_image: Union[torch.FloatTensor, PIL.Image.Image],
1139
+ prompt: Union[str, List[str]],
1140
+ negative_prompt: Optional[Union[str, List[str]]] = None,
1141
+ strength: float = 0.8,
1142
+ num_inference_steps: Optional[int] = 50,
1143
+ guidance_scale: Optional[float] = 7.5,
1144
+ num_images_per_prompt: Optional[int] = 1,
1145
+ eta: Optional[float] = 0.0,
1146
+ generator: Optional[torch.Generator] = None,
1147
+ max_embeddings_multiples: Optional[int] = 3,
1148
+ output_type: Optional[str] = "pil",
1149
+ return_dict: bool = True,
1150
+ callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None,
1151
+ is_cancelled_callback: Optional[Callable[[], bool]] = None,
1152
+ callback_steps: int = 1,
1153
+ ):
1154
+ r"""
1155
+ Function for inpaint.
1156
+ Args:
1157
+ image (`torch.FloatTensor` or `PIL.Image.Image`):
1158
+ `Image`, or tensor representing an image batch, that will be used as the starting point for the
1159
+ process. This is the image whose masked region will be inpainted.
1160
+ mask_image (`torch.FloatTensor` or `PIL.Image.Image`):
1161
+ `Image`, or tensor representing an image batch, to mask `image`. White pixels in the mask will be
1162
+ replaced by noise and therefore repainted, while black pixels will be preserved. If `mask_image` is a
1163
+ PIL image, it will be converted to a single channel (luminance) before use. If it's a tensor, it should
1164
+ contain one color channel (L) instead of 3, so the expected shape would be `(B, H, W, 1)`.
1165
+ prompt (`str` or `List[str]`):
1166
+ The prompt or prompts to guide the image generation.
1167
+ negative_prompt (`str` or `List[str]`, *optional*):
1168
+ The prompt or prompts not to guide the image generation. Ignored when not using guidance (i.e., ignored
1169
+ if `guidance_scale` is less than `1`).
1170
+ strength (`float`, *optional*, defaults to 0.8):
1171
+ Conceptually, indicates how much to inpaint the masked area. Must be between 0 and 1. When `strength`
1172
+ is 1, the denoising process will be run on the masked area for the full number of iterations specified
1173
+ in `num_inference_steps`. `image` will be used as a reference for the masked area, adding more
1174
+ noise to that region the larger the `strength`. If `strength` is 0, no inpainting will occur.
1175
+ num_inference_steps (`int`, *optional*, defaults to 50):
1176
+ The reference number of denoising steps. More denoising steps usually lead to a higher quality image at
1177
+ the expense of slower inference. This parameter will be modulated by `strength`, as explained above.
1178
+ guidance_scale (`float`, *optional*, defaults to 7.5):
1179
+ Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
1180
+ `guidance_scale` is defined as `w` of equation 2. of [Imagen
1181
+ Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale >
1182
+ 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,
1183
+ usually at the expense of lower image quality.
1184
+ num_images_per_prompt (`int`, *optional*, defaults to 1):
1185
+ The number of images to generate per prompt.
1186
+ eta (`float`, *optional*, defaults to 0.0):
1187
+ Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to
1188
+ [`schedulers.DDIMScheduler`], will be ignored for others.
1189
+ generator (`torch.Generator`, *optional*):
1190
+ A [torch generator](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make generation
1191
+ deterministic.
1192
+ max_embeddings_multiples (`int`, *optional*, defaults to `3`):
1193
+ The max multiple length of prompt embeddings compared to the max output length of text encoder.
1194
+ output_type (`str`, *optional*, defaults to `"pil"`):
1195
+ The output format of the generate image. Choose between
1196
+ [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
1197
+ return_dict (`bool`, *optional*, defaults to `True`):
1198
+ Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a
1199
+ plain tuple.
1200
+ callback (`Callable`, *optional*):
1201
+ A function that will be called every `callback_steps` steps during inference. The function will be
1202
+ called with the following arguments: `callback(step: int, timestep: int, latents: torch.FloatTensor)`.
1203
+ is_cancelled_callback (`Callable`, *optional*):
1204
+ A function that will be called every `callback_steps` steps during inference. If the function returns
1205
+ `True`, the inference will be cancelled.
1206
+ callback_steps (`int`, *optional*, defaults to 1):
1207
+ The frequency at which the `callback` function will be called. If not specified, the callback will be
1208
+ called at every step.
1209
+ Returns:
1210
+ [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] or `tuple`:
1211
+ [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] if `return_dict` is True, otherwise a `tuple.
1212
+ When returning a tuple, the first element is a list with the generated images, and the second element is a
1213
+ list of `bool`s denoting whether the corresponding generated image likely represents "not-safe-for-work"
1214
+ (nsfw) content, according to the `safety_checker`.
1215
+ """
1216
+ return self.__call__(
1217
+ prompt=prompt,
1218
+ negative_prompt=negative_prompt,
1219
+ image=image,
1220
+ mask_image=mask_image,
1221
+ num_inference_steps=num_inference_steps,
1222
+ guidance_scale=guidance_scale,
1223
+ strength=strength,
1224
+ num_images_per_prompt=num_images_per_prompt,
1225
+ eta=eta,
1226
+ generator=generator,
1227
+ max_embeddings_multiples=max_embeddings_multiples,
1228
+ output_type=output_type,
1229
+ return_dict=return_dict,
1230
+ callback=callback,
1231
+ is_cancelled_callback=is_cancelled_callback,
1232
+ callback_steps=callback_steps,
1233
+ )
main.py ADDED
@@ -0,0 +1,166 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ extract factors the build is dependent on:
3
+ [X] compute capability
4
+ [ ] TODO: Q - What if we have multiple GPUs of different makes?
5
+ - CUDA version
6
+ - Software:
7
+ - CPU-only: only CPU quantization functions (no optimizer, no matrix multiple)
8
+ - CuBLAS-LT: full-build 8-bit optimizer
9
+ - no CuBLAS-LT: no 8-bit matrix multiplication (`nomatmul`)
10
+
11
+ evaluation:
12
+ - if paths faulty, return meaningful error
13
+ - else:
14
+ - determine CUDA version
15
+ - determine capabilities
16
+ - based on that set the default path
17
+ """
18
+
19
+ import ctypes
20
+
21
+ from .paths import determine_cuda_runtime_lib_path
22
+
23
+
24
+ def check_cuda_result(cuda, result_val):
25
+ # 3. Check for CUDA errors
26
+ if result_val != 0:
27
+ error_str = ctypes.c_char_p()
28
+ cuda.cuGetErrorString(result_val, ctypes.byref(error_str))
29
+ print(f"CUDA exception! Error code: {error_str.value.decode()}")
30
+
31
+ def get_cuda_version(cuda, cudart_path):
32
+ # https://docs.nvidia.com/cuda/cuda-runtime-api/group__CUDART____VERSION.html#group__CUDART____VERSION
33
+ try:
34
+ cudart = ctypes.CDLL(cudart_path)
35
+ except OSError:
36
+ # TODO: shouldn't we error or at least warn here?
37
+ print(f'ERROR: libcudart.so could not be read from path: {cudart_path}!')
38
+ return None
39
+
40
+ version = ctypes.c_int()
41
+ check_cuda_result(cuda, cudart.cudaRuntimeGetVersion(ctypes.byref(version)))
42
+ version = int(version.value)
43
+ major = version//1000
44
+ minor = (version-(major*1000))//10
45
+
46
+ if major < 11:
47
+ print('CUDA SETUP: CUDA version lower than 11 are currently not supported for LLM.int8(). You will be only to use 8-bit optimizers and quantization routines!!')
48
+
49
+ return f'{major}{minor}'
50
+
51
+
52
+ def get_cuda_lib_handle():
53
+ # 1. find libcuda.so library (GPU driver) (/usr/lib)
54
+ try:
55
+ cuda = ctypes.CDLL("libcuda.so")
56
+ except OSError:
57
+ # TODO: shouldn't we error or at least warn here?
58
+ print('CUDA SETUP: WARNING! libcuda.so not found! Do you have a CUDA driver installed? If you are on a cluster, make sure you are on a CUDA machine!')
59
+ return None
60
+ check_cuda_result(cuda, cuda.cuInit(0))
61
+
62
+ return cuda
63
+
64
+
65
+ def get_compute_capabilities(cuda):
66
+ """
67
+ 1. find libcuda.so library (GPU driver) (/usr/lib)
68
+ init_device -> init variables -> call function by reference
69
+ 2. call extern C function to determine CC
70
+ (https://docs.nvidia.com/cuda/cuda-driver-api/group__CUDA__DEVICE__DEPRECATED.html)
71
+ 3. Check for CUDA errors
72
+ https://stackoverflow.com/questions/14038589/what-is-the-canonical-way-to-check-for-errors-using-the-cuda-runtime-api
73
+ # bits taken from https://gist.github.com/f0k/63a664160d016a491b2cbea15913d549
74
+ """
75
+
76
+
77
+ nGpus = ctypes.c_int()
78
+ cc_major = ctypes.c_int()
79
+ cc_minor = ctypes.c_int()
80
+
81
+ device = ctypes.c_int()
82
+
83
+ check_cuda_result(cuda, cuda.cuDeviceGetCount(ctypes.byref(nGpus)))
84
+ ccs = []
85
+ for i in range(nGpus.value):
86
+ check_cuda_result(cuda, cuda.cuDeviceGet(ctypes.byref(device), i))
87
+ ref_major = ctypes.byref(cc_major)
88
+ ref_minor = ctypes.byref(cc_minor)
89
+ # 2. call extern C function to determine CC
90
+ check_cuda_result(
91
+ cuda, cuda.cuDeviceComputeCapability(ref_major, ref_minor, device)
92
+ )
93
+ ccs.append(f"{cc_major.value}.{cc_minor.value}")
94
+
95
+ return ccs
96
+
97
+
98
+ # def get_compute_capability()-> Union[List[str, ...], None]: # FIXME: error
99
+ def get_compute_capability(cuda):
100
+ """
101
+ Extracts the highest compute capbility from all available GPUs, as compute
102
+ capabilities are downwards compatible. If no GPUs are detected, it returns
103
+ None.
104
+ """
105
+ ccs = get_compute_capabilities(cuda)
106
+ if ccs is not None:
107
+ # TODO: handle different compute capabilities; for now, take the max
108
+ return ccs[-1]
109
+ return None
110
+
111
+
112
+ def evaluate_cuda_setup():
113
+ print('')
114
+ print('='*35 + 'BUG REPORT' + '='*35)
115
+ print('Welcome to bitsandbytes. For bug reports, please submit your error trace to: https://github.com/TimDettmers/bitsandbytes/issues')
116
+ print('For effortless bug reporting copy-paste your error into this form: https://docs.google.com/forms/d/e/1FAIpQLScPB8emS3Thkp66nvqwmjTEgxp8Y9ufuWTzFyr9kJ5AoI47dQ/viewform?usp=sf_link')
117
+ print('='*80)
118
+ return "libbitsandbytes_cuda116.dll" # $$$
119
+
120
+ binary_name = "libbitsandbytes_cpu.so"
121
+ #if not torch.cuda.is_available():
122
+ #print('No GPU detected. Loading CPU library...')
123
+ #return binary_name
124
+
125
+ cudart_path = determine_cuda_runtime_lib_path()
126
+ if cudart_path is None:
127
+ print(
128
+ "WARNING: No libcudart.so found! Install CUDA or the cudatoolkit package (anaconda)!"
129
+ )
130
+ return binary_name
131
+
132
+ print(f"CUDA SETUP: CUDA runtime path found: {cudart_path}")
133
+ cuda = get_cuda_lib_handle()
134
+ cc = get_compute_capability(cuda)
135
+ print(f"CUDA SETUP: Highest compute capability among GPUs detected: {cc}")
136
+ cuda_version_string = get_cuda_version(cuda, cudart_path)
137
+
138
+
139
+ if cc == '':
140
+ print(
141
+ "WARNING: No GPU detected! Check your CUDA paths. Processing to load CPU-only library..."
142
+ )
143
+ return binary_name
144
+
145
+ # 7.5 is the minimum CC vor cublaslt
146
+ has_cublaslt = cc in ["7.5", "8.0", "8.6"]
147
+
148
+ # TODO:
149
+ # (1) CUDA missing cases (no CUDA installed by CUDA driver (nvidia-smi accessible)
150
+ # (2) Multiple CUDA versions installed
151
+
152
+ # we use ls -l instead of nvcc to determine the cuda version
153
+ # since most installations will have the libcudart.so installed, but not the compiler
154
+ print(f'CUDA SETUP: Detected CUDA version {cuda_version_string}')
155
+
156
+ def get_binary_name():
157
+ "if not has_cublaslt (CC < 7.5), then we have to choose _nocublaslt.so"
158
+ bin_base_name = "libbitsandbytes_cuda"
159
+ if has_cublaslt:
160
+ return f"{bin_base_name}{cuda_version_string}.so"
161
+ else:
162
+ return f"{bin_base_name}{cuda_version_string}_nocublaslt.so"
163
+
164
+ binary_name = get_binary_name()
165
+
166
+ return binary_name
make_captions.py ADDED
@@ -0,0 +1,210 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import glob
3
+ import os
4
+ import json
5
+ import random
6
+ import sys
7
+
8
+ from pathlib import Path
9
+ from PIL import Image
10
+ from tqdm import tqdm
11
+ import numpy as np
12
+
13
+ import torch
14
+ from library.device_utils import init_ipex, get_preferred_device
15
+ init_ipex()
16
+
17
+ from torchvision import transforms
18
+ from torchvision.transforms.functional import InterpolationMode
19
+ sys.path.append(os.path.dirname(__file__))
20
+ from blip.blip import blip_decoder, is_url
21
+ import library.train_util as train_util
22
+ from library.utils import setup_logging
23
+ setup_logging()
24
+ import logging
25
+ logger = logging.getLogger(__name__)
26
+
27
+ DEVICE = get_preferred_device()
28
+
29
+
30
+ IMAGE_SIZE = 384
31
+
32
+ # 正方形でいいのか? という気がするがソースがそうなので
33
+ IMAGE_TRANSFORM = transforms.Compose(
34
+ [
35
+ transforms.Resize((IMAGE_SIZE, IMAGE_SIZE), interpolation=InterpolationMode.BICUBIC),
36
+ transforms.ToTensor(),
37
+ transforms.Normalize((0.48145466, 0.4578275, 0.40821073), (0.26862954, 0.26130258, 0.27577711)),
38
+ ]
39
+ )
40
+
41
+
42
+ # 共通化したいが微妙に処理が異なる……
43
+ class ImageLoadingTransformDataset(torch.utils.data.Dataset):
44
+ def __init__(self, image_paths):
45
+ self.images = image_paths
46
+
47
+ def __len__(self):
48
+ return len(self.images)
49
+
50
+ def __getitem__(self, idx):
51
+ img_path = self.images[idx]
52
+
53
+ try:
54
+ image = Image.open(img_path).convert("RGB")
55
+ # convert to tensor temporarily so dataloader will accept it
56
+ tensor = IMAGE_TRANSFORM(image)
57
+ except Exception as e:
58
+ logger.error(f"Could not load image path / 画像を読み込めません: {img_path}, error: {e}")
59
+ return None
60
+
61
+ return (tensor, img_path)
62
+
63
+
64
+ def collate_fn_remove_corrupted(batch):
65
+ """Collate function that allows to remove corrupted examples in the
66
+ dataloader. It expects that the dataloader returns 'None' when that occurs.
67
+ The 'None's in the batch are removed.
68
+ """
69
+ # Filter out all the Nones (corrupted examples)
70
+ batch = list(filter(lambda x: x is not None, batch))
71
+ return batch
72
+
73
+
74
+ def main(args):
75
+ # fix the seed for reproducibility
76
+ seed = args.seed # + utils.get_rank()
77
+ torch.manual_seed(seed)
78
+ np.random.seed(seed)
79
+ random.seed(seed)
80
+
81
+ if not os.path.exists("blip"):
82
+ args.train_data_dir = os.path.abspath(args.train_data_dir) # convert to absolute path
83
+
84
+ cwd = os.getcwd()
85
+ logger.info(f"Current Working Directory is: {cwd}")
86
+ os.chdir("finetune")
87
+ if not is_url(args.caption_weights) and not os.path.isfile(args.caption_weights):
88
+ args.caption_weights = os.path.join("..", args.caption_weights)
89
+
90
+ logger.info(f"load images from {args.train_data_dir}")
91
+ train_data_dir_path = Path(args.train_data_dir)
92
+ image_paths = train_util.glob_images_pathlib(train_data_dir_path, args.recursive)
93
+ logger.info(f"found {len(image_paths)} images.")
94
+
95
+ logger.info(f"loading BLIP caption: {args.caption_weights}")
96
+ model = blip_decoder(pretrained=args.caption_weights, image_size=IMAGE_SIZE, vit="large", med_config="./blip/med_config.json")
97
+ model.eval()
98
+ model = model.to(DEVICE)
99
+ logger.info("BLIP loaded")
100
+
101
+ # captioningする
102
+ def run_batch(path_imgs):
103
+ imgs = torch.stack([im for _, im in path_imgs]).to(DEVICE)
104
+
105
+ with torch.no_grad():
106
+ if args.beam_search:
107
+ captions = model.generate(
108
+ imgs, sample=False, num_beams=args.num_beams, max_length=args.max_length, min_length=args.min_length
109
+ )
110
+ else:
111
+ captions = model.generate(
112
+ imgs, sample=True, top_p=args.top_p, max_length=args.max_length, min_length=args.min_length
113
+ )
114
+
115
+ for (image_path, _), caption in zip(path_imgs, captions):
116
+ with open(os.path.splitext(image_path)[0] + args.caption_extension, "wt", encoding="utf-8") as f:
117
+ f.write(caption + "\n")
118
+ if args.debug:
119
+ logger.info(f'{image_path} {caption}')
120
+
121
+ # 読み込みの高速化のためにDataLoaderを使うオプション
122
+ if args.max_data_loader_n_workers is not None:
123
+ dataset = ImageLoadingTransformDataset(image_paths)
124
+ data = torch.utils.data.DataLoader(
125
+ dataset,
126
+ batch_size=args.batch_size,
127
+ shuffle=False,
128
+ num_workers=args.max_data_loader_n_workers,
129
+ collate_fn=collate_fn_remove_corrupted,
130
+ drop_last=False,
131
+ )
132
+ else:
133
+ data = [[(None, ip)] for ip in image_paths]
134
+
135
+ b_imgs = []
136
+ for data_entry in tqdm(data, smoothing=0.0):
137
+ for data in data_entry:
138
+ if data is None:
139
+ continue
140
+
141
+ img_tensor, image_path = data
142
+ if img_tensor is None:
143
+ try:
144
+ raw_image = Image.open(image_path)
145
+ if raw_image.mode != "RGB":
146
+ raw_image = raw_image.convert("RGB")
147
+ img_tensor = IMAGE_TRANSFORM(raw_image)
148
+ except Exception as e:
149
+ logger.error(f"Could not load image path / 画像を読み込めません: {image_path}, error: {e}")
150
+ continue
151
+
152
+ b_imgs.append((image_path, img_tensor))
153
+ if len(b_imgs) >= args.batch_size:
154
+ run_batch(b_imgs)
155
+ b_imgs.clear()
156
+ if len(b_imgs) > 0:
157
+ run_batch(b_imgs)
158
+
159
+ logger.info("done!")
160
+
161
+
162
+ def setup_parser() -> argparse.ArgumentParser:
163
+ parser = argparse.ArgumentParser()
164
+ parser.add_argument("train_data_dir", type=str, help="directory for train images / 学習画像データのディレクトリ")
165
+ parser.add_argument(
166
+ "--caption_weights",
167
+ type=str,
168
+ default="https://storage.googleapis.com/sfr-vision-language-research/BLIP/models/model_large_caption.pth",
169
+ help="BLIP caption weights (model_large_caption.pth) / BLIP captionの重みファイル(model_large_caption.pth)",
170
+ )
171
+ parser.add_argument(
172
+ "--caption_extention",
173
+ type=str,
174
+ default=None,
175
+ help="extension of caption file (for backward compatibility) / 出力されるキャプションファイルの拡張子(スペルミスしていたのを残してあります)",
176
+ )
177
+ parser.add_argument("--caption_extension", type=str, default=".caption", help="extension of caption file / 出力されるキャプションファイルの拡張子")
178
+ parser.add_argument(
179
+ "--beam_search",
180
+ action="store_true",
181
+ help="use beam search (default Nucleus sampling) / beam searchを使う(このオプション未指定時はNucleus sampling)",
182
+ )
183
+ parser.add_argument("--batch_size", type=int, default=1, help="batch size in inference / 推論時のバッチサイズ")
184
+ parser.add_argument(
185
+ "--max_data_loader_n_workers",
186
+ type=int,
187
+ default=None,
188
+ help="enable image reading by DataLoader with this number of workers (faster) / DataLoaderによる画像読み込みを有効にしてこのワーカー数を適用する(読み込みを高速化)",
189
+ )
190
+ parser.add_argument("--num_beams", type=int, default=1, help="num of beams in beam search /beam search時のビーム数(多いと精度が上がるが時間がかかる)")
191
+ parser.add_argument("--top_p", type=float, default=0.9, help="top_p in Nucleus sampling / Nucleus sampling時のtop_p")
192
+ parser.add_argument("--max_length", type=int, default=75, help="max length of caption / captionの最大長")
193
+ parser.add_argument("--min_length", type=int, default=5, help="min length of caption / captionの最小長")
194
+ parser.add_argument("--seed", default=42, type=int, help="seed for reproducibility / 再現性を確保するための乱数seed")
195
+ parser.add_argument("--debug", action="store_true", help="debug mode")
196
+ parser.add_argument("--recursive", action="store_true", help="search for images in subfolders recursively / サブフォルダを再帰的に検索する")
197
+
198
+ return parser
199
+
200
+
201
+ if __name__ == "__main__":
202
+ parser = setup_parser()
203
+
204
+ args = parser.parse_args()
205
+
206
+ # スペルミスしていたオプションを復元する
207
+ if args.caption_extention is not None:
208
+ args.caption_extension = args.caption_extention
209
+
210
+ main(args)
make_captions_by_git.py ADDED
@@ -0,0 +1,183 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import os
3
+ import re
4
+
5
+ from pathlib import Path
6
+ from PIL import Image
7
+ from tqdm import tqdm
8
+
9
+ import torch
10
+ from library.device_utils import init_ipex, get_preferred_device
11
+ init_ipex()
12
+
13
+ from transformers import AutoProcessor, AutoModelForCausalLM
14
+ from transformers.generation.utils import GenerationMixin
15
+
16
+ import library.train_util as train_util
17
+ from library.utils import setup_logging
18
+ setup_logging()
19
+ import logging
20
+ logger = logging.getLogger(__name__)
21
+
22
+ DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
23
+
24
+ PATTERN_REPLACE = [
25
+ re.compile(r'(has|with|and) the (words?|letters?|name) (" ?[^"]*"|\w+)( ?(is )?(on|in) (the |her |their |him )?\w+)?'),
26
+ re.compile(r'(with a sign )?that says ?(" ?[^"]*"|\w+)( ?on it)?'),
27
+ re.compile(r"(with a sign )?that says ?(' ?(i'm)?[^']*'|\w+)( ?on it)?"),
28
+ re.compile(r"with the number \d+ on (it|\w+ \w+)"),
29
+ re.compile(r'with the words "'),
30
+ re.compile(r"word \w+ on it"),
31
+ re.compile(r"that says the word \w+ on it"),
32
+ re.compile("that says'the word \"( on it)?"),
33
+ ]
34
+
35
+ # 誤検知しまくりの with the word xxxx を消す
36
+
37
+
38
+ def remove_words(captions, debug):
39
+ removed_caps = []
40
+ for caption in captions:
41
+ cap = caption
42
+ for pat in PATTERN_REPLACE:
43
+ cap = pat.sub("", cap)
44
+ if debug and cap != caption:
45
+ logger.info(caption)
46
+ logger.info(cap)
47
+ removed_caps.append(cap)
48
+ return removed_caps
49
+
50
+
51
+ def collate_fn_remove_corrupted(batch):
52
+ """Collate function that allows to remove corrupted examples in the
53
+ dataloader. It expects that the dataloader returns 'None' when that occurs.
54
+ The 'None's in the batch are removed.
55
+ """
56
+ # Filter out all the Nones (corrupted examples)
57
+ batch = list(filter(lambda x: x is not None, batch))
58
+ return batch
59
+
60
+
61
+ def main(args):
62
+ r"""
63
+ transformers 4.30.2で、バッチサイズ>1でも動くようになったので、以下コメントアウト
64
+
65
+ # GITにバッチサイズが1より大きくても動くようにパッチを当てる: transformers 4.26.0用
66
+ org_prepare_input_ids_for_generation = GenerationMixin._prepare_input_ids_for_generation
67
+ curr_batch_size = [args.batch_size] # ループの最後で件数がbatch_size未満になるので入れ替えられるように
68
+
69
+ # input_idsがバッチサイズと同じ件数である必要がある:バッチサイズはこの関数から参照できないので外から渡す
70
+ # ここより上で置き換えようとするとすごく大変
71
+ def _prepare_input_ids_for_generation_patch(self, bos_token_id, encoder_outputs):
72
+ input_ids = org_prepare_input_ids_for_generation(self, bos_token_id, encoder_outputs)
73
+ if input_ids.size()[0] != curr_batch_size[0]:
74
+ input_ids = input_ids.repeat(curr_batch_size[0], 1)
75
+ return input_ids
76
+
77
+ GenerationMixin._prepare_input_ids_for_generation = _prepare_input_ids_for_generation_patch
78
+ """
79
+
80
+ logger.info(f"load images from {args.train_data_dir}")
81
+ train_data_dir_path = Path(args.train_data_dir)
82
+ image_paths = train_util.glob_images_pathlib(train_data_dir_path, args.recursive)
83
+ logger.info(f"found {len(image_paths)} images.")
84
+
85
+ # できればcacheに依存せず明示的にダウンロードしたい
86
+ logger.info(f"loading GIT: {args.model_id}")
87
+ git_processor = AutoProcessor.from_pretrained(args.model_id)
88
+ git_model = AutoModelForCausalLM.from_pretrained(args.model_id).to(DEVICE)
89
+ logger.info("GIT loaded")
90
+
91
+ # captioningする
92
+ def run_batch(path_imgs):
93
+ imgs = [im for _, im in path_imgs]
94
+
95
+ # curr_batch_size[0] = len(path_imgs)
96
+ inputs = git_processor(images=imgs, return_tensors="pt").to(DEVICE) # 画像はpil形式
97
+ generated_ids = git_model.generate(pixel_values=inputs.pixel_values, max_length=args.max_length)
98
+ captions = git_processor.batch_decode(generated_ids, skip_special_tokens=True)
99
+
100
+ if args.remove_words:
101
+ captions = remove_words(captions, args.debug)
102
+
103
+ for (image_path, _), caption in zip(path_imgs, captions):
104
+ with open(os.path.splitext(image_path)[0] + args.caption_extension, "wt", encoding="utf-8") as f:
105
+ f.write(caption + "\n")
106
+ if args.debug:
107
+ logger.info(f"{image_path} {caption}")
108
+
109
+ # 読み込みの高速化のためにDataLoaderを使うオプション
110
+ if args.max_data_loader_n_workers is not None:
111
+ dataset = train_util.ImageLoadingDataset(image_paths)
112
+ data = torch.utils.data.DataLoader(
113
+ dataset,
114
+ batch_size=args.batch_size,
115
+ shuffle=False,
116
+ num_workers=args.max_data_loader_n_workers,
117
+ collate_fn=collate_fn_remove_corrupted,
118
+ drop_last=False,
119
+ )
120
+ else:
121
+ data = [[(None, ip)] for ip in image_paths]
122
+
123
+ b_imgs = []
124
+ for data_entry in tqdm(data, smoothing=0.0):
125
+ for data in data_entry:
126
+ if data is None:
127
+ continue
128
+
129
+ image, image_path = data
130
+ if image is None:
131
+ try:
132
+ image = Image.open(image_path)
133
+ if image.mode != "RGB":
134
+ image = image.convert("RGB")
135
+ except Exception as e:
136
+ logger.error(f"Could not load image path / 画像を読み込めません: {image_path}, error: {e}")
137
+ continue
138
+
139
+ b_imgs.append((image_path, image))
140
+ if len(b_imgs) >= args.batch_size:
141
+ run_batch(b_imgs)
142
+ b_imgs.clear()
143
+
144
+ if len(b_imgs) > 0:
145
+ run_batch(b_imgs)
146
+
147
+ logger.info("done!")
148
+
149
+
150
+ def setup_parser() -> argparse.ArgumentParser:
151
+ parser = argparse.ArgumentParser()
152
+ parser.add_argument("train_data_dir", type=str, help="directory for train images / 学習画像データのディレクトリ")
153
+ parser.add_argument("--caption_extension", type=str, default=".caption", help="extension of caption file / 出力されるキャプションファイルの拡張子")
154
+ parser.add_argument(
155
+ "--model_id",
156
+ type=str,
157
+ default="microsoft/git-large-textcaps",
158
+ help="model id for GIT in Hugging Face / 使用するGITのHugging FaceのモデルID",
159
+ )
160
+ parser.add_argument("--batch_size", type=int, default=1, help="batch size in inference / 推論時のバッチサイズ")
161
+ parser.add_argument(
162
+ "--max_data_loader_n_workers",
163
+ type=int,
164
+ default=None,
165
+ help="enable image reading by DataLoader with this number of workers (faster) / DataLoaderによる画像読み込みを有効にしてこのワーカー数を適用する(読み込みを高速化)",
166
+ )
167
+ parser.add_argument("--max_length", type=int, default=50, help="max length of caption / captionの最大長")
168
+ parser.add_argument(
169
+ "--remove_words",
170
+ action="store_true",
171
+ help="remove like `with the words xxx` from caption / `with the words xxx`のような部分をキャプションから削除する",
172
+ )
173
+ parser.add_argument("--debug", action="store_true", help="debug mode")
174
+ parser.add_argument("--recursive", action="store_true", help="search for images in subfolders recursively / サブフォルダを再帰的に検索する")
175
+
176
+ return parser
177
+
178
+
179
+ if __name__ == "__main__":
180
+ parser = setup_parser()
181
+
182
+ args = parser.parse_args()
183
+ main(args)
masked_loss_README-ja.md ADDED
@@ -0,0 +1,57 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ## マスクロスについて
2
+
3
+ マスクロスは、入力画像のマスクで指定された部分だけ損失計算することで、画像の一部分だけを学習することができる機能です。
4
+ たとえばキャラクタを学習したい場合、キャラクタ部分だけをマスクして学習することで、背景を無視して学習することができます。
5
+
6
+ マスクロスのマスクには、二種類の指定方法があります。
7
+
8
+ - マスク画像を用いる方法
9
+ - 透明度(アルファチャネル)を使用する方法
10
+
11
+ なお、サンプルは [ずんずんPJイラスト/3Dデータ](https://zunko.jp/con_illust.html) の「AI画像モデル用学習データ」を使用しています。
12
+
13
+ ### マスク画像を用いる方法
14
+
15
+ 学習画像それぞれに対応するマスク画像を用意する方法です。学習画像と同じファイル名のマスク画像を用意し、それを学習画像と別のディレクトリに保存します。
16
+
17
+ - 学習画像
18
+ ![image](https://github.com/kohya-ss/sd-scripts/assets/52813779/607c5116-5f62-47de-8b66-9c4a597f0441)
19
+ - マスク画像
20
+ ![image](https://github.com/kohya-ss/sd-scripts/assets/52813779/53e9b0f8-a4bf-49ed-882d-4026f84e8450)
21
+
22
+ ```.toml
23
+ [[datasets.subsets]]
24
+ image_dir = "/path/to/a_zundamon"
25
+ caption_extension = ".txt"
26
+ conditioning_data_dir = "/path/to/a_zundamon_mask"
27
+ num_repeats = 8
28
+ ```
29
+
30
+ マスク画像は、学習画像と同じサイズで、学習する部分を白、無視する部分を黒で描画します。グレースケールにも対応しています(127 ならロス重みが 0.5 になります)。なお、正確にはマスク画像の R チャネルが用いられます。
31
+
32
+ DreamBooth 方式の dataset で、`conditioning_data_dir` で指定したディレクトリにマスク画像を保存してください。ControlNet のデータセットと同じですので、詳細は [ControlNet-LLLite](train_lllite_README-ja.md#データセットの準備) を参照してください。
33
+
34
+ ### 透明度(アルファチャネル)を使用する方法
35
+
36
+ 学習画像の透明度(アルファチャネル)がマスクとして使用されます。透明度が 0 の部分は無視され、255 の部分は学習されます。半透明の場合は、その透明度に応じてロス重みが変化します(127 ならおおむね 0.5)。
37
+
38
+ ![image](https://github.com/kohya-ss/sd-scripts/assets/52813779/0baa129b-446a-4aac-b98c-7208efb0e75e)
39
+
40
+ ※それぞれの画像は透過PNG
41
+
42
+ 学習時のスクリプトのオプションに `--alpha_mask` を指定するか、dataset の設定ファイルの subset で、`alpha_mask` を指定してください。たとえば、以下のようになります。
43
+
44
+ ```toml
45
+ [[datasets.subsets]]
46
+ image_dir = "/path/to/image/dir"
47
+ caption_extension = ".txt"
48
+ num_repeats = 8
49
+ alpha_mask = true
50
+ ```
51
+
52
+ ## 学習時の注意事項
53
+
54
+ - 現時点では DreamBooth 方式の dataset のみ対応しています。
55
+ - マスクは latents のサイズ、つまり 1/8 に縮小されてから適用されます。そのため、細かい部分(たとえばアホ毛やイヤリングなど)はうまく学習できない可能性があります。マスクをわずかに拡張するなどの工夫が必要かもしれません。
56
+ - マスクロスを用いる場合、学習対象外の部分をキャプションに含める必要はないかもしれません。(要検証)
57
+ - `alpha_mask` の場合、マスクの有無を切り替えると latents キャッシュが自動的に再生成されます。
masked_loss_README.md ADDED
@@ -0,0 +1,56 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ## Masked Loss
2
+
3
+ Masked loss is a feature that allows you to train only part of an image by calculating the loss only for the part specified by the mask of the input image. For example, if you want to train a character, you can train only the character part by masking it, ignoring the background.
4
+
5
+ There are two ways to specify the mask for masked loss.
6
+
7
+ - Using a mask image
8
+ - Using transparency (alpha channel) of the image
9
+
10
+ The sample uses the "AI image model training data" from [ZunZunPJ Illustration/3D Data](https://zunko.jp/con_illust.html).
11
+
12
+ ### Using a mask image
13
+
14
+ This is a method of preparing a mask image corresponding to each training image. Prepare a mask image with the same file name as the training image and save it in a different directory from the training image.
15
+
16
+ - Training image
17
+ ![image](https://github.com/kohya-ss/sd-scripts/assets/52813779/607c5116-5f62-47de-8b66-9c4a597f0441)
18
+ - Mask image
19
+ ![image](https://github.com/kohya-ss/sd-scripts/assets/52813779/53e9b0f8-a4bf-49ed-882d-4026f84e8450)
20
+
21
+ ```.toml
22
+ [[datasets.subsets]]
23
+ image_dir = "/path/to/a_zundamon"
24
+ caption_extension = ".txt"
25
+ conditioning_data_dir = "/path/to/a_zundamon_mask"
26
+ num_repeats = 8
27
+ ```
28
+
29
+ The mask image is the same size as the training image, with the part to be trained drawn in white and the part to be ignored in black. It also supports grayscale (127 gives a loss weight of 0.5). The R channel of the mask image is used currently.
30
+
31
+ Use the dataset in the DreamBooth method, and save the mask image in the directory specified by `conditioning_data_dir`. It is the same as the ControlNet dataset, so please refer to [ControlNet-LLLite](train_lllite_README.md#Preparing-the-dataset) for details.
32
+
33
+ ### Using transparency (alpha channel) of the image
34
+
35
+ The transparency (alpha channel) of the training image is used as a mask. The part with transparency 0 is ignored, the part with transparency 255 is trained. For semi-transparent parts, the loss weight changes according to the transparency (127 gives a weight of about 0.5).
36
+
37
+ ![image](https://github.com/kohya-ss/sd-scripts/assets/52813779/0baa129b-446a-4aac-b98c-7208efb0e75e)
38
+
39
+ ※Each image is a transparent PNG
40
+
41
+ Specify `--alpha_mask` in the training script options or specify `alpha_mask` in the subset of the dataset configuration file. For example, it will look like this.
42
+
43
+ ```toml
44
+ [[datasets.subsets]]
45
+ image_dir = "/path/to/image/dir"
46
+ caption_extension = ".txt"
47
+ num_repeats = 8
48
+ alpha_mask = true
49
+ ```
50
+
51
+ ## Notes on training
52
+
53
+ - At the moment, only the dataset in the DreamBooth method is supported.
54
+ - The mask is applied after the size is reduced to 1/8, which is the size of the latents. Therefore, fine details (such as ahoge or earrings) may not be learned well. Some dilations of the mask may be necessary.
55
+ - If using masked loss, it may not be necessary to include parts that are not to be trained in the caption. (To be verified)
56
+ - In the case of `alpha_mask`, the latents cache is automatically regenerated when the enable/disable state of the mask is switched.