Spaces:
Running
Running
Update UI.
Browse files- app.py +355 -144
- assets/clientside.js +62 -0
- assets/custom.css +129 -0
app.py
CHANGED
|
@@ -1,6 +1,6 @@
|
|
| 1 |
import dash
|
| 2 |
import dash_bootstrap_components as dbc
|
| 3 |
-
from dash import dcc, html, Input, Output, State, callback_context
|
| 4 |
import plotly.graph_objects as go
|
| 5 |
|
| 6 |
from src.execution_model import ScheduleConfig, Schedule
|
|
@@ -23,7 +23,7 @@ STRATEGIES = {
|
|
| 23 |
"dualpipe": generate_dualpipe_schedule,
|
| 24 |
}
|
| 25 |
|
| 26 |
-
app = dash.Dash(__name__, external_stylesheets=[dbc.themes.BOOTSTRAP], suppress_callback_exceptions=True)
|
| 27 |
app.title = "Pipeline Parallelism Schedule Visualizer"
|
| 28 |
|
| 29 |
# Initial default values
|
|
@@ -36,107 +36,321 @@ default_values = {
|
|
| 36 |
"op_time_backward_d": 1.0,
|
| 37 |
"op_time_backward_w": 1.0,
|
| 38 |
"op_time_backward": 2.0,
|
| 39 |
-
"strategy": "1f1b_interleave",
|
| 40 |
"op_time_overlapped_fwd_bwd": None,
|
| 41 |
}
|
| 42 |
|
| 43 |
# Define input groups using dbc components
|
|
|
|
|
|
|
| 44 |
basic_params_card = dbc.Card(
|
| 45 |
dbc.CardBody([
|
| 46 |
-
html.H5("Basic Parameters", className="card-title"),
|
| 47 |
-
html.Div([
|
| 48 |
-
dbc.Label("Number of Devices (GPUs):"),
|
| 49 |
-
dbc.Input(id='num_devices', type='number', value=default_values["num_devices"], min=1, step=1),
|
| 50 |
-
], className="mb-3"),
|
| 51 |
html.Div([
|
| 52 |
-
dbc.Label("Number of
|
| 53 |
-
dbc.Input(id='
|
|
|
|
| 54 |
], className="mb-3"),
|
| 55 |
html.Div([
|
| 56 |
-
dbc.Label("Number of
|
| 57 |
-
dbc.Input(id='
|
|
|
|
| 58 |
], className="mb-3"),
|
| 59 |
html.Div([
|
| 60 |
-
dbc.Label("
|
| 61 |
-
dbc.Input(id='
|
|
|
|
| 62 |
], className="mb-3"),
|
| 63 |
-
])
|
|
|
|
| 64 |
)
|
| 65 |
|
| 66 |
scheduling_params_card = dbc.Card(
|
| 67 |
dbc.CardBody([
|
| 68 |
-
html.H5("Scheduling
|
| 69 |
-
|
| 70 |
-
|
| 71 |
-
|
| 72 |
-
|
| 73 |
-
|
| 74 |
-
|
| 75 |
-
|
| 76 |
-
|
| 77 |
-
|
| 78 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 79 |
)
|
| 80 |
|
| 81 |
timing_params_card = dbc.Card(
|
| 82 |
dbc.CardBody([
|
| 83 |
-
html.H5("Operation Timing (ms)", className="card-title"),
|
| 84 |
-
html.Div([
|
| 85 |
-
dbc.Label("Forward:"),
|
| 86 |
-
dbc.Input(id='op_time_forward', type='number', value=default_values["op_time_forward"], min=0.01, step=0.01),
|
| 87 |
-
], className="mb-3"),
|
| 88 |
html.Div([
|
| 89 |
-
|
| 90 |
-
|
| 91 |
-
|
| 92 |
-
|
| 93 |
-
|
| 94 |
-
dbc.
|
| 95 |
-
dbc.
|
| 96 |
-
|
|
|
|
|
|
|
|
|
|
| 97 |
], className="mb-3"),
|
| 98 |
html.Div([
|
| 99 |
-
|
| 100 |
-
|
| 101 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 102 |
], className="mb-3"),
|
| 103 |
html.Div([
|
| 104 |
-
|
| 105 |
-
|
| 106 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 107 |
], className="mb-3"),
|
| 108 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 109 |
)
|
| 110 |
|
| 111 |
# Updated app layout using dbc components and structure
|
| 112 |
app.layout = dbc.Container([
|
| 113 |
html.H1("Pipeline Parallelism Schedule Visualizer", className="my-4 text-center"),
|
| 114 |
|
|
|
|
| 115 |
dbc.Row([
|
| 116 |
-
|
| 117 |
-
dbc.Col(scheduling_params_card, md=4),
|
| 118 |
-
dbc.Col(timing_params_card, md=4),
|
| 119 |
-
]),
|
| 120 |
-
|
| 121 |
-
dbc.Row([
|
| 122 |
-
dbc.Col([
|
| 123 |
-
dbc.Button('Generate Schedule', id='generate-button', n_clicks=0, color="primary", className="mt-4"),
|
| 124 |
-
], className="text-center")
|
| 125 |
-
]),
|
| 126 |
-
|
| 127 |
-
dbc.Row([
|
| 128 |
dbc.Col([
|
|
|
|
| 129 |
dcc.Loading(
|
| 130 |
id="loading-graph-area",
|
| 131 |
type="circle",
|
| 132 |
-
children=html.Div(id='graph-output-container',
|
| 133 |
)
|
| 134 |
-
])
|
| 135 |
-
|
| 136 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 137 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 138 |
@app.callback(
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 139 |
Output('graph-output-container', 'children'),
|
|
|
|
| 140 |
Input('generate-button', 'n_clicks'),
|
| 141 |
State('num_devices', 'value'),
|
| 142 |
State('num_stages', 'value'),
|
|
@@ -147,7 +361,7 @@ app.layout = dbc.Container([
|
|
| 147 |
State('op_time_backward_d', 'value'),
|
| 148 |
State('op_time_backward_w', 'value'),
|
| 149 |
State('op_time_overlapped_fwd_bwd', 'value'),
|
| 150 |
-
State('
|
| 151 |
prevent_initial_call=True
|
| 152 |
)
|
| 153 |
def update_graph(n_clicks, num_devices, num_stages, num_batches, p2p_latency,
|
|
@@ -155,19 +369,39 @@ def update_graph(n_clicks, num_devices, num_stages, num_batches, p2p_latency,
|
|
| 155 |
op_time_overlapped_fwd_bwd,
|
| 156 |
selected_strategies):
|
| 157 |
|
| 158 |
-
# Define the desired display order for strategies
|
| 159 |
strategy_display_order = ["1f1b", "1f1b_interleave", "1f1b_overlap", "1f1b_interleave_overlap", "dualpipe", "zb1p"]
|
| 160 |
-
|
| 161 |
-
|
| 162 |
-
|
| 163 |
-
|
| 164 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 165 |
|
| 166 |
if not selected_strategies:
|
| 167 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 168 |
|
| 169 |
-
if
|
| 170 |
-
|
|
|
|
| 171 |
|
| 172 |
for strategy in selected_strategies:
|
| 173 |
error_message = ""
|
|
@@ -179,17 +413,15 @@ def update_graph(n_clicks, num_devices, num_stages, num_batches, p2p_latency,
|
|
| 179 |
|
| 180 |
# Apply automatic adjustments for dualpipe
|
| 181 |
if strategy == "dualpipe" and num_stages != num_devices:
|
| 182 |
-
current_num_stages = num_devices
|
| 183 |
-
|
| 184 |
-
|
| 185 |
-
)
|
| 186 |
|
| 187 |
# Apply automatic adjustments for strategies that require num_stages == num_devices
|
| 188 |
if strategy in ["1f1b", "1f1b_overlap", "zb1p"] and num_stages != num_devices:
|
| 189 |
current_num_stages = num_devices
|
| 190 |
-
|
| 191 |
-
|
| 192 |
-
)
|
| 193 |
|
| 194 |
split_backward = strategy in ["zb1p", "dualpipe"]
|
| 195 |
|
|
@@ -201,41 +433,32 @@ def update_graph(n_clicks, num_devices, num_stages, num_batches, p2p_latency,
|
|
| 201 |
if not error_message:
|
| 202 |
if strategy in ["1f1b", "1f1b_overlap", "zb1p"]:
|
| 203 |
placement_strategy = "standard"
|
| 204 |
-
# No need to check num_stages == num_devices as we've enforced it above
|
| 205 |
elif strategy in ["1f1b_interleave", "1f1b_interleave_overlap"]:
|
| 206 |
placement_strategy = "interleave"
|
| 207 |
if current_num_stages % current_num_devices != 0:
|
| 208 |
-
error_message = f"Strategy '{strategy}': Requires
|
| 209 |
elif strategy == "dualpipe":
|
| 210 |
placement_strategy = "dualpipe"
|
| 211 |
if current_num_stages % 2 != 0:
|
| 212 |
-
error_message = f"Strategy '{strategy}'
|
| 213 |
|
| 214 |
# Create adjusted operation times based on placement strategy
|
| 215 |
if not error_message:
|
| 216 |
try:
|
| 217 |
-
# Calculate number of stages per device for time adjustment
|
| 218 |
stages_per_device = current_num_stages // current_num_devices
|
| 219 |
-
|
| 220 |
-
# Calculate scaling factor - this normalizes operation time by stages per device
|
| 221 |
-
# For standard placement (1:1 stage:device mapping), this remains 1.0
|
| 222 |
-
# For interleaved, this scales down the time proportionally
|
| 223 |
time_scale_factor = 1.0 / stages_per_device if stages_per_device > 0 else 1.0
|
| 224 |
-
|
| 225 |
if stages_per_device > 1:
|
| 226 |
-
|
| 227 |
-
|
| 228 |
-
|
| 229 |
-
|
| 230 |
-
|
| 231 |
-
op_times = {
|
| 232 |
-
|
| 233 |
-
}
|
| 234 |
-
|
| 235 |
if split_backward:
|
| 236 |
op_times["backward_D"] = float(op_time_backward_d) * time_scale_factor
|
| 237 |
op_times["backward_W"] = float(op_time_backward_w) * time_scale_factor
|
| 238 |
-
# Keep combined for compatibility
|
| 239 |
op_times["backward"] = (float(op_time_backward_d) + float(op_time_backward_w)) * time_scale_factor
|
| 240 |
else:
|
| 241 |
op_times["backward"] = float(op_time_backward) * time_scale_factor
|
|
@@ -244,14 +467,13 @@ def update_graph(n_clicks, num_devices, num_stages, num_batches, p2p_latency,
|
|
| 244 |
try:
|
| 245 |
overlapped_val = float(op_time_overlapped_fwd_bwd)
|
| 246 |
if overlapped_val > 0:
|
| 247 |
-
# Scale overlapped time too
|
| 248 |
op_times["overlapped_forward_backward"] = overlapped_val * time_scale_factor
|
| 249 |
except (ValueError, TypeError):
|
| 250 |
pass
|
| 251 |
|
| 252 |
config = ScheduleConfig(
|
| 253 |
num_devices=int(current_num_devices),
|
| 254 |
-
num_stages=int(current_num_stages),
|
| 255 |
num_batches=int(num_batches),
|
| 256 |
p2p_latency=float(p2p_latency),
|
| 257 |
placement_strategy=placement_strategy,
|
|
@@ -265,73 +487,62 @@ def update_graph(n_clicks, num_devices, num_stages, num_batches, p2p_latency,
|
|
| 265 |
|
| 266 |
schedule = schedule_func(config)
|
| 267 |
schedule.execute()
|
| 268 |
-
|
| 269 |
-
# Store valid results instead of creating figure immediately
|
| 270 |
vis_data = convert_schedule_to_visualization_format(schedule)
|
| 271 |
valid_results.append((strategy, schedule, vis_data))
|
| 272 |
|
| 273 |
except (AssertionError, ValueError, TypeError) as e:
|
| 274 |
-
error_message = f"Error
|
| 275 |
-
import traceback
|
| 276 |
-
traceback.print_exc()
|
| 277 |
except Exception as e:
|
| 278 |
-
error_message = f"
|
| 279 |
-
import traceback
|
| 280 |
-
traceback.print_exc()
|
| 281 |
|
| 282 |
if error_message:
|
| 283 |
error_messages.append((strategy, error_message))
|
| 284 |
|
| 285 |
-
#
|
|
|
|
| 286 |
for adjustment in automatic_adjustments:
|
| 287 |
-
|
| 288 |
-
dbc.
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 289 |
)
|
| 290 |
|
| 291 |
-
#
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 292 |
if valid_results:
|
| 293 |
-
# Find global maximum execution time
|
| 294 |
max_execution_time = max(schedule.get_total_execution_time() for _, schedule, _ in valid_results)
|
| 295 |
-
|
| 296 |
-
|
| 297 |
-
sorted_valid_results = []
|
| 298 |
-
|
| 299 |
-
# First add strategies in the predefined order
|
| 300 |
-
for strategy_name in strategy_display_order:
|
| 301 |
-
for result in valid_results:
|
| 302 |
-
if result[0] == strategy_name:
|
| 303 |
-
sorted_valid_results.append(result)
|
| 304 |
-
|
| 305 |
-
# Then add any remaining strategies that might not be in the predefined order
|
| 306 |
-
for result in valid_results:
|
| 307 |
-
if result[0] not in strategy_display_order:
|
| 308 |
-
sorted_valid_results.append(result)
|
| 309 |
-
|
| 310 |
-
# Create figures with aligned x-axis, using the sorted results
|
| 311 |
for strategy, _, vis_data in sorted_valid_results:
|
| 312 |
fig = create_pipeline_figure(vis_data, max_time=max_execution_time, show_progress=False)
|
| 313 |
-
|
| 314 |
-
# Force the x-axis range to be the same for all figures
|
| 315 |
-
# Add a small margin (5%) for better visualization
|
| 316 |
margin = max_execution_time * 0.05
|
| 317 |
fig.update_layout(
|
| 318 |
-
xaxis=dict(
|
| 319 |
-
range=[0, max_execution_time + margin]
|
| 320 |
-
)
|
| 321 |
)
|
| 322 |
-
|
| 323 |
-
output_components.append(html.Div([
|
| 324 |
html.H4(f"Schedule: {strategy}", className="text-center mt-3 mb-2"),
|
| 325 |
dcc.Graph(figure=fig)
|
| 326 |
]))
|
| 327 |
-
|
| 328 |
-
# Add error messages to output
|
| 329 |
-
for strategy, msg in error_messages:
|
| 330 |
-
output_components.append(
|
| 331 |
-
dbc.Alert(msg, color="danger", className="mt-3")
|
| 332 |
-
)
|
| 333 |
|
| 334 |
-
|
|
|
|
| 335 |
|
| 336 |
# For Hugging Face Spaces deployment
|
| 337 |
server = app.server
|
|
|
|
| 1 |
import dash
|
| 2 |
import dash_bootstrap_components as dbc
|
| 3 |
+
from dash import dcc, html, Input, Output, State, callback_context, ALL, ClientsideFunction
|
| 4 |
import plotly.graph_objects as go
|
| 5 |
|
| 6 |
from src.execution_model import ScheduleConfig, Schedule
|
|
|
|
| 23 |
"dualpipe": generate_dualpipe_schedule,
|
| 24 |
}
|
| 25 |
|
| 26 |
+
app = dash.Dash(__name__, external_stylesheets=[dbc.themes.BOOTSTRAP, dbc.icons.BOOTSTRAP], suppress_callback_exceptions=True)
|
| 27 |
app.title = "Pipeline Parallelism Schedule Visualizer"
|
| 28 |
|
| 29 |
# Initial default values
|
|
|
|
| 36 |
"op_time_backward_d": 1.0,
|
| 37 |
"op_time_backward_w": 1.0,
|
| 38 |
"op_time_backward": 2.0,
|
| 39 |
+
"strategy": ["1f1b_interleave"],
|
| 40 |
"op_time_overlapped_fwd_bwd": None,
|
| 41 |
}
|
| 42 |
|
| 43 |
# Define input groups using dbc components
|
| 44 |
+
card_style = {"marginBottom": "24px"}
|
| 45 |
+
|
| 46 |
basic_params_card = dbc.Card(
|
| 47 |
dbc.CardBody([
|
| 48 |
+
html.H5("Basic Parameters", className="card-title mb-4"),
|
|
|
|
|
|
|
|
|
|
|
|
|
| 49 |
html.Div([
|
| 50 |
+
dbc.Label("Number of Devices (GPUs)", html_for='num_devices', className="form-label"),
|
| 51 |
+
dbc.Input(id='num_devices', type='number', value=default_values["num_devices"], min=1, step=1, required=True),
|
| 52 |
+
dbc.FormFeedback("Please provide a positive integer for the number of devices.", type="invalid", id="feedback-num_devices"),
|
| 53 |
], className="mb-3"),
|
| 54 |
html.Div([
|
| 55 |
+
dbc.Label("Number of Stages (Model Chunks)", html_for='num_stages', className="form-label"),
|
| 56 |
+
dbc.Input(id='num_stages', type='number', value=default_values["num_stages"], min=1, step=1, required=True),
|
| 57 |
+
dbc.FormFeedback("Please provide a positive integer for the number of stages.", type="invalid", id="feedback-num_stages"),
|
| 58 |
], className="mb-3"),
|
| 59 |
html.Div([
|
| 60 |
+
dbc.Label("Number of Microbatches", html_for='num_batches', className="form-label"),
|
| 61 |
+
dbc.Input(id='num_batches', type='number', value=default_values["num_batches"], min=1, step=1, required=True),
|
| 62 |
+
dbc.FormFeedback("Please provide a positive integer for the number of microbatches.", type="invalid", id="feedback-num_batches"),
|
| 63 |
], className="mb-3"),
|
| 64 |
+
]),
|
| 65 |
+
style=card_style
|
| 66 |
)
|
| 67 |
|
| 68 |
scheduling_params_card = dbc.Card(
|
| 69 |
dbc.CardBody([
|
| 70 |
+
html.H5("Scheduling Strategy", className="card-title mb-4"),
|
| 71 |
+
dbc.ButtonGroup(
|
| 72 |
+
[
|
| 73 |
+
dbc.Button(
|
| 74 |
+
strategy,
|
| 75 |
+
id={"type": "strategy-button", "index": strategy},
|
| 76 |
+
color="secondary",
|
| 77 |
+
outline=True,
|
| 78 |
+
active=strategy in default_values["strategy"],
|
| 79 |
+
className="me-1"
|
| 80 |
+
)
|
| 81 |
+
for strategy in STRATEGIES.keys()
|
| 82 |
+
],
|
| 83 |
+
className="d-flex flex-wrap"
|
| 84 |
+
),
|
| 85 |
+
dcc.Store(id='selected-strategies-store', data=default_values["strategy"]),
|
| 86 |
+
html.Div(id='strategy-selection-feedback', className='invalid-feedback d-block mt-2')
|
| 87 |
+
]),
|
| 88 |
+
style=card_style
|
| 89 |
)
|
| 90 |
|
| 91 |
timing_params_card = dbc.Card(
|
| 92 |
dbc.CardBody([
|
| 93 |
+
html.H5("Operation Timing (ms)", className="card-title mb-4"),
|
|
|
|
|
|
|
|
|
|
|
|
|
| 94 |
html.Div([
|
| 95 |
+
html.Div([
|
| 96 |
+
dbc.Label("P2P Latency", html_for='p2p_latency', className="form-label d-inline-block me-1"),
|
| 97 |
+
html.I(className="bi bi-info-circle", id="tooltip-target-p2p", style={"cursor": "pointer"})
|
| 98 |
+
]),
|
| 99 |
+
dbc.Input(id='p2p_latency', type='number', value=default_values["p2p_latency"], min=0, step=0.01, required=True),
|
| 100 |
+
dbc.FormFeedback("P2P latency must be a number >= 0.", type="invalid", id="feedback-p2p_latency"),
|
| 101 |
+
dbc.Tooltip(
|
| 102 |
+
"Time (ms) for point-to-point communication between adjacent devices.",
|
| 103 |
+
target="tooltip-target-p2p",
|
| 104 |
+
placement="right"
|
| 105 |
+
)
|
| 106 |
], className="mb-3"),
|
| 107 |
html.Div([
|
| 108 |
+
html.Div([
|
| 109 |
+
dbc.Label("Forward Operation Time", html_for='op_time_forward', className="form-label d-inline-block me-1"),
|
| 110 |
+
html.I(className="bi bi-info-circle", id="tooltip-target-fwd", style={"cursor": "pointer"})
|
| 111 |
+
]),
|
| 112 |
+
dbc.Input(id='op_time_forward', type='number', value=default_values["op_time_forward"], min=0.01, step=0.01, required=True),
|
| 113 |
+
dbc.FormFeedback("Forward time must be a number > 0.", type="invalid", id="feedback-op_time_forward"),
|
| 114 |
+
dbc.Tooltip(
|
| 115 |
+
"Time (ms) for a single forward pass of one microbatch through one stage.",
|
| 116 |
+
target="tooltip-target-fwd",
|
| 117 |
+
placement="right"
|
| 118 |
+
)
|
| 119 |
], className="mb-3"),
|
| 120 |
html.Div([
|
| 121 |
+
html.Div([
|
| 122 |
+
dbc.Label("Backward (Combined)", html_for='op_time_backward', className="form-label d-inline-block me-1"),
|
| 123 |
+
html.I(className="bi bi-info-circle", id="tooltip-target-bwd", style={"cursor": "pointer"})
|
| 124 |
+
]),
|
| 125 |
+
dbc.Input(id='op_time_backward', type='number', value=default_values["op_time_backward"], min=0.01, step=0.01),
|
| 126 |
+
dbc.FormText("Used when strategy does NOT require split backward."),
|
| 127 |
+
dbc.FormFeedback("Backward time must be > 0 if specified.", type="invalid", id="feedback-op_time_backward"),
|
| 128 |
+
dbc.Tooltip(
|
| 129 |
+
"Time (ms) for a combined backward pass (data gradient + weight gradient) of one microbatch through one stage.",
|
| 130 |
+
target="tooltip-target-bwd",
|
| 131 |
+
placement="right"
|
| 132 |
+
)
|
| 133 |
], className="mb-3"),
|
| 134 |
+
|
| 135 |
+
# --- Collapsible Advanced Options (Item 3) ---
|
| 136 |
+
html.Hr(className="my-3"),
|
| 137 |
+
dbc.Switch(
|
| 138 |
+
id="advanced-timing-switch",
|
| 139 |
+
label="Show Advanced Timing Options",
|
| 140 |
+
value=False,
|
| 141 |
+
className="mb-3"
|
| 142 |
+
),
|
| 143 |
+
dbc.Collapse(
|
| 144 |
+
id="advanced-timing-collapse",
|
| 145 |
+
is_open=False,
|
| 146 |
+
children=[
|
| 147 |
+
html.Div([
|
| 148 |
+
html.Div([
|
| 149 |
+
dbc.Label("Backward D (Data Grad)", html_for='op_time_backward_d', className="form-label d-inline-block me-1"),
|
| 150 |
+
html.I(className="bi bi-info-circle", id="tooltip-target-bwd-d", style={"cursor": "pointer"})
|
| 151 |
+
]),
|
| 152 |
+
dbc.Input(id='op_time_backward_d', type='number', value=default_values["op_time_backward_d"], min=0.01, step=0.01),
|
| 153 |
+
dbc.FormText("Used when strategy requires split backward (e.g., ZB-1P, DualPipe)."),
|
| 154 |
+
dbc.FormFeedback("Backward D time must be > 0 if specified.", type="invalid", id="feedback-op_time_backward_d"),
|
| 155 |
+
dbc.Tooltip(
|
| 156 |
+
"Time (ms) for the data gradient part of the backward pass.",
|
| 157 |
+
target="tooltip-target-bwd-d",
|
| 158 |
+
placement="right"
|
| 159 |
+
)
|
| 160 |
+
], className="mb-3"),
|
| 161 |
+
html.Div([
|
| 162 |
+
html.Div([
|
| 163 |
+
dbc.Label("Backward W (Weight Grad)", html_for='op_time_backward_w', className="form-label d-inline-block me-1"),
|
| 164 |
+
html.I(className="bi bi-info-circle", id="tooltip-target-bwd-w", style={"cursor": "pointer"})
|
| 165 |
+
]),
|
| 166 |
+
dbc.Input(id='op_time_backward_w', type='number', value=default_values["op_time_backward_w"], min=0.01, step=0.01),
|
| 167 |
+
dbc.FormText("Used when strategy requires split backward (e.g., ZB-1P, DualPipe)."),
|
| 168 |
+
dbc.FormFeedback("Backward W time must be > 0 if specified.", type="invalid", id="feedback-op_time_backward_w"),
|
| 169 |
+
dbc.Tooltip(
|
| 170 |
+
"Time (ms) for the weight gradient part of the backward pass.",
|
| 171 |
+
target="tooltip-target-bwd-w",
|
| 172 |
+
placement="right"
|
| 173 |
+
)
|
| 174 |
+
], className="mb-3"),
|
| 175 |
+
html.Div([
|
| 176 |
+
html.Div([
|
| 177 |
+
dbc.Label("Overlapped Forward+Backward", html_for='op_time_overlapped_fwd_bwd', className="form-label d-inline-block me-1"),
|
| 178 |
+
html.I(className="bi bi-info-circle", id="tooltip-target-overlap", style={"cursor": "pointer"})
|
| 179 |
+
]),
|
| 180 |
+
dbc.Input(id='op_time_overlapped_fwd_bwd', type='number', placeholder="Defaults to Fwd + Bwd", min=0.01, step=0.01, value=default_values["op_time_overlapped_fwd_bwd"]),
|
| 181 |
+
dbc.FormText("Specify if Forward and Backward ops overlap completely."),
|
| 182 |
+
dbc.FormFeedback("Overlapped time must be > 0 if specified.", type="invalid", id="feedback-op_time_overlapped_fwd_bwd"),
|
| 183 |
+
dbc.Tooltip(
|
| 184 |
+
"Optional: Specify a single time (ms) if the forward and backward passes for a microbatch can be fully overlapped within the same stage execution slot.",
|
| 185 |
+
target="tooltip-target-overlap",
|
| 186 |
+
placement="right"
|
| 187 |
+
)
|
| 188 |
+
], className="mb-3"),
|
| 189 |
+
]
|
| 190 |
+
)
|
| 191 |
+
]),
|
| 192 |
+
style=card_style
|
| 193 |
)
|
| 194 |
|
| 195 |
# Updated app layout using dbc components and structure
|
| 196 |
app.layout = dbc.Container([
|
| 197 |
html.H1("Pipeline Parallelism Schedule Visualizer", className="my-4 text-center"),
|
| 198 |
|
| 199 |
+
# Main Row with Left (Graphs) and Right (Controls) Columns
|
| 200 |
dbc.Row([
|
| 201 |
+
# --- Left Column (Graphs Area) ---
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 202 |
dbc.Col([
|
| 203 |
+
# Output Area for Graphs
|
| 204 |
dcc.Loading(
|
| 205 |
id="loading-graph-area",
|
| 206 |
type="circle",
|
| 207 |
+
children=html.Div(id='graph-output-container', style={"minHeight": "600px"})
|
| 208 |
)
|
| 209 |
+
], lg=8, md=7, sm=12, className="mb-4 mb-lg-0"),
|
| 210 |
+
|
| 211 |
+
# --- Right Column (Controls Area) ---
|
| 212 |
+
dbc.Col([
|
| 213 |
+
# Parameter Cards Stacked Vertically
|
| 214 |
+
basic_params_card,
|
| 215 |
+
scheduling_params_card,
|
| 216 |
+
timing_params_card,
|
| 217 |
+
|
| 218 |
+
# Generate Button below the cards in the right column
|
| 219 |
+
dbc.Row([
|
| 220 |
+
dbc.Col(
|
| 221 |
+
dbc.Button(
|
| 222 |
+
'Generate Schedule',
|
| 223 |
+
id='generate-button',
|
| 224 |
+
n_clicks=0,
|
| 225 |
+
color="primary",
|
| 226 |
+
className="w-100",
|
| 227 |
+
disabled=False
|
| 228 |
+
),
|
| 229 |
+
)
|
| 230 |
+
], className="mt-3")
|
| 231 |
+
], lg=4, md=5, sm=12)
|
| 232 |
+
]),
|
| 233 |
+
|
| 234 |
+
# --- Toast Container (Positioned Fixed) ---
|
| 235 |
+
html.Div(id="toast-container", style={"position": "fixed", "top": 20, "right": 20, "zIndex": 1050})
|
| 236 |
|
| 237 |
+
], fluid=True, className="py-4")
|
| 238 |
+
|
| 239 |
+
# --- Callback for Input Validation and Generate Button State ---
|
| 240 |
+
@app.callback(
|
| 241 |
+
Output('generate-button', 'disabled'),
|
| 242 |
+
# Outputs to control the 'invalid' state of Inputs
|
| 243 |
+
Output('num_devices', 'invalid'),
|
| 244 |
+
Output('num_stages', 'invalid'),
|
| 245 |
+
Output('num_batches', 'invalid'),
|
| 246 |
+
Output('p2p_latency', 'invalid'),
|
| 247 |
+
Output('op_time_forward', 'invalid'),
|
| 248 |
+
Output('op_time_backward', 'invalid'),
|
| 249 |
+
Output('op_time_backward_d', 'invalid'),
|
| 250 |
+
Output('op_time_backward_w', 'invalid'),
|
| 251 |
+
Output('op_time_overlapped_fwd_bwd', 'invalid'),
|
| 252 |
+
# Outputs to control the visibility/content of FormFeedback (can also just control Input's invalid state)
|
| 253 |
+
# We are primarily using the Input's `invalid` prop which automatically shows/hides associated FormFeedback
|
| 254 |
+
# Output('feedback-num_devices', 'children'), ... (Add if more specific messages needed per validation type)
|
| 255 |
+
Output('strategy-selection-feedback', 'children', allow_duplicate=True), # Update feedback from validation callback too
|
| 256 |
+
# Inputs: Trigger validation whenever any relevant input changes
|
| 257 |
+
Input('num_devices', 'value'),
|
| 258 |
+
Input('num_stages', 'value'),
|
| 259 |
+
Input('num_batches', 'value'),
|
| 260 |
+
Input('p2p_latency', 'value'),
|
| 261 |
+
Input('op_time_forward', 'value'),
|
| 262 |
+
Input('op_time_backward', 'value'),
|
| 263 |
+
Input('op_time_backward_d', 'value'),
|
| 264 |
+
Input('op_time_backward_w', 'value'),
|
| 265 |
+
Input('op_time_overlapped_fwd_bwd', 'value'),
|
| 266 |
+
Input('selected-strategies-store', 'data'), # Validate strategy selection
|
| 267 |
+
prevent_initial_call=True # Prevent callback running on page load before user interaction
|
| 268 |
+
)
|
| 269 |
+
def validate_inputs(num_devices, num_stages, num_batches, p2p_latency,
|
| 270 |
+
op_time_forward, op_time_backward, op_time_backward_d, op_time_backward_w,
|
| 271 |
+
op_time_overlapped_fwd_bwd, selected_strategies):
|
| 272 |
+
is_invalid = {
|
| 273 |
+
"num_devices": num_devices is None or num_devices < 1,
|
| 274 |
+
"num_stages": num_stages is None or num_stages < 1,
|
| 275 |
+
"num_batches": num_batches is None or num_batches < 1,
|
| 276 |
+
"p2p_latency": p2p_latency is None or p2p_latency < 0,
|
| 277 |
+
"op_time_forward": op_time_forward is None or op_time_forward <= 0,
|
| 278 |
+
"op_time_backward": op_time_backward is not None and op_time_backward <= 0,
|
| 279 |
+
"op_time_backward_d": op_time_backward_d is not None and op_time_backward_d <= 0,
|
| 280 |
+
"op_time_backward_w": op_time_backward_w is not None and op_time_backward_w <= 0,
|
| 281 |
+
"op_time_overlapped_fwd_bwd": op_time_overlapped_fwd_bwd is not None and op_time_overlapped_fwd_bwd <= 0,
|
| 282 |
+
}
|
| 283 |
+
|
| 284 |
+
# Validate strategy selection
|
| 285 |
+
strategy_feedback = "" # Default empty feedback
|
| 286 |
+
if not selected_strategies or len(selected_strategies) == 0:
|
| 287 |
+
is_invalid["strategies"] = True
|
| 288 |
+
strategy_feedback = "Please select at least one strategy."
|
| 289 |
+
else:
|
| 290 |
+
is_invalid["strategies"] = False
|
| 291 |
+
# Additional validation: Check if required timings are provided for selected strategies
|
| 292 |
+
needs_split_backward = any(s in ["zb1p", "dualpipe"] for s in selected_strategies)
|
| 293 |
+
needs_combined_backward = any(s not in ["zb1p", "dualpipe"] for s in selected_strategies)
|
| 294 |
+
|
| 295 |
+
if needs_split_backward and (op_time_backward_d is None or op_time_backward_w is None):
|
| 296 |
+
is_invalid["op_time_backward_d"] = op_time_backward_d is None or op_time_backward_d <= 0
|
| 297 |
+
is_invalid["op_time_backward_w"] = op_time_backward_w is None or op_time_backward_w <= 0
|
| 298 |
+
# We might want specific feedback here, but setting invalid=True is often enough
|
| 299 |
+
|
| 300 |
+
if needs_combined_backward and op_time_backward is None:
|
| 301 |
+
is_invalid["op_time_backward"] = op_time_backward is None or op_time_backward <= 0
|
| 302 |
+
|
| 303 |
+
# Check if any input is invalid
|
| 304 |
+
overall_invalid = any(is_invalid.values())
|
| 305 |
+
|
| 306 |
+
# Disable button if any validation fails
|
| 307 |
+
disable_button = overall_invalid
|
| 308 |
+
|
| 309 |
+
# Return button state and invalid states for each input
|
| 310 |
+
return (
|
| 311 |
+
disable_button,
|
| 312 |
+
is_invalid["num_devices"],
|
| 313 |
+
is_invalid["num_stages"],
|
| 314 |
+
is_invalid["num_batches"],
|
| 315 |
+
is_invalid["p2p_latency"],
|
| 316 |
+
is_invalid["op_time_forward"],
|
| 317 |
+
is_invalid["op_time_backward"],
|
| 318 |
+
is_invalid["op_time_backward_d"],
|
| 319 |
+
is_invalid["op_time_backward_w"],
|
| 320 |
+
is_invalid["op_time_overlapped_fwd_bwd"],
|
| 321 |
+
strategy_feedback # Update strategy feedback based on validation
|
| 322 |
+
)
|
| 323 |
+
|
| 324 |
+
# --- Callback to toggle Advanced Options Collapse ---
|
| 325 |
@app.callback(
|
| 326 |
+
Output("advanced-timing-collapse", "is_open"),
|
| 327 |
+
Input("advanced-timing-switch", "value"),
|
| 328 |
+
prevent_initial_call=True,
|
| 329 |
+
)
|
| 330 |
+
def toggle_advanced_options(switch_value):
|
| 331 |
+
return switch_value
|
| 332 |
+
|
| 333 |
+
# --- Client-side Callback for Strategy ButtonGroup ---
|
| 334 |
+
app.clientside_callback(
|
| 335 |
+
ClientsideFunction(
|
| 336 |
+
namespace='clientside',
|
| 337 |
+
function_name='update_strategy_selection'
|
| 338 |
+
),
|
| 339 |
+
Output('selected-strategies-store', 'data'),
|
| 340 |
+
Output({'type': 'strategy-button', 'index': ALL}, 'active'),
|
| 341 |
+
Output({'type': 'strategy-button', 'index': ALL}, 'color'),
|
| 342 |
+
Output({'type': 'strategy-button', 'index': ALL}, 'outline'),
|
| 343 |
+
Output('strategy-selection-feedback', 'children'),
|
| 344 |
+
Input({'type': 'strategy-button', 'index': ALL}, 'n_clicks'),
|
| 345 |
+
State('selected-strategies-store', 'data'),
|
| 346 |
+
prevent_initial_call=True
|
| 347 |
+
)
|
| 348 |
+
|
| 349 |
+
# --- Main Graph Update Callback ---
|
| 350 |
+
@app.callback(
|
| 351 |
+
# Output graph container and toast container separately
|
| 352 |
Output('graph-output-container', 'children'),
|
| 353 |
+
Output('toast-container', 'children'), # Output for toasts
|
| 354 |
Input('generate-button', 'n_clicks'),
|
| 355 |
State('num_devices', 'value'),
|
| 356 |
State('num_stages', 'value'),
|
|
|
|
| 361 |
State('op_time_backward_d', 'value'),
|
| 362 |
State('op_time_backward_w', 'value'),
|
| 363 |
State('op_time_overlapped_fwd_bwd', 'value'),
|
| 364 |
+
State('selected-strategies-store', 'data'),
|
| 365 |
prevent_initial_call=True
|
| 366 |
)
|
| 367 |
def update_graph(n_clicks, num_devices, num_stages, num_batches, p2p_latency,
|
|
|
|
| 369 |
op_time_overlapped_fwd_bwd,
|
| 370 |
selected_strategies):
|
| 371 |
|
|
|
|
| 372 |
strategy_display_order = ["1f1b", "1f1b_interleave", "1f1b_overlap", "1f1b_interleave_overlap", "dualpipe", "zb1p"]
|
| 373 |
+
|
| 374 |
+
graph_components = [] # Renamed from output_components
|
| 375 |
+
toast_components = [] # New list for toasts
|
| 376 |
+
valid_results = []
|
| 377 |
+
error_messages = []
|
| 378 |
+
automatic_adjustments = []
|
| 379 |
+
|
| 380 |
+
# Use a variable to track if initial validation fails
|
| 381 |
+
initial_validation_error = None
|
| 382 |
|
| 383 |
if not selected_strategies:
|
| 384 |
+
initial_validation_error = dbc.Toast(
|
| 385 |
+
"Please select at least one scheduling strategy.",
|
| 386 |
+
header="Input Error",
|
| 387 |
+
icon="warning",
|
| 388 |
+
duration=4000,
|
| 389 |
+
is_open=True,
|
| 390 |
+
className="border-warning"
|
| 391 |
+
)
|
| 392 |
+
elif not all([num_devices, num_stages, num_batches, op_time_forward]):
|
| 393 |
+
initial_validation_error = dbc.Toast(
|
| 394 |
+
"Missing required basic input values (Devices, Stages, Batches, Forward Time).",
|
| 395 |
+
header="Input Error",
|
| 396 |
+
icon="danger",
|
| 397 |
+
duration=4000,
|
| 398 |
+
is_open=True,
|
| 399 |
+
className="border-danger"
|
| 400 |
+
)
|
| 401 |
|
| 402 |
+
if initial_validation_error:
|
| 403 |
+
# Return empty graph list and the validation error toast
|
| 404 |
+
return [], [initial_validation_error]
|
| 405 |
|
| 406 |
for strategy in selected_strategies:
|
| 407 |
error_message = ""
|
|
|
|
| 413 |
|
| 414 |
# Apply automatic adjustments for dualpipe
|
| 415 |
if strategy == "dualpipe" and num_stages != num_devices:
|
| 416 |
+
current_num_stages = num_devices
|
| 417 |
+
adjustment_msg = f"Strategy '{strategy}': Number of Stages auto-adjusted to {num_devices} to match Devices."
|
| 418 |
+
automatic_adjustments.append(adjustment_msg)
|
|
|
|
| 419 |
|
| 420 |
# Apply automatic adjustments for strategies that require num_stages == num_devices
|
| 421 |
if strategy in ["1f1b", "1f1b_overlap", "zb1p"] and num_stages != num_devices:
|
| 422 |
current_num_stages = num_devices
|
| 423 |
+
adjustment_msg = f"Strategy '{strategy}': Number of Stages auto-adjusted to {num_devices} to match Devices."
|
| 424 |
+
automatic_adjustments.append(adjustment_msg)
|
|
|
|
| 425 |
|
| 426 |
split_backward = strategy in ["zb1p", "dualpipe"]
|
| 427 |
|
|
|
|
| 433 |
if not error_message:
|
| 434 |
if strategy in ["1f1b", "1f1b_overlap", "zb1p"]:
|
| 435 |
placement_strategy = "standard"
|
|
|
|
| 436 |
elif strategy in ["1f1b_interleave", "1f1b_interleave_overlap"]:
|
| 437 |
placement_strategy = "interleave"
|
| 438 |
if current_num_stages % current_num_devices != 0:
|
| 439 |
+
error_message = f"Strategy '{strategy}': Requires Stages divisible by Devices."
|
| 440 |
elif strategy == "dualpipe":
|
| 441 |
placement_strategy = "dualpipe"
|
| 442 |
if current_num_stages % 2 != 0:
|
| 443 |
+
error_message = f"Strategy '{strategy}': Requires an even number of stages."
|
| 444 |
|
| 445 |
# Create adjusted operation times based on placement strategy
|
| 446 |
if not error_message:
|
| 447 |
try:
|
|
|
|
| 448 |
stages_per_device = current_num_stages // current_num_devices
|
|
|
|
|
|
|
|
|
|
|
|
|
| 449 |
time_scale_factor = 1.0 / stages_per_device if stages_per_device > 0 else 1.0
|
| 450 |
+
|
| 451 |
if stages_per_device > 1:
|
| 452 |
+
adjustment_msg = f"Strategy '{strategy}': Op times scaled by 1/{stages_per_device} ({stages_per_device} stages/device)."
|
| 453 |
+
# Avoid adding duplicate adjustment messages if already added above
|
| 454 |
+
if adjustment_msg not in automatic_adjustments:
|
| 455 |
+
automatic_adjustments.append(adjustment_msg)
|
| 456 |
+
|
| 457 |
+
op_times = { "forward": float(op_time_forward) * time_scale_factor }
|
| 458 |
+
|
|
|
|
|
|
|
| 459 |
if split_backward:
|
| 460 |
op_times["backward_D"] = float(op_time_backward_d) * time_scale_factor
|
| 461 |
op_times["backward_W"] = float(op_time_backward_w) * time_scale_factor
|
|
|
|
| 462 |
op_times["backward"] = (float(op_time_backward_d) + float(op_time_backward_w)) * time_scale_factor
|
| 463 |
else:
|
| 464 |
op_times["backward"] = float(op_time_backward) * time_scale_factor
|
|
|
|
| 467 |
try:
|
| 468 |
overlapped_val = float(op_time_overlapped_fwd_bwd)
|
| 469 |
if overlapped_val > 0:
|
|
|
|
| 470 |
op_times["overlapped_forward_backward"] = overlapped_val * time_scale_factor
|
| 471 |
except (ValueError, TypeError):
|
| 472 |
pass
|
| 473 |
|
| 474 |
config = ScheduleConfig(
|
| 475 |
num_devices=int(current_num_devices),
|
| 476 |
+
num_stages=int(current_num_stages),
|
| 477 |
num_batches=int(num_batches),
|
| 478 |
p2p_latency=float(p2p_latency),
|
| 479 |
placement_strategy=placement_strategy,
|
|
|
|
| 487 |
|
| 488 |
schedule = schedule_func(config)
|
| 489 |
schedule.execute()
|
|
|
|
|
|
|
| 490 |
vis_data = convert_schedule_to_visualization_format(schedule)
|
| 491 |
valid_results.append((strategy, schedule, vis_data))
|
| 492 |
|
| 493 |
except (AssertionError, ValueError, TypeError) as e:
|
| 494 |
+
error_message = f"Error for '{strategy}': {e}"
|
|
|
|
|
|
|
| 495 |
except Exception as e:
|
| 496 |
+
error_message = f"Unexpected error for '{strategy}': {e}"
|
|
|
|
|
|
|
| 497 |
|
| 498 |
if error_message:
|
| 499 |
error_messages.append((strategy, error_message))
|
| 500 |
|
| 501 |
+
# --- Generate Toasts ---
|
| 502 |
+
# Add toasts for automatic adjustments
|
| 503 |
for adjustment in automatic_adjustments:
|
| 504 |
+
toast_components.append(
|
| 505 |
+
dbc.Toast(
|
| 506 |
+
adjustment,
|
| 507 |
+
header="Parameter Adjustment",
|
| 508 |
+
icon="info",
|
| 509 |
+
duration=5000, # Slightly longer duration for info
|
| 510 |
+
is_open=True,
|
| 511 |
+
className="border-info"
|
| 512 |
+
)
|
| 513 |
)
|
| 514 |
|
| 515 |
+
# Add toasts for errors
|
| 516 |
+
for strategy, msg in error_messages:
|
| 517 |
+
toast_components.append(
|
| 518 |
+
dbc.Toast(
|
| 519 |
+
msg,
|
| 520 |
+
header=f"Error: {strategy}",
|
| 521 |
+
icon="danger",
|
| 522 |
+
duration=8000, # Longer duration for errors
|
| 523 |
+
is_open=True,
|
| 524 |
+
className="border-danger"
|
| 525 |
+
)
|
| 526 |
+
)
|
| 527 |
+
|
| 528 |
+
# --- Generate Graphs ---
|
| 529 |
if valid_results:
|
|
|
|
| 530 |
max_execution_time = max(schedule.get_total_execution_time() for _, schedule, _ in valid_results)
|
| 531 |
+
sorted_valid_results = sorted(valid_results, key=lambda x: strategy_display_order.index(x[0]) if x[0] in strategy_display_order else float('inf'))
|
| 532 |
+
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 533 |
for strategy, _, vis_data in sorted_valid_results:
|
| 534 |
fig = create_pipeline_figure(vis_data, max_time=max_execution_time, show_progress=False)
|
|
|
|
|
|
|
|
|
|
| 535 |
margin = max_execution_time * 0.05
|
| 536 |
fig.update_layout(
|
| 537 |
+
xaxis=dict(range=[0, max_execution_time + margin])
|
|
|
|
|
|
|
| 538 |
)
|
| 539 |
+
graph_components.append(html.Div([
|
|
|
|
| 540 |
html.H4(f"Schedule: {strategy}", className="text-center mt-3 mb-2"),
|
| 541 |
dcc.Graph(figure=fig)
|
| 542 |
]))
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 543 |
|
| 544 |
+
# Return graph components and toast components
|
| 545 |
+
return graph_components, toast_components
|
| 546 |
|
| 547 |
# For Hugging Face Spaces deployment
|
| 548 |
server = app.server
|
assets/clientside.js
ADDED
|
@@ -0,0 +1,62 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
// assets/clientside.js
|
| 2 |
+
|
| 3 |
+
// Make sure the assets folder is configured correctly in Dash for this to be loaded.
|
| 4 |
+
// Dash automatically serves files from a folder named 'assets' in the root directory.
|
| 5 |
+
|
| 6 |
+
if (!window.dash_clientside) { window.dash_clientside = {}; }
|
| 7 |
+
|
| 8 |
+
window.dash_clientside.clientside = {
|
| 9 |
+
update_strategy_selection: function(n_clicks_all, current_selection) {
|
| 10 |
+
// Determine which button triggered the callback
|
| 11 |
+
const ctx = dash_clientside.callback_context;
|
| 12 |
+
if (!ctx.triggered || ctx.triggered.length === 0) {
|
| 13 |
+
// Should not happen with prevent_initial_call=True, but handle defensively
|
| 14 |
+
return dash_clientside.no_update;
|
| 15 |
+
}
|
| 16 |
+
|
| 17 |
+
const triggered_id_str = ctx.triggered[0].prop_id.split('.')[0];
|
| 18 |
+
if (!triggered_id_str) {
|
| 19 |
+
// If we can't parse the ID, don't update
|
| 20 |
+
return dash_clientside.no_update;
|
| 21 |
+
}
|
| 22 |
+
|
| 23 |
+
// Parse the JSON ID string to get the actual index (strategy name)
|
| 24 |
+
let triggered_index;
|
| 25 |
+
try {
|
| 26 |
+
const triggered_id_obj = JSON.parse(triggered_id_str);
|
| 27 |
+
triggered_index = triggered_id_obj.index;
|
| 28 |
+
} catch (e) {
|
| 29 |
+
console.error("Error parsing callback context ID:", e);
|
| 30 |
+
return dash_clientside.no_update; // Don't update if ID parsing fails
|
| 31 |
+
}
|
| 32 |
+
|
| 33 |
+
// --- Update Selection Logic ---
|
| 34 |
+
// Initialize new_selection as a copy of the current selection
|
| 35 |
+
let new_selection = current_selection ? [...current_selection] : [];
|
| 36 |
+
|
| 37 |
+
// Toggle the selected state
|
| 38 |
+
const index_in_selection = new_selection.indexOf(triggered_index);
|
| 39 |
+
if (index_in_selection > -1) {
|
| 40 |
+
// If already selected, remove it (allow deselecting all for now)
|
| 41 |
+
new_selection.splice(index_in_selection, 1);
|
| 42 |
+
} else {
|
| 43 |
+
// If not selected, add it
|
| 44 |
+
new_selection.push(triggered_index);
|
| 45 |
+
}
|
| 46 |
+
|
| 47 |
+
// --- Prepare Outputs ---
|
| 48 |
+
const all_indices = ctx.inputs_list[0].map(input => input.id.index); // Get all strategy names from the Input IDs
|
| 49 |
+
|
| 50 |
+
// Generate active states, colors, and outlines for ALL buttons
|
| 51 |
+
const active_states = all_indices.map(index => new_selection.includes(index));
|
| 52 |
+
const colors = active_states.map(active => active ? 'primary' : 'secondary'); // 'primary' for active, 'secondary' for inactive
|
| 53 |
+
const outlines = active_states.map(active => !active); // Outline=true for inactive, false for active
|
| 54 |
+
|
| 55 |
+
// Generate validation message
|
| 56 |
+
const feedback = new_selection.length === 0 ? "Please select at least one strategy." : "";
|
| 57 |
+
|
| 58 |
+
// Return updated store data, button states, and feedback
|
| 59 |
+
return [new_selection, active_states, colors, outlines, feedback];
|
| 60 |
+
}
|
| 61 |
+
// Add other clientside functions here if needed
|
| 62 |
+
};
|
assets/custom.css
ADDED
|
@@ -0,0 +1,129 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
/* assets/custom.css */
|
| 2 |
+
|
| 3 |
+
/* --- General & Typography (Item 7, 11) --- */
|
| 4 |
+
body {
|
| 5 |
+
background-color: #F7F9FC; /* Neutral background */
|
| 6 |
+
color: #212B36; /* Dark text */
|
| 7 |
+
font-family: -apple-system, BlinkMacSystemFont, \"Segoe UI\", Roboto, \"Helvetica Neue\", Arial, sans-serif;
|
| 8 |
+
font-size: 14px;
|
| 9 |
+
}
|
| 10 |
+
|
| 11 |
+
/* Use H1 from dbc.Container/app.layout directly */
|
| 12 |
+
.h1, h1 {
|
| 13 |
+
font-size: 24px; /* H2 equivalent in request */
|
| 14 |
+
font-weight: 600;
|
| 15 |
+
}
|
| 16 |
+
|
| 17 |
+
/* Card titles */
|
| 18 |
+
.card-title.h5, .h5.card-title {
|
| 19 |
+
font-size: 18px; /* H3 equivalent */
|
| 20 |
+
font-weight: 600;
|
| 21 |
+
margin-bottom: 1rem; /* Add space below title */
|
| 22 |
+
}
|
| 23 |
+
|
| 24 |
+
/* Form labels (Item 2) */
|
| 25 |
+
.form-label {
|
| 26 |
+
font-size: 14px;
|
| 27 |
+
font-weight: 500;
|
| 28 |
+
margin-bottom: 0.3rem; /* Space between label and input */
|
| 29 |
+
display: block; /* Ensure it takes full width */
|
| 30 |
+
}
|
| 31 |
+
|
| 32 |
+
/* Form inputs (Item 2) */
|
| 33 |
+
.form-control,
|
| 34 |
+
.form-select {
|
| 35 |
+
font-size: 14px;
|
| 36 |
+
/* width: 100%; Ensure inputs take full width - Bootstrap usually handles this in columns */
|
| 37 |
+
padding: 0.5rem 0.75rem;
|
| 38 |
+
border-radius: 0.375rem; /* Softer corners */
|
| 39 |
+
}
|
| 40 |
+
|
| 41 |
+
/* Form help text */
|
| 42 |
+
.form-text {
|
| 43 |
+
font-size: 12px;
|
| 44 |
+
color: #6c757d; /* Muted color */
|
| 45 |
+
}
|
| 46 |
+
|
| 47 |
+
/* --- Layout & Spacing (Item 1, 7) --- */
|
| 48 |
+
.container-fluid {
|
| 49 |
+
padding-top: 2rem;
|
| 50 |
+
padding-bottom: 2rem;
|
| 51 |
+
}
|
| 52 |
+
|
| 53 |
+
/* Spacing between form rows inside cards */
|
| 54 |
+
.card-body .mb-3 {
|
| 55 |
+
margin-bottom: 1rem !important; /* Default is 1rem, ensure consistency */
|
| 56 |
+
}
|
| 57 |
+
|
| 58 |
+
/* Spacing between cards */
|
| 59 |
+
.card {
|
| 60 |
+
margin-bottom: 24px;
|
| 61 |
+
border: 1px solid #dee2e6; /* Subtle border */
|
| 62 |
+
border-radius: 0.5rem; /* Consistent radius */
|
| 63 |
+
box-shadow: 0 2px 4px rgba(0,0,0,0.05); /* Subtle shadow */
|
| 64 |
+
/* Padding is handled by card-body */
|
| 65 |
+
}
|
| 66 |
+
|
| 67 |
+
/* --- Button Styling (Item 4, 5, 11) --- */
|
| 68 |
+
|
| 69 |
+
/* Primary Action Button (Generate Schedule) */
|
| 70 |
+
#generate-button.btn-primary {
|
| 71 |
+
background-color: #0A74DA; /* Accent color */
|
| 72 |
+
border-color: #0A74DA;
|
| 73 |
+
font-weight: 500;
|
| 74 |
+
padding: 0.6rem 1.2rem; /* Slightly larger padding */
|
| 75 |
+
}
|
| 76 |
+
|
| 77 |
+
#generate-button.btn-primary:hover,
|
| 78 |
+
#generate-button.btn-primary:focus {
|
| 79 |
+
background-color: #085ead; /* Darker accent on hover/focus */
|
| 80 |
+
border-color: #085ead;
|
| 81 |
+
}
|
| 82 |
+
|
| 83 |
+
#generate-button.btn-primary:disabled {
|
| 84 |
+
background-color: #a0cff7; /* Lighter, muted accent when disabled */
|
| 85 |
+
border-color: #a0cff7;
|
| 86 |
+
}
|
| 87 |
+
|
| 88 |
+
/* Strategy Toggle Buttons */
|
| 89 |
+
.btn-group .btn {
|
| 90 |
+
margin-right: 0.5rem; /* Space between buttons */
|
| 91 |
+
margin-bottom: 0.5rem; /* Space for wrapping */
|
| 92 |
+
border-radius: 1rem; /* Pill shape */
|
| 93 |
+
padding: 0.4rem 0.8rem;
|
| 94 |
+
font-size: 13px;
|
| 95 |
+
}
|
| 96 |
+
|
| 97 |
+
/* Active strategy button */
|
| 98 |
+
.btn-group .btn.btn-primary:not(.disabled):not(:disabled).active,
|
| 99 |
+
.btn-group .btn.btn-primary:not(.disabled):not(:disabled):active {
|
| 100 |
+
background-color: #0A74DA; /* Accent color */
|
| 101 |
+
border-color: #0A74DA;
|
| 102 |
+
color: white;
|
| 103 |
+
box-shadow: none; /* Remove default active shadow if needed */
|
| 104 |
+
}
|
| 105 |
+
|
| 106 |
+
/* Inactive strategy button (using outline secondary) */
|
| 107 |
+
.btn-group .btn.btn-outline-secondary {
|
| 108 |
+
border-color: #ced4da;
|
| 109 |
+
color: #495057;
|
| 110 |
+
}
|
| 111 |
+
|
| 112 |
+
.btn-group .btn.btn-outline-secondary:hover {
|
| 113 |
+
background-color: #e9ecef;
|
| 114 |
+
}
|
| 115 |
+
|
| 116 |
+
/* --- Validation Feedback --- */
|
| 117 |
+
.invalid-feedback {
|
| 118 |
+
font-size: 12px;
|
| 119 |
+
margin-top: 0.25rem;
|
| 120 |
+
}
|
| 121 |
+
|
| 122 |
+
/* --- Responsive Adjustments (Item 10) --- */
|
| 123 |
+
/* Bootstrap handles column stacking. We might need more specific rules later */
|
| 124 |
+
/* e.g., adjust chart container width/scrolling on smaller screens */
|
| 125 |
+
|
| 126 |
+
/* Chart Container - Add basic styles, will be refined (Item 9) */
|
| 127 |
+
#graph-output-container .plotly.graph-div {
|
| 128 |
+
/* Add styles for the chart itself if needed */
|
| 129 |
+
}
|