Spaces:
Running
on
Zero
Running
on
Zero
| import os | |
| from typing import Union | |
| # this is a HF Spaces specific hack for ZeroGPU | |
| import spaces | |
| import sys | |
| import torch | |
| from shap_e.models.transmitter.base import Transmitter, VectorDecoder | |
| import torch | |
| import torch.nn as nn | |
| import gradio as gr | |
| import numpy as np | |
| from PIL import Image | |
| from omegaconf import OmegaConf | |
| from pytorch_lightning import seed_everything | |
| from huggingface_hub import hf_hub_download | |
| from diffusers import DiffusionPipeline, EulerAncestralDiscreteScheduler | |
| from einops import rearrange | |
| from shap_e.diffusion.sample import sample_latents | |
| from shap_e.diffusion.gaussian_diffusion import diffusion_from_config | |
| from shap_e.models.download import load_model, load_config | |
| from shap_e.util.notebooks import create_pan_cameras | |
| from shap_e.models.nn.camera import DifferentiableCameraBatch, DifferentiableProjectiveCamera | |
| import math | |
| import time | |
| from requests.exceptions import ReadTimeout, ConnectionError | |
| from shap_e.util.collections import AttrDict | |
| from src.utils.train_util import instantiate_from_config | |
| from src.utils.camera_util import ( | |
| FOV_to_intrinsics, | |
| get_zero123plus_input_cameras, | |
| get_circular_camera_poses, | |
| spherical_camera_pose | |
| ) | |
| from src.utils.mesh_util import save_obj, save_glb | |
| from src.utils.infer_util import remove_background, resize_foreground | |
| def decode_latent_images( | |
| xm: Union[Transmitter, VectorDecoder], | |
| latent: torch.Tensor, | |
| cameras: DifferentiableCameraBatch, | |
| rendering_mode: str = "stf", | |
| params = None, | |
| background_color: torch.Tensor = torch.tensor([255.0, 255.0, 255.0], dtype=torch.float32), | |
| ): | |
| params = params if params is not None else (xm.encoder if isinstance(xm, Transmitter) else xm).bottleneck_to_params( | |
| latent[None] | |
| ) | |
| params = xm.renderer.update(params) | |
| decoded = xm.renderer.render_views( | |
| AttrDict(cameras=cameras), | |
| params=params, | |
| options=AttrDict(rendering_mode=rendering_mode, render_with_direction=False), | |
| ) | |
| bg_color = background_color.to(decoded.channels.device) | |
| images = bg_color * decoded.transmittance + (1 - decoded.transmittance) * decoded.channels | |
| # arr = decoded.channels.clamp(0, 255).to(torch.uint8)[0].cpu().numpy() | |
| return images | |
| def create_custom_cameras(size: int, device: torch.device, azimuths: list, elevations: list, | |
| fov_degrees: float, distance: float) -> DifferentiableCameraBatch: | |
| # Object is in a 2x2x2 bounding box (-1 to 1 in each dimension) | |
| object_diagonal = distance # Correct diagonal calculation for the cube | |
| # Calculate radius based on object size and FOV | |
| fov_radians = math.radians(fov_degrees) | |
| radius = (object_diagonal / 2) / math.tan(fov_radians / 2) # Correct radius calculation | |
| origins = [] | |
| xs = [] | |
| ys = [] | |
| zs = [] | |
| for azimuth, elevation in zip(azimuths, elevations): | |
| azimuth_rad = np.radians(azimuth-90) | |
| elevation_rad = np.radians(elevation) | |
| # Calculate camera position | |
| x = radius * np.cos(elevation_rad) * np.cos(azimuth_rad) | |
| y = radius * np.cos(elevation_rad) * np.sin(azimuth_rad) | |
| z = radius * np.sin(elevation_rad) | |
| origin = np.array([x, y, z]) | |
| # Calculate camera orientation | |
| z_axis = -origin / np.linalg.norm(origin) # Point towards center | |
| x_axis = np.array([-np.sin(azimuth_rad), np.cos(azimuth_rad), 0]) | |
| y_axis = np.cross(z_axis, x_axis) | |
| origins.append(origin) | |
| zs.append(z_axis) | |
| xs.append(x_axis) | |
| ys.append(y_axis) | |
| return DifferentiableCameraBatch( | |
| shape=(1, len(origins)), | |
| flat_camera=DifferentiableProjectiveCamera( | |
| origin=torch.from_numpy(np.stack(origins, axis=0)).float().to(device), | |
| x=torch.from_numpy(np.stack(xs, axis=0)).float().to(device), | |
| y=torch.from_numpy(np.stack(ys, axis=0)).float().to(device), | |
| z=torch.from_numpy(np.stack(zs, axis=0)).float().to(device), | |
| width=size, | |
| height=size, | |
| x_fov=fov_radians, | |
| y_fov=fov_radians, | |
| ), | |
| ) | |
| def load_models(): | |
| """Initialize and load all required models""" | |
| config = OmegaConf.load('configs/instant-nerf-large-best.yaml') | |
| model_config = config.model_config | |
| infer_config = config.infer_config | |
| device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') | |
| # Load diffusion pipeline with retry logic | |
| print('Loading diffusion pipeline...') | |
| max_retries = 3 | |
| retry_delay = 5 | |
| for attempt in range(max_retries): | |
| try: | |
| pipeline = DiffusionPipeline.from_pretrained( | |
| "sudo-ai/zero123plus-v1.2", | |
| custom_pipeline="zero123plus", | |
| torch_dtype=torch.float16, | |
| local_files_only=False, | |
| resume_download=True, | |
| ) | |
| break | |
| except (ReadTimeout, ConnectionError) as e: | |
| if attempt == max_retries - 1: | |
| raise Exception(f"Failed to download pipeline after {max_retries} attempts: {str(e)}") | |
| print(f"Download attempt {attempt + 1} failed, retrying in {retry_delay} seconds...") | |
| time.sleep(retry_delay) | |
| retry_delay *= 2 # Exponential backoff | |
| pipeline.scheduler = EulerAncestralDiscreteScheduler.from_config( | |
| pipeline.scheduler.config, timestep_spacing='trailing' | |
| ) | |
| # Modify UNet to handle 8 input channels instead of 4 | |
| in_channels = 8 | |
| out_channels = pipeline.unet.conv_in.out_channels | |
| pipeline.unet.register_to_config(in_channels=in_channels) | |
| with torch.no_grad(): | |
| new_conv_in = nn.Conv2d( | |
| in_channels, out_channels, pipeline.unet.conv_in.kernel_size, | |
| pipeline.unet.conv_in.stride, pipeline.unet.conv_in.padding | |
| ) | |
| new_conv_in.weight.zero_() | |
| new_conv_in.weight[:, :4, :, :].copy_(pipeline.unet.conv_in.weight) | |
| pipeline.unet.conv_in = new_conv_in | |
| # Load custom UNet with retry logic | |
| print('Loading custom UNet...') | |
| for attempt in range(max_retries): | |
| try: | |
| pipeline.unet = pipeline.unet.from_pretrained( | |
| "YiftachEde/Sharp-It", | |
| local_files_only=False, | |
| resume_download=True, | |
| ).to(torch.float16) | |
| break | |
| except (ReadTimeout, ConnectionError) as e: | |
| if attempt == max_retries - 1: | |
| raise Exception(f"Failed to download UNet after {max_retries} attempts: {str(e)}") | |
| print(f"Download attempt {attempt + 1} failed, retrying in {retry_delay} seconds...") | |
| time.sleep(retry_delay) | |
| retry_delay *= 2 | |
| pipeline = pipeline.to(device).to(torch_dtype=torch.float16) | |
| # Load reconstruction model with retry logic | |
| print('Loading reconstruction model...') | |
| model = instantiate_from_config(model_config) | |
| for attempt in range(max_retries): | |
| try: | |
| model_path = hf_hub_download( | |
| repo_id="TencentARC/InstantMesh", | |
| filename="instant_nerf_large.ckpt", | |
| repo_type="model", | |
| local_files_only=False, | |
| resume_download=True, | |
| cache_dir="model_cache" # Use a specific cache directory | |
| ) | |
| break | |
| except (ReadTimeout, ConnectionError) as e: | |
| if attempt == max_retries - 1: | |
| raise Exception(f"Failed to download model after {max_retries} attempts: {str(e)}") | |
| print(f"Download attempt {attempt + 1} failed, retrying in {retry_delay} seconds...") | |
| time.sleep(retry_delay) | |
| retry_delay *= 2 | |
| state_dict = torch.load(model_path, map_location='cpu')['state_dict'] | |
| state_dict = {k[14:]: v for k, v in state_dict.items() | |
| if k.startswith('lrm_generator.') and 'source_camera' not in k} | |
| model.load_state_dict(state_dict, strict=True) | |
| model = model.to(device) | |
| model.eval() | |
| return pipeline, model, infer_config | |
| def process_images(input_images, prompt, steps=75, guidance_scale=7.5, pipeline=None): | |
| """Process input images and run refinement""" | |
| device = pipeline.device | |
| if isinstance(input_images, list): | |
| if len(input_images) == 1: | |
| # Check if this is a pre-arranged layout | |
| img = Image.open(input_images[0].name).convert('RGB') | |
| if img.size == (640, 960): | |
| # This is already a layout, use it directly | |
| input_image = img | |
| else: | |
| # Single view - need 6 copies | |
| img = img.resize((320, 320)) | |
| img_array = np.array(img) / 255.0 | |
| images = [img_array] * 6 | |
| images = np.stack(images) | |
| # Convert to tensor and create layout | |
| images = torch.from_numpy(images).float() | |
| images = images.permute(0, 3, 1, 2) | |
| images = images.reshape(3, 2, 3, 320, 320) | |
| images = images.permute(0, 2, 3, 1, 4) | |
| images = images.reshape(3, 3, 320, 640) | |
| images = images.reshape(1, 3, 960, 640) | |
| # Convert back to PIL | |
| images = images.permute(0, 2, 3, 1)[0] | |
| images = (images.numpy() * 255).astype(np.uint8) | |
| input_image = Image.fromarray(images) | |
| else: | |
| # Multiple individual views | |
| images = [] | |
| for img_file in input_images: | |
| img = Image.open(img_file.name).convert('RGB') | |
| img = img.resize((320, 320)) | |
| img = np.array(img) / 255.0 | |
| images.append(img) | |
| # Pad to 6 images if needed | |
| while len(images) < 6: | |
| images.append(np.zeros_like(images[0])) | |
| images = np.stack(images[:6]) | |
| # Convert to tensor and create layout | |
| images = torch.from_numpy(images).float() | |
| images = images.permute(0, 3, 1, 2) | |
| images = images.reshape(3, 2, 3, 320, 320) | |
| images = images.permute(0, 2, 3, 1, 4) | |
| images = images.reshape(3, 3, 320, 640) | |
| images = images.reshape(1, 3, 960, 640) | |
| # Convert back to PIL | |
| images = images.permute(0, 2, 3, 1)[0] | |
| images = (images.numpy() * 255).astype(np.uint8) | |
| input_image = Image.fromarray(images) | |
| else: | |
| raise ValueError("Expected a list of images") | |
| # Generate refined output | |
| output = pipeline.refine( | |
| input_image, | |
| prompt=prompt, | |
| num_inference_steps=int(steps), | |
| guidance_scale=guidance_scale | |
| ).images[0] | |
| return output, input_image | |
| def create_mesh(refined_image, model, infer_config): | |
| """Generate mesh from refined image""" | |
| # Convert PIL image to tensor | |
| image = np.array(refined_image) / 255.0 | |
| image = torch.from_numpy(image).float().permute(2, 0, 1) | |
| # Reshape to 6 views | |
| image = image.reshape(3, 960, 640) | |
| image = image.reshape(3, 3, 320, 640) | |
| image = image.permute(1, 0, 2, 3) | |
| image = image.reshape(3, 3, 320, 2, 320) | |
| image = image.permute(0, 3, 1, 2, 4) | |
| image = image.reshape(6, 3, 320, 320) | |
| # Add batch dimension | |
| image = image.unsqueeze(0) | |
| input_cameras = get_zero123plus_input_cameras(batch_size=1, radius=4.0).to("cuda") | |
| image = image.to("cuda") | |
| with torch.no_grad(): | |
| planes = model.forward_planes(image, input_cameras) | |
| mesh_out = model.extract_mesh(planes, **infer_config) | |
| vertices, faces, vertex_colors = mesh_out | |
| return vertices, faces, vertex_colors | |
| class ShapERenderer: | |
| def __init__(self, device): | |
| print("Initializing Shap-E models...") | |
| self.device = device | |
| torch.cuda.empty_cache() # Clear GPU memory before loading | |
| self.xm = load_model('transmitter', device=self.device) | |
| self.model = load_model('text300M', device=self.device) | |
| self.diffusion = diffusion_from_config(load_config('diffusion')) | |
| print("Shap-E models initialized!") | |
| def generate_views(self, prompt, guidance_scale=15.0, num_steps=64): | |
| try: | |
| torch.cuda.empty_cache() # Clear GPU memory before generation | |
| # Generate latents using the text-to-3D model | |
| batch_size = 1 | |
| guidance_scale = float(guidance_scale) | |
| with torch.amp.autocast('cuda'): # Use automatic mixed precision | |
| # Generate latents directly without nested spaces.GPU context | |
| latents = sample_latents( | |
| batch_size=batch_size, | |
| model=self.model, | |
| diffusion=self.diffusion, | |
| guidance_scale=guidance_scale, | |
| model_kwargs=dict(texts=[prompt] * batch_size), | |
| progress=True, | |
| clip_denoised=True, | |
| use_fp16=True, | |
| use_karras=True, | |
| karras_steps=num_steps, | |
| sigma_min=1e-3, | |
| sigma_max=160, | |
| s_churn=0, | |
| ) | |
| # Render the 6 views we need with specific viewing angles | |
| size = 320 # Size of each rendered image | |
| images = [] | |
| # Define our 6 specific camera positions to match refine.py | |
| azimuths = [30, 90, 150, 210, 270, 330] | |
| elevations = [20, -10, 20, -10, 20, -10] | |
| for i, (azimuth, elevation) in enumerate(zip(azimuths, elevations)): | |
| cameras = create_custom_cameras(size, self.device, azimuths=[azimuth], elevations=[elevation], fov_degrees=30, distance=3.0) | |
| with torch.amp.autocast('cuda'): # Use automatic mixed precision | |
| rendered_image = decode_latent_images( | |
| self.xm, | |
| latents[0], | |
| cameras=cameras, | |
| rendering_mode='stf' | |
| ) | |
| images.append(rendered_image[0]) | |
| torch.cuda.empty_cache() # Clear GPU memory after each view | |
| # Convert images to uint8 | |
| images = [np.array(image) for image in images] | |
| # Create 2x3 grid layout (640x960) | |
| layout = np.zeros((960, 640, 3), dtype=np.uint8) | |
| for i, img in enumerate(images): | |
| row = i // 2 | |
| col = i % 2 | |
| layout[row*320:(row+1)*320, col*320:(col+1)*320] = img | |
| return Image.fromarray(layout), images | |
| except Exception as e: | |
| print(f"Error in generate_views: {e}") | |
| torch.cuda.empty_cache() # Clear GPU memory on error | |
| raise | |
| class RefinerInterface: | |
| def __init__(self): | |
| print("Initializing InstantMesh models...") | |
| torch.cuda.empty_cache() # Clear GPU memory before loading | |
| self.pipeline, self.model, self.infer_config = load_models() | |
| print("InstantMesh models initialized!") | |
| def refine_model(self, input_image, prompt, steps=75, guidance_scale=7.5): | |
| """Main refinement function""" | |
| try: | |
| torch.cuda.empty_cache() # Clear GPU memory before processing | |
| # Process image and get refined output | |
| input_image = Image.fromarray(input_image) | |
| # Rotate the layout if needed (if we're getting a 640x960 layout but pipeline expects 960x640) | |
| if input_image.width == 960 and input_image.height == 640: | |
| # Transpose the image to get 960x640 layout | |
| input_array = np.array(input_image) | |
| new_layout = np.zeros((960, 640, 3), dtype=np.uint8) | |
| # Rearrange from 2x3 to 3x2 | |
| for i in range(6): | |
| src_row = i // 3 | |
| src_col = i % 3 | |
| dst_row = i // 2 | |
| dst_col = i % 2 | |
| new_layout[dst_row*320:(dst_row+1)*320, dst_col*320:(dst_col+1)*320] = \ | |
| input_array[src_row*320:(src_row+1)*320, src_col*320:(src_col+1)*320] | |
| input_image = Image.fromarray(new_layout) | |
| # Process with the pipeline (expects 960x640) | |
| with torch.amp.autocast('cuda'): # Use automatic mixed precision | |
| refined_output_960x640 = self.pipeline.refine( | |
| input_image, | |
| prompt=prompt, | |
| num_inference_steps=int(steps), | |
| guidance_scale=guidance_scale | |
| ).images[0] | |
| torch.cuda.empty_cache() # Clear GPU memory after refinement | |
| # Generate mesh using the 960x640 format | |
| with torch.amp.autocast('cuda'): # Use automatic mixed precision | |
| vertices, faces, vertex_colors = create_mesh( | |
| refined_output_960x640, | |
| self.model, | |
| self.infer_config | |
| ) | |
| torch.cuda.empty_cache() # Clear GPU memory after mesh generation | |
| # Save temporary mesh file | |
| os.makedirs("temp", exist_ok=True) | |
| temp_obj = os.path.join("temp", "refined_mesh.obj") | |
| save_obj(vertices, faces, vertex_colors, temp_obj) | |
| # Convert the output to 640x960 for display | |
| refined_array = np.array(refined_output_960x640) | |
| display_layout = np.zeros((960, 640, 3), dtype=np.uint8) | |
| # Rearrange from 3x2 to 2x3 | |
| for i in range(6): | |
| src_row = i // 2 | |
| src_col = i % 2 | |
| dst_row = i // 2 | |
| dst_col = i % 2 | |
| display_layout[dst_row*320:(dst_row+1)*320, dst_col*320:(dst_col+1)*320] = \ | |
| refined_array[src_row*320:(src_row+1)*320, src_col*320:(src_col+1)*320] | |
| refined_output_640x960 = Image.fromarray(display_layout) | |
| return refined_output_640x960, temp_obj | |
| except Exception as e: | |
| print(f"Error in refine_model: {e}") | |
| torch.cuda.empty_cache() # Clear GPU memory on error | |
| raise | |
| def create_demo(): | |
| print("Initializing models...") | |
| device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') | |
| # Initialize models at startup | |
| shap_e = ShapERenderer(device) | |
| refiner = RefinerInterface() | |
| print("All models initialized!") | |
| with gr.Blocks() as demo: | |
| gr.Markdown("# Shap-E to InstantMesh Pipeline") | |
| # First row: Controls | |
| with gr.Row(): | |
| with gr.Column(): | |
| # Shap-E inputs | |
| shape_prompt = gr.Textbox( | |
| label="Shap-E Prompt", | |
| placeholder="Enter text to generate initial 3D model..." | |
| ) | |
| shape_guidance = gr.Slider( | |
| minimum=1, | |
| maximum=30, | |
| value=15.0, | |
| label="Shap-E Guidance Scale" | |
| ) | |
| shape_steps = gr.Slider( | |
| minimum=16, | |
| maximum=128, | |
| value=64, | |
| step=16, | |
| label="Shap-E Steps" | |
| ) | |
| generate_btn = gr.Button("Generate Views") | |
| with gr.Column(): | |
| # Refinement inputs | |
| refine_prompt = gr.Textbox( | |
| label="Refinement Prompt", | |
| placeholder="Enter prompt to guide refinement..." | |
| ) | |
| refine_steps = gr.Slider( | |
| minimum=30, | |
| maximum=100, | |
| value=75, | |
| step=1, | |
| label="Refinement Steps" | |
| ) | |
| refine_guidance = gr.Slider( | |
| minimum=1, | |
| maximum=20, | |
| value=7.5, | |
| label="Refinement Guidance Scale" | |
| ) | |
| refine_btn = gr.Button("Refine") | |
| error_output = gr.Textbox(label="Status/Error Messages", interactive=False) | |
| # Second row: Image panels side by side | |
| with gr.Row(): | |
| # Outputs - Images side by side | |
| shape_output = gr.Image( | |
| label="Generated Views", | |
| width=640, | |
| height=960 | |
| ) | |
| refined_output = gr.Image( | |
| label="Refined Output", | |
| width=640, | |
| height=960 | |
| ) | |
| # Third row: 3D mesh panel below | |
| with gr.Row(): | |
| # 3D mesh centered | |
| mesh_output = gr.Model3D( | |
| label="3D Mesh", | |
| clear_color=[1.0, 1.0, 1.0, 1.0], | |
| ) | |
| # Set up event handlers | |
| # Add GPU decorator to the generate function | |
| def generate(prompt, guidance_scale, num_steps): | |
| try: | |
| torch.cuda.empty_cache() # Clear GPU memory before starting | |
| with torch.no_grad(): | |
| layout, _ = shap_e.generate_views(prompt, guidance_scale, num_steps) | |
| return layout, None # Return None for error message | |
| except Exception as e: | |
| torch.cuda.empty_cache() # Clear GPU memory on error | |
| error_msg = f"Error during generation: {str(e)}" | |
| print(error_msg) | |
| return None, error_msg | |
| def refine(input_image, prompt, steps, guidance_scale): | |
| try: | |
| torch.cuda.empty_cache() # Clear GPU memory before starting | |
| refined_img, mesh_path = refiner.refine_model( | |
| input_image, | |
| prompt, | |
| steps, | |
| guidance_scale | |
| ) | |
| return refined_img, mesh_path, None # Return None for error message | |
| except Exception as e: | |
| torch.cuda.empty_cache() # Clear GPU memory on error | |
| error_msg = f"Error during refinement: {str(e)}" | |
| print(error_msg) | |
| return None, None, error_msg | |
| generate_btn.click( | |
| fn=generate, | |
| inputs=[shape_prompt, shape_guidance, shape_steps], | |
| outputs=[shape_output, error_output] | |
| ) | |
| refine_btn.click( | |
| fn=refine, | |
| inputs=[shape_output, refine_prompt, refine_steps, refine_guidance], | |
| outputs=[refined_output, mesh_output, error_output] | |
| ) | |
| return demo | |
| if __name__ == "__main__": | |
| demo = create_demo() | |
| demo.launch(share=True) |