#!/usr/bin/env python import argparse from safetensors import safe_open from safetensors.torch import save_file from tqdm import tqdm # This is the set of layers that will be removed from the file. # Using a set provides fast lookups. KEYS_TO_REMOVE = { "model.diffusion_model.double_blocks.0.img_mod.lin.bias", "model.diffusion_model.double_blocks.0.img_mod.lin.weight", "model.diffusion_model.double_blocks.0.txt_mod.lin.bias", "model.diffusion_model.double_blocks.0.txt_mod.lin.weight", "model.diffusion_model.double_blocks.1.img_mod.lin.bias", "model.diffusion_model.double_blocks.1.img_mod.lin.weight", "model.diffusion_model.double_blocks.1.txt_mod.lin.bias", "model.diffusion_model.double_blocks.1.txt_mod.lin.weight", "model.diffusion_model.double_blocks.10.img_mod.lin.bias", "model.diffusion_model.double_blocks.10.img_mod.lin.weight", "model.diffusion_model.double_blocks.10.txt_mod.lin.bias", "model.diffusion_model.double_blocks.10.txt_mod.lin.weight", "model.diffusion_model.double_blocks.11.img_mod.lin.bias", "model.diffusion_model.double_blocks.11.img_mod.lin.weight", "model.diffusion_model.double_blocks.11.txt_mod.lin.bias", "model.diffusion_model.double_blocks.11.txt_mod.lin.weight", "model.diffusion_model.double_blocks.12.img_mod.lin.bias", "model.diffusion_model.double_blocks.12.img_mod.lin.weight", "model.diffusion_model.double_blocks.12.txt_mod.lin.bias", "model.diffusion_model.double_blocks.12.txt_mod.lin.weight", "model.diffusion_model.double_blocks.13.img_mod.lin.bias", "model.diffusion_model.double_blocks.13.img_mod.lin.weight", "model.diffusion_model.double_blocks.13.txt_mod.lin.bias", "model.diffusion_model.double_blocks.13.txt_mod.lin.weight", "model.diffusion_model.double_blocks.14.img_mod.lin.bias", "model.diffusion_model.double_blocks.14.img_mod.lin.weight", "model.diffusion_model.double_blocks.14.txt_mod.lin.bias", "model.diffusion_model.double_blocks.14.txt_mod.lin.weight", "model.diffusion_model.double_blocks.15.img_mod.lin.bias", "model.diffusion_model.double_blocks.15.img_mod.lin.weight", "model.diffusion_model.double_blocks.15.txt_mod.lin.bias", "model.diffusion_model.double_blocks.15.txt_mod.lin.weight", "model.diffusion_model.double_blocks.16.img_mod.lin.bias", "model.diffusion_model.double_blocks.16.img_mod.lin.weight", "model.diffusion_model.double_blocks.16.txt_mod.lin.bias", "model.diffusion_model.double_blocks.16.txt_mod.lin.weight", "model.diffusion_model.double_blocks.17.img_mod.lin.bias", "model.diffusion_model.double_blocks.17.img_mod.lin.weight", "model.diffusion_model.double_blocks.17.txt_mod.lin.bias", "model.diffusion_model.double_blocks.17.txt_mod.lin.weight", "model.diffusion_model.double_blocks.18.img_mod.lin.bias", "model.diffusion_model.double_blocks.18.img_mod.lin.weight", "model.diffusion_model.double_blocks.18.txt_mod.lin.bias", "model.diffusion_model.double_blocks.18.txt_mod.lin.weight", "model.diffusion_model.double_blocks.2.img_mod.lin.bias", "model.diffusion_model.double_blocks.2.img_mod.lin.weight", "model.diffusion_model.double_blocks.2.txt_mod.lin.bias", "model.diffusion_model.double_blocks.2.txt_mod.lin.weight", "model.diffusion_model.double_blocks.3.img_mod.lin.bias", "model.diffusion_model.double_blocks.3.img_mod.lin.weight", "model.diffusion_model.double_blocks.3.txt_mod.lin.bias", "model.diffusion_model.double_blocks.3.txt_mod.lin.weight", "model.diffusion_model.double_blocks.4.img_mod.lin.bias", "model.diffusion_model.double_blocks.4.img_mod.lin.weight", "model.diffusion_model.double_blocks.4.txt_mod.lin.bias", "model.diffusion_model.double_blocks.4.txt_mod.lin.weight", "model.diffusion_model.double_blocks.5.img_mod.lin.bias", "model.diffusion_model.double_blocks.5.img_mod.lin.weight", "model.diffusion_model.double_blocks.5.txt_mod.lin.bias", "model.diffusion_model.double_blocks.5.txt_mod.lin.weight", "model.diffusion_model.double_blocks.6.img_mod.lin.bias", "model.diffusion_model.double_blocks.6.img_mod.lin.weight", "model.diffusion_model.double_blocks.6.txt_mod.lin.bias", "model.diffusion_model.double_blocks.6.txt_mod.lin.weight", "model.diffusion_model.double_blocks.7.img_mod.lin.bias", "model.diffusion_model.double_blocks.7.img_mod.lin.weight", "model.diffusion_model.double_blocks.7.txt_mod.lin.bias", "model.diffusion_model.double_blocks.7.txt_mod.lin.weight", "model.diffusion_model.double_blocks.8.img_mod.lin.bias", "model.diffusion_model.double_blocks.8.img_mod.lin.weight", "model.diffusion_model.double_blocks.8.txt_mod.lin.bias", "model.diffusion_model.double_blocks.8.txt_mod.lin.weight", "model.diffusion_model.double_blocks.9.img_mod.lin.bias", "model.diffusion_model.double_blocks.9.img_mod.lin.weight", "model.diffusion_model.double_blocks.9.txt_mod.lin.bias", "model.diffusion_model.double_blocks.9.txt_mod.lin.weight", "model.diffusion_model.final_layer.adaLN_modulation.1.bias", "model.diffusion_model.final_layer.adaLN_modulation.1.weight", # "model.diffusion_model.guidance_in.in_layer.bias", # "model.diffusion_model.guidance_in.in_layer.weight", # "model.diffusion_model.guidance_in.out_layer.bias", # "model.diffusion_model.guidance_in.out_layer.weight", "model.diffusion_model.single_blocks.0.modulation.lin.bias", "model.diffusion_model.single_blocks.0.modulation.lin.weight", "model.diffusion_model.single_blocks.1.modulation.lin.bias", "model.diffusion_model.single_blocks.1.modulation.lin.weight", "model.diffusion_model.single_blocks.10.modulation.lin.bias", "model.diffusion_model.single_blocks.10.modulation.lin.weight", "model.diffusion_model.single_blocks.11.modulation.lin.bias", "model.diffusion_model.single_blocks.11.modulation.lin.weight", "model.diffusion_model.single_blocks.12.modulation.lin.bias", "model.diffusion_model.single_blocks.12.modulation.lin.weight", "model.diffusion_model.single_blocks.13.modulation.lin.bias", "model.diffusion_model.single_blocks.13.modulation.lin.weight", "model.diffusion_model.single_blocks.14.modulation.lin.bias", "model.diffusion_model.single_blocks.14.modulation.lin.weight", "model.diffusion_model.single_blocks.15.modulation.lin.bias", "model.diffusion_model.single_blocks.15.modulation.lin.weight", "model.diffusion_model.single_blocks.16.modulation.lin.bias", "model.diffusion_model.single_blocks.16.modulation.lin.weight", "model.diffusion_model.single_blocks.17.modulation.lin.bias", "model.diffusion_model.single_blocks.17.modulation.lin.weight", "model.diffusion_model.single_blocks.18.modulation.lin.bias", "model.diffusion_model.single_blocks.18.modulation.lin.weight", "model.diffusion_model.single_blocks.19.modulation.lin.bias", "model.diffusion_model.single_blocks.19.modulation.lin.weight", "model.diffusion_model.single_blocks.2.modulation.lin.bias", "model.diffusion_model.single_blocks.2.modulation.lin.weight", "model.diffusion_model.single_blocks.20.modulation.lin.bias", "model.diffusion_model.single_blocks.20.modulation.lin.weight", "model.diffusion_model.single_blocks.21.modulation.lin.bias", "model.diffusion_model.single_blocks.21.modulation.lin.weight", "model.diffusion_model.single_blocks.22.modulation.lin.bias", "model.diffusion_model.single_blocks.22.modulation.lin.weight", "model.diffusion_model.single_blocks.23.modulation.lin.bias", "model.diffusion_model.single_blocks.23.modulation.lin.weight", "model.diffusion_model.single_blocks.24.modulation.lin.bias", "model.diffusion_model.single_blocks.24.modulation.lin.weight", "model.diffusion_model.single_blocks.25.modulation.lin.bias", "model.diffusion_model.single_blocks.25.modulation.lin.weight", "model.diffusion_model.single_blocks.26.modulation.lin.bias", "model.diffusion_model.single_blocks.26.modulation.lin.weight", "model.diffusion_model.single_blocks.27.modulation.lin.bias", "model.diffusion_model.single_blocks.27.modulation.lin.weight", "model.diffusion_model.single_blocks.28.modulation.lin.bias", "model.diffusion_model.single_blocks.28.modulation.lin.weight", "model.diffusion_model.single_blocks.29.modulation.lin.bias", "model.diffusion_model.single_blocks.29.modulation.lin.weight", "model.diffusion_model.single_blocks.3.modulation.lin.bias", "model.diffusion_model.single_blocks.3.modulation.lin.weight", "model.diffusion_model.single_blocks.30.modulation.lin.bias", "model.diffusion_model.single_blocks.30.modulation.lin.weight", "model.diffusion_model.single_blocks.31.modulation.lin.bias", "model.diffusion_model.single_blocks.31.modulation.lin.weight", "model.diffusion_model.single_blocks.32.modulation.lin.bias", "model.diffusion_model.single_blocks.32.modulation.lin.weight", "model.diffusion_model.single_blocks.33.modulation.lin.bias", "model.diffusion_model.single_blocks.33.modulation.lin.weight", "model.diffusion_model.single_blocks.34.modulation.lin.bias", "model.diffusion_model.single_blocks.34.modulation.lin.weight", "model.diffusion_model.single_blocks.35.modulation.lin.bias", "model.diffusion_model.single_blocks.35.modulation.lin.weight", "model.diffusion_model.single_blocks.36.modulation.lin.bias", "model.diffusion_model.single_blocks.36.modulation.lin.weight", "model.diffusion_model.single_blocks.37.modulation.lin.bias", "model.diffusion_model.single_blocks.37.modulation.lin.weight", "model.diffusion_model.single_blocks.4.modulation.lin.bias", "model.diffusion_model.single_blocks.4.modulation.lin.weight", "model.diffusion_model.single_blocks.5.modulation.lin.bias", "model.diffusion_model.single_blocks.5.modulation.lin.weight", "model.diffusion_model.single_blocks.6.modulation.lin.bias", "model.diffusion_model.single_blocks.6.modulation.lin.weight", "model.diffusion_model.single_blocks.7.modulation.lin.bias", "model.diffusion_model.single_blocks.7.modulation.lin.weight", "model.diffusion_model.single_blocks.8.modulation.lin.bias", "model.diffusion_model.single_blocks.8.modulation.lin.weight", "model.diffusion_model.single_blocks.9.modulation.lin.bias", "model.diffusion_model.single_blocks.9.modulation.lin.weight", # "model.diffusion_model.time_in.in_layer.bias", # "model.diffusion_model.time_in.in_layer.weight", # "model.diffusion_model.time_in.out_layer.bias", # "model.diffusion_model.time_in.out_layer.weight", # "model.diffusion_model.vector_in.in_layer.bias", # "model.diffusion_model.vector_in.in_layer.weight", # "model.diffusion_model.vector_in.out_layer.bias", # "model.diffusion_model.vector_in.out_layer.weight", } def prune_safetensors_file(input_path: str, output_path: str, keys_to_remove: set): """ Reads a safetensors file, removes specified layers, and saves to a new file. """ tensors_to_keep = {} original_count = 0 print(f"▶️ Reading from: {input_path}") try: # Open the source file for reading with safe_open(input_path, framework="pt", device="cpu") as f: all_keys = f.keys() original_count = len(all_keys) # Iterate through all tensors, keeping only the ones not in the removal list for key in tqdm(all_keys, desc="Filtering Tensors"): if 'modulation' not in key : tensors_to_keep[key] = f.get_tensor(key) removed_count = original_count - len(tensors_to_keep) print("\n--- Summary ---") print(f"Original tensor count: {original_count}") print(f"Tensors removed: {removed_count}") print(f"New tensor count: {len(tensors_to_keep)}") if removed_count == 0 and len(keys_to_remove) > 0: print("\n⚠️ Warning: None of the specified layers to remove were found in the input file.") # Save the filtered tensors to the new file print(f"\n▶️ Saving {len(tensors_to_keep)} tensors to: {output_path}") save_file(tensors_to_keep, output_path) print("\n✅ Pruning complete!") except FileNotFoundError: print(f"\n❌ Error: The file '{input_path}' was not found.") except Exception as e: print(f"\n❌ An error occurred: {e}") if __name__ == "__main__": parser = argparse.ArgumentParser( description="Removes a predefined list of layers from a .safetensors file and saves it as a new file.", formatter_class=argparse.RawTextHelpFormatter ) parser.add_argument( "input_file", type=str, help="Path to the source .safetensors file." ) parser.add_argument( "output_file", type=str, help="Path to save the new, pruned .safetensors file." ) args = parser.parse_args() prune_safetensors_file(args.input_file, args.output_file, KEYS_TO_REMOVE)