# /// script # dependencies = [ # "torch", # "numpy", # ] # /// """Simple utilities for running the models.""" import torch def to_dtype(dtype_str: str): """Convert string to torch dtype.""" if dtype_str == "float16": return torch.float16 if dtype_str == "bfloat16": return torch.bfloat16 return torch.float32 def tensor_stats(t: torch.Tensor) -> str: """Generate stats string for a tensor.""" return (f"shape={tuple(t.shape)}, " f"dtype={t.dtype}, " f"device={t.device}, " f"mean={t.mean().item():.6f}, " f"std={t.std().item():.6f}") def set_seed(seed: int): """Set seeds for reproducibility.""" torch.manual_seed(seed) if torch.cuda.is_available(): torch.cuda.manual_seed(seed) torch.cuda.manual_seed_all(seed) torch.backends.cudnn.deterministic = True torch.backends.cudnn.benchmark = False