Julian Bilcke
commited on
Commit
·
adc5756
1
Parent(s):
892fa67
working on a fix
Browse files- vms/services/trainer.py +21 -12
- vms/tabs/train_tab.py +62 -22
- vms/ui/video_trainer_ui.py +17 -7
vms/services/trainer.py
CHANGED
|
@@ -353,7 +353,7 @@ class TrainingService:
|
|
| 353 |
resume_from_checkpoint: Optional[str] = None,
|
| 354 |
) -> Tuple[str, str]:
|
| 355 |
"""Start training with finetrainers"""
|
| 356 |
-
|
| 357 |
self.clear_logs()
|
| 358 |
|
| 359 |
if not model_type:
|
|
@@ -365,22 +365,31 @@ class TrainingService:
|
|
| 365 |
is_resuming = resume_from_checkpoint is not None
|
| 366 |
log_prefix = "Resuming" if is_resuming else "Initializing"
|
| 367 |
logger.info(f"{log_prefix} training with model_type={model_type}")
|
| 368 |
-
self.append_log(f"{log_prefix} training with model_type={model_type}")
|
| 369 |
-
|
| 370 |
-
if is_resuming:
|
| 371 |
-
self.append_log(f"Resuming from checkpoint: {resume_from_checkpoint}")
|
| 372 |
|
| 373 |
try:
|
| 374 |
-
# Get absolute paths
|
| 375 |
-
current_dir = Path(__file__).parent.absolute()
|
| 376 |
-
train_script = current_dir
|
| 377 |
-
|
| 378 |
|
| 379 |
if not train_script.exists():
|
| 380 |
-
|
| 381 |
-
|
| 382 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 383 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 384 |
# Log paths for debugging
|
| 385 |
logger.info("Current working directory: %s", current_dir)
|
| 386 |
logger.info("Training script path: %s", train_script)
|
|
|
|
| 353 |
resume_from_checkpoint: Optional[str] = None,
|
| 354 |
) -> Tuple[str, str]:
|
| 355 |
"""Start training with finetrainers"""
|
| 356 |
+
|
| 357 |
self.clear_logs()
|
| 358 |
|
| 359 |
if not model_type:
|
|
|
|
| 365 |
is_resuming = resume_from_checkpoint is not None
|
| 366 |
log_prefix = "Resuming" if is_resuming else "Initializing"
|
| 367 |
logger.info(f"{log_prefix} training with model_type={model_type}")
|
|
|
|
|
|
|
|
|
|
|
|
|
| 368 |
|
| 369 |
try:
|
| 370 |
+
# Get absolute paths - FIXED to look in project root instead of within vms directory
|
| 371 |
+
current_dir = Path(__file__).parent.parent.parent.absolute() # Go up to project root
|
| 372 |
+
train_script = current_dir / "train.py"
|
|
|
|
| 373 |
|
| 374 |
if not train_script.exists():
|
| 375 |
+
# Try alternative locations
|
| 376 |
+
alt_locations = [
|
| 377 |
+
current_dir.parent / "train.py", # One level up from project root
|
| 378 |
+
Path("/home/user/app/train.py"), # Absolute path
|
| 379 |
+
Path("train.py") # Current working directory
|
| 380 |
+
]
|
| 381 |
+
|
| 382 |
+
for alt_path in alt_locations:
|
| 383 |
+
if alt_path.exists():
|
| 384 |
+
train_script = alt_path
|
| 385 |
+
logger.info(f"Found train.py at alternative location: {train_script}")
|
| 386 |
+
break
|
| 387 |
|
| 388 |
+
if not train_script.exists():
|
| 389 |
+
error_msg = f"Training script not found at {train_script} or any alternative locations"
|
| 390 |
+
logger.error(error_msg)
|
| 391 |
+
return error_msg, "Training script not found"
|
| 392 |
+
|
| 393 |
# Log paths for debugging
|
| 394 |
logger.info("Current working directory: %s", current_dir)
|
| 395 |
logger.info("Training script path: %s", train_script)
|
vms/tabs/train_tab.py
CHANGED
|
@@ -91,20 +91,35 @@ class TrainTab(BaseTab):
|
|
| 91 |
|
| 92 |
with gr.Column():
|
| 93 |
with gr.Row():
|
|
|
|
|
|
|
|
|
|
|
|
|
| 94 |
self.components["start_btn"] = gr.Button(
|
| 95 |
-
|
| 96 |
variant="primary",
|
| 97 |
interactive=not ASK_USER_TO_DUPLICATE_SPACE
|
| 98 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 99 |
self.components["pause_resume_btn"] = gr.Button(
|
| 100 |
"Resume Training",
|
| 101 |
variant="secondary",
|
| 102 |
-
interactive=False
|
|
|
|
| 103 |
)
|
| 104 |
-
|
| 105 |
-
|
|
|
|
|
|
|
| 106 |
variant="stop",
|
| 107 |
-
interactive=
|
| 108 |
)
|
| 109 |
|
| 110 |
with gr.Row():
|
|
@@ -468,31 +483,56 @@ class TrainTab(BaseTab):
|
|
| 468 |
|
| 469 |
return (state["status"], state["message"], logs)
|
| 470 |
|
| 471 |
-
def get_latest_status_message_logs_and_button_labels(self) -> Tuple
|
|
|
|
| 472 |
status, message, logs = self.get_latest_status_message_and_logs()
|
| 473 |
-
|
| 474 |
-
|
| 475 |
-
|
| 476 |
-
|
| 477 |
-
)
|
| 478 |
-
|
| 479 |
-
|
|
|
|
|
|
|
|
|
|
| 480 |
"""Update training control buttons based on state"""
|
|
|
|
|
|
|
|
|
|
| 481 |
is_training = status in ["training", "initializing"]
|
| 482 |
-
is_paused = status == "paused"
|
| 483 |
is_completed = status in ["completed", "error", "stopped"]
|
| 484 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 485 |
"start_btn": gr.Button(
|
| 486 |
-
|
|
|
|
| 487 |
variant="primary" if not is_training else "secondary",
|
| 488 |
),
|
| 489 |
"stop_btn": gr.Button(
|
| 490 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 491 |
variant="stop",
|
| 492 |
-
)
|
| 493 |
-
|
| 494 |
-
|
| 495 |
-
|
|
|
|
|
|
|
| 496 |
variant="secondary",
|
|
|
|
| 497 |
)
|
| 498 |
-
|
|
|
|
|
|
| 91 |
|
| 92 |
with gr.Column():
|
| 93 |
with gr.Row():
|
| 94 |
+
# Check for existing checkpoints to determine button text
|
| 95 |
+
has_checkpoints = len(list(OUTPUT_PATH.glob("checkpoint-*"))) > 0
|
| 96 |
+
start_text = "Continue Training" if has_checkpoints else "Start Training"
|
| 97 |
+
|
| 98 |
self.components["start_btn"] = gr.Button(
|
| 99 |
+
start_text,
|
| 100 |
variant="primary",
|
| 101 |
interactive=not ASK_USER_TO_DUPLICATE_SPACE
|
| 102 |
)
|
| 103 |
+
|
| 104 |
+
# Just use stop and pause buttons for now to ensure compatibility
|
| 105 |
+
self.components["stop_btn"] = gr.Button(
|
| 106 |
+
"Stop at Last Checkpoint",
|
| 107 |
+
variant="primary",
|
| 108 |
+
interactive=False
|
| 109 |
+
)
|
| 110 |
+
|
| 111 |
self.components["pause_resume_btn"] = gr.Button(
|
| 112 |
"Resume Training",
|
| 113 |
variant="secondary",
|
| 114 |
+
interactive=False,
|
| 115 |
+
visible=False
|
| 116 |
)
|
| 117 |
+
|
| 118 |
+
# Add delete checkpoints button - THIS IS THE KEY FIX
|
| 119 |
+
self.components["delete_checkpoints_btn"] = gr.Button(
|
| 120 |
+
"Delete All Checkpoints",
|
| 121 |
variant="stop",
|
| 122 |
+
interactive=True
|
| 123 |
)
|
| 124 |
|
| 125 |
with gr.Row():
|
|
|
|
| 483 |
|
| 484 |
return (state["status"], state["message"], logs)
|
| 485 |
|
| 486 |
+
def get_latest_status_message_logs_and_button_labels(self) -> Tuple:
|
| 487 |
+
"""Get latest status message, logs and button states"""
|
| 488 |
status, message, logs = self.get_latest_status_message_and_logs()
|
| 489 |
+
|
| 490 |
+
# Add checkpoints detection
|
| 491 |
+
has_checkpoints = len(list(OUTPUT_PATH.glob("checkpoint-*"))) > 0
|
| 492 |
+
|
| 493 |
+
button_updates = self.update_training_buttons(status, has_checkpoints).values()
|
| 494 |
+
|
| 495 |
+
# Return in order expected by timer
|
| 496 |
+
return (message, logs, *button_updates)
|
| 497 |
+
|
| 498 |
+
def update_training_buttons(self, status: str, has_checkpoints: bool = None) -> Dict:
|
| 499 |
"""Update training control buttons based on state"""
|
| 500 |
+
if has_checkpoints is None:
|
| 501 |
+
has_checkpoints = len(list(OUTPUT_PATH.glob("checkpoint-*"))) > 0
|
| 502 |
+
|
| 503 |
is_training = status in ["training", "initializing"]
|
|
|
|
| 504 |
is_completed = status in ["completed", "error", "stopped"]
|
| 505 |
+
|
| 506 |
+
start_text = "Continue Training" if has_checkpoints else "Start Training"
|
| 507 |
+
|
| 508 |
+
# Only include buttons that we know exist in components
|
| 509 |
+
result = {
|
| 510 |
"start_btn": gr.Button(
|
| 511 |
+
value=start_text,
|
| 512 |
+
interactive=not is_training,
|
| 513 |
variant="primary" if not is_training else "secondary",
|
| 514 |
),
|
| 515 |
"stop_btn": gr.Button(
|
| 516 |
+
value="Stop at Last Checkpoint",
|
| 517 |
+
interactive=is_training,
|
| 518 |
+
variant="primary" if is_training else "secondary",
|
| 519 |
+
)
|
| 520 |
+
}
|
| 521 |
+
|
| 522 |
+
# Add delete_checkpoints_btn only if it exists in components
|
| 523 |
+
if "delete_checkpoints_btn" in self.components:
|
| 524 |
+
result["delete_checkpoints_btn"] = gr.Button(
|
| 525 |
+
value="Delete All Checkpoints",
|
| 526 |
+
interactive=has_checkpoints and not is_training,
|
| 527 |
variant="stop",
|
| 528 |
+
)
|
| 529 |
+
else:
|
| 530 |
+
# Add pause_resume_btn as fallback
|
| 531 |
+
result["pause_resume_btn"] = gr.Button(
|
| 532 |
+
value="Resume Training" if status == "paused" else "Pause Training",
|
| 533 |
+
interactive=(is_training or status == "paused") and not is_completed,
|
| 534 |
variant="secondary",
|
| 535 |
+
visible=False
|
| 536 |
)
|
| 537 |
+
|
| 538 |
+
return result
|
vms/ui/video_trainer_ui.py
CHANGED
|
@@ -100,15 +100,25 @@ class VideoTrainerUI:
|
|
| 100 |
"""Add auto-refresh timers to the UI"""
|
| 101 |
# Status update timer (every 1 second)
|
| 102 |
status_timer = gr.Timer(value=1)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 103 |
status_timer.tick(
|
| 104 |
fn=self.tabs["train_tab"].get_latest_status_message_logs_and_button_labels,
|
| 105 |
-
outputs=
|
| 106 |
-
self.tabs["train_tab"].components["status_box"],
|
| 107 |
-
self.tabs["train_tab"].components["log_box"],
|
| 108 |
-
self.tabs["train_tab"].components["start_btn"],
|
| 109 |
-
self.tabs["train_tab"].components["stop_btn"],
|
| 110 |
-
self.tabs["train_tab"].components["delete_checkpoints_btn"] # Replace pause_resume_btn
|
| 111 |
-
]
|
| 112 |
)
|
| 113 |
|
| 114 |
# Dataset refresh timer (every 5 seconds)
|
|
|
|
| 100 |
"""Add auto-refresh timers to the UI"""
|
| 101 |
# Status update timer (every 1 second)
|
| 102 |
status_timer = gr.Timer(value=1)
|
| 103 |
+
|
| 104 |
+
# Use a safer approach - check if the component exists before using it
|
| 105 |
+
outputs = [
|
| 106 |
+
self.tabs["train_tab"].components["status_box"],
|
| 107 |
+
self.tabs["train_tab"].components["log_box"],
|
| 108 |
+
self.tabs["train_tab"].components["start_btn"],
|
| 109 |
+
self.tabs["train_tab"].components["stop_btn"]
|
| 110 |
+
]
|
| 111 |
+
|
| 112 |
+
# Add delete_checkpoints_btn only if it exists
|
| 113 |
+
if "delete_checkpoints_btn" in self.tabs["train_tab"].components:
|
| 114 |
+
outputs.append(self.tabs["train_tab"].components["delete_checkpoints_btn"])
|
| 115 |
+
else:
|
| 116 |
+
# Add pause_resume_btn as fallback
|
| 117 |
+
outputs.append(self.tabs["train_tab"].components["pause_resume_btn"])
|
| 118 |
+
|
| 119 |
status_timer.tick(
|
| 120 |
fn=self.tabs["train_tab"].get_latest_status_message_logs_and_button_labels,
|
| 121 |
+
outputs=outputs
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 122 |
)
|
| 123 |
|
| 124 |
# Dataset refresh timer (every 5 seconds)
|