anyMODE commited on
Commit
59d2585
·
verified ·
1 Parent(s): 042c5bc

Upload lora_redim.py

Browse files
Files changed (1) hide show
  1. lora_redim.py +171 -0
lora_redim.py ADDED
@@ -0,0 +1,171 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import argparse
3
+ import torch
4
+ from safetensors.torch import load_file, save_file
5
+ from safetensors import safe_open
6
+ from tqdm import tqdm
7
+
8
+
9
+ def resize_lora_model(model_path, output_path, new_dim, device):
10
+ """
11
+ Resizes the LoRA dimension of a model using SVD for optimal weight preservation.
12
+
13
+ Args:
14
+ model_path (str): Path to the LoRA model to resize.
15
+ output_path (str): Path to save the new resized model.
16
+ new_dim (int): The target new dimension for the LoRA weights.
17
+ device (str): The device to run calculations on ('cuda' or 'cpu').
18
+ """
19
+ print(f"Loading model from: {model_path}")
20
+ model = load_file(model_path)
21
+ new_model = {}
22
+
23
+ # --- Metadata & Weight Inspection ---
24
+ original_dim = None
25
+ alpha = None
26
+ try:
27
+ with safe_open(model_path, framework="pt", device="cpu") as f:
28
+ metadata = f.metadata()
29
+ if metadata:
30
+ if 'ss_network_dim' in metadata:
31
+ original_dim = int(metadata['ss_network_dim'])
32
+ print(f"Original dimension (from metadata): {original_dim}")
33
+ if 'ss_network_alpha' in metadata:
34
+ alpha = float(metadata['ss_network_alpha'])
35
+ print(f"Original alpha (from metadata): {alpha}")
36
+ except Exception as e:
37
+ print(f"Could not read metadata: {e}. Dimension and alpha will be inferred.")
38
+
39
+ # Infer original_dim from weights if not in metadata
40
+ if original_dim is None:
41
+ for key in model.keys():
42
+ if key.endswith((".lora_down.weight", ".lora_A.weight")):
43
+ original_dim = model[key].shape[0]
44
+ print(f"Inferred original dimension from weights: {original_dim}")
45
+ break
46
+
47
+ # Infer alpha from weights if not in metadata
48
+ if alpha is None:
49
+ for key in model.keys():
50
+ if key.endswith(".alpha"):
51
+ alpha = model[key].item()
52
+ print(f"Inferred alpha from weights: {alpha}")
53
+ break
54
+
55
+ # Fallback for alpha if still not found
56
+ if alpha is None and original_dim is not None:
57
+ alpha = float(original_dim)
58
+ print(f"Alpha not found, falling back to using dimension: {alpha}")
59
+
60
+ # --- Tensor Processing ---
61
+ lora_keys_to_process = set()
62
+ for key in model.keys():
63
+ if 'lora_' in key and key.endswith('.weight'):
64
+ # Get the base name (e.g., "lora_unet_down_blocks_0_attentions_0_proj_in")
65
+ base_key = key.split('.lora_')[0]
66
+ lora_keys_to_process.add(base_key)
67
+
68
+ if not lora_keys_to_process:
69
+ print("Error: No LoRA weights found in the model.")
70
+ return
71
+
72
+ print(f"\nFound {len(lora_keys_to_process)} LoRA modules to resize...")
73
+
74
+ for base_key in tqdm(sorted(list(lora_keys_to_process)), desc="Resizing modules"):
75
+ try:
76
+ down_key, up_key = None, None
77
+
78
+ # Determine naming convention
79
+ if base_key + ".lora_down.weight" in model:
80
+ down_key = base_key + ".lora_down.weight"
81
+ up_key = base_key + ".lora_up.weight"
82
+ elif base_key + ".lora_A.weight" in model:
83
+ down_key = base_key + ".lora_A.weight"
84
+ up_key = base_key + ".lora_B.weight"
85
+ else:
86
+ continue
87
+
88
+ # Move weights to the selected device for calculation
89
+ down_weight = model[down_key].to(device)
90
+ up_weight = model[up_key].to(device)
91
+
92
+ # --- SVD Resizing ---
93
+ original_dtype = up_weight.dtype
94
+
95
+ # Combine the two matrices to get the full weight update
96
+ conv2d = down_weight.ndim == 4
97
+ if conv2d:
98
+ # For conv layers, treat spatial dims as batch dims
99
+ down_weight = down_weight.flatten(1)
100
+ up_weight = up_weight.flatten(1)
101
+
102
+ full_weight = up_weight @ down_weight
103
+
104
+ # Always cast to float32 for SVD, as some devices (CPU, and some GPUs) don't support bfloat16
105
+ U, S, Vh = torch.linalg.svd(full_weight.to(torch.float32))
106
+
107
+ # Truncate or pad the SVD components
108
+ U = U[:, :new_dim]
109
+ S = S[:new_dim]
110
+ Vh = Vh[:new_dim, :]
111
+
112
+ # Reconstruct the new low-rank matrices
113
+ new_down = torch.diag(S) @ Vh
114
+ new_up = U
115
+
116
+ # Reshape back to original conv format if necessary
117
+ if conv2d:
118
+ new_down = new_down.reshape(new_dim, down_weight.shape[1], 1, 1)
119
+ new_up = new_up.reshape(up_weight.shape[0], new_dim, 1, 1)
120
+
121
+ # Move back to CPU and original dtype for saving
122
+ new_model[down_key] = new_down.contiguous().to(original_dtype).cpu()
123
+ new_model[up_key] = new_up.contiguous().to(original_dtype).cpu()
124
+
125
+ # Copy alpha tensor if it exists for this key
126
+ alpha_key = base_key + ".alpha"
127
+ if alpha_key in model:
128
+ new_model[alpha_key] = model[alpha_key]
129
+
130
+ except KeyError:
131
+ continue
132
+
133
+ # Copy non-LoRA tensors
134
+ for key, value in model.items():
135
+ if ".lora_" not in key:
136
+ new_model[key] = value
137
+
138
+ # --- Save New Model ---
139
+ new_metadata = {'ss_network_dim': str(new_dim)}
140
+ if alpha is not None and original_dim is not None and original_dim > 0:
141
+ new_alpha = alpha * (new_dim / original_dim)
142
+ new_metadata['ss_network_alpha'] = str(new_alpha)
143
+ print(f"\nNew alpha scaled to: {new_alpha:.2f}")
144
+
145
+ print(f"\nSaving resized model to: {output_path}")
146
+ save_file(new_model, output_path, metadata=new_metadata)
147
+ print("Done!")
148
+
149
+
150
+ if __name__ == "__main__":
151
+ parser = argparse.ArgumentParser(
152
+ description="Resize a LoRA model to a new dimension using SVD.",
153
+ formatter_class=argparse.RawTextHelpFormatter
154
+ )
155
+ parser.add_argument("model_path", type=str, help="Path to the LoRA model (.safetensors).")
156
+ parser.add_argument("output_path", type=str, help="Path to save the resized LoRA model.")
157
+ parser.add_argument("new_dim", type=int, help="The new LoRA dimension (rank).")
158
+ parser.add_argument("--device", type=str, default=None,
159
+ help="Device to use (e.g., 'cpu', 'cuda'). Autodetects if not specified.")
160
+
161
+ args = parser.parse_args()
162
+
163
+ if args.device:
164
+ device = args.device
165
+ else:
166
+ device = "cuda" if torch.cuda.is_available() else "cpu"
167
+
168
+ print(f"Using device: {device}")
169
+
170
+ resize_lora_model(args.model_path, args.output_path, args.new_dim, device)
171
+