File size: 3,343 Bytes
c5d4931
 
 
 
0d9ff36
c5d4931
 
 
0d9ff36
c5d4931
 
0d9ff36
 
 
c5d4931
 
0d9ff36
c5d4931
 
 
 
 
 
0d9ff36
 
 
 
 
 
 
 
 
 
 
 
 
 
c5d4931
 
0d9ff36
c5d4931
 
 
 
 
0d9ff36
 
 
 
 
 
 
 
c5d4931
0d9ff36
c5d4931
0d9ff36
c5d4931
0d9ff36
c5d4931
 
0d9ff36
c5d4931
0d9ff36
 
 
 
 
 
c5d4931
0d9ff36
 
 
c5d4931
 
0d9ff36
c5d4931
0d9ff36
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
import difflib
import re
from functools import lru_cache

import gradio.components.audio as gr_audio
import torch
from transformers import pipeline


# ------------------- Utilities -------------------
def normalize_text(t: str, lower: bool = True) -> str:
    """For normalizing LLM-generated and human-generated strings.
    For LLMs, this removes extraneous quote marks and spaces."""
    # English-only normalization: lowercase, keep letters/digits/' and -
    if lower:
        t = t.lower()
    # TODO: Previously was re.sub(r"[^a-z0-9'\-]+", " ", t); discuss normalizing for LLMs too.
    t = re.sub(r"[^a-zA-Z0-9'\-.,]+", " ", t)
    t = re.sub(r"\s+", " ", t).strip()
    return t


@lru_cache(maxsize=2)
def get_asr_pipeline(model_id: str, device_preference: str) -> pipeline:
    """Cache an ASR pipeline.
    Parameters:
        model_id: String of desired ASR model.
        device_preference: String of desired device for ASR processing, "cuda", "cpu", or "auto".
    Returns:
        transformers.pipeline ASR component.
    """
    if device_preference == "cuda" and torch.cuda.is_available():
        device = 0
    elif device_preference == "auto":
        device = 0 if torch.cuda.is_available() else -1
    else:
        device = -1
    return pipeline(
        "automatic-speech-recognition",
        model=model_id,           # use English-only Whisper models (.en)
        device=device,
        chunk_length_s=30,
        return_timestamps=False,
    )

def run_asr(audio_path: gr_audio, model_id: str, device_pref: str) -> str | Exception:
    """Returns the recognized user utterance from the input audio stream.
    Parameters:
        audio_path: gradio.Audio component.
        model_id: String of desired ASR model.
        device_preference: String of desired device for ASR processing, "cuda", "cpu", or "auto".
    Returns:
        hyp_raw: Recognized user utterance.
    """
    asr = get_asr_pipeline(model_id, device_pref)
    try:
        # IMPORTANT: For English-only Whisper (.en), do NOT pass language/task args.
        result = asr(audio_path)
        hyp_raw = result["text"].strip()
    except Exception as e:
        return e
    return hyp_raw

def similarity_and_diff(ref_tokens: list, hyp_tokens: list) -> (float, list[str, int, int, int]):
    """
    Returns:
        ratio: Similarity ratio (0..1).
        opcodes: List of differences between target and recognized user utterance.
    """
    sm = difflib.SequenceMatcher(a=ref_tokens, b=hyp_tokens)
    ratio = sm.ratio()
    opcodes = sm.get_opcodes()
    return ratio, opcodes

class SentenceMatcher:
    """Class for keeping track of (target sentence, user utterance) match features."""
    def __init__(self, target_sentence, user_transcript, pass_threshold):
        self.target_sentence: str = target_sentence
        self.user_transcript: str = user_transcript
        self.pass_threshold: float = pass_threshold
        self.target_tokens: list = normalize_text(target_sentence).split()
        self.user_tokens: list = normalize_text(user_transcript).split()
        self.ratio: float
        self.alignments: list
        self.ratio, self.alignments = similarity_and_diff(self.target_tokens,
                                                          self.user_tokens)
        self.passed: bool = self.ratio >= self.pass_threshold