|
|
from __future__ import annotations |
|
|
import os |
|
|
import copy |
|
|
import uuid |
|
|
import logging |
|
|
from typing import List, Optional, Tuple, Dict |
|
|
|
|
|
|
|
|
os.environ.setdefault("TQDM_DISABLE", "1") |
|
|
os.environ.setdefault("TOKENIZERS_PARALLELISM", "false") |
|
|
|
|
|
import numpy as np |
|
|
import torch |
|
|
import torchaudio |
|
|
import soundfile as sf |
|
|
import gradio as gr |
|
|
|
|
|
|
|
|
from nemo.collections.asr.models import ASRModel |
|
|
from omegaconf import OmegaConf |
|
|
from nemo.utils import logging as nemo_logging |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
MODEL_NAME = os.environ.get("PARAKEET_MODEL", "nvidia/parakeet-tdt-0.6b-v3") |
|
|
TARGET_SR = 16_000 |
|
|
BEAM_SIZE = int(os.environ.get("PARAKEET_BEAM_SIZE", "16")) |
|
|
OFFLINE_BATCH= int(os.environ.get("PARAKEET_BATCH", "8")) |
|
|
CHUNK_S = float(os.environ.get("PARAKEET_CHUNK_S", "2.0")) |
|
|
FLUSH_PAD_S = float(os.environ.get("PARAKEET_FLUSH_PAD_S", "2.0")) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
LOG_LEVEL = os.environ.get("LOG_LEVEL", "INFO").upper() |
|
|
logger = logging.getLogger("parakeet_app") |
|
|
logger.setLevel(getattr(logging, LOG_LEVEL, logging.INFO)) |
|
|
_handler = logging.StreamHandler() |
|
|
_handler.setFormatter(logging.Formatter("%(asctime)s %(levelname)s %(name)s: %(message)s")) |
|
|
logger.handlers = [_handler] |
|
|
logger.propagate = False |
|
|
|
|
|
|
|
|
nemo_logging.setLevel(logging.ERROR) |
|
|
logging.getLogger("nemo").setLevel(logging.ERROR) |
|
|
logging.getLogger("nemo.collections.asr").setLevel(logging.ERROR) |
|
|
|
|
|
torch.set_grad_enabled(False) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def to_mono_np(x: np.ndarray) -> np.ndarray: |
|
|
if x.ndim == 2: |
|
|
x = x.mean(axis=1) |
|
|
return x.astype(np.float32, copy=False) |
|
|
|
|
|
class ResamplerCache: |
|
|
def __init__(self): |
|
|
self._cache: Dict[int, torchaudio.transforms.Resample] = {} |
|
|
def resample(self, wav: np.ndarray, src_sr: int) -> np.ndarray: |
|
|
if src_sr == TARGET_SR: |
|
|
return wav |
|
|
if src_sr not in self._cache: |
|
|
logger.debug(f"create_resampler src_sr={src_sr} -> {TARGET_SR}") |
|
|
self._cache[src_sr] = torchaudio.transforms.Resample(orig_freq=src_sr, new_freq=TARGET_SR) |
|
|
t = torch.from_numpy(wav) |
|
|
if t.ndim == 1: |
|
|
t = t.unsqueeze(0) |
|
|
y = self._cache[src_sr](t) |
|
|
return y.squeeze(0).numpy() |
|
|
|
|
|
RESAMPLER = ResamplerCache() |
|
|
|
|
|
def load_mono16k(path: str) -> np.ndarray: |
|
|
"""Load any audio file, convert to mono float32 at 16 kHz.""" |
|
|
try: |
|
|
wav, sr = sf.read(path, dtype="float32", always_2d=True) |
|
|
wav = wav.mean(axis=1).astype(np.float32, copy=False) |
|
|
return RESAMPLER.resample(wav, sr) |
|
|
except Exception: |
|
|
wav_t, sr = torchaudio.load(path) |
|
|
if wav_t.dtype != torch.float32: |
|
|
wav_t = wav_t.float() |
|
|
wav = wav_t.mean(dim=0).numpy() |
|
|
return RESAMPLER.resample(wav, int(sr)) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class ParakeetManager: |
|
|
def __init__(self, device: str = "cpu"): |
|
|
self.device = torch.device(device) |
|
|
logger.info(f"loading_model name={MODEL_NAME} device={self.device}") |
|
|
self.model: ASRModel = ASRModel.from_pretrained(model_name=MODEL_NAME) |
|
|
self.model.to(self.device) |
|
|
self.model.eval() |
|
|
for p in self.model.parameters(): |
|
|
p.requires_grad = False |
|
|
|
|
|
|
|
|
if hasattr(self.model, "decoder") and hasattr(self.model.decoder, "decoder"): |
|
|
self._base_decoding = copy.deepcopy(self.model.decoder.decoder.cfg) |
|
|
else: |
|
|
self._base_decoding = copy.deepcopy(self.model.cfg.decoding) |
|
|
|
|
|
self._set_malsd_beam() |
|
|
logger.info(f"model_loaded strategy=malsd_batch beam_size={BEAM_SIZE}") |
|
|
|
|
|
def _set_malsd_beam(self): |
|
|
cfg = copy.deepcopy(self._base_decoding) |
|
|
cfg.strategy = "malsd_batch" |
|
|
cfg.beam = OmegaConf.create({ |
|
|
"beam_size": BEAM_SIZE, |
|
|
"return_best_hypothesis": True, |
|
|
"score_norm": True, |
|
|
"allow_cuda_graphs": False, |
|
|
"max_symbols_per_step": 10, |
|
|
}) |
|
|
OmegaConf.set_struct(cfg, False) |
|
|
cfg["loop_labels"] = True |
|
|
cfg["fused_batch_size"] = -1 |
|
|
cfg["compute_timestamps"] = False |
|
|
if hasattr(cfg, "greedy"): |
|
|
cfg.greedy.use_cuda_graph_decoder = False |
|
|
self.model.change_decoding_strategy(cfg) |
|
|
logger.info("decoding_set strategy=malsd_batch loop_labels=True") |
|
|
|
|
|
def _transcribe(self, items: List, *, partial=None): |
|
|
with torch.inference_mode(): |
|
|
return self.model.transcribe( |
|
|
items, |
|
|
batch_size=1 if len(items) == 1 else OFFLINE_BATCH, |
|
|
num_workers=0, |
|
|
return_hypotheses=True, |
|
|
partial_hypothesis=partial, |
|
|
) |
|
|
|
|
|
|
|
|
def transcribe_files(self, paths: List[str]): |
|
|
n = 0 if not paths else len(paths) |
|
|
logger.info(f"files_run start count={n} batch={OFFLINE_BATCH}") |
|
|
if not paths: |
|
|
return [] |
|
|
arrays = [load_mono16k(p) for p in paths] |
|
|
out = self._transcribe(arrays, partial=None) |
|
|
results = [] |
|
|
for p, o in zip(paths, out): |
|
|
h = o[0] if isinstance(o, list) and o else o |
|
|
text = h if isinstance(h, str) else getattr(h, "text", "") |
|
|
results.append({"path": p, "text": text}) |
|
|
logger.info("files_run ok") |
|
|
return results |
|
|
|
|
|
|
|
|
def stream_step(self, audio_16k: np.ndarray, prev_hyp) -> object: |
|
|
out = self._transcribe([audio_16k], partial=[prev_hyp] if prev_hyp is not None else None) |
|
|
h = out[0][0] if isinstance(out[0], list) else out[0] |
|
|
return h |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class StreamingSession: |
|
|
def __init__(self, manager: ParakeetManager, chunk_s: float, flush_pad_s: float): |
|
|
self.mgr = manager |
|
|
self.chunk_s = chunk_s |
|
|
self.flush_pad_s = flush_pad_s |
|
|
self.hyp = None |
|
|
self.pending = np.zeros(0, dtype=np.float32) |
|
|
self.text = "" |
|
|
logger.info(f"mic_reset chunk={self.chunk_s}s flush_pad={self.flush_pad_s}s") |
|
|
|
|
|
def add_audio(self, audio: np.ndarray, src_sr: int): |
|
|
mono = to_mono_np(audio) |
|
|
res = RESAMPLER.resample(mono, src_sr) |
|
|
self.pending = np.concatenate([self.pending, res]) if self.pending.size else res |
|
|
self._drain() |
|
|
|
|
|
def _drain(self): |
|
|
C = int(self.chunk_s * TARGET_SR) |
|
|
while self.pending.size >= C: |
|
|
chunk = self.pending[:C] |
|
|
self.pending = self.pending[C:] |
|
|
try: |
|
|
self.hyp = self.mgr.stream_step(chunk, self.hyp) |
|
|
self.text = getattr(self.hyp, "text", self.text) |
|
|
except Exception: |
|
|
logger.exception("mic_step failed") |
|
|
break |
|
|
|
|
|
def flush(self) -> str: |
|
|
if self.pending.size: |
|
|
pad = np.zeros(int(self.flush_pad_s * TARGET_SR), dtype=np.float32) |
|
|
final = np.concatenate([self.pending, pad]) |
|
|
try: |
|
|
self.hyp = self.mgr.stream_step(final, self.hyp) |
|
|
self.text = getattr(self.hyp, "text", self.text) |
|
|
except Exception: |
|
|
logger.exception("mic_flush failed") |
|
|
self.pending = np.zeros(0, dtype=np.float32) |
|
|
return self.text |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
SESS: Dict[str, StreamingSession] = {} |
|
|
def _new_session_id() -> str: |
|
|
return uuid.uuid4().hex |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
MANAGER = ParakeetManager(device="cpu") |
|
|
|
|
|
def _parse_gr_audio(x) -> Tuple[np.ndarray, int]: |
|
|
if x is None: |
|
|
return np.zeros(0, dtype=np.float32), TARGET_SR |
|
|
if isinstance(x, tuple) and len(x) == 2: |
|
|
sr = int(x[0]); arr = np.array(x[1], dtype=np.float32); return arr, sr |
|
|
if isinstance(x, dict) and "data" in x and "sampling_rate" in x: |
|
|
arr = np.array(x["data"], dtype=np.float32); sr = int(x["sampling_rate"]); return arr, sr |
|
|
if isinstance(x, np.ndarray): |
|
|
return x.astype(np.float32, copy=False), TARGET_SR |
|
|
logger.error(f"unsupported_gr_audio_payload type={type(x)}"); raise ValueError("Unsupported audio payload") |
|
|
|
|
|
def mic_step(audio_chunk, sess_id: Optional[str]): |
|
|
if not sess_id or sess_id not in SESS: |
|
|
sess_id = _new_session_id() |
|
|
SESS[sess_id] = StreamingSession(MANAGER, CHUNK_S, FLUSH_PAD_S) |
|
|
sess = SESS[sess_id] |
|
|
try: |
|
|
wav, sr = _parse_gr_audio(audio_chunk) |
|
|
except Exception: |
|
|
logger.exception("mic_parse failed") |
|
|
return sess_id, sess.text |
|
|
if wav.size: |
|
|
sess.add_audio(wav, sr) |
|
|
return sess_id, sess.text |
|
|
|
|
|
def mic_flush(sess_id: Optional[str]): |
|
|
if not sess_id or sess_id not in SESS: |
|
|
return None, "" |
|
|
text = SESS[sess_id].flush() |
|
|
logger.info("mic_flush ok") |
|
|
return None, text |
|
|
|
|
|
def files_run(files): |
|
|
n = 0 if not files else len(files) |
|
|
logger.info(f"files_ui start count={n}") |
|
|
if not files: |
|
|
return [] |
|
|
paths: List[str] = [] |
|
|
for f in files: |
|
|
if isinstance(f, str): |
|
|
paths.append(f) |
|
|
elif hasattr(f, "name"): |
|
|
paths.append(f.name) |
|
|
try: |
|
|
results = MANAGER.transcribe_files(paths) |
|
|
except Exception: |
|
|
logger.exception("files_run failed"); raise |
|
|
table = [[os.path.basename(r["path"]), r["text"]] for r in results] |
|
|
logger.info("files_ui ok") |
|
|
return table |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
with gr.Blocks(title="Parakeet-TDT v3 (Unified MALSD Beam)") as demo: |
|
|
with gr.Tab("Mic"): |
|
|
mic = gr.Audio(sources=["microphone"], type="numpy", streaming=True, label="Speak") |
|
|
text_out = gr.Textbox(label="Transcript", lines=8) |
|
|
flush_btn = gr.Button("Flush") |
|
|
state_id = gr.State() |
|
|
mic.stream(mic_step, inputs=[mic, state_id], outputs=[state_id, text_out]) |
|
|
flush_btn.click(mic_flush, inputs=[state_id], outputs=[state_id, text_out]) |
|
|
|
|
|
with gr.Tab("Files"): |
|
|
files = gr.File(file_count="multiple", type="filepath", label="Upload audio files") |
|
|
run_btn = gr.Button("Run") |
|
|
results_table = gr.Dataframe(headers=["file", "text"], label="Results", |
|
|
row_count=(0, "dynamic"), col_count=(2, "fixed")) |
|
|
run_btn.click(files_run, inputs=[files], outputs=[results_table]) |
|
|
|
|
|
demo.queue().launch(ssr_mode=False) |