File size: 5,209 Bytes
bf1fba0 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 |
import torch
import argparse
from safetensors.torch import load_file, save_file
from tqdm import tqdm
import os
def slerp(t1, t2, alpha):
"""
Performs Spherical Linear Interpolation (SLERP) between two tensors.
"""
# Ensure tensors are float32 for high precision calculations
t1_float = t1.float()
t2_float = t2.float()
# Flatten tensors to treat them as vectors
t1_flat = t1_float.flatten()
t2_flat = t2_float.flatten()
# Calculate the dot product between the normalized vectors
dot = torch.sum(t1_flat * t2_flat) / (torch.linalg.norm(t1_flat) * torch.linalg.norm(t2_flat))
# Clamp the dot product to the valid range [-1.0, 1.0] to avoid numerical errors
dot = torch.clamp(dot, -1.0, 1.0)
# Calculate the angle between the vectors
theta = torch.acos(dot)
# If the angle is very small, the tensors are nearly parallel.
# In this case, linear interpolation (LERP) is a good and stable approximation.
if torch.abs(theta) < 1e-4:
return torch.lerp(t1, t2, alpha)
sin_theta = torch.sin(theta)
# SLERP formula
factor1 = torch.sin((1.0 - alpha) * theta) / sin_theta
factor2 = torch.sin(alpha * theta) / sin_theta
# Interpolate the flattened tensors
interp_flat = factor1 * t1_flat + factor2 * t2_flat
# Reshape the result to the original tensor shape and cast back to the original dtype
return interp_flat.reshape(t1.shape).to(t1.dtype)
def lerp(t1, t2, alpha):
"""
Performs Linear Interpolation (LERP) between two tensors.
"""
return torch.lerp(t1, t2, alpha)
def main():
parser = argparse.ArgumentParser(description="Merge two Safetensor models using either Linear (LERP) or Spherical (SLERP) interpolation.")
parser.add_argument("model_a", type=str, help="Path to the first model (A).")
parser.add_argument("model_b", type=str, help="Path to the second model (B).")
parser.add_argument("output", type=str, help="Path to save the merged model.")
parser.add_argument("--alpha", type=float, default=0.5, help="Interpolation factor (alpha). 0.0 is 100%% model A, 1.0 is 100%% model B. Default is 0.5.")
parser.add_argument("--method", type=str, default="lerp", choices=["lerp", "slerp"], help="Merge method to use: 'lerp' (linear) or 'slerp' (spherical). Default is 'lerp'.")
args = parser.parse_args()
if not os.path.exists(args.model_a):
print(f"Error: Model file not found at {args.model_a}")
return
if not os.path.exists(args.model_b):
print(f"Error: Model file not found at {args.model_b}")
return
print(f"Loading model A from: {args.model_a}")
tensors_a = load_file(args.model_a)
print(f"Loading model B from: {args.model_b}")
tensors_b = load_file(args.model_b)
merged_tensors = {}
# Find common and unique keys
keys_a = set(tensors_a.keys())
keys_b = set(tensors_b.keys())
common_keys = keys_a.intersection(keys_b)
keys_only_in_a = keys_a - keys_b
keys_only_in_b = keys_b - keys_a
print(f"\nFound {len(keys_a)} keys in {args.model_a}.")
print(f"Found {len(keys_b)} keys in {args.model_b}.")
print(f"-> Found {len(common_keys)} common keys.")
print(f"-> Found {len(keys_only_in_a)} keys unique to model A.")
print(f"-> Found {len(keys_only_in_b)} keys unique to model B.\n")
if not common_keys and not keys_only_in_a and not keys_only_in_b:
print("Warning: No tensors found to merge or copy. The output file will be empty.")
save_file({}, args.output)
print("Operation complete.")
return
print(f"Merging {len(common_keys)} common layers with alpha={args.alpha} using {args.method.upper()}...")
for key in tqdm(common_keys, desc="Merging common layers"):
if tensors_a[key].shape != tensors_b[key].shape:
print(f"Warning: Skipping layer '{key}' due to shape mismatch: {tensors_a[key].shape} vs {tensors_b[key].shape}")
merged_tensors[key] = tensors_a[key]
continue
tensor_a = tensors_a[key]
tensor_b = tensors_b[key]
if not tensor_a.is_floating_point():
print(f"Warning: Skipping merge for non-floating point tensor '{key}' (dtype: {tensor_a.dtype}). Copying from model A.")
merged_tensors[key] = tensor_a
continue
if args.method == "slerp":
merged_tensors[key] = slerp(tensor_a, tensor_b, args.alpha)
else: # Default to lerp
merged_tensors[key] = lerp(tensor_a, tensor_b, args.alpha)
# Copy unique layers
if keys_only_in_a:
print(f"Copying {len(keys_only_in_a)} layers unique to model A...")
for key in tqdm(keys_only_in_a, desc="Copying layers from A"):
merged_tensors[key] = tensors_a[key]
if keys_only_in_b:
print(f"Copying {len(keys_only_in_b)} layers unique to model B...")
for key in tqdm(keys_only_in_b, desc="Copying layers from B"):
merged_tensors[key] = tensors_b[key]
print(f"\nSaving merged model to: {args.output}")
save_file(merged_tensors, args.output)
print("Merge complete!")
if __name__ == "__main__":
main()
|