anyMODE commited on
Commit
138d9de
·
verified ·
1 Parent(s): faa1b64

Upload extract_lora.py

Browse files
Files changed (1) hide show
  1. extract_lora.py +119 -87
extract_lora.py CHANGED
@@ -5,6 +5,14 @@ from tqdm import tqdm
5
  import sys
6
 
7
 
 
 
 
 
 
 
 
 
8
  def get_torch_dtype(dtype_str: str):
9
  """Converts a string to a torch.dtype object."""
10
  if dtype_str == "fp32":
@@ -16,124 +24,148 @@ def get_torch_dtype(dtype_str: str):
16
  raise ValueError(f"Unsupported dtype: {dtype_str}")
17
 
18
 
19
- def extract_and_svd_lora(model_a_path: str, model_b_path: str, output_path: str, rank: int, device: str, alpha: float,
20
- dtype: torch.dtype):
21
- """
22
- Extracts the difference between two models, applies SVD to reduce the rank,
23
- and saves the result as a LoRA file.
24
- """
25
- print(f"Loading base model A: {model_a_path}")
26
- print(f"Loading finetuned model B: {model_b_path}")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
27
 
28
  lora_tensors = {}
 
 
 
 
29
 
30
- with safe_open(model_a_path, framework="pt", device="cpu") as f_a, \
31
- safe_open(model_b_path, framework="pt", device="cpu") as f_b:
 
 
32
 
33
- keys_a = set(f_a.keys())
34
- keys_b = set(f_b.keys())
35
- common_keys = keys_a.intersection(keys_b)
36
 
37
- # Filter for processable layers (typically linear and conv weights)
38
- # We exclude biases and non-weight tensors.
39
- weight_keys = {k for k in common_keys if k.endswith('.weight') and 'lora_' not in k}
40
 
41
- if not weight_keys:
42
- print("No common weight keys found between the two models. Exiting.")
 
 
 
43
  sys.exit(1)
44
 
45
- print(f"Found {len(weight_keys)} common weight keys to process.")
46
 
47
- # Main processing loop with progress bar
48
- for key in tqdm(sorted(list(weight_keys)), desc="Processing Layers"):
49
  try:
50
- # Load tensors and move to the selected device and dtype
51
- tensor_a = f_a.get_tensor(key).to(device=device, dtype=dtype)
52
- tensor_b = f_b.get_tensor(key).to(device=device, dtype=dtype)
 
 
53
 
54
  if tensor_a.shape != tensor_b.shape:
55
- print(f"Skipping key {key} due to shape mismatch: A={tensor_a.shape}, B={tensor_b.shape}")
56
  continue
57
 
58
- # Calculate the difference (delta weight)
59
- delta_w = tensor_b - tensor_a
60
-
61
- # SVD works on 2D matrices. Reshape conv layers and other ND tensors.
62
- original_shape = delta_w.shape
63
- if delta_w.dim() > 2:
64
- delta_w = delta_w.view(original_shape[0], -1)
65
-
66
- # --- Core SVD Logic ---
67
- # ΔW ≈ U * S * Vh
68
- # U: Left singular vectors
69
- # S: Singular values (a 1D vector)
70
- # Vh: Right singular vectors (transposed)
71
- U, S, Vh = torch.linalg.svd(delta_w, full_matrices=False)
72
-
73
- # Truncate to the desired rank
74
- current_rank = min(rank, S.size(0)) # Ensure rank is not > possible rank
75
- U = U[:, :current_rank]
76
- S = S[:current_rank]
77
- Vh = Vh[:current_rank, :]
78
-
79
- # --- Decompose into LoRA A and B matrices ---
80
- # LoRA A (lora_down) is Vh
81
- # LoRA B (lora_up) is U * S
82
- # We scale lora_up by the singular values to retain the magnitude
83
- lora_down = Vh
84
- lora_up = U @ torch.diag(S)
85
-
86
- # Reshape back to original conv format if necessary
87
- if len(original_shape) > 2:
88
- # For Conv2D, lora_down is (rank, in_channels * k_h * k_w)
89
- # and lora_up is (out_channels, rank). No reshape needed for up.
90
- pass # The matrix form is standard for LoRA conv layers
91
-
92
- # Create LoRA tensor names
93
- base_name = key.replace('.weight', '')
94
- lora_down_name = f"{base_name}.lora_down.weight"
95
- lora_up_name = f"{base_name}.lora_up.weight"
96
- alpha_name = f"{base_name}.alpha"
97
-
98
- # Store tensors, moving them to CPU for saving
99
- lora_tensors[lora_down_name] = lora_down.contiguous().cpu().to(torch.float32)
100
- lora_tensors[lora_up_name] = lora_up.contiguous().cpu().to(torch.float32)
101
- lora_tensors[alpha_name] = torch.tensor(alpha).to(torch.float32)
102
 
103
  except Exception as e:
