File size: 4,020 Bytes
54c3b42
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
#!/usr/bin/env python

import argparse
from safetensors import safe_open
from safetensors.torch import save_file
from tqdm import tqdm

# This set contains the specific layers we want to find and save.
# Using a set is efficient for checking if a layer should be kept.
KEYS_TO_KEEP = {
    "distilled_guidance_layer.in_proj.bias",
    "distilled_guidance_layer.in_proj.weight",
    "distilled_guidance_layer.layers.0.in_layer.bias",
    "distilled_guidance_layer.layers.0.in_layer.weight",
    "distilled_guidance_layer.layers.0.out_layer.bias",
    "distilled_guidance_layer.layers.0.out_layer.weight",
    "distilled_guidance_layer.layers.1.in_layer.bias",
    "distilled_guidance_layer.layers.1.in_layer.weight",
    "distilled_guidance_layer.layers.1.out_layer.bias",
    "distilled_guidance_layer.layers.1.out_layer.weight",
    "distilled_guidance_layer.layers.2.in_layer.bias",
    "distilled_guidance_layer.layers.2.in_layer.weight",
    "distilled_guidance_layer.layers.2.out_layer.bias",
    "distilled_guidance_layer.layers.2.out_layer.weight",
    "distilled_guidance_layer.layers.3.in_layer.bias",
    "distilled_guidance_layer.layers.3.in_layer.weight",
    "distilled_guidance_layer.layers.3.out_layer.bias",
    "distilled_guidance_layer.layers.3.out_layer.weight",
    "distilled_guidance_layer.layers.4.in_layer.bias",
    "distilled_guidance_layer.layers.4.in_layer.weight",
    "distilled_guidance_layer.layers.4.out_layer.bias",
    "distilled_guidance_layer.layers.4.out_layer.weight",
    "distilled_guidance_layer.norms.0.scale",
    "distilled_guidance_layer.norms.1.scale",
    "distilled_guidance_layer.norms.2.scale",
    "distilled_guidance_layer.norms.3.scale",
    "distilled_guidance_layer.norms.4.scale",
    "distilled_guidance_layer.out_proj.bias",
    "distilled_guidance_layer.out_proj.weight",
}

def extract_safetensors_layers(input_path: str, output_path: str, keys_to_keep: set):
    """
    Reads a safetensors file, extracts specified layers, and saves them to a new file.
    """
    tensors_to_save = {}
    
    print(f"▶️  Reading from: {input_path}")
    try:
        # Open the source file for reading
        with safe_open(input_path, framework="pt", device="cpu") as f:
            
            # Iterate through the keys we want to keep
            for key in tqdm(keys_to_keep, desc="Extracting Tensors"):
                try:
                    # If the key exists in the file, add its tensor to our dictionary
                    tensors_to_save[key] = f.get_tensor(key)
                except KeyError:
                    print(f"\n⚠️  Warning: Key not found in source file and will be skipped: {key}")

        print("\n--- Summary ---")
        print(f"Layers specified to keep: {len(keys_to_keep)}")
        print(f"Layers found and extracted: {len(tensors_to_save)}")

        if not tensors_to_save:
            print("\n❌ Error: No matching layers were found. The output file will not be created.")
            return

        # Save the extracted tensors to the new file
        print(f"\n▶️  Saving {len(tensors_to_save)} tensors to: {output_path}")
        save_file(tensors_to_save, output_path)

        print("\n✅ Extraction complete!")

    except FileNotFoundError:
        print(f"\n❌ Error: The file '{input_path}' was not found.")
    except Exception as e:
        print(f"\n❌ An error occurred: {e}")

if __name__ == "__main__":
    parser = argparse.ArgumentParser(
        description="Extracts a predefined list of layers from a .safetensors file and saves them to a new file.",
        formatter_class=argparse.RawTextHelpFormatter
    )

    parser.add_argument(
        "input_file",
        type=str,
        help="Path to the source .safetensors file."
    )
    parser.add_argument(
        "output_file",
        type=str,
        help="Path to save the new, extracted .safetensors file."
    )

    args = parser.parse_args()
    
    extract_safetensors_layers(args.input_file, args.output_file, KEYS_TO_KEEP)