Commit
·
00613da
1
Parent(s):
f217a11
Add demo inference
Browse files- README.md +24 -1
- app.py +135 -0
- requirements.txt +5 -0
README.md
CHANGED
|
@@ -11,4 +11,27 @@ license: apache-2.0
|
|
| 11 |
short_description: Demo deployment for voice safety classifier
|
| 12 |
---
|
| 13 |
|
| 14 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 11 |
short_description: Demo deployment for voice safety classifier
|
| 12 |
---
|
| 13 |
|
| 14 |
+
# Voice Safety Classifier Demo
|
| 15 |
+
|
| 16 |
+
This is a demo application for the [Roblox Voice Safety Classifier v2](https://huggingface.co/Roblox/voice-safety-classifier-v2) model.
|
| 17 |
+
|
| 18 |
+
## Model Information
|
| 19 |
+
|
| 20 |
+
The Voice Safety Classifier is designed to detect potentially unsafe content in audio. It can classify audio into various safety categories to help identify problematic content.
|
| 21 |
+
|
| 22 |
+
## Usage
|
| 23 |
+
|
| 24 |
+
1. Upload an audio file or record audio directly in your browser
|
| 25 |
+
2. The model will process the audio and return classification results
|
| 26 |
+
3. Results are displayed with confidence scores for each category
|
| 27 |
+
|
| 28 |
+
## Technical Details
|
| 29 |
+
|
| 30 |
+
This demo uses:
|
| 31 |
+
- Hugging Face Transformers
|
| 32 |
+
- Gradio for the web interface
|
| 33 |
+
- PyTorch and TorchAudio for audio processing
|
| 34 |
+
|
| 35 |
+
## License
|
| 36 |
+
|
| 37 |
+
This demo uses the Roblox Voice Safety Classifier v2 model. Please refer to the [model card](https://huggingface.co/Roblox/voice-safety-classifier-v2) for licensing information.
|
app.py
ADDED
|
@@ -0,0 +1,135 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from typing import Any, Dict
|
| 2 |
+
|
| 3 |
+
import gradio as gr
|
| 4 |
+
import librosa
|
| 5 |
+
import numpy as np
|
| 6 |
+
import torch
|
| 7 |
+
from transformers import WavLMForSequenceClassification
|
| 8 |
+
|
| 9 |
+
|
| 10 |
+
def feature_extract_simple(
|
| 11 |
+
wav,
|
| 12 |
+
sr=16_000,
|
| 13 |
+
win_len=15.0,
|
| 14 |
+
win_stride=15.0,
|
| 15 |
+
do_normalize=False,
|
| 16 |
+
) -> np.ndarray:
|
| 17 |
+
"""Simple feature extraction for WavLM.
|
| 18 |
+
Parameters
|
| 19 |
+
----------
|
| 20 |
+
wav : str or array-like
|
| 21 |
+
path to the wav file, or array-like
|
| 22 |
+
sr : int, optional
|
| 23 |
+
sample rate, by default 16_000
|
| 24 |
+
win_len : float, optional
|
| 25 |
+
window length, by default 15.0
|
| 26 |
+
win_stride : float, optional
|
| 27 |
+
window stride, by default 15.0
|
| 28 |
+
do_normalize: bool, optional
|
| 29 |
+
whether to normalize the input, by default False.
|
| 30 |
+
Returns
|
| 31 |
+
-------
|
| 32 |
+
np.ndarray
|
| 33 |
+
batched input to WavLM
|
| 34 |
+
"""
|
| 35 |
+
if type(wav) == str:
|
| 36 |
+
signal, _ = librosa.core.load(wav, sr=sr)
|
| 37 |
+
else:
|
| 38 |
+
try:
|
| 39 |
+
signal = np.array(wav).squeeze()
|
| 40 |
+
except Exception as e:
|
| 41 |
+
print(e)
|
| 42 |
+
raise RuntimeError
|
| 43 |
+
batched_input = []
|
| 44 |
+
stride = int(win_stride * sr)
|
| 45 |
+
l = int(win_len * sr)
|
| 46 |
+
if len(signal) / sr > win_len:
|
| 47 |
+
for i in range(0, len(signal), stride):
|
| 48 |
+
if i + int(win_len * sr) > len(signal):
|
| 49 |
+
# padding the last chunk to make it the same length as others
|
| 50 |
+
chunked = np.pad(signal[i:], (0, l - len(signal[i:])))
|
| 51 |
+
else:
|
| 52 |
+
chunked = signal[i : i + l]
|
| 53 |
+
if do_normalize:
|
| 54 |
+
chunked = (chunked - np.mean(chunked)) / (np.std(chunked) + 1e-7)
|
| 55 |
+
batched_input.append(chunked)
|
| 56 |
+
if i + int(win_len * sr) > len(signal):
|
| 57 |
+
break
|
| 58 |
+
else:
|
| 59 |
+
if do_normalize:
|
| 60 |
+
signal = (signal - np.mean(signal)) / (np.std(signal) + 1e-7)
|
| 61 |
+
batched_input.append(signal)
|
| 62 |
+
return np.stack(batched_input) # [N, T]
|
| 63 |
+
|
| 64 |
+
|
| 65 |
+
def infer(model, inputs) -> torch.Tensor:
|
| 66 |
+
output = model(inputs)
|
| 67 |
+
probs = torch.sigmoid(torch.Tensor(output.logits))
|
| 68 |
+
return probs
|
| 69 |
+
|
| 70 |
+
|
| 71 |
+
def predict(audio_file) -> Dict[str, Any]:
|
| 72 |
+
if audio_file is None:
|
| 73 |
+
return {"No prediction available": 0.0}
|
| 74 |
+
|
| 75 |
+
try:
|
| 76 |
+
input_np = feature_extract_simple(audio_file, sr=16000, do_normalize=True)
|
| 77 |
+
input_pt = torch.Tensor(input_np)
|
| 78 |
+
|
| 79 |
+
probs = infer(model, input_pt)
|
| 80 |
+
probs_list = probs.reshape(-1, len(labels)).detach().tolist()
|
| 81 |
+
|
| 82 |
+
# Create a results dictionary
|
| 83 |
+
if len(probs_list) > 0:
|
| 84 |
+
first_segment_probs = probs_list[0]
|
| 85 |
+
results = {
|
| 86 |
+
label: float(prob) for label, prob in zip(labels, first_segment_probs)
|
| 87 |
+
}
|
| 88 |
+
|
| 89 |
+
# If there are multiple segments, include that information in the results
|
| 90 |
+
if len(probs_list) > 1:
|
| 91 |
+
results["Note"] = (
|
| 92 |
+
f"Audio contains {len(probs_list)} segments. Showing first segment only."
|
| 93 |
+
)
|
| 94 |
+
else:
|
| 95 |
+
results = {"Error": "No segments detected in audio"}
|
| 96 |
+
|
| 97 |
+
# Sort by confidence score
|
| 98 |
+
sorted_results = dict(sorted(results.items(), key=lambda x: x[1], reverse=True))
|
| 99 |
+
|
| 100 |
+
return sorted_results
|
| 101 |
+
except Exception as e:
|
| 102 |
+
return {"Error": str(e)}
|
| 103 |
+
|
| 104 |
+
|
| 105 |
+
if __name__ == "__main__":
|
| 106 |
+
model_path = "Roblox/voice-safety-classifier-v2"
|
| 107 |
+
labels = [
|
| 108 |
+
"Discrimination",
|
| 109 |
+
"Harassment",
|
| 110 |
+
"Sexual",
|
| 111 |
+
"IllegalAndRegulated",
|
| 112 |
+
"DatingAndRomantic",
|
| 113 |
+
"Profanity",
|
| 114 |
+
]
|
| 115 |
+
|
| 116 |
+
model = WavLMForSequenceClassification.from_pretrained(
|
| 117 |
+
model_path, num_labels=len(labels)
|
| 118 |
+
)
|
| 119 |
+
model.eval()
|
| 120 |
+
|
| 121 |
+
demo = gr.Interface(
|
| 122 |
+
fn=predict,
|
| 123 |
+
inputs=gr.Audio(type="filepath", label="Upload or record audio"),
|
| 124 |
+
outputs=gr.Label(num_top_classes=6, label="Classification Results"),
|
| 125 |
+
title="Voice Safety Classifier",
|
| 126 |
+
description="""This app uses the Roblox Voice Safety Classifier v2 model to identify potentially unsafe content in audio.
|
| 127 |
+
Upload or record an audio file to get started. The model classifies audio into categories including Discrimination,
|
| 128 |
+
Harassment, Sexual, IllegalAndRegulated, DatingAndRomantic, and Profanity.
|
| 129 |
+
|
| 130 |
+
The model processes audio in 15-second chunks and returns probability scores for each category.""",
|
| 131 |
+
examples=[],
|
| 132 |
+
flagging_mode="never",
|
| 133 |
+
)
|
| 134 |
+
|
| 135 |
+
demo.launch()
|
requirements.txt
ADDED
|
@@ -0,0 +1,5 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
gradio>=5
|
| 2 |
+
librosa>=0.10.0
|
| 3 |
+
numpy>=1.24.0
|
| 4 |
+
torch>=2.0.0
|
| 5 |
+
transformers>=4.30.0
|