face-similarity-demo / app /Hackathon_setup /face_recognition_model.py
Kousik Kumar Siddavaram
Updated face recognition model to pretrained InceptionResnetV1
a35f4a3
"""
face_recognition_model.py
--------------------------
Defines the FaceNet-based face recognition model for deployment in Hugging Face Spaces.
"""
import torch
import torch.nn as nn
import torch.nn.functional as F
from torchvision import transforms
from facenet_pytorch import InceptionResnetV1, MTCNN
import os
# =====================================================
# DEVICE CONFIGURATION
# =====================================================
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# =====================================================
# PREPROCESSING PIPELINE
# =====================================================
# Matches the original FaceNet preprocessing
transform = transforms.Compose([
transforms.Resize((160, 160)),
transforms.ToTensor(),
transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
])
# =====================================================
# FACENET MODEL WRAPPER CLASS
# =====================================================
class FaceRecognitionModel(nn.Module):
"""
Wrapper around pretrained FaceNet (InceptionResnetV1)
Provides:
- Face embedding extraction
- Pairwise similarity computation
- Single image embedding for classification
"""
def __init__(self, pretrained=True):
super(FaceRecognitionModel, self).__init__()
# Load pretrained FaceNet model
self.model = InceptionResnetV1(pretrained='vggface2' if pretrained else None).eval()
self.model = self.model.to(device)
# Optional face detector (used externally)
self.mtcnn = MTCNN(keep_all=False, device=device, post_process=True)
# --------------------------------------------------
# Forward once: get normalized 512-D embeddings
# --------------------------------------------------
def forward_once(self, x):
x = x.to(device)
emb = self.model(x)
emb = F.normalize(emb, p=2, dim=1)
return emb
# --------------------------------------------------
# Compute cosine similarity between two images
# --------------------------------------------------
def forward(self, img1, img2):
emb1 = self.forward_once(img1)
emb2 = self.forward_once(img2)
similarity = F.cosine_similarity(emb1, emb2)
return similarity
# --------------------------------------------------
# Extract embedding from a PIL image directly
# --------------------------------------------------
def extract_embedding(self, image):
"""
Extracts 512-D normalized embedding from a single PIL image.
Handles detection, alignment, and transformation.
"""
with torch.no_grad():
face = self.mtcnn(image)
if face is None:
return None
face = face.unsqueeze(0).to(device)
emb = self.model(face)
emb = F.normalize(emb, p=2, dim=1)
return emb.cpu().numpy()
# =====================================================
# MODEL LOADER (FINAL ROBUST FIX)
# =====================================================
def load_model(model_path='face_recognition_model.t7', pretrained=True):
"""
Loads the FaceRecognitionModel and applies the saved .t7 checkpoint.
FIX: Dynamically strips the 'model.' prefix and loads into the inner
model.model (InceptionResnetV1) to ensure layer names match the source.
"""
# 1. Initialize the full wrapper model
model = FaceRecognitionModel(pretrained=pretrained).to(device)
if os.path.exists(model_path):
try:
checkpoint = torch.load(model_path, map_location=device)
# Determine the actual state dictionary source
state_dict = checkpoint.get('net_dict', checkpoint)
# -------------------------------------------------------------------
# ROBUST FIX: Prepare a clean state dictionary for the inner model.
# -------------------------------------------------------------------
# Check if keys are prefixed (saved from the wrapper)
has_wrapper_prefix = any(k.startswith('model.') for k in state_dict.keys())
new_state_dict = {}
for k, v in state_dict.items():
if has_wrapper_prefix and k.startswith('model.'):
# Strip the 'model.' prefix (6 characters)
new_key = k[6:]
elif not has_wrapper_prefix and k.startswith('model.'):
# This handles the rare case where a key might be named 'model.something'
# but the majority don't have the prefix. We treat it as a non-prefix case.
new_key = k
else:
new_key = k
new_state_dict[new_key] = v
# Load the cleaned state dictionary into the INNER InceptionResnetV1 model.
# strict=False is essential to ignore the 'logits', 'last_linear', etc.
# layers that facenet_pytorch adds if we only saved the core network.
model.model.load_state_dict(new_state_dict, strict=False)
print(f"Loaded weights from: {model_path} into inner InceptionResnetV1 model.")
except Exception as e:
# Fallback for unexpected formats (e.g., if a file was saved entirely wrong)
print(f"Error during state_dict manipulation and loading: {e}")
print("Attempting simple load into the inner model as a final fallback.")
# Attempt to load original checkpoint directly into inner model with strict=False
model.model.load_state_dict(state_dict, strict=False)
print(f"Loaded weights using final fallback: {model_path}.")
else:
print(f"Warning: Model weights not found at {model_path}. Using pretrained FaceNet only.")
model.eval()
return model
# =====================================================
# UTILITY VARIABLES
# =====================================================
classifier = None # Loaded dynamically by face_recognition.py
classes = [] # Populated dynamically with known identities