104
- print(f"Failed to process key {key}: {e}")
105
 
106
- # Save the final LoRA file
107
  if not lora_tensors:
108
  print("No tensors were processed. Output file will not be created.")
109
  return
110
 
111
- print(f"\nSaving {len(lora_tensors)} tensors to {output_path}...")
112
- save_file(lora_tensors, output_path)
113
  print("✅ Done!")
114
 
115
 
116
  if __name__ == "__main__":
117
- parser = argparse.ArgumentParser(description="Extract and SVD a LoRA from two SafeTensors checkpoints.")
118
-
119
- parser.add_argument("model_a", type=str, help="Path to the base model (A) checkpoint in .safetensors format.")
120
- parser.add_argument("model_b", type=str, help="Path to the finetuned model (B) checkpoint in .safetensors format.")
121
- parser.add_argument("output", type=str, help="Path to save the output LoRA file in .safetensors format.")
122
-
123
- parser.add_argument("--rank", type=int, required=True, help="The target rank for the SVD.")
 
 
124
  parser.add_argument("--device", type=str, default="cuda", choices=["cuda", "cpu"],
125
- help="Device to use for computation ('cuda' or 'cpu').")
126
- parser.add_argument("--alpha", type=float, default=1.0, help="The alpha (scaling) factor for the LoRA.")
127
  parser.add_argument("--precision", type=str, default="fp32", choices=["fp32", "fp16", "bf16"],
128
- help="Precision to use for calculations.")
 
 
 
129
 
130
  args = parser.parse_args()
131
 
132
- # Device check
133
  if args.device == "cuda" and not torch.cuda.is_available():
134
  print("CUDA is not available. Falling back to CPU.")
135
  args.device = "cpu"
136
 
137
- dtype = get_torch_dtype(args.precision)
138
-
139
- extract_and_svd_lora(args.model_a, args.model_b, args.output, args.rank, args.device, args.alpha, dtype)
 
5
  import sys
6
 
7
 
8
+ def normalize_key(key):
9
+ """Strips the 'model.diffusion_model.' prefix from a key if it exists."""
10
+ prefix = 'model.diffusion_model.'
11
+ if key.startswith(prefix):
12
+ return key[len(prefix):]
13
+ return key
14
+
15
+
16
  def get_torch_dtype(dtype_str: str):
17
  """Converts a string to a torch.dtype object."""
18
  if dtype_str == "fp32":
 
24
  raise ValueError(f"Unsupported dtype: {dtype_str}")
25
 
26
 
27
+ def randomized_svd(matrix, rank, n_oversamples=10):
28
+ """Performs Randomized SVD for a faster approximation."""
29
+ max_rank = min(matrix.shape)
30
+ if rank >= max_rank:
31
+ rank = max_rank
32
+ n_oversamples = 0
33
+
34
+ target_rank = min(rank + n_oversamples, max_rank)
35
+
36
+ P = torch.randn((matrix.shape[1], target_rank), device=matrix.device, dtype=matrix.dtype)
37
+ Y = matrix @ P
38
+
39
+ Q, _ = torch.linalg.qr(Y.float())
40
+
41
+ B = Q.T @ matrix.float()
42
+
43
+ U_b, S, Vh = torch.linalg.svd(B, full_matrices=False)
44
+ U = Q @ U_b
45
+
46
+ U = U[:, :rank]
47
+ S = S[:rank]
48
+ Vh = Vh[:rank, :]
49
+
50
+ return U, S, Vh
51
+
52
+
53
+ def extract_and_svd_lora(args):
54
+ """Main function to extract, decompose, and save the LoRA."""
55
+ print(f"Loading base model A: {args.model_a}")
56
+ print(f"Loading finetuned model B: {args.model_b}")
57
+ print(f"Using decomposition method: {args.method}")
58
 
59
  lora_tensors = {}
60
+ dtype = get_torch_dtype(args.precision)
61
+
62
+ with safe_open(args.model_a, framework="pt", device="cpu") as f_a, \
63
+ safe_open(args.model_b, framework="pt", device="cpu") as f_b:
64
 
65
+ keys_a_original = set(f_a.keys())
66
+ keys_b_original = set(f_b.keys())
67
+ print(f"\nFound {len(keys_a_original)} keys in model A.")
68
+ print(f"Found {len(keys_b_original)} keys in model B.")
69
 
70
+ normalized_keys_a = {normalize_key(k): k for k in keys_a_original}
71
+ normalized_keys_b = {normalize_key(k): k for k in keys_b_original}
 
72
 
73
+ common_normalized_keys = set(normalized_keys_a.keys()).intersection(set(normalized_keys_b.keys()))
74
+ print(f"Found {len(common_normalized_keys)} common keys after normalization.\n")
 
75
 
76
+ processable_keys = {k for k in common_normalized_keys if
77
+ (k.endswith('.weight') or k.endswith('.bias')) and 'lora_' not in k}
78
+
79
+ if not processable_keys:
80
+ print("No common weight or bias keys found to process. Check if models are compatible.")
81
  sys.exit(1)
