Spaces:
Running
Running
model training script added, removed unnecessary code.
Browse filesThis view is limited to 50 files because it contains too many changes.
See raw diff
- .gitignore +6 -0
- .vscode/launch.json +16 -0
- interface.py +0 -151
- load_model.py +75 -7
- logs.txt +29 -0
- output/vitonhd-512/unpaired/00654_00.jpg +0 -0
- output/vitonhd-512/unpaired/01265_00.jpg +0 -0
- output/vitonhd-512/unpaired/01985_00.jpg +0 -0
- output/vitonhd-512/unpaired/02023_00.jpg +0 -0
- output/vitonhd-512/unpaired/02532_00.jpg +0 -0
- output/vitonhd-512/unpaired/02944_00.jpg +0 -0
- output/vitonhd-512/unpaired/03191_00.jpg +0 -0
- output/vitonhd-512/unpaired/03921_00.jpg +0 -0
- output/vitonhd-512/unpaired/05006_00.jpg +0 -0
- output/vitonhd-512/unpaired/05378_00.jpg +0 -0
- output/vitonhd-512/unpaired/07342_00.jpg +0 -0
- output/vitonhd-512/unpaired/08088_00.jpg +0 -0
- output/vitonhd-512/unpaired/08239_00.jpg +0 -0
- output/vitonhd-512/unpaired/08650_00.jpg +0 -0
- output/vitonhd-512/unpaired/08839_00.jpg +0 -0
- output/vitonhd-512/unpaired/11085_00.jpg +0 -0
- output/vitonhd-512/unpaired/12345_00.jpg +0 -0
- output/vitonhd-512/unpaired/12419_00.jpg +0 -0
- output/vitonhd-512/unpaired/12562_00.jpg +0 -0
- output/vitonhd-512/unpaired/14651_00.jpg +0 -0
- pipeline.py +0 -314
- requirements.txt +36 -61
- sample_inference.ipynb +435 -0
- test.ipynb +0 -1430
- trained_output/vitonhd-384/unpaired/00654_00.jpg +0 -0
- trained_output/vitonhd-384/unpaired/01265_00.jpg +0 -0
- trained_output/vitonhd-384/unpaired/01985_00.jpg +0 -0
- trained_output/vitonhd-384/unpaired/02023_00.jpg +0 -0
- trained_output/vitonhd-384/unpaired/02532_00.jpg +0 -0
- trained_output/vitonhd-384/unpaired/02944_00.jpg +0 -0
- trained_output/vitonhd-384/unpaired/03191_00.jpg +0 -0
- trained_output/vitonhd-384/unpaired/03921_00.jpg +0 -0
- trained_output/vitonhd-384/unpaired/05006_00.jpg +0 -0
- trained_output/vitonhd-384/unpaired/05378_00.jpg +0 -0
- trained_output/vitonhd-384/unpaired/07342_00.jpg +0 -0
- trained_output/vitonhd-384/unpaired/08088_00.jpg +0 -0
- trained_output/vitonhd-384/unpaired/08239_00.jpg +0 -0
- trained_output/vitonhd-384/unpaired/08650_00.jpg +0 -0
- trained_output/vitonhd-384/unpaired/08839_00.jpg +0 -0
- trained_output/vitonhd-384/unpaired/11085_00.jpg +0 -0
- trained_output/vitonhd-384/unpaired/12345_00.jpg +0 -0
- trained_output/vitonhd-384/unpaired/12419_00.jpg +0 -0
- trained_output/vitonhd-384/unpaired/12562_00.jpg +0 -0
- trained_output/vitonhd-384/unpaired/14651_00.jpg +0 -0
- training.ipynb +326 -413
.gitignore
CHANGED
|
@@ -2,6 +2,12 @@
|
|
| 2 |
*sd-v1-5-inpainting.ckpt
|
| 3 |
*zalando-hd-resized.zip
|
| 4 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 5 |
# Byte-compiled / optimized / DLL files
|
| 6 |
__pycache__/
|
| 7 |
**/__pycache__/
|
|
|
|
| 2 |
*sd-v1-5-inpainting.ckpt
|
| 3 |
*zalando-hd-resized.zip
|
| 4 |
|
| 5 |
+
# *viton-hd-dataset.zip
|
| 6 |
+
viton-hd-dataset/
|
| 7 |
+
checkpoints/
|
| 8 |
+
|
| 9 |
+
*finetuned_weights.safetensors
|
| 10 |
+
|
| 11 |
# Byte-compiled / optimized / DLL files
|
| 12 |
__pycache__/
|
| 13 |
**/__pycache__/
|
.vscode/launch.json
ADDED
|
@@ -0,0 +1,16 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
// Use IntelliSense to learn about possible attributes.
|
| 3 |
+
// Hover to view descriptions of existing attributes.
|
| 4 |
+
// For more information, visit: https://go.microsoft.com/fwlink/?linkid=830387
|
| 5 |
+
"version": "0.2.0",
|
| 6 |
+
"configurations": [
|
| 7 |
+
{
|
| 8 |
+
"name": "Python Debugger: Current File",
|
| 9 |
+
"type": "debugpy",
|
| 10 |
+
"request": "launch",
|
| 11 |
+
"program": "${file}",
|
| 12 |
+
"console": "integratedTerminal",
|
| 13 |
+
"subProcess": true
|
| 14 |
+
}
|
| 15 |
+
]
|
| 16 |
+
}
|
interface.py
DELETED
|
@@ -1,151 +0,0 @@
|
|
| 1 |
-
import gradio as gr
|
| 2 |
-
import torch
|
| 3 |
-
from PIL import Image
|
| 4 |
-
from transformers import CLIPTokenizer
|
| 5 |
-
|
| 6 |
-
# Import your existing model and pipeline modules
|
| 7 |
-
import load_model
|
| 8 |
-
import pipeline
|
| 9 |
-
|
| 10 |
-
# Device Configuration
|
| 11 |
-
ALLOW_CUDA = True
|
| 12 |
-
ALLOW_MPS = False
|
| 13 |
-
|
| 14 |
-
def determine_device():
|
| 15 |
-
if torch.cuda.is_available() and ALLOW_CUDA:
|
| 16 |
-
return "cuda"
|
| 17 |
-
elif (torch.backends.mps.is_built() or torch.backends.mps.is_available()) and ALLOW_MPS:
|
| 18 |
-
return "mps"
|
| 19 |
-
return "cpu"
|
| 20 |
-
|
| 21 |
-
DEVICE = determine_device()
|
| 22 |
-
print(f"Using device: {DEVICE}")
|
| 23 |
-
|
| 24 |
-
# Load tokenizer and models
|
| 25 |
-
tokenizer = CLIPTokenizer("vocab.json", merges_file="merges.txt")
|
| 26 |
-
model_file = "inkpunk-diffusion-v1.ckpt"
|
| 27 |
-
models = load_model.preload_models_from_standard_weights(model_file, DEVICE)
|
| 28 |
-
# models=None
|
| 29 |
-
|
| 30 |
-
def generate_image(
|
| 31 |
-
prompt,
|
| 32 |
-
uncond_prompt="",
|
| 33 |
-
do_cfg=True,
|
| 34 |
-
cfg_scale=8,
|
| 35 |
-
sampler="ddpm",
|
| 36 |
-
num_inference_steps=50,
|
| 37 |
-
seed=42,
|
| 38 |
-
input_image=None,
|
| 39 |
-
strength=1.0
|
| 40 |
-
):
|
| 41 |
-
"""
|
| 42 |
-
Generate an image using the Stable Diffusion pipeline
|
| 43 |
-
|
| 44 |
-
Args:
|
| 45 |
-
- prompt (str): Text description of the image to generate
|
| 46 |
-
- uncond_prompt (str, optional): Negative prompt to guide generation
|
| 47 |
-
- do_cfg (bool): Whether to use classifier-free guidance
|
| 48 |
-
- cfg_scale (float): Classifier-free guidance scale
|
| 49 |
-
- sampler (str): Sampling method
|
| 50 |
-
- num_inference_steps (int): Number of denoising steps
|
| 51 |
-
- seed (int): Random seed for reproducibility
|
| 52 |
-
- input_image (PIL.Image, optional): Input image for image-to-image generation
|
| 53 |
-
- strength (float): Strength of image transformation (0-1)
|
| 54 |
-
|
| 55 |
-
Returns:
|
| 56 |
-
- PIL.Image: Generated image
|
| 57 |
-
"""
|
| 58 |
-
try:
|
| 59 |
-
# Ensure input_image is None if not provided
|
| 60 |
-
if input_image is None:
|
| 61 |
-
strength = 1.0
|
| 62 |
-
|
| 63 |
-
# Generate the image
|
| 64 |
-
output_image = pipeline.generate(
|
| 65 |
-
prompt=prompt,
|
| 66 |
-
uncond_prompt=uncond_prompt,
|
| 67 |
-
input_image=input_image,
|
| 68 |
-
strength=strength,
|
| 69 |
-
do_cfg=do_cfg,
|
| 70 |
-
cfg_scale=cfg_scale,
|
| 71 |
-
sampler_name=sampler,
|
| 72 |
-
n_inference_steps=num_inference_steps,
|
| 73 |
-
seed=seed,
|
| 74 |
-
models=models,
|
| 75 |
-
device=DEVICE,
|
| 76 |
-
idle_device="cuda",
|
| 77 |
-
tokenizer=tokenizer,
|
| 78 |
-
)
|
| 79 |
-
|
| 80 |
-
# Convert numpy array to PIL Image
|
| 81 |
-
return Image.fromarray(output_image)
|
| 82 |
-
|
| 83 |
-
except Exception as e:
|
| 84 |
-
print(f"Error generating image: {e}")
|
| 85 |
-
return None
|
| 86 |
-
|
| 87 |
-
def launch_gradio_interface():
|
| 88 |
-
"""
|
| 89 |
-
Create and launch Gradio interface for Stable Diffusion
|
| 90 |
-
"""
|
| 91 |
-
with gr.Blocks(title="Stable Diffusion Image Generator") as demo:
|
| 92 |
-
gr.Markdown("# 🎨 Stable Diffusion Image Generator")
|
| 93 |
-
|
| 94 |
-
with gr.Row():
|
| 95 |
-
with gr.Column():
|
| 96 |
-
# Text Inputs
|
| 97 |
-
prompt = gr.Textbox(label="Prompt",
|
| 98 |
-
placeholder="Describe the image you want to generate...")
|
| 99 |
-
uncond_prompt = gr.Textbox(label="Negative Prompt (Optional)",
|
| 100 |
-
placeholder="Describe what you don't want in the image...")
|
| 101 |
-
|
| 102 |
-
# Generation Parameters
|
| 103 |
-
with gr.Accordion("Advanced Settings", open=False):
|
| 104 |
-
do_cfg = gr.Checkbox(label="Use Classifier-Free Guidance", value=True)
|
| 105 |
-
cfg_scale = gr.Slider(minimum=1, maximum=14, value=8, label="CFG Scale")
|
| 106 |
-
sampler = gr.Dropdown(
|
| 107 |
-
choices=["ddpm", "ddim", "pndm"], # Add more samplers if available
|
| 108 |
-
value="ddpm",
|
| 109 |
-
label="Sampling Method"
|
| 110 |
-
)
|
| 111 |
-
num_inference_steps = gr.Slider(
|
| 112 |
-
minimum=10,
|
| 113 |
-
maximum=100,
|
| 114 |
-
value=50,
|
| 115 |
-
label="Number of Inference Steps"
|
| 116 |
-
)
|
| 117 |
-
seed = gr.Number(value=42, label="Random Seed")
|
| 118 |
-
|
| 119 |
-
# Image-to-Image Section
|
| 120 |
-
with gr.Accordion("Image-to-Image", open=False):
|
| 121 |
-
input_image = gr.Image(type="pil", label="Input Image (Optional)")
|
| 122 |
-
strength = gr.Slider(
|
| 123 |
-
minimum=0,
|
| 124 |
-
maximum=1,
|
| 125 |
-
value=0.8,
|
| 126 |
-
label="Image Transformation Strength"
|
| 127 |
-
)
|
| 128 |
-
|
| 129 |
-
# Generate Button
|
| 130 |
-
generate_btn = gr.Button("Generate Image", variant="primary")
|
| 131 |
-
|
| 132 |
-
with gr.Row():
|
| 133 |
-
# Output Image
|
| 134 |
-
output_image = gr.Image(label="Generated Image")
|
| 135 |
-
|
| 136 |
-
# Connect Button to Generation Function
|
| 137 |
-
generate_btn.click(
|
| 138 |
-
fn=generate_image,
|
| 139 |
-
inputs=[
|
| 140 |
-
prompt, uncond_prompt, do_cfg, cfg_scale,
|
| 141 |
-
sampler, num_inference_steps, seed,
|
| 142 |
-
input_image, strength
|
| 143 |
-
],
|
| 144 |
-
outputs=output_image
|
| 145 |
-
)
|
| 146 |
-
|
| 147 |
-
# Launch the interface
|
| 148 |
-
demo.launch(server_name="0.0.0.0", server_port=7860)
|
| 149 |
-
|
| 150 |
-
if __name__ == "__main__":
|
| 151 |
-
launch_gradio_interface()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
load_model.py
CHANGED
|
@@ -1,11 +1,81 @@
|
|
| 1 |
from clip import CLIP
|
| 2 |
from encoder import VAE_Encoder
|
| 3 |
from decoder import VAE_Decoder
|
| 4 |
-
from diffusion import Diffusion
|
| 5 |
-
|
| 6 |
import model_converter
|
| 7 |
import torch
|
| 8 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 9 |
def preload_models_from_standard_weights(ckpt_path, device, finetune_weights_path=None):
|
| 10 |
# CatVTON parameters
|
| 11 |
in_channels = 9
|
|
@@ -14,12 +84,10 @@ def preload_models_from_standard_weights(ckpt_path, device, finetune_weights_pat
|
|
| 14 |
state_dict=model_converter.load_from_standard_weights(ckpt_path, device)
|
| 15 |
|
| 16 |
diffusion=Diffusion(in_channels=in_channels, out_channels=out_channels).to(device)
|
| 17 |
-
|
|
|
|
| 18 |
if finetune_weights_path != None:
|
| 19 |
-
|
| 20 |
-
diffusion.load_state_dict(checkpoint['diffusion_state_dict'], strict=True)
|
| 21 |
-
else:
|
| 22 |
-
diffusion.load_state_dict(state_dict['diffusion'], strict=True)
|
| 23 |
|
| 24 |
encoder=VAE_Encoder().to(device)
|
| 25 |
encoder.load_state_dict(state_dict['encoder'], strict=True)
|
|
|
|
| 1 |
from clip import CLIP
|
| 2 |
from encoder import VAE_Encoder
|
| 3 |
from decoder import VAE_Decoder
|
| 4 |
+
from diffusion import Diffusion, UNET_AttentionBlock
|
| 5 |
+
from safetensors.torch import load_file
|
| 6 |
import model_converter
|
| 7 |
import torch
|
| 8 |
|
| 9 |
+
def load_finetuned_attention_weights(finetune_weights_path, diffusion, device):
|
| 10 |
+
updated_loaded_data = load_file(finetune_weights_path, device=device)
|
| 11 |
+
print(f"Loaded finetuned weights from {finetune_weights_path}")
|
| 12 |
+
|
| 13 |
+
unet= diffusion.unet
|
| 14 |
+
idx = 0
|
| 15 |
+
# Iterate through the attention layers in the encoders
|
| 16 |
+
for layers in unet.encoders:
|
| 17 |
+
for layer in layers:
|
| 18 |
+
if isinstance(layer, UNET_AttentionBlock):
|
| 19 |
+
# Get the parameters from the loaded data for this block
|
| 20 |
+
in_proj_weight_key = f"{idx}.in_proj.weight"
|
| 21 |
+
out_proj_weight_key = f"{idx}.out_proj.weight"
|
| 22 |
+
out_proj_bias_key = f"{idx}.out_proj.bias"
|
| 23 |
+
|
| 24 |
+
# Load the weights if they exist in the loaded data
|
| 25 |
+
if in_proj_weight_key in updated_loaded_data:
|
| 26 |
+
print(f"Loading {in_proj_weight_key}")
|
| 27 |
+
layer.attention_1.in_proj.weight.data.copy_(updated_loaded_data[in_proj_weight_key])
|
| 28 |
+
if out_proj_weight_key in updated_loaded_data:
|
| 29 |
+
print(f"Loading {out_proj_weight_key}")
|
| 30 |
+
layer.attention_1.out_proj.weight.data.copy_(updated_loaded_data[out_proj_weight_key])
|
| 31 |
+
if out_proj_bias_key in updated_loaded_data:
|
| 32 |
+
print(f"Loading {out_proj_bias_key}")
|
| 33 |
+
layer.attention_1.out_proj.bias.data.copy_(updated_loaded_data[out_proj_bias_key])
|
| 34 |
+
idx += 8
|
| 35 |
+
|
| 36 |
+
# Move to the next attention block index in the loaded data
|
| 37 |
+
|
| 38 |
+
|
| 39 |
+
# Iterate through the attention layers in the decoders
|
| 40 |
+
for layers in unet.decoders:
|
| 41 |
+
for layer in layers:
|
| 42 |
+
if isinstance(layer, UNET_AttentionBlock):
|
| 43 |
+
in_proj_weight_key = f"{idx}.in_proj.weight"
|
| 44 |
+
out_proj_weight_key = f"{idx}.out_proj.weight"
|
| 45 |
+
out_proj_bias_key = f"{idx}.out_proj.bias"
|
| 46 |
+
|
| 47 |
+
if in_proj_weight_key in updated_loaded_data:
|
| 48 |
+
print(f"Loading {in_proj_weight_key}")
|
| 49 |
+
layer.attention_1.in_proj.weight.data.copy_(updated_loaded_data[in_proj_weight_key])
|
| 50 |
+
if out_proj_weight_key in updated_loaded_data:
|
| 51 |
+
print(f"Loading {out_proj_weight_key}")
|
| 52 |
+
layer.attention_1.out_proj.weight.data.copy_(updated_loaded_data[out_proj_weight_key])
|
| 53 |
+
if out_proj_bias_key in updated_loaded_data:
|
| 54 |
+
print(f"Loading {out_proj_bias_key}")
|
| 55 |
+
layer.attention_1.out_proj.bias.data.copy_(updated_loaded_data[out_proj_bias_key])
|
| 56 |
+
idx += 8
|
| 57 |
+
|
| 58 |
+
|
| 59 |
+
# Iterate through the attention layers in the bottleneck
|
| 60 |
+
for layer in unet.bottleneck:
|
| 61 |
+
if isinstance(layer, UNET_AttentionBlock):
|
| 62 |
+
in_proj_weight_key = f"{idx}.in_proj.weight"
|
| 63 |
+
out_proj_weight_key = f"{idx}.out_proj.weight"
|
| 64 |
+
out_proj_bias_key = f"{idx}.out_proj.bias"
|
| 65 |
+
|
| 66 |
+
if in_proj_weight_key in updated_loaded_data:
|
| 67 |
+
print(f"Loading {in_proj_weight_key}")
|
| 68 |
+
layer.attention_1.in_proj.weight.data.copy_(updated_loaded_data[in_proj_weight_key])
|
| 69 |
+
if out_proj_weight_key in updated_loaded_data:
|
| 70 |
+
print(f"Loading {out_proj_weight_key}")
|
| 71 |
+
layer.attention_1.out_proj.weight.data.copy_(updated_loaded_data[out_proj_weight_key])
|
| 72 |
+
if out_proj_bias_key in updated_loaded_data:
|
| 73 |
+
print(f"Loading {out_proj_bias_key}")
|
| 74 |
+
layer.attention_1.out_proj.bias.data.copy_(updated_loaded_data[out_proj_bias_key])
|
| 75 |
+
idx += 8
|
| 76 |
+
|
| 77 |
+
print("\nAttention module weights loaded from {finetune_weights_path} successfully.")
|
| 78 |
+
|
| 79 |
def preload_models_from_standard_weights(ckpt_path, device, finetune_weights_path=None):
|
| 80 |
# CatVTON parameters
|
| 81 |
in_channels = 9
|
|
|
|
| 84 |
state_dict=model_converter.load_from_standard_weights(ckpt_path, device)
|
| 85 |
|
| 86 |
diffusion=Diffusion(in_channels=in_channels, out_channels=out_channels).to(device)
|
| 87 |
+
diffusion.load_state_dict(state_dict['diffusion'], strict=True)
|
| 88 |
+
|
| 89 |
if finetune_weights_path != None:
|
| 90 |
+
load_finetuned_attention_weights(finetune_weights_path, diffusion, device)
|
|
|
|
|
|
|
|
|
|
| 91 |
|
| 92 |
encoder=VAE_Encoder().to(device)
|
| 93 |
encoder.load_state_dict(state_dict['encoder'], strict=True)
|
logs.txt
ADDED
|
@@ -0,0 +1,29 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
/home/mahesh/harsh/stable-diffusion/training.py:84: FutureWarning: `torch.cuda.amp.GradScaler(args...)` is deprecated. Please use `torch.amp.GradScaler('cuda', args...)` instead.
|
| 2 |
+
self.scaler = torch.cuda.amp.GradScaler()
|
| 3 |
+
----------------------------------------------------------------------------------------------------
|
| 4 |
+
Loading pretrained models...
|
| 5 |
+
Models loaded successfully.
|
| 6 |
+
----------------------------------------------------------------------------------------------------
|
| 7 |
+
Creating dataloader...
|
| 8 |
+
Dataset vitonhd loaded, total 11647 pairs.
|
| 9 |
+
Training for 50 epochs
|
| 10 |
+
Batches per epoch: 5824
|
| 11 |
+
----------------------------------------------------------------------------------------------------
|
| 12 |
+
Initializing trainer...
|
| 13 |
+
Enabling PEFT training (self-attention layers only)
|
| 14 |
+
Total parameters: 899,226,667
|
| 15 |
+
Trainable parameters: 49,574,080 (5.51%)
|
| 16 |
+
Checkpoint loaded: ./checkpoints/checkpoint_step_50000.pth
|
| 17 |
+
Resuming from epoch 13, step 50000
|
| 18 |
+
Starting training...
|
| 19 |
+
Starting training for 50 epochs
|
| 20 |
+
Total training batches per epoch: 5824
|
| 21 |
+
Using DREAM with lambda = 0
|
| 22 |
+
Mixed precision: True
|
| 23 |
+
/home/mahesh/harsh/stable-diffusion/training.py:304: FutureWarning: `torch.cuda.amp.autocast(args...)` is deprecated. Please use `torch.amp.autocast('cuda', args...)` instead.
|
| 24 |
+
with torch.cuda.amp.autocast():
|
| 25 |
+
/home/mahesh/harsh/stable-diffusion/training.py:194: FutureWarning: `torch.cuda.amp.autocast(args...)` is deprecated. Please use `torch.amp.autocast('cuda', args...)` instead.
|
| 26 |
+
with torch.cuda.amp.autocast(enabled=self.use_mixed_precision):
|
| 27 |
+
/home/mahesh/harsh/stable-diffusion/utils.py:491: FutureWarning: `torch.cuda.amp.autocast(args...)` is deprecated. Please use `torch.amp.autocast('cuda', args...)` instead.
|
| 28 |
+
with torch.cuda.amp.autocast(enabled=False):
|
| 29 |
+
|
output/vitonhd-512/unpaired/00654_00.jpg
CHANGED
|
|
output/vitonhd-512/unpaired/01265_00.jpg
CHANGED
|
|
output/vitonhd-512/unpaired/01985_00.jpg
CHANGED
|
|
output/vitonhd-512/unpaired/02023_00.jpg
CHANGED
|
|
output/vitonhd-512/unpaired/02532_00.jpg
CHANGED
|
|
output/vitonhd-512/unpaired/02944_00.jpg
CHANGED
|
|
output/vitonhd-512/unpaired/03191_00.jpg
CHANGED
|
|
output/vitonhd-512/unpaired/03921_00.jpg
CHANGED
|
|
output/vitonhd-512/unpaired/05006_00.jpg
CHANGED
|
|
output/vitonhd-512/unpaired/05378_00.jpg
CHANGED
|
|
output/vitonhd-512/unpaired/07342_00.jpg
CHANGED
|
|
output/vitonhd-512/unpaired/08088_00.jpg
CHANGED
|
|
output/vitonhd-512/unpaired/08239_00.jpg
CHANGED
|
|
output/vitonhd-512/unpaired/08650_00.jpg
CHANGED
|
|
output/vitonhd-512/unpaired/08839_00.jpg
CHANGED
|
|
output/vitonhd-512/unpaired/11085_00.jpg
CHANGED
|
|
output/vitonhd-512/unpaired/12345_00.jpg
CHANGED
|
|
output/vitonhd-512/unpaired/12419_00.jpg
CHANGED
|
|
output/vitonhd-512/unpaired/12562_00.jpg
CHANGED
|
|
output/vitonhd-512/unpaired/14651_00.jpg
CHANGED
|
|
pipeline.py
DELETED
|
@@ -1,314 +0,0 @@
|
|
| 1 |
-
import math
|
| 2 |
-
from typing import List, Union
|
| 3 |
-
import PIL
|
| 4 |
-
import torch
|
| 5 |
-
import numpy as np
|
| 6 |
-
from tqdm import tqdm
|
| 7 |
-
from ddpm import DDPMSampler
|
| 8 |
-
from PIL import Image
|
| 9 |
-
import load_model
|
| 10 |
-
from utils import check_inputs, prepare_image, prepare_mask_image
|
| 11 |
-
|
| 12 |
-
WIDTH = 512
|
| 13 |
-
HEIGHT = 512
|
| 14 |
-
LATENTS_WIDTH = WIDTH // 8
|
| 15 |
-
LATENTS_HEIGHT = HEIGHT // 8
|
| 16 |
-
|
| 17 |
-
def repaint_result(result, person_image, mask_image):
|
| 18 |
-
result, person, mask = np.array(result), np.array(person_image), np.array(mask_image)
|
| 19 |
-
# expand the mask to 3 channels & to 0~1
|
| 20 |
-
mask = np.expand_dims(mask, axis=2)
|
| 21 |
-
mask = mask / 255.0
|
| 22 |
-
# mask for result, ~mask for person
|
| 23 |
-
result_ = result * mask + person * (1 - mask)
|
| 24 |
-
return Image.fromarray(result_.astype(np.uint8))
|
| 25 |
-
|
| 26 |
-
def numpy_to_pil(images):
|
| 27 |
-
"""
|
| 28 |
-
Convert a numpy image or a batch of images to a PIL image.
|
| 29 |
-
"""
|
| 30 |
-
if images.ndim == 3:
|
| 31 |
-
images = images[None, ...]
|
| 32 |
-
images = (images * 255).round().astype("uint8")
|
| 33 |
-
if images.shape[-1] == 1:
|
| 34 |
-
# special case for grayscale (single channel) images
|
| 35 |
-
pil_images = [Image.fromarray(image.squeeze(), mode="L") for image in images]
|
| 36 |
-
else:
|
| 37 |
-
pil_images = [Image.fromarray(image) for image in images]
|
| 38 |
-
|
| 39 |
-
return pil_images
|
| 40 |
-
|
| 41 |
-
|
| 42 |
-
def tensor_to_image(tensor: torch.Tensor):
|
| 43 |
-
"""
|
| 44 |
-
Converts a torch tensor to PIL Image.
|
| 45 |
-
"""
|
| 46 |
-
assert tensor.dim() == 3, "Input tensor should be 3-dimensional."
|
| 47 |
-
assert tensor.dtype == torch.float32, "Input tensor should be float32."
|
| 48 |
-
assert (
|
| 49 |
-
tensor.min() >= 0 and tensor.max() <= 1
|
| 50 |
-
), "Input tensor should be in range [0, 1]."
|
| 51 |
-
tensor = tensor.cpu()
|
| 52 |
-
tensor = tensor * 255
|
| 53 |
-
tensor = tensor.permute(1, 2, 0)
|
| 54 |
-
tensor = tensor.numpy().astype(np.uint8)
|
| 55 |
-
image = Image.fromarray(tensor)
|
| 56 |
-
return image
|
| 57 |
-
|
| 58 |
-
|
| 59 |
-
def concat_images(images: List[Image.Image], divider: int = 4, cols: int = 4):
|
| 60 |
-
"""
|
| 61 |
-
Concatenates images horizontally and with
|
| 62 |
-
"""
|
| 63 |
-
widths = [image.size[0] for image in images]
|
| 64 |
-
heights = [image.size[1] for image in images]
|
| 65 |
-
total_width = cols * max(widths)
|
| 66 |
-
total_width += divider * (cols - 1)
|
| 67 |
-
# `col` images each row
|
| 68 |
-
rows = math.ceil(len(images) / cols)
|
| 69 |
-
total_height = max(heights) * rows
|
| 70 |
-
# add divider between rows
|
| 71 |
-
total_height += divider * (len(heights) // cols - 1)
|
| 72 |
-
|
| 73 |
-
# all black image
|
| 74 |
-
concat_image = Image.new("RGB", (total_width, total_height), (0, 0, 0))
|
| 75 |
-
|
| 76 |
-
x_offset = 0
|
| 77 |
-
y_offset = 0
|
| 78 |
-
for i, image in enumerate(images):
|
| 79 |
-
concat_image.paste(image, (x_offset, y_offset))
|
| 80 |
-
x_offset += image.size[0] + divider
|
| 81 |
-
if (i + 1) % cols == 0:
|
| 82 |
-
x_offset = 0
|
| 83 |
-
y_offset += image.size[1] + divider
|
| 84 |
-
|
| 85 |
-
return concat_image
|
| 86 |
-
|
| 87 |
-
def compute_vae_encodings(image_tensor, encoder, device):
|
| 88 |
-
"""Encode image using VAE encoder"""
|
| 89 |
-
# Generate random noise for encoding
|
| 90 |
-
encoder_noise = torch.randn(
|
| 91 |
-
(image_tensor.shape[0], 4, image_tensor.shape[2] // 8, image_tensor.shape[3] // 8),
|
| 92 |
-
device=device,
|
| 93 |
-
)
|
| 94 |
-
|
| 95 |
-
# Encode using your custom encoder
|
| 96 |
-
latent = encoder(image_tensor, encoder_noise)
|
| 97 |
-
return latent
|
| 98 |
-
|
| 99 |
-
|
| 100 |
-
def generate(
|
| 101 |
-
image: Union[PIL.Image.Image, torch.Tensor],
|
| 102 |
-
condition_image: Union[PIL.Image.Image, torch.Tensor],
|
| 103 |
-
mask: Union[PIL.Image.Image, torch.Tensor],
|
| 104 |
-
num_inference_steps: int = 50,
|
| 105 |
-
guidance_scale: float = 2.5,
|
| 106 |
-
height: int = 1024,
|
| 107 |
-
width: int = 768,
|
| 108 |
-
models={},
|
| 109 |
-
sampler_name="ddpm",
|
| 110 |
-
seed=None,
|
| 111 |
-
device=None,
|
| 112 |
-
idle_device=None,
|
| 113 |
-
**kwargs
|
| 114 |
-
):
|
| 115 |
-
with torch.no_grad():
|
| 116 |
-
if idle_device:
|
| 117 |
-
to_idle = lambda x: x.to(idle_device)
|
| 118 |
-
else:
|
| 119 |
-
to_idle = lambda x: x
|
| 120 |
-
|
| 121 |
-
# Initialize random number generator according to the seed specified
|
| 122 |
-
generator = torch.Generator(device=device)
|
| 123 |
-
if seed is None:
|
| 124 |
-
generator.seed()
|
| 125 |
-
else:
|
| 126 |
-
generator.manual_seed(seed)
|
| 127 |
-
|
| 128 |
-
concat_dim = -1 # FIXME: y axis concat
|
| 129 |
-
|
| 130 |
-
# Prepare inputs to Tensor
|
| 131 |
-
image, condition_image, mask = check_inputs(image, condition_image, mask, width, height)
|
| 132 |
-
# print(f"Input image shape: {image.shape}, condition image shape: {condition_image.shape}, mask shape: {mask.shape}")
|
| 133 |
-
image = prepare_image(image).to(device)
|
| 134 |
-
condition_image = prepare_image(condition_image).to(device)
|
| 135 |
-
mask = prepare_mask_image(mask).to(device)
|
| 136 |
-
|
| 137 |
-
print(f"Prepared image shape: {image.shape}, condition image shape: {condition_image.shape}, mask shape: {mask.shape}")
|
| 138 |
-
# Mask image
|
| 139 |
-
masked_image = image * (mask < 0.5)
|
| 140 |
-
|
| 141 |
-
print(f"Masked image shape: {masked_image.shape}")
|
| 142 |
-
|
| 143 |
-
# VAE encoding
|
| 144 |
-
encoder = models.get('encoder', None)
|
| 145 |
-
if encoder is None:
|
| 146 |
-
raise ValueError("Encoder model not found in models dictionary")
|
| 147 |
-
|
| 148 |
-
encoder.to(device)
|
| 149 |
-
masked_latent = compute_vae_encodings(masked_image, encoder, device)
|
| 150 |
-
condition_latent = compute_vae_encodings(condition_image, encoder, device)
|
| 151 |
-
to_idle(encoder)
|
| 152 |
-
|
| 153 |
-
print(f"Masked latent shape: {masked_latent.shape}, condition latent shape: {condition_latent.shape}")
|
| 154 |
-
|
| 155 |
-
# Concatenate latents
|
| 156 |
-
masked_latent_concat = torch.cat([masked_latent, condition_latent], dim=concat_dim)
|
| 157 |
-
|
| 158 |
-
print(f"Masked Person latent + garment latent: {masked_latent_concat.shape}")
|
| 159 |
-
|
| 160 |
-
mask_latent = torch.nn.functional.interpolate(mask, size=masked_latent.shape[-2:], mode="nearest")
|
| 161 |
-
del image, mask, condition_image
|
| 162 |
-
mask_latent_concat = torch.cat([mask_latent, torch.zeros_like(mask_latent)], dim=concat_dim)
|
| 163 |
-
|
| 164 |
-
print(f"Mask latent concat shape: {mask_latent_concat.shape}")
|
| 165 |
-
|
| 166 |
-
# Initialize latents
|
| 167 |
-
latents = torch.randn(
|
| 168 |
-
masked_latent_concat.shape,
|
| 169 |
-
generator=generator,
|
| 170 |
-
device=masked_latent_concat.device,
|
| 171 |
-
dtype=masked_latent_concat.dtype
|
| 172 |
-
)
|
| 173 |
-
|
| 174 |
-
print(f"Latents shape: {latents.shape}")
|
| 175 |
-
|
| 176 |
-
# Prepare timesteps
|
| 177 |
-
if sampler_name == "ddpm":
|
| 178 |
-
sampler = DDPMSampler(generator)
|
| 179 |
-
sampler.set_inference_timesteps(num_inference_steps)
|
| 180 |
-
else:
|
| 181 |
-
raise ValueError("Unknown sampler value %s. " % sampler_name)
|
| 182 |
-
|
| 183 |
-
timesteps = sampler.timesteps
|
| 184 |
-
# latents = sampler.add_noise(latents, timesteps[0])
|
| 185 |
-
|
| 186 |
-
# Classifier-Free Guidance
|
| 187 |
-
do_classifier_free_guidance = guidance_scale > 1.0
|
| 188 |
-
if do_classifier_free_guidance:
|
| 189 |
-
masked_latent_concat = torch.cat(
|
| 190 |
-
[
|
| 191 |
-
torch.cat([masked_latent, torch.zeros_like(condition_latent)], dim=concat_dim),
|
| 192 |
-
masked_latent_concat,
|
| 193 |
-
]
|
| 194 |
-
)
|
| 195 |
-
mask_latent_concat = torch.cat([mask_latent_concat] * 2)
|
| 196 |
-
|
| 197 |
-
print(f"Masked latent concat for classifier-free guidance: {masked_latent_concat.shape}, mask latent concat: {mask_latent_concat.shape}")
|
| 198 |
-
|
| 199 |
-
|
| 200 |
-
# Denoising loop - Fixed: removed self references and incorrect scheduler calls
|
| 201 |
-
num_warmup_steps = 0 # For simple DDPM, no warmup needed
|
| 202 |
-
|
| 203 |
-
with tqdm(total=num_inference_steps) as progress_bar:
|
| 204 |
-
for i, t in enumerate(timesteps):
|
| 205 |
-
# expand the latents if we are doing classifier free guidance
|
| 206 |
-
non_inpainting_latent_model_input = (torch.cat([latents] * 2) if do_classifier_free_guidance else latents)
|
| 207 |
-
|
| 208 |
-
# print(f"Non-inpainting latent model input shape: {non_inpainting_latent_model_input.shape}")
|
| 209 |
-
|
| 210 |
-
# prepare the input for the inpainting model
|
| 211 |
-
inpainting_latent_model_input = torch.cat([non_inpainting_latent_model_input, mask_latent_concat, masked_latent_concat], dim=1)
|
| 212 |
-
|
| 213 |
-
# print(f"Inpainting latent model input shape: {inpainting_latent_model_input.shape}")
|
| 214 |
-
|
| 215 |
-
# predict the noise residual
|
| 216 |
-
diffusion = models.get('diffusion', None)
|
| 217 |
-
if diffusion is None:
|
| 218 |
-
raise ValueError("Diffusion model not found in models dictionary")
|
| 219 |
-
|
| 220 |
-
diffusion.to(device)
|
| 221 |
-
|
| 222 |
-
# Create time embedding for the current timestep
|
| 223 |
-
time_embedding = get_time_embedding(t.item()).to(device)
|
| 224 |
-
# print(f"Time embedding shape: {time_embedding.shape}")
|
| 225 |
-
|
| 226 |
-
if do_classifier_free_guidance:
|
| 227 |
-
time_embedding = torch.cat([time_embedding] * 2)
|
| 228 |
-
|
| 229 |
-
noise_pred = diffusion(
|
| 230 |
-
inpainting_latent_model_input,
|
| 231 |
-
time_embedding
|
| 232 |
-
)
|
| 233 |
-
|
| 234 |
-
to_idle(diffusion)
|
| 235 |
-
|
| 236 |
-
# perform guidance
|
| 237 |
-
if do_classifier_free_guidance:
|
| 238 |
-
noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
|
| 239 |
-
noise_pred = noise_pred_uncond + guidance_scale * (
|
| 240 |
-
noise_pred_text - noise_pred_uncond
|
| 241 |
-
)
|
| 242 |
-
|
| 243 |
-
# compute the previous noisy sample x_t -> x_t-1
|
| 244 |
-
latents = sampler.step(t, latents, noise_pred)
|
| 245 |
-
|
| 246 |
-
# Update progress bar
|
| 247 |
-
if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps):
|
| 248 |
-
progress_bar.update()
|
| 249 |
-
|
| 250 |
-
# Decode the final latents
|
| 251 |
-
latents = latents.split(latents.shape[concat_dim] // 2, dim=concat_dim)[0]
|
| 252 |
-
|
| 253 |
-
decoder = models.get('decoder', None)
|
| 254 |
-
if decoder is None:
|
| 255 |
-
raise ValueError("Decoder model not found in models dictionary")
|
| 256 |
-
|
| 257 |
-
decoder.to(device)
|
| 258 |
-
|
| 259 |
-
image = decoder(latents.to(device))
|
| 260 |
-
# image = rescale(image, (-1, 1), (0, 255), clamp=True)
|
| 261 |
-
image = (image / 2 + 0.5).clamp(0, 1)
|
| 262 |
-
# we always cast to float32 as this does not cause significant overhead and is compatible with bfloat16
|
| 263 |
-
image = image.cpu().permute(0, 2, 3, 1).float().numpy()
|
| 264 |
-
image = numpy_to_pil(image)
|
| 265 |
-
|
| 266 |
-
to_idle(decoder)
|
| 267 |
-
|
| 268 |
-
return image
|
| 269 |
-
|
| 270 |
-
|
| 271 |
-
def rescale(x, old_range, new_range, clamp=False):
|
| 272 |
-
old_min, old_max = old_range
|
| 273 |
-
new_min, new_max = new_range
|
| 274 |
-
x -= old_min
|
| 275 |
-
x *= (new_max - new_min) / (old_max - old_min)
|
| 276 |
-
x += new_min
|
| 277 |
-
if clamp:
|
| 278 |
-
x = x.clamp(new_min, new_max)
|
| 279 |
-
return x
|
| 280 |
-
|
| 281 |
-
def get_time_embedding(timestep):
|
| 282 |
-
# Shape: (160,)
|
| 283 |
-
freqs = torch.pow(10000, -torch.arange(start=0, end=160, dtype=torch.float32) / 160)
|
| 284 |
-
# Shape: (1, 160)
|
| 285 |
-
x = torch.tensor([timestep], dtype=torch.float32)[:, None] * freqs[None]
|
| 286 |
-
# Shape: (1, 160 * 2) -> (1, 320)
|
| 287 |
-
return torch.cat([torch.cos(x), torch.sin(x)], dim=-1)
|
| 288 |
-
|
| 289 |
-
if __name__ == "__main__":
|
| 290 |
-
# Example usage
|
| 291 |
-
image = Image.open("person.jpg").convert("RGB")
|
| 292 |
-
condition_image = Image.open("image.png").convert("RGB")
|
| 293 |
-
mask = Image.open("agnostic_mask.png").convert("L")
|
| 294 |
-
|
| 295 |
-
# Load models
|
| 296 |
-
models=load_model.preload_models_from_standard_weights("sd-v1-5-inpainting.ckpt", device="cuda")
|
| 297 |
-
|
| 298 |
-
# Generate image
|
| 299 |
-
generated_image = generate(
|
| 300 |
-
image=image,
|
| 301 |
-
condition_image=condition_image,
|
| 302 |
-
mask=mask,
|
| 303 |
-
num_inference_steps=50,
|
| 304 |
-
guidance_scale=2.5,
|
| 305 |
-
width=WIDTH,
|
| 306 |
-
height=HEIGHT,
|
| 307 |
-
models=models,
|
| 308 |
-
sampler_name="ddpm",
|
| 309 |
-
seed=42,
|
| 310 |
-
device="cuda" # or "cpu"
|
| 311 |
-
)
|
| 312 |
-
|
| 313 |
-
generated_image[0].save("generated_image.png")
|
| 314 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
requirements.txt
CHANGED
|
@@ -1,53 +1,39 @@
|
|
| 1 |
-
|
| 2 |
-
aiohttp==3.11.18
|
| 3 |
-
aiosignal==1.3.2
|
| 4 |
-
annotated-types==0.7.0
|
| 5 |
asttokens @ file:///home/conda/feedstock_root/build_artifacts/asttokens_1733250440834/work
|
| 6 |
-
|
| 7 |
-
certifi==2025.
|
| 8 |
charset-normalizer==3.4.2
|
| 9 |
-
click==8.2.0
|
| 10 |
comm @ file:///home/conda/feedstock_root/build_artifacts/comm_1733502965406/work
|
| 11 |
contourpy==1.3.2
|
| 12 |
cycler==0.12.1
|
| 13 |
-
|
| 14 |
-
debugpy @ file:///home/conda/feedstock_root/build_artifacts/debugpy_1744321233760/work
|
| 15 |
decorator @ file:///home/conda/feedstock_root/build_artifacts/decorator_1740384970518/work
|
| 16 |
-
|
| 17 |
-
docker-pycreds==0.4.0
|
| 18 |
exceptiongroup @ file:///home/conda/feedstock_root/build_artifacts/exceptiongroup_1746947292760/work
|
| 19 |
executing @ file:///home/conda/feedstock_root/build_artifacts/executing_1745502089858/work
|
| 20 |
filelock==3.18.0
|
| 21 |
-
fonttools==4.
|
| 22 |
-
|
| 23 |
-
|
| 24 |
-
|
| 25 |
-
|
| 26 |
-
hf-xet==1.1.0
|
| 27 |
-
huggingface-hub==0.31.1
|
| 28 |
idna==3.10
|
| 29 |
-
importlib_metadata @ file:///home/conda/feedstock_root/build_artifacts/
|
| 30 |
ipykernel @ file:///home/conda/feedstock_root/build_artifacts/ipykernel_1719845459717/work
|
| 31 |
-
ipython @ file:///home/conda/feedstock_root/build_artifacts/bld/rattler-
|
| 32 |
-
ipython_pygments_lexers @ file:///home/conda/feedstock_root/build_artifacts/ipython_pygments_lexers_1737123620466/work
|
| 33 |
-
ipywidgets==8.1.7
|
| 34 |
jedi @ file:///home/conda/feedstock_root/build_artifacts/jedi_1733300866624/work
|
| 35 |
Jinja2==3.1.6
|
| 36 |
jupyter_client @ file:///home/conda/feedstock_root/build_artifacts/jupyter_client_1733440914442/work
|
| 37 |
-
jupyter_core @ file:///home/conda/feedstock_root/build_artifacts/
|
| 38 |
-
|
| 39 |
kiwisolver==1.4.8
|
| 40 |
-
lightning==2.5.1.post0
|
| 41 |
-
lightning-utilities==0.14.3
|
| 42 |
MarkupSafe==3.0.2
|
| 43 |
matplotlib==3.10.3
|
| 44 |
matplotlib-inline @ file:///home/conda/feedstock_root/build_artifacts/matplotlib-inline_1733416936468/work
|
| 45 |
mpmath==1.3.0
|
| 46 |
-
multidict==6.4.3
|
| 47 |
-
multiprocess==0.70.16
|
| 48 |
nest_asyncio @ file:///home/conda/feedstock_root/build_artifacts/nest-asyncio_1733325553580/work
|
| 49 |
networkx==3.4.2
|
| 50 |
-
numpy==2.2.
|
| 51 |
nvidia-cublas-cu12==12.6.4.1
|
| 52 |
nvidia-cuda-cupti-cu12==12.6.80
|
| 53 |
nvidia-cuda-nvrtc-cu12==12.6.77
|
|
@@ -62,54 +48,43 @@ nvidia-cusparselt-cu12==0.6.3
|
|
| 62 |
nvidia-nccl-cu12==2.26.2
|
| 63 |
nvidia-nvjitlink-cu12==12.6.85
|
| 64 |
nvidia-nvtx-cu12==12.6.77
|
| 65 |
-
packaging
|
| 66 |
-
pandas==2.
|
| 67 |
parso @ file:///home/conda/feedstock_root/build_artifacts/parso_1733271261340/work
|
| 68 |
pexpect @ file:///home/conda/feedstock_root/build_artifacts/pexpect_1733301927746/work
|
| 69 |
pickleshare @ file:///home/conda/feedstock_root/build_artifacts/pickleshare_1733327343728/work
|
| 70 |
-
pillow==11.
|
| 71 |
platformdirs @ file:///home/conda/feedstock_root/build_artifacts/bld/rattler-build_platformdirs_1746710438/work
|
| 72 |
prompt_toolkit @ file:///home/conda/feedstock_root/build_artifacts/prompt-toolkit_1744724089886/work
|
| 73 |
-
|
| 74 |
-
protobuf==6.30.2
|
| 75 |
-
psutil @ file:///home/conda/feedstock_root/build_artifacts/psutil_1740663149797/work
|
| 76 |
ptyprocess @ file:///home/conda/feedstock_root/build_artifacts/ptyprocess_1733302279685/work/dist/ptyprocess-0.7.0-py2.py3-none-any.whl#sha256=92c32ff62b5fd8cf325bec5ab90d7be3d2a8ca8c8a3813ff487a8d2002630d1f
|
| 77 |
pure_eval @ file:///home/conda/feedstock_root/build_artifacts/pure_eval_1733569405015/work
|
| 78 |
-
|
| 79 |
-
pydantic==2.11.4
|
| 80 |
-
pydantic_core==2.33.2
|
| 81 |
-
Pygments @ file:///home/conda/feedstock_root/build_artifacts/pygments_1736243443484/work
|
| 82 |
pyparsing==3.2.3
|
| 83 |
-
|
| 84 |
-
|
| 85 |
pytz==2025.2
|
| 86 |
PyYAML==6.0.2
|
| 87 |
-
pyzmq @ file:///home/conda/feedstock_root/build_artifacts/
|
| 88 |
regex==2024.11.6
|
| 89 |
-
requests==2.32.
|
| 90 |
safetensors==0.5.3
|
| 91 |
-
sentry-sdk==2.27.0
|
| 92 |
-
setproctitle==1.3.6
|
| 93 |
six @ file:///home/conda/feedstock_root/build_artifacts/six_1733380938961/work
|
| 94 |
-
|
| 95 |
stack_data @ file:///home/conda/feedstock_root/build_artifacts/stack_data_1733569443808/work
|
| 96 |
sympy==1.14.0
|
| 97 |
-
tokenizers==0.21.
|
| 98 |
-
torch==2.7.
|
| 99 |
-
|
| 100 |
-
torchvision==0.22.
|
| 101 |
-
tornado @ file:///home/conda/feedstock_root/build_artifacts/
|
| 102 |
tqdm==4.67.1
|
| 103 |
traitlets @ file:///home/conda/feedstock_root/build_artifacts/traitlets_1733367359838/work
|
| 104 |
-
transformers==4.
|
| 105 |
-
triton==3.3.
|
| 106 |
-
|
| 107 |
-
typing_extensions @ file:///home/conda/feedstock_root/build_artifacts/bld/rattler-build_typing_extensions_1744302253/work
|
| 108 |
tzdata==2025.2
|
| 109 |
-
|
| 110 |
-
|
| 111 |
wcwidth @ file:///home/conda/feedstock_root/build_artifacts/wcwidth_1733231326287/work
|
| 112 |
-
|
| 113 |
-
xxhash==3.5.0
|
| 114 |
-
yarl==1.20.0
|
| 115 |
-
zipp @ file:///home/conda/feedstock_root/build_artifacts/zipp_1732827521216/work
|
|
|
|
| 1 |
+
accelerate==1.9.0
|
|
|
|
|
|
|
|
|
|
| 2 |
asttokens @ file:///home/conda/feedstock_root/build_artifacts/asttokens_1733250440834/work
|
| 3 |
+
beautifulsoup4==4.13.4
|
| 4 |
+
certifi==2025.7.14
|
| 5 |
charset-normalizer==3.4.2
|
|
|
|
| 6 |
comm @ file:///home/conda/feedstock_root/build_artifacts/comm_1733502965406/work
|
| 7 |
contourpy==1.3.2
|
| 8 |
cycler==0.12.1
|
| 9 |
+
debugpy @ file:///home/conda/feedstock_root/build_artifacts/bld/rattler-build_debugpy_1752827112/work
|
|
|
|
| 10 |
decorator @ file:///home/conda/feedstock_root/build_artifacts/decorator_1740384970518/work
|
| 11 |
+
diffusers==0.34.0
|
|
|
|
| 12 |
exceptiongroup @ file:///home/conda/feedstock_root/build_artifacts/exceptiongroup_1746947292760/work
|
| 13 |
executing @ file:///home/conda/feedstock_root/build_artifacts/executing_1745502089858/work
|
| 14 |
filelock==3.18.0
|
| 15 |
+
fonttools==4.59.0
|
| 16 |
+
fsspec==2025.7.0
|
| 17 |
+
gdown==5.2.0
|
| 18 |
+
hf-xet==1.1.5
|
| 19 |
+
huggingface-hub==0.33.4
|
|
|
|
|
|
|
| 20 |
idna==3.10
|
| 21 |
+
importlib_metadata @ file:///home/conda/feedstock_root/build_artifacts/bld/rattler-build_importlib-metadata_1747934053/work
|
| 22 |
ipykernel @ file:///home/conda/feedstock_root/build_artifacts/ipykernel_1719845459717/work
|
| 23 |
+
ipython @ file:///home/conda/feedstock_root/build_artifacts/bld/rattler-build_ipython_1748711175/work
|
|
|
|
|
|
|
| 24 |
jedi @ file:///home/conda/feedstock_root/build_artifacts/jedi_1733300866624/work
|
| 25 |
Jinja2==3.1.6
|
| 26 |
jupyter_client @ file:///home/conda/feedstock_root/build_artifacts/jupyter_client_1733440914442/work
|
| 27 |
+
jupyter_core @ file:///home/conda/feedstock_root/build_artifacts/jupyter_core_1748333051527/work
|
| 28 |
+
kagglehub==0.3.12
|
| 29 |
kiwisolver==1.4.8
|
|
|
|
|
|
|
| 30 |
MarkupSafe==3.0.2
|
| 31 |
matplotlib==3.10.3
|
| 32 |
matplotlib-inline @ file:///home/conda/feedstock_root/build_artifacts/matplotlib-inline_1733416936468/work
|
| 33 |
mpmath==1.3.0
|
|
|
|
|
|
|
| 34 |
nest_asyncio @ file:///home/conda/feedstock_root/build_artifacts/nest-asyncio_1733325553580/work
|
| 35 |
networkx==3.4.2
|
| 36 |
+
numpy==2.2.6
|
| 37 |
nvidia-cublas-cu12==12.6.4.1
|
| 38 |
nvidia-cuda-cupti-cu12==12.6.80
|
| 39 |
nvidia-cuda-nvrtc-cu12==12.6.77
|
|
|
|
| 48 |
nvidia-nccl-cu12==2.26.2
|
| 49 |
nvidia-nvjitlink-cu12==12.6.85
|
| 50 |
nvidia-nvtx-cu12==12.6.77
|
| 51 |
+
packaging @ file:///home/conda/feedstock_root/build_artifacts/bld/rattler-build_packaging_1745345660/work
|
| 52 |
+
pandas==2.3.1
|
| 53 |
parso @ file:///home/conda/feedstock_root/build_artifacts/parso_1733271261340/work
|
| 54 |
pexpect @ file:///home/conda/feedstock_root/build_artifacts/pexpect_1733301927746/work
|
| 55 |
pickleshare @ file:///home/conda/feedstock_root/build_artifacts/pickleshare_1733327343728/work
|
| 56 |
+
pillow==11.3.0
|
| 57 |
platformdirs @ file:///home/conda/feedstock_root/build_artifacts/bld/rattler-build_platformdirs_1746710438/work
|
| 58 |
prompt_toolkit @ file:///home/conda/feedstock_root/build_artifacts/prompt-toolkit_1744724089886/work
|
| 59 |
+
psutil @ file:///home/conda/feedstock_root/build_artifacts/psutil_1740663128538/work
|
|
|
|
|
|
|
| 60 |
ptyprocess @ file:///home/conda/feedstock_root/build_artifacts/ptyprocess_1733302279685/work/dist/ptyprocess-0.7.0-py2.py3-none-any.whl#sha256=92c32ff62b5fd8cf325bec5ab90d7be3d2a8ca8c8a3813ff487a8d2002630d1f
|
| 61 |
pure_eval @ file:///home/conda/feedstock_root/build_artifacts/pure_eval_1733569405015/work
|
| 62 |
+
Pygments @ file:///home/conda/feedstock_root/build_artifacts/pygments_1750615794071/work
|
|
|
|
|
|
|
|
|
|
| 63 |
pyparsing==3.2.3
|
| 64 |
+
PySocks==1.7.1
|
| 65 |
+
python-dateutil @ file:///home/conda/feedstock_root/build_artifacts/bld/rattler-build_python-dateutil_1751104122/work
|
| 66 |
pytz==2025.2
|
| 67 |
PyYAML==6.0.2
|
| 68 |
+
pyzmq @ file:///home/conda/feedstock_root/build_artifacts/pyzmq_1749898457097/work
|
| 69 |
regex==2024.11.6
|
| 70 |
+
requests==2.32.4
|
| 71 |
safetensors==0.5.3
|
|
|
|
|
|
|
| 72 |
six @ file:///home/conda/feedstock_root/build_artifacts/six_1733380938961/work
|
| 73 |
+
soupsieve==2.7
|
| 74 |
stack_data @ file:///home/conda/feedstock_root/build_artifacts/stack_data_1733569443808/work
|
| 75 |
sympy==1.14.0
|
| 76 |
+
tokenizers==0.21.2
|
| 77 |
+
torch==2.7.1
|
| 78 |
+
torchsummary==1.5.1
|
| 79 |
+
torchvision==0.22.1
|
| 80 |
+
tornado @ file:///home/conda/feedstock_root/build_artifacts/tornado_1748003301700/work
|
| 81 |
tqdm==4.67.1
|
| 82 |
traitlets @ file:///home/conda/feedstock_root/build_artifacts/traitlets_1733367359838/work
|
| 83 |
+
transformers==4.53.2
|
| 84 |
+
triton==3.3.1
|
| 85 |
+
typing_extensions @ file:///home/conda/feedstock_root/build_artifacts/bld/rattler-build_typing_extensions_1751643513/work
|
|
|
|
| 86 |
tzdata==2025.2
|
| 87 |
+
unzip==1.0.0
|
| 88 |
+
urllib3==2.5.0
|
| 89 |
wcwidth @ file:///home/conda/feedstock_root/build_artifacts/wcwidth_1733231326287/work
|
| 90 |
+
zipp @ file:///home/conda/feedstock_root/build_artifacts/zipp_1749421620841/work
|
|
|
|
|
|
|
|
|
sample_inference.ipynb
ADDED
|
@@ -0,0 +1,435 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"cells": [
|
| 3 |
+
{
|
| 4 |
+
"cell_type": "code",
|
| 5 |
+
"execution_count": 1,
|
| 6 |
+
"id": "237f5cbf",
|
| 7 |
+
"metadata": {},
|
| 8 |
+
"outputs": [
|
| 9 |
+
{
|
| 10 |
+
"name": "stdout",
|
| 11 |
+
"output_type": "stream",
|
| 12 |
+
"text": [
|
| 13 |
+
"Model already downloaded.\n"
|
| 14 |
+
]
|
| 15 |
+
}
|
| 16 |
+
],
|
| 17 |
+
"source": [
|
| 18 |
+
"# !wget https://huggingface.co/sd-legacy/stable-diffusion-inpainting/resolve/main/sd-v1-5-inpainting.ckpt\n",
|
| 19 |
+
"\n",
|
| 20 |
+
"# check if the model is downloaded, if not download it\n",
|
| 21 |
+
"import os\n",
|
| 22 |
+
"if not os.path.exists(\"sd-v1-5-inpainting.ckpt\"):\n",
|
| 23 |
+
" !wget https://huggingface.co/sd-legacy/stable-diffusion-inpainting/resolve/main/sd-v1-5-inpainting.ckpt\n",
|
| 24 |
+
"else:\n",
|
| 25 |
+
" print(\"Model already downloaded.\")"
|
| 26 |
+
]
|
| 27 |
+
},
|
| 28 |
+
{
|
| 29 |
+
"cell_type": "code",
|
| 30 |
+
"execution_count": null,
|
| 31 |
+
"id": "bab24c29",
|
| 32 |
+
"metadata": {},
|
| 33 |
+
"outputs": [
|
| 34 |
+
{
|
| 35 |
+
"name": "stderr",
|
| 36 |
+
"output_type": "stream",
|
| 37 |
+
"text": [
|
| 38 |
+
"/home/mahesh/miniconda3/envs/harsh/lib/python3.10/site-packages/tqdm/auto.py:21: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n",
|
| 39 |
+
" from .autonotebook import tqdm as notebook_tqdm\n"
|
| 40 |
+
]
|
| 41 |
+
}
|
| 42 |
+
],
|
| 43 |
+
"source": [
|
| 44 |
+
"import inspect\n",
|
| 45 |
+
"import os\n",
|
| 46 |
+
"from typing import Union\n",
|
| 47 |
+
"\n",
|
| 48 |
+
"import PIL\n",
|
| 49 |
+
"import numpy as np\n",
|
| 50 |
+
"import torch\n",
|
| 51 |
+
"import tqdm\n",
|
| 52 |
+
"from diffusers.utils.torch_utils import randn_tensor\n",
|
| 53 |
+
"\n",
|
| 54 |
+
"from utils import (check_inputs, get_time_embedding, numpy_to_pil, prepare_image,\n",
|
| 55 |
+
" prepare_mask_image, compute_vae_encodings)\n",
|
| 56 |
+
"from ddpm import DDPMSampler\n",
|
| 57 |
+
"\n",
|
| 58 |
+
"class CatVTONPipeline:\n",
|
| 59 |
+
" def __init__(\n",
|
| 60 |
+
" self, \n",
|
| 61 |
+
" weight_dtype=torch.float32,\n",
|
| 62 |
+
" device='cuda',\n",
|
| 63 |
+
" compile=False,\n",
|
| 64 |
+
" skip_safety_check=True,\n",
|
| 65 |
+
" use_tf32=True,\n",
|
| 66 |
+
" models={},\n",
|
| 67 |
+
" ):\n",
|
| 68 |
+
" self.device = device\n",
|
| 69 |
+
" self.weight_dtype = weight_dtype\n",
|
| 70 |
+
" self.skip_safety_check = skip_safety_check\n",
|
| 71 |
+
" self.models = models\n",
|
| 72 |
+
"\n",
|
| 73 |
+
" self.generator = torch.Generator(device=device)\n",
|
| 74 |
+
" self.noise_scheduler = DDPMSampler(generator=self.generator)\n",
|
| 75 |
+
" # self.vae = AutoencoderKL.from_pretrained(\"stabilityai/sd-vae-ft-mse\").to(device, dtype=weight_dtype)\n",
|
| 76 |
+
" self.encoder= models.get('encoder', None)\n",
|
| 77 |
+
" self.decoder= models.get('decoder', None)\n",
|
| 78 |
+
" \n",
|
| 79 |
+
" self.unet=models.get('diffusion', None) \n",
|
| 80 |
+
" # # Enable TF32 for faster training on Ampere GPUs (A100 and RTX 30 series).\n",
|
| 81 |
+
" if use_tf32:\n",
|
| 82 |
+
" torch.set_float32_matmul_precision(\"high\")\n",
|
| 83 |
+
" torch.backends.cuda.matmul.allow_tf32 = True\n",
|
| 84 |
+
"\n",
|
| 85 |
+
" @torch.no_grad()\n",
|
| 86 |
+
" def __call__(\n",
|
| 87 |
+
" self, \n",
|
| 88 |
+
" image: Union[PIL.Image.Image, torch.Tensor],\n",
|
| 89 |
+
" condition_image: Union[PIL.Image.Image, torch.Tensor],\n",
|
| 90 |
+
" mask: Union[PIL.Image.Image, torch.Tensor],\n",
|
| 91 |
+
" num_inference_steps: int = 50,\n",
|
| 92 |
+
" guidance_scale: float = 2.5,\n",
|
| 93 |
+
" height: int = 1024,\n",
|
| 94 |
+
" width: int = 768,\n",
|
| 95 |
+
" generator=None,\n",
|
| 96 |
+
" eta=1.0,\n",
|
| 97 |
+
" **kwargs\n",
|
| 98 |
+
" ):\n",
|
| 99 |
+
" concat_dim = -2 # FIXME: y axis concat\n",
|
| 100 |
+
" # Prepare inputs to Tensor\n",
|
| 101 |
+
" image, condition_image, mask = check_inputs(image, condition_image, mask, width, height)\n",
|
| 102 |
+
" image = prepare_image(image).to(self.device, dtype=self.weight_dtype)\n",
|
| 103 |
+
" condition_image = prepare_image(condition_image).to(self.device, dtype=self.weight_dtype)\n",
|
| 104 |
+
" mask = prepare_mask_image(mask).to(self.device, dtype=self.weight_dtype)\n",
|
| 105 |
+
" # Mask image\n",
|
| 106 |
+
" masked_image = image * (mask < 0.5)\n",
|
| 107 |
+
" # VAE encoding\n",
|
| 108 |
+
" masked_latent = compute_vae_encodings(masked_image, self.encoder)\n",
|
| 109 |
+
" condition_latent = compute_vae_encodings(condition_image, self.encoder)\n",
|
| 110 |
+
" mask_latent = torch.nn.functional.interpolate(mask, size=masked_latent.shape[-2:], mode=\"nearest\")\n",
|
| 111 |
+
" del image, mask, condition_image\n",
|
| 112 |
+
" # Concatenate latents\n",
|
| 113 |
+
" masked_latent_concat = torch.cat([masked_latent, condition_latent], dim=concat_dim)\n",
|
| 114 |
+
" mask_latent_concat = torch.cat([mask_latent, torch.zeros_like(mask_latent)], dim=concat_dim)\n",
|
| 115 |
+
" # Prepare noise\n",
|
| 116 |
+
" latents = randn_tensor(\n",
|
| 117 |
+
" masked_latent_concat.shape,\n",
|
| 118 |
+
" generator=generator,\n",
|
| 119 |
+
" device=masked_latent_concat.device,\n",
|
| 120 |
+
" dtype=self.weight_dtype,\n",
|
| 121 |
+
" )\n",
|
| 122 |
+
" # Prepare timesteps\n",
|
| 123 |
+
" self.noise_scheduler.set_inference_timesteps(num_inference_steps)\n",
|
| 124 |
+
" timesteps = self.noise_scheduler.timesteps\n",
|
| 125 |
+
" # latents = latents * self.noise_scheduler.init_noise_sigma\n",
|
| 126 |
+
" latents = self.noise_scheduler.add_noise(latents, timesteps[0])\n",
|
| 127 |
+
" \n",
|
| 128 |
+
" # Classifier-Free Guidance\n",
|
| 129 |
+
" if do_classifier_free_guidance := (guidance_scale > 1.0):\n",
|
| 130 |
+
" masked_latent_concat = torch.cat(\n",
|
| 131 |
+
" [\n",
|
| 132 |
+
" torch.cat([masked_latent, torch.zeros_like(condition_latent)], dim=concat_dim),\n",
|
| 133 |
+
" masked_latent_concat,\n",
|
| 134 |
+
" ]\n",
|
| 135 |
+
" )\n",
|
| 136 |
+
" mask_latent_concat = torch.cat([mask_latent_concat] * 2)\n",
|
| 137 |
+
"\n",
|
| 138 |
+
" num_warmup_steps = 0 # For simple DDPM, no warmup needed\n",
|
| 139 |
+
" with tqdm(total=num_inference_steps) as progress_bar:\n",
|
| 140 |
+
" for i, t in enumerate(timesteps):\n",
|
| 141 |
+
" # expand the latents if we are doing classifier free guidance\n",
|
| 142 |
+
" non_inpainting_latent_model_input = (torch.cat([latents] * 2) if do_classifier_free_guidance else latents)\n",
|
| 143 |
+
" # non_inpainting_latent_model_input = self.noise_scheduler.scale_model_input(non_inpainting_latent_model_input, t)\n",
|
| 144 |
+
" # prepare the input for the inpainting model\n",
|
| 145 |
+
" inpainting_latent_model_input = torch.cat([non_inpainting_latent_model_input, mask_latent_concat, masked_latent_concat], dim=1).to(self.device, dtype=self.weight_dtype)\n",
|
| 146 |
+
" # predict the noise residual\n",
|
| 147 |
+
" \n",
|
| 148 |
+
" timestep = t.repeat(inpainting_latent_model_input.shape[0])\n",
|
| 149 |
+
" time_embedding = get_time_embedding(timestep).to(self.device, dtype=self.weight_dtype)\n",
|
| 150 |
+
"\n",
|
| 151 |
+
" noise_pred = self.unet(\n",
|
| 152 |
+
" inpainting_latent_model_input,\n",
|
| 153 |
+
" time_embedding\n",
|
| 154 |
+
" )\n",
|
| 155 |
+
" # perform guidance\n",
|
| 156 |
+
" if do_classifier_free_guidance:\n",
|
| 157 |
+
" noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)\n",
|
| 158 |
+
" noise_pred = noise_pred_uncond + guidance_scale * (\n",
|
| 159 |
+
" noise_pred_text - noise_pred_uncond\n",
|
| 160 |
+
" )\n",
|
| 161 |
+
" # compute the previous noisy sample x_t -> x_t-1\n",
|
| 162 |
+
" latents = self.noise_scheduler.step(\n",
|
| 163 |
+
" t, latents, noise_pred\n",
|
| 164 |
+
" )\n",
|
| 165 |
+
" # call the callback, if provided\n",
|
| 166 |
+
" if i == len(timesteps) - 1 or (\n",
|
| 167 |
+
" (i + 1) > num_warmup_steps\n",
|
| 168 |
+
" ):\n",
|
| 169 |
+
" progress_bar.update()\n",
|
| 170 |
+
"\n",
|
| 171 |
+
" # Decode the final latents\n",
|
| 172 |
+
" latents = latents.split(latents.shape[concat_dim] // 2, dim=concat_dim)[0]\n",
|
| 173 |
+
" # latents = 1 / self.vae.config.scaling_factor * latents\n",
|
| 174 |
+
" # image = self.vae.decode(latents.to(self.device, dtype=self.weight_dtype)).sample\n",
|
| 175 |
+
" image = self.decoder(latents.to(self.device, dtype=self.weight_dtype))\n",
|
| 176 |
+
" image = (image / 2 + 0.5).clamp(0, 1)\n",
|
| 177 |
+
" # we always cast to float32 as this does not cause significant overhead and is compatible with bfloat16\n",
|
| 178 |
+
" image = image.cpu().permute(0, 2, 3, 1).float().numpy()\n",
|
| 179 |
+
" image = numpy_to_pil(image)\n",
|
| 180 |
+
" \n",
|
| 181 |
+
" return image\n"
|
| 182 |
+
]
|
| 183 |
+
},
|
| 184 |
+
{
|
| 185 |
+
"cell_type": "code",
|
| 186 |
+
"execution_count": 4,
|
| 187 |
+
"id": "a069151e",
|
| 188 |
+
"metadata": {},
|
| 189 |
+
"outputs": [
|
| 190 |
+
{
|
| 191 |
+
"name": "stdout",
|
| 192 |
+
"output_type": "stream",
|
| 193 |
+
"text": [
|
| 194 |
+
"Loaded finetuned weights from finetuned_weights.safetensors\n",
|
| 195 |
+
"Loading 0.in_proj.weight\n",
|
| 196 |
+
"Loading 0.out_proj.weight\n",
|
| 197 |
+
"Loading 0.out_proj.bias\n",
|
| 198 |
+
"Loading 8.in_proj.weight\n",
|
| 199 |
+
"Loading 8.out_proj.weight\n",
|
| 200 |
+
"Loading 8.out_proj.bias\n",
|
| 201 |
+
"Loading 16.in_proj.weight\n",
|
| 202 |
+
"Loading 16.out_proj.weight\n",
|
| 203 |
+
"Loading 16.out_proj.bias\n",
|
| 204 |
+
"Loading 24.in_proj.weight\n",
|
| 205 |
+
"Loading 24.out_proj.weight\n",
|
| 206 |
+
"Loading 24.out_proj.bias\n",
|
| 207 |
+
"Loading 32.in_proj.weight\n",
|
| 208 |
+
"Loading 32.out_proj.weight\n",
|
| 209 |
+
"Loading 32.out_proj.bias\n",
|
| 210 |
+
"Loading 40.in_proj.weight\n",
|
| 211 |
+
"Loading 40.out_proj.weight\n",
|
| 212 |
+
"Loading 40.out_proj.bias\n",
|
| 213 |
+
"Loading 48.in_proj.weight\n",
|
| 214 |
+
"Loading 48.out_proj.weight\n",
|
| 215 |
+
"Loading 48.out_proj.bias\n",
|
| 216 |
+
"Loading 56.in_proj.weight\n",
|
| 217 |
+
"Loading 56.out_proj.weight\n",
|
| 218 |
+
"Loading 56.out_proj.bias\n",
|
| 219 |
+
"Loading 64.in_proj.weight\n",
|
| 220 |
+
"Loading 64.out_proj.weight\n",
|
| 221 |
+
"Loading 64.out_proj.bias\n",
|
| 222 |
+
"Loading 72.in_proj.weight\n",
|
| 223 |
+
"Loading 72.out_proj.weight\n",
|
| 224 |
+
"Loading 72.out_proj.bias\n",
|
| 225 |
+
"Loading 80.in_proj.weight\n",
|
| 226 |
+
"Loading 80.out_proj.weight\n",
|
| 227 |
+
"Loading 80.out_proj.bias\n",
|
| 228 |
+
"Loading 88.in_proj.weight\n",
|
| 229 |
+
"Loading 88.out_proj.weight\n",
|
| 230 |
+
"Loading 88.out_proj.bias\n",
|
| 231 |
+
"Loading 96.in_proj.weight\n",
|
| 232 |
+
"Loading 96.out_proj.weight\n",
|
| 233 |
+
"Loading 96.out_proj.bias\n",
|
| 234 |
+
"Loading 104.in_proj.weight\n",
|
| 235 |
+
"Loading 104.out_proj.weight\n",
|
| 236 |
+
"Loading 104.out_proj.bias\n",
|
| 237 |
+
"Loading 112.in_proj.weight\n",
|
| 238 |
+
"Loading 112.out_proj.weight\n",
|
| 239 |
+
"Loading 112.out_proj.bias\n",
|
| 240 |
+
"Loading 120.in_proj.weight\n",
|
| 241 |
+
"Loading 120.out_proj.weight\n",
|
| 242 |
+
"Loading 120.out_proj.bias\n",
|
| 243 |
+
"\n",
|
| 244 |
+
"Attention module weights loaded from {finetune_weights_path} successfully.\n"
|
| 245 |
+
]
|
| 246 |
+
}
|
| 247 |
+
],
|
| 248 |
+
"source": [
|
| 249 |
+
"import load_model\n",
|
| 250 |
+
"\n",
|
| 251 |
+
"models=load_model.preload_models_from_standard_weights(ckpt_path=\"sd-v1-5-inpainting.ckpt\", device=\"cuda\", finetune_weights_path=\"finetuned_weights.safetensors\")"
|
| 252 |
+
]
|
| 253 |
+
},
|
| 254 |
+
{
|
| 255 |
+
"cell_type": "code",
|
| 256 |
+
"execution_count": 5,
|
| 257 |
+
"id": "a729bf46",
|
| 258 |
+
"metadata": {},
|
| 259 |
+
"outputs": [
|
| 260 |
+
{
|
| 261 |
+
"name": "stdout",
|
| 262 |
+
"output_type": "stream",
|
| 263 |
+
"text": [
|
| 264 |
+
"Dataset vitonhd loaded, total 20 pairs.\n"
|
| 265 |
+
]
|
| 266 |
+
},
|
| 267 |
+
{
|
| 268 |
+
"name": "stderr",
|
| 269 |
+
"output_type": "stream",
|
| 270 |
+
"text": [
|
| 271 |
+
"100%|██████████| 50/50 [00:07<00:00, 6.39it/s]\n",
|
| 272 |
+
"100%|██████████| 50/50 [00:07<00:00, 6.39it/s]\n",
|
| 273 |
+
"100%|██████████| 50/50 [00:07<00:00, 6.42it/s]\n",
|
| 274 |
+
"100%|██████████| 50/50 [00:07<00:00, 6.44it/s]\n",
|
| 275 |
+
"100%|██████████| 50/50 [00:07<00:00, 6.44it/s]\n",
|
| 276 |
+
"100%|██████████| 50/50 [00:07<00:00, 6.40it/s]\n",
|
| 277 |
+
"100%|██████████| 50/50 [00:07<00:00, 6.43it/s]\n",
|
| 278 |
+
"100%|██████████| 50/50 [00:07<00:00, 6.42it/s]\n",
|
| 279 |
+
"100%|██████████| 50/50 [00:07<00:00, 6.41it/s]\n",
|
| 280 |
+
"100%|██████████| 50/50 [00:07<00:00, 6.43it/s]\n",
|
| 281 |
+
"100%|██████████| 50/50 [00:07<00:00, 6.43it/s]\n",
|
| 282 |
+
"100%|██████████| 50/50 [00:07<00:00, 6.41it/s]\n",
|
| 283 |
+
"100%|██████████| 50/50 [00:07<00:00, 6.40it/s]\n",
|
| 284 |
+
"100%|██████████| 50/50 [00:07<00:00, 6.43it/s]\n",
|
| 285 |
+
"100%|██████████| 50/50 [00:07<00:00, 6.29it/s]\n",
|
| 286 |
+
"100%|██████████| 50/50 [00:07<00:00, 6.46it/s]\n",
|
| 287 |
+
"100%|██████████| 50/50 [00:07<00:00, 6.45it/s]\n",
|
| 288 |
+
"100%|██████████| 50/50 [00:07<00:00, 6.46it/s]\n",
|
| 289 |
+
"100%|██████████| 50/50 [00:07<00:00, 6.47it/s]\n",
|
| 290 |
+
"100%|██████████| 50/50 [00:07<00:00, 6.45it/s]\n",
|
| 291 |
+
"100%|██████████| 20/20 [02:43<00:00, 8.15s/it]\n"
|
| 292 |
+
]
|
| 293 |
+
}
|
| 294 |
+
],
|
| 295 |
+
"source": [
|
| 296 |
+
"import os\n",
|
| 297 |
+
"import numpy as np\n",
|
| 298 |
+
"import torch\n",
|
| 299 |
+
"import argparse\n",
|
| 300 |
+
"from torch.utils.data import Dataset, DataLoader\n",
|
| 301 |
+
"from VITON_Dataset import VITONHDTestDataset\n",
|
| 302 |
+
"from diffusers.image_processor import VaeImageProcessor\n",
|
| 303 |
+
"from tqdm import tqdm\n",
|
| 304 |
+
"from PIL import Image, ImageFilter\n",
|
| 305 |
+
"\n",
|
| 306 |
+
"from utils import repaint, to_pil_image\n",
|
| 307 |
+
"\n",
|
| 308 |
+
"@torch.no_grad()\n",
|
| 309 |
+
"def main():\n",
|
| 310 |
+
" args=argparse.Namespace()\n",
|
| 311 |
+
" args.__dict__= {\n",
|
| 312 |
+
" \"dataset_name\": \"vitonhd\",\n",
|
| 313 |
+
" \"data_root_path\": \"./sample_dataset\",\n",
|
| 314 |
+
" \"output_dir\": \"./trained_output\",\n",
|
| 315 |
+
" \"seed\": 555,\n",
|
| 316 |
+
" \"batch_size\": 1,\n",
|
| 317 |
+
" \"num_inference_steps\": 50,\n",
|
| 318 |
+
" \"guidance_scale\": 2.5,\n",
|
| 319 |
+
" \"width\": 384,\n",
|
| 320 |
+
" \"height\": 384,\n",
|
| 321 |
+
" \"repaint\": True,\n",
|
| 322 |
+
" \"eval_pair\": False,\n",
|
| 323 |
+
" \"concat_eval_results\": True,\n",
|
| 324 |
+
" \"allow_tf32\": True,\n",
|
| 325 |
+
" \"dataloader_num_workers\": 4,\n",
|
| 326 |
+
" \"mixed_precision\": 'no',\n",
|
| 327 |
+
" \"concat_axis\": 'y',\n",
|
| 328 |
+
" \"enable_condition_noise\": True,\n",
|
| 329 |
+
" \"is_train\": False\n",
|
| 330 |
+
" }\n",
|
| 331 |
+
"\n",
|
| 332 |
+
" # Pipeline\n",
|
| 333 |
+
" pipeline = CatVTONPipeline(\n",
|
| 334 |
+
" weight_dtype={\n",
|
| 335 |
+
" \"no\": torch.float32,\n",
|
| 336 |
+
" \"fp16\": torch.float16,\n",
|
| 337 |
+
" \"bf16\": torch.bfloat16,\n",
|
| 338 |
+
" }[args.mixed_precision],\n",
|
| 339 |
+
" device=\"cuda\",\n",
|
| 340 |
+
" skip_safety_check=True,\n",
|
| 341 |
+
" models=models,\n",
|
| 342 |
+
" )\n",
|
| 343 |
+
" # Dataset\n",
|
| 344 |
+
" if args.dataset_name == \"vitonhd\":\n",
|
| 345 |
+
" dataset = VITONHDTestDataset(args)\n",
|
| 346 |
+
" else:\n",
|
| 347 |
+
" raise ValueError(f\"Invalid dataset name {args.dataset}.\")\n",
|
| 348 |
+
" print(f\"Dataset {args.dataset_name} loaded, total {len(dataset)} pairs.\")\n",
|
| 349 |
+
" dataloader = DataLoader(\n",
|
| 350 |
+
" dataset,\n",
|
| 351 |
+
" batch_size=args.batch_size,\n",
|
| 352 |
+
" shuffle=False,\n",
|
| 353 |
+
" num_workers=args.dataloader_num_workers\n",
|
| 354 |
+
" )\n",
|
| 355 |
+
" \n",
|
| 356 |
+
" # Inference\n",
|
| 357 |
+
" generator = torch.Generator(device='cuda').manual_seed(args.seed)\n",
|
| 358 |
+
" args.output_dir = os.path.join(args.output_dir, f\"{args.dataset_name}-{args.height}\", \"paired\" if args.eval_pair else \"unpaired\")\n",
|
| 359 |
+
" if not os.path.exists(args.output_dir):\n",
|
| 360 |
+
" os.makedirs(args.output_dir)\n",
|
| 361 |
+
" \n",
|
| 362 |
+
" for batch in tqdm(dataloader):\n",
|
| 363 |
+
" person_images = batch['person']\n",
|
| 364 |
+
" cloth_images = batch['cloth']\n",
|
| 365 |
+
" masks = batch['mask']\n",
|
| 366 |
+
"\n",
|
| 367 |
+
" results = pipeline(\n",
|
| 368 |
+
" person_images,\n",
|
| 369 |
+
" cloth_images,\n",
|
| 370 |
+
" masks,\n",
|
| 371 |
+
" num_inference_steps=args.num_inference_steps,\n",
|
| 372 |
+
" guidance_scale=args.guidance_scale,\n",
|
| 373 |
+
" height=args.height,\n",
|
| 374 |
+
" width=args.width,\n",
|
| 375 |
+
" generator=generator,\n",
|
| 376 |
+
" )\n",
|
| 377 |
+
" \n",
|
| 378 |
+
" if args.concat_eval_results or args.repaint:\n",
|
| 379 |
+
" person_images = to_pil_image(person_images)\n",
|
| 380 |
+
" cloth_images = to_pil_image(cloth_images)\n",
|
| 381 |
+
" masks = to_pil_image(masks)\n",
|
| 382 |
+
" for i, result in enumerate(results):\n",
|
| 383 |
+
" person_name = batch['person_name'][i]\n",
|
| 384 |
+
" output_path = os.path.join(args.output_dir, person_name)\n",
|
| 385 |
+
" if not os.path.exists(os.path.dirname(output_path)):\n",
|
| 386 |
+
" os.makedirs(os.path.dirname(output_path))\n",
|
| 387 |
+
" if args.repaint:\n",
|
| 388 |
+
" person_path, mask_path = dataset.data[batch['index'][i]]['person'], dataset.data[batch['index'][i]]['mask']\n",
|
| 389 |
+
" person_image= Image.open(person_path).resize(result.size, Image.LANCZOS)\n",
|
| 390 |
+
" mask = Image.open(mask_path).resize(result.size, Image.NEAREST)\n",
|
| 391 |
+
" result = repaint(person_image, mask, result)\n",
|
| 392 |
+
" if args.concat_eval_results:\n",
|
| 393 |
+
" w, h = result.size\n",
|
| 394 |
+
" concated_result = Image.new('RGB', (w*3, h))\n",
|
| 395 |
+
" concated_result.paste(person_images[i], (0, 0))\n",
|
| 396 |
+
" concated_result.paste(cloth_images[i], (w, 0)) \n",
|
| 397 |
+
" concated_result.paste(result, (w*2, 0))\n",
|
| 398 |
+
" result = concated_result\n",
|
| 399 |
+
" result.save(output_path)\n",
|
| 400 |
+
"\n",
|
| 401 |
+
"if __name__ == \"__main__\":\n",
|
| 402 |
+
" main()"
|
| 403 |
+
]
|
| 404 |
+
},
|
| 405 |
+
{
|
| 406 |
+
"cell_type": "code",
|
| 407 |
+
"execution_count": null,
|
| 408 |
+
"id": "55d88911",
|
| 409 |
+
"metadata": {},
|
| 410 |
+
"outputs": [],
|
| 411 |
+
"source": []
|
| 412 |
+
}
|
| 413 |
+
],
|
| 414 |
+
"metadata": {
|
| 415 |
+
"kernelspec": {
|
| 416 |
+
"display_name": "harsh",
|
| 417 |
+
"language": "python",
|
| 418 |
+
"name": "python3"
|
| 419 |
+
},
|
| 420 |
+
"language_info": {
|
| 421 |
+
"codemirror_mode": {
|
| 422 |
+
"name": "ipython",
|
| 423 |
+
"version": 3
|
| 424 |
+
},
|
| 425 |
+
"file_extension": ".py",
|
| 426 |
+
"mimetype": "text/x-python",
|
| 427 |
+
"name": "python",
|
| 428 |
+
"nbconvert_exporter": "python",
|
| 429 |
+
"pygments_lexer": "ipython3",
|
| 430 |
+
"version": "3.10.18"
|
| 431 |
+
}
|
| 432 |
+
},
|
| 433 |
+
"nbformat": 4,
|
| 434 |
+
"nbformat_minor": 5
|
| 435 |
+
}
|
test.ipynb
DELETED
|
@@ -1,1430 +0,0 @@
|
|
| 1 |
-
{
|
| 2 |
-
"cells": [
|
| 3 |
-
{
|
| 4 |
-
"cell_type": "code",
|
| 5 |
-
"execution_count": null,
|
| 6 |
-
"id": "6387c9e1",
|
| 7 |
-
"metadata": {},
|
| 8 |
-
"outputs": [],
|
| 9 |
-
"source": []
|
| 10 |
-
},
|
| 11 |
-
{
|
| 12 |
-
"cell_type": "code",
|
| 13 |
-
"execution_count": null,
|
| 14 |
-
"id": "ca9233f0",
|
| 15 |
-
"metadata": {},
|
| 16 |
-
"outputs": [
|
| 17 |
-
{
|
| 18 |
-
"data": {
|
| 19 |
-
"text/plain": [
|
| 20 |
-
"'/kaggle/working'"
|
| 21 |
-
]
|
| 22 |
-
},
|
| 23 |
-
"execution_count": 16,
|
| 24 |
-
"metadata": {},
|
| 25 |
-
"output_type": "execute_result"
|
| 26 |
-
}
|
| 27 |
-
],
|
| 28 |
-
"source": []
|
| 29 |
-
},
|
| 30 |
-
{
|
| 31 |
-
"cell_type": "code",
|
| 32 |
-
"execution_count": 1,
|
| 33 |
-
"id": "3d2f98af",
|
| 34 |
-
"metadata": {},
|
| 35 |
-
"outputs": [
|
| 36 |
-
{
|
| 37 |
-
"name": "stdout",
|
| 38 |
-
"output_type": "stream",
|
| 39 |
-
"text": [
|
| 40 |
-
"\u001b[0m\u001b[01;34mtest\u001b[0m/ test_pairs.txt \u001b[01;34mtrain\u001b[0m/ train_pairs.txt\n"
|
| 41 |
-
]
|
| 42 |
-
}
|
| 43 |
-
],
|
| 44 |
-
"source": [
|
| 45 |
-
"ls /kaggle/input/viton-hd-dataset"
|
| 46 |
-
]
|
| 47 |
-
},
|
| 48 |
-
{
|
| 49 |
-
"cell_type": "code",
|
| 50 |
-
"execution_count": 2,
|
| 51 |
-
"id": "dc0f36f4",
|
| 52 |
-
"metadata": {},
|
| 53 |
-
"outputs": [
|
| 54 |
-
{
|
| 55 |
-
"name": "stdout",
|
| 56 |
-
"output_type": "stream",
|
| 57 |
-
"text": [
|
| 58 |
-
"Cloning into 'stable-diffusion'...\n",
|
| 59 |
-
"remote: Enumerating objects: 156, done.\u001b[K\n",
|
| 60 |
-
"remote: Counting objects: 100% (156/156), done.\u001b[K\n",
|
| 61 |
-
"remote: Compressing objects: 100% (129/129), done.\u001b[K\n",
|
| 62 |
-
"remote: Total 156 (delta 41), reused 141 (delta 27), pack-reused 0 (from 0)\u001b[K\n",
|
| 63 |
-
"Receiving objects: 100% (156/156), 9.12 MiB | 41.53 MiB/s, done.\n",
|
| 64 |
-
"Resolving deltas: 100% (41/41), done.\n"
|
| 65 |
-
]
|
| 66 |
-
}
|
| 67 |
-
],
|
| 68 |
-
"source": [
|
| 69 |
-
"!git clone -b CatVTON https://github.com/Harsh-Kesharwani/stable-diffusion.git"
|
| 70 |
-
]
|
| 71 |
-
},
|
| 72 |
-
{
|
| 73 |
-
"cell_type": "code",
|
| 74 |
-
"execution_count": 3,
|
| 75 |
-
"id": "a0bf01ab",
|
| 76 |
-
"metadata": {},
|
| 77 |
-
"outputs": [
|
| 78 |
-
{
|
| 79 |
-
"name": "stdout",
|
| 80 |
-
"output_type": "stream",
|
| 81 |
-
"text": [
|
| 82 |
-
"/kaggle/working/stable-diffusion\n"
|
| 83 |
-
]
|
| 84 |
-
}
|
| 85 |
-
],
|
| 86 |
-
"source": [
|
| 87 |
-
"cd stable-diffusion/"
|
| 88 |
-
]
|
| 89 |
-
},
|
| 90 |
-
{
|
| 91 |
-
"cell_type": "code",
|
| 92 |
-
"execution_count": 4,
|
| 93 |
-
"id": "1401cd56",
|
| 94 |
-
"metadata": {},
|
| 95 |
-
"outputs": [
|
| 96 |
-
{
|
| 97 |
-
"name": "stdout",
|
| 98 |
-
"output_type": "stream",
|
| 99 |
-
"text": [
|
| 100 |
-
"--2025-06-15 18:33:59-- https://huggingface.co/sd-legacy/stable-diffusion-inpainting/resolve/main/sd-v1-5-inpainting.ckpt\n",
|
| 101 |
-
"Resolving huggingface.co (huggingface.co)... 3.171.171.128, 3.171.171.6, 3.171.171.104, ...\n",
|
| 102 |
-
"Connecting to huggingface.co (huggingface.co)|3.171.171.128|:443... connected.\n",
|
| 103 |
-
"HTTP request sent, awaiting response... 307 Temporary Redirect\n",
|
| 104 |
-
"Location: /stable-diffusion-v1-5/stable-diffusion-inpainting/resolve/main/sd-v1-5-inpainting.ckpt [following]\n",
|
| 105 |
-
"--2025-06-15 18:33:59-- https://huggingface.co/stable-diffusion-v1-5/stable-diffusion-inpainting/resolve/main/sd-v1-5-inpainting.ckpt\n",
|
| 106 |
-
"Reusing existing connection to huggingface.co:443.\n",
|
| 107 |
-
"HTTP request sent, awaiting response... 302 Found\n",
|
| 108 |
-
"Location: https://cdn-lfs.hf.co/repos/f6/56/f656f0fa3b8a40ac76d297fa2a4b00f981e8eb1261963460764e7dd3b35ec97f/c6bbc15e3224e6973459ba78de4998b80b50112b0ae5b5c67113d56b4e366b19?response-content-disposition=inline%3B+filename*%3DUTF-8%27%27sd-v1-5-inpainting.ckpt%3B+filename%3D%22sd-v1-5-inpainting.ckpt%22%3B&Expires=1750014781&Policy=eyJTdGF0ZW1lbnQiOlt7IkNvbmRpdGlvbiI6eyJEYXRlTGVzc1RoYW4iOnsiQVdTOkVwb2NoVGltZSI6MTc1MDAxNDc4MX19LCJSZXNvdXJjZSI6Imh0dHBzOi8vY2RuLWxmcy5oZi5jby9yZXBvcy9mNi81Ni9mNjU2ZjBmYTNiOGE0MGFjNzZkMjk3ZmEyYTRiMDBmOTgxZThlYjEyNjE5NjM0NjA3NjRlN2RkM2IzNWVjOTdmL2M2YmJjMTVlMzIyNGU2OTczNDU5YmE3OGRlNDk5OGI4MGI1MDExMmIwYWU1YjVjNjcxMTNkNTZiNGUzNjZiMTk%7EcmVzcG9uc2UtY29udGVudC1kaXNwb3NpdGlvbj0qIn1dfQ__&Signature=J6qT-n9PY34qz09a9caWcpc8-GaaGi%7EUu6AemCTMk48YsJgF9fjibpdUC-yTeIAJxbF4KxuFDt-5T6tXkQXgDaNakqUiTcxxJKpVNQYG9MlJ%7E3xeXE-WfBpwE9BbXkksCDStzHYqWV5ni5q0t2gPUqfwbmEFdfvZbQPol1oKH1ldWgCa3XusvR%7EUfdcxtci8gCgLXIrbNu7AG2lepj0AqpxkO5hsIBIhqUOTDXG7okdVLhepoAwnmJkc4neFV5LcR1Tt70My-1jdSFExn6c3yMLmWprMm3UMv6h5MyMifZWw4RdBrBWDjm0TPDwVMuwhgKiT6F9WTnUZvl1F0KKXFQ__&Key-Pair-Id=K3RPWS32NSSJCE [following]\n",
|
| 109 |
-
"--2025-06-15 18:33:59-- https://cdn-lfs.hf.co/repos/f6/56/f656f0fa3b8a40ac76d297fa2a4b00f981e8eb1261963460764e7dd3b35ec97f/c6bbc15e3224e6973459ba78de4998b80b50112b0ae5b5c67113d56b4e366b19?response-content-disposition=inline%3B+filename*%3DUTF-8%27%27sd-v1-5-inpainting.ckpt%3B+filename%3D%22sd-v1-5-inpainting.ckpt%22%3B&Expires=1750014781&Policy=eyJTdGF0ZW1lbnQiOlt7IkNvbmRpdGlvbiI6eyJEYXRlTGVzc1RoYW4iOnsiQVdTOkVwb2NoVGltZSI6MTc1MDAxNDc4MX19LCJSZXNvdXJjZSI6Imh0dHBzOi8vY2RuLWxmcy5oZi5jby9yZXBvcy9mNi81Ni9mNjU2ZjBmYTNiOGE0MGFjNzZkMjk3ZmEyYTRiMDBmOTgxZThlYjEyNjE5NjM0NjA3NjRlN2RkM2IzNWVjOTdmL2M2YmJjMTVlMzIyNGU2OTczNDU5YmE3OGRlNDk5OGI4MGI1MDExMmIwYWU1YjVjNjcxMTNkNTZiNGUzNjZiMTk%7EcmVzcG9uc2UtY29udGVudC1kaXNwb3NpdGlvbj0qIn1dfQ__&Signature=J6qT-n9PY34qz09a9caWcpc8-GaaGi%7EUu6AemCTMk48YsJgF9fjibpdUC-yTeIAJxbF4KxuFDt-5T6tXkQXgDaNakqUiTcxxJKpVNQYG9MlJ%7E3xeXE-WfBpwE9BbXkksCDStzHYqWV5ni5q0t2gPUqfwbmEFdfvZbQPol1oKH1ldWgCa3XusvR%7EUfdcxtci8gCgLXIrbNu7AG2lepj0AqpxkO5hsIBIhqUOTDXG7okdVLhepoAwnmJkc4neFV5LcR1Tt70My-1jdSFExn6c3yMLmWprMm3UMv6h5MyMifZWw4RdBrBWDjm0TPDwVMuwhgKiT6F9WTnUZvl1F0KKXFQ__&Key-Pair-Id=K3RPWS32NSSJCE\n",
|
| 110 |
-
"Resolving cdn-lfs.hf.co (cdn-lfs.hf.co)... 18.160.78.87, 18.160.78.43, 18.160.78.83, ...\n",
|
| 111 |
-
"Connecting to cdn-lfs.hf.co (cdn-lfs.hf.co)|18.160.78.87|:443... connected.\n",
|
| 112 |
-
"HTTP request sent, awaiting response... 200 OK\n",
|
| 113 |
-
"Length: 4265437280 (4.0G) [binary/octet-stream]\n",
|
| 114 |
-
"Saving to: ‘sd-v1-5-inpainting.ckpt’\n",
|
| 115 |
-
"\n",
|
| 116 |
-
"sd-v1-5-inpainting. 100%[===================>] 3.97G 306MB/s in 13s \n",
|
| 117 |
-
"\n",
|
| 118 |
-
"2025-06-15 18:34:13 (302 MB/s) - ‘sd-v1-5-inpainting.ckpt’ saved [4265437280/4265437280]\n",
|
| 119 |
-
"\n"
|
| 120 |
-
]
|
| 121 |
-
}
|
| 122 |
-
],
|
| 123 |
-
"source": [
|
| 124 |
-
"!wget https://huggingface.co/sd-legacy/stable-diffusion-inpainting/resolve/main/sd-v1-5-inpainting.ckpt"
|
| 125 |
-
]
|
| 126 |
-
},
|
| 127 |
-
{
|
| 128 |
-
"cell_type": "code",
|
| 129 |
-
"execution_count": null,
|
| 130 |
-
"id": "f7450c55",
|
| 131 |
-
"metadata": {},
|
| 132 |
-
"outputs": [
|
| 133 |
-
{
|
| 134 |
-
"name": "stdout",
|
| 135 |
-
"output_type": "stream",
|
| 136 |
-
"text": [
|
| 137 |
-
"--2025-06-11 10:33:19-- https://huggingface.co/zhengchong/CatVTON/resolve/main/vitonhd-16k-512/attention/model.safetensors?download=true\n",
|
| 138 |
-
"Resolving huggingface.co (huggingface.co)... 3.163.189.37, 3.163.189.90, 3.163.189.114, ...\n",
|
| 139 |
-
"Connecting to huggingface.co (huggingface.co)|3.163.189.37|:443... connected.\n",
|
| 140 |
-
"HTTP request sent, awaiting response... 302 Found\n",
|
| 141 |
-
"Location: https://cdn-lfs-us-1.hf.co/repos/49/48/4948d897acaa287a14cc261fb60bfdb3ff0e6571ca16a0b5fa38cec3cfebdc34/915df7bf19a33bee36a28d5f9ceaef1e2267c47526f98ca9e4c49e90ae5f0fd0?response-content-disposition=attachment%3B+filename*%3DUTF-8%27%27model.safetensors%3B+filename%3D%22model.safetensors%22%3B&Expires=1749641599&Policy=eyJTdGF0ZW1lbnQiOlt7IkNvbmRpdGlvbiI6eyJEYXRlTGVzc1RoYW4iOnsiQVdTOkVwb2NoVGltZSI6MTc0OTY0MTU5OX19LCJSZXNvdXJjZSI6Imh0dHBzOi8vY2RuLWxmcy11cy0xLmhmLmNvL3JlcG9zLzQ5LzQ4LzQ5NDhkODk3YWNhYTI4N2ExNGNjMjYxZmI2MGJmZGIzZmYwZTY1NzFjYTE2YTBiNWZhMzhjZWMzY2ZlYmRjMzQvOTE1ZGY3YmYxOWEzM2JlZTM2YTI4ZDVmOWNlYWVmMWUyMjY3YzQ3NTI2Zjk4Y2E5ZTRjNDllOTBhZTVmMGZkMD9yZXNwb25zZS1jb250ZW50LWRpc3Bvc2l0aW9uPSoifV19&Signature=iN3Lw7GVk22rlaKenmmcr3VTvG2wC9AFWTNHUmdS8DOVyKF2fUSnjW3QnGTm6P15luwwy2xs-43aiE22hmdjFm9AOV9v67mBvhUe3Gjp9k2DC-KIY%7ES6YuRPUUMLHSriK2bN6GfVpl6e-XN%7Ew6mEHiyUah9plAkKGidYjfaUXrODQr34siqAmTjDDD8wRyHAbLFiCMB-zUbllG4YjEO-rJkilkVtUEriayspO1uEKe%7EtAjW27n5Te68FqKTX%7Etj77fPDKGNV4p%7EUIvRtPx4jdtb1Mll7ga5C-YMwpNCKDX4bvWDMrnf2NNs9EIouNdjMZdBpPHUH2EpQGfEASUX0eg__&Key-Pair-Id=K24J24Z295AEI9 [following]\n",
|
| 142 |
-
"--2025-06-11 10:33:19-- https://cdn-lfs-us-1.hf.co/repos/49/48/4948d897acaa287a14cc261fb60bfdb3ff0e6571ca16a0b5fa38cec3cfebdc34/915df7bf19a33bee36a28d5f9ceaef1e2267c47526f98ca9e4c49e90ae5f0fd0?response-content-disposition=attachment%3B+filename*%3DUTF-8%27%27model.safetensors%3B+filename%3D%22model.safetensors%22%3B&Expires=1749641599&Policy=eyJTdGF0ZW1lbnQiOlt7IkNvbmRpdGlvbiI6eyJEYXRlTGVzc1RoYW4iOnsiQVdTOkVwb2NoVGltZSI6MTc0OTY0MTU5OX19LCJSZXNvdXJjZSI6Imh0dHBzOi8vY2RuLWxmcy11cy0xLmhmLmNvL3JlcG9zLzQ5LzQ4LzQ5NDhkODk3YWNhYTI4N2ExNGNjMjYxZmI2MGJmZGIzZmYwZTY1NzFjYTE2YTBiNWZhMzhjZWMzY2ZlYmRjMzQvOTE1ZGY3YmYxOWEzM2JlZTM2YTI4ZDVmOWNlYWVmMWUyMjY3YzQ3NTI2Zjk4Y2E5ZTRjNDllOTBhZTVmMGZkMD9yZXNwb25zZS1jb250ZW50LWRpc3Bvc2l0aW9uPSoifV19&Signature=iN3Lw7GVk22rlaKenmmcr3VTvG2wC9AFWTNHUmdS8DOVyKF2fUSnjW3QnGTm6P15luwwy2xs-43aiE22hmdjFm9AOV9v67mBvhUe3Gjp9k2DC-KIY%7ES6YuRPUUMLHSriK2bN6GfVpl6e-XN%7Ew6mEHiyUah9plAkKGidYjfaUXrODQr34siqAmTjDDD8wRyHAbLFiCMB-zUbllG4YjEO-rJkilkVtUEriayspO1uEKe%7EtAjW27n5Te68FqKTX%7Etj77fPDKGNV4p%7EUIvRtPx4jdtb1Mll7ga5C-YMwpNCKDX4bvWDMrnf2NNs9EIouNdjMZdBpPHUH2EpQGfEASUX0eg__&Key-Pair-Id=K24J24Z295AEI9\n",
|
| 143 |
-
"Resolving cdn-lfs-us-1.hf.co (cdn-lfs-us-1.hf.co)... 18.238.238.75, 18.238.238.106, 18.238.238.119, ...\n",
|
| 144 |
-
"Connecting to cdn-lfs-us-1.hf.co (cdn-lfs-us-1.hf.co)|18.238.238.75|:443... connected.\n",
|
| 145 |
-
"HTTP request sent, awaiting response... 200 OK\n",
|
| 146 |
-
"Length: 198303368 (189M) [binary/octet-stream]\n",
|
| 147 |
-
"Saving to: ‘model.safetensors?download=true’\n",
|
| 148 |
-
"\n",
|
| 149 |
-
"model.safetensors?d 100%[===================>] 189.12M 298MB/s in 0.6s \n",
|
| 150 |
-
"\n",
|
| 151 |
-
"2025-06-11 10:33:20 (298 MB/s) - ‘model.safetensors?download=true’ saved [198303368/198303368]\n",
|
| 152 |
-
"\n"
|
| 153 |
-
]
|
| 154 |
-
}
|
| 155 |
-
],
|
| 156 |
-
"source": [
|
| 157 |
-
"# !wget https://huggingface.co/zhengchong/CatVTON/resolve/main/vitonhd-16k-512/attention/model.safetensors?download=true "
|
| 158 |
-
]
|
| 159 |
-
},
|
| 160 |
-
{
|
| 161 |
-
"cell_type": "code",
|
| 162 |
-
"execution_count": null,
|
| 163 |
-
"id": "ca20c487",
|
| 164 |
-
"metadata": {},
|
| 165 |
-
"outputs": [],
|
| 166 |
-
"source": [
|
| 167 |
-
"# mv 'model.safetensors?download=true' model.safetensors"
|
| 168 |
-
]
|
| 169 |
-
},
|
| 170 |
-
{
|
| 171 |
-
"cell_type": "code",
|
| 172 |
-
"execution_count": 12,
|
| 173 |
-
"id": "6d0a1287",
|
| 174 |
-
"metadata": {},
|
| 175 |
-
"outputs": [
|
| 176 |
-
{
|
| 177 |
-
"name": "stdout",
|
| 178 |
-
"output_type": "stream",
|
| 179 |
-
"text": [
|
| 180 |
-
"attention.py encoder.py\t model.safetensors sd-v1-5-inpainting.ckpt\n",
|
| 181 |
-
"clip.py interface.py\t pipeline.py\t test.ipynb\n",
|
| 182 |
-
"ddpm.py merges.txt\t README.md\t vocab.json\n",
|
| 183 |
-
"decoder.py model_converter.py requirements.txt\n",
|
| 184 |
-
"diffusion.py model.py\t\t sample_dataset\n"
|
| 185 |
-
]
|
| 186 |
-
}
|
| 187 |
-
],
|
| 188 |
-
"source": [
|
| 189 |
-
"!ls"
|
| 190 |
-
]
|
| 191 |
-
},
|
| 192 |
-
{
|
| 193 |
-
"cell_type": "code",
|
| 194 |
-
"execution_count": 14,
|
| 195 |
-
"id": "8f11470e",
|
| 196 |
-
"metadata": {},
|
| 197 |
-
"outputs": [
|
| 198 |
-
{
|
| 199 |
-
"name": "stdout",
|
| 200 |
-
"output_type": "stream",
|
| 201 |
-
"text": [
|
| 202 |
-
"/kaggle/working/stable-diffusion/CatVTON\n"
|
| 203 |
-
]
|
| 204 |
-
}
|
| 205 |
-
],
|
| 206 |
-
"source": [
|
| 207 |
-
"cd .."
|
| 208 |
-
]
|
| 209 |
-
},
|
| 210 |
-
{
|
| 211 |
-
"cell_type": "code",
|
| 212 |
-
"execution_count": 15,
|
| 213 |
-
"id": "cb794cb3",
|
| 214 |
-
"metadata": {},
|
| 215 |
-
"outputs": [
|
| 216 |
-
{
|
| 217 |
-
"name": "stdout",
|
| 218 |
-
"output_type": "stream",
|
| 219 |
-
"text": [
|
| 220 |
-
"app_flux.py eval.py preprocess_agnostic_mask.py \u001b[0m\u001b[01;34mstable-diffusion\u001b[0m/\n",
|
| 221 |
-
"app_p2p.py index.html \u001b[01;34m__pycache__\u001b[0m/ utils.py\n",
|
| 222 |
-
"app.py inference.py README.md\n",
|
| 223 |
-
"\u001b[01;34mdensepose\u001b[0m/ LICENSE requirements.txt\n",
|
| 224 |
-
"\u001b[01;34mdetectron2\u001b[0m/ \u001b[01;34mmodel\u001b[0m/ \u001b[01;34mresource\u001b[0m/\n"
|
| 225 |
-
]
|
| 226 |
-
}
|
| 227 |
-
],
|
| 228 |
-
"source": [
|
| 229 |
-
"ls"
|
| 230 |
-
]
|
| 231 |
-
},
|
| 232 |
-
{
|
| 233 |
-
"cell_type": "code",
|
| 234 |
-
"execution_count": 16,
|
| 235 |
-
"id": "b6af145b",
|
| 236 |
-
"metadata": {},
|
| 237 |
-
"outputs": [],
|
| 238 |
-
"source": [
|
| 239 |
-
"import os\n",
|
| 240 |
-
"import shutil\n",
|
| 241 |
-
"\n",
|
| 242 |
-
"src_dir = \"./stable-diffusion\"\n",
|
| 243 |
-
"dst_dir = \".\"\n",
|
| 244 |
-
"\n",
|
| 245 |
-
"for filename in os.listdir(src_dir):\n",
|
| 246 |
-
" src_path = os.path.join(src_dir, filename)\n",
|
| 247 |
-
" dst_path = os.path.join(dst_dir, filename)\n",
|
| 248 |
-
" if os.path.isfile(src_path):\n",
|
| 249 |
-
" shutil.move(src_path, dst_path)\n",
|
| 250 |
-
" elif os.path.isdir(src_path):\n",
|
| 251 |
-
" shutil.move(src_path, dst_path)"
|
| 252 |
-
]
|
| 253 |
-
},
|
| 254 |
-
{
|
| 255 |
-
"cell_type": "code",
|
| 256 |
-
"execution_count": null,
|
| 257 |
-
"id": "63ee438c",
|
| 258 |
-
"metadata": {},
|
| 259 |
-
"outputs": [],
|
| 260 |
-
"source": []
|
| 261 |
-
},
|
| 262 |
-
{
|
| 263 |
-
"cell_type": "code",
|
| 264 |
-
"execution_count": null,
|
| 265 |
-
"id": "60598bd3",
|
| 266 |
-
"metadata": {},
|
| 267 |
-
"outputs": [],
|
| 268 |
-
"source": []
|
| 269 |
-
},
|
| 270 |
-
{
|
| 271 |
-
"cell_type": "code",
|
| 272 |
-
"execution_count": 229,
|
| 273 |
-
"id": "192a649c",
|
| 274 |
-
"metadata": {},
|
| 275 |
-
"outputs": [],
|
| 276 |
-
"source": [
|
| 277 |
-
"import torch\n",
|
| 278 |
-
"import gc\n",
|
| 279 |
-
"\n",
|
| 280 |
-
"# Clear CUDA cache and collect garbage\n",
|
| 281 |
-
"torch.cuda.empty_cache()\n",
|
| 282 |
-
"gc.collect()\n",
|
| 283 |
-
"\n",
|
| 284 |
-
"# Delete all user-defined variables except for built-ins and modules\n",
|
| 285 |
-
"for var in list(globals()):\n",
|
| 286 |
-
" if not var.startswith(\"__\") and var not in [\"torch\", \"gc\"]:\n",
|
| 287 |
-
" del globals()[var]\n",
|
| 288 |
-
"\n",
|
| 289 |
-
"gc.collect()\n",
|
| 290 |
-
"torch.cuda.empty_cache()"
|
| 291 |
-
]
|
| 292 |
-
},
|
| 293 |
-
{
|
| 294 |
-
"cell_type": "code",
|
| 295 |
-
"execution_count": 245,
|
| 296 |
-
"id": "a3a4a5dc",
|
| 297 |
-
"metadata": {},
|
| 298 |
-
"outputs": [],
|
| 299 |
-
"source": [
|
| 300 |
-
"import torch\n",
|
| 301 |
-
"import gc\n",
|
| 302 |
-
"\n",
|
| 303 |
-
"# Clear CUDA cache and collect garbage\n",
|
| 304 |
-
"torch.cuda.empty_cache()\n",
|
| 305 |
-
"gc.collect()\n",
|
| 306 |
-
"\n",
|
| 307 |
-
"# Delete all user-defined variables except for built-ins and modules\n",
|
| 308 |
-
"for var_name in list(globals()):\n",
|
| 309 |
-
" if not var_name.startswith(\"__\") and var_name not in [\"torch\", \"gc\"]:\n",
|
| 310 |
-
" del globals()[var_name]\n",
|
| 311 |
-
"\n",
|
| 312 |
-
"gc.collect()\n",
|
| 313 |
-
"torch.cuda.empty_cache()\n",
|
| 314 |
-
"\n",
|
| 315 |
-
"import tensorflow as tf\n",
|
| 316 |
-
"tf.keras.backend.clear_session()"
|
| 317 |
-
]
|
| 318 |
-
},
|
| 319 |
-
{
|
| 320 |
-
"cell_type": "code",
|
| 321 |
-
"execution_count": 4,
|
| 322 |
-
"id": "91ef7a4e",
|
| 323 |
-
"metadata": {},
|
| 324 |
-
"outputs": [
|
| 325 |
-
{
|
| 326 |
-
"data": {
|
| 327 |
-
"text/plain": [
|
| 328 |
-
"0"
|
| 329 |
-
]
|
| 330 |
-
},
|
| 331 |
-
"execution_count": 4,
|
| 332 |
-
"metadata": {},
|
| 333 |
-
"output_type": "execute_result"
|
| 334 |
-
}
|
| 335 |
-
],
|
| 336 |
-
"source": [
|
| 337 |
-
"import torch\n",
|
| 338 |
-
"import gc\n",
|
| 339 |
-
"\n",
|
| 340 |
-
"torch.cuda.empty_cache() # Release unused GPU memory\n",
|
| 341 |
-
"gc.collect() # Run Python garbage collector"
|
| 342 |
-
]
|
| 343 |
-
},
|
| 344 |
-
{
|
| 345 |
-
"cell_type": "code",
|
| 346 |
-
"execution_count": 9,
|
| 347 |
-
"id": "08f29055",
|
| 348 |
-
"metadata": {},
|
| 349 |
-
"outputs": [
|
| 350 |
-
{
|
| 351 |
-
"name": "stdout",
|
| 352 |
-
"output_type": "stream",
|
| 353 |
-
"text": [
|
| 354 |
-
"GPU memory used: 0.00 MB / 16269.25 MB\n"
|
| 355 |
-
]
|
| 356 |
-
}
|
| 357 |
-
],
|
| 358 |
-
"source": [
|
| 359 |
-
"import torch\n",
|
| 360 |
-
"\n",
|
| 361 |
-
"if torch.cuda.is_available():\n",
|
| 362 |
-
" used = torch.cuda.memory_allocated() / 1024 ** 2 # in MB\n",
|
| 363 |
-
" total = torch.cuda.get_device_properties(0).total_memory / 1024 ** 2 # in MB\n",
|
| 364 |
-
" print(f\"GPU memory used: {used:.2f} MB / {total:.2f} MB\")\n",
|
| 365 |
-
"else:\n",
|
| 366 |
-
" print(\"CUDA is not available.\")"
|
| 367 |
-
]
|
| 368 |
-
},
|
| 369 |
-
{
|
| 370 |
-
"cell_type": "code",
|
| 371 |
-
"execution_count": 197,
|
| 372 |
-
"id": "6fbde810",
|
| 373 |
-
"metadata": {},
|
| 374 |
-
"outputs": [],
|
| 375 |
-
"source": [
|
| 376 |
-
"# rm -rf output"
|
| 377 |
-
]
|
| 378 |
-
},
|
| 379 |
-
{
|
| 380 |
-
"cell_type": "code",
|
| 381 |
-
"execution_count": 5,
|
| 382 |
-
"id": "37335c1e",
|
| 383 |
-
"metadata": {},
|
| 384 |
-
"outputs": [],
|
| 385 |
-
"source": [
|
| 386 |
-
"def compute_vae_encodings(image_tensor, encoder, device=\"cuda\"):\n",
|
| 387 |
-
" \"\"\"Encode image using VAE encoder\"\"\"\n",
|
| 388 |
-
" # Generate random noise for encoding\n",
|
| 389 |
-
" encoder_noise = torch.randn(\n",
|
| 390 |
-
" (image_tensor.shape[0], 4, image_tensor.shape[2] // 8, image_tensor.shape[3] // 8),\n",
|
| 391 |
-
" device=device,\n",
|
| 392 |
-
" )\n",
|
| 393 |
-
" \n",
|
| 394 |
-
" # Encode using your custom encoder\n",
|
| 395 |
-
" latent = encoder(image_tensor, encoder_noise)\n",
|
| 396 |
-
" return latent"
|
| 397 |
-
]
|
| 398 |
-
},
|
| 399 |
-
{
|
| 400 |
-
"cell_type": "code",
|
| 401 |
-
"execution_count": 6,
|
| 402 |
-
"id": "35d98b83",
|
| 403 |
-
"metadata": {},
|
| 404 |
-
"outputs": [],
|
| 405 |
-
"source": [
|
| 406 |
-
"def get_trainable_module(unet, trainable_module_name):\n",
|
| 407 |
-
" if trainable_module_name == \"unet\":\n",
|
| 408 |
-
" return unet\n",
|
| 409 |
-
" elif trainable_module_name == \"transformer\":\n",
|
| 410 |
-
" trainable_modules = torch.nn.ModuleList()\n",
|
| 411 |
-
" for blocks in [unet.encoders, unet.bottleneck, unet.decoders]:\n",
|
| 412 |
-
" if hasattr(blocks, \"attentions\"):\n",
|
| 413 |
-
" trainable_modules.append(blocks.attentions)\n",
|
| 414 |
-
" else:\n",
|
| 415 |
-
" for block in blocks:\n",
|
| 416 |
-
" if hasattr(block, \"attentions\"):\n",
|
| 417 |
-
" trainable_modules.append(block.attentions)\n",
|
| 418 |
-
" return trainable_modules\n",
|
| 419 |
-
" elif trainable_module_name == \"attention\":\n",
|
| 420 |
-
" attn_blocks = torch.nn.ModuleList()\n",
|
| 421 |
-
" for name, param in unet.named_modules():\n",
|
| 422 |
-
" if \"attention_1\" in name:\n",
|
| 423 |
-
" attn_blocks.append(param)\n",
|
| 424 |
-
" return attn_blocks\n",
|
| 425 |
-
" else:\n",
|
| 426 |
-
" raise ValueError(f\"Unknown trainable_module_name: {trainable_module_name}\")"
|
| 427 |
-
]
|
| 428 |
-
},
|
| 429 |
-
{
|
| 430 |
-
"cell_type": "code",
|
| 431 |
-
"execution_count": 7,
|
| 432 |
-
"id": "d7ff094a",
|
| 433 |
-
"metadata": {},
|
| 434 |
-
"outputs": [],
|
| 435 |
-
"source": [
|
| 436 |
-
"from torch.nn import functional as F\n",
|
| 437 |
-
"import torch\n",
|
| 438 |
-
"# from flash_attn import flash_attn_func\n",
|
| 439 |
-
"\n",
|
| 440 |
-
"class SkipAttnProcessor(torch.nn.Module):\n",
|
| 441 |
-
" def __init__(self, *args, **kwargs) -> None:\n",
|
| 442 |
-
" super().__init__()\n",
|
| 443 |
-
"\n",
|
| 444 |
-
" def __call__(\n",
|
| 445 |
-
" self,\n",
|
| 446 |
-
" attn,\n",
|
| 447 |
-
" hidden_states,\n",
|
| 448 |
-
" encoder_hidden_states=None,\n",
|
| 449 |
-
" attention_mask=None,\n",
|
| 450 |
-
" temb=None,\n",
|
| 451 |
-
" ):\n",
|
| 452 |
-
" return hidden_states\n",
|
| 453 |
-
"\n",
|
| 454 |
-
"class AttnProcessor2_0(torch.nn.Module):\n",
|
| 455 |
-
" r\"\"\"\n",
|
| 456 |
-
" Processor for implementing scaled dot-product attention (enabled by default if you're using PyTorch 2.0).\n",
|
| 457 |
-
" \"\"\"\n",
|
| 458 |
-
"\n",
|
| 459 |
-
" def __init__(\n",
|
| 460 |
-
" self,\n",
|
| 461 |
-
" hidden_size=None,\n",
|
| 462 |
-
" cross_attention_dim=None,\n",
|
| 463 |
-
" **kwargs\n",
|
| 464 |
-
" ):\n",
|
| 465 |
-
" super().__init__()\n",
|
| 466 |
-
" if not hasattr(F, \"scaled_dot_product_attention\"):\n",
|
| 467 |
-
" raise ImportError(\"AttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0.\")\n",
|
| 468 |
-
"\n",
|
| 469 |
-
" def __call__(\n",
|
| 470 |
-
" self,\n",
|
| 471 |
-
" attn,\n",
|
| 472 |
-
" hidden_states,\n",
|
| 473 |
-
" encoder_hidden_states=None,\n",
|
| 474 |
-
" attention_mask=None,\n",
|
| 475 |
-
" temb=None,\n",
|
| 476 |
-
" *args,\n",
|
| 477 |
-
" **kwargs,\n",
|
| 478 |
-
" ):\n",
|
| 479 |
-
" residual = hidden_states\n",
|
| 480 |
-
"\n",
|
| 481 |
-
" if attn.spatial_norm is not None:\n",
|
| 482 |
-
" hidden_states = attn.spatial_norm(hidden_states, temb)\n",
|
| 483 |
-
"\n",
|
| 484 |
-
" input_ndim = hidden_states.ndim\n",
|
| 485 |
-
"\n",
|
| 486 |
-
" if input_ndim == 4:\n",
|
| 487 |
-
" batch_size, channel, height, width = hidden_states.shape\n",
|
| 488 |
-
" hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)\n",
|
| 489 |
-
"\n",
|
| 490 |
-
" batch_size, sequence_length, _ = (\n",
|
| 491 |
-
" hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape\n",
|
| 492 |
-
" )\n",
|
| 493 |
-
"\n",
|
| 494 |
-
" if attention_mask is not None:\n",
|
| 495 |
-
" attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)\n",
|
| 496 |
-
" # scaled_dot_product_attention expects attention_mask shape to be\n",
|
| 497 |
-
" # (batch, heads, source_length, target_length)\n",
|
| 498 |
-
" attention_mask = attention_mask.view(batch_size, attn.heads, -1, attention_mask.shape[-1])\n",
|
| 499 |
-
"\n",
|
| 500 |
-
" if attn.group_norm is not None:\n",
|
| 501 |
-
" hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)\n",
|
| 502 |
-
"\n",
|
| 503 |
-
" query = attn.to_q(hidden_states)\n",
|
| 504 |
-
"\n",
|
| 505 |
-
" if encoder_hidden_states is None:\n",
|
| 506 |
-
" encoder_hidden_states = hidden_states\n",
|
| 507 |
-
" elif attn.norm_cross:\n",
|
| 508 |
-
" encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)\n",
|
| 509 |
-
"\n",
|
| 510 |
-
" key = attn.to_k(encoder_hidden_states)\n",
|
| 511 |
-
" value = attn.to_v(encoder_hidden_states)\n",
|
| 512 |
-
"\n",
|
| 513 |
-
" inner_dim = key.shape[-1]\n",
|
| 514 |
-
" head_dim = inner_dim // attn.heads\n",
|
| 515 |
-
"\n",
|
| 516 |
-
" query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)\n",
|
| 517 |
-
"\n",
|
| 518 |
-
" key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)\n",
|
| 519 |
-
" value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)\n",
|
| 520 |
-
"\n",
|
| 521 |
-
" # the output of sdp = (batch, num_heads, seq_len, head_dim)\n",
|
| 522 |
-
" # TODO: add support for attn.scale when we move to Torch 2.1\n",
|
| 523 |
-
" \n",
|
| 524 |
-
" hidden_states = F.scaled_dot_product_attention(\n",
|
| 525 |
-
" query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False\n",
|
| 526 |
-
" )\n",
|
| 527 |
-
" # hidden_states = flash_attn_func(\n",
|
| 528 |
-
" # query, key, value, dropout_p=0.0, causal=False\n",
|
| 529 |
-
" # )\n",
|
| 530 |
-
"\n",
|
| 531 |
-
" hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)\n",
|
| 532 |
-
" hidden_states = hidden_states.to(query.dtype)\n",
|
| 533 |
-
"\n",
|
| 534 |
-
" # linear proj\n",
|
| 535 |
-
" hidden_states = attn.to_out[0](hidden_states)\n",
|
| 536 |
-
" # dropout\n",
|
| 537 |
-
" hidden_states = attn.to_out[1](hidden_states)\n",
|
| 538 |
-
"\n",
|
| 539 |
-
" if input_ndim == 4:\n",
|
| 540 |
-
" hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)\n",
|
| 541 |
-
"\n",
|
| 542 |
-
" if attn.residual_connection:\n",
|
| 543 |
-
" hidden_states = hidden_states + residual\n",
|
| 544 |
-
"\n",
|
| 545 |
-
" hidden_states = hidden_states / attn.rescale_output_factor\n",
|
| 546 |
-
"\n",
|
| 547 |
-
" return hidden_states\n",
|
| 548 |
-
" "
|
| 549 |
-
]
|
| 550 |
-
},
|
| 551 |
-
{
|
| 552 |
-
"cell_type": "code",
|
| 553 |
-
"execution_count": 8,
|
| 554 |
-
"id": "84a7fa87",
|
| 555 |
-
"metadata": {},
|
| 556 |
-
"outputs": [],
|
| 557 |
-
"source": [
|
| 558 |
-
"import os\n",
|
| 559 |
-
"import json\n",
|
| 560 |
-
"import torch\n",
|
| 561 |
-
"\n",
|
| 562 |
-
"def init_adapter(unet, \n",
|
| 563 |
-
" cross_attn_cls=SkipAttnProcessor,\n",
|
| 564 |
-
" self_attn_cls=None,\n",
|
| 565 |
-
" cross_attn_dim=None, \n",
|
| 566 |
-
" **kwargs):\n",
|
| 567 |
-
" if cross_attn_dim is None:\n",
|
| 568 |
-
" cross_attn_dim = unet.config.cross_attention_dim\n",
|
| 569 |
-
" attn_procs = {}\n",
|
| 570 |
-
" for name in unet.attn_processors.keys():\n",
|
| 571 |
-
" cross_attention_dim = None if name.endswith(\"attn1.processor\") else cross_attn_dim\n",
|
| 572 |
-
" if name.startswith(\"mid_block\"):\n",
|
| 573 |
-
" hidden_size = unet.config.block_out_channels[-1]\n",
|
| 574 |
-
" elif name.startswith(\"up_blocks\"):\n",
|
| 575 |
-
" block_id = int(name[len(\"up_blocks.\")])\n",
|
| 576 |
-
" hidden_size = list(reversed(unet.config.block_out_channels))[block_id]\n",
|
| 577 |
-
" elif name.startswith(\"down_blocks\"):\n",
|
| 578 |
-
" block_id = int(name[len(\"down_blocks.\")])\n",
|
| 579 |
-
" hidden_size = unet.config.block_out_channels[block_id]\n",
|
| 580 |
-
" if cross_attention_dim is None:\n",
|
| 581 |
-
" if self_attn_cls is not None:\n",
|
| 582 |
-
" attn_procs[name] = self_attn_cls(hidden_size=hidden_size, cross_attention_dim=cross_attention_dim, **kwargs)\n",
|
| 583 |
-
" else:\n",
|
| 584 |
-
" # retain the original attn processor\n",
|
| 585 |
-
" attn_procs[name] = AttnProcessor2_0(hidden_size=hidden_size, cross_attention_dim=cross_attention_dim, **kwargs)\n",
|
| 586 |
-
" else:\n",
|
| 587 |
-
" attn_procs[name] = cross_attn_cls(hidden_size=hidden_size, cross_attention_dim=cross_attention_dim, **kwargs)\n",
|
| 588 |
-
" \n",
|
| 589 |
-
" unet.set_attn_processor(attn_procs)\n",
|
| 590 |
-
" adapter_modules = torch.nn.ModuleList(unet.attn_processors.values())\n",
|
| 591 |
-
" return adapter_modules\n",
|
| 592 |
-
"\n",
|
| 593 |
-
"def init_diffusion_model(diffusion_model_name_or_path, unet_class=None):\n",
|
| 594 |
-
" from diffusers import AutoencoderKL\n",
|
| 595 |
-
" from transformers import CLIPTextModel, CLIPTokenizer\n",
|
| 596 |
-
"\n",
|
| 597 |
-
" text_encoder = CLIPTextModel.from_pretrained(diffusion_model_name_or_path, subfolder=\"text_encoder\")\n",
|
| 598 |
-
" vae = AutoencoderKL.from_pretrained(diffusion_model_name_or_path, subfolder=\"vae\")\n",
|
| 599 |
-
" tokenizer = CLIPTokenizer.from_pretrained(diffusion_model_name_or_path, subfolder=\"tokenizer\")\n",
|
| 600 |
-
" try:\n",
|
| 601 |
-
" unet_folder = os.path.join(diffusion_model_name_or_path, \"unet\")\n",
|
| 602 |
-
" unet_configs = json.load(open(os.path.join(unet_folder, \"config.json\"), \"r\"))\n",
|
| 603 |
-
" unet = unet_class(**unet_configs)\n",
|
| 604 |
-
" unet.load_state_dict(torch.load(os.path.join(unet_folder, \"diffusion_pytorch_model.bin\"), map_location=\"cpu\"), strict=True)\n",
|
| 605 |
-
" except:\n",
|
| 606 |
-
" unet = None\n",
|
| 607 |
-
" return text_encoder, vae, tokenizer, unet\n",
|
| 608 |
-
"\n",
|
| 609 |
-
"def attn_of_unet(unet):\n",
|
| 610 |
-
" attn_blocks = torch.nn.ModuleList()\n",
|
| 611 |
-
" for name, param in unet.named_modules():\n",
|
| 612 |
-
" if \"attn1\" in name:\n",
|
| 613 |
-
" attn_blocks.append(param)\n",
|
| 614 |
-
" return attn_blocks\n",
|
| 615 |
-
"\n",
|
| 616 |
-
"def get_trainable_module(unet, trainable_module_name):\n",
|
| 617 |
-
" if trainable_module_name == \"unet\":\n",
|
| 618 |
-
" return unet\n",
|
| 619 |
-
" elif trainable_module_name == \"transformer\":\n",
|
| 620 |
-
" trainable_modules = torch.nn.ModuleList()\n",
|
| 621 |
-
" for blocks in [unet.down_blocks, unet.mid_block, unet.up_blocks]:\n",
|
| 622 |
-
" if hasattr(blocks, \"attentions\"):\n",
|
| 623 |
-
" trainable_modules.append(blocks.attentions)\n",
|
| 624 |
-
" else:\n",
|
| 625 |
-
" for block in blocks:\n",
|
| 626 |
-
" if hasattr(block, \"attentions\"):\n",
|
| 627 |
-
" trainable_modules.append(block.attentions)\n",
|
| 628 |
-
" return trainable_modules\n",
|
| 629 |
-
" elif trainable_module_name == \"attention\":\n",
|
| 630 |
-
" attn_blocks = torch.nn.ModuleList()\n",
|
| 631 |
-
" for name, param in unet.named_modules():\n",
|
| 632 |
-
" if \"attn1\" in name:\n",
|
| 633 |
-
" attn_blocks.append(param)\n",
|
| 634 |
-
" return attn_blocks\n",
|
| 635 |
-
" else:\n",
|
| 636 |
-
" raise ValueError(f\"Unknown trainable_module_name: {trainable_module_name}\")\n",
|
| 637 |
-
"\n",
|
| 638 |
-
" \n",
|
| 639 |
-
" "
|
| 640 |
-
]
|
| 641 |
-
},
|
| 642 |
-
{
|
| 643 |
-
"cell_type": "code",
|
| 644 |
-
"execution_count": 9,
|
| 645 |
-
"id": "6028381d",
|
| 646 |
-
"metadata": {},
|
| 647 |
-
"outputs": [
|
| 648 |
-
{
|
| 649 |
-
"name": "stderr",
|
| 650 |
-
"output_type": "stream",
|
| 651 |
-
"text": [
|
| 652 |
-
"2025-06-15 18:35:15.189276: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:477] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered\n",
|
| 653 |
-
"WARNING: All log messages before absl::InitializeLog() is called are written to STDERR\n",
|
| 654 |
-
"E0000 00:00:1750012515.396602 73 cuda_dnn.cc:8310] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered\n",
|
| 655 |
-
"E0000 00:00:1750012515.456784 73 cuda_blas.cc:1418] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered\n"
|
| 656 |
-
]
|
| 657 |
-
}
|
| 658 |
-
],
|
| 659 |
-
"source": [
|
| 660 |
-
"import inspect\n",
|
| 661 |
-
"import os\n",
|
| 662 |
-
"from typing import Union\n",
|
| 663 |
-
"\n",
|
| 664 |
-
"import PIL\n",
|
| 665 |
-
"import numpy as np\n",
|
| 666 |
-
"import torch\n",
|
| 667 |
-
"import tqdm\n",
|
| 668 |
-
"from accelerate import load_checkpoint_in_model\n",
|
| 669 |
-
"from diffusers import AutoencoderKL, DDIMScheduler, UNet2DConditionModel\n",
|
| 670 |
-
"from diffusers.pipelines.stable_diffusion.safety_checker import \\\n",
|
| 671 |
-
" StableDiffusionSafetyChecker\n",
|
| 672 |
-
"from diffusers.utils.torch_utils import randn_tensor\n",
|
| 673 |
-
"from huggingface_hub import snapshot_download\n",
|
| 674 |
-
"from transformers import CLIPImageProcessor\n",
|
| 675 |
-
"\n",
|
| 676 |
-
"from utils import (check_inputs, get_time_embedding, numpy_to_pil, prepare_image,\n",
|
| 677 |
-
" prepare_mask_image, resize_and_crop, resize_and_padding)\n",
|
| 678 |
-
"from ddpm import DDPMSampler\n",
|
| 679 |
-
"\n",
|
| 680 |
-
"class CatVTONPipeline:\n",
|
| 681 |
-
" def __init__(\n",
|
| 682 |
-
" self, \n",
|
| 683 |
-
" base_ckpt, \n",
|
| 684 |
-
" attn_ckpt, \n",
|
| 685 |
-
" attn_ckpt_version=\"mix\",\n",
|
| 686 |
-
" weight_dtype=torch.float32,\n",
|
| 687 |
-
" device='cuda',\n",
|
| 688 |
-
" compile=False,\n",
|
| 689 |
-
" skip_safety_check=True,\n",
|
| 690 |
-
" use_tf32=True,\n",
|
| 691 |
-
" models={},\n",
|
| 692 |
-
" ):\n",
|
| 693 |
-
" self.device = device\n",
|
| 694 |
-
" self.weight_dtype = weight_dtype\n",
|
| 695 |
-
" self.skip_safety_check = skip_safety_check\n",
|
| 696 |
-
" self.models = models\n",
|
| 697 |
-
"\n",
|
| 698 |
-
" self.generator = torch.Generator(device=device)\n",
|
| 699 |
-
" self.noise_scheduler = DDPMSampler(generator=self.generator)\n",
|
| 700 |
-
" # self.vae = AutoencoderKL.from_pretrained(\"stabilityai/sd-vae-ft-mse\").to(device, dtype=weight_dtype)\n",
|
| 701 |
-
" self.encoder= models.get('encoder', None)\n",
|
| 702 |
-
" self.decoder= models.get('decoder', None)\n",
|
| 703 |
-
" if not skip_safety_check:\n",
|
| 704 |
-
" self.feature_extractor = CLIPImageProcessor.from_pretrained(base_ckpt, subfolder=\"feature_extractor\")\n",
|
| 705 |
-
" self.safety_checker = StableDiffusionSafetyChecker.from_pretrained(base_ckpt, subfolder=\"safety_checker\").to(device, dtype=weight_dtype)\n",
|
| 706 |
-
" self.unet = UNet2DConditionModel.from_pretrained(base_ckpt, subfolder=\"unet\").to(device, dtype=weight_dtype)\n",
|
| 707 |
-
" # self.unet=models.get('diffusion', None)\n",
|
| 708 |
-
" init_adapter(self.unet, cross_attn_cls=SkipAttnProcessor) # Skip Cross-Attention\n",
|
| 709 |
-
" self.attn_modules = get_trainable_module(self.unet, \"attention\")\n",
|
| 710 |
-
" self.auto_attn_ckpt_load(attn_ckpt, attn_ckpt_version)\n",
|
| 711 |
-
" # Pytorch 2.0 Compile\n",
|
| 712 |
-
" # if compile:\n",
|
| 713 |
-
" # self.unet = torch.compile(self.unet)\n",
|
| 714 |
-
" # self.vae = torch.compile(self.vae, mode=\"reduce-overhead\")\n",
|
| 715 |
-
" \n",
|
| 716 |
-
" # # Enable TF32 for faster training on Ampere GPUs (A100 and RTX 30 series).\n",
|
| 717 |
-
" if use_tf32:\n",
|
| 718 |
-
" torch.set_float32_matmul_precision(\"high\")\n",
|
| 719 |
-
" torch.backends.cuda.matmul.allow_tf32 = True\n",
|
| 720 |
-
"\n",
|
| 721 |
-
" def auto_attn_ckpt_load(self, attn_ckpt, version):\n",
|
| 722 |
-
" sub_folder = {\n",
|
| 723 |
-
" \"mix\": \"mix-48k-1024\",\n",
|
| 724 |
-
" \"vitonhd\": \"vitonhd-16k-512\",\n",
|
| 725 |
-
" \"dresscode\": \"dresscode-16k-512\",\n",
|
| 726 |
-
" }[version]\n",
|
| 727 |
-
" if os.path.exists(attn_ckpt):\n",
|
| 728 |
-
" load_checkpoint_in_model(self.attn_modules, os.path.join(attn_ckpt, sub_folder, 'attention'))\n",
|
| 729 |
-
" else:\n",
|
| 730 |
-
" repo_path = snapshot_download(repo_id=attn_ckpt)\n",
|
| 731 |
-
" print(f\"Downloaded {attn_ckpt} to {repo_path}\")\n",
|
| 732 |
-
" load_checkpoint_in_model(self.attn_modules, os.path.join(repo_path, sub_folder, 'attention'))\n",
|
| 733 |
-
" \n",
|
| 734 |
-
" def run_safety_checker(self, image):\n",
|
| 735 |
-
" if self.safety_checker is None:\n",
|
| 736 |
-
" has_nsfw_concept = None\n",
|
| 737 |
-
" else:\n",
|
| 738 |
-
" safety_checker_input = self.feature_extractor(image, return_tensors=\"pt\").to(self.device)\n",
|
| 739 |
-
" image, has_nsfw_concept = self.safety_checker(\n",
|
| 740 |
-
" images=image, clip_input=safety_checker_input.pixel_values.to(self.weight_dtype)\n",
|
| 741 |
-
" )\n",
|
| 742 |
-
" return image, has_nsfw_concept\n",
|
| 743 |
-
" \n",
|
| 744 |
-
" def prepare_extra_step_kwargs(self, generator, eta):\n",
|
| 745 |
-
" # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature\n",
|
| 746 |
-
" # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers.\n",
|
| 747 |
-
" # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502\n",
|
| 748 |
-
" # and should be between [0, 1]\n",
|
| 749 |
-
"\n",
|
| 750 |
-
" accepts_eta = \"eta\" in set(\n",
|
| 751 |
-
" inspect.signature(self.noise_scheduler.step).parameters.keys()\n",
|
| 752 |
-
" )\n",
|
| 753 |
-
" extra_step_kwargs = {}\n",
|
| 754 |
-
" if accepts_eta:\n",
|
| 755 |
-
" extra_step_kwargs[\"eta\"] = eta\n",
|
| 756 |
-
"\n",
|
| 757 |
-
" # check if the scheduler accepts generator\n",
|
| 758 |
-
" accepts_generator = \"generator\" in set(\n",
|
| 759 |
-
" inspect.signature(self.noise_scheduler.step).parameters.keys()\n",
|
| 760 |
-
" )\n",
|
| 761 |
-
" if accepts_generator:\n",
|
| 762 |
-
" extra_step_kwargs[\"generator\"] = generator\n",
|
| 763 |
-
" return extra_step_kwargs\n",
|
| 764 |
-
"\n",
|
| 765 |
-
" @torch.no_grad()\n",
|
| 766 |
-
" def __call__(\n",
|
| 767 |
-
" self, \n",
|
| 768 |
-
" image: Union[PIL.Image.Image, torch.Tensor],\n",
|
| 769 |
-
" condition_image: Union[PIL.Image.Image, torch.Tensor],\n",
|
| 770 |
-
" mask: Union[PIL.Image.Image, torch.Tensor],\n",
|
| 771 |
-
" num_inference_steps: int = 50,\n",
|
| 772 |
-
" guidance_scale: float = 2.5,\n",
|
| 773 |
-
" height: int = 1024,\n",
|
| 774 |
-
" width: int = 768,\n",
|
| 775 |
-
" generator=None,\n",
|
| 776 |
-
" eta=1.0,\n",
|
| 777 |
-
" **kwargs\n",
|
| 778 |
-
" ):\n",
|
| 779 |
-
" concat_dim = -2 # FIXME: y axis concat\n",
|
| 780 |
-
" # Prepare inputs to Tensor\n",
|
| 781 |
-
" image, condition_image, mask = check_inputs(image, condition_image, mask, width, height)\n",
|
| 782 |
-
" image = prepare_image(image).to(self.device, dtype=self.weight_dtype)\n",
|
| 783 |
-
" condition_image = prepare_image(condition_image).to(self.device, dtype=self.weight_dtype)\n",
|
| 784 |
-
" mask = prepare_mask_image(mask).to(self.device, dtype=self.weight_dtype)\n",
|
| 785 |
-
" # Mask image\n",
|
| 786 |
-
" masked_image = image * (mask < 0.5)\n",
|
| 787 |
-
" # VAE encoding\n",
|
| 788 |
-
" masked_latent = compute_vae_encodings(masked_image, self.encoder)\n",
|
| 789 |
-
" condition_latent = compute_vae_encodings(condition_image, self.encoder)\n",
|
| 790 |
-
" mask_latent = torch.nn.functional.interpolate(mask, size=masked_latent.shape[-2:], mode=\"nearest\")\n",
|
| 791 |
-
" del image, mask, condition_image\n",
|
| 792 |
-
" # Concatenate latents\n",
|
| 793 |
-
" masked_latent_concat = torch.cat([masked_latent, condition_latent], dim=concat_dim)\n",
|
| 794 |
-
" mask_latent_concat = torch.cat([mask_latent, torch.zeros_like(mask_latent)], dim=concat_dim)\n",
|
| 795 |
-
" # Prepare noise\n",
|
| 796 |
-
" latents = randn_tensor(\n",
|
| 797 |
-
" masked_latent_concat.shape,\n",
|
| 798 |
-
" generator=generator,\n",
|
| 799 |
-
" device=masked_latent_concat.device,\n",
|
| 800 |
-
" dtype=self.weight_dtype,\n",
|
| 801 |
-
" )\n",
|
| 802 |
-
" # Prepare timesteps\n",
|
| 803 |
-
" self.noise_scheduler.set_inference_timesteps(num_inference_steps)\n",
|
| 804 |
-
" timesteps = self.noise_scheduler.timesteps\n",
|
| 805 |
-
" # latents = latents * self.noise_scheduler.init_noise_sigma\n",
|
| 806 |
-
" latents = self.noise_scheduler.add_noise(latents, timesteps[0])\n",
|
| 807 |
-
" # Classifier-Free Guidance\n",
|
| 808 |
-
" if do_classifier_free_guidance := (guidance_scale > 1.0):\n",
|
| 809 |
-
" masked_latent_concat = torch.cat(\n",
|
| 810 |
-
" [\n",
|
| 811 |
-
" torch.cat([masked_latent, torch.zeros_like(condition_latent)], dim=concat_dim),\n",
|
| 812 |
-
" masked_latent_concat,\n",
|
| 813 |
-
" ]\n",
|
| 814 |
-
" )\n",
|
| 815 |
-
" mask_latent_concat = torch.cat([mask_latent_concat] * 2)\n",
|
| 816 |
-
"\n",
|
| 817 |
-
" # Denoising loop\n",
|
| 818 |
-
" # extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)\n",
|
| 819 |
-
" # num_warmup_steps = (len(timesteps) - num_inference_steps * self.noise_scheduler.order)\n",
|
| 820 |
-
" num_warmup_steps = 0 # For simple DDPM, no warmup needed\n",
|
| 821 |
-
" with tqdm(total=num_inference_steps) as progress_bar:\n",
|
| 822 |
-
" for i, t in enumerate(timesteps):\n",
|
| 823 |
-
" # expand the latents if we are doing classifier free guidance\n",
|
| 824 |
-
" non_inpainting_latent_model_input = (torch.cat([latents] * 2) if do_classifier_free_guidance else latents)\n",
|
| 825 |
-
" # non_inpainting_latent_model_input = self.noise_scheduler.scale_model_input(non_inpainting_latent_model_input, t)\n",
|
| 826 |
-
" # prepare the input for the inpainting model\n",
|
| 827 |
-
" inpainting_latent_model_input = torch.cat([non_inpainting_latent_model_input, mask_latent_concat, masked_latent_concat], dim=1).to(self.device, dtype=self.weight_dtype)\n",
|
| 828 |
-
" # predict the noise residual\n",
|
| 829 |
-
" # time_embedding = get_time_embedding(t.item())\n",
|
| 830 |
-
" # time_embedding = time_embedding.repeat(inpainting_latent_model_input.shape[0], 1).to(self.device, dtype=self.weight_dtype)\n",
|
| 831 |
-
" noise_pred= self.unet(\n",
|
| 832 |
-
" inpainting_latent_model_input,\n",
|
| 833 |
-
" # time_embedding\n",
|
| 834 |
-
" t.to(self.device),\n",
|
| 835 |
-
" encoder_hidden_states=None, # FIXME\n",
|
| 836 |
-
" return_dict=False,\n",
|
| 837 |
-
" )[0]\n",
|
| 838 |
-
" # perform guidance\n",
|
| 839 |
-
" if do_classifier_free_guidance:\n",
|
| 840 |
-
" noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)\n",
|
| 841 |
-
" noise_pred = noise_pred_uncond + guidance_scale * (\n",
|
| 842 |
-
" noise_pred_text - noise_pred_uncond\n",
|
| 843 |
-
" )\n",
|
| 844 |
-
" # compute the previous noisy sample x_t -> x_t-1\n",
|
| 845 |
-
" latents = self.noise_scheduler.step(\n",
|
| 846 |
-
" t, latents, noise_pred\n",
|
| 847 |
-
" )\n",
|
| 848 |
-
" # call the callback, if provided\n",
|
| 849 |
-
" if i == len(timesteps) - 1 or (\n",
|
| 850 |
-
" (i + 1) > num_warmup_steps\n",
|
| 851 |
-
" ):\n",
|
| 852 |
-
" progress_bar.update()\n",
|
| 853 |
-
"\n",
|
| 854 |
-
" # Decode the final latents\n",
|
| 855 |
-
" latents = latents.split(latents.shape[concat_dim] // 2, dim=concat_dim)[0]\n",
|
| 856 |
-
" # latents = 1 / self.vae.config.scaling_factor * latents\n",
|
| 857 |
-
" # image = self.vae.decode(latents.to(self.device, dtype=self.weight_dtype)).sample\n",
|
| 858 |
-
" image = self.decoder(latents.to(self.device, dtype=self.weight_dtype))\n",
|
| 859 |
-
" image = (image / 2 + 0.5).clamp(0, 1)\n",
|
| 860 |
-
" # we always cast to float32 as this does not cause significant overhead and is compatible with bfloat16\n",
|
| 861 |
-
" image = image.cpu().permute(0, 2, 3, 1).float().numpy()\n",
|
| 862 |
-
" image = numpy_to_pil(image)\n",
|
| 863 |
-
" \n",
|
| 864 |
-
" # Safety Check\n",
|
| 865 |
-
" if not self.skip_safety_check:\n",
|
| 866 |
-
" current_script_directory = os.path.dirname(os.path.realpath(__file__))\n",
|
| 867 |
-
" nsfw_image = os.path.join(os.path.dirname(current_script_directory), 'resource', 'img', 'NSFW.jpg')\n",
|
| 868 |
-
" nsfw_image = PIL.Image.open(nsfw_image).resize(image[0].size)\n",
|
| 869 |
-
" image_np = np.array(image)\n",
|
| 870 |
-
" _, has_nsfw_concept = self.run_safety_checker(image=image_np)\n",
|
| 871 |
-
" for i, not_safe in enumerate(has_nsfw_concept):\n",
|
| 872 |
-
" if not_safe:\n",
|
| 873 |
-
" image[i] = nsfw_image\n",
|
| 874 |
-
" return image\n"
|
| 875 |
-
]
|
| 876 |
-
},
|
| 877 |
-
{
|
| 878 |
-
"cell_type": "code",
|
| 879 |
-
"execution_count": 10,
|
| 880 |
-
"id": "94e19198",
|
| 881 |
-
"metadata": {},
|
| 882 |
-
"outputs": [
|
| 883 |
-
{
|
| 884 |
-
"data": {
|
| 885 |
-
"application/vnd.jupyter.widget-view+json": {
|
| 886 |
-
"model_id": "594e5184ce094185bf75cb38118c1867",
|
| 887 |
-
"version_major": 2,
|
| 888 |
-
"version_minor": 0
|
| 889 |
-
},
|
| 890 |
-
"text/plain": [
|
| 891 |
-
"config.json: 0%| | 0.00/748 [00:00<?, ?B/s]"
|
| 892 |
-
]
|
| 893 |
-
},
|
| 894 |
-
"metadata": {},
|
| 895 |
-
"output_type": "display_data"
|
| 896 |
-
},
|
| 897 |
-
{
|
| 898 |
-
"name": "stderr",
|
| 899 |
-
"output_type": "stream",
|
| 900 |
-
"text": [
|
| 901 |
-
"An error occurred while trying to fetch booksforcharlie/stable-diffusion-inpainting: booksforcharlie/stable-diffusion-inpainting does not appear to have a file named diffusion_pytorch_model.safetensors.\n",
|
| 902 |
-
"Defaulting to unsafe serialization. Pass `allow_pickle=False` to raise an error instead.\n"
|
| 903 |
-
]
|
| 904 |
-
},
|
| 905 |
-
{
|
| 906 |
-
"data": {
|
| 907 |
-
"application/vnd.jupyter.widget-view+json": {
|
| 908 |
-
"model_id": "85377ab853e2479484e9b7449d1d2351",
|
| 909 |
-
"version_major": 2,
|
| 910 |
-
"version_minor": 0
|
| 911 |
-
},
|
| 912 |
-
"text/plain": [
|
| 913 |
-
"diffusion_pytorch_model.bin: 0%| | 0.00/3.44G [00:00<?, ?B/s]"
|
| 914 |
-
]
|
| 915 |
-
},
|
| 916 |
-
"metadata": {},
|
| 917 |
-
"output_type": "display_data"
|
| 918 |
-
},
|
| 919 |
-
{
|
| 920 |
-
"data": {
|
| 921 |
-
"application/vnd.jupyter.widget-view+json": {
|
| 922 |
-
"model_id": "23724cc21c594658bf4a861e8c45784c",
|
| 923 |
-
"version_major": 2,
|
| 924 |
-
"version_minor": 0
|
| 925 |
-
},
|
| 926 |
-
"text/plain": [
|
| 927 |
-
"Fetching 12 files: 0%| | 0/12 [00:00<?, ?it/s]"
|
| 928 |
-
]
|
| 929 |
-
},
|
| 930 |
-
"metadata": {},
|
| 931 |
-
"output_type": "display_data"
|
| 932 |
-
},
|
| 933 |
-
{
|
| 934 |
-
"data": {
|
| 935 |
-
"application/vnd.jupyter.widget-view+json": {
|
| 936 |
-
"model_id": "b6c7c4967a8746c494b0f782e2703e65",
|
| 937 |
-
"version_major": 2,
|
| 938 |
-
"version_minor": 0
|
| 939 |
-
},
|
| 940 |
-
"text/plain": [
|
| 941 |
-
"densepose_rcnn_R_50_FPN_s1x.yaml: 0%| | 0.00/182 [00:00<?, ?B/s]"
|
| 942 |
-
]
|
| 943 |
-
},
|
| 944 |
-
"metadata": {},
|
| 945 |
-
"output_type": "display_data"
|
| 946 |
-
},
|
| 947 |
-
{
|
| 948 |
-
"data": {
|
| 949 |
-
"application/vnd.jupyter.widget-view+json": {
|
| 950 |
-
"model_id": "0b3a37c06f3d4b2cb1a3c1be21eccd0d",
|
| 951 |
-
"version_major": 2,
|
| 952 |
-
"version_minor": 0
|
| 953 |
-
},
|
| 954 |
-
"text/plain": [
|
| 955 |
-
"model_final_162be9.pkl: 0%| | 0.00/256M [00:00<?, ?B/s]"
|
| 956 |
-
]
|
| 957 |
-
},
|
| 958 |
-
"metadata": {},
|
| 959 |
-
"output_type": "display_data"
|
| 960 |
-
},
|
| 961 |
-
{
|
| 962 |
-
"data": {
|
| 963 |
-
"application/vnd.jupyter.widget-view+json": {
|
| 964 |
-
"model_id": "5b70abffe67043a4b574be0fdad5b9f3",
|
| 965 |
-
"version_major": 2,
|
| 966 |
-
"version_minor": 0
|
| 967 |
-
},
|
| 968 |
-
"text/plain": [
|
| 969 |
-
".gitattributes: 0%| | 0.00/1.52k [00:00<?, ?B/s]"
|
| 970 |
-
]
|
| 971 |
-
},
|
| 972 |
-
"metadata": {},
|
| 973 |
-
"output_type": "display_data"
|
| 974 |
-
},
|
| 975 |
-
{
|
| 976 |
-
"data": {
|
| 977 |
-
"application/vnd.jupyter.widget-view+json": {
|
| 978 |
-
"model_id": "f9a17d6446ff4faeaf232440c49fca43",
|
| 979 |
-
"version_major": 2,
|
| 980 |
-
"version_minor": 0
|
| 981 |
-
},
|
| 982 |
-
"text/plain": [
|
| 983 |
-
"exp-schp-201908261155-lip.pth: 0%| | 0.00/267M [00:00<?, ?B/s]"
|
| 984 |
-
]
|
| 985 |
-
},
|
| 986 |
-
"metadata": {},
|
| 987 |
-
"output_type": "display_data"
|
| 988 |
-
},
|
| 989 |
-
{
|
| 990 |
-
"data": {
|
| 991 |
-
"application/vnd.jupyter.widget-view+json": {
|
| 992 |
-
"model_id": "285628a5b093405694737f77fdd21e3d",
|
| 993 |
-
"version_major": 2,
|
| 994 |
-
"version_minor": 0
|
| 995 |
-
},
|
| 996 |
-
"text/plain": [
|
| 997 |
-
"README.md: 0%| | 0.00/9.66k [00:00<?, ?B/s]"
|
| 998 |
-
]
|
| 999 |
-
},
|
| 1000 |
-
"metadata": {},
|
| 1001 |
-
"output_type": "display_data"
|
| 1002 |
-
},
|
| 1003 |
-
{
|
| 1004 |
-
"data": {
|
| 1005 |
-
"application/vnd.jupyter.widget-view+json": {
|
| 1006 |
-
"model_id": "438a1dec4d8340bf971700c8cd12c534",
|
| 1007 |
-
"version_major": 2,
|
| 1008 |
-
"version_minor": 0
|
| 1009 |
-
},
|
| 1010 |
-
"text/plain": [
|
| 1011 |
-
"model.safetensors: 0%| | 0.00/198M [00:00<?, ?B/s]"
|
| 1012 |
-
]
|
| 1013 |
-
},
|
| 1014 |
-
"metadata": {},
|
| 1015 |
-
"output_type": "display_data"
|
| 1016 |
-
},
|
| 1017 |
-
{
|
| 1018 |
-
"data": {
|
| 1019 |
-
"application/vnd.jupyter.widget-view+json": {
|
| 1020 |
-
"model_id": "8f21e713409c4b858348c800551866c3",
|
| 1021 |
-
"version_major": 2,
|
| 1022 |
-
"version_minor": 0
|
| 1023 |
-
},
|
| 1024 |
-
"text/plain": [
|
| 1025 |
-
"exp-schp-201908301523-atr.pth: 0%| | 0.00/267M [00:00<?, ?B/s]"
|
| 1026 |
-
]
|
| 1027 |
-
},
|
| 1028 |
-
"metadata": {},
|
| 1029 |
-
"output_type": "display_data"
|
| 1030 |
-
},
|
| 1031 |
-
{
|
| 1032 |
-
"data": {
|
| 1033 |
-
"application/vnd.jupyter.widget-view+json": {
|
| 1034 |
-
"model_id": "401a1d41102d49fb81f25f97f418aa8e",
|
| 1035 |
-
"version_major": 2,
|
| 1036 |
-
"version_minor": 0
|
| 1037 |
-
},
|
| 1038 |
-
"text/plain": [
|
| 1039 |
-
"Base-DensePose-RCNN-FPN.yaml: 0%| | 0.00/1.52k [00:00<?, ?B/s]"
|
| 1040 |
-
]
|
| 1041 |
-
},
|
| 1042 |
-
"metadata": {},
|
| 1043 |
-
"output_type": "display_data"
|
| 1044 |
-
},
|
| 1045 |
-
{
|
| 1046 |
-
"data": {
|
| 1047 |
-
"application/vnd.jupyter.widget-view+json": {
|
| 1048 |
-
"model_id": "f0b7926b07f34d84943fa2ee475df375",
|
| 1049 |
-
"version_major": 2,
|
| 1050 |
-
"version_minor": 0
|
| 1051 |
-
},
|
| 1052 |
-
"text/plain": [
|
| 1053 |
-
"pytorch_lora_weights.safetensors: 0%| | 0.00/37.4M [00:00<?, ?B/s]"
|
| 1054 |
-
]
|
| 1055 |
-
},
|
| 1056 |
-
"metadata": {},
|
| 1057 |
-
"output_type": "display_data"
|
| 1058 |
-
},
|
| 1059 |
-
{
|
| 1060 |
-
"data": {
|
| 1061 |
-
"application/vnd.jupyter.widget-view+json": {
|
| 1062 |
-
"model_id": "813b7619918f4295b017329f700a1475",
|
| 1063 |
-
"version_major": 2,
|
| 1064 |
-
"version_minor": 0
|
| 1065 |
-
},
|
| 1066 |
-
"text/plain": [
|
| 1067 |
-
"model.safetensors: 0%| | 0.00/198M [00:00<?, ?B/s]"
|
| 1068 |
-
]
|
| 1069 |
-
},
|
| 1070 |
-
"metadata": {},
|
| 1071 |
-
"output_type": "display_data"
|
| 1072 |
-
},
|
| 1073 |
-
{
|
| 1074 |
-
"data": {
|
| 1075 |
-
"application/vnd.jupyter.widget-view+json": {
|
| 1076 |
-
"model_id": "fbb928425e694855bb8da9b71a72fd75",
|
| 1077 |
-
"version_major": 2,
|
| 1078 |
-
"version_minor": 0
|
| 1079 |
-
},
|
| 1080 |
-
"text/plain": [
|
| 1081 |
-
"model.safetensors: 0%| | 0.00/198M [00:00<?, ?B/s]"
|
| 1082 |
-
]
|
| 1083 |
-
},
|
| 1084 |
-
"metadata": {},
|
| 1085 |
-
"output_type": "display_data"
|
| 1086 |
-
},
|
| 1087 |
-
{
|
| 1088 |
-
"name": "stdout",
|
| 1089 |
-
"output_type": "stream",
|
| 1090 |
-
"text": [
|
| 1091 |
-
"Downloaded zhengchong/CatVTON to /root/.cache/huggingface/hub/models--zhengchong--CatVTON/snapshots/2969fcf85fe62f2036605716f0b56f0b81d01d79\n",
|
| 1092 |
-
"Dataset vitonhd loaded, total 20 pairs.\n"
|
| 1093 |
-
]
|
| 1094 |
-
},
|
| 1095 |
-
{
|
| 1096 |
-
"name": "stderr",
|
| 1097 |
-
"output_type": "stream",
|
| 1098 |
-
"text": [
|
| 1099 |
-
"100%|██████████| 50/50 [00:37<00:00, 1.34it/s]\n",
|
| 1100 |
-
"100%|██████████| 50/50 [00:37<00:00, 1.34it/s]\n",
|
| 1101 |
-
"100%|██████████| 50/50 [00:37<00:00, 1.33it/s]\n",
|
| 1102 |
-
"100%|██████████| 50/50 [00:37<00:00, 1.34it/s]\n",
|
| 1103 |
-
"100%|██████████| 50/50 [00:37<00:00, 1.33it/s]\n",
|
| 1104 |
-
"100%|██████████| 50/50 [00:37<00:00, 1.33it/s]\n",
|
| 1105 |
-
"100%|██████████| 50/50 [00:37<00:00, 1.34it/s]\n",
|
| 1106 |
-
"100%|██████████| 50/50 [00:37<00:00, 1.34it/s]\n",
|
| 1107 |
-
"100%|██████████| 50/50 [00:37<00:00, 1.33it/s]\n",
|
| 1108 |
-
"100%|██████████| 50/50 [00:37<00:00, 1.34it/s]\n",
|
| 1109 |
-
"100%|██████████| 10/10 [06:35<00:00, 39.50s/it]\n"
|
| 1110 |
-
]
|
| 1111 |
-
}
|
| 1112 |
-
],
|
| 1113 |
-
"source": [
|
| 1114 |
-
"import os\n",
|
| 1115 |
-
"import numpy as np\n",
|
| 1116 |
-
"import torch\n",
|
| 1117 |
-
"import argparse\n",
|
| 1118 |
-
"from torch.utils.data import Dataset, DataLoader\n",
|
| 1119 |
-
"from VITON_Dataset import VITONHDTestDataset\n",
|
| 1120 |
-
"from diffusers.image_processor import VaeImageProcessor\n",
|
| 1121 |
-
"from tqdm import tqdm\n",
|
| 1122 |
-
"from PIL import Image, ImageFilter\n",
|
| 1123 |
-
"import load_model\n",
|
| 1124 |
-
"\n",
|
| 1125 |
-
"from utils import repaint, to_pil_image\n",
|
| 1126 |
-
" \n",
|
| 1127 |
-
"def parse_args():\n",
|
| 1128 |
-
" parser = argparse.ArgumentParser(description=\"Simple example of a training script.\")\n",
|
| 1129 |
-
" parser.add_argument(\n",
|
| 1130 |
-
" \"--base_model_path\",\n",
|
| 1131 |
-
" type=str,\n",
|
| 1132 |
-
" default=\"booksforcharlie/stable-diffusion-inpainting\", # Change to a copy repo as runawayml delete original repo\n",
|
| 1133 |
-
" help=(\n",
|
| 1134 |
-
" \"The path to the base model to use for evaluation. This can be a local path or a model identifier from the Model Hub.\"\n",
|
| 1135 |
-
" ),\n",
|
| 1136 |
-
" )\n",
|
| 1137 |
-
" parser.add_argument(\n",
|
| 1138 |
-
" \"--resume_path\",\n",
|
| 1139 |
-
" type=str,\n",
|
| 1140 |
-
" default=\"zhengchong/CatVTON\",\n",
|
| 1141 |
-
" help=(\n",
|
| 1142 |
-
" \"The Path to the checkpoint of trained tryon model.\"\n",
|
| 1143 |
-
" ),\n",
|
| 1144 |
-
" )\n",
|
| 1145 |
-
" parser.add_argument(\n",
|
| 1146 |
-
" \"--dataset_name\",\n",
|
| 1147 |
-
" type=str,\n",
|
| 1148 |
-
" required=True,\n",
|
| 1149 |
-
" help=\"The datasets to use for evaluation.\",\n",
|
| 1150 |
-
" )\n",
|
| 1151 |
-
" parser.add_argument(\n",
|
| 1152 |
-
" \"--data_root_path\", \n",
|
| 1153 |
-
" type=str, \n",
|
| 1154 |
-
" required=True,\n",
|
| 1155 |
-
" help=\"Path to the dataset to evaluate.\"\n",
|
| 1156 |
-
" )\n",
|
| 1157 |
-
" parser.add_argument(\n",
|
| 1158 |
-
" \"--output_dir\",\n",
|
| 1159 |
-
" type=str,\n",
|
| 1160 |
-
" default=\"output\",\n",
|
| 1161 |
-
" help=\"The output directory where the model predictions will be written.\",\n",
|
| 1162 |
-
" )\n",
|
| 1163 |
-
"\n",
|
| 1164 |
-
" parser.add_argument(\n",
|
| 1165 |
-
" \"--seed\", type=int, default=555, help=\"A seed for reproducible evaluation.\"\n",
|
| 1166 |
-
" )\n",
|
| 1167 |
-
" parser.add_argument(\n",
|
| 1168 |
-
" \"--batch_size\", type=int, default=8, help=\"The batch size for evaluation.\"\n",
|
| 1169 |
-
" )\n",
|
| 1170 |
-
" \n",
|
| 1171 |
-
" parser.add_argument(\n",
|
| 1172 |
-
" \"--num_inference_steps\",\n",
|
| 1173 |
-
" type=int,\n",
|
| 1174 |
-
" default=50,\n",
|
| 1175 |
-
" help=\"Number of inference steps to perform.\",\n",
|
| 1176 |
-
" )\n",
|
| 1177 |
-
" parser.add_argument(\n",
|
| 1178 |
-
" \"--guidance_scale\",\n",
|
| 1179 |
-
" type=float,\n",
|
| 1180 |
-
" default=2.5,\n",
|
| 1181 |
-
" help=\"The scale of classifier-free guidance for inference.\",\n",
|
| 1182 |
-
" )\n",
|
| 1183 |
-
"\n",
|
| 1184 |
-
" parser.add_argument(\n",
|
| 1185 |
-
" \"--width\",\n",
|
| 1186 |
-
" type=int,\n",
|
| 1187 |
-
" default=384,\n",
|
| 1188 |
-
" help=(\n",
|
| 1189 |
-
" \"The resolution for input images, all the images in the train/validation dataset will be resized to this\"\n",
|
| 1190 |
-
" \" resolution\"\n",
|
| 1191 |
-
" ),\n",
|
| 1192 |
-
" )\n",
|
| 1193 |
-
" parser.add_argument(\n",
|
| 1194 |
-
" \"--height\",\n",
|
| 1195 |
-
" type=int,\n",
|
| 1196 |
-
" default=512,\n",
|
| 1197 |
-
" help=(\n",
|
| 1198 |
-
" \"The resolution for input images, all the images in the train/validation dataset will be resized to this\"\n",
|
| 1199 |
-
" \" resolution\"\n",
|
| 1200 |
-
" ),\n",
|
| 1201 |
-
" )\n",
|
| 1202 |
-
" parser.add_argument(\n",
|
| 1203 |
-
" \"--repaint\", \n",
|
| 1204 |
-
" action=\"store_true\", \n",
|
| 1205 |
-
" help=\"Whether to repaint the result image with the original background.\"\n",
|
| 1206 |
-
" )\n",
|
| 1207 |
-
" parser.add_argument(\n",
|
| 1208 |
-
" \"--eval_pair\",\n",
|
| 1209 |
-
" action=\"store_true\",\n",
|
| 1210 |
-
" help=\"Whether or not to evaluate the pair.\",\n",
|
| 1211 |
-
" )\n",
|
| 1212 |
-
" parser.add_argument(\n",
|
| 1213 |
-
" \"--concat_eval_results\",\n",
|
| 1214 |
-
" action=\"store_true\",\n",
|
| 1215 |
-
" help=\"Whether or not to concatenate the all conditions into one image.\",\n",
|
| 1216 |
-
" )\n",
|
| 1217 |
-
" parser.add_argument(\n",
|
| 1218 |
-
" \"--allow_tf32\",\n",
|
| 1219 |
-
" action=\"store_true\",\n",
|
| 1220 |
-
" default=True,\n",
|
| 1221 |
-
" help=(\n",
|
| 1222 |
-
" \"Whether or not to allow TF32 on Ampere GPUs. Can be used to speed up training. For more information, see\"\n",
|
| 1223 |
-
" \" https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices\"\n",
|
| 1224 |
-
" ),\n",
|
| 1225 |
-
" )\n",
|
| 1226 |
-
" parser.add_argument(\n",
|
| 1227 |
-
" \"--dataloader_num_workers\",\n",
|
| 1228 |
-
" type=int,\n",
|
| 1229 |
-
" default=8,\n",
|
| 1230 |
-
" help=(\n",
|
| 1231 |
-
" \"Number of subprocesses to use for data loading. 0 means that the data will be loaded in the main process.\"\n",
|
| 1232 |
-
" ),\n",
|
| 1233 |
-
" )\n",
|
| 1234 |
-
" parser.add_argument(\n",
|
| 1235 |
-
" \"--mixed_precision\",\n",
|
| 1236 |
-
" type=str,\n",
|
| 1237 |
-
" default=\"bf16\",\n",
|
| 1238 |
-
" choices=[\"no\", \"fp16\", \"bf16\"],\n",
|
| 1239 |
-
" help=(\n",
|
| 1240 |
-
" \"Whether to use mixed precision. Choose between fp16 and bf16 (bfloat16). Bf16 requires PyTorch >=\"\n",
|
| 1241 |
-
" \" 1.10.and an Nvidia Ampere GPU. Default to the value of accelerate config of the current system or the\"\n",
|
| 1242 |
-
" \" flag passed with the `accelerate.launch` command. Use this argument to override the accelerate config.\"\n",
|
| 1243 |
-
" ),\n",
|
| 1244 |
-
" )\n",
|
| 1245 |
-
"\n",
|
| 1246 |
-
" parser.add_argument(\n",
|
| 1247 |
-
" \"--concat_axis\",\n",
|
| 1248 |
-
" type=str,\n",
|
| 1249 |
-
" choices=[\"x\", \"y\", 'random'],\n",
|
| 1250 |
-
" default=\"y\",\n",
|
| 1251 |
-
" help=\"The axis to concat the cloth feature, select from ['x', 'y', 'random'].\",\n",
|
| 1252 |
-
" )\n",
|
| 1253 |
-
" parser.add_argument(\n",
|
| 1254 |
-
" \"--enable_condition_noise\",\n",
|
| 1255 |
-
" action=\"store_true\",\n",
|
| 1256 |
-
" default=True,\n",
|
| 1257 |
-
" help=\"Whether or not to enable condition noise.\",\n",
|
| 1258 |
-
" )\n",
|
| 1259 |
-
" \n",
|
| 1260 |
-
" args = parser.parse_args()\n",
|
| 1261 |
-
" env_local_rank = int(os.environ.get(\"LOCAL_RANK\", -1))\n",
|
| 1262 |
-
" if env_local_rank != -1 and env_local_rank != args.local_rank:\n",
|
| 1263 |
-
" args.local_rank = env_local_rank\n",
|
| 1264 |
-
"\n",
|
| 1265 |
-
" return args\n",
|
| 1266 |
-
"\n",
|
| 1267 |
-
"@torch.no_grad()\n",
|
| 1268 |
-
"def main():\n",
|
| 1269 |
-
" # args = parse_args()\n",
|
| 1270 |
-
"\n",
|
| 1271 |
-
" # Replace <path> with your actual data root and output directory paths\n",
|
| 1272 |
-
" # !CUDA_VISIBLE_DEVICES=0 python inference.py \\\n",
|
| 1273 |
-
" # --dataset vitonhd \\\n",
|
| 1274 |
-
" # --data_root_path /kaggle/input/viton-hd-dataset \\\n",
|
| 1275 |
-
" # --output_dir ./output \\\n",
|
| 1276 |
-
" # --dataloader_num_workers 8 \\\n",
|
| 1277 |
-
" # --batch_size 8 \\\n",
|
| 1278 |
-
" # --seed 555 \\\n",
|
| 1279 |
-
" # --mixed_precision no \\\n",
|
| 1280 |
-
" # --allow_tf32 \\\n",
|
| 1281 |
-
" # --repaint \\\n",
|
| 1282 |
-
" # --eval_pair\n",
|
| 1283 |
-
" \n",
|
| 1284 |
-
" args=argparse.Namespace()\n",
|
| 1285 |
-
" args.__dict__= {\n",
|
| 1286 |
-
" \"base_model_path\": \"booksforcharlie/stable-diffusion-inpainting\",\n",
|
| 1287 |
-
" \"resume_path\": \"zhengchong/CatVTON\",\n",
|
| 1288 |
-
" \"dataset_name\": \"vitonhd\",\n",
|
| 1289 |
-
" # \"data_root_path\": \"/kaggle/input/viton-hd-dataset\",\n",
|
| 1290 |
-
" \"data_root_path\": \"/kaggle/working/stable-diffusion/sample_dataset\",\n",
|
| 1291 |
-
" \"output_dir\": \"./output\",\n",
|
| 1292 |
-
" \"seed\": 555,\n",
|
| 1293 |
-
" \"batch_size\": 2,\n",
|
| 1294 |
-
" \"num_inference_steps\": 50,\n",
|
| 1295 |
-
" \"guidance_scale\": 2.5,\n",
|
| 1296 |
-
" \"width\": 384,\n",
|
| 1297 |
-
" \"height\": 512,\n",
|
| 1298 |
-
" \"repaint\": True,\n",
|
| 1299 |
-
" \"eval_pair\": False,\n",
|
| 1300 |
-
" \"concat_eval_results\": True,\n",
|
| 1301 |
-
" \"allow_tf32\": True,\n",
|
| 1302 |
-
" \"dataloader_num_workers\": 4,\n",
|
| 1303 |
-
" \"mixed_precision\": 'no',\n",
|
| 1304 |
-
" \"concat_axis\": 'y',\n",
|
| 1305 |
-
" \"enable_condition_noise\": True,\n",
|
| 1306 |
-
" \"is_train\": False\n",
|
| 1307 |
-
" }\n",
|
| 1308 |
-
"\n",
|
| 1309 |
-
" models=load_model.preload_models_from_standard_weights(ckpt_path=\"sd-v1-5-inpainting.ckpt\", device=\"cuda\", finetune_weights_path=None)\n",
|
| 1310 |
-
"\n",
|
| 1311 |
-
" # Pipeline\n",
|
| 1312 |
-
" pipeline = CatVTONPipeline(\n",
|
| 1313 |
-
" attn_ckpt_version=args.dataset_name,\n",
|
| 1314 |
-
" attn_ckpt=args.resume_path,\n",
|
| 1315 |
-
" base_ckpt=args.base_model_path,\n",
|
| 1316 |
-
" weight_dtype={\n",
|
| 1317 |
-
" \"no\": torch.float32,\n",
|
| 1318 |
-
" \"fp16\": torch.float16,\n",
|
| 1319 |
-
" \"bf16\": torch.bfloat16,\n",
|
| 1320 |
-
" }[args.mixed_precision],\n",
|
| 1321 |
-
" device=\"cuda\",\n",
|
| 1322 |
-
" skip_safety_check=True,\n",
|
| 1323 |
-
" models=models,\n",
|
| 1324 |
-
" )\n",
|
| 1325 |
-
" # Dataset\n",
|
| 1326 |
-
" if args.dataset_name == \"vitonhd\":\n",
|
| 1327 |
-
" dataset = VITONHDTestDataset(args)\n",
|
| 1328 |
-
" else:\n",
|
| 1329 |
-
" raise ValueError(f\"Invalid dataset name {args.dataset}.\")\n",
|
| 1330 |
-
" print(f\"Dataset {args.dataset_name} loaded, total {len(dataset)} pairs.\")\n",
|
| 1331 |
-
" dataloader = DataLoader(\n",
|
| 1332 |
-
" dataset,\n",
|
| 1333 |
-
" batch_size=args.batch_size,\n",
|
| 1334 |
-
" shuffle=False,\n",
|
| 1335 |
-
" num_workers=args.dataloader_num_workers\n",
|
| 1336 |
-
" )\n",
|
| 1337 |
-
" # Inference\n",
|
| 1338 |
-
" generator = torch.Generator(device='cuda').manual_seed(args.seed)\n",
|
| 1339 |
-
" args.output_dir = os.path.join(args.output_dir, f\"{args.dataset_name}-{args.height}\", \"paired\" if args.eval_pair else \"unpaired\")\n",
|
| 1340 |
-
" if not os.path.exists(args.output_dir):\n",
|
| 1341 |
-
" os.makedirs(args.output_dir)\n",
|
| 1342 |
-
" \n",
|
| 1343 |
-
" for batch in tqdm(dataloader):\n",
|
| 1344 |
-
" person_images = batch['person']\n",
|
| 1345 |
-
" cloth_images = batch['cloth']\n",
|
| 1346 |
-
" masks = batch['mask']\n",
|
| 1347 |
-
"\n",
|
| 1348 |
-
" results = pipeline(\n",
|
| 1349 |
-
" person_images,\n",
|
| 1350 |
-
" cloth_images,\n",
|
| 1351 |
-
" masks,\n",
|
| 1352 |
-
" num_inference_steps=args.num_inference_steps,\n",
|
| 1353 |
-
" guidance_scale=args.guidance_scale,\n",
|
| 1354 |
-
" height=args.height,\n",
|
| 1355 |
-
" width=args.width,\n",
|
| 1356 |
-
" generator=generator,\n",
|
| 1357 |
-
" )\n",
|
| 1358 |
-
" \n",
|
| 1359 |
-
" if args.concat_eval_results or args.repaint:\n",
|
| 1360 |
-
" person_images = to_pil_image(person_images)\n",
|
| 1361 |
-
" cloth_images = to_pil_image(cloth_images)\n",
|
| 1362 |
-
" masks = to_pil_image(masks)\n",
|
| 1363 |
-
" for i, result in enumerate(results):\n",
|
| 1364 |
-
" person_name = batch['person_name'][i]\n",
|
| 1365 |
-
" output_path = os.path.join(args.output_dir, person_name)\n",
|
| 1366 |
-
" if not os.path.exists(os.path.dirname(output_path)):\n",
|
| 1367 |
-
" os.makedirs(os.path.dirname(output_path))\n",
|
| 1368 |
-
" if args.repaint:\n",
|
| 1369 |
-
" person_path, mask_path = dataset.data[batch['index'][i]]['person'], dataset.data[batch['index'][i]]['mask']\n",
|
| 1370 |
-
" person_image= Image.open(person_path).resize(result.size, Image.LANCZOS)\n",
|
| 1371 |
-
" mask = Image.open(mask_path).resize(result.size, Image.NEAREST)\n",
|
| 1372 |
-
" result = repaint(person_image, mask, result)\n",
|
| 1373 |
-
" if args.concat_eval_results:\n",
|
| 1374 |
-
" w, h = result.size\n",
|
| 1375 |
-
" concated_result = Image.new('RGB', (w*3, h))\n",
|
| 1376 |
-
" concated_result.paste(person_images[i], (0, 0))\n",
|
| 1377 |
-
" concated_result.paste(cloth_images[i], (w, 0)) \n",
|
| 1378 |
-
" concated_result.paste(result, (w*2, 0))\n",
|
| 1379 |
-
" result = concated_result\n",
|
| 1380 |
-
" result.save(output_path)\n",
|
| 1381 |
-
"\n",
|
| 1382 |
-
"if __name__ == \"__main__\":\n",
|
| 1383 |
-
" main()"
|
| 1384 |
-
]
|
| 1385 |
-
},
|
| 1386 |
-
{
|
| 1387 |
-
"cell_type": "code",
|
| 1388 |
-
"execution_count": null,
|
| 1389 |
-
"id": "5c2d9f98",
|
| 1390 |
-
"metadata": {},
|
| 1391 |
-
"outputs": [],
|
| 1392 |
-
"source": []
|
| 1393 |
-
},
|
| 1394 |
-
{
|
| 1395 |
-
"cell_type": "code",
|
| 1396 |
-
"execution_count": null,
|
| 1397 |
-
"id": "143d0ef9",
|
| 1398 |
-
"metadata": {},
|
| 1399 |
-
"outputs": [],
|
| 1400 |
-
"source": [
|
| 1401 |
-
"# rm -rf output"
|
| 1402 |
-
]
|
| 1403 |
-
},
|
| 1404 |
-
{
|
| 1405 |
-
"cell_type": "code",
|
| 1406 |
-
"execution_count": null,
|
| 1407 |
-
"id": "e417edb7",
|
| 1408 |
-
"metadata": {},
|
| 1409 |
-
"outputs": [],
|
| 1410 |
-
"source": []
|
| 1411 |
-
},
|
| 1412 |
-
{
|
| 1413 |
-
"cell_type": "code",
|
| 1414 |
-
"execution_count": null,
|
| 1415 |
-
"id": "1c86c58d",
|
| 1416 |
-
"metadata": {},
|
| 1417 |
-
"outputs": [],
|
| 1418 |
-
"source": []
|
| 1419 |
-
}
|
| 1420 |
-
],
|
| 1421 |
-
"metadata": {
|
| 1422 |
-
"kernelspec": {
|
| 1423 |
-
"display_name": "Python 3 (ipykernel)",
|
| 1424 |
-
"language": "python",
|
| 1425 |
-
"name": "python3"
|
| 1426 |
-
}
|
| 1427 |
-
},
|
| 1428 |
-
"nbformat": 4,
|
| 1429 |
-
"nbformat_minor": 5
|
| 1430 |
-
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
trained_output/vitonhd-384/unpaired/00654_00.jpg
ADDED
|
trained_output/vitonhd-384/unpaired/01265_00.jpg
ADDED
|
trained_output/vitonhd-384/unpaired/01985_00.jpg
ADDED
|
trained_output/vitonhd-384/unpaired/02023_00.jpg
ADDED
|
trained_output/vitonhd-384/unpaired/02532_00.jpg
ADDED
|
trained_output/vitonhd-384/unpaired/02944_00.jpg
ADDED
|
trained_output/vitonhd-384/unpaired/03191_00.jpg
ADDED
|
trained_output/vitonhd-384/unpaired/03921_00.jpg
ADDED
|
trained_output/vitonhd-384/unpaired/05006_00.jpg
ADDED
|
trained_output/vitonhd-384/unpaired/05378_00.jpg
ADDED
|
trained_output/vitonhd-384/unpaired/07342_00.jpg
ADDED
|
trained_output/vitonhd-384/unpaired/08088_00.jpg
ADDED
|
trained_output/vitonhd-384/unpaired/08239_00.jpg
ADDED
|
trained_output/vitonhd-384/unpaired/08650_00.jpg
ADDED
|
trained_output/vitonhd-384/unpaired/08839_00.jpg
ADDED
|
trained_output/vitonhd-384/unpaired/11085_00.jpg
ADDED
|
trained_output/vitonhd-384/unpaired/12345_00.jpg
ADDED
|
trained_output/vitonhd-384/unpaired/12419_00.jpg
ADDED
|
trained_output/vitonhd-384/unpaired/12562_00.jpg
ADDED
|
trained_output/vitonhd-384/unpaired/14651_00.jpg
ADDED
|
training.ipynb
CHANGED
|
@@ -5,23 +5,9 @@
|
|
| 5 |
"execution_count": 1,
|
| 6 |
"id": "81e4a1db",
|
| 7 |
"metadata": {},
|
| 8 |
-
"outputs": [
|
| 9 |
-
{
|
| 10 |
-
"name": "stdout",
|
| 11 |
-
"output_type": "stream",
|
| 12 |
-
"text": [
|
| 13 |
-
"Cloning into 'stable-diffusion'...\n",
|
| 14 |
-
"remote: Enumerating objects: 184, done.\u001b[K\n",
|
| 15 |
-
"remote: Counting objects: 100% (184/184), done.\u001b[K\n",
|
| 16 |
-
"remote: Compressing objects: 100% (156/156), done.\u001b[K\n",
|
| 17 |
-
"remote: Total 184 (delta 44), reused 165 (delta 26), pack-reused 0 (from 0)\u001b[K\n",
|
| 18 |
-
"Receiving objects: 100% (184/184), 9.94 MiB | 37.02 MiB/s, done.\n",
|
| 19 |
-
"Resolving deltas: 100% (44/44), done.\n"
|
| 20 |
-
]
|
| 21 |
-
}
|
| 22 |
-
],
|
| 23 |
"source": [
|
| 24 |
-
"!git clone -b CatVTON https://github.com/Harsh-Kesharwani/stable-diffusion.git"
|
| 25 |
]
|
| 26 |
},
|
| 27 |
{
|
|
@@ -29,443 +15,370 @@
|
|
| 29 |
"execution_count": 2,
|
| 30 |
"id": "9c89e320",
|
| 31 |
"metadata": {},
|
| 32 |
-
"outputs": [
|
| 33 |
-
{
|
| 34 |
-
"name": "stdout",
|
| 35 |
-
"output_type": "stream",
|
| 36 |
-
"text": [
|
| 37 |
-
"/kaggle/working/stable-diffusion\n"
|
| 38 |
-
]
|
| 39 |
-
}
|
| 40 |
-
],
|
| 41 |
"source": [
|
| 42 |
-
"cd stable-diffusion/"
|
| 43 |
]
|
| 44 |
},
|
| 45 |
{
|
| 46 |
"cell_type": "code",
|
| 47 |
"execution_count": 3,
|
| 48 |
-
"id": "
|
| 49 |
"metadata": {},
|
| 50 |
"outputs": [
|
| 51 |
{
|
| 52 |
"name": "stdout",
|
| 53 |
"output_type": "stream",
|
| 54 |
"text": [
|
| 55 |
-
"
|
| 56 |
]
|
| 57 |
}
|
| 58 |
],
|
| 59 |
"source": [
|
| 60 |
-
"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 61 |
]
|
| 62 |
},
|
| 63 |
{
|
| 64 |
"cell_type": "code",
|
| 65 |
"execution_count": 4,
|
| 66 |
-
"id": "
|
| 67 |
"metadata": {},
|
| 68 |
"outputs": [
|
| 69 |
{
|
| 70 |
"name": "stdout",
|
| 71 |
"output_type": "stream",
|
| 72 |
"text": [
|
| 73 |
-
"
|
| 74 |
-
"Resolving huggingface.co (huggingface.co)... 3.171.171.104, 3.171.171.128, 3.171.171.6, ...\n",
|
| 75 |
-
"Connecting to huggingface.co (huggingface.co)|3.171.171.104|:443... connected.\n",
|
| 76 |
-
"HTTP request sent, awaiting response... 307 Temporary Redirect\n",
|
| 77 |
-
"Location: /stable-diffusion-v1-5/stable-diffusion-inpainting/resolve/main/sd-v1-5-inpainting.ckpt [following]\n",
|
| 78 |
-
"--2025-06-17 08:50:15-- https://huggingface.co/stable-diffusion-v1-5/stable-diffusion-inpainting/resolve/main/sd-v1-5-inpainting.ckpt\n",
|
| 79 |
-
"Reusing existing connection to huggingface.co:443.\n",
|
| 80 |
-
"HTTP request sent, awaiting response... 302 Found\n",
|
| 81 |
-
"Location: https://cdn-lfs.hf.co/repos/f6/56/f656f0fa3b8a40ac76d297fa2a4b00f981e8eb1261963460764e7dd3b35ec97f/c6bbc15e3224e6973459ba78de4998b80b50112b0ae5b5c67113d56b4e366b19?response-content-disposition=inline%3B+filename*%3DUTF-8%27%27sd-v1-5-inpainting.ckpt%3B+filename%3D%22sd-v1-5-inpainting.ckpt%22%3B&Expires=1750153142&Policy=eyJTdGF0ZW1lbnQiOlt7IkNvbmRpdGlvbiI6eyJEYXRlTGVzc1RoYW4iOnsiQVdTOkVwb2NoVGltZSI6MTc1MDE1MzE0Mn19LCJSZXNvdXJjZSI6Imh0dHBzOi8vY2RuLWxmcy5oZi5jby9yZXBvcy9mNi81Ni9mNjU2ZjBmYTNiOGE0MGFjNzZkMjk3ZmEyYTRiMDBmOTgxZThlYjEyNjE5NjM0NjA3NjRlN2RkM2IzNWVjOTdmL2M2YmJjMTVlMzIyNGU2OTczNDU5YmE3OGRlNDk5OGI4MGI1MDExMmIwYWU1YjVjNjcxMTNkNTZiNGUzNjZiMTk%7EcmVzcG9uc2UtY29udGVudC1kaXNwb3NpdGlvbj0qIn1dfQ__&Signature=kAea10Cu%7EhNLABWiXI0i%7E5gAtwsQUUM6CIZczAEWsswZur-XllSQvXEoKksmPdojVE654r7s-CxII8r%7EQ52to%7EQMLbjsjw-JmXq4duiq91qz6U5aenByAXSpOO1ihAoCmCkP02e7L5Wcbs%7EhaV26W9Q%7EAfbwyQ1mn9ta%7EHIDiE7AuNuHgkEEA2IP45ao25b9zsaFw6fIUlBy93Meuf82zwzsw8CJPWV9QEwj-oPVeSDyv3ZhfxS3iCgGSYS320Vs7NcK%7EqJxPfttpTHG9m6zAnfxOpWjYVQfre6HnHUt3VHOy4QdDvpyfljgEQoH4LxRBWI%7Ev72YjOJZDEgSPoTi1Q__&Key-Pair-Id=K3RPWS32NSSJCE [following]\n",
|
| 82 |
-
"--2025-06-17 08:50:15-- https://cdn-lfs.hf.co/repos/f6/56/f656f0fa3b8a40ac76d297fa2a4b00f981e8eb1261963460764e7dd3b35ec97f/c6bbc15e3224e6973459ba78de4998b80b50112b0ae5b5c67113d56b4e366b19?response-content-disposition=inline%3B+filename*%3DUTF-8%27%27sd-v1-5-inpainting.ckpt%3B+filename%3D%22sd-v1-5-inpainting.ckpt%22%3B&Expires=1750153142&Policy=eyJTdGF0ZW1lbnQiOlt7IkNvbmRpdGlvbiI6eyJEYXRlTGVzc1RoYW4iOnsiQVdTOkVwb2NoVGltZSI6MTc1MDE1MzE0Mn19LCJSZXNvdXJjZSI6Imh0dHBzOi8vY2RuLWxmcy5oZi5jby9yZXBvcy9mNi81Ni9mNjU2ZjBmYTNiOGE0MGFjNzZkMjk3ZmEyYTRiMDBmOTgxZThlYjEyNjE5NjM0NjA3NjRlN2RkM2IzNWVjOTdmL2M2YmJjMTVlMzIyNGU2OTczNDU5YmE3OGRlNDk5OGI4MGI1MDExMmIwYWU1YjVjNjcxMTNkNTZiNGUzNjZiMTk%7EcmVzcG9uc2UtY29udGVudC1kaXNwb3NpdGlvbj0qIn1dfQ__&Signature=kAea10Cu%7EhNLABWiXI0i%7E5gAtwsQUUM6CIZczAEWsswZur-XllSQvXEoKksmPdojVE654r7s-CxII8r%7EQ52to%7EQMLbjsjw-JmXq4duiq91qz6U5aenByAXSpOO1ihAoCmCkP02e7L5Wcbs%7EhaV26W9Q%7EAfbwyQ1mn9ta%7EHIDiE7AuNuHgkEEA2IP45ao25b9zsaFw6fIUlBy93Meuf82zwzsw8CJPWV9QEwj-oPVeSDyv3ZhfxS3iCgGSYS320Vs7NcK%7EqJxPfttpTHG9m6zAnfxOpWjYVQfre6HnHUt3VHOy4QdDvpyfljgEQoH4LxRBWI%7Ev72YjOJZDEgSPoTi1Q__&Key-Pair-Id=K3RPWS32NSSJCE\n",
|
| 83 |
-
"Resolving cdn-lfs.hf.co (cdn-lfs.hf.co)... 18.160.78.83, 18.160.78.87, 18.160.78.43, ...\n",
|
| 84 |
-
"Connecting to cdn-lfs.hf.co (cdn-lfs.hf.co)|18.160.78.83|:443... connected.\n",
|
| 85 |
-
"HTTP request sent, awaiting response... 200 OK\n",
|
| 86 |
-
"Length: 4265437280 (4.0G) [binary/octet-stream]\n",
|
| 87 |
-
"Saving to: ‘sd-v1-5-inpainting.ckpt’\n",
|
| 88 |
-
"\n",
|
| 89 |
-
"sd-v1-5-inpainting. 100%[===================>] 3.97G 324MB/s in 12s \n",
|
| 90 |
-
"\n",
|
| 91 |
-
"2025-06-17 08:50:27 (341 MB/s) - ‘sd-v1-5-inpainting.ckpt’ saved [4265437280/4265437280]\n",
|
| 92 |
-
"\n"
|
| 93 |
]
|
| 94 |
}
|
| 95 |
],
|
| 96 |
"source": [
|
| 97 |
-
"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 98 |
]
|
| 99 |
},
|
| 100 |
{
|
| 101 |
"cell_type": "code",
|
| 102 |
"execution_count": 5,
|
| 103 |
-
"id": "
|
| 104 |
"metadata": {},
|
| 105 |
"outputs": [
|
| 106 |
{
|
| 107 |
"name": "stdout",
|
| 108 |
"output_type": "stream",
|
| 109 |
"text": [
|
| 110 |
-
"
|
| 111 |
-
"
|
| 112 |
-
"ddpm.py merges.txt\t sample_dataset\t vocab.json\n",
|
| 113 |
-
"decoder.py model_converter.py sd-v1-5-inpainting.ckpt\n",
|
| 114 |
-
"diffusion.py output\t\t test.ipynb\n",
|
| 115 |
-
"encoder.py pipeline.py\t training.ipynb\n"
|
| 116 |
]
|
| 117 |
}
|
| 118 |
],
|
| 119 |
"source": [
|
| 120 |
-
"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 121 |
]
|
| 122 |
},
|
| 123 |
{
|
| 124 |
"cell_type": "code",
|
| 125 |
"execution_count": null,
|
| 126 |
-
"id": "
|
| 127 |
"metadata": {},
|
| 128 |
"outputs": [
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 129 |
{
|
| 130 |
"name": "stdout",
|
| 131 |
"output_type": "stream",
|
| 132 |
"text": [
|
| 133 |
-
"
|
| 134 |
-
"
|
| 135 |
-
"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 136 |
]
|
| 137 |
-
}
|
| 138 |
-
|
| 139 |
-
|
| 140 |
-
|
| 141 |
-
|
| 142 |
-
|
| 143 |
-
|
| 144 |
-
|
| 145 |
-
|
| 146 |
-
"id": "a9c7b968",
|
| 147 |
-
"metadata": {},
|
| 148 |
-
"outputs": [
|
| 149 |
{
|
| 150 |
"name": "stdout",
|
| 151 |
"output_type": "stream",
|
| 152 |
"text": [
|
| 153 |
-
"
|
| 154 |
-
"
|
| 155 |
-
"
|
| 156 |
-
"
|
| 157 |
-
"
|
| 158 |
-
"
|
| 159 |
-
"
|
| 160 |
-
"Requirement already satisfied: typing-extensions>=4.0.0 in /usr/local/lib/python3.11/dist-packages (from beautifulsoup4->gdown) (4.13.1)\n",
|
| 161 |
-
"Requirement already satisfied: charset-normalizer<4,>=2 in /usr/local/lib/python3.11/dist-packages (from requests[socks]->gdown) (3.4.1)\n",
|
| 162 |
-
"Requirement already satisfied: idna<4,>=2.5 in /usr/local/lib/python3.11/dist-packages (from requests[socks]->gdown) (3.10)\n",
|
| 163 |
-
"Requirement already satisfied: urllib3<3,>=1.21.1 in /usr/local/lib/python3.11/dist-packages (from requests[socks]->gdown) (2.3.0)\n",
|
| 164 |
-
"Requirement already satisfied: certifi>=2017.4.17 in /usr/local/lib/python3.11/dist-packages (from requests[socks]->gdown) (2025.1.31)\n",
|
| 165 |
-
"Requirement already satisfied: PySocks!=1.5.7,>=1.5.6 in /usr/local/lib/python3.11/dist-packages (from requests[socks]->gdown) (1.7.1)\n",
|
| 166 |
-
"Downloading gdown-5.2.0-py3-none-any.whl (18 kB)\n",
|
| 167 |
-
"Installing collected packages: gdown\n",
|
| 168 |
-
"Successfully installed gdown-5.2.0\n"
|
| 169 |
]
|
| 170 |
-
}
|
| 171 |
-
|
| 172 |
-
|
| 173 |
-
|
| 174 |
-
|
| 175 |
-
|
| 176 |
-
|
| 177 |
-
|
| 178 |
-
|
| 179 |
-
|
| 180 |
-
|
| 181 |
-
|
|
|
|
| 182 |
{
|
| 183 |
"name": "stdout",
|
| 184 |
"output_type": "stream",
|
| 185 |
"text": [
|
| 186 |
-
"
|
| 187 |
]
|
| 188 |
-
}
|
| 189 |
-
|
| 190 |
-
|
| 191 |
-
|
| 192 |
-
|
| 193 |
-
|
| 194 |
-
|
| 195 |
-
|
| 196 |
-
]
|
| 197 |
-
},
|
| 198 |
-
{
|
| 199 |
-
"cell_type": "code",
|
| 200 |
-
"execution_count": 56,
|
| 201 |
-
"id": "a5d54cb4",
|
| 202 |
-
"metadata": {},
|
| 203 |
-
"outputs": [
|
| 204 |
{
|
| 205 |
"name": "stdout",
|
| 206 |
"output_type": "stream",
|
| 207 |
"text": [
|
| 208 |
-
"
|
| 209 |
-
"attention.py\t dog.jpg\t model_converter.py sd-v1-5-inpainting.ckpt\n",
|
| 210 |
-
"clip.py\t\t encoder.py\t model.py\t test.ipynb\n",
|
| 211 |
-
"data\t\t garment.jpg\t person.jpg\t vocab.json\n",
|
| 212 |
-
"ddpm.py\t\t image.png\t pipeline.py\t zalando-hd-resized.zip\n",
|
| 213 |
-
"decoder.py\t interface.py README.md\n"
|
| 214 |
]
|
| 215 |
-
}
|
| 216 |
-
|
| 217 |
-
|
| 218 |
-
|
| 219 |
-
|
| 220 |
-
|
| 221 |
-
|
| 222 |
-
|
| 223 |
-
"execution_count": 57,
|
| 224 |
-
"id": "f379e29c",
|
| 225 |
-
"metadata": {},
|
| 226 |
-
"outputs": [],
|
| 227 |
-
"source": [
|
| 228 |
-
"# cat data/train"
|
| 229 |
-
]
|
| 230 |
-
},
|
| 231 |
-
{
|
| 232 |
-
"cell_type": "code",
|
| 233 |
-
"execution_count": 59,
|
| 234 |
-
"id": "34cda0aa",
|
| 235 |
-
"metadata": {},
|
| 236 |
-
"outputs": [],
|
| 237 |
-
"source": [
|
| 238 |
-
"# !cat data/train_pairs.txt"
|
| 239 |
-
]
|
| 240 |
-
},
|
| 241 |
-
{
|
| 242 |
-
"cell_type": "code",
|
| 243 |
-
"execution_count": 6,
|
| 244 |
-
"id": "53095103",
|
| 245 |
-
"metadata": {},
|
| 246 |
-
"outputs": [
|
| 247 |
{
|
| 248 |
"name": "stdout",
|
| 249 |
"output_type": "stream",
|
| 250 |
"text": [
|
| 251 |
-
"
|
| 252 |
]
|
| 253 |
-
}
|
| 254 |
-
],
|
| 255 |
-
"source": [
|
| 256 |
-
"!mkdir output\n",
|
| 257 |
-
"!mkdir checkpoints"
|
| 258 |
-
]
|
| 259 |
-
},
|
| 260 |
-
{
|
| 261 |
-
"cell_type": "code",
|
| 262 |
-
"execution_count": 34,
|
| 263 |
-
"id": "7efe325c",
|
| 264 |
-
"metadata": {},
|
| 265 |
-
"outputs": [],
|
| 266 |
-
"source": [
|
| 267 |
-
"import torch\n",
|
| 268 |
-
"import gc\n",
|
| 269 |
-
"\n",
|
| 270 |
-
"# Delete all tensors and force garbage collection\n",
|
| 271 |
-
"torch.cuda.empty_cache() # Clears unused memory\n",
|
| 272 |
-
"gc.collect() # Python garbage collection\n",
|
| 273 |
-
"\n",
|
| 274 |
-
"# If you want to delete specific variables:\n",
|
| 275 |
-
"for obj in dir():\n",
|
| 276 |
-
" if 'cuda' in str(locals()[obj]):\n",
|
| 277 |
-
" del locals()[obj]\n",
|
| 278 |
-
"gc.collect()\n",
|
| 279 |
-
"torch.cuda.empty_cache()\n"
|
| 280 |
-
]
|
| 281 |
-
},
|
| 282 |
-
{
|
| 283 |
-
"cell_type": "code",
|
| 284 |
-
"execution_count": 35,
|
| 285 |
-
"id": "a48f2753",
|
| 286 |
-
"metadata": {},
|
| 287 |
-
"outputs": [
|
| 288 |
{
|
| 289 |
-
"
|
| 290 |
-
"
|
| 291 |
-
"
|
| 292 |
-
|
| 293 |
-
"\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
|
| 294 |
-
"\u001b[0;31mKeyError\u001b[0m Traceback (most recent call last)",
|
| 295 |
-
"\u001b[0;32m/tmp/ipykernel_69/1017109895.py\u001b[0m in \u001b[0;36m<cell line: 0>\u001b[0;34m()\u001b[0m\n\u001b[1;32m 3\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 4\u001b[0m \u001b[0mtorch\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mcuda\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mempty_cache\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m \u001b[0;31m# Release unused GPU memory\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m----> 5\u001b[0;31m \u001b[0mgc\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mcollect\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m \u001b[0;31m# Run Python garbage collector\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m",
|
| 296 |
-
"\u001b[0;32m/usr/local/lib/python3.11/dist-packages/IPython/core/displayhook.py\u001b[0m in \u001b[0;36m__call__\u001b[0;34m(self, result)\u001b[0m\n\u001b[1;32m 261\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mwrite_output_prompt\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 262\u001b[0m \u001b[0mformat_dict\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mmd_dict\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mcompute_format_data\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mresult\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 263\u001b[0;31m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mupdate_user_ns\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mresult\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 264\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mfill_exec_result\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mresult\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 265\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0mformat_dict\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
|
| 297 |
-
"\u001b[0;32m/usr/local/lib/python3.11/dist-packages/IPython/core/displayhook.py\u001b[0m in \u001b[0;36mupdate_user_ns\u001b[0;34m(self, result)\u001b[0m\n\u001b[1;32m 199\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 200\u001b[0m \u001b[0;31m# Avoid recursive reference when displaying _oh/Out\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 201\u001b[0;31m \u001b[0;32mif\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mcache_size\u001b[0m \u001b[0;32mand\u001b[0m \u001b[0mresult\u001b[0m \u001b[0;32mis\u001b[0m \u001b[0;32mnot\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mshell\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0muser_ns\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;34m'_oh'\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 202\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0mlen\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mshell\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0muser_ns\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;34m'_oh'\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m)\u001b[0m \u001b[0;34m>=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mcache_size\u001b[0m \u001b[0;32mand\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mdo_full_cache\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 203\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mcull_cache\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
|
| 298 |
-
"\u001b[0;31mKeyError\u001b[0m: '_oh'"
|
| 299 |
]
|
| 300 |
-
}
|
| 301 |
-
],
|
| 302 |
-
"source": [
|
| 303 |
-
"import torch\n",
|
| 304 |
-
"import gc\n",
|
| 305 |
-
"\n",
|
| 306 |
-
"torch.cuda.empty_cache() # Release unused GPU memory\n",
|
| 307 |
-
"gc.collect() # Run Python garbage collector"
|
| 308 |
-
]
|
| 309 |
-
},
|
| 310 |
-
{
|
| 311 |
-
"cell_type": "code",
|
| 312 |
-
"execution_count": 36,
|
| 313 |
-
"id": "5a57d765",
|
| 314 |
-
"metadata": {},
|
| 315 |
-
"outputs": [],
|
| 316 |
-
"source": [
|
| 317 |
-
"import torch\n",
|
| 318 |
-
"import gc\n",
|
| 319 |
-
"\n",
|
| 320 |
-
"# Clear CUDA cache and collect garbage\n",
|
| 321 |
-
"torch.cuda.empty_cache()\n",
|
| 322 |
-
"gc.collect()\n",
|
| 323 |
-
"\n",
|
| 324 |
-
"# Delete all user-defined variables except for built-ins and modules\n",
|
| 325 |
-
"for var in list(globals()):\n",
|
| 326 |
-
" if not var.startswith(\"__\") and var not in [\"torch\", \"gc\"]:\n",
|
| 327 |
-
" del globals()[var]\n",
|
| 328 |
-
"\n",
|
| 329 |
-
"gc.collect()\n",
|
| 330 |
-
"torch.cuda.empty_cache()"
|
| 331 |
-
]
|
| 332 |
-
},
|
| 333 |
-
{
|
| 334 |
-
"cell_type": "code",
|
| 335 |
-
"execution_count": 37,
|
| 336 |
-
"id": "5957ec57",
|
| 337 |
-
"metadata": {},
|
| 338 |
-
"outputs": [],
|
| 339 |
-
"source": [
|
| 340 |
-
"import tensorflow as tf\n",
|
| 341 |
-
"tf.keras.backend.clear_session()"
|
| 342 |
-
]
|
| 343 |
-
},
|
| 344 |
-
{
|
| 345 |
-
"cell_type": "code",
|
| 346 |
-
"execution_count": 38,
|
| 347 |
-
"id": "796e8ef7",
|
| 348 |
-
"metadata": {},
|
| 349 |
-
"outputs": [
|
| 350 |
{
|
| 351 |
"name": "stdout",
|
| 352 |
"output_type": "stream",
|
| 353 |
"text": [
|
| 354 |
-
"
|
| 355 |
]
|
| 356 |
-
}
|
| 357 |
-
|
| 358 |
-
|
| 359 |
-
|
| 360 |
-
|
| 361 |
-
|
| 362 |
-
|
| 363 |
-
|
| 364 |
-
" print(f\"GPU memory used: {used:.2f} MB / {total:.2f} MB\")\n",
|
| 365 |
-
"else:\n",
|
| 366 |
-
" print(\"CUDA is not available.\")"
|
| 367 |
-
]
|
| 368 |
-
},
|
| 369 |
-
{
|
| 370 |
-
"cell_type": "code",
|
| 371 |
-
"execution_count": null,
|
| 372 |
-
"id": "32ed173e",
|
| 373 |
-
"metadata": {},
|
| 374 |
-
"outputs": [
|
| 375 |
{
|
| 376 |
"name": "stdout",
|
| 377 |
"output_type": "stream",
|
| 378 |
"text": [
|
| 379 |
-
"
|
| 380 |
-
"Available RAM: 24.16 GB\n"
|
| 381 |
]
|
| 382 |
-
}
|
| 383 |
-
|
| 384 |
-
|
| 385 |
-
|
| 386 |
-
|
| 387 |
-
|
| 388 |
-
|
| 389 |
-
|
| 390 |
-
|
| 391 |
-
|
| 392 |
-
|
| 393 |
-
|
| 394 |
-
|
| 395 |
-
|
| 396 |
-
|
| 397 |
-
|
| 398 |
-
|
| 399 |
-
|
| 400 |
-
|
| 401 |
-
|
| 402 |
-
|
| 403 |
-
|
| 404 |
-
|
| 405 |
-
|
| 406 |
-
|
| 407 |
-
|
| 408 |
-
|
| 409 |
-
|
| 410 |
-
|
| 411 |
-
|
| 412 |
-
|
| 413 |
-
|
| 414 |
-
|
| 415 |
-
|
| 416 |
-
|
| 417 |
-
|
| 418 |
-
|
| 419 |
-
|
| 420 |
-
|
| 421 |
-
|
| 422 |
-
|
| 423 |
-
|
| 424 |
-
|
| 425 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 426 |
{
|
| 427 |
"name": "stdout",
|
| 428 |
"output_type": "stream",
|
| 429 |
"text": [
|
| 430 |
-
"
|
| 431 |
]
|
| 432 |
},
|
| 433 |
{
|
| 434 |
-
"
|
| 435 |
-
"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 436 |
"output_type": "error",
|
| 437 |
"traceback": [
|
| 438 |
"\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
|
| 439 |
-
"\u001b[0;
|
| 440 |
-
"\u001b[0;
|
| 441 |
-
"\u001b[0;
|
| 442 |
-
"\u001b[0;
|
| 443 |
-
"\u001b[0;
|
| 444 |
-
"\u001b[0;
|
| 445 |
-
"\u001b[0;
|
| 446 |
-
"\u001b[0;32m/usr/local/lib/python3.11/dist-packages/torch/nn/modules/module.py\u001b[0m in \u001b[0;36m_apply\u001b[0;34m(self, fn, recurse)\u001b[0m\n\u001b[1;32m 898\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0mrecurse\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 899\u001b[0m \u001b[0;32mfor\u001b[0m \u001b[0mmodule\u001b[0m \u001b[0;32min\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mchildren\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 900\u001b[0;31m \u001b[0mmodule\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_apply\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mfn\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 901\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 902\u001b[0m \u001b[0;32mdef\u001b[0m \u001b[0mcompute_should_use_set_data\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mtensor\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mtensor_applied\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
|
| 447 |
-
"\u001b[0;32m/usr/local/lib/python3.11/dist-packages/torch/nn/modules/module.py\u001b[0m in \u001b[0;36m_apply\u001b[0;34m(self, fn, recurse)\u001b[0m\n\u001b[1;32m 898\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0mrecurse\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 899\u001b[0m \u001b[0;32mfor\u001b[0m \u001b[0mmodule\u001b[0m \u001b[0;32min\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mchildren\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 900\u001b[0;31m \u001b[0mmodule\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_apply\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mfn\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 901\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 902\u001b[0m \u001b[0;32mdef\u001b[0m \u001b[0mcompute_should_use_set_data\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mtensor\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mtensor_applied\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
|
| 448 |
-
"\u001b[0;32m/usr/local/lib/python3.11/dist-packages/torch/nn/modules/module.py\u001b[0m in \u001b[0;36m_apply\u001b[0;34m(self, fn, recurse)\u001b[0m\n\u001b[1;32m 898\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0mrecurse\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 899\u001b[0m \u001b[0;32mfor\u001b[0m \u001b[0mmodule\u001b[0m \u001b[0;32min\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mchildren\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 900\u001b[0;31m \u001b[0mmodule\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_apply\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mfn\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 901\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 902\u001b[0m \u001b[0;32mdef\u001b[0m \u001b[0mcompute_should_use_set_data\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mtensor\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mtensor_applied\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
|
| 449 |
-
"\u001b[0;32m/usr/local/lib/python3.11/dist-packages/torch/nn/modules/module.py\u001b[0m in \u001b[0;36m_apply\u001b[0;34m(self, fn, recurse)\u001b[0m\n\u001b[1;32m 925\u001b[0m \u001b[0;31m# `with torch.no_grad():`\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 926\u001b[0m \u001b[0;32mwith\u001b[0m \u001b[0mtorch\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mno_grad\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 927\u001b[0;31m \u001b[0mparam_applied\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mfn\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mparam\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 928\u001b[0m \u001b[0mp_should_use_set_data\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mcompute_should_use_set_data\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mparam\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mparam_applied\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 929\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n",
|
| 450 |
-
"\u001b[0;32m/usr/local/lib/python3.11/dist-packages/torch/nn/modules/module.py\u001b[0m in \u001b[0;36mconvert\u001b[0;34m(t)\u001b[0m\n\u001b[1;32m 1324\u001b[0m \u001b[0mmemory_format\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mconvert_to_format\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 1325\u001b[0m )\n\u001b[0;32m-> 1326\u001b[0;31m return t.to(\n\u001b[0m\u001b[1;32m 1327\u001b[0m \u001b[0mdevice\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 1328\u001b[0m \u001b[0mdtype\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0mt\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mis_floating_point\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m \u001b[0;32mor\u001b[0m \u001b[0mt\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mis_complex\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m \u001b[0;32melse\u001b[0m \u001b[0;32mNone\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
|
| 451 |
-
"\u001b[0;31mOutOfMemoryError\u001b[0m: CUDA out of memory. Tried to allocate 58.00 MiB. GPU 0 has a total capacity of 15.89 GiB of which 29.12 MiB is free. Process 3907 has 15.85 GiB memory in use. Of the allocated memory 15.49 GiB is allocated by PyTorch, and 62.82 MiB is reserved by PyTorch but unallocated. If reserved but unallocated memory is large try setting PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True to avoid fragmentation. See documentation for Memory Management (https://pytorch.org/docs/stable/notes/cuda.html#environment-variables)"
|
| 452 |
]
|
| 453 |
}
|
| 454 |
],
|
| 455 |
"source": [
|
| 456 |
-
"import os\n",
|
| 457 |
-
"import json\n",
|
| 458 |
"import random\n",
|
| 459 |
"import argparse\n",
|
| 460 |
"from pathlib import Path\n",
|
| 461 |
-
"from typing import Dict,
|
| 462 |
"\n",
|
| 463 |
"import torch\n",
|
| 464 |
"import torch.nn as nn\n",
|
| 465 |
"import torch.nn.functional as F\n",
|
| 466 |
-
"from torch.utils.data import
|
| 467 |
"from torch.optim import AdamW\n",
|
| 468 |
-
"from torch.optim.lr_scheduler import CosineAnnealingLR\n",
|
| 469 |
"\n",
|
| 470 |
"import numpy as np\n",
|
| 471 |
"from PIL import Image\n",
|
|
@@ -475,7 +388,7 @@
|
|
| 475 |
"# Import your custom modules\n",
|
| 476 |
"from load_model import preload_models_from_standard_weights\n",
|
| 477 |
"from ddpm import DDPMSampler\n",
|
| 478 |
-
"from utils import check_inputs, get_time_embedding, prepare_image, prepare_mask_image\n",
|
| 479 |
"from diffusers.utils.torch_utils import randn_tensor\n",
|
| 480 |
"\n",
|
| 481 |
"class CatVTONTrainer:\n",
|
|
@@ -533,15 +446,18 @@
|
|
| 533 |
" # Resume from checkpoint if provided\n",
|
| 534 |
" self.global_step = 0\n",
|
| 535 |
" self.current_epoch = 0\n",
|
|
|
|
|
|
|
|
|
|
|
|
|
| 536 |
" if resume_from_checkpoint:\n",
|
| 537 |
" self._load_checkpoint(resume_from_checkpoint)\n",
|
|
|
|
|
|
|
| 538 |
" \n",
|
| 539 |
" self.encoder = self.models.get('encoder', None)\n",
|
| 540 |
" self.decoder = self.models.get('decoder', None)\n",
|
| 541 |
" self.diffusion = self.models.get('diffusion', None)\n",
|
| 542 |
-
"\n",
|
| 543 |
-
" # Setup models and optimizers\n",
|
| 544 |
-
" self._setup_training()\n",
|
| 545 |
" \n",
|
| 546 |
" def _setup_training(self):\n",
|
| 547 |
" \"\"\"Setup models for training with PEFT\"\"\"\n",
|
|
@@ -595,7 +511,7 @@
|
|
| 595 |
" \"\"\"Enable PEFT training - only self-attention layers\"\"\"\n",
|
| 596 |
" print(\"Enabling PEFT training (self-attention layers only)\")\n",
|
| 597 |
" \n",
|
| 598 |
-
" unet = self.diffusion.unet\n",
|
| 599 |
" \n",
|
| 600 |
" # Enable attention layers in encoders and decoders\n",
|
| 601 |
" for layers in [unet.encoders, unet.decoders]:\n",
|
|
@@ -610,19 +526,14 @@
|
|
| 610 |
" for name, param in layer.named_parameters():\n",
|
| 611 |
" if 'attention_1' in name:\n",
|
| 612 |
" param.requires_grad = True\n",
|
| 613 |
-
" \n",
|
| 614 |
-
" def _apply_cfg_dropout(self, garment_latent: torch.Tensor) -> torch.Tensor:\n",
|
| 615 |
-
" \"\"\"Apply classifier-free guidance dropout (10% chance)\"\"\"\n",
|
| 616 |
-
" if self.training and random.random() < self.cfg_dropout_prob:\n",
|
| 617 |
-
" # Replace with zero tensor for unconditional generation\n",
|
| 618 |
-
" return torch.zeros_like(garment_latent)\n",
|
| 619 |
-
" return garment_latent\n",
|
| 620 |
" \n",
|
| 621 |
" def compute_loss(self, batch: Dict[str, torch.Tensor]) -> torch.Tensor:\n",
|
| 622 |
" \"\"\"Compute MSE loss for denoising with DREAM strategy\"\"\"\n",
|
| 623 |
" person_images = batch['person'].to(self.device)\n",
|
| 624 |
" cloth_images = batch['cloth'].to(self.device)\n",
|
| 625 |
" masks = batch['mask'].to(self.device)\n",
|
|
|
|
|
|
|
| 626 |
"\n",
|
| 627 |
" concat_dim = -2 # y axis concat\n",
|
| 628 |
" \n",
|
|
@@ -642,67 +553,65 @@
|
|
| 642 |
" condition_latent = compute_vae_encodings(condition_image, self.encoder)\n",
|
| 643 |
" mask_latent = torch.nn.functional.interpolate(mask, size=masked_latent.shape[-2:], mode=\"nearest\")\n",
|
| 644 |
" \n",
|
|
|
|
| 645 |
" del image, mask, condition_image\n",
|
| 646 |
-
"\n",
|
| 647 |
-
" # Apply CFG dropout to garment latent\n",
|
| 648 |
-
"
|
|
|
|
| 649 |
" \n",
|
| 650 |
" # Concatenate latents\n",
|
| 651 |
-
"
|
| 652 |
-
"
|
| 653 |
" target_latents = torch.cat([person_latent, condition_latent], dim=concat_dim)\n",
|
| 654 |
"\n",
|
| 655 |
" noise = randn_tensor(\n",
|
| 656 |
-
"
|
| 657 |
" generator=self.generator,\n",
|
| 658 |
-
" device=
|
| 659 |
" dtype=self.weight_dtype,\n",
|
| 660 |
" )\n",
|
| 661 |
"\n",
|
| 662 |
-
" timesteps = torch.randint(1, 1000, size=(1,), device=self.device)[0].long().item()\n",
|
| 663 |
-
" timesteps = torch.tensor(timesteps, device=self.device)\n",
|
|
|
|
|
|
|
| 664 |
" timesteps_embedding = get_time_embedding(timesteps).to(self.device, dtype=self.weight_dtype)\n",
|
| 665 |
"\n",
|
| 666 |
" # Add noise to latents\n",
|
| 667 |
" noisy_latents = self.scheduler.add_noise(target_latents, timesteps, noise)\n",
|
| 668 |
"\n",
|
| 669 |
-
"
|
| 670 |
-
"
|
| 671 |
-
"
|
| 672 |
-
"
|
| 673 |
-
"
|
|
|
|
|
|
|
| 674 |
"\n",
|
| 675 |
" # DREAM strategy implementation\n",
|
| 676 |
" if self.dream_lambda > 0:\n",
|
| 677 |
" # Get initial noise prediction\n",
|
| 678 |
" with torch.no_grad():\n",
|
| 679 |
" epsilon_theta = self.diffusion(\n",
|
| 680 |
-
"
|
| 681 |
" timesteps_embedding\n",
|
| 682 |
" )\n",
|
| 683 |
" \n",
|
| 684 |
-
" #
|
| 685 |
-
" alphas_cumprod = self.scheduler.alphas_cumprod.to(device=self.device, dtype=self.weight_dtype)\n",
|
| 686 |
-
" sqrt_alpha_prod = alphas_cumprod[timesteps] ** 0.5\n",
|
| 687 |
-
" sqrt_one_minus_alpha_prod = (1 - alphas_cumprod[timesteps]) ** 0.5\n",
|
| 688 |
-
" \n",
|
| 689 |
-
" # Reshape for broadcasting\n",
|
| 690 |
-
" sqrt_alpha_prod = sqrt_alpha_prod.view(-1, 1, 1, 1)\n",
|
| 691 |
-
" sqrt_one_minus_alpha_prod = sqrt_one_minus_alpha_prod.view(-1, 1, 1, 1)\n",
|
| 692 |
-
" \n",
|
| 693 |
-
" # DREAM noise combination\n",
|
| 694 |
" dream_noise = noise + self.dream_lambda * epsilon_theta\n",
|
|
|
|
|
|
|
|
|
|
| 695 |
"\n",
|
| 696 |
-
"
|
| 697 |
-
"\n",
|
| 698 |
-
"
|
| 699 |
-
" dream_noisy_latents
|
| 700 |
-
"
|
| 701 |
-
" masked_latent_concat\n",
|
| 702 |
-
" ], dim=1)\n",
|
| 703 |
"\n",
|
| 704 |
" predicted_noise = self.diffusion(\n",
|
| 705 |
-
"
|
| 706 |
" timesteps_embedding\n",
|
| 707 |
" )\n",
|
| 708 |
" # DREAM loss: |(ε + λεθ) - εθ(ẑt, t)|²\n",
|
|
@@ -710,13 +619,27 @@
|
|
| 710 |
" else:\n",
|
| 711 |
" # Standard training without DREAM\n",
|
| 712 |
" predicted_noise = self.diffusion(\n",
|
| 713 |
-
"
|
| 714 |
" timesteps_embedding,\n",
|
| 715 |
" )\n",
|
| 716 |
-
"\n",
|
| 717 |
" # Standard MSE loss\n",
|
| 718 |
" loss = F.mse_loss(predicted_noise, noise)\n",
|
| 719 |
" \n",
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 720 |
" return loss\n",
|
| 721 |
" \n",
|
| 722 |
" def train_epoch(self) -> float:\n",
|
|
@@ -883,13 +806,13 @@
|
|
| 883 |
" args.__dict__ = {\n",
|
| 884 |
" \"base_model_path\": \"sd-v1-5-inpainting.ckpt\",\n",
|
| 885 |
" \"dataset_name\": \"vitonhd\",\n",
|
| 886 |
-
" \"data_root_path\": \"
|
| 887 |
" \"output_dir\": \"./checkpoints\",\n",
|
| 888 |
-
" \"resume_from_checkpoint\":
|
| 889 |
" \"seed\": 42,\n",
|
| 890 |
-
" \"batch_size\":
|
| 891 |
" \"width\": 384,\n",
|
| 892 |
-
" \"height\":
|
| 893 |
" \"repaint\": True,\n",
|
| 894 |
" \"eval_pair\": True,\n",
|
| 895 |
" \"concat_eval_results\": True,\n",
|
|
@@ -899,10 +822,10 @@
|
|
| 899 |
" \"learning_rate\": 1e-5,\n",
|
| 900 |
" \"max_grad_norm\": 1.0,\n",
|
| 901 |
" \"cfg_dropout_prob\": 0.1,\n",
|
| 902 |
-
" \"dream_lambda\": 0,\n",
|
| 903 |
" \"use_peft\": True,\n",
|
| 904 |
" \"use_mixed_precision\": True,\n",
|
| 905 |
-
" \"save_steps\":
|
| 906 |
" \"is_train\": True\n",
|
| 907 |
" }\n",
|
| 908 |
" \n",
|
|
@@ -918,10 +841,15 @@
|
|
| 918 |
" torch.backends.cuda.matmul.allow_tf32 = True \n",
|
| 919 |
" torch.backends.cudnn.allow_tf32 = True \n",
|
| 920 |
" torch.set_float32_matmul_precision(\"high\")\n",
|
|
|
|
|
|
|
| 921 |
"\n",
|
| 922 |
" # Load pretrained models\n",
|
| 923 |
" print(\"Loading pretrained models...\")\n",
|
| 924 |
" models = preload_models_from_standard_weights(args.base_model_path, args.device)\n",
|
|
|
|
|
|
|
|
|
|
| 925 |
" \n",
|
| 926 |
" # Create dataloader\n",
|
| 927 |
" print(\"Creating dataloader...\")\n",
|
|
@@ -930,6 +858,8 @@
|
|
| 930 |
" print(f\"Training for {args.num_epochs} epochs\")\n",
|
| 931 |
" print(f\"Batches per epoch: {len(train_dataloader)}\")\n",
|
| 932 |
" \n",
|
|
|
|
|
|
|
| 933 |
" # Initialize trainer\n",
|
| 934 |
" print(\"Initializing trainer...\") \n",
|
| 935 |
" trainer = CatVTONTrainer(\n",
|
|
@@ -954,31 +884,14 @@
|
|
| 954 |
" print(\"Starting training...\")\n",
|
| 955 |
" trainer.train() \n",
|
| 956 |
"\n",
|
| 957 |
-
"\n",
|
| 958 |
"if __name__ == \"__main__\":\n",
|
| 959 |
" main()"
|
| 960 |
]
|
| 961 |
-
},
|
| 962 |
-
{
|
| 963 |
-
"cell_type": "code",
|
| 964 |
-
"execution_count": null,
|
| 965 |
-
"id": "77892d6a",
|
| 966 |
-
"metadata": {},
|
| 967 |
-
"outputs": [],
|
| 968 |
-
"source": []
|
| 969 |
-
},
|
| 970 |
-
{
|
| 971 |
-
"cell_type": "code",
|
| 972 |
-
"execution_count": null,
|
| 973 |
-
"id": "b3917d76",
|
| 974 |
-
"metadata": {},
|
| 975 |
-
"outputs": [],
|
| 976 |
-
"source": []
|
| 977 |
}
|
| 978 |
],
|
| 979 |
"metadata": {
|
| 980 |
"kernelspec": {
|
| 981 |
-
"display_name": "
|
| 982 |
"language": "python",
|
| 983 |
"name": "python3"
|
| 984 |
},
|
|
@@ -992,7 +905,7 @@
|
|
| 992 |
"name": "python",
|
| 993 |
"nbconvert_exporter": "python",
|
| 994 |
"pygments_lexer": "ipython3",
|
| 995 |
-
"version": "3.
|
| 996 |
}
|
| 997 |
},
|
| 998 |
"nbformat": 4,
|
|
|
|
| 5 |
"execution_count": 1,
|
| 6 |
"id": "81e4a1db",
|
| 7 |
"metadata": {},
|
| 8 |
+
"outputs": [],
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 9 |
"source": [
|
| 10 |
+
"# !git clone -b CatVTON https://github.com/Harsh-Kesharwani/stable-diffusion.git"
|
| 11 |
]
|
| 12 |
},
|
| 13 |
{
|
|
|
|
| 15 |
"execution_count": 2,
|
| 16 |
"id": "9c89e320",
|
| 17 |
"metadata": {},
|
| 18 |
+
"outputs": [],
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 19 |
"source": [
|
| 20 |
+
"# cd stable-diffusion/"
|
| 21 |
]
|
| 22 |
},
|
| 23 |
{
|
| 24 |
"cell_type": "code",
|
| 25 |
"execution_count": 3,
|
| 26 |
+
"id": "ff8b706c",
|
| 27 |
"metadata": {},
|
| 28 |
"outputs": [
|
| 29 |
{
|
| 30 |
"name": "stdout",
|
| 31 |
"output_type": "stream",
|
| 32 |
"text": [
|
| 33 |
+
"Model already downloaded.\n"
|
| 34 |
]
|
| 35 |
}
|
| 36 |
],
|
| 37 |
"source": [
|
| 38 |
+
"# check if the model is downloaded, if not download it\n",
|
| 39 |
+
"import os\n",
|
| 40 |
+
"if not os.path.exists(\"sd-v1-5-inpainting.ckpt\"):\n",
|
| 41 |
+
" !wget https://huggingface.co/sd-legacy/stable-diffusion-inpainting/resolve/main/sd-v1-5-inpainting.ckpt\n",
|
| 42 |
+
"else:\n",
|
| 43 |
+
" print(\"Model already downloaded.\")\n"
|
| 44 |
]
|
| 45 |
},
|
| 46 |
{
|
| 47 |
"cell_type": "code",
|
| 48 |
"execution_count": 4,
|
| 49 |
+
"id": "53095103",
|
| 50 |
"metadata": {},
|
| 51 |
"outputs": [
|
| 52 |
{
|
| 53 |
"name": "stdout",
|
| 54 |
"output_type": "stream",
|
| 55 |
"text": [
|
| 56 |
+
"Checkpoints directory already exists.\n"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 57 |
]
|
| 58 |
}
|
| 59 |
],
|
| 60 |
"source": [
|
| 61 |
+
"# make output and checkpoints directories if they don't exist\n",
|
| 62 |
+
"import os\n",
|
| 63 |
+
"if not os.path.exists(\"checkpoints\"):\n",
|
| 64 |
+
" os.makedirs(\"checkpoints\")\n",
|
| 65 |
+
"else:\n",
|
| 66 |
+
" print(\"Checkpoints directory already exists.\")"
|
| 67 |
]
|
| 68 |
},
|
| 69 |
{
|
| 70 |
"cell_type": "code",
|
| 71 |
"execution_count": 5,
|
| 72 |
+
"id": "d8978b25",
|
| 73 |
"metadata": {},
|
| 74 |
"outputs": [
|
| 75 |
{
|
| 76 |
"name": "stdout",
|
| 77 |
"output_type": "stream",
|
| 78 |
"text": [
|
| 79 |
+
"VITON-HD dataset already exists.\n",
|
| 80 |
+
"Zip file does not exist, nothing to remove.\n"
|
|
|
|
|
|
|
|
|
|
|
|
|
| 81 |
]
|
| 82 |
}
|
| 83 |
],
|
| 84 |
"source": [
|
| 85 |
+
"import os\n",
|
| 86 |
+
"if not os.path.exists(\"viton-hd-dataset\"):\n",
|
| 87 |
+
" !curl -L -u harshkesherwani:7695128b407febc869a6f5b2cb0cbf26\\\n",
|
| 88 |
+
" -o /home/mahesh/harsh/stable-diffusion/viton-hd-dataset.zip\\\n",
|
| 89 |
+
" https://www.kaggle.com/api/v1/datasets/download/harshkesherwani/viton-hd-dataset\n",
|
| 90 |
+
" \n",
|
| 91 |
+
" import zipfile\n",
|
| 92 |
+
" with zipfile.ZipFile('viton-hd-dataset.zip', 'r') as zip_ref:\n",
|
| 93 |
+
" zip_ref.extractall('viton-hd-dataset')\n",
|
| 94 |
+
" \n",
|
| 95 |
+
" print(\"VITON-HD dataset downloaded and extracted.\")\n",
|
| 96 |
+
"else:\n",
|
| 97 |
+
" print(\"VITON-HD dataset already exists.\")\n",
|
| 98 |
+
" \n",
|
| 99 |
+
"import os\n",
|
| 100 |
+
"if os.path.exists(\"viton-hd-dataset.zip\"):\n",
|
| 101 |
+
" os.remove(\"viton-hd-dataset.zip\")\n",
|
| 102 |
+
" print(\"Removed the zip file after extraction.\")\n",
|
| 103 |
+
"else:\n",
|
| 104 |
+
" print(\"Zip file does not exist, nothing to remove.\")\n"
|
| 105 |
]
|
| 106 |
},
|
| 107 |
{
|
| 108 |
"cell_type": "code",
|
| 109 |
"execution_count": null,
|
| 110 |
+
"id": "3aea80d9",
|
| 111 |
"metadata": {},
|
| 112 |
"outputs": [
|
| 113 |
+
{
|
| 114 |
+
"name": "stderr",
|
| 115 |
+
"output_type": "stream",
|
| 116 |
+
"text": [
|
| 117 |
+
"/home/mahesh/miniconda3/envs/harsh/lib/python3.10/site-packages/tqdm/auto.py:21: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n",
|
| 118 |
+
" from .autonotebook import tqdm as notebook_tqdm\n"
|
| 119 |
+
]
|
| 120 |
+
},
|
| 121 |
{
|
| 122 |
"name": "stdout",
|
| 123 |
"output_type": "stream",
|
| 124 |
"text": [
|
| 125 |
+
"----------------------------------------------------------------------------------------------------\n",
|
| 126 |
+
"Loading pretrained models...\n",
|
| 127 |
+
"Models loaded successfully.\n",
|
| 128 |
+
"----------------------------------------------------------------------------------------------------\n",
|
| 129 |
+
"Creating dataloader...\n",
|
| 130 |
+
"Dataset vitonhd loaded, total 11647 pairs.\n",
|
| 131 |
+
"Training for 50 epochs\n",
|
| 132 |
+
"Batches per epoch: 5824\n",
|
| 133 |
+
"----------------------------------------------------------------------------------------------------\n",
|
| 134 |
+
"Initializing trainer...\n",
|
| 135 |
+
"Enabling PEFT training (self-attention layers only)\n",
|
| 136 |
+
"Total parameters: 899,226,667\n",
|
| 137 |
+
"Trainable parameters: 49,574,080 (5.51%)\n"
|
| 138 |
]
|
| 139 |
+
},
|
| 140 |
+
{
|
| 141 |
+
"name": "stderr",
|
| 142 |
+
"output_type": "stream",
|
| 143 |
+
"text": [
|
| 144 |
+
"/tmp/ipykernel_1669505/646906096.py:71: FutureWarning: `torch.cuda.amp.GradScaler(args...)` is deprecated. Please use `torch.amp.GradScaler('cuda', args...)` instead.\n",
|
| 145 |
+
" self.scaler = torch.cuda.amp.GradScaler()\n"
|
| 146 |
+
]
|
| 147 |
+
},
|
|
|
|
|
|
|
|
|
|
| 148 |
{
|
| 149 |
"name": "stdout",
|
| 150 |
"output_type": "stream",
|
| 151 |
"text": [
|
| 152 |
+
"Checkpoint loaded: ./checkpoints/checkpoint_step_40000.pth\n",
|
| 153 |
+
"Resuming from epoch 12, step 40000\n",
|
| 154 |
+
"Starting training...\n",
|
| 155 |
+
"Starting training for 50 epochs\n",
|
| 156 |
+
"Total training batches per epoch: 5824\n",
|
| 157 |
+
"Using DREAM with lambda = 0\n",
|
| 158 |
+
"Mixed precision: True\n"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 159 |
]
|
| 160 |
+
},
|
| 161 |
+
{
|
| 162 |
+
"name": "stderr",
|
| 163 |
+
"output_type": "stream",
|
| 164 |
+
"text": [
|
| 165 |
+
"Epoch 13: 0%| | 0/5824 [00:00<?, ?it/s]/tmp/ipykernel_1669505/646906096.py:291: FutureWarning: `torch.cuda.amp.autocast(args...)` is deprecated. Please use `torch.amp.autocast('cuda', args...)` instead.\n",
|
| 166 |
+
" with torch.cuda.amp.autocast():\n",
|
| 167 |
+
"/tmp/ipykernel_1669505/646906096.py:181: FutureWarning: `torch.cuda.amp.autocast(args...)` is deprecated. Please use `torch.amp.autocast('cuda', args...)` instead.\n",
|
| 168 |
+
" with torch.cuda.amp.autocast(enabled=self.use_mixed_precision):\n",
|
| 169 |
+
"/tmp/ipykernel_1669505/504089317.py:50: FutureWarning: `torch.cuda.amp.autocast(args...)` is deprecated. Please use `torch.amp.autocast('cuda', args...)` instead.\n",
|
| 170 |
+
" with torch.cuda.amp.autocast(enabled=False):\n"
|
| 171 |
+
]
|
| 172 |
+
},
|
| 173 |
{
|
| 174 |
"name": "stdout",
|
| 175 |
"output_type": "stream",
|
| 176 |
"text": [
|
| 177 |
+
"Debug visualization saved: checkpoints/debug_viz/debug_step_040000.jpg\n"
|
| 178 |
]
|
| 179 |
+
},
|
| 180 |
+
{
|
| 181 |
+
"name": "stderr",
|
| 182 |
+
"output_type": "stream",
|
| 183 |
+
"text": [
|
| 184 |
+
"Epoch 13: 17%|█▋ | 1000/5824 [09:02<41:34, 1.93it/s, loss=0.0591, lr=1e-5, step=41001]"
|
| 185 |
+
]
|
| 186 |
+
},
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 187 |
{
|
| 188 |
"name": "stdout",
|
| 189 |
"output_type": "stream",
|
| 190 |
"text": [
|
| 191 |
+
"Debug visualization saved: checkpoints/debug_viz/debug_step_041000.jpg\n"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 192 |
]
|
| 193 |
+
},
|
| 194 |
+
{
|
| 195 |
+
"name": "stderr",
|
| 196 |
+
"output_type": "stream",
|
| 197 |
+
"text": [
|
| 198 |
+
"Epoch 13: 34%|███▍ | 2001/5824 [17:46<38:47, 1.64it/s, loss=0.0143, lr=1e-5, step=42001] "
|
| 199 |
+
]
|
| 200 |
+
},
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 201 |
{
|
| 202 |
"name": "stdout",
|
| 203 |
"output_type": "stream",
|
| 204 |
"text": [
|
| 205 |
+
"Debug visualization saved: checkpoints/debug_viz/debug_step_042000.jpg\n"
|
| 206 |
]
|
| 207 |
+
},
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 208 |
{
|
| 209 |
+
"name": "stderr",
|
| 210 |
+
"output_type": "stream",
|
| 211 |
+
"text": [
|
| 212 |
+
"Epoch 13: 52%|█████▏ | 3000/5824 [26:32<24:32, 1.92it/s, loss=0.0144, lr=1e-5, step=43001] "
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 213 |
]
|
| 214 |
+
},
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 215 |
{
|
| 216 |
"name": "stdout",
|
| 217 |
"output_type": "stream",
|
| 218 |
"text": [
|
| 219 |
+
"Debug visualization saved: checkpoints/debug_viz/debug_step_043000.jpg\n"
|
| 220 |
]
|
| 221 |
+
},
|
| 222 |
+
{
|
| 223 |
+
"name": "stderr",
|
| 224 |
+
"output_type": "stream",
|
| 225 |
+
"text": [
|
| 226 |
+
"Epoch 13: 69%|██████▊ | 4001/5824 [35:19<18:14, 1.67it/s, loss=0.0233, lr=1e-5, step=44001] "
|
| 227 |
+
]
|
| 228 |
+
},
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 229 |
{
|
| 230 |
"name": "stdout",
|
| 231 |
"output_type": "stream",
|
| 232 |
"text": [
|
| 233 |
+
"Debug visualization saved: checkpoints/debug_viz/debug_step_044000.jpg\n"
|
|
|
|
| 234 |
]
|
| 235 |
+
},
|
| 236 |
+
{
|
| 237 |
+
"name": "stderr",
|
| 238 |
+
"output_type": "stream",
|
| 239 |
+
"text": [
|
| 240 |
+
"Epoch 13: 86%|████████▌ | 5000/5824 [44:07<07:16, 1.89it/s, loss=0.0609, lr=1e-5, step=45001] "
|
| 241 |
+
]
|
| 242 |
+
},
|
| 243 |
+
{
|
| 244 |
+
"name": "stdout",
|
| 245 |
+
"output_type": "stream",
|
| 246 |
+
"text": [
|
| 247 |
+
"Debug visualization saved: checkpoints/debug_viz/debug_step_045000.jpg\n"
|
| 248 |
+
]
|
| 249 |
+
},
|
| 250 |
+
{
|
| 251 |
+
"name": "stderr",
|
| 252 |
+
"output_type": "stream",
|
| 253 |
+
"text": [
|
| 254 |
+
"Epoch 13: 100%|██████████| 5824/5824 [51:31<00:00, 1.88it/s, loss=0.00715, lr=1e-5, step=45824]\n"
|
| 255 |
+
]
|
| 256 |
+
},
|
| 257 |
+
{
|
| 258 |
+
"name": "stdout",
|
| 259 |
+
"output_type": "stream",
|
| 260 |
+
"text": [
|
| 261 |
+
"Epoch 13/50 - Train Loss: 0.030487\n"
|
| 262 |
+
]
|
| 263 |
+
},
|
| 264 |
+
{
|
| 265 |
+
"name": "stderr",
|
| 266 |
+
"output_type": "stream",
|
| 267 |
+
"text": [
|
| 268 |
+
"Epoch 14: 3%|▎ | 177/5824 [01:33<56:36, 1.66it/s, loss=0.0409, lr=1e-5, step=46001] "
|
| 269 |
+
]
|
| 270 |
+
},
|
| 271 |
+
{
|
| 272 |
+
"name": "stdout",
|
| 273 |
+
"output_type": "stream",
|
| 274 |
+
"text": [
|
| 275 |
+
"Debug visualization saved: checkpoints/debug_viz/debug_step_046000.jpg\n"
|
| 276 |
+
]
|
| 277 |
+
},
|
| 278 |
+
{
|
| 279 |
+
"name": "stderr",
|
| 280 |
+
"output_type": "stream",
|
| 281 |
+
"text": [
|
| 282 |
+
"Epoch 14: 20%|██ | 1177/5824 [10:19<46:38, 1.66it/s, loss=0.00494, lr=1e-5, step=47001]"
|
| 283 |
+
]
|
| 284 |
+
},
|
| 285 |
+
{
|
| 286 |
+
"name": "stdout",
|
| 287 |
+
"output_type": "stream",
|
| 288 |
+
"text": [
|
| 289 |
+
"Debug visualization saved: checkpoints/debug_viz/debug_step_047000.jpg\n"
|
| 290 |
+
]
|
| 291 |
+
},
|
| 292 |
+
{
|
| 293 |
+
"name": "stderr",
|
| 294 |
+
"output_type": "stream",
|
| 295 |
+
"text": [
|
| 296 |
+
"Epoch 14: 37%|███▋ | 2177/5824 [19:07<36:55, 1.65it/s, loss=0.0527, lr=1e-5, step=48001] "
|
| 297 |
+
]
|
| 298 |
+
},
|
| 299 |
+
{
|
| 300 |
+
"name": "stdout",
|
| 301 |
+
"output_type": "stream",
|
| 302 |
+
"text": [
|
| 303 |
+
"Debug visualization saved: checkpoints/debug_viz/debug_step_048000.jpg\n"
|
| 304 |
+
]
|
| 305 |
+
},
|
| 306 |
+
{
|
| 307 |
+
"name": "stderr",
|
| 308 |
+
"output_type": "stream",
|
| 309 |
+
"text": [
|
| 310 |
+
"Epoch 14: 55%|█████▍ | 3177/5824 [27:52<26:30, 1.66it/s, loss=0.0266, lr=1e-5, step=49001] "
|
| 311 |
+
]
|
| 312 |
+
},
|
| 313 |
{
|
| 314 |
"name": "stdout",
|
| 315 |
"output_type": "stream",
|
| 316 |
"text": [
|
| 317 |
+
"Debug visualization saved: checkpoints/debug_viz/debug_step_049000.jpg\n"
|
| 318 |
]
|
| 319 |
},
|
| 320 |
{
|
| 321 |
+
"name": "stderr",
|
| 322 |
+
"output_type": "stream",
|
| 323 |
+
"text": [
|
| 324 |
+
"Epoch 14: 72%|███████▏ | 4176/5824 [36:39<41:17, 1.50s/it, loss=0.0227, lr=1e-5, step=5e+4] "
|
| 325 |
+
]
|
| 326 |
+
},
|
| 327 |
+
{
|
| 328 |
+
"name": "stdout",
|
| 329 |
+
"output_type": "stream",
|
| 330 |
+
"text": [
|
| 331 |
+
"Checkpoint saved: checkpoints/checkpoint_step_50000.pth\n"
|
| 332 |
+
]
|
| 333 |
+
},
|
| 334 |
+
{
|
| 335 |
+
"name": "stderr",
|
| 336 |
+
"output_type": "stream",
|
| 337 |
+
"text": [
|
| 338 |
+
"Epoch 14: 72%|███████▏ | 4177/5824 [36:40<35:22, 1.29s/it, loss=0.0152, lr=1e-5, step=5e+4]"
|
| 339 |
+
]
|
| 340 |
+
},
|
| 341 |
+
{
|
| 342 |
+
"name": "stdout",
|
| 343 |
+
"output_type": "stream",
|
| 344 |
+
"text": [
|
| 345 |
+
"Debug visualization saved: checkpoints/debug_viz/debug_step_050000.jpg\n"
|
| 346 |
+
]
|
| 347 |
+
},
|
| 348 |
+
{
|
| 349 |
+
"name": "stderr",
|
| 350 |
+
"output_type": "stream",
|
| 351 |
+
"text": [
|
| 352 |
+
"Epoch 14: 72%|███████▏ | 4211/5824 [36:58<14:09, 1.90it/s, loss=0.0351, lr=1e-5, step=5e+4] \n"
|
| 353 |
+
]
|
| 354 |
+
},
|
| 355 |
+
{
|
| 356 |
+
"ename": "KeyboardInterrupt",
|
| 357 |
+
"evalue": "",
|
| 358 |
"output_type": "error",
|
| 359 |
"traceback": [
|
| 360 |
"\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
|
| 361 |
+
"\u001b[0;31mKeyboardInterrupt\u001b[0m Traceback (most recent call last)",
|
| 362 |
+
"Cell \u001b[0;32mIn[8], line 520\u001b[0m\n\u001b[1;32m 517\u001b[0m trainer\u001b[38;5;241m.\u001b[39mtrain() \n\u001b[1;32m 519\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;18m__name__\u001b[39m \u001b[38;5;241m==\u001b[39m \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124m__main__\u001b[39m\u001b[38;5;124m\"\u001b[39m:\n\u001b[0;32m--> 520\u001b[0m \u001b[43mmain\u001b[49m\u001b[43m(\u001b[49m\u001b[43m)\u001b[49m\n",
|
| 363 |
+
"Cell \u001b[0;32mIn[8], line 517\u001b[0m, in \u001b[0;36mmain\u001b[0;34m()\u001b[0m\n\u001b[1;32m 515\u001b[0m \u001b[38;5;66;03m# Start training\u001b[39;00m\n\u001b[1;32m 516\u001b[0m \u001b[38;5;28mprint\u001b[39m(\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mStarting training...\u001b[39m\u001b[38;5;124m\"\u001b[39m)\n\u001b[0;32m--> 517\u001b[0m \u001b[43mtrainer\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mtrain\u001b[49m\u001b[43m(\u001b[49m\u001b[43m)\u001b[49m\n",
|
| 364 |
+
"Cell \u001b[0;32mIn[8], line 353\u001b[0m, in \u001b[0;36mCatVTONTrainer.train\u001b[0;34m(self)\u001b[0m\n\u001b[1;32m 350\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mcurrent_epoch \u001b[38;5;241m=\u001b[39m epoch\n\u001b[1;32m 352\u001b[0m \u001b[38;5;66;03m# Train one epoch\u001b[39;00m\n\u001b[0;32m--> 353\u001b[0m train_loss \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mtrain_epoch\u001b[49m\u001b[43m(\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 355\u001b[0m \u001b[38;5;28mprint\u001b[39m(\u001b[38;5;124mf\u001b[39m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mEpoch \u001b[39m\u001b[38;5;132;01m{\u001b[39;00mepoch\u001b[38;5;241m+\u001b[39m\u001b[38;5;241m1\u001b[39m\u001b[38;5;132;01m}\u001b[39;00m\u001b[38;5;124m/\u001b[39m\u001b[38;5;132;01m{\u001b[39;00m\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mnum_epochs\u001b[38;5;132;01m}\u001b[39;00m\u001b[38;5;124m - Train Loss: \u001b[39m\u001b[38;5;132;01m{\u001b[39;00mtrain_loss\u001b[38;5;132;01m:\u001b[39;00m\u001b[38;5;124m.6f\u001b[39m\u001b[38;5;132;01m}\u001b[39;00m\u001b[38;5;124m\"\u001b[39m)\n\u001b[1;32m 357\u001b[0m \u001b[38;5;66;03m# Save epoch checkpoint\u001b[39;00m\n",
|
| 365 |
+
"Cell \u001b[0;32mIn[8], line 292\u001b[0m, in \u001b[0;36mCatVTONTrainer.train_epoch\u001b[0;34m(self)\u001b[0m\n\u001b[1;32m 290\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39muse_mixed_precision:\n\u001b[1;32m 291\u001b[0m \u001b[38;5;28;01mwith\u001b[39;00m torch\u001b[38;5;241m.\u001b[39mcuda\u001b[38;5;241m.\u001b[39mamp\u001b[38;5;241m.\u001b[39mautocast():\n\u001b[0;32m--> 292\u001b[0m loss \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mcompute_loss\u001b[49m\u001b[43m(\u001b[49m\u001b[43mbatch\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 294\u001b[0m \u001b[38;5;66;03m# Backward pass with scaling\u001b[39;00m\n\u001b[1;32m 295\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mscaler\u001b[38;5;241m.\u001b[39mscale(loss)\u001b[38;5;241m.\u001b[39mbackward()\n",
|
| 366 |
+
"Cell \u001b[0;32mIn[8], line 211\u001b[0m, in \u001b[0;36mCatVTONTrainer.compute_loss\u001b[0;34m(self, batch)\u001b[0m\n\u001b[1;32m 207\u001b[0m \u001b[38;5;66;03m# timesteps = torch.randint(1, 1000, size=(1,), device=self.device)[0].long().item()\u001b[39;00m\n\u001b[1;32m 208\u001b[0m \u001b[38;5;66;03m# timesteps = torch.tensor(timesteps, device=self.device)\u001b[39;00m\n\u001b[1;32m 209\u001b[0m \u001b[38;5;66;03m# timesteps_embedding = get_time_embedding(timesteps).to(self.device, dtype=self.weight_dtype)\u001b[39;00m\n\u001b[1;32m 210\u001b[0m timesteps \u001b[38;5;241m=\u001b[39m torch\u001b[38;5;241m.\u001b[39mrandint(\u001b[38;5;241m1\u001b[39m, \u001b[38;5;241m1000\u001b[39m, size\u001b[38;5;241m=\u001b[39m(batch_size,))\n\u001b[0;32m--> 211\u001b[0m timesteps_embedding \u001b[38;5;241m=\u001b[39m \u001b[43mget_time_embedding\u001b[49m\u001b[43m(\u001b[49m\u001b[43mtimesteps\u001b[49m\u001b[43m)\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mto\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mdevice\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mdtype\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mweight_dtype\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 213\u001b[0m \u001b[38;5;66;03m# Add noise to latents\u001b[39;00m\n\u001b[1;32m 214\u001b[0m noisy_latents \u001b[38;5;241m=\u001b[39m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mscheduler\u001b[38;5;241m.\u001b[39madd_noise(target_latents, timesteps, noise)\n",
|
| 367 |
+
"\u001b[0;31mKeyboardInterrupt\u001b[0m: "
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 368 |
]
|
| 369 |
}
|
| 370 |
],
|
| 371 |
"source": [
|
|
|
|
|
|
|
| 372 |
"import random\n",
|
| 373 |
"import argparse\n",
|
| 374 |
"from pathlib import Path\n",
|
| 375 |
+
"from typing import Dict, Optional\n",
|
| 376 |
"\n",
|
| 377 |
"import torch\n",
|
| 378 |
"import torch.nn as nn\n",
|
| 379 |
"import torch.nn.functional as F\n",
|
| 380 |
+
"from torch.utils.data import DataLoader\n",
|
| 381 |
"from torch.optim import AdamW\n",
|
|
|
|
| 382 |
"\n",
|
| 383 |
"import numpy as np\n",
|
| 384 |
"from PIL import Image\n",
|
|
|
|
| 388 |
"# Import your custom modules\n",
|
| 389 |
"from load_model import preload_models_from_standard_weights\n",
|
| 390 |
"from ddpm import DDPMSampler\n",
|
| 391 |
+
"from utils import check_inputs, get_time_embedding, prepare_image, prepare_mask_image, compute_vae_encodings, save_debug_visualization\n",
|
| 392 |
"from diffusers.utils.torch_utils import randn_tensor\n",
|
| 393 |
"\n",
|
| 394 |
"class CatVTONTrainer:\n",
|
|
|
|
| 446 |
" # Resume from checkpoint if provided\n",
|
| 447 |
" self.global_step = 0\n",
|
| 448 |
" self.current_epoch = 0\n",
|
| 449 |
+
" \n",
|
| 450 |
+
" # Setup models and optimizers\n",
|
| 451 |
+
" self._setup_training()\n",
|
| 452 |
+
" \n",
|
| 453 |
" if resume_from_checkpoint:\n",
|
| 454 |
" self._load_checkpoint(resume_from_checkpoint)\n",
|
| 455 |
+
" \n",
|
| 456 |
+
" \n",
|
| 457 |
" \n",
|
| 458 |
" self.encoder = self.models.get('encoder', None)\n",
|
| 459 |
" self.decoder = self.models.get('decoder', None)\n",
|
| 460 |
" self.diffusion = self.models.get('diffusion', None)\n",
|
|
|
|
|
|
|
|
|
|
| 461 |
" \n",
|
| 462 |
" def _setup_training(self):\n",
|
| 463 |
" \"\"\"Setup models for training with PEFT\"\"\"\n",
|
|
|
|
| 511 |
" \"\"\"Enable PEFT training - only self-attention layers\"\"\"\n",
|
| 512 |
" print(\"Enabling PEFT training (self-attention layers only)\")\n",
|
| 513 |
" \n",
|
| 514 |
+
" unet = self.models['diffusion'].unet\n",
|
| 515 |
" \n",
|
| 516 |
" # Enable attention layers in encoders and decoders\n",
|
| 517 |
" for layers in [unet.encoders, unet.decoders]:\n",
|
|
|
|
| 526 |
" for name, param in layer.named_parameters():\n",
|
| 527 |
" if 'attention_1' in name:\n",
|
| 528 |
" param.requires_grad = True\n",
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 529 |
" \n",
|
| 530 |
" def compute_loss(self, batch: Dict[str, torch.Tensor]) -> torch.Tensor:\n",
|
| 531 |
" \"\"\"Compute MSE loss for denoising with DREAM strategy\"\"\"\n",
|
| 532 |
" person_images = batch['person'].to(self.device)\n",
|
| 533 |
" cloth_images = batch['cloth'].to(self.device)\n",
|
| 534 |
" masks = batch['mask'].to(self.device)\n",
|
| 535 |
+
" \n",
|
| 536 |
+
" batch_size = person_images.shape[0]\n",
|
| 537 |
"\n",
|
| 538 |
" concat_dim = -2 # y axis concat\n",
|
| 539 |
" \n",
|
|
|
|
| 553 |
" condition_latent = compute_vae_encodings(condition_image, self.encoder)\n",
|
| 554 |
" mask_latent = torch.nn.functional.interpolate(mask, size=masked_latent.shape[-2:], mode=\"nearest\")\n",
|
| 555 |
" \n",
|
| 556 |
+
" \n",
|
| 557 |
" del image, mask, condition_image\n",
|
| 558 |
+
" \n",
|
| 559 |
+
" # Apply CFG dropout to garment latent (10% chance)\n",
|
| 560 |
+
" if self.training and random.random() < self.cfg_dropout_prob:\n",
|
| 561 |
+
" condition_latent = torch.zeros_like(condition_latent)\n",
|
| 562 |
" \n",
|
| 563 |
" # Concatenate latents\n",
|
| 564 |
+
" input_latents = torch.cat([masked_latent, condition_latent], dim=concat_dim)\n",
|
| 565 |
+
" mask_input = torch.cat([mask_latent, torch.zeros_like(mask_latent)], dim=concat_dim)\n",
|
| 566 |
" target_latents = torch.cat([person_latent, condition_latent], dim=concat_dim)\n",
|
| 567 |
"\n",
|
| 568 |
" noise = randn_tensor(\n",
|
| 569 |
+
" target_latents.shape,\n",
|
| 570 |
" generator=self.generator,\n",
|
| 571 |
+
" device=target_latents.device,\n",
|
| 572 |
" dtype=self.weight_dtype,\n",
|
| 573 |
" )\n",
|
| 574 |
"\n",
|
| 575 |
+
" # timesteps = torch.randint(1, 1000, size=(1,), device=self.device)[0].long().item()\n",
|
| 576 |
+
" # timesteps = torch.tensor(timesteps, device=self.device)\n",
|
| 577 |
+
" # timesteps_embedding = get_time_embedding(timesteps).to(self.device, dtype=self.weight_dtype)\n",
|
| 578 |
+
" timesteps = torch.randint(1, 1000, size=(batch_size,))\n",
|
| 579 |
" timesteps_embedding = get_time_embedding(timesteps).to(self.device, dtype=self.weight_dtype)\n",
|
| 580 |
"\n",
|
| 581 |
" # Add noise to latents\n",
|
| 582 |
" noisy_latents = self.scheduler.add_noise(target_latents, timesteps, noise)\n",
|
| 583 |
"\n",
|
| 584 |
+
" # UNet(zt ⊙ Mi ⊙ Xi) where ⊙ is channel concatenation\n",
|
| 585 |
+
" unet_input = torch.cat([\n",
|
| 586 |
+
" input_latents, # Xi\n",
|
| 587 |
+
" mask_input, # Mi\n",
|
| 588 |
+
" noisy_latents, # zt\n",
|
| 589 |
+
" ], dim=1).to(self.device, dtype=self.weight_dtype) # Channel dimension\n",
|
| 590 |
+
" \n",
|
| 591 |
"\n",
|
| 592 |
" # DREAM strategy implementation\n",
|
| 593 |
" if self.dream_lambda > 0:\n",
|
| 594 |
" # Get initial noise prediction\n",
|
| 595 |
" with torch.no_grad():\n",
|
| 596 |
" epsilon_theta = self.diffusion(\n",
|
| 597 |
+
" unet_input,\n",
|
| 598 |
" timesteps_embedding\n",
|
| 599 |
" )\n",
|
| 600 |
" \n",
|
| 601 |
+
" # DREAM noise combination: ε + λ*εθ\n",
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 602 |
" dream_noise = noise + self.dream_lambda * epsilon_theta\n",
|
| 603 |
+
" \n",
|
| 604 |
+
" # Create new noisy latents with DREAM noise\n",
|
| 605 |
+
" dream_noisy_latents = self.scheduler.add_noise(target_latents, timesteps, dream_noise)\n",
|
| 606 |
"\n",
|
| 607 |
+
" dream_unet_input = torch.cat([\n",
|
| 608 |
+
" input_latents, \n",
|
| 609 |
+
" mask_input,\n",
|
| 610 |
+
" dream_noisy_latents\n",
|
| 611 |
+
" ], dim=1).to(self.device, dtype=self.weight_dtype)\n",
|
|
|
|
|
|
|
| 612 |
"\n",
|
| 613 |
" predicted_noise = self.diffusion(\n",
|
| 614 |
+
" dream_unet_input,\n",
|
| 615 |
" timesteps_embedding\n",
|
| 616 |
" )\n",
|
| 617 |
" # DREAM loss: |(ε + λεθ) - εθ(ẑt, t)|²\n",
|
|
|
|
| 619 |
" else:\n",
|
| 620 |
" # Standard training without DREAM\n",
|
| 621 |
" predicted_noise = self.diffusion(\n",
|
| 622 |
+
" unet_input,\n",
|
| 623 |
" timesteps_embedding,\n",
|
| 624 |
" )\n",
|
| 625 |
+
" \n",
|
| 626 |
" # Standard MSE loss\n",
|
| 627 |
" loss = F.mse_loss(predicted_noise, noise)\n",
|
| 628 |
" \n",
|
| 629 |
+
" if self.global_step % 1000 == 0:\n",
|
| 630 |
+
" save_debug_visualization(\n",
|
| 631 |
+
" person_images=person_images,\n",
|
| 632 |
+
" cloth_images=cloth_images, \n",
|
| 633 |
+
" masks=masks,\n",
|
| 634 |
+
" masked_image=masked_image,\n",
|
| 635 |
+
" noisy_latents=noisy_latents,\n",
|
| 636 |
+
" predicted_noise=predicted_noise,\n",
|
| 637 |
+
" target_latents=target_latents,\n",
|
| 638 |
+
" decoder=self.decoder,\n",
|
| 639 |
+
" global_step=self.global_step,\n",
|
| 640 |
+
" output_dir=self.output_dir,\n",
|
| 641 |
+
" device=self.device\n",
|
| 642 |
+
" )\n",
|
| 643 |
" return loss\n",
|
| 644 |
" \n",
|
| 645 |
" def train_epoch(self) -> float:\n",
|
|
|
|
| 806 |
" args.__dict__ = {\n",
|
| 807 |
" \"base_model_path\": \"sd-v1-5-inpainting.ckpt\",\n",
|
| 808 |
" \"dataset_name\": \"vitonhd\",\n",
|
| 809 |
+
" \"data_root_path\": \"./viton-hd-dataset\",\n",
|
| 810 |
" \"output_dir\": \"./checkpoints\",\n",
|
| 811 |
+
" \"resume_from_checkpoint\": \"./checkpoints/checkpoint_step_40000.pth\",\n",
|
| 812 |
" \"seed\": 42,\n",
|
| 813 |
+
" \"batch_size\": 2,\n",
|
| 814 |
" \"width\": 384,\n",
|
| 815 |
+
" \"height\": 384,\n",
|
| 816 |
" \"repaint\": True,\n",
|
| 817 |
" \"eval_pair\": True,\n",
|
| 818 |
" \"concat_eval_results\": True,\n",
|
|
|
|
| 822 |
" \"learning_rate\": 1e-5,\n",
|
| 823 |
" \"max_grad_norm\": 1.0,\n",
|
| 824 |
" \"cfg_dropout_prob\": 0.1,\n",
|
| 825 |
+
" \"dream_lambda\": 10.0,\n",
|
| 826 |
" \"use_peft\": True,\n",
|
| 827 |
" \"use_mixed_precision\": True,\n",
|
| 828 |
+
" \"save_steps\": 10000,\n",
|
| 829 |
" \"is_train\": True\n",
|
| 830 |
" }\n",
|
| 831 |
" \n",
|
|
|
|
| 841 |
" torch.backends.cuda.matmul.allow_tf32 = True \n",
|
| 842 |
" torch.backends.cudnn.allow_tf32 = True \n",
|
| 843 |
" torch.set_float32_matmul_precision(\"high\")\n",
|
| 844 |
+
" \n",
|
| 845 |
+
" print(\"-\"*100)\n",
|
| 846 |
"\n",
|
| 847 |
" # Load pretrained models\n",
|
| 848 |
" print(\"Loading pretrained models...\")\n",
|
| 849 |
" models = preload_models_from_standard_weights(args.base_model_path, args.device)\n",
|
| 850 |
+
" print(\"Models loaded successfully.\")\n",
|
| 851 |
+
" \n",
|
| 852 |
+
" print(\"-\"*100)\n",
|
| 853 |
" \n",
|
| 854 |
" # Create dataloader\n",
|
| 855 |
" print(\"Creating dataloader...\")\n",
|
|
|
|
| 858 |
" print(f\"Training for {args.num_epochs} epochs\")\n",
|
| 859 |
" print(f\"Batches per epoch: {len(train_dataloader)}\")\n",
|
| 860 |
" \n",
|
| 861 |
+
" print(\"-\"*100)\n",
|
| 862 |
+
" \n",
|
| 863 |
" # Initialize trainer\n",
|
| 864 |
" print(\"Initializing trainer...\") \n",
|
| 865 |
" trainer = CatVTONTrainer(\n",
|
|
|
|
| 884 |
" print(\"Starting training...\")\n",
|
| 885 |
" trainer.train() \n",
|
| 886 |
"\n",
|
|
|
|
| 887 |
"if __name__ == \"__main__\":\n",
|
| 888 |
" main()"
|
| 889 |
]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 890 |
}
|
| 891 |
],
|
| 892 |
"metadata": {
|
| 893 |
"kernelspec": {
|
| 894 |
+
"display_name": "harsh",
|
| 895 |
"language": "python",
|
| 896 |
"name": "python3"
|
| 897 |
},
|
|
|
|
| 905 |
"name": "python",
|
| 906 |
"nbconvert_exporter": "python",
|
| 907 |
"pygments_lexer": "ipython3",
|
| 908 |
+
"version": "3.10.18"
|
| 909 |
}
|
| 910 |
},
|
| 911 |
"nbformat": 4,
|