Spaces:
Build error
Build error
| import torch | |
| from diffusers import AutoencoderKL, DDPMScheduler, EulerAncestralDiscreteScheduler, DDIMScheduler | |
| from transformers import CLIPImageProcessor, CLIPVisionModelWithProjection, BatchFeature | |
| import json | |
| from dataclasses import dataclass | |
| from typing import List, Optional | |
| from custum_3d_diffusion.modules import register | |
| from custum_3d_diffusion.trainings.base import BasicTrainer | |
| from custum_3d_diffusion.custum_pipeline.unifield_pipeline_img2mvimg import StableDiffusionImage2MVCustomPipeline | |
| from diffusers.pipelines.stable_diffusion import StableDiffusionPipelineOutput | |
| def get_HW(resolution): | |
| if isinstance(resolution, str): | |
| resolution = json.loads(resolution) | |
| if isinstance(resolution, int): | |
| H = W = resolution | |
| elif isinstance(resolution, list): | |
| H, W = resolution | |
| return H, W | |
| class Image2MVImageTrainer(BasicTrainer): | |
| """ | |
| Trainer for simple image to multiview images. | |
| """ | |
| class TrainerConfig(BasicTrainer.TrainerConfig): | |
| trainer_name: str = "image2mvimage" | |
| condition_image_column_name: str = "conditioning_image" | |
| image_column_name: str = "image" | |
| condition_dropout: float = 0. | |
| condition_image_resolution: str = "512" | |
| validation_images: Optional[List[str]] = None | |
| noise_offset: float = 0.1 | |
| max_loss_drop: float = 0. | |
| snr_gamma: float = 5.0 | |
| log_distribution: bool = False | |
| latents_offset: Optional[List[float]] = None | |
| input_perturbation: float = 0. | |
| noisy_condition_input: bool = False # whether to add noise for ref unet input | |
| normal_cls_offset: int = 0 | |
| condition_offset: bool = True | |
| zero_snr: bool = False | |
| linear_beta_schedule: bool = False | |
| cfg: TrainerConfig | |
| def configure(self) -> None: | |
| return super().configure() | |
| def init_shared_modules(self, shared_modules: dict) -> dict: | |
| if 'vae' not in shared_modules: | |
| vae = AutoencoderKL.from_pretrained( | |
| self.cfg.pretrained_model_name_or_path, subfolder="vae", torch_dtype=self.weight_dtype | |
| ) | |
| vae.requires_grad_(False) | |
| vae.to(self.accelerator.device, dtype=self.weight_dtype) | |
| shared_modules['vae'] = vae | |
| if 'image_encoder' not in shared_modules: | |
| image_encoder = CLIPVisionModelWithProjection.from_pretrained( | |
| self.cfg.pretrained_model_name_or_path, subfolder="image_encoder" | |
| ) | |
| image_encoder.requires_grad_(False) | |
| image_encoder.to(self.accelerator.device, dtype=self.weight_dtype) | |
| shared_modules['image_encoder'] = image_encoder | |
| if 'feature_extractor' not in shared_modules: | |
| feature_extractor = CLIPImageProcessor.from_pretrained( | |
| self.cfg.pretrained_model_name_or_path, subfolder="feature_extractor" | |
| ) | |
| shared_modules['feature_extractor'] = feature_extractor | |
| return shared_modules | |
| def init_train_dataloader(self, shared_modules: dict) -> torch.utils.data.DataLoader: | |
| raise NotImplementedError() | |
| def loss_rescale(self, loss, timesteps=None): | |
| raise NotImplementedError() | |
| def forward_step(self, batch, unet, shared_modules, noise_scheduler: DDPMScheduler, global_step) -> torch.Tensor: | |
| raise NotImplementedError() | |
| def construct_pipeline(self, shared_modules, unet, old_version=False): | |
| MyPipeline = StableDiffusionImage2MVCustomPipeline | |
| pipeline = MyPipeline.from_pretrained( | |
| self.cfg.pretrained_model_name_or_path, | |
| vae=shared_modules['vae'], | |
| image_encoder=shared_modules['image_encoder'], | |
| feature_extractor=shared_modules['feature_extractor'], | |
| unet=unet, | |
| safety_checker=None, | |
| torch_dtype=self.weight_dtype, | |
| latents_offset=self.cfg.latents_offset, | |
| noisy_cond_latents=self.cfg.noisy_condition_input, | |
| condition_offset=self.cfg.condition_offset, | |
| ) | |
| pipeline.set_progress_bar_config(disable=True) | |
| scheduler_dict = {} | |
| if self.cfg.zero_snr: | |
| scheduler_dict.update(rescale_betas_zero_snr=True) | |
| if self.cfg.linear_beta_schedule: | |
| scheduler_dict.update(beta_schedule='linear') | |
| pipeline.scheduler = EulerAncestralDiscreteScheduler.from_config(pipeline.scheduler.config, **scheduler_dict) | |
| return pipeline | |
| def get_forward_args(self): | |
| if self.cfg.seed is None: | |
| generator = None | |
| else: | |
| generator = torch.Generator(device=self.accelerator.device).manual_seed(self.cfg.seed) | |
| H, W = get_HW(self.cfg.resolution) | |
| H_cond, W_cond = get_HW(self.cfg.condition_image_resolution) | |
| sub_img_H = H // 2 | |
| num_imgs = H // sub_img_H * W // sub_img_H | |
| forward_args = dict( | |
| num_images_per_prompt=num_imgs, | |
| num_inference_steps=50, | |
| height=sub_img_H, | |
| width=sub_img_H, | |
| height_cond=H_cond, | |
| width_cond=W_cond, | |
| generator=generator, | |
| ) | |
| if self.cfg.zero_snr: | |
| forward_args.update(guidance_rescale=0.7) | |
| return forward_args | |
| def pipeline_forward(self, pipeline, **pipeline_call_kwargs) -> StableDiffusionPipelineOutput: | |
| forward_args = self.get_forward_args() | |
| forward_args.update(pipeline_call_kwargs) | |
| return pipeline(**forward_args) | |
| def batched_validation_forward(self, pipeline, **pipeline_call_kwargs) -> tuple: | |
| raise NotImplementedError() |