multimodalart's picture
Upload 121 files
f555806 verified
raw
history blame
27.8 kB
import { NextRequest, NextResponse } from 'next/server';
import { spawn } from 'child_process';
import { writeFile } from 'fs/promises';
import path from 'path';
import { tmpdir } from 'os';
export async function POST(request: NextRequest) {
try {
const body = await request.json();
const { action, token, hardware, namespace, jobConfig, datasetRepo } = body;
switch (action) {
case 'checkStatus':
try {
if (!token || !jobConfig?.hf_job_id) {
return NextResponse.json({ error: 'Token and job ID required' }, { status: 400 });
}
const jobStatus = await checkHFJobStatus(token, jobConfig.hf_job_id);
return NextResponse.json({ status: jobStatus });
} catch (error: any) {
console.error('Job status check error:', error);
return NextResponse.json({ error: error.message }, { status: 500 });
}
case 'generateScript':
try {
const uvScript = generateUVScript({
jobConfig,
datasetRepo,
namespace,
token: token || 'YOUR_HF_TOKEN',
});
return NextResponse.json({
script: uvScript,
filename: `train_${jobConfig.config.name.replace(/[^a-zA-Z0-9]/g, '_')}.py`
});
} catch (error: any) {
return NextResponse.json({ error: error.message }, { status: 500 });
}
case 'submitJob':
try {
if (!token || !hardware) {
return NextResponse.json({ error: 'Token and hardware required' }, { status: 400 });
}
// Generate UV script
const uvScript = generateUVScript({
jobConfig,
datasetRepo,
namespace,
token,
});
// Write script to temporary file
const scriptPath = path.join(tmpdir(), `train_${Date.now()}.py`);
await writeFile(scriptPath, uvScript);
// Submit HF job using uv run
const jobId = await submitHFJobUV(token, hardware, scriptPath);
return NextResponse.json({
success: true,
jobId,
message: `Job submitted successfully with ID: ${jobId}`
});
} catch (error: any) {
console.error('Job submission error:', error);
return NextResponse.json({ error: error.message }, { status: 500 });
}
default:
return NextResponse.json({ error: 'Invalid action' }, { status: 400 });
}
} catch (error: any) {
console.error('HF Jobs API error:', error);
return NextResponse.json({ error: error.message }, { status: 500 });
}
}
function generateUVScript({ jobConfig, datasetRepo, namespace, token }: {
jobConfig: any;
datasetRepo: string;
namespace: string;
token: string;
}) {
const config = jobConfig.config;
const process = config.process[0];
return `# /// script
# dependencies = [
# "torch>=2.0.0",
# "torchvision",
# "torchao==0.10.0",
# "safetensors",
# "diffusers @ git+https://github.com/huggingface/diffusers@7a2b78bf0f788d311cc96b61e660a8e13e3b1e63",
# "transformers==4.52.4",
# "lycoris-lora==1.8.3",
# "flatten_json",
# "pyyaml",
# "oyaml",
# "tensorboard",
# "kornia",
# "invisible-watermark",
# "einops",
# "accelerate",
# "toml",
# "albumentations==1.4.15",
# "albucore==0.0.16",
# "pydantic",
# "omegaconf",
# "k-diffusion",
# "open_clip_torch",
# "timm",
# "prodigyopt",
# "controlnet_aux==0.0.10",
# "python-dotenv",
# "bitsandbytes",
# "hf_transfer",
# "lpips",
# "pytorch_fid",
# "optimum-quanto==0.2.4",
# "sentencepiece",
# "huggingface_hub",
# "peft",
# "python-slugify",
# "opencv-python-headless",
# "pytorch-wavelets==1.3.0",
# "matplotlib==3.10.1",
# "setuptools==69.5.1",
# "datasets==4.0.0",
# "pyarrow==20.0.0",
# "pillow",
# "ftfy",
# ]
# ///
import os
import sys
import subprocess
import argparse
import oyaml as yaml
from datasets import load_dataset
from huggingface_hub import HfApi, create_repo, upload_folder, snapshot_download
import tempfile
import shutil
import glob
from PIL import Image
def setup_ai_toolkit():
"""Clone and setup ai-toolkit repository"""
repo_dir = "ai-toolkit"
if not os.path.exists(repo_dir):
print("Cloning ai-toolkit repository...")
subprocess.run(
["git", "clone", "https://github.com/ostris/ai-toolkit.git", repo_dir],
check=True
)
sys.path.insert(0, os.path.abspath(repo_dir))
return repo_dir
def download_dataset(dataset_repo: str, local_path: str):
"""Download dataset from HF Hub as files"""
print(f"Downloading dataset from {dataset_repo}...")
# Create local dataset directory
os.makedirs(local_path, exist_ok=True)
# Use snapshot_download to get the dataset files directly
from huggingface_hub import snapshot_download
try:
# First try to download as a structured dataset
dataset = load_dataset(dataset_repo, split="train")
# Download images and captions from structured dataset
for i, item in enumerate(dataset):
# Save image
if "image" in item:
image_path = os.path.join(local_path, f"image_{i:06d}.jpg")
image = item["image"]
# Convert RGBA to RGB if necessary (for JPEG compatibility)
if image.mode == 'RGBA':
# Create a white background and paste the RGBA image on it
background = Image.new('RGB', image.size, (255, 255, 255))
background.paste(image, mask=image.split()[-1]) # Use alpha channel as mask
image = background
elif image.mode not in ['RGB', 'L']:
# Convert any other mode to RGB
image = image.convert('RGB')
image.save(image_path, 'JPEG')
# Save caption
if "text" in item:
caption_path = os.path.join(local_path, f"image_{i:06d}.txt")
with open(caption_path, "w", encoding="utf-8") as f:
f.write(item["text"])
print(f"Downloaded {len(dataset)} items to {local_path}")
except Exception as e:
print(f"Failed to load as structured dataset: {e}")
print("Attempting to download raw files...")
# Download the dataset repository as files
temp_repo_path = snapshot_download(repo_id=dataset_repo, repo_type="dataset")
# Copy all image and text files to the local path
import glob
import shutil
print(f"Downloaded repo to: {temp_repo_path}")
print(f"Contents: {os.listdir(temp_repo_path)}")
# Find all image files
image_extensions = ['*.jpg', '*.jpeg', '*.png', '*.webp', '*.bmp', '*.JPG', '*.JPEG', '*.PNG']
image_files = []
for ext in image_extensions:
pattern = os.path.join(temp_repo_path, "**", ext)
found_files = glob.glob(pattern, recursive=True)
image_files.extend(found_files)
print(f"Pattern {pattern} found {len(found_files)} files")
# Find all text files
text_files = glob.glob(os.path.join(temp_repo_path, "**", "*.txt"), recursive=True)
print(f"Found {len(image_files)} image files and {len(text_files)} text files")
# Copy image files
for i, img_file in enumerate(image_files):
dest_path = os.path.join(local_path, f"image_{i:06d}.jpg")
# Load and convert image if needed
try:
with Image.open(img_file) as image:
if image.mode == 'RGBA':
background = Image.new('RGB', image.size, (255, 255, 255))
background.paste(image, mask=image.split()[-1])
image = background
elif image.mode not in ['RGB', 'L']:
image = image.convert('RGB')
image.save(dest_path, 'JPEG')
except Exception as img_error:
print(f"Error processing image {img_file}: {img_error}")
continue
# Copy text files (captions)
for i, txt_file in enumerate(text_files[:len(image_files)]): # Match number of images
dest_path = os.path.join(local_path, f"image_{i:06d}.txt")
try:
shutil.copy2(txt_file, dest_path)
except Exception as txt_error:
print(f"Error copying text file {txt_file}: {txt_error}")
continue
print(f"Downloaded {len(image_files)} images and {len(text_files)} captions to {local_path}")
def create_config(dataset_path: str, output_path: str):
"""Create training configuration"""
import json
# Load config from JSON string and fix boolean/null values for Python
config_str = """${JSON.stringify(jobConfig, null, 2)}"""
config_str = config_str.replace('true', 'True').replace('false', 'False').replace('null', 'None')
config = eval(config_str)
# Update paths for cloud environment
config["config"]["process"][0]["datasets"][0]["folder_path"] = dataset_path
config["config"]["process"][0]["training_folder"] = output_path
# Remove sqlite_db_path as it's not needed for cloud training
if "sqlite_db_path" in config["config"]["process"][0]:
del config["config"]["process"][0]["sqlite_db_path"]
# Also change trainer type from ui_trainer to standard trainer to avoid UI dependencies
if config["config"]["process"][0]["type"] == "ui_trainer":
config["config"]["process"][0]["type"] = "sd_trainer"
return config
def upload_results(output_path: str, model_name: str, namespace: str, token: str, config: dict):
"""Upload trained model to HF Hub with README generation and proper file organization"""
import tempfile
import shutil
import glob
import re
import yaml
from datetime import datetime
from huggingface_hub import create_repo, upload_file, HfApi
try:
repo_id = f"{namespace}/{model_name}"
# Create repository
create_repo(repo_id=repo_id, token=token, exist_ok=True)
print(f"Uploading model to {repo_id}...")
# Create temporary directory for organized upload
with tempfile.TemporaryDirectory() as temp_upload_dir:
api = HfApi()
# 1. Find and upload model files to root directory
safetensors_files = glob.glob(os.path.join(output_path, "**", "*.safetensors"), recursive=True)
json_files = glob.glob(os.path.join(output_path, "**", "*.json"), recursive=True)
txt_files = glob.glob(os.path.join(output_path, "**", "*.txt"), recursive=True)
uploaded_files = []
# Upload .safetensors files to root
for file_path in safetensors_files:
filename = os.path.basename(file_path)
print(f"Uploading {filename} to repository root...")
api.upload_file(
path_or_fileobj=file_path,
path_in_repo=filename,
repo_id=repo_id,
token=token
)
uploaded_files.append(filename)
# Upload relevant JSON config files to root (skip metadata.json and other internal files)
config_files_uploaded = []
for file_path in json_files:
filename = os.path.basename(file_path)
# Only upload important config files, skip internal metadata
if any(keyword in filename.lower() for keyword in ['config', 'adapter', 'lora', 'model']):
print(f"Uploading {filename} to repository root...")
api.upload_file(
path_or_fileobj=file_path,
path_in_repo=filename,
repo_id=repo_id,
token=token
)
uploaded_files.append(filename)
config_files_uploaded.append(filename)
# 2. Handle sample images
samples_uploaded = []
samples_dir = os.path.join(output_path, "samples")
if os.path.isdir(samples_dir):
print("Uploading sample images...")
# Create samples directory in repo
for filename in os.listdir(samples_dir):
if filename.lower().endswith(('.jpg', '.jpeg', '.png', '.webp')):
file_path = os.path.join(samples_dir, filename)
repo_path = f"samples/{filename}"
api.upload_file(
path_or_fileobj=file_path,
path_in_repo=repo_path,
repo_id=repo_id,
token=token
)
samples_uploaded.append(repo_path)
# 3. Generate and upload README.md
readme_content = generate_model_card_readme(
repo_id=repo_id,
config=config,
model_name=model_name,
samples_dir=samples_dir if os.path.isdir(samples_dir) else None,
uploaded_files=uploaded_files
)
# Create README.md file and upload to root
readme_path = os.path.join(temp_upload_dir, "README.md")
with open(readme_path, "w", encoding="utf-8") as f:
f.write(readme_content)
print("Uploading README.md to repository root...")
api.upload_file(
path_or_fileobj=readme_path,
path_in_repo="README.md",
repo_id=repo_id,
token=token
)
print(f"Model uploaded successfully to https://huggingface.co/{repo_id}")
print(f"Files uploaded: {len(uploaded_files)} model files, {len(samples_uploaded)} samples, README.md")
except Exception as e:
print(f"Failed to upload model: {e}")
raise e
def generate_model_card_readme(repo_id: str, config: dict, model_name: str, samples_dir: str = None, uploaded_files: list = None) -> str:
"""Generate README.md content for the model card based on AI Toolkit's implementation"""
import re
import yaml
import os
try:
# Extract configuration details
process_config = config.get("config", {}).get("process", [{}])[0]
model_config = process_config.get("model", {})
train_config = process_config.get("train", {})
sample_config = process_config.get("sample", {})
# Gather model info
base_model = model_config.get("name_or_path", "unknown")
trigger_word = process_config.get("trigger_word")
arch = model_config.get("arch", "")
# Determine license based on base model
if "FLUX.1-schnell" in base_model:
license_info = {"license": "apache-2.0"}
elif "FLUX.1-dev" in base_model:
license_info = {
"license": "other",
"license_name": "flux-1-dev-non-commercial-license",
"license_link": "https://huggingface.co/black-forest-labs/FLUX.1-dev/blob/main/LICENSE.md"
}
else:
license_info = {"license": "creativeml-openrail-m"}
# Generate tags based on model architecture
tags = ["text-to-image"]
if "xl" in arch.lower():
tags.append("stable-diffusion-xl")
if "flux" in arch.lower():
tags.append("flux")
if "lumina" in arch.lower():
tags.append("lumina2")
if "sd3" in arch.lower() or "v3" in arch.lower():
tags.append("sd3")
# Add LoRA-specific tags
tags.extend(["lora", "diffusers", "template:sd-lora", "ai-toolkit"])
# Generate widgets from sample images and prompts
widgets = []
if samples_dir and os.path.isdir(samples_dir):
sample_prompts = sample_config.get("samples", [])
if not sample_prompts:
# Fallback to old format
sample_prompts = [{"prompt": p} for p in sample_config.get("prompts", [])]
# Get sample image files
sample_files = []
if os.path.isdir(samples_dir):
for filename in os.listdir(samples_dir):
if filename.lower().endswith(('.jpg', '.jpeg', '.png', '.webp')):
# Parse filename pattern: timestamp__steps_index.jpg
match = re.search(r"__(\d+)_(\d+)\.jpg$", filename)
if match:
steps, index = int(match.group(1)), int(match.group(2))
# Only use samples from final training step
final_steps = train_config.get("steps", 1000)
if steps == final_steps:
sample_files.append((index, f"samples/{filename}"))
# Sort by index and create widgets
sample_files.sort(key=lambda x: x[0])
for i, prompt_obj in enumerate(sample_prompts):
prompt = prompt_obj.get("prompt", "") if isinstance(prompt_obj, dict) else str(prompt_obj)
if i < len(sample_files):
_, image_path = sample_files[i]
widgets.append({
"text": prompt,
"output": {"url": image_path}
})
# Determine torch dtype based on model
dtype = "torch.bfloat16" if "flux" in arch.lower() else "torch.float16"
# Find the main safetensors file for usage example
main_safetensors = f"{model_name}.safetensors"
if uploaded_files:
safetensors_files = [f for f in uploaded_files if f.endswith('.safetensors')]
if safetensors_files:
main_safetensors = safetensors_files[0]
# Construct YAML frontmatter
frontmatter = {
"tags": tags,
"base_model": base_model,
**license_info
}
if widgets:
frontmatter["widget"] = widgets
if trigger_word:
frontmatter["instance_prompt"] = trigger_word
# Get first prompt for usage example
usage_prompt = trigger_word or "a beautiful landscape"
if widgets:
usage_prompt = widgets[0]["text"]
elif trigger_word:
usage_prompt = trigger_word
# Construct README content
trigger_section = f"You should use \`{trigger_word}\` to trigger the image generation." if trigger_word else "No trigger words defined."
# Build YAML frontmatter string
frontmatter_yaml = yaml.dump(frontmatter, default_flow_style=False, allow_unicode=True, sort_keys=False).strip()
readme_content = f"""---
{frontmatter_yaml}
---
# {model_name}
Model trained with [AI Toolkit by Ostris](https://github.com/ostris/ai-toolkit)
<Gallery />
## Trigger words
{trigger_section}
## Download model and use it with ComfyUI, AUTOMATIC1111, SD.Next, Invoke AI, etc.
Weights for this model are available in Safetensors format.
[Download]({repo_id}/tree/main) them in the Files & versions tab.
## Use it with the [🧨 diffusers library](https://github.com/huggingface/diffusers)
\`\`\`py
from diffusers import AutoPipelineForText2Image
import torch
pipeline = AutoPipelineForText2Image.from_pretrained('{base_model}', torch_dtype={dtype}).to('cuda')
pipeline.load_lora_weights('{repo_id}', weight_name='{main_safetensors}')
image = pipeline('{usage_prompt}').images[0]
image.save("my_image.png")
\`\`\`
For more details, including weighting, merging and fusing LoRAs, check the [documentation on loading LoRAs in diffusers](https://huggingface.co/docs/diffusers/main/en/using-diffusers/loading_adapters)
"""
return readme_content
except Exception as e:
print(f"Error generating README: {e}")
# Fallback simple README
return f"""# {model_name}
Model trained with [AI Toolkit by Ostris](https://github.com/ostris/ai-toolkit)
## Download model
Weights for this model are available in Safetensors format.
[Download]({repo_id}/tree/main) them in the Files & versions tab.
"""
def main():
# Setup environment - token comes from HF Jobs secrets
if "HF_TOKEN" not in os.environ:
raise ValueError("HF_TOKEN environment variable not set")
# Install system dependencies for headless operation
print("Installing system dependencies...")
try:
subprocess.run(["apt-get", "update"], check=True, capture_output=True)
subprocess.run([
"apt-get", "install", "-y",
"libgl1-mesa-glx",
"libglib2.0-0",
"libsm6",
"libxext6",
"libxrender-dev",
"libgomp1",
"ffmpeg"
], check=True, capture_output=True)
print("System dependencies installed successfully")
except subprocess.CalledProcessError as e:
print(f"Failed to install system dependencies: {e}")
print("Continuing without system dependencies...")
# Setup ai-toolkit
toolkit_dir = setup_ai_toolkit()
# Create temporary directories
with tempfile.TemporaryDirectory() as temp_dir:
dataset_path = os.path.join(temp_dir, "dataset")
output_path = os.path.join(temp_dir, "output")
# Download dataset
download_dataset("${datasetRepo}", dataset_path)
# Create config
config = create_config(dataset_path, output_path)
config_path = os.path.join(temp_dir, "config.yaml")
with open(config_path, "w") as f:
yaml.dump(config, f, default_flow_style=False)
# Run training
print("Starting training...")
os.chdir(toolkit_dir)
subprocess.run([
sys.executable, "run.py",
config_path
], check=True)
print("Training completed!")
# Upload results
model_name = f"${jobConfig.config.name}-lora"
upload_results(output_path, model_name, "${namespace}", os.environ["HF_TOKEN"], config)
if __name__ == "__main__":
main()
`;
}
async function submitHFJobUV(token: string, hardware: string, scriptPath: string): Promise<string> {
return new Promise((resolve, reject) => {
// Ensure token is available
if (!token) {
reject(new Error('HF_TOKEN is required'));
return;
}
console.log('Setting up environment with HF_TOKEN for job submission');
console.log(`Command: hf jobs uv run --flavor ${hardware} --timeout 5h --secrets HF_TOKEN --detach ${scriptPath}`);
// Use hf jobs uv run command with timeout and detach to get job ID
const childProcess = spawn('hf', [
'jobs', 'uv', 'run',
'--flavor', hardware,
'--timeout', '5h',
'--secrets', 'HF_TOKEN',
'--detach',
scriptPath
], {
env: {
...process.env,
HF_TOKEN: token
}
});
let output = '';
let error = '';
childProcess.stdout.on('data', (data) => {
const text = data.toString();
output += text;
console.log('HF Jobs stdout:', text);
});
childProcess.stderr.on('data', (data) => {
const text = data.toString();
error += text;
console.log('HF Jobs stderr:', text);
});
childProcess.on('close', (code) => {
console.log('HF Jobs process closed with code:', code);
console.log('Full output:', output);
console.log('Full error:', error);
if (code === 0) {
// With --detach flag, the output should be just the job ID
const fullText = (output + ' ' + error).trim();
// Updated patterns to handle variable-length hex job IDs (16-24+ characters)
const jobIdPatterns = [
/Job started with ID:\s*([a-f0-9]{16,})/i, // "Job started with ID: 68b26b73767540db9fc726ac"
/job\s+([a-f0-9]{16,})/i, // "job 68b26b73767540db9fc726ac"
/Job ID:\s*([a-f0-9]{16,})/i, // "Job ID: 68b26b73767540db9fc726ac"
/created\s+job\s+([a-f0-9]{16,})/i, // "created job 68b26b73767540db9fc726ac"
/submitted.*?job\s+([a-f0-9]{16,})/i, // "submitted ... job 68b26b73767540db9fc726ac"
/https:\/\/huggingface\.co\/jobs\/[^\/]+\/([a-f0-9]{16,})/i, // URL pattern
/([a-f0-9]{20,})/i, // Fallback: any 20+ char hex string
];
let jobId = 'unknown';
for (const pattern of jobIdPatterns) {
const match = fullText.match(pattern);
if (match && match[1] && match[1] !== 'started') {
jobId = match[1];
console.log(`Extracted job ID using pattern: ${pattern.toString()} -> ${jobId}`);
break;
}
}
resolve(jobId);
} else {
reject(new Error(error || output || 'Failed to submit job'));
}
});
childProcess.on('error', (err) => {
console.error('HF Jobs process error:', err);
reject(new Error(`Process error: ${err.message}`));
});
});
}
async function checkHFJobStatus(token: string, jobId: string): Promise<any> {
return new Promise((resolve, reject) => {
console.log(`Checking HF Job status for: ${jobId}`);
const childProcess = spawn('hf', [
'jobs', 'inspect', jobId
], {
env: {
...process.env,
HF_TOKEN: token
}
});
let output = '';
let error = '';
childProcess.stdout.on('data', (data) => {
const text = data.toString();
output += text;
});
childProcess.stderr.on('data', (data) => {
const text = data.toString();
error += text;
});
childProcess.on('close', (code) => {
if (code === 0) {
try {
// Parse the JSON output from hf jobs inspect
const jobInfo = JSON.parse(output);
if (Array.isArray(jobInfo) && jobInfo.length > 0) {
const job = jobInfo[0];
resolve({
id: job.id,
status: job.status?.stage || 'UNKNOWN',
message: job.status?.message,
created_at: job.created_at,
flavor: job.flavor,
url: job.url,
});
} else {
reject(new Error('Invalid job info response'));
}
} catch (parseError: any) {
console.error('Failed to parse job status:', parseError, output);
reject(new Error('Failed to parse job status'));
}
} else {
reject(new Error(error || output || 'Failed to check job status'));
}
});
childProcess.on('error', (err) => {
console.error('HF Jobs inspect process error:', err);
reject(new Error(`Process error: ${err.message}`));
});
});
}