Spaces:
Running
Running
jhj0517
commited on
Commit
·
131f180
1
Parent(s):
80e4171
Add resamples
Browse files
modules/whisper/whisper_base.py
CHANGED
|
@@ -2,6 +2,7 @@ import os
|
|
| 2 |
import torch
|
| 3 |
import whisper
|
| 4 |
import gradio as gr
|
|
|
|
| 5 |
from abc import ABC, abstractmethod
|
| 6 |
from typing import BinaryIO, Union, Tuple, List
|
| 7 |
import numpy as np
|
|
@@ -111,12 +112,19 @@ class WhisperBase(ABC):
|
|
| 111 |
|
| 112 |
if params.is_bgm_separate:
|
| 113 |
music, audio = self.music_separator.separate(
|
| 114 |
-
|
| 115 |
model_name=params.uvr_model_size,
|
| 116 |
device=params.uvr_device,
|
| 117 |
segment_size=params.uvr_segment_size,
|
|
|
|
| 118 |
progress=progress
|
| 119 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 120 |
self.music_separator.offload()
|
| 121 |
|
| 122 |
if params.vad_filter:
|
|
@@ -473,3 +481,18 @@ class WhisperBase(ABC):
|
|
| 473 |
cached_yaml["whisper"]["add_timestamp"] = add_timestamp
|
| 474 |
|
| 475 |
save_yaml(cached_yaml, DEFAULT_PARAMETERS_CONFIG_PATH)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 2 |
import torch
|
| 3 |
import whisper
|
| 4 |
import gradio as gr
|
| 5 |
+
import torchaudio
|
| 6 |
from abc import ABC, abstractmethod
|
| 7 |
from typing import BinaryIO, Union, Tuple, List
|
| 8 |
import numpy as np
|
|
|
|
| 112 |
|
| 113 |
if params.is_bgm_separate:
|
| 114 |
music, audio = self.music_separator.separate(
|
| 115 |
+
audio=audio,
|
| 116 |
model_name=params.uvr_model_size,
|
| 117 |
device=params.uvr_device,
|
| 118 |
segment_size=params.uvr_segment_size,
|
| 119 |
+
save_file=params.uvr_save_file,
|
| 120 |
progress=progress
|
| 121 |
)
|
| 122 |
+
|
| 123 |
+
if audio.ndim >= 2:
|
| 124 |
+
audio = audio.mean(axis=1)
|
| 125 |
+
origin_sample_rate = self.music_separator.audio_info.sample_rate
|
| 126 |
+
audio = self.resample_audio(audio=audio, original_sample_rate=origin_sample_rate)
|
| 127 |
+
|
| 128 |
self.music_separator.offload()
|
| 129 |
|
| 130 |
if params.vad_filter:
|
|
|
|
| 481 |
cached_yaml["whisper"]["add_timestamp"] = add_timestamp
|
| 482 |
|
| 483 |
save_yaml(cached_yaml, DEFAULT_PARAMETERS_CONFIG_PATH)
|
| 484 |
+
|
| 485 |
+
@staticmethod
|
| 486 |
+
def resample_audio(audio: Union[str, np.ndarray],
|
| 487 |
+
new_sample_rate: int = 16000,
|
| 488 |
+
original_sample_rate: Optional[int] = None,) -> np.ndarray:
|
| 489 |
+
"""Resamples audio to 16k sample rate, standard on Whisper model"""
|
| 490 |
+
if isinstance(audio, str):
|
| 491 |
+
audio, original_sample_rate = torchaudio.load(audio)
|
| 492 |
+
else:
|
| 493 |
+
if original_sample_rate is None:
|
| 494 |
+
raise ValueError("original_sample_rate must be provided when audio is numpy array.")
|
| 495 |
+
audio = torch.from_numpy(audio)
|
| 496 |
+
resampler = torchaudio.transforms.Resample(orig_freq=original_sample_rate, new_freq=new_sample_rate)
|
| 497 |
+
resampled_audio = resampler(audio).numpy()
|
| 498 |
+
return resampled_audio
|