Spaces:
Running
on
L4
Running
on
L4
| import argparse | |
| import os | |
| from contextlib import nullcontext | |
| import torch | |
| from PIL import Image | |
| from tqdm import tqdm | |
| from transparent_background import Remover | |
| from spar3d.models.mesh import QUAD_REMESH_AVAILABLE, TRIANGLE_REMESH_AVAILABLE | |
| from spar3d.system import SPAR3D | |
| from spar3d.utils import foreground_crop, get_device, remove_background | |
| def check_positive(value): | |
| ivalue = int(value) | |
| if ivalue <= 0: | |
| raise argparse.ArgumentTypeError("%s is an invalid positive int value" % value) | |
| return ivalue | |
| if __name__ == "__main__": | |
| parser = argparse.ArgumentParser() | |
| parser.add_argument( | |
| "image", type=str, nargs="+", help="Path to input image(s) or folder." | |
| ) | |
| parser.add_argument( | |
| "--device", | |
| default=get_device(), | |
| type=str, | |
| help=f"Device to use. If no CUDA/MPS-compatible device is found, the baking will fail. Default: '{get_device()}'", | |
| ) | |
| parser.add_argument( | |
| "--pretrained-model", | |
| default="stabilityai/stable-point-aware-3d", | |
| type=str, | |
| help="Path to the pretrained model. Could be either a huggingface model id is or a local path. Default: 'stabilityai/stable-point-aware-3d'", | |
| ) | |
| parser.add_argument( | |
| "--foreground-ratio", | |
| default=1.3, | |
| type=float, | |
| help="Ratio of the foreground size to the image size. Only used when --no-remove-bg is not specified. Default: 0.85", | |
| ) | |
| parser.add_argument( | |
| "--output-dir", | |
| default="output/", | |
| type=str, | |
| help="Output directory to save the results. Default: 'output/'", | |
| ) | |
| parser.add_argument( | |
| "--texture-resolution", | |
| default=1024, | |
| type=int, | |
| help="Texture atlas resolution. Default: 1024", | |
| ) | |
| parser.add_argument( | |
| "--low-vram-mode", | |
| action="store_true", | |
| help=( | |
| "Use low VRAM mode. SPAR3D consumes 10.5GB of VRAM by default. " | |
| "This mode will reduce the VRAM consumption to roughly 7GB but in exchange " | |
| "the model will be slower. Default: False" | |
| ), | |
| ) | |
| remesh_choices = ["none"] | |
| if TRIANGLE_REMESH_AVAILABLE: | |
| remesh_choices.append("triangle") | |
| if QUAD_REMESH_AVAILABLE: | |
| remesh_choices.append("quad") | |
| parser.add_argument( | |
| "--remesh_option", | |
| choices=remesh_choices, | |
| default="none", | |
| help="Remeshing option", | |
| ) | |
| if TRIANGLE_REMESH_AVAILABLE or QUAD_REMESH_AVAILABLE: | |
| parser.add_argument( | |
| "--reduction_count_type", | |
| choices=["keep", "vertex", "faces"], | |
| default="keep", | |
| help="Vertex count type", | |
| ) | |
| parser.add_argument( | |
| "--target_count", | |
| type=check_positive, | |
| help="Selected target count.", | |
| default=2000, | |
| ) | |
| parser.add_argument( | |
| "--batch_size", default=1, type=int, help="Batch size for inference" | |
| ) | |
| args = parser.parse_args() | |
| # Ensure args.device contains cuda | |
| devices = ["cuda", "mps", "cpu"] | |
| if not any(args.device in device for device in devices): | |
| raise ValueError("Invalid device. Use cuda, mps or cpu") | |
| output_dir = args.output_dir | |
| os.makedirs(output_dir, exist_ok=True) | |
| device = args.device | |
| if not (torch.cuda.is_available() or torch.backends.mps.is_available()): | |
| device = "cpu" | |
| print("Device used: ", device) | |
| model = SPAR3D.from_pretrained( | |
| args.pretrained_model, | |
| config_name="config.yaml", | |
| weight_name="model.safetensors", | |
| low_vram_mode=args.low_vram_mode, | |
| ) | |
| model.to(device) | |
| model.eval() | |
| bg_remover = Remover(device=device) | |
| images = [] | |
| idx = 0 | |
| for image_path in args.image: | |
| def handle_image(image_path, idx): | |
| image = remove_background( | |
| Image.open(image_path).convert("RGBA"), bg_remover | |
| ) | |
| image = foreground_crop(image, args.foreground_ratio) | |
| os.makedirs(os.path.join(output_dir, str(idx)), exist_ok=True) | |
| image.save(os.path.join(output_dir, str(idx), "input.png")) | |
| images.append(image) | |
| if os.path.isdir(image_path): | |
| image_paths = [ | |
| os.path.join(image_path, f) | |
| for f in os.listdir(image_path) | |
| if f.endswith((".png", ".jpg", ".jpeg")) | |
| ] | |
| for image_path in image_paths: | |
| handle_image(image_path, idx) | |
| idx += 1 | |
| else: | |
| handle_image(image_path, idx) | |
| idx += 1 | |
| vertex_count = ( | |
| -1 | |
| if args.reduction_count_type == "keep" | |
| else ( | |
| args.target_count | |
| if args.reduction_count_type == "vertex" | |
| else args.target_count // 2 | |
| ) | |
| ) | |
| for i in tqdm(range(0, len(images), args.batch_size)): | |
| image = images[i : i + args.batch_size] | |
| if torch.cuda.is_available(): | |
| torch.cuda.reset_peak_memory_stats() | |
| with torch.no_grad(): | |
| with ( | |
| torch.autocast(device_type=device, dtype=torch.bfloat16) | |
| if "cuda" in device | |
| else nullcontext() | |
| ): | |
| mesh, glob_dict = model.run_image( | |
| image, | |
| bake_resolution=args.texture_resolution, | |
| remesh=args.remesh_option, | |
| vertex_count=vertex_count, | |
| return_points=True, | |
| ) | |
| if torch.cuda.is_available(): | |
| print("Peak Memory:", torch.cuda.max_memory_allocated() / 1024 / 1024, "MB") | |
| elif torch.backends.mps.is_available(): | |
| print( | |
| "Peak Memory:", torch.mps.driver_allocated_memory() / 1024 / 1024, "MB" | |
| ) | |
| if len(image) == 1: | |
| out_mesh_path = os.path.join(output_dir, str(i), "mesh.glb") | |
| mesh.export(out_mesh_path, include_normals=True) | |
| out_points_path = os.path.join(output_dir, str(i), "points.ply") | |
| glob_dict["point_clouds"][0].export(out_points_path) | |
| else: | |
| for j in range(len(mesh)): | |
| out_mesh_path = os.path.join(output_dir, str(i + j), "mesh.glb") | |
| mesh[j].export(out_mesh_path, include_normals=True) | |
| out_points_path = os.path.join(output_dir, str(i + j), "points.ply") | |
| glob_dict["point_clouds"][j].export(out_points_path) | |