import gc
import os
import shutil
import time
from datetime import datetime
import io
import sys
os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True"
import cv2
import gradio as gr
import numpy as np
import spaces
import torch
from PIL import Image
from pillow_heif import register_heif_opener
register_heif_opener()
from src.utils.inference_utils import load_and_preprocess_images
from src.utils.geometry import (
    depth_edge,
    normals_edge
)
from src.utils.visual_util import (
    convert_predictions_to_glb_scene,
    segment_sky,
    download_file_from_url
)
from src.utils.save_utils import save_camera_params, save_gs_ply, process_ply_to_splat, convert_gs_to_ply
from src.utils.render_utils import render_interpolated_video
import onnxruntime
# Initialize model - this will be done on GPU when needed
model = None
# Global variable to store current terminal output
current_terminal_output = ""
# Helper class to capture terminal output
class TeeOutput:
    """Capture output while still printing to console"""
    def __init__(self, max_chars=10000):
        self.terminal = sys.stdout
        self.log = io.StringIO()
        self.max_chars = max_chars  # 限制最大字符数
    
    def write(self, message):
        global current_terminal_output
        self.terminal.write(message)
        self.log.write(message)
        
        # 获取当前内容并限制长度
        content = self.log.getvalue()
        if len(content) > self.max_chars:
            # 只保留最后 max_chars 个字符
            content = "...(earlier output truncated)...\n" + content[-self.max_chars:]
            self.log = io.StringIO()
            self.log.write(content)
        
        current_terminal_output = self.log.getvalue()
    
    def flush(self):
        self.terminal.flush()
    
    def getvalue(self):
        return self.log.getvalue()
    
    def clear(self):
        global current_terminal_output
        self.log = io.StringIO()
        current_terminal_output = ""
# -------------------------------------------------------------------------
# Model inference
# -------------------------------------------------------------------------
@spaces.GPU(duration=120)
def run_model(
    target_dir,
    confidence_percentile: float = 10,
    edge_normal_threshold: float = 5.0,
    edge_depth_threshold: float = 0.03,
    apply_confidence_mask: bool = True,
    apply_edge_mask: bool = True,
):
    """
    Run the WorldMirror model on images in the 'target_dir/images' folder and return predictions.
    """
    global model
    import torch  # Ensure torch is available in function scope
    
    from src.models.models.worldmirror import WorldMirror
    from src.models.utils.geometry import depth_to_world_coords_points
    print(f"Processing images from {target_dir}")
    # Device check
    device = "cuda" if torch.cuda.is_available() else "cpu"
    device = torch.device(device)
    # Initialize model if not already done
    if model is None:
        model = WorldMirror.from_pretrained("tencent/HunyuanWorld-Mirror").to(device)
    else:
        model.to(device)
    
    model.eval()
    # Load images using WorldMirror's load_images function
    print("Loading images...")
    image_folder_path = os.path.join(target_dir, "images")
    image_file_paths = [os.path.join(image_folder_path, path) for path in os.listdir(image_folder_path)]
    img = load_and_preprocess_images(image_file_paths).to(device)
    print(f"Loaded {img.shape[1]} images")
    if img.shape[1] == 0:
        raise ValueError("No images found. Check your upload.")
    # Run model inference
    print("Running inference...")
    inputs = {}
    inputs['img'] = img
    use_amp = torch.cuda.is_available() and torch.cuda.is_bf16_supported()
    if use_amp:
        amp_dtype = torch.bfloat16
    else:
        amp_dtype = torch.float32
    with torch.amp.autocast('cuda', enabled=bool(use_amp), dtype=amp_dtype):
        predictions = model(inputs)
    # img
    imgs = inputs["img"].permute(0, 1, 3, 4, 2)
    imgs = imgs[0].detach().cpu().numpy() # S H W 3
    # depth output
    depth_preds = predictions["depth"]
    depth_conf = predictions["depth_conf"]
    depth_preds = depth_preds[0].detach().cpu().numpy() # S H W 1
    depth_conf = depth_conf[0].detach().cpu().numpy() # S H W
    # normal output
    normal_preds = predictions["normals"] # S H W 3
    normal_preds = normal_preds[0].detach().cpu().numpy() # S H W 3
    # camera parameters
    camera_poses = predictions["camera_poses"][0].detach().cpu().numpy() # [S,4,4]
    camera_intrs = predictions["camera_intrs"][0].detach().cpu().numpy() # [S,3,3]
    
    # points output
    pts3d_preds = depth_to_world_coords_points(predictions["depth"][0, ..., 0], predictions["camera_poses"][0], predictions["camera_intrs"][0])[0]
    pts3d_preds = pts3d_preds.detach().cpu().numpy()  # S H W 3
    pts3d_conf = depth_conf              # S H W
    # sky mask segmentation
    if not os.path.exists("skyseg.onnx"):
        print("Downloading skyseg.onnx...")
        download_file_from_url(
            "https://huggingface.co/JianyuanWang/skyseg/resolve/main/skyseg.onnx", "skyseg.onnx"
        )
    skyseg_session = onnxruntime.InferenceSession("skyseg.onnx")
    sky_mask_list = []
    for i, img_path in enumerate([os.path.join(image_folder_path, path) for path in os.listdir(image_folder_path)]):
        sky_mask = segment_sky(img_path, skyseg_session)
        # Resize mask to match H×W if needed
        if sky_mask.shape[0] != imgs.shape[1] or sky_mask.shape[1] != imgs.shape[2]:
            sky_mask = cv2.resize(sky_mask, (imgs.shape[2], imgs.shape[1]))
        sky_mask_list.append(sky_mask)
    sky_mask = np.stack(sky_mask_list, axis=0) # [S, H, W]
    sky_mask = sky_mask>0
    # mask computation
    final_mask_list = []    
    for i in range(inputs["img"].shape[1]):
        final_mask = None
        if apply_confidence_mask:
            # compute confidence mask based on the pointmap confidence
            confidences = pts3d_conf[i, :, :] # [H, W]
            percentile_threshold = np.quantile(confidences, confidence_percentile / 100.0)
            conf_mask = confidences >= percentile_threshold
            if final_mask is None:
                final_mask = conf_mask
            else:
                final_mask = final_mask & conf_mask
        if apply_edge_mask:
            # compute edge mask based on the normalmap
            normal_pred = normal_preds[i] # [H, W, 3]
            normal_edges = normals_edge(
                normal_pred, tol=edge_normal_threshold, mask=final_mask
            )
            # compute depth mask based on the depthmap
            depth_pred = depth_preds[i, :, :, 0] # [H, W]
            depth_edges = depth_edge(
                depth_pred, rtol=edge_depth_threshold, mask=final_mask
            )
            edge_mask = ~(depth_edges & normal_edges)
            if final_mask is None:
                final_mask = edge_mask
            else:
                final_mask = final_mask & edge_mask
        final_mask_list.append(final_mask)
    if final_mask_list[0] is not None:
        final_mask = np.stack(final_mask_list, axis=0) # [S, H, W]
    else:
        final_mask = np.ones(pts3d_conf.shape[:3], dtype=bool) # [S, H, W]
    # gaussian splatting output
    if "splats" in predictions:
        splats_dict = {}
        splats_dict['means'] = predictions["splats"]["means"]
        splats_dict['scales'] = predictions["splats"]["scales"]
        splats_dict['quats'] = predictions["splats"]["quats"]
        splats_dict['opacities'] = predictions["splats"]["opacities"]
        if "sh" in predictions["splats"]:
            splats_dict['sh'] = predictions["splats"]["sh"]
        if "colors" in predictions["splats"]:
            splats_dict['colors'] = predictions["splats"]["colors"]
    # output lists
    outputs = {}
    outputs['images'] = imgs
    outputs['world_points'] = pts3d_preds
    outputs['depth'] = depth_preds
    outputs['normal'] = normal_preds
    outputs['final_mask'] = final_mask
    outputs['sky_mask'] = sky_mask
    outputs['camera_poses'] = camera_poses
    outputs['camera_intrs'] = camera_intrs
    if "splats" in predictions:
        outputs['splats'] = splats_dict
    
    # Process data for visualization tabs (depth, normal)
    processed_data = prepare_visualization_data(
        outputs, inputs
    )
    # Clean up
    torch.cuda.empty_cache()
    return outputs, processed_data
