Spaces:
Sleeping
Sleeping
File size: 1,169 Bytes
5f58699 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 |
"""Random seed utilities."""
from __future__ import annotations
import os
import random
from dataclasses import dataclass
import numpy as np
try: # pragma: no cover - optional dependency
import torch
except ImportError: # pragma: no cover - optional dependency
torch = None # type: ignore
@dataclass(slots=True)
class SeedState:
"""Record of RNG seeds applied across libraries."""
python: int
numpy: int
torch: int | None = None
def set_global_seeds(seed: int) -> SeedState:
"""Seed ``random``, ``numpy`` and ``torch`` (if available)."""
os.environ["PYTHONHASHSEED"] = str(seed)
random.seed(seed)
np.random.seed(seed)
torch_seed: int | None = None
if "torch" in globals() and torch is not None: # pragma: no branch
torch.manual_seed(seed)
if torch.cuda.is_available(): # pragma: no cover - GPU specific
torch.cuda.manual_seed_all(seed)
if hasattr(torch.backends, "cudnn"):
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False
torch_seed = seed
return SeedState(python=seed, numpy=seed, torch=torch_seed)
|