|
|
from typing import Any |
|
|
|
|
|
import numpy as np |
|
|
import torch |
|
|
import transformers |
|
|
|
|
|
try: |
|
|
from .asr_modeling import ASRModel |
|
|
except ImportError: |
|
|
from asr_modeling import ASRModel |
|
|
|
|
|
|
|
|
class ForcedAligner: |
|
|
"""Lazy-loaded forced aligner for word-level timestamps using torchaudio wav2vec2.""" |
|
|
|
|
|
_bundle = None |
|
|
_model = None |
|
|
_labels = None |
|
|
_dictionary = None |
|
|
|
|
|
@classmethod |
|
|
def get_instance(cls, device: str = "cuda"): |
|
|
if cls._model is None: |
|
|
import torchaudio |
|
|
|
|
|
cls._bundle = torchaudio.pipelines.WAV2VEC2_ASR_BASE_960H |
|
|
cls._model = cls._bundle.get_model().to(device) |
|
|
cls._model.eval() |
|
|
cls._labels = cls._bundle.get_labels() |
|
|
cls._dictionary = {c: i for i, c in enumerate(cls._labels)} |
|
|
return cls._model, cls._labels, cls._dictionary |
|
|
|
|
|
@classmethod |
|
|
def align( |
|
|
cls, |
|
|
audio: np.ndarray, |
|
|
text: str, |
|
|
sample_rate: int = 16000, |
|
|
language: str = "eng", |
|
|
batch_size: int = 16, |
|
|
) -> list[dict]: |
|
|
"""Align transcript to audio and return word-level timestamps. |
|
|
|
|
|
Args: |
|
|
audio: Audio waveform as numpy array |
|
|
text: Transcript text to align |
|
|
sample_rate: Audio sample rate (default 16000) |
|
|
language: ISO-639-3 language code (default "eng" for English, unused) |
|
|
batch_size: Batch size for alignment model (unused) |
|
|
|
|
|
Returns: |
|
|
List of dicts with 'word', 'start', 'end' keys |
|
|
""" |
|
|
import torchaudio |
|
|
from torchaudio.functional import forced_align, merge_tokens |
|
|
|
|
|
device = "cuda" if torch.cuda.is_available() else "cpu" |
|
|
model, labels, dictionary = cls.get_instance(device) |
|
|
|
|
|
|
|
|
if isinstance(audio, np.ndarray): |
|
|
waveform = torch.from_numpy(audio.copy()).float() |
|
|
else: |
|
|
waveform = audio.clone().float() |
|
|
|
|
|
|
|
|
if waveform.dim() == 1: |
|
|
waveform = waveform.unsqueeze(0) |
|
|
|
|
|
|
|
|
if sample_rate != cls._bundle.sample_rate: |
|
|
waveform = torchaudio.functional.resample( |
|
|
waveform, sample_rate, cls._bundle.sample_rate |
|
|
) |
|
|
|
|
|
waveform = waveform.to(device) |
|
|
|
|
|
|
|
|
with torch.inference_mode(): |
|
|
emissions, _ = model(waveform) |
|
|
emissions = torch.log_softmax(emissions, dim=-1) |
|
|
|
|
|
emission = emissions[0].cpu() |
|
|
|
|
|
|
|
|
transcript = text.upper() |
|
|
|
|
|
tokens = [] |
|
|
for char in transcript: |
|
|
if char in dictionary: |
|
|
tokens.append(dictionary[char]) |
|
|
elif char == " ": |
|
|
tokens.append(dictionary.get("|", dictionary.get(" ", 0))) |
|
|
|
|
|
if not tokens: |
|
|
return [] |
|
|
|
|
|
targets = torch.tensor([tokens], dtype=torch.int32) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
aligned_tokens, scores = forced_align(emission.unsqueeze(0), targets, blank=0) |
|
|
|
|
|
|
|
|
token_spans = merge_tokens(aligned_tokens[0], scores[0]) |
|
|
|
|
|
|
|
|
frame_duration = 320 / cls._bundle.sample_rate |
|
|
|
|
|
|
|
|
words = text.split() |
|
|
word_timestamps = [] |
|
|
current_word_start = None |
|
|
current_word_end = None |
|
|
word_idx = 0 |
|
|
|
|
|
for span in token_spans: |
|
|
token_char = labels[span.token] |
|
|
if token_char == "|": |
|
|
if current_word_start is not None and word_idx < len(words): |
|
|
word_timestamps.append({ |
|
|
"word": words[word_idx], |
|
|
"start": current_word_start * frame_duration, |
|
|
"end": current_word_end * frame_duration, |
|
|
}) |
|
|
word_idx += 1 |
|
|
current_word_start = None |
|
|
current_word_end = None |
|
|
else: |
|
|
if current_word_start is None: |
|
|
current_word_start = span.start |
|
|
current_word_end = span.end |
|
|
|
|
|
|
|
|
if current_word_start is not None and word_idx < len(words): |
|
|
word_timestamps.append({ |
|
|
"word": words[word_idx], |
|
|
"start": current_word_start * frame_duration, |
|
|
"end": current_word_end * frame_duration, |
|
|
}) |
|
|
|
|
|
return word_timestamps |
|
|
|
|
|
|
|
|
class SpeakerDiarizer: |
|
|
"""Lazy-loaded speaker diarization using pyannote-audio.""" |
|
|
|
|
|
_pipeline = None |
|
|
|
|
|
@classmethod |
|
|
def get_instance(cls, hf_token: str | None = None): |
|
|
"""Get or create the diarization pipeline. |
|
|
|
|
|
Args: |
|
|
hf_token: HuggingFace token with access to pyannote models. |
|
|
Can also be set via HF_TOKEN environment variable. |
|
|
""" |
|
|
if cls._pipeline is None: |
|
|
from pyannote.audio import Pipeline |
|
|
|
|
|
cls._pipeline = Pipeline.from_pretrained( |
|
|
"pyannote/speaker-diarization-3.1", |
|
|
) |
|
|
|
|
|
|
|
|
if torch.cuda.is_available(): |
|
|
cls._pipeline.to(torch.device("cuda")) |
|
|
elif torch.backends.mps.is_available(): |
|
|
cls._pipeline.to(torch.device("mps")) |
|
|
|
|
|
return cls._pipeline |
|
|
|
|
|
@classmethod |
|
|
def diarize( |
|
|
cls, |
|
|
audio: np.ndarray | str, |
|
|
sample_rate: int = 16000, |
|
|
num_speakers: int | None = None, |
|
|
min_speakers: int | None = None, |
|
|
max_speakers: int | None = None, |
|
|
hf_token: str | None = None, |
|
|
) -> list[dict]: |
|
|
"""Run speaker diarization on audio. |
|
|
|
|
|
Args: |
|
|
audio: Audio waveform as numpy array or path to audio file |
|
|
sample_rate: Audio sample rate (default 16000) |
|
|
num_speakers: Exact number of speakers (if known) |
|
|
min_speakers: Minimum number of speakers |
|
|
max_speakers: Maximum number of speakers |
|
|
hf_token: HuggingFace token for pyannote models |
|
|
|
|
|
Returns: |
|
|
List of dicts with 'speaker', 'start', 'end' keys |
|
|
""" |
|
|
pipeline = cls.get_instance(hf_token) |
|
|
|
|
|
|
|
|
if isinstance(audio, np.ndarray): |
|
|
|
|
|
waveform = torch.from_numpy(audio).unsqueeze(0) |
|
|
if waveform.dim() == 1: |
|
|
waveform = waveform.unsqueeze(0) |
|
|
audio_input = {"waveform": waveform, "sample_rate": sample_rate} |
|
|
else: |
|
|
|
|
|
audio_input = audio |
|
|
|
|
|
|
|
|
diarization_args = {} |
|
|
if num_speakers is not None: |
|
|
diarization_args["num_speakers"] = num_speakers |
|
|
if min_speakers is not None: |
|
|
diarization_args["min_speakers"] = min_speakers |
|
|
if max_speakers is not None: |
|
|
diarization_args["max_speakers"] = max_speakers |
|
|
|
|
|
diarization = pipeline(audio_input, **diarization_args) |
|
|
|
|
|
|
|
|
|
|
|
if hasattr(diarization, "itertracks"): |
|
|
annotation = diarization |
|
|
elif hasattr(diarization, "speaker_diarization"): |
|
|
|
|
|
annotation = diarization.speaker_diarization |
|
|
elif isinstance(diarization, tuple): |
|
|
|
|
|
annotation = diarization[0] |
|
|
else: |
|
|
raise TypeError(f"Unexpected diarization output type: {type(diarization)}") |
|
|
|
|
|
|
|
|
segments = [] |
|
|
for turn, _, speaker in annotation.itertracks(yield_label=True): |
|
|
segments.append({ |
|
|
"speaker": speaker, |
|
|
"start": turn.start, |
|
|
"end": turn.end, |
|
|
}) |
|
|
|
|
|
return segments |
|
|
|
|
|
@classmethod |
|
|
def assign_speakers_to_words( |
|
|
cls, |
|
|
words: list[dict], |
|
|
speaker_segments: list[dict], |
|
|
) -> list[dict]: |
|
|
"""Assign speaker labels to words based on timestamp overlap. |
|
|
|
|
|
Args: |
|
|
words: List of word dicts with 'word', 'start', 'end' keys |
|
|
speaker_segments: List of speaker dicts with 'speaker', 'start', 'end' keys |
|
|
|
|
|
Returns: |
|
|
Words list with 'speaker' key added to each word |
|
|
""" |
|
|
for word in words: |
|
|
word_mid = (word["start"] + word["end"]) / 2 |
|
|
|
|
|
|
|
|
best_speaker = None |
|
|
for seg in speaker_segments: |
|
|
if seg["start"] <= word_mid <= seg["end"]: |
|
|
best_speaker = seg["speaker"] |
|
|
break |
|
|
|
|
|
|
|
|
if best_speaker is None and speaker_segments: |
|
|
min_dist = float("inf") |
|
|
for seg in speaker_segments: |
|
|
seg_mid = (seg["start"] + seg["end"]) / 2 |
|
|
dist = abs(word_mid - seg_mid) |
|
|
if dist < min_dist: |
|
|
min_dist = dist |
|
|
best_speaker = seg["speaker"] |
|
|
|
|
|
word["speaker"] = best_speaker |
|
|
|
|
|
return words |
|
|
|
|
|
|
|
|
class ASRPipeline(transformers.AutomaticSpeechRecognitionPipeline): |
|
|
"""ASR Pipeline for audio-to-text transcription.""" |
|
|
|
|
|
model: ASRModel |
|
|
|
|
|
def __init__(self, model: ASRModel, **kwargs): |
|
|
feature_extractor = kwargs.pop("feature_extractor", None) |
|
|
tokenizer = kwargs.pop("tokenizer", model.tokenizer) |
|
|
|
|
|
if feature_extractor is None: |
|
|
feature_extractor = model.get_processor().feature_extractor |
|
|
|
|
|
super().__init__( |
|
|
model=model, feature_extractor=feature_extractor, tokenizer=tokenizer, **kwargs |
|
|
) |
|
|
self._current_audio = None |
|
|
|
|
|
def _sanitize_parameters(self, **kwargs): |
|
|
"""Intercept our custom parameters before parent class validates them.""" |
|
|
|
|
|
kwargs.pop("return_timestamps", None) |
|
|
kwargs.pop("return_speakers", None) |
|
|
kwargs.pop("num_speakers", None) |
|
|
kwargs.pop("min_speakers", None) |
|
|
kwargs.pop("max_speakers", None) |
|
|
kwargs.pop("hf_token", None) |
|
|
|
|
|
return super()._sanitize_parameters(**kwargs) |
|
|
|
|
|
def __call__( |
|
|
self, |
|
|
inputs, |
|
|
**kwargs, |
|
|
): |
|
|
"""Transcribe audio with optional word-level timestamps and speaker diarization. |
|
|
|
|
|
Args: |
|
|
inputs: Audio input (file path, dict with array/sampling_rate, etc.) |
|
|
return_timestamps: If True, return word-level timestamps using forced alignment |
|
|
return_speakers: If True, return speaker labels for each word |
|
|
num_speakers: Exact number of speakers (if known, for diarization) |
|
|
min_speakers: Minimum number of speakers (for diarization) |
|
|
max_speakers: Maximum number of speakers (for diarization) |
|
|
hf_token: HuggingFace token for pyannote models (or set HF_TOKEN env var) |
|
|
**kwargs: Additional arguments passed to the pipeline |
|
|
|
|
|
Returns: |
|
|
Dict with 'text' key, 'words' key if return_timestamps=True, |
|
|
and speaker labels on words if return_speakers=True |
|
|
""" |
|
|
|
|
|
return_timestamps = kwargs.pop("return_timestamps", False) |
|
|
return_speakers = kwargs.pop("return_speakers", False) |
|
|
diarization_params = { |
|
|
"num_speakers": kwargs.pop("num_speakers", None), |
|
|
"min_speakers": kwargs.pop("min_speakers", None), |
|
|
"max_speakers": kwargs.pop("max_speakers", None), |
|
|
"hf_token": kwargs.pop("hf_token", None), |
|
|
} |
|
|
|
|
|
if return_speakers: |
|
|
return_timestamps = True |
|
|
|
|
|
|
|
|
if return_timestamps or return_speakers: |
|
|
self._current_audio = self._extract_audio(inputs) |
|
|
|
|
|
|
|
|
result = super().__call__(inputs, **kwargs) |
|
|
|
|
|
|
|
|
if return_timestamps and self._current_audio is not None: |
|
|
text = result.get("text", "") |
|
|
if text: |
|
|
try: |
|
|
words = ForcedAligner.align( |
|
|
self._current_audio["array"], |
|
|
text, |
|
|
sample_rate=self._current_audio.get("sampling_rate", 16000), |
|
|
) |
|
|
result["words"] = words |
|
|
except Exception as e: |
|
|
result["words"] = [] |
|
|
result["timestamp_error"] = str(e) |
|
|
else: |
|
|
result["words"] = [] |
|
|
|
|
|
|
|
|
if return_speakers and self._current_audio is not None: |
|
|
try: |
|
|
|
|
|
speaker_segments = SpeakerDiarizer.diarize( |
|
|
self._current_audio["array"], |
|
|
sample_rate=self._current_audio.get("sampling_rate", 16000), |
|
|
**{k: v for k, v in diarization_params.items() if v is not None}, |
|
|
) |
|
|
result["speaker_segments"] = speaker_segments |
|
|
|
|
|
|
|
|
if result.get("words"): |
|
|
result["words"] = SpeakerDiarizer.assign_speakers_to_words( |
|
|
result["words"], |
|
|
speaker_segments, |
|
|
) |
|
|
except Exception as e: |
|
|
result["speaker_segments"] = [] |
|
|
result["diarization_error"] = str(e) |
|
|
|
|
|
|
|
|
self._current_audio = None |
|
|
|
|
|
return result |
|
|
|
|
|
def _extract_audio(self, inputs) -> dict | None: |
|
|
"""Extract audio array from various input formats using HF utilities.""" |
|
|
from transformers.pipelines.audio_utils import ffmpeg_read |
|
|
|
|
|
if isinstance(inputs, dict): |
|
|
if "array" in inputs: |
|
|
return { |
|
|
"array": inputs["array"], |
|
|
"sampling_rate": inputs.get("sampling_rate", 16000), |
|
|
} |
|
|
if "raw" in inputs: |
|
|
return { |
|
|
"array": inputs["raw"], |
|
|
"sampling_rate": inputs.get("sampling_rate", 16000), |
|
|
} |
|
|
elif isinstance(inputs, str): |
|
|
|
|
|
with open(inputs, "rb") as f: |
|
|
audio = ffmpeg_read(f.read(), sampling_rate=16000) |
|
|
return {"array": audio, "sampling_rate": 16000} |
|
|
elif isinstance(inputs, bytes): |
|
|
audio = ffmpeg_read(inputs, sampling_rate=16000) |
|
|
return {"array": audio, "sampling_rate": 16000} |
|
|
elif isinstance(inputs, np.ndarray): |
|
|
return {"array": inputs, "sampling_rate": 16000} |
|
|
|
|
|
return None |
|
|
|
|
|
def preprocess(self, inputs, **preprocess_params): |
|
|
|
|
|
if isinstance(inputs, dict) and "array" in inputs: |
|
|
inputs = { |
|
|
"raw": inputs["array"], |
|
|
"sampling_rate": inputs.get("sampling_rate", self.feature_extractor.sampling_rate), |
|
|
} |
|
|
|
|
|
for item in super().preprocess(inputs, **preprocess_params): |
|
|
if "is_last" not in item: |
|
|
item["is_last"] = True |
|
|
yield item |
|
|
|
|
|
def _forward(self, model_inputs, **generate_kwargs) -> dict[str, Any]: |
|
|
|
|
|
is_last = model_inputs.pop("is_last", True) if isinstance(model_inputs, dict) else True |
|
|
|
|
|
input_features = model_inputs["input_features"].to(self.model.device) |
|
|
audio_attention_mask = model_inputs["attention_mask"].to(self.model.device) |
|
|
|
|
|
generated_ids = self.model.generate( |
|
|
input_features=input_features, |
|
|
audio_attention_mask=audio_attention_mask, |
|
|
**generate_kwargs, |
|
|
) |
|
|
|
|
|
return {"tokens": generated_ids, "is_last": is_last} |
|
|
|
|
|
def postprocess(self, model_outputs, **kwargs) -> dict[str, str]: |
|
|
|
|
|
if isinstance(model_outputs, list): |
|
|
model_outputs = model_outputs[0] if model_outputs else {} |
|
|
|
|
|
tokens = model_outputs.get("tokens") |
|
|
if tokens is None: |
|
|
return super().postprocess(model_outputs, **kwargs) |
|
|
|
|
|
if torch.is_tensor(tokens): |
|
|
tokens = tokens.cpu() |
|
|
if tokens.dim() > 1: |
|
|
tokens = tokens[0] |
|
|
|
|
|
text = self.tokenizer.decode(tokens, skip_special_tokens=True).strip() |
|
|
return {"text": text} |
|
|
|