Update app.py
Browse files
app.py
CHANGED
|
@@ -5,7 +5,7 @@ import threading
|
|
| 5 |
import time
|
| 6 |
import urllib.request
|
| 7 |
from pathlib import Path
|
| 8 |
-
from typing import List, Union
|
| 9 |
|
| 10 |
try:
|
| 11 |
from typing import Literal
|
|
@@ -15,7 +15,6 @@ except ImportError:
|
|
| 15 |
import av
|
| 16 |
import cv2
|
| 17 |
import numpy as np
|
| 18 |
-
import PIL
|
| 19 |
import streamlit as st
|
| 20 |
from aiortc.contrib.media import MediaPlayer
|
| 21 |
|
|
@@ -77,6 +76,12 @@ def download_file(url, download_to: Path, expected_size=None):
|
|
| 77 |
progress_bar.empty()
|
| 78 |
|
| 79 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 80 |
def main():
|
| 81 |
st.header("WebRTC demo")
|
| 82 |
|
|
@@ -230,28 +235,32 @@ def app_object_detection():
|
|
| 230 |
|
| 231 |
DEFAULT_CONFIDENCE_THRESHOLD = 0.5
|
| 232 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 233 |
class MobileNetSSDVideoTransformer(VideoTransformerBase):
|
| 234 |
confidence_threshold: float
|
| 235 |
-
|
| 236 |
-
|
| 237 |
|
| 238 |
def __init__(self) -> None:
|
| 239 |
self._net = cv2.dnn.readNetFromCaffe(
|
| 240 |
str(PROTOTXT_LOCAL_PATH), str(MODEL_LOCAL_PATH)
|
| 241 |
)
|
| 242 |
self.confidence_threshold = DEFAULT_CONFIDENCE_THRESHOLD
|
| 243 |
-
self.
|
| 244 |
-
self.
|
| 245 |
|
| 246 |
@property
|
| 247 |
-
def
|
| 248 |
-
with self.
|
| 249 |
-
return self.
|
| 250 |
|
| 251 |
def _annotate_image(self, image, detections):
|
| 252 |
# loop over the detections
|
| 253 |
(h, w) = image.shape[:2]
|
| 254 |
-
|
| 255 |
for i in np.arange(0, detections.shape[2]):
|
| 256 |
confidence = detections[0, 0, i, 2]
|
| 257 |
|
|
@@ -263,9 +272,11 @@ def app_object_detection():
|
|
| 263 |
box = detections[0, 0, i, 3:7] * np.array([w, h, w, h])
|
| 264 |
(startX, startY, endX, endY) = box.astype("int")
|
| 265 |
|
|
|
|
|
|
|
|
|
|
| 266 |
# display the prediction
|
| 267 |
-
label = f"{
|
| 268 |
-
labels.append(label)
|
| 269 |
cv2.rectangle(image, (startX, startY), (endX, endY), COLORS[idx], 2)
|
| 270 |
y = startY - 15 if startY - 15 > 15 else startY + 15
|
| 271 |
cv2.putText(
|
|
@@ -277,7 +288,7 @@ def app_object_detection():
|
|
| 277 |
COLORS[idx],
|
| 278 |
2,
|
| 279 |
)
|
| 280 |
-
return image,
|
| 281 |
|
| 282 |
def transform(self, frame: av.VideoFrame) -> np.ndarray:
|
| 283 |
image = frame.to_ndarray(format="bgr24")
|
|
@@ -286,12 +297,12 @@ def app_object_detection():
|
|
| 286 |
)
|
| 287 |
self._net.setInput(blob)
|
| 288 |
detections = self._net.forward()
|
| 289 |
-
annotated_image,
|
| 290 |
|
| 291 |
# NOTE: This `transform` method is called in another thread,
|
| 292 |
# so it must be thread-safe.
|
| 293 |
-
with self.
|
| 294 |
-
self.
|
| 295 |
|
| 296 |
return annotated_image
|
| 297 |
|
|
@@ -309,7 +320,7 @@ def app_object_detection():
|
|
| 309 |
if webrtc_ctx.video_transformer:
|
| 310 |
webrtc_ctx.video_transformer.confidence_threshold = confidence_threshold
|
| 311 |
|
| 312 |
-
if st.checkbox("Show the detected labels"):
|
| 313 |
if webrtc_ctx.state.playing:
|
| 314 |
labels_placeholder = st.empty()
|
| 315 |
# NOTE: The video transformation with object detection and
|
|
@@ -319,7 +330,7 @@ def app_object_detection():
|
|
| 319 |
# are not synchronized.
|
| 320 |
while True:
|
| 321 |
if webrtc_ctx.video_transformer:
|
| 322 |
-
labels_placeholder.
|
| 323 |
time.sleep(0.1)
|
| 324 |
|
| 325 |
st.markdown(
|
|
@@ -371,7 +382,7 @@ def app_streaming():
|
|
| 371 |
|
| 372 |
WEBRTC_CLIENT_SETTINGS.update(
|
| 373 |
{
|
| 374 |
-
"
|
| 375 |
"video": media_file_info["type"] == "video",
|
| 376 |
"audio": media_file_info["type"] == "audio",
|
| 377 |
}
|
|
@@ -405,15 +416,9 @@ def app_sendonly():
|
|
| 405 |
webrtc_ctx.video_receiver.stop()
|
| 406 |
break
|
| 407 |
|
| 408 |
-
|
| 409 |
-
|
| 410 |
-
image_loc.image(img)
|
| 411 |
-
|
| 412 |
|
| 413 |
-
WEBRTC_CLIENT_SETTINGS = ClientSettings(
|
| 414 |
-
rtc_configuration={"iceServers": [{"urls": ["stun:stun.l.google.com:19302"]}]},
|
| 415 |
-
media_stream_constraints={"video": True, "audio": True},
|
| 416 |
-
)
|
| 417 |
|
| 418 |
if __name__ == "__main__":
|
| 419 |
logging.basicConfig(
|
|
|
|
| 5 |
import time
|
| 6 |
import urllib.request
|
| 7 |
from pathlib import Path
|
| 8 |
+
from typing import List, NamedTuple, Union
|
| 9 |
|
| 10 |
try:
|
| 11 |
from typing import Literal
|
|
|
|
| 15 |
import av
|
| 16 |
import cv2
|
| 17 |
import numpy as np
|
|
|
|
| 18 |
import streamlit as st
|
| 19 |
from aiortc.contrib.media import MediaPlayer
|
| 20 |
|
|
|
|
| 76 |
progress_bar.empty()
|
| 77 |
|
| 78 |
|
| 79 |
+
WEBRTC_CLIENT_SETTINGS = ClientSettings(
|
| 80 |
+
rtc_configuration={"iceServers": [{"urls": ["stun:stun.l.google.com:19302"]}]},
|
| 81 |
+
media_stream_constraints={"video": True, "audio": True},
|
| 82 |
+
)
|
| 83 |
+
|
| 84 |
+
|
| 85 |
def main():
|
| 86 |
st.header("WebRTC demo")
|
| 87 |
|
|
|
|
| 235 |
|
| 236 |
DEFAULT_CONFIDENCE_THRESHOLD = 0.5
|
| 237 |
|
| 238 |
+
class Detection(NamedTuple):
|
| 239 |
+
name: str
|
| 240 |
+
prob: float
|
| 241 |
+
|
| 242 |
class MobileNetSSDVideoTransformer(VideoTransformerBase):
|
| 243 |
confidence_threshold: float
|
| 244 |
+
_result: Union[List[Detection], None]
|
| 245 |
+
_result_lock: threading.Lock
|
| 246 |
|
| 247 |
def __init__(self) -> None:
|
| 248 |
self._net = cv2.dnn.readNetFromCaffe(
|
| 249 |
str(PROTOTXT_LOCAL_PATH), str(MODEL_LOCAL_PATH)
|
| 250 |
)
|
| 251 |
self.confidence_threshold = DEFAULT_CONFIDENCE_THRESHOLD
|
| 252 |
+
self._result = None
|
| 253 |
+
self._result_lock = threading.Lock()
|
| 254 |
|
| 255 |
@property
|
| 256 |
+
def result(self) -> Union[List[Detection], None]:
|
| 257 |
+
with self._result_lock:
|
| 258 |
+
return self._result
|
| 259 |
|
| 260 |
def _annotate_image(self, image, detections):
|
| 261 |
# loop over the detections
|
| 262 |
(h, w) = image.shape[:2]
|
| 263 |
+
result: List[Detection] = []
|
| 264 |
for i in np.arange(0, detections.shape[2]):
|
| 265 |
confidence = detections[0, 0, i, 2]
|
| 266 |
|
|
|
|
| 272 |
box = detections[0, 0, i, 3:7] * np.array([w, h, w, h])
|
| 273 |
(startX, startY, endX, endY) = box.astype("int")
|
| 274 |
|
| 275 |
+
name = CLASSES[idx]
|
| 276 |
+
result.append(Detection(name=name, prob=float(confidence)))
|
| 277 |
+
|
| 278 |
# display the prediction
|
| 279 |
+
label = f"{name}: {round(confidence * 100, 2)}%"
|
|
|
|
| 280 |
cv2.rectangle(image, (startX, startY), (endX, endY), COLORS[idx], 2)
|
| 281 |
y = startY - 15 if startY - 15 > 15 else startY + 15
|
| 282 |
cv2.putText(
|
|
|
|
| 288 |
COLORS[idx],
|
| 289 |
2,
|
| 290 |
)
|
| 291 |
+
return image, result
|
| 292 |
|
| 293 |
def transform(self, frame: av.VideoFrame) -> np.ndarray:
|
| 294 |
image = frame.to_ndarray(format="bgr24")
|
|
|
|
| 297 |
)
|
| 298 |
self._net.setInput(blob)
|
| 299 |
detections = self._net.forward()
|
| 300 |
+
annotated_image, result = self._annotate_image(image, detections)
|
| 301 |
|
| 302 |
# NOTE: This `transform` method is called in another thread,
|
| 303 |
# so it must be thread-safe.
|
| 304 |
+
with self._result_lock:
|
| 305 |
+
self._result = result
|
| 306 |
|
| 307 |
return annotated_image
|
| 308 |
|
|
|
|
| 320 |
if webrtc_ctx.video_transformer:
|
| 321 |
webrtc_ctx.video_transformer.confidence_threshold = confidence_threshold
|
| 322 |
|
| 323 |
+
if st.checkbox("Show the detected labels", value=True):
|
| 324 |
if webrtc_ctx.state.playing:
|
| 325 |
labels_placeholder = st.empty()
|
| 326 |
# NOTE: The video transformation with object detection and
|
|
|
|
| 330 |
# are not synchronized.
|
| 331 |
while True:
|
| 332 |
if webrtc_ctx.video_transformer:
|
| 333 |
+
labels_placeholder.table(webrtc_ctx.video_transformer.result)
|
| 334 |
time.sleep(0.1)
|
| 335 |
|
| 336 |
st.markdown(
|
|
|
|
| 382 |
|
| 383 |
WEBRTC_CLIENT_SETTINGS.update(
|
| 384 |
{
|
| 385 |
+
"media_stream_constraints": {
|
| 386 |
"video": media_file_info["type"] == "video",
|
| 387 |
"audio": media_file_info["type"] == "audio",
|
| 388 |
}
|
|
|
|
| 416 |
webrtc_ctx.video_receiver.stop()
|
| 417 |
break
|
| 418 |
|
| 419 |
+
img_rgb = frame.to_ndarray(format="rgb24")
|
| 420 |
+
image_loc.image(img_rgb)
|
|
|
|
|
|
|
| 421 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 422 |
|
| 423 |
if __name__ == "__main__":
|
| 424 |
logging.basicConfig(
|