|
|
|
|
|
|
|
|
import argparse |
|
|
from safetensors import safe_open |
|
|
from safetensors.torch import save_file |
|
|
from tqdm import tqdm |
|
|
|
|
|
|
|
|
|
|
|
KEYS_TO_KEEP = { |
|
|
"distilled_guidance_layer.in_proj.bias", |
|
|
"distilled_guidance_layer.in_proj.weight", |
|
|
"distilled_guidance_layer.layers.0.in_layer.bias", |
|
|
"distilled_guidance_layer.layers.0.in_layer.weight", |
|
|
"distilled_guidance_layer.layers.0.out_layer.bias", |
|
|
"distilled_guidance_layer.layers.0.out_layer.weight", |
|
|
"distilled_guidance_layer.layers.1.in_layer.bias", |
|
|
"distilled_guidance_layer.layers.1.in_layer.weight", |
|
|
"distilled_guidance_layer.layers.1.out_layer.bias", |
|
|
"distilled_guidance_layer.layers.1.out_layer.weight", |
|
|
"distilled_guidance_layer.layers.2.in_layer.bias", |
|
|
"distilled_guidance_layer.layers.2.in_layer.weight", |
|
|
"distilled_guidance_layer.layers.2.out_layer.bias", |
|
|
"distilled_guidance_layer.layers.2.out_layer.weight", |
|
|
"distilled_guidance_layer.layers.3.in_layer.bias", |
|
|
"distilled_guidance_layer.layers.3.in_layer.weight", |
|
|
"distilled_guidance_layer.layers.3.out_layer.bias", |
|
|
"distilled_guidance_layer.layers.3.out_layer.weight", |
|
|
"distilled_guidance_layer.layers.4.in_layer.bias", |
|
|
"distilled_guidance_layer.layers.4.in_layer.weight", |
|
|
"distilled_guidance_layer.layers.4.out_layer.bias", |
|
|
"distilled_guidance_layer.layers.4.out_layer.weight", |
|
|
"distilled_guidance_layer.norms.0.scale", |
|
|
"distilled_guidance_layer.norms.1.scale", |
|
|
"distilled_guidance_layer.norms.2.scale", |
|
|
"distilled_guidance_layer.norms.3.scale", |
|
|
"distilled_guidance_layer.norms.4.scale", |
|
|
"distilled_guidance_layer.out_proj.bias", |
|
|
"distilled_guidance_layer.out_proj.weight", |
|
|
} |
|
|
|
|
|
def extract_safetensors_layers(input_path: str, output_path: str, keys_to_keep: set): |
|
|
""" |
|
|
Reads a safetensors file, extracts specified layers, and saves them to a new file. |
|
|
""" |
|
|
tensors_to_save = {} |
|
|
|
|
|
print(f"▶️ Reading from: {input_path}") |
|
|
try: |
|
|
|
|
|
with safe_open(input_path, framework="pt", device="cpu") as f: |
|
|
|
|
|
|
|
|
for key in tqdm(keys_to_keep, desc="Extracting Tensors"): |
|
|
try: |
|
|
|
|
|
tensors_to_save[key] = f.get_tensor(key) |
|
|
except KeyError: |
|
|
print(f"\n⚠️ Warning: Key not found in source file and will be skipped: {key}") |
|
|
|
|
|
print("\n--- Summary ---") |
|
|
print(f"Layers specified to keep: {len(keys_to_keep)}") |
|
|
print(f"Layers found and extracted: {len(tensors_to_save)}") |
|
|
|
|
|
if not tensors_to_save: |
|
|
print("\n❌ Error: No matching layers were found. The output file will not be created.") |
|
|
return |
|
|
|
|
|
|
|
|
print(f"\n▶️ Saving {len(tensors_to_save)} tensors to: {output_path}") |
|
|
save_file(tensors_to_save, output_path) |
|
|
|
|
|
print("\n✅ Extraction complete!") |
|
|
|
|
|
except FileNotFoundError: |
|
|
print(f"\n❌ Error: The file '{input_path}' was not found.") |
|
|
except Exception as e: |
|
|
print(f"\n❌ An error occurred: {e}") |
|
|
|
|
|
if __name__ == "__main__": |
|
|
parser = argparse.ArgumentParser( |
|
|
description="Extracts a predefined list of layers from a .safetensors file and saves them to a new file.", |
|
|
formatter_class=argparse.RawTextHelpFormatter |
|
|
) |
|
|
|
|
|
parser.add_argument( |
|
|
"input_file", |
|
|
type=str, |
|
|
help="Path to the source .safetensors file." |
|
|
) |
|
|
parser.add_argument( |
|
|
"output_file", |
|
|
type=str, |
|
|
help="Path to save the new, extracted .safetensors file." |
|
|
) |
|
|
|
|
|
args = parser.parse_args() |
|
|
|
|
|
extract_safetensors_layers(args.input_file, args.output_file, KEYS_TO_KEEP) |
|
|
|