Julian Bilcke
commited on
Commit
·
446e79f
1
Parent(s):
54a2a4e
working on fixes
Browse files- app.py +72 -37
- vms/training_service.py +10 -5
app.py
CHANGED
|
@@ -77,12 +77,68 @@ class VideoTrainerUI:
|
|
| 77 |
# UI will be in ready-to-start mode
|
| 78 |
|
| 79 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 80 |
def update_ui_state(self, **kwargs):
|
| 81 |
"""Update UI state with new values"""
|
| 82 |
current_state = self.trainer.load_ui_state()
|
| 83 |
current_state.update(kwargs)
|
| 84 |
self.trainer.save_ui_state(current_state)
|
| 85 |
-
return
|
|
|
|
| 86 |
|
| 87 |
def load_ui_values(self):
|
| 88 |
"""Load UI state values for initializing form fields"""
|
|
@@ -130,6 +186,19 @@ class VideoTrainerUI:
|
|
| 130 |
)
|
| 131 |
)
|
| 132 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 133 |
def show_refreshing_status(self) -> List[List[str]]:
|
| 134 |
"""Show a 'Refreshing...' status in the dataframe"""
|
| 135 |
return [["Refreshing...", "please wait"]]
|
|
@@ -1421,52 +1490,18 @@ class VideoTrainerUI:
|
|
| 1421 |
]
|
| 1422 |
)
|
| 1423 |
|
| 1424 |
-
# Add this new method to get initial button states:
|
| 1425 |
-
def get_initial_button_states(self):
|
| 1426 |
-
"""Get the initial states for training buttons based on recovery status"""
|
| 1427 |
-
recovery_result = self.trainer.recover_interrupted_training()
|
| 1428 |
-
ui_updates = recovery_result.get("ui_updates", {})
|
| 1429 |
-
|
| 1430 |
-
# Return button states in the correct order
|
| 1431 |
-
return (
|
| 1432 |
-
gr.Button(**ui_updates.get("start_btn", {"interactive": True, "variant": "primary"})),
|
| 1433 |
-
gr.Button(**ui_updates.get("stop_btn", {"interactive": False, "variant": "secondary"})),
|
| 1434 |
-
gr.Button(**ui_updates.get("pause_resume_btn", {"interactive": False, "variant": "secondary"}))
|
| 1435 |
-
)
|
| 1436 |
-
|
| 1437 |
-
def initialize_ui_from_state(self):
|
| 1438 |
-
"""Initialize UI components from saved state"""
|
| 1439 |
-
ui_state = self.load_ui_values()
|
| 1440 |
-
|
| 1441 |
-
# Return values in order matching the outputs in app.load
|
| 1442 |
-
return (
|
| 1443 |
-
ui_state.get("training_preset", list(TRAINING_PRESETS.keys())[0]),
|
| 1444 |
-
ui_state.get("model_type", list(MODEL_TYPES.keys())[0]),
|
| 1445 |
-
ui_state.get("lora_rank", "128"),
|
| 1446 |
-
ui_state.get("lora_alpha", "128"),
|
| 1447 |
-
ui_state.get("num_epochs", 70),
|
| 1448 |
-
ui_state.get("batch_size", 1),
|
| 1449 |
-
ui_state.get("learning_rate", 3e-5),
|
| 1450 |
-
ui_state.get("save_iterations", 500)
|
| 1451 |
-
)
|
| 1452 |
|
| 1453 |
-
# Auto-refresh timers
|
| 1454 |
app.load(
|
| 1455 |
-
fn=
|
| 1456 |
-
self.refresh_dataset(),
|
| 1457 |
-
*self.get_initial_button_states(),
|
| 1458 |
-
# Load saved UI state values
|
| 1459 |
-
*self.initialize_ui_from_state()
|
| 1460 |
-
),
|
| 1461 |
outputs=[
|
| 1462 |
video_list, training_dataset,
|
| 1463 |
start_btn, stop_btn, pause_resume_btn,
|
| 1464 |
-
# Add outputs for UI fields
|
| 1465 |
training_preset, model_type, lora_rank, lora_alpha,
|
| 1466 |
num_epochs, batch_size, learning_rate, save_iterations
|
| 1467 |
]
|
| 1468 |
)
|
| 1469 |
|
|
|
|
| 1470 |
timer = gr.Timer(value=1)
|
| 1471 |
timer.tick(
|
| 1472 |
fn=lambda: (
|
|
|
|
| 77 |
# UI will be in ready-to-start mode
|
| 78 |
|
| 79 |
|
| 80 |
+
def initialize_app_state(self):
|
| 81 |
+
"""Initialize all app state in one function to ensure correct output count"""
|
| 82 |
+
# Get dataset info
|
| 83 |
+
video_list, training_dataset = self.refresh_dataset()
|
| 84 |
+
|
| 85 |
+
# Get button states
|
| 86 |
+
button_states = self.get_initial_button_states()
|
| 87 |
+
start_btn = button_states[0]
|
| 88 |
+
stop_btn = button_states[1]
|
| 89 |
+
pause_resume_btn = button_states[2]
|
| 90 |
+
|
| 91 |
+
# Get UI form values
|
| 92 |
+
ui_state = self.load_ui_values()
|
| 93 |
+
training_preset = ui_state.get("training_preset", list(TRAINING_PRESETS.keys())[0])
|
| 94 |
+
model_type_val = ui_state.get("model_type", list(MODEL_TYPES.keys())[0])
|
| 95 |
+
lora_rank_val = ui_state.get("lora_rank", "128")
|
| 96 |
+
lora_alpha_val = ui_state.get("lora_alpha", "128")
|
| 97 |
+
num_epochs_val = int(ui_state.get("num_epochs", 70))
|
| 98 |
+
batch_size_val = int(ui_state.get("batch_size", 1))
|
| 99 |
+
learning_rate_val = float(ui_state.get("learning_rate", 3e-5))
|
| 100 |
+
save_iterations_val = int(ui_state.get("save_iterations", 500))
|
| 101 |
+
|
| 102 |
+
# Return all values in the exact order expected by outputs
|
| 103 |
+
return (
|
| 104 |
+
video_list,
|
| 105 |
+
training_dataset,
|
| 106 |
+
start_btn,
|
| 107 |
+
stop_btn,
|
| 108 |
+
pause_resume_btn,
|
| 109 |
+
training_preset,
|
| 110 |
+
model_type_val,
|
| 111 |
+
lora_rank_val,
|
| 112 |
+
lora_alpha_val,
|
| 113 |
+
num_epochs_val,
|
| 114 |
+
batch_size_val,
|
| 115 |
+
learning_rate_val,
|
| 116 |
+
save_iterations_val
|
| 117 |
+
)
|
| 118 |
+
|
| 119 |
+
def initialize_ui_from_state(self):
|
| 120 |
+
"""Initialize UI components from saved state"""
|
| 121 |
+
ui_state = self.load_ui_values()
|
| 122 |
+
|
| 123 |
+
# Return values in order matching the outputs in app.load
|
| 124 |
+
return (
|
| 125 |
+
ui_state.get("training_preset", list(TRAINING_PRESETS.keys())[0]),
|
| 126 |
+
ui_state.get("model_type", list(MODEL_TYPES.keys())[0]),
|
| 127 |
+
ui_state.get("lora_rank", "128"),
|
| 128 |
+
ui_state.get("lora_alpha", "128"),
|
| 129 |
+
ui_state.get("num_epochs", 70),
|
| 130 |
+
ui_state.get("batch_size", 1),
|
| 131 |
+
ui_state.get("learning_rate", 3e-5),
|
| 132 |
+
ui_state.get("save_iterations", 500)
|
| 133 |
+
)
|
| 134 |
+
|
| 135 |
def update_ui_state(self, **kwargs):
|
| 136 |
"""Update UI state with new values"""
|
| 137 |
current_state = self.trainer.load_ui_state()
|
| 138 |
current_state.update(kwargs)
|
| 139 |
self.trainer.save_ui_state(current_state)
|
| 140 |
+
# Don't return anything to avoid Gradio warnings
|
| 141 |
+
return None
|
| 142 |
|
| 143 |
def load_ui_values(self):
|
| 144 |
"""Load UI state values for initializing form fields"""
|
|
|
|
| 186 |
)
|
| 187 |
)
|
| 188 |
|
| 189 |
+
# Add this new method to get initial button states:
|
| 190 |
+
def get_initial_button_states(self):
|
| 191 |
+
"""Get the initial states for training buttons based on recovery status"""
|
| 192 |
+
recovery_result = self.trainer.recover_interrupted_training()
|
| 193 |
+
ui_updates = recovery_result.get("ui_updates", {})
|
| 194 |
+
|
| 195 |
+
# Return button states in the correct order
|
| 196 |
+
return (
|
| 197 |
+
gr.Button(**ui_updates.get("start_btn", {"interactive": True, "variant": "primary"})),
|
| 198 |
+
gr.Button(**ui_updates.get("stop_btn", {"interactive": False, "variant": "secondary"})),
|
| 199 |
+
gr.Button(**ui_updates.get("pause_resume_btn", {"interactive": False, "variant": "secondary"}))
|
| 200 |
+
)
|
| 201 |
+
|
| 202 |
def show_refreshing_status(self) -> List[List[str]]:
|
| 203 |
"""Show a 'Refreshing...' status in the dataframe"""
|
| 204 |
return [["Refreshing...", "please wait"]]
|
|
|
|
| 1490 |
]
|
| 1491 |
)
|
| 1492 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1493 |
|
|
|
|
| 1494 |
app.load(
|
| 1495 |
+
fn=self.initialize_app_state,
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1496 |
outputs=[
|
| 1497 |
video_list, training_dataset,
|
| 1498 |
start_btn, stop_btn, pause_resume_btn,
|
|
|
|
| 1499 |
training_preset, model_type, lora_rank, lora_alpha,
|
| 1500 |
num_epochs, batch_size, learning_rate, save_iterations
|
| 1501 |
]
|
| 1502 |
)
|
| 1503 |
|
| 1504 |
+
# Auto-refresh timers
|
| 1505 |
timer = gr.Timer(value=1)
|
| 1506 |
timer.tick(
|
| 1507 |
fn=lambda: (
|
vms/training_service.py
CHANGED
|
@@ -164,12 +164,11 @@ class TrainingService:
|
|
| 164 |
|
| 165 |
if not self.status_file.exists():
|
| 166 |
return default_status
|
| 167 |
-
|
| 168 |
try:
|
| 169 |
with open(self.status_file, 'r') as f:
|
| 170 |
status = json.load(f)
|
| 171 |
-
|
| 172 |
-
|
| 173 |
# Check if process is actually running
|
| 174 |
if self.pid_file.exists():
|
| 175 |
with open(self.pid_file, 'r') as f:
|
|
@@ -177,14 +176,20 @@ class TrainingService:
|
|
| 177 |
if not psutil.pid_exists(pid):
|
| 178 |
# Process died unexpectedly
|
| 179 |
if status['status'] == 'training':
|
|
|
|
|
|
|
|
|
|
|
|
|
| 180 |
status['status'] = 'error'
|
| 181 |
status['message'] = 'Training process terminated unexpectedly'
|
| 182 |
-
|
|
|
|
|
|
|
| 183 |
else:
|
| 184 |
status['status'] = 'stopped'
|
| 185 |
status['message'] = 'Training process not found'
|
| 186 |
return status
|
| 187 |
-
|
| 188 |
except (json.JSONDecodeError, ValueError):
|
| 189 |
return default_status
|
| 190 |
|
|
|
|
| 164 |
|
| 165 |
if not self.status_file.exists():
|
| 166 |
return default_status
|
| 167 |
+
|
| 168 |
try:
|
| 169 |
with open(self.status_file, 'r') as f:
|
| 170 |
status = json.load(f)
|
| 171 |
+
|
|
|
|
| 172 |
# Check if process is actually running
|
| 173 |
if self.pid_file.exists():
|
| 174 |
with open(self.pid_file, 'r') as f:
|
|
|
|
| 176 |
if not psutil.pid_exists(pid):
|
| 177 |
# Process died unexpectedly
|
| 178 |
if status['status'] == 'training':
|
| 179 |
+
# Only log this once by checking if we've already updated the status
|
| 180 |
+
if not hasattr(self, '_process_terminated_logged') or not self._process_terminated_logged:
|
| 181 |
+
self.append_log("Training process terminated unexpectedly")
|
| 182 |
+
self._process_terminated_logged = True
|
| 183 |
status['status'] = 'error'
|
| 184 |
status['message'] = 'Training process terminated unexpectedly'
|
| 185 |
+
# Update the status file to avoid repeated logging
|
| 186 |
+
with open(self.status_file, 'w') as f:
|
| 187 |
+
json.dump(status, f, indent=2)
|
| 188 |
else:
|
| 189 |
status['status'] = 'stopped'
|
| 190 |
status['message'] = 'Training process not found'
|
| 191 |
return status
|
| 192 |
+
|
| 193 |
except (json.JSONDecodeError, ValueError):
|
| 194 |
return default_status
|
| 195 |
|