Spaces:
Running
Running
Add VPP support and refactor project.
Browse files- .gitignore +9 -77
- LICENSE +21 -0
- README-dash-visualizer.md +0 -91
- README.md +80 -62
- conf/config.yaml +22 -0
- configs/standard.json +0 -8
- main.py +62 -0
- pipeline.py +0 -491
- pipeline_1f1b.png +0 -3
- pyproject.toml +67 -0
- requirements-dash.txt +0 -5
- src/__init__.py +3 -0
- src/execution_model.py +219 -0
- src/strategies.py +192 -0
- dash_visualizer.py → src/visualizer.py +195 -157
- visualizer.py +0 -141
.gitignore
CHANGED
|
@@ -1,78 +1,10 @@
|
|
| 1 |
# Python
|
| 2 |
-
|
| 3 |
-
|
| 4 |
-
|
| 5 |
-
|
| 6 |
-
|
| 7 |
-
|
| 8 |
-
|
| 9 |
-
|
| 10 |
-
|
| 11 |
-
eggs/
|
| 12 |
-
.eggs/
|
| 13 |
-
lib/
|
| 14 |
-
lib64/
|
| 15 |
-
parts/
|
| 16 |
-
sdist/
|
| 17 |
-
var/
|
| 18 |
-
wheels/
|
| 19 |
-
*.egg-info/
|
| 20 |
-
.installed.cfg
|
| 21 |
-
*.egg
|
| 22 |
-
|
| 23 |
-
# Virtual Environment
|
| 24 |
-
venv/
|
| 25 |
-
env/
|
| 26 |
-
ENV/
|
| 27 |
-
.env
|
| 28 |
-
|
| 29 |
-
# IDE specific files
|
| 30 |
-
.idea/
|
| 31 |
-
.vscode/
|
| 32 |
-
*.swp
|
| 33 |
-
*.swo
|
| 34 |
-
.DS_Store
|
| 35 |
-
|
| 36 |
-
# Jupyter Notebook
|
| 37 |
-
.ipynb_checkpoints
|
| 38 |
-
|
| 39 |
-
# Distribution / packaging
|
| 40 |
-
.Python
|
| 41 |
-
env/
|
| 42 |
-
build/
|
| 43 |
-
develop-eggs/
|
| 44 |
-
dist/
|
| 45 |
-
downloads/
|
| 46 |
-
eggs/
|
| 47 |
-
.eggs/
|
| 48 |
-
lib/
|
| 49 |
-
lib64/
|
| 50 |
-
parts/
|
| 51 |
-
sdist/
|
| 52 |
-
var/
|
| 53 |
-
wheels/
|
| 54 |
-
*.egg-info/
|
| 55 |
-
.installed.cfg
|
| 56 |
-
*.egg
|
| 57 |
-
|
| 58 |
-
# Unit test / coverage reports
|
| 59 |
-
htmlcov/
|
| 60 |
-
.tox/
|
| 61 |
-
.coverage
|
| 62 |
-
.coverage.*
|
| 63 |
-
.cache
|
| 64 |
-
nosetests.xml
|
| 65 |
-
coverage.xml
|
| 66 |
-
*.cover
|
| 67 |
-
.hypothesis/
|
| 68 |
-
|
| 69 |
-
# Pipeline visualization outputs
|
| 70 |
-
*.png
|
| 71 |
-
*.jpg
|
| 72 |
-
*.jpeg
|
| 73 |
-
*.pdf
|
| 74 |
-
*.svg
|
| 75 |
-
|
| 76 |
-
# Local configuration
|
| 77 |
-
config.ini
|
| 78 |
-
secrets.json
|
|
|
|
| 1 |
# Python
|
| 2 |
+
./venv
|
| 3 |
+
uv.lock
|
| 4 |
+
outputs/
|
| 5 |
+
|
| 6 |
+
# Uncomment below if you want to include these files
|
| 7 |
+
# !assets/*.png
|
| 8 |
+
# !assets/*.jpg
|
| 9 |
+
# !docs/*.png
|
| 10 |
+
# !docs/*.jpg
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
LICENSE
ADDED
|
@@ -0,0 +1,21 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
MIT License
|
| 2 |
+
|
| 3 |
+
Copyright (c) 2024
|
| 4 |
+
|
| 5 |
+
Permission is hereby granted, free of charge, to any person obtaining a copy
|
| 6 |
+
of this software and associated documentation files (the "Software"), to deal
|
| 7 |
+
in the Software without restriction, including without limitation the rights
|
| 8 |
+
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
| 9 |
+
copies of the Software, and to permit persons to whom the Software is
|
| 10 |
+
furnished to do so, subject to the following conditions:
|
| 11 |
+
|
| 12 |
+
The above copyright notice and this permission notice shall be included in all
|
| 13 |
+
copies or substantial portions of the Software.
|
| 14 |
+
|
| 15 |
+
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
| 16 |
+
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
| 17 |
+
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
| 18 |
+
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
| 19 |
+
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
| 20 |
+
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
| 21 |
+
SOFTWARE.
|
README-dash-visualizer.md
DELETED
|
@@ -1,91 +0,0 @@
|
|
| 1 |
-
# Pipeline Parallelism Dash Visualizer
|
| 2 |
-
|
| 3 |
-
This is an interactive Dash-based visualizer for pipeline parallelism scheduling, complementing the existing Matplotlib-based visualization.
|
| 4 |
-
|
| 5 |
-
## Features
|
| 6 |
-
|
| 7 |
-
- **Static image generation** similar to the Matplotlib version
|
| 8 |
-
- **Interactive web-based visualization** with Dash
|
| 9 |
-
- **Download functionality** to save the visualization as PNG
|
| 10 |
-
- **Progress indication** during figure creation and image generation
|
| 11 |
-
- **Compatible API** with the existing visualizer
|
| 12 |
-
|
| 13 |
-
## Installation
|
| 14 |
-
|
| 15 |
-
Install the required dependencies:
|
| 16 |
-
|
| 17 |
-
```bash
|
| 18 |
-
pip install -r requirements-dash.txt
|
| 19 |
-
```
|
| 20 |
-
|
| 21 |
-
## Usage
|
| 22 |
-
|
| 23 |
-
### From Python
|
| 24 |
-
|
| 25 |
-
```python
|
| 26 |
-
from pipeline import create_1f1b_schedule
|
| 27 |
-
from dash_visualizer import visualize_pipeline_parallelism_dash, save_pipeline_visualization_plotly
|
| 28 |
-
|
| 29 |
-
# Create a schedule
|
| 30 |
-
schedule = create_1f1b_schedule(
|
| 31 |
-
num_stages=4,
|
| 32 |
-
num_batches=8,
|
| 33 |
-
forward_times=[1.0, 1.0, 1.0, 1.0],
|
| 34 |
-
backward_times=[2.0, 2.0, 2.0, 2.0],
|
| 35 |
-
)
|
| 36 |
-
|
| 37 |
-
# Generate a static image
|
| 38 |
-
save_pipeline_visualization_plotly(
|
| 39 |
-
schedule=schedule,
|
| 40 |
-
schedule_type="1f1b",
|
| 41 |
-
output_file="pipeline_plotly.png"
|
| 42 |
-
)
|
| 43 |
-
|
| 44 |
-
# OR launch an interactive Dash app
|
| 45 |
-
visualize_pipeline_parallelism_dash(
|
| 46 |
-
schedule=schedule,
|
| 47 |
-
schedule_type="1f1b",
|
| 48 |
-
port=8050,
|
| 49 |
-
debug=False
|
| 50 |
-
)
|
| 51 |
-
```
|
| 52 |
-
|
| 53 |
-
### Using the Command Line
|
| 54 |
-
|
| 55 |
-
You can use the updated command line interface:
|
| 56 |
-
|
| 57 |
-
```bash
|
| 58 |
-
# Generate a static image with Dash/Plotly
|
| 59 |
-
python pipeline.py --visualizer dash --output-file pipeline_viz.png
|
| 60 |
-
|
| 61 |
-
# Launch an interactive Dash app
|
| 62 |
-
python pipeline.py --visualizer dash-interactive
|
| 63 |
-
|
| 64 |
-
# Use the original Matplotlib visualizer
|
| 65 |
-
python pipeline.py --visualizer matplotlib
|
| 66 |
-
```
|
| 67 |
-
|
| 68 |
-
You can also use the dash_visualizer.py script directly for testing:
|
| 69 |
-
|
| 70 |
-
```bash
|
| 71 |
-
# Generate a static image
|
| 72 |
-
python dash_visualizer.py --output test_viz.png
|
| 73 |
-
|
| 74 |
-
# Launch an interactive app
|
| 75 |
-
python dash_visualizer.py --interactive
|
| 76 |
-
```
|
| 77 |
-
|
| 78 |
-
## Differences from Matplotlib Visualizer
|
| 79 |
-
|
| 80 |
-
The Dash-based visualizer provides all the same visual elements as the Matplotlib version:
|
| 81 |
-
- Color-coded rectangles for forward, backward, and optimizer operations
|
| 82 |
-
- Batch numbers displayed inside each rectangle
|
| 83 |
-
- Device labels on the y-axis
|
| 84 |
-
- Clear legend
|
| 85 |
-
|
| 86 |
-
Additional features:
|
| 87 |
-
- Interactive web interface
|
| 88 |
-
- Hovering over elements to see details
|
| 89 |
-
- Download button to save the visualization
|
| 90 |
-
- Progress bars for tracking visualization creation
|
| 91 |
-
- Responsive layout that works well on different screen sizes
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
README.md
CHANGED
|
@@ -1,77 +1,95 @@
|
|
| 1 |
-
# Pipeline Parallelism
|
| 2 |
|
| 3 |
-
This
|
| 4 |
|
| 5 |
-
##
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 6 |
|
| 7 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 8 |
|
|
|
|
| 9 |
```bash
|
| 10 |
-
python
|
| 11 |
```
|
| 12 |
-
|
| 13 |
-
|
| 14 |
-
|
| 15 |
-
|
| 16 |
-
| Option | Short | Description |
|
| 17 |
-
|--------|-------|-------------|
|
| 18 |
-
| `--config` | `-c` | Path to config file (JSON or YAML) |
|
| 19 |
-
| `--num-stages` | `-s` | Number of pipeline stages (devices) |
|
| 20 |
-
| `--num-batches` | `-b` | Number of micro-batches |
|
| 21 |
-
| `--forward-times` | `-f` | Time for forward pass at each stage (space-separated list) |
|
| 22 |
-
| `--backward-times` | `-bw` | Time for backward pass at each stage (space-separated list) |
|
| 23 |
-
| `--output` | `-o` | Output file path for visualization |
|
| 24 |
-
| `--no-visualization` | | Skip visualization generation |
|
| 25 |
-
| `--p2p-time`| | P2P communication time of PP |
|
| 26 |
-
|
| 27 |
-
### Using Configuration Files
|
| 28 |
-
|
| 29 |
-
You can use either JSON or YAML configuration files:
|
| 30 |
-
|
| 31 |
-
Example JSON configuration (sample_config.json):
|
| 32 |
-
```json
|
| 33 |
-
{
|
| 34 |
-
"num_stages": 6,
|
| 35 |
-
"num_batches": 12,
|
| 36 |
-
"forward_times": [0.8, 1.0, 1.2, 1.0, 0.9, 1.1],
|
| 37 |
-
"backward_times": [1.6, 2.0, 2.4, 2.0, 1.8, 2.2],
|
| 38 |
-
"output_file": "pipeline_1f1b_custom.png"
|
| 39 |
-
}
|
| 40 |
```
|
| 41 |
|
| 42 |
-
|
| 43 |
-
|
| 44 |
-
|
| 45 |
-
|
| 46 |
-
|
| 47 |
-
|
| 48 |
-
|
| 49 |
-
|
| 50 |
-
|
| 51 |
-
|
| 52 |
-
|
| 53 |
-
|
| 54 |
-
|
| 55 |
-
|
| 56 |
-
|
| 57 |
-
|
| 58 |
-
|
| 59 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 60 |
```
|
| 61 |
|
| 62 |
-
##
|
| 63 |
|
| 64 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 65 |
|
| 66 |
-
|
| 67 |
-
1. **Warmup Phase**: Forward passes for the first several micro-batches
|
| 68 |
-
2. **Steady State**: Each device alternates between forward and backward passes
|
| 69 |
-
3. **Cooldown Phase**: Backward passes to complete the computation for remaining micro-batches
|
| 70 |
|
| 71 |
-
|
| 72 |
|
| 73 |
-
##
|
| 74 |
|
| 75 |
-
|
| 76 |
-
- GPipe: Efficient Training of Giant Neural Networks using Pipeline Parallelism (NeurIPS'19)
|
| 77 |
-
- Megatron-LM: Training Multi-Billion Parameter Language Models Using Model Parallelism
|
|
|
|
| 1 |
+
# Pipeline Parallelism Emulation
|
| 2 |
|
| 3 |
+
This project provides tools for emulating and visualizing pipeline parallelism strategies used in large language model training.
|
| 4 |
|
| 5 |
+
## Overview
|
| 6 |
+
|
| 7 |
+
Pipeline parallelism is a technique used to train large models by partitioning the model across multiple devices and processing data in a pipelined fashion. This project allows you to:
|
| 8 |
+
|
| 9 |
+
- Simulate different pipeline parallelism strategies (1F1B, Interleaved)
|
| 10 |
+
- Visualize the execution schedule on multiple devices
|
| 11 |
+
- Compare different strategies for efficiency
|
| 12 |
+
|
| 13 |
+
## Features
|
| 14 |
+
- Supported Pipeline Stragegies:
|
| 15 |
+
- 1F1B
|
| 16 |
+
- Interleaved 1F1B
|
| 17 |
+
- Visualization:
|
| 18 |
+
- Interactive visualization dashboard using Plotly/Dash
|
| 19 |
+
- Config:
|
| 20 |
+
- Configurable simulation parameters through Hydra
|
| 21 |
+
- Each stage
|
| 22 |
+
|
| 23 |
+
## Installation
|
| 24 |
+
|
| 25 |
+
This project uses [uv](https://github.com/astral-sh/uv) for dependency management.
|
| 26 |
|
| 27 |
+
Setup `uv` if not installed in your computer:
|
| 28 |
+
```
|
| 29 |
+
# On macOS and Linux.
|
| 30 |
+
curl -LsSf https://astral.sh/uv/install.sh | sh
|
| 31 |
+
```
|
| 32 |
+
|
| 33 |
+
## Usage
|
| 34 |
|
| 35 |
+
Running for 1F1B strategy:
|
| 36 |
```bash
|
| 37 |
+
uv run python main.py strategy=1f1b num_devices=4 num_stages=4 num_batches=8
|
| 38 |
```
|
| 39 |
+
|
| 40 |
+
```bash
|
| 41 |
+
uv run python main.py strategy=interleave num_devices=4 num_stages=8 num_batches=8
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 42 |
```
|
| 43 |
|
| 44 |
+
## Configuration
|
| 45 |
+
|
| 46 |
+
The default configuration is in `conf/config.yaml`. You can override any parameter on the command line or create configuration groups for different scenarios.
|
| 47 |
+
|
| 48 |
+
### Using Different Configuration Files
|
| 49 |
+
|
| 50 |
+
You can use different configuration files with Hydra in several ways:
|
| 51 |
+
|
| 52 |
+
#### Recommended Approach
|
| 53 |
+
|
| 54 |
+
1. Create multiple configuration files in the `conf` directory for different use cases:
|
| 55 |
+
```
|
| 56 |
+
conf/
|
| 57 |
+
├── config.yaml # Default configuration
|
| 58 |
+
└── model_A.yaml # Create your own config with stage-specific latency for performance projection.
|
| 59 |
+
```
|
| 60 |
+
|
| 61 |
+
2. Run with your desired configuration using the `--config-name` flag:
|
| 62 |
+
```bash
|
| 63 |
+
uv run python main.py --config-name=model_A
|
| 64 |
+
```
|
| 65 |
+
|
| 66 |
+
#### Override Specific Parameters
|
| 67 |
+
|
| 68 |
+
You can also override specific parameters at runtime:
|
| 69 |
+
```bash
|
| 70 |
+
uv run python main.py op_times.forward=0.5 op_times.backward=1.0 num_batches=6
|
| 71 |
```
|
| 72 |
|
| 73 |
+
## Project Structure
|
| 74 |
|
| 75 |
+
```
|
| 76 |
+
PP-Emulation/
|
| 77 |
+
├── conf/ # Hydra configuration files
|
| 78 |
+
│ └── config.yaml # Default configuration
|
| 79 |
+
├── src/ # Source code
|
| 80 |
+
│ ├── __init__.py # Package initialization
|
| 81 |
+
│ ├── execution_model.py # Schedule execution models
|
| 82 |
+
│ ├── strategies.py # Pipeline parallelism strategies
|
| 83 |
+
│ └── visualizer.py # Visualization utilities
|
| 84 |
+
├── main.py # Main entry point
|
| 85 |
+
├── pyproject.toml # Project metadata and dependencies
|
| 86 |
+
└── README.md # This file
|
| 87 |
+
```
|
| 88 |
|
| 89 |
+
## License
|
|
|
|
|
|
|
|
|
|
| 90 |
|
| 91 |
+
This project is licensed under the MIT License - see the LICENSE file for details.
|
| 92 |
|
| 93 |
+
## Contributing
|
| 94 |
|
| 95 |
+
Contributions are welcome! Please feel free to submit a Pull Request.
|
|
|
|
|
|
conf/config.yaml
ADDED
|
@@ -0,0 +1,22 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Default configuration for Pipeline Parallelism Emulation
|
| 2 |
+
num_devices: 4
|
| 3 |
+
num_stages: 4
|
| 4 |
+
num_batches: 12
|
| 5 |
+
visualization_port: 8050
|
| 6 |
+
strategy: "1f1b" # Options: "1f1b", "interleave"
|
| 7 |
+
p2p_latency: 0.0
|
| 8 |
+
|
| 9 |
+
# Operation time configurations
|
| 10 |
+
op_times:
|
| 11 |
+
# Option 1: Simple configuration (same time for all stages)
|
| 12 |
+
forward: 1.0
|
| 13 |
+
backward: 2.0
|
| 14 |
+
|
| 15 |
+
# Option 2: Commented example of stage-specific configuration
|
| 16 |
+
# forward:
|
| 17 |
+
# 0: 0.8 # Stage 0 forward time
|
| 18 |
+
# 1: 1.2 # Stage 1 forward time
|
| 19 |
+
# 2: 1.5 # Stage 2 forward time
|
| 20 |
+
# 3: 0.9 # Stage 3 forward time
|
| 21 |
+
# backward:
|
| 22 |
+
# 0: 1.0 # Stage 0 backward time
|
configs/standard.json
DELETED
|
@@ -1,8 +0,0 @@
|
|
| 1 |
-
{
|
| 2 |
-
"num_stages": 4,
|
| 3 |
-
"num_batches": 8,
|
| 4 |
-
"forward_times": [1.0, 1.0, 1.0, 1.0],
|
| 5 |
-
"backward_times": [2.0, 2.0, 2.0, 2.0],
|
| 6 |
-
"output_file": "pipeline_1f1b.png",
|
| 7 |
-
"p2p_time": 0.0
|
| 8 |
-
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
main.py
ADDED
|
@@ -0,0 +1,62 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from src.execution_model import ScheduleConfig, ScheduleExecutor
|
| 2 |
+
from src.strategies import generate_1f1b_interleave_schedule, generate_1f1b_schedule
|
| 3 |
+
from src.visualizer import visualize_pipeline_parallelism_dash, save_pipeline_visualization_plotly
|
| 4 |
+
import hydra
|
| 5 |
+
from omegaconf import DictConfig, OmegaConf
|
| 6 |
+
|
| 7 |
+
|
| 8 |
+
@hydra.main(config_path="conf", config_name="config", version_base=None)
|
| 9 |
+
def main(cfg: DictConfig) -> None:
|
| 10 |
+
"""Run pipeline parallelism simulation with the specified configuration."""
|
| 11 |
+
print(f"Running with configuration: {cfg}")
|
| 12 |
+
|
| 13 |
+
if cfg.strategy == "1f1b":
|
| 14 |
+
run_1f1b(cfg)
|
| 15 |
+
elif cfg.strategy == "interleave":
|
| 16 |
+
run_interleave(cfg)
|
| 17 |
+
else:
|
| 18 |
+
raise ValueError(f"Unknown strategy: {cfg.strategy}")
|
| 19 |
+
|
| 20 |
+
|
| 21 |
+
def run_1f1b(cfg: DictConfig) -> None:
|
| 22 |
+
"""Run 1F1B pipeline parallelism simulation."""
|
| 23 |
+
# Convert OmegaConf to dict for op_times if it exists
|
| 24 |
+
op_times = OmegaConf.to_container(cfg.op_times) if hasattr(cfg, 'op_times') else None
|
| 25 |
+
|
| 26 |
+
schedule_config = ScheduleConfig(
|
| 27 |
+
num_devices=cfg.num_devices,
|
| 28 |
+
num_stages=cfg.num_stages,
|
| 29 |
+
num_batches=cfg.num_batches,
|
| 30 |
+
p2p_latency=cfg.p2p_latency,
|
| 31 |
+
op_times=op_times,
|
| 32 |
+
placement_strategy="1f1b"
|
| 33 |
+
)
|
| 34 |
+
schedule = generate_1f1b_schedule(schedule_config)
|
| 35 |
+
executor = ScheduleExecutor(schedule)
|
| 36 |
+
executor.execute()
|
| 37 |
+
|
| 38 |
+
visualize_pipeline_parallelism_dash(schedule, port=cfg.visualization_port)
|
| 39 |
+
|
| 40 |
+
|
| 41 |
+
def run_interleave(cfg: DictConfig) -> None:
|
| 42 |
+
"""Run interleaved pipeline parallelism simulation."""
|
| 43 |
+
# Convert OmegaConf to dict for op_times if it exists
|
| 44 |
+
op_times = OmegaConf.to_container(cfg.op_times) if hasattr(cfg, 'op_times') else None
|
| 45 |
+
|
| 46 |
+
schedule_config = ScheduleConfig(
|
| 47 |
+
num_devices=cfg.num_devices,
|
| 48 |
+
num_stages=cfg.num_stages,
|
| 49 |
+
num_batches=cfg.num_batches,
|
| 50 |
+
p2p_latency=cfg.p2p_latency,
|
| 51 |
+
placement_strategy="interleave",
|
| 52 |
+
op_times=op_times
|
| 53 |
+
)
|
| 54 |
+
schedule = generate_1f1b_interleave_schedule(schedule_config)
|
| 55 |
+
executor = ScheduleExecutor(schedule)
|
| 56 |
+
executor.execute()
|
| 57 |
+
|
| 58 |
+
visualize_pipeline_parallelism_dash(schedule, port=cfg.visualization_port)
|
| 59 |
+
|
| 60 |
+
|
| 61 |
+
if __name__ == "__main__":
|
| 62 |
+
main()
|
pipeline.py
DELETED
|
@@ -1,491 +0,0 @@
|
|
| 1 |
-
import argparse
|
| 2 |
-
import json
|
| 3 |
-
import yaml
|
| 4 |
-
import os
|
| 5 |
-
from typing import List, Dict
|
| 6 |
-
|
| 7 |
-
# Import visualization function from the new module
|
| 8 |
-
from visualizer import visualize_pipeline_parallelism
|
| 9 |
-
try:
|
| 10 |
-
from dash_visualizer import visualize_pipeline_parallelism_dash, save_pipeline_visualization_plotly
|
| 11 |
-
DASH_AVAILABLE = True
|
| 12 |
-
except ImportError:
|
| 13 |
-
DASH_AVAILABLE = False
|
| 14 |
-
|
| 15 |
-
|
| 16 |
-
def create_1f1b_schedule(
|
| 17 |
-
num_stages: int,
|
| 18 |
-
num_batches: int,
|
| 19 |
-
forward_times: List[float],
|
| 20 |
-
backward_times: List[float],
|
| 21 |
-
p2p_time: float = 0.0,
|
| 22 |
-
) -> Dict[int, List[Dict]]:
|
| 23 |
-
"""
|
| 24 |
-
Create a 1F1B (One-Forward-One-Backward) schedule for pipeline parallelism.
|
| 25 |
-
|
| 26 |
-
This implementation takes a data-centric approach:
|
| 27 |
-
1. First determine the operation sequence for each pipeline stage (which microbatch to process when)
|
| 28 |
-
2. Then calculate timing based on dependencies between operations
|
| 29 |
-
|
| 30 |
-
The 1F1B pattern has three phases:
|
| 31 |
-
- Warmup: Forward passes for first num_stages microbatches
|
| 32 |
-
- Steady state: Alternating between forward and backward passes
|
| 33 |
-
- Cooldown: Backward passes for remaining microbatches
|
| 34 |
-
|
| 35 |
-
Returns:
|
| 36 |
-
A dictionary mapping device IDs to lists of tasks.
|
| 37 |
-
Each task is a dictionary with keys:
|
| 38 |
-
- 'type': 'forward' or 'backward'
|
| 39 |
-
- 'batch': batch number
|
| 40 |
-
- 'start_time': start time of the task
|
| 41 |
-
- 'duration': duration of the task
|
| 42 |
-
"""
|
| 43 |
-
# Initialize empty schedule
|
| 44 |
-
schedule = {stage: [] for stage in range(num_stages)}
|
| 45 |
-
|
| 46 |
-
# Step 1: Determine operation sequence for each stage
|
| 47 |
-
# This will generate the sequence of operations (forward/backward on which microbatch)
|
| 48 |
-
# that each stage should perform, without timing information yet
|
| 49 |
-
operation_sequence = determine_1f1b_operation_sequence(num_stages, num_batches)
|
| 50 |
-
|
| 51 |
-
# Step 2: Convert operation sequence to schedule with timing
|
| 52 |
-
# Taking into account dependencies between operations
|
| 53 |
-
schedule = calculate_operation_timing(
|
| 54 |
-
operation_sequence, num_stages, forward_times, backward_times, p2p_time
|
| 55 |
-
)
|
| 56 |
-
|
| 57 |
-
return schedule
|
| 58 |
-
|
| 59 |
-
|
| 60 |
-
def determine_1f1b_operation_sequence(
|
| 61 |
-
num_stages: int, num_batches: int
|
| 62 |
-
) -> Dict[int, List[Dict]]:
|
| 63 |
-
"""
|
| 64 |
-
Determine the sequence of operations (forward/backward) for each stage in 1F1B scheduling.
|
| 65 |
-
|
| 66 |
-
Args:
|
| 67 |
-
num_stages: Number of pipeline stages
|
| 68 |
-
num_batches: Number of micro-batches
|
| 69 |
-
|
| 70 |
-
Returns:
|
| 71 |
-
Dictionary mapping stage ID to a list of operations in sequence.
|
| 72 |
-
Each operation is a dict with keys 'type' ('forward' or 'backward') and 'batch'.
|
| 73 |
-
"""
|
| 74 |
-
operation_sequence = {i: [] for i in range(num_stages)}
|
| 75 |
-
for current_stage in range(num_stages):
|
| 76 |
-
warmup_batches = num_stages - current_stage
|
| 77 |
-
for j in range(1, warmup_batches + 1):
|
| 78 |
-
operation_sequence[current_stage].append({"type": "forward", "batch": j})
|
| 79 |
-
steady_batches = num_batches - warmup_batches
|
| 80 |
-
for j in range(warmup_batches + 1, warmup_batches + steady_batches + 1):
|
| 81 |
-
operation_sequence[current_stage].append(
|
| 82 |
-
{"type": "backward", "batch": j - warmup_batches}
|
| 83 |
-
)
|
| 84 |
-
operation_sequence[current_stage].append({"type": "forward", "batch": j})
|
| 85 |
-
for j in range(warmup_batches):
|
| 86 |
-
operation_sequence[current_stage].append(
|
| 87 |
-
{"type": "backward", "batch": j + steady_batches + 1}
|
| 88 |
-
)
|
| 89 |
-
|
| 90 |
-
return operation_sequence
|
| 91 |
-
|
| 92 |
-
|
| 93 |
-
def calculate_operation_timing(
|
| 94 |
-
operation_sequence: Dict[int, List[Dict]],
|
| 95 |
-
num_stages: int,
|
| 96 |
-
forward_times: List[float],
|
| 97 |
-
backward_times: List[float],
|
| 98 |
-
p2p_time: float = 0.0,
|
| 99 |
-
) -> Dict[int, List[Dict]]:
|
| 100 |
-
"""
|
| 101 |
-
Recursively calculate the specific timing of each operation in a 1F1B schedule.
|
| 102 |
-
|
| 103 |
-
When encountering an operation that depends on a previous operation that hasn't been calculated yet,
|
| 104 |
-
it will recursively calculate the timing of those operations.
|
| 105 |
-
|
| 106 |
-
Args:
|
| 107 |
-
operation_sequence: Operation sequence for each stage
|
| 108 |
-
num_stages: Number of pipeline stages
|
| 109 |
-
forward_times: Forward propagation time for each stage
|
| 110 |
-
backward_times: Backward propagation time for each stage
|
| 111 |
-
p2p_time: Point-to-point communication time between stages
|
| 112 |
-
|
| 113 |
-
Returns:
|
| 114 |
-
Complete schedule with timing information, each operation includes start_time and duration
|
| 115 |
-
"""
|
| 116 |
-
# Initialize schedule with timing information
|
| 117 |
-
schedule = {i: [] for i in range(num_stages)}
|
| 118 |
-
|
| 119 |
-
# For recording already computed operation end times
|
| 120 |
-
# Format: {(stage, batch, op_type): (start_time, end_time)}
|
| 121 |
-
computed_ops = {}
|
| 122 |
-
|
| 123 |
-
# For recording the end time of the last operation for each stage
|
| 124 |
-
stage_last_end_time = [0.0] * num_stages
|
| 125 |
-
|
| 126 |
-
# Helper function: recursively calculate the time for an operation
|
| 127 |
-
def compute_op_time(stage, batch, op_type):
|
| 128 |
-
# Check if this operation has already been calculated
|
| 129 |
-
key = (stage, batch, op_type)
|
| 130 |
-
if key in computed_ops:
|
| 131 |
-
return computed_ops[key]
|
| 132 |
-
|
| 133 |
-
# Get operation duration
|
| 134 |
-
duration = (
|
| 135 |
-
forward_times[stage] if op_type == "forward" else backward_times[stage]
|
| 136 |
-
)
|
| 137 |
-
|
| 138 |
-
# Determine start time (dependent on other operations)
|
| 139 |
-
# 1. Consider sequential dependencies on the stage (must wait for previous operation to complete)
|
| 140 |
-
start_time = stage_last_end_time[stage]
|
| 141 |
-
|
| 142 |
-
# 2. Forward pass also depends on forward pass of previous stage (if not the first stage)
|
| 143 |
-
if op_type == "forward" and stage > 0:
|
| 144 |
-
# Recursively calculate the time for the forward pass of the previous stage (if not calculated yet)
|
| 145 |
-
prev_stage_key = (stage - 1, batch, "forward")
|
| 146 |
-
if prev_stage_key not in computed_ops:
|
| 147 |
-
prev_start, prev_end = compute_op_time(stage - 1, batch, "forward")
|
| 148 |
-
else:
|
| 149 |
-
_, prev_end = computed_ops[prev_stage_key]
|
| 150 |
-
# Update start time
|
| 151 |
-
start_time = max(start_time, prev_end + p2p_time)
|
| 152 |
-
|
| 153 |
-
# 3. Backward pass depends on:
|
| 154 |
-
elif op_type == "backward":
|
| 155 |
-
# a. Forward pass of the same stage
|
| 156 |
-
same_stage_forward_key = (stage, batch, "forward")
|
| 157 |
-
if same_stage_forward_key not in computed_ops:
|
| 158 |
-
_, forward_end = compute_op_time(stage, batch, "forward")
|
| 159 |
-
else:
|
| 160 |
-
_, forward_end = computed_ops[same_stage_forward_key]
|
| 161 |
-
|
| 162 |
-
start_time = max(start_time, forward_end)
|
| 163 |
-
|
| 164 |
-
# b. Backward pass of the next stage (if not the last stage)
|
| 165 |
-
if stage < num_stages - 1:
|
| 166 |
-
next_stage_backward_key = (stage + 1, batch, "backward")
|
| 167 |
-
if next_stage_backward_key not in computed_ops:
|
| 168 |
-
_, next_backward_end = compute_op_time(stage + 1, batch, "backward")
|
| 169 |
-
else:
|
| 170 |
-
_, next_backward_end = computed_ops[next_stage_backward_key]
|
| 171 |
-
|
| 172 |
-
start_time = max(start_time, next_backward_end + p2p_time)
|
| 173 |
-
|
| 174 |
-
# Calculate end time
|
| 175 |
-
end_time = start_time + duration
|
| 176 |
-
|
| 177 |
-
# Store calculation results
|
| 178 |
-
computed_ops[key] = (start_time, end_time)
|
| 179 |
-
|
| 180 |
-
# Update the end time of the last operation for this stage
|
| 181 |
-
stage_last_end_time[stage] = end_time
|
| 182 |
-
|
| 183 |
-
return start_time, end_time
|
| 184 |
-
|
| 185 |
-
# Calculate time for each operation in the operation_sequence
|
| 186 |
-
for i in range(len(operation_sequence[0])):
|
| 187 |
-
for stage in range(num_stages):
|
| 188 |
-
batch = operation_sequence[stage][i]["batch"]
|
| 189 |
-
op_type = operation_sequence[stage][i]["type"]
|
| 190 |
-
|
| 191 |
-
# Recursively calculate the time for this operation
|
| 192 |
-
start_time, _ = compute_op_time(stage, batch, op_type)
|
| 193 |
-
|
| 194 |
-
# Fill in scheduling information
|
| 195 |
-
op_with_timing = operation_sequence[stage][i].copy()
|
| 196 |
-
op_with_timing["start_time"] = start_time
|
| 197 |
-
op_with_timing["duration"] = (
|
| 198 |
-
forward_times[stage] if op_type == "forward" else backward_times[stage]
|
| 199 |
-
)
|
| 200 |
-
schedule[stage].append(op_with_timing)
|
| 201 |
-
|
| 202 |
-
return schedule
|
| 203 |
-
|
| 204 |
-
|
| 205 |
-
def get_schedule_info(schedule: Dict[int, List[Dict]]):
|
| 206 |
-
num_stages = len(schedule)
|
| 207 |
-
|
| 208 |
-
max_time = 0
|
| 209 |
-
for device in schedule:
|
| 210 |
-
for task in schedule[device]:
|
| 211 |
-
end_time = task["start_time"] + task["duration"]
|
| 212 |
-
if end_time > max_time:
|
| 213 |
-
max_time = end_time
|
| 214 |
-
|
| 215 |
-
total_execution_time = max_time * num_stages
|
| 216 |
-
|
| 217 |
-
total_computation_time = 0
|
| 218 |
-
device_computation_times = {}
|
| 219 |
-
|
| 220 |
-
for device in schedule:
|
| 221 |
-
device_computation_time = 0
|
| 222 |
-
for task in schedule[device]:
|
| 223 |
-
device_computation_time += task["duration"]
|
| 224 |
-
device_computation_times[device] = device_computation_time
|
| 225 |
-
total_computation_time += device_computation_time
|
| 226 |
-
|
| 227 |
-
bubble_rate = (
|
| 228 |
-
total_execution_time - total_computation_time
|
| 229 |
-
) / total_computation_time
|
| 230 |
-
|
| 231 |
-
return {
|
| 232 |
-
"bubble_rate": f"{bubble_rate*100:.2f}%",
|
| 233 |
-
"execution_time": f"{max_time / 1000:.2f} s",
|
| 234 |
-
}
|
| 235 |
-
|
| 236 |
-
|
| 237 |
-
def read_config_file(config_path):
|
| 238 |
-
"""
|
| 239 |
-
Read configuration from a JSON or YAML file.
|
| 240 |
-
|
| 241 |
-
Args:
|
| 242 |
-
config_path: Path to the config file (JSON or YAML)
|
| 243 |
-
|
| 244 |
-
Returns:
|
| 245 |
-
Dictionary containing configuration parameters
|
| 246 |
-
"""
|
| 247 |
-
if not os.path.exists(config_path):
|
| 248 |
-
raise FileNotFoundError(f"Config file not found: {config_path}")
|
| 249 |
-
|
| 250 |
-
file_ext = os.path.splitext(config_path)[1].lower()
|
| 251 |
-
|
| 252 |
-
try:
|
| 253 |
-
with open(config_path, "r") as f:
|
| 254 |
-
if file_ext == ".json":
|
| 255 |
-
config = json.load(f)
|
| 256 |
-
elif file_ext in (".yaml", ".yml"):
|
| 257 |
-
config = yaml.safe_load(f)
|
| 258 |
-
else:
|
| 259 |
-
raise ValueError(
|
| 260 |
-
f"Unsupported config file format: {file_ext}. Use .json, .yaml, or .yml"
|
| 261 |
-
)
|
| 262 |
-
return config
|
| 263 |
-
except Exception as e:
|
| 264 |
-
raise ValueError(f"Error reading config file: {str(e)}")
|
| 265 |
-
|
| 266 |
-
|
| 267 |
-
def parse_args():
|
| 268 |
-
"""
|
| 269 |
-
Parse command-line arguments for the pipeline parallelism tool.
|
| 270 |
-
|
| 271 |
-
Returns:
|
| 272 |
-
Parsed arguments namespace
|
| 273 |
-
"""
|
| 274 |
-
parser = argparse.ArgumentParser(
|
| 275 |
-
description="Pipeline Parallelism Scheduler and Visualizer",
|
| 276 |
-
formatter_class=argparse.ArgumentDefaultsHelpFormatter,
|
| 277 |
-
)
|
| 278 |
-
|
| 279 |
-
# Config file option
|
| 280 |
-
parser.add_argument(
|
| 281 |
-
"--config", "-c", type=str, help="Path to config file (JSON or YAML)"
|
| 282 |
-
)
|
| 283 |
-
|
| 284 |
-
# Main parameters
|
| 285 |
-
parser.add_argument(
|
| 286 |
-
"--num-stages",
|
| 287 |
-
"-s",
|
| 288 |
-
type=int,
|
| 289 |
-
default=0,
|
| 290 |
-
help="Number of pipeline stages (devices)",
|
| 291 |
-
)
|
| 292 |
-
|
| 293 |
-
parser.add_argument(
|
| 294 |
-
"--num-batches", "-b", type=int, default=0, help="Number of micro-batches"
|
| 295 |
-
)
|
| 296 |
-
|
| 297 |
-
# Forward and backward times
|
| 298 |
-
parser.add_argument(
|
| 299 |
-
"--forward-times",
|
| 300 |
-
"-f",
|
| 301 |
-
type=float,
|
| 302 |
-
nargs="+",
|
| 303 |
-
help="Time for forward pass at each stage (space-separated list)",
|
| 304 |
-
)
|
| 305 |
-
|
| 306 |
-
parser.add_argument(
|
| 307 |
-
"--backward-times",
|
| 308 |
-
"-bw",
|
| 309 |
-
type=float,
|
| 310 |
-
nargs="+",
|
| 311 |
-
help="Time for backward pass at each stage (space-separated list)",
|
| 312 |
-
)
|
| 313 |
-
|
| 314 |
-
# Output options
|
| 315 |
-
parser.add_argument(
|
| 316 |
-
"--output",
|
| 317 |
-
"-o",
|
| 318 |
-
type=str,
|
| 319 |
-
default="pipeline_1f1b.png",
|
| 320 |
-
help="Output file path for visualization",
|
| 321 |
-
)
|
| 322 |
-
|
| 323 |
-
parser.add_argument(
|
| 324 |
-
"--no-visualization", action="store_true", help="Skip visualization generation"
|
| 325 |
-
)
|
| 326 |
-
|
| 327 |
-
parser.add_argument(
|
| 328 |
-
"--p2p-time",
|
| 329 |
-
type=float,
|
| 330 |
-
default=0.0,
|
| 331 |
-
help="Time for point-to-point communication between stages",
|
| 332 |
-
)
|
| 333 |
-
|
| 334 |
-
parser.add_argument("--visualizer", choices=["matplotlib", "dash", "dash-interactive"],
|
| 335 |
-
default="matplotlib", help="Visualization library to use")
|
| 336 |
-
|
| 337 |
-
return parser.parse_args()
|
| 338 |
-
|
| 339 |
-
|
| 340 |
-
def example_usage():
|
| 341 |
-
"""Example usage of the visualization function and testing the scheduling algorithms."""
|
| 342 |
-
# Example parameters
|
| 343 |
-
num_stages = 4 # Number of pipeline stages (devices)
|
| 344 |
-
num_batches = 10 # Number of micro-batches
|
| 345 |
-
|
| 346 |
-
# Example times for forward and backward passes for each stage
|
| 347 |
-
forward_times = [1.0, 1.0, 1.0, 1.0] # Time for forward pass at each stage
|
| 348 |
-
backward_times = [2.0, 2.0, 2.0, 2.0] # Time for backward pass at each stage
|
| 349 |
-
|
| 350 |
-
# Create 1F1B schedule
|
| 351 |
-
schedule = create_1f1b_schedule(
|
| 352 |
-
num_stages=num_stages,
|
| 353 |
-
num_batches=num_batches,
|
| 354 |
-
forward_times=forward_times,
|
| 355 |
-
backward_times=backward_times,
|
| 356 |
-
)
|
| 357 |
-
|
| 358 |
-
# Create visualization with the schedule
|
| 359 |
-
visualize_pipeline_parallelism(
|
| 360 |
-
schedule=schedule, schedule_type="1f1b", output_file="pipeline_1f1b.png"
|
| 361 |
-
)
|
| 362 |
-
|
| 363 |
-
# Analyze the schedule
|
| 364 |
-
schedule_info = get_schedule_info(schedule)
|
| 365 |
-
print(schedule_info)
|
| 366 |
-
|
| 367 |
-
|
| 368 |
-
def main():
|
| 369 |
-
"""
|
| 370 |
-
Main function that parses arguments and runs the pipeline parallelism analysis.
|
| 371 |
-
"""
|
| 372 |
-
args = parse_args()
|
| 373 |
-
|
| 374 |
-
# Initialize with default values
|
| 375 |
-
num_stages = 4
|
| 376 |
-
num_batches = 10
|
| 377 |
-
forward_times = None
|
| 378 |
-
backward_times = None
|
| 379 |
-
output_file = "pipeline_1f1b.png"
|
| 380 |
-
p2p_time = 0.0
|
| 381 |
-
|
| 382 |
-
# Command line arguments override config file
|
| 383 |
-
num_stages = args.num_stages
|
| 384 |
-
num_batches = args.num_batches
|
| 385 |
-
forward_times = args.forward_times
|
| 386 |
-
backward_times = args.backward_times
|
| 387 |
-
output_file = args.output
|
| 388 |
-
p2p_time = args.p2p_time
|
| 389 |
-
|
| 390 |
-
# Read from config file if provided
|
| 391 |
-
if args.config:
|
| 392 |
-
try:
|
| 393 |
-
print(f"Reading configuration from {args.config}")
|
| 394 |
-
config = read_config_file(args.config)
|
| 395 |
-
|
| 396 |
-
# Update parameters from config
|
| 397 |
-
num_stages = config.get("num_stages", num_stages)
|
| 398 |
-
num_batches = config.get("num_batches", num_batches)
|
| 399 |
-
forward_times = config.get("forward_times")
|
| 400 |
-
backward_times = config.get("backward_times")
|
| 401 |
-
output_file = config.get("output_file", output_file)
|
| 402 |
-
p2p_time = config.get("p2p_time", 0.0)
|
| 403 |
-
|
| 404 |
-
except Exception as e:
|
| 405 |
-
print(f"Error reading config file: {str(e)}")
|
| 406 |
-
print("Falling back to command line arguments or defaults")
|
| 407 |
-
|
| 408 |
-
# Validate inputs
|
| 409 |
-
if forward_times is None:
|
| 410 |
-
forward_times = [1.0] * num_stages
|
| 411 |
-
elif len(forward_times) != num_stages:
|
| 412 |
-
print(
|
| 413 |
-
f"Warning: forward_times length ({len(forward_times)}) doesn't match num_stages ({num_stages})"
|
| 414 |
-
)
|
| 415 |
-
if len(forward_times) < num_stages:
|
| 416 |
-
# Extend with repeats of the last value
|
| 417 |
-
forward_times = list(forward_times) + [forward_times[-1]] * (
|
| 418 |
-
num_stages - len(forward_times)
|
| 419 |
-
)
|
| 420 |
-
else:
|
| 421 |
-
# Truncate
|
| 422 |
-
forward_times = forward_times[:num_stages]
|
| 423 |
-
print(f"Adjusted forward_times: {forward_times}")
|
| 424 |
-
|
| 425 |
-
if backward_times is None:
|
| 426 |
-
backward_times = [2.0] * num_stages
|
| 427 |
-
elif len(backward_times) != num_stages:
|
| 428 |
-
print(
|
| 429 |
-
f"Warning: backward_times length ({len(backward_times)}) doesn't match num_stages ({num_stages})"
|
| 430 |
-
)
|
| 431 |
-
if len(backward_times) < num_stages:
|
| 432 |
-
# Extend with repeats of the last value
|
| 433 |
-
backward_times = list(backward_times) + [backward_times[-1]] * (
|
| 434 |
-
num_stages - len(backward_times)
|
| 435 |
-
)
|
| 436 |
-
else:
|
| 437 |
-
# Truncate
|
| 438 |
-
backward_times = backward_times[:num_stages]
|
| 439 |
-
print(f"Adjusted backward_times: {backward_times}")
|
| 440 |
-
|
| 441 |
-
print(f"Running with parameters:")
|
| 442 |
-
print(f" num_stages: {num_stages}")
|
| 443 |
-
print(f" num_batches: {num_batches}")
|
| 444 |
-
print(f" forward_times: {forward_times}")
|
| 445 |
-
print(f" backward_times: {backward_times}")
|
| 446 |
-
print(f" output_file: {output_file}")
|
| 447 |
-
|
| 448 |
-
# Create 1F1B schedule
|
| 449 |
-
schedule = create_1f1b_schedule(
|
| 450 |
-
num_stages=num_stages,
|
| 451 |
-
num_batches=num_batches,
|
| 452 |
-
forward_times=forward_times,
|
| 453 |
-
backward_times=backward_times,
|
| 454 |
-
p2p_time=p2p_time,
|
| 455 |
-
)
|
| 456 |
-
|
| 457 |
-
# Create visualization unless --no-visualization is specified
|
| 458 |
-
if not args.no_visualization:
|
| 459 |
-
if args.visualizer == "matplotlib" or not DASH_AVAILABLE:
|
| 460 |
-
if not DASH_AVAILABLE and args.visualizer in ["dash", "dash-interactive"]:
|
| 461 |
-
print("Warning: Dash not available. Falling back to matplotlib.")
|
| 462 |
-
visualize_pipeline_parallelism(
|
| 463 |
-
schedule=schedule, schedule_type="1f1b", output_file=output_file
|
| 464 |
-
)
|
| 465 |
-
elif args.visualizer == "dash":
|
| 466 |
-
# Get output file name without extension to use the appropriate extension
|
| 467 |
-
output_base = os.path.splitext(output_file)[0]
|
| 468 |
-
output_dash = f"{output_base}_plotly.png"
|
| 469 |
-
save_pipeline_visualization_plotly(
|
| 470 |
-
schedule=schedule, schedule_type="1f1b", output_file=output_dash
|
| 471 |
-
)
|
| 472 |
-
elif args.visualizer == "dash-interactive":
|
| 473 |
-
print("Using Dash interactive visualization")
|
| 474 |
-
visualize_pipeline_parallelism_dash(
|
| 475 |
-
schedule=schedule, schedule_type="1f1b", port=8050, debug=False
|
| 476 |
-
)
|
| 477 |
-
|
| 478 |
-
# Analyze the schedule
|
| 479 |
-
schedule_info = get_schedule_info(schedule)
|
| 480 |
-
print(schedule_info)
|
| 481 |
-
|
| 482 |
-
return {
|
| 483 |
-
"schedule": schedule,
|
| 484 |
-
"schedule_info": schedule_info,
|
| 485 |
-
"num_stages": num_stages,
|
| 486 |
-
"num_batches": num_batches,
|
| 487 |
-
}
|
| 488 |
-
|
| 489 |
-
|
| 490 |
-
if __name__ == "__main__":
|
| 491 |
-
main()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
pipeline_1f1b.png
DELETED
Git LFS Details
|
pyproject.toml
ADDED
|
@@ -0,0 +1,67 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
[build-system]
|
| 2 |
+
requires = ["hatchling"]
|
| 3 |
+
build-backend = "hatchling.build"
|
| 4 |
+
|
| 5 |
+
[project]
|
| 6 |
+
name = "pp-emulation"
|
| 7 |
+
version = "0.1.0"
|
| 8 |
+
description = "Pipeline Parallelism Emulation and Visualization"
|
| 9 |
+
readme = "README.md"
|
| 10 |
+
requires-python = ">=3.10"
|
| 11 |
+
authors = [
|
| 12 |
+
{name = "Project Author"}
|
| 13 |
+
]
|
| 14 |
+
classifiers = [
|
| 15 |
+
"Programming Language :: Python :: 3",
|
| 16 |
+
"License :: OSI Approved :: MIT License",
|
| 17 |
+
"Operating System :: OS Independent",
|
| 18 |
+
]
|
| 19 |
+
dependencies = [
|
| 20 |
+
"dash>=2.14.0",
|
| 21 |
+
"hydra-core>=1.3.2",
|
| 22 |
+
"omegaconf>=2.3.0",
|
| 23 |
+
"plotly>=5.18.0",
|
| 24 |
+
"pandas>=2.1.0",
|
| 25 |
+
"numpy>=1.26.0",
|
| 26 |
+
"tqdm>=4.67.0",
|
| 27 |
+
]
|
| 28 |
+
|
| 29 |
+
[project.optional-dependencies]
|
| 30 |
+
dev = [
|
| 31 |
+
"pytest>=7.4.0",
|
| 32 |
+
"black>=23.7.0",
|
| 33 |
+
"isort>=5.12.0",
|
| 34 |
+
"mypy>=1.5.1",
|
| 35 |
+
]
|
| 36 |
+
|
| 37 |
+
# Add Hatch configuration to explicitly define where source code is located
|
| 38 |
+
[tool.hatch.build.targets.wheel]
|
| 39 |
+
packages = ["src"]
|
| 40 |
+
|
| 41 |
+
[tool.hatch.build.targets.sdist]
|
| 42 |
+
include = [
|
| 43 |
+
"src",
|
| 44 |
+
"main.py",
|
| 45 |
+
"conf",
|
| 46 |
+
"LICENSE",
|
| 47 |
+
"README.md",
|
| 48 |
+
]
|
| 49 |
+
|
| 50 |
+
[tool.black]
|
| 51 |
+
line-length = 88
|
| 52 |
+
target-version = ["py310"]
|
| 53 |
+
|
| 54 |
+
[tool.isort]
|
| 55 |
+
profile = "black"
|
| 56 |
+
line_length = 88
|
| 57 |
+
|
| 58 |
+
[tool.mypy]
|
| 59 |
+
python_version = "3.10"
|
| 60 |
+
warn_return_any = true
|
| 61 |
+
warn_unused_configs = true
|
| 62 |
+
disallow_untyped_defs = true
|
| 63 |
+
disallow_incomplete_defs = true
|
| 64 |
+
|
| 65 |
+
[tool.pytest]
|
| 66 |
+
testpaths = ["tests"]
|
| 67 |
+
pythonpath = ["."]
|
requirements-dash.txt
DELETED
|
@@ -1,5 +0,0 @@
|
|
| 1 |
-
dash==2.13.0
|
| 2 |
-
plotly==5.18.0
|
| 3 |
-
numpy
|
| 4 |
-
kaleido # For static image export
|
| 5 |
-
tqdm # For progress bars
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
src/__init__.py
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Pipeline Parallelism Emulation and Visualization package."""
|
| 2 |
+
|
| 3 |
+
__version__ = "0.1.0"
|
src/execution_model.py
ADDED
|
@@ -0,0 +1,219 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from collections import defaultdict
|
| 2 |
+
from typing import Dict, List, Optional, Union
|
| 3 |
+
|
| 4 |
+
|
| 5 |
+
class Operation:
|
| 6 |
+
"""Operation is a single operation in the pipeline."""
|
| 7 |
+
|
| 8 |
+
def __init__(self, batch_id: int, stage_id: int, op_type: str):
|
| 9 |
+
self.batch_id = batch_id
|
| 10 |
+
self.stage_id = stage_id
|
| 11 |
+
self.op_type = op_type
|
| 12 |
+
self.device_id = None
|
| 13 |
+
|
| 14 |
+
self.start_time = None
|
| 15 |
+
self.end_time = None
|
| 16 |
+
|
| 17 |
+
|
| 18 |
+
class DeviceQueue:
|
| 19 |
+
def __init__(self, stages: List[int], device_id: int):
|
| 20 |
+
self.stages = stages
|
| 21 |
+
self.device_id = device_id
|
| 22 |
+
self.ops = [] # List of operations
|
| 23 |
+
|
| 24 |
+
def add_operation(self, op: Operation):
|
| 25 |
+
assert op.stage_id in self.stages
|
| 26 |
+
self.ops.append(op)
|
| 27 |
+
assert op.device_id is None
|
| 28 |
+
op.device_id = self.device_id
|
| 29 |
+
|
| 30 |
+
|
| 31 |
+
class ScheduleConfig:
|
| 32 |
+
def __init__(
|
| 33 |
+
self,
|
| 34 |
+
num_devices: int,
|
| 35 |
+
num_stages: int,
|
| 36 |
+
num_batches: int,
|
| 37 |
+
p2p_latency: float = 0.0,
|
| 38 |
+
placement_strategy: str = "normal",
|
| 39 |
+
op_times: Optional[Dict[str, Union[float, Dict[int, float]]]] = None,
|
| 40 |
+
):
|
| 41 |
+
self.num_devices = num_devices
|
| 42 |
+
self.num_stages = num_stages
|
| 43 |
+
self.num_batches = num_batches
|
| 44 |
+
self.p2p_latency = p2p_latency
|
| 45 |
+
self.placement_strategy = placement_strategy
|
| 46 |
+
|
| 47 |
+
# Initialize default operation times
|
| 48 |
+
self.op_times = {
|
| 49 |
+
"forward": 1.0,
|
| 50 |
+
"backward": 2.0,
|
| 51 |
+
}
|
| 52 |
+
|
| 53 |
+
# Update with user-provided operation times
|
| 54 |
+
if op_times:
|
| 55 |
+
for op_type, times in op_times.items():
|
| 56 |
+
if isinstance(times, dict):
|
| 57 |
+
# If a dict is provided, it maps stage_id -> time
|
| 58 |
+
if op_type not in self.op_times:
|
| 59 |
+
self.op_times[op_type] = {}
|
| 60 |
+
elif not isinstance(self.op_times[op_type], dict):
|
| 61 |
+
# Convert float to dict if needed
|
| 62 |
+
self.op_times[op_type] = {i: self.op_times[op_type] for i in range(num_stages)}
|
| 63 |
+
|
| 64 |
+
# Update with provided stage-specific times
|
| 65 |
+
for stage_id, time in times.items():
|
| 66 |
+
if not isinstance(self.op_times[op_type], dict):
|
| 67 |
+
self.op_times[op_type] = {i: self.op_times[op_type] for i in range(num_stages)}
|
| 68 |
+
self.op_times[op_type][stage_id] = time
|
| 69 |
+
else:
|
| 70 |
+
# If a float is provided, use same time for all stages
|
| 71 |
+
self.op_times[op_type] = times
|
| 72 |
+
|
| 73 |
+
assert num_stages % num_devices == 0, "num_stages must be divisible by num_devices"
|
| 74 |
+
self.num_stages_per_device = num_stages // num_devices
|
| 75 |
+
|
| 76 |
+
self.init_device_to_stages()
|
| 77 |
+
assert (
|
| 78 |
+
sum(len(stages) for stages in self.device_to_stages.values()) == num_stages
|
| 79 |
+
)
|
| 80 |
+
|
| 81 |
+
def init_device_to_stages(self):
|
| 82 |
+
if self.placement_strategy == "normal":
|
| 83 |
+
# Evenly distributed
|
| 84 |
+
stages_per_device = self.num_stages // self.num_devices
|
| 85 |
+
self.device_to_stages = defaultdict(list)
|
| 86 |
+
for i in range(self.num_stages):
|
| 87 |
+
device_to_put = i // stages_per_device
|
| 88 |
+
self.device_to_stages[device_to_put].append(i)
|
| 89 |
+
elif self.placement_strategy == "interleave":
|
| 90 |
+
self.device_to_stages = defaultdict(list)
|
| 91 |
+
for i in range(self.num_stages):
|
| 92 |
+
device_to_put = i % self.num_devices
|
| 93 |
+
self.device_to_stages[device_to_put].append(i)
|
| 94 |
+
else:
|
| 95 |
+
raise ValueError(f"Invalid placement strategy: {self.placement_strategy}")
|
| 96 |
+
|
| 97 |
+
def get_op_time(self, op_type: str, stage_id: int):
|
| 98 |
+
if op_type not in self.op_times:
|
| 99 |
+
raise ValueError(f"Invalid operation type: {op_type}")
|
| 100 |
+
|
| 101 |
+
times = self.op_times[op_type]
|
| 102 |
+
if isinstance(times, dict):
|
| 103 |
+
# If we have stage-specific times, use those
|
| 104 |
+
if stage_id not in times:
|
| 105 |
+
raise ValueError(f"No time specified for operation {op_type} at stage {stage_id}")
|
| 106 |
+
return times[stage_id]
|
| 107 |
+
else:
|
| 108 |
+
# If we have a single float, use the same value for all stages
|
| 109 |
+
return times
|
| 110 |
+
|
| 111 |
+
|
| 112 |
+
class Schedule:
|
| 113 |
+
def __init__(self, config: ScheduleConfig):
|
| 114 |
+
self.ops = {} # (batch_id, stage_id, op_type) -> Operation
|
| 115 |
+
self.dev_queues: List[DeviceQueue] = []
|
| 116 |
+
for dev_id in range(config.num_devices):
|
| 117 |
+
self.dev_queues.append(DeviceQueue(config.device_to_stages[dev_id], dev_id))
|
| 118 |
+
self.config = config
|
| 119 |
+
|
| 120 |
+
self.init_operations()
|
| 121 |
+
|
| 122 |
+
def init_operations(self, op_types: Optional[List[str]] = None):
|
| 123 |
+
if op_types is None:
|
| 124 |
+
op_types = ["forward", "backward"]
|
| 125 |
+
for batch_id in range(self.config.num_batches):
|
| 126 |
+
for stage_id in range(self.config.num_stages):
|
| 127 |
+
for op_type in op_types:
|
| 128 |
+
self.ops[(batch_id, stage_id, op_type)] = Operation(
|
| 129 |
+
batch_id, stage_id, op_type
|
| 130 |
+
)
|
| 131 |
+
|
| 132 |
+
def get_op(self, batch_id: int, stage_id: int, op_type: str):
|
| 133 |
+
return self.ops[(batch_id, stage_id, op_type)]
|
| 134 |
+
|
| 135 |
+
def get_dependencies(self, op: Operation):
|
| 136 |
+
deps = []
|
| 137 |
+
if op.op_type == "forward":
|
| 138 |
+
if op.stage_id > 0:
|
| 139 |
+
deps.append(
|
| 140 |
+
(
|
| 141 |
+
self.get_op(op.batch_id, op.stage_id - 1, "forward"),
|
| 142 |
+
self.config.p2p_latency,
|
| 143 |
+
)
|
| 144 |
+
)
|
| 145 |
+
elif op.op_type == "backward":
|
| 146 |
+
if op.stage_id < self.config.num_stages - 1:
|
| 147 |
+
deps.append(
|
| 148 |
+
(
|
| 149 |
+
self.get_op(op.batch_id, op.stage_id + 1, "backward"),
|
| 150 |
+
self.config.p2p_latency,
|
| 151 |
+
)
|
| 152 |
+
)
|
| 153 |
+
|
| 154 |
+
device_index = self.dev_queues[op.device_id].ops.index(op)
|
| 155 |
+
if device_index > 0:
|
| 156 |
+
deps.append((self.dev_queues[op.device_id].ops[device_index - 1], 0.0))
|
| 157 |
+
return deps
|
| 158 |
+
|
| 159 |
+
def show(self):
|
| 160 |
+
"""Display detailed information about the schedule for debugging purposes."""
|
| 161 |
+
print("\n=== SCHEDULE DETAILS ===")
|
| 162 |
+
print(f"Devices: {self.config.num_devices}, Stages: {self.config.num_stages}, Batches: {self.config.num_batches}")
|
| 163 |
+
print(f"Placement Strategy: {self.config.placement_strategy}")
|
| 164 |
+
print("\n=== DEVICE QUEUES ===")
|
| 165 |
+
|
| 166 |
+
for dev_id in range(self.config.num_devices):
|
| 167 |
+
print(f"\nDEVICE {dev_id} (Stages: {self.dev_queues[dev_id].stages}):")
|
| 168 |
+
print("-" * 80)
|
| 169 |
+
print(f"{'Batch':^6} | {'Stage':^6} | {'Type':^10} | {'Start':^10} | {'End':^10} | {'Duration':^10}")
|
| 170 |
+
print("-" * 80)
|
| 171 |
+
|
| 172 |
+
for op in self.dev_queues[dev_id].ops:
|
| 173 |
+
op_type = "Forward" if op.op_type == "forward" else "Backward"
|
| 174 |
+
start = f"{op.start_time:.2f}" if op.start_time is not None else "N/A"
|
| 175 |
+
end = f"{op.end_time:.2f}" if op.end_time is not None else "N/A"
|
| 176 |
+
|
| 177 |
+
duration = "N/A"
|
| 178 |
+
if op.start_time is not None and op.end_time is not None:
|
| 179 |
+
duration = f"{op.end_time - op.start_time:.2f}"
|
| 180 |
+
|
| 181 |
+
print(f"{op.batch_id:^6} | {op.stage_id:^6} | {op_type:^10} | {start:^10} | {end:^10} | {duration:^10}")
|
| 182 |
+
|
| 183 |
+
# Find the total execution time (if timing info is available)
|
| 184 |
+
if all(op.end_time is not None for op in self.ops.values()):
|
| 185 |
+
total_time = max(op.end_time for op in self.ops.values())
|
| 186 |
+
print(f"\nTotal execution time: {total_time:.2f}")
|
| 187 |
+
|
| 188 |
+
|
| 189 |
+
class ScheduleExecutor:
|
| 190 |
+
def __init__(self, schedule: Schedule):
|
| 191 |
+
self.schedule = schedule
|
| 192 |
+
|
| 193 |
+
def execute(self):
|
| 194 |
+
def execute_op(op: Operation):
|
| 195 |
+
deps = self.schedule.get_dependencies(op)
|
| 196 |
+
if len(deps) == 0:
|
| 197 |
+
op.start_time = 0.0
|
| 198 |
+
else:
|
| 199 |
+
for dep, gap in deps:
|
| 200 |
+
if dep.end_time is None or dep.start_time is None:
|
| 201 |
+
execute_op(dep)
|
| 202 |
+
op.start_time = max(dep.end_time + gap for dep, gap in deps)
|
| 203 |
+
op.end_time = op.start_time + self.schedule.config.get_op_time(
|
| 204 |
+
op.op_type, op.stage_id
|
| 205 |
+
)
|
| 206 |
+
|
| 207 |
+
op_num = len(self.schedule.dev_queues[0].ops)
|
| 208 |
+
for i in range(op_num):
|
| 209 |
+
for dev_id in range(self.schedule.config.num_devices):
|
| 210 |
+
op = self.schedule.dev_queues[dev_id].ops[i]
|
| 211 |
+
execute_op(op)
|
| 212 |
+
|
| 213 |
+
for op in self.schedule.ops.values():
|
| 214 |
+
assert (
|
| 215 |
+
op.start_time is not None
|
| 216 |
+
), f"op {op.batch_id}, {op.stage_id}, {op.op_type} has no start time"
|
| 217 |
+
assert (
|
| 218 |
+
op.end_time is not None
|
| 219 |
+
), f"op {op.batch_id}, {op.stage_id}, {op.op_type} has no end time"
|
src/strategies.py
ADDED
|
@@ -0,0 +1,192 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from collections import defaultdict
|
| 2 |
+
from src.execution_model import Schedule, ScheduleConfig
|
| 3 |
+
|
| 4 |
+
|
| 5 |
+
def generate_1f1b_schedule(config: ScheduleConfig):
|
| 6 |
+
schedule = Schedule(config)
|
| 7 |
+
|
| 8 |
+
for i in range(config.num_devices):
|
| 9 |
+
fwd_batch_id = 0
|
| 10 |
+
bwd_batch_id = 0
|
| 11 |
+
cooldown_batches = warmup_batches = config.num_devices - i - 1
|
| 12 |
+
steady_batches = config.num_batches - warmup_batches
|
| 13 |
+
|
| 14 |
+
for _ in range(warmup_batches):
|
| 15 |
+
for j in range(len(schedule.dev_queues[i].stages)):
|
| 16 |
+
schedule.dev_queues[i].add_operation(
|
| 17 |
+
schedule.get_op(fwd_batch_id, schedule.dev_queues[i].stages[j], "forward")
|
| 18 |
+
)
|
| 19 |
+
fwd_batch_id += 1
|
| 20 |
+
|
| 21 |
+
for _ in range(steady_batches):
|
| 22 |
+
for j in range(len(schedule.dev_queues[i].stages)):
|
| 23 |
+
schedule.dev_queues[i].add_operation(
|
| 24 |
+
schedule.get_op(fwd_batch_id, schedule.dev_queues[i].stages[j], "forward")
|
| 25 |
+
)
|
| 26 |
+
fwd_batch_id += 1
|
| 27 |
+
for j in range(len(schedule.dev_queues[i].stages)-1, -1, -1):
|
| 28 |
+
schedule.dev_queues[i].add_operation(
|
| 29 |
+
schedule.get_op(bwd_batch_id, schedule.dev_queues[i].stages[j], "backward")
|
| 30 |
+
)
|
| 31 |
+
bwd_batch_id += 1
|
| 32 |
+
|
| 33 |
+
for _ in range(cooldown_batches):
|
| 34 |
+
for j in range(len(schedule.dev_queues[i].stages)-1, -1, -1):
|
| 35 |
+
schedule.dev_queues[i].add_operation(
|
| 36 |
+
schedule.get_op(bwd_batch_id, schedule.dev_queues[i].stages[j], "backward")
|
| 37 |
+
)
|
| 38 |
+
bwd_batch_id += 1
|
| 39 |
+
|
| 40 |
+
return schedule
|
| 41 |
+
|
| 42 |
+
|
| 43 |
+
# Some codes are copied from Megatron-LM
|
| 44 |
+
def generate_1f1b_interleave_schedule(config: ScheduleConfig):
|
| 45 |
+
schedule = Schedule(config)
|
| 46 |
+
|
| 47 |
+
def get_pp_rank_microbatches(
|
| 48 |
+
num_microbatches,
|
| 49 |
+
num_devices,
|
| 50 |
+
device_id,
|
| 51 |
+
num_stages_per_device,
|
| 52 |
+
microbatch_group_size_per_vp_stage,
|
| 53 |
+
):
|
| 54 |
+
"""Get the number of total, warmup, and remaining microbatches in PP scheduling."""
|
| 55 |
+
total_num_microbatches = num_microbatches * num_stages_per_device
|
| 56 |
+
are_all_microbatches_in_warmup = False
|
| 57 |
+
|
| 58 |
+
if num_devices > 1:
|
| 59 |
+
if num_stages_per_device is None:
|
| 60 |
+
# forward_backward_pipelining_without_interleaving
|
| 61 |
+
num_warmup_microbatches = num_devices - device_id - 1
|
| 62 |
+
else:
|
| 63 |
+
# forward_backward_pipelining_with_interleaving
|
| 64 |
+
# Run (num_model_chunks-1)*microbatch_group_size_per_vp_stage on
|
| 65 |
+
# all workers, followed by more microbatches after depending on
|
| 66 |
+
# stage ID (more forward passes for earlier stages, later stages can
|
| 67 |
+
# immediately start with 1F1B).
|
| 68 |
+
num_warmup_microbatches = (num_devices - device_id - 1) * 2
|
| 69 |
+
num_warmup_microbatches += (num_stages_per_device - 1) * microbatch_group_size_per_vp_stage
|
| 70 |
+
else:
|
| 71 |
+
# forward_backward_no_pipelining
|
| 72 |
+
num_warmup_microbatches = 1
|
| 73 |
+
|
| 74 |
+
if num_warmup_microbatches >= total_num_microbatches:
|
| 75 |
+
num_warmup_microbatches = total_num_microbatches
|
| 76 |
+
are_all_microbatches_in_warmup = True
|
| 77 |
+
num_microbatches_remaining = total_num_microbatches - num_warmup_microbatches
|
| 78 |
+
|
| 79 |
+
return (
|
| 80 |
+
total_num_microbatches,
|
| 81 |
+
are_all_microbatches_in_warmup,
|
| 82 |
+
num_warmup_microbatches,
|
| 83 |
+
num_microbatches_remaining,
|
| 84 |
+
)
|
| 85 |
+
|
| 86 |
+
|
| 87 |
+
def get_schedule_table(num_microbatches, num_model_chunks, microbatch_group_size_per_vp_stage):
|
| 88 |
+
"""Get the schedule table for PP scheduling.
|
| 89 |
+
|
| 90 |
+
Create a tunable schedule lookup table.
|
| 91 |
+
The schedule lookup table uses the virtual_microbatch_id to find the corresponding microbatch_id and model_chunk_id.
|
| 92 |
+
For example, the tunable schedule table for PP2 N3M5 with VP2 is constructed as below:
|
| 93 |
+
virtual_microbatch_id | 0 1 2 3 4 5 6 7 8 9
|
| 94 |
+
microbatch_id | 0 1 2 0 1 2 3 4 3 4
|
| 95 |
+
model_chunk_id | 0 0 0 1 1 1 0 0 1 1
|
| 96 |
+
"""
|
| 97 |
+
schedule_table = []
|
| 98 |
+
for min_microbatch_id_in_group in range(
|
| 99 |
+
0, num_microbatches, microbatch_group_size_per_vp_stage
|
| 100 |
+
):
|
| 101 |
+
if min_microbatch_id_in_group + microbatch_group_size_per_vp_stage >= num_microbatches:
|
| 102 |
+
# Construct schedule for the last microbatch group
|
| 103 |
+
schedule_table.extend(
|
| 104 |
+
[
|
| 105 |
+
(microbatch_id, model_chunk_id)
|
| 106 |
+
for model_chunk_id in range(num_model_chunks)
|
| 107 |
+
for microbatch_id in range(min_microbatch_id_in_group, num_microbatches)
|
| 108 |
+
]
|
| 109 |
+
)
|
| 110 |
+
else:
|
| 111 |
+
# Construct schedule for other microbatch groups
|
| 112 |
+
schedule_table.extend(
|
| 113 |
+
[
|
| 114 |
+
(microbatch_id, model_chunk_id)
|
| 115 |
+
for model_chunk_id in range(num_model_chunks)
|
| 116 |
+
for microbatch_id in range(
|
| 117 |
+
min_microbatch_id_in_group,
|
| 118 |
+
min_microbatch_id_in_group + microbatch_group_size_per_vp_stage,
|
| 119 |
+
)
|
| 120 |
+
]
|
| 121 |
+
)
|
| 122 |
+
return schedule_table
|
| 123 |
+
|
| 124 |
+
|
| 125 |
+
def convert_schedule_table_to_order(num_warmup_microbatches, num_model_chunks, schedule_table):
|
| 126 |
+
"""Convert a tunable schedule lookup table to the te.make_graphed_callables() accepted
|
| 127 |
+
order format. For example, the tunable schedule table for PP2 N3M5 with VP2 is as below:
|
| 128 |
+
virtual_microbatch_id | 0 1 2 3 4 5 6 7 8 9
|
| 129 |
+
microbatch_id | 0 1 2 0 1 2 3 4 3 4
|
| 130 |
+
model_chunk_id | 0 0 0 1 1 1 0 0 1 1
|
| 131 |
+
|
| 132 |
+
Then the forward backward separated order is:
|
| 133 |
+
forward | 1 1 1 2 2 2 1 1 2 2
|
| 134 |
+
backward | -2 -2 -2 -1 -1 -1 -2 -2 -1 -1
|
| 135 |
+
|
| 136 |
+
If num_warmup_microbatches is 5, the output order is:
|
| 137 |
+
1 1 1 2 2 2 -2 1 -2 1 -2 2 -1 2 -1 -1 -2 -2 -1 -1
|
| 138 |
+
"""
|
| 139 |
+
_, model_chunk_id_table = zip(*schedule_table)
|
| 140 |
+
forward_order = [chunk_id + 1 for chunk_id in model_chunk_id_table]
|
| 141 |
+
backward_order = [chunk_id - num_model_chunks for chunk_id in model_chunk_id_table]
|
| 142 |
+
order = forward_order[:num_warmup_microbatches]
|
| 143 |
+
for i in range(num_warmup_microbatches, len(forward_order)):
|
| 144 |
+
order.append(forward_order[i])
|
| 145 |
+
order.append(backward_order[i - num_warmup_microbatches])
|
| 146 |
+
if num_warmup_microbatches > 0:
|
| 147 |
+
order.extend(backward_order[-num_warmup_microbatches:])
|
| 148 |
+
return order
|
| 149 |
+
|
| 150 |
+
for device_id in range(config.num_devices):
|
| 151 |
+
microbatch_group_size_per_vp_stage = config.num_devices
|
| 152 |
+
total_num_microbatches, are_all_microbatches_in_warmup, num_warmup_microbatches, num_microbatches_remaining = get_pp_rank_microbatches(
|
| 153 |
+
config.num_batches,
|
| 154 |
+
config.num_devices,
|
| 155 |
+
device_id,
|
| 156 |
+
config.num_stages_per_device,
|
| 157 |
+
microbatch_group_size_per_vp_stage,
|
| 158 |
+
)
|
| 159 |
+
|
| 160 |
+
schedule_table = get_schedule_table(
|
| 161 |
+
config.num_batches,
|
| 162 |
+
config.num_stages_per_device,
|
| 163 |
+
microbatch_group_size_per_vp_stage,
|
| 164 |
+
)
|
| 165 |
+
|
| 166 |
+
order = convert_schedule_table_to_order(
|
| 167 |
+
num_warmup_microbatches,
|
| 168 |
+
num_model_chunks=config.num_stages_per_device,
|
| 169 |
+
schedule_table=schedule_table,
|
| 170 |
+
)
|
| 171 |
+
|
| 172 |
+
cur_stage_microbatch_id = {}
|
| 173 |
+
for i in range(1, config.num_stages_per_device+1):
|
| 174 |
+
cur_stage_microbatch_id[i] = 0
|
| 175 |
+
cur_stage_microbatch_id[-i] = 0
|
| 176 |
+
for order_item in order:
|
| 177 |
+
stage_id = schedule.dev_queues[device_id].stages[abs(order_item)-1]
|
| 178 |
+
|
| 179 |
+
if order_item > 0:
|
| 180 |
+
op_type = "forward"
|
| 181 |
+
micro_batch_id = cur_stage_microbatch_id[order_item]
|
| 182 |
+
cur_stage_microbatch_id[order_item] = cur_stage_microbatch_id[order_item] + 1
|
| 183 |
+
elif order_item < 0:
|
| 184 |
+
op_type = "backward"
|
| 185 |
+
micro_batch_id = cur_stage_microbatch_id[order_item]
|
| 186 |
+
cur_stage_microbatch_id[order_item] = cur_stage_microbatch_id[order_item] + 1
|
| 187 |
+
else:
|
| 188 |
+
raise ValueError(f"Invalid order item: {order_item}")
|
| 189 |
+
schedule.dev_queues[device_id].add_operation(
|
| 190 |
+
schedule.get_op(micro_batch_id, stage_id, op_type)
|
| 191 |
+
)
|
| 192 |
+
return schedule
|
dash_visualizer.py → src/visualizer.py
RENAMED
|
@@ -1,41 +1,86 @@
|
|
| 1 |
import dash
|
| 2 |
from dash import dcc, html
|
| 3 |
-
from dash.dependencies import Input, Output
|
| 4 |
import plotly.graph_objects as go
|
| 5 |
-
import
|
| 6 |
-
from typing import List, Dict, Literal
|
| 7 |
from tqdm import tqdm
|
| 8 |
-
import
|
| 9 |
|
|
|
|
| 10 |
|
| 11 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 12 |
"""
|
| 13 |
Create a Plotly figure for pipeline parallelism scheduling.
|
| 14 |
|
| 15 |
Args:
|
| 16 |
-
|
| 17 |
-
Each task is a dictionary with keys:
|
| 18 |
-
- 'type': 'forward', 'backward', or 'optimizer'
|
| 19 |
-
- 'batch': batch number
|
| 20 |
-
- 'start_time': start time of the task
|
| 21 |
-
- 'duration': duration of the task
|
| 22 |
max_time: Optional maximum time to display
|
| 23 |
show_progress: Whether to show a progress bar
|
| 24 |
"""
|
| 25 |
-
#
|
| 26 |
-
|
| 27 |
-
|
| 28 |
-
optimizer_color = "#FFEFCF"
|
| 29 |
empty_color = "whitesmoke"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 30 |
|
| 31 |
-
|
| 32 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 33 |
|
| 34 |
# Find the maximum time in the schedule if not provided
|
| 35 |
if max_time is None:
|
| 36 |
max_time = 0
|
| 37 |
-
for device in
|
| 38 |
-
for task in
|
| 39 |
end_time = task["start_time"] + task["duration"]
|
| 40 |
if end_time > max_time:
|
| 41 |
max_time = end_time
|
|
@@ -44,56 +89,51 @@ def create_pipeline_figure(schedule: Dict[int, List[Dict]], max_time=None, show_
|
|
| 44 |
fig = go.Figure()
|
| 45 |
|
| 46 |
# Initialize progress tracking
|
| 47 |
-
total_tasks = sum(len(tasks) for tasks in
|
| 48 |
tasks_processed = 0
|
| 49 |
|
| 50 |
if show_progress:
|
| 51 |
-
progress_bar = tqdm(total=total_tasks +
|
| 52 |
|
| 53 |
-
#
|
| 54 |
-
for
|
| 55 |
-
device_idx_reversed = num_stages - device_idx - 1 # Reverse for plotting
|
| 56 |
-
fig.add_trace(go.Scatter(
|
| 57 |
-
x=[0, max_time],
|
| 58 |
-
y=[device_idx_reversed, device_idx_reversed],
|
| 59 |
-
mode='lines',
|
| 60 |
-
line=dict(color='lightgray', width=0.5),
|
| 61 |
-
showlegend=False,
|
| 62 |
-
hoverinfo='none'
|
| 63 |
-
))
|
| 64 |
-
if show_progress:
|
| 65 |
-
progress_bar.update(1)
|
| 66 |
|
| 67 |
# Add rectangles for each task
|
| 68 |
-
for device_idx, device in enumerate(
|
| 69 |
-
device_idx_reversed =
|
|
|
|
|
|
|
|
|
|
| 70 |
|
| 71 |
-
for task in
|
| 72 |
# Determine task color and text color
|
| 73 |
if task["type"] == "forward":
|
| 74 |
-
color =
|
| 75 |
text_color = "white"
|
| 76 |
name = "Forward"
|
| 77 |
elif task["type"] == "backward":
|
| 78 |
-
color =
|
| 79 |
text_color = "black"
|
| 80 |
name = "Backward"
|
| 81 |
-
else:
|
| 82 |
-
color =
|
| 83 |
text_color = "black"
|
| 84 |
-
name = "
|
| 85 |
-
|
| 86 |
# Add rectangle for the task
|
| 87 |
start_time = task["start_time"]
|
| 88 |
duration = task["duration"]
|
| 89 |
|
|
|
|
|
|
|
|
|
|
| 90 |
# Create rectangle using shape
|
| 91 |
fig.add_shape(
|
| 92 |
type="rect",
|
| 93 |
x0=start_time,
|
| 94 |
-
y0=
|
| 95 |
x1=start_time + duration,
|
| 96 |
-
y1=
|
| 97 |
line=dict(color="black", width=0.5),
|
| 98 |
fillcolor=color,
|
| 99 |
layer="above",
|
|
@@ -102,12 +142,23 @@ def create_pipeline_figure(schedule: Dict[int, List[Dict]], max_time=None, show_
|
|
| 102 |
# Add batch number text
|
| 103 |
fig.add_annotation(
|
| 104 |
x=start_time + duration / 2,
|
| 105 |
-
y=
|
| 106 |
-
text=
|
| 107 |
showarrow=False,
|
| 108 |
-
font=dict(color=text_color, size=
|
| 109 |
)
|
| 110 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 111 |
# Update progress
|
| 112 |
if show_progress:
|
| 113 |
tasks_processed += 1
|
|
@@ -115,9 +166,8 @@ def create_pipeline_figure(schedule: Dict[int, List[Dict]], max_time=None, show_
|
|
| 115 |
|
| 116 |
# Add custom legend
|
| 117 |
legend_items = [
|
| 118 |
-
dict(name="Forward", color=
|
| 119 |
-
dict(name="Backward", color=
|
| 120 |
-
dict(name="Optimizer step", color=optimizer_color)
|
| 121 |
]
|
| 122 |
|
| 123 |
for i, item in enumerate(legend_items):
|
|
@@ -133,77 +183,98 @@ def create_pipeline_figure(schedule: Dict[int, List[Dict]], max_time=None, show_
|
|
| 133 |
progress_bar.update(1)
|
| 134 |
|
| 135 |
# Set axis properties
|
| 136 |
-
device_labels = [f"Device {i
|
| 137 |
-
device_labels.reverse() # Reverse to put Device
|
|
|
|
|
|
|
|
|
|
| 138 |
|
|
|
|
|
|
|
|
|
|
| 139 |
fig.update_layout(
|
| 140 |
-
xaxis=dict(
|
| 141 |
-
showticklabels=False,
|
| 142 |
-
showgrid=False,
|
| 143 |
-
zeroline=False,
|
| 144 |
-
title="Time →",
|
| 145 |
-
range=[0, max_time + 0.5]
|
| 146 |
-
),
|
| 147 |
yaxis=dict(
|
| 148 |
tickmode="array",
|
| 149 |
-
tickvals=
|
| 150 |
ticktext=device_labels,
|
| 151 |
showgrid=False,
|
| 152 |
zeroline=False,
|
| 153 |
-
range=[-0.5, num_stages - 0.5]
|
| 154 |
),
|
| 155 |
-
margin=dict(l=50, r=
|
| 156 |
plot_bgcolor="white",
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 157 |
legend=dict(
|
| 158 |
orientation="h",
|
| 159 |
-
yanchor="
|
| 160 |
-
y=-0.
|
| 161 |
xanchor="center",
|
| 162 |
x=0.5
|
| 163 |
-
)
|
|
|
|
|
|
|
|
|
|
|
|
|
| 164 |
)
|
| 165 |
|
| 166 |
if show_progress:
|
| 167 |
-
progress_bar.update(1)
|
| 168 |
progress_bar.close()
|
| 169 |
|
| 170 |
return fig
|
| 171 |
|
| 172 |
|
| 173 |
-
def create_dash_app(schedule:
|
| 174 |
"""
|
| 175 |
-
Create a Dash app
|
| 176 |
-
|
| 177 |
Args:
|
| 178 |
-
schedule:
|
| 179 |
-
schedule_type: Type of
|
| 180 |
"""
|
| 181 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 182 |
|
| 183 |
app.layout = html.Div([
|
| 184 |
-
html.H1(f"Pipeline Parallelism
|
| 185 |
-
style={'textAlign': 'center'}),
|
| 186 |
-
|
| 187 |
-
html.Div(id="loading-container", children=[
|
| 188 |
-
dcc.Loading(
|
| 189 |
-
id="loading-graph",
|
| 190 |
-
type="circle",
|
| 191 |
-
children=[
|
| 192 |
-
html.Div(id="graph-container", children=[
|
| 193 |
-
dcc.Graph(
|
| 194 |
-
id='pipeline-graph',
|
| 195 |
-
style={'height': '600px'}
|
| 196 |
-
)
|
| 197 |
-
])
|
| 198 |
-
]
|
| 199 |
-
)
|
| 200 |
-
]),
|
| 201 |
|
| 202 |
html.Div([
|
| 203 |
-
html.
|
| 204 |
-
|
| 205 |
-
|
| 206 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 207 |
])
|
| 208 |
|
| 209 |
@app.callback(
|
|
@@ -213,98 +284,65 @@ def create_dash_app(schedule: Dict[int, List[Dict]], schedule_type="1f1b"):
|
|
| 213 |
)
|
| 214 |
def load_graph(_):
|
| 215 |
# Create the figure when the app loads
|
| 216 |
-
return create_pipeline_figure(
|
| 217 |
-
|
| 218 |
@app.callback(
|
| 219 |
Output("download-image", "data"),
|
| 220 |
Input("btn-download", "n_clicks"),
|
| 221 |
prevent_initial_call=True,
|
| 222 |
)
|
| 223 |
def download_image(n_clicks):
|
| 224 |
-
#
|
| 225 |
-
fig = create_pipeline_figure(
|
| 226 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 227 |
return dict(
|
| 228 |
-
content=
|
| 229 |
-
filename="
|
|
|
|
|
|
|
| 230 |
)
|
| 231 |
|
| 232 |
return app
|
| 233 |
|
| 234 |
|
| 235 |
def visualize_pipeline_parallelism_dash(
|
| 236 |
-
schedule:
|
| 237 |
-
schedule_type: Literal["simple", "1f1b"] = "1f1b",
|
| 238 |
port: int = 8050,
|
| 239 |
debug: bool = False
|
| 240 |
):
|
| 241 |
"""
|
| 242 |
-
|
| 243 |
-
|
| 244 |
Args:
|
| 245 |
-
schedule:
|
| 246 |
-
|
| 247 |
-
|
| 248 |
-
debug: Whether to run the app in debug mode
|
| 249 |
"""
|
| 250 |
-
app = create_dash_app(schedule
|
| 251 |
print(f"Starting Dash app on http://localhost:{port}/")
|
| 252 |
app.run_server(debug=debug, port=port)
|
| 253 |
|
| 254 |
|
| 255 |
def save_pipeline_visualization_plotly(
|
| 256 |
-
schedule:
|
| 257 |
-
schedule_type: Literal["simple", "1f1b"] = "1f1b",
|
| 258 |
output_file: str = "pipeline_visualization_plotly.png",
|
| 259 |
):
|
| 260 |
"""
|
| 261 |
-
Save a static
|
| 262 |
-
|
| 263 |
Args:
|
| 264 |
-
schedule:
|
| 265 |
-
|
| 266 |
-
output_file: Path to save the visualization
|
| 267 |
"""
|
| 268 |
-
|
| 269 |
-
fig = create_pipeline_figure(
|
| 270 |
-
|
| 271 |
-
# Update layout for static image
|
| 272 |
-
fig.update_layout(
|
| 273 |
-
title=f"Pipeline Parallelism Visualization ({schedule_type.upper()})",
|
| 274 |
-
title_x=0.5
|
| 275 |
-
)
|
| 276 |
|
| 277 |
-
print(f"Saving
|
| 278 |
-
|
| 279 |
-
fig.write_image(output_file, scale=3)
|
| 280 |
print(f"Visualization saved to {output_file}")
|
| 281 |
|
| 282 |
-
|
| 283 |
-
if __name__ == "__main__":
|
| 284 |
-
# Example usage
|
| 285 |
-
import argparse
|
| 286 |
-
from pipeline import create_1f1b_schedule
|
| 287 |
-
|
| 288 |
-
parser = argparse.ArgumentParser(description="Pipeline Parallelism Visualizer")
|
| 289 |
-
parser.add_argument("--num-stages", type=int, default=4, help="Number of pipeline stages")
|
| 290 |
-
parser.add_argument("--num-batches", type=int, default=8, help="Number of microbatches")
|
| 291 |
-
parser.add_argument("--interactive", action="store_true", help="Run interactive Dash app")
|
| 292 |
-
parser.add_argument("--port", type=int, default=8050, help="Port for Dash app")
|
| 293 |
-
parser.add_argument("--output", type=str, default="pipeline_visualization_plotly.png", help="Output file for static image")
|
| 294 |
-
args = parser.parse_args()
|
| 295 |
-
|
| 296 |
-
# Create an example schedule
|
| 297 |
-
forward_times = [1.0] * args.num_stages
|
| 298 |
-
backward_times = [2.0] * args.num_stages
|
| 299 |
-
|
| 300 |
-
schedule = create_1f1b_schedule(
|
| 301 |
-
num_stages=args.num_stages,
|
| 302 |
-
num_batches=args.num_batches,
|
| 303 |
-
forward_times=forward_times,
|
| 304 |
-
backward_times=backward_times,
|
| 305 |
-
)
|
| 306 |
-
|
| 307 |
-
if args.interactive:
|
| 308 |
-
visualize_pipeline_parallelism_dash(schedule, port=args.port)
|
| 309 |
-
else:
|
| 310 |
-
save_pipeline_visualization_plotly(schedule, output_file=args.output)
|
|
|
|
| 1 |
import dash
|
| 2 |
from dash import dcc, html
|
| 3 |
+
from dash.dependencies import Input, Output
|
| 4 |
import plotly.graph_objects as go
|
| 5 |
+
import argparse
|
| 6 |
+
from typing import List, Dict, Literal, Optional
|
| 7 |
from tqdm import tqdm
|
| 8 |
+
import base64
|
| 9 |
|
| 10 |
+
from src.execution_model import Schedule
|
| 11 |
|
| 12 |
+
|
| 13 |
+
def convert_schedule_to_visualization_format(schedule: Schedule):
|
| 14 |
+
"""
|
| 15 |
+
Converts a Schedule object to the format needed for visualization.
|
| 16 |
+
|
| 17 |
+
Returns:
|
| 18 |
+
Dict[int, List[Dict]]: Dictionary mapping device_id to a list of operation dictionaries
|
| 19 |
+
"""
|
| 20 |
+
# Make sure all operations have start and end times
|
| 21 |
+
for op in schedule.ops.values():
|
| 22 |
+
if op.start_time is None or op.end_time is None:
|
| 23 |
+
raise ValueError("Operations must have start and end times. Run ScheduleExecutor.execute() first.")
|
| 24 |
+
|
| 25 |
+
visualization_data = {}
|
| 26 |
+
|
| 27 |
+
# Organize operations by device
|
| 28 |
+
for device_id, device_queue in enumerate(schedule.dev_queues):
|
| 29 |
+
visualization_data[device_id] = []
|
| 30 |
+
|
| 31 |
+
for op in device_queue.ops:
|
| 32 |
+
visualization_data[device_id].append({
|
| 33 |
+
"type": op.op_type,
|
| 34 |
+
"batch": op.batch_id + 1, # +1 because batch_id is 0-indexed
|
| 35 |
+
"stage": op.stage_id,
|
| 36 |
+
"start_time": op.start_time,
|
| 37 |
+
"duration": op.end_time - op.start_time
|
| 38 |
+
})
|
| 39 |
+
|
| 40 |
+
return visualization_data
|
| 41 |
+
|
| 42 |
+
|
| 43 |
+
def create_pipeline_figure(schedule_data: Dict[int, List[Dict]], max_time=None, show_progress=True):
|
| 44 |
"""
|
| 45 |
Create a Plotly figure for pipeline parallelism scheduling.
|
| 46 |
|
| 47 |
Args:
|
| 48 |
+
schedule_data: Dictionary mapping device IDs to lists of tasks (converted from Schedule)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 49 |
max_time: Optional maximum time to display
|
| 50 |
show_progress: Whether to show a progress bar
|
| 51 |
"""
|
| 52 |
+
# Find the number of devices
|
| 53 |
+
num_devices = len(schedule_data)
|
| 54 |
+
|
|
|
|
| 55 |
empty_color = "whitesmoke"
|
| 56 |
+
# Colors for task types
|
| 57 |
+
def get_color(op_type: str, stage_id: int):
|
| 58 |
+
# Base colors
|
| 59 |
+
forward_base_color = "royalblue"
|
| 60 |
+
backward_base_color = "lightgreen" # Changed from sandybrown to match your visualization
|
| 61 |
+
|
| 62 |
+
virtual_stage = stage_id // num_devices
|
| 63 |
|
| 64 |
+
if op_type == "forward":
|
| 65 |
+
if virtual_stage == 0:
|
| 66 |
+
return forward_base_color
|
| 67 |
+
else:
|
| 68 |
+
# Lighter shade for virtual_stage > 0
|
| 69 |
+
return "lightskyblue"
|
| 70 |
+
elif op_type == "backward":
|
| 71 |
+
if virtual_stage == 0:
|
| 72 |
+
return backward_base_color
|
| 73 |
+
else:
|
| 74 |
+
# Lighter shade for virtual_stage > 0
|
| 75 |
+
return "lightseagreen"
|
| 76 |
+
else:
|
| 77 |
+
raise ValueError(f"Invalid operation type: {op_type}")
|
| 78 |
|
| 79 |
# Find the maximum time in the schedule if not provided
|
| 80 |
if max_time is None:
|
| 81 |
max_time = 0
|
| 82 |
+
for device in schedule_data:
|
| 83 |
+
for task in schedule_data[device]:
|
| 84 |
end_time = task["start_time"] + task["duration"]
|
| 85 |
if end_time > max_time:
|
| 86 |
max_time = end_time
|
|
|
|
| 89 |
fig = go.Figure()
|
| 90 |
|
| 91 |
# Initialize progress tracking
|
| 92 |
+
total_tasks = sum(len(tasks) for tasks in schedule_data.values())
|
| 93 |
tasks_processed = 0
|
| 94 |
|
| 95 |
if show_progress:
|
| 96 |
+
progress_bar = tqdm(total=total_tasks + num_devices + 3, desc="Creating visualization")
|
| 97 |
|
| 98 |
+
# Create a custom y-axis with no gaps between devices
|
| 99 |
+
y_spacing = 1.0 # Use 1.0 for no gaps
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 100 |
|
| 101 |
# Add rectangles for each task
|
| 102 |
+
for device_idx, device in enumerate(schedule_data):
|
| 103 |
+
device_idx_reversed = num_devices - device_idx - 1
|
| 104 |
+
|
| 105 |
+
# Sort tasks by start time to ensure correct rendering
|
| 106 |
+
sorted_tasks = sorted(schedule_data[device], key=lambda t: t["start_time"])
|
| 107 |
|
| 108 |
+
for task in sorted_tasks:
|
| 109 |
# Determine task color and text color
|
| 110 |
if task["type"] == "forward":
|
| 111 |
+
color = get_color(task["type"], task["stage"])
|
| 112 |
text_color = "white"
|
| 113 |
name = "Forward"
|
| 114 |
elif task["type"] == "backward":
|
| 115 |
+
color = get_color(task["type"], task["stage"])
|
| 116 |
text_color = "black"
|
| 117 |
name = "Backward"
|
| 118 |
+
else:
|
| 119 |
+
color = empty_color
|
| 120 |
text_color = "black"
|
| 121 |
+
name = "Unknown"
|
| 122 |
+
|
| 123 |
# Add rectangle for the task
|
| 124 |
start_time = task["start_time"]
|
| 125 |
duration = task["duration"]
|
| 126 |
|
| 127 |
+
# Calculate y positions with no gaps
|
| 128 |
+
y_pos = device_idx_reversed * y_spacing
|
| 129 |
+
|
| 130 |
# Create rectangle using shape
|
| 131 |
fig.add_shape(
|
| 132 |
type="rect",
|
| 133 |
x0=start_time,
|
| 134 |
+
y0=y_pos - 0.5,
|
| 135 |
x1=start_time + duration,
|
| 136 |
+
y1=y_pos + 0.5,
|
| 137 |
line=dict(color="black", width=0.5),
|
| 138 |
fillcolor=color,
|
| 139 |
layer="above",
|
|
|
|
| 142 |
# Add batch number text
|
| 143 |
fig.add_annotation(
|
| 144 |
x=start_time + duration / 2,
|
| 145 |
+
y=y_pos,
|
| 146 |
+
text=f"{task['batch']}", # Only show batch ID
|
| 147 |
showarrow=False,
|
| 148 |
+
font=dict(color=text_color, size=12, family="Arial, bold"), # Increased font size
|
| 149 |
)
|
| 150 |
|
| 151 |
+
# Add hover data with additional details
|
| 152 |
+
fig.add_trace(go.Scatter(
|
| 153 |
+
x=[start_time + duration / 2],
|
| 154 |
+
y=[y_pos],
|
| 155 |
+
mode='markers',
|
| 156 |
+
marker=dict(opacity=0), # Invisible marker
|
| 157 |
+
hoverinfo='text',
|
| 158 |
+
text=f"Batch: {task['batch']}<br>Stage: {task['stage']}<br>Type: {name}<br>Start: {task['start_time']:.2f}<br>End: {task['start_time'] + task['duration']:.2f}<br>Duration: {task['duration']:.2f}",
|
| 159 |
+
showlegend=False
|
| 160 |
+
))
|
| 161 |
+
|
| 162 |
# Update progress
|
| 163 |
if show_progress:
|
| 164 |
tasks_processed += 1
|
|
|
|
| 166 |
|
| 167 |
# Add custom legend
|
| 168 |
legend_items = [
|
| 169 |
+
dict(name="Forward", color=get_color("forward", 0)),
|
| 170 |
+
dict(name="Backward", color=get_color("backward", 0)),
|
|
|
|
| 171 |
]
|
| 172 |
|
| 173 |
for i, item in enumerate(legend_items):
|
|
|
|
| 183 |
progress_bar.update(1)
|
| 184 |
|
| 185 |
# Set axis properties
|
| 186 |
+
device_labels = [f"Device {i}" for i in range(num_devices)]
|
| 187 |
+
device_labels.reverse() # Reverse to put Device 0 at the top
|
| 188 |
+
|
| 189 |
+
# Calculate tick positions with no gaps
|
| 190 |
+
tick_positions = [(num_devices - i - 1) * y_spacing for i in range(num_devices)]
|
| 191 |
|
| 192 |
+
# Adjust the range to ensure there are no empty spaces at the end
|
| 193 |
+
x_end = max_time * 1.05 # Add a small margin
|
| 194 |
+
|
| 195 |
fig.update_layout(
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 196 |
yaxis=dict(
|
| 197 |
tickmode="array",
|
| 198 |
+
tickvals=tick_positions,
|
| 199 |
ticktext=device_labels,
|
| 200 |
showgrid=False,
|
| 201 |
zeroline=False,
|
|
|
|
| 202 |
),
|
| 203 |
+
margin=dict(l=50, r=20, t=40, b=40),
|
| 204 |
plot_bgcolor="white",
|
| 205 |
+
title=dict(
|
| 206 |
+
text="Pipeline Parallelism Schedule",
|
| 207 |
+
x=0.5,
|
| 208 |
+
y=0.98, # Move title position closer to the top
|
| 209 |
+
font=dict(size=20)
|
| 210 |
+
),
|
| 211 |
legend=dict(
|
| 212 |
orientation="h",
|
| 213 |
+
yanchor="top",
|
| 214 |
+
y=-0.1, # Position below the plot
|
| 215 |
xanchor="center",
|
| 216 |
x=0.5
|
| 217 |
+
),
|
| 218 |
+
width=1600,
|
| 219 |
+
height=400, # Reduce height to make the visualization more compact
|
| 220 |
+
bargap=0,
|
| 221 |
+
bargroupgap=0,
|
| 222 |
)
|
| 223 |
|
| 224 |
if show_progress:
|
| 225 |
+
progress_bar.update(1)
|
| 226 |
progress_bar.close()
|
| 227 |
|
| 228 |
return fig
|
| 229 |
|
| 230 |
|
| 231 |
+
def create_dash_app(schedule: Schedule, schedule_type="1f1b"):
|
| 232 |
"""
|
| 233 |
+
Create a Dash app to visualize the pipeline schedule.
|
| 234 |
+
|
| 235 |
Args:
|
| 236 |
+
schedule: Schedule object to visualize
|
| 237 |
+
schedule_type: Type of schedule ("1f1b" or other)
|
| 238 |
"""
|
| 239 |
+
# Convert schedule to visualization format
|
| 240 |
+
schedule_data = convert_schedule_to_visualization_format(schedule)
|
| 241 |
+
|
| 242 |
+
# Create the app
|
| 243 |
+
app = dash.Dash(__name__, title=f"Pipeline Parallelism Visualizer - {schedule_type}")
|
| 244 |
|
| 245 |
app.layout = html.Div([
|
| 246 |
+
html.H1(f"Pipeline Parallelism Visualizer - {schedule_type}", style={'textAlign': 'center'}),
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 247 |
|
| 248 |
html.Div([
|
| 249 |
+
html.Div([
|
| 250 |
+
html.H3("Schedule Configuration:"),
|
| 251 |
+
html.Ul([
|
| 252 |
+
html.Li(f"Number of devices: {schedule.config.num_devices}"),
|
| 253 |
+
html.Li(f"Number of stages: {schedule.config.num_stages}"),
|
| 254 |
+
html.Li(f"Number of batches: {schedule.config.num_batches}"),
|
| 255 |
+
]),
|
| 256 |
+
], className="config-section"),
|
| 257 |
+
|
| 258 |
+
html.Button("Download Image", id="btn-download",
|
| 259 |
+
style={
|
| 260 |
+
'marginTop': '20px',
|
| 261 |
+
'padding': '10px',
|
| 262 |
+
'backgroundColor': '#007BFF',
|
| 263 |
+
'color': 'white',
|
| 264 |
+
'border': 'none',
|
| 265 |
+
'borderRadius': '5px',
|
| 266 |
+
'cursor': 'pointer'
|
| 267 |
+
}),
|
| 268 |
+
|
| 269 |
+
dcc.Download(id="download-image"),
|
| 270 |
+
], style={'margin': '20px'}),
|
| 271 |
+
|
| 272 |
+
html.Div(id="graph-container", children=[]),
|
| 273 |
+
|
| 274 |
+
dcc.Graph(
|
| 275 |
+
id="pipeline-graph",
|
| 276 |
+
config={'displayModeBar': True, 'toImageButtonOptions': {'format': 'png', 'filename': 'pipeline_visualization'}}
|
| 277 |
+
),
|
| 278 |
])
|
| 279 |
|
| 280 |
@app.callback(
|
|
|
|
| 284 |
)
|
| 285 |
def load_graph(_):
|
| 286 |
# Create the figure when the app loads
|
| 287 |
+
return create_pipeline_figure(schedule_data, show_progress=True)
|
| 288 |
+
|
| 289 |
@app.callback(
|
| 290 |
Output("download-image", "data"),
|
| 291 |
Input("btn-download", "n_clicks"),
|
| 292 |
prevent_initial_call=True,
|
| 293 |
)
|
| 294 |
def download_image(n_clicks):
|
| 295 |
+
# Generate the figure for download
|
| 296 |
+
fig = create_pipeline_figure(schedule_data, show_progress=True)
|
| 297 |
+
|
| 298 |
+
# Convert to base64 image
|
| 299 |
+
img_bytes = fig.to_image(format="png", width=1600, height=1000, scale=2)
|
| 300 |
+
img_base64 = base64.b64encode(img_bytes).decode('ascii')
|
| 301 |
+
|
| 302 |
+
# Return the download data
|
| 303 |
return dict(
|
| 304 |
+
content=img_base64,
|
| 305 |
+
filename=f"pipeline_visualization_{schedule_type}.png",
|
| 306 |
+
type="image/png",
|
| 307 |
+
base64=True
|
| 308 |
)
|
| 309 |
|
| 310 |
return app
|
| 311 |
|
| 312 |
|
| 313 |
def visualize_pipeline_parallelism_dash(
|
| 314 |
+
schedule: Schedule,
|
|
|
|
| 315 |
port: int = 8050,
|
| 316 |
debug: bool = False
|
| 317 |
):
|
| 318 |
"""
|
| 319 |
+
Launch a Dash app to visualize the pipeline schedule interactively.
|
| 320 |
+
|
| 321 |
Args:
|
| 322 |
+
schedule: Schedule object to visualize
|
| 323 |
+
port: Port to run the Dash app on
|
| 324 |
+
debug: Whether to run the Dash app in debug mode
|
|
|
|
| 325 |
"""
|
| 326 |
+
app = create_dash_app(schedule)
|
| 327 |
print(f"Starting Dash app on http://localhost:{port}/")
|
| 328 |
app.run_server(debug=debug, port=port)
|
| 329 |
|
| 330 |
|
| 331 |
def save_pipeline_visualization_plotly(
|
| 332 |
+
schedule: Schedule,
|
|
|
|
| 333 |
output_file: str = "pipeline_visualization_plotly.png",
|
| 334 |
):
|
| 335 |
"""
|
| 336 |
+
Save a static image of the pipeline schedule visualization.
|
| 337 |
+
|
| 338 |
Args:
|
| 339 |
+
schedule: Schedule object to visualize
|
| 340 |
+
output_file: Path to save the image to
|
|
|
|
| 341 |
"""
|
| 342 |
+
schedule_data = convert_schedule_to_visualization_format(schedule)
|
| 343 |
+
fig = create_pipeline_figure(schedule_data, show_progress=True)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 344 |
|
| 345 |
+
print(f"Saving visualization to {output_file}...")
|
| 346 |
+
fig.write_image(output_file, width=1600, height=400, scale=2)
|
|
|
|
| 347 |
print(f"Visualization saved to {output_file}")
|
| 348 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
visualizer.py
DELETED
|
@@ -1,141 +0,0 @@
|
|
| 1 |
-
import matplotlib.pyplot as plt
|
| 2 |
-
import numpy as np
|
| 3 |
-
from matplotlib.patches import Rectangle
|
| 4 |
-
from typing import List, Dict, Literal
|
| 5 |
-
|
| 6 |
-
|
| 7 |
-
def visualize_pipeline_parallelism(
|
| 8 |
-
schedule: Dict[int, List[Dict]],
|
| 9 |
-
schedule_type: Literal["simple", "1f1b"] = "1f1b",
|
| 10 |
-
output_file: str = "pipeline_visualization.png",
|
| 11 |
-
):
|
| 12 |
-
"""
|
| 13 |
-
Visualize pipeline parallelism scheduling.
|
| 14 |
-
|
| 15 |
-
Args:
|
| 16 |
-
schedule: Dictionary mapping device IDs to lists of tasks.
|
| 17 |
-
Each task is a dictionary with keys:
|
| 18 |
-
- 'type': 'forward', 'backward', or 'optimizer'
|
| 19 |
-
- 'batch': batch number
|
| 20 |
-
- 'start_time': start time of the task
|
| 21 |
-
- 'duration': duration of the task
|
| 22 |
-
schedule_type: Type of scheduling algorithm used ("simple" or "1f1b")
|
| 23 |
-
output_file: Path to save the visualization
|
| 24 |
-
"""
|
| 25 |
-
# Colors for task types
|
| 26 |
-
forward_color = "royalblue"
|
| 27 |
-
backward_color = "sandybrown" # Changed to match the reference image
|
| 28 |
-
optimizer_color = "#FFEFCF" # Light beige for optimizer steps
|
| 29 |
-
empty_color = "whitesmoke" # Very light gray for empty cells
|
| 30 |
-
|
| 31 |
-
# Find the number of stages (devices)
|
| 32 |
-
num_stages = len(schedule)
|
| 33 |
-
|
| 34 |
-
# Find the maximum time in the schedule
|
| 35 |
-
max_time = 0
|
| 36 |
-
for device in schedule:
|
| 37 |
-
for task in schedule[device]:
|
| 38 |
-
end_time = task["start_time"] + task["duration"]
|
| 39 |
-
if end_time > max_time:
|
| 40 |
-
max_time = end_time
|
| 41 |
-
|
| 42 |
-
# Create figure and axis
|
| 43 |
-
fig, ax = plt.subplots(figsize=(15, 4))
|
| 44 |
-
|
| 45 |
-
# Create an empty grid with light gray color
|
| 46 |
-
for device_idx in range(num_stages):
|
| 47 |
-
device_idx_reversed = num_stages - device_idx - 1 # Reverse the device index for plotting
|
| 48 |
-
for t in range(int(max_time) + 1):
|
| 49 |
-
rect = Rectangle(
|
| 50 |
-
(t, device_idx_reversed),
|
| 51 |
-
1.0,
|
| 52 |
-
1.0,
|
| 53 |
-
edgecolor="lightgray",
|
| 54 |
-
facecolor=empty_color,
|
| 55 |
-
linewidth=0.5,
|
| 56 |
-
)
|
| 57 |
-
ax.add_patch(rect)
|
| 58 |
-
|
| 59 |
-
# Plot the schedule
|
| 60 |
-
for device_idx, device in enumerate(schedule):
|
| 61 |
-
device_idx_reversed = num_stages - device_idx - 1 # Reverse the device index for plotting
|
| 62 |
-
for task in schedule[device]:
|
| 63 |
-
# Determine task color
|
| 64 |
-
if task["type"] == "forward":
|
| 65 |
-
color = forward_color
|
| 66 |
-
text_color = "white"
|
| 67 |
-
elif task["type"] == "backward":
|
| 68 |
-
color = backward_color
|
| 69 |
-
text_color = "black"
|
| 70 |
-
else: # optimizer or any other type
|
| 71 |
-
color = optimizer_color
|
| 72 |
-
text_color = "black"
|
| 73 |
-
|
| 74 |
-
rect = Rectangle(
|
| 75 |
-
(task["start_time"], device_idx_reversed),
|
| 76 |
-
task["duration"],
|
| 77 |
-
1.0,
|
| 78 |
-
edgecolor="black",
|
| 79 |
-
facecolor=color,
|
| 80 |
-
linewidth=0.5,
|
| 81 |
-
)
|
| 82 |
-
ax.add_patch(rect)
|
| 83 |
-
|
| 84 |
-
# Add text (batch number)
|
| 85 |
-
ax.text(
|
| 86 |
-
task["start_time"] + task["duration"] / 2,
|
| 87 |
-
device_idx_reversed + 0.5,
|
| 88 |
-
str(task["batch"]),
|
| 89 |
-
ha="center",
|
| 90 |
-
va="center",
|
| 91 |
-
fontsize=10,
|
| 92 |
-
fontweight="bold",
|
| 93 |
-
color=text_color,
|
| 94 |
-
)
|
| 95 |
-
|
| 96 |
-
# Set axis limits and labels
|
| 97 |
-
ax.set_xlim(0, max_time + 0.5)
|
| 98 |
-
ax.set_ylim(-0.5, num_stages + 0.5)
|
| 99 |
-
ax.set_yticks(np.arange(num_stages) + 0.5)
|
| 100 |
-
|
| 101 |
-
# Reverse the order: Device 1 at the top, highest number at the bottom
|
| 102 |
-
device_labels = [f"Device {i+1}" for i in range(num_stages)]
|
| 103 |
-
device_labels.reverse() # Reverse to put Device 1 at the top
|
| 104 |
-
ax.set_yticklabels(device_labels)
|
| 105 |
-
|
| 106 |
-
# Add "Time" label and arrow at the bottom
|
| 107 |
-
arrow_y = -0.4
|
| 108 |
-
ax.text(0.5, arrow_y, "Time", ha="right", va="center", fontsize=10)
|
| 109 |
-
ax.annotate("", xy=(2, arrow_y), xytext=(1, arrow_y),
|
| 110 |
-
arrowprops=dict(arrowstyle="->", lw=1))
|
| 111 |
-
|
| 112 |
-
# Remove the x-axis ticks
|
| 113 |
-
ax.set_xticks([])
|
| 114 |
-
|
| 115 |
-
# Remove the outer frame/border
|
| 116 |
-
for spine in ax.spines.values():
|
| 117 |
-
spine.set_visible(False)
|
| 118 |
-
|
| 119 |
-
# Add a legend - using 3 parts like in the reference image
|
| 120 |
-
forward_patch = Rectangle((0, 0), 1, 1, facecolor=forward_color)
|
| 121 |
-
backward_patch = Rectangle((0, 0), 1, 1, facecolor=backward_color)
|
| 122 |
-
optimizer_patch = Rectangle((0, 0), 1, 1, facecolor=optimizer_color)
|
| 123 |
-
|
| 124 |
-
legend = ax.legend(
|
| 125 |
-
[forward_patch, backward_patch, optimizer_patch],
|
| 126 |
-
["Forward", "Backward", "Optimizer step"],
|
| 127 |
-
loc="upper center",
|
| 128 |
-
bbox_to_anchor=(0.5, -0.15),
|
| 129 |
-
ncol=3,
|
| 130 |
-
frameon=False,
|
| 131 |
-
)
|
| 132 |
-
|
| 133 |
-
# Turn off grid
|
| 134 |
-
ax.grid(False)
|
| 135 |
-
|
| 136 |
-
# Save the figure
|
| 137 |
-
plt.tight_layout()
|
| 138 |
-
plt.savefig(output_file, dpi=300, bbox_inches="tight")
|
| 139 |
-
plt.close()
|
| 140 |
-
|
| 141 |
-
print(f"Visualization saved to {output_file}")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|