""" exp_recognition_model.py ------------------------ Facial Expression Recognition using ViT (trpakov/vit-face-expression). This file loads the pretrained model and processor for inference or evaluation. """ import torch import torch.nn as nn from torchvision import transforms from transformers import AutoImageProcessor, AutoModelForImageClassification from PIL import Image # ===================================================== # DEVICE CONFIGURATION # ===================================================== device = torch.device("cuda" if torch.cuda.is_available() else "cpu") # ===================================================== # MODEL NAME (ViT pretrained on facial expressions) # ===================================================== MODEL_NAME = "trpakov/vit-face-expression" # ===================================================== # SAFE GLOBAL LOAD — ensures only one instance of model & processor # ===================================================== try: processor = AutoImageProcessor.from_pretrained(MODEL_NAME) model = AutoModelForImageClassification.from_pretrained(MODEL_NAME) model.to(device) model.eval() print(f"Loaded ViT model '{MODEL_NAME}' successfully on {device}.") except Exception as e: print(f"Error loading ViT model: {e}") processor, model = None, None # ===================================================== # PREPROCESSING FUNCTION # ===================================================== def preprocess_image(img_pil: Image.Image): """ Converts a PIL image into ViT-compatible tensors. Handles normalization, resize, etc. """ if processor is None: raise RuntimeError("Processor not initialized. Check model loading.") inputs = processor(images=img_pil, return_tensors="pt").to(device) return inputs # ===================================================== # MAIN MODEL WRAPPER CLASS # ===================================================== class facExpRec(nn.Module): """ Wrapper around pretrained ViT Face Expression model. Provides convenience for inference and app integration. """ def __init__(self): super(facExpRec, self).__init__() if model is None or processor is None: raise RuntimeError("ViT model not loaded correctly.") self.model = model self.processor = processor self.id2label = getattr(self.model.config, "id2label", None) def forward(self, x): """ Forward expects a PIL image or already-preprocessed input tensor. Returns dict with expression and confidence. """ if isinstance(x, Image.Image): inputs = self.processor(images=x, return_tensors="pt").to(device) else: inputs = x with torch.no_grad(): outputs = self.model(**inputs) probs = torch.nn.functional.softmax(outputs.logits, dim=-1) pred_idx = torch.argmax(probs, dim=-1).item() confidence = probs[0][pred_idx].item() # Prefer model.config.id2label if available label = ( self.id2label[pred_idx] if self.id2label and pred_idx in self.id2label else "Unknown" ) return { "expression": label, "confidence": round(confidence, 3) } # ===================================================== # TORCHVISION-COMPATIBLE TRANSFORM (for testing/training) # ===================================================== trnscm = transforms.Compose([ transforms.Resize((224, 224)), # ViT expects 224x224 RGB transforms.ToTensor(), ])