#!/usr/bin/env python import argparse import torch from safetensors import safe_open def compare_safetensors(filepath1: str, filepath2: str): """ Compares two .safetensors files, ignoring a specific prefix on layer names, and prints a summary of the differences. Args: filepath1 (str): Path to the first .safetensors file. filepath2 (str): Path to the second .safetensors file. """ # The prefix to ignore on layer names prefix_to_ignore = "model.diffusion_model." # Dictionaries to hold results results = { "only_in_file1": [], "only_in_file2": [], "different_content": [], } print("\nLoading files and preparing for comparison...") print(f"Ignoring prefix: '{prefix_to_ignore}'") try: # Use 'with' to ensure files are closed properly with safe_open(filepath1, framework="pt", device="cpu") as f1, \ safe_open(filepath2, framework="pt", device="cpu") as f2: # Create maps from the normalized key (suffix) to the original key map1 = {key.removeprefix(prefix_to_ignore): key for key in f1.keys()} map2 = {key.removeprefix(prefix_to_ignore): key for key in f2.keys()} # Get the set of normalized tensor keys from each file normalized_keys1 = set(map1.keys()) normalized_keys2 = set(map2.keys()) # 1. Find normalized keys (layers) unique to each file results["only_in_file1"] = sorted(list(normalized_keys1 - normalized_keys2)) results["only_in_file2"] = sorted(list(normalized_keys2 - normalized_keys1)) # 2. Find normalized keys present in both files to compare their content common_normalized_keys = normalized_keys1.intersection(normalized_keys2) print(f"Comparing {len(common_normalized_keys)} common tensors...") for norm_key in sorted(list(common_normalized_keys)): # Get the original key for each file using the maps original_key1 = map1[norm_key] original_key2 = map2[norm_key] # Get the tensor from each file using its original key tensor1 = f1.get_tensor(original_key1) tensor2 = f2.get_tensor(original_key2) # Compare tensors for equality if not torch.equal(tensor1, tensor2): # Store the normalized key if content differs results["different_content"].append(norm_key) # --- Print the results --- print("\n" + "=" * 60) print("šŸ” Safetensor Comparison Results") print("=" * 60) print(f"File 1: {filepath1}") print(f"File 2: {filepath2}") print("-" * 60) # Check if any differences were found at all total_diffs = len(results["only_in_file1"]) + len(results["only_in_file2"]) + len(results["different_content"]) if total_diffs == 0: print("\nāœ… The files are identical after normalization. No differences found.") print("=" * 60 + "\n") return # Report tensors with different content if results["different_content"]: print(f"\nā†”ļø Tensors with Different Content ({len(results['different_content'])}):") for norm_key in results["different_content"]: print(f" - Normalized Key: {norm_key}") print(f" (File 1 Original: {map1[norm_key]})") print(f" (File 2 Original: {map2[norm_key]})") # Report tensors only in file 1 if results["only_in_file1"]: print(f"\n→ Tensors Only in File 1 ({len(results['only_in_file1'])}):") for norm_key in results["only_in_file1"]: print(f" - Normalized Key: {norm_key} (Original: {map1[norm_key]})") # Report tensors only in file 2 if results["only_in_file2"]: print(f"\n← Tensors Only in File 2 ({len(results['only_in_file2'])}):") for norm_key in results["only_in_file2"]: print(f" - Normalized Key: {norm_key} (Original: {map2[norm_key]})") print("\n" + "=" * 60 + "\n") except FileNotFoundError as e: print(f"āŒ Error: Could not find a file. Details: {e}") except Exception as e: print(f"āŒ An error occurred: {e}") print("Please ensure both files are valid .safetensors files.") if __name__ == "__main__": # --- Argument Parser Setup --- parser = argparse.ArgumentParser( description="Compares two .safetensors files and lists the differences in their layers (tensors), ignoring a specific prefix.", formatter_class=argparse.RawTextHelpFormatter ) parser.add_argument( "file1", type=str, help="Path to the first .safetensors file." ) parser.add_argument( "file2", type=str, help="Path to the second .safetensors file." ) args = parser.parse_args() # --- Run the function --- compare_safetensors(args.file1, args.file2)