Spaces:
Running
on
Zero
Running
on
Zero
| """ | |
| General utils | |
| Author: Xiaoyang Wu (xiaoyang.wu.cs@gmail.com) | |
| Please cite our work if the code is helpful to you. | |
| """ | |
| # Copyright (c) Meta Platforms, Inc. and affiliates. | |
| # | |
| # Licensed under the Apache License, Version 2.0 (the "License"); | |
| # you may not use this file except in compliance with the License. | |
| # You may obtain a copy of the License at | |
| # | |
| # http://www.apache.org/licenses/LICENSE-2.0 | |
| # | |
| # Unless required by applicable law or agreed to in writing, software | |
| # distributed under the License is distributed on an "AS IS" BASIS, | |
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
| # See the License for the specific language governing permissions and | |
| # limitations under the License. | |
| import os | |
| import random | |
| import numpy as np | |
| import torch | |
| import torch.backends.cudnn as cudnn | |
| from datetime import datetime | |
| def offset2bincount(offset): | |
| return torch.diff( | |
| offset, prepend=torch.tensor([0], device=offset.device, dtype=torch.long) | |
| ) | |
| def bincount2offset(bincount): | |
| return torch.cumsum(bincount, dim=0) | |
| def offset2batch(offset): | |
| bincount = offset2bincount(offset) | |
| return torch.arange( | |
| len(bincount), device=offset.device, dtype=torch.long | |
| ).repeat_interleave(bincount) | |
| def batch2offset(batch): | |
| return torch.cumsum(batch.bincount(), dim=0).long() | |
| def get_random_seed(): | |
| seed = ( | |
| os.getpid() | |
| + int(datetime.now().strftime("%S%f")) | |
| + int.from_bytes(os.urandom(2), "big") | |
| ) | |
| return seed | |
| def set_seed(seed=None): | |
| if seed is None: | |
| seed = get_random_seed() | |
| random.seed(seed) | |
| np.random.seed(seed) | |
| torch.manual_seed(seed) | |
| torch.cuda.manual_seed(seed) | |
| torch.cuda.manual_seed_all(seed) | |
| cudnn.benchmark = False | |
| cudnn.deterministic = True | |
| os.environ["PYTHONHASHSEED"] = str(seed) | |