|
|
|
|
|
""" |
|
|
LoRA Trainer Funcional para Hugging Face |
|
|
Baseado no kohya-ss sd-scripts |
|
|
""" |
|
|
|
|
|
import gradio as gr |
|
|
import os |
|
|
import sys |
|
|
import json |
|
|
import subprocess |
|
|
import shutil |
|
|
import zipfile |
|
|
import tempfile |
|
|
import toml |
|
|
import logging |
|
|
from pathlib import Path |
|
|
from typing import Optional, Tuple, List, Dict, Any |
|
|
import time |
|
|
import threading |
|
|
import queue |
|
|
|
|
|
|
|
|
sys.path.insert(0, str(Path(__file__).parent / "sd-scripts")) |
|
|
|
|
|
|
|
|
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s') |
|
|
logger = logging.getLogger(__name__) |
|
|
|
|
|
class LoRATrainerHF: |
|
|
def __init__(self): |
|
|
self.base_dir = Path("/tmp/lora_training") |
|
|
self.base_dir.mkdir(exist_ok=True) |
|
|
|
|
|
self.models_dir = self.base_dir / "models" |
|
|
self.models_dir.mkdir(exist_ok=True) |
|
|
|
|
|
self.projects_dir = self.base_dir / "projects" |
|
|
self.projects_dir.mkdir(exist_ok=True) |
|
|
|
|
|
self.sd_scripts_dir = Path(__file__).parent / "sd-scripts" |
|
|
|
|
|
|
|
|
self.model_urls = { |
|
|
"Anime (animefull-final-pruned)": "https://huggingface.co/hollowstrawberry/stable-diffusion-guide/resolve/main/models/animefull-final-pruned-fp16.safetensors", |
|
|
"AnyLoRA": "https://huggingface.co/Lykon/AnyLoRA/resolve/main/AnyLoRA_noVae_fp16-pruned.ckpt", |
|
|
"Stable Diffusion 1.5": "https://huggingface.co/runwayml/stable-diffusion-v1-5/resolve/main/v1-5-pruned-emaonly.safetensors", |
|
|
"Waifu Diffusion 1.4": "https://huggingface.co/hakurei/waifu-diffusion-v1-4/resolve/main/wd-1-4-anime_e1.ckpt" |
|
|
} |
|
|
|
|
|
self.training_process = None |
|
|
self.training_output_queue = queue.Queue() |
|
|
|
|
|
def install_dependencies(self) -> str: |
|
|
"""Instala as dependências necessárias""" |
|
|
try: |
|
|
logger.info("Instalando dependências...") |
|
|
|
|
|
|
|
|
packages = [ |
|
|
"torch>=2.0.0", |
|
|
"torchvision>=0.15.0", |
|
|
"diffusers>=0.21.0", |
|
|
"transformers>=4.25.0", |
|
|
"accelerate>=0.20.0", |
|
|
"safetensors>=0.3.0", |
|
|
"huggingface-hub>=0.16.0", |
|
|
"xformers>=0.0.20", |
|
|
"bitsandbytes>=0.41.0", |
|
|
"opencv-python>=4.7.0", |
|
|
"Pillow>=9.0.0", |
|
|
"numpy>=1.21.0", |
|
|
"tqdm>=4.64.0", |
|
|
"toml>=0.10.0", |
|
|
"tensorboard>=2.13.0", |
|
|
"wandb>=0.15.0", |
|
|
"scipy>=1.9.0", |
|
|
"matplotlib>=3.5.0", |
|
|
"datasets>=2.14.0", |
|
|
"peft>=0.5.0", |
|
|
"omegaconf>=2.3.0" |
|
|
] |
|
|
|
|
|
|
|
|
for package in packages: |
|
|
try: |
|
|
subprocess.run([ |
|
|
sys.executable, "-m", "pip", "install", package, "--quiet" |
|
|
], check=True, capture_output=True, text=True) |
|
|
logger.info(f"✓ {package} instalado") |
|
|
except subprocess.CalledProcessError as e: |
|
|
logger.warning(f"⚠ Erro ao instalar {package}: {e}") |
|
|
|
|
|
return "✅ Dependências instaladas com sucesso!" |
|
|
|
|
|
except Exception as e: |
|
|
logger.error(f"Erro ao instalar dependências: {e}") |
|
|
return f"❌ Erro ao instalar dependências: {e}" |
|
|
|
|
|
def download_model(self, model_choice: str, custom_url: str = "") -> str: |
|
|
"""Download do modelo base""" |
|
|
try: |
|
|
if custom_url.strip(): |
|
|
model_url = custom_url.strip() |
|
|
model_name = model_url.split("/")[-1] |
|
|
else: |
|
|
if model_choice not in self.model_urls: |
|
|
return f"❌ Modelo '{model_choice}' não encontrado" |
|
|
model_url = self.model_urls[model_choice] |
|
|
model_name = model_url.split("/")[-1] |
|
|
|
|
|
model_path = self.models_dir / model_name |
|
|
|
|
|
if model_path.exists(): |
|
|
return f"✅ Modelo já existe: {model_name}" |
|
|
|
|
|
logger.info(f"Baixando modelo: {model_url}") |
|
|
|
|
|
|
|
|
result = subprocess.run([ |
|
|
"wget", "-O", str(model_path), model_url, "--progress=bar:force" |
|
|
], capture_output=True, text=True) |
|
|
|
|
|
if result.returncode == 0: |
|
|
return f"✅ Modelo baixado: {model_name} ({model_path.stat().st_size // (1024*1024)} MB)" |
|
|
else: |
|
|
return f"❌ Erro no download: {result.stderr}" |
|
|
|
|
|
except Exception as e: |
|
|
logger.error(f"Erro ao baixar modelo: {e}") |
|
|
return f"❌ Erro ao baixar modelo: {e}" |
|
|
|
|
|
def process_dataset(self, dataset_zip, project_name: str) -> Tuple[str, str]: |
|
|
"""Processa o dataset enviado""" |
|
|
try: |
|
|
if not dataset_zip: |
|
|
return "❌ Nenhum dataset foi enviado", "" |
|
|
|
|
|
if not project_name.strip(): |
|
|
return "❌ Nome do projeto é obrigatório", "" |
|
|
|
|
|
project_name = project_name.strip().replace(" ", "_") |
|
|
project_dir = self.projects_dir / project_name |
|
|
project_dir.mkdir(exist_ok=True) |
|
|
|
|
|
dataset_dir = project_dir / "dataset" |
|
|
if dataset_dir.exists(): |
|
|
shutil.rmtree(dataset_dir) |
|
|
dataset_dir.mkdir() |
|
|
|
|
|
|
|
|
with zipfile.ZipFile(dataset_zip.name, 'r') as zip_ref: |
|
|
zip_ref.extractall(dataset_dir) |
|
|
|
|
|
|
|
|
image_extensions = {'.jpg', '.jpeg', '.png', '.webp', '.bmp', '.tiff'} |
|
|
images = [] |
|
|
captions = [] |
|
|
|
|
|
for file_path in dataset_dir.rglob("*"): |
|
|
if file_path.suffix.lower() in image_extensions: |
|
|
images.append(file_path) |
|
|
|
|
|
|
|
|
caption_path = file_path.with_suffix('.txt') |
|
|
if caption_path.exists(): |
|
|
captions.append(caption_path) |
|
|
|
|
|
info = f"✅ Dataset processado!\n" |
|
|
info += f"📁 Projeto: {project_name}\n" |
|
|
info += f"🖼️ Imagens: {len(images)}\n" |
|
|
info += f"📝 Captions: {len(captions)}\n" |
|
|
info += f"📂 Diretório: {dataset_dir}" |
|
|
|
|
|
return info, str(dataset_dir) |
|
|
|
|
|
except Exception as e: |
|
|
logger.error(f"Erro ao processar dataset: {e}") |
|
|
return f"❌ Erro ao processar dataset: {e}", "" |
|
|
|
|
|
def create_training_config(self, |
|
|
project_name: str, |
|
|
dataset_dir: str, |
|
|
model_choice: str, |
|
|
custom_model_url: str, |
|
|
resolution: int, |
|
|
batch_size: int, |
|
|
epochs: int, |
|
|
learning_rate: float, |
|
|
text_encoder_lr: float, |
|
|
network_dim: int, |
|
|
network_alpha: int, |
|
|
lora_type: str, |
|
|
optimizer: str, |
|
|
scheduler: str, |
|
|
flip_aug: bool, |
|
|
shuffle_caption: bool, |
|
|
keep_tokens: int, |
|
|
clip_skip: int, |
|
|
mixed_precision: str, |
|
|
save_every_n_epochs: int, |
|
|
max_train_steps: int) -> str: |
|
|
"""Cria configuração de treinamento""" |
|
|
try: |
|
|
if not project_name.strip(): |
|
|
return "❌ Nome do projeto é obrigatório" |
|
|
|
|
|
project_name = project_name.strip().replace(" ", "_") |
|
|
project_dir = self.projects_dir / project_name |
|
|
project_dir.mkdir(exist_ok=True) |
|
|
|
|
|
output_dir = project_dir / "output" |
|
|
output_dir.mkdir(exist_ok=True) |
|
|
|
|
|
log_dir = project_dir / "logs" |
|
|
log_dir.mkdir(exist_ok=True) |
|
|
|
|
|
|
|
|
if custom_model_url.strip(): |
|
|
model_name = custom_model_url.strip().split("/")[-1] |
|
|
else: |
|
|
model_name = self.model_urls[model_choice].split("/")[-1] |
|
|
|
|
|
model_path = self.models_dir / model_name |
|
|
|
|
|
if not model_path.exists(): |
|
|
return f"❌ Modelo não encontrado: {model_name}. Faça o download primeiro." |
|
|
|
|
|
|
|
|
dataset_config = { |
|
|
"general": { |
|
|
"shuffle_caption": shuffle_caption, |
|
|
"caption_extension": ".txt", |
|
|
"keep_tokens": keep_tokens, |
|
|
"flip_aug": flip_aug, |
|
|
"color_aug": False, |
|
|
"face_crop_aug_range": None, |
|
|
"random_crop": False, |
|
|
"debug_dataset": False |
|
|
}, |
|
|
"datasets": [{ |
|
|
"resolution": resolution, |
|
|
"batch_size": batch_size, |
|
|
"subsets": [{ |
|
|
"image_dir": str(dataset_dir), |
|
|
"num_repeats": 1 |
|
|
}] |
|
|
}] |
|
|
} |
|
|
|
|
|
|
|
|
training_config = { |
|
|
"model_arguments": { |
|
|
"pretrained_model_name_or_path": str(model_path), |
|
|
"v2": False, |
|
|
"v_parameterization": False, |
|
|
"clip_skip": clip_skip |
|
|
}, |
|
|
"dataset_arguments": { |
|
|
"dataset_config": str(project_dir / "dataset_config.toml") |
|
|
}, |
|
|
"training_arguments": { |
|
|
"output_dir": str(output_dir), |
|
|
"output_name": project_name, |
|
|
"save_precision": "fp16", |
|
|
"save_every_n_epochs": save_every_n_epochs, |
|
|
"max_train_epochs": epochs if max_train_steps == 0 else None, |
|
|
"max_train_steps": max_train_steps if max_train_steps > 0 else None, |
|
|
"train_batch_size": batch_size, |
|
|
"gradient_accumulation_steps": 1, |
|
|
"learning_rate": learning_rate, |
|
|
"text_encoder_lr": text_encoder_lr, |
|
|
"lr_scheduler": scheduler, |
|
|
"lr_warmup_steps": 0, |
|
|
"optimizer_type": optimizer, |
|
|
"mixed_precision": mixed_precision, |
|
|
"save_model_as": "safetensors", |
|
|
"seed": 42, |
|
|
"max_data_loader_n_workers": 2, |
|
|
"persistent_data_loader_workers": True, |
|
|
"gradient_checkpointing": True, |
|
|
"xformers": True, |
|
|
"lowram": True, |
|
|
"cache_latents": True, |
|
|
"cache_latents_to_disk": True, |
|
|
"logging_dir": str(log_dir), |
|
|
"log_with": "tensorboard" |
|
|
}, |
|
|
"network_arguments": { |
|
|
"network_module": "networks.lora" if lora_type == "LoRA" else "networks.dylora", |
|
|
"network_dim": network_dim, |
|
|
"network_alpha": network_alpha, |
|
|
"network_train_unet_only": False, |
|
|
"network_train_text_encoder_only": False |
|
|
} |
|
|
} |
|
|
|
|
|
|
|
|
if lora_type == "LoCon": |
|
|
training_config["network_arguments"]["network_module"] = "networks.lora" |
|
|
training_config["network_arguments"]["conv_dim"] = max(1, network_dim // 2) |
|
|
training_config["network_arguments"]["conv_alpha"] = max(1, network_alpha // 2) |
|
|
|
|
|
|
|
|
dataset_config_path = project_dir / "dataset_config.toml" |
|
|
training_config_path = project_dir / "training_config.toml" |
|
|
|
|
|
with open(dataset_config_path, 'w') as f: |
|
|
toml.dump(dataset_config, f) |
|
|
|
|
|
with open(training_config_path, 'w') as f: |
|
|
toml.dump(training_config, f) |
|
|
|
|
|
return f"✅ Configuração criada!\n📁 Dataset: {dataset_config_path}\n⚙️ Treinamento: {training_config_path}" |
|
|
|
|
|
except Exception as e: |
|
|
logger.error(f"Erro ao criar configuração: {e}") |
|
|
return f"❌ Erro ao criar configuração: {e}" |
|
|
|
|
|
def start_training(self, project_name: str) -> str: |
|
|
"""Inicia o treinamento""" |
|
|
try: |
|
|
if not project_name.strip(): |
|
|
return "❌ Nome do projeto é obrigatório" |
|
|
|
|
|
project_name = project_name.strip().replace(" ", "_") |
|
|
project_dir = self.projects_dir / project_name |
|
|
|
|
|
training_config_path = project_dir / "training_config.toml" |
|
|
if not training_config_path.exists(): |
|
|
return "❌ Configuração não encontrada. Crie a configuração primeiro." |
|
|
|
|
|
|
|
|
train_script = self.sd_scripts_dir / "train_network.py" |
|
|
if not train_script.exists(): |
|
|
return "❌ Script de treinamento não encontrado" |
|
|
|
|
|
|
|
|
cmd = [ |
|
|
sys.executable, |
|
|
str(train_script), |
|
|
"--config_file", str(training_config_path) |
|
|
] |
|
|
|
|
|
logger.info(f"Iniciando treinamento: {' '.join(cmd)}") |
|
|
|
|
|
|
|
|
def run_training(): |
|
|
try: |
|
|
process = subprocess.Popen( |
|
|
cmd, |
|
|
stdout=subprocess.PIPE, |
|
|
stderr=subprocess.STDOUT, |
|
|
text=True, |
|
|
bufsize=1, |
|
|
universal_newlines=True, |
|
|
cwd=str(self.sd_scripts_dir) |
|
|
) |
|
|
|
|
|
self.training_process = process |
|
|
|
|
|
for line in process.stdout: |
|
|
self.training_output_queue.put(line.strip()) |
|
|
logger.info(line.strip()) |
|
|
|
|
|
process.wait() |
|
|
|
|
|
if process.returncode == 0: |
|
|
self.training_output_queue.put("✅ TREINAMENTO CONCLUÍDO COM SUCESSO!") |
|
|
else: |
|
|
self.training_output_queue.put(f"❌ TREINAMENTO FALHOU (código {process.returncode})") |
|
|
|
|
|
except Exception as e: |
|
|
self.training_output_queue.put(f"❌ ERRO NO TREINAMENTO: {e}") |
|
|
finally: |
|
|
self.training_process = None |
|
|
|
|
|
|
|
|
training_thread = threading.Thread(target=run_training) |
|
|
training_thread.daemon = True |
|
|
training_thread.start() |
|
|
|
|
|
return "🚀 Treinamento iniciado! Acompanhe o progresso abaixo." |
|
|
|
|
|
except Exception as e: |
|
|
logger.error(f"Erro ao iniciar treinamento: {e}") |
|
|
return f"❌ Erro ao iniciar treinamento: {e}" |
|
|
|
|
|
def get_training_output(self) -> str: |
|
|
"""Obtém output do treinamento""" |
|
|
output_lines = [] |
|
|
try: |
|
|
while not self.training_output_queue.empty(): |
|
|
line = self.training_output_queue.get_nowait() |
|
|
output_lines.append(line) |
|
|
except queue.Empty: |
|
|
pass |
|
|
|
|
|
if output_lines: |
|
|
return "\n".join(output_lines) |
|
|
elif self.training_process and self.training_process.poll() is None: |
|
|
return "🔄 Treinamento em andamento..." |
|
|
else: |
|
|
return "⏸️ Nenhum treinamento ativo" |
|
|
|
|
|
def stop_training(self) -> str: |
|
|
"""Para o treinamento""" |
|
|
try: |
|
|
if self.training_process and self.training_process.poll() is None: |
|
|
self.training_process.terminate() |
|
|
self.training_process.wait(timeout=10) |
|
|
return "⏹️ Treinamento interrompido" |
|
|
else: |
|
|
return "ℹ️ Nenhum treinamento ativo para parar" |
|
|
except Exception as e: |
|
|
return f"❌ Erro ao parar treinamento: {e}" |
|
|
|
|
|
def list_output_files(self, project_name: str) -> List[str]: |
|
|
"""Lista arquivos de saída""" |
|
|
try: |
|
|
if not project_name.strip(): |
|
|
return [] |
|
|
|
|
|
project_name = project_name.strip().replace(" ", "_") |
|
|
project_dir = self.projects_dir / project_name |
|
|
output_dir = project_dir / "output" |
|
|
|
|
|
if not output_dir.exists(): |
|
|
return [] |
|
|
|
|
|
files = [] |
|
|
for file_path in output_dir.rglob("*.safetensors"): |
|
|
size_mb = file_path.stat().st_size // (1024 * 1024) |
|
|
files.append(f"{file_path.name} ({size_mb} MB)") |
|
|
|
|
|
return sorted(files, reverse=True) |
|
|
|
|
|
except Exception as e: |
|
|
logger.error(f"Erro ao listar arquivos: {e}") |
|
|
return [] |
|
|
|
|
|
|
|
|
trainer = LoRATrainerHF() |
|
|
|
|
|
def create_interface(): |
|
|
"""Cria a interface Gradio""" |
|
|
|
|
|
with gr.Blocks(title="LoRA Trainer Funcional - Hugging Face", theme=gr.themes.Soft()) as interface: |
|
|
|
|
|
gr.Markdown(""" |
|
|
# 🎨 LoRA Trainer Funcional para Hugging Face |
|
|
|
|
|
**Treine seus próprios modelos LoRA para Stable Diffusion de forma profissional!** |
|
|
|
|
|
Esta ferramenta é baseada no kohya-ss sd-scripts e oferece treinamento real e funcional de modelos LoRA. |
|
|
""") |
|
|
|
|
|
|
|
|
dataset_dir_state = gr.State("") |
|
|
|
|
|
with gr.Tab("🔧 Instalação"): |
|
|
gr.Markdown("### Primeiro, instale as dependências necessárias:") |
|
|
install_btn = gr.Button("📦 Instalar Dependências", variant="primary", size="lg") |
|
|
install_status = gr.Textbox(label="Status da Instalação", lines=3, interactive=False) |
|
|
|
|
|
install_btn.click( |
|
|
fn=trainer.install_dependencies, |
|
|
outputs=install_status |
|
|
) |
|
|
|
|
|
with gr.Tab("📁 Configuração do Projeto"): |
|
|
with gr.Row(): |
|
|
project_name = gr.Textbox( |
|
|
label="Nome do Projeto", |
|
|
placeholder="meu_lora_anime", |
|
|
info="Nome único para seu projeto (sem espaços especiais)" |
|
|
) |
|
|
|
|
|
gr.Markdown("### 📥 Download do Modelo Base") |
|
|
with gr.Row(): |
|
|
model_choice = gr.Dropdown( |
|
|
choices=list(trainer.model_urls.keys()), |
|
|
label="Modelo Base Pré-definido", |
|
|
value="Anime (animefull-final-pruned)", |
|
|
info="Escolha um modelo base ou use URL personalizada" |
|
|
) |
|
|
custom_model_url = gr.Textbox( |
|
|
label="URL Personalizada (opcional)", |
|
|
placeholder="https://huggingface.co/...", |
|
|
info="URL direta para download de modelo personalizado" |
|
|
) |
|
|
|
|
|
download_btn = gr.Button("📥 Baixar Modelo", variant="primary") |
|
|
download_status = gr.Textbox(label="Status do Download", lines=2, interactive=False) |
|
|
|
|
|
gr.Markdown("### 📊 Upload do Dataset") |
|
|
gr.Markdown(""" |
|
|
**Formato do Dataset:** |
|
|
- Crie um arquivo ZIP contendo suas imagens |
|
|
- Para cada imagem, inclua um arquivo .txt com o mesmo nome contendo as tags/descrições |
|
|
- Exemplo: `imagem1.jpg` + `imagem1.txt` |
|
|
""") |
|
|
|
|
|
dataset_upload = gr.File( |
|
|
label="Upload do Dataset (ZIP)", |
|
|
file_types=[".zip"] |
|
|
) |
|
|
|
|
|
process_btn = gr.Button("📊 Processar Dataset", variant="primary") |
|
|
dataset_status = gr.Textbox(label="Status do Dataset", lines=4, interactive=False) |
|
|
|
|
|
with gr.Tab("⚙️ Parâmetros de Treinamento"): |
|
|
with gr.Row(): |
|
|
with gr.Column(): |
|
|
gr.Markdown("#### 🖼️ Configurações de Imagem") |
|
|
resolution = gr.Slider( |
|
|
minimum=512, maximum=1024, step=64, value=512, |
|
|
label="Resolução", |
|
|
info="Resolução das imagens (512 = mais rápido, 1024 = melhor qualidade)" |
|
|
) |
|
|
batch_size = gr.Slider( |
|
|
minimum=1, maximum=8, step=1, value=1, |
|
|
label="Batch Size", |
|
|
info="Imagens por lote (aumente se tiver GPU potente)" |
|
|
) |
|
|
flip_aug = gr.Checkbox( |
|
|
label="Flip Augmentation", |
|
|
info="Espelhar imagens para aumentar dataset" |
|
|
) |
|
|
shuffle_caption = gr.Checkbox( |
|
|
value=True, |
|
|
label="Shuffle Caption", |
|
|
info="Embaralhar ordem das tags" |
|
|
) |
|
|
keep_tokens = gr.Slider( |
|
|
minimum=0, maximum=5, step=1, value=1, |
|
|
label="Keep Tokens", |
|
|
info="Número de tokens iniciais que não serão embaralhados" |
|
|
) |
|
|
|
|
|
with gr.Column(): |
|
|
gr.Markdown("#### 🎯 Configurações de Treinamento") |
|
|
epochs = gr.Slider( |
|
|
minimum=1, maximum=100, step=1, value=10, |
|
|
label="Épocas", |
|
|
info="Número de épocas de treinamento" |
|
|
) |
|
|
max_train_steps = gr.Number( |
|
|
value=0, |
|
|
label="Max Train Steps (0 = usar épocas)", |
|
|
info="Número máximo de steps (deixe 0 para usar épocas)" |
|
|
) |
|
|
save_every_n_epochs = gr.Slider( |
|
|
minimum=1, maximum=10, step=1, value=1, |
|
|
label="Salvar a cada N épocas", |
|
|
info="Frequência de salvamento dos checkpoints" |
|
|
) |
|
|
mixed_precision = gr.Dropdown( |
|
|
choices=["fp16", "bf16", "no"], |
|
|
value="fp16", |
|
|
label="Mixed Precision", |
|
|
info="fp16 = mais rápido, bf16 = mais estável" |
|
|
) |
|
|
clip_skip = gr.Slider( |
|
|
minimum=1, maximum=12, step=1, value=2, |
|
|
label="CLIP Skip", |
|
|
info="Camadas CLIP a pular (2 para anime, 1 para realista)" |
|
|
) |
|
|
|
|
|
with gr.Row(): |
|
|
with gr.Column(): |
|
|
gr.Markdown("#### 📚 Learning Rate") |
|
|
learning_rate = gr.Number( |
|
|
value=1e-4, |
|
|
label="Learning Rate (UNet)", |
|
|
info="Taxa de aprendizado principal" |
|
|
) |
|
|
text_encoder_lr = gr.Number( |
|
|
value=5e-5, |
|
|
label="Learning Rate (Text Encoder)", |
|
|
info="Taxa de aprendizado do text encoder" |
|
|
) |
|
|
scheduler = gr.Dropdown( |
|
|
choices=["cosine", "cosine_with_restarts", "constant", "constant_with_warmup", "linear"], |
|
|
value="cosine_with_restarts", |
|
|
label="LR Scheduler", |
|
|
info="Algoritmo de ajuste da learning rate" |
|
|
) |
|
|
optimizer = gr.Dropdown( |
|
|
choices=["AdamW8bit", "AdamW", "Lion", "SGD"], |
|
|
value="AdamW8bit", |
|
|
label="Otimizador", |
|
|
info="AdamW8bit = menos memória" |
|
|
) |
|
|
|
|
|
with gr.Column(): |
|
|
gr.Markdown("#### 🧠 Arquitetura LoRA") |
|
|
lora_type = gr.Radio( |
|
|
choices=["LoRA", "LoCon"], |
|
|
value="LoRA", |
|
|
label="Tipo de LoRA", |
|
|
info="LoRA = geral, LoCon = estilos artísticos" |
|
|
) |
|
|
network_dim = gr.Slider( |
|
|
minimum=4, maximum=128, step=4, value=32, |
|
|
label="Network Dimension", |
|
|
info="Dimensão da rede (maior = mais detalhes, mais memória)" |
|
|
) |
|
|
network_alpha = gr.Slider( |
|
|
minimum=1, maximum=128, step=1, value=16, |
|
|
label="Network Alpha", |
|
|
info="Controla a força do LoRA (geralmente dim/2)" |
|
|
) |
|
|
|
|
|
with gr.Tab("🚀 Treinamento"): |
|
|
create_config_btn = gr.Button("📝 Criar Configuração de Treinamento", variant="primary", size="lg") |
|
|
config_status = gr.Textbox(label="Status da Configuração", lines=3, interactive=False) |
|
|
|
|
|
with gr.Row(): |
|
|
start_training_btn = gr.Button("🎯 Iniciar Treinamento", variant="primary", size="lg") |
|
|
stop_training_btn = gr.Button("⏹️ Parar Treinamento", variant="stop") |
|
|
|
|
|
training_output = gr.Textbox( |
|
|
label="Output do Treinamento", |
|
|
lines=15, |
|
|
interactive=False, |
|
|
info="Acompanhe o progresso do treinamento em tempo real" |
|
|
) |
|
|
|
|
|
|
|
|
def update_output(): |
|
|
return trainer.get_training_output() |
|
|
|
|
|
with gr.Tab("📥 Download dos Resultados"): |
|
|
refresh_files_btn = gr.Button("🔄 Atualizar Lista de Arquivos", variant="secondary") |
|
|
|
|
|
output_files = gr.Dropdown( |
|
|
label="Arquivos LoRA Gerados", |
|
|
choices=[], |
|
|
info="Selecione um arquivo para download" |
|
|
) |
|
|
|
|
|
download_info = gr.Markdown("ℹ️ Os arquivos LoRA estarão disponíveis após o treinamento") |
|
|
|
|
|
|
|
|
download_btn.click( |
|
|
fn=trainer.download_model, |
|
|
inputs=[model_choice, custom_model_url], |
|
|
outputs=download_status |
|
|
) |
|
|
|
|
|
process_btn.click( |
|
|
fn=trainer.process_dataset, |
|
|
inputs=[dataset_upload, project_name], |
|
|
outputs=[dataset_status, dataset_dir_state] |
|
|
) |
|
|
|
|
|
create_config_btn.click( |
|
|
fn=trainer.create_training_config, |
|
|
inputs=[ |
|
|
project_name, dataset_dir_state, model_choice, custom_model_url, |
|
|
resolution, batch_size, epochs, learning_rate, text_encoder_lr, |
|
|
network_dim, network_alpha, lora_type, optimizer, scheduler, |
|
|
flip_aug, shuffle_caption, keep_tokens, clip_skip, mixed_precision, |
|
|
save_every_n_epochs, max_train_steps |
|
|
], |
|
|
outputs=config_status |
|
|
) |
|
|
|
|
|
start_training_btn.click( |
|
|
fn=trainer.start_training, |
|
|
inputs=project_name, |
|
|
outputs=training_output |
|
|
) |
|
|
|
|
|
stop_training_btn.click( |
|
|
fn=trainer.stop_training, |
|
|
outputs=training_output |
|
|
) |
|
|
|
|
|
refresh_files_btn.click( |
|
|
fn=trainer.list_output_files, |
|
|
inputs=project_name, |
|
|
outputs=output_files |
|
|
) |
|
|
|
|
|
return interface |
|
|
|
|
|
if __name__ == "__main__": |
|
|
print("🚀 Iniciando LoRA Trainer Funcional...") |
|
|
interface = create_interface() |
|
|
interface.launch( |
|
|
server_name="0.0.0.0", |
|
|
server_port=7860, |
|
|
share=False, |
|
|
show_error=True |
|
|
) |
|
|
|
|
|
|