import os import gradio as gr import torch import torch.nn as nn import torch.nn.functional as F import soundfile as sf import torchaudio from transformers import Wav2Vec2FeatureExtractor, Wav2Vec2Model import numpy as np import json # ============================================================ # MODEL DEFINITION # ============================================================ class Wav2Vec2ForSpeakerEmbedding(nn.Module): def __init__(self, embedding_size=256): super().__init__() self.wav2vec2 = Wav2Vec2Model.from_pretrained("facebook/wav2vec2-base") for param in self.wav2vec2.parameters(): param.requires_grad = False self.projection = nn.Sequential( nn.Linear(768, 512), nn.ReLU(), nn.Dropout(0.1), nn.Linear(512, embedding_size) ) def forward(self, input_values): outputs = self.wav2vec2(input_values) hidden_states = outputs.last_hidden_state embeddings = torch.mean(hidden_states, dim=1) embeddings = self.projection(embeddings) embeddings = F.normalize(embeddings, p=2, dim=1) return embeddings # ============================================================ # GLOBAL SETUP # ============================================================ device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') # Load model model = Wav2Vec2ForSpeakerEmbedding(embedding_size=256).to(device) checkpoint = torch.load('best_embedding_model.pth', map_location=device) model.load_state_dict(checkpoint['model_state_dict']) model.eval() feature_extractor = Wav2Vec2FeatureExtractor.from_pretrained("facebook/wav2vec2-base") # ============================================================ # DATABASE # ============================================================ class EnrollmentDB: def __init__(self, db_path='enrollments.json'): self.db_path = db_path self.load_db() def load_db(self): if os.path.exists(self.db_path): with open(self.db_path, 'r') as f: data = json.load(f) self.enrollments = {k: np.array(v) for k, v in data.items()} else: self.enrollments = {} def save_db(self): data = {k: v.tolist() for k, v in self.enrollments.items()} with open(self.db_path, 'w') as f: json.dump(data, f) def enroll(self, name, embedding): self.enrollments[name] = embedding self.save_db() def verify(self, embedding, threshold=0.75): if not self.enrollments: return None, 0.0, False best_match = None best_score = -1.0 embedding_tensor = torch.from_numpy(embedding) for name, enrolled_emb in self.enrollments.items(): enrolled_tensor = torch.from_numpy(enrolled_emb) similarity = F.cosine_similarity(embedding_tensor, enrolled_tensor, dim=1).item() if similarity > best_score: best_score = similarity best_match = name is_verified = best_score >= threshold return best_match, best_score, is_verified def get_all_users(self): return list(self.enrollments.keys()) def get_user_count(self): return len(self.enrollments) def remove_user(self, name): if name in self.enrollments: del self.enrollments[name] self.save_db() return True return False db = EnrollmentDB() # ============================================================ # AUDIO PROCESSING # ============================================================ def process_audio(audio_path, max_length=16000*3): """Process audio file""" try: waveform, sr = sf.read(audio_path, dtype='float32') waveform = torch.from_numpy(waveform) if len(waveform.shape) > 1: waveform = torch.mean(waveform, dim=-1) if sr != 16000: resampler = torchaudio.transforms.Resample(sr, 16000) waveform = resampler(waveform) if len(waveform) > max_length: start = (len(waveform) - max_length) // 2 waveform = waveform[start:start + max_length] elif len(waveform) < max_length: padding = max_length - len(waveform) waveform = torch.nn.functional.pad(waveform, (0, padding)) if waveform.abs().max() > 0: waveform = waveform / waveform.abs().max() inputs = feature_extractor( waveform.numpy(), sampling_rate=16000, return_tensors="pt" ) return inputs.input_values except Exception as e: raise ValueError(f"Error processing audio: {e}") def get_embedding(audio_path): """Extract embedding from audio""" model.eval() with torch.no_grad(): inputs = process_audio(audio_path) inputs = inputs.to(device) embedding = model(inputs) return embedding.cpu().numpy() # ============================================================ # GRADIO FUNCTIONS # ============================================================ def enroll_user(name, audio, threshold): """Enroll a new user""" if not name or not name.strip(): return "❌ Veuillez entrer un nom.", get_user_list(), get_stats() if not audio: return "❌ Veuillez uploader un enregistrement audio.", get_user_list(), get_stats() name = name.strip() if name in db.get_all_users(): return f"⚠️ L'utilisateur '{name}' existe déjà.", get_user_list(), get_stats() try: embedding = get_embedding(audio) db.enroll(name, embedding) return f"✅ Enregistrement réussi!\n\n👤 {name} a été enregistré dans le système.\n📊 Total utilisateurs: {db.get_user_count()}", get_user_list(), get_stats() except Exception as e: return f"❌ Erreur: {str(e)}", get_user_list(), get_stats() def verify_user(audio, threshold): """Verify a user""" if not audio: return "❌ Veuillez uploader un enregistrement audio.", "" if db.get_user_count() == 0: return "⚠️ Aucun utilisateur enregistré. Veuillez d'abord enregistrer des utilisateurs.", "" try: embedding = get_embedding(audio) match_name, similarity, is_verified = db.verify(embedding, threshold) # Build detailed results details = "📊 **Scores détaillés:**\n\n" embedding_tensor = torch.from_numpy(embedding) scores = [] for name, enrolled_emb in db.enrollments.items(): enrolled_tensor = torch.from_numpy(enrolled_emb) sim = F.cosine_similarity(embedding_tensor, enrolled_tensor, dim=1).item() status = "✅" if sim >= threshold else "❌" scores.append((name, sim, status)) scores.sort(key=lambda x: x[1], reverse=True) for name, sim, status in scores: details += f"{status} **{name}**: {sim:.1%}\n" if is_verified: result = f""" # ✅ VÉRIFICATION RÉUSSIE ## Identifié comme: **{match_name}** ### Score de confiance: **{similarity:.1%}** --- """ return result + details, details else: result = f""" # ❌ VÉRIFICATION ÉCHOUÉE Meilleure correspondance: **{match_name}** Similarité: **{similarity:.1%}** Seuil requis: **{threshold:.1%}** *Cette voix n'est pas reconnue dans le système.* --- """ return result + details, details except Exception as e: return f"❌ Erreur: {str(e)}", "" def get_user_list(): """Get list of enrolled users""" users = db.get_all_users() if not users: return "Aucun utilisateur enregistré" return "\n".join([f"• {user}" for user in sorted(users)]) def get_stats(): """Get system statistics""" return f""" **📊 Statistiques du système:** - Utilisateurs enregistrés: {db.get_user_count()} - Précision du modèle: 76% - Score AUC: 0.82 - Architecture: Wav2Vec 2.0 """ def delete_user(name): """Delete a user""" if not name or not name.strip(): return "❌ Veuillez sélectionner un utilisateur.", get_user_list(), get_stats() if db.remove_user(name.strip()): return f"✅ Utilisateur '{name}' supprimé.", get_user_list(), get_stats() else: return f"❌ Utilisateur '{name}' non trouvé.", get_user_list(), get_stats() # ============================================================ # GRADIO INTERFACE # ============================================================ with gr.Blocks(title="Biométrie Vocale - POC", theme=gr.themes.Soft()) as demo: gr.Markdown(""" # 🎤 Système de Biométrie Vocale ### Proof of Concept - Wav2Vec 2.0 Fine-tuné """) with gr.Row(): with gr.Column(scale=2): stats_display = gr.Markdown(get_stats()) with gr.Column(scale=1): threshold = gr.Slider( minimum=0.5, maximum=0.95, value=0.75, step=0.05, label="Seuil de vérification", info="Plus élevé = vérification plus stricte" ) with gr.Tabs(): # TAB 1: ENROLLMENT with gr.Tab("📝 Enregistrement"): gr.Markdown("### Enregistrer un nouvel utilisateur") with gr.Row(): with gr.Column(): enroll_name_input = gr.Textbox( label="Nom de l'utilisateur", placeholder="Ex: Jean Dupont" ) enroll_audio_input = gr.Audio( label="Enregistrement vocal", type="filepath", sources=["upload", "microphone"] ) enroll_button = gr.Button("🎯 Enregistrer", variant="primary") with gr.Column(): gr.Markdown(""" **💡 Conseils:** - Audio clair et net - 3-20 secondes recommandées - Bruit de fond minimal - Voix normale """) enrolled_users = gr.Textbox( label="Utilisateurs enregistrés", value=get_user_list(), lines=8, interactive=False ) enroll_output = gr.Markdown() enroll_button.click( fn=enroll_user, inputs=[enroll_name_input, enroll_audio_input, threshold], outputs=[enroll_output, enrolled_users, stats_display] ) # TAB 2: VERIFICATION with gr.Tab("✅ Vérification"): gr.Markdown("### Vérifier l'identité d'un utilisateur") with gr.Row(): with gr.Column(): verify_audio_input = gr.Audio( label="Enregistrement vocal à vérifier", type="filepath", sources=["upload", "microphone"] ) verify_button = gr.Button("🔍 Vérifier", variant="primary") with gr.Column(): gr.Markdown(f""" **ℹ️ Information:** - {db.get_user_count()} utilisateur(s) enregistré(s) - Seuil: ajustable dans le slider ci-dessus - Modèle: Wav2Vec 2.0 """) verify_output = gr.Markdown() verify_details = gr.Markdown() verify_button.click( fn=verify_user, inputs=[verify_audio_input, threshold], outputs=[verify_output, verify_details] ) # TAB 3: MANAGEMENT with gr.Tab("⚙️ Gestion"): gr.Markdown("### Gérer les utilisateurs enregistrés") with gr.Row(): with gr.Column(): delete_name_input = gr.Textbox( label="Nom de l'utilisateur à supprimer", placeholder="Ex: Jean Dupont" ) delete_button = gr.Button("🗑️ Supprimer", variant="stop") with gr.Column(): delete_users_list = gr.Textbox( label="Utilisateurs enregistrés", value=get_user_list(), lines=8, interactive=False ) delete_output = gr.Markdown() delete_button.click( fn=delete_user, inputs=[delete_name_input], outputs=[delete_output, delete_users_list, stats_display] ) # TAB 4: ABOUT with gr.Tab("ℹ️ À propos"): gr.Markdown(""" ## 🎯 Technologie **Architecture du modèle:** - Base: Wav2Vec 2.0 (Facebook AI) - Fine-tuné sur 247 locuteurs - 1035 échantillons vocaux (qualité téléphonique, 8kHz) - Dimension d'embedding: 256 **Détails d'entraînement:** - Loss: Supervised Contrastive Learning - Framework: PyTorch + Transformers - Durée d'entraînement: ~50 epochs - Matériel: NVIDIA RTX 3050 --- ## 📊 Métriques de Performance **Résultats d'évaluation:** - **Précision:** 76% - **Score AUC:** 0.82 - **Taux de vrais positifs:** 79% - **Taux de faux positifs:** 27% **Ensemble de test:** - 1000 paires de vérification - 500 paires même locuteur - 500 paires locuteurs différents --- ## 🔧 Fonctionnement 1. **Phase d'enregistrement:** - L'utilisateur uploade un enregistrement vocal - Le système extrait un embedding de dimension 256 - L'embedding est stocké dans la base de données 2. **Phase de vérification:** - Enregistrement vocal inconnu uploadé - Le système extrait l'embedding - Calcul de similarité cosinus avec tous les utilisateurs enregistrés - Correspondance si similarité > seuil 3. **Algorithme de correspondance:** - Similarité cosinus entre embeddings - Plage: -1 (opposé) à +1 (identique) - Même locuteur typique: 0.75-0.95 - Locuteurs différents typique: 0.30-0.70 --- **Note:** Ceci est un système proof of concept. Pour un déploiement en production, considérer: - Dataset plus large (10-20 échantillons par locuteur) - Meilleur modèle de base (WavLM pour conditions bruitées) - Mesures anti-spoofing - Détection de vivacité - Multi-enregistrement (moyenne de plusieurs enregistrements par utilisateur) """) demo.launch(share=False)