Spaces:
Running
Running
Update flow.py
Browse files
flow.py
CHANGED
|
@@ -15,7 +15,26 @@ import threading
|
|
| 15 |
import torch
|
| 16 |
import torch.nn.functional as F
|
| 17 |
from matcha.models.components.flow_matching import BASECFM
|
|
|
|
| 18 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 19 |
|
| 20 |
class ConditionalCFM(BASECFM):
|
| 21 |
def __init__(self, in_channels, cfm_params, n_spks=1, spk_emb_dim=64, estimator: torch.nn.Module = None):
|
|
@@ -125,22 +144,50 @@ class ConditionalCFM(BASECFM):
|
|
| 125 |
if isinstance(self.estimator, torch.nn.Module):
|
| 126 |
return self.estimator.forward(x, mask, mu, t, spks, cond)
|
| 127 |
else:
|
| 128 |
-
|
| 129 |
-
|
| 130 |
-
|
| 131 |
-
|
| 132 |
-
|
| 133 |
-
|
| 134 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 135 |
# run trt engine
|
| 136 |
-
|
| 137 |
-
|
| 138 |
-
|
| 139 |
-
|
| 140 |
-
|
| 141 |
-
|
| 142 |
-
|
| 143 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 144 |
|
| 145 |
def compute_loss(self, x1, mask, mu, spks=None, cond=None):
|
| 146 |
"""Computes diffusion loss
|
|
|
|
| 15 |
import torch
|
| 16 |
import torch.nn.functional as F
|
| 17 |
from matcha.models.components.flow_matching import BASECFM
|
| 18 |
+
import queue
|
| 19 |
|
| 20 |
+
class EstimatorWrapper:
|
| 21 |
+
def __init__(self, estimator_engine, estimator_count=2,):
|
| 22 |
+
self.estimators = queue.Queue()
|
| 23 |
+
self.estimator_engine = estimator_engine
|
| 24 |
+
for _ in range(estimator_count):
|
| 25 |
+
estimator = estimator_engine.create_execution_context()
|
| 26 |
+
if estimator is not None:
|
| 27 |
+
self.estimators.put(estimator)
|
| 28 |
+
|
| 29 |
+
if self.estimators.empty():
|
| 30 |
+
raise Exception("No available estimator")
|
| 31 |
+
|
| 32 |
+
def acquire_estimator(self):
|
| 33 |
+
return self.estimators.get(), self.estimator_engine
|
| 34 |
+
|
| 35 |
+
def release_estimator(self, estimator):
|
| 36 |
+
self.estimators.put(estimator)
|
| 37 |
+
return
|
| 38 |
|
| 39 |
class ConditionalCFM(BASECFM):
|
| 40 |
def __init__(self, in_channels, cfm_params, n_spks=1, spk_emb_dim=64, estimator: torch.nn.Module = None):
|
|
|
|
| 144 |
if isinstance(self.estimator, torch.nn.Module):
|
| 145 |
return self.estimator.forward(x, mask, mu, t, spks, cond)
|
| 146 |
else:
|
| 147 |
+
if isinstance(self.estimator, EstimatorWrapper):
|
| 148 |
+
estimator, engine = self.estimator.acquire_estimator()
|
| 149 |
+
|
| 150 |
+
estimator.set_input_shape('x', (2, 80, x.size(2)))
|
| 151 |
+
estimator.set_input_shape('mask', (2, 1, x.size(2)))
|
| 152 |
+
estimator.set_input_shape('mu', (2, 80, x.size(2)))
|
| 153 |
+
estimator.set_input_shape('t', (2,))
|
| 154 |
+
estimator.set_input_shape('spks', (2, 80))
|
| 155 |
+
estimator.set_input_shape('cond', (2, 80, x.size(2)))
|
| 156 |
+
|
| 157 |
+
data_ptrs = [x.contiguous().data_ptr(),
|
| 158 |
+
mask.contiguous().data_ptr(),
|
| 159 |
+
mu.contiguous().data_ptr(),
|
| 160 |
+
t.contiguous().data_ptr(),
|
| 161 |
+
spks.contiguous().data_ptr(),
|
| 162 |
+
cond.contiguous().data_ptr(),
|
| 163 |
+
x.data_ptr()]
|
| 164 |
+
|
| 165 |
+
for idx, data_ptr in enumerate(data_ptrs):
|
| 166 |
+
estimator.set_tensor_address(engine.get_tensor_name(idx), data_ptr)
|
| 167 |
+
|
| 168 |
# run trt engine
|
| 169 |
+
estimator.execute_async_v3(torch.cuda.current_stream().cuda_stream)
|
| 170 |
+
|
| 171 |
+
torch.cuda.current_stream().synchronize()
|
| 172 |
+
self.estimator.release_estimator(estimator)
|
| 173 |
+
return x
|
| 174 |
+
else:
|
| 175 |
+
with self.lock:
|
| 176 |
+
self.estimator.set_input_shape('x', (2, 80, x.size(2)))
|
| 177 |
+
self.estimator.set_input_shape('mask', (2, 1, x.size(2)))
|
| 178 |
+
self.estimator.set_input_shape('mu', (2, 80, x.size(2)))
|
| 179 |
+
self.estimator.set_input_shape('t', (2,))
|
| 180 |
+
self.estimator.set_input_shape('spks', (2, 80))
|
| 181 |
+
self.estimator.set_input_shape('cond', (2, 80, x.size(2)))
|
| 182 |
+
# run trt engine
|
| 183 |
+
self.estimator.execute_v2([x.contiguous().data_ptr(),
|
| 184 |
+
mask.contiguous().data_ptr(),
|
| 185 |
+
mu.contiguous().data_ptr(),
|
| 186 |
+
t.contiguous().data_ptr(),
|
| 187 |
+
spks.contiguous().data_ptr(),
|
| 188 |
+
cond.contiguous().data_ptr(),
|
| 189 |
+
x.data_ptr()])
|
| 190 |
+
return x
|
| 191 |
|
| 192 |
def compute_loss(self, x1, mask, mu, spks=None, cond=None):
|
| 193 |
"""Computes diffusion loss
|