anyMODE commited on
Commit
54c3b42
·
verified ·
1 Parent(s): ab09044

Upload 4 files

Browse files
Files changed (4) hide show
  1. comparemodels.py +130 -0
  2. extract_layers.py +99 -0
  3. listlayers.py +47 -0
  4. prune_layers.py +238 -0
comparemodels.py ADDED
@@ -0,0 +1,130 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python
2
+
3
+ import argparse
4
+ import torch
5
+ from safetensors import safe_open
6
+
7
+
8
+ def compare_safetensors(filepath1: str, filepath2: str):
9
+ """
10
+ Compares two .safetensors files, ignoring a specific prefix on layer names,
11
+ and prints a summary of the differences.
12
+
13
+ Args:
14
+ filepath1 (str): Path to the first .safetensors file.
15
+ filepath2 (str): Path to the second .safetensors file.
16
+ """
17
+ # The prefix to ignore on layer names
18
+ prefix_to_ignore = "model.diffusion_model."
19
+
20
+ # Dictionaries to hold results
21
+ results = {
22
+ "only_in_file1": [],
23
+ "only_in_file2": [],
24
+ "different_content": [],
25
+ }
26
+
27
+ print("\nLoading files and preparing for comparison...")
28
+ print(f"Ignoring prefix: '{prefix_to_ignore}'")
29
+
30
+ try:
31
+ # Use 'with' to ensure files are closed properly
32
+ with safe_open(filepath1, framework="pt", device="cpu") as f1, \
33
+ safe_open(filepath2, framework="pt", device="cpu") as f2:
34
+
35
+ # Create maps from the normalized key (suffix) to the original key
36
+ map1 = {key.removeprefix(prefix_to_ignore): key for key in f1.keys()}
37
+ map2 = {key.removeprefix(prefix_to_ignore): key for key in f2.keys()}
38
+
39
+ # Get the set of normalized tensor keys from each file
40
+ normalized_keys1 = set(map1.keys())
41
+ normalized_keys2 = set(map2.keys())
42
+
43
+ # 1. Find normalized keys (layers) unique to each file
44
+ results["only_in_file1"] = sorted(list(normalized_keys1 - normalized_keys2))
45
+ results["only_in_file2"] = sorted(list(normalized_keys2 - normalized_keys1))
46
+
47
+ # 2. Find normalized keys present in both files to compare their content
48
+ common_normalized_keys = normalized_keys1.intersection(normalized_keys2)
49
+ print(f"Comparing {len(common_normalized_keys)} common tensors...")
50
+
51
+ for norm_key in sorted(list(common_normalized_keys)):
52
+ # Get the original key for each file using the maps
53
+ original_key1 = map1[norm_key]
54
+ original_key2 = map2[norm_key]
55
+
56
+ # Get the tensor from each file using its original key
57
+ tensor1 = f1.get_tensor(original_key1)
58
+ tensor2 = f2.get_tensor(original_key2)
59
+
60
+ # Compare tensors for equality
61
+ if not torch.equal(tensor1, tensor2):
62
+ # Store the normalized key if content differs
63
+ results["different_content"].append(norm_key)
64
+
65
+ # --- Print the results ---
66
+ print("\n" + "=" * 60)
67
+ print("🔍 Safetensor Comparison Results")
68
+ print("=" * 60)
69
+ print(f"File 1: {filepath1}")
70
+ print(f"File 2: {filepath2}")
71
+ print("-" * 60)
72
+
73
+ # Check if any differences were found at all
74
+ total_diffs = len(results["only_in_file1"]) + len(results["only_in_file2"]) + len(results["different_content"])
75
+ if total_diffs == 0:
76
+ print("\n✅ The files are identical after normalization. No differences found.")
77
+ print("=" * 60 + "\n")
78
+ return
79
+
80
+ # Report tensors with different content
81
+ if results["different_content"]:
82
+ print(f"\n↔️ Tensors with Different Content ({len(results['different_content'])}):")
83
+ for norm_key in results["different_content"]:
84
+ print(f" - Normalized Key: {norm_key}")
85
+ print(f" (File 1 Original: {map1[norm_key]})")
86
+ print(f" (File 2 Original: {map2[norm_key]})")
87
+
88
+ # Report tensors only in file 1
89
+ if results["only_in_file1"]:
90
+ print(f"\n→ Tensors Only in File 1 ({len(results['only_in_file1'])}):")
91
+ for norm_key in results["only_in_file1"]:
92
+ print(f" - Normalized Key: {norm_key} (Original: {map1[norm_key]})")
93
+
94
+ # Report tensors only in file 2
95
+ if results["only_in_file2"]:
96
+ print(f"\n← Tensors Only in File 2 ({len(results['only_in_file2'])}):")
97
+ for norm_key in results["only_in_file2"]:
98
+ print(f" - Normalized Key: {norm_key} (Original: {map2[norm_key]})")
99
+
100
+ print("\n" + "=" * 60 + "\n")
101
+
102
+ except FileNotFoundError as e:
103
+ print(f"❌ Error: Could not find a file. Details: {e}")
104
+ except Exception as e:
105
+ print(f"❌ An error occurred: {e}")
106
+ print("Please ensure both files are valid .safetensors files.")
107
+
108
+
109
+ if __name__ == "__main__":
110
+ # --- Argument Parser Setup ---
111
+ parser = argparse.ArgumentParser(
112
+ description="Compares two .safetensors files and lists the differences in their layers (tensors), ignoring a specific prefix.",
113
+ formatter_class=argparse.RawTextHelpFormatter
114
+ )
115
+
116
+ parser.add_argument(
117
+ "file1",
118
+ type=str,
119
+ help="Path to the first .safetensors file."
120
+ )
121
+ parser.add_argument(
122
+ "file2",
123
+ type=str,
124
+ help="Path to the second .safetensors file."
125
+ )
126
+
127
+ args = parser.parse_args()
128
+
129
+ # --- Run the function ---
130
+ compare_safetensors(args.file1, args.file2)
extract_layers.py ADDED
@@ -0,0 +1,99 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python
2
+
3
+ import argparse
4
+ from safetensors import safe_open
5
+ from safetensors.torch import save_file
6
+ from tqdm import tqdm
7
+
8
+ # This set contains the specific layers we want to find and save.
9
+ # Using a set is efficient for checking if a layer should be kept.
10
+ KEYS_TO_KEEP = {
11
+ "distilled_guidance_layer.in_proj.bias",
12
+ "distilled_guidance_layer.in_proj.weight",
13
+ "distilled_guidance_layer.layers.0.in_layer.bias",
14
+ "distilled_guidance_layer.layers.0.in_layer.weight",
15
+ "distilled_guidance_layer.layers.0.out_layer.bias",
16
+ "distilled_guidance_layer.layers.0.out_layer.weight",
17
+ "distilled_guidance_layer.layers.1.in_layer.bias",
18
+ "distilled_guidance_layer.layers.1.in_layer.weight",
19
+ "distilled_guidance_layer.layers.1.out_layer.bias",
20
+ "distilled_guidance_layer.layers.1.out_layer.weight",
21
+ "distilled_guidance_layer.layers.2.in_layer.bias",
22
+ "distilled_guidance_layer.layers.2.in_layer.weight",
23
+ "distilled_guidance_layer.layers.2.out_layer.bias",
24
+ "distilled_guidance_layer.layers.2.out_layer.weight",
25
+ "distilled_guidance_layer.layers.3.in_layer.bias",
26
+ "distilled_guidance_layer.layers.3.in_layer.weight",
27
+ "distilled_guidance_layer.layers.3.out_layer.bias",
28
+ "distilled_guidance_layer.layers.3.out_layer.weight",
29
+ "distilled_guidance_layer.layers.4.in_layer.bias",
30
+ "distilled_guidance_layer.layers.4.in_layer.weight",
31
+ "distilled_guidance_layer.layers.4.out_layer.bias",
32
+ "distilled_guidance_layer.layers.4.out_layer.weight",
33
+ "distilled_guidance_layer.norms.0.scale",
34
+ "distilled_guidance_layer.norms.1.scale",
35
+ "distilled_guidance_layer.norms.2.scale",
36
+ "distilled_guidance_layer.norms.3.scale",
37
+ "distilled_guidance_layer.norms.4.scale",
38
+ "distilled_guidance_layer.out_proj.bias",
39
+ "distilled_guidance_layer.out_proj.weight",
40
+ }
41
+
42
+ def extract_safetensors_layers(input_path: str, output_path: str, keys_to_keep: set):
43
+ """
44
+ Reads a safetensors file, extracts specified layers, and saves them to a new file.
45
+ """
46
+ tensors_to_save = {}
47
+
48
+ print(f"▶️ Reading from: {input_path}")
49
+ try:
50
+ # Open the source file for reading
51
+ with safe_open(input_path, framework="pt", device="cpu") as f:
52
+
53
+ # Iterate through the keys we want to keep
54
+ for key in tqdm(keys_to_keep, desc="Extracting Tensors"):
55
+ try:
56
+ # If the key exists in the file, add its tensor to our dictionary
57
+ tensors_to_save[key] = f.get_tensor(key)
58
+ except KeyError:
59
+ print(f"\n⚠️ Warning: Key not found in source file and will be skipped: {key}")
60
+
61
+ print("\n--- Summary ---")
62
+ print(f"Layers specified to keep: {len(keys_to_keep)}")
63
+ print(f"Layers found and extracted: {len(tensors_to_save)}")
64
+
65
+ if not tensors_to_save:
66
+ print("\n❌ Error: No matching layers were found. The output file will not be created.")
67
+ return
68
+
69
+ # Save the extracted tensors to the new file
70
+ print(f"\n▶️ Saving {len(tensors_to_save)} tensors to: {output_path}")
71
+ save_file(tensors_to_save, output_path)
72
+
73
+ print("\n✅ Extraction complete!")
74
+
75
+ except FileNotFoundError:
76
+ print(f"\n❌ Error: The file '{input_path}' was not found.")
77
+ except Exception as e:
78
+ print(f"\n❌ An error occurred: {e}")
79
+
80
+ if __name__ == "__main__":
81
+ parser = argparse.ArgumentParser(
82
+ description="Extracts a predefined list of layers from a .safetensors file and saves them to a new file.",
83
+ formatter_class=argparse.RawTextHelpFormatter
84
+ )
85
+
86
+ parser.add_argument(
87
+ "input_file",
88
+ type=str,
89
+ help="Path to the source .safetensors file."
90
+ )
91
+ parser.add_argument(
92
+ "output_file",
93
+ type=str,
94
+ help="Path to save the new, extracted .safetensors file."
95
+ )
96
+
97
+ args = parser.parse_args()
98
+
99
+ extract_safetensors_layers(args.input_file, args.output_file, KEYS_TO_KEEP)
listlayers.py ADDED
@@ -0,0 +1,47 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python
2
+
3
+ import argparse
4
+ from safetensors import safe_open
5
+
6
+ def list_safetensor_layers(filepath: str):
7
+ """
8
+ Opens a .safetensors file and prints the name and shape of each tensor.
9
+
10
+ Args:
11
+ filepath (str): The path to the .safetensors file.
12
+ """
13
+ try:
14
+ print(f"\n📄 Tensors in: {filepath}\n" + "="*50)
15
+
16
+ total_tensors = 0
17
+ with safe_open(filepath, framework="pt", device="cpu") as f:
18
+ for key in f.keys():
19
+ tensor = f.get_tensor(key)
20
+ print(f"- {key:<50} | Shape: {tensor.shape}")
21
+ total_tensors += 1
22
+
23
+ print("="*50 + f"\n✅ Found {total_tensors} total tensors.\n")
24
+
25
+ except FileNotFoundError:
26
+ print(f"❌ Error: The file '{filepath}' was not found.")
27
+ except Exception as e:
28
+ print(f"❌ An error occurred: {e}")
29
+ print("Please ensure the file is a valid .safetensors file.")
30
+
31
+ if __name__ == "__main__":
32
+ # --- Argument Parser Setup ---
33
+ parser = argparse.ArgumentParser(
34
+ description="List all layers (tensors) and their shapes in a .safetensors file.",
35
+ formatter_class=argparse.RawTextHelpFormatter
36
+ )
37
+
38
+ parser.add_argument(
39
+ "filepath",
40
+ type=str,
41
+ help="Path to the .safetensors file."
42
+ )
43
+
44
+ args = parser.parse_args()
45
+
46
+ # --- Run the function ---
47
+ list_safetensor_layers(args.filepath)
prune_layers.py ADDED
@@ -0,0 +1,238 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python
2
+
3
+ import argparse
4
+ from safetensors import safe_open
5
+ from safetensors.torch import save_file
6
+ from tqdm import tqdm
7
+
8
+ # This is the set of layers that will be removed from the file.
9
+ # Using a set provides fast lookups.
10
+ KEYS_TO_REMOVE = {
11
+ "model.diffusion_model.double_blocks.0.img_mod.lin.bias",
12
+ "model.diffusion_model.double_blocks.0.img_mod.lin.weight",
13
+ "model.diffusion_model.double_blocks.0.txt_mod.lin.bias",
14
+ "model.diffusion_model.double_blocks.0.txt_mod.lin.weight",
15
+ "model.diffusion_model.double_blocks.1.img_mod.lin.bias",
16
+ "model.diffusion_model.double_blocks.1.img_mod.lin.weight",
17
+ "model.diffusion_model.double_blocks.1.txt_mod.lin.bias",
18
+ "model.diffusion_model.double_blocks.1.txt_mod.lin.weight",
19
+ "model.diffusion_model.double_blocks.10.img_mod.lin.bias",
20
+ "model.diffusion_model.double_blocks.10.img_mod.lin.weight",
21
+ "model.diffusion_model.double_blocks.10.txt_mod.lin.bias",
22
+ "model.diffusion_model.double_blocks.10.txt_mod.lin.weight",
23
+ "model.diffusion_model.double_blocks.11.img_mod.lin.bias",
24
+ "model.diffusion_model.double_blocks.11.img_mod.lin.weight",
25
+ "model.diffusion_model.double_blocks.11.txt_mod.lin.bias",
26
+ "model.diffusion_model.double_blocks.11.txt_mod.lin.weight",
27
+ "model.diffusion_model.double_blocks.12.img_mod.lin.bias",
28
+ "model.diffusion_model.double_blocks.12.img_mod.lin.weight",
29
+ "model.diffusion_model.double_blocks.12.txt_mod.lin.bias",
30
+ "model.diffusion_model.double_blocks.12.txt_mod.lin.weight",
31
+ "model.diffusion_model.double_blocks.13.img_mod.lin.bias",
32
+ "model.diffusion_model.double_blocks.13.img_mod.lin.weight",
33
+ "model.diffusion_model.double_blocks.13.txt_mod.lin.bias",
34
+ "model.diffusion_model.double_blocks.13.txt_mod.lin.weight",
35
+ "model.diffusion_model.double_blocks.14.img_mod.lin.bias",
36
+ "model.diffusion_model.double_blocks.14.img_mod.lin.weight",
37
+ "model.diffusion_model.double_blocks.14.txt_mod.lin.bias",
38
+ "model.diffusion_model.double_blocks.14.txt_mod.lin.weight",
39
+ "model.diffusion_model.double_blocks.15.img_mod.lin.bias",
40
+ "model.diffusion_model.double_blocks.15.img_mod.lin.weight",
41
+ "model.diffusion_model.double_blocks.15.txt_mod.lin.bias",
42
+ "model.diffusion_model.double_blocks.15.txt_mod.lin.weight",
43
+ "model.diffusion_model.double_blocks.16.img_mod.lin.bias",
44
+ "model.diffusion_model.double_blocks.16.img_mod.lin.weight",
45
+ "model.diffusion_model.double_blocks.16.txt_mod.lin.bias",
46
+ "model.diffusion_model.double_blocks.16.txt_mod.lin.weight",
47
+ "model.diffusion_model.double_blocks.17.img_mod.lin.bias",
48
+ "model.diffusion_model.double_blocks.17.img_mod.lin.weight",
49
+ "model.diffusion_model.double_blocks.17.txt_mod.lin.bias",
50
+ "model.diffusion_model.double_blocks.17.txt_mod.lin.weight",
51
+ "model.diffusion_model.double_blocks.18.img_mod.lin.bias",
52
+ "model.diffusion_model.double_blocks.18.img_mod.lin.weight",
53
+ "model.diffusion_model.double_blocks.18.txt_mod.lin.bias",
54
+ "model.diffusion_model.double_blocks.18.txt_mod.lin.weight",
55
+ "model.diffusion_model.double_blocks.2.img_mod.lin.bias",
56
+ "model.diffusion_model.double_blocks.2.img_mod.lin.weight",
57
+ "model.diffusion_model.double_blocks.2.txt_mod.lin.bias",
58
+ "model.diffusion_model.double_blocks.2.txt_mod.lin.weight",
59
+ "model.diffusion_model.double_blocks.3.img_mod.lin.bias",
60
+ "model.diffusion_model.double_blocks.3.img_mod.lin.weight",
61
+ "model.diffusion_model.double_blocks.3.txt_mod.lin.bias",
62
+ "model.diffusion_model.double_blocks.3.txt_mod.lin.weight",
63
+ "model.diffusion_model.double_blocks.4.img_mod.lin.bias",
64
+ "model.diffusion_model.double_blocks.4.img_mod.lin.weight",
65
+ "model.diffusion_model.double_blocks.4.txt_mod.lin.bias",
66
+ "model.diffusion_model.double_blocks.4.txt_mod.lin.weight",
67
+ "model.diffusion_model.double_blocks.5.img_mod.lin.bias",
68
+ "model.diffusion_model.double_blocks.5.img_mod.lin.weight",
69
+ "model.diffusion_model.double_blocks.5.txt_mod.lin.bias",
70
+ "model.diffusion_model.double_blocks.5.txt_mod.lin.weight",
71
+ "model.diffusion_model.double_blocks.6.img_mod.lin.bias",
72
+ "model.diffusion_model.double_blocks.6.img_mod.lin.weight",
73
+ "model.diffusion_model.double_blocks.6.txt_mod.lin.bias",
74
+ "model.diffusion_model.double_blocks.6.txt_mod.lin.weight",
75
+ "model.diffusion_model.double_blocks.7.img_mod.lin.bias",
76
+ "model.diffusion_model.double_blocks.7.img_mod.lin.weight",
77
+ "model.diffusion_model.double_blocks.7.txt_mod.lin.bias",
78
+ "model.diffusion_model.double_blocks.7.txt_mod.lin.weight",
79
+ "model.diffusion_model.double_blocks.8.img_mod.lin.bias",
80
+ "model.diffusion_model.double_blocks.8.img_mod.lin.weight",
81
+ "model.diffusion_model.double_blocks.8.txt_mod.lin.bias",
82
+ "model.diffusion_model.double_blocks.8.txt_mod.lin.weight",
83
+ "model.diffusion_model.double_blocks.9.img_mod.lin.bias",
84
+ "model.diffusion_model.double_blocks.9.img_mod.lin.weight",
85
+ "model.diffusion_model.double_blocks.9.txt_mod.lin.bias",
86
+ "model.diffusion_model.double_blocks.9.txt_mod.lin.weight",
87
+ "model.diffusion_model.final_layer.adaLN_modulation.1.bias",
88
+ "model.diffusion_model.final_layer.adaLN_modulation.1.weight",
89
+ # "model.diffusion_model.guidance_in.in_layer.bias",
90
+ # "model.diffusion_model.guidance_in.in_layer.weight",
91
+ # "model.diffusion_model.guidance_in.out_layer.bias",
92
+ # "model.diffusion_model.guidance_in.out_layer.weight",
93
+ "model.diffusion_model.single_blocks.0.modulation.lin.bias",
94
+ "model.diffusion_model.single_blocks.0.modulation.lin.weight",
95
+ "model.diffusion_model.single_blocks.1.modulation.lin.bias",
96
+ "model.diffusion_model.single_blocks.1.modulation.lin.weight",
97
+ "model.diffusion_model.single_blocks.10.modulation.lin.bias",
98
+ "model.diffusion_model.single_blocks.10.modulation.lin.weight",
99
+ "model.diffusion_model.single_blocks.11.modulation.lin.bias",
100
+ "model.diffusion_model.single_blocks.11.modulation.lin.weight",
101
+ "model.diffusion_model.single_blocks.12.modulation.lin.bias",
102
+ "model.diffusion_model.single_blocks.12.modulation.lin.weight",
103
+ "model.diffusion_model.single_blocks.13.modulation.lin.bias",
104
+ "model.diffusion_model.single_blocks.13.modulation.lin.weight",
105
+ "model.diffusion_model.single_blocks.14.modulation.lin.bias",
106
+ "model.diffusion_model.single_blocks.14.modulation.lin.weight",
107
+ "model.diffusion_model.single_blocks.15.modulation.lin.bias",
108
+ "model.diffusion_model.single_blocks.15.modulation.lin.weight",
109
+ "model.diffusion_model.single_blocks.16.modulation.lin.bias",
110
+ "model.diffusion_model.single_blocks.16.modulation.lin.weight",
111
+ "model.diffusion_model.single_blocks.17.modulation.lin.bias",
112
+ "model.diffusion_model.single_blocks.17.modulation.lin.weight",
113
+ "model.diffusion_model.single_blocks.18.modulation.lin.bias",
114
+ "model.diffusion_model.single_blocks.18.modulation.lin.weight",
115
+ "model.diffusion_model.single_blocks.19.modulation.lin.bias",
116
+ "model.diffusion_model.single_blocks.19.modulation.lin.weight",
117
+ "model.diffusion_model.single_blocks.2.modulation.lin.bias",
118
+ "model.diffusion_model.single_blocks.2.modulation.lin.weight",
119
+ "model.diffusion_model.single_blocks.20.modulation.lin.bias",
120
+ "model.diffusion_model.single_blocks.20.modulation.lin.weight",
121
+ "model.diffusion_model.single_blocks.21.modulation.lin.bias",
122
+ "model.diffusion_model.single_blocks.21.modulation.lin.weight",
123
+ "model.diffusion_model.single_blocks.22.modulation.lin.bias",
124
+ "model.diffusion_model.single_blocks.22.modulation.lin.weight",
125
+ "model.diffusion_model.single_blocks.23.modulation.lin.bias",
126
+ "model.diffusion_model.single_blocks.23.modulation.lin.weight",
127
+ "model.diffusion_model.single_blocks.24.modulation.lin.bias",
128
+ "model.diffusion_model.single_blocks.24.modulation.lin.weight",
129
+ "model.diffusion_model.single_blocks.25.modulation.lin.bias",
130
+ "model.diffusion_model.single_blocks.25.modulation.lin.weight",
131
+ "model.diffusion_model.single_blocks.26.modulation.lin.bias",
132
+ "model.diffusion_model.single_blocks.26.modulation.lin.weight",
133
+ "model.diffusion_model.single_blocks.27.modulation.lin.bias",
134
+ "model.diffusion_model.single_blocks.27.modulation.lin.weight",
135
+ "model.diffusion_model.single_blocks.28.modulation.lin.bias",
136
+ "model.diffusion_model.single_blocks.28.modulation.lin.weight",
137
+ "model.diffusion_model.single_blocks.29.modulation.lin.bias",
138
+ "model.diffusion_model.single_blocks.29.modulation.lin.weight",
139
+ "model.diffusion_model.single_blocks.3.modulation.lin.bias",
140
+ "model.diffusion_model.single_blocks.3.modulation.lin.weight",
141
+ "model.diffusion_model.single_blocks.30.modulation.lin.bias",
142
+ "model.diffusion_model.single_blocks.30.modulation.lin.weight",
143
+ "model.diffusion_model.single_blocks.31.modulation.lin.bias",
144
+ "model.diffusion_model.single_blocks.31.modulation.lin.weight",
145
+ "model.diffusion_model.single_blocks.32.modulation.lin.bias",
146
+ "model.diffusion_model.single_blocks.32.modulation.lin.weight",
147
+ "model.diffusion_model.single_blocks.33.modulation.lin.bias",
148
+ "model.diffusion_model.single_blocks.33.modulation.lin.weight",
149
+ "model.diffusion_model.single_blocks.34.modulation.lin.bias",
150
+ "model.diffusion_model.single_blocks.34.modulation.lin.weight",
151
+ "model.diffusion_model.single_blocks.35.modulation.lin.bias",
152
+ "model.diffusion_model.single_blocks.35.modulation.lin.weight",
153
+ "model.diffusion_model.single_blocks.36.modulation.lin.bias",
154
+ "model.diffusion_model.single_blocks.36.modulation.lin.weight",
155
+ "model.diffusion_model.single_blocks.37.modulation.lin.bias",
156
+ "model.diffusion_model.single_blocks.37.modulation.lin.weight",
157
+ "model.diffusion_model.single_blocks.4.modulation.lin.bias",
158
+ "model.diffusion_model.single_blocks.4.modulation.lin.weight",
159
+ "model.diffusion_model.single_blocks.5.modulation.lin.bias",
160
+ "model.diffusion_model.single_blocks.5.modulation.lin.weight",
161
+ "model.diffusion_model.single_blocks.6.modulation.lin.bias",
162
+ "model.diffusion_model.single_blocks.6.modulation.lin.weight",
163
+ "model.diffusion_model.single_blocks.7.modulation.lin.bias",
164
+ "model.diffusion_model.single_blocks.7.modulation.lin.weight",
165
+ "model.diffusion_model.single_blocks.8.modulation.lin.bias",
166
+ "model.diffusion_model.single_blocks.8.modulation.lin.weight",
167
+ "model.diffusion_model.single_blocks.9.modulation.lin.bias",
168
+ "model.diffusion_model.single_blocks.9.modulation.lin.weight",
169
+ # "model.diffusion_model.time_in.in_layer.bias",
170
+ # "model.diffusion_model.time_in.in_layer.weight",
171
+ # "model.diffusion_model.time_in.out_layer.bias",
172
+ # "model.diffusion_model.time_in.out_layer.weight",
173
+ # "model.diffusion_model.vector_in.in_layer.bias",
174
+ # "model.diffusion_model.vector_in.in_layer.weight",
175
+ # "model.diffusion_model.vector_in.out_layer.bias",
176
+ # "model.diffusion_model.vector_in.out_layer.weight",
177
+ }
178
+
179
+ def prune_safetensors_file(input_path: str, output_path: str, keys_to_remove: set):
180
+ """
181
+ Reads a safetensors file, removes specified layers, and saves to a new file.
182
+ """
183
+ tensors_to_keep = {}
184
+ original_count = 0
185
+
186
+ print(f"▶️ Reading from: {input_path}")
187
+ try:
188
+ # Open the source file for reading
189
+ with safe_open(input_path, framework="pt", device="cpu") as f:
190
+ all_keys = f.keys()
191
+ original_count = len(all_keys)
192
+
193
+ # Iterate through all tensors, keeping only the ones not in the removal list
194
+ for key in tqdm(all_keys, desc="Filtering Tensors"):
195
+ if 'modulation' not in key :
196
+ tensors_to_keep[key] = f.get_tensor(key)
197
+
198
+ removed_count = original_count - len(tensors_to_keep)
199
+
200
+ print("\n--- Summary ---")
201
+ print(f"Original tensor count: {original_count}")
202
+ print(f"Tensors removed: {removed_count}")
203
+ print(f"New tensor count: {len(tensors_to_keep)}")
204
+
205
+ if removed_count == 0 and len(keys_to_remove) > 0:
206
+ print("\n⚠️ Warning: None of the specified layers to remove were found in the input file.")
207
+
208
+ # Save the filtered tensors to the new file
209
+ print(f"\n▶️ Saving {len(tensors_to_keep)} tensors to: {output_path}")
210
+ save_file(tensors_to_keep, output_path)
211
+
212
+ print("\n✅ Pruning complete!")
213
+
214
+ except FileNotFoundError:
215
+ print(f"\n❌ Error: The file '{input_path}' was not found.")
216
+ except Exception as e:
217
+ print(f"\n❌ An error occurred: {e}")
218
+
219
+ if __name__ == "__main__":
220
+ parser = argparse.ArgumentParser(
221
+ description="Removes a predefined list of layers from a .safetensors file and saves it as a new file.",
222
+ formatter_class=argparse.RawTextHelpFormatter
223
+ )
224
+
225
+ parser.add_argument(
226
+ "input_file",
227
+ type=str,
228
+ help="Path to the source .safetensors file."
229
+ )
230
+ parser.add_argument(
231
+ "output_file",
232
+ type=str,
233
+ help="Path to save the new, pruned .safetensors file."
234
+ )
235
+
236
+ args = parser.parse_args()
237
+
238
+ prune_safetensors_file(args.input_file, args.output_file, KEYS_TO_REMOVE)