Update app.py
Browse files
app.py
CHANGED
|
@@ -1,4 +1,3 @@
|
|
| 1 |
-
import spaces
|
| 2 |
import gradio as gr
|
| 3 |
import cv2
|
| 4 |
from PIL import Image, ImageDraw, ImageFont
|
|
@@ -9,6 +8,7 @@ import os
|
|
| 9 |
import matplotlib.pyplot as plt
|
| 10 |
from io import BytesIO
|
| 11 |
import tempfile
|
|
|
|
| 12 |
|
| 13 |
# Check if CUDA is available, otherwise use CPU
|
| 14 |
device = 'cuda' if torch.cuda.is_available() else 'cpu'
|
|
@@ -33,18 +33,19 @@ def process_video(video_path, target, progress=gr.Progress()):
|
|
| 33 |
frame_duration = 1 / output_fps
|
| 34 |
video_duration = frame_count / original_fps
|
| 35 |
|
| 36 |
-
processed_frames = []
|
| 37 |
frame_scores = []
|
|
|
|
|
|
|
| 38 |
|
| 39 |
-
for time in progress.tqdm(np.arange(0, video_duration, frame_duration)):
|
| 40 |
frame_number = int(time * original_fps)
|
| 41 |
cap.set(cv2.CAP_PROP_POS_FRAMES, frame_number)
|
| 42 |
ret, img = cap.read()
|
| 43 |
if not ret:
|
| 44 |
break
|
| 45 |
|
| 46 |
-
# Resize the frame
|
| 47 |
-
|
| 48 |
pil_img = Image.fromarray(cv2.cvtColor(img_resized, cv2.COLOR_BGR2RGB))
|
| 49 |
|
| 50 |
# Process single image
|
|
@@ -58,7 +59,7 @@ def process_video(video_path, target, progress=gr.Progress()):
|
|
| 58 |
max_score = 0
|
| 59 |
|
| 60 |
try:
|
| 61 |
-
font = ImageFont.truetype("arial.ttf", 20)
|
| 62 |
except IOError:
|
| 63 |
font = ImageFont.load_default()
|
| 64 |
|
|
@@ -77,15 +78,22 @@ def process_video(video_path, target, progress=gr.Progress()):
|
|
| 77 |
|
| 78 |
max_score = max(max_score, confidence)
|
| 79 |
|
| 80 |
-
|
|
|
|
|
|
|
|
|
|
| 81 |
frame_scores.append(max_score)
|
| 82 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 83 |
cap.release()
|
| 84 |
-
return
|
| 85 |
-
|
| 86 |
def create_heatmap(frame_scores, current_frame):
|
| 87 |
plt.figure(figsize=(12, 3))
|
| 88 |
-
plt.imshow([frame_scores], cmap='hot_r', aspect='auto')
|
| 89 |
cbar = plt.colorbar(label='Confidence')
|
| 90 |
cbar.ax.yaxis.set_ticks_position('left')
|
| 91 |
cbar.ax.yaxis.set_label_position('left')
|
|
@@ -93,13 +101,11 @@ def create_heatmap(frame_scores, current_frame):
|
|
| 93 |
plt.xlabel('Frame')
|
| 94 |
plt.yticks([])
|
| 95 |
|
| 96 |
-
# Add more frame numbers on x-axis
|
| 97 |
num_frames = len(frame_scores)
|
| 98 |
-
step = max(1, num_frames // 10)
|
| 99 |
frame_numbers = range(0, num_frames, step)
|
| 100 |
plt.xticks(frame_numbers, [str(i) for i in frame_numbers])
|
| 101 |
|
| 102 |
-
# Add vertical line for current frame
|
| 103 |
plt.axvline(x=current_frame, color='blue', linestyle='--', linewidth=2)
|
| 104 |
|
| 105 |
plt.tight_layout()
|
|
@@ -121,6 +127,13 @@ def load_sample_frame(video_path):
|
|
| 121 |
frame_rgb = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
|
| 122 |
return frame_rgb
|
| 123 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 124 |
def gradio_app():
|
| 125 |
with gr.Blocks() as app:
|
| 126 |
gr.Markdown("# Video Object Detection with Owlv2")
|
|
@@ -135,28 +148,23 @@ def gradio_app():
|
|
| 135 |
use_sample_button = gr.Button("Use Sample Video")
|
| 136 |
progress_bar = gr.Progress()
|
| 137 |
|
| 138 |
-
|
| 139 |
frame_scores = gr.State([])
|
| 140 |
|
| 141 |
def process_and_update(video, target):
|
| 142 |
-
|
| 143 |
-
if
|
| 144 |
-
heatmap_path = create_heatmap(scores, 0)
|
| 145 |
-
|
|
|
|
| 146 |
return None, None, None, None, error, gr.Slider(maximum=100, value=0)
|
| 147 |
|
| 148 |
-
def update_frame_and_heatmap(frame_index, frames, scores):
|
| 149 |
-
if frames and 0 <= frame_index < len(frames):
|
| 150 |
-
heatmap_path = create_heatmap(scores, frame_index)
|
| 151 |
-
return frames[frame_index], heatmap_path
|
| 152 |
-
return None, None
|
| 153 |
-
|
| 154 |
video_input.upload(process_and_update,
|
| 155 |
inputs=[video_input, target_input],
|
| 156 |
-
outputs=[
|
| 157 |
|
| 158 |
frame_slider.change(update_frame_and_heatmap,
|
| 159 |
-
inputs=[frame_slider,
|
| 160 |
outputs=[output_image, heatmap_output])
|
| 161 |
|
| 162 |
def use_sample_video():
|
|
@@ -165,7 +173,7 @@ def gradio_app():
|
|
| 165 |
|
| 166 |
use_sample_button.click(use_sample_video,
|
| 167 |
inputs=None,
|
| 168 |
-
outputs=[
|
| 169 |
|
| 170 |
# Layout
|
| 171 |
with gr.Row():
|
|
@@ -179,4 +187,15 @@ def gradio_app():
|
|
| 179 |
|
| 180 |
if __name__ == "__main__":
|
| 181 |
app = gradio_app()
|
| 182 |
-
app.launch(share=True)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
import gradio as gr
|
| 2 |
import cv2
|
| 3 |
from PIL import Image, ImageDraw, ImageFont
|
|
|
|
| 8 |
import matplotlib.pyplot as plt
|
| 9 |
from io import BytesIO
|
| 10 |
import tempfile
|
| 11 |
+
import shutil
|
| 12 |
|
| 13 |
# Check if CUDA is available, otherwise use CPU
|
| 14 |
device = 'cuda' if torch.cuda.is_available() else 'cpu'
|
|
|
|
| 33 |
frame_duration = 1 / output_fps
|
| 34 |
video_duration = frame_count / original_fps
|
| 35 |
|
|
|
|
| 36 |
frame_scores = []
|
| 37 |
+
temp_dir = tempfile.mkdtemp()
|
| 38 |
+
frame_paths = []
|
| 39 |
|
| 40 |
+
for i, time in enumerate(progress.tqdm(np.arange(0, video_duration, frame_duration))):
|
| 41 |
frame_number = int(time * original_fps)
|
| 42 |
cap.set(cv2.CAP_PROP_POS_FRAMES, frame_number)
|
| 43 |
ret, img = cap.read()
|
| 44 |
if not ret:
|
| 45 |
break
|
| 46 |
|
| 47 |
+
# Resize the frame
|
| 48 |
+
img_resized = cv2.resize(img, (640, 360))
|
| 49 |
pil_img = Image.fromarray(cv2.cvtColor(img_resized, cv2.COLOR_BGR2RGB))
|
| 50 |
|
| 51 |
# Process single image
|
|
|
|
| 59 |
max_score = 0
|
| 60 |
|
| 61 |
try:
|
| 62 |
+
font = ImageFont.truetype("arial.ttf", 20)
|
| 63 |
except IOError:
|
| 64 |
font = ImageFont.load_default()
|
| 65 |
|
|
|
|
| 78 |
|
| 79 |
max_score = max(max_score, confidence)
|
| 80 |
|
| 81 |
+
# Save frame to disk
|
| 82 |
+
frame_path = os.path.join(temp_dir, f"frame_{i:04d}.png")
|
| 83 |
+
pil_img.save(frame_path)
|
| 84 |
+
frame_paths.append(frame_path)
|
| 85 |
frame_scores.append(max_score)
|
| 86 |
|
| 87 |
+
# Clear GPU cache every 10 frames
|
| 88 |
+
if i % 10 == 0:
|
| 89 |
+
torch.cuda.empty_cache()
|
| 90 |
+
|
| 91 |
cap.release()
|
| 92 |
+
return frame_paths, frame_scores, None
|
| 93 |
+
|
| 94 |
def create_heatmap(frame_scores, current_frame):
|
| 95 |
plt.figure(figsize=(12, 3))
|
| 96 |
+
plt.imshow([frame_scores], cmap='hot_r', aspect='auto')
|
| 97 |
cbar = plt.colorbar(label='Confidence')
|
| 98 |
cbar.ax.yaxis.set_ticks_position('left')
|
| 99 |
cbar.ax.yaxis.set_label_position('left')
|
|
|
|
| 101 |
plt.xlabel('Frame')
|
| 102 |
plt.yticks([])
|
| 103 |
|
|
|
|
| 104 |
num_frames = len(frame_scores)
|
| 105 |
+
step = max(1, num_frames // 10)
|
| 106 |
frame_numbers = range(0, num_frames, step)
|
| 107 |
plt.xticks(frame_numbers, [str(i) for i in frame_numbers])
|
| 108 |
|
|
|
|
| 109 |
plt.axvline(x=current_frame, color='blue', linestyle='--', linewidth=2)
|
| 110 |
|
| 111 |
plt.tight_layout()
|
|
|
|
| 127 |
frame_rgb = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
|
| 128 |
return frame_rgb
|
| 129 |
|
| 130 |
+
def update_frame_and_heatmap(frame_index, frame_paths, scores):
|
| 131 |
+
if frame_paths and 0 <= frame_index < len(frame_paths):
|
| 132 |
+
frame = Image.open(frame_paths[frame_index])
|
| 133 |
+
heatmap_path = create_heatmap(scores, frame_index)
|
| 134 |
+
return np.array(frame), heatmap_path
|
| 135 |
+
return None, None
|
| 136 |
+
|
| 137 |
def gradio_app():
|
| 138 |
with gr.Blocks() as app:
|
| 139 |
gr.Markdown("# Video Object Detection with Owlv2")
|
|
|
|
| 148 |
use_sample_button = gr.Button("Use Sample Video")
|
| 149 |
progress_bar = gr.Progress()
|
| 150 |
|
| 151 |
+
frame_paths = gr.State([])
|
| 152 |
frame_scores = gr.State([])
|
| 153 |
|
| 154 |
def process_and_update(video, target):
|
| 155 |
+
paths, scores, error = process_video(video, target, progress_bar)
|
| 156 |
+
if paths is not None:
|
| 157 |
+
heatmap_path = create_heatmap(scores, 0)
|
| 158 |
+
first_frame = Image.open(paths[0])
|
| 159 |
+
return paths, scores, np.array(first_frame), heatmap_path, error, gr.Slider(maximum=len(paths) - 1, value=0)
|
| 160 |
return None, None, None, None, error, gr.Slider(maximum=100, value=0)
|
| 161 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 162 |
video_input.upload(process_and_update,
|
| 163 |
inputs=[video_input, target_input],
|
| 164 |
+
outputs=[frame_paths, frame_scores, output_image, heatmap_output, error_output, frame_slider])
|
| 165 |
|
| 166 |
frame_slider.change(update_frame_and_heatmap,
|
| 167 |
+
inputs=[frame_slider, frame_paths, frame_scores],
|
| 168 |
outputs=[output_image, heatmap_output])
|
| 169 |
|
| 170 |
def use_sample_video():
|
|
|
|
| 173 |
|
| 174 |
use_sample_button.click(use_sample_video,
|
| 175 |
inputs=None,
|
| 176 |
+
outputs=[frame_paths, frame_scores, output_image, heatmap_output, error_output, frame_slider])
|
| 177 |
|
| 178 |
# Layout
|
| 179 |
with gr.Row():
|
|
|
|
| 187 |
|
| 188 |
if __name__ == "__main__":
|
| 189 |
app = gradio_app()
|
| 190 |
+
app.launch(share=True)
|
| 191 |
+
|
| 192 |
+
# Cleanup temporary files
|
| 193 |
+
def cleanup():
|
| 194 |
+
for path in frame_paths.value:
|
| 195 |
+
if os.path.exists(path):
|
| 196 |
+
os.remove(path)
|
| 197 |
+
if os.path.exists(temp_dir):
|
| 198 |
+
shutil.rmtree(temp_dir)
|
| 199 |
+
|
| 200 |
+
# Make sure to call cleanup when the app is closed
|
| 201 |
+
# This might require additional setup depending on how you're running the app
|