Spaces:
Runtime error
Runtime error
| # MIT License | |
| # Copyright (c) 2024 Jiahao Shao | |
| # Permission is hereby granted, free of charge, to any person obtaining a copy | |
| # of this software and associated documentation files (the "Software"), to deal | |
| # in the Software without restriction, including without limitation the rights | |
| # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell | |
| # copies of the Software, and to permit persons to whom the Software is | |
| # furnished to do so, subject to the following conditions: | |
| # The above copyright notice and this permission notice shall be included in all | |
| # copies or substantial portions of the Software. | |
| # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR | |
| # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, | |
| # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE | |
| # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER | |
| # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, | |
| # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE | |
| # SOFTWARE. | |
| import functools | |
| import os | |
| import zipfile | |
| import tempfile | |
| from io import BytesIO | |
| import spaces | |
| import gradio as gr | |
| import numpy as np | |
| import torch as torch | |
| import torch.nn.functional as F | |
| import xformers | |
| from PIL import Image | |
| from tqdm import tqdm | |
| import mediapy as media | |
| from huggingface_hub import login | |
| from gradio_patches.examples import Examples | |
| from chronodepth.unet_chronodepth import DiffusersUNetSpatioTemporalConditionModelChronodepth | |
| from chronodepth.chronodepth_pipeline import ChronoDepthPipeline | |
| from chronodepth.video_utils import resize_max_res, colorize_video_depth | |
| MAX_FRAME=60 | |
| default_seed = 2024 | |
| default_num_inference_steps = 5 | |
| default_n_tokens = 10 | |
| default_chunk_size = 5 | |
| default_video_processing_resolution = 768 | |
| default_decode_chunk_size = 8 | |
| def run_pipeline(pipe, video_rgb, generator, device): | |
| """ | |
| Run the pipe on the input video. | |
| args: | |
| pipe: ChronoDepthPipeline object | |
| video_rgb: input video, torch.Tensor, shape [T, H, W, 3], range [0, 255] | |
| generator: torch.Generator | |
| returns: | |
| video_depth_pred: predicted depth, torch.Tensor, shape [T, H, W], range [0, 1] | |
| """ | |
| if isinstance(video_rgb, torch.Tensor): | |
| video_rgb = video_rgb.cpu().numpy() | |
| original_height = video_rgb.shape[1] | |
| original_width = video_rgb.shape[2] | |
| # resize the video to the max resolution | |
| video_rgb = resize_max_res(video_rgb, default_video_processing_resolution) | |
| video_rgb = video_rgb.astype(np.float32) / 255.0 | |
| pipe_out = pipe( | |
| video_rgb, | |
| num_inference_steps=default_num_inference_steps, | |
| decode_chunk_size=default_decode_chunk_size, | |
| motion_bucket_id=127, | |
| fps=7, | |
| noise_aug_strength=0.0, | |
| generator=generator, | |
| infer_mode="ours", | |
| sigma_epsilon=-4, | |
| ) | |
| depth_frames_pred = pipe_out.frames | |
| depth_frames_pred = torch.from_numpy(depth_frames_pred).to(device) | |
| depth_frames_pred = F.interpolate(depth_frames_pred, size=(original_height, original_width), mode="bilinear", align_corners=False) | |
| depth_frames_pred = depth_frames_pred.clamp(0, 1) | |
| depth_frames_pred = depth_frames_pred.squeeze(1) | |
| return depth_frames_pred | |
| def process_video( | |
| pipe, | |
| path_input, | |
| num_inference_steps=default_num_inference_steps, | |
| out_max_frames=MAX_FRAME, | |
| progress=gr.Progress(), | |
| ): | |
| if path_input is None: | |
| raise gr.Error( | |
| "Missing video in the first pane: upload a file or use one from the gallery below." | |
| ) | |
| name_base, name_ext = os.path.splitext(os.path.basename(path_input)) | |
| print(f"Processing video {name_base}{name_ext}") | |
| path_output_dir = tempfile.mkdtemp() | |
| path_out_vis = os.path.join(path_output_dir, f"{name_base}_depth_colored.mp4") | |
| path_out_16bit = os.path.join(path_output_dir, f"{name_base}_depth_16bit.zip") | |
| generator = torch.Generator(device=pipe.device).manual_seed(default_seed) | |
| import time | |
| start_time = time.time() | |
| zipf = None | |
| try: | |
| # -------------------- data -------------------- | |
| video_name = path_input.split('/')[-1].split('.')[0] | |
| video_data = media.read_video(path_input) | |
| fps = video_data.metadata.fps | |
| video_length = len(video_data) | |
| video_rgb = np.array(video_data) | |
| duration_sec = video_length / fps | |
| out_duration_sec = out_max_frames / fps | |
| if duration_sec > out_duration_sec: | |
| gr.Warning( | |
| f"Only the first ~{int(out_duration_sec)} seconds will be processed; " | |
| f"use alternative setups such as ChronoDepth on github for full processing" | |
| ) | |
| video_rgb = video_rgb[:out_max_frames] | |
| zipf = zipfile.ZipFile(path_out_16bit, "w", zipfile.ZIP_DEFLATED) | |
| # -------------------- Inference and saving -------------------- | |
| depth_pred = run_pipeline(pipe, video_rgb, generator, pipe.device) # range [0, 1] | |
| depth_pred = depth_pred.cpu().numpy() | |
| depth_colored_pred = colorize_video_depth(depth_pred) # range [0, 1] -> [0, 255] | |
| # -------------------- Save results -------------------- | |
| for i in tqdm(range(len(depth_pred))): | |
| archive_path = os.path.join( | |
| f"{name_base}_depth_16bit", f"{i:05d}.png" | |
| ) | |
| img_byte_arr = BytesIO() | |
| depth_16bit = Image.fromarray((depth_pred[i] * 65535.0).astype(np.uint16)) | |
| depth_16bit.save(img_byte_arr, format="png") | |
| img_byte_arr.seek(0) | |
| zipf.writestr(archive_path, img_byte_arr.read()) | |
| # Export to video | |
| media.write_video(path_out_vis, depth_colored_pred, fps=fps) | |
| finally: | |
| if zipf is not None: | |
| zipf.close() | |
| end_time = time.time() | |
| print(f"Processing time: {end_time - start_time} seconds") | |
| return ( | |
| path_out_vis, | |
| [path_out_vis, path_out_16bit], | |
| ) | |
| def run_demo_server(pipe): | |
| process_pipe_video = spaces.GPU( | |
| functools.partial(process_video, pipe), duration=70 | |
| ) | |
| os.environ["GRADIO_ALLOW_FLAGGING"] = "never" | |
| with gr.Blocks( | |
| analytics_enabled=False, | |
| title="ChronoDepth Video Depth Estimation", | |
| css=""" | |
| #download { | |
| height: 118px; | |
| } | |
| .slider .inner { | |
| width: 5px; | |
| background: #FFF; | |
| } | |
| .viewport { | |
| aspect-ratio: 4/3; | |
| } | |
| h1 { | |
| text-align: center; | |
| display: block; | |
| } | |
| h2 { | |
| text-align: center; | |
| display: block; | |
| } | |
| h3 { | |
| text-align: center; | |
| display: block; | |
| } | |
| """, | |
| ) as demo: | |
| gr.HTML( | |
| """ | |
| <h1>⏰ChronoDepth: Learning Temporally Consistent Video Depth from Video Diffusion Priors</h1> | |
| <div style="text-align: center; margin-top: 20px;"> | |
| <a title="Website" href="https://jiahao-shao1.github.io/ChronoDepth/" target="_blank" rel="noopener noreferrer" style="display: inline-block;"> | |
| <img src="https://img.shields.io/website?url=https%3A%2F%2Fjhaoshao.github.io%2FChronoDepth%2F&up_message=ChronoDepth&up_color=blue&style=flat&logo=timescale&logoColor=%23FFDC0F"> | |
| </a> | |
| <a title="arXiv" href="https://arxiv.org/abs/2312.02145" target="_blank" rel="noopener noreferrer" style="display: inline-block;"> | |
| <img src="https://img.shields.io/badge/arXiv-PDF-b31b1b"> | |
| </a> | |
| <a title="Github" href="https://github.com/jiahao-shao1/ChronoDepth" target="_blank" rel="noopener noreferrer" style="display: inline-block;"> | |
| <img src="https://img.shields.io/github/stars/jhaoshao/ChronoDepth?label=GitHub%20%E2%98%85&logo=github&color=C8C" alt="badge-github-stars"> | |
| </a> | |
| </div> | |
| <p style="margin-top: 20px; text-align: justify;"> | |
| ChronoDepth is the state-of-the-art video depth estimator for streaming videos in the wild. | |
| </p> | |
| <p style="margin-top: 20px; text-align: justify;"> | |
| PS: The maximum video length is limited to 60 frames for the demo. To process longer videos, please use the ChronoDepth on github. | |
| </p> | |
| """ | |
| ) | |
| with gr.Row(): | |
| with gr.Column(): | |
| video_input = gr.Video( | |
| label="Input Video", | |
| sources=["upload"], | |
| ) | |
| with gr.Row(): | |
| video_submit_btn = gr.Button( | |
| value="Compute Depth", variant="primary" | |
| ) | |
| video_reset_btn = gr.Button(value="Reset") | |
| with gr.Column(): | |
| video_output_video = gr.Video( | |
| label="Output video depth (red-near, blue-far)", | |
| interactive=False, | |
| ) | |
| video_output_files = gr.Files( | |
| label="Depth outputs", | |
| elem_id="download", | |
| interactive=False, | |
| ) | |
| Examples( | |
| fn=process_pipe_video, | |
| examples=[ | |
| ["files/elephant.mp4"], | |
| ["files/kitti360_seq_0000.mp4"], | |
| ], | |
| inputs=[video_input], | |
| outputs=[video_output_video, video_output_files], | |
| cache_examples=True, | |
| directory_name="examples_video", | |
| ) | |
| video_submit_btn.click( | |
| fn=process_pipe_video, | |
| inputs=[video_input], | |
| outputs=[video_output_video, video_output_files], | |
| concurrency_limit=1, | |
| ) | |
| video_reset_btn.click( | |
| fn=lambda: (None, None, None), | |
| inputs=[], | |
| outputs=[video_input, video_output_video], | |
| concurrency_limit=1, | |
| ) | |
| demo.queue( | |
| api_open=False, | |
| ).launch( | |
| server_name="0.0.0.0", | |
| server_port=7860, | |
| ) | |
| def main(): | |
| CHECKPOINT = "jhshao/ChronoDepth-v1" | |
| if "HF_TOKEN_LOGIN" in os.environ: | |
| login(token=os.environ["HF_TOKEN_LOGIN"]) | |
| device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
| print(f"Running on device: {device}") | |
| # -------------------- Model -------------------- | |
| unet = DiffusersUNetSpatioTemporalConditionModelChronodepth.from_pretrained( | |
| CHECKPOINT, | |
| low_cpu_mem_usage=True, | |
| torch_dtype=torch.float16, | |
| ) | |
| pipe = ChronoDepthPipeline.from_pretrained( | |
| "stabilityai/stable-video-diffusion-img2vid-xt", | |
| unet=unet, | |
| torch_dtype=torch.float16, | |
| variant="fp16", | |
| ) | |
| pipe.n_tokens = default_n_tokens | |
| pipe.chunk_size = default_chunk_size | |
| try: | |
| pipe.enable_xformers_memory_efficient_attention() | |
| except: | |
| pass # run without xformers | |
| pipe = pipe.to(device) | |
| run_demo_server(pipe) | |
| if __name__ == "__main__": | |
| main() | |