Spaces:
Sleeping
Sleeping
| import os | |
| import gradio as gr | |
| import torch | |
| import torch.nn as nn | |
| import torch.nn.functional as F | |
| import soundfile as sf | |
| import torchaudio | |
| from transformers import Wav2Vec2FeatureExtractor, Wav2Vec2Model | |
| import numpy as np | |
| import json | |
| # ============================================================ | |
| # MODEL DEFINITION | |
| # ============================================================ | |
| class Wav2Vec2ForSpeakerEmbedding(nn.Module): | |
| def __init__(self, embedding_size=256): | |
| super().__init__() | |
| self.wav2vec2 = Wav2Vec2Model.from_pretrained("facebook/wav2vec2-base") | |
| for param in self.wav2vec2.parameters(): | |
| param.requires_grad = False | |
| self.projection = nn.Sequential( | |
| nn.Linear(768, 512), | |
| nn.ReLU(), | |
| nn.Dropout(0.1), | |
| nn.Linear(512, embedding_size) | |
| ) | |
| def forward(self, input_values): | |
| outputs = self.wav2vec2(input_values) | |
| hidden_states = outputs.last_hidden_state | |
| embeddings = torch.mean(hidden_states, dim=1) | |
| embeddings = self.projection(embeddings) | |
| embeddings = F.normalize(embeddings, p=2, dim=1) | |
| return embeddings | |
| # ============================================================ | |
| # GLOBAL SETUP | |
| # ============================================================ | |
| device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') | |
| # Load model | |
| model = Wav2Vec2ForSpeakerEmbedding(embedding_size=256).to(device) | |
| checkpoint = torch.load('best_embedding_model.pth', map_location=device) | |
| model.load_state_dict(checkpoint['model_state_dict']) | |
| model.eval() | |
| feature_extractor = Wav2Vec2FeatureExtractor.from_pretrained("facebook/wav2vec2-base") | |
| # ============================================================ | |
| # DATABASE | |
| # ============================================================ | |
| class EnrollmentDB: | |
| def __init__(self, db_path='enrollments.json'): | |
| self.db_path = db_path | |
| self.load_db() | |
| def load_db(self): | |
| if os.path.exists(self.db_path): | |
| with open(self.db_path, 'r') as f: | |
| data = json.load(f) | |
| self.enrollments = {k: np.array(v) for k, v in data.items()} | |
| else: | |
| self.enrollments = {} | |
| def save_db(self): | |
| data = {k: v.tolist() for k, v in self.enrollments.items()} | |
| with open(self.db_path, 'w') as f: | |
| json.dump(data, f) | |
| def enroll(self, name, embedding): | |
| self.enrollments[name] = embedding | |
| self.save_db() | |
| def verify(self, embedding, threshold=0.75): | |
| if not self.enrollments: | |
| return None, 0.0, False | |
| best_match = None | |
| best_score = -1.0 | |
| embedding_tensor = torch.from_numpy(embedding) | |
| for name, enrolled_emb in self.enrollments.items(): | |
| enrolled_tensor = torch.from_numpy(enrolled_emb) | |
| similarity = F.cosine_similarity(embedding_tensor, enrolled_tensor, dim=1).item() | |
| if similarity > best_score: | |
| best_score = similarity | |
| best_match = name | |
| is_verified = best_score >= threshold | |
| return best_match, best_score, is_verified | |
| def get_all_users(self): | |
| return list(self.enrollments.keys()) | |
| def get_user_count(self): | |
| return len(self.enrollments) | |
| def remove_user(self, name): | |
| if name in self.enrollments: | |
| del self.enrollments[name] | |
| self.save_db() | |
| return True | |
| return False | |
| db = EnrollmentDB() | |
| # ============================================================ | |
| # AUDIO PROCESSING | |
| # ============================================================ | |
| def process_audio(audio_path, max_length=16000*3): | |
| """Process audio file""" | |
| try: | |
| waveform, sr = sf.read(audio_path, dtype='float32') | |
| waveform = torch.from_numpy(waveform) | |
| if len(waveform.shape) > 1: | |
| waveform = torch.mean(waveform, dim=-1) | |
| if sr != 16000: | |
| resampler = torchaudio.transforms.Resample(sr, 16000) | |
| waveform = resampler(waveform) | |
| if len(waveform) > max_length: | |
| start = (len(waveform) - max_length) // 2 | |
| waveform = waveform[start:start + max_length] | |
| elif len(waveform) < max_length: | |
| padding = max_length - len(waveform) | |
| waveform = torch.nn.functional.pad(waveform, (0, padding)) | |
| if waveform.abs().max() > 0: | |
| waveform = waveform / waveform.abs().max() | |
| inputs = feature_extractor( | |
| waveform.numpy(), | |
| sampling_rate=16000, | |
| return_tensors="pt" | |
| ) | |
| return inputs.input_values | |
| except Exception as e: | |
| raise ValueError(f"Error processing audio: {e}") | |
| def get_embedding(audio_path): | |
| """Extract embedding from audio""" | |
| model.eval() | |
| with torch.no_grad(): | |
| inputs = process_audio(audio_path) | |
| inputs = inputs.to(device) | |
| embedding = model(inputs) | |
| return embedding.cpu().numpy() | |
| # ============================================================ | |
| # GRADIO FUNCTIONS | |
| # ============================================================ | |
| def enroll_user(name, audio, threshold): | |
| """Enroll a new user""" | |
| if not name or not name.strip(): | |
| return "❌ Veuillez entrer un nom.", get_user_list(), get_stats() | |
| if not audio: | |
| return "❌ Veuillez uploader un enregistrement audio.", get_user_list(), get_stats() | |
| name = name.strip() | |
| if name in db.get_all_users(): | |
| return f"⚠️ L'utilisateur '{name}' existe déjà.", get_user_list(), get_stats() | |
| try: | |
| embedding = get_embedding(audio) | |
| db.enroll(name, embedding) | |
| return f"✅ Enregistrement réussi!\n\n👤 {name} a été enregistré dans le système.\n📊 Total utilisateurs: {db.get_user_count()}", get_user_list(), get_stats() | |
| except Exception as e: | |
| return f"❌ Erreur: {str(e)}", get_user_list(), get_stats() | |
| def verify_user(audio, threshold): | |
| """Verify a user""" | |
| if not audio: | |
| return "❌ Veuillez uploader un enregistrement audio.", "" | |
| if db.get_user_count() == 0: | |
| return "⚠️ Aucun utilisateur enregistré. Veuillez d'abord enregistrer des utilisateurs.", "" | |
| try: | |
| embedding = get_embedding(audio) | |
| match_name, similarity, is_verified = db.verify(embedding, threshold) | |
| # Build detailed results | |
| details = "📊 **Scores détaillés:**\n\n" | |
| embedding_tensor = torch.from_numpy(embedding) | |
| scores = [] | |
| for name, enrolled_emb in db.enrollments.items(): | |
| enrolled_tensor = torch.from_numpy(enrolled_emb) | |
| sim = F.cosine_similarity(embedding_tensor, enrolled_tensor, dim=1).item() | |
| status = "✅" if sim >= threshold else "❌" | |
| scores.append((name, sim, status)) | |
| scores.sort(key=lambda x: x[1], reverse=True) | |
| for name, sim, status in scores: | |
| details += f"{status} **{name}**: {sim:.1%}\n" | |
| if is_verified: | |
| result = f""" | |
| # ✅ VÉRIFICATION RÉUSSIE | |
| ## Identifié comme: **{match_name}** | |
| ### Score de confiance: **{similarity:.1%}** | |
| --- | |
| """ | |
| return result + details, details | |
| else: | |
| result = f""" | |
| # ❌ VÉRIFICATION ÉCHOUÉE | |
| Meilleure correspondance: **{match_name}** | |
| Similarité: **{similarity:.1%}** | |
| Seuil requis: **{threshold:.1%}** | |
| *Cette voix n'est pas reconnue dans le système.* | |
| --- | |
| """ | |
| return result + details, details | |
| except Exception as e: | |
| return f"❌ Erreur: {str(e)}", "" | |
| def get_user_list(): | |
| """Get list of enrolled users""" | |
| users = db.get_all_users() | |
| if not users: | |
| return "Aucun utilisateur enregistré" | |
| return "\n".join([f"• {user}" for user in sorted(users)]) | |
| def get_stats(): | |
| """Get system statistics""" | |
| return f""" | |
| **📊 Statistiques du système:** | |
| - Utilisateurs enregistrés: {db.get_user_count()} | |
| - Précision du modèle: 76% | |
| - Score AUC: 0.82 | |
| - Architecture: Wav2Vec 2.0 | |
| """ | |
| def delete_user(name): | |
| """Delete a user""" | |
| if not name or not name.strip(): | |
| return "❌ Veuillez sélectionner un utilisateur.", get_user_list(), get_stats() | |
| if db.remove_user(name.strip()): | |
| return f"✅ Utilisateur '{name}' supprimé.", get_user_list(), get_stats() | |
| else: | |
| return f"❌ Utilisateur '{name}' non trouvé.", get_user_list(), get_stats() | |
| # ============================================================ | |
| # GRADIO INTERFACE | |
| # ============================================================ | |
| with gr.Blocks(title="Biométrie Vocale - POC", theme=gr.themes.Soft()) as demo: | |
| gr.Markdown(""" | |
| # 🎤 Système de Biométrie Vocale | |
| ### Proof of Concept - Wav2Vec 2.0 Fine-tuné | |
| """) | |
| with gr.Row(): | |
| with gr.Column(scale=2): | |
| stats_display = gr.Markdown(get_stats()) | |
| with gr.Column(scale=1): | |
| threshold = gr.Slider( | |
| minimum=0.5, | |
| maximum=0.95, | |
| value=0.75, | |
| step=0.05, | |
| label="Seuil de vérification", | |
| info="Plus élevé = vérification plus stricte" | |
| ) | |
| with gr.Tabs(): | |
| # TAB 1: ENROLLMENT | |
| with gr.Tab("📝 Enregistrement"): | |
| gr.Markdown("### Enregistrer un nouvel utilisateur") | |
| with gr.Row(): | |
| with gr.Column(): | |
| enroll_name_input = gr.Textbox( | |
| label="Nom de l'utilisateur", | |
| placeholder="Ex: Jean Dupont" | |
| ) | |
| enroll_audio_input = gr.Audio( | |
| label="Enregistrement vocal", | |
| type="filepath", | |
| sources=["upload", "microphone"] | |
| ) | |
| enroll_button = gr.Button("🎯 Enregistrer", variant="primary") | |
| with gr.Column(): | |
| gr.Markdown(""" | |
| **💡 Conseils:** | |
| - Audio clair et net | |
| - 3-20 secondes recommandées | |
| - Bruit de fond minimal | |
| - Voix normale | |
| """) | |
| enrolled_users = gr.Textbox( | |
| label="Utilisateurs enregistrés", | |
| value=get_user_list(), | |
| lines=8, | |
| interactive=False | |
| ) | |
| enroll_output = gr.Markdown() | |
| enroll_button.click( | |
| fn=enroll_user, | |
| inputs=[enroll_name_input, enroll_audio_input, threshold], | |
| outputs=[enroll_output, enrolled_users, stats_display] | |
| ) | |
| # TAB 2: VERIFICATION | |
| with gr.Tab("✅ Vérification"): | |
| gr.Markdown("### Vérifier l'identité d'un utilisateur") | |
| with gr.Row(): | |
| with gr.Column(): | |
| verify_audio_input = gr.Audio( | |
| label="Enregistrement vocal à vérifier", | |
| type="filepath", | |
| sources=["upload", "microphone"] | |
| ) | |
| verify_button = gr.Button("🔍 Vérifier", variant="primary") | |
| with gr.Column(): | |
| gr.Markdown(f""" | |
| **ℹ️ Information:** | |
| - {db.get_user_count()} utilisateur(s) enregistré(s) | |
| - Seuil: ajustable dans le slider ci-dessus | |
| - Modèle: Wav2Vec 2.0 | |
| """) | |
| verify_output = gr.Markdown() | |
| verify_details = gr.Markdown() | |
| verify_button.click( | |
| fn=verify_user, | |
| inputs=[verify_audio_input, threshold], | |
| outputs=[verify_output, verify_details] | |
| ) | |
| # TAB 3: MANAGEMENT | |
| with gr.Tab("⚙️ Gestion"): | |
| gr.Markdown("### Gérer les utilisateurs enregistrés") | |
| with gr.Row(): | |
| with gr.Column(): | |
| delete_name_input = gr.Textbox( | |
| label="Nom de l'utilisateur à supprimer", | |
| placeholder="Ex: Jean Dupont" | |
| ) | |
| delete_button = gr.Button("🗑️ Supprimer", variant="stop") | |
| with gr.Column(): | |
| delete_users_list = gr.Textbox( | |
| label="Utilisateurs enregistrés", | |
| value=get_user_list(), | |
| lines=8, | |
| interactive=False | |
| ) | |
| delete_output = gr.Markdown() | |
| delete_button.click( | |
| fn=delete_user, | |
| inputs=[delete_name_input], | |
| outputs=[delete_output, delete_users_list, stats_display] | |
| ) | |
| # TAB 4: ABOUT | |
| with gr.Tab("ℹ️ À propos"): | |
| gr.Markdown(""" | |
| ## 🎯 Technologie | |
| **Architecture du modèle:** | |
| - Base: Wav2Vec 2.0 (Facebook AI) | |
| - Fine-tuné sur 247 locuteurs | |
| - 1035 échantillons vocaux (qualité téléphonique, 8kHz) | |
| - Dimension d'embedding: 256 | |
| **Détails d'entraînement:** | |
| - Loss: Supervised Contrastive Learning | |
| - Framework: PyTorch + Transformers | |
| - Durée d'entraînement: ~50 epochs | |
| - Matériel: NVIDIA RTX 3050 | |
| --- | |
| ## 📊 Métriques de Performance | |
| **Résultats d'évaluation:** | |
| - **Précision:** 76% | |
| - **Score AUC:** 0.82 | |
| - **Taux de vrais positifs:** 79% | |
| - **Taux de faux positifs:** 27% | |
| **Ensemble de test:** | |
| - 1000 paires de vérification | |
| - 500 paires même locuteur | |
| - 500 paires locuteurs différents | |
| --- | |
| ## 🔧 Fonctionnement | |
| 1. **Phase d'enregistrement:** | |
| - L'utilisateur uploade un enregistrement vocal | |
| - Le système extrait un embedding de dimension 256 | |
| - L'embedding est stocké dans la base de données | |
| 2. **Phase de vérification:** | |
| - Enregistrement vocal inconnu uploadé | |
| - Le système extrait l'embedding | |
| - Calcul de similarité cosinus avec tous les utilisateurs enregistrés | |
| - Correspondance si similarité > seuil | |
| 3. **Algorithme de correspondance:** | |
| - Similarité cosinus entre embeddings | |
| - Plage: -1 (opposé) à +1 (identique) | |
| - Même locuteur typique: 0.75-0.95 | |
| - Locuteurs différents typique: 0.30-0.70 | |
| --- | |
| **Note:** Ceci est un système proof of concept. Pour un déploiement en production, considérer: | |
| - Dataset plus large (10-20 échantillons par locuteur) | |
| - Meilleur modèle de base (WavLM pour conditions bruitées) | |
| - Mesures anti-spoofing | |
| - Détection de vivacité | |
| - Multi-enregistrement (moyenne de plusieurs enregistrements par utilisateur) | |
| """) | |
| demo.launch(share=False) |