| from typing import Dict, Optional, Tuple, Union | |
| from transformers import PretrainedConfig | |
| from transformers.models.qwen2_5_omni.configuration_qwen2_5_omni import ( | |
| Qwen2_5OmniTextConfig, | |
| ) | |
| class DashengConfig(PretrainedConfig): | |
| model_type = "midashenglm_dasheng_encoder" | |
| def __init__( | |
| self, | |
| embed_dim: int = 768, | |
| outputdim: int = 527, | |
| patch_size: Union[int, Tuple[int, int]] = 16, | |
| patch_stride: Union[int, Tuple[int, int]] = 16, | |
| input_channels: int = 1, | |
| target_length: int = 1012, | |
| depth: int = 12, | |
| num_heads: int = 12, | |
| mlp_ratio: float = 4.0, | |
| qkv_bias: bool = True, | |
| init_values: Optional[float] = None, | |
| drop_rate: float = 0.0, | |
| attn_drop_rate: float = 0.0, | |
| f_min: float = 0.0, | |
| f_max: float = 8000.0, | |
| center: bool = True, | |
| win_length: int = 512, | |
| hop_length: int = 160, | |
| sample_rate: int = 16000, | |
| n_fft: int = 512, | |
| n_mels: int = 64, | |
| **kwargs, | |
| ): | |
| self.embed_dim = embed_dim | |
| self.outputdim = outputdim | |
| self.patch_size = patch_size | |
| self.patch_stride = patch_stride | |
| self.input_channels = input_channels | |
| self.target_length = target_length | |
| self.depth = depth | |
| self.num_heads = num_heads | |
| self.mlp_ratio = mlp_ratio | |
| self.qkv_bias = qkv_bias | |
| self.init_values = init_values | |
| self.drop_rate = drop_rate | |
| self.attn_drop_rate = attn_drop_rate | |
| self.f_min = f_min | |
| self.f_max = f_max | |
| self.center = center | |
| self.win_length = win_length | |
| self.hop_length = hop_length | |
| self.sample_rate = sample_rate | |
| self.n_fft = n_fft | |
| self.n_mels = n_mels | |
| super().__init__(**kwargs) | |
| class MiDashengLMConfig(PretrainedConfig): | |
| model_type = "midashenglm" | |
| def __init__( | |
| self, | |
| audio_encoder_config: Dict = {}, | |
| subsample_factor: int = 5, | |
| text_config: Dict = {}, | |
| audio_token_id: Optional[int] = None, | |
| **kwargs, | |
| ): | |
| self.audio_encoder_config = DashengConfig(**audio_encoder_config) | |
| self.subsample_factor = subsample_factor | |
| self.text_config = ( | |
| Qwen2_5OmniTextConfig(**text_config) | |
| if text_config | |
| else Qwen2_5OmniTextConfig() | |
| ) | |
| self.audio_token_id = audio_token_id | |
| super().__init__(**kwargs) | |
