|
|
import os |
|
|
import argparse |
|
|
from safetensors import safe_open |
|
|
|
|
|
def get_lora_dimensions_from_directory(directory_path): |
|
|
""" |
|
|
Scans a directory for .safetensors files and extracts the LoRA network dimension |
|
|
from their metadata, falling back to inspecting tensor shapes if metadata is absent. |
|
|
|
|
|
Args: |
|
|
directory_path (str): The path to the directory to scan. |
|
|
""" |
|
|
print(f"Scanning for LoRA models in: '{directory_path}'...\n") |
|
|
|
|
|
found_models = 0 |
|
|
|
|
|
|
|
|
for root, _, files in os.walk(directory_path): |
|
|
for filename in sorted(files): |
|
|
if filename.lower().endswith(".safetensors"): |
|
|
file_path = os.path.join(root, filename) |
|
|
try: |
|
|
|
|
|
with safe_open(file_path, framework="pt", device="cpu") as f: |
|
|
metadata = f.metadata() |
|
|
|
|
|
if not metadata: |
|
|
print(f"- {filename}: No metadata found. Checking weights...") |
|
|
|
|
|
|
|
|
|
|
|
network_dim = metadata.get("ss_network_dim") if metadata else None |
|
|
|
|
|
if network_dim: |
|
|
print(f"- {filename}: Dimension = {network_dim} (from metadata)") |
|
|
found_models += 1 |
|
|
else: |
|
|
|
|
|
dim_from_weights = None |
|
|
for key in f.keys(): |
|
|
|
|
|
if key.endswith("lora_down.weight"): |
|
|
tensor = f.get_tensor(key) |
|
|
|
|
|
dim_from_weights = tensor.shape[0] |
|
|
break |
|
|
|
|
|
|
|
|
if key.endswith(("lora_B.weight", "lora_up.weight")): |
|
|
tensor = f.get_tensor(key) |
|
|
|
|
|
dim_from_weights = tensor.shape[1] |
|
|
break |
|
|
|
|
|
if dim_from_weights is not None: |
|
|
print(f"- {filename}: Dimension = {dim_from_weights} (from weights)") |
|
|
found_models += 1 |
|
|
else: |
|
|
print(f"- {filename}: (Could not determine dimension from metadata or weights)") |
|
|
|
|
|
except Exception as e: |
|
|
print(f"Could not process {filename}. Error: {e}") |
|
|
|
|
|
if found_models == 0: |
|
|
print("\nNo LoRA models with dimension information were found in the specified directory.") |
|
|
else: |
|
|
print(f"\nScan complete. Found {found_models} models with dimension info.") |
|
|
|
|
|
if __name__ == "__main__": |
|
|
|
|
|
parser = argparse.ArgumentParser( |
|
|
description="Get the network dimensions of LoRA models in a directory.", |
|
|
formatter_class=argparse.RawTextHelpFormatter |
|
|
) |
|
|
|
|
|
parser.add_argument( |
|
|
"directory", |
|
|
type=str, |
|
|
help="The path to the directory containing your LoRA (.safetensors) files." |
|
|
) |
|
|
|
|
|
args = parser.parse_args() |
|
|
|
|
|
|
|
|
if not os.path.isdir(args.directory): |
|
|
print(f"Error: The path '{args.directory}' is not a valid directory.") |
|
|
else: |
|
|
get_lora_dimensions_from_directory(args.directory) |
|
|
|
|
|
|
|
|
|
|
|
|