File size: 3,906 Bytes
0084610
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
from torch.distributed._tensor import Replicate
from torch.distributed.tensor.parallel import (
    ColwiseParallel,
    PrepareModuleInput,
    RowwiseParallel,
    SequenceParallel,
    parallelize_module,
)


def parallelize_dit(model, tp_mesh):
    if tp_mesh.size() > 1:
        plan = {
            "in_layer": ColwiseParallel(),
            "out_layer": RowwiseParallel(
                output_layouts=Replicate(),
            ),
        }
        parallelize_module(model.time_embeddings, tp_mesh, plan)

        plan = {
            "in_layer": ColwiseParallel(
                output_layouts=Replicate(),
            )
        }
        parallelize_module(model.text_embeddings, tp_mesh, plan)
        parallelize_module(model.pooled_text_embeddings, tp_mesh, plan)
        parallelize_module(model.visual_embeddings, tp_mesh, plan)

        for visual_transformer_block in model.visual_transformer_blocks:
            plan = {
                "visual_modulation": PrepareModuleInput(
                    input_layouts=(None),
                    desired_input_layouts=(Replicate()),
                ),
                "visual_modulation.out_layer": ColwiseParallel(
                    output_layouts=Replicate(),
                ),
                "self_attention_norm": SequenceParallel(
                    sequence_dim=0, use_local_output=True
                ),
                "self_attention.to_query": ColwiseParallel(
                    input_layouts=Replicate(),
                ),
                "self_attention.to_key": ColwiseParallel(
                    input_layouts=Replicate(),
                ),
                "self_attention.to_value": ColwiseParallel(
                    input_layouts=Replicate(),
                ),
                "self_attention.query_norm": SequenceParallel(
                    sequence_dim=0, use_local_output=True
                ),
                "self_attention.key_norm": SequenceParallel(
                    sequence_dim=0, use_local_output=True
                ),
                "self_attention.out_layer": RowwiseParallel(
                    output_layouts=Replicate(),
                ),
                "cross_attention_norm": SequenceParallel(
                    sequence_dim=0, use_local_output=True
                ),
                "cross_attention.to_query": ColwiseParallel(
                    input_layouts=Replicate(),
                ),
                "cross_attention.to_key": ColwiseParallel(
                    input_layouts=Replicate(),
                ),
                "cross_attention.to_value": ColwiseParallel(
                    input_layouts=Replicate(),
                ),
                "cross_attention.query_norm": SequenceParallel(
                    sequence_dim=0, use_local_output=True
                ),
                "cross_attention.key_norm": SequenceParallel(
                    sequence_dim=0, use_local_output=True
                ),
                "cross_attention.out_layer": RowwiseParallel(
                    output_layouts=Replicate(),
                ),
                "feed_forward_norm": SequenceParallel(
                    sequence_dim=0, use_local_output=True
                ),
                "feed_forward.in_layer": ColwiseParallel(),
                "feed_forward.out_layer": RowwiseParallel(),
            }
            self_attn = visual_transformer_block.self_attention
            self_attn.num_heads = self_attn.num_heads // tp_mesh.size()

            cross_attn = visual_transformer_block.cross_attention
            cross_attn.num_heads = cross_attn.num_heads // tp_mesh.size()

            parallelize_module(visual_transformer_block, tp_mesh, plan)

        plan = {
            "out_layer": ColwiseParallel(
                output_layouts=Replicate(),
            ),
        }
        parallelize_module(model.out_layer, tp_mesh, plan)

    return model