Spaces:
Sleeping
Sleeping
| import gradio as gr | |
| import torch | |
| import numpy as np | |
| import librosa | |
| import os | |
| from transformers import Wav2Vec2BertModel, AutoFeatureExtractor, HubertModel | |
| import torch.nn as nn | |
| from typing import Optional, Tuple | |
| from transformers.file_utils import ModelOutput | |
| from dataclasses import dataclass | |
| from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss | |
| class SpeechClassifierOutput(ModelOutput): | |
| loss: Optional[torch.FloatTensor] = None | |
| logits: torch.FloatTensor = None | |
| hidden_states: Optional[Tuple[torch.FloatTensor]] = None | |
| attentions: Optional[Tuple[torch.FloatTensor]] = None | |
| from transformers.models.wav2vec2.modeling_wav2vec2 import ( | |
| Wav2Vec2PreTrainedModel, | |
| Wav2Vec2Model | |
| ) | |
| class Wav2Vec2ClassificationHead(nn.Module): | |
| """Head for wav2vec classification task.""" | |
| def __init__(self, config): | |
| super().__init__() | |
| self.dense = nn.Linear(config.hidden_size, config.hidden_size) | |
| self.dropout = nn.Dropout(config.final_dropout) | |
| self.out_proj = nn.Linear(config.hidden_size, config.num_labels) | |
| def forward(self, features, **kwargs): | |
| x = features | |
| x = self.dropout(x) | |
| x = self.dense(x) | |
| x = torch.tanh(x) | |
| x = self.dropout(x) | |
| x = self.out_proj(x) | |
| return x | |
| class Wav2Vec2ForSpeechClassification(nn.Module): | |
| def __init__(self,model_name): | |
| super().__init__() | |
| self.num_labels = 2 | |
| self.pooling_mode = 'mean' | |
| self.wav2vec2bert = Wav2Vec2BertModel.from_pretrained(model_name) | |
| self.config = self.wav2vec2bert.config | |
| self.classifier = Wav2Vec2ClassificationHead(self.wav2vec2bert.config) | |
| def merged_strategy(self,hidden_states,mode="mean"): | |
| if mode == "mean": | |
| outputs = torch.mean(hidden_states, dim=1) | |
| elif mode == "sum": | |
| outputs = torch.sum(hidden_states, dim=1) | |
| elif mode == "max": | |
| outputs = torch.max(hidden_states, dim=1)[0] | |
| else: | |
| raise Exception( | |
| "The pooling method hasn't been defined! Your pooling mode must be one of these ['mean', 'sum', 'max']") | |
| return outputs | |
| def forward(self,input_features,attention_mask=None,output_attentions=None,output_hidden_states=None,return_dict=None,labels=None,): | |
| return_dict = return_dict if return_dict is not None else self.config.use_return_dict | |
| outputs = self.wav2vec2bert( | |
| input_features, | |
| attention_mask=attention_mask, | |
| output_attentions=output_attentions, | |
| output_hidden_states=output_hidden_states, | |
| return_dict=return_dict, | |
| ) | |
| hidden_states = outputs.last_hidden_state | |
| hidden_states = self.merged_strategy(hidden_states, mode=self.pooling_mode) | |
| logits = self.classifier(hidden_states) | |
| loss = None | |
| if labels is not None: | |
| if self.config.problem_type is None: | |
| if self.num_labels == 1: | |
| self.config.problem_type = "regression" | |
| elif self.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int): | |
| self.config.problem_type = "single_label_classification" | |
| else: | |
| self.config.problem_type = "multi_label_classification" | |
| if self.config.problem_type == "regression": | |
| loss_fct = MSELoss() | |
| loss = loss_fct(logits.view(-1, self.num_labels), labels) | |
| elif self.config.problem_type == "single_label_classification": | |
| loss_fct = CrossEntropyLoss() | |
| loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1)) | |
| elif self.config.problem_type == "multi_label_classification": | |
| loss_fct = BCEWithLogitsLoss() | |
| loss = loss_fct(logits, labels) | |
| if not return_dict: | |
| output = (logits,) + outputs[2:] | |
| return ((loss,) + output) if loss is not None else output | |
| return SpeechClassifierOutput( | |
| loss=loss, | |
| logits=logits, | |
| hidden_states=outputs.last_hidden_state, | |
| attentions=outputs.attentions, | |
| ) | |
| class HuBERT(nn.Module): | |
| def __init__(self, model_name): | |
| super().__init__() | |
| self.num_labels = 2 | |
| self.pooling_mode = 'mean' | |
| self.wav2vec2 = HubertModel.from_pretrained(model_name) | |
| self.config = self.wav2vec2.config | |
| self.classifier = Wav2Vec2ClassificationHead(self.wav2vec2.config) | |
| def merged_strategy(self, hidden_states, mode="mean"): | |
| if mode == "mean": | |
| outputs = torch.mean(hidden_states, dim=1) | |
| elif mode == "sum": | |
| outputs = torch.sum(hidden_states, dim=1) | |
| elif mode == "max": | |
| outputs = torch.max(hidden_states, dim=1)[0] | |
| else: | |
| raise Exception( | |
| "The pooling method hasn't been defined! Your pooling mode must be one of these ['mean', 'sum', 'max']") | |
| return outputs | |
| def forward(self, input_values, attention_mask=None, output_attentions=None, output_hidden_states=None, | |
| return_dict=None, labels=None, ): | |
| return_dict = return_dict if return_dict is not None else self.config.use_return_dict | |
| outputs = self.wav2vec2( | |
| input_values, | |
| attention_mask=attention_mask, | |
| output_attentions=output_attentions, | |
| output_hidden_states=output_hidden_states, | |
| return_dict=return_dict, | |
| ) | |
| hidden_states = outputs.last_hidden_state | |
| hidden_states = self.merged_strategy(hidden_states, mode=self.pooling_mode) | |
| logits = self.classifier(hidden_states) | |
| loss = None | |
| if labels is not None: | |
| if self.config.problem_type is None: | |
| if self.num_labels == 1: | |
| self.config.problem_type = "regression" | |
| elif self.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int): | |
| self.config.problem_type = "single_label_classification" | |
| else: | |
| self.config.problem_type = "multi_label_classification" | |
| if self.config.problem_type == "regression": | |
| loss_fct = MSELoss() | |
| loss = loss_fct(logits.view(-1, self.num_labels), labels) | |
| elif self.config.problem_type == "single_label_classification": | |
| loss_fct = CrossEntropyLoss() | |
| loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1)) | |
| elif self.config.problem_type == "multi_label_classification": | |
| loss_fct = BCEWithLogitsLoss() | |
| loss = loss_fct(logits, labels) | |
| if not return_dict: | |
| output = (logits,) + outputs[2:] | |
| return ((loss,) + output) if loss is not None else output | |
| return SpeechClassifierOutput( | |
| loss=loss, | |
| logits=logits, | |
| hidden_states=outputs.last_hidden_state, | |
| attentions=outputs.attentions, | |
| ) | |
| def pad(x, max_len=64000): | |
| x_len = x.shape[0] | |
| if x_len > max_len: | |
| stt = np.random.randint(x_len - max_len) | |
| return x[stt:stt + max_len] | |
| # return x[:max_len] | |
| # num_repeats = int(max_len / x_len) + 1 | |
| # padded_x = np.tile(x, (num_repeats))[:max_len] | |
| pad_length = max_len - x_len | |
| padded_x = np.concatenate([x, np.zeros(pad_length)], axis=0) | |
| return padded_x | |
| class AudioDeepfakeDetector: | |
| def __init__(self): | |
| self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') | |
| self.models = {} | |
| self.feature_extractors = {} | |
| self.current_model = None | |
| # model_name = 'facebook/w2v-bert-2.0' | |
| # self.feature_extractor = AutoFeatureExtractor.from_pretrained(model_name) | |
| # self.model = Wav2Vec2ForSpeechClassification(model_name).to(self.device) | |
| # ckpt = torch.load("wave2vec2bert_wavefake.pth",map_location=self.device) | |
| # self.model.load_state_dict(ckpt) | |
| print(f"Using device: {self.device}") | |
| print("Audio deepfake detector initilized") | |
| def load_model(self, model_type): | |
| """Load the specified model type""" | |
| if model_type in self.models: | |
| self.current_model = model_type | |
| return | |
| try: | |
| print(f"π Loading {model_type} model...") | |
| if model_type == "Wave2Vec2BERT": | |
| model_name = 'facebook/w2v-bert-2.0' | |
| self.feature_extractors[model_type] = AutoFeatureExtractor.from_pretrained(model_name) | |
| self.models[model_type] = Wav2Vec2ForSpeechClassification(model_name).to(self.device) | |
| # checkpoint_path = "wave2vec2bert_wavefake.pth" | |
| # if os.path.exists(checkpoint_path): | |
| # ckpt = torch.load(checkpoint_path, map_location=self.device) | |
| # self.models[model_type].load_state_dict(ckpt) | |
| # print(f"β Loaded checkpoint for {model_type}") | |
| # else: | |
| # print(f"β οΈ Checkpoint not found for {model_type}, using pretrained weights only") | |
| try: | |
| from huggingface_hub import hf_hub_download | |
| checkpoint_path = hf_hub_download( | |
| repo_id="TrustSafeAI/AudioDeepfakeDetectors", | |
| filename="wave2vec2bert_wavefake.pth", | |
| cache_dir="./models" | |
| ) | |
| ckpt = torch.load(checkpoint_path, map_location=self.device) | |
| self.models[model_type].load_state_dict(ckpt) | |
| print(f"β Loaded checkpoint for {model_type}") | |
| except Exception as e: | |
| print(f"β οΈ Could not load checkpoint for {model_type}: {e}") | |
| print("Using pretrained weights only") | |
| elif model_type == "HuBERT": | |
| model_name = 'facebook/hubert-large-ls960-ft' | |
| self.feature_extractors[model_type] = AutoFeatureExtractor.from_pretrained(model_name) | |
| self.models[model_type] = HuBERT(model_name).to(self.device) | |
| # checkpoint_path = "hubert_large_wavefake.pth" | |
| # if os.path.exists(checkpoint_path): | |
| # ckpt = torch.load(checkpoint_path, map_location=self.device) | |
| # self.models[model_type].load_state_dict(ckpt) | |
| # print(f"β Loaded checkpoint for {model_type}") | |
| # else: | |
| # print(f"β οΈ Checkpoint not found for {model_type}, using pretrained weights only") | |
| try: | |
| from huggingface_hub import hf_hub_download | |
| checkpoint_path = hf_hub_download( | |
| repo_id="TrustSafeAI/AudioDeepfakeDetectors", # ζΏζ’δΈΊδ½ η樑εδ»εΊ | |
| filename="hubert_large_wavefake.pth", | |
| cache_dir="./models" | |
| ) | |
| ckpt = torch.load(checkpoint_path, map_location=self.device) | |
| self.models[model_type].load_state_dict(ckpt) | |
| print(f"β Loaded checkpoint for {model_type}") | |
| except Exception as e: | |
| print(f"β οΈ Could not load checkpoint for {model_type}: {e}") | |
| print("Using pretrained weights only") | |
| self.current_model = model_type | |
| print(f"β {model_type} model loaded successfully") | |
| except Exception as e: | |
| print(f"β Error loading {model_type} model: {str(e)}") | |
| raise | |
| def preprocess_audio(self, audio_path, target_sr=16000, max_length=4): | |
| try: | |
| print(f"π Loading audio file: {os.path.basename(audio_path)}") | |
| audio, sr = librosa.load(audio_path, sr=target_sr) | |
| original_duration = len(audio) / sr | |
| audio = pad(audio).reshape(-1) | |
| audio = audio[np.newaxis, :] | |
| print(f"β Audio loaded successfully: {original_duration:.2f}s, {sr}Hz") | |
| return audio, sr | |
| except Exception as e: | |
| print(f"β Audio processing error: {str(e)}") | |
| raise | |
| def extract_features(self, audio, sr, model_type): | |
| print("π extract audio features...") | |
| feature_extractor = self.feature_extractors[model_type] | |
| inputs = feature_extractor(audio, sampling_rate=sr, return_attention_mask=True, padding_value=0, return_tensors="pt").to(self.device) | |
| print("β Feature extracion completed") | |
| return inputs | |
| def classifier(self, features, model_type): | |
| model = self.models[model_type] | |
| with torch.no_grad(): | |
| outputs = model(**features) | |
| prob = outputs.logits.softmax(dim=-1) | |
| fake_prob = prob[0][0].item() | |
| return fake_prob | |
| def predict(self, audio_path, model_type): | |
| try: | |
| print("π΅ Start analyzing...") | |
| self.load_model(model_type) | |
| audio, sr = self.preprocess_audio(audio_path) | |
| features= self.extract_features(audio, sr, model_type) | |
| fake_probability = self.classifier(features, model_type) | |
| real_probability = 1 - fake_probability | |
| threshold = 0.5 | |
| if fake_probability > threshold: | |
| status = "SUSPICIOUS" | |
| prediction = "π¨ Likely fake audio" | |
| confidence = fake_probability | |
| color = "red" | |
| else: | |
| status = "LIKELY_REAL" | |
| prediction = "β Likely real audio" | |
| confidence = real_probability | |
| color = "green" | |
| print(f"\n{'='*50}") | |
| print(f"π― Result: {prediction}") | |
| print(f"π Confidence: {confidence:.1%}") | |
| print(f"π Real Probability: {real_probability:.1%}") | |
| print(f"π Fake Probability: {fake_probability:.1%}") | |
| print(f"{'='*50}") | |
| duration = len(audio) / sr | |
| file_size = os.path.getsize(audio_path) / 1024 | |
| result_data = { | |
| "status": status, | |
| "prediction": prediction, | |
| "confidence": confidence, | |
| "real_probability": real_probability, | |
| "fake_probability": fake_probability, | |
| "duration": duration, | |
| "sample_rate": sr, | |
| "file_size_kb": file_size, | |
| "model_used": model_type | |
| } | |
| return result_data | |
| except Exception as e: | |
| print(f"β Failed: {str(e)}") | |
| return {"error": str(e)} | |
| detector = AudioDeepfakeDetector() | |
| def analyze_uploaded_audio(audio_file, model_choice): | |
| if audio_file is None: | |
| return "Please upload audio", {} | |
| try: | |
| result = detector.predict(audio_file, model_choice) | |
| if "error" in result: | |
| return f"Error: {result['error']}", {} | |
| status_color = "#ff4444" if result['status'] == "SUSPICIOUS" else "#44ff44" | |
| result_html = f""" | |
| <div style="padding: 20px; border-radius: 10px; background-color: {status_color}20; border: 2px solid {status_color};"> | |
| <h3 style="color: {status_color}; margin-top: 0;">{result['prediction']}</h3> | |
| <p><strong>Status:</strong> {result['status']}</p> | |
| <p><strong>Confidence:</strong> {result['confidence']:.1%}</p> | |
| </div> | |
| """ | |
| analysis_data = { | |
| "status": result['status'], | |
| "real_probability": f"{result['real_probability']:.1%}", | |
| "fake_probability": f"{result['fake_probability']:.1%}", | |
| } | |
| return result_html, analysis_data | |
| except Exception as e: | |
| error_html = f""" | |
| <div style="padding: 20px; border-radius: 10px; background-color: #ff444420; border: 2px solid #ff4444;"> | |
| <h3 style="color: #ff4444;">β Processing error</h3> | |
| <p>{str(e)}</p> | |
| </div> | |
| """ | |
| return error_html, {"error": str(e)} | |
| def create_audio_interface(): | |
| with gr.Blocks(title="Audio Deepfake Detection", theme=gr.themes.Soft()) as interface: | |
| gr.Markdown(""" | |
| <div style="text-align: center; margin-bottom: 30px;"> | |
| <h1 style="font-size: 28px; font-weight: bold; margin-bottom: 20px; color: #333;"> | |
| Measuring the Robustness of Audio Deepfake Detection under Real-World Corruptions | |
| </h1> | |
| <p style="font-size: 16px; color: #666; margin-bottom: 15px;"> | |
| Audio deepfake detectors based on Wave2Vec2BERT and HuBERT speech foundation models (fine-tuned with Wavefake dataset). | |
| </p> | |
| <div style="font-size: 14px; color: #555; line-height: 1.8; text-align: left;"> | |
| <p><strong>Paper:</strong> <a href="https://arxiv.org/pdf/2503.17577" target="_blank" style="color: #4285f4; text-decoration: none;">https://arxiv.org/pdf/2503.17577</a></p> | |
| <p><strong>Project Page:</strong> <a href="https://huggingface.co/spaces/TrustSafeAI/AudioPerturber" target="_blank" style="color: #4285f4; text-decoration: none;">"https://huggingface.co/spaces/TrustSafeAI/AudioPerturber</a></p> | |
| <p><strong>Model Checkpoints:</strong> <a href="https://huggingface.co/TrustSafeAI/AudioDeepfakeDetectors" target="_blank" style="color: #4285f4; text-decoration: none;">"https://huggingface.co/TrustSafeAI/AudioDeepfakeDetectors</a></p> | |
| <p><strong>Github Codebase:</strong> <a href="https://github.com/Jessegator/Audio_robustness_evaluation" target="_blank" style="color: #4285f4; text-decoration: none;">https://github.com/Jessegator/Audio_robustness_evaluation</a></p> | |
| </div> | |
| </div> | |
| <hr style="margin: 30px 0; border: none; border-top: 1px solid #e0e0e0;"> | |
| """) | |
| gr.Markdown(""" | |
| # Audio Deepfake Detection | |
| **Supported Format**: .wav, .mp3, .flac, .m4a, etc. | |
| """) | |
| with gr.Row(): | |
| # model_choice = gr.Dropdown( | |
| # choices=["Wave2Vec2BERT", "HuBERT"], | |
| # value="Wave2Vec2BERT", | |
| # label="π€ Select Model", | |
| # info="Choose the foundation model for detection" | |
| # ) | |
| with gr.Column(scale=1): | |
| model_choice = gr.Dropdown( | |
| choices=["Wave2Vec2BERT", "HuBERT"], | |
| value="Wave2Vec2BERT", | |
| label="π€ Select Model", | |
| info="Choose the foundation model for detection" | |
| ) | |
| audio_input = gr.Audio( | |
| label="π Upload audio file", | |
| type="filepath", | |
| show_label=True, | |
| interactive=True | |
| ) | |
| analyze_btn = gr.Button( | |
| "π Start analyzing", | |
| variant="primary", | |
| size="lg" | |
| ) | |
| gr.Markdown("### π Play uploaded audio") | |
| audio_player = gr.Audio( | |
| label="Audio Player", | |
| interactive=False, | |
| show_label=False | |
| ) | |
| with gr.Column(scale=1): | |
| result_display = gr.HTML( | |
| label="π― Results", | |
| value="<p style='text-align: center; color: #666;'>Waiting for uploading...</p>" | |
| ) | |
| analysis_json = gr.JSON( | |
| label="π Detailed analysis", | |
| value={} | |
| ) | |
| def update_player_and_analyze(audio_file, model_type): | |
| if audio_file is not None: | |
| result_html, result_data = analyze_uploaded_audio(audio_file, model_type) | |
| return audio_file, result_html, result_data | |
| else: | |
| return None, "<p style='text-align: center; color: #666;'>Waiting for uploading...</p>", {} | |
| audio_input.change( | |
| fn=update_player_and_analyze, | |
| inputs=[audio_input, model_choice], | |
| outputs=[audio_player, result_display, analysis_json] | |
| ) | |
| analyze_btn.click( | |
| fn=analyze_uploaded_audio, | |
| inputs=[audio_input, model_choice], | |
| outputs=[result_display, analysis_json] | |
| ) | |
| model_choice.change( | |
| fn=lambda audio_file, model_type: analyze_uploaded_audio(audio_file, model_type) if audio_file is not None else ("Please upload audio first", {}), | |
| inputs=[audio_input, model_choice], | |
| outputs=[result_display, analysis_json] | |
| ) | |
| return interface | |
| if __name__ == "__main__": | |
| print("π Create interface...") | |
| demo = create_audio_interface() | |
| print("π± Launching...") | |
| demo.launch( | |
| share=False, | |
| debug=True, | |
| show_error=True | |
| ) |