|
|
|
|
|
|
|
|
import argparse |
|
|
from safetensors import safe_open |
|
|
|
|
|
def list_safetensor_layers(filepath: str): |
|
|
""" |
|
|
Opens a .safetensors file and prints the name and shape of each tensor. |
|
|
|
|
|
Args: |
|
|
filepath (str): The path to the .safetensors file. |
|
|
""" |
|
|
try: |
|
|
print(f"\nπ Tensors in: {filepath}\n" + "="*50) |
|
|
|
|
|
total_tensors = 0 |
|
|
with safe_open(filepath, framework="pt", device="cpu") as f: |
|
|
for key in f.keys(): |
|
|
tensor = f.get_tensor(key) |
|
|
print(f"- {key:<50} | Shape: {tensor.shape}") |
|
|
total_tensors += 1 |
|
|
|
|
|
print("="*50 + f"\nβ
Found {total_tensors} total tensors.\n") |
|
|
|
|
|
except FileNotFoundError: |
|
|
print(f"β Error: The file '{filepath}' was not found.") |
|
|
except Exception as e: |
|
|
print(f"β An error occurred: {e}") |
|
|
print("Please ensure the file is a valid .safetensors file.") |
|
|
|
|
|
if __name__ == "__main__": |
|
|
|
|
|
parser = argparse.ArgumentParser( |
|
|
description="List all layers (tensors) and their shapes in a .safetensors file.", |
|
|
formatter_class=argparse.RawTextHelpFormatter |
|
|
) |
|
|
|
|
|
parser.add_argument( |
|
|
"filepath", |
|
|
type=str, |
|
|
help="Path to the .safetensors file." |
|
|
) |
|
|
|
|
|
args = parser.parse_args() |
|
|
|
|
|
|
|
|
list_safetensor_layers(args.filepath) |
|
|
|