File size: 5,112 Bytes
54c3b42
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
#!/usr/bin/env python

import argparse
import torch
from safetensors import safe_open


def compare_safetensors(filepath1: str, filepath2: str):
    """
    Compares two .safetensors files, ignoring a specific prefix on layer names,
    and prints a summary of the differences.

    Args:
        filepath1 (str): Path to the first .safetensors file.
        filepath2 (str): Path to the second .safetensors file.
    """
    # The prefix to ignore on layer names
    prefix_to_ignore = "model.diffusion_model."

    # Dictionaries to hold results
    results = {
        "only_in_file1": [],
        "only_in_file2": [],
        "different_content": [],
    }

    print("\nLoading files and preparing for comparison...")
    print(f"Ignoring prefix: '{prefix_to_ignore}'")

    try:
        # Use 'with' to ensure files are closed properly
        with safe_open(filepath1, framework="pt", device="cpu") as f1, \
                safe_open(filepath2, framework="pt", device="cpu") as f2:

            # Create maps from the normalized key (suffix) to the original key
            map1 = {key.removeprefix(prefix_to_ignore): key for key in f1.keys()}
            map2 = {key.removeprefix(prefix_to_ignore): key for key in f2.keys()}

            # Get the set of normalized tensor keys from each file
            normalized_keys1 = set(map1.keys())
            normalized_keys2 = set(map2.keys())

            # 1. Find normalized keys (layers) unique to each file
            results["only_in_file1"] = sorted(list(normalized_keys1 - normalized_keys2))
            results["only_in_file2"] = sorted(list(normalized_keys2 - normalized_keys1))

            # 2. Find normalized keys present in both files to compare their content
            common_normalized_keys = normalized_keys1.intersection(normalized_keys2)
            print(f"Comparing {len(common_normalized_keys)} common tensors...")

            for norm_key in sorted(list(common_normalized_keys)):
                # Get the original key for each file using the maps
                original_key1 = map1[norm_key]
                original_key2 = map2[norm_key]

                # Get the tensor from each file using its original key
                tensor1 = f1.get_tensor(original_key1)
                tensor2 = f2.get_tensor(original_key2)

                # Compare tensors for equality
                if not torch.equal(tensor1, tensor2):
                    # Store the normalized key if content differs
                    results["different_content"].append(norm_key)

        # --- Print the results ---
        print("\n" + "=" * 60)
        print("🔍 Safetensor Comparison Results")
        print("=" * 60)
        print(f"File 1: {filepath1}")
        print(f"File 2: {filepath2}")
        print("-" * 60)

        # Check if any differences were found at all
        total_diffs = len(results["only_in_file1"]) + len(results["only_in_file2"]) + len(results["different_content"])
        if total_diffs == 0:
            print("\n✅ The files are identical after normalization. No differences found.")
            print("=" * 60 + "\n")
            return

        # Report tensors with different content
        if results["different_content"]:
            print(f"\n↔️ Tensors with Different Content ({len(results['different_content'])}):")
            for norm_key in results["different_content"]:
                print(f"  - Normalized Key: {norm_key}")
                print(f"    (File 1 Original: {map1[norm_key]})")
                print(f"    (File 2 Original: {map2[norm_key]})")

        # Report tensors only in file 1
        if results["only_in_file1"]:
            print(f"\n→ Tensors Only in File 1 ({len(results['only_in_file1'])}):")
            for norm_key in results["only_in_file1"]:
                print(f"  - Normalized Key: {norm_key} (Original: {map1[norm_key]})")

        # Report tensors only in file 2
        if results["only_in_file2"]:
            print(f"\n← Tensors Only in File 2 ({len(results['only_in_file2'])}):")
            for norm_key in results["only_in_file2"]:
                print(f"  - Normalized Key: {norm_key} (Original: {map2[norm_key]})")

        print("\n" + "=" * 60 + "\n")

    except FileNotFoundError as e:
        print(f"❌ Error: Could not find a file. Details: {e}")
    except Exception as e:
        print(f"❌ An error occurred: {e}")
        print("Please ensure both files are valid .safetensors files.")


if __name__ == "__main__":
    # --- Argument Parser Setup ---
    parser = argparse.ArgumentParser(
        description="Compares two .safetensors files and lists the differences in their layers (tensors), ignoring a specific prefix.",
        formatter_class=argparse.RawTextHelpFormatter
    )

    parser.add_argument(
        "file1",
        type=str,
        help="Path to the first .safetensors file."
    )
    parser.add_argument(
        "file2",
        type=str,
        help="Path to the second .safetensors file."
    )

    args = parser.parse_args()

    # --- Run the function ---
    compare_safetensors(args.file1, args.file2)