Spaces:
Running
Running
| # coding: utf-8 | |
| import torch | |
| import torch.nn as nn | |
| import torch.nn.functional as F | |
| from .help_layers import TransformerEncoderLayer, GAL,GraphFusionLayer, GraphFusionLayerAtt, MambaBlock, RMSNorm | |
| class PredictionsFusion(nn.Module): | |
| def __init__(self, num_matrices=2, num_classes=7): | |
| super(PredictionsFusion, self).__init__() | |
| self.weights = nn.Parameter(torch.rand(num_matrices, num_classes)) | |
| def forward(self, pred): | |
| normalized_weights = torch.softmax(self.weights, dim=0) | |
| weighted_matrix = sum(mat * normalized_weights[i] for i, mat in enumerate(pred)) | |
| return weighted_matrix | |
| class MultiModalTransformer_v3(nn.Module): | |
| def __init__(self, audio_dim=1024, text_dim=1024, hidden_dim=512, hidden_dim_gated=512, num_transformer_heads=2, num_graph_heads=2, seg_len=44, positional_encoding=True, dropout=0, mode='mean', device="cuda", tr_layer_number=1, out_features=128, num_classes=7): | |
| super(MultiModalTransformer_v3, self).__init__() | |
| self.mode = mode | |
| self.hidden_dim = hidden_dim | |
| # Проекционные слои | |
| # self.audio_proj = nn.Linear(audio_dim, hidden_dim) if audio_dim != hidden_dim else nn.Identity() | |
| # self.audio_proj = nn.Sequential( | |
| # nn.Linear(audio_dim, hidden_dim) if audio_dim != hidden_dim else nn.Identity(), | |
| # nn.LayerNorm(hidden_dim), | |
| # nn.Dropout(dropout) | |
| # ) | |
| self.audio_proj = nn.Sequential( | |
| nn.Conv1d(audio_dim, hidden_dim, 1), | |
| nn.GELU(), | |
| ) | |
| self.text_proj = nn.Sequential( | |
| nn.Conv1d(text_dim, hidden_dim, 1), | |
| nn.GELU(), | |
| ) | |
| # self.text_proj = nn.Linear(text_dim, hidden_dim) if text_dim != hidden_dim else nn.Identity() | |
| # self.text_proj = nn.Sequential( | |
| # nn.Linear(audio_dim, hidden_dim) if audio_dim != hidden_dim else nn.Identity(), | |
| # nn.LayerNorm(hidden_dim), | |
| # nn.Dropout(dropout) | |
| # ) | |
| # Механизмы внимания | |
| self.audio_to_text_attn = nn.ModuleList([TransformerEncoderLayer(input_dim=hidden_dim, num_heads=num_transformer_heads, positional_encoding=positional_encoding, dropout=dropout) for i in range(tr_layer_number) | |
| ]) | |
| self.text_to_audio_attn = nn.ModuleList([TransformerEncoderLayer(input_dim=hidden_dim, num_heads=num_transformer_heads, positional_encoding=positional_encoding, dropout=dropout) for i in range(tr_layer_number) | |
| ]) | |
| # Классификатор | |
| # self.classifier = nn.Sequential( | |
| # nn.Linear(hidden_dim*2, out_features) if self.mode == 'mean' else nn.Linear(hidden_dim*4, out_features), | |
| # nn.ReLU(), | |
| # nn.Linear(out_features, num_classes) | |
| # ) | |
| self.classifier = nn.Sequential( | |
| nn.Linear(hidden_dim*2, out_features) if self.mode == 'mean' else nn.Linear(hidden_dim*4, out_features), | |
| # nn.LayerNorm(out_features), | |
| # nn.GELU(), | |
| # nn.Dropout(dropout), | |
| nn.Linear(out_features, num_classes) | |
| ) | |
| # self._init_weights() | |
| def forward(self, audio_features, text_features): | |
| # Преобразование размерностей | |
| audio_features = audio_features.float() | |
| text_features = text_features.float() | |
| # audio_features = self.audio_proj(audio_features) | |
| # text_features = self.text_proj(text_features) | |
| audio_features = self.audio_proj(audio_features.permute(0,2,1)).permute(0,2,1) | |
| text_features = self.text_proj(text_features.permute(0,2,1)).permute(0,2,1) | |
| # Адаптивная пуллинг до минимальной длины | |
| min_seq_len = min(audio_features.size(1), text_features.size(1)) | |
| audio_features = F.adaptive_avg_pool1d(audio_features.permute(0,2,1), min_seq_len).permute(0,2,1) | |
| text_features = F.adaptive_avg_pool1d(text_features.permute(0,2,1), min_seq_len).permute(0,2,1) | |
| # Трансформерные блоки | |
| for i in range(len(self.audio_to_text_attn)): | |
| attn_audio = self.audio_to_text_attn[i](text_features, audio_features, audio_features) | |
| attn_text = self.text_to_audio_attn[i](audio_features, text_features, text_features) | |
| audio_features += attn_audio | |
| text_features += attn_text | |
| # Статистики | |
| std_audio, mean_audio = torch.std_mean(attn_audio, dim=1) | |
| std_text, mean_text = torch.std_mean(attn_text, dim=1) | |
| # Классификация | |
| if self.mode == 'mean': | |
| return self.classifier(torch.cat([mean_audio, mean_audio], dim=1)) | |
| else: | |
| return self.classifier(torch.cat([mean_audio, std_audio, mean_text, std_text], dim=1)) | |
| def _init_weights(self): | |
| for m in self.modules(): | |
| if isinstance(m, nn.Linear): | |
| nn.init.xavier_uniform_(m.weight) | |
| if m.bias is not None: | |
| nn.init.constant_(m.bias, 0) | |
| elif isinstance(m, nn.LayerNorm): | |
| nn.init.constant_(m.weight, 1) | |
| nn.init.constant_(m.bias, 0) | |
| class MultiModalTransformer_v4(nn.Module): | |
| def __init__(self, audio_dim=1024, text_dim=1024, hidden_dim=512, hidden_dim_gated=512, num_transformer_heads=2, num_graph_heads=2, seg_len=44, positional_encoding=True, dropout=0, mode='mean', device="cuda", tr_layer_number=1, out_features=128, num_classes=7): | |
| super(MultiModalTransformer_v4, self).__init__() | |
| self.mode = mode | |
| self.hidden_dim = hidden_dim | |
| # Проекционные слои | |
| self.audio_proj = nn.Linear(audio_dim, hidden_dim) if audio_dim != hidden_dim else nn.Identity() | |
| self.text_proj = nn.Linear(text_dim, hidden_dim) if text_dim != hidden_dim else nn.Identity() | |
| # Механизмы внимания | |
| self.audio_to_text_attn = nn.ModuleList([TransformerEncoderLayer(input_dim=hidden_dim, num_heads=num_transformer_heads, positional_encoding=positional_encoding, dropout=dropout) for i in range(tr_layer_number) | |
| ]) | |
| self.text_to_audio_attn = nn.ModuleList([TransformerEncoderLayer(input_dim=hidden_dim, num_heads=num_transformer_heads, positional_encoding=positional_encoding, dropout=dropout) for i in range(tr_layer_number) | |
| ]) | |
| # Графовое слияние вместо GAL | |
| if self.mode == 'mean': | |
| self.graph_fusion = GraphFusionLayer(hidden_dim, heads=num_graph_heads) | |
| else: | |
| self.graph_fusion = GraphFusionLayer(hidden_dim*2, heads=num_graph_heads) | |
| # Классификатор | |
| self.classifier = nn.Sequential( | |
| nn.Linear(hidden_dim, out_features) if self.mode == 'mean' else nn.Linear(hidden_dim*2, out_features), | |
| nn.ReLU(), | |
| nn.Linear(out_features, num_classes) | |
| ) | |
| def forward(self, audio_features, text_features): | |
| # Преобразование размерностей | |
| audio_features = audio_features.float() | |
| text_features = text_features.float() | |
| audio_features = self.audio_proj(audio_features) | |
| text_features = self.text_proj(text_features) | |
| # Адаптивная пуллинг до минимальной длины | |
| min_seq_len = min(audio_features.size(1), text_features.size(1)) | |
| audio_features = F.adaptive_avg_pool1d(audio_features.permute(0,2,1), min_seq_len).permute(0,2,1) | |
| text_features = F.adaptive_avg_pool1d(text_features.permute(0,2,1), min_seq_len).permute(0,2,1) | |
| # Трансформерные блоки | |
| for i in range(len(self.audio_to_text_attn)): | |
| attn_audio = self.audio_to_text_attn[i](text_features, audio_features, audio_features) | |
| attn_text = self.text_to_audio_attn[i](audio_features, text_features, text_features) | |
| audio_features += attn_audio | |
| text_features += attn_text | |
| # Статистики | |
| std_audio, mean_audio = torch.std_mean(attn_audio, dim=1) | |
| std_text, mean_text = torch.std_mean(attn_text, dim=1) | |
| # Графовое слияние статистик | |
| if self.mode == 'mean': | |
| h_ta = self.graph_fusion(mean_audio, mean_text) | |
| else: | |
| h_ta = self.graph_fusion(torch.cat([mean_audio, std_audio], dim=1), torch.cat([mean_text, std_text], dim=1)) | |
| # Классификация | |
| return self.classifier(h_ta) | |
| class MultiModalTransformer_v5(nn.Module): | |
| def __init__(self, audio_dim=1024, text_dim=1024, hidden_dim=512, hidden_dim_gated=512, num_transformer_heads=2, num_graph_heads=2, seg_len=44, tr_layer_number=1, positional_encoding=True, dropout=0, mode='mean', device="cuda", out_features=128, num_classes=7): | |
| super(MultiModalTransformer_v5, self).__init__() | |
| self.hidden_dim = hidden_dim | |
| self.mode = mode | |
| # Приведение к общей размерности (адаптивные проекции) | |
| self.audio_proj = nn.Linear(audio_dim, hidden_dim) if audio_dim != hidden_dim else nn.Identity() | |
| self.text_proj = nn.Linear(text_dim, hidden_dim) if text_dim != hidden_dim else nn.Identity() | |
| # Механизмы внимания | |
| self.audio_to_text_attn = nn.ModuleList([TransformerEncoderLayer(input_dim=hidden_dim, num_heads=num_transformer_heads, positional_encoding=positional_encoding, dropout=dropout) for i in range(tr_layer_number) | |
| ]) | |
| self.text_to_audio_attn = nn.ModuleList([TransformerEncoderLayer(input_dim=hidden_dim, num_heads=num_transformer_heads, positional_encoding=positional_encoding, dropout=dropout) for i in range(tr_layer_number) | |
| ]) | |
| # Гейтед аттеншн | |
| if self.mode == 'mean': | |
| self.gal = GAL(hidden_dim, hidden_dim, hidden_dim_gated) | |
| else: | |
| self.gal = GAL(hidden_dim*2, hidden_dim*2, hidden_dim_gated) | |
| # Классификатор | |
| self.classifier = nn.Sequential( | |
| nn.Linear(hidden_dim, out_features), | |
| nn.ReLU(), | |
| nn.Linear(out_features, num_classes) | |
| ) | |
| def forward(self, audio_features, text_features): | |
| bs, seq_audio, audio_feat_dim = audio_features.shape | |
| bs, seq_text, text_feat_dim = text_features.shape | |
| text_features = text_features.to(torch.float32) | |
| audio_features = audio_features.to(torch.float32) | |
| # Приведение размерности | |
| audio_features = self.audio_proj(audio_features) # (bs, seq_audio, hidden_dim) | |
| text_features = self.text_proj(text_features) # (bs, seq_text, hidden_dim) | |
| # Определяем минимальную длину последовательности | |
| min_seq_len = min(seq_audio, seq_text) | |
| # Усреднение до минимальной длины | |
| audio_features = F.adaptive_avg_pool2d(audio_features.permute(0, 2, 1), (self.hidden_dim, min_seq_len)).permute(0, 2, 1) | |
| text_features = F.adaptive_avg_pool2d(text_features.permute(0, 2, 1), (self.hidden_dim, min_seq_len)).permute(0, 2, 1) | |
| # Трансформерные блоки | |
| for i in range(len(self.audio_to_text_attn)): | |
| attn_audio = self.audio_to_text_attn[i](text_features, audio_features, audio_features) | |
| attn_text = self.text_to_audio_attn[i](audio_features, text_features, text_features) | |
| audio_features += attn_audio | |
| text_features += attn_text | |
| # Статистики | |
| std_audio, mean_audio = torch.std_mean(attn_audio, dim=1) | |
| std_text, mean_text = torch.std_mean(attn_text, dim=1) | |
| # # Гейтед аттеншн | |
| # h_audio = torch.tanh(self.Wa(torch.cat([min_audio, std_audio], dim=1))) | |
| # h_text = torch.tanh(self.Wt(torch.cat([min_text, std_text], dim=1))) | |
| # z_ta = torch.sigmoid(self.W_at(torch.cat([min_audio, std_audio, min_text, std_text], dim=1))) | |
| # h_ta = z_ta * h_text + (1 - z_ta) * h_audio | |
| if self.mode == 'mean': | |
| h_ta = self.gal(mean_audio, mean_text) | |
| else: | |
| h_ta = self.gal(torch.cat([mean_audio, std_audio], dim=1), torch.cat([mean_text, std_text], dim=1)) | |
| # Классификация | |
| output = self.classifier(h_ta) | |
| return output | |
| class MultiModalTransformer_v7(nn.Module): | |
| def __init__(self, audio_dim=1024, text_dim=1024, hidden_dim=512, num_heads=2, positional_encoding=True, dropout=0, mode='mean', device="cuda", tr_layer_number=1, out_features=128, num_classes=7): | |
| super(MultiModalTransformer_v7, self).__init__() | |
| self.mode = mode | |
| self.hidden_dim = hidden_dim | |
| # Проекционные слои | |
| self.audio_proj = nn.Linear(audio_dim, hidden_dim) if audio_dim != hidden_dim else nn.Identity() | |
| self.text_proj = nn.Linear(text_dim, hidden_dim) if text_dim != hidden_dim else nn.Identity() | |
| # Механизмы внимания | |
| self.audio_to_text_attn = nn.ModuleList([TransformerEncoderLayer(input_dim=hidden_dim, num_heads=num_heads, positional_encoding=positional_encoding, dropout=dropout) for i in range(tr_layer_number) | |
| ]) | |
| self.text_to_audio_attn = nn.ModuleList([TransformerEncoderLayer(input_dim=hidden_dim, num_heads=num_heads, positional_encoding=positional_encoding, dropout=dropout) for i in range(tr_layer_number) | |
| ]) | |
| # Графовое слияние вместо GAL | |
| if self.mode == 'mean': | |
| self.graph_fusion = GraphFusionLayerAtt(hidden_dim, heads=num_heads) | |
| else: | |
| self.graph_fusion = GraphFusionLayerAtt(hidden_dim*2, heads=num_heads) | |
| # Классификатор | |
| self.classifier = nn.Sequential( | |
| nn.Linear(hidden_dim, out_features) if self.mode == 'mean' else nn.Linear(hidden_dim*2, out_features), | |
| nn.ReLU(), | |
| nn.Linear(out_features, num_classes) | |
| ) | |
| def forward(self, audio_features, text_features): | |
| # Преобразование размерностей | |
| audio_features = audio_features.float() | |
| text_features = text_features.float() | |
| audio_features = self.audio_proj(audio_features) | |
| text_features = self.text_proj(text_features) | |
| # Адаптивная пуллинг до минимальной длины | |
| min_seq_len = min(audio_features.size(1), text_features.size(1)) | |
| audio_features = F.adaptive_avg_pool1d(audio_features.permute(0,2,1), min_seq_len).permute(0,2,1) | |
| text_features = F.adaptive_avg_pool1d(text_features.permute(0,2,1), min_seq_len).permute(0,2,1) | |
| # Трансформерные блоки | |
| for i in range(len(self.audio_to_text_attn)): | |
| attn_audio = self.audio_to_text_attn[i](text_features, audio_features, audio_features) | |
| attn_text = self.text_to_audio_attn[i](audio_features, text_features, text_features) | |
| audio_features += attn_audio | |
| text_features += attn_text | |
| # Статистики | |
| std_audio, mean_audio = torch.std_mean(attn_audio, dim=1) | |
| std_text, mean_text = torch.std_mean(attn_text, dim=1) | |
| # Графовое слияние статистик | |
| if self.mode == 'mean': | |
| h_ta = self.graph_fusion(mean_audio, mean_text) | |
| else: | |
| h_ta = self.graph_fusion(torch.cat([mean_audio, std_audio], dim=1), torch.cat([mean_audio, std_text], dim=1)) | |
| # Классификация | |
| return self.classifier(h_ta) | |
| class BiFormer(nn.Module): | |
| def __init__(self, audio_dim=1024, text_dim=1024, seg_len=44, hidden_dim=512, hidden_dim_gated=128, | |
| num_transformer_heads=2, num_graph_heads=2, positional_encoding=True, dropout=0.1, mode='mean', | |
| device="cuda", tr_layer_number=1, out_features=128, num_classes=7): | |
| super(BiFormer, self).__init__() | |
| self.mode = mode | |
| self.hidden_dim = hidden_dim | |
| self.seg_len = seg_len | |
| self.tr_layer_number = tr_layer_number | |
| # Проекционные слои с нормализацией | |
| self.audio_proj = nn.Sequential( | |
| nn.Linear(audio_dim, hidden_dim) if audio_dim != hidden_dim else nn.Identity(), | |
| nn.LayerNorm(hidden_dim), | |
| nn.Dropout(dropout) | |
| ) | |
| self.text_proj = nn.Sequential( | |
| nn.Linear(text_dim, hidden_dim) if text_dim != hidden_dim else nn.Identity(), | |
| nn.LayerNorm(hidden_dim), | |
| nn.Dropout(dropout) | |
| ) | |
| # self.audio_proj = nn.Sequential( | |
| # nn.Conv1d(audio_dim, hidden_dim, 1), | |
| # nn.GELU(), | |
| # ) | |
| # self.text_proj = nn.Sequential( | |
| # nn.Conv1d(text_dim, hidden_dim, 1), | |
| # nn.GELU(), | |
| # ) | |
| # Трансформерные слои (сохраняем вашу реализацию) | |
| self.audio_to_text_attn = nn.ModuleList([ | |
| TransformerEncoderLayer( | |
| input_dim=hidden_dim, | |
| num_heads=num_transformer_heads, | |
| dropout=dropout, | |
| positional_encoding=positional_encoding | |
| ) for _ in range(tr_layer_number) | |
| ]) | |
| self.text_to_audio_attn = nn.ModuleList([ | |
| TransformerEncoderLayer( | |
| input_dim=hidden_dim, | |
| num_heads=num_transformer_heads, | |
| dropout=dropout, | |
| positional_encoding=positional_encoding | |
| ) for _ in range(tr_layer_number) | |
| ]) | |
| # Автоматический расчёт размерности для классификатора | |
| self._calculate_classifier_input_dim() | |
| # Классификатор | |
| self.classifier = nn.Sequential( | |
| nn.Linear(self.classifier_input_dim, out_features), | |
| nn.LayerNorm(out_features), | |
| nn.GELU(), | |
| nn.Dropout(dropout), | |
| nn.Linear(out_features, num_classes) | |
| ) | |
| self._init_weights() | |
| def _calculate_classifier_input_dim(self): | |
| """Вычисляет размер входных признаков для классификатора""" | |
| # Тестовый проход через пулинг с dummy-данными | |
| dummy_audio = torch.randn(1, self.seg_len, self.hidden_dim) | |
| dummy_text = torch.randn(1, self.seg_len, self.hidden_dim) | |
| audio_pool = self._pool_features(dummy_audio) | |
| text_pool = self._pool_features(dummy_text) | |
| combined = torch.cat([audio_pool, text_pool], dim=1) | |
| self.classifier_input_dim = combined.size(1) | |
| def _pool_features(self, x): | |
| # Статистики по временной оси (seq_len) | |
| mean_temp = x.mean(dim=1) # [batch, hidden_dim] | |
| # Статистики по feature оси (hidden_dim) | |
| mean_feat = x.mean(dim=-1) # [batch, seq_len] | |
| return torch.cat([mean_temp, mean_feat], dim=1) | |
| def forward(self, audio_features, text_features): | |
| # Проекция признаков | |
| # audio = self.audio_proj(audio_features.permute(0,2,1)).permute(0,2,1) | |
| # text = self.text_proj(text_features.permute(0,2,1)).permute(0,2,1) | |
| audio = self.audio_proj(audio_features.float()) | |
| text = self.text_proj(text_features.float()) | |
| # Адаптивный пулинг | |
| min_len = min(audio.size(1), text.size(1)) | |
| audio = self.adaptive_temporal_pool(audio, min_len) | |
| text = self.adaptive_temporal_pool(text, min_len) | |
| # Кросс-модальное взаимодействие | |
| for i in range(self.tr_layer_number): | |
| attn_audio = self.audio_to_text_attn[i](text, audio, audio) | |
| attn_text = self.text_to_audio_attn[i](audio, text, text) | |
| audio = audio + attn_audio | |
| text = text + attn_text | |
| # Агрегация признаков | |
| audio_pool = self._pool_features(audio) | |
| text_pool = self._pool_features(text) | |
| # Классификация | |
| features = torch.cat([audio_pool, text_pool], dim=1) | |
| return self.classifier(features) | |
| def adaptive_temporal_pool(self, x, target_len): | |
| """Адаптивное изменение временной длины""" | |
| if x.size(1) == target_len: | |
| return x | |
| return F.interpolate( | |
| x.permute(0, 2, 1), | |
| size=target_len, | |
| mode='linear', | |
| align_corners=False | |
| ).permute(0, 2, 1) | |
| def _init_weights(self): | |
| for m in self.modules(): | |
| if isinstance(m, nn.Linear): | |
| nn.init.xavier_uniform_(m.weight) | |
| if m.bias is not None: | |
| nn.init.constant_(m.bias, 0) | |
| elif isinstance(m, nn.LayerNorm): | |
| nn.init.constant_(m.weight, 1) | |
| nn.init.constant_(m.bias, 0) | |
| class BiGraphFormer(nn.Module): | |
| def __init__(self, audio_dim=1024, text_dim=1024, seg_len=44, hidden_dim=512, hidden_dim_gated=128, | |
| num_transformer_heads=2, num_graph_heads = 2, positional_encoding=True, dropout=0.1, mode='mean', | |
| device="cuda", tr_layer_number=1, out_features=128, num_classes=7): | |
| super(BiGraphFormer, self).__init__() | |
| self.mode = mode | |
| self.hidden_dim = hidden_dim | |
| self.seg_len = seg_len | |
| self.tr_layer_number = tr_layer_number | |
| # Проекционные слои с нормализацией | |
| self.audio_proj = nn.Sequential( | |
| nn.Linear(audio_dim, hidden_dim) if audio_dim != hidden_dim else nn.Identity(), | |
| nn.LayerNorm(hidden_dim), | |
| nn.Dropout(dropout) | |
| ) | |
| self.text_proj = nn.Sequential( | |
| nn.Linear(text_dim, hidden_dim) if text_dim != hidden_dim else nn.Identity(), | |
| nn.LayerNorm(hidden_dim), | |
| nn.Dropout(dropout) | |
| ) | |
| # Трансформерные слои (сохраняем вашу реализацию) | |
| self.audio_to_text_attn = nn.ModuleList([ | |
| TransformerEncoderLayer( | |
| input_dim=hidden_dim, | |
| num_heads=num_transformer_heads, | |
| dropout=dropout, | |
| positional_encoding=positional_encoding | |
| ) for _ in range(tr_layer_number) | |
| ]) | |
| self.text_to_audio_attn = nn.ModuleList([ | |
| TransformerEncoderLayer( | |
| input_dim=hidden_dim, | |
| num_heads=num_transformer_heads, | |
| dropout=dropout, | |
| positional_encoding=positional_encoding | |
| ) for _ in range(tr_layer_number) | |
| ]) | |
| self.graph_fusion_feat = GraphFusionLayer(self.seg_len, heads=num_graph_heads) | |
| self.graph_fusion_temp = GraphFusionLayer(hidden_dim, heads=num_graph_heads) | |
| # Автоматический расчёт размерности для классификатора | |
| self._calculate_classifier_input_dim() | |
| # Классификатор | |
| self.classifier = nn.Sequential( | |
| nn.Linear(self.classifier_input_dim, out_features), | |
| nn.LayerNorm(out_features), | |
| nn.GELU(), | |
| nn.Dropout(dropout), | |
| nn.Linear(out_features, num_classes) | |
| ) | |
| # Финальная проекция графов | |
| self.fc_feat = nn.Sequential( | |
| nn.Linear(self.seg_len, self.seg_len), | |
| nn.LayerNorm(self.seg_len), | |
| nn.Dropout(dropout) | |
| ) | |
| self.fc_temp = nn.Sequential( | |
| nn.Linear(hidden_dim, hidden_dim), | |
| nn.LayerNorm(hidden_dim), | |
| nn.Dropout(dropout) | |
| ) | |
| self._init_weights() | |
| def _calculate_classifier_input_dim(self): | |
| """Вычисляет размер входных признаков для классификатора""" | |
| # Тестовый проход через пулинг с dummy-данными | |
| dummy_audio = torch.randn(1, self.seg_len, self.hidden_dim) | |
| dummy_text = torch.randn(1, self.seg_len, self.hidden_dim) | |
| audio_pool_temp, audio_pool_feat = self._pool_features(dummy_audio) | |
| # text_pool_temp, _ = self._pool_features(dummy_text) | |
| combined = torch.cat([audio_pool_temp, audio_pool_feat], dim=1) | |
| self.classifier_input_dim = combined.size(1) | |
| def _pool_features(self, x): | |
| # Статистики по временной оси (seq_len) | |
| mean_temp = x.mean(dim=1) # [batch, hidden_dim] | |
| # Статистики по feature оси (hidden_dim) | |
| mean_feat = x.mean(dim=-1) # [batch, seq_len] | |
| return mean_temp, mean_feat | |
| def forward(self, audio_features, text_features): | |
| # Проекция признаков | |
| audio = self.audio_proj(audio_features.float()) | |
| text = self.text_proj(text_features.float()) | |
| # Адаптивный пулинг | |
| min_len = min(audio.size(1), text.size(1)) | |
| audio = self.adaptive_temporal_pool(audio, min_len) | |
| text = self.adaptive_temporal_pool(text, min_len) | |
| # Кросс-модальное взаимодействие | |
| for i in range(self.tr_layer_number): | |
| attn_audio = self.audio_to_text_attn[i](text, audio, audio) | |
| attn_text = self.text_to_audio_attn[i](audio, text, text) | |
| audio = audio + attn_audio | |
| text = text + attn_text | |
| # Агрегация признаков | |
| audio_pool_temp, audio_pool_feat = self._pool_features(audio) | |
| text_pool_temp, text_pool_feat = self._pool_features(text) | |
| # print(audio_pool_temp.shape, audio_pool_feat.shape, text_pool_temp.shape, text_pool_feat.shape) | |
| graph_feat = self.graph_fusion_feat(audio_pool_feat, text_pool_feat) | |
| graph_temp = self.graph_fusion_temp(audio_pool_temp, text_pool_temp) | |
| # print(graph_feat.shape, graph_temp.shape) | |
| # print(torch.mean(graph_feat, dim=1).shape, torch.mean(graph_temp, dim=1).shape) | |
| # graph_feat = self.fc_feat(graph_feat) | |
| # graph_temp = self.fc_temp(graph_temp) | |
| # Классификация | |
| features = torch.cat([graph_feat, graph_temp], dim=1) | |
| # print(graph_feat.shape, graph_temp.shape, features.shape) | |
| return self.classifier(features) | |
| def adaptive_temporal_pool(self, x, target_len): | |
| """Адаптивное изменение временной длины""" | |
| if x.size(1) == target_len: | |
| return x | |
| return F.interpolate( | |
| x.permute(0, 2, 1), | |
| size=target_len, | |
| mode='linear', | |
| align_corners=False | |
| ).permute(0, 2, 1) | |
| def _init_weights(self): | |
| for m in self.modules(): | |
| if isinstance(m, nn.Linear): | |
| nn.init.xavier_uniform_(m.weight) | |
| if m.bias is not None: | |
| nn.init.constant_(m.bias, 0) | |
| elif isinstance(m, nn.LayerNorm): | |
| nn.init.constant_(m.weight, 1) | |
| nn.init.constant_(m.bias, 0) | |
| class BiGatedGraphFormer(nn.Module): | |
| def __init__(self, audio_dim=1024, text_dim=1024, seg_len=44, hidden_dim=512, hidden_dim_gated=128, | |
| num_transformer_heads=2, num_graph_heads = 2, positional_encoding=True, dropout=0.1, mode='mean', | |
| device="cuda", tr_layer_number=1, out_features=128, num_classes=7): | |
| super(BiGatedGraphFormer, self).__init__() | |
| self.mode = mode | |
| self.hidden_dim = hidden_dim | |
| self.seg_len = seg_len | |
| self.tr_layer_number = tr_layer_number | |
| # Проекционные слои с нормализацией | |
| self.audio_proj = nn.Sequential( | |
| nn.Linear(audio_dim, hidden_dim) if audio_dim != hidden_dim else nn.Identity(), | |
| nn.LayerNorm(hidden_dim), | |
| nn.Dropout(dropout) | |
| ) | |
| self.text_proj = nn.Sequential( | |
| nn.Linear(text_dim, hidden_dim) if text_dim != hidden_dim else nn.Identity(), | |
| nn.LayerNorm(hidden_dim), | |
| nn.Dropout(dropout) | |
| ) | |
| # Трансформерные слои (сохраняем вашу реализацию) | |
| self.audio_to_text_attn = nn.ModuleList([ | |
| TransformerEncoderLayer( | |
| input_dim=hidden_dim, | |
| num_heads=num_transformer_heads, | |
| dropout=dropout, | |
| positional_encoding=positional_encoding | |
| ) for _ in range(tr_layer_number) | |
| ]) | |
| self.text_to_audio_attn = nn.ModuleList([ | |
| TransformerEncoderLayer( | |
| input_dim=hidden_dim, | |
| num_heads=num_transformer_heads, | |
| dropout=dropout, | |
| positional_encoding=positional_encoding | |
| ) for _ in range(tr_layer_number) | |
| ]) | |
| self.graph_fusion_feat = GraphFusionLayer(self.seg_len, heads=num_graph_heads, out_mean=False) | |
| self.graph_fusion_temp = GraphFusionLayer(hidden_dim, heads=num_graph_heads, out_mean=False) | |
| self.gated_feat = GAL(self.seg_len, self.seg_len, hidden_dim_gated, dropout_rate=dropout) | |
| self.gated_temp = GAL(hidden_dim, hidden_dim, hidden_dim_gated, dropout_rate=dropout) | |
| # Автоматический расчёт размерности для классификатора | |
| self._calculate_classifier_input_dim() | |
| # Классификатор | |
| self.classifier = nn.Sequential( | |
| nn.Linear(hidden_dim_gated*2, out_features), | |
| nn.LayerNorm(out_features), | |
| nn.GELU(), | |
| nn.Dropout(dropout), | |
| nn.Linear(out_features, num_classes) | |
| ) | |
| # Финальная проекция графов | |
| self.fc_graph_feat = nn.Sequential( | |
| nn.Linear(self.seg_len, hidden_dim_gated), | |
| nn.LayerNorm(hidden_dim_gated), | |
| nn.Dropout(dropout) | |
| ) | |
| self.fc_graph_temp = nn.Sequential( | |
| nn.Linear(hidden_dim, hidden_dim_gated), | |
| nn.LayerNorm(hidden_dim_gated), | |
| nn.Dropout(dropout) | |
| ) | |
| # Финальная проекция gated | |
| self.fc_gated_feat = nn.Sequential( | |
| nn.Linear(hidden_dim_gated, hidden_dim_gated), | |
| nn.LayerNorm(hidden_dim_gated), | |
| nn.Dropout(dropout) | |
| ) | |
| self.fc_gated_temp = nn.Sequential( | |
| nn.Linear(hidden_dim_gated, hidden_dim_gated), | |
| nn.LayerNorm(hidden_dim_gated), | |
| nn.Dropout(dropout) | |
| ) | |
| self._init_weights() | |
| def _calculate_classifier_input_dim(self): | |
| """Вычисляет размер входных признаков для классификатора""" | |
| # Тестовый проход через пулинг с dummy-данными | |
| dummy_audio = torch.randn(1, self.seg_len, self.hidden_dim) | |
| dummy_text = torch.randn(1, self.seg_len, self.hidden_dim) | |
| audio_pool_temp, audio_pool_feat = self._pool_features(dummy_audio) | |
| # text_pool_temp, _ = self._pool_features(dummy_text) | |
| combined = torch.cat([audio_pool_temp, audio_pool_feat], dim=1) | |
| self.classifier_input_dim = combined.size(1) | |
| def _pool_features(self, x): | |
| # Статистики по временной оси (seq_len) | |
| mean_temp = x.mean(dim=1) # [batch, hidden_dim] | |
| # Статистики по feature оси (hidden_dim) | |
| mean_feat = x.mean(dim=-1) # [batch, seq_len] | |
| return mean_temp, mean_feat | |
| def forward(self, audio_features, text_features): | |
| # Проекция признаков | |
| audio = self.audio_proj(audio_features.float()) | |
| text = self.text_proj(text_features.float()) | |
| # Адаптивный пулинг | |
| min_len = min(audio.size(1), text.size(1)) | |
| audio = self.adaptive_temporal_pool(audio, min_len) | |
| text = self.adaptive_temporal_pool(text, min_len) | |
| # Кросс-модальное взаимодействие | |
| for i in range(self.tr_layer_number): | |
| attn_audio = self.audio_to_text_attn[i](text, audio, audio) | |
| attn_text = self.text_to_audio_attn[i](audio, text, text) | |
| audio = audio + attn_audio | |
| text = text + attn_text | |
| # Агрегация признаков | |
| audio_pool_temp, audio_pool_feat = self._pool_features(audio) | |
| text_pool_temp, text_pool_feat = self._pool_features(text) | |
| # print(audio_pool_temp.shape, audio_pool_feat.shape, text_pool_temp.shape, text_pool_feat.shape) | |
| graph_feat = self.graph_fusion_feat(audio_pool_feat, text_pool_feat) | |
| graph_temp = self.graph_fusion_temp(audio_pool_temp, text_pool_temp) | |
| gated_feat = self.gated_feat(graph_feat[:, 0, :], graph_feat[:, 1, :]) | |
| gated_temp = self.gated_temp(graph_temp[:, 0, :], graph_temp[:, 1, :]) | |
| fused_feat = self.fc_graph_feat(torch.mean(graph_feat, dim=1)) + self.fc_gated_feat(gated_feat) | |
| fused_temp = self.fc_graph_temp(torch.mean(graph_temp, dim=1)) + self.fc_gated_feat(gated_temp) | |
| # print(graph_feat.shape, graph_temp.shape) | |
| # print(torch.mean(graph_feat, dim=1).shape, torch.mean(graph_temp, dim=1).shape) | |
| # graph_feat = self.fc_feat(graph_feat) | |
| # graph_temp = self.fc_temp(graph_temp) | |
| # Классификация | |
| features = torch.cat([fused_feat, fused_temp], dim=1) | |
| # print(graph_feat.shape, graph_temp.shape, features.shape) | |
| return self.classifier(features) | |
| def adaptive_temporal_pool(self, x, target_len): | |
| """Адаптивное изменение временной длины""" | |
| if x.size(1) == target_len: | |
| return x | |
| return F.interpolate( | |
| x.permute(0, 2, 1), | |
| size=target_len, | |
| mode='linear', | |
| align_corners=False | |
| ).permute(0, 2, 1) | |
| def _init_weights(self): | |
| for m in self.modules(): | |
| if isinstance(m, nn.Linear): | |
| nn.init.xavier_uniform_(m.weight) | |
| if m.bias is not None: | |
| nn.init.constant_(m.bias, 0) | |
| elif isinstance(m, nn.LayerNorm): | |
| nn.init.constant_(m.weight, 1) | |
| nn.init.constant_(m.bias, 0) | |
| class BiFormerWithProb(nn.Module): | |
| def __init__(self, audio_dim=1024, text_dim=1024, seg_len=44, hidden_dim=512, hidden_dim_gated=128, | |
| num_transformer_heads=2, num_graph_heads=2, positional_encoding=True, dropout=0.1, mode='mean', | |
| device="cuda", tr_layer_number=1, out_features=128, num_classes=7): | |
| super(BiFormerWithProb, self).__init__() | |
| self.mode = mode | |
| self.hidden_dim = hidden_dim | |
| self.seg_len = seg_len | |
| self.tr_layer_number = tr_layer_number | |
| # Проекционные слои с нормализацией | |
| self.audio_proj = nn.Sequential( | |
| nn.Linear(audio_dim, hidden_dim) if audio_dim != hidden_dim else nn.Identity(), | |
| nn.LayerNorm(hidden_dim), | |
| nn.Dropout(dropout) | |
| ) | |
| self.text_proj = nn.Sequential( | |
| nn.Linear(text_dim, hidden_dim) if text_dim != hidden_dim else nn.Identity(), | |
| nn.LayerNorm(hidden_dim), | |
| nn.Dropout(dropout) | |
| ) | |
| # self.audio_proj = nn.Sequential( | |
| # nn.Conv1d(audio_dim, hidden_dim, 1), | |
| # nn.GELU(), | |
| # ) | |
| # self.text_proj = nn.Sequential( | |
| # nn.Conv1d(text_dim, hidden_dim, 1), | |
| # nn.GELU(), | |
| # ) | |
| # Трансформерные слои (сохраняем вашу реализацию) | |
| self.audio_to_text_attn = nn.ModuleList([ | |
| TransformerEncoderLayer( | |
| input_dim=hidden_dim, | |
| num_heads=num_transformer_heads, | |
| dropout=dropout, | |
| positional_encoding=positional_encoding | |
| ) for _ in range(tr_layer_number) | |
| ]) | |
| self.text_to_audio_attn = nn.ModuleList([ | |
| TransformerEncoderLayer( | |
| input_dim=hidden_dim, | |
| num_heads=num_transformer_heads, | |
| dropout=dropout, | |
| positional_encoding=positional_encoding | |
| ) for _ in range(tr_layer_number) | |
| ]) | |
| # Автоматический расчёт размерности для классификатора | |
| self._calculate_classifier_input_dim() | |
| # Классификатор | |
| self.classifier = nn.Sequential( | |
| nn.Linear(self.classifier_input_dim, out_features), | |
| nn.LayerNorm(out_features), | |
| nn.GELU(), | |
| nn.Dropout(dropout), | |
| nn.Linear(out_features, num_classes) | |
| ) | |
| self.pred_fusion = PredictionsFusion(num_matrices=3, num_classes=num_classes) | |
| self._init_weights() | |
| def _calculate_classifier_input_dim(self): | |
| """Вычисляет размер входных признаков для классификатора""" | |
| # Тестовый проход через пулинг с dummy-данными | |
| dummy_audio = torch.randn(1, self.seg_len, self.hidden_dim) | |
| dummy_text = torch.randn(1, self.seg_len, self.hidden_dim) | |
| audio_pool = self._pool_features(dummy_audio) | |
| text_pool = self._pool_features(dummy_text) | |
| combined = torch.cat([audio_pool, text_pool], dim=1) | |
| self.classifier_input_dim = combined.size(1) | |
| def _pool_features(self, x): | |
| # Статистики по временной оси (seq_len) | |
| mean_temp = x.mean(dim=1) # [batch, hidden_dim] | |
| # Статистики по feature оси (hidden_dim) | |
| mean_feat = x.mean(dim=-1) # [batch, seq_len] | |
| return torch.cat([mean_temp, mean_feat], dim=1) | |
| def forward(self, audio_features, text_features, audio_pred, text_pred): | |
| # Проекция признаков | |
| # audio = self.audio_proj(audio_features.permute(0,2,1)).permute(0,2,1) | |
| # text = self.text_proj(text_features.permute(0,2,1)).permute(0,2,1) | |
| audio = self.audio_proj(audio_features.float()) | |
| text = self.text_proj(text_features.float()) | |
| # Адаптивный пулинг | |
| min_len = min(audio.size(1), text.size(1)) | |
| audio = self.adaptive_temporal_pool(audio, min_len) | |
| text = self.adaptive_temporal_pool(text, min_len) | |
| # Кросс-модальное взаимодействие | |
| for i in range(self.tr_layer_number): | |
| attn_audio = self.audio_to_text_attn[i](text, audio, audio) | |
| attn_text = self.text_to_audio_attn[i](audio, text, text) | |
| audio = audio + attn_audio | |
| text = text + attn_text | |
| # Агрегация признаков | |
| audio_pool = self._pool_features(audio) | |
| text_pool = self._pool_features(text) | |
| # Классификация | |
| features = torch.cat([audio_pool, text_pool], dim=1) | |
| out = self.classifier(features) | |
| w_out = self.pred_fusion([audio_pred, text_pred, out]) | |
| return w_out | |
| def adaptive_temporal_pool(self, x, target_len): | |
| """Адаптивное изменение временной длины""" | |
| if x.size(1) == target_len: | |
| return x | |
| return F.interpolate( | |
| x.permute(0, 2, 1), | |
| size=target_len, | |
| mode='linear', | |
| align_corners=False | |
| ).permute(0, 2, 1) | |
| def _init_weights(self): | |
| for m in self.modules(): | |
| if isinstance(m, nn.Linear): | |
| nn.init.xavier_uniform_(m.weight) | |
| if m.bias is not None: | |
| nn.init.constant_(m.bias, 0) | |
| elif isinstance(m, nn.LayerNorm): | |
| nn.init.constant_(m.weight, 1) | |
| nn.init.constant_(m.bias, 0) | |
| class BiGraphFormerWithProb(nn.Module): | |
| def __init__(self, audio_dim=1024, text_dim=1024, seg_len=44, hidden_dim=512, hidden_dim_gated=128, | |
| num_transformer_heads=2, num_graph_heads = 2, positional_encoding=True, dropout=0.1, mode='mean', | |
| device="cuda", tr_layer_number=1, out_features=128, num_classes=7): | |
| super(BiGraphFormerWithProb, self).__init__() | |
| self.mode = mode | |
| self.hidden_dim = hidden_dim | |
| self.seg_len = seg_len | |
| self.tr_layer_number = tr_layer_number | |
| # Проекционные слои с нормализацией | |
| self.audio_proj = nn.Sequential( | |
| nn.Linear(audio_dim, hidden_dim) if audio_dim != hidden_dim else nn.Identity(), | |
| nn.LayerNorm(hidden_dim), | |
| nn.Dropout(dropout) | |
| ) | |
| self.text_proj = nn.Sequential( | |
| nn.Linear(text_dim, hidden_dim) if text_dim != hidden_dim else nn.Identity(), | |
| nn.LayerNorm(hidden_dim), | |
| nn.Dropout(dropout) | |
| ) | |
| # Трансформерные слои (сохраняем вашу реализацию) | |
| self.audio_to_text_attn = nn.ModuleList([ | |
| TransformerEncoderLayer( | |
| input_dim=hidden_dim, | |
| num_heads=num_transformer_heads, | |
| dropout=dropout, | |
| positional_encoding=positional_encoding | |
| ) for _ in range(tr_layer_number) | |
| ]) | |
| self.text_to_audio_attn = nn.ModuleList([ | |
| TransformerEncoderLayer( | |
| input_dim=hidden_dim, | |
| num_heads=num_transformer_heads, | |
| dropout=dropout, | |
| positional_encoding=positional_encoding | |
| ) for _ in range(tr_layer_number) | |
| ]) | |
| self.graph_fusion_feat = GraphFusionLayer(self.seg_len, heads=num_graph_heads) | |
| self.graph_fusion_temp = GraphFusionLayer(hidden_dim, heads=num_graph_heads) | |
| # Автоматический расчёт размерности для классификатора | |
| self._calculate_classifier_input_dim() | |
| # Классификатор | |
| self.classifier = nn.Sequential( | |
| nn.Linear(self.classifier_input_dim, out_features), | |
| nn.LayerNorm(out_features), | |
| nn.GELU(), | |
| nn.Dropout(dropout), | |
| nn.Linear(out_features, num_classes) | |
| ) | |
| # Финальная проекция графов | |
| self.fc_feat = nn.Sequential( | |
| nn.Linear(self.seg_len, self.seg_len), | |
| nn.LayerNorm(self.seg_len), | |
| nn.Dropout(dropout) | |
| ) | |
| self.fc_temp = nn.Sequential( | |
| nn.Linear(hidden_dim, hidden_dim), | |
| nn.LayerNorm(hidden_dim), | |
| nn.Dropout(dropout) | |
| ) | |
| self.pred_fusion = PredictionsFusion(num_matrices=3, num_classes=num_classes) | |
| self._init_weights() | |
| def _calculate_classifier_input_dim(self): | |
| """Вычисляет размер входных признаков для классификатора""" | |
| # Тестовый проход через пулинг с dummy-данными | |
| dummy_audio = torch.randn(1, self.seg_len, self.hidden_dim) | |
| dummy_text = torch.randn(1, self.seg_len, self.hidden_dim) | |
| audio_pool_temp, audio_pool_feat = self._pool_features(dummy_audio) | |
| # text_pool_temp, _ = self._pool_features(dummy_text) | |
| combined = torch.cat([audio_pool_temp, audio_pool_feat], dim=1) | |
| self.classifier_input_dim = combined.size(1) | |
| def _pool_features(self, x): | |
| # Статистики по временной оси (seq_len) | |
| mean_temp = x.mean(dim=1) # [batch, hidden_dim] | |
| # Статистики по feature оси (hidden_dim) | |
| mean_feat = x.mean(dim=-1) # [batch, seq_len] | |
| return mean_temp, mean_feat | |
| def forward(self, audio_features, text_features, audio_pred, text_pred): | |
| # Проекция признаков | |
| audio = self.audio_proj(audio_features.float()) | |
| text = self.text_proj(text_features.float()) | |
| # Адаптивный пулинг | |
| min_len = min(audio.size(1), text.size(1)) | |
| audio = self.adaptive_temporal_pool(audio, min_len) | |
| text = self.adaptive_temporal_pool(text, min_len) | |
| # Кросс-модальное взаимодействие | |
| for i in range(self.tr_layer_number): | |
| attn_audio = self.audio_to_text_attn[i](text, audio, audio) | |
| attn_text = self.text_to_audio_attn[i](audio, text, text) | |
| audio = audio + attn_audio | |
| text = text + attn_text | |
| # Агрегация признаков | |
| audio_pool_temp, audio_pool_feat = self._pool_features(audio) | |
| text_pool_temp, text_pool_feat = self._pool_features(text) | |
| # print(audio_pool_temp.shape, audio_pool_feat.shape, text_pool_temp.shape, text_pool_feat.shape) | |
| graph_feat = self.graph_fusion_feat(audio_pool_feat, text_pool_feat) | |
| graph_temp = self.graph_fusion_temp(audio_pool_temp, text_pool_temp) | |
| # print(graph_feat.shape, graph_temp.shape) | |
| # print(torch.mean(graph_feat, dim=1).shape, torch.mean(graph_temp, dim=1).shape) | |
| # graph_feat = self.fc_feat(graph_feat) | |
| # graph_temp = self.fc_temp(graph_temp) | |
| # Классификация | |
| features = torch.cat([graph_feat, graph_temp], dim=1) | |
| # print(graph_feat.shape, graph_temp.shape, features.shape) | |
| out = self.classifier(features) | |
| w_out = self.pred_fusion([audio_pred, text_pred, out]) | |
| return w_out | |
| def adaptive_temporal_pool(self, x, target_len): | |
| """Адаптивное изменение временной длины""" | |
| if x.size(1) == target_len: | |
| return x | |
| return F.interpolate( | |
| x.permute(0, 2, 1), | |
| size=target_len, | |
| mode='linear', | |
| align_corners=False | |
| ).permute(0, 2, 1) | |
| def _init_weights(self): | |
| for m in self.modules(): | |
| if isinstance(m, nn.Linear): | |
| nn.init.xavier_uniform_(m.weight) | |
| if m.bias is not None: | |
| nn.init.constant_(m.bias, 0) | |
| elif isinstance(m, nn.LayerNorm): | |
| nn.init.constant_(m.weight, 1) | |
| nn.init.constant_(m.bias, 0) | |
| class BiGatedGraphFormerWithProb(nn.Module): | |
| def __init__(self, audio_dim=1024, text_dim=1024, seg_len=44, hidden_dim=512, hidden_dim_gated=128, | |
| num_transformer_heads=2, num_graph_heads = 2, positional_encoding=True, dropout=0.1, mode='mean', | |
| device="cuda", tr_layer_number=1, out_features=128, num_classes=7): | |
| super(BiGatedGraphFormerWithProb, self).__init__() | |
| self.mode = mode | |
| self.hidden_dim = hidden_dim | |
| self.seg_len = seg_len | |
| self.tr_layer_number = tr_layer_number | |
| # Проекционные слои с нормализацией | |
| self.audio_proj = nn.Sequential( | |
| nn.Linear(audio_dim, hidden_dim) if audio_dim != hidden_dim else nn.Identity(), | |
| nn.LayerNorm(hidden_dim), | |
| nn.Dropout(dropout) | |
| ) | |
| self.text_proj = nn.Sequential( | |
| nn.Linear(text_dim, hidden_dim) if text_dim != hidden_dim else nn.Identity(), | |
| nn.LayerNorm(hidden_dim), | |
| nn.Dropout(dropout) | |
| ) | |
| # Трансформерные слои (сохраняем вашу реализацию) | |
| self.audio_to_text_attn = nn.ModuleList([ | |
| TransformerEncoderLayer( | |
| input_dim=hidden_dim, | |
| num_heads=num_transformer_heads, | |
| dropout=dropout, | |
| positional_encoding=positional_encoding | |
| ) for _ in range(tr_layer_number) | |
| ]) | |
| self.text_to_audio_attn = nn.ModuleList([ | |
| TransformerEncoderLayer( | |
| input_dim=hidden_dim, | |
| num_heads=num_transformer_heads, | |
| dropout=dropout, | |
| positional_encoding=positional_encoding | |
| ) for _ in range(tr_layer_number) | |
| ]) | |
| self.graph_fusion_feat = GraphFusionLayer(self.seg_len, heads=num_graph_heads, out_mean=False) | |
| self.graph_fusion_temp = GraphFusionLayer(hidden_dim, heads=num_graph_heads, out_mean=False) | |
| self.gated_feat = GAL(self.seg_len, self.seg_len, hidden_dim_gated, dropout_rate=dropout) | |
| self.gated_temp = GAL(hidden_dim, hidden_dim, hidden_dim_gated, dropout_rate=dropout) | |
| # Автоматический расчёт размерности для классификатора | |
| self._calculate_classifier_input_dim() | |
| # Классификатор | |
| self.classifier = nn.Sequential( | |
| nn.Linear(hidden_dim_gated*2, out_features), | |
| nn.LayerNorm(out_features), | |
| nn.GELU(), | |
| nn.Dropout(dropout), | |
| nn.Linear(out_features, num_classes) | |
| ) | |
| # Финальная проекция графов | |
| self.fc_graph_feat = nn.Sequential( | |
| nn.Linear(self.seg_len, hidden_dim_gated), | |
| nn.LayerNorm(hidden_dim_gated), | |
| nn.Dropout(dropout) | |
| ) | |
| self.fc_graph_temp = nn.Sequential( | |
| nn.Linear(hidden_dim, hidden_dim_gated), | |
| nn.LayerNorm(hidden_dim_gated), | |
| nn.Dropout(dropout) | |
| ) | |
| # Финальная проекция gated | |
| self.fc_gated_feat = nn.Sequential( | |
| nn.Linear(hidden_dim_gated, hidden_dim_gated), | |
| nn.LayerNorm(hidden_dim_gated), | |
| nn.Dropout(dropout) | |
| ) | |
| self.fc_gated_temp = nn.Sequential( | |
| nn.Linear(hidden_dim_gated, hidden_dim_gated), | |
| nn.LayerNorm(hidden_dim_gated), | |
| nn.Dropout(dropout) | |
| ) | |
| self.pred_fusion = PredictionsFusion(num_matrices=3, num_classes=num_classes) | |
| self._init_weights() | |
| def _calculate_classifier_input_dim(self): | |
| """Вычисляет размер входных признаков для классификатора""" | |
| # Тестовый проход через пулинг с dummy-данными | |
| dummy_audio = torch.randn(1, self.seg_len, self.hidden_dim) | |
| dummy_text = torch.randn(1, self.seg_len, self.hidden_dim) | |
| audio_pool_temp, audio_pool_feat = self._pool_features(dummy_audio) | |
| # text_pool_temp, _ = self._pool_features(dummy_text) | |
| combined = torch.cat([audio_pool_temp, audio_pool_feat], dim=1) | |
| self.classifier_input_dim = combined.size(1) | |
| def _pool_features(self, x): | |
| # Статистики по временной оси (seq_len) | |
| mean_temp = x.mean(dim=1) # [batch, hidden_dim] | |
| # Статистики по feature оси (hidden_dim) | |
| mean_feat = x.mean(dim=-1) # [batch, seq_len] | |
| return mean_temp, mean_feat | |
| def forward(self, audio_features, text_features, audio_pred, text_pred): | |
| # Проекция признаков | |
| audio = self.audio_proj(audio_features.float()) | |
| text = self.text_proj(text_features.float()) | |
| # Адаптивный пулинг | |
| min_len = min(audio.size(1), text.size(1)) | |
| audio = self.adaptive_temporal_pool(audio, min_len) | |
| text = self.adaptive_temporal_pool(text, min_len) | |
| # Кросс-модальное взаимодействие | |
| for i in range(self.tr_layer_number): | |
| attn_audio = self.audio_to_text_attn[i](text, audio, audio) | |
| attn_text = self.text_to_audio_attn[i](audio, text, text) | |
| audio = audio + attn_audio | |
| text = text + attn_text | |
| # Агрегация признаков | |
| audio_pool_temp, audio_pool_feat = self._pool_features(audio) | |
| text_pool_temp, text_pool_feat = self._pool_features(text) | |
| # print(audio_pool_temp.shape, audio_pool_feat.shape, text_pool_temp.shape, text_pool_feat.shape) | |
| graph_feat = self.graph_fusion_feat(audio_pool_feat, text_pool_feat) | |
| graph_temp = self.graph_fusion_temp(audio_pool_temp, text_pool_temp) | |
| gated_feat = self.gated_feat(graph_feat[:, 0, :], graph_feat[:, 1, :]) | |
| gated_temp = self.gated_temp(graph_temp[:, 0, :], graph_temp[:, 1, :]) | |
| fused_feat = self.fc_graph_feat(torch.mean(graph_feat, dim=1)) + self.fc_gated_feat(gated_feat) | |
| fused_temp = self.fc_graph_temp(torch.mean(graph_temp, dim=1)) + self.fc_gated_feat(gated_temp) | |
| # print(graph_feat.shape, graph_temp.shape) | |
| # print(torch.mean(graph_feat, dim=1).shape, torch.mean(graph_temp, dim=1).shape) | |
| # graph_feat = self.fc_feat(graph_feat) | |
| # graph_temp = self.fc_temp(graph_temp) | |
| # Классификация | |
| features = torch.cat([fused_feat, fused_temp], dim=1) | |
| # print(graph_feat.shape, graph_temp.shape, features.shape) | |
| out = self.classifier(features) | |
| w_out = self.pred_fusion([audio_pred, text_pred, out]) | |
| return w_out | |
| def adaptive_temporal_pool(self, x, target_len): | |
| """Адаптивное изменение временной длины""" | |
| if x.size(1) == target_len: | |
| return x | |
| return F.interpolate( | |
| x.permute(0, 2, 1), | |
| size=target_len, | |
| mode='linear', | |
| align_corners=False | |
| ).permute(0, 2, 1) | |
| def _init_weights(self): | |
| for m in self.modules(): | |
| if isinstance(m, nn.Linear): | |
| nn.init.xavier_uniform_(m.weight) | |
| if m.bias is not None: | |
| nn.init.constant_(m.bias, 0) | |
| elif isinstance(m, nn.LayerNorm): | |
| nn.init.constant_(m.weight, 1) | |
| nn.init.constant_(m.bias, 0) | |
| class BiGatedFormer(nn.Module): | |
| def __init__(self, audio_dim=1024, text_dim=1024, seg_len=44, hidden_dim=512, hidden_dim_gated=128, | |
| num_transformer_heads=2, num_graph_heads = 2, positional_encoding=True, dropout=0.1, mode='mean', | |
| device="cuda", tr_layer_number=1, out_features=128, num_classes=7): | |
| super(BiGatedFormer, self).__init__() | |
| self.mode = mode | |
| self.hidden_dim = hidden_dim | |
| self.seg_len = seg_len | |
| self.tr_layer_number = tr_layer_number | |
| # Проекционные слои с нормализацией | |
| self.audio_proj = nn.Sequential( | |
| nn.Linear(audio_dim, hidden_dim) if audio_dim != hidden_dim else nn.Identity(), | |
| nn.LayerNorm(hidden_dim), | |
| nn.Dropout(dropout) | |
| ) | |
| self.text_proj = nn.Sequential( | |
| nn.Linear(text_dim, hidden_dim) if text_dim != hidden_dim else nn.Identity(), | |
| nn.LayerNorm(hidden_dim), | |
| nn.Dropout(dropout) | |
| ) | |
| # Трансформерные слои (сохраняем вашу реализацию) | |
| self.audio_to_text_attn = nn.ModuleList([ | |
| TransformerEncoderLayer( | |
| input_dim=hidden_dim, | |
| num_heads=num_transformer_heads, | |
| dropout=dropout, | |
| positional_encoding=positional_encoding | |
| ) for _ in range(tr_layer_number) | |
| ]) | |
| self.text_to_audio_attn = nn.ModuleList([ | |
| TransformerEncoderLayer( | |
| input_dim=hidden_dim, | |
| num_heads=num_transformer_heads, | |
| dropout=dropout, | |
| positional_encoding=positional_encoding | |
| ) for _ in range(tr_layer_number) | |
| ]) | |
| # self.graph_fusion_feat = GraphFusionLayer(self.seg_len, heads=num_graph_heads, out_mean=False) | |
| # self.graph_fusion_temp = GraphFusionLayer(hidden_dim, heads=num_graph_heads, out_mean=False) | |
| self.gated_feat = GAL(self.seg_len, self.seg_len, hidden_dim_gated, dropout_rate=dropout) | |
| self.gated_temp = GAL(hidden_dim, hidden_dim, hidden_dim_gated, dropout_rate=dropout) | |
| # Автоматический расчёт размерности для классификатора | |
| self._calculate_classifier_input_dim() | |
| # Классификатор | |
| self.classifier = nn.Sequential( | |
| nn.Linear(hidden_dim_gated*2, out_features), | |
| nn.LayerNorm(out_features), | |
| nn.GELU(), | |
| nn.Dropout(dropout), | |
| nn.Linear(out_features, num_classes) | |
| ) | |
| # Финальная проекция графов | |
| # self.fc_graph_feat = nn.Sequential( | |
| # nn.Linear(self.seg_len, hidden_dim_gated), | |
| # nn.LayerNorm(hidden_dim_gated), | |
| # nn.Dropout(dropout) | |
| # ) | |
| # self.fc_graph_temp = nn.Sequential( | |
| # nn.Linear(hidden_dim, hidden_dim_gated), | |
| # nn.LayerNorm(hidden_dim_gated), | |
| # nn.Dropout(dropout) | |
| # ) | |
| # Финальная проекция gated | |
| self.fc_gated_feat = nn.Sequential( | |
| nn.Linear(hidden_dim_gated, hidden_dim_gated), | |
| nn.LayerNorm(hidden_dim_gated), | |
| nn.Dropout(dropout) | |
| ) | |
| self.fc_gated_temp = nn.Sequential( | |
| nn.Linear(hidden_dim_gated, hidden_dim_gated), | |
| nn.LayerNorm(hidden_dim_gated), | |
| nn.Dropout(dropout) | |
| ) | |
| self._init_weights() | |
| def _calculate_classifier_input_dim(self): | |
| """Вычисляет размер входных признаков для классификатора""" | |
| # Тестовый проход через пулинг с dummy-данными | |
| dummy_audio = torch.randn(1, self.seg_len, self.hidden_dim) | |
| dummy_text = torch.randn(1, self.seg_len, self.hidden_dim) | |
| audio_pool_temp, audio_pool_feat = self._pool_features(dummy_audio) | |
| # text_pool_temp, _ = self._pool_features(dummy_text) | |
| combined = torch.cat([audio_pool_temp, audio_pool_feat], dim=1) | |
| self.classifier_input_dim = combined.size(1) | |
| def _pool_features(self, x): | |
| # Статистики по временной оси (seq_len) | |
| mean_temp = x.mean(dim=1) # [batch, hidden_dim] | |
| # Статистики по feature оси (hidden_dim) | |
| mean_feat = x.mean(dim=-1) # [batch, seq_len] | |
| return mean_temp, mean_feat | |
| def forward(self, audio_features, text_features): | |
| # Проекция признаков | |
| audio = self.audio_proj(audio_features.float()) | |
| text = self.text_proj(text_features.float()) | |
| # Адаптивный пулинг | |
| min_len = min(audio.size(1), text.size(1)) | |
| audio = self.adaptive_temporal_pool(audio, min_len) | |
| text = self.adaptive_temporal_pool(text, min_len) | |
| # Кросс-модальное взаимодействие | |
| for i in range(self.tr_layer_number): | |
| attn_audio = self.audio_to_text_attn[i](text, audio, audio) | |
| attn_text = self.text_to_audio_attn[i](audio, text, text) | |
| audio = audio + attn_audio | |
| text = text + attn_text | |
| # Агрегация признаков | |
| audio_pool_temp, audio_pool_feat = self._pool_features(audio) | |
| text_pool_temp, text_pool_feat = self._pool_features(text) | |
| # print(audio_pool_temp.shape, audio_pool_feat.shape, text_pool_temp.shape, text_pool_feat.shape) | |
| # graph_feat = self.graph_fusion_feat(audio_pool_feat, text_pool_feat) | |
| # graph_temp = self.graph_fusion_temp(audio_pool_temp, text_pool_temp) | |
| gated_feat = self.gated_feat(audio_pool_feat, text_pool_feat) | |
| gated_temp = self.gated_temp(audio_pool_temp, text_pool_temp) | |
| # fused_feat = self.fc_graph_feat(torch.mean(graph_feat, dim=1)) + self.fc_gated_feat(gated_feat) | |
| # fused_temp = self.fc_graph_temp(torch.mean(graph_temp, dim=1)) + self.fc_gated_feat(gated_temp) | |
| # print(graph_feat.shape, graph_temp.shape) | |
| # print(torch.mean(graph_feat, dim=1).shape, torch.mean(graph_temp, dim=1).shape) | |
| # graph_feat = self.fc_feat(graph_feat) | |
| # graph_temp = self.fc_temp(graph_temp) | |
| # Классификация | |
| features = torch.cat([gated_feat, gated_temp], dim=1) | |
| # print(graph_feat.shape, graph_temp.shape, features.shape) | |
| return self.classifier(features) | |
| def adaptive_temporal_pool(self, x, target_len): | |
| """Адаптивное изменение временной длины""" | |
| if x.size(1) == target_len: | |
| return x | |
| return F.interpolate( | |
| x.permute(0, 2, 1), | |
| size=target_len, | |
| mode='linear', | |
| align_corners=False | |
| ).permute(0, 2, 1) | |
| def _init_weights(self): | |
| for m in self.modules(): | |
| if isinstance(m, nn.Linear): | |
| nn.init.xavier_uniform_(m.weight) | |
| if m.bias is not None: | |
| nn.init.constant_(m.bias, 0) | |
| elif isinstance(m, nn.LayerNorm): | |
| nn.init.constant_(m.weight, 1) | |
| nn.init.constant_(m.bias, 0) | |
| class BiMamba(nn.Module): | |
| def __init__(self, audio_dim=1024, text_dim=1024, seg_len=44, hidden_dim=512, mamba_d_state=16, | |
| d_discr=None, mamba_ker_size=4, mamba_layer_number=2, dropout=0.1, mode='', positional_encoding=False, | |
| out_features=128, num_classes=7, device="cuda"): | |
| super(BiMamba, self).__init__() | |
| self.hidden_dim = hidden_dim | |
| self.seg_len = seg_len | |
| self.num_mamba_layers = mamba_layer_number | |
| self.device = device | |
| # Проекционные слои для каждой модальности | |
| self.audio_proj = nn.Sequential( | |
| nn.Linear(audio_dim, hidden_dim), | |
| nn.LayerNorm(hidden_dim), | |
| nn.Dropout(dropout) | |
| ) | |
| self.text_proj = nn.Sequential( | |
| nn.Linear(text_dim, hidden_dim), | |
| nn.LayerNorm(hidden_dim), | |
| nn.Dropout(dropout) | |
| ) | |
| # Слой для объединения модальностей | |
| self.fusion_proj = nn.Sequential( | |
| nn.Linear(2 * hidden_dim, hidden_dim), | |
| nn.LayerNorm(hidden_dim), | |
| nn.Dropout(dropout) | |
| ) | |
| # Mamba блоки для обработки объединенных признаков | |
| mamba_params = { | |
| 'd_input': hidden_dim, | |
| 'd_model': hidden_dim, | |
| 'd_state': mamba_d_state, | |
| 'd_discr': d_discr, | |
| 'ker_size': mamba_ker_size | |
| } | |
| self.mamba_blocks = nn.ModuleList([ | |
| nn.Sequential( | |
| MambaBlock(**mamba_params), | |
| RMSNorm(hidden_dim) | |
| ) | |
| for _ in range(self.num_mamba_layers) | |
| ]) | |
| # Автоматический расчет размерности классификатора | |
| # self._calculate_classifier_input_dim() | |
| # Классификатор | |
| self.classifier = nn.Sequential( | |
| nn.Linear(self.seg_len + self.hidden_dim, out_features), | |
| nn.LayerNorm(out_features), | |
| nn.GELU(), | |
| nn.Dropout(dropout), | |
| nn.Linear(out_features, num_classes) | |
| ) | |
| self._init_weights() | |
| # def _calculate_classifier_input_dim(self): | |
| # """Вычисляет размер входных признаков для классификатора""" | |
| # dummy = torch.randn(1, self.seg_len, self.hidden_dim) | |
| # pooled = self._pool_features(dummy) | |
| # self.classifier_input_dim = pooled.size(1) | |
| def _pool_features(self, x): | |
| """Объединение временных и feature статистик""" | |
| mean_temp = x.mean(dim=1) # [batch, hidden_dim] | |
| mean_feat = x.mean(dim=-1) # [batch, seq_len] | |
| full_feature = torch.cat([mean_temp, mean_feat], dim=1) | |
| if full_feature.shape[-1] == self.seg_len+self.hidden_dim: | |
| return torch.cat([mean_temp, mean_feat], dim=1) | |
| else: | |
| pad_size = self.seg_len+self.hidden_dim - full_feature.shape[-1] | |
| return F.pad(full_feature, (0, pad_size), mode="constant", value=0) | |
| def forward(self, audio_features, text_features): | |
| # Проекция признаков | |
| audio = self.audio_proj(audio_features.float()) # [B, T, D] | |
| text = self.text_proj(text_features.float()) # [B, T, D] | |
| # Адаптивный пулинг к минимальной длине | |
| min_len = min(audio.size(1), text.size(1)) | |
| audio = self._adaptive_pool(audio, min_len) | |
| text = self._adaptive_pool(text, min_len) | |
| # Объединение модальностей | |
| fused = torch.cat([audio, text], dim=-1) # [B, T, 2*D] | |
| fused = self.fusion_proj(fused) # [B, T, D] | |
| # Обработка объединенных признаков через Mamba | |
| for mamba_block in self.mamba_blocks: | |
| out, _ = mamba_block[0](fused, None) | |
| out = mamba_block[1](out) | |
| fused = fused + out # Residual connection | |
| # Агрегация признаков и классификация | |
| pooled = self._pool_features(fused) | |
| return self.classifier(pooled) | |
| def _adaptive_pool(self, x, target_len): | |
| """Адаптивное изменение временной длины""" | |
| if x.size(1) == target_len: | |
| return x | |
| return F.interpolate( | |
| x.permute(0, 2, 1), | |
| size=target_len, | |
| mode='linear', | |
| align_corners=False | |
| ).permute(0, 2, 1) | |
| def _init_weights(self): | |
| for m in self.modules(): | |
| if isinstance(m, nn.Linear): | |
| nn.init.xavier_uniform_(m.weight) | |
| if m.bias is not None: | |
| nn.init.constant_(m.bias, 0) | |
| elif isinstance(m, nn.LayerNorm): | |
| nn.init.constant_(m.weight, 1) | |
| nn.init.constant_(m.bias, 0) | |
| class BiMambaWithProb(nn.Module): | |
| def __init__(self, audio_dim=1024, text_dim=1024, seg_len=44, hidden_dim=512, mamba_d_state=16, | |
| d_discr=None, mamba_ker_size=4, mamba_layer_number=2, dropout=0.1, mode='',positional_encoding=False, | |
| out_features=128, num_classes=7, device="cuda"): | |
| super(BiMambaWithProb, self).__init__() | |
| self.hidden_dim = hidden_dim | |
| self.seg_len = seg_len | |
| self.num_mamba_layers = mamba_layer_number | |
| self.device = device | |
| # Проекционные слои для каждой модальности | |
| self.audio_proj = nn.Sequential( | |
| nn.Linear(audio_dim, hidden_dim), | |
| nn.LayerNorm(hidden_dim), | |
| nn.Dropout(dropout) | |
| ) | |
| self.text_proj = nn.Sequential( | |
| nn.Linear(text_dim, hidden_dim), | |
| nn.LayerNorm(hidden_dim), | |
| nn.Dropout(dropout) | |
| ) | |
| # Слой для объединения модальностей | |
| self.fusion_proj = nn.Sequential( | |
| nn.Linear(2 * hidden_dim, hidden_dim), | |
| nn.LayerNorm(hidden_dim), | |
| nn.Dropout(dropout) | |
| ) | |
| # Mamba блоки для обработки объединенных признаков | |
| mamba_params = { | |
| 'd_input': hidden_dim, | |
| 'd_model': hidden_dim, | |
| 'd_state': mamba_d_state, | |
| 'd_discr': d_discr, | |
| 'ker_size': mamba_ker_size | |
| } | |
| self.mamba_blocks = nn.ModuleList([ | |
| nn.Sequential( | |
| MambaBlock(**mamba_params), | |
| RMSNorm(hidden_dim) | |
| ) | |
| for _ in range(self.num_mamba_layers) | |
| ]) | |
| # Автоматический расчет размерности классификатора | |
| # self._calculate_classifier_input_dim() | |
| # Классификатор | |
| self.classifier = nn.Sequential( | |
| nn.Linear(self.seg_len + self.hidden_dim, out_features), | |
| nn.LayerNorm(out_features), | |
| nn.GELU(), | |
| nn.Dropout(dropout), | |
| nn.Linear(out_features, num_classes) | |
| ) | |
| self.pred_fusion = PredictionsFusion(num_matrices=3, num_classes=num_classes) | |
| self._init_weights() | |
| # def _calculate_classifier_input_dim(self): | |
| # """Вычисляет размер входных признаков для классификатора""" | |
| # dummy = torch.randn(1, self.seg_len, self.hidden_dim) | |
| # pooled = self._pool_features(dummy) | |
| # self.classifier_input_dim = pooled.size(1) | |
| def _pool_features(self, x): | |
| """Объединение временных и feature статистик""" | |
| mean_temp = x.mean(dim=1) # [batch, hidden_dim] | |
| mean_feat = x.mean(dim=-1) # [batch, seq_len] | |
| full_feature = torch.cat([mean_temp, mean_feat], dim=1) | |
| if full_feature.shape[-1] == self.seg_len+self.hidden_dim: | |
| return torch.cat([mean_temp, mean_feat], dim=1) | |
| else: | |
| pad_size = self.seg_len+self.hidden_dim - full_feature.shape[-1] | |
| return F.pad(full_feature, (0, pad_size), mode="constant", value=0) | |
| def forward(self, audio_features, text_features, audio_pred, text_pred): | |
| # Проекция признаков | |
| audio = self.audio_proj(audio_features.float()) # [B, T, D] | |
| text = self.text_proj(text_features.float()) # [B, T, D] | |
| # Адаптивный пулинг к минимальной длине | |
| min_len = min(audio.size(1), text.size(1)) | |
| audio = self._adaptive_pool(audio, min_len) | |
| text = self._adaptive_pool(text, min_len) | |
| # Объединение модальностей | |
| fused = torch.cat([audio, text], dim=-1) # [B, T, 2*D] | |
| fused = self.fusion_proj(fused) # [B, T, D] | |
| # Обработка объединенных признаков через Mamba | |
| for mamba_block in self.mamba_blocks: | |
| out, _ = mamba_block[0](fused, None) | |
| out = mamba_block[1](out) | |
| fused = fused + out # Residual connection | |
| # Агрегация признаков и классификация | |
| pooled = self._pool_features(fused) | |
| out = self.classifier(pooled) | |
| w_out = self.pred_fusion([audio_pred, text_pred, out]) | |
| return w_out | |
| def _adaptive_pool(self, x, target_len): | |
| """Адаптивное изменение временной длины""" | |
| if x.size(1) == target_len: | |
| return x | |
| return F.interpolate( | |
| x.permute(0, 2, 1), | |
| size=target_len, | |
| mode='linear', | |
| align_corners=False | |
| ).permute(0, 2, 1) | |
| def _init_weights(self): | |
| for m in self.modules(): | |
| if isinstance(m, nn.Linear): | |
| nn.init.xavier_uniform_(m.weight) | |
| if m.bias is not None: | |
| nn.init.constant_(m.bias, 0) | |
| elif isinstance(m, nn.LayerNorm): | |
| nn.init.constant_(m.weight, 1) | |
| nn.init.constant_(m.bias, 0) | |