python-script-dump / comparemodels.py
anyMODE's picture
Upload 4 files
54c3b42 verified
raw
history blame
5.11 kB
#!/usr/bin/env python
import argparse
import torch
from safetensors import safe_open
def compare_safetensors(filepath1: str, filepath2: str):
"""
Compares two .safetensors files, ignoring a specific prefix on layer names,
and prints a summary of the differences.
Args:
filepath1 (str): Path to the first .safetensors file.
filepath2 (str): Path to the second .safetensors file.
"""
# The prefix to ignore on layer names
prefix_to_ignore = "model.diffusion_model."
# Dictionaries to hold results
results = {
"only_in_file1": [],
"only_in_file2": [],
"different_content": [],
}
print("\nLoading files and preparing for comparison...")
print(f"Ignoring prefix: '{prefix_to_ignore}'")
try:
# Use 'with' to ensure files are closed properly
with safe_open(filepath1, framework="pt", device="cpu") as f1, \
safe_open(filepath2, framework="pt", device="cpu") as f2:
# Create maps from the normalized key (suffix) to the original key
map1 = {key.removeprefix(prefix_to_ignore): key for key in f1.keys()}
map2 = {key.removeprefix(prefix_to_ignore): key for key in f2.keys()}
# Get the set of normalized tensor keys from each file
normalized_keys1 = set(map1.keys())
normalized_keys2 = set(map2.keys())
# 1. Find normalized keys (layers) unique to each file
results["only_in_file1"] = sorted(list(normalized_keys1 - normalized_keys2))
results["only_in_file2"] = sorted(list(normalized_keys2 - normalized_keys1))
# 2. Find normalized keys present in both files to compare their content
common_normalized_keys = normalized_keys1.intersection(normalized_keys2)
print(f"Comparing {len(common_normalized_keys)} common tensors...")
for norm_key in sorted(list(common_normalized_keys)):
# Get the original key for each file using the maps
original_key1 = map1[norm_key]
original_key2 = map2[norm_key]
# Get the tensor from each file using its original key
tensor1 = f1.get_tensor(original_key1)
tensor2 = f2.get_tensor(original_key2)
# Compare tensors for equality
if not torch.equal(tensor1, tensor2):
# Store the normalized key if content differs
results["different_content"].append(norm_key)
# --- Print the results ---
print("\n" + "=" * 60)
print("πŸ” Safetensor Comparison Results")
print("=" * 60)
print(f"File 1: {filepath1}")
print(f"File 2: {filepath2}")
print("-" * 60)
# Check if any differences were found at all
total_diffs = len(results["only_in_file1"]) + len(results["only_in_file2"]) + len(results["different_content"])
if total_diffs == 0:
print("\nβœ… The files are identical after normalization. No differences found.")
print("=" * 60 + "\n")
return
# Report tensors with different content
if results["different_content"]:
print(f"\n↔️ Tensors with Different Content ({len(results['different_content'])}):")
for norm_key in results["different_content"]:
print(f" - Normalized Key: {norm_key}")
print(f" (File 1 Original: {map1[norm_key]})")
print(f" (File 2 Original: {map2[norm_key]})")
# Report tensors only in file 1
if results["only_in_file1"]:
print(f"\n→ Tensors Only in File 1 ({len(results['only_in_file1'])}):")
for norm_key in results["only_in_file1"]:
print(f" - Normalized Key: {norm_key} (Original: {map1[norm_key]})")
# Report tensors only in file 2
if results["only_in_file2"]:
print(f"\n← Tensors Only in File 2 ({len(results['only_in_file2'])}):")
for norm_key in results["only_in_file2"]:
print(f" - Normalized Key: {norm_key} (Original: {map2[norm_key]})")
print("\n" + "=" * 60 + "\n")
except FileNotFoundError as e:
print(f"❌ Error: Could not find a file. Details: {e}")
except Exception as e:
print(f"❌ An error occurred: {e}")
print("Please ensure both files are valid .safetensors files.")
if __name__ == "__main__":
# --- Argument Parser Setup ---
parser = argparse.ArgumentParser(
description="Compares two .safetensors files and lists the differences in their layers (tensors), ignoring a specific prefix.",
formatter_class=argparse.RawTextHelpFormatter
)
parser.add_argument(
"file1",
type=str,
help="Path to the first .safetensors file."
)
parser.add_argument(
"file2",
type=str,
help="Path to the second .safetensors file."
)
args = parser.parse_args()
# --- Run the function ---
compare_safetensors(args.file1, args.file2)