Spaces:
Running
on
Zero
Running
on
Zero
| from typing import Callable, Dict, List, Optional, Union | |
| import gc | |
| import numpy as np | |
| import torch | |
| import torch.nn.functional as F | |
| from diffusers.pipelines.stable_video_diffusion.pipeline_stable_video_diffusion import ( | |
| _resize_with_antialiasing, | |
| StableVideoDiffusionPipeline, | |
| retrieve_timesteps, | |
| ) | |
| from diffusers.utils import logging | |
| from kornia.utils import create_meshgrid | |
| from diffusers.models.autoencoders.vae import DiagonalGaussianDistribution | |
| logger = logging.get_logger(__name__) # pylint: disable=invalid-name | |
| def normalize_point_map(point_map, valid_mask): | |
| # T,H,W,3 T,H,W | |
| norm_factor = (point_map[..., 2] * valid_mask.float()).mean() / (valid_mask.float().mean() + 1e-8) | |
| norm_factor = norm_factor.clip(min=1e-3) | |
| return point_map / norm_factor | |
| def point_map_xy2intrinsic_map(point_map_xy): | |
| # *,h,w,2 | |
| height, width = point_map_xy.shape[-3], point_map_xy.shape[-2] | |
| assert height % 2 == 0 | |
| assert width % 2 == 0 | |
| mesh_grid = create_meshgrid( | |
| height=height, | |
| width=width, | |
| normalized_coordinates=True, | |
| device=point_map_xy.device, | |
| dtype=point_map_xy.dtype | |
| )[0] # h,w,2 | |
| assert mesh_grid.abs().min() > 1e-4 | |
| # *,h,w,2 | |
| mesh_grid = mesh_grid.expand_as(point_map_xy) | |
| nc = point_map_xy.mean(dim=-2).mean(dim=-2) # *, 2 | |
| nc_map = nc[..., None, None, :].expand_as(point_map_xy) | |
| nf = ((point_map_xy - nc_map) / mesh_grid).mean(dim=-2).mean(dim=-2) | |
| nf_map = nf[..., None, None, :].expand_as(point_map_xy) | |
| # print((mesh_grid * nf_map + nc_map - point_map_xy).abs().max()) | |
| return torch.cat([nc_map, nf_map], dim=-1) | |
| def robust_min_max(tensor, quantile=0.99): | |
| T, H, W = tensor.shape | |
| min_vals = [] | |
| max_vals = [] | |
| for i in range(T): | |
| min_vals.append(torch.quantile(tensor[i], q=1-quantile, interpolation='nearest').item()) | |
| max_vals.append(torch.quantile(tensor[i], q=quantile, interpolation='nearest').item()) | |
| return min(min_vals), max(max_vals) | |
| class GeometryCrafterDiffPipeline(StableVideoDiffusionPipeline): | |
| def encode_video( | |
| self, | |
| video: torch.Tensor, | |
| chunk_size: int = 14, | |
| ) -> torch.Tensor: | |
| """ | |
| :param video: [b, c, h, w] in range [-1, 1], the b may contain multiple videos or frames | |
| :param chunk_size: the chunk size to encode video | |
| :return: image_embeddings in shape of [b, 1024] | |
| """ | |
| video_224 = _resize_with_antialiasing(video.float(), (224, 224)) | |
| video_224 = (video_224 + 1.0) / 2.0 # [-1, 1] -> [0, 1] | |
| embeddings = [] | |
| for i in range(0, video_224.shape[0], chunk_size): | |
| emb = self.feature_extractor( | |
| images=video_224[i : i + chunk_size], | |
| do_normalize=True, | |
| do_center_crop=False, | |
| do_resize=False, | |
| do_rescale=False, | |
| return_tensors="pt", | |
| ).pixel_values.to(video.device, dtype=video.dtype) | |
| embeddings.append(self.image_encoder(emb).image_embeds) # [b, 1024] | |
| embeddings = torch.cat(embeddings, dim=0) # [t, 1024] | |
| return embeddings | |
| def encode_vae_video( | |
| self, | |
| video: torch.Tensor, | |
| chunk_size: int = 14, | |
| ): | |
| """ | |
| :param video: [b, c, h, w] in range [-1, 1], the b may contain multiple videos or frames | |
| :param chunk_size: the chunk size to encode video | |
| :return: vae latents in shape of [b, c, h, w] | |
| """ | |
| video_latents = [] | |
| for i in range(0, video.shape[0], chunk_size): | |
| video_latents.append( | |
| self.vae.encode(video[i : i + chunk_size]).latent_dist.mode() | |
| ) | |
| video_latents = torch.cat(video_latents, dim=0) | |
| return video_latents | |
| def produce_priors(self, prior_model, frame, chunk_size=8): | |
| T, _, H, W = frame.shape | |
| # frame = (frame + 1) / 2 | |
| pred_point_maps = [] | |
| pred_masks = [] | |
| for i in range(0, len(frame), chunk_size): | |
| pred_p, pred_m = prior_model.forward_image(frame[i:i+chunk_size]) | |
| pred_point_maps.append(pred_p) | |
| pred_masks.append(pred_m) | |
| pred_point_maps = torch.cat(pred_point_maps, dim=0) | |
| pred_masks = torch.cat(pred_masks, dim=0) | |
| pred_masks = pred_masks.float() * 2 - 1 | |
| # T,H,W,3 T,H,W | |
| pred_point_maps = normalize_point_map(pred_point_maps, pred_masks > 0) | |
| pred_disps = 1.0 / pred_point_maps[..., 2].clamp_min(1e-3) | |
| pred_disps = pred_disps * (pred_masks > 0) | |
| min_disparity, max_disparity = robust_min_max(pred_disps) | |
| pred_disps = ((pred_disps - min_disparity) / (max_disparity - min_disparity+1e-4)).clamp(0, 1) | |
| pred_disps = pred_disps * 2 - 1 | |
| pred_point_maps[..., :2] = pred_point_maps[..., :2] / (pred_point_maps[..., 2:3] + 1e-7) | |
| pred_point_maps[..., 2] = torch.log(pred_point_maps[..., 2] + 1e-7) * (pred_masks > 0) # [x/z, y/z, log(z)] | |
| pred_intr_maps = point_map_xy2intrinsic_map(pred_point_maps[..., :2]).permute(0,3,1,2) # T,H,W,2 | |
| pred_point_maps = pred_point_maps.permute(0,3,1,2) | |
| return pred_disps, pred_masks, pred_point_maps, pred_intr_maps | |
| def encode_point_map(self, point_map_vae, disparity, valid_mask, point_map, intrinsic_map, chunk_size=8): | |
| T, _, H, W = point_map.shape | |
| latents = [] | |
| psedo_image = disparity[:, None].repeat(1,3,1,1) | |
| intrinsic_map = torch.norm(intrinsic_map[:, 2:4], p=2, dim=1, keepdim=False) | |
| for i in range(0, T, chunk_size): | |
| latent_dist = self.vae.encode(psedo_image[i : i + chunk_size].to(self.vae.dtype)).latent_dist | |
| latent_dist = point_map_vae.encode( | |
| torch.cat([ | |
| intrinsic_map[i:i+chunk_size, None], | |
| point_map[i:i+chunk_size, 2:3], | |
| disparity[i:i+chunk_size, None], | |
| valid_mask[i:i+chunk_size, None]], dim=1), | |
| latent_dist | |
| ) | |
| if isinstance(latent_dist, DiagonalGaussianDistribution): | |
| latent = latent_dist.mode() | |
| else: | |
| latent = latent_dist | |
| assert isinstance(latent, torch.Tensor) | |
| latents.append(latent) | |
| latents = torch.cat(latents, dim=0) | |
| latents = latents * self.vae.config.scaling_factor | |
| return latents | |
| def decode_point_map(self, point_map_vae, latents, chunk_size=8, force_projection=True, force_fixed_focal=True, use_extract_interp=False, need_resize=False, height=None, width=None): | |
| T = latents.shape[0] | |
| rec_intrinsic_maps = [] | |
| rec_depth_maps = [] | |
| rec_valid_masks = [] | |
| for i in range(0, T, chunk_size): | |
| lat = latents[i:i+chunk_size] | |
| rec_imap, rec_dmap, rec_vmask = point_map_vae.decode( | |
| lat, | |
| num_frames=lat.shape[0], | |
| ) | |
| rec_intrinsic_maps.append(rec_imap) | |
| rec_depth_maps.append(rec_dmap) | |
| rec_valid_masks.append(rec_vmask) | |
| rec_intrinsic_maps = torch.cat(rec_intrinsic_maps, dim=0) | |
| rec_depth_maps = torch.cat(rec_depth_maps, dim=0) | |
| rec_valid_masks = torch.cat(rec_valid_masks, dim=0) | |
| if need_resize: | |
| rec_depth_maps = F.interpolate(rec_depth_maps, (height, width), mode='nearest-exact') if use_extract_interp else F.interpolate(rec_depth_maps, (height, width), mode='bilinear', align_corners=False) | |
| rec_valid_masks = F.interpolate(rec_valid_masks, (height, width), mode='nearest-exact') if use_extract_interp else F.interpolate(rec_valid_masks, (height, width), mode='bilinear', align_corners=False) | |
| rec_intrinsic_maps = F.interpolate(rec_intrinsic_maps, (height, width), mode='bilinear', align_corners=False) | |
| H, W = rec_intrinsic_maps.shape[-2], rec_intrinsic_maps.shape[-1] | |
| mesh_grid = create_meshgrid( | |
| H, W, | |
| normalized_coordinates=True | |
| ).to(rec_intrinsic_maps.device, rec_intrinsic_maps.dtype, non_blocking=True) | |
| # 1,h,w,2 | |
| rec_intrinsic_maps = torch.cat([rec_intrinsic_maps * W / np.sqrt(W**2+H**2), rec_intrinsic_maps * H / np.sqrt(W**2+H**2)], dim=1) # t,2,h,w | |
| mesh_grid = mesh_grid.permute(0,3,1,2) | |
| rec_valid_masks = rec_valid_masks.squeeze(1) > 0 | |
| if force_projection: | |
| if force_fixed_focal: | |
| nfx = (rec_intrinsic_maps[:, 0, :, :] * rec_valid_masks.float()).mean() / (rec_valid_masks.float().mean() + 1e-4) | |
| nfy = (rec_intrinsic_maps[:, 1, :, :] * rec_valid_masks.float()).mean() / (rec_valid_masks.float().mean() + 1e-4) | |
| rec_intrinsic_maps = torch.tensor([nfx, nfy], device=rec_intrinsic_maps.device)[None, :, None, None].repeat(T, 1, 1, 1) | |
| else: | |
| nfx = (rec_intrinsic_maps[:, 0, :, :] * rec_valid_masks.float()).mean(dim=[-1, -2]) / (rec_valid_masks.float().mean(dim=[-1, -2]) + 1e-4) | |
| nfy = (rec_intrinsic_maps[:, 1, :, :] * rec_valid_masks.float()).mean(dim=[-1, -2]) / (rec_valid_masks.float().mean(dim=[-1, -2]) + 1e-4) | |
| rec_intrinsic_maps = torch.stack([nfx, nfy], dim=-1)[:, :, None, None] | |
| # t,2,1,1 | |
| rec_point_maps = torch.cat([rec_intrinsic_maps * mesh_grid, rec_depth_maps], dim=1).permute(0,2,3,1) | |
| xy, z = rec_point_maps.split([2, 1], dim=-1) | |
| z = torch.clamp_max(z, 10) # for numerical stability | |
| z = torch.exp(z) | |
| rec_point_maps = torch.cat([xy * z, z], dim=-1) | |
| return rec_point_maps, rec_valid_masks | |
| def __call__( | |
| self, | |
| video: Union[np.ndarray, torch.Tensor], | |
| point_map_vae, | |
| prior_model, | |
| height: int = 320, | |
| width: int = 640, | |
| num_inference_steps: int = 5, | |
| guidance_scale: float = 1.0, | |
| window_size: Optional[int] = 14, | |
| noise_aug_strength: float = 0.02, | |
| decode_chunk_size: Optional[int] = None, | |
| generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, | |
| latents: Optional[torch.FloatTensor] = None, | |
| callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None, | |
| callback_on_step_end_tensor_inputs: List[str] = ["latents"], | |
| overlap: int = 4, | |
| force_projection: bool = True, | |
| force_fixed_focal: bool = True, | |
| use_extract_interp: bool = False, | |
| track_time: bool = False, | |
| ): | |
| # video: in shape [t, h, w, c] if np.ndarray or [t, c, h, w] if torch.Tensor, in range [0, 1] | |
| # 0. Default height and width to unet | |
| if isinstance(video, np.ndarray): | |
| video = torch.from_numpy(video.transpose(0, 3, 1, 2)) | |
| else: | |
| assert isinstance(video, torch.Tensor) | |
| height = height or video.shape[-2] | |
| width = width or video.shape[-1] | |
| original_height = video.shape[-2] | |
| original_width = video.shape[-1] | |
| num_frames = video.shape[0] | |
| decode_chunk_size = decode_chunk_size if decode_chunk_size is not None else 8 | |
| if num_frames <= window_size: | |
| window_size = num_frames | |
| overlap = 0 | |
| stride = window_size - overlap | |
| # 1. Check inputs. Raise error if not correct | |
| assert height % 64 == 0 and width % 64 == 0 | |
| if original_height != height or original_width != width: | |
| need_resize = True | |
| else: | |
| need_resize = False | |
| # 2. Define call parameters | |
| batch_size = 1 | |
| device = self._execution_device | |
| # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2) | |
| # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1` | |
| # corresponds to doing no classifier free guidance. | |
| self._guidance_scale = guidance_scale | |
| if track_time: | |
| start_event = torch.cuda.Event(enable_timing=True) | |
| prior_event = torch.cuda.Event(enable_timing=True) | |
| encode_event = torch.cuda.Event(enable_timing=True) | |
| denoise_event = torch.cuda.Event(enable_timing=True) | |
| decode_event = torch.cuda.Event(enable_timing=True) | |
| start_event.record() | |
| # 3. Encode input video | |
| pred_disparity, pred_valid_mask, pred_point_map, pred_intrinsic_map = self.produce_priors( | |
| prior_model, | |
| video.to(device=device, dtype=torch.float32), | |
| chunk_size=decode_chunk_size | |
| ) # T,H,W T,H,W T,3,H,W T,2,H,W | |
| if need_resize: | |
| pred_disparity = F.interpolate(pred_disparity.unsqueeze(1), (height, width), mode='bilinear', align_corners=False).squeeze(1) | |
| pred_valid_mask = F.interpolate(pred_valid_mask.unsqueeze(1), (height, width), mode='bilinear', align_corners=False).squeeze(1) | |
| pred_point_map = F.interpolate(pred_point_map, (height, width), mode='bilinear', align_corners=False) | |
| pred_intrinsic_map = F.interpolate(pred_intrinsic_map, (height, width), mode='bilinear', align_corners=False) | |
| if track_time: | |
| prior_event.record() | |
| torch.cuda.synchronize() | |
| elapsed_time_ms = start_event.elapsed_time(prior_event) | |
| print(f"Elapsed time for computing per-frame prior: {elapsed_time_ms} ms") | |
| else: | |
| gc.collect() | |
| torch.cuda.empty_cache() | |
| # 3. Encode input video | |
| if need_resize: | |
| video = F.interpolate(video, (height, width), mode="bicubic", align_corners=False, antialias=True).clamp(0, 1) | |
| video = video.to(device=device, dtype=self.dtype) | |
| video = video * 2.0 - 1.0 # [0,1] -> [-1,1], in [t, c, h, w] | |
| video_embeddings = self.encode_video(video, chunk_size=decode_chunk_size).unsqueeze(0) | |
| prior_latents = self.encode_point_map( | |
| point_map_vae, | |
| pred_disparity, | |
| pred_valid_mask, | |
| pred_point_map, | |
| pred_intrinsic_map, | |
| chunk_size=decode_chunk_size | |
| ).unsqueeze(0).to(video_embeddings.dtype) # 1,T,C,H,W | |
| # 4. Encode input image using VAE | |
| # pdb.set_trace() | |
| needs_upcasting = ( | |
| self.vae.dtype == torch.float16 and self.vae.config.force_upcast | |
| ) | |
| if needs_upcasting: | |
| self.vae.to(dtype=torch.float32) | |
| video_latents = self.encode_vae_video( | |
| video.to(self.vae.dtype), | |
| chunk_size=decode_chunk_size, | |
| ).unsqueeze(0).to(video_embeddings.dtype) # [1, t, c, h, w] | |
| torch.cuda.empty_cache() | |
| if track_time: | |
| encode_event.record() | |
| torch.cuda.synchronize() | |
| elapsed_time_ms = prior_event.elapsed_time(encode_event) | |
| print(f"Elapsed time for encode prior and frames: {elapsed_time_ms} ms") | |
| else: | |
| gc.collect() | |
| torch.cuda.empty_cache() | |
| # cast back to fp16 if needed | |
| if needs_upcasting: | |
| self.vae.to(dtype=torch.float16) | |
| # 5. Get Added Time IDs | |
| added_time_ids = self._get_add_time_ids( | |
| 7, | |
| 127, | |
| noise_aug_strength, | |
| video_embeddings.dtype, | |
| batch_size, | |
| 1, | |
| False, | |
| ) # [1 or 2, 3] | |
| added_time_ids = added_time_ids.to(device) | |
| # 6. Prepare timesteps | |
| timesteps, num_inference_steps = retrieve_timesteps( | |
| self.scheduler, num_inference_steps, device, None, None | |
| ) | |
| num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order | |
| self._num_timesteps = len(timesteps) | |
| # 7. Prepare latent variables | |
| # num_channels_latents = self.unet.config.in_channels - prior_latents.shape[1] | |
| num_channels_latents = 8 | |
| latents_init = self.prepare_latents( | |
| batch_size, | |
| window_size, | |
| num_channels_latents, | |
| height, | |
| width, | |
| video_embeddings.dtype, | |
| device, | |
| generator, | |
| latents, | |
| ) # [1, t, c, h, w] | |
| latents_all = None | |
| idx_start = 0 | |
| if overlap > 0: | |
| weights = torch.linspace(0, 1, overlap, device=device) | |
| weights = weights.view(1, overlap, 1, 1, 1) | |
| else: | |
| weights = None | |
| while idx_start < num_frames - overlap: | |
| idx_end = min(idx_start + window_size, num_frames) | |
| self.scheduler.set_timesteps(num_inference_steps, device=device) | |
| # 9. Denoising loop | |
| # latents_init = latents_init.flip(1) | |
| latents = latents_init[:, : idx_end - idx_start].clone() | |
| latents_init = torch.cat( | |
| [latents_init[:, -overlap:], latents_init[:, :stride]], dim=1 | |
| ) | |
| video_latents_current = video_latents[:, idx_start:idx_end] | |
| prior_latents_current = prior_latents[:, idx_start:idx_end] | |
| video_embeddings_current = video_embeddings[:, idx_start:idx_end] | |
| with self.progress_bar(total=num_inference_steps) as progress_bar: | |
| for i, t in enumerate(timesteps): | |
| if latents_all is not None and i == 0: | |
| latents[:, :overlap] = ( | |
| latents_all[:, -overlap:] | |
| + latents[:, :overlap] | |
| / self.scheduler.init_noise_sigma | |
| * self.scheduler.sigmas[i] | |
| ) | |
| latent_model_input = latents | |
| latent_model_input = self.scheduler.scale_model_input( | |
| latent_model_input, t | |
| ) # [1 or 2, t, c, h, w] | |
| latent_model_input = torch.cat( | |
| [latent_model_input, video_latents_current, prior_latents_current], dim=2 | |
| ) | |
| noise_pred = self.unet( | |
| latent_model_input, | |
| t, | |
| encoder_hidden_states=video_embeddings_current, | |
| added_time_ids=added_time_ids, | |
| return_dict=False, | |
| )[0] | |
| # pdb.set_trace() | |
| # perform guidance | |
| if self.do_classifier_free_guidance: | |
| latent_model_input = latents | |
| latent_model_input = self.scheduler.scale_model_input( | |
| latent_model_input, t | |
| ) | |
| latent_model_input = torch.cat( | |
| [latent_model_input, torch.zeros_like(latent_model_input), torch.zeros_like(latent_model_input)], | |
| dim=2, | |
| ) | |
| noise_pred_uncond = self.unet( | |
| latent_model_input, | |
| t, | |
| encoder_hidden_states=torch.zeros_like( | |
| video_embeddings_current | |
| ), | |
| added_time_ids=added_time_ids, | |
| return_dict=False, | |
| )[0] | |
| noise_pred = noise_pred_uncond + self.guidance_scale * ( | |
| noise_pred - noise_pred_uncond | |
| ) | |
| latents = self.scheduler.step(noise_pred, t, latents).prev_sample | |
| if callback_on_step_end is not None: | |
| callback_kwargs = {} | |
| for k in callback_on_step_end_tensor_inputs: | |
| callback_kwargs[k] = locals()[k] | |
| callback_outputs = callback_on_step_end( | |
| self, i, t, callback_kwargs | |
| ) | |
| latents = callback_outputs.pop("latents", latents) | |
| if i == len(timesteps) - 1 or ( | |
| (i + 1) > num_warmup_steps | |
| and (i + 1) % self.scheduler.order == 0 | |
| ): | |
| progress_bar.update() | |
| if latents_all is None: | |
| latents_all = latents.clone() | |
| else: | |
| if overlap > 0: | |
| latents_all[:, -overlap:] = latents[ | |
| :, :overlap | |
| ] * weights + latents_all[:, -overlap:] * (1 - weights) | |
| latents_all = torch.cat([latents_all, latents[:, overlap:]], dim=1) | |
| idx_start += stride | |
| latents_all = 1 / self.vae.config.scaling_factor * latents_all.squeeze(0).to(torch.float32) | |
| if track_time: | |
| denoise_event.record() | |
| torch.cuda.synchronize() | |
| elapsed_time_ms = encode_event.elapsed_time(denoise_event) | |
| print(f"Elapsed time for denoise latent: {elapsed_time_ms} ms") | |
| else: | |
| gc.collect() | |
| torch.cuda.empty_cache() | |
| point_map, valid_mask = self.decode_point_map( | |
| point_map_vae, | |
| latents_all, | |
| chunk_size=decode_chunk_size, | |
| force_projection=force_projection, | |
| force_fixed_focal=force_fixed_focal, | |
| use_extract_interp=use_extract_interp, | |
| need_resize=need_resize, | |
| height=original_height, | |
| width=original_width) | |
| if track_time: | |
| decode_event.record() | |
| torch.cuda.synchronize() | |
| elapsed_time_ms = denoise_event.elapsed_time(decode_event) | |
| print(f"Elapsed time for decode latent: {elapsed_time_ms} ms") | |
| else: | |
| gc.collect() | |
| torch.cuda.empty_cache() | |
| self.maybe_free_model_hooks() | |
| # t,h,w,3 t,h,w | |
| return point_map, valid_mask | |