|
|
|
|
|
|
|
|
import argparse |
|
|
from safetensors import safe_open |
|
|
from safetensors.torch import save_file |
|
|
from tqdm import tqdm |
|
|
|
|
|
|
|
|
|
|
|
KEYS_TO_REMOVE = { |
|
|
"model.diffusion_model.double_blocks.0.img_mod.lin.bias", |
|
|
"model.diffusion_model.double_blocks.0.img_mod.lin.weight", |
|
|
"model.diffusion_model.double_blocks.0.txt_mod.lin.bias", |
|
|
"model.diffusion_model.double_blocks.0.txt_mod.lin.weight", |
|
|
"model.diffusion_model.double_blocks.1.img_mod.lin.bias", |
|
|
"model.diffusion_model.double_blocks.1.img_mod.lin.weight", |
|
|
"model.diffusion_model.double_blocks.1.txt_mod.lin.bias", |
|
|
"model.diffusion_model.double_blocks.1.txt_mod.lin.weight", |
|
|
"model.diffusion_model.double_blocks.10.img_mod.lin.bias", |
|
|
"model.diffusion_model.double_blocks.10.img_mod.lin.weight", |
|
|
"model.diffusion_model.double_blocks.10.txt_mod.lin.bias", |
|
|
"model.diffusion_model.double_blocks.10.txt_mod.lin.weight", |
|
|
"model.diffusion_model.double_blocks.11.img_mod.lin.bias", |
|
|
"model.diffusion_model.double_blocks.11.img_mod.lin.weight", |
|
|
"model.diffusion_model.double_blocks.11.txt_mod.lin.bias", |
|
|
"model.diffusion_model.double_blocks.11.txt_mod.lin.weight", |
|
|
"model.diffusion_model.double_blocks.12.img_mod.lin.bias", |
|
|
"model.diffusion_model.double_blocks.12.img_mod.lin.weight", |
|
|
"model.diffusion_model.double_blocks.12.txt_mod.lin.bias", |
|
|
"model.diffusion_model.double_blocks.12.txt_mod.lin.weight", |
|
|
"model.diffusion_model.double_blocks.13.img_mod.lin.bias", |
|
|
"model.diffusion_model.double_blocks.13.img_mod.lin.weight", |
|
|
"model.diffusion_model.double_blocks.13.txt_mod.lin.bias", |
|
|
"model.diffusion_model.double_blocks.13.txt_mod.lin.weight", |
|
|
"model.diffusion_model.double_blocks.14.img_mod.lin.bias", |
|
|
"model.diffusion_model.double_blocks.14.img_mod.lin.weight", |
|
|
"model.diffusion_model.double_blocks.14.txt_mod.lin.bias", |
|
|
"model.diffusion_model.double_blocks.14.txt_mod.lin.weight", |
|
|
"model.diffusion_model.double_blocks.15.img_mod.lin.bias", |
|
|
"model.diffusion_model.double_blocks.15.img_mod.lin.weight", |
|
|
"model.diffusion_model.double_blocks.15.txt_mod.lin.bias", |
|
|
"model.diffusion_model.double_blocks.15.txt_mod.lin.weight", |
|
|
"model.diffusion_model.double_blocks.16.img_mod.lin.bias", |
|
|
"model.diffusion_model.double_blocks.16.img_mod.lin.weight", |
|
|
"model.diffusion_model.double_blocks.16.txt_mod.lin.bias", |
|
|
"model.diffusion_model.double_blocks.16.txt_mod.lin.weight", |
|
|
"model.diffusion_model.double_blocks.17.img_mod.lin.bias", |
|
|
"model.diffusion_model.double_blocks.17.img_mod.lin.weight", |
|
|
"model.diffusion_model.double_blocks.17.txt_mod.lin.bias", |
|
|
"model.diffusion_model.double_blocks.17.txt_mod.lin.weight", |
|
|
"model.diffusion_model.double_blocks.18.img_mod.lin.bias", |
|
|
"model.diffusion_model.double_blocks.18.img_mod.lin.weight", |
|
|
"model.diffusion_model.double_blocks.18.txt_mod.lin.bias", |
|
|
"model.diffusion_model.double_blocks.18.txt_mod.lin.weight", |
|
|
"model.diffusion_model.double_blocks.2.img_mod.lin.bias", |
|
|
"model.diffusion_model.double_blocks.2.img_mod.lin.weight", |
|
|
"model.diffusion_model.double_blocks.2.txt_mod.lin.bias", |
|
|
"model.diffusion_model.double_blocks.2.txt_mod.lin.weight", |
|
|
"model.diffusion_model.double_blocks.3.img_mod.lin.bias", |
|
|
"model.diffusion_model.double_blocks.3.img_mod.lin.weight", |
|
|
"model.diffusion_model.double_blocks.3.txt_mod.lin.bias", |
|
|
"model.diffusion_model.double_blocks.3.txt_mod.lin.weight", |
|
|
"model.diffusion_model.double_blocks.4.img_mod.lin.bias", |
|
|
"model.diffusion_model.double_blocks.4.img_mod.lin.weight", |
|
|
"model.diffusion_model.double_blocks.4.txt_mod.lin.bias", |
|
|
"model.diffusion_model.double_blocks.4.txt_mod.lin.weight", |
|
|
"model.diffusion_model.double_blocks.5.img_mod.lin.bias", |
|
|
"model.diffusion_model.double_blocks.5.img_mod.lin.weight", |
|
|
"model.diffusion_model.double_blocks.5.txt_mod.lin.bias", |
|
|
"model.diffusion_model.double_blocks.5.txt_mod.lin.weight", |
|
|
"model.diffusion_model.double_blocks.6.img_mod.lin.bias", |
|
|
"model.diffusion_model.double_blocks.6.img_mod.lin.weight", |
|
|
"model.diffusion_model.double_blocks.6.txt_mod.lin.bias", |
|
|
"model.diffusion_model.double_blocks.6.txt_mod.lin.weight", |
|
|
"model.diffusion_model.double_blocks.7.img_mod.lin.bias", |
|
|
"model.diffusion_model.double_blocks.7.img_mod.lin.weight", |
|
|
"model.diffusion_model.double_blocks.7.txt_mod.lin.bias", |
|
|
"model.diffusion_model.double_blocks.7.txt_mod.lin.weight", |
|
|
"model.diffusion_model.double_blocks.8.img_mod.lin.bias", |
|
|
"model.diffusion_model.double_blocks.8.img_mod.lin.weight", |
|
|
"model.diffusion_model.double_blocks.8.txt_mod.lin.bias", |
|
|
"model.diffusion_model.double_blocks.8.txt_mod.lin.weight", |
|
|
"model.diffusion_model.double_blocks.9.img_mod.lin.bias", |
|
|
"model.diffusion_model.double_blocks.9.img_mod.lin.weight", |
|
|
"model.diffusion_model.double_blocks.9.txt_mod.lin.bias", |
|
|
"model.diffusion_model.double_blocks.9.txt_mod.lin.weight", |
|
|
"model.diffusion_model.final_layer.adaLN_modulation.1.bias", |
|
|
"model.diffusion_model.final_layer.adaLN_modulation.1.weight", |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
"model.diffusion_model.single_blocks.0.modulation.lin.bias", |
|
|
"model.diffusion_model.single_blocks.0.modulation.lin.weight", |
|
|
"model.diffusion_model.single_blocks.1.modulation.lin.bias", |
|
|
"model.diffusion_model.single_blocks.1.modulation.lin.weight", |
|
|
"model.diffusion_model.single_blocks.10.modulation.lin.bias", |
|
|
"model.diffusion_model.single_blocks.10.modulation.lin.weight", |
|
|
"model.diffusion_model.single_blocks.11.modulation.lin.bias", |
|
|
"model.diffusion_model.single_blocks.11.modulation.lin.weight", |
|
|
"model.diffusion_model.single_blocks.12.modulation.lin.bias", |
|
|
"model.diffusion_model.single_blocks.12.modulation.lin.weight", |
|
|
"model.diffusion_model.single_blocks.13.modulation.lin.bias", |
|
|
"model.diffusion_model.single_blocks.13.modulation.lin.weight", |
|
|
"model.diffusion_model.single_blocks.14.modulation.lin.bias", |
|
|
"model.diffusion_model.single_blocks.14.modulation.lin.weight", |
|
|
"model.diffusion_model.single_blocks.15.modulation.lin.bias", |
|
|
"model.diffusion_model.single_blocks.15.modulation.lin.weight", |
|
|
"model.diffusion_model.single_blocks.16.modulation.lin.bias", |
|
|
"model.diffusion_model.single_blocks.16.modulation.lin.weight", |
|
|
"model.diffusion_model.single_blocks.17.modulation.lin.bias", |
|
|
"model.diffusion_model.single_blocks.17.modulation.lin.weight", |
|
|
"model.diffusion_model.single_blocks.18.modulation.lin.bias", |
|
|
"model.diffusion_model.single_blocks.18.modulation.lin.weight", |
|
|
"model.diffusion_model.single_blocks.19.modulation.lin.bias", |
|
|
"model.diffusion_model.single_blocks.19.modulation.lin.weight", |
|
|
"model.diffusion_model.single_blocks.2.modulation.lin.bias", |
|
|
"model.diffusion_model.single_blocks.2.modulation.lin.weight", |
|
|
"model.diffusion_model.single_blocks.20.modulation.lin.bias", |
|
|
"model.diffusion_model.single_blocks.20.modulation.lin.weight", |
|
|
"model.diffusion_model.single_blocks.21.modulation.lin.bias", |
|
|
"model.diffusion_model.single_blocks.21.modulation.lin.weight", |
|
|
"model.diffusion_model.single_blocks.22.modulation.lin.bias", |
|
|
"model.diffusion_model.single_blocks.22.modulation.lin.weight", |
|
|
"model.diffusion_model.single_blocks.23.modulation.lin.bias", |
|
|
"model.diffusion_model.single_blocks.23.modulation.lin.weight", |
|
|
"model.diffusion_model.single_blocks.24.modulation.lin.bias", |
|
|
"model.diffusion_model.single_blocks.24.modulation.lin.weight", |
|
|
"model.diffusion_model.single_blocks.25.modulation.lin.bias", |
|
|
"model.diffusion_model.single_blocks.25.modulation.lin.weight", |
|
|
"model.diffusion_model.single_blocks.26.modulation.lin.bias", |
|
|
"model.diffusion_model.single_blocks.26.modulation.lin.weight", |
|
|
"model.diffusion_model.single_blocks.27.modulation.lin.bias", |
|
|
"model.diffusion_model.single_blocks.27.modulation.lin.weight", |
|
|
"model.diffusion_model.single_blocks.28.modulation.lin.bias", |
|
|
"model.diffusion_model.single_blocks.28.modulation.lin.weight", |
|
|
"model.diffusion_model.single_blocks.29.modulation.lin.bias", |
|
|
"model.diffusion_model.single_blocks.29.modulation.lin.weight", |
|
|
"model.diffusion_model.single_blocks.3.modulation.lin.bias", |
|
|
"model.diffusion_model.single_blocks.3.modulation.lin.weight", |
|
|
"model.diffusion_model.single_blocks.30.modulation.lin.bias", |
|
|
"model.diffusion_model.single_blocks.30.modulation.lin.weight", |
|
|
"model.diffusion_model.single_blocks.31.modulation.lin.bias", |
|
|
"model.diffusion_model.single_blocks.31.modulation.lin.weight", |
|
|
"model.diffusion_model.single_blocks.32.modulation.lin.bias", |
|
|
"model.diffusion_model.single_blocks.32.modulation.lin.weight", |
|
|
"model.diffusion_model.single_blocks.33.modulation.lin.bias", |
|
|
"model.diffusion_model.single_blocks.33.modulation.lin.weight", |
|
|
"model.diffusion_model.single_blocks.34.modulation.lin.bias", |
|
|
"model.diffusion_model.single_blocks.34.modulation.lin.weight", |
|
|
"model.diffusion_model.single_blocks.35.modulation.lin.bias", |
|
|
"model.diffusion_model.single_blocks.35.modulation.lin.weight", |
|
|
"model.diffusion_model.single_blocks.36.modulation.lin.bias", |
|
|
"model.diffusion_model.single_blocks.36.modulation.lin.weight", |
|
|
"model.diffusion_model.single_blocks.37.modulation.lin.bias", |
|
|
"model.diffusion_model.single_blocks.37.modulation.lin.weight", |
|
|
"model.diffusion_model.single_blocks.4.modulation.lin.bias", |
|
|
"model.diffusion_model.single_blocks.4.modulation.lin.weight", |
|
|
"model.diffusion_model.single_blocks.5.modulation.lin.bias", |
|
|
"model.diffusion_model.single_blocks.5.modulation.lin.weight", |
|
|
"model.diffusion_model.single_blocks.6.modulation.lin.bias", |
|
|
"model.diffusion_model.single_blocks.6.modulation.lin.weight", |
|
|
"model.diffusion_model.single_blocks.7.modulation.lin.bias", |
|
|
"model.diffusion_model.single_blocks.7.modulation.lin.weight", |
|
|
"model.diffusion_model.single_blocks.8.modulation.lin.bias", |
|
|
"model.diffusion_model.single_blocks.8.modulation.lin.weight", |
|
|
"model.diffusion_model.single_blocks.9.modulation.lin.bias", |
|
|
"model.diffusion_model.single_blocks.9.modulation.lin.weight", |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
} |
|
|
|
|
|
def prune_safetensors_file(input_path: str, output_path: str, keys_to_remove: set): |
|
|
""" |
|
|
Reads a safetensors file, removes specified layers, and saves to a new file. |
|
|
""" |
|
|
tensors_to_keep = {} |
|
|
original_count = 0 |
|
|
|
|
|
print(f"▶️ Reading from: {input_path}") |
|
|
try: |
|
|
|
|
|
with safe_open(input_path, framework="pt", device="cpu") as f: |
|
|
all_keys = f.keys() |
|
|
original_count = len(all_keys) |
|
|
|
|
|
|
|
|
for key in tqdm(all_keys, desc="Filtering Tensors"): |
|
|
if 'modulation' not in key : |
|
|
tensors_to_keep[key] = f.get_tensor(key) |
|
|
|
|
|
removed_count = original_count - len(tensors_to_keep) |
|
|
|
|
|
print("\n--- Summary ---") |
|
|
print(f"Original tensor count: {original_count}") |
|
|
print(f"Tensors removed: {removed_count}") |
|
|
print(f"New tensor count: {len(tensors_to_keep)}") |
|
|
|
|
|
if removed_count == 0 and len(keys_to_remove) > 0: |
|
|
print("\n⚠️ Warning: None of the specified layers to remove were found in the input file.") |
|
|
|
|
|
|
|
|
print(f"\n▶️ Saving {len(tensors_to_keep)} tensors to: {output_path}") |
|
|
save_file(tensors_to_keep, output_path) |
|
|
|
|
|
print("\n✅ Pruning complete!") |
|
|
|
|
|
except FileNotFoundError: |
|
|
print(f"\n❌ Error: The file '{input_path}' was not found.") |
|
|
except Exception as e: |
|
|
print(f"\n❌ An error occurred: {e}") |
|
|
|
|
|
if __name__ == "__main__": |
|
|
parser = argparse.ArgumentParser( |
|
|
description="Removes a predefined list of layers from a .safetensors file and saves it as a new file.", |
|
|
formatter_class=argparse.RawTextHelpFormatter |
|
|
) |
|
|
|
|
|
parser.add_argument( |
|
|
"input_file", |
|
|
type=str, |
|
|
help="Path to the source .safetensors file." |
|
|
) |
|
|
parser.add_argument( |
|
|
"output_file", |
|
|
type=str, |
|
|
help="Path to save the new, pruned .safetensors file." |
|
|
) |
|
|
|
|
|
args = parser.parse_args() |
|
|
|
|
|
prune_safetensors_file(args.input_file, args.output_file, KEYS_TO_REMOVE) |
|
|
|