# Copyright (c) Meta Platforms, Inc. and affiliates. # All rights reserved. # # This source code is licensed under the license found in the # LICENSE file in the root directory of this source tree. # conda activate hf3.10 import gc import os import shutil import sys import time from datetime import datetime os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True" import cv2 import gradio as gr import numpy as np import spaces import torch from PIL import Image from pillow_heif import register_heif_opener register_heif_opener() sys.path.append("mapanything/") from mapanything.utils.geometry import depthmap_to_world_frame, points_to_normals from mapanything.utils.hf_utils.css_and_html import ( GRADIO_CSS, MEASURE_INSTRUCTIONS_HTML, get_acknowledgements_html, get_description_html, get_gradio_theme, get_header_html, ) from mapanything.utils.hf_utils.hf_helpers import initialize_mapanything_model from mapanything.utils.hf_utils.visual_util import predictions_to_glb from mapanything.utils.image import load_images, rgb def get_logo_base64(): """Convert WAI logo to base64 for embedding in HTML""" import base64 logo_path = "examples/WAI-Logo/wai_logo.png" try: with open(logo_path, "rb") as img_file: img_data = img_file.read() base64_str = base64.b64encode(img_data).decode() return f"data:image/png;base64,{base64_str}" except FileNotFoundError: return None # MapAnything Configuration high_level_config = { "path": "configs/train.yaml", "hf_model_name": "facebook/map-anything", "model_str": "mapanything", "config_overrides": [ "machine=aws", "model=mapanything", "model/task=images_only", "model.encoder.uses_torch_hub=false", ], "checkpoint_name": "model.safetensors", "config_name": "config.json", "trained_with_amp": True, "trained_with_amp_dtype": "bf16", "data_norm_type": "dinov2", "patch_size": 14, "resolution": 518, } # Initialize model - this will be done on GPU when needed model = None # ------------------------------------------------------------------------- # 1) Core model inference # ------------------------------------------------------------------------- @spaces.GPU(duration=120) def run_model( target_dir, apply_mask=True, mask_edges=True, filter_black_bg=False, filter_white_bg=False, ): """ Run the MapAnything model on images in the 'target_dir/images' folder and return predictions. """ global model import torch # Ensure torch is available in function scope print(f"Processing images from {target_dir}") # Device check device = "cuda" if torch.cuda.is_available() else "cpu" device = torch.device(device) # Initialize model if not already done if model is None: model = initialize_mapanything_model(high_level_config, device) else: model = model.to(device) model.eval() # Load images using MapAnything's load_images function print("Loading images...") image_folder_path = os.path.join(target_dir, "images") views = load_images(image_folder_path) print(f"Loaded {len(views)} images") if len(views) == 0: raise ValueError("No images found. Check your upload.") # Run model inference print("Running inference...") # apply_mask: Whether to apply the non-ambiguous mask to the output. Defaults to True. # mask_edges: Whether to compute an edge mask based on normals and depth and apply it to the output. Defaults to True. # Use checkbox values - mask_edges is set to True by default since there's no UI control for it outputs = model.infer( views, apply_mask=apply_mask, mask_edges=True, memory_efficient_inference=False ) # Convert predictions to format expected by visualization predictions = {} # Initialize lists for the required keys extrinsic_list = [] intrinsic_list = [] world_points_list = [] depth_maps_list = [] images_list = [] final_mask_list = [] # Loop through the outputs for pred in outputs: # Extract data from predictions depthmap_torch = pred["depth_z"][0].squeeze(-1) # (H, W) intrinsics_torch = pred["intrinsics"][0] # (3, 3) camera_pose_torch = pred["camera_poses"][0] # (4, 4) # Compute new pts3d using depth, intrinsics, and camera pose pts3d_computed, valid_mask = depthmap_to_world_frame( depthmap_torch, intrinsics_torch, camera_pose_torch ) # Convert to numpy arrays for visualization # Check if mask key exists in pred, if not, fill with boolean trues in the size of depthmap_torch if "mask" in pred: mask = pred["mask"][0].squeeze(-1).cpu().numpy().astype(bool) else: # Fill with boolean trues in the size of depthmap_torch mask = np.ones_like(depthmap_torch.cpu().numpy(), dtype=bool) # Combine with valid depth mask mask = mask & valid_mask.cpu().numpy() image = pred["img_no_norm"][0].cpu().numpy() # Append to lists extrinsic_list.append(camera_pose_torch.cpu().numpy()) intrinsic_list.append(intrinsics_torch.cpu().numpy()) world_points_list.append(pts3d_computed.cpu().numpy()) depth_maps_list.append(depthmap_torch.cpu().numpy()) images_list.append(image) # Add image to list final_mask_list.append(mask) # Add final_mask to list # Convert lists to numpy arrays with required shapes # extrinsic: (S, 3, 4) - batch of camera extrinsic matrices predictions["extrinsic"] = np.stack(extrinsic_list, axis=0) # intrinsic: (S, 3, 3) - batch of camera intrinsic matrices predictions["intrinsic"] = np.stack(intrinsic_list, axis=0) # world_points: (S, H, W, 3) - batch of 3D world points predictions["world_points"] = np.stack(world_points_list, axis=0) # depth: (S, H, W, 1) or (S, H, W) - batch of depth maps depth_maps = np.stack(depth_maps_list, axis=0) # Add channel dimension if needed to match (S, H, W, 1) format if len(depth_maps.shape) == 3: depth_maps = depth_maps[..., np.newaxis] predictions["depth"] = depth_maps # images: (S, H, W, 3) - batch of input images predictions["images"] = np.stack(images_list, axis=0) # final_mask: (S, H, W) - batch of final masks for filtering predictions["final_mask"] = np.stack(final_mask_list, axis=0) # Process data for visualization tabs (depth, normal, measure) processed_data = process_predictions_for_visualization( predictions, views, high_level_config, filter_black_bg, filter_white_bg ) # Clean up torch.cuda.empty_cache() return predictions, processed_data def update_view_selectors(processed_data): """Update view selector dropdowns based on available views""" if processed_data is None or len(processed_data) == 0: choices = ["View 1"] else: num_views = len(processed_data) choices = [f"View {i + 1}" for i in range(num_views)] return ( gr.Dropdown(choices=choices, value=choices[0]), # depth_view_selector gr.Dropdown(choices=choices, value=choices[0]), # normal_view_selector gr.Dropdown(choices=choices, value=choices[0]), # measure_view_selector ) def get_view_data_by_index(processed_data, view_index): """Get view data by index, handling bounds""" if processed_data is None or len(processed_data) == 0: return None view_keys = list(processed_data.keys()) if view_index < 0 or view_index >= len(view_keys): view_index = 0 return processed_data[view_keys[view_index]] def update_depth_view(processed_data, view_index): """Update depth view for a specific view index""" view_data = get_view_data_by_index(processed_data, view_index) if view_data is None or view_data["depth"] is None: return None return colorize_depth(view_data["depth"], mask=view_data.get("mask")) def update_normal_view(processed_data, view_index): """Update normal view for a specific view index""" view_data = get_view_data_by_index(processed_data, view_index) if view_data is None or view_data["normal"] is None: return None return colorize_normal(view_data["normal"], mask=view_data.get("mask")) def update_measure_view(processed_data, view_index): """Update measure view for a specific view index with mask overlay""" view_data = get_view_data_by_index(processed_data, view_index) if view_data is None: return None, [] # image, measure_points # Get the base image image = view_data["image"].copy() # Ensure image is in uint8 format if image.dtype != np.uint8: if image.max() <= 1.0: image = (image * 255).astype(np.uint8) else: image = image.astype(np.uint8) # Apply mask overlay if mask is available if view_data["mask"] is not None: mask = view_data["mask"] # Create light grey overlay for masked areas # Masked areas (False values) will be overlaid with light grey invalid_mask = ~mask # Areas where mask is False if invalid_mask.any(): # Create a light grey overlay (RGB: 192, 192, 192) overlay_color = np.array([255, 220, 220], dtype=np.uint8) # Apply overlay with some transparency alpha = 0.5 # Transparency level for c in range(3): # RGB channels image[:, :, c] = np.where( invalid_mask, (1 - alpha) * image[:, :, c] + alpha * overlay_color[c], image[:, :, c], ).astype(np.uint8) return image, [] def navigate_depth_view(processed_data, current_selector_value, direction): """Navigate depth view (direction: -1 for previous, +1 for next)""" if processed_data is None or len(processed_data) == 0: return "View 1", None # Parse current view number try: current_view = int(current_selector_value.split()[1]) - 1 except: current_view = 0 num_views = len(processed_data) new_view = (current_view + direction) % num_views new_selector_value = f"View {new_view + 1}" depth_vis = update_depth_view(processed_data, new_view) return new_selector_value, depth_vis def navigate_normal_view(processed_data, current_selector_value, direction): """Navigate normal view (direction: -1 for previous, +1 for next)""" if processed_data is None or len(processed_data) == 0: return "View 1", None # Parse current view number try: current_view = int(current_selector_value.split()[1]) - 1 except: current_view = 0 num_views = len(processed_data) new_view = (current_view + direction) % num_views new_selector_value = f"View {new_view + 1}" normal_vis = update_normal_view(processed_data, new_view) return new_selector_value, normal_vis def navigate_measure_view(processed_data, current_selector_value, direction): """Navigate measure view (direction: -1 for previous, +1 for next)""" if processed_data is None or len(processed_data) == 0: return "View 1", None, [] # Parse current view number try: current_view = int(current_selector_value.split()[1]) - 1 except: current_view = 0 num_views = len(processed_data) new_view = (current_view + direction) % num_views new_selector_value = f"View {new_view + 1}" measure_image, measure_points = update_measure_view(processed_data, new_view) return new_selector_value, measure_image, measure_points def populate_visualization_tabs(processed_data): """Populate the depth, normal, and measure tabs with processed data""" if processed_data is None or len(processed_data) == 0: return None, None, None, [] # Use update functions to ensure confidence filtering is applied from the start depth_vis = update_depth_view(processed_data, 0) normal_vis = update_normal_view(processed_data, 0) measure_img, _ = update_measure_view(processed_data, 0) return depth_vis, normal_vis, measure_img, [] # ------------------------------------------------------------------------- # 2) Handle uploaded video/images --> produce target_dir + images # ------------------------------------------------------------------------- def handle_uploads(unified_upload, s_time_interval=1.0): """ Create a new 'target_dir' + 'images' subfolder, and place user-uploaded images or extracted frames from video into it. Return (target_dir, image_paths). """ start_time = time.time() gc.collect() torch.cuda.empty_cache() # Create a unique folder name timestamp = datetime.now().strftime("%Y%m%d_%H%M%S_%f") target_dir = f"input_images_{timestamp}" target_dir_images = os.path.join(target_dir, "images") # Clean up if somehow that folder already exists if os.path.exists(target_dir): shutil.rmtree(target_dir) os.makedirs(target_dir) os.makedirs(target_dir_images) image_paths = [] # --- Handle uploaded files (both images and videos) --- if unified_upload is not None: for file_data in unified_upload: if isinstance(file_data, dict) and "name" in file_data: file_path = file_data["name"] else: file_path = str(file_data) file_ext = os.path.splitext(file_path)[1].lower() # Check if it's a video file video_extensions = [ ".mp4", ".avi", ".mov", ".mkv", ".wmv", ".flv", ".webm", ".m4v", ".3gp", ] if file_ext in video_extensions: # Handle as video vs = cv2.VideoCapture(file_path) fps = vs.get(cv2.CAP_PROP_FPS) frame_interval = int(fps * s_time_interval) # frames per interval count = 0 video_frame_num = 0 while True: gotit, frame = vs.read() if not gotit: break count += 1 if count % frame_interval == 0: # Use original filename as prefix for frames base_name = os.path.splitext(os.path.basename(file_path))[0] image_path = os.path.join( target_dir_images, f"{base_name}_{video_frame_num:06}.png" ) cv2.imwrite(image_path, frame) image_paths.append(image_path) video_frame_num += 1 vs.release() print( f"Extracted {video_frame_num} frames from video: {os.path.basename(file_path)}" ) else: # Handle as image # Check if the file is a HEIC image if file_ext in [".heic", ".heif"]: # Convert HEIC to JPEG for better gallery compatibility try: with Image.open(file_path) as img: # Convert to RGB if necessary (HEIC can have different color modes) if img.mode not in ("RGB", "L"): img = img.convert("RGB") # Create JPEG filename base_name = os.path.splitext(os.path.basename(file_path))[0] dst_path = os.path.join( target_dir_images, f"{base_name}.jpg" ) # Save as JPEG with high quality img.save(dst_path, "JPEG", quality=95) image_paths.append(dst_path) print( f"Converted HEIC to JPEG: {os.path.basename(file_path)} -> {os.path.basename(dst_path)}" ) except Exception as e: print(f"Error converting HEIC file {file_path}: {e}") # Fall back to copying as is dst_path = os.path.join( target_dir_images, os.path.basename(file_path) ) shutil.copy(file_path, dst_path) image_paths.append(dst_path) else: # Regular image files - copy as is dst_path = os.path.join( target_dir_images, os.path.basename(file_path) ) shutil.copy(file_path, dst_path) image_paths.append(dst_path) # Sort final images for gallery image_paths = sorted(image_paths) end_time = time.time() print( f"Files processed to {target_dir_images}; took {end_time - start_time:.3f} seconds" ) return target_dir, image_paths # ------------------------------------------------------------------------- # 3) Update gallery on upload # ------------------------------------------------------------------------- def update_gallery_on_upload(input_video, input_images, s_time_interval=1.0): """ Whenever user uploads or changes files, immediately handle them and show in the gallery. Return (target_dir, image_paths). If nothing is uploaded, returns "None" and empty list. """ if not input_video and not input_images: return None, None, None, None target_dir, image_paths = handle_uploads(input_video, input_images, s_time_interval) return ( None, target_dir, image_paths, "Upload complete. Click 'Reconstruct' to begin 3D processing.", ) # ------------------------------------------------------------------------- # 4) Reconstruction: uses the target_dir plus any viz parameters # ------------------------------------------------------------------------- @spaces.GPU(duration=120) def gradio_demo( target_dir, frame_filter="All", show_cam=True, filter_black_bg=False, filter_white_bg=False, apply_mask=True, show_mesh=True, ): """ Perform reconstruction using the already-created target_dir/images. """ if not os.path.isdir(target_dir) or target_dir == "None": return None, "No valid target directory found. Please upload first.", None, None start_time = time.time() gc.collect() torch.cuda.empty_cache() # Prepare frame_filter dropdown target_dir_images = os.path.join(target_dir, "images") all_files = ( sorted(os.listdir(target_dir_images)) if os.path.isdir(target_dir_images) else [] ) all_files = [f"{i}: {filename}" for i, filename in enumerate(all_files)] frame_filter_choices = ["All"] + all_files print("Running MapAnything model...") with torch.no_grad(): predictions, processed_data = run_model(target_dir, apply_mask) # Save predictions prediction_save_path = os.path.join(target_dir, "predictions.npz") np.savez(prediction_save_path, **predictions) # Handle None frame_filter if frame_filter is None: frame_filter = "All" # Build a GLB file name glbfile = os.path.join( target_dir, f"glbscene_{frame_filter.replace('.', '_').replace(':', '').replace(' ', '_')}_cam{show_cam}_mesh{show_mesh}_black{filter_black_bg}_white{filter_white_bg}.glb", ) # Convert predictions to GLB glbscene = predictions_to_glb( predictions, filter_by_frames=frame_filter, show_cam=show_cam, mask_black_bg=filter_black_bg, mask_white_bg=filter_white_bg, as_mesh=show_mesh, # Use the show_mesh parameter ) glbscene.export(file_obj=glbfile) # Cleanup del predictions gc.collect() torch.cuda.empty_cache() end_time = time.time() print(f"Total time: {end_time - start_time:.2f} seconds") log_msg = ( f"Reconstruction Success ({len(all_files)} frames). Waiting for visualization." ) # Populate visualization tabs with processed data depth_vis, normal_vis, measure_img, measure_pts = populate_visualization_tabs( processed_data ) # Update view selectors based on available views depth_selector, normal_selector, measure_selector = update_view_selectors( processed_data ) return ( glbfile, log_msg, gr.Dropdown(choices=frame_filter_choices, value=frame_filter, interactive=True), processed_data, depth_vis, normal_vis, measure_img, "", # measure_text (empty initially) depth_selector, normal_selector, measure_selector, ) # ------------------------------------------------------------------------- # 5) Helper functions for UI resets + re-visualization # ------------------------------------------------------------------------- def colorize_depth(depth_map, mask=None): """Convert depth map to colorized visualization with optional mask""" if depth_map is None: return None # Normalize depth to 0-1 range depth_normalized = depth_map.copy() valid_mask = depth_normalized > 0 # Apply additional mask if provided (for background filtering) if mask is not None: valid_mask = valid_mask & mask if valid_mask.sum() > 0: valid_depths = depth_normalized[valid_mask] p5 = np.percentile(valid_depths, 5) p95 = np.percentile(valid_depths, 95) depth_normalized[valid_mask] = (depth_normalized[valid_mask] - p5) / (p95 - p5) # Apply colormap import matplotlib.pyplot as plt colormap = plt.cm.turbo_r colored = colormap(depth_normalized) colored = (colored[:, :, :3] * 255).astype(np.uint8) # Set invalid pixels to white colored[~valid_mask] = [255, 255, 255] return colored def colorize_normal(normal_map, mask=None): """Convert normal map to colorized visualization with optional mask""" if normal_map is None: return None # Create a copy for modification normal_vis = normal_map.copy() # Apply mask if provided (set masked areas to [0, 0, 0] which becomes grey after normalization) if mask is not None: invalid_mask = ~mask normal_vis[invalid_mask] = [0, 0, 0] # Set invalid areas to zero # Normalize normals to [0, 1] range for visualization normal_vis = (normal_vis + 1.0) / 2.0 normal_vis = (normal_vis * 255).astype(np.uint8) return normal_vis def process_predictions_for_visualization( predictions, views, high_level_config, filter_black_bg=False, filter_white_bg=False ): """Extract depth, normal, and 3D points from predictions for visualization""" processed_data = {} # Process each view for view_idx, view in enumerate(views): # Get image image = rgb(view["img"], norm_type=high_level_config["data_norm_type"]) # Get predicted points pred_pts3d = predictions["world_points"][view_idx] # Initialize data for this view view_data = { "image": image[0], "points3d": pred_pts3d, "depth": None, "normal": None, "mask": None, } # Start with the final mask from predictions mask = predictions["final_mask"][view_idx].copy() # Apply black background filtering if enabled if filter_black_bg: # Get the image colors (ensure they're in 0-255 range) view_colors = image[0] * 255 if image[0].max() <= 1.0 else image[0] # Filter out black background pixels (sum of RGB < 16) black_bg_mask = view_colors.sum(axis=2) >= 16 mask = mask & black_bg_mask # Apply white background filtering if enabled if filter_white_bg: # Get the image colors (ensure they're in 0-255 range) view_colors = image[0] * 255 if image[0].max() <= 1.0 else image[0] # Filter out white background pixels (all RGB > 240) white_bg_mask = ~( (view_colors[:, :, 0] > 240) & (view_colors[:, :, 1] > 240) & (view_colors[:, :, 2] > 240) ) mask = mask & white_bg_mask view_data["mask"] = mask view_data["depth"] = predictions["depth"][view_idx].squeeze() normals, _ = points_to_normals(pred_pts3d, mask=view_data["mask"]) view_data["normal"] = normals processed_data[view_idx] = view_data return processed_data def reset_measure(processed_data): """Reset measure points""" if processed_data is None or len(processed_data) == 0: return None, [], "" # Return the first view image first_view = list(processed_data.values())[0] return first_view["image"], [], "" def measure( processed_data, measure_points, current_view_selector, event: gr.SelectData ): """Handle measurement on images""" try: print(f"Measure function called with selector: {current_view_selector}") if processed_data is None or len(processed_data) == 0: return None, [], "No data available" # Use the currently selected view instead of always using the first view try: current_view_index = int(current_view_selector.split()[1]) - 1 except: current_view_index = 0 print(f"Using view index: {current_view_index}") # Get view data safely if current_view_index < 0 or current_view_index >= len(processed_data): current_view_index = 0 view_keys = list(processed_data.keys()) current_view = processed_data[view_keys[current_view_index]] if current_view is None: return None, [], "No view data available" point2d = event.index[0], event.index[1] print(f"Clicked point: {point2d}") # Check if the clicked point is in a masked area (prevent interaction) if ( current_view["mask"] is not None and 0 <= point2d[1] < current_view["mask"].shape[0] and 0 <= point2d[0] < current_view["mask"].shape[1] ): # Check if the point is in a masked (invalid) area if not current_view["mask"][point2d[1], point2d[0]]: print(f"Clicked point {point2d} is in masked area, ignoring click") # Always return image with mask overlay masked_image, _ = update_measure_view( processed_data, current_view_index ) return ( masked_image, measure_points, 'Cannot measure on masked areas (shown in grey)', ) measure_points.append(point2d) # Get image with mask overlay and ensure it's valid image, _ = update_measure_view(processed_data, current_view_index) if image is None: return None, [], "No image available" image = image.copy() points3d = current_view["points3d"] # Ensure image is in uint8 format for proper cv2 operations try: if image.dtype != np.uint8: if image.max() <= 1.0: # Image is in [0, 1] range, convert to [0, 255] image = (image * 255).astype(np.uint8) else: # Image is already in [0, 255] range image = image.astype(np.uint8) except Exception as e: print(f"Image conversion error: {e}") return None, [], f"Image conversion error: {e}" # Draw circles for points try: for p in measure_points: if 0 <= p[0] < image.shape[1] and 0 <= p[1] < image.shape[0]: image = cv2.circle( image, p, radius=5, color=(255, 0, 0), thickness=2 ) except Exception as e: print(f"Drawing error: {e}") return None, [], f"Drawing error: {e}" depth_text = "" try: for i, p in enumerate(measure_points): if ( current_view["depth"] is not None and 0 <= p[1] < current_view["depth"].shape[0] and 0 <= p[0] < current_view["depth"].shape[1] ): d = current_view["depth"][p[1], p[0]] depth_text += f"- **P{i + 1} depth: {d:.2f}m.**\n" else: # Use Z coordinate of 3D points if depth not available if ( points3d is not None and 0 <= p[1] < points3d.shape[0] and 0 <= p[0] < points3d.shape[1] ): z = points3d[p[1], p[0], 2] depth_text += f"- **P{i + 1} Z-coord: {z:.2f}m.**\n" except Exception as e: print(f"Depth text error: {e}") depth_text = f"Error computing depth: {e}\n" if len(measure_points) == 2: try: point1, point2 = measure_points # Draw line if ( 0 <= point1[0] < image.shape[1] and 0 <= point1[1] < image.shape[0] and 0 <= point2[0] < image.shape[1] and 0 <= point2[1] < image.shape[0] ): image = cv2.line( image, point1, point2, color=(255, 0, 0), thickness=2 ) # Compute 3D distance distance_text = "- **Distance: Unable to compute**" if ( points3d is not None and 0 <= point1[1] < points3d.shape[0] and 0 <= point1[0] < points3d.shape[1] and 0 <= point2[1] < points3d.shape[0] and 0 <= point2[0] < points3d.shape[1] ): try: p1_3d = points3d[point1[1], point1[0]] p2_3d = points3d[point2[1], point2[0]] distance = np.linalg.norm(p1_3d - p2_3d) distance_text = f"- **Distance: {distance:.2f}m**" except Exception as e: print(f"Distance computation error: {e}") distance_text = f"- **Distance computation error: {e}**" measure_points = [] text = depth_text + distance_text print(f"Measurement complete: {text}") return [image, measure_points, text] except Exception as e: print(f"Final measurement error: {e}") return None, [], f"Measurement error: {e}" else: print(f"Single point measurement: {depth_text}") return [image, measure_points, depth_text] except Exception as e: print(f"Overall measure function error: {e}") return None, [], f"Measure function error: {e}" def clear_fields(): """ Clears the 3D viewer, the stored target_dir, and empties the gallery. """ return None def update_log(): """ Display a quick log message while waiting. """ return "Loading and Reconstructing..." def update_visualization( target_dir, frame_filter, show_cam, is_example, filter_black_bg=False, filter_white_bg=False, show_mesh=True, ): """ Reload saved predictions from npz, create (or reuse) the GLB for new parameters, and return it for the 3D viewer. If is_example == "True", skip. """ # If it's an example click, skip as requested if is_example == "True": return ( gr.update(), "No reconstruction available. Please click the Reconstruct button first.", ) if not target_dir or target_dir == "None" or not os.path.isdir(target_dir): return ( gr.update(), "No reconstruction available. Please click the Reconstruct button first.", ) predictions_path = os.path.join(target_dir, "predictions.npz") if not os.path.exists(predictions_path): return ( gr.update(), f"No reconstruction available at {predictions_path}. Please run 'Reconstruct' first.", ) loaded = np.load(predictions_path, allow_pickle=True) predictions = {key: loaded[key] for key in loaded.keys()} glbfile = os.path.join( target_dir, f"glbscene_{frame_filter.replace('.', '_').replace(':', '').replace(' ', '_')}_cam{show_cam}_mesh{show_mesh}_black{filter_black_bg}_white{filter_white_bg}.glb", ) if not os.path.exists(glbfile): glbscene = predictions_to_glb( predictions, filter_by_frames=frame_filter, show_cam=show_cam, mask_black_bg=filter_black_bg, mask_white_bg=filter_white_bg, as_mesh=show_mesh, ) glbscene.export(file_obj=glbfile) return ( glbfile, "Visualization updated.", ) def update_all_views_on_filter_change( target_dir, filter_black_bg, filter_white_bg, processed_data, depth_view_selector, normal_view_selector, measure_view_selector, ): """ Update all individual view tabs when background filtering checkboxes change. This regenerates the processed data with new filtering and updates all views. """ # Check if we have a valid target directory and predictions if not target_dir or target_dir == "None" or not os.path.isdir(target_dir): return processed_data, None, None, None, [] predictions_path = os.path.join(target_dir, "predictions.npz") if not os.path.exists(predictions_path): return processed_data, None, None, None, [] try: # Load the original predictions and views loaded = np.load(predictions_path, allow_pickle=True) predictions = {key: loaded[key] for key in loaded.keys()} # Load images using MapAnything's load_images function image_folder_path = os.path.join(target_dir, "images") views = load_images(image_folder_path) # Regenerate processed data with new filtering settings new_processed_data = process_predictions_for_visualization( predictions, views, high_level_config, filter_black_bg, filter_white_bg ) # Get current view indices try: depth_view_idx = ( int(depth_view_selector.split()[1]) - 1 if depth_view_selector else 0 ) except: depth_view_idx = 0 try: normal_view_idx = ( int(normal_view_selector.split()[1]) - 1 if normal_view_selector else 0 ) except: normal_view_idx = 0 try: measure_view_idx = ( int(measure_view_selector.split()[1]) - 1 if measure_view_selector else 0 ) except: measure_view_idx = 0 # Update all views with new filtered data depth_vis = update_depth_view(new_processed_data, depth_view_idx) normal_vis = update_normal_view(new_processed_data, normal_view_idx) measure_img, _ = update_measure_view(new_processed_data, measure_view_idx) return new_processed_data, depth_vis, normal_vis, measure_img, [] except Exception as e: print(f"Error updating views on filter change: {e}") return processed_data, None, None, None, [] # ------------------------------------------------------------------------- # Example scene functions # ------------------------------------------------------------------------- def get_scene_info(examples_dir): """Get information about scenes in the examples directory""" import glob scenes = [] if not os.path.exists(examples_dir): return scenes for scene_folder in sorted(os.listdir(examples_dir)): scene_path = os.path.join(examples_dir, scene_folder) if os.path.isdir(scene_path): # Find all image files in the scene folder image_extensions = ["*.jpg", "*.jpeg", "*.png", "*.bmp", "*.tiff", "*.tif"] image_files = [] for ext in image_extensions: image_files.extend(glob.glob(os.path.join(scene_path, ext))) image_files.extend(glob.glob(os.path.join(scene_path, ext.upper()))) if image_files: # Sort images and get the first one for thumbnail image_files = sorted(image_files) first_image = image_files[0] num_images = len(image_files) scenes.append( { "name": scene_folder, "path": scene_path, "thumbnail": first_image, "num_images": num_images, "image_files": image_files, } ) return scenes def load_example_scene(scene_name, examples_dir="examples"): """Load a scene from examples directory""" scenes = get_scene_info(examples_dir) # Find the selected scene selected_scene = None for scene in scenes: if scene["name"] == scene_name: selected_scene = scene break if selected_scene is None: return None, None, None, "Scene not found" # Create file-like objects for the unified upload system # Convert image file paths to the format expected by unified_upload file_objects = [] for image_path in selected_scene["image_files"]: file_objects.append(image_path) # Create target directory and copy images using the unified upload system target_dir, image_paths = handle_uploads(file_objects, 1.0) return ( None, # Clear reconstruction output target_dir, # Set target directory image_paths, # Set gallery f"Loaded scene '{scene_name}' with {selected_scene['num_images']} images. Click 'Reconstruct' to begin 3D processing.", ) # ------------------------------------------------------------------------- # 6) Build Gradio UI # ------------------------------------------------------------------------- theme = get_gradio_theme() with gr.Blocks(theme=theme, css=GRADIO_CSS) as demo: # State variables for the tabbed interface is_example = gr.Textbox(label="is_example", visible=False, value="None") num_images = gr.Textbox(label="num_images", visible=False, value="None") processed_data_state = gr.State(value=None) measure_points_state = gr.State(value=[]) current_view_index = gr.State(value=0) # Track current view index for navigation gr.HTML(get_header_html(get_logo_base64())) gr.HTML(get_description_html()) target_dir_output = gr.Textbox(label="Target Dir", visible=False, value="None") with gr.Row(): with gr.Column(scale=2): # Unified upload component for both videos and images unified_upload = gr.File( file_count="multiple", label="Upload Video or Images", interactive=True, file_types=["image", "video"], ) with gr.Row(): s_time_interval = gr.Slider( minimum=0.1, maximum=5.0, value=1.0, step=0.1, label="Video sample time interval (take a sample every x sec.)", interactive=True, visible=True, scale=3, ) resample_btn = gr.Button( "Resample Video", visible=False, variant="secondary", scale=1, ) image_gallery = gr.Gallery( label="Preview", columns=4, height="300px", show_download_button=True, object_fit="contain", preview=True, ) clear_uploads_btn = gr.ClearButton( [unified_upload, image_gallery], value="Clear Uploads", variant="secondary", size="sm", ) with gr.Column(scale=4): with gr.Column(): gr.Markdown( "**Metric 3D Reconstruction (Point Cloud and Camera Poses)**" ) log_output = gr.Markdown( "Please upload a video or images, then click Reconstruct.", elem_classes=["custom-log"], ) # Add tabbed interface similar to MoGe with gr.Tabs(): with gr.Tab("3D View"): reconstruction_output = gr.Model3D( height=520, zoom_speed=0.5, pan_speed=0.5, clear_color=[0.0, 0.0, 0.0, 0.0], key="persistent_3d_viewer", elem_id="reconstruction_3d_viewer", ) with gr.Tab("Depth"): with gr.Row(elem_classes=["navigation-row"]): prev_depth_btn = gr.Button("◀ Previous", size="sm", scale=1) depth_view_selector = gr.Dropdown( choices=["View 1"], value="View 1", label="Select View", scale=2, interactive=True, allow_custom_value=True, ) next_depth_btn = gr.Button("Next ▶", size="sm", scale=1) depth_map = gr.Image( type="numpy", label="Colorized Depth Map", format="png", interactive=False, ) with gr.Tab("Normal"): with gr.Row(elem_classes=["navigation-row"]): prev_normal_btn = gr.Button( "◀ Previous", size="sm", scale=1 ) normal_view_selector = gr.Dropdown( choices=["View 1"], value="View 1", label="Select View", scale=2, interactive=True, allow_custom_value=True, ) next_normal_btn = gr.Button("Next ▶", size="sm", scale=1) normal_map = gr.Image( type="numpy", label="Normal Map", format="png", interactive=False, ) with gr.Tab("Measure"): gr.Markdown(MEASURE_INSTRUCTIONS_HTML) with gr.Row(elem_classes=["navigation-row"]): prev_measure_btn = gr.Button( "◀ Previous", size="sm", scale=1 ) measure_view_selector = gr.Dropdown( choices=["View 1"], value="View 1", label="Select View", scale=2, interactive=True, allow_custom_value=True, ) next_measure_btn = gr.Button("Next ▶", size="sm", scale=1) measure_image = gr.Image( type="numpy", show_label=False, format="webp", interactive=False, sources=[], ) gr.Markdown( "**Note:** Light-grey areas indicate regions with no depth information where measurements cannot be taken." ) measure_text = gr.Markdown("") with gr.Row(): submit_btn = gr.Button("Reconstruct", scale=1, variant="primary") clear_btn = gr.ClearButton( [ unified_upload, reconstruction_output, log_output, target_dir_output, image_gallery, ], scale=1, ) with gr.Row(): frame_filter = gr.Dropdown( choices=["All"], value="All", label="Show Points from Frame" ) with gr.Column(): gr.Markdown("### Pointcloud Options: (live updates)") show_cam = gr.Checkbox(label="Show Camera", value=True) show_mesh = gr.Checkbox(label="Show Mesh", value=True) filter_black_bg = gr.Checkbox( label="Filter Black Background", value=False ) filter_white_bg = gr.Checkbox( label="Filter White Background", value=False ) gr.Markdown("### Reconstruction Options: (updated on next run)") apply_mask_checkbox = gr.Checkbox( label="Apply mask for predicted ambiguous depth classes & edges", value=True, ) # ---------------------- Example Scenes Section ---------------------- gr.Markdown("## Example Scenes (lists all scenes in the examples folder)") gr.Markdown("Click any thumbnail to load the scene for reconstruction.") # Get scene information scenes = get_scene_info("examples") # Create thumbnail grid (4 columns, N rows) if scenes: for i in range(0, len(scenes), 4): # Process 4 scenes per row with gr.Row(): for j in range(4): scene_idx = i + j if scene_idx < len(scenes): scene = scenes[scene_idx] with gr.Column(scale=1, elem_classes=["clickable-thumbnail"]): # Clickable thumbnail scene_img = gr.Image( value=scene["thumbnail"], height=150, interactive=False, show_label=False, elem_id=f"scene_thumb_{scene['name']}", sources=[], ) # Scene name and image count as text below thumbnail gr.Markdown( f"**{scene['name']}** \n {scene['num_images']} images", elem_classes=["scene-info"], ) # Connect thumbnail click to load scene scene_img.select( fn=lambda name=scene["name"]: load_example_scene(name), outputs=[ reconstruction_output, target_dir_output, image_gallery, log_output, ], ) else: # Empty column to maintain grid structure with gr.Column(scale=1): pass # ------------------------------------------------------------------------- # "Reconstruct" button logic: # - Clear fields # - Update log # - gradio_demo(...) with the existing target_dir # - Then set is_example = "False" # ------------------------------------------------------------------------- submit_btn.click(fn=clear_fields, inputs=[], outputs=[reconstruction_output]).then( fn=update_log, inputs=[], outputs=[log_output] ).then( fn=gradio_demo, inputs=[ target_dir_output, frame_filter, show_cam, filter_black_bg, filter_white_bg, apply_mask_checkbox, show_mesh, ], outputs=[ reconstruction_output, log_output, frame_filter, processed_data_state, depth_map, normal_map, measure_image, measure_text, depth_view_selector, normal_view_selector, measure_view_selector, ], ).then( fn=lambda: "False", inputs=[], outputs=[is_example], # set is_example to "False" ) # ------------------------------------------------------------------------- # Real-time Visualization Updates # ------------------------------------------------------------------------- frame_filter.change( update_visualization, [ target_dir_output, frame_filter, show_cam, is_example, filter_black_bg, filter_white_bg, show_mesh, ], [reconstruction_output, log_output], ) show_cam.change( update_visualization, [ target_dir_output, frame_filter, show_cam, is_example, ], [reconstruction_output, log_output], ) filter_black_bg.change( update_visualization, [ target_dir_output, frame_filter, show_cam, is_example, filter_black_bg, filter_white_bg, ], [reconstruction_output, log_output], ).then( fn=update_all_views_on_filter_change, inputs=[ target_dir_output, filter_black_bg, filter_white_bg, processed_data_state, depth_view_selector, normal_view_selector, measure_view_selector, ], outputs=[ processed_data_state, depth_map, normal_map, measure_image, measure_points_state, ], ) filter_white_bg.change( update_visualization, [ target_dir_output, frame_filter, show_cam, is_example, filter_black_bg, filter_white_bg, show_mesh, ], [reconstruction_output, log_output], ).then( fn=update_all_views_on_filter_change, inputs=[ target_dir_output, filter_black_bg, filter_white_bg, processed_data_state, depth_view_selector, normal_view_selector, measure_view_selector, ], outputs=[ processed_data_state, depth_map, normal_map, measure_image, measure_points_state, ], ) show_mesh.change( update_visualization, [ target_dir_output, frame_filter, show_cam, is_example, filter_black_bg, filter_white_bg, show_mesh, ], [reconstruction_output, log_output], ) # ------------------------------------------------------------------------- # Auto-update gallery whenever user uploads or changes their files # ------------------------------------------------------------------------- def update_gallery_on_unified_upload(files, interval): if not files: return None, None, None target_dir, image_paths = handle_uploads(files, interval) return ( target_dir, image_paths, "Upload complete. Click 'Reconstruct' to begin 3D processing.", ) def show_resample_button(files): """Show the resample button only if there are uploaded files containing videos""" if not files: return gr.update(visible=False) # Check if any uploaded files are videos video_extensions = [ ".mp4", ".avi", ".mov", ".mkv", ".wmv", ".flv", ".webm", ".m4v", ".3gp", ] has_video = False for file_data in files: if isinstance(file_data, dict) and "name" in file_data: file_path = file_data["name"] else: file_path = str(file_data) file_ext = os.path.splitext(file_path)[1].lower() if file_ext in video_extensions: has_video = True break return gr.update(visible=has_video) def hide_resample_button(): """Hide the resample button after use""" return gr.update(visible=False) def resample_video_with_new_interval(files, new_interval, current_target_dir): """Resample video with new slider value""" if not files: return ( current_target_dir, None, "No files to resample.", gr.update(visible=False), ) # Check if we have videos to resample video_extensions = [ ".mp4", ".avi", ".mov", ".mkv", ".wmv", ".flv", ".webm", ".m4v", ".3gp", ] has_video = any( os.path.splitext( str(file_data["name"] if isinstance(file_data, dict) else file_data) )[1].lower() in video_extensions for file_data in files ) if not has_video: return ( current_target_dir, None, "No videos found to resample.", gr.update(visible=False), ) # Clean up old target directory if it exists if ( current_target_dir and current_target_dir != "None" and os.path.exists(current_target_dir) ): shutil.rmtree(current_target_dir) # Process files with new interval target_dir, image_paths = handle_uploads(files, new_interval) return ( target_dir, image_paths, f"Video resampled with {new_interval}s interval. Click 'Reconstruct' to begin 3D processing.", gr.update(visible=False), ) unified_upload.change( fn=update_gallery_on_unified_upload, inputs=[unified_upload, s_time_interval], outputs=[target_dir_output, image_gallery, log_output], ).then( fn=show_resample_button, inputs=[unified_upload], outputs=[resample_btn], ) # Show resample button when slider changes (only if files are uploaded) s_time_interval.change( fn=show_resample_button, inputs=[unified_upload], outputs=[resample_btn], ) # Handle resample button click resample_btn.click( fn=resample_video_with_new_interval, inputs=[unified_upload, s_time_interval, target_dir_output], outputs=[target_dir_output, image_gallery, log_output, resample_btn], ) # ------------------------------------------------------------------------- # Measure tab functionality # ------------------------------------------------------------------------- measure_image.select( fn=measure, inputs=[processed_data_state, measure_points_state, measure_view_selector], outputs=[measure_image, measure_points_state, measure_text], ) # ------------------------------------------------------------------------- # Navigation functionality for Depth, Normal, and Measure tabs # ------------------------------------------------------------------------- # Depth tab navigation prev_depth_btn.click( fn=lambda processed_data, current_selector: navigate_depth_view( processed_data, current_selector, -1 ), inputs=[processed_data_state, depth_view_selector], outputs=[depth_view_selector, depth_map], ) next_depth_btn.click( fn=lambda processed_data, current_selector: navigate_depth_view( processed_data, current_selector, 1 ), inputs=[processed_data_state, depth_view_selector], outputs=[depth_view_selector, depth_map], ) depth_view_selector.change( fn=lambda processed_data, selector_value: ( update_depth_view( processed_data, int(selector_value.split()[1]) - 1, ) if selector_value else None ), inputs=[processed_data_state, depth_view_selector], outputs=[depth_map], ) # Normal tab navigation prev_normal_btn.click( fn=lambda processed_data, current_selector: navigate_normal_view( processed_data, current_selector, -1 ), inputs=[processed_data_state, normal_view_selector], outputs=[normal_view_selector, normal_map], ) next_normal_btn.click( fn=lambda processed_data, current_selector: navigate_normal_view( processed_data, current_selector, 1 ), inputs=[processed_data_state, normal_view_selector], outputs=[normal_view_selector, normal_map], ) normal_view_selector.change( fn=lambda processed_data, selector_value: ( update_normal_view( processed_data, int(selector_value.split()[1]) - 1, ) if selector_value else None ), inputs=[processed_data_state, normal_view_selector], outputs=[normal_map], ) # Measure tab navigation prev_measure_btn.click( fn=lambda processed_data, current_selector: navigate_measure_view( processed_data, current_selector, -1 ), inputs=[processed_data_state, measure_view_selector], outputs=[measure_view_selector, measure_image, measure_points_state], ) next_measure_btn.click( fn=lambda processed_data, current_selector: navigate_measure_view( processed_data, current_selector, 1 ), inputs=[processed_data_state, measure_view_selector], outputs=[measure_view_selector, measure_image, measure_points_state], ) measure_view_selector.change( fn=lambda processed_data, selector_value: ( update_measure_view(processed_data, int(selector_value.split()[1]) - 1) if selector_value else (None, []) ), inputs=[processed_data_state, measure_view_selector], outputs=[measure_image, measure_points_state], ) # ------------------------------------------------------------------------- # Acknowledgement section # ------------------------------------------------------------------------- gr.HTML(get_acknowledgements_html()) demo.queue(max_size=20).launch(show_error=True, share=True, ssr_mode=False)