Spaces:
Running
on
Zero
Running
on
Zero
| import os | |
| import torch | |
| import torch.distributed as dist | |
| def get_global_rank() -> int: | |
| """ | |
| Get the global rank, the global index of the GPU. | |
| """ | |
| return int(os.environ.get("RANK", "0")) | |
| def get_local_rank() -> int: | |
| """ | |
| Get the local rank, the local index of the GPU. | |
| """ | |
| return int(os.environ.get("LOCAL_RANK", "0")) | |
| def get_world_size() -> int: | |
| """ | |
| Get the world size, the total amount of GPUs. | |
| """ | |
| return int(os.environ.get("WORLD_SIZE", "1")) | |
| def get_device() -> torch.device: | |
| """ | |
| Get current rank device. | |
| """ | |
| return torch.device("cuda", get_local_rank()) | |
| def get_sequence_parallel_group(): | |
| """Get the sequence parallel group the caller rank belongs to.""" | |
| return _SEQUENCE_PARALLEL_GROUP | |
| def initialize_sequence_parallelism(sequence_parallel_size): | |
| assert int(get_world_size()) % sequence_parallel_size == 0 | |
| sequence_parallel_num_groups = int(get_world_size()) // sequence_parallel_size | |
| global _SEQUENCE_PARALLEL_GROUP | |
| for i in range(sequence_parallel_num_groups): | |
| ranks = range(i * sequence_parallel_size, | |
| (i + 1) * sequence_parallel_size) | |
| group = torch.distributed.new_group(ranks) | |
| if int(get_global_rank()) in ranks: | |
| print(f"Rank {get_global_rank()} joined group with ranks {list(ranks)}") | |
| _SEQUENCE_PARALLEL_GROUP = group |