# -------------------------------------------------------------------------
# Update and navigation function
# -------------------------------------------------------------------------
def update_view_info(current_view, total_views, view_type="Depth"):
        """Update view information display"""
        return f"""
        
            {view_type} View Navigation | 
            Current: View {current_view} / {total_views} views
        
        """
        
def update_view_selectors(processed_data):
    """Update view selector sliders and info displays based on available views"""
    if processed_data is None or len(processed_data) == 0:
        num_views = 1
    else:
        num_views = len(processed_data)
    # 确保 num_views 至少为 1
    num_views = max(1, num_views)
    # 更新滑块的最大值和视图信息,使用 gr.update() 而不是创建新组件
    depth_slider_update = gr.update(minimum=1, maximum=num_views, value=1, step=1)
    normal_slider_update = gr.update(minimum=1, maximum=num_views, value=1, step=1)
    
    # 更新视图信息显示
    depth_info_update = update_view_info(1, num_views, "Depth")
    normal_info_update = update_view_info(1, num_views, "Normal")
    return (
        depth_slider_update,  # depth_view_slider
        normal_slider_update,  # normal_view_slider
        depth_info_update,    # depth_view_info
        normal_info_update,   # normal_view_info
    )
def get_view_data_by_index(processed_data, view_index):
    """Get view data by index, handling bounds"""
    if processed_data is None or len(processed_data) == 0:
        return None
    view_keys = list(processed_data.keys())
    if view_index < 0 or view_index >= len(view_keys):
        view_index = 0
    return processed_data[view_keys[view_index]]
def update_depth_view(processed_data, view_index):
    """Update depth view for a specific view index"""
    view_data = get_view_data_by_index(processed_data, view_index)
    if view_data is None or view_data["depth"] is None:
        return None
    return render_depth_visualization(view_data["depth"], mask=view_data.get("mask"))
def update_normal_view(processed_data, view_index):
    """Update normal view for a specific view index"""
    view_data = get_view_data_by_index(processed_data, view_index)
    if view_data is None or view_data["normal"] is None:
        return None
    return render_normal_visualization(view_data["normal"], mask=view_data.get("mask"))
def initialize_depth_normal_views(processed_data):
    """Initialize the depth and normal view displays with the first view data"""
    if processed_data is None or len(processed_data) == 0:
        return None, None
    # Use update functions to ensure confidence filtering is applied from the start
    depth_vis = update_depth_view(processed_data, 0)
    normal_vis = update_normal_view(processed_data, 0)
    return depth_vis, normal_vis
# -------------------------------------------------------------------------
# File upload and update preview gallery
# -------------------------------------------------------------------------
def process_uploaded_files(files, time_interval=1.0):
    """
    Process uploaded files by extracting video frames or copying images.
    
    Args:
        files: List of uploaded file objects (videos or images)
        time_interval: Interval in seconds for video frame extraction
        
    Returns:
        tuple: (target_dir, image_paths) where target_dir is the output directory
               and image_paths is a list of processed image file paths
    """
    gc.collect()
    torch.cuda.empty_cache()
    # Create unique output directory
    timestamp = datetime.now().strftime("%Y%m%d_%H%M%S_%f")
    target_dir = f"input_images_{timestamp}"
    images_dir = os.path.join(target_dir, "images")
    if os.path.exists(target_dir):
        shutil.rmtree(target_dir)
    os.makedirs(images_dir)
    image_paths = []
    if files is None:
        return target_dir, image_paths
    video_exts = [".mp4", ".avi", ".mov", ".mkv", ".wmv", ".flv", ".webm", ".m4v", ".3gp"]
    for file_data in files:
        # Get file path
        if isinstance(file_data, dict) and "name" in file_data:
            src_path = file_data["name"]
        else:
            src_path = str(file_data)
        ext = os.path.splitext(src_path)[1].lower()
        base_name = os.path.splitext(os.path.basename(src_path))[0]
        # Process video: extract frames
        if ext in video_exts:
            cap = cv2.VideoCapture(src_path)
            fps = cap.get(cv2.CAP_PROP_FPS)
            interval = int(fps * time_interval)
            frame_count = 0
            saved_count = 0
            while True:
                ret, frame = cap.read()
                if not ret:
                    break
                frame_count += 1
                if frame_count % interval == 0:
                    dst_path = os.path.join(images_dir, f"{base_name}_{saved_count:06}.png")
                    cv2.imwrite(dst_path, frame)
                    image_paths.append(dst_path)
                    saved_count += 1
            cap.release()
            print(f"Extracted {saved_count} frames from: {os.path.basename(src_path)}")
        # Process HEIC/HEIF: convert to JPEG
        elif ext in [".heic", ".heif"]:
            try:
                with Image.open(src_path) as img:
                    if img.mode not in ("RGB", "L"):
                        img = img.convert("RGB")
                    dst_path = os.path.join(images_dir, f"{base_name}.jpg")
                    img.save(dst_path, "JPEG", quality=95)
                    image_paths.append(dst_path)
                    print(f"Converted HEIC: {os.path.basename(src_path)} -> {os.path.basename(dst_path)}")
            except Exception as e:
                print(f"HEIC conversion failed for {src_path}: {e}")
                dst_path = os.path.join(images_dir, os.path.basename(src_path))
                shutil.copy(src_path, dst_path)
                image_paths.append(dst_path)
        # Process regular images: copy directly
        else:
            dst_path = os.path.join(images_dir, os.path.basename(src_path))
            shutil.copy(src_path, dst_path)
            image_paths.append(dst_path)
    image_paths = sorted(image_paths)
    print(f"Processed files to {images_dir}")
    return target_dir, image_paths
