Spaces:
Running
Running
| import torch | |
| from torch import nn | |
| import torch.nn.functional as F | |
| class MLPProberBase(nn.Module): | |
| def __init__(self, d=768, layer='all', num_outputs=87): | |
| super().__init__() | |
| self.hidden_layer_sizes = [512, ] # eval(self.cfg.hidden_layer_sizes) | |
| self.num_layers = len(self.hidden_layer_sizes) | |
| self.layer = layer | |
| for i, ld in enumerate(self.hidden_layer_sizes): | |
| setattr(self, f"hidden_{i}", nn.Linear(d, ld)) | |
| d = ld | |
| self.output = nn.Linear(d, num_outputs) | |
| self.n_tranformer_layer = 12 | |
| self.init_aggregator() | |
| def init_aggregator(self): | |
| """Initialize the aggregator for weighted sum over different layers of features | |
| """ | |
| if self.layer == "all": | |
| # use learned weights to aggregate features | |
| self.aggregator = nn.Parameter(torch.randn((1, self.n_tranformer_layer, 1))) | |
| def forward(self, x): | |
| """ | |
| x: (B, L, T, H) | |
| T=#chunks, can be 1 or several chunks | |
| """ | |
| if self.layer == "all": | |
| weights = F.softmax(self.aggregator, dim=1) | |
| x = (x * weights).sum(dim=1) | |
| for i in range(self.num_layers): | |
| x = getattr(self, f"hidden_{i}")(x) | |
| # x = self.dropout(x) | |
| x = F.relu(x) | |
| output = self.output(x) | |
| return output |