|
|
""" |
|
|
Train tab for Video Model Studio UI with improved task progress display |
|
|
""" |
|
|
|
|
|
import gradio as gr |
|
|
import logging |
|
|
import os |
|
|
import json |
|
|
import shutil |
|
|
from typing import Dict, Any, List, Optional, Tuple |
|
|
from pathlib import Path |
|
|
|
|
|
from vms.utils import BaseTab |
|
|
from vms.config import ( |
|
|
OUTPUT_PATH, ASK_USER_TO_DUPLICATE_SPACE, |
|
|
SMALL_TRAINING_BUCKETS, |
|
|
TRAINING_PRESETS, TRAINING_TYPES, MODEL_TYPES, MODEL_VERSIONS, |
|
|
DEFAULT_NB_TRAINING_STEPS, DEFAULT_SAVE_CHECKPOINT_EVERY_N_STEPS, |
|
|
DEFAULT_BATCH_SIZE, DEFAULT_CAPTION_DROPOUT_P, |
|
|
DEFAULT_LEARNING_RATE, |
|
|
DEFAULT_LORA_RANK, DEFAULT_LORA_ALPHA, |
|
|
DEFAULT_LORA_RANK_STR, DEFAULT_LORA_ALPHA_STR, |
|
|
DEFAULT_SEED, |
|
|
DEFAULT_NUM_GPUS, |
|
|
DEFAULT_MAX_GPUS, |
|
|
DEFAULT_PRECOMPUTATION_ITEMS, |
|
|
DEFAULT_NB_TRAINING_STEPS, |
|
|
DEFAULT_NB_LR_WARMUP_STEPS, |
|
|
DEFAULT_AUTO_RESUME |
|
|
) |
|
|
|
|
|
logger = logging.getLogger(__name__) |
|
|
|
|
|
class TrainTab(BaseTab): |
|
|
"""Train tab for model training""" |
|
|
|
|
|
def __init__(self, app_state): |
|
|
super().__init__(app_state) |
|
|
self.id = "train_tab" |
|
|
self.title = "3️⃣ Train" |
|
|
|
|
|
def create(self, parent=None) -> gr.TabItem: |
|
|
"""Create the Train tab UI components""" |
|
|
with gr.TabItem(self.title, id=self.id) as tab: |
|
|
with gr.Row(): |
|
|
with gr.Column(): |
|
|
with gr.Row(): |
|
|
self.components["train_title"] = gr.Markdown("## 0 files in the training dataset") |
|
|
|
|
|
with gr.Row(): |
|
|
with gr.Column(): |
|
|
self.components["training_preset"] = gr.Dropdown( |
|
|
choices=list(TRAINING_PRESETS.keys()), |
|
|
label="Training Preset", |
|
|
value=list(TRAINING_PRESETS.keys())[0] |
|
|
) |
|
|
self.components["preset_info"] = gr.Markdown() |
|
|
|
|
|
with gr.Row(): |
|
|
with gr.Column(): |
|
|
|
|
|
default_model_type = list(MODEL_TYPES.keys())[0] |
|
|
|
|
|
self.components["model_type"] = gr.Dropdown( |
|
|
choices=list(MODEL_TYPES.keys()), |
|
|
label="Model Type", |
|
|
value=default_model_type, |
|
|
interactive=True |
|
|
) |
|
|
|
|
|
|
|
|
default_model_versions = self.get_model_version_choices(default_model_type) |
|
|
default_model_version = self.get_default_model_version(default_model_type) |
|
|
|
|
|
|
|
|
if not default_model_versions: |
|
|
|
|
|
internal_type = MODEL_TYPES.get(default_model_type) |
|
|
if internal_type in MODEL_VERSIONS: |
|
|
default_model_versions = list(MODEL_VERSIONS[internal_type].keys()) |
|
|
else: |
|
|
|
|
|
default_model_versions = [] |
|
|
for model_versions in MODEL_VERSIONS.values(): |
|
|
default_model_versions.extend(list(model_versions.keys())) |
|
|
|
|
|
|
|
|
if not default_model_versions: |
|
|
default_model_versions = ["-- No versions available --"] |
|
|
|
|
|
|
|
|
if default_model_versions: |
|
|
default_model_version = default_model_versions[0] |
|
|
else: |
|
|
default_model_version = "" |
|
|
|
|
|
self.components["model_version"] = gr.Dropdown( |
|
|
choices=default_model_versions, |
|
|
label="Model Version", |
|
|
value=default_model_version, |
|
|
interactive=True, |
|
|
allow_custom_value=True |
|
|
) |
|
|
|
|
|
self.components["training_type"] = gr.Dropdown( |
|
|
choices=list(TRAINING_TYPES.keys()), |
|
|
label="Training Type", |
|
|
value=list(TRAINING_TYPES.keys())[0] |
|
|
) |
|
|
|
|
|
with gr.Row(): |
|
|
self.components["model_info"] = gr.Markdown( |
|
|
value=self.get_model_info(list(MODEL_TYPES.keys())[0], list(TRAINING_TYPES.keys())[0]) |
|
|
) |
|
|
|
|
|
|
|
|
with gr.Row(visible=True) as lora_params_row: |
|
|
self.components["lora_params_row"] = lora_params_row |
|
|
self.components["lora_rank"] = gr.Dropdown( |
|
|
label="LoRA Rank", |
|
|
choices=["16", "32", "64", "128", "256", "512", "1024"], |
|
|
value=DEFAULT_LORA_RANK_STR, |
|
|
type="value" |
|
|
) |
|
|
self.components["lora_alpha"] = gr.Dropdown( |
|
|
label="LoRA Alpha", |
|
|
choices=["16", "32", "64", "128", "256", "512", "1024"], |
|
|
value=DEFAULT_LORA_ALPHA_STR, |
|
|
type="value" |
|
|
) |
|
|
|
|
|
with gr.Row(): |
|
|
self.components["train_steps"] = gr.Number( |
|
|
label="Number of Training Steps", |
|
|
value=DEFAULT_NB_TRAINING_STEPS, |
|
|
minimum=1, |
|
|
precision=0 |
|
|
) |
|
|
self.components["batch_size"] = gr.Number( |
|
|
label="Batch Size", |
|
|
value=1, |
|
|
minimum=1, |
|
|
precision=0 |
|
|
) |
|
|
with gr.Row(): |
|
|
self.components["learning_rate"] = gr.Number( |
|
|
label="Learning Rate", |
|
|
value=DEFAULT_LEARNING_RATE, |
|
|
minimum=1e-8 |
|
|
) |
|
|
self.components["save_iterations"] = gr.Number( |
|
|
label="Save checkpoint every N iterations", |
|
|
value=DEFAULT_SAVE_CHECKPOINT_EVERY_N_STEPS, |
|
|
minimum=1, |
|
|
precision=0, |
|
|
info="Model will be saved periodically after these many steps" |
|
|
) |
|
|
with gr.Row(): |
|
|
self.components["num_gpus"] = gr.Slider( |
|
|
label="Number of GPUs to use", |
|
|
value=DEFAULT_NUM_GPUS, |
|
|
minimum=1, |
|
|
maximum=DEFAULT_MAX_GPUS, |
|
|
step=1, |
|
|
info="Number of GPUs to use for training" |
|
|
) |
|
|
self.components["precomputation_items"] = gr.Number( |
|
|
label="Precomputation Items", |
|
|
value=DEFAULT_PRECOMPUTATION_ITEMS, |
|
|
minimum=1, |
|
|
precision=0, |
|
|
info="Should be more or less the number of total items (ex: 200 videos), divided by the number of GPUs" |
|
|
) |
|
|
with gr.Row(): |
|
|
self.components["lr_warmup_steps"] = gr.Number( |
|
|
label="Learning Rate Warmup Steps", |
|
|
value=DEFAULT_NB_LR_WARMUP_STEPS, |
|
|
minimum=0, |
|
|
precision=0, |
|
|
info="Number of warmup steps (typically 20-40% of total training steps). This helps reducing the impact of early training examples as well as giving time to optimizers to compute accurate statistics." |
|
|
) |
|
|
|
|
|
with gr.Row(): |
|
|
with gr.Column(): |
|
|
|
|
|
with gr.Row(): |
|
|
with gr.Column(): |
|
|
|
|
|
self.components["training_buttons_info"] = gr.Markdown(""" |
|
|
## ⚗️ Train your model on your dataset |
|
|
- **🚀 Start new training**: Begins training from scratch (clears previous checkpoints) |
|
|
- **🛸 Start from latest checkpoint**: Continues training from the most recent checkpoint |
|
|
""") |
|
|
|
|
|
with gr.Row(): |
|
|
|
|
|
checkpoints = list(OUTPUT_PATH.glob("finetrainers_step_*")) |
|
|
has_checkpoints = len(checkpoints) > 0 |
|
|
|
|
|
self.components["start_btn"] = gr.Button( |
|
|
"🚀 Start new training", |
|
|
variant="primary", |
|
|
interactive=not ASK_USER_TO_DUPLICATE_SPACE |
|
|
) |
|
|
|
|
|
|
|
|
self.components["resume_btn"] = gr.Button( |
|
|
"🛸 Start from latest checkpoint", |
|
|
variant="primary", |
|
|
interactive=has_checkpoints and not ASK_USER_TO_DUPLICATE_SPACE |
|
|
) |
|
|
|
|
|
with gr.Row(): |
|
|
|
|
|
self.components["stop_btn"] = gr.Button( |
|
|
"Stop at Last Checkpoint", |
|
|
variant="primary", |
|
|
interactive=False |
|
|
) |
|
|
|
|
|
self.components["pause_resume_btn"] = gr.Button( |
|
|
"Resume Training", |
|
|
variant="secondary", |
|
|
interactive=False, |
|
|
visible=False |
|
|
) |
|
|
|
|
|
|
|
|
self.components["delete_checkpoints_btn"] = gr.Button( |
|
|
"Delete All Checkpoints", |
|
|
variant="stop", |
|
|
interactive=has_checkpoints |
|
|
) |
|
|
|
|
|
with gr.Row(): |
|
|
self.components["auto_resume_checkbox"] = gr.Checkbox( |
|
|
label="Automatically continue training in case of server reboot.", |
|
|
value=DEFAULT_AUTO_RESUME, |
|
|
info="When enabled, training will automatically resume from the latest checkpoint after app restart" |
|
|
) |
|
|
|
|
|
with gr.Row(): |
|
|
with gr.Column(): |
|
|
self.components["status_box"] = gr.Textbox( |
|
|
label="Training Status", |
|
|
interactive=False, |
|
|
lines=4 |
|
|
) |
|
|
|
|
|
|
|
|
self.components["current_task_box"] = gr.Textbox( |
|
|
label="Current Task Progress", |
|
|
interactive=False, |
|
|
lines=3, |
|
|
elem_id="current_task_display" |
|
|
) |
|
|
|
|
|
with gr.Accordion("Finetrainers output (or see app logs for more details)", open=False): |
|
|
self.components["log_box"] = gr.TextArea( |
|
|
|
|
|
interactive=False, |
|
|
lines=60, |
|
|
max_lines=600, |
|
|
autoscroll=True |
|
|
) |
|
|
|
|
|
return tab |
|
|
|
|
|
def update_model_type_and_version(self, model_type: str, model_version: str): |
|
|
"""Update both model type and version together to keep them in sync""" |
|
|
|
|
|
internal_type = MODEL_TYPES.get(model_type) |
|
|
|
|
|
|
|
|
if internal_type and internal_type in MODEL_VERSIONS: |
|
|
valid_versions = list(MODEL_VERSIONS[internal_type].keys()) |
|
|
if not model_version or model_version not in valid_versions: |
|
|
if valid_versions: |
|
|
model_version = valid_versions[0] |
|
|
|
|
|
|
|
|
self.app.update_ui_state(model_type=model_type, model_version=model_version) |
|
|
return None |
|
|
|
|
|
def save_model_version(self, model_type: str, model_version: str): |
|
|
"""Save model version ensuring it's consistent with model type""" |
|
|
internal_type = MODEL_TYPES.get(model_type) |
|
|
|
|
|
|
|
|
if internal_type and internal_type in MODEL_VERSIONS: |
|
|
valid_versions = MODEL_VERSIONS[internal_type].keys() |
|
|
if model_version not in valid_versions: |
|
|
|
|
|
return None |
|
|
|
|
|
|
|
|
self.app.update_ui_state(model_type=model_type, model_version=model_version) |
|
|
return None |
|
|
|
|
|
def handle_new_training_start( |
|
|
self, preset, model_type, model_version, training_type, |
|
|
lora_rank, lora_alpha, train_steps, batch_size, learning_rate, |
|
|
save_iterations, repo_id, progress=gr.Progress() |
|
|
): |
|
|
"""Handle new training start with checkpoint cleanup""" |
|
|
|
|
|
|
|
|
|
|
|
for checkpoint in OUTPUT_PATH.glob("finetrainers_step_*"): |
|
|
if checkpoint.is_dir(): |
|
|
shutil.rmtree(checkpoint) |
|
|
|
|
|
|
|
|
session_file = OUTPUT_PATH / "session.json" |
|
|
if session_file.exists(): |
|
|
session_file.unlink() |
|
|
|
|
|
self.app.training.append_log("Cleared previous checkpoints for new training session") |
|
|
|
|
|
|
|
|
return self.handle_training_start( |
|
|
preset, model_type, model_version, training_type, |
|
|
lora_rank, lora_alpha, train_steps, batch_size, learning_rate, |
|
|
save_iterations, repo_id, progress |
|
|
) |
|
|
|
|
|
def handle_resume_training( |
|
|
self, preset, model_type, model_version, training_type, |
|
|
lora_rank, lora_alpha, train_steps, batch_size, learning_rate, |
|
|
save_iterations, repo_id, progress=gr.Progress() |
|
|
): |
|
|
"""Handle resuming training from the latest checkpoint""" |
|
|
|
|
|
checkpoints = list(OUTPUT_PATH.glob("finetrainers_step_*")) |
|
|
|
|
|
if not checkpoints: |
|
|
return "No checkpoints found to resume from", "Please start a new training session instead" |
|
|
|
|
|
self.app.training.append_log(f"Resuming training from latest checkpoint") |
|
|
|
|
|
|
|
|
return self.handle_training_start( |
|
|
preset, model_type, model_version, training_type, |
|
|
lora_rank, lora_alpha, train_steps, batch_size, learning_rate, |
|
|
save_iterations, repo_id, progress, |
|
|
resume_from_checkpoint="latest" |
|
|
) |
|
|
|
|
|
def connect_events(self) -> None: |
|
|
"""Connect event handlers to UI components""" |
|
|
|
|
|
self.components["model_type"].change( |
|
|
fn=self.update_model_versions, |
|
|
inputs=[self.components["model_type"]], |
|
|
outputs=[self.components["model_version"]] |
|
|
).then( |
|
|
fn=self.update_model_type_and_version, |
|
|
inputs=[self.components["model_type"], self.components["model_version"]], |
|
|
outputs=[] |
|
|
).then( |
|
|
|
|
|
fn=self.get_model_info, |
|
|
inputs=[self.components["model_type"], self.components["training_type"]], |
|
|
outputs=[self.components["model_info"]] |
|
|
) |
|
|
|
|
|
|
|
|
self.components["model_version"].change( |
|
|
fn=self.save_model_version, |
|
|
inputs=[self.components["model_type"], self.components["model_version"]], |
|
|
outputs=[] |
|
|
) |
|
|
|
|
|
|
|
|
self.components["training_type"].change( |
|
|
fn=lambda v: self.app.update_ui_state(training_type=v), |
|
|
inputs=[self.components["training_type"]], |
|
|
outputs=[] |
|
|
).then( |
|
|
fn=self.update_model_info, |
|
|
inputs=[self.components["model_type"], self.components["training_type"]], |
|
|
outputs=[ |
|
|
self.components["model_info"], |
|
|
self.components["train_steps"], |
|
|
self.components["batch_size"], |
|
|
self.components["learning_rate"], |
|
|
self.components["save_iterations"], |
|
|
self.components["lora_params_row"] |
|
|
] |
|
|
) |
|
|
|
|
|
self.components["auto_resume_checkbox"].change( |
|
|
fn=lambda v: self.app.update_ui_state(auto_resume=v), |
|
|
inputs=[self.components["auto_resume_checkbox"]], |
|
|
outputs=[] |
|
|
) |
|
|
|
|
|
|
|
|
self.components["num_gpus"].change( |
|
|
fn=lambda v: self.app.update_ui_state(num_gpus=v), |
|
|
inputs=[self.components["num_gpus"]], |
|
|
outputs=[] |
|
|
) |
|
|
|
|
|
self.components["precomputation_items"].change( |
|
|
fn=lambda v: self.app.update_ui_state(precomputation_items=v), |
|
|
inputs=[self.components["precomputation_items"]], |
|
|
outputs=[] |
|
|
) |
|
|
|
|
|
self.components["lr_warmup_steps"].change( |
|
|
fn=lambda v: self.app.update_ui_state(lr_warmup_steps=v), |
|
|
inputs=[self.components["lr_warmup_steps"]], |
|
|
outputs=[] |
|
|
) |
|
|
|
|
|
|
|
|
self.components["lora_rank"].change( |
|
|
fn=lambda v: self.app.update_ui_state(lora_rank=v), |
|
|
inputs=[self.components["lora_rank"]], |
|
|
outputs=[] |
|
|
) |
|
|
|
|
|
self.components["lora_alpha"].change( |
|
|
fn=lambda v: self.app.update_ui_state(lora_alpha=v), |
|
|
inputs=[self.components["lora_alpha"]], |
|
|
outputs=[] |
|
|
) |
|
|
|
|
|
self.components["train_steps"].change( |
|
|
fn=lambda v: self.app.update_ui_state(train_steps=v), |
|
|
inputs=[self.components["train_steps"]], |
|
|
outputs=[] |
|
|
) |
|
|
|
|
|
self.components["batch_size"].change( |
|
|
fn=lambda v: self.app.update_ui_state(batch_size=v), |
|
|
inputs=[self.components["batch_size"]], |
|
|
outputs=[] |
|
|
) |
|
|
|
|
|
self.components["learning_rate"].change( |
|
|
fn=lambda v: self.app.update_ui_state(learning_rate=v), |
|
|
inputs=[self.components["learning_rate"]], |
|
|
outputs=[] |
|
|
) |
|
|
|
|
|
self.components["save_iterations"].change( |
|
|
fn=lambda v: self.app.update_ui_state(save_iterations=v), |
|
|
inputs=[self.components["save_iterations"]], |
|
|
outputs=[] |
|
|
) |
|
|
|
|
|
|
|
|
self.components["training_preset"].change( |
|
|
fn=lambda v: self.app.update_ui_state(training_preset=v), |
|
|
inputs=[self.components["training_preset"]], |
|
|
outputs=[] |
|
|
).then( |
|
|
fn=self.update_training_params, |
|
|
inputs=[self.components["training_preset"]], |
|
|
outputs=[ |
|
|
self.components["model_type"], |
|
|
self.components["training_type"], |
|
|
self.components["lora_rank"], |
|
|
self.components["lora_alpha"], |
|
|
self.components["train_steps"], |
|
|
self.components["batch_size"], |
|
|
self.components["learning_rate"], |
|
|
self.components["save_iterations"], |
|
|
self.components["preset_info"], |
|
|
self.components["lora_params_row"], |
|
|
self.components["num_gpus"], |
|
|
self.components["precomputation_items"], |
|
|
self.components["lr_warmup_steps"], |
|
|
|
|
|
self.components["model_version"] |
|
|
] |
|
|
) |
|
|
|
|
|
|
|
|
self.components["start_btn"].click( |
|
|
fn=self.handle_new_training_start, |
|
|
inputs=[ |
|
|
self.components["training_preset"], |
|
|
self.components["model_type"], |
|
|
self.components["model_version"], |
|
|
self.components["training_type"], |
|
|
self.components["lora_rank"], |
|
|
self.components["lora_alpha"], |
|
|
self.components["train_steps"], |
|
|
self.components["batch_size"], |
|
|
self.components["learning_rate"], |
|
|
self.components["save_iterations"], |
|
|
self.app.tabs["manage_tab"].components["repo_id"] |
|
|
], |
|
|
outputs=[ |
|
|
self.components["status_box"], |
|
|
self.components["log_box"] |
|
|
] |
|
|
) |
|
|
|
|
|
self.components["resume_btn"].click( |
|
|
fn=self.handle_resume_training, |
|
|
inputs=[ |
|
|
self.components["training_preset"], |
|
|
self.components["model_type"], |
|
|
self.components["model_version"], |
|
|
self.components["training_type"], |
|
|
self.components["lora_rank"], |
|
|
self.components["lora_alpha"], |
|
|
self.components["train_steps"], |
|
|
self.components["batch_size"], |
|
|
self.components["learning_rate"], |
|
|
self.components["save_iterations"], |
|
|
self.app.tabs["manage_tab"].components["repo_id"] |
|
|
], |
|
|
outputs=[ |
|
|
self.components["status_box"], |
|
|
self.components["log_box"] |
|
|
] |
|
|
) |
|
|
|
|
|
|
|
|
|
|
|
third_btn = self.components["delete_checkpoints_btn"] if "delete_checkpoints_btn" in self.components else self.components["pause_resume_btn"] |
|
|
|
|
|
self.components["pause_resume_btn"].click( |
|
|
fn=self.handle_pause_resume, |
|
|
outputs=[ |
|
|
self.components["status_box"], |
|
|
self.components["log_box"], |
|
|
self.components["current_task_box"], |
|
|
self.components["start_btn"], |
|
|
self.components["stop_btn"], |
|
|
third_btn |
|
|
] |
|
|
) |
|
|
|
|
|
self.components["stop_btn"].click( |
|
|
fn=self.handle_stop, |
|
|
outputs=[ |
|
|
self.components["status_box"], |
|
|
self.components["log_box"], |
|
|
self.components["current_task_box"], |
|
|
self.components["start_btn"], |
|
|
self.components["stop_btn"], |
|
|
third_btn |
|
|
] |
|
|
) |
|
|
|
|
|
|
|
|
self.components["delete_checkpoints_btn"].click( |
|
|
fn=lambda: self.app.training.delete_all_checkpoints(), |
|
|
outputs=[self.components["status_box"]] |
|
|
) |
|
|
|
|
|
def update_model_versions(self, model_type: str) -> Dict: |
|
|
"""Update model version choices based on selected model type""" |
|
|
try: |
|
|
|
|
|
model_versions = self.get_model_version_choices(model_type) |
|
|
|
|
|
|
|
|
default_version = self.get_default_model_version(model_type) |
|
|
logger.info(f"update_model_versions({model_type}): default_version = {default_version}, available versions: {model_versions}") |
|
|
|
|
|
|
|
|
self.app.update_ui_state(model_type=model_type) |
|
|
|
|
|
|
|
|
model_versions = [str(version) for version in model_versions] |
|
|
|
|
|
|
|
|
if not model_versions: |
|
|
logger.warning(f"No model versions available for {model_type}, using empty list") |
|
|
|
|
|
return gr.Dropdown(choices=[], value=None) |
|
|
|
|
|
|
|
|
if default_version not in model_versions and model_versions: |
|
|
default_version = model_versions[0] |
|
|
logger.info(f"Default version not in choices, using first available: {default_version}") |
|
|
|
|
|
|
|
|
logger.info(f"Returning dropdown with {len(model_versions)} choices") |
|
|
return gr.Dropdown(choices=model_versions, value=default_version) |
|
|
except Exception as e: |
|
|
|
|
|
logger.error(f"Error in update_model_versions: {str(e)}") |
|
|
|
|
|
return gr.Dropdown(choices=[], value=None) |
|
|
|
|
|
def handle_training_start( |
|
|
self, preset, model_type, model_version, training_type, |
|
|
lora_rank, lora_alpha, train_steps, batch_size, learning_rate, |
|
|
save_iterations, repo_id, |
|
|
progress=gr.Progress(), |
|
|
resume_from_checkpoint=None, |
|
|
): |
|
|
"""Handle training start with proper log parser reset and checkpoint detection""" |
|
|
|
|
|
if hasattr(self.app, 'log_parser') and self.app.log_parser is not None: |
|
|
self.app.log_parser.reset() |
|
|
else: |
|
|
logger.warning("Log parser not initialized, creating a new one") |
|
|
from ..utils import TrainingLogParser |
|
|
self.app.log_parser = TrainingLogParser() |
|
|
|
|
|
|
|
|
checkpoints = list(OUTPUT_PATH.glob("finetrainers_step_*")) |
|
|
has_checkpoints = len(checkpoints) > 0 |
|
|
resume_from = resume_from_checkpoint |
|
|
|
|
|
if resume_from == "latest" and checkpoints: |
|
|
|
|
|
latest_checkpoint = max(checkpoints, key=os.path.getmtime) |
|
|
resume_from = str(latest_checkpoint) |
|
|
logger.info(f"Found checkpoint at {resume_from}, will resume training") |
|
|
|
|
|
|
|
|
model_internal_type = MODEL_TYPES.get(model_type) |
|
|
|
|
|
if not model_internal_type: |
|
|
logger.error(f"Invalid model type: {model_type}") |
|
|
return f"Error: Invalid model type '{model_type}'", "Model type not recognized" |
|
|
|
|
|
|
|
|
training_internal_type = TRAINING_TYPES.get(training_type) |
|
|
|
|
|
if not training_internal_type: |
|
|
logger.error(f"Invalid training type: {training_type}") |
|
|
return f"Error: Invalid training type '{training_type}'", "Training type not recognized" |
|
|
|
|
|
|
|
|
num_gpus = int(self.components["num_gpus"].value) |
|
|
precomputation_items = int(self.components["precomputation_items"].value) |
|
|
lr_warmup_steps = int(self.components["lr_warmup_steps"].value) |
|
|
|
|
|
|
|
|
try: |
|
|
return self.app.training.start_training( |
|
|
model_internal_type, |
|
|
lora_rank, |
|
|
lora_alpha, |
|
|
train_steps, |
|
|
batch_size, |
|
|
learning_rate, |
|
|
save_iterations, |
|
|
repo_id, |
|
|
preset_name=preset, |
|
|
training_type=training_internal_type, |
|
|
model_version=model_version, |
|
|
resume_from_checkpoint=resume_from, |
|
|
num_gpus=num_gpus, |
|
|
precomputation_items=precomputation_items, |
|
|
lr_warmup_steps=lr_warmup_steps, |
|
|
progress=progress |
|
|
) |
|
|
except Exception as e: |
|
|
logger.exception("Error starting training") |
|
|
return f"Error starting training: {str(e)}", f"Exception: {str(e)}\n\nCheck the logs for more details." |
|
|
|
|
|
def get_model_version_choices(self, model_type: str) -> List[str]: |
|
|
"""Get model version choices based on model type""" |
|
|
|
|
|
internal_type = MODEL_TYPES.get(model_type) |
|
|
if not internal_type or internal_type not in MODEL_VERSIONS: |
|
|
logger.warning(f"No model versions found for {model_type} (internal type: {internal_type})") |
|
|
return [] |
|
|
|
|
|
|
|
|
version_ids = list(MODEL_VERSIONS.get(internal_type, {}).keys()) |
|
|
logger.info(f"Found {len(version_ids)} versions for {model_type}: {version_ids}") |
|
|
|
|
|
|
|
|
return [str(version) for version in version_ids] |
|
|
|
|
|
def get_default_model_version(self, model_type: str) -> str: |
|
|
"""Get default model version for the given model type""" |
|
|
|
|
|
internal_type = MODEL_TYPES.get(model_type) |
|
|
logger.debug(f"get_default_model_version({model_type}) = {internal_type}") |
|
|
|
|
|
if not internal_type or internal_type not in MODEL_VERSIONS: |
|
|
logger.warning(f"No valid model versions found for {model_type}") |
|
|
return "" |
|
|
|
|
|
|
|
|
versions = list(MODEL_VERSIONS.get(internal_type, {}).keys()) |
|
|
if versions: |
|
|
default_version = versions[0] |
|
|
logger.debug(f"Default version for {model_type}: {default_version}") |
|
|
return default_version |
|
|
return "" |
|
|
|
|
|
def update_model_info(self, model_type: str, training_type: str) -> Dict: |
|
|
"""Update model info and related UI components based on model type and training type""" |
|
|
|
|
|
model_info = self.get_model_info(model_type, training_type) |
|
|
|
|
|
|
|
|
params = self.get_default_params(MODEL_TYPES.get(model_type), TRAINING_TYPES.get(training_type)) |
|
|
|
|
|
|
|
|
show_lora_params = training_type == "LoRA Finetune" |
|
|
|
|
|
|
|
|
return { |
|
|
self.components["model_info"]: model_info, |
|
|
self.components["train_steps"]: params["train_steps"], |
|
|
self.components["batch_size"]: params["batch_size"], |
|
|
self.components["learning_rate"]: params["learning_rate"], |
|
|
self.components["save_iterations"]: params["save_iterations"], |
|
|
self.components["lora_params_row"]: gr.Row(visible=show_lora_params) |
|
|
} |
|
|
|
|
|
def get_model_info(self, model_type: str, training_type: str) -> str: |
|
|
"""Get information about the selected model type and training method""" |
|
|
if model_type == "HunyuanVideo": |
|
|
base_info = """### HunyuanVideo |
|
|
- Required VRAM: ~48GB minimum |
|
|
- Recommended batch size: 1-2 |
|
|
- Typical training time: 2-4 hours |
|
|
- Default resolution: 49x512x768""" |
|
|
|
|
|
if training_type == "LoRA Finetune": |
|
|
return base_info + "\n- Required VRAM: ~18GB minimum\n- Default LoRA rank: 128 (~400 MB)" |
|
|
else: |
|
|
return base_info + "\n- Required VRAM: ~48GB minimum\n- **Full finetune not recommended due to VRAM requirements**" |
|
|
|
|
|
elif model_type == "LTX-Video": |
|
|
base_info = """### LTX-Video |
|
|
- Recommended batch size: 1-4 |
|
|
- Typical training time: 1-3 hours |
|
|
- Default resolution: 49x512x768""" |
|
|
|
|
|
if training_type == "LoRA Finetune": |
|
|
return base_info + "\n- Required VRAM: ~18GB minimum\n- Default LoRA rank: 128 (~400 MB)" |
|
|
else: |
|
|
return base_info + "\n- Required VRAM: ~21GB minimum\n- Full model size: ~8GB" |
|
|
|
|
|
elif model_type == "Wan": |
|
|
base_info = """### Wan |
|
|
- Recommended batch size: 1-4 |
|
|
- Typical training time: 1-3 hours |
|
|
- Default resolution: 49x512x768""" |
|
|
|
|
|
if training_type == "LoRA Finetune": |
|
|
return base_info + "\n- Required VRAM: ~16GB minimum\n- Default LoRA rank: 32 (~120 MB)" |
|
|
else: |
|
|
return base_info + "\n- **Full finetune not recommended due to VRAM requirements**" |
|
|
|
|
|
|
|
|
return f"### {model_type}\nPlease check documentation for VRAM requirements and recommended settings." |
|
|
|
|
|
def get_default_params(self, model_type: str, training_type: str) -> Dict[str, Any]: |
|
|
"""Get default training parameters for model type""" |
|
|
|
|
|
matching_presets = [ |
|
|
preset for preset_name, preset in TRAINING_PRESETS.items() |
|
|
if preset["model_type"] == model_type and preset["training_type"] == training_type |
|
|
] |
|
|
|
|
|
if matching_presets: |
|
|
|
|
|
preset = matching_presets[0] |
|
|
return { |
|
|
"train_steps": preset.get("train_steps", DEFAULT_NB_TRAINING_STEPS), |
|
|
"batch_size": preset.get("batch_size", DEFAULT_BATCH_SIZE), |
|
|
"learning_rate": preset.get("learning_rate", DEFAULT_LEARNING_RATE), |
|
|
"save_iterations": preset.get("save_iterations", DEFAULT_SAVE_CHECKPOINT_EVERY_N_STEPS), |
|
|
"lora_rank": preset.get("lora_rank", DEFAULT_LORA_RANK_STR), |
|
|
"lora_alpha": preset.get("lora_alpha", DEFAULT_LORA_ALPHA_STR) |
|
|
} |
|
|
|
|
|
|
|
|
if model_type == "hunyuan_video": |
|
|
return { |
|
|
"train_steps": DEFAULT_NB_TRAINING_STEPS, |
|
|
"batch_size": DEFAULT_BATCH_SIZE, |
|
|
"learning_rate": 2e-5, |
|
|
"save_iterations": DEFAULT_SAVE_CHECKPOINT_EVERY_N_STEPS, |
|
|
"lora_rank": DEFAULT_LORA_RANK_STR, |
|
|
"lora_alpha": DEFAULT_LORA_ALPHA_STR |
|
|
} |
|
|
elif model_type == "ltx_video": |
|
|
return { |
|
|
"train_steps": DEFAULT_NB_TRAINING_STEPS, |
|
|
"batch_size": DEFAULT_BATCH_SIZE, |
|
|
"learning_rate": DEFAULT_LEARNING_RATE, |
|
|
"save_iterations": DEFAULT_SAVE_CHECKPOINT_EVERY_N_STEPS, |
|
|
"lora_rank": DEFAULT_LORA_RANK_STR, |
|
|
"lora_alpha": DEFAULT_LORA_ALPHA_STR |
|
|
} |
|
|
elif model_type == "wan": |
|
|
return { |
|
|
"train_steps": DEFAULT_NB_TRAINING_STEPS, |
|
|
"batch_size": DEFAULT_BATCH_SIZE, |
|
|
"learning_rate": 5e-5, |
|
|
"save_iterations": DEFAULT_SAVE_CHECKPOINT_EVERY_N_STEPS, |
|
|
"lora_rank": "32", |
|
|
"lora_alpha": "32" |
|
|
} |
|
|
else: |
|
|
|
|
|
return { |
|
|
"train_steps": DEFAULT_NB_TRAINING_STEPS, |
|
|
"batch_size": DEFAULT_BATCH_SIZE, |
|
|
"learning_rate": DEFAULT_LEARNING_RATE, |
|
|
"save_iterations": DEFAULT_SAVE_CHECKPOINT_EVERY_N_STEPS, |
|
|
"lora_rank": DEFAULT_LORA_RANK_STR, |
|
|
"lora_alpha": DEFAULT_LORA_ALPHA_STR |
|
|
} |
|
|
|
|
|
def update_training_params(self, preset_name: str) -> Tuple: |
|
|
"""Update UI components based on selected preset while preserving custom settings""" |
|
|
preset = TRAINING_PRESETS[preset_name] |
|
|
|
|
|
|
|
|
current_state = self.app.load_ui_values() |
|
|
|
|
|
|
|
|
model_display_name = next( |
|
|
key for key, value in MODEL_TYPES.items() |
|
|
if value == preset["model_type"] |
|
|
) |
|
|
|
|
|
|
|
|
training_display_name = next( |
|
|
key for key, value in TRAINING_TYPES.items() |
|
|
if value == preset["training_type"] |
|
|
) |
|
|
|
|
|
|
|
|
description = preset.get("description", "") |
|
|
|
|
|
|
|
|
buckets = preset["training_buckets"] |
|
|
max_frames = max(frames for frames, _, _ in buckets) |
|
|
max_height = max(height for _, height, _ in buckets) |
|
|
max_width = max(width for _, _, width in buckets) |
|
|
bucket_info = f"\nMaximum video size: {max_frames} frames at {max_width}x{max_height} resolution" |
|
|
|
|
|
info_text = f"{description}{bucket_info}" |
|
|
|
|
|
|
|
|
show_lora_params = preset["training_type"] == "lora" |
|
|
|
|
|
|
|
|
lora_rank_val = current_state.get("lora_rank") if current_state.get("lora_rank") != preset.get("lora_rank", DEFAULT_LORA_RANK_STR) else preset.get("lora_rank", DEFAULT_LORA_RANK_STR) |
|
|
lora_alpha_val = current_state.get("lora_alpha") if current_state.get("lora_alpha") != preset.get("lora_alpha", DEFAULT_LORA_ALPHA_STR) else preset.get("lora_alpha", DEFAULT_LORA_ALPHA_STR) |
|
|
train_steps_val = current_state.get("train_steps") if current_state.get("train_steps") != preset.get("train_steps", DEFAULT_NB_TRAINING_STEPS) else preset.get("train_steps", DEFAULT_NB_TRAINING_STEPS) |
|
|
batch_size_val = current_state.get("batch_size") if current_state.get("batch_size") != preset.get("batch_size", DEFAULT_BATCH_SIZE) else preset.get("batch_size", DEFAULT_BATCH_SIZE) |
|
|
learning_rate_val = current_state.get("learning_rate") if current_state.get("learning_rate") != preset.get("learning_rate", DEFAULT_LEARNING_RATE) else preset.get("learning_rate", DEFAULT_LEARNING_RATE) |
|
|
save_iterations_val = current_state.get("save_iterations") if current_state.get("save_iterations") != preset.get("save_iterations", DEFAULT_SAVE_CHECKPOINT_EVERY_N_STEPS) else preset.get("save_iterations", DEFAULT_SAVE_CHECKPOINT_EVERY_N_STEPS) |
|
|
num_gpus_val = current_state.get("num_gpus") if current_state.get("num_gpus") != preset.get("num_gpus", DEFAULT_NUM_GPUS) else preset.get("num_gpus", DEFAULT_NUM_GPUS) |
|
|
precomputation_items_val = current_state.get("precomputation_items") if current_state.get("precomputation_items") != preset.get("precomputation_items", DEFAULT_PRECOMPUTATION_ITEMS) else preset.get("precomputation_items", DEFAULT_PRECOMPUTATION_ITEMS) |
|
|
lr_warmup_steps_val = current_state.get("lr_warmup_steps") if current_state.get("lr_warmup_steps") != preset.get("lr_warmup_steps", DEFAULT_NB_LR_WARMUP_STEPS) else preset.get("lr_warmup_steps", DEFAULT_NB_LR_WARMUP_STEPS) |
|
|
|
|
|
|
|
|
model_versions = self.get_model_version_choices(model_display_name) |
|
|
default_model_version = self.get_default_model_version(model_display_name) |
|
|
|
|
|
|
|
|
if not model_versions: |
|
|
logger.warning(f"No versions found for {model_display_name}, using empty list") |
|
|
model_versions = [] |
|
|
default_model_version = None |
|
|
elif default_model_version not in model_versions and model_versions: |
|
|
default_model_version = model_versions[0] |
|
|
logger.info(f"Reset default version to first available: {default_model_version}") |
|
|
|
|
|
|
|
|
model_versions = [str(version) for version in model_versions] |
|
|
|
|
|
|
|
|
model_version_update = gr.Dropdown(choices=model_versions, value=default_model_version) |
|
|
|
|
|
|
|
|
return ( |
|
|
model_display_name, |
|
|
training_display_name, |
|
|
lora_rank_val, |
|
|
lora_alpha_val, |
|
|
train_steps_val, |
|
|
batch_size_val, |
|
|
learning_rate_val, |
|
|
save_iterations_val, |
|
|
info_text, |
|
|
gr.Row(visible=show_lora_params), |
|
|
num_gpus_val, |
|
|
precomputation_items_val, |
|
|
lr_warmup_steps_val, |
|
|
model_version_update, |
|
|
) |
|
|
|
|
|
|
|
|
def get_latest_status_message_and_logs(self) -> Tuple[str, str, str]: |
|
|
"""Get latest status message, log content, and status code in a safer way""" |
|
|
state = self.app.training.get_status() |
|
|
logs = self.app.training.get_logs() |
|
|
|
|
|
|
|
|
training_died = False |
|
|
|
|
|
if state["status"] == "training" and not self.app.training.is_training_running(): |
|
|
state["status"] = "error" |
|
|
state["message"] = "Training process terminated unexpectedly." |
|
|
training_died = True |
|
|
|
|
|
|
|
|
error_lines = [] |
|
|
for line in logs.splitlines(): |
|
|
if "Error:" in line or "Exception:" in line or "Traceback" in line: |
|
|
error_lines.append(line) |
|
|
|
|
|
if error_lines: |
|
|
state["message"] += f"\n\nPossible error: {error_lines[-1]}" |
|
|
|
|
|
|
|
|
if not hasattr(self.app, 'log_parser') or self.app.log_parser is None: |
|
|
from ..utils import TrainingLogParser |
|
|
self.app.log_parser = TrainingLogParser() |
|
|
logger.info("Initialized missing log parser") |
|
|
|
|
|
|
|
|
if logs and not training_died: |
|
|
last_state = None |
|
|
for line in logs.splitlines(): |
|
|
try: |
|
|
state_update = self.app.log_parser.parse_line(line) |
|
|
if state_update: |
|
|
last_state = state_update |
|
|
except Exception as e: |
|
|
logger.error(f"Error parsing log line: {str(e)}") |
|
|
continue |
|
|
|
|
|
if last_state: |
|
|
ui_updates = self.update_training_ui(last_state) |
|
|
state["message"] = ui_updates.get("status_box", state["message"]) |
|
|
|
|
|
|
|
|
if "completed" in state["message"].lower(): |
|
|
state["status"] = "completed" |
|
|
elif "error" in state["message"].lower(): |
|
|
state["status"] = "error" |
|
|
elif "failed" in state["message"].lower(): |
|
|
state["status"] = "error" |
|
|
elif "stopped" in state["message"].lower(): |
|
|
state["status"] = "stopped" |
|
|
|
|
|
|
|
|
if hasattr(self.app, 'log_parser') and self.app.log_parser is not None: |
|
|
state["current_task"] = self.app.log_parser.get_current_task_display() |
|
|
|
|
|
return (state["status"], state["message"], logs) |
|
|
|
|
|
def get_status_updates(self): |
|
|
"""Get status updates for text components (no variant property)""" |
|
|
status, message, logs = self.get_latest_status_message_and_logs() |
|
|
|
|
|
|
|
|
current_task = "" |
|
|
if hasattr(self.app, 'log_parser') and self.app.log_parser is not None: |
|
|
current_task = self.app.log_parser.get_current_task_display() |
|
|
|
|
|
return message, logs, current_task |
|
|
|
|
|
def get_button_updates(self): |
|
|
"""Get button updates (with variant property)""" |
|
|
status, _, _ = self.get_latest_status_message_and_logs() |
|
|
|
|
|
|
|
|
checkpoints = list(OUTPUT_PATH.glob("finetrainers_step_*")) |
|
|
has_checkpoints = len(checkpoints) > 0 |
|
|
|
|
|
is_training = status in ["training", "initializing"] |
|
|
is_completed = status in ["completed", "error", "stopped"] |
|
|
|
|
|
|
|
|
start_btn = gr.Button( |
|
|
value="🚀 Start new training", |
|
|
interactive=not is_training, |
|
|
variant="primary" if not is_training else "secondary" |
|
|
) |
|
|
|
|
|
resume_btn = gr.Button( |
|
|
value="🛸 Start from latest checkpoint", |
|
|
interactive=has_checkpoints and not is_training, |
|
|
variant="primary" if not is_training else "secondary" |
|
|
) |
|
|
|
|
|
stop_btn = gr.Button( |
|
|
value="Stop at Last Checkpoint", |
|
|
interactive=is_training, |
|
|
variant="primary" if is_training else "secondary" |
|
|
) |
|
|
|
|
|
|
|
|
delete_checkpoints_btn = gr.Button( |
|
|
"Delete All Checkpoints", |
|
|
interactive=has_checkpoints and not is_training, |
|
|
variant="stop" |
|
|
) |
|
|
|
|
|
return start_btn, resume_btn, stop_btn, delete_checkpoints_btn |
|
|
|
|
|
def update_training_ui(self, training_state: Dict[str, Any]): |
|
|
"""Update UI components based on training state""" |
|
|
updates = {} |
|
|
|
|
|
|
|
|
status_text = [] |
|
|
if training_state["status"] != "idle": |
|
|
status_text.extend([ |
|
|
f"Status: {training_state['status']}", |
|
|
f"Progress: {training_state['progress']}", |
|
|
f"Step: {training_state['current_step']}/{training_state['total_steps']}", |
|
|
f"Time elapsed: {training_state['elapsed']}", |
|
|
f"Estimated remaining: {training_state['remaining']}", |
|
|
"", |
|
|
f"Current loss: {training_state['step_loss']}", |
|
|
f"Learning rate: {training_state['learning_rate']}", |
|
|
f"Gradient norm: {training_state['grad_norm']}", |
|
|
f"Memory usage: {training_state['memory']}" |
|
|
]) |
|
|
|
|
|
if training_state["error_message"]: |
|
|
status_text.append(f"\nError: {training_state['error_message']}") |
|
|
|
|
|
updates["status_box"] = "\n".join(status_text) |
|
|
|
|
|
|
|
|
if training_state.get("current_task"): |
|
|
updates["current_task_box"] = training_state["current_task"] |
|
|
else: |
|
|
updates["current_task_box"] = "No active task" if training_state["status"] != "training" else "Waiting for task information..." |
|
|
|
|
|
return updates |
|
|
|
|
|
def handle_pause_resume(self): |
|
|
"""Handle pause/resume button click""" |
|
|
status, _, _ = self.get_latest_status_message_and_logs() |
|
|
|
|
|
if status == "paused": |
|
|
self.app.training.resume_training() |
|
|
else: |
|
|
self.app.training.pause_training() |
|
|
|
|
|
|
|
|
return (*self.get_status_updates(), *self.get_button_updates()) |
|
|
|
|
|
def handle_stop(self): |
|
|
"""Handle stop button click""" |
|
|
self.app.training.stop_training() |
|
|
|
|
|
|
|
|
return (*self.get_status_updates(), *self.get_button_updates()) |