Spaces:
				
			
			
	
			
			
					
		Running
		
	
	
	
			
			
	
	
	
	
		
		
					
		Running
		
	Add support for DualPipe.
Browse files- .gitignore +1 -0
- README.md +21 -6
- assets/dualpipe.png +3 -0
- conf/config.yaml +3 -0
- main.py +23 -0
- src/execution_model.py +81 -19
- src/strategies.py +227 -2
- src/visualizer.py +2 -12
    	
        .gitignore
    CHANGED
    
    | @@ -3,6 +3,7 @@ | |
| 3 | 
             
            uv.lock
         | 
| 4 | 
             
            outputs/
         | 
| 5 | 
             
            .cursor/*
         | 
|  | |
| 6 |  | 
| 7 | 
             
            # Uncomment below if you want to include these files
         | 
| 8 | 
             
            # !assets/*.png
         | 
|  | |
| 3 | 
             
            uv.lock
         | 
| 4 | 
             
            outputs/
         | 
| 5 | 
             
            .cursor/*
         | 
| 6 | 
            +
            *.json
         | 
| 7 |  | 
| 8 | 
             
            # Uncomment below if you want to include these files
         | 
| 9 | 
             
            # !assets/*.png
         | 
    	
        README.md
    CHANGED
    
    | @@ -18,6 +18,7 @@ Pipeline parallelism is a technique used to train large models by partitioning t | |
| 18 | 
             
              - Zero-Bubble 1F1B (ZB-1P)
         | 
| 19 | 
             
              - 1F1B with computation-communication overlap
         | 
| 20 | 
             
              - Interleaved 1F1B with computation-communication overlap
         | 
|  | |
| 21 |  | 
| 22 | 
             
            - **Visualization**:
         | 
| 23 | 
             
              - Interactive visualization dashboard using Plotly/Dash
         | 
| @@ -56,6 +57,12 @@ uv run python main.py strategy=zb1p num_devices=4 num_stages=4 num_batches=8 | |
| 56 | 
             
            ```
         | 
| 57 | 
             
            
         | 
| 58 |  | 
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
| 59 | 
             
            ### Running for 1F1B-batch-overlap strategy:
         | 
| 60 | 
             
            ```bash
         | 
| 61 | 
             
            uv run python main.py strategy=1f1b_overlap num_devices=4 num_stages=4 num_batches=8
         | 
| @@ -68,10 +75,24 @@ uv run python main.py strategy=1f1b_interleave_overlap num_devices=4 num_stages= | |
| 68 | 
             
            ```
         | 
| 69 | 
             
            
         | 
| 70 |  | 
|  | |
| 71 | 
             
            ## Configuration
         | 
| 72 |  | 
| 73 | 
             
            The default configuration is in `conf/config.yaml`. You can override any parameter on the command line or create configuration groups for different scenarios.
         | 
| 74 |  | 
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
| 75 | 
             
            ### Using Different Configuration Files
         | 
| 76 |  | 
| 77 | 
             
            You can use different configuration files with Hydra in several ways:
         | 
| @@ -90,12 +111,6 @@ You can use different configuration files with Hydra in several ways: | |
| 90 | 
             
               uv run python main.py --config-name=model_A
         | 
| 91 | 
             
               ```
         | 
| 92 |  | 
| 93 | 
            -
            #### Override Specific Parameters
         | 
| 94 | 
            -
             | 
| 95 | 
            -
            You can also override specific parameters at runtime:
         | 
| 96 | 
            -
            ```bash
         | 
| 97 | 
            -
            uv run python main.py op_times.forward=0.5 op_times.backward=1.0 num_batches=6
         | 
| 98 | 
            -
            ```
         | 
| 99 |  | 
| 100 | 
             
            ## Project Structure
         | 
| 101 |  | 
|  | |
| 18 | 
             
              - Zero-Bubble 1F1B (ZB-1P)
         | 
| 19 | 
             
              - 1F1B with computation-communication overlap
         | 
| 20 | 
             
              - Interleaved 1F1B with computation-communication overlap
         | 
| 21 | 
            +
              - DualPipe (Bidirectional pipeline parallelism with full forward-backward overlap)
         | 
| 22 |  | 
| 23 | 
             
            - **Visualization**:
         | 
| 24 | 
             
              - Interactive visualization dashboard using Plotly/Dash
         | 
|  | |
| 57 | 
             
            ```
         | 
| 58 | 
             
            
         | 
| 59 |  | 
| 60 | 
            +
            ### Running for DualPipe strategy:
         | 
| 61 | 
            +
            ```bash
         | 
| 62 | 
            +
            uv run python main.py strategy=dualpipe num_devices=8 num_stages=8 num_batches=20
         | 
| 63 | 
            +
            ```
         | 
| 64 | 
            +
            
         | 
| 65 | 
            +
             | 
| 66 | 
             
            ### Running for 1F1B-batch-overlap strategy:
         | 
| 67 | 
             
            ```bash
         | 
| 68 | 
             
            uv run python main.py strategy=1f1b_overlap num_devices=4 num_stages=4 num_batches=8
         | 
|  | |
| 75 | 
             
            ```
         | 
| 76 | 
             
            
         | 
| 77 |  | 
| 78 | 
            +
             | 
| 79 | 
             
            ## Configuration
         | 
| 80 |  | 
| 81 | 
             
            The default configuration is in `conf/config.yaml`. You can override any parameter on the command line or create configuration groups for different scenarios.
         | 
| 82 |  | 
| 83 | 
            +
            #### Override Specific Parameters
         | 
| 84 | 
            +
             | 
| 85 | 
            +
            You can override specific parameters at runtime:
         | 
| 86 | 
            +
            ```bash
         | 
| 87 | 
            +
            uv run python main.py op_times.forward=0.5 op_times.backward=1.0 num_batches=6
         | 
| 88 | 
            +
            ```
         | 
| 89 | 
            +
             | 
| 90 | 
            +
            Use DualPipe as an example, you can manually set different time for forward/backward/backward_D/backward_W/overlapped_forward_backward:
         | 
| 91 | 
            +
            ```bash
         | 
| 92 | 
            +
            uv run python main.py strategy=dualpipe num_devices=8 num_stages=8 num_batches=32 op_times.forward=1.0 op_times.backward=2.0 op_times.backward_D=1.0 op_times.backward_W=1.0 op_times.overlapped_forward_backward=2.5
         | 
| 93 | 
            +
            ```
         | 
| 94 | 
            +
             | 
| 95 | 
            +
             | 
| 96 | 
             
            ### Using Different Configuration Files
         | 
| 97 |  | 
| 98 | 
             
            You can use different configuration files with Hydra in several ways:
         | 
|  | |
| 111 | 
             
               uv run python main.py --config-name=model_A
         | 
| 112 | 
             
               ```
         | 
| 113 |  | 
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
| 114 |  | 
| 115 | 
             
            ## Project Structure
         | 
| 116 |  | 
    	
        assets/dualpipe.png
    ADDED
    
    |   | 
| Git LFS Details
 | 
    	
        conf/config.yaml
    CHANGED
    
    | @@ -11,6 +11,9 @@ op_times: | |
| 11 | 
             
              # Option 1: Simple configuration (same time for all stages)
         | 
| 12 | 
             
              forward: 1.0
         | 
| 13 | 
             
              backward: 2.0
         | 
|  | |
|  | |
|  | |
| 14 |  | 
| 15 | 
             
              # Option 2: Commented example of stage-specific configuration
         | 
| 16 | 
             
              # forward:
         | 
|  | |
| 11 | 
             
              # Option 1: Simple configuration (same time for all stages)
         | 
| 12 | 
             
              forward: 1.0
         | 
| 13 | 
             
              backward: 2.0
         | 
| 14 | 
            +
              backward_D: 1.0
         | 
| 15 | 
            +
              backward_W: 1.0
         | 
| 16 | 
            +
              overlapped_forward_backward: 2.0
         | 
| 17 |  | 
| 18 | 
             
              # Option 2: Commented example of stage-specific configuration
         | 
| 19 | 
             
              # forward:
         | 
    	
        main.py
    CHANGED
    
    | @@ -5,6 +5,7 @@ from src.strategies import ( | |
| 5 | 
             
                generate_1f1b_overlap_schedule,
         | 
| 6 | 
             
                generate_1f1b_schedule,
         | 
| 7 | 
             
                generate_zero_bubble_1p_schedule,
         | 
|  | |
| 8 | 
             
            )
         | 
| 9 | 
             
            from src.visualizer import visualize_pipeline_parallelism_dash
         | 
| 10 | 
             
            import hydra
         | 
| @@ -26,6 +27,8 @@ def main(cfg: DictConfig) -> None: | |
| 26 | 
             
                    run_1f1b_overlap(cfg)
         | 
| 27 | 
             
                elif cfg.strategy == "1f1b_interleave_overlap":
         | 
| 28 | 
             
                    run_1f1b_interleave_overlap(cfg)
         | 
|  | |
|  | |
| 29 | 
             
                else:
         | 
| 30 | 
             
                    raise ValueError(f"Unknown strategy: {cfg.strategy}")
         | 
| 31 |  | 
| @@ -129,5 +132,25 @@ def run_1f1b_interleave_overlap(cfg: DictConfig) -> None: | |
| 129 | 
             
                schedule.execute()
         | 
| 130 | 
             
                visualize_pipeline_parallelism_dash(schedule, port=cfg.visualization_port)
         | 
| 131 |  | 
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
| 132 | 
             
            if __name__ == "__main__":
         | 
| 133 | 
             
                main()
         | 
|  | |
| 5 | 
             
                generate_1f1b_overlap_schedule,
         | 
| 6 | 
             
                generate_1f1b_schedule,
         | 
| 7 | 
             
                generate_zero_bubble_1p_schedule,
         | 
| 8 | 
            +
                generate_dualpipe_schedule,
         | 
| 9 | 
             
            )
         | 
| 10 | 
             
            from src.visualizer import visualize_pipeline_parallelism_dash
         | 
| 11 | 
             
            import hydra
         | 
|  | |
| 27 | 
             
                    run_1f1b_overlap(cfg)
         | 
| 28 | 
             
                elif cfg.strategy == "1f1b_interleave_overlap":
         | 
| 29 | 
             
                    run_1f1b_interleave_overlap(cfg)
         | 
| 30 | 
            +
                elif cfg.strategy == "dualpipe":
         | 
| 31 | 
            +
                    run_dualpipe(cfg)
         | 
| 32 | 
             
                else:
         | 
| 33 | 
             
                    raise ValueError(f"Unknown strategy: {cfg.strategy}")
         | 
| 34 |  | 
|  | |
| 132 | 
             
                schedule.execute()
         | 
| 133 | 
             
                visualize_pipeline_parallelism_dash(schedule, port=cfg.visualization_port)
         | 
| 134 |  | 
| 135 | 
            +
            def run_dualpipe(cfg: DictConfig) -> None:
         | 
| 136 | 
            +
                """Run DualPipe pipeline parallelism simulation."""
         | 
| 137 | 
            +
                # Convert OmegaConf to dict for op_times if it exists
         | 
| 138 | 
            +
                op_times = (
         | 
| 139 | 
            +
                    OmegaConf.to_container(cfg.op_times) if hasattr(cfg, "op_times") else None
         | 
| 140 | 
            +
                )
         | 
| 141 | 
            +
             | 
| 142 | 
            +
                schedule_config = ScheduleConfig(
         | 
| 143 | 
            +
                    num_devices=cfg.num_devices,
         | 
| 144 | 
            +
                    num_stages=cfg.num_stages,
         | 
| 145 | 
            +
                    num_batches=cfg.num_batches,
         | 
| 146 | 
            +
                    p2p_latency=cfg.p2p_latency,
         | 
| 147 | 
            +
                    op_times=op_times,
         | 
| 148 | 
            +
                    split_backward=True,
         | 
| 149 | 
            +
                    placement_strategy="dualpipe",
         | 
| 150 | 
            +
                )
         | 
| 151 | 
            +
                schedule = generate_dualpipe_schedule(schedule_config)
         | 
| 152 | 
            +
                schedule.execute()
         | 
| 153 | 
            +
                visualize_pipeline_parallelism_dash(schedule, port=cfg.visualization_port)
         | 
| 154 | 
            +
             | 
| 155 | 
             
            if __name__ == "__main__":
         | 
| 156 | 
             
                main()
         | 
    	
        src/execution_model.py
    CHANGED
    
    | @@ -69,7 +69,7 @@ class DeviceQueue: | |
| 69 | 
             
                def add_operation(self, op: Operation):
         | 
| 70 | 
             
                    assert op.stage_id in self.stages
         | 
| 71 | 
             
                    self.ops.append(op)
         | 
| 72 | 
            -
                    assert op.device_id is None
         | 
| 73 | 
             
                    op.device_id = self.device_id
         | 
| 74 |  | 
| 75 |  | 
| @@ -97,6 +97,7 @@ class ScheduleConfig: | |
| 97 | 
             
                            "forward": 1.0,
         | 
| 98 | 
             
                            "backward_D": 1.0,
         | 
| 99 | 
             
                            "backward_W": 1.0,
         | 
|  | |
| 100 | 
             
                        }
         | 
| 101 | 
             
                    else:
         | 
| 102 | 
             
                        self.op_times = {
         | 
| @@ -128,9 +129,14 @@ class ScheduleConfig: | |
| 128 | 
             
                    self.num_stages_per_device = num_stages // num_devices
         | 
| 129 |  | 
| 130 | 
             
                    self.init_device_to_stages()
         | 
| 131 | 
            -
                     | 
| 132 | 
            -
                         | 
| 133 | 
            -
             | 
|  | |
|  | |
|  | |
|  | |
|  | |
| 134 |  | 
| 135 | 
             
                def init_device_to_stages(self):
         | 
| 136 | 
             
                    if self.placement_strategy == "standard":
         | 
| @@ -145,14 +151,27 @@ class ScheduleConfig: | |
| 145 | 
             
                        for i in range(self.num_stages):
         | 
| 146 | 
             
                            device_to_put = i % self.num_devices
         | 
| 147 | 
             
                            self.device_to_stages[device_to_put].append(i)
         | 
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
| 148 | 
             
                    else:
         | 
| 149 | 
             
                        raise ValueError(f"Invalid placement strategy: {self.placement_strategy}")
         | 
| 150 |  | 
| 151 | 
             
                def get_op_time(self, op_type: str, stage_id: int):
         | 
| 152 | 
             
                    # For overlapped operations, extract the original operation types
         | 
| 153 | 
             
                    if op_type.startswith("overlapped_"):
         | 
| 154 | 
            -
                        if op_type in self.op_times | 
| 155 | 
            -
                             | 
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
| 156 | 
             
                        else:
         | 
| 157 | 
             
                            op_parts = op_type.split("_")[1:]
         | 
| 158 | 
             
                            if len(op_parts) >= 2:
         | 
| @@ -173,20 +192,25 @@ class ScheduleConfig: | |
| 173 |  | 
| 174 |  | 
| 175 | 
             
            class Schedule:
         | 
| 176 | 
            -
                def __init__(self, config: ScheduleConfig):
         | 
| 177 | 
             
                    self.ops = {}  # (batch_id, stage_id, op_type) -> Operation
         | 
| 178 | 
             
                    self.device_queues: List[DeviceQueue] = []
         | 
| 179 | 
             
                    for dev_id in range(config.num_devices):
         | 
| 180 | 
             
                        self.device_queues.append(DeviceQueue(config.device_to_stages[dev_id], dev_id))
         | 
| 181 | 
             
                    self.config = config
         | 
| 182 |  | 
| 183 | 
            -
                     | 
|  | |
| 184 | 
             
                    self.op_to_overlapped = {}
         | 
| 185 |  | 
| 186 | 
             
                def register_overlapped_operation(self, overlapped_op: OverlappedOperation):
         | 
| 187 | 
             
                    for op in overlapped_op.operations:
         | 
| 188 | 
             
                        self.op_to_overlapped[(op.batch_id, op.stage_id, op.op_type)] = overlapped_op
         | 
| 189 | 
             
                        self.ops[(op.batch_id, op.stage_id, op.op_type)] = overlapped_op
         | 
|  | |
|  | |
|  | |
|  | |
| 190 |  | 
| 191 | 
             
                def init_operations(self):
         | 
| 192 | 
             
                    op_types = ["forward", "backward"]
         | 
| @@ -199,9 +223,12 @@ class Schedule: | |
| 199 | 
             
                                    batch_id, stage_id, op_type
         | 
| 200 | 
             
                                )
         | 
| 201 |  | 
| 202 | 
            -
                def get_op(self, batch_id: int, stage_id: int, op_type: str):
         | 
| 203 | 
             
                    if (batch_id, stage_id, op_type) in self.op_to_overlapped:
         | 
| 204 | 
             
                        return self.op_to_overlapped[(batch_id, stage_id, op_type)]
         | 
|  | |
|  | |
|  | |
| 205 | 
             
                    return self.ops[(batch_id, stage_id, op_type)]
         | 
| 206 |  | 
| 207 | 
             
                def get_dependencies(self, op: Operation, include_device_dependency=True):
         | 
| @@ -226,20 +253,55 @@ class Schedule: | |
| 226 | 
             
                    if self.config.split_backward:
         | 
| 227 | 
             
                        if op.op_type == "backward_D":
         | 
| 228 | 
             
                            if op.stage_id < self.config.num_stages - 1:
         | 
| 229 | 
            -
                                 | 
| 230 | 
            -
             | 
| 231 | 
            -
             | 
| 232 | 
            -
                                         | 
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
| 233 | 
             
                                    )
         | 
| 234 | 
            -
                                )
         | 
| 235 | 
             
                        elif op.op_type == "backward_W":
         | 
| 236 | 
             
                            if op.stage_id < self.config.num_stages - 1:
         | 
| 237 | 
            -
                                 | 
| 238 | 
            -
             | 
| 239 | 
            -
             | 
| 240 | 
            -
                                         | 
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
| 241 | 
             
                                    )
         | 
| 242 | 
            -
                                )
         | 
| 243 | 
             
                    else:
         | 
| 244 | 
             
                        if op.op_type == "backward":
         | 
| 245 | 
             
                            if op.stage_id < self.config.num_stages - 1:
         | 
|  | |
| 69 | 
             
                def add_operation(self, op: Operation):
         | 
| 70 | 
             
                    assert op.stage_id in self.stages
         | 
| 71 | 
             
                    self.ops.append(op)
         | 
| 72 | 
            +
                    assert op.device_id is None, f"Operation {op.batch_id}, {op.stage_id}, {op.op_type} already has a device id on {op.device_id}"
         | 
| 73 | 
             
                    op.device_id = self.device_id
         | 
| 74 |  | 
| 75 |  | 
|  | |
| 97 | 
             
                            "forward": 1.0,
         | 
| 98 | 
             
                            "backward_D": 1.0,
         | 
| 99 | 
             
                            "backward_W": 1.0,
         | 
| 100 | 
            +
                            "backward": 2.0,
         | 
| 101 | 
             
                        }
         | 
| 102 | 
             
                    else:
         | 
| 103 | 
             
                        self.op_times = {
         | 
|  | |
| 129 | 
             
                    self.num_stages_per_device = num_stages // num_devices
         | 
| 130 |  | 
| 131 | 
             
                    self.init_device_to_stages()
         | 
| 132 | 
            +
                    if self.placement_strategy == "dualpipe":
         | 
| 133 | 
            +
                        assert (
         | 
| 134 | 
            +
                            sum(len(stages) for stages in self.device_to_stages.values()) == num_stages * 2
         | 
| 135 | 
            +
                        )
         | 
| 136 | 
            +
                    else:
         | 
| 137 | 
            +
                        assert (
         | 
| 138 | 
            +
                            sum(len(stages) for stages in self.device_to_stages.values()) == num_stages
         | 
| 139 | 
            +
                        )
         | 
| 140 |  | 
| 141 | 
             
                def init_device_to_stages(self):
         | 
| 142 | 
             
                    if self.placement_strategy == "standard":
         | 
|  | |
| 151 | 
             
                        for i in range(self.num_stages):
         | 
| 152 | 
             
                            device_to_put = i % self.num_devices
         | 
| 153 | 
             
                            self.device_to_stages[device_to_put].append(i)
         | 
| 154 | 
            +
                    elif self.placement_strategy == "dualpipe":
         | 
| 155 | 
            +
                        # For DualPipe, each device has two stages 
         | 
| 156 | 
            +
                        assert self.num_devices == self.num_stages, "DualPipe requires num_devices == num_stages"
         | 
| 157 | 
            +
                        assert self.num_devices % 2 == 0, "DualPipe requires an even number of devices"
         | 
| 158 | 
            +
                        self.device_to_stages = defaultdict(list)
         | 
| 159 | 
            +
                        for i in range(self.num_stages):
         | 
| 160 | 
            +
                            self.device_to_stages[i] = [i, self.num_stages - i - 1]
         | 
| 161 | 
             
                    else:
         | 
| 162 | 
             
                        raise ValueError(f"Invalid placement strategy: {self.placement_strategy}")
         | 
| 163 |  | 
| 164 | 
             
                def get_op_time(self, op_type: str, stage_id: int):
         | 
| 165 | 
             
                    # For overlapped operations, extract the original operation types
         | 
| 166 | 
             
                    if op_type.startswith("overlapped_"):
         | 
| 167 | 
            +
                        if op_type in self.op_times:
         | 
| 168 | 
            +
                            if isinstance(self.op_times[op_type], dict):
         | 
| 169 | 
            +
                                if stage_id in self.op_times[op_type]:
         | 
| 170 | 
            +
                                    return self.op_times[op_type][stage_id]
         | 
| 171 | 
            +
                                else:
         | 
| 172 | 
            +
                                    raise ValueError(f"No time specified for operation {op_type} at stage {stage_id}")
         | 
| 173 | 
            +
                            else:
         | 
| 174 | 
            +
                                return self.op_times[op_type]
         | 
| 175 | 
             
                        else:
         | 
| 176 | 
             
                            op_parts = op_type.split("_")[1:]
         | 
| 177 | 
             
                            if len(op_parts) >= 2:
         | 
|  | |
| 192 |  | 
| 193 |  | 
| 194 | 
             
            class Schedule:
         | 
| 195 | 
            +
                def __init__(self, config: ScheduleConfig, init_ops: bool = True):
         | 
| 196 | 
             
                    self.ops = {}  # (batch_id, stage_id, op_type) -> Operation
         | 
| 197 | 
             
                    self.device_queues: List[DeviceQueue] = []
         | 
| 198 | 
             
                    for dev_id in range(config.num_devices):
         | 
| 199 | 
             
                        self.device_queues.append(DeviceQueue(config.device_to_stages[dev_id], dev_id))
         | 
| 200 | 
             
                    self.config = config
         | 
| 201 |  | 
| 202 | 
            +
                    if init_ops:
         | 
| 203 | 
            +
                        self.init_operations()
         | 
| 204 | 
             
                    self.op_to_overlapped = {}
         | 
| 205 |  | 
| 206 | 
             
                def register_overlapped_operation(self, overlapped_op: OverlappedOperation):
         | 
| 207 | 
             
                    for op in overlapped_op.operations:
         | 
| 208 | 
             
                        self.op_to_overlapped[(op.batch_id, op.stage_id, op.op_type)] = overlapped_op
         | 
| 209 | 
             
                        self.ops[(op.batch_id, op.stage_id, op.op_type)] = overlapped_op
         | 
| 210 | 
            +
                
         | 
| 211 | 
            +
                def register_operation(self, op: Operation):
         | 
| 212 | 
            +
                    assert (op.batch_id, op.stage_id, op.op_type) not in self.ops, f"Operation {op.batch_id}, {op.stage_id}, {op.op_type} already registered"
         | 
| 213 | 
            +
                    self.ops[(op.batch_id, op.stage_id, op.op_type)] = op
         | 
| 214 |  | 
| 215 | 
             
                def init_operations(self):
         | 
| 216 | 
             
                    op_types = ["forward", "backward"]
         | 
|  | |
| 223 | 
             
                                    batch_id, stage_id, op_type
         | 
| 224 | 
             
                                )
         | 
| 225 |  | 
| 226 | 
            +
                def get_op(self, batch_id: int, stage_id: int, op_type: str, allow_none=False):
         | 
| 227 | 
             
                    if (batch_id, stage_id, op_type) in self.op_to_overlapped:
         | 
| 228 | 
             
                        return self.op_to_overlapped[(batch_id, stage_id, op_type)]
         | 
| 229 | 
            +
                    if allow_none:
         | 
| 230 | 
            +
                        if (batch_id, stage_id, op_type) not in self.ops:
         | 
| 231 | 
            +
                            return None
         | 
| 232 | 
             
                    return self.ops[(batch_id, stage_id, op_type)]
         | 
| 233 |  | 
| 234 | 
             
                def get_dependencies(self, op: Operation, include_device_dependency=True):
         | 
|  | |
| 253 | 
             
                    if self.config.split_backward:
         | 
| 254 | 
             
                        if op.op_type == "backward_D":
         | 
| 255 | 
             
                            if op.stage_id < self.config.num_stages - 1:
         | 
| 256 | 
            +
                                op_bwd_d = self.get_op(op.batch_id, op.stage_id + 1, "backward_D", allow_none=True)
         | 
| 257 | 
            +
                                if op_bwd_d is not None:
         | 
| 258 | 
            +
                                    deps.append(
         | 
| 259 | 
            +
                                        (
         | 
| 260 | 
            +
                                            op_bwd_d,
         | 
| 261 | 
            +
                                            self.config.p2p_latency,
         | 
| 262 | 
            +
                                        )
         | 
| 263 | 
            +
                                    )
         | 
| 264 | 
            +
                                else:
         | 
| 265 | 
            +
                                    deps.append(
         | 
| 266 | 
            +
                                        (
         | 
| 267 | 
            +
                                            self.get_op(op.batch_id, op.stage_id + 1, "backward"),
         | 
| 268 | 
            +
                                            self.config.p2p_latency,
         | 
| 269 | 
            +
                                        )
         | 
| 270 | 
             
                                    )
         | 
|  | |
| 271 | 
             
                        elif op.op_type == "backward_W":
         | 
| 272 | 
             
                            if op.stage_id < self.config.num_stages - 1:
         | 
| 273 | 
            +
                                op_bwd_d = self.get_op(op.batch_id, op.stage_id, "backward_D", allow_none=True)
         | 
| 274 | 
            +
                                if op_bwd_d is not None:
         | 
| 275 | 
            +
                                    deps.append(
         | 
| 276 | 
            +
                                        (
         | 
| 277 | 
            +
                                            op_bwd_d,
         | 
| 278 | 
            +
                                            self.config.p2p_latency,
         | 
| 279 | 
            +
                                        )
         | 
| 280 | 
            +
                                    )
         | 
| 281 | 
            +
                                else:
         | 
| 282 | 
            +
                                    deps.append(
         | 
| 283 | 
            +
                                        (
         | 
| 284 | 
            +
                                            self.get_op(op.batch_id, op.stage_id, "backward"),
         | 
| 285 | 
            +
                                            self.config.p2p_latency,
         | 
| 286 | 
            +
                                        )
         | 
| 287 | 
            +
                                    )
         | 
| 288 | 
            +
                        elif op.op_type == "backward":
         | 
| 289 | 
            +
                            if op.stage_id < self.config.num_stages - 1:
         | 
| 290 | 
            +
                                op_bwd = self.get_op(op.batch_id, op.stage_id + 1, "backward", allow_none=True)
         | 
| 291 | 
            +
                                if op_bwd is not None:
         | 
| 292 | 
            +
                                    deps.append(
         | 
| 293 | 
            +
                                        (
         | 
| 294 | 
            +
                                            op_bwd,
         | 
| 295 | 
            +
                                            self.config.p2p_latency,
         | 
| 296 | 
            +
                                        )
         | 
| 297 | 
            +
                                    )
         | 
| 298 | 
            +
                                else:
         | 
| 299 | 
            +
                                    deps.append(
         | 
| 300 | 
            +
                                        (
         | 
| 301 | 
            +
                                            self.get_op(op.batch_id, op.stage_id + 1, "backward_D"),
         | 
| 302 | 
            +
                                            self.config.p2p_latency,
         | 
| 303 | 
            +
                                        )
         | 
| 304 | 
             
                                    )
         | 
|  | |
| 305 | 
             
                    else:
         | 
| 306 | 
             
                        if op.op_type == "backward":
         | 
| 307 | 
             
                            if op.stage_id < self.config.num_stages - 1:
         | 
    	
        src/strategies.py
    CHANGED
    
    | @@ -1,5 +1,5 @@ | |
| 1 | 
            -
            from collections import defaultdict
         | 
| 2 | 
            -
            from src.execution_model import OverlappedOperation, Schedule, ScheduleConfig
         | 
| 3 |  | 
| 4 |  | 
| 5 | 
             
            def generate_1f1b_schedule(config: ScheduleConfig):
         | 
| @@ -43,6 +43,7 @@ def generate_zero_bubble_1p_schedule(config: ScheduleConfig): | |
| 43 | 
             
                schedule = Schedule(config)
         | 
| 44 | 
             
                total_batches = config.num_batches
         | 
| 45 | 
             
                assert config.num_devices == config.num_stages, "num_devices must be equal to num_stages for ZB-1P"
         | 
|  | |
| 46 |  | 
| 47 | 
             
                for i in range(config.num_devices):
         | 
| 48 | 
             
                    fwd_batch_id = 0
         | 
| @@ -354,3 +355,227 @@ def generate_1f1b_interleave_overlap_schedule(config: ScheduleConfig): | |
| 354 |  | 
| 355 |  | 
| 356 | 
             
                return schedule
         | 
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            from collections import defaultdict, deque
         | 
| 2 | 
            +
            from src.execution_model import OverlappedOperation, Operation, Schedule, ScheduleConfig
         | 
| 3 |  | 
| 4 |  | 
| 5 | 
             
            def generate_1f1b_schedule(config: ScheduleConfig):
         | 
|  | |
| 43 | 
             
                schedule = Schedule(config)
         | 
| 44 | 
             
                total_batches = config.num_batches
         | 
| 45 | 
             
                assert config.num_devices == config.num_stages, "num_devices must be equal to num_stages for ZB-1P"
         | 
| 46 | 
            +
                assert config.split_backward, "ZB-1P requires split_backward=True"
         | 
| 47 |  | 
| 48 | 
             
                for i in range(config.num_devices):
         | 
| 49 | 
             
                    fwd_batch_id = 0
         | 
|  | |
| 355 |  | 
| 356 |  | 
| 357 | 
             
                return schedule
         | 
| 358 | 
            +
             | 
| 359 | 
            +
             | 
| 360 | 
            +
            def create_overlapped_ops(schedule, batch_id1, batch_id2, stage_id, type1, type2):
         | 
| 361 | 
            +
                """
         | 
| 362 | 
            +
                Helper function to create overlapped operations correctly.
         | 
| 363 | 
            +
                This handles the underlying operation creation and registration to avoid device_id issues.
         | 
| 364 | 
            +
                """
         | 
| 365 | 
            +
                # Get the operations from the schedule
         | 
| 366 | 
            +
                op1 = schedule.ops[(batch_id1, stage_id, type1)]
         | 
| 367 | 
            +
                op2 = schedule.ops[(batch_id2, stage_id, type2)]
         | 
| 368 | 
            +
                
         | 
| 369 | 
            +
                # Create the overlapped operation
         | 
| 370 | 
            +
                overlapped_op = OverlappedOperation([op1, op2])
         | 
| 371 | 
            +
                
         | 
| 372 | 
            +
                # Register in the schedule to ensure proper tracking
         | 
| 373 | 
            +
                schedule.register_overlapped_operation(overlapped_op)
         | 
| 374 | 
            +
                
         | 
| 375 | 
            +
                return overlapped_op
         | 
| 376 | 
            +
             | 
| 377 | 
            +
             | 
| 378 | 
            +
            def generate_dualpipe_schedule(config: ScheduleConfig):
         | 
| 379 | 
            +
                """
         | 
| 380 | 
            +
                Implements the DualPipe scheduling strategy.
         | 
| 381 | 
            +
                
         | 
| 382 | 
            +
                DualPipe is a bidirectional pipeline parallelism algorithm that achieves full overlap of forward
         | 
| 383 | 
            +
                and backward computation-communication phases and reduces pipeline bubbles.
         | 
| 384 | 
            +
                
         | 
| 385 | 
            +
                The DualPipe strategy has the following characteristics:
         | 
| 386 | 
            +
                1. Requires placement_strategy="dualpipe" in ScheduleConfig (set automatically)
         | 
| 387 | 
            +
                2. Each device handles both a forward stage and a reverse stage
         | 
| 388 | 
            +
                3. Overlaps forward and backward operations to reduce bubble size
         | 
| 389 | 
            +
                4. Assumes config.num_batches corresponds to half the total microbatches in original paper (M).
         | 
| 390 | 
            +
                5. Currently only supports split_backward=True.
         | 
| 391 | 
            +
             | 
| 392 | 
            +
                Args:
         | 
| 393 | 
            +
                    config: The scheduling configuration
         | 
| 394 | 
            +
             | 
| 395 | 
            +
                Returns:
         | 
| 396 | 
            +
                    A Schedule object with the DualPipe scheduling
         | 
| 397 | 
            +
                """
         | 
| 398 | 
            +
                # Ensure placement strategy is set for Schedule initialization
         | 
| 399 | 
            +
                assert config.placement_strategy == "dualpipe", "DualPipe schedule currently only supports placement_strategy='dualpipe'"
         | 
| 400 | 
            +
                # Assertions based on DualPipe requirements
         | 
| 401 | 
            +
                assert config.num_stages % 2 == 0, "DualPipe requires an even number of stages (and devices)"
         | 
| 402 | 
            +
                assert config.num_devices == config.num_stages, "DualPipe requires num_devices == num_stages"
         | 
| 403 | 
            +
                assert config.num_batches % 2 == 0, "DualPipe requires an even number of microbatches (config.num_batches)"
         | 
| 404 | 
            +
                # Assertion based on original implementation: num_chunks >= num_ranks * 2
         | 
| 405 | 
            +
                # Here, M (config.num_batches) corresponds to half_num_chunks
         | 
| 406 | 
            +
                assert config.num_batches >= config.num_devices, "DualPipe requires config.num_batches >= config.num_devices"
         | 
| 407 | 
            +
                assert config.split_backward, "DualPipe schedule currently only supports split_backward=True"
         | 
| 408 | 
            +
             | 
| 409 | 
            +
                schedule = Schedule(config, init_ops=False)
         | 
| 410 | 
            +
             | 
| 411 | 
            +
                num_stages = config.num_stages
         | 
| 412 | 
            +
                num_devices = config.num_devices
         | 
| 413 | 
            +
                # config.num_batches is M in the original paper, which corresponds to half_num_chunks
         | 
| 414 | 
            +
                half_num_chunks = config.num_batches // 2
         | 
| 415 | 
            +
                num_half_ranks = num_devices // 2
         | 
| 416 | 
            +
             | 
| 417 | 
            +
                fwd_batch_ids = defaultdict(int) # (device_id, phase) -> batch_id
         | 
| 418 | 
            +
                bwd_d_batch_ids = defaultdict(int) # (device_id, phase) -> batch_id
         | 
| 419 | 
            +
             | 
| 420 | 
            +
                waited_weight_grad = [deque() for _ in range(num_devices)] # (device_id, ) -> List[(stage_id, batch_id)]
         | 
| 421 | 
            +
             | 
| 422 | 
            +
                for device_id in range(num_devices):
         | 
| 423 | 
            +
                    is_in_second_half = device_id >= num_half_ranks
         | 
| 424 | 
            +
                    if is_in_second_half:
         | 
| 425 | 
            +
                        fwd_batch_ids[device_id, 1] = 0
         | 
| 426 | 
            +
                        fwd_batch_ids[device_id, 0] = config.num_batches // 2
         | 
| 427 | 
            +
                        bwd_d_batch_ids[device_id, 1] = 0
         | 
| 428 | 
            +
                        bwd_d_batch_ids[device_id, 0] = config.num_batches // 2
         | 
| 429 | 
            +
                    else:
         | 
| 430 | 
            +
                        fwd_batch_ids[device_id, 0] = 0
         | 
| 431 | 
            +
                        fwd_batch_ids[device_id, 1] = config.num_batches // 2
         | 
| 432 | 
            +
                        bwd_d_batch_ids[device_id, 0] = 0
         | 
| 433 | 
            +
                        bwd_d_batch_ids[device_id, 1] = config.num_batches // 2
         | 
| 434 | 
            +
                def get_stage_for_phase(device_id, phase, num_stages, is_in_second_half):
         | 
| 435 | 
            +
                    stage_fwd_dir = device_id # Stage handled when moving forward (0 to N-1)
         | 
| 436 | 
            +
                    stage_rev_dir = num_stages - 1 - device_id # Stage handled when moving backward (N-1 to 0)
         | 
| 437 | 
            +
                    if not is_in_second_half:
         | 
| 438 | 
            +
                        # First half: phase 0 -> fwd_dir, phase 1 -> rev_dir
         | 
| 439 | 
            +
                        return stage_fwd_dir if phase == 0 else stage_rev_dir
         | 
| 440 | 
            +
                    else:
         | 
| 441 | 
            +
                        # Second half: phase 0 -> rev_dir, phase 1 -> fwd_dir
         | 
| 442 | 
            +
                        return stage_rev_dir if phase == 0 else stage_fwd_dir
         | 
| 443 | 
            +
                
         | 
| 444 | 
            +
             | 
| 445 | 
            +
                def add_op_to_queue(device_id, stage_id, op_type, batch_id):
         | 
| 446 | 
            +
                    # Retrieve the correct pre-initialized Operation object
         | 
| 447 | 
            +
                    op = Operation(batch_id, stage_id, op_type)
         | 
| 448 | 
            +
                    schedule.register_operation(op)
         | 
| 449 | 
            +
                    # Add to the device queue
         | 
| 450 | 
            +
                    schedule.device_queues[device_id].add_operation(op)
         | 
| 451 | 
            +
             | 
| 452 | 
            +
                def _schedule_forward_chunk(device_id, phase, is_in_second_half):
         | 
| 453 | 
            +
                    """Schedules a forward compute operation."""
         | 
| 454 | 
            +
                    stage_id = get_stage_for_phase(device_id, phase, num_stages, is_in_second_half)
         | 
| 455 | 
            +
                    batch_id = fwd_batch_ids[device_id, phase]
         | 
| 456 | 
            +
                    add_op_to_queue(device_id, stage_id, "forward", batch_id)
         | 
| 457 | 
            +
                    fwd_batch_ids[device_id, phase] += 1
         | 
| 458 | 
            +
             | 
| 459 | 
            +
                def _schedule_backward_chunk(device_id, phase, is_in_second_half):
         | 
| 460 | 
            +
                    """Schedules a backward_D with backward_W compute operation."""
         | 
| 461 | 
            +
                    stage_id = get_stage_for_phase(device_id, phase, num_stages, is_in_second_half)
         | 
| 462 | 
            +
                    batch_id = bwd_d_batch_ids[device_id, phase]
         | 
| 463 | 
            +
                    add_op_to_queue(device_id, stage_id, "backward", batch_id)
         | 
| 464 | 
            +
                    bwd_d_batch_ids[device_id, phase] += 1
         | 
| 465 | 
            +
                
         | 
| 466 | 
            +
                def _schedule_backward_input_chunk(device_id, phase, is_in_second_half):
         | 
| 467 | 
            +
                    """Schedules a backward_D compute operation."""
         | 
| 468 | 
            +
                    stage_id = get_stage_for_phase(device_id, phase, num_stages, is_in_second_half)
         | 
| 469 | 
            +
                    batch_id = bwd_d_batch_ids[device_id, phase]
         | 
| 470 | 
            +
                    add_op_to_queue(device_id, stage_id, "backward_D", batch_id)
         | 
| 471 | 
            +
                    bwd_d_batch_ids[device_id, phase] += 1
         | 
| 472 | 
            +
                    waited_weight_grad[device_id].append((stage_id, batch_id))
         | 
| 473 | 
            +
             | 
| 474 | 
            +
                def _schedule_backward_weight_chunk(device_id):
         | 
| 475 | 
            +
                    """Schedules a backward_W compute operation."""
         | 
| 476 | 
            +
                    stage_id, batch_id = waited_weight_grad[device_id].popleft()
         | 
| 477 | 
            +
                    add_op_to_queue(device_id, stage_id, "backward_W", batch_id)
         | 
| 478 | 
            +
             | 
| 479 | 
            +
                def _schedule_forward_backward_chunk(device_id, fwd_phase, bwd_phase, is_in_second_half):
         | 
| 480 | 
            +
                    """Schedules an overlapped forward and backward_D compute operation."""
         | 
| 481 | 
            +
                    fwd_stage_id = get_stage_for_phase(device_id, fwd_phase, num_stages, is_in_second_half)
         | 
| 482 | 
            +
                    bwd_stage_id = get_stage_for_phase(device_id, bwd_phase, num_stages, is_in_second_half)
         | 
| 483 | 
            +
                    
         | 
| 484 | 
            +
                    fwd_batch_id = fwd_batch_ids[device_id, fwd_phase]
         | 
| 485 | 
            +
             | 
| 486 | 
            +
                    fwd_op = Operation(fwd_batch_id, fwd_stage_id, "forward")
         | 
| 487 | 
            +
                    schedule.register_operation(fwd_op)
         | 
| 488 | 
            +
                    fwd_batch_ids[device_id, fwd_phase] += 1
         | 
| 489 | 
            +
             | 
| 490 | 
            +
                    bwd_batch_id_d = bwd_d_batch_ids[device_id, bwd_phase]
         | 
| 491 | 
            +
                    bwd_op = Operation(bwd_batch_id_d, bwd_stage_id, "backward")
         | 
| 492 | 
            +
                    schedule.register_operation(bwd_op)
         | 
| 493 | 
            +
                    bwd_d_batch_ids[device_id, bwd_phase] += 1
         | 
| 494 | 
            +
             | 
| 495 | 
            +
                    # Create and register the overlapped operation
         | 
| 496 | 
            +
                    overlapped_op = OverlappedOperation([fwd_op, bwd_op]) 
         | 
| 497 | 
            +
                    schedule.register_overlapped_operation(overlapped_op)
         | 
| 498 | 
            +
                    
         | 
| 499 | 
            +
                    # Add the overlapped operation to the queue
         | 
| 500 | 
            +
                    schedule.device_queues[device_id].add_operation(overlapped_op)
         | 
| 501 | 
            +
             | 
| 502 | 
            +
             | 
| 503 | 
            +
                # Process each device (rank in original code)
         | 
| 504 | 
            +
                for device_id in range(num_devices):
         | 
| 505 | 
            +
                    half_rank = min(device_id, num_devices - 1 - device_id)
         | 
| 506 | 
            +
                    is_in_second_half = device_id >= num_half_ranks
         | 
| 507 | 
            +
                    is_middle_rank = (device_id == num_half_ranks - 1) or (device_id == num_half_ranks)
         | 
| 508 | 
            +
             | 
| 509 | 
            +
                    # Map original steps to operation additions
         | 
| 510 | 
            +
                    # Step 1: nF0
         | 
| 511 | 
            +
                    step_1_count = (num_half_ranks - half_rank - 1) * 2
         | 
| 512 | 
            +
                    for _ in range(step_1_count):
         | 
| 513 | 
            +
                        _schedule_forward_chunk(device_id, 0, is_in_second_half) # F0
         | 
| 514 | 
            +
             | 
| 515 | 
            +
                    # Step 2: nF0F1
         | 
| 516 | 
            +
                    step_2_count = half_rank + 1
         | 
| 517 | 
            +
                    for i in range(step_2_count):
         | 
| 518 | 
            +
                        _schedule_forward_chunk(device_id, 0, is_in_second_half) # F0
         | 
| 519 | 
            +
                        _schedule_forward_chunk(device_id, 1, is_in_second_half) # F1
         | 
| 520 | 
            +
             | 
| 521 | 
            +
                    # Step 3: nB1W1F1
         | 
| 522 | 
            +
                    step_3_count = num_half_ranks - half_rank - 1
         | 
| 523 | 
            +
                    for _ in range(step_3_count):
         | 
| 524 | 
            +
                        _schedule_backward_input_chunk(device_id, 1, is_in_second_half) # B1_D
         | 
| 525 | 
            +
                        _schedule_backward_weight_chunk(device_id,)   # W1
         | 
| 526 | 
            +
                        _schedule_forward_chunk(device_id, 1, is_in_second_half)  # F1
         | 
| 527 | 
            +
             | 
| 528 | 
            +
                    # Step 4 (Main step): nF0B1F1B0
         | 
| 529 | 
            +
                    step_4_count = half_num_chunks - num_devices + half_rank + 1
         | 
| 530 | 
            +
                    for i in range(step_4_count):
         | 
| 531 | 
            +
                        # if i == 0 and is_middle_rank:
         | 
| 532 | 
            +
                            # Schedule F0, B1_D, W1 sequentially for middle ranks on first iteration
         | 
| 533 | 
            +
                            # _schedule_forward_chunk(device_id, 0, is_in_second_half) # F0
         | 
| 534 | 
            +
                            # _schedule_backward_chunk(device_id, 1, is_in_second_half)# B1
         | 
| 535 | 
            +
                            # _schedule_backward_weight_chunk(device_id, 1, is_in_second_half)  # W1
         | 
| 536 | 
            +
                        # else:
         | 
| 537 | 
            +
                        # Overlap F0 and B1_D, then schedule W1
         | 
| 538 | 
            +
                        _schedule_forward_backward_chunk(device_id, 0, 1, is_in_second_half) # F0+B1
         | 
| 539 | 
            +
                        
         | 
| 540 | 
            +
                        # Overlap F1 and B0_D, then schedule W0
         | 
| 541 | 
            +
                        _schedule_forward_backward_chunk(device_id, 1, 0, is_in_second_half) # F1+B0
         | 
| 542 | 
            +
             | 
| 543 | 
            +
                    # Step 5: nB1F1B0
         | 
| 544 | 
            +
                    step_5_count = num_half_ranks - half_rank - 1
         | 
| 545 | 
            +
                    for _ in range(step_5_count):
         | 
| 546 | 
            +
                        _schedule_backward_chunk(device_id, 1, is_in_second_half) # B1_D + B1_W
         | 
| 547 | 
            +
                        _schedule_forward_backward_chunk(device_id, 1, 0, is_in_second_half) # F1+B0
         | 
| 548 | 
            +
             | 
| 549 | 
            +
                    # Step 6: nB1B0
         | 
| 550 | 
            +
                    step_6_count = half_rank + 1
         | 
| 551 | 
            +
                    enable_zb = False
         | 
| 552 | 
            +
                    for i in range(step_6_count):
         | 
| 553 | 
            +
                        if i == step_6_count // 2 and half_rank % 2 == 1:
         | 
| 554 | 
            +
                            enable_zb = True
         | 
| 555 | 
            +
                        if enable_zb:
         | 
| 556 | 
            +
                            _schedule_backward_input_chunk(device_id, 1, is_in_second_half)
         | 
| 557 | 
            +
                        else:
         | 
| 558 | 
            +
                            _schedule_backward_chunk(device_id, 1, is_in_second_half)
         | 
| 559 | 
            +
                        if i == step_6_count // 2 and half_rank % 2 == 0:
         | 
| 560 | 
            +
                            enable_zb = True
         | 
| 561 | 
            +
                        if enable_zb:
         | 
| 562 | 
            +
                            _schedule_backward_input_chunk(device_id, 0, is_in_second_half)
         | 
| 563 | 
            +
                        else:
         | 
| 564 | 
            +
                            _schedule_backward_chunk(device_id, 0, is_in_second_half)
         | 
| 565 | 
            +
             | 
| 566 | 
            +
                    # Step 7: nWB0
         | 
| 567 | 
            +
                    step_7_count = num_half_ranks - half_rank - 1
         | 
| 568 | 
            +
                    for _ in range(step_7_count):
         | 
| 569 | 
            +
                        _schedule_backward_weight_chunk(device_id)   # W1 (use gradient from B1_D scheduled previously)
         | 
| 570 | 
            +
                        _schedule_backward_input_chunk(device_id, 0, is_in_second_half) # B0_D
         | 
| 571 | 
            +
             | 
| 572 | 
            +
                    # Step 8: nW
         | 
| 573 | 
            +
                    step_8_count = half_rank + 1
         | 
| 574 | 
            +
                    for _ in range(step_8_count):
         | 
| 575 | 
            +
                        # W0 uses gradients from B0_D scheduled in steps 4, 5, 6.
         | 
| 576 | 
            +
                        # W1 uses gradients from B1_D scheduled in steps 3, 4, 5, 6.
         | 
| 577 | 
            +
                        # The last W0 gradients correspond to B0_D from step 6 or 7.
         | 
| 578 | 
            +
                        _schedule_backward_weight_chunk(device_id)   # W0 (use gradient from B0_D scheduled previously)
         | 
| 579 | 
            +
             | 
| 580 | 
            +
                return schedule
         | 
| 581 | 
            +
             | 
    	
        src/visualizer.py
    CHANGED
    
    | @@ -89,11 +89,6 @@ def get_color(op_type: str, stage_id: int, num_devices: int): | |
| 89 |  | 
| 90 | 
             
                # Improved teal/turquoise palette with low saturation and high brightness
         | 
| 91 | 
             
                backward_d_colors = [
         | 
| 92 | 
            -
                    "#ccffff",  # Very light cyan
         | 
| 93 | 
            -
                    "#b3ffff",  # Pale cyan
         | 
| 94 | 
            -
                    "#99ffff",  # Light cyan
         | 
| 95 | 
            -
                    "#80ffff",  # Cyan
         | 
| 96 | 
            -
                    "#66e6e6",  # Soft teal
         | 
| 97 | 
             
                    "#4dcccc",  # Light teal
         | 
| 98 | 
             
                    "#33b3b3",  # Teal
         | 
| 99 | 
             
                    "#009999",  # Medium teal
         | 
| @@ -102,12 +97,6 @@ def get_color(op_type: str, stage_id: int, num_devices: int): | |
| 102 |  | 
| 103 | 
             
                # Improved green palette with low saturation and high brightness
         | 
| 104 | 
             
                backward_w_colors = [
         | 
| 105 | 
            -
                    "#ccffe6",  # Very light mint
         | 
| 106 | 
            -
                    "#b3ffd9",  # Pale mint
         | 
| 107 | 
            -
                    "#99ffcc",  # Light mint
         | 
| 108 | 
            -
                    "#80ffbf",  # Mint green
         | 
| 109 | 
            -
                    "#66e6a6",  # Soft green
         | 
| 110 | 
            -
                    "#4dcc8c",  # Light green
         | 
| 111 | 
             
                    "#33b373",  # Medium green
         | 
| 112 | 
             
                    "#009959",  # Forest green
         | 
| 113 | 
             
                    "#008040",  # Dark green
         | 
| @@ -162,7 +151,8 @@ def create_pipeline_figure( | |
| 162 | 
             
                        max_batch = max(max_batch, task["batch"])
         | 
| 163 |  | 
| 164 | 
             
                # Flag to determine whether to show text labels
         | 
| 165 | 
            -
                 | 
|  | |
| 166 |  | 
| 167 | 
             
                # Create a figure
         | 
| 168 | 
             
                fig = go.Figure()
         | 
|  | |
| 89 |  | 
| 90 | 
             
                # Improved teal/turquoise palette with low saturation and high brightness
         | 
| 91 | 
             
                backward_d_colors = [
         | 
|  | |
|  | |
|  | |
|  | |
|  | |
| 92 | 
             
                    "#4dcccc",  # Light teal
         | 
| 93 | 
             
                    "#33b3b3",  # Teal
         | 
| 94 | 
             
                    "#009999",  # Medium teal
         | 
|  | |
| 97 |  | 
| 98 | 
             
                # Improved green palette with low saturation and high brightness
         | 
| 99 | 
             
                backward_w_colors = [
         | 
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
| 100 | 
             
                    "#33b373",  # Medium green
         | 
| 101 | 
             
                    "#009959",  # Forest green
         | 
| 102 | 
             
                    "#008040",  # Dark green
         | 
|  | |
| 151 | 
             
                        max_batch = max(max_batch, task["batch"])
         | 
| 152 |  | 
| 153 | 
             
                # Flag to determine whether to show text labels
         | 
| 154 | 
            +
                num_operations_per_device = len(schedule_data[0])
         | 
| 155 | 
            +
                show_text_labels = num_operations_per_device <= 64
         | 
| 156 |  | 
| 157 | 
             
                # Create a figure
         | 
| 158 | 
             
                fig = go.Figure()
         | 
