map-anything / app.py
aknapitsch user
common video and image upload field
478d9e0
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
# conda activate hf3.10
import gc
import os
import shutil
import sys
import time
from datetime import datetime
os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True"
import cv2
import gradio as gr
import numpy as np
import spaces
import torch
from PIL import Image
from pillow_heif import register_heif_opener
register_heif_opener()
sys.path.append("mapanything/")
from mapanything.utils.geometry import depthmap_to_world_frame, points_to_normals
from mapanything.utils.hf_utils.css_and_html import (
GRADIO_CSS,
MEASURE_INSTRUCTIONS_HTML,
get_acknowledgements_html,
get_description_html,
get_gradio_theme,
get_header_html,
)
from mapanything.utils.hf_utils.hf_helpers import initialize_mapanything_model
from mapanything.utils.hf_utils.visual_util import predictions_to_glb
from mapanything.utils.image import load_images, rgb
def get_logo_base64():
"""Convert WAI logo to base64 for embedding in HTML"""
import base64
logo_path = "examples/WAI-Logo/wai_logo.png"
try:
with open(logo_path, "rb") as img_file:
img_data = img_file.read()
base64_str = base64.b64encode(img_data).decode()
return f"data:image/png;base64,{base64_str}"
except FileNotFoundError:
return None
# MapAnything Configuration
high_level_config = {
"path": "configs/train.yaml",
"hf_model_name": "facebook/map-anything",
"model_str": "mapanything",
"config_overrides": [
"machine=aws",
"model=mapanything",
"model/task=images_only",
"model.encoder.uses_torch_hub=false",
],
"checkpoint_name": "model.safetensors",
"config_name": "config.json",
"trained_with_amp": True,
"trained_with_amp_dtype": "bf16",
"data_norm_type": "dinov2",
"patch_size": 14,
"resolution": 518,
}
# Initialize model - this will be done on GPU when needed
model = None
# -------------------------------------------------------------------------
# 1) Core model inference
# -------------------------------------------------------------------------
@spaces.GPU(duration=120)
def run_model(
target_dir,
apply_mask=True,
mask_edges=True,
filter_black_bg=False,
filter_white_bg=False,
):
"""
Run the MapAnything model on images in the 'target_dir/images' folder and return predictions.
"""
global model
import torch # Ensure torch is available in function scope
print(f"Processing images from {target_dir}")
# Device check
device = "cuda" if torch.cuda.is_available() else "cpu"
device = torch.device(device)
# Initialize model if not already done
if model is None:
model = initialize_mapanything_model(high_level_config, device)
else:
model = model.to(device)
model.eval()
# Load images using MapAnything's load_images function
print("Loading images...")
image_folder_path = os.path.join(target_dir, "images")
views = load_images(image_folder_path)
print(f"Loaded {len(views)} images")
if len(views) == 0:
raise ValueError("No images found. Check your upload.")
# Run model inference
print("Running inference...")
# apply_mask: Whether to apply the non-ambiguous mask to the output. Defaults to True.
# mask_edges: Whether to compute an edge mask based on normals and depth and apply it to the output. Defaults to True.
# Use checkbox values - mask_edges is set to True by default since there's no UI control for it
outputs = model.infer(
views, apply_mask=apply_mask, mask_edges=True, memory_efficient_inference=False
)
# Convert predictions to format expected by visualization
predictions = {}
# Initialize lists for the required keys
extrinsic_list = []
intrinsic_list = []
world_points_list = []
depth_maps_list = []
images_list = []
final_mask_list = []
# Loop through the outputs
for pred in outputs:
# Extract data from predictions
depthmap_torch = pred["depth_z"][0].squeeze(-1) # (H, W)
intrinsics_torch = pred["intrinsics"][0] # (3, 3)
camera_pose_torch = pred["camera_poses"][0] # (4, 4)
# Compute new pts3d using depth, intrinsics, and camera pose
pts3d_computed, valid_mask = depthmap_to_world_frame(
depthmap_torch, intrinsics_torch, camera_pose_torch
)
# Convert to numpy arrays for visualization
# Check if mask key exists in pred, if not, fill with boolean trues in the size of depthmap_torch
if "mask" in pred:
mask = pred["mask"][0].squeeze(-1).cpu().numpy().astype(bool)
else:
# Fill with boolean trues in the size of depthmap_torch
mask = np.ones_like(depthmap_torch.cpu().numpy(), dtype=bool)
# Combine with valid depth mask
mask = mask & valid_mask.cpu().numpy()
image = pred["img_no_norm"][0].cpu().numpy()
# Append to lists
extrinsic_list.append(camera_pose_torch.cpu().numpy())
intrinsic_list.append(intrinsics_torch.cpu().numpy())
world_points_list.append(pts3d_computed.cpu().numpy())
depth_maps_list.append(depthmap_torch.cpu().numpy())
images_list.append(image) # Add image to list
final_mask_list.append(mask) # Add final_mask to list
# Convert lists to numpy arrays with required shapes
# extrinsic: (S, 3, 4) - batch of camera extrinsic matrices
predictions["extrinsic"] = np.stack(extrinsic_list, axis=0)
# intrinsic: (S, 3, 3) - batch of camera intrinsic matrices
predictions["intrinsic"] = np.stack(intrinsic_list, axis=0)
# world_points: (S, H, W, 3) - batch of 3D world points
predictions["world_points"] = np.stack(world_points_list, axis=0)
# depth: (S, H, W, 1) or (S, H, W) - batch of depth maps
depth_maps = np.stack(depth_maps_list, axis=0)
# Add channel dimension if needed to match (S, H, W, 1) format
if len(depth_maps.shape) == 3:
depth_maps = depth_maps[..., np.newaxis]
predictions["depth"] = depth_maps
# images: (S, H, W, 3) - batch of input images
predictions["images"] = np.stack(images_list, axis=0)
# final_mask: (S, H, W) - batch of final masks for filtering
predictions["final_mask"] = np.stack(final_mask_list, axis=0)
# Process data for visualization tabs (depth, normal, measure)
processed_data = process_predictions_for_visualization(
predictions, views, high_level_config, filter_black_bg, filter_white_bg
)
# Clean up
torch.cuda.empty_cache()
return predictions, processed_data
def update_view_selectors(processed_data):
"""Update view selector dropdowns based on available views"""
if processed_data is None or len(processed_data) == 0:
choices = ["View 1"]
else:
num_views = len(processed_data)
choices = [f"View {i + 1}" for i in range(num_views)]
return (
gr.Dropdown(choices=choices, value=choices[0]), # depth_view_selector
gr.Dropdown(choices=choices, value=choices[0]), # normal_view_selector
gr.Dropdown(choices=choices, value=choices[0]), # measure_view_selector
)
def get_view_data_by_index(processed_data, view_index):
"""Get view data by index, handling bounds"""
if processed_data is None or len(processed_data) == 0:
return None
view_keys = list(processed_data.keys())
if view_index < 0 or view_index >= len(view_keys):
view_index = 0
return processed_data[view_keys[view_index]]
def update_depth_view(processed_data, view_index):
"""Update depth view for a specific view index"""
view_data = get_view_data_by_index(processed_data, view_index)
if view_data is None or view_data["depth"] is None:
return None
return colorize_depth(view_data["depth"], mask=view_data.get("mask"))
def update_normal_view(processed_data, view_index):
"""Update normal view for a specific view index"""
view_data = get_view_data_by_index(processed_data, view_index)
if view_data is None or view_data["normal"] is None:
return None
return colorize_normal(view_data["normal"], mask=view_data.get("mask"))
def update_measure_view(processed_data, view_index):
"""Update measure view for a specific view index with mask overlay"""
view_data = get_view_data_by_index(processed_data, view_index)
if view_data is None:
return None, [] # image, measure_points
# Get the base image
image = view_data["image"].copy()
# Ensure image is in uint8 format
if image.dtype != np.uint8:
if image.max() <= 1.0:
image = (image * 255).astype(np.uint8)
else:
image = image.astype(np.uint8)
# Apply mask overlay if mask is available
if view_data["mask"] is not None:
mask = view_data["mask"]
# Create light grey overlay for masked areas
# Masked areas (False values) will be overlaid with light grey
invalid_mask = ~mask # Areas where mask is False
if invalid_mask.any():
# Create a light grey overlay (RGB: 192, 192, 192)
overlay_color = np.array([255, 220, 220], dtype=np.uint8)
# Apply overlay with some transparency
alpha = 0.5 # Transparency level
for c in range(3): # RGB channels
image[:, :, c] = np.where(
invalid_mask,
(1 - alpha) * image[:, :, c] + alpha * overlay_color[c],
image[:, :, c],
).astype(np.uint8)
return image, []
def navigate_depth_view(processed_data, current_selector_value, direction):
"""Navigate depth view (direction: -1 for previous, +1 for next)"""
if processed_data is None or len(processed_data) == 0:
return "View 1", None
# Parse current view number
try:
current_view = int(current_selector_value.split()[1]) - 1
except:
current_view = 0
num_views = len(processed_data)
new_view = (current_view + direction) % num_views
new_selector_value = f"View {new_view + 1}"
depth_vis = update_depth_view(processed_data, new_view)
return new_selector_value, depth_vis
def navigate_normal_view(processed_data, current_selector_value, direction):
"""Navigate normal view (direction: -1 for previous, +1 for next)"""
if processed_data is None or len(processed_data) == 0:
return "View 1", None
# Parse current view number
try:
current_view = int(current_selector_value.split()[1]) - 1
except:
current_view = 0
num_views = len(processed_data)
new_view = (current_view + direction) % num_views
new_selector_value = f"View {new_view + 1}"
normal_vis = update_normal_view(processed_data, new_view)
return new_selector_value, normal_vis
def navigate_measure_view(processed_data, current_selector_value, direction):
"""Navigate measure view (direction: -1 for previous, +1 for next)"""
if processed_data is None or len(processed_data) == 0:
return "View 1", None, []
# Parse current view number
try:
current_view = int(current_selector_value.split()[1]) - 1
except:
current_view = 0
num_views = len(processed_data)
new_view = (current_view + direction) % num_views
new_selector_value = f"View {new_view + 1}"
measure_image, measure_points = update_measure_view(processed_data, new_view)
return new_selector_value, measure_image, measure_points
def populate_visualization_tabs(processed_data):
"""Populate the depth, normal, and measure tabs with processed data"""
if processed_data is None or len(processed_data) == 0:
return None, None, None, []
# Use update functions to ensure confidence filtering is applied from the start
depth_vis = update_depth_view(processed_data, 0)
normal_vis = update_normal_view(processed_data, 0)
measure_img, _ = update_measure_view(processed_data, 0)
return depth_vis, normal_vis, measure_img, []
# -------------------------------------------------------------------------
# 2) Handle uploaded video/images --> produce target_dir + images
# -------------------------------------------------------------------------
def handle_uploads(unified_upload, s_time_interval=1.0):
"""
Create a new 'target_dir' + 'images' subfolder, and place user-uploaded
images or extracted frames from video into it. Return (target_dir, image_paths).
"""
start_time = time.time()
gc.collect()
torch.cuda.empty_cache()
# Create a unique folder name
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S_%f")
target_dir = f"input_images_{timestamp}"
target_dir_images = os.path.join(target_dir, "images")
# Clean up if somehow that folder already exists
if os.path.exists(target_dir):
shutil.rmtree(target_dir)
os.makedirs(target_dir)
os.makedirs(target_dir_images)
image_paths = []
# --- Handle uploaded files (both images and videos) ---
if unified_upload is not None:
for file_data in unified_upload:
if isinstance(file_data, dict) and "name" in file_data:
file_path = file_data["name"]
else:
file_path = str(file_data)
file_ext = os.path.splitext(file_path)[1].lower()
# Check if it's a video file
video_extensions = [
".mp4",
".avi",
".mov",
".mkv",
".wmv",
".flv",
".webm",
".m4v",
".3gp",
]
if file_ext in video_extensions:
# Handle as video
vs = cv2.VideoCapture(file_path)
fps = vs.get(cv2.CAP_PROP_FPS)
frame_interval = int(fps * s_time_interval) # frames per interval
count = 0
video_frame_num = 0
while True:
gotit, frame = vs.read()
if not gotit:
break
count += 1
if count % frame_interval == 0:
# Use original filename as prefix for frames
base_name = os.path.splitext(os.path.basename(file_path))[0]
image_path = os.path.join(
target_dir_images, f"{base_name}_{video_frame_num:06}.png"
)
cv2.imwrite(image_path, frame)
image_paths.append(image_path)
video_frame_num += 1
vs.release()
print(
f"Extracted {video_frame_num} frames from video: {os.path.basename(file_path)}"
)
else:
# Handle as image
# Check if the file is a HEIC image
if file_ext in [".heic", ".heif"]:
# Convert HEIC to JPEG for better gallery compatibility
try:
with Image.open(file_path) as img:
# Convert to RGB if necessary (HEIC can have different color modes)
if img.mode not in ("RGB", "L"):
img = img.convert("RGB")
# Create JPEG filename
base_name = os.path.splitext(os.path.basename(file_path))[0]
dst_path = os.path.join(
target_dir_images, f"{base_name}.jpg"
)
# Save as JPEG with high quality
img.save(dst_path, "JPEG", quality=95)
image_paths.append(dst_path)
print(
f"Converted HEIC to JPEG: {os.path.basename(file_path)} -> {os.path.basename(dst_path)}"
)
except Exception as e:
print(f"Error converting HEIC file {file_path}: {e}")
# Fall back to copying as is
dst_path = os.path.join(
target_dir_images, os.path.basename(file_path)
)
shutil.copy(file_path, dst_path)
image_paths.append(dst_path)
else:
# Regular image files - copy as is
dst_path = os.path.join(
target_dir_images, os.path.basename(file_path)
)
shutil.copy(file_path, dst_path)
image_paths.append(dst_path)
# Sort final images for gallery
image_paths = sorted(image_paths)
end_time = time.time()
print(
f"Files processed to {target_dir_images}; took {end_time - start_time:.3f} seconds"
)
return target_dir, image_paths
# -------------------------------------------------------------------------
# 3) Update gallery on upload
# -------------------------------------------------------------------------
def update_gallery_on_upload(input_video, input_images, s_time_interval=1.0):
"""
Whenever user uploads or changes files, immediately handle them
and show in the gallery. Return (target_dir, image_paths).
If nothing is uploaded, returns "None" and empty list.
"""
if not input_video and not input_images:
return None, None, None, None
target_dir, image_paths = handle_uploads(input_video, input_images, s_time_interval)
return (
None,
target_dir,
image_paths,
"Upload complete. Click 'Reconstruct' to begin 3D processing.",
)
# -------------------------------------------------------------------------
# 4) Reconstruction: uses the target_dir plus any viz parameters
# -------------------------------------------------------------------------
@spaces.GPU(duration=120)
def gradio_demo(
target_dir,
frame_filter="All",
show_cam=True,
filter_black_bg=False,
filter_white_bg=False,
apply_mask=True,
show_mesh=True,
):
"""
Perform reconstruction using the already-created target_dir/images.
"""
if not os.path.isdir(target_dir) or target_dir == "None":
return None, "No valid target directory found. Please upload first.", None, None
start_time = time.time()
gc.collect()
torch.cuda.empty_cache()
# Prepare frame_filter dropdown
target_dir_images = os.path.join(target_dir, "images")
all_files = (
sorted(os.listdir(target_dir_images))
if os.path.isdir(target_dir_images)
else []
)
all_files = [f"{i}: {filename}" for i, filename in enumerate(all_files)]
frame_filter_choices = ["All"] + all_files
print("Running MapAnything model...")
with torch.no_grad():
predictions, processed_data = run_model(target_dir, apply_mask)
# Save predictions
prediction_save_path = os.path.join(target_dir, "predictions.npz")
np.savez(prediction_save_path, **predictions)
# Handle None frame_filter
if frame_filter is None:
frame_filter = "All"
# Build a GLB file name
glbfile = os.path.join(
target_dir,
f"glbscene_{frame_filter.replace('.', '_').replace(':', '').replace(' ', '_')}_cam{show_cam}_mesh{show_mesh}_black{filter_black_bg}_white{filter_white_bg}.glb",
)
# Convert predictions to GLB
glbscene = predictions_to_glb(
predictions,
filter_by_frames=frame_filter,
show_cam=show_cam,
mask_black_bg=filter_black_bg,
mask_white_bg=filter_white_bg,
as_mesh=show_mesh, # Use the show_mesh parameter
)
glbscene.export(file_obj=glbfile)
# Cleanup
del predictions
gc.collect()
torch.cuda.empty_cache()
end_time = time.time()
print(f"Total time: {end_time - start_time:.2f} seconds")
log_msg = (
f"Reconstruction Success ({len(all_files)} frames). Waiting for visualization."
)
# Populate visualization tabs with processed data
depth_vis, normal_vis, measure_img, measure_pts = populate_visualization_tabs(
processed_data
)
# Update view selectors based on available views
depth_selector, normal_selector, measure_selector = update_view_selectors(
processed_data
)
return (
glbfile,
log_msg,
gr.Dropdown(choices=frame_filter_choices, value=frame_filter, interactive=True),
processed_data,
depth_vis,
normal_vis,
measure_img,
"", # measure_text (empty initially)
depth_selector,
normal_selector,
measure_selector,
)
# -------------------------------------------------------------------------
# 5) Helper functions for UI resets + re-visualization
# -------------------------------------------------------------------------
def colorize_depth(depth_map, mask=None):
"""Convert depth map to colorized visualization with optional mask"""
if depth_map is None:
return None
# Normalize depth to 0-1 range
depth_normalized = depth_map.copy()
valid_mask = depth_normalized > 0
# Apply additional mask if provided (for background filtering)
if mask is not None:
valid_mask = valid_mask & mask
if valid_mask.sum() > 0:
valid_depths = depth_normalized[valid_mask]
p5 = np.percentile(valid_depths, 5)
p95 = np.percentile(valid_depths, 95)
depth_normalized[valid_mask] = (depth_normalized[valid_mask] - p5) / (p95 - p5)
# Apply colormap
import matplotlib.pyplot as plt
colormap = plt.cm.turbo_r
colored = colormap(depth_normalized)
colored = (colored[:, :, :3] * 255).astype(np.uint8)
# Set invalid pixels to white
colored[~valid_mask] = [255, 255, 255]
return colored
def colorize_normal(normal_map, mask=None):
"""Convert normal map to colorized visualization with optional mask"""
if normal_map is None:
return None
# Create a copy for modification
normal_vis = normal_map.copy()
# Apply mask if provided (set masked areas to [0, 0, 0] which becomes grey after normalization)
if mask is not None:
invalid_mask = ~mask
normal_vis[invalid_mask] = [0, 0, 0] # Set invalid areas to zero
# Normalize normals to [0, 1] range for visualization
normal_vis = (normal_vis + 1.0) / 2.0
normal_vis = (normal_vis * 255).astype(np.uint8)
return normal_vis
def process_predictions_for_visualization(
predictions, views, high_level_config, filter_black_bg=False, filter_white_bg=False
):
"""Extract depth, normal, and 3D points from predictions for visualization"""
processed_data = {}
# Process each view
for view_idx, view in enumerate(views):
# Get image
image = rgb(view["img"], norm_type=high_level_config["data_norm_type"])
# Get predicted points
pred_pts3d = predictions["world_points"][view_idx]
# Initialize data for this view
view_data = {
"image": image[0],
"points3d": pred_pts3d,
"depth": None,
"normal": None,
"mask": None,
}
# Start with the final mask from predictions
mask = predictions["final_mask"][view_idx].copy()
# Apply black background filtering if enabled
if filter_black_bg:
# Get the image colors (ensure they're in 0-255 range)
view_colors = image[0] * 255 if image[0].max() <= 1.0 else image[0]
# Filter out black background pixels (sum of RGB < 16)
black_bg_mask = view_colors.sum(axis=2) >= 16
mask = mask & black_bg_mask
# Apply white background filtering if enabled
if filter_white_bg:
# Get the image colors (ensure they're in 0-255 range)
view_colors = image[0] * 255 if image[0].max() <= 1.0 else image[0]
# Filter out white background pixels (all RGB > 240)
white_bg_mask = ~(
(view_colors[:, :, 0] > 240)
& (view_colors[:, :, 1] > 240)
& (view_colors[:, :, 2] > 240)
)
mask = mask & white_bg_mask
view_data["mask"] = mask
view_data["depth"] = predictions["depth"][view_idx].squeeze()
normals, _ = points_to_normals(pred_pts3d, mask=view_data["mask"])
view_data["normal"] = normals
processed_data[view_idx] = view_data
return processed_data
def reset_measure(processed_data):
"""Reset measure points"""
if processed_data is None or len(processed_data) == 0:
return None, [], ""
# Return the first view image
first_view = list(processed_data.values())[0]
return first_view["image"], [], ""
def measure(
processed_data, measure_points, current_view_selector, event: gr.SelectData
):
"""Handle measurement on images"""
try:
print(f"Measure function called with selector: {current_view_selector}")
if processed_data is None or len(processed_data) == 0:
return None, [], "No data available"
# Use the currently selected view instead of always using the first view
try:
current_view_index = int(current_view_selector.split()[1]) - 1
except:
current_view_index = 0
print(f"Using view index: {current_view_index}")
# Get view data safely
if current_view_index < 0 or current_view_index >= len(processed_data):
current_view_index = 0
view_keys = list(processed_data.keys())
current_view = processed_data[view_keys[current_view_index]]
if current_view is None:
return None, [], "No view data available"
point2d = event.index[0], event.index[1]
print(f"Clicked point: {point2d}")
# Check if the clicked point is in a masked area (prevent interaction)
if (
current_view["mask"] is not None
and 0 <= point2d[1] < current_view["mask"].shape[0]
and 0 <= point2d[0] < current_view["mask"].shape[1]
):
# Check if the point is in a masked (invalid) area
if not current_view["mask"][point2d[1], point2d[0]]:
print(f"Clicked point {point2d} is in masked area, ignoring click")
# Always return image with mask overlay
masked_image, _ = update_measure_view(
processed_data, current_view_index
)
return (
masked_image,
measure_points,
'<span style="color: red; font-weight: bold;">Cannot measure on masked areas (shown in grey)</span>',
)
measure_points.append(point2d)
# Get image with mask overlay and ensure it's valid
image, _ = update_measure_view(processed_data, current_view_index)
if image is None:
return None, [], "No image available"
image = image.copy()
points3d = current_view["points3d"]
# Ensure image is in uint8 format for proper cv2 operations
try:
if image.dtype != np.uint8:
if image.max() <= 1.0:
# Image is in [0, 1] range, convert to [0, 255]
image = (image * 255).astype(np.uint8)
else:
# Image is already in [0, 255] range
image = image.astype(np.uint8)
except Exception as e:
print(f"Image conversion error: {e}")
return None, [], f"Image conversion error: {e}"
# Draw circles for points
try:
for p in measure_points:
if 0 <= p[0] < image.shape[1] and 0 <= p[1] < image.shape[0]:
image = cv2.circle(
image, p, radius=5, color=(255, 0, 0), thickness=2
)
except Exception as e:
print(f"Drawing error: {e}")
return None, [], f"Drawing error: {e}"
depth_text = ""
try:
for i, p in enumerate(measure_points):
if (
current_view["depth"] is not None
and 0 <= p[1] < current_view["depth"].shape[0]
and 0 <= p[0] < current_view["depth"].shape[1]
):
d = current_view["depth"][p[1], p[0]]
depth_text += f"- **P{i + 1} depth: {d:.2f}m.**\n"
else:
# Use Z coordinate of 3D points if depth not available
if (
points3d is not None
and 0 <= p[1] < points3d.shape[0]
and 0 <= p[0] < points3d.shape[1]
):
z = points3d[p[1], p[0], 2]
depth_text += f"- **P{i + 1} Z-coord: {z:.2f}m.**\n"
except Exception as e:
print(f"Depth text error: {e}")
depth_text = f"Error computing depth: {e}\n"
if len(measure_points) == 2:
try:
point1, point2 = measure_points
# Draw line
if (
0 <= point1[0] < image.shape[1]
and 0 <= point1[1] < image.shape[0]
and 0 <= point2[0] < image.shape[1]
and 0 <= point2[1] < image.shape[0]
):
image = cv2.line(
image, point1, point2, color=(255, 0, 0), thickness=2
)
# Compute 3D distance
distance_text = "- **Distance: Unable to compute**"
if (
points3d is not None
and 0 <= point1[1] < points3d.shape[0]
and 0 <= point1[0] < points3d.shape[1]
and 0 <= point2[1] < points3d.shape[0]
and 0 <= point2[0] < points3d.shape[1]
):
try:
p1_3d = points3d[point1[1], point1[0]]
p2_3d = points3d[point2[1], point2[0]]
distance = np.linalg.norm(p1_3d - p2_3d)
distance_text = f"- **Distance: {distance:.2f}m**"
except Exception as e:
print(f"Distance computation error: {e}")
distance_text = f"- **Distance computation error: {e}**"
measure_points = []
text = depth_text + distance_text
print(f"Measurement complete: {text}")
return [image, measure_points, text]
except Exception as e:
print(f"Final measurement error: {e}")
return None, [], f"Measurement error: {e}"
else:
print(f"Single point measurement: {depth_text}")
return [image, measure_points, depth_text]
except Exception as e:
print(f"Overall measure function error: {e}")
return None, [], f"Measure function error: {e}"
def clear_fields():
"""
Clears the 3D viewer, the stored target_dir, and empties the gallery.
"""
return None
def update_log():
"""
Display a quick log message while waiting.
"""
return "Loading and Reconstructing..."
def update_visualization(
target_dir,
frame_filter,
show_cam,
is_example,
filter_black_bg=False,
filter_white_bg=False,
show_mesh=True,
):
"""
Reload saved predictions from npz, create (or reuse) the GLB for new parameters,
and return it for the 3D viewer. If is_example == "True", skip.
"""
# If it's an example click, skip as requested
if is_example == "True":
return (
gr.update(),
"No reconstruction available. Please click the Reconstruct button first.",
)
if not target_dir or target_dir == "None" or not os.path.isdir(target_dir):
return (
gr.update(),
"No reconstruction available. Please click the Reconstruct button first.",
)
predictions_path = os.path.join(target_dir, "predictions.npz")
if not os.path.exists(predictions_path):
return (
gr.update(),
f"No reconstruction available at {predictions_path}. Please run 'Reconstruct' first.",
)
loaded = np.load(predictions_path, allow_pickle=True)
predictions = {key: loaded[key] for key in loaded.keys()}
glbfile = os.path.join(
target_dir,
f"glbscene_{frame_filter.replace('.', '_').replace(':', '').replace(' ', '_')}_cam{show_cam}_mesh{show_mesh}_black{filter_black_bg}_white{filter_white_bg}.glb",
)
if not os.path.exists(glbfile):
glbscene = predictions_to_glb(
predictions,
filter_by_frames=frame_filter,
show_cam=show_cam,
mask_black_bg=filter_black_bg,
mask_white_bg=filter_white_bg,
as_mesh=show_mesh,
)
glbscene.export(file_obj=glbfile)
return (
glbfile,
"Visualization updated.",
)
def update_all_views_on_filter_change(
target_dir,
filter_black_bg,
filter_white_bg,
processed_data,
depth_view_selector,
normal_view_selector,
measure_view_selector,
):
"""
Update all individual view tabs when background filtering checkboxes change.
This regenerates the processed data with new filtering and updates all views.
"""
# Check if we have a valid target directory and predictions
if not target_dir or target_dir == "None" or not os.path.isdir(target_dir):
return processed_data, None, None, None, []
predictions_path = os.path.join(target_dir, "predictions.npz")
if not os.path.exists(predictions_path):
return processed_data, None, None, None, []
try:
# Load the original predictions and views
loaded = np.load(predictions_path, allow_pickle=True)
predictions = {key: loaded[key] for key in loaded.keys()}
# Load images using MapAnything's load_images function
image_folder_path = os.path.join(target_dir, "images")
views = load_images(image_folder_path)
# Regenerate processed data with new filtering settings
new_processed_data = process_predictions_for_visualization(
predictions, views, high_level_config, filter_black_bg, filter_white_bg
)
# Get current view indices
try:
depth_view_idx = (
int(depth_view_selector.split()[1]) - 1 if depth_view_selector else 0
)
except:
depth_view_idx = 0
try:
normal_view_idx = (
int(normal_view_selector.split()[1]) - 1 if normal_view_selector else 0
)
except:
normal_view_idx = 0
try:
measure_view_idx = (
int(measure_view_selector.split()[1]) - 1
if measure_view_selector
else 0
)
except:
measure_view_idx = 0
# Update all views with new filtered data
depth_vis = update_depth_view(new_processed_data, depth_view_idx)
normal_vis = update_normal_view(new_processed_data, normal_view_idx)
measure_img, _ = update_measure_view(new_processed_data, measure_view_idx)
return new_processed_data, depth_vis, normal_vis, measure_img, []
except Exception as e:
print(f"Error updating views on filter change: {e}")
return processed_data, None, None, None, []
# -------------------------------------------------------------------------
# Example scene functions
# -------------------------------------------------------------------------
def get_scene_info(examples_dir):
"""Get information about scenes in the examples directory"""
import glob
scenes = []
if not os.path.exists(examples_dir):
return scenes
for scene_folder in sorted(os.listdir(examples_dir)):
scene_path = os.path.join(examples_dir, scene_folder)
if os.path.isdir(scene_path):
# Find all image files in the scene folder
image_extensions = ["*.jpg", "*.jpeg", "*.png", "*.bmp", "*.tiff", "*.tif"]
image_files = []
for ext in image_extensions:
image_files.extend(glob.glob(os.path.join(scene_path, ext)))
image_files.extend(glob.glob(os.path.join(scene_path, ext.upper())))
if image_files:
# Sort images and get the first one for thumbnail
image_files = sorted(image_files)
first_image = image_files[0]
num_images = len(image_files)
scenes.append(
{
"name": scene_folder,
"path": scene_path,
"thumbnail": first_image,
"num_images": num_images,
"image_files": image_files,
}
)
return scenes
def load_example_scene(scene_name, examples_dir="examples"):
"""Load a scene from examples directory"""
scenes = get_scene_info(examples_dir)
# Find the selected scene
selected_scene = None
for scene in scenes:
if scene["name"] == scene_name:
selected_scene = scene
break
if selected_scene is None:
return None, None, None, "Scene not found"
# Create file-like objects for the unified upload system
# Convert image file paths to the format expected by unified_upload
file_objects = []
for image_path in selected_scene["image_files"]:
file_objects.append(image_path)
# Create target directory and copy images using the unified upload system
target_dir, image_paths = handle_uploads(file_objects, 1.0)
return (
None, # Clear reconstruction output
target_dir, # Set target directory
image_paths, # Set gallery
f"Loaded scene '{scene_name}' with {selected_scene['num_images']} images. Click 'Reconstruct' to begin 3D processing.",
)
# -------------------------------------------------------------------------
# 6) Build Gradio UI
# -------------------------------------------------------------------------
theme = get_gradio_theme()
with gr.Blocks(theme=theme, css=GRADIO_CSS) as demo:
# State variables for the tabbed interface
is_example = gr.Textbox(label="is_example", visible=False, value="None")
num_images = gr.Textbox(label="num_images", visible=False, value="None")
processed_data_state = gr.State(value=None)
measure_points_state = gr.State(value=[])
current_view_index = gr.State(value=0) # Track current view index for navigation
gr.HTML(get_header_html(get_logo_base64()))
gr.HTML(get_description_html())
target_dir_output = gr.Textbox(label="Target Dir", visible=False, value="None")
with gr.Row():
with gr.Column(scale=2):
# Unified upload component for both videos and images
unified_upload = gr.File(
file_count="multiple",
label="Upload Video or Images",
interactive=True,
file_types=["image", "video"],
)
with gr.Row():
s_time_interval = gr.Slider(
minimum=0.1,
maximum=5.0,
value=1.0,
step=0.1,
label="Video sample time interval (take a sample every x sec.)",
interactive=True,
visible=True,
scale=3,
)
resample_btn = gr.Button(
"Resample Video",
visible=False,
variant="secondary",
scale=1,
)
image_gallery = gr.Gallery(
label="Preview",
columns=4,
height="300px",
show_download_button=True,
object_fit="contain",
preview=True,
)
clear_uploads_btn = gr.ClearButton(
[unified_upload, image_gallery],
value="Clear Uploads",
variant="secondary",
size="sm",
)
with gr.Column(scale=4):
with gr.Column():
gr.Markdown(
"**Metric 3D Reconstruction (Point Cloud and Camera Poses)**"
)
log_output = gr.Markdown(
"Please upload a video or images, then click Reconstruct.",
elem_classes=["custom-log"],
)
# Add tabbed interface similar to MoGe
with gr.Tabs():
with gr.Tab("3D View"):
reconstruction_output = gr.Model3D(
height=520,
zoom_speed=0.5,
pan_speed=0.5,
clear_color=[0.0, 0.0, 0.0, 0.0],
key="persistent_3d_viewer",
elem_id="reconstruction_3d_viewer",
)
with gr.Tab("Depth"):
with gr.Row(elem_classes=["navigation-row"]):
prev_depth_btn = gr.Button("◀ Previous", size="sm", scale=1)
depth_view_selector = gr.Dropdown(
choices=["View 1"],
value="View 1",
label="Select View",
scale=2,
interactive=True,
allow_custom_value=True,
)
next_depth_btn = gr.Button("Next ▶", size="sm", scale=1)
depth_map = gr.Image(
type="numpy",
label="Colorized Depth Map",
format="png",
interactive=False,
)
with gr.Tab("Normal"):
with gr.Row(elem_classes=["navigation-row"]):
prev_normal_btn = gr.Button(
"◀ Previous", size="sm", scale=1
)
normal_view_selector = gr.Dropdown(
choices=["View 1"],
value="View 1",
label="Select View",
scale=2,
interactive=True,
allow_custom_value=True,
)
next_normal_btn = gr.Button("Next ▶", size="sm", scale=1)
normal_map = gr.Image(
type="numpy",
label="Normal Map",
format="png",
interactive=False,
)
with gr.Tab("Measure"):
gr.Markdown(MEASURE_INSTRUCTIONS_HTML)
with gr.Row(elem_classes=["navigation-row"]):
prev_measure_btn = gr.Button(
"◀ Previous", size="sm", scale=1
)
measure_view_selector = gr.Dropdown(
choices=["View 1"],
value="View 1",
label="Select View",
scale=2,
interactive=True,
allow_custom_value=True,
)
next_measure_btn = gr.Button("Next ▶", size="sm", scale=1)
measure_image = gr.Image(
type="numpy",
show_label=False,
format="webp",
interactive=False,
sources=[],
)
gr.Markdown(
"**Note:** Light-grey areas indicate regions with no depth information where measurements cannot be taken."
)
measure_text = gr.Markdown("")
with gr.Row():
submit_btn = gr.Button("Reconstruct", scale=1, variant="primary")
clear_btn = gr.ClearButton(
[
unified_upload,
reconstruction_output,
log_output,
target_dir_output,
image_gallery,
],
scale=1,
)
with gr.Row():
frame_filter = gr.Dropdown(
choices=["All"], value="All", label="Show Points from Frame"
)
with gr.Column():
gr.Markdown("### Pointcloud Options: (live updates)")
show_cam = gr.Checkbox(label="Show Camera", value=True)
show_mesh = gr.Checkbox(label="Show Mesh", value=True)
filter_black_bg = gr.Checkbox(
label="Filter Black Background", value=False
)
filter_white_bg = gr.Checkbox(
label="Filter White Background", value=False
)
gr.Markdown("### Reconstruction Options: (updated on next run)")
apply_mask_checkbox = gr.Checkbox(
label="Apply mask for predicted ambiguous depth classes & edges",
value=True,
)
# ---------------------- Example Scenes Section ----------------------
gr.Markdown("## Example Scenes (lists all scenes in the examples folder)")
gr.Markdown("Click any thumbnail to load the scene for reconstruction.")
# Get scene information
scenes = get_scene_info("examples")
# Create thumbnail grid (4 columns, N rows)
if scenes:
for i in range(0, len(scenes), 4): # Process 4 scenes per row
with gr.Row():
for j in range(4):
scene_idx = i + j
if scene_idx < len(scenes):
scene = scenes[scene_idx]
with gr.Column(scale=1, elem_classes=["clickable-thumbnail"]):
# Clickable thumbnail
scene_img = gr.Image(
value=scene["thumbnail"],
height=150,
interactive=False,
show_label=False,
elem_id=f"scene_thumb_{scene['name']}",
sources=[],
)
# Scene name and image count as text below thumbnail
gr.Markdown(
f"**{scene['name']}** \n {scene['num_images']} images",
elem_classes=["scene-info"],
)
# Connect thumbnail click to load scene
scene_img.select(
fn=lambda name=scene["name"]: load_example_scene(name),
outputs=[
reconstruction_output,
target_dir_output,
image_gallery,
log_output,
],
)
else:
# Empty column to maintain grid structure
with gr.Column(scale=1):
pass
# -------------------------------------------------------------------------
# "Reconstruct" button logic:
# - Clear fields
# - Update log
# - gradio_demo(...) with the existing target_dir
# - Then set is_example = "False"
# -------------------------------------------------------------------------
submit_btn.click(fn=clear_fields, inputs=[], outputs=[reconstruction_output]).then(
fn=update_log, inputs=[], outputs=[log_output]
).then(
fn=gradio_demo,
inputs=[
target_dir_output,
frame_filter,
show_cam,
filter_black_bg,
filter_white_bg,
apply_mask_checkbox,
show_mesh,
],
outputs=[
reconstruction_output,
log_output,
frame_filter,
processed_data_state,
depth_map,
normal_map,
measure_image,
measure_text,
depth_view_selector,
normal_view_selector,
measure_view_selector,
],
).then(
fn=lambda: "False",
inputs=[],
outputs=[is_example], # set is_example to "False"
)
# -------------------------------------------------------------------------
# Real-time Visualization Updates
# -------------------------------------------------------------------------
frame_filter.change(
update_visualization,
[
target_dir_output,
frame_filter,
show_cam,
is_example,
filter_black_bg,
filter_white_bg,
show_mesh,
],
[reconstruction_output, log_output],
)
show_cam.change(
update_visualization,
[
target_dir_output,
frame_filter,
show_cam,
is_example,
],
[reconstruction_output, log_output],
)
filter_black_bg.change(
update_visualization,
[
target_dir_output,
frame_filter,
show_cam,
is_example,
filter_black_bg,
filter_white_bg,
],
[reconstruction_output, log_output],
).then(
fn=update_all_views_on_filter_change,
inputs=[
target_dir_output,
filter_black_bg,
filter_white_bg,
processed_data_state,
depth_view_selector,
normal_view_selector,
measure_view_selector,
],
outputs=[
processed_data_state,
depth_map,
normal_map,
measure_image,
measure_points_state,
],
)
filter_white_bg.change(
update_visualization,
[
target_dir_output,
frame_filter,
show_cam,
is_example,
filter_black_bg,
filter_white_bg,
show_mesh,
],
[reconstruction_output, log_output],
).then(
fn=update_all_views_on_filter_change,
inputs=[
target_dir_output,
filter_black_bg,
filter_white_bg,
processed_data_state,
depth_view_selector,
normal_view_selector,
measure_view_selector,
],
outputs=[
processed_data_state,
depth_map,
normal_map,
measure_image,
measure_points_state,
],
)
show_mesh.change(
update_visualization,
[
target_dir_output,
frame_filter,
show_cam,
is_example,
filter_black_bg,
filter_white_bg,
show_mesh,
],
[reconstruction_output, log_output],
)
# -------------------------------------------------------------------------
# Auto-update gallery whenever user uploads or changes their files
# -------------------------------------------------------------------------
def update_gallery_on_unified_upload(files, interval):
if not files:
return None, None, None
target_dir, image_paths = handle_uploads(files, interval)
return (
target_dir,
image_paths,
"Upload complete. Click 'Reconstruct' to begin 3D processing.",
)
def show_resample_button(files):
"""Show the resample button only if there are uploaded files containing videos"""
if not files:
return gr.update(visible=False)
# Check if any uploaded files are videos
video_extensions = [
".mp4",
".avi",
".mov",
".mkv",
".wmv",
".flv",
".webm",
".m4v",
".3gp",
]
has_video = False
for file_data in files:
if isinstance(file_data, dict) and "name" in file_data:
file_path = file_data["name"]
else:
file_path = str(file_data)
file_ext = os.path.splitext(file_path)[1].lower()
if file_ext in video_extensions:
has_video = True
break
return gr.update(visible=has_video)
def hide_resample_button():
"""Hide the resample button after use"""
return gr.update(visible=False)
def resample_video_with_new_interval(files, new_interval, current_target_dir):
"""Resample video with new slider value"""
if not files:
return (
current_target_dir,
None,
"No files to resample.",
gr.update(visible=False),
)
# Check if we have videos to resample
video_extensions = [
".mp4",
".avi",
".mov",
".mkv",
".wmv",
".flv",
".webm",
".m4v",
".3gp",
]
has_video = any(
os.path.splitext(
str(file_data["name"] if isinstance(file_data, dict) else file_data)
)[1].lower()
in video_extensions
for file_data in files
)
if not has_video:
return (
current_target_dir,
None,
"No videos found to resample.",
gr.update(visible=False),
)
# Clean up old target directory if it exists
if (
current_target_dir
and current_target_dir != "None"
and os.path.exists(current_target_dir)
):
shutil.rmtree(current_target_dir)
# Process files with new interval
target_dir, image_paths = handle_uploads(files, new_interval)
return (
target_dir,
image_paths,
f"Video resampled with {new_interval}s interval. Click 'Reconstruct' to begin 3D processing.",
gr.update(visible=False),
)
unified_upload.change(
fn=update_gallery_on_unified_upload,
inputs=[unified_upload, s_time_interval],
outputs=[target_dir_output, image_gallery, log_output],
).then(
fn=show_resample_button,
inputs=[unified_upload],
outputs=[resample_btn],
)
# Show resample button when slider changes (only if files are uploaded)
s_time_interval.change(
fn=show_resample_button,
inputs=[unified_upload],
outputs=[resample_btn],
)
# Handle resample button click
resample_btn.click(
fn=resample_video_with_new_interval,
inputs=[unified_upload, s_time_interval, target_dir_output],
outputs=[target_dir_output, image_gallery, log_output, resample_btn],
)
# -------------------------------------------------------------------------
# Measure tab functionality
# -------------------------------------------------------------------------
measure_image.select(
fn=measure,
inputs=[processed_data_state, measure_points_state, measure_view_selector],
outputs=[measure_image, measure_points_state, measure_text],
)
# -------------------------------------------------------------------------
# Navigation functionality for Depth, Normal, and Measure tabs
# -------------------------------------------------------------------------
# Depth tab navigation
prev_depth_btn.click(
fn=lambda processed_data, current_selector: navigate_depth_view(
processed_data, current_selector, -1
),
inputs=[processed_data_state, depth_view_selector],
outputs=[depth_view_selector, depth_map],
)
next_depth_btn.click(
fn=lambda processed_data, current_selector: navigate_depth_view(
processed_data, current_selector, 1
),
inputs=[processed_data_state, depth_view_selector],
outputs=[depth_view_selector, depth_map],
)
depth_view_selector.change(
fn=lambda processed_data, selector_value: (
update_depth_view(
processed_data,
int(selector_value.split()[1]) - 1,
)
if selector_value
else None
),
inputs=[processed_data_state, depth_view_selector],
outputs=[depth_map],
)
# Normal tab navigation
prev_normal_btn.click(
fn=lambda processed_data, current_selector: navigate_normal_view(
processed_data, current_selector, -1
),
inputs=[processed_data_state, normal_view_selector],
outputs=[normal_view_selector, normal_map],
)
next_normal_btn.click(
fn=lambda processed_data, current_selector: navigate_normal_view(
processed_data, current_selector, 1
),
inputs=[processed_data_state, normal_view_selector],
outputs=[normal_view_selector, normal_map],
)
normal_view_selector.change(
fn=lambda processed_data, selector_value: (
update_normal_view(
processed_data,
int(selector_value.split()[1]) - 1,
)
if selector_value
else None
),
inputs=[processed_data_state, normal_view_selector],
outputs=[normal_map],
)
# Measure tab navigation
prev_measure_btn.click(
fn=lambda processed_data, current_selector: navigate_measure_view(
processed_data, current_selector, -1
),
inputs=[processed_data_state, measure_view_selector],
outputs=[measure_view_selector, measure_image, measure_points_state],
)
next_measure_btn.click(
fn=lambda processed_data, current_selector: navigate_measure_view(
processed_data, current_selector, 1
),
inputs=[processed_data_state, measure_view_selector],
outputs=[measure_view_selector, measure_image, measure_points_state],
)
measure_view_selector.change(
fn=lambda processed_data, selector_value: (
update_measure_view(processed_data, int(selector_value.split()[1]) - 1)
if selector_value
else (None, [])
),
inputs=[processed_data_state, measure_view_selector],
outputs=[measure_image, measure_points_state],
)
# -------------------------------------------------------------------------
# Acknowledgement section
# -------------------------------------------------------------------------
gr.HTML(get_acknowledgements_html())
demo.queue(max_size=20).launch(show_error=True, share=True, ssr_mode=False)