|
|
import torch |
|
|
import argparse |
|
|
from safetensors.torch import save_file, safe_open |
|
|
from tqdm import tqdm |
|
|
import sys |
|
|
|
|
|
|
|
|
def normalize_key(key): |
|
|
"""Strips the 'model.diffusion_model.' prefix from a key if it exists.""" |
|
|
prefix = 'model.diffusion_model.' |
|
|
if key.startswith(prefix): |
|
|
return key[len(prefix):] |
|
|
return key |
|
|
|
|
|
|
|
|
def get_torch_dtype(dtype_str: str): |
|
|
"""Converts a string to a torch.dtype object.""" |
|
|
if dtype_str == "fp32": |
|
|
return torch.float32 |
|
|
if dtype_str == "fp16": |
|
|
return torch.float16 |
|
|
if dtype_str == "bf16": |
|
|
return torch.bfloat16 |
|
|
raise ValueError(f"Unsupported dtype: {dtype_str}") |
|
|
|
|
|
|
|
|
def randomized_svd(matrix, rank, n_oversamples=10): |
|
|
"""Performs Randomized SVD for a faster approximation.""" |
|
|
max_rank = min(matrix.shape) |
|
|
if rank >= max_rank: |
|
|
rank = max_rank |
|
|
n_oversamples = 0 |
|
|
|
|
|
target_rank = min(rank + n_oversamples, max_rank) |
|
|
|
|
|
P = torch.randn((matrix.shape[1], target_rank), device=matrix.device, dtype=matrix.dtype) |
|
|
Y = matrix @ P |
|
|
|
|
|
Q, _ = torch.linalg.qr(Y.float()) |
|
|
|
|
|
B = Q.T @ matrix.float() |
|
|
|
|
|
U_b, S, Vh = torch.linalg.svd(B, full_matrices=False) |
|
|
U = Q @ U_b |
|
|
|
|
|
U = U[:, :rank] |
|
|
S = S[:rank] |
|
|
Vh = Vh[:rank, :] |
|
|
|
|
|
return U, S, Vh |
|
|
|
|
|
|
|
|
def extract_and_svd_lora(args): |
|
|
"""Main function to extract, decompose, and save the LoRA.""" |
|
|
print(f"Loading base model A: {args.model_a}") |
|
|
print(f"Loading finetuned model B: {args.model_b}") |
|
|
print(f"Using decomposition method: {args.method}") |
|
|
|
|
|
lora_tensors = {} |
|
|
dtype = get_torch_dtype(args.precision) |
|
|
|
|
|
with safe_open(args.model_a, framework="pt", device="cpu") as f_a, \ |
|
|
safe_open(args.model_b, framework="pt", device="cpu") as f_b: |
|
|
|
|
|
keys_a_original = set(f_a.keys()) |
|
|
keys_b_original = set(f_b.keys()) |
|
|
print(f"\nFound {len(keys_a_original)} keys in model A.") |
|
|
print(f"Found {len(keys_b_original)} keys in model B.") |
|
|
|
|
|
normalized_keys_a = {normalize_key(k): k for k in keys_a_original} |
|
|
normalized_keys_b = {normalize_key(k): k for k in keys_b_original} |
|
|
|
|
|
common_normalized_keys = set(normalized_keys_a.keys()).intersection(set(normalized_keys_b.keys())) |
|
|
print(f"Found {len(common_normalized_keys)} common keys after normalization.\n") |
|
|
|
|
|
processable_keys = {k for k in common_normalized_keys if |
|
|
(k.endswith('.weight') or k.endswith('.bias')) and 'lora_' not in k} |
|
|
|
|
|
if not processable_keys: |
|
|
print("No common weight or bias keys found to process. Check if models are compatible.") |
|
|
sys.exit(1) |
|
|
|
|
|
print(f"Found {len(processable_keys)} common keys to process.") |
|
|
|
|
|
for norm_key in tqdm(sorted(list(processable_keys)), desc="Processing Layers"): |
|
|
try: |
|
|
original_key_a = normalized_keys_a[norm_key] |
|
|
original_key_b = normalized_keys_b[norm_key] |
|
|
|
|
|
tensor_a = f_a.get_tensor(original_key_a).to(device=args.device, dtype=dtype) |
|
|
tensor_b = f_b.get_tensor(original_key_b).to(device=args.device, dtype=dtype) |
|
|
|
|
|
if tensor_a.shape != tensor_b.shape: |
|
|
tqdm.write(f"Skipping key {norm_key} due to shape mismatch") |
|
|
continue |
|
|
|
|
|
delta = tensor_b - tensor_a |
|
|
|
|
|
if norm_key.endswith('.weight'): |
|
|
delta_w = delta |
|
|
if delta_w.dim() < 2: |
|
|
tqdm.write(f"Skipping weight key {norm_key} as it's not a 2D matrix.") |
|
|
continue |
|
|
if delta_w.dim() > 2: |
|
|
delta_w = delta_w.view(delta_w.shape[0], -1) |
|
|
|
|
|
if args.method == 'rsvd': |
|
|
|
|
|
U, S, Vh = randomized_svd(delta_w, args.rank, n_oversamples=args.oversamples) |
|
|
else: |
|
|
U, S, Vh = torch.linalg.svd(delta_w, full_matrices=False) |
|
|
current_rank = min(args.rank, S.size(0)) |
|
|
U = U[:, :current_rank] |
|
|
S = S[:current_rank] |
|
|
Vh = Vh[:current_rank, :] |
|
|
|
|
|
lora_down = Vh |
|
|
lora_up = U @ torch.diag(S) |
|
|
|
|
|
base_name = norm_key.replace('.weight', '') |
|
|
prefixed_base_name = f"diffusion_model.{base_name}" |
|
|
lora_down_name = f"{prefixed_base_name}.lora_down.weight" |
|
|
lora_up_name = f"{prefixed_base_name}.lora_up.weight" |
|
|
|
|
|
lora_tensors[lora_down_name] = lora_down.contiguous().cpu().to(torch.float32) |
|
|
lora_tensors[lora_up_name] = lora_up.contiguous().cpu().to(torch.float32) |
|
|
|
|
|
except Exception as e: |
|
|
tqdm.write(f"Failed to process key {norm_key}: {e}") |
|
|
|
|
|
if not lora_tensors: |
|
|
print("No tensors were processed. Output file will not be created.") |
|
|
return |
|
|
|
|
|
print(f"\nSaving {len(lora_tensors)} tensors to {args.output}...") |
|
|
save_file(lora_tensors, args.output) |
|
|
print("✅ Done!") |
|
|
|
|
|
|
|
|
if __name__ == "__main__": |
|
|
parser = argparse.ArgumentParser(description="Extract a LoRA/LoRA+ from two SafeTensors checkpoints.") |
|
|
parser.add_argument("model_a", type=str, help="Path to the base model (A) checkpoint.") |
|
|
parser.add_argument("model_b", type=str, help="Path to the finetuned model (B) checkpoint.") |
|
|
parser.add_argument("output", type=str, help="Path to save the output file.") |
|
|
|
|
|
parser.add_argument("--rank", type=int, required=True, help="The target rank for the decomposition.") |
|
|
parser.add_argument("--alpha", type=float, default=1.0, |
|
|
help="Informational alpha value for scaling. This value is NOT saved in the file.") |
|
|
parser.add_argument("--method", type=str, default="rsvd", choices=["svd", "rsvd"], help="Decomposition method.") |
|
|
parser.add_argument("--device", type=str, default="cuda", choices=["cuda", "cpu"], |
|
|
help="Device to use for computation.") |
|
|
parser.add_argument("--precision", type=str, default="fp32", choices=["fp32", "fp16", "bf16"], |
|
|
help="Precision for calculations.") |
|
|
|
|
|
parser.add_argument("--oversamples", type=int, default=10, |
|
|
help="Oversampling parameter for Randomized SVD for better accuracy.") |
|
|
|
|
|
args = parser.parse_args() |
|
|
|
|
|
if args.device == "cuda" and not torch.cuda.is_available(): |
|
|
print("CUDA is not available. Falling back to CPU.") |
|
|
args.device = "cpu" |
|
|
|
|
|
extract_and_svd_lora(args) |