Spaces:
Running
Running
Update utils.py
Browse files
utils.py
CHANGED
|
@@ -47,13 +47,8 @@ def load_wav(wav, target_sr):
|
|
| 47 |
return speech
|
| 48 |
|
| 49 |
|
| 50 |
-
def convert_onnx_to_trt(trt_model, onnx_model, fp16):
|
| 51 |
import tensorrt as trt
|
| 52 |
-
_min_shape = [(2, 80, 4), (2, 1, 4), (2, 80, 4), (2,), (2, 80), (2, 80, 4)]
|
| 53 |
-
_opt_shape = [(2, 80, 193), (2, 1, 193), (2, 80, 193), (2,), (2, 80), (2, 80, 193)]
|
| 54 |
-
_max_shape = [(2, 80, 6800), (2, 1, 6800), (2, 80, 6800), (2,), (2, 80), (2, 80, 6800)]
|
| 55 |
-
input_names = ["x", "mask", "mu", "t", "spks", "cond"]
|
| 56 |
-
|
| 57 |
logging.info("Converting onnx to trt...")
|
| 58 |
network_flags = 1 << int(trt.NetworkDefinitionCreationFlag.EXPLICIT_BATCH)
|
| 59 |
logger = trt.Logger(trt.Logger.INFO)
|
|
@@ -61,7 +56,7 @@ def convert_onnx_to_trt(trt_model, onnx_model, fp16):
|
|
| 61 |
network = builder.create_network(network_flags)
|
| 62 |
parser = trt.OnnxParser(network, logger)
|
| 63 |
config = builder.create_builder_config()
|
| 64 |
-
config.set_memory_pool_limit(trt.MemoryPoolType.WORKSPACE, 1 <<
|
| 65 |
if fp16:
|
| 66 |
config.set_flag(trt.BuilderFlag.FP16)
|
| 67 |
profile = builder.create_optimization_profile()
|
|
@@ -72,8 +67,8 @@ def convert_onnx_to_trt(trt_model, onnx_model, fp16):
|
|
| 72 |
print(parser.get_error(error))
|
| 73 |
raise ValueError('failed to parse {}'.format(onnx_model))
|
| 74 |
# set input shapes
|
| 75 |
-
for i in range(len(input_names)):
|
| 76 |
-
profile.set_shape(input_names[i],
|
| 77 |
tensor_dtype = trt.DataType.HALF if fp16 else trt.DataType.FLOAT
|
| 78 |
# set input and output data type
|
| 79 |
for i in range(network.num_inputs):
|
|
@@ -86,4 +81,5 @@ def convert_onnx_to_trt(trt_model, onnx_model, fp16):
|
|
| 86 |
engine_bytes = builder.build_serialized_network(network, config)
|
| 87 |
# save trt engine
|
| 88 |
with open(trt_model, "wb") as f:
|
| 89 |
-
f.write(engine_bytes)
|
|
|
|
|
|
| 47 |
return speech
|
| 48 |
|
| 49 |
|
| 50 |
+
def convert_onnx_to_trt(trt_model, trt_kwargs, onnx_model, fp16):
|
| 51 |
import tensorrt as trt
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 52 |
logging.info("Converting onnx to trt...")
|
| 53 |
network_flags = 1 << int(trt.NetworkDefinitionCreationFlag.EXPLICIT_BATCH)
|
| 54 |
logger = trt.Logger(trt.Logger.INFO)
|
|
|
|
| 56 |
network = builder.create_network(network_flags)
|
| 57 |
parser = trt.OnnxParser(network, logger)
|
| 58 |
config = builder.create_builder_config()
|
| 59 |
+
config.set_memory_pool_limit(trt.MemoryPoolType.WORKSPACE, 1 << 33) # 8GB
|
| 60 |
if fp16:
|
| 61 |
config.set_flag(trt.BuilderFlag.FP16)
|
| 62 |
profile = builder.create_optimization_profile()
|
|
|
|
| 67 |
print(parser.get_error(error))
|
| 68 |
raise ValueError('failed to parse {}'.format(onnx_model))
|
| 69 |
# set input shapes
|
| 70 |
+
for i in range(len(trt_kwargs['input_names'])):
|
| 71 |
+
profile.set_shape(trt_kwargs['input_names'][i], trt_kwargs['min_shape'][i], trt_kwargs['opt_shape'][i], trt_kwargs['max_shape'][i])
|
| 72 |
tensor_dtype = trt.DataType.HALF if fp16 else trt.DataType.FLOAT
|
| 73 |
# set input and output data type
|
| 74 |
for i in range(network.num_inputs):
|
|
|
|
| 81 |
engine_bytes = builder.build_serialized_network(network, config)
|
| 82 |
# save trt engine
|
| 83 |
with open(trt_model, "wb") as f:
|
| 84 |
+
f.write(engine_bytes)
|
| 85 |
+
logging.info("Succesfully convert onnx to trt...")
|