anyMODE commited on
Commit
40b438e
·
verified ·
1 Parent(s): c66ab51

Upload lora_redim.py

Browse files

Change alpha with the same ratio as rank

Files changed (1) hide show
  1. lora_redim.py +95 -45
lora_redim.py CHANGED
@@ -6,18 +6,21 @@ from safetensors import safe_open
6
  from tqdm import tqdm
7
 
8
 
9
- def resize_lora_model(model_path, output_path, new_dim, device):
10
  """
11
- Resizes the LoRA dimension of a model using SVD for optimal weight preservation.
 
12
 
13
  Args:
14
  model_path (str): Path to the LoRA model to resize.
15
  output_path (str): Path to save the new resized model.
16
  new_dim (int): The target new dimension for the LoRA weights.
17
  device (str): The device to run calculations on ('cuda' or 'cpu').
 
18
  """
19
  print(f"Loading model from: {model_path}")
20
- model = load_file(model_path)
 
21
  new_model = {}
22
 
23
  # --- Metadata & Weight Inspection ---
@@ -44,6 +47,14 @@ def resize_lora_model(model_path, output_path, new_dim, device):
44
  print(f"Inferred original dimension from weights: {original_dim}")
45
  break
46
 
 
 
 
 
 
 
 
 
47
  # Infer alpha from weights if not in metadata
48
  if alpha is None:
49
  for key in model.keys():
@@ -52,16 +63,19 @@ def resize_lora_model(model_path, output_path, new_dim, device):
52
  print(f"Inferred alpha from weights: {alpha}")
53
  break
54
 
55
- # Fallback for alpha if still not found
56
- if alpha is None and original_dim is not None:
57
  alpha = float(original_dim)
58
- print(f"Alpha not found, falling back to using dimension: {alpha}")
59
 
60
  # --- Tensor Processing ---
 
 
 
 
61
  lora_keys_to_process = set()
62
  for key in model.keys():
63
  if 'lora_' in key and key.endswith('.weight'):
64
- # Get the base name (e.g., "lora_unet_down_blocks_0_attentions_0_proj_in")
65
  base_key = key.split('.lora_')[0]
66
  lora_keys_to_process.add(base_key)
67
 
@@ -70,12 +84,13 @@ def resize_lora_model(model_path, output_path, new_dim, device):
70
  return
71
 
72
  print(f"\nFound {len(lora_keys_to_process)} LoRA modules to resize...")
 
73
 
74
  for base_key in tqdm(sorted(list(lora_keys_to_process)), desc="Resizing modules"):
75
  try:
76
  down_key, up_key = None, None
77
 
78
- # Determine naming convention
79
  if base_key + ".lora_down.weight" in model:
80
  down_key = base_key + ".lora_down.weight"
81
  up_key = base_key + ".lora_up.weight"
@@ -85,71 +100,98 @@ def resize_lora_model(model_path, output_path, new_dim, device):
85
  else:
86
  continue
87
 
88
- # Move weights to the selected device for calculation
89
- down_weight = model[down_key].to(device)
90
- up_weight = model[up_key].to(device)
91
-
92
- # --- SVD Resizing ---
93
  original_dtype = up_weight.dtype
94
 
95
- # Combine the two matrices to get the full weight update
 
 
 
 
96
  conv2d = down_weight.ndim == 4
97
  if conv2d:
98
- # For conv layers, treat spatial dims as batch dims
99
  down_weight = down_weight.flatten(1)
100
  up_weight = up_weight.flatten(1)
101
 
 
102
  full_weight = up_weight @ down_weight
103
 
104
- # Always cast to float32 for SVD, as some devices (CPU, and some GPUs) don't support bfloat16
105
- U, S, Vh = torch.linalg.svd(full_weight.to(torch.float32))
 
 
 
 
 
 
106
 
107
- # Truncate or pad the SVD components
108
- U = U[:, :new_dim]
109
- S = S[:new_dim]
110
- Vh = Vh[:new_dim, :]
 
111
 
112
- # Reconstruct the new low-rank matrices
113
- new_down = torch.diag(S) @ Vh
114
- new_up = U
 
 
 
 
 
 
115
 
116
- # Reshape back to original conv format if necessary
117
  if conv2d:
118
- new_down = new_down.reshape(new_dim, down_weight.shape[1], 1, 1)
119
- new_up = new_up.reshape(up_weight.shape[0], new_dim, 1, 1)
120
 
121
- # Move back to CPU and original dtype for saving
122
- new_model[down_key] = new_down.contiguous().to(original_dtype).cpu()
123
- new_model[up_key] = new_up.contiguous().to(original_dtype).cpu()
124
 
125
- # Copy alpha tensor if it exists for this key
 
126
  alpha_key = base_key + ".alpha"
127
  if alpha_key in model:
128
- new_model[alpha_key] = model[alpha_key]
129
-
130
- except KeyError:
 
 
 
 
 
131
  continue
132
 
