face-similarity-demo / app /Hackathon_setup /exp_recognition_model.py
Kousik Kumar Siddavaram
Updated expression recognition
39dee13
"""
exp_recognition_model.py
------------------------
Facial Expression Recognition using ViT (trpakov/vit-face-expression).
This file loads the pretrained model and processor for inference or evaluation.
"""
import torch
import torch.nn as nn
from torchvision import transforms
from transformers import AutoImageProcessor, AutoModelForImageClassification
from PIL import Image
# =====================================================
# DEVICE CONFIGURATION
# =====================================================
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# =====================================================
# MODEL NAME (ViT pretrained on facial expressions)
# =====================================================
MODEL_NAME = "trpakov/vit-face-expression"
# =====================================================
# SAFE GLOBAL LOAD β€” ensures only one instance of model & processor
# =====================================================
try:
processor = AutoImageProcessor.from_pretrained(MODEL_NAME)
model = AutoModelForImageClassification.from_pretrained(MODEL_NAME)
model.to(device)
model.eval()
print(f"Loaded ViT model '{MODEL_NAME}' successfully on {device}.")
except Exception as e:
print(f"Error loading ViT model: {e}")
processor, model = None, None
# =====================================================
# PREPROCESSING FUNCTION
# =====================================================
def preprocess_image(img_pil: Image.Image):
"""
Converts a PIL image into ViT-compatible tensors.
Handles normalization, resize, etc.
"""
if processor is None:
raise RuntimeError("Processor not initialized. Check model loading.")
inputs = processor(images=img_pil, return_tensors="pt").to(device)
return inputs
# =====================================================
# MAIN MODEL WRAPPER CLASS
# =====================================================
class facExpRec(nn.Module):
"""
Wrapper around pretrained ViT Face Expression model.
Provides convenience for inference and app integration.
"""
def __init__(self):
super(facExpRec, self).__init__()
if model is None or processor is None:
raise RuntimeError("ViT model not loaded correctly.")
self.model = model
self.processor = processor
self.id2label = getattr(self.model.config, "id2label", None)
def forward(self, x):
"""
Forward expects a PIL image or already-preprocessed input tensor.
Returns dict with expression and confidence.
"""
if isinstance(x, Image.Image):
inputs = self.processor(images=x, return_tensors="pt").to(device)
else:
inputs = x
with torch.no_grad():
outputs = self.model(**inputs)
probs = torch.nn.functional.softmax(outputs.logits, dim=-1)
pred_idx = torch.argmax(probs, dim=-1).item()
confidence = probs[0][pred_idx].item()
# Prefer model.config.id2label if available
label = (
self.id2label[pred_idx]
if self.id2label and pred_idx in self.id2label
else "Unknown"
)
return {
"expression": label,
"confidence": round(confidence, 3)
}
# =====================================================
# TORCHVISION-COMPATIBLE TRANSFORM (for testing/training)
# =====================================================
trnscm = transforms.Compose([
transforms.Resize((224, 224)), # ViT expects 224x224 RGB
transforms.ToTensor(),
])