Spaces:
Running
Running
| import matplotlib.pyplot as plt | |
| import numpy as np | |
| from matplotlib.patches import Rectangle | |
| from typing import List, Dict, Literal | |
| def visualize_pipeline_parallelism( | |
| schedule: Dict[int, List[Dict]], | |
| schedule_type: Literal["simple", "1f1b"] = "1f1b", | |
| output_file: str = "pipeline_visualization.png", | |
| ): | |
| """ | |
| Visualize pipeline parallelism scheduling. | |
| Args: | |
| schedule: Dictionary mapping device IDs to lists of tasks. | |
| Each task is a dictionary with keys: | |
| - 'type': 'forward', 'backward', or 'optimizer' | |
| - 'batch': batch number | |
| - 'start_time': start time of the task | |
| - 'duration': duration of the task | |
| schedule_type: Type of scheduling algorithm used ("simple" or "1f1b") | |
| output_file: Path to save the visualization | |
| """ | |
| # Colors for task types | |
| forward_color = "royalblue" | |
| backward_color = "sandybrown" # Changed to match the reference image | |
| optimizer_color = "#FFEFCF" # Light beige for optimizer steps | |
| empty_color = "whitesmoke" # Very light gray for empty cells | |
| # Find the number of stages (devices) | |
| num_stages = len(schedule) | |
| # Find the maximum time in the schedule | |
| max_time = 0 | |
| for device in schedule: | |
| for task in schedule[device]: | |
| end_time = task["start_time"] + task["duration"] | |
| if end_time > max_time: | |
| max_time = end_time | |
| # Create figure and axis | |
| fig, ax = plt.subplots(figsize=(15, 4)) | |
| # Create an empty grid with light gray color | |
| for device_idx in range(num_stages): | |
| device_idx_reversed = num_stages - device_idx - 1 # Reverse the device index for plotting | |
| for t in range(int(max_time) + 1): | |
| rect = Rectangle( | |
| (t, device_idx_reversed), | |
| 1.0, | |
| 1.0, | |
| edgecolor="lightgray", | |
| facecolor=empty_color, | |
| linewidth=0.5, | |
| ) | |
| ax.add_patch(rect) | |
| # Plot the schedule | |
| for device_idx, device in enumerate(schedule): | |
| device_idx_reversed = num_stages - device_idx - 1 # Reverse the device index for plotting | |
| for task in schedule[device]: | |
| # Determine task color | |
| if task["type"] == "forward": | |
| color = forward_color | |
| text_color = "white" | |
| elif task["type"] == "backward": | |
| color = backward_color | |
| text_color = "black" | |
| else: # optimizer or any other type | |
| color = optimizer_color | |
| text_color = "black" | |
| rect = Rectangle( | |
| (task["start_time"], device_idx_reversed), | |
| task["duration"], | |
| 1.0, | |
| edgecolor="black", | |
| facecolor=color, | |
| linewidth=0.5, | |
| ) | |
| ax.add_patch(rect) | |
| # Add text (batch number) | |
| ax.text( | |
| task["start_time"] + task["duration"] / 2, | |
| device_idx_reversed + 0.5, | |
| str(task["batch"]), | |
| ha="center", | |
| va="center", | |
| fontsize=10, | |
| fontweight="bold", | |
| color=text_color, | |
| ) | |
| # Set axis limits and labels | |
| ax.set_xlim(0, max_time + 0.5) | |
| ax.set_ylim(-0.5, num_stages + 0.5) | |
| ax.set_yticks(np.arange(num_stages) + 0.5) | |
| # Reverse the order: Device 1 at the top, highest number at the bottom | |
| device_labels = [f"Device {i+1}" for i in range(num_stages)] | |
| device_labels.reverse() # Reverse to put Device 1 at the top | |
| ax.set_yticklabels(device_labels) | |
| # Add "Time" label and arrow at the bottom | |
| arrow_y = -0.4 | |
| ax.text(0.5, arrow_y, "Time", ha="right", va="center", fontsize=10) | |
| ax.annotate("", xy=(2, arrow_y), xytext=(1, arrow_y), | |
| arrowprops=dict(arrowstyle="->", lw=1)) | |
| # Remove the x-axis ticks | |
| ax.set_xticks([]) | |
| # Remove the outer frame/border | |
| for spine in ax.spines.values(): | |
| spine.set_visible(False) | |
| # Add a legend - using 3 parts like in the reference image | |
| forward_patch = Rectangle((0, 0), 1, 1, facecolor=forward_color) | |
| backward_patch = Rectangle((0, 0), 1, 1, facecolor=backward_color) | |
| optimizer_patch = Rectangle((0, 0), 1, 1, facecolor=optimizer_color) | |
| legend = ax.legend( | |
| [forward_patch, backward_patch, optimizer_patch], | |
| ["Forward", "Backward", "Optimizer step"], | |
| loc="upper center", | |
| bbox_to_anchor=(0.5, -0.15), | |
| ncol=3, | |
| frameon=False, | |
| ) | |
| # Turn off grid | |
| ax.grid(False) | |
| # Save the figure | |
| plt.tight_layout() | |
| plt.savefig(output_file, dpi=300, bbox_inches="tight") | |
| plt.close() | |
| print(f"Visualization saved to {output_file}") | |