# Copyright (c) Meta Platforms, Inc. and affiliates. # All rights reserved. # # This source code is licensed under the license found in the # LICENSE file in the root directory of this source tree. # conda activate hf3.10 import gc import os import shutil import sys import time from datetime import datetime import cv2 import gradio as gr import numpy as np import spaces import torch sys.path.append("mapanything/") from mapanything.utils.geometry import depthmap_to_world_frame, points_to_normals from mapanything.utils.hf_utils.css_and_html import ( GRADIO_CSS, MEASURE_INSTRUCTIONS_HTML, get_acknowledgements_html, get_description_html, get_gradio_theme, get_header_html, ) from mapanything.utils.hf_utils.hf_helpers import initialize_mapanything_model from mapanything.utils.hf_utils.visual_util import predictions_to_glb from mapanything.utils.image import load_images, rgb def get_logo_base64(): """Convert WAI logo to base64 for embedding in HTML""" import base64 logo_path = "examples/wai_logo/wai_logo.png" try: with open(logo_path, "rb") as img_file: img_data = img_file.read() base64_str = base64.b64encode(img_data).decode() return f"data:image/png;base64,{base64_str}" except FileNotFoundError: return None # MapAnything Configuration high_level_config = { "path": "configs/train.yaml", "hf_model_name": "facebook/map-anything", "model_str": "mapanything", "config_overrides": [ "machine=aws", "model=mapanything", "model/task=images_only", "model.encoder.uses_torch_hub=false", ], "checkpoint_name": "model.safetensors", "config_name": "config.json", "trained_with_amp": True, "trained_with_amp_dtype": "bf16", "data_norm_type": "dinov2", "patch_size": 14, "resolution": 518, } # Initialize model - this will be done on GPU when needed model = None # ------------------------------------------------------------------------- # 1) Core model inference # ------------------------------------------------------------------------- @spaces.GPU(duration=120) def run_model( target_dir, apply_mask=True, mask_edges=True, filter_black_bg=False, filter_white_bg=False, ): """ Run the MapAnything model on images in the 'target_dir/images' folder and return predictions. """ global model import torch # Ensure torch is available in function scope print(f"Processing images from {target_dir}") # Device check device = "cuda" if torch.cuda.is_available() else "cpu" device = torch.device(device) # Initialize model if not already done if model is None: model = initialize_mapanything_model(high_level_config, device) else: model = model.to(device) model.eval() # Load images using MapAnything's load_images function print("Loading images...") image_folder_path = os.path.join(target_dir, "images") views = load_images(image_folder_path) print(f"Loaded {len(views)} images") if len(views) == 0: raise ValueError("No images found. Check your upload.") # Run model inference print("Running inference...") # apply_mask: Whether to apply the non-ambiguous mask to the output. Defaults to True. # mask_edges: Whether to compute an edge mask based on normals and depth and apply it to the output. Defaults to True. # Use checkbox values outputs = model.infer(views, apply_mask=apply_mask, mask_edges=mask_edges) # Convert predictions to format expected by visualization predictions = {} # Initialize lists for the required keys extrinsic_list = [] intrinsic_list = [] world_points_list = [] depth_maps_list = [] images_list = [] final_mask_list = [] # Loop through the outputs for pred in outputs: # Extract data from predictions depthmap_torch = pred["depth_z"][0].squeeze(-1) # (H, W) intrinsics_torch = pred["intrinsics"][0] # (3, 3) camera_pose_torch = pred["camera_poses"][0] # (4, 4) # Compute new pts3d using depth, intrinsics, and camera pose pts3d_computed, valid_mask = depthmap_to_world_frame( depthmap_torch, intrinsics_torch, camera_pose_torch ) # Convert to numpy arrays for visualization # Check if mask key exists in pred, if not, fill with boolean trues in the size of depthmap_torch if "mask" in pred: mask = pred["mask"][0].squeeze(-1).cpu().numpy().astype(bool) else: # Fill with boolean trues in the size of depthmap_torch mask = np.ones_like(depthmap_torch.cpu().numpy(), dtype=bool) # Combine with valid depth mask mask = mask & valid_mask.cpu().numpy() image = pred["img_no_norm"][0].cpu().numpy() # Append to lists extrinsic_list.append(camera_pose_torch.cpu().numpy()) intrinsic_list.append(intrinsics_torch.cpu().numpy()) world_points_list.append(pts3d_computed.cpu().numpy()) depth_maps_list.append(depthmap_torch.cpu().numpy()) images_list.append(image) # Add image to list final_mask_list.append(mask) # Add final_mask to list # Convert lists to numpy arrays with required shapes # extrinsic: (S, 3, 4) - batch of camera extrinsic matrices predictions["extrinsic"] = np.stack(extrinsic_list, axis=0) # intrinsic: (S, 3, 3) - batch of camera intrinsic matrices predictions["intrinsic"] = np.stack(intrinsic_list, axis=0) # world_points: (S, H, W, 3) - batch of 3D world points predictions["world_points"] = np.stack(world_points_list, axis=0) # depth: (S, H, W, 1) or (S, H, W) - batch of depth maps depth_maps = np.stack(depth_maps_list, axis=0) # Add channel dimension if needed to match (S, H, W, 1) format if len(depth_maps.shape) == 3: depth_maps = depth_maps[..., np.newaxis] predictions["depth"] = depth_maps # images: (S, H, W, 3) - batch of input images predictions["images"] = np.stack(images_list, axis=0) # final_mask: (S, H, W) - batch of final masks for filtering predictions["final_mask"] = np.stack(final_mask_list, axis=0) # Process data for visualization tabs (depth, normal, measure) processed_data = process_predictions_for_visualization( predictions, views, high_level_config, filter_black_bg, filter_white_bg ) # Clean up torch.cuda.empty_cache() return predictions, processed_data def update_view_selectors(processed_data): """Update view selector dropdowns based on available views""" if processed_data is None or len(processed_data) == 0: choices = ["View 1"] else: num_views = len(processed_data) choices = [f"View {i + 1}" for i in range(num_views)] return ( gr.Dropdown(choices=choices, value=choices[0]), # depth_view_selector gr.Dropdown(choices=choices, value=choices[0]), # normal_view_selector gr.Dropdown(choices=choices, value=choices[0]), # measure_view_selector ) def get_view_data_by_index(processed_data, view_index): """Get view data by index, handling bounds""" if processed_data is None or len(processed_data) == 0: return None view_keys = list(processed_data.keys()) if view_index < 0 or view_index >= len(view_keys): view_index = 0 return processed_data[view_keys[view_index]] def update_depth_view(processed_data, view_index): """Update depth view for a specific view index""" view_data = get_view_data_by_index(processed_data, view_index) if view_data is None or view_data["depth"] is None: return None return colorize_depth(view_data["depth"], mask=view_data.get("mask")) def update_normal_view(processed_data, view_index): """Update normal view for a specific view index""" view_data = get_view_data_by_index(processed_data, view_index) if view_data is None or view_data["normal"] is None: return None return colorize_normal(view_data["normal"], mask=view_data.get("mask")) def update_measure_view(processed_data, view_index): """Update measure view for a specific view index with mask overlay""" view_data = get_view_data_by_index(processed_data, view_index) if view_data is None: return None, [] # image, measure_points # Get the base image image = view_data["image"].copy() # Ensure image is in uint8 format if image.dtype != np.uint8: if image.max() <= 1.0: image = (image * 255).astype(np.uint8) else: image = image.astype(np.uint8) # Apply mask overlay if mask is available if view_data["mask"] is not None: mask = view_data["mask"] # Create light grey overlay for masked areas # Masked areas (False values) will be overlaid with light grey invalid_mask = ~mask # Areas where mask is False if invalid_mask.any(): # Create a light grey overlay (RGB: 192, 192, 192) overlay_color = np.array([255, 220, 220], dtype=np.uint8) # Apply overlay with some transparency alpha = 0.5 # Transparency level for c in range(3): # RGB channels image[:, :, c] = np.where( invalid_mask, (1 - alpha) * image[:, :, c] + alpha * overlay_color[c], image[:, :, c], ).astype(np.uint8) return image, [] def navigate_depth_view(processed_data, current_selector_value, direction): """Navigate depth view (direction: -1 for previous, +1 for next)""" if processed_data is None or len(processed_data) == 0: return "View 1", None # Parse current view number try: current_view = int(current_selector_value.split()[1]) - 1 except: current_view = 0 num_views = len(processed_data) new_view = (current_view + direction) % num_views new_selector_value = f"View {new_view + 1}" depth_vis = update_depth_view(processed_data, new_view) return new_selector_value, depth_vis def navigate_normal_view(processed_data, current_selector_value, direction): """Navigate normal view (direction: -1 for previous, +1 for next)""" if processed_data is None or len(processed_data) == 0: return "View 1", None # Parse current view number try: current_view = int(current_selector_value.split()[1]) - 1 except: current_view = 0 num_views = len(processed_data) new_view = (current_view + direction) % num_views new_selector_value = f"View {new_view + 1}" normal_vis = update_normal_view(processed_data, new_view) return new_selector_value, normal_vis def navigate_measure_view(processed_data, current_selector_value, direction): """Navigate measure view (direction: -1 for previous, +1 for next)""" if processed_data is None or len(processed_data) == 0: return "View 1", None, [] # Parse current view number try: current_view = int(current_selector_value.split()[1]) - 1 except: current_view = 0 num_views = len(processed_data) new_view = (current_view + direction) % num_views new_selector_value = f"View {new_view + 1}" measure_image, measure_points = update_measure_view(processed_data, new_view) return new_selector_value, measure_image, measure_points def populate_visualization_tabs(processed_data): """Populate the depth, normal, and measure tabs with processed data""" if processed_data is None or len(processed_data) == 0: return None, None, None, [] # Use update functions to ensure confidence filtering is applied from the start depth_vis = update_depth_view(processed_data, 0) normal_vis = update_normal_view(processed_data, 0) measure_img, _ = update_measure_view(processed_data, 0) return depth_vis, normal_vis, measure_img, [] # ------------------------------------------------------------------------- # 2) Handle uploaded video/images --> produce target_dir + images # ------------------------------------------------------------------------- def handle_uploads(input_video, input_images, s_time_interval=1.0): """ Create a new 'target_dir' + 'images' subfolder, and place user-uploaded images or extracted frames from video into it. Return (target_dir, image_paths). """ start_time = time.time() gc.collect() torch.cuda.empty_cache() # Create a unique folder name timestamp = datetime.now().strftime("%Y%m%d_%H%M%S_%f") target_dir = f"input_images_{timestamp}" target_dir_images = os.path.join(target_dir, "images") # Clean up if somehow that folder already exists if os.path.exists(target_dir): shutil.rmtree(target_dir) os.makedirs(target_dir) os.makedirs(target_dir_images) image_paths = [] # --- Handle images --- if input_images is not None: for file_data in input_images: if isinstance(file_data, dict) and "name" in file_data: file_path = file_data["name"] else: file_path = file_data dst_path = os.path.join(target_dir_images, os.path.basename(file_path)) shutil.copy(file_path, dst_path) image_paths.append(dst_path) # --- Handle video --- if input_video is not None: if isinstance(input_video, dict) and "name" in input_video: video_path = input_video["name"] else: video_path = input_video vs = cv2.VideoCapture(video_path) fps = vs.get(cv2.CAP_PROP_FPS) frame_interval = int(fps * s_time_interval) # 1 frame/sec count = 0 video_frame_num = 0 while True: gotit, frame = vs.read() if not gotit: break count += 1 if count % frame_interval == 0: image_path = os.path.join( target_dir_images, f"{video_frame_num:06}.png" ) cv2.imwrite(image_path, frame) image_paths.append(image_path) video_frame_num += 1 # Sort final images for gallery image_paths = sorted(image_paths) end_time = time.time() print( f"Files copied to {target_dir_images}; took {end_time - start_time:.3f} seconds" ) return target_dir, image_paths # ------------------------------------------------------------------------- # 3) Update gallery on upload # ------------------------------------------------------------------------- def update_gallery_on_upload(input_video, input_images, s_time_interval=1.0): """ Whenever user uploads or changes files, immediately handle them and show in the gallery. Return (target_dir, image_paths). If nothing is uploaded, returns "None" and empty list. """ if not input_video and not input_images: return None, None, None, None target_dir, image_paths = handle_uploads(input_video, input_images, s_time_interval) return ( None, target_dir, image_paths, "Upload complete. Click 'Reconstruct' to begin 3D processing.", ) # ------------------------------------------------------------------------- # 4) Reconstruction: uses the target_dir plus any viz parameters # ------------------------------------------------------------------------- @spaces.GPU(duration=120) def gradio_demo( target_dir, frame_filter="All", show_cam=True, filter_black_bg=False, filter_white_bg=False, apply_mask=True, mask_edges=True, show_mesh=True, ): """ Perform reconstruction using the already-created target_dir/images. """ if not os.path.isdir(target_dir) or target_dir == "None": return None, "No valid target directory found. Please upload first.", None, None start_time = time.time() gc.collect() torch.cuda.empty_cache() # Prepare frame_filter dropdown target_dir_images = os.path.join(target_dir, "images") all_files = ( sorted(os.listdir(target_dir_images)) if os.path.isdir(target_dir_images) else [] ) all_files = [f"{i}: {filename}" for i, filename in enumerate(all_files)] frame_filter_choices = ["All"] + all_files print("Running MapAnything model...") with torch.no_grad(): predictions, processed_data = run_model(target_dir, apply_mask, mask_edges) # Save predictions prediction_save_path = os.path.join(target_dir, "predictions.npz") np.savez(prediction_save_path, **predictions) # Handle None frame_filter if frame_filter is None: frame_filter = "All" # Build a GLB file name glbfile = os.path.join( target_dir, f"glbscene_{frame_filter.replace('.', '_').replace(':', '').replace(' ', '_')}_cam{show_cam}_mesh{show_mesh}_black{filter_black_bg}_white{filter_white_bg}.glb", ) # Convert predictions to GLB glbscene = predictions_to_glb( predictions, filter_by_frames=frame_filter, show_cam=show_cam, mask_black_bg=filter_black_bg, mask_white_bg=filter_white_bg, as_mesh=show_mesh, # Use the show_mesh parameter ) glbscene.export(file_obj=glbfile) # Cleanup del predictions gc.collect() torch.cuda.empty_cache() end_time = time.time() print(f"Total time: {end_time - start_time:.2f} seconds") log_msg = ( f"Reconstruction Success ({len(all_files)} frames). Waiting for visualization." ) # Populate visualization tabs with processed data depth_vis, normal_vis, measure_img, measure_pts = populate_visualization_tabs( processed_data ) # Update view selectors based on available views depth_selector, normal_selector, measure_selector = update_view_selectors( processed_data ) return ( glbfile, log_msg, gr.Dropdown(choices=frame_filter_choices, value=frame_filter, interactive=True), processed_data, depth_vis, normal_vis, measure_img, "", # measure_text (empty initially) depth_selector, normal_selector, measure_selector, ) # ------------------------------------------------------------------------- # 5) Helper functions for UI resets + re-visualization # ------------------------------------------------------------------------- def colorize_depth(depth_map, mask=None): """Convert depth map to colorized visualization with optional mask""" if depth_map is None: return None # Normalize depth to 0-1 range depth_normalized = depth_map.copy() valid_mask = depth_normalized > 0 # Apply additional mask if provided (for background filtering) if mask is not None: valid_mask = valid_mask & mask if valid_mask.sum() > 0: valid_depths = depth_normalized[valid_mask] p5 = np.percentile(valid_depths, 5) p95 = np.percentile(valid_depths, 95) depth_normalized[valid_mask] = (depth_normalized[valid_mask] - p5) / (p95 - p5) # Apply colormap import matplotlib.pyplot as plt colormap = plt.cm.turbo_r colored = colormap(depth_normalized) colored = (colored[:, :, :3] * 255).astype(np.uint8) # Set invalid pixels to white colored[~valid_mask] = [255, 255, 255] return colored def colorize_normal(normal_map, mask=None): """Convert normal map to colorized visualization with optional mask""" if normal_map is None: return None # Create a copy for modification normal_vis = normal_map.copy() # Apply mask if provided (set masked areas to [0, 0, 0] which becomes grey after normalization) if mask is not None: invalid_mask = ~mask normal_vis[invalid_mask] = [0, 0, 0] # Set invalid areas to zero # Normalize normals to [0, 1] range for visualization normal_vis = (normal_vis + 1.0) / 2.0 normal_vis = (normal_vis * 255).astype(np.uint8) return normal_vis def process_predictions_for_visualization( predictions, views, high_level_config, filter_black_bg=False, filter_white_bg=False ): """Extract depth, normal, and 3D points from predictions for visualization""" processed_data = {} # Process each view for view_idx, view in enumerate(views): # Get image image = rgb(view["img"], norm_type=high_level_config["data_norm_type"]) # Get predicted points pred_pts3d = predictions["world_points"][view_idx] # Initialize data for this view view_data = { "image": image[0], "points3d": pred_pts3d, "depth": None, "normal": None, "mask": None, } # Start with the final mask from predictions mask = predictions["final_mask"][view_idx].copy() # Apply black background filtering if enabled if filter_black_bg: # Get the image colors (ensure they're in 0-255 range) view_colors = image[0] * 255 if image[0].max() <= 1.0 else image[0] # Filter out black background pixels (sum of RGB < 16) black_bg_mask = view_colors.sum(axis=2) >= 16 mask = mask & black_bg_mask # Apply white background filtering if enabled if filter_white_bg: # Get the image colors (ensure they're in 0-255 range) view_colors = image[0] * 255 if image[0].max() <= 1.0 else image[0] # Filter out white background pixels (all RGB > 240) white_bg_mask = ~( (view_colors[:, :, 0] > 240) & (view_colors[:, :, 1] > 240) & (view_colors[:, :, 2] > 240) ) mask = mask & white_bg_mask view_data["mask"] = mask view_data["depth"] = predictions["depth"][view_idx].squeeze() normals, _ = points_to_normals(pred_pts3d, mask=view_data["mask"]) view_data["normal"] = normals processed_data[view_idx] = view_data return processed_data def reset_measure(processed_data): """Reset measure points""" if processed_data is None or len(processed_data) == 0: return None, [], "" # Return the first view image first_view = list(processed_data.values())[0] return first_view["image"], [], "" def measure( processed_data, measure_points, current_view_selector, event: gr.SelectData ): """Handle measurement on images""" try: print(f"Measure function called with selector: {current_view_selector}") if processed_data is None or len(processed_data) == 0: return None, [], "No data available" # Use the currently selected view instead of always using the first view try: current_view_index = int(current_view_selector.split()[1]) - 1 except: current_view_index = 0 print(f"Using view index: {current_view_index}") # Get view data safely if current_view_index < 0 or current_view_index >= len(processed_data): current_view_index = 0 view_keys = list(processed_data.keys()) current_view = processed_data[view_keys[current_view_index]] if current_view is None: return None, [], "No view data available" point2d = event.index[0], event.index[1] print(f"Clicked point: {point2d}") # Check if the clicked point is in a masked area (prevent interaction) if ( current_view["mask"] is not None and 0 <= point2d[1] < current_view["mask"].shape[0] and 0 <= point2d[0] < current_view["mask"].shape[1] ): # Check if the point is in a masked (invalid) area if not current_view["mask"][point2d[1], point2d[0]]: print(f"Clicked point {point2d} is in masked area, ignoring click") # Always return image with mask overlay masked_image, _ = update_measure_view( processed_data, current_view_index ) return ( masked_image, measure_points, 'Cannot measure on masked areas (shown in grey)', ) measure_points.append(point2d) # Get image with mask overlay and ensure it's valid image, _ = update_measure_view(processed_data, current_view_index) if image is None: return None, [], "No image available" image = image.copy() points3d = current_view["points3d"] # Ensure image is in uint8 format for proper cv2 operations try: if image.dtype != np.uint8: if image.max() <= 1.0: # Image is in [0, 1] range, convert to [0, 255] image = (image * 255).astype(np.uint8) else: # Image is already in [0, 255] range image = image.astype(np.uint8) except Exception as e: print(f"Image conversion error: {e}") return None, [], f"Image conversion error: {e}" # Draw circles for points try: for p in measure_points: if 0 <= p[0] < image.shape[1] and 0 <= p[1] < image.shape[0]: image = cv2.circle( image, p, radius=5, color=(255, 0, 0), thickness=2 ) except Exception as e: print(f"Drawing error: {e}") return None, [], f"Drawing error: {e}" depth_text = "" try: for i, p in enumerate(measure_points): if ( current_view["depth"] is not None and 0 <= p[1] < current_view["depth"].shape[0] and 0 <= p[0] < current_view["depth"].shape[1] ): d = current_view["depth"][p[1], p[0]] depth_text += f"- **P{i + 1} depth: {d:.2f}m.**\n" else: # Use Z coordinate of 3D points if depth not available if ( points3d is not None and 0 <= p[1] < points3d.shape[0] and 0 <= p[0] < points3d.shape[1] ): z = points3d[p[1], p[0], 2] depth_text += f"- **P{i + 1} Z-coord: {z:.2f}m.**\n" except Exception as e: print(f"Depth text error: {e}") depth_text = f"Error computing depth: {e}\n" if len(measure_points) == 2: try: point1, point2 = measure_points # Draw line if ( 0 <= point1[0] < image.shape[1] and 0 <= point1[1] < image.shape[0] and 0 <= point2[0] < image.shape[1] and 0 <= point2[1] < image.shape[0] ): image = cv2.line( image, point1, point2, color=(255, 0, 0), thickness=2 ) # Compute 3D distance distance_text = "- **Distance: Unable to compute**" if ( points3d is not None and 0 <= point1[1] < points3d.shape[0] and 0 <= point1[0] < points3d.shape[1] and 0 <= point2[1] < points3d.shape[0] and 0 <= point2[0] < points3d.shape[1] ): try: p1_3d = points3d[point1[1], point1[0]] p2_3d = points3d[point2[1], point2[0]] distance = np.linalg.norm(p1_3d - p2_3d) distance_text = f"- **Distance: {distance:.2f}m**" except Exception as e: print(f"Distance computation error: {e}") distance_text = f"- **Distance computation error: {e}**" measure_points = [] text = depth_text + distance_text print(f"Measurement complete: {text}") return [image, measure_points, text] except Exception as e: print(f"Final measurement error: {e}") return None, [], f"Measurement error: {e}" else: print(f"Single point measurement: {depth_text}") return [image, measure_points, depth_text] except Exception as e: print(f"Overall measure function error: {e}") return None, [], f"Measure function error: {e}" def clear_fields(): """ Clears the 3D viewer, the stored target_dir, and empties the gallery. """ return None def update_log(): """ Display a quick log message while waiting. """ return "Loading and Reconstructing..." def update_visualization( target_dir, frame_filter, show_cam, is_example, filter_black_bg=False, filter_white_bg=False, show_mesh=True, ): """ Reload saved predictions from npz, create (or reuse) the GLB for new parameters, and return it for the 3D viewer. If is_example == "True", skip. """ # If it's an example click, skip as requested if is_example == "True": return ( gr.update(), "No reconstruction available. Please click the Reconstruct button first.", ) if not target_dir or target_dir == "None" or not os.path.isdir(target_dir): return ( gr.update(), "No reconstruction available. Please click the Reconstruct button first.", ) predictions_path = os.path.join(target_dir, "predictions.npz") if not os.path.exists(predictions_path): return ( gr.update(), f"No reconstruction available at {predictions_path}. Please run 'Reconstruct' first.", ) loaded = np.load(predictions_path, allow_pickle=True) predictions = {key: loaded[key] for key in loaded.keys()} glbfile = os.path.join( target_dir, f"glbscene_{frame_filter.replace('.', '_').replace(':', '').replace(' ', '_')}_cam{show_cam}_mesh{show_mesh}_black{filter_black_bg}_white{filter_white_bg}.glb", ) if not os.path.exists(glbfile): glbscene = predictions_to_glb( predictions, filter_by_frames=frame_filter, show_cam=show_cam, mask_black_bg=filter_black_bg, mask_white_bg=filter_white_bg, as_mesh=show_mesh, ) glbscene.export(file_obj=glbfile) return ( glbfile, "Visualization updated.", ) def update_all_views_on_filter_change( target_dir, filter_black_bg, filter_white_bg, processed_data, depth_view_selector, normal_view_selector, measure_view_selector, ): """ Update all individual view tabs when background filtering checkboxes change. This regenerates the processed data with new filtering and updates all views. """ # Check if we have a valid target directory and predictions if not target_dir or target_dir == "None" or not os.path.isdir(target_dir): return processed_data, None, None, None, [] predictions_path = os.path.join(target_dir, "predictions.npz") if not os.path.exists(predictions_path): return processed_data, None, None, None, [] try: # Load the original predictions and views loaded = np.load(predictions_path, allow_pickle=True) predictions = {key: loaded[key] for key in loaded.keys()} # Load images using MapAnything's load_images function image_folder_path = os.path.join(target_dir, "images") views = load_images(image_folder_path) # Regenerate processed data with new filtering settings new_processed_data = process_predictions_for_visualization( predictions, views, high_level_config, filter_black_bg, filter_white_bg ) # Get current view indices try: depth_view_idx = ( int(depth_view_selector.split()[1]) - 1 if depth_view_selector else 0 ) except: depth_view_idx = 0 try: normal_view_idx = ( int(normal_view_selector.split()[1]) - 1 if normal_view_selector else 0 ) except: normal_view_idx = 0 try: measure_view_idx = ( int(measure_view_selector.split()[1]) - 1 if measure_view_selector else 0 ) except: measure_view_idx = 0 # Update all views with new filtered data depth_vis = update_depth_view(new_processed_data, depth_view_idx) normal_vis = update_normal_view(new_processed_data, normal_view_idx) measure_img, _ = update_measure_view(new_processed_data, measure_view_idx) return new_processed_data, depth_vis, normal_vis, measure_img, [] except Exception as e: print(f"Error updating views on filter change: {e}") return processed_data, None, None, None, [] # ------------------------------------------------------------------------- # Example scene functions # ------------------------------------------------------------------------- def get_scene_info(examples_dir): """Get information about scenes in the examples directory""" import glob scenes = [] if not os.path.exists(examples_dir): return scenes for scene_folder in sorted(os.listdir(examples_dir)): scene_path = os.path.join(examples_dir, scene_folder) if os.path.isdir(scene_path): # Find all image files in the scene folder image_extensions = ["*.jpg", "*.jpeg", "*.png", "*.bmp", "*.tiff", "*.tif"] image_files = [] for ext in image_extensions: image_files.extend(glob.glob(os.path.join(scene_path, ext))) image_files.extend(glob.glob(os.path.join(scene_path, ext.upper()))) if image_files: # Sort images and get the first one for thumbnail image_files = sorted(image_files) first_image = image_files[0] num_images = len(image_files) scenes.append( { "name": scene_folder, "path": scene_path, "thumbnail": first_image, "num_images": num_images, "image_files": image_files, } ) return scenes def load_example_scene(scene_name, examples_dir="examples"): """Load a scene from examples directory""" scenes = get_scene_info(examples_dir) # Find the selected scene selected_scene = None for scene in scenes: if scene["name"] == scene_name: selected_scene = scene break if selected_scene is None: return None, None, None, "Scene not found" # Create target directory and copy images target_dir, image_paths = handle_uploads(None, selected_scene["image_files"]) return ( None, # Clear reconstruction output target_dir, # Set target directory image_paths, # Set gallery f"Loaded scene '{scene_name}' with {selected_scene['num_images']} images. Click 'Reconstruct' to begin 3D processing.", ) # ------------------------------------------------------------------------- # 6) Build Gradio UI # ------------------------------------------------------------------------- theme = get_gradio_theme() with gr.Blocks(theme=theme, css=GRADIO_CSS) as demo: # State variables for the tabbed interface is_example = gr.Textbox(label="is_example", visible=False, value="None") num_images = gr.Textbox(label="num_images", visible=False, value="None") processed_data_state = gr.State(value=None) measure_points_state = gr.State(value=[]) current_view_index = gr.State(value=0) # Track current view index for navigation gr.HTML(get_header_html(get_logo_base64())) gr.HTML(get_description_html()) target_dir_output = gr.Textbox(label="Target Dir", visible=False, value="None") with gr.Row(): with gr.Column(scale=2): input_video = gr.Video(label="Upload Video", interactive=True) s_time_interval = gr.Slider( minimum=0.1, maximum=5.0, value=1.0, step=0.1, label="Sample time interval (take a sample every x sec.)", interactive=True, visible=True, ) input_images = gr.File( file_count="multiple", label="Upload Images", interactive=True ) image_gallery = gr.Gallery( label="Preview", columns=4, height="300px", show_download_button=True, object_fit="contain", preview=True, ) with gr.Column(scale=4): with gr.Column(): gr.Markdown("**3D Reconstruction (Point Cloud and Camera Poses)**") log_output = gr.Markdown( "Please upload a video or images, then click Reconstruct.", elem_classes=["custom-log"], ) # Add tabbed interface similar to MoGe with gr.Tabs(): with gr.Tab("3D View"): reconstruction_output = gr.Model3D( height=520, zoom_speed=0.5, pan_speed=0.5, clear_color=[0.0, 0.0, 0.0, 0.0], key="persistent_3d_viewer", elem_id="reconstruction_3d_viewer", ) with gr.Tab("Depth"): with gr.Row(elem_classes=["navigation-row"]): prev_depth_btn = gr.Button("◀ Previous", size="sm", scale=1) depth_view_selector = gr.Dropdown( choices=["View 1"], value="View 1", label="Select View", scale=2, interactive=True, allow_custom_value=True, ) next_depth_btn = gr.Button("Next ▶", size="sm", scale=1) depth_map = gr.Image( type="numpy", label="Colorized Depth Map", format="png", interactive=False, ) with gr.Tab("Normal"): with gr.Row(elem_classes=["navigation-row"]): prev_normal_btn = gr.Button( "◀ Previous", size="sm", scale=1 ) normal_view_selector = gr.Dropdown( choices=["View 1"], value="View 1", label="Select View", scale=2, interactive=True, allow_custom_value=True, ) next_normal_btn = gr.Button("Next ▶", size="sm", scale=1) normal_map = gr.Image( type="numpy", label="Normal Map", format="png", interactive=False, ) with gr.Tab("Measure"): gr.Markdown(MEASURE_INSTRUCTIONS_HTML) with gr.Row(elem_classes=["navigation-row"]): prev_measure_btn = gr.Button( "◀ Previous", size="sm", scale=1 ) measure_view_selector = gr.Dropdown( choices=["View 1"], value="View 1", label="Select View", scale=2, interactive=True, allow_custom_value=True, ) next_measure_btn = gr.Button("Next ▶", size="sm", scale=1) measure_image = gr.Image( type="numpy", show_label=False, format="webp", interactive=False, sources=[], ) gr.Markdown( "**Note:** Light-grey areas indicate regions with no depth information where measurements cannot be taken." ) measure_text = gr.Markdown("") with gr.Row(): submit_btn = gr.Button("Reconstruct", scale=1, variant="primary") clear_btn = gr.ClearButton( [ input_video, input_images, reconstruction_output, log_output, target_dir_output, image_gallery, ], scale=1, ) with gr.Row(): frame_filter = gr.Dropdown( choices=["All"], value="All", label="Show Points from Frame" ) with gr.Column(): gr.Markdown("### Pointcloud Options: (live updates)") show_cam = gr.Checkbox(label="Show Camera", value=True) show_mesh = gr.Checkbox(label="Show Mesh", value=True) filter_black_bg = gr.Checkbox( label="Filter Black Background", value=False ) filter_white_bg = gr.Checkbox( label="Filter White Background", value=False ) gr.Markdown("### Reconstruction Options: (updated on next run)") apply_mask_checkbox = gr.Checkbox( label="Apply mask for predicted ambiguous depth classes & edges", value=True ) # ---------------------- Example Scenes Section ---------------------- gr.Markdown("## Example Scenes (lists all scenes in the examples folder)") gr.Markdown("Click any thumbnail to load the scene for reconstruction.") # Get scene information scenes = get_scene_info("examples") # Create thumbnail grid (4 columns, N rows) if scenes: for i in range(0, len(scenes), 4): # Process 4 scenes per row with gr.Row(): for j in range(4): scene_idx = i + j if scene_idx < len(scenes): scene = scenes[scene_idx] with gr.Column(scale=1, elem_classes=["clickable-thumbnail"]): # Clickable thumbnail scene_img = gr.Image( value=scene["thumbnail"], height=150, interactive=False, show_label=False, elem_id=f"scene_thumb_{scene['name']}", sources=[], ) # Scene name and image count as text below thumbnail gr.Markdown( f"**{scene['name']}** \n {scene['num_images']} images", elem_classes=["scene-info"], ) # Connect thumbnail click to load scene scene_img.select( fn=lambda name=scene["name"]: load_example_scene(name), outputs=[ reconstruction_output, target_dir_output, image_gallery, log_output, ], ) else: # Empty column to maintain grid structure with gr.Column(scale=1): pass # ------------------------------------------------------------------------- # "Reconstruct" button logic: # - Clear fields # - Update log # - gradio_demo(...) with the existing target_dir # - Then set is_example = "False" # ------------------------------------------------------------------------- submit_btn.click(fn=clear_fields, inputs=[], outputs=[reconstruction_output]).then( fn=update_log, inputs=[], outputs=[log_output] ).then( fn=gradio_demo, inputs=[ target_dir_output, frame_filter, show_cam, filter_black_bg, filter_white_bg, apply_mask_checkbox, show_mesh, ], outputs=[ reconstruction_output, log_output, frame_filter, processed_data_state, depth_map, normal_map, measure_image, measure_text, depth_view_selector, normal_view_selector, measure_view_selector, ], ).then( fn=lambda: "False", inputs=[], outputs=[is_example], # set is_example to "False" ) # ------------------------------------------------------------------------- # Real-time Visualization Updates # ------------------------------------------------------------------------- frame_filter.change( update_visualization, [ target_dir_output, frame_filter, show_cam, is_example, filter_black_bg, filter_white_bg, show_mesh, ], [reconstruction_output, log_output], ) show_cam.change( update_visualization, [ target_dir_output, frame_filter, show_cam, is_example, ], [reconstruction_output, log_output], ) filter_black_bg.change( update_visualization, [ target_dir_output, frame_filter, show_cam, is_example, filter_black_bg, filter_white_bg, ], [reconstruction_output, log_output], ).then( fn=update_all_views_on_filter_change, inputs=[ target_dir_output, filter_black_bg, filter_white_bg, processed_data_state, depth_view_selector, normal_view_selector, measure_view_selector, ], outputs=[ processed_data_state, depth_map, normal_map, measure_image, measure_points_state, ], ) filter_white_bg.change( update_visualization, [ target_dir_output, frame_filter, show_cam, is_example, filter_black_bg, filter_white_bg, show_mesh, ], [reconstruction_output, log_output], ).then( fn=update_all_views_on_filter_change, inputs=[ target_dir_output, filter_black_bg, filter_white_bg, processed_data_state, depth_view_selector, normal_view_selector, measure_view_selector, ], outputs=[ processed_data_state, depth_map, normal_map, measure_image, measure_points_state, ], ) show_mesh.change( update_visualization, [ target_dir_output, frame_filter, show_cam, is_example, filter_black_bg, filter_white_bg, show_mesh, ], [reconstruction_output, log_output], ) # ------------------------------------------------------------------------- # Auto-update gallery whenever user uploads or changes their files # ------------------------------------------------------------------------- input_video.change( fn=update_gallery_on_upload, inputs=[input_video, input_images, s_time_interval], outputs=[reconstruction_output, target_dir_output, image_gallery, log_output], ) input_images.change( fn=update_gallery_on_upload, inputs=[input_video, input_images, s_time_interval], outputs=[reconstruction_output, target_dir_output, image_gallery, log_output], ) # ------------------------------------------------------------------------- # Measure tab functionality # ------------------------------------------------------------------------- measure_image.select( fn=measure, inputs=[processed_data_state, measure_points_state, measure_view_selector], outputs=[measure_image, measure_points_state, measure_text], ) # ------------------------------------------------------------------------- # Navigation functionality for Depth, Normal, and Measure tabs # ------------------------------------------------------------------------- # Depth tab navigation prev_depth_btn.click( fn=lambda processed_data, current_selector: navigate_depth_view( processed_data, current_selector, -1 ), inputs=[processed_data_state, depth_view_selector], outputs=[depth_view_selector, depth_map], ) next_depth_btn.click( fn=lambda processed_data, current_selector: navigate_depth_view( processed_data, current_selector, 1 ), inputs=[processed_data_state, depth_view_selector], outputs=[depth_view_selector, depth_map], ) depth_view_selector.change( fn=lambda processed_data, selector_value: ( update_depth_view( processed_data, int(selector_value.split()[1]) - 1, ) if selector_value else None ), inputs=[processed_data_state, depth_view_selector], outputs=[depth_map], ) # Normal tab navigation prev_normal_btn.click( fn=lambda processed_data, current_selector: navigate_normal_view( processed_data, current_selector, -1 ), inputs=[processed_data_state, normal_view_selector], outputs=[normal_view_selector, normal_map], ) next_normal_btn.click( fn=lambda processed_data, current_selector: navigate_normal_view( processed_data, current_selector, 1 ), inputs=[processed_data_state, normal_view_selector], outputs=[normal_view_selector, normal_map], ) normal_view_selector.change( fn=lambda processed_data, selector_value: ( update_normal_view( processed_data, int(selector_value.split()[1]) - 1, ) if selector_value else None ), inputs=[processed_data_state, normal_view_selector], outputs=[normal_map], ) # Measure tab navigation prev_measure_btn.click( fn=lambda processed_data, current_selector: navigate_measure_view( processed_data, current_selector, -1 ), inputs=[processed_data_state, measure_view_selector], outputs=[measure_view_selector, measure_image, measure_points_state], ) next_measure_btn.click( fn=lambda processed_data, current_selector: navigate_measure_view( processed_data, current_selector, 1 ), inputs=[processed_data_state, measure_view_selector], outputs=[measure_view_selector, measure_image, measure_points_state], ) measure_view_selector.change( fn=lambda processed_data, selector_value: ( update_measure_view(processed_data, int(selector_value.split()[1]) - 1) if selector_value else (None, []) ), inputs=[processed_data_state, measure_view_selector], outputs=[measure_image, measure_points_state], ) # ------------------------------------------------------------------------- # Acknowledgement section # ------------------------------------------------------------------------- gr.HTML(get_acknowledgements_html()) demo.queue(max_size=20).launch(show_error=True, share=True, ssr_mode=False)