Spaces:
Sleeping
Sleeping
File size: 3,613 Bytes
bf73c48 163d359 bf73c48 c24064d 53743af bf73c48 7d4c476 bf73c48 7d4c476 39dee13 7d4c476 39dee13 7d4c476 bf73c48 7d4c476 bf73c48 7d4c476 bf73c48 7d4c476 bf73c48 ba73318 bf73c48 7d4c476 bf73c48 53743af bf73c48 7d4c476 bf73c48 39dee13 c24064d 53743af bf73c48 7d4c476 bf73c48 ba73318 bf73c48 39dee13 bf73c48 39dee13 bf73c48 7d4c476 bf73c48 7d4c476 bf73c48 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 |
"""
exp_recognition_model.py
------------------------
Facial Expression Recognition using ViT (trpakov/vit-face-expression).
This file loads the pretrained model and processor for inference or evaluation.
"""
import torch
import torch.nn as nn
from torchvision import transforms
from transformers import AutoImageProcessor, AutoModelForImageClassification
from PIL import Image
# =====================================================
# DEVICE CONFIGURATION
# =====================================================
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# =====================================================
# MODEL NAME (ViT pretrained on facial expressions)
# =====================================================
MODEL_NAME = "trpakov/vit-face-expression"
# =====================================================
# SAFE GLOBAL LOAD — ensures only one instance of model & processor
# =====================================================
try:
processor = AutoImageProcessor.from_pretrained(MODEL_NAME)
model = AutoModelForImageClassification.from_pretrained(MODEL_NAME)
model.to(device)
model.eval()
print(f"Loaded ViT model '{MODEL_NAME}' successfully on {device}.")
except Exception as e:
print(f"Error loading ViT model: {e}")
processor, model = None, None
# =====================================================
# PREPROCESSING FUNCTION
# =====================================================
def preprocess_image(img_pil: Image.Image):
"""
Converts a PIL image into ViT-compatible tensors.
Handles normalization, resize, etc.
"""
if processor is None:
raise RuntimeError("Processor not initialized. Check model loading.")
inputs = processor(images=img_pil, return_tensors="pt").to(device)
return inputs
# =====================================================
# MAIN MODEL WRAPPER CLASS
# =====================================================
class facExpRec(nn.Module):
"""
Wrapper around pretrained ViT Face Expression model.
Provides convenience for inference and app integration.
"""
def __init__(self):
super(facExpRec, self).__init__()
if model is None or processor is None:
raise RuntimeError("ViT model not loaded correctly.")
self.model = model
self.processor = processor
self.id2label = getattr(self.model.config, "id2label", None)
def forward(self, x):
"""
Forward expects a PIL image or already-preprocessed input tensor.
Returns dict with expression and confidence.
"""
if isinstance(x, Image.Image):
inputs = self.processor(images=x, return_tensors="pt").to(device)
else:
inputs = x
with torch.no_grad():
outputs = self.model(**inputs)
probs = torch.nn.functional.softmax(outputs.logits, dim=-1)
pred_idx = torch.argmax(probs, dim=-1).item()
confidence = probs[0][pred_idx].item()
# Prefer model.config.id2label if available
label = (
self.id2label[pred_idx]
if self.id2label and pred_idx in self.id2label
else "Unknown"
)
return {
"expression": label,
"confidence": round(confidence, 3)
}
# =====================================================
# TORCHVISION-COMPATIBLE TRANSFORM (for testing/training)
# =====================================================
trnscm = transforms.Compose([
transforms.Resize((224, 224)), # ViT expects 224x224 RGB
transforms.ToTensor(),
])
|