Spaces:
Paused
Paused
| ### app.py | |
| # User interface for the demo. | |
| ### | |
| import os | |
| import pandas as pd | |
| import gradio as gr | |
| from gradio_rich_textbox import RichTextbox | |
| from demo import VideoCLSModel | |
| def load_samples(data_root): | |
| sample_videos = [] | |
| n_sample = len(os.listdir(f'{data_root}/csv')) | |
| for i in range(n_sample): | |
| df = pd.read_csv(f'{data_root}/csv/{i}.csv') | |
| vid = df['id'].values[0] | |
| sample_videos.append(f'{data_root}/video/{vid}.mp4') | |
| return sample_videos | |
| def format_pred(pred, gt): | |
| tp = '[color=green]{}[/color]' | |
| fp = '[color=red]{}[/color]' | |
| fmt_pred = [] | |
| for x in pred: | |
| if x in gt: | |
| fmt_pred.append(tp.format(x)) | |
| else: | |
| fmt_pred.append(fp.format(x)) | |
| return ', '.join(fmt_pred) | |
| def main(): | |
| lavila = VideoCLSModel("configs/charades_ego/zeroshot.yml") | |
| egovpa = VideoCLSModel("configs/charades_ego/egovpa.yml") | |
| sample_videos = load_samples('data/charades_ego') | |
| print(sample_videos) | |
| def predict(idx): | |
| zeroshot_action, gt_action = lavila.predict(idx) | |
| egovpa_action, gt_action = egovpa.predict(idx) | |
| zeroshot_action = format_pred(zeroshot_action, gt_action) | |
| egovpa_action = format_pred(egovpa_action, gt_action) | |
| return gt_action, zeroshot_action, egovpa_action | |
| with gr.Blocks() as demo: | |
| gr.Markdown( | |
| """ | |
| # Ego-VPA Demo | |
| Choose a sample video and click predict to view the results | |
| (<span style="color:green">correct</span>/<span style="color:red">incorrect</span>). | |
| """ | |
| ) | |
| with gr.Row(): | |
| with gr.Column(): | |
| video = gr.PlayableVideo(label="video", height='300px', interactive=False, autoplay=True) | |
| with gr.Column(): | |
| idx = gr.Number(label="Idx", visible=False) | |
| label = RichTextbox(label="Ground Truth", visible=False) | |
| zeroshot = RichTextbox(label="LaViLa (zero-shot) prediction") | |
| ours = RichTextbox(label="Ego-VPA prediction") | |
| btn = gr.Button("Predict", variant="primary") | |
| btn.click(predict, inputs=[idx], outputs=[label, zeroshot, ours]) | |
| gr.Examples(examples=[[i, x] for i, x in enumerate(sample_videos)], inputs=[idx, video]) | |
| demo.launch(share=True) | |
| if __name__ == "__main__": | |
| main() | |