AreebKhan's picture
Update app.py
c8bf851 verified
raw
history blame
1.8 kB
import gradio as gr
import torch
from transformers import VideoMAEForVideoClassification, VideoMAEFeatureExtractor
import cv2
import numpy as np
import tempfile
import os
# Load the pre-trained model
model_name = "Sokaina55/xclip-base-patch32-finetuned-ssl-sign-language-recognition"
device = "cuda" if torch.cuda.is_available() else "cpu"
feature_extractor = VideoMAEFeatureExtractor.from_pretrained(model_name)
model = VideoMAEForVideoClassification.from_pretrained(model_name).to(device)
def process_video(video_path):
"""Processes video and predicts sign language word."""
if not os.path.exists(video_path):
return "Error: Video file not found"
# Read video
cap = cv2.VideoCapture(video_path)
frames = []
while cap.isOpened():
ret, frame = cap.read()
if not ret:
break
frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
frames.append(frame)
cap.release()
if len(frames) == 0:
return "Error: No frames extracted from the video"
# Preprocess frames
inputs = feature_extractor(frames, return_tensors="pt")
inputs = {k: v.to(device) for k, v in inputs.items()}
# Get predictions
with torch.no_grad():
outputs = model(**inputs)
predicted_class = outputs.logits.argmax(-1).item()
class_labels = model.config.id2label # Map predictions to words
return f"Predicted word: {class_labels.get(predicted_class, 'Unknown')}"
# Gradio UI
with gr.Blocks() as demo:
gr.Markdown("## Sign Language to Text Recognition")
video_input = gr.Video(label="Upload a sign language video")
output_text = gr.Textbox(label="Predicted Word")
btn = gr.Button("Predict")
btn.click(fn=process_video, inputs=video_input, outputs=output_text)
demo.launch()