Spaces:
Runtime error
Runtime error
| import torch | |
| import copy | |
| from torch.nn import functional as F | |
| from torch.nn.modules.module import Module | |
| from torch.nn.modules.container import ModuleList | |
| from torch.nn.init import xavier_uniform_ | |
| from torch.nn.modules.dropout import Dropout | |
| from torch.nn.modules.linear import Linear | |
| from torch.nn.modules.normalization import LayerNorm | |
| from .attention import MultiheadAttention | |
| from .transformer import _get_activation_fn | |
| class TransformerEncoderLayerImproved(Module): | |
| def __init__(self, d_model, nhead, dim_feedforward=2048, dropout=0.1, activation="relu", d_global2=None): | |
| super(TransformerEncoderLayerImproved, self).__init__() | |
| self.self_attn = MultiheadAttention(d_model, nhead, dropout=dropout) | |
| if d_global2 is not None: | |
| self.linear_global2 = Linear(d_global2, d_model) | |
| # Implementation of Feedforward model | |
| self.linear1 = Linear(d_model, dim_feedforward) | |
| self.dropout = Dropout(dropout) | |
| self.linear2 = Linear(dim_feedforward, d_model) | |
| self.norm1 = LayerNorm(d_model) | |
| self.norm2 = LayerNorm(d_model) | |
| self.dropout1 = Dropout(dropout) | |
| self.dropout2_2 = Dropout(dropout) | |
| self.dropout2 = Dropout(dropout) | |
| self.activation = _get_activation_fn(activation) | |
| def __setstate__(self, state): | |
| if 'activation' not in state: | |
| state['activation'] = F.relu | |
| super(TransformerEncoderLayerImproved, self).__setstate__(state) | |
| def forward(self, src, memory2=None, src_mask=None, src_key_padding_mask=None): | |
| src1 = self.norm1(src) | |
| src2 = self.self_attn(src1, src1, src1, attn_mask=src_mask, key_padding_mask=src_key_padding_mask)[0] | |
| src = src + self.dropout1(src2) | |
| if memory2 is not None: | |
| src2_2 = self.linear_global2(memory2) | |
| src = src + self.dropout2_2(src2_2) | |
| src1 = self.norm2(src) | |
| src2 = self.linear2(self.dropout(self.activation(self.linear1(src1)))) | |
| src = src + self.dropout2(src2) | |
| return src | |
| class TransformerDecoderLayerImproved(Module): | |
| def __init__(self, d_model, nhead, dim_feedforward=2048, dropout=0.1, activation="relu"): | |
| super(TransformerDecoderLayerImproved, self).__init__() | |
| self.self_attn = MultiheadAttention(d_model, nhead, dropout=dropout) | |
| self.multihead_attn = MultiheadAttention(d_model, nhead, dropout=dropout) | |
| # Implementation of Feedforward model | |
| self.linear1 = Linear(d_model, dim_feedforward) | |
| self.dropout = Dropout(dropout) | |
| self.linear2 = Linear(dim_feedforward, d_model) | |
| self.norm1 = LayerNorm(d_model) | |
| self.norm2 = LayerNorm(d_model) | |
| self.norm3 = LayerNorm(d_model) | |
| self.dropout1 = Dropout(dropout) | |
| self.dropout2 = Dropout(dropout) | |
| self.dropout3 = Dropout(dropout) | |
| self.activation = _get_activation_fn(activation) | |
| def __setstate__(self, state): | |
| if 'activation' not in state: | |
| state['activation'] = F.relu | |
| super(TransformerDecoderLayerImproved, self).__setstate__(state) | |
| def forward(self, tgt, memory, tgt_mask=None, memory_mask=None, | |
| tgt_key_padding_mask=None, memory_key_padding_mask=None): | |
| tgt1 = self.norm1(tgt) | |
| tgt2 = self.self_attn(tgt1, tgt1, tgt1, attn_mask=tgt_mask, key_padding_mask=tgt_key_padding_mask)[0] | |
| tgt = tgt + self.dropout1(tgt2) | |
| tgt1 = self.norm2(tgt) | |
| tgt2 = self.multihead_attn(tgt1, memory, memory, attn_mask=memory_mask, key_padding_mask=memory_key_padding_mask)[0] | |
| tgt = tgt + self.dropout2(tgt2) | |
| tgt1 = self.norm3(tgt) | |
| tgt2 = self.linear2(self.dropout(self.activation(self.linear1(tgt1)))) | |
| tgt = tgt + self.dropout3(tgt2) | |
| return tgt | |
| class TransformerDecoderLayerGlobalImproved(Module): | |
| def __init__(self, d_model, d_global, nhead, dim_feedforward=2048, dropout=0.1, activation="relu", d_global2=None): | |
| super(TransformerDecoderLayerGlobalImproved, self).__init__() | |
| self.self_attn = MultiheadAttention(d_model, nhead, dropout=dropout) | |
| self.linear_global = Linear(d_global, d_model) | |
| if d_global2 is not None: | |
| self.linear_global2 = Linear(d_global2, d_model) | |
| # Implementation of Feedforward model | |
| self.linear1 = Linear(d_model, dim_feedforward) | |
| self.dropout = Dropout(dropout) | |
| self.linear2 = Linear(dim_feedforward, d_model) | |
| self.norm1 = LayerNorm(d_model) | |
| self.norm2 = LayerNorm(d_model) | |
| self.dropout1 = Dropout(dropout) | |
| self.dropout2 = Dropout(dropout) | |
| self.dropout2_2 = Dropout(dropout) | |
| self.dropout3 = Dropout(dropout) | |
| self.activation = _get_activation_fn(activation) | |
| def __setstate__(self, state): | |
| if 'activation' not in state: | |
| state['activation'] = F.relu | |
| super(TransformerDecoderLayerGlobalImproved, self).__setstate__(state) | |
| def forward(self, tgt, memory, memory2=None, tgt_mask=None, tgt_key_padding_mask=None, *args, **kwargs): | |
| tgt1 = self.norm1(tgt) | |
| tgt2 = self.self_attn(tgt1, tgt1, tgt1, attn_mask=tgt_mask, key_padding_mask=tgt_key_padding_mask)[0] | |
| tgt = tgt + self.dropout1(tgt2) | |
| tgt2 = self.linear_global(memory) | |
| tgt = tgt + self.dropout2(tgt2) # implicit broadcast | |
| if memory2 is not None: | |
| tgt2_2 = self.linear_global2(memory2) | |
| tgt = tgt + self.dropout2_2(tgt2_2) | |
| tgt1 = self.norm2(tgt) | |
| tgt2 = self.linear2(self.dropout(self.activation(self.linear1(tgt1)))) | |
| tgt = tgt + self.dropout3(tgt2) | |
| return tgt | |