Spaces:
Runtime error
Runtime error
| import gradio as gr | |
| import warnings | |
| import cv2 | |
| import dlib | |
| from pytorch_grad_cam.utils.image import show_cam_on_image | |
| from pytorch_grad_cam.utils.model_targets import ClassifierOutputTarget | |
| import numpy as np | |
| import torch | |
| from retinaface.pre_trained_models import get_model | |
| from Scripts.model import create_cam, create_model | |
| from Scripts.preprocess import crop_face, extract_face, extract_frames | |
| from Scripts.ca_generator import get_augs | |
| import spaces | |
| warnings.filterwarnings('ignore') | |
| device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') | |
| sbcl = create_model("Weights/weights.tar") | |
| sbcl.to(device) | |
| face_detector = get_model("resnet50_2020-07-20", max_size=1024, device=device) | |
| face_detector.eval() | |
| cam_sbcl = create_cam(sbcl) | |
| targets = [ClassifierOutputTarget(1)] | |
| # Examples | |
| examples = ["Examples/Fake/Fake1.PNG", "Examples/Real/Real1.PNG", "Examples/Real/Real2.PNG", "Examples/Fake/Fake3.PNG", | |
| "Examples/Fake/Fake2.PNG", ] | |
| examples_videos = ['Examples/Fake1.mp4', 'Examples/Real1.mp4'] | |
| # dlib Models | |
| dlib_face_detector = dlib.get_frontal_face_detector() | |
| dlib_face_predictor = dlib.shape_predictor( | |
| 'Weights/shape_predictor_81_face_landmarks.dat') | |
| def predict_image(inp): | |
| face_list = extract_face(inp, face_detector) | |
| if len(face_list) == 0: | |
| return {'No face detected!': 1}, None | |
| with torch.no_grad(): | |
| img = torch.tensor(face_list).to(device).float() / 255 | |
| pred = sbcl(img).softmax(1)[:, 1].cpu().data.numpy().tolist()[0] | |
| confidences = {'Real': 1 - pred, 'Fake': pred} | |
| grayscale_cam = cam_sbcl(input_tensor=img, targets=targets, aug_smooth=True) | |
| grayscale_cam = grayscale_cam[0, :] | |
| cam_image = show_cam_on_image(face_list[0].transpose(1, 2, 0) / 255, grayscale_cam, use_rgb=True) | |
| return confidences, cam_image | |
| def predict_video(inp): | |
| face_list, idx_list = extract_frames(inp, 10, face_detector) | |
| with torch.no_grad(): | |
| img = torch.tensor(face_list).to(device).float() / 255 | |
| pred = sbcl(img).softmax(1)[:, 1] | |
| pred_list = [] | |
| idx_img = -1 | |
| for i in range(len(pred)): | |
| if idx_list[i] != idx_img: | |
| pred_list.append([]) | |
| idx_img = idx_list[i] | |
| pred_list[-1].append(pred[i].item()) | |
| pred_res = np.zeros(len(pred_list)) | |
| for i in range(len(pred_res)): | |
| pred_res[i] = max(pred_list[i]) | |
| pred = pred_res.mean() | |
| most_fake = np.argmax(pred_res) | |
| grayscale_cam = cam_sbcl(input_tensor=img[most_fake].unsqueeze(0), targets=targets, aug_smooth=True) | |
| grayscale_cam = grayscale_cam[0, :] | |
| cam_image = show_cam_on_image(face_list[most_fake].transpose(1, 2, 0) / 255, grayscale_cam, use_rgb=True) | |
| return {'Real': 1 - pred, 'Fake': pred}, cam_image | |
| with gr.Blocks(title="Deepfake Detection CL", theme='upsatwal/mlsc_tiet', css=""" | |
| @import url('https://fonts.googleapis.com/css?family=Source+Code+Pro:200'); | |
| #custom_header { | |
| min-height: 3rem; | |
| background-image: url('https://static.pexels.com/photos/414171/pexels-photo-414171.jpeg'); | |
| background-size: cover; | |
| background-position: top; | |
| color: white; | |
| text-align: center; | |
| padding: 0.5rem; | |
| font-family: 'Source Code Pro', monospace; | |
| text-transform: uppercase; | |
| } | |
| #custom_header:hover { | |
| -webkit-animation: slidein 10s; | |
| animation: slidein 10s; | |
| -webkit-animation-fill-mode: forwards; | |
| animation-fill-mode: forwards; | |
| -webkit-animation-iteration-count: infinite; | |
| animation-iteration-count: infinite; | |
| -webkit-animation-direction: alternate; | |
| animation-direction: alternate; | |
| } | |
| @-webkit-keyframes slidein { | |
| from { | |
| background-position: top; | |
| background-size: 3000px; | |
| } | |
| to { | |
| background-position: -100px 0px; | |
| background-size: 2750px; | |
| } | |
| } | |
| @keyframes slidein { | |
| from { | |
| background-position: top; | |
| background-size: 3000px; | |
| } | |
| to { | |
| background-position: -100px 0px; | |
| background-size: 2750px; | |
| } | |
| } | |
| #custom_title { | |
| min-height: 3rem; | |
| text-align: center; | |
| } | |
| .full-width { | |
| width: 100%; | |
| } | |
| .full-width:hover { | |
| background: rgba(75, 75, 250, 0.3); | |
| color: white; | |
| } | |
| """) as demo: | |
| with gr.Tab("Image"): | |
| with gr.Row(): | |
| with gr.Column(): | |
| with gr.Group(): | |
| gr.Markdown("## Deepfake Detection", elem_id="custom_header") | |
| input_image = gr.Image(label="Input Image", height=240) | |
| btn = gr.Button(value="Submit", variant="primary", elem_classes="full-width") | |
| with gr.Column(): | |
| with gr.Group(): | |
| gr.Markdown("## Result", elem_id="custom_header") | |
| output_image = gr.Image(label="GradCAM Image", height=240) | |
| label_probs = gr.Label() | |
| gr.Examples( | |
| examples=examples, | |
| inputs=input_image, | |
| outputs=[label_probs, output_image], | |
| fn=predict_image, | |
| cache_examples=True, | |
| ) | |
| btn.click(predict_image, inputs=input_image, outputs=[label_probs, output_image], api_name="/predict_image") | |
| with gr.Tab("Video"): | |
| with gr.Row(): | |
| with gr.Column(): | |
| with gr.Group(): | |
| gr.Markdown("## Deepfake Detection", elem_id="custom_header") | |
| input_video = gr.Video(label="Input Video", height=240) | |
| btn_video = gr.Button(value="Submit", variant="primary", elem_classes="full-width") | |
| with gr.Column(): | |
| with gr.Group(): | |
| gr.Markdown("## Result", elem_id="custom_header") | |
| output_image_video = gr.Image(label="GradCAM", height=240) | |
| label_probs_video = gr.Label() | |
| gr.Examples( | |
| examples=examples_videos, | |
| inputs=input_video, | |
| outputs=[label_probs_video, output_image_video], | |
| fn=predict_video, | |
| cache_examples=True, | |
| ) | |
| btn_video.click(predict_video, inputs=input_video, outputs=[label_probs_video, output_image_video], api_name="/predict_video") | |
| if __name__ == "__main__": | |
| demo.launch() |