In [1]:
# check if the model is downloaded,  if not download it
import os
if not os.path.exists("instruct-pix2pix-00-22000.ckpt"):
    !wget https://huggingface.co/timbrooks/instruct-pix2pix/resolve/main/instruct-pix2pix-00-22000.ckpt
else:
    print("Model already downloaded.")

Model already downloaded.


In [2]:
import load_model

models=load_model.preload_models_from_standard_weights(ckpt_path="instruct-pix2pix-00-22000.ckpt", device="cuda", finetune_weights_path="maskfree_finetuned_weights.safetensors")

Loaded finetuned weights from maskfree_finetuned_weights.safetensors
Loading 0.in_proj.weight
Loading 0.out_proj.weight
Loading 0.out_proj.bias
Loading 8.in_proj.weight
Loading 8.out_proj.weight
Loading 8.out_proj.bias
Loading 16.in_proj.weight
Loading 16.out_proj.weight
Loading 16.out_proj.bias
Loading 24.in_proj.weight
Loading 24.out_proj.weight
Loading 24.out_proj.bias
Loading 32.in_proj.weight
Loading 32.out_proj.weight
Loading 32.out_proj.bias
Loading 40.in_proj.weight
Loading 40.out_proj.weight
Loading 40.out_proj.bias
Loading 48.in_proj.weight
Loading 48.out_proj.weight
Loading 48.out_proj.bias
Loading 56.in_proj.weight
Loading 56.out_proj.weight
Loading 56.out_proj.bias
Loading 64.in_proj.weight
Loading 64.out_proj.weight
Loading 64.out_proj.bias
Loading 72.in_proj.weight
Loading 72.out_proj.weight
Loading 72.out_proj.bias
Loading 80.in_proj.weight
Loading 80.out_proj.weight
Loading 80.out_proj.bias
Loading 88.in_proj.weight
Loading 88.out_proj.weight
Loading 88.out_proj.bias
L

In [3]:
import os
import torch
import argparse
from torch.utils.data import DataLoader
from VITON_Dataset import VITONHDTestDataset
from tqdm import tqdm
from PIL import Image
from CatVTON_model import CatVTONPix2PixPipeline

from utils import to_pil_image

@torch.no_grad()
def main():
    args=argparse.Namespace()
    args.__dict__= {
        "dataset_name": "vitonhd",
        "data_root_path": "./sample_dataset",
        "output_dir": "./mask-free-output",
        "seed": 555,
        "batch_size": 1,
        "num_inference_steps": 50,
        "guidance_scale": 2.5,
        "width": 384,
        "height": 512,
        "eval_pair": False,
        "concat_eval_results": True,
        "allow_tf32": True,
        "dataloader_num_workers": 4,
        "mixed_precision": 'no',
        "concat_axis": 'y',
        "enable_condition_noise": True,
        "is_train": False
    }

    # Pipeline
    pipeline = CatVTONPix2PixPipeline(
        weight_dtype={
            "no": torch.float32,
            "fp16": torch.float16,
            "bf16": torch.bfloat16,
        }[args.mixed_precision],
        device="cuda",
        skip_safety_check=True,
        models=models,
    )
    # Dataset
    if args.dataset_name == "vitonhd":
        dataset = VITONHDTestDataset(args)
    else:
        raise ValueError(f"Invalid dataset name {args.dataset}.")
    print(f"Dataset {args.dataset_name} loaded, total {len(dataset)} pairs.")
    dataloader = DataLoader(
        dataset,
        batch_size=args.batch_size,
        shuffle=False,
        num_workers=args.dataloader_num_workers
    )
        
    # Inference
    generator = torch.Generator(device='cuda').manual_seed(args.seed)
    args.output_dir = os.path.join(args.output_dir, f"{args.dataset_name}-{args.height}", "paired" if args.eval_pair else "unpaired")
    if not os.path.exists(args.output_dir):
        os.makedirs(args.output_dir)
        
    for batch in tqdm(dataloader):
        person_images = batch['person']
        cloth_images = batch['cloth']

        results = pipeline(
            person_images,
            cloth_images,
            num_inference_steps=args.num_inference_steps,
            guidance_scale=args.guidance_scale,
            height=args.height,
            width=args.width,
            generator=generator,
        )
        
        if args.concat_eval_results:
            person_images = to_pil_image(person_images)
            cloth_images = to_pil_image(cloth_images)
        for i, result in enumerate(results):
            person_name = batch['person_name'][i]
            output_path = os.path.join(args.output_dir, person_name)
            if not os.path.exists(os.path.dirname(output_path)):
                os.makedirs(os.path.dirname(output_path))
            if args.concat_eval_results:
                w, h = result.size
                concated_result = Image.new('RGB', (w*3, h))
                concated_result.paste(person_images[i], (0, 0))
                concated_result.paste(cloth_images[i], (w, 0))  
                concated_result.paste(result, (w*2, 0))
                result = concated_result
            result.save(output_path)

if __name__ == "__main__":
    main()

  from .autonotebook import tqdm as notebook_tqdm


AssertionError: File ./sample_dataset/samples_pairs.txt does not exist.