import random import numpy as np import torch import wandb from datasets import load_dataset from diffusers import DDIMScheduler from PIL import Image from torchvision import transforms from utils.pipeline_controlnet import LightControlNetPipeline def image_grid(imgs, rows, cols): assert len(imgs) == rows * cols w, h = imgs[0].size grid = Image.new("RGB", size=(cols * w, rows * h)) for i, img in enumerate(imgs): grid.paste(img, box=(i % cols * w, i // cols * h)) return grid def log_validation(val_dataset, text_encoder, unet, controlnet, args, accelerator): pipeline = LightControlNetPipeline.from_pretrained( args.pretrained_model_name_or_path, controlnet=accelerator.unwrap_model(controlnet, keep_fp32_wrapper=True), unet=accelerator.unwrap_model(unet, keep_fp32_wrapper=True).model, text_encoder=accelerator.unwrap_model(text_encoder, keep_fp32_wrapper=True), safety_checker=None, revision=args.revision, ) pipeline.scheduler = DDIMScheduler.from_config(pipeline.scheduler.config) pipeline = pipeline.to(accelerator.device) pipeline.set_progress_bar_config(disable=True) generator = torch.Generator(device=accelerator.device).manual_seed(args.seed) image_logs = [] for idx in range(args.num_validation_images): data = val_dataset[idx] validation_prompt = data["text"] validation_image = data["conditioning_pixel_values"] image = pipeline( validation_prompt, [validation_image], num_inference_steps=50, generator=generator, )[0][0] image_logs.append( { "validation_image": validation_image, "image": image, "validation_prompt": validation_prompt, } ) for tracker in accelerator.trackers: formatted_images = [] for log in image_logs: image = log["image"] validation_prompt = log["validation_prompt"] validation_image = log["validation_image"] formatted_images.append(wandb.Image(validation_image, caption="Controlnet conditioning")) image = wandb.Image(image, caption=validation_prompt) formatted_images.append(image) tracker.log({"validation": formatted_images}) del pipeline torch.cuda.empty_cache() def make_dataset(args, tokenizer, accelerator, split="train"): # Get the datasets: you can either provide your own training and evaluation files (see below) # or specify a Dataset from the hub (the dataset will be downloaded automatically from the datasets Hub). # In distributed training, the load_dataset function guarantees that only one local process can concurrently # download the dataset. if args.dataset_name is not None: # Downloading and loading a dataset from the hub. dataset = load_dataset( args.dataset_name, args.dataset_config_name, cache_dir=args.cache_dir, ) else: if args.train_data_dir is not None: dataset = load_dataset( args.train_data_dir, cache_dir=args.cache_dir, ) # See more about loading custom images at # https://huggingface.co/docs/datasets/v2.0.0/en/dataset_script # Preprocessing the datasets. # We need to tokenize inputs and targets. column_names = dataset[split].column_names # Get the column names for input/target. if args.image_column is None: image_column = column_names[0] else: image_column = args.image_column if image_column not in column_names: raise ValueError( f"`--image_column` value '{args.image_column}' not found in dataset columns. Dataset columns are: {', '.join(column_names)}" ) if args.caption_column is None: caption_column = column_names[1] else: caption_column = args.caption_column if caption_column not in column_names: raise ValueError( f"`--caption_column` value '{args.caption_column}' not found in dataset columns. Dataset columns are: {', '.join(column_names)}" ) if args.conditioning_image_column is None: conditioning_image_column = column_names[2] else: conditioning_image_column = args.conditioning_image_column if conditioning_image_column not in column_names: raise ValueError( f"`--conditioning_image_column` value '{args.conditioning_image_column}' not found in dataset columns. Dataset columns are: {', '.join(column_names)}" ) def tokenize_captions(examples, is_train=True): captions = [] for caption in examples[caption_column]: if random.random() < args.proportion_empty_prompts: captions.append("") elif isinstance(caption, str): captions.append(caption) elif isinstance(caption, (list, np.ndarray)): # take a random caption if there are multiple captions.append(random.choice(caption) if is_train else caption[0]) else: raise ValueError( f"Caption column `{caption_column}` should contain either strings or lists of strings." ) inputs = tokenizer( captions, max_length=tokenizer.model_max_length, padding="max_length", truncation=True, return_tensors="pt" ) return inputs.input_ids image_transforms = transforms.Compose( [ transforms.Resize(args.resolution, interpolation=transforms.InterpolationMode.BILINEAR), transforms.CenterCrop(args.resolution), transforms.ToTensor(), transforms.Normalize([0.5], [0.5]), ] ) conditioning_image_transforms = transforms.Compose( [ transforms.Resize(args.resolution, interpolation=transforms.InterpolationMode.BILINEAR), transforms.CenterCrop(args.resolution), transforms.ToTensor(), ] ) def preprocess_train(examples): images = [image.convert("RGB") for image in examples[image_column]] images = [image_transforms(image) for image in images] conditioning_images = [image.convert("RGB") for image in examples[conditioning_image_column]] conditioning_images = [conditioning_image_transforms(image) for image in conditioning_images] examples["pixel_values"] = images examples["conditioning_pixel_values"] = conditioning_images examples["input_ids"] = tokenize_captions(examples) return examples with accelerator.main_process_first(): if args.max_train_samples is not None: dataset[split] = dataset[split].shuffle(seed=args.seed).select(range(args.max_train_samples)) # Set the training transforms split_dataset = dataset[split].with_transform(preprocess_train) return split_dataset def collate_fn(examples): pixel_values = torch.stack([example["pixel_values"] for example in examples]) pixel_values = pixel_values.to(memory_format=torch.contiguous_format).float() conditioning_pixel_values = torch.stack([example["conditioning_pixel_values"] for example in examples]) conditioning_pixel_values = conditioning_pixel_values.to(memory_format=torch.contiguous_format).float() input_ids = torch.stack([example["input_ids"] for example in examples]) return { "pixel_values": pixel_values, "conditioning_pixel_values": conditioning_pixel_values, "input_ids": input_ids, }