import os import argparse from safetensors import safe_open def get_lora_dimensions_from_directory(directory_path): """ Scans a directory for .safetensors files and extracts the LoRA network dimension from their metadata, falling back to inspecting tensor shapes if metadata is absent. Args: directory_path (str): The path to the directory to scan. """ print(f"Scanning for LoRA models in: '{directory_path}'...\n") found_models = 0 # Walk through the directory and its subdirectories for root, _, files in os.walk(directory_path): for filename in sorted(files): if filename.lower().endswith(".safetensors"): file_path = os.path.join(root, filename) try: # Use safe_open to read metadata without loading the whole file with safe_open(file_path, framework="pt", device="cpu") as f: metadata = f.metadata() if not metadata: print(f"- {filename}: No metadata found. Checking weights...") # Fallthrough to weight checking # LoRA training scripts like Kohya's SS store the dimension here network_dim = metadata.get("ss_network_dim") if metadata else None if network_dim: print(f"- {filename}: Dimension = {network_dim} (from metadata)") found_models += 1 else: # Fallback: try to determine dimension from tensor shapes dim_from_weights = None for key in f.keys(): # Typically, the rank is the first dimension of the 'lora_down' tensor if key.endswith("lora_down.weight"): tensor = f.get_tensor(key) # The shape of lora_down.weight is (rank, in_features) dim_from_weights = tensor.shape[0] break # Found it, no need to check other keys # Alternative naming uses lora_B or lora_up for the up-projection if key.endswith(("lora_B.weight", "lora_up.weight")): tensor = f.get_tensor(key) # The shape of lora_up/lora_B is (out_features, rank) dim_from_weights = tensor.shape[1] break # Found it, no need to check other keys if dim_from_weights is not None: print(f"- {filename}: Dimension = {dim_from_weights} (from weights)") found_models += 1 else: print(f"- {filename}: (Could not determine dimension from metadata or weights)") except Exception as e: print(f"Could not process {filename}. Error: {e}") if found_models == 0: print("\nNo LoRA models with dimension information were found in the specified directory.") else: print(f"\nScan complete. Found {found_models} models with dimension info.") if __name__ == "__main__": # Set up command-line argument parsing parser = argparse.ArgumentParser( description="Get the network dimensions of LoRA models in a directory.", formatter_class=argparse.RawTextHelpFormatter ) parser.add_argument( "directory", type=str, help="The path to the directory containing your LoRA (.safetensors) files." ) args = parser.parse_args() # Check if the provided path is a valid directory if not os.path.isdir(args.directory): print(f"Error: The path '{args.directory}' is not a valid directory.") else: get_lora_dimensions_from_directory(args.directory)