File size: 5,112 Bytes
54c3b42 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 |
#!/usr/bin/env python
import argparse
import torch
from safetensors import safe_open
def compare_safetensors(filepath1: str, filepath2: str):
"""
Compares two .safetensors files, ignoring a specific prefix on layer names,
and prints a summary of the differences.
Args:
filepath1 (str): Path to the first .safetensors file.
filepath2 (str): Path to the second .safetensors file.
"""
# The prefix to ignore on layer names
prefix_to_ignore = "model.diffusion_model."
# Dictionaries to hold results
results = {
"only_in_file1": [],
"only_in_file2": [],
"different_content": [],
}
print("\nLoading files and preparing for comparison...")
print(f"Ignoring prefix: '{prefix_to_ignore}'")
try:
# Use 'with' to ensure files are closed properly
with safe_open(filepath1, framework="pt", device="cpu") as f1, \
safe_open(filepath2, framework="pt", device="cpu") as f2:
# Create maps from the normalized key (suffix) to the original key
map1 = {key.removeprefix(prefix_to_ignore): key for key in f1.keys()}
map2 = {key.removeprefix(prefix_to_ignore): key for key in f2.keys()}
# Get the set of normalized tensor keys from each file
normalized_keys1 = set(map1.keys())
normalized_keys2 = set(map2.keys())
# 1. Find normalized keys (layers) unique to each file
results["only_in_file1"] = sorted(list(normalized_keys1 - normalized_keys2))
results["only_in_file2"] = sorted(list(normalized_keys2 - normalized_keys1))
# 2. Find normalized keys present in both files to compare their content
common_normalized_keys = normalized_keys1.intersection(normalized_keys2)
print(f"Comparing {len(common_normalized_keys)} common tensors...")
for norm_key in sorted(list(common_normalized_keys)):
# Get the original key for each file using the maps
original_key1 = map1[norm_key]
original_key2 = map2[norm_key]
# Get the tensor from each file using its original key
tensor1 = f1.get_tensor(original_key1)
tensor2 = f2.get_tensor(original_key2)
# Compare tensors for equality
if not torch.equal(tensor1, tensor2):
# Store the normalized key if content differs
results["different_content"].append(norm_key)
# --- Print the results ---
print("\n" + "=" * 60)
print("🔍 Safetensor Comparison Results")
print("=" * 60)
print(f"File 1: {filepath1}")
print(f"File 2: {filepath2}")
print("-" * 60)
# Check if any differences were found at all
total_diffs = len(results["only_in_file1"]) + len(results["only_in_file2"]) + len(results["different_content"])
if total_diffs == 0:
print("\n✅ The files are identical after normalization. No differences found.")
print("=" * 60 + "\n")
return
# Report tensors with different content
if results["different_content"]:
print(f"\n↔️ Tensors with Different Content ({len(results['different_content'])}):")
for norm_key in results["different_content"]:
print(f" - Normalized Key: {norm_key}")
print(f" (File 1 Original: {map1[norm_key]})")
print(f" (File 2 Original: {map2[norm_key]})")
# Report tensors only in file 1
if results["only_in_file1"]:
print(f"\n→ Tensors Only in File 1 ({len(results['only_in_file1'])}):")
for norm_key in results["only_in_file1"]:
print(f" - Normalized Key: {norm_key} (Original: {map1[norm_key]})")
# Report tensors only in file 2
if results["only_in_file2"]:
print(f"\n← Tensors Only in File 2 ({len(results['only_in_file2'])}):")
for norm_key in results["only_in_file2"]:
print(f" - Normalized Key: {norm_key} (Original: {map2[norm_key]})")
print("\n" + "=" * 60 + "\n")
except FileNotFoundError as e:
print(f"❌ Error: Could not find a file. Details: {e}")
except Exception as e:
print(f"❌ An error occurred: {e}")
print("Please ensure both files are valid .safetensors files.")
if __name__ == "__main__":
# --- Argument Parser Setup ---
parser = argparse.ArgumentParser(
description="Compares two .safetensors files and lists the differences in their layers (tensors), ignoring a specific prefix.",
formatter_class=argparse.RawTextHelpFormatter
)
parser.add_argument(
"file1",
type=str,
help="Path to the first .safetensors file."
)
parser.add_argument(
"file2",
type=str,
help="Path to the second .safetensors file."
)
args = parser.parse_args()
# --- Run the function ---
compare_safetensors(args.file1, args.file2) |