import librosa import torch import torch.nn.functional as F import math import numpy as np from typing import List, Tuple, Dict from dataclasses import dataclass from typing import List, Optional from transformers.models.whisper.processing_whisper import WhisperProcessor from ..dataset.chatml_dataset import ChatMLDatasetSample, RankedChatMLDatasetSampleTuple from ..model.utils import build_delay_pattern_mask def _ceil_to_nearest(n, round_to): return (n + round_to - 1) // round_to * round_to @dataclass class HiggsAudioBatchInput: input_ids: torch.LongTensor # shape (bsz, seq_len). attention_mask: torch.Tensor # shape (bsz, seq_len). audio_features: Optional[torch.Tensor] # shape (num_audio_in, feature_dim, max_mel_seq_len). audio_feature_attention_mask: Optional[torch.Tensor] # shape (num_audio_in, max_mel_seq_len). audio_out_ids: Optional[torch.LongTensor] # shape (num_codebooks, audio_out_total_length) audio_out_ids_start: Optional[torch.LongTensor] # shape (num_audio_out,) # The audio_out_ids_start_group_loc has the same length as audio_out_ids_start. It is used to recover group location in a batch for an audio segment # Currently, we concatenante audio segments along dim 0 to handle variadic audio segment length. However, in the alignment stage, we need the location information # For example, # audio_out_ids_start = [0, 2, 4, 8]; and the first two audio segments come from the same sample in a batch, and other two come from different samples. # This is a batch of 3 samples, then we will have the group location as: # audio_out_ids_start_group_loc = [0, 0, 1, 2] audio_out_ids_start_group_loc: Optional[ torch.LongTensor ] # shape (num_audio_out,), specify which a sample's group location in the batch audio_in_ids: Optional[torch.LongTensor] # shape (num_codebooks, audio_in_total_length) audio_in_ids_start: Optional[torch.LongTensor] # shape (num_audio_in,) label_ids: Optional[torch.LongTensor] # shape (bsz, seq_len) label_audio_ids: Optional[torch.LongTensor] # shape (num_codebooks, audio_out_total_length) reward: Optional[float] = None class HiggsAudioSampleCollator: """Sample collator for Higgs-Audio model. Args: whisper_processor (WhisperProcessor): The whisper processor. audio_in_token_id (int): The token id for audio-in. audio_out_token_id (int): The token id for audio-out. pad_token_id (int): The token id for padding. audio_stream_bos_id (int): The token id for audio-stream beginning of sentence. audio_stream_eos_id (int): The token id for audio-stream end of sentence. round_to (int): The round-to value. pad_left (bool): Whether to pad left. return_audio_in_tokens (bool): Whether to return audio-in tokens. use_delay_pattern (bool): Whether to use delay pattern. disable_audio_codes_transform (bool): Whether to add bos and eos tokens to audio codes. chunk_size_seconds (int): The chunk size in seconds. add_new_bos_eos_for_long_chunk (bool): Whether to add new bos and eos tokens for long chunks. mask_audio_out_token_label (bool): Whether to always mask the label associated with <|AUDIO_OUT|> token. Since we will always have `<|AUDIO_OUT|>` after `<|audio_bos|>`, we can safely mask <|AUDIO_OUT|>. """ def __init__( self, whisper_processor: WhisperProcessor, audio_in_token_id, audio_out_token_id, pad_token_id, audio_stream_bos_id, audio_stream_eos_id, round_to=8, pad_left=False, encode_whisper_embed=True, return_audio_in_tokens=True, audio_num_codebooks=None, use_delay_pattern=False, disable_audio_codes_transform=False, chunk_size_seconds=30, # Maximum duration for each chunk add_new_bos_eos_for_long_chunk=True, mask_audio_out_token_label=True, ): self.whisper_processor = whisper_processor self.round_to = round_to self.pad_left = pad_left self.audio_in_token_id = audio_in_token_id self.audio_out_token_id = audio_out_token_id self.audio_stream_bos_id = audio_stream_bos_id self.audio_stream_eos_id = audio_stream_eos_id self.pad_token_id = pad_token_id self.encode_whisper_embed = encode_whisper_embed self.return_audio_in_tokens = return_audio_in_tokens self.audio_num_codebooks = audio_num_codebooks self.use_delay_pattern = use_delay_pattern if encode_whisper_embed: self.chunk_size_seconds = chunk_size_seconds self.chunk_size_samples = int(chunk_size_seconds * whisper_processor.feature_extractor.sampling_rate) else: self.chunk_size_seconds = None self.chunk_size_samples = None self.disable_audio_codes_transform = disable_audio_codes_transform self.add_new_bos_eos_for_long_chunk = add_new_bos_eos_for_long_chunk self.mask_audio_out_token_label = mask_audio_out_token_label def _process_and_duplicate_audio_tokens( self, input_ids: torch.Tensor, audio_idx: int, wv: torch.Tensor, sr: int, labels: Optional[torch.Tensor] = None, ) -> Tuple[torch.Tensor, torch.Tensor, int]: """Process long audio and duplicate corresponding audio tokens. Args: input_ids: Input token ids audio_idx: Index of the audio token in the sequence wv: Audio waveform sr: Sample rate labels: Optional label ids to be duplicated alongside input ids Returns: Tuple of: - New input ids with duplicated audio tokens - New label ids (if labels were provided) or None - Number of chunks created """ # Calculate number of chunks needed total_samples = len(wv) num_chunks = math.ceil(total_samples / self.chunk_size_samples) if num_chunks <= 1: return input_ids, labels, 1 # Get the three tokens: <|audio_bos|><|AUDIO|><|audio_eos|> audio_token_seq = input_ids[audio_idx - 1 : audio_idx + 2] # Duplicate sequence for each chunk duplicated_sequence = audio_token_seq.repeat(num_chunks) # Create new input_ids with duplicated tokens new_input_ids = torch.cat( [ input_ids[: audio_idx - 1], duplicated_sequence, input_ids[audio_idx + 2 :], ] ) # If labels are provided, duplicate them as well new_labels = None if labels is not None: label_seq = labels[audio_idx - 1 : audio_idx + 2] duplicated_labels = label_seq.repeat(num_chunks) new_labels = torch.cat([labels[: audio_idx - 1], duplicated_labels, labels[audio_idx + 2 :]]) return new_input_ids, new_labels, num_chunks def __call__(self, batch: List[ChatMLDatasetSample]): """Collate the input data with support for long audio processing.""" label_ids = None label_audio_ids = None if all([ele.label_ids is None for ele in batch]): return_labels = False else: return_labels = True if self.encode_whisper_embed: # Process each sample in the batch to handle long audio # TODO(?) The implementation here can be optimized. processed_batch = [] for i in range(len(batch)): sample = batch[i] audio_in_mask = sample.input_ids == self.audio_in_token_id audio_in_indices = torch.where(audio_in_mask)[0] audio_out_mask = sample.input_ids == self.audio_out_token_id # Process each audio token and duplicate if needed modified_input_ids = sample.input_ids modified_labels = sample.label_ids if return_labels else None modified_waveforms_concat = [] modified_waveforms_start = [] modified_sample_rate = [] offset = 0 # Track position changes from duplicating tokens curr_wv_offset = 0 # Process input audio tokens for idx, audio_idx in enumerate(audio_in_indices): # Get the audio for this token wv, sr = sample.get_wv(idx) # Use idx since we want the original audio index if sr != self.whisper_processor.feature_extractor.sampling_rate: resampled_wv = librosa.resample( wv.cpu().numpy(), orig_sr=sr, target_sr=self.whisper_processor.feature_extractor.sampling_rate, ) else: resampled_wv = wv.cpu().numpy() wv = torch.tensor(resampled_wv, device=wv.device) sr = self.whisper_processor.feature_extractor.sampling_rate # Process and duplicate tokens if necessary token_pos = audio_idx + offset modified_input_ids, modified_labels, num_chunks = self._process_and_duplicate_audio_tokens( modified_input_ids, token_pos, wv, sr, modified_labels ) # Update audio data for chunk_idx in range(num_chunks): chunk_start = chunk_idx * self.chunk_size_samples chunk_end = min((chunk_idx + 1) * self.chunk_size_samples, len(wv)) chunk_wv = wv[chunk_start:chunk_end] modified_waveforms_concat.append(chunk_wv) modified_waveforms_start.append(curr_wv_offset) curr_wv_offset += len(chunk_wv) modified_sample_rate.append(sr) # Update offset for next iteration offset += (num_chunks - 1) * 3 # Each new chunk adds 3 more tokens # Create new sample with modified tokens and audio data processed_sample = ChatMLDatasetSample( input_ids=modified_input_ids, label_ids=modified_labels if return_labels else sample.label_ids, audio_ids_concat=sample.audio_ids_concat, audio_ids_start=sample.audio_ids_start, audio_waveforms_concat=torch.cat(modified_waveforms_concat) if modified_waveforms_concat else sample.audio_waveforms_concat, audio_waveforms_start=torch.tensor(modified_waveforms_start, dtype=torch.long) if modified_waveforms_start else sample.audio_waveforms_start, audio_sample_rate=torch.tensor(modified_sample_rate) if modified_sample_rate else sample.audio_sample_rate, audio_speaker_indices=torch.tensor([]), # FIXME(sxjscience): The logic here is not correct for audio_label_ids_concat. audio_label_ids_concat=sample.audio_label_ids_concat, ) # audio_in_chunk_len = len(torch.where(modified_input_ids == self.audio_in_token_id)[0]) # assert audio_in_chunk_len == processed_sample.num_audios(), f"Mismatch: audio_in_chunk_len={audio_in_chunk_len}, processed_sample.num_audios()={processed_sample.num_audios()}" processed_batch.append(processed_sample) else: processed_batch = batch # Get the max sequence length based on processed batch max_seq_length = _ceil_to_nearest(max([len(sample.input_ids) for sample in processed_batch]), self.round_to) # Get the ids for audio-in and audio-out for each batch audio_in_wv_l = [] audio_in_ids_l = [] audio_out_ids_l = [] audio_out_ids_group_loc_l = [] audio_in_label_ids_l = None audio_out_label_ids_l = None reward_l = [] if return_labels: audio_out_no_train_flag = [] # Whether the audio-out data should be trained on or not. # Process the audio inputs and outputs for i in range(len(processed_batch)): audio_in_mask = processed_batch[i].input_ids == self.audio_in_token_id audio_out_mask = processed_batch[i].input_ids == self.audio_out_token_id audio_ids = torch.ones_like(processed_batch[i].input_ids) audio_ids[audio_in_mask ^ audio_out_mask] = torch.cumsum(audio_ids[audio_in_mask ^ audio_out_mask], 0) - 1 audio_in_ids = audio_ids[audio_in_mask] audio_out_ids = audio_ids[audio_out_mask] if return_labels: audio_out_no_train_flag.append(processed_batch[i].label_ids[audio_out_mask] < 0) if self.mask_audio_out_token_label: processed_batch[i].label_ids[audio_out_mask] = -100 # Process audio inputs if self.return_audio_in_tokens: audio_in_ids_l.extend( [processed_batch[i].get_audio_codes(idx)[: self.audio_num_codebooks, :] for idx in audio_in_ids] ) if processed_batch[i].audio_label_ids_concat is not None: if audio_in_label_ids_l is None: audio_in_label_ids_l = [] audio_in_label_ids_l.extend( [ processed_batch[i].get_audio_codes_labels(idx)[: self.audio_num_codebooks, :] for idx in audio_in_ids ] ) audio_out_ids_l.extend( [processed_batch[i].get_audio_codes(idx)[: self.audio_num_codebooks, :] for idx in audio_out_ids] ) audio_out_ids_group_loc_l.append(i) if processed_batch[i].reward is not None: reward_l.append(processed_batch[i].reward) if processed_batch[i].audio_label_ids_concat is not None: if audio_out_label_ids_l is None: audio_out_label_ids_l = [] audio_out_label_ids_l.extend( [ processed_batch[i].get_audio_codes_labels(idx)[: self.audio_num_codebooks, :] for idx in audio_out_ids ] ) if self.encode_whisper_embed: for idx in audio_in_ids: wv, sr = processed_batch[i].get_wv(idx) resampled_wv = wv.cpu().numpy() # Split long audio into chunks total_samples = len(resampled_wv) for chunk_start in range(0, total_samples, self.chunk_size_samples): chunk_end = min(chunk_start + self.chunk_size_samples, total_samples) chunk = resampled_wv[chunk_start:chunk_end] audio_in_wv_l.append(chunk) # assert len(audio_in_wv_l) == processed_batch[i].num_audios(), \ # f"Assertion failed: Mismatch in number of audios. " \ # f"Expected {processed_batch[i].num_audios()}, but got {len(audio_in_wv_l)} at index {i}." if return_labels: audio_out_no_train_flag = torch.cat(audio_out_no_train_flag, dim=0) # Process all audio features if len(audio_in_wv_l) > 0: feature_ret = self.whisper_processor.feature_extractor( audio_in_wv_l, sampling_rate=self.whisper_processor.feature_extractor.sampling_rate, return_attention_mask=True, padding="max_length", ) audio_features = torch.from_numpy(feature_ret["input_features"]) audio_feature_attention_mask = torch.from_numpy(feature_ret["attention_mask"]) else: if self.encode_whisper_embed: audio_features = torch.zeros( ( 0, self.whisper_processor.feature_extractor.feature_size, self.whisper_processor.feature_extractor.nb_max_frames, ), dtype=torch.float32, ) audio_feature_attention_mask = torch.zeros( (0, self.whisper_processor.feature_extractor.nb_max_frames), dtype=torch.int32, ) else: audio_features = None audio_feature_attention_mask = None # Process audio input tokens if len(audio_in_ids_l) > 0: # Append audio-stream-bos and eos tokens new_audio_in_ids_l = [] for ele in audio_in_ids_l: if self.disable_audio_codes_transform: # Do not add audio-stream-bos or eos tokens. # This may indicate that the sample comes from ConstantLengthDatasetWithBuffer. audio_codes = ele else: audio_codes = torch.cat( [ torch.full( (ele.shape[0], 1), self.audio_stream_bos_id, dtype=torch.long, ), ele, torch.full( (ele.shape[0], 1), self.audio_stream_eos_id, dtype=torch.long, ), ], dim=1, ) if self.use_delay_pattern: audio_codes = build_delay_pattern_mask( audio_codes.unsqueeze(0), bos_token_id=self.audio_stream_bos_id, pad_token_id=self.audio_stream_eos_id, )[0].squeeze(0) new_audio_in_ids_l.append(audio_codes) audio_in_ids = torch.cat(new_audio_in_ids_l, dim=1).long() audio_in_ids_start = torch.cumsum( torch.tensor([0] + [audio_codes.shape[1] for audio_codes in new_audio_in_ids_l[:-1]]), dim=0, ) else: audio_in_ids = torch.zeros((0, 0), dtype=torch.long) audio_in_ids_start = torch.zeros(0, dtype=torch.long) # Process audio output tokens audio_out_ids_start_group_loc = None if len(audio_out_ids_l) > 0: new_audio_out_ids_l = [] label_audio_ids_l = [] for idx, ele in enumerate(audio_out_ids_l): if self.disable_audio_codes_transform: # Do not add audio-stream-bos or eos tokens. # This may indicate that the sample comes from ConstantLengthDatasetWithBuffer. audio_codes = ele if return_labels: label_audio_ids = audio_out_label_ids_l[idx] else: audio_codes = torch.cat( [ torch.full( (ele.shape[0], 1), self.audio_stream_bos_id, dtype=torch.long, ), ele, torch.full( (ele.shape[0], 1), self.audio_stream_eos_id, dtype=torch.long, ), ], dim=1, ) if return_labels: label_audio_ids = torch.cat( [ torch.full((ele.shape[0], 1), -100, dtype=torch.long), ele, torch.full( (ele.shape[0], 1), self.audio_stream_eos_id, dtype=torch.long, ), ], dim=1, ) if self.use_delay_pattern: audio_codes = build_delay_pattern_mask( audio_codes.unsqueeze(0), bos_token_id=self.audio_stream_bos_id, pad_token_id=self.audio_stream_eos_id, )[0].squeeze(0) if return_labels: label_audio_ids = build_delay_pattern_mask( label_audio_ids.unsqueeze(0), bos_token_id=-100, pad_token_id=-100, )[0].squeeze(0) new_audio_out_ids_l.append(audio_codes) if return_labels: if audio_out_no_train_flag[idx]: label_audio_ids[:] = -100 label_audio_ids_l.append(label_audio_ids) audio_out_ids = torch.cat(new_audio_out_ids_l, dim=1).long() if return_labels: label_audio_ids = torch.cat(label_audio_ids_l, dim=1).long() audio_out_ids_start = torch.cumsum( torch.tensor([0] + [audio_codes.shape[1] for audio_codes in new_audio_out_ids_l[:-1]]), dim=0, ) audio_out_ids_start_group_loc = torch.tensor(audio_out_ids_group_loc_l, dtype=torch.long) else: audio_out_ids = torch.zeros((0, 0), dtype=torch.long) audio_out_ids_start = torch.zeros(0, dtype=torch.long) if return_labels: label_audio_ids = torch.zeros((0, 0), dtype=torch.long) reward = torch.tensor(reward_l, dtype=torch.float32) # Handle padding for input ids and attention mask if self.pad_left: input_ids = torch.stack( [ F.pad( ele.input_ids, (max_seq_length - len(ele.input_ids), 0), value=self.pad_token_id, ) for ele in processed_batch ] ) if return_labels: label_ids = torch.stack( [ F.pad( ele.label_ids, (max_seq_length - len(ele.label_ids), 0), value=-100, ) for ele in processed_batch ] ) attention_mask = torch.stack( [ F.pad( torch.ones_like(ele.input_ids), (max_seq_length - len(ele.input_ids), 0), value=0, ) for ele in processed_batch ] ) else: input_ids = torch.stack( [ F.pad( ele.input_ids, (0, max_seq_length - len(ele.input_ids)), value=self.pad_token_id, ) for ele in processed_batch ] ) if return_labels: label_ids = torch.stack( [ F.pad( ele.label_ids, (0, max_seq_length - len(ele.label_ids)), value=-100, ) for ele in processed_batch ] ) attention_mask = torch.stack( [ F.pad( torch.ones_like(ele.input_ids), (0, max_seq_length - len(ele.input_ids)), value=0, ) for ele in processed_batch ] ) if not self.return_audio_in_tokens: audio_in_ids = None audio_in_ids_start = None # Apply audio_num_codebooks limit if specified if self.audio_num_codebooks is not None: if audio_in_ids is not None: audio_in_ids = audio_in_ids[: self.audio_num_codebooks] if audio_out_ids is not None: audio_out_ids = audio_out_ids[: self.audio_num_codebooks] if label_audio_ids is not None: label_audio_ids = label_audio_ids[: self.audio_num_codebooks] return HiggsAudioBatchInput( input_ids=input_ids, attention_mask=attention_mask, audio_features=audio_features, audio_feature_attention_mask=audio_feature_attention_mask, audio_out_ids=audio_out_ids, audio_out_ids_start=audio_out_ids_start, audio_out_ids_start_group_loc=audio_out_ids_start_group_loc, audio_in_ids=audio_in_ids, audio_in_ids_start=audio_in_ids_start, label_ids=label_ids, label_audio_ids=label_audio_ids, reward=reward, ) class HiggsAudioDPOSamplesCollator(HiggsAudioSampleCollator): def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) def __call__(self, batch: List[RankedChatMLDatasetSampleTuple]) -> HiggsAudioBatchInput: # flatten ranked chatml samples chosen = [] rejected = [] for sample in batch: chosen.append(sample.max_score_sample()) rejected.append(sample.min_score_sample()) merged = chosen merged.extend(rejected) return super().__call__(batch=merged)