Spaces:
Paused
Paused
| import os, sys, re | |
| import shutil | |
| import subprocess | |
| import soundfile | |
| from process_audio import segment_audio | |
| from write_srt import write_to_file | |
| from clean_text import clean_english, clean_german, clean_spanish | |
| from transformers import Wav2Vec2Processor, Wav2Vec2ForCTC | |
| from transformers import AutoModelForCTC, AutoProcessor | |
| import torch | |
| import gradio as gr | |
| english_model = "facebook/wav2vec2-large-960h-lv60-self" | |
| english_tokenizer = Wav2Vec2Processor.from_pretrained(english_model) | |
| english_asr_model = Wav2Vec2ForCTC.from_pretrained(english_model) | |
| german_model = "flozi00/wav2vec2-large-xlsr-53-german-with-lm" | |
| german_tokenizer = Wav2Vec2Processor.from_pretrained(german_model) | |
| german_asr_model = Wav2Vec2ForCTC.from_pretrained(german_model) | |
| spanish_model = "patrickvonplaten/wav2vec2-large-xlsr-53-spanish-with-lm" | |
| spanish_tokenizer = Wav2Vec2Processor.from_pretrained(spanish_model) | |
| spanish_asr_model = Wav2Vec2ForCTC.from_pretrained(spanish_model) | |
| # Get German corpus and update nltk | |
| command = ["python", "-m", "textblob.download_corpora"] | |
| subprocess.run(command) | |
| # Line count for SRT file | |
| line_count = 0 | |
| def sort_alphanumeric(data): | |
| convert = lambda text: int(text) if text.isdigit() else text.lower() | |
| alphanum_key = lambda key: [convert(c) for c in re.split('([0-9]+)', key)] | |
| return sorted(data, key = alphanum_key) | |
| def transcribe_audio(tokenizer, asr_model, audio_file, file_handle): | |
| # Run Wav2Vec2.0 inference on each audio file generated after VAD segmentation. | |
| global line_count | |
| speech, rate = soundfile.read(audio_file) | |
| input_values = tokenizer(speech, sampling_rate=16000, return_tensors = "pt", padding='longest').input_values | |
| logits = asr_model(input_values).logits | |
| prediction = torch.argmax(logits, dim = -1) | |
| infered_text = tokenizer.batch_decode(prediction)[0].lower() | |
| if len(infered_text) > 1: | |
| if lang == 'english': | |
| infered_text = clean_english(infered_text) | |
| elif lang == 'german': | |
| infered_text = clean_german(infered_text) | |
| elif lang == 'spanish': | |
| infered_text = clean_spanish(infered_text) | |
| print(infered_text) | |
| limits = audio_file.split(os.sep)[-1][:-4].split("_")[-1].split("-") | |
| line_count += 1 | |
| write_to_file(file_handle, infered_text, line_count, limits) | |
| else: | |
| infered_text = '' | |
| def get_subs(input_file, language): | |
| # Get directory for audio | |
| base_directory = os.getcwd() | |
| audio_directory = os.path.join(base_directory, "audio") | |
| if os.path.isdir(audio_directory): | |
| shutil.rmtree(audio_directory) | |
| os.mkdir(audio_directory) | |
| # Extract audio from video file | |
| video_file = input_file | |
| audio_file = audio_directory+'/temp.wav' | |
| command = ["ffmpeg", "-i", video_file, "-ac", "1", "-ar", "16000","-vn", "-f", "wav", audio_file] | |
| subprocess.run(command) | |
| video_file = input_file.split('/')[-1][:-4] | |
| srt_file_name = os.path.join(video_file + ".srt") | |
| # Split audio file based on VAD silent segments | |
| segment_audio(audio_file) | |
| os.remove(audio_file) | |
| # Output SRT file | |
| file_handle = open(srt_file_name, "a+") | |
| file_handle.seek(0) | |
| for file in sort_alphanumeric(os.listdir(audio_directory)): | |
| audio_segment_path = os.path.join(audio_directory, file) | |
| global lang | |
| lang = language.lower() | |
| tokenizer = globals()[lang+'_tokenizer'] | |
| asr_model = globals()[lang+'_asr_model'] | |
| if audio_segment_path.split(os.sep)[-1] != audio_file.split(os.sep)[-1]: | |
| transcribe_audio(tokenizer, asr_model, audio_segment_path, file_handle) | |
| file_handle.close() | |
| shutil.rmtree(audio_directory) | |
| return srt_file_name | |
| gradio_ui = gr.Interface( | |
| enable_queue=True, | |
| fn=get_subs, | |
| title="Video to Subtitle", | |
| description="Get subtitles (SRT file) for your videos. Inference speed is about 10s/per 1min of video BUT the speed of uploading your video depends on your internet connection.", | |
| inputs=[gr.inputs.Video(label="Upload Video File"), | |
| gr.inputs.Radio(label="Choose Language", choices=['English', 'German', 'Spanish'])], | |
| outputs=gr.outputs.File(label="Auto-Transcript") | |
| ) | |
| gradio_ui.launch() | |