Spaces:
Running
on
Zero
Running
on
Zero
| import librosa | |
| import torch | |
| import torch.nn.functional as F | |
| import math | |
| import numpy as np | |
| from typing import List, Tuple, Dict | |
| from dataclasses import dataclass | |
| from typing import List, Optional | |
| from transformers.models.whisper.processing_whisper import WhisperProcessor | |
| from ..dataset.chatml_dataset import ChatMLDatasetSample, RankedChatMLDatasetSampleTuple | |
| from ..model.utils import build_delay_pattern_mask | |
| def _ceil_to_nearest(n, round_to): | |
| return (n + round_to - 1) // round_to * round_to | |
| class HiggsAudioBatchInput: | |
| input_ids: torch.LongTensor # shape (bsz, seq_len). | |
| attention_mask: torch.Tensor # shape (bsz, seq_len). | |
| audio_features: Optional[torch.Tensor] # shape (num_audio_in, feature_dim, max_mel_seq_len). | |
| audio_feature_attention_mask: Optional[torch.Tensor] # shape (num_audio_in, max_mel_seq_len). | |
| audio_out_ids: Optional[torch.LongTensor] # shape (num_codebooks, audio_out_total_length) | |
| audio_out_ids_start: Optional[torch.LongTensor] # shape (num_audio_out,) | |
| # The audio_out_ids_start_group_loc has the same length as audio_out_ids_start. It is used to recover group location in a batch for an audio segment | |
| # Currently, we concatenante audio segments along dim 0 to handle variadic audio segment length. However, in the alignment stage, we need the location information | |
| # For example, | |
| # audio_out_ids_start = [0, 2, 4, 8]; and the first two audio segments come from the same sample in a batch, and other two come from different samples. | |
| # This is a batch of 3 samples, then we will have the group location as: | |
| # audio_out_ids_start_group_loc = [0, 0, 1, 2] | |
| audio_out_ids_start_group_loc: Optional[ | |
| torch.LongTensor | |
| ] # shape (num_audio_out,), specify which a sample's group location in the batch | |
| audio_in_ids: Optional[torch.LongTensor] # shape (num_codebooks, audio_in_total_length) | |
| audio_in_ids_start: Optional[torch.LongTensor] # shape (num_audio_in,) | |
| label_ids: Optional[torch.LongTensor] # shape (bsz, seq_len) | |
| label_audio_ids: Optional[torch.LongTensor] # shape (num_codebooks, audio_out_total_length) | |
| reward: Optional[float] = None | |
| class HiggsAudioSampleCollator: | |
| """Sample collator for Higgs-Audio model. | |
| Args: | |
| whisper_processor (WhisperProcessor): The whisper processor. | |
| audio_in_token_id (int): The token id for audio-in. | |
| audio_out_token_id (int): The token id for audio-out. | |
| pad_token_id (int): The token id for padding. | |
| audio_stream_bos_id (int): The token id for audio-stream beginning of sentence. | |
| audio_stream_eos_id (int): The token id for audio-stream end of sentence. | |
| round_to (int): The round-to value. | |
| pad_left (bool): Whether to pad left. | |
| return_audio_in_tokens (bool): Whether to return audio-in tokens. | |
| use_delay_pattern (bool): Whether to use delay pattern. | |
| disable_audio_codes_transform (bool): Whether to add bos and eos tokens to audio codes. | |
| chunk_size_seconds (int): The chunk size in seconds. | |
| add_new_bos_eos_for_long_chunk (bool): Whether to add new bos and eos tokens for long chunks. | |
| mask_audio_out_token_label (bool): Whether to always mask the label associated with <|AUDIO_OUT|> token. Since we will always have `<|AUDIO_OUT|>` after `<|audio_bos|>`, we can safely mask <|AUDIO_OUT|>. | |
| """ | |
| def __init__( | |
| self, | |
| whisper_processor: WhisperProcessor, | |
| audio_in_token_id, | |
| audio_out_token_id, | |
| pad_token_id, | |
| audio_stream_bos_id, | |
| audio_stream_eos_id, | |
| round_to=8, | |
| pad_left=False, | |
| encode_whisper_embed=True, | |
| return_audio_in_tokens=True, | |
| audio_num_codebooks=None, | |
| use_delay_pattern=False, | |
| disable_audio_codes_transform=False, | |
| chunk_size_seconds=30, # Maximum duration for each chunk | |
| add_new_bos_eos_for_long_chunk=True, | |
| mask_audio_out_token_label=True, | |
| ): | |
| self.whisper_processor = whisper_processor | |
| self.round_to = round_to | |
| self.pad_left = pad_left | |
| self.audio_in_token_id = audio_in_token_id | |
| self.audio_out_token_id = audio_out_token_id | |
| self.audio_stream_bos_id = audio_stream_bos_id | |
| self.audio_stream_eos_id = audio_stream_eos_id | |
| self.pad_token_id = pad_token_id | |
| self.encode_whisper_embed = encode_whisper_embed | |
| self.return_audio_in_tokens = return_audio_in_tokens | |
| self.audio_num_codebooks = audio_num_codebooks | |
| self.use_delay_pattern = use_delay_pattern | |
| if encode_whisper_embed: | |
| self.chunk_size_seconds = chunk_size_seconds | |
| self.chunk_size_samples = int(chunk_size_seconds * whisper_processor.feature_extractor.sampling_rate) | |
| else: | |
| self.chunk_size_seconds = None | |
| self.chunk_size_samples = None | |
| self.disable_audio_codes_transform = disable_audio_codes_transform | |
| self.add_new_bos_eos_for_long_chunk = add_new_bos_eos_for_long_chunk | |
| self.mask_audio_out_token_label = mask_audio_out_token_label | |
| def _process_and_duplicate_audio_tokens( | |
| self, | |
| input_ids: torch.Tensor, | |
| audio_idx: int, | |
| wv: torch.Tensor, | |
| sr: int, | |
| labels: Optional[torch.Tensor] = None, | |
| ) -> Tuple[torch.Tensor, torch.Tensor, int]: | |
| """Process long audio and duplicate corresponding audio tokens. | |
| Args: | |
| input_ids: Input token ids | |
| audio_idx: Index of the audio token in the sequence | |
| wv: Audio waveform | |
| sr: Sample rate | |
| labels: Optional label ids to be duplicated alongside input ids | |
| Returns: | |
| Tuple of: | |
| - New input ids with duplicated audio tokens | |
| - New label ids (if labels were provided) or None | |
| - Number of chunks created | |
| """ | |
| # Calculate number of chunks needed | |
| total_samples = len(wv) | |
| num_chunks = math.ceil(total_samples / self.chunk_size_samples) | |
| if num_chunks <= 1: | |
| return input_ids, labels, 1 | |
| # Get the three tokens: <|audio_bos|><|AUDIO|><|audio_eos|> | |
| audio_token_seq = input_ids[audio_idx - 1 : audio_idx + 2] | |
| # Duplicate sequence for each chunk | |
| duplicated_sequence = audio_token_seq.repeat(num_chunks) | |
| # Create new input_ids with duplicated tokens | |
| new_input_ids = torch.cat( | |
| [ | |
| input_ids[: audio_idx - 1], | |
| duplicated_sequence, | |
| input_ids[audio_idx + 2 :], | |
| ] | |
| ) | |
| # If labels are provided, duplicate them as well | |
| new_labels = None | |
| if labels is not None: | |
| label_seq = labels[audio_idx - 1 : audio_idx + 2] | |
| duplicated_labels = label_seq.repeat(num_chunks) | |
| new_labels = torch.cat([labels[: audio_idx - 1], duplicated_labels, labels[audio_idx + 2 :]]) | |
| return new_input_ids, new_labels, num_chunks | |
| def __call__(self, batch: List[ChatMLDatasetSample]): | |
| """Collate the input data with support for long audio processing.""" | |
| label_ids = None | |
| label_audio_ids = None | |
| if all([ele.label_ids is None for ele in batch]): | |
| return_labels = False | |
| else: | |
| return_labels = True | |
| if self.encode_whisper_embed: | |
| # Process each sample in the batch to handle long audio | |
| # TODO(?) The implementation here can be optimized. | |
| processed_batch = [] | |
| for i in range(len(batch)): | |
| sample = batch[i] | |
| audio_in_mask = sample.input_ids == self.audio_in_token_id | |
| audio_in_indices = torch.where(audio_in_mask)[0] | |
| audio_out_mask = sample.input_ids == self.audio_out_token_id | |
| # Process each audio token and duplicate if needed | |
| modified_input_ids = sample.input_ids | |
| modified_labels = sample.label_ids if return_labels else None | |
| modified_waveforms_concat = [] | |
| modified_waveforms_start = [] | |
| modified_sample_rate = [] | |
| offset = 0 # Track position changes from duplicating tokens | |
| curr_wv_offset = 0 | |
| # Process input audio tokens | |
| for idx, audio_idx in enumerate(audio_in_indices): | |
| # Get the audio for this token | |
| wv, sr = sample.get_wv(idx) # Use idx since we want the original audio index | |
| if sr != self.whisper_processor.feature_extractor.sampling_rate: | |
| resampled_wv = librosa.resample( | |
| wv.cpu().numpy(), | |
| orig_sr=sr, | |
| target_sr=self.whisper_processor.feature_extractor.sampling_rate, | |
| ) | |
| else: | |
| resampled_wv = wv.cpu().numpy() | |
| wv = torch.tensor(resampled_wv, device=wv.device) | |
| sr = self.whisper_processor.feature_extractor.sampling_rate | |
| # Process and duplicate tokens if necessary | |
| token_pos = audio_idx + offset | |
| modified_input_ids, modified_labels, num_chunks = self._process_and_duplicate_audio_tokens( | |
| modified_input_ids, token_pos, wv, sr, modified_labels | |
| ) | |
| # Update audio data | |
| for chunk_idx in range(num_chunks): | |
| chunk_start = chunk_idx * self.chunk_size_samples | |
| chunk_end = min((chunk_idx + 1) * self.chunk_size_samples, len(wv)) | |
| chunk_wv = wv[chunk_start:chunk_end] | |
| modified_waveforms_concat.append(chunk_wv) | |
| modified_waveforms_start.append(curr_wv_offset) | |
| curr_wv_offset += len(chunk_wv) | |
| modified_sample_rate.append(sr) | |
| # Update offset for next iteration | |
| offset += (num_chunks - 1) * 3 # Each new chunk adds 3 more tokens | |
| # Create new sample with modified tokens and audio data | |
| processed_sample = ChatMLDatasetSample( | |
| input_ids=modified_input_ids, | |
| label_ids=modified_labels if return_labels else sample.label_ids, | |
| audio_ids_concat=sample.audio_ids_concat, | |
| audio_ids_start=sample.audio_ids_start, | |
| audio_waveforms_concat=torch.cat(modified_waveforms_concat) | |
| if modified_waveforms_concat | |
| else sample.audio_waveforms_concat, | |
| audio_waveforms_start=torch.tensor(modified_waveforms_start, dtype=torch.long) | |
| if modified_waveforms_start | |
| else sample.audio_waveforms_start, | |
| audio_sample_rate=torch.tensor(modified_sample_rate) | |
| if modified_sample_rate | |
| else sample.audio_sample_rate, | |
| audio_speaker_indices=torch.tensor([]), | |
| # FIXME(sxjscience): The logic here is not correct for audio_label_ids_concat. | |
| audio_label_ids_concat=sample.audio_label_ids_concat, | |
| ) | |
| # audio_in_chunk_len = len(torch.where(modified_input_ids == self.audio_in_token_id)[0]) | |
| # assert audio_in_chunk_len == processed_sample.num_audios(), f"Mismatch: audio_in_chunk_len={audio_in_chunk_len}, processed_sample.num_audios()={processed_sample.num_audios()}" | |
| processed_batch.append(processed_sample) | |
| else: | |
| processed_batch = batch | |
| # Get the max sequence length based on processed batch | |
| max_seq_length = _ceil_to_nearest(max([len(sample.input_ids) for sample in processed_batch]), self.round_to) | |
| # Get the ids for audio-in and audio-out for each batch | |
| audio_in_wv_l = [] | |
| audio_in_ids_l = [] | |
| audio_out_ids_l = [] | |
| audio_out_ids_group_loc_l = [] | |
| audio_in_label_ids_l = None | |
| audio_out_label_ids_l = None | |
| reward_l = [] | |
| if return_labels: | |
| audio_out_no_train_flag = [] # Whether the audio-out data should be trained on or not. | |
| # Process the audio inputs and outputs | |
| for i in range(len(processed_batch)): | |
| audio_in_mask = processed_batch[i].input_ids == self.audio_in_token_id | |
| audio_out_mask = processed_batch[i].input_ids == self.audio_out_token_id | |
| audio_ids = torch.ones_like(processed_batch[i].input_ids) | |
| audio_ids[audio_in_mask ^ audio_out_mask] = torch.cumsum(audio_ids[audio_in_mask ^ audio_out_mask], 0) - 1 | |
| audio_in_ids = audio_ids[audio_in_mask] | |
| audio_out_ids = audio_ids[audio_out_mask] | |
| if return_labels: | |
| audio_out_no_train_flag.append(processed_batch[i].label_ids[audio_out_mask] < 0) | |
| if self.mask_audio_out_token_label: | |
| processed_batch[i].label_ids[audio_out_mask] = -100 | |
| # Process audio inputs | |
| if self.return_audio_in_tokens: | |
| audio_in_ids_l.extend( | |
| [processed_batch[i].get_audio_codes(idx)[: self.audio_num_codebooks, :] for idx in audio_in_ids] | |
| ) | |
| if processed_batch[i].audio_label_ids_concat is not None: | |
| if audio_in_label_ids_l is None: | |
| audio_in_label_ids_l = [] | |
| audio_in_label_ids_l.extend( | |
| [ | |
| processed_batch[i].get_audio_codes_labels(idx)[: self.audio_num_codebooks, :] | |
| for idx in audio_in_ids | |
| ] | |
| ) | |
| audio_out_ids_l.extend( | |
| [processed_batch[i].get_audio_codes(idx)[: self.audio_num_codebooks, :] for idx in audio_out_ids] | |
| ) | |
| audio_out_ids_group_loc_l.append(i) | |
| if processed_batch[i].reward is not None: | |
| reward_l.append(processed_batch[i].reward) | |
| if processed_batch[i].audio_label_ids_concat is not None: | |
| if audio_out_label_ids_l is None: | |
| audio_out_label_ids_l = [] | |
| audio_out_label_ids_l.extend( | |
| [ | |
| processed_batch[i].get_audio_codes_labels(idx)[: self.audio_num_codebooks, :] | |
| for idx in audio_out_ids | |
| ] | |
| ) | |
| if self.encode_whisper_embed: | |
| for idx in audio_in_ids: | |
| wv, sr = processed_batch[i].get_wv(idx) | |
| resampled_wv = wv.cpu().numpy() | |
| # Split long audio into chunks | |
| total_samples = len(resampled_wv) | |
| for chunk_start in range(0, total_samples, self.chunk_size_samples): | |
| chunk_end = min(chunk_start + self.chunk_size_samples, total_samples) | |
| chunk = resampled_wv[chunk_start:chunk_end] | |
| audio_in_wv_l.append(chunk) | |
| # assert len(audio_in_wv_l) == processed_batch[i].num_audios(), \ | |
| # f"Assertion failed: Mismatch in number of audios. " \ | |
| # f"Expected {processed_batch[i].num_audios()}, but got {len(audio_in_wv_l)} at index {i}." | |
| if return_labels: | |
| audio_out_no_train_flag = torch.cat(audio_out_no_train_flag, dim=0) | |
| # Process all audio features | |
| if len(audio_in_wv_l) > 0: | |
| feature_ret = self.whisper_processor.feature_extractor( | |
| audio_in_wv_l, | |
| sampling_rate=self.whisper_processor.feature_extractor.sampling_rate, | |
| return_attention_mask=True, | |
| padding="max_length", | |
| ) | |
| audio_features = torch.from_numpy(feature_ret["input_features"]) | |
| audio_feature_attention_mask = torch.from_numpy(feature_ret["attention_mask"]) | |
| else: | |
| if self.encode_whisper_embed: | |
| audio_features = torch.zeros( | |
| ( | |
| 0, | |
| self.whisper_processor.feature_extractor.feature_size, | |
| self.whisper_processor.feature_extractor.nb_max_frames, | |
| ), | |
| dtype=torch.float32, | |
| ) | |
| audio_feature_attention_mask = torch.zeros( | |
| (0, self.whisper_processor.feature_extractor.nb_max_frames), | |
| dtype=torch.int32, | |
| ) | |
| else: | |
| audio_features = None | |
| audio_feature_attention_mask = None | |
| # Process audio input tokens | |
| if len(audio_in_ids_l) > 0: | |
| # Append audio-stream-bos and eos tokens | |
| new_audio_in_ids_l = [] | |
| for ele in audio_in_ids_l: | |
| if self.disable_audio_codes_transform: | |
| # Do not add audio-stream-bos or eos tokens. | |
| # This may indicate that the sample comes from ConstantLengthDatasetWithBuffer. | |
| audio_codes = ele | |
| else: | |
| audio_codes = torch.cat( | |
| [ | |
| torch.full( | |
| (ele.shape[0], 1), | |
| self.audio_stream_bos_id, | |
| dtype=torch.long, | |
| ), | |
| ele, | |
| torch.full( | |
| (ele.shape[0], 1), | |
| self.audio_stream_eos_id, | |
| dtype=torch.long, | |
| ), | |
| ], | |
| dim=1, | |
| ) | |
| if self.use_delay_pattern: | |
| audio_codes = build_delay_pattern_mask( | |
| audio_codes.unsqueeze(0), | |
| bos_token_id=self.audio_stream_bos_id, | |
| pad_token_id=self.audio_stream_eos_id, | |
| )[0].squeeze(0) | |
| new_audio_in_ids_l.append(audio_codes) | |
| audio_in_ids = torch.cat(new_audio_in_ids_l, dim=1).long() | |
| audio_in_ids_start = torch.cumsum( | |
| torch.tensor([0] + [audio_codes.shape[1] for audio_codes in new_audio_in_ids_l[:-1]]), | |
| dim=0, | |
| ) | |
| else: | |
| audio_in_ids = torch.zeros((0, 0), dtype=torch.long) | |
| audio_in_ids_start = torch.zeros(0, dtype=torch.long) | |
| # Process audio output tokens | |
| audio_out_ids_start_group_loc = None | |
| if len(audio_out_ids_l) > 0: | |
| new_audio_out_ids_l = [] | |
| label_audio_ids_l = [] | |
| for idx, ele in enumerate(audio_out_ids_l): | |
| if self.disable_audio_codes_transform: | |
| # Do not add audio-stream-bos or eos tokens. | |
| # This may indicate that the sample comes from ConstantLengthDatasetWithBuffer. | |
| audio_codes = ele | |
| if return_labels: | |
| label_audio_ids = audio_out_label_ids_l[idx] | |
| else: | |
| audio_codes = torch.cat( | |
| [ | |
| torch.full( | |
| (ele.shape[0], 1), | |
| self.audio_stream_bos_id, | |
| dtype=torch.long, | |
| ), | |
| ele, | |
| torch.full( | |
| (ele.shape[0], 1), | |
| self.audio_stream_eos_id, | |
| dtype=torch.long, | |
| ), | |
| ], | |
| dim=1, | |
| ) | |
| if return_labels: | |
| label_audio_ids = torch.cat( | |
| [ | |
| torch.full((ele.shape[0], 1), -100, dtype=torch.long), | |
| ele, | |
| torch.full( | |
| (ele.shape[0], 1), | |
| self.audio_stream_eos_id, | |
| dtype=torch.long, | |
| ), | |
| ], | |
| dim=1, | |
| ) | |
| if self.use_delay_pattern: | |
| audio_codes = build_delay_pattern_mask( | |
| audio_codes.unsqueeze(0), | |
| bos_token_id=self.audio_stream_bos_id, | |
| pad_token_id=self.audio_stream_eos_id, | |
| )[0].squeeze(0) | |
| if return_labels: | |
| label_audio_ids = build_delay_pattern_mask( | |
| label_audio_ids.unsqueeze(0), | |
| bos_token_id=-100, | |
| pad_token_id=-100, | |
| )[0].squeeze(0) | |
| new_audio_out_ids_l.append(audio_codes) | |
| if return_labels: | |
| if audio_out_no_train_flag[idx]: | |
| label_audio_ids[:] = -100 | |
| label_audio_ids_l.append(label_audio_ids) | |
| audio_out_ids = torch.cat(new_audio_out_ids_l, dim=1).long() | |
| if return_labels: | |
| label_audio_ids = torch.cat(label_audio_ids_l, dim=1).long() | |
| audio_out_ids_start = torch.cumsum( | |
| torch.tensor([0] + [audio_codes.shape[1] for audio_codes in new_audio_out_ids_l[:-1]]), | |
| dim=0, | |
| ) | |
| audio_out_ids_start_group_loc = torch.tensor(audio_out_ids_group_loc_l, dtype=torch.long) | |
| else: | |
| audio_out_ids = torch.zeros((0, 0), dtype=torch.long) | |
| audio_out_ids_start = torch.zeros(0, dtype=torch.long) | |
| if return_labels: | |
| label_audio_ids = torch.zeros((0, 0), dtype=torch.long) | |
| reward = torch.tensor(reward_l, dtype=torch.float32) | |
| # Handle padding for input ids and attention mask | |
| if self.pad_left: | |
| input_ids = torch.stack( | |
| [ | |
| F.pad( | |
| ele.input_ids, | |
| (max_seq_length - len(ele.input_ids), 0), | |
| value=self.pad_token_id, | |
| ) | |
| for ele in processed_batch | |
| ] | |
| ) | |
| if return_labels: | |
| label_ids = torch.stack( | |
| [ | |
| F.pad( | |
| ele.label_ids, | |
| (max_seq_length - len(ele.label_ids), 0), | |
| value=-100, | |
| ) | |
| for ele in processed_batch | |
| ] | |
| ) | |
| attention_mask = torch.stack( | |
| [ | |
| F.pad( | |
| torch.ones_like(ele.input_ids), | |
| (max_seq_length - len(ele.input_ids), 0), | |
| value=0, | |
| ) | |
| for ele in processed_batch | |
| ] | |
| ) | |
| else: | |
| input_ids = torch.stack( | |
| [ | |
| F.pad( | |
| ele.input_ids, | |
| (0, max_seq_length - len(ele.input_ids)), | |
| value=self.pad_token_id, | |
| ) | |
| for ele in processed_batch | |
| ] | |
| ) | |
| if return_labels: | |
| label_ids = torch.stack( | |
| [ | |
| F.pad( | |
| ele.label_ids, | |
| (0, max_seq_length - len(ele.label_ids)), | |
| value=-100, | |
| ) | |
| for ele in processed_batch | |
| ] | |
| ) | |
| attention_mask = torch.stack( | |
| [ | |
| F.pad( | |
| torch.ones_like(ele.input_ids), | |
| (0, max_seq_length - len(ele.input_ids)), | |
| value=0, | |
| ) | |
| for ele in processed_batch | |
| ] | |
| ) | |
| if not self.return_audio_in_tokens: | |
| audio_in_ids = None | |
| audio_in_ids_start = None | |
| # Apply audio_num_codebooks limit if specified | |
| if self.audio_num_codebooks is not None: | |
| if audio_in_ids is not None: | |
| audio_in_ids = audio_in_ids[: self.audio_num_codebooks] | |
| if audio_out_ids is not None: | |
| audio_out_ids = audio_out_ids[: self.audio_num_codebooks] | |
| if label_audio_ids is not None: | |
| label_audio_ids = label_audio_ids[: self.audio_num_codebooks] | |
| return HiggsAudioBatchInput( | |
| input_ids=input_ids, | |
| attention_mask=attention_mask, | |
| audio_features=audio_features, | |
| audio_feature_attention_mask=audio_feature_attention_mask, | |
| audio_out_ids=audio_out_ids, | |
| audio_out_ids_start=audio_out_ids_start, | |
| audio_out_ids_start_group_loc=audio_out_ids_start_group_loc, | |
| audio_in_ids=audio_in_ids, | |
| audio_in_ids_start=audio_in_ids_start, | |
| label_ids=label_ids, | |
| label_audio_ids=label_audio_ids, | |
| reward=reward, | |
| ) | |
| class HiggsAudioDPOSamplesCollator(HiggsAudioSampleCollator): | |
| def __init__(self, *args, **kwargs): | |
| super().__init__(*args, **kwargs) | |
| def __call__(self, batch: List[RankedChatMLDatasetSampleTuple]) -> HiggsAudioBatchInput: | |
| # flatten ranked chatml samples | |
| chosen = [] | |
| rejected = [] | |
| for sample in batch: | |
| chosen.append(sample.max_score_sample()) | |
| rejected.append(sample.min_score_sample()) | |
| merged = chosen | |
| merged.extend(rejected) | |
| return super().__call__(batch=merged) | |