TrainL / app.py
Allex21's picture
Upload 24 files
eac965b verified
#!/usr/bin/env python3
"""
LoRA Trainer Funcional para Hugging Face
Baseado no kohya-ss sd-scripts
"""
import gradio as gr
import os
import sys
import json
import subprocess
import shutil
import zipfile
import tempfile
import toml
import logging
from pathlib import Path
from typing import Optional, Tuple, List, Dict, Any
import time
import threading
import queue
# Adicionar o diretório sd-scripts ao path
sys.path.insert(0, str(Path(__file__).parent / "sd-scripts"))
# Configurar logging
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
logger = logging.getLogger(__name__)
class LoRATrainerHF:
def __init__(self):
self.base_dir = Path("/tmp/lora_training")
self.base_dir.mkdir(exist_ok=True)
self.models_dir = self.base_dir / "models"
self.models_dir.mkdir(exist_ok=True)
self.projects_dir = self.base_dir / "projects"
self.projects_dir.mkdir(exist_ok=True)
self.sd_scripts_dir = Path(__file__).parent / "sd-scripts"
# URLs dos modelos
self.model_urls = {
"Anime (animefull-final-pruned)": "https://huggingface.co/hollowstrawberry/stable-diffusion-guide/resolve/main/models/animefull-final-pruned-fp16.safetensors",
"AnyLoRA": "https://huggingface.co/Lykon/AnyLoRA/resolve/main/AnyLoRA_noVae_fp16-pruned.ckpt",
"Stable Diffusion 1.5": "https://huggingface.co/runwayml/stable-diffusion-v1-5/resolve/main/v1-5-pruned-emaonly.safetensors",
"Waifu Diffusion 1.4": "https://huggingface.co/hakurei/waifu-diffusion-v1-4/resolve/main/wd-1-4-anime_e1.ckpt"
}
self.training_process = None
self.training_output_queue = queue.Queue()
def install_dependencies(self) -> str:
"""Instala as dependências necessárias"""
try:
logger.info("Instalando dependências...")
# Lista de pacotes necessários
packages = [
"torch>=2.0.0",
"torchvision>=0.15.0",
"diffusers>=0.21.0",
"transformers>=4.25.0",
"accelerate>=0.20.0",
"safetensors>=0.3.0",
"huggingface-hub>=0.16.0",
"xformers>=0.0.20",
"bitsandbytes>=0.41.0",
"opencv-python>=4.7.0",
"Pillow>=9.0.0",
"numpy>=1.21.0",
"tqdm>=4.64.0",
"toml>=0.10.0",
"tensorboard>=2.13.0",
"wandb>=0.15.0",
"scipy>=1.9.0",
"matplotlib>=3.5.0",
"datasets>=2.14.0",
"peft>=0.5.0",
"omegaconf>=2.3.0"
]
# Instalar pacotes
for package in packages:
try:
subprocess.run([
sys.executable, "-m", "pip", "install", package, "--quiet"
], check=True, capture_output=True, text=True)
logger.info(f"✓ {package} instalado")
except subprocess.CalledProcessError as e:
logger.warning(f"⚠ Erro ao instalar {package}: {e}")
return "✅ Dependências instaladas com sucesso!"
except Exception as e:
logger.error(f"Erro ao instalar dependências: {e}")
return f"❌ Erro ao instalar dependências: {e}"
def download_model(self, model_choice: str, custom_url: str = "") -> str:
"""Download do modelo base"""
try:
if custom_url.strip():
model_url = custom_url.strip()
model_name = model_url.split("/")[-1]
else:
if model_choice not in self.model_urls:
return f"❌ Modelo '{model_choice}' não encontrado"
model_url = self.model_urls[model_choice]
model_name = model_url.split("/")[-1]
model_path = self.models_dir / model_name
if model_path.exists():
return f"✅ Modelo já existe: {model_name}"
logger.info(f"Baixando modelo: {model_url}")
# Download usando wget
result = subprocess.run([
"wget", "-O", str(model_path), model_url, "--progress=bar:force"
], capture_output=True, text=True)
if result.returncode == 0:
return f"✅ Modelo baixado: {model_name} ({model_path.stat().st_size // (1024*1024)} MB)"
else:
return f"❌ Erro no download: {result.stderr}"
except Exception as e:
logger.error(f"Erro ao baixar modelo: {e}")
return f"❌ Erro ao baixar modelo: {e}"
def process_dataset(self, dataset_zip, project_name: str) -> Tuple[str, str]:
"""Processa o dataset enviado"""
try:
if not dataset_zip:
return "❌ Nenhum dataset foi enviado", ""
if not project_name.strip():
return "❌ Nome do projeto é obrigatório", ""
project_name = project_name.strip().replace(" ", "_")
project_dir = self.projects_dir / project_name
project_dir.mkdir(exist_ok=True)
dataset_dir = project_dir / "dataset"
if dataset_dir.exists():
shutil.rmtree(dataset_dir)
dataset_dir.mkdir()
# Extrair ZIP
with zipfile.ZipFile(dataset_zip.name, 'r') as zip_ref:
zip_ref.extractall(dataset_dir)
# Analisar dataset
image_extensions = {'.jpg', '.jpeg', '.png', '.webp', '.bmp', '.tiff'}
images = []
captions = []
for file_path in dataset_dir.rglob("*"):
if file_path.suffix.lower() in image_extensions:
images.append(file_path)
# Procurar caption
caption_path = file_path.with_suffix('.txt')
if caption_path.exists():
captions.append(caption_path)
info = f"✅ Dataset processado!\n"
info += f"📁 Projeto: {project_name}\n"
info += f"🖼️ Imagens: {len(images)}\n"
info += f"📝 Captions: {len(captions)}\n"
info += f"📂 Diretório: {dataset_dir}"
return info, str(dataset_dir)
except Exception as e:
logger.error(f"Erro ao processar dataset: {e}")
return f"❌ Erro ao processar dataset: {e}", ""
def create_training_config(self,
project_name: str,
dataset_dir: str,
model_choice: str,
custom_model_url: str,
resolution: int,
batch_size: int,
epochs: int,
learning_rate: float,
text_encoder_lr: float,
network_dim: int,
network_alpha: int,
lora_type: str,
optimizer: str,
scheduler: str,
flip_aug: bool,
shuffle_caption: bool,
keep_tokens: int,
clip_skip: int,
mixed_precision: str,
save_every_n_epochs: int,
max_train_steps: int) -> str:
"""Cria configuração de treinamento"""
try:
if not project_name.strip():
return "❌ Nome do projeto é obrigatório"
project_name = project_name.strip().replace(" ", "_")
project_dir = self.projects_dir / project_name
project_dir.mkdir(exist_ok=True)
output_dir = project_dir / "output"
output_dir.mkdir(exist_ok=True)
log_dir = project_dir / "logs"
log_dir.mkdir(exist_ok=True)
# Determinar modelo
if custom_model_url.strip():
model_name = custom_model_url.strip().split("/")[-1]
else:
model_name = self.model_urls[model_choice].split("/")[-1]
model_path = self.models_dir / model_name
if not model_path.exists():
return f"❌ Modelo não encontrado: {model_name}. Faça o download primeiro."
# Configuração do dataset
dataset_config = {
"general": {
"shuffle_caption": shuffle_caption,
"caption_extension": ".txt",
"keep_tokens": keep_tokens,
"flip_aug": flip_aug,
"color_aug": False,
"face_crop_aug_range": None,
"random_crop": False,
"debug_dataset": False
},
"datasets": [{
"resolution": resolution,
"batch_size": batch_size,
"subsets": [{
"image_dir": str(dataset_dir),
"num_repeats": 1
}]
}]
}
# Configuração de treinamento
training_config = {
"model_arguments": {
"pretrained_model_name_or_path": str(model_path),
"v2": False,
"v_parameterization": False,
"clip_skip": clip_skip
},
"dataset_arguments": {
"dataset_config": str(project_dir / "dataset_config.toml")
},
"training_arguments": {
"output_dir": str(output_dir),
"output_name": project_name,
"save_precision": "fp16",
"save_every_n_epochs": save_every_n_epochs,
"max_train_epochs": epochs if max_train_steps == 0 else None,
"max_train_steps": max_train_steps if max_train_steps > 0 else None,
"train_batch_size": batch_size,
"gradient_accumulation_steps": 1,
"learning_rate": learning_rate,
"text_encoder_lr": text_encoder_lr,
"lr_scheduler": scheduler,
"lr_warmup_steps": 0,
"optimizer_type": optimizer,
"mixed_precision": mixed_precision,
"save_model_as": "safetensors",
"seed": 42,
"max_data_loader_n_workers": 2,
"persistent_data_loader_workers": True,
"gradient_checkpointing": True,
"xformers": True,
"lowram": True,
"cache_latents": True,
"cache_latents_to_disk": True,
"logging_dir": str(log_dir),
"log_with": "tensorboard"
},
"network_arguments": {
"network_module": "networks.lora" if lora_type == "LoRA" else "networks.dylora",
"network_dim": network_dim,
"network_alpha": network_alpha,
"network_train_unet_only": False,
"network_train_text_encoder_only": False
}
}
# Adicionar argumentos específicos para LoCon
if lora_type == "LoCon":
training_config["network_arguments"]["network_module"] = "networks.lora"
training_config["network_arguments"]["conv_dim"] = max(1, network_dim // 2)
training_config["network_arguments"]["conv_alpha"] = max(1, network_alpha // 2)
# Salvar configurações
dataset_config_path = project_dir / "dataset_config.toml"
training_config_path = project_dir / "training_config.toml"
with open(dataset_config_path, 'w') as f:
toml.dump(dataset_config, f)
with open(training_config_path, 'w') as f:
toml.dump(training_config, f)
return f"✅ Configuração criada!\n📁 Dataset: {dataset_config_path}\n⚙️ Treinamento: {training_config_path}"
except Exception as e:
logger.error(f"Erro ao criar configuração: {e}")
return f"❌ Erro ao criar configuração: {e}"
def start_training(self, project_name: str) -> str:
"""Inicia o treinamento"""
try:
if not project_name.strip():
return "❌ Nome do projeto é obrigatório"
project_name = project_name.strip().replace(" ", "_")
project_dir = self.projects_dir / project_name
training_config_path = project_dir / "training_config.toml"
if not training_config_path.exists():
return "❌ Configuração não encontrada. Crie a configuração primeiro."
# Script de treinamento
train_script = self.sd_scripts_dir / "train_network.py"
if not train_script.exists():
return "❌ Script de treinamento não encontrado"
# Comando de treinamento
cmd = [
sys.executable,
str(train_script),
"--config_file", str(training_config_path)
]
logger.info(f"Iniciando treinamento: {' '.join(cmd)}")
# Executar em thread separada
def run_training():
try:
process = subprocess.Popen(
cmd,
stdout=subprocess.PIPE,
stderr=subprocess.STDOUT,
text=True,
bufsize=1,
universal_newlines=True,
cwd=str(self.sd_scripts_dir)
)
self.training_process = process
for line in process.stdout:
self.training_output_queue.put(line.strip())
logger.info(line.strip())
process.wait()
if process.returncode == 0:
self.training_output_queue.put("✅ TREINAMENTO CONCLUÍDO COM SUCESSO!")
else:
self.training_output_queue.put(f"❌ TREINAMENTO FALHOU (código {process.returncode})")
except Exception as e:
self.training_output_queue.put(f"❌ ERRO NO TREINAMENTO: {e}")
finally:
self.training_process = None
# Iniciar thread
training_thread = threading.Thread(target=run_training)
training_thread.daemon = True
training_thread.start()
return "🚀 Treinamento iniciado! Acompanhe o progresso abaixo."
except Exception as e:
logger.error(f"Erro ao iniciar treinamento: {e}")
return f"❌ Erro ao iniciar treinamento: {e}"
def get_training_output(self) -> str:
"""Obtém output do treinamento"""
output_lines = []
try:
while not self.training_output_queue.empty():
line = self.training_output_queue.get_nowait()
output_lines.append(line)
except queue.Empty:
pass
if output_lines:
return "\n".join(output_lines)
elif self.training_process and self.training_process.poll() is None:
return "🔄 Treinamento em andamento..."
else:
return "⏸️ Nenhum treinamento ativo"
def stop_training(self) -> str:
"""Para o treinamento"""
try:
if self.training_process and self.training_process.poll() is None:
self.training_process.terminate()
self.training_process.wait(timeout=10)
return "⏹️ Treinamento interrompido"
else:
return "ℹ️ Nenhum treinamento ativo para parar"
except Exception as e:
return f"❌ Erro ao parar treinamento: {e}"
def list_output_files(self, project_name: str) -> List[str]:
"""Lista arquivos de saída"""
try:
if not project_name.strip():
return []
project_name = project_name.strip().replace(" ", "_")
project_dir = self.projects_dir / project_name
output_dir = project_dir / "output"
if not output_dir.exists():
return []
files = []
for file_path in output_dir.rglob("*.safetensors"):
size_mb = file_path.stat().st_size // (1024 * 1024)
files.append(f"{file_path.name} ({size_mb} MB)")
return sorted(files, reverse=True) # Mais recentes primeiro
except Exception as e:
logger.error(f"Erro ao listar arquivos: {e}")
return []
# Instância global
trainer = LoRATrainerHF()
def create_interface():
"""Cria a interface Gradio"""
with gr.Blocks(title="LoRA Trainer Funcional - Hugging Face", theme=gr.themes.Soft()) as interface:
gr.Markdown("""
# 🎨 LoRA Trainer Funcional para Hugging Face
**Treine seus próprios modelos LoRA para Stable Diffusion de forma profissional!**
Esta ferramenta é baseada no kohya-ss sd-scripts e oferece treinamento real e funcional de modelos LoRA.
""")
# Estado para armazenar informações
dataset_dir_state = gr.State("")
with gr.Tab("🔧 Instalação"):
gr.Markdown("### Primeiro, instale as dependências necessárias:")
install_btn = gr.Button("📦 Instalar Dependências", variant="primary", size="lg")
install_status = gr.Textbox(label="Status da Instalação", lines=3, interactive=False)
install_btn.click(
fn=trainer.install_dependencies,
outputs=install_status
)
with gr.Tab("📁 Configuração do Projeto"):
with gr.Row():
project_name = gr.Textbox(
label="Nome do Projeto",
placeholder="meu_lora_anime",
info="Nome único para seu projeto (sem espaços especiais)"
)
gr.Markdown("### 📥 Download do Modelo Base")
with gr.Row():
model_choice = gr.Dropdown(
choices=list(trainer.model_urls.keys()),
label="Modelo Base Pré-definido",
value="Anime (animefull-final-pruned)",
info="Escolha um modelo base ou use URL personalizada"
)
custom_model_url = gr.Textbox(
label="URL Personalizada (opcional)",
placeholder="https://huggingface.co/...",
info="URL direta para download de modelo personalizado"
)
download_btn = gr.Button("📥 Baixar Modelo", variant="primary")
download_status = gr.Textbox(label="Status do Download", lines=2, interactive=False)
gr.Markdown("### 📊 Upload do Dataset")
gr.Markdown("""
**Formato do Dataset:**
- Crie um arquivo ZIP contendo suas imagens
- Para cada imagem, inclua um arquivo .txt com o mesmo nome contendo as tags/descrições
- Exemplo: `imagem1.jpg` + `imagem1.txt`
""")
dataset_upload = gr.File(
label="Upload do Dataset (ZIP)",
file_types=[".zip"]
)
process_btn = gr.Button("📊 Processar Dataset", variant="primary")
dataset_status = gr.Textbox(label="Status do Dataset", lines=4, interactive=False)
with gr.Tab("⚙️ Parâmetros de Treinamento"):
with gr.Row():
with gr.Column():
gr.Markdown("#### 🖼️ Configurações de Imagem")
resolution = gr.Slider(
minimum=512, maximum=1024, step=64, value=512,
label="Resolução",
info="Resolução das imagens (512 = mais rápido, 1024 = melhor qualidade)"
)
batch_size = gr.Slider(
minimum=1, maximum=8, step=1, value=1,
label="Batch Size",
info="Imagens por lote (aumente se tiver GPU potente)"
)
flip_aug = gr.Checkbox(
label="Flip Augmentation",
info="Espelhar imagens para aumentar dataset"
)
shuffle_caption = gr.Checkbox(
value=True,
label="Shuffle Caption",
info="Embaralhar ordem das tags"
)
keep_tokens = gr.Slider(
minimum=0, maximum=5, step=1, value=1,
label="Keep Tokens",
info="Número de tokens iniciais que não serão embaralhados"
)
with gr.Column():
gr.Markdown("#### 🎯 Configurações de Treinamento")
epochs = gr.Slider(
minimum=1, maximum=100, step=1, value=10,
label="Épocas",
info="Número de épocas de treinamento"
)
max_train_steps = gr.Number(
value=0,
label="Max Train Steps (0 = usar épocas)",
info="Número máximo de steps (deixe 0 para usar épocas)"
)
save_every_n_epochs = gr.Slider(
minimum=1, maximum=10, step=1, value=1,
label="Salvar a cada N épocas",
info="Frequência de salvamento dos checkpoints"
)
mixed_precision = gr.Dropdown(
choices=["fp16", "bf16", "no"],
value="fp16",
label="Mixed Precision",
info="fp16 = mais rápido, bf16 = mais estável"
)
clip_skip = gr.Slider(
minimum=1, maximum=12, step=1, value=2,
label="CLIP Skip",
info="Camadas CLIP a pular (2 para anime, 1 para realista)"
)
with gr.Row():
with gr.Column():
gr.Markdown("#### 📚 Learning Rate")
learning_rate = gr.Number(
value=1e-4,
label="Learning Rate (UNet)",
info="Taxa de aprendizado principal"
)
text_encoder_lr = gr.Number(
value=5e-5,
label="Learning Rate (Text Encoder)",
info="Taxa de aprendizado do text encoder"
)
scheduler = gr.Dropdown(
choices=["cosine", "cosine_with_restarts", "constant", "constant_with_warmup", "linear"],
value="cosine_with_restarts",
label="LR Scheduler",
info="Algoritmo de ajuste da learning rate"
)
optimizer = gr.Dropdown(
choices=["AdamW8bit", "AdamW", "Lion", "SGD"],
value="AdamW8bit",
label="Otimizador",
info="AdamW8bit = menos memória"
)
with gr.Column():
gr.Markdown("#### 🧠 Arquitetura LoRA")
lora_type = gr.Radio(
choices=["LoRA", "LoCon"],
value="LoRA",
label="Tipo de LoRA",
info="LoRA = geral, LoCon = estilos artísticos"
)
network_dim = gr.Slider(
minimum=4, maximum=128, step=4, value=32,
label="Network Dimension",
info="Dimensão da rede (maior = mais detalhes, mais memória)"
)
network_alpha = gr.Slider(
minimum=1, maximum=128, step=1, value=16,
label="Network Alpha",
info="Controla a força do LoRA (geralmente dim/2)"
)
with gr.Tab("🚀 Treinamento"):
create_config_btn = gr.Button("📝 Criar Configuração de Treinamento", variant="primary", size="lg")
config_status = gr.Textbox(label="Status da Configuração", lines=3, interactive=False)
with gr.Row():
start_training_btn = gr.Button("🎯 Iniciar Treinamento", variant="primary", size="lg")
stop_training_btn = gr.Button("⏹️ Parar Treinamento", variant="stop")
training_output = gr.Textbox(
label="Output do Treinamento",
lines=15,
interactive=False,
info="Acompanhe o progresso do treinamento em tempo real"
)
# Auto-refresh do output
def update_output():
return trainer.get_training_output()
with gr.Tab("📥 Download dos Resultados"):
refresh_files_btn = gr.Button("🔄 Atualizar Lista de Arquivos", variant="secondary")
output_files = gr.Dropdown(
label="Arquivos LoRA Gerados",
choices=[],
info="Selecione um arquivo para download"
)
download_info = gr.Markdown("ℹ️ Os arquivos LoRA estarão disponíveis após o treinamento")
# Event handlers
download_btn.click(
fn=trainer.download_model,
inputs=[model_choice, custom_model_url],
outputs=download_status
)
process_btn.click(
fn=trainer.process_dataset,
inputs=[dataset_upload, project_name],
outputs=[dataset_status, dataset_dir_state]
)
create_config_btn.click(
fn=trainer.create_training_config,
inputs=[
project_name, dataset_dir_state, model_choice, custom_model_url,
resolution, batch_size, epochs, learning_rate, text_encoder_lr,
network_dim, network_alpha, lora_type, optimizer, scheduler,
flip_aug, shuffle_caption, keep_tokens, clip_skip, mixed_precision,
save_every_n_epochs, max_train_steps
],
outputs=config_status
)
start_training_btn.click(
fn=trainer.start_training,
inputs=project_name,
outputs=training_output
)
stop_training_btn.click(
fn=trainer.stop_training,
outputs=training_output
)
refresh_files_btn.click(
fn=trainer.list_output_files,
inputs=project_name,
outputs=output_files
)
return interface
if __name__ == "__main__":
print("🚀 Iniciando LoRA Trainer Funcional...")
interface = create_interface()
interface.launch(
server_name="0.0.0.0",
server_port=7860,
share=False,
show_error=True
)