File size: 6,660 Bytes
faa1b64 138d9de faa1b64 138d9de faa1b64 138d9de faa1b64 138d9de faa1b64 138d9de faa1b64 138d9de faa1b64 138d9de faa1b64 138d9de faa1b64 138d9de faa1b64 138d9de faa1b64 138d9de faa1b64 138d9de faa1b64 138d9de faa1b64 138d9de faa1b64 138d9de faa1b64 138d9de faa1b64 138d9de faa1b64 138d9de |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 |
import torch
import argparse
from safetensors.torch import save_file, safe_open
from tqdm import tqdm
import sys
def normalize_key(key):
"""Strips the 'model.diffusion_model.' prefix from a key if it exists."""
prefix = 'model.diffusion_model.'
if key.startswith(prefix):
return key[len(prefix):]
return key
def get_torch_dtype(dtype_str: str):
"""Converts a string to a torch.dtype object."""
if dtype_str == "fp32":
return torch.float32
if dtype_str == "fp16":
return torch.float16
if dtype_str == "bf16":
return torch.bfloat16
raise ValueError(f"Unsupported dtype: {dtype_str}")
def randomized_svd(matrix, rank, n_oversamples=10):
"""Performs Randomized SVD for a faster approximation."""
max_rank = min(matrix.shape)
if rank >= max_rank:
rank = max_rank
n_oversamples = 0
target_rank = min(rank + n_oversamples, max_rank)
P = torch.randn((matrix.shape[1], target_rank), device=matrix.device, dtype=matrix.dtype)
Y = matrix @ P
Q, _ = torch.linalg.qr(Y.float())
B = Q.T @ matrix.float()
U_b, S, Vh = torch.linalg.svd(B, full_matrices=False)
U = Q @ U_b
U = U[:, :rank]
S = S[:rank]
Vh = Vh[:rank, :]
return U, S, Vh
def extract_and_svd_lora(args):
"""Main function to extract, decompose, and save the LoRA."""
print(f"Loading base model A: {args.model_a}")
print(f"Loading finetuned model B: {args.model_b}")
print(f"Using decomposition method: {args.method}")
lora_tensors = {}
dtype = get_torch_dtype(args.precision)
with safe_open(args.model_a, framework="pt", device="cpu") as f_a, \
safe_open(args.model_b, framework="pt", device="cpu") as f_b:
keys_a_original = set(f_a.keys())
keys_b_original = set(f_b.keys())
print(f"\nFound {len(keys_a_original)} keys in model A.")
print(f"Found {len(keys_b_original)} keys in model B.")
normalized_keys_a = {normalize_key(k): k for k in keys_a_original}
normalized_keys_b = {normalize_key(k): k for k in keys_b_original}
common_normalized_keys = set(normalized_keys_a.keys()).intersection(set(normalized_keys_b.keys()))
print(f"Found {len(common_normalized_keys)} common keys after normalization.\n")
processable_keys = {k for k in common_normalized_keys if
(k.endswith('.weight') or k.endswith('.bias')) and 'lora_' not in k}
if not processable_keys:
print("No common weight or bias keys found to process. Check if models are compatible.")
sys.exit(1)
print(f"Found {len(processable_keys)} common keys to process.")
for norm_key in tqdm(sorted(list(processable_keys)), desc="Processing Layers"):
try:
original_key_a = normalized_keys_a[norm_key]
original_key_b = normalized_keys_b[norm_key]
tensor_a = f_a.get_tensor(original_key_a).to(device=args.device, dtype=dtype)
tensor_b = f_b.get_tensor(original_key_b).to(device=args.device, dtype=dtype)
if tensor_a.shape != tensor_b.shape:
tqdm.write(f"Skipping key {norm_key} due to shape mismatch")
continue
delta = tensor_b - tensor_a
if norm_key.endswith('.weight'):
delta_w = delta
if delta_w.dim() < 2:
tqdm.write(f"Skipping weight key {norm_key} as it's not a 2D matrix.")
continue
if delta_w.dim() > 2:
delta_w = delta_w.view(delta_w.shape[0], -1)
if args.method == 'rsvd':
# Use the new oversamples argument
U, S, Vh = randomized_svd(delta_w, args.rank, n_oversamples=args.oversamples)
else:
U, S, Vh = torch.linalg.svd(delta_w, full_matrices=False)
current_rank = min(args.rank, S.size(0))
U = U[:, :current_rank]
S = S[:current_rank]
Vh = Vh[:current_rank, :]
lora_down = Vh
lora_up = U @ torch.diag(S)
base_name = norm_key.replace('.weight', '')
prefixed_base_name = f"diffusion_model.{base_name}"
lora_down_name = f"{prefixed_base_name}.lora_down.weight"
lora_up_name = f"{prefixed_base_name}.lora_up.weight"
lora_tensors[lora_down_name] = lora_down.contiguous().cpu().to(torch.float32)
lora_tensors[lora_up_name] = lora_up.contiguous().cpu().to(torch.float32)
except Exception as e:
tqdm.write(f"Failed to process key {norm_key}: {e}")
if not lora_tensors:
print("No tensors were processed. Output file will not be created.")
return
print(f"\nSaving {len(lora_tensors)} tensors to {args.output}...")
save_file(lora_tensors, args.output)
print("✅ Done!")
if __name__ == "__main__":
parser = argparse.ArgumentParser(description="Extract a LoRA/LoRA+ from two SafeTensors checkpoints.")
parser.add_argument("model_a", type=str, help="Path to the base model (A) checkpoint.")
parser.add_argument("model_b", type=str, help="Path to the finetuned model (B) checkpoint.")
parser.add_argument("output", type=str, help="Path to save the output file.")
parser.add_argument("--rank", type=int, required=True, help="The target rank for the decomposition.")
parser.add_argument("--alpha", type=float, default=1.0,
help="Informational alpha value for scaling. This value is NOT saved in the file.")
parser.add_argument("--method", type=str, default="rsvd", choices=["svd", "rsvd"], help="Decomposition method.")
parser.add_argument("--device", type=str, default="cuda", choices=["cuda", "cpu"],
help="Device to use for computation.")
parser.add_argument("--precision", type=str, default="fp32", choices=["fp32", "fp16", "bf16"],
help="Precision for calculations.")
# --- NEW ARGUMENT ---
parser.add_argument("--oversamples", type=int, default=10,
help="Oversampling parameter for Randomized SVD for better accuracy.")
args = parser.parse_args()
if args.device == "cuda" and not torch.cuda.is_available():
print("CUDA is not available. Falling back to CPU.")
args.device = "cpu"
extract_and_svd_lora(args) |