Julian Bilcke
commited on
Commit
·
4e9670a
1
Parent(s):
b06a009
fix
Browse files
vms/ui/project/tabs/train_tab.py
CHANGED
|
@@ -580,7 +580,9 @@ class TrainTab(BaseTab):
|
|
| 580 |
def handle_training_start(
|
| 581 |
self, preset, model_type, model_version, training_type,
|
| 582 |
lora_rank, lora_alpha, train_steps, batch_size, learning_rate,
|
| 583 |
-
save_iterations, repo_id,
|
|
|
|
|
|
|
| 584 |
):
|
| 585 |
"""Handle training start with proper log parser reset and checkpoint detection"""
|
| 586 |
# Safely reset log parser if it exists
|
|
@@ -594,14 +596,14 @@ class TrainTab(BaseTab):
|
|
| 594 |
# Check for latest checkpoint
|
| 595 |
checkpoints = list(OUTPUT_PATH.glob("finetrainers_step_*"))
|
| 596 |
has_checkpoints = len(checkpoints) > 0
|
| 597 |
-
resume_from =
|
| 598 |
|
| 599 |
-
if checkpoints:
|
| 600 |
# Find the latest checkpoint
|
| 601 |
latest_checkpoint = max(checkpoints, key=os.path.getmtime)
|
| 602 |
resume_from = str(latest_checkpoint)
|
| 603 |
logger.info(f"Found checkpoint at {resume_from}, will resume training")
|
| 604 |
-
|
| 605 |
# Convert model_type display name to internal name
|
| 606 |
model_internal_type = MODEL_TYPES.get(model_type)
|
| 607 |
|
|
|
|
| 580 |
def handle_training_start(
|
| 581 |
self, preset, model_type, model_version, training_type,
|
| 582 |
lora_rank, lora_alpha, train_steps, batch_size, learning_rate,
|
| 583 |
+
save_iterations, repo_id,
|
| 584 |
+
progress=gr.Progress(),
|
| 585 |
+
resume_from_checkpoint=None,
|
| 586 |
):
|
| 587 |
"""Handle training start with proper log parser reset and checkpoint detection"""
|
| 588 |
# Safely reset log parser if it exists
|
|
|
|
| 596 |
# Check for latest checkpoint
|
| 597 |
checkpoints = list(OUTPUT_PATH.glob("finetrainers_step_*"))
|
| 598 |
has_checkpoints = len(checkpoints) > 0
|
| 599 |
+
resume_from = resume_from_checkpoint # Use the passed parameter
|
| 600 |
|
| 601 |
+
if resume_from == "latest" and checkpoints:
|
| 602 |
# Find the latest checkpoint
|
| 603 |
latest_checkpoint = max(checkpoints, key=os.path.getmtime)
|
| 604 |
resume_from = str(latest_checkpoint)
|
| 605 |
logger.info(f"Found checkpoint at {resume_from}, will resume training")
|
| 606 |
+
|
| 607 |
# Convert model_type display name to internal name
|
| 608 |
model_internal_type = MODEL_TYPES.get(model_type)
|
| 609 |
|