|
|
import os |
|
|
import argparse |
|
|
import torch |
|
|
from safetensors.torch import load_file, save_file |
|
|
from safetensors import safe_open |
|
|
from tqdm import tqdm |
|
|
|
|
|
|
|
|
def resize_lora_model(model_path, output_path, new_dim, device): |
|
|
""" |
|
|
Resizes the LoRA dimension of a model using SVD for optimal weight preservation. |
|
|
|
|
|
Args: |
|
|
model_path (str): Path to the LoRA model to resize. |
|
|
output_path (str): Path to save the new resized model. |
|
|
new_dim (int): The target new dimension for the LoRA weights. |
|
|
device (str): The device to run calculations on ('cuda' or 'cpu'). |
|
|
""" |
|
|
print(f"Loading model from: {model_path}") |
|
|
model = load_file(model_path) |
|
|
new_model = {} |
|
|
|
|
|
|
|
|
original_dim = None |
|
|
alpha = None |
|
|
try: |
|
|
with safe_open(model_path, framework="pt", device="cpu") as f: |
|
|
metadata = f.metadata() |
|
|
if metadata: |
|
|
if 'ss_network_dim' in metadata: |
|
|
original_dim = int(metadata['ss_network_dim']) |
|
|
print(f"Original dimension (from metadata): {original_dim}") |
|
|
if 'ss_network_alpha' in metadata: |
|
|
alpha = float(metadata['ss_network_alpha']) |
|
|
print(f"Original alpha (from metadata): {alpha}") |
|
|
except Exception as e: |
|
|
print(f"Could not read metadata: {e}. Dimension and alpha will be inferred.") |
|
|
|
|
|
|
|
|
if original_dim is None: |
|
|
for key in model.keys(): |
|
|
if key.endswith((".lora_down.weight", ".lora_A.weight")): |
|
|
original_dim = model[key].shape[0] |
|
|
print(f"Inferred original dimension from weights: {original_dim}") |
|
|
break |
|
|
|
|
|
|
|
|
if alpha is None: |
|
|
for key in model.keys(): |
|
|
if key.endswith(".alpha"): |
|
|
alpha = model[key].item() |
|
|
print(f"Inferred alpha from weights: {alpha}") |
|
|
break |
|
|
|
|
|
|
|
|
if alpha is None and original_dim is not None: |
|
|
alpha = float(original_dim) |
|
|
print(f"Alpha not found, falling back to using dimension: {alpha}") |
|
|
|
|
|
|
|
|
lora_keys_to_process = set() |
|
|
for key in model.keys(): |
|
|
if 'lora_' in key and key.endswith('.weight'): |
|
|
|
|
|
base_key = key.split('.lora_')[0] |
|
|
lora_keys_to_process.add(base_key) |
|
|
|
|
|
if not lora_keys_to_process: |
|
|
print("Error: No LoRA weights found in the model.") |
|
|
return |
|
|
|
|
|
print(f"\nFound {len(lora_keys_to_process)} LoRA modules to resize...") |
|
|
|
|
|
for base_key in tqdm(sorted(list(lora_keys_to_process)), desc="Resizing modules"): |
|
|
try: |
|
|
down_key, up_key = None, None |
|
|
|
|
|
|
|
|
if base_key + ".lora_down.weight" in model: |
|
|
down_key = base_key + ".lora_down.weight" |
|
|
up_key = base_key + ".lora_up.weight" |
|
|
elif base_key + ".lora_A.weight" in model: |
|
|
down_key = base_key + ".lora_A.weight" |
|
|
up_key = base_key + ".lora_B.weight" |
|
|
else: |
|
|
continue |
|
|
|
|
|
|
|
|
down_weight = model[down_key].to(device) |
|
|
up_weight = model[up_key].to(device) |
|
|
|
|
|
|
|
|
original_dtype = up_weight.dtype |
|
|
|
|
|
|
|
|
conv2d = down_weight.ndim == 4 |
|
|
if conv2d: |
|
|
|
|
|
down_weight = down_weight.flatten(1) |
|
|
up_weight = up_weight.flatten(1) |
|
|
|
|
|
full_weight = up_weight @ down_weight |
|
|
|
|
|
|
|
|
U, S, Vh = torch.linalg.svd(full_weight.to(torch.float32)) |
|
|
|
|
|
|
|
|
U = U[:, :new_dim] |
|
|
S = S[:new_dim] |
|
|
Vh = Vh[:new_dim, :] |
|
|
|
|
|
|
|
|
new_down = torch.diag(S) @ Vh |
|
|
new_up = U |
|
|
|
|
|
|
|
|
if conv2d: |
|
|
new_down = new_down.reshape(new_dim, down_weight.shape[1], 1, 1) |
|
|
new_up = new_up.reshape(up_weight.shape[0], new_dim, 1, 1) |
|
|
|
|
|
|
|
|
new_model[down_key] = new_down.contiguous().to(original_dtype).cpu() |
|
|
new_model[up_key] = new_up.contiguous().to(original_dtype).cpu() |
|
|
|
|
|
|
|
|
alpha_key = base_key + ".alpha" |
|
|
if alpha_key in model: |
|
|
new_model[alpha_key] = model[alpha_key] |
|
|
|
|
|
except KeyError: |
|
|
continue |
|
|
|
|
|
|
|
|
for key, value in model.items(): |
|
|
if ".lora_" not in key: |
|
|
new_model[key] = value |
|
|
|
|
|
|
|
|
new_metadata = {'ss_network_dim': str(new_dim)} |
|
|
if alpha is not None and original_dim is not None and original_dim > 0: |
|
|
new_alpha = alpha * (new_dim / original_dim) |
|
|
new_metadata['ss_network_alpha'] = str(new_alpha) |
|
|
print(f"\nNew alpha scaled to: {new_alpha:.2f}") |
|
|
|
|
|
print(f"\nSaving resized model to: {output_path}") |
|
|
save_file(new_model, output_path, metadata=new_metadata) |
|
|
print("Done!") |
|
|
|
|
|
|
|
|
if __name__ == "__main__": |
|
|
parser = argparse.ArgumentParser( |
|
|
description="Resize a LoRA model to a new dimension using SVD.", |
|
|
formatter_class=argparse.RawTextHelpFormatter |
|
|
) |
|
|
parser.add_argument("model_path", type=str, help="Path to the LoRA model (.safetensors).") |
|
|
parser.add_argument("output_path", type=str, help="Path to save the resized LoRA model.") |
|
|
parser.add_argument("new_dim", type=int, help="The new LoRA dimension (rank).") |
|
|
parser.add_argument("--device", type=str, default=None, |
|
|
help="Device to use (e.g., 'cpu', 'cuda'). Autodetects if not specified.") |
|
|
|
|
|
args = parser.parse_args() |
|
|
|
|
|
if args.device: |
|
|
device = args.device |
|
|
else: |
|
|
device = "cuda" if torch.cuda.is_available() else "cpu" |
|
|
|
|
|
print(f"Using device: {device}") |
|
|
|
|
|
resize_lora_model(args.model_path, args.output_path, args.new_dim, device) |
|
|
|
|
|
|