Spaces:
Sleeping
Sleeping
| """ | |
| exp_recognition_model.py | |
| ------------------------ | |
| Facial Expression Recognition using ViT (trpakov/vit-face-expression). | |
| This file loads the pretrained model and processor for inference or evaluation. | |
| """ | |
| import torch | |
| import torch.nn as nn | |
| from torchvision import transforms | |
| from transformers import AutoImageProcessor, AutoModelForImageClassification | |
| from PIL import Image | |
| # ===================================================== | |
| # DEVICE CONFIGURATION | |
| # ===================================================== | |
| device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
| # ===================================================== | |
| # MODEL NAME (ViT pretrained on facial expressions) | |
| # ===================================================== | |
| MODEL_NAME = "trpakov/vit-face-expression" | |
| # ===================================================== | |
| # SAFE GLOBAL LOAD β ensures only one instance of model & processor | |
| # ===================================================== | |
| try: | |
| processor = AutoImageProcessor.from_pretrained(MODEL_NAME) | |
| model = AutoModelForImageClassification.from_pretrained(MODEL_NAME) | |
| model.to(device) | |
| model.eval() | |
| print(f"Loaded ViT model '{MODEL_NAME}' successfully on {device}.") | |
| except Exception as e: | |
| print(f"Error loading ViT model: {e}") | |
| processor, model = None, None | |
| # ===================================================== | |
| # PREPROCESSING FUNCTION | |
| # ===================================================== | |
| def preprocess_image(img_pil: Image.Image): | |
| """ | |
| Converts a PIL image into ViT-compatible tensors. | |
| Handles normalization, resize, etc. | |
| """ | |
| if processor is None: | |
| raise RuntimeError("Processor not initialized. Check model loading.") | |
| inputs = processor(images=img_pil, return_tensors="pt").to(device) | |
| return inputs | |
| # ===================================================== | |
| # MAIN MODEL WRAPPER CLASS | |
| # ===================================================== | |
| class facExpRec(nn.Module): | |
| """ | |
| Wrapper around pretrained ViT Face Expression model. | |
| Provides convenience for inference and app integration. | |
| """ | |
| def __init__(self): | |
| super(facExpRec, self).__init__() | |
| if model is None or processor is None: | |
| raise RuntimeError("ViT model not loaded correctly.") | |
| self.model = model | |
| self.processor = processor | |
| self.id2label = getattr(self.model.config, "id2label", None) | |
| def forward(self, x): | |
| """ | |
| Forward expects a PIL image or already-preprocessed input tensor. | |
| Returns dict with expression and confidence. | |
| """ | |
| if isinstance(x, Image.Image): | |
| inputs = self.processor(images=x, return_tensors="pt").to(device) | |
| else: | |
| inputs = x | |
| with torch.no_grad(): | |
| outputs = self.model(**inputs) | |
| probs = torch.nn.functional.softmax(outputs.logits, dim=-1) | |
| pred_idx = torch.argmax(probs, dim=-1).item() | |
| confidence = probs[0][pred_idx].item() | |
| # Prefer model.config.id2label if available | |
| label = ( | |
| self.id2label[pred_idx] | |
| if self.id2label and pred_idx in self.id2label | |
| else "Unknown" | |
| ) | |
| return { | |
| "expression": label, | |
| "confidence": round(confidence, 3) | |
| } | |
| # ===================================================== | |
| # TORCHVISION-COMPATIBLE TRANSFORM (for testing/training) | |
| # ===================================================== | |
| trnscm = transforms.Compose([ | |
| transforms.Resize((224, 224)), # ViT expects 224x224 RGB | |
| transforms.ToTensor(), | |
| ]) | |