File size: 3,613 Bytes
bf73c48
 
 
 
 
 
 
163d359
bf73c48
 
 
c24064d
53743af
bf73c48
 
 
 
 
 
7d4c476
bf73c48
 
 
7d4c476
 
 
 
 
 
 
 
39dee13
7d4c476
39dee13
7d4c476
bf73c48
 
7d4c476
bf73c48
 
 
 
7d4c476
bf73c48
7d4c476
 
bf73c48
 
ba73318
bf73c48
 
 
 
 
7d4c476
 
bf73c48
53743af
bf73c48
7d4c476
 
bf73c48
 
39dee13
c24064d
53743af
bf73c48
7d4c476
 
bf73c48
 
 
ba73318
bf73c48
 
 
 
 
 
 
 
39dee13
 
 
 
 
 
 
bf73c48
39dee13
bf73c48
 
 
 
7d4c476
bf73c48
 
7d4c476
 
bf73c48
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
"""
exp_recognition_model.py
------------------------
Facial Expression Recognition using ViT (trpakov/vit-face-expression).
This file loads the pretrained model and processor for inference or evaluation.
"""

import torch
import torch.nn as nn
from torchvision import transforms
from transformers import AutoImageProcessor, AutoModelForImageClassification
from PIL import Image

# =====================================================
# DEVICE CONFIGURATION
# =====================================================
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# =====================================================
# MODEL NAME (ViT pretrained on facial expressions)
# =====================================================
MODEL_NAME = "trpakov/vit-face-expression"

# =====================================================
# SAFE GLOBAL LOAD — ensures only one instance of model & processor
# =====================================================
try:
    processor = AutoImageProcessor.from_pretrained(MODEL_NAME)
    model = AutoModelForImageClassification.from_pretrained(MODEL_NAME)
    model.to(device)
    model.eval()
    print(f"Loaded ViT model '{MODEL_NAME}' successfully on {device}.")
except Exception as e:
    print(f"Error loading ViT model: {e}")
    processor, model = None, None

# =====================================================
# PREPROCESSING FUNCTION
# =====================================================
def preprocess_image(img_pil: Image.Image):
    """
    Converts a PIL image into ViT-compatible tensors.
    Handles normalization, resize, etc.
    """
    if processor is None:
        raise RuntimeError("Processor not initialized. Check model loading.")
    inputs = processor(images=img_pil, return_tensors="pt").to(device)
    return inputs

# =====================================================
# MAIN MODEL WRAPPER CLASS
# =====================================================
class facExpRec(nn.Module):
    """
    Wrapper around pretrained ViT Face Expression model.
    Provides convenience for inference and app integration.
    """
    def __init__(self):
        super(facExpRec, self).__init__()
        if model is None or processor is None:
            raise RuntimeError("ViT model not loaded correctly.")
        self.model = model
        self.processor = processor
        self.id2label = getattr(self.model.config, "id2label", None)

    def forward(self, x):
        """
        Forward expects a PIL image or already-preprocessed input tensor.
        Returns dict with expression and confidence.
        """
        if isinstance(x, Image.Image):
            inputs = self.processor(images=x, return_tensors="pt").to(device)
        else:
            inputs = x

        with torch.no_grad():
            outputs = self.model(**inputs)
            probs = torch.nn.functional.softmax(outputs.logits, dim=-1)
            pred_idx = torch.argmax(probs, dim=-1).item()
            confidence = probs[0][pred_idx].item()

        # Prefer model.config.id2label if available
        label = (
            self.id2label[pred_idx]
            if self.id2label and pred_idx in self.id2label
            else "Unknown"
        )

        return {
            "expression": label,
            "confidence": round(confidence, 3)
        }

# =====================================================
# TORCHVISION-COMPATIBLE TRANSFORM (for testing/training)
# =====================================================
trnscm = transforms.Compose([
    transforms.Resize((224, 224)),  # ViT expects 224x224 RGB
    transforms.ToTensor(),
])