AreebKhan commited on
Commit
7e09910
·
verified ·
1 Parent(s): 9827b29

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +37 -36
app.py CHANGED
@@ -1,28 +1,25 @@
1
- import gradio as gr
2
  import torch
3
- from transformers import VideoMAEForVideoClassification, VideoMAEFeatureExtractor
4
  import cv2
5
  import numpy as np
6
- import tempfile
7
- import os
8
-
9
- # Load the pre-trained model
10
- model_name = "Sokaina55/xclip-base-patch32-finetuned-ssl-sign-language-recognition"
11
- device = "cuda" if torch.cuda.is_available() else "cpu"
12
 
13
- feature_extractor = VideoMAEFeatureExtractor.from_pretrained(model_name)
14
- model = VideoMAEForVideoClassification.from_pretrained(model_name).to(device)
15
 
16
- def process_video(video_path):
17
- """Processes video and predicts sign language word."""
18
- if not os.path.exists(video_path):
19
- return "Error: Video file not found"
20
 
21
- # Read video
 
22
  cap = cv2.VideoCapture(video_path)
23
  frames = []
24
-
25
- while cap.isOpened():
 
 
 
26
  ret, frame = cap.read()
27
  if not ret:
28
  break
@@ -30,30 +27,34 @@ def process_video(video_path):
30
  frames.append(frame)
31
 
32
  cap.release()
 
 
 
 
 
 
33
 
34
- if len(frames) == 0:
35
- return "Error: No frames extracted from the video"
36
-
37
- # Preprocess frames
38
- inputs = feature_extractor(frames, return_tensors="pt")
39
- inputs = {k: v.to(device) for k, v in inputs.items()}
40
 
41
- # Get predictions
 
 
42
  with torch.no_grad():
43
  outputs = model(**inputs)
44
 
45
- predicted_class = outputs.logits.argmax(-1).item()
46
- class_labels = model.config.id2label # Map predictions to words
47
 
48
- return f"Predicted word: {class_labels.get(predicted_class, 'Unknown')}"
49
 
50
  # Gradio UI
51
- with gr.Blocks() as demo:
52
- gr.Markdown("## Sign Language to Text Recognition")
53
- video_input = gr.Video(label="Upload a sign language video")
54
- output_text = gr.Textbox(label="Predicted Word")
55
- btn = gr.Button("Predict")
56
-
57
- btn.click(fn=process_video, inputs=video_input, outputs=output_text)
58
-
59
- demo.launch()
 
 
1
  import torch
2
+ import gradio as gr
3
  import cv2
4
  import numpy as np
5
+ from transformers import VideoMAEForVideoClassification, VideoMAEImageProcessor
 
 
 
 
 
6
 
7
+ # Model name
8
+ model_name = "MCG-NJU/videomae-base" # Ensure this is a valid model on Hugging Face
9
 
10
+ # Load model and processor
11
+ model = VideoMAEForVideoClassification.from_pretrained(model_name)
12
+ processor = VideoMAEImageProcessor.from_pretrained(model_name)
 
13
 
14
+ # Function to extract frames from video
15
+ def extract_frames(video_path, num_frames=16):
16
  cap = cv2.VideoCapture(video_path)
17
  frames = []
18
+ total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
19
+
20
+ # Select frames evenly spaced throughout the video
21
+ for i in np.linspace(0, total_frames - 1, num_frames, dtype=int):
22
+ cap.set(cv2.CAP_PROP_POS_FRAMES, i)
23
  ret, frame = cap.read()
24
  if not ret:
25
  break
 
27
  frames.append(frame)
28
 
29
  cap.release()
30
+
31
+ # Ensure exactly `num_frames` frames are used
32
+ while len(frames) < num_frames:
33
+ frames.append(frames[-1]) # Duplicate last frame if needed
34
+
35
+ return frames
36
 
37
+ # Function to process video and make predictions
38
+ def process_video(video):
39
+ frames = extract_frames(video)
 
 
 
40
 
41
+ # Process video frames with correct resizing and normalization
42
+ inputs = processor(frames, return_tensors="pt", sampling_rate=30, do_resize=True, size={"shortest_edge": 224}, do_normalize=True)
43
+
44
  with torch.no_grad():
45
  outputs = model(**inputs)
46
 
47
+ logits = outputs.logits
48
+ predicted_class = torch.argmax(logits, dim=1).item()
49
 
50
+ return f"Predicted Class: {predicted_class}"
51
 
52
  # Gradio UI
53
+ iface = gr.Interface(
54
+ fn=process_video,
55
+ inputs=gr.Video(label="Upload a video"),
56
+ outputs=gr.Textbox(label="Prediction"),
57
+ )
58
+
59
+ # Launch app
60
+ iface.launch(server_name="0.0.0.0", server_port=7860, share=True)