File size: 4,020 Bytes
54c3b42 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 |
#!/usr/bin/env python
import argparse
from safetensors import safe_open
from safetensors.torch import save_file
from tqdm import tqdm
# This set contains the specific layers we want to find and save.
# Using a set is efficient for checking if a layer should be kept.
KEYS_TO_KEEP = {
"distilled_guidance_layer.in_proj.bias",
"distilled_guidance_layer.in_proj.weight",
"distilled_guidance_layer.layers.0.in_layer.bias",
"distilled_guidance_layer.layers.0.in_layer.weight",
"distilled_guidance_layer.layers.0.out_layer.bias",
"distilled_guidance_layer.layers.0.out_layer.weight",
"distilled_guidance_layer.layers.1.in_layer.bias",
"distilled_guidance_layer.layers.1.in_layer.weight",
"distilled_guidance_layer.layers.1.out_layer.bias",
"distilled_guidance_layer.layers.1.out_layer.weight",
"distilled_guidance_layer.layers.2.in_layer.bias",
"distilled_guidance_layer.layers.2.in_layer.weight",
"distilled_guidance_layer.layers.2.out_layer.bias",
"distilled_guidance_layer.layers.2.out_layer.weight",
"distilled_guidance_layer.layers.3.in_layer.bias",
"distilled_guidance_layer.layers.3.in_layer.weight",
"distilled_guidance_layer.layers.3.out_layer.bias",
"distilled_guidance_layer.layers.3.out_layer.weight",
"distilled_guidance_layer.layers.4.in_layer.bias",
"distilled_guidance_layer.layers.4.in_layer.weight",
"distilled_guidance_layer.layers.4.out_layer.bias",
"distilled_guidance_layer.layers.4.out_layer.weight",
"distilled_guidance_layer.norms.0.scale",
"distilled_guidance_layer.norms.1.scale",
"distilled_guidance_layer.norms.2.scale",
"distilled_guidance_layer.norms.3.scale",
"distilled_guidance_layer.norms.4.scale",
"distilled_guidance_layer.out_proj.bias",
"distilled_guidance_layer.out_proj.weight",
}
def extract_safetensors_layers(input_path: str, output_path: str, keys_to_keep: set):
"""
Reads a safetensors file, extracts specified layers, and saves them to a new file.
"""
tensors_to_save = {}
print(f"▶️ Reading from: {input_path}")
try:
# Open the source file for reading
with safe_open(input_path, framework="pt", device="cpu") as f:
# Iterate through the keys we want to keep
for key in tqdm(keys_to_keep, desc="Extracting Tensors"):
try:
# If the key exists in the file, add its tensor to our dictionary
tensors_to_save[key] = f.get_tensor(key)
except KeyError:
print(f"\n⚠️ Warning: Key not found in source file and will be skipped: {key}")
print("\n--- Summary ---")
print(f"Layers specified to keep: {len(keys_to_keep)}")
print(f"Layers found and extracted: {len(tensors_to_save)}")
if not tensors_to_save:
print("\n❌ Error: No matching layers were found. The output file will not be created.")
return
# Save the extracted tensors to the new file
print(f"\n▶️ Saving {len(tensors_to_save)} tensors to: {output_path}")
save_file(tensors_to_save, output_path)
print("\n✅ Extraction complete!")
except FileNotFoundError:
print(f"\n❌ Error: The file '{input_path}' was not found.")
except Exception as e:
print(f"\n❌ An error occurred: {e}")
if __name__ == "__main__":
parser = argparse.ArgumentParser(
description="Extracts a predefined list of layers from a .safetensors file and saves them to a new file.",
formatter_class=argparse.RawTextHelpFormatter
)
parser.add_argument(
"input_file",
type=str,
help="Path to the source .safetensors file."
)
parser.add_argument(
"output_file",
type=str,
help="Path to save the new, extracted .safetensors file."
)
args = parser.parse_args()
extract_safetensors_layers(args.input_file, args.output_file, KEYS_TO_KEEP)
|