Spaces:
Running
Running
| #!/usr/bin/env python | |
| # -*- coding: utf-8 -*- | |
| import spaces | |
| import os | |
| import argparse | |
| import gradio as gr | |
| from goldfish_lv import GoldFish_LV | |
| from theme import minigptlv_style, custom_css,text_css | |
| import re | |
| from huggingface_hub import login, hf_hub_download | |
| import time | |
| import moviepy.editor as mp | |
| from index import MemoryIndex | |
| # hf_token = os.environ.get('HF_TKN') | |
| # login(token=hf_token) | |
| def str2bool(v): | |
| if isinstance(v, bool): | |
| return v | |
| if v.lower() in ('yes', 'true', 't', 'y', '1'): | |
| return True | |
| elif v.lower() in ('no', 'false', 'f', 'n', '0'): | |
| return False | |
| else: | |
| raise argparse.ArgumentTypeError('Boolean value expected.') | |
| def get_arguments(): | |
| parser = argparse.ArgumentParser(description="Inference parameters") | |
| parser.add_argument("--cfg-path", default="test_configs/llama2_test_config.yaml") | |
| parser.add_argument("--name", type=str, default='test') | |
| parser.add_argument("--ckpt", type=str, default="checkpoints/video_llama_checkpoint_last.pth") | |
| parser.add_argument("--add_subtitles", action='store_true') | |
| parser.add_argument("--neighbours", type=int, default=3) | |
| parser.add_argument("--eval_opt", type=str, default='all') | |
| parser.add_argument("--max_new_tokens", type=int, default=512) | |
| parser.add_argument("--use_openai_embedding",type=str2bool, default=False) | |
| parser.add_argument("--batch_size", type=int, default=2, help="Batch size for short video clips") | |
| parser.add_argument("--lora_r", type=int, default=64) | |
| parser.add_argument("--lora_alpha", type=int, default=16) | |
| parser.add_argument("--video_path", type=str, help="Path to the video file") | |
| parser.add_argument("--options", nargs="+") | |
| return parser.parse_args() | |
| def download_video(youtube_url, download_finish): | |
| if is_youtube_url(youtube_url): | |
| processed_video_path = goldfish_obj.process_video_url(youtube_url) | |
| download_finish = gr.State(value=True) | |
| return processed_video_path, download_finish | |
| else: | |
| return None, download_finish | |
| def is_youtube_url(url: str) -> bool: | |
| youtube_regex = ( | |
| r'(https?://)?(www\.)?' | |
| '(youtube|youtu|youtube-nocookie)\.(com|be)/' | |
| '(watch\?v=|embed/|v/|.+\?v=)?([^&=%\?]{11})' | |
| ) | |
| return bool(re.match(youtube_regex, url)) | |
| def gradio_long_inference_video(videos_list,tmp_save_path,subtitle_paths, use_subtitles=True): | |
| clips_summary = goldfish_obj.long_inference_video(videos_list,tmp_save_path,subtitle_paths) | |
| return clips_summary | |
| def gradio_short_inference_video(video_path, instruction, use_subtitles=True): | |
| pred = goldfish_obj.short_video_inference(video_path, instruction, use_subtitles) | |
| return pred | |
| def gradio_inference_RAG (instruction,related_information): | |
| pred=goldfish_obj.inference_RAG([instruction], [related_information])[0] | |
| return pred | |
| def inference(video_path, use_subtitles=True, instruction="", number_of_neighbours=3): | |
| start_time = time.time() | |
| video_name = os.path.splitext(os.path.basename(video_path))[0] | |
| goldfish_obj.args.neighbours = number_of_neighbours | |
| print(f"Video name: {video_name}") | |
| video_duration = mp.VideoFileClip(video_path).duration | |
| print(f"Video duration: {video_duration:.2f} seconds") | |
| # if the video duration is more than 2 minutes we need to run the long inference | |
| if video_duration > 180 : | |
| print("Long video") | |
| # if the video data is already stored in the external memory, we can use it directly else we need to run the long inference | |
| file_path=f'new_workspace/clips_summary/demo/{video_name}.json' | |
| if not os.path.exists(file_path): | |
| print("Clips summary is not ready") | |
| videos_list,tmp_save_path=goldfish_obj.split_long_video_into_clips(video_path) | |
| subtitle_paths = [] | |
| for video_p in videos_list: | |
| clip_path = os.path.join(tmp_save_path, video_p) | |
| subtitle_path = goldfish_obj.get_subtitles(clip_path) if use_subtitles else None | |
| subtitle_paths.append(subtitle_path) | |
| gradio_long_inference_video(videos_list,tmp_save_path,subtitle_paths, use_subtitles=use_subtitles) | |
| else: | |
| print("External memory is ready") | |
| os.makedirs("new_workspace/embedding/demo", exist_ok=True) | |
| os.makedirs("new_workspace/open_ai_embedding/demo", exist_ok=True) | |
| if goldfish_obj.args.use_openai_embedding: | |
| embedding_path=f"new_workspace/open_ai_embedding/demo/{video_name}.pkl" | |
| else: | |
| embedding_path=f"new_workspace/embedding/demo/{video_name}.pkl" | |
| external_memory=MemoryIndex(goldfish_obj.args.neighbours,use_openai=goldfish_obj.args.use_openai_embedding) | |
| if os.path.exists(embedding_path): | |
| print("Loading embeddings from pkl file") | |
| external_memory.load_embeddings_from_pkl(embedding_path) | |
| else: | |
| # will embed the information and save it in the pkl file | |
| external_memory.load_documents_from_json(file_path,embedding_path) | |
| # get the most similar context from the external memory to this instruction | |
| related_context_documents,related_context_keys = external_memory.search_by_similarity(instruction) | |
| related_information=goldfish_obj.get_related_context(external_memory,related_context_keys) | |
| pred=gradio_inference_RAG(instruction,related_information) | |
| # remove stored data | |
| # os.remove(file_path) | |
| # os.system(f"rm -r workspace/tmp/{self.video_name}") | |
| # os.system(f"rm -r workspace/subtitles/{self.video_name}") | |
| # os.system(f"rm workspace/tmp/{self.video_id}.mp4") | |
| else: | |
| print("Short video") | |
| goldfish_obj.video_name=video_path.split('/')[-1].split('.')[0] | |
| pred=gradio_short_inference_video(video_path,instruction,use_subtitles) | |
| processing_time = time.time() - start_time | |
| print(f"Processing time: {processing_time:.2f} seconds") | |
| return pred | |
| def process_video(path_url, has_subtitles, instruction, number_of_neighbours): | |
| if is_youtube_url(path_url): | |
| video_path = return_video_path(path_url) | |
| else: | |
| video_path = path_url | |
| pred = inference(video_path, has_subtitles, instruction, number_of_neighbours) | |
| return pred | |
| def return_video_path(youtube_url): | |
| video_id = youtube_url.split("https://www.youtube.com/watch?v=")[-1].split('&')[0] | |
| if video_id: | |
| return os.path.join("workspace", "tmp", f"{video_id}.mp4") | |
| else: | |
| raise ValueError("Invalid YouTube URL provided.") | |
| def run_gradio(): | |
| title = """<h1 align="center">Goldfish Demo </h1>""" | |
| description = """<h5>[ECCV 2024 Accepted]Goldfish: Vision-Language Understanding of Arbitrarily Long Videos</h5>""" | |
| project_page = """<p><a href='https://vision-cair.github.io/MiniGPT4-video/'><img src='https://img.shields.io/badge/Project-Page-Green'></a></p>""" | |
| code_link="""<p><a href='https://github.com/Vision-CAIR/MiniGPT4-video'><img src='repo_imgs/goldfishai_png.png'></a></p>""" | |
| paper_link="""<p><a href=''><img src='https://img.shields.io/badge/Paper-PDF-red'></a></p>""" | |
| with gr.Blocks(title="Goldfish demo",css=text_css ) as demo : | |
| gr.Markdown(title) | |
| gr.Markdown(description) | |
| with gr.Tab("Youtube videos") as youtube_tab: | |
| with gr.Row(): | |
| with gr.Column(): | |
| youtube_link = gr.Textbox(label="YouTube link", placeholder="Paste YouTube URL here") | |
| video_player = gr.Video(autoplay=False) | |
| download_finish = gr.State(value=False) | |
| youtube_link.change( | |
| fn=download_video, | |
| inputs=[youtube_link, download_finish], | |
| outputs=[video_player, download_finish] | |
| ) | |
| with gr.Row(): | |
| with gr.Column(scale=2) : | |
| youtube_question = gr.Textbox(label="Your Question", placeholder="Default: What's this video talking about?") | |
| youtube_has_subtitles = gr.Checkbox(label="Use subtitles", value=True) | |
| youtube_input_note = """<p>For the global questions set the number of neighbours=-1 otherwise use 3 as the defualt.</p>""" | |
| gr.Markdown(youtube_input_note) | |
| # input number | |
| youtube_number_of_neighbours=gr.Number(label="Number of Neighbours",interactive=True,value=3) | |
| youtube_process_button = gr.Button("⛓️ Answer the Question (QA)") | |
| with gr.Column(scale=3): | |
| youtube_answer = gr.Textbox(label="Answer of the question", lines=8, interactive=True, placeholder="Answer of the question will show up here.") | |
| youtube_process_button.click(fn=process_video, inputs=[youtube_link, youtube_has_subtitles, youtube_question,youtube_number_of_neighbours], outputs=[youtube_answer]) | |
| with gr.Tab("Local videos") as local_tab: | |
| with gr.Row(): | |
| with gr.Column(): | |
| local_video_player = gr.Video(sources=["upload"]) | |
| with gr.Row(): | |
| with gr.Column(scale=2): | |
| local_question = gr.Textbox(label="Your Question", placeholder="Default: What's this video talking about?") | |
| local_has_subtitles = gr.Checkbox(label="Use subtitles", value=True) | |
| local_input_note = """<p>For the global questions set the number of neighbours=-1 otherwise use 3 as the defualt.</p>""" | |
| gr.Markdown(local_input_note) | |
| local_number_of_neighbours=gr.Number(label="Number of Neighbours",interactive=True,value=3) | |
| local_process_button = gr.Button("⛓️ Answer the Question (QA)") | |
| with gr.Column(scale=3): | |
| local_answer = gr.Textbox(label="Answer of the question", lines=8, interactive=True, placeholder="Answer of the question will show up here.") | |
| local_process_button.click(fn=process_video, inputs=[local_video_player, local_has_subtitles, local_question,local_number_of_neighbours], outputs=[local_answer]) | |
| demo.queue(max_size=10).launch(show_error=True,share=True, show_api=False,server_port=5000) | |
| if __name__ == "__main__": | |
| args=get_arguments() | |
| goldfish_obj = GoldFish_LV(args) | |
| run_gradio() |