82
 
83
+ print(f"Found {len(processable_keys)} common keys to process.")
84
 
85
+ for norm_key in tqdm(sorted(list(processable_keys)), desc="Processing Layers"):
 
86
  try:
87
+ original_key_a = normalized_keys_a[norm_key]
88
+ original_key_b = normalized_keys_b[norm_key]
89
+
90
+ tensor_a = f_a.get_tensor(original_key_a).to(device=args.device, dtype=dtype)
91
+ tensor_b = f_b.get_tensor(original_key_b).to(device=args.device, dtype=dtype)
92
 
93
  if tensor_a.shape != tensor_b.shape:
94
+ tqdm.write(f"Skipping key {norm_key} due to shape mismatch")
95
  continue
96
 
97
+ delta = tensor_b - tensor_a
98
+
99
+ if norm_key.endswith('.weight'):
100
+ delta_w = delta
101
+ if delta_w.dim() < 2:
102
+ tqdm.write(f"Skipping weight key {norm_key} as it's not a 2D matrix.")
103
+ continue
104
+ if delta_w.dim() > 2:
105
+ delta_w = delta_w.view(delta_w.shape[0], -1)
106
+
107
+ if args.method == 'rsvd':
108
+ # Use the new oversamples argument
109
+ U, S, Vh = randomized_svd(delta_w, args.rank, n_oversamples=args.oversamples)
110
+ else:
111
+ U, S, Vh = torch.linalg.svd(delta_w, full_matrices=False)
112
+ current_rank = min(args.rank, S.size(0))
113
+ U = U[:, :current_rank]
114
+ S = S[:current_rank]
115
+ Vh = Vh[:current_rank, :]
116
+
117
+ lora_down = Vh
118
+ lora_up = U @ torch.diag(S)
119
+
120
+ base_name = norm_key.replace('.weight', '')
121
+ prefixed_base_name = f"diffusion_model.{base_name}"
122
+ lora_down_name = f"{prefixed_base_name}.lora_down.weight"
123
+ lora_up_name = f"{prefixed_base_name}.lora_up.weight"
124
+
125
+ lora_tensors[lora_down_name] = lora_down.contiguous().cpu().to(torch.float32)
126
+ lora_tensors[lora_up_name] = lora_up.contiguous().cpu().to(torch.float32)
127
+
128
+ elif norm_key.endswith('.bias'):
129
+ delta_b = delta
130
+ base_name = norm_key.replace('.bias', '')
131
+ prefixed_base_name = f"diffusion_model.{base_name}"
132
+ diff_b_name = f"{prefixed_base_name}.diff_b"
133
+ lora_tensors[diff_b_name] = delta_b.contiguous().cpu().to(torch.float32)
 
 
 
 
 
 
 
134
 
135
  except Exception as e:
136
+ tqdm.write(f"Failed to process key {norm_key}: {e}")
137
 
 
138
  if not lora_tensors:
139
  print("No tensors were processed. Output file will not be created.")
140
  return
141
 
142
+ print(f"\nSaving {len(lora_tensors)} tensors to {args.output}...")
143
+ save_file(lora_tensors, args.output)
144
  print("✅ Done!")
145
 
146
 
147
  if __name__ == "__main__":
148
+ parser = argparse.ArgumentParser(description="Extract a LoRA/LoRA+ from two SafeTensors checkpoints.")
149
+ parser.add_argument("model_a", type=str, help="Path to the base model (A) checkpoint.")
150
+ parser.add_argument("model_b", type=str, help="Path to the finetuned model (B) checkpoint.")
151
+ parser.add_argument("output", type=str, help="Path to save the output file.")
152
+
153
+ parser.add_argument("--rank", type=int, required=True, help="The target rank for the decomposition.")
154
+ parser.add_argument("--alpha", type=float, default=1.0,
155
+ help="Informational alpha value for scaling. This value is NOT saved in the file.")
156
+ parser.add_argument("--method", type=str, default="rsvd", choices=["svd", "rsvd"], help="Decomposition method.")
157
  parser.add_argument("--device", type=str, default="cuda", choices=["cuda", "cpu"],
158
+ help="Device to use for computation.")
 
159
  parser.add_argument("--precision", type=str, default="fp32", choices=["fp32", "fp16", "bf16"],
160
+ help="Precision for calculations.")
161
+ # --- NEW ARGUMENT ---
162
+ parser.add_argument("--oversamples", type=int, default=10,
163
+ help="Oversampling parameter for Randomized SVD for better accuracy.")
164
 
165
  args = parser.parse_args()
166
 
 
167
  if args.device == "cuda" and not torch.cuda.is_available():
168
  print("CUDA is not available. Falling back to CPU.")
169
  args.device = "cpu"
170
 
171
+ extract_and_svd_lora(args)