Spaces:
Sleeping
Sleeping
| import dash | |
| import dash_bootstrap_components as dbc | |
| from dash import dcc, html, Input, Output, State, callback_context | |
| import plotly.graph_objects as go | |
| from src.execution_model import ScheduleConfig, Schedule | |
| from src.strategies import ( | |
| generate_1f1b_schedule, | |
| generate_zero_bubble_1p_schedule, | |
| generate_1f1b_overlap_schedule, | |
| generate_1f1b_interleave_schedule, | |
| generate_1f1b_interleave_overlap_schedule, | |
| generate_dualpipe_schedule | |
| ) | |
| from src.visualizer import convert_schedule_to_visualization_format, create_pipeline_figure | |
| STRATEGIES = { | |
| "1f1b": generate_1f1b_schedule, | |
| "zb1p": generate_zero_bubble_1p_schedule, | |
| "1f1b_overlap": generate_1f1b_overlap_schedule, | |
| "1f1b_interleave": generate_1f1b_interleave_schedule, | |
| "1f1b_interleave_overlap": generate_1f1b_interleave_overlap_schedule, | |
| "dualpipe": generate_dualpipe_schedule, | |
| } | |
| app = dash.Dash(__name__, external_stylesheets=[dbc.themes.BOOTSTRAP], suppress_callback_exceptions=True) | |
| app.title = "Pipeline Parallelism Schedule Visualizer" | |
| # Initial default values | |
| default_values = { | |
| "num_devices": 4, | |
| "num_stages": 8, | |
| "num_batches": 16, | |
| "p2p_latency": 0.0, | |
| "op_time_forward": 1.0, | |
| "op_time_backward_d": 1.0, | |
| "op_time_backward_w": 1.0, | |
| "op_time_backward": 2.0, | |
| "strategy": "1f1b_interleave", | |
| "op_time_overlapped_fwd_bwd": None, | |
| } | |
| # Define input groups using dbc components | |
| basic_params_card = dbc.Card( | |
| dbc.CardBody([ | |
| html.H5("Basic Parameters", className="card-title"), | |
| html.Div([ | |
| dbc.Label("Number of Devices (GPUs):"), | |
| dbc.Input(id='num_devices', type='number', value=default_values["num_devices"], min=1, step=1), | |
| ], className="mb-3"), | |
| html.Div([ | |
| dbc.Label("Number of Stages (Model Chunks):"), | |
| dbc.Input(id='num_stages', type='number', value=default_values["num_stages"], min=1, step=1), | |
| ], className="mb-3"), | |
| html.Div([ | |
| dbc.Label("Number of Microbatches:"), | |
| dbc.Input(id='num_batches', type='number', value=default_values["num_batches"], min=1, step=1), | |
| ], className="mb-3"), | |
| html.Div([ | |
| dbc.Label("P2P Latency (ms):"), | |
| dbc.Input(id='p2p_latency', type='number', value=default_values["p2p_latency"], min=0, step=0.01), | |
| ], className="mb-3"), | |
| ]) | |
| ) | |
| scheduling_params_card = dbc.Card( | |
| dbc.CardBody([ | |
| html.H5("Scheduling Parameters", className="card-title"), | |
| html.Div([ | |
| dbc.Label("Scheduling Strategies:"), | |
| dbc.Checklist( | |
| id='strategy-checklist', | |
| options=[{'label': k, 'value': k} for k in STRATEGIES.keys()], | |
| value=list(STRATEGIES.keys()), | |
| inline=False, | |
| ), | |
| ], className="mb-3"), | |
| ]) | |
| ) | |
| timing_params_card = dbc.Card( | |
| dbc.CardBody([ | |
| html.H5("Operation Timing (ms)", className="card-title"), | |
| html.Div([ | |
| dbc.Label("Forward:"), | |
| dbc.Input(id='op_time_forward', type='number', value=default_values["op_time_forward"], min=0.01, step=0.01), | |
| ], className="mb-3"), | |
| html.Div([ | |
| dbc.Label("Backward (Combined):"), | |
| dbc.Input(id='op_time_backward', type='number', value=default_values["op_time_backward"], min=0.01, step=0.01), | |
| dbc.FormText("Used when strategy does NOT require split backward."), | |
| ], className="mb-3"), | |
| html.Div([ | |
| dbc.Label("Backward D (Data Grad):"), | |
| dbc.Input(id='op_time_backward_d', type='number', value=default_values["op_time_backward_d"], min=0.01, step=0.01), | |
| dbc.FormText("Used when strategy requires split backward (e.g., ZB-1P, DualPipe)."), | |
| ], className="mb-3"), | |
| html.Div([ | |
| dbc.Label("Backward W (Weight Grad):"), | |
| dbc.Input(id='op_time_backward_w', type='number', value=default_values["op_time_backward_w"], min=0.01, step=0.01), | |
| dbc.FormText("Used when strategy requires split backward (e.g., ZB-1P, DualPipe)."), | |
| ], className="mb-3"), | |
| html.Div([ | |
| dbc.Label("Overlapped Forward+Backward:"), | |
| dbc.Input(id='op_time_overlapped_fwd_bwd', type='number', placeholder="Optional: Defaults to Fwd + Bwd times", min=0.01, step=0.01, value=default_values["op_time_overlapped_fwd_bwd"]), | |
| dbc.FormText("Specify a custom duration if Forward and Backward ops overlap completely."), | |
| ], className="mb-3"), | |
| ]) | |
| ) | |
| # Updated app layout using dbc components and structure | |
| app.layout = dbc.Container([ | |
| html.H1("Pipeline Parallelism Schedule Visualizer", className="my-4 text-center"), | |
| dbc.Row([ | |
| dbc.Col(basic_params_card, md=4), | |
| dbc.Col(scheduling_params_card, md=4), | |
| dbc.Col(timing_params_card, md=4), | |
| ]), | |
| dbc.Row([ | |
| dbc.Col([ | |
| dbc.Button('Generate Schedule', id='generate-button', n_clicks=0, color="primary", className="mt-4"), | |
| ], className="text-center") | |
| ]), | |
| dbc.Row([ | |
| dbc.Col([ | |
| dcc.Loading( | |
| id="loading-graph-area", | |
| type="circle", | |
| children=html.Div(id='graph-output-container', className="mt-4") | |
| ) | |
| ]) | |
| ]) | |
| ], fluid=True) | |
| def update_graph(n_clicks, num_devices, num_stages, num_batches, p2p_latency, | |
| op_time_forward, op_time_backward, op_time_backward_d, op_time_backward_w, | |
| op_time_overlapped_fwd_bwd, | |
| selected_strategies): | |
| # Define the desired display order for strategies | |
| strategy_display_order = ["1f1b", "1f1b_interleave", "1f1b_overlap", "1f1b_interleave_overlap", "dualpipe", "zb1p"] | |
| output_components = [] | |
| valid_results = [] # Store (strategy_name, schedule, vis_data) for valid schedules | |
| error_messages = [] # Store (strategy_name, error_message) for errors | |
| automatic_adjustments = [] # Store messages about automatic parameter adjustments | |
| if not selected_strategies: | |
| return [dbc.Alert("Please select at least one scheduling strategy.", color="warning")] | |
| if not all([num_devices, num_stages, num_batches, op_time_forward]): | |
| return [dbc.Alert("Missing required basic input values (Devices, Stages, Batches, Forward Time).", color="danger")] | |
| for strategy in selected_strategies: | |
| error_message = "" | |
| placement_strategy = "" | |
| # Use local copies of params that might be adjusted for this strategy | |
| current_num_stages = num_stages | |
| current_num_devices = num_devices | |
| # Apply automatic adjustments for dualpipe | |
| if strategy == "dualpipe" and num_stages != num_devices: | |
| current_num_stages = num_devices # Force num_stages = num_devices for dualpipe | |
| automatic_adjustments.append( | |
| f"Strategy '{strategy}': Number of Stages automatically adjusted to {num_devices} to match Number of Devices." | |
| ) | |
| # Apply automatic adjustments for strategies that require num_stages == num_devices | |
| if strategy in ["1f1b", "1f1b_overlap", "zb1p"] and num_stages != num_devices: | |
| current_num_stages = num_devices | |
| automatic_adjustments.append( | |
| f"Strategy '{strategy}': Number of Stages automatically adjusted to {num_devices} to match Number of Devices." | |
| ) | |
| split_backward = strategy in ["zb1p", "dualpipe"] | |
| if split_backward and not all([op_time_backward_d, op_time_backward_w]): | |
| error_message = f"Strategy '{strategy}': Backward D and Backward W times are required." | |
| elif not split_backward and not op_time_backward: | |
| error_message = f"Strategy '{strategy}': Combined Backward time is required." | |
| if not error_message: | |
| if strategy in ["1f1b", "1f1b_overlap", "zb1p"]: | |
| placement_strategy = "standard" | |
| # No need to check num_stages == num_devices as we've enforced it above | |
| elif strategy in ["1f1b_interleave", "1f1b_interleave_overlap"]: | |
| placement_strategy = "interleave" | |
| if current_num_stages % current_num_devices != 0: | |
| error_message = f"Strategy '{strategy}': Requires Number of Stages to be divisible by Number of Devices." | |
| elif strategy == "dualpipe": | |
| placement_strategy = "dualpipe" | |
| if current_num_stages % 2 != 0: | |
| error_message = f"Strategy '{strategy}' (DualPipe): Requires an even number of stages." | |
| # Create adjusted operation times based on placement strategy | |
| if not error_message: | |
| try: | |
| # Calculate number of stages per device for time adjustment | |
| stages_per_device = current_num_stages // current_num_devices | |
| # Calculate scaling factor - this normalizes operation time by stages per device | |
| # For standard placement (1:1 stage:device mapping), this remains 1.0 | |
| # For interleaved, this scales down the time proportionally | |
| time_scale_factor = 1.0 / stages_per_device if stages_per_device > 0 else 1.0 | |
| if stages_per_device > 1: | |
| automatic_adjustments.append( | |
| f"Strategy '{strategy}': Operation times scaled by 1/{stages_per_device} to account for {stages_per_device} stages per device." | |
| ) | |
| # Apply scaling to operation times | |
| op_times = { | |
| "forward": float(op_time_forward) * time_scale_factor | |
| } | |
| if split_backward: | |
| op_times["backward_D"] = float(op_time_backward_d) * time_scale_factor | |
| op_times["backward_W"] = float(op_time_backward_w) * time_scale_factor | |
| # Keep combined for compatibility | |
| op_times["backward"] = (float(op_time_backward_d) + float(op_time_backward_w)) * time_scale_factor | |
| else: | |
| op_times["backward"] = float(op_time_backward) * time_scale_factor | |
| if op_time_overlapped_fwd_bwd is not None: | |
| try: | |
| overlapped_val = float(op_time_overlapped_fwd_bwd) | |
| if overlapped_val > 0: | |
| # Scale overlapped time too | |
| op_times["overlapped_forward_backward"] = overlapped_val * time_scale_factor | |
| except (ValueError, TypeError): | |
| pass | |
| config = ScheduleConfig( | |
| num_devices=int(current_num_devices), | |
| num_stages=int(current_num_stages), # Use adjusted value | |
| num_batches=int(num_batches), | |
| p2p_latency=float(p2p_latency), | |
| placement_strategy=placement_strategy, | |
| split_backward=split_backward, | |
| op_times=op_times, | |
| ) | |
| schedule_func = STRATEGIES.get(strategy) | |
| if not schedule_func: | |
| raise ValueError(f"Invalid strategy function for: {strategy}") | |
| schedule = schedule_func(config) | |
| schedule.execute() | |
| # Store valid results instead of creating figure immediately | |
| vis_data = convert_schedule_to_visualization_format(schedule) | |
| valid_results.append((strategy, schedule, vis_data)) | |
| except (AssertionError, ValueError, TypeError) as e: | |
| error_message = f"Error generating schedule for '{strategy}': {e}" | |
| import traceback | |
| traceback.print_exc() | |
| except Exception as e: | |
| error_message = f"An unexpected error occurred for '{strategy}': {e}" | |
| import traceback | |
| traceback.print_exc() | |
| if error_message: | |
| error_messages.append((strategy, error_message)) | |
| # Add alerts for any automatic parameter adjustments | |
| for adjustment in automatic_adjustments: | |
| output_components.append( | |
| dbc.Alert(adjustment, color="info", dismissable=True) | |
| ) | |
| # If we have valid results, calculate the maximum execution time across all schedules | |
| if valid_results: | |
| # Find global maximum execution time | |
| max_execution_time = max(schedule.get_total_execution_time() for _, schedule, _ in valid_results) | |
| # Sort valid results according to the display order | |
| sorted_valid_results = [] | |
| # First add strategies in the predefined order | |
| for strategy_name in strategy_display_order: | |
| for result in valid_results: | |
| if result[0] == strategy_name: | |
| sorted_valid_results.append(result) | |
| # Then add any remaining strategies that might not be in the predefined order | |
| for result in valid_results: | |
| if result[0] not in strategy_display_order: | |
| sorted_valid_results.append(result) | |
| # Create figures with aligned x-axis, using the sorted results | |
| for strategy, _, vis_data in sorted_valid_results: | |
| fig = create_pipeline_figure(vis_data, max_time=max_execution_time, show_progress=False) | |
| # Force the x-axis range to be the same for all figures | |
| # Add a small margin (5%) for better visualization | |
| margin = max_execution_time * 0.05 | |
| fig.update_layout( | |
| xaxis=dict( | |
| range=[0, max_execution_time + margin] | |
| ) | |
| ) | |
| output_components.append(html.Div([ | |
| html.H4(f"Schedule: {strategy}", className="text-center mt-3 mb-2"), | |
| dcc.Graph(figure=fig) | |
| ])) | |
| # Add error messages to output | |
| for strategy, msg in error_messages: | |
| output_components.append( | |
| dbc.Alert(msg, color="danger", className="mt-3") | |
| ) | |
| return output_components | |
| # For Hugging Face Spaces deployment | |
| server = app.server | |
| if __name__ == '__main__': | |
| app.run_server(debug=False, host='0.0.0.0', port=7860) |