AreebKhan's picture
Update app.py
7e09910 verified
import torch
import gradio as gr
import cv2
import numpy as np
from transformers import VideoMAEForVideoClassification, VideoMAEImageProcessor
# Model name
model_name = "MCG-NJU/videomae-base" # Ensure this is a valid model on Hugging Face
# Load model and processor
model = VideoMAEForVideoClassification.from_pretrained(model_name)
processor = VideoMAEImageProcessor.from_pretrained(model_name)
# Function to extract frames from video
def extract_frames(video_path, num_frames=16):
cap = cv2.VideoCapture(video_path)
frames = []
total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
# Select frames evenly spaced throughout the video
for i in np.linspace(0, total_frames - 1, num_frames, dtype=int):
cap.set(cv2.CAP_PROP_POS_FRAMES, i)
ret, frame = cap.read()
if not ret:
break
frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
frames.append(frame)
cap.release()
# Ensure exactly `num_frames` frames are used
while len(frames) < num_frames:
frames.append(frames[-1]) # Duplicate last frame if needed
return frames
# Function to process video and make predictions
def process_video(video):
frames = extract_frames(video)
# Process video frames with correct resizing and normalization
inputs = processor(frames, return_tensors="pt", sampling_rate=30, do_resize=True, size={"shortest_edge": 224}, do_normalize=True)
with torch.no_grad():
outputs = model(**inputs)
logits = outputs.logits
predicted_class = torch.argmax(logits, dim=1).item()
return f"Predicted Class: {predicted_class}"
# Gradio UI
iface = gr.Interface(
fn=process_video,
inputs=gr.Video(label="Upload a video"),
outputs=gr.Textbox(label="Prediction"),
)
# Launch app
iface.launch(server_name="0.0.0.0", server_port=7860, share=True)