Spaces:
Running
on
Zero
Running
on
Zero
John Meade
commited on
Commit
·
96bdb69
1
Parent(s):
af25078
vad trimming for ref wavs
Browse files- .gitignore +1 -0
- chatterbox/src/chatterbox/tts.py +27 -3
.gitignore
ADDED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
__pycache__
|
chatterbox/src/chatterbox/tts.py
CHANGED
|
@@ -2,10 +2,12 @@ from dataclasses import dataclass
|
|
| 2 |
from pathlib import Path
|
| 3 |
|
| 4 |
import librosa
|
|
|
|
| 5 |
import torch
|
| 6 |
import perth
|
| 7 |
import torch.nn.functional as F
|
| 8 |
from huggingface_hub import hf_hub_download
|
|
|
|
| 9 |
|
| 10 |
from .models.t3 import T3
|
| 11 |
from .models.s3tokenizer import S3_SR, drop_invalid_tokens
|
|
@@ -121,6 +123,7 @@ class ChatterboxTTS:
|
|
| 121 |
self.device = device
|
| 122 |
self.conds = conds
|
| 123 |
self.watermarker = perth.PerthImplicitWatermarker()
|
|
|
|
| 124 |
|
| 125 |
@classmethod
|
| 126 |
def from_local(cls, ckpt_dir, device) -> 'ChatterboxTTS':
|
|
@@ -162,11 +165,32 @@ class ChatterboxTTS:
|
|
| 162 |
|
| 163 |
return cls.from_local(Path(local_path).parent, device)
|
| 164 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 165 |
def prepare_conditionals(self, wav_fpath, exaggeration=0.5):
|
| 166 |
-
|
| 167 |
-
|
|
|
|
| 168 |
|
| 169 |
-
|
|
|
|
|
|
|
| 170 |
|
| 171 |
s3gen_ref_wav = s3gen_ref_wav[:self.DEC_COND_LEN]
|
| 172 |
s3gen_ref_dict = self.s3gen.embed_ref(s3gen_ref_wav, S3GEN_SR, device=self.device)
|
|
|
|
| 2 |
from pathlib import Path
|
| 3 |
|
| 4 |
import librosa
|
| 5 |
+
import numpy as np
|
| 6 |
import torch
|
| 7 |
import perth
|
| 8 |
import torch.nn.functional as F
|
| 9 |
from huggingface_hub import hf_hub_download
|
| 10 |
+
from silero_vad import load_silero_vad, get_speech_timestamps
|
| 11 |
|
| 12 |
from .models.t3 import T3
|
| 13 |
from .models.s3tokenizer import S3_SR, drop_invalid_tokens
|
|
|
|
| 123 |
self.device = device
|
| 124 |
self.conds = conds
|
| 125 |
self.watermarker = perth.PerthImplicitWatermarker()
|
| 126 |
+
self.silero_vad = load_silero_vad()
|
| 127 |
|
| 128 |
@classmethod
|
| 129 |
def from_local(cls, ckpt_dir, device) -> 'ChatterboxTTS':
|
|
|
|
| 165 |
|
| 166 |
return cls.from_local(Path(local_path).parent, device)
|
| 167 |
|
| 168 |
+
def trim_excess_silence(self, wav, sr):
|
| 169 |
+
"Trim excess silence from speech. Input must be a multiple of 16kHz."
|
| 170 |
+
assert sr % 16_000 == 0, "Silero requires an integer multiple of 16kHz"
|
| 171 |
+
|
| 172 |
+
# Get VAD as sample-level bool array
|
| 173 |
+
silero_regions = get_speech_timestamps(wav, self.silero_vad, sampling_rate=sr)
|
| 174 |
+
vad = np.zeros_like(wav)
|
| 175 |
+
for region in silero_regions:
|
| 176 |
+
vad[region["start"]:region["end"]] = 1
|
| 177 |
+
|
| 178 |
+
# Dilate VAD
|
| 179 |
+
max_silence_ms = 400
|
| 180 |
+
cfilter = np.ones(int(sr * max_silence_ms / (2 * 1000)))
|
| 181 |
+
dilated_vad = np.convolve(vad, cfilter, mode="same") > 0
|
| 182 |
+
|
| 183 |
+
# Trim out silence
|
| 184 |
+
return wav[dilated_vad]
|
| 185 |
+
|
| 186 |
def prepare_conditionals(self, wav_fpath, exaggeration=0.5):
|
| 187 |
+
# Load reference wav at high SR and trim silence
|
| 188 |
+
ref_wav, highres_sr = librosa.load(wav_fpath, sr=48_000)
|
| 189 |
+
ref_wav = self.trim_excess_silence(ref_wav, highres_sr)
|
| 190 |
|
| 191 |
+
# Resample down
|
| 192 |
+
s3gen_ref_wav = librosa.resample(ref_wav, orig_sr=highres_sr, target_sr=S3GEN_SR)
|
| 193 |
+
ref_16k_wav = librosa.resample(ref_wav, orig_sr=highres_sr, target_sr=S3_SR)
|
| 194 |
|
| 195 |
s3gen_ref_wav = s3gen_ref_wav[:self.DEC_COND_LEN]
|
| 196 |
s3gen_ref_dict = self.s3gen.embed_ref(s3gen_ref_wav, S3GEN_SR, device=self.device)
|