import torch import argparse from safetensors.torch import save_file, safe_open from tqdm import tqdm import sys def normalize_key(key): """Strips the 'model.diffusion_model.' prefix from a key if it exists.""" prefix = 'model.diffusion_model.' if key.startswith(prefix): return key[len(prefix):] return key def get_torch_dtype(dtype_str: str): """Converts a string to a torch.dtype object.""" if dtype_str == "fp32": return torch.float32 if dtype_str == "fp16": return torch.float16 if dtype_str == "bf16": return torch.bfloat16 raise ValueError(f"Unsupported dtype: {dtype_str}") def randomized_svd(matrix, rank, n_oversamples=10): """Performs Randomized SVD for a faster approximation.""" max_rank = min(matrix.shape) if rank >= max_rank: rank = max_rank n_oversamples = 0 target_rank = min(rank + n_oversamples, max_rank) P = torch.randn((matrix.shape[1], target_rank), device=matrix.device, dtype=matrix.dtype) Y = matrix @ P Q, _ = torch.linalg.qr(Y.float()) B = Q.T @ matrix.float() U_b, S, Vh = torch.linalg.svd(B, full_matrices=False) U = Q @ U_b U = U[:, :rank] S = S[:rank] Vh = Vh[:rank, :] return U, S, Vh def extract_and_svd_lora(args): """Main function to extract, decompose, and save the LoRA.""" print(f"Loading base model A: {args.model_a}") print(f"Loading finetuned model B: {args.model_b}") print(f"Using decomposition method: {args.method}") lora_tensors = {} dtype = get_torch_dtype(args.precision) with safe_open(args.model_a, framework="pt", device="cpu") as f_a, \ safe_open(args.model_b, framework="pt", device="cpu") as f_b: keys_a_original = set(f_a.keys()) keys_b_original = set(f_b.keys()) print(f"\nFound {len(keys_a_original)} keys in model A.") print(f"Found {len(keys_b_original)} keys in model B.") normalized_keys_a = {normalize_key(k): k for k in keys_a_original} normalized_keys_b = {normalize_key(k): k for k in keys_b_original} common_normalized_keys = set(normalized_keys_a.keys()).intersection(set(normalized_keys_b.keys())) print(f"Found {len(common_normalized_keys)} common keys after normalization.\n") processable_keys = {k for k in common_normalized_keys if (k.endswith('.weight') or k.endswith('.bias')) and 'lora_' not in k} if not processable_keys: print("No common weight or bias keys found to process. Check if models are compatible.") sys.exit(1) print(f"Found {len(processable_keys)} common keys to process.") for norm_key in tqdm(sorted(list(processable_keys)), desc="Processing Layers"): try: original_key_a = normalized_keys_a[norm_key] original_key_b = normalized_keys_b[norm_key] tensor_a = f_a.get_tensor(original_key_a).to(device=args.device, dtype=dtype) tensor_b = f_b.get_tensor(original_key_b).to(device=args.device, dtype=dtype) if tensor_a.shape != tensor_b.shape: tqdm.write(f"Skipping key {norm_key} due to shape mismatch") continue delta = tensor_b - tensor_a if norm_key.endswith('.weight'): delta_w = delta if delta_w.dim() < 2: tqdm.write(f"Skipping weight key {norm_key} as it's not a 2D matrix.") continue if delta_w.dim() > 2: delta_w = delta_w.view(delta_w.shape[0], -1) if args.method == 'rsvd': # Use the new oversamples argument U, S, Vh = randomized_svd(delta_w, args.rank, n_oversamples=args.oversamples) else: U, S, Vh = torch.linalg.svd(delta_w, full_matrices=False) current_rank = min(args.rank, S.size(0)) U = U[:, :current_rank] S = S[:current_rank] Vh = Vh[:current_rank, :] lora_down = Vh lora_up = U @ torch.diag(S) base_name = norm_key.replace('.weight', '') prefixed_base_name = f"diffusion_model.{base_name}" lora_down_name = f"{prefixed_base_name}.lora_down.weight" lora_up_name = f"{prefixed_base_name}.lora_up.weight" lora_tensors[lora_down_name] = lora_down.contiguous().cpu().to(torch.float32) lora_tensors[lora_up_name] = lora_up.contiguous().cpu().to(torch.float32) except Exception as e: tqdm.write(f"Failed to process key {norm_key}: {e}") if not lora_tensors: print("No tensors were processed. Output file will not be created.") return print(f"\nSaving {len(lora_tensors)} tensors to {args.output}...") save_file(lora_tensors, args.output) print("✅ Done!") if __name__ == "__main__": parser = argparse.ArgumentParser(description="Extract a LoRA/LoRA+ from two SafeTensors checkpoints.") parser.add_argument("model_a", type=str, help="Path to the base model (A) checkpoint.") parser.add_argument("model_b", type=str, help="Path to the finetuned model (B) checkpoint.") parser.add_argument("output", type=str, help="Path to save the output file.") parser.add_argument("--rank", type=int, required=True, help="The target rank for the decomposition.") parser.add_argument("--alpha", type=float, default=1.0, help="Informational alpha value for scaling. This value is NOT saved in the file.") parser.add_argument("--method", type=str, default="rsvd", choices=["svd", "rsvd"], help="Decomposition method.") parser.add_argument("--device", type=str, default="cuda", choices=["cuda", "cpu"], help="Device to use for computation.") parser.add_argument("--precision", type=str, default="fp32", choices=["fp32", "fp16", "bf16"], help="Precision for calculations.") # --- NEW ARGUMENT --- parser.add_argument("--oversamples", type=int, default=10, help="Oversampling parameter for Randomized SVD for better accuracy.") args = parser.parse_args() if args.device == "cuda" and not torch.cuda.is_available(): print("CUDA is not available. Falling back to CPU.") args.device = "cpu" extract_and_svd_lora(args)