"""Shared inference helpers for the aud2seq demo. This module mirrors the chat-style data layout used during training so that both the Gradio app and CLI tooling can invoke the fine-tuned Qwen2.5 Omni model in a consistent way. """ from __future__ import annotations import json from dataclasses import dataclass from functools import lru_cache from pathlib import Path from typing import Any, Dict, List, Optional import logging import torch from peft import AutoPeftModelForCausalLM from qwen_omni_utils import process_mm_info from transformers import Qwen2_5OmniProcessor ADAPTER_PATH = (Path(__file__).resolve().parent / "model" / "checkpoint-22000").resolve() DEFAULT_SYSTEM_PROMPT = "You are an audio segmentation assistant. For each turn, return JSON with the requested keys." DEFAULT_INSTRUCTION = "General note: The segments should be relatively concise (1-8 seconds). To begin, respond with a JSON object containing `start_time` and `end_time` for the initial segment. Return the JSON only, with no extra commentary." USE_AUDIO_IN_VIDEO = False @dataclass class InferenceResult: """Container for model generations.""" text: str parsed: Optional[Any] conversation: List[Dict[str, Any]] @lru_cache(maxsize=1) def get_processor() -> Qwen2_5OmniProcessor: return Qwen2_5OmniProcessor.from_pretrained( str(ADAPTER_PATH), trust_remote_code=True, ) @lru_cache(maxsize=1) def get_model() -> AutoPeftModelForCausalLM: model = AutoPeftModelForCausalLM.from_pretrained( str(ADAPTER_PATH), torch_dtype=torch.bfloat16, device_map="auto", trust_remote_code=True, ) model.eval() return model def build_conversation( audio_path: str, instruction: Optional[str] = None, *, system_prompt: Optional[str] = None, ) -> List[Dict[str, Any]]: prompt = system_prompt.strip() if system_prompt else DEFAULT_SYSTEM_PROMPT user_instruction = instruction.strip() if instruction else DEFAULT_INSTRUCTION conversation: List[Dict[str, Any]] = [ { "role": "system", "content": [{"type": "text", "text": prompt}], }, { "role": "user", "content": [ {"type": "audio", "audio": audio_path}, {"type": "text", "text": user_instruction}, ], }, ] return conversation def generate_response( conversation: List[Dict[str, Any]], *, max_new_tokens: int = 1024, temperature: float = 0.0, top_p: float = 0.9, do_sample: Optional[bool] = None, ) -> InferenceResult: processor = get_processor() model = get_model() logging.disable(logging.WARNING) try: chat_text = processor.apply_chat_template( conversation, add_generation_prompt=True, tokenize=False, ) finally: logging.disable(logging.NOTSET) audios, images, videos = process_mm_info(conversation, use_audio_in_video=USE_AUDIO_IN_VIDEO) batch = processor( text=chat_text, audio=audios, images=images, videos=videos, return_tensors="pt", padding=True, use_audio_in_video=USE_AUDIO_IN_VIDEO, ) model_inputs = { key: value.to(model.device) if isinstance(value, torch.Tensor) else value for key, value in batch.items() } sampling = do_sample if do_sample is not None else temperature > 0.0 generation_kwargs = dict( max_new_tokens=max_new_tokens, use_audio_in_video=USE_AUDIO_IN_VIDEO, do_sample=sampling, ) eos_token_id = processor.tokenizer.eos_token_id or model.config.eos_token_id if eos_token_id is not None: generation_kwargs["eos_token_id"] = eos_token_id pad_token_id = processor.tokenizer.pad_token_id or model.config.pad_token_id if pad_token_id is not None: generation_kwargs.setdefault("pad_token_id", pad_token_id) if sampling: generation_kwargs["temperature"] = max(temperature, 1e-4) generation_kwargs["top_p"] = top_p logging.disable(logging.WARNING) try: with torch.no_grad(): generated_ids = model.generate(**model_inputs, **generation_kwargs) finally: logging.disable(logging.NOTSET) prompt_length = model_inputs["input_ids"].shape[1] generated_tokens = generated_ids[:, prompt_length:] decoded = processor.batch_decode( generated_tokens, skip_special_tokens=True, clean_up_tokenization_spaces=False, )[0] text = decoded.strip() parsed = _try_parse_json(text) return InferenceResult(text=text, parsed=parsed, conversation=conversation) def generate_single_turn( audio_path: str, instruction: Optional[str] = None, *, system_prompt: Optional[str] = None, max_new_tokens: int = 1024, temperature: float = 0.0, top_p: float = 0.9, do_sample: Optional[bool] = None, ) -> InferenceResult: conversation = build_conversation( audio_path, instruction=instruction, system_prompt=system_prompt, ) return generate_response( conversation, max_new_tokens=max_new_tokens, temperature=temperature, top_p=top_p, do_sample=do_sample, ) def _try_parse_json(text: str) -> Optional[Any]: candidate = text.strip() if not candidate: return None if candidate.startswith("```") and candidate.endswith("```"): body = candidate.strip("`") lines = body.splitlines() if lines and lines[0].lower().startswith("json"): lines = lines[1:] candidate = "\n".join(lines).strip() for snippet in (candidate, candidate.rstrip("`")): try: return json.loads(snippet) except json.JSONDecodeError: pass prefix = _extract_json_prefix(snippet) if prefix is not None: try: return json.loads(prefix) except json.JSONDecodeError: continue return None def _extract_json_prefix(text: str) -> Optional[str]: start = None opening = None for idx, char in enumerate(text): if char in "{[": start = idx opening = char break if start is None: return None stack = [] in_string = False escape = False for idx in range(start, len(text)): char = text[idx] if char == "\\" and not escape: escape = True continue if in_string: if char == '"' and not escape: in_string = False escape = False continue if char == '"': in_string = True escape = False continue if char in "{[": stack.append(char) elif char in "}]": if not stack: return None opening_char = stack.pop() if (opening_char, char) not in (("{", "}"), ("[", "]")): return None if not stack: return text[start : idx + 1] escape = False return None def format_inference_result(result: InferenceResult, include_pretty_json: bool = True) -> str: """Return a display string for an inference result.""" text = result.text.strip() for stop_token in ("<|im_end|>", "Human:", "Assistant:"): if stop_token in text: text = text.split(stop_token, 1)[0].strip() if not include_pretty_json or result.parsed is None: return text try: pretty = json.dumps(result.parsed, indent=2) return "```json\n" + pretty + "\n```" except (TypeError, ValueError): return text __all__ = [ "ADAPTER_PATH", "DEFAULT_INSTRUCTION", "DEFAULT_SYSTEM_PROMPT", "InferenceResult", "build_conversation", "generate_single_turn", "generate_response", "format_inference_result", "get_model", "get_processor", ]