python-script-dump / lora_dims.py
anyMODE's picture
Upload lora_dims.py
042c5bc verified
import os
import argparse
from safetensors import safe_open
def get_lora_dimensions_from_directory(directory_path):
"""
Scans a directory for .safetensors files and extracts the LoRA network dimension
from their metadata, falling back to inspecting tensor shapes if metadata is absent.
Args:
directory_path (str): The path to the directory to scan.
"""
print(f"Scanning for LoRA models in: '{directory_path}'...\n")
found_models = 0
# Walk through the directory and its subdirectories
for root, _, files in os.walk(directory_path):
for filename in sorted(files):
if filename.lower().endswith(".safetensors"):
file_path = os.path.join(root, filename)
try:
# Use safe_open to read metadata without loading the whole file
with safe_open(file_path, framework="pt", device="cpu") as f:
metadata = f.metadata()
if not metadata:
print(f"- {filename}: No metadata found. Checking weights...")
# Fallthrough to weight checking
# LoRA training scripts like Kohya's SS store the dimension here
network_dim = metadata.get("ss_network_dim") if metadata else None
if network_dim:
print(f"- {filename}: Dimension = {network_dim} (from metadata)")
found_models += 1
else:
# Fallback: try to determine dimension from tensor shapes
dim_from_weights = None
for key in f.keys():
# Typically, the rank is the first dimension of the 'lora_down' tensor
if key.endswith("lora_down.weight"):
tensor = f.get_tensor(key)
# The shape of lora_down.weight is (rank, in_features)
dim_from_weights = tensor.shape[0]
break # Found it, no need to check other keys
# Alternative naming uses lora_B or lora_up for the up-projection
if key.endswith(("lora_B.weight", "lora_up.weight")):
tensor = f.get_tensor(key)
# The shape of lora_up/lora_B is (out_features, rank)
dim_from_weights = tensor.shape[1]
break # Found it, no need to check other keys
if dim_from_weights is not None:
print(f"- {filename}: Dimension = {dim_from_weights} (from weights)")
found_models += 1
else:
print(f"- {filename}: (Could not determine dimension from metadata or weights)")
except Exception as e:
print(f"Could not process {filename}. Error: {e}")
if found_models == 0:
print("\nNo LoRA models with dimension information were found in the specified directory.")
else:
print(f"\nScan complete. Found {found_models} models with dimension info.")
if __name__ == "__main__":
# Set up command-line argument parsing
parser = argparse.ArgumentParser(
description="Get the network dimensions of LoRA models in a directory.",
formatter_class=argparse.RawTextHelpFormatter
)
parser.add_argument(
"directory",
type=str,
help="The path to the directory containing your LoRA (.safetensors) files."
)
args = parser.parse_args()
# Check if the provided path is a valid directory
if not os.path.isdir(args.directory):
print(f"Error: The path '{args.directory}' is not a valid directory.")
else:
get_lora_dimensions_from_directory(args.directory)