Azgadel's picture
Bug Fixes
b251a32 verified
import os
import gradio as gr
import torch
import torch.nn as nn
import torch.nn.functional as F
import soundfile as sf
import torchaudio
from transformers import Wav2Vec2FeatureExtractor, Wav2Vec2Model
import numpy as np
import json
# ============================================================
# MODEL DEFINITION
# ============================================================
class Wav2Vec2ForSpeakerEmbedding(nn.Module):
def __init__(self, embedding_size=256):
super().__init__()
self.wav2vec2 = Wav2Vec2Model.from_pretrained("facebook/wav2vec2-base")
for param in self.wav2vec2.parameters():
param.requires_grad = False
self.projection = nn.Sequential(
nn.Linear(768, 512),
nn.ReLU(),
nn.Dropout(0.1),
nn.Linear(512, embedding_size)
)
def forward(self, input_values):
outputs = self.wav2vec2(input_values)
hidden_states = outputs.last_hidden_state
embeddings = torch.mean(hidden_states, dim=1)
embeddings = self.projection(embeddings)
embeddings = F.normalize(embeddings, p=2, dim=1)
return embeddings
# ============================================================
# GLOBAL SETUP
# ============================================================
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
# Load model
model = Wav2Vec2ForSpeakerEmbedding(embedding_size=256).to(device)
checkpoint = torch.load('best_embedding_model.pth', map_location=device)
model.load_state_dict(checkpoint['model_state_dict'])
model.eval()
feature_extractor = Wav2Vec2FeatureExtractor.from_pretrained("facebook/wav2vec2-base")
# ============================================================
# DATABASE
# ============================================================
class EnrollmentDB:
def __init__(self, db_path='enrollments.json'):
self.db_path = db_path
self.load_db()
def load_db(self):
if os.path.exists(self.db_path):
with open(self.db_path, 'r') as f:
data = json.load(f)
self.enrollments = {k: np.array(v) for k, v in data.items()}
else:
self.enrollments = {}
def save_db(self):
data = {k: v.tolist() for k, v in self.enrollments.items()}
with open(self.db_path, 'w') as f:
json.dump(data, f)
def enroll(self, name, embedding):
self.enrollments[name] = embedding
self.save_db()
def verify(self, embedding, threshold=0.75):
if not self.enrollments:
return None, 0.0, False
best_match = None
best_score = -1.0
embedding_tensor = torch.from_numpy(embedding)
for name, enrolled_emb in self.enrollments.items():
enrolled_tensor = torch.from_numpy(enrolled_emb)
similarity = F.cosine_similarity(embedding_tensor, enrolled_tensor, dim=1).item()
if similarity > best_score:
best_score = similarity
best_match = name
is_verified = best_score >= threshold
return best_match, best_score, is_verified
def get_all_users(self):
return list(self.enrollments.keys())
def get_user_count(self):
return len(self.enrollments)
def remove_user(self, name):
if name in self.enrollments:
del self.enrollments[name]
self.save_db()
return True
return False
db = EnrollmentDB()
# ============================================================
# AUDIO PROCESSING
# ============================================================
def process_audio(audio_path, max_length=16000*3):
"""Process audio file"""
try:
waveform, sr = sf.read(audio_path, dtype='float32')
waveform = torch.from_numpy(waveform)
if len(waveform.shape) > 1:
waveform = torch.mean(waveform, dim=-1)
if sr != 16000:
resampler = torchaudio.transforms.Resample(sr, 16000)
waveform = resampler(waveform)
if len(waveform) > max_length:
start = (len(waveform) - max_length) // 2
waveform = waveform[start:start + max_length]
elif len(waveform) < max_length:
padding = max_length - len(waveform)
waveform = torch.nn.functional.pad(waveform, (0, padding))
if waveform.abs().max() > 0:
waveform = waveform / waveform.abs().max()
inputs = feature_extractor(
waveform.numpy(),
sampling_rate=16000,
return_tensors="pt"
)
return inputs.input_values
except Exception as e:
raise ValueError(f"Error processing audio: {e}")
def get_embedding(audio_path):
"""Extract embedding from audio"""
model.eval()
with torch.no_grad():
inputs = process_audio(audio_path)
inputs = inputs.to(device)
embedding = model(inputs)
return embedding.cpu().numpy()
# ============================================================
# GRADIO FUNCTIONS
# ============================================================
def enroll_user(name, audio, threshold):
"""Enroll a new user"""
if not name or not name.strip():
return "❌ Veuillez entrer un nom.", get_user_list(), get_stats()
if not audio:
return "❌ Veuillez uploader un enregistrement audio.", get_user_list(), get_stats()
name = name.strip()
if name in db.get_all_users():
return f"⚠️ L'utilisateur '{name}' existe déjà.", get_user_list(), get_stats()
try:
embedding = get_embedding(audio)
db.enroll(name, embedding)
return f"✅ Enregistrement réussi!\n\n👤 {name} a été enregistré dans le système.\n📊 Total utilisateurs: {db.get_user_count()}", get_user_list(), get_stats()
except Exception as e:
return f"❌ Erreur: {str(e)}", get_user_list(), get_stats()
def verify_user(audio, threshold):
"""Verify a user"""
if not audio:
return "❌ Veuillez uploader un enregistrement audio.", ""
if db.get_user_count() == 0:
return "⚠️ Aucun utilisateur enregistré. Veuillez d'abord enregistrer des utilisateurs.", ""
try:
embedding = get_embedding(audio)
match_name, similarity, is_verified = db.verify(embedding, threshold)
# Build detailed results
details = "📊 **Scores détaillés:**\n\n"
embedding_tensor = torch.from_numpy(embedding)
scores = []
for name, enrolled_emb in db.enrollments.items():
enrolled_tensor = torch.from_numpy(enrolled_emb)
sim = F.cosine_similarity(embedding_tensor, enrolled_tensor, dim=1).item()
status = "✅" if sim >= threshold else "❌"
scores.append((name, sim, status))
scores.sort(key=lambda x: x[1], reverse=True)
for name, sim, status in scores:
details += f"{status} **{name}**: {sim:.1%}\n"
if is_verified:
result = f"""
# ✅ VÉRIFICATION RÉUSSIE
## Identifié comme: **{match_name}**
### Score de confiance: **{similarity:.1%}**
---
"""
return result + details, details
else:
result = f"""
# ❌ VÉRIFICATION ÉCHOUÉE
Meilleure correspondance: **{match_name}**
Similarité: **{similarity:.1%}**
Seuil requis: **{threshold:.1%}**
*Cette voix n'est pas reconnue dans le système.*
---
"""
return result + details, details
except Exception as e:
return f"❌ Erreur: {str(e)}", ""
def get_user_list():
"""Get list of enrolled users"""
users = db.get_all_users()
if not users:
return "Aucun utilisateur enregistré"
return "\n".join([f"• {user}" for user in sorted(users)])
def get_stats():
"""Get system statistics"""
return f"""
**📊 Statistiques du système:**
- Utilisateurs enregistrés: {db.get_user_count()}
- Précision du modèle: 76%
- Score AUC: 0.82
- Architecture: Wav2Vec 2.0
"""
def delete_user(name):
"""Delete a user"""
if not name or not name.strip():
return "❌ Veuillez sélectionner un utilisateur.", get_user_list(), get_stats()
if db.remove_user(name.strip()):
return f"✅ Utilisateur '{name}' supprimé.", get_user_list(), get_stats()
else:
return f"❌ Utilisateur '{name}' non trouvé.", get_user_list(), get_stats()
# ============================================================
# GRADIO INTERFACE
# ============================================================
with gr.Blocks(title="Biométrie Vocale - POC", theme=gr.themes.Soft()) as demo:
gr.Markdown("""
# 🎤 Système de Biométrie Vocale
### Proof of Concept - Wav2Vec 2.0 Fine-tuné
""")
with gr.Row():
with gr.Column(scale=2):
stats_display = gr.Markdown(get_stats())
with gr.Column(scale=1):
threshold = gr.Slider(
minimum=0.5,
maximum=0.95,
value=0.75,
step=0.05,
label="Seuil de vérification",
info="Plus élevé = vérification plus stricte"
)
with gr.Tabs():
# TAB 1: ENROLLMENT
with gr.Tab("📝 Enregistrement"):
gr.Markdown("### Enregistrer un nouvel utilisateur")
with gr.Row():
with gr.Column():
enroll_name_input = gr.Textbox(
label="Nom de l'utilisateur",
placeholder="Ex: Jean Dupont"
)
enroll_audio_input = gr.Audio(
label="Enregistrement vocal",
type="filepath",
sources=["upload", "microphone"]
)
enroll_button = gr.Button("🎯 Enregistrer", variant="primary")
with gr.Column():
gr.Markdown("""
**💡 Conseils:**
- Audio clair et net
- 3-20 secondes recommandées
- Bruit de fond minimal
- Voix normale
""")
enrolled_users = gr.Textbox(
label="Utilisateurs enregistrés",
value=get_user_list(),
lines=8,
interactive=False
)
enroll_output = gr.Markdown()
enroll_button.click(
fn=enroll_user,
inputs=[enroll_name_input, enroll_audio_input, threshold],
outputs=[enroll_output, enrolled_users, stats_display]
)
# TAB 2: VERIFICATION
with gr.Tab("✅ Vérification"):
gr.Markdown("### Vérifier l'identité d'un utilisateur")
with gr.Row():
with gr.Column():
verify_audio_input = gr.Audio(
label="Enregistrement vocal à vérifier",
type="filepath",
sources=["upload", "microphone"]
)
verify_button = gr.Button("🔍 Vérifier", variant="primary")
with gr.Column():
gr.Markdown(f"""
**ℹ️ Information:**
- {db.get_user_count()} utilisateur(s) enregistré(s)
- Seuil: ajustable dans le slider ci-dessus
- Modèle: Wav2Vec 2.0
""")
verify_output = gr.Markdown()
verify_details = gr.Markdown()
verify_button.click(
fn=verify_user,
inputs=[verify_audio_input, threshold],
outputs=[verify_output, verify_details]
)
# TAB 3: MANAGEMENT
with gr.Tab("⚙️ Gestion"):
gr.Markdown("### Gérer les utilisateurs enregistrés")
with gr.Row():
with gr.Column():
delete_name_input = gr.Textbox(
label="Nom de l'utilisateur à supprimer",
placeholder="Ex: Jean Dupont"
)
delete_button = gr.Button("🗑️ Supprimer", variant="stop")
with gr.Column():
delete_users_list = gr.Textbox(
label="Utilisateurs enregistrés",
value=get_user_list(),
lines=8,
interactive=False
)
delete_output = gr.Markdown()
delete_button.click(
fn=delete_user,
inputs=[delete_name_input],
outputs=[delete_output, delete_users_list, stats_display]
)
# TAB 4: ABOUT
with gr.Tab("ℹ️ À propos"):
gr.Markdown("""
## 🎯 Technologie
**Architecture du modèle:**
- Base: Wav2Vec 2.0 (Facebook AI)
- Fine-tuné sur 247 locuteurs
- 1035 échantillons vocaux (qualité téléphonique, 8kHz)
- Dimension d'embedding: 256
**Détails d'entraînement:**
- Loss: Supervised Contrastive Learning
- Framework: PyTorch + Transformers
- Durée d'entraînement: ~50 epochs
- Matériel: NVIDIA RTX 3050
---
## 📊 Métriques de Performance
**Résultats d'évaluation:**
- **Précision:** 76%
- **Score AUC:** 0.82
- **Taux de vrais positifs:** 79%
- **Taux de faux positifs:** 27%
**Ensemble de test:**
- 1000 paires de vérification
- 500 paires même locuteur
- 500 paires locuteurs différents
---
## 🔧 Fonctionnement
1. **Phase d'enregistrement:**
- L'utilisateur uploade un enregistrement vocal
- Le système extrait un embedding de dimension 256
- L'embedding est stocké dans la base de données
2. **Phase de vérification:**
- Enregistrement vocal inconnu uploadé
- Le système extrait l'embedding
- Calcul de similarité cosinus avec tous les utilisateurs enregistrés
- Correspondance si similarité > seuil
3. **Algorithme de correspondance:**
- Similarité cosinus entre embeddings
- Plage: -1 (opposé) à +1 (identique)
- Même locuteur typique: 0.75-0.95
- Locuteurs différents typique: 0.30-0.70
---
**Note:** Ceci est un système proof of concept. Pour un déploiement en production, considérer:
- Dataset plus large (10-20 échantillons par locuteur)
- Meilleur modèle de base (WavLM pour conditions bruitées)
- Mesures anti-spoofing
- Détection de vivacité
- Multi-enregistrement (moyenne de plusieurs enregistrements par utilisateur)
""")
demo.launch(share=False)