File size: 8,930 Bytes
59d2585 40b438e 59d2585 40b438e 59d2585 40b438e 59d2585 40b438e 59d2585 40b438e 59d2585 40b438e 59d2585 40b438e 59d2585 40b438e 59d2585 40b438e 59d2585 40b438e 59d2585 40b438e 59d2585 40b438e 59d2585 40b438e 59d2585 40b438e 59d2585 40b438e 59d2585 40b438e 59d2585 40b438e 59d2585 40b438e 59d2585 40b438e 59d2585 40b438e 59d2585 40b438e 59d2585 40b438e 59d2585 40b438e 59d2585 40b438e 59d2585 40b438e 59d2585 40b438e 59d2585 40b438e 59d2585 40b438e 59d2585 40b438e |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 |
import os
import argparse
import torch
from safetensors.torch import load_file, save_file
from safetensors import safe_open
from tqdm import tqdm
def resize_lora_model(model_path, output_path, new_dim, device, method):
"""
Resizes the LoRA dimension of a model using SVD or Randomized SVD.
Also scales the alpha value(s) proportionally.
Args:
model_path (str): Path to the LoRA model to resize.
output_path (str): Path to save the new resized model.
new_dim (int): The target new dimension for the LoRA weights.
device (str): The device to run calculations on ('cuda' or 'cpu').
method (str): The resizing method to use ('svd' or 'randomized_svd').
"""
print(f"Loading model from: {model_path}")
# Load the model onto CPU memory first to avoid VRAM issues with large models
model = load_file(model_path, device="cpu")
new_model = {}
# --- Metadata & Weight Inspection ---
original_dim = None
alpha = None
try:
with safe_open(model_path, framework="pt", device="cpu") as f:
metadata = f.metadata()
if metadata:
if 'ss_network_dim' in metadata:
original_dim = int(metadata['ss_network_dim'])
print(f"Original dimension (from metadata): {original_dim}")
if 'ss_network_alpha' in metadata:
alpha = float(metadata['ss_network_alpha'])
print(f"Original alpha (from metadata): {alpha}")
except Exception as e:
print(f"Could not read metadata: {e}. Dimension and alpha will be inferred.")
# Infer original_dim from weights if not in metadata
if original_dim is None:
for key in model.keys():
if key.endswith((".lora_down.weight", ".lora_A.weight")):
original_dim = model[key].shape[0]
print(f"Inferred original dimension from weights: {original_dim}")
break
if original_dim is None:
print("Error: Could not determine original LoRA dimension.")
return
if original_dim == new_dim:
print("Error: New dimension is the same as the original dimension. No changes to make.")
return
# Infer alpha from weights if not in metadata
if alpha is None:
for key in model.keys():
if key.endswith(".alpha"):
alpha = model[key].item()
print(f"Inferred alpha from weights: {alpha}")
break
# Fallback for alpha if still not found
if alpha is None:
alpha = float(original_dim)
print(f"Alpha not found, falling back to using dimension value: {alpha}")
# --- Tensor Processing ---
# Calculate the scaling ratio for alpha
ratio = new_dim / original_dim
print(f"Dimension resize ratio: {ratio:.4f}")
lora_keys_to_process = set()
for key in model.keys():
if 'lora_' in key and key.endswith('.weight'):
base_key = key.split('.lora_')[0]
lora_keys_to_process.add(base_key)
if not lora_keys_to_process:
print("Error: No LoRA weights found in the model.")
return
print(f"\nFound {len(lora_keys_to_process)} LoRA modules to resize...")
print(f"Using '{method}' method for resizing.")
for base_key in tqdm(sorted(list(lora_keys_to_process)), desc="Resizing modules"):
try:
down_key, up_key = None, None
# Determine the correct key names for down and up weights
if base_key + ".lora_down.weight" in model:
down_key = base_key + ".lora_down.weight"
up_key = base_key + ".lora_up.weight"
elif base_key + ".lora_A.weight" in model:
down_key = base_key + ".lora_A.weight"
up_key = base_key + ".lora_B.weight"
else:
continue
down_weight = model[down_key]
up_weight = model[up_key]
original_dtype = up_weight.dtype
# Move weights to the selected device for processing
down_weight = down_weight.to(device, dtype=torch.float32)
up_weight = up_weight.to(device, dtype=torch.float32)
# Handle both linear and convolutional layers
conv2d = down_weight.ndim == 4
if conv2d:
conv_shape = down_weight.shape
down_weight = down_weight.flatten(1)
up_weight = up_weight.flatten(1)
# Reconstruct the full weight matrix
full_weight = up_weight @ down_weight
if method == 'svd':
# --- Full SVD Resizing (Accurate) ---
U, S, Vh = torch.linalg.svd(full_weight)
# Truncate or pad the SVD components to the new dimension
U = U[:, :new_dim]
S = S[:new_dim]
Vh = Vh[:new_dim, :]
# Distribute singular values (S) back to the new matrices
# A common practice is to take the square root for balanced distribution
S_sqrt = torch.sqrt(S)
new_up = U @ torch.diag(S_sqrt)
new_down = torch.diag(S_sqrt) @ Vh
elif method == 'randomized_svd':
# --- Randomized SVD Resizing (Fast Approximation) ---
U, S, V = torch.svd_lowrank(full_weight, q=new_dim)
Vh = V.T
# Distribute singular values like in the full SVD method
S_sqrt = torch.sqrt(S)
new_up = U @ torch.diag(S_sqrt)
new_down = torch.diag(S_sqrt) @ Vh
if conv2d:
new_down = new_down.reshape(new_dim, conv_shape[1], conv_shape[2], conv_shape[3])
# Store the new resized weights
new_model[down_key] = new_down.contiguous().to(original_dtype)
new_model[up_key] = new_up.contiguous().to(original_dtype)
# --- MODIFICATION START ---
# If a per-module alpha exists, scale it proportionally.
alpha_key = base_key + ".alpha"
if alpha_key in model:
original_alpha_tensor = model[alpha_key]
# Calculate new alpha and create a new tensor with the same dtype
new_alpha_value = original_alpha_tensor.item() * ratio
new_model[alpha_key] = torch.tensor(new_alpha_value, dtype=original_alpha_tensor.dtype)
# --- MODIFICATION END ---
except Exception as e:
print(f"Warning: Failed to process {base_key}. Error: {e}")
continue
# Copy all non-LoRA tensors from the original model
for key, value in model.items():
if ".lora_" not in key:
# Ensure we don't copy an old alpha that has already been processed
if ".alpha" not in key or key not in new_model:
new_model[key] = value
# Update metadata with the new dimension and the globally scaled alpha
new_metadata = {'ss_network_dim': str(new_dim)}
new_alpha = alpha * ratio
new_metadata['ss_network_alpha'] = str(new_alpha)
print(f"\nNew global alpha scaled to: {new_alpha:.2f}")
# Move all tensors to CPU before saving
if device != 'cpu':
print("\nMoving processed tensors to CPU for saving...")
for key in tqdm(new_model.keys(), desc="Finalizing"):
if isinstance(new_model[key], torch.Tensor):
new_model[key] = new_model[key].cpu()
print(f"\nSaving resized model to: {output_path}")
save_file(new_model, output_path, metadata=new_metadata)
print("Done! 🎉")
if __name__ == "__main__":
parser = argparse.ArgumentParser(
description="Resize a LoRA model to a new dimension and scales alpha proportionally.",
formatter_class=argparse.RawTextHelpFormatter
)
parser.add_argument("model_path", type=str, help="Path to the LoRA model (.safetensors).")
parser.add_argument("output_path", type=str, help="Path to save the resized LoRA model.")
parser.add_argument("new_dim", type=int, help="The new LoRA dimension (rank).")
parser.add_argument("--device", type=str, default=None,
help="Device to use (e.g., 'cpu', 'cuda'). Autodetects if not specified.")
parser.add_argument(
"--method",
type=str,
default="svd",
choices=["svd", "randomized_svd"],
help="""Resizing method:
'svd' (default): Accurate but slower. Uses full SVD for optimal weight preservation.
'randomized_svd': Faster approximation of SVD. Excellent for speed on large models."""
)
args = parser.parse_args()
if args.device:
device = args.device
else:
device = "cuda" if torch.cuda.is_available() else "cpu"
print(f"Using device: {device}")
resize_lora_model(args.model_path, args.output_path, args.new_dim, device, args.method) |