Spaces:
Sleeping
Sleeping
| """Random seed utilities.""" | |
| from __future__ import annotations | |
| import os | |
| import random | |
| from dataclasses import dataclass | |
| import numpy as np | |
| try: # pragma: no cover - optional dependency | |
| import torch | |
| except ImportError: # pragma: no cover - optional dependency | |
| torch = None # type: ignore | |
| class SeedState: | |
| """Record of RNG seeds applied across libraries.""" | |
| python: int | |
| numpy: int | |
| torch: int | None = None | |
| def set_global_seeds(seed: int) -> SeedState: | |
| """Seed ``random``, ``numpy`` and ``torch`` (if available).""" | |
| os.environ["PYTHONHASHSEED"] = str(seed) | |
| random.seed(seed) | |
| np.random.seed(seed) | |
| torch_seed: int | None = None | |
| if "torch" in globals() and torch is not None: # pragma: no branch | |
| torch.manual_seed(seed) | |
| if torch.cuda.is_available(): # pragma: no cover - GPU specific | |
| torch.cuda.manual_seed_all(seed) | |
| if hasattr(torch.backends, "cudnn"): | |
| torch.backends.cudnn.deterministic = True | |
| torch.backends.cudnn.benchmark = False | |
| torch_seed = seed | |
| return SeedState(python=seed, numpy=seed, torch=torch_seed) | |