| # /// script | |
| # dependencies = [ | |
| # "torch", | |
| # "numpy", | |
| # ] | |
| # /// | |
| """Simple utilities for running the models.""" | |
| import torch | |
| def to_dtype(dtype_str: str): | |
| """Convert string to torch dtype.""" | |
| if dtype_str == "float16": | |
| return torch.float16 | |
| if dtype_str == "bfloat16": | |
| return torch.bfloat16 | |
| return torch.float32 | |
| def tensor_stats(t: torch.Tensor) -> str: | |
| """Generate stats string for a tensor.""" | |
| return (f"shape={tuple(t.shape)}, " | |
| f"dtype={t.dtype}, " | |
| f"device={t.device}, " | |
| f"mean={t.mean().item():.6f}, " | |
| f"std={t.std().item():.6f}") | |
| def set_seed(seed: int): | |
| """Set seeds for reproducibility.""" | |
| torch.manual_seed(seed) | |
| if torch.cuda.is_available(): | |
| torch.cuda.manual_seed(seed) | |
| torch.cuda.manual_seed_all(seed) | |
| torch.backends.cudnn.deterministic = True | |
| torch.backends.cudnn.benchmark = False |