|
|
import os |
|
|
import sys |
|
|
import json |
|
|
import time |
|
|
import shutil |
|
|
import gradio as gr |
|
|
from pathlib import Path |
|
|
from datetime import datetime |
|
|
import subprocess |
|
|
import signal |
|
|
import psutil |
|
|
import tempfile |
|
|
import zipfile |
|
|
import logging |
|
|
import traceback |
|
|
import threading |
|
|
import select |
|
|
|
|
|
from typing import Any, Optional, Dict, List, Union, Tuple |
|
|
|
|
|
from huggingface_hub import upload_folder, create_repo |
|
|
from config import TrainingConfig, LOG_FILE_PATH, TRAINING_VIDEOS_PATH, STORAGE_PATH, TRAINING_PATH, MODEL_PATH, OUTPUT_PATH, HF_API_TOKEN, MODEL_TYPES |
|
|
from utils import make_archive, parse_training_log, is_image_file, is_video_file |
|
|
from finetrainers_utils import prepare_finetrainers_dataset, copy_files_to_training_dir |
|
|
|
|
|
logger = logging.getLogger(__name__) |
|
|
|
|
|
class TrainingService: |
|
|
def __init__(self): |
|
|
|
|
|
self.session_file = OUTPUT_PATH / "session.json" |
|
|
self.status_file = OUTPUT_PATH / "status.json" |
|
|
self.pid_file = OUTPUT_PATH / "training.pid" |
|
|
self.log_file = OUTPUT_PATH / "training.log" |
|
|
|
|
|
self.file_handler = None |
|
|
self.setup_logging() |
|
|
|
|
|
logger.info("Training service initialized") |
|
|
|
|
|
def setup_logging(self): |
|
|
"""Set up logging with proper handler management""" |
|
|
global logger |
|
|
logger = logging.getLogger(__name__) |
|
|
logger.setLevel(logging.INFO) |
|
|
|
|
|
|
|
|
logger.handlers.clear() |
|
|
|
|
|
|
|
|
stdout_handler = logging.StreamHandler(sys.stdout) |
|
|
stdout_handler.setFormatter(logging.Formatter( |
|
|
'%(asctime)s - %(name)s - %(levelname)s - %(message)s' |
|
|
)) |
|
|
logger.addHandler(stdout_handler) |
|
|
|
|
|
|
|
|
try: |
|
|
|
|
|
if self.file_handler: |
|
|
self.file_handler.close() |
|
|
logger.removeHandler(self.file_handler) |
|
|
|
|
|
self.file_handler = logging.FileHandler(str(LOG_FILE_PATH)) |
|
|
self.file_handler.setFormatter(logging.Formatter( |
|
|
'%(asctime)s - %(name)s - %(levelname)s - %(message)s' |
|
|
)) |
|
|
logger.addHandler(self.file_handler) |
|
|
except Exception as e: |
|
|
logger.warning(f"Could not set up log file: {e}") |
|
|
|
|
|
def clear_logs(self) -> None: |
|
|
"""Clear log file with proper handler cleanup""" |
|
|
try: |
|
|
|
|
|
if self.file_handler: |
|
|
logger.removeHandler(self.file_handler) |
|
|
self.file_handler.close() |
|
|
self.file_handler = None |
|
|
|
|
|
|
|
|
if LOG_FILE_PATH.exists(): |
|
|
LOG_FILE_PATH.unlink() |
|
|
|
|
|
|
|
|
self.setup_logging() |
|
|
self.append_log("Log file cleared and recreated") |
|
|
|
|
|
except Exception as e: |
|
|
logger.error(f"Error clearing logs: {e}") |
|
|
raise |
|
|
|
|
|
def __del__(self): |
|
|
"""Cleanup when the service is destroyed""" |
|
|
if self.file_handler: |
|
|
self.file_handler.close() |
|
|
|
|
|
def save_session(self, params: Dict) -> None: |
|
|
"""Save training session parameters""" |
|
|
session_data = { |
|
|
"timestamp": datetime.now().isoformat(), |
|
|
"params": params, |
|
|
"status": self.get_status() |
|
|
} |
|
|
with open(self.session_file, 'w') as f: |
|
|
json.dump(session_data, f, indent=2) |
|
|
|
|
|
def load_session(self) -> Optional[Dict]: |
|
|
"""Load saved training session""" |
|
|
if self.session_file.exists(): |
|
|
try: |
|
|
with open(self.session_file, 'r') as f: |
|
|
return json.load(f) |
|
|
except json.JSONDecodeError: |
|
|
return None |
|
|
return None |
|
|
|
|
|
def get_status(self) -> Dict: |
|
|
"""Get current training status""" |
|
|
default_status = {'status': 'stopped', 'message': 'No training in progress'} |
|
|
|
|
|
if not self.status_file.exists(): |
|
|
return default_status |
|
|
|
|
|
try: |
|
|
with open(self.status_file, 'r') as f: |
|
|
status = json.load(f) |
|
|
|
|
|
|
|
|
|
|
|
if self.pid_file.exists(): |
|
|
with open(self.pid_file, 'r') as f: |
|
|
pid = int(f.read().strip()) |
|
|
if not psutil.pid_exists(pid): |
|
|
|
|
|
if status['status'] == 'training': |
|
|
status['status'] = 'error' |
|
|
status['message'] = 'Training process terminated unexpectedly' |
|
|
self.append_log("Training process terminated unexpectedly") |
|
|
else: |
|
|
status['status'] = 'stopped' |
|
|
status['message'] = 'Training process not found' |
|
|
return status |
|
|
|
|
|
except (json.JSONDecodeError, ValueError): |
|
|
return default_status |
|
|
|
|
|
def get_logs(self, max_lines: int = 100) -> str: |
|
|
"""Get training logs with line limit""" |
|
|
if self.log_file.exists(): |
|
|
with open(self.log_file, 'r') as f: |
|
|
lines = f.readlines() |
|
|
return ''.join(lines[-max_lines:]) |
|
|
return "" |
|
|
|
|
|
def append_log(self, message: str) -> None: |
|
|
"""Append message to log file and logger""" |
|
|
timestamp = datetime.now().strftime("%Y-%m-%d %H:%M:%S") |
|
|
with open(self.log_file, 'a') as f: |
|
|
f.write(f"[{timestamp}] {message}\n") |
|
|
logger.info(message) |
|
|
|
|
|
def clear_logs(self) -> None: |
|
|
"""Clear log file""" |
|
|
if self.log_file.exists(): |
|
|
self.log_file.unlink() |
|
|
self.append_log("Log file cleared") |
|
|
|
|
|
def validate_training_config(self, config: TrainingConfig, model_type: str) -> Optional[str]: |
|
|
"""Validate training configuration""" |
|
|
logger.info(f"Validating config for {model_type}") |
|
|
|
|
|
try: |
|
|
|
|
|
if not config.data_root or not Path(config.data_root).exists(): |
|
|
return f"Invalid data root path: {config.data_root}" |
|
|
|
|
|
if not config.output_dir: |
|
|
return "Output directory not specified" |
|
|
|
|
|
|
|
|
videos_file = Path(config.data_root) / "videos.txt" |
|
|
prompts_file = Path(config.data_root) / "prompts.txt" |
|
|
|
|
|
if not videos_file.exists(): |
|
|
return f"Missing videos list file: {videos_file}" |
|
|
if not prompts_file.exists(): |
|
|
return f"Missing prompts list file: {prompts_file}" |
|
|
|
|
|
|
|
|
video_lines = [l.strip() for l in open(videos_file) if l.strip()] |
|
|
prompt_lines = [l.strip() for l in open(prompts_file) if l.strip()] |
|
|
|
|
|
if not video_lines: |
|
|
return "No training files found" |
|
|
if len(video_lines) != len(prompt_lines): |
|
|
return f"Mismatch between video count ({len(video_lines)}) and prompt count ({len(prompt_lines)})" |
|
|
|
|
|
|
|
|
if model_type == "hunyuan_video": |
|
|
if config.batch_size > 2: |
|
|
return "Hunyuan model recommended batch size is 1-2" |
|
|
if not config.gradient_checkpointing: |
|
|
return "Gradient checkpointing is required for Hunyuan model" |
|
|
elif model_type == "ltx_video": |
|
|
if config.batch_size > 4: |
|
|
return "LTX model recommended batch size is 1-4" |
|
|
|
|
|
logger.info(f"Config validation passed with {len(video_lines)} training files") |
|
|
return None |
|
|
|
|
|
except Exception as e: |
|
|
logger.error(f"Error during config validation: {str(e)}") |
|
|
return f"Configuration validation failed: {str(e)}" |
|
|
|
|
|
|
|
|
def start_training(self, model_type: str, lora_rank: str, lora_alpha: str, num_epochs: int, batch_size: int, |
|
|
learning_rate: float, save_iterations: int, repo_id: str) -> Tuple[str, str]: |
|
|
"""Start training with finetrainers""" |
|
|
|
|
|
self.clear_logs() |
|
|
|
|
|
if not model_type: |
|
|
raise ValueError("model_type cannot be empty") |
|
|
if model_type not in MODEL_TYPES.values(): |
|
|
raise ValueError(f"Invalid model_type: {model_type}. Must be one of {list(MODEL_TYPES.values())}") |
|
|
|
|
|
|
|
|
logger.info(f"Initializing training with model_type={model_type}") |
|
|
|
|
|
try: |
|
|
|
|
|
current_dir = Path(__file__).parent.absolute() |
|
|
train_script = current_dir / "train.py" |
|
|
|
|
|
if not train_script.exists(): |
|
|
error_msg = f"Training script not found at {train_script}" |
|
|
logger.error(error_msg) |
|
|
return error_msg, "Training script not found" |
|
|
|
|
|
|
|
|
logger.info("Current working directory: %s", current_dir) |
|
|
logger.info("Training script path: %s", train_script) |
|
|
logger.info("Training data path: %s", TRAINING_PATH) |
|
|
|
|
|
videos_file, prompts_file = prepare_finetrainers_dataset() |
|
|
if videos_file is None or prompts_file is None: |
|
|
error_msg = "Failed to generate training lists" |
|
|
logger.error(error_msg) |
|
|
return error_msg, "Training preparation failed" |
|
|
|
|
|
video_count = sum(1 for _ in open(videos_file)) |
|
|
logger.info(f"Generated training lists with {video_count} files") |
|
|
|
|
|
if video_count == 0: |
|
|
error_msg = "No training files found" |
|
|
logger.error(error_msg) |
|
|
return error_msg, "No training data available" |
|
|
|
|
|
|
|
|
if model_type == "hunyuan_video": |
|
|
config = TrainingConfig.hunyuan_video_lora( |
|
|
data_path=str(TRAINING_PATH), |
|
|
output_path=str(OUTPUT_PATH) |
|
|
) |
|
|
else: |
|
|
config = TrainingConfig.ltx_video_lora( |
|
|
data_path=str(TRAINING_PATH), |
|
|
output_path=str(OUTPUT_PATH) |
|
|
) |
|
|
|
|
|
|
|
|
config.train_epochs = int(num_epochs) |
|
|
config.lora_rank = int(lora_rank) |
|
|
config.lora_alpha = int(lora_alpha) |
|
|
config.batch_size = int(batch_size) |
|
|
config.lr = float(learning_rate) |
|
|
config.checkpointing_steps = int(save_iterations) |
|
|
|
|
|
|
|
|
config.mixed_precision = "bf16" |
|
|
config.seed = 42 |
|
|
config.gradient_checkpointing = True |
|
|
config.enable_slicing = True |
|
|
config.enable_tiling = True |
|
|
config.caption_dropout_p = 0.05 |
|
|
|
|
|
validation_error = self.validate_training_config(config, model_type) |
|
|
if validation_error: |
|
|
error_msg = f"Configuration validation failed: {validation_error}" |
|
|
logger.error(error_msg) |
|
|
return "Error: Invalid configuration", error_msg |
|
|
|
|
|
|
|
|
accelerate_args = [ |
|
|
"accelerate", "launch", |
|
|
"--mixed_precision=bf16", |
|
|
"--num_processes=1", |
|
|
"--num_machines=1", |
|
|
"--dynamo_backend=no" |
|
|
] |
|
|
|
|
|
accelerate_args.append(str(train_script)) |
|
|
|
|
|
|
|
|
config_args = config.to_args_list() |
|
|
|
|
|
|
|
|
logger.debug("Generated args list: %s", config_args) |
|
|
|
|
|
|
|
|
command_str = ' '.join(accelerate_args + config_args) |
|
|
self.append_log(f"Command: {command_str}") |
|
|
logger.info(f"Executing command: {command_str}") |
|
|
|
|
|
|
|
|
env = os.environ.copy() |
|
|
env["NCCL_P2P_DISABLE"] = "1" |
|
|
env["TORCH_NCCL_ENABLE_MONITORING"] = "0" |
|
|
env["WANDB_MODE"] = "offline" |
|
|
env["HF_API_TOKEN"] = HF_API_TOKEN |
|
|
env["FINETRAINERS_LOG_LEVEL"] = "DEBUG" |
|
|
|
|
|
|
|
|
process = subprocess.Popen( |
|
|
accelerate_args + config_args, |
|
|
stdout=subprocess.PIPE, |
|
|
stderr=subprocess.PIPE, |
|
|
start_new_session=True, |
|
|
env=env, |
|
|
cwd=str(current_dir), |
|
|
bufsize=1, |
|
|
universal_newlines=True |
|
|
) |
|
|
|
|
|
logger.info(f"Started process with PID: {process.pid}") |
|
|
|
|
|
with open(self.pid_file, 'w') as f: |
|
|
f.write(str(process.pid)) |
|
|
|
|
|
|
|
|
self.save_session({ |
|
|
"model_type": model_type, |
|
|
"lora_rank": lora_rank, |
|
|
"lora_alpha": lora_alpha, |
|
|
"num_epochs": num_epochs, |
|
|
"batch_size": batch_size, |
|
|
"learning_rate": learning_rate, |
|
|
"save_iterations": save_iterations, |
|
|
"repo_id": repo_id, |
|
|
"start_time": datetime.now().isoformat() |
|
|
}) |
|
|
|
|
|
|
|
|
total_steps = num_epochs * (max(1, video_count) // batch_size) |
|
|
self.save_status( |
|
|
state='training', |
|
|
epoch=0, |
|
|
step=0, |
|
|
total_steps=total_steps, |
|
|
loss=0.0, |
|
|
total_epochs=num_epochs, |
|
|
message='Training started', |
|
|
repo_id=repo_id, |
|
|
model_type=model_type |
|
|
) |
|
|
|
|
|
|
|
|
self._start_log_monitor(process) |
|
|
|
|
|
success_msg = f"Started training {model_type} model" |
|
|
self.append_log(success_msg) |
|
|
logger.info(success_msg) |
|
|
|
|
|
return success_msg, self.get_logs() |
|
|
|
|
|
except Exception as e: |
|
|
error_msg = f"Error starting training: {str(e)}" |
|
|
self.append_log(error_msg) |
|
|
logger.exception("Training startup failed") |
|
|
traceback.print_exc() |
|
|
return "Error starting training", error_msg |
|
|
|
|
|
|
|
|
def stop_training(self) -> Tuple[str, str]: |
|
|
"""Stop training process""" |
|
|
if not self.pid_file.exists(): |
|
|
return "No training process found", self.get_logs() |
|
|
|
|
|
try: |
|
|
with open(self.pid_file, 'r') as f: |
|
|
pid = int(f.read().strip()) |
|
|
|
|
|
if psutil.pid_exists(pid): |
|
|
os.killpg(os.getpgid(pid), signal.SIGTERM) |
|
|
|
|
|
if self.pid_file.exists(): |
|
|
self.pid_file.unlink() |
|
|
|
|
|
self.append_log("Training process stopped") |
|
|
self.save_status(state='stopped', message='Training stopped') |
|
|
|
|
|
return "Training stopped successfully", self.get_logs() |
|
|
|
|
|
except Exception as e: |
|
|
error_msg = f"Error stopping training: {str(e)}" |
|
|
self.append_log(error_msg) |
|
|
if self.pid_file.exists(): |
|
|
self.pid_file.unlink() |
|
|
return "Error stopping training", error_msg |
|
|
|
|
|
def pause_training(self) -> Tuple[str, str]: |
|
|
"""Pause training process by sending SIGUSR1""" |
|
|
if not self.is_training_running(): |
|
|
return "No training process found", self.get_logs() |
|
|
|
|
|
try: |
|
|
with open(self.pid_file, 'r') as f: |
|
|
pid = int(f.read().strip()) |
|
|
|
|
|
if psutil.pid_exists(pid): |
|
|
os.kill(pid, signal.SIGUSR1) |
|
|
self.save_status(state='paused', message='Training paused') |
|
|
self.append_log("Training paused") |
|
|
|
|
|
return "Training paused", self.get_logs() |
|
|
|
|
|
except Exception as e: |
|
|
error_msg = f"Error pausing training: {str(e)}" |
|
|
self.append_log(error_msg) |
|
|
return "Error pausing training", error_msg |
|
|
|
|
|
def resume_training(self) -> Tuple[str, str]: |
|
|
"""Resume training process by sending SIGUSR2""" |
|
|
if not self.is_training_running(): |
|
|
return "No training process found", self.get_logs() |
|
|
|
|
|
try: |
|
|
with open(self.pid_file, 'r') as f: |
|
|
pid = int(f.read().strip()) |
|
|
|
|
|
if psutil.pid_exists(pid): |
|
|
os.kill(pid, signal.SIGUSR2) |
|
|
self.save_status(state='training', message='Training resumed') |
|
|
self.append_log("Training resumed") |
|
|
|
|
|
return "Training resumed", self.get_logs() |
|
|
|
|
|
except Exception as e: |
|
|
error_msg = f"Error resuming training: {str(e)}" |
|
|
self.append_log(error_msg) |
|
|
return "Error resuming training", error_msg |
|
|
|
|
|
def is_training_running(self) -> bool: |
|
|
"""Check if training is currently running""" |
|
|
if not self.pid_file.exists(): |
|
|
return False |
|
|
|
|
|
try: |
|
|
with open(self.pid_file, 'r') as f: |
|
|
pid = int(f.read().strip()) |
|
|
return psutil.pid_exists(pid) |
|
|
except: |
|
|
return False |
|
|
|
|
|
def clear_training_data(self) -> str: |
|
|
"""Clear all training data""" |
|
|
if self.is_training_running(): |
|
|
return gr.Error("Cannot clear data while training is running") |
|
|
|
|
|
try: |
|
|
for file in TRAINING_VIDEOS_PATH.glob("*.*"): |
|
|
file.unlink() |
|
|
for file in TRAINING_PATH.glob("*.*"): |
|
|
file.unlink() |
|
|
|
|
|
self.append_log("Cleared all training data") |
|
|
return "Training data cleared successfully" |
|
|
|
|
|
except Exception as e: |
|
|
error_msg = f"Error clearing training data: {str(e)}" |
|
|
self.append_log(error_msg) |
|
|
return error_msg |
|
|
|
|
|
def save_status(self, state: str, **kwargs) -> None: |
|
|
"""Save current training status""" |
|
|
status = { |
|
|
'status': state, |
|
|
'timestamp': datetime.now().isoformat(), |
|
|
**kwargs |
|
|
} |
|
|
if state === "Training started" or state == "initializing": |
|
|
gr.Info("Initializing model and dataset..") |
|
|
elif state == "training": |
|
|
gr.Info("Training started!") |
|
|
elif state == "completed": |
|
|
gr.Info("Training completed!") |
|
|
|
|
|
with open(self.status_file, 'w') as f: |
|
|
json.dump(status, f, indent=2) |
|
|
|
|
|
def _start_log_monitor(self, process: subprocess.Popen) -> None: |
|
|
"""Start monitoring process output for logs""" |
|
|
|
|
|
|
|
|
def monitor(): |
|
|
self.append_log("Starting log monitor thread") |
|
|
|
|
|
def read_stream(stream, is_error=False): |
|
|
if stream: |
|
|
output = stream.readline() |
|
|
if output: |
|
|
|
|
|
line = output.strip() |
|
|
if is_error: |
|
|
|
|
|
|
|
|
|
|
|
self.append_log(line) |
|
|
else: |
|
|
self.append_log(line) |
|
|
|
|
|
metrics = parse_training_log(line) |
|
|
if metrics: |
|
|
status = self.get_status() |
|
|
status.update(metrics) |
|
|
self.save_status(**status) |
|
|
return True |
|
|
return False |
|
|
|
|
|
|
|
|
while process.poll() is None: |
|
|
outputs = [process.stdout, process.stderr] |
|
|
readable, _, _ = select.select(outputs, [], [], 1.0) |
|
|
|
|
|
for stream in readable: |
|
|
is_error = (stream == process.stderr) |
|
|
read_stream(stream, is_error) |
|
|
|
|
|
|
|
|
while read_stream(process.stdout): |
|
|
pass |
|
|
while read_stream(process.stderr, True): |
|
|
pass |
|
|
|
|
|
|
|
|
return_code = process.poll() |
|
|
if return_code == 0: |
|
|
success_msg = "Training completed successfully" |
|
|
self.append_log(success_msg) |
|
|
gr.Info(success_msg) |
|
|
self.save_status(state='completed', message=success_msg) |
|
|
|
|
|
|
|
|
session = self.load_session() |
|
|
if session and session['params'].get('repo_id'): |
|
|
repo_id = session['params']['repo_id'] |
|
|
latest_run = max(Path(OUTPUT_PATH).glob('*'), key=os.path.getmtime) |
|
|
if self.upload_to_hub(latest_run, repo_id): |
|
|
self.append_log(f"Model uploaded to {repo_id}") |
|
|
else: |
|
|
self.append_log("Failed to upload model to hub") |
|
|
else: |
|
|
error_msg = f"Training failed with return code {return_code}" |
|
|
self.append_log(error_msg) |
|
|
logger.error(error_msg) |
|
|
self.save_status(state='error', message=error_msg) |
|
|
|
|
|
|
|
|
if self.pid_file.exists(): |
|
|
self.pid_file.unlink() |
|
|
|
|
|
monitor_thread = threading.Thread(target=monitor) |
|
|
monitor_thread.daemon = True |
|
|
monitor_thread.start() |
|
|
|
|
|
def upload_to_hub(self, model_path: Path, repo_id: str) -> bool: |
|
|
"""Upload model to Hugging Face Hub |
|
|
|
|
|
Args: |
|
|
model_path: Path to model files |
|
|
repo_id: Repository ID (username/model-name) |
|
|
|
|
|
Returns: |
|
|
bool: Whether upload was successful |
|
|
""" |
|
|
try: |
|
|
token = os.getenv("HF_API_TOKEN") |
|
|
if not token: |
|
|
self.append_log("Error: HF_API_TOKEN not set") |
|
|
return False |
|
|
|
|
|
|
|
|
create_repo(repo_id, token=token, repo_type="model", exist_ok=True) |
|
|
|
|
|
|
|
|
upload_folder( |
|
|
folder_path=str(OUTPUT_PATH), |
|
|
repo_id=repo_id, |
|
|
repo_type="model", |
|
|
commit_message="Training completed" |
|
|
) |
|
|
|
|
|
return True |
|
|
except Exception as e: |
|
|
self.append_log(f"Error uploading to hub: {str(e)}") |
|
|
return False |
|
|
|
|
|
def get_model_output_safetensors(self) -> str: |
|
|
"""Return the path to the model safetensors |
|
|
|
|
|
|
|
|
Returns: |
|
|
Path to created ZIP file |
|
|
""" |
|
|
|
|
|
model_output_safetensors_path = OUTPUT_PATH / "pytorch_lora_weights.safetensors" |
|
|
return str(model_output_safetensors_path) |
|
|
|
|
|
def create_training_dataset_zip(self) -> str: |
|
|
"""Create a ZIP file containing all training data |
|
|
|
|
|
|
|
|
Returns: |
|
|
Path to created ZIP file |
|
|
""" |
|
|
|
|
|
with tempfile.NamedTemporaryFile(suffix='.zip', delete=False) as temp_zip: |
|
|
temp_zip_path = str(temp_zip.name) |
|
|
print(f"Creating zip file for {TRAINING_PATH}..") |
|
|
try: |
|
|
make_archive(TRAINING_PATH, temp_zip_path) |
|
|
print(f"Zip file created!") |
|
|
return temp_zip_path |
|
|
except Exception as e: |
|
|
print(f"Failed to create zip: {str(e)}") |
|
|
raise gr.Error(f"Failed to create zip: {str(e)}") |
|
|
|
|
|
|