Spaces:
Running
Running
| import torch | |
| # from transformers.models.bart.modeling_bart import BartForConditionalGeneration | |
| # from transformers.models.bert.modeling_bert import BertForSequenceClassification | |
| # model = BartForConditionalGeneration(None) | |
| class PrefixEncoder(torch.nn.Module): | |
| r""" | |
| The torch.nn model to encode the prefix | |
| Input shape: (batch-size, prefix-length) | |
| Output shape: (batch-size, prefix-length, 2*layers*hidden) | |
| """ | |
| def __init__(self, config): | |
| super().__init__() | |
| self.prefix_projection = config.prefix_projection | |
| if self.prefix_projection: | |
| # Use a two-layer MLP to encode the prefix | |
| self.embedding = torch.nn.Embedding(config.pre_seq_len, config.hidden_size) | |
| self.trans = torch.nn.Sequential( | |
| torch.nn.Linear(config.hidden_size, config.prefix_hidden_size), | |
| torch.nn.Tanh(), | |
| torch.nn.Linear(config.prefix_hidden_size, config.num_hidden_layers * 2 * config.hidden_size) | |
| ) | |
| else: | |
| self.embedding = torch.nn.Embedding(config.pre_seq_len, config.num_hidden_layers * 2 * config.hidden_size) | |
| def forward(self, prefix: torch.Tensor): | |
| if self.prefix_projection: | |
| prefix_tokens = self.embedding(prefix) # [pre_seq_len, hidden_dim] | |
| past_key_values = self.trans(prefix_tokens) | |
| else: | |
| past_key_values = self.embedding(prefix) | |
| return past_key_values | |