133
- # Copy non-LoRA tensors
134
  for key, value in model.items():
135
  if ".lora_" not in key:
136
- new_model[key] = value
 
 
137
 
138
- # --- Save New Model ---
139
  new_metadata = {'ss_network_dim': str(new_dim)}
140
- if alpha is not None and original_dim is not None and original_dim > 0:
141
- new_alpha = alpha * (new_dim / original_dim)
142
- new_metadata['ss_network_alpha'] = str(new_alpha)
143
- print(f"\nNew alpha scaled to: {new_alpha:.2f}")
 
 
 
 
 
 
144
 
145
  print(f"\nSaving resized model to: {output_path}")
146
  save_file(new_model, output_path, metadata=new_metadata)
147
- print("Done!")
148
 
149
 
150
  if __name__ == "__main__":
151
  parser = argparse.ArgumentParser(
152
- description="Resize a LoRA model to a new dimension using SVD.",
153
  formatter_class=argparse.RawTextHelpFormatter
154
  )
155
  parser.add_argument("model_path", type=str, help="Path to the LoRA model (.safetensors).")
@@ -157,6 +199,15 @@ if __name__ == "__main__":
157
  parser.add_argument("new_dim", type=int, help="The new LoRA dimension (rank).")
158
  parser.add_argument("--device", type=str, default=None,
159
  help="Device to use (e.g., 'cpu', 'cuda'). Autodetects if not specified.")
 
 
 
 
 
 
 
 
 
160
 
161
  args = parser.parse_args()
162
 
@@ -167,5 +218,4 @@ if __name__ == "__main__":
167
 
168
  print(f"Using device: {device}")
169
 
170
- resize_lora_model(args.model_path, args.output_path, args.new_dim, device)
171
-
 
6
  from tqdm import tqdm
7
 
8
 
9
+ def resize_lora_model(model_path, output_path, new_dim, device, method):
10
  """
11
+ Resizes the LoRA dimension of a model using SVD or Randomized SVD.
12
+ Also scales the alpha value(s) proportionally.
13
 
14
  Args:
15
  model_path (str): Path to the LoRA model to resize.
16
  output_path (str): Path to save the new resized model.
17
  new_dim (int): The target new dimension for the LoRA weights.
18
  device (str): The device to run calculations on ('cuda' or 'cpu').
19
+ method (str): The resizing method to use ('svd' or 'randomized_svd').
20
  """
21
  print(f"Loading model from: {model_path}")
22
+ # Load the model onto CPU memory first to avoid VRAM issues with large models
23
+ model = load_file(model_path, device="cpu")
24
  new_model = {}
25
 
26
  # --- Metadata & Weight Inspection ---
 
47
  print(f"Inferred original dimension from weights: {original_dim}")
48
  break
49
 
50
+ if original_dim is None:
51
+ print("Error: Could not determine original LoRA dimension.")
52
+ return
53
+
54
+ if original_dim == new_dim:
55
+ print("Error: New dimension is the same as the original dimension. No changes to make.")
56
+ return
57
+
58
  # Infer alpha from weights if not in metadata
59
  if alpha is None:
60
  for key in model.keys():
 
63
  print(f"Inferred alpha from weights: {alpha}")
64
  break
65
 
66
+ # Fallback for alpha if still not found
67
+ if alpha is None:
68
  alpha = float(original_dim)
69
+ print(f"Alpha not found, falling back to using dimension value: {alpha}")
70
 
71
  # --- Tensor Processing ---
72
+ # Calculate the scaling ratio for alpha
73
+ ratio = new_dim / original_dim
74
+ print(f"Dimension resize ratio: {ratio:.4f}")
75
+
76
  lora_keys_to_process = set()
77
  for key in model.keys():
78
  if 'lora_' in key and key.endswith('.weight'):
 
79
  base_key = key.split('.lora_')[0]
80
  lora_keys_to_process.add(base_key)
81
 
 
84
  return
85
 
86
  print(f"\nFound {len(lora_keys_to_process)} LoRA modules to resize...")
87
+ print(f"Using '{method}' method for resizing.")
88
 
89
  for base_key in tqdm(sorted(list(lora_keys_to_process)), desc="Resizing modules"):
90
  try:
91
  down_key, up_key = None, None
92
 
93
+ # Determine the correct key names for down and up weights
94
  if base_key + ".lora_down.weight" in model:
95
  down_key = base_key + ".lora_down.weight"
96
  up_key = base_key + ".lora_up.weight"
 
100
  else:
101
  continue
102
 
103
+ down_weight = model[down_key]
104
+ up_weight = model[up_key]
 
 
 
105
  original_dtype = up_weight.dtype
106
 
107
+ # Move weights to the selected device for processing
108
+ down_weight = down_weight.to(device, dtype=torch.float32)
109
+ up_weight = up_weight.to(device, dtype=torch.float32)
110
+
111
+ # Handle both linear and convolutional layers
112
  conv2d = down_weight.ndim == 4
