Spaces:
Runtime error
Runtime error
| import csv | |
| import os | |
| import tempfile | |
| import gradio as gr | |
| import requests | |
| import torch | |
| import torchvision | |
| import torchvision.transforms as T | |
| from PIL import Image | |
| from featup.util import norm | |
| from torchaudio.functional import resample | |
| from denseav.train import LitAVAligner | |
| from denseav.plotting import plot_attention_video, plot_2head_attention_video, plot_feature_video | |
| from denseav.shared import norm, crop_to_divisor, blur_dim | |
| from os.path import join | |
| if __name__ == "__main__": | |
| mode = "hf" | |
| if mode == "local": | |
| sample_videos_dir = "samples" | |
| else: | |
| os.environ['TORCH_HOME'] = '/tmp/.cache' | |
| os.environ['HF_HOME'] = '/tmp/.cache' | |
| os.environ['HF_DATASETS_CACHE'] = '/tmp/.cache' | |
| os.environ['TRANSFORMERS_CACHE'] = '/tmp/.cache' | |
| os.environ['GRADIO_EXAMPLES_CACHE'] = '/tmp/gradio_cache' | |
| sample_videos_dir = "/tmp/samples" | |
| def download_video(url, save_path): | |
| response = requests.get(url) | |
| with open(save_path, 'wb') as file: | |
| file.write(response.content) | |
| base_url = "https://marhamilresearch4.blob.core.windows.net/denseav-public/samples/" | |
| sample_videos_urls = { | |
| "puppies.mp4": base_url + "puppies.mp4", | |
| "peppers.mp4": base_url + "peppers.mp4", | |
| "boat.mp4": base_url + "boat.mp4", | |
| "elephant2.mp4": base_url + "elephant2.mp4", | |
| } | |
| # Ensure the directory for sample videos exists | |
| os.makedirs(sample_videos_dir, exist_ok=True) | |
| # Download each sample video | |
| for filename, url in sample_videos_urls.items(): | |
| save_path = os.path.join(sample_videos_dir, filename) | |
| # Download the video if it doesn't already exist | |
| if not os.path.exists(save_path): | |
| print(f"Downloading {filename}...") | |
| download_video(url, save_path) | |
| else: | |
| print(f"{filename} already exists. Skipping download.") | |
| csv.field_size_limit(100000000) | |
| options = ['language', "sound-language", "sound"] | |
| load_size = 224 | |
| plot_size = 224 | |
| video_input = gr.Video(label="Choose a video to featurize", height=480) | |
| model_option = gr.Radio(options, value="language", label='Choose a model') | |
| video_output1 = gr.Video(label="Audio Video Attention", height=480) | |
| video_output2 = gr.Video(label="Multi-Head Audio Video Attention (Only Availible for sound_and_language)", | |
| height=480) | |
| video_output3 = gr.Video(label="Visual Features", height=480) | |
| models = {o: LitAVAligner.from_pretrained(f"mhamilton723/DenseAV-{o}") for o in options} | |
| def process_video(video, model_option): | |
| model = models[model_option].cuda() | |
| original_frames, audio, info = torchvision.io.read_video(video, end_pts=10, pts_unit='sec') | |
| sample_rate = 16000 | |
| if info["audio_fps"] != sample_rate: | |
| audio = resample(audio, info["audio_fps"], sample_rate) | |
| audio = audio[0].unsqueeze(0) | |
| img_transform = T.Compose([ | |
| T.Resize(load_size, Image.BILINEAR), | |
| lambda x: crop_to_divisor(x, 8), | |
| lambda x: x.to(torch.float32) / 255, | |
| norm]) | |
| frames = torch.cat([img_transform(f.permute(2, 0, 1)).unsqueeze(0) for f in original_frames], axis=0) | |
| plotting_img_transform = T.Compose([ | |
| T.Resize(plot_size, Image.BILINEAR), | |
| lambda x: crop_to_divisor(x, 8), | |
| lambda x: x.to(torch.float32) / 255]) | |
| frames_to_plot = plotting_img_transform(original_frames.permute(0, 3, 1, 2)) | |
| with torch.no_grad(): | |
| audio_feats = model.forward_audio({"audio": audio.cuda()}) | |
| audio_feats = {k: v.cpu() for k, v in audio_feats.items()} | |
| image_feats = model.forward_image({"frames": frames.unsqueeze(0).cuda()}, max_batch_size=2) | |
| image_feats = {k: v.cpu() for k, v in image_feats.items()} | |
| sim_by_head = model.sim_agg.get_pairwise_sims( | |
| {**image_feats, **audio_feats}, | |
| raw=False, | |
| agg_sim=False, | |
| agg_heads=False | |
| ).mean(dim=-2).cpu() | |
| sim_by_head = blur_dim(sim_by_head, window=3, dim=-1) | |
| print(sim_by_head.shape) | |
| temp_video_path_1 = tempfile.mktemp(suffix='.mp4') | |
| plot_attention_video( | |
| sim_by_head, | |
| frames_to_plot, | |
| audio, | |
| info["video_fps"], | |
| sample_rate, | |
| temp_video_path_1) | |
| if model_option == "sound_and_language": | |
| temp_video_path_2 = tempfile.mktemp(suffix='.mp4') | |
| plot_2head_attention_video( | |
| sim_by_head, | |
| frames_to_plot, | |
| audio, | |
| info["video_fps"], | |
| sample_rate, | |
| temp_video_path_2) | |
| else: | |
| temp_video_path_2 = None | |
| temp_video_path_3 = tempfile.mktemp(suffix='.mp4') | |
| temp_video_path_4 = tempfile.mktemp(suffix='.mp4') | |
| plot_feature_video( | |
| image_feats["image_feats"].cpu(), | |
| audio_feats['audio_feats'].cpu(), | |
| frames_to_plot, | |
| audio, | |
| info["video_fps"], | |
| sample_rate, | |
| temp_video_path_3, | |
| temp_video_path_4, | |
| ) | |
| # return temp_video_path_1, temp_video_path_2, temp_video_path_3, temp_video_path_4 | |
| return temp_video_path_1, temp_video_path_2, temp_video_path_3 | |
| with gr.Blocks() as demo: | |
| with gr.Column(): | |
| gr.Markdown("## Visualizing Sound and Language with DenseAV") | |
| gr.Markdown( | |
| "This demo allows you to explore the inner attention maps of DenseAV's dense multi-head contrastive operator.") | |
| with gr.Row(): | |
| with gr.Column(scale=1): | |
| model_option.render() | |
| with gr.Column(scale=3): | |
| video_input.render() | |
| with gr.Row(): | |
| submit_button = gr.Button("Submit") | |
| with gr.Row(): | |
| gr.Examples( | |
| examples=[ | |
| [join(sample_videos_dir, "puppies.mp4"), "sound_and_language"], | |
| [join(sample_videos_dir, "peppers.mp4"), "language"], | |
| [join(sample_videos_dir, "elephant2.mp4"), "language"], | |
| [join(sample_videos_dir, "boat.mp4"), "language"] | |
| ], | |
| inputs=[video_input, model_option] | |
| ) | |
| with gr.Row(): | |
| video_output1.render() | |
| video_output2.render() | |
| video_output3.render() | |
| submit_button.click(fn=process_video, inputs=[video_input, model_option], | |
| outputs=[video_output1, video_output2, video_output3]) | |
| if mode == "local": | |
| demo.launch(server_name="0.0.0.0", server_port=6006, debug=True) | |
| else: | |
| demo.launch(server_name="0.0.0.0", server_port=7860, debug=True) | |