# Handle file upload and update preview gallery
def update_gallery_on_upload(input_video, input_images, time_interval=1.0):
    """
    Process uploaded files immediately when user uploads or changes files,
    and display them in the gallery. Returns (target_dir, image_paths).
    If nothing is uploaded, returns None and empty list.
    """
    if not input_video and not input_images:
        return None, None, None, None
    target_dir, image_paths = process_uploaded_files(input_video, input_images, time_interval)
    return (
        None,
        target_dir,
        image_paths,
        "Upload complete. Click 'Reconstruct' to begin 3D processing.",
    )
        
# -------------------------------------------------------------------------
# Init function
# -------------------------------------------------------------------------
def prepare_visualization_data(
    model_outputs, input_views
):
    """Transform model predictions into structured format for display components"""
    visualization_dict = {}
    # Iterate through each input view
    nviews = input_views["img"].shape[1]
    for idx in range(nviews):
        # Extract RGB image data
        rgb_image = input_views["img"][0, idx].detach().cpu().numpy()
        # Retrieve 3D coordinate predictions
        world_coordinates = model_outputs["world_points"][idx]
        # Build view-specific data structure
        current_view_info = {
            "image": rgb_image,
            "points3d": world_coordinates,
            "depth": None,
            "normal": None,
            "mask": None,
        }
        # Apply final segmentation mask from model
        segmentation_mask = model_outputs["final_mask"][idx].copy()
        current_view_info["mask"] = segmentation_mask
        current_view_info["depth"] = model_outputs["depth"][idx].squeeze()
        surface_normals = model_outputs["normal"][idx]
        current_view_info["normal"] = surface_normals
        visualization_dict[idx] = current_view_info
    return visualization_dict
@spaces.GPU(duration=120)
def gradio_demo(
    target_dir,
    frame_selector="All",
    show_camera=False,
    filter_sky_bg=False,
    show_mesh=False,
    filter_ambiguous=False,
):
    """
    Perform reconstruction using the already-created target_dir/images.
    """
    # Capture terminal output
    tee = TeeOutput()
    old_stdout = sys.stdout
    sys.stdout = tee
    
    try:
        if not os.path.isdir(target_dir) or target_dir == "None":
            terminal_log = tee.getvalue()
            sys.stdout = old_stdout
            return None, "No valid target directory found. Please upload first.", None, None, None, None, None, None, None, None, None, None, None, None, terminal_log
        start_time = time.time()
        gc.collect()
        torch.cuda.empty_cache()
        # Prepare frame_selector dropdown
        target_dir_images = os.path.join(target_dir, "images")
        all_files = (
            sorted(os.listdir(target_dir_images))
            if os.path.isdir(target_dir_images)
            else []
        )
        all_files = [f"{i}: {filename}" for i, filename in enumerate(all_files)]
        frame_selector_choices = ["All"] + all_files
        print("Running WorldMirror model...")
        with torch.no_grad():
            predictions, processed_data = run_model(target_dir)
        # Save predictions
        prediction_save_path = os.path.join(target_dir, "predictions.npz")
        np.savez(prediction_save_path, **predictions)
        # Save camera parameters as JSON
        camera_params_file = save_camera_params(
            predictions['camera_poses'], 
            predictions['camera_intrs'], 
            target_dir
        )
        # Handle None frame_selector
        if frame_selector is None:
            frame_selector = "All"
        # Build a GLB file name
        glbfile = os.path.join(
            target_dir,
            f"glbscene_{frame_selector.replace('.', '_').replace(':', '').replace(' ', '_')}_cam{show_camera}_mesh{show_mesh}.glb",
        )
        # Convert predictions to GLB
        glbscene = convert_predictions_to_glb_scene(
            predictions,
            filter_by_frames=frame_selector,
            show_camera=show_camera,
            mask_sky_bg=filter_sky_bg,
            as_mesh=show_mesh,  # Use the show_mesh parameter
            mask_ambiguous=filter_ambiguous
        )
        glbscene.export(file_obj=glbfile)
        
        end_time = time.time()
        print(f"Total time: {end_time - start_time:.2f} seconds")
        log_msg = (
            f"Reconstruction Success ({len(all_files)} frames). Waiting for visualization."
        )
        # Convert predictions to 3dgs ply
        gs_file = None
        splat_mode = 'ply'
        if "splats" in predictions:
            # Get Gaussian parameters (already filtered by GaussianSplatRenderer)
            means = predictions["splats"]["means"][0].reshape(-1, 3)
            scales = predictions["splats"]["scales"][0].reshape(-1, 3)
            quats = predictions["splats"]["quats"][0].reshape(-1, 4)
            colors = (predictions["splats"]["sh"][0] if "sh" in predictions["splats"] else predictions["splats"]["colors"][0]).reshape(-1, 3)
            opacities = predictions["splats"]["opacities"][0].reshape(-1)
            
            # Convert to torch tensors if needed
            if not isinstance(means, torch.Tensor):
                means = torch.from_numpy(means)
            if not isinstance(scales, torch.Tensor):
                scales = torch.from_numpy(scales)
            if not isinstance(quats, torch.Tensor):
                quats = torch.from_numpy(quats)
            if not isinstance(colors, torch.Tensor):
                colors = torch.from_numpy(colors)
            if not isinstance(opacities, torch.Tensor):
                opacities = torch.from_numpy(opacities)
            
            if splat_mode == 'ply':
                gs_file = os.path.join(target_dir, "gaussians.ply")
                save_gs_ply(
                    gs_file,
                    means,
                    scales,
                    quats,
                    colors,
                    opacities
                )
                print(f"Saved Gaussian Splatting PLY to: {gs_file}")
                print(f"File exists: {os.path.exists(gs_file)}")
                if os.path.exists(gs_file):
                    print(f"File size: {os.path.getsize(gs_file)} bytes")
            elif splat_mode == 'splat':
                # Save Gaussian splat
                plydata = convert_gs_to_ply(
                        means,
                        scales,
                        quats,
                        colors,
                        opacities
                    )
                gs_file = os.path.join(target_dir, "gaussians.splat")
                gs_file = process_ply_to_splat(plydata, gs_file)
        # Initialize depth and normal view displays with processed data
        depth_vis, normal_vis = initialize_depth_normal_views(
            processed_data
        )
        # Update view selectors and info displays based on available views
        depth_slider, normal_slider, depth_info, normal_info = update_view_selectors(
            processed_data
        )
        # Automatically generate render video
        # Generate render video if possible
        rgb_video_path = None
        depth_video_path = None
        
        if "splats" in predictions:
            # try:
            from pathlib import Path
            
            device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
            # Get camera parameters and image dimensions
            camera_poses = torch.tensor(predictions['camera_poses']).unsqueeze(0).to(device)
            camera_intrs = torch.tensor(predictions['camera_intrs']).unsqueeze(0).to(device)
            H, W = predictions['images'].shape[1], predictions['images'].shape[2]
            
            # Render video
            out_path = Path(target_dir) / "rendered_video"
            render_interpolated_video(
                model.gs_renderer, 
                predictions["splats"], 
                camera_poses, 
                camera_intrs, 
                (H, W), 
                out_path, 
                interp_per_pair=15, 
                loop_reverse=True,
                save_mode="split"
            )
            
            # Check output files
            rgb_video_path = str(out_path) + "_rgb.mp4"
            depth_video_path = str(out_path) + "_depth.mp4"
            
            if not os.path.exists(rgb_video_path) and not os.path.exists(depth_video_path):
                rgb_video_path = None
                depth_video_path = None
                
        # Cleanup
        del predictions
        gc.collect()
        torch.cuda.empty_cache()
        # Get terminal output and restore stdout
        terminal_log = tee.getvalue()
        sys.stdout = old_stdout
        return (
            glbfile,
            log_msg,
            gr.Dropdown(choices=frame_selector_choices, value=frame_selector, interactive=True),
            processed_data,
            depth_vis,
            normal_vis,
            depth_slider,
            normal_slider,
            depth_info,
            normal_info,
            camera_params_file,
            gs_file,
            rgb_video_path,
            depth_video_path,
            terminal_log,
        )
    
    except Exception as e:
        # In case of error, still restore stdout
        terminal_log = tee.getvalue()
        sys.stdout = old_stdout
        print(f"Error occurred: {e}")
        raise
