John Meade commited on
Commit
96bdb69
·
1 Parent(s): af25078

vad trimming for ref wavs

Browse files
Files changed (2) hide show
  1. .gitignore +1 -0
  2. chatterbox/src/chatterbox/tts.py +27 -3
.gitignore ADDED
@@ -0,0 +1 @@
 
 
1
+ __pycache__
chatterbox/src/chatterbox/tts.py CHANGED
@@ -2,10 +2,12 @@ from dataclasses import dataclass
2
  from pathlib import Path
3
 
4
  import librosa
 
5
  import torch
6
  import perth
7
  import torch.nn.functional as F
8
  from huggingface_hub import hf_hub_download
 
9
 
10
  from .models.t3 import T3
11
  from .models.s3tokenizer import S3_SR, drop_invalid_tokens
@@ -121,6 +123,7 @@ class ChatterboxTTS:
121
  self.device = device
122
  self.conds = conds
123
  self.watermarker = perth.PerthImplicitWatermarker()
 
124
 
125
  @classmethod
126
  def from_local(cls, ckpt_dir, device) -> 'ChatterboxTTS':
@@ -162,11 +165,32 @@ class ChatterboxTTS:
162
 
163
  return cls.from_local(Path(local_path).parent, device)
164
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
165
  def prepare_conditionals(self, wav_fpath, exaggeration=0.5):
166
- ## Load reference wav
167
- s3gen_ref_wav, _sr = librosa.load(wav_fpath, sr=S3GEN_SR)
 
168
 
169
- ref_16k_wav = librosa.resample(s3gen_ref_wav, orig_sr=S3GEN_SR, target_sr=S3_SR)
 
 
170
 
171
  s3gen_ref_wav = s3gen_ref_wav[:self.DEC_COND_LEN]
172
  s3gen_ref_dict = self.s3gen.embed_ref(s3gen_ref_wav, S3GEN_SR, device=self.device)
 
2
  from pathlib import Path
3
 
4
  import librosa
5
+ import numpy as np
6
  import torch
7
  import perth
8
  import torch.nn.functional as F
9
  from huggingface_hub import hf_hub_download
10
+ from silero_vad import load_silero_vad, get_speech_timestamps
11
 
12
  from .models.t3 import T3
13
  from .models.s3tokenizer import S3_SR, drop_invalid_tokens
 
123
  self.device = device
124
  self.conds = conds
125
  self.watermarker = perth.PerthImplicitWatermarker()
126
+ self.silero_vad = load_silero_vad()
127
 
128
  @classmethod
129
  def from_local(cls, ckpt_dir, device) -> 'ChatterboxTTS':
 
165
 
166
  return cls.from_local(Path(local_path).parent, device)
167
 
168
+ def trim_excess_silence(self, wav, sr):
169
+ "Trim excess silence from speech. Input must be a multiple of 16kHz."
170
+ assert sr % 16_000 == 0, "Silero requires an integer multiple of 16kHz"
171
+
172
+ # Get VAD as sample-level bool array
173
+ silero_regions = get_speech_timestamps(wav, self.silero_vad, sampling_rate=sr)
174
+ vad = np.zeros_like(wav)
175
+ for region in silero_regions:
176
+ vad[region["start"]:region["end"]] = 1
177
+
178
+ # Dilate VAD
179
+ max_silence_ms = 400
180
+ cfilter = np.ones(int(sr * max_silence_ms / (2 * 1000)))
181
+ dilated_vad = np.convolve(vad, cfilter, mode="same") > 0
182
+
183
+ # Trim out silence
184
+ return wav[dilated_vad]
185
+
186
  def prepare_conditionals(self, wav_fpath, exaggeration=0.5):
187
+ # Load reference wav at high SR and trim silence
188
+ ref_wav, highres_sr = librosa.load(wav_fpath, sr=48_000)
189
+ ref_wav = self.trim_excess_silence(ref_wav, highres_sr)
190
 
191
+ # Resample down
192
+ s3gen_ref_wav = librosa.resample(ref_wav, orig_sr=highres_sr, target_sr=S3GEN_SR)
193
+ ref_16k_wav = librosa.resample(ref_wav, orig_sr=highres_sr, target_sr=S3_SR)
194
 
195
  s3gen_ref_wav = s3gen_ref_wav[:self.DEC_COND_LEN]
196
  s3gen_ref_dict = self.s3gen.embed_ref(s3gen_ref_wav, S3GEN_SR, device=self.device)