Kandinsky / kandinsky /models /parallelize.py
rahul7star's picture
Migrated from GitHub
0084610 verified
from torch.distributed._tensor import Replicate
from torch.distributed.tensor.parallel import (
ColwiseParallel,
PrepareModuleInput,
RowwiseParallel,
SequenceParallel,
parallelize_module,
)
def parallelize_dit(model, tp_mesh):
if tp_mesh.size() > 1:
plan = {
"in_layer": ColwiseParallel(),
"out_layer": RowwiseParallel(
output_layouts=Replicate(),
),
}
parallelize_module(model.time_embeddings, tp_mesh, plan)
plan = {
"in_layer": ColwiseParallel(
output_layouts=Replicate(),
)
}
parallelize_module(model.text_embeddings, tp_mesh, plan)
parallelize_module(model.pooled_text_embeddings, tp_mesh, plan)
parallelize_module(model.visual_embeddings, tp_mesh, plan)
for visual_transformer_block in model.visual_transformer_blocks:
plan = {
"visual_modulation": PrepareModuleInput(
input_layouts=(None),
desired_input_layouts=(Replicate()),
),
"visual_modulation.out_layer": ColwiseParallel(
output_layouts=Replicate(),
),
"self_attention_norm": SequenceParallel(
sequence_dim=0, use_local_output=True
),
"self_attention.to_query": ColwiseParallel(
input_layouts=Replicate(),
),
"self_attention.to_key": ColwiseParallel(
input_layouts=Replicate(),
),
"self_attention.to_value": ColwiseParallel(
input_layouts=Replicate(),
),
"self_attention.query_norm": SequenceParallel(
sequence_dim=0, use_local_output=True
),
"self_attention.key_norm": SequenceParallel(
sequence_dim=0, use_local_output=True
),
"self_attention.out_layer": RowwiseParallel(
output_layouts=Replicate(),
),
"cross_attention_norm": SequenceParallel(
sequence_dim=0, use_local_output=True
),
"cross_attention.to_query": ColwiseParallel(
input_layouts=Replicate(),
),
"cross_attention.to_key": ColwiseParallel(
input_layouts=Replicate(),
),
"cross_attention.to_value": ColwiseParallel(
input_layouts=Replicate(),
),
"cross_attention.query_norm": SequenceParallel(
sequence_dim=0, use_local_output=True
),
"cross_attention.key_norm": SequenceParallel(
sequence_dim=0, use_local_output=True
),
"cross_attention.out_layer": RowwiseParallel(
output_layouts=Replicate(),
),
"feed_forward_norm": SequenceParallel(
sequence_dim=0, use_local_output=True
),
"feed_forward.in_layer": ColwiseParallel(),
"feed_forward.out_layer": RowwiseParallel(),
}
self_attn = visual_transformer_block.self_attention
self_attn.num_heads = self_attn.num_heads // tp_mesh.size()
cross_attn = visual_transformer_block.cross_attention
cross_attn.num_heads = cross_attn.num_heads // tp_mesh.size()
parallelize_module(visual_transformer_block, tp_mesh, plan)
plan = {
"out_layer": ColwiseParallel(
output_layouts=Replicate(),
),
}
parallelize_module(model.out_layer, tp_mesh, plan)
return model