File size: 4,236 Bytes
042c5bc
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
import os
import argparse
from safetensors import safe_open

def get_lora_dimensions_from_directory(directory_path):
    """
    Scans a directory for .safetensors files and extracts the LoRA network dimension
    from their metadata, falling back to inspecting tensor shapes if metadata is absent.

    Args:
        directory_path (str): The path to the directory to scan.
    """
    print(f"Scanning for LoRA models in: '{directory_path}'...\n")
    
    found_models = 0
    
    # Walk through the directory and its subdirectories
    for root, _, files in os.walk(directory_path):
        for filename in sorted(files):
            if filename.lower().endswith(".safetensors"):
                file_path = os.path.join(root, filename)
                try:
                    # Use safe_open to read metadata without loading the whole file
                    with safe_open(file_path, framework="pt", device="cpu") as f:
                        metadata = f.metadata()
                        
                        if not metadata:
                            print(f"- {filename}: No metadata found. Checking weights...")
                            # Fallthrough to weight checking
                        
                        # LoRA training scripts like Kohya's SS store the dimension here
                        network_dim = metadata.get("ss_network_dim") if metadata else None
                        
                        if network_dim:
                            print(f"- {filename}: Dimension = {network_dim} (from metadata)")
                            found_models += 1
                        else:
                            # Fallback: try to determine dimension from tensor shapes
                            dim_from_weights = None
                            for key in f.keys():
                                # Typically, the rank is the first dimension of the 'lora_down' tensor
                                if key.endswith("lora_down.weight"):
                                    tensor = f.get_tensor(key)
                                    # The shape of lora_down.weight is (rank, in_features)
                                    dim_from_weights = tensor.shape[0]
                                    break  # Found it, no need to check other keys
                                
                                # Alternative naming uses lora_B or lora_up for the up-projection
                                if key.endswith(("lora_B.weight", "lora_up.weight")):
                                    tensor = f.get_tensor(key)
                                    # The shape of lora_up/lora_B is (out_features, rank)
                                    dim_from_weights = tensor.shape[1]
                                    break  # Found it, no need to check other keys
                            
                            if dim_from_weights is not None:
                                print(f"- {filename}: Dimension = {dim_from_weights} (from weights)")
                                found_models += 1
                            else:
                                print(f"- {filename}: (Could not determine dimension from metadata or weights)")

                except Exception as e:
                    print(f"Could not process {filename}. Error: {e}")

    if found_models == 0:
        print("\nNo LoRA models with dimension information were found in the specified directory.")
    else:
        print(f"\nScan complete. Found {found_models} models with dimension info.")

if __name__ == "__main__":
    # Set up command-line argument parsing
    parser = argparse.ArgumentParser(
        description="Get the network dimensions of LoRA models in a directory.",
        formatter_class=argparse.RawTextHelpFormatter
    )
    
    parser.add_argument(
        "directory",
        type=str,
        help="The path to the directory containing your LoRA (.safetensors) files."
    )

    args = parser.parse_args()

    # Check if the provided path is a valid directory
    if not os.path.isdir(args.directory):
        print(f"Error: The path '{args.directory}' is not a valid directory.")
    else:
        get_lora_dimensions_from_directory(args.directory)