Spaces:
Running
on
Zero
Running
on
Zero
| # type: ignore | |
| # Inspired from https://github.com/ixarchakos/try-off-anyone/blob/aa3045453013065573a647e4536922bac696b968/src/model/pipeline.py | |
| # Inspired from https://github.com/ixarchakos/try-off-anyone/blob/aa3045453013065573a647e4536922bac696b968/src/model/attention.py | |
| import torch | |
| from accelerate import load_checkpoint_in_model | |
| from diffusers import AutoencoderKL, DDIMScheduler, UNet2DConditionModel | |
| from diffusers.models.attention_processor import AttnProcessor | |
| from diffusers.utils.torch_utils import randn_tensor | |
| from huggingface_hub import hf_hub_download | |
| from PIL import Image | |
| class Skip(torch.nn.Module): | |
| def __init__(self) -> None: | |
| super().__init__() | |
| def __call__( | |
| self, | |
| attn: torch.Tensor, | |
| hidden_states: torch.Tensor, | |
| encoder_hidden_states: torch.Tensor = None, | |
| attention_mask: torch.Tensor = None, | |
| temb: torch.Tensor = None, | |
| ) -> torch.Tensor: | |
| return hidden_states | |
| def fine_tuned_modules(unet: UNet2DConditionModel) -> torch.nn.ModuleList: | |
| trainable_modules = torch.nn.ModuleList() | |
| for blocks in [unet.down_blocks, unet.mid_block, unet.up_blocks]: | |
| if hasattr(blocks, "attentions"): | |
| trainable_modules.append(blocks.attentions) | |
| else: | |
| for block in blocks: | |
| if hasattr(block, "attentions"): | |
| trainable_modules.append(block.attentions) | |
| return trainable_modules | |
| def skip_cross_attentions(unet: UNet2DConditionModel) -> dict[str, AttnProcessor | Skip]: | |
| attn_processors = { | |
| name: unet.attn_processors[name] if name.endswith("attn1.processor") else Skip() | |
| for name in unet.attn_processors.keys() | |
| } | |
| return attn_processors | |
| def encode(image: torch.Tensor, vae: AutoencoderKL) -> torch.Tensor: | |
| image = image.to(memory_format=torch.contiguous_format).float().to(vae.device, dtype=vae.dtype) | |
| with torch.no_grad(): | |
| return vae.encode(image).latent_dist.sample() * vae.config.scaling_factor | |
| class TryOffAnyone: | |
| def __init__( | |
| self, | |
| device: torch.device, | |
| dtype: torch.dtype, | |
| concat_dim: int = -2, | |
| ) -> None: | |
| self.concat_dim = concat_dim | |
| self.device = device | |
| self.dtype = dtype | |
| self.noise_scheduler = DDIMScheduler.from_pretrained( | |
| pretrained_model_name_or_path="stable-diffusion-v1-5/stable-diffusion-inpainting", | |
| subfolder="scheduler", | |
| ) | |
| self.vae = AutoencoderKL.from_pretrained( | |
| pretrained_model_name_or_path="stabilityai/sd-vae-ft-mse", | |
| ).to(device, dtype=dtype) | |
| self.unet = UNet2DConditionModel.from_pretrained( | |
| pretrained_model_name_or_path="stable-diffusion-v1-5/stable-diffusion-inpainting", | |
| subfolder="unet", | |
| variant="fp16", | |
| ).to(device, dtype=dtype) | |
| self.unet.set_attn_processor(skip_cross_attentions(self.unet)) | |
| load_checkpoint_in_model( | |
| model=fine_tuned_modules(unet=self.unet), | |
| checkpoint=hf_hub_download( | |
| repo_id="ixarchakos/tryOffAnyone", | |
| filename="model.safetensors", | |
| ), | |
| ) | |
| def __call__( | |
| self, | |
| image: torch.Tensor, | |
| mask: torch.Tensor, | |
| inference_steps: int, | |
| scale: float, | |
| generator: torch.Generator, | |
| ) -> list[Image.Image]: | |
| image = image.unsqueeze(0).to(self.device, dtype=self.dtype) | |
| mask = (mask.unsqueeze(0) > 0.5).to(self.device, dtype=self.dtype) | |
| masked_image = image * (mask < 0.5) | |
| masked_latent = encode(masked_image, self.vae) | |
| image_latent = encode(image, self.vae) | |
| mask = torch.nn.functional.interpolate(mask, size=masked_latent.shape[-2:], mode="nearest") | |
| masked_latent_concat = torch.cat([masked_latent, image_latent], dim=self.concat_dim) | |
| mask_concat = torch.cat([mask, torch.zeros_like(mask)], dim=self.concat_dim) | |
| latents = randn_tensor( | |
| shape=masked_latent_concat.shape, | |
| generator=generator, | |
| device=self.device, | |
| dtype=self.dtype, | |
| ) | |
| self.noise_scheduler.set_timesteps(inference_steps, device=self.device) | |
| timesteps = self.noise_scheduler.timesteps | |
| if do_classifier_free_guidance := (scale > 1.0): | |
| masked_latent_concat = torch.cat( | |
| [ | |
| torch.cat([masked_latent, torch.zeros_like(image_latent)], dim=self.concat_dim), | |
| masked_latent_concat, | |
| ] | |
| ) | |
| mask_concat = torch.cat([mask_concat] * 2) | |
| extra_step = {"generator": generator, "eta": 1.0} | |
| for t in timesteps: | |
| input_latents = torch.cat([latents] * 2) if do_classifier_free_guidance else latents | |
| input_latents = self.noise_scheduler.scale_model_input(input_latents, t) | |
| input_latents = torch.cat([input_latents, mask_concat, masked_latent_concat], dim=1) | |
| noise_pred = self.unet( | |
| input_latents, | |
| t.to(self.device), | |
| encoder_hidden_states=None, | |
| return_dict=False, | |
| )[0] | |
| if do_classifier_free_guidance: | |
| noise_pred_unc, noise_pred_text = noise_pred.chunk(2) | |
| noise_pred = noise_pred_unc + scale * (noise_pred_text - noise_pred_unc) | |
| latents = self.noise_scheduler.step(noise_pred, t, latents, **extra_step).prev_sample | |
| latents = latents.split(latents.shape[self.concat_dim] // 2, dim=self.concat_dim)[0] | |
| latents = 1 / self.vae.config.scaling_factor * latents | |
| image = self.vae.decode(latents.to(self.device, dtype=self.dtype)).sample | |
| image = (image / 2 + 0.5).clamp(0, 1) | |
| image = image.cpu().permute(0, 2, 3, 1).float().numpy() | |
| image = (image * 255).round().astype("uint8") | |
| image = [Image.fromarray(im) for im in image] | |
| return image | |