python-script-dump / extract_lora.py
anyMODE's picture
Upload extract_lora.py
c66ab51 verified
raw
history blame
6.66 kB
import torch
import argparse
from safetensors.torch import save_file, safe_open
from tqdm import tqdm
import sys
def normalize_key(key):
"""Strips the 'model.diffusion_model.' prefix from a key if it exists."""
prefix = 'model.diffusion_model.'
if key.startswith(prefix):
return key[len(prefix):]
return key
def get_torch_dtype(dtype_str: str):
"""Converts a string to a torch.dtype object."""
if dtype_str == "fp32":
return torch.float32
if dtype_str == "fp16":
return torch.float16
if dtype_str == "bf16":
return torch.bfloat16
raise ValueError(f"Unsupported dtype: {dtype_str}")
def randomized_svd(matrix, rank, n_oversamples=10):
"""Performs Randomized SVD for a faster approximation."""
max_rank = min(matrix.shape)
if rank >= max_rank:
rank = max_rank
n_oversamples = 0
target_rank = min(rank + n_oversamples, max_rank)
P = torch.randn((matrix.shape[1], target_rank), device=matrix.device, dtype=matrix.dtype)
Y = matrix @ P
Q, _ = torch.linalg.qr(Y.float())
B = Q.T @ matrix.float()
U_b, S, Vh = torch.linalg.svd(B, full_matrices=False)
U = Q @ U_b
U = U[:, :rank]
S = S[:rank]
Vh = Vh[:rank, :]
return U, S, Vh
def extract_and_svd_lora(args):
"""Main function to extract, decompose, and save the LoRA."""
print(f"Loading base model A: {args.model_a}")
print(f"Loading finetuned model B: {args.model_b}")
print(f"Using decomposition method: {args.method}")
lora_tensors = {}
dtype = get_torch_dtype(args.precision)
with safe_open(args.model_a, framework="pt", device="cpu") as f_a, \
safe_open(args.model_b, framework="pt", device="cpu") as f_b:
keys_a_original = set(f_a.keys())
keys_b_original = set(f_b.keys())
print(f"\nFound {len(keys_a_original)} keys in model A.")
print(f"Found {len(keys_b_original)} keys in model B.")
normalized_keys_a = {normalize_key(k): k for k in keys_a_original}
normalized_keys_b = {normalize_key(k): k for k in keys_b_original}
common_normalized_keys = set(normalized_keys_a.keys()).intersection(set(normalized_keys_b.keys()))
print(f"Found {len(common_normalized_keys)} common keys after normalization.\n")
processable_keys = {k for k in common_normalized_keys if
(k.endswith('.weight') or k.endswith('.bias')) and 'lora_' not in k}
if not processable_keys:
print("No common weight or bias keys found to process. Check if models are compatible.")
sys.exit(1)
print(f"Found {len(processable_keys)} common keys to process.")
for norm_key in tqdm(sorted(list(processable_keys)), desc="Processing Layers"):
try:
original_key_a = normalized_keys_a[norm_key]
original_key_b = normalized_keys_b[norm_key]
tensor_a = f_a.get_tensor(original_key_a).to(device=args.device, dtype=dtype)
tensor_b = f_b.get_tensor(original_key_b).to(device=args.device, dtype=dtype)
if tensor_a.shape != tensor_b.shape:
tqdm.write(f"Skipping key {norm_key} due to shape mismatch")
continue
delta = tensor_b - tensor_a
if norm_key.endswith('.weight'):
delta_w = delta
if delta_w.dim() < 2:
tqdm.write(f"Skipping weight key {norm_key} as it's not a 2D matrix.")
continue
if delta_w.dim() > 2:
delta_w = delta_w.view(delta_w.shape[0], -1)
if args.method == 'rsvd':
# Use the new oversamples argument
U, S, Vh = randomized_svd(delta_w, args.rank, n_oversamples=args.oversamples)
else:
U, S, Vh = torch.linalg.svd(delta_w, full_matrices=False)
current_rank = min(args.rank, S.size(0))
U = U[:, :current_rank]
S = S[:current_rank]
Vh = Vh[:current_rank, :]
lora_down = Vh
lora_up = U @ torch.diag(S)
base_name = norm_key.replace('.weight', '')
prefixed_base_name = f"diffusion_model.{base_name}"
lora_down_name = f"{prefixed_base_name}.lora_down.weight"
lora_up_name = f"{prefixed_base_name}.lora_up.weight"
lora_tensors[lora_down_name] = lora_down.contiguous().cpu().to(torch.float32)
lora_tensors[lora_up_name] = lora_up.contiguous().cpu().to(torch.float32)
except Exception as e:
tqdm.write(f"Failed to process key {norm_key}: {e}")
if not lora_tensors:
print("No tensors were processed. Output file will not be created.")
return
print(f"\nSaving {len(lora_tensors)} tensors to {args.output}...")
save_file(lora_tensors, args.output)
print("✅ Done!")
if __name__ == "__main__":
parser = argparse.ArgumentParser(description="Extract a LoRA/LoRA+ from two SafeTensors checkpoints.")
parser.add_argument("model_a", type=str, help="Path to the base model (A) checkpoint.")
parser.add_argument("model_b", type=str, help="Path to the finetuned model (B) checkpoint.")
parser.add_argument("output", type=str, help="Path to save the output file.")
parser.add_argument("--rank", type=int, required=True, help="The target rank for the decomposition.")
parser.add_argument("--alpha", type=float, default=1.0,
help="Informational alpha value for scaling. This value is NOT saved in the file.")
parser.add_argument("--method", type=str, default="rsvd", choices=["svd", "rsvd"], help="Decomposition method.")
parser.add_argument("--device", type=str, default="cuda", choices=["cuda", "cpu"],
help="Device to use for computation.")
parser.add_argument("--precision", type=str, default="fp32", choices=["fp32", "fp16", "bf16"],
help="Precision for calculations.")
# --- NEW ARGUMENT ---
parser.add_argument("--oversamples", type=int, default=10,
help="Oversampling parameter for Randomized SVD for better accuracy.")
args = parser.parse_args()
if args.device == "cuda" and not torch.cuda.is_available():
print("CUDA is not available. Falling back to CPU.")
args.device = "cpu"
extract_and_svd_lora(args)