diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000000000000000000000000000000000000..fafde6777cebb2d5c4e995b98942592f7ae533de --- /dev/null +++ b/.gitignore @@ -0,0 +1,10 @@ +__pycache__/ +*.pyc +*.pyo +*.pyd +*.pyw +*.pyz +*.pywz +*.pyzw +*.pyzwz +.ruff_cache/ \ No newline at end of file diff --git a/README.md b/README.md index ccf4730ffc58509b4cff9827544be54f41e6a2f6..537ec7771d66bade22f67f7f65db688d9ad9ff08 100644 --- a/README.md +++ b/README.md @@ -1,6 +1,6 @@ --- title: Higgs Audio Demo -emoji: 🏢 +emoji: 🎤 colorFrom: green colorTo: purple sdk: gradio diff --git a/app.py b/app.py index 04cc31aa8d0e06aeaac3b59bb361ed71d831e43f..d7c6a2af561ce24542a9f75479148b832a8883e3 100644 --- a/app.py +++ b/app.py @@ -1,7 +1,531 @@ +""" +Gradio UI for Text-to-Speech using HiggsAudioServeEngine +""" + +import argparse +import base64 +import os +import uuid +import json +from typing import Optional import gradio as gr +from loguru import logger +import numpy as np +import time +from functools import lru_cache +import re +import spaces + + +# Import HiggsAudio components +from higgs_audio.serve.serve_engine import HiggsAudioServeEngine +from higgs_audio.data_types import ChatMLSample, AudioContent, Message + +# Global engine instance +engine = None + +# Set up default paths and resources +EXAMPLES_DIR = os.path.join(os.path.dirname(__file__), "examples") +os.makedirs(EXAMPLES_DIR, exist_ok=True) + +# Default model configuration +DEFAULT_MODEL_PATH = "bosonai/higgs-audio-v2-generation-3B-staging" +DEFAULT_AUDIO_TOKENIZER_PATH = "bosonai/higgs-audio-v2-tokenizer-staging" +SAMPLE_RATE = 24000 + +DEFAULT_SYSTEM_PROMPT = ( + "Generate audio following instruction.\n\n" + "<|scene_desc_start|>\n" + "Audio is recorded from a quiet room.\n" + "<|scene_desc_end|>" +) + +DEFAULT_STOP_STRINGS = ["<|end_of_text|>", "<|eot_id|>"] + +# Predefined examples for system and input messages +PREDEFINED_EXAMPLES = { + "None": {"system_prompt": "", "input_text": "", "description": "Default example"}, + "multispeaker-interleave": { + "system_prompt": "Generate audio following instruction.\n\n" + "<|scene_desc_start|>\n" + "SPEAKER0: vocal fry;feminism;slightly fast\n" + "SPEAKER1: masculine;moderate;moderate pitch;monotone;mature\n" + "In this scene, a group of adventurers is debating whether to investigate a potentially dangerous situation.\n" + "<|scene_desc_end|>", + "input_text": "<|generation_instruction_start|>\nGenerate interleaved transcript and audio that lasts for around 10 seconds.\n<|generation_instruction_end|>", + "description": "Multispeaker interleave example", + }, + "single-speaker": { + "system_prompt": "Generate audio following instruction.\n\n" + "<|scene_desc_start|>\n" + "SPEAKER0: british accent\n" + "<|scene_desc_end|>", + "input_text": "Hey, everyone! Welcome back to Tech Talk Tuesdays.\n" + "It's your host, Alex, and today, we're diving into a topic that's become absolutely crucial in the tech world — deep learning.\n" + "And let's be honest, if you've been even remotely connected to tech, AI, or machine learning lately, you know that deep learning is everywhere.\n" + "\n" + "So here's the big question: Do you want to understand how deep learning works?\n" + "How to use it to build powerful models that can predict, automate, and transform industries?\n" + "Well, today, I've got some exciting news for you.\n" + "\n" + "We're going to talk about a course that I highly recommend: Dive into Deep Learning.\n" + "It's not just another course; it's an entire experience that will take you from a beginner to someone who is well-versed in deep learning techniques.", + "description": "Single speaker example", + }, + "single-speaker-zh": { + "system_prompt": "Generate audio following instruction.\n\n" + "<|scene_desc_start|>\n" + "\nAudio is recorded from a quiet room.\n" + "\nSPEAKER0: feminine\n" + "<|scene_desc_end|>", + "input_text": "大家好, 欢迎收听本期的跟李沐学AI. 今天沐哥在忙着洗数据, 所以由我, 希格斯主播代替他讲这期视频.\n" + "今天我们要聊的是一个你绝对不能忽视的话题: 多模态学习.\n" + "无论你是开发者, 数据科学爱好者, 还是只是对人工智能感兴趣的人都一定听说过这个词. 它已经成为AI时代的一个研究热点.\n" + "那么, 问题来了, 你真的了解多模态吗? 你知道如何自己动手构建多模态大模型吗.\n" + "或者说, 你能察觉到我其实是个机器人吗?", + "description": "Single speaker with Chinese text", + }, +} + + +@lru_cache(maxsize=20) +def encode_audio_file(file_path): + """Encode an audio file to base64.""" + with open(file_path, "rb") as audio_file: + return base64.b64encode(audio_file.read()).decode("utf-8") + + +def load_voice_presets(): + """Load the voice presets from the voice_examples directory.""" + try: + with open( + os.path.join(os.path.dirname(__file__), "voice_examples", "config.json"), + "r", + ) as f: + voice_dict = json.load(f) + voice_presets = {k: v["transcript"] for k, v in voice_dict.items()} + voice_presets["EMPTY"] = "No reference voice" + logger.info(f"Loaded voice presets: {list(voice_presets.keys())}") + return voice_presets + except FileNotFoundError: + logger.warning("Voice examples config file not found. Using empty voice presets.") + return {"EMPTY": "No reference voice"} + except Exception as e: + logger.error(f"Error loading voice presets: {e}") + return {"EMPTY": "No reference voice"} + + +def get_voice_present(voice_preset): + """Get the voice path and text for a given voice preset.""" + voice_path = os.path.join(os.path.dirname(__file__), "voice_examples", f"{voice_preset}.wav") + if not os.path.exists(voice_path): + logger.warning(f"Voice preset file not found: {voice_path}") + return None, "Voice preset not found" + + text = VOICE_PRESETS.get(voice_preset, "No transcript available") + return voice_path, text + + +@spaces.GPU +def initialize_engine(model_path, audio_tokenizer_path, device="cuda") -> bool: + """Initialize the HiggsAudioServeEngine.""" + global engine + try: + engine = HiggsAudioServeEngine( + model_name_or_path=model_path, + audio_tokenizer_name_or_path=audio_tokenizer_path, + device=device, + ) + logger.info(f"Successfully initialized HiggsAudioServeEngine with model: {model_path}") + return True + except Exception as e: + logger.error(f"Failed to initialize engine: {e}") + return False + + +def check_return_audio(audio_wv: np.ndarray): + # check if the audio returned is all silent + if np.all(audio_wv == 0): + logger.warning("Audio is silent, returning None") + + +def process_text_output(text_output: str): + # remove all the continuous <|AUDIO_OUT|> tokens with a single <|AUDIO_OUT|> + text_output = re.sub(r"(<\|AUDIO_OUT\|>)+", r"<|AUDIO_OUT|>", text_output) + return text_output + + +def prepare_chatml_sample( + voice_present: str, + text: str, + reference_audio: Optional[str] = None, + reference_text: Optional[str] = None, + system_prompt: str = DEFAULT_SYSTEM_PROMPT, +): + """Prepare a ChatMLSample for the HiggsAudioServeEngine.""" + messages = [] + + # Add system message if provided + if len(system_prompt) > 0: + messages.append(Message(role="system", content=system_prompt)) + + # Add reference audio if provided + audio_base64 = None + ref_text = "" + + if reference_audio: + # Custom reference audio + audio_base64 = encode_audio_file(reference_audio) + ref_text = reference_text or "" + elif voice_present != "EMPTY": + # Voice preset + voice_path, ref_text = get_voice_present(voice_present) + if voice_path is None: + logger.warning(f"Voice preset {voice_present} not found, skipping reference audio") + else: + audio_base64 = encode_audio_file(voice_path) + + # Only add reference audio if we have it + if audio_base64 is not None: + # Add user message with reference text + messages.append(Message(role="user", content=ref_text)) + + # Add assistant message with audio content + audio_content = AudioContent(raw_audio=audio_base64, audio_url="") + messages.append(Message(role="assistant", content=[audio_content])) + + # Add the main user message + messages.append(Message(role="user", content=text)) + + return ChatMLSample(messages=messages) + + +@spaces.GPU(duration=500) +def text_to_speech( + text, + voice_preset, + reference_audio=None, + reference_text=None, + max_completion_tokens=1024, + temperature=1.0, + top_p=0.95, + top_k=50, + system_prompt=DEFAULT_SYSTEM_PROMPT, + stop_strings=None, +): + """Convert text to speech using HiggsAudioServeEngine.""" + global engine + + if engine is None: + error_msg = "Engine not initialized. Please load a model first." + logger.error(error_msg) + gr.Error(error_msg) + return f"❌ {error_msg}", None + + try: + # Prepare ChatML sample + chatml_sample = prepare_chatml_sample(voice_preset, text, reference_audio, reference_text, system_prompt) + + # Convert stop strings format + if stop_strings is None: + stop_list = DEFAULT_STOP_STRINGS + else: + stop_list = [s for s in stop_strings["stops"] if s.strip()] + + request_id = f"tts-playground-{str(uuid.uuid4())}" + logger.info( + f"{request_id}: Generating speech for text: {text[:100]}..., \n" + f"with parameters: temperature={temperature}, top_p={top_p}, top_k={top_k}, stop_list={stop_list}" + ) + start_time = time.time() + + # Generate using the engine + response = engine.generate( + chat_ml_sample=chatml_sample, + max_new_tokens=max_completion_tokens, + temperature=temperature, + top_k=top_k if top_k > 0 else None, + top_p=top_p, + stop_strings=stop_list, + ) + + generation_time = time.time() - start_time + logger.info(f"{request_id}: Generated audio in {generation_time:.3f} seconds") + gr.Info(f"Generated audio in {generation_time:.3f} seconds") + + # Process the response + text_output = process_text_output(response.generated_text) + + if response.audio is not None: + # Convert to int16 for Gradio + audio_data = (response.audio * 32767).astype(np.int16) + check_return_audio(audio_data) + return text_output, (response.sampling_rate, audio_data) + else: + logger.warning("No audio generated") + return text_output, None + + except Exception as e: + error_msg = f"Error generating speech: {e}" + logger.error(error_msg) + gr.Error(error_msg) + return f"❌ {error_msg}", None + + +def create_ui(): + my_theme = "JohnSmith9982/small_and_pretty" + + # Add custom CSS to disable focus highlighting on textboxes + custom_css = """ + .gradio-container input:focus, + .gradio-container textarea:focus, + .gradio-container select:focus, + .gradio-container .gr-input:focus, + .gradio-container .gr-textarea:focus, + .gradio-container .gr-textbox:focus, + .gradio-container .gr-textbox:focus-within, + .gradio-container .gr-form:focus-within, + .gradio-container *:focus { + box-shadow: none !important; + border-color: var(--border-color-primary) !important; + outline: none !important; + background-color: var(--input-background-fill) !important; + } + + /* Override any hover effects as well */ + .gradio-container input:hover, + .gradio-container textarea:hover, + .gradio-container select:hover, + .gradio-container .gr-input:hover, + .gradio-container .gr-textarea:hover, + .gradio-container .gr-textbox:hover { + border-color: var(--border-color-primary) !important; + background-color: var(--input-background-fill) !important; + } + + /* Style for checked checkbox */ + .gradio-container input[type="checkbox"]:checked { + background-color: var(--primary-500) !important; + border-color: var(--primary-500) !important; + } + """ + + """Create the Gradio UI.""" + with gr.Blocks(theme=my_theme, css=custom_css) as demo: + gr.Markdown("# Higgs Audio Text-to-Speech Playground") + + # Main UI section + with gr.Row(): + with gr.Column(scale=2): + # Template selection dropdown + template_dropdown = gr.Dropdown( + label="Message examples", + choices=list(PREDEFINED_EXAMPLES.keys()), + value="None", + info="Select a predefined example for system and input messages. Voice preset will be set to EMPTY when a example is selected.", + ) + + system_prompt = gr.TextArea( + label="System Prompt", + placeholder="Enter system prompt to guide the model...", + value=DEFAULT_SYSTEM_PROMPT, + lines=2, + ) + + input_text = gr.TextArea( + label="Input Text", + placeholder="Type the text you want to convert to speech...", + lines=5, + ) + + voice_preset = gr.Dropdown( + label="Voice Preset", + choices=list(VOICE_PRESETS.keys()), + value="EMPTY", + ) + + with gr.Accordion("Custom Reference (Optional)", open=False): + reference_audio = gr.Audio(label="Reference Audio", type="filepath") + reference_text = gr.TextArea( + label="Reference Text (transcript of the reference audio)", + placeholder="Enter the transcript of your reference audio...", + lines=3, + ) + + with gr.Accordion("Advanced Parameters", open=False): + max_completion_tokens = gr.Slider( + minimum=128, + maximum=4096, + value=1024, + step=10, + label="Max Completion Tokens", + ) + temperature = gr.Slider( + minimum=0.0, + maximum=1.5, + value=1.0, + step=0.1, + label="Temperature", + ) + top_p = gr.Slider(minimum=0.1, maximum=1.0, value=0.95, step=0.05, label="Top P") + top_k = gr.Slider(minimum=-1, maximum=100, value=50, step=1, label="Top K") + # Add stop strings component + stop_strings = gr.Dataframe( + label="Stop Strings", + headers=["stops"], + datatype=["str"], + value=[[s] for s in DEFAULT_STOP_STRINGS], + interactive=True, + col_count=(1, "fixed"), + ) + + submit_btn = gr.Button("Generate Speech", variant="primary", scale=1) + + with gr.Column(scale=2): + output_text = gr.TextArea(label="Model Response", lines=2) + + # Audio output + output_audio = gr.Audio(label="Generated Audio", interactive=False, autoplay=True) + + stop_btn = gr.Button("Stop Playback", variant="primary") + + # Example voice + with gr.Row(): + voice_samples_table = gr.Dataframe( + headers=["Voice Preset", "Sample Text"], + datatype=["str", "str"], + value=[[preset, text] for preset, text in VOICE_PRESETS.items() if preset != "EMPTY"], + interactive=False, + ) + sample_audio = gr.Audio(label="Voice Sample", visible=True) + + # Function to play voice sample when clicking on a row + def play_voice_sample(evt: gr.SelectData): + try: + # Get the preset name from the clicked row + preset_names = [preset for preset in VOICE_PRESETS.keys() if preset != "EMPTY"] + if evt.index[0] < len(preset_names): + preset = preset_names[evt.index[0]] + voice_path, _ = get_voice_present(preset) + if voice_path and os.path.exists(voice_path): + return voice_path + else: + gr.Warning(f"Voice sample file not found for preset: {preset}") + return None + else: + gr.Warning("Invalid voice preset selection") + return None + except Exception as e: + logger.error(f"Error playing voice sample: {e}") + gr.Error(f"Error playing voice sample: {e}") + return None + + voice_samples_table.select(fn=play_voice_sample, outputs=[sample_audio]) + + # Function to handle template selection + def apply_template(template_name): + if template_name in PREDEFINED_EXAMPLES: + template = PREDEFINED_EXAMPLES[template_name] + return ( + template["system_prompt"], # system_prompt + template["input_text"], # input_text + "EMPTY", # voice_preset (always set to EMPTY for examples) + ) + else: + return ( + gr.update(), + gr.update(), + gr.update(), + ) # No change if template not found + + # Set up event handlers + + # Connect template dropdown to handler + template_dropdown.change( + fn=apply_template, + inputs=[template_dropdown], + outputs=[system_prompt, input_text, voice_preset], + ) + + # Connect submit button to the TTS function + submit_btn.click( + fn=text_to_speech, + inputs=[ + input_text, + voice_preset, + reference_audio, + reference_text, + max_completion_tokens, + temperature, + top_p, + top_k, + system_prompt, + stop_strings, + ], + outputs=[output_text, output_audio], + api_name="generate_speech", + ) + + # Stop button functionality + stop_btn.click( + fn=lambda: None, + inputs=[], + outputs=[output_audio], + js="() => {const audio = document.querySelector('audio'); if(audio) audio.pause(); return null;}", + ) + + return demo + + +def main(): + """Main function to parse arguments and launch the UI.""" + global DEFAULT_MODEL_PATH, DEFAULT_AUDIO_TOKENIZER_PATH, VOICE_PRESETS + + parser = argparse.ArgumentParser(description="Gradio UI for Text-to-Speech using HiggsAudioServeEngine") + parser.add_argument( + "--model-path", + type=str, + default=DEFAULT_MODEL_PATH, + help="Path to the Higgs Audio model.", + ) + parser.add_argument( + "--audio-tokenizer-path", + type=str, + default=DEFAULT_AUDIO_TOKENIZER_PATH, + help="Path to the audio tokenizer.", + ) + parser.add_argument( + "--device", + type=str, + default="cuda", + choices=["cuda", "cpu"], + help="Device to run the model on.", + ) + parser.add_argument("--host", type=str, default="0.0.0.0", help="Host for the Gradio interface.") + parser.add_argument("--port", type=int, default=7860, help="Port for the Gradio interface.") + + args = parser.parse_args() + + # Update default values if provided via command line + DEFAULT_MODEL_PATH = args.model_path + DEFAULT_AUDIO_TOKENIZER_PATH = args.audio_tokenizer_path + VOICE_PRESETS = load_voice_presets() + + # Load model on startup + logger.info("Loading model...") + result = initialize_engine(args.model_path, args.audio_tokenizer_path, args.device) + + # Exit if model loading failed + if not result: + logger.error("Failed to load model. Exiting.") + return + + logger.info(f"Model loaded: {DEFAULT_MODEL_PATH}") + + # Create and launch the UI + demo = create_ui() + demo.launch(server_name=args.host, server_port=args.port) -def greet(name): - return "Hello " + name + "!!" -demo = gr.Interface(fn=greet, inputs="text", outputs="text") -demo.launch() +if __name__ == "__main__": + main() diff --git a/higgs_audio/__init__.py b/higgs_audio/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..4a60464c84d36b56145fa7aff66bcf970ab30346 --- /dev/null +++ b/higgs_audio/__init__.py @@ -0,0 +1 @@ +from .model import HiggsAudioConfig, HiggsAudioModel diff --git a/higgs_audio/audio_processing/LICENSE b/higgs_audio/audio_processing/LICENSE new file mode 100644 index 0000000000000000000000000000000000000000..c27989d900614d1a53e5bdcc1f9c2bee23ef3969 --- /dev/null +++ b/higgs_audio/audio_processing/LICENSE @@ -0,0 +1,51 @@ +Third-Party License Attribution for Audio Processing Module +=========================================================== + +This directory contains code derived from multiple open-source projects. +The following sections detail the licenses and attributions for third-party code. + +## XCodec Repository +The code in this directory is derived from: +https://github.com/zhenye234/xcodec + +## Individual File Attributions + +### Quantization Module (quantization/) +- Several files contain code derived from Meta Platforms, Inc. and the vector-quantize-pytorch repository +- Individual files contain their own license headers where applicable +- The vector-quantize-pytorch portions are licensed under the MIT License + +## License Terms + +### MIT License (for applicable portions) +Permission is hereby granted, free of charge, to any person obtaining a copy +of this software and associated documentation files (the "Software"), to deal +in the Software without restriction, including without limitation the rights +to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +copies of the Software, and to permit persons to whom the Software is +furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in all +copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +SOFTWARE. + +## Attribution Requirements +When using this code, please ensure proper attribution to: +1. The original xcodec repository: https://github.com/zhenye234/xcodec +2. Any other repositories mentioned in individual file headers +3. This derivative work and its modifications + +## Disclaimer +This directory contains modified versions of the original code. Please refer to +the original repositories for the canonical implementations and their specific +license terms. + +For any questions about licensing or attribution, please check the individual +file headers and the original source repositories. \ No newline at end of file diff --git a/higgs_audio/audio_processing/descriptaudiocodec/__init__.py b/higgs_audio/audio_processing/descriptaudiocodec/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/higgs_audio/audio_processing/descriptaudiocodec/dac/model/base.py b/higgs_audio/audio_processing/descriptaudiocodec/dac/model/base.py new file mode 100644 index 0000000000000000000000000000000000000000..08e39a2d9016c6ddc2491d0e2644b80c8efe3986 --- /dev/null +++ b/higgs_audio/audio_processing/descriptaudiocodec/dac/model/base.py @@ -0,0 +1,286 @@ +import math +from dataclasses import dataclass +from pathlib import Path +from typing import Union + +import numpy as np +import torch +import tqdm +from audiotools import AudioSignal +from torch import nn + +SUPPORTED_VERSIONS = ["1.0.0"] + + +@dataclass +class DACFile: + codes: torch.Tensor + + # Metadata + chunk_length: int + original_length: int + input_db: float + channels: int + sample_rate: int + padding: bool + dac_version: str + + def save(self, path): + artifacts = { + "codes": self.codes.numpy().astype(np.uint16), + "metadata": { + "input_db": self.input_db.numpy().astype(np.float32), + "original_length": self.original_length, + "sample_rate": self.sample_rate, + "chunk_length": self.chunk_length, + "channels": self.channels, + "padding": self.padding, + "dac_version": SUPPORTED_VERSIONS[-1], + }, + } + path = Path(path).with_suffix(".dac") + with open(path, "wb") as f: + np.save(f, artifacts) + return path + + @classmethod + def load(cls, path): + artifacts = np.load(path, allow_pickle=True)[()] + codes = torch.from_numpy(artifacts["codes"].astype(int)) + if artifacts["metadata"].get("dac_version", None) not in SUPPORTED_VERSIONS: + raise RuntimeError(f"Given file {path} can't be loaded with this version of descript-audio-codec.") + return cls(codes=codes, **artifacts["metadata"]) + + +class CodecMixin: + @property + def padding(self): + if not hasattr(self, "_padding"): + self._padding = True + return self._padding + + @padding.setter + def padding(self, value): + assert isinstance(value, bool) + + layers = [l for l in self.modules() if isinstance(l, (nn.Conv1d, nn.ConvTranspose1d))] + + for layer in layers: + if value: + if hasattr(layer, "original_padding"): + layer.padding = layer.original_padding + else: + layer.original_padding = layer.padding + layer.padding = tuple(0 for _ in range(len(layer.padding))) + + self._padding = value + + def get_delay(self): + # Any number works here, delay is invariant to input length + l_out = self.get_output_length(0) + L = l_out + + layers = [] + for layer in self.modules(): + if isinstance(layer, (nn.Conv1d, nn.ConvTranspose1d)): + layers.append(layer) + + for layer in reversed(layers): + d = layer.dilation[0] + k = layer.kernel_size[0] + s = layer.stride[0] + + if isinstance(layer, nn.ConvTranspose1d): + L = ((L - d * (k - 1) - 1) / s) + 1 + elif isinstance(layer, nn.Conv1d): + L = (L - 1) * s + d * (k - 1) + 1 + + L = math.ceil(L) + + l_in = L + + return (l_in - l_out) // 2 + + def get_output_length(self, input_length): + L = input_length + # Calculate output length + for layer in self.modules(): + if isinstance(layer, (nn.Conv1d, nn.ConvTranspose1d)): + d = layer.dilation[0] + k = layer.kernel_size[0] + s = layer.stride[0] + + if isinstance(layer, nn.Conv1d): + L = ((L - d * (k - 1) - 1) / s) + 1 + elif isinstance(layer, nn.ConvTranspose1d): + L = (L - 1) * s + d * (k - 1) + 1 + + L = math.floor(L) + return L + + @torch.no_grad() + def compress( + self, + audio_path_or_signal: Union[str, Path, AudioSignal], + win_duration: float = 1.0, + verbose: bool = False, + normalize_db: float = -16, + n_quantizers: int = None, + ) -> DACFile: + """Processes an audio signal from a file or AudioSignal object into + discrete codes. This function processes the signal in short windows, + using constant GPU memory. + + Parameters + ---------- + audio_path_or_signal : Union[str, Path, AudioSignal] + audio signal to reconstruct + win_duration : float, optional + window duration in seconds, by default 5.0 + verbose : bool, optional + by default False + normalize_db : float, optional + normalize db, by default -16 + + Returns + ------- + DACFile + Object containing compressed codes and metadata + required for decompression + """ + audio_signal = audio_path_or_signal + if isinstance(audio_signal, (str, Path)): + audio_signal = AudioSignal.load_from_file_with_ffmpeg(str(audio_signal)) + + self.eval() + original_padding = self.padding + original_device = audio_signal.device + + audio_signal = audio_signal.clone() + original_sr = audio_signal.sample_rate + + resample_fn = audio_signal.resample + loudness_fn = audio_signal.loudness + + # If audio is > 10 minutes long, use the ffmpeg versions + if audio_signal.signal_duration >= 10 * 60 * 60: + resample_fn = audio_signal.ffmpeg_resample + loudness_fn = audio_signal.ffmpeg_loudness + + original_length = audio_signal.signal_length + resample_fn(self.sample_rate) + input_db = loudness_fn() + + if normalize_db is not None: + audio_signal.normalize(normalize_db) + audio_signal.ensure_max_of_audio() + + nb, nac, nt = audio_signal.audio_data.shape + audio_signal.audio_data = audio_signal.audio_data.reshape(nb * nac, 1, nt) + win_duration = audio_signal.signal_duration if win_duration is None else win_duration + + if audio_signal.signal_duration <= win_duration: + # Unchunked compression (used if signal length < win duration) + self.padding = True + n_samples = nt + hop = nt + else: + # Chunked inference + self.padding = False + # Zero-pad signal on either side by the delay + audio_signal.zero_pad(self.delay, self.delay) + n_samples = int(win_duration * self.sample_rate) + # Round n_samples to nearest hop length multiple + n_samples = int(math.ceil(n_samples / self.hop_length) * self.hop_length) + hop = self.get_output_length(n_samples) + + codes = [] + range_fn = range if not verbose else tqdm.trange + + for i in range_fn(0, nt, hop): + x = audio_signal[..., i : i + n_samples] + x = x.zero_pad(0, max(0, n_samples - x.shape[-1])) + + audio_data = x.audio_data.to(self.device) + audio_data = self.preprocess(audio_data, self.sample_rate) + _, c, _, _, _ = self.encode(audio_data, n_quantizers) + codes.append(c.to(original_device)) + chunk_length = c.shape[-1] + + codes = torch.cat(codes, dim=-1) + + dac_file = DACFile( + codes=codes, + chunk_length=chunk_length, + original_length=original_length, + input_db=input_db, + channels=nac, + sample_rate=original_sr, + padding=self.padding, + dac_version=SUPPORTED_VERSIONS[-1], + ) + + if n_quantizers is not None: + codes = codes[:, :n_quantizers, :] + + self.padding = original_padding + return dac_file + + @torch.no_grad() + def decompress( + self, + obj: Union[str, Path, DACFile], + verbose: bool = False, + ) -> AudioSignal: + """Reconstruct audio from a given .dac file + + Parameters + ---------- + obj : Union[str, Path, DACFile] + .dac file location or corresponding DACFile object. + verbose : bool, optional + Prints progress if True, by default False + + Returns + ------- + AudioSignal + Object with the reconstructed audio + """ + self.eval() + if isinstance(obj, (str, Path)): + obj = DACFile.load(obj) + + original_padding = self.padding + self.padding = obj.padding + + range_fn = range if not verbose else tqdm.trange + codes = obj.codes + original_device = codes.device + chunk_length = obj.chunk_length + recons = [] + + for i in range_fn(0, codes.shape[-1], chunk_length): + c = codes[..., i : i + chunk_length].to(self.device) + z = self.quantizer.from_codes(c)[0] + r = self.decode(z) + recons.append(r.to(original_device)) + + recons = torch.cat(recons, dim=-1) + recons = AudioSignal(recons, self.sample_rate) + + resample_fn = recons.resample + loudness_fn = recons.loudness + + # If audio is > 10 minutes long, use the ffmpeg versions + if recons.signal_duration >= 10 * 60 * 60: + resample_fn = recons.ffmpeg_resample + loudness_fn = recons.ffmpeg_loudness + + recons.normalize(obj.input_db) + resample_fn(obj.sample_rate) + recons = recons[..., : obj.original_length] + loudness_fn() + recons.audio_data = recons.audio_data.reshape(-1, obj.channels, obj.original_length) + + self.padding = original_padding + return recons diff --git a/higgs_audio/audio_processing/descriptaudiocodec/dac/model/dac.py b/higgs_audio/audio_processing/descriptaudiocodec/dac/model/dac.py new file mode 100644 index 0000000000000000000000000000000000000000..efaed1c25eee7cbb55a96b4f12376b9d26d4a685 --- /dev/null +++ b/higgs_audio/audio_processing/descriptaudiocodec/dac/model/dac.py @@ -0,0 +1,365 @@ +import math +from typing import List +from typing import Union + +import numpy as np +import torch +from audiotools import AudioSignal +from audiotools.ml import BaseModel +from torch import nn + +from .base import CodecMixin +from dac.nn.layers import Snake1d +from dac.nn.layers import WNConv1d +from dac.nn.layers import WNConvTranspose1d +from dac.nn.quantize import ResidualVectorQuantize + + +def init_weights(m): + if isinstance(m, nn.Conv1d): + nn.init.trunc_normal_(m.weight, std=0.02) + nn.init.constant_(m.bias, 0) + + +class ResidualUnit(nn.Module): + def __init__(self, dim: int = 16, dilation: int = 1): + super().__init__() + pad = ((7 - 1) * dilation) // 2 + self.block = nn.Sequential( + Snake1d(dim), + WNConv1d(dim, dim, kernel_size=7, dilation=dilation, padding=pad), + Snake1d(dim), + WNConv1d(dim, dim, kernel_size=1), + ) + + def forward(self, x): + y = self.block(x) + pad = (x.shape[-1] - y.shape[-1]) // 2 + if pad > 0: + x = x[..., pad:-pad] + return x + y + + +class EncoderBlock(nn.Module): + def __init__(self, dim: int = 16, stride: int = 1): + super().__init__() + self.block = nn.Sequential( + ResidualUnit(dim // 2, dilation=1), + ResidualUnit(dim // 2, dilation=3), + ResidualUnit(dim // 2, dilation=9), + Snake1d(dim // 2), + WNConv1d( + dim // 2, + dim, + kernel_size=2 * stride, + stride=stride, + padding=math.ceil(stride / 2), + ), + ) + + def forward(self, x): + return self.block(x) + + +class Encoder(nn.Module): + def __init__( + self, + d_model: int = 64, + strides: list = [2, 4, 8, 8], + d_latent: int = 256, + ): + super().__init__() + # Create first convolution + self.block = [WNConv1d(1, d_model, kernel_size=7, padding=3)] + + # Create EncoderBlocks that double channels as they downsample by `stride` + for stride in strides: + d_model *= 2 + self.block += [EncoderBlock(d_model, stride=stride)] + + # Create last convolution + self.block += [ + Snake1d(d_model), + WNConv1d(d_model, d_latent, kernel_size=3, padding=1), + ] + + # Wrap black into nn.Sequential + self.block = nn.Sequential(*self.block) + self.enc_dim = d_model + + def forward(self, x): + return self.block(x) + + +class DecoderBlock(nn.Module): + def __init__(self, input_dim: int = 16, output_dim: int = 8, stride: int = 1, out_pad=0): + super().__init__() + self.block = nn.Sequential( + Snake1d(input_dim), + WNConvTranspose1d( + input_dim, + output_dim, + kernel_size=2 * stride, + stride=stride, + padding=math.ceil(stride / 2), + output_padding=stride % 2, # out_pad, + ), + ResidualUnit(output_dim, dilation=1), + ResidualUnit(output_dim, dilation=3), + ResidualUnit(output_dim, dilation=9), + ) + + def forward(self, x): + return self.block(x) + + +class Decoder(nn.Module): + def __init__( + self, + input_channel, + channels, + rates, + d_out: int = 1, + ): + super().__init__() + + # Add first conv layer + layers = [WNConv1d(input_channel, channels, kernel_size=7, padding=3)] + + # Add upsampling + MRF blocks + for i, stride in enumerate(rates): + input_dim = channels // 2**i + output_dim = channels // 2 ** (i + 1) + if i == 1: + out_pad = 1 + else: + out_pad = 0 + layers += [DecoderBlock(input_dim, output_dim, stride, out_pad)] + + # Add final conv layer + layers += [ + Snake1d(output_dim), + WNConv1d(output_dim, d_out, kernel_size=7, padding=3), + # nn.Tanh(), + ] + + self.model = nn.Sequential(*layers) + + def forward(self, x): + return self.model(x) + + +class DAC(BaseModel, CodecMixin): + def __init__( + self, + encoder_dim: int = 64, + encoder_rates: List[int] = [2, 4, 8, 8], + latent_dim: int = None, + decoder_dim: int = 1536, + decoder_rates: List[int] = [8, 8, 4, 2], + n_codebooks: int = 9, + codebook_size: int = 1024, + codebook_dim: Union[int, list] = 8, + quantizer_dropout: bool = False, + sample_rate: int = 44100, + ): + super().__init__() + + self.encoder_dim = encoder_dim + self.encoder_rates = encoder_rates + self.decoder_dim = decoder_dim + self.decoder_rates = decoder_rates + self.sample_rate = sample_rate + + if latent_dim is None: + latent_dim = encoder_dim * (2 ** len(encoder_rates)) + + self.latent_dim = latent_dim + + self.hop_length = np.prod(encoder_rates) + self.encoder = Encoder(encoder_dim, encoder_rates, latent_dim) + + self.n_codebooks = n_codebooks + self.codebook_size = codebook_size + self.codebook_dim = codebook_dim + self.quantizer = ResidualVectorQuantize( + input_dim=latent_dim, + n_codebooks=n_codebooks, + codebook_size=codebook_size, + codebook_dim=codebook_dim, + quantizer_dropout=quantizer_dropout, + ) + + self.decoder = Decoder( + latent_dim, + decoder_dim, + decoder_rates, + ) + self.sample_rate = sample_rate + self.apply(init_weights) + + self.delay = self.get_delay() + + def preprocess(self, audio_data, sample_rate): + if sample_rate is None: + sample_rate = self.sample_rate + assert sample_rate == self.sample_rate + + length = audio_data.shape[-1] + right_pad = math.ceil(length / self.hop_length) * self.hop_length - length + audio_data = nn.functional.pad(audio_data, (0, right_pad)) + + return audio_data + + def encode( + self, + audio_data: torch.Tensor, + n_quantizers: int = None, + ): + """Encode given audio data and return quantized latent codes + + Parameters + ---------- + audio_data : Tensor[B x 1 x T] + Audio data to encode + n_quantizers : int, optional + Number of quantizers to use, by default None + If None, all quantizers are used. + + Returns + ------- + dict + A dictionary with the following keys: + "z" : Tensor[B x D x T] + Quantized continuous representation of input + "codes" : Tensor[B x N x T] + Codebook indices for each codebook + (quantized discrete representation of input) + "latents" : Tensor[B x N*D x T] + Projected latents (continuous representation of input before quantization) + "vq/commitment_loss" : Tensor[1] + Commitment loss to train encoder to predict vectors closer to codebook + entries + "vq/codebook_loss" : Tensor[1] + Codebook loss to update the codebook + "length" : int + Number of samples in input audio + """ + z = self.encoder(audio_data) + z, codes, latents, commitment_loss, codebook_loss = self.quantizer(z, n_quantizers) + return z, codes, latents, commitment_loss, codebook_loss + + def decode(self, z: torch.Tensor): + """Decode given latent codes and return audio data + + Parameters + ---------- + z : Tensor[B x D x T] + Quantized continuous representation of input + length : int, optional + Number of samples in output audio, by default None + + Returns + ------- + dict + A dictionary with the following keys: + "audio" : Tensor[B x 1 x length] + Decoded audio data. + """ + return self.decoder(z) + + def forward( + self, + audio_data: torch.Tensor, + sample_rate: int = None, + n_quantizers: int = None, + ): + """Model forward pass + + Parameters + ---------- + audio_data : Tensor[B x 1 x T] + Audio data to encode + sample_rate : int, optional + Sample rate of audio data in Hz, by default None + If None, defaults to `self.sample_rate` + n_quantizers : int, optional + Number of quantizers to use, by default None. + If None, all quantizers are used. + + Returns + ------- + dict + A dictionary with the following keys: + "z" : Tensor[B x D x T] + Quantized continuous representation of input + "codes" : Tensor[B x N x T] + Codebook indices for each codebook + (quantized discrete representation of input) + "latents" : Tensor[B x N*D x T] + Projected latents (continuous representation of input before quantization) + "vq/commitment_loss" : Tensor[1] + Commitment loss to train encoder to predict vectors closer to codebook + entries + "vq/codebook_loss" : Tensor[1] + Codebook loss to update the codebook + "length" : int + Number of samples in input audio + "audio" : Tensor[B x 1 x length] + Decoded audio data. + """ + length = audio_data.shape[-1] + audio_data = self.preprocess(audio_data, sample_rate) + z, codes, latents, commitment_loss, codebook_loss = self.encode(audio_data, n_quantizers) + + x = self.decode(z) + return { + "audio": x[..., :length], + "z": z, + "codes": codes, + "latents": latents, + "vq/commitment_loss": commitment_loss, + "vq/codebook_loss": codebook_loss, + } + + +if __name__ == "__main__": + import numpy as np + from functools import partial + + model = DAC().to("cpu") + + for n, m in model.named_modules(): + o = m.extra_repr() + p = sum([np.prod(p.size()) for p in m.parameters()]) + fn = lambda o, p: o + f" {p / 1e6:<.3f}M params." + setattr(m, "extra_repr", partial(fn, o=o, p=p)) + print(model) + print("Total # of params: ", sum([np.prod(p.size()) for p in model.parameters()])) + + length = 88200 * 2 + x = torch.randn(1, 1, length).to(model.device) + x.requires_grad_(True) + x.retain_grad() + + # Make a forward pass + out = model(x)["audio"] + print("Input shape:", x.shape) + print("Output shape:", out.shape) + + # Create gradient variable + grad = torch.zeros_like(out) + grad[:, :, grad.shape[-1] // 2] = 1 + + # Make a backward pass + out.backward(grad) + + # Check non-zero values + gradmap = x.grad.squeeze(0) + gradmap = (gradmap != 0).sum(0) # sum across features + rf = (gradmap != 0).sum() + + print(f"Receptive field: {rf.item()}") + + x = AudioSignal(torch.randn(1, 1, 44100 * 60), 44100) + model.decompress(model.compress(x, verbose=True), verbose=True) diff --git a/higgs_audio/audio_processing/descriptaudiocodec/dac/nn/layers.py b/higgs_audio/audio_processing/descriptaudiocodec/dac/nn/layers.py new file mode 100644 index 0000000000000000000000000000000000000000..44fbc2929715e11d843b24195d7042a528969a94 --- /dev/null +++ b/higgs_audio/audio_processing/descriptaudiocodec/dac/nn/layers.py @@ -0,0 +1,33 @@ +import numpy as np +import torch +import torch.nn as nn +import torch.nn.functional as F +from einops import rearrange +from torch.nn.utils import weight_norm + + +def WNConv1d(*args, **kwargs): + return weight_norm(nn.Conv1d(*args, **kwargs)) + + +def WNConvTranspose1d(*args, **kwargs): + return weight_norm(nn.ConvTranspose1d(*args, **kwargs)) + + +# Scripting this brings model speed up 1.4x +@torch.jit.script +def snake(x, alpha): + shape = x.shape + x = x.reshape(shape[0], shape[1], -1) + x = x + (alpha + 1e-9).reciprocal() * torch.sin(alpha * x).pow(2) + x = x.reshape(shape) + return x + + +class Snake1d(nn.Module): + def __init__(self, channels): + super().__init__() + self.alpha = nn.Parameter(torch.ones(1, channels, 1)) + + def forward(self, x): + return snake(x, self.alpha) diff --git a/higgs_audio/audio_processing/descriptaudiocodec/dac/nn/quantize.py b/higgs_audio/audio_processing/descriptaudiocodec/dac/nn/quantize.py new file mode 100644 index 0000000000000000000000000000000000000000..8861224cbb49813816dc41b63059faa13d246cc7 --- /dev/null +++ b/higgs_audio/audio_processing/descriptaudiocodec/dac/nn/quantize.py @@ -0,0 +1,251 @@ +from typing import Union + +import numpy as np +import torch +import torch.nn as nn +import torch.nn.functional as F +from einops import rearrange +from torch.nn.utils import weight_norm + +from dac.nn.layers import WNConv1d + + +class VectorQuantize(nn.Module): + """ + Implementation of VQ similar to Karpathy's repo: + https://github.com/karpathy/deep-vector-quantization + Additionally uses following tricks from Improved VQGAN + (https://arxiv.org/pdf/2110.04627.pdf): + 1. Factorized codes: Perform nearest neighbor lookup in low-dimensional space + for improved codebook usage + 2. l2-normalized codes: Converts euclidean distance to cosine similarity which + improves training stability + """ + + def __init__(self, input_dim: int, codebook_size: int, codebook_dim: int): + super().__init__() + self.codebook_size = codebook_size + self.codebook_dim = codebook_dim + + self.in_proj = WNConv1d(input_dim, codebook_dim, kernel_size=1) + self.out_proj = WNConv1d(codebook_dim, input_dim, kernel_size=1) + self.codebook = nn.Embedding(codebook_size, codebook_dim) + + def forward(self, z): + """Quantized the input tensor using a fixed codebook and returns + the corresponding codebook vectors + + Parameters + ---------- + z : Tensor[B x D x T] + + Returns + ------- + Tensor[B x D x T] + Quantized continuous representation of input + Tensor[1] + Commitment loss to train encoder to predict vectors closer to codebook + entries + Tensor[1] + Codebook loss to update the codebook + Tensor[B x T] + Codebook indices (quantized discrete representation of input) + Tensor[B x D x T] + Projected latents (continuous representation of input before quantization) + """ + + # Factorized codes (ViT-VQGAN) Project input into low-dimensional space + z_e = self.in_proj(z) # z_e : (B x D x T) + z_q, indices = self.decode_latents(z_e) + + commitment_loss = F.mse_loss(z_e, z_q.detach(), reduction="none").mean([1, 2]) + codebook_loss = F.mse_loss(z_q, z_e.detach(), reduction="none").mean([1, 2]) + + z_q = z_e + (z_q - z_e).detach() # noop in forward pass, straight-through gradient estimator in backward pass + + z_q = self.out_proj(z_q) + + return z_q, commitment_loss, codebook_loss, indices, z_e + + def embed_code(self, embed_id): + return F.embedding(embed_id, self.codebook.weight) + + def decode_code(self, embed_id): + return self.embed_code(embed_id).transpose(1, 2) + + def decode_latents(self, latents): + encodings = rearrange(latents, "b d t -> (b t) d") + codebook = self.codebook.weight # codebook: (N x D) + + # L2 normalize encodings and codebook (ViT-VQGAN) + encodings = F.normalize(encodings) + codebook = F.normalize(codebook) + + # Compute euclidean distance with codebook + dist = ( + encodings.pow(2).sum(1, keepdim=True) + - 2 * encodings @ codebook.t() + + codebook.pow(2).sum(1, keepdim=True).t() + ) + indices = rearrange((-dist).max(1)[1], "(b t) -> b t", b=latents.size(0)) + z_q = self.decode_code(indices) + return z_q, indices + + +class ResidualVectorQuantize(nn.Module): + """ + Introduced in SoundStream: An end2end neural audio codec + https://arxiv.org/abs/2107.03312 + """ + + def __init__( + self, + input_dim: int = 512, + n_codebooks: int = 9, + codebook_size: int = 1024, + codebook_dim: Union[int, list] = 8, + quantizer_dropout: float = 0.0, + ): + super().__init__() + if isinstance(codebook_dim, int): + codebook_dim = [codebook_dim for _ in range(n_codebooks)] + + self.n_codebooks = n_codebooks + self.codebook_dim = codebook_dim + self.codebook_size = codebook_size + + self.quantizers = nn.ModuleList( + [VectorQuantize(input_dim, codebook_size, codebook_dim[i]) for i in range(n_codebooks)] + ) + self.quantizer_dropout = quantizer_dropout + + def forward(self, z, n_quantizers: int = None): + """Quantized the input tensor using a fixed set of `n` codebooks and returns + the corresponding codebook vectors + Parameters + ---------- + z : Tensor[B x D x T] + n_quantizers : int, optional + No. of quantizers to use + (n_quantizers < self.n_codebooks ex: for quantizer dropout) + Note: if `self.quantizer_dropout` is True, this argument is ignored + when in training mode, and a random number of quantizers is used. + Returns + ------- + dict + A dictionary with the following keys: + + "z" : Tensor[B x D x T] + Quantized continuous representation of input + "codes" : Tensor[B x N x T] + Codebook indices for each codebook + (quantized discrete representation of input) + "latents" : Tensor[B x N*D x T] + Projected latents (continuous representation of input before quantization) + "vq/commitment_loss" : Tensor[1] + Commitment loss to train encoder to predict vectors closer to codebook + entries + "vq/codebook_loss" : Tensor[1] + Codebook loss to update the codebook + """ + z_q = 0 + residual = z + commitment_loss = 0 + codebook_loss = 0 + + codebook_indices = [] + latents = [] + + if n_quantizers is None: + n_quantizers = self.n_codebooks + if self.training: + n_quantizers = torch.ones((z.shape[0],)) * self.n_codebooks + 1 + dropout = torch.randint(1, self.n_codebooks + 1, (z.shape[0],)) + n_dropout = int(z.shape[0] * self.quantizer_dropout) + n_quantizers[:n_dropout] = dropout[:n_dropout] + n_quantizers = n_quantizers.to(z.device) + + for i, quantizer in enumerate(self.quantizers): + if self.training is False and i >= n_quantizers: + break + + z_q_i, commitment_loss_i, codebook_loss_i, indices_i, z_e_i = quantizer(residual) + + # Create mask to apply quantizer dropout + mask = torch.full((z.shape[0],), fill_value=i, device=z.device) < n_quantizers + z_q = z_q + z_q_i * mask[:, None, None] + residual = residual - z_q_i + + # Sum losses + commitment_loss += (commitment_loss_i * mask).mean() + codebook_loss += (codebook_loss_i * mask).mean() + + codebook_indices.append(indices_i) + latents.append(z_e_i) + + codes = torch.stack(codebook_indices, dim=1) + latents = torch.cat(latents, dim=1) + + return z_q, codes, latents, commitment_loss, codebook_loss + + def from_codes(self, codes: torch.Tensor): + """Given the quantized codes, reconstruct the continuous representation + Parameters + ---------- + codes : Tensor[B x N x T] + Quantized discrete representation of input + Returns + ------- + Tensor[B x D x T] + Quantized continuous representation of input + """ + z_q = 0.0 + z_p = [] + n_codebooks = codes.shape[1] + for i in range(n_codebooks): + z_p_i = self.quantizers[i].decode_code(codes[:, i, :]) + z_p.append(z_p_i) + + z_q_i = self.quantizers[i].out_proj(z_p_i) + z_q = z_q + z_q_i + return z_q, torch.cat(z_p, dim=1), codes + + def from_latents(self, latents: torch.Tensor): + """Given the unquantized latents, reconstruct the + continuous representation after quantization. + + Parameters + ---------- + latents : Tensor[B x N x T] + Continuous representation of input after projection + + Returns + ------- + Tensor[B x D x T] + Quantized representation of full-projected space + Tensor[B x D x T] + Quantized representation of latent space + """ + z_q = 0 + z_p = [] + codes = [] + dims = np.cumsum([0] + [q.codebook_dim for q in self.quantizers]) + + n_codebooks = np.where(dims <= latents.shape[1])[0].max(axis=0, keepdims=True)[0] + for i in range(n_codebooks): + j, k = dims[i], dims[i + 1] + z_p_i, codes_i = self.quantizers[i].decode_latents(latents[:, j:k, :]) + z_p.append(z_p_i) + codes.append(codes_i) + + z_q_i = self.quantizers[i].out_proj(z_p_i) + z_q = z_q + z_q_i + + return z_q, torch.cat(z_p, dim=1), torch.stack(codes, dim=1) + + +if __name__ == "__main__": + rvq = ResidualVectorQuantize(quantizer_dropout=True) + x = torch.randn(16, 512, 80) + y = rvq(x) + print(y["latents"].shape) diff --git a/higgs_audio/audio_processing/higgs_audio_tokenizer.py b/higgs_audio/audio_processing/higgs_audio_tokenizer.py new file mode 100644 index 0000000000000000000000000000000000000000..b3fde0962f81dd74520c9fe81328e2166677a1f3 --- /dev/null +++ b/higgs_audio/audio_processing/higgs_audio_tokenizer.py @@ -0,0 +1,341 @@ +# Based on code from: https://github.com/zhenye234/xcodec +# Licensed under MIT License +# Modifications by BosonAI + +import math +import os +import torch +import torch.nn as nn +import torch.nn.functional as F +from typing import Optional, Union, Sequence +import numpy as np +from transformers import AutoModel +import torchaudio +import json +import librosa +from huggingface_hub import snapshot_download + +from vector_quantize_pytorch import ResidualFSQ +from .descriptaudiocodec.dac.model import dac as dac2 +from .quantization.vq import ResidualVectorQuantizer +from .semantic_module import Encoder, Decoder + + +class EncodedResult: + def __init__(self, audio_codes): + self.audio_codes = audio_codes + + +class HiggsAudioFeatureExtractor(nn.Module): + def __init__(self, sampling_rate=16000): + super().__init__() + self.sampling_rate = sampling_rate + + def forward(self, raw_audio, sampling_rate=16000, return_tensors="pt"): + # Convert from librosa to torch + audio_signal = torch.tensor(raw_audio) + audio_signal = audio_signal.unsqueeze(0) + if len(audio_signal.shape) < 3: + audio_signal = audio_signal.unsqueeze(0) + return {"input_values": audio_signal} + + +class HiggsAudioTokenizer(nn.Module): + def __init__( + self, + n_filters: int = 32, + D: int = 128, + target_bandwidths: Sequence[Union[int, float]] = [1, 1.5, 2, 4, 6], + ratios: Sequence[int] = [8, 5, 4, 2], # downsampling by 320 + sample_rate: int = 16000, + bins: int = 1024, + n_q: int = 8, + codebook_dim: int = None, + normalize: bool = False, + causal: bool = False, + semantic_techer: str = "hubert_base_general", + last_layer_semantic: bool = True, + merge_mode: str = "concat", + downsample_mode: str = "step_down", + semantic_mode: str = "classic", + vq_scale: int = 1, + semantic_sample_rate: int = None, + device: str = "cuda", + ): + super().__init__() + self.hop_length = np.prod(ratios) + self.semantic_techer = semantic_techer + + self.frame_rate = math.ceil(sample_rate / np.prod(ratios)) # 50 Hz + + self.target_bandwidths = target_bandwidths + self.n_q = n_q + self.sample_rate = sample_rate + self.encoder = dac2.Encoder(64, ratios, D) + + self.decoder_2 = dac2.Decoder(D, 1024, ratios) + self.last_layer_semantic = last_layer_semantic + self.device = device + if semantic_techer == "hubert_base": + self.semantic_model = AutoModel.from_pretrained("facebook/hubert-base-ls960") + self.semantic_sample_rate = 16000 + self.semantic_dim = 768 + self.encoder_semantic_dim = 768 + + elif semantic_techer == "wavlm_base_plus": + self.semantic_model = AutoModel.from_pretrained("microsoft/wavlm-base-plus") + self.semantic_sample_rate = 16000 + self.semantic_dim = 768 + self.encoder_semantic_dim = 768 + + elif semantic_techer == "hubert_base_general": + self.semantic_model = AutoModel.from_pretrained("ZhenYe234/hubert_base_general_audio") + self.semantic_sample_rate = 16000 + self.semantic_dim = 768 + self.encoder_semantic_dim = 768 + + # Overwrite semantic model sr to ensure semantic_downsample_factor is an integer + if semantic_sample_rate is not None: + self.semantic_sample_rate = semantic_sample_rate + + self.semantic_model.eval() + + # make the semantic model parameters do not need gradient + for param in self.semantic_model.parameters(): + param.requires_grad = False + + self.semantic_downsample_factor = int(self.hop_length / (self.sample_rate / self.semantic_sample_rate) / 320) + + self.quantizer_dim = int((D + self.encoder_semantic_dim) // vq_scale) + self.encoder_semantic = Encoder(input_channels=self.semantic_dim, encode_channels=self.encoder_semantic_dim) + self.decoder_semantic = Decoder( + code_dim=self.encoder_semantic_dim, + output_channels=self.semantic_dim, + decode_channels=self.semantic_dim, + ) + + # out_D=D+768 + if isinstance(bins, int): # RVQ + self.quantizer = ResidualVectorQuantizer( + dimension=self.quantizer_dim, + codebook_dim=codebook_dim, + n_q=n_q, + bins=bins, + ) + self.quantizer_type = "RVQ" + else: # RFSQ + self.quantizer = ResidualFSQ(dim=self.quantizer_dim, levels=bins, num_quantizers=n_q) + self.quantizer_type = "RFSQ" + + self.fc_prior = nn.Linear(D + self.encoder_semantic_dim, self.quantizer_dim) + self.fc_post1 = nn.Linear(self.quantizer_dim, self.encoder_semantic_dim) + self.fc_post2 = nn.Linear(self.quantizer_dim, D) + + self.downsample_mode = downsample_mode + if downsample_mode == "avg": + self.semantic_pooling = nn.AvgPool1d( + kernel_size=self.semantic_downsample_factor, + stride=self.semantic_downsample_factor, + ) + + self.audio_tokenizer_feature_extractor = HiggsAudioFeatureExtractor(sampling_rate=self.sample_rate) + + @property + def tps(self): + return self.frame_rate + + @property + def sampling_rate(self): + return self.sample_rate + + @property + def num_codebooks(self): + return self.n_q + + @property + def codebook_size(self): + return self.quantizer_dim + + def get_last_layer(self): + return self.decoder.layers[-1].weight + + def calculate_rec_loss(self, rec, target): + target = target / target.norm(dim=-1, keepdim=True) + rec = rec / rec.norm(dim=-1, keepdim=True) + rec_loss = (1 - (target * rec).sum(-1)).mean() + + return rec_loss + + @torch.no_grad() + def get_regress_target(self, x): + x = torchaudio.functional.resample(x, self.sample_rate, self.semantic_sample_rate) + + if ( + self.semantic_techer == "hubert_base" + or self.semantic_techer == "hubert_base_general" + or self.semantic_techer == "wavlm_base_plus" + ): + x = x[:, 0, :] + x = F.pad(x, (160, 160)) + target = self.semantic_model(x, output_hidden_states=True).hidden_states + target = torch.stack(target, dim=1) # .transpose(-1, -2)#.flatten(start_dim=1, end_dim=2) + + # average for all layers + target = target.mean(1) + # target = target[9] + # if self.hop_length > 320: + # target = self.semantic_pooling(target.transpose(1, 2)).transpose(1, 2) + + elif self.semantic_techer == "w2v_bert2": + target = self.semantic_model(x) + + elif self.semantic_techer.startswith("whisper"): + if self.last_layer_semantic: + target = self.semantic_model(x, avg_layers=False) + else: + target = self.semantic_model(x, avg_layers=True) + + elif self.semantic_techer.startswith("mert_music"): + if self.last_layer_semantic: + target = self.semantic_model(x, avg_layers=False) + else: + target = self.semantic_model(x, avg_layers=True) + + elif self.semantic_techer.startswith("qwen_audio_omni"): + target = self.semantic_model(x) + + if self.downsample_mode == "step_down": + if self.semantic_downsample_factor > 1: + target = target[:, :: self.semantic_downsample_factor, :] + + elif self.downsample_mode == "avg": + target = self.semantic_pooling(target.transpose(1, 2)).transpose(1, 2) + return target + + def forward(self, x: torch.Tensor, bw: int): + e_semantic_input = self.get_regress_target(x).detach() + + e_semantic = self.encoder_semantic(e_semantic_input.transpose(1, 2)) + e_acoustic = self.encoder(x) + + e = torch.cat([e_acoustic, e_semantic], dim=1) + + e = self.fc_prior(e.transpose(1, 2)) + + if self.quantizer_type == "RVQ": + e = e.transpose(1, 2) + quantized, codes, bandwidth, commit_loss = self.quantizer(e, self.frame_rate, bw) + quantized = quantized.transpose(1, 2) + else: + quantized, codes = self.quantizer(e) + commit_loss = torch.tensor(0.0) + + quantized_semantic = self.fc_post1(quantized).transpose(1, 2) + quantized_acoustic = self.fc_post2(quantized).transpose(1, 2) + + o = self.decoder_2(quantized_acoustic) + + o_semantic = self.decoder_semantic(quantized_semantic) + semantic_recon_loss = F.mse_loss(e_semantic_input.transpose(1, 2).detach(), o_semantic) + + return o, commit_loss, semantic_recon_loss, None + + def encode( + self, + audio_path_or_wv, + sr=None, + loudness_normalize=False, + loudness_threshold=-23.0, + ): + if isinstance(audio_path_or_wv, str): + wv, sr = librosa.load(audio_path_or_wv, mono=True, sr=None) + else: + wv = audio_path_or_wv + assert sr is not None + if loudness_normalize: + import pyloudnorm as pyln + + meter = pyln.Meter(sr) + l = meter.integrated_loudness(wv) + wv = pyln.normalize.loudness(wv, l, loudness_threshold) + if sr != self.sampling_rate: + wv = librosa.resample(wv, orig_sr=sr, target_sr=self.sampling_rate) + if self.audio_tokenizer_feature_extractor is not None: + inputs = self.audio_tokenizer_feature_extractor( + raw_audio=wv, + sampling_rate=self.audio_tokenizer_feature_extractor.sampling_rate, + return_tensors="pt", + ) + input_values = inputs["input_values"].to(self.device) + else: + input_values = torch.from_numpy(wv).float().unsqueeze(0) + with torch.no_grad(): + encoder_outputs = self._xcodec_encode(input_values) + vq_code = encoder_outputs.audio_codes[0] + return vq_code + + def _xcodec_encode(self, x: torch.Tensor, target_bw: Optional[int] = None) -> torch.Tensor: + bw = target_bw + + e_semantic_input = self.get_regress_target(x).detach() + + e_semantic = self.encoder_semantic(e_semantic_input.transpose(1, 2)) + e_acoustic = self.encoder(x) + + if e_acoustic.shape[2] != e_semantic.shape[2]: + pad_size = 160 * self.semantic_downsample_factor + e_acoustic = self.encoder(F.pad(x[:, 0, :], (pad_size, pad_size)).unsqueeze(0)) + + if e_acoustic.shape[2] != e_semantic.shape[2]: + if e_acoustic.shape[2] > e_semantic.shape[2]: + e_acoustic = e_acoustic[:, :, : e_semantic.shape[2]] + else: + e_semantic = e_semantic[:, :, : e_acoustic.shape[2]] + + e = torch.cat([e_acoustic, e_semantic], dim=1) + + e = self.fc_prior(e.transpose(1, 2)) + + if self.quantizer_type == "RVQ": + e = e.transpose(1, 2) + quantized, codes, bandwidth, commit_loss = self.quantizer(e, self.frame_rate, bw) + codes = codes.permute(1, 0, 2) + else: + quantized, codes = self.quantizer(e) + codes = codes.permute(0, 2, 1) + + # return codes + return EncodedResult(codes) + + def decode(self, vq_code: torch.Tensor) -> torch.Tensor: + if self.quantizer_type == "RVQ": + vq_code = vq_code.permute(1, 0, 2) + quantized = self.quantizer.decode(vq_code) + quantized = quantized.transpose(1, 2) + else: + vq_code = vq_code.permute(0, 2, 1) + quantized = self.quantizer.get_output_from_indices(vq_code) + quantized_acoustic = self.fc_post2(quantized).transpose(1, 2) + + o = self.decoder_2(quantized_acoustic) + return o.cpu().numpy() + + +def load_higgs_audio_tokenizer(tokenizer_name_or_path, device="cuda"): + is_local = os.path.exists(tokenizer_name_or_path) + if not is_local: + tokenizer_path = snapshot_download(tokenizer_name_or_path) + else: + tokenizer_path = tokenizer_name_or_path + config_path = os.path.join(tokenizer_path, "config.json") + model_path = os.path.join(tokenizer_path, "model.pth") + config = json.load(open(config_path)) + model = HiggsAudioTokenizer( + **config, + device=device, + ) + parameter_dict = torch.load(model_path, map_location=device) + model.load_state_dict(parameter_dict, strict=False) + model.to(device) + model.eval() + return model diff --git a/higgs_audio/audio_processing/quantization/__init__.py b/higgs_audio/audio_processing/quantization/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..bfabe52b8cb6f260cdda6137b34df2f4736bd02f --- /dev/null +++ b/higgs_audio/audio_processing/quantization/__init__.py @@ -0,0 +1,8 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +# flake8: noqa +from .vq import QuantizedResult, ResidualVectorQuantizer diff --git a/higgs_audio/audio_processing/quantization/ac.py b/higgs_audio/audio_processing/quantization/ac.py new file mode 100644 index 0000000000000000000000000000000000000000..3e6e0edacdc646e705350a5b3e6ac0a2042cee66 --- /dev/null +++ b/higgs_audio/audio_processing/quantization/ac.py @@ -0,0 +1,301 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +"""Arithmetic coder.""" + +import io +import math +import random +import typing as tp +import torch + +from ..binary import BitPacker, BitUnpacker + + +def build_stable_quantized_cdf( + pdf: torch.Tensor, + total_range_bits: int, + roundoff: float = 1e-8, + min_range: int = 2, + check: bool = True, +) -> torch.Tensor: + """Turn the given PDF into a quantized CDF that splits + [0, 2 ** self.total_range_bits - 1] into chunks of size roughly proportional + to the PDF. + + Args: + pdf (torch.Tensor): probability distribution, shape should be `[N]`. + total_range_bits (int): see `ArithmeticCoder`, the typical range we expect + during the coding process is `[0, 2 ** total_range_bits - 1]`. + roundoff (float): will round the pdf up to that level to remove difference coming + from e.g. evaluating the Language Model on different architectures. + min_range (int): minimum range width. Should always be at least 2 for numerical + stability. Use this to avoid pathological behavior is a value + that is expected to be rare actually happens in real life. + check (bool): if True, checks that nothing bad happened, can be deactivated for speed. + """ + pdf = pdf.detach() + if roundoff: + pdf = (pdf / roundoff).floor() * roundoff + # interpolate with uniform distribution to achieve desired minimum probability. + total_range = 2**total_range_bits + cardinality = len(pdf) + alpha = min_range * cardinality / total_range + assert alpha <= 1, "you must reduce min_range" + ranges = (((1 - alpha) * total_range) * pdf).floor().long() + ranges += min_range + quantized_cdf = torch.cumsum(ranges, dim=-1) + if min_range < 2: + raise ValueError("min_range must be at least 2.") + if check: + assert quantized_cdf[-1] <= 2**total_range_bits, quantized_cdf[-1] + if ((quantized_cdf[1:] - quantized_cdf[:-1]) < min_range).any() or quantized_cdf[0] < min_range: + raise ValueError("You must increase your total_range_bits.") + return quantized_cdf + + +class ArithmeticCoder: + """ArithmeticCoder, + Let us take a distribution `p` over `N` symbols, and assume we have a stream + of random variables `s_t` sampled from `p`. Let us assume that we have a budget + of `B` bits that we can afford to write on device. There are `2**B` possible numbers, + corresponding to the range `[0, 2 ** B - 1]`. We can map each of those number to a single + sequence `(s_t)` by doing the following: + + 1) Initialize the current range to` [0 ** 2 B - 1]`. + 2) For each time step t, split the current range into contiguous chunks, + one for each possible outcome, with size roughly proportional to `p`. + For instance, if `p = [0.75, 0.25]`, and the range is `[0, 3]`, the chunks + would be `{[0, 2], [3, 3]}`. + 3) Select the chunk corresponding to `s_t`, and replace the current range with this. + 4) When done encoding all the values, just select any value remaining in the range. + + You will notice that this procedure can fail: for instance if at any point in time + the range is smaller than `N`, then we can no longer assign a non-empty chunk to each + possible outcome. Intuitively, the more likely a value is, the less the range width + will reduce, and the longer we can go on encoding values. This makes sense: for any efficient + coding scheme, likely outcomes would take less bits, and more of them can be coded + with a fixed budget. + + In practice, we do not know `B` ahead of time, but we have a way to inject new bits + when the current range decreases below a given limit (given by `total_range_bits`), without + having to redo all the computations. If we encode mostly likely values, we will seldom + need to inject new bits, but a single rare value can deplete our stock of entropy! + + In this explanation, we assumed that the distribution `p` was constant. In fact, the present + code works for any sequence `(p_t)` possibly different for each timestep. + We also assume that `s_t ~ p_t`, but that doesn't need to be true, although the smaller + the KL between the true distribution and `p_t`, the most efficient the coding will be. + + Args: + fo (IO[bytes]): file-like object to which the bytes will be written to. + total_range_bits (int): the range `M` described above is `2 ** total_range_bits. + Any time the current range width fall under this limit, new bits will + be injected to rescale the initial range. + """ + + def __init__(self, fo: tp.IO[bytes], total_range_bits: int = 24): + assert total_range_bits <= 30 + self.total_range_bits = total_range_bits + self.packer = BitPacker(bits=1, fo=fo) # we push single bits at a time. + self.low: int = 0 + self.high: int = 0 + self.max_bit: int = -1 + self._dbg: tp.List[tp.Any] = [] + self._dbg2: tp.List[tp.Any] = [] + + @property + def delta(self) -> int: + """Return the current range width.""" + return self.high - self.low + 1 + + def _flush_common_prefix(self): + # If self.low and self.high start with the sames bits, + # those won't change anymore as we always just increase the range + # by powers of 2, and we can flush them out to the bit stream. + assert self.high >= self.low, (self.low, self.high) + assert self.high < 2 ** (self.max_bit + 1) + while self.max_bit >= 0: + b1 = self.low >> self.max_bit + b2 = self.high >> self.max_bit + if b1 == b2: + self.low -= b1 << self.max_bit + self.high -= b1 << self.max_bit + assert self.high >= self.low, (self.high, self.low, self.max_bit) + assert self.low >= 0 + self.max_bit -= 1 + self.packer.push(b1) + else: + break + + def push(self, symbol: int, quantized_cdf: torch.Tensor): + """Push the given symbol on the stream, flushing out bits + if possible. + + Args: + symbol (int): symbol to encode with the AC. + quantized_cdf (torch.Tensor): use `build_stable_quantized_cdf` + to build this from your pdf estimate. + """ + while self.delta < 2**self.total_range_bits: + self.low *= 2 + self.high = self.high * 2 + 1 + self.max_bit += 1 + + range_low = 0 if symbol == 0 else quantized_cdf[symbol - 1].item() + range_high = quantized_cdf[symbol].item() - 1 + effective_low = int(math.ceil(range_low * (self.delta / (2**self.total_range_bits)))) + effective_high = int(math.floor(range_high * (self.delta / (2**self.total_range_bits)))) + assert self.low <= self.high + self.high = self.low + effective_high + self.low = self.low + effective_low + assert self.low <= self.high, ( + effective_low, + effective_high, + range_low, + range_high, + ) + self._dbg.append((self.low, self.high)) + self._dbg2.append((self.low, self.high)) + outs = self._flush_common_prefix() + assert self.low <= self.high + assert self.max_bit >= -1 + assert self.max_bit <= 61, self.max_bit + return outs + + def flush(self): + """Flush the remaining information to the stream.""" + while self.max_bit >= 0: + b1 = (self.low >> self.max_bit) & 1 + self.packer.push(b1) + self.max_bit -= 1 + self.packer.flush() + + +class ArithmeticDecoder: + """ArithmeticDecoder, see `ArithmeticCoder` for a detailed explanation. + + Note that this must be called with **exactly** the same parameters and sequence + of quantized cdf as the arithmetic encoder or the wrong values will be decoded. + + If the AC encoder current range is [L, H], with `L` and `H` having the some common + prefix (i.e. the same most significant bits), then this prefix will be flushed to the stream. + For instances, having read 3 bits `b1 b2 b3`, we know that `[L, H]` is contained inside + `[b1 b2 b3 0 ... 0 b1 b3 b3 1 ... 1]`. Now this specific sub-range can only be obtained + for a specific sequence of symbols and a binary-search allows us to decode those symbols. + At some point, the prefix `b1 b2 b3` will no longer be sufficient to decode new symbols, + and we will need to read new bits from the stream and repeat the process. + + """ + + def __init__(self, fo: tp.IO[bytes], total_range_bits: int = 24): + self.total_range_bits = total_range_bits + self.low: int = 0 + self.high: int = 0 + self.current: int = 0 + self.max_bit: int = -1 + self.unpacker = BitUnpacker(bits=1, fo=fo) # we pull single bits at a time. + # Following is for debugging + self._dbg: tp.List[tp.Any] = [] + self._dbg2: tp.List[tp.Any] = [] + self._last: tp.Any = None + + @property + def delta(self) -> int: + return self.high - self.low + 1 + + def _flush_common_prefix(self): + # Given the current range [L, H], if both have a common prefix, + # we know we can remove it from our representation to avoid handling large numbers. + while self.max_bit >= 0: + b1 = self.low >> self.max_bit + b2 = self.high >> self.max_bit + if b1 == b2: + self.low -= b1 << self.max_bit + self.high -= b1 << self.max_bit + self.current -= b1 << self.max_bit + assert self.high >= self.low + assert self.low >= 0 + self.max_bit -= 1 + else: + break + + def pull(self, quantized_cdf: torch.Tensor) -> tp.Optional[int]: + """Pull a symbol, reading as many bits from the stream as required. + This returns `None` when the stream has been exhausted. + + Args: + quantized_cdf (torch.Tensor): use `build_stable_quantized_cdf` + to build this from your pdf estimate. This must be **exatly** + the same cdf as the one used at encoding time. + """ + while self.delta < 2**self.total_range_bits: + bit = self.unpacker.pull() + if bit is None: + return None + self.low *= 2 + self.high = self.high * 2 + 1 + self.current = self.current * 2 + bit + self.max_bit += 1 + + def bin_search(low_idx: int, high_idx: int): + # Binary search is not just for coding interviews :) + if high_idx < low_idx: + raise RuntimeError("Binary search failed") + mid = (low_idx + high_idx) // 2 + range_low = quantized_cdf[mid - 1].item() if mid > 0 else 0 + range_high = quantized_cdf[mid].item() - 1 + effective_low = int(math.ceil(range_low * (self.delta / (2**self.total_range_bits)))) + effective_high = int(math.floor(range_high * (self.delta / (2**self.total_range_bits)))) + low = effective_low + self.low + high = effective_high + self.low + if self.current >= low: + if self.current <= high: + return (mid, low, high, self.current) + else: + return bin_search(mid + 1, high_idx) + else: + return bin_search(low_idx, mid - 1) + + self._last = (self.low, self.high, self.current, self.max_bit) + sym, self.low, self.high, self.current = bin_search(0, len(quantized_cdf) - 1) + self._dbg.append((self.low, self.high, self.current)) + self._flush_common_prefix() + self._dbg2.append((self.low, self.high, self.current)) + + return sym + + +def test(): + torch.manual_seed(1234) + random.seed(1234) + for _ in range(4): + pdfs = [] + cardinality = random.randrange(4000) + steps = random.randrange(100, 500) + fo = io.BytesIO() + encoder = ArithmeticCoder(fo) + symbols = [] + for step in range(steps): + pdf = torch.softmax(torch.randn(cardinality), dim=0) + pdfs.append(pdf) + q_cdf = build_stable_quantized_cdf(pdf, encoder.total_range_bits) + symbol = torch.multinomial(pdf, 1).item() + symbols.append(symbol) + encoder.push(symbol, q_cdf) + encoder.flush() + + fo.seek(0) + decoder = ArithmeticDecoder(fo) + for idx, (pdf, symbol) in enumerate(zip(pdfs, symbols)): + q_cdf = build_stable_quantized_cdf(pdf, encoder.total_range_bits) + decoded_symbol = decoder.pull(q_cdf) + assert decoded_symbol == symbol, idx + assert decoder.pull(torch.zeros(1)) is None + + +if __name__ == "__main__": + test() diff --git a/higgs_audio/audio_processing/quantization/core_vq.py b/higgs_audio/audio_processing/quantization/core_vq.py new file mode 100644 index 0000000000000000000000000000000000000000..ad368a980582bcbba901f28b568c4bfb8f4099e6 --- /dev/null +++ b/higgs_audio/audio_processing/quantization/core_vq.py @@ -0,0 +1,360 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. +# +# This implementation is inspired from +# https://github.com/lucidrains/vector-quantize-pytorch +# which is released under MIT License. Hereafter, the original license: +# MIT License +# +# Copyright (c) 2020 Phil Wang +# +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in all +# copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +# SOFTWARE. + +"""Core vector quantization implementation.""" + +import typing as tp + +from einops import rearrange, repeat +import torch +from torch import nn +import torch.nn.functional as F + +from xcodec.quantization.distrib import broadcast_tensors, rank + + +def default(val: tp.Any, d: tp.Any) -> tp.Any: + return val if val is not None else d + + +def ema_inplace(moving_avg, new, decay: float): + moving_avg.data.mul_(decay).add_(new, alpha=(1 - decay)) + + +def laplace_smoothing(x, n_categories: int, epsilon: float = 1e-5): + return (x + epsilon) / (x.sum() + n_categories * epsilon) + + +def uniform_init(*shape: int): + t = torch.empty(shape) + nn.init.kaiming_uniform_(t) + return t + + +def sample_vectors(samples, num: int): + num_samples, device = samples.shape[0], samples.device + + if num_samples >= num: + indices = torch.randperm(num_samples, device=device)[:num] + else: + indices = torch.randint(0, num_samples, (num,), device=device) + + return samples[indices] + + +def kmeans(samples, num_clusters: int, num_iters: int = 10): + dim, dtype = samples.shape[-1], samples.dtype + + means = sample_vectors(samples, num_clusters) + + for _ in range(num_iters): + diffs = rearrange(samples, "n d -> n () d") - rearrange(means, "c d -> () c d") + dists = -(diffs**2).sum(dim=-1) + + buckets = dists.max(dim=-1).indices + bins = torch.bincount(buckets, minlength=num_clusters) + zero_mask = bins == 0 + bins_min_clamped = bins.masked_fill(zero_mask, 1) + + new_means = buckets.new_zeros(num_clusters, dim, dtype=dtype) + new_means.scatter_add_(0, repeat(buckets, "n -> n d", d=dim), samples) + new_means = new_means / bins_min_clamped[..., None] + + means = torch.where(zero_mask[..., None], means, new_means) + + return means, bins + + +class EuclideanCodebook(nn.Module): + """Codebook with Euclidean distance. + Args: + dim (int): Dimension. + codebook_size (int): Codebook size. + kmeans_init (bool): Whether to use k-means to initialize the codebooks. + If set to true, run the k-means algorithm on the first training batch and use + the learned centroids as initialization. + kmeans_iters (int): Number of iterations used for k-means algorithm at initialization. + decay (float): Decay for exponential moving average over the codebooks. + epsilon (float): Epsilon value for numerical stability. + threshold_ema_dead_code (int): Threshold for dead code expiration. Replace any codes + that have an exponential moving average cluster size less than the specified threshold with + randomly selected vector from the current batch. + """ + + def __init__( + self, + dim: int, + codebook_size: int, + kmeans_init: int = False, + kmeans_iters: int = 10, + decay: float = 0.99, + epsilon: float = 1e-5, + threshold_ema_dead_code: int = 2, + ): + super().__init__() + self.decay = decay + init_fn: tp.Union[tp.Callable[..., torch.Tensor], tp.Any] = uniform_init if not kmeans_init else torch.zeros + embed = init_fn(codebook_size, dim) + + self.codebook_size = codebook_size + + self.kmeans_iters = kmeans_iters + self.epsilon = epsilon + self.threshold_ema_dead_code = threshold_ema_dead_code + + self.register_buffer("inited", torch.Tensor([not kmeans_init])) + self.register_buffer("cluster_size", torch.zeros(codebook_size)) + self.register_buffer("embed", embed) + self.register_buffer("embed_avg", embed.clone()) + + @torch.jit.ignore + def init_embed_(self, data): + if self.inited: + return + + embed, cluster_size = kmeans(data, self.codebook_size, self.kmeans_iters) + self.embed.data.copy_(embed) + self.embed_avg.data.copy_(embed.clone()) + self.cluster_size.data.copy_(cluster_size) + self.inited.data.copy_(torch.Tensor([True])) + # Make sure all buffers across workers are in sync after initialization + broadcast_tensors(self.buffers()) + + def replace_(self, samples, mask): + modified_codebook = torch.where(mask[..., None], sample_vectors(samples, self.codebook_size), self.embed) + self.embed.data.copy_(modified_codebook) + + def expire_codes_(self, batch_samples): + if self.threshold_ema_dead_code == 0: + return + + expired_codes = self.cluster_size < self.threshold_ema_dead_code + if not torch.any(expired_codes): + return + + batch_samples = rearrange(batch_samples, "... d -> (...) d") + self.replace_(batch_samples, mask=expired_codes) + broadcast_tensors(self.buffers()) + + def preprocess(self, x): + x = rearrange(x, "... d -> (...) d") + return x + + def quantize(self, x): + embed = self.embed.t() + dist = -(x.pow(2).sum(1, keepdim=True) - 2 * x @ embed + embed.pow(2).sum(0, keepdim=True)) + embed_ind = dist.max(dim=-1).indices + return embed_ind + + def postprocess_emb(self, embed_ind, shape): + return embed_ind.view(*shape[:-1]) + + def dequantize(self, embed_ind): + quantize = F.embedding(embed_ind, self.embed) # get embedding based on index + return quantize + + def encode(self, x): + shape = x.shape + # pre-process + x = self.preprocess(x) + # quantize + embed_ind = self.quantize(x) # get index based on Euclidean distance + # post-process + embed_ind = self.postprocess_emb(embed_ind, shape) + return embed_ind + + def decode(self, embed_ind): + quantize = self.dequantize(embed_ind) + return quantize + + def forward(self, x): + shape, dtype = x.shape, x.dtype + x = self.preprocess(x) + + self.init_embed_(x) + + embed_ind = self.quantize(x) + embed_onehot = F.one_hot(embed_ind, self.codebook_size).type(dtype) + embed_ind = self.postprocess_emb(embed_ind, shape) + quantize = self.dequantize(embed_ind) + + if self.training: + # We do the expiry of code at that point as buffers are in sync + # and all the workers will take the same decision. + self.expire_codes_(x) + ema_inplace(self.cluster_size, embed_onehot.sum(0), self.decay) + embed_sum = x.t() @ embed_onehot + ema_inplace(self.embed_avg, embed_sum.t(), self.decay) + cluster_size = ( + laplace_smoothing(self.cluster_size, self.codebook_size, self.epsilon) * self.cluster_size.sum() + ) + embed_normalized = self.embed_avg / cluster_size.unsqueeze(1) + self.embed.data.copy_(embed_normalized) + + return quantize, embed_ind + + +class VectorQuantization(nn.Module): + """Vector quantization implementation. + Currently supports only euclidean distance. + Args: + dim (int): Dimension + codebook_size (int): Codebook size + codebook_dim (int): Codebook dimension. If not defined, uses the specified dimension in dim. + decay (float): Decay for exponential moving average over the codebooks. + epsilon (float): Epsilon value for numerical stability. + kmeans_init (bool): Whether to use kmeans to initialize the codebooks. + kmeans_iters (int): Number of iterations used for kmeans initialization. + threshold_ema_dead_code (int): Threshold for dead code expiration. Replace any codes + that have an exponential moving average cluster size less than the specified threshold with + randomly selected vector from the current batch. + commitment_weight (float): Weight for commitment loss. + """ + + def __init__( + self, + dim: int, + codebook_size: int, + codebook_dim: tp.Optional[int] = None, + decay: float = 0.99, + epsilon: float = 1e-5, + kmeans_init: bool = True, + kmeans_iters: int = 50, + threshold_ema_dead_code: int = 2, + commitment_weight: float = 1.0, + ): + super().__init__() + _codebook_dim: int = default(codebook_dim, dim) + + requires_projection = _codebook_dim != dim + self.project_in = nn.Linear(dim, _codebook_dim) if requires_projection else nn.Identity() + self.project_out = nn.Linear(_codebook_dim, dim) if requires_projection else nn.Identity() + + self.epsilon = epsilon + self.commitment_weight = commitment_weight + + self._codebook = EuclideanCodebook( + dim=_codebook_dim, + codebook_size=codebook_size, + kmeans_init=kmeans_init, + kmeans_iters=kmeans_iters, + decay=decay, + epsilon=epsilon, + threshold_ema_dead_code=threshold_ema_dead_code, + ) + self.codebook_size = codebook_size + + @property + def codebook(self): + return self._codebook.embed + + def encode(self, x): + x = rearrange(x, "b d n -> b n d") + x = self.project_in(x) + embed_in = self._codebook.encode(x) + return embed_in + + def decode(self, embed_ind): + quantize = self._codebook.decode(embed_ind) + quantize = self.project_out(quantize) + quantize = rearrange(quantize, "b n d -> b d n") + return quantize + + def forward(self, x): + device = x.device + x = rearrange(x, "b d n -> b n d") + x = self.project_in(x) + + quantize, embed_ind = self._codebook(x) + + if self.training: + quantize = x + (quantize - x).detach() + + loss = torch.tensor([0.0], device=device, requires_grad=self.training) + + if self.training: + if self.commitment_weight > 0: + commit_loss = F.mse_loss(quantize.detach(), x) + loss = loss + commit_loss * self.commitment_weight + + quantize = self.project_out(quantize) + quantize = rearrange(quantize, "b n d -> b d n") + return quantize, embed_ind, loss + + +class ResidualVectorQuantization(nn.Module): + """Residual vector quantization implementation. + Follows Algorithm 1. in https://arxiv.org/pdf/2107.03312.pdf + """ + + def __init__(self, *, num_quantizers, **kwargs): + super().__init__() + self.layers = nn.ModuleList([VectorQuantization(**kwargs) for _ in range(num_quantizers)]) + + def forward(self, x, n_q: tp.Optional[int] = None): + quantized_out = 0.0 + residual = x + + all_losses = [] + all_indices = [] + + n_q = n_q or len(self.layers) + + for layer in self.layers[:n_q]: + quantized, indices, loss = layer(residual) + residual = residual - quantized + quantized_out = quantized_out + quantized + + all_indices.append(indices) + all_losses.append(loss) + + out_losses, out_indices = map(torch.stack, (all_losses, all_indices)) + return quantized_out, out_indices, out_losses + + def encode(self, x: torch.Tensor, n_q: tp.Optional[int] = None) -> torch.Tensor: + residual = x + all_indices = [] + n_q = n_q or len(self.layers) + for layer in self.layers[:n_q]: + indices = layer.encode(residual) + quantized = layer.decode(indices) + residual = residual - quantized + all_indices.append(indices) + out_indices = torch.stack(all_indices) + return out_indices + + def decode(self, q_indices: torch.Tensor) -> torch.Tensor: + quantized_out = torch.tensor(0.0, device=q_indices.device) + for i, indices in enumerate(q_indices): + layer = self.layers[i] + quantized = layer.decode(indices) + quantized_out = quantized_out + quantized + return quantized_out diff --git a/higgs_audio/audio_processing/quantization/core_vq_lsx_version.py b/higgs_audio/audio_processing/quantization/core_vq_lsx_version.py new file mode 100644 index 0000000000000000000000000000000000000000..96c9282302dea87e1be690d0ee69f6fb296083e5 --- /dev/null +++ b/higgs_audio/audio_processing/quantization/core_vq_lsx_version.py @@ -0,0 +1,431 @@ +# Copyright (c) +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. +# This implementation is inspired from +# https://github.com/rosinality/vq-vae-2-pytorch/blob/master/vqvae.py and +# https://github.com/clementchadebec/benchmark_VAE/blob/dfa0dcf6c79172df5d27769c09c860c42008baaa/src/pythae/models/vq_vae/vq_vae_utils.py#L81 +# +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. +# +# This implementation is inspired from +# https://github.com/lucidrains/vector-quantize-pytorch +# which is released under MIT License. Hereafter, the original license: +# MIT License +# +# Copyright (c) 2020 Phil Wang +# +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in all +# copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +# SOFTWARE. + +"""Core vector quantization implementation.""" + +import typing as tp + +from einops import rearrange +import torch +from torch import nn +import torch.nn.functional as F +import torch.distributed as dist + +from .distrib import broadcast_tensors, is_distributed +from .ddp_utils import SyncFunction + + +def default(val: tp.Any, d: tp.Any) -> tp.Any: + return val if val is not None else d + + +def ema_inplace(moving_avg, new, decay: float): + moving_avg.data.mul_(decay).add_(new, alpha=(1 - decay)) + + +def laplace_smoothing(x, n_categories: int, epsilon: float = 1e-5): + return (x + epsilon) / (x.sum() + n_categories * epsilon) + + +def uniform_init(*shape: int): + t = torch.empty(shape) + nn.init.kaiming_uniform_(t) + return t + + +def sample_vectors(samples, num: int): + num_samples, device = samples.shape[0], samples.device + + if num_samples >= num: + indices = torch.randperm(num_samples, device=device)[:num] + else: + indices = torch.randint(0, num_samples, (num,), device=device) + + return samples[indices] + + +def kmeans( + samples, + num_clusters: int, + num_iters: int = 10, + frames_to_use: int = 10_000, + batch_size: int = 64, +): + """ + Memory-efficient K-means clustering. + Args: + samples (tensor): shape [N, D] + num_clusters (int): number of centroids. + num_iters (int): number of iterations. + frames_to_use (int): subsample size from total samples. + batch_size (int): batch size used in distance computation. + Returns: + means: [num_clusters, D] + bins: [num_clusters] (number of points per cluster) + """ + N, D = samples.shape + dtype, device = samples.dtype, samples.device + + if frames_to_use < N: + indices = torch.randperm(N, device=device)[:frames_to_use] + samples = samples[indices] + + means = sample_vectors(samples, num_clusters) + + for _ in range(num_iters): + # Store cluster assignments + all_assignments = [] + + for i in range(0, samples.shape[0], batch_size): + batch = samples[i : i + batch_size] # [B, D] + dists = torch.cdist(batch, means, p=2) # [B, C] + assignments = dists.argmin(dim=1) # [B] + all_assignments.append(assignments) + + buckets = torch.cat(all_assignments, dim=0) # [N] + bins = torch.bincount(buckets, minlength=num_clusters) + zero_mask = bins == 0 + bins_min_clamped = bins.masked_fill(zero_mask, 1) + + # Compute new means + new_means = torch.zeros_like(means) + for i in range(num_clusters): + mask = buckets == i + if mask.any(): + new_means[i] = samples[mask].mean(dim=0) + + means = torch.where(zero_mask[:, None], means, new_means) + + return means, bins + + +class EuclideanCodebook(nn.Module): + """Codebook with Euclidean distance. + Args: + dim (int): Dimension. + codebook_size (int): Codebook size. + kmeans_init (bool): Whether to use k-means to initialize the codebooks. + If set to true, run the k-means algorithm on the first training batch and use + the learned centroids as initialization. + kmeans_iters (int): Number of iterations used for k-means algorithm at initialization. + decay (float): Decay for exponential moving average over the codebooks. + epsilon (float): Epsilon value for numerical stability. + threshold_ema_dead_code (int): Threshold for dead code expiration. Replace any codes + that have an exponential moving average cluster size less than the specified threshold with + randomly selected vector from the current batch. + """ + + def __init__( + self, + dim: int, + codebook_size: int, + kmeans_init: int = False, + kmeans_iters: int = 10, + decay: float = 0.99, + epsilon: float = 1e-5, + threshold_ema_dead_code: int = 2, + ): + super().__init__() + self.decay = decay + init_fn: tp.Union[tp.Callable[..., torch.Tensor], tp.Any] = uniform_init if not kmeans_init else torch.zeros + embed = init_fn(codebook_size, dim) + + self.codebook_size = codebook_size + + self.kmeans_iters = kmeans_iters + self.epsilon = epsilon + self.threshold_ema_dead_code = threshold_ema_dead_code + + # Flag variable to indicate whether the codebook is initialized + self.register_buffer("inited", torch.Tensor([not kmeans_init])) + # Runing EMA cluster size/count: N_i^t in eq. (6) in vqvae paper + self.register_buffer("cluster_size", torch.zeros(codebook_size)) + # Codebook + self.register_buffer("embed", embed) + # EMA codebook: eq. (7) in vqvae paper + self.register_buffer("embed_avg", embed.clone()) + + @torch.jit.ignore + def init_embed_(self, data): + """Initialize codebook. + Args: + data (tensor): [B * T, D]. + """ + if self.inited: + return + + ## NOTE (snippet added by Songxiang Liu): gather data from all gpus + if dist.is_available() and dist.is_initialized(): + # [B * T * world_size, D] + data = SyncFunction.apply(data) + + embed, cluster_size = kmeans(data, self.codebook_size, self.kmeans_iters) + self.embed.data.copy_(embed) + self.embed_avg.data.copy_(embed.clone()) + self.cluster_size.data.copy_(cluster_size) + self.inited.data.copy_(torch.Tensor([True])) + # Make sure all buffers across workers are in sync after initialization + broadcast_tensors(self.buffers()) + + def replace_(self, samples, mask): + modified_codebook = torch.where(mask[..., None], sample_vectors(samples, self.codebook_size), self.embed) + self.embed.data.copy_(modified_codebook) + + def expire_codes_(self, batch_samples): + if self.threshold_ema_dead_code == 0: + return + + expired_codes = self.cluster_size < self.threshold_ema_dead_code + if not torch.any(expired_codes): + return + + ## NOTE (snippet added by Songxiang Liu): gather data from all gpus + if is_distributed(): + # [B * T * world_size, D] + batch_samples = SyncFunction.apply(batch_samples) + + batch_samples = rearrange(batch_samples, "... d -> (...) d") + self.replace_(batch_samples, mask=expired_codes) + broadcast_tensors(self.buffers()) + + def preprocess(self, x): + x = rearrange(x, "... d -> (...) d") + return x + + def quantize(self, x): + embed = self.embed.t() + dist = -(x.pow(2).sum(1, keepdim=True) - 2 * x @ embed + embed.pow(2).sum(0, keepdim=True)) + embed_ind = dist.max(dim=-1).indices + return embed_ind + + def postprocess_emb(self, embed_ind, shape): + return embed_ind.view(*shape[:-1]) + + def dequantize(self, embed_ind): + quantize = F.embedding(embed_ind, self.embed) + return quantize + + def encode(self, x): + shape = x.shape + # pre-process + x = self.preprocess(x) # [B, T, D] -> [B*T, D] + # quantize + embed_ind = self.quantize(x) + # post-process + embed_ind = self.postprocess_emb(embed_ind, shape) + return embed_ind + + def decode(self, embed_ind): + quantize = self.dequantize(embed_ind) + return quantize + + def forward(self, x): + # shape: [B, T, D] + shape, dtype = x.shape, x.dtype + x = self.preprocess(x) # [B, T, D] -> [B*T, D] + + # Initialize codebook + self.init_embed_(x) + + embed_ind = self.quantize(x) # [B*T,] + embed_onehot = F.one_hot(embed_ind, self.codebook_size).type(dtype) # [B*T, cb-size] + embed_ind = self.postprocess_emb(embed_ind, shape) # [B, T] + quantize = self.dequantize(embed_ind) # [B, T, D] + + if self.training: + ### Update codebook by EMA + embed_onehot_sum = embed_onehot.sum(0) # [cb-size,] + embed_sum = x.t() @ embed_onehot # [D, cb-size] + if is_distributed(): + dist.all_reduce(embed_onehot_sum) + dist.all_reduce(embed_sum) + # Update ema cluster count N_i^t, eq. (6) in vqvae paper + self.cluster_size.data.mul_(self.decay).add_(embed_onehot_sum, alpha=1 - self.decay) + # Update ema embed: eq. (7) in vqvae paper + self.embed_avg.data.mul_(self.decay).add_(embed_sum.t(), alpha=1 - self.decay) + # apply laplace smoothing + n = self.cluster_size.sum() + cluster_size = (self.cluster_size + self.epsilon) / (n + self.codebook_size * self.epsilon) * n + # Update ema embed: eq. (8) in vqvae paper + embed_normalized = self.embed_avg / cluster_size.unsqueeze(1) + self.embed.data.copy_(embed_normalized) + + # We do the expiry of code at that point as buffers are in sync + # and all the workers will take the same decision. + self.expire_codes_(x) + + return quantize, embed_ind + + +class VectorQuantization(nn.Module): + """Vector quantization implementation. + Currently supports only euclidean distance. + Args: + dim (int): Dimension + codebook_size (int): Codebook size + codebook_dim (int): Codebook dimension. If not defined, uses the specified dimension in dim. + decay (float): Decay for exponential moving average over the codebooks. + epsilon (float): Epsilon value for numerical stability. + kmeans_init (bool): Whether to use kmeans to initialize the codebooks. + kmeans_iters (int): Number of iterations used for kmeans initialization. + threshold_ema_dead_code (int): Threshold for dead code expiration. Replace any codes + that have an exponential moving average cluster size less than the specified threshold with + randomly selected vector from the current batch. + commitment_weight (float): Weight for commitment loss. + """ + + def __init__( + self, + dim: int, + codebook_size: int, + codebook_dim: tp.Optional[int] = None, + decay: float = 0.99, + epsilon: float = 1e-5, + kmeans_init: bool = True, + kmeans_iters: int = 50, + threshold_ema_dead_code: int = 2, + commitment_weight: float = 1.0, + ): + super().__init__() + _codebook_dim: int = default(codebook_dim, dim) + + requires_projection = _codebook_dim != dim + self.project_in = nn.Linear(dim, _codebook_dim) if requires_projection else nn.Identity() + self.project_out = nn.Linear(_codebook_dim, dim) if requires_projection else nn.Identity() + + self.epsilon = epsilon + self.commitment_weight = commitment_weight + + self._codebook = EuclideanCodebook( + dim=_codebook_dim, + codebook_size=codebook_size, + kmeans_init=kmeans_init, + kmeans_iters=kmeans_iters, + decay=decay, + epsilon=epsilon, + threshold_ema_dead_code=threshold_ema_dead_code, + ) + self.codebook_size = codebook_size + + @property + def codebook(self): + return self._codebook.embed + + def encode(self, x): + x = rearrange(x, "b d n -> b n d") + x = self.project_in(x) + embed_in = self._codebook.encode(x) + return embed_in + + def decode(self, embed_ind): + quantize = self._codebook.decode(embed_ind) + quantize = self.project_out(quantize) + quantize = rearrange(quantize, "b n d -> b d n") + return quantize + + def forward(self, x): + device = x.device + x = x.transpose(1, 2).contiguous() # [b d n] -> [b n d] + x = self.project_in(x) + + quantize, embed_ind = self._codebook(x) + + if self.training: + quantize = x + (quantize - x).detach() + + loss = torch.tensor([0.0], device=device, requires_grad=self.training) + + if self.training: + if self.commitment_weight > 0: + commit_loss = F.mse_loss(quantize.detach(), x) + loss = loss + commit_loss * self.commitment_weight + + quantize = self.project_out(quantize) + quantize = quantize.transpose(1, 2).contiguous() # [b n d] -> [b d n] + return quantize, embed_ind, loss + + +class ResidualVectorQuantization(nn.Module): + """Residual vector quantization implementation. + Follows Algorithm 1. in https://arxiv.org/pdf/2107.03312.pdf + """ + + def __init__(self, *, num_quantizers, **kwargs): + super().__init__() + self.layers = nn.ModuleList([VectorQuantization(**kwargs) for _ in range(num_quantizers)]) + + def forward(self, x, n_q: tp.Optional[int] = None): + quantized_out = 0.0 + residual = x + + all_losses = [] + all_indices = [] + + n_q = n_q or len(self.layers) + + for layer in self.layers[:n_q]: + quantized, indices, loss = layer(residual) + residual = residual - quantized + quantized_out = quantized_out + quantized + + all_indices.append(indices) + all_losses.append(loss) + + out_losses, out_indices = map(torch.stack, (all_losses, all_indices)) + return quantized_out, out_indices, out_losses + + def encode(self, x: torch.Tensor, n_q: tp.Optional[int] = None) -> torch.Tensor: + residual = x + all_indices = [] + n_q = n_q or len(self.layers) + for layer in self.layers[:n_q]: + indices = layer.encode(residual) + quantized = layer.decode(indices) + residual = residual - quantized + all_indices.append(indices) + out_indices = torch.stack(all_indices) + return out_indices + + def decode(self, q_indices: torch.Tensor) -> torch.Tensor: + quantized_out = torch.tensor(0.0, device=q_indices.device) + for i, indices in enumerate(q_indices): + layer = self.layers[i] + quantized = layer.decode(indices) + quantized_out = quantized_out + quantized + return quantized_out diff --git a/higgs_audio/audio_processing/quantization/ddp_utils.py b/higgs_audio/audio_processing/quantization/ddp_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..990dca85fd518f09e2fcd528e28e7d256f64a15a --- /dev/null +++ b/higgs_audio/audio_processing/quantization/ddp_utils.py @@ -0,0 +1,197 @@ +import logging +import random +import subprocess +from datetime import datetime + +import numpy as np +import torch +import torch.distributed as dist +from torch.nn.parallel import DistributedDataParallel +from torch.nn.parallel.distributed import _find_tensors +import torch.optim +import torch.utils.data +from packaging import version +from omegaconf import OmegaConf + + +def set_random_seed(seed): + random.seed(seed) + np.random.seed(seed) + torch.manual_seed(seed) + torch.cuda.manual_seed_all(seed) + + +def is_logging_process(): + return not dist.is_initialized() or dist.get_rank() == 0 + + +def get_logger(cfg, name=None): + # log_file_path is used when unit testing + if is_logging_process(): + logging.config.dictConfig(OmegaConf.to_container(cfg.job_logging_config, resolve=True)) + return logging.getLogger(name) + + +# from https://github.com/Lightning-AI/lightning-bolts/blob/5d61197cd2f491f69e238137a5edabe80ae14ad9/pl_bolts/models/self_supervised/simclr/simclr_module.py#L20 +class SyncFunction(torch.autograd.Function): + @staticmethod + # @torch.no_grad() + def forward(ctx, tensor): + ctx.batch_size = tensor.shape[0] + + gathered_tensor = [torch.zeros_like(tensor) for _ in range(torch.distributed.get_world_size())] + + torch.distributed.all_gather(gathered_tensor, tensor) + gathered_tensor = torch.cat(gathered_tensor, 0) + + return gathered_tensor + + @staticmethod + def backward(ctx, grad_output): + grad_input = grad_output.clone() + torch.distributed.all_reduce(grad_input, op=torch.distributed.ReduceOp.SUM, async_op=False) + + idx_from = torch.distributed.get_rank() * ctx.batch_size + idx_to = (torch.distributed.get_rank() + 1) * ctx.batch_size + return grad_input[idx_from:idx_to] + + +def get_timestamp(): + return datetime.now().strftime("%y%m%d-%H%M%S") + + +def get_commit_hash(): + message = subprocess.check_output(["git", "rev-parse", "--short", "HEAD"]) + return message.strip().decode("utf-8") + + +class DDP(DistributedDataParallel): + """ + Override the forward call in lightning so it goes to training and validation step respectively + """ + + def forward(self, *inputs, **kwargs): # pragma: no cover + if version.parse(torch.__version__[:6]) < version.parse("1.11"): + self._sync_params() + inputs, kwargs = self.scatter(inputs, kwargs, self.device_ids) + assert len(self.device_ids) == 1 + if self.module.training: + output = self.module.training_step(*inputs[0], **kwargs[0]) + elif self.module.testing: + output = self.module.test_step(*inputs[0], **kwargs[0]) + else: + output = self.module.validation_step(*inputs[0], **kwargs[0]) + if torch.is_grad_enabled(): + # We'll return the output object verbatim since it is a freeform + # object. We need to find any tensors in this object, though, + # because we need to figure out which parameters were used during + # this forward pass, to ensure we short circuit reduction for any + # unused parameters. Only if `find_unused_parameters` is set. + if self.find_unused_parameters: + self.reducer.prepare_for_backward(list(_find_tensors(output))) + else: + self.reducer.prepare_for_backward([]) + else: + from torch.nn.parallel.distributed import ( + logging, + Join, + _DDPSink, + _tree_flatten_with_rref, + _tree_unflatten_with_rref, + ) + + with torch.autograd.profiler.record_function("DistributedDataParallel.forward"): + if torch.is_grad_enabled() and self.require_backward_grad_sync: + self.logger.set_runtime_stats_and_log() + self.num_iterations += 1 + self.reducer.prepare_for_forward() + + # Notify the join context that this process has not joined, if + # needed + work = Join.notify_join_context(self) + if work: + self.reducer._set_forward_pass_work_handle(work, self._divide_by_initial_world_size) + + # Calling _rebuild_buckets before forward compuation, + # It may allocate new buckets before deallocating old buckets + # inside _rebuild_buckets. To save peak memory usage, + # call _rebuild_buckets before the peak memory usage increases + # during forward computation. + # This should be called only once during whole training period. + if torch.is_grad_enabled() and self.reducer._rebuild_buckets(): + logging.info("Reducer buckets have been rebuilt in this iteration.") + self._has_rebuilt_buckets = True + + # sync params according to location (before/after forward) user + # specified as part of hook, if hook was specified. + buffer_hook_registered = hasattr(self, "buffer_hook") + if self._check_sync_bufs_pre_fwd(): + self._sync_buffers() + + if self._join_config.enable: + # Notify joined ranks whether they should sync in backwards pass or not. + self._check_global_requires_backward_grad_sync(is_joined_rank=False) + + inputs, kwargs = self.scatter(inputs, kwargs, self.device_ids) + if self.module.training: + output = self.module.training_step(*inputs[0], **kwargs[0]) + elif self.module.testing: + output = self.module.test_step(*inputs[0], **kwargs[0]) + else: + output = self.module.validation_step(*inputs[0], **kwargs[0]) + + # sync params according to location (before/after forward) user + # specified as part of hook, if hook was specified. + if self._check_sync_bufs_post_fwd(): + self._sync_buffers() + + if torch.is_grad_enabled() and self.require_backward_grad_sync: + self.require_forward_param_sync = True + # We'll return the output object verbatim since it is a freeform + # object. We need to find any tensors in this object, though, + # because we need to figure out which parameters were used during + # this forward pass, to ensure we short circuit reduction for any + # unused parameters. Only if `find_unused_parameters` is set. + if self.find_unused_parameters and not self.static_graph: + # Do not need to populate this for static graph. + self.reducer.prepare_for_backward(list(_find_tensors(output))) + else: + self.reducer.prepare_for_backward([]) + else: + self.require_forward_param_sync = False + + # TODO: DDPSink is currently enabled for unused parameter detection and + # static graph training for first iteration. + if (self.find_unused_parameters and not self.static_graph) or ( + self.static_graph and self.num_iterations == 1 + ): + state_dict = { + "static_graph": self.static_graph, + "num_iterations": self.num_iterations, + } + + output_tensor_list, treespec, output_is_rref = _tree_flatten_with_rref(output) + output_placeholders = [None for _ in range(len(output_tensor_list))] + # Do not touch tensors that have no grad_fn, which can cause issues + # such as https://github.com/pytorch/pytorch/issues/60733 + for i, output in enumerate(output_tensor_list): + if torch.is_tensor(output) and output.grad_fn is None: + output_placeholders[i] = output + + # When find_unused_parameters=True, makes tensors which require grad + # run through the DDPSink backward pass. When not all outputs are + # used in loss, this makes those corresponding tensors receive + # undefined gradient which the reducer then handles to ensure + # param.grad field is not touched and we don't error out. + passthrough_tensor_list = _DDPSink.apply( + self.reducer, + state_dict, + *output_tensor_list, + ) + for i in range(len(output_placeholders)): + if output_placeholders[i] is None: + output_placeholders[i] = passthrough_tensor_list[i] + + # Reconstruct output data structure. + output = _tree_unflatten_with_rref(output_placeholders, treespec, output_is_rref) + return output diff --git a/higgs_audio/audio_processing/quantization/distrib.py b/higgs_audio/audio_processing/quantization/distrib.py new file mode 100644 index 0000000000000000000000000000000000000000..cabf8f8a24eb710ab0eb83ce29ba054b7c11ccf3 --- /dev/null +++ b/higgs_audio/audio_processing/quantization/distrib.py @@ -0,0 +1,123 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +"""Torch distributed utilities.""" + +import typing as tp + +import torch + + +def rank(): + if torch.distributed.is_initialized(): + return torch.distributed.get_rank() + else: + return 0 + + +def world_size(): + if torch.distributed.is_initialized(): + return torch.distributed.get_world_size() + else: + return 1 + + +def is_distributed(): + return world_size() > 1 + + +def all_reduce(tensor: torch.Tensor, op=torch.distributed.ReduceOp.SUM): + if is_distributed(): + return torch.distributed.all_reduce(tensor, op) + + +def _is_complex_or_float(tensor): + return torch.is_floating_point(tensor) or torch.is_complex(tensor) + + +def _check_number_of_params(params: tp.List[torch.Tensor]): + # utility function to check that the number of params in all workers is the same, + # and thus avoid a deadlock with distributed all reduce. + if not is_distributed() or not params: + return + # print('params[0].device ', params[0].device) + tensor = torch.tensor([len(params)], device=params[0].device, dtype=torch.long) + all_reduce(tensor) + if tensor.item() != len(params) * world_size(): + # If not all the workers have the same number, for at least one of them, + # this inequality will be verified. + raise RuntimeError( + f"Mismatch in number of params: ours is {len(params)}, at least one worker has a different one." + ) + + +def broadcast_tensors(tensors: tp.Iterable[torch.Tensor], src: int = 0): + """Broadcast the tensors from the given parameters to all workers. + This can be used to ensure that all workers have the same model to start with. + """ + if not is_distributed(): + return + tensors = [tensor for tensor in tensors if _is_complex_or_float(tensor)] + _check_number_of_params(tensors) + handles = [] + for tensor in tensors: + handle = torch.distributed.broadcast(tensor.data, src=src, async_op=True) + handles.append(handle) + for handle in handles: + handle.wait() + + +def sync_buffer(buffers, average=True): + """ + Sync grad for buffers. If average is False, broadcast instead of averaging. + """ + if not is_distributed(): + return + handles = [] + for buffer in buffers: + if torch.is_floating_point(buffer.data): + if average: + handle = torch.distributed.all_reduce(buffer.data, op=torch.distributed.ReduceOp.SUM, async_op=True) + else: + handle = torch.distributed.broadcast(buffer.data, src=0, async_op=True) + handles.append((buffer, handle)) + for buffer, handle in handles: + handle.wait() + if average: + buffer.data /= world_size + + +def sync_grad(params): + """ + Simpler alternative to DistributedDataParallel, that doesn't rely + on any black magic. For simple models it can also be as fast. + Just call this on your model parameters after the call to backward! + """ + if not is_distributed(): + return + handles = [] + for p in params: + if p.grad is not None: + handle = torch.distributed.all_reduce(p.grad.data, op=torch.distributed.ReduceOp.SUM, async_op=True) + handles.append((p, handle)) + for p, handle in handles: + handle.wait() + p.grad.data /= world_size() + + +def average_metrics(metrics: tp.Dict[str, float], count=1.0): + """Average a dictionary of metrics across all workers, using the optional + `count` as unormalized weight. + """ + if not is_distributed(): + return metrics + keys, values = zip(*metrics.items()) + device = "cuda" if torch.cuda.is_available() else "cpu" + tensor = torch.tensor(list(values) + [1], device=device, dtype=torch.float32) + tensor *= count + all_reduce(tensor) + averaged = (tensor[:-1] / tensor[-1]).cpu().tolist() + return dict(zip(keys, averaged)) diff --git a/higgs_audio/audio_processing/quantization/vq.py b/higgs_audio/audio_processing/quantization/vq.py new file mode 100644 index 0000000000000000000000000000000000000000..dac26ba2a3bc2c97d6178fa33c629f324980d5a0 --- /dev/null +++ b/higgs_audio/audio_processing/quantization/vq.py @@ -0,0 +1,116 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +"""Residual vector quantizer implementation.""" + +from dataclasses import dataclass, field +import math +import typing as tp + +import torch +from torch import nn + +# from .core_vq import ResidualVectorQuantization +from .core_vq_lsx_version import ResidualVectorQuantization + + +@dataclass +class QuantizedResult: + quantized: torch.Tensor + codes: torch.Tensor + bandwidth: torch.Tensor # bandwidth in kb/s used, per batch item. + penalty: tp.Optional[torch.Tensor] = None + metrics: dict = field(default_factory=dict) + + +class ResidualVectorQuantizer(nn.Module): + """Residual Vector Quantizer. + Args: + dimension (int): Dimension of the codebooks. + n_q (int): Number of residual vector quantizers used. + bins (int): Codebook size. + decay (float): Decay for exponential moving average over the codebooks. + kmeans_init (bool): Whether to use kmeans to initialize the codebooks. + kmeans_iters (int): Number of iterations used for kmeans initialization. + threshold_ema_dead_code (int): Threshold for dead code expiration. Replace any codes + that have an exponential moving average cluster size less than the specified threshold with + randomly selected vector from the current batch. + """ + + def __init__( + self, + dimension: int = 256, + codebook_dim: int = None, + n_q: int = 8, + bins: int = 1024, + decay: float = 0.99, + kmeans_init: bool = True, + kmeans_iters: int = 50, + threshold_ema_dead_code: int = 2, + ): + super().__init__() + self.n_q = n_q + self.dimension = dimension + self.codebook_dim = codebook_dim + self.bins = bins + self.decay = decay + self.kmeans_init = kmeans_init + self.kmeans_iters = kmeans_iters + self.threshold_ema_dead_code = threshold_ema_dead_code + self.vq = ResidualVectorQuantization( + dim=self.dimension, + codebook_dim=self.codebook_dim, + codebook_size=self.bins, + num_quantizers=self.n_q, + decay=self.decay, + kmeans_init=self.kmeans_init, + kmeans_iters=self.kmeans_iters, + threshold_ema_dead_code=self.threshold_ema_dead_code, + ) + + def forward(self, x: torch.Tensor, sample_rate: int, bandwidth: tp.Optional[float] = None): # -> QuantizedResult: + """Residual vector quantization on the given input tensor. + Args: + x (torch.Tensor): Input tensor. + sample_rate (int): Sample rate of the input tensor. + bandwidth (float): Target bandwidth. + Returns: + QuantizedResult: + The quantized (or approximately quantized) representation with + the associated bandwidth and any penalty term for the loss. + """ + bw_per_q = self.get_bandwidth_per_quantizer(sample_rate) + n_q = self.get_num_quantizers_for_bandwidth(sample_rate, bandwidth) + quantized, codes, commit_loss = self.vq(x, n_q=n_q) + bw = torch.tensor(n_q * bw_per_q).to(x) + return quantized, codes, bw, torch.mean(commit_loss) + # return QuantizedResult(quantized, codes, bw, penalty=torch.mean(commit_loss)) + + def get_num_quantizers_for_bandwidth(self, sample_rate: int, bandwidth: tp.Optional[float] = None) -> int: + """Return n_q based on specified target bandwidth.""" + bw_per_q = self.get_bandwidth_per_quantizer(sample_rate) + n_q = self.n_q + if bandwidth and bandwidth > 0.0: + n_q = int(max(1, math.floor(bandwidth / bw_per_q))) + return n_q + + def get_bandwidth_per_quantizer(self, sample_rate: int): + """Return bandwidth per quantizer for a given input sample rate.""" + return math.log2(self.bins) * sample_rate / 1000 + + def encode(self, x: torch.Tensor, sample_rate: int, bandwidth: tp.Optional[float] = None) -> torch.Tensor: + """Encode a given input tensor with the specified sample rate at the given bandwidth. + The RVQ encode method sets the appropriate number of quantizer to use + and returns indices for each quantizer. + """ + n_q = self.get_num_quantizers_for_bandwidth(sample_rate, bandwidth) + codes = self.vq.encode(x, n_q=n_q) + return codes + + def decode(self, codes: torch.Tensor) -> torch.Tensor: + """Decode the given codes to the quantized representation.""" + quantized = self.vq.decode(codes) + return quantized diff --git a/higgs_audio/audio_processing/semantic_module.py b/higgs_audio/audio_processing/semantic_module.py new file mode 100644 index 0000000000000000000000000000000000000000..c4fd352909afede0de1ed7f3d96630bc0934a668 --- /dev/null +++ b/higgs_audio/audio_processing/semantic_module.py @@ -0,0 +1,310 @@ +# Based on code from: https://github.com/zhenye234/xcodec +# Licensed under MIT License +# Modifications by BosonAI + +import torch +import torch.nn as nn + + +class Conv1d1x1(nn.Conv1d): + """1x1 Conv1d.""" + + def __init__(self, in_channels, out_channels, bias=True): + super(Conv1d1x1, self).__init__(in_channels, out_channels, kernel_size=1, bias=bias) + + +class Conv1d(nn.Module): + def __init__( + self, + in_channels: int, + out_channels: int, + kernel_size: int, + stride: int = 1, + padding: int = -1, + dilation: int = 1, + groups: int = 1, + bias: bool = True, + ): + super().__init__() + self.in_channels = in_channels + self.out_channels = out_channels + self.kernel_size = kernel_size + if padding < 0: + padding = (kernel_size - 1) // 2 * dilation + self.dilation = dilation + self.conv = nn.Conv1d( + in_channels=in_channels, + out_channels=out_channels, + kernel_size=kernel_size, + stride=stride, + padding=padding, + dilation=dilation, + groups=groups, + bias=bias, + ) + + def forward(self, x): + """ + Args: + x (Tensor): Float tensor variable with the shape (B, C, T). + Returns: + Tensor: Float tensor variable with the shape (B, C, T). + """ + x = self.conv(x) + return x + + +class ResidualUnit(nn.Module): + def __init__( + self, + in_channels: int, + out_channels: int, + kernel_size=3, + dilation=1, + bias=False, + nonlinear_activation="ELU", + nonlinear_activation_params={}, + ): + super().__init__() + self.activation = getattr(nn, nonlinear_activation)(**nonlinear_activation_params) + self.conv1 = Conv1d( + in_channels=in_channels, + out_channels=out_channels, + kernel_size=kernel_size, + stride=1, + dilation=dilation, + bias=bias, + ) + self.conv2 = Conv1d1x1(out_channels, out_channels, bias) + + def forward(self, x): + y = self.conv1(self.activation(x)) + y = self.conv2(self.activation(y)) + return x + y + + +class ConvTranspose1d(nn.Module): + def __init__( + self, + in_channels: int, + out_channels: int, + kernel_size: int, + stride: int, + padding=-1, + output_padding=-1, + groups=1, + bias=True, + ): + super().__init__() + if padding < 0: + padding = (stride + 1) // 2 + if output_padding < 0: + output_padding = 1 if stride % 2 else 0 + self.deconv = nn.ConvTranspose1d( + in_channels=in_channels, + out_channels=out_channels, + kernel_size=kernel_size, + stride=stride, + padding=padding, + output_padding=output_padding, + groups=groups, + bias=bias, + ) + + def forward(self, x): + """ + Args: + x (Tensor): Float tensor variable with the shape (B, C, T). + Returns: + Tensor: Float tensor variable with the shape (B, C', T'). + """ + x = self.deconv(x) + return x + + +class EncoderBlock(nn.Module): + def __init__( + self, + in_channels: int, + out_channels: int, + stride: int, + dilations=(1, 1), + unit_kernel_size=3, + bias=True, + ): + super().__init__() + self.res_units = torch.nn.ModuleList() + for dilation in dilations: + self.res_units += [ + ResidualUnit( + in_channels, + in_channels, + kernel_size=unit_kernel_size, + dilation=dilation, + ) + ] + self.num_res = len(self.res_units) + + self.conv = Conv1d( + in_channels=in_channels, + out_channels=out_channels, + kernel_size=3 if stride == 1 else (2 * stride), # special case: stride=1, do not use kernel=2 + stride=stride, + bias=bias, + ) + + def forward(self, x): + for idx in range(self.num_res): + x = self.res_units[idx](x) + x = self.conv(x) + return x + + +class Encoder(nn.Module): + def __init__( + self, + input_channels: int, + encode_channels: int, + channel_ratios=(1, 1), + strides=(1, 1), + kernel_size=3, + bias=True, + block_dilations=(1, 1), + unit_kernel_size=3, + ): + super().__init__() + assert len(channel_ratios) == len(strides) + + self.conv = Conv1d( + in_channels=input_channels, + out_channels=encode_channels, + kernel_size=kernel_size, + stride=1, + bias=False, + ) + self.conv_blocks = torch.nn.ModuleList() + in_channels = encode_channels + for idx, stride in enumerate(strides): + out_channels = int(encode_channels * channel_ratios[idx]) # could be float + self.conv_blocks += [ + EncoderBlock( + in_channels, + out_channels, + stride, + dilations=block_dilations, + unit_kernel_size=unit_kernel_size, + bias=bias, + ) + ] + in_channels = out_channels + self.num_blocks = len(self.conv_blocks) + self.out_channels = out_channels + + def forward(self, x): + x = self.conv(x) + for i in range(self.num_blocks): + x = self.conv_blocks[i](x) + return x + + +class DecoderBlock(nn.Module): + """Decoder block (no up-sampling)""" + + def __init__( + self, + in_channels: int, + out_channels: int, + stride: int, + dilations=(1, 1), + unit_kernel_size=3, + bias=True, + ): + super().__init__() + + if stride == 1: + self.conv = Conv1d( + in_channels=in_channels, + out_channels=out_channels, + kernel_size=3, # fix kernel=3 when stride=1 for unchanged shape + stride=stride, + bias=bias, + ) + else: + self.conv = ConvTranspose1d( + in_channels=in_channels, + out_channels=out_channels, + kernel_size=(2 * stride), + stride=stride, + bias=bias, + ) + + self.res_units = torch.nn.ModuleList() + for idx, dilation in enumerate(dilations): + self.res_units += [ + ResidualUnit( + out_channels, + out_channels, + kernel_size=unit_kernel_size, + dilation=dilation, + ) + ] + self.num_res = len(self.res_units) + + def forward(self, x): + x = self.conv(x) + for idx in range(self.num_res): + x = self.res_units[idx](x) + return x + + +class Decoder(nn.Module): + def __init__( + self, + code_dim: int, + output_channels: int, + decode_channels: int, + channel_ratios=(1, 1), + strides=(1, 1), + kernel_size=3, + bias=True, + block_dilations=(1, 1), + unit_kernel_size=3, + ): + super().__init__() + assert len(channel_ratios) == len(strides) + + self.conv1 = Conv1d( + in_channels=code_dim, + out_channels=int(decode_channels * channel_ratios[0]), + kernel_size=kernel_size, + stride=1, + bias=False, + ) + + self.conv_blocks = torch.nn.ModuleList() + for idx, stride in enumerate(strides): + in_channels = int(decode_channels * channel_ratios[idx]) + if idx < (len(channel_ratios) - 1): + out_channels = int(decode_channels * channel_ratios[idx + 1]) + else: + out_channels = decode_channels + self.conv_blocks += [ + DecoderBlock( + in_channels, + out_channels, + stride, + dilations=block_dilations, + unit_kernel_size=unit_kernel_size, + bias=bias, + ) + ] + self.num_blocks = len(self.conv_blocks) + + self.conv2 = Conv1d(out_channels, output_channels, kernel_size, 1, bias=False) + + def forward(self, z): + x = self.conv1(z) + for i in range(self.num_blocks): + x = self.conv_blocks[i](x) + x = self.conv2(x) + return x diff --git a/higgs_audio/constants.py b/higgs_audio/constants.py new file mode 100644 index 0000000000000000000000000000000000000000..addf77d2512980bfbc84b389ba830924f1f2ba33 --- /dev/null +++ b/higgs_audio/constants.py @@ -0,0 +1,3 @@ +AUDIO_IN_TOKEN = "<|AUDIO|>" +AUDIO_OUT_TOKEN = "<|AUDIO_OUT|>" +EOS_TOKEN = "<|end_of_text|>" diff --git a/higgs_audio/data_collator/__init__.py b/higgs_audio/data_collator/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/higgs_audio/data_collator/higgs_audio_collator.py b/higgs_audio/data_collator/higgs_audio_collator.py new file mode 100644 index 0000000000000000000000000000000000000000..a0a41116c9c36618b2de038e22aa0815cc1e6a9e --- /dev/null +++ b/higgs_audio/data_collator/higgs_audio_collator.py @@ -0,0 +1,583 @@ +import librosa +import torch +import torch.nn.functional as F +import math +import numpy as np +from typing import List, Tuple, Dict + +from dataclasses import dataclass +from typing import List, Optional +from transformers.models.whisper.processing_whisper import WhisperProcessor + +from ..dataset.chatml_dataset import ChatMLDatasetSample, RankedChatMLDatasetSampleTuple +from ..model.utils import build_delay_pattern_mask + + +def _ceil_to_nearest(n, round_to): + return (n + round_to - 1) // round_to * round_to + + +@dataclass +class HiggsAudioBatchInput: + input_ids: torch.LongTensor # shape (bsz, seq_len). + attention_mask: torch.Tensor # shape (bsz, seq_len). + audio_features: Optional[torch.Tensor] # shape (num_audio_in, feature_dim, max_mel_seq_len). + audio_feature_attention_mask: Optional[torch.Tensor] # shape (num_audio_in, max_mel_seq_len). + audio_out_ids: Optional[torch.LongTensor] # shape (num_codebooks, audio_out_total_length) + audio_out_ids_start: Optional[torch.LongTensor] # shape (num_audio_out,) + # The audio_out_ids_start_group_loc has the same length as audio_out_ids_start. It is used to recover group location in a batch for an audio segment + # Currently, we concatenante audio segments along dim 0 to handle variadic audio segment length. However, in the alignment stage, we need the location information + # For example, + # audio_out_ids_start = [0, 2, 4, 8]; and the first two audio segments come from the same sample in a batch, and other two come from different samples. + # This is a batch of 3 samples, then we will have the group location as: + # audio_out_ids_start_group_loc = [0, 0, 1, 2] + audio_out_ids_start_group_loc: Optional[ + torch.LongTensor + ] # shape (num_audio_out,), specify which a sample's group location in the batch + audio_in_ids: Optional[torch.LongTensor] # shape (num_codebooks, audio_in_total_length) + audio_in_ids_start: Optional[torch.LongTensor] # shape (num_audio_in,) + label_ids: Optional[torch.LongTensor] # shape (bsz, seq_len) + label_audio_ids: Optional[torch.LongTensor] # shape (num_codebooks, audio_out_total_length) + reward: Optional[float] = None + + +class HiggsAudioSampleCollator: + """Sample collator for Higgs-Audio model. + + Args: + whisper_processor (WhisperProcessor): The whisper processor. + audio_in_token_id (int): The token id for audio-in. + audio_out_token_id (int): The token id for audio-out. + pad_token_id (int): The token id for padding. + audio_stream_bos_id (int): The token id for audio-stream beginning of sentence. + audio_stream_eos_id (int): The token id for audio-stream end of sentence. + round_to (int): The round-to value. + pad_left (bool): Whether to pad left. + return_audio_in_tokens (bool): Whether to return audio-in tokens. + use_delay_pattern (bool): Whether to use delay pattern. + disable_audio_codes_transform (bool): Whether to add bos and eos tokens to audio codes. + chunk_size_seconds (int): The chunk size in seconds. + add_new_bos_eos_for_long_chunk (bool): Whether to add new bos and eos tokens for long chunks. + mask_audio_out_token_label (bool): Whether to always mask the label associated with <|AUDIO_OUT|> token. Since we will always have `<|AUDIO_OUT|>` after `<|audio_bos|>`, we can safely mask <|AUDIO_OUT|>. + + """ + + def __init__( + self, + whisper_processor: WhisperProcessor, + audio_in_token_id, + audio_out_token_id, + pad_token_id, + audio_stream_bos_id, + audio_stream_eos_id, + round_to=8, + pad_left=False, + encode_whisper_embed=True, + return_audio_in_tokens=True, + audio_num_codebooks=None, + use_delay_pattern=False, + disable_audio_codes_transform=False, + chunk_size_seconds=30, # Maximum duration for each chunk + add_new_bos_eos_for_long_chunk=True, + mask_audio_out_token_label=True, + ): + self.whisper_processor = whisper_processor + self.round_to = round_to + self.pad_left = pad_left + self.audio_in_token_id = audio_in_token_id + self.audio_out_token_id = audio_out_token_id + self.audio_stream_bos_id = audio_stream_bos_id + self.audio_stream_eos_id = audio_stream_eos_id + self.pad_token_id = pad_token_id + self.encode_whisper_embed = encode_whisper_embed + self.return_audio_in_tokens = return_audio_in_tokens + self.audio_num_codebooks = audio_num_codebooks + self.use_delay_pattern = use_delay_pattern + if encode_whisper_embed: + self.chunk_size_seconds = chunk_size_seconds + self.chunk_size_samples = int(chunk_size_seconds * whisper_processor.feature_extractor.sampling_rate) + else: + self.chunk_size_seconds = None + self.chunk_size_samples = None + self.disable_audio_codes_transform = disable_audio_codes_transform + self.add_new_bos_eos_for_long_chunk = add_new_bos_eos_for_long_chunk + self.mask_audio_out_token_label = mask_audio_out_token_label + + def _process_and_duplicate_audio_tokens( + self, + input_ids: torch.Tensor, + audio_idx: int, + wv: torch.Tensor, + sr: int, + labels: Optional[torch.Tensor] = None, + ) -> Tuple[torch.Tensor, torch.Tensor, int]: + """Process long audio and duplicate corresponding audio tokens. + + Args: + input_ids: Input token ids + audio_idx: Index of the audio token in the sequence + wv: Audio waveform + sr: Sample rate + labels: Optional label ids to be duplicated alongside input ids + + Returns: + Tuple of: + - New input ids with duplicated audio tokens + - New label ids (if labels were provided) or None + - Number of chunks created + """ + # Calculate number of chunks needed + total_samples = len(wv) + num_chunks = math.ceil(total_samples / self.chunk_size_samples) + + if num_chunks <= 1: + return input_ids, labels, 1 + + # Get the three tokens: <|audio_bos|><|AUDIO|><|audio_eos|> + audio_token_seq = input_ids[audio_idx - 1 : audio_idx + 2] + # Duplicate sequence for each chunk + duplicated_sequence = audio_token_seq.repeat(num_chunks) + + # Create new input_ids with duplicated tokens + new_input_ids = torch.cat( + [ + input_ids[: audio_idx - 1], + duplicated_sequence, + input_ids[audio_idx + 2 :], + ] + ) + + # If labels are provided, duplicate them as well + new_labels = None + if labels is not None: + label_seq = labels[audio_idx - 1 : audio_idx + 2] + duplicated_labels = label_seq.repeat(num_chunks) + new_labels = torch.cat([labels[: audio_idx - 1], duplicated_labels, labels[audio_idx + 2 :]]) + + return new_input_ids, new_labels, num_chunks + + def __call__(self, batch: List[ChatMLDatasetSample]): + """Collate the input data with support for long audio processing.""" + + label_ids = None + label_audio_ids = None + if all([ele.label_ids is None for ele in batch]): + return_labels = False + else: + return_labels = True + + if self.encode_whisper_embed: + # Process each sample in the batch to handle long audio + # TODO(?) The implementation here can be optimized. + processed_batch = [] + for i in range(len(batch)): + sample = batch[i] + audio_in_mask = sample.input_ids == self.audio_in_token_id + audio_in_indices = torch.where(audio_in_mask)[0] + audio_out_mask = sample.input_ids == self.audio_out_token_id + + # Process each audio token and duplicate if needed + modified_input_ids = sample.input_ids + modified_labels = sample.label_ids if return_labels else None + modified_waveforms_concat = [] + modified_waveforms_start = [] + modified_sample_rate = [] + offset = 0 # Track position changes from duplicating tokens + curr_wv_offset = 0 + + # Process input audio tokens + for idx, audio_idx in enumerate(audio_in_indices): + # Get the audio for this token + wv, sr = sample.get_wv(idx) # Use idx since we want the original audio index + if sr != self.whisper_processor.feature_extractor.sampling_rate: + resampled_wv = librosa.resample( + wv.cpu().numpy(), + orig_sr=sr, + target_sr=self.whisper_processor.feature_extractor.sampling_rate, + ) + else: + resampled_wv = wv.cpu().numpy() + wv = torch.tensor(resampled_wv, device=wv.device) + sr = self.whisper_processor.feature_extractor.sampling_rate + + # Process and duplicate tokens if necessary + token_pos = audio_idx + offset + modified_input_ids, modified_labels, num_chunks = self._process_and_duplicate_audio_tokens( + modified_input_ids, token_pos, wv, sr, modified_labels + ) + + # Update audio data + for chunk_idx in range(num_chunks): + chunk_start = chunk_idx * self.chunk_size_samples + chunk_end = min((chunk_idx + 1) * self.chunk_size_samples, len(wv)) + chunk_wv = wv[chunk_start:chunk_end] + modified_waveforms_concat.append(chunk_wv) + modified_waveforms_start.append(curr_wv_offset) + curr_wv_offset += len(chunk_wv) + modified_sample_rate.append(sr) + + # Update offset for next iteration + offset += (num_chunks - 1) * 3 # Each new chunk adds 3 more tokens + + # Create new sample with modified tokens and audio data + processed_sample = ChatMLDatasetSample( + input_ids=modified_input_ids, + label_ids=modified_labels if return_labels else sample.label_ids, + audio_ids_concat=sample.audio_ids_concat, + audio_ids_start=sample.audio_ids_start, + audio_waveforms_concat=torch.cat(modified_waveforms_concat) + if modified_waveforms_concat + else sample.audio_waveforms_concat, + audio_waveforms_start=torch.tensor(modified_waveforms_start, dtype=torch.long) + if modified_waveforms_start + else sample.audio_waveforms_start, + audio_sample_rate=torch.tensor(modified_sample_rate) + if modified_sample_rate + else sample.audio_sample_rate, + audio_speaker_indices=torch.tensor([]), + # FIXME(sxjscience): The logic here is not correct for audio_label_ids_concat. + audio_label_ids_concat=sample.audio_label_ids_concat, + ) + # audio_in_chunk_len = len(torch.where(modified_input_ids == self.audio_in_token_id)[0]) + # assert audio_in_chunk_len == processed_sample.num_audios(), f"Mismatch: audio_in_chunk_len={audio_in_chunk_len}, processed_sample.num_audios()={processed_sample.num_audios()}" + processed_batch.append(processed_sample) + else: + processed_batch = batch + + # Get the max sequence length based on processed batch + max_seq_length = _ceil_to_nearest(max([len(sample.input_ids) for sample in processed_batch]), self.round_to) + + # Get the ids for audio-in and audio-out for each batch + audio_in_wv_l = [] + audio_in_ids_l = [] + audio_out_ids_l = [] + audio_out_ids_group_loc_l = [] + audio_in_label_ids_l = None + audio_out_label_ids_l = None + reward_l = [] + + if return_labels: + audio_out_no_train_flag = [] # Whether the audio-out data should be trained on or not. + + # Process the audio inputs and outputs + for i in range(len(processed_batch)): + audio_in_mask = processed_batch[i].input_ids == self.audio_in_token_id + audio_out_mask = processed_batch[i].input_ids == self.audio_out_token_id + audio_ids = torch.ones_like(processed_batch[i].input_ids) + audio_ids[audio_in_mask ^ audio_out_mask] = torch.cumsum(audio_ids[audio_in_mask ^ audio_out_mask], 0) - 1 + audio_in_ids = audio_ids[audio_in_mask] + audio_out_ids = audio_ids[audio_out_mask] + + if return_labels: + audio_out_no_train_flag.append(processed_batch[i].label_ids[audio_out_mask] < 0) + if self.mask_audio_out_token_label: + processed_batch[i].label_ids[audio_out_mask] = -100 + + # Process audio inputs + if self.return_audio_in_tokens: + audio_in_ids_l.extend( + [processed_batch[i].get_audio_codes(idx)[: self.audio_num_codebooks, :] for idx in audio_in_ids] + ) + if processed_batch[i].audio_label_ids_concat is not None: + if audio_in_label_ids_l is None: + audio_in_label_ids_l = [] + audio_in_label_ids_l.extend( + [ + processed_batch[i].get_audio_codes_labels(idx)[: self.audio_num_codebooks, :] + for idx in audio_in_ids + ] + ) + + audio_out_ids_l.extend( + [processed_batch[i].get_audio_codes(idx)[: self.audio_num_codebooks, :] for idx in audio_out_ids] + ) + audio_out_ids_group_loc_l.append(i) + if processed_batch[i].reward is not None: + reward_l.append(processed_batch[i].reward) + + if processed_batch[i].audio_label_ids_concat is not None: + if audio_out_label_ids_l is None: + audio_out_label_ids_l = [] + audio_out_label_ids_l.extend( + [ + processed_batch[i].get_audio_codes_labels(idx)[: self.audio_num_codebooks, :] + for idx in audio_out_ids + ] + ) + + if self.encode_whisper_embed: + for idx in audio_in_ids: + wv, sr = processed_batch[i].get_wv(idx) + resampled_wv = wv.cpu().numpy() + # Split long audio into chunks + total_samples = len(resampled_wv) + for chunk_start in range(0, total_samples, self.chunk_size_samples): + chunk_end = min(chunk_start + self.chunk_size_samples, total_samples) + chunk = resampled_wv[chunk_start:chunk_end] + audio_in_wv_l.append(chunk) + # assert len(audio_in_wv_l) == processed_batch[i].num_audios(), \ + # f"Assertion failed: Mismatch in number of audios. " \ + # f"Expected {processed_batch[i].num_audios()}, but got {len(audio_in_wv_l)} at index {i}." + + if return_labels: + audio_out_no_train_flag = torch.cat(audio_out_no_train_flag, dim=0) + + # Process all audio features + if len(audio_in_wv_l) > 0: + feature_ret = self.whisper_processor.feature_extractor( + audio_in_wv_l, + sampling_rate=self.whisper_processor.feature_extractor.sampling_rate, + return_attention_mask=True, + padding="max_length", + ) + audio_features = torch.from_numpy(feature_ret["input_features"]) + audio_feature_attention_mask = torch.from_numpy(feature_ret["attention_mask"]) + else: + if self.encode_whisper_embed: + audio_features = torch.zeros( + ( + 0, + self.whisper_processor.feature_extractor.feature_size, + self.whisper_processor.feature_extractor.nb_max_frames, + ), + dtype=torch.float32, + ) + audio_feature_attention_mask = torch.zeros( + (0, self.whisper_processor.feature_extractor.nb_max_frames), + dtype=torch.int32, + ) + else: + audio_features = None + audio_feature_attention_mask = None + + # Process audio input tokens + if len(audio_in_ids_l) > 0: + # Append audio-stream-bos and eos tokens + new_audio_in_ids_l = [] + for ele in audio_in_ids_l: + if self.disable_audio_codes_transform: + # Do not add audio-stream-bos or eos tokens. + # This may indicate that the sample comes from ConstantLengthDatasetWithBuffer. + audio_codes = ele + else: + audio_codes = torch.cat( + [ + torch.full( + (ele.shape[0], 1), + self.audio_stream_bos_id, + dtype=torch.long, + ), + ele, + torch.full( + (ele.shape[0], 1), + self.audio_stream_eos_id, + dtype=torch.long, + ), + ], + dim=1, + ) + if self.use_delay_pattern: + audio_codes = build_delay_pattern_mask( + audio_codes.unsqueeze(0), + bos_token_id=self.audio_stream_bos_id, + pad_token_id=self.audio_stream_eos_id, + )[0].squeeze(0) + new_audio_in_ids_l.append(audio_codes) + audio_in_ids = torch.cat(new_audio_in_ids_l, dim=1).long() + audio_in_ids_start = torch.cumsum( + torch.tensor([0] + [audio_codes.shape[1] for audio_codes in new_audio_in_ids_l[:-1]]), + dim=0, + ) + else: + audio_in_ids = torch.zeros((0, 0), dtype=torch.long) + audio_in_ids_start = torch.zeros(0, dtype=torch.long) + + # Process audio output tokens + audio_out_ids_start_group_loc = None + if len(audio_out_ids_l) > 0: + new_audio_out_ids_l = [] + label_audio_ids_l = [] + for idx, ele in enumerate(audio_out_ids_l): + if self.disable_audio_codes_transform: + # Do not add audio-stream-bos or eos tokens. + # This may indicate that the sample comes from ConstantLengthDatasetWithBuffer. + audio_codes = ele + if return_labels: + label_audio_ids = audio_out_label_ids_l[idx] + else: + audio_codes = torch.cat( + [ + torch.full( + (ele.shape[0], 1), + self.audio_stream_bos_id, + dtype=torch.long, + ), + ele, + torch.full( + (ele.shape[0], 1), + self.audio_stream_eos_id, + dtype=torch.long, + ), + ], + dim=1, + ) + if return_labels: + label_audio_ids = torch.cat( + [ + torch.full((ele.shape[0], 1), -100, dtype=torch.long), + ele, + torch.full( + (ele.shape[0], 1), + self.audio_stream_eos_id, + dtype=torch.long, + ), + ], + dim=1, + ) + if self.use_delay_pattern: + audio_codes = build_delay_pattern_mask( + audio_codes.unsqueeze(0), + bos_token_id=self.audio_stream_bos_id, + pad_token_id=self.audio_stream_eos_id, + )[0].squeeze(0) + if return_labels: + label_audio_ids = build_delay_pattern_mask( + label_audio_ids.unsqueeze(0), + bos_token_id=-100, + pad_token_id=-100, + )[0].squeeze(0) + new_audio_out_ids_l.append(audio_codes) + + if return_labels: + if audio_out_no_train_flag[idx]: + label_audio_ids[:] = -100 + label_audio_ids_l.append(label_audio_ids) + + audio_out_ids = torch.cat(new_audio_out_ids_l, dim=1).long() + if return_labels: + label_audio_ids = torch.cat(label_audio_ids_l, dim=1).long() + audio_out_ids_start = torch.cumsum( + torch.tensor([0] + [audio_codes.shape[1] for audio_codes in new_audio_out_ids_l[:-1]]), + dim=0, + ) + audio_out_ids_start_group_loc = torch.tensor(audio_out_ids_group_loc_l, dtype=torch.long) + else: + audio_out_ids = torch.zeros((0, 0), dtype=torch.long) + audio_out_ids_start = torch.zeros(0, dtype=torch.long) + if return_labels: + label_audio_ids = torch.zeros((0, 0), dtype=torch.long) + + reward = torch.tensor(reward_l, dtype=torch.float32) + + # Handle padding for input ids and attention mask + if self.pad_left: + input_ids = torch.stack( + [ + F.pad( + ele.input_ids, + (max_seq_length - len(ele.input_ids), 0), + value=self.pad_token_id, + ) + for ele in processed_batch + ] + ) + if return_labels: + label_ids = torch.stack( + [ + F.pad( + ele.label_ids, + (max_seq_length - len(ele.label_ids), 0), + value=-100, + ) + for ele in processed_batch + ] + ) + attention_mask = torch.stack( + [ + F.pad( + torch.ones_like(ele.input_ids), + (max_seq_length - len(ele.input_ids), 0), + value=0, + ) + for ele in processed_batch + ] + ) + else: + input_ids = torch.stack( + [ + F.pad( + ele.input_ids, + (0, max_seq_length - len(ele.input_ids)), + value=self.pad_token_id, + ) + for ele in processed_batch + ] + ) + if return_labels: + label_ids = torch.stack( + [ + F.pad( + ele.label_ids, + (0, max_seq_length - len(ele.label_ids)), + value=-100, + ) + for ele in processed_batch + ] + ) + attention_mask = torch.stack( + [ + F.pad( + torch.ones_like(ele.input_ids), + (0, max_seq_length - len(ele.input_ids)), + value=0, + ) + for ele in processed_batch + ] + ) + + if not self.return_audio_in_tokens: + audio_in_ids = None + audio_in_ids_start = None + + # Apply audio_num_codebooks limit if specified + if self.audio_num_codebooks is not None: + if audio_in_ids is not None: + audio_in_ids = audio_in_ids[: self.audio_num_codebooks] + if audio_out_ids is not None: + audio_out_ids = audio_out_ids[: self.audio_num_codebooks] + if label_audio_ids is not None: + label_audio_ids = label_audio_ids[: self.audio_num_codebooks] + + return HiggsAudioBatchInput( + input_ids=input_ids, + attention_mask=attention_mask, + audio_features=audio_features, + audio_feature_attention_mask=audio_feature_attention_mask, + audio_out_ids=audio_out_ids, + audio_out_ids_start=audio_out_ids_start, + audio_out_ids_start_group_loc=audio_out_ids_start_group_loc, + audio_in_ids=audio_in_ids, + audio_in_ids_start=audio_in_ids_start, + label_ids=label_ids, + label_audio_ids=label_audio_ids, + reward=reward, + ) + + +class HiggsAudioDPOSamplesCollator(HiggsAudioSampleCollator): + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + + def __call__(self, batch: List[RankedChatMLDatasetSampleTuple]) -> HiggsAudioBatchInput: + # flatten ranked chatml samples + chosen = [] + rejected = [] + + for sample in batch: + chosen.append(sample.max_score_sample()) + rejected.append(sample.min_score_sample()) + + merged = chosen + merged.extend(rejected) + + return super().__call__(batch=merged) diff --git a/higgs_audio/data_types.py b/higgs_audio/data_types.py new file mode 100644 index 0000000000000000000000000000000000000000..2b86089d48d6d3e25e307575e44a2c287ccae6fa --- /dev/null +++ b/higgs_audio/data_types.py @@ -0,0 +1,38 @@ +"""Basic data types for multimodal ChatML format.""" + +from dataclasses import dataclass +from typing import Dict, List, Optional, Union + + +@dataclass +class AudioContent: + audio_url: str + # Base64 encoded audio bytes + raw_audio: Optional[str] = None + offset: Optional[float] = None + duration: Optional[float] = None + row_id: Optional[int] = None + type: str = "audio" + + +@dataclass +class TextContent: + text: str + type: str = "text" + + +@dataclass +class Message: + role: str + content: Union[str, AudioContent, TextContent, List[Union[str, AudioContent, TextContent]]] + recipient: Optional[str] = None + + +@dataclass +class ChatMLSample: + """Dataclass to hold multimodal ChatML data.""" + + messages: List[Message] + start_index: Optional[int] = None # We will mask the messages[:start_index] when finetuning the LLM. + misc: Optional[Dict] = None + speaker: Optional[str] = None diff --git a/higgs_audio/dataset/__init__.py b/higgs_audio/dataset/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/higgs_audio/dataset/chatml_dataset.py b/higgs_audio/dataset/chatml_dataset.py new file mode 100644 index 0000000000000000000000000000000000000000..0f550a5f378f12b81477fc2aeed6ed6bf6d63a28 --- /dev/null +++ b/higgs_audio/dataset/chatml_dataset.py @@ -0,0 +1,554 @@ +import dacite +import pandas as pd +import torch +import json + +import numpy as np +import multiprocessing as mp + +from dataclasses import dataclass, fields +from abc import ABC, abstractmethod +from typing import Union, List, Dict, Optional + +from ..data_types import ChatMLSample, TextContent, AudioContent +from ..constants import AUDIO_IN_TOKEN, AUDIO_OUT_TOKEN + +from loguru import logger + +# Whisper processor, 30 sec -> 3000 features +# Then we divide 4 in the audio towker, we decrease 3000 features to 750, which gives 25 Hz +WHISPER_EMBED_NUM_HIDDEN_STATE_PER_SEC = 25 + + +@dataclass +class ChatMLDatasetSample: + input_ids: torch.LongTensor # Shape (seq_len,): The input text tokens. + label_ids: torch.LongTensor # Shape (seq_len,): The label ids. + audio_ids_concat: torch.LongTensor # Shape (num_codebooks, audio_seq_len): The audio tokens that are concatenated. + # Here `audio_seq_len` is the length of the concatenated audio tokens.` + audio_ids_start: ( + torch.LongTensor + ) # Shape (num_audios,): The start index of each audio token in the concatenated audio tokens. + audio_waveforms_concat: ( + torch.Tensor + ) # Shape (total_wv_length,): The concatenated audio waveforms for audio-in features. + audio_waveforms_start: ( + torch.LongTensor + ) # Shape (num_audios,): The start index of each audio waveform in the concatenated audio waveforms. + audio_sample_rate: torch.Tensor # Shape (num_audios,): The sampling rate of the audio waveforms. + audio_speaker_indices: ( + torch.LongTensor + ) # Shape (num_audios,) -1 means unknown speaker: The speaker indices for each audio. + audio_label_ids_concat: Optional[torch.LongTensor] = ( + None # Shape (num_codebooks, audio_seq_len): The audio tokens that are concatenated. + ) + # Here `audio_seq_len` is the length of the concatenated audio tokens.` + reward: Optional[float] = None + + def num_audios(self): + return max(len(self.audio_waveforms_start), len(self.audio_ids_start)) + + def get_audio_codes(self, idx): + code_start = self.audio_ids_start[idx] + if idx < len(self.audio_ids_start) - 1: + code_end = self.audio_ids_start[idx + 1] + else: + code_end = self.audio_ids_concat.shape[-1] + + return self.audio_ids_concat[:, code_start:code_end] + + def get_audio_codes_labels(self, idx): + if self.audio_label_ids_concat is None: + return None + code_start = self.audio_ids_start[idx] + if idx < len(self.audio_ids_start) - 1: + code_end = self.audio_ids_start[idx + 1] + else: + code_end = self.audio_ids_concat.shape[-1] + + return self.audio_label_ids_concat[:, code_start:code_end] + + def get_wv(self, idx): + wv_start = self.audio_waveforms_start[idx] + sr = self.audio_sample_rate[idx] + if idx < len(self.audio_waveforms_start) - 1: + wv_end = self.audio_waveforms_start[idx + 1] + else: + wv_end = self.audio_waveforms_concat.shape[-1] + return self.audio_waveforms_concat[wv_start:wv_end], sr + + def cal_num_tokens( + self, + encode_whisper_embed: bool = True, + encode_audio_in_tokens: bool = False, + encode_audio_out_tokens: bool = True, + audio_in_token_id: int = 128015, + audio_out_token_id: int = 128016, + ) -> int: + # we firstly exclude <|AUDIO|> and <|AUDIO_OUT|> because we do late merging and replace those position with actual audio features and audio token ids + # It's assumed that we always have audio_ids when audio_waveforms are there (but not vice-versa) + num_tokens = len(self.input_ids) - len(self.audio_ids_start) + + if encode_whisper_embed and len(self.audio_waveforms_concat) > 0: + audio_lengths = torch.diff(self.audio_waveforms_start) + if len(audio_lengths): + # Sum before calling .item() + num_tokens += ( + ( + np.ceil(WHISPER_EMBED_NUM_HIDDEN_STATE_PER_SEC * audio_lengths / self.audio_sample_rate[:-1]) + ).sum() + ).item() + # add the last audio's token estimation + num_tokens += ( + np.ceil( + WHISPER_EMBED_NUM_HIDDEN_STATE_PER_SEC + * (self.audio_waveforms_concat.shape[0] - self.audio_waveforms_start[-1]) + / self.audio_sample_rate[-1] + ) + ).item() + + if self.audio_ids_concat.size(1) > 0: + audio_io_ids = self.input_ids[ + (self.input_ids == audio_in_token_id) | (self.input_ids == audio_out_token_id) + ] + audio_io_id_lengths = torch.concat( + [ + torch.diff(self.audio_ids_start), + torch.tensor([self.audio_ids_concat.shape[-1] - self.audio_ids_start[-1]]), + ] + ) + if encode_audio_in_tokens: + num_tokens += torch.sum(audio_io_id_lengths[audio_io_ids == audio_in_token_id]).item() + + if encode_audio_out_tokens: + num_tokens += torch.sum(audio_io_id_lengths[audio_io_ids == audio_out_token_id]).item() + + return int(num_tokens) + + @classmethod + def merge( + cls, + samples: List["ChatMLDatasetSample"], + eos_token_id: int, + ignore_index: int, + padding_size: Optional[int] = None, + ) -> "ChatMLDatasetSample": + """Merges a list of ChatMLDatasetSample instances, inserting eos_token_id and ignore_index between them, and adjusting offsets for audio_ids_start and audio_waveforms_start. + + Args: + samples (List[ChatMLDatasetSample]): List of samples to merge. + eos_token_id (int): Tokens to be inserted into input_ids between samples. + ignore_index (int): Default label for padding. + padding_size (Optional[int]): If provided, pad the sequence to with this length. + + Returns: + ChatMLDatasetSample: Merged and potentially padded sample. + """ + if not samples: + logger.fatal("The samples list is empty and cannot be merged.") + raise ValueError("The samples list is empty and cannot be merged.") + + # Initialize empty lists for concatenation + input_ids_list = [] + label_ids_list = [] + audio_ids_concat_list = [] + audio_ids_start_list = [] + audio_waveforms_concat_list = [] + audio_waveforms_start_list = [] + audio_sample_rate_list = [] + audio_speaker_indices_list = [] + + # Track offsets + audio_ids_offset = 0 + audio_waveforms_offset = 0 + + for sample in samples: + # Add input_ids and label_ids with padding + if input_ids_list: + input_ids_list.append(torch.tensor([eos_token_id], dtype=torch.long)) + label_ids_list.append(torch.tensor([ignore_index], dtype=torch.long)) + input_ids_list.append(sample.input_ids) + label_ids_list.append(sample.label_ids) + + # Add audio_ids_concat and handle empty audio ids + if sample.audio_ids_concat.size(1) > 0: + audio_ids_concat_list.append(sample.audio_ids_concat) + + # Offset and add audio_ids_start + audio_ids_start_list.append(sample.audio_ids_start + audio_ids_offset) + audio_ids_offset += sample.audio_ids_concat.size( + 1 + ) # (num_codebooks, seq_len): Update offset by audio_seq_len + + # Add audio_waveforms_concat + if sample.audio_waveforms_concat.size(0) > 0: + # Check dimensions of the audio waveform to ensure consistency + if ( + audio_waveforms_concat_list + and sample.audio_waveforms_concat.dim() != audio_waveforms_concat_list[0].dim() + ): + logger.warning( + f"Skipping audio waveform with inconsistent dimensions: expected {audio_waveforms_concat_list[0].dim()}D, got {sample.audio_waveforms_concat.dim()}D" + ) + continue + + audio_waveforms_concat_list.append(sample.audio_waveforms_concat) + audio_waveforms_start_list.append(sample.audio_waveforms_start + audio_waveforms_offset) + audio_waveforms_offset += sample.audio_waveforms_concat.size(0) + + # Add audio_sample_rate and audio_speaker_indices + audio_sample_rate_list.append(sample.audio_sample_rate) + + audio_speaker_indices_list.append(sample.audio_speaker_indices) + + # Concatenate all tensors + input_ids = torch.cat(input_ids_list, dim=0) + label_ids = torch.cat(label_ids_list, dim=0) + + # Apply padding if padding_size is specified + if padding_size is not None and padding_size > 0: + input_ids = torch.cat( + [ + input_ids, + torch.full((padding_size,), eos_token_id, dtype=torch.long), + ], + dim=0, + ) + label_ids = torch.cat( + [ + label_ids, + torch.full((padding_size,), ignore_index, dtype=torch.long), + ], + dim=0, + ) + + # Safely concatenate audio tensors with proper error handling + try: + audio_ids_concat = torch.cat(audio_ids_concat_list, dim=1) if audio_ids_concat_list else torch.tensor([[]]) + audio_ids_start = torch.cat(audio_ids_start_list, dim=0) if audio_ids_start_list else torch.tensor([]) + + # Check for dimensional consistency in audio waveforms + if audio_waveforms_concat_list: + dims = [t.dim() for t in audio_waveforms_concat_list] + if not all(d == dims[0] for d in dims): + # If dimensions don't match, log warning and filter out the problematic tensors + logger.warning( + f"Inconsistent dimensions in audio waveforms: {dims}. Filtering to keep only consistent ones." + ) + expected_dim = max(set(dims), key=dims.count) # Most common dimension + audio_waveforms_concat_list = [t for t in audio_waveforms_concat_list if t.dim() == expected_dim] + + # Recalculate audio_waveforms_start with the filtered list + if audio_waveforms_concat_list: + audio_waveforms_offset = 0 + audio_waveforms_start_list = [] + for waveform in audio_waveforms_concat_list: + audio_waveforms_start_list.append(torch.tensor([audio_waveforms_offset])) + audio_waveforms_offset += waveform.size(0) + + audio_waveforms_concat = ( + torch.cat(audio_waveforms_concat_list, dim=0) if audio_waveforms_concat_list else torch.tensor([]) + ) + audio_waveforms_start = ( + torch.cat(audio_waveforms_start_list, dim=0) if audio_waveforms_start_list else torch.tensor([]) + ) + audio_sample_rate = ( + torch.cat(audio_sample_rate_list, dim=0) if audio_sample_rate_list else torch.tensor([]) + ) + audio_speaker_indices = ( + torch.cat(audio_speaker_indices_list, dim=0) if audio_speaker_indices_list else torch.tensor([]) + ) + + except RuntimeError as e: + logger.error(f"Error during tensor concatenation: {str(e)}") + logger.warning("Falling back to empty audio tensors") + # Fall back to empty tensors + audio_ids_concat = torch.tensor([[]]) + audio_ids_start = torch.tensor([]) + audio_waveforms_concat = torch.tensor([]) + audio_waveforms_start = torch.tensor([]) + audio_sample_rate = torch.tensor([]) + audio_speaker_indices = torch.tensor([]) + + # Create the merged sample + merged_sample = cls( + input_ids=input_ids, + label_ids=label_ids, + audio_ids_concat=audio_ids_concat, + audio_ids_start=audio_ids_start, + audio_waveforms_concat=audio_waveforms_concat, + audio_waveforms_start=audio_waveforms_start, + audio_sample_rate=audio_sample_rate, + audio_speaker_indices=audio_speaker_indices, + ) + + return merged_sample + + +@dataclass +class RankedChatMLDatasetSampleTuple: + samples: List[ChatMLDatasetSample] + scores: List[float] + + def max_score_sample(self) -> ChatMLDatasetSample: + idx = self.scores.index(max(self.scores)) + self.samples[idx].reward = self.scores[idx] + return self.samples[idx] + + def min_score_sample(self) -> ChatMLDatasetSample: + idx = self.scores.index(min(self.scores)) + self.samples[idx].reward = self.scores[idx] + return self.samples[idx] + + +@dataclass +class ChatMLDatasetStorageSample: + input_tokens: torch.LongTensor + label_tokens: torch.LongTensor + audio_bytes_cache_dir_index: int + audio_codes_cache_dir_index: int + audio_bytes_indices: torch.LongTensor + audio_codes_indices: torch.LongTensor + speaker_indices: torch.LongTensor + file_index: int + original_sample_index: int + + +# TODO(sxjscience): We need to revist the logic about parsing speaker ids. +# Currently, we assume that the speaker id is stored at the "misc" field in ChatMLSample. +def prepare_chatml_sample(sample: Union[ChatMLSample, Dict], tokenizer): + """Preprocess the ChatML sample to get the tokens for the text part. + + Args: + sample (ChatMLSample): The ChatML sample to preprocess. + tokenizer: The tokenizer to use for encoding the text. + + """ + + try: + if not isinstance(sample, ChatMLSample): + # Handle all fields that could be NaN + if "speaker" in sample and pd.isna(sample["speaker"]): + sample["speaker"] = None + if "start_index" in sample and pd.isna(sample["start_index"]): + sample["start_index"] = None + if "content" in sample and pd.isna(sample["content"]): + sample["content"] = "" + + # Convert any other potential NaN values in nested structures + def convert_nan_to_none(obj): + import numpy as np + + if isinstance(obj, (pd.Series, np.ndarray)): + return obj.tolist() + elif pd.api.types.is_scalar(obj) and pd.isna(obj): + return None + elif isinstance(obj, dict): + return {k: convert_nan_to_none(v) for k, v in obj.items()} + elif isinstance(obj, (list, tuple)): # Fixed: Handle both list and tuple + return [convert_nan_to_none(item) for item in obj] + return obj + + # Clean the sample data + clean_sample = convert_nan_to_none(sample) + + val_keys = [] + for field in fields(ChatMLSample): + if field.name in clean_sample: + val_keys.append(field.name) + clean_sample = {k: clean_sample[k] for k in val_keys} + + try: + sample = dacite.from_dict( + data_class=ChatMLSample, + data=clean_sample, + config=dacite.Config(strict=True, check_types=True), + ) + except Exception as e: + print(f"Failed to convert to ChatMLSample: {e}") + print(f"Clean sample: {json.dumps(clean_sample, indent=2)}") + return None, None, None, None + + input_tokens = [] + label_tokens = [] + audio_contents = [] + speaker_id = None + if sample.speaker is not None: + speaker_id = sample.speaker + elif sample.misc is not None: + if "speaker" in sample.misc: + speaker_id = sample.misc["speaker"] + + total_m = len(sample.messages) + for turn_id, message in enumerate(sample.messages): + role = message.role + recipient = message.recipient + content = message.content + content_l = [] + + if isinstance(content, str): + content_l.append(TextContent(text=content)) + elif isinstance(content, TextContent): + content_l.append(content) + elif isinstance(content, AudioContent): + content_l.append(content) + elif isinstance(content, list): + for ele in content: + if isinstance(ele, str): + content_l.append(TextContent(text=ele)) + else: + content_l.append(ele) + if turn_id == 0: + prefix = f"<|begin_of_text|><|start_header_id|>{role}<|end_header_id|>\n\n" + else: + prefix = f"<|start_header_id|>{role}<|end_header_id|>\n\n" + eot_postfix = "<|eot_id|>" + eom_postfix = "<|eom_id|>" + + prefix_tokens = tokenizer.encode(prefix, add_special_tokens=False) + input_tokens.extend(prefix_tokens) + label_tokens.extend([-100 for _ in prefix_tokens]) + + if recipient: + assert role == "assistant", "Recipient is only available for assistant role." + recipient_tokens = tokenizer.encode(f"{recipient}<|recipient|>", add_special_tokens=False) + input_tokens.extend(recipient_tokens) + label_tokens.extend(recipient_tokens) + + for content in content_l: + if content.type == "text": + text_tokens = tokenizer.encode(content.text, add_special_tokens=False) + input_tokens.extend(text_tokens) + if role == "assistant" and (sample.start_index is None or turn_id >= sample.start_index): + label_tokens.extend(text_tokens) + else: + label_tokens.extend([-100 for _ in text_tokens]) + + elif content.type == "audio": + # Generate the text-part of the audio tokens + audio_contents.append(content) + if role == "user" or role == "system": + # Add the text tokens + text_tokens = tokenizer.encode( + f"<|audio_bos|><|AUDIO|><|audio_eos|>", + add_special_tokens=False, + ) + input_tokens.extend(text_tokens) + label_tokens.extend([-100 for _ in text_tokens]) + elif role == "assistant": + # Add the text tokens for audio-out part. + text_tokens = tokenizer.encode( + f"<|audio_out_bos|><|AUDIO_OUT|><|audio_eos|>", + add_special_tokens=False, + ) + input_tokens.extend(text_tokens) + if sample.start_index is None or turn_id >= sample.start_index: + label_tokens.extend(text_tokens) + else: + label_tokens.extend([-100 for _ in text_tokens]) + next_id = turn_id + 1 + if role == "assistant" and next_id != total_m and sample.messages[next_id].role == "assistant": + postfix_tokens = tokenizer.encode(eom_postfix, add_special_tokens=False) + input_tokens.extend(postfix_tokens) + else: + postfix_tokens = tokenizer.encode(eot_postfix, add_special_tokens=False) + input_tokens.extend(postfix_tokens) + if role == "assistant" and (sample.start_index is None or turn_id >= sample.start_index): + label_tokens.extend(postfix_tokens) + else: + label_tokens.extend([-100 for _ in postfix_tokens]) + + return input_tokens, label_tokens, audio_contents, speaker_id + + except Exception as e: + print(f"Error in prepare_chatml_sample: {str(e)}") + print(f"Sample data: {json.dumps(sample, indent=2)}") + return None, None, None, None + + +def extract_generation_prompt_from_input_tokens(input_tokens, tokenizer): + """Extract the generation prompt and reference answer from the input tokens. + + For example: + + Input Text = '<|begin_of_text|><|start_header_id|>user<|end_header_id|>\n\n + What words do you hear from the provided audio? Write it down for me.<|audio_bos|><|AUDIO|><|audio_eos|><|eot_id|> + <|start_header_id|>assistant<|end_header_id|>\n\nAt first they went by quick, too quick to even get.<|eot_id|>' + + --> + + Prompt = '<|begin_of_text|><|start_header_id|>user<|end_header_id|>\n\n + What words do you hear from the provided audio? Write it down for me.<|audio_bos|><|AUDIO|><|audio_eos|><|eot_id|> + <|start_header_id|>assistant<|end_header_id|>\n\n', + Reference = 'At first they went by quick, too quick to even get.' + + Args: + input_tokens: The input tokens. + audio_contents: The audio contents. + tokenizer: The tokenizer to use for decoding the text. + + Returns: + prompt_tokens: The tokens for the prompt. + reference_answer: The reference answer. + num_audios_in_reference: The number of audios in the reference answer. + + """ + input_text = tokenizer.decode(input_tokens) + generation_prefix = "<|start_header_id|>assistant<|end_header_id|>\n\n" + postfix = "<|eot_id|>" + assert generation_prefix in input_text + generation_prompt_end_loc = input_text.rfind(generation_prefix) + len(generation_prefix) + generation_prompt = input_text[:generation_prompt_end_loc] + reference_answer = input_text[generation_prompt_end_loc : input_text.find(postfix, generation_prompt_end_loc)] + num_audios_in_reference = reference_answer.count(AUDIO_IN_TOKEN) + reference_answer.count(AUDIO_OUT_TOKEN) + return ( + tokenizer.encode(generation_prompt, add_special_tokens=False), + reference_answer, + num_audios_in_reference, + ) + + +def prepare_chatml_dataframe_single_process(df, tokenizer): + """Prepare the ChatML DataFrame.""" + ret = [] + for _, row in df.iterrows(): + input_tokens, label_tokens, audio_contents, speaker_id = prepare_chatml_sample(row.to_dict(), tokenizer) + ret.append((input_tokens, label_tokens, audio_contents, speaker_id)) + return ret + + +def prepare_chatml_dataframe(df, tokenizer, num_process=16): + if num_process is None: + return prepare_chatml_dataframe_single_process(df, tokenizer) + else: + num_process = max(min(len(df) // 1000, num_process), 1) + workloads = np.array_split(df, num_process) + with mp.Pool(num_process) as pool: + ret = pool.starmap( + prepare_chatml_dataframe_single_process, + [(workload, tokenizer) for workload in workloads], + ) + return sum(ret, []) + + +class DatasetInterface(ABC): + @abstractmethod + def __getitem__(self, idx) -> Union["ChatMLDatasetSample", "RankedChatMLDatasetSampleTuple"]: + """Retrieve a dataset sample by index.""" + raise NotImplementedError + + +class IterableDatasetInterface(ABC): + @abstractmethod + def __iter__( + self, + ) -> Union["ChatMLDatasetSample", "RankedChatMLDatasetSampleTuple"]: + """Retrieve a sample by iterating through the dataset.""" + raise NotImplementedError + + +@dataclass +class DatasetInfo: + dataset_type: str + group_type: Optional[str] = None + mask_text: Optional[bool] = None # Whether to mask the text tokens for pretraining samples. diff --git a/higgs_audio/model/__init__.py b/higgs_audio/model/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..6ad77c28104c694c79e3588903eba7fdb2051a8e --- /dev/null +++ b/higgs_audio/model/__init__.py @@ -0,0 +1,9 @@ +from transformers import AutoConfig, AutoModel + +from .configuration_higgs_audio import HiggsAudioConfig, HiggsAudioEncoderConfig +from .modeling_higgs_audio import HiggsAudioModel + + +AutoConfig.register("higgs_audio_encoder", HiggsAudioEncoderConfig) +AutoConfig.register("higgs_audio", HiggsAudioConfig) +AutoModel.register(HiggsAudioConfig, HiggsAudioModel) diff --git a/higgs_audio/model/audio_head.py b/higgs_audio/model/audio_head.py new file mode 100644 index 0000000000000000000000000000000000000000..8cb41a275f596ec5c1826b217ab0679faf7ad8c5 --- /dev/null +++ b/higgs_audio/model/audio_head.py @@ -0,0 +1,139 @@ +"""Projector that maps hidden states from the LLM component to multimodal logits.""" + +import torch +from torch import nn + +from dataclasses import dataclass +from typing import Optional, Tuple + +from .common import HiggsAudioPreTrainedModel +from .configuration_higgs_audio import HiggsAudioConfig + + +@dataclass +class HiggsAudioDecoderLayerOutput: + logits: torch.FloatTensor + audio_logits: torch.FloatTensor + attentions: Optional[Tuple[torch.FloatTensor, ...]] = None + past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None + + +class HiggsAudioDecoderProjector(HiggsAudioPreTrainedModel): + """Projection layers that map hidden states from the LLM component to audio / text logits. + + We support two type of audio head: + - Basic Audio Head: + Directly map the hidden states to audio logits for all the codebooks. + """ + + def __init__(self, config: HiggsAudioConfig, layer_idx: Optional[int] = None): + super().__init__(config) + self.text_lm_head = nn.Linear(config.text_config.hidden_size, config.text_config.vocab_size, bias=False) + self.audio_lm_head = nn.Linear( + config.text_config.hidden_size, + config.audio_num_codebooks * (config.audio_codebook_size + 2), + bias=False, + ) + + # Initialize weights and apply final processing + self.post_init() + + def forward( + self, + hidden_states, + audio_out_mask, + label_audio_ids=None, + attention_mask=None, + position_ids=None, + past_key_values=None, + use_cache=None, + output_attentions=None, + output_hidden_states=None, + output_audio_hidden_states=False, + cache_position=None, + ): + """ + Args: + hidden_states (`torch.Tensor` of shape `(batch_size, seq_len, hidden_size)`): + Hidden states from the LLM component + audio_out_mask (`torch.Tensor` of shape `(batch_size, seq_len)`): + Mask for identifying the audio out tokens. + label_audio_ids (`torch.Tensor` of shape `(num_codebooks, num_audio_out_tokens)`): + Label tokens for the audio-out part. This is used for calculating the logits if RQ-Transformer is used. + attention_mask (`torch.Tensor` of shape `(batch_size, seq_len)`): + Mask to avoid performing attention on padding token indices + position_ids (`torch.Tensor` of shape `(batch_size, seq_len)`): + Position ids for the input tokens + + Returns: + logits (`torch.Tensor` of shape `(batch_size, seq_len, vocab_size)`): + Logits for text tokens + audio_logits (`torch.Tensor` of shape `(num_audio_out_tokens, audio_num_codebooks * audio_codebook_size)`): + Logits for audio tokens. We ensure `num_text_tokens + num_audio_tokens == batch_size * seq_len` + """ + logits = self.text_lm_head(hidden_states) + + all_hidden_states = () if output_hidden_states else None + all_self_attns = () if output_attentions else None + next_decoder_cache = None + + # TODO(sxjscience) Need to check if DeepSpeed Zero3 supports zero-shape input. + if self.config.audio_decoder_proj_num_layers > 0: + # create position embeddings to be shared across the decoder layers + position_embeddings = self.rotary_emb(hidden_states, position_ids) + for decoder_layer in self.transformer_layers: + if output_hidden_states: + all_hidden_states += (hidden_states,) + + if self.gradient_checkpointing and self.training: + layer_outputs = self._gradient_checkpointing_func( + decoder_layer.__call__, + hidden_states, + attention_mask, + position_ids, + past_key_values, + output_attentions, + use_cache, + cache_position, + position_embeddings, + ) + else: + layer_outputs = decoder_layer( + hidden_states, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_value=past_key_values, + output_attentions=output_attentions, + use_cache=use_cache, + cache_position=cache_position, + position_embeddings=position_embeddings, + ) + hidden_states = layer_outputs[0] + hidden_states = self.norm(hidden_states) + + if output_hidden_states: + all_hidden_states += (hidden_states,) + + if output_attentions: + all_self_attns += (layer_outputs[1],) + + if use_cache: + next_decoder_cache = layer_outputs[2 if output_attentions else 1] + + next_cache = next_decoder_cache if use_cache else None + + audio_logits = self.audio_lm_head(hidden_states[audio_out_mask]) + + if output_audio_hidden_states: + audio_hidden_states = hidden_states[audio_out_mask] + else: + audio_hidden_states = None + + return ( + logits, + audio_logits, + all_self_attns, + all_hidden_states, + audio_hidden_states, + next_cache, + ) diff --git a/higgs_audio/model/common.py b/higgs_audio/model/common.py new file mode 100644 index 0000000000000000000000000000000000000000..e01ba869e2a5a10ab942730411e54bc0f55f8e2e --- /dev/null +++ b/higgs_audio/model/common.py @@ -0,0 +1,27 @@ +from torch import nn + +from transformers.modeling_utils import PreTrainedModel + +from .configuration_higgs_audio import HiggsAudioConfig + + +class HiggsAudioPreTrainedModel(PreTrainedModel): + config_class = HiggsAudioConfig + base_model_prefix = "model" + supports_gradient_checkpointing = True + _no_split_modules = [] + _skip_keys_device_placement = "past_key_values" + _supports_flash_attn_2 = True + _supports_sdpa = True + + def _init_weights(self, module): + std = self.config.init_std if hasattr(self.config, "init_std") else self.config.audio_encoder_config.init_std + + if isinstance(module, (nn.Linear, nn.Conv1d)): + module.weight.data.normal_(mean=0.0, std=std) + if module.bias is not None: + module.bias.data.zero_() + elif isinstance(module, nn.Embedding): + module.weight.data.normal_(mean=0.0, std=std) + if module.padding_idx is not None: + module.weight.data[module.padding_idx].zero_() diff --git a/higgs_audio/model/configuration_higgs_audio.py b/higgs_audio/model/configuration_higgs_audio.py new file mode 100644 index 0000000000000000000000000000000000000000..1783d029d77a9599b0a7e75a2f4dbaf192431da1 --- /dev/null +++ b/higgs_audio/model/configuration_higgs_audio.py @@ -0,0 +1,235 @@ +from transformers.configuration_utils import PretrainedConfig +from transformers.models.auto import CONFIG_MAPPING + + +class HiggsAudioEncoderConfig(PretrainedConfig): + """Configuration of the Audio encoder in Higgs-Audio.""" + + model_type = "higgs_audio_encoder" + + def __init__( + self, + num_mel_bins=128, + encoder_layers=32, + encoder_attention_heads=20, + encoder_ffn_dim=5120, + encoder_layerdrop=0.0, + d_model=1280, + dropout=0.0, + attention_dropout=0.0, + activation_function="gelu", + activation_dropout=0.0, + scale_embedding=False, + init_std=0.02, + max_source_positions=1500, + pad_token_id=128001, + **kwargs, + ): + super().__init__(**kwargs) + + self.num_mel_bins = num_mel_bins + self.d_model = d_model + self.encoder_layers = encoder_layers + self.encoder_attention_heads = encoder_attention_heads + self.encoder_ffn_dim = encoder_ffn_dim + self.dropout = dropout + self.attention_dropout = attention_dropout + self.activation_function = activation_function + self.activation_dropout = activation_dropout + self.encoder_layerdrop = encoder_layerdrop + self.num_hidden_layers = encoder_layers + self.init_std = init_std + self.scale_embedding = scale_embedding # scale factor will be sqrt(d_model) if True + self.max_source_positions = max_source_positions + self.pad_token_id = pad_token_id + + +class HiggsAudioConfig(PretrainedConfig): + r""" + This is the configuration class for the HiggsAudioModel. + + Args: + text_config (`Union[AutoConfig, dict]`): + The config object or dictionary of the text backbone. + audio_encoder_config (`Union[AutoConfig, dict]`): + The config object or dictionary of the whisper encoder. + The audio encoder will be bidirectional and will be only available for audio understanding. + audio_tokenizer_config + The config object or dictionary of the audio tokenizer. + audio_adapter_type + The type of audio adapter to use. We support two types of adapter: + - stack: + We stack additional Transformer layers after the main LLM backbone for audio generation. + - dual_ffn: + For selected part of the LLM backbone, we replace the text FFN with a dual FFN architecture + that contains an additional audio FFN. The audio FFN will be triggered when the location is marked for audio tokens. + - dual_ffn_fast_forward: + We pick a few layers in the LLM backbone to plug-in the audio FFN. For the remaining layers, + the audio hidden states will be directly fast-forward to the next layer. + This reduces the computational cost for audio generation. + audio_embed_avg (`bool`, *optional*, defaults to False): + Whether to average the audio embeddings before sending them to the text attention layer. + audio_ffn_hidden_size + The hidden size of the audio feedforward network in dual-path FFN + audio_ffn_intermediate_size + The intermediate size of the audio feedforward network in dual-path FFN + audio_dual_ffn_layers + The layers in the LLM backbone to plug-in the dual FFN layer (mixture of audio FFN and text FFN). + audio_decoder_proj_num_attention (`int`, *optional*, defaults to 0): + The number of attention heads in the audio decoder projection layer. + use_delay_pattern (`bool`, *optional*, defaults to False): + Whether to use delay pattern in the audio decoder. + skip_audio_tower (`bool`, *optional*, defaults to False): + Whether to skip the audio tower in the audio encoder. + use_audio_out_embed_projector (`bool`, *optional*, defaults to False): + Whether to use an embedding projector to map audio out embeddings. + use_audio_out_self_attention (`bool`, *optional*, defaults to False): + Whether to use self-attention to aggregate information from audio-tokens before sending to the text attention layer. + audio_num_codebooks (`int`, *optional*, defaults to 12): + The number of codebooks in RVQGAN. + audio_codebook_size (`int`, *optional*, defaults to 1024): + The size of each codebook in RVQGAN. + audio_stream_bos_id + The id of the bos in the audio stream + audio_stream_eos_id + The id of the eos in the audio stream + audio_bos_token (`str`, *optional*, defaults to "<|audio_bos|>"): + The special `<|audio_bos|>` token. In Higgs-Audio, it is mapped to 128011, + which is the index of `<|reserved_special_token_3|>` in Llama-3.1-8B-Instruct's tokenizer. + audio_eos_token (`str`, *optional*, defaults to "<|audio_eos|>"): + The special `<|audio_eos|>` token. We use 128012 as the default value, + which is the index of `<|reserved_special_token_4|>` in Llama-3.1-8B-Instruct's tokenizer. + audio_out_bos_token (`str`, *optional*, defaults to "<|audio_out_bos|>"): + The special `<|audio_out_bos|>` token. We use 128013 as the default value, + which is the index of `<|reserved_special_token_5|>` in Llama-3.1-8B-Instruct's tokenizer. + audio_token (`str`, *optional*, defaults to "<|AUDIO|>"): + The special `<|AUDIO|>` token. We use 128015 as the default value, + which is the index of `<|reserved_special_token_7|>` in Llama-3.1-8B-Instruct's tokenizer. + This token indicates that the location should be filled in with whisper features. + audio_out_token (`str`, *optional*, defaults to "<|AUDIO_OUT|>"): + The special `<|AUDIO_OUT|>` token. We use 128016 as the default value, + which is the index of `<|reserved_special_token_8|>` in Llama-3.1-8B-Instruct's tokenizer. + This token indicates that the location should be filled in with audio tokens extracted via audio tokenizer. + """ + + model_type = "higgs_audio" + is_composition = True + + def __init__( + self, + text_config=None, + audio_encoder_config=None, + audio_tokenizer_config=None, + audio_adapter_type="stack", + audio_embed_avg=False, + audio_ffn_hidden_size=4096, + audio_ffn_intermediate_size=14336, + audio_dual_ffn_layers=None, + audio_decoder_proj_num_layers=0, + encode_whisper_embed=True, + encode_audio_in_tokens=False, + use_delay_pattern=False, + skip_audio_tower=False, + use_audio_out_embed_projector=False, + use_audio_out_self_attention=False, + use_rq_transformer=False, + rq_transformer_hidden_size=None, + rq_transformer_intermediate_size=None, + rq_transformer_num_attention_heads=None, + rq_transformer_num_key_value_heads=None, + rq_transformer_num_hidden_layers=3, + audio_num_codebooks=12, + audio_codebook_size=1024, + audio_stream_bos_id=1024, + audio_stream_eos_id=1025, + audio_bos_token="<|audio_bos|>", + audio_eos_token="<|audio_eos|>", + audio_out_bos_token="<|audio_out_bos|>", + audio_in_token="<|AUDIO|>", + audio_out_token="<|AUDIO_OUT|>", + audio_in_token_idx=128015, + audio_out_token_idx=128016, + pad_token_id=128001, + audio_out_bos_token_id=128013, + audio_eos_token_id=128012, + **kwargs, + ): + if isinstance(audio_encoder_config, dict): + audio_encoder_config["model_type"] = ( + audio_encoder_config["model_type"] if "model_type" in audio_encoder_config else "higgs_audio_encoder" + ) + audio_encoder_config = CONFIG_MAPPING[audio_encoder_config["model_type"]](**audio_encoder_config) + elif audio_encoder_config is None: + audio_encoder_config = HiggsAudioEncoderConfig() + + if isinstance(text_config, dict): + text_config["model_type"] = text_config["model_type"] if "model_type" in text_config else "llama" + text_config = CONFIG_MAPPING[text_config["model_type"]](**text_config) + elif text_config is None: + text_config = CONFIG_MAPPING["llama"]() + + assert audio_adapter_type in [ + "stack", + "dual_ffn", + "dual_ffn_fast_forward", + ], f"Invalid audio adapter type: {audio_adapter_type}" + if audio_adapter_type.startswith("dual_ffn"): + assert audio_dual_ffn_layers is not None, ( + "audio_dual_ffn_layers must be specified when using dual_ffn adapter." + ) + self.text_config = text_config + self.audio_encoder_config = audio_encoder_config + self.audio_tokenizer_config = audio_tokenizer_config + self.audio_adapter_type = audio_adapter_type + self.audio_embed_avg = audio_embed_avg + self.audio_ffn_hidden_size = audio_ffn_hidden_size + self.audio_ffn_intermediate_size = audio_ffn_intermediate_size + self.audio_dual_ffn_layers = audio_dual_ffn_layers + self.audio_decoder_proj_num_layers = audio_decoder_proj_num_layers + self.encode_whisper_embed = encode_whisper_embed + self.encode_audio_in_tokens = encode_audio_in_tokens + self.use_delay_pattern = use_delay_pattern + self.skip_audio_tower = skip_audio_tower + self.use_audio_out_embed_projector = use_audio_out_embed_projector + self.use_audio_out_self_attention = use_audio_out_self_attention + + self.use_rq_transformer = use_rq_transformer + + if self.use_rq_transformer: + assert not self.use_delay_pattern, "Delay pattern is not supported if you turned on RQ-Transformer!" + self.rq_transformer_hidden_size = rq_transformer_hidden_size + self.rq_transformer_intermediate_size = rq_transformer_intermediate_size + self.rq_transformer_num_attention_heads = rq_transformer_num_attention_heads + self.rq_transformer_num_key_value_heads = rq_transformer_num_key_value_heads + self.rq_transformer_num_hidden_layers = rq_transformer_num_hidden_layers + + if use_rq_transformer: + # For RQ-Transformer, we set the hidden_size to the same as the text model's hidden size if it is not specified. + if self.rq_transformer_hidden_size is None: + self.rq_transformer_hidden_size = text_config.hidden_size + assert self.rq_transformer_hidden_size % 128 == 0 + if self.rq_transformer_intermediate_size is None: + self.rq_transformer_intermediate_size = text_config.intermediate_size + if self.rq_transformer_num_attention_heads is None: + self.rq_transformer_num_attention_heads = self.rq_transformer_hidden_size // 128 + if self.rq_transformer_num_key_value_heads is None: + self.rq_transformer_num_key_value_heads = self.rq_transformer_hidden_size // 128 // 4 + assert self.rq_transformer_hidden_size % self.rq_transformer_num_attention_heads == 0 + assert self.rq_transformer_hidden_size % self.rq_transformer_num_key_value_heads == 0 + + self.audio_num_codebooks = audio_num_codebooks + self.audio_codebook_size = audio_codebook_size + self.audio_bos_token = audio_bos_token + self.audio_eos_token = audio_eos_token + self.audio_out_bos_token = audio_out_bos_token + self.audio_in_token = audio_in_token + self.audio_out_token = audio_out_token + self.audio_in_token_idx = audio_in_token_idx + self.audio_out_token_idx = audio_out_token_idx + self.audio_stream_bos_id = audio_stream_bos_id + self.audio_stream_eos_id = audio_stream_eos_id + self.audio_out_bos_token_id = audio_out_bos_token_id + self.audio_eos_token_id = audio_eos_token_id + + super().__init__(**kwargs) + self.pad_token_id = pad_token_id diff --git a/higgs_audio/model/cuda_graph_runner.py b/higgs_audio/model/cuda_graph_runner.py new file mode 100644 index 0000000000000000000000000000000000000000..a99507cb6c3a17c5414c46e09398b942af1f4004 --- /dev/null +++ b/higgs_audio/model/cuda_graph_runner.py @@ -0,0 +1,129 @@ +import torch +import torch.nn as nn +from typing import Optional, List, Dict, Tuple, Union +import gc + +from transformers.cache_utils import Cache + + +_NUM_WARMUP_ITERS = 2 + + +class CUDAGraphRunner(nn.Module): + def __init__(self, model): + super().__init__() + self.model = model + + self.input_buffers: Dict[str, torch.Tensor] = {} + self.output_buffers: Dict[str, torch.Tensor] = {} + + self._graph: Optional[torch.cuda.CUDAGraph] = None + + @property + def graph(self): + assert self._graph is not None + return self._graph + + def capture( + self, + hidden_states: torch.Tensor, + causal_mask: torch.Tensor, + position_ids: torch.Tensor, + audio_discrete_codes_mask: torch.Tensor, + cache_position: torch.Tensor, + past_key_values: Union[Cache, List[torch.FloatTensor]], + use_cache: bool, + audio_attention_mask: torch.Tensor, + fast_forward_attention_mask: torch.Tensor, + output_attentions: bool, + output_hidden_states: bool, + is_decoding_audio_token: Optional[bool] = None, + is_using_cuda_graph: Optional[bool] = False, + stream: torch.cuda.Stream = None, + memory_pool: Optional[Tuple[int, int]] = None, + ): + assert self._graph is None + # Run warmup iterations + for _ in range(_NUM_WARMUP_ITERS): + self.model( + hidden_states=hidden_states, + causal_mask=causal_mask, + position_ids=position_ids, + audio_discrete_codes_mask=audio_discrete_codes_mask, + cache_position=cache_position, + past_key_values=past_key_values, + use_cache=use_cache, + audio_attention_mask=audio_attention_mask, + fast_forward_attention_mask=fast_forward_attention_mask, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + is_decoding_audio_token=is_decoding_audio_token, + is_using_cuda_graph=is_using_cuda_graph, + ) + + torch.cuda.synchronize() + + # Capture the graph + self._graph = torch.cuda.CUDAGraph() + with torch.cuda.graph(self._graph, pool=memory_pool, stream=stream): + out_hidden_states, all_hidden_states, all_self_attns = self.model( + hidden_states=hidden_states, + causal_mask=causal_mask, + position_ids=position_ids, + audio_discrete_codes_mask=audio_discrete_codes_mask, + cache_position=cache_position, + past_key_values=past_key_values, + use_cache=use_cache, + audio_attention_mask=audio_attention_mask, + fast_forward_attention_mask=fast_forward_attention_mask, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + is_decoding_audio_token=is_decoding_audio_token, + is_using_cuda_graph=is_using_cuda_graph, + ) + # hidden_states_out = torch.ops._C.weak_ref_tensor(outputs[0]) + # del outputs + gc.collect() + torch.cuda.synchronize() + + # Save input and output buffers + self.input_buffers = { + "hidden_states": hidden_states, + "causal_mask": causal_mask, + "position_ids": position_ids, + "audio_discrete_codes_mask": audio_discrete_codes_mask, + "cache_position": cache_position, + "past_key_values": past_key_values, + "audio_attention_mask": audio_attention_mask, + "fast_forward_attention_mask": fast_forward_attention_mask, + } + self.output_buffers = { + "hidden_states": out_hidden_states, + "all_hidden_states": all_hidden_states, + "all_self_attns": all_self_attns, + } + + def forward( + self, + hidden_states: torch.Tensor, + causal_mask: torch.Tensor, + position_ids: torch.Tensor, + audio_discrete_codes_mask: torch.Tensor, + cache_position: torch.Tensor, + audio_attention_mask: torch.Tensor, + fast_forward_attention_mask: torch.Tensor, + **kwargs, + ) -> torch.Tensor: + # Copy input tensors to buffers + self.input_buffers["hidden_states"].copy_(hidden_states, non_blocking=True) + self.input_buffers["causal_mask"].copy_(causal_mask, non_blocking=True) + self.input_buffers["position_ids"].copy_(position_ids, non_blocking=True) + self.input_buffers["audio_discrete_codes_mask"].copy_(audio_discrete_codes_mask, non_blocking=True) + self.input_buffers["cache_position"].copy_(cache_position, non_blocking=True) + self.input_buffers["audio_attention_mask"].copy_(audio_attention_mask, non_blocking=True) + self.input_buffers["fast_forward_attention_mask"].copy_(fast_forward_attention_mask, non_blocking=True) + + # Run the captured graph + self.graph.replay() + + return self.output_buffers["hidden_states"], None, None diff --git a/higgs_audio/model/custom_modules.py b/higgs_audio/model/custom_modules.py new file mode 100644 index 0000000000000000000000000000000000000000..eb585c8cc8edb6be7762cbc5ccd149e77079ecb3 --- /dev/null +++ b/higgs_audio/model/custom_modules.py @@ -0,0 +1,155 @@ +import torch +import torch.nn as nn + + +class PartiallyFrozenEmbedding(nn.Module): + """Split an existing `nn.Embedding` module that splits the embedding into: + + - A frozen embedding for indices [0..freeze_until_idx]. + - A trainable embedding for indices [freeze_until_idx+1..vocab_size-1]. + + This should work with both Zero-2 and Zero-3 seamlessly + """ + + def __init__(self, original_embedding: nn.Embedding, freeze_until_idx: int): + """ + :param original_embedding: An instance of nn.Embedding (the original embedding layer). + :param freeze_until_idx: The index up to which the embedding is frozen (excluding). The freeze_until_idx is not frozen. + """ + super().__init__() + self.freeze_until_idx = freeze_until_idx + self.original_vocab_size = original_embedding.num_embeddings + self.embedding_dim = original_embedding.embedding_dim + + # Split the original embedding into frozen and trainable parts + self.embedding_frozen = nn.Embedding( + freeze_until_idx, + self.embedding_dim, + dtype=original_embedding.weight.dtype, + device=original_embedding.weight.device, + ) + self.embedding_trainable = nn.Embedding( + self.original_vocab_size - freeze_until_idx, + self.embedding_dim, + dtype=original_embedding.weight.dtype, + device=original_embedding.weight.device, + ) + + # Copy weights from the original embedding into the frozen and trainable parts + with torch.no_grad(): + self.embedding_frozen.weight.copy_(original_embedding.weight[:freeze_until_idx]) + self.embedding_trainable.weight.copy_(original_embedding.weight[freeze_until_idx:]) + + # Freeze the frozen embedding + self.embedding_frozen.weight.requires_grad = False + + def forward(self, input_ids: torch.Tensor) -> torch.Tensor: + """ + Forward pass for the split embedding wrapper. + :param input_ids: Tensor of shape [batch_size, seq_len] with indices in [0..original_vocab_size-1]. + """ + # Masks to separate frozen and trainable indices + # (bsz, seq_len) + mask_frozen = input_ids < self.freeze_until_idx + mask_trainable = ~mask_frozen + + # Output tensor for embedding results + batch_size, seq_len = input_ids.shape + embeddings = torch.zeros( + batch_size, + seq_len, + self.embedding_dim, + device=input_ids.device, + dtype=self.embedding_frozen.weight.dtype, + ) + + # Handle frozen embedding + if mask_frozen.any(): + frozen_ids = input_ids[mask_frozen] + frozen_emb = self.embedding_frozen(frozen_ids) + embeddings[mask_frozen] = frozen_emb + + # Handle trainable embedding + if mask_trainable.any(): + # Adjust trainable IDs to the local index space of the trainable embedding + trainable_ids = input_ids[mask_trainable] - (self.freeze_until_idx) + trainable_emb = self.embedding_trainable(trainable_ids) + embeddings[mask_trainable] = trainable_emb + + return embeddings + + def to_unsplit(self) -> nn.Embedding: + unsplit_embedding = nn.Embedding( + self.original_vocab_size, + self.embedding_dim, + dtype=self.embedding_frozen.weight.dtype, + device=self.embedding_frozen.weight.device, + ) + + with torch.no_grad(): + unsplit_embedding.weight[: self.freeze_until_idx].copy_(self.embedding_frozen.weight) + unsplit_embedding.weight[self.freeze_until_idx :].copy_(self.embedding_trainable.weight) + + return unsplit_embedding + + +class PartiallyFrozenLinear(nn.Module): + """A wrapper around nn.Linear to partially freeze part of the weight matrix.""" + + def __init__(self, original_linear: nn.Linear, freeze_until_idx: int): + """ + :param original_linear: The original nn.Linear layer. + :param freeze_until_idx: The index up to which the rows of the weight matrix are frozen. + """ + super().__init__() + assert original_linear.bias is None, "Currently only support linear module without bias" + + self.freeze_until_idx = freeze_until_idx + self.input_dim = original_linear.in_features + self.output_dim = original_linear.out_features + + # Create frozen and trainable linear layers + self.linear_frozen = nn.Linear( + self.input_dim, + freeze_until_idx, + bias=False, + dtype=original_linear.weight.dtype, + device=original_linear.weight.device, + ) + self.linear_trainable = nn.Linear( + self.input_dim, + self.output_dim - freeze_until_idx, + bias=False, + dtype=original_linear.weight.dtype, + device=original_linear.weight.device, + ) + + # Copy weights from the original linear layer + with torch.no_grad(): + self.linear_frozen.weight.copy_(original_linear.weight[:freeze_until_idx]) + self.linear_trainable.weight.copy_(original_linear.weight[freeze_until_idx:]) + + # Freeze the frozen linear layer + self.linear_frozen.weight.requires_grad = False + + def forward(self, input_tensor): + # input_tensor: (bsz, seq_len, hidden_state_dim) + frozen_output = self.linear_frozen(input_tensor) + trainable_output = self.linear_trainable(input_tensor) + return torch.cat((frozen_output, trainable_output), dim=-1) + + def to_unsplit(self) -> nn.Linear: + unsplit_linear = nn.Linear( + self.input_dim, + self.output_dim, + bias=False, + dtype=self.linear_frozen.weight.dtype, + device=self.linear_frozen.weight.device, + ) + + # Copy weights from the frozen and trainable layers into the unsplit linear layer + with torch.no_grad(): + unsplit_linear.weight[: self.freeze_until_idx].copy_(self.linear_frozen.weight) + unsplit_linear.weight[self.freeze_until_idx :].copy_(self.linear_trainable.weight) + + return unsplit_linear diff --git a/higgs_audio/model/modeling_higgs_audio.py b/higgs_audio/model/modeling_higgs_audio.py new file mode 100644 index 0000000000000000000000000000000000000000..44b15df580cb3b28a61c142eb4d4135a27f8f55c --- /dev/null +++ b/higgs_audio/model/modeling_higgs_audio.py @@ -0,0 +1,2388 @@ +"""Higgs-Audio is an end-to-end multimodal model with the capability to understand and generate text / audio.""" + +import torch +import torch.nn as nn +import math +import glob +import functools +import os +from collections import defaultdict, OrderedDict +from dataclasses import dataclass +from enum import Enum +from safetensors.torch import load_file +from typing import Optional, Tuple, Union, List, Dict, Any + +from transformers import AutoTokenizer +from transformers.modeling_outputs import BaseModelOutput +from transformers.models.whisper.modeling_whisper import WhisperEncoderLayer +from transformers.models.llama.modeling_llama import ( + LlamaDecoderLayer, + LlamaRMSNorm, + LlamaRotaryEmbedding, + LLAMA_ATTENTION_CLASSES, + LlamaMLP, + LlamaRMSNorm, +) +from transformers.modeling_attn_mask_utils import AttentionMaskConverter +from transformers.cache_utils import Cache, DynamicCache, StaticCache +from transformers.generation import ( + GenerationMixin, + GenerationConfig, + LogitsProcessorList, + StoppingCriteriaList, +) +from transformers.generation.utils import GenerateNonBeamOutput +from transformers.utils import logging, ModelOutput + +from .common import HiggsAudioPreTrainedModel +from .utils import ( + merge_input_ids_with_audio_features, + count_parameters, +) +from .configuration_higgs_audio import HiggsAudioConfig, HiggsAudioEncoderConfig +from .custom_modules import PartiallyFrozenLinear, PartiallyFrozenEmbedding +from .cuda_graph_runner import CUDAGraphRunner +from .audio_head import HiggsAudioDecoderProjector + +logger = logging.get_logger(__name__) + + +class GenerationMode(Enum): + """Enum for different generation modes in HiggsAudio model.""" + + TEXT = 0 # Text generation mode + AUDIO_INIT = 1 # Audio generation mode initialization + AUDIO_IN_PROGRESS = 2 # Audio generation mode in progress + + +def _whisper_encoder_zero_shape_forward(whisper_encoder, *args, **kwargs): + """The whisper encoder does not support zero-shape tensor by default due to the following implementations + + key_states = self._shape(self.k_proj(current_states), -1, bsz) + + If `bsz` is 0, the "-1" dimension will be ambiguous and triggers error in the shape inference pass. + + See also: https://github.com/huggingface/transformers/blob/30335093276212ce74938bdfd85bfd5df31a668a/src/transformers/models/whisper/modeling_whisper.py#L306-L307 + + This function monkey-patches all `_shape` functions in the whisper encoder's self-attention layers to ensure function supports zero-shape tensor. + + #FIXME!!!! This is a temporary workaround and should be removed once the upstream issue is resolved. + + """ + + global _higgs_flash_attention_forward + + def _patched_shape(tensor: torch.Tensor, seq_len: int, bsz: int, num_heads: int, head_dim: int): + if seq_len == -1: + return tensor.view(bsz, tensor.shape[1], num_heads, head_dim).transpose(1, 2).contiguous() + else: + return tensor.view(bsz, seq_len, num_heads, head_dim).transpose(1, 2).contiguous() + + def _patched_scaled_dot_product_attention( + query, + key, + value, + attn_mask=None, + dropout_p=0.0, + is_causal=False, + scale=None, + enable_gqa=False, + ) -> torch.Tensor: + # IMPORTANT! Implementation here is wrong and is only for the purpose of obtaining the correct attn_weight shape + if enable_gqa: + key = key.repeat_interleave(query.size(-3) // key.size(-3), -3) + value = value.repeat_interleave(query.size(-3) // value.size(-3), -3) + + attn_weight = query @ key.transpose(-2, -1) + return attn_weight @ value + + # Apply monkey-patch + if whisper_encoder.config._attn_implementation != "flash_attention_2": + old_shape_functions = [] + for layer in whisper_encoder.layers: + old_shape_functions.append(getattr(layer.self_attn, "_shape")) + layer.self_attn._shape = functools.partial( + _patched_shape, + num_heads=layer.self_attn.num_heads, + head_dim=layer.self_attn.head_dim, + ) + + original_scaled_dot_product_attention = torch.nn.functional.scaled_dot_product_attention + torch.nn.functional.scaled_dot_product_attention = _patched_scaled_dot_product_attention + + out = whisper_encoder(*args, **kwargs) + torch.nn.functional.scaled_dot_product_attention = original_scaled_dot_product_attention + + # Restore the original shape functions + if whisper_encoder.config._attn_implementation != "flash_attention_2": + for layer, old_shape_function in zip(whisper_encoder.layers, old_shape_functions): + layer.self_attn._shape = old_shape_function + + return out + + +def _prepare_4d_causal_attention_mask_with_cache_position( + attention_mask: torch.Tensor, + sequence_length: int, + target_length: int, + dtype: torch.dtype, + device: torch.device, + min_dtype: float, + cache_position: torch.Tensor, + batch_size: int, +): + """ + Creates a causal 4D mask of shape `(batch_size, 1, query_length, key_value_length)` from a 2D mask of shape + `(batch_size, key_value_length)`, or if the input `attention_mask` is already 4D, do nothing. + + Args: + attention_mask (`torch.Tensor`): + A 2D attention mask of shape `(batch_size, key_value_length)` or a 4D attention mask of shape `(batch_size, 1, query_length, key_value_length)`. + sequence_length (`int`): + The sequence length being processed. + target_length (`int`): + The target length: when generating with static cache, the mask should be as long as the static cache, to account for the 0 padding, the part of the cache that is not filled yet. + dtype (`torch.dtype`): + The dtype to use for the 4D attention mask. + device (`torch.device`): + The device to plcae the 4D attention mask on. + min_dtype (`float`): + The minimum value representable with the dtype `dtype`. + cache_position (`torch.Tensor`): + Indices depicting the position of the input sequence tokens in the sequence. + batch_size (`torch.Tensor`): + Batch size. + """ + if attention_mask is not None and attention_mask.dim() == 4: + # In this case we assume that the mask comes already in inverted form and requires no inversion or slicing. + causal_mask = attention_mask + else: + causal_mask = torch.full( + (sequence_length, target_length), + fill_value=min_dtype, + dtype=dtype, + device=device, + ) + if sequence_length != 1: + causal_mask = torch.triu(causal_mask, diagonal=1) + causal_mask *= torch.arange(target_length, device=device) > cache_position.reshape(-1, 1) + causal_mask = causal_mask[None, None, :, :].expand(batch_size, 1, -1, -1) + if attention_mask is not None: + causal_mask = causal_mask.clone() # copy to contiguous memory for in-place edit + mask_length = attention_mask.shape[-1] + padding_mask = causal_mask[:, :, :, :mask_length] + attention_mask[:, None, None, :] + padding_mask = padding_mask == 0 + causal_mask[:, :, :, :mask_length] = causal_mask[:, :, :, :mask_length].masked_fill( + padding_mask, min_dtype + ) + + return causal_mask + + +class HiggsAudioFeatureProjector(nn.Module): + """Projector that maps audio features extracted by Whisper to hidden state of the text model.""" + + def __init__(self, config: HiggsAudioConfig): + super().__init__() + self.linear = nn.Linear( + config.audio_encoder_config.d_model, + config.text_config.hidden_size, + bias=True, + ) + + def forward(self, audio_features): + hidden_states = self.linear(audio_features) + return hidden_states + + +# Revised on top of transformers.models.qwen2_audio.modeling_qwen2_audio with Qwen2AudioEncoder --> HiggsAudioEncoder +# The code was originally borrowed from WhisperEncoder +class HiggsAudioEncoder(HiggsAudioPreTrainedModel): + """ + Transformer encoder consisting of *config.encoder_layers* self attention layers. Each layer is a + [`WhisperEncoderLayer`]. + + Args: + config: HiggsAudioEncoderConfig + """ + + # Ignore copy + config_class = HiggsAudioEncoderConfig + main_input_name = "input_features" + _no_split_modules = ["WhisperEncoderLayer"] + + def __init__(self, config: HiggsAudioEncoderConfig): + super().__init__(config) + self.dropout = config.dropout + self.layerdrop = config.encoder_layerdrop + + embed_dim = config.d_model + self.num_mel_bins = config.num_mel_bins + self.padding_idx = config.pad_token_id + self.max_source_positions = config.max_source_positions + self.embed_scale = math.sqrt(embed_dim) if config.scale_embedding else 1.0 + + self.conv1 = nn.Conv1d(self.num_mel_bins, embed_dim, kernel_size=3, padding=1) + self.conv2 = nn.Conv1d(embed_dim, embed_dim, kernel_size=3, stride=2, padding=1) + + self.embed_positions = nn.Embedding(self.max_source_positions, embed_dim) + self.embed_positions.requires_grad_(False) + + # Flash Attention 2 does not support zero shape tensor, so we have to use sdpa implementation for the Whisper component. + self.layers = nn.ModuleList([WhisperEncoderLayer(config) for _ in range(config.encoder_layers)]) + self.layer_norm = nn.LayerNorm(config.d_model) + # Ignore copy + self.avg_pooler = nn.AvgPool1d(2, stride=2) + + self.gradient_checkpointing = False + # Initialize weights and apply final processing + self.post_init() + + def _freeze_parameters(self): + for param in self.parameters(): + param.requires_grad = False + self._requires_grad = False + + def get_input_embeddings(self) -> nn.Module: + return self.conv1 + + def set_input_embeddings(self, value: nn.Module): + self.conv1 = value + + def forward( + self, + input_features, + attention_mask=None, + head_mask=None, + output_attentions=None, + output_hidden_states=None, + return_dict=None, + check_seq_length=True, + ): + r""" + Args: + input_features (`torch.LongTensor` of shape `(batch_size, feature_size, sequence_length)`): + Float values of mel features extracted from the raw speech waveform. Raw speech waveform can be + obtained by loading a `.flac` or `.wav` audio file into an array of type `List[float]` or a + `numpy.ndarray`, *e.g.* via the soundfile library (`pip install soundfile`). To prepare the array into + `input_features`, the [`AutoFeatureExtractor`] should be used for extracting the mel features, padding + and conversion into a tensor of type `torch.FloatTensor`. See [`~WhisperFeatureExtractor.__call__`] + attention_mask (`torch.Tensor`)`, *optional*): + HiggsAudio does not support masking of the `input_features`, this argument is preserved for compatibility, + but it is not used. By default the silence in the input log mel spectrogram are ignored. + head_mask (`torch.Tensor` of shape `(encoder_layers, encoder_attention_heads)`, *optional*): + Mask to nullify selected heads of the attention modules. Mask values selected in `[0, 1]`: + + - 1 indicates the head is **not masked**, + - 0 indicates the head is **masked**. + output_attentions (`bool`, *optional*): + Whether or not to return the attentions tensors of all attention layers. See `attentions` under + returned tensors for more detail. + output_hidden_states (`bool`, *optional*): + Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors + for more detail. + return_dict (`bool`, *optional*): + Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. + """ + + expected_seq_length = self.config.max_source_positions * self.conv1.stride[0] * self.conv2.stride[0] + if check_seq_length and (input_features.shape[-1] != expected_seq_length): + raise ValueError( + f"HiggsAudio expects the mel input features to be of length {expected_seq_length}, but found {input_features.shape[-1]}. Make sure to pad the input mel features to {expected_seq_length}." + ) + + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + # Ignore copy + input_features = input_features.to(dtype=self.conv1.weight.dtype, device=self.conv1.weight.device) + + inputs_embeds = nn.functional.gelu(self.conv1(input_features)) + inputs_embeds = nn.functional.gelu(self.conv2(inputs_embeds)) + + inputs_embeds = inputs_embeds.permute(0, 2, 1) + embed_pos = self.embed_positions.weight + + hidden_states = inputs_embeds + embed_pos + hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training) + + encoder_states = () if output_hidden_states else None + all_attentions = () if output_attentions else None + + # check if head_mask has a correct number of layers specified if desired + if head_mask is not None: + assert head_mask.size()[0] == (len(self.layers)), ( + f"The head_mask should be specified for {len(self.layers)} layers, but it is for {head_mask.size()[0]}." + ) + + for idx, encoder_layer in enumerate(self.layers): + if output_hidden_states: + encoder_states = encoder_states + (hidden_states,) + # add LayerDrop (see https://arxiv.org/abs/1909.11556 for description) + to_drop = False + if self.training: + dropout_probability = torch.rand([]) + if dropout_probability < self.layerdrop: # skip the layer + to_drop = True + + # Ignore copy + if to_drop: + layer_outputs = (None, None) + else: + if self.gradient_checkpointing and self.training: + layer_outputs = self._gradient_checkpointing_func( + encoder_layer.__call__, + hidden_states, + attention_mask, + (head_mask[idx] if head_mask is not None else None), + output_attentions, + ) + else: + layer_outputs = encoder_layer( + hidden_states, + attention_mask, + layer_head_mask=(head_mask[idx] if head_mask is not None else None), + output_attentions=output_attentions, + ) + + hidden_states = layer_outputs[0] + + if output_attentions: + all_attentions = all_attentions + (layer_outputs[1],) + + # Ignore copy + hidden_states = hidden_states.permute(0, 2, 1) + # If the sequence length after average pooling is not divisible by the sequence parallel size, we would duplicate it across the sequence parallel ranks. + # In this case, gradients need to be scaled up because the subsequent scaling up in the function _apply_audio_tower is skipped. + hidden_states = self.avg_pooler(hidden_states) + + hidden_states = hidden_states.permute(0, 2, 1) + + hidden_states = self.layer_norm(hidden_states) + + if output_hidden_states: + encoder_states = encoder_states + (hidden_states,) + + if not return_dict: + return tuple(v for v in [hidden_states, encoder_states, all_attentions] if v is not None) + return BaseModelOutput( + last_hidden_state=hidden_states, + hidden_states=encoder_states, + attentions=all_attentions, + ) + + # Ignore copy + def _get_feat_extract_output_lengths(self, input_lengths: torch.LongTensor): + """ + Computes the output length of the convolutional layers and the output length of the audio encoder + """ + # TODO(sxjscience) Double confirm the formula + input_lengths = (input_lengths - 1) // 2 + 1 + output_lengths = (input_lengths - 2) // 2 + 1 + return input_lengths, output_lengths + + +class HiggsAudioDualFFNDecoderLayer(nn.Module): + """We implement a dual-path FFN decoder layer where the audio tokens and text tokens go through separate FFN layers. + + The audio and text tokens share the text-attention layer, but will be encoded with separate feedforward layers. + In addition, the audio tokens can be configured to go through separate attention layer. + + Following is an illustration: + + t t t a a a t t t + | + | (audio self-attention layer) + v + t t t h'_a h'_a h'_a t t t + | + | (shared attention layer) + v + h_t h_t h_t h_a h_a h_a h_t h_t h_t + | + | (separate text/audio hidden states) + v + [h_t h_t h_t h_t h_t h_t], [h_a, h_a, h_a] + | | + | (separate FFNs) | + v v + [o_t o_t o_t o_t o_t o_t], [o_a, o_a, o_a] + | + | (reorder) + v + o_t o_t o_t o_a o_a o_a o_t o_t o_t + + This has a few advantages: + 1) We are able to use a smaller FFN, or even bypass the FFN for audio tokens. This accelerates the inference speed. + 2) The Audio-FFN introduces more trainable parameters to the model. + This should have the same effect as the mixture-of-expert layer and we may expect better performance due to the scaling law. + 3) We can replace the original FFN in LLMs with the dual-path FFN without changing the model architecture. + + + """ + + def __init__( + self, + config: HiggsAudioConfig, + layer_idx: int, + fast_forward: bool = False, + use_audio_attention: bool = False, + ): + super().__init__() + text_config = config.text_config + self.hidden_size = text_config.hidden_size + self.layer_idx = layer_idx + self.self_attn = LLAMA_ATTENTION_CLASSES[config._attn_implementation](config=text_config, layer_idx=layer_idx) + + self.mlp = LlamaMLP(text_config) + + if not fast_forward: + if use_audio_attention: + self.audio_attn = LLAMA_ATTENTION_CLASSES[config._attn_implementation]( + config=text_config, layer_idx=layer_idx + 1 + ) + self.audio_post_audio_attn_layer_norm = LlamaRMSNorm( + text_config.hidden_size, eps=text_config.rms_norm_eps + ) + + self.audio_mlp = LlamaMLP(text_config) + self.audio_input_layernorm = LlamaRMSNorm(text_config.hidden_size, eps=text_config.rms_norm_eps) + self.audio_post_attention_layernorm = LlamaRMSNorm(text_config.hidden_size, eps=text_config.rms_norm_eps) + + self.use_audio_attention = use_audio_attention + self.fast_forward = fast_forward + if self.fast_forward: + assert not self.use_audio_attention, ( + "We cannot use audio_attention if the layer is marked as fast-forward." + ) + self.input_layernorm = LlamaRMSNorm(text_config.hidden_size, eps=text_config.rms_norm_eps) + self.post_attention_layernorm = LlamaRMSNorm(text_config.hidden_size, eps=text_config.rms_norm_eps) + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + audio_attention_mask: Optional[torch.Tensor] = None, + fast_forward_attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + audio_out_mask: Optional[torch.BoolTensor] = None, + is_decoding_audio_token: Optional[bool] = None, + past_key_value: Optional[Cache] = None, + output_attentions: Optional[bool] = False, + use_cache: Optional[bool] = False, + cache_position: Optional[torch.LongTensor] = None, + position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # will become mandatory in v4.46 + is_using_cuda_graph: Optional[bool] = False, + **kwargs, + ): + """ + Args: + hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)` + attention_mask (`torch.FloatTensor`, *optional*): + attention mask of size `(batch_size, sequence_length)` if flash attention is used or `(batch_size, 1, + query_sequence_length, key_sequence_length)` if default attention is used. + position_ids + IDs of positions in the input sequence + audio_out_mask + Mask for identifying the audio tokens. Size (batch_size, sequence_length) + 1 --> location contains audio_out + 0 --> location does not contain audio_out + + When use_cache is True and not in torch compile mode, the audio_out_mask contains audio_out masks for + all tokens up to the current token. That means, it has size (batch_size, sequence_length) while + hidden_states will have size (batch_size, 1). In the torch compile mode, the audio_out_mask will have + size (batch_size, 1). + is_decoding_audio_token + Used in the torch compile mode to determine if the current token is an audio token or not. + past_key_value (`Cache`, *optional*): cached past key and value projection states. We fetch the corresponding cached key/value via the layer_idx. + output_attentions (`bool`, *optional*): + Whether or not to return the attentions tensors of all attention layers. See `attentions` under + returned tensors for more detail. + use_cache (`bool`, *optional*): + If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding + (see `past_key_values`). + cache_position (`torch.LongTensor` of shape `(sequence_length)`, *optional*): + Indices depicting the position of the input sequence tokens in the sequence + position_embeddings (`Tuple[torch.FloatTensor, torch.FloatTensor]`, *optional*): + Tuple containing the cosine and sine positional embeddings of shape `(batch_size, seq_len, head_dim)`, + with `head_dim` being the embedding dimension of each attention head. + is_using_cuda_graph (`bool`, *optional*): + Indicates whether the model is running by cuda graph. + kwargs (`dict`, *optional*): + Arbitrary kwargs to be ignored, used for FSDP and other methods that injects code + into the model + """ + residual = hidden_states + target_length = hidden_states.shape[1] + use_static_cache = isinstance(past_key_value, StaticCache) + decode_stage = hidden_states.shape[1] == 1 + if is_using_cuda_graph: + assert decode_stage and use_static_cache, ( + "The CUDA graph mode should only be used in the decoding stage with static cache." + ) + + # If we are decoding an audio token and the layer is marked as fast-forward, + # we can skip it. + if is_decoding_audio_token and self.fast_forward: + return (hidden_states,) + + has_audio_out = audio_out_mask is not None and audio_out_mask.shape[0] > 0 + + audio_out_mask_sq = audio_out_mask + + if self.fast_forward and has_audio_out: + original_hidden_states = hidden_states.clone() + min_dtype = torch.finfo(hidden_states.dtype).min + if attention_mask is None: + attention_mask = ~audio_out_mask + + if self.self_attn.config._attn_implementation != "flash_attention_2": + sequence_length = audio_out_mask.shape[1] + attention_mask = _prepare_4d_causal_attention_mask_with_cache_position( + attention_mask=attention_mask, + sequence_length=sequence_length, + target_length=sequence_length, + dtype=hidden_states.dtype, + min_dtype=min_dtype, + device=hidden_states.device, + cache_position=cache_position, + batch_size=hidden_states.shape[0], + ) + if use_cache: + attention_mask = attention_mask[:, :, -target_length:, :] + elif len(attention_mask.shape) == 2: + # Attention mask has shape (batch_size, sequence_length) + # We should be using flash attention 2 + attention_mask = attention_mask * ~audio_out_mask + elif len(attention_mask.shape) == 4: + # When using static cache, the attention mask was already preprocessed in the previous layer + if use_static_cache: + attention_mask = fast_forward_attention_mask + else: + if use_cache: + # Attention mask has shape (batch_size, 1, query_length, key_length) + # In addition, the attention mask should be inverted, that means "1" (attend_to) --> "0", and "0" --> minimal dtype value. + attention_mask = attention_mask.masked_fill( + audio_out_mask[:, -target_length:].reshape(audio_out_mask.shape[0], 1, target_length, 1) + | audio_out_mask.reshape(audio_out_mask.shape[0], 1, 1, audio_out_mask.shape[1]), + min_dtype, + ) + else: + attention_mask = attention_mask.masked_fill( + audio_out_mask.reshape(audio_out_mask.shape[0], 1, audio_out_mask.shape[1], 1) + | audio_out_mask.reshape(audio_out_mask.shape[0], 1, 1, audio_out_mask.shape[1]), + min_dtype, + ) + else: + raise NotImplementedError(f"Unsupported attention_mask format, attention_mask={attention_mask}") + + if ( + self.self_attn.config._attn_implementation == "sdpa" + and attention_mask is not None + and attention_mask.device.type == "cuda" + and not output_attentions + ): + # Attend to all tokens in fully masked rows in the causal_mask, for example the relevant first rows when + # using left padding. This is required by F.scaled_dot_product_attention memory-efficient attention path. + # Details: https://github.com/pytorch/pytorch/issues/110213 + attention_mask = AttentionMaskConverter._unmask_unattended(attention_mask, min_dtype) + + if has_audio_out and not self.fast_forward: + # Apply separate layernorm layers for audio tokens and text tokens + if use_cache: + hidden_states = torch.where( + audio_out_mask_sq[:, -target_length:].unsqueeze(-1), + self.audio_input_layernorm(hidden_states), + self.input_layernorm(hidden_states), + ) + else: + hidden_states = torch.where( + audio_out_mask_sq.unsqueeze(-1), + self.audio_input_layernorm(hidden_states), + self.input_layernorm(hidden_states), + ) + else: + hidden_states = self.input_layernorm(hidden_states) + + # Audio Attention + if self.use_audio_attention and has_audio_out: + if use_static_cache: + assert audio_attention_mask is not None, ( + "audio_attention_mask should not be None when using static cache." + ) + + if audio_attention_mask is None: + no_audio_out_mask = (~audio_out_mask)[:, -target_length:].reshape( + audio_out_mask.shape[0], 1, target_length, 1 + ) | (~audio_out_mask).reshape(audio_out_mask.shape[0], 1, 1, audio_out_mask.shape[1]) + min_dtype = torch.finfo(hidden_states.dtype).min + + if attention_mask is None: + audio_attention_mask = audio_out_mask + + if self.audio_attn.config._attn_implementation != "flash_attention_2": + sequence_length = audio_out_mask.shape[1] + audio_attention_mask = _prepare_4d_causal_attention_mask_with_cache_position( + attention_mask=audio_attention_mask, + sequence_length=sequence_length, + target_length=sequence_length, + dtype=hidden_states.dtype, + min_dtype=min_dtype, + device=hidden_states.device, + cache_position=cache_position, + batch_size=hidden_states.shape[0], + ) + if use_cache: + audio_attention_mask = audio_attention_mask[:, :, -target_length:, :] + audio_attention_mask = audio_attention_mask.masked_fill(no_audio_out_mask, min_dtype) + elif len(attention_mask.shape) == 2: + # Attention mask has shape (batch_size, sequence_length) + audio_attention_mask = attention_mask * audio_out_mask + elif len(attention_mask.shape) == 4: + # Attention mask has shape (batch_size, 1, query_length, key_length) + # In addition, the attention mask should be inverted. This means "1" (attend_to) --> "0", and "0" --> minimal dtype value. + audio_attention_mask = attention_mask.masked_fill(no_audio_out_mask, min_dtype) + else: + raise NotImplementedError(f"Unsupported attention_mask format, attention_mask={attention_mask}") + + if ( + self.audio_attn.config._attn_implementation == "sdpa" + and audio_attention_mask is not None + and audio_attention_mask.device.type == "cuda" + and not output_attentions + ): + # Attend to all tokens in fully masked rows in the causal_mask, for example the relevant first rows when + # using left padding. This is required by F.scaled_dot_product_attention memory-efficient attention path. + # Details: https://github.com/pytorch/pytorch/issues/110213 + audio_attention_mask = AttentionMaskConverter._unmask_unattended(audio_attention_mask, min_dtype) + + audio_attention_mask = audio_attention_mask.contiguous() + + audio_hidden_states, audio_self_attn_weights, audio_present_key_value = self.audio_attn( + hidden_states=hidden_states, + attention_mask=audio_attention_mask, + position_ids=position_ids, + past_key_value=past_key_value, + output_attentions=output_attentions, + use_cache=use_cache, + cache_position=cache_position, + position_embeddings=position_embeddings, + **kwargs, + ) + audio_hidden_states = residual + audio_hidden_states + if use_cache: + residual = torch.where( + audio_out_mask_sq[:, -target_length:].unsqueeze(-1), + audio_hidden_states, + residual, + ) + else: + residual = torch.where(audio_out_mask_sq.unsqueeze(-1), audio_hidden_states, residual) + audio_hidden_states = self.audio_post_audio_attn_layer_norm(audio_hidden_states) + if use_cache: + hidden_states = torch.where( + audio_out_mask_sq[:, -target_length:].unsqueeze(-1), + audio_hidden_states, + hidden_states, + ) + else: + hidden_states = torch.where(audio_out_mask_sq.unsqueeze(-1), audio_hidden_states, hidden_states) + + # Text Attention + hidden_states, self_attn_weights, present_key_value = self.self_attn( + hidden_states=hidden_states, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_value=past_key_value, + output_attentions=output_attentions, + use_cache=use_cache, + cache_position=cache_position, + position_embeddings=position_embeddings, + **kwargs, + ) + hidden_states = residual + hidden_states + + # Apply Dual-path FFN + residual = hidden_states + + if has_audio_out and not self.fast_forward: + if use_cache: + real_audio_out_mask = audio_out_mask_sq[:, -target_length:] + else: + real_audio_out_mask = audio_out_mask_sq + + # Make whole graph in decode stage + if decode_stage and is_using_cuda_graph: + assert is_decoding_audio_token is not None, ( + "is_decoding_audio_token should be present in the decoding stage." + ) + if is_decoding_audio_token: + hidden_states = self.audio_post_attention_layernorm(hidden_states) + hidden_states = self.audio_mlp(hidden_states) + else: + hidden_states = self.post_attention_layernorm(hidden_states) + hidden_states = self.mlp(hidden_states) + residual = residual + hidden_states + else: + text_hidden_states = self.post_attention_layernorm(hidden_states[~real_audio_out_mask]) + audio_hidden_states = self.audio_post_attention_layernorm(hidden_states[real_audio_out_mask]) + + text_hidden_states = self.mlp(text_hidden_states) + residual[~real_audio_out_mask] += text_hidden_states + + audio_hidden_states = self.audio_mlp(audio_hidden_states) + residual[real_audio_out_mask] += audio_hidden_states + + hidden_states = residual + else: + hidden_states = self.post_attention_layernorm(hidden_states) + hidden_states = self.mlp(hidden_states) + hidden_states = residual + hidden_states + + if self.fast_forward and has_audio_out: + if use_cache: + hidden_states = torch.where( + audio_out_mask_sq[:, -target_length:].unsqueeze(-1), + original_hidden_states, + hidden_states, + ) + else: + hidden_states = torch.where( + audio_out_mask_sq.unsqueeze(-1), + original_hidden_states, + hidden_states, + ) + + outputs = (hidden_states,) + + if output_attentions: + if self.use_audio_attention: + # The returned attn weights have shape (batch_size, num_heads + num_audio_attn_heads, seq_length, seq_length) + outputs += (torch.concat([self_attn_weights, audio_self_attn_weights], dim=1),) + else: + # The returned attn weights have shape (batch_size, num_heads, seq_length, seq_length) + outputs += (self_attn_weights,) + + if use_cache: + outputs += (present_key_value,) + + return outputs + + +@dataclass +class HiggsAudioModelOutputWithPast(ModelOutput): + loss: Optional[torch.FloatTensor] = None + llm_loss: Optional[torch.FloatTensor] = None + audio_loss: Optional[torch.FloatTensor] = None + codebook_losses: Optional[torch.FloatTensor] = None + logits: Optional[torch.FloatTensor] = None + expanded_input_ids: Optional[torch.LongTensor] = None + expanded_labels: Optional[torch.LongTensor] = None + audio_in_mask: Optional[torch.BoolTensor] = None + audio_in_discrete_codes_mask: Optional[torch.BoolTensor] = None + audio_out_mask: Optional[torch.BoolTensor] = None + attention_mask: Optional[torch.BoolTensor] = None + audio_logits: Optional[torch.FloatTensor] = None + past_key_values: Optional[Cache] = None + hidden_states: Optional[Tuple[torch.FloatTensor, ...]] = None + audio_hidden_states: Optional[Tuple[torch.FloatTensor, ...]] = None + attentions: Optional[Tuple[torch.FloatTensor, ...]] = None + + +@dataclass +class HiggsAudioGenerationOutput(ModelOutput): + """ + Outputs of HiggsAudio generation models, when using non-beam methods. + + Args: + sequences (`torch.LongTensor` of shape `(batch_size, sequence_length)`): + The generated sequences. The second dimension (sequence_length) is either equal to `max_length` or shorter + if all batches finished early due to the `eos_token_id`. + audio_sequences (`tuple(torch.LongTensor)` *optional*): + The generated discrete audio codes. These codes can be used to fill-in related locations of <|AUDIO_OUT|> at input sequences. + scores (`tuple(torch.FloatTensor)` *optional*, returned when `output_scores=True`): + Processed prediction scores of the language modeling head (scores for each vocabulary token before SoftMax) + at each generation step. Tuple of `torch.FloatTensor` with up to `max_new_tokens` elements (one element for + each generated token). + If the generated token is a text token, the tensor will have shape `(batch_size, config.vocab_size)`. + If the generated token is an audio token, the tensor will have shape `(config.audio_num_codebooks, self.audio_codebook_size)` + logits (`tuple(torch.FloatTensor)` *optional*, returned when `output_logits=True`): + Unprocessed prediction scores of the language modeling head or the audio head (scores for each vocabulary token before SoftMax) + at each generation step. Tuple of `torch.FloatTensor` with up to `max_new_tokens` elements (one element for + each generated token). + If the generated token is a text token, the tensor will have shape `(batch_size, config.vocab_size)`. + If the generated token is an audio token, the tensor will have shape `(config.audio_num_codebooks, self.audio_codebook_size)` + attentions (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `output_attentions=True`): + Tuple (one element for each generated token) of tuples (one element for each layer of the decoder) of + `torch.FloatTensor` of shape `(batch_size, num_heads, generated_length, sequence_length)`. + hidden_states (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `output_hidden_states=True`): + Tuple (one element for each generated token) of tuples (one element for each layer of the decoder) of + `torch.FloatTensor` of shape `(batch_size, generated_length, hidden_size)`. + past_key_values (`tuple(tuple(torch.FloatTensor)))`, *optional*, returned when `use_cache=True`): + Returns the model cache, used to speed up decoding. Different models have a different cache format, check + the model's documentation. Usually, a [`~cache_utils.Cache`] instance. + """ + + sequences: torch.LongTensor = None + audio_sequences: Optional[List[torch.LongTensor]] = None + scores: Optional[Tuple[torch.FloatTensor]] = None + logits: Optional[Tuple[torch.FloatTensor]] = None + attentions: Optional[Tuple[Tuple[torch.FloatTensor]]] = None + hidden_states: Optional[Tuple[Tuple[torch.FloatTensor]]] = None + past_key_values: Optional[Tuple[Tuple[Tuple[torch.FloatTensor]]]] = None + + +class HiggsAudioModel(HiggsAudioPreTrainedModel, GenerationMixin): + """Higgs-Audio is an end-to-end multimodal model with the capability to understand and generate text / audio. + + Consider the following example for mixed text/audio understanding / generation: + + - input_tokens: <|audio_bos|>[AUDIO]<|audio_eos|><|audio_bos|>[AUDIO]<|audio_eos|> + - input_tokens: <|audio_bos|>[AUDIO]<|audio_eos|><|audio_out_bos|>[AUDIO_OUT]<|audio_eos|> + + We will fill [AUDIO] with the audio features extracted by Whisper and fill [AUDIO_OUT] with the audio tokens. + + Consider the following example for mixed text/audio generation: + + text: <|audio_out_bos|> MASK MASK MASK MASK MASK <|audio_eos|> [text_token1] + audio: MASK <|audio_stream_bos|> [audio_token1] [audio_token2] [audio_token3] <|audio_stream_eos|> MASK MASK + token_type: 0 1 1 1 1 1 0 0 + + """ + + _supports_cache_class = True + _supports_static_cache = True + + def __init__(self, config: HiggsAudioConfig): + super().__init__(config) + self.padding_idx = config.pad_token_id + self.audio_in_token_idx = config.audio_in_token_idx + self.audio_out_token_idx = config.audio_out_token_idx + self.audio_out_bos_token_id = config.audio_out_bos_token_id if "audio_out_bos_token_id" in config else None + self.audio_eos_token_id = config.audio_eos_token_id if "audio_eos_token_id" in config else None + self.vocab_size = config.text_config.vocab_size + self.audio_num_codebooks = config.audio_num_codebooks + self.use_delay_pattern = config.use_delay_pattern + self.use_audio_out_embed_projector = config.use_audio_out_embed_projector + self.use_audio_out_self_attention = config.use_audio_out_self_attention + + self.embed_tokens = nn.Embedding(self.vocab_size, config.text_config.hidden_size, self.padding_idx) + + if config.audio_adapter_type == "dual_ffn": + layer_idx = 0 + layers = [] + for j in range(config.text_config.num_hidden_layers): + if j in config.audio_dual_ffn_layers: + layers.append( + HiggsAudioDualFFNDecoderLayer( + config, + layer_idx, + use_audio_attention=self.use_audio_out_self_attention, + ) + ) + layer_idx += 2 if self.use_audio_out_self_attention else 1 + else: + layers.append(LlamaDecoderLayer(config.text_config, layer_idx)) + layer_idx += 1 + self.layers = nn.ModuleList(layers) + elif config.audio_adapter_type == "dual_ffn_fast_forward": + layer_idx = 0 + layers = [] + for j in range(config.text_config.num_hidden_layers): + if j in config.audio_dual_ffn_layers: + layers.append( + HiggsAudioDualFFNDecoderLayer( + config, + layer_idx, + fast_forward=False, + use_audio_attention=self.use_audio_out_self_attention, + ) + ) + layer_idx += 2 if self.use_audio_out_self_attention else 1 + else: + layers.append( + HiggsAudioDualFFNDecoderLayer( + config, + layer_idx, + fast_forward=True, + use_audio_attention=False, + ) + ) + layer_idx += 1 + self.layers = nn.ModuleList(layers) + elif config.audio_adapter_type == "stack": + self.layers = nn.ModuleList( + [ + LlamaDecoderLayer(config.text_config, layer_idx) + for layer_idx in range(config.text_config.num_hidden_layers) + ] + ) + layer_idx = config.text_config.num_hidden_layers + else: + raise NotImplementedError(f"Audio adapter type {config.audio_adapter_type} not implemented.") + + self.num_activation_checkpointing_layers = len(self.layers) + + self.decode_graph_runners = defaultdict(dict[bool, CUDAGraphRunner]) + self.norm = LlamaRMSNorm(config.text_config.hidden_size, eps=config.text_config.rms_norm_eps) + self.rotary_emb = LlamaRotaryEmbedding(config=config.text_config) + + if not config.skip_audio_tower: + self.audio_tower = HiggsAudioEncoder(config.audio_encoder_config) + self.audio_encoder_proj = HiggsAudioFeatureProjector(config) + else: + self.audio_tower = None + self.audio_encoder_proj = None + self.audio_decoder_proj = HiggsAudioDecoderProjector(config, layer_idx=layer_idx) + self.audio_codebook_size = ( + config.audio_codebook_size + 2 + ) # We add 1 for the audio_stream_bos token and 1 for the audio_stream_eos token + + if config.use_audio_out_embed_projector: + self.audio_out_embed_projector = nn.Linear( + config.text_config.hidden_size, + config.text_config.hidden_size, + bias=False, + ) + + self.audio_codebook_embeddings = nn.Embedding( + config.audio_num_codebooks * self.audio_codebook_size, + config.text_config.hidden_size, + ) + + self.audio_codebook_weights = ( + torch.ones(config.audio_num_codebooks) / config.audio_num_codebooks + ) # default to equal weights + self.post_init() + + def set_num_activation_checkpointing_layers(self, num_layers): + self.num_activation_checkpointing_layers = num_layers + + def set_delay_pattern(self): + self.config.use_delay_pattern = True + self.use_delay_pattern = True + + def set_audio_special_tokens(self, tokenizer: AutoTokenizer): + self.audio_out_bos_token_id = tokenizer.convert_tokens_to_ids("<|audio_out_bos|>") + self.audio_eos_token_id = tokenizer.convert_tokens_to_ids("<|audio_eos|>") + + def _embed_audio_ids(self, audio_ids): + """Embed the audio ids + + Args: + audio_ids: torch.LongTensor of shape (num_codebooks, audio_in_total_length) + + Returns: + audio_embed: torch.LongTensor of shape (audio_in_total_length, hidden_size) + """ + codebook_shift = ( + torch.arange(self.config.audio_num_codebooks, device=audio_ids.device) * self.audio_codebook_size + ) + audio_embed = self.audio_codebook_embeddings(audio_ids + codebook_shift.unsqueeze(-1)) + if self.config.audio_embed_avg: + audio_embed = torch.mean(audio_embed, dim=0) + else: + audio_embed = torch.sum(audio_embed, dim=0) + if self.use_audio_out_embed_projector: + audio_embed = self.audio_out_embed_projector(audio_embed) + return audio_embed + + def _apply_audio_tower(self, audio_features, audio_feature_attention_mask): + """Apply the audio tower to the audio features""" + + if audio_features.shape[0] == 0: + if torch.is_grad_enabled(): + # FIXME!!!!!!!! + # This is a hack to ensure that the forward+backward pass of audio_tower and audio_encoder_proj get triggered. + # The monkey patch won't overwrite the backward pass of nn.Module. + audio_outputs = _whisper_encoder_zero_shape_forward( + self.audio_tower, + audio_features, + attention_mask=None, + check_seq_length=False, + ) + selected_audio_feature = audio_outputs.last_hidden_state + audio_features_embed = self.audio_encoder_proj(selected_audio_feature) + audio_feat_out_lengths = None + return audio_features_embed, audio_feat_out_lengths + else: + return None, None + + audio_feat_lengths, audio_feat_out_lengths = self.audio_tower._get_feat_extract_output_lengths( + audio_feature_attention_mask.sum(-1) + ) + batch_size, _, max_mel_seq_len = audio_features.shape + max_seq_len = (max_mel_seq_len - 1) // 2 + 1 + # Create a sequence tensor of shape (batch_size, max_seq_len) + seq_range = ( + torch.arange( + 0, + max_seq_len, + dtype=audio_feat_lengths.dtype, + device=audio_feat_lengths.device, + ) + .unsqueeze(0) + .expand(batch_size, max_seq_len) + ) + lengths_expand = audio_feat_lengths.unsqueeze(1).expand(batch_size, max_seq_len) + # Create mask + padding_mask = seq_range < lengths_expand + + if self.config._attn_implementation != "flash_attention_2": + audio_attention_mask = padding_mask.view(batch_size, 1, 1, max_seq_len).expand( + batch_size, 1, max_seq_len, max_seq_len + ) + else: + audio_attention_mask = padding_mask + + audio_outputs = self.audio_tower(audio_features, attention_mask=audio_attention_mask) + selected_audio_feature = audio_outputs.last_hidden_state + audio_features_embed = self.audio_encoder_proj(selected_audio_feature) + + return audio_features_embed, audio_feat_out_lengths + + def _update_causal_mask( + self, + attention_mask: torch.Tensor, + input_tensor: torch.Tensor, + cache_position: torch.Tensor, + past_key_values: Cache, + output_attentions: bool, + ): + if self.config._attn_implementation == "flash_attention_2": + if attention_mask is not None and 0.0 in attention_mask: + return attention_mask + return None + + # For SDPA, when possible, we will rely on its `is_causal` argument instead of its `attn_mask` argument, in + # order to dispatch on Flash Attention 2. This feature is not compatible with static cache, as SDPA will fail + # to infer the attention mask. + past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0 + using_static_cache = isinstance(past_key_values, StaticCache) + + # When output attentions is True, sdpa implementation's forward method calls the eager implementation's forward + if self.config._attn_implementation == "sdpa" and not using_static_cache and not output_attentions: + if AttentionMaskConverter._ignore_causal_mask_sdpa( + attention_mask, + inputs_embeds=input_tensor, + past_key_values_length=past_seen_tokens, + is_training=self.training, + ): + return None + + dtype, device = input_tensor.dtype, input_tensor.device + min_dtype = torch.finfo(dtype).min + sequence_length = input_tensor.shape[1] + if using_static_cache: + target_length = past_key_values.get_max_length() + else: + target_length = ( + attention_mask.shape[-1] + if isinstance(attention_mask, torch.Tensor) + else past_seen_tokens + sequence_length + 1 + ) + + # In case the provided `attention` mask is 2D, we generate a causal mask here (4D). + causal_mask = _prepare_4d_causal_attention_mask_with_cache_position( + attention_mask, + sequence_length=sequence_length, + target_length=target_length, + dtype=dtype, + device=device, + min_dtype=min_dtype, + cache_position=cache_position, + batch_size=input_tensor.shape[0], + ) + + if ( + self.config._attn_implementation == "sdpa" + and attention_mask is not None + and attention_mask.device.type == "cuda" + and not output_attentions + ): + # Attend to all tokens in fully masked rows in the causal_mask, for example the relevant first rows when + # using left padding. This is required by F.scaled_dot_product_attention memory-efficient attention path. + # Details: https://github.com/pytorch/pytorch/issues/110213 + causal_mask = AttentionMaskConverter._unmask_unattended(causal_mask, min_dtype) + + return causal_mask + + def _prepare_all_static_kv_cache_masks(self, hidden_states, attention_mask, audio_out_mask, past_key_values): + target_length = hidden_states.shape[1] + cur_pos = audio_out_mask.shape[1] + min_dtype = torch.finfo(hidden_states.dtype).min + assert len(attention_mask.shape) == 4, "Only support SDPA for now" + kv_cache_len = past_key_values.get_max_cache_shape() + audio_out_mask_padded = torch.nn.functional.pad(audio_out_mask, (0, kv_cache_len - cur_pos), value=True) + fast_forward_attention_mask = attention_mask.masked_fill( + audio_out_mask_padded[:, audio_out_mask.shape[1] - target_length : audio_out_mask.shape[1]].reshape( + audio_out_mask_padded.shape[0], 1, target_length, 1 + ) + | audio_out_mask_padded.reshape(audio_out_mask_padded.shape[0], 1, 1, audio_out_mask_padded.shape[1]), + min_dtype, + ) + + no_audio_out_mask = ~audio_out_mask + no_audio_out_mask = torch.nn.functional.pad( + no_audio_out_mask, (0, kv_cache_len - audio_out_mask.shape[1]), value=False + ) + no_audio_out_mask = no_audio_out_mask[ + :, audio_out_mask.shape[1] - target_length : audio_out_mask.shape[1] + ].reshape(audio_out_mask.shape[0], 1, target_length, 1) | no_audio_out_mask.reshape( + audio_out_mask.shape[0], 1, 1, kv_cache_len + ) + audio_attention_mask = attention_mask.masked_fill(no_audio_out_mask, min_dtype) + return fast_forward_attention_mask, audio_attention_mask + + def _forward_core( + self, + hidden_states: torch.Tensor, + causal_mask: torch.Tensor, + position_ids: torch.Tensor, + audio_discrete_codes_mask: torch.Tensor, + cache_position: torch.Tensor, + past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]], + use_cache: bool, + audio_attention_mask: torch.Tensor, + fast_forward_attention_mask: torch.Tensor, + output_attentions: bool, + output_hidden_states: bool, + is_decoding_audio_token: Optional[bool] = None, + is_using_cuda_graph: Optional[bool] = False, + ): + # create position embeddings to be shared across the decoder layers + # When past_key_values is passed in, we need to offset the position ids when calculating the position embeddings. + # Therefore, cache_position is used. + position_id_offset = cache_position[0] if use_cache else 0 + position_embeddings = self.rotary_emb(hidden_states, position_ids + position_id_offset) + + # decoder layers + all_hidden_states = () if output_hidden_states else None + all_self_attns = () if output_attentions else None + + for decoder_layer in self.layers: + if output_hidden_states: + all_hidden_states += (hidden_states,) + if isinstance(decoder_layer, HiggsAudioDualFFNDecoderLayer): + layer_outputs = decoder_layer( + hidden_states, + attention_mask=causal_mask, + audio_attention_mask=audio_attention_mask, + fast_forward_attention_mask=fast_forward_attention_mask, + position_ids=position_ids, + audio_out_mask=audio_discrete_codes_mask, + is_decoding_audio_token=is_decoding_audio_token, + past_key_value=past_key_values, + output_attentions=output_attentions, + use_cache=use_cache, + cache_position=cache_position, + position_embeddings=position_embeddings, + is_using_cuda_graph=is_using_cuda_graph, + ) + else: + layer_outputs = decoder_layer( + hidden_states, + attention_mask=causal_mask, + position_ids=position_ids, + past_key_value=past_key_values, + output_attentions=output_attentions, + use_cache=use_cache, + cache_position=cache_position, + position_embeddings=position_embeddings, + ) + + hidden_states = layer_outputs[0] + + if output_attentions: + all_self_attns += (layer_outputs[1],) + + return hidden_states, all_hidden_states, all_self_attns + + def forward( + self, + input_ids: Optional[torch.LongTensor] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + attention_mask: Optional[torch.BoolTensor] = None, + audio_features: Optional[torch.FloatTensor] = None, + audio_feature_attention_mask: Optional[torch.BoolTensor] = None, + audio_in_ids: Optional[torch.LongTensor] = None, + audio_in_ids_start: Optional[torch.LongTensor] = None, + audio_out_ids: Optional[torch.LongTensor] = None, + audio_out_ids_start: Optional[torch.LongTensor] = None, + audio_out_ids_start_group_loc: Optional[torch.LongTensor] = None, + label_ids: Optional[torch.LongTensor] = None, + label_audio_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + output_audio_hidden_states: Optional[bool] = False, + return_dict: Optional[bool] = None, + cache_position: Optional[torch.LongTensor] = None, + cache_audio_discrete_codes_mask: Optional[torch.LongTensor] = None, + past_key_values_buckets: Optional[OrderedDict[int, Cache]] = None, + reward: Optional[torch.FloatTensor] = None, + ): + """Forward pass for the Higgs-Audio model. + + Args: + input_ids (:obj:`torch.LongTensor`): + The input ids of the prompt. It will have shape (bsz, seq_len). + When use_cache is enabled, the input_ids will have + shape (bsz, 1) for incremental decode or None + inputs_embeds: + Input embeddings. This flag won't be used. + attention_mask (:obj:`torch.LongTensor`): + The attention mask of the prompt. It will have shape (bsz, seq_len). + audio_features (:obj:`torch.FloatTensor`): + The audio features extracted by Whisper. It will have shape (num_audio_in, feature_dim, max_mel_seq_len). + audio_feature_attention_mask (:obj:`torch.LongTensor`): + The attention mask of the audio features. It will have shape (num_audio_in, max_mel_seq_len). + audio_in_ids (:obj:`torch.LongTensor`): + The discretized audio tokens. It will have shape (num_codebooks, audio_in_total_length). + audio_in_ids_start (:obj:`torch.LongTensor`): + The start indices for each audio in audio_in_ids. It will have shape (num_audio_in,) + audio_out_ids (:obj:`torch.LongTensor`): + The discretized audio tokens. It will have shape (num_codebooks, audio_out_total_length). + audio_out_ids_start (:obj:`torch.LongTensor`): + The start indices for each audio in audio_out_ids. It will have shape (num_audio_out,) + audio_out_ids_start_group_loc (:obj:`torch.LongTensor`): + The sample indices in a batch that map to each element in the audio_out_ids_start. It will have shape (num_audio_out,) + label_text_ids (:obj:`torch.LongTensor`): + The labels of the prompt. It will have shape (bsz, seq_len). + label_audio_ids (:obj:`torch.LongTensor`): + The labels of the audio tokens. It will have the same shape as audio_out_ids, i.e., (num_codebooks, audio_out_total_length) + past_key_values (:obj:`Tuple`): + Tuple of past key values. + use_cache (:obj:`bool`): + Whether to use cache. + output_attentions (:obj:`bool`): + Whether to output attentions. + output_hidden_states (:obj:`bool`): + Whether to output hidden states. + output_audio_hidden_states (:obj:`bool`): + Whether to output audio hidden states. + return_dict (:obj:`bool`): + Whether to return a dictionary. + cache_position (:obj:`torch.LongTensor`): + The position of the cache. + cache_audio_discrete_codes_mask (:obj:`torch.LongTensor`): + The cached audio discrete codes mask. It will only be used when use_cache is turned on. + past_key_values_buckets (:obj:`OrderedDict`): + The buckets of past key values. + """ + target_device = input_ids.device + + # not used + del inputs_embeds + + if audio_features is not None: + audio_features = audio_features.to(target_device) + audio_feature_attention_mask = audio_feature_attention_mask.to(target_device) + + # 1. Extract the input embeddings + inputs_embeds = self.embed_tokens(input_ids) + + # 2. Extract audio embeddings + if self.config.skip_audio_tower: + audio_features_embed = audio_features_length = None + else: + audio_features_embed, audio_features_length = self._apply_audio_tower( + audio_features, audio_feature_attention_mask + ) + + if self.config.encode_audio_in_tokens: + if audio_in_ids is not None and audio_in_ids.shape[-1] > 0: + audio_in_ids = audio_in_ids.to(target_device) + else: + audio_in_ids = torch.zeros( + (self.audio_num_codebooks, 0), + device=target_device, + dtype=torch.long, + ) + audio_in_embed = self._embed_audio_ids(audio_in_ids) + else: + audio_in_embed = None + + if audio_out_ids is not None and audio_out_ids.shape[-1] > 0: + audio_out_ids = audio_out_ids.to(target_device) + else: + audio_out_ids = torch.zeros((self.audio_num_codebooks, 0), device=target_device, dtype=torch.long) + audio_out_embed = self._embed_audio_ids(audio_out_ids) + + # 3. Merge text, audio-in embeddings, and audio-out embeddings + + # use_cache is turned on during inference time, we should set round_to to 1 to avoid extra padding in the end. + round_to = 1 if use_cache else 8 + left_padding = True if use_cache or input_ids.shape[0] == 1 else False + ( + inputs_embeds, + attention_mask, + labels, + position_ids, + input_ids, + audio_in_mask, + audio_in_discrete_codes_mask, + audio_out_mask, + ) = merge_input_ids_with_audio_features( + audio_features_embed, + audio_features_length, + audio_in_embed, + audio_in_ids_start, + audio_out_embed, + audio_out_ids_start, + self.audio_in_token_idx, + self.audio_out_token_idx, + inputs_embeds, + input_ids, + attention_mask, + label_ids, + pad_token_id=self.padding_idx, + round_to=round_to, + left_padding=left_padding, + ) + + # re-check if we use the correct kv cache bucket after + # the input_embeds has been merged with audio features + if past_key_values_buckets is not None and inputs_embeds.shape[1] > past_key_values.get_max_cache_shape(): + past_key_values, self.current_past_key_values_bucket = self._prepare_kv_cache( + inputs_embeds.shape[1], None, past_key_values_buckets + ) + + if use_cache and past_key_values is None: + past_key_values = DynamicCache() + + if cache_position is None: + past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0 + cache_position = torch.arange( + past_seen_tokens, + past_seen_tokens + inputs_embeds.shape[1], + device=inputs_embeds.device, + ) + if isinstance(past_key_values, StaticCache) and past_seen_tokens >= past_key_values.get_max_cache_shape(): + raise ValueError( + f"The current sequence length ({past_seen_tokens}) exceeds " + f"the maximum cache shape. " + f"Please consider increasing the cache size." + ) + + # Use torch compile + use_static_cache = isinstance(past_key_values, StaticCache) + + # Apply the LLM component + causal_mask = self._update_causal_mask( + attention_mask, + inputs_embeds, + cache_position, + past_key_values, + output_attentions, + ) + + hidden_states = inputs_embeds + + audio_discrete_codes_mask = audio_in_discrete_codes_mask | audio_out_mask + if cache_audio_discrete_codes_mask is not None and use_cache: + audio_discrete_codes_mask = torch.concat( + [cache_audio_discrete_codes_mask, audio_discrete_codes_mask], dim=1 + ) + + # Generate the audio attention mask outside the layer to avoid recompilation + if use_static_cache: + fast_forward_attention_mask, audio_attention_mask = self._prepare_all_static_kv_cache_masks( + hidden_states, + causal_mask, + audio_discrete_codes_mask, + past_key_values, + ) + # Set the audio out mask to the last token + if hidden_states.shape[1] == 1: + audio_discrete_codes_mask = audio_discrete_codes_mask[:, -1:] + audio_discrete_codes_mask = audio_discrete_codes_mask.reshape((-1, 1)).contiguous() + is_decoding_audio_token = audio_discrete_codes_mask.item() + else: + is_decoding_audio_token = False + + # Use the captured cuda graph runner for decoding + # if it exists, otherwise use the normal forward pass + if ( + past_key_values is not None + and past_key_values.get_max_cache_shape() in self.decode_graph_runners + and (input_ids.shape[-1] == 1) + ): + _forward_core = self.decode_graph_runners[past_key_values.get_max_cache_shape()][is_decoding_audio_token] + is_using_cuda_graph = True + else: + _forward_core = self._forward_core + is_using_cuda_graph = False + + hidden_states, all_hidden_states, all_self_attns = _forward_core( + hidden_states=hidden_states, + causal_mask=causal_mask, + position_ids=position_ids, + audio_discrete_codes_mask=audio_discrete_codes_mask, + is_decoding_audio_token=is_decoding_audio_token if use_static_cache else None, + cache_position=cache_position, + past_key_values=past_key_values, + use_cache=use_cache, + audio_attention_mask=audio_attention_mask if use_static_cache else None, + fast_forward_attention_mask=fast_forward_attention_mask if use_static_cache else None, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + is_using_cuda_graph=is_using_cuda_graph, + ) + hidden_states = self.norm(hidden_states) + + # add hidden states from the last decoder layer + if output_hidden_states: + all_hidden_states += (hidden_states,) + + # Apply the audio decoder projector + ( + logits, + audio_logits, + decoder_all_self_attns, + decoder_all_hidden_states, + audio_hidden_states, + _, + ) = self.audio_decoder_proj( + hidden_states, + audio_out_mask, + label_audio_ids=label_audio_ids, + attention_mask=causal_mask, + position_ids=position_ids, + past_key_values=past_key_values, + use_cache=use_cache, + output_attentions=output_attentions, + output_audio_hidden_states=output_audio_hidden_states, + cache_position=cache_position, + ) + + if audio_logits is not None: + audio_logits = audio_logits.view( + audio_logits.shape[0], + self.audio_num_codebooks, + self.audio_codebook_size, + ).float() + + if output_hidden_states: + if decoder_all_hidden_states is not None and len(decoder_all_hidden_states) > 1: + all_hidden_states += decoder_all_hidden_states[1:] + + if output_attentions: + all_self_attns += decoder_all_self_attns + + next_cache = past_key_values if use_cache else None + + ret = HiggsAudioModelOutputWithPast( + logits=logits, + audio_logits=audio_logits, + expanded_input_ids=input_ids, + expanded_labels=labels, + audio_in_mask=audio_in_mask, + audio_in_discrete_codes_mask=audio_in_discrete_codes_mask, + audio_out_mask=audio_out_mask, + attention_mask=attention_mask, + past_key_values=next_cache, + hidden_states=all_hidden_states, + audio_hidden_states=audio_hidden_states, + attentions=all_self_attns, + ) + + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + if not return_dict: + outputs = ret.to_tuple() + return outputs + + return ret + + # Overwrite GenerationMixin._update_model_kwargs_for_generation + def _update_model_kwargs_for_generation( + self, + outputs: ModelOutput, + model_kwargs: Dict[str, Any], + is_encoder_decoder: bool = False, + num_new_tokens: int = 1, + extend_attention_mask: bool = True, + ) -> Dict[str, Any]: + """Update the model kwargs for each step.""" + model_kwargs["past_key_values"] = outputs.past_key_values + + # update attention mask + if "attention_mask" in model_kwargs: + attention_mask = model_kwargs["attention_mask"] + if extend_attention_mask: + model_kwargs["attention_mask"] = torch.cat( + [ + attention_mask, + attention_mask.new_ones((attention_mask.shape[0], 1)), + ], + dim=-1, + ) + if "cache_audio_discrete_codes_mask" in model_kwargs: + if model_kwargs["cache_audio_discrete_codes_mask"] is None: + model_kwargs["cache_audio_discrete_codes_mask"] = ( + outputs.audio_in_discrete_codes_mask | outputs.audio_out_mask + ) + else: + model_kwargs["cache_audio_discrete_codes_mask"] = torch.concat( + [ + model_kwargs["cache_audio_discrete_codes_mask"], + outputs.audio_in_discrete_codes_mask | outputs.audio_out_mask, + ], + 1, + ) + + return model_kwargs + + def _copy_kv_cache(self, from_cache: Cache, to_cache: Cache): + num_layers = self.config.text_config.num_hidden_layers + if self.config.audio_dual_ffn_layers is not None: + num_layers += len(self.config.audio_dual_ffn_layers) + """ Copy the key-value pairs from one cache to another. """ + for layer_idx in range(num_layers): + from_cache_size = from_cache.get_max_cache_shape() + assert to_cache.get_max_cache_shape() >= from_cache_size, ( + f"The target cache size {to_cache.get_max_cache_shape()} is smaller than the source cache size {from_cache_size}." + ) + to_cache.key_cache[layer_idx][:, :, :from_cache_size, :] = from_cache.key_cache[layer_idx] + to_cache.value_cache[layer_idx][:, :, :from_cache_size, :] = from_cache.value_cache[layer_idx] + + def _prepare_kv_cache( + self, + current_sequence_length: int, + current_past_key_values_bucket: Optional[int], + past_key_values_buckets: OrderedDict[int, Cache], + ) -> Tuple[Optional[Cache], Optional[int]]: + """Prepare the KV cache for the current sequence length.""" + for cache_length in past_key_values_buckets.keys(): + if cache_length >= current_sequence_length: + # Promote to the next KV cache bucket, copy the current KV cache bucket + # to the new one. + if current_past_key_values_bucket is not None and cache_length != current_past_key_values_bucket: + self._copy_kv_cache( + past_key_values_buckets[current_past_key_values_bucket], + past_key_values_buckets[cache_length], + ) + + return past_key_values_buckets[cache_length], cache_length + + raise ValueError( + f"The current sequence length {current_sequence_length} is larger than " + f"all past key values buckets {past_key_values_buckets.keys()}." + ) + + def _sample_audio_tokens( + self, + hidden_states: torch.Tensor, + audio_logits: torch.Tensor, + audio_out_ids: torch.Tensor, + do_sample: bool, + logits_processor: LogitsProcessorList, + device: torch.device, + torch_generator: Optional[torch.Generator], + generation_config: GenerationConfig, + num_delay: int, + num_remaining_delays: Optional[int], + ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, int, Optional[int]]: + """Sample audio tokens and its corresponding text tokens from the logits""" + + # parameters related to repetition aware sampling + ras_win_len = generation_config.generation_kwargs.get("ras_win_len", None) + ras_win_max_num_repeat = generation_config.generation_kwargs.get("ras_win_max_num_repeat", 2) + audio_eos_token_id = generation_config.generation_kwargs.get("audio_eos_token_id", None) + # In the audio generation mode, we sample from audio_logits and keep updating audio_out_ids. + next_audio_token_logits = audio_logits.clone()[-1, :, :].float().to(device) + # TopP, TopK logits processor supports empty input_ids + next_audio_token_scores = logits_processor(None, next_audio_token_logits) + + # token selection + if do_sample: + # next_audio_token_scores has been applied top_p, top_k, and temperature. + probs = nn.functional.softmax(next_audio_token_scores, dim=-1) + # TODO (joao): this OP throws "skipping cudagraphs due to ['incompatible ops']", find solution + next_audio_tokens = torch.multinomial(probs, num_samples=1, generator=torch_generator).squeeze(1) + else: + next_audio_tokens = torch.argmax(next_audio_token_scores, dim=-1) + + # next_tokens: (num_codebooks, ) + if ras_win_len is not None: + # check if there are repetitions over a window of tokens. + rep_num = (audio_out_ids[:, -ras_win_len:] == next_audio_tokens.unsqueeze(1)).sum(dim=1) + + # if we saw repeated tokens in the most recent window of tokens, resample without temperature. + row_indices = torch.nonzero(rep_num >= ras_win_max_num_repeat).squeeze(1) + resampled_next_tokens = ( + next_audio_token_logits[row_indices] + .softmax(dim=-1) + .multinomial(1, replacement=True, generator=torch_generator) + .squeeze(1) + ) + next_audio_tokens[row_indices] = resampled_next_tokens + + # Force the next text tokens to be <|AUDIO_OUT|> in audio generation mode + next_tokens = torch.full( + (audio_logits.shape[0],), + self.config.audio_out_token_idx, + dtype=torch.long, + device=device, + ) + + # Handle delay_pattern + if self.use_delay_pattern: + if num_delay + 1 < next_audio_tokens.shape[0]: + next_audio_tokens[(num_delay + 1) :] = self.config.audio_stream_bos_id + num_delay += 1 + if num_remaining_delays is not None: + next_audio_tokens[: (self.audio_num_codebooks - num_remaining_delays)] = ( + self.config.audio_stream_eos_id + ) + num_remaining_delays -= 1 + else: + all_eos_indices = (next_audio_tokens == self.config.audio_stream_eos_id).nonzero() + if torch.numel(all_eos_indices) > 0: + all_eos_indices = all_eos_indices[0] + last_eos_idx = all_eos_indices[-1] + next_audio_tokens[:last_eos_idx] = self.config.audio_stream_eos_id + num_remaining_delays = self.audio_num_codebooks - last_eos_idx - 1 + if num_remaining_delays is not None and num_remaining_delays <= 0: + next_tokens[...] = audio_eos_token_id + num_delay = 0 + num_remaining_delays = None + + return ( + next_tokens, + next_audio_tokens, + next_audio_token_logits, + next_audio_token_scores, + num_delay, + num_remaining_delays, + ) + + def _sample_text_tokens( + self, + logits: torch.Tensor, + input_ids: torch.Tensor, + do_sample: bool, + logits_processor: LogitsProcessorList, + device: torch.device, + generation_mode: GenerationMode, + torch_generator: Optional[torch.Generator], + ) -> torch.Tensor: + """Sample text tokens from the logits""" + # Clone is needed to avoid keeping a hanging ref to outputs.logits which may be very large for first iteration + # (the clone itself is always small) + next_token_logits = logits.clone()[:, -1, :].float() + next_token_logits = next_token_logits.to(input_ids.device) + + # pre-process distribution + next_token_scores = logits_processor(input_ids, next_token_logits) + + if generation_mode == GenerationMode.AUDIO_INIT: + # See the audio bos token, we should start generating audio tokens + next_tokens = torch.full( + (input_ids.shape[0],), + self.audio_out_token_idx, + dtype=torch.long, + device=device, + ) + next_audio_tokens = torch.full( + (self.config.audio_num_codebooks,), + self.config.audio_stream_bos_id, + dtype=torch.long, + device=device, + ) + else: + if do_sample: + probs = nn.functional.softmax(next_token_scores, dim=-1) + # TODO (joao): this OP throws "skipping cudagraphs due to ['incompatible ops']", find solution + next_tokens = torch.multinomial(probs, num_samples=1, generator=torch_generator).squeeze(1) + else: + next_tokens = torch.argmax(next_token_scores, dim=-1) + + next_audio_tokens = None + + return next_tokens, next_audio_tokens, next_token_logits, next_token_scores + + # Built on top of GenerationMixin._sample. + # We revise the implementation to support generating both audio / text. + def _sample( + self, + input_ids: torch.LongTensor, + logits_processor: LogitsProcessorList, + stopping_criteria: StoppingCriteriaList, + generation_config: GenerationConfig, + synced_gpus: bool, + streamer: Optional["BaseStreamer"], + past_key_values_buckets: Optional[OrderedDict[int, Cache]], + **model_kwargs, + ) -> Union[GenerateNonBeamOutput, torch.LongTensor]: + r""" + Generates sequences of token ids for joint text/audio models using **multinomial sampling**. + + This function may also be revised to support generating samples from HiggsAudio-like end-to-end text/audio models built on top of LLMs. + If the input_ids ends with <|audio_out_bos|>, we will switch to the audio-generation mode. + + ``` + ...<|start_header_id|>assistant<|end_header_id|>\n\n<|audio_out_bos|> + ``` + + Otherwise, we will keep generating the text tokens. + + Parameters: + input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`): + The sequence used as a prompt for the generation. + logits_processor (`LogitsProcessorList`): + An instance of [`LogitsProcessorList`]. List of instances of class derived from [`LogitsProcessor`] + used to modify the prediction scores of the language modeling head applied at each generation step. + stopping_criteria (`StoppingCriteriaList`): + An instance of [`StoppingCriteriaList`]. List of instances of class derived from [`StoppingCriteria`] + used to tell if the generation loop should stop. + generation_config ([`~generation.GenerationConfig`]): + The generation configuration to be used as parametrization of the decoding method. + synced_gpus (`bool`): + Whether to continue running the while loop until max_length (needed to avoid deadlocking with + `FullyShardedDataParallel` and DeepSpeed ZeRO Stage 3). + streamer (`BaseStreamer`, *optional*): + Streamer object that will be used to stream the generated sequences. Generated tokens are passed + through `streamer.put(token_ids)` and the streamer is responsible for any further processing. + model_kwargs: + Additional model specific kwargs will be forwarded to the `forward` function of the model. If model is + an encoder-decoder model the kwargs should include `encoder_outputs`. + + Return: + [`~generation.GenerateDecoderOnlyOutput`], [`~generation.GenerateEncoderDecoderOutput`] or `torch.LongTensor`: + A `torch.LongTensor` containing the generated tokens (default behaviour) or a + [`~generation.GenerateDecoderOnlyOutput`] if `model.config.is_encoder_decoder=False` and + `return_dict_in_generate=True` or a [`~generation.GenerateEncoderDecoderOutput`] if + `model.config.is_encoder_decoder=True`. + """ + assert input_ids.shape[0] == 1, "Only support batch_size=1 in _sample()" + audio_out_bos_token_id = generation_config.generation_kwargs.get("audio_out_bos_token_id", None) + + # torch generator for sampling + seed = generation_config.generation_kwargs.get("seed", None) + if seed is not None: + torch_generator = torch.Generator(device=input_ids.device).manual_seed(seed) + else: + torch_generator = None + + # init values + pad_token_id = generation_config._pad_token_tensor + output_attentions = generation_config.output_attentions + output_hidden_states = generation_config.output_hidden_states + output_scores = generation_config.output_scores + output_logits = generation_config.output_logits + return_dict_in_generate = generation_config.return_dict_in_generate + max_length = generation_config.max_length + has_eos_stopping_criteria = any(hasattr(criteria, "eos_token_id") for criteria in stopping_criteria) + do_sample = generation_config.do_sample + # Used to track which past_key_va + self.current_past_key_values_bucket = None + + # init attention / hidden states / scores tuples + scores = () if (return_dict_in_generate and output_scores) else None + raw_logits = () if (return_dict_in_generate and output_logits) else None + + decoder_attentions = () if (return_dict_in_generate and output_attentions) else None + decoder_hidden_states = () if (return_dict_in_generate and output_hidden_states) else None + + # keep track of which sequences are already finished + batch_size, cur_len = input_ids.shape + this_peer_finished = False + unfinished_sequences = torch.ones(batch_size, dtype=torch.long, device=input_ids.device) + if generation_config.use_cache: + model_kwargs["cache_audio_discrete_codes_mask"] = None + + init_model_input = True + num_delay = 0 + num_remaining_delays = None + audio_sequences = [] + # A tensor to keep track of all the audio placeholder tokens. + input_ids_full = input_ids.clone() + + # Initialize the audio variables based on the input prompt. + if input_ids[0][-1] == self.config.audio_out_token_idx: + audio_sequences = [model_kwargs["audio_out_ids"][:, model_kwargs["audio_out_ids_start"][-1] :]] + if self.use_delay_pattern: + num_delay = ( + self.audio_num_codebooks + - (model_kwargs["audio_out_ids"][:, -1] == self.config.audio_stream_bos_id).sum() + ) + all_eos_indices = (model_kwargs["audio_out_ids"][:, -1] == self.config.audio_stream_eos_id).nonzero() + if torch.numel(all_eos_indices) > 0: + all_eos_indices = all_eos_indices[0] + last_eos_idx = all_eos_indices[-1] + num_remaining_delays = self.audio_num_codebooks - last_eos_idx - 1 + + while self._has_unfinished_sequences( + this_peer_finished, + synced_gpus, + device=input_ids.device, + cur_len=cur_len, + max_length=max_length, + ): + # Check which multimodal stage we are in + # FIXME: Assume single input generation + if input_ids[0][-1] == audio_out_bos_token_id: + generation_mode = GenerationMode.AUDIO_INIT + elif input_ids[0][-1] == self.audio_out_token_idx: + generation_mode = GenerationMode.AUDIO_IN_PROGRESS + else: + generation_mode = GenerationMode.TEXT + + is_audio_generation_mode = generation_mode == GenerationMode.AUDIO_IN_PROGRESS + + if init_model_input or not generation_config.use_cache: + model_inputs = {"input_ids": input_ids, **model_kwargs} + else: + model_inputs = {"input_ids": input_ids[:, -1:], **model_kwargs} + + if is_audio_generation_mode and generation_config.use_cache: + model_inputs["audio_out_ids"] = model_kwargs["audio_out_ids"][:, -1:] + model_inputs["audio_out_ids_start"] = torch.tensor([0], dtype=torch.long, device=input_ids.device) + elif not is_audio_generation_mode: + del model_inputs["audio_out_ids"] + del model_inputs["audio_out_ids_start"] + + if generation_config.use_cache: + if "audio_features" in model_inputs and model_inputs["audio_features"] is not None: + model_inputs["audio_features"] = model_inputs["audio_features"][:0, ...] + model_inputs["audio_feature_attention_mask"] = model_inputs["audio_feature_attention_mask"][ + :0, ... + ] + + if "audio_in_ids" in model_inputs and model_inputs["audio_in_ids"] is not None: + model_inputs["audio_in_ids"] = None + model_inputs["audio_in_ids_start"] = None + + # prepare variable output controls (note: some models won't accept all output controls) + model_inputs.update({"output_attentions": output_attentions} if output_attentions else {}) + model_inputs.update({"output_hidden_states": output_hidden_states} if output_hidden_states else {}) + + if past_key_values_buckets is not None: + past_key_values, self.current_past_key_values_bucket = self._prepare_kv_cache( + cur_len, + self.current_past_key_values_bucket, + past_key_values_buckets, + ) + if past_key_values is not None: + model_inputs.update({"past_key_values": past_key_values}) + model_inputs["past_key_values_buckets"] = past_key_values_buckets + + # forward pass to get next token + outputs = self(**model_inputs, return_dict=True) + + # Update the actual sequence length after the first forward pass + if init_model_input and past_key_values_buckets is not None: + cur_len = past_key_values_buckets[self.current_past_key_values_bucket].get_seq_length().item() + + # synced_gpus: don't waste resources running the code we don't need; kwargs must be updated before skipping + model_kwargs = self._update_model_kwargs_for_generation( + outputs, + model_kwargs, + is_encoder_decoder=self.config.is_encoder_decoder, + extend_attention_mask=True, + ) + + # After the first forward pass, we can set init_model_input to False. + init_model_input = False + + if synced_gpus and this_peer_finished: + continue + + if is_audio_generation_mode: + # In audio generation mode, we sample the audio tokens from audio logits. + # It might also generate the audio eos token to end the audio generation. + ( + next_tokens, + next_audio_tokens, + next_audio_token_logits, + next_audio_token_scores, + num_delay, + num_remaining_delays, + ) = self._sample_audio_tokens( + hidden_states=outputs.audio_hidden_states, + audio_logits=outputs.audio_logits, + audio_out_ids=model_kwargs["audio_out_ids"], + do_sample=do_sample, + logits_processor=logits_processor, + device=input_ids.device, + torch_generator=torch_generator, + generation_config=generation_config, + num_delay=num_delay, + num_remaining_delays=num_remaining_delays, + ) + + # update generated ids, model inputs, and length for next step + model_kwargs["audio_out_ids"] = torch.cat( + [model_kwargs["audio_out_ids"], next_audio_tokens[:, None]], dim=-1 + ) + audio_sequences[-1] = torch.cat([audio_sequences[-1], next_audio_tokens[:, None]], dim=-1) + + if streamer is not None: + streamer.put(next_audio_tokens.cpu()) + else: + # In text generation mode, we sample the text tokens from text logits. + # It might also generate the audio placeholder token to start the audio generation. + next_tokens, next_audio_tokens, next_token_logits, next_token_scores = self._sample_text_tokens( + input_ids=input_ids, + logits=outputs.logits, + do_sample=do_sample, + logits_processor=logits_processor, + device=input_ids.device, + generation_mode=generation_mode, + torch_generator=torch_generator, + ) + + if streamer is not None: + streamer.put(next_tokens.cpu()) + + if next_audio_tokens is not None: + # If the token is audio bos token, we will generate the audio placeholder token + # and the corrensponding audio stream bos token to start the audio generation. + audio_sequences.append(next_audio_tokens[:, None]) + if streamer is not None: + streamer.put(next_audio_tokens.cpu()) + if model_kwargs["audio_out_ids"] is None or model_kwargs["audio_out_ids"].shape[0] == 0: + # Initialize audio_out_ids + model_kwargs["audio_out_ids"] = next_audio_tokens[:, None] + model_kwargs["audio_out_ids_start"] = torch.tensor( + [0], dtype=torch.long, device=input_ids.device + ) + else: + model_kwargs["audio_out_ids_start"] = torch.concat( + [ + model_kwargs["audio_out_ids_start"], + torch.tensor( + [model_kwargs["audio_out_ids"].shape[1]], + dtype=torch.long, + device=input_ids.device, + ), + ], + dim=0, + ) + model_kwargs["audio_out_ids"] = torch.concat( + [model_kwargs["audio_out_ids"], next_audio_tokens[:, None]], + dim=1, + ) + + if return_dict_in_generate: + if output_scores: + if is_audio_generation_mode: + scores += (next_audio_token_scores,) + else: + scores += (next_token_scores,) + if output_logits: + if is_audio_generation_mode: + raw_logits += (next_audio_token_logits,) + else: + raw_logits += (next_token_logits,) + if output_attentions: + decoder_attentions += (outputs.attentions,) + if output_hidden_states: + decoder_hidden_states += (outputs.hidden_states,) + + # finished sentences should have their next token be a padding token + if has_eos_stopping_criteria: + next_tokens = next_tokens * unfinished_sequences + pad_token_id * (1 - unfinished_sequences) + + if "tokenizer_length" in generation_config.generation_kwargs: + tokenizer_length = generation_config.generation_kwargs["tokenizer_length"] + if torch.max(next_tokens) >= tokenizer_length: + raise ValueError( + f"Next generated token has max value {torch.max(next_tokens)} which is greater than the tokenizer's vocabulary size {tokenizer_length}, this is undesired behavior." + ) + + # update generated ids, model inputs, and length for next step + if not is_audio_generation_mode or next_tokens[0] != self.audio_out_token_idx: + # We only add one <|AUDIO_OUT|> token to the input_ids for simplicity. + input_ids = torch.cat([input_ids, next_tokens[:, None]], dim=-1) + input_ids_full = torch.cat([input_ids_full, next_tokens[:, None]], dim=-1) + unfinished_sequences = unfinished_sequences & ~stopping_criteria(input_ids_full, scores) + this_peer_finished = unfinished_sequences.max() == 0 + cur_len += 1 + + # This is needed to properly delete outputs.logits which may be very large for first iteration + # Otherwise a reference to outputs is kept which keeps the logits alive in the next iteration + del outputs + + if streamer is not None: + streamer.end() + + if return_dict_in_generate: + return HiggsAudioGenerationOutput( + sequences=input_ids, + audio_sequences=audio_sequences, + scores=scores, + logits=raw_logits, + attentions=decoder_attentions, + hidden_states=decoder_hidden_states, + past_key_values=model_kwargs.get("past_key_values"), + ) + else: + return input_ids, audio_sequences + + @torch.inference_mode() + def generate( + self, + input_ids: Optional[torch.LongTensor] = None, + audio_features: Optional[torch.FloatTensor] = None, + audio_feature_attention_mask: Optional[torch.BoolTensor] = None, + audio_in_ids: Optional[torch.LongTensor] = None, + audio_in_ids_start: Optional[torch.LongTensor] = None, + audio_out_ids: Optional[torch.LongTensor] = None, + audio_out_ids_start: Optional[torch.LongTensor] = None, + past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None, + audio_out_bos_token_id: int = None, + audio_eos_token_id: int = None, + past_key_values_buckets: Optional[OrderedDict[int, Cache]] = None, + seed: Optional[int] = None, + **kwargs, + ): + """ + The generate function in huggingface generally follows these steps: + + for sample_step in 1, 2, 3, 4, 5, ... + ... + + """ + # Right now, it's a very simplified version of generate, we should revisit this after our model architecture stabilizes. + assert input_ids.shape[0] == 1, ( + "Currently HiggsAudioModel.generate() only supports batch_size=1. See the implementation of " + ) + generation_config, kwargs = self._prepare_generation_config(kwargs.pop("generation_config", None), **kwargs) + if audio_out_bos_token_id is not None: + generation_config.generation_kwargs["audio_out_bos_token_id"] = audio_out_bos_token_id + else: + try: + generation_config.generation_kwargs["audio_out_bos_token_id"] = self.audio_out_bos_token_id + except: + generation_config.generation_kwargs["audio_out_bos_token_id"] = None + + if audio_eos_token_id is not None: + generation_config.generation_kwargs["audio_eos_token_id"] = audio_eos_token_id + else: + try: + generation_config.generation_kwargs["audio_eos_token_id"] = self.audio_eos_token_id + except: + generation_config.generation_kwargs["audio_eos_token_id"] = None + + has_default_max_length = kwargs.get("max_length") is None and generation_config.max_length is not None + has_default_min_length = kwargs.get("min_length") is None and generation_config.min_length is not None + + generation_config.generation_kwargs["ras_win_len"] = kwargs.pop("ras_win_len", None) + generation_config.generation_kwargs["ras_win_max_num_repeat"] = kwargs.pop("ras_win_max_num_repeat", 2) + # Set generation seed if determinstic generation is required + if seed is not None: + generation_config.generation_kwargs["seed"] = seed + + # Store tokenizer in generation config if it is in kwargs without popping it + if "tokenizer" in kwargs: + generation_config.generation_kwargs["tokenizer_length"] = len(kwargs["tokenizer"]) + + # input_ids: [bsz, seq_len] + # The merging of audio features happens inside the forward path. The input_ids does not need to change. + # TODO: prepare the final input embeddings to improve generation performance + input_ids_length = input_ids.shape[-1] + generation_config = self._prepare_generated_length( + generation_config=generation_config, + has_default_max_length=has_default_max_length, + has_default_min_length=has_default_min_length, + model_input_name=None, + inputs_tensor=None, + input_ids_length=input_ids_length, + ) + assert generation_config.num_beams == 1, "Currently, we only support beam search with num_beams=1" + return_dict_in_generate = generation_config.return_dict_in_generate + output_scores = generation_config.output_scores + + # When attn_implement is spda or flash-attention, it will create causal mask automatically. + attention_mask = kwargs.pop("attention_mask", None) + return super().generate( + input_ids=input_ids, + attention_mask=attention_mask, + audio_features=audio_features, + audio_feature_attention_mask=audio_feature_attention_mask, + audio_in_ids=audio_in_ids, + audio_in_ids_start=audio_in_ids_start, + audio_out_ids=audio_out_ids, + audio_out_ids_start=audio_out_ids_start, + past_key_values=past_key_values, + generation_config=generation_config, + output_scores=output_scores, + return_dict_in_generate=return_dict_in_generate, + past_key_values_buckets=past_key_values_buckets, + **kwargs, + ) + + def parameter_count_per_component(self): + """Count the number of parameters per component in the model. + + HiggsAudio has the following main components: + audio_tower: For mapping audio features to hidden states), + llm_embed: The size of embedding layer of the LLM + llm_non_embed: The size of non-embedding layer of the LLM + audio_adapter: The overall size of additional layers for audio generation + + """ + trainable_stats = { + "audio_tower": 0, + "llm_embed": 0, + "llm_non_embed": 0, + "audio_embed": 0, + "audio_adapter": 0, + "overall": 0, + } + total_stats = { + "audio_tower": 0, + "llm_embed": 0, + "llm_non_embed": 0, + "audio_embed": 0, + "audio_adapter": 0, + "overall": 0, + } + + total_stats["overall"] = count_parameters(self, trainable_only=False) + trainable_stats["overall"] = count_parameters(self, trainable_only=True) + + for mod in [self.audio_tower]: + if mod is not None: + total_stats["audio_tower"] += count_parameters(mod, trainable_only=False) + trainable_stats["audio_tower"] += count_parameters(mod, trainable_only=True) + + total_stats["llm_embed"] = count_parameters(self.embed_tokens, trainable_only=False) + trainable_stats["llm_embed"] = count_parameters(self.embed_tokens, trainable_only=True) + + total_stats["audio_embed"] = count_parameters(self.audio_codebook_embeddings, trainable_only=False) + trainable_stats["audio_embed"] = count_parameters(self.audio_codebook_embeddings, trainable_only=True) + + # Calculate number of parameters for LLM + for layer in self.layers: + if isinstance(layer, HiggsAudioDualFFNDecoderLayer): + total_param_count = count_parameters(layer, trainable_only=False) + total_trainable_param_count = count_parameters(layer, trainable_only=True) + total_stats["llm_non_embed"] += total_param_count + trainable_stats["llm_non_embed"] += total_trainable_param_count + if not layer.fast_forward: + audio_mlp_param_count = count_parameters(layer.audio_mlp, trainable_only=False) + audio_mlp_trainable_param_count = count_parameters(layer.audio_mlp, trainable_only=True) + + audio_norm_param_count = count_parameters( + layer.audio_post_attention_layernorm, trainable_only=False + ) + count_parameters(layer.audio_input_layernorm, trainable_only=False) + audio_norm_trainable_param_count = count_parameters( + layer.audio_post_attention_layernorm, trainable_only=True + ) + count_parameters(layer.audio_input_layernorm, trainable_only=True) + total_stats["llm_non_embed"] -= audio_mlp_param_count + audio_norm_param_count + trainable_stats["llm_non_embed"] -= ( + audio_mlp_trainable_param_count + audio_norm_trainable_param_count + ) + total_stats["audio_adapter"] += audio_mlp_param_count + audio_norm_param_count + trainable_stats["audio_adapter"] += ( + audio_mlp_trainable_param_count + audio_norm_trainable_param_count + ) + + if layer.use_audio_attention: + audio_attn_param_count = count_parameters( + layer.audio_attn, trainable_only=False + ) + count_parameters(layer.audio_post_audio_attn_layer_norm, trainable_only=False) + audio_attn_trainable_param_count = count_parameters( + layer.audio_attn, trainable_only=True + ) + count_parameters(layer.audio_post_audio_attn_layer_norm, trainable_only=True) + total_stats["llm_non_embed"] -= audio_attn_param_count + trainable_stats["llm_non_embed"] -= audio_attn_trainable_param_count + total_stats["audio_adapter"] += audio_attn_param_count + trainable_stats["audio_adapter"] += audio_attn_trainable_param_count + else: + total_stats["llm_non_embed"] += count_parameters(layer, trainable_only=False) + trainable_stats["llm_non_embed"] += count_parameters(layer, trainable_only=True) + total_stats["llm_non_embed"] += count_parameters(self.norm, trainable_only=False) + trainable_stats["llm_non_embed"] += count_parameters(self.norm, trainable_only=True) + + total_stats["audio_adapter"] += count_parameters(self.audio_decoder_proj.audio_lm_head, trainable_only=False) + trainable_stats["audio_adapter"] += count_parameters( + self.audio_decoder_proj.audio_lm_head, trainable_only=True + ) + total_stats["llm_embed"] += count_parameters(self.audio_decoder_proj.text_lm_head, trainable_only=False) + trainable_stats["llm_embed"] += count_parameters(self.audio_decoder_proj.text_lm_head, trainable_only=True) + + other_audio_modules = [self.audio_encoder_proj] + if self.use_audio_out_embed_projector: + other_audio_modules.append(self.audio_out_embed_projector) + + for mod in other_audio_modules: + if mod is not None: + total_stats["audio_adapter"] += count_parameters(mod, trainable_only=False) + trainable_stats["audio_adapter"] += count_parameters(mod, trainable_only=True) + return {"trainable": trainable_stats, "total": total_stats} + + def set_skip_audio_tower(self): + self.config.skip_audio_tower = True + self.config.encode_whisper_embed = False + + def set_encode_audio_in_tokens(self): + self.config.encode_audio_in_tokens = True + + def freeze_audio_tower(self): + if self.audio_tower is not None: + for param in self.audio_tower.parameters(): + param.requires_grad = False + + def freeze_audio_encoder_proj(self): + if self.audio_encoder_proj is not None: + for param in self.audio_encoder_proj.parameters(): + param.requires_grad = False + + def freeze_llm(self, freeze_embed=True, freeze_embed_until_idx: Optional[int] = None): + for layer in self.layers: + if isinstance(layer, HiggsAudioDualFFNDecoderLayer): + for param in layer.self_attn.parameters(): + param.requires_grad = False + for param in layer.mlp.parameters(): + param.requires_grad = False + + for param in layer.post_attention_layernorm.parameters(): + param.requires_grad = False + + for param in layer.input_layernorm.parameters(): + param.requires_grad = False + else: + for param in layer.parameters(): + param.requires_grad = False + + for param in self.norm.parameters(): + param.requires_grad = False + + if freeze_embed: + if freeze_embed_until_idx is None: + for param in self.embed_tokens.parameters(): + param.requires_grad = False + else: + assert isinstance(self.embed_tokens, nn.Embedding) + self.embed_tokens = PartiallyFrozenEmbedding( + original_embedding=self.embed_tokens, + freeze_until_idx=freeze_embed_until_idx, + ) + + def freeze_text_head(self, freeze_text_head_until_idx: Optional[int] = None): + """Freeze the final text head""" + if freeze_text_head_until_idx is None: + for param in self.audio_decoder_proj.text_lm_head.parameters(): + param.requires_grad = False + + else: + assert isinstance(self.audio_decoder_proj.text_lm_head, nn.Linear) + self.audio_decoder_proj.text_lm_head = PartiallyFrozenLinear( + original_linear=self.audio_decoder_proj.text_lm_head, + freeze_until_idx=freeze_text_head_until_idx, + ) + + @classmethod + def merge_weights_from_checkpoint(cls, checkpoint_dir: str, merged_output_dir: str, *model_args, **kwargs): + # For users' convenience, we merge back embedding and text_lm_head if they are splitted + splitted_model = super().from_pretrained( + checkpoint_dir, + *model_args, + torch_dtype=torch.bfloat16, + device_map="cpu", + **{**kwargs, "state_dict": None}, # Prevent auto-loading state_dict + ) + + # Load all safetensor shards + state_dict = {} + shard_paths = sorted(glob.glob(os.path.join(checkpoint_dir, "*.safetensors"))) + + for shard_path in shard_paths: + shard_dict = load_file(shard_path) # Load each shard + state_dict.update(shard_dict) # Merge into a single dict + + # Merge weights + if ( + "audio_decoder_proj.text_lm_head.linear_frozen.weight" in state_dict + and "audio_decoder_proj.text_lm_head.linear_trainable.weight" in state_dict + ): + state_dict["audio_decoder_proj.text_lm_head.weight"] = torch.cat( + [ + state_dict["audio_decoder_proj.text_lm_head.linear_frozen.weight"], + state_dict["audio_decoder_proj.text_lm_head.linear_trainable.weight"], + ], + dim=0, + ) + + del state_dict["audio_decoder_proj.text_lm_head.linear_frozen.weight"] + del state_dict["audio_decoder_proj.text_lm_head.linear_trainable.weight"] + + if ( + "embed_tokens.embedding_frozen.weight" in state_dict + and "embed_tokens.embedding_trainable.weight" in state_dict + ): + state_dict["embed_tokens.weight"] = torch.cat( + [ + state_dict["embed_tokens.embedding_frozen.weight"], + state_dict["embed_tokens.embedding_trainable.weight"], + ], + dim=0, + ) + + del state_dict["embed_tokens.embedding_frozen.weight"] + del state_dict["embed_tokens.embedding_trainable.weight"] + + # Load the final state_dict + splitted_model.load_state_dict(state_dict, strict=True) + + if merged_output_dir: + splitted_model.save_pretrained(merged_output_dir, is_main_process=True, state_dict=state_dict) + + @torch.inference_mode() + def capture_model(self, past_key_values: list[Union[Cache, List[torch.FloatTensor]]]) -> None: + """Capture CUDA graphs for the model's forward pass with different KV cache lengths. + + Args: + past_key_values: List of KV caches to capture graphs for + """ + for past_key_value in past_key_values: + kv_cache_length = past_key_value.get_max_cache_shape() + # We capture two graphs, one for decoding audio tokens and one for decoding text tokens + for is_decoding_audio_token in [True, False]: + runner = CUDAGraphRunner(self._forward_core) + + # Create dummy inputs for graph capture + batch_size = 1 + hidden_dim = self.config.hidden_size + + hidden_states = torch.zeros( + (batch_size, 1, hidden_dim), + dtype=self.config.torch_dtype, + device="cuda", + ) + causal_mask = torch.ones( + (batch_size, 1, 1, kv_cache_length), + dtype=self.config.torch_dtype, + device="cuda", + ) + position_ids = torch.zeros((batch_size, 1), dtype=torch.long, device="cuda") + audio_discrete_codes_mask = torch.tensor([[is_decoding_audio_token]], dtype=torch.bool, device="cuda") + cache_position = torch.tensor([kv_cache_length - 1], dtype=torch.long, device="cuda") + audio_attention_mask = torch.ones_like(causal_mask) + fast_forward_attention_mask = torch.ones_like(causal_mask) + + runner.capture( + hidden_states=hidden_states, + causal_mask=causal_mask, + position_ids=position_ids, + audio_discrete_codes_mask=audio_discrete_codes_mask, + cache_position=cache_position, + past_key_values=past_key_value, + use_cache=True, + audio_attention_mask=audio_attention_mask, + fast_forward_attention_mask=fast_forward_attention_mask, + output_attentions=False, + output_hidden_states=False, + is_decoding_audio_token=is_decoding_audio_token, + is_using_cuda_graph=True, + ) + + self.decode_graph_runners[kv_cache_length][is_decoding_audio_token] = runner diff --git a/higgs_audio/model/utils.py b/higgs_audio/model/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..ee073043f64ba4dde2e400f9a11052eed64735d6 --- /dev/null +++ b/higgs_audio/model/utils.py @@ -0,0 +1,778 @@ +import contextlib +from contextlib import contextmanager +from functools import wraps +import torch +from transformers.integrations import is_deepspeed_available + +if is_deepspeed_available(): + from deepspeed.utils import groups as deepspeed_groups + from deepspeed.sequence.layer import _SeqAllToAll +else: + deepspeed_groups = None + _SeqAllToAll = None + + +def _ceil_to_nearest(n, round_to): + return (n + round_to - 1) // round_to * round_to + + +def count_parameters(model, trainable_only=True): + if trainable_only: + return sum(p.numel() for p in model.parameters() if p.requires_grad) + else: + return sum(p.numel() for p in model.parameters()) + + +# TODO(sxjscience) Consider to move the function to audio_processing/utils.py +def build_delay_pattern_mask( + input_ids: torch.LongTensor, + bos_token_id: int, + pad_token_id: int, +): + """Implement the delay pattern proposed in "Simple and Controllable Music Generation", https://arxiv.org/pdf/2306.05284 + + In the delay pattern, each codebook is offset by the previous codebook by + one. We insert a special delay token at the start of the sequence if its delayed, and append pad token once the sequence finishes. + + Take the example where there are 4 codebooks and audio sequence length=5. After shifting, the output should have length seq_len + num_codebooks - 1 + + - [ *, *, *, *, *, P, P, P] + - [ B, *, *, *, *, *, P, P] + - [ B, B, *, *, *, *, *, P] + - [ B, B, B, *, *, *, *, *] + + where B indicates the delay token id, P is the special padding token id and `*` indicates that the original audio token. + + Now let's consider the case where we have a sequence of audio tokens to condition on. + The audio tokens were originally in the following non-delayed form: + + - [a, b] + - [c, d] + - [e, f] + - [g, h] + + After conversion, we get the following delayed form: + - [a, b, -1, -1, -1] + - [B, c, d, -1, -1] + - [B, B, e, f, -1] + - [B, B, B, g, h] + + Note that we have a special token `-1` that indicates it should be replaced by a new token we see in the generation phase. + In that case, we should override the `-1` tokens in auto-regressive generation. + + Args: + input_ids (:obj:`torch.LongTensor`): + The input ids of the prompt. It will have shape (bsz, num_codebooks, seq_len). + bos_token_id (:obj:`int`): + The id of the special delay token + pad_token_id (:obj:`int`): + The id of the padding token. Should be the same as eos_token_id. + + Returns: + input_ids (:obj:`torch.LongTensor`): + The transformed input ids with delay pattern applied. It will have shape (bsz, num_codebooks, seq_len + num_codebooks - 1). + input_ids_with_gen_mask (:obj:`torch.LongTensor`): + The transformed input ids with delay pattern applied. The -1 in the output indicates new tokens that should be generated. + + """ + bsz, num_codebooks, seq_len = input_ids.shape + + new_seq_len = seq_len + num_codebooks - 1 + input_ids_with_gen_mask = torch.ones((bsz, num_codebooks, new_seq_len), dtype=torch.long, device=input_ids.device) + bos_mask = torch.tril(input_ids_with_gen_mask, -1) > 0 + eos_mask = torch.triu(input_ids_with_gen_mask, seq_len) > 0 + input_ids_with_gen_mask[bos_mask] = bos_token_id + input_ids_with_gen_mask[(~bos_mask) & (~eos_mask)] = input_ids.reshape(-1) + input_ids = input_ids_with_gen_mask.clone() + input_ids[eos_mask] = pad_token_id + input_ids_with_gen_mask[eos_mask] = -1 + return input_ids, input_ids_with_gen_mask + + +def revert_delay_pattern(data): + """Convert samples encoded with delay pattern back to the original form. + + Args: + data (:obj:`torch.Tensor`): + The data with delay pattern applied. It will have shape (num_codebooks, seq_len + num_codebooks - 1). + + Returns: + ret (:obj:`torch.Tensor`): + Recovered data with delay pattern removed. It will have shape (num_codebooks, seq_len). + """ + assert len(data.shape) == 2 + out_l = [] + num_codebooks = data.shape[0] + for i in range(num_codebooks): + out_l.append(data[i : (i + 1), i : (data.shape[1] - num_codebooks + 1 + i)]) + return torch.cat(out_l, dim=0) + + +def merge_input_ids_with_audio_features( + audio_features_embed, + audio_features_length, + audio_in_embed, + audio_in_ids_start, + audio_out_embed, + audio_out_ids_start, + audio_in_token_idx, + audio_out_token_idx, + inputs_embeds, + input_ids, + attention_mask, + label_ids, + pad_token_id, + ignore_index=-100, + round_to=8, + left_padding=True, +): + """ + Merge input_ids with audio features into final embeddings. + + Args: + audio_features_embed (`torch.Tensor` of shape `(num_audios, max_audio_tokens, embed_dim)`): + Encoded vectors of all audios in the batch (obtained from the semantic encoder) + audio_features_length (`torch.LongTensor` of shape `(num_audios,)`): + The length of audio embeddings of each audio as stacked in `audio_features_embed` + audio_in_embed (`torch.Tensor` of shape `(total_num_audio_in_tokens, embed_dim)`): + The embeddings of audio-in tokens + audio_in_ids_start (`torch.LongTensor` of shape `(num_audios,)`): + The start index of the audio-in tokens for each audio + audio_out_embed (`torch.Tensor` of shape `(total_num_audio_out_tokens, embed_dim)`): + The embeddings of audio-out tokens + audio_out_ids_start (`torch.LongTensor` of shape `(num_audios,)`): + The start index of the audio-out tokens for each audio + audio_in_token_idx + The index of the audio-in token in the vocabulary + audio_out_token_idx + The index of the audio-out token in the vocabulary + inputs_embeds (`torch.Tensor` of shape `(batch_size, sequence_length, embed_dim)`): + Token embeddings before merging with audio embeddings + input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`): + Input_ids of tokens, possibly filled with audio token + attention_mask (`torch.LongTensor` of shape `(batch_size, sequence_length)`): + Mask to avoid performing attention on padding token indices. + label_ids (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*) + labels need to be recalculated to support training (if provided) + pad_token_id (`int`): + The index of the pad token in the vocabulary + ignore_index + The index to ignore in the loss calculation + round_to + The number to round to for padding + left_padding + Whether to apply left padding + + Returns: + final_embedding + The final embeddings after merging audio embeddings with text embeddings. + final_attention_mask + The final attention mask after merging audio embeddings with text embeddings. + final_labels + The labels for the text stream + position_ids + Positional ids for the merged data + final_input_ids + The final input_ids after merging audio embeddings with text embeddings. + final_audio_in_mask + Mask for audio-in embeddings + final_audio_in_discrete_codes_mask + Mask for audio-in discrete tokens + final_audio_out_mask + Mask for audio-out embeddings + + Explanation: + each audio has variable length embeddings, with length specified by + - audio_features_length + - audio_in_ids_start + - audio_out_ids_start + + Task: + - fill each <|AUDIO|> with audio embeddings (it can be the combination of embeddings extracted by WhisperEncoder and embeddings from audio codebooks) + - fill each <|AUDIO_OUT|> with the audio-out embeddings + + Example: + <|AUDIO_OUT|>: X (5 tokens), Y (3 tokens) + <|AUDIO|>: Z (8 tokens) + + X, Y are in the same sequence (in-context voice-clone). Z is in a different sequence (audio understanding). + if right padding + input_ids: [ + a b c d e f X g h i j k Y l m + o p q r Z s t u v _ _ _ _ _ _ + ] + input_ids should be: [ + a b c d e f X X X X X g h i j k Y Y Y l m + o p q r Z Z Z Z Z Z Z Z s t u v _ _ _ _ _ + ] + labels should be: [ + a b c d e f _ _ _ _ _ g h i j k _ _ _ l m + o p q r _ _ _ _ _ _ _ _ s t u v _ _ _ _ _ + ] + elif left padding + input_ids: [ + a b c d e f X g h i j k Y l m + _ _ _ _ _ _ o p q r Z s t u v + ] + input_ids should be: [ + a b c d e f X X X X X g h i j k Y Y Y l m + _ _ _ _ _ o p q r Z Z Z Z Z Z Z Z s t u v + ] + labels should be: [ + a b c d e f _ _ _ _ _ g h i j k _ _ _ l m + _ _ _ _ _ o p q r _ _ _ _ _ _ _ _ s t u v + ] + + """ + if label_ids is None: + skip_labels = True + else: + skip_labels = False + if audio_features_embed is not None and audio_features_embed.shape[0] == 0: + audio_features_embed = None + if audio_in_embed is not None and audio_in_embed.shape[0] == 0: + audio_in_embed = None + if audio_out_embed is not None and audio_out_embed.shape[0] == 0: + audio_out_embed = None + + batch_size, sequence_length, embed_dim = inputs_embeds.shape + + target_device = inputs_embeds.device + if left_padding is None: + left_padding = torch.any(attention_mask[:, 0] == 0) + + audio_in_token_mask = input_ids == audio_in_token_idx + audio_out_token_mask = input_ids == audio_out_token_idx + text_token_mask = (input_ids != audio_in_token_idx) & (input_ids != audio_out_token_idx) + + # 1. Calculate the number of tokens for each placeholder (like [<|AUDIO|>, <|AUDIO_OUT|>]). + token_placeholder_num = torch.ones_like(input_ids) + + if audio_features_embed is not None: + num_audios, max_audio_tokens, _ = audio_features_embed.shape + audio_in_features_mask = torch.arange(max_audio_tokens).expand(num_audios, max_audio_tokens).to( + audio_features_length.device + ) < audio_features_length.unsqueeze(1) + masked_audio_in_features = audio_features_embed[audio_in_features_mask].view(-1, embed_dim) + token_placeholder_num[audio_in_token_mask] = audio_features_length.long() + + if audio_in_embed is not None: + audio_in_codes_length = torch.concat( + [ + audio_in_ids_start[1:] - audio_in_ids_start[:-1], + torch.tensor( + [audio_in_embed.shape[0] - audio_in_ids_start[-1]], + device=audio_in_ids_start.device, + dtype=torch.long, + ), + ], + dim=0, + ) + if audio_features_embed is not None: + token_placeholder_num[audio_in_token_mask] += audio_in_codes_length.long() + else: + token_placeholder_num[audio_in_token_mask] = audio_in_codes_length.long() + + if audio_out_embed is not None: + audio_out_codes_length = torch.concat( + [ + audio_out_ids_start[1:] - audio_out_ids_start[:-1], + torch.tensor( + [audio_out_embed.shape[0] - audio_out_ids_start[-1]], + device=audio_out_ids_start.device, + dtype=torch.long, + ), + ], + dim=0, + ) + token_placeholder_num[audio_out_token_mask] = audio_out_codes_length.long() + + new_token_positions = torch.cumsum(token_placeholder_num, -1) - 1 + max_token_num = _ceil_to_nearest(token_placeholder_num.sum(-1).max(), round_to) + nb_audio_pad = max_token_num - 1 - new_token_positions[:, -1] + + if left_padding: + new_token_positions += nb_audio_pad[:, None] # offset for left padding + + # 2. Create the full embedding, already padded to the maximum position + final_embedding = torch.zeros( + (batch_size, max_token_num, embed_dim), + dtype=inputs_embeds.dtype, + device=inputs_embeds.device, + ) + final_attention_mask = torch.zeros( + (batch_size, max_token_num), + dtype=attention_mask.dtype, + device=inputs_embeds.device, + ) + final_input_ids = torch.full( + (batch_size, max_token_num), + pad_token_id, + dtype=input_ids.dtype, + device=inputs_embeds.device, + ) + if skip_labels: + final_labels = None + else: + final_labels = torch.full( + (batch_size, max_token_num), + ignore_index, + dtype=label_ids.dtype, + device=inputs_embeds.device, + ) + + final_audio_in_mask = torch.full( + (batch_size, max_token_num), + False, + dtype=torch.bool, + device=inputs_embeds.device, + ) + final_audio_in_discrete_codes_mask = torch.full( + (batch_size, max_token_num), + False, + dtype=torch.bool, + device=inputs_embeds.device, + ) + final_audio_out_mask = torch.full( + (batch_size, max_token_num), + False, + dtype=torch.bool, + device=inputs_embeds.device, + ) + # 3. Get the audio-in token positions and audio-out token positions + batch_id = torch.arange(batch_size, device=target_device).unsqueeze(1).expand(batch_size, sequence_length) + audio_in_batch_id = batch_id[audio_in_token_mask] # Shape (num_audio_in,) + audio_out_batch_id = batch_id[audio_out_token_mask] # Shape (num_audio_out,) + audio_features_token_ends = new_token_positions[audio_in_token_mask] # Shape (num_audio_in,) + audio_out_embed_ends = new_token_positions[audio_out_token_mask] # Shape (num_audio_out,) + + if audio_in_embed is not None: + # Fill in the audio-in embeddings + seq_indices = ( + torch.arange(max_token_num, device=target_device) + .unsqueeze(0) + .expand(audio_in_ids_start.shape[0], max_token_num) + ) + audio_in_embed_token_starts = audio_features_token_ends - audio_in_codes_length + 1 + batch_indices, col_indices = torch.where( + (seq_indices >= audio_in_embed_token_starts.unsqueeze(1)) + & (seq_indices <= audio_features_token_ends.unsqueeze(1)) + ) + batch_indices = audio_in_batch_id[batch_indices] + final_embedding[batch_indices, col_indices] = audio_in_embed + final_input_ids[batch_indices, col_indices] = audio_in_token_idx + if not skip_labels: + final_labels[batch_indices, col_indices] = ignore_index + final_audio_in_mask[batch_indices, col_indices] = True + final_audio_in_discrete_codes_mask[batch_indices, col_indices] = True + audio_features_token_ends = audio_features_token_ends - audio_in_codes_length + + if audio_features_embed is not None: + # Fill in the audio features + seq_indices = ( + torch.arange(max_token_num, device=target_device) + .unsqueeze(0) + .expand(audio_features_embed.shape[0], max_token_num) + ) + audio_features_token_starts = audio_features_token_ends - audio_features_length + 1 + batch_indices, col_indices = torch.where( + (seq_indices >= audio_features_token_starts.unsqueeze(1)) + & (seq_indices <= audio_features_token_ends.unsqueeze(1)) + ) + batch_indices = audio_in_batch_id[batch_indices] + final_embedding[batch_indices, col_indices] = masked_audio_in_features + final_input_ids[batch_indices, col_indices] = audio_in_token_idx + if not skip_labels: + final_labels[batch_indices, col_indices] = ignore_index + final_audio_in_mask[batch_indices, col_indices] = True + + if audio_out_embed is not None: + # Fill in the audio-out embeddings + seq_indices = ( + torch.arange(max_token_num, device=target_device) + .unsqueeze(0) + .expand(audio_out_ids_start.shape[0], max_token_num) + ) + audio_out_embed_token_starts = audio_out_embed_ends - audio_out_codes_length + 1 + batch_indices, col_indices = torch.where( + (seq_indices >= audio_out_embed_token_starts.unsqueeze(1)) + & (seq_indices <= audio_out_embed_ends.unsqueeze(1)) + ) + batch_indices = audio_out_batch_id[batch_indices] + final_embedding[batch_indices, col_indices] = audio_out_embed + final_input_ids[batch_indices, col_indices] = audio_out_token_idx + if not skip_labels: + final_labels[batch_indices, col_indices] = ignore_index + final_audio_out_mask[batch_indices, col_indices] = True + + # Fill in the original text embeddings and labels + batch_indices, non_audio_indices = torch.where(text_token_mask) + text_to_overwrite = new_token_positions[batch_indices, non_audio_indices] + final_embedding[batch_indices, text_to_overwrite] = inputs_embeds[batch_indices, non_audio_indices] + if not skip_labels: + final_labels[batch_indices, text_to_overwrite] = label_ids[batch_indices, non_audio_indices] + final_input_ids[batch_indices, text_to_overwrite] = input_ids[batch_indices, non_audio_indices] + final_attention_mask[batch_indices, text_to_overwrite] = attention_mask[batch_indices, non_audio_indices] + final_attention_mask = final_attention_mask | final_audio_in_mask | final_audio_out_mask + + # Trim the tensor if there are redundant padding tokens + if left_padding: + first_non_zero_loc = final_attention_mask.sum(0).nonzero()[0] + first_non_zero_loc = (first_non_zero_loc // round_to) * round_to + if first_non_zero_loc > 0: + final_attention_mask = final_attention_mask[:, first_non_zero_loc:] + final_embedding = final_embedding[:, first_non_zero_loc:] + if not skip_labels: + final_labels = final_labels[:, first_non_zero_loc:] + final_input_ids = final_input_ids[:, first_non_zero_loc:] + final_audio_in_mask = final_audio_in_mask[:, first_non_zero_loc:] + final_audio_in_discrete_codes_mask = final_audio_in_discrete_codes_mask[:, first_non_zero_loc:] + final_audio_out_mask = final_audio_out_mask[:, first_non_zero_loc:] + else: + # We have done right padding, so we need to trim the mask + last_non_zero_loc = final_attention_mask.sum(0).nonzero()[-1] + 1 + last_non_zero_loc = ((last_non_zero_loc + round_to - 1) // round_to) * round_to + if last_non_zero_loc < max_token_num: + final_attention_mask = final_attention_mask[:, :last_non_zero_loc] + final_embedding = final_embedding[:, :last_non_zero_loc] + if not skip_labels: + final_labels = final_labels[:, :last_non_zero_loc] + final_input_ids = final_input_ids[:, :last_non_zero_loc] + final_audio_in_mask = final_audio_in_mask[:, :last_non_zero_loc] + final_audio_in_discrete_codes_mask = final_audio_in_discrete_codes_mask[:, :last_non_zero_loc] + final_audio_out_mask = final_audio_out_mask[:, :last_non_zero_loc] + + position_ids = (final_attention_mask.cumsum(-1) - 1).masked_fill_((final_attention_mask == 0), 1) + return ( + final_embedding, + final_attention_mask, + final_labels, + position_ids, + final_input_ids, + final_audio_in_mask, + final_audio_in_discrete_codes_mask, + final_audio_out_mask, + ) + + +def is_deepspeed_ulysses_enabled(): + if deepspeed_groups is None: + return False + + """Check if sequence parallelism is enabled.""" + return deepspeed_groups._get_sequence_parallel_world_size() > 1 + + +def support_deepspeed_ulysses(module): + """A decorator around Pytorch module. It is needed for the module that needs access to sequence parallel info.""" + module._sp_size = None + module._sp_rank = None + module._sp_group = None + + @property + def sp_size(self): + if self._sp_size is None: + self._sp_size = 1 + if is_deepspeed_ulysses_enabled(): + self._sp_size = deepspeed_groups._get_sequence_parallel_group().size() + return self._sp_size + + @property + def sp_rank(self): + if self._sp_rank is None: + self._sp_rank = 0 + if is_deepspeed_ulysses_enabled(): + self._sp_rank = deepspeed_groups._get_sequence_parallel_rank() + return self._sp_rank + + @property + def sp_group(self): + if self._sp_group is None and is_deepspeed_ulysses_enabled(): + self._sp_group = deepspeed_groups._get_sequence_parallel_group() + return self._sp_group + + module.sp_size = sp_size + module.sp_rank = sp_rank + module.sp_group = sp_group + + return module + + +def deepspeed_ulysses_attention(seq_dim=1, head_dim=2): + """Perform all-to-all before and after the attention function.""" + + def attention_decorator(attn_func=None): + def wrapped(*args, **kwargs): + if is_deepspeed_ulysses_enabled(): + sp_group = deepspeed_groups._get_sequence_parallel_group() + scatter_idx = head_dim # Scatter on num_heads dimension + gather_idx = seq_dim # Gather on seq_len dimension + batch_dim_idx = 0 + args = list(args) + args[0] = _SeqAllToAll.apply(sp_group, args[0], scatter_idx, gather_idx, batch_dim_idx) + args[1] = _SeqAllToAll.apply(sp_group, args[1], scatter_idx, gather_idx, batch_dim_idx) + args[2] = _SeqAllToAll.apply(sp_group, args[2], scatter_idx, gather_idx, batch_dim_idx) + args = tuple(args) + + attn_output = attn_func(*args, **kwargs) + + if is_deepspeed_ulysses_enabled(): + scatter_idx = seq_dim # Scatter back on seq_len dimension + gather_idx = head_dim # Gather on num_heads dimension + batch_dim_idx = 0 + attn_output = _SeqAllToAll.apply(sp_group, attn_output, scatter_idx, gather_idx, batch_dim_idx) + + return attn_output + + return wrapped + + return attention_decorator + + +def deepspeed_ulysses_rope(state_seq_dim=2, trig_seq_dim=1): + """Slice the corresponding cos and sin chunks for rope.""" + + def rope_decorator(rope_func=None): + def wrapped(*args, **kwargs): + if is_deepspeed_ulysses_enabled(): + sp_rank = deepspeed_groups._get_sequence_parallel_rank() + args = list(args) + seq_chunk_size = args[0].size(state_seq_dim) + args[2] = torch.narrow(args[2], trig_seq_dim, sp_rank * seq_chunk_size, seq_chunk_size) + args[3] = torch.narrow(args[3], trig_seq_dim, sp_rank * seq_chunk_size, seq_chunk_size) + args = tuple(args) + + return rope_func(*args, **kwargs) + + return wrapped + + return rope_decorator + + +def _gather_tensors(input_, group=None): + """Gather tensors and concatenate them along a dimension.""" + input_ = input_.contiguous() + world_size = torch.distributed.get_world_size(group) + if world_size == 1: + return input_ + tensor_shapes = [ + torch.empty(len(input_.size()), dtype=torch.int64, device=input_.device) for _ in range(world_size) + ] + input_size = torch.tensor(input_.size(), dtype=torch.int64, device=input_.device) + torch.distributed.all_gather(tensor_shapes, input_size, group=group) + gathered_buffers = [ + torch.empty(tensor_shapes[i].tolist(), dtype=input_.dtype, device=input_.device) for i in range(world_size) + ] + torch.distributed.all_gather(gathered_buffers, input_, group=group) + return gathered_buffers + + +def _scatter_tensors(input_, group=None): + """Scatter tensors.""" + world_size = torch.distributed.get_world_size(group) + if world_size == 1: + return input_ + rank = torch.distributed.get_rank(group) + return input_[rank] + + +class _GatherTensors(torch.autograd.Function): + """All gather tensors among the ranks.""" + + @staticmethod + def symbolic(graph, input_, group): + return _gather_tensors(input_, group) + + @staticmethod + def forward(ctx, input_, group): + ctx.group = group + return torch.nested.as_nested_tensor(_gather_tensors(input_, group), layout=torch.jagged) + + @staticmethod + def backward(ctx, grad_output): + return _scatter_tensors(grad_output, ctx.group), None + + +def all_gather_tensors(input_, size=None, dim=0, group=None): + if torch.distributed.get_world_size(group) == 1: + # no sequence parallelism + return input_ + gathered_tensors = _GatherTensors.apply(input_, group) + + if size: + split_gathered_tensors = [] + for s, gathered_tensor in zip(size, gathered_tensors): + split_gathered_tensor = torch.split(gathered_tensor, s.tolist()) + split_gathered_tensors.append(split_gathered_tensor) + + gathered_tensors = [y for x in zip(*split_gathered_tensors) for y in x] + + return torch.cat(gathered_tensors, dim).contiguous() + + +def get_sequence_data_parallel_world_size(): + return torch.distributed.get_world_size() + + +def get_sequence_data_parallel_rank(): + return torch.distributed.get_rank() + + +def get_sequence_data_parallel_group(): + return torch.distributed.group.WORLD + + +if is_deepspeed_available(): + deepspeed_groups._get_sequence_data_parallel_world_size = get_sequence_data_parallel_world_size + deepspeed_groups._get_sequence_data_parallel_rank = get_sequence_data_parallel_rank + deepspeed_groups._get_sequence_data_parallel_group = get_sequence_data_parallel_group + + +def _gather_tokens(input_, dim=0, group=None): + """Gather tensors and concatenate them along a dimension""" + input_ = input_.contiguous() + world_size = torch.distributed.get_world_size(group) + if world_size == 1: + return input_ + + gather_buffer = torch.empty(world_size * input_.numel(), dtype=input_.dtype, device=input_.device) + torch.distributed.all_gather_into_tensor(gather_buffer, input_, group=group) + if dim == 0: + shape = list(input_.size()) + shape[0] = shape[0] * world_size + output = gather_buffer.view(shape) + else: + tensor_list = [ + gather_buffer.narrow(0, input_.numel() * i, input_.numel()).view_as(input_) for i in range(world_size) + ] + # Note: torch.cat already creates a contiguous tensor. + output = torch.cat(tensor_list, dim=dim).contiguous() + + return output + + +def _drop_tokens(input_, dim=0, group=None): + """Divide a tensor among the sequence parallel ranks""" + world_size = torch.distributed.get_world_size(group) + if world_size == 1: + return input_ + this_rank = torch.distributed.get_rank(group) + assert input_.shape[dim] % world_size == 0, ( + f"input dimension {dim} ({input_.shape[dim]}) is not divisible by sequence parallel world size ({world_size})" + ) + chunk_size = input_.shape[dim] // world_size + + return torch.narrow(input_, dim, this_rank * chunk_size, chunk_size) + + +class _DropTokens(torch.autograd.Function): + "Divide tokens equally among the sequence parallel ranks" + + @staticmethod + def symbolic(graph, input_, dim, group, grad_scale): + return _drop_tokens(input_, dim, group) + + @staticmethod + def forward(ctx, input_, dim, group, grad_scale): + ctx.dim = dim + ctx.group = group + ctx.grad_scale = grad_scale + return _drop_tokens(input_, dim, group) + + @staticmethod + def backward(ctx, grad_output): + grad_input = _gather_tokens(grad_output, ctx.dim, ctx.group) + if ctx.grad_scale != 1: + grad_input /= ctx.grad_scale + return grad_input, None, None, None + + +class _GatherTokens(torch.autograd.Function): + "Gather tokens among the sequence parallel ranks" + + @staticmethod + def symbolic(graph, input_, dim, group, grad_scale): + return _gather_tokens(input_, dim, group) + + @staticmethod + def forward(ctx, input_, dim, group, grad_scale): + ctx.dim = dim + ctx.group = group + ctx.grad_scale = grad_scale + return _gather_tokens(input_, dim, group) + + @staticmethod + def backward(ctx, grad_output): + grad_input = _drop_tokens(grad_output, ctx.dim, ctx.group) + if ctx.grad_scale != 1: + grad_input *= ctx.grad_scale + return grad_input, None, None, None + + +def drop_tokens(input_, dim=0, group=None, grad_scale=1): + if torch.distributed.get_world_size(group) == 1: + # no sequence parallelism + return input_ + return _DropTokens.apply(input_, dim, group, grad_scale) + + +def gather_tokens(input_, dim=0, group=None, grad_scale=1): + if torch.distributed.get_world_size(group) == 1: + # no sequence parallelism + return input_ + return _GatherTokens.apply(input_, dim, group, grad_scale) + + +def sequence_chunking_per_rank(sp_size, sp_rank, *args, dim=1): + """ + Slice the inputs to create chuncks per the sequence parallel rank. This is used for the context parallel training. + + Args: + sp_size (`int`): + Sequence parallel size. + sp_rank (`int`): + Sequence parallel rank for the current process. + dim (`int`): + The dimension to slice + """ + if sp_size == 1: + return args[0] if len(args) == 1 else args + + seq_length = args[0].size(dim) + for arg in args[1:]: + assert arg.size(dim) == seq_length, ( + f"arg={arg} ({arg.shape[dim]}) does not have the same size as args[0] ({seq_length}) in dimension {dim}" + ) + assert seq_length % sp_size == 0, ( + f"dimension {dim} ({args[0].shape[dim]}) is not divisible by sequence parallel world size ({sp_size})" + ) + + sub_seq_length = seq_length // sp_size + sub_seq_start = sp_rank * sub_seq_length + + output = [] + for ind in args: + ind = torch.narrow(ind, dim, sub_seq_start, sub_seq_length) + output.append(ind) + + return tuple(output) if len(output) > 1 else output[0] + + +@contextmanager +def disable_deepspeed_ulysses(): + """Disable deepspeed ulysses (sequence parallelism) if it is enabled""" + if is_deepspeed_ulysses_enabled(): + _old_get_sequence_parallel_world_size = deepspeed_groups._get_sequence_parallel_world_size + + def _get_sequence_parallel_world_size(): + return 1 + + deepspeed_groups._get_sequence_parallel_world_size = _get_sequence_parallel_world_size + try: + yield + finally: + deepspeed_groups._get_sequence_parallel_world_size = _old_get_sequence_parallel_world_size + else: + context = contextlib.nullcontext + with context(): + yield diff --git a/higgs_audio/serve/serve_engine.py b/higgs_audio/serve/serve_engine.py new file mode 100644 index 0000000000000000000000000000000000000000..f624a5c5cd96421bd8c4148b2940e6f3a553b8be --- /dev/null +++ b/higgs_audio/serve/serve_engine.py @@ -0,0 +1,424 @@ +import asyncio +import base64 +import torch +import numpy as np +from io import BytesIO +from dataclasses import dataclass +from typing import List, Optional, Union +from copy import deepcopy +from transformers import AutoTokenizer, AutoProcessor +from transformers.cache_utils import StaticCache +from transformers.generation.streamers import BaseStreamer +from transformers.generation.stopping_criteria import StoppingCriteria +from dataclasses import asdict +from loguru import logger +import threading +import librosa + + +from ..dataset.chatml_dataset import ( + ChatMLSample, + ChatMLDatasetSample, + prepare_chatml_sample, +) +from ..model import HiggsAudioModel +from ..model.utils import revert_delay_pattern +from ..data_collator.higgs_audio_collator import HiggsAudioSampleCollator +from ..audio_processing.higgs_audio_tokenizer import load_higgs_audio_tokenizer + + +@dataclass +class HiggsAudioStreamerDelta: + """Represents a chunk of generated content, either text or audio tokens.""" + + text: Optional[str] = None + text_tokens: Optional[torch.Tensor] = None + audio_tokens: Optional[torch.Tensor] = None + finish_reason: Optional[str] = None + + +class AsyncHiggsAudioStreamer(BaseStreamer): + """ + Async streamer that handles both text and audio token generation from Higgs-Audio model. + Stores chunks in a queue to be consumed by downstream applications. + + Parameters: + tokenizer (`AutoTokenizer`): + The tokenizer used to decode text tokens. + skip_prompt (`bool`, *optional*, defaults to `False`): + Whether to skip the prompt tokens in generation. + timeout (`float`, *optional*): + The timeout for the queue. If `None`, the queue will block indefinitely. + decode_kwargs (`dict`, *optional*): + Additional keyword arguments to pass to the tokenizer's `decode` method. + + Examples: + ```python + >>> from transformers import AutoTokenizer + >>> from threading import Thread + >>> import asyncio + + >>> tokenizer = AutoTokenizer.from_pretrained("path/to/higgs/tokenizer") + >>> model = HiggsAudioModel.from_pretrained("path/to/higgs/model") + >>> inputs = tokenizer(["Generate some text and audio:"], return_tensors="pt") + + >>> async def main(): + ... streamer = AsyncHiggsAudioStreamer(tokenizer) + ... generation_kwargs = dict(inputs, streamer=streamer, max_new_tokens=20) + ... thread = Thread(target=model.generate, kwargs=generation_kwargs) + ... thread.start() + ... + ... async for delta in streamer: + ... if delta.text is not None: + ... print("Text:", delta.text) + ... if delta.audio_tokens is not None: + ... print("Audio tokens shape:", delta.audio_tokens.shape) + >>> asyncio.run(main()) + ``` + """ + + def __init__( + self, + tokenizer: "AutoTokenizer", + skip_prompt: bool = False, + timeout: Optional[float] = None, + audio_num_codebooks: int = 1, + **decode_kwargs, + ): + self.tokenizer = tokenizer + self.skip_prompt = skip_prompt + self.timeout = timeout + self.decode_kwargs = decode_kwargs + self.audio_num_codebooks = audio_num_codebooks + + # Queue to store generated chunks + self.queue = asyncio.Queue() + self.stop_signal = None + + # Get running event loop + self.loop = asyncio.get_running_loop() + self.has_asyncio_timeout = hasattr(asyncio, "timeout") + + # State tracking + self.next_tokens_are_prompt = True + + def put(self, value: torch.Tensor): + """ + Receives tokens and processes them as either text or audio tokens. + For text tokens, decodes and caches them until complete words are formed. + For audio tokens, directly queues them. + """ + if value.shape[0] > 1 and not self.next_tokens_are_prompt: + # This is likely audio tokens (shape: [audio_num_codebooks]) + assert value.shape[0] == self.audio_num_codebooks, "Number of codebooks mismatch" + delta = HiggsAudioStreamerDelta(audio_tokens=value) + self.loop.call_soon_threadsafe(self.queue.put_nowait, delta) + return + + # Skip prompt tokens if configured + if self.skip_prompt and self.next_tokens_are_prompt: + self.next_tokens_are_prompt = False + return + + # Process as text tokens + if len(value.shape) > 1: + value = value[0] + + text = self.tokenizer.decode(value, **self.decode_kwargs) + delta = HiggsAudioStreamerDelta(text=text, text_tokens=value) + self.loop.call_soon_threadsafe(self.queue.put_nowait, delta) + + def end(self): + """Flushes any remaining text tokens and signals the end of generation.""" + self.next_tokens_are_prompt = True + self.loop.call_soon_threadsafe(self.queue.put_nowait, self.stop_signal) + + def __aiter__(self): + return self + + async def __anext__(self): + try: + if self.has_asyncio_timeout: + async with asyncio.timeout(self.timeout): + value = await self.queue.get() + else: + value = await asyncio.wait_for(self.queue.get(), timeout=self.timeout) + except asyncio.TimeoutError: + raise TimeoutError() + else: + if value == self.stop_signal: + raise StopAsyncIteration() + else: + return value + + +class AsyncStoppingCriteria(StoppingCriteria): + """ + Stopping criteria that checks for stop signal from a threading event. + + Args: + stop_signal (threading.Event): Event that will receive stop signals + """ + + def __init__(self, stop_signal: threading.Event): + self.stop_signal = stop_signal + + def __call__(self, input_ids, scores, **kwargs) -> bool: + if self.stop_signal.is_set(): + logger.info(f"Stop signal received. Can be caused by client disconnection.") + return True + return False + + +@dataclass +class HiggsAudioResponse: + audio: Optional[np.ndarray] = None + generated_audio_tokens: Optional[np.ndarray] = None + sampling_rate: Optional[int] = None + generated_text: str = "" + generated_text_tokens: np.ndarray = np.array([]) + usage: Optional[dict] = None + + +class HiggsAudioServeEngine: + def __init__( + self, + model_name_or_path: str, + audio_tokenizer_name_or_path: str, + tokenizer_name_or_path: Optional[str] = None, + device: str = "cuda", + torch_dtype: Union[torch.dtype, str] = "auto", + kv_cache_lengths: List[int] = [1024, 4096, 8192], # Multiple KV cache sizes + ): + """ + Initialize the HiggsAudioServeEngine, a serving wrapper for the HiggsAudioModel. + The model, tokenizer, and audio tokenizer will be downloaded from the Hugging Face Hub if they are not local. + + Args: + model_name_or_path (str): + The name or path of the model to load. + audio_tokenizer_name_or_path (str): + The name or path of the audio tokenizer to load. + tokenizer_name_or_path (str): + The name or path of the tokenizer to load. + device (str): + The device to use for the model. + kv_cache_lengths (List[int]): + The lengths of the KV caches to use for the model. Used for cuda graph capture when device is cuda. + torch_dtype (Union[torch.dtype, str]): + The dtype to use for the model. + """ + self.device = device + self.model_name_or_path = model_name_or_path + self.torch_dtype = torch_dtype + + # Initialize model and tokenizer + self.model = HiggsAudioModel.from_pretrained(model_name_or_path, torch_dtype=torch_dtype).to(device) + logger.info(f"Loaded model from {model_name_or_path}, dtype: {self.model.dtype}") + + if tokenizer_name_or_path is None: + tokenizer_name_or_path = model_name_or_path + logger.info(f"Loading tokenizer from {tokenizer_name_or_path}") + self.tokenizer = AutoTokenizer.from_pretrained(tokenizer_name_or_path) + + logger.info(f"Initializing Higgs Audio Tokenizer") + self.audio_tokenizer = load_higgs_audio_tokenizer(audio_tokenizer_name_or_path, device=device) + + self.audio_num_codebooks = self.model.config.audio_num_codebooks + self.audio_codebook_size = self.model.config.audio_codebook_size + self.audio_tokenizer_tps = self.audio_tokenizer.tps + self.samples_per_token = int(self.audio_tokenizer.sampling_rate // self.audio_tokenizer_tps) + self.hamming_window_len = 2 * self.audio_num_codebooks * self.samples_per_token + # Set the audio special tokens + self.model.set_audio_special_tokens(self.tokenizer) + + # Prepare KV caches for different lengths + cache_config = deepcopy(self.model.config.text_config) + cache_config.num_hidden_layers = self.model.config.text_config.num_hidden_layers + if self.model.config.audio_dual_ffn_layers: + cache_config.num_hidden_layers += len(self.model.config.audio_dual_ffn_layers) + # A list of KV caches for different lengths + self.kv_caches = { + length: StaticCache( + config=cache_config, + max_batch_size=1, + max_cache_len=length, + device=self.model.device, + dtype=self.model.dtype, + ) + for length in sorted(kv_cache_lengths) + } + + if self.model.config.encode_whisper_embed: + logger.info(f"Loading whisper processor") + whisper_processor = AutoProcessor.from_pretrained( + "openai/whisper-large-v3-turbo", + trust_remote=True, + device=self.device, + ) + else: + whisper_processor = None + + # Reuse collator to prepare inference samples + self.collator = HiggsAudioSampleCollator( + whisper_processor=whisper_processor, + encode_whisper_embed=self.model.config.encode_whisper_embed, + audio_in_token_id=self.model.config.audio_in_token_idx, + audio_out_token_id=self.model.config.audio_out_token_idx, + audio_stream_bos_id=self.model.config.audio_stream_bos_id, + audio_stream_eos_id=self.model.config.audio_stream_eos_id, + pad_token_id=self.model.config.pad_token_id, + return_audio_in_tokens=False, + use_delay_pattern=self.model.config.use_delay_pattern, + audio_num_codebooks=self.model.config.audio_num_codebooks, + round_to=1, + ) + + # Lock to prevent multiple generations from happening at the same time + self.generate_lock = threading.Lock() + + # Capture CUDA graphs for each KV cache length + if device == "cuda": + logger.info(f"Capturing CUDA graphs for each KV cache length") + self.model.capture_model(self.kv_caches.values()) + + def _prepare_inputs(self, chat_ml_sample: ChatMLSample, force_audio_gen: bool = False): + input_tokens, _, audio_contents, _ = prepare_chatml_sample( + chat_ml_sample, + self.tokenizer, + ) + + postfix = "<|start_header_id|>assistant<|end_header_id|>\n\n" + if force_audio_gen: + postfix += "<|audio_out_bos|>" + postfix = self.tokenizer.encode(postfix, add_special_tokens=False) + input_tokens.extend(postfix) + + # Configure the audio inputs + audio_ids_l = [] + for audio_content in audio_contents: + if audio_content.audio_url not in ["placeholder", ""]: + raw_audio, _ = librosa.load(audio_content.audio_url, sr=self.audio_tokenizer.sampling_rate) + elif audio_content.raw_audio is not None: + raw_audio, _ = librosa.load( + BytesIO(base64.b64decode(audio_content.raw_audio)), + sr=self.audio_tokenizer.sampling_rate, + ) + else: + raw_audio = None + + if raw_audio is not None: + audio_ids = self.audio_tokenizer.encode(raw_audio, self.audio_tokenizer.sampling_rate) + audio_ids_l.append(audio_ids.squeeze(0).cpu()) + + if len(audio_ids_l) > 0: + audio_ids_start = torch.tensor( + np.cumsum(np.array([0] + [audio_ids.shape[1] for audio_ids in audio_ids_l])), + dtype=torch.long, + device=self.device, + )[0:-1] + audio_ids_concat = torch.cat(audio_ids_l, dim=1) + else: + audio_ids_start = None + audio_ids_concat = None + + sample = ChatMLDatasetSample( + input_ids=torch.LongTensor(input_tokens), + label_ids=None, + audio_ids_concat=audio_ids_concat, + audio_ids_start=audio_ids_start, + audio_waveforms_concat=None, + audio_waveforms_start=None, + audio_sample_rate=None, + audio_speaker_indices=None, + ) + data = self.collator([sample]) + inputs = asdict(data) + for k, v in inputs.items(): + if isinstance(v, torch.Tensor): + inputs[k] = v.to(self.model.device) + + return inputs + + def _prepare_kv_caches(self): + for kv_cache in self.kv_caches.values(): + kv_cache.reset() + + def generate( + self, + chat_ml_sample: ChatMLSample, + max_new_tokens: int, + temperature: float = 0.7, + top_k: Optional[int] = None, + top_p: float = 0.95, + stop_strings: Optional[List[str]] = None, + force_audio_gen: bool = False, + ras_win_len: Optional[int] = None, + ras_win_max_num_repeat: int = 2, + ): + """ + Generate audio from a chatml sample. + Args: + chat_ml_sample: A chatml sample. + max_new_tokens: The maximum number of new tokens to generate. + temperature: The temperature to use for the generation. + top_p: The top p to use for the generation. + Returns: + A dictionary with the following keys: + audio: The generated audio. + sampling_rate: The sampling rate of the generated audio. + """ + # Default stop strings + if stop_strings is None: + stop_strings = ["<|end_of_text|>", "<|eot_id|>"] + + with torch.no_grad(), self.generate_lock: + inputs = self._prepare_inputs(chat_ml_sample, force_audio_gen=force_audio_gen) + prompt_token_ids = inputs["input_ids"][0].cpu().numpy() + + self._prepare_kv_caches() + + outputs = self.model.generate( + **inputs, + max_new_tokens=max_new_tokens, + use_cache=True, + stop_strings=stop_strings, + tokenizer=self.tokenizer, + do_sample=False if temperature == 0.0 else True, + temperature=temperature, + top_k=top_k, + top_p=top_p, + past_key_values_buckets=self.kv_caches, + ras_win_len=ras_win_len, + ras_win_max_num_repeat=ras_win_max_num_repeat, + ) + + if len(outputs[1]) > 0: + wv_list = [] + for output_audio in outputs[1]: + vq_code = revert_delay_pattern(output_audio).clip(0, self.audio_codebook_size - 1)[:, 1:-1] + wv_numpy = self.audio_tokenizer.decode(vq_code.unsqueeze(0))[0, 0] + wv_list.append(wv_numpy) + wv_numpy = np.concatenate(wv_list) + else: + wv_numpy = None + + # We only support one request at a time now + generated_text_tokens = outputs[0][0].cpu().numpy()[len(prompt_token_ids) :] + generated_text = self.tokenizer.decode(generated_text_tokens) + generated_audio_tokens = outputs[1][0].cpu().numpy() + return HiggsAudioResponse( + audio=wv_numpy, + generated_audio_tokens=generated_audio_tokens, + sampling_rate=self.audio_tokenizer.sampling_rate, + generated_text=generated_text, + generated_text_tokens=generated_text_tokens, + usage={ + "prompt_tokens": prompt_token_ids.shape[0], + "completion_tokens": generated_text_tokens.shape[0] + generated_audio_tokens.shape[1], + "total_tokens": ( + prompt_token_ids.shape[0] + generated_text_tokens.shape[0] + generated_audio_tokens.shape[1] + ), + "cached_tokens": 0, + }, + ) diff --git a/higgs_audio/serve/utils.py b/higgs_audio/serve/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..204d195a85c1de44060c88a9552640e07a6ba1b5 --- /dev/null +++ b/higgs_audio/serve/utils.py @@ -0,0 +1,254 @@ +import uuid +import base64 +import re +import regex +from typing import AsyncGenerator, Union +import io +from pydub import AudioSegment +import torch +import numpy as np +from functools import lru_cache + +from ..audio_processing.higgs_audio_tokenizer import HiggsAudioTokenizer + + +def random_uuid() -> str: + return str(uuid.uuid4().hex) + + +async def async_generator_wrap(first_element, gen: AsyncGenerator): + """Wrap an async generator with the first element.""" + yield first_element + async for item in gen: + yield item + + +@lru_cache(maxsize=50) +def encode_base64_content_from_file(file_path: str) -> str: + """Encode a content from a local file to base64 format.""" + # Read the MP3 file as binary and encode it directly to Base64 + with open(file_path, "rb") as audio_file: + audio_base64 = base64.b64encode(audio_file.read()).decode("utf-8") + return audio_base64 + + +def pcm16_to_target_format( + np_audio: np.ndarray, + sample_rate: int, + bit_depth: int, + channels: int, + format: str, + target_rate: int, +): + wav_audio = AudioSegment( + np_audio.tobytes(), + frame_rate=sample_rate, + sample_width=bit_depth // 8, + channels=channels, + ) + if target_rate is not None and target_rate != sample_rate: + wav_audio = wav_audio.set_frame_rate(target_rate) + + # Convert WAV to MP3 + target_io = io.BytesIO() + wav_audio.export(target_io, format=format) + target_io.seek(0) + + return target_io + + +chinese_char_pattern = re.compile(r"[\u4e00-\u9fff]+") + + +def contains_chinese(text: str): + return bool(chinese_char_pattern.search(text)) + + +# remove blank between chinese character +def replace_blank(text: str): + out_str = [] + for i, c in enumerate(text): + if c == " ": + if (text[i + 1].isascii() and text[i + 1] != " ") and (text[i - 1].isascii() and text[i - 1] != " "): + out_str.append(c) + else: + out_str.append(c) + return "".join(out_str) + + +def replace_corner_mark(text: str): + text = text.replace("²", "平方") + text = text.replace("³", "立方") + return text + + +# remove meaningless symbol +def remove_bracket(text: str): + text = text.replace("(", "").replace(")", "") + text = text.replace("【", "").replace("】", "") + text = text.replace("`", "").replace("`", "") + text = text.replace("——", " ") + return text + + +# split paragrah logic: +# 1. per sentence max len token_max_n, min len token_min_n, merge if last sentence len less than merge_len +# 2. cal sentence len according to lang +# 3. split sentence according to puncatation +def split_paragraph( + text: str, + tokenize, + lang="zh", + token_max_n=80, + token_min_n=60, + merge_len=20, + comma_split=False, +): + def calc_utt_length(_text: str): + if lang == "zh": + return len(_text) + else: + return len(tokenize(_text)) + + def should_merge(_text: str): + if lang == "zh": + return len(_text) < merge_len + else: + return len(tokenize(_text)) < merge_len + + if lang == "zh": + pounc = ["。", "?", "!", ";", ":", "、", ".", "?", "!", ";"] + else: + pounc = [".", "?", "!", ";", ":"] + if comma_split: + pounc.extend([",", ","]) + + if text[-1] not in pounc: + if lang == "zh": + text += "。" + else: + text += "." + + st = 0 + utts = [] + for i, c in enumerate(text): + if c in pounc: + if len(text[st:i]) > 0: + utts.append(text[st:i] + c) + if i + 1 < len(text) and text[i + 1] in ['"', "”"]: + tmp = utts.pop(-1) + utts.append(tmp + text[i + 1]) + st = i + 2 + else: + st = i + 1 + + final_utts = [] + cur_utt = "" + for utt in utts: + if calc_utt_length(cur_utt + utt) > token_max_n and calc_utt_length(cur_utt) > token_min_n: + final_utts.append(cur_utt) + cur_utt = "" + cur_utt = cur_utt + utt + if len(cur_utt) > 0: + if should_merge(cur_utt) and len(final_utts) != 0: + final_utts[-1] = final_utts[-1] + cur_utt + else: + final_utts.append(cur_utt) + + return final_utts + + +def is_only_punctuation(text: str): + # Regular expression: Match strings that consist only of punctuation marks or are empty. + punctuation_pattern = r"^[\p{P}\p{S}]*$" + return bool(regex.fullmatch(punctuation_pattern, text)) + + +# spell Arabic numerals +def spell_out_number(text: str, inflect_parser): + new_text = [] + st = None + for i, c in enumerate(text): + if not c.isdigit(): + if st is not None: + num_str = inflect_parser.number_to_words(text[st:i]) + new_text.append(num_str) + st = None + new_text.append(c) + else: + if st is None: + st = i + if st is not None and st < len(text): + num_str = inflect_parser.number_to_words(text[st:]) + new_text.append(num_str) + return "".join(new_text) + + +def remove_emoji(text: str): + # Pattern to match emojis and their modifiers + # - Standard emoji range + # - Zero-width joiners (U+200D) + # - Variation selectors (U+FE0F, U+FE0E) + # - Skin tone modifiers (U+1F3FB to U+1F3FF) + emoji_pattern = re.compile( + r"[" + r"\U00010000-\U0010FFFF" # Standard emoji range + r"\u200D" # Zero-width joiner + r"\uFE0F\uFE0E" # Variation selectors + r"\U0001F3FB-\U0001F3FF" # Skin tone modifiers + r"]+", + flags=re.UNICODE, + ) + return emoji_pattern.sub(r"", text) + + +def remove_repeated_punctuations(text, punctuations): + if len(punctuations) == 0: + return text + pattern = f"[{re.escape(''.join(punctuations))}]" # Create regex pattern for given punctuations + return re.sub(rf"({pattern})\1+", r"\1", text) + + +def full_to_half_width(text: str) -> str: + """Convert full-width punctuation to half-width in a given string.""" + full_width = "!"#$%&'()*+,-./:;<=>?@[\]^_`{|}~" + half_width = "!\"#$%&'()*+,-./:;<=>?@[\\]^_`{|}~" + trans_table = str.maketrans(full_width, half_width) + return text.translate(trans_table) + + +def split_interleaved_delayed_audios( + audio_data: Union[list[list[int]], torch.Tensor], + audio_tokenizer: HiggsAudioTokenizer, + audio_stream_eos_id: int, +) -> list[tuple[list[list[int]], torch.Tensor]]: + separator = [audio_stream_eos_id] * audio_tokenizer.num_codebooks + + # Convert separator to numpy array if audio_data is numpy array + if isinstance(audio_data, torch.Tensor): + audio_data = audio_data.transpose(1, 0) + separator = torch.tensor(separator) + # Find the indices where the rows equal the separator + split_indices = torch.where(torch.all(audio_data == separator, dim=1))[0] + start = 0 + groups = [] + for idx in split_indices: + groups.append(audio_data[start:idx].transpose(1, 0)) + start = idx + 1 + if start < len(audio_data): + groups.append(audio_data[start:].transpose(1, 0)) + else: + groups = [] + current = [] + for row in audio_data: + current.append(row) + + if row == separator: + groups.append(current) + current = [] + + # Don't forget the last group if there's no trailing separator + if current: + groups.append(current) + + return groups diff --git a/pyproject.toml b/pyproject.toml new file mode 100644 index 0000000000000000000000000000000000000000..83f40b6116f1459a5191911a90491a3921be5af6 --- /dev/null +++ b/pyproject.toml @@ -0,0 +1,100 @@ +[build-system] +requires = ["setuptools"] +build-backend = "setuptools.build_meta" + +[tool.ruff] +line-length = 119 +target-version = "py310" +indent-width = 4 +exclude = [ + ".bzr", + ".direnv", + ".eggs", + ".git", + ".git-rewrite", + ".hg", + ".ipynb_checkpoints", + ".mypy_cache", + ".nox", + ".pants.d", + ".pyenv", + ".pytest_cache", + ".pytype", + ".ruff_cache", + ".svn", + ".tox", + ".venv", + ".vscode", + "__pypackages__", + "_build", + "buck-out", + "build", + "dist", + "node_modules", + "site-packages", + "venv", + "external", + "third_party", +] + +[tool.ruff.lint] +preview = true +ignore-init-module-imports = true +extend-select = [ + "B009", # static getattr + "B010", # static setattr + "CPY", # Copyright + "E", # PEP8 errors + "F", # PEP8 formatting + "I", # Import sorting + "TID251", # Banned API + "UP", # Pyupgrade + "W", # PEP8 warnings +] +ignore = [ + "E501", # Line length (handled by ruff-format) + "E741", # Ambiguous variable name + "W605", # Invalid escape sequence + "UP007", # X | Y type annotations +] + +[tool.ruff.lint.per-file-ignores] +"__init__.py" = [ + "F401", # Ignore seemingly unused imports (they're meant for re-export) +] + +[tool.ruff.lint.isort] +lines-after-imports = 2 +known-first-party = ["character_tuning"] + +[tool.ruff.format] +# Like Black, use double quotes for strings. +quote-style = "double" + +# Like Black, indent with spaces, rather than tabs. +indent-style = "space" + +# Like Black, respect magic trailing commas. +skip-magic-trailing-comma = false + +# Like Black, automatically detect the appropriate line ending. +line-ending = "auto" + +# Enable auto-formatting of code examples in docstrings. Markdown, +# reStructuredText code/literal blocks and doctests are all supported. +# +# This is currently disabled by default, but it is planned for this +# to be opt-out in the future. +docstring-code-format = false + +# Set the line length limit used when formatting code snippets in +# docstrings. +# +# This only has an effect when the `docstring-code-format` setting is +# enabled. +docstring-code-line-length = "dynamic" + +[tool.ruff.lint.flake8-tidy-imports.banned-api] +"os.getenv".msg = "Use os.environ instead" +"os.putenv".msg = "Use os.environ instead" +"os.unsetenv".msg = "Use os.environ instead" diff --git a/requirements.txt b/requirements.txt new file mode 100644 index 0000000000000000000000000000000000000000..4c6b5d66b21ada748ba05c12df30178ab48aa2f7 --- /dev/null +++ b/requirements.txt @@ -0,0 +1,17 @@ +descript-audio-codec +torch==2.5.1 +torchaudio==2.5.1 +transformers>=4.45.1,<4.47.0 +librosa +dacite +boto3==1.35.36 +s3fs +json_repair +pandas +pydantic +vector_quantize_pytorch +loguru +pydub +ruff==0.12.2 +omegaconf +click