|
|
|
|
|
|
|
|
import argparse |
|
|
import datetime |
|
|
import numpy as np |
|
|
import time |
|
|
import torch |
|
|
import torch.backends.cudnn as cudnn |
|
|
import json |
|
|
from pathlib import Path |
|
|
from torch.cuda.amp import GradScaler, autocast |
|
|
from torchvision.utils import make_grid |
|
|
import importlib |
|
|
import logging |
|
|
import torch.nn as nn |
|
|
|
|
|
import sys |
|
|
from typing import Iterable, Optional |
|
|
import logging |
|
|
import torch.distributed as dist |
|
|
import os |
|
|
from torch.utils.data import DataLoader |
|
|
from torch.utils.tensorboard import SummaryWriter |
|
|
from PIL import Image, ImageDraw, ImageFont |
|
|
from tqdm import tqdm |
|
|
import matplotlib.pyplot as plt |
|
|
from utils.util import show_params, visualize_features, init_distributed_mode, get_rank, get_world_size |
|
|
from dataset.build_dataset import CustomCocoDataset |
|
|
|
|
|
from omegaconf import OmegaConf |
|
|
from model import ControlLDM, Diffusion |
|
|
from utils.common import instantiate_from_config |
|
|
from utils.sampler import SpacedSampler |
|
|
|
|
|
|
|
|
def get_args_parser(): |
|
|
parser = argparse.ArgumentParser('DeiT training and evaluation script', add_help=False) |
|
|
parser.add_argument('--batch-size', default=64, type=int) |
|
|
parser.add_argument('--epochs', default=300, type=int) |
|
|
parser.add_argument('--bce-loss', action='store_true') |
|
|
parser.add_argument('--unscale-lr', action='store_true') |
|
|
|
|
|
|
|
|
parser.add_argument('--model', default='deit_base_patch16_224', type=str) |
|
|
parser.add_argument('--target_model', default='deit_base_patch16_224', type=str) |
|
|
parser.add_argument('--input-size', default=224, type=int, help='images input size') |
|
|
|
|
|
parser.add_argument('--drop', type=float, default=0.0, metavar='PCT', |
|
|
help='Dropout rate (default: 0.)') |
|
|
parser.add_argument('--drop-path', type=float, default=0.1, metavar='PCT', |
|
|
help='Drop path rate (default: 0.1)') |
|
|
|
|
|
parser.add_argument('--model-ema', action='store_true') |
|
|
parser.add_argument('--no-model-ema', action='store_false', dest='model_ema') |
|
|
parser.set_defaults(model_ema=True) |
|
|
parser.add_argument('--model-ema-decay', type=float, default=0.99996, help='') |
|
|
parser.add_argument('--model-ema-force-cpu', action='store_true', default=False, help='') |
|
|
|
|
|
|
|
|
parser.add_argument("--config", type=str, required=True) |
|
|
parser.add_argument('--total_train_steps', default=300000, type=int) |
|
|
parser.add_argument('--resume_path', default='', help='resume from checkpoint for controlnet') |
|
|
parser.add_argument('--cldm_learning_rate', default=1e-4, type=float, help='learning rate for controlnet') |
|
|
parser.add_argument('--sd_locked', default=True, help='whether to lock the sd of controlnet') |
|
|
parser.add_argument('--only_mid_control', default=False, help='only control the middle layers of controlnet') |
|
|
parser.add_argument('--cldm_yaml', default='./models/cldm_v15.yaml', help='yaml file for controlnet') |
|
|
parser.add_argument('--exp_dir', default='./exp', help='experiment directory') |
|
|
parser.add_argument('--image_floder', default='./data', help='training image floder') |
|
|
parser.add_argument("--log_every", default=20, type=int, help="log every n steps") |
|
|
parser.add_argument("--ckpt_every", default=1000, type=int, help="save checkpoint every n steps") |
|
|
parser.add_argument("--image_every", default=1000, type=int, help="log image every n steps") |
|
|
|
|
|
parser.add_argument('--global_step', default=0, type=int, help='global step') |
|
|
parser.add_argument('--ddim_steps', default=50, type=int, help='ddim steps') |
|
|
parser.add_argument('--eta', default=0, type=float, help='ddim eta') |
|
|
parser.add_argument('--unconditional_guidance_scale', default=1.25, type=float, help='unconditional guidance scale') |
|
|
|
|
|
|
|
|
parser.add_argument('--opt', default='adamw', type=str, metavar='OPTIMIZER', |
|
|
help='Optimizer (default: "adamw"') |
|
|
parser.add_argument('--opt-eps', default=1e-8, type=float, metavar='EPSILON', |
|
|
help='Optimizer Epsilon (default: 1e-8)') |
|
|
parser.add_argument('--opt-betas', default=None, type=float, nargs='+', metavar='BETA', |
|
|
help='Optimizer Betas (default: None, use opt default)') |
|
|
parser.add_argument('--clip-grad', type=float, default=None, metavar='NORM', |
|
|
help='Clip gradient norm (default: None, no clipping)') |
|
|
parser.add_argument('--momentum', type=float, default=0.9, metavar='M', |
|
|
help='SGD momentum (default: 0.9)') |
|
|
parser.add_argument('--weight-decay', type=float, default=0.05, |
|
|
help='weight decay (default: 0.05)') |
|
|
|
|
|
parser.add_argument('--sched', default='cosine', type=str, metavar='SCHEDULER', |
|
|
help='LR scheduler (default: "cosine"') |
|
|
parser.add_argument('--lr', type=float, default=4e-4, metavar='LR', |
|
|
help='learning rate (default: 5e-4)') |
|
|
parser.add_argument('--lr-noise', type=float, nargs='+', default=None, metavar='pct, pct', |
|
|
help='learning rate noise on/off epoch percentages') |
|
|
parser.add_argument('--lr-noise-pct', type=float, default=0.67, metavar='PERCENT', |
|
|
help='learning rate noise limit percent (default: 0.67)') |
|
|
parser.add_argument('--lr-noise-std', type=float, default=1.0, metavar='STDDEV', |
|
|
help='learning rate noise std-dev (default: 1.0)') |
|
|
parser.add_argument('--warmup-lr', type=float, default=1e-6, metavar='LR', |
|
|
help='warmup learning rate (default: 1e-6)') |
|
|
parser.add_argument('--min-lr', type=float, default=1e-5, metavar='LR', |
|
|
help='lower lr bound for cyclic schedulers that hit 0 (1e-5)') |
|
|
|
|
|
parser.add_argument('--decay-epochs', type=float, default=30, metavar='N', |
|
|
help='epoch interval to decay LR') |
|
|
parser.add_argument('--warmup-epochs', type=int, default=5, metavar='N', |
|
|
help='epochs to warmup LR, if scheduler supports') |
|
|
parser.add_argument('--cooldown-epochs', type=int, default=10, metavar='N', |
|
|
help='epochs to cooldown LR at min_lr, after cyclic schedule ends') |
|
|
parser.add_argument('--patience-epochs', type=int, default=10, metavar='N', |
|
|
help='patience epochs for Plateau LR scheduler (default: 10') |
|
|
parser.add_argument('--decay-rate', '--dr', type=float, default=0.1, metavar='RATE', |
|
|
help='LR decay rate (default: 0.1)') |
|
|
|
|
|
|
|
|
parser.add_argument('--color-jitter', type=float, default=0.3, metavar='PCT', |
|
|
help='Color jitter factor (default: 0.3)') |
|
|
parser.add_argument('--aa', type=str, default='rand-m9-mstd0.5-inc1', metavar='NAME', |
|
|
help='Use AutoAugment policy. "v0" or "original". " + \ |
|
|
"(default: rand-m9-mstd0.5-inc1)'), |
|
|
parser.add_argument('--smoothing', type=float, default=0.1, help='Label smoothing (default: 0.1)') |
|
|
parser.add_argument('--train-interpolation', type=str, default='bicubic', |
|
|
help='Training interpolation (random, bilinear, bicubic default: "bicubic")') |
|
|
|
|
|
parser.add_argument('--repeated-aug', action='store_true') |
|
|
parser.add_argument('--no-repeated-aug', action='store_false', dest='repeated_aug') |
|
|
parser.set_defaults(repeated_aug=True) |
|
|
|
|
|
parser.add_argument('--train-mode', action='store_true') |
|
|
parser.add_argument('--no-train-mode', action='store_false', dest='train_mode') |
|
|
parser.set_defaults(train_mode=True) |
|
|
|
|
|
parser.add_argument('--ThreeAugment', action='store_true') |
|
|
|
|
|
parser.add_argument('--src', action='store_true') |
|
|
|
|
|
|
|
|
parser.add_argument('--global_crops_size', '--img_size', default=224, type=int, |
|
|
help="this should be equal to image size") |
|
|
parser.add_argument('--patch_size', default=16, type=int, |
|
|
help="patch size for vit patch embedding") |
|
|
|
|
|
|
|
|
parser.add_argument('--mask_ratio', default=(0.1, 0.5), type=float, nargs='+', |
|
|
help="mask ratio can be either a value or a range") |
|
|
parser.add_argument('--mask_probability', default=0., type=float, |
|
|
help="how many samples with be applied with masking") |
|
|
parser.add_argument('--mask_first_n', action='store_true', |
|
|
help="mask the first n sample to avoid shuffling. Needed for MAE-style encoder") |
|
|
parser.add_argument('--clone_batch', default=1, type=int, |
|
|
help="how many times to clone the batch for masking (default: 1, not cloning)") |
|
|
|
|
|
|
|
|
parser.add_argument('--reprob', type=float, default=0.25, metavar='PCT', |
|
|
help='Random erase prob (default: 0.25)') |
|
|
parser.add_argument('--remode', type=str, default='pixel', |
|
|
help='Random erase mode (default: "pixel")') |
|
|
parser.add_argument('--recount', type=int, default=1, |
|
|
help='Random erase count (default: 1)') |
|
|
parser.add_argument('--resplit', action='store_true', default=False, |
|
|
help='Do not random erase first (clean) augmentation split') |
|
|
|
|
|
|
|
|
parser.add_argument('--mixup', type=float, default=0.8, |
|
|
help='mixup alpha, mixup enabled if > 0. (default: 0.8)') |
|
|
parser.add_argument('--cutmix', type=float, default=1.0, |
|
|
help='cutmix alpha, cutmix enabled if > 0. (default: 1.0)') |
|
|
parser.add_argument('--cutmix-minmax', type=float, nargs='+', default=None, |
|
|
help='cutmix min/max ratio, overrides alpha and enables cutmix if set (default: None)') |
|
|
parser.add_argument('--mixup-prob', type=float, default=1.0, |
|
|
help='Probability of performing mixup or cutmix when either/both is enabled') |
|
|
parser.add_argument('--mixup-switch-prob', type=float, default=0.5, |
|
|
help='Probability of switching to cutmix when both mixup and cutmix enabled') |
|
|
parser.add_argument('--mixup-mode', type=str, default='batch', |
|
|
help='How to apply mixup/cutmix params. Per "batch", "pair", or "elem"') |
|
|
|
|
|
|
|
|
parser.add_argument('--teacher-model', default='base', type=str) |
|
|
parser.add_argument('--teacher-path', type=str, default='') |
|
|
parser.add_argument('--distillation-type', default='none', choices=['none', 'soft', 'hard'], type=str, help="") |
|
|
parser.add_argument('--distillation-alpha', default=0.5, type=float, help="") |
|
|
parser.add_argument('--distillation-tau', default=1.0, type=float, help="") |
|
|
parser.add_argument('--lambda_token', type=float, default=1.0) |
|
|
parser.add_argument('--lambda_fea', type=float, default=1.0) |
|
|
parser.add_argument('--lambda_patch', type=float, default=1.0) |
|
|
|
|
|
|
|
|
parser.add_argument('--cosub', action='store_true') |
|
|
|
|
|
|
|
|
parser.add_argument('--finetune', default='', help='finetune from checkpoint') |
|
|
parser.add_argument('--attn-only', action='store_true') |
|
|
parser.add_argument('--weight_inherit', default='') |
|
|
|
|
|
|
|
|
parser.add_argument('--data-path', default='/datasets01/imagenet_full_size/061417/', type=str, |
|
|
help='dataset path') |
|
|
parser.add_argument('--data-set', default='IMNET', choices=['CIFAR', 'IMNET', 'IMNET_ibot', 'IMNET_ibot_aug', 'IMNET_ibot_fast_aug', 'INAT', 'INAT19', 'IMNET_L', 'IMNET_L_ibot'], |
|
|
type=str, help='Image Net dataset path') |
|
|
parser.add_argument('--inat-category', default='name', |
|
|
choices=['kingdom', 'phylum', 'class', 'order', 'supercategory', 'family', 'genus', 'name'], |
|
|
type=str, help='semantic granularity') |
|
|
|
|
|
parser.add_argument('--output_dir', default='', |
|
|
help='path where to save, empty for no saving') |
|
|
parser.add_argument('--log_dir', default='/data1/qiyp/Proteus-pytorch/pretrain/log/DINOv2_training/log', |
|
|
type=str, help='saving logging info every 20 iters') |
|
|
parser.add_argument('--device', default='cuda', |
|
|
help='device to use for training / testing') |
|
|
parser.add_argument('--seed', default=231, type=int) |
|
|
parser.add_argument('--start_epoch', default=0, type=int, metavar='N', |
|
|
help='start epoch') |
|
|
parser.add_argument('--eval', action='store_true', help='Perform evaluation only') |
|
|
parser.add_argument('--eval-crop-ratio', default=0.875, type=float, help="Crop ratio for evaluation") |
|
|
parser.add_argument('--dist-eval', action='store_true', default=False, help='Enabling distributed evaluation') |
|
|
parser.add_argument('--num_workers', default=8, type=int) |
|
|
parser.add_argument('--pin-mem', action='store_true', |
|
|
help='Pin CPU memory in DataLoader for more efficient (sometimes) transfer to GPU.') |
|
|
parser.add_argument('--no-pin-mem', action='store_false', dest='pin_mem', |
|
|
help='') |
|
|
parser.set_defaults(pin_mem=True) |
|
|
|
|
|
|
|
|
parser.add_argument('--distributed', action='store_true', default=False, help='Enabling distributed training') |
|
|
parser.add_argument('--world_size', default=1, type=int, |
|
|
help='number of distributed processes') |
|
|
parser.add_argument('--dist_url', default='env://', help='url used to set up distributed training') |
|
|
return parser |
|
|
|
|
|
def setup_logger(log_dir, rank=0): |
|
|
if rank != 0: |
|
|
return |
|
|
log_formatter = logging.Formatter("%(asctime)s [%(levelname)-5.5s] %(message)s") |
|
|
root_logger = logging.getLogger() |
|
|
root_logger.setLevel(logging.INFO) |
|
|
|
|
|
log_file_handler = logging.FileHandler(log_dir, encoding='utf-8') |
|
|
log_file_handler.setFormatter(log_formatter) |
|
|
root_logger.addHandler(log_file_handler) |
|
|
|
|
|
log_stream_handler = logging.StreamHandler(sys.stdout) |
|
|
log_stream_handler.setFormatter(log_formatter) |
|
|
root_logger.addHandler(log_stream_handler) |
|
|
|
|
|
logging.info('Logging file is %s' % log_dir) |
|
|
|
|
|
def main(args): |
|
|
init_distributed_mode(args) |
|
|
|
|
|
print(args) |
|
|
|
|
|
device = torch.device(args.device) |
|
|
|
|
|
rank = dist.get_rank() if dist.is_initialized() else 0 |
|
|
|
|
|
os.makedirs(args.log_dir, exist_ok=True) |
|
|
setup_logger(args.log_dir + '/' + time.strftime('%Y%m%d_%H%M%S') + '.log', rank) |
|
|
logging.info('job dir: {}'.format(os.path.dirname(os.path.realpath(__file__)))) |
|
|
logging.info("{}".format(args).replace(', ', ',\n') + '\n') |
|
|
|
|
|
|
|
|
|
|
|
seed = args.seed + get_rank() |
|
|
torch.manual_seed(seed) |
|
|
np.random.seed(seed) |
|
|
|
|
|
|
|
|
cudnn.benchmark = True |
|
|
|
|
|
|
|
|
exp_dir = args.exp_dir |
|
|
os.makedirs(exp_dir, exist_ok=True) |
|
|
ckpt_dir = os.path.join(exp_dir, "checkpoints") |
|
|
os.makedirs(ckpt_dir, exist_ok=True) |
|
|
logging.info(f"Experiment directory created at {exp_dir}") |
|
|
|
|
|
|
|
|
img_folder = args.image_floder |
|
|
dataset = CustomCocoDataset(img_folder) |
|
|
logging.info(f"Loaded train dataset with {len(dataset)} samples") |
|
|
|
|
|
|
|
|
if args.distributed: |
|
|
num_tasks = get_world_size() |
|
|
global_rank = get_rank() |
|
|
sampler_train = torch.utils.data.DistributedSampler( |
|
|
dataset, num_replicas=num_tasks, rank=global_rank, shuffle=True |
|
|
) |
|
|
else: |
|
|
sampler_train = torch.utils.data.RandomSampler(dataset) |
|
|
logging.info("Sampler_train = %s" % str(sampler_train)) |
|
|
|
|
|
|
|
|
dataloader = DataLoader( |
|
|
dataset, sampler=sampler_train, batch_size=args.batch_size, num_workers=args.num_workers, |
|
|
pin_memory=args.pin_mem, drop_last=True) |
|
|
logging.info('Dataloader created') |
|
|
|
|
|
|
|
|
logging.info(f"# ===== Creating Feature Extractor: {args.model} ===== #") |
|
|
meta_arch_module = importlib.import_module(args.model) |
|
|
MetaArch = meta_arch_module.MetaArch |
|
|
|
|
|
model = MetaArch(args) |
|
|
logging.info("Model = %s" % str(model)) |
|
|
|
|
|
if args.finetune: |
|
|
checkpoint = torch.load(args.finetune, map_location='cpu') |
|
|
|
|
|
if 'state_dict' in checkpoint: |
|
|
pretrained_dict = checkpoint['state_dict'] |
|
|
elif 'model' in checkpoint: |
|
|
pretrained_dict = checkpoint['model'] |
|
|
else: |
|
|
pretrained_dict = checkpoint |
|
|
|
|
|
missing_keys, unexpected_keys = model.load_state_dict(pretrained_dict, False) |
|
|
logging.info('Finetuning from %s' % args.finetune) |
|
|
logging.info('missing_keys: %s' % str(missing_keys)) |
|
|
logging.info('unexpected_keys: %s' % str(unexpected_keys)) |
|
|
|
|
|
if args.attn_only: |
|
|
for name_p,p in model.named_parameters(): |
|
|
if '.attn.' in name_p: |
|
|
p.requires_grad = True |
|
|
else: |
|
|
p.requires_grad = False |
|
|
try: |
|
|
model.head.weight.requires_grad = True |
|
|
model.head.bias.requires_grad = True |
|
|
except: |
|
|
model.fc.weight.requires_grad = True |
|
|
model.fc.bias.requires_grad = True |
|
|
try: |
|
|
model.pos_embed.requires_grad = True |
|
|
except: |
|
|
print('no position encoding') |
|
|
try: |
|
|
for p in model.patch_embed.parameters(): |
|
|
p.requires_grad = False |
|
|
except: |
|
|
print('no patch embed') |
|
|
|
|
|
model.to(device) |
|
|
model.eval() |
|
|
|
|
|
|
|
|
cfg = OmegaConf.load(args.config) |
|
|
cldm: ControlLDM = instantiate_from_config(cfg.model.cldm) |
|
|
sd = torch.load(cfg.train.sd_path, map_location="cpu")["state_dict"] |
|
|
unused = cldm.load_pretrained_sd(sd) |
|
|
logging.info(f"strictly load pretrained SD weight from {cfg.train.sd_path}\n" |
|
|
f"unused weights: {unused}") |
|
|
|
|
|
if cfg.train.resume: |
|
|
ckpt = torch.load(cfg.train.resume, map_location="cpu") |
|
|
cldm_state_dict = ckpt['model_state_dict'] |
|
|
cldm.load_controlnet_from_ckpt(cldm_state_dict) |
|
|
logging.info(f"strictly load controlnet weight from checkpoint: {cfg.train.resume}") |
|
|
if 'global_step' in ckpt: |
|
|
global_step = ckpt['global_step'] |
|
|
if 'epoch' in ckpt: |
|
|
epoch = ckpt['epoch'] |
|
|
logging.info(f"Resumed from global step {global_step}, epoch {epoch}") |
|
|
else: |
|
|
init_with_new_zero, init_with_scratch = cldm.load_controlnet_from_unet() |
|
|
logging.info(f"strictly load controlnet weight from pretrained SD\n" |
|
|
f"weights initialized with newly added zeros: {init_with_new_zero}\n" |
|
|
f"weights initialized from scratch: {init_with_scratch}") |
|
|
global_step = 0 |
|
|
epoch = 0 |
|
|
|
|
|
cldm = cldm.to(device) |
|
|
if args.distributed: |
|
|
cldm = torch.nn.parallel.DistributedDataParallel(cldm, device_ids=[args.gpu], find_unused_parameters=True) |
|
|
cldm_without_ddp = cldm.module |
|
|
else: |
|
|
cldm_without_ddp = cldm |
|
|
|
|
|
diffusion: Diffusion = instantiate_from_config(cfg.model.diffusion) |
|
|
|
|
|
|
|
|
optimizer = torch.optim.AdamW(cldm_without_ddp.controlnet.parameters(), lr=args.cldm_learning_rate) |
|
|
if cfg.train.resume: |
|
|
checkpoint = torch.load(cfg.train.resume, map_location="cpu") |
|
|
|
|
|
if 'optimizer_state_dict' in checkpoint: |
|
|
optimizer.load_state_dict(checkpoint['optimizer_state_dict']) |
|
|
logging.info(f"Optimizer state loaded from checkpoint: {cfg.train.resume}") |
|
|
|
|
|
show_params(cldm_without_ddp.controlnet) |
|
|
|
|
|
ddim_sampler = SpacedSampler(diffusion.betas) |
|
|
logging.info("# ========== Building model done! ========== #") |
|
|
|
|
|
|
|
|
cldm_without_ddp.train() |
|
|
diffusion.to(device) |
|
|
|
|
|
scaler = GradScaler(enabled=True) |
|
|
if cfg.train.resume: |
|
|
checkpoint = torch.load(cfg.train.resume, map_location="cpu") |
|
|
|
|
|
if 'scaler_state_dict' in checkpoint: |
|
|
scaler.load_state_dict(checkpoint['scaler_state_dict']) |
|
|
logging.info(f"Loss Scaler state loaded from checkpoint: {cfg.train.resume}") |
|
|
|
|
|
max_steps = args.total_train_steps |
|
|
step_loss = [] |
|
|
epoch_loss = [] |
|
|
|
|
|
if rank == 0: |
|
|
writer = SummaryWriter(exp_dir) |
|
|
|
|
|
logging.info(f"Training for {max_steps} steps...") |
|
|
|
|
|
while global_step < max_steps: |
|
|
pbar = tqdm(iterable=dataloader, unit="batch") |
|
|
for batch in dataloader: |
|
|
if args.distributed: |
|
|
sampler_train.set_epoch(epoch) |
|
|
|
|
|
gt = batch["jpg"].to(device) |
|
|
hint = batch["hint"].to(device) |
|
|
prompt = batch["txt"] |
|
|
_, _, h, w = hint.shape |
|
|
_, _, H, W = gt.shape |
|
|
with torch.no_grad(): |
|
|
with autocast(dtype=torch.bfloat16): |
|
|
|
|
|
z_0 = cldm_without_ddp.vae_encode(2 * gt - 1) |
|
|
features_dict = model.student.backbone(hint, is_training=True) |
|
|
features = features_dict['x_norm_patchtokens'] |
|
|
features, _ = model.info_bottleneck(features, is_training=False) |
|
|
features = features.view(-1, h//14, w//14, features.shape[2]) |
|
|
features = features.permute(0, 3, 1, 2) |
|
|
features = (features - features.mean()) / features.std() |
|
|
features = torch.clamp(features, -5, 5) |
|
|
features_not_zero = features |
|
|
if torch.rand(1).item() < 0.1: |
|
|
features = torch.zeros_like(features) |
|
|
cond = cldm_without_ddp.prepare_condition(features, prompt) |
|
|
|
|
|
with autocast(dtype=torch.bfloat16): |
|
|
|
|
|
t = torch.randint(0, diffusion.num_timesteps, (z_0.shape[0],), device=device).long() |
|
|
|
|
|
loss = diffusion.p_losses(cldm_without_ddp, z_0, t, cond) |
|
|
optimizer.zero_grad() |
|
|
|
|
|
|
|
|
scaler.scale(loss).backward() |
|
|
scaler.step(optimizer) |
|
|
scaler.update() |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
global_step += 1 |
|
|
step_loss.append(loss.item()) |
|
|
epoch_loss.append(loss.item()) |
|
|
pbar.update(1) |
|
|
pbar.set_description(f"Epoch: {epoch:04d}, Global Step: {global_step:07d}, Loss: {loss.item():.6f}") |
|
|
|
|
|
|
|
|
if rank == 0 and global_step % args.log_every == 0 and global_step > 0: |
|
|
avg_loss = np.mean(step_loss) |
|
|
step_loss.clear() |
|
|
writer.add_scalar("loss/loss_simple_step", avg_loss, global_step) |
|
|
|
|
|
|
|
|
if rank == 0 and global_step % args.ckpt_every == 0 and global_step > 0: |
|
|
|
|
|
checkpoint = { |
|
|
'model_state_dict': cldm_without_ddp.controlnet.state_dict(), |
|
|
'optimizer_state_dict': optimizer.state_dict(), |
|
|
'scaler_state_dict': scaler.state_dict(), |
|
|
'global_step': global_step, |
|
|
'epoch': epoch, |
|
|
} |
|
|
ckpt_path = f"{ckpt_dir}/{global_step:07d}.pt" |
|
|
torch.save(checkpoint, ckpt_path) |
|
|
|
|
|
if rank == 0 and (global_step % args.image_every == 0 or global_step <= 1): |
|
|
eta = args.eta |
|
|
ddim_steps = args.ddim_steps |
|
|
cldm_without_ddp.eval() |
|
|
N = 1 |
|
|
log_features = features_not_zero[:N] |
|
|
log_cond = cldm_without_ddp.prepare_condition(log_features, prompt[:N]) |
|
|
log_uncond = cldm_without_ddp.prepare_condition(torch.zeros_like(log_features), prompt[:N]) |
|
|
log_gt = gt[:N] |
|
|
log_hint = hint[:N] |
|
|
|
|
|
|
|
|
with torch.no_grad(): |
|
|
with torch.cuda.amp.autocast(dtype=torch.bfloat16): |
|
|
|
|
|
z = ddim_sampler.sample( |
|
|
model=cldm_without_ddp, device=device, steps=50, batch_size=len(log_gt), x_size=z_0.shape[1:], |
|
|
cond=log_cond, uncond=None, cfg_scale=1.0, x_T=None |
|
|
) |
|
|
x_samples = cldm_without_ddp.vae_decode(z) |
|
|
x_samples = (x_samples + 1) / 2 |
|
|
|
|
|
z_cfg = ddim_sampler.sample( |
|
|
model=cldm_without_ddp, device=device, steps=50, batch_size=len(log_gt), x_size=z_0.shape[1:], |
|
|
cond=log_cond, uncond=log_uncond, cfg_scale=args.unconditional_guidance_scale, x_T=None |
|
|
) |
|
|
x_samples_cfg = cldm_without_ddp.vae_decode(z_cfg) |
|
|
x_samples_cfg = (x_samples_cfg + 1) / 2 |
|
|
|
|
|
vis_features = visualize_features(log_features) |
|
|
for tag, image in [ |
|
|
("image/samples", x_samples), |
|
|
("image/samples_cfg", x_samples_cfg), |
|
|
("image/gt", log_gt), ("image/condition", vis_features), |
|
|
("image/hint", log_hint), |
|
|
]: |
|
|
|
|
|
image = image.to(torch.float32) |
|
|
writer.add_image(tag, make_grid(image, nrow=1), global_step) |
|
|
cldm_without_ddp.train() |
|
|
if global_step == max_steps: |
|
|
break |
|
|
|
|
|
pbar.close() |
|
|
epoch += 1 |
|
|
|
|
|
if rank == 0: |
|
|
avg_epoch_loss = np.mean(epoch_loss) |
|
|
epoch_loss.clear() |
|
|
writer.add_scalar("loss/loss_simple_epoch", avg_epoch_loss, global_step) |
|
|
logging.info(f"Epoch: {epoch:04d}, Global Step: {global_step:07d}, Loss: {avg_epoch_loss:.6f}") |
|
|
|
|
|
logging.info("done!") |
|
|
if rank == 0: |
|
|
writer.close() |
|
|
if args.distributed: |
|
|
dist.barrier() |
|
|
|
|
|
|
|
|
if __name__ == '__main__': |
|
|
parser = argparse.ArgumentParser('DeiT training and evaluation script', parents=[get_args_parser()]) |
|
|
args = parser.parse_args() |
|
|
if args.output_dir: |
|
|
Path(args.output_dir).mkdir(parents=True, exist_ok=True) |
|
|
main(args) |
|
|
|