File size: 1,441 Bytes
a3a2e41
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48

import os
import torch
import torch.distributed as dist


def get_global_rank() -> int:
    """

    Get the global rank, the global index of the GPU.

    """
    return int(os.environ.get("RANK", "0"))


def get_local_rank() -> int:
    """

    Get the local rank, the local index of the GPU.

    """
    return int(os.environ.get("LOCAL_RANK", "0"))


def get_world_size() -> int:
    """

    Get the world size, the total amount of GPUs.

    """
    return int(os.environ.get("WORLD_SIZE", "1"))


def get_device() -> torch.device:
    """

    Get current rank device.

    """
    return torch.device("cuda", get_local_rank())

def get_sequence_parallel_group():
    """Get the sequence parallel group the caller rank belongs to."""
    return _SEQUENCE_PARALLEL_GROUP

def initialize_sequence_parallelism(sequence_parallel_size):
    assert int(get_world_size()) % sequence_parallel_size == 0
    sequence_parallel_num_groups = int(get_world_size()) // sequence_parallel_size
    global _SEQUENCE_PARALLEL_GROUP
    for i in range(sequence_parallel_num_groups):
        ranks = range(i * sequence_parallel_size,
                    (i + 1) * sequence_parallel_size)
        group = torch.distributed.new_group(ranks)
        if int(get_global_rank()) in ranks:
            print(f"Rank {get_global_rank()} joined group with ranks {list(ranks)}")
            _SEQUENCE_PARALLEL_GROUP = group