Spaces:
Running
Running
jhj0517
commited on
Commit
·
e284444
1
Parent(s):
8a43431
Add gradio parameter `file_format` to cache
Browse files
modules/whisper/base_transcription_pipeline.py
CHANGED
|
@@ -71,6 +71,7 @@ class BaseTranscriptionPipeline(ABC):
|
|
| 71 |
def run(self,
|
| 72 |
audio: Union[str, BinaryIO, np.ndarray],
|
| 73 |
progress: gr.Progress = gr.Progress(),
|
|
|
|
| 74 |
add_timestamp: bool = True,
|
| 75 |
*pipeline_params,
|
| 76 |
) -> Tuple[List[Segment], float]:
|
|
@@ -86,6 +87,8 @@ class BaseTranscriptionPipeline(ABC):
|
|
| 86 |
Audio input. This can be file path or binary type.
|
| 87 |
progress: gr.Progress
|
| 88 |
Indicator to show progress directly in gradio.
|
|
|
|
|
|
|
| 89 |
add_timestamp: bool
|
| 90 |
Whether to add a timestamp at the end of the filename.
|
| 91 |
*pipeline_params: tuple
|
|
@@ -168,6 +171,7 @@ class BaseTranscriptionPipeline(ABC):
|
|
| 168 |
|
| 169 |
self.cache_parameters(
|
| 170 |
params=params,
|
|
|
|
| 171 |
add_timestamp=add_timestamp
|
| 172 |
)
|
| 173 |
return result, elapsed_time
|
|
@@ -224,6 +228,7 @@ class BaseTranscriptionPipeline(ABC):
|
|
| 224 |
transcribed_segments, time_for_task = self.run(
|
| 225 |
file,
|
| 226 |
progress,
|
|
|
|
| 227 |
add_timestamp,
|
| 228 |
*pipeline_params,
|
| 229 |
)
|
|
@@ -298,6 +303,7 @@ class BaseTranscriptionPipeline(ABC):
|
|
| 298 |
transcribed_segments, time_for_task = self.run(
|
| 299 |
mic_audio,
|
| 300 |
progress,
|
|
|
|
| 301 |
add_timestamp,
|
| 302 |
*pipeline_params,
|
| 303 |
)
|
|
@@ -364,6 +370,7 @@ class BaseTranscriptionPipeline(ABC):
|
|
| 364 |
transcribed_segments, time_for_task = self.run(
|
| 365 |
audio,
|
| 366 |
progress,
|
|
|
|
| 367 |
add_timestamp,
|
| 368 |
*pipeline_params,
|
| 369 |
)
|
|
@@ -513,7 +520,8 @@ class BaseTranscriptionPipeline(ABC):
|
|
| 513 |
@staticmethod
|
| 514 |
def cache_parameters(
|
| 515 |
params: TranscriptionPipelineParams,
|
| 516 |
-
|
|
|
|
| 517 |
):
|
| 518 |
"""Cache parameters to the yaml file"""
|
| 519 |
cached_params = load_yaml(DEFAULT_PARAMETERS_CONFIG_PATH)
|
|
@@ -521,6 +529,7 @@ class BaseTranscriptionPipeline(ABC):
|
|
| 521 |
|
| 522 |
cached_yaml = {**cached_params, **param_to_cache}
|
| 523 |
cached_yaml["whisper"]["add_timestamp"] = add_timestamp
|
|
|
|
| 524 |
|
| 525 |
supress_token = cached_yaml["whisper"].get("suppress_tokens", None)
|
| 526 |
if supress_token and isinstance(supress_token, list):
|
|
|
|
| 71 |
def run(self,
|
| 72 |
audio: Union[str, BinaryIO, np.ndarray],
|
| 73 |
progress: gr.Progress = gr.Progress(),
|
| 74 |
+
file_format: str = "SRT",
|
| 75 |
add_timestamp: bool = True,
|
| 76 |
*pipeline_params,
|
| 77 |
) -> Tuple[List[Segment], float]:
|
|
|
|
| 87 |
Audio input. This can be file path or binary type.
|
| 88 |
progress: gr.Progress
|
| 89 |
Indicator to show progress directly in gradio.
|
| 90 |
+
file_format: str
|
| 91 |
+
Subtitle file format between ["SRT", "WebVTT", "txt", "lrc"]
|
| 92 |
add_timestamp: bool
|
| 93 |
Whether to add a timestamp at the end of the filename.
|
| 94 |
*pipeline_params: tuple
|
|
|
|
| 171 |
|
| 172 |
self.cache_parameters(
|
| 173 |
params=params,
|
| 174 |
+
file_format=file_format,
|
| 175 |
add_timestamp=add_timestamp
|
| 176 |
)
|
| 177 |
return result, elapsed_time
|
|
|
|
| 228 |
transcribed_segments, time_for_task = self.run(
|
| 229 |
file,
|
| 230 |
progress,
|
| 231 |
+
file_format,
|
| 232 |
add_timestamp,
|
| 233 |
*pipeline_params,
|
| 234 |
)
|
|
|
|
| 303 |
transcribed_segments, time_for_task = self.run(
|
| 304 |
mic_audio,
|
| 305 |
progress,
|
| 306 |
+
file_format,
|
| 307 |
add_timestamp,
|
| 308 |
*pipeline_params,
|
| 309 |
)
|
|
|
|
| 370 |
transcribed_segments, time_for_task = self.run(
|
| 371 |
audio,
|
| 372 |
progress,
|
| 373 |
+
file_format,
|
| 374 |
add_timestamp,
|
| 375 |
*pipeline_params,
|
| 376 |
)
|
|
|
|
| 520 |
@staticmethod
|
| 521 |
def cache_parameters(
|
| 522 |
params: TranscriptionPipelineParams,
|
| 523 |
+
file_format: str = "SRT",
|
| 524 |
+
add_timestamp: bool = True
|
| 525 |
):
|
| 526 |
"""Cache parameters to the yaml file"""
|
| 527 |
cached_params = load_yaml(DEFAULT_PARAMETERS_CONFIG_PATH)
|
|
|
|
| 529 |
|
| 530 |
cached_yaml = {**cached_params, **param_to_cache}
|
| 531 |
cached_yaml["whisper"]["add_timestamp"] = add_timestamp
|
| 532 |
+
cached_yaml["whisper"]["file_format"] = file_format
|
| 533 |
|
| 534 |
supress_token = cached_yaml["whisper"].get("suppress_tokens", None)
|
| 535 |
if supress_token and isinstance(supress_token, list):
|