Spaces:
Runtime error
Runtime error
update code
Browse files- .gitattributes +34 -0
- app.py +393 -338
- assets/GIF.gif +0 -0
- assets/Teaser_Small.png +3 -0
- assets/examples/Lancia.webp +3 -0
- assets/examples/car.jpeg +3 -0
- assets/examples/car1.webp +3 -0
- assets/examples/carpet2.webp +3 -0
- assets/examples/chair.jpeg +3 -0
- assets/examples/chair1.jpeg +3 -0
- assets/examples/dog.jpeg +3 -0
- assets/examples/door.jpeg +3 -0
- assets/examples/door2.jpeg +3 -0
- assets/examples/grasslands-national-park.jpeg +3 -0
- assets/examples/house.jpeg +3 -0
- assets/examples/house2.jpeg +3 -0
- assets/examples/ian.jpeg +3 -0
- assets/examples/park.webp +3 -0
- assets/examples/ran.webp +3 -0
- assets/hulk.jpeg +0 -0
- assets/ironman.webp +0 -0
- assets/lava.jpg +0 -0
- assets/ski.jpg +0 -0
- assets/truck.png +0 -0
- assets/truck2.jpeg +0 -0
- cldm/appearance_networks.py +75 -0
- cldm/cldm.py +115 -118
- cldm/controlnet.py +306 -0
- cldm/ddim_hacked.py +2 -3
- cldm/logger.py +10 -10
- configs/{sap_fixed_hintnet_v15.yaml → pair_diff.yaml} +20 -7
- ldm/ldm/util.py +197 -0
- ldm/models/diffusion/ddim.py +15 -4
- ldm/modules/attention.py +61 -15
- ldm/modules/diffusionmodules/openaimodel.py +16 -4
- ldm/modules/diffusionmodules/util.py +2 -1
- ldm/modules/encoders/modules.py +3 -3
- pair_diff_demo.py +516 -0
- requirements.txt +2 -1
.gitattributes
CHANGED
|
@@ -32,3 +32,37 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
|
| 32 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
| 33 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
| 34 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 32 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
| 33 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
| 34 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
| 35 |
+
assets/examples/ian.jpeg filter=lfs diff=lfs merge=lfs -text
|
| 36 |
+
assets/examples/resized_anm_38.jpg filter=lfs diff=lfs merge=lfs -text
|
| 37 |
+
assets/examples/anm_8.jpg filter=lfs diff=lfs merge=lfs -text
|
| 38 |
+
assets/examples/house.jpeg filter=lfs diff=lfs merge=lfs -text
|
| 39 |
+
assets/examples/door2.jpeg filter=lfs diff=lfs merge=lfs -text
|
| 40 |
+
assets/examples/door.jpeg filter=lfs diff=lfs merge=lfs -text
|
| 41 |
+
assets/examples/frn_38.jpg filter=lfs diff=lfs merge=lfs -text
|
| 42 |
+
assets/examples/park.webp filter=lfs diff=lfs merge=lfs -text
|
| 43 |
+
assets/examples/car1.webp filter=lfs diff=lfs merge=lfs -text
|
| 44 |
+
assets/examples/car.jpeg filter=lfs diff=lfs merge=lfs -text
|
| 45 |
+
assets/examples/house2.jpeg filter=lfs diff=lfs merge=lfs -text
|
| 46 |
+
assets/examples/Lancia.webp filter=lfs diff=lfs merge=lfs -text
|
| 47 |
+
assets/examples/obj_11.jpg filter=lfs diff=lfs merge=lfs -text
|
| 48 |
+
assets/examples/resized_anm_8.jpg filter=lfs diff=lfs merge=lfs -text
|
| 49 |
+
assets/examples/resized_frn_38.jpg filter=lfs diff=lfs merge=lfs -text
|
| 50 |
+
assets/examples/resized_obj_11.jpg filter=lfs diff=lfs merge=lfs -text
|
| 51 |
+
assets/examples/dog.jpeg filter=lfs diff=lfs merge=lfs -text
|
| 52 |
+
assets/examples/grasslands-national-park.jpeg filter=lfs diff=lfs merge=lfs -text
|
| 53 |
+
assets/examples/resized_obj_38.jpg filter=lfs diff=lfs merge=lfs -text
|
| 54 |
+
assets/examples/chair1.jpeg filter=lfs diff=lfs merge=lfs -text
|
| 55 |
+
assets/examples/chair.jpeg filter=lfs diff=lfs merge=lfs -text
|
| 56 |
+
assets/examples/obj_38.jpg filter=lfs diff=lfs merge=lfs -text
|
| 57 |
+
assets/examples/ran.webp filter=lfs diff=lfs merge=lfs -text
|
| 58 |
+
assets/examples/anm_38.jpg filter=lfs diff=lfs merge=lfs -text
|
| 59 |
+
assets/examples/carpet2.webp filter=lfs diff=lfs merge=lfs -text
|
| 60 |
+
assets/ironman.webp filter=lfs diff=lfs merge=lfs -text
|
| 61 |
+
assets/truck2.jpeg filter=lfs diff=lfs merge=lfs -text
|
| 62 |
+
assets/truck.png filter=lfs diff=lfs merge=lfs -text
|
| 63 |
+
assets/ski.jpg filter=lfs diff=lfs merge=lfs -text
|
| 64 |
+
assets/Teaser_Small.png filter=lfs diff=lfs merge=lfs -text
|
| 65 |
+
assets/examples filter=lfs diff=lfs merge=lfs -text
|
| 66 |
+
assets/GIF.gif filter=lfs diff=lfs merge=lfs -text
|
| 67 |
+
assets/hulk.jpeg filter=lfs diff=lfs merge=lfs -text
|
| 68 |
+
assets/lava.jpg filter=lfs diff=lfs merge=lfs -text
|
app.py
CHANGED
|
@@ -1,429 +1,484 @@
|
|
| 1 |
-
|
| 2 |
-
|
| 3 |
-
import einops
|
| 4 |
import gradio as gr
|
| 5 |
-
|
| 6 |
-
import torch
|
| 7 |
-
import random
|
| 8 |
-
import os
|
| 9 |
-
import subprocess
|
| 10 |
-
import shlex
|
| 11 |
-
|
| 12 |
-
from huggingface_hub import hf_hub_url, hf_hub_download
|
| 13 |
-
from share import *
|
| 14 |
-
|
| 15 |
-
from pytorch_lightning import seed_everything
|
| 16 |
-
from annotator.util import resize_image, HWC3
|
| 17 |
-
from annotator.OneFormer import OneformerSegmenter
|
| 18 |
-
from cldm.model import create_model, load_state_dict
|
| 19 |
-
from cldm.ddim_hacked import DDIMSamplerSpaCFG
|
| 20 |
-
from ldm.models.autoencoder import DiagonalGaussianDistribution
|
| 21 |
-
|
| 22 |
-
urls = {
|
| 23 |
-
'shi-labs/oneformer_coco_swin_large': ['150_16_swin_l_oneformer_coco_100ep.pth'],
|
| 24 |
-
'PAIR/PAIR-diffusion-sdv15-coco-finetune': ['pair_diffusion_epoch62.ckpt']
|
| 25 |
-
}
|
| 26 |
-
|
| 27 |
-
WTS_DICT = {
|
| 28 |
-
|
| 29 |
-
}
|
| 30 |
-
|
| 31 |
-
if os.path.exists('checkpoints') == False:
|
| 32 |
-
os.mkdir('checkpoints')
|
| 33 |
-
for repo in urls:
|
| 34 |
-
files = urls[repo]
|
| 35 |
-
for file in files:
|
| 36 |
-
url = hf_hub_url(repo, file)
|
| 37 |
-
name_ckp = url.split('/')[-1]
|
| 38 |
-
WTS_DICT[repo] = hf_hub_download(repo_id=repo, filename=file, token=os.environ.get("ACCESS_TOKEN"))
|
| 39 |
-
|
| 40 |
-
print(WTS_DICT)
|
| 41 |
-
apply_segmentor = OneformerSegmenter(WTS_DICT['shi-labs/oneformer_coco_swin_large'])
|
| 42 |
-
|
| 43 |
-
model = create_model('./configs/sap_fixed_hintnet_v15.yaml').cpu()
|
| 44 |
-
model.load_state_dict(load_state_dict(WTS_DICT['PAIR/PAIR-diffusion-sdv15-coco-finetune'], location='cuda'))
|
| 45 |
-
model = model.cuda()
|
| 46 |
-
ddim_sampler = DDIMSamplerSpaCFG(model)
|
| 47 |
-
_COLORS = []
|
| 48 |
-
save_memory = False
|
| 49 |
-
|
| 50 |
-
def gen_color():
|
| 51 |
-
color = tuple(np.round(np.random.choice(range(256), size=3), 3))
|
| 52 |
-
if color not in _COLORS and np.mean(color) != 0.0:
|
| 53 |
-
_COLORS.append(color)
|
| 54 |
-
else:
|
| 55 |
-
gen_color()
|
| 56 |
-
|
| 57 |
-
|
| 58 |
-
for _ in range(300):
|
| 59 |
-
gen_color()
|
| 60 |
-
|
| 61 |
|
| 62 |
-
|
| 63 |
-
def __init__(self, edit_operation):
|
| 64 |
-
self.input_img = None
|
| 65 |
-
self.input_pmask = None
|
| 66 |
-
self.input_segmask = None
|
| 67 |
|
| 68 |
-
|
| 69 |
-
|
| 70 |
-
self.ref_segmask = None
|
| 71 |
-
|
| 72 |
-
self.H = None
|
| 73 |
-
self.W = None
|
| 74 |
-
self.baseoutput = None
|
| 75 |
-
self.kernel = np.ones((5, 5), np.uint8)
|
| 76 |
-
self.edit_operation = edit_operation
|
| 77 |
-
|
| 78 |
-
def init_input_canvas(self, img):
|
| 79 |
-
img = HWC3(img)
|
| 80 |
-
img = resize_image(img, 512)
|
| 81 |
-
detected_mask = apply_segmentor(img, 'panoptic')[0]
|
| 82 |
-
detected_seg = apply_segmentor(img, 'semantic')
|
| 83 |
|
| 84 |
-
|
| 85 |
-
|
| 86 |
-
self.input_segmask = detected_seg
|
| 87 |
-
self.H = img.shape[0]
|
| 88 |
-
self.W = img.shape[1]
|
| 89 |
-
|
| 90 |
-
detected_mask = detected_mask.cpu().numpy()
|
| 91 |
|
| 92 |
-
|
| 93 |
-
|
| 94 |
-
for i in uni:
|
| 95 |
-
color_mask[detected_mask == i] = _COLORS[i]
|
| 96 |
|
| 97 |
-
|
| 98 |
-
|
| 99 |
-
return self.baseoutput
|
| 100 |
-
|
| 101 |
-
def init_ref_canvas(self, img):
|
| 102 |
-
img = HWC3(img)
|
| 103 |
-
img = resize_image(img, 512)
|
| 104 |
-
detected_mask = apply_segmentor(img, 'panoptic')[0]
|
| 105 |
-
detected_seg = apply_segmentor(img, 'semantic')
|
| 106 |
-
|
| 107 |
-
self.ref_img = img
|
| 108 |
-
self.ref_pmask = detected_mask
|
| 109 |
-
self.ref_segmask = detected_seg
|
| 110 |
-
|
| 111 |
-
detected_mask = detected_mask.cpu().numpy()
|
| 112 |
-
|
| 113 |
-
uni = np.unique(detected_mask)
|
| 114 |
-
color_mask = np.zeros((detected_mask.shape[0], detected_mask.shape[1], 3))
|
| 115 |
-
for i in uni:
|
| 116 |
-
color_mask[detected_mask == i] = _COLORS[i]
|
| 117 |
-
|
| 118 |
-
output = color_mask*0.8 + img * 0.2
|
| 119 |
-
self.baseoutput = output.astype(np.uint8)
|
| 120 |
-
return self.baseoutput
|
| 121 |
-
|
| 122 |
-
def _process_mask(self, mask, panoptic_mask, segmask):
|
| 123 |
-
panoptic_mask_ = panoptic_mask + 1
|
| 124 |
-
mask_ = resize_image(mask['mask'][:, :, 0], min(panoptic_mask.shape))
|
| 125 |
-
mask_ = torch.tensor(mask_)
|
| 126 |
-
maski = torch.zeros_like(mask_).cuda()
|
| 127 |
-
maski[mask_ > 127] = 1
|
| 128 |
-
mask = maski * panoptic_mask_
|
| 129 |
-
unique_ids, counts = torch.unique(mask, return_counts=True)
|
| 130 |
-
mask_id = unique_ids[torch.argmax(counts[1:]) + 1]
|
| 131 |
-
final_mask = torch.zeros(mask.shape).cuda()
|
| 132 |
-
final_mask[panoptic_mask_ == mask_id] = 1
|
| 133 |
-
|
| 134 |
-
obj_class = maski * (segmask + 1)
|
| 135 |
-
unique_ids, counts = torch.unique(obj_class, return_counts=True)
|
| 136 |
-
obj_class = unique_ids[torch.argmax(counts[1:]) + 1] - 1
|
| 137 |
-
return final_mask, obj_class
|
| 138 |
-
|
| 139 |
-
|
| 140 |
-
def _edit_app(self, input_mask, ref_mask, whole_ref):
|
| 141 |
-
input_pmask = self.input_pmask
|
| 142 |
-
input_segmask = self.input_segmask
|
| 143 |
|
| 144 |
-
|
| 145 |
-
|
| 146 |
-
else:
|
| 147 |
-
reference_mask, _ = self._process_mask(ref_mask, self.ref_pmask, self.ref_segmask)
|
| 148 |
|
| 149 |
-
|
| 150 |
-
|
| 151 |
-
input_pmask[edit_mask == 1] = ma + 1
|
| 152 |
-
return reference_mask, input_pmask, input_segmask, edit_mask, ma
|
| 153 |
|
| 154 |
-
|
| 155 |
-
|
| 156 |
-
input_img = (self.input_img/127.5 - 1)
|
| 157 |
-
input_img = torch.from_numpy(input_img.astype(np.float32)).cuda().unsqueeze(0).permute(0,3,1,2)
|
| 158 |
|
| 159 |
-
|
| 160 |
-
|
| 161 |
|
| 162 |
-
|
|
|
|
| 163 |
|
| 164 |
-
|
| 165 |
-
|
|
|
|
|
|
|
|
|
|
| 166 |
|
| 167 |
-
|
| 168 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 169 |
|
| 170 |
-
if mean_feat_ref.shape[1] > 1:
|
| 171 |
-
mean_feat_inpt[:, ma + 1] = (1 - inter) * mean_feat_inpt[:, ma + 1] + inter*mean_feat_ref[:, 1]
|
| 172 |
|
| 173 |
-
|
| 174 |
-
|
| 175 |
-
|
| 176 |
-
|
| 177 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 178 |
|
|
|
|
| 179 |
|
| 180 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 181 |
|
| 182 |
-
|
| 183 |
-
|
| 184 |
-
|
| 185 |
-
|
| 186 |
-
whole_ref=whole_ref, inter=inter)
|
| 187 |
-
|
| 188 |
-
null_structure = torch.zeros(structure.shape).cuda() - 1
|
| 189 |
-
null_appearance = torch.zeros(appearance.shape).cuda()
|
| 190 |
-
|
| 191 |
-
null_control = torch.cat([null_structure, null_appearance], dim=1)
|
| 192 |
-
structure_control = torch.cat([structure, null_appearance], dim=1)
|
| 193 |
-
full_control = torch.cat([structure, appearance], dim=1)
|
| 194 |
-
|
| 195 |
-
null_control = torch.cat([null_control for _ in range(num_samples)], dim=0)
|
| 196 |
-
structure_control = torch.cat([structure_control for _ in range(num_samples)], dim=0)
|
| 197 |
-
full_control = torch.cat([full_control for _ in range(num_samples)], dim=0)
|
| 198 |
-
|
| 199 |
-
#Masking for local edit
|
| 200 |
-
if not masking:
|
| 201 |
-
mask, x0 = None, None
|
| 202 |
-
else:
|
| 203 |
-
x0 = model.encode_first_stage(img)
|
| 204 |
-
x0 = x0.sample() if isinstance(x0, DiagonalGaussianDistribution) else x0 # todo: check if we can set random number
|
| 205 |
-
x0 = x0 * model.scale_factor
|
| 206 |
-
mask = 1 - torch.tensor(mask).unsqueeze(0).unsqueeze(1).cuda()
|
| 207 |
-
mask = torch.nn.functional.interpolate(mask, x0.shape[2:]).float()
|
| 208 |
-
|
| 209 |
-
if seed == -1:
|
| 210 |
-
seed = random.randint(0, 65535)
|
| 211 |
-
seed_everything(seed)
|
| 212 |
|
| 213 |
-
|
| 214 |
-
print(scale)
|
| 215 |
-
if save_memory:
|
| 216 |
-
model.low_vram_shift(is_diffusing=False)
|
| 217 |
-
# uc_cross = model.get_unconditional_conditioning(num_samples)
|
| 218 |
-
uc_cross = model.get_learned_conditioning([n_prompt] * num_samples)
|
| 219 |
-
cond = {"c_concat": [full_control], "c_crossattn": [model.get_learned_conditioning([prompt + ', ' + a_prompt] * num_samples)]}
|
| 220 |
-
un_cond = {"c_concat": None if guess_mode else [null_control], "c_crossattn": [uc_cross]}
|
| 221 |
-
un_cond_struct = {"c_concat": None if guess_mode else [structure_control], "c_crossattn": [uc_cross]}
|
| 222 |
-
un_cond_struct_app = {"c_concat": None if guess_mode else [full_control], "c_crossattn": [uc_cross]}
|
| 223 |
|
| 224 |
-
|
|
|
|
| 225 |
|
| 226 |
-
|
| 227 |
-
|
| 228 |
|
| 229 |
-
|
| 230 |
-
|
| 231 |
-
|
| 232 |
-
|
| 233 |
-
|
| 234 |
|
| 235 |
-
|
| 236 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 237 |
|
| 238 |
-
|
| 239 |
-
x_samples = (einops.rearrange(x_samples, 'b c h w -> b h w c')).cpu().numpy().clip(0, 255).astype(np.uint8)
|
| 240 |
|
| 241 |
-
|
| 242 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 243 |
|
| 244 |
|
| 245 |
-
def
|
| 246 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 247 |
|
| 248 |
-
|
| 249 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 250 |
|
| 251 |
-
|
| 252 |
-
|
|
|
|
| 253 |
|
|
|
|
|
|
|
|
|
|
| 254 |
|
|
|
|
| 255 |
|
| 256 |
-
|
| 257 |
-
|
| 258 |
-
|
| 259 |
-
|
| 260 |
-
.
|
| 261 |
-
|
| 262 |
-
|
| 263 |
-
|
| 264 |
-
|
| 265 |
-
|
| 266 |
-
|
| 267 |
-
|
| 268 |
-
|
| 269 |
-
|
| 270 |
-
.
|
| 271 |
-
|
| 272 |
-
|
| 273 |
-
|
| 274 |
-
|
| 275 |
-
.
|
| 276 |
-
|
| 277 |
-
|
| 278 |
-
|
| 279 |
-
|
| 280 |
-
|
| 281 |
|
| 282 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 283 |
|
|
|
|
| 284 |
with gr.Row():
|
| 285 |
-
gr.Markdown("##
|
| 286 |
with gr.Row():
|
| 287 |
gr.HTML(
|
| 288 |
"""
|
| 289 |
-
<div
|
| 290 |
-
|
| 291 |
-
|
| 292 |
-
|
| 293 |
-
|
| 294 |
-
|
| 295 |
-
|
| 296 |
-
|
| 297 |
-
|
| 298 |
-
|
| 299 |
-
|
| 300 |
-
<div class="image">
|
| 301 |
-
<img src="file/assets/GIF.gif" width="400"">
|
| 302 |
-
</div>
|
| 303 |
-
</div>
|
| 304 |
-
""")
|
| 305 |
with gr.Column():
|
| 306 |
with gr.Row():
|
| 307 |
img_edit = gr.State(ImageComp('edit_app'))
|
| 308 |
with gr.Column():
|
| 309 |
-
btn1 = gr.Button("Input Image")
|
| 310 |
input_image = gr.Image(source='upload', label='Input Image', type="numpy",)
|
| 311 |
with gr.Column():
|
| 312 |
-
|
| 313 |
-
input_mask = gr.Image(source="upload", label='Select Object in Input Image', type="numpy", tool="sketch")
|
| 314 |
-
input_image.change(fn=init_input_canvas_wrapper, inputs=[img_edit, input_image], outputs=[input_mask], queue=False)
|
| 315 |
-
|
| 316 |
-
# with gr.Row():
|
| 317 |
-
with gr.Column():
|
| 318 |
-
btn3 = gr.Button("Reference Image")
|
| 319 |
-
ref_img = gr.Image(source='upload', label='Reference Image', type="numpy")
|
| 320 |
-
with gr.Column():
|
| 321 |
-
btn4 = gr.Button("Select Reference Object")
|
| 322 |
-
reference_mask = gr.Image(source="upload", label='Select Object in Refernce Image', type="numpy", tool="sketch")
|
| 323 |
|
| 324 |
-
ref_img.change(fn=init_ref_canvas_wrapper, inputs=[img_edit, ref_img], outputs=[reference_mask], queue=False)
|
| 325 |
-
|
| 326 |
with gr.Row():
|
| 327 |
-
prompt = gr.Textbox(label="Prompt", value='
|
| 328 |
-
|
| 329 |
-
|
| 330 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 331 |
with gr.Row():
|
| 332 |
run_button = gr.Button(label="Run")
|
|
|
|
| 333 |
|
| 334 |
with gr.Row():
|
| 335 |
result_gallery = gr.Gallery(label='Output', show_label=False, elem_id="gallery").style(grid=4, height='auto')
|
| 336 |
|
| 337 |
with gr.Accordion("Advanced options", open=False):
|
| 338 |
-
num_samples = gr.Slider(label="Images", minimum=1, maximum=12, value=
|
|
|
|
| 339 |
strength = gr.Slider(label="Control Strength", minimum=0.0, maximum=2.0, value=1.0, step=0.01)
|
| 340 |
guess_mode = gr.Checkbox(label='Guess Mode', value=False)
|
| 341 |
ddim_steps = gr.Slider(label="Steps", minimum=1, maximum=100, value=20, step=1)
|
| 342 |
-
|
| 343 |
-
|
| 344 |
-
|
|
|
|
| 345 |
seed = gr.Slider(label="Seed", minimum=-1, maximum=2147483647, step=1, randomize=True)
|
| 346 |
eta = gr.Number(label="eta (DDIM)", value=0.0)
|
| 347 |
masking = gr.Checkbox(label='Only edit the local region', value=True)
|
| 348 |
a_prompt = gr.Textbox(label="Added Prompt", value='best quality, extremely detailed')
|
| 349 |
n_prompt = gr.Textbox(label="Negative Prompt",
|
| 350 |
value='longbody, lowres, bad anatomy, bad hands, missing fingers, extra digit, fewer digits, cropped, worst quality, low quality')
|
| 351 |
-
|
|
|
|
|
|
|
|
|
|
| 352 |
with gr.Column():
|
| 353 |
gr.Examples(
|
| 354 |
-
examples=[['
|
| 355 |
-
['
|
| 356 |
-
['
|
| 357 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 358 |
outputs=None,
|
| 359 |
fn=None,
|
| 360 |
cache_examples=False,
|
| 361 |
)
|
| 362 |
-
ips = [input_mask,
|
| 363 |
-
scale_s, scale_f, scale_t, seed, eta,
|
|
|
|
|
|
|
| 364 |
run_button.click(fn=process_wrapper, inputs=[img_edit, *ips], outputs=[result_gallery])
|
|
|
|
| 365 |
|
| 366 |
|
| 367 |
-
|
| 368 |
-
def create_struct_demo():
|
| 369 |
with gr.Row():
|
| 370 |
-
gr.Markdown("##
|
| 371 |
-
|
| 372 |
-
def create_both_demo():
|
| 373 |
with gr.Row():
|
| 374 |
-
gr.
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 375 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 376 |
|
| 377 |
|
| 378 |
-
block = gr.Blocks(css=css).queue()
|
| 379 |
with block:
|
| 380 |
gr.HTML(
|
| 381 |
"""
|
| 382 |
<div style="text-align: center; max-width: 1200px; margin: 20px auto;">
|
| 383 |
<h1 style="font-weight: 900; font-size: 3rem; margin: 0rem">
|
| 384 |
-
PAIR Diffusion
|
| 385 |
</h1>
|
| 386 |
-
<
|
| 387 |
-
<a href="https://vidit98.github.io/" style="color:blue;">Vidit Goel</a><sup>1*</sup>,
|
| 388 |
-
<a href="https://helia95.github.io/" style="color:blue;">Elia Peruzzo</a><sup>1,2*</sup>,
|
| 389 |
-
<a href="https://yifanjiang19.github.io/" style="color:blue;">Yifan Jiang</a><sup>3</sup>,
|
| 390 |
-
<a href="https://ir1d.github.io/" style="color:blue;">Dejia Xu</a><sup>3</sup>,
|
| 391 |
-
<a href="http://disi.unitn.it/~sebe/" style="color:blue;">Nicu Sebe</a><sup>2</sup>, <br>
|
| 392 |
-
<a href=" https://people.eecs.berkeley.edu/~trevor/" style="color:blue;">Trevor Darrell</a><sup>4</sup>,
|
| 393 |
-
<a href="https://vita-group.github.io/" style="color:blue;">Zhangyang Wang</a><sup>1,3</sup>
|
| 394 |
-
and <a href="https://www.humphreyshi.com/home" style="color:blue;">Humphrey Shi</a> <sup>1,5,6</sup> <br>
|
| 395 |
-
[<a href="https://arxiv.org/abs/2303.17546" style="color:red;">arXiv</a>]
|
| 396 |
-
[<a href="https://github.com/Picsart-AI-Research/PAIR-Diffusion" style="color:red;">GitHub</a>]
|
| 397 |
-
</h2>
|
| 398 |
-
<h3 style="font-weight: 450; font-size: 1rem; margin: 0rem">
|
| 399 |
-
<sup>1</sup>Picsart AI Resarch (PAIR), <sup>2</sup>UTrenton, <sup>3</sup>UT Austin, <sup>4</sup>UC Berkeley, <sup>5</sup>UOregon, <sup>6</sup>UIUC
|
| 400 |
-
</h3>
|
| 401 |
<h2 style="font-weight: 450; font-size: 1rem; margin-top: 0.8rem; margin-bottom: 0.8rem">
|
| 402 |
-
|
| 403 |
-
|
| 404 |
-
|
| 405 |
-
|
| 406 |
-
|
| 407 |
-
where we need to have consistent appearance across time in case of video or across various viewing positions in case of 3D.
|
| 408 |
</h2>
|
| 409 |
-
|
| 410 |
</div>
|
| 411 |
""")
|
| 412 |
|
| 413 |
-
gr.HTML("""
|
| 414 |
-
<p>For faster inference without waiting in queue, you may duplicate the space and upgrade to GPU in settings.
|
| 415 |
-
<br/>
|
| 416 |
-
<a href="https://huggingface.co/spaces/PAIR/PAIR-Diffusion?duplicate=true">
|
| 417 |
-
<img style="margin-top: 0em; margin-bottom: 0em" src="https://bit.ly/3gLdBN6" alt="Duplicate Space"></a>
|
| 418 |
-
</p>""")
|
| 419 |
-
|
| 420 |
with gr.Tab('Edit Appearance'):
|
| 421 |
create_app_demo()
|
| 422 |
-
with gr.Tab('
|
| 423 |
-
|
| 424 |
-
with gr.Tab('
|
| 425 |
-
|
| 426 |
-
|
|
|
|
| 427 |
|
| 428 |
block.queue(max_size=20)
|
| 429 |
-
block.launch(
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
import gradio as gr
|
| 2 |
+
from pair_diff_demo import ImageComp
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 3 |
|
| 4 |
+
# torch.cuda.set_per_process_memory_fraction(0.6)
|
|
|
|
|
|
|
|
|
|
|
|
|
| 5 |
|
| 6 |
+
def init_input_canvas_wrapper(obj, *args):
|
| 7 |
+
return obj.init_input_canvas(*args)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 8 |
|
| 9 |
+
def init_ref_canvas_wrapper(obj, *args):
|
| 10 |
+
return obj.init_ref_canvas(*args)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 11 |
|
| 12 |
+
def select_input_object_wrapper(obj, evt: gr.SelectData):
|
| 13 |
+
return obj.select_input_object(evt)
|
|
|
|
|
|
|
| 14 |
|
| 15 |
+
def select_ref_object_wrapper(obj, evt: gr.SelectData):
|
| 16 |
+
return obj.select_ref_object(evt)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 17 |
|
| 18 |
+
def process_wrapper(obj, *args):
|
| 19 |
+
return obj.process(*args)
|
|
|
|
|
|
|
| 20 |
|
| 21 |
+
def set_multi_modal_wrapper(obj, *args):
|
| 22 |
+
return obj.set_multi_modal(*args)
|
|
|
|
|
|
|
| 23 |
|
| 24 |
+
def save_result_wrapper(obj, *args):
|
| 25 |
+
return obj.save_result(*args)
|
|
|
|
|
|
|
| 26 |
|
| 27 |
+
def return_input_img_wrapper(obj):
|
| 28 |
+
return obj.return_input_img()
|
| 29 |
|
| 30 |
+
def get_caption_wrapper(obj, *args):
|
| 31 |
+
return obj.get_caption(*args)
|
| 32 |
|
| 33 |
+
def multimodal_params(b):
|
| 34 |
+
if b:
|
| 35 |
+
return 10, 3, 6
|
| 36 |
+
else:
|
| 37 |
+
return 6, 8, 9
|
| 38 |
|
| 39 |
+
theme = gr.themes.Soft(
|
| 40 |
+
primary_hue="purple",
|
| 41 |
+
font_mono=[gr.themes.GoogleFont("IBM Plex Mono"), "ui-monospace", "Consolas", 'monospace'],
|
| 42 |
+
).set(
|
| 43 |
+
block_label_background_fill_dark='*neutral_800'
|
| 44 |
+
)
|
| 45 |
|
|
|
|
|
|
|
| 46 |
|
| 47 |
+
css = """
|
| 48 |
+
#customized_imbox {
|
| 49 |
+
min-height: 450px;
|
| 50 |
+
}
|
| 51 |
+
#customized_imbox>div[data-testid="image"] {
|
| 52 |
+
min-height: 450px;
|
| 53 |
+
}
|
| 54 |
+
#customized_imbox>div[data-testid="image"]>div {
|
| 55 |
+
min-height: 450px;
|
| 56 |
+
}
|
| 57 |
+
#customized_imbox>div[data-testid="image"]>iframe {
|
| 58 |
+
min-height: 450px;
|
| 59 |
+
}
|
| 60 |
+
#customized_imbox>div.unpadded_box {
|
| 61 |
+
min-height: 450px;
|
| 62 |
+
}
|
| 63 |
+
#myinst {
|
| 64 |
+
font-size: 0.8rem;
|
| 65 |
+
margin: 0rem;
|
| 66 |
+
color: #6B7280;
|
| 67 |
+
}
|
| 68 |
+
#maskinst {
|
| 69 |
+
text-align: justify;
|
| 70 |
+
min-width: 1200px;
|
| 71 |
+
}
|
| 72 |
+
#maskinst>img {
|
| 73 |
+
min-width:399px;
|
| 74 |
+
max-width:450px;
|
| 75 |
+
vertical-align: top;
|
| 76 |
+
display: inline-block;
|
| 77 |
+
}
|
| 78 |
+
#maskinst:after {
|
| 79 |
+
content: "";
|
| 80 |
+
width: 100%;
|
| 81 |
+
display: inline-block;
|
| 82 |
+
}
|
| 83 |
+
"""
|
| 84 |
|
| 85 |
+
def create_app_demo():
|
| 86 |
|
| 87 |
+
with gr.Row():
|
| 88 |
+
gr.Markdown("## Object Level Appearance Editing")
|
| 89 |
+
with gr.Row():
|
| 90 |
+
gr.HTML(
|
| 91 |
+
"""
|
| 92 |
+
<div style="text-align: left; max-width: 1200px;">
|
| 93 |
+
<h3 style="font-weight: 450; font-size: 1rem; margin-top: 0.8rem; margin-bottom: 0.8rem">
|
| 94 |
+
Instructions </h3>
|
| 95 |
+
<ol>
|
| 96 |
+
<li>Upload an Input Image.</li>
|
| 97 |
+
<li>Mark one of segmented objects in the <i>Select Object to Edit</i> tab.</li>
|
| 98 |
+
<li>Upload an Reference Image.</li>
|
| 99 |
+
<li>Mark one of segmented objects in the <i>Select Reference Object</i> tab, whose appearance needs to used in the selected input object.</li>
|
| 100 |
+
<li>Enter a prompt and press <i>Run</i> button. (A very simple would also work) </li>
|
| 101 |
+
</ol>
|
| 102 |
+
</ol>
|
| 103 |
+
</div>""")
|
| 104 |
+
with gr.Column():
|
| 105 |
+
with gr.Row():
|
| 106 |
+
img_edit = gr.State(ImageComp('edit_app'))
|
| 107 |
+
with gr.Column():
|
| 108 |
+
input_image = gr.Image(source='upload', label='Input Image', type="numpy",)
|
| 109 |
+
with gr.Column():
|
| 110 |
+
input_mask = gr.Image(source="upload", label='Select Object in Input Image', type="numpy",)
|
| 111 |
+
|
| 112 |
+
with gr.Column():
|
| 113 |
+
ref_img = gr.Image(source='upload', label='Reference Image', type="numpy")
|
| 114 |
+
with gr.Column():
|
| 115 |
+
reference_mask = gr.Image(source="upload", label='Select Object in Refernce Image', type="numpy")
|
| 116 |
|
| 117 |
+
with gr.Row():
|
| 118 |
+
with gr.Column():
|
| 119 |
+
prompt = gr.Textbox(label="Prompt", value='A picture of truck')
|
| 120 |
+
mulitmod = gr.Checkbox(label='Multi-Modal', value=False)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 121 |
|
| 122 |
+
mulitmod.change(fn=set_multi_modal_wrapper, inputs=[img_edit, mulitmod])
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 123 |
|
| 124 |
+
input_image.change(fn=init_input_canvas_wrapper, inputs=[img_edit, input_image], outputs=[input_image], show_progress=True)
|
| 125 |
+
input_image.select(fn=select_input_object_wrapper, inputs=[img_edit], outputs=[input_mask, prompt])
|
| 126 |
|
| 127 |
+
ref_img.change(fn=init_ref_canvas_wrapper, inputs=[img_edit, ref_img], outputs=[ref_img], show_progress=True)
|
| 128 |
+
ref_img.select(fn=select_ref_object_wrapper, inputs=[img_edit], outputs=[reference_mask])
|
| 129 |
|
| 130 |
+
with gr.Column():
|
| 131 |
+
interpolation = gr.Slider(label="Mixing ratio of appearance from reference object", minimum=0.1, maximum=1, value=1.0, step=0.1)
|
| 132 |
+
whole_ref = gr.Checkbox(label='Use whole reference Image for appearance (Only useful for style transfers)', visible=False)
|
| 133 |
+
|
| 134 |
+
# clear_button.click(fn=img_edit.clear_points, inputs=[], outputs=[input_mask, reference_mask])
|
| 135 |
|
| 136 |
+
with gr.Row():
|
| 137 |
+
run_button = gr.Button(label="Run")
|
| 138 |
+
save_button = gr.Button("Save")
|
| 139 |
+
|
| 140 |
+
with gr.Row():
|
| 141 |
+
result_gallery = gr.Gallery(label='Output', show_label=False, elem_id="gallery").style(grid=4, height='auto')
|
| 142 |
+
|
| 143 |
+
with gr.Accordion("Advanced options", open=False):
|
| 144 |
+
num_samples = gr.Slider(label="Images", minimum=1, maximum=12, value=4, step=1)
|
| 145 |
+
image_resolution = gr.Slider(label="Image Resolution", minimum=512, maximum=512, value=512, step=64)
|
| 146 |
+
strength = gr.Slider(label="Control Strength", minimum=0.0, maximum=2.0, value=1.0, step=0.01)
|
| 147 |
+
guess_mode = gr.Checkbox(label='Guess Mode', value=False)
|
| 148 |
+
ddim_steps = gr.Slider(label="Steps", minimum=1, maximum=100, value=20, step=1)
|
| 149 |
+
scale_t = gr.Slider(label="Guidance Scale Text", minimum=0., maximum=30.0, value=6.0, step=0.1)
|
| 150 |
+
scale_f = gr.Slider(label="Guidance Scale Appearance", minimum=0., maximum=30.0, value=8.0, step=0.1)
|
| 151 |
+
scale_s = gr.Slider(label="Guidance Scale Structure", minimum=0., maximum=30.0, value=9.0, step=0.1)
|
| 152 |
+
seed = gr.Slider(label="Seed", minimum=-1, maximum=2147483647, step=1, randomize=True)
|
| 153 |
+
eta = gr.Number(label="eta (DDIM)", value=0.0)
|
| 154 |
+
masking = gr.Checkbox(label='Only edit the local region', value=True)
|
| 155 |
+
a_prompt = gr.Textbox(label="Added Prompt", value='best quality, extremely detailed')
|
| 156 |
+
n_prompt = gr.Textbox(label="Negative Prompt",
|
| 157 |
+
value='longbody, lowres, bad anatomy, bad hands, missing fingers, extra digit, fewer digits, cropped, worst quality, low quality')
|
| 158 |
+
dil = gr.Slider(label="Merging region around Edge", minimum=0, maximum=0, value=0, step=0)
|
| 159 |
+
|
| 160 |
+
with gr.Column():
|
| 161 |
+
gr.Examples(
|
| 162 |
+
examples=[['assets/examples/car.jpeg','assets/examples/ian.jpeg', '', 709736989, 6, 8, 9],
|
| 163 |
+
['assets/examples/ian.jpeg','assets/examples/car.jpeg', '', 709736989, 6, 8, 9],
|
| 164 |
+
['assets/examples/car.jpeg','assets/examples/ran.webp', '', 709736989, 6, 8, 9],
|
| 165 |
+
['assets/examples/car.jpeg','assets/examples/car1.webp', '', 709736989, 6, 8, 9],
|
| 166 |
+
['assets/examples/car1.webp','assets/examples/car.jpeg', '', 709736989, 6, 8, 9],
|
| 167 |
+
['assets/examples/chair.jpeg','assets/examples/chair1.jpeg', '', 1106204668, 6, 8, 9],
|
| 168 |
+
['assets/examples/house.jpeg','assets/examples/house2.jpeg', '', 1106204668, 6, 8, 9],
|
| 169 |
+
['assets/examples/house2.jpeg','assets/examples/house.jpeg', '', 1106204668, 6, 8, 9],
|
| 170 |
+
['assets/examples/park.webp','assets/examples/grasslands-national-park.jpeg', '', 1106204668, 6, 8, 9],
|
| 171 |
+
['assets/examples/door.jpeg','assets/examples/door2.jpeg', '', 709736989, 6, 8, 9]],
|
| 172 |
+
inputs=[input_image, ref_img, prompt, seed, scale_t, scale_f, scale_s],
|
| 173 |
+
cache_examples=False,
|
| 174 |
+
)
|
| 175 |
|
| 176 |
+
mulitmod.change(fn=multimodal_params, inputs=[mulitmod], outputs=[scale_t, scale_f, scale_s])
|
|
|
|
| 177 |
|
| 178 |
+
ips = [input_mask, reference_mask, prompt, a_prompt, n_prompt, num_samples, ddim_steps, guess_mode, strength,
|
| 179 |
+
scale_s, scale_f, scale_t, seed, eta, dil, masking, whole_ref, interpolation]
|
| 180 |
+
ips_save = [input_mask, prompt, a_prompt, n_prompt, ddim_steps,
|
| 181 |
+
scale_s, scale_f, scale_t, seed, dil, interpolation]
|
| 182 |
+
run_button.click(fn=process_wrapper, inputs=[img_edit, *ips], outputs=[result_gallery])
|
| 183 |
+
save_button.click(fn=save_result_wrapper, inputs=[img_edit, *ips_save])
|
| 184 |
|
| 185 |
|
| 186 |
+
def create_add_obj_demo():
|
| 187 |
+
with gr.Row():
|
| 188 |
+
gr.Markdown("## Add Objects to Image")
|
| 189 |
+
with gr.Row():
|
| 190 |
+
gr.HTML(
|
| 191 |
+
"""
|
| 192 |
+
<div style="text-align: left; max-width: 1200px;">
|
| 193 |
+
<h3 style="font-weight: 450; font-size: 1rem; margin-top: 0.8rem; margin-bottom: 0.8rem">
|
| 194 |
+
Instructions </h3>
|
| 195 |
+
<ol>
|
| 196 |
+
<li> Upload an Input Image.</li>
|
| 197 |
+
<li>Draw the precise shape of object in the image where you want to add object in <i>Draw Object</i> tab.</li>
|
| 198 |
+
<li>Upload an Reference Image.</li>
|
| 199 |
+
<li>Click on the object in the Reference Image tab that you want to add in the Input Image.</li>
|
| 200 |
+
<li>Enter a prompt and press <i>Run</i> button. (A very simple would also work) </li>
|
| 201 |
+
</ol>
|
| 202 |
+
</ol>
|
| 203 |
+
</div>""")
|
| 204 |
+
with gr.Column():
|
| 205 |
+
with gr.Row():
|
| 206 |
+
img_edit = gr.State(ImageComp('add_obj'))
|
| 207 |
+
with gr.Column():
|
| 208 |
+
input_image = gr.Image(source='upload', label='Input Image', type="numpy",)
|
| 209 |
+
with gr.Column():
|
| 210 |
+
input_mask = gr.Image(source="upload", label='Draw the desired Object', type="numpy", tool="sketch")
|
| 211 |
|
| 212 |
+
input_image.change(fn=init_input_canvas_wrapper, inputs=[img_edit, input_image], outputs=[input_image])
|
| 213 |
+
input_image.change(fn=return_input_img_wrapper, inputs=[img_edit], outputs=[input_mask], queue=False)
|
| 214 |
+
|
| 215 |
+
with gr.Column():
|
| 216 |
+
ref_img = gr.Image(source='upload', label='Reference Image', type="numpy")
|
| 217 |
+
with gr.Column():
|
| 218 |
+
reference_mask = gr.Image(source="upload", label='Selected Object in Refernce Image', type="numpy")
|
| 219 |
|
| 220 |
+
ref_img.change(fn=init_ref_canvas_wrapper, inputs=[img_edit, ref_img], outputs=[ref_img], queue=False)
|
| 221 |
+
# ref_img.upload(fn=img_edit.init_ref_canvas, inputs=[ref_img], outputs=[ref_img])
|
| 222 |
+
ref_img.select(fn=select_ref_object_wrapper, inputs=[img_edit], outputs=[reference_mask])
|
| 223 |
|
| 224 |
+
with gr.Row():
|
| 225 |
+
prompt = gr.Textbox(label="Prompt", value='A picture of truck')
|
| 226 |
+
mulitmod = gr.Checkbox(label='Multi-Modal', value=False, visible=False)
|
| 227 |
|
| 228 |
+
mulitmod.change(fn=set_multi_modal_wrapper, inputs=[img_edit, mulitmod])
|
| 229 |
|
| 230 |
+
with gr.Row():
|
| 231 |
+
run_button = gr.Button(label="Run")
|
| 232 |
+
save_button = gr.Button("Save")
|
| 233 |
+
|
| 234 |
+
with gr.Row():
|
| 235 |
+
result_gallery = gr.Gallery(label='Output', show_label=False, elem_id="gallery").style(grid=4, height='auto')
|
| 236 |
+
|
| 237 |
+
with gr.Accordion("Advanced options", open=False):
|
| 238 |
+
num_samples = gr.Slider(label="Images", minimum=1, maximum=12, value=4, step=1)
|
| 239 |
+
# image_resolution = gr.Slider(label="Image Resolution", minimum=512, maximum=512, value=512, step=64)
|
| 240 |
+
strength = gr.Slider(label="Control Strength", minimum=0.0, maximum=2.0, value=1.0, step=0.01)
|
| 241 |
+
guess_mode = gr.Checkbox(label='Guess Mode', value=False)
|
| 242 |
+
ddim_steps = gr.Slider(label="Steps", minimum=1, maximum=100, value=20, step=1)
|
| 243 |
+
dil = gr.Slider(label="Merging region around Edge", minimum=0, maximum=5, value=2, step=1)
|
| 244 |
+
scale_t = gr.Slider(label="Guidance Scale Text", minimum=0., maximum=30.0, value=6.0, step=0.1)
|
| 245 |
+
scale_f = gr.Slider(label="Guidance Scale Appearance", minimum=0., maximum=30.0, value=8.0, step=0.1)
|
| 246 |
+
scale_s = gr.Slider(label="Guidance Scale Structure", minimum=0., maximum=30.0, value=9.0, step=0.1)
|
| 247 |
+
seed = gr.Slider(label="Seed", minimum=-1, maximum=2147483647, step=1, randomize=True)
|
| 248 |
+
eta = gr.Number(label="eta (DDIM)", value=0.0)
|
| 249 |
+
masking = gr.Checkbox(label='Only edit the local region', value=True)
|
| 250 |
+
a_prompt = gr.Textbox(label="Added Prompt", value='best quality, extremely detailed')
|
| 251 |
+
n_prompt = gr.Textbox(label="Negative Prompt",
|
| 252 |
+
value='longbody, lowres, bad anatomy, bad hands, missing fingers, extra digit, fewer digits, cropped, worst quality, low quality')
|
| 253 |
+
|
| 254 |
+
mulitmod.change(fn=multimodal_params, inputs=[mulitmod], outputs=[scale_t, scale_f, scale_s])
|
| 255 |
|
| 256 |
+
with gr.Column():
|
| 257 |
+
gr.Examples(
|
| 258 |
+
examples=[['assets/examples/chair.jpeg','assets/examples/carpet2.webp', 'A picture of living room with carpet', 892905419, 6, 8, 9],
|
| 259 |
+
['assets/examples/chair.jpeg','assets/examples/chair1.jpeg', 'A picture of living room with a orange and white sofa', 892905419, 6, 8, 9],
|
| 260 |
+
['assets/examples/park.webp','assets/examples/dog.jpeg', 'A picture of dog in the park', 892905419, 6, 8, 9]],
|
| 261 |
+
inputs=[input_image, ref_img, prompt, seed, scale_t, scale_f, scale_s],
|
| 262 |
+
outputs=None,
|
| 263 |
+
fn=None,
|
| 264 |
+
cache_examples=False,
|
| 265 |
+
)
|
| 266 |
+
ips = [input_mask, reference_mask, prompt, a_prompt, n_prompt, num_samples, ddim_steps, guess_mode, strength,
|
| 267 |
+
scale_s, scale_f, scale_t, seed, eta, dil, masking]
|
| 268 |
+
ips_save = [input_mask, prompt, a_prompt, n_prompt, ddim_steps,
|
| 269 |
+
scale_s, scale_f, scale_t, seed, dil]
|
| 270 |
+
run_button.click(fn=process_wrapper, inputs=[img_edit, *ips], outputs=[result_gallery])
|
| 271 |
+
save_button.click(fn=save_result_wrapper, inputs=[img_edit, *ips_save])
|
| 272 |
|
| 273 |
+
def create_obj_variation_demo():
|
| 274 |
with gr.Row():
|
| 275 |
+
gr.Markdown("## Objects Variation")
|
| 276 |
with gr.Row():
|
| 277 |
gr.HTML(
|
| 278 |
"""
|
| 279 |
+
<div style="text-align: left; max-width: 1200px;">
|
| 280 |
+
<h3 style="font-weight: 450; font-size: 1rem; margin-top: 0.8rem; margin-bottom: 0.8rem">
|
| 281 |
+
Instructions </h3>
|
| 282 |
+
<ol>
|
| 283 |
+
<li> Upload an Input Image.</li>
|
| 284 |
+
<li>Click on object to have variations</li>
|
| 285 |
+
<li>Press <i>Run</i> button</li>
|
| 286 |
+
</ol>
|
| 287 |
+
</ol>
|
| 288 |
+
</div>""")
|
| 289 |
+
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 290 |
with gr.Column():
|
| 291 |
with gr.Row():
|
| 292 |
img_edit = gr.State(ImageComp('edit_app'))
|
| 293 |
with gr.Column():
|
|
|
|
| 294 |
input_image = gr.Image(source='upload', label='Input Image', type="numpy",)
|
| 295 |
with gr.Column():
|
| 296 |
+
input_mask = gr.Image(source="upload", label='Select Object in Input Image', type="numpy",)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 297 |
|
|
|
|
|
|
|
| 298 |
with gr.Row():
|
| 299 |
+
prompt = gr.Textbox(label="Prompt", value='')
|
| 300 |
+
mulitmod = gr.Checkbox(label='Multi-Modal', value=False)
|
| 301 |
+
|
| 302 |
+
|
| 303 |
+
mulitmod.change(fn=set_multi_modal_wrapper, inputs=[img_edit, mulitmod])
|
| 304 |
+
|
| 305 |
+
input_image.change(fn=init_input_canvas_wrapper, inputs=[img_edit, input_image], outputs=[input_image])
|
| 306 |
+
input_image.select(fn=select_input_object_wrapper, inputs=[img_edit], outputs=[input_mask, prompt])
|
| 307 |
+
input_image.change(fn=init_ref_canvas_wrapper, inputs=[img_edit, input_image], outputs=[], queue=False)
|
| 308 |
+
input_image.select(fn=select_ref_object_wrapper, inputs=[img_edit], outputs=[])
|
| 309 |
+
|
| 310 |
with gr.Row():
|
| 311 |
run_button = gr.Button(label="Run")
|
| 312 |
+
save_button = gr.Button("Save")
|
| 313 |
|
| 314 |
with gr.Row():
|
| 315 |
result_gallery = gr.Gallery(label='Output', show_label=False, elem_id="gallery").style(grid=4, height='auto')
|
| 316 |
|
| 317 |
with gr.Accordion("Advanced options", open=False):
|
| 318 |
+
num_samples = gr.Slider(label="Images", minimum=1, maximum=12, value=4, step=2)
|
| 319 |
+
# image_resolution = gr.Slider(label="Image Resolution", minimum=512, maximum=512, value=512, step=64)
|
| 320 |
strength = gr.Slider(label="Control Strength", minimum=0.0, maximum=2.0, value=1.0, step=0.01)
|
| 321 |
guess_mode = gr.Checkbox(label='Guess Mode', value=False)
|
| 322 |
ddim_steps = gr.Slider(label="Steps", minimum=1, maximum=100, value=20, step=1)
|
| 323 |
+
dil = gr.Slider(label="Merging region around Edge", minimum=0, maximum=5, value=2, step=1)
|
| 324 |
+
scale_t = gr.Slider(label="Guidance Scale Text", minimum=0.0, maximum=30.0, value=6.0, step=0.1)
|
| 325 |
+
scale_f = gr.Slider(label="Guidance Scale Appearance", minimum=0.0, maximum=30.0, value=8.0, step=0.1)
|
| 326 |
+
scale_s = gr.Slider(label="Guidance Scale Structure", minimum=0.0, maximum=30.0, value=9.0, step=0.1)
|
| 327 |
seed = gr.Slider(label="Seed", minimum=-1, maximum=2147483647, step=1, randomize=True)
|
| 328 |
eta = gr.Number(label="eta (DDIM)", value=0.0)
|
| 329 |
masking = gr.Checkbox(label='Only edit the local region', value=True)
|
| 330 |
a_prompt = gr.Textbox(label="Added Prompt", value='best quality, extremely detailed')
|
| 331 |
n_prompt = gr.Textbox(label="Negative Prompt",
|
| 332 |
value='longbody, lowres, bad anatomy, bad hands, missing fingers, extra digit, fewer digits, cropped, worst quality, low quality')
|
| 333 |
+
|
| 334 |
+
|
| 335 |
+
mulitmod.change(fn=multimodal_params, inputs=[mulitmod], outputs=[scale_t, scale_f, scale_s])
|
| 336 |
+
|
| 337 |
with gr.Column():
|
| 338 |
gr.Examples(
|
| 339 |
+
examples=[['assets/examples/chair.jpeg' , 892905419, 6, 8, 9],
|
| 340 |
+
['assets/examples/chair1.jpeg', 892905419, 6, 8, 9],
|
| 341 |
+
['assets/examples/park.webp', 892905419, 6, 8, 9],
|
| 342 |
+
['assets/examples/car.jpeg', 709736989, 6, 8, 9],
|
| 343 |
+
['assets/examples/ian.jpeg', 709736989, 6, 8, 9],
|
| 344 |
+
['assets/examples/chair.jpeg', 1106204668, 6, 8, 9],
|
| 345 |
+
['assets/examples/door.jpeg', 709736989, 6, 8, 9],
|
| 346 |
+
['assets/examples/carpet2.webp', 892905419, 6, 8, 9],
|
| 347 |
+
['assets/examples/house.jpeg', 709736989, 6, 8, 9],
|
| 348 |
+
['assets/examples/house2.jpeg', 709736989, 6, 8, 9],],
|
| 349 |
+
inputs=[input_image, seed, scale_t, scale_f, scale_s],
|
| 350 |
outputs=None,
|
| 351 |
fn=None,
|
| 352 |
cache_examples=False,
|
| 353 |
)
|
| 354 |
+
ips = [input_mask, input_mask, prompt, a_prompt, n_prompt, num_samples, ddim_steps, guess_mode, strength,
|
| 355 |
+
scale_s, scale_f, scale_t, seed, eta, dil, masking]
|
| 356 |
+
ips_save = [input_mask, prompt, a_prompt, n_prompt, ddim_steps,
|
| 357 |
+
scale_s, scale_f, scale_t, seed, dil]
|
| 358 |
run_button.click(fn=process_wrapper, inputs=[img_edit, *ips], outputs=[result_gallery])
|
| 359 |
+
save_button.click(fn=save_result_wrapper, inputs=[img_edit, *ips_save])
|
| 360 |
|
| 361 |
|
| 362 |
+
def create_free_form_obj_variation_demo():
|
|
|
|
| 363 |
with gr.Row():
|
| 364 |
+
gr.Markdown("## Objects Variation")
|
|
|
|
|
|
|
| 365 |
with gr.Row():
|
| 366 |
+
gr.HTML(
|
| 367 |
+
"""
|
| 368 |
+
<div style="text-align: left; max-width: 1200px;">
|
| 369 |
+
<h3 style="font-weight: 450; font-size: 1rem; margin-top: 0.8rem; margin-bottom: 0.8rem">
|
| 370 |
+
Instructions </h3>
|
| 371 |
+
<ol>
|
| 372 |
+
<li> Upload an Input Image.</li>
|
| 373 |
+
<li>Mask the region that you want to have variation</li>
|
| 374 |
+
<li>Press <i>Run</i> button</li>
|
| 375 |
+
</ol>
|
| 376 |
+
</ol>
|
| 377 |
+
</div>""")
|
| 378 |
|
| 379 |
+
with gr.Column():
|
| 380 |
+
with gr.Row():
|
| 381 |
+
img_edit = gr.State(ImageComp('edit_app'))
|
| 382 |
+
with gr.Column():
|
| 383 |
+
input_image = gr.Image(source='upload', label='Input Image', type="numpy", )
|
| 384 |
+
with gr.Column():
|
| 385 |
+
input_mask = gr.Image(source="upload", label='Select Object in Input Image', type="numpy", tool="sketch")
|
| 386 |
+
|
| 387 |
+
with gr.Row():
|
| 388 |
+
prompt = gr.Textbox(label="Prompt", value='')
|
| 389 |
+
ignore_structure = gr.Checkbox(label='Ignore Structure (Please provide a good caption)', visible=False)
|
| 390 |
+
mulitmod = gr.Checkbox(label='Multi-Modal', value=False)
|
| 391 |
+
|
| 392 |
+
mulitmod.change(fn=set_multi_modal_wrapper, inputs=[img_edit, mulitmod])
|
| 393 |
+
|
| 394 |
+
input_image.change(fn=init_input_canvas_wrapper, inputs=[img_edit, input_image], outputs=[input_mask])
|
| 395 |
+
input_mask.edit(fn=get_caption_wrapper, inputs=[img_edit, input_mask], outputs=[prompt])
|
| 396 |
+
input_image.change(fn=init_ref_canvas_wrapper, inputs=[img_edit, input_image], outputs=[], queue=False)
|
| 397 |
+
# input_image.select(fn=select_ref_object_wrapper, inputs=[img_edit], outputs=[])
|
| 398 |
+
|
| 399 |
+
# input_image.edit(fn=img_edit.vis_mask, inputs=[input_image], outputs=[input_mask])
|
| 400 |
+
|
| 401 |
+
with gr.Row():
|
| 402 |
+
run_button = gr.Button(label="Run")
|
| 403 |
+
save_button = gr.Button("Save")
|
| 404 |
+
|
| 405 |
+
with gr.Row():
|
| 406 |
+
result_gallery = gr.Gallery(label='Output', show_label=False, elem_id="gallery").style(grid=4, height='auto')
|
| 407 |
+
|
| 408 |
+
with gr.Accordion("Advanced options", open=False):
|
| 409 |
+
num_samples = gr.Slider(label="Images", minimum=1, maximum=12, value=4, step=2)
|
| 410 |
+
# image_resolution = gr.Slider(label="Image Resolution", minimum=512, maximum=512, value=512, step=64)
|
| 411 |
+
strength = gr.Slider(label="Control Strength", minimum=0.0, maximum=2.0, value=1.0, step=0.01)
|
| 412 |
+
guess_mode = gr.Checkbox(label='Guess Mode', value=False)
|
| 413 |
+
ddim_steps = gr.Slider(label="Steps", minimum=1, maximum=100, value=20, step=1)
|
| 414 |
+
dil = gr.Slider(label="Merging region around Edge", minimum=0, maximum=5, value=2, step=1)
|
| 415 |
+
scale_t = gr.Slider(label="Guidance Scale Text", minimum=0.0, maximum=30.0, value=6.0, step=0.1)
|
| 416 |
+
scale_f = gr.Slider(label="Guidance Scale Appearance", minimum=0.0, maximum=30.0, value=8.0, step=0.1)
|
| 417 |
+
scale_s = gr.Slider(label="Guidance Scale Structure", minimum=0.0, maximum=30.0, value=9.0, step=0.1)
|
| 418 |
+
seed = gr.Slider(label="Seed", minimum=-1, maximum=2147483647, step=1, randomize=True)
|
| 419 |
+
eta = gr.Number(label="eta (DDIM)", value=0.0)
|
| 420 |
+
masking = gr.Checkbox(label='Only edit the local region', value=True)
|
| 421 |
+
free_form_obj_var = gr.Checkbox(label='', value=True)
|
| 422 |
+
a_prompt = gr.Textbox(label="Added Prompt", value='best quality, extremely detailed')
|
| 423 |
+
n_prompt = gr.Textbox(label="Negative Prompt",
|
| 424 |
+
value='longbody, lowres, bad anatomy, bad hands, missing fingers, extra digit, fewer digits, cropped, worst quality, low quality')
|
| 425 |
+
interpolation = gr.Slider(label="Mixing ratio of appearance from reference object", minimum=0.0, maximum=0.1, step=0.1)
|
| 426 |
+
|
| 427 |
+
mulitmod.change(fn=multimodal_params, inputs=[mulitmod], outputs=[scale_t, scale_f, scale_s])
|
| 428 |
+
|
| 429 |
+
with gr.Column():
|
| 430 |
+
gr.Examples(
|
| 431 |
+
examples=[['assets/examples/chair.jpeg' , 892905419, 6, 8, 9],
|
| 432 |
+
['assets/examples/chair1.jpeg', 892905419, 6, 8, 9],
|
| 433 |
+
['assets/examples/park.webp', 892905419, 6, 8, 9],
|
| 434 |
+
['assets/examples/car.jpeg', 709736989, 6, 8, 9],
|
| 435 |
+
['assets/examples/ian.jpeg', 709736989, 6, 8, 9],
|
| 436 |
+
['assets/examples/chair.jpeg', 1106204668, 6, 8, 9],
|
| 437 |
+
['assets/examples/door.jpeg', 709736989, 6, 8, 9],
|
| 438 |
+
['assets/examples/carpet2.webp', 892905419, 6, 8, 9],
|
| 439 |
+
['assets/examples/house.jpeg', 709736989, 6, 8, 9],
|
| 440 |
+
['assets/examples/house2.jpeg', 709736989, 6, 8, 9],],
|
| 441 |
+
inputs=[input_image, seed, scale_t, scale_f, scale_s],
|
| 442 |
+
outputs=None,
|
| 443 |
+
fn=None,
|
| 444 |
+
cache_examples=False,
|
| 445 |
+
)
|
| 446 |
+
ips = [input_mask, input_mask, prompt, a_prompt, n_prompt, num_samples, ddim_steps, guess_mode, strength,
|
| 447 |
+
scale_s, scale_f, scale_t, seed, eta, dil, masking, free_form_obj_var, dil, free_form_obj_var, ignore_structure]
|
| 448 |
+
ips_save = [input_mask, prompt, a_prompt, n_prompt, ddim_steps,
|
| 449 |
+
scale_s, scale_f, scale_t, seed, dil, interpolation, free_form_obj_var]
|
| 450 |
+
run_button.click(fn=process_wrapper, inputs=[img_edit, *ips], outputs=[result_gallery])
|
| 451 |
+
save_button.click(fn=save_result_wrapper, inputs=[img_edit, *ips_save])
|
| 452 |
|
| 453 |
|
| 454 |
+
block = gr.Blocks(css=css, theme=theme).queue()
|
| 455 |
with block:
|
| 456 |
gr.HTML(
|
| 457 |
"""
|
| 458 |
<div style="text-align: center; max-width: 1200px; margin: 20px auto;">
|
| 459 |
<h1 style="font-weight: 900; font-size: 3rem; margin: 0rem">
|
| 460 |
+
PAIR Diffusion: A Comprehensive Multimodal Object-Level Image Editor
|
| 461 |
</h1>
|
| 462 |
+
<h3 style="margin-top: 0.6rem; margin-bottom: 1rem">Picsart AI Research</h3>
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 463 |
<h2 style="font-weight: 450; font-size: 1rem; margin-top: 0.8rem; margin-bottom: 0.8rem">
|
| 464 |
+
PAIR diffusion provides comprehensive multi-modal editing capabilities to edit real images without the need of inverting them. The current suite contains
|
| 465 |
+
<span style="color: #01feee;">Object Variation</span>, <span style="color: #4f82d9;">Edit Appearance of any object using a reference image and text</span>,
|
| 466 |
+
<span style="color: #d402bf;">Add any object from a reference image in the input image</span>. This operations can be mixed with each other to
|
| 467 |
+
develop new editing operations in future.
|
| 468 |
+
</ul>
|
|
|
|
| 469 |
</h2>
|
|
|
|
| 470 |
</div>
|
| 471 |
""")
|
| 472 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 473 |
with gr.Tab('Edit Appearance'):
|
| 474 |
create_app_demo()
|
| 475 |
+
with gr.Tab('Object Variation Free Form Mask'):
|
| 476 |
+
create_free_form_obj_variation_demo()
|
| 477 |
+
with gr.Tab('Object Variation'):
|
| 478 |
+
create_obj_variation_demo()
|
| 479 |
+
with gr.Tab('Add Objects'):
|
| 480 |
+
create_add_obj_demo()
|
| 481 |
|
| 482 |
block.queue(max_size=20)
|
| 483 |
+
block.launch(share=True)
|
| 484 |
+
|
assets/GIF.gif
CHANGED
|
|
Git LFS Details
|
assets/Teaser_Small.png
ADDED
|
Git LFS Details
|
assets/examples/Lancia.webp
ADDED
|
Git LFS Details
|
assets/examples/car.jpeg
ADDED
|
Git LFS Details
|
assets/examples/car1.webp
ADDED
|
Git LFS Details
|
assets/examples/carpet2.webp
ADDED
|
Git LFS Details
|
assets/examples/chair.jpeg
ADDED
|
Git LFS Details
|
assets/examples/chair1.jpeg
ADDED
|
Git LFS Details
|
assets/examples/dog.jpeg
ADDED
|
Git LFS Details
|
assets/examples/door.jpeg
ADDED
|
Git LFS Details
|
assets/examples/door2.jpeg
ADDED
|
Git LFS Details
|
assets/examples/grasslands-national-park.jpeg
ADDED
|
Git LFS Details
|
assets/examples/house.jpeg
ADDED
|
Git LFS Details
|
assets/examples/house2.jpeg
ADDED
|
Git LFS Details
|
assets/examples/ian.jpeg
ADDED
|
Git LFS Details
|
assets/examples/park.webp
ADDED
|
Git LFS Details
|
assets/examples/ran.webp
ADDED
|
Git LFS Details
|
assets/hulk.jpeg
CHANGED
|
|
Git LFS Details
|
assets/ironman.webp
CHANGED
|
|
Git LFS Details
|
assets/lava.jpg
CHANGED
|
|
Git LFS Details
|
assets/ski.jpg
CHANGED
|
|
Git LFS Details
|
assets/truck.png
CHANGED
|
|
Git LFS Details
|
assets/truck2.jpeg
CHANGED
|
|
Git LFS Details
|
cldm/appearance_networks.py
ADDED
|
@@ -0,0 +1,75 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Neighborhood Attention Transformer.
|
| 3 |
+
https://arxiv.org/abs/2204.07143
|
| 4 |
+
|
| 5 |
+
This source code is licensed under the license found in the
|
| 6 |
+
LICENSE file in the root directory of this source tree.
|
| 7 |
+
"""
|
| 8 |
+
import torch
|
| 9 |
+
import torchvision
|
| 10 |
+
import torch.nn as nn
|
| 11 |
+
from timm.models.layers import trunc_normal_, DropPath
|
| 12 |
+
from timm.models.registry import register_model
|
| 13 |
+
|
| 14 |
+
|
| 15 |
+
IMAGENET_MEAN = [0.485, 0.456, 0.406]
|
| 16 |
+
IMAGENET_STD = [0.229, 0.224, 0.225]
|
| 17 |
+
|
| 18 |
+
class VGGPerceptualLoss(torch.nn.Module):
|
| 19 |
+
def __init__(self, resize=True):
|
| 20 |
+
super(VGGPerceptualLoss, self).__init__()
|
| 21 |
+
blocks = []
|
| 22 |
+
blocks.append(torchvision.models.vgg16(pretrained=True).features[:4].eval())
|
| 23 |
+
blocks.append(torchvision.models.vgg16(pretrained=True).features[4:9].eval())
|
| 24 |
+
blocks.append(torchvision.models.vgg16(pretrained=True).features[9:16].eval())
|
| 25 |
+
blocks.append(torchvision.models.vgg16(pretrained=True).features[16:23].eval())
|
| 26 |
+
for bl in blocks:
|
| 27 |
+
for p in bl.parameters():
|
| 28 |
+
p.requires_grad = False
|
| 29 |
+
self.blocks = torch.nn.ModuleList(blocks)
|
| 30 |
+
self.transform = torch.nn.functional.interpolate
|
| 31 |
+
self.resize = resize
|
| 32 |
+
self.register_buffer("mean", torch.tensor([0.485, 0.456, 0.406]).view(1, 3, 1, 1))
|
| 33 |
+
self.register_buffer("std", torch.tensor([0.229, 0.224, 0.225]).view(1, 3, 1, 1))
|
| 34 |
+
|
| 35 |
+
def forward(self, input, appearance_layers=[0,1,2,3]):
|
| 36 |
+
if input.shape[1] != 3:
|
| 37 |
+
input = input.repeat(1, 3, 1, 1)
|
| 38 |
+
target = target.repeat(1, 3, 1, 1)
|
| 39 |
+
input = (input-self.mean) / self.std
|
| 40 |
+
if self.resize:
|
| 41 |
+
input = self.transform(input, mode='bilinear', size=(224, 224), align_corners=False)
|
| 42 |
+
x = input
|
| 43 |
+
feats = []
|
| 44 |
+
for i, block in enumerate(self.blocks):
|
| 45 |
+
x = block(x)
|
| 46 |
+
if i in appearance_layers:
|
| 47 |
+
feats.append(x)
|
| 48 |
+
|
| 49 |
+
return feats
|
| 50 |
+
|
| 51 |
+
|
| 52 |
+
class DINOv2(torch.nn.Module):
|
| 53 |
+
def __init__(self, resize=True, size=224, model_type='dinov2_vitl14'):
|
| 54 |
+
super(DINOv2, self).__init__()
|
| 55 |
+
self.size=size
|
| 56 |
+
self.resize = resize
|
| 57 |
+
self.transform = torch.nn.functional.interpolate
|
| 58 |
+
self.model = torch.hub.load('facebookresearch/dinov2', model_type)
|
| 59 |
+
self.register_buffer("mean", torch.tensor([0.485, 0.456, 0.406]).view(1, 3, 1, 1))
|
| 60 |
+
self.register_buffer("std", torch.tensor([0.229, 0.224, 0.225]).view(1, 3, 1, 1))
|
| 61 |
+
|
| 62 |
+
def forward(self, input, appearance_layers=[1,2]):
|
| 63 |
+
if input.shape[1] != 3:
|
| 64 |
+
input = input.repeat(1, 3, 1, 1)
|
| 65 |
+
target = target.repeat(1, 3, 1, 1)
|
| 66 |
+
|
| 67 |
+
if self.resize:
|
| 68 |
+
input = self.transform(input, mode='bicubic', size=(self.size, self.size), align_corners=False)
|
| 69 |
+
# mean = torch.tensor(IMAGENET_MEAN).view(1, 3, 1, 1).to(input.device)
|
| 70 |
+
# std = torch.tensor(IMAGENET_STD).view(1, 3, 1, 1).to(input.device)
|
| 71 |
+
input = (input-self.mean) / self.std
|
| 72 |
+
feats = self.model.get_intermediate_layers(input, self.model.n_blocks, reshape=True)
|
| 73 |
+
feats = [f.detach() for f in feats]
|
| 74 |
+
|
| 75 |
+
return feats
|
cldm/cldm.py
CHANGED
|
@@ -10,7 +10,6 @@ from ldm.modules.diffusionmodules.util import (
|
|
| 10 |
zero_module,
|
| 11 |
timestep_embedding,
|
| 12 |
)
|
| 13 |
-
import torchvision
|
| 14 |
from einops import rearrange, repeat
|
| 15 |
from torchvision.utils import make_grid
|
| 16 |
from ldm.modules.attention import SpatialTransformer
|
|
@@ -18,46 +17,9 @@ from ldm.modules.diffusionmodules.openaimodel import UNetModel, TimestepEmbedSeq
|
|
| 18 |
from ldm.models.diffusion.ddpm import LatentDiffusion
|
| 19 |
from ldm.util import log_txt_as_img, exists, instantiate_from_config
|
| 20 |
from ldm.models.diffusion.ddim import DDIMSampler
|
|
|
|
| 21 |
|
| 22 |
|
| 23 |
-
class VGGPerceptualLoss(torch.nn.Module):
|
| 24 |
-
def __init__(self, resize=True):
|
| 25 |
-
super(VGGPerceptualLoss, self).__init__()
|
| 26 |
-
blocks = []
|
| 27 |
-
vgg_model = torchvision.models.vgg16(pretrained=True)
|
| 28 |
-
print('Loaded VGG weights')
|
| 29 |
-
blocks.append(vgg_model.features[:4].eval())
|
| 30 |
-
blocks.append(vgg_model.features[4:9].eval())
|
| 31 |
-
blocks.append(vgg_model.features[9:16].eval())
|
| 32 |
-
blocks.append(vgg_model.features[16:23].eval())
|
| 33 |
-
|
| 34 |
-
for bl in blocks:
|
| 35 |
-
for p in bl.parameters():
|
| 36 |
-
p.requires_grad = False
|
| 37 |
-
self.blocks = torch.nn.ModuleList(blocks)
|
| 38 |
-
self.transform = torch.nn.functional.interpolate
|
| 39 |
-
self.resize = resize
|
| 40 |
-
self.register_buffer("mean", torch.tensor([0.485, 0.456, 0.406]).view(1, 3, 1, 1))
|
| 41 |
-
self.register_buffer("std", torch.tensor([0.229, 0.224, 0.225]).view(1, 3, 1, 1))
|
| 42 |
-
print('Initialized VGG model')
|
| 43 |
-
|
| 44 |
-
def forward(self, input, feature_layers=[0, 1, 2, 3], style_layers=[1,]):
|
| 45 |
-
if input.shape[1] != 3:
|
| 46 |
-
input = input.repeat(1, 3, 1, 1)
|
| 47 |
-
target = target.repeat(1, 3, 1, 1)
|
| 48 |
-
input = (input-self.mean) / self.std
|
| 49 |
-
if self.resize:
|
| 50 |
-
input = self.transform(input, mode='bilinear', size=(224, 224), align_corners=False)
|
| 51 |
-
x = input
|
| 52 |
-
gram_matrices_all = []
|
| 53 |
-
feats = []
|
| 54 |
-
for i, block in enumerate(self.blocks):
|
| 55 |
-
x = block(x)
|
| 56 |
-
if i in style_layers:
|
| 57 |
-
feats.append(x)
|
| 58 |
-
|
| 59 |
-
return feats
|
| 60 |
-
|
| 61 |
|
| 62 |
|
| 63 |
class ControlledUnetModel(UNetModel):
|
|
@@ -325,6 +287,7 @@ class ControlNet(nn.Module):
|
|
| 325 |
def forward(self, x, hint, timesteps, context, **kwargs):
|
| 326 |
t_emb = timestep_embedding(timesteps, self.model_channels, repeat_only=False)
|
| 327 |
emb = self.time_embed(t_emb)
|
|
|
|
| 328 |
guided_hint = self.input_hint_block(hint, emb, context, x.shape)
|
| 329 |
|
| 330 |
outs = []
|
|
@@ -343,57 +306,6 @@ class ControlNet(nn.Module):
|
|
| 343 |
outs.append(self.middle_block_out(h, emb, context))
|
| 344 |
|
| 345 |
return outs
|
| 346 |
-
|
| 347 |
-
class Interpolate(nn.Module):
|
| 348 |
-
def __init__(self, size, mode):
|
| 349 |
-
super(Interpolate, self).__init__()
|
| 350 |
-
self.interp = torch.nn.functional.interpolate
|
| 351 |
-
self.size = size
|
| 352 |
-
self.mode = mode
|
| 353 |
-
self.factor = 8
|
| 354 |
-
|
| 355 |
-
def forward(self, x):
|
| 356 |
-
h,w = x.shape[2]//self.factor, x.shape[3]//self.factor
|
| 357 |
-
x = self.interp(x, size=(h,w), mode=self.mode)
|
| 358 |
-
return x
|
| 359 |
-
|
| 360 |
-
class ControlNetSAP(ControlNet):
|
| 361 |
-
def __init__(
|
| 362 |
-
self,
|
| 363 |
-
hint_channels,
|
| 364 |
-
model_channels,
|
| 365 |
-
input_hint_block='fixed',
|
| 366 |
-
size = 64,
|
| 367 |
-
mode='nearest',
|
| 368 |
-
*args,
|
| 369 |
-
**kwargs
|
| 370 |
-
):
|
| 371 |
-
super().__init__( hint_channels=hint_channels, model_channels=model_channels, *args, **kwargs)
|
| 372 |
-
#hint channels are atleast 128 dims
|
| 373 |
-
|
| 374 |
-
if input_hint_block == 'learnable':
|
| 375 |
-
ch = 2 ** (int(math.log2(hint_channels)))
|
| 376 |
-
self.input_hint_block = TimestepEmbedSequential(
|
| 377 |
-
conv_nd(self.dims, hint_channels, hint_channels, 3, padding=1),
|
| 378 |
-
nn.SiLU(),
|
| 379 |
-
conv_nd(self.dims, hint_channels, 2*ch, 3, padding=1, stride=2),
|
| 380 |
-
nn.SiLU(),
|
| 381 |
-
conv_nd(self.dims, 2*ch, 2*ch, 3, padding=1),
|
| 382 |
-
nn.SiLU(),
|
| 383 |
-
conv_nd(self.dims, 2*ch, 2*ch, 3, padding=1, stride=2),
|
| 384 |
-
nn.SiLU(),
|
| 385 |
-
conv_nd(self.dims, 2*ch, 2*ch, 3, padding=1),
|
| 386 |
-
nn.SiLU(),
|
| 387 |
-
conv_nd(self.dims, 2*ch, model_channels, 3, padding=1, stride=2),
|
| 388 |
-
nn.SiLU(),
|
| 389 |
-
zero_module(conv_nd(self.dims, model_channels, model_channels, 3, padding=1))
|
| 390 |
-
)
|
| 391 |
-
else:
|
| 392 |
-
print("Only interpolation")
|
| 393 |
-
self.input_hint_block = TimestepEmbedSequential(
|
| 394 |
-
Interpolate(size, mode),
|
| 395 |
-
zero_module(conv_nd(self.dims, hint_channels, model_channels, 3, padding=1)))
|
| 396 |
-
|
| 397 |
|
| 398 |
class ControlLDM(LatentDiffusion):
|
| 399 |
|
|
@@ -420,11 +332,11 @@ class ControlLDM(LatentDiffusion):
|
|
| 420 |
diffusion_model = self.model.diffusion_model
|
| 421 |
|
| 422 |
cond_txt = torch.cat(cond['c_crossattn'], 1)
|
| 423 |
-
|
| 424 |
if cond['c_concat'] is None:
|
| 425 |
eps = diffusion_model(x=x_noisy, timesteps=t, context=cond_txt, control=None, only_mid_control=self.only_mid_control)
|
| 426 |
else:
|
| 427 |
-
control = self.control_model(x=x_noisy, hint=torch.cat(cond['c_concat'], 1), timesteps=t, context=cond_txt)
|
|
|
|
| 428 |
control = [c * scale for c, scale in zip(control, self.control_scales)]
|
| 429 |
eps = diffusion_model(x=x_noisy, timesteps=t, context=cond_txt, control=control, only_mid_control=self.only_mid_control)
|
| 430 |
|
|
@@ -443,7 +355,7 @@ class ControlLDM(LatentDiffusion):
|
|
| 443 |
use_ddim = ddim_steps is not None
|
| 444 |
|
| 445 |
log = dict()
|
| 446 |
-
z, c = self.get_input(batch, self.first_stage_key, bs=N)
|
| 447 |
c_cat, c = c["c_concat"][0][:N], c["c_crossattn"][0][:N]
|
| 448 |
N = min(z.shape[0], N)
|
| 449 |
n_row = min(z.shape[0], n_row)
|
|
@@ -498,8 +410,9 @@ class ControlLDM(LatentDiffusion):
|
|
| 498 |
@torch.no_grad()
|
| 499 |
def sample_log(self, cond, batch_size, ddim, ddim_steps, **kwargs):
|
| 500 |
ddim_sampler = DDIMSampler(self)
|
| 501 |
-
b, c, h, w = cond["c_concat"][0].shape
|
| 502 |
-
shape = (self.channels, h // 8, w // 8)
|
|
|
|
| 503 |
samples, intermediates = ddim_sampler.sample(ddim_steps, batch_size, shape, cond, verbose=False, **kwargs)
|
| 504 |
return samples, intermediates
|
| 505 |
|
|
@@ -525,24 +438,54 @@ class ControlLDM(LatentDiffusion):
|
|
| 525 |
self.cond_stage_model = self.cond_stage_model.cuda()
|
| 526 |
|
| 527 |
|
| 528 |
-
|
|
|
|
| 529 |
@torch.no_grad()
|
| 530 |
-
def __init__(self,control_stage_config, control_key, only_mid_control,
|
|
|
|
| 531 |
super().__init__(control_stage_config=control_stage_config,
|
| 532 |
control_key=control_key,
|
| 533 |
only_mid_control=only_mid_control,
|
| 534 |
*args, **kwargs)
|
| 535 |
-
self.appearance_net = VGGPerceptualLoss().to(self.device)
|
| 536 |
-
print("Loaded VGG model")
|
| 537 |
|
| 538 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 539 |
img = (img + 1) * 0.5
|
| 540 |
-
feat =
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 541 |
empty_mask_flag = torch.sum(mask, dim=(1,2,3)) == 0
|
| 542 |
|
| 543 |
|
| 544 |
empty_appearance = torch.zeros(feat.shape).to(self.device)
|
| 545 |
-
mask = torch.nn.functional.interpolate(mask.float(), (feat.shape[2
|
| 546 |
one_hot = torch.nn.functional.one_hot(mask[:,0]).permute(0,3,1,2).float()
|
| 547 |
|
| 548 |
feat = torch.einsum('nchw, nmhw->nmchw', feat, one_hot)
|
|
@@ -552,32 +495,68 @@ class SAP(ControlLDM):
|
|
| 552 |
mean_feat[:, 0] = torch.zeros(mean_feat[:,0].shape).to(self.device) #set edges in panopitc mask to empty appearance feature
|
| 553 |
|
| 554 |
splatted_feat = torch.einsum('nmc, nmhw->nchw', mean_feat, one_hot)
|
| 555 |
-
splatted_feat[empty_mask_flag] = empty_appearance[empty_mask_flag]
|
| 556 |
splatted_feat = torch.nn.functional.normalize(splatted_feat) #l2 normalize on c dim
|
| 557 |
|
| 558 |
if return_all:
|
| 559 |
return splatted_feat, mean_feat, one_hot, empty_mask_flag
|
| 560 |
-
|
| 561 |
return splatted_feat
|
| 562 |
-
|
|
|
|
| 563 |
def get_input(self, batch, k, bs=None, *args, **kwargs):
|
| 564 |
z, c, x_orig, x_recon = super(ControlLDM, self).get_input(batch, self.first_stage_key, return_first_stage_outputs=True , *args, **kwargs)
|
| 565 |
structure = batch['seg'].unsqueeze(1)
|
| 566 |
mask = batch['mask'].unsqueeze(1).to(self.device)
|
| 567 |
-
|
|
|
|
|
|
|
|
|
|
| 568 |
if bs is not None:
|
| 569 |
structure = structure[:bs]
|
| 570 |
-
appearance = appearance[:bs]
|
| 571 |
-
|
| 572 |
structure = structure.to(self.device)
|
| 573 |
-
appearance = appearance.to(self.device)
|
| 574 |
structure = structure.to(memory_format=torch.contiguous_format).float()
|
| 575 |
-
|
| 576 |
-
|
| 577 |
-
|
| 578 |
-
|
| 579 |
-
|
| 580 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 581 |
@torch.no_grad()
|
| 582 |
def log_images(self, batch, N=4, n_row=2, sample=False, ddim_steps=50, ddim_eta=0.0, return_keys=None,
|
| 583 |
quantize_denoised=True, inpaint=True, plot_denoise_rows=False, plot_progressive_rows=False,
|
|
@@ -588,11 +567,14 @@ class SAP(ControlLDM):
|
|
| 588 |
|
| 589 |
log = dict()
|
| 590 |
z, c = self.get_input(batch, self.first_stage_key, bs=N)
|
| 591 |
-
c_cat, c = c["c_concat"][0]
|
| 592 |
N = min(z.shape[0], N)
|
| 593 |
n_row = min(z.shape[0], n_row)
|
| 594 |
log["reconstruction"] = self.decode_first_stage(z)
|
| 595 |
-
log["control"] =
|
|
|
|
|
|
|
|
|
|
| 596 |
log["conditioning"] = log_txt_as_img((512, 512), batch[self.cond_stage_key], size=16)
|
| 597 |
|
| 598 |
if plot_diffusion_rows:
|
|
@@ -634,7 +616,7 @@ class SAP(ControlLDM):
|
|
| 634 |
|
| 635 |
if unconditional_guidance_scale > 1.0:
|
| 636 |
uc_cross = self.get_unconditional_conditioning(N)
|
| 637 |
-
uc_cat = c_cat # torch.zeros_like(c_cat)
|
| 638 |
uc_full = {"c_concat": [uc_cat], "c_crossattn": [uc_cross]}
|
| 639 |
samples_cfg, _ = self.sample_log(cond={"c_concat": [c_cat], "c_crossattn": [c]},
|
| 640 |
batch_size=N, ddim=use_ddim,
|
|
@@ -646,3 +628,18 @@ class SAP(ControlLDM):
|
|
| 646 |
log[f"samples_cfg_scale_{unconditional_guidance_scale:.2f}"] = x_samples_cfg
|
| 647 |
|
| 648 |
return log
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 10 |
zero_module,
|
| 11 |
timestep_embedding,
|
| 12 |
)
|
|
|
|
| 13 |
from einops import rearrange, repeat
|
| 14 |
from torchvision.utils import make_grid
|
| 15 |
from ldm.modules.attention import SpatialTransformer
|
|
|
|
| 17 |
from ldm.models.diffusion.ddpm import LatentDiffusion
|
| 18 |
from ldm.util import log_txt_as_img, exists, instantiate_from_config
|
| 19 |
from ldm.models.diffusion.ddim import DDIMSampler
|
| 20 |
+
from cldm.appearance_networks import VGGPerceptualLoss, DINOv2
|
| 21 |
|
| 22 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 23 |
|
| 24 |
|
| 25 |
class ControlledUnetModel(UNetModel):
|
|
|
|
| 287 |
def forward(self, x, hint, timesteps, context, **kwargs):
|
| 288 |
t_emb = timestep_embedding(timesteps, self.model_channels, repeat_only=False)
|
| 289 |
emb = self.time_embed(t_emb)
|
| 290 |
+
# hint = hint[:,:-1]
|
| 291 |
guided_hint = self.input_hint_block(hint, emb, context, x.shape)
|
| 292 |
|
| 293 |
outs = []
|
|
|
|
| 306 |
outs.append(self.middle_block_out(h, emb, context))
|
| 307 |
|
| 308 |
return outs
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 309 |
|
| 310 |
class ControlLDM(LatentDiffusion):
|
| 311 |
|
|
|
|
| 332 |
diffusion_model = self.model.diffusion_model
|
| 333 |
|
| 334 |
cond_txt = torch.cat(cond['c_crossattn'], 1)
|
|
|
|
| 335 |
if cond['c_concat'] is None:
|
| 336 |
eps = diffusion_model(x=x_noisy, timesteps=t, context=cond_txt, control=None, only_mid_control=self.only_mid_control)
|
| 337 |
else:
|
| 338 |
+
# control = self.control_model(x=x_noisy, hint=torch.cat(cond['c_concat'], 1), timesteps=t, context=cond_txt)
|
| 339 |
+
control = self.control_model(x=x_noisy, hint=cond['c_concat'][0], timesteps=t, context=cond_txt)
|
| 340 |
control = [c * scale for c, scale in zip(control, self.control_scales)]
|
| 341 |
eps = diffusion_model(x=x_noisy, timesteps=t, context=cond_txt, control=control, only_mid_control=self.only_mid_control)
|
| 342 |
|
|
|
|
| 355 |
use_ddim = ddim_steps is not None
|
| 356 |
|
| 357 |
log = dict()
|
| 358 |
+
z, c = self.get_input(batch, self.first_stage_key, bs=N, logging=True)
|
| 359 |
c_cat, c = c["c_concat"][0][:N], c["c_crossattn"][0][:N]
|
| 360 |
N = min(z.shape[0], N)
|
| 361 |
n_row = min(z.shape[0], n_row)
|
|
|
|
| 410 |
@torch.no_grad()
|
| 411 |
def sample_log(self, cond, batch_size, ddim, ddim_steps, **kwargs):
|
| 412 |
ddim_sampler = DDIMSampler(self)
|
| 413 |
+
b, c, h, w = cond["c_concat"][0][0].shape if isinstance(cond["c_concat"][0], list) else cond["c_concat"][0].shape
|
| 414 |
+
# shape = (self.channels, h // 8, w // 8)
|
| 415 |
+
shape = (self.channels, h, w)
|
| 416 |
samples, intermediates = ddim_sampler.sample(ddim_steps, batch_size, shape, cond, verbose=False, **kwargs)
|
| 417 |
return samples, intermediates
|
| 418 |
|
|
|
|
| 438 |
self.cond_stage_model = self.cond_stage_model.cuda()
|
| 439 |
|
| 440 |
|
| 441 |
+
|
| 442 |
+
class PAIRDiffusion(ControlLDM):
|
| 443 |
@torch.no_grad()
|
| 444 |
+
def __init__(self,control_stage_config, control_key, only_mid_control, app_net='vgg', app_layer_conc=(1,), app_layer_ca=(6,6,18,18),
|
| 445 |
+
appearance_net_locked=True, concat_multi_app=False, train_structure_variation_only=False, instruct=False, *args, **kwargs):
|
| 446 |
super().__init__(control_stage_config=control_stage_config,
|
| 447 |
control_key=control_key,
|
| 448 |
only_mid_control=only_mid_control,
|
| 449 |
*args, **kwargs)
|
|
|
|
|
|
|
| 450 |
|
| 451 |
+
self.appearance_net_conc = VGGPerceptualLoss().to(self.device)
|
| 452 |
+
self.appearance_net_ca = DINOv2().to(self.device)
|
| 453 |
+
self.appearance_net = VGGPerceptualLoss().to(self.device) #need to be removed no use
|
| 454 |
+
self.app_layer_conc = app_layer_conc
|
| 455 |
+
self.app_layer_ca = app_layer_ca
|
| 456 |
+
|
| 457 |
+
|
| 458 |
+
def get_appearance(self, net, layer, img, mask, return_all=False):
|
| 459 |
img = (img + 1) * 0.5
|
| 460 |
+
feat = net(img)
|
| 461 |
+
splatted_feat = []
|
| 462 |
+
mean_feat = []
|
| 463 |
+
for fe_i in layer:
|
| 464 |
+
v = self.get_appearance_single(feat[fe_i], mask, return_all=return_all)
|
| 465 |
+
if return_all:
|
| 466 |
+
spl, me_f, one_hot, empty_mask = v
|
| 467 |
+
splatted_feat.append(spl)
|
| 468 |
+
mean_feat.append(me_f)
|
| 469 |
+
else:
|
| 470 |
+
splatted_feat.append(v)
|
| 471 |
+
|
| 472 |
+
if len(layer) == 1:
|
| 473 |
+
splatted_feat = splatted_feat[0]
|
| 474 |
+
# mean_feat = mean_feat[0]
|
| 475 |
+
|
| 476 |
+
del feat
|
| 477 |
+
|
| 478 |
+
if return_all:
|
| 479 |
+
return splatted_feat, mean_feat, one_hot, empty_mask
|
| 480 |
+
|
| 481 |
+
return splatted_feat
|
| 482 |
+
|
| 483 |
+
def get_appearance_single(self, feat, mask, return_all):
|
| 484 |
empty_mask_flag = torch.sum(mask, dim=(1,2,3)) == 0
|
| 485 |
|
| 486 |
|
| 487 |
empty_appearance = torch.zeros(feat.shape).to(self.device)
|
| 488 |
+
mask = torch.nn.functional.interpolate(mask.float(), size=(feat.shape[2], feat.shape[3])).long()
|
| 489 |
one_hot = torch.nn.functional.one_hot(mask[:,0]).permute(0,3,1,2).float()
|
| 490 |
|
| 491 |
feat = torch.einsum('nchw, nmhw->nmchw', feat, one_hot)
|
|
|
|
| 495 |
mean_feat[:, 0] = torch.zeros(mean_feat[:,0].shape).to(self.device) #set edges in panopitc mask to empty appearance feature
|
| 496 |
|
| 497 |
splatted_feat = torch.einsum('nmc, nmhw->nchw', mean_feat, one_hot)
|
| 498 |
+
splatted_feat[empty_mask_flag] = empty_appearance[empty_mask_flag]
|
| 499 |
splatted_feat = torch.nn.functional.normalize(splatted_feat) #l2 normalize on c dim
|
| 500 |
|
| 501 |
if return_all:
|
| 502 |
return splatted_feat, mean_feat, one_hot, empty_mask_flag
|
|
|
|
| 503 |
return splatted_feat
|
| 504 |
+
|
| 505 |
+
|
| 506 |
def get_input(self, batch, k, bs=None, *args, **kwargs):
|
| 507 |
z, c, x_orig, x_recon = super(ControlLDM, self).get_input(batch, self.first_stage_key, return_first_stage_outputs=True , *args, **kwargs)
|
| 508 |
structure = batch['seg'].unsqueeze(1)
|
| 509 |
mask = batch['mask'].unsqueeze(1).to(self.device)
|
| 510 |
+
|
| 511 |
+
appearance_conc = self.get_appearance(self.appearance_net_conc, self.app_layer_conc, x_orig, mask)
|
| 512 |
+
appearance_ca = self.get_appearance(self.appearance_net_ca, self.app_layer_ca, x_orig, mask)
|
| 513 |
+
|
| 514 |
if bs is not None:
|
| 515 |
structure = structure[:bs]
|
|
|
|
|
|
|
| 516 |
structure = structure.to(self.device)
|
|
|
|
| 517 |
structure = structure.to(memory_format=torch.contiguous_format).float()
|
| 518 |
+
structure = torch.nn.functional.interpolate(structure, z.shape[2:])
|
| 519 |
+
|
| 520 |
+
mask = torch.nn.functional.interpolate(mask.float(), z.shape[2:])
|
| 521 |
+
|
| 522 |
+
def format_appearance(appearance):
|
| 523 |
+
if isinstance(appearance, list):
|
| 524 |
+
if bs is not None:
|
| 525 |
+
appearance = [ap[:bs] for ap in appearance]
|
| 526 |
+
appearance = [ap.to(self.device) for ap in appearance]
|
| 527 |
+
appearance = [ap.to(memory_format=torch.contiguous_format).float() for ap in appearance]
|
| 528 |
+
appearance = [torch.nn.functional.interpolate(ap, z.shape[2:]) for ap in appearance]
|
| 529 |
+
|
| 530 |
+
else:
|
| 531 |
+
if bs is not None:
|
| 532 |
+
appearance = appearance[:bs]
|
| 533 |
+
appearance = appearance.to(self.device)
|
| 534 |
+
appearance = appearance.to(memory_format=torch.contiguous_format).float()
|
| 535 |
+
appearance = torch.nn.functional.interpolate(appearance, z.shape[2:])
|
| 536 |
+
|
| 537 |
+
return appearance
|
| 538 |
+
|
| 539 |
+
appearance_conc = format_appearance(appearance_conc)
|
| 540 |
+
appearance_ca = format_appearance(appearance_ca)
|
| 541 |
+
|
| 542 |
+
if isinstance(appearance_conc, list):
|
| 543 |
+
concat_control = torch.cat(appearance_conc, dim=1)
|
| 544 |
+
concat_control = torch.cat([structure, concat_control, mask], dim=1)
|
| 545 |
+
else:
|
| 546 |
+
concat_control = torch.cat([structure, appearance_conc, mask], dim=1)
|
| 547 |
+
|
| 548 |
+
|
| 549 |
+
if isinstance(appearance_ca, list):
|
| 550 |
+
control = []
|
| 551 |
+
for ap in appearance_ca:
|
| 552 |
+
control.append(torch.cat([structure, ap, mask], dim=1))
|
| 553 |
+
control.append(concat_control)
|
| 554 |
+
return z, dict(c_crossattn=[c], c_concat=[control])
|
| 555 |
+
else:
|
| 556 |
+
control = torch.cat([structure, appearance_ca, mask], dim=1)
|
| 557 |
+
control.append(concat_control)
|
| 558 |
+
return z, dict(c_crossattn=[c], c_concat=[control])
|
| 559 |
+
|
| 560 |
@torch.no_grad()
|
| 561 |
def log_images(self, batch, N=4, n_row=2, sample=False, ddim_steps=50, ddim_eta=0.0, return_keys=None,
|
| 562 |
quantize_denoised=True, inpaint=True, plot_denoise_rows=False, plot_progressive_rows=False,
|
|
|
|
| 567 |
|
| 568 |
log = dict()
|
| 569 |
z, c = self.get_input(batch, self.first_stage_key, bs=N)
|
| 570 |
+
c_cat, c = c["c_concat"][0], c["c_crossattn"][0]
|
| 571 |
N = min(z.shape[0], N)
|
| 572 |
n_row = min(z.shape[0], n_row)
|
| 573 |
log["reconstruction"] = self.decode_first_stage(z)
|
| 574 |
+
log["control"] = batch['mask'].unsqueeze(1)
|
| 575 |
+
if 'aug_mask' in batch:
|
| 576 |
+
log['aug_mask'] = batch['aug_mask'].unsqueeze(1)
|
| 577 |
+
|
| 578 |
log["conditioning"] = log_txt_as_img((512, 512), batch[self.cond_stage_key], size=16)
|
| 579 |
|
| 580 |
if plot_diffusion_rows:
|
|
|
|
| 616 |
|
| 617 |
if unconditional_guidance_scale > 1.0:
|
| 618 |
uc_cross = self.get_unconditional_conditioning(N)
|
| 619 |
+
uc_cat = list(c_cat) # torch.zeros_like(c_cat)
|
| 620 |
uc_full = {"c_concat": [uc_cat], "c_crossattn": [uc_cross]}
|
| 621 |
samples_cfg, _ = self.sample_log(cond={"c_concat": [c_cat], "c_crossattn": [c]},
|
| 622 |
batch_size=N, ddim=use_ddim,
|
|
|
|
| 628 |
log[f"samples_cfg_scale_{unconditional_guidance_scale:.2f}"] = x_samples_cfg
|
| 629 |
|
| 630 |
return log
|
| 631 |
+
|
| 632 |
+
|
| 633 |
+
def configure_optimizers(self):
|
| 634 |
+
lr = self.learning_rate
|
| 635 |
+
|
| 636 |
+
params = list(self.control_model.parameters())
|
| 637 |
+
if not self.sd_locked:
|
| 638 |
+
params += list(self.model.diffusion_model.output_blocks.parameters())
|
| 639 |
+
params += list(self.model.diffusion_model.out.parameters())
|
| 640 |
+
|
| 641 |
+
opt = torch.optim.AdamW(params, lr=lr)
|
| 642 |
+
return opt
|
| 643 |
+
|
| 644 |
+
|
| 645 |
+
|
cldm/controlnet.py
ADDED
|
@@ -0,0 +1,306 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import torch as th
|
| 3 |
+
import torch.nn as nn
|
| 4 |
+
|
| 5 |
+
from ldm.modules.diffusionmodules.util import (
|
| 6 |
+
conv_nd,
|
| 7 |
+
linear,
|
| 8 |
+
zero_module,
|
| 9 |
+
timestep_embedding,
|
| 10 |
+
)
|
| 11 |
+
|
| 12 |
+
from ldm.modules.attention import SpatialTransformer
|
| 13 |
+
from ldm.modules.diffusionmodules.openaimodel import TimestepEmbedSequential, ResBlock, Downsample, AttentionBlock
|
| 14 |
+
from ldm.util import exists
|
| 15 |
+
|
| 16 |
+
torch.autograd.set_detect_anomaly(True)
|
| 17 |
+
|
| 18 |
+
class Interpolate(nn.Module):
|
| 19 |
+
def __init__(self, mode):
|
| 20 |
+
super(Interpolate, self).__init__()
|
| 21 |
+
self.interp = torch.nn.functional.interpolate
|
| 22 |
+
self.mode = mode
|
| 23 |
+
self.factor = 8
|
| 24 |
+
|
| 25 |
+
def forward(self, x):
|
| 26 |
+
return x
|
| 27 |
+
|
| 28 |
+
class ControlNetPAIR(nn.Module):
|
| 29 |
+
def __init__(
|
| 30 |
+
self,
|
| 31 |
+
image_size,
|
| 32 |
+
in_channels,
|
| 33 |
+
model_channels,
|
| 34 |
+
hint_channels,
|
| 35 |
+
concat_indices,
|
| 36 |
+
num_res_blocks,
|
| 37 |
+
attention_resolutions,
|
| 38 |
+
concat_channels=130,
|
| 39 |
+
dropout=0,
|
| 40 |
+
channel_mult=(1, 2, 4, 8),
|
| 41 |
+
mode='nearest',
|
| 42 |
+
conv_resample=True,
|
| 43 |
+
dims=2,
|
| 44 |
+
use_checkpoint=False,
|
| 45 |
+
use_fp16=False,
|
| 46 |
+
num_heads=-1,
|
| 47 |
+
num_head_channels=-1,
|
| 48 |
+
num_heads_upsample=-1,
|
| 49 |
+
use_scale_shift_norm=False,
|
| 50 |
+
resblock_updown=False,
|
| 51 |
+
use_new_attention_order=False,
|
| 52 |
+
use_spatial_transformer=False, # custom transformer support
|
| 53 |
+
transformer_depth=1, # custom transformer support
|
| 54 |
+
context_dim=None, # custom transformer support
|
| 55 |
+
n_embed=None, # custom support for prediction of discrete ids into codebook of first stage vq model
|
| 56 |
+
legacy=True,
|
| 57 |
+
disable_self_attentions=None,
|
| 58 |
+
num_attention_blocks=None,
|
| 59 |
+
disable_middle_self_attn=False,
|
| 60 |
+
use_linear_in_transformer=False,
|
| 61 |
+
attn_class=['softmax', 'softmax', 'softmax', 'softmax'],
|
| 62 |
+
):
|
| 63 |
+
super().__init__()
|
| 64 |
+
if use_spatial_transformer:
|
| 65 |
+
assert context_dim is not None, 'Fool!! You forgot to include the dimension of your cross-attention conditioning...'
|
| 66 |
+
|
| 67 |
+
if context_dim is not None:
|
| 68 |
+
assert use_spatial_transformer, 'Fool!! You forgot to use the spatial transformer for your cross-attention conditioning...'
|
| 69 |
+
from omegaconf.listconfig import ListConfig
|
| 70 |
+
if type(context_dim) == ListConfig:
|
| 71 |
+
context_dim = list(context_dim)
|
| 72 |
+
|
| 73 |
+
if num_heads_upsample == -1:
|
| 74 |
+
num_heads_upsample = num_heads
|
| 75 |
+
|
| 76 |
+
if num_heads == -1:
|
| 77 |
+
assert num_head_channels != -1, 'Either num_heads or num_head_channels has to be set'
|
| 78 |
+
|
| 79 |
+
if num_head_channels == -1:
|
| 80 |
+
assert num_heads != -1, 'Either num_heads or num_head_channels has to be set'
|
| 81 |
+
|
| 82 |
+
self.dims = dims
|
| 83 |
+
self.image_size = image_size
|
| 84 |
+
self.in_channels = in_channels
|
| 85 |
+
self.model_channels = model_channels
|
| 86 |
+
if isinstance(num_res_blocks, int):
|
| 87 |
+
self.num_res_blocks = len(channel_mult) * [num_res_blocks]
|
| 88 |
+
else:
|
| 89 |
+
if len(num_res_blocks) != len(channel_mult):
|
| 90 |
+
raise ValueError("provide num_res_blocks either as an int (globally constant) or "
|
| 91 |
+
"as a list/tuple (per-level) with the same length as channel_mult")
|
| 92 |
+
self.num_res_blocks = num_res_blocks
|
| 93 |
+
if disable_self_attentions is not None:
|
| 94 |
+
# should be a list of booleans, indicating whether to disable self-attention in TransformerBlocks or not
|
| 95 |
+
assert len(disable_self_attentions) == len(channel_mult)
|
| 96 |
+
if num_attention_blocks is not None:
|
| 97 |
+
assert len(num_attention_blocks) == len(self.num_res_blocks)
|
| 98 |
+
assert all(map(lambda i: self.num_res_blocks[i] >= num_attention_blocks[i], range(len(num_attention_blocks))))
|
| 99 |
+
print(f"Constructor of UNetModel received num_attention_blocks={num_attention_blocks}. "
|
| 100 |
+
f"This option has LESS priority than attention_resolutions {attention_resolutions}, "
|
| 101 |
+
f"i.e., in cases where num_attention_blocks[i] > 0 but 2**i not in attention_resolutions, "
|
| 102 |
+
f"attention will still not be set.")
|
| 103 |
+
|
| 104 |
+
self.attention_resolutions = attention_resolutions
|
| 105 |
+
self.dropout = dropout
|
| 106 |
+
self.channel_mult = channel_mult
|
| 107 |
+
self.conv_resample = conv_resample
|
| 108 |
+
self.use_checkpoint = use_checkpoint
|
| 109 |
+
self.dtype = th.float16 if use_fp16 else th.float32
|
| 110 |
+
self.num_heads = num_heads
|
| 111 |
+
self.num_head_channels = num_head_channels
|
| 112 |
+
self.num_heads_upsample = num_heads_upsample
|
| 113 |
+
self.predict_codebook_ids = n_embed is not None
|
| 114 |
+
|
| 115 |
+
time_embed_dim = model_channels * 4
|
| 116 |
+
self.time_embed = nn.Sequential(
|
| 117 |
+
linear(model_channels, time_embed_dim),
|
| 118 |
+
nn.SiLU(),
|
| 119 |
+
linear(time_embed_dim, time_embed_dim),
|
| 120 |
+
)
|
| 121 |
+
|
| 122 |
+
self.input_blocks = nn.ModuleList(
|
| 123 |
+
[
|
| 124 |
+
TimestepEmbedSequential(
|
| 125 |
+
conv_nd(dims, in_channels, model_channels, 3, padding=1)
|
| 126 |
+
)
|
| 127 |
+
]
|
| 128 |
+
)
|
| 129 |
+
self.zero_convs = nn.ModuleList([self.make_zero_conv(model_channels)])
|
| 130 |
+
self.concat_indices = concat_indices
|
| 131 |
+
self.hint_channels = hint_channels
|
| 132 |
+
h_ch = sum([hint_channels[i] for i in concat_indices ])
|
| 133 |
+
|
| 134 |
+
self.input_hint_block = TimestepEmbedSequential(
|
| 135 |
+
Interpolate('nearest'),
|
| 136 |
+
conv_nd(self.dims, concat_channels, self.model_channels, 3, padding=1),
|
| 137 |
+
nn.SiLU(),
|
| 138 |
+
zero_module(conv_nd(self.dims, self.model_channels, self.model_channels, 3, padding=1)))
|
| 139 |
+
|
| 140 |
+
self._feature_size = model_channels
|
| 141 |
+
input_block_chans = [model_channels]
|
| 142 |
+
ch = model_channels
|
| 143 |
+
ds = 1
|
| 144 |
+
for level, mult in enumerate(channel_mult):
|
| 145 |
+
for nr in range(self.num_res_blocks[level]):
|
| 146 |
+
layers = [
|
| 147 |
+
ResBlock(
|
| 148 |
+
ch,
|
| 149 |
+
time_embed_dim,
|
| 150 |
+
dropout,
|
| 151 |
+
out_channels=mult * model_channels,
|
| 152 |
+
dims=dims,
|
| 153 |
+
use_checkpoint=use_checkpoint,
|
| 154 |
+
use_scale_shift_norm=use_scale_shift_norm,
|
| 155 |
+
)
|
| 156 |
+
]
|
| 157 |
+
ch = mult * model_channels
|
| 158 |
+
if ds in attention_resolutions:
|
| 159 |
+
if num_head_channels == -1:
|
| 160 |
+
dim_head = ch // num_heads
|
| 161 |
+
else:
|
| 162 |
+
num_heads = ch // num_head_channels
|
| 163 |
+
dim_head = num_head_channels
|
| 164 |
+
if legacy:
|
| 165 |
+
# num_heads = 1
|
| 166 |
+
dim_head = ch // num_heads if use_spatial_transformer else num_head_channels
|
| 167 |
+
if exists(disable_self_attentions):
|
| 168 |
+
disabled_sa = disable_self_attentions[level]
|
| 169 |
+
else:
|
| 170 |
+
disabled_sa = False
|
| 171 |
+
|
| 172 |
+
if not exists(num_attention_blocks) or nr < num_attention_blocks[level]:
|
| 173 |
+
layers.append(
|
| 174 |
+
AttentionBlock(
|
| 175 |
+
ch,
|
| 176 |
+
use_checkpoint=use_checkpoint,
|
| 177 |
+
num_heads=num_heads,
|
| 178 |
+
num_head_channels=dim_head,
|
| 179 |
+
use_new_attention_order=use_new_attention_order,
|
| 180 |
+
) if not use_spatial_transformer else SpatialTransformer(
|
| 181 |
+
ch, num_heads, dim_head, depth=transformer_depth, context_dim=context_dim,
|
| 182 |
+
disable_self_attn=disabled_sa, use_linear=use_linear_in_transformer,
|
| 183 |
+
use_checkpoint=use_checkpoint, attn1_mode=attn_class[level], obj_feat_dim=hint_channels[level]
|
| 184 |
+
)
|
| 185 |
+
)
|
| 186 |
+
self.input_blocks.append(TimestepEmbedSequential(*layers))
|
| 187 |
+
self.zero_convs.append(self.make_zero_conv(ch))
|
| 188 |
+
self._feature_size += ch
|
| 189 |
+
input_block_chans.append(ch)
|
| 190 |
+
if level != len(channel_mult) - 1:
|
| 191 |
+
out_ch = ch
|
| 192 |
+
self.input_blocks.append(
|
| 193 |
+
TimestepEmbedSequential(
|
| 194 |
+
ResBlock(
|
| 195 |
+
ch,
|
| 196 |
+
time_embed_dim,
|
| 197 |
+
dropout,
|
| 198 |
+
out_channels=out_ch,
|
| 199 |
+
dims=dims,
|
| 200 |
+
use_checkpoint=use_checkpoint,
|
| 201 |
+
use_scale_shift_norm=use_scale_shift_norm,
|
| 202 |
+
down=True,
|
| 203 |
+
)
|
| 204 |
+
if resblock_updown
|
| 205 |
+
else Downsample(
|
| 206 |
+
ch, conv_resample, dims=dims, out_channels=out_ch
|
| 207 |
+
)
|
| 208 |
+
)
|
| 209 |
+
)
|
| 210 |
+
ch = out_ch
|
| 211 |
+
input_block_chans.append(ch)
|
| 212 |
+
self.zero_convs.append(self.make_zero_conv(ch))
|
| 213 |
+
ds *= 2
|
| 214 |
+
self._feature_size += ch
|
| 215 |
+
|
| 216 |
+
if num_head_channels == -1:
|
| 217 |
+
dim_head = ch // num_heads
|
| 218 |
+
else:
|
| 219 |
+
num_heads = ch // num_head_channels
|
| 220 |
+
dim_head = num_head_channels
|
| 221 |
+
if legacy:
|
| 222 |
+
# num_heads = 1
|
| 223 |
+
dim_head = ch // num_heads if use_spatial_transformer else num_head_channels
|
| 224 |
+
self.middle_block = TimestepEmbedSequential(
|
| 225 |
+
ResBlock(
|
| 226 |
+
ch,
|
| 227 |
+
time_embed_dim,
|
| 228 |
+
# hint_channels[-1],
|
| 229 |
+
dropout,
|
| 230 |
+
dims=dims,
|
| 231 |
+
use_checkpoint=use_checkpoint,
|
| 232 |
+
use_scale_shift_norm=use_scale_shift_norm,
|
| 233 |
+
),
|
| 234 |
+
AttentionBlock(
|
| 235 |
+
ch,
|
| 236 |
+
use_checkpoint=use_checkpoint,
|
| 237 |
+
num_heads=num_heads,
|
| 238 |
+
num_head_channels=dim_head,
|
| 239 |
+
use_new_attention_order=use_new_attention_order,
|
| 240 |
+
) if not use_spatial_transformer else SpatialTransformer( # always uses a self-attn
|
| 241 |
+
ch, num_heads, dim_head, depth=transformer_depth, context_dim=context_dim,
|
| 242 |
+
disable_self_attn=disable_middle_self_attn, use_linear=use_linear_in_transformer,
|
| 243 |
+
use_checkpoint=use_checkpoint
|
| 244 |
+
),
|
| 245 |
+
ResBlock(
|
| 246 |
+
ch,
|
| 247 |
+
time_embed_dim,
|
| 248 |
+
# hint_channels[-1],
|
| 249 |
+
dropout,
|
| 250 |
+
dims=dims,
|
| 251 |
+
use_checkpoint=use_checkpoint,
|
| 252 |
+
use_scale_shift_norm=use_scale_shift_norm,
|
| 253 |
+
),
|
| 254 |
+
)
|
| 255 |
+
self.middle_block_out = self.make_zero_conv(ch)
|
| 256 |
+
self._feature_size += ch
|
| 257 |
+
|
| 258 |
+
def make_zero_conv(self, channels):
|
| 259 |
+
return TimestepEmbedSequential(zero_module(conv_nd(self.dims, channels, channels, 1, padding=0)))
|
| 260 |
+
|
| 261 |
+
def forward(self, x, hint, timesteps, context, **kwargs):
|
| 262 |
+
hint_list = []
|
| 263 |
+
concat_hint = hint[-1]
|
| 264 |
+
hint_c = hint[:-1]
|
| 265 |
+
|
| 266 |
+
if not isinstance(hint_c, list):
|
| 267 |
+
for _ in range(len(self.channel_mult)):
|
| 268 |
+
hint_list.append(hint_c)
|
| 269 |
+
else:
|
| 270 |
+
hint_list = hint_c
|
| 271 |
+
while len(hint_list) < 4:
|
| 272 |
+
hint_list.append(hint_c[-1])
|
| 273 |
+
|
| 274 |
+
mask = hint_c[0][:,-1].unsqueeze(1) #panoptic
|
| 275 |
+
|
| 276 |
+
t_emb = timestep_embedding(timesteps, self.model_channels, repeat_only=False)
|
| 277 |
+
emb = self.time_embed(t_emb)
|
| 278 |
+
|
| 279 |
+
guided_hint = self.input_hint_block(concat_hint, emb, context, x.shape)
|
| 280 |
+
outs = []
|
| 281 |
+
|
| 282 |
+
h = x.type(self.dtype)
|
| 283 |
+
|
| 284 |
+
cnt = self.num_res_blocks[0] + 1
|
| 285 |
+
i = 0
|
| 286 |
+
for module, zero_conv in zip(self.input_blocks, self.zero_convs):
|
| 287 |
+
if guided_hint is not None:
|
| 288 |
+
h = module(h, emb, context, hint_list[i], mask)
|
| 289 |
+
h += guided_hint
|
| 290 |
+
guided_hint = None
|
| 291 |
+
else:
|
| 292 |
+
h = module(h, emb, context, hint_list[i], mask)
|
| 293 |
+
outs.append(zero_conv(h, emb, context))
|
| 294 |
+
|
| 295 |
+
cnt -= 1
|
| 296 |
+
if cnt == 0:
|
| 297 |
+
if i<len(self.num_res_blocks):
|
| 298 |
+
cnt = self.num_res_blocks[i] + 1
|
| 299 |
+
else:
|
| 300 |
+
if (i+1)<len(self.num_res_blocks):
|
| 301 |
+
i += 1
|
| 302 |
+
|
| 303 |
+
h = self.middle_block(h, emb, context, hint_list[-1], mask)
|
| 304 |
+
outs.append(self.middle_block_out(h, emb, context))
|
| 305 |
+
|
| 306 |
+
return outs
|
cldm/ddim_hacked.py
CHANGED
|
@@ -316,7 +316,6 @@ class DDIMSampler(object):
|
|
| 316 |
return x_dec
|
| 317 |
|
| 318 |
|
| 319 |
-
|
| 320 |
class DDIMSamplerSpaCFG(DDIMSampler):
|
| 321 |
@torch.no_grad()
|
| 322 |
def p_sample_ddim(self, x, c, t, index, repeat_noise=False, use_original_steps=False, quantize_denoised=False,
|
|
@@ -332,8 +331,8 @@ class DDIMSamplerSpaCFG(DDIMSampler):
|
|
| 332 |
model_uncond = self.model.apply_model(x, t, unconditional_conditioning[0])
|
| 333 |
model_struct = self.model.apply_model(x, t, unconditional_conditioning[1])
|
| 334 |
model_struct_app = self.model.apply_model(x, t, unconditional_conditioning[2])
|
| 335 |
-
|
| 336 |
-
model_output = model_uncond + sS * (model_struct - model_uncond) + sF * (model_struct_app - model_struct) + sT * (model_t -
|
| 337 |
|
| 338 |
if self.model.parameterization == "v":
|
| 339 |
e_t = self.model.predict_eps_from_z_and_v(x, t, model_output)
|
|
|
|
| 316 |
return x_dec
|
| 317 |
|
| 318 |
|
|
|
|
| 319 |
class DDIMSamplerSpaCFG(DDIMSampler):
|
| 320 |
@torch.no_grad()
|
| 321 |
def p_sample_ddim(self, x, c, t, index, repeat_noise=False, use_original_steps=False, quantize_denoised=False,
|
|
|
|
| 331 |
model_uncond = self.model.apply_model(x, t, unconditional_conditioning[0])
|
| 332 |
model_struct = self.model.apply_model(x, t, unconditional_conditioning[1])
|
| 333 |
model_struct_app = self.model.apply_model(x, t, unconditional_conditioning[2])
|
| 334 |
+
sS, sF, sT = unconditional_guidance_scale
|
| 335 |
+
model_output = model_uncond + sS * (model_struct - model_uncond) + sF * (model_struct_app - model_struct) + sT * (model_t - model_uncond)
|
| 336 |
|
| 337 |
if self.model.parameterization == "v":
|
| 338 |
e_t = self.model.predict_eps_from_z_and_v(x, t, model_output)
|
cldm/logger.py
CHANGED
|
@@ -114,16 +114,16 @@ class SetupCallback(Callback):
|
|
| 114 |
OmegaConf.save(OmegaConf.create({"lightning": self.lightning_config}),
|
| 115 |
os.path.join(self.cfgdir, "{}-lightning.yaml".format(self.now)))
|
| 116 |
|
| 117 |
-
else:
|
| 118 |
-
|
| 119 |
-
|
| 120 |
-
|
| 121 |
-
|
| 122 |
-
|
| 123 |
-
|
| 124 |
-
|
| 125 |
-
|
| 126 |
-
|
| 127 |
|
| 128 |
|
| 129 |
class ImageLogger(Callback):
|
|
|
|
| 114 |
OmegaConf.save(OmegaConf.create({"lightning": self.lightning_config}),
|
| 115 |
os.path.join(self.cfgdir, "{}-lightning.yaml".format(self.now)))
|
| 116 |
|
| 117 |
+
# else:
|
| 118 |
+
# # ModelCheckpoint callback created log directory --- remove it
|
| 119 |
+
# if not self.resume and os.path.exists(self.logdir):
|
| 120 |
+
# dst, name = os.path.split(self.logdir)
|
| 121 |
+
# dst = os.path.join(dst, "child_runs", name)
|
| 122 |
+
# os.makedirs(os.path.split(dst)[0], exist_ok=True)
|
| 123 |
+
# try:
|
| 124 |
+
# os.rename(self.logdir, dst)
|
| 125 |
+
# except FileNotFoundError:
|
| 126 |
+
# pass
|
| 127 |
|
| 128 |
|
| 129 |
class ImageLogger(Callback):
|
configs/{sap_fixed_hintnet_v15.yaml → pair_diff.yaml}
RENAMED
|
@@ -1,9 +1,9 @@
|
|
| 1 |
model:
|
| 2 |
-
target: cldm.cldm.
|
| 3 |
learning_rate: 1.5e-05
|
| 4 |
sd_locked: True
|
| 5 |
only_mid_control: False
|
| 6 |
-
init_ckpt: './models/
|
| 7 |
params:
|
| 8 |
linear_start: 0.00085
|
| 9 |
linear_end: 0.0120
|
|
@@ -21,14 +21,17 @@ model:
|
|
| 21 |
scale_factor: 0.18215
|
| 22 |
use_ema: False
|
| 23 |
only_mid_control: False
|
|
|
|
|
|
|
| 24 |
|
| 25 |
control_stage_config:
|
| 26 |
-
target: cldm.
|
| 27 |
params:
|
| 28 |
-
input_hint_block: 'fixed'
|
| 29 |
image_size: 32 # unused
|
| 30 |
in_channels: 4
|
| 31 |
-
|
|
|
|
|
|
|
| 32 |
model_channels: 320
|
| 33 |
attention_resolutions: [ 4, 2, 1 ]
|
| 34 |
num_res_blocks: 2
|
|
@@ -39,6 +42,7 @@ model:
|
|
| 39 |
context_dim: 768
|
| 40 |
use_checkpoint: True
|
| 41 |
legacy: False
|
|
|
|
| 42 |
|
| 43 |
unet_config:
|
| 44 |
target: cldm.cldm.ControlledUnetModel
|
|
@@ -87,16 +91,25 @@ model:
|
|
| 87 |
data:
|
| 88 |
target: cldm.data.DataModuleFromConfig
|
| 89 |
params:
|
| 90 |
-
batch_size:
|
| 91 |
wrap: True
|
|
|
|
| 92 |
train:
|
| 93 |
target: dataset.txtseg.COCOTrain
|
| 94 |
params:
|
|
|
|
|
|
|
|
|
|
|
|
|
| 95 |
size: 512
|
| 96 |
validation:
|
| 97 |
target: dataset.txtseg.COCOValidation
|
| 98 |
params:
|
| 99 |
size: 512
|
|
|
|
|
|
|
|
|
|
|
|
|
| 100 |
|
| 101 |
|
| 102 |
lightning:
|
|
@@ -111,4 +124,4 @@ lightning:
|
|
| 111 |
|
| 112 |
trainer:
|
| 113 |
benchmark: True
|
| 114 |
-
accumulate_grad_batches:
|
|
|
|
| 1 |
model:
|
| 2 |
+
target: cldm.cldm.PAIRDiffusion
|
| 3 |
learning_rate: 1.5e-05
|
| 4 |
sd_locked: True
|
| 5 |
only_mid_control: False
|
| 6 |
+
init_ckpt: './models/pair_diff_init.ckpt'
|
| 7 |
params:
|
| 8 |
linear_start: 0.00085
|
| 9 |
linear_end: 0.0120
|
|
|
|
| 21 |
scale_factor: 0.18215
|
| 22 |
use_ema: False
|
| 23 |
only_mid_control: False
|
| 24 |
+
appearance_net_locked: True
|
| 25 |
+
app_net: 'DINO'
|
| 26 |
|
| 27 |
control_stage_config:
|
| 28 |
+
target: cldm.controlnet.ControlNetPAIR
|
| 29 |
params:
|
|
|
|
| 30 |
image_size: 32 # unused
|
| 31 |
in_channels: 4
|
| 32 |
+
concat_indices: [0,1]
|
| 33 |
+
concat_channels: 130
|
| 34 |
+
hint_channels: [1026, 1026, -1, -1] #(1024 + 2)
|
| 35 |
model_channels: 320
|
| 36 |
attention_resolutions: [ 4, 2, 1 ]
|
| 37 |
num_res_blocks: 2
|
|
|
|
| 42 |
context_dim: 768
|
| 43 |
use_checkpoint: True
|
| 44 |
legacy: False
|
| 45 |
+
attn_class: ['maskguided', 'maskguided', 'softmax', 'softmax']
|
| 46 |
|
| 47 |
unet_config:
|
| 48 |
target: cldm.cldm.ControlledUnetModel
|
|
|
|
| 91 |
data:
|
| 92 |
target: cldm.data.DataModuleFromConfig
|
| 93 |
params:
|
| 94 |
+
batch_size: 2
|
| 95 |
wrap: True
|
| 96 |
+
num_workers: 4
|
| 97 |
train:
|
| 98 |
target: dataset.txtseg.COCOTrain
|
| 99 |
params:
|
| 100 |
+
image_dir:
|
| 101 |
+
caption_file:
|
| 102 |
+
panoptic_mask_dir:
|
| 103 |
+
seg_dir:
|
| 104 |
size: 512
|
| 105 |
validation:
|
| 106 |
target: dataset.txtseg.COCOValidation
|
| 107 |
params:
|
| 108 |
size: 512
|
| 109 |
+
image_dir:
|
| 110 |
+
caption_file:
|
| 111 |
+
panoptic_mask_dir:
|
| 112 |
+
seg_dir:
|
| 113 |
|
| 114 |
|
| 115 |
lightning:
|
|
|
|
| 124 |
|
| 125 |
trainer:
|
| 126 |
benchmark: True
|
| 127 |
+
accumulate_grad_batches: 2
|
ldm/ldm/util.py
ADDED
|
@@ -0,0 +1,197 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import importlib
|
| 2 |
+
|
| 3 |
+
import torch
|
| 4 |
+
from torch import optim
|
| 5 |
+
import numpy as np
|
| 6 |
+
|
| 7 |
+
from inspect import isfunction
|
| 8 |
+
from PIL import Image, ImageDraw, ImageFont
|
| 9 |
+
|
| 10 |
+
|
| 11 |
+
def log_txt_as_img(wh, xc, size=10):
|
| 12 |
+
# wh a tuple of (width, height)
|
| 13 |
+
# xc a list of captions to plot
|
| 14 |
+
b = len(xc)
|
| 15 |
+
txts = list()
|
| 16 |
+
for bi in range(b):
|
| 17 |
+
txt = Image.new("RGB", wh, color="white")
|
| 18 |
+
draw = ImageDraw.Draw(txt)
|
| 19 |
+
font = ImageFont.truetype('font/DejaVuSans.ttf', size=size)
|
| 20 |
+
nc = int(40 * (wh[0] / 256))
|
| 21 |
+
lines = "\n".join(xc[bi][start:start + nc] for start in range(0, len(xc[bi]), nc))
|
| 22 |
+
|
| 23 |
+
try:
|
| 24 |
+
draw.text((0, 0), lines, fill="black", font=font)
|
| 25 |
+
except UnicodeEncodeError:
|
| 26 |
+
print("Cant encode string for logging. Skipping.")
|
| 27 |
+
|
| 28 |
+
txt = np.array(txt).transpose(2, 0, 1) / 127.5 - 1.0
|
| 29 |
+
txts.append(txt)
|
| 30 |
+
txts = np.stack(txts)
|
| 31 |
+
txts = torch.tensor(txts)
|
| 32 |
+
return txts
|
| 33 |
+
|
| 34 |
+
|
| 35 |
+
def ismap(x):
|
| 36 |
+
if not isinstance(x, torch.Tensor):
|
| 37 |
+
return False
|
| 38 |
+
return (len(x.shape) == 4) and (x.shape[1] > 3)
|
| 39 |
+
|
| 40 |
+
|
| 41 |
+
def isimage(x):
|
| 42 |
+
if not isinstance(x,torch.Tensor):
|
| 43 |
+
return False
|
| 44 |
+
return (len(x.shape) == 4) and (x.shape[1] == 3 or x.shape[1] == 1)
|
| 45 |
+
|
| 46 |
+
|
| 47 |
+
def exists(x):
|
| 48 |
+
return x is not None
|
| 49 |
+
|
| 50 |
+
|
| 51 |
+
def default(val, d):
|
| 52 |
+
if exists(val):
|
| 53 |
+
return val
|
| 54 |
+
return d() if isfunction(d) else d
|
| 55 |
+
|
| 56 |
+
|
| 57 |
+
def mean_flat(tensor):
|
| 58 |
+
"""
|
| 59 |
+
https://github.com/openai/guided-diffusion/blob/27c20a8fab9cb472df5d6bdd6c8d11c8f430b924/guided_diffusion/nn.py#L86
|
| 60 |
+
Take the mean over all non-batch dimensions.
|
| 61 |
+
"""
|
| 62 |
+
return tensor.mean(dim=list(range(1, len(tensor.shape))))
|
| 63 |
+
|
| 64 |
+
|
| 65 |
+
def count_params(model, verbose=False):
|
| 66 |
+
total_params = sum(p.numel() for p in model.parameters())
|
| 67 |
+
if verbose:
|
| 68 |
+
print(f"{model.__class__.__name__} has {total_params*1.e-6:.2f} M params.")
|
| 69 |
+
return total_params
|
| 70 |
+
|
| 71 |
+
|
| 72 |
+
def instantiate_from_config(config):
|
| 73 |
+
if not "target" in config:
|
| 74 |
+
if config == '__is_first_stage__':
|
| 75 |
+
return None
|
| 76 |
+
elif config == "__is_unconditional__":
|
| 77 |
+
return None
|
| 78 |
+
raise KeyError("Expected key `target` to instantiate.")
|
| 79 |
+
return get_obj_from_str(config["target"])(**config.get("params", dict()))
|
| 80 |
+
|
| 81 |
+
|
| 82 |
+
def get_obj_from_str(string, reload=False):
|
| 83 |
+
module, cls = string.rsplit(".", 1)
|
| 84 |
+
if reload:
|
| 85 |
+
module_imp = importlib.import_module(module)
|
| 86 |
+
importlib.reload(module_imp)
|
| 87 |
+
return getattr(importlib.import_module(module, package=None), cls)
|
| 88 |
+
|
| 89 |
+
|
| 90 |
+
class AdamWwithEMAandWings(optim.Optimizer):
|
| 91 |
+
# credit to https://gist.github.com/crowsonkb/65f7265353f403714fce3b2595e0b298
|
| 92 |
+
def __init__(self, params, lr=1.e-3, betas=(0.9, 0.999), eps=1.e-8, # TODO: check hyperparameters before using
|
| 93 |
+
weight_decay=1.e-2, amsgrad=False, ema_decay=0.9999, # ema decay to match previous code
|
| 94 |
+
ema_power=1., param_names=()):
|
| 95 |
+
"""AdamW that saves EMA versions of the parameters."""
|
| 96 |
+
if not 0.0 <= lr:
|
| 97 |
+
raise ValueError("Invalid learning rate: {}".format(lr))
|
| 98 |
+
if not 0.0 <= eps:
|
| 99 |
+
raise ValueError("Invalid epsilon value: {}".format(eps))
|
| 100 |
+
if not 0.0 <= betas[0] < 1.0:
|
| 101 |
+
raise ValueError("Invalid beta parameter at index 0: {}".format(betas[0]))
|
| 102 |
+
if not 0.0 <= betas[1] < 1.0:
|
| 103 |
+
raise ValueError("Invalid beta parameter at index 1: {}".format(betas[1]))
|
| 104 |
+
if not 0.0 <= weight_decay:
|
| 105 |
+
raise ValueError("Invalid weight_decay value: {}".format(weight_decay))
|
| 106 |
+
if not 0.0 <= ema_decay <= 1.0:
|
| 107 |
+
raise ValueError("Invalid ema_decay value: {}".format(ema_decay))
|
| 108 |
+
defaults = dict(lr=lr, betas=betas, eps=eps,
|
| 109 |
+
weight_decay=weight_decay, amsgrad=amsgrad, ema_decay=ema_decay,
|
| 110 |
+
ema_power=ema_power, param_names=param_names)
|
| 111 |
+
super().__init__(params, defaults)
|
| 112 |
+
|
| 113 |
+
def __setstate__(self, state):
|
| 114 |
+
super().__setstate__(state)
|
| 115 |
+
for group in self.param_groups:
|
| 116 |
+
group.setdefault('amsgrad', False)
|
| 117 |
+
|
| 118 |
+
@torch.no_grad()
|
| 119 |
+
def step(self, closure=None):
|
| 120 |
+
"""Performs a single optimization step.
|
| 121 |
+
Args:
|
| 122 |
+
closure (callable, optional): A closure that reevaluates the model
|
| 123 |
+
and returns the loss.
|
| 124 |
+
"""
|
| 125 |
+
loss = None
|
| 126 |
+
if closure is not None:
|
| 127 |
+
with torch.enable_grad():
|
| 128 |
+
loss = closure()
|
| 129 |
+
|
| 130 |
+
for group in self.param_groups:
|
| 131 |
+
params_with_grad = []
|
| 132 |
+
grads = []
|
| 133 |
+
exp_avgs = []
|
| 134 |
+
exp_avg_sqs = []
|
| 135 |
+
ema_params_with_grad = []
|
| 136 |
+
state_sums = []
|
| 137 |
+
max_exp_avg_sqs = []
|
| 138 |
+
state_steps = []
|
| 139 |
+
amsgrad = group['amsgrad']
|
| 140 |
+
beta1, beta2 = group['betas']
|
| 141 |
+
ema_decay = group['ema_decay']
|
| 142 |
+
ema_power = group['ema_power']
|
| 143 |
+
|
| 144 |
+
for p in group['params']:
|
| 145 |
+
if p.grad is None:
|
| 146 |
+
continue
|
| 147 |
+
params_with_grad.append(p)
|
| 148 |
+
if p.grad.is_sparse:
|
| 149 |
+
raise RuntimeError('AdamW does not support sparse gradients')
|
| 150 |
+
grads.append(p.grad)
|
| 151 |
+
|
| 152 |
+
state = self.state[p]
|
| 153 |
+
|
| 154 |
+
# State initialization
|
| 155 |
+
if len(state) == 0:
|
| 156 |
+
state['step'] = 0
|
| 157 |
+
# Exponential moving average of gradient values
|
| 158 |
+
state['exp_avg'] = torch.zeros_like(p, memory_format=torch.preserve_format)
|
| 159 |
+
# Exponential moving average of squared gradient values
|
| 160 |
+
state['exp_avg_sq'] = torch.zeros_like(p, memory_format=torch.preserve_format)
|
| 161 |
+
if amsgrad:
|
| 162 |
+
# Maintains max of all exp. moving avg. of sq. grad. values
|
| 163 |
+
state['max_exp_avg_sq'] = torch.zeros_like(p, memory_format=torch.preserve_format)
|
| 164 |
+
# Exponential moving average of parameter values
|
| 165 |
+
state['param_exp_avg'] = p.detach().float().clone()
|
| 166 |
+
|
| 167 |
+
exp_avgs.append(state['exp_avg'])
|
| 168 |
+
exp_avg_sqs.append(state['exp_avg_sq'])
|
| 169 |
+
ema_params_with_grad.append(state['param_exp_avg'])
|
| 170 |
+
|
| 171 |
+
if amsgrad:
|
| 172 |
+
max_exp_avg_sqs.append(state['max_exp_avg_sq'])
|
| 173 |
+
|
| 174 |
+
# update the steps for each param group update
|
| 175 |
+
state['step'] += 1
|
| 176 |
+
# record the step after step update
|
| 177 |
+
state_steps.append(state['step'])
|
| 178 |
+
|
| 179 |
+
optim._functional.adamw(params_with_grad,
|
| 180 |
+
grads,
|
| 181 |
+
exp_avgs,
|
| 182 |
+
exp_avg_sqs,
|
| 183 |
+
max_exp_avg_sqs,
|
| 184 |
+
state_steps,
|
| 185 |
+
amsgrad=amsgrad,
|
| 186 |
+
beta1=beta1,
|
| 187 |
+
beta2=beta2,
|
| 188 |
+
lr=group['lr'],
|
| 189 |
+
weight_decay=group['weight_decay'],
|
| 190 |
+
eps=group['eps'],
|
| 191 |
+
maximize=False)
|
| 192 |
+
|
| 193 |
+
cur_ema_decay = min(ema_decay, 1 - state['step'] ** -ema_power)
|
| 194 |
+
for param, ema_param in zip(params_with_grad, ema_params_with_grad):
|
| 195 |
+
ema_param.mul_(cur_ema_decay).add_(param.float(), alpha=1 - cur_ema_decay)
|
| 196 |
+
|
| 197 |
+
return loss
|
ldm/models/diffusion/ddim.py
CHANGED
|
@@ -194,9 +194,19 @@ class DDIMSampler(object):
|
|
| 194 |
c_in = dict()
|
| 195 |
for k in c:
|
| 196 |
if isinstance(c[k], list):
|
| 197 |
-
c_in[k] = [
|
| 198 |
-
|
| 199 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 200 |
else:
|
| 201 |
c_in[k] = torch.cat([
|
| 202 |
unconditional_conditioning[k],
|
|
@@ -333,4 +343,5 @@ class DDIMSampler(object):
|
|
| 333 |
unconditional_guidance_scale=unconditional_guidance_scale,
|
| 334 |
unconditional_conditioning=unconditional_conditioning)
|
| 335 |
if callback: callback(i)
|
| 336 |
-
return x_dec
|
|
|
|
|
|
| 194 |
c_in = dict()
|
| 195 |
for k in c:
|
| 196 |
if isinstance(c[k], list):
|
| 197 |
+
c_in[k] = []
|
| 198 |
+
if isinstance(c[k][0], list):
|
| 199 |
+
for i in range(len(c[k])):
|
| 200 |
+
c_ = []
|
| 201 |
+
for j in range(len(c[k][i])):
|
| 202 |
+
c_.append(torch.cat([
|
| 203 |
+
unconditional_conditioning[k][i][j],
|
| 204 |
+
c[k][i][j]]) )
|
| 205 |
+
c_in[k].append(c_)
|
| 206 |
+
else:
|
| 207 |
+
c_in[k] = [torch.cat([
|
| 208 |
+
unconditional_conditioning[k][i],
|
| 209 |
+
c[k][i]]) for i in range(len(c[k]))]
|
| 210 |
else:
|
| 211 |
c_in[k] = torch.cat([
|
| 212 |
unconditional_conditioning[k],
|
|
|
|
| 343 |
unconditional_guidance_scale=unconditional_guidance_scale,
|
| 344 |
unconditional_conditioning=unconditional_conditioning)
|
| 345 |
if callback: callback(i)
|
| 346 |
+
return x_dec
|
| 347 |
+
|
ldm/modules/attention.py
CHANGED
|
@@ -42,7 +42,7 @@ def init_(tensor):
|
|
| 42 |
dim = tensor.shape[-1]
|
| 43 |
std = 1 / math.sqrt(dim)
|
| 44 |
tensor.uniform_(-std, std)
|
| 45 |
-
return tensor
|
| 46 |
|
| 47 |
|
| 48 |
# feedforward
|
|
@@ -143,7 +143,7 @@ class SpatialSelfAttention(nn.Module):
|
|
| 143 |
|
| 144 |
|
| 145 |
class CrossAttention(nn.Module):
|
| 146 |
-
def __init__(self, query_dim, context_dim=None, heads=8, dim_head=64, dropout=0
|
| 147 |
super().__init__()
|
| 148 |
inner_dim = dim_head * heads
|
| 149 |
context_dim = default(context_dim, query_dim)
|
|
@@ -160,7 +160,7 @@ class CrossAttention(nn.Module):
|
|
| 160 |
nn.Dropout(dropout)
|
| 161 |
)
|
| 162 |
|
| 163 |
-
def forward(self, x, context=None, mask=None):
|
| 164 |
h = self.heads
|
| 165 |
|
| 166 |
q = self.to_q(x)
|
|
@@ -194,6 +194,34 @@ class CrossAttention(nn.Module):
|
|
| 194 |
return self.to_out(out)
|
| 195 |
|
| 196 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 197 |
class MemoryEfficientCrossAttention(nn.Module):
|
| 198 |
# https://github.com/MatthieuTPHR/diffusers/blob/d80b531ff8060ec1ea982b65a1b8df70f73aa67c/src/diffusers/models/attention.py#L223
|
| 199 |
def __init__(self, query_dim, context_dim=None, heads=8, dim_head=64, dropout=0.0):
|
|
@@ -246,17 +274,19 @@ class MemoryEfficientCrossAttention(nn.Module):
|
|
| 246 |
class BasicTransformerBlock(nn.Module):
|
| 247 |
ATTENTION_MODES = {
|
| 248 |
"softmax": CrossAttention, # vanilla attention
|
| 249 |
-
"softmax-xformers": MemoryEfficientCrossAttention
|
|
|
|
| 250 |
}
|
| 251 |
def __init__(self, dim, n_heads, d_head, dropout=0., context_dim=None, gated_ff=True, checkpoint=True,
|
| 252 |
-
disable_self_attn=False):
|
| 253 |
super().__init__()
|
| 254 |
attn_mode = "softmax-xformers" if XFORMERS_IS_AVAILBLE else "softmax"
|
| 255 |
assert attn_mode in self.ATTENTION_MODES
|
| 256 |
attn_cls = self.ATTENTION_MODES[attn_mode]
|
|
|
|
| 257 |
self.disable_self_attn = disable_self_attn
|
| 258 |
-
self.attn1 =
|
| 259 |
-
context_dim=context_dim if self.disable_self_attn else None) # is a self-attention if not self.disable_self_attn
|
| 260 |
self.ff = FeedForward(dim, dropout=dropout, glu=gated_ff)
|
| 261 |
self.attn2 = attn_cls(query_dim=dim, context_dim=context_dim,
|
| 262 |
heads=n_heads, dim_head=d_head, dropout=dropout) # is self-attn if context is none
|
|
@@ -265,11 +295,17 @@ class BasicTransformerBlock(nn.Module):
|
|
| 265 |
self.norm3 = nn.LayerNorm(dim)
|
| 266 |
self.checkpoint = checkpoint
|
| 267 |
|
| 268 |
-
|
| 269 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 270 |
|
| 271 |
-
def _forward(self, x, context=None):
|
| 272 |
-
x = self.attn1(self.norm1(x), context=context if self.disable_self_attn else None
|
|
|
|
| 273 |
x = self.attn2(self.norm2(x), context=context) + x
|
| 274 |
x = self.ff(self.norm3(x)) + x
|
| 275 |
return x
|
|
@@ -287,7 +323,7 @@ class SpatialTransformer(nn.Module):
|
|
| 287 |
def __init__(self, in_channels, n_heads, d_head,
|
| 288 |
depth=1, dropout=0., context_dim=None,
|
| 289 |
disable_self_attn=False, use_linear=False,
|
| 290 |
-
use_checkpoint=True):
|
| 291 |
super().__init__()
|
| 292 |
if exists(context_dim) and not isinstance(context_dim, list):
|
| 293 |
context_dim = [context_dim]
|
|
@@ -305,7 +341,8 @@ class SpatialTransformer(nn.Module):
|
|
| 305 |
|
| 306 |
self.transformer_blocks = nn.ModuleList(
|
| 307 |
[BasicTransformerBlock(inner_dim, n_heads, d_head, dropout=dropout, context_dim=context_dim[d],
|
| 308 |
-
disable_self_attn=disable_self_attn, checkpoint=use_checkpoint
|
|
|
|
| 309 |
for d in range(depth)]
|
| 310 |
)
|
| 311 |
if not use_linear:
|
|
@@ -318,11 +355,20 @@ class SpatialTransformer(nn.Module):
|
|
| 318 |
self.proj_out = zero_module(nn.Linear(in_channels, inner_dim))
|
| 319 |
self.use_linear = use_linear
|
| 320 |
|
| 321 |
-
def forward(self, x, context=None):
|
| 322 |
# note: if no context is given, cross-attention defaults to self-attention
|
| 323 |
if not isinstance(context, list):
|
| 324 |
context = [context]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 325 |
b, c, h, w = x.shape
|
|
|
|
|
|
|
|
|
|
|
|
|
| 326 |
x_in = x
|
| 327 |
x = self.norm(x)
|
| 328 |
if not self.use_linear:
|
|
@@ -331,7 +377,7 @@ class SpatialTransformer(nn.Module):
|
|
| 331 |
if self.use_linear:
|
| 332 |
x = self.proj_in(x)
|
| 333 |
for i, block in enumerate(self.transformer_blocks):
|
| 334 |
-
x = block(x, context=context[i])
|
| 335 |
if self.use_linear:
|
| 336 |
x = self.proj_out(x)
|
| 337 |
x = rearrange(x, 'b (h w) c -> b c h w', h=h, w=w).contiguous()
|
|
|
|
| 42 |
dim = tensor.shape[-1]
|
| 43 |
std = 1 / math.sqrt(dim)
|
| 44 |
tensor.uniform_(-std, std)
|
| 45 |
+
return tensor
|
| 46 |
|
| 47 |
|
| 48 |
# feedforward
|
|
|
|
| 143 |
|
| 144 |
|
| 145 |
class CrossAttention(nn.Module):
|
| 146 |
+
def __init__(self, query_dim, context_dim=None, heads=8, dim_head=64, dropout=0., **kargs):
|
| 147 |
super().__init__()
|
| 148 |
inner_dim = dim_head * heads
|
| 149 |
context_dim = default(context_dim, query_dim)
|
|
|
|
| 160 |
nn.Dropout(dropout)
|
| 161 |
)
|
| 162 |
|
| 163 |
+
def forward(self, x, context=None, mask=None, **kargs):
|
| 164 |
h = self.heads
|
| 165 |
|
| 166 |
q = self.to_q(x)
|
|
|
|
| 194 |
return self.to_out(out)
|
| 195 |
|
| 196 |
|
| 197 |
+
class MaskGuidedSelfAttention(nn.Module):
|
| 198 |
+
def __init__(self, query_dim, context_dim=None, heads=8, dim_head=64, dropout=0., obj_feat_dim=1024):
|
| 199 |
+
super().__init__()
|
| 200 |
+
#here context dim is for object features coming from image encoder
|
| 201 |
+
inner_dim = dim_head * heads
|
| 202 |
+
self.heads = heads
|
| 203 |
+
|
| 204 |
+
self.obj_feats_map = nn.Linear(obj_feat_dim, inner_dim)
|
| 205 |
+
self.to_v = nn.Linear(inner_dim, inner_dim, bias=False)
|
| 206 |
+
|
| 207 |
+
self.to_out = nn.Sequential(
|
| 208 |
+
nn.Linear(inner_dim, query_dim),
|
| 209 |
+
nn.Dropout(dropout)
|
| 210 |
+
)
|
| 211 |
+
|
| 212 |
+
self.scale = dim_head ** -0.5
|
| 213 |
+
|
| 214 |
+
def forward(self, x, context=None, mask=None, obj_mask=None, obj_feat=None):
|
| 215 |
+
_, _, ht, wd = obj_feat.shape
|
| 216 |
+
obj_feat = rearrange(obj_feat, 'b c h w -> b (h w) c').contiguous()
|
| 217 |
+
obj_feat = self.obj_feats_map(obj_feat)
|
| 218 |
+
v = self.to_v(obj_feat)
|
| 219 |
+
return self.to_out(v)
|
| 220 |
+
|
| 221 |
+
|
| 222 |
+
|
| 223 |
+
|
| 224 |
+
|
| 225 |
class MemoryEfficientCrossAttention(nn.Module):
|
| 226 |
# https://github.com/MatthieuTPHR/diffusers/blob/d80b531ff8060ec1ea982b65a1b8df70f73aa67c/src/diffusers/models/attention.py#L223
|
| 227 |
def __init__(self, query_dim, context_dim=None, heads=8, dim_head=64, dropout=0.0):
|
|
|
|
| 274 |
class BasicTransformerBlock(nn.Module):
|
| 275 |
ATTENTION_MODES = {
|
| 276 |
"softmax": CrossAttention, # vanilla attention
|
| 277 |
+
"softmax-xformers": MemoryEfficientCrossAttention,
|
| 278 |
+
"maskguided": MaskGuidedSelfAttention
|
| 279 |
}
|
| 280 |
def __init__(self, dim, n_heads, d_head, dropout=0., context_dim=None, gated_ff=True, checkpoint=True,
|
| 281 |
+
disable_self_attn=False, attn1_mode="softmax", obj_feat_dim=1024):
|
| 282 |
super().__init__()
|
| 283 |
attn_mode = "softmax-xformers" if XFORMERS_IS_AVAILBLE else "softmax"
|
| 284 |
assert attn_mode in self.ATTENTION_MODES
|
| 285 |
attn_cls = self.ATTENTION_MODES[attn_mode]
|
| 286 |
+
attn1_cls = self.ATTENTION_MODES[attn1_mode]
|
| 287 |
self.disable_self_attn = disable_self_attn
|
| 288 |
+
self.attn1 = attn1_cls(query_dim=dim, heads=n_heads, dim_head=d_head, dropout=dropout,
|
| 289 |
+
context_dim=context_dim if self.disable_self_attn else None, obj_feat_dim=obj_feat_dim) # is a self-attention if not self.disable_self_attn
|
| 290 |
self.ff = FeedForward(dim, dropout=dropout, glu=gated_ff)
|
| 291 |
self.attn2 = attn_cls(query_dim=dim, context_dim=context_dim,
|
| 292 |
heads=n_heads, dim_head=d_head, dropout=dropout) # is self-attn if context is none
|
|
|
|
| 295 |
self.norm3 = nn.LayerNorm(dim)
|
| 296 |
self.checkpoint = checkpoint
|
| 297 |
|
| 298 |
+
# self.ff_text_obj_feat = FeedForward(context_dim, dim_out=dim, mult=1, dropout=dropout, glu=gated_ff)
|
| 299 |
+
|
| 300 |
+
def forward(self, x, context=None, obj_mask=None, obj_feat=None):
|
| 301 |
+
if obj_mask is None:
|
| 302 |
+
# return self._forward(x, context, obj_mask, obj_feat)
|
| 303 |
+
return checkpoint(self._forward, (x, context), self.parameters(), self.checkpoint)
|
| 304 |
+
return checkpoint(self._forward, (x, context, obj_mask, obj_feat), self.parameters(), self.checkpoint)
|
| 305 |
|
| 306 |
+
def _forward(self, x, context=None, obj_mask=None, obj_feat=None):
|
| 307 |
+
x = self.attn1(self.norm1(x), context=context if self.disable_self_attn else None,
|
| 308 |
+
obj_mask=obj_mask, obj_feat=obj_feat) + x
|
| 309 |
x = self.attn2(self.norm2(x), context=context) + x
|
| 310 |
x = self.ff(self.norm3(x)) + x
|
| 311 |
return x
|
|
|
|
| 323 |
def __init__(self, in_channels, n_heads, d_head,
|
| 324 |
depth=1, dropout=0., context_dim=None,
|
| 325 |
disable_self_attn=False, use_linear=False,
|
| 326 |
+
use_checkpoint=True,attn1_mode='softmax',obj_feat_dim=None):
|
| 327 |
super().__init__()
|
| 328 |
if exists(context_dim) and not isinstance(context_dim, list):
|
| 329 |
context_dim = [context_dim]
|
|
|
|
| 341 |
|
| 342 |
self.transformer_blocks = nn.ModuleList(
|
| 343 |
[BasicTransformerBlock(inner_dim, n_heads, d_head, dropout=dropout, context_dim=context_dim[d],
|
| 344 |
+
disable_self_attn=disable_self_attn, checkpoint=use_checkpoint, attn1_mode=attn1_mode,
|
| 345 |
+
obj_feat_dim=obj_feat_dim)
|
| 346 |
for d in range(depth)]
|
| 347 |
)
|
| 348 |
if not use_linear:
|
|
|
|
| 355 |
self.proj_out = zero_module(nn.Linear(in_channels, inner_dim))
|
| 356 |
self.use_linear = use_linear
|
| 357 |
|
| 358 |
+
def forward(self, x, context=None, obj_masks=None, obj_feats=None):
|
| 359 |
# note: if no context is given, cross-attention defaults to self-attention
|
| 360 |
if not isinstance(context, list):
|
| 361 |
context = [context]
|
| 362 |
+
if not isinstance(obj_masks, list):
|
| 363 |
+
obj_masks = [obj_masks]
|
| 364 |
+
if not isinstance(obj_feats, list):
|
| 365 |
+
obj_feats = [obj_feats]
|
| 366 |
+
|
| 367 |
b, c, h, w = x.shape
|
| 368 |
+
if obj_feats[0] is not None:
|
| 369 |
+
obj_feats = [torch.nn.functional.interpolate(ofe, [h,w]) for ofe in obj_feats]
|
| 370 |
+
obj_masks = [torch.nn.functional.interpolate(om, [h,w]) for om in obj_masks]
|
| 371 |
+
|
| 372 |
x_in = x
|
| 373 |
x = self.norm(x)
|
| 374 |
if not self.use_linear:
|
|
|
|
| 377 |
if self.use_linear:
|
| 378 |
x = self.proj_in(x)
|
| 379 |
for i, block in enumerate(self.transformer_blocks):
|
| 380 |
+
x = block(x, context=context[i], obj_mask=obj_masks[i], obj_feat=obj_feats[i])
|
| 381 |
if self.use_linear:
|
| 382 |
x = self.proj_out(x)
|
| 383 |
x = rearrange(x, 'b (h w) c -> b c h w', h=h, w=w).contiguous()
|
ldm/modules/diffusionmodules/openaimodel.py
CHANGED
|
@@ -69,19 +69,31 @@ class TimestepBlock(nn.Module):
|
|
| 69 |
Apply the module to `x` given `emb` timestep embeddings.
|
| 70 |
"""
|
| 71 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 72 |
|
| 73 |
-
class TimestepEmbedSequential(nn.Sequential, TimestepBlock):
|
| 74 |
"""
|
| 75 |
A sequential module that passes timestep embeddings to the children that
|
| 76 |
support it as an extra input.
|
| 77 |
"""
|
| 78 |
|
| 79 |
-
def forward(self, x, emb, context=None, *args):
|
| 80 |
for layer in self:
|
| 81 |
if isinstance(layer, TimestepBlock):
|
| 82 |
x = layer(x, emb)
|
| 83 |
elif isinstance(layer, SpatialTransformer):
|
| 84 |
-
x = layer(x, context)
|
|
|
|
|
|
|
| 85 |
else:
|
| 86 |
x = layer(x)
|
| 87 |
return x
|
|
@@ -783,4 +795,4 @@ class UNetModel(nn.Module):
|
|
| 783 |
if self.predict_codebook_ids:
|
| 784 |
return self.id_predictor(h)
|
| 785 |
else:
|
| 786 |
-
return self.out(h)
|
|
|
|
| 69 |
Apply the module to `x` given `emb` timestep embeddings.
|
| 70 |
"""
|
| 71 |
|
| 72 |
+
class TimestepBlockSpa(nn.Module):
|
| 73 |
+
"""
|
| 74 |
+
Any module where forward() takes timestep embeddings as a second argument.
|
| 75 |
+
"""
|
| 76 |
+
|
| 77 |
+
@abstractmethod
|
| 78 |
+
def forward(self, x, emb, obj_feat):
|
| 79 |
+
"""
|
| 80 |
+
Apply the module to `x` given `emb` timestep embeddings.
|
| 81 |
+
"""
|
| 82 |
|
| 83 |
+
class TimestepEmbedSequential(nn.Sequential, TimestepBlock, TimestepBlockSpa):
|
| 84 |
"""
|
| 85 |
A sequential module that passes timestep embeddings to the children that
|
| 86 |
support it as an extra input.
|
| 87 |
"""
|
| 88 |
|
| 89 |
+
def forward(self, x, emb, context=None, obj_feat=None,obj_masks=None, *args):
|
| 90 |
for layer in self:
|
| 91 |
if isinstance(layer, TimestepBlock):
|
| 92 |
x = layer(x, emb)
|
| 93 |
elif isinstance(layer, SpatialTransformer):
|
| 94 |
+
x = layer(x, context, obj_masks=obj_masks, obj_feats=obj_feat)
|
| 95 |
+
elif isinstance(layer, TimestepBlockSpa):
|
| 96 |
+
x = layer(x, emb, obj_feat)
|
| 97 |
else:
|
| 98 |
x = layer(x)
|
| 99 |
return x
|
|
|
|
| 795 |
if self.predict_codebook_ids:
|
| 796 |
return self.id_predictor(h)
|
| 797 |
else:
|
| 798 |
+
return self.out(h)
|
ldm/modules/diffusionmodules/util.py
CHANGED
|
@@ -215,9 +215,10 @@ class SiLU(nn.Module):
|
|
| 215 |
|
| 216 |
|
| 217 |
class GroupNorm32(nn.GroupNorm):
|
| 218 |
-
def forward(self, x):
|
| 219 |
return super().forward(x.float()).type(x.dtype)
|
| 220 |
|
|
|
|
| 221 |
def conv_nd(dims, *args, **kwargs):
|
| 222 |
"""
|
| 223 |
Create a 1D, 2D, or 3D convolution module.
|
|
|
|
| 215 |
|
| 216 |
|
| 217 |
class GroupNorm32(nn.GroupNorm):
|
| 218 |
+
def forward(self, x, *args):
|
| 219 |
return super().forward(x.float()).type(x.dtype)
|
| 220 |
|
| 221 |
+
|
| 222 |
def conv_nd(dims, *args, **kwargs):
|
| 223 |
"""
|
| 224 |
Create a 1D, 2D, or 3D convolution module.
|
ldm/modules/encoders/modules.py
CHANGED
|
@@ -114,14 +114,14 @@ class FrozenCLIPEmbedder(AbstractEncoder):
|
|
| 114 |
for param in self.parameters():
|
| 115 |
param.requires_grad = False
|
| 116 |
|
| 117 |
-
def forward(self, text):
|
| 118 |
batch_encoding = self.tokenizer(text, truncation=True, max_length=self.max_length, return_length=True,
|
| 119 |
return_overflowing_tokens=False, padding="max_length", return_tensors="pt")
|
| 120 |
tokens = batch_encoding["input_ids"].to(self.device)
|
| 121 |
outputs = self.transformer(input_ids=tokens, output_hidden_states=self.layer=="hidden")
|
| 122 |
-
if
|
| 123 |
z = outputs.last_hidden_state
|
| 124 |
-
elif
|
| 125 |
z = outputs.pooler_output[:, None, :]
|
| 126 |
else:
|
| 127 |
z = outputs.hidden_states[self.layer_idx]
|
|
|
|
| 114 |
for param in self.parameters():
|
| 115 |
param.requires_grad = False
|
| 116 |
|
| 117 |
+
def forward(self, text, layer='last'):
|
| 118 |
batch_encoding = self.tokenizer(text, truncation=True, max_length=self.max_length, return_length=True,
|
| 119 |
return_overflowing_tokens=False, padding="max_length", return_tensors="pt")
|
| 120 |
tokens = batch_encoding["input_ids"].to(self.device)
|
| 121 |
outputs = self.transformer(input_ids=tokens, output_hidden_states=self.layer=="hidden")
|
| 122 |
+
if layer == "last":
|
| 123 |
z = outputs.last_hidden_state
|
| 124 |
+
elif layer == "pooled":
|
| 125 |
z = outputs.pooler_output[:, None, :]
|
| 126 |
else:
|
| 127 |
z = outputs.hidden_states[self.layer_idx]
|
pair_diff_demo.py
ADDED
|
@@ -0,0 +1,516 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import cv2
|
| 2 |
+
import einops
|
| 3 |
+
import gradio as gr
|
| 4 |
+
import numpy as np
|
| 5 |
+
import torch
|
| 6 |
+
import random
|
| 7 |
+
import os
|
| 8 |
+
import json
|
| 9 |
+
import datetime
|
| 10 |
+
from huggingface_hub import hf_hub_url, hf_hub_download
|
| 11 |
+
|
| 12 |
+
from pytorch_lightning import seed_everything
|
| 13 |
+
from annotator.util import resize_image, HWC3
|
| 14 |
+
from annotator.OneFormer import OneformerSegmenter
|
| 15 |
+
from cldm.model import create_model, load_state_dict
|
| 16 |
+
from cldm.ddim_hacked import DDIMSamplerSpaCFG
|
| 17 |
+
from ldm.models.autoencoder import DiagonalGaussianDistribution
|
| 18 |
+
|
| 19 |
+
|
| 20 |
+
SEGMENT_MODEL_DICT = {
|
| 21 |
+
'Oneformer': OneformerSegmenter,
|
| 22 |
+
}
|
| 23 |
+
|
| 24 |
+
MASK_MODEL_DICT = {
|
| 25 |
+
'Oneformer': OneformerSegmenter,
|
| 26 |
+
}
|
| 27 |
+
|
| 28 |
+
urls = {
|
| 29 |
+
'shi-labs/oneformer_coco_swin_large': ['150_16_swin_l_oneformer_coco_100ep.pth'],
|
| 30 |
+
'PAIR/PAIR-diffusion-sdv15-coco-finetune': ['model_e91.ckpt']
|
| 31 |
+
}
|
| 32 |
+
|
| 33 |
+
WTS_DICT = {
|
| 34 |
+
|
| 35 |
+
}
|
| 36 |
+
|
| 37 |
+
if os.path.exists('checkpoints') == False:
|
| 38 |
+
os.mkdir('checkpoints')
|
| 39 |
+
for repo in urls:
|
| 40 |
+
files = urls[repo]
|
| 41 |
+
for file in files:
|
| 42 |
+
url = hf_hub_url(repo, file)
|
| 43 |
+
name_ckp = url.split('/')[-1]
|
| 44 |
+
|
| 45 |
+
WTS_DICT[repo] = hf_hub_download(repo_id=repo, filename=file)
|
| 46 |
+
|
| 47 |
+
|
| 48 |
+
#main model
|
| 49 |
+
model = create_model('configs/pair_diff.yaml').cpu()
|
| 50 |
+
model.load_state_dict(load_state_dict(WTS_DICT['PAIR/PAIR-diffusion-sdv15-coco-finetune'], location='cuda'))
|
| 51 |
+
|
| 52 |
+
save_dir = 'results/'
|
| 53 |
+
|
| 54 |
+
model = model.cuda()
|
| 55 |
+
ddim_sampler = DDIMSamplerSpaCFG(model)
|
| 56 |
+
save_memory = False
|
| 57 |
+
|
| 58 |
+
|
| 59 |
+
class ImageComp:
|
| 60 |
+
def __init__(self, edit_operation):
|
| 61 |
+
self.input_img = None
|
| 62 |
+
self.input_pmask = None
|
| 63 |
+
self.input_segmask = None
|
| 64 |
+
self.input_mask = None
|
| 65 |
+
self.input_points = []
|
| 66 |
+
self.input_scale = 1
|
| 67 |
+
|
| 68 |
+
self.ref_img = None
|
| 69 |
+
self.ref_pmask = None
|
| 70 |
+
self.ref_segmask = None
|
| 71 |
+
self.ref_mask = None
|
| 72 |
+
self.ref_points = []
|
| 73 |
+
self.ref_scale = 1
|
| 74 |
+
|
| 75 |
+
self.multi_modal = False
|
| 76 |
+
|
| 77 |
+
self.H = None
|
| 78 |
+
self.W = None
|
| 79 |
+
self.kernel = np.ones((5, 5), np.uint8)
|
| 80 |
+
self.edit_operation = edit_operation
|
| 81 |
+
self.init_segmentation_model()
|
| 82 |
+
os.makedirs(save_dir, exist_ok=True)
|
| 83 |
+
|
| 84 |
+
self.base_prompt = 'A picture of {}'
|
| 85 |
+
|
| 86 |
+
def init_segmentation_model(self, mask_model='Oneformer', segment_model='Oneformer'):
|
| 87 |
+
self.segment_model_name = segment_model
|
| 88 |
+
self.mask_model_name = mask_model
|
| 89 |
+
|
| 90 |
+
self.segment_model = SEGMENT_MODEL_DICT[segment_model](WTS_DICT['shi-labs/oneformer_coco_swin_large'])
|
| 91 |
+
|
| 92 |
+
if mask_model == 'Oneformer' and segment_model == 'Oneformer':
|
| 93 |
+
self.mask_model_inp = self.segment_model
|
| 94 |
+
self.mask_model_ref = self.segment_model
|
| 95 |
+
else:
|
| 96 |
+
self.mask_model_inp = MASK_MODEL_DICT[mask_model]()
|
| 97 |
+
self.mask_model_ref = MASK_MODEL_DICT[mask_model]()
|
| 98 |
+
|
| 99 |
+
print(f"Segmentation Models initialized with {mask_model} as mask and {segment_model} as segment")
|
| 100 |
+
|
| 101 |
+
def init_input_canvas(self, img):
|
| 102 |
+
|
| 103 |
+
img = HWC3(img)
|
| 104 |
+
img = resize_image(img, 512)
|
| 105 |
+
if self.segment_model_name == 'Oneformer':
|
| 106 |
+
detected_seg = self.segment_model(img, 'semantic')
|
| 107 |
+
elif self.segment_model_name == 'SAM':
|
| 108 |
+
raise NotImplementedError
|
| 109 |
+
|
| 110 |
+
if self.mask_model_name == 'Oneformer':
|
| 111 |
+
detected_mask = self.mask_model_inp(img, 'panoptic')[0]
|
| 112 |
+
elif self.mask_model_name == 'SAM':
|
| 113 |
+
detected_mask = self.mask_model_inp(img)
|
| 114 |
+
|
| 115 |
+
self.input_points = []
|
| 116 |
+
self.input_img = img
|
| 117 |
+
self.input_pmask = detected_mask
|
| 118 |
+
self.input_segmask = detected_seg
|
| 119 |
+
self.H = img.shape[0]
|
| 120 |
+
self.W = img.shape[1]
|
| 121 |
+
|
| 122 |
+
return img
|
| 123 |
+
|
| 124 |
+
def init_ref_canvas(self, img):
|
| 125 |
+
|
| 126 |
+
img = HWC3(img)
|
| 127 |
+
img = resize_image(img, 512)
|
| 128 |
+
if self.segment_model_name == 'Oneformer':
|
| 129 |
+
detected_seg = self.segment_model(img, 'semantic')
|
| 130 |
+
elif self.segment_model_name == 'SAM':
|
| 131 |
+
raise NotImplementedError
|
| 132 |
+
|
| 133 |
+
if self.mask_model_name == 'Oneformer':
|
| 134 |
+
detected_mask = self.mask_model_ref(img, 'panoptic')[0]
|
| 135 |
+
elif self.mask_model_name == 'SAM':
|
| 136 |
+
detected_mask = self.mask_model_ref(img)
|
| 137 |
+
|
| 138 |
+
self.ref_points = []
|
| 139 |
+
print("Initialized ref", img.shape)
|
| 140 |
+
self.ref_img = img
|
| 141 |
+
self.ref_pmask = detected_mask
|
| 142 |
+
self.ref_segmask = detected_seg
|
| 143 |
+
|
| 144 |
+
return img
|
| 145 |
+
|
| 146 |
+
def select_input_object(self, evt: gr.SelectData):
|
| 147 |
+
idx = list(np.array(evt.index) * self.input_scale)
|
| 148 |
+
self.input_points.append(idx)
|
| 149 |
+
if self.mask_model_name == 'Oneformer':
|
| 150 |
+
mask = self._get_mask_from_panoptic(np.array(self.input_points), self.input_pmask)
|
| 151 |
+
else:
|
| 152 |
+
mask = self.mask_model_inp(self.input_img, self.input_points)
|
| 153 |
+
|
| 154 |
+
c_ids = self.input_segmask[np.array(self.input_points)[:,1], np.array(self.input_points)[:,0]]
|
| 155 |
+
unique_ids, counts = torch.unique(c_ids, return_counts=True)
|
| 156 |
+
c_id = int(unique_ids[torch.argmax(counts)].cpu().detach().numpy())
|
| 157 |
+
category = self.segment_model.metadata.stuff_classes[c_id]
|
| 158 |
+
# print(self.segment_model.metadata.stuff_classes)
|
| 159 |
+
|
| 160 |
+
self.input_mask = mask
|
| 161 |
+
mask = mask.cpu().numpy()
|
| 162 |
+
output = mask[:,:,None] * self.input_img + (1 - mask[:,:,None]) * self.input_img * 0.2
|
| 163 |
+
return output.astype(np.uint8), self.base_prompt.format(category)
|
| 164 |
+
|
| 165 |
+
def select_ref_object(self, evt: gr.SelectData):
|
| 166 |
+
idx = list(np.array(evt.index) * self.ref_scale)
|
| 167 |
+
self.ref_points.append(idx)
|
| 168 |
+
if self.mask_model_name == 'Oneformer':
|
| 169 |
+
mask = self._get_mask_from_panoptic(np.array(self.ref_points), self.ref_pmask)
|
| 170 |
+
else:
|
| 171 |
+
mask = self.mask_model_ref(self.ref_img, self.ref_points)
|
| 172 |
+
c_ids = self.ref_segmask[np.array(self.ref_points)[:,1], np.array(self.ref_points)[:,0]]
|
| 173 |
+
unique_ids, counts = torch.unique(c_ids, return_counts=True)
|
| 174 |
+
c_id = int(unique_ids[torch.argmax(counts)].cpu().detach().numpy())
|
| 175 |
+
category = self.segment_model.metadata.stuff_classes[c_id]
|
| 176 |
+
print("Category of reference object is:", category)
|
| 177 |
+
|
| 178 |
+
self.ref_mask = mask
|
| 179 |
+
mask = mask.cpu().numpy()
|
| 180 |
+
output = mask[:,:,None] * self.ref_img + (1 - mask[:,:,None]) * self.ref_img * 0.2
|
| 181 |
+
return output.astype(np.uint8)
|
| 182 |
+
|
| 183 |
+
def clear_points(self):
|
| 184 |
+
self.input_points = []
|
| 185 |
+
self.ref_points = []
|
| 186 |
+
zeros_inp = np.zeros(self.input_img.shape)
|
| 187 |
+
zeros_ref = np.zeros(self.ref_img.shape)
|
| 188 |
+
return zeros_inp, zeros_ref
|
| 189 |
+
|
| 190 |
+
def return_input_img(self):
|
| 191 |
+
return self.input_img
|
| 192 |
+
|
| 193 |
+
|
| 194 |
+
def _get_mask_from_panoptic(self, points, panoptic_mask):
|
| 195 |
+
panoptic_mask_ = panoptic_mask + 1
|
| 196 |
+
ids = panoptic_mask_[points[:,1], points[:,0]]
|
| 197 |
+
unique_ids, counts = torch.unique(ids, return_counts=True)
|
| 198 |
+
mask_id = unique_ids[torch.argmax(counts)]
|
| 199 |
+
final_mask = torch.zeros(panoptic_mask.shape).cuda()
|
| 200 |
+
final_mask[panoptic_mask_ == mask_id] = 1
|
| 201 |
+
|
| 202 |
+
return final_mask
|
| 203 |
+
|
| 204 |
+
|
| 205 |
+
def _process_mask(self, mask, panoptic_mask, segmask):
|
| 206 |
+
obj_class = mask * (segmask + 1)
|
| 207 |
+
unique_ids, counts = torch.unique(obj_class, return_counts=True)
|
| 208 |
+
obj_class = unique_ids[torch.argmax(counts[1:]) + 1] - 1
|
| 209 |
+
return mask, obj_class
|
| 210 |
+
|
| 211 |
+
|
| 212 |
+
def _edit_app(self, whole_ref):
|
| 213 |
+
"""
|
| 214 |
+
Manipulates the panoptic mask of input image to change appearance
|
| 215 |
+
"""
|
| 216 |
+
input_pmask = self.input_pmask
|
| 217 |
+
input_segmask = self.input_segmask
|
| 218 |
+
|
| 219 |
+
if whole_ref:
|
| 220 |
+
reference_mask = torch.ones(self.ref_pmask.shape).cuda()
|
| 221 |
+
else:
|
| 222 |
+
reference_mask, _ = self._process_mask(self.ref_mask, self.ref_pmask, self.ref_segmask)
|
| 223 |
+
|
| 224 |
+
edit_mask, _ = self._process_mask(self.input_mask, self.input_pmask, self.input_segmask)
|
| 225 |
+
# tmp = cv2.dilate(edit_mask.squeeze().cpu().numpy(), self.kernel, iterations = 2)
|
| 226 |
+
# region_mask = torch.tensor(tmp).cuda()
|
| 227 |
+
region_mask = edit_mask
|
| 228 |
+
ma = torch.max(input_pmask)
|
| 229 |
+
|
| 230 |
+
input_pmask[edit_mask == 1] = ma + 1
|
| 231 |
+
return reference_mask, input_pmask, input_segmask, region_mask, ma
|
| 232 |
+
|
| 233 |
+
def _add_object(self, input_mask, dilation_fac):
|
| 234 |
+
"""
|
| 235 |
+
Manipulates the panooptic mask of input image for adding objects
|
| 236 |
+
Args:
|
| 237 |
+
input_mask (numpy array): Region where new objects needs to be added
|
| 238 |
+
dilation factor (float): Controls edge merging region for adding objects
|
| 239 |
+
|
| 240 |
+
"""
|
| 241 |
+
input_pmask = self.input_pmask
|
| 242 |
+
input_segmask = self.input_segmask
|
| 243 |
+
reference_mask, obj_class = self._process_mask(self.ref_mask, self.ref_pmask, self.ref_segmask)
|
| 244 |
+
|
| 245 |
+
tmp = cv2.dilate(input_mask['mask'][:, :, 0], self.kernel, iterations = int(dilation_fac))
|
| 246 |
+
region = torch.tensor(tmp)
|
| 247 |
+
region_mask = torch.zeros_like(region).cuda()
|
| 248 |
+
region_mask[region > 127] = 1
|
| 249 |
+
|
| 250 |
+
mask_ = torch.tensor(input_mask['mask'][:, :, 0])
|
| 251 |
+
edit_mask = torch.zeros_like(mask_).cuda()
|
| 252 |
+
edit_mask[mask_ > 127] = 1
|
| 253 |
+
ma = torch.max(input_pmask)
|
| 254 |
+
input_pmask[edit_mask == 1] = ma + 1
|
| 255 |
+
print(obj_class)
|
| 256 |
+
input_segmask[edit_mask == 1] = obj_class.long()
|
| 257 |
+
|
| 258 |
+
return reference_mask, input_pmask, input_segmask, region_mask, ma
|
| 259 |
+
|
| 260 |
+
def _edit(self, input_mask, ref_mask, dilation_fac=1, whole_ref=False, inter=1):
|
| 261 |
+
"""
|
| 262 |
+
Entry point for all the appearance editing and add objects operations. The function manipulates the
|
| 263 |
+
appearance vectors and structure based on user input
|
| 264 |
+
Args:
|
| 265 |
+
input mask (numpy array): Region in input image which needs to be edited
|
| 266 |
+
dilation factor (float): Controls edge merging region for adding objects
|
| 267 |
+
whole_ref (bool): Flag for specifying if complete reference image should be used
|
| 268 |
+
inter (float): Interpolation of appearance between the reference appearance and the input appearance.
|
| 269 |
+
"""
|
| 270 |
+
input_img = (self.input_img/127.5 - 1)
|
| 271 |
+
input_img = torch.from_numpy(input_img.astype(np.float32)).cuda().unsqueeze(0).permute(0,3,1,2)
|
| 272 |
+
|
| 273 |
+
reference_img = (self.ref_img/127.5 - 1)
|
| 274 |
+
reference_img = torch.from_numpy(reference_img.astype(np.float32)).cuda().unsqueeze(0).permute(0,3,1,2)
|
| 275 |
+
|
| 276 |
+
if self.edit_operation == 'add_obj':
|
| 277 |
+
reference_mask, input_pmask, input_segmask, region_mask, ma = self._add_object(input_mask, dilation_fac)
|
| 278 |
+
elif self.edit_operation == 'edit_app':
|
| 279 |
+
reference_mask, input_pmask, input_segmask, region_mask, ma = self._edit_app(whole_ref)
|
| 280 |
+
|
| 281 |
+
#concat featurees
|
| 282 |
+
input_pmask = input_pmask.float().cuda().unsqueeze(0).unsqueeze(1)
|
| 283 |
+
_, mean_feat_inpt_conc, one_hot_inpt_conc, _ = model.get_appearance(model.appearance_net_conc, model.app_layer_conc, input_img, input_pmask, return_all=True)
|
| 284 |
+
|
| 285 |
+
reference_mask = reference_mask.float().cuda().unsqueeze(0).unsqueeze(1)
|
| 286 |
+
_, mean_feat_ref_conc, _, _ = model.get_appearance(model.appearance_net_conc, model.app_layer_conc, reference_img, reference_mask, return_all=True)
|
| 287 |
+
|
| 288 |
+
# if mean_feat_ref.shape[1] > 1:
|
| 289 |
+
if isinstance(mean_feat_inpt_conc, list):
|
| 290 |
+
appearance_conc = []
|
| 291 |
+
for i in range(len(mean_feat_inpt_conc)):
|
| 292 |
+
mean_feat_inpt_conc[i][:, ma + 1] = (1 - inter) * mean_feat_inpt_conc[i][:, ma + 1] + inter*mean_feat_ref_conc[i][:, 1]
|
| 293 |
+
splatted_feat_conc = torch.einsum('nmc, nmhw->nchw', mean_feat_inpt_conc[i], one_hot_inpt_conc)
|
| 294 |
+
splatted_feat_conc = torch.nn.functional.normalize(splatted_feat_conc)
|
| 295 |
+
splatted_feat_conc = torch.nn.functional.interpolate(splatted_feat_conc, (self.H//8, self.W//8))
|
| 296 |
+
appearance_conc.append(splatted_feat_conc)
|
| 297 |
+
appearance_conc = torch.cat(appearance_conc, dim=1)
|
| 298 |
+
else:
|
| 299 |
+
print("manipulating")
|
| 300 |
+
mean_feat_inpt_conc[:, ma + 1] = (1 - inter) * mean_feat_inpt_conc[:, ma + 1] + inter*mean_feat_ref_conc[:, 1]
|
| 301 |
+
|
| 302 |
+
splatted_feat_conc = torch.einsum('nmc, nmhw->nchw', mean_feat_inpt_conc, one_hot_inpt_conc)
|
| 303 |
+
appearance_conc = torch.nn.functional.normalize(splatted_feat_conc) #l2 normaliz
|
| 304 |
+
appearance_conc = torch.nn.functional.interpolate(appearance_conc, (self.H//8, self.W//8))
|
| 305 |
+
|
| 306 |
+
#cross attention features
|
| 307 |
+
_, mean_feat_inpt_ca, one_hot_inpt_ca, _ = model.get_appearance(model.appearance_net_ca, model.app_layer_ca, input_img, input_pmask, return_all=True)
|
| 308 |
+
|
| 309 |
+
_, mean_feat_ref_ca, _, _ = model.get_appearance(model.appearance_net_ca, model.app_layer_ca, reference_img, reference_mask, return_all=True)
|
| 310 |
+
|
| 311 |
+
# if mean_feat_ref.shape[1] > 1:
|
| 312 |
+
if isinstance(mean_feat_inpt_ca, list):
|
| 313 |
+
appearance_ca = []
|
| 314 |
+
for i in range(len(mean_feat_inpt_ca)):
|
| 315 |
+
mean_feat_inpt_ca[i][:, ma + 1] = (1 - inter) * mean_feat_inpt_ca[i][:, ma + 1] + inter*mean_feat_ref_ca[i][:, 1]
|
| 316 |
+
splatted_feat_ca = torch.einsum('nmc, nmhw->nchw', mean_feat_inpt_ca[i], one_hot_inpt_ca)
|
| 317 |
+
splatted_feat_ca = torch.nn.functional.normalize(splatted_feat_ca)
|
| 318 |
+
splatted_feat_ca = torch.nn.functional.interpolate(splatted_feat_ca, (self.H//8, self.W//8))
|
| 319 |
+
appearance_ca.append(splatted_feat_ca)
|
| 320 |
+
else:
|
| 321 |
+
print("manipulating")
|
| 322 |
+
mean_feat_inpt_ca[:, ma + 1] = (1 - inter) * mean_feat_inpt_ca[:, ma + 1] + inter*mean_feat_ref_ca[:, 1]
|
| 323 |
+
|
| 324 |
+
splatted_feat_ca = torch.einsum('nmc, nmhw->nchw', mean_feat_inpt_ca, one_hot_inpt_ca)
|
| 325 |
+
appearance_ca = torch.nn.functional.normalize(splatted_feat_ca) #l2 normaliz
|
| 326 |
+
appearance_ca = torch.nn.functional.interpolate(appearance_ca, (self.H//8, self.W//8))
|
| 327 |
+
|
| 328 |
+
|
| 329 |
+
|
| 330 |
+
input_segmask = ((input_segmask+1)/ 127.5 - 1.0).cuda().unsqueeze(0).unsqueeze(1)
|
| 331 |
+
structure = torch.nn.functional.interpolate(input_segmask, (self.H//8, self.W//8))
|
| 332 |
+
|
| 333 |
+
|
| 334 |
+
return structure, appearance_conc, appearance_ca, region_mask, input_img
|
| 335 |
+
|
| 336 |
+
def _edit_obj_var(self, input_mask, ignore_structure):
|
| 337 |
+
input_img = (self.input_img/127.5 - 1)
|
| 338 |
+
input_img = torch.from_numpy(input_img.astype(np.float32)).cuda().unsqueeze(0).permute(0,3,1,2)
|
| 339 |
+
|
| 340 |
+
|
| 341 |
+
input_pmask = self.input_pmask
|
| 342 |
+
input_segmask = self.input_segmask
|
| 343 |
+
|
| 344 |
+
ma = torch.max(input_pmask)
|
| 345 |
+
mask_ = torch.tensor(input_mask['mask'][:, :, 0])
|
| 346 |
+
edit_mask = torch.zeros_like(mask_).cuda()
|
| 347 |
+
edit_mask[mask_ > 127] = 1
|
| 348 |
+
tmp = edit_mask * (input_pmask + ma + 1)
|
| 349 |
+
if ignore_structure:
|
| 350 |
+
tmp = edit_mask
|
| 351 |
+
|
| 352 |
+
input_pmask = tmp * edit_mask + (1 - edit_mask) * input_pmask
|
| 353 |
+
|
| 354 |
+
input_pmask = input_pmask.float().cuda().unsqueeze(0).unsqueeze(1)
|
| 355 |
+
|
| 356 |
+
mask_ca_feat = self.input_pmask.float().cuda().unsqueeze(0).unsqueeze(1) if ignore_structure else input_pmask
|
| 357 |
+
print(torch.unique(mask_ca_feat))
|
| 358 |
+
|
| 359 |
+
appearance_conc,_,_,_ = model.get_appearance(model.appearance_net_conc, model.app_layer_conc, input_img, input_pmask, return_all=True)
|
| 360 |
+
appearance_ca = model.get_appearance(model.appearance_net_ca, model.app_layer_ca, input_img, mask_ca_feat)
|
| 361 |
+
|
| 362 |
+
appearance_conc = torch.nn.functional.interpolate(appearance_conc, (self.H//8, self.W//8))
|
| 363 |
+
appearance_ca = [torch.nn.functional.interpolate(ap, (self.H//8, self.W//8)) for ap in appearance_ca]
|
| 364 |
+
|
| 365 |
+
input_segmask = ((input_segmask+1)/ 127.5 - 1.0).cuda().unsqueeze(0).unsqueeze(1)
|
| 366 |
+
structure = torch.nn.functional.interpolate(input_segmask, (self.H//8, self.W//8))
|
| 367 |
+
|
| 368 |
+
|
| 369 |
+
tmp = input_mask['mask'][:, :, 0]
|
| 370 |
+
region = torch.tensor(tmp)
|
| 371 |
+
mask = torch.zeros_like(region).cuda()
|
| 372 |
+
mask[region > 127] = 1
|
| 373 |
+
|
| 374 |
+
return structure, appearance_conc, appearance_ca, mask, input_img
|
| 375 |
+
|
| 376 |
+
def get_caption(self, mask):
|
| 377 |
+
"""
|
| 378 |
+
Generates the captions based on a set template
|
| 379 |
+
Args:
|
| 380 |
+
mask (numpy array): Region of image based on which caption needs to be generated
|
| 381 |
+
"""
|
| 382 |
+
mask = mask['mask'][:, :, 0]
|
| 383 |
+
region = torch.tensor(mask).cuda()
|
| 384 |
+
mask = torch.zeros_like(region)
|
| 385 |
+
mask[region > 127] = 1
|
| 386 |
+
|
| 387 |
+
if torch.sum(mask) == 0:
|
| 388 |
+
return ""
|
| 389 |
+
|
| 390 |
+
c_ids = self.input_segmask * mask
|
| 391 |
+
unique_ids, counts = torch.unique(c_ids, return_counts=True)
|
| 392 |
+
c_id = int(unique_ids[torch.argmax(counts[1:]) + 1].cpu().detach().numpy())
|
| 393 |
+
category = self.segment_model.metadata.stuff_classes[c_id]
|
| 394 |
+
|
| 395 |
+
return self.base_prompt.format(category)
|
| 396 |
+
|
| 397 |
+
def save_result(self, input_mask, prompt, a_prompt, n_prompt,
|
| 398 |
+
ddim_steps, scale_s, scale_f, scale_t, seed, dilation_fac=1,inter=1,
|
| 399 |
+
free_form_obj_var=False, ignore_structure=False):
|
| 400 |
+
"""
|
| 401 |
+
Saves the current results with all the meta data
|
| 402 |
+
"""
|
| 403 |
+
|
| 404 |
+
meta_data = {}
|
| 405 |
+
meta_data['prompt'] = prompt
|
| 406 |
+
meta_data['a_prompt'] = a_prompt
|
| 407 |
+
meta_data['n_prompt'] = n_prompt
|
| 408 |
+
meta_data['seed'] = seed
|
| 409 |
+
meta_data['ddim_steps'] = ddim_steps
|
| 410 |
+
meta_data['scale_s'] = scale_s
|
| 411 |
+
meta_data['scale_f'] = scale_f
|
| 412 |
+
meta_data['scale_t'] = scale_t
|
| 413 |
+
meta_data['inter'] = inter
|
| 414 |
+
meta_data['dilation_fac'] = dilation_fac
|
| 415 |
+
meta_data['edit_operation'] = self.edit_operation
|
| 416 |
+
|
| 417 |
+
uuid = datetime.datetime.now().strftime("%Y-%m-%dT%H-%M-%S")
|
| 418 |
+
os.makedirs(f'{save_dir}/{uuid}')
|
| 419 |
+
|
| 420 |
+
with open(f'{save_dir}/{uuid}/meta.json', "w") as outfile:
|
| 421 |
+
json.dump(meta_data, outfile)
|
| 422 |
+
cv2.imwrite(f'{save_dir}/{uuid}/input.png', self.input_img[:,:,::-1])
|
| 423 |
+
cv2.imwrite(f'{save_dir}/{uuid}/ref.png', self.ref_img[:,:,::-1])
|
| 424 |
+
if self.ref_mask is not None:
|
| 425 |
+
cv2.imwrite(f'{save_dir}/{uuid}/ref_mask.png', self.ref_mask.cpu().squeeze().numpy() * 200)
|
| 426 |
+
for i in range(len(self.results)):
|
| 427 |
+
cv2.imwrite(f'{save_dir}/{uuid}/edit{i}.png', self.results[i][:,:,::-1])
|
| 428 |
+
|
| 429 |
+
if self.edit_operation == 'add_obj' or free_form_obj_var:
|
| 430 |
+
cv2.imwrite(f'{save_dir}/{uuid}/input_mask.png', input_mask['mask'] * 200)
|
| 431 |
+
else:
|
| 432 |
+
cv2.imwrite(f'{save_dir}/{uuid}/input_mask.png', self.input_mask.cpu().squeeze().numpy() * 200)
|
| 433 |
+
|
| 434 |
+
print("Saved results at", f'{save_dir}/{uuid}')
|
| 435 |
+
|
| 436 |
+
|
| 437 |
+
def process(self, input_mask, ref_mask, prompt, a_prompt, n_prompt,
|
| 438 |
+
num_samples, ddim_steps, guess_mode, strength,
|
| 439 |
+
scale_s, scale_f, scale_t, seed, eta, dilation_fac=1,masking=True,whole_ref=False,inter=1,
|
| 440 |
+
free_form_obj_var=False, ignore_structure=False):
|
| 441 |
+
|
| 442 |
+
print(prompt)
|
| 443 |
+
if free_form_obj_var:
|
| 444 |
+
print("Free form")
|
| 445 |
+
structure, appearance_conc, appearance_ca, mask, img = self._edit_obj_var(input_mask, ignore_structure)
|
| 446 |
+
else:
|
| 447 |
+
structure, appearance_conc, appearance_ca, mask, img = self._edit(input_mask, ref_mask, dilation_fac=dilation_fac,
|
| 448 |
+
whole_ref=whole_ref, inter=inter)
|
| 449 |
+
|
| 450 |
+
input_pmask = torch.nn.functional.interpolate(self.input_pmask.cuda().unsqueeze(0).unsqueeze(1).float(), (self.H//8, self.W//8))
|
| 451 |
+
input_pmask = input_pmask.to(memory_format=torch.contiguous_format)
|
| 452 |
+
|
| 453 |
+
|
| 454 |
+
if isinstance(appearance_ca, list):
|
| 455 |
+
null_appearance_ca = [torch.zeros(a.shape).cuda() for a in appearance_ca]
|
| 456 |
+
null_appearance_conc = torch.zeros(appearance_conc.shape).cuda()
|
| 457 |
+
null_structure = torch.zeros(structure.shape).cuda() - 1
|
| 458 |
+
|
| 459 |
+
null_control = [torch.cat([null_structure, napp, input_pmask * 0], dim=1) for napp in null_appearance_ca]
|
| 460 |
+
structure_control = [torch.cat([structure, napp, input_pmask], dim=1) for napp in null_appearance_ca]
|
| 461 |
+
full_control = [torch.cat([structure, napp, input_pmask], dim=1) for napp in appearance_ca]
|
| 462 |
+
|
| 463 |
+
null_control.append(torch.cat([null_structure, null_appearance_conc, null_structure * 0], dim=1))
|
| 464 |
+
structure_control.append(torch.cat([structure, null_appearance_conc, null_structure], dim=1))
|
| 465 |
+
full_control.append(torch.cat([structure, appearance_conc, input_pmask], dim=1))
|
| 466 |
+
|
| 467 |
+
null_control = [torch.cat([nc for _ in range(num_samples)], dim=0) for nc in null_control]
|
| 468 |
+
structure_control = [torch.cat([sc for _ in range(num_samples)], dim=0) for sc in structure_control]
|
| 469 |
+
full_control = [torch.cat([fc for _ in range(num_samples)], dim=0) for fc in full_control]
|
| 470 |
+
|
| 471 |
+
#Masking for local edit
|
| 472 |
+
if not masking:
|
| 473 |
+
mask, x0 = None, None
|
| 474 |
+
else:
|
| 475 |
+
x0 = model.encode_first_stage(img)
|
| 476 |
+
x0 = x0.sample() if isinstance(x0, DiagonalGaussianDistribution) else x0 # todo: check if we can set random number
|
| 477 |
+
x0 = x0 * model.scale_factor
|
| 478 |
+
mask = 1 - torch.tensor(mask).unsqueeze(0).unsqueeze(1).cuda()
|
| 479 |
+
mask = torch.nn.functional.interpolate(mask.float(), x0.shape[2:]).float()
|
| 480 |
+
|
| 481 |
+
if seed == -1:
|
| 482 |
+
seed = random.randint(0, 65535)
|
| 483 |
+
seed_everything(seed)
|
| 484 |
+
|
| 485 |
+
scale = [scale_s, scale_f, scale_t]
|
| 486 |
+
print(scale)
|
| 487 |
+
if save_memory:
|
| 488 |
+
model.low_vram_shift(is_diffusing=False)
|
| 489 |
+
|
| 490 |
+
uc_cross = model.get_learned_conditioning([n_prompt] * num_samples)
|
| 491 |
+
c_cross = model.get_learned_conditioning([prompt + ', ' + a_prompt] * num_samples)
|
| 492 |
+
cond = {"c_concat": [null_control], "c_crossattn": [c_cross]}
|
| 493 |
+
un_cond = {"c_concat": None if guess_mode else [null_control], "c_crossattn": [uc_cross]}
|
| 494 |
+
un_cond_struct = {"c_concat": None if guess_mode else [structure_control], "c_crossattn": [uc_cross]}
|
| 495 |
+
un_cond_struct_app = {"c_concat": None if guess_mode else [full_control], "c_crossattn": [uc_cross]}
|
| 496 |
+
|
| 497 |
+
shape = (4, self.H // 8, self.W // 8)
|
| 498 |
+
|
| 499 |
+
if save_memory:
|
| 500 |
+
model.low_vram_shift(is_diffusing=True)
|
| 501 |
+
|
| 502 |
+
model.control_scales = [strength * (0.825 ** float(12 - i)) for i in range(13)] if guess_mode else ([strength] * 13) # Magic number. IDK why. Perhaps because 0.825**12<0.01 but 0.826**12>0.01
|
| 503 |
+
samples, _ = ddim_sampler.sample(ddim_steps, num_samples,
|
| 504 |
+
shape, cond, verbose=False, eta=eta,
|
| 505 |
+
unconditional_guidance_scale=scale, mask=mask, x0=x0,
|
| 506 |
+
unconditional_conditioning=[un_cond, un_cond_struct, un_cond_struct_app ])
|
| 507 |
+
|
| 508 |
+
if save_memory:
|
| 509 |
+
model.low_vram_shift(is_diffusing=False)
|
| 510 |
+
|
| 511 |
+
x_samples = (model.decode_first_stage(samples) + 1) * 127.5
|
| 512 |
+
x_samples = (einops.rearrange(x_samples, 'b c h w -> b h w c')).cpu().numpy().clip(0, 255).astype(np.uint8)
|
| 513 |
+
|
| 514 |
+
results = [x_samples[i] for i in range(num_samples)]
|
| 515 |
+
self.results = results
|
| 516 |
+
return [] + results
|
requirements.txt
CHANGED
|
@@ -9,6 +9,7 @@ omegaconf==2.3.0
|
|
| 9 |
open-clip-torch==2.0.2
|
| 10 |
opencv-contrib-python==4.3.0.36
|
| 11 |
opencv-python-headless==4.7.0.72
|
|
|
|
| 12 |
prettytable==3.6.0
|
| 13 |
pytorch-lightning==1.5.0
|
| 14 |
safetensors==0.2.7
|
|
@@ -44,4 +45,4 @@ diffdist
|
|
| 44 |
gdown
|
| 45 |
huggingface_hub
|
| 46 |
tqdm
|
| 47 |
-
wget
|
|
|
|
| 9 |
open-clip-torch==2.0.2
|
| 10 |
opencv-contrib-python==4.3.0.36
|
| 11 |
opencv-python-headless==4.7.0.72
|
| 12 |
+
pillow==9.4.0
|
| 13 |
prettytable==3.6.0
|
| 14 |
pytorch-lightning==1.5.0
|
| 15 |
safetensors==0.2.7
|
|
|
|
| 45 |
gdown
|
| 46 |
huggingface_hub
|
| 47 |
tqdm
|
| 48 |
+
wget
|