Spaces:
Paused
Paused
File size: 3,906 Bytes
0084610 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 |
from torch.distributed._tensor import Replicate
from torch.distributed.tensor.parallel import (
ColwiseParallel,
PrepareModuleInput,
RowwiseParallel,
SequenceParallel,
parallelize_module,
)
def parallelize_dit(model, tp_mesh):
if tp_mesh.size() > 1:
plan = {
"in_layer": ColwiseParallel(),
"out_layer": RowwiseParallel(
output_layouts=Replicate(),
),
}
parallelize_module(model.time_embeddings, tp_mesh, plan)
plan = {
"in_layer": ColwiseParallel(
output_layouts=Replicate(),
)
}
parallelize_module(model.text_embeddings, tp_mesh, plan)
parallelize_module(model.pooled_text_embeddings, tp_mesh, plan)
parallelize_module(model.visual_embeddings, tp_mesh, plan)
for visual_transformer_block in model.visual_transformer_blocks:
plan = {
"visual_modulation": PrepareModuleInput(
input_layouts=(None),
desired_input_layouts=(Replicate()),
),
"visual_modulation.out_layer": ColwiseParallel(
output_layouts=Replicate(),
),
"self_attention_norm": SequenceParallel(
sequence_dim=0, use_local_output=True
),
"self_attention.to_query": ColwiseParallel(
input_layouts=Replicate(),
),
"self_attention.to_key": ColwiseParallel(
input_layouts=Replicate(),
),
"self_attention.to_value": ColwiseParallel(
input_layouts=Replicate(),
),
"self_attention.query_norm": SequenceParallel(
sequence_dim=0, use_local_output=True
),
"self_attention.key_norm": SequenceParallel(
sequence_dim=0, use_local_output=True
),
"self_attention.out_layer": RowwiseParallel(
output_layouts=Replicate(),
),
"cross_attention_norm": SequenceParallel(
sequence_dim=0, use_local_output=True
),
"cross_attention.to_query": ColwiseParallel(
input_layouts=Replicate(),
),
"cross_attention.to_key": ColwiseParallel(
input_layouts=Replicate(),
),
"cross_attention.to_value": ColwiseParallel(
input_layouts=Replicate(),
),
"cross_attention.query_norm": SequenceParallel(
sequence_dim=0, use_local_output=True
),
"cross_attention.key_norm": SequenceParallel(
sequence_dim=0, use_local_output=True
),
"cross_attention.out_layer": RowwiseParallel(
output_layouts=Replicate(),
),
"feed_forward_norm": SequenceParallel(
sequence_dim=0, use_local_output=True
),
"feed_forward.in_layer": ColwiseParallel(),
"feed_forward.out_layer": RowwiseParallel(),
}
self_attn = visual_transformer_block.self_attention
self_attn.num_heads = self_attn.num_heads // tp_mesh.size()
cross_attn = visual_transformer_block.cross_attention
cross_attn.num_heads = cross_attn.num_heads // tp_mesh.size()
parallelize_module(visual_transformer_block, tp_mesh, plan)
plan = {
"out_layer": ColwiseParallel(
output_layouts=Replicate(),
),
}
parallelize_module(model.out_layer, tp_mesh, plan)
return model
|