Spaces:
Running
Running
| import typing | |
| import types # fusion of forward() of Wav2Vec2 | |
| import gradio as gr | |
| import matplotlib.pyplot as plt | |
| import numpy as np | |
| import spaces | |
| import torch | |
| import torch.nn as nn | |
| from transformers import Wav2Vec2Processor | |
| from transformers.models.wav2vec2.modeling_wav2vec2 import Wav2Vec2Model | |
| from transformers.models.wav2vec2.modeling_wav2vec2 import Wav2Vec2PreTrainedModel | |
| import audiofile | |
| import audresample | |
| device = 0 if torch.cuda.is_available() else "cpu" | |
| duration = 2 # limit processing of audio | |
| age_gender_model_name = "audeering/wav2vec2-large-robust-24-ft-age-gender" | |
| expression_model_name = "audeering/wav2vec2-large-robust-12-ft-emotion-msp-dim" | |
| class AgeGenderHead(nn.Module): | |
| r"""Age-gender model head.""" | |
| def __init__(self, config, num_labels): | |
| super().__init__() | |
| self.dense = nn.Linear(config.hidden_size, config.hidden_size) | |
| self.dropout = nn.Dropout(config.final_dropout) | |
| self.out_proj = nn.Linear(config.hidden_size, num_labels) | |
| def forward(self, features, **kwargs): | |
| x = features | |
| x = self.dropout(x) | |
| x = self.dense(x) | |
| x = torch.tanh(x) | |
| x = self.dropout(x) | |
| x = self.out_proj(x) | |
| return x | |
| class AgeGenderModel(Wav2Vec2PreTrainedModel): | |
| r"""Age-gender recognition model.""" | |
| def __init__(self, config): | |
| super().__init__(config) | |
| self.config = config | |
| self.wav2vec2 = Wav2Vec2Model(config) | |
| self.age = AgeGenderHead(config, 1) | |
| self.gender = AgeGenderHead(config, 3) | |
| self.init_weights() | |
| def forward( | |
| self, | |
| frozen_cnn7, | |
| ): | |
| hidden_states = self.wav2vec2(frozen_cnn7=frozen_cnn7) # runs only Transformer layers | |
| hidden_states = torch.mean(hidden_states, dim=1) | |
| logits_age = self.age(hidden_states) | |
| logits_gender = torch.softmax(self.gender(hidden_states), dim=1) | |
| return hidden_states, logits_age, logits_gender | |
| # == Fusion = Define Age Wav2Vec2Model's forward to accept already computed CNN7 features from Emotion | |
| def _forward( | |
| self, | |
| extract_features, | |
| attention_mask=None): | |
| # extract_features : CNN7 fetures of wav2vec2 as they are calc. from CNN7 feature extractor | |
| if attention_mask is not None: | |
| # compute reduced attention_mask corresponding to feature vectors | |
| attention_mask = self._get_feature_vector_attention_mask( | |
| extract_features.shape[1], attention_mask, add_adapter=False | |
| ) | |
| hidden_states, extract_features = self.feature_projection(extract_features) | |
| hidden_states = self._mask_hidden_states( | |
| hidden_states, mask_time_indices=mask_time_indices, attention_mask=attention_mask | |
| ) | |
| encoder_outputs = self.encoder( | |
| hidden_states, | |
| attention_mask=attention_mask, | |
| output_attentions=output_attentions, | |
| output_hidden_states=output_hidden_states, | |
| return_dict=return_dict, | |
| ) | |
| hidden_states = encoder_outputs[0] | |
| if self.adapter is not None: | |
| raise ValueError | |
| hidden_states = self.adapter(hidden_states) | |
| return hidden_states | |
| # =============================================== | |
| # ================== Foward & CNN features | |
| def _forward_and_cnn7( | |
| self, | |
| input_values, | |
| attention_mask=None | |
| ): | |
| frozen_cnn7 = self.feature_extractor(input_values) | |
| frozen_cnn7 = frozen_cnn7.transpose(1, 2) | |
| if attention_mask is not None: | |
| # compute reduced attention_mask corresponding to feature vectors | |
| attention_mask = self._get_feature_vector_attention_mask( | |
| frozen_cnn7.shape[1], attention_mask, add_adapter=False | |
| ) | |
| hidden_states, extract_features = self.feature_projection(frozen_cnn7) # grad=True non frozen | |
| hidden_states = self._mask_hidden_states( | |
| hidden_states, mask_time_indices=mask_time_indices, attention_mask=attention_mask | |
| ) | |
| encoder_outputs = self.encoder( | |
| hidden_states, | |
| attention_mask=attention_mask, | |
| output_attentions=output_attentions, | |
| output_hidden_states=output_hidden_states, | |
| return_dict=return_dict, | |
| ) | |
| hidden_states = encoder_outputs[0] | |
| if self.adapter is not None: | |
| raise ValueError | |
| hidden_states = self.adapter(hidden_states) | |
| return hidden_states, frozen_cnn7 # feature_projection is trainable thus we are unable to use the projected hidden states from official wav2vev2.forward | |
| # ============================= | |
| class ExpressionHead(nn.Module): | |
| r"""Expression model head.""" | |
| def __init__(self, config): | |
| super().__init__() | |
| self.dense = nn.Linear(config.hidden_size, config.hidden_size) | |
| self.dropout = nn.Dropout(config.final_dropout) | |
| self.out_proj = nn.Linear(config.hidden_size, config.num_labels) | |
| def forward(self, features, **kwargs): | |
| x = features | |
| x = self.dropout(x) | |
| x = self.dense(x) | |
| x = torch.tanh(x) | |
| x = self.dropout(x) | |
| x = self.out_proj(x) | |
| return x | |
| class ExpressionModel(Wav2Vec2PreTrainedModel): | |
| r"""speech expression model.""" | |
| def __init__(self, config): | |
| super().__init__(config) | |
| self.config = config | |
| self.wav2vec2 = Wav2Vec2Model(config) | |
| self.classifier = ExpressionHead(config) | |
| self.init_weights() | |
| def forward(self, input_values): | |
| hidden_states, frozen_cnn7 = self.wav2vec2(input_values) | |
| hidden_states = torch.mean(hidden_states, dim=1) | |
| logits = self.classifier(hidden_states) | |
| return hidden_states, logits, frozen_cnn7 | |
| # Load models from hub | |
| age_gender_processor = Wav2Vec2Processor.from_pretrained(age_gender_model_name) | |
| age_gender_model = AgeGenderModel.from_pretrained(age_gender_model_name) | |
| expression_processor = Wav2Vec2Processor.from_pretrained(expression_model_name) | |
| expression_model = ExpressionModel.from_pretrained(expression_model_name) | |
| # Emotion Calc. CNN features | |
| age_gender_model.wav2vec2.forward = types.MethodType(_forward, age_gender_model) | |
| expression_model.wav2vec2.forward = types.MethodType(_forward_and_cnn7, expression_model) | |
| def process_func(x: np.ndarray, sampling_rate: int) -> typing.Tuple[str, dict, str]: | |
| # batch audio | |
| y = expression_processor(x, sampling_rate=sampling_rate) | |
| y = y['input_values'][0] | |
| y = y.reshape(1, -1) | |
| y = torch.from_numpy(y).to(device) | |
| # run through expression model | |
| with torch.no_grad(): | |
| _, logits_expression, frozen_cnn7 = expression_model(y) | |
| _, logits_age, logits_gender = age_gender_model(frozen_cnn7=frozen_cnn7) | |
| # Plot A/D/V values | |
| plot_expression(logits_expression[0, 0].item(), # implicit detach().cpu().numpy() | |
| logits_expression[0, 1].item(), | |
| logits_expression[0, 2].item()) | |
| expression_file = "expression.png" | |
| plt.savefig(expression_file) | |
| return ( | |
| f"{round(100 * logits_age[0, 0].item())} years", # age | |
| { | |
| "female": logits_gender[0, 0].item(), | |
| "male": logits_gender[0, 1].item(), | |
| "child": logits_gender[0, 2].item(), | |
| }, | |
| expression_file, | |
| ) | |
| def recognize(input_file: str) -> typing.Tuple[str, dict, str]: | |
| # sampling_rate, signal = input_microphone | |
| # signal = signal.astype(np.float32, order="C") / 32768.0 | |
| if input_file is None: | |
| raise gr.Error( | |
| "No audio file submitted! " | |
| "Please upload or record an audio file " | |
| "before submitting your request." | |
| ) | |
| signal, sampling_rate = audiofile.read(input_file, duration=duration) | |
| # Resample to sampling rate supported byu the models | |
| target_rate = 16000 | |
| signal = audresample.resample(signal, sampling_rate, target_rate) | |
| return process_func(signal, target_rate) | |
| def plot_expression(arousal, dominance, valence): | |
| r"""3D pixel plot of arousal, dominance, valence.""" | |
| # Voxels per dimension | |
| voxels = 7 | |
| # Create voxel grid | |
| x, y, z = np.indices((voxels + 1, voxels + 1, voxels + 1)) | |
| voxel = ( | |
| (x == round(arousal * voxels)) | |
| & (y == round(dominance * voxels)) | |
| & (z == round(valence * voxels)) | |
| ) | |
| projection = ( | |
| (x == round(arousal * voxels)) | |
| & (y == round(dominance * voxels)) | |
| & (z < round(valence * voxels)) | |
| ) | |
| colors = np.empty((voxel | projection).shape, dtype=object) | |
| colors[voxel] = "#fcb06c" | |
| colors[projection] = "#fed7a9" | |
| ax = plt.figure().add_subplot(projection='3d') | |
| ax.voxels(voxel | projection, facecolors=colors, edgecolor='k') | |
| ax.set_xlim([0, voxels]) | |
| ax.set_ylim([0, voxels]) | |
| ax.set_zlim([0, voxels]) | |
| ax.set_aspect("equal") | |
| ax.set_xlabel("arousal", fontsize="large", labelpad=0) | |
| ax.set_ylabel("dominance", fontsize="large", labelpad=0) | |
| ax.set_zlabel("valence", fontsize="large", labelpad=0) | |
| ax.set_xticks( | |
| list(range(voxels + 1)), | |
| labels=[0, None, None, None, None, None, None, 1], | |
| verticalalignment="bottom", | |
| ) | |
| ax.set_yticks( | |
| list(range(voxels + 1)), | |
| labels=[0, None, None, None, None, None, None, 1], | |
| verticalalignment="bottom", | |
| ) | |
| ax.set_zticks( | |
| list(range(voxels + 1)), | |
| labels=[0, None, None, None, None, None, None, 1], | |
| verticalalignment="top", | |
| ) | |
| description = ( | |
| "Estimate **age**, **gender**, and **expression** " | |
| "of the speaker contained in an audio file or microphone recording. \n" | |
| f"The model [{age_gender_model_name}]" | |
| f"(https://huggingface.co/{age_gender_model_name}) " | |
| "recognises age and gender, " | |
| f"whereas [{expression_model_name}]" | |
| f"(https://huggingface.co/{expression_model_name}) " | |
| "recognises the expression dimensions arousal, dominance, and valence. " | |
| ) | |
| with gr.Blocks() as demo: | |
| with gr.Tab(label="Speech analysis"): | |
| with gr.Row(): | |
| with gr.Column(): | |
| gr.Markdown(description) | |
| input = gr.Audio( | |
| sources=["upload", "microphone"], | |
| type="filepath", | |
| label="Audio input", | |
| min_length=0.025, # seconds | |
| ) | |
| gr.Examples( | |
| [ | |
| "female-46-neutral.wav", | |
| "female-20-happy.wav", | |
| "male-60-angry.wav", | |
| "male-27-sad.wav", | |
| ], | |
| [input], | |
| label="Examples from CREMA-D, ODbL v1.0 license", | |
| ) | |
| gr.Markdown("Only the first two seconds of the audio will be processed.") | |
| submit_btn = gr.Button(value="Submit") | |
| with gr.Column(): | |
| output_age = gr.Textbox(label="Age") | |
| output_gender = gr.Label(label="Gender") | |
| output_expression = gr.Image(label="Expression") | |
| outputs = [output_age, output_gender, output_expression] | |
| submit_btn.click(recognize, input, outputs) | |
| demo.launch(debug=True) | |