Spaces:
Sleeping
Sleeping
| import gradio as gr | |
| import cv2 | |
| from PIL import Image | |
| import torch | |
| import numpy as np | |
| import os | |
| from transformers import AutoProcessor, CLIPVisionModel | |
| from detection import detect_image, detect_video | |
| from model import LinearClassifier | |
| def load_model(detection_type): | |
| device = torch.device("cpu") | |
| processor = AutoProcessor.from_pretrained("clip-vit-large-patch14") | |
| clip_model = CLIPVisionModel.from_pretrained("clip-vit-large-patch14", output_attentions=True) | |
| model_path = f"pretrained_models/{detection_type}/clip_weights.pth" | |
| checkpoint = torch.load(model_path, map_location="cpu") | |
| input_dim = checkpoint["linear.weight"].shape[1] | |
| detection_model = LinearClassifier(input_dim) | |
| detection_model.load_state_dict(checkpoint) | |
| detection_model = detection_model.to(device) | |
| return processor, clip_model, detection_model | |
| def process_image(image, detection_type): | |
| processor, clip_model, detection_model = load_model(detection_type) | |
| results = detect_image(image, processor, clip_model, detection_model) | |
| pred_score = 1 - results["pred_score"] | |
| attn_map = results["attn_map"] | |
| return pred_score, attn_map | |
| def process_video(video, detection_type): | |
| processor, clip_model, detection_model = load_model(detection_type) | |
| cap = cv2.VideoCapture(video) | |
| frames = [] | |
| while True: | |
| ret, frame = cap.read() | |
| if not ret: | |
| break | |
| frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB) | |
| pil_image = Image.fromarray(frame) | |
| frames.append(pil_image) | |
| cap.release() | |
| results = detect_video(frames, processor, clip_model, detection_model) | |
| pred_score = results["pred_score"] | |
| attn_map = results["attn_map"] | |
| return pred_score, attn_map | |
| def change_input(input_type): | |
| if input_type == "Image": | |
| return gr.update(value=None, visible=True), gr.update(value=None, visible=False) | |
| elif input_type == "Video": | |
| return gr.update(value=None, visible=False), gr.update(value=None, visible=True) | |
| else: | |
| return gr.update(value=None, visible=False), gr.update(value=None, visible=False) | |
| def determine_model_type(image_path): | |
| if "facial" in image_path.lower(): | |
| return "Facial" | |
| elif "general" in image_path.lower(): | |
| return "General" | |
| else: | |
| return "Facial" # 기본값 | |
| def process_input(input_type, model_type, image, video): | |
| detection_type = "facial" if model_type == "Facial" else "general" | |
| if input_type == "Image" and image is not None: | |
| return process_image(image, detection_type) | |
| elif input_type == "Video" and video is not None: | |
| return process_video(video, detection_type) | |
| else: | |
| return None, None | |
| def process_example(image_path): | |
| model_type = determine_model_type(image_path) | |
| return Image.open(image_path), model_type | |
| fake_examples, real_examples = [], [] | |
| for example in os.listdir("examples/fake"): | |
| fake_examples.append(os.path.join("examples/fake", example)) | |
| for example in os.listdir("examples/real"): | |
| real_examples.append(os.path.join("examples/real", example)) | |
| with gr.Blocks() as demo: | |
| gr.Markdown("## Deepfake Detection : Facial / General") | |
| input_type = gr.Radio(["Image", "Video"], label="Choose Input Type", value="Image") | |
| model_type = gr.Radio(["Facial", "General"], label="Choose Model Type", value="General") | |
| H, W = 300, 300 | |
| image_input = gr.Image(type="pil", label="Upload Image", visible=True, height=H, width=W) | |
| video_input = gr.Video(label="Upload Video", visible=False, height=H, width=W) | |
| process_button = gr.Button("Run Model") | |
| pred_score_output = gr.Textbox(label="Prediction Score : 0 - REAL, 1 - FAKE") | |
| attn_map_output = gr.Image(type="pil", label="Attention Map", height=H, width=W) | |
| # Example Images 추가 | |
| gr.Examples( | |
| examples=fake_examples, | |
| inputs=[image_input], | |
| outputs=[image_input, model_type], | |
| fn=process_example, | |
| cache_examples=False, | |
| examples_per_page=10, | |
| label="Fake Examples" | |
| ) | |
| gr.Examples( | |
| examples=real_examples, | |
| inputs=[image_input], | |
| outputs=[image_input, model_type], | |
| fn=process_example, | |
| cache_examples=False, | |
| examples_per_page=10, | |
| label="Real Examples" | |
| ) | |
| input_type.change(fn=change_input, inputs=[input_type], outputs=[image_input, video_input]) | |
| process_button.click( | |
| fn=process_input, | |
| inputs=[input_type, model_type, image_input, video_input], | |
| outputs=[pred_score_output, attn_map_output] | |
| ) | |
| if __name__ == "__main__": | |
| demo.launch() |