import os import argparse import torch from safetensors.torch import load_file, save_file from safetensors import safe_open from tqdm import tqdm def resize_lora_model(model_path, output_path, new_dim, device, method): """ Resizes the LoRA dimension of a model using SVD or Randomized SVD. Also scales the alpha value(s) proportionally. Args: model_path (str): Path to the LoRA model to resize. output_path (str): Path to save the new resized model. new_dim (int): The target new dimension for the LoRA weights. device (str): The device to run calculations on ('cuda' or 'cpu'). method (str): The resizing method to use ('svd' or 'randomized_svd'). """ print(f"Loading model from: {model_path}") # Load the model onto CPU memory first to avoid VRAM issues with large models model = load_file(model_path, device="cpu") new_model = {} # --- Metadata & Weight Inspection --- original_dim = None alpha = None try: with safe_open(model_path, framework="pt", device="cpu") as f: metadata = f.metadata() if metadata: if 'ss_network_dim' in metadata: original_dim = int(metadata['ss_network_dim']) print(f"Original dimension (from metadata): {original_dim}") if 'ss_network_alpha' in metadata: alpha = float(metadata['ss_network_alpha']) print(f"Original alpha (from metadata): {alpha}") except Exception as e: print(f"Could not read metadata: {e}. Dimension and alpha will be inferred.") # Infer original_dim from weights if not in metadata if original_dim is None: for key in model.keys(): if key.endswith((".lora_down.weight", ".lora_A.weight")): original_dim = model[key].shape[0] print(f"Inferred original dimension from weights: {original_dim}") break if original_dim is None: print("Error: Could not determine original LoRA dimension.") return if original_dim == new_dim: print("Error: New dimension is the same as the original dimension. No changes to make.") return # Infer alpha from weights if not in metadata if alpha is None: for key in model.keys(): if key.endswith(".alpha"): alpha = model[key].item() print(f"Inferred alpha from weights: {alpha}") break # Fallback for alpha if still not found if alpha is None: alpha = float(original_dim) print(f"Alpha not found, falling back to using dimension value: {alpha}") # --- Tensor Processing --- # Calculate the scaling ratio for alpha ratio = new_dim / original_dim print(f"Dimension resize ratio: {ratio:.4f}") lora_keys_to_process = set() for key in model.keys(): if 'lora_' in key and key.endswith('.weight'): base_key = key.split('.lora_')[0] lora_keys_to_process.add(base_key) if not lora_keys_to_process: print("Error: No LoRA weights found in the model.") return print(f"\nFound {len(lora_keys_to_process)} LoRA modules to resize...") print(f"Using '{method}' method for resizing.") for base_key in tqdm(sorted(list(lora_keys_to_process)), desc="Resizing modules"): try: down_key, up_key = None, None # Determine the correct key names for down and up weights if base_key + ".lora_down.weight" in model: down_key = base_key + ".lora_down.weight" up_key = base_key + ".lora_up.weight" elif base_key + ".lora_A.weight" in model: down_key = base_key + ".lora_A.weight" up_key = base_key + ".lora_B.weight" else: continue down_weight = model[down_key] up_weight = model[up_key] original_dtype = up_weight.dtype # Move weights to the selected device for processing down_weight = down_weight.to(device, dtype=torch.float32) up_weight = up_weight.to(device, dtype=torch.float32) # Handle both linear and convolutional layers conv2d = down_weight.ndim == 4 if conv2d: conv_shape = down_weight.shape down_weight = down_weight.flatten(1) up_weight = up_weight.flatten(1) # Reconstruct the full weight matrix full_weight = up_weight @ down_weight if method == 'svd': # --- Full SVD Resizing (Accurate) --- U, S, Vh = torch.linalg.svd(full_weight) # Truncate or pad the SVD components to the new dimension U = U[:, :new_dim] S = S[:new_dim] Vh = Vh[:new_dim, :] # Distribute singular values (S) back to the new matrices # A common practice is to take the square root for balanced distribution S_sqrt = torch.sqrt(S) new_up = U @ torch.diag(S_sqrt) new_down = torch.diag(S_sqrt) @ Vh elif method == 'randomized_svd': # --- Randomized SVD Resizing (Fast Approximation) --- U, S, V = torch.svd_lowrank(full_weight, q=new_dim) Vh = V.T # Distribute singular values like in the full SVD method S_sqrt = torch.sqrt(S) new_up = U @ torch.diag(S_sqrt) new_down = torch.diag(S_sqrt) @ Vh if conv2d: new_down = new_down.reshape(new_dim, conv_shape[1], conv_shape[2], conv_shape[3]) # Store the new resized weights new_model[down_key] = new_down.contiguous().to(original_dtype) new_model[up_key] = new_up.contiguous().to(original_dtype) # --- MODIFICATION START --- # If a per-module alpha exists, scale it proportionally. alpha_key = base_key + ".alpha" if alpha_key in model: original_alpha_tensor = model[alpha_key] # Calculate new alpha and create a new tensor with the same dtype new_alpha_value = original_alpha_tensor.item() * ratio new_model[alpha_key] = torch.tensor(new_alpha_value, dtype=original_alpha_tensor.dtype) # --- MODIFICATION END --- except Exception as e: print(f"Warning: Failed to process {base_key}. Error: {e}") continue # Copy all non-LoRA tensors from the original model for key, value in model.items(): if ".lora_" not in key: # Ensure we don't copy an old alpha that has already been processed if ".alpha" not in key or key not in new_model: new_model[key] = value # Update metadata with the new dimension and the globally scaled alpha new_metadata = {'ss_network_dim': str(new_dim)} new_alpha = alpha * ratio new_metadata['ss_network_alpha'] = str(new_alpha) print(f"\nNew global alpha scaled to: {new_alpha:.2f}") # Move all tensors to CPU before saving if device != 'cpu': print("\nMoving processed tensors to CPU for saving...") for key in tqdm(new_model.keys(), desc="Finalizing"): if isinstance(new_model[key], torch.Tensor): new_model[key] = new_model[key].cpu() print(f"\nSaving resized model to: {output_path}") save_file(new_model, output_path, metadata=new_metadata) print("Done! 🎉") if __name__ == "__main__": parser = argparse.ArgumentParser( description="Resize a LoRA model to a new dimension and scales alpha proportionally.", formatter_class=argparse.RawTextHelpFormatter ) parser.add_argument("model_path", type=str, help="Path to the LoRA model (.safetensors).") parser.add_argument("output_path", type=str, help="Path to save the resized LoRA model.") parser.add_argument("new_dim", type=int, help="The new LoRA dimension (rank).") parser.add_argument("--device", type=str, default=None, help="Device to use (e.g., 'cpu', 'cuda'). Autodetects if not specified.") parser.add_argument( "--method", type=str, default="svd", choices=["svd", "randomized_svd"], help="""Resizing method: 'svd' (default): Accurate but slower. Uses full SVD for optimal weight preservation. 'randomized_svd': Faster approximation of SVD. Excellent for speed on large models.""" ) args = parser.parse_args() if args.device: device = args.device else: device = "cuda" if torch.cuda.is_available() else "cpu" print(f"Using device: {device}") resize_lora_model(args.model_path, args.output_path, args.new_dim, device, args.method)