Spaces:
Running
Running
Add DualPipe-V support.
Browse files- README.md +6 -0
- assets/dualpipe_v.png +3 -0
- main.py +22 -0
- src/execution_model.py +7 -0
- src/strategies.py +340 -78
README.md
CHANGED
|
@@ -84,6 +84,12 @@ uv run python main.py strategy=dualpipe num_devices=8 num_stages=8 num_batches=2
|
|
| 84 |
```
|
| 85 |

|
| 86 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 87 |
### Running for 1F1B-batch-overlap strategy:
|
| 88 |
```bash
|
| 89 |
uv run python main.py strategy=1f1b_overlap num_devices=4 num_stages=4 num_batches=8
|
|
|
|
| 84 |
```
|
| 85 |

|
| 86 |
|
| 87 |
+
### Running for DualPipe-V strategy
|
| 88 |
+
```bash
|
| 89 |
+
uv run python main.py strategy=dualpipe_v num_devices=4 num_stages=8 num_batches=10
|
| 90 |
+
```
|
| 91 |
+

|
| 92 |
+
|
| 93 |
### Running for 1F1B-batch-overlap strategy:
|
| 94 |
```bash
|
| 95 |
uv run python main.py strategy=1f1b_overlap num_devices=4 num_stages=4 num_batches=8
|
assets/dualpipe_v.png
ADDED
|
Git LFS Details
|
main.py
CHANGED
|
@@ -4,6 +4,7 @@ from src.strategies import (
|
|
| 4 |
generate_1f1b_interleave_schedule,
|
| 5 |
generate_1f1b_overlap_schedule,
|
| 6 |
generate_1f1b_schedule,
|
|
|
|
| 7 |
generate_zero_bubble_1p_schedule,
|
| 8 |
generate_dualpipe_schedule,
|
| 9 |
)
|
|
@@ -29,6 +30,8 @@ def main(cfg: DictConfig) -> None:
|
|
| 29 |
run_1f1b_interleave_overlap(cfg)
|
| 30 |
elif cfg.strategy == "dualpipe":
|
| 31 |
run_dualpipe(cfg)
|
|
|
|
|
|
|
| 32 |
else:
|
| 33 |
raise ValueError(f"Unknown strategy: {cfg.strategy}")
|
| 34 |
|
|
@@ -152,5 +155,24 @@ def run_dualpipe(cfg: DictConfig) -> None:
|
|
| 152 |
schedule.execute()
|
| 153 |
visualize_pipeline_parallelism_dash(schedule, port=cfg.visualization_port)
|
| 154 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 155 |
if __name__ == "__main__":
|
| 156 |
main()
|
|
|
|
| 4 |
generate_1f1b_interleave_schedule,
|
| 5 |
generate_1f1b_overlap_schedule,
|
| 6 |
generate_1f1b_schedule,
|
| 7 |
+
generate_dualpipe_v_schedule,
|
| 8 |
generate_zero_bubble_1p_schedule,
|
| 9 |
generate_dualpipe_schedule,
|
| 10 |
)
|
|
|
|
| 30 |
run_1f1b_interleave_overlap(cfg)
|
| 31 |
elif cfg.strategy == "dualpipe":
|
| 32 |
run_dualpipe(cfg)
|
| 33 |
+
elif cfg.strategy == "dualpipe_v":
|
| 34 |
+
run_dualpipe_v(cfg)
|
| 35 |
else:
|
| 36 |
raise ValueError(f"Unknown strategy: {cfg.strategy}")
|
| 37 |
|
|
|
|
| 155 |
schedule.execute()
|
| 156 |
visualize_pipeline_parallelism_dash(schedule, port=cfg.visualization_port)
|
| 157 |
|
| 158 |
+
def run_dualpipe_v(cfg: DictConfig) -> None:
|
| 159 |
+
"""Run DualPipeV pipeline parallelism simulation."""
|
| 160 |
+
# Convert OmegaConf to dict for op_times if it exists
|
| 161 |
+
op_times = (
|
| 162 |
+
OmegaConf.to_container(cfg.op_times) if hasattr(cfg, "op_times") else None
|
| 163 |
+
)
|
| 164 |
+
schedule_config = ScheduleConfig(
|
| 165 |
+
num_devices=cfg.num_devices,
|
| 166 |
+
num_stages=cfg.num_stages,
|
| 167 |
+
num_batches=cfg.num_batches,
|
| 168 |
+
p2p_latency=cfg.p2p_latency,
|
| 169 |
+
op_times=op_times,
|
| 170 |
+
split_backward=True,
|
| 171 |
+
placement_strategy="dualpipe_v",
|
| 172 |
+
)
|
| 173 |
+
schedule = generate_dualpipe_v_schedule(schedule_config)
|
| 174 |
+
schedule.execute()
|
| 175 |
+
visualize_pipeline_parallelism_dash(schedule, port=cfg.visualization_port)
|
| 176 |
+
|
| 177 |
if __name__ == "__main__":
|
| 178 |
main()
|
src/execution_model.py
CHANGED
|
@@ -158,6 +158,13 @@ class ScheduleConfig:
|
|
| 158 |
self.device_to_stages = defaultdict(list)
|
| 159 |
for i in range(self.num_stages):
|
| 160 |
self.device_to_stages[i] = [i, self.num_stages - i - 1]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 161 |
else:
|
| 162 |
raise ValueError(f"Invalid placement strategy: {self.placement_strategy}")
|
| 163 |
|
|
|
|
| 158 |
self.device_to_stages = defaultdict(list)
|
| 159 |
for i in range(self.num_stages):
|
| 160 |
self.device_to_stages[i] = [i, self.num_stages - i - 1]
|
| 161 |
+
elif self.placement_strategy == "dualpipe_v":
|
| 162 |
+
assert self.num_devices % 2 == 0, "DualPipe-V requires an even number of devices"
|
| 163 |
+
assert self.num_stages == self.num_devices * 2, "DualPipe-V requires num_stages == num_devices * 2"
|
| 164 |
+
assert self.split_backward, "DualPipe-V requires split_backward=True"
|
| 165 |
+
self.device_to_stages = defaultdict(list)
|
| 166 |
+
for i in range(self.num_devices):
|
| 167 |
+
self.device_to_stages[i] = [i, self.num_stages - i - 1]
|
| 168 |
else:
|
| 169 |
raise ValueError(f"Invalid placement strategy: {self.placement_strategy}")
|
| 170 |
|
src/strategies.py
CHANGED
|
@@ -5,7 +5,9 @@ from src.execution_model import OverlappedOperation, Operation, Schedule, Schedu
|
|
| 5 |
def generate_1f1b_schedule(config: ScheduleConfig):
|
| 6 |
schedule = Schedule(config)
|
| 7 |
|
| 8 |
-
assert
|
|
|
|
|
|
|
| 9 |
|
| 10 |
for i in range(config.num_devices):
|
| 11 |
fwd_batch_id = 0
|
|
@@ -42,7 +44,9 @@ def generate_zero_bubble_1p_schedule(config: ScheduleConfig):
|
|
| 42 |
# Create a new schedule with split_backward=True to support backward_D and backward_W operations
|
| 43 |
schedule = Schedule(config)
|
| 44 |
total_batches = config.num_batches
|
| 45 |
-
assert
|
|
|
|
|
|
|
| 46 |
assert config.split_backward, "ZB-1P requires split_backward=True"
|
| 47 |
|
| 48 |
for i in range(config.num_devices):
|
|
@@ -73,7 +77,7 @@ def generate_zero_bubble_1p_schedule(config: ScheduleConfig):
|
|
| 73 |
bwd_w_batch_id += 1
|
| 74 |
bwd_d_batch_id += 1
|
| 75 |
fwd_batch_id += 1
|
| 76 |
-
|
| 77 |
for _ in range(cooldown_batches):
|
| 78 |
schedule.device_queues[i].add_operation(
|
| 79 |
schedule.get_op(bwd_d_batch_id, i, "backward_D")
|
|
@@ -85,7 +89,7 @@ def generate_zero_bubble_1p_schedule(config: ScheduleConfig):
|
|
| 85 |
|
| 86 |
bwd_w_batch_id += 1
|
| 87 |
bwd_d_batch_id += 1
|
| 88 |
-
|
| 89 |
while bwd_w_batch_id < total_batches:
|
| 90 |
schedule.device_queues[i].add_operation(
|
| 91 |
schedule.get_op(bwd_w_batch_id, i, "backward_W")
|
|
@@ -98,7 +102,9 @@ def generate_zero_bubble_1p_schedule(config: ScheduleConfig):
|
|
| 98 |
def generate_1f1b_overlap_schedule(config: ScheduleConfig):
|
| 99 |
schedule = Schedule(config)
|
| 100 |
|
| 101 |
-
assert
|
|
|
|
|
|
|
| 102 |
|
| 103 |
for i in range(config.num_devices):
|
| 104 |
fwd_batch_id = 0
|
|
@@ -132,11 +138,11 @@ def generate_1f1b_overlap_schedule(config: ScheduleConfig):
|
|
| 132 |
|
| 133 |
|
| 134 |
def _get_pp_rank_microbatches(
|
| 135 |
-
num_microbatches,
|
| 136 |
num_devices,
|
| 137 |
device_id,
|
| 138 |
-
num_stages_per_device,
|
| 139 |
-
microbatch_group_size_per_vp_stage,
|
| 140 |
):
|
| 141 |
"""Get the number of total, warmup, and remaining microbatches in PP scheduling."""
|
| 142 |
total_num_microbatches = num_microbatches * num_stages_per_device
|
|
@@ -147,7 +153,9 @@ def _get_pp_rank_microbatches(
|
|
| 147 |
# stage ID (more forward passes for earlier stages, later stages can
|
| 148 |
# immediately start with 1F1B).
|
| 149 |
num_warmup_microbatches = (num_devices - device_id - 1) * 2
|
| 150 |
-
num_warmup_microbatches += (
|
|
|
|
|
|
|
| 151 |
else:
|
| 152 |
# forward_backward_no_pipelining
|
| 153 |
num_warmup_microbatches = 1
|
|
@@ -158,27 +166,34 @@ def _get_pp_rank_microbatches(
|
|
| 158 |
return num_warmup_microbatches
|
| 159 |
|
| 160 |
|
| 161 |
-
def _get_schedule_table(
|
|
|
|
|
|
|
| 162 |
"""Get the schedule table for PP scheduling.
|
| 163 |
|
| 164 |
Create a tunable schedule lookup table.
|
| 165 |
-
The schedule lookup table uses the virtual_microbatch_id to find the corresponding microbatch_id and model_chunk_id.
|
| 166 |
For example, the tunable schedule table for PP2 N3M5 with VP2 is constructed as below:
|
| 167 |
virtual_microbatch_id | 0 1 2 3 4 5 6 7 8 9
|
| 168 |
microbatch_id | 0 1 2 0 1 2 3 4 3 4
|
| 169 |
-
model_chunk_id | 0 0 0 1 1 1 0 0 1 1
|
| 170 |
"""
|
| 171 |
schedule_table = []
|
| 172 |
for min_microbatch_id_in_group in range(
|
| 173 |
0, num_microbatches, microbatch_group_size_per_vp_stage
|
| 174 |
):
|
| 175 |
-
if
|
|
|
|
|
|
|
|
|
|
| 176 |
# Construct schedule for the last microbatch group
|
| 177 |
schedule_table.extend(
|
| 178 |
[
|
| 179 |
(microbatch_id, model_chunk_id)
|
| 180 |
for model_chunk_id in range(num_model_chunks)
|
| 181 |
-
for microbatch_id in range(
|
|
|
|
|
|
|
| 182 |
]
|
| 183 |
)
|
| 184 |
else:
|
|
@@ -196,7 +211,9 @@ def _get_schedule_table(num_microbatches, num_model_chunks, microbatch_group_siz
|
|
| 196 |
return schedule_table
|
| 197 |
|
| 198 |
|
| 199 |
-
def _convert_schedule_table_to_order(
|
|
|
|
|
|
|
| 200 |
"""Convert a tunable schedule lookup table to the te.make_graphed_callables() accepted
|
| 201 |
order format. For example, the tunable schedule table for PP2 N3M5 with VP2 is as below:
|
| 202 |
virtual_microbatch_id | 0 1 2 3 4 5 6 7 8 9
|
|
@@ -225,7 +242,7 @@ def _convert_schedule_table_to_order(num_warmup_microbatches, num_model_chunks,
|
|
| 225 |
# Some codes are copied from Megatron-LM
|
| 226 |
def generate_1f1b_interleave_schedule(config: ScheduleConfig):
|
| 227 |
schedule = Schedule(config)
|
| 228 |
-
|
| 229 |
for device_id in range(config.num_devices):
|
| 230 |
microbatch_group_size_per_vp_stage = config.num_devices
|
| 231 |
num_warmup_microbatches = _get_pp_rank_microbatches(
|
|
@@ -244,25 +261,29 @@ def generate_1f1b_interleave_schedule(config: ScheduleConfig):
|
|
| 244 |
|
| 245 |
order = _convert_schedule_table_to_order(
|
| 246 |
num_warmup_microbatches,
|
| 247 |
-
num_model_chunks=config.num_stages_per_device,
|
| 248 |
schedule_table=schedule_table,
|
| 249 |
)
|
| 250 |
|
| 251 |
cur_stage_microbatch_id = {}
|
| 252 |
-
for i in range(1, config.num_stages_per_device+1):
|
| 253 |
cur_stage_microbatch_id[i] = 0
|
| 254 |
cur_stage_microbatch_id[-i] = 0
|
| 255 |
for order_item in order:
|
| 256 |
-
stage_id = schedule.device_queues[device_id].stages[abs(order_item)-1]
|
| 257 |
|
| 258 |
if order_item > 0:
|
| 259 |
op_type = "forward"
|
| 260 |
micro_batch_id = cur_stage_microbatch_id[order_item]
|
| 261 |
-
cur_stage_microbatch_id[order_item] =
|
|
|
|
|
|
|
| 262 |
elif order_item < 0:
|
| 263 |
op_type = "backward"
|
| 264 |
micro_batch_id = cur_stage_microbatch_id[order_item]
|
| 265 |
-
cur_stage_microbatch_id[order_item] =
|
|
|
|
|
|
|
| 266 |
else:
|
| 267 |
raise ValueError(f"Invalid order item: {order_item}")
|
| 268 |
schedule.device_queues[device_id].add_operation(
|
|
@@ -270,6 +291,7 @@ def generate_1f1b_interleave_schedule(config: ScheduleConfig):
|
|
| 270 |
)
|
| 271 |
return schedule
|
| 272 |
|
|
|
|
| 273 |
def generate_1f1b_interleave_overlap_schedule(config: ScheduleConfig):
|
| 274 |
schedule = Schedule(config)
|
| 275 |
|
|
@@ -290,15 +312,15 @@ def generate_1f1b_interleave_overlap_schedule(config: ScheduleConfig):
|
|
| 290 |
)
|
| 291 |
|
| 292 |
# NOTE: Add one more warmup microbatch for overlapped operations!
|
| 293 |
-
num_warmup_microbatches += 1
|
| 294 |
order = _convert_schedule_table_to_order(
|
| 295 |
num_warmup_microbatches,
|
| 296 |
-
num_model_chunks=config.num_stages_per_device,
|
| 297 |
schedule_table=schedule_table,
|
| 298 |
)
|
| 299 |
|
| 300 |
cur_stage_microbatch_id = {}
|
| 301 |
-
for i in range(1, config.num_stages_per_device+1):
|
| 302 |
cur_stage_microbatch_id[i] = 0
|
| 303 |
cur_stage_microbatch_id[-i] = 0
|
| 304 |
i = 0
|
|
@@ -310,27 +332,40 @@ def generate_1f1b_interleave_overlap_schedule(config: ScheduleConfig):
|
|
| 310 |
assert order_item > 0
|
| 311 |
op_type = "forward"
|
| 312 |
micro_batch_id = cur_stage_microbatch_id[order_item]
|
| 313 |
-
cur_stage_microbatch_id[order_item] =
|
|
|
|
|
|
|
| 314 |
|
| 315 |
-
stage_id = schedule.device_queues[device_id].stages[abs(order_item)-1]
|
| 316 |
schedule.device_queues[device_id].add_operation(
|
| 317 |
schedule.get_op(micro_batch_id, stage_id, op_type)
|
| 318 |
)
|
| 319 |
i += 1
|
| 320 |
-
elif
|
|
|
|
|
|
|
|
|
|
| 321 |
order_item_a = order[i]
|
| 322 |
-
order_item_b = order[i+1]
|
| 323 |
|
| 324 |
op_type_a = "forward" if order_item_a > 0 else "backward"
|
| 325 |
micro_batch_id_a = cur_stage_microbatch_id[order_item_a]
|
| 326 |
-
cur_stage_microbatch_id[order_item_a] =
|
|
|
|
|
|
|
| 327 |
|
| 328 |
op_type_b = "forward" if order_item_b > 0 else "backward"
|
| 329 |
micro_batch_id_b = cur_stage_microbatch_id[order_item_b]
|
| 330 |
-
cur_stage_microbatch_id[order_item_b] =
|
|
|
|
|
|
|
| 331 |
|
| 332 |
-
stage_id_a = schedule.device_queues[device_id].stages[
|
| 333 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 334 |
|
| 335 |
op_a = schedule.get_op(micro_batch_id_a, stage_id_a, op_type_a)
|
| 336 |
op_b = schedule.get_op(micro_batch_id_b, stage_id_b, op_type_b)
|
|
@@ -345,14 +380,15 @@ def generate_1f1b_interleave_overlap_schedule(config: ScheduleConfig):
|
|
| 345 |
assert order_item < 0
|
| 346 |
op_type = "backward"
|
| 347 |
micro_batch_id = cur_stage_microbatch_id[order_item]
|
| 348 |
-
cur_stage_microbatch_id[order_item] =
|
|
|
|
|
|
|
| 349 |
|
| 350 |
-
stage_id = schedule.device_queues[device_id].stages[abs(order_item)-1]
|
| 351 |
schedule.device_queues[device_id].add_operation(
|
| 352 |
schedule.get_op(micro_batch_id, stage_id, op_type)
|
| 353 |
)
|
| 354 |
i += 1
|
| 355 |
-
|
| 356 |
|
| 357 |
return schedule
|
| 358 |
|
|
@@ -365,23 +401,23 @@ def create_overlapped_ops(schedule, batch_id1, batch_id2, stage_id, type1, type2
|
|
| 365 |
# Get the operations from the schedule
|
| 366 |
op1 = schedule.ops[(batch_id1, stage_id, type1)]
|
| 367 |
op2 = schedule.ops[(batch_id2, stage_id, type2)]
|
| 368 |
-
|
| 369 |
# Create the overlapped operation
|
| 370 |
overlapped_op = OverlappedOperation([op1, op2])
|
| 371 |
-
|
| 372 |
# Register in the schedule to ensure proper tracking
|
| 373 |
schedule.register_overlapped_operation(overlapped_op)
|
| 374 |
-
|
| 375 |
return overlapped_op
|
| 376 |
|
| 377 |
|
| 378 |
def generate_dualpipe_schedule(config: ScheduleConfig):
|
| 379 |
"""
|
| 380 |
Implements the DualPipe scheduling strategy.
|
| 381 |
-
|
| 382 |
DualPipe is a bidirectional pipeline parallelism algorithm that achieves full overlap of forward
|
| 383 |
and backward computation-communication phases and reduces pipeline bubbles.
|
| 384 |
-
|
| 385 |
The DualPipe strategy has the following characteristics:
|
| 386 |
1. Requires placement_strategy="dualpipe" in ScheduleConfig (set automatically)
|
| 387 |
2. Each device handles both a forward stage and a reverse stage
|
|
@@ -396,15 +432,27 @@ def generate_dualpipe_schedule(config: ScheduleConfig):
|
|
| 396 |
A Schedule object with the DualPipe scheduling
|
| 397 |
"""
|
| 398 |
# Ensure placement strategy is set for Schedule initialization
|
| 399 |
-
assert
|
|
|
|
|
|
|
| 400 |
# Assertions based on DualPipe requirements
|
| 401 |
-
assert
|
| 402 |
-
|
| 403 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 404 |
# Assertion based on original implementation: num_chunks >= num_ranks * 2
|
| 405 |
# Here, M (config.num_batches) corresponds to half_num_chunks
|
| 406 |
-
assert
|
| 407 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 408 |
|
| 409 |
schedule = Schedule(config, init_ops=False)
|
| 410 |
|
|
@@ -414,10 +462,12 @@ def generate_dualpipe_schedule(config: ScheduleConfig):
|
|
| 414 |
half_num_chunks = config.num_batches // 2
|
| 415 |
num_half_ranks = num_devices // 2
|
| 416 |
|
| 417 |
-
fwd_batch_ids = defaultdict(int)
|
| 418 |
-
bwd_d_batch_ids = defaultdict(int)
|
| 419 |
|
| 420 |
-
waited_weight_grad = [
|
|
|
|
|
|
|
| 421 |
|
| 422 |
for device_id in range(num_devices):
|
| 423 |
is_in_second_half = device_id >= num_half_ranks
|
|
@@ -431,16 +481,18 @@ def generate_dualpipe_schedule(config: ScheduleConfig):
|
|
| 431 |
fwd_batch_ids[device_id, 1] = config.num_batches // 2
|
| 432 |
bwd_d_batch_ids[device_id, 0] = 0
|
| 433 |
bwd_d_batch_ids[device_id, 1] = config.num_batches // 2
|
|
|
|
| 434 |
def get_stage_for_phase(device_id, phase, num_stages, is_in_second_half):
|
| 435 |
-
stage_fwd_dir = device_id
|
| 436 |
-
stage_rev_dir =
|
|
|
|
|
|
|
| 437 |
if not is_in_second_half:
|
| 438 |
# First half: phase 0 -> fwd_dir, phase 1 -> rev_dir
|
| 439 |
return stage_fwd_dir if phase == 0 else stage_rev_dir
|
| 440 |
else:
|
| 441 |
# Second half: phase 0 -> rev_dir, phase 1 -> fwd_dir
|
| 442 |
return stage_rev_dir if phase == 0 else stage_fwd_dir
|
| 443 |
-
|
| 444 |
|
| 445 |
def add_op_to_queue(device_id, stage_id, op_type, batch_id):
|
| 446 |
# Retrieve the correct pre-initialized Operation object
|
|
@@ -462,7 +514,7 @@ def generate_dualpipe_schedule(config: ScheduleConfig):
|
|
| 462 |
batch_id = bwd_d_batch_ids[device_id, phase]
|
| 463 |
add_op_to_queue(device_id, stage_id, "backward", batch_id)
|
| 464 |
bwd_d_batch_ids[device_id, phase] += 1
|
| 465 |
-
|
| 466 |
def _schedule_backward_input_chunk(device_id, phase, is_in_second_half):
|
| 467 |
"""Schedules a backward_D compute operation."""
|
| 468 |
stage_id = get_stage_for_phase(device_id, phase, num_stages, is_in_second_half)
|
|
@@ -476,11 +528,17 @@ def generate_dualpipe_schedule(config: ScheduleConfig):
|
|
| 476 |
stage_id, batch_id = waited_weight_grad[device_id].popleft()
|
| 477 |
add_op_to_queue(device_id, stage_id, "backward_W", batch_id)
|
| 478 |
|
| 479 |
-
def _schedule_forward_backward_chunk(
|
|
|
|
|
|
|
| 480 |
"""Schedules an overlapped forward and backward_D compute operation."""
|
| 481 |
-
fwd_stage_id = get_stage_for_phase(
|
| 482 |
-
|
| 483 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 484 |
fwd_batch_id = fwd_batch_ids[device_id, fwd_phase]
|
| 485 |
|
| 486 |
fwd_op = Operation(fwd_batch_id, fwd_stage_id, "forward")
|
|
@@ -493,58 +551,67 @@ def generate_dualpipe_schedule(config: ScheduleConfig):
|
|
| 493 |
bwd_d_batch_ids[device_id, bwd_phase] += 1
|
| 494 |
|
| 495 |
# Create and register the overlapped operation
|
| 496 |
-
overlapped_op = OverlappedOperation([fwd_op, bwd_op])
|
| 497 |
schedule.register_overlapped_operation(overlapped_op)
|
| 498 |
-
|
| 499 |
# Add the overlapped operation to the queue
|
| 500 |
schedule.device_queues[device_id].add_operation(overlapped_op)
|
| 501 |
|
| 502 |
-
|
| 503 |
# Process each device (rank in original code)
|
| 504 |
for device_id in range(num_devices):
|
| 505 |
half_rank = min(device_id, num_devices - 1 - device_id)
|
| 506 |
is_in_second_half = device_id >= num_half_ranks
|
| 507 |
-
is_middle_rank = (device_id == num_half_ranks - 1) or (
|
|
|
|
|
|
|
| 508 |
|
| 509 |
# Map original steps to operation additions
|
| 510 |
# Step 1: nF0
|
| 511 |
step_1_count = (num_half_ranks - half_rank - 1) * 2
|
| 512 |
for _ in range(step_1_count):
|
| 513 |
-
_schedule_forward_chunk(device_id, 0, is_in_second_half)
|
| 514 |
|
| 515 |
# Step 2: nF0F1
|
| 516 |
step_2_count = half_rank + 1
|
| 517 |
for i in range(step_2_count):
|
| 518 |
-
_schedule_forward_chunk(device_id, 0, is_in_second_half)
|
| 519 |
-
_schedule_forward_chunk(device_id, 1, is_in_second_half)
|
| 520 |
|
| 521 |
# Step 3: nB1W1F1
|
| 522 |
step_3_count = num_half_ranks - half_rank - 1
|
| 523 |
for _ in range(step_3_count):
|
| 524 |
-
_schedule_backward_input_chunk(device_id, 1, is_in_second_half)
|
| 525 |
-
_schedule_backward_weight_chunk(
|
|
|
|
|
|
|
| 526 |
_schedule_forward_chunk(device_id, 1, is_in_second_half) # F1
|
| 527 |
|
| 528 |
# Step 4 (Main step): nF0B1F1B0
|
| 529 |
step_4_count = half_num_chunks - num_devices + half_rank + 1
|
| 530 |
for i in range(step_4_count):
|
| 531 |
# if i == 0 and is_middle_rank:
|
| 532 |
-
|
| 533 |
-
|
| 534 |
-
|
| 535 |
-
|
| 536 |
# else:
|
| 537 |
# Overlap F0 and B1_D, then schedule W1
|
| 538 |
-
_schedule_forward_backward_chunk(
|
| 539 |
-
|
|
|
|
|
|
|
| 540 |
# Overlap F1 and B0_D, then schedule W0
|
| 541 |
-
_schedule_forward_backward_chunk(
|
|
|
|
|
|
|
| 542 |
|
| 543 |
# Step 5: nB1F1B0
|
| 544 |
step_5_count = num_half_ranks - half_rank - 1
|
| 545 |
for _ in range(step_5_count):
|
| 546 |
-
_schedule_backward_chunk(device_id, 1, is_in_second_half)
|
| 547 |
-
_schedule_forward_backward_chunk(
|
|
|
|
|
|
|
| 548 |
|
| 549 |
# Step 6: nB1B0
|
| 550 |
step_6_count = half_rank + 1
|
|
@@ -566,8 +633,10 @@ def generate_dualpipe_schedule(config: ScheduleConfig):
|
|
| 566 |
# Step 7: nWB0
|
| 567 |
step_7_count = num_half_ranks - half_rank - 1
|
| 568 |
for _ in range(step_7_count):
|
| 569 |
-
_schedule_backward_weight_chunk(
|
| 570 |
-
|
|
|
|
|
|
|
| 571 |
|
| 572 |
# Step 8: nW
|
| 573 |
step_8_count = half_rank + 1
|
|
@@ -575,7 +644,200 @@ def generate_dualpipe_schedule(config: ScheduleConfig):
|
|
| 575 |
# W0 uses gradients from B0_D scheduled in steps 4, 5, 6.
|
| 576 |
# W1 uses gradients from B1_D scheduled in steps 3, 4, 5, 6.
|
| 577 |
# The last W0 gradients correspond to B0_D from step 6 or 7.
|
| 578 |
-
_schedule_backward_weight_chunk(
|
|
|
|
|
|
|
| 579 |
|
| 580 |
return schedule
|
| 581 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 5 |
def generate_1f1b_schedule(config: ScheduleConfig):
|
| 6 |
schedule = Schedule(config)
|
| 7 |
|
| 8 |
+
assert (
|
| 9 |
+
config.num_devices == config.num_stages
|
| 10 |
+
), "num_devices must be equal to num_stages for 1F1B"
|
| 11 |
|
| 12 |
for i in range(config.num_devices):
|
| 13 |
fwd_batch_id = 0
|
|
|
|
| 44 |
# Create a new schedule with split_backward=True to support backward_D and backward_W operations
|
| 45 |
schedule = Schedule(config)
|
| 46 |
total_batches = config.num_batches
|
| 47 |
+
assert (
|
| 48 |
+
config.num_devices == config.num_stages
|
| 49 |
+
), "num_devices must be equal to num_stages for ZB-1P"
|
| 50 |
assert config.split_backward, "ZB-1P requires split_backward=True"
|
| 51 |
|
| 52 |
for i in range(config.num_devices):
|
|
|
|
| 77 |
bwd_w_batch_id += 1
|
| 78 |
bwd_d_batch_id += 1
|
| 79 |
fwd_batch_id += 1
|
| 80 |
+
|
| 81 |
for _ in range(cooldown_batches):
|
| 82 |
schedule.device_queues[i].add_operation(
|
| 83 |
schedule.get_op(bwd_d_batch_id, i, "backward_D")
|
|
|
|
| 89 |
|
| 90 |
bwd_w_batch_id += 1
|
| 91 |
bwd_d_batch_id += 1
|
| 92 |
+
|
| 93 |
while bwd_w_batch_id < total_batches:
|
| 94 |
schedule.device_queues[i].add_operation(
|
| 95 |
schedule.get_op(bwd_w_batch_id, i, "backward_W")
|
|
|
|
| 102 |
def generate_1f1b_overlap_schedule(config: ScheduleConfig):
|
| 103 |
schedule = Schedule(config)
|
| 104 |
|
| 105 |
+
assert (
|
| 106 |
+
config.num_devices == config.num_stages
|
| 107 |
+
), "num_devices must be equal to num_stages for 1F1B"
|
| 108 |
|
| 109 |
for i in range(config.num_devices):
|
| 110 |
fwd_batch_id = 0
|
|
|
|
| 138 |
|
| 139 |
|
| 140 |
def _get_pp_rank_microbatches(
|
| 141 |
+
num_microbatches,
|
| 142 |
num_devices,
|
| 143 |
device_id,
|
| 144 |
+
num_stages_per_device,
|
| 145 |
+
microbatch_group_size_per_vp_stage,
|
| 146 |
):
|
| 147 |
"""Get the number of total, warmup, and remaining microbatches in PP scheduling."""
|
| 148 |
total_num_microbatches = num_microbatches * num_stages_per_device
|
|
|
|
| 153 |
# stage ID (more forward passes for earlier stages, later stages can
|
| 154 |
# immediately start with 1F1B).
|
| 155 |
num_warmup_microbatches = (num_devices - device_id - 1) * 2
|
| 156 |
+
num_warmup_microbatches += (
|
| 157 |
+
num_stages_per_device - 1
|
| 158 |
+
) * microbatch_group_size_per_vp_stage
|
| 159 |
else:
|
| 160 |
# forward_backward_no_pipelining
|
| 161 |
num_warmup_microbatches = 1
|
|
|
|
| 166 |
return num_warmup_microbatches
|
| 167 |
|
| 168 |
|
| 169 |
+
def _get_schedule_table(
|
| 170 |
+
num_microbatches, num_model_chunks, microbatch_group_size_per_vp_stage
|
| 171 |
+
):
|
| 172 |
"""Get the schedule table for PP scheduling.
|
| 173 |
|
| 174 |
Create a tunable schedule lookup table.
|
| 175 |
+
The schedule lookup table uses the virtual_microbatch_id to find the corresponding microbatch_id and model_chunk_id.
|
| 176 |
For example, the tunable schedule table for PP2 N3M5 with VP2 is constructed as below:
|
| 177 |
virtual_microbatch_id | 0 1 2 3 4 5 6 7 8 9
|
| 178 |
microbatch_id | 0 1 2 0 1 2 3 4 3 4
|
| 179 |
+
model_chunk_id | 0 0 0 1 1 1 0 0 1 1
|
| 180 |
"""
|
| 181 |
schedule_table = []
|
| 182 |
for min_microbatch_id_in_group in range(
|
| 183 |
0, num_microbatches, microbatch_group_size_per_vp_stage
|
| 184 |
):
|
| 185 |
+
if (
|
| 186 |
+
min_microbatch_id_in_group + microbatch_group_size_per_vp_stage
|
| 187 |
+
>= num_microbatches
|
| 188 |
+
):
|
| 189 |
# Construct schedule for the last microbatch group
|
| 190 |
schedule_table.extend(
|
| 191 |
[
|
| 192 |
(microbatch_id, model_chunk_id)
|
| 193 |
for model_chunk_id in range(num_model_chunks)
|
| 194 |
+
for microbatch_id in range(
|
| 195 |
+
min_microbatch_id_in_group, num_microbatches
|
| 196 |
+
)
|
| 197 |
]
|
| 198 |
)
|
| 199 |
else:
|
|
|
|
| 211 |
return schedule_table
|
| 212 |
|
| 213 |
|
| 214 |
+
def _convert_schedule_table_to_order(
|
| 215 |
+
num_warmup_microbatches, num_model_chunks, schedule_table
|
| 216 |
+
):
|
| 217 |
"""Convert a tunable schedule lookup table to the te.make_graphed_callables() accepted
|
| 218 |
order format. For example, the tunable schedule table for PP2 N3M5 with VP2 is as below:
|
| 219 |
virtual_microbatch_id | 0 1 2 3 4 5 6 7 8 9
|
|
|
|
| 242 |
# Some codes are copied from Megatron-LM
|
| 243 |
def generate_1f1b_interleave_schedule(config: ScheduleConfig):
|
| 244 |
schedule = Schedule(config)
|
| 245 |
+
|
| 246 |
for device_id in range(config.num_devices):
|
| 247 |
microbatch_group_size_per_vp_stage = config.num_devices
|
| 248 |
num_warmup_microbatches = _get_pp_rank_microbatches(
|
|
|
|
| 261 |
|
| 262 |
order = _convert_schedule_table_to_order(
|
| 263 |
num_warmup_microbatches,
|
| 264 |
+
num_model_chunks=config.num_stages_per_device,
|
| 265 |
schedule_table=schedule_table,
|
| 266 |
)
|
| 267 |
|
| 268 |
cur_stage_microbatch_id = {}
|
| 269 |
+
for i in range(1, config.num_stages_per_device + 1):
|
| 270 |
cur_stage_microbatch_id[i] = 0
|
| 271 |
cur_stage_microbatch_id[-i] = 0
|
| 272 |
for order_item in order:
|
| 273 |
+
stage_id = schedule.device_queues[device_id].stages[abs(order_item) - 1]
|
| 274 |
|
| 275 |
if order_item > 0:
|
| 276 |
op_type = "forward"
|
| 277 |
micro_batch_id = cur_stage_microbatch_id[order_item]
|
| 278 |
+
cur_stage_microbatch_id[order_item] = (
|
| 279 |
+
cur_stage_microbatch_id[order_item] + 1
|
| 280 |
+
)
|
| 281 |
elif order_item < 0:
|
| 282 |
op_type = "backward"
|
| 283 |
micro_batch_id = cur_stage_microbatch_id[order_item]
|
| 284 |
+
cur_stage_microbatch_id[order_item] = (
|
| 285 |
+
cur_stage_microbatch_id[order_item] + 1
|
| 286 |
+
)
|
| 287 |
else:
|
| 288 |
raise ValueError(f"Invalid order item: {order_item}")
|
| 289 |
schedule.device_queues[device_id].add_operation(
|
|
|
|
| 291 |
)
|
| 292 |
return schedule
|
| 293 |
|
| 294 |
+
|
| 295 |
def generate_1f1b_interleave_overlap_schedule(config: ScheduleConfig):
|
| 296 |
schedule = Schedule(config)
|
| 297 |
|
|
|
|
| 312 |
)
|
| 313 |
|
| 314 |
# NOTE: Add one more warmup microbatch for overlapped operations!
|
| 315 |
+
num_warmup_microbatches += 1
|
| 316 |
order = _convert_schedule_table_to_order(
|
| 317 |
num_warmup_microbatches,
|
| 318 |
+
num_model_chunks=config.num_stages_per_device,
|
| 319 |
schedule_table=schedule_table,
|
| 320 |
)
|
| 321 |
|
| 322 |
cur_stage_microbatch_id = {}
|
| 323 |
+
for i in range(1, config.num_stages_per_device + 1):
|
| 324 |
cur_stage_microbatch_id[i] = 0
|
| 325 |
cur_stage_microbatch_id[-i] = 0
|
| 326 |
i = 0
|
|
|
|
| 332 |
assert order_item > 0
|
| 333 |
op_type = "forward"
|
| 334 |
micro_batch_id = cur_stage_microbatch_id[order_item]
|
| 335 |
+
cur_stage_microbatch_id[order_item] = (
|
| 336 |
+
cur_stage_microbatch_id[order_item] + 1
|
| 337 |
+
)
|
| 338 |
|
| 339 |
+
stage_id = schedule.device_queues[device_id].stages[abs(order_item) - 1]
|
| 340 |
schedule.device_queues[device_id].add_operation(
|
| 341 |
schedule.get_op(micro_batch_id, stage_id, op_type)
|
| 342 |
)
|
| 343 |
i += 1
|
| 344 |
+
elif (
|
| 345 |
+
i >= num_warmup_microbatches
|
| 346 |
+
and i < num_warmup_microbatches + num_overlapped_batches - 1
|
| 347 |
+
):
|
| 348 |
order_item_a = order[i]
|
| 349 |
+
order_item_b = order[i + 1]
|
| 350 |
|
| 351 |
op_type_a = "forward" if order_item_a > 0 else "backward"
|
| 352 |
micro_batch_id_a = cur_stage_microbatch_id[order_item_a]
|
| 353 |
+
cur_stage_microbatch_id[order_item_a] = (
|
| 354 |
+
cur_stage_microbatch_id[order_item_a] + 1
|
| 355 |
+
)
|
| 356 |
|
| 357 |
op_type_b = "forward" if order_item_b > 0 else "backward"
|
| 358 |
micro_batch_id_b = cur_stage_microbatch_id[order_item_b]
|
| 359 |
+
cur_stage_microbatch_id[order_item_b] = (
|
| 360 |
+
cur_stage_microbatch_id[order_item_b] + 1
|
| 361 |
+
)
|
| 362 |
|
| 363 |
+
stage_id_a = schedule.device_queues[device_id].stages[
|
| 364 |
+
abs(order_item_a) - 1
|
| 365 |
+
]
|
| 366 |
+
stage_id_b = schedule.device_queues[device_id].stages[
|
| 367 |
+
abs(order_item_b) - 1
|
| 368 |
+
]
|
| 369 |
|
| 370 |
op_a = schedule.get_op(micro_batch_id_a, stage_id_a, op_type_a)
|
| 371 |
op_b = schedule.get_op(micro_batch_id_b, stage_id_b, op_type_b)
|
|
|
|
| 380 |
assert order_item < 0
|
| 381 |
op_type = "backward"
|
| 382 |
micro_batch_id = cur_stage_microbatch_id[order_item]
|
| 383 |
+
cur_stage_microbatch_id[order_item] = (
|
| 384 |
+
cur_stage_microbatch_id[order_item] + 1
|
| 385 |
+
)
|
| 386 |
|
| 387 |
+
stage_id = schedule.device_queues[device_id].stages[abs(order_item) - 1]
|
| 388 |
schedule.device_queues[device_id].add_operation(
|
| 389 |
schedule.get_op(micro_batch_id, stage_id, op_type)
|
| 390 |
)
|
| 391 |
i += 1
|
|
|
|
| 392 |
|
| 393 |
return schedule
|
| 394 |
|
|
|
|
| 401 |
# Get the operations from the schedule
|
| 402 |
op1 = schedule.ops[(batch_id1, stage_id, type1)]
|
| 403 |
op2 = schedule.ops[(batch_id2, stage_id, type2)]
|
| 404 |
+
|
| 405 |
# Create the overlapped operation
|
| 406 |
overlapped_op = OverlappedOperation([op1, op2])
|
| 407 |
+
|
| 408 |
# Register in the schedule to ensure proper tracking
|
| 409 |
schedule.register_overlapped_operation(overlapped_op)
|
| 410 |
+
|
| 411 |
return overlapped_op
|
| 412 |
|
| 413 |
|
| 414 |
def generate_dualpipe_schedule(config: ScheduleConfig):
|
| 415 |
"""
|
| 416 |
Implements the DualPipe scheduling strategy.
|
| 417 |
+
|
| 418 |
DualPipe is a bidirectional pipeline parallelism algorithm that achieves full overlap of forward
|
| 419 |
and backward computation-communication phases and reduces pipeline bubbles.
|
| 420 |
+
|
| 421 |
The DualPipe strategy has the following characteristics:
|
| 422 |
1. Requires placement_strategy="dualpipe" in ScheduleConfig (set automatically)
|
| 423 |
2. Each device handles both a forward stage and a reverse stage
|
|
|
|
| 432 |
A Schedule object with the DualPipe scheduling
|
| 433 |
"""
|
| 434 |
# Ensure placement strategy is set for Schedule initialization
|
| 435 |
+
assert (
|
| 436 |
+
config.placement_strategy == "dualpipe"
|
| 437 |
+
), "DualPipe schedule currently only supports placement_strategy='dualpipe'"
|
| 438 |
# Assertions based on DualPipe requirements
|
| 439 |
+
assert (
|
| 440 |
+
config.num_stages % 2 == 0
|
| 441 |
+
), "DualPipe requires an even number of stages (and devices)"
|
| 442 |
+
assert (
|
| 443 |
+
config.num_devices == config.num_stages
|
| 444 |
+
), "DualPipe requires num_devices == num_stages"
|
| 445 |
+
assert (
|
| 446 |
+
config.num_batches % 2 == 0
|
| 447 |
+
), "DualPipe requires an even number of microbatches (config.num_batches)"
|
| 448 |
# Assertion based on original implementation: num_chunks >= num_ranks * 2
|
| 449 |
# Here, M (config.num_batches) corresponds to half_num_chunks
|
| 450 |
+
assert (
|
| 451 |
+
config.num_batches >= config.num_devices
|
| 452 |
+
), "DualPipe requires config.num_batches >= config.num_devices"
|
| 453 |
+
assert (
|
| 454 |
+
config.split_backward
|
| 455 |
+
), "DualPipe schedule currently only supports split_backward=True"
|
| 456 |
|
| 457 |
schedule = Schedule(config, init_ops=False)
|
| 458 |
|
|
|
|
| 462 |
half_num_chunks = config.num_batches // 2
|
| 463 |
num_half_ranks = num_devices // 2
|
| 464 |
|
| 465 |
+
fwd_batch_ids = defaultdict(int) # (device_id, phase) -> batch_id
|
| 466 |
+
bwd_d_batch_ids = defaultdict(int) # (device_id, phase) -> batch_id
|
| 467 |
|
| 468 |
+
waited_weight_grad = [
|
| 469 |
+
deque() for _ in range(num_devices)
|
| 470 |
+
] # (device_id, ) -> List[(stage_id, batch_id)]
|
| 471 |
|
| 472 |
for device_id in range(num_devices):
|
| 473 |
is_in_second_half = device_id >= num_half_ranks
|
|
|
|
| 481 |
fwd_batch_ids[device_id, 1] = config.num_batches // 2
|
| 482 |
bwd_d_batch_ids[device_id, 0] = 0
|
| 483 |
bwd_d_batch_ids[device_id, 1] = config.num_batches // 2
|
| 484 |
+
|
| 485 |
def get_stage_for_phase(device_id, phase, num_stages, is_in_second_half):
|
| 486 |
+
stage_fwd_dir = device_id # Stage handled when moving forward (0 to N-1)
|
| 487 |
+
stage_rev_dir = (
|
| 488 |
+
num_stages - 1 - device_id
|
| 489 |
+
) # Stage handled when moving backward (N-1 to 0)
|
| 490 |
if not is_in_second_half:
|
| 491 |
# First half: phase 0 -> fwd_dir, phase 1 -> rev_dir
|
| 492 |
return stage_fwd_dir if phase == 0 else stage_rev_dir
|
| 493 |
else:
|
| 494 |
# Second half: phase 0 -> rev_dir, phase 1 -> fwd_dir
|
| 495 |
return stage_rev_dir if phase == 0 else stage_fwd_dir
|
|
|
|
| 496 |
|
| 497 |
def add_op_to_queue(device_id, stage_id, op_type, batch_id):
|
| 498 |
# Retrieve the correct pre-initialized Operation object
|
|
|
|
| 514 |
batch_id = bwd_d_batch_ids[device_id, phase]
|
| 515 |
add_op_to_queue(device_id, stage_id, "backward", batch_id)
|
| 516 |
bwd_d_batch_ids[device_id, phase] += 1
|
| 517 |
+
|
| 518 |
def _schedule_backward_input_chunk(device_id, phase, is_in_second_half):
|
| 519 |
"""Schedules a backward_D compute operation."""
|
| 520 |
stage_id = get_stage_for_phase(device_id, phase, num_stages, is_in_second_half)
|
|
|
|
| 528 |
stage_id, batch_id = waited_weight_grad[device_id].popleft()
|
| 529 |
add_op_to_queue(device_id, stage_id, "backward_W", batch_id)
|
| 530 |
|
| 531 |
+
def _schedule_forward_backward_chunk(
|
| 532 |
+
device_id, fwd_phase, bwd_phase, is_in_second_half
|
| 533 |
+
):
|
| 534 |
"""Schedules an overlapped forward and backward_D compute operation."""
|
| 535 |
+
fwd_stage_id = get_stage_for_phase(
|
| 536 |
+
device_id, fwd_phase, num_stages, is_in_second_half
|
| 537 |
+
)
|
| 538 |
+
bwd_stage_id = get_stage_for_phase(
|
| 539 |
+
device_id, bwd_phase, num_stages, is_in_second_half
|
| 540 |
+
)
|
| 541 |
+
|
| 542 |
fwd_batch_id = fwd_batch_ids[device_id, fwd_phase]
|
| 543 |
|
| 544 |
fwd_op = Operation(fwd_batch_id, fwd_stage_id, "forward")
|
|
|
|
| 551 |
bwd_d_batch_ids[device_id, bwd_phase] += 1
|
| 552 |
|
| 553 |
# Create and register the overlapped operation
|
| 554 |
+
overlapped_op = OverlappedOperation([fwd_op, bwd_op])
|
| 555 |
schedule.register_overlapped_operation(overlapped_op)
|
| 556 |
+
|
| 557 |
# Add the overlapped operation to the queue
|
| 558 |
schedule.device_queues[device_id].add_operation(overlapped_op)
|
| 559 |
|
|
|
|
| 560 |
# Process each device (rank in original code)
|
| 561 |
for device_id in range(num_devices):
|
| 562 |
half_rank = min(device_id, num_devices - 1 - device_id)
|
| 563 |
is_in_second_half = device_id >= num_half_ranks
|
| 564 |
+
is_middle_rank = (device_id == num_half_ranks - 1) or (
|
| 565 |
+
device_id == num_half_ranks
|
| 566 |
+
)
|
| 567 |
|
| 568 |
# Map original steps to operation additions
|
| 569 |
# Step 1: nF0
|
| 570 |
step_1_count = (num_half_ranks - half_rank - 1) * 2
|
| 571 |
for _ in range(step_1_count):
|
| 572 |
+
_schedule_forward_chunk(device_id, 0, is_in_second_half) # F0
|
| 573 |
|
| 574 |
# Step 2: nF0F1
|
| 575 |
step_2_count = half_rank + 1
|
| 576 |
for i in range(step_2_count):
|
| 577 |
+
_schedule_forward_chunk(device_id, 0, is_in_second_half) # F0
|
| 578 |
+
_schedule_forward_chunk(device_id, 1, is_in_second_half) # F1
|
| 579 |
|
| 580 |
# Step 3: nB1W1F1
|
| 581 |
step_3_count = num_half_ranks - half_rank - 1
|
| 582 |
for _ in range(step_3_count):
|
| 583 |
+
_schedule_backward_input_chunk(device_id, 1, is_in_second_half) # B1_D
|
| 584 |
+
_schedule_backward_weight_chunk(
|
| 585 |
+
device_id,
|
| 586 |
+
) # W1
|
| 587 |
_schedule_forward_chunk(device_id, 1, is_in_second_half) # F1
|
| 588 |
|
| 589 |
# Step 4 (Main step): nF0B1F1B0
|
| 590 |
step_4_count = half_num_chunks - num_devices + half_rank + 1
|
| 591 |
for i in range(step_4_count):
|
| 592 |
# if i == 0 and is_middle_rank:
|
| 593 |
+
# Schedule F0, B1_D, W1 sequentially for middle ranks on first iteration
|
| 594 |
+
# _schedule_forward_chunk(device_id, 0, is_in_second_half) # F0
|
| 595 |
+
# _schedule_backward_chunk(device_id, 1, is_in_second_half)# B1
|
| 596 |
+
# _schedule_backward_weight_chunk(device_id, 1, is_in_second_half) # W1
|
| 597 |
# else:
|
| 598 |
# Overlap F0 and B1_D, then schedule W1
|
| 599 |
+
_schedule_forward_backward_chunk(
|
| 600 |
+
device_id, 0, 1, is_in_second_half
|
| 601 |
+
) # F0+B1
|
| 602 |
+
|
| 603 |
# Overlap F1 and B0_D, then schedule W0
|
| 604 |
+
_schedule_forward_backward_chunk(
|
| 605 |
+
device_id, 1, 0, is_in_second_half
|
| 606 |
+
) # F1+B0
|
| 607 |
|
| 608 |
# Step 5: nB1F1B0
|
| 609 |
step_5_count = num_half_ranks - half_rank - 1
|
| 610 |
for _ in range(step_5_count):
|
| 611 |
+
_schedule_backward_chunk(device_id, 1, is_in_second_half) # B1_D + B1_W
|
| 612 |
+
_schedule_forward_backward_chunk(
|
| 613 |
+
device_id, 1, 0, is_in_second_half
|
| 614 |
+
) # F1+B0
|
| 615 |
|
| 616 |
# Step 6: nB1B0
|
| 617 |
step_6_count = half_rank + 1
|
|
|
|
| 633 |
# Step 7: nWB0
|
| 634 |
step_7_count = num_half_ranks - half_rank - 1
|
| 635 |
for _ in range(step_7_count):
|
| 636 |
+
_schedule_backward_weight_chunk(
|
| 637 |
+
device_id
|
| 638 |
+
) # W1 (use gradient from B1_D scheduled previously)
|
| 639 |
+
_schedule_backward_input_chunk(device_id, 0, is_in_second_half) # B0_D
|
| 640 |
|
| 641 |
# Step 8: nW
|
| 642 |
step_8_count = half_rank + 1
|
|
|
|
| 644 |
# W0 uses gradients from B0_D scheduled in steps 4, 5, 6.
|
| 645 |
# W1 uses gradients from B1_D scheduled in steps 3, 4, 5, 6.
|
| 646 |
# The last W0 gradients correspond to B0_D from step 6 or 7.
|
| 647 |
+
_schedule_backward_weight_chunk(
|
| 648 |
+
device_id
|
| 649 |
+
) # W0 (use gradient from B0_D scheduled previously)
|
| 650 |
|
| 651 |
return schedule
|
| 652 |
|
| 653 |
+
|
| 654 |
+
def generate_dualpipe_v_schedule(config: ScheduleConfig):
|
| 655 |
+
"""
|
| 656 |
+
Implements the DualPipe-V scheduling strategy based on dualpipe_v.py.
|
| 657 |
+
|
| 658 |
+
DualPipe-V aims to improve upon DualPipe by utilizing Zero Bubble (ZB)
|
| 659 |
+
techniques, further reducing pipeline bubbles by overlapping gradient
|
| 660 |
+
computation (backward_D) and weight updates (backward_W).
|
| 661 |
+
|
| 662 |
+
Key characteristics:
|
| 663 |
+
1. Requires placement_strategy="dualpipe".
|
| 664 |
+
2. Each device handles a forward stage and a reverse stage.
|
| 665 |
+
3. Requires split_backward=True.
|
| 666 |
+
4. Overlaps forward (F) and backward_D (B_D) operations.
|
| 667 |
+
5. Schedules backward_W (W) operations separately.
|
| 668 |
+
6. Uses Zero Bubble logic in later steps to delay W operations.
|
| 669 |
+
7. Assumes config.num_batches corresponds to the total number of microbatches (`num_chunks` in dualpipe_v.py).
|
| 670 |
+
|
| 671 |
+
Args:
|
| 672 |
+
config: The scheduling configuration.
|
| 673 |
+
|
| 674 |
+
Returns:
|
| 675 |
+
A Schedule object with the DualPipe-V scheduling.
|
| 676 |
+
"""
|
| 677 |
+
schedule = Schedule(config, init_ops=False)
|
| 678 |
+
|
| 679 |
+
assert config.num_stages == config.num_devices * 2, "num_stages must be equal to num_devices * 2 for DualPipe-V"
|
| 680 |
+
assert config.split_backward, "DualPipe-V requires split_backward=True"
|
| 681 |
+
|
| 682 |
+
num_stages = config.num_stages
|
| 683 |
+
num_devices = config.num_devices
|
| 684 |
+
|
| 685 |
+
fwd_batch_ids = defaultdict(int) # (device_id, chunk_id) -> batch_id
|
| 686 |
+
bwd_d_batch_ids = defaultdict(int) # (device_id, chunk_id) -> batch_id
|
| 687 |
+
|
| 688 |
+
waited_weight_grad = [
|
| 689 |
+
deque() for _ in range(num_devices)
|
| 690 |
+
] # (device_id, ) -> List[(stage_id, batch_id)]
|
| 691 |
+
|
| 692 |
+
for device_id in range(num_devices):
|
| 693 |
+
fwd_batch_ids[device_id, 0] = 0
|
| 694 |
+
fwd_batch_ids[device_id, 1] = 0
|
| 695 |
+
bwd_d_batch_ids[device_id, 0] = 0
|
| 696 |
+
bwd_d_batch_ids[device_id, 1] = 0
|
| 697 |
+
|
| 698 |
+
|
| 699 |
+
def add_op_to_queue(device_id, stage_id, op_type, batch_id):
|
| 700 |
+
# Retrieve the correct pre-initialized Operation object
|
| 701 |
+
op = Operation(batch_id, stage_id, op_type)
|
| 702 |
+
schedule.register_operation(op)
|
| 703 |
+
# Add to the device queue
|
| 704 |
+
schedule.device_queues[device_id].add_operation(op)
|
| 705 |
+
|
| 706 |
+
def get_stage_for_chunk(device_id, chunk_id):
|
| 707 |
+
if chunk_id == 0:
|
| 708 |
+
# Forward direction stage for this device
|
| 709 |
+
return device_id
|
| 710 |
+
else:
|
| 711 |
+
# Reverse direction stage for this device
|
| 712 |
+
return num_stages - 1 - device_id
|
| 713 |
+
|
| 714 |
+
def _schedule_forward_chunk(device_id, chunk_id):
|
| 715 |
+
"""Schedules a forward compute operation."""
|
| 716 |
+
stage_id = get_stage_for_chunk(device_id, chunk_id)
|
| 717 |
+
batch_id = fwd_batch_ids[device_id, chunk_id]
|
| 718 |
+
add_op_to_queue(device_id, stage_id, "forward", batch_id)
|
| 719 |
+
fwd_batch_ids[device_id, chunk_id] += 1
|
| 720 |
+
|
| 721 |
+
def _schedule_backward_chunk(device_id, chunk_id, enable_zb=False):
|
| 722 |
+
"""Schedules a backward_D compute operation."""
|
| 723 |
+
stage_id = get_stage_for_chunk(device_id, chunk_id)
|
| 724 |
+
batch_id = bwd_d_batch_ids[device_id, chunk_id]
|
| 725 |
+
if enable_zb:
|
| 726 |
+
add_op_to_queue(device_id, stage_id, "backward_D", batch_id)
|
| 727 |
+
waited_weight_grad[device_id].append((stage_id, batch_id))
|
| 728 |
+
else:
|
| 729 |
+
add_op_to_queue(device_id, stage_id, "backward", batch_id)
|
| 730 |
+
bwd_d_batch_ids[device_id, chunk_id] += 1
|
| 731 |
+
|
| 732 |
+
def _schedule_backward_weight_chunk(device_id):
|
| 733 |
+
"""Schedules a backward_W compute operation."""
|
| 734 |
+
assert waited_weight_grad[device_id], f"Device {device_id} has no waited weight grads to schedule"
|
| 735 |
+
stage_id, batch_id = waited_weight_grad[device_id].popleft()
|
| 736 |
+
add_op_to_queue(device_id, stage_id, "backward_W", batch_id)
|
| 737 |
+
|
| 738 |
+
def _schedule_forward_backward_chunk(
|
| 739 |
+
device_id, fwd_chunk_id, bwd_chunk_id
|
| 740 |
+
):
|
| 741 |
+
"""Schedules an overlapped forward and backward_D compute operation."""
|
| 742 |
+
fwd_stage_id = get_stage_for_chunk(device_id, fwd_chunk_id)
|
| 743 |
+
bwd_stage_id = get_stage_for_chunk(device_id, bwd_chunk_id)
|
| 744 |
+
|
| 745 |
+
fwd_batch_id = fwd_batch_ids[device_id, fwd_chunk_id]
|
| 746 |
+
fwd_op = Operation(fwd_batch_id, fwd_stage_id, "forward")
|
| 747 |
+
schedule.register_operation(fwd_op)
|
| 748 |
+
fwd_batch_ids[device_id, fwd_chunk_id] += 1
|
| 749 |
+
|
| 750 |
+
bwd_batch_id_d = bwd_d_batch_ids[device_id, bwd_chunk_id]
|
| 751 |
+
# Schedule backward_D
|
| 752 |
+
bwd_op = Operation(bwd_batch_id_d, bwd_stage_id, "backward")
|
| 753 |
+
schedule.register_operation(bwd_op)
|
| 754 |
+
bwd_d_batch_ids[device_id, bwd_chunk_id] += 1
|
| 755 |
+
|
| 756 |
+
# Create and register the overlapped operation
|
| 757 |
+
overlapped_op = OverlappedOperation([fwd_op, bwd_op])
|
| 758 |
+
schedule.register_overlapped_operation(overlapped_op)
|
| 759 |
+
|
| 760 |
+
# Add the overlapped operation to the queue
|
| 761 |
+
schedule.device_queues[device_id].add_operation(overlapped_op)
|
| 762 |
+
|
| 763 |
+
# Process each device (rank in original code)
|
| 764 |
+
for device_id in range(num_devices):
|
| 765 |
+
# Step 1: nF0
|
| 766 |
+
step_1_count = (num_devices - device_id - 1) * 2
|
| 767 |
+
for _ in range(step_1_count):
|
| 768 |
+
_schedule_forward_chunk(device_id, 0) # F0
|
| 769 |
+
|
| 770 |
+
# Step 2: nF0F1
|
| 771 |
+
step_2_count = device_id + 1
|
| 772 |
+
for i in range(step_2_count):
|
| 773 |
+
_schedule_forward_chunk(device_id, 0) # F0
|
| 774 |
+
_schedule_forward_chunk(device_id, 1) # F1
|
| 775 |
+
|
| 776 |
+
# Step 3: nB1W1F1 (Use zero bubble for B1)
|
| 777 |
+
step_3_count = num_devices - device_id - 1
|
| 778 |
+
for _ in range(step_3_count):
|
| 779 |
+
_schedule_backward_chunk(device_id, 1, enable_zb=True) # B1_D (ZB enabled)
|
| 780 |
+
_schedule_backward_weight_chunk(device_id) # W1
|
| 781 |
+
_schedule_forward_chunk(device_id, 1) # F1
|
| 782 |
+
|
| 783 |
+
# Step 4 (Main step): nF0B1F1B0 (Overlapped F and B_D)
|
| 784 |
+
num_batches = config.num_batches
|
| 785 |
+
step_4_count = num_batches - num_devices * 2 + device_id + 1
|
| 786 |
+
is_last_rank = (device_id == num_devices - 1) # Check if it's the last rank
|
| 787 |
+
|
| 788 |
+
for i in range(step_4_count):
|
| 789 |
+
if i == 0:
|
| 790 |
+
if is_last_rank:
|
| 791 |
+
# Special handling for the first iteration on the last rank
|
| 792 |
+
# Schedule F0, B1, W1 sequentially
|
| 793 |
+
_schedule_forward_chunk(device_id, 0) # F0
|
| 794 |
+
_schedule_backward_chunk(device_id, 1, enable_zb=False) # B1_D
|
| 795 |
+
else:
|
| 796 |
+
# Overlap F0 and B1
|
| 797 |
+
_schedule_forward_backward_chunk(device_id, 0, 1) # F0 + B1_D
|
| 798 |
+
else:
|
| 799 |
+
# Overlap F1 and B0_D
|
| 800 |
+
_schedule_forward_backward_chunk(device_id, 0, 1) # F0B1
|
| 801 |
+
_schedule_forward_backward_chunk(device_id, 1, 0) #
|
| 802 |
+
|
| 803 |
+
|
| 804 |
+
# Step 5: nB1F1B0
|
| 805 |
+
step_5_count = num_devices - device_id - 1
|
| 806 |
+
for _ in range(step_5_count):
|
| 807 |
+
# Schedule B1 (B1_D + B1_W) sequentially
|
| 808 |
+
_schedule_backward_chunk(device_id, 1, enable_zb=False) # B1_D + W1
|
| 809 |
+
|
| 810 |
+
# Overlap F1 and B0
|
| 811 |
+
_schedule_forward_backward_chunk(device_id, 1, 0) # F1 + B0
|
| 812 |
+
|
| 813 |
+
# Step 6: nB1B0 (The second half of the chunks use zero bubble)
|
| 814 |
+
step_6_count = device_id + 1
|
| 815 |
+
enable_zb = False
|
| 816 |
+
for i in range(step_6_count):
|
| 817 |
+
# Determine if ZB should be enabled for B1
|
| 818 |
+
if i == step_6_count // 2 and device_id % 2 == 1:
|
| 819 |
+
enable_zb = True
|
| 820 |
+
_schedule_backward_chunk(device_id, 1, enable_zb=enable_zb) # B1_D
|
| 821 |
+
|
| 822 |
+
# Determine if ZB should be enabled for B0
|
| 823 |
+
# ZB is enabled after the midpoint check for B0
|
| 824 |
+
if i == step_6_count // 2 and device_id % 2 == 0:
|
| 825 |
+
enable_zb = True # Enable ZB for the rest, including B0
|
| 826 |
+
_schedule_backward_chunk(device_id, 0, enable_zb=enable_zb) # B0_D
|
| 827 |
+
|
| 828 |
+
# Step 7: nWB0 (Use zero bubble for B0)
|
| 829 |
+
step_7_count = num_devices - device_id - 1
|
| 830 |
+
for _ in range(step_7_count):
|
| 831 |
+
_schedule_backward_weight_chunk(device_id) # W1 (from ZB B1_D in Step 6 or Step 3)
|
| 832 |
+
_schedule_backward_chunk(device_id, 0, enable_zb=True) # B0_D
|
| 833 |
+
|
| 834 |
+
# Step 8: nW
|
| 835 |
+
step_8_count = device_id + 1
|
| 836 |
+
for _ in range(step_8_count):
|
| 837 |
+
_schedule_backward_weight_chunk(device_id) # W0 (from ZB B0_D in Step 6 or 7) or W1 (from ZB B1_D in Step 6)
|
| 838 |
+
|
| 839 |
+
# Final check: Ensure all waited gradients are processed
|
| 840 |
+
assert not waited_weight_grad[device_id], f"Device {device_id} has remaining waited weight grads: {waited_weight_grad[device_id]}"
|
| 841 |
+
|
| 842 |
+
|
| 843 |
+
return schedule
|