Upload model_merge.py
Browse filesMerge models with SLERP or LERP
- model_merge.py +134 -0
model_merge.py
ADDED
|
@@ -0,0 +1,134 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import argparse
|
| 3 |
+
from safetensors.torch import load_file, save_file
|
| 4 |
+
from tqdm import tqdm
|
| 5 |
+
import os
|
| 6 |
+
|
| 7 |
+
def slerp(t1, t2, alpha):
|
| 8 |
+
"""
|
| 9 |
+
Performs Spherical Linear Interpolation (SLERP) between two tensors.
|
| 10 |
+
"""
|
| 11 |
+
# Ensure tensors are float32 for high precision calculations
|
| 12 |
+
t1_float = t1.float()
|
| 13 |
+
t2_float = t2.float()
|
| 14 |
+
|
| 15 |
+
# Flatten tensors to treat them as vectors
|
| 16 |
+
t1_flat = t1_float.flatten()
|
| 17 |
+
t2_flat = t2_float.flatten()
|
| 18 |
+
|
| 19 |
+
# Calculate the dot product between the normalized vectors
|
| 20 |
+
dot = torch.sum(t1_flat * t2_flat) / (torch.linalg.norm(t1_flat) * torch.linalg.norm(t2_flat))
|
| 21 |
+
|
| 22 |
+
# Clamp the dot product to the valid range [-1.0, 1.0] to avoid numerical errors
|
| 23 |
+
dot = torch.clamp(dot, -1.0, 1.0)
|
| 24 |
+
|
| 25 |
+
# Calculate the angle between the vectors
|
| 26 |
+
theta = torch.acos(dot)
|
| 27 |
+
|
| 28 |
+
# If the angle is very small, the tensors are nearly parallel.
|
| 29 |
+
# In this case, linear interpolation (LERP) is a good and stable approximation.
|
| 30 |
+
if torch.abs(theta) < 1e-4:
|
| 31 |
+
return torch.lerp(t1, t2, alpha)
|
| 32 |
+
|
| 33 |
+
sin_theta = torch.sin(theta)
|
| 34 |
+
|
| 35 |
+
# SLERP formula
|
| 36 |
+
factor1 = torch.sin((1.0 - alpha) * theta) / sin_theta
|
| 37 |
+
factor2 = torch.sin(alpha * theta) / sin_theta
|
| 38 |
+
|
| 39 |
+
# Interpolate the flattened tensors
|
| 40 |
+
interp_flat = factor1 * t1_flat + factor2 * t2_flat
|
| 41 |
+
|
| 42 |
+
# Reshape the result to the original tensor shape and cast back to the original dtype
|
| 43 |
+
return interp_flat.reshape(t1.shape).to(t1.dtype)
|
| 44 |
+
|
| 45 |
+
def lerp(t1, t2, alpha):
|
| 46 |
+
"""
|
| 47 |
+
Performs Linear Interpolation (LERP) between two tensors.
|
| 48 |
+
"""
|
| 49 |
+
return torch.lerp(t1, t2, alpha)
|
| 50 |
+
|
| 51 |
+
|
| 52 |
+
def main():
|
| 53 |
+
parser = argparse.ArgumentParser(description="Merge two Safetensor models using either Linear (LERP) or Spherical (SLERP) interpolation.")
|
| 54 |
+
parser.add_argument("model_a", type=str, help="Path to the first model (A).")
|
| 55 |
+
parser.add_argument("model_b", type=str, help="Path to the second model (B).")
|
| 56 |
+
parser.add_argument("output", type=str, help="Path to save the merged model.")
|
| 57 |
+
parser.add_argument("--alpha", type=float, default=0.5, help="Interpolation factor (alpha). 0.0 is 100%% model A, 1.0 is 100%% model B. Default is 0.5.")
|
| 58 |
+
parser.add_argument("--method", type=str, default="lerp", choices=["lerp", "slerp"], help="Merge method to use: 'lerp' (linear) or 'slerp' (spherical). Default is 'lerp'.")
|
| 59 |
+
|
| 60 |
+
args = parser.parse_args()
|
| 61 |
+
|
| 62 |
+
if not os.path.exists(args.model_a):
|
| 63 |
+
print(f"Error: Model file not found at {args.model_a}")
|
| 64 |
+
return
|
| 65 |
+
if not os.path.exists(args.model_b):
|
| 66 |
+
print(f"Error: Model file not found at {args.model_b}")
|
| 67 |
+
return
|
| 68 |
+
|
| 69 |
+
print(f"Loading model A from: {args.model_a}")
|
| 70 |
+
tensors_a = load_file(args.model_a)
|
| 71 |
+
|
| 72 |
+
print(f"Loading model B from: {args.model_b}")
|
| 73 |
+
tensors_b = load_file(args.model_b)
|
| 74 |
+
|
| 75 |
+
merged_tensors = {}
|
| 76 |
+
|
| 77 |
+
# Find common and unique keys
|
| 78 |
+
keys_a = set(tensors_a.keys())
|
| 79 |
+
keys_b = set(tensors_b.keys())
|
| 80 |
+
common_keys = keys_a.intersection(keys_b)
|
| 81 |
+
keys_only_in_a = keys_a - keys_b
|
| 82 |
+
keys_only_in_b = keys_b - keys_a
|
| 83 |
+
|
| 84 |
+
print(f"\nFound {len(keys_a)} keys in {args.model_a}.")
|
| 85 |
+
print(f"Found {len(keys_b)} keys in {args.model_b}.")
|
| 86 |
+
print(f"-> Found {len(common_keys)} common keys.")
|
| 87 |
+
print(f"-> Found {len(keys_only_in_a)} keys unique to model A.")
|
| 88 |
+
print(f"-> Found {len(keys_only_in_b)} keys unique to model B.\n")
|
| 89 |
+
|
| 90 |
+
if not common_keys and not keys_only_in_a and not keys_only_in_b:
|
| 91 |
+
print("Warning: No tensors found to merge or copy. The output file will be empty.")
|
| 92 |
+
save_file({}, args.output)
|
| 93 |
+
print("Operation complete.")
|
| 94 |
+
return
|
| 95 |
+
|
| 96 |
+
print(f"Merging {len(common_keys)} common layers with alpha={args.alpha} using {args.method.upper()}...")
|
| 97 |
+
for key in tqdm(common_keys, desc="Merging common layers"):
|
| 98 |
+
if tensors_a[key].shape != tensors_b[key].shape:
|
| 99 |
+
print(f"Warning: Skipping layer '{key}' due to shape mismatch: {tensors_a[key].shape} vs {tensors_b[key].shape}")
|
| 100 |
+
merged_tensors[key] = tensors_a[key]
|
| 101 |
+
continue
|
| 102 |
+
|
| 103 |
+
tensor_a = tensors_a[key]
|
| 104 |
+
tensor_b = tensors_b[key]
|
| 105 |
+
|
| 106 |
+
if not tensor_a.is_floating_point():
|
| 107 |
+
print(f"Warning: Skipping merge for non-floating point tensor '{key}' (dtype: {tensor_a.dtype}). Copying from model A.")
|
| 108 |
+
merged_tensors[key] = tensor_a
|
| 109 |
+
continue
|
| 110 |
+
|
| 111 |
+
if args.method == "slerp":
|
| 112 |
+
merged_tensors[key] = slerp(tensor_a, tensor_b, args.alpha)
|
| 113 |
+
else: # Default to lerp
|
| 114 |
+
merged_tensors[key] = lerp(tensor_a, tensor_b, args.alpha)
|
| 115 |
+
|
| 116 |
+
|
| 117 |
+
# Copy unique layers
|
| 118 |
+
if keys_only_in_a:
|
| 119 |
+
print(f"Copying {len(keys_only_in_a)} layers unique to model A...")
|
| 120 |
+
for key in tqdm(keys_only_in_a, desc="Copying layers from A"):
|
| 121 |
+
merged_tensors[key] = tensors_a[key]
|
| 122 |
+
|
| 123 |
+
if keys_only_in_b:
|
| 124 |
+
print(f"Copying {len(keys_only_in_b)} layers unique to model B...")
|
| 125 |
+
for key in tqdm(keys_only_in_b, desc="Copying layers from B"):
|
| 126 |
+
merged_tensors[key] = tensors_b[key]
|
| 127 |
+
|
| 128 |
+
print(f"\nSaving merged model to: {args.output}")
|
| 129 |
+
save_file(merged_tensors, args.output)
|
| 130 |
+
print("Merge complete!")
|
| 131 |
+
|
| 132 |
+
if __name__ == "__main__":
|
| 133 |
+
main()
|
| 134 |
+
|