Spaces:
Paused
Paused
| import sys | |
| import logging | |
| import io | |
| import soundfile as sf | |
| import math | |
| try: | |
| import torch | |
| except ImportError: | |
| torch = None | |
| from typing import List | |
| import numpy as np | |
| from timed_objects import ASRToken | |
| logger = logging.getLogger(__name__) | |
| class ASRBase: | |
| sep = " " # join transcribe words with this character (" " for whisper_timestamped, | |
| # "" for faster-whisper because it emits the spaces when needed) | |
| def __init__(self, lan, modelsize=None, cache_dir=None, model_dir=None, logfile=sys.stderr): | |
| self.logfile = logfile | |
| self.transcribe_kargs = {} | |
| if lan == "auto": | |
| self.original_language = None | |
| else: | |
| self.original_language = lan | |
| self.model = self.load_model(modelsize, cache_dir, model_dir) | |
| def with_offset(self, offset: float) -> ASRToken: | |
| # This method is kept for compatibility (typically you will use ASRToken.with_offset) | |
| return ASRToken(self.start + offset, self.end + offset, self.text) | |
| def __repr__(self): | |
| return f"ASRToken(start={self.start:.2f}, end={self.end:.2f}, text={self.text!r})" | |
| def load_model(self, modelsize, cache_dir, model_dir): | |
| raise NotImplementedError("must be implemented in the child class") | |
| def transcribe(self, audio, init_prompt=""): | |
| raise NotImplementedError("must be implemented in the child class") | |
| def use_vad(self): | |
| raise NotImplementedError("must be implemented in the child class") | |
| class WhisperTimestampedASR(ASRBase): | |
| """Uses whisper_timestamped as the backend.""" | |
| sep = " " | |
| def load_model(self, modelsize=None, cache_dir=None, model_dir=None): | |
| print("Loading whisper_timestamped model") | |
| import whisper | |
| import whisper_timestamped | |
| from whisper_timestamped import transcribe_timestamped | |
| self.transcribe_timestamped = transcribe_timestamped | |
| if model_dir is not None: | |
| logger.debug("ignoring model_dir, not implemented") | |
| return whisper.load_model(modelsize, download_root=cache_dir) | |
| def transcribe(self, audio, init_prompt=""): | |
| result = self.transcribe_timestamped( | |
| self.model, | |
| audio, | |
| language=self.original_language, | |
| initial_prompt=init_prompt, | |
| verbose=None, | |
| condition_on_previous_text=True, | |
| **self.transcribe_kargs, | |
| ) | |
| return result | |
| def ts_words(self, r) -> List[ASRToken]: | |
| """ | |
| Converts the whisper_timestamped result to a list of ASRToken objects. | |
| """ | |
| tokens = [] | |
| for segment in r["segments"]: | |
| for word in segment["words"]: | |
| token = ASRToken(word["start"], word["end"], word["text"]) | |
| tokens.append(token) | |
| return tokens | |
| def segments_end_ts(self, res) -> List[float]: | |
| return [segment["end"] for segment in res["segments"]] | |
| def use_vad(self): | |
| self.transcribe_kargs["vad"] = True | |
| def set_translate_task(self): | |
| self.transcribe_kargs["task"] = "translate" | |
| def detect_language(self, audio_file_path): | |
| import whisper | |
| """ | |
| Detect the language of the audio using Whisper's language detection. | |
| Args: | |
| audio (np.ndarray): Audio data as numpy array | |
| Returns: | |
| tuple: (detected_language, confidence, probabilities) | |
| - detected_language (str): The detected language code | |
| - confidence (float): Confidence score for the detected language | |
| - probabilities (dict): Dictionary of language probabilities | |
| """ | |
| try: | |
| # Pad or trim audio to the correct length | |
| audio = whisper.load_audio(audio_file_path) | |
| audio = whisper.pad_or_trim(audio) | |
| # Create mel spectrogram with correct dimensions | |
| mel = whisper.log_mel_spectrogram(audio, n_mels=128).to(self.model.device) | |
| # Detect language | |
| _, probs = self.model.detect_language(mel) | |
| detected_lang = max(probs, key=probs.get) | |
| confidence = probs[detected_lang] | |
| return detected_lang, confidence, probs | |
| except Exception as e: | |
| logger.error(f"Error in language detection: {e}") | |
| raise | |
| class FasterWhisperASR(ASRBase): | |
| """Uses faster-whisper as the backend.""" | |
| sep = "" | |
| def load_model(self, modelsize=None, cache_dir=None, model_dir=None): | |
| print("Loading faster-whisper model") | |
| from faster_whisper import WhisperModel | |
| if model_dir is not None: | |
| logger.debug(f"Loading whisper model from model_dir {model_dir}. " | |
| f"modelsize and cache_dir parameters are not used.") | |
| model_size_or_path = model_dir | |
| elif modelsize is not None: | |
| model_size_or_path = modelsize | |
| else: | |
| raise ValueError("Either modelsize or model_dir must be set") | |
| device = "cuda" if torch and torch.cuda.is_available() else "cpu" | |
| compute_type = "float16" if device == "cuda" else "float32" | |
| print(f"Loading whisper model {model_size_or_path} on {device} with compute type {compute_type}") | |
| model = WhisperModel( | |
| model_size_or_path, | |
| device=device, | |
| compute_type=compute_type, | |
| download_root=cache_dir, | |
| ) | |
| return model | |
| def transcribe(self, audio: np.ndarray, init_prompt: str = "") -> list: | |
| segments, info = self.model.transcribe( | |
| audio, | |
| language=None, | |
| initial_prompt=init_prompt, | |
| beam_size=5, | |
| word_timestamps=True, | |
| condition_on_previous_text=True, | |
| **self.transcribe_kargs, | |
| ) | |
| return list(segments) | |
| def ts_words(self, segments) -> List[ASRToken]: | |
| tokens = [] | |
| for segment in segments: | |
| if segment.no_speech_prob > 0.9: | |
| continue | |
| for word in segment.words: | |
| token = ASRToken(word.start, word.end, word.word, probability=word.probability) | |
| tokens.append(token) | |
| return tokens | |
| def segments_end_ts(self, segments) -> List[float]: | |
| return [segment.end for segment in segments] | |
| def use_vad(self): | |
| self.transcribe_kargs["vad_filter"] = True | |
| def set_translate_task(self): | |
| self.transcribe_kargs["task"] = "translate" | |
| def detect_language(self, audio_file_path): | |
| from faster_whisper.audio import decode_audio | |
| """ | |
| Detect the language of the audio using faster-whisper's language detection. | |
| Args: | |
| audio_file_path: Path to the audio file | |
| Returns: | |
| tuple: (detected_language, confidence, probabilities) | |
| - detected_language (str): The detected language code | |
| - confidence (float): Confidence score for the detected language | |
| - probabilities (dict): Dictionary of language probabilities | |
| """ | |
| try: | |
| audio = decode_audio(audio_file_path, sampling_rate=self.model.feature_extractor.sampling_rate) | |
| # Calculate total number of segments (each segment is 30 seconds) | |
| audio_duration = len(audio) / self.model.feature_extractor.sampling_rate | |
| segments_num = max(1, int(audio_duration / 30)) # At least 1 segment | |
| logger.info(f"Audio duration: {audio_duration:.2f}s, using {segments_num} segments for language detection") | |
| # Use faster-whisper's detect_language method | |
| language, language_probability, all_language_probs = self.model.detect_language( | |
| audio=audio, | |
| vad_filter=False, # Disable VAD for language detection | |
| language_detection_segments=segments_num, # Use all possible segments | |
| language_detection_threshold=0.5 # Default threshold | |
| ) | |
| # Convert list of tuples to dictionary for consistent return format | |
| probs = {lang: prob for lang, prob in all_language_probs} | |
| return language, language_probability, probs | |
| except Exception as e: | |
| logger.error(f"Error in language detection: {e}") | |
| raise | |
| class MLXWhisper(ASRBase): | |
| """ | |
| Uses MLX Whisper optimized for Apple Silicon. | |
| """ | |
| sep = "" | |
| def load_model(self, modelsize=None, cache_dir=None, model_dir=None): | |
| print("Loading mlx whisper model") | |
| from mlx_whisper.transcribe import ModelHolder, transcribe | |
| import mlx.core as mx | |
| if model_dir is not None: | |
| logger.debug(f"Loading whisper model from model_dir {model_dir}. modelsize parameter is not used.") | |
| model_size_or_path = model_dir | |
| elif modelsize is not None: | |
| model_size_or_path = self.translate_model_name(modelsize) | |
| logger.debug(f"Loading whisper model {modelsize}. You use mlx whisper, so {model_size_or_path} will be used.") | |
| else: | |
| raise ValueError("Either modelsize or model_dir must be set") | |
| self.model_size_or_path = model_size_or_path | |
| dtype = mx.float16 | |
| ModelHolder.get_model(model_size_or_path, dtype) | |
| return transcribe | |
| def translate_model_name(self, model_name): | |
| model_mapping = { | |
| "tiny.en": "mlx-community/whisper-tiny.en-mlx", | |
| "tiny": "mlx-community/whisper-tiny-mlx", | |
| "base.en": "mlx-community/whisper-base.en-mlx", | |
| "base": "mlx-community/whisper-base-mlx", | |
| "small.en": "mlx-community/whisper-small.en-mlx", | |
| "small": "mlx-community/whisper-small-mlx", | |
| "medium.en": "mlx-community/whisper-medium.en-mlx", | |
| "medium": "mlx-community/whisper-medium-mlx", | |
| "large-v1": "mlx-community/whisper-large-v1-mlx", | |
| "large-v2": "mlx-community/whisper-large-v2-mlx", | |
| "large-v3": "mlx-community/whisper-large-v3-mlx", | |
| "large-v3-turbo": "mlx-community/whisper-large-v3-turbo", | |
| "large": "mlx-community/whisper-large-mlx", | |
| } | |
| mlx_model_path = model_mapping.get(model_name) | |
| if mlx_model_path: | |
| return mlx_model_path | |
| else: | |
| raise ValueError(f"Model name '{model_name}' is not recognized or not supported.") | |
| def transcribe(self, audio, init_prompt=""): | |
| if self.transcribe_kargs: | |
| logger.warning("Transcribe kwargs (vad, task) are not compatible with MLX Whisper and will be ignored.") | |
| segments = self.model( | |
| audio, | |
| language=self.original_language, | |
| initial_prompt=init_prompt, | |
| word_timestamps=True, | |
| condition_on_previous_text=True, | |
| path_or_hf_repo=self.model_size_or_path, | |
| ) | |
| return segments.get("segments", []) | |
| def ts_words(self, segments) -> List[ASRToken]: | |
| tokens = [] | |
| for segment in segments: | |
| if segment.get("no_speech_prob", 0) > 0.9: | |
| continue | |
| for word in segment.get("words", []): | |
| token = ASRToken(word["start"], word["end"], word["word"], probability=word["probability"]) | |
| tokens.append(token) | |
| return tokens | |
| def segments_end_ts(self, res) -> List[float]: | |
| return [s["end"] for s in res] | |
| def use_vad(self): | |
| self.transcribe_kargs["vad_filter"] = True | |
| def set_translate_task(self): | |
| self.transcribe_kargs["task"] = "translate" | |
| def detect_language(self, audio): | |
| raise NotImplementedError("MLX Whisper does not support language detection.") | |
| class OpenaiApiASR(ASRBase): | |
| """Uses OpenAI's Whisper API for transcription.""" | |
| def __init__(self, lan=None, temperature=0, logfile=sys.stderr): | |
| print("Loading openai api model") | |
| self.logfile = logfile | |
| self.modelname = "whisper-1" | |
| self.original_language = None if lan == "auto" else lan | |
| self.response_format = "verbose_json" | |
| self.temperature = temperature | |
| self.load_model() | |
| self.use_vad_opt = False | |
| self.task = "transcribe" | |
| def load_model(self, *args, **kwargs): | |
| from openai import OpenAI | |
| self.client = OpenAI() | |
| self.transcribed_seconds = 0 | |
| def ts_words(self, segments) -> List[ASRToken]: | |
| """ | |
| Converts OpenAI API response words into ASRToken objects while | |
| optionally skipping words that fall into no-speech segments. | |
| """ | |
| no_speech_segments = [] | |
| if self.use_vad_opt: | |
| for segment in segments.segments: | |
| if segment.no_speech_prob > 0.8: | |
| no_speech_segments.append((segment.start, segment.end)) | |
| tokens = [] | |
| for word in segments.words: | |
| start = word.start | |
| end = word.end | |
| if any(s[0] <= start <= s[1] for s in no_speech_segments): | |
| continue | |
| tokens.append(ASRToken(start, end, word.word)) | |
| return tokens | |
| def segments_end_ts(self, res) -> List[float]: | |
| return [s.end for s in res.words] | |
| def transcribe(self, audio_data, prompt=None, *args, **kwargs): | |
| buffer = io.BytesIO() | |
| buffer.name = "temp.wav" | |
| sf.write(buffer, audio_data, samplerate=16000, format="WAV", subtype="PCM_16") | |
| buffer.seek(0) | |
| self.transcribed_seconds += math.ceil(len(audio_data) / 16000) | |
| params = { | |
| "model": self.modelname, | |
| "file": buffer, | |
| "response_format": self.response_format, | |
| "temperature": self.temperature, | |
| "timestamp_granularities": ["word", "segment"], | |
| } | |
| if self.task != "translate" and self.original_language: | |
| params["language"] = self.original_language | |
| if prompt: | |
| params["prompt"] = prompt | |
| proc = self.client.audio.translations if self.task == "translate" else self.client.audio.transcriptions | |
| transcript = proc.create(**params) | |
| logger.debug(f"OpenAI API processed accumulated {self.transcribed_seconds} seconds") | |
| return transcript | |
| def use_vad(self): | |
| self.use_vad_opt = True | |
| def set_translate_task(self): | |
| self.task = "translate" | |
| def detect_language(self, audio): | |
| raise NotImplementedError("MLX Whisper does not support language detection.") |