Spaces:
Paused
Paused
| from dataclasses import dataclass, field | |
| import numpy as np | |
| import json | |
| import copy | |
| import torch | |
| import torch.nn as nn | |
| import torch.nn.functional as F | |
| from skimage import measure | |
| from einops import repeat | |
| from tqdm import tqdm | |
| from PIL import Image | |
| from diffusers import ( | |
| DDPMScheduler, | |
| DDIMScheduler, | |
| UniPCMultistepScheduler, | |
| KarrasVeScheduler, | |
| DPMSolverMultistepScheduler, | |
| ) | |
| from diffusers.training_utils import ( | |
| compute_density_for_timestep_sampling, | |
| compute_loss_weighting_for_sd3, | |
| free_memory, | |
| ) | |
| import step1x3d_geometry | |
| from step1x3d_geometry.systems.base import BaseSystem | |
| from step1x3d_geometry.utils.misc import get_rank | |
| from step1x3d_geometry.utils.typing import * | |
| from step1x3d_geometry.systems.utils import read_image, preprocess_image, flow_sample | |
| def get_sigmas(noise_scheduler, timesteps, n_dim=4, dtype=torch.float32): | |
| sigmas = noise_scheduler.sigmas.to(device=timesteps.device, dtype=dtype) | |
| schedule_timesteps = noise_scheduler.timesteps.to(timesteps.device) | |
| step_indices = [(schedule_timesteps == t).nonzero().item() for t in timesteps] | |
| sigma = sigmas[step_indices].flatten() | |
| while len(sigma.shape) < n_dim: | |
| sigma = sigma.unsqueeze(-1) | |
| return sigma | |
| class RectifiedFlowSystem(BaseSystem): | |
| class Config(BaseSystem.Config): | |
| skip_validation: bool = True | |
| val_samples_json: str = "" | |
| bounds: float = 1.05 | |
| mc_level: float = 0.0 | |
| octree_resolution: int = 256 | |
| # diffusion config | |
| guidance_scale: float = 7.5 | |
| num_inference_steps: int = 30 | |
| eta: float = 0.0 | |
| snr_gamma: float = 5.0 | |
| # flow | |
| weighting_scheme: str = "logit_normal" | |
| logit_mean: float = 0 | |
| logit_std: float = 1.0 | |
| mode_scale: float = 1.29 | |
| precondition_outputs: bool = True | |
| precondition_t: int = 1000 | |
| # shape vae model | |
| shape_model_type: str = None | |
| shape_model: dict = field(default_factory=dict) | |
| # condition model | |
| visual_condition_type: Optional[str] = None | |
| visual_condition: dict = field(default_factory=dict) | |
| caption_condition_type: Optional[str] = None | |
| caption_condition: dict = field(default_factory=dict) | |
| label_condition_type: Optional[str] = None | |
| label_condition: dict = field(default_factory=dict) | |
| # diffusion model | |
| denoiser_model_type: str = None | |
| denoiser_model: dict = field(default_factory=dict) | |
| # noise scheduler | |
| noise_scheduler_type: str = None | |
| noise_scheduler: dict = field(default_factory=dict) | |
| # denoise scheduler | |
| denoise_scheduler_type: str = None | |
| denoise_scheduler: dict = field(default_factory=dict) | |
| # lora | |
| use_lora: bool = False | |
| lora_layers: Optional[str] = None | |
| rank: int = 128 # The dimension of the LoRA update matrices. | |
| alpha: int = 128 | |
| cfg: Config | |
| def configure(self): | |
| super().configure() | |
| self.shape_model = step1x3d_geometry.find(self.cfg.shape_model_type)( | |
| self.cfg.shape_model | |
| ) | |
| self.shape_model.eval() | |
| self.shape_model.requires_grad_(False) | |
| if self.cfg.visual_condition_type is not None: | |
| self.visual_condition = step1x3d_geometry.find( | |
| self.cfg.visual_condition_type | |
| )(self.cfg.visual_condition) | |
| self.visual_condition.requires_grad_(False) | |
| if self.cfg.caption_condition_type is not None: | |
| self.caption_condition = step1x3d_geometry.find( | |
| self.cfg.caption_condition_type | |
| )(self.cfg.caption_condition) | |
| self.caption_condition.requires_grad_(False) | |
| if self.cfg.label_condition_type is not None: | |
| self.label_condition = step1x3d_geometry.find( | |
| self.cfg.label_condition_type | |
| )(self.cfg.label_condition) | |
| self.denoiser_model = step1x3d_geometry.find(self.cfg.denoiser_model_type)( | |
| self.cfg.denoiser_model | |
| ) | |
| if self.cfg.use_lora: # We only train the additional adapter LoRA layers | |
| self.denoiser_model.requires_grad_(False) | |
| self.noise_scheduler = step1x3d_geometry.find(self.cfg.noise_scheduler_type)( | |
| **self.cfg.noise_scheduler | |
| ) | |
| self.noise_scheduler_copy = copy.deepcopy(self.noise_scheduler) | |
| self.denoise_scheduler = step1x3d_geometry.find( | |
| self.cfg.denoise_scheduler_type | |
| )(**self.cfg.denoise_scheduler) | |
| if self.cfg.use_lora: | |
| from peft import LoraConfig, set_peft_model_state_dict | |
| if self.cfg.lora_layers is not None: | |
| self.target_modules = [ | |
| layer.strip() for layer in self.cfg.lora_layers.split(",") | |
| ] | |
| else: | |
| self.target_modules = [ | |
| "attn.to_k", | |
| "attn.to_q", | |
| "attn.to_v", | |
| "attn.to_out.0", | |
| "attn.add_k_proj", | |
| "attn.add_q_proj", | |
| "attn.add_v_proj", | |
| "attn.to_add_out", | |
| "ff.net.0.proj", | |
| "ff.net.2", | |
| "ff_context.net.0.proj", | |
| "ff_context.net.2", | |
| ] | |
| self.transformer_lora_config = LoraConfig( | |
| r=self.cfg.rank, | |
| lora_alpha=self.cfg.alpha, | |
| init_lora_weights="gaussian", | |
| target_modules=self.target_modules, | |
| ) | |
| self.denoiser_model.dit_model.add_adapter(self.transformer_lora_config) | |
| def forward(self, batch: Dict[str, Any], skip_noise=False) -> Dict[str, Any]: | |
| # 1. encode shape latents | |
| if "sharp_surface" in batch.keys(): | |
| sharp_surface = batch["sharp_surface"][ | |
| ..., : 3 + self.cfg.shape_model.point_feats | |
| ] | |
| else: | |
| sharp_surface = None | |
| shape_embeds, latents, _ = self.shape_model.encode( | |
| batch["surface"][..., : 3 + self.cfg.shape_model.point_feats], | |
| sample_posterior=True, | |
| sharp_surface=sharp_surface, | |
| ) | |
| # 2. gain visual condition | |
| visual_cond = None | |
| if self.cfg.visual_condition_type is not None: | |
| assert "image" in batch.keys(), "image is required for label encoder" | |
| if "image" in batch and batch["image"].dim() == 5: | |
| if self.training: | |
| bs, n_images = batch["image"].shape[:2] | |
| batch["image"] = batch["image"].view( | |
| bs * n_images, *batch["image"].shape[-3:] | |
| ) | |
| else: | |
| batch["image"] = batch["image"][:, 0, ...] | |
| n_images = 1 | |
| bs = batch["image"].shape[0] | |
| visual_cond = self.visual_condition(batch).to(latents) | |
| latents = latents.unsqueeze(1).repeat(1, n_images, 1, 1) | |
| latents = latents.view(bs * n_images, *latents.shape[-2:]) | |
| else: | |
| visual_cond = self.visual_condition(batch).to(latents) | |
| bs = visual_cond.shape[0] | |
| n_images = 1 | |
| ## 2.1 text condition if provided | |
| caption_cond = None | |
| if self.cfg.caption_condition_type is not None: | |
| assert "caption" in batch.keys(), "caption is required for caption encoder" | |
| assert bs == len( | |
| batch["caption"] | |
| ), "Batch size must be the same as the caption length." | |
| caption_cond = ( | |
| self.caption_condition(batch) | |
| .repeat_interleave(n_images, dim=0) | |
| .to(latents) | |
| ) | |
| ## 2.2 label condition if provided | |
| label_cond = None | |
| if self.cfg.label_condition_type is not None: | |
| assert "label" in batch.keys(), "label is required for label encoder" | |
| assert bs == len( | |
| batch["label"] | |
| ), "Batch size must be the same as the label length." | |
| label_cond = ( | |
| self.label_condition(batch) | |
| .repeat_interleave(n_images, dim=0) | |
| .to(latents) | |
| ) | |
| # 3. sample noise that we"ll add to the latents | |
| noise = torch.randn_like(latents).to( | |
| latents | |
| ) # [batch_size, n_token, latent_dim] | |
| # 4. Sample a random timestep | |
| u = compute_density_for_timestep_sampling( | |
| weighting_scheme=self.cfg.weighting_scheme, | |
| batch_size=bs * n_images, | |
| logit_mean=self.cfg.logit_mean, | |
| logit_std=self.cfg.logit_std, | |
| mode_scale=self.cfg.mode_scale, | |
| ) | |
| indices = (u * self.cfg.noise_scheduler.num_train_timesteps).long() | |
| timesteps = self.noise_scheduler_copy.timesteps[indices].to( | |
| device=latents.device | |
| ) | |
| # 5. add noise | |
| sigmas = get_sigmas( | |
| self.noise_scheduler_copy, timesteps, n_dim=3, dtype=latents.dtype | |
| ) | |
| noisy_z = (1.0 - sigmas) * latents + sigmas * noise | |
| # 6. diffusion model forward | |
| output = self.denoiser_model( | |
| noisy_z, timesteps.long(), visual_cond, caption_cond, label_cond | |
| ).sample | |
| # 7. compute loss | |
| if self.cfg.precondition_outputs: | |
| output = output * (-sigmas) + noisy_z | |
| # these weighting schemes use a uniform timestep sampling | |
| # and instead post-weight the loss | |
| weighting = compute_loss_weighting_for_sd3( | |
| weighting_scheme=self.cfg.weighting_scheme, sigmas=sigmas | |
| ) | |
| # flow matching loss | |
| if self.cfg.precondition_outputs: | |
| target = latents | |
| else: | |
| target = noise - latents | |
| # Compute regular loss. | |
| loss = torch.mean( | |
| (weighting.float() * (output.float() - target.float()) ** 2).reshape( | |
| target.shape[0], -1 | |
| ), | |
| 1, | |
| ) | |
| loss = loss.mean() | |
| return { | |
| "loss_diffusion": loss, | |
| "latents": latents, | |
| "x_t": noisy_z, | |
| "noise": noise, | |
| "noise_pred": output, | |
| "timesteps": timesteps, | |
| } | |
| def training_step(self, batch, batch_idx): | |
| out = self(batch) | |
| loss = 0.0 | |
| for name, value in out.items(): | |
| if name.startswith("loss_"): | |
| self.log(f"train/{name}", value) | |
| loss += value * self.C(self.cfg.loss[name.replace("loss_", "lambda_")]) | |
| if name.startswith("log_"): | |
| self.log(f"log/{name.replace('log_', '')}", value.mean()) | |
| for name, value in self.cfg.loss.items(): | |
| if name.startswith("lambda_"): | |
| self.log(f"train_params/{name}", self.C(value)) | |
| return {"loss": loss} | |
| def validation_step(self, batch, batch_idx): | |
| if self.cfg.skip_validation: | |
| return {} | |
| self.eval() | |
| if get_rank() == 0: | |
| sample_inputs = json.loads( | |
| open(self.cfg.val_samples_json).read() | |
| ) # condition | |
| sample_inputs_ = copy.deepcopy(sample_inputs) | |
| sample_outputs = self.sample(sample_inputs) # list | |
| for i, latents in enumerate(sample_outputs["latents"]): | |
| meshes = self.shape_model.extract_geometry( | |
| latents, | |
| bounds=self.cfg.bounds, | |
| mc_level=self.cfg.mc_level, | |
| octree_resolution=self.cfg.octree_resolution, | |
| enable_pbar=False, | |
| ) | |
| for j in range(len(meshes)): | |
| name = "" | |
| if "image" in sample_inputs_: | |
| name += ( | |
| sample_inputs_["image"][j] | |
| .split("/")[-1] | |
| .replace(".png", "") | |
| ) | |
| elif "mvimages" in sample_inputs_: | |
| name += ( | |
| sample_inputs_["mvimages"][j][0] | |
| .split("/")[-2] | |
| .replace(".png", "") | |
| ) | |
| if "caption" in sample_inputs_: | |
| name += "_" + sample_inputs_["caption"][j].replace( | |
| " ", "_" | |
| ).replace(".", "") | |
| if "label" in sample_inputs_: | |
| name += ( | |
| "_" | |
| + sample_inputs_["label"][j]["symmetry"] | |
| + sample_inputs_["label"][j]["edge_type"] | |
| ) | |
| if ( | |
| meshes[j].verts is not None | |
| and meshes[j].verts.shape[0] > 0 | |
| and meshes[j].faces is not None | |
| and meshes[j].faces.shape[0] > 0 | |
| ): | |
| self.save_mesh( | |
| f"it{self.true_global_step}/{name}_{i}.obj", | |
| meshes[j].verts, | |
| meshes[j].faces, | |
| ) | |
| torch.cuda.empty_cache() | |
| out = self(batch) | |
| if self.global_step == 0: | |
| latents = self.shape_model.decode(out["latents"]) | |
| meshes = self.shape_model.extract_geometry( | |
| latents, | |
| bounds=self.cfg.bounds, | |
| mc_level=self.cfg.mc_level, | |
| octree_resolution=self.cfg.octree_resolution, | |
| enable_pbar=False, | |
| ) | |
| for i, mesh in enumerate(meshes): | |
| self.save_mesh( | |
| f"it{self.true_global_step}/{batch['uid'][i]}.obj", | |
| mesh.verts, | |
| mesh.faces, | |
| ) | |
| return {"val/loss": out["loss_diffusion"]} | |
| def sample( | |
| self, | |
| sample_inputs: Dict[str, Union[torch.FloatTensor, List[str]]], | |
| sample_times: int = 1, | |
| steps: Optional[int] = None, | |
| guidance_scale: Optional[float] = None, | |
| eta: float = 0.0, | |
| seed: Optional[int] = None, | |
| **kwargs, | |
| ): | |
| if steps is None: | |
| steps = self.cfg.num_inference_steps | |
| if guidance_scale is None: | |
| guidance_scale = self.cfg.guidance_scale | |
| do_classifier_free_guidance = guidance_scale != 1.0 | |
| # conditional encode | |
| visal_cond = None | |
| if "image" in sample_inputs: | |
| sample_inputs["image"] = [ | |
| Image.open(img) if type(img) == str else img | |
| for img in sample_inputs["image"] | |
| ] | |
| sample_inputs["image"] = preprocess_image(sample_inputs["image"], **kwargs) | |
| cond = self.visual_condition.encode_image(sample_inputs["image"]) | |
| if do_classifier_free_guidance: | |
| un_cond = self.visual_condition.empty_image_embeds.repeat( | |
| len(sample_inputs["image"]), 1, 1 | |
| ).to(cond) | |
| visal_cond = torch.cat([un_cond, cond], dim=0) | |
| caption_cond = None | |
| if "caption" in sample_inputs: | |
| cond = self.label_condition.encode_label(sample_inputs["caption"]) | |
| if do_classifier_free_guidance: | |
| un_cond = self.caption_condition.empty_caption_embeds.repeat( | |
| len(sample_inputs["caption"]), 1, 1 | |
| ).to(cond) | |
| caption_cond = torch.cat([un_cond, cond], dim=0) | |
| label_cond = None | |
| if "label" in sample_inputs: | |
| cond = self.label_condition.encode_label(sample_inputs["label"]) | |
| if do_classifier_free_guidance: | |
| un_cond = self.label_condition.empty_label_embeds.repeat( | |
| len(sample_inputs["label"]), 1, 1 | |
| ).to(cond) | |
| label_cond = torch.cat([un_cond, cond], dim=0) | |
| latents_list = [] | |
| if seed != None: | |
| generator = torch.Generator(device="cuda").manual_seed(seed) | |
| else: | |
| generator = None | |
| for _ in range(sample_times): | |
| sample_loop = flow_sample( | |
| self.denoise_scheduler, | |
| self.denoiser_model.eval(), | |
| shape=self.shape_model.latent_shape, | |
| visual_cond=visal_cond, | |
| caption_cond=caption_cond, | |
| label_cond=label_cond, | |
| steps=steps, | |
| guidance_scale=guidance_scale, | |
| do_classifier_free_guidance=do_classifier_free_guidance, | |
| device=self.device, | |
| eta=eta, | |
| disable_prog=False, | |
| generator=generator, | |
| ) | |
| for sample, t in sample_loop: | |
| latents = sample | |
| latents_list.append(self.shape_model.decode(latents)) | |
| return {"latents": latents_list, "inputs": sample_inputs} | |
| def on_validation_epoch_end(self): | |
| pass | |
| def test_step(self, batch, batch_idx): | |
| return | |