python-script-dump / prune_layers.py
anyMODE's picture
Upload 4 files
54c3b42 verified
#!/usr/bin/env python
import argparse
from safetensors import safe_open
from safetensors.torch import save_file
from tqdm import tqdm
# This is the set of layers that will be removed from the file.
# Using a set provides fast lookups.
KEYS_TO_REMOVE = {
"model.diffusion_model.double_blocks.0.img_mod.lin.bias",
"model.diffusion_model.double_blocks.0.img_mod.lin.weight",
"model.diffusion_model.double_blocks.0.txt_mod.lin.bias",
"model.diffusion_model.double_blocks.0.txt_mod.lin.weight",
"model.diffusion_model.double_blocks.1.img_mod.lin.bias",
"model.diffusion_model.double_blocks.1.img_mod.lin.weight",
"model.diffusion_model.double_blocks.1.txt_mod.lin.bias",
"model.diffusion_model.double_blocks.1.txt_mod.lin.weight",
"model.diffusion_model.double_blocks.10.img_mod.lin.bias",
"model.diffusion_model.double_blocks.10.img_mod.lin.weight",
"model.diffusion_model.double_blocks.10.txt_mod.lin.bias",
"model.diffusion_model.double_blocks.10.txt_mod.lin.weight",
"model.diffusion_model.double_blocks.11.img_mod.lin.bias",
"model.diffusion_model.double_blocks.11.img_mod.lin.weight",
"model.diffusion_model.double_blocks.11.txt_mod.lin.bias",
"model.diffusion_model.double_blocks.11.txt_mod.lin.weight",
"model.diffusion_model.double_blocks.12.img_mod.lin.bias",
"model.diffusion_model.double_blocks.12.img_mod.lin.weight",
"model.diffusion_model.double_blocks.12.txt_mod.lin.bias",
"model.diffusion_model.double_blocks.12.txt_mod.lin.weight",
"model.diffusion_model.double_blocks.13.img_mod.lin.bias",
"model.diffusion_model.double_blocks.13.img_mod.lin.weight",
"model.diffusion_model.double_blocks.13.txt_mod.lin.bias",
"model.diffusion_model.double_blocks.13.txt_mod.lin.weight",
"model.diffusion_model.double_blocks.14.img_mod.lin.bias",
"model.diffusion_model.double_blocks.14.img_mod.lin.weight",
"model.diffusion_model.double_blocks.14.txt_mod.lin.bias",
"model.diffusion_model.double_blocks.14.txt_mod.lin.weight",
"model.diffusion_model.double_blocks.15.img_mod.lin.bias",
"model.diffusion_model.double_blocks.15.img_mod.lin.weight",
"model.diffusion_model.double_blocks.15.txt_mod.lin.bias",
"model.diffusion_model.double_blocks.15.txt_mod.lin.weight",
"model.diffusion_model.double_blocks.16.img_mod.lin.bias",
"model.diffusion_model.double_blocks.16.img_mod.lin.weight",
"model.diffusion_model.double_blocks.16.txt_mod.lin.bias",
"model.diffusion_model.double_blocks.16.txt_mod.lin.weight",
"model.diffusion_model.double_blocks.17.img_mod.lin.bias",
"model.diffusion_model.double_blocks.17.img_mod.lin.weight",
"model.diffusion_model.double_blocks.17.txt_mod.lin.bias",
"model.diffusion_model.double_blocks.17.txt_mod.lin.weight",
"model.diffusion_model.double_blocks.18.img_mod.lin.bias",
"model.diffusion_model.double_blocks.18.img_mod.lin.weight",
"model.diffusion_model.double_blocks.18.txt_mod.lin.bias",
"model.diffusion_model.double_blocks.18.txt_mod.lin.weight",
"model.diffusion_model.double_blocks.2.img_mod.lin.bias",
"model.diffusion_model.double_blocks.2.img_mod.lin.weight",
"model.diffusion_model.double_blocks.2.txt_mod.lin.bias",
"model.diffusion_model.double_blocks.2.txt_mod.lin.weight",
"model.diffusion_model.double_blocks.3.img_mod.lin.bias",
"model.diffusion_model.double_blocks.3.img_mod.lin.weight",
"model.diffusion_model.double_blocks.3.txt_mod.lin.bias",
"model.diffusion_model.double_blocks.3.txt_mod.lin.weight",
"model.diffusion_model.double_blocks.4.img_mod.lin.bias",
"model.diffusion_model.double_blocks.4.img_mod.lin.weight",
"model.diffusion_model.double_blocks.4.txt_mod.lin.bias",
"model.diffusion_model.double_blocks.4.txt_mod.lin.weight",
"model.diffusion_model.double_blocks.5.img_mod.lin.bias",
"model.diffusion_model.double_blocks.5.img_mod.lin.weight",
"model.diffusion_model.double_blocks.5.txt_mod.lin.bias",
"model.diffusion_model.double_blocks.5.txt_mod.lin.weight",
"model.diffusion_model.double_blocks.6.img_mod.lin.bias",
"model.diffusion_model.double_blocks.6.img_mod.lin.weight",
"model.diffusion_model.double_blocks.6.txt_mod.lin.bias",
"model.diffusion_model.double_blocks.6.txt_mod.lin.weight",
"model.diffusion_model.double_blocks.7.img_mod.lin.bias",
"model.diffusion_model.double_blocks.7.img_mod.lin.weight",
"model.diffusion_model.double_blocks.7.txt_mod.lin.bias",
"model.diffusion_model.double_blocks.7.txt_mod.lin.weight",
"model.diffusion_model.double_blocks.8.img_mod.lin.bias",
"model.diffusion_model.double_blocks.8.img_mod.lin.weight",
"model.diffusion_model.double_blocks.8.txt_mod.lin.bias",
"model.diffusion_model.double_blocks.8.txt_mod.lin.weight",
"model.diffusion_model.double_blocks.9.img_mod.lin.bias",
"model.diffusion_model.double_blocks.9.img_mod.lin.weight",
"model.diffusion_model.double_blocks.9.txt_mod.lin.bias",
"model.diffusion_model.double_blocks.9.txt_mod.lin.weight",
"model.diffusion_model.final_layer.adaLN_modulation.1.bias",
"model.diffusion_model.final_layer.adaLN_modulation.1.weight",
# "model.diffusion_model.guidance_in.in_layer.bias",
# "model.diffusion_model.guidance_in.in_layer.weight",
# "model.diffusion_model.guidance_in.out_layer.bias",
# "model.diffusion_model.guidance_in.out_layer.weight",
"model.diffusion_model.single_blocks.0.modulation.lin.bias",
"model.diffusion_model.single_blocks.0.modulation.lin.weight",
"model.diffusion_model.single_blocks.1.modulation.lin.bias",
"model.diffusion_model.single_blocks.1.modulation.lin.weight",
"model.diffusion_model.single_blocks.10.modulation.lin.bias",
"model.diffusion_model.single_blocks.10.modulation.lin.weight",
"model.diffusion_model.single_blocks.11.modulation.lin.bias",
"model.diffusion_model.single_blocks.11.modulation.lin.weight",
"model.diffusion_model.single_blocks.12.modulation.lin.bias",
"model.diffusion_model.single_blocks.12.modulation.lin.weight",
"model.diffusion_model.single_blocks.13.modulation.lin.bias",
"model.diffusion_model.single_blocks.13.modulation.lin.weight",
"model.diffusion_model.single_blocks.14.modulation.lin.bias",
"model.diffusion_model.single_blocks.14.modulation.lin.weight",
"model.diffusion_model.single_blocks.15.modulation.lin.bias",
"model.diffusion_model.single_blocks.15.modulation.lin.weight",
"model.diffusion_model.single_blocks.16.modulation.lin.bias",
"model.diffusion_model.single_blocks.16.modulation.lin.weight",
"model.diffusion_model.single_blocks.17.modulation.lin.bias",
"model.diffusion_model.single_blocks.17.modulation.lin.weight",
"model.diffusion_model.single_blocks.18.modulation.lin.bias",
"model.diffusion_model.single_blocks.18.modulation.lin.weight",
"model.diffusion_model.single_blocks.19.modulation.lin.bias",
"model.diffusion_model.single_blocks.19.modulation.lin.weight",
"model.diffusion_model.single_blocks.2.modulation.lin.bias",
"model.diffusion_model.single_blocks.2.modulation.lin.weight",
"model.diffusion_model.single_blocks.20.modulation.lin.bias",
"model.diffusion_model.single_blocks.20.modulation.lin.weight",
"model.diffusion_model.single_blocks.21.modulation.lin.bias",
"model.diffusion_model.single_blocks.21.modulation.lin.weight",
"model.diffusion_model.single_blocks.22.modulation.lin.bias",
"model.diffusion_model.single_blocks.22.modulation.lin.weight",
"model.diffusion_model.single_blocks.23.modulation.lin.bias",
"model.diffusion_model.single_blocks.23.modulation.lin.weight",
"model.diffusion_model.single_blocks.24.modulation.lin.bias",
"model.diffusion_model.single_blocks.24.modulation.lin.weight",
"model.diffusion_model.single_blocks.25.modulation.lin.bias",
"model.diffusion_model.single_blocks.25.modulation.lin.weight",
"model.diffusion_model.single_blocks.26.modulation.lin.bias",
"model.diffusion_model.single_blocks.26.modulation.lin.weight",
"model.diffusion_model.single_blocks.27.modulation.lin.bias",
"model.diffusion_model.single_blocks.27.modulation.lin.weight",
"model.diffusion_model.single_blocks.28.modulation.lin.bias",
"model.diffusion_model.single_blocks.28.modulation.lin.weight",
"model.diffusion_model.single_blocks.29.modulation.lin.bias",
"model.diffusion_model.single_blocks.29.modulation.lin.weight",
"model.diffusion_model.single_blocks.3.modulation.lin.bias",
"model.diffusion_model.single_blocks.3.modulation.lin.weight",
"model.diffusion_model.single_blocks.30.modulation.lin.bias",
"model.diffusion_model.single_blocks.30.modulation.lin.weight",
"model.diffusion_model.single_blocks.31.modulation.lin.bias",
"model.diffusion_model.single_blocks.31.modulation.lin.weight",
"model.diffusion_model.single_blocks.32.modulation.lin.bias",
"model.diffusion_model.single_blocks.32.modulation.lin.weight",
"model.diffusion_model.single_blocks.33.modulation.lin.bias",
"model.diffusion_model.single_blocks.33.modulation.lin.weight",
"model.diffusion_model.single_blocks.34.modulation.lin.bias",
"model.diffusion_model.single_blocks.34.modulation.lin.weight",
"model.diffusion_model.single_blocks.35.modulation.lin.bias",
"model.diffusion_model.single_blocks.35.modulation.lin.weight",
"model.diffusion_model.single_blocks.36.modulation.lin.bias",
"model.diffusion_model.single_blocks.36.modulation.lin.weight",
"model.diffusion_model.single_blocks.37.modulation.lin.bias",
"model.diffusion_model.single_blocks.37.modulation.lin.weight",
"model.diffusion_model.single_blocks.4.modulation.lin.bias",
"model.diffusion_model.single_blocks.4.modulation.lin.weight",
"model.diffusion_model.single_blocks.5.modulation.lin.bias",
"model.diffusion_model.single_blocks.5.modulation.lin.weight",
"model.diffusion_model.single_blocks.6.modulation.lin.bias",
"model.diffusion_model.single_blocks.6.modulation.lin.weight",
"model.diffusion_model.single_blocks.7.modulation.lin.bias",
"model.diffusion_model.single_blocks.7.modulation.lin.weight",
"model.diffusion_model.single_blocks.8.modulation.lin.bias",
"model.diffusion_model.single_blocks.8.modulation.lin.weight",
"model.diffusion_model.single_blocks.9.modulation.lin.bias",
"model.diffusion_model.single_blocks.9.modulation.lin.weight",
# "model.diffusion_model.time_in.in_layer.bias",
# "model.diffusion_model.time_in.in_layer.weight",
# "model.diffusion_model.time_in.out_layer.bias",
# "model.diffusion_model.time_in.out_layer.weight",
# "model.diffusion_model.vector_in.in_layer.bias",
# "model.diffusion_model.vector_in.in_layer.weight",
# "model.diffusion_model.vector_in.out_layer.bias",
# "model.diffusion_model.vector_in.out_layer.weight",
}
def prune_safetensors_file(input_path: str, output_path: str, keys_to_remove: set):
"""
Reads a safetensors file, removes specified layers, and saves to a new file.
"""
tensors_to_keep = {}
original_count = 0
print(f"▶️ Reading from: {input_path}")
try:
# Open the source file for reading
with safe_open(input_path, framework="pt", device="cpu") as f:
all_keys = f.keys()
original_count = len(all_keys)
# Iterate through all tensors, keeping only the ones not in the removal list
for key in tqdm(all_keys, desc="Filtering Tensors"):
if 'modulation' not in key :
tensors_to_keep[key] = f.get_tensor(key)
removed_count = original_count - len(tensors_to_keep)
print("\n--- Summary ---")
print(f"Original tensor count: {original_count}")
print(f"Tensors removed: {removed_count}")
print(f"New tensor count: {len(tensors_to_keep)}")
if removed_count == 0 and len(keys_to_remove) > 0:
print("\n⚠️ Warning: None of the specified layers to remove were found in the input file.")
# Save the filtered tensors to the new file
print(f"\n▶️ Saving {len(tensors_to_keep)} tensors to: {output_path}")
save_file(tensors_to_keep, output_path)
print("\n✅ Pruning complete!")
except FileNotFoundError:
print(f"\n❌ Error: The file '{input_path}' was not found.")
except Exception as e:
print(f"\n❌ An error occurred: {e}")
if __name__ == "__main__":
parser = argparse.ArgumentParser(
description="Removes a predefined list of layers from a .safetensors file and saves it as a new file.",
formatter_class=argparse.RawTextHelpFormatter
)
parser.add_argument(
"input_file",
type=str,
help="Path to the source .safetensors file."
)
parser.add_argument(
"output_file",
type=str,
help="Path to save the new, pruned .safetensors file."
)
args = parser.parse_args()
prune_safetensors_file(args.input_file, args.output_file, KEYS_TO_REMOVE)