Spaces:
Running
Running
Add 1F1B-overlap implementation.
Browse files- .gitignore +1 -0
- README.md +7 -0
- assets/1f1b.png +2 -2
- assets/1f1b_overlap.png +3 -0
- main.py +43 -10
- src/strategies.py +36 -0
- src/visualizer.py +225 -165
.gitignore
CHANGED
|
@@ -2,6 +2,7 @@
|
|
| 2 |
./venv
|
| 3 |
uv.lock
|
| 4 |
outputs/
|
|
|
|
| 5 |
|
| 6 |
# Uncomment below if you want to include these files
|
| 7 |
# !assets/*.png
|
|
|
|
| 2 |
./venv
|
| 3 |
uv.lock
|
| 4 |
outputs/
|
| 5 |
+
.cursor/*
|
| 6 |
|
| 7 |
# Uncomment below if you want to include these files
|
| 8 |
# !assets/*.png
|
README.md
CHANGED
|
@@ -50,6 +50,13 @@ uv run python main.py strategy=zb1p num_devices=4 num_stages=4 num_batches=8
|
|
| 50 |
```
|
| 51 |

|
| 52 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 53 |
## Configuration
|
| 54 |
|
| 55 |
The default configuration is in `conf/config.yaml`. You can override any parameter on the command line or create configuration groups for different scenarios.
|
|
|
|
| 50 |
```
|
| 51 |

|
| 52 |
|
| 53 |
+
|
| 54 |
+
Running for 1F1B-batch-overlap strategy:
|
| 55 |
+
```bah
|
| 56 |
+
uv run python main.py strategy=1f1b_overlap num_devices=4 num_stages=4 num_batches=8
|
| 57 |
+
```
|
| 58 |
+

|
| 59 |
+
|
| 60 |
## Configuration
|
| 61 |
|
| 62 |
The default configuration is in `conf/config.yaml`. You can override any parameter on the command line or create configuration groups for different scenarios.
|
assets/1f1b.png
CHANGED
|
Git LFS Details
|
|
Git LFS Details
|
assets/1f1b_overlap.png
ADDED
|
Git LFS Details
|
main.py
CHANGED
|
@@ -1,5 +1,10 @@
|
|
| 1 |
from src.execution_model import ScheduleConfig
|
| 2 |
-
from src.strategies import
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 3 |
from src.visualizer import visualize_pipeline_parallelism_dash
|
| 4 |
import hydra
|
| 5 |
from omegaconf import DictConfig, OmegaConf
|
|
@@ -16,6 +21,8 @@ def main(cfg: DictConfig) -> None:
|
|
| 16 |
run_interleave(cfg)
|
| 17 |
elif cfg.strategy == "zb1p":
|
| 18 |
run_zero_bubble_1p(cfg)
|
|
|
|
|
|
|
| 19 |
else:
|
| 20 |
raise ValueError(f"Unknown strategy: {cfg.strategy}")
|
| 21 |
|
|
@@ -23,7 +30,9 @@ def main(cfg: DictConfig) -> None:
|
|
| 23 |
def run_1f1b(cfg: DictConfig) -> None:
|
| 24 |
"""Run 1F1B pipeline parallelism simulation."""
|
| 25 |
# Convert OmegaConf to dict for op_times if it exists
|
| 26 |
-
op_times =
|
|
|
|
|
|
|
| 27 |
|
| 28 |
schedule_config = ScheduleConfig(
|
| 29 |
num_devices=cfg.num_devices,
|
|
@@ -31,7 +40,7 @@ def run_1f1b(cfg: DictConfig) -> None:
|
|
| 31 |
num_batches=cfg.num_batches,
|
| 32 |
p2p_latency=cfg.p2p_latency,
|
| 33 |
op_times=op_times,
|
| 34 |
-
placement_strategy="standard"
|
| 35 |
)
|
| 36 |
schedule = generate_1f1b_schedule(schedule_config)
|
| 37 |
schedule.execute()
|
|
@@ -42,15 +51,17 @@ def run_1f1b(cfg: DictConfig) -> None:
|
|
| 42 |
def run_interleave(cfg: DictConfig) -> None:
|
| 43 |
"""Run interleaved pipeline parallelism simulation."""
|
| 44 |
# Convert OmegaConf to dict for op_times if it exists
|
| 45 |
-
op_times =
|
| 46 |
-
|
|
|
|
|
|
|
| 47 |
schedule_config = ScheduleConfig(
|
| 48 |
num_devices=cfg.num_devices,
|
| 49 |
num_stages=cfg.num_stages,
|
| 50 |
num_batches=cfg.num_batches,
|
| 51 |
p2p_latency=cfg.p2p_latency,
|
| 52 |
placement_strategy="interleave",
|
| 53 |
-
op_times=op_times
|
| 54 |
)
|
| 55 |
schedule = generate_1f1b_interleave_schedule(schedule_config)
|
| 56 |
schedule.execute()
|
|
@@ -60,20 +71,42 @@ def run_interleave(cfg: DictConfig) -> None:
|
|
| 60 |
def run_zero_bubble_1p(cfg: DictConfig) -> None:
|
| 61 |
"""Run zero bubble 1P pipeline parallelism simulation."""
|
| 62 |
# Convert OmegaConf to dict for op_times if it exists
|
| 63 |
-
op_times =
|
| 64 |
-
|
|
|
|
|
|
|
| 65 |
schedule_config = ScheduleConfig(
|
| 66 |
num_devices=cfg.num_devices,
|
| 67 |
num_stages=cfg.num_stages,
|
| 68 |
num_batches=cfg.num_batches,
|
| 69 |
p2p_latency=cfg.p2p_latency,
|
| 70 |
op_times=op_times,
|
| 71 |
-
split_backward=True
|
| 72 |
)
|
| 73 |
schedule = generate_zero_bubble_1p_schedule(schedule_config)
|
| 74 |
schedule.execute()
|
| 75 |
visualize_pipeline_parallelism_dash(schedule, port=cfg.visualization_port)
|
| 76 |
|
| 77 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 78 |
if __name__ == "__main__":
|
| 79 |
-
main()
|
|
|
|
| 1 |
from src.execution_model import ScheduleConfig
|
| 2 |
+
from src.strategies import (
|
| 3 |
+
generate_1f1b_interleave_schedule,
|
| 4 |
+
generate_1f1b_overlap_schedule,
|
| 5 |
+
generate_1f1b_schedule,
|
| 6 |
+
generate_zero_bubble_1p_schedule,
|
| 7 |
+
)
|
| 8 |
from src.visualizer import visualize_pipeline_parallelism_dash
|
| 9 |
import hydra
|
| 10 |
from omegaconf import DictConfig, OmegaConf
|
|
|
|
| 21 |
run_interleave(cfg)
|
| 22 |
elif cfg.strategy == "zb1p":
|
| 23 |
run_zero_bubble_1p(cfg)
|
| 24 |
+
elif cfg.strategy == "1f1b_overlap":
|
| 25 |
+
run_1f1b_overlap(cfg)
|
| 26 |
else:
|
| 27 |
raise ValueError(f"Unknown strategy: {cfg.strategy}")
|
| 28 |
|
|
|
|
| 30 |
def run_1f1b(cfg: DictConfig) -> None:
|
| 31 |
"""Run 1F1B pipeline parallelism simulation."""
|
| 32 |
# Convert OmegaConf to dict for op_times if it exists
|
| 33 |
+
op_times = (
|
| 34 |
+
OmegaConf.to_container(cfg.op_times) if hasattr(cfg, "op_times") else None
|
| 35 |
+
)
|
| 36 |
|
| 37 |
schedule_config = ScheduleConfig(
|
| 38 |
num_devices=cfg.num_devices,
|
|
|
|
| 40 |
num_batches=cfg.num_batches,
|
| 41 |
p2p_latency=cfg.p2p_latency,
|
| 42 |
op_times=op_times,
|
| 43 |
+
placement_strategy="standard",
|
| 44 |
)
|
| 45 |
schedule = generate_1f1b_schedule(schedule_config)
|
| 46 |
schedule.execute()
|
|
|
|
| 51 |
def run_interleave(cfg: DictConfig) -> None:
|
| 52 |
"""Run interleaved pipeline parallelism simulation."""
|
| 53 |
# Convert OmegaConf to dict for op_times if it exists
|
| 54 |
+
op_times = (
|
| 55 |
+
OmegaConf.to_container(cfg.op_times) if hasattr(cfg, "op_times") else None
|
| 56 |
+
)
|
| 57 |
+
|
| 58 |
schedule_config = ScheduleConfig(
|
| 59 |
num_devices=cfg.num_devices,
|
| 60 |
num_stages=cfg.num_stages,
|
| 61 |
num_batches=cfg.num_batches,
|
| 62 |
p2p_latency=cfg.p2p_latency,
|
| 63 |
placement_strategy="interleave",
|
| 64 |
+
op_times=op_times,
|
| 65 |
)
|
| 66 |
schedule = generate_1f1b_interleave_schedule(schedule_config)
|
| 67 |
schedule.execute()
|
|
|
|
| 71 |
def run_zero_bubble_1p(cfg: DictConfig) -> None:
|
| 72 |
"""Run zero bubble 1P pipeline parallelism simulation."""
|
| 73 |
# Convert OmegaConf to dict for op_times if it exists
|
| 74 |
+
op_times = (
|
| 75 |
+
OmegaConf.to_container(cfg.op_times) if hasattr(cfg, "op_times") else None
|
| 76 |
+
)
|
| 77 |
+
|
| 78 |
schedule_config = ScheduleConfig(
|
| 79 |
num_devices=cfg.num_devices,
|
| 80 |
num_stages=cfg.num_stages,
|
| 81 |
num_batches=cfg.num_batches,
|
| 82 |
p2p_latency=cfg.p2p_latency,
|
| 83 |
op_times=op_times,
|
| 84 |
+
split_backward=True,
|
| 85 |
)
|
| 86 |
schedule = generate_zero_bubble_1p_schedule(schedule_config)
|
| 87 |
schedule.execute()
|
| 88 |
visualize_pipeline_parallelism_dash(schedule, port=cfg.visualization_port)
|
| 89 |
|
| 90 |
|
| 91 |
+
def run_1f1b_overlap(cfg: DictConfig) -> None:
|
| 92 |
+
"""Run 1F1B overlap pipeline parallelism simulation."""
|
| 93 |
+
# Convert OmegaConf to dict for op_times if it exists
|
| 94 |
+
op_times = (
|
| 95 |
+
OmegaConf.to_container(cfg.op_times) if hasattr(cfg, "op_times") else None
|
| 96 |
+
)
|
| 97 |
+
|
| 98 |
+
schedule_config = ScheduleConfig(
|
| 99 |
+
num_devices=cfg.num_devices,
|
| 100 |
+
num_stages=cfg.num_stages,
|
| 101 |
+
num_batches=cfg.num_batches,
|
| 102 |
+
p2p_latency=cfg.p2p_latency,
|
| 103 |
+
op_times=op_times,
|
| 104 |
+
split_backward=False,
|
| 105 |
+
)
|
| 106 |
+
schedule = generate_1f1b_overlap_schedule(schedule_config)
|
| 107 |
+
schedule.execute()
|
| 108 |
+
visualize_pipeline_parallelism_dash(schedule, port=cfg.visualization_port)
|
| 109 |
+
|
| 110 |
+
|
| 111 |
if __name__ == "__main__":
|
| 112 |
+
main()
|
src/strategies.py
CHANGED
|
@@ -94,6 +94,42 @@ def generate_zero_bubble_1p_schedule(config: ScheduleConfig):
|
|
| 94 |
return schedule
|
| 95 |
|
| 96 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 97 |
# Some codes are copied from Megatron-LM
|
| 98 |
def generate_1f1b_interleave_schedule(config: ScheduleConfig):
|
| 99 |
schedule = Schedule(config)
|
|
|
|
| 94 |
return schedule
|
| 95 |
|
| 96 |
|
| 97 |
+
def generate_1f1b_overlap_schedule(config: ScheduleConfig):
|
| 98 |
+
schedule = Schedule(config)
|
| 99 |
+
|
| 100 |
+
assert config.num_devices == config.num_stages, "num_devices must be equal to num_stages for 1F1B"
|
| 101 |
+
|
| 102 |
+
for i in range(config.num_devices):
|
| 103 |
+
fwd_batch_id = 0
|
| 104 |
+
bwd_batch_id = 0
|
| 105 |
+
cooldown_batches = warmup_batches = 2 * (config.num_devices - i - 1) + 1
|
| 106 |
+
steady_batches = config.num_batches - warmup_batches
|
| 107 |
+
|
| 108 |
+
for _ in range(warmup_batches):
|
| 109 |
+
schedule.dev_queues[i].add_operation(
|
| 110 |
+
schedule.get_op(fwd_batch_id, i, "forward")
|
| 111 |
+
)
|
| 112 |
+
fwd_batch_id += 1
|
| 113 |
+
|
| 114 |
+
for _ in range(steady_batches):
|
| 115 |
+
schedule.dev_queues[i].add_operation(
|
| 116 |
+
schedule.get_op(fwd_batch_id, i, "forward")
|
| 117 |
+
)
|
| 118 |
+
fwd_batch_id += 1
|
| 119 |
+
schedule.dev_queues[i].add_operation(
|
| 120 |
+
schedule.get_op(bwd_batch_id, i, "backward")
|
| 121 |
+
)
|
| 122 |
+
bwd_batch_id += 1
|
| 123 |
+
|
| 124 |
+
for _ in range(cooldown_batches):
|
| 125 |
+
schedule.dev_queues[i].add_operation(
|
| 126 |
+
schedule.get_op(bwd_batch_id, i, "backward")
|
| 127 |
+
)
|
| 128 |
+
bwd_batch_id += 1
|
| 129 |
+
|
| 130 |
+
return schedule
|
| 131 |
+
|
| 132 |
+
|
| 133 |
# Some codes are copied from Megatron-LM
|
| 134 |
def generate_1f1b_interleave_schedule(config: ScheduleConfig):
|
| 135 |
schedule = Schedule(config)
|
src/visualizer.py
CHANGED
|
@@ -12,30 +12,34 @@ from src.execution_model import Schedule
|
|
| 12 |
def convert_schedule_to_visualization_format(schedule: Schedule):
|
| 13 |
"""
|
| 14 |
Converts a Schedule object to the format needed for visualization.
|
| 15 |
-
|
| 16 |
Returns:
|
| 17 |
Dict[int, List[Dict]]: Dictionary mapping device_id to a list of operation dictionaries
|
| 18 |
"""
|
| 19 |
# Make sure all operations have start and end times
|
| 20 |
for op in schedule.ops.values():
|
| 21 |
if op.start_time is None or op.end_time is None:
|
| 22 |
-
raise ValueError(
|
| 23 |
-
|
|
|
|
|
|
|
| 24 |
visualization_data = {}
|
| 25 |
-
|
| 26 |
# Organize operations by device
|
| 27 |
for device_id, device_queue in enumerate(schedule.dev_queues):
|
| 28 |
visualization_data[device_id] = []
|
| 29 |
-
|
| 30 |
for op in device_queue.ops:
|
| 31 |
-
visualization_data[device_id].append(
|
| 32 |
-
|
| 33 |
-
|
| 34 |
-
|
| 35 |
-
|
| 36 |
-
|
| 37 |
-
|
| 38 |
-
|
|
|
|
|
|
|
| 39 |
return visualization_data
|
| 40 |
|
| 41 |
|
|
@@ -44,58 +48,58 @@ def convert_schedule_to_visualization_format(schedule: Schedule):
|
|
| 44 |
def get_color(op_type: str, stage_id: int, num_devices: int):
|
| 45 |
# A more harmonious blue palette with better progression for forward operations
|
| 46 |
forward_colors = [
|
| 47 |
-
"#5c88f2",
|
| 48 |
-
"#1a53ff",
|
| 49 |
-
"#b3c6ff",
|
| 50 |
-
"#4d79ff",
|
| 51 |
-
"#809fff",
|
| 52 |
-
"#0039e6",
|
| 53 |
-
"#002db3",
|
| 54 |
-
"#264db3",
|
| 55 |
-
"#7094db",
|
| 56 |
-
"#99b3e6"
|
| 57 |
]
|
| 58 |
-
|
| 59 |
# Orange palette for backward operations
|
| 60 |
backward_colors = [
|
| 61 |
-
"#ff9933",
|
| 62 |
-
"#ffad5c",
|
| 63 |
-
"#ffc285",
|
| 64 |
-
"#ffd6ad",
|
| 65 |
-
"#ff8000",
|
| 66 |
-
"#cc6600",
|
| 67 |
-
"#ff9933",
|
| 68 |
-
"#ffb366",
|
| 69 |
-
"#cc9966",
|
| 70 |
-
"#ffd699"
|
| 71 |
]
|
| 72 |
-
|
| 73 |
# Improved teal/turquoise palette with better progression for backward_D operations
|
| 74 |
backward_d_colors = [
|
| 75 |
-
"#80ffff",
|
| 76 |
-
"#00cccc",
|
| 77 |
-
"#00e6e6",
|
| 78 |
-
"#33ffff",
|
| 79 |
-
"#00b3b3",
|
| 80 |
-
"#008080",
|
| 81 |
-
"#00e6cc",
|
| 82 |
-
"#4ddbbd",
|
| 83 |
-
"#80d4c8",
|
| 84 |
-
"#b3e6e0"
|
| 85 |
]
|
| 86 |
-
|
| 87 |
# Improved green palette with better progression for backward_W operations
|
| 88 |
backward_w_colors = [
|
| 89 |
-
"#00cc66",
|
| 90 |
-
"#00e673",
|
| 91 |
-
"#33ff99",
|
| 92 |
-
"#80ffbf",
|
| 93 |
-
"#009933",
|
| 94 |
-
"#006622",
|
| 95 |
-
"#33cc33",
|
| 96 |
-
"#66cc66",
|
| 97 |
-
"#99cc99",
|
| 98 |
-
"#c6e6c6"
|
| 99 |
]
|
| 100 |
|
| 101 |
virtual_stage = stage_id // num_devices
|
|
@@ -115,7 +119,9 @@ def get_color(op_type: str, stage_id: int, num_devices: int):
|
|
| 115 |
raise ValueError(f"Invalid operation type: {op_type}")
|
| 116 |
|
| 117 |
|
| 118 |
-
def create_pipeline_figure(
|
|
|
|
|
|
|
| 119 |
"""
|
| 120 |
Create a Plotly figure for pipeline parallelism scheduling.
|
| 121 |
|
|
@@ -126,9 +132,9 @@ def create_pipeline_figure(schedule_data: Dict[int, List[Dict]], max_time=None,
|
|
| 126 |
"""
|
| 127 |
# Find the number of devices
|
| 128 |
num_devices = len(schedule_data)
|
| 129 |
-
|
| 130 |
empty_color = "whitesmoke"
|
| 131 |
-
|
| 132 |
# Find the maximum time in the schedule if not provided
|
| 133 |
if max_time is None:
|
| 134 |
max_time = 0
|
|
@@ -146,7 +152,9 @@ def create_pipeline_figure(schedule_data: Dict[int, List[Dict]], max_time=None,
|
|
| 146 |
tasks_processed = 0
|
| 147 |
|
| 148 |
if show_progress:
|
| 149 |
-
progress_bar = tqdm(
|
|
|
|
|
|
|
| 150 |
|
| 151 |
# Create a custom y-axis with no gaps between devices
|
| 152 |
y_spacing = 1.0 # Use 1.0 for no gaps
|
|
@@ -159,7 +167,7 @@ def create_pipeline_figure(schedule_data: Dict[int, List[Dict]], max_time=None,
|
|
| 159 |
# Add rectangles for each task
|
| 160 |
for device_idx, device in enumerate(schedule_data):
|
| 161 |
device_idx_reversed = num_devices - device_idx - 1
|
| 162 |
-
|
| 163 |
# Sort tasks by start time to ensure correct rendering
|
| 164 |
sorted_tasks = sorted(schedule_data[device], key=lambda t: t["start_time"])
|
| 165 |
|
|
@@ -189,44 +197,50 @@ def create_pipeline_figure(schedule_data: Dict[int, List[Dict]], max_time=None,
|
|
| 189 |
# Add rectangle for the task
|
| 190 |
start_time = task["start_time"]
|
| 191 |
duration = task["duration"]
|
| 192 |
-
|
| 193 |
# Calculate y positions with no gaps
|
| 194 |
y_pos = device_idx_reversed * y_spacing
|
| 195 |
-
|
| 196 |
# Create rectangle using shape (batch-add later)
|
| 197 |
-
shapes.append(
|
| 198 |
-
|
| 199 |
-
|
| 200 |
-
|
| 201 |
-
|
| 202 |
-
|
| 203 |
-
|
| 204 |
-
|
| 205 |
-
|
| 206 |
-
|
| 207 |
-
|
|
|
|
|
|
|
| 208 |
# Add batch number text (batch-add later)
|
| 209 |
-
annotations.append(
|
| 210 |
-
|
| 211 |
-
|
| 212 |
-
|
| 213 |
-
|
| 214 |
-
|
| 215 |
-
|
| 216 |
-
|
|
|
|
|
|
|
| 217 |
# Prepare hover data (add traces in batches later)
|
| 218 |
hover_text = f"Batch: {task['batch']}<br>Stage: {task['stage']}<br>Type: {name}<br>Start: {task['start_time']:.2f}<br>End: {task['start_time'] + task['duration']:.2f}<br>Duration: {task['duration']:.2f}"
|
| 219 |
-
|
| 220 |
-
hover_traces.append(
|
| 221 |
-
|
| 222 |
-
|
| 223 |
-
|
| 224 |
-
|
| 225 |
-
|
| 226 |
-
|
| 227 |
-
|
| 228 |
-
|
| 229 |
-
|
|
|
|
|
|
|
| 230 |
# Update progress
|
| 231 |
if show_progress:
|
| 232 |
tasks_processed += 1
|
|
@@ -234,63 +248,83 @@ def create_pipeline_figure(schedule_data: Dict[int, List[Dict]], max_time=None,
|
|
| 234 |
|
| 235 |
# Add all shapes at once for better performance
|
| 236 |
fig.update_layout(shapes=shapes)
|
| 237 |
-
|
| 238 |
# Add all annotations at once
|
| 239 |
fig.update_layout(annotations=annotations)
|
| 240 |
-
|
| 241 |
# Add all hover traces at once
|
| 242 |
for trace in hover_traces:
|
| 243 |
fig.add_trace(go.Scatter(**trace))
|
| 244 |
|
| 245 |
# Add custom legend
|
| 246 |
legend_items = []
|
| 247 |
-
|
| 248 |
# Find the maximum virtual stage in the data
|
| 249 |
max_virtual_stage = 0
|
| 250 |
for device in schedule_data:
|
| 251 |
for task in schedule_data[device]:
|
| 252 |
virtual_stage = task["stage"] // num_devices
|
| 253 |
max_virtual_stage = max(max_virtual_stage, virtual_stage)
|
| 254 |
-
|
| 255 |
# Add forward and backward items for each virtual stage
|
| 256 |
for vs in range(max_virtual_stage + 1):
|
| 257 |
-
legend_items.append(
|
| 258 |
-
|
| 259 |
-
|
| 260 |
-
|
| 261 |
-
|
| 262 |
-
|
| 263 |
-
|
| 264 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 265 |
# Add entries for split backward operations if this is a zb1p schedule
|
| 266 |
-
if any(
|
| 267 |
-
|
| 268 |
-
|
| 269 |
-
|
| 270 |
-
|
| 271 |
-
legend_items.append(
|
| 272 |
-
|
| 273 |
-
|
| 274 |
-
|
| 275 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 276 |
# If no tasks found, add default legend items
|
| 277 |
if not legend_items:
|
| 278 |
legend_items = [
|
| 279 |
dict(name="Forward (VS 0)", color=get_color("forward", 0, num_devices)),
|
| 280 |
dict(name="Backward (VS 0)", color=get_color("backward", 0, num_devices)),
|
| 281 |
-
dict(
|
| 282 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 283 |
]
|
| 284 |
-
|
| 285 |
for i, item in enumerate(legend_items):
|
| 286 |
-
fig.add_trace(
|
| 287 |
-
|
| 288 |
-
|
| 289 |
-
|
| 290 |
-
|
| 291 |
-
|
| 292 |
-
|
| 293 |
-
|
|
|
|
|
|
|
| 294 |
if show_progress and i < len(legend_items) - 1:
|
| 295 |
progress_bar.update(1)
|
| 296 |
|
|
@@ -299,11 +333,15 @@ def create_pipeline_figure(schedule_data: Dict[int, List[Dict]], max_time=None,
|
|
| 299 |
# Modify the ordering to put Device 1 at the top, then Device 0, then the rest
|
| 300 |
if num_devices >= 2:
|
| 301 |
# Move Device 1 to the top, followed by Device 0
|
| 302 |
-
device_labels =
|
| 303 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 304 |
# Calculate tick positions with no gaps
|
| 305 |
tick_positions = [(num_devices - i - 1) * y_spacing for i in range(num_devices)]
|
| 306 |
-
|
| 307 |
# Adjust the range to ensure there are no empty spaces at the end
|
| 308 |
x_end = max_time * 1.05 # Add a small margin
|
| 309 |
|
|
@@ -323,17 +361,17 @@ def create_pipeline_figure(schedule_data: Dict[int, List[Dict]], max_time=None,
|
|
| 323 |
text=title_text,
|
| 324 |
x=0.5,
|
| 325 |
y=0.98, # Move title position closer to the top
|
| 326 |
-
font=dict(size=20)
|
| 327 |
),
|
| 328 |
legend=dict(
|
| 329 |
orientation="v", # Changed from horizontal to vertical
|
| 330 |
yanchor="top",
|
| 331 |
y=1.02, # Position at the top
|
| 332 |
xanchor="right",
|
| 333 |
-
x=1.20,
|
| 334 |
title=dict(text="<b>Operation Types:</b>"),
|
| 335 |
itemsizing="constant",
|
| 336 |
-
tracegroupgap=0
|
| 337 |
),
|
| 338 |
width=2000, # Increase width to accommodate the expanded legend
|
| 339 |
height=400, # Maintain current height
|
|
@@ -351,10 +389,13 @@ def create_pipeline_figure(schedule_data: Dict[int, List[Dict]], max_time=None,
|
|
| 351 |
# Cache for storing processed schedule data
|
| 352 |
_schedule_data_cache = {}
|
| 353 |
|
| 354 |
-
|
|
|
|
|
|
|
|
|
|
| 355 |
"""
|
| 356 |
Create a Dash app to visualize the pipeline schedule.
|
| 357 |
-
|
| 358 |
Args:
|
| 359 |
schedule: Schedule object to visualize
|
| 360 |
schedule_type: Type of schedule ("1f1b", "zb1p", or custom description)
|
|
@@ -363,7 +404,7 @@ def create_dash_app(schedule: Schedule, schedule_type="1f1b", enable_caching: bo
|
|
| 363 |
# Process schedule data only once and cache it
|
| 364 |
global _schedule_data_cache
|
| 365 |
cache_key = id(schedule)
|
| 366 |
-
|
| 367 |
if enable_caching and cache_key in _schedule_data_cache:
|
| 368 |
schedule_data = _schedule_data_cache[cache_key]
|
| 369 |
print("Using cached schedule data")
|
|
@@ -372,7 +413,7 @@ def create_dash_app(schedule: Schedule, schedule_type="1f1b", enable_caching: bo
|
|
| 372 |
if enable_caching:
|
| 373 |
_schedule_data_cache[cache_key] = schedule_data
|
| 374 |
print("Cached schedule data")
|
| 375 |
-
|
| 376 |
total_tasks = sum(len(tasks) for tasks in schedule_data.values())
|
| 377 |
print(f"Total tasks in schedule: {total_tasks}")
|
| 378 |
|
|
@@ -380,31 +421,48 @@ def create_dash_app(schedule: Schedule, schedule_type="1f1b", enable_caching: bo
|
|
| 380 |
app.title = f"Pipeline Parallelism Visualization - {schedule_type}"
|
| 381 |
|
| 382 |
# Create a more informative layout with data size information
|
| 383 |
-
app.layout = html.Div(
|
| 384 |
-
|
| 385 |
-
|
| 386 |
-
|
| 387 |
-
|
| 388 |
-
|
| 389 |
-
|
| 390 |
-
|
| 391 |
-
|
| 392 |
-
|
| 393 |
-
|
| 394 |
-
|
| 395 |
-
|
| 396 |
-
|
| 397 |
-
|
| 398 |
-
|
| 399 |
-
|
| 400 |
-
|
| 401 |
-
|
| 402 |
-
|
| 403 |
-
|
| 404 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 405 |
# Cache for storing figure to avoid regenerating it
|
| 406 |
figure_cache = {}
|
| 407 |
-
|
| 408 |
@app.callback(
|
| 409 |
Output("pipeline-graph", "figure"),
|
| 410 |
Input("graph-container", "children"),
|
|
@@ -416,15 +474,15 @@ def create_dash_app(schedule: Schedule, schedule_type="1f1b", enable_caching: bo
|
|
| 416 |
if enable_caching and cache_key in figure_cache:
|
| 417 |
print("Using cached figure")
|
| 418 |
return figure_cache[cache_key]
|
| 419 |
-
|
| 420 |
# Create the figure
|
| 421 |
figure = create_pipeline_figure(schedule_data, show_progress=True)
|
| 422 |
-
|
| 423 |
# Cache the figure
|
| 424 |
if enable_caching:
|
| 425 |
figure_cache[cache_key] = figure
|
| 426 |
print("Cached figure")
|
| 427 |
-
|
| 428 |
return figure
|
| 429 |
|
| 430 |
return app
|
|
@@ -435,11 +493,11 @@ def visualize_pipeline_parallelism_dash(
|
|
| 435 |
port: int = 8050,
|
| 436 |
debug: bool = False,
|
| 437 |
enable_caching: bool = True,
|
| 438 |
-
schedule_type="1f1b"
|
| 439 |
):
|
| 440 |
"""
|
| 441 |
Launch a Dash app to visualize the pipeline schedule interactively.
|
| 442 |
-
|
| 443 |
Args:
|
| 444 |
schedule: Schedule object to visualize
|
| 445 |
port: Port to run the Dash app on
|
|
@@ -447,6 +505,8 @@ def visualize_pipeline_parallelism_dash(
|
|
| 447 |
enable_caching: Whether to cache schedule data and figures
|
| 448 |
schedule_type: Type of schedule ("1f1b", "zb1p", or custom description)
|
| 449 |
"""
|
| 450 |
-
app = create_dash_app(
|
|
|
|
|
|
|
| 451 |
print(f"Starting Dash app on http://localhost:{port}/")
|
| 452 |
app.run_server(debug=debug, port=port)
|
|
|
|
| 12 |
def convert_schedule_to_visualization_format(schedule: Schedule):
|
| 13 |
"""
|
| 14 |
Converts a Schedule object to the format needed for visualization.
|
| 15 |
+
|
| 16 |
Returns:
|
| 17 |
Dict[int, List[Dict]]: Dictionary mapping device_id to a list of operation dictionaries
|
| 18 |
"""
|
| 19 |
# Make sure all operations have start and end times
|
| 20 |
for op in schedule.ops.values():
|
| 21 |
if op.start_time is None or op.end_time is None:
|
| 22 |
+
raise ValueError(
|
| 23 |
+
"Operations must have start and end times. Run ScheduleExecutor.execute() first."
|
| 24 |
+
)
|
| 25 |
+
|
| 26 |
visualization_data = {}
|
| 27 |
+
|
| 28 |
# Organize operations by device
|
| 29 |
for device_id, device_queue in enumerate(schedule.dev_queues):
|
| 30 |
visualization_data[device_id] = []
|
| 31 |
+
|
| 32 |
for op in device_queue.ops:
|
| 33 |
+
visualization_data[device_id].append(
|
| 34 |
+
{
|
| 35 |
+
"type": op.op_type,
|
| 36 |
+
"batch": op.batch_id + 1, # +1 because batch_id is 0-indexed
|
| 37 |
+
"stage": op.stage_id,
|
| 38 |
+
"start_time": op.start_time,
|
| 39 |
+
"duration": op.end_time - op.start_time,
|
| 40 |
+
}
|
| 41 |
+
)
|
| 42 |
+
|
| 43 |
return visualization_data
|
| 44 |
|
| 45 |
|
|
|
|
| 48 |
def get_color(op_type: str, stage_id: int, num_devices: int):
|
| 49 |
# A more harmonious blue palette with better progression for forward operations
|
| 50 |
forward_colors = [
|
| 51 |
+
"#5c88f2", # Periwinkle blue
|
| 52 |
+
"#1a53ff", # Deep blue
|
| 53 |
+
"#b3c6ff", # Light blue
|
| 54 |
+
"#4d79ff", # Strong blue
|
| 55 |
+
"#809fff", # Medium blue
|
| 56 |
+
"#0039e6", # Rich navy
|
| 57 |
+
"#002db3", # Dark navy
|
| 58 |
+
"#264db3", # Royal blue
|
| 59 |
+
"#7094db", # Steel blue
|
| 60 |
+
"#99b3e6", # Pale blue
|
| 61 |
]
|
| 62 |
+
|
| 63 |
# Orange palette for backward operations
|
| 64 |
backward_colors = [
|
| 65 |
+
"#ff9933", # Bright orange
|
| 66 |
+
"#ffad5c", # Medium orange
|
| 67 |
+
"#ffc285", # Light orange
|
| 68 |
+
"#ffd6ad", # Pale orange
|
| 69 |
+
"#ff8000", # Deep orange
|
| 70 |
+
"#cc6600", # Dark orange
|
| 71 |
+
"#ff9933", # Vivid orange
|
| 72 |
+
"#ffb366", # Soft orange
|
| 73 |
+
"#cc9966", # Muted orange
|
| 74 |
+
"#ffd699", # Light amber
|
| 75 |
]
|
| 76 |
+
|
| 77 |
# Improved teal/turquoise palette with better progression for backward_D operations
|
| 78 |
backward_d_colors = [
|
| 79 |
+
"#80ffff", # Light cyan
|
| 80 |
+
"#00cccc", # Teal
|
| 81 |
+
"#00e6e6", # Bright teal
|
| 82 |
+
"#33ffff", # Cyan
|
| 83 |
+
"#00b3b3", # Medium teal
|
| 84 |
+
"#008080", # Dark teal
|
| 85 |
+
"#00e6cc", # Turquoise
|
| 86 |
+
"#4ddbbd", # Aqua
|
| 87 |
+
"#80d4c8", # Pale teal
|
| 88 |
+
"#b3e6e0", # Ice
|
| 89 |
]
|
| 90 |
+
|
| 91 |
# Improved green palette with better progression for backward_W operations
|
| 92 |
backward_w_colors = [
|
| 93 |
+
"#00cc66", # Medium green
|
| 94 |
+
"#00e673", # Bright green
|
| 95 |
+
"#33ff99", # Mint green
|
| 96 |
+
"#80ffbf", # Light green
|
| 97 |
+
"#009933", # Forest green
|
| 98 |
+
"#006622", # Dark green
|
| 99 |
+
"#33cc33", # True green
|
| 100 |
+
"#66cc66", # Sage green
|
| 101 |
+
"#99cc99", # Pale green
|
| 102 |
+
"#c6e6c6", # Pastel green
|
| 103 |
]
|
| 104 |
|
| 105 |
virtual_stage = stage_id // num_devices
|
|
|
|
| 119 |
raise ValueError(f"Invalid operation type: {op_type}")
|
| 120 |
|
| 121 |
|
| 122 |
+
def create_pipeline_figure(
|
| 123 |
+
schedule_data: Dict[int, List[Dict]], max_time=None, show_progress=True
|
| 124 |
+
):
|
| 125 |
"""
|
| 126 |
Create a Plotly figure for pipeline parallelism scheduling.
|
| 127 |
|
|
|
|
| 132 |
"""
|
| 133 |
# Find the number of devices
|
| 134 |
num_devices = len(schedule_data)
|
| 135 |
+
|
| 136 |
empty_color = "whitesmoke"
|
| 137 |
+
|
| 138 |
# Find the maximum time in the schedule if not provided
|
| 139 |
if max_time is None:
|
| 140 |
max_time = 0
|
|
|
|
| 152 |
tasks_processed = 0
|
| 153 |
|
| 154 |
if show_progress:
|
| 155 |
+
progress_bar = tqdm(
|
| 156 |
+
total=total_tasks + num_devices + 3, desc="Creating visualization"
|
| 157 |
+
)
|
| 158 |
|
| 159 |
# Create a custom y-axis with no gaps between devices
|
| 160 |
y_spacing = 1.0 # Use 1.0 for no gaps
|
|
|
|
| 167 |
# Add rectangles for each task
|
| 168 |
for device_idx, device in enumerate(schedule_data):
|
| 169 |
device_idx_reversed = num_devices - device_idx - 1
|
| 170 |
+
|
| 171 |
# Sort tasks by start time to ensure correct rendering
|
| 172 |
sorted_tasks = sorted(schedule_data[device], key=lambda t: t["start_time"])
|
| 173 |
|
|
|
|
| 197 |
# Add rectangle for the task
|
| 198 |
start_time = task["start_time"]
|
| 199 |
duration = task["duration"]
|
| 200 |
+
|
| 201 |
# Calculate y positions with no gaps
|
| 202 |
y_pos = device_idx_reversed * y_spacing
|
| 203 |
+
|
| 204 |
# Create rectangle using shape (batch-add later)
|
| 205 |
+
shapes.append(
|
| 206 |
+
dict(
|
| 207 |
+
type="rect",
|
| 208 |
+
x0=start_time,
|
| 209 |
+
y0=y_pos - 0.5,
|
| 210 |
+
x1=start_time + duration,
|
| 211 |
+
y1=y_pos + 0.5,
|
| 212 |
+
line=dict(color="black", width=0.5),
|
| 213 |
+
fillcolor=color,
|
| 214 |
+
layer="above",
|
| 215 |
+
)
|
| 216 |
+
)
|
| 217 |
+
|
| 218 |
# Add batch number text (batch-add later)
|
| 219 |
+
annotations.append(
|
| 220 |
+
dict(
|
| 221 |
+
x=start_time + duration / 2,
|
| 222 |
+
y=y_pos,
|
| 223 |
+
text=f"{task['batch']}",
|
| 224 |
+
showarrow=False,
|
| 225 |
+
font=dict(color=text_color, size=12, family="Arial, bold"),
|
| 226 |
+
)
|
| 227 |
+
)
|
| 228 |
+
|
| 229 |
# Prepare hover data (add traces in batches later)
|
| 230 |
hover_text = f"Batch: {task['batch']}<br>Stage: {task['stage']}<br>Type: {name}<br>Start: {task['start_time']:.2f}<br>End: {task['start_time'] + task['duration']:.2f}<br>Duration: {task['duration']:.2f}"
|
| 231 |
+
|
| 232 |
+
hover_traces.append(
|
| 233 |
+
dict(
|
| 234 |
+
x=[start_time + duration / 2],
|
| 235 |
+
y=[y_pos],
|
| 236 |
+
mode="markers",
|
| 237 |
+
marker=dict(opacity=0), # Invisible marker
|
| 238 |
+
hoverinfo="text",
|
| 239 |
+
text=hover_text,
|
| 240 |
+
showlegend=False,
|
| 241 |
+
)
|
| 242 |
+
)
|
| 243 |
+
|
| 244 |
# Update progress
|
| 245 |
if show_progress:
|
| 246 |
tasks_processed += 1
|
|
|
|
| 248 |
|
| 249 |
# Add all shapes at once for better performance
|
| 250 |
fig.update_layout(shapes=shapes)
|
| 251 |
+
|
| 252 |
# Add all annotations at once
|
| 253 |
fig.update_layout(annotations=annotations)
|
| 254 |
+
|
| 255 |
# Add all hover traces at once
|
| 256 |
for trace in hover_traces:
|
| 257 |
fig.add_trace(go.Scatter(**trace))
|
| 258 |
|
| 259 |
# Add custom legend
|
| 260 |
legend_items = []
|
| 261 |
+
|
| 262 |
# Find the maximum virtual stage in the data
|
| 263 |
max_virtual_stage = 0
|
| 264 |
for device in schedule_data:
|
| 265 |
for task in schedule_data[device]:
|
| 266 |
virtual_stage = task["stage"] // num_devices
|
| 267 |
max_virtual_stage = max(max_virtual_stage, virtual_stage)
|
| 268 |
+
|
| 269 |
# Add forward and backward items for each virtual stage
|
| 270 |
for vs in range(max_virtual_stage + 1):
|
| 271 |
+
legend_items.append(
|
| 272 |
+
dict(
|
| 273 |
+
name=f"Forward (VS {vs})",
|
| 274 |
+
color=get_color("forward", vs * num_devices, num_devices),
|
| 275 |
+
)
|
| 276 |
+
)
|
| 277 |
+
legend_items.append(
|
| 278 |
+
dict(
|
| 279 |
+
name=f"Backward (VS {vs})",
|
| 280 |
+
color=get_color("backward", vs * num_devices, num_devices),
|
| 281 |
+
)
|
| 282 |
+
)
|
| 283 |
# Add entries for split backward operations if this is a zb1p schedule
|
| 284 |
+
if any(
|
| 285 |
+
task["type"] in ["backward_D", "backward_W"]
|
| 286 |
+
for device in schedule_data
|
| 287 |
+
for task in schedule_data[device]
|
| 288 |
+
):
|
| 289 |
+
legend_items.append(
|
| 290 |
+
dict(
|
| 291 |
+
name=f"Backward Grad (VS {vs})",
|
| 292 |
+
color=get_color("backward_D", vs * num_devices, num_devices),
|
| 293 |
+
)
|
| 294 |
+
)
|
| 295 |
+
legend_items.append(
|
| 296 |
+
dict(
|
| 297 |
+
name=f"Backward Weight (VS {vs})",
|
| 298 |
+
color=get_color("backward_W", vs * num_devices, num_devices),
|
| 299 |
+
)
|
| 300 |
+
)
|
| 301 |
+
|
| 302 |
# If no tasks found, add default legend items
|
| 303 |
if not legend_items:
|
| 304 |
legend_items = [
|
| 305 |
dict(name="Forward (VS 0)", color=get_color("forward", 0, num_devices)),
|
| 306 |
dict(name="Backward (VS 0)", color=get_color("backward", 0, num_devices)),
|
| 307 |
+
dict(
|
| 308 |
+
name="Backward Grad (VS 0)",
|
| 309 |
+
color=get_color("backward_D", 0, num_devices),
|
| 310 |
+
),
|
| 311 |
+
dict(
|
| 312 |
+
name="Backward Weight (VS 0)",
|
| 313 |
+
color=get_color("backward_W", 0, num_devices),
|
| 314 |
+
),
|
| 315 |
]
|
| 316 |
+
|
| 317 |
for i, item in enumerate(legend_items):
|
| 318 |
+
fig.add_trace(
|
| 319 |
+
go.Scatter(
|
| 320 |
+
x=[None],
|
| 321 |
+
y=[None],
|
| 322 |
+
mode="markers",
|
| 323 |
+
marker=dict(size=10, color=item["color"]),
|
| 324 |
+
name=item["name"],
|
| 325 |
+
showlegend=True,
|
| 326 |
+
)
|
| 327 |
+
)
|
| 328 |
if show_progress and i < len(legend_items) - 1:
|
| 329 |
progress_bar.update(1)
|
| 330 |
|
|
|
|
| 333 |
# Modify the ordering to put Device 1 at the top, then Device 0, then the rest
|
| 334 |
if num_devices >= 2:
|
| 335 |
# Move Device 1 to the top, followed by Device 0
|
| 336 |
+
device_labels = (
|
| 337 |
+
[device_labels[1], device_labels[0]] + device_labels[2:]
|
| 338 |
+
if num_devices > 1
|
| 339 |
+
else device_labels
|
| 340 |
+
)
|
| 341 |
+
|
| 342 |
# Calculate tick positions with no gaps
|
| 343 |
tick_positions = [(num_devices - i - 1) * y_spacing for i in range(num_devices)]
|
| 344 |
+
|
| 345 |
# Adjust the range to ensure there are no empty spaces at the end
|
| 346 |
x_end = max_time * 1.05 # Add a small margin
|
| 347 |
|
|
|
|
| 361 |
text=title_text,
|
| 362 |
x=0.5,
|
| 363 |
y=0.98, # Move title position closer to the top
|
| 364 |
+
font=dict(size=20),
|
| 365 |
),
|
| 366 |
legend=dict(
|
| 367 |
orientation="v", # Changed from horizontal to vertical
|
| 368 |
yanchor="top",
|
| 369 |
y=1.02, # Position at the top
|
| 370 |
xanchor="right",
|
| 371 |
+
x=1.20, # Position further to the right to accommodate more items
|
| 372 |
title=dict(text="<b>Operation Types:</b>"),
|
| 373 |
itemsizing="constant",
|
| 374 |
+
tracegroupgap=0,
|
| 375 |
),
|
| 376 |
width=2000, # Increase width to accommodate the expanded legend
|
| 377 |
height=400, # Maintain current height
|
|
|
|
| 389 |
# Cache for storing processed schedule data
|
| 390 |
_schedule_data_cache = {}
|
| 391 |
|
| 392 |
+
|
| 393 |
+
def create_dash_app(
|
| 394 |
+
schedule: Schedule, schedule_type="1f1b", enable_caching: bool = True
|
| 395 |
+
):
|
| 396 |
"""
|
| 397 |
Create a Dash app to visualize the pipeline schedule.
|
| 398 |
+
|
| 399 |
Args:
|
| 400 |
schedule: Schedule object to visualize
|
| 401 |
schedule_type: Type of schedule ("1f1b", "zb1p", or custom description)
|
|
|
|
| 404 |
# Process schedule data only once and cache it
|
| 405 |
global _schedule_data_cache
|
| 406 |
cache_key = id(schedule)
|
| 407 |
+
|
| 408 |
if enable_caching and cache_key in _schedule_data_cache:
|
| 409 |
schedule_data = _schedule_data_cache[cache_key]
|
| 410 |
print("Using cached schedule data")
|
|
|
|
| 413 |
if enable_caching:
|
| 414 |
_schedule_data_cache[cache_key] = schedule_data
|
| 415 |
print("Cached schedule data")
|
| 416 |
+
|
| 417 |
total_tasks = sum(len(tasks) for tasks in schedule_data.values())
|
| 418 |
print(f"Total tasks in schedule: {total_tasks}")
|
| 419 |
|
|
|
|
| 421 |
app.title = f"Pipeline Parallelism Visualization - {schedule_type}"
|
| 422 |
|
| 423 |
# Create a more informative layout with data size information
|
| 424 |
+
app.layout = html.Div(
|
| 425 |
+
[
|
| 426 |
+
html.H1(
|
| 427 |
+
f"Pipeline Parallelism Visualization - {schedule_type}",
|
| 428 |
+
style={"textAlign": "center"},
|
| 429 |
+
),
|
| 430 |
+
html.Div(
|
| 431 |
+
[
|
| 432 |
+
html.P(
|
| 433 |
+
f"Number of devices: {len(schedule_data)}",
|
| 434 |
+
style={"display": "inline-block", "marginRight": "20px"},
|
| 435 |
+
),
|
| 436 |
+
html.P(
|
| 437 |
+
f"Total tasks: {total_tasks}",
|
| 438 |
+
style={"display": "inline-block", "marginRight": "20px"},
|
| 439 |
+
),
|
| 440 |
+
],
|
| 441 |
+
style={"marginBottom": "20px"},
|
| 442 |
+
),
|
| 443 |
+
html.Div(id="graph-container", children=[]),
|
| 444 |
+
dcc.Loading(
|
| 445 |
+
id="loading-graph",
|
| 446 |
+
type="circle",
|
| 447 |
+
children=[
|
| 448 |
+
dcc.Graph(
|
| 449 |
+
id="pipeline-graph",
|
| 450 |
+
config={
|
| 451 |
+
"displayModeBar": True,
|
| 452 |
+
"toImageButtonOptions": {
|
| 453 |
+
"format": "png",
|
| 454 |
+
"filename": "pipeline_visualization",
|
| 455 |
+
},
|
| 456 |
+
},
|
| 457 |
+
),
|
| 458 |
+
],
|
| 459 |
+
),
|
| 460 |
+
]
|
| 461 |
+
)
|
| 462 |
+
|
| 463 |
# Cache for storing figure to avoid regenerating it
|
| 464 |
figure_cache = {}
|
| 465 |
+
|
| 466 |
@app.callback(
|
| 467 |
Output("pipeline-graph", "figure"),
|
| 468 |
Input("graph-container", "children"),
|
|
|
|
| 474 |
if enable_caching and cache_key in figure_cache:
|
| 475 |
print("Using cached figure")
|
| 476 |
return figure_cache[cache_key]
|
| 477 |
+
|
| 478 |
# Create the figure
|
| 479 |
figure = create_pipeline_figure(schedule_data, show_progress=True)
|
| 480 |
+
|
| 481 |
# Cache the figure
|
| 482 |
if enable_caching:
|
| 483 |
figure_cache[cache_key] = figure
|
| 484 |
print("Cached figure")
|
| 485 |
+
|
| 486 |
return figure
|
| 487 |
|
| 488 |
return app
|
|
|
|
| 493 |
port: int = 8050,
|
| 494 |
debug: bool = False,
|
| 495 |
enable_caching: bool = True,
|
| 496 |
+
schedule_type="1f1b",
|
| 497 |
):
|
| 498 |
"""
|
| 499 |
Launch a Dash app to visualize the pipeline schedule interactively.
|
| 500 |
+
|
| 501 |
Args:
|
| 502 |
schedule: Schedule object to visualize
|
| 503 |
port: Port to run the Dash app on
|
|
|
|
| 505 |
enable_caching: Whether to cache schedule data and figures
|
| 506 |
schedule_type: Type of schedule ("1f1b", "zb1p", or custom description)
|
| 507 |
"""
|
| 508 |
+
app = create_dash_app(
|
| 509 |
+
schedule, schedule_type=schedule_type, enable_caching=enable_caching
|
| 510 |
+
)
|
| 511 |
print(f"Starting Dash app on http://localhost:{port}/")
|
| 512 |
app.run_server(debug=debug, port=port)
|