Spaces:
Running
on
Zero
Running
on
Zero
| import spaces | |
| import gradio as gr | |
| import huggingface_hub | |
| import os | |
| import subprocess | |
| import threading | |
| import shutil | |
| import numpy as np | |
| import matplotlib.pyplot as plt | |
| from scipy.io import wavfile | |
| from moviepy.editor import VideoFileClip, AudioFileClip | |
| # download model | |
| huggingface_hub.snapshot_download( | |
| repo_id='ariesssxu/vta-ldm-clip4clip-v-large', | |
| local_dir='./ckpt/vta-ldm-clip4clip-v-large' | |
| ) | |
| def stream_output(pipe): | |
| for line in iter(pipe.readline, ''): | |
| print(line, end='') | |
| def print_directory_contents(path): | |
| for root, dirs, files in os.walk(path): | |
| level = root.replace(path, '').count(os.sep) | |
| indent = ' ' * 4 * (level) | |
| print(f"{indent}{os.path.basename(root)}/") | |
| subindent = ' ' * 4 * (level + 1) | |
| for f in files: | |
| print(f"{subindent}{f}") | |
| # Print the ckpt directory contents | |
| print_directory_contents('./ckpt') | |
| def get_wav_files(path): | |
| wav_files = [] # Initialize an empty list to store the paths of .wav files | |
| for root, dirs, files in os.walk(path): | |
| level = root.replace(path, '').count(os.sep) | |
| indent = ' ' * 4 * (level) | |
| print(f"{indent}{os.path.basename(root)}/") | |
| subindent = ' ' * 4 * (level + 1) | |
| for f in files: | |
| file_path = os.path.join(root, f) | |
| if f.lower().endswith('.wav'): | |
| wav_files.append(file_path) # Add .wav file paths to the list | |
| print(f"{subindent}{file_path}") | |
| else: | |
| print(f"{subindent}{f}") | |
| return wav_files # Return the list of .wav file paths | |
| def check_outputs_folder(folder_path): | |
| # Check if the folder exists | |
| if os.path.exists(folder_path) and os.path.isdir(folder_path): | |
| # Delete all contents inside the folder | |
| for filename in os.listdir(folder_path): | |
| file_path = os.path.join(folder_path, filename) | |
| try: | |
| if os.path.isfile(file_path) or os.path.islink(file_path): | |
| os.unlink(file_path) # Remove file or link | |
| elif os.path.isdir(file_path): | |
| shutil.rmtree(file_path) # Remove directory | |
| except Exception as e: | |
| print(f'Failed to delete {file_path}. Reason: {e}') | |
| else: | |
| print(f'The folder {folder_path} does not exist.') | |
| def plot_spectrogram(wav_file, output_image): | |
| # Read the WAV file | |
| sample_rate, audio_data = wavfile.read(wav_file) | |
| # Check if audio_data is stereo (2 channels) and convert it to mono (1 channel) if needed | |
| if len(audio_data.shape) == 2: | |
| audio_data = audio_data.mean(axis=1) | |
| # Create a plot for the spectrogram | |
| plt.figure(figsize=(10, 2)) | |
| plt.specgram(audio_data, Fs=sample_rate, NFFT=1024, noverlap=512, cmap='gray', aspect='auto') | |
| # Remove gridlines and ticks for a cleaner look | |
| plt.grid(False) | |
| plt.xticks([]) | |
| plt.yticks([]) | |
| # Save the plot as an image file | |
| plt.savefig(output_image, bbox_inches='tight', pad_inches=0, dpi=300) | |
| plt.close | |
| def merge_audio_to_video(input_vid, input_aud): | |
| # Load the video file | |
| video = VideoFileClip(input_vid) | |
| # Load the new audio file | |
| new_audio = AudioFileClip(input_aud) | |
| # Set the new audio to the video | |
| video_with_new_audio = video.set_audio(new_audio) | |
| # Save the result to a new file | |
| video_with_new_audio.write_videofile("output_video.mp4", codec='libx264', audio_codec='aac') | |
| return "output_video.mp4" | |
| def infer(video_in): | |
| """Generate an audio track from a silent video using a pre-trained VTA (Video-to-Audio) model. | |
| This function performs the following steps: | |
| 1. Ensures the output directory is clean. | |
| 2. Optionally trims the video to a maximum of 10 seconds. | |
| 3. Runs inference using a pre-trained latent diffusion model to generate audio. | |
| 4. Finds the generated WAV audio output. | |
| 5. Plots a spectrogram of the generated audio. | |
| 6. Merges the audio back into the input video. | |
| Args: | |
| video_in (str): The file path to the input silent video (MP4 format). If the video is longer than 10 seconds, it will be trimmed. | |
| Returns: | |
| Tuple[str, str, str]: | |
| - The path to the generated `.wav` audio file. | |
| - The path to the generated spectrogram `.png` image. | |
| - The path to the final `.mp4` video with the generated audio merged in. | |
| Example: | |
| Given a silent video of a lion, this function will return: | |
| - A realistic generated audio track simulating the lion's sound, | |
| - A visual spectrogram representation of the audio, | |
| - And a new video file where the generated audio is synced to the original visuals. | |
| """ | |
| # check if 'outputs' dir exists and empty it if necessary | |
| check_outputs_folder('./outputs/tmp') | |
| # Need to find path to gradio temp vid from video input | |
| print(f"VIDEO IN PATH: {video_in}") | |
| # Get the directory name | |
| folder_path = os.path.dirname(video_in) | |
| # Path to the input video file | |
| input_video_path = video_in | |
| # Load the video file | |
| video = VideoFileClip(input_video_path) | |
| # Get the length of the video in seconds | |
| video_duration = int(video.duration) | |
| print(f"Video duration: {video_duration} seconds") | |
| # Check if the video duration is more than 10 seconds | |
| if video_duration > 10: | |
| # Cut the video to the first 10 seconds | |
| cut_video = video.subclip(0, 10) | |
| video_duration = 10 | |
| # Extract the directory and filename | |
| dir_name = os.path.dirname(input_video_path) | |
| base_name = os.path.basename(input_video_path) | |
| # Generate the new filename | |
| new_base_name = base_name.replace(".mp4", "_10sec_cut.mp4") | |
| output_video_path = os.path.join(dir_name, new_base_name) | |
| # Save the cut video | |
| cut_video.write_videofile(output_video_path, codec='libx264', audio_codec='aac') | |
| print(f"Cut video saved as: {output_video_path}") | |
| video_in = output_video_path | |
| # Delete the original video file | |
| os.remove(input_video_path) | |
| print(f"Original video file {input_video_path} deleted.") | |
| else: | |
| print("Video is 10 seconds or shorter; no cutting needed.") | |
| # Execute the inference command | |
| command = ['python', 'inference_from_video.py', | |
| '--original_args', 'ckpt/vta-ldm-clip4clip-v-large/summary.jsonl', | |
| '--model', 'ckpt/vta-ldm-clip4clip-v-large/pytorch_model_2.bin', | |
| '--data_path', folder_path, | |
| '--max_duration', f"{video_duration}" | |
| ] | |
| process = subprocess.Popen(command, stdout=subprocess.PIPE, stderr=subprocess.PIPE, text=True, bufsize=1) | |
| # Create threads to handle stdout and stderr | |
| stdout_thread = threading.Thread(target=stream_output, args=(process.stdout,)) | |
| stderr_thread = threading.Thread(target=stream_output, args=(process.stderr,)) | |
| # Start the threads | |
| stdout_thread.start() | |
| stderr_thread.start() | |
| # Wait for the process to complete and the threads to finish | |
| process.wait() | |
| stdout_thread.join() | |
| stderr_thread.join() | |
| print("Inference script finished with return code:", process.returncode) | |
| # Need to find where are the results stored, default should be "./outputs/tmp" | |
| # Print the outputs directory contents | |
| print_directory_contents('./outputs/tmp') | |
| wave_files = get_wav_files('./outputs/tmp') | |
| print(wave_files) | |
| plot_spectrogram(wave_files[0], 'spectrogram.png') | |
| final_merged_out = merge_audio_to_video(video_in, wave_files[0]) | |
| return wave_files[0], 'spectrogram.png', final_merged_out | |
| css=""" | |
| #col-container { | |
| max-width: 920px; | |
| margin: 0 auto; | |
| } | |
| """ | |
| with gr.Blocks(css=css) as demo: | |
| with gr.Column(elem_id="col-container"): | |
| gr.Markdown("# Video-to-Audio Generation with Hidden Alignment") | |
| gr.HTML(""" | |
| <div style="display:flex;column-gap:4px;"> | |
| <a href='https://sites.google.com/view/vta-ldm'> | |
| <img src='https://img.shields.io/badge/Project-Page-Green'> | |
| </a> | |
| <a href='https://huggingface.co/papers/2407.07464'> | |
| <img src='https://img.shields.io/badge/HF-Paper-red'> | |
| </a> | |
| </div> | |
| """) | |
| with gr.Row(): | |
| with gr.Column(): | |
| video_in = gr.Video(label='Video IN', format="mp4", include_audio=False) | |
| submit_btn = gr.Button("Submit") | |
| gr.Examples( | |
| examples = [ | |
| ["./examples/lion_gt.mp4"], | |
| ["./examples/ice_gt.mp4"], | |
| ["./examples/seashore.mp4"], | |
| ["./examples/typewriter.mp4"], | |
| ["./examples/tennis_gt.mp4"], | |
| ["./examples/chew.mp4"], | |
| ], | |
| inputs = [video_in] | |
| ) | |
| with gr.Column(): | |
| output_sound = gr.Audio(label="Audio OUT") | |
| output_spectrogram = gr.Image(label='Spectrogram') | |
| merged_out = gr.Video(label="Merged video + generated audio") | |
| submit_btn.click( | |
| fn = infer, | |
| inputs = [video_in], | |
| outputs = [output_sound, output_spectrogram, merged_out], | |
| show_api = True | |
| ) | |
| demo.launch(show_api=True, show_error=True, ssr_mode=False, mcp_server=True) |