Spaces:
Running
Running
update UI.
Browse files- src/server.py +124 -27
src/server.py
CHANGED
|
@@ -36,14 +36,13 @@ default_values = {
|
|
| 36 |
"num_devices": 4,
|
| 37 |
"num_stages": 8,
|
| 38 |
"num_batches": 16,
|
| 39 |
-
"p2p_latency": 0.
|
| 40 |
"op_time_forward": 1.0,
|
| 41 |
"op_time_backward_d": 1.0,
|
| 42 |
"op_time_backward_w": 1.0,
|
| 43 |
"op_time_backward": 2.0,
|
| 44 |
"strategy": "1f1b_interleave",
|
| 45 |
-
"
|
| 46 |
-
"placement_strategy": "interleave"
|
| 47 |
}
|
| 48 |
|
| 49 |
# Define input groups using dbc components
|
|
@@ -77,7 +76,7 @@ scheduling_params_card = dbc.Card(
|
|
| 77 |
dbc.Checklist(
|
| 78 |
id='strategy-checklist',
|
| 79 |
options=[{'label': k, 'value': k} for k in STRATEGIES.keys()],
|
| 80 |
-
value=
|
| 81 |
inline=False,
|
| 82 |
),
|
| 83 |
], className="mb-3"),
|
|
@@ -106,6 +105,11 @@ timing_params_card = dbc.Card(
|
|
| 106 |
dbc.Input(id='op_time_backward_w', type='number', value=default_values["op_time_backward_w"], min=0.01, step=0.01),
|
| 107 |
dbc.FormText("Used when strategy requires split backward (e.g., ZB-1P, DualPipe)."),
|
| 108 |
], className="mb-3"),
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 109 |
])
|
| 110 |
)
|
| 111 |
|
|
@@ -147,14 +151,22 @@ app.layout = dbc.Container([
|
|
| 147 |
State('op_time_backward', 'value'),
|
| 148 |
State('op_time_backward_d', 'value'),
|
| 149 |
State('op_time_backward_w', 'value'),
|
|
|
|
| 150 |
State('strategy-checklist', 'value'),
|
| 151 |
prevent_initial_call=True
|
| 152 |
)
|
| 153 |
def update_graph(n_clicks, num_devices, num_stages, num_batches, p2p_latency,
|
| 154 |
op_time_forward, op_time_backward, op_time_backward_d, op_time_backward_w,
|
|
|
|
| 155 |
selected_strategies):
|
| 156 |
|
|
|
|
|
|
|
|
|
|
| 157 |
output_components = []
|
|
|
|
|
|
|
|
|
|
| 158 |
|
| 159 |
if not selected_strategies:
|
| 160 |
return [dbc.Alert("Please select at least one scheduling strategy.", color="warning")]
|
|
@@ -164,8 +176,25 @@ def update_graph(n_clicks, num_devices, num_stages, num_batches, p2p_latency,
|
|
| 164 |
|
| 165 |
for strategy in selected_strategies:
|
| 166 |
error_message = ""
|
| 167 |
-
fig = go.Figure()
|
| 168 |
placement_strategy = ""
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 169 |
|
| 170 |
split_backward = strategy in ["zb1p", "dualpipe"]
|
| 171 |
|
|
@@ -177,32 +206,57 @@ def update_graph(n_clicks, num_devices, num_stages, num_batches, p2p_latency,
|
|
| 177 |
if not error_message:
|
| 178 |
if strategy in ["1f1b", "1f1b_overlap", "zb1p"]:
|
| 179 |
placement_strategy = "standard"
|
| 180 |
-
|
| 181 |
-
error_message = f"Strategy '{strategy}': Requires Number of Stages == Number of Devices."
|
| 182 |
elif strategy in ["1f1b_interleave", "1f1b_interleave_overlap"]:
|
| 183 |
placement_strategy = "interleave"
|
| 184 |
-
if
|
| 185 |
error_message = f"Strategy '{strategy}': Requires Number of Stages to be divisible by Number of Devices."
|
| 186 |
elif strategy == "dualpipe":
|
| 187 |
placement_strategy = "dualpipe"
|
| 188 |
-
if
|
| 189 |
error_message = f"Strategy '{strategy}' (DualPipe): Requires an even number of stages."
|
| 190 |
-
elif num_stages != num_devices:
|
| 191 |
-
error_message = f"Strategy '{strategy}' (DualPipe): Requires Number of Stages == Number of Devices."
|
| 192 |
|
|
|
|
| 193 |
if not error_message:
|
| 194 |
try:
|
| 195 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 196 |
if split_backward:
|
| 197 |
-
op_times["backward_D"] = float(op_time_backward_d)
|
| 198 |
-
op_times["backward_W"] = float(op_time_backward_w)
|
| 199 |
-
|
|
|
|
| 200 |
else:
|
| 201 |
-
op_times["backward"] = float(op_time_backward)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 202 |
|
| 203 |
config = ScheduleConfig(
|
| 204 |
-
num_devices=int(
|
| 205 |
-
num_stages=int(
|
| 206 |
num_batches=int(num_batches),
|
| 207 |
p2p_latency=float(p2p_latency),
|
| 208 |
placement_strategy=placement_strategy,
|
|
@@ -217,13 +271,9 @@ def update_graph(n_clicks, num_devices, num_stages, num_batches, p2p_latency,
|
|
| 217 |
schedule = schedule_func(config)
|
| 218 |
schedule.execute()
|
| 219 |
|
|
|
|
| 220 |
vis_data = convert_schedule_to_visualization_format(schedule)
|
| 221 |
-
|
| 222 |
-
|
| 223 |
-
output_components.append(html.Div([
|
| 224 |
-
html.H4(f"Schedule: {strategy}", className="text-center mt-3 mb-2"),
|
| 225 |
-
dcc.Graph(figure=fig)
|
| 226 |
-
]))
|
| 227 |
|
| 228 |
except (AssertionError, ValueError, TypeError) as e:
|
| 229 |
error_message = f"Error generating schedule for '{strategy}': {e}"
|
|
@@ -235,9 +285,56 @@ def update_graph(n_clicks, num_devices, num_stages, num_batches, p2p_latency,
|
|
| 235 |
traceback.print_exc()
|
| 236 |
|
| 237 |
if error_message:
|
| 238 |
-
|
| 239 |
-
|
| 240 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 241 |
|
| 242 |
return output_components
|
| 243 |
|
|
|
|
| 36 |
"num_devices": 4,
|
| 37 |
"num_stages": 8,
|
| 38 |
"num_batches": 16,
|
| 39 |
+
"p2p_latency": 0.0,
|
| 40 |
"op_time_forward": 1.0,
|
| 41 |
"op_time_backward_d": 1.0,
|
| 42 |
"op_time_backward_w": 1.0,
|
| 43 |
"op_time_backward": 2.0,
|
| 44 |
"strategy": "1f1b_interleave",
|
| 45 |
+
"op_time_overlapped_fwd_bwd": None,
|
|
|
|
| 46 |
}
|
| 47 |
|
| 48 |
# Define input groups using dbc components
|
|
|
|
| 76 |
dbc.Checklist(
|
| 77 |
id='strategy-checklist',
|
| 78 |
options=[{'label': k, 'value': k} for k in STRATEGIES.keys()],
|
| 79 |
+
value=list(STRATEGIES.keys()),
|
| 80 |
inline=False,
|
| 81 |
),
|
| 82 |
], className="mb-3"),
|
|
|
|
| 105 |
dbc.Input(id='op_time_backward_w', type='number', value=default_values["op_time_backward_w"], min=0.01, step=0.01),
|
| 106 |
dbc.FormText("Used when strategy requires split backward (e.g., ZB-1P, DualPipe)."),
|
| 107 |
], className="mb-3"),
|
| 108 |
+
html.Div([
|
| 109 |
+
dbc.Label("Overlapped Forward+Backward:"),
|
| 110 |
+
dbc.Input(id='op_time_overlapped_fwd_bwd', type='number', placeholder="Optional: Defaults to Fwd + Bwd times", min=0.01, step=0.01, value=default_values["op_time_overlapped_fwd_bwd"]),
|
| 111 |
+
dbc.FormText("Specify a custom duration if Forward and Backward ops overlap completely."),
|
| 112 |
+
], className="mb-3"),
|
| 113 |
])
|
| 114 |
)
|
| 115 |
|
|
|
|
| 151 |
State('op_time_backward', 'value'),
|
| 152 |
State('op_time_backward_d', 'value'),
|
| 153 |
State('op_time_backward_w', 'value'),
|
| 154 |
+
State('op_time_overlapped_fwd_bwd', 'value'),
|
| 155 |
State('strategy-checklist', 'value'),
|
| 156 |
prevent_initial_call=True
|
| 157 |
)
|
| 158 |
def update_graph(n_clicks, num_devices, num_stages, num_batches, p2p_latency,
|
| 159 |
op_time_forward, op_time_backward, op_time_backward_d, op_time_backward_w,
|
| 160 |
+
op_time_overlapped_fwd_bwd,
|
| 161 |
selected_strategies):
|
| 162 |
|
| 163 |
+
# Define the desired display order for strategies
|
| 164 |
+
strategy_display_order = ["1f1b", "1f1b_interleave", "1f1b_overlap", "1f1b_interleave_overlap", "dualpipe", "zb1p"]
|
| 165 |
+
|
| 166 |
output_components = []
|
| 167 |
+
valid_results = [] # Store (strategy_name, schedule, vis_data) for valid schedules
|
| 168 |
+
error_messages = [] # Store (strategy_name, error_message) for errors
|
| 169 |
+
automatic_adjustments = [] # Store messages about automatic parameter adjustments
|
| 170 |
|
| 171 |
if not selected_strategies:
|
| 172 |
return [dbc.Alert("Please select at least one scheduling strategy.", color="warning")]
|
|
|
|
| 176 |
|
| 177 |
for strategy in selected_strategies:
|
| 178 |
error_message = ""
|
|
|
|
| 179 |
placement_strategy = ""
|
| 180 |
+
|
| 181 |
+
# Use local copies of params that might be adjusted for this strategy
|
| 182 |
+
current_num_stages = num_stages
|
| 183 |
+
current_num_devices = num_devices
|
| 184 |
+
|
| 185 |
+
# Apply automatic adjustments for dualpipe
|
| 186 |
+
if strategy == "dualpipe" and num_stages != num_devices:
|
| 187 |
+
current_num_stages = num_devices # Force num_stages = num_devices for dualpipe
|
| 188 |
+
automatic_adjustments.append(
|
| 189 |
+
f"Strategy '{strategy}': Number of Stages automatically adjusted to {num_devices} to match Number of Devices."
|
| 190 |
+
)
|
| 191 |
+
|
| 192 |
+
# Apply automatic adjustments for strategies that require num_stages == num_devices
|
| 193 |
+
if strategy in ["1f1b", "1f1b_overlap", "zb1p"] and num_stages != num_devices:
|
| 194 |
+
current_num_stages = num_devices
|
| 195 |
+
automatic_adjustments.append(
|
| 196 |
+
f"Strategy '{strategy}': Number of Stages automatically adjusted to {num_devices} to match Number of Devices."
|
| 197 |
+
)
|
| 198 |
|
| 199 |
split_backward = strategy in ["zb1p", "dualpipe"]
|
| 200 |
|
|
|
|
| 206 |
if not error_message:
|
| 207 |
if strategy in ["1f1b", "1f1b_overlap", "zb1p"]:
|
| 208 |
placement_strategy = "standard"
|
| 209 |
+
# No need to check num_stages == num_devices as we've enforced it above
|
|
|
|
| 210 |
elif strategy in ["1f1b_interleave", "1f1b_interleave_overlap"]:
|
| 211 |
placement_strategy = "interleave"
|
| 212 |
+
if current_num_stages % current_num_devices != 0:
|
| 213 |
error_message = f"Strategy '{strategy}': Requires Number of Stages to be divisible by Number of Devices."
|
| 214 |
elif strategy == "dualpipe":
|
| 215 |
placement_strategy = "dualpipe"
|
| 216 |
+
if current_num_stages % 2 != 0:
|
| 217 |
error_message = f"Strategy '{strategy}' (DualPipe): Requires an even number of stages."
|
|
|
|
|
|
|
| 218 |
|
| 219 |
+
# Create adjusted operation times based on placement strategy
|
| 220 |
if not error_message:
|
| 221 |
try:
|
| 222 |
+
# Calculate number of stages per device for time adjustment
|
| 223 |
+
stages_per_device = current_num_stages // current_num_devices
|
| 224 |
+
|
| 225 |
+
# Calculate scaling factor - this normalizes operation time by stages per device
|
| 226 |
+
# For standard placement (1:1 stage:device mapping), this remains 1.0
|
| 227 |
+
# For interleaved, this scales down the time proportionally
|
| 228 |
+
time_scale_factor = 1.0 / stages_per_device if stages_per_device > 0 else 1.0
|
| 229 |
+
|
| 230 |
+
if stages_per_device > 1:
|
| 231 |
+
automatic_adjustments.append(
|
| 232 |
+
f"Strategy '{strategy}': Operation times scaled by 1/{stages_per_device} to account for {stages_per_device} stages per device."
|
| 233 |
+
)
|
| 234 |
+
|
| 235 |
+
# Apply scaling to operation times
|
| 236 |
+
op_times = {
|
| 237 |
+
"forward": float(op_time_forward) * time_scale_factor
|
| 238 |
+
}
|
| 239 |
+
|
| 240 |
if split_backward:
|
| 241 |
+
op_times["backward_D"] = float(op_time_backward_d) * time_scale_factor
|
| 242 |
+
op_times["backward_W"] = float(op_time_backward_w) * time_scale_factor
|
| 243 |
+
# Keep combined for compatibility
|
| 244 |
+
op_times["backward"] = (float(op_time_backward_d) + float(op_time_backward_w)) * time_scale_factor
|
| 245 |
else:
|
| 246 |
+
op_times["backward"] = float(op_time_backward) * time_scale_factor
|
| 247 |
+
|
| 248 |
+
if op_time_overlapped_fwd_bwd is not None:
|
| 249 |
+
try:
|
| 250 |
+
overlapped_val = float(op_time_overlapped_fwd_bwd)
|
| 251 |
+
if overlapped_val > 0:
|
| 252 |
+
# Scale overlapped time too
|
| 253 |
+
op_times["overlapped_forward_backward"] = overlapped_val * time_scale_factor
|
| 254 |
+
except (ValueError, TypeError):
|
| 255 |
+
pass
|
| 256 |
|
| 257 |
config = ScheduleConfig(
|
| 258 |
+
num_devices=int(current_num_devices),
|
| 259 |
+
num_stages=int(current_num_stages), # Use adjusted value
|
| 260 |
num_batches=int(num_batches),
|
| 261 |
p2p_latency=float(p2p_latency),
|
| 262 |
placement_strategy=placement_strategy,
|
|
|
|
| 271 |
schedule = schedule_func(config)
|
| 272 |
schedule.execute()
|
| 273 |
|
| 274 |
+
# Store valid results instead of creating figure immediately
|
| 275 |
vis_data = convert_schedule_to_visualization_format(schedule)
|
| 276 |
+
valid_results.append((strategy, schedule, vis_data))
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 277 |
|
| 278 |
except (AssertionError, ValueError, TypeError) as e:
|
| 279 |
error_message = f"Error generating schedule for '{strategy}': {e}"
|
|
|
|
| 285 |
traceback.print_exc()
|
| 286 |
|
| 287 |
if error_message:
|
| 288 |
+
error_messages.append((strategy, error_message))
|
| 289 |
+
|
| 290 |
+
# Add alerts for any automatic parameter adjustments
|
| 291 |
+
for adjustment in automatic_adjustments:
|
| 292 |
+
output_components.append(
|
| 293 |
+
dbc.Alert(adjustment, color="info", dismissable=True)
|
| 294 |
+
)
|
| 295 |
+
|
| 296 |
+
# If we have valid results, calculate the maximum execution time across all schedules
|
| 297 |
+
if valid_results:
|
| 298 |
+
# Find global maximum execution time
|
| 299 |
+
max_execution_time = max(schedule.get_total_execution_time() for _, schedule, _ in valid_results)
|
| 300 |
+
|
| 301 |
+
# Sort valid results according to the display order
|
| 302 |
+
sorted_valid_results = []
|
| 303 |
+
|
| 304 |
+
# First add strategies in the predefined order
|
| 305 |
+
for strategy_name in strategy_display_order:
|
| 306 |
+
for result in valid_results:
|
| 307 |
+
if result[0] == strategy_name:
|
| 308 |
+
sorted_valid_results.append(result)
|
| 309 |
+
|
| 310 |
+
# Then add any remaining strategies that might not be in the predefined order
|
| 311 |
+
for result in valid_results:
|
| 312 |
+
if result[0] not in strategy_display_order:
|
| 313 |
+
sorted_valid_results.append(result)
|
| 314 |
+
|
| 315 |
+
# Create figures with aligned x-axis, using the sorted results
|
| 316 |
+
for strategy, _, vis_data in sorted_valid_results:
|
| 317 |
+
fig = create_pipeline_figure(vis_data, max_time=max_execution_time, show_progress=False)
|
| 318 |
+
|
| 319 |
+
# Force the x-axis range to be the same for all figures
|
| 320 |
+
# Add a small margin (5%) for better visualization
|
| 321 |
+
margin = max_execution_time * 0.05
|
| 322 |
+
fig.update_layout(
|
| 323 |
+
xaxis=dict(
|
| 324 |
+
range=[0, max_execution_time + margin]
|
| 325 |
+
)
|
| 326 |
+
)
|
| 327 |
+
|
| 328 |
+
output_components.append(html.Div([
|
| 329 |
+
html.H4(f"Schedule: {strategy}", className="text-center mt-3 mb-2"),
|
| 330 |
+
dcc.Graph(figure=fig)
|
| 331 |
+
]))
|
| 332 |
+
|
| 333 |
+
# Add error messages to output
|
| 334 |
+
for strategy, msg in error_messages:
|
| 335 |
+
output_components.append(
|
| 336 |
+
dbc.Alert(msg, color="danger", className="mt-3")
|
| 337 |
+
)
|
| 338 |
|
| 339 |
return output_components
|
| 340 |
|