Spaces:
Running
Running
Add table results.
Browse files
app.py
CHANGED
|
@@ -371,11 +371,12 @@ def update_graph(n_clicks, num_devices, num_stages, num_batches, p2p_latency,
|
|
| 371 |
|
| 372 |
strategy_display_order = ["1f1b", "1f1b_interleave", "1f1b_overlap", "1f1b_interleave_overlap", "dualpipe", "zb1p"]
|
| 373 |
|
| 374 |
-
graph_components = []
|
| 375 |
-
toast_components = []
|
| 376 |
valid_results = []
|
| 377 |
error_messages = []
|
| 378 |
automatic_adjustments = []
|
|
|
|
| 379 |
|
| 380 |
# Use a variable to track if initial validation fails
|
| 381 |
initial_validation_error = None
|
|
@@ -489,6 +490,8 @@ def update_graph(n_clicks, num_devices, num_stages, num_batches, p2p_latency,
|
|
| 489 |
schedule.execute()
|
| 490 |
vis_data = convert_schedule_to_visualization_format(schedule)
|
| 491 |
valid_results.append((strategy, schedule, vis_data))
|
|
|
|
|
|
|
| 492 |
|
| 493 |
except (AssertionError, ValueError, TypeError) as e:
|
| 494 |
error_message = f"Error for '{strategy}': {e}"
|
|
@@ -530,19 +533,69 @@ def update_graph(n_clicks, num_devices, num_stages, num_batches, p2p_latency,
|
|
| 530 |
max_execution_time = max(schedule.get_total_execution_time() for _, schedule, _ in valid_results)
|
| 531 |
sorted_valid_results = sorted(valid_results, key=lambda x: strategy_display_order.index(x[0]) if x[0] in strategy_display_order else float('inf'))
|
| 532 |
|
|
|
|
|
|
|
| 533 |
for strategy, _, vis_data in sorted_valid_results:
|
| 534 |
fig = create_pipeline_figure(vis_data, max_time=max_execution_time, show_progress=False)
|
| 535 |
margin = max_execution_time * 0.05
|
| 536 |
fig.update_layout(
|
| 537 |
xaxis=dict(range=[0, max_execution_time + margin])
|
| 538 |
)
|
| 539 |
-
|
| 540 |
-
|
| 541 |
-
|
| 542 |
-
|
|
|
|
|
|
|
|
|
|
| 543 |
|
| 544 |
-
|
| 545 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 546 |
|
| 547 |
# For Hugging Face Spaces deployment
|
| 548 |
server = app.server
|
|
|
|
| 371 |
|
| 372 |
strategy_display_order = ["1f1b", "1f1b_interleave", "1f1b_overlap", "1f1b_interleave_overlap", "dualpipe", "zb1p"]
|
| 373 |
|
| 374 |
+
graph_components = []
|
| 375 |
+
toast_components = []
|
| 376 |
valid_results = []
|
| 377 |
error_messages = []
|
| 378 |
automatic_adjustments = []
|
| 379 |
+
execution_times = [] # Add list to store execution times
|
| 380 |
|
| 381 |
# Use a variable to track if initial validation fails
|
| 382 |
initial_validation_error = None
|
|
|
|
| 490 |
schedule.execute()
|
| 491 |
vis_data = convert_schedule_to_visualization_format(schedule)
|
| 492 |
valid_results.append((strategy, schedule, vis_data))
|
| 493 |
+
# Store execution time
|
| 494 |
+
execution_times.append((strategy, schedule.get_total_execution_time()))
|
| 495 |
|
| 496 |
except (AssertionError, ValueError, TypeError) as e:
|
| 497 |
error_message = f"Error for '{strategy}': {e}"
|
|
|
|
| 533 |
max_execution_time = max(schedule.get_total_execution_time() for _, schedule, _ in valid_results)
|
| 534 |
sorted_valid_results = sorted(valid_results, key=lambda x: strategy_display_order.index(x[0]) if x[0] in strategy_display_order else float('inf'))
|
| 535 |
|
| 536 |
+
# Prepare graphs for single-column layout
|
| 537 |
+
graph_components = [] # Use graph_components again
|
| 538 |
for strategy, _, vis_data in sorted_valid_results:
|
| 539 |
fig = create_pipeline_figure(vis_data, max_time=max_execution_time, show_progress=False)
|
| 540 |
margin = max_execution_time * 0.05
|
| 541 |
fig.update_layout(
|
| 542 |
xaxis=dict(range=[0, max_execution_time + margin])
|
| 543 |
)
|
| 544 |
+
# Append the Div directly for vertical stacking
|
| 545 |
+
graph_components.append(
|
| 546 |
+
html.Div([
|
| 547 |
+
html.H4(f"Schedule: {strategy}", className="text-center mt-3 mb-2"),
|
| 548 |
+
dcc.Graph(figure=fig)
|
| 549 |
+
])
|
| 550 |
+
)
|
| 551 |
|
| 552 |
+
# No grid arrangement needed for single column
|
| 553 |
+
# rows = [] ... removed ...
|
| 554 |
+
|
| 555 |
+
# If there are graphs, use the component list, otherwise show a message
|
| 556 |
+
output_content = []
|
| 557 |
+
if graph_components: # Check if graph_components list is populated
|
| 558 |
+
output_content = graph_components # Assign the list of components
|
| 559 |
+
elif not toast_components: # Only show 'no results' if no errors/adjustments either
|
| 560 |
+
output_content = dbc.Alert("Click 'Generate Schedule' to see results.", color="info", className="mt-3")
|
| 561 |
+
|
| 562 |
+
# Add the execution time table if there are results
|
| 563 |
+
if execution_times:
|
| 564 |
+
# Sort times based on execution time (ascending)
|
| 565 |
+
sorted_times = sorted(execution_times, key=lambda x: x[1])
|
| 566 |
+
min_time = sorted_times[0][1] if sorted_times else None
|
| 567 |
+
|
| 568 |
+
table_header = [html.Thead(html.Tr([html.Th("Strategy"), html.Th("Total Execution Time (ms)")]))]
|
| 569 |
+
table_rows = []
|
| 570 |
+
for strategy, time in sorted_times:
|
| 571 |
+
row_class = "table-success" if time == min_time else ""
|
| 572 |
+
table_rows.append(html.Tr([html.Td(strategy), html.Td(f"{time:.2f}")], className=row_class))
|
| 573 |
+
|
| 574 |
+
table_body = [html.Tbody(table_rows)]
|
| 575 |
+
summary_table = dbc.Table(
|
| 576 |
+
table_header + table_body,
|
| 577 |
+
bordered=True,
|
| 578 |
+
striped=True,
|
| 579 |
+
hover=True,
|
| 580 |
+
responsive=True,
|
| 581 |
+
color="light", # Apply a light theme color
|
| 582 |
+
className="mt-5" # Add margin top
|
| 583 |
+
)
|
| 584 |
+
# Prepend title to the table
|
| 585 |
+
table_component = html.Div([
|
| 586 |
+
html.H4("Execution Time Summary", className="text-center mt-4 mb-3"),
|
| 587 |
+
summary_table
|
| 588 |
+
])
|
| 589 |
+
|
| 590 |
+
# Append the table component to the output content
|
| 591 |
+
# If output_content is just the alert, replace it. Otherwise, append.
|
| 592 |
+
if isinstance(output_content, list):
|
| 593 |
+
output_content.append(table_component)
|
| 594 |
+
else: # It must be the Alert
|
| 595 |
+
output_content = [output_content, table_component] # Replace Alert with list
|
| 596 |
+
|
| 597 |
+
# Return graph components (single column list or message) and toast components
|
| 598 |
+
return output_content, toast_components
|
| 599 |
|
| 600 |
# For Hugging Face Spaces deployment
|
| 601 |
server = app.server
|