File size: 3,606 Bytes
163d359
 
 
 
53743af
 
39dee13
bf73c48
 
 
 
7d4c476
bf73c48
 
7d4c476
 
 
53743af
bf73c48
7d4c476
 
 
39dee13
 
 
 
 
 
 
 
7d4c476
bf73c48
 
 
163d359
 
bf73c48
 
163d359
7d4c476
 
163d359
7d4c476
163d359
bf73c48
 
163d359
bf73c48
 
39dee13
bf73c48
 
 
 
 
7d4c476
bf73c48
7d4c476
bf73c48
 
 
 
 
 
163d359
 
bf73c48
 
163d359
39dee13
 
 
163d359
 
 
7d4c476
bf73c48
 
7d4c476
bf73c48
 
 
163d359
bf73c48
 
 
 
 
39dee13
 
 
 
 
 
163d359
7d4c476
 
bf73c48
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
import numpy as np
import cv2
import torch
from PIL import Image
import os

# Keep the dot when deploying (Spaces); remove locally if needed
from .exp_recognition_model import facExpRec, processor, device

#############################################################################################################################
#   Caution: Don't change any of the filenames, function names and definitions                                              #
#   Always use the current_path + file_name for referring any files, without it we cannot access files on the server        #
#############################################################################################################################

# =====================================================
# PATH SETUP
# =====================================================
current_path = os.path.dirname(os.path.abspath(__file__))

# =====================================================
# LOAD MODEL ONCE (global)
# =====================================================
print("Loading Expression Recognition Model...")
try:
    exp_model = facExpRec().to(device)
    exp_model.eval()
    print("Expression model loaded successfully.")
except Exception as e:
    print(f"Failed to load Expression Model: {e}")
    exp_model = None

# =====================================================
# FACE DETECTION FUNCTION
# =====================================================
def detected_face(image):
    """
    Detects faces using Haar cascades and returns the face with the largest area.
    Returns 0 if no face detected.
    """
    face_haar = os.path.join(current_path, "haarcascade_frontalface_default.xml")
    eye_haar = os.path.join(current_path, "haarcascade_eye.xml")

    face_cascade = cv2.CascadeClassifier(face_haar)
    gray = cv2.cvtColor(image, cv2.COLOR_BGR2GRAY)
    faces = face_cascade.detectMultiScale(gray, 1.3, 5)

    if len(faces) == 0:
        return 0

    face_areas, images = [], []
    for (x, y, w, h) in faces:
        face_cropped = gray[y:y + h, x:x + w]
        face_areas.append(w * h)
        images.append(face_cropped)

    # Return the face with the maximum area
    required_image = images[np.argmax(face_areas)]
    required_image = Image.fromarray(required_image).convert("RGB")
    return required_image


# =====================================================
# EXPRESSION PREDICTION FUNCTION
# =====================================================
def get_expression(img):
    """
    Takes an OpenCV BGR image as input, detects the face, and returns the
    predicted facial expression as a string.
    """
    if exp_model is None:
        raise RuntimeError("Expression model not loaded properly.")

    # Detect face
    face = detected_face(img)
    if face == 0:
        # Fallback: no face detected, use entire image
        face = Image.fromarray(cv2.cvtColor(img, cv2.COLOR_BGR2RGB))

    # Preprocess with ViT processor
    inputs = processor(images=face, return_tensors="pt").to(device)

    # Inference
    with torch.no_grad():
        outputs = exp_model.model(**inputs)
        probs = torch.nn.functional.softmax(outputs.logits, dim=-1)
        pred_idx = torch.argmax(probs, dim=-1).item()
        confidence = probs[0][pred_idx].item()

    # Use model.config.id2label safely
    id2label = getattr(exp_model.model.config, "id2label", None)
    if id2label and pred_idx in id2label:
        expression_label = id2label[pred_idx]
    else:
        expression_label = "Unknown"

    print(f"Detected Expression: {expression_label} ({confidence:.2f})")

    return expression_label