Julian Bilcke
commited on
Commit
·
54a2a4e
1
Parent(s):
9545589
working on training job failure recovery
Browse files- app.py +124 -4
- vms/training_log_parser.py +33 -34
- vms/training_service.py +188 -4
app.py
CHANGED
|
@@ -59,7 +59,43 @@ class VideoTrainerUI:
|
|
| 59 |
self.captioner = CaptioningService()
|
| 60 |
self._should_stop_captioning = False
|
| 61 |
self.log_parser = TrainingLogParser()
|
| 62 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 63 |
def update_captioning_buttons_start(self):
|
| 64 |
"""Return individual button values instead of a dictionary"""
|
| 65 |
return (
|
|
@@ -1120,12 +1156,55 @@ class VideoTrainerUI:
|
|
| 1120 |
return gr.update(value=repo_id, error=None)
|
| 1121 |
|
| 1122 |
# Connect events
|
|
|
|
|
|
|
| 1123 |
model_type.change(
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1124 |
fn=update_model_info,
|
| 1125 |
inputs=[model_type],
|
| 1126 |
outputs=[model_info, num_epochs, batch_size, learning_rate, save_iterations]
|
| 1127 |
)
|
| 1128 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1129 |
async def on_import_success(enable_splitting, enable_automatic_content_captioning, prompt_prefix):
|
| 1130 |
videos = self.list_unprocessed_videos()
|
| 1131 |
# If scene detection isn't already running and there are videos to process,
|
|
@@ -1243,8 +1322,13 @@ class VideoTrainerUI:
|
|
| 1243 |
fn=self.list_training_files_to_caption,
|
| 1244 |
outputs=[training_dataset]
|
| 1245 |
)
|
| 1246 |
-
|
|
|
|
| 1247 |
training_preset.change(
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1248 |
fn=self.update_training_params,
|
| 1249 |
inputs=[training_preset],
|
| 1250 |
outputs=[
|
|
@@ -1337,13 +1421,49 @@ class VideoTrainerUI:
|
|
| 1337 |
]
|
| 1338 |
)
|
| 1339 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1340 |
# Auto-refresh timers
|
| 1341 |
app.load(
|
| 1342 |
fn=lambda: (
|
| 1343 |
-
self.refresh_dataset()
|
|
|
|
|
|
|
|
|
|
| 1344 |
),
|
| 1345 |
outputs=[
|
| 1346 |
-
video_list, training_dataset
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1347 |
]
|
| 1348 |
)
|
| 1349 |
|
|
|
|
| 59 |
self.captioner = CaptioningService()
|
| 60 |
self._should_stop_captioning = False
|
| 61 |
self.log_parser = TrainingLogParser()
|
| 62 |
+
|
| 63 |
+
# Try to recover any interrupted training sessions
|
| 64 |
+
recovery_result = self.trainer.recover_interrupted_training()
|
| 65 |
+
|
| 66 |
+
self.recovery_status = recovery_result.get("status", "unknown")
|
| 67 |
+
self.ui_updates = recovery_result.get("ui_updates", {})
|
| 68 |
+
|
| 69 |
+
if recovery_result["status"] == "recovered":
|
| 70 |
+
logger.info(f"Training recovery: {recovery_result['message']}")
|
| 71 |
+
# No need to do anything else - the training is already running
|
| 72 |
+
elif recovery_result["status"] == "running":
|
| 73 |
+
logger.info("Training process is already running")
|
| 74 |
+
# No need to do anything - the process is still alive
|
| 75 |
+
elif recovery_result["status"] in ["error", "idle"]:
|
| 76 |
+
logger.warning(f"Training status: {recovery_result['message']}")
|
| 77 |
+
# UI will be in ready-to-start mode
|
| 78 |
+
|
| 79 |
+
|
| 80 |
+
def update_ui_state(self, **kwargs):
|
| 81 |
+
"""Update UI state with new values"""
|
| 82 |
+
current_state = self.trainer.load_ui_state()
|
| 83 |
+
current_state.update(kwargs)
|
| 84 |
+
self.trainer.save_ui_state(current_state)
|
| 85 |
+
return current_state
|
| 86 |
+
|
| 87 |
+
def load_ui_values(self):
|
| 88 |
+
"""Load UI state values for initializing form fields"""
|
| 89 |
+
ui_state = self.trainer.load_ui_state()
|
| 90 |
+
|
| 91 |
+
# Convert types as needed since JSON stores everything as strings
|
| 92 |
+
ui_state["num_epochs"] = int(ui_state.get("num_epochs", 70))
|
| 93 |
+
ui_state["batch_size"] = int(ui_state.get("batch_size", 1))
|
| 94 |
+
ui_state["learning_rate"] = float(ui_state.get("learning_rate", 3e-5))
|
| 95 |
+
ui_state["save_iterations"] = int(ui_state.get("save_iterations", 500))
|
| 96 |
+
|
| 97 |
+
return ui_state
|
| 98 |
+
|
| 99 |
def update_captioning_buttons_start(self):
|
| 100 |
"""Return individual button values instead of a dictionary"""
|
| 101 |
return (
|
|
|
|
| 1156 |
return gr.update(value=repo_id, error=None)
|
| 1157 |
|
| 1158 |
# Connect events
|
| 1159 |
+
|
| 1160 |
+
# Save state when model type changes
|
| 1161 |
model_type.change(
|
| 1162 |
+
fn=lambda v: self.update_ui_state(model_type=v),
|
| 1163 |
+
inputs=[model_type],
|
| 1164 |
+
outputs=[] # No UI update needed
|
| 1165 |
+
).then(
|
| 1166 |
fn=update_model_info,
|
| 1167 |
inputs=[model_type],
|
| 1168 |
outputs=[model_info, num_epochs, batch_size, learning_rate, save_iterations]
|
| 1169 |
)
|
| 1170 |
|
| 1171 |
+
# the following change listeners are used for UI persistence
|
| 1172 |
+
lora_rank.change(
|
| 1173 |
+
fn=lambda v: self.update_ui_state(lora_rank=v),
|
| 1174 |
+
inputs=[lora_rank],
|
| 1175 |
+
outputs=[]
|
| 1176 |
+
)
|
| 1177 |
+
|
| 1178 |
+
lora_alpha.change(
|
| 1179 |
+
fn=lambda v: self.update_ui_state(lora_alpha=v),
|
| 1180 |
+
inputs=[lora_alpha],
|
| 1181 |
+
outputs=[]
|
| 1182 |
+
)
|
| 1183 |
+
|
| 1184 |
+
num_epochs.change(
|
| 1185 |
+
fn=lambda v: self.update_ui_state(num_epochs=v),
|
| 1186 |
+
inputs=[num_epochs],
|
| 1187 |
+
outputs=[]
|
| 1188 |
+
)
|
| 1189 |
+
|
| 1190 |
+
batch_size.change(
|
| 1191 |
+
fn=lambda v: self.update_ui_state(batch_size=v),
|
| 1192 |
+
inputs=[batch_size],
|
| 1193 |
+
outputs=[]
|
| 1194 |
+
)
|
| 1195 |
+
|
| 1196 |
+
learning_rate.change(
|
| 1197 |
+
fn=lambda v: self.update_ui_state(learning_rate=v),
|
| 1198 |
+
inputs=[learning_rate],
|
| 1199 |
+
outputs=[]
|
| 1200 |
+
)
|
| 1201 |
+
|
| 1202 |
+
save_iterations.change(
|
| 1203 |
+
fn=lambda v: self.update_ui_state(save_iterations=v),
|
| 1204 |
+
inputs=[save_iterations],
|
| 1205 |
+
outputs=[]
|
| 1206 |
+
)
|
| 1207 |
+
|
| 1208 |
async def on_import_success(enable_splitting, enable_automatic_content_captioning, prompt_prefix):
|
| 1209 |
videos = self.list_unprocessed_videos()
|
| 1210 |
# If scene detection isn't already running and there are videos to process,
|
|
|
|
| 1322 |
fn=self.list_training_files_to_caption,
|
| 1323 |
outputs=[training_dataset]
|
| 1324 |
)
|
| 1325 |
+
|
| 1326 |
+
# Save state when training preset changes
|
| 1327 |
training_preset.change(
|
| 1328 |
+
fn=lambda v: self.update_ui_state(training_preset=v),
|
| 1329 |
+
inputs=[training_preset],
|
| 1330 |
+
outputs=[] # No UI update needed
|
| 1331 |
+
).then(
|
| 1332 |
fn=self.update_training_params,
|
| 1333 |
inputs=[training_preset],
|
| 1334 |
outputs=[
|
|
|
|
| 1421 |
]
|
| 1422 |
)
|
| 1423 |
|
| 1424 |
+
# Add this new method to get initial button states:
|
| 1425 |
+
def get_initial_button_states(self):
|
| 1426 |
+
"""Get the initial states for training buttons based on recovery status"""
|
| 1427 |
+
recovery_result = self.trainer.recover_interrupted_training()
|
| 1428 |
+
ui_updates = recovery_result.get("ui_updates", {})
|
| 1429 |
+
|
| 1430 |
+
# Return button states in the correct order
|
| 1431 |
+
return (
|
| 1432 |
+
gr.Button(**ui_updates.get("start_btn", {"interactive": True, "variant": "primary"})),
|
| 1433 |
+
gr.Button(**ui_updates.get("stop_btn", {"interactive": False, "variant": "secondary"})),
|
| 1434 |
+
gr.Button(**ui_updates.get("pause_resume_btn", {"interactive": False, "variant": "secondary"}))
|
| 1435 |
+
)
|
| 1436 |
+
|
| 1437 |
+
def initialize_ui_from_state(self):
|
| 1438 |
+
"""Initialize UI components from saved state"""
|
| 1439 |
+
ui_state = self.load_ui_values()
|
| 1440 |
+
|
| 1441 |
+
# Return values in order matching the outputs in app.load
|
| 1442 |
+
return (
|
| 1443 |
+
ui_state.get("training_preset", list(TRAINING_PRESETS.keys())[0]),
|
| 1444 |
+
ui_state.get("model_type", list(MODEL_TYPES.keys())[0]),
|
| 1445 |
+
ui_state.get("lora_rank", "128"),
|
| 1446 |
+
ui_state.get("lora_alpha", "128"),
|
| 1447 |
+
ui_state.get("num_epochs", 70),
|
| 1448 |
+
ui_state.get("batch_size", 1),
|
| 1449 |
+
ui_state.get("learning_rate", 3e-5),
|
| 1450 |
+
ui_state.get("save_iterations", 500)
|
| 1451 |
+
)
|
| 1452 |
+
|
| 1453 |
# Auto-refresh timers
|
| 1454 |
app.load(
|
| 1455 |
fn=lambda: (
|
| 1456 |
+
self.refresh_dataset(),
|
| 1457 |
+
*self.get_initial_button_states(),
|
| 1458 |
+
# Load saved UI state values
|
| 1459 |
+
*self.initialize_ui_from_state()
|
| 1460 |
),
|
| 1461 |
outputs=[
|
| 1462 |
+
video_list, training_dataset,
|
| 1463 |
+
start_btn, stop_btn, pause_resume_btn,
|
| 1464 |
+
# Add outputs for UI fields
|
| 1465 |
+
training_preset, model_type, lora_rank, lora_alpha,
|
| 1466 |
+
num_epochs, batch_size, learning_rate, save_iterations
|
| 1467 |
]
|
| 1468 |
)
|
| 1469 |
|
vms/training_log_parser.py
CHANGED
|
@@ -34,7 +34,14 @@ class TrainingState:
|
|
| 34 |
|
| 35 |
def to_dict(self) -> Dict[str, Any]:
|
| 36 |
"""Convert state to dictionary for UI updates"""
|
| 37 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 38 |
remaining = str(self.estimated_remaining) if self.estimated_remaining else "calculating..."
|
| 39 |
|
| 40 |
return {
|
|
@@ -74,10 +81,11 @@ class TrainingLogParser:
|
|
| 74 |
if ("Started training" in line) or ("Starting training" in line):
|
| 75 |
self.state.status = "training"
|
| 76 |
|
|
|
|
| 77 |
if "Training steps:" in line:
|
| 78 |
# Set status to training if we see this
|
| 79 |
self.state.status = "training"
|
| 80 |
-
|
| 81 |
if not self.state.start_time:
|
| 82 |
self.state.start_time = datetime.now()
|
| 83 |
|
|
@@ -97,36 +105,23 @@ class TrainingLogParser:
|
|
| 97 |
if match:
|
| 98 |
setattr(self.state, attr, float(match.group(1)))
|
| 99 |
|
| 100 |
-
#
|
| 101 |
-
|
| 102 |
-
|
| 103 |
-
|
| 104 |
-
|
| 105 |
-
|
| 106 |
-
|
| 107 |
-
|
| 108 |
-
|
| 109 |
-
|
| 110 |
-
|
| 111 |
-
|
| 112 |
-
|
| 113 |
-
|
| 114 |
-
|
| 115 |
-
|
| 116 |
-
|
| 117 |
-
# Create formatted timedelta
|
| 118 |
-
if days > 0:
|
| 119 |
-
formatted_time = f"{days}d {hours}h {minutes}m {seconds}s"
|
| 120 |
-
elif hours > 0:
|
| 121 |
-
formatted_time = f"{hours}h {minutes}m {seconds}s"
|
| 122 |
-
elif minutes > 0:
|
| 123 |
-
formatted_time = f"{minutes}m {seconds}s"
|
| 124 |
-
else:
|
| 125 |
-
formatted_time = f"{seconds}s"
|
| 126 |
-
|
| 127 |
-
self.state.estimated_remaining = formatted_time
|
| 128 |
-
self.state.last_step_time = now
|
| 129 |
-
|
| 130 |
logger.info(f"Updated training state: step={self.state.current_step}/{self.state.total_steps}, loss={self.state.step_loss}")
|
| 131 |
return self.state.to_dict()
|
| 132 |
|
|
@@ -162,12 +157,16 @@ class TrainingLogParser:
|
|
| 162 |
|
| 163 |
# Completion states
|
| 164 |
if "Training completed successfully" in line:
|
| 165 |
-
self.
|
|
|
|
|
|
|
| 166 |
logger.info("Training completed")
|
| 167 |
return self.state.to_dict()
|
| 168 |
|
| 169 |
if any(x in line for x in ["Training process stopped", "Training stopped"]):
|
| 170 |
-
self.
|
|
|
|
|
|
|
| 171 |
logger.info("Training stopped")
|
| 172 |
return self.state.to_dict()
|
| 173 |
|
|
|
|
| 34 |
|
| 35 |
def to_dict(self) -> Dict[str, Any]:
|
| 36 |
"""Convert state to dictionary for UI updates"""
|
| 37 |
+
# Calculate elapsed time only if training is active and we have a start time
|
| 38 |
+
if self.start_time and self.status in ["training", "initializing"]:
|
| 39 |
+
elapsed = str(datetime.now() - self.start_time)
|
| 40 |
+
else:
|
| 41 |
+
# Use the last known elapsed time or show 0
|
| 42 |
+
elapsed = "0:00:00" if not self.last_step_time else str(self.last_step_time - self.start_time if self.start_time else "0:00:00")
|
| 43 |
+
|
| 44 |
+
# Use precomputed remaining time from logs if available
|
| 45 |
remaining = str(self.estimated_remaining) if self.estimated_remaining else "calculating..."
|
| 46 |
|
| 47 |
return {
|
|
|
|
| 81 |
if ("Started training" in line) or ("Starting training" in line):
|
| 82 |
self.state.status = "training"
|
| 83 |
|
| 84 |
+
# Check for "Training steps:" which contains the progress information
|
| 85 |
if "Training steps:" in line:
|
| 86 |
# Set status to training if we see this
|
| 87 |
self.state.status = "training"
|
| 88 |
+
|
| 89 |
if not self.state.start_time:
|
| 90 |
self.state.start_time = datetime.now()
|
| 91 |
|
|
|
|
| 105 |
if match:
|
| 106 |
setattr(self.state, attr, float(match.group(1)))
|
| 107 |
|
| 108 |
+
# Extract time remaining directly from the log
|
| 109 |
+
# Format: [MM:SS<M:SS:SS, SS.SSs/it]
|
| 110 |
+
time_remaining_match = re.search(r"<(\d+:\d+:\d+)", line)
|
| 111 |
+
if time_remaining_match:
|
| 112 |
+
remaining_str = time_remaining_match.group(1)
|
| 113 |
+
# Store the string directly - no need to parse it
|
| 114 |
+
self.state.estimated_remaining = remaining_str
|
| 115 |
+
|
| 116 |
+
# If no direct time estimate, look for hour:min format
|
| 117 |
+
if not time_remaining_match:
|
| 118 |
+
hour_min_match = re.search(r"<(\d+h\s*\d+m)", line)
|
| 119 |
+
if hour_min_match:
|
| 120 |
+
self.state.estimated_remaining = hour_min_match.group(1)
|
| 121 |
+
|
| 122 |
+
# Update last processing time
|
| 123 |
+
self.state.last_step_time = datetime.now()
|
| 124 |
+
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 125 |
logger.info(f"Updated training state: step={self.state.current_step}/{self.state.total_steps}, loss={self.state.step_loss}")
|
| 126 |
return self.state.to_dict()
|
| 127 |
|
|
|
|
| 157 |
|
| 158 |
# Completion states
|
| 159 |
if "Training completed successfully" in line:
|
| 160 |
+
self.status = "completed"
|
| 161 |
+
# Store final elapsed time
|
| 162 |
+
self.last_step_time = datetime.now()
|
| 163 |
logger.info("Training completed")
|
| 164 |
return self.state.to_dict()
|
| 165 |
|
| 166 |
if any(x in line for x in ["Training process stopped", "Training stopped"]):
|
| 167 |
+
self.status = "stopped"
|
| 168 |
+
# Store final elapsed time
|
| 169 |
+
self.last_step_time = datetime.now()
|
| 170 |
logger.info("Training stopped")
|
| 171 |
return self.state.to_dict()
|
| 172 |
|
vms/training_service.py
CHANGED
|
@@ -38,7 +38,7 @@ class TrainingService:
|
|
| 38 |
self.setup_logging()
|
| 39 |
|
| 40 |
logger.info("Training service initialized")
|
| 41 |
-
|
| 42 |
def setup_logging(self):
|
| 43 |
"""Set up logging with proper handler management"""
|
| 44 |
global logger
|
|
@@ -96,16 +96,58 @@ class TrainingService:
|
|
| 96 |
if self.file_handler:
|
| 97 |
self.file_handler.close()
|
| 98 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 99 |
def save_session(self, params: Dict) -> None:
|
| 100 |
"""Save training session parameters"""
|
| 101 |
session_data = {
|
| 102 |
"timestamp": datetime.now().isoformat(),
|
| 103 |
"params": params,
|
| 104 |
-
"status": self.get_status()
|
|
|
|
|
|
|
| 105 |
}
|
| 106 |
with open(self.session_file, 'w') as f:
|
| 107 |
json.dump(session_data, f, indent=2)
|
| 108 |
-
|
| 109 |
def load_session(self) -> Optional[Dict]:
|
| 110 |
"""Load saved training session"""
|
| 111 |
if self.session_file.exists():
|
|
@@ -225,6 +267,7 @@ class TrainingService:
|
|
| 225 |
save_iterations: int,
|
| 226 |
repo_id: str,
|
| 227 |
preset_name: str,
|
|
|
|
| 228 |
) -> Tuple[str, str]:
|
| 229 |
"""Start training with finetrainers"""
|
| 230 |
|
|
@@ -295,6 +338,11 @@ class TrainingService:
|
|
| 295 |
config.lr = float(learning_rate)
|
| 296 |
config.checkpointing_steps = int(save_iterations)
|
| 297 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 298 |
# Common settings for both models
|
| 299 |
config.mixed_precision = "bf16"
|
| 300 |
config.seed = 42
|
|
@@ -477,10 +525,146 @@ class TrainingService:
|
|
| 477 |
try:
|
| 478 |
with open(self.pid_file, 'r') as f:
|
| 479 |
pid = int(f.read().strip())
|
| 480 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 481 |
except:
|
| 482 |
return False
|
| 483 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 484 |
def clear_training_data(self) -> str:
|
| 485 |
"""Clear all training data"""
|
| 486 |
if self.is_training_running():
|
|
|
|
| 38 |
self.setup_logging()
|
| 39 |
|
| 40 |
logger.info("Training service initialized")
|
| 41 |
+
|
| 42 |
def setup_logging(self):
|
| 43 |
"""Set up logging with proper handler management"""
|
| 44 |
global logger
|
|
|
|
| 96 |
if self.file_handler:
|
| 97 |
self.file_handler.close()
|
| 98 |
|
| 99 |
+
|
| 100 |
+
def save_ui_state(self, values: Dict[str, Any]) -> None:
|
| 101 |
+
"""Save current UI state to file"""
|
| 102 |
+
ui_state_file = OUTPUT_PATH / "ui_state.json"
|
| 103 |
+
try:
|
| 104 |
+
with open(ui_state_file, 'w') as f:
|
| 105 |
+
json.dump(values, f, indent=2)
|
| 106 |
+
logger.debug(f"UI state saved: {values}")
|
| 107 |
+
except Exception as e:
|
| 108 |
+
logger.error(f"Error saving UI state: {str(e)}")
|
| 109 |
+
|
| 110 |
+
def load_ui_state(self) -> Dict[str, Any]:
|
| 111 |
+
"""Load saved UI state"""
|
| 112 |
+
ui_state_file = OUTPUT_PATH / "ui_state.json"
|
| 113 |
+
default_state = {
|
| 114 |
+
"model_type": list(MODEL_TYPES.keys())[0],
|
| 115 |
+
"lora_rank": "128",
|
| 116 |
+
"lora_alpha": "128",
|
| 117 |
+
"num_epochs": 70,
|
| 118 |
+
"batch_size": 1,
|
| 119 |
+
"learning_rate": 3e-5,
|
| 120 |
+
"save_iterations": 500,
|
| 121 |
+
"training_preset": list(TRAINING_PRESETS.keys())[0]
|
| 122 |
+
}
|
| 123 |
+
|
| 124 |
+
if not ui_state_file.exists():
|
| 125 |
+
return default_state
|
| 126 |
+
|
| 127 |
+
try:
|
| 128 |
+
with open(ui_state_file, 'r') as f:
|
| 129 |
+
saved_state = json.load(f)
|
| 130 |
+
# Make sure we have all keys (in case structure changed)
|
| 131 |
+
merged_state = default_state.copy()
|
| 132 |
+
merged_state.update(saved_state)
|
| 133 |
+
return merged_state
|
| 134 |
+
except Exception as e:
|
| 135 |
+
logger.error(f"Error loading UI state: {str(e)}")
|
| 136 |
+
return default_state
|
| 137 |
+
|
| 138 |
+
# Modify save_session to also store the UI state at training start
|
| 139 |
def save_session(self, params: Dict) -> None:
|
| 140 |
"""Save training session parameters"""
|
| 141 |
session_data = {
|
| 142 |
"timestamp": datetime.now().isoformat(),
|
| 143 |
"params": params,
|
| 144 |
+
"status": self.get_status(),
|
| 145 |
+
# Add UI state at the time training started
|
| 146 |
+
"initial_ui_state": self.load_ui_state()
|
| 147 |
}
|
| 148 |
with open(self.session_file, 'w') as f:
|
| 149 |
json.dump(session_data, f, indent=2)
|
| 150 |
+
|
| 151 |
def load_session(self) -> Optional[Dict]:
|
| 152 |
"""Load saved training session"""
|
| 153 |
if self.session_file.exists():
|
|
|
|
| 267 |
save_iterations: int,
|
| 268 |
repo_id: str,
|
| 269 |
preset_name: str,
|
| 270 |
+
resume_from_checkpoint: Optional[str] = None,
|
| 271 |
) -> Tuple[str, str]:
|
| 272 |
"""Start training with finetrainers"""
|
| 273 |
|
|
|
|
| 338 |
config.lr = float(learning_rate)
|
| 339 |
config.checkpointing_steps = int(save_iterations)
|
| 340 |
|
| 341 |
+
# Update with resume_from_checkpoint if provided
|
| 342 |
+
if resume_from_checkpoint:
|
| 343 |
+
config.resume_from_checkpoint = resume_from_checkpoint
|
| 344 |
+
self.append_log(f"Resuming from checkpoint: {resume_from_checkpoint}")
|
| 345 |
+
|
| 346 |
# Common settings for both models
|
| 347 |
config.mixed_precision = "bf16"
|
| 348 |
config.seed = 42
|
|
|
|
| 525 |
try:
|
| 526 |
with open(self.pid_file, 'r') as f:
|
| 527 |
pid = int(f.read().strip())
|
| 528 |
+
|
| 529 |
+
# Check if process exists AND is a Python process running train.py
|
| 530 |
+
if psutil.pid_exists(pid):
|
| 531 |
+
try:
|
| 532 |
+
process = psutil.Process(pid)
|
| 533 |
+
cmdline = process.cmdline()
|
| 534 |
+
# Check if it's a Python process running train.py
|
| 535 |
+
return any('train.py' in cmd for cmd in cmdline)
|
| 536 |
+
except (psutil.NoSuchProcess, psutil.AccessDenied):
|
| 537 |
+
return False
|
| 538 |
+
return False
|
| 539 |
except:
|
| 540 |
return False
|
| 541 |
|
| 542 |
+
def recover_interrupted_training(self) -> Dict[str, Any]:
|
| 543 |
+
"""Attempt to recover interrupted training
|
| 544 |
+
|
| 545 |
+
Returns:
|
| 546 |
+
Dict with recovery status and UI updates
|
| 547 |
+
"""
|
| 548 |
+
status = self.get_status()
|
| 549 |
+
ui_updates = {}
|
| 550 |
+
|
| 551 |
+
# If status indicates training but process isn't running, try to recover
|
| 552 |
+
if status.get('status') == 'training' and not self.is_training_running():
|
| 553 |
+
logger.info("Detected interrupted training session, attempting to recover...")
|
| 554 |
+
|
| 555 |
+
# Get the latest checkpoint
|
| 556 |
+
last_session = self.load_session()
|
| 557 |
+
if not last_session:
|
| 558 |
+
logger.warning("No session data found for recovery")
|
| 559 |
+
# Set buttons for no active training
|
| 560 |
+
ui_updates = {
|
| 561 |
+
"start_btn": {"interactive": True, "variant": "primary"},
|
| 562 |
+
"stop_btn": {"interactive": False, "variant": "secondary"},
|
| 563 |
+
"pause_resume_btn": {"interactive": False, "variant": "secondary"}
|
| 564 |
+
}
|
| 565 |
+
return {"status": "error", "message": "No session data found", "ui_updates": ui_updates}
|
| 566 |
+
|
| 567 |
+
# Find the latest checkpoint
|
| 568 |
+
checkpoints = list(OUTPUT_PATH.glob("checkpoint-*"))
|
| 569 |
+
if not checkpoints:
|
| 570 |
+
logger.warning("No checkpoints found for recovery")
|
| 571 |
+
# Set buttons for no active training
|
| 572 |
+
ui_updates = {
|
| 573 |
+
"start_btn": {"interactive": True, "variant": "primary"},
|
| 574 |
+
"stop_btn": {"interactive": False, "variant": "secondary"},
|
| 575 |
+
"pause_resume_btn": {"interactive": False, "variant": "secondary"}
|
| 576 |
+
}
|
| 577 |
+
return {"status": "error", "message": "No checkpoints found", "ui_updates": ui_updates}
|
| 578 |
+
|
| 579 |
+
latest_checkpoint = max(checkpoints, key=os.path.getmtime)
|
| 580 |
+
checkpoint_step = int(latest_checkpoint.name.split("-")[1])
|
| 581 |
+
|
| 582 |
+
logger.info(f"Found checkpoint at step {checkpoint_step}, attempting to resume")
|
| 583 |
+
|
| 584 |
+
# Extract parameters from the saved session (not current UI state)
|
| 585 |
+
# This ensures we use the original training parameters
|
| 586 |
+
params = last_session.get('params', {})
|
| 587 |
+
initial_ui_state = last_session.get('initial_ui_state', {})
|
| 588 |
+
|
| 589 |
+
# Add UI updates to restore the training parameters in the UI
|
| 590 |
+
# This shows the user what values are being used for the resumed training
|
| 591 |
+
ui_updates.update({
|
| 592 |
+
"model_type": gr.update(value=params.get('model_type', list(MODEL_TYPES.keys())[0])),
|
| 593 |
+
"lora_rank": gr.update(value=params.get('lora_rank', "128")),
|
| 594 |
+
"lora_alpha": gr.update(value=params.get('lora_alpha', "128")),
|
| 595 |
+
"num_epochs": gr.update(value=params.get('num_epochs', 70)),
|
| 596 |
+
"batch_size": gr.update(value=params.get('batch_size', 1)),
|
| 597 |
+
"learning_rate": gr.update(value=params.get('learning_rate', 3e-5)),
|
| 598 |
+
"save_iterations": gr.update(value=params.get('save_iterations', 500)),
|
| 599 |
+
"training_preset": gr.update(value=params.get('preset_name', list(TRAINING_PRESETS.keys())[0]))
|
| 600 |
+
})
|
| 601 |
+
|
| 602 |
+
# Attempt to resume training using the ORIGINAL parameters
|
| 603 |
+
try:
|
| 604 |
+
# Extract required parameters from the session
|
| 605 |
+
model_type = params.get('model_type')
|
| 606 |
+
lora_rank = params.get('lora_rank')
|
| 607 |
+
lora_alpha = params.get('lora_alpha')
|
| 608 |
+
num_epochs = params.get('num_epochs')
|
| 609 |
+
batch_size = params.get('batch_size')
|
| 610 |
+
learning_rate = params.get('learning_rate')
|
| 611 |
+
save_iterations = params.get('save_iterations')
|
| 612 |
+
repo_id = params.get('repo_id')
|
| 613 |
+
preset_name = params.get('preset_name', list(TRAINING_PRESETS.keys())[0])
|
| 614 |
+
|
| 615 |
+
# Attempt to resume training
|
| 616 |
+
result = self.start_training(
|
| 617 |
+
model_type=model_type,
|
| 618 |
+
lora_rank=lora_rank,
|
| 619 |
+
lora_alpha=lora_alpha,
|
| 620 |
+
num_epochs=num_epochs,
|
| 621 |
+
batch_size=batch_size,
|
| 622 |
+
learning_rate=learning_rate,
|
| 623 |
+
save_iterations=save_iterations,
|
| 624 |
+
repo_id=repo_id,
|
| 625 |
+
preset_name=preset_name,
|
| 626 |
+
resume_from_checkpoint=str(latest_checkpoint)
|
| 627 |
+
)
|
| 628 |
+
|
| 629 |
+
# Set buttons for active training
|
| 630 |
+
ui_updates.update({
|
| 631 |
+
"start_btn": {"interactive": False, "variant": "secondary"},
|
| 632 |
+
"stop_btn": {"interactive": True, "variant": "stop"},
|
| 633 |
+
"pause_resume_btn": {"interactive": True, "variant": "secondary"}
|
| 634 |
+
})
|
| 635 |
+
|
| 636 |
+
return {
|
| 637 |
+
"status": "recovered",
|
| 638 |
+
"message": f"Training resumed from checkpoint {checkpoint_step}",
|
| 639 |
+
"result": result,
|
| 640 |
+
"ui_updates": ui_updates
|
| 641 |
+
}
|
| 642 |
+
except Exception as e:
|
| 643 |
+
logger.error(f"Failed to resume training: {str(e)}")
|
| 644 |
+
# Set buttons for no active training
|
| 645 |
+
ui_updates.update({
|
| 646 |
+
"start_btn": {"interactive": True, "variant": "primary"},
|
| 647 |
+
"stop_btn": {"interactive": False, "variant": "secondary"},
|
| 648 |
+
"pause_resume_btn": {"interactive": False, "variant": "secondary"}
|
| 649 |
+
})
|
| 650 |
+
return {"status": "error", "message": f"Failed to resume: {str(e)}", "ui_updates": ui_updates}
|
| 651 |
+
elif self.is_training_running():
|
| 652 |
+
# Process is still running, set buttons accordingly
|
| 653 |
+
ui_updates = {
|
| 654 |
+
"start_btn": {"interactive": False, "variant": "secondary"},
|
| 655 |
+
"stop_btn": {"interactive": True, "variant": "stop"},
|
| 656 |
+
"pause_resume_btn": {"interactive": True, "variant": "secondary"}
|
| 657 |
+
}
|
| 658 |
+
return {"status": "running", "message": "Training process is running", "ui_updates": ui_updates}
|
| 659 |
+
else:
|
| 660 |
+
# No training process, set buttons to default state
|
| 661 |
+
ui_updates = {
|
| 662 |
+
"start_btn": {"interactive": True, "variant": "primary"},
|
| 663 |
+
"stop_btn": {"interactive": False, "variant": "secondary"},
|
| 664 |
+
"pause_resume_btn": {"interactive": False, "variant": "secondary"}
|
| 665 |
+
}
|
| 666 |
+
return {"status": "idle", "message": "No training in progress", "ui_updates": ui_updates}
|
| 667 |
+
|
| 668 |
def clear_training_data(self) -> str:
|
| 669 |
"""Clear all training data"""
|
| 670 |
if self.is_training_running():
|