Spaces:
Build error
Build error
| import torch | |
| import torchaudio | |
| from torch import nn | |
| from transformers import AutoFeatureExtractor,AutoModelForAudioClassification,pipeline | |
| #Preprocessing the data | |
| feature_extractor = AutoFeatureExtractor.from_pretrained("facebook/wav2vec2-base") | |
| max_duration = 2.0 # seconds | |
| if torch.cuda.is_available(): | |
| device = "cuda" | |
| else: | |
| device = "cpu" | |
| softmax = nn.Softmax() | |
| label2id, id2label = dict(), dict() | |
| labels = ['0','1','2','3','4','5','6','7','8','9'] | |
| num_labels = 10 | |
| for i, label in enumerate(labels): | |
| label2id[label] = str(i) | |
| id2label[str(i)] = label | |
| def get_pipeline(model_name): | |
| if model_name.split('-')[-1].strip()!='ibo': | |
| return None | |
| return pipeline(task="audio-classification", model=model_name) | |
| def load_model(model_checkpoint): | |
| #if model_checkpoint.split('-')[-1].strip()!='ibo': #This is for DEBUGGING | |
| # return None, None | |
| # construct model and assign it to device | |
| model = AutoModelForAudioClassification.from_pretrained( | |
| model_checkpoint, | |
| num_labels=num_labels, | |
| label2id=label2id, | |
| id2label=id2label, | |
| ).to(device) | |
| return model | |
| language_dict = { | |
| "Igbo":'ibo', | |
| "Oshiwambo":'kua', | |
| "Yoruba":'yor', | |
| "Oromo":'gax', | |
| "Shona":'sna', | |
| "Rundi":'run', | |
| "Choose language":'none', | |
| "MULTILINGUAL":'all' | |
| } | |
| AUDIO_CLASSIFICATION_MODELS= {'ibo':load_model('chrisjay/afrospeech-wav2vec-ibo'), | |
| 'kua':load_model('chrisjay/afrospeech-wav2vec-kua'), | |
| 'sna':load_model('chrisjay/afrospeech-wav2vec-sna'), | |
| 'yor':load_model('chrisjay/afrospeech-wav2vec-yor'), | |
| 'gax':load_model('chrisjay/afrospeech-wav2vec-gax'), | |
| 'run':load_model('chrisjay/afrospeech-wav2vec-run'), | |
| 'all':load_model('chrisjay/afrospeech-wav2vec-all-6') } | |
| def cut_if_necessary(signal,num_samples): | |
| if signal.shape[1] > num_samples: | |
| signal = signal[:, :num_samples] | |
| return signal | |
| def right_pad_if_necessary(signal,num_samples): | |
| length_signal = signal.shape[1] | |
| if length_signal < num_samples: | |
| num_missing_samples = num_samples - length_signal | |
| last_dim_padding = (0, num_missing_samples) | |
| signal = torch.nn.functional.pad(signal, last_dim_padding) | |
| return signal | |
| def resample_if_necessary(signal, sr,target_sample_rate,device): | |
| if sr != target_sample_rate: | |
| resampler = torchaudio.transforms.Resample(sr, target_sample_rate).to(device) | |
| signal = resampler(signal) | |
| return signal | |
| def mix_down_if_necessary(signal): | |
| if signal.shape[0] > 1: | |
| signal = torch.mean(signal, dim=0, keepdim=True) | |
| return signal | |
| def preprocess_audio(waveform,sample_rate,feature_extractor): | |
| waveform = resample_if_necessary(waveform, sample_rate,16000,device) | |
| waveform = mix_down_if_necessary(waveform) | |
| waveform = cut_if_necessary(waveform,16000) | |
| waveform = right_pad_if_necessary(waveform,16000) | |
| transformed = feature_extractor(waveform,sampling_rate=feature_extractor.sampling_rate, max_length=16000, truncation=True) | |
| return transformed | |
| def make_inference(drop_down,audio): | |
| waveform, sample_rate = torchaudio.load(audio) | |
| preprocessed_audio = preprocess_audio(waveform,sample_rate,feature_extractor) | |
| language_code_chosen = language_dict[drop_down] | |
| model = AUDIO_CLASSIFICATION_MODELS[language_code_chosen] | |
| model.eval() | |
| torch_preprocessed_audio = torch.from_numpy(preprocessed_audio.input_values[0]) | |
| # make prediction | |
| prediction = softmax(model(torch_preprocessed_audio).logits) | |
| sorted_prediction = torch.sort(prediction,descending=True) | |
| confidences={} | |
| for s,v in zip(sorted_prediction.indices.detach().numpy().tolist()[0],sorted_prediction.values.detach().numpy().tolist()[0]): | |
| confidences.update({s:v}) | |
| return confidences | |