#!/usr/bin/env python import argparse from safetensors import safe_open from safetensors.torch import save_file from tqdm import tqdm # This set contains the specific layers we want to find and save. # Using a set is efficient for checking if a layer should be kept. KEYS_TO_KEEP = { "distilled_guidance_layer.in_proj.bias", "distilled_guidance_layer.in_proj.weight", "distilled_guidance_layer.layers.0.in_layer.bias", "distilled_guidance_layer.layers.0.in_layer.weight", "distilled_guidance_layer.layers.0.out_layer.bias", "distilled_guidance_layer.layers.0.out_layer.weight", "distilled_guidance_layer.layers.1.in_layer.bias", "distilled_guidance_layer.layers.1.in_layer.weight", "distilled_guidance_layer.layers.1.out_layer.bias", "distilled_guidance_layer.layers.1.out_layer.weight", "distilled_guidance_layer.layers.2.in_layer.bias", "distilled_guidance_layer.layers.2.in_layer.weight", "distilled_guidance_layer.layers.2.out_layer.bias", "distilled_guidance_layer.layers.2.out_layer.weight", "distilled_guidance_layer.layers.3.in_layer.bias", "distilled_guidance_layer.layers.3.in_layer.weight", "distilled_guidance_layer.layers.3.out_layer.bias", "distilled_guidance_layer.layers.3.out_layer.weight", "distilled_guidance_layer.layers.4.in_layer.bias", "distilled_guidance_layer.layers.4.in_layer.weight", "distilled_guidance_layer.layers.4.out_layer.bias", "distilled_guidance_layer.layers.4.out_layer.weight", "distilled_guidance_layer.norms.0.scale", "distilled_guidance_layer.norms.1.scale", "distilled_guidance_layer.norms.2.scale", "distilled_guidance_layer.norms.3.scale", "distilled_guidance_layer.norms.4.scale", "distilled_guidance_layer.out_proj.bias", "distilled_guidance_layer.out_proj.weight", } def extract_safetensors_layers(input_path: str, output_path: str, keys_to_keep: set): """ Reads a safetensors file, extracts specified layers, and saves them to a new file. """ tensors_to_save = {} print(f"▶️ Reading from: {input_path}") try: # Open the source file for reading with safe_open(input_path, framework="pt", device="cpu") as f: # Iterate through the keys we want to keep for key in tqdm(keys_to_keep, desc="Extracting Tensors"): try: # If the key exists in the file, add its tensor to our dictionary tensors_to_save[key] = f.get_tensor(key) except KeyError: print(f"\n⚠️ Warning: Key not found in source file and will be skipped: {key}") print("\n--- Summary ---") print(f"Layers specified to keep: {len(keys_to_keep)}") print(f"Layers found and extracted: {len(tensors_to_save)}") if not tensors_to_save: print("\n❌ Error: No matching layers were found. The output file will not be created.") return # Save the extracted tensors to the new file print(f"\n▶️ Saving {len(tensors_to_save)} tensors to: {output_path}") save_file(tensors_to_save, output_path) print("\n✅ Extraction complete!") except FileNotFoundError: print(f"\n❌ Error: The file '{input_path}' was not found.") except Exception as e: print(f"\n❌ An error occurred: {e}") if __name__ == "__main__": parser = argparse.ArgumentParser( description="Extracts a predefined list of layers from a .safetensors file and saves them to a new file.", formatter_class=argparse.RawTextHelpFormatter ) parser.add_argument( "input_file", type=str, help="Path to the source .safetensors file." ) parser.add_argument( "output_file", type=str, help="Path to save the new, extracted .safetensors file." ) args = parser.parse_args() extract_safetensors_layers(args.input_file, args.output_file, KEYS_TO_KEEP)