File size: 5,209 Bytes
bf1fba0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
import torch
import argparse
from safetensors.torch import load_file, save_file
from tqdm import tqdm
import os

def slerp(t1, t2, alpha):
    """
    Performs Spherical Linear Interpolation (SLERP) between two tensors.
    """
    # Ensure tensors are float32 for high precision calculations
    t1_float = t1.float()
    t2_float = t2.float()

    # Flatten tensors to treat them as vectors
    t1_flat = t1_float.flatten()
    t2_flat = t2_float.flatten()

    # Calculate the dot product between the normalized vectors
    dot = torch.sum(t1_flat * t2_flat) / (torch.linalg.norm(t1_flat) * torch.linalg.norm(t2_flat))

    # Clamp the dot product to the valid range [-1.0, 1.0] to avoid numerical errors
    dot = torch.clamp(dot, -1.0, 1.0)

    # Calculate the angle between the vectors
    theta = torch.acos(dot)

    # If the angle is very small, the tensors are nearly parallel.
    # In this case, linear interpolation (LERP) is a good and stable approximation.
    if torch.abs(theta) < 1e-4:
        return torch.lerp(t1, t2, alpha)

    sin_theta = torch.sin(theta)

    # SLERP formula
    factor1 = torch.sin((1.0 - alpha) * theta) / sin_theta
    factor2 = torch.sin(alpha * theta) / sin_theta

    # Interpolate the flattened tensors
    interp_flat = factor1 * t1_flat + factor2 * t2_flat

    # Reshape the result to the original tensor shape and cast back to the original dtype
    return interp_flat.reshape(t1.shape).to(t1.dtype)

def lerp(t1, t2, alpha):
    """
    Performs Linear Interpolation (LERP) between two tensors.
    """
    return torch.lerp(t1, t2, alpha)


def main():
    parser = argparse.ArgumentParser(description="Merge two Safetensor models using either Linear (LERP) or Spherical (SLERP) interpolation.")
    parser.add_argument("model_a", type=str, help="Path to the first model (A).")
    parser.add_argument("model_b", type=str, help="Path to the second model (B).")
    parser.add_argument("output", type=str, help="Path to save the merged model.")
    parser.add_argument("--alpha", type=float, default=0.5, help="Interpolation factor (alpha). 0.0 is 100%% model A, 1.0 is 100%% model B. Default is 0.5.")
    parser.add_argument("--method", type=str, default="lerp", choices=["lerp", "slerp"], help="Merge method to use: 'lerp' (linear) or 'slerp' (spherical). Default is 'lerp'.")
    
    args = parser.parse_args()

    if not os.path.exists(args.model_a):
        print(f"Error: Model file not found at {args.model_a}")
        return
    if not os.path.exists(args.model_b):
        print(f"Error: Model file not found at {args.model_b}")
        return

    print(f"Loading model A from: {args.model_a}")
    tensors_a = load_file(args.model_a)
    
    print(f"Loading model B from: {args.model_b}")
    tensors_b = load_file(args.model_b)
    
    merged_tensors = {}

    # Find common and unique keys
    keys_a = set(tensors_a.keys())
    keys_b = set(tensors_b.keys())
    common_keys = keys_a.intersection(keys_b)
    keys_only_in_a = keys_a - keys_b
    keys_only_in_b = keys_b - keys_a

    print(f"\nFound {len(keys_a)} keys in {args.model_a}.")
    print(f"Found {len(keys_b)} keys in {args.model_b}.")
    print(f"-> Found {len(common_keys)} common keys.")
    print(f"-> Found {len(keys_only_in_a)} keys unique to model A.")
    print(f"-> Found {len(keys_only_in_b)} keys unique to model B.\n")

    if not common_keys and not keys_only_in_a and not keys_only_in_b:
        print("Warning: No tensors found to merge or copy. The output file will be empty.")
        save_file({}, args.output)
        print("Operation complete.")
        return

    print(f"Merging {len(common_keys)} common layers with alpha={args.alpha} using {args.method.upper()}...")
    for key in tqdm(common_keys, desc="Merging common layers"):
        if tensors_a[key].shape != tensors_b[key].shape:
            print(f"Warning: Skipping layer '{key}' due to shape mismatch: {tensors_a[key].shape} vs {tensors_b[key].shape}")
            merged_tensors[key] = tensors_a[key]
            continue

        tensor_a = tensors_a[key]
        tensor_b = tensors_b[key]

        if not tensor_a.is_floating_point():
            print(f"Warning: Skipping merge for non-floating point tensor '{key}' (dtype: {tensor_a.dtype}). Copying from model A.")
            merged_tensors[key] = tensor_a
            continue
        
        if args.method == "slerp":
            merged_tensors[key] = slerp(tensor_a, tensor_b, args.alpha)
        else: # Default to lerp
            merged_tensors[key] = lerp(tensor_a, tensor_b, args.alpha)


    # Copy unique layers
    if keys_only_in_a:
        print(f"Copying {len(keys_only_in_a)} layers unique to model A...")
        for key in tqdm(keys_only_in_a, desc="Copying layers from A"):
            merged_tensors[key] = tensors_a[key]

    if keys_only_in_b:
        print(f"Copying {len(keys_only_in_b)} layers unique to model B...")
        for key in tqdm(keys_only_in_b, desc="Copying layers from B"):
            merged_tensors[key] = tensors_b[key]

    print(f"\nSaving merged model to: {args.output}")
    save_file(merged_tensors, args.output)
    print("Merge complete!")

if __name__ == "__main__":
    main()