python-script-dump / lora_redim.py
anyMODE's picture
Upload lora_redim.py
40b438e verified
import os
import argparse
import torch
from safetensors.torch import load_file, save_file
from safetensors import safe_open
from tqdm import tqdm
def resize_lora_model(model_path, output_path, new_dim, device, method):
"""
Resizes the LoRA dimension of a model using SVD or Randomized SVD.
Also scales the alpha value(s) proportionally.
Args:
model_path (str): Path to the LoRA model to resize.
output_path (str): Path to save the new resized model.
new_dim (int): The target new dimension for the LoRA weights.
device (str): The device to run calculations on ('cuda' or 'cpu').
method (str): The resizing method to use ('svd' or 'randomized_svd').
"""
print(f"Loading model from: {model_path}")
# Load the model onto CPU memory first to avoid VRAM issues with large models
model = load_file(model_path, device="cpu")
new_model = {}
# --- Metadata & Weight Inspection ---
original_dim = None
alpha = None
try:
with safe_open(model_path, framework="pt", device="cpu") as f:
metadata = f.metadata()
if metadata:
if 'ss_network_dim' in metadata:
original_dim = int(metadata['ss_network_dim'])
print(f"Original dimension (from metadata): {original_dim}")
if 'ss_network_alpha' in metadata:
alpha = float(metadata['ss_network_alpha'])
print(f"Original alpha (from metadata): {alpha}")
except Exception as e:
print(f"Could not read metadata: {e}. Dimension and alpha will be inferred.")
# Infer original_dim from weights if not in metadata
if original_dim is None:
for key in model.keys():
if key.endswith((".lora_down.weight", ".lora_A.weight")):
original_dim = model[key].shape[0]
print(f"Inferred original dimension from weights: {original_dim}")
break
if original_dim is None:
print("Error: Could not determine original LoRA dimension.")
return
if original_dim == new_dim:
print("Error: New dimension is the same as the original dimension. No changes to make.")
return
# Infer alpha from weights if not in metadata
if alpha is None:
for key in model.keys():
if key.endswith(".alpha"):
alpha = model[key].item()
print(f"Inferred alpha from weights: {alpha}")
break
# Fallback for alpha if still not found
if alpha is None:
alpha = float(original_dim)
print(f"Alpha not found, falling back to using dimension value: {alpha}")
# --- Tensor Processing ---
# Calculate the scaling ratio for alpha
ratio = new_dim / original_dim
print(f"Dimension resize ratio: {ratio:.4f}")
lora_keys_to_process = set()
for key in model.keys():
if 'lora_' in key and key.endswith('.weight'):
base_key = key.split('.lora_')[0]
lora_keys_to_process.add(base_key)
if not lora_keys_to_process:
print("Error: No LoRA weights found in the model.")
return
print(f"\nFound {len(lora_keys_to_process)} LoRA modules to resize...")
print(f"Using '{method}' method for resizing.")
for base_key in tqdm(sorted(list(lora_keys_to_process)), desc="Resizing modules"):
try:
down_key, up_key = None, None
# Determine the correct key names for down and up weights
if base_key + ".lora_down.weight" in model:
down_key = base_key + ".lora_down.weight"
up_key = base_key + ".lora_up.weight"
elif base_key + ".lora_A.weight" in model:
down_key = base_key + ".lora_A.weight"
up_key = base_key + ".lora_B.weight"
else:
continue
down_weight = model[down_key]
up_weight = model[up_key]
original_dtype = up_weight.dtype
# Move weights to the selected device for processing
down_weight = down_weight.to(device, dtype=torch.float32)
up_weight = up_weight.to(device, dtype=torch.float32)
# Handle both linear and convolutional layers
conv2d = down_weight.ndim == 4
if conv2d:
conv_shape = down_weight.shape
down_weight = down_weight.flatten(1)
up_weight = up_weight.flatten(1)
# Reconstruct the full weight matrix
full_weight = up_weight @ down_weight
if method == 'svd':
# --- Full SVD Resizing (Accurate) ---
U, S, Vh = torch.linalg.svd(full_weight)
# Truncate or pad the SVD components to the new dimension
U = U[:, :new_dim]
S = S[:new_dim]
Vh = Vh[:new_dim, :]
# Distribute singular values (S) back to the new matrices
# A common practice is to take the square root for balanced distribution
S_sqrt = torch.sqrt(S)
new_up = U @ torch.diag(S_sqrt)
new_down = torch.diag(S_sqrt) @ Vh
elif method == 'randomized_svd':
# --- Randomized SVD Resizing (Fast Approximation) ---
U, S, V = torch.svd_lowrank(full_weight, q=new_dim)
Vh = V.T
# Distribute singular values like in the full SVD method
S_sqrt = torch.sqrt(S)
new_up = U @ torch.diag(S_sqrt)
new_down = torch.diag(S_sqrt) @ Vh
if conv2d:
new_down = new_down.reshape(new_dim, conv_shape[1], conv_shape[2], conv_shape[3])
# Store the new resized weights
new_model[down_key] = new_down.contiguous().to(original_dtype)
new_model[up_key] = new_up.contiguous().to(original_dtype)
# --- MODIFICATION START ---
# If a per-module alpha exists, scale it proportionally.
alpha_key = base_key + ".alpha"
if alpha_key in model:
original_alpha_tensor = model[alpha_key]
# Calculate new alpha and create a new tensor with the same dtype
new_alpha_value = original_alpha_tensor.item() * ratio
new_model[alpha_key] = torch.tensor(new_alpha_value, dtype=original_alpha_tensor.dtype)
# --- MODIFICATION END ---
except Exception as e:
print(f"Warning: Failed to process {base_key}. Error: {e}")
continue
# Copy all non-LoRA tensors from the original model
for key, value in model.items():
if ".lora_" not in key:
# Ensure we don't copy an old alpha that has already been processed
if ".alpha" not in key or key not in new_model:
new_model[key] = value
# Update metadata with the new dimension and the globally scaled alpha
new_metadata = {'ss_network_dim': str(new_dim)}
new_alpha = alpha * ratio
new_metadata['ss_network_alpha'] = str(new_alpha)
print(f"\nNew global alpha scaled to: {new_alpha:.2f}")
# Move all tensors to CPU before saving
if device != 'cpu':
print("\nMoving processed tensors to CPU for saving...")
for key in tqdm(new_model.keys(), desc="Finalizing"):
if isinstance(new_model[key], torch.Tensor):
new_model[key] = new_model[key].cpu()
print(f"\nSaving resized model to: {output_path}")
save_file(new_model, output_path, metadata=new_metadata)
print("Done! 🎉")
if __name__ == "__main__":
parser = argparse.ArgumentParser(
description="Resize a LoRA model to a new dimension and scales alpha proportionally.",
formatter_class=argparse.RawTextHelpFormatter
)
parser.add_argument("model_path", type=str, help="Path to the LoRA model (.safetensors).")
parser.add_argument("output_path", type=str, help="Path to save the resized LoRA model.")
parser.add_argument("new_dim", type=int, help="The new LoRA dimension (rank).")
parser.add_argument("--device", type=str, default=None,
help="Device to use (e.g., 'cpu', 'cuda'). Autodetects if not specified.")
parser.add_argument(
"--method",
type=str,
default="svd",
choices=["svd", "randomized_svd"],
help="""Resizing method:
'svd' (default): Accurate but slower. Uses full SVD for optimal weight preservation.
'randomized_svd': Faster approximation of SVD. Excellent for speed on large models."""
)
args = parser.parse_args()
if args.device:
device = args.device
else:
device = "cuda" if torch.cuda.is_available() else "cpu"
print(f"Using device: {device}")
resize_lora_model(args.model_path, args.output_path, args.new_dim, device, args.method)