113
  if conv2d:
114
+ conv_shape = down_weight.shape
115
  down_weight = down_weight.flatten(1)
116
  up_weight = up_weight.flatten(1)
117
 
118
+ # Reconstruct the full weight matrix
119
  full_weight = up_weight @ down_weight
120
 
121
+ if method == 'svd':
122
+ # --- Full SVD Resizing (Accurate) ---
123
+ U, S, Vh = torch.linalg.svd(full_weight)
124
+
125
+ # Truncate or pad the SVD components to the new dimension
126
+ U = U[:, :new_dim]
127
+ S = S[:new_dim]
128
+ Vh = Vh[:new_dim, :]
129
 
130
+ # Distribute singular values (S) back to the new matrices
131
+ # A common practice is to take the square root for balanced distribution
132
+ S_sqrt = torch.sqrt(S)
133
+ new_up = U @ torch.diag(S_sqrt)
134
+ new_down = torch.diag(S_sqrt) @ Vh
135
 
136
+ elif method == 'randomized_svd':
137
+ # --- Randomized SVD Resizing (Fast Approximation) ---
138
+ U, S, V = torch.svd_lowrank(full_weight, q=new_dim)
139
+ Vh = V.T
140
+
141
+ # Distribute singular values like in the full SVD method
142
+ S_sqrt = torch.sqrt(S)
143
+ new_up = U @ torch.diag(S_sqrt)
144
+ new_down = torch.diag(S_sqrt) @ Vh
145
 
 
146
  if conv2d:
147
+ new_down = new_down.reshape(new_dim, conv_shape[1], conv_shape[2], conv_shape[3])
 
148
 
149
+ # Store the new resized weights
150
+ new_model[down_key] = new_down.contiguous().to(original_dtype)
151
+ new_model[up_key] = new_up.contiguous().to(original_dtype)
152
 
153
+ # --- MODIFICATION START ---
154
+ # If a per-module alpha exists, scale it proportionally.
155
  alpha_key = base_key + ".alpha"
156
  if alpha_key in model:
157
+ original_alpha_tensor = model[alpha_key]
158
+ # Calculate new alpha and create a new tensor with the same dtype
159
+ new_alpha_value = original_alpha_tensor.item() * ratio
160
+ new_model[alpha_key] = torch.tensor(new_alpha_value, dtype=original_alpha_tensor.dtype)
161
+ # --- MODIFICATION END ---
162
+
163
+ except Exception as e:
164
+ print(f"Warning: Failed to process {base_key}. Error: {e}")
165
  continue
166
 
167
+ # Copy all non-LoRA tensors from the original model
168
  for key, value in model.items():
169
  if ".lora_" not in key:
170
+ # Ensure we don't copy an old alpha that has already been processed
171
+ if ".alpha" not in key or key not in new_model:
172
+ new_model[key] = value
173
 
174
+ # Update metadata with the new dimension and the globally scaled alpha
175
  new_metadata = {'ss_network_dim': str(new_dim)}
176
+ new_alpha = alpha * ratio
177
+ new_metadata['ss_network_alpha'] = str(new_alpha)
178
+ print(f"\nNew global alpha scaled to: {new_alpha:.2f}")
179
+
180
+ # Move all tensors to CPU before saving
181
+ if device != 'cpu':
182
+ print("\nMoving processed tensors to CPU for saving...")
183
+ for key in tqdm(new_model.keys(), desc="Finalizing"):
184
+ if isinstance(new_model[key], torch.Tensor):
185
+ new_model[key] = new_model[key].cpu()
186
 
187
  print(f"\nSaving resized model to: {output_path}")
188
  save_file(new_model, output_path, metadata=new_metadata)
189
+ print("Done! 🎉")
190
 
191
 
192
  if __name__ == "__main__":
193
  parser = argparse.ArgumentParser(
194
+ description="Resize a LoRA model to a new dimension and scales alpha proportionally.",
195
  formatter_class=argparse.RawTextHelpFormatter
196
  )
197
  parser.add_argument("model_path", type=str, help="Path to the LoRA model (.safetensors).")
 
199
  parser.add_argument("new_dim", type=int, help="The new LoRA dimension (rank).")
200
  parser.add_argument("--device", type=str, default=None,
201
  help="Device to use (e.g., 'cpu', 'cuda'). Autodetects if not specified.")
202
+ parser.add_argument(
203
+ "--method",
204
+ type=str,
205
+ default="svd",
206
+ choices=["svd", "randomized_svd"],
207
+ help="""Resizing method:
208
+ 'svd' (default): Accurate but slower. Uses full SVD for optimal weight preservation.
209
+ 'randomized_svd': Faster approximation of SVD. Excellent for speed on large models."""
210
+ )
211
 
212
  args = parser.parse_args()
213
 
 
218
 
219
  print(f"Using device: {device}")
220
 
221
+ resize_lora_model(args.model_path, args.output_path, args.new_dim, device, args.method)