feat: add live training log streaming using TrainerCallback and background thread
Browse files- Added GradioLoggerCallback to capture trainer log events
- Ran trainer.train() in a background thread to avoid UI blocking
- Streamed logs from queue during training to Gradio UI using yield
- Replaced simulated progress loop with real-time progress and log updates
- Fixes issue where UI showed only progress bar and froze at 99%
- train_abuse_model.py +33 -19
train_abuse_model.py
CHANGED
|
@@ -1,6 +1,6 @@
|
|
| 1 |
# # Install core packages
|
| 2 |
# !pip install -U transformers datasets accelerate
|
| 3 |
-
|
| 4 |
import logging
|
| 5 |
import io
|
| 6 |
import os
|
|
@@ -256,28 +256,42 @@ def run_training(progress=gr.Progress(track_tqdm=True)):
|
|
| 256 |
)
|
| 257 |
|
| 258 |
logger.info("Training started with %d samples", len(train_dataset))
|
| 259 |
-
yield "π Training
|
| 260 |
-
|
| 261 |
-
|
| 262 |
-
|
| 263 |
-
|
| 264 |
-
|
| 265 |
-
|
| 266 |
-
|
| 267 |
-
|
| 268 |
-
|
| 269 |
-
|
| 270 |
-
|
| 271 |
-
|
| 272 |
-
|
| 273 |
-
#
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 274 |
|
| 275 |
# Start training!
|
| 276 |
trainer.train()
|
| 277 |
|
| 278 |
-
# Drain queue to UI
|
| 279 |
-
while not log_queue.empty():
|
| 280 |
-
|
| 281 |
|
| 282 |
progress(1.0)
|
| 283 |
yield "β
Progress: 100%\n"
|
|
|
|
| 1 |
# # Install core packages
|
| 2 |
# !pip install -U transformers datasets accelerate
|
| 3 |
+
import threading
|
| 4 |
import logging
|
| 5 |
import io
|
| 6 |
import os
|
|
|
|
| 256 |
)
|
| 257 |
|
| 258 |
logger.info("Training started with %d samples", len(train_dataset))
|
| 259 |
+
yield "π Training started...\n"
|
| 260 |
+
|
| 261 |
+
progress(0.01)
|
| 262 |
+
|
| 263 |
+
# Run training in background thread
|
| 264 |
+
trainer_training = [True]
|
| 265 |
+
|
| 266 |
+
def background_train():
|
| 267 |
+
trainer.train()
|
| 268 |
+
trainer_training[0] = False # Mark as done
|
| 269 |
+
|
| 270 |
+
train_thread = threading.Thread(target=background_train)
|
| 271 |
+
train_thread.start()
|
| 272 |
+
|
| 273 |
+
# Drain log queue live while training runs
|
| 274 |
+
percent = 0
|
| 275 |
+
while train_thread.is_alive() or not log_queue.empty():
|
| 276 |
+
while not log_queue.empty():
|
| 277 |
+
log_msg = log_queue.get()
|
| 278 |
+
yield log_msg
|
| 279 |
+
# Optional: update progress bar slowly toward 1.0
|
| 280 |
+
if percent < 98:
|
| 281 |
+
percent += 1
|
| 282 |
+
progress(percent / 100)
|
| 283 |
+
time.sleep(1)
|
| 284 |
+
|
| 285 |
+
progress(1.0)
|
| 286 |
+
yield "β
Progress: 100%\n"
|
| 287 |
+
|
| 288 |
|
| 289 |
# Start training!
|
| 290 |
trainer.train()
|
| 291 |
|
| 292 |
+
# # Drain queue to UI
|
| 293 |
+
# while not log_queue.empty():
|
| 294 |
+
# yield log_queue.get()
|
| 295 |
|
| 296 |
progress(1.0)
|
| 297 |
yield "β
Progress: 100%\n"
|