Upload 4 files
Browse files- comparemodels.py +130 -0
- extract_layers.py +99 -0
- listlayers.py +47 -0
- prune_layers.py +238 -0
comparemodels.py
ADDED
|
@@ -0,0 +1,130 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python
|
| 2 |
+
|
| 3 |
+
import argparse
|
| 4 |
+
import torch
|
| 5 |
+
from safetensors import safe_open
|
| 6 |
+
|
| 7 |
+
|
| 8 |
+
def compare_safetensors(filepath1: str, filepath2: str):
|
| 9 |
+
"""
|
| 10 |
+
Compares two .safetensors files, ignoring a specific prefix on layer names,
|
| 11 |
+
and prints a summary of the differences.
|
| 12 |
+
|
| 13 |
+
Args:
|
| 14 |
+
filepath1 (str): Path to the first .safetensors file.
|
| 15 |
+
filepath2 (str): Path to the second .safetensors file.
|
| 16 |
+
"""
|
| 17 |
+
# The prefix to ignore on layer names
|
| 18 |
+
prefix_to_ignore = "model.diffusion_model."
|
| 19 |
+
|
| 20 |
+
# Dictionaries to hold results
|
| 21 |
+
results = {
|
| 22 |
+
"only_in_file1": [],
|
| 23 |
+
"only_in_file2": [],
|
| 24 |
+
"different_content": [],
|
| 25 |
+
}
|
| 26 |
+
|
| 27 |
+
print("\nLoading files and preparing for comparison...")
|
| 28 |
+
print(f"Ignoring prefix: '{prefix_to_ignore}'")
|
| 29 |
+
|
| 30 |
+
try:
|
| 31 |
+
# Use 'with' to ensure files are closed properly
|
| 32 |
+
with safe_open(filepath1, framework="pt", device="cpu") as f1, \
|
| 33 |
+
safe_open(filepath2, framework="pt", device="cpu") as f2:
|
| 34 |
+
|
| 35 |
+
# Create maps from the normalized key (suffix) to the original key
|
| 36 |
+
map1 = {key.removeprefix(prefix_to_ignore): key for key in f1.keys()}
|
| 37 |
+
map2 = {key.removeprefix(prefix_to_ignore): key for key in f2.keys()}
|
| 38 |
+
|
| 39 |
+
# Get the set of normalized tensor keys from each file
|
| 40 |
+
normalized_keys1 = set(map1.keys())
|
| 41 |
+
normalized_keys2 = set(map2.keys())
|
| 42 |
+
|
| 43 |
+
# 1. Find normalized keys (layers) unique to each file
|
| 44 |
+
results["only_in_file1"] = sorted(list(normalized_keys1 - normalized_keys2))
|
| 45 |
+
results["only_in_file2"] = sorted(list(normalized_keys2 - normalized_keys1))
|
| 46 |
+
|
| 47 |
+
# 2. Find normalized keys present in both files to compare their content
|
| 48 |
+
common_normalized_keys = normalized_keys1.intersection(normalized_keys2)
|
| 49 |
+
print(f"Comparing {len(common_normalized_keys)} common tensors...")
|
| 50 |
+
|
| 51 |
+
for norm_key in sorted(list(common_normalized_keys)):
|
| 52 |
+
# Get the original key for each file using the maps
|
| 53 |
+
original_key1 = map1[norm_key]
|
| 54 |
+
original_key2 = map2[norm_key]
|
| 55 |
+
|
| 56 |
+
# Get the tensor from each file using its original key
|
| 57 |
+
tensor1 = f1.get_tensor(original_key1)
|
| 58 |
+
tensor2 = f2.get_tensor(original_key2)
|
| 59 |
+
|
| 60 |
+
# Compare tensors for equality
|
| 61 |
+
if not torch.equal(tensor1, tensor2):
|
| 62 |
+
# Store the normalized key if content differs
|
| 63 |
+
results["different_content"].append(norm_key)
|
| 64 |
+
|
| 65 |
+
# --- Print the results ---
|
| 66 |
+
print("\n" + "=" * 60)
|
| 67 |
+
print("🔍 Safetensor Comparison Results")
|
| 68 |
+
print("=" * 60)
|
| 69 |
+
print(f"File 1: {filepath1}")
|
| 70 |
+
print(f"File 2: {filepath2}")
|
| 71 |
+
print("-" * 60)
|
| 72 |
+
|
| 73 |
+
# Check if any differences were found at all
|
| 74 |
+
total_diffs = len(results["only_in_file1"]) + len(results["only_in_file2"]) + len(results["different_content"])
|
| 75 |
+
if total_diffs == 0:
|
| 76 |
+
print("\n✅ The files are identical after normalization. No differences found.")
|
| 77 |
+
print("=" * 60 + "\n")
|
| 78 |
+
return
|
| 79 |
+
|
| 80 |
+
# Report tensors with different content
|
| 81 |
+
if results["different_content"]:
|
| 82 |
+
print(f"\n↔️ Tensors with Different Content ({len(results['different_content'])}):")
|
| 83 |
+
for norm_key in results["different_content"]:
|
| 84 |
+
print(f" - Normalized Key: {norm_key}")
|
| 85 |
+
print(f" (File 1 Original: {map1[norm_key]})")
|
| 86 |
+
print(f" (File 2 Original: {map2[norm_key]})")
|
| 87 |
+
|
| 88 |
+
# Report tensors only in file 1
|
| 89 |
+
if results["only_in_file1"]:
|
| 90 |
+
print(f"\n→ Tensors Only in File 1 ({len(results['only_in_file1'])}):")
|
| 91 |
+
for norm_key in results["only_in_file1"]:
|
| 92 |
+
print(f" - Normalized Key: {norm_key} (Original: {map1[norm_key]})")
|
| 93 |
+
|
| 94 |
+
# Report tensors only in file 2
|
| 95 |
+
if results["only_in_file2"]:
|
| 96 |
+
print(f"\n← Tensors Only in File 2 ({len(results['only_in_file2'])}):")
|
| 97 |
+
for norm_key in results["only_in_file2"]:
|
| 98 |
+
print(f" - Normalized Key: {norm_key} (Original: {map2[norm_key]})")
|
| 99 |
+
|
| 100 |
+
print("\n" + "=" * 60 + "\n")
|
| 101 |
+
|
| 102 |
+
except FileNotFoundError as e:
|
| 103 |
+
print(f"❌ Error: Could not find a file. Details: {e}")
|
| 104 |
+
except Exception as e:
|
| 105 |
+
print(f"❌ An error occurred: {e}")
|
| 106 |
+
print("Please ensure both files are valid .safetensors files.")
|
| 107 |
+
|
| 108 |
+
|
| 109 |
+
if __name__ == "__main__":
|
| 110 |
+
# --- Argument Parser Setup ---
|
| 111 |
+
parser = argparse.ArgumentParser(
|
| 112 |
+
description="Compares two .safetensors files and lists the differences in their layers (tensors), ignoring a specific prefix.",
|
| 113 |
+
formatter_class=argparse.RawTextHelpFormatter
|
| 114 |
+
)
|
| 115 |
+
|
| 116 |
+
parser.add_argument(
|
| 117 |
+
"file1",
|
| 118 |
+
type=str,
|
| 119 |
+
help="Path to the first .safetensors file."
|
| 120 |
+
)
|
| 121 |
+
parser.add_argument(
|
| 122 |
+
"file2",
|
| 123 |
+
type=str,
|
| 124 |
+
help="Path to the second .safetensors file."
|
| 125 |
+
)
|
| 126 |
+
|
| 127 |
+
args = parser.parse_args()
|
| 128 |
+
|
| 129 |
+
# --- Run the function ---
|
| 130 |
+
compare_safetensors(args.file1, args.file2)
|
extract_layers.py
ADDED
|
@@ -0,0 +1,99 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python
|
| 2 |
+
|
| 3 |
+
import argparse
|
| 4 |
+
from safetensors import safe_open
|
| 5 |
+
from safetensors.torch import save_file
|
| 6 |
+
from tqdm import tqdm
|
| 7 |
+
|
| 8 |
+
# This set contains the specific layers we want to find and save.
|
| 9 |
+
# Using a set is efficient for checking if a layer should be kept.
|
| 10 |
+
KEYS_TO_KEEP = {
|
| 11 |
+
"distilled_guidance_layer.in_proj.bias",
|
| 12 |
+
"distilled_guidance_layer.in_proj.weight",
|
| 13 |
+
"distilled_guidance_layer.layers.0.in_layer.bias",
|
| 14 |
+
"distilled_guidance_layer.layers.0.in_layer.weight",
|
| 15 |
+
"distilled_guidance_layer.layers.0.out_layer.bias",
|
| 16 |
+
"distilled_guidance_layer.layers.0.out_layer.weight",
|
| 17 |
+
"distilled_guidance_layer.layers.1.in_layer.bias",
|
| 18 |
+
"distilled_guidance_layer.layers.1.in_layer.weight",
|
| 19 |
+
"distilled_guidance_layer.layers.1.out_layer.bias",
|
| 20 |
+
"distilled_guidance_layer.layers.1.out_layer.weight",
|
| 21 |
+
"distilled_guidance_layer.layers.2.in_layer.bias",
|
| 22 |
+
"distilled_guidance_layer.layers.2.in_layer.weight",
|
| 23 |
+
"distilled_guidance_layer.layers.2.out_layer.bias",
|
| 24 |
+
"distilled_guidance_layer.layers.2.out_layer.weight",
|
| 25 |
+
"distilled_guidance_layer.layers.3.in_layer.bias",
|
| 26 |
+
"distilled_guidance_layer.layers.3.in_layer.weight",
|
| 27 |
+
"distilled_guidance_layer.layers.3.out_layer.bias",
|
| 28 |
+
"distilled_guidance_layer.layers.3.out_layer.weight",
|
| 29 |
+
"distilled_guidance_layer.layers.4.in_layer.bias",
|
| 30 |
+
"distilled_guidance_layer.layers.4.in_layer.weight",
|
| 31 |
+
"distilled_guidance_layer.layers.4.out_layer.bias",
|
| 32 |
+
"distilled_guidance_layer.layers.4.out_layer.weight",
|
| 33 |
+
"distilled_guidance_layer.norms.0.scale",
|
| 34 |
+
"distilled_guidance_layer.norms.1.scale",
|
| 35 |
+
"distilled_guidance_layer.norms.2.scale",
|
| 36 |
+
"distilled_guidance_layer.norms.3.scale",
|
| 37 |
+
"distilled_guidance_layer.norms.4.scale",
|
| 38 |
+
"distilled_guidance_layer.out_proj.bias",
|
| 39 |
+
"distilled_guidance_layer.out_proj.weight",
|
| 40 |
+
}
|
| 41 |
+
|
| 42 |
+
def extract_safetensors_layers(input_path: str, output_path: str, keys_to_keep: set):
|
| 43 |
+
"""
|
| 44 |
+
Reads a safetensors file, extracts specified layers, and saves them to a new file.
|
| 45 |
+
"""
|
| 46 |
+
tensors_to_save = {}
|
| 47 |
+
|
| 48 |
+
print(f"▶️ Reading from: {input_path}")
|
| 49 |
+
try:
|
| 50 |
+
# Open the source file for reading
|
| 51 |
+
with safe_open(input_path, framework="pt", device="cpu") as f:
|
| 52 |
+
|
| 53 |
+
# Iterate through the keys we want to keep
|
| 54 |
+
for key in tqdm(keys_to_keep, desc="Extracting Tensors"):
|
| 55 |
+
try:
|
| 56 |
+
# If the key exists in the file, add its tensor to our dictionary
|
| 57 |
+
tensors_to_save[key] = f.get_tensor(key)
|
| 58 |
+
except KeyError:
|
| 59 |
+
print(f"\n⚠️ Warning: Key not found in source file and will be skipped: {key}")
|
| 60 |
+
|
| 61 |
+
print("\n--- Summary ---")
|
| 62 |
+
print(f"Layers specified to keep: {len(keys_to_keep)}")
|
| 63 |
+
print(f"Layers found and extracted: {len(tensors_to_save)}")
|
| 64 |
+
|
| 65 |
+
if not tensors_to_save:
|
| 66 |
+
print("\n❌ Error: No matching layers were found. The output file will not be created.")
|
| 67 |
+
return
|
| 68 |
+
|
| 69 |
+
# Save the extracted tensors to the new file
|
| 70 |
+
print(f"\n▶️ Saving {len(tensors_to_save)} tensors to: {output_path}")
|
| 71 |
+
save_file(tensors_to_save, output_path)
|
| 72 |
+
|
| 73 |
+
print("\n✅ Extraction complete!")
|
| 74 |
+
|
| 75 |
+
except FileNotFoundError:
|
| 76 |
+
print(f"\n❌ Error: The file '{input_path}' was not found.")
|
| 77 |
+
except Exception as e:
|
| 78 |
+
print(f"\n❌ An error occurred: {e}")
|
| 79 |
+
|
| 80 |
+
if __name__ == "__main__":
|
| 81 |
+
parser = argparse.ArgumentParser(
|
| 82 |
+
description="Extracts a predefined list of layers from a .safetensors file and saves them to a new file.",
|
| 83 |
+
formatter_class=argparse.RawTextHelpFormatter
|
| 84 |
+
)
|
| 85 |
+
|
| 86 |
+
parser.add_argument(
|
| 87 |
+
"input_file",
|
| 88 |
+
type=str,
|
| 89 |
+
help="Path to the source .safetensors file."
|
| 90 |
+
)
|
| 91 |
+
parser.add_argument(
|
| 92 |
+
"output_file",
|
| 93 |
+
type=str,
|
| 94 |
+
help="Path to save the new, extracted .safetensors file."
|
| 95 |
+
)
|
| 96 |
+
|
| 97 |
+
args = parser.parse_args()
|
| 98 |
+
|
| 99 |
+
extract_safetensors_layers(args.input_file, args.output_file, KEYS_TO_KEEP)
|
listlayers.py
ADDED
|
@@ -0,0 +1,47 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python
|
| 2 |
+
|
| 3 |
+
import argparse
|
| 4 |
+
from safetensors import safe_open
|
| 5 |
+
|
| 6 |
+
def list_safetensor_layers(filepath: str):
|
| 7 |
+
"""
|
| 8 |
+
Opens a .safetensors file and prints the name and shape of each tensor.
|
| 9 |
+
|
| 10 |
+
Args:
|
| 11 |
+
filepath (str): The path to the .safetensors file.
|
| 12 |
+
"""
|
| 13 |
+
try:
|
| 14 |
+
print(f"\n📄 Tensors in: {filepath}\n" + "="*50)
|
| 15 |
+
|
| 16 |
+
total_tensors = 0
|
| 17 |
+
with safe_open(filepath, framework="pt", device="cpu") as f:
|
| 18 |
+
for key in f.keys():
|
| 19 |
+
tensor = f.get_tensor(key)
|
| 20 |
+
print(f"- {key:<50} | Shape: {tensor.shape}")
|
| 21 |
+
total_tensors += 1
|
| 22 |
+
|
| 23 |
+
print("="*50 + f"\n✅ Found {total_tensors} total tensors.\n")
|
| 24 |
+
|
| 25 |
+
except FileNotFoundError:
|
| 26 |
+
print(f"❌ Error: The file '{filepath}' was not found.")
|
| 27 |
+
except Exception as e:
|
| 28 |
+
print(f"❌ An error occurred: {e}")
|
| 29 |
+
print("Please ensure the file is a valid .safetensors file.")
|
| 30 |
+
|
| 31 |
+
if __name__ == "__main__":
|
| 32 |
+
# --- Argument Parser Setup ---
|
| 33 |
+
parser = argparse.ArgumentParser(
|
| 34 |
+
description="List all layers (tensors) and their shapes in a .safetensors file.",
|
| 35 |
+
formatter_class=argparse.RawTextHelpFormatter
|
| 36 |
+
)
|
| 37 |
+
|
| 38 |
+
parser.add_argument(
|
| 39 |
+
"filepath",
|
| 40 |
+
type=str,
|
| 41 |
+
help="Path to the .safetensors file."
|
| 42 |
+
)
|
| 43 |
+
|
| 44 |
+
args = parser.parse_args()
|
| 45 |
+
|
| 46 |
+
# --- Run the function ---
|
| 47 |
+
list_safetensor_layers(args.filepath)
|
prune_layers.py
ADDED
|
@@ -0,0 +1,238 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python
|
| 2 |
+
|
| 3 |
+
import argparse
|
| 4 |
+
from safetensors import safe_open
|
| 5 |
+
from safetensors.torch import save_file
|
| 6 |
+
from tqdm import tqdm
|
| 7 |
+
|
| 8 |
+
# This is the set of layers that will be removed from the file.
|
| 9 |
+
# Using a set provides fast lookups.
|
| 10 |
+
KEYS_TO_REMOVE = {
|
| 11 |
+
"model.diffusion_model.double_blocks.0.img_mod.lin.bias",
|
| 12 |
+
"model.diffusion_model.double_blocks.0.img_mod.lin.weight",
|
| 13 |
+
"model.diffusion_model.double_blocks.0.txt_mod.lin.bias",
|
| 14 |
+
"model.diffusion_model.double_blocks.0.txt_mod.lin.weight",
|
| 15 |
+
"model.diffusion_model.double_blocks.1.img_mod.lin.bias",
|
| 16 |
+
"model.diffusion_model.double_blocks.1.img_mod.lin.weight",
|
| 17 |
+
"model.diffusion_model.double_blocks.1.txt_mod.lin.bias",
|
| 18 |
+
"model.diffusion_model.double_blocks.1.txt_mod.lin.weight",
|
| 19 |
+
"model.diffusion_model.double_blocks.10.img_mod.lin.bias",
|
| 20 |
+
"model.diffusion_model.double_blocks.10.img_mod.lin.weight",
|
| 21 |
+
"model.diffusion_model.double_blocks.10.txt_mod.lin.bias",
|
| 22 |
+
"model.diffusion_model.double_blocks.10.txt_mod.lin.weight",
|
| 23 |
+
"model.diffusion_model.double_blocks.11.img_mod.lin.bias",
|
| 24 |
+
"model.diffusion_model.double_blocks.11.img_mod.lin.weight",
|
| 25 |
+
"model.diffusion_model.double_blocks.11.txt_mod.lin.bias",
|
| 26 |
+
"model.diffusion_model.double_blocks.11.txt_mod.lin.weight",
|
| 27 |
+
"model.diffusion_model.double_blocks.12.img_mod.lin.bias",
|
| 28 |
+
"model.diffusion_model.double_blocks.12.img_mod.lin.weight",
|
| 29 |
+
"model.diffusion_model.double_blocks.12.txt_mod.lin.bias",
|
| 30 |
+
"model.diffusion_model.double_blocks.12.txt_mod.lin.weight",
|
| 31 |
+
"model.diffusion_model.double_blocks.13.img_mod.lin.bias",
|
| 32 |
+
"model.diffusion_model.double_blocks.13.img_mod.lin.weight",
|
| 33 |
+
"model.diffusion_model.double_blocks.13.txt_mod.lin.bias",
|
| 34 |
+
"model.diffusion_model.double_blocks.13.txt_mod.lin.weight",
|
| 35 |
+
"model.diffusion_model.double_blocks.14.img_mod.lin.bias",
|
| 36 |
+
"model.diffusion_model.double_blocks.14.img_mod.lin.weight",
|
| 37 |
+
"model.diffusion_model.double_blocks.14.txt_mod.lin.bias",
|
| 38 |
+
"model.diffusion_model.double_blocks.14.txt_mod.lin.weight",
|
| 39 |
+
"model.diffusion_model.double_blocks.15.img_mod.lin.bias",
|
| 40 |
+
"model.diffusion_model.double_blocks.15.img_mod.lin.weight",
|
| 41 |
+
"model.diffusion_model.double_blocks.15.txt_mod.lin.bias",
|
| 42 |
+
"model.diffusion_model.double_blocks.15.txt_mod.lin.weight",
|
| 43 |
+
"model.diffusion_model.double_blocks.16.img_mod.lin.bias",
|
| 44 |
+
"model.diffusion_model.double_blocks.16.img_mod.lin.weight",
|
| 45 |
+
"model.diffusion_model.double_blocks.16.txt_mod.lin.bias",
|
| 46 |
+
"model.diffusion_model.double_blocks.16.txt_mod.lin.weight",
|
| 47 |
+
"model.diffusion_model.double_blocks.17.img_mod.lin.bias",
|
| 48 |
+
"model.diffusion_model.double_blocks.17.img_mod.lin.weight",
|
| 49 |
+
"model.diffusion_model.double_blocks.17.txt_mod.lin.bias",
|
| 50 |
+
"model.diffusion_model.double_blocks.17.txt_mod.lin.weight",
|
| 51 |
+
"model.diffusion_model.double_blocks.18.img_mod.lin.bias",
|
| 52 |
+
"model.diffusion_model.double_blocks.18.img_mod.lin.weight",
|
| 53 |
+
"model.diffusion_model.double_blocks.18.txt_mod.lin.bias",
|
| 54 |
+
"model.diffusion_model.double_blocks.18.txt_mod.lin.weight",
|
| 55 |
+
"model.diffusion_model.double_blocks.2.img_mod.lin.bias",
|
| 56 |
+
"model.diffusion_model.double_blocks.2.img_mod.lin.weight",
|
| 57 |
+
"model.diffusion_model.double_blocks.2.txt_mod.lin.bias",
|
| 58 |
+
"model.diffusion_model.double_blocks.2.txt_mod.lin.weight",
|
| 59 |
+
"model.diffusion_model.double_blocks.3.img_mod.lin.bias",
|
| 60 |
+
"model.diffusion_model.double_blocks.3.img_mod.lin.weight",
|
| 61 |
+
"model.diffusion_model.double_blocks.3.txt_mod.lin.bias",
|
| 62 |
+
"model.diffusion_model.double_blocks.3.txt_mod.lin.weight",
|
| 63 |
+
"model.diffusion_model.double_blocks.4.img_mod.lin.bias",
|
| 64 |
+
"model.diffusion_model.double_blocks.4.img_mod.lin.weight",
|
| 65 |
+
"model.diffusion_model.double_blocks.4.txt_mod.lin.bias",
|
| 66 |
+
"model.diffusion_model.double_blocks.4.txt_mod.lin.weight",
|
| 67 |
+
"model.diffusion_model.double_blocks.5.img_mod.lin.bias",
|
| 68 |
+
"model.diffusion_model.double_blocks.5.img_mod.lin.weight",
|
| 69 |
+
"model.diffusion_model.double_blocks.5.txt_mod.lin.bias",
|
| 70 |
+
"model.diffusion_model.double_blocks.5.txt_mod.lin.weight",
|
| 71 |
+
"model.diffusion_model.double_blocks.6.img_mod.lin.bias",
|
| 72 |
+
"model.diffusion_model.double_blocks.6.img_mod.lin.weight",
|
| 73 |
+
"model.diffusion_model.double_blocks.6.txt_mod.lin.bias",
|
| 74 |
+
"model.diffusion_model.double_blocks.6.txt_mod.lin.weight",
|
| 75 |
+
"model.diffusion_model.double_blocks.7.img_mod.lin.bias",
|
| 76 |
+
"model.diffusion_model.double_blocks.7.img_mod.lin.weight",
|
| 77 |
+
"model.diffusion_model.double_blocks.7.txt_mod.lin.bias",
|
| 78 |
+
"model.diffusion_model.double_blocks.7.txt_mod.lin.weight",
|
| 79 |
+
"model.diffusion_model.double_blocks.8.img_mod.lin.bias",
|
| 80 |
+
"model.diffusion_model.double_blocks.8.img_mod.lin.weight",
|
| 81 |
+
"model.diffusion_model.double_blocks.8.txt_mod.lin.bias",
|
| 82 |
+
"model.diffusion_model.double_blocks.8.txt_mod.lin.weight",
|
| 83 |
+
"model.diffusion_model.double_blocks.9.img_mod.lin.bias",
|
| 84 |
+
"model.diffusion_model.double_blocks.9.img_mod.lin.weight",
|
| 85 |
+
"model.diffusion_model.double_blocks.9.txt_mod.lin.bias",
|
| 86 |
+
"model.diffusion_model.double_blocks.9.txt_mod.lin.weight",
|
| 87 |
+
"model.diffusion_model.final_layer.adaLN_modulation.1.bias",
|
| 88 |
+
"model.diffusion_model.final_layer.adaLN_modulation.1.weight",
|
| 89 |
+
# "model.diffusion_model.guidance_in.in_layer.bias",
|
| 90 |
+
# "model.diffusion_model.guidance_in.in_layer.weight",
|
| 91 |
+
# "model.diffusion_model.guidance_in.out_layer.bias",
|
| 92 |
+
# "model.diffusion_model.guidance_in.out_layer.weight",
|
| 93 |
+
"model.diffusion_model.single_blocks.0.modulation.lin.bias",
|
| 94 |
+
"model.diffusion_model.single_blocks.0.modulation.lin.weight",
|
| 95 |
+
"model.diffusion_model.single_blocks.1.modulation.lin.bias",
|
| 96 |
+
"model.diffusion_model.single_blocks.1.modulation.lin.weight",
|
| 97 |
+
"model.diffusion_model.single_blocks.10.modulation.lin.bias",
|
| 98 |
+
"model.diffusion_model.single_blocks.10.modulation.lin.weight",
|
| 99 |
+
"model.diffusion_model.single_blocks.11.modulation.lin.bias",
|
| 100 |
+
"model.diffusion_model.single_blocks.11.modulation.lin.weight",
|
| 101 |
+
"model.diffusion_model.single_blocks.12.modulation.lin.bias",
|
| 102 |
+
"model.diffusion_model.single_blocks.12.modulation.lin.weight",
|
| 103 |
+
"model.diffusion_model.single_blocks.13.modulation.lin.bias",
|
| 104 |
+
"model.diffusion_model.single_blocks.13.modulation.lin.weight",
|
| 105 |
+
"model.diffusion_model.single_blocks.14.modulation.lin.bias",
|
| 106 |
+
"model.diffusion_model.single_blocks.14.modulation.lin.weight",
|
| 107 |
+
"model.diffusion_model.single_blocks.15.modulation.lin.bias",
|
| 108 |
+
"model.diffusion_model.single_blocks.15.modulation.lin.weight",
|
| 109 |
+
"model.diffusion_model.single_blocks.16.modulation.lin.bias",
|
| 110 |
+
"model.diffusion_model.single_blocks.16.modulation.lin.weight",
|
| 111 |
+
"model.diffusion_model.single_blocks.17.modulation.lin.bias",
|
| 112 |
+
"model.diffusion_model.single_blocks.17.modulation.lin.weight",
|
| 113 |
+
"model.diffusion_model.single_blocks.18.modulation.lin.bias",
|
| 114 |
+
"model.diffusion_model.single_blocks.18.modulation.lin.weight",
|
| 115 |
+
"model.diffusion_model.single_blocks.19.modulation.lin.bias",
|
| 116 |
+
"model.diffusion_model.single_blocks.19.modulation.lin.weight",
|
| 117 |
+
"model.diffusion_model.single_blocks.2.modulation.lin.bias",
|
| 118 |
+
"model.diffusion_model.single_blocks.2.modulation.lin.weight",
|
| 119 |
+
"model.diffusion_model.single_blocks.20.modulation.lin.bias",
|
| 120 |
+
"model.diffusion_model.single_blocks.20.modulation.lin.weight",
|
| 121 |
+
"model.diffusion_model.single_blocks.21.modulation.lin.bias",
|
| 122 |
+
"model.diffusion_model.single_blocks.21.modulation.lin.weight",
|
| 123 |
+
"model.diffusion_model.single_blocks.22.modulation.lin.bias",
|
| 124 |
+
"model.diffusion_model.single_blocks.22.modulation.lin.weight",
|
| 125 |
+
"model.diffusion_model.single_blocks.23.modulation.lin.bias",
|
| 126 |
+
"model.diffusion_model.single_blocks.23.modulation.lin.weight",
|
| 127 |
+
"model.diffusion_model.single_blocks.24.modulation.lin.bias",
|
| 128 |
+
"model.diffusion_model.single_blocks.24.modulation.lin.weight",
|
| 129 |
+
"model.diffusion_model.single_blocks.25.modulation.lin.bias",
|
| 130 |
+
"model.diffusion_model.single_blocks.25.modulation.lin.weight",
|
| 131 |
+
"model.diffusion_model.single_blocks.26.modulation.lin.bias",
|
| 132 |
+
"model.diffusion_model.single_blocks.26.modulation.lin.weight",
|
| 133 |
+
"model.diffusion_model.single_blocks.27.modulation.lin.bias",
|
| 134 |
+
"model.diffusion_model.single_blocks.27.modulation.lin.weight",
|
| 135 |
+
"model.diffusion_model.single_blocks.28.modulation.lin.bias",
|
| 136 |
+
"model.diffusion_model.single_blocks.28.modulation.lin.weight",
|
| 137 |
+
"model.diffusion_model.single_blocks.29.modulation.lin.bias",
|
| 138 |
+
"model.diffusion_model.single_blocks.29.modulation.lin.weight",
|
| 139 |
+
"model.diffusion_model.single_blocks.3.modulation.lin.bias",
|
| 140 |
+
"model.diffusion_model.single_blocks.3.modulation.lin.weight",
|
| 141 |
+
"model.diffusion_model.single_blocks.30.modulation.lin.bias",
|
| 142 |
+
"model.diffusion_model.single_blocks.30.modulation.lin.weight",
|
| 143 |
+
"model.diffusion_model.single_blocks.31.modulation.lin.bias",
|
| 144 |
+
"model.diffusion_model.single_blocks.31.modulation.lin.weight",
|
| 145 |
+
"model.diffusion_model.single_blocks.32.modulation.lin.bias",
|
| 146 |
+
"model.diffusion_model.single_blocks.32.modulation.lin.weight",
|
| 147 |
+
"model.diffusion_model.single_blocks.33.modulation.lin.bias",
|
| 148 |
+
"model.diffusion_model.single_blocks.33.modulation.lin.weight",
|
| 149 |
+
"model.diffusion_model.single_blocks.34.modulation.lin.bias",
|
| 150 |
+
"model.diffusion_model.single_blocks.34.modulation.lin.weight",
|
| 151 |
+
"model.diffusion_model.single_blocks.35.modulation.lin.bias",
|
| 152 |
+
"model.diffusion_model.single_blocks.35.modulation.lin.weight",
|
| 153 |
+
"model.diffusion_model.single_blocks.36.modulation.lin.bias",
|
| 154 |
+
"model.diffusion_model.single_blocks.36.modulation.lin.weight",
|
| 155 |
+
"model.diffusion_model.single_blocks.37.modulation.lin.bias",
|
| 156 |
+
"model.diffusion_model.single_blocks.37.modulation.lin.weight",
|
| 157 |
+
"model.diffusion_model.single_blocks.4.modulation.lin.bias",
|
| 158 |
+
"model.diffusion_model.single_blocks.4.modulation.lin.weight",
|
| 159 |
+
"model.diffusion_model.single_blocks.5.modulation.lin.bias",
|
| 160 |
+
"model.diffusion_model.single_blocks.5.modulation.lin.weight",
|
| 161 |
+
"model.diffusion_model.single_blocks.6.modulation.lin.bias",
|
| 162 |
+
"model.diffusion_model.single_blocks.6.modulation.lin.weight",
|
| 163 |
+
"model.diffusion_model.single_blocks.7.modulation.lin.bias",
|
| 164 |
+
"model.diffusion_model.single_blocks.7.modulation.lin.weight",
|
| 165 |
+
"model.diffusion_model.single_blocks.8.modulation.lin.bias",
|
| 166 |
+
"model.diffusion_model.single_blocks.8.modulation.lin.weight",
|
| 167 |
+
"model.diffusion_model.single_blocks.9.modulation.lin.bias",
|
| 168 |
+
"model.diffusion_model.single_blocks.9.modulation.lin.weight",
|
| 169 |
+
# "model.diffusion_model.time_in.in_layer.bias",
|
| 170 |
+
# "model.diffusion_model.time_in.in_layer.weight",
|
| 171 |
+
# "model.diffusion_model.time_in.out_layer.bias",
|
| 172 |
+
# "model.diffusion_model.time_in.out_layer.weight",
|
| 173 |
+
# "model.diffusion_model.vector_in.in_layer.bias",
|
| 174 |
+
# "model.diffusion_model.vector_in.in_layer.weight",
|
| 175 |
+
# "model.diffusion_model.vector_in.out_layer.bias",
|
| 176 |
+
# "model.diffusion_model.vector_in.out_layer.weight",
|
| 177 |
+
}
|
| 178 |
+
|
| 179 |
+
def prune_safetensors_file(input_path: str, output_path: str, keys_to_remove: set):
|
| 180 |
+
"""
|
| 181 |
+
Reads a safetensors file, removes specified layers, and saves to a new file.
|
| 182 |
+
"""
|
| 183 |
+
tensors_to_keep = {}
|
| 184 |
+
original_count = 0
|
| 185 |
+
|
| 186 |
+
print(f"▶️ Reading from: {input_path}")
|
| 187 |
+
try:
|
| 188 |
+
# Open the source file for reading
|
| 189 |
+
with safe_open(input_path, framework="pt", device="cpu") as f:
|
| 190 |
+
all_keys = f.keys()
|
| 191 |
+
original_count = len(all_keys)
|
| 192 |
+
|
| 193 |
+
# Iterate through all tensors, keeping only the ones not in the removal list
|
| 194 |
+
for key in tqdm(all_keys, desc="Filtering Tensors"):
|
| 195 |
+
if 'modulation' not in key :
|
| 196 |
+
tensors_to_keep[key] = f.get_tensor(key)
|
| 197 |
+
|
| 198 |
+
removed_count = original_count - len(tensors_to_keep)
|
| 199 |
+
|
| 200 |
+
print("\n--- Summary ---")
|
| 201 |
+
print(f"Original tensor count: {original_count}")
|
| 202 |
+
print(f"Tensors removed: {removed_count}")
|
| 203 |
+
print(f"New tensor count: {len(tensors_to_keep)}")
|
| 204 |
+
|
| 205 |
+
if removed_count == 0 and len(keys_to_remove) > 0:
|
| 206 |
+
print("\n⚠️ Warning: None of the specified layers to remove were found in the input file.")
|
| 207 |
+
|
| 208 |
+
# Save the filtered tensors to the new file
|
| 209 |
+
print(f"\n▶️ Saving {len(tensors_to_keep)} tensors to: {output_path}")
|
| 210 |
+
save_file(tensors_to_keep, output_path)
|
| 211 |
+
|
| 212 |
+
print("\n✅ Pruning complete!")
|
| 213 |
+
|
| 214 |
+
except FileNotFoundError:
|
| 215 |
+
print(f"\n❌ Error: The file '{input_path}' was not found.")
|
| 216 |
+
except Exception as e:
|
| 217 |
+
print(f"\n❌ An error occurred: {e}")
|
| 218 |
+
|
| 219 |
+
if __name__ == "__main__":
|
| 220 |
+
parser = argparse.ArgumentParser(
|
| 221 |
+
description="Removes a predefined list of layers from a .safetensors file and saves it as a new file.",
|
| 222 |
+
formatter_class=argparse.RawTextHelpFormatter
|
| 223 |
+
)
|
| 224 |
+
|
| 225 |
+
parser.add_argument(
|
| 226 |
+
"input_file",
|
| 227 |
+
type=str,
|
| 228 |
+
help="Path to the source .safetensors file."
|
| 229 |
+
)
|
| 230 |
+
parser.add_argument(
|
| 231 |
+
"output_file",
|
| 232 |
+
type=str,
|
| 233 |
+
help="Path to save the new, pruned .safetensors file."
|
| 234 |
+
)
|
| 235 |
+
|
| 236 |
+
args = parser.parse_args()
|
| 237 |
+
|
| 238 |
+
prune_safetensors_file(args.input_file, args.output_file, KEYS_TO_REMOVE)
|