Spaces:
Sleeping
Sleeping
| """ | |
| face_recognition_model.py | |
| -------------------------- | |
| Defines the FaceNet-based face recognition model for deployment in Hugging Face Spaces. | |
| """ | |
| import torch | |
| import torch.nn as nn | |
| import torch.nn.functional as F | |
| from torchvision import transforms | |
| from facenet_pytorch import InceptionResnetV1, MTCNN | |
| import os | |
| # ===================================================== | |
| # DEVICE CONFIGURATION | |
| # ===================================================== | |
| device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
| # ===================================================== | |
| # PREPROCESSING PIPELINE | |
| # ===================================================== | |
| # Matches the original FaceNet preprocessing | |
| transform = transforms.Compose([ | |
| transforms.Resize((160, 160)), | |
| transforms.ToTensor(), | |
| transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]) | |
| ]) | |
| # ===================================================== | |
| # FACENET MODEL WRAPPER CLASS | |
| # ===================================================== | |
| class FaceRecognitionModel(nn.Module): | |
| """ | |
| Wrapper around pretrained FaceNet (InceptionResnetV1) | |
| Provides: | |
| - Face embedding extraction | |
| - Pairwise similarity computation | |
| - Single image embedding for classification | |
| """ | |
| def __init__(self, pretrained=True): | |
| super(FaceRecognitionModel, self).__init__() | |
| # Load pretrained FaceNet model | |
| self.model = InceptionResnetV1(pretrained='vggface2' if pretrained else None).eval() | |
| self.model = self.model.to(device) | |
| # Optional face detector (used externally) | |
| self.mtcnn = MTCNN(keep_all=False, device=device, post_process=True) | |
| # -------------------------------------------------- | |
| # Forward once: get normalized 512-D embeddings | |
| # -------------------------------------------------- | |
| def forward_once(self, x): | |
| x = x.to(device) | |
| emb = self.model(x) | |
| emb = F.normalize(emb, p=2, dim=1) | |
| return emb | |
| # -------------------------------------------------- | |
| # Compute cosine similarity between two images | |
| # -------------------------------------------------- | |
| def forward(self, img1, img2): | |
| emb1 = self.forward_once(img1) | |
| emb2 = self.forward_once(img2) | |
| similarity = F.cosine_similarity(emb1, emb2) | |
| return similarity | |
| # -------------------------------------------------- | |
| # Extract embedding from a PIL image directly | |
| # -------------------------------------------------- | |
| def extract_embedding(self, image): | |
| """ | |
| Extracts 512-D normalized embedding from a single PIL image. | |
| Handles detection, alignment, and transformation. | |
| """ | |
| with torch.no_grad(): | |
| face = self.mtcnn(image) | |
| if face is None: | |
| return None | |
| face = face.unsqueeze(0).to(device) | |
| emb = self.model(face) | |
| emb = F.normalize(emb, p=2, dim=1) | |
| return emb.cpu().numpy() | |
| # ===================================================== | |
| # MODEL LOADER (FINAL ROBUST FIX) | |
| # ===================================================== | |
| def load_model(model_path='face_recognition_model.t7', pretrained=True): | |
| """ | |
| Loads the FaceRecognitionModel and applies the saved .t7 checkpoint. | |
| FIX: Dynamically strips the 'model.' prefix and loads into the inner | |
| model.model (InceptionResnetV1) to ensure layer names match the source. | |
| """ | |
| # 1. Initialize the full wrapper model | |
| model = FaceRecognitionModel(pretrained=pretrained).to(device) | |
| if os.path.exists(model_path): | |
| try: | |
| checkpoint = torch.load(model_path, map_location=device) | |
| # Determine the actual state dictionary source | |
| state_dict = checkpoint.get('net_dict', checkpoint) | |
| # ------------------------------------------------------------------- | |
| # ROBUST FIX: Prepare a clean state dictionary for the inner model. | |
| # ------------------------------------------------------------------- | |
| # Check if keys are prefixed (saved from the wrapper) | |
| has_wrapper_prefix = any(k.startswith('model.') for k in state_dict.keys()) | |
| new_state_dict = {} | |
| for k, v in state_dict.items(): | |
| if has_wrapper_prefix and k.startswith('model.'): | |
| # Strip the 'model.' prefix (6 characters) | |
| new_key = k[6:] | |
| elif not has_wrapper_prefix and k.startswith('model.'): | |
| # This handles the rare case where a key might be named 'model.something' | |
| # but the majority don't have the prefix. We treat it as a non-prefix case. | |
| new_key = k | |
| else: | |
| new_key = k | |
| new_state_dict[new_key] = v | |
| # Load the cleaned state dictionary into the INNER InceptionResnetV1 model. | |
| # strict=False is essential to ignore the 'logits', 'last_linear', etc. | |
| # layers that facenet_pytorch adds if we only saved the core network. | |
| model.model.load_state_dict(new_state_dict, strict=False) | |
| print(f"Loaded weights from: {model_path} into inner InceptionResnetV1 model.") | |
| except Exception as e: | |
| # Fallback for unexpected formats (e.g., if a file was saved entirely wrong) | |
| print(f"Error during state_dict manipulation and loading: {e}") | |
| print("Attempting simple load into the inner model as a final fallback.") | |
| # Attempt to load original checkpoint directly into inner model with strict=False | |
| model.model.load_state_dict(state_dict, strict=False) | |
| print(f"Loaded weights using final fallback: {model_path}.") | |
| else: | |
| print(f"Warning: Model weights not found at {model_path}. Using pretrained FaceNet only.") | |
| model.eval() | |
| return model | |
| # ===================================================== | |
| # UTILITY VARIABLES | |
| # ===================================================== | |
| classifier = None # Loaded dynamically by face_recognition.py | |
| classes = [] # Populated dynamically with known identities |