Spaces:
Running
Running
jhj0517
commited on
Commit
·
a1b32c1
1
Parent(s):
a377305
Use `generate_file()`
Browse files
modules/utils/subtitle_manager.py
CHANGED
|
@@ -5,10 +5,11 @@ import os
|
|
| 5 |
import re
|
| 6 |
import sys
|
| 7 |
import zlib
|
| 8 |
-
from typing import Callable, List, Optional, TextIO, Union, Dict
|
| 9 |
from datetime import datetime
|
| 10 |
|
| 11 |
from modules.whisper.data_classes import Segment
|
|
|
|
| 12 |
|
| 13 |
|
| 14 |
def format_timestamp(
|
|
@@ -61,7 +62,7 @@ class ResultWriter:
|
|
| 61 |
|
| 62 |
if add_timestamp:
|
| 63 |
timestamp = datetime.now().strftime("%m%d%H%M%S")
|
| 64 |
-
output_file_name += timestamp
|
| 65 |
|
| 66 |
output_path = os.path.join(
|
| 67 |
self.output_dir, output_file_name + "." + self.extension
|
|
@@ -264,6 +265,8 @@ class WriteJSON(ResultWriter):
|
|
| 264 |
def get_writer(
|
| 265 |
output_format: str, output_dir: str
|
| 266 |
) -> Callable[[dict, TextIO, dict], None]:
|
|
|
|
|
|
|
| 267 |
writers = {
|
| 268 |
"txt": WriteTXT,
|
| 269 |
"vtt": WriteVTT,
|
|
@@ -286,6 +289,16 @@ def get_writer(
|
|
| 286 |
return writers[output_format](output_dir)
|
| 287 |
|
| 288 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 289 |
def parse_srt(file_path):
|
| 290 |
"""Reads SRT file and returns as dict"""
|
| 291 |
with open(file_path, 'r', encoding='utf-8') as file:
|
|
|
|
| 5 |
import re
|
| 6 |
import sys
|
| 7 |
import zlib
|
| 8 |
+
from typing import Callable, List, Optional, TextIO, Union, Dict, Tuple
|
| 9 |
from datetime import datetime
|
| 10 |
|
| 11 |
from modules.whisper.data_classes import Segment
|
| 12 |
+
from .files_manager import read_file
|
| 13 |
|
| 14 |
|
| 15 |
def format_timestamp(
|
|
|
|
| 62 |
|
| 63 |
if add_timestamp:
|
| 64 |
timestamp = datetime.now().strftime("%m%d%H%M%S")
|
| 65 |
+
output_file_name += f"-{timestamp}"
|
| 66 |
|
| 67 |
output_path = os.path.join(
|
| 68 |
self.output_dir, output_file_name + "." + self.extension
|
|
|
|
| 265 |
def get_writer(
|
| 266 |
output_format: str, output_dir: str
|
| 267 |
) -> Callable[[dict, TextIO, dict], None]:
|
| 268 |
+
output_format = output_format.strip().lower()
|
| 269 |
+
|
| 270 |
writers = {
|
| 271 |
"txt": WriteTXT,
|
| 272 |
"vtt": WriteVTT,
|
|
|
|
| 289 |
return writers[output_format](output_dir)
|
| 290 |
|
| 291 |
|
| 292 |
+
def generate_file(
|
| 293 |
+
output_format: str, output_dir: str, result: Union[dict, List[Segment]], output_file_name: str, add_timestamp: bool = True,
|
| 294 |
+
) -> Tuple[str, str]:
|
| 295 |
+
file_path = os.path.join(output_dir, f"{output_file_name}.{output_format}")
|
| 296 |
+
file_writer = get_writer(output_format=output_format, output_dir=output_dir)
|
| 297 |
+
file_writer(result=result, output_file_name=output_file_name, add_timestamp=add_timestamp)
|
| 298 |
+
content = read_file(file_path)
|
| 299 |
+
return content, file_path
|
| 300 |
+
|
| 301 |
+
|
| 302 |
def parse_srt(file_path):
|
| 303 |
"""Reads SRT file and returns as dict"""
|
| 304 |
with open(file_path, 'r', encoding='utf-8') as file:
|
modules/whisper/base_transcription_pipeline.py
CHANGED
|
@@ -13,9 +13,9 @@ from modules.uvr.music_separator import MusicSeparator
|
|
| 13 |
from modules.utils.paths import (WHISPER_MODELS_DIR, DIARIZATION_MODELS_DIR, OUTPUT_DIR, DEFAULT_PARAMETERS_CONFIG_PATH,
|
| 14 |
UVR_MODELS_DIR)
|
| 15 |
from modules.utils.constants import *
|
| 16 |
-
from modules.utils.subtitle_manager import
|
| 17 |
from modules.utils.youtube_manager import get_ytdata, get_ytaudio
|
| 18 |
-
from modules.utils.files_manager import get_media_files, format_gradio_files, load_yaml, save_yaml
|
| 19 |
from modules.whisper.data_classes import *
|
| 20 |
from modules.diarize.diarizer import Diarizer
|
| 21 |
from modules.vad.silero_vad import SileroVAD
|
|
@@ -224,14 +224,14 @@ class BaseTranscriptionPipeline(ABC):
|
|
| 224 |
)
|
| 225 |
|
| 226 |
file_name, file_ext = os.path.splitext(os.path.basename(file))
|
| 227 |
-
subtitle, file_path =
|
| 228 |
-
|
| 229 |
-
|
| 230 |
-
|
| 231 |
-
|
| 232 |
-
|
| 233 |
)
|
| 234 |
-
files_info[file_name] = {"subtitle":
|
| 235 |
|
| 236 |
total_result = ''
|
| 237 |
total_time = 0
|
|
@@ -291,16 +291,17 @@ class BaseTranscriptionPipeline(ABC):
|
|
| 291 |
)
|
| 292 |
progress(1, desc="Completed!")
|
| 293 |
|
| 294 |
-
|
| 295 |
-
|
| 296 |
-
|
| 297 |
-
|
| 298 |
-
|
| 299 |
-
|
|
|
|
| 300 |
)
|
| 301 |
|
| 302 |
result_str = f"Done in {self.format_time(time_for_task)}! Subtitle file is in the outputs folder.\n\n{subtitle}"
|
| 303 |
-
return [result_str,
|
| 304 |
except Exception as e:
|
| 305 |
print(f"Error transcribing file: {e}")
|
| 306 |
finally:
|
|
@@ -351,19 +352,20 @@ class BaseTranscriptionPipeline(ABC):
|
|
| 351 |
progress(1, desc="Completed!")
|
| 352 |
|
| 353 |
file_name = safe_filename(yt.title)
|
| 354 |
-
subtitle,
|
| 355 |
-
|
| 356 |
-
|
| 357 |
-
|
| 358 |
-
|
| 359 |
-
|
| 360 |
)
|
|
|
|
| 361 |
result_str = f"Done in {self.format_time(time_for_task)}! Subtitle file is in the outputs folder.\n\n{subtitle}"
|
| 362 |
|
| 363 |
if os.path.exists(audio):
|
| 364 |
os.remove(audio)
|
| 365 |
|
| 366 |
-
return [result_str,
|
| 367 |
|
| 368 |
except Exception as e:
|
| 369 |
print(f"Error transcribing file: {e}")
|
|
@@ -384,58 +386,6 @@ class BaseTranscriptionPipeline(ABC):
|
|
| 384 |
else:
|
| 385 |
return list(ctranslate2.get_supported_compute_types("cpu"))
|
| 386 |
|
| 387 |
-
@staticmethod
|
| 388 |
-
def generate_and_write_file(file_name: str,
|
| 389 |
-
transcribed_segments: list,
|
| 390 |
-
add_timestamp: bool,
|
| 391 |
-
file_format: str,
|
| 392 |
-
output_dir: str
|
| 393 |
-
) -> str:
|
| 394 |
-
"""
|
| 395 |
-
Writes subtitle file
|
| 396 |
-
|
| 397 |
-
Parameters
|
| 398 |
-
----------
|
| 399 |
-
file_name: str
|
| 400 |
-
Output file name
|
| 401 |
-
transcribed_segments: list
|
| 402 |
-
Text segments transcribed from audio
|
| 403 |
-
add_timestamp: bool
|
| 404 |
-
Determines whether to add a timestamp to the end of the filename.
|
| 405 |
-
file_format: str
|
| 406 |
-
File format to write. Supported formats: [SRT, WebVTT, txt]
|
| 407 |
-
output_dir: str
|
| 408 |
-
Directory path of the output
|
| 409 |
-
|
| 410 |
-
Returns
|
| 411 |
-
----------
|
| 412 |
-
content: str
|
| 413 |
-
Result of the transcription
|
| 414 |
-
output_path: str
|
| 415 |
-
output file path
|
| 416 |
-
"""
|
| 417 |
-
if add_timestamp:
|
| 418 |
-
timestamp = datetime.now().strftime("%m%d%H%M%S")
|
| 419 |
-
output_path = os.path.join(output_dir, f"{file_name}-{timestamp}")
|
| 420 |
-
else:
|
| 421 |
-
output_path = os.path.join(output_dir, f"{file_name}")
|
| 422 |
-
|
| 423 |
-
file_format = file_format.strip().lower()
|
| 424 |
-
if file_format == "srt":
|
| 425 |
-
content = get_srt(transcribed_segments)
|
| 426 |
-
output_path += '.srt'
|
| 427 |
-
|
| 428 |
-
elif file_format == "webvtt":
|
| 429 |
-
content = get_vtt(transcribed_segments)
|
| 430 |
-
output_path += '.vtt'
|
| 431 |
-
|
| 432 |
-
elif file_format == "txt":
|
| 433 |
-
content = get_txt(transcribed_segments)
|
| 434 |
-
output_path += '.txt'
|
| 435 |
-
|
| 436 |
-
write_file(content, output_path)
|
| 437 |
-
return content, output_path
|
| 438 |
-
|
| 439 |
@staticmethod
|
| 440 |
def format_time(elapsed_time: float) -> str:
|
| 441 |
"""
|
|
|
|
| 13 |
from modules.utils.paths import (WHISPER_MODELS_DIR, DIARIZATION_MODELS_DIR, OUTPUT_DIR, DEFAULT_PARAMETERS_CONFIG_PATH,
|
| 14 |
UVR_MODELS_DIR)
|
| 15 |
from modules.utils.constants import *
|
| 16 |
+
from modules.utils.subtitle_manager import *
|
| 17 |
from modules.utils.youtube_manager import get_ytdata, get_ytaudio
|
| 18 |
+
from modules.utils.files_manager import get_media_files, format_gradio_files, load_yaml, save_yaml, read_file
|
| 19 |
from modules.whisper.data_classes import *
|
| 20 |
from modules.diarize.diarizer import Diarizer
|
| 21 |
from modules.vad.silero_vad import SileroVAD
|
|
|
|
| 224 |
)
|
| 225 |
|
| 226 |
file_name, file_ext = os.path.splitext(os.path.basename(file))
|
| 227 |
+
subtitle, file_path = generate_file(
|
| 228 |
+
output_dir=self.output_dir,
|
| 229 |
+
output_file_name=file_name,
|
| 230 |
+
output_format=file_format,
|
| 231 |
+
result=transcribed_segments,
|
| 232 |
+
add_timestamp=add_timestamp
|
| 233 |
)
|
| 234 |
+
files_info[file_name] = {"subtitle": read_file(file_path), "time_for_task": time_for_task, "path": file_path}
|
| 235 |
|
| 236 |
total_result = ''
|
| 237 |
total_time = 0
|
|
|
|
| 291 |
)
|
| 292 |
progress(1, desc="Completed!")
|
| 293 |
|
| 294 |
+
file_name = "Mic"
|
| 295 |
+
subtitle, file_path = generate_file(
|
| 296 |
+
output_dir=self.output_dir,
|
| 297 |
+
output_file_name=file_name,
|
| 298 |
+
output_format=file_format,
|
| 299 |
+
result=transcribed_segments,
|
| 300 |
+
add_timestamp=add_timestamp
|
| 301 |
)
|
| 302 |
|
| 303 |
result_str = f"Done in {self.format_time(time_for_task)}! Subtitle file is in the outputs folder.\n\n{subtitle}"
|
| 304 |
+
return [result_str, file_path]
|
| 305 |
except Exception as e:
|
| 306 |
print(f"Error transcribing file: {e}")
|
| 307 |
finally:
|
|
|
|
| 352 |
progress(1, desc="Completed!")
|
| 353 |
|
| 354 |
file_name = safe_filename(yt.title)
|
| 355 |
+
subtitle, file_path = generate_file(
|
| 356 |
+
output_dir=self.output_dir,
|
| 357 |
+
output_file_name=file_name,
|
| 358 |
+
output_format=file_format,
|
| 359 |
+
result=transcribed_segments,
|
| 360 |
+
add_timestamp=add_timestamp
|
| 361 |
)
|
| 362 |
+
|
| 363 |
result_str = f"Done in {self.format_time(time_for_task)}! Subtitle file is in the outputs folder.\n\n{subtitle}"
|
| 364 |
|
| 365 |
if os.path.exists(audio):
|
| 366 |
os.remove(audio)
|
| 367 |
|
| 368 |
+
return [result_str, file_path]
|
| 369 |
|
| 370 |
except Exception as e:
|
| 371 |
print(f"Error transcribing file: {e}")
|
|
|
|
| 386 |
else:
|
| 387 |
return list(ctranslate2.get_supported_compute_types("cpu"))
|
| 388 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 389 |
@staticmethod
|
| 390 |
def format_time(elapsed_time: float) -> str:
|
| 391 |
"""
|