Spaces:
Runtime error
Runtime error
| import tensorflow as tf | |
| import tensorflow_hub as hub | |
| import numpy as np | |
| import csv | |
| import matplotlib.pyplot as plt | |
| from IPython.display import Audio | |
| from scipy.io import wavfile | |
| # Load the model. | |
| model = hub.load('https://tfhub.dev/google/yamnet/1') | |
| # Find the name of the class with the top score when mean-aggregated across frames. | |
| def class_names_from_csv(class_map_csv_text): | |
| """Returns list of class names corresponding to score vector.""" | |
| class_names = [] | |
| with tf.io.gfile.GFile(class_map_csv_text) as csvfile: | |
| reader = csv.DictReader(csvfile) | |
| for row in reader: | |
| class_names.append(row['display_name']) | |
| return class_names | |
| class_map_path = model.class_map_path().numpy() | |
| class_names = class_names_from_csv(class_map_path) | |
| def ensure_sample_rate(original_sample_rate, waveform, | |
| desired_sample_rate=16000): | |
| """Resample waveform if required.""" | |
| if original_sample_rate != desired_sample_rate: | |
| desired_length = int(round(float(len(waveform)) / | |
| original_sample_rate * desired_sample_rate)) | |
| waveform = scipy.signal.resample(waveform, desired_length) | |
| return desired_sample_rate, waveform | |
| os.system("wget https://storage.googleapis.com/audioset/miaow_16k.wav") | |
| def inference(audio): | |
| # wav_file_name = 'speech_whistling2.wav' | |
| wav_file_name = audio | |
| sample_rate, wav_data = wavfile.read(wav_file_name, 'rb') | |
| sample_rate, wav_data = ensure_sample_rate(sample_rate, wav_data) | |
| waveform = wav_data / tf.int16.max | |
| # Run the model, check the output. | |
| scores, embeddings, spectrogram = model(waveform) | |
| scores_np = scores.numpy() | |
| spectrogram_np = spectrogram.numpy() | |
| infered_class = class_names[scores_np.mean(axis=0).argmax()] | |
| return f'The main sound is: {infered_class}' | |
| gr.Interface(inference,"audio","text").launch() | |