# -------------------------------------------------------------------------
# Helper functions for visualization
# -------------------------------------------------------------------------
def render_depth_visualization(depth_map, mask=None):
    """Generate a color-coded depth visualization image with masking capabilities"""
    if depth_map is None:
        return None
    # Create working copy and identify positive depth values
    depth_copy = depth_map.copy()
    positive_depth_mask = depth_copy > 0
    # Combine with user-provided mask for filtering
    if mask is not None:
        positive_depth_mask = positive_depth_mask & mask
    # Perform percentile-based normalization on valid regions
    if positive_depth_mask.sum() > 0:
        valid_depth_values = depth_copy[positive_depth_mask]
        lower_bound = np.percentile(valid_depth_values, 5)
        upper_bound = np.percentile(valid_depth_values, 95)
        depth_copy[positive_depth_mask] = (depth_copy[positive_depth_mask] - lower_bound) / (upper_bound - lower_bound)
    # Convert to RGB using matplotlib colormap
    import matplotlib.pyplot as plt
    color_mapper = plt.cm.turbo_r
    rgb_result = color_mapper(depth_copy)
    rgb_result = (rgb_result[:, :, :3] * 255).astype(np.uint8)
    # Mark invalid regions with white color
    rgb_result[~positive_depth_mask] = [255, 255, 255]
    return rgb_result
def render_normal_visualization(normal_map, mask=None):
    """Convert surface normal vectors to RGB color representation for display"""
    if normal_map is None:
        return None
    # Make a working copy to avoid modifying original data
    normal_display = normal_map.copy()
    # Handle masking by zeroing out invalid regions
    if mask is not None:
        masked_regions = ~mask
        normal_display[masked_regions] = [0, 0, 0]  # Zero out masked pixels
    # Transform from [-1, 1] to [0, 1] range for RGB display
    normal_display = (normal_display + 1.0) / 2.0
    normal_display = (normal_display * 255).astype(np.uint8)
    return normal_display
def clear_fields():
    """
    Clears the 3D viewer, the stored target_dir, and empties the gallery.
    """
    return None
def update_log():
    """
    Display a quick log message while waiting.
    """
    return "Loading and Reconstructing..."
def get_terminal_output():
    """
    Get current terminal output for real-time display
    """
    global current_terminal_output
    return current_terminal_output
# -------------------------------------------------------------------------
# FunctionExample scene metadata extraction
# -------------------------------------------------------------------------
def extract_example_scenes_metadata(base_directory):
    """
    Extract comprehensive metadata for all scene directories containing valid images.
    
    Args:
        base_directory: Root path where example scene directories are located
        
    Returns:
        Collection of dictionaries with scene details (title, location, preview, etc.)
    """
    from glob import glob
    
    # Return empty list if base directory is missing
    if not os.path.exists(base_directory):
        return []
    
    # Define supported image format extensions
    VALID_IMAGE_FORMATS = ['jpg', 'jpeg', 'png', 'bmp', 'tiff', 'tif']
    
    scenes_data = []
    
    # Process each subdirectory in the base directory
    for directory_name in sorted(os.listdir(base_directory)):
        current_directory = os.path.join(base_directory, directory_name)
        
        # Filter out non-directory items
        if not os.path.isdir(current_directory):
            continue
        
        # Gather all valid image files within the current directory
        discovered_images = []
        for file_format in VALID_IMAGE_FORMATS:
            # Include both lowercase and uppercase format variations
            discovered_images.extend(glob(os.path.join(current_directory, f'*.{file_format}')))
            discovered_images.extend(glob(os.path.join(current_directory, f'*.{file_format.upper()}')))
        
        # Skip directories without any valid images
        if not discovered_images:
            continue
        
        # Ensure consistent image ordering
        discovered_images.sort()
        
        # Construct scene metadata record
        scene_record = {
            'name': directory_name,
            'path': current_directory,
            'thumbnail': discovered_images[0],
            'num_images': len(discovered_images),
            'image_files': discovered_images,
        }
        
        scenes_data.append(scene_record)
    
    return scenes_data
def load_example_scenes(scene_name, scenes):
    """
    Initialize and prepare an example scene for 3D reconstruction processing.
    
    Args:
        scene_name: Identifier of the target scene to load
        scenes: List containing all available scene configurations
        
    Returns:
        Tuple containing processed scene data and status information
    """
    # Locate the target scene configuration by matching names
    target_scene_config = None
    for scene_config in scenes:
        if scene_config["name"] == scene_name:
            target_scene_config = scene_config
            break
    # Handle case where requested scene doesn't exist
    if target_scene_config is None:
        return None, None, None, "Scene not found"
    # Prepare image file paths for processing pipeline
    # Extract all image file paths from the selected scene
    image_file_paths = []
    for img_file_path in target_scene_config["image_files"]:
        image_file_paths.append(img_file_path)
    # Process the scene images through the standard upload pipeline
    processed_target_dir, processed_image_list = process_uploaded_files(image_file_paths, 1.0)
    # Return structured response with scene data and user feedback
    status_message = f"Successfully loaded scene '{scene_name}' containing {target_scene_config['num_images']} images. Click 'Reconstruct' to begin 3D processing."
    
    return (
        None,  # Reset reconstruction visualization
        None,  # Reset gaussian splatting output
        processed_target_dir,  # Provide working directory path
        processed_image_list,  # Update image gallery display
        status_message,
    )
