python-script-dump / listlayers.py
anyMODE's picture
Upload 4 files
54c3b42 verified
raw
history blame
1.42 kB
#!/usr/bin/env python
import argparse
from safetensors import safe_open
def list_safetensor_layers(filepath: str):
"""
Opens a .safetensors file and prints the name and shape of each tensor.
Args:
filepath (str): The path to the .safetensors file.
"""
try:
print(f"\nπŸ“„ Tensors in: {filepath}\n" + "="*50)
total_tensors = 0
with safe_open(filepath, framework="pt", device="cpu") as f:
for key in f.keys():
tensor = f.get_tensor(key)
print(f"- {key:<50} | Shape: {tensor.shape}")
total_tensors += 1
print("="*50 + f"\nβœ… Found {total_tensors} total tensors.\n")
except FileNotFoundError:
print(f"❌ Error: The file '{filepath}' was not found.")
except Exception as e:
print(f"❌ An error occurred: {e}")
print("Please ensure the file is a valid .safetensors file.")
if __name__ == "__main__":
# --- Argument Parser Setup ---
parser = argparse.ArgumentParser(
description="List all layers (tensors) and their shapes in a .safetensors file.",
formatter_class=argparse.RawTextHelpFormatter
)
parser.add_argument(
"filepath",
type=str,
help="Path to the .safetensors file."
)
args = parser.parse_args()
# --- Run the function ---
list_safetensor_layers(args.filepath)