Spaces:
Running
on
Zero
Running
on
Zero
| """ | |
| File: utils.py | |
| Author: Dmitry Ryumin, Maxim Markitantov, Elena Ryumina, Anastasia Dvoynikova, and Alexey Karpov | |
| Description: Utility functions. | |
| License: MIT License | |
| """ | |
| import time | |
| import torch | |
| import os | |
| import subprocess | |
| import bisect | |
| import re | |
| import requests | |
| from torchvision import transforms | |
| from PIL import Image | |
| from transformers import WhisperProcessor, WhisperForConditionalGeneration | |
| from pathlib import Path | |
| from contextlib import suppress | |
| from urllib.parse import urlparse | |
| from contextlib import ContextDecorator | |
| from typing import Callable | |
| class Timer(ContextDecorator): | |
| """Context manager for measuring code execution time""" | |
| def __enter__(self): | |
| self.start = time.time() | |
| return self | |
| def __exit__(self, *args): | |
| self.end = time.time() | |
| self.execution_time = f"{self.end - self.start:.2f} seconds" | |
| def __str__(self): | |
| return self.execution_time | |
| def load_model( | |
| model_url: str, folder_path: str, force_reload: bool = False | |
| ) -> str | None: | |
| file_name = Path(urlparse(model_url).path).name | |
| file_path = Path(folder_path) / file_name | |
| if file_path.exists() and not force_reload: | |
| return str(file_path) | |
| with suppress(Exception), requests.get(model_url, stream=True) as response: | |
| file_path.parent.mkdir(parents=True, exist_ok=True) | |
| with file_path.open("wb") as file: | |
| for chunk in response.iter_content(chunk_size=8192): | |
| file.write(chunk) | |
| return str(file_path) | |
| return None | |
| def readetect_speech( | |
| file_path: str, | |
| read_audio: Callable, | |
| get_speech_timestamps: Callable, | |
| vad_model: torch.jit.ScriptModule, | |
| sr: int = 16000, | |
| ) -> list[dict]: | |
| wav = read_audio(file_path, sampling_rate=sr) | |
| # get speech timestamps from full audio file | |
| speech_timestamps = get_speech_timestamps(wav, vad_model, sampling_rate=sr) | |
| return wav, speech_timestamps | |
| def calculate_mode(series): | |
| mode = series.mode() | |
| return mode[0] if not mode.empty else None | |
| def pth_processing(fp): | |
| class PreprocessInput(torch.nn.Module): | |
| def init(self): | |
| super(PreprocessInput, self).init() | |
| def forward(self, x): | |
| x = x.to(torch.float32) | |
| x = torch.flip(x, dims=(0,)) | |
| x[0, :, :] -= 91.4953 | |
| x[1, :, :] -= 103.8827 | |
| x[2, :, :] -= 131.0912 | |
| return x | |
| def get_img_torch(img, target_size=(224, 224)): | |
| transform = transforms.Compose([transforms.PILToTensor(), PreprocessInput()]) | |
| img = img.resize(target_size, Image.Resampling.NEAREST) | |
| img = transform(img) | |
| img = torch.unsqueeze(img, 0) | |
| return img | |
| return get_img_torch(fp) | |
| def get_idx_frames_in_windows( | |
| frames: list[int], window: dict, fps: int, sr: int = 16000 | |
| ) -> list[list]: | |
| frames_in_windows = [ | |
| idx | |
| for idx, frame in enumerate(frames) | |
| if window["start"] * fps / sr <= frame < window["end"] * fps / sr | |
| ] | |
| return frames_in_windows | |
| # Maxim code | |
| def slice_audio( | |
| start_time: float, | |
| end_time: float, | |
| win_max_length: float, | |
| win_shift: float, | |
| win_min_length: float, | |
| ) -> list[dict]: | |
| """Slices audio on windows | |
| Args: | |
| start_time (float): Start time of audio | |
| end_time (float): End time of audio | |
| win_max_length (float): Window max length | |
| win_shift (float): Window shift | |
| win_min_length (float): Window min length | |
| Returns: | |
| list[dict]: List of dict with timings, f.e.: {'start': 0, 'end': 12} | |
| """ | |
| if end_time < start_time: | |
| return [] | |
| elif (end_time - start_time) > win_max_length: | |
| timings = [] | |
| while start_time < end_time: | |
| end_time_chunk = start_time + win_max_length | |
| if end_time_chunk < end_time: | |
| timings.append({"start": start_time, "end": end_time_chunk}) | |
| elif end_time_chunk == end_time: # if tail exact `win_max_length` seconds | |
| timings.append({"start": start_time, "end": end_time_chunk}) | |
| break | |
| else: # if tail less then `win_max_length` seconds | |
| if ( | |
| end_time - start_time < win_min_length | |
| ): # if tail less then `win_min_length` seconds | |
| break | |
| timings.append({"start": start_time, "end": end_time}) | |
| break | |
| start_time += win_shift | |
| return timings | |
| else: | |
| return [{"start": start_time, "end": end_time}] | |
| def convert_video_to_audio(file_path: str, sr: int = 16000) -> str: | |
| path_save = file_path.split(".")[0] + ".wav" | |
| if not os.path.exists(path_save): | |
| ffmpeg_command = f"ffmpeg -y -i {file_path} -async 1 -vn -acodec pcm_s16le -ar {sr} {path_save}" | |
| subprocess.call(ffmpeg_command, shell=True) | |
| return path_save | |
| def find_nearest_frames(target_frames, all_frames): | |
| nearest_frames = [] | |
| for frame in target_frames: | |
| pos = bisect.bisect_left(all_frames, frame) | |
| if pos == 0: | |
| nearest_frame = all_frames[0] | |
| elif pos == len(all_frames): | |
| nearest_frame = all_frames[-1] | |
| else: | |
| before = all_frames[pos - 1] | |
| after = all_frames[pos] | |
| nearest_frame = before if frame - before <= after - frame else after | |
| nearest_frames.append(nearest_frame) | |
| return nearest_frames | |
| def find_intersections( | |
| x: list[dict], y: list[dict], min_length: float = 0 | |
| ) -> list[dict]: | |
| """Find intersections of two lists of dicts with intervals, preserving structure of `x` and adding intersection info | |
| Args: | |
| x (list[dict]): First list of intervals | |
| y (list[dict]): Second list of intervals | |
| min_length (float, optional): Minimum length of intersection. Defaults to 0. | |
| Returns: | |
| list[dict]: Windows with intersections, maintaining structure of `x`, and indicating intersection presence. | |
| """ | |
| timings = [] | |
| j = 0 | |
| for interval_x in x: | |
| original_start = int(interval_x["start"]) | |
| original_end = int(interval_x["end"]) | |
| intersections_found = False | |
| while j < len(y) and y[j]["end"] < original_start: | |
| j += 1 # Skip any intervals in `y` that end before the current interval in `x` starts | |
| # Check for all overlapping intervals in `y` | |
| temp_j = ( | |
| j # Temporary pointer to check intersections within `y` for current `x` | |
| ) | |
| while temp_j < len(y) and y[temp_j]["start"] <= original_end: | |
| # Calculate the intersection between `x[i]` and `y[j]` | |
| intersection_start = max(original_start, y[temp_j]["start"]) | |
| intersection_end = min(original_end, y[temp_j]["end"]) | |
| if ( | |
| intersection_start < intersection_end | |
| and (intersection_end - intersection_start) >= min_length | |
| ): | |
| timings.append( | |
| { | |
| "original_start": original_start, | |
| "original_end": original_end, | |
| "start": intersection_start, | |
| "end": intersection_end, | |
| "speech": True, | |
| } | |
| ) | |
| intersections_found = True | |
| temp_j += 1 # Move to the next interval in `y` for further intersections | |
| # If no intersections were found, add the interval with `intersected` set to False | |
| if not intersections_found: | |
| timings.append( | |
| { | |
| "original_start": original_start, | |
| "original_end": original_end, | |
| "start": None, | |
| "end": None, | |
| "speech": False, | |
| } | |
| ) | |
| return timings | |
| # Anastasia code | |
| class ASRModel: | |
| def __init__(self, checkpoint_path: str, device: torch.device): | |
| self.processor = WhisperProcessor.from_pretrained(checkpoint_path) | |
| self.model = WhisperForConditionalGeneration.from_pretrained( | |
| checkpoint_path | |
| ).to(device) | |
| self.device = device | |
| self.model.config.forced_decoder_ids = None | |
| def __call__( | |
| self, sample: torch.Tensor, audio_windows: dict, sr: int = 16000 | |
| ) -> tuple: | |
| texts = [] | |
| for t in range(len(audio_windows)): | |
| input_features = self.processor( | |
| sample[audio_windows[t]["start"] : audio_windows[t]["end"]], | |
| sampling_rate=sr, | |
| return_tensors="pt", | |
| ).input_features | |
| predicted_ids = self.model.generate(input_features.to(self.device)) | |
| transcription = self.processor.batch_decode( | |
| predicted_ids, skip_special_tokens=False | |
| ) | |
| curr_text = re.findall(r"> ([^<>]+)", transcription[0]) | |
| if curr_text: | |
| texts.append(curr_text) | |
| else: | |
| texts.append([""]) | |
| # for drawing | |
| input_features = self.processor( | |
| sample, sampling_rate=sr, return_tensors="pt" | |
| ).input_features | |
| predicted_ids = self.model.generate(input_features.to(self.device)) | |
| transcription = self.processor.batch_decode( | |
| predicted_ids, skip_special_tokens=False | |
| ) | |
| total_text = re.findall(r"> ([^<>]+)", transcription[0]) | |
| return texts, total_text | |
| def convert_webm_to_mp4(input_file): | |
| path_save = os.path.splitext(input_file)[0] + ".mp4" | |
| ff_video = "ffmpeg -i {} -c:v libx264 -c:a aac -strict experimental {}".format(input_file, path_save) | |
| subprocess.run(ff_video, shell=True, check=True, capture_output=True, text=True) | |
| return path_save | |