Update app.py
Browse files
app.py
CHANGED
|
@@ -42,11 +42,18 @@ device = "cuda:0" if torch.cuda.is_available() else "cpu"
|
|
| 42 |
def download_audio(url, method_choice):
|
| 43 |
parsed_url = urlparse(url)
|
| 44 |
logging.info(f"Downloading audio from URL: {url} using method: {method_choice}")
|
| 45 |
-
|
| 46 |
-
|
| 47 |
-
|
| 48 |
-
|
| 49 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 50 |
def download_youtube_audio(url, method_choice):
|
| 51 |
methods = {
|
| 52 |
'yt-dlp': youtube_dl_method,
|
|
@@ -66,19 +73,24 @@ def download_youtube_audio(url, method_choice):
|
|
| 66 |
|
| 67 |
def youtube_dl_method(url):
|
| 68 |
logging.info("Using yt-dlp method")
|
| 69 |
-
|
| 70 |
-
|
| 71 |
-
|
| 72 |
-
'
|
| 73 |
-
|
| 74 |
-
|
| 75 |
-
|
| 76 |
-
|
| 77 |
-
|
| 78 |
-
|
| 79 |
-
|
| 80 |
-
|
| 81 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 82 |
|
| 83 |
def pytube_method(url):
|
| 84 |
logging.info("Using pytube method")
|
|
@@ -183,11 +195,11 @@ def trim_audio(audio_path, start_time, end_time):
|
|
| 183 |
|
| 184 |
# Validate times
|
| 185 |
if start_time < 0 or end_time < 0:
|
| 186 |
-
raise
|
| 187 |
if start_time >= end_time:
|
| 188 |
-
raise gr.Error("End time must be greater than start time.")
|
| 189 |
if start_time > audio_duration:
|
| 190 |
-
raise
|
| 191 |
|
| 192 |
trimmed_audio = audio[start_time * 1000:end_time * 1000]
|
| 193 |
trimmed_audio_path = tempfile.mktemp(suffix='.wav')
|
|
@@ -212,67 +224,40 @@ def get_model_options(pipeline_type):
|
|
| 212 |
else:
|
| 213 |
return []
|
| 214 |
|
|
|
|
|
|
|
| 215 |
def transcribe_audio(input_source, pipeline_type, model_id, dtype, batch_size, download_method, start_time=None, end_time=None, verbose=False):
|
| 216 |
try:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 217 |
# Determine if input_source is a URL or file
|
| 218 |
if isinstance(input_source, str):
|
| 219 |
if input_source.startswith('http://') or input_source.startswith('https://'):
|
| 220 |
audio_path = download_audio(input_source, download_method)
|
| 221 |
-
# Handle potential errors during download
|
| 222 |
if not audio_path or audio_path.startswith("Error"):
|
| 223 |
yield f"Error: {audio_path}", "", None
|
| 224 |
return
|
| 225 |
-
|
| 226 |
-
|
|
|
|
|
|
|
|
|
|
| 227 |
audio_path = input_source.name
|
| 228 |
logging.info(f"Using uploaded audio file: {audio_path}")
|
| 229 |
-
|
| 230 |
-
try:
|
| 231 |
-
logging.info(f"Transcription parameters: pipeline_type={pipeline_type}, model_id={model_id}, dtype={dtype}, batch_size={batch_size}, download_method={download_method}")
|
| 232 |
-
verbose_messages = f"Starting transcription with parameters:\nPipeline Type: {pipeline_type}\nModel ID: {model_id}\nData Type: {dtype}\nBatch Size: {batch_size}\nDownload Method: {download_method}\n"
|
| 233 |
-
|
| 234 |
-
if verbose:
|
| 235 |
-
yield verbose_messages, "", None
|
| 236 |
-
|
| 237 |
-
if pipeline_type == "faster-batched":
|
| 238 |
-
model = WhisperModel(model_id, device="auto", compute_type=dtype)
|
| 239 |
-
pipeline = BatchedInferencePipeline(model=model)
|
| 240 |
-
elif pipeline_type == "faster-sequenced":
|
| 241 |
-
model = WhisperModel(model_id)
|
| 242 |
-
pipeline = model.transcribe
|
| 243 |
-
elif pipeline_type == "transformers":
|
| 244 |
-
torch_dtype = torch.float16 if dtype == "float16" else torch.float32
|
| 245 |
-
model = AutoModelForSpeechSeq2Seq.from_pretrained(
|
| 246 |
-
model_id, torch_dtype=torch_dtype, low_cpu_mem_usage=True, use_safetensors=True
|
| 247 |
-
)
|
| 248 |
-
model.to(device)
|
| 249 |
-
processor = AutoProcessor.from_pretrained(model_id)
|
| 250 |
-
pipeline = pipeline(
|
| 251 |
-
"automatic-speech-recognition",
|
| 252 |
-
model=model,
|
| 253 |
-
tokenizer=processor.tokenizer,
|
| 254 |
-
feature_extractor=processor.feature_extractor,
|
| 255 |
-
chunk_length_s=30,
|
| 256 |
-
batch_size=batch_size,
|
| 257 |
-
return_timestamps=True,
|
| 258 |
-
torch_dtype=torch_dtype,
|
| 259 |
-
device=device,
|
| 260 |
-
)
|
| 261 |
else:
|
| 262 |
-
|
| 263 |
-
|
| 264 |
-
if isinstance(input_source, str) and (input_source.startswith('http://') or input_source.startswith('https://')):
|
| 265 |
-
audio_path = download_audio(input_source, download_method)
|
| 266 |
-
verbose_messages += f"Audio file downloaded: {audio_path}\n"
|
| 267 |
-
if verbose:
|
| 268 |
-
yield verbose_messages, "", None
|
| 269 |
-
|
| 270 |
-
if not audio_path or audio_path.startswith("Error"):
|
| 271 |
-
yield f"Error: {audio_path}", "", None
|
| 272 |
-
return
|
| 273 |
-
else:
|
| 274 |
-
audio_path = input_source
|
| 275 |
|
|
|
|
| 276 |
start_time = float(start_time) if start_time else None
|
| 277 |
end_time = float(end_time) if end_time else None
|
| 278 |
|
|
@@ -283,11 +268,47 @@ def transcribe_audio(input_source, pipeline_type, model_id, dtype, batch_size, d
|
|
| 283 |
if verbose:
|
| 284 |
yield verbose_messages, "", None
|
| 285 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 286 |
start_time_perf = time.time()
|
| 287 |
-
if pipeline_type
|
| 288 |
-
segments, info =
|
|
|
|
|
|
|
| 289 |
else:
|
| 290 |
-
result =
|
| 291 |
segments = result["chunks"]
|
| 292 |
end_time_perf = time.time()
|
| 293 |
|
|
@@ -305,11 +326,10 @@ def transcribe_audio(input_source, pipeline_type, model_id, dtype, batch_size, d
|
|
| 305 |
transcription = ""
|
| 306 |
|
| 307 |
for segment in segments:
|
| 308 |
-
|
| 309 |
-
f"[{segment.start:.2f}s -> {segment.end:.2f}s] {segment.text}\n"
|
| 310 |
-
|
| 311 |
-
f"[{segment['timestamp'][0]:.2f}s -> {segment['timestamp'][1]:.2f}s] {segment['text']}\n"
|
| 312 |
-
)
|
| 313 |
transcription += transcription_segment
|
| 314 |
if verbose:
|
| 315 |
yield verbose_messages + metrics_output, transcription, None
|
|
@@ -322,23 +342,21 @@ def transcribe_audio(input_source, pipeline_type, model_id, dtype, batch_size, d
|
|
| 322 |
yield f"An error occurred: {str(e)}", "", None
|
| 323 |
|
| 324 |
finally:
|
| 325 |
-
#
|
| 326 |
if audio_path and os.path.exists(audio_path):
|
| 327 |
os.remove(audio_path)
|
| 328 |
-
# Remove trimmed audio file
|
| 329 |
if 'trimmed_audio_path' in locals() and os.path.exists(trimmed_audio_path):
|
| 330 |
os.remove(trimmed_audio_path)
|
| 331 |
-
|
| 332 |
-
if transcription_file and os.path.exists(transcription_file):
|
| 333 |
os.remove(transcription_file)
|
| 334 |
-
|
| 335 |
|
| 336 |
with gr.Blocks() as iface:
|
| 337 |
gr.Markdown("# Multi-Pipeline Transcription")
|
| 338 |
gr.Markdown("Transcribe audio using multiple pipelines and models.")
|
| 339 |
|
| 340 |
with gr.Row():
|
| 341 |
-
input_source = gr.File(label="Audio Source (Upload a file or enter a URL/YouTube URL)")
|
|
|
|
| 342 |
pipeline_type = gr.Dropdown(
|
| 343 |
choices=["faster-batched", "faster-sequenced", "transformers"],
|
| 344 |
label="Pipeline Type",
|
|
@@ -375,7 +393,6 @@ with gr.Blocks() as iface:
|
|
| 375 |
try:
|
| 376 |
model_choices = get_model_options(pipeline_type)
|
| 377 |
logging.info(f"Model choices for {pipeline_type}: {model_choices}")
|
| 378 |
-
|
| 379 |
if model_choices:
|
| 380 |
return gr.update(choices=model_choices, value=model_choices[0], visible=True)
|
| 381 |
else:
|
|
@@ -383,9 +400,9 @@ with gr.Blocks() as iface:
|
|
| 383 |
except Exception as e:
|
| 384 |
logging.error(f"Error in update_model_dropdown: {str(e)}")
|
| 385 |
return gr.update(choices=["Error"], value="Error", visible=True)
|
| 386 |
-
|
| 387 |
-
#
|
| 388 |
-
pipeline_type.change(update_model_dropdown, inputs=[pipeline_type], outputs=model_id)
|
| 389 |
|
| 390 |
def transcribe_with_progress(*args):
|
| 391 |
for result in transcribe_audio(*args):
|
|
@@ -399,9 +416,9 @@ with gr.Blocks() as iface:
|
|
| 399 |
|
| 400 |
gr.Examples(
|
| 401 |
examples=[
|
| 402 |
-
|
| 403 |
-
|
| 404 |
-
|
| 405 |
],
|
| 406 |
inputs=[input_source, pipeline_type, model_id, dtype, batch_size, download_method, start_time, end_time, verbose],
|
| 407 |
)
|
|
|
|
| 42 |
def download_audio(url, method_choice):
|
| 43 |
parsed_url = urlparse(url)
|
| 44 |
logging.info(f"Downloading audio from URL: {url} using method: {method_choice}")
|
| 45 |
+
try:
|
| 46 |
+
if parsed_url.netloc in ['www.youtube.com', 'youtu.be', 'youtube.com']:
|
| 47 |
+
audio_file = download_youtube_audio(url, method_choice)
|
| 48 |
+
else:
|
| 49 |
+
audio_file = download_direct_audio(url, method_choice)
|
| 50 |
+
if not audio_file or not os.path.exists(audio_file):
|
| 51 |
+
raise Exception(f"Failed to download audio from {url}")
|
| 52 |
+
return audio_file
|
| 53 |
+
except Exception as e:
|
| 54 |
+
logging.error(f"Error downloading audio: {str(e)}")
|
| 55 |
+
return f"Error: {str(e)}"
|
| 56 |
+
|
| 57 |
def download_youtube_audio(url, method_choice):
|
| 58 |
methods = {
|
| 59 |
'yt-dlp': youtube_dl_method,
|
|
|
|
| 73 |
|
| 74 |
def youtube_dl_method(url):
|
| 75 |
logging.info("Using yt-dlp method")
|
| 76 |
+
try:
|
| 77 |
+
ydl_opts = {
|
| 78 |
+
'format': 'bestaudio/best',
|
| 79 |
+
'postprocessors': [{
|
| 80 |
+
'key': 'FFmpegExtractAudio',
|
| 81 |
+
'preferredcodec': 'mp3',
|
| 82 |
+
'preferredquality': '192',
|
| 83 |
+
}],
|
| 84 |
+
'outtmpl': '%(id)s.%(ext)s',
|
| 85 |
+
}
|
| 86 |
+
with yt_dlp.YoutubeDL(ydl_opts) as ydl:
|
| 87 |
+
info = ydl.extract_info(url, download=True)
|
| 88 |
+
output_file = f"{info['id']}.mp3"
|
| 89 |
+
logging.info(f"Downloaded YouTube audio: {output_file}")
|
| 90 |
+
return output_file
|
| 91 |
+
except Exception as e:
|
| 92 |
+
logging.error(f"Error in youtube_dl_method: {str(e)}")
|
| 93 |
+
return None
|
| 94 |
|
| 95 |
def pytube_method(url):
|
| 96 |
logging.info("Using pytube method")
|
|
|
|
| 195 |
|
| 196 |
# Validate times
|
| 197 |
if start_time < 0 or end_time < 0:
|
| 198 |
+
raise gr.Error("Start time and end time must be non-negative.")
|
| 199 |
if start_time >= end_time:
|
| 200 |
+
raise gr.Error("End time must be greater than start time.")
|
| 201 |
if start_time > audio_duration:
|
| 202 |
+
raise gr.Error("Start time exceeds audio duration.")
|
| 203 |
|
| 204 |
trimmed_audio = audio[start_time * 1000:end_time * 1000]
|
| 205 |
trimmed_audio_path = tempfile.mktemp(suffix='.wav')
|
|
|
|
| 224 |
else:
|
| 225 |
return []
|
| 226 |
|
| 227 |
+
loaded_models = {}
|
| 228 |
+
|
| 229 |
def transcribe_audio(input_source, pipeline_type, model_id, dtype, batch_size, download_method, start_time=None, end_time=None, verbose=False):
|
| 230 |
try:
|
| 231 |
+
if verbose:
|
| 232 |
+
logging.getLogger().setLevel(logging.INFO)
|
| 233 |
+
else:
|
| 234 |
+
logging.getLogger().setLevel(logging.WARNING)
|
| 235 |
+
|
| 236 |
+
logging.info(f"Transcription parameters: pipeline_type={pipeline_type}, model_id={model_id}, dtype={dtype}, batch_size={batch_size}, download_method={download_method}")
|
| 237 |
+
verbose_messages = f"Starting transcription with parameters:\nPipeline Type: {pipeline_type}\nModel ID: {model_id}\nData Type: {dtype}\nBatch Size: {batch_size}\nDownload Method: {download_method}\n"
|
| 238 |
+
|
| 239 |
+
if verbose:
|
| 240 |
+
yield verbose_messages, "", None
|
| 241 |
+
|
| 242 |
# Determine if input_source is a URL or file
|
| 243 |
if isinstance(input_source, str):
|
| 244 |
if input_source.startswith('http://') or input_source.startswith('https://'):
|
| 245 |
audio_path = download_audio(input_source, download_method)
|
|
|
|
| 246 |
if not audio_path or audio_path.startswith("Error"):
|
| 247 |
yield f"Error: {audio_path}", "", None
|
| 248 |
return
|
| 249 |
+
else:
|
| 250 |
+
# Assume it's a local file path
|
| 251 |
+
audio_path = input_source
|
| 252 |
+
elif input_source is not None:
|
| 253 |
+
# Uploaded file object
|
| 254 |
audio_path = input_source.name
|
| 255 |
logging.info(f"Using uploaded audio file: {audio_path}")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 256 |
else:
|
| 257 |
+
yield "No audio source provided.", "", None
|
| 258 |
+
return
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 259 |
|
| 260 |
+
# Convert start_time and end_time to float or None
|
| 261 |
start_time = float(start_time) if start_time else None
|
| 262 |
end_time = float(end_time) if end_time else None
|
| 263 |
|
|
|
|
| 268 |
if verbose:
|
| 269 |
yield verbose_messages, "", None
|
| 270 |
|
| 271 |
+
# Model caching
|
| 272 |
+
model_key = (pipeline_type, model_id, dtype)
|
| 273 |
+
if model_key in loaded_models:
|
| 274 |
+
model_or_pipeline = loaded_models[model_key]
|
| 275 |
+
logging.info("Loaded model from cache")
|
| 276 |
+
else:
|
| 277 |
+
if pipeline_type == "faster-batched":
|
| 278 |
+
model = WhisperModel(model_id, device=device, compute_type=dtype)
|
| 279 |
+
pipeline = BatchedInferencePipeline(model=model)
|
| 280 |
+
elif pipeline_type == "faster-sequenced":
|
| 281 |
+
model = WhisperModel(model_id, device=device, compute_type=dtype)
|
| 282 |
+
pipeline = model.transcribe
|
| 283 |
+
elif pipeline_type == "transformers":
|
| 284 |
+
torch_dtype = torch.float16 if dtype == "float16" else torch.float32
|
| 285 |
+
model = AutoModelForSpeechSeq2Seq.from_pretrained(
|
| 286 |
+
model_id, torch_dtype=torch_dtype, low_cpu_mem_usage=True
|
| 287 |
+
)
|
| 288 |
+
model.to(device)
|
| 289 |
+
processor = AutoProcessor.from_pretrained(model_id)
|
| 290 |
+
pipeline = pipeline(
|
| 291 |
+
"automatic-speech-recognition",
|
| 292 |
+
model=model,
|
| 293 |
+
tokenizer=processor.tokenizer,
|
| 294 |
+
feature_extractor=processor.feature_extractor,
|
| 295 |
+
chunk_length_s=30,
|
| 296 |
+
batch_size=batch_size,
|
| 297 |
+
return_timestamps=True,
|
| 298 |
+
torch_dtype=torch_dtype,
|
| 299 |
+
device=device,
|
| 300 |
+
)
|
| 301 |
+
else:
|
| 302 |
+
raise ValueError("Invalid pipeline type")
|
| 303 |
+
loaded_models[model_key] = model_or_pipeline # Cache the model
|
| 304 |
+
|
| 305 |
start_time_perf = time.time()
|
| 306 |
+
if pipeline_type == "faster-batched":
|
| 307 |
+
segments, info = model_or_pipeline.transcribe(audio_path, batch_size=batch_size)
|
| 308 |
+
elif pipeline_type == "faster-sequenced":
|
| 309 |
+
segments, info = model_or_pipeline.transcribe(audio_path)
|
| 310 |
else:
|
| 311 |
+
result = model_or_pipeline(audio_path)
|
| 312 |
segments = result["chunks"]
|
| 313 |
end_time_perf = time.time()
|
| 314 |
|
|
|
|
| 326 |
transcription = ""
|
| 327 |
|
| 328 |
for segment in segments:
|
| 329 |
+
if pipeline_type in ["faster-batched", "faster-sequenced"]:
|
| 330 |
+
transcription_segment = f"[{segment.start:.2f}s -> {segment.end:.2f}s] {segment.text}\n"
|
| 331 |
+
else:
|
| 332 |
+
transcription_segment = f"[{segment['timestamp'][0]:.2f}s -> {segment['timestamp'][1]:.2f}s] {segment['text']}\n"
|
|
|
|
| 333 |
transcription += transcription_segment
|
| 334 |
if verbose:
|
| 335 |
yield verbose_messages + metrics_output, transcription, None
|
|
|
|
| 342 |
yield f"An error occurred: {str(e)}", "", None
|
| 343 |
|
| 344 |
finally:
|
| 345 |
+
# Clean up temporary files
|
| 346 |
if audio_path and os.path.exists(audio_path):
|
| 347 |
os.remove(audio_path)
|
|
|
|
| 348 |
if 'trimmed_audio_path' in locals() and os.path.exists(trimmed_audio_path):
|
| 349 |
os.remove(trimmed_audio_path)
|
| 350 |
+
if 'transcription_file' in locals() and os.path.exists(transcription_file):
|
|
|
|
| 351 |
os.remove(transcription_file)
|
|
|
|
| 352 |
|
| 353 |
with gr.Blocks() as iface:
|
| 354 |
gr.Markdown("# Multi-Pipeline Transcription")
|
| 355 |
gr.Markdown("Transcribe audio using multiple pipelines and models.")
|
| 356 |
|
| 357 |
with gr.Row():
|
| 358 |
+
#input_source = gr.File(label="Audio Source (Upload a file or enter a URL/YouTube URL)")
|
| 359 |
+
input_source = gr.Textbox(label="Audio Source (Upload a file or enter a URL/YouTube URL)")
|
| 360 |
pipeline_type = gr.Dropdown(
|
| 361 |
choices=["faster-batched", "faster-sequenced", "transformers"],
|
| 362 |
label="Pipeline Type",
|
|
|
|
| 393 |
try:
|
| 394 |
model_choices = get_model_options(pipeline_type)
|
| 395 |
logging.info(f"Model choices for {pipeline_type}: {model_choices}")
|
|
|
|
| 396 |
if model_choices:
|
| 397 |
return gr.update(choices=model_choices, value=model_choices[0], visible=True)
|
| 398 |
else:
|
|
|
|
| 400 |
except Exception as e:
|
| 401 |
logging.error(f"Error in update_model_dropdown: {str(e)}")
|
| 402 |
return gr.update(choices=["Error"], value="Error", visible=True)
|
| 403 |
+
|
| 404 |
+
# event handler for pipeline_type change
|
| 405 |
+
pipeline_type.change(update_model_dropdown, inputs=[pipeline_type], outputs=[model_id])
|
| 406 |
|
| 407 |
def transcribe_with_progress(*args):
|
| 408 |
for result in transcribe_audio(*args):
|
|
|
|
| 416 |
|
| 417 |
gr.Examples(
|
| 418 |
examples=[
|
| 419 |
+
["https://www.youtube.com/watch?v=daQ_hqA6HDo", "faster-batched", "cstr/whisper-large-v3-turbo-int8_float32", "int8", 16, "yt-dlp", None, None, True],
|
| 420 |
+
["https://mcdn.podbean.com/mf/web/dir5wty678b6g4vg/HoP_453_-_The_Price_is_Right_-_Law_and_Economics_in_the_Second_Scholastic5yxzh.mp3", "faster-sequenced", "deepdml/faster-whisper-large-v3-turbo-ct2", "float16", 1, "ffmpeg", 0, 300, True],
|
| 421 |
+
["path/to/local/audio.mp3", "transformers", "openai/whisper-large-v3", "float16", 16, "yt-dlp", 60, 180, True]
|
| 422 |
],
|
| 423 |
inputs=[input_source, pipeline_type, model_id, dtype, batch_size, download_method, start_time, end_time, verbose],
|
| 424 |
)
|