File size: 937 Bytes
93c5002 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 |
# /// script
# dependencies = [
# "torch",
# "numpy",
# ]
# ///
"""Simple utilities for running the models."""
import torch
def to_dtype(dtype_str: str):
"""Convert string to torch dtype."""
if dtype_str == "float16":
return torch.float16
if dtype_str == "bfloat16":
return torch.bfloat16
return torch.float32
def tensor_stats(t: torch.Tensor) -> str:
"""Generate stats string for a tensor."""
return (f"shape={tuple(t.shape)}, "
f"dtype={t.dtype}, "
f"device={t.device}, "
f"mean={t.mean().item():.6f}, "
f"std={t.std().item():.6f}")
def set_seed(seed: int):
"""Set seeds for reproducibility."""
torch.manual_seed(seed)
if torch.cuda.is_available():
torch.cuda.manual_seed(seed)
torch.cuda.manual_seed_all(seed)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False |