Spaces:
Running
Running
Update args logic.
Browse files- pipeline.py +11 -21
pipeline.py
CHANGED
|
@@ -280,12 +280,12 @@ def parse_args():
|
|
| 280 |
"--num-stages",
|
| 281 |
"-s",
|
| 282 |
type=int,
|
| 283 |
-
default=
|
| 284 |
help="Number of pipeline stages (devices)",
|
| 285 |
)
|
| 286 |
|
| 287 |
parser.add_argument(
|
| 288 |
-
"--num-batches", "-b", type=int, default=
|
| 289 |
)
|
| 290 |
|
| 291 |
# Forward and backward times
|
|
@@ -369,6 +369,15 @@ def main():
|
|
| 369 |
backward_times = None
|
| 370 |
output_file = "pipeline_1f1b.png"
|
| 371 |
p2p_time = 0.0
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 372 |
# Read from config file if provided
|
| 373 |
if args.config:
|
| 374 |
try:
|
|
@@ -387,25 +396,6 @@ def main():
|
|
| 387 |
print(f"Error reading config file: {str(e)}")
|
| 388 |
print("Falling back to command line arguments or defaults")
|
| 389 |
|
| 390 |
-
# Command line arguments override config file
|
| 391 |
-
if args.num_stages:
|
| 392 |
-
num_stages = args.num_stages
|
| 393 |
-
|
| 394 |
-
if args.num_batches:
|
| 395 |
-
num_batches = args.num_batches
|
| 396 |
-
|
| 397 |
-
if args.forward_times:
|
| 398 |
-
forward_times = args.forward_times
|
| 399 |
-
|
| 400 |
-
if args.backward_times:
|
| 401 |
-
backward_times = args.backward_times
|
| 402 |
-
|
| 403 |
-
if args.output:
|
| 404 |
-
output_file = args.output
|
| 405 |
-
|
| 406 |
-
if args.p2p_time:
|
| 407 |
-
p2p_time = args.p2p_time
|
| 408 |
-
|
| 409 |
# Validate inputs
|
| 410 |
if forward_times is None:
|
| 411 |
forward_times = [1.0] * num_stages
|
|
|
|
| 280 |
"--num-stages",
|
| 281 |
"-s",
|
| 282 |
type=int,
|
| 283 |
+
default=0,
|
| 284 |
help="Number of pipeline stages (devices)",
|
| 285 |
)
|
| 286 |
|
| 287 |
parser.add_argument(
|
| 288 |
+
"--num-batches", "-b", type=int, default=0, help="Number of micro-batches"
|
| 289 |
)
|
| 290 |
|
| 291 |
# Forward and backward times
|
|
|
|
| 369 |
backward_times = None
|
| 370 |
output_file = "pipeline_1f1b.png"
|
| 371 |
p2p_time = 0.0
|
| 372 |
+
|
| 373 |
+
# Command line arguments override config file
|
| 374 |
+
num_stages = args.num_stages
|
| 375 |
+
num_batches = args.num_batches
|
| 376 |
+
forward_times = args.forward_times
|
| 377 |
+
backward_times = args.backward_times
|
| 378 |
+
output_file = args.output
|
| 379 |
+
p2p_time = args.p2p_time
|
| 380 |
+
|
| 381 |
# Read from config file if provided
|
| 382 |
if args.config:
|
| 383 |
try:
|
|
|
|
| 396 |
print(f"Error reading config file: {str(e)}")
|
| 397 |
print("Falling back to command line arguments or defaults")
|
| 398 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 399 |
# Validate inputs
|
| 400 |
if forward_times is None:
|
| 401 |
forward_times = [1.0] * num_stages
|