Kousik Kumar Siddavaram commited on
Commit
39dee13
Β·
1 Parent(s): 7f46f51

Updated expression recognition

Browse files
app/Hackathon_setup/exp_recognition.py CHANGED
@@ -4,7 +4,7 @@ import torch
4
  from PIL import Image
5
  import os
6
 
7
- # Remove '.' if running locally; keep it for server (Spaces)
8
  from .exp_recognition_model import facExpRec, processor, device
9
 
10
  #############################################################################################################################
@@ -20,10 +20,14 @@ current_path = os.path.dirname(os.path.abspath(__file__))
20
  # =====================================================
21
  # LOAD MODEL ONCE (global)
22
  # =====================================================
23
- print("πŸ”Ή Loading Expression Recognition Model...")
24
- exp_model = facExpRec().to(device)
25
- exp_model.eval()
26
- print("βœ… Expression model loaded successfully.")
 
 
 
 
27
 
28
  # =====================================================
29
  # FACE DETECTION FUNCTION
@@ -43,8 +47,7 @@ def detected_face(image):
43
  if len(faces) == 0:
44
  return 0
45
 
46
- face_areas = []
47
- images = []
48
  for (x, y, w, h) in faces:
49
  face_cropped = gray[y:y + h, x:x + w]
50
  face_areas.append(w * h)
@@ -64,6 +67,9 @@ def get_expression(img):
64
  Takes an OpenCV BGR image as input, detects the face, and returns the
65
  predicted facial expression as a string.
66
  """
 
 
 
67
  # Detect face
68
  face = detected_face(img)
69
  if face == 0:
@@ -80,8 +86,12 @@ def get_expression(img):
80
  pred_idx = torch.argmax(probs, dim=-1).item()
81
  confidence = probs[0][pred_idx].item()
82
 
83
- # Map index β†’ expression label
84
- expression_label = exp_model.processor.config.id2label.get(pred_idx, "Unknown")
 
 
 
 
85
 
86
  print(f"Detected Expression: {expression_label} ({confidence:.2f})")
87
 
 
4
  from PIL import Image
5
  import os
6
 
7
+ # Keep the dot when deploying (Spaces); remove locally if needed
8
  from .exp_recognition_model import facExpRec, processor, device
9
 
10
  #############################################################################################################################
 
20
  # =====================================================
21
  # LOAD MODEL ONCE (global)
22
  # =====================================================
23
+ print("Loading Expression Recognition Model...")
24
+ try:
25
+ exp_model = facExpRec().to(device)
26
+ exp_model.eval()
27
+ print("Expression model loaded successfully.")
28
+ except Exception as e:
29
+ print(f"Failed to load Expression Model: {e}")
30
+ exp_model = None
31
 
32
  # =====================================================
33
  # FACE DETECTION FUNCTION
 
47
  if len(faces) == 0:
48
  return 0
49
 
50
+ face_areas, images = [], []
 
51
  for (x, y, w, h) in faces:
52
  face_cropped = gray[y:y + h, x:x + w]
53
  face_areas.append(w * h)
 
67
  Takes an OpenCV BGR image as input, detects the face, and returns the
68
  predicted facial expression as a string.
69
  """
70
+ if exp_model is None:
71
+ raise RuntimeError("Expression model not loaded properly.")
72
+
73
  # Detect face
74
  face = detected_face(img)
75
  if face == 0:
 
86
  pred_idx = torch.argmax(probs, dim=-1).item()
87
  confidence = probs[0][pred_idx].item()
88
 
89
+ # Use model.config.id2label safely
90
+ id2label = getattr(exp_model.model.config, "id2label", None)
91
+ if id2label and pred_idx in id2label:
92
+ expression_label = id2label[pred_idx]
93
+ else:
94
+ expression_label = "Unknown"
95
 
96
  print(f"Detected Expression: {expression_label} ({confidence:.2f})")
97
 
app/Hackathon_setup/exp_recognition_model.py CHANGED
@@ -16,19 +16,6 @@ from PIL import Image
16
  # =====================================================
17
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
18
 
19
- # =====================================================
20
- # CLASS DEFINITIONS
21
- # =====================================================
22
- classes = {
23
- 0: 'ANGER',
24
- 1: 'DISGUST',
25
- 2: 'FEAR',
26
- 3: 'HAPPINESS',
27
- 4: 'NEUTRAL',
28
- 5: 'SADNESS',
29
- 6: 'SURPRISE'
30
- }
31
-
32
  # =====================================================
33
  # MODEL NAME (ViT pretrained on facial expressions)
34
  # =====================================================
@@ -42,9 +29,9 @@ try:
42
  model = AutoModelForImageClassification.from_pretrained(MODEL_NAME)
43
  model.to(device)
44
  model.eval()
45
- print(f"βœ… Loaded ViT model '{MODEL_NAME}' successfully on {device}.")
46
  except Exception as e:
47
- print(f"❌ Error loading ViT model: {e}")
48
  processor, model = None, None
49
 
50
  # =====================================================
@@ -74,6 +61,7 @@ class facExpRec(nn.Module):
74
  raise RuntimeError("ViT model not loaded correctly.")
75
  self.model = model
76
  self.processor = processor
 
77
 
78
  def forward(self, x):
79
  """
@@ -91,8 +79,15 @@ class facExpRec(nn.Module):
91
  pred_idx = torch.argmax(probs, dim=-1).item()
92
  confidence = probs[0][pred_idx].item()
93
 
 
 
 
 
 
 
 
94
  return {
95
- "expression": classes.get(pred_idx, "Unknown"),
96
  "confidence": round(confidence, 3)
97
  }
98
 
 
16
  # =====================================================
17
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
18
 
 
 
 
 
 
 
 
 
 
 
 
 
 
19
  # =====================================================
20
  # MODEL NAME (ViT pretrained on facial expressions)
21
  # =====================================================
 
29
  model = AutoModelForImageClassification.from_pretrained(MODEL_NAME)
30
  model.to(device)
31
  model.eval()
32
+ print(f"Loaded ViT model '{MODEL_NAME}' successfully on {device}.")
33
  except Exception as e:
34
+ print(f"Error loading ViT model: {e}")
35
  processor, model = None, None
36
 
37
  # =====================================================
 
61
  raise RuntimeError("ViT model not loaded correctly.")
62
  self.model = model
63
  self.processor = processor
64
+ self.id2label = getattr(self.model.config, "id2label", None)
65
 
66
  def forward(self, x):
67
  """
 
79
  pred_idx = torch.argmax(probs, dim=-1).item()
80
  confidence = probs[0][pred_idx].item()
81
 
82
+ # Prefer model.config.id2label if available
83
+ label = (
84
+ self.id2label[pred_idx]
85
+ if self.id2label and pred_idx in self.id2label
86
+ else "Unknown"
87
+ )
88
+
89
  return {
90
+ "expression": label,
91
  "confidence": round(confidence, 3)
92
  }
93