Spaces:
Paused
Paused
| from typing import Callable, List, Optional, Union, Dict, Any | |
| import os | |
| from diffusers.utils import logging | |
| import PIL.Image | |
| import torch | |
| import trimesh | |
| import pymeshlab | |
| import tempfile | |
| from step1x3d_geometry.models.autoencoders.surface_extractors import MeshExtractResult | |
| logger = logging.get_logger(__name__) | |
| def preprocess_image( | |
| images_pil: Union[List[PIL.Image.Image], PIL.Image.Image], | |
| force: bool = False, | |
| background_color: List[int] = [255, 255, 255], | |
| foreground_ratio: float = 0.9, | |
| rembg_backend: str = "bria", | |
| **rembg_kwargs, | |
| ): | |
| r""" | |
| Crop and remote the background of the input image | |
| Args: | |
| image_pil (`List[PIL.Image.Image]`): | |
| List of `PIL.Image.Image` objects representing the input image. | |
| force (`bool`, *optional*, defaults to `False`): | |
| Whether to force remove the background even if the image has an alpha channel. | |
| Returns: | |
| `List[PIL.Image.Image]`: List of `PIL.Image.Image` objects representing the preprocessed image. | |
| """ | |
| is_single_image = False | |
| if isinstance(images_pil, PIL.Image.Image): | |
| images_pil = [images_pil] | |
| is_single_image = True | |
| preprocessed_images = [] | |
| for i in range(len(images_pil)): | |
| image = images_pil[i] | |
| width, height, size = image.width, image.height, image.size | |
| do_remove = True | |
| if image.mode == "RGBA" and image.getextrema()[3][0] < 255: | |
| # explain why current do not rm bg | |
| print( | |
| "alhpa channl not empty, skip remove background, using alpha channel as mask" | |
| ) | |
| do_remove = False | |
| do_remove = do_remove or force | |
| if do_remove: | |
| import rembg # lazy import | |
| if rembg_backend == "default": | |
| image = rembg.remove(image, **rembg_kwargs) | |
| else: | |
| image = rembg.remove( | |
| image, | |
| session=rembg.new_session( | |
| model_name="bria", | |
| providers=[ | |
| ( | |
| "CUDAExecutionProvider", | |
| { | |
| "device_id": 0, | |
| "arena_extend_strategy": "kSameAsRequested", | |
| "gpu_mem_limit": 6 * 1024 * 1024 * 1024, | |
| "cudnn_conv_algo_search": "HEURISTIC", | |
| }, | |
| ), | |
| "CPUExecutionProvider", | |
| ], | |
| ), | |
| **rembg_kwargs, | |
| ) | |
| # calculate the min bbox of the image | |
| alpha = image.split()[-1] | |
| bboxs = alpha.getbbox() | |
| x1, y1, x2, y2 = bboxs | |
| dy, dx = y2 - y1, x2 - x1 | |
| s = min(height * foreground_ratio / dy, width * foreground_ratio / dx) | |
| Ht, Wt = int(dy * s), int(dx * s) | |
| background = PIL.Image.new("RGBA", image.size, (*background_color, 255)) | |
| image = PIL.Image.alpha_composite(background, image) | |
| image = image.crop(alpha.getbbox()) | |
| alpha = alpha.crop(alpha.getbbox()) | |
| # Calculate the new size after rescaling | |
| new_size = tuple(int(dim * foreground_ratio) for dim in size) | |
| # Resize the image while maintaining the aspect ratio | |
| resized_image = image.resize((Wt, Ht)) | |
| resized_alpha = alpha.resize((Wt, Ht)) | |
| # Create a new image with the original size and white background | |
| padded_image = PIL.Image.new("RGB", size, tuple(background_color)) | |
| padded_alpha = PIL.Image.new("L", size, (0)) | |
| paste_position = ( | |
| (width - resized_image.width) // 2, | |
| (height - resized_image.height) // 2, | |
| ) | |
| padded_image.paste(resized_image, paste_position) | |
| padded_alpha.paste(resized_alpha, paste_position) | |
| # expand image to 1:1 | |
| width, height = padded_image.size | |
| if width == height: | |
| padded_image.putalpha(padded_alpha) | |
| preprocessed_images.append(padded_image) | |
| continue | |
| new_size = (max(width, height), max(width, height)) | |
| new_image = PIL.Image.new("RGB", new_size, tuple(background_color)) | |
| new_alpha = PIL.Image.new("L", new_size, (0)) | |
| paste_position = ((new_size[0] - width) // 2, (new_size[1] - height) // 2) | |
| new_image.paste(padded_image, paste_position) | |
| new_alpha.paste(padded_alpha, paste_position) | |
| new_image.putalpha(new_alpha) | |
| preprocessed_images.append(new_image) | |
| if is_single_image: | |
| return preprocessed_images[0] | |
| return preprocessed_images | |
| def load_mesh(path): | |
| if path.endswith(".glb"): | |
| mesh = trimesh.load(path) | |
| else: | |
| mesh = pymeshlab.MeshSet() | |
| mesh.load_new_mesh(path) | |
| return mesh | |
| def trimesh2pymeshlab(mesh: trimesh.Trimesh): | |
| with tempfile.NamedTemporaryFile(suffix=".ply", delete=False) as temp_file: | |
| if isinstance(mesh, trimesh.scene.Scene): | |
| for idx, obj in enumerate(mesh.geometry.values()): | |
| if idx == 0: | |
| temp_mesh = obj | |
| else: | |
| temp_mesh = temp_mesh + obj | |
| mesh = temp_mesh | |
| mesh.export(temp_file.name) | |
| mesh = pymeshlab.MeshSet() | |
| mesh.load_new_mesh(temp_file.name) | |
| return mesh | |
| def pymeshlab2trimesh(mesh: pymeshlab.MeshSet): | |
| with tempfile.NamedTemporaryFile(suffix=".ply", delete=False) as temp_file: | |
| mesh.save_current_mesh(temp_file.name) | |
| mesh = trimesh.load(temp_file.name) | |
| if isinstance(mesh, trimesh.Scene): | |
| combined_mesh = trimesh.Trimesh() | |
| for geom in mesh.geometry.values(): | |
| combined_mesh = trimesh.util.concatenate([combined_mesh, geom]) | |
| mesh = combined_mesh | |
| return mesh | |
| def import_mesh(mesh): | |
| mesh_type = type(mesh) | |
| if isinstance(mesh, str): | |
| mesh = load_mesh(mesh) | |
| elif isinstance(mesh, MeshExtractResult): | |
| mesh = pymeshlab.MeshSet() | |
| mesh_pymeshlab = pymeshlab.Mesh( | |
| vertex_matrix=mesh.verts.cpu().numpy(), face_matrix=mesh.faces.cpu().numpy() | |
| ) | |
| mesh.add_mesh(mesh_pymeshlab, "converted_mesh") | |
| if isinstance(mesh, (trimesh.Trimesh, trimesh.scene.Scene)): | |
| mesh = trimesh2pymeshlab(mesh) | |
| return mesh, mesh_type | |
| def remove_floater(mesh): | |
| mesh, mesh_type = import_mesh(mesh) | |
| mesh.apply_filter( | |
| "compute_selection_by_small_disconnected_components_per_face", nbfaceratio=0.001 | |
| ) | |
| mesh.apply_filter("compute_selection_transfer_face_to_vertex", inclusive=False) | |
| mesh.apply_filter("meshing_remove_selected_vertices_and_faces") | |
| return pymeshlab2trimesh(mesh) | |
| def remove_degenerate_face(mesh): | |
| mesh, mesh_type = import_mesh(mesh) | |
| with tempfile.NamedTemporaryFile(suffix=".ply", delete=False) as temp_file: | |
| mesh.save_current_mesh(temp_file.name) | |
| mesh = pymeshlab.MeshSet() | |
| mesh.load_new_mesh(temp_file.name) | |
| return pymeshlab2trimesh(mesh) | |
| def reduce_face(mesh, max_facenum=50000): | |
| mesh, mesh_type = import_mesh(mesh) | |
| if max_facenum > mesh.current_mesh().face_number(): | |
| return pymeshlab2trimesh(mesh) | |
| mesh.apply_filter( | |
| "meshing_decimation_quadric_edge_collapse", | |
| targetfacenum=max_facenum, | |
| qualitythr=1.0, | |
| preserveboundary=True, | |
| boundaryweight=3, | |
| preservenormal=True, | |
| preservetopology=True, | |
| autoclean=True, | |
| ) | |
| return pymeshlab2trimesh(mesh) | |
| # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps | |
| def retrieve_timesteps( | |
| scheduler, | |
| num_inference_steps: Optional[int] = None, | |
| device: Optional[Union[str, torch.device]] = None, | |
| timesteps: Optional[List[int]] = None, | |
| sigmas: Optional[List[float]] = None, | |
| **kwargs, | |
| ): | |
| r""" | |
| Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles | |
| custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`. | |
| Args: | |
| scheduler (`SchedulerMixin`): | |
| The scheduler to get timesteps from. | |
| num_inference_steps (`int`): | |
| The number of diffusion steps used when generating samples with a pre-trained model. If used, `timesteps` | |
| must be `None`. | |
| device (`str` or `torch.device`, *optional*): | |
| The device to which the timesteps should be moved to. If `None`, the timesteps are not moved. | |
| timesteps (`List[int]`, *optional*): | |
| Custom timesteps used to override the timestep spacing strategy of the scheduler. If `timesteps` is passed, | |
| `num_inference_steps` and `sigmas` must be `None`. | |
| sigmas (`List[float]`, *optional*): | |
| Custom sigmas used to override the timestep spacing strategy of the scheduler. If `sigmas` is passed, | |
| `num_inference_steps` and `timesteps` must be `None`. | |
| Returns: | |
| `Tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the | |
| second element is the number of inference steps. | |
| """ | |
| if timesteps is not None and sigmas is not None: | |
| raise ValueError( | |
| "Only one of `timesteps` or `sigmas` can be passed. Please choose one to set custom values" | |
| ) | |
| if timesteps is not None: | |
| accepts_timesteps = "timesteps" in set( | |
| inspect.signature(scheduler.set_timesteps).parameters.keys() | |
| ) | |
| if not accepts_timesteps: | |
| raise ValueError( | |
| f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom" | |
| f" timestep schedules. Please check whether you are using the correct scheduler." | |
| ) | |
| scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs) | |
| timesteps = scheduler.timesteps | |
| num_inference_steps = len(timesteps) | |
| elif sigmas is not None: | |
| accept_sigmas = "sigmas" in set( | |
| inspect.signature(scheduler.set_timesteps).parameters.keys() | |
| ) | |
| if not accept_sigmas: | |
| raise ValueError( | |
| f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom" | |
| f" sigmas schedules. Please check whether you are using the correct scheduler." | |
| ) | |
| scheduler.set_timesteps(sigmas=sigmas, device=device, **kwargs) | |
| timesteps = scheduler.timesteps | |
| num_inference_steps = len(timesteps) | |
| else: | |
| scheduler.set_timesteps(num_inference_steps, device=device, **kwargs) | |
| timesteps = scheduler.timesteps | |
| return timesteps, num_inference_steps | |
| class TransformerDiffusionMixin: | |
| r""" | |
| Helper for DiffusionPipeline with vae and transformer.(mainly for DIT) | |
| """ | |
| def enable_vae_slicing(self): | |
| r""" | |
| Enable sliced VAE decoding. When this option is enabled, the VAE will split the input tensor in slices to | |
| compute decoding in several steps. This is useful to save some memory and allow larger batch sizes. | |
| """ | |
| self.vae.enable_slicing() | |
| def disable_vae_slicing(self): | |
| r""" | |
| Disable sliced VAE decoding. If `enable_vae_slicing` was previously enabled, this method will go back to | |
| computing decoding in one step. | |
| """ | |
| self.vae.disable_slicing() | |
| def enable_vae_tiling(self): | |
| r""" | |
| Enable tiled VAE decoding. When this option is enabled, the VAE will split the input tensor into tiles to | |
| compute decoding and encoding in several steps. This is useful for saving a large amount of memory and to allow | |
| processing larger images. | |
| """ | |
| self.vae.enable_tiling() | |
| def disable_vae_tiling(self): | |
| r""" | |
| Disable tiled VAE decoding. If `enable_vae_tiling` was previously enabled, this method will go back to | |
| computing decoding in one step. | |
| """ | |
| self.vae.disable_tiling() | |
| def fuse_qkv_projections(self, transformer: bool = True, vae: bool = True): | |
| """ | |
| Enables fused QKV projections. For self-attention modules, all projection matrices (i.e., query, key, value) | |
| are fused. For cross-attention modules, key and value projection matrices are fused. | |
| <Tip warning={true}> | |
| This API is 🧪 experimental. | |
| </Tip> | |
| Args: | |
| transformer (`bool`, defaults to `True`): To apply fusion on the Transformer. | |
| vae (`bool`, defaults to `True`): To apply fusion on the VAE. | |
| """ | |
| self.fusing_transformer = False | |
| self.fusing_vae = False | |
| if transformer: | |
| self.fusing_transformer = True | |
| self.transformer.fuse_qkv_projections() | |
| if vae: | |
| self.fusing_vae = True | |
| self.vae.fuse_qkv_projections() | |
| def unfuse_qkv_projections(self, transformer: bool = True, vae: bool = True): | |
| """Disable QKV projection fusion if enabled. | |
| <Tip warning={true}> | |
| This API is 🧪 experimental. | |
| </Tip> | |
| Args: | |
| transformer (`bool`, defaults to `True`): To apply fusion on the Transformer. | |
| vae (`bool`, defaults to `True`): To apply fusion on the VAE. | |
| """ | |
| if transformer: | |
| if not self.fusing_transformer: | |
| logger.warning( | |
| "The UNet was not initially fused for QKV projections. Doing nothing." | |
| ) | |
| else: | |
| self.transformer.unfuse_qkv_projections() | |
| self.fusing_transformer = False | |
| if vae: | |
| if not self.fusing_vae: | |
| logger.warning( | |
| "The VAE was not initially fused for QKV projections. Doing nothing." | |
| ) | |
| else: | |
| self.vae.unfuse_qkv_projections() | |
| self.fusing_vae = False | |
| def try_download(model_id, subfolder): | |
| try: | |
| from huggingface_hub import snapshot_download | |
| path = snapshot_download( | |
| repo_id=model_id, | |
| allow_patterns=[f"{subfolder}/*"], | |
| ) | |
| print(path) | |
| model_path = os.path.join(path, subfolder) | |
| return model_path | |
| except Exception as e: | |
| raise e | |
| def smart_load_model(model_path, subfolder = ""): | |
| if subfolder == "": | |
| if os.path.exists(model_path): | |
| return model_path | |
| else: | |
| return try_download(model_path, '.') | |
| else: | |
| if os.path.exists(os.path.join(model_path, subfolder)): | |
| return os.path.join(model_path, subfolder) | |
| else: | |
| return try_download(model_path, subfolder) | |