# -------------------------------------------------------------------------
# UI and event handling
# -------------------------------------------------------------------------
theme = gr.themes.Base()
with gr.Blocks(
    theme=theme,
    css="""
    .custom-log * {
        font-style: italic;
        font-size: 22px !important;
        background-image: linear-gradient(120deg, #a9b8f8 0%, #7081e8 60%, #4254c5 100%);
        -webkit-background-clip: text;
        background-clip: text;
        font-weight: bold !important;
        color: transparent !important;
        text-align: center !important;
    }
    .normal-weight-btn button,
    .normal-weight-btn button span,
    .normal-weight-btn button *,
    .normal-weight-btn * {
        font-weight: 400 !important;
    }
    .terminal-output {
        max-height: 400px !important;
        overflow-y: auto !important;
    }
    .terminal-output textarea {
        font-family: 'Monaco', 'Menlo', 'Ubuntu Mono', monospace !important;
        font-size: 13px !important;
        line-height: 1.5 !important;
        color: #333 !important;
        background-color: #f8f9fa !important;
        max-height: 400px !important;
    }
    .example-gallery {
        width: 100% !important;
    }
    .example-gallery img {
        width: 100% !important;
        height: 280px !important;
        object-fit: contain !important;
        aspect-ratio: 16 / 9 !important;
    }
    .example-gallery .grid-wrap {
        width: 100% !important;
    }
    
    /* 滑块导航样式 */
    .depth-tab-improved .gradio-slider input[type="range"] {
        height: 8px !important;
        border-radius: 4px !important;
        background: linear-gradient(90deg, #667eea 0%, #764ba2 100%) !important;
    }
    .depth-tab-improved .gradio-slider input[type="range"]::-webkit-slider-thumb {
        height: 20px !important;
        width: 20px !important;
        border-radius: 50% !important;
        background: #fff !important;
        box-shadow: 0 2px 6px rgba(0,0,0,0.3) !important;
    }
    .depth-tab-improved button {
        transition: all 0.3s ease !important;
        border-radius: 6px !important;
        font-weight: 500 !important;
    }
    .depth-tab-improved button:hover {
        transform: translateY(-1px) !important;
        box-shadow: 0 4px 8px rgba(0,0,0,0.2) !important;
    }
    
    .normal-tab-improved .gradio-slider input[type="range"] {
        height: 8px !important;
        border-radius: 4px !important;
        background: linear-gradient(90deg, #667eea 0%, #764ba2 100%) !important;
    }
    .normal-tab-improved .gradio-slider input[type="range"]::-webkit-slider-thumb {
        height: 20px !important;
        width: 20px !important;
        border-radius: 50% !important;
        background: #fff !important;
        box-shadow: 0 2px 6px rgba(0,0,0,0.3) !important;
    }
    .normal-tab-improved button {
        transition: all 0.3s ease !important;
        border-radius: 6px !important;
        font-weight: 500 !important;
    }
    .normal-tab-improved button:hover {
        transform: translateY(-1px) !important;
        box-shadow: 0 4px 8px rgba(0,0,0,0.2) !important;
    }
    #depth-view-info, #normal-view-info {
        animation: fadeIn 0.5s ease-in-out;
    }
    @keyframes fadeIn {
        from { opacity: 0; transform: translateY(-10px); }
        to { opacity: 1; transform: translateY(0); }
    }
    """
) as demo:
    # State variables for the tabbed interface
    is_example = gr.Textbox(label="is_example", visible=False, value="None")
    num_images = gr.Textbox(label="num_images", visible=False, value="None")
    processed_data_state = gr.State(value=None)
    current_view_index = gr.State(value=0)  # Track current view index for navigation
    # Header and description
    gr.HTML(
    """
    
    
        WorldMirror supports any combination of inputs (images, intrinsics, poses, and depth) and multiple outputs including point clouds, camera parameters, depth maps, normal maps, and 3D Gaussian Splatting (3DGS). 
    How to Use:
    
        - Upload Your Data: Click the "Upload Video or Images" button to add your files. Videos are automatically extracted into frames at one-second intervals.
- Reconstruct: Click the "Reconstruct" button to start the 3D reconstruction.
- Visualize: Explore multiple reconstruction results across different tabs:
                
                    - 3D View: Interactive point cloud/mesh visualization with camera poses (downloadable as GLB)
- 3D Gaussian Splatting: Interactive 3D Gaussian Splatting visualization with RGB and depth videos (downloadable as PLY)
- Depth Maps: Per-view depth estimation results (downloadable as PNG)
- Normal Maps: Per-view surface orientation visualization (downloadable as PNG)
- Camera Parameters: Estimated camera poses and intrinsics (downloadable as JSON)
 
Please note: Loading data and displaying 3D effects may take a moment. For faster performance, we recommend downloading the code from our GitHub and running it locally.
     
    """)
    output_path_state = gr.Textbox(label="Output Path", visible=False, value="None")
    # Main UI components
    with gr.Row(equal_height=False):
        with gr.Column(scale=1):
            file_upload = gr.File(
                file_count="multiple",
                label="Upload Video or Images",
                interactive=True,
                file_types=["image", "video"],
                height="200px",
            )
            time_interval = gr.Slider(
                minimum=0.1,
                maximum=10.0,
                value=1.0,
                step=0.1,
                label="Video Sample interval",
                interactive=True,
                visible=True,
                scale=4,
            )
            resample_btn = gr.Button(
                "Resample",
                visible=True,
                scale=1,
                elem_classes=["normal-weight-btn"],
            )
            image_gallery = gr.Gallery(
                label="Image Preview",
                columns=4,
                height="200px",
                show_download_button=True,
                object_fit="contain",
                preview=True
            )
            
            terminal_output = gr.Textbox(
                label="Terminal Output",
                lines=6,
                max_lines=6,
                interactive=False,
                show_copy_button=True,
                container=True,
                elem_classes=["terminal-output"],
                autoscroll=True
            )
        with gr.Column(scale=3):
            log_output = gr.Markdown(
                "Upload video or images first, then click Reconstruct to start processing",
                elem_classes=["custom-log"],
            )
            with gr.Tabs() as tabs:
                with gr.Tab("3D Gaussian Splatting", id=1) as gs_tab:
                    with gr.Row():
                        with gr.Column(scale=3):
                            gs_output = gr.Model3D(
                                label="Gaussian Splatting",
                                height=500,
                            )
                        with gr.Column(scale=1):
                            gs_rgb_video = gr.Video(
                                label="Rendered RGB Video",
                                height=250,
                                autoplay=False,
                                loop=False,
                                interactive=False,
                            )
                            gs_depth_video = gr.Video(
                                label="Rendered Depth Video",
                                height=250,
                                autoplay=False,
                                loop=False,
                                interactive=False,
                            )
                with gr.Tab("Point Cloud/Mesh", id=0):
                    reconstruction_output = gr.Model3D(
                        label="3D Pointmap/Mesh",
                        height=500,
                        zoom_speed=0.4,
                        pan_speed=0.4,
                    )
                with gr.Tab("Depth", elem_classes=["depth-tab-improved"]):
                    depth_view_info = gr.HTML(
                        value=""
                              "Depth View Navigation | Current: View 1 / 1 views
",
                        elem_id="depth-view-info"
                    )
                    depth_view_slider = gr.Slider(
                        minimum=1, 
                        maximum=1, 
                        step=1, 
                        value=1,
                        label="View Selection Slider",
                        interactive=True,
                        elem_id="depth-view-slider"
                    )
                    depth_map = gr.Image(
                        type="numpy",
                        label="Depth Map",
                        format="png",
                        interactive=False,
                        height=340
                    )
                with gr.Tab("Normal", elem_classes=["normal-tab-improved"]):
                    normal_view_info = gr.HTML(
                        value=""
                              "Normal View Navigation | Current: View 1 / 1 views
",
                        elem_id="normal-view-info"
                    )
                    normal_view_slider = gr.Slider(
                        minimum=1, 
                        maximum=1, 
                        step=1, 
                        value=1,
                        label="View Selection Slider",
                        interactive=True,
                        elem_id="normal-view-slider"
                    )
                    normal_map = gr.Image(
                        type="numpy",
                        label="Normal Map",
                        format="png",
                        interactive=False,
                        height=340
                    )
                with gr.Tab("Camera Parameters", elem_classes=["camera-tab"]):
                    with gr.Row():
                        gr.HTML("")
                        camera_params = gr.DownloadButton(
                            label="Download Camera Parameters",
                            scale=1,
                            variant="primary",
                        )
                        gr.HTML("")
                    
            with gr.Row():
                reconstruct_btn = gr.Button(
                    "Reconstruct", 
                    scale=1, 
                    variant="primary"
                )
                clear_btn = gr.ClearButton(
                    [
                        file_upload,
                        reconstruction_output,
                        log_output,
                        output_path_state,
                        image_gallery,
                        depth_map,
                        normal_map,
                        depth_view_slider,
                        normal_view_slider,
                        depth_view_info,
                        normal_view_info,
                        camera_params,
                        gs_output,
                        gs_rgb_video,
                        gs_depth_video,
                    ],
                    scale=1,
                )
                
            with gr.Row():
                frame_selector = gr.Dropdown(
                        choices=["All"], value="All", label="Show Points of a Specific Frame"
                    )
                
            gr.Markdown("### Reconstruction Options: (not applied to 3DGS)")
            with gr.Row():
                show_camera = gr.Checkbox(label="Show Camera", value=True)
                show_mesh = gr.Checkbox(label="Show Mesh", value=True)
                filter_ambiguous = gr.Checkbox(label="Filter low confidence & depth/normal edges", value=True)
                filter_sky_bg = gr.Checkbox(label="Filter Sky Background", value=False)
        with gr.Column(scale=1):            
            gr.Markdown("### Click to load example scenes")
            realworld_scenes = extract_example_scenes_metadata("examples/realistic") if os.path.exists("examples/realistic") else extract_example_scenes_metadata("examples")
            generated_scenes = extract_example_scenes_metadata("examples/stylistic") if os.path.exists("examples/stylistic") else []
            
            # If no subdirectories exist, fall back to single gallery
            if not os.path.exists("examples/realistic") and not os.path.exists("examples/stylistic"):
                # Fallback: use all scenes from examples directory
                all_scenes = extract_example_scenes_metadata("examples")
                if all_scenes:
                    gallery_items = [
                        (scene["thumbnail"], f"{scene['name']}\n📷 {scene['num_images']} images")
                        for scene in all_scenes
                    ]
                    
                    example_gallery = gr.Gallery(
                        value=gallery_items,
                        label="Example Scenes",
                        columns=1,
                        rows=None,
                        height=800,
                        object_fit="contain",
                        show_label=False,
                        interactive=True,
                        preview=False,
                        allow_preview=False,
                        elem_classes=["example-gallery"]
                    )
                    
                    def handle_example_selection(evt: gr.SelectData):
                        if evt:
                            result = load_example_scenes(all_scenes[evt.index]["name"], all_scenes)
                            return result
                        return (None, None, None, None, "No scene selected")
                    
                    example_gallery.select(
                        fn=handle_example_selection,
                        outputs=[
                            reconstruction_output,
                            gs_output,
                            output_path_state,
                            image_gallery,
                            log_output,
                        ],
                    )
            else:
                # Tabbed interface for categorized examples
                with gr.Tabs():
                    with gr.Tab("🌍 Realistic Cases"):
                        if realworld_scenes:
                            realworld_items = [
                                (scene["thumbnail"], f"{scene['name']}\n📷 {scene['num_images']} images")
                                for scene in realworld_scenes
                            ]
                            
                            realworld_gallery = gr.Gallery(
                                value=realworld_items,
                                label="Real-world Examples",
                                columns=1,
                                rows=None,
                                height=750,
                                object_fit="contain",
                                show_label=False,
                                interactive=True,
                                preview=False,
                                allow_preview=False,
                                elem_classes=["example-gallery"]
                            )
                            
                            def handle_realworld_selection(evt: gr.SelectData):
                                if evt:
                                    result = load_example_scenes(realworld_scenes[evt.index]["name"], realworld_scenes)
                                    return result
                                return (None, None, None, None, "No scene selected")
                            
                            realworld_gallery.select(
                                fn=handle_realworld_selection,
                                outputs=[
                                    reconstruction_output,
                                    gs_output,
                                    output_path_state,
                                    image_gallery,
                                    log_output,
                                ],
                            )
                        else:
                            gr.Markdown("No real-world examples available")
                    
                    with gr.Tab("🎨 Stylistic Cases"):
                        if generated_scenes:
                            generated_items = [
                                (scene["thumbnail"], f"{scene['name']}\n📷 {scene['num_images']} images")
                                for scene in generated_scenes
                            ]
                            
                            generated_gallery = gr.Gallery(
                                value=generated_items,
                                label="Generated Examples",
                                columns=1,
                                rows=None,
                                height=750,
                                object_fit="contain",
                                show_label=False,
                                interactive=True,
                                preview=False,
                                allow_preview=False,
                                elem_classes=["example-gallery"]
                            )
                            
                            def handle_generated_selection(evt: gr.SelectData):
                                if evt:
                                    result = load_example_scenes(generated_scenes[evt.index]["name"], generated_scenes)
                                    return result
                                return (None, None, None, None, "No scene selected")
                            
                            generated_gallery.select(
                                fn=handle_generated_selection,
                                outputs=[
                                    reconstruction_output,
                                    gs_output,
                                    output_path_state,
                                    image_gallery,
                                    log_output,
                                ],
                            )
                        else:
                            gr.Markdown("No generated examples available")
    
    # -------------------------------------------------------------------------
    # Click logic
    # -------------------------------------------------------------------------
    reconstruct_btn.click(fn=clear_fields, inputs=[], outputs=[]).then(
        fn=update_log, inputs=[], outputs=[log_output]
    ).then(
        fn=gradio_demo,
        inputs=[
            output_path_state,
            frame_selector,
            show_camera,
            filter_sky_bg,
            show_mesh,
            filter_ambiguous
        ],
        outputs=[
            reconstruction_output,
            log_output,
            frame_selector,
            processed_data_state,
            depth_map,
            normal_map,
            depth_view_slider,
            normal_view_slider,
            depth_view_info,
            normal_view_info,
            camera_params,
            gs_output,
            gs_rgb_video,
            gs_depth_video,
            terminal_output,
        ],
    ).then(
        fn=lambda: "False",
        inputs=[],
        outputs=[is_example],  # set is_example to "False"
    )
    # -------------------------------------------------------------------------
    # Live update logic
    # -------------------------------------------------------------------------
    def refresh_3d_scene(
        workspace_path,
        frame_selector,
        show_camera,
        is_example,
        filter_sky_bg=False,
        show_mesh=False,
        filter_ambiguous=False
    ):
        """
        Refresh 3D scene visualization
        
        Load prediction data from workspace, generate or reuse GLB scene files based on current parameters,
        and return file paths needed for the 3D viewer.
        
        Args:
            workspace_path: Workspace directory path for reconstruction results
            frame_selector: Frame selector value for filtering points from specific frames
            show_camera: Whether to display camera positions
            is_example: Whether this is an example scene
            filter_sky_bg: Whether to filter sky background
            show_mesh: Whether to display as mesh mode
            filter_ambiguous: Whether to filter low-confidence ambiguous areas
            
        Returns:
            tuple: (GLB scene file path, Gaussian point cloud file path, status message)
        """
        # If example scene is clicked, skip processing directly
        if is_example == "True":
            return (
                gr.update(),
                gr.update(),
                "No reconstruction results available. Please click the Reconstruct button first.",
            )
        # Validate workspace directory path
        if not workspace_path or workspace_path == "None" or not os.path.isdir(workspace_path):
            return (
                gr.update(),
                gr.update(),
                "No reconstruction results available. Please click the Reconstruct button first.",
            )
        # Check if prediction data file exists
        prediction_file_path = os.path.join(workspace_path, "predictions.npz")
        if not os.path.exists(prediction_file_path):
            return (
                gr.update(),
                gr.update(),
                f"Prediction file does not exist: {prediction_file_path}. Please run reconstruction first.",
            )
        # Load prediction data
        prediction_data = np.load(prediction_file_path, allow_pickle=True)
        predictions = {key: prediction_data[key] for key in prediction_data.keys() if key != 'splats'}
        # Generate GLB scene file path (named based on parameter combination)
        safe_frame_name = frame_selector.replace('.', '_').replace(':', '').replace(' ', '_')
        scene_filename = f"scene_{safe_frame_name}_cam{show_camera}_mesh{show_mesh}_edges{filter_ambiguous}_sky{filter_sky_bg}.glb"
        scene_glb_path = os.path.join(workspace_path, scene_filename)
        # If GLB file doesn't exist, generate new scene file
        if not os.path.exists(scene_glb_path):
            scene_model = convert_predictions_to_glb_scene(
                predictions,
                filter_by_frames=frame_selector,
                show_camera=show_camera,
                mask_sky_bg=filter_sky_bg,
                as_mesh=show_mesh,
                mask_ambiguous=filter_ambiguous
            )
            scene_model.export(file_obj=scene_glb_path)
        # Find Gaussian point cloud file
        gaussian_file_path = os.path.join(workspace_path, "gaussians.ply")
        if not os.path.exists(gaussian_file_path):
            gaussian_file_path = None
        return (
            scene_glb_path,
            gaussian_file_path,
            "3D scene updated.",
        )
    
    def refresh_view_displays_on_filter_update(
        workspace_dir,
        sky_background_filter,
        current_processed_data,
        depth_slider_position,
        normal_slider_position,
    ):
        """
        Refresh depth and normal view displays when filter settings change
        
        When the background filter checkbox state changes, regenerate processed data and update all view displays.
        This ensures that filter effects are reflected in real-time in the depth map and normal map visualizations.
        
        Args:
            workspace_dir: Workspace directory path containing prediction data and images
            sky_background_filter: Sky background filter enable status
            current_processed_data: Currently processed visualization data
            depth_slider_position: Current position of the depth view slider
            normal_slider_position: Current position of the normal view slider
            
        Returns:
            tuple: (updated processed data, depth visualization result, normal visualization result)
        """
        
        # Validate workspace directory validity
        if not workspace_dir or workspace_dir == "None" or not os.path.isdir(workspace_dir):
            return current_processed_data, None, None
        # Build and check prediction data file path
        prediction_data_path = os.path.join(workspace_dir, "predictions.npz")
        if not os.path.exists(prediction_data_path):
            return current_processed_data, None, None
        try:
            # Load raw prediction data
            raw_prediction_data = np.load(prediction_data_path, allow_pickle=True)
            predictions_dict = {key: raw_prediction_data[key] for key in raw_prediction_data.keys()}
            # Load image data using WorldMirror's load_images function
            images_directory = os.path.join(workspace_dir, "images")
            image_file_paths = [os.path.join(images_directory, path) for path in os.listdir(images_directory)]
            img = load_and_preprocess_images(image_file_paths)
            img = img.detach().cpu().numpy()
            # Regenerate processed data with new filter settings
            refreshed_data = {}
            for view_idx in range(img.shape[1]):
                view_data = {
                    "image": img[0, view_idx],
                    "points3d": predictions_dict["world_points"][view_idx],
                    "depth": None,
                    "normal": None,
                    "mask": None,
                }
                mask = predictions_dict["final_mask"][view_idx].copy()
                if sky_background_filter:
                    sky_mask = predictions_dict["sky_mask"][view_idx]
                    mask = mask & sky_mask
                view_data["mask"] = mask
                view_data["depth"] = predictions_dict["depth"][view_idx].squeeze()
                view_data["normal"] = predictions_dict["normal"][view_idx]
                refreshed_data[view_idx] = view_data
            # Get current view indices from slider positions (convert to 0-based indices)
            current_depth_index = int(depth_slider_position) - 1 if depth_slider_position else 0
            current_normal_index = int(normal_slider_position) - 1 if normal_slider_position else 0
            # Update depth and normal views with new filter data
            updated_depth_visualization = update_depth_view(refreshed_data, current_depth_index)
            updated_normal_visualization = update_normal_view(refreshed_data, current_normal_index)
            return refreshed_data, updated_depth_visualization, updated_normal_visualization
        except Exception as error:
            print(f"Error occurred while refreshing view displays: {error}")
            return current_processed_data, None, None
    frame_selector.change(
        refresh_3d_scene,
        [
            output_path_state,
            frame_selector,
            show_camera,
            is_example,
            filter_sky_bg,
            show_mesh,
            filter_ambiguous
        ],
        [reconstruction_output, gs_output, log_output],
    )
    show_camera.change(
        refresh_3d_scene,
        [
            output_path_state,
            frame_selector,
            show_camera,
            is_example,
            filter_sky_bg,
            show_mesh,
            filter_ambiguous
        ],
        [reconstruction_output, gs_output, log_output],
    )
    show_mesh.change(
        refresh_3d_scene,
        [
            output_path_state,
            frame_selector,
            show_camera,
            is_example,
            filter_sky_bg,
            show_mesh,
            filter_ambiguous
        ],
        [reconstruction_output, gs_output, log_output],
    )
    
    filter_sky_bg.change(
        refresh_3d_scene,
        [
            output_path_state,
            frame_selector,
            show_camera,
            is_example,
            filter_sky_bg,
            show_mesh,
            filter_ambiguous
        ],
        [reconstruction_output, gs_output, log_output],
    ).then(
        fn=refresh_view_displays_on_filter_update,
        inputs=[
            output_path_state,
            filter_sky_bg,
            processed_data_state,
            depth_view_slider,
            normal_view_slider,
        ],
        outputs=[
            processed_data_state,
            depth_map,
            normal_map,
        ],
    )
    filter_ambiguous.change(
        refresh_3d_scene,
        [
            output_path_state,
            frame_selector,
            show_camera,
            is_example,
            filter_sky_bg,
            show_mesh,
            filter_ambiguous
        ],
        [reconstruction_output, gs_output, log_output],
    ).then(
        fn=refresh_view_displays_on_filter_update,
        inputs=[
            output_path_state,
            filter_sky_bg,
            processed_data_state,
            depth_view_slider,
            normal_view_slider,
        ],
        outputs=[
            processed_data_state,
            depth_map,
            normal_map,
        ],
    )
    # -------------------------------------------------------------------------
    # Auto update gallery when user uploads or changes files
    # -------------------------------------------------------------------------
    def update_gallery_on_file_upload(files, interval):
        if not files:
            return None, None, None, ""
        
        # Capture terminal output
        tee = TeeOutput()
        old_stdout = sys.stdout
        sys.stdout = tee
        
        try:
            target_dir, image_paths = process_uploaded_files(files, interval)
            terminal_log = tee.getvalue()
            sys.stdout = old_stdout
            
            return (
                target_dir,
                image_paths,
                "Upload complete. Click 'Reconstruct' to begin 3D processing.",
                terminal_log,
            )
        except Exception as e:
            terminal_log = tee.getvalue()
            sys.stdout = old_stdout
            print(f"Error occurred: {e}")
            raise
    def resample_video_with_new_interval(files, new_interval, current_target_dir):
        """Resample video with new slider value"""
        if not files:
            return (
                current_target_dir,
                None,
                "No files to resample.",
                "",
            )
        # Check if we have videos to resample
        video_extensions = [
            ".mp4",
            ".avi",
            ".mov",
            ".mkv",
            ".wmv",
            ".flv",
            ".webm",
            ".m4v",
            ".3gp",
        ]
        has_video = any(
            os.path.splitext(
                str(file_data["name"] if isinstance(file_data, dict) else file_data)
            )[1].lower()
            in video_extensions
            for file_data in files
        )
        if not has_video:
            return (
                current_target_dir,
                None,
                "No videos found to resample.",
                "",
            )
        # Capture terminal output
        tee = TeeOutput()
        old_stdout = sys.stdout
        sys.stdout = tee
        
        try:
            # Clean up old target directory if it exists
            if (
                current_target_dir
                and current_target_dir != "None"
                and os.path.exists(current_target_dir)
            ):
                shutil.rmtree(current_target_dir)
            # Process files with new interval
            target_dir, image_paths = process_uploaded_files(files, new_interval)
            
            terminal_log = tee.getvalue()
            sys.stdout = old_stdout
            return (
                target_dir,
                image_paths,
                f"Video resampled with {new_interval}s interval. Click 'Reconstruct' to begin 3D processing.",
                terminal_log,
            )
        except Exception as e:
            terminal_log = tee.getvalue()
            sys.stdout = old_stdout
            print(f"Error occurred: {e}")
            raise
    file_upload.change(
        fn=update_gallery_on_file_upload,
        inputs=[file_upload, time_interval],
        outputs=[output_path_state, image_gallery, log_output, terminal_output],
    )
    resample_btn.click(
        fn=resample_video_with_new_interval,
        inputs=[file_upload, time_interval, output_path_state],
        outputs=[output_path_state, image_gallery, log_output, terminal_output],
    )
    # -------------------------------------------------------------------------
    # Navigation for Depth, Normal tabs
    # -------------------------------------------------------------------------
    def navigate_with_slider(processed_data, target_view):
        """Navigate to specified view using slider"""
        if processed_data is None or len(processed_data) == 0:
            return None, update_view_info(1, 1)
        
        # Check if target_view is None or invalid value, and safely convert to int
        try:
            if target_view is None:
                target_view = 1
            else:
                target_view = int(float(target_view))  # Convert to float first then int, handle decimal input
        except (ValueError, TypeError):
            target_view = 1
        
        total_views = len(processed_data)
        # Ensure view index is within valid range
        view_index = max(1, min(target_view, total_views)) - 1
        
        # Update depth map
        depth_vis = update_depth_view(processed_data, view_index)
        
        # Update view information
        info_html = update_view_info(view_index + 1, total_views)
        
        return depth_vis, info_html
    def navigate_with_slider_normal(processed_data, target_view):
        """Navigate to specified normal view using slider"""
        if processed_data is None or len(processed_data) == 0:
            return None, update_view_info(1, 1, "Normal")
        
        # Check if target_view is None or invalid value, and safely convert to int
        try:
            if target_view is None:
                target_view = 1
            else:
                target_view = int(float(target_view))  # Convert to float first then int, handle decimal input
        except (ValueError, TypeError):
            target_view = 1
        
        total_views = len(processed_data)
        # Ensure view index is within valid range
        view_index = max(1, min(target_view, total_views)) - 1
        
        # Update normal map
        normal_vis = update_normal_view(processed_data, view_index)
        
        # Update view information
        info_html = update_view_info(view_index + 1, total_views, "Normal")
        
        return normal_vis, info_html
    def handle_depth_slider_change(processed_data, target_view):
        return navigate_with_slider(processed_data, target_view)
    
    def handle_normal_slider_change(processed_data, target_view):
        return navigate_with_slider_normal(processed_data, target_view)
    
    depth_view_slider.change(
        fn=handle_depth_slider_change,
        inputs=[processed_data_state, depth_view_slider],
        outputs=[depth_map, depth_view_info]
    )
    
    normal_view_slider.change(
        fn=handle_normal_slider_change,
        inputs=[processed_data_state, normal_view_slider],
        outputs=[normal_map, normal_view_info]
    )
    
    # -------------------------------------------------------------------------
    # Real-time terminal output update
    # -------------------------------------------------------------------------
    # Use a timer to periodically update terminal output
    timer = gr.Timer(value=0.5)  # Update every 0.5 seconds
    timer.tick(
        fn=get_terminal_output,
        inputs=[],
        outputs=[terminal_output]
    )
    
    gr.HTML("""
    
    
    """)
    demo.queue().launch(
        show_error=True,
        share=True,
        ssr_mode=False,
    )