Spaces:
Paused
Paused
| # Some parts of this file are refer to Hugging Face Diffusers library. | |
| import os | |
| import json | |
| import warnings | |
| from typing import Callable, List, Optional, Union, Dict, Any | |
| import PIL.Image | |
| import trimesh | |
| import rembg | |
| import torch | |
| import numpy as np | |
| from huggingface_hub import hf_hub_download | |
| from diffusers.schedulers import FlowMatchEulerDiscreteScheduler | |
| from diffusers.utils import BaseOutput | |
| from diffusers.utils.torch_utils import randn_tensor | |
| from diffusers.pipelines.pipeline_utils import DiffusionPipeline | |
| from diffusers.loaders import ( | |
| FluxIPAdapterMixin, | |
| FluxLoraLoaderMixin, | |
| FromSingleFileMixin, | |
| TextualInversionLoaderMixin, | |
| ) | |
| from .pipeline_utils import ( | |
| TransformerDiffusionMixin, | |
| preprocess_image, | |
| retrieve_timesteps, | |
| remove_floater, | |
| remove_degenerate_face, | |
| reduce_face, | |
| smart_load_model, | |
| ) | |
| from transformers import ( | |
| BitImageProcessor, | |
| ) | |
| import step1x3d_geometry | |
| from step1x3d_geometry.models.autoencoders.surface_extractors import MeshExtractResult | |
| from step1x3d_geometry.utils.config import ExperimentConfig, load_config | |
| from ..autoencoders.michelangelo_autoencoder import MichelangeloAutoencoder | |
| from ..conditional_encoders.dinov2_encoder import Dinov2Encoder | |
| from ..conditional_encoders.t5_encoder import T5Encoder | |
| from ..conditional_encoders.label_encoder import LabelEncoder | |
| from ..transformers.flux_transformer_1d import FluxDenoiser | |
| class Step1X3DGeometryPipelineOutput(BaseOutput): | |
| """ | |
| Output class for image pipelines. | |
| Args: | |
| images (`List[PIL.Image.Image]` or `torch.Tensor`): | |
| List of PIL images or a tensor representing the input images. | |
| meshes (`List[trimesh.Trimesh]` or `np.ndarray`) | |
| List of denoised trimesh meshes of length `batch_size` or a tuple of NumPy array with shape `((vertices, 3), (faces, 3)) of length `batch_size``. | |
| """ | |
| image: PIL.Image.Image | |
| mesh: Union[trimesh.Trimesh, MeshExtractResult, np.ndarray] | |
| class Step1X3DGeometryPipeline( | |
| DiffusionPipeline, FromSingleFileMixin, TransformerDiffusionMixin | |
| ): | |
| """ | |
| Step1X-3D Geometry Pipeline, generate high-quality meshes conditioned on image/caption/label inputs | |
| Args: | |
| scheduler (FlowMatchEulerDiscreteScheduler): | |
| The diffusion scheduler controlling the denoising process | |
| vae (MichelangeloAutoencoder): | |
| Variational Autoencoder for latent space compression/reconstruction | |
| transformer (FluxDenoiser): | |
| Transformer-based denoising model | |
| visual_encoder (Dinov2Encoder): | |
| Pretrained visual encoder for image feature extraction | |
| caption_encoder (T5Encoder): | |
| Text encoder for processing natural language captions | |
| label_encoder (LabelEncoder): | |
| Auxiliary text encoder for label conditioning | |
| visual_eature_extractor (BitImageProcessor): | |
| Preprocessor for input images | |
| Note: | |
| - CPU offloading sequence: visual_encoder → caption_encoder → label_encoder → transformer → vae | |
| - Optional components: visual_encoder, visual_eature_extractor, caption_encoder, label_encoder | |
| """ | |
| model_cpu_offload_seq = ( | |
| "visual_encoder->caption_encoder->label_encoder->transformer->vae" | |
| ) | |
| _optional_components = [ | |
| "visual_encoder", | |
| "visual_eature_extractor", | |
| "caption_encoder", | |
| "label_encoder", | |
| ] | |
| def from_pretrained(cls, model_path, subfolder='.', **kwargs): | |
| local_model_path = smart_load_model(model_path, subfolder) | |
| return super().from_pretrained(local_model_path, **kwargs) | |
| def __init__( | |
| self, | |
| scheduler: FlowMatchEulerDiscreteScheduler, | |
| vae: MichelangeloAutoencoder, | |
| transformer: FluxDenoiser, | |
| visual_encoder: Dinov2Encoder, | |
| caption_encoder: T5Encoder, | |
| label_encoder: LabelEncoder, | |
| visual_eature_extractor: BitImageProcessor, | |
| ): | |
| super().__init__() | |
| self.register_modules( | |
| vae=vae, | |
| transformer=transformer, | |
| scheduler=scheduler, | |
| visual_encoder=visual_encoder, | |
| caption_encoder=caption_encoder, | |
| label_encoder=label_encoder, | |
| visual_eature_extractor=visual_eature_extractor, | |
| ) | |
| def guidance_scale(self): | |
| return self._guidance_scale | |
| def do_classifier_free_guidance(self): | |
| return self._guidance_scale > 1 | |
| def num_timesteps(self): | |
| return self._num_timesteps | |
| def check_inputs( | |
| self, | |
| image, | |
| ): | |
| r""" | |
| Check if the inputs are valid. Raise an error if not. | |
| """ | |
| if isinstance(image, str): | |
| assert os.path.isfile(image) or image.startswith( | |
| "http" | |
| ), "Input image must be a valid URL or a file path." | |
| elif isinstance(image, (torch.Tensor, PIL.Image.Image)): | |
| raise ValueError( | |
| "Input image must be a `torch.Tensor` or `PIL.Image.Image`." | |
| ) | |
| def encode_image(self, image, device, num_meshes_per_prompt): | |
| dtype = next(self.visual_encoder.parameters()).dtype | |
| image_embeds = self.visual_encoder.encode_image(image) | |
| image_embeds = image_embeds.repeat_interleave(num_meshes_per_prompt, dim=0) | |
| uncond_image_embeds = self.visual_encoder.empty_image_embeds.repeat( | |
| image_embeds.shape[0], 1, 1 | |
| ).to(image_embeds) | |
| return image_embeds, uncond_image_embeds | |
| def encode_caption(self, caption, device, num_meshes_per_prompt): | |
| dtype = next(self.label_encoder.parameters()).dtype | |
| caption_embeds = self.caption_encoder.encode_text([caption]) | |
| caption_embeds = caption_embeds.repeat_interleave(num_meshes_per_prompt, dim=0) | |
| uncond_caption_embeds = self.caption_encoder.empty_text_embeds.repeat( | |
| caption_embeds.shape[0], 1, 1 | |
| ).to(caption_embeds) | |
| return caption_embeds, uncond_caption_embeds | |
| def encode_label(self, label, device, num_meshes_per_prompt): | |
| dtype = next(self.label_encoder.parameters()).dtype | |
| label_embeds = self.label_encoder.encode_label([label]) | |
| label_embeds = label_embeds.repeat_interleave(num_meshes_per_prompt, dim=0) | |
| uncond_label_embeds = self.label_encoder.empty_label_embeds.repeat( | |
| label_embeds.shape[0], 1, 1 | |
| ).to(label_embeds) | |
| return label_embeds, uncond_label_embeds | |
| def prepare_latents( | |
| self, | |
| batch_size, | |
| num_tokens, | |
| num_channels_latents, | |
| dtype, | |
| device, | |
| generator, | |
| latents: Optional[torch.Tensor] = None, | |
| ): | |
| if latents is not None: | |
| return latents.to(device=device, dtype=dtype) | |
| shape = (batch_size, num_tokens, num_channels_latents) | |
| if isinstance(generator, list) and len(generator) != batch_size: | |
| raise ValueError( | |
| f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" | |
| f" size of {batch_size}. Make sure the batch size matches the length of the generators." | |
| ) | |
| latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype) | |
| return latents | |
| def __call__( | |
| self, | |
| image: Union[torch.FloatTensor, PIL.Image.Image, str], | |
| label: Optional[str] = None, | |
| caption: Optional[str] = None, | |
| num_inference_steps: int = 30, | |
| timesteps: List[int] = None, | |
| num_meshes_per_prompt: int = 1, | |
| guidance_scale: float = 7.5, | |
| generator: Optional[int] = None, | |
| latents: Optional[torch.FloatTensor] = None, | |
| force_remove_background: bool = False, | |
| background_color: List[int] = [255, 255, 255], | |
| foreground_ratio: float = 0.95, | |
| surface_extractor_type: Optional[str] = None, | |
| bounds: float = 1.05, | |
| mc_level: float = 0.0, | |
| octree_resolution: int = 384, | |
| output_type: str = "trimesh", | |
| do_remove_floater: bool = True, | |
| do_remove_degenerate_face: bool = False, | |
| do_reduce_face: bool = True, | |
| do_shade_smooth: bool = True, | |
| max_facenum: int = 200000, | |
| return_dict: bool = True, | |
| use_zero_init: Optional[bool] = True, | |
| zero_steps: Optional[int] = 0, | |
| ): | |
| r""" | |
| Function invoked when calling the pipeline for generation. | |
| Args: | |
| image (`torch.FloatTensor` or `PIL.Image.Image` or `str`): | |
| `Image`, or tensor representing an image batch, or path to an image file. The image will be encoded to | |
| its CLIP/DINO-v2 embedding which the DiT will be conditioned on. | |
| label (`str`): | |
| The label of the generated mesh, like {"symmetry": "asymmetry", "edge_type": "smooth"} | |
| num_inference_steps (`int`, *optional*, defaults to 30): | |
| The number of denoising steps. More denoising steps usually lead to a higher quality mesh at the expense | |
| of slower inference. | |
| timesteps (`List[int]`, *optional*): | |
| Custom timesteps to use for the denoising process. If not provided, will use equally spaced timesteps. | |
| num_meshes_per_prompt (`int`, *optional*, defaults to 1): | |
| The number of meshes to generate per input image. | |
| guidance_scale (`float`, *optional*, defaults to 7.5): | |
| Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598). | |
| Higher guidance scale encourages generation that closely matches the input image. | |
| generator (`int`, *optional*): | |
| A seed to make the generation deterministic. | |
| latents (`torch.FloatTensor`, *optional*): | |
| Pre-generated noisy latents to use as inputs for mesh generation. | |
| force_remove_background (`bool`, *optional*, defaults to `False`): | |
| Whether to force remove the background from the input image before processing. | |
| background_color (`List[int]`, *optional*, defaults to `[255, 255, 255]`): | |
| RGB color values for the background if it needs to be removed or modified. | |
| foreground_ratio (`float`, *optional*, defaults to 0.95): | |
| Ratio of the image to consider as foreground when processing. | |
| surface_extractor_type (`str`, *optional*, defaults to "mc"): | |
| Type of surface extraction method to use ("mc" for Marching Cubes or other available methods). | |
| bounds (`float`, *optional*, defaults to 1.05): | |
| Bounding box size for the generated mesh. | |
| mc_level (`float`, *optional*, defaults to 0.0): | |
| Iso-surface level value for Marching Cubes extraction. | |
| octree_resolution (`int`, *optional*, defaults to 256): | |
| Resolution of the octree used for mesh generation. | |
| output_type (`str`, *optional*, defaults to "trimesh"): | |
| Type of output mesh format ("trimesh" or other supported formats). | |
| return_dict (`bool`, *optional*, defaults to `True`): | |
| Whether or not to return a `MeshPipelineOutput` instead of a plain tuple. | |
| Returns: | |
| [`MeshPipelineOutput`] or `tuple`: | |
| If `return_dict` is `True`, [`MeshPipelineOutput`] is returned, otherwise a `tuple` is returned where the | |
| first element is a list of generated meshes and the second element is a list of corresponding metadata. | |
| """ | |
| # 0. Check inputs. Raise error if not correct | |
| self.check_inputs( | |
| image=image, | |
| ) | |
| device = self._execution_device | |
| self._guidance_scale = guidance_scale | |
| # 1. Define call parameters | |
| if isinstance(image, torch.Tensor): | |
| batch_size = image.shape[0] | |
| elif isinstance(image, PIL.Image.Image) or isinstance(image, str): | |
| batch_size = 1 | |
| # 2. Preprocess input image | |
| if isinstance(image, torch.Tensor): | |
| assert image.ndim == 3 # H, W, 3 | |
| image_pil = TF.to_pil_image(image) | |
| elif isinstance(image, PIL.Image.Image): | |
| image_pil = image | |
| elif isinstance(image, str): | |
| if image.startswith("http"): | |
| import requests | |
| image_pil = PIL.Image.open(requests.get(image, stream=True).raw) | |
| else: | |
| image_pil = PIL.Image.open(image) | |
| image_pil = preprocess_image(image_pil, force=force_remove_background, background_color=background_color, foreground_ratio=foreground_ratio) # remove the background images | |
| # 3. Encode condition | |
| image_embeds, negative_image_embeds = self.encode_image( | |
| image_pil, device, num_meshes_per_prompt | |
| ) | |
| if self.do_classifier_free_guidance and image_embeds is not None: | |
| image_embeds = torch.cat([negative_image_embeds, image_embeds], dim=0) | |
| # 3.1 Encode label condition | |
| label_embeds = None | |
| if self.transformer.cfg.use_label_condition: | |
| if label is not None: | |
| label_embeds, negative_label_embeds = self.encode_label( | |
| label, device, num_meshes_per_prompt | |
| ) | |
| if self.do_classifier_free_guidance: | |
| label_embeds = torch.cat( | |
| [negative_label_embeds, label_embeds], dim=0 | |
| ) | |
| else: | |
| uncond_label_embeds = self.label_encoder.empty_label_embeds.repeat( | |
| num_meshes_per_prompt, 1, 1 | |
| ).to(image_embeds) | |
| if self.do_classifier_free_guidance: | |
| label_embeds = torch.cat( | |
| [uncond_label_embeds, uncond_label_embeds], dim=0 | |
| ) | |
| # 3.3 Encode caption condition | |
| caption_embeds = None | |
| if self.transformer.cfg.use_caption_condition: | |
| if caption is not None: | |
| caption_embeds, negative_caption_embeds = self.encode_caption( | |
| caption, device, num_meshes_per_prompt | |
| ) | |
| if self.do_classifier_free_guidance: | |
| caption_embeds = torch.cat( | |
| [negative_caption_embeds, caption_embeds], dim=0 | |
| ) | |
| else: | |
| uncond_caption_embeds = self.caption_encoder.empty_text_embeds.repeat( | |
| num_meshes_per_prompt, 1, 1 | |
| ).to(image_embeds) | |
| if self.do_classifier_free_guidance: | |
| caption_embeds = torch.cat( | |
| [uncond_caption_embeds, uncond_caption_embeds], dim=0 | |
| ) | |
| # 4. Prepare timesteps | |
| timesteps, num_inference_steps = retrieve_timesteps( | |
| self.scheduler, num_inference_steps, device, timesteps | |
| ) | |
| num_warmup_steps = max( | |
| len(timesteps) - num_inference_steps * self.scheduler.order, 0 | |
| ) | |
| self._num_timesteps = len(timesteps) | |
| # 5. Prepare latent variables | |
| num_latents = self.vae.cfg.num_latents | |
| num_channels_latents = self.transformer.cfg.input_channels | |
| latents = self.prepare_latents( | |
| batch_size * num_meshes_per_prompt, | |
| num_latents, | |
| num_channels_latents, | |
| image_embeds.dtype, | |
| device, | |
| generator, | |
| latents, | |
| ) | |
| # 6. Denoising loop | |
| with self.progress_bar(total=num_inference_steps) as progress_bar: | |
| for i, t in enumerate(timesteps): | |
| # expand the latents if we are doing classifier free guidance | |
| latent_model_input = ( | |
| torch.cat([latents] * 2) | |
| if self.do_classifier_free_guidance | |
| else latents | |
| ) | |
| # broadcast to batch dimension in a way that's compatible with ONNX/Core ML | |
| timestep = t.expand(latent_model_input.shape[0]) | |
| noise_pred = self.transformer( | |
| latent_model_input, | |
| timestep, | |
| visual_condition=image_embeds, | |
| label_condition=label_embeds, | |
| caption_condition=caption_embeds, | |
| return_dict=False, | |
| )[0] | |
| # perform guidance | |
| if self.do_classifier_free_guidance: | |
| noise_pred_uncond, noise_pred_image = noise_pred.chunk(2) | |
| noise_pred = noise_pred_uncond + self.guidance_scale * ( | |
| noise_pred_image - noise_pred_uncond | |
| ) | |
| if (i <= zero_steps) and use_zero_init: | |
| noise_pred = noise_pred * 0.0 | |
| # compute the previous noisy sample x_t -> x_t-1 | |
| latents_dtype = latents.dtype | |
| latents = self.scheduler.step( | |
| noise_pred, t, latents, return_dict=False | |
| )[0] | |
| if latents.dtype != latents_dtype: | |
| if torch.backends.mps.is_available(): | |
| # some platforms (eg. apple mps) misbehave due to a pytorch bug: https://github.com/pytorch/pytorch/pull/99272 | |
| latents = latents.to(latents_dtype) | |
| if i == len(timesteps) - 1 or ( | |
| (i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0 | |
| ): | |
| progress_bar.update() | |
| # 4. Post-processing | |
| if not output_type == "latent": | |
| if latents.dtype == torch.bfloat16: | |
| self.vae.to(torch.float16) | |
| latents = latents.to(torch.float16) | |
| mesh = self.vae.extract_geometry( | |
| self.vae.decode(latents), | |
| surface_extractor_type=surface_extractor_type, | |
| bounds=bounds, | |
| mc_level=mc_level, | |
| octree_resolution=octree_resolution, | |
| enable_pbar=False, | |
| ) | |
| if output_type != "raw": | |
| mesh_list = [] | |
| for i, cur_mesh in enumerate(mesh): | |
| print(f"Generating mesh {i+1}/{num_meshes_per_prompt}") | |
| if output_type == "trimesh": | |
| import trimesh | |
| cur_mesh = trimesh.Trimesh( | |
| vertices=cur_mesh.verts.cpu().numpy(), | |
| faces=cur_mesh.faces.cpu().numpy(), | |
| ) | |
| cur_mesh.fix_normals() | |
| cur_mesh.face_normals | |
| cur_mesh.vertex_normals | |
| cur_mesh.visual = trimesh.visual.TextureVisuals( | |
| material=trimesh.visual.material.PBRMaterial( | |
| baseColorFactor=(255, 255, 255), | |
| main_color=(255, 255, 255), | |
| metallicFactor=0.05, | |
| roughnessFactor=1.0, | |
| ) | |
| ) | |
| if do_remove_floater: | |
| cur_mesh = remove_floater(cur_mesh) | |
| if do_remove_degenerate_face: | |
| cur_mesh = remove_degenerate_face(cur_mesh) | |
| if do_reduce_face and max_facenum > 0: | |
| cur_mesh = reduce_face(cur_mesh, max_facenum) | |
| if do_shade_smooth: | |
| cur_mesh = cur_mesh.smooth_shaded | |
| mesh_list.append(cur_mesh) | |
| elif output_type == "np": | |
| if do_remove_floater: | |
| print( | |
| 'remove floater is NOT used when output_type is "np". ' | |
| ) | |
| if do_remove_degenerate_face: | |
| print( | |
| 'remove degenerate face is NOT used when output_type is "np". ' | |
| ) | |
| if do_reduce_face: | |
| print( | |
| 'reduce floater is NOT used when output_type is "np". ' | |
| ) | |
| if do_shade_smooth: | |
| print('shade smooth is NOT used when output_type is "np". ') | |
| mesh_list.append( | |
| [ | |
| cur_mesh[0].verts.cpu().numpy(), | |
| cur_mesh[0].faces.cpu().numpy(), | |
| ] | |
| ) | |
| mesh = mesh_list | |
| else: | |
| if do_remove_floater: | |
| print('remove floater is NOT used when output_type is "raw". ') | |
| if do_remove_degenerate_face: | |
| print( | |
| 'remove degenerate face is NOT used when output_type is "raw". ' | |
| ) | |
| if do_reduce_face: | |
| print('reduce floater is NOT used when output_type is "raw". ') | |
| else: | |
| mesh = latents | |
| if not return_dict: | |
| return tuple(image_pil), tuple(mesh) | |
| return Step1X3DGeometryPipelineOutput(image=image_pil, mesh=mesh) | |