|
|
|
|
|
|
|
|
import argparse |
|
|
import torch |
|
|
from safetensors import safe_open |
|
|
|
|
|
|
|
|
def compare_safetensors(filepath1: str, filepath2: str): |
|
|
""" |
|
|
Compares two .safetensors files, ignoring a specific prefix on layer names, |
|
|
and prints a summary of the differences. |
|
|
|
|
|
Args: |
|
|
filepath1 (str): Path to the first .safetensors file. |
|
|
filepath2 (str): Path to the second .safetensors file. |
|
|
""" |
|
|
|
|
|
prefix_to_ignore = "model.diffusion_model." |
|
|
|
|
|
|
|
|
results = { |
|
|
"only_in_file1": [], |
|
|
"only_in_file2": [], |
|
|
"different_content": [], |
|
|
} |
|
|
|
|
|
print("\nLoading files and preparing for comparison...") |
|
|
print(f"Ignoring prefix: '{prefix_to_ignore}'") |
|
|
|
|
|
try: |
|
|
|
|
|
with safe_open(filepath1, framework="pt", device="cpu") as f1, \ |
|
|
safe_open(filepath2, framework="pt", device="cpu") as f2: |
|
|
|
|
|
|
|
|
map1 = {key.removeprefix(prefix_to_ignore): key for key in f1.keys()} |
|
|
map2 = {key.removeprefix(prefix_to_ignore): key for key in f2.keys()} |
|
|
|
|
|
|
|
|
normalized_keys1 = set(map1.keys()) |
|
|
normalized_keys2 = set(map2.keys()) |
|
|
|
|
|
|
|
|
results["only_in_file1"] = sorted(list(normalized_keys1 - normalized_keys2)) |
|
|
results["only_in_file2"] = sorted(list(normalized_keys2 - normalized_keys1)) |
|
|
|
|
|
|
|
|
common_normalized_keys = normalized_keys1.intersection(normalized_keys2) |
|
|
print(f"Comparing {len(common_normalized_keys)} common tensors...") |
|
|
|
|
|
for norm_key in sorted(list(common_normalized_keys)): |
|
|
|
|
|
original_key1 = map1[norm_key] |
|
|
original_key2 = map2[norm_key] |
|
|
|
|
|
|
|
|
tensor1 = f1.get_tensor(original_key1) |
|
|
tensor2 = f2.get_tensor(original_key2) |
|
|
|
|
|
|
|
|
if not torch.equal(tensor1, tensor2): |
|
|
|
|
|
results["different_content"].append(norm_key) |
|
|
|
|
|
|
|
|
print("\n" + "=" * 60) |
|
|
print("π Safetensor Comparison Results") |
|
|
print("=" * 60) |
|
|
print(f"File 1: {filepath1}") |
|
|
print(f"File 2: {filepath2}") |
|
|
print("-" * 60) |
|
|
|
|
|
|
|
|
total_diffs = len(results["only_in_file1"]) + len(results["only_in_file2"]) + len(results["different_content"]) |
|
|
if total_diffs == 0: |
|
|
print("\nβ
The files are identical after normalization. No differences found.") |
|
|
print("=" * 60 + "\n") |
|
|
return |
|
|
|
|
|
|
|
|
if results["different_content"]: |
|
|
print(f"\nβοΈ Tensors with Different Content ({len(results['different_content'])}):") |
|
|
for norm_key in results["different_content"]: |
|
|
print(f" - Normalized Key: {norm_key}") |
|
|
print(f" (File 1 Original: {map1[norm_key]})") |
|
|
print(f" (File 2 Original: {map2[norm_key]})") |
|
|
|
|
|
|
|
|
if results["only_in_file1"]: |
|
|
print(f"\nβ Tensors Only in File 1 ({len(results['only_in_file1'])}):") |
|
|
for norm_key in results["only_in_file1"]: |
|
|
print(f" - Normalized Key: {norm_key} (Original: {map1[norm_key]})") |
|
|
|
|
|
|
|
|
if results["only_in_file2"]: |
|
|
print(f"\nβ Tensors Only in File 2 ({len(results['only_in_file2'])}):") |
|
|
for norm_key in results["only_in_file2"]: |
|
|
print(f" - Normalized Key: {norm_key} (Original: {map2[norm_key]})") |
|
|
|
|
|
print("\n" + "=" * 60 + "\n") |
|
|
|
|
|
except FileNotFoundError as e: |
|
|
print(f"β Error: Could not find a file. Details: {e}") |
|
|
except Exception as e: |
|
|
print(f"β An error occurred: {e}") |
|
|
print("Please ensure both files are valid .safetensors files.") |
|
|
|
|
|
|
|
|
if __name__ == "__main__": |
|
|
|
|
|
parser = argparse.ArgumentParser( |
|
|
description="Compares two .safetensors files and lists the differences in their layers (tensors), ignoring a specific prefix.", |
|
|
formatter_class=argparse.RawTextHelpFormatter |
|
|
) |
|
|
|
|
|
parser.add_argument( |
|
|
"file1", |
|
|
type=str, |
|
|
help="Path to the first .safetensors file." |
|
|
) |
|
|
parser.add_argument( |
|
|
"file2", |
|
|
type=str, |
|
|
help="Path to the second .safetensors file." |
|
|
) |
|
|
|
|
|
args = parser.parse_args() |
|
|
|
|
|
|
|
|
compare_safetensors(args.file1, args.file2) |