Spaces:
Runtime error
Runtime error
| from transformers import BertPreTrainedModel | |
| from transformers.modeling_outputs import SequenceClassifierOutput | |
| from transformers.utils import logging | |
| from BERT_explainability.modules.layers_ours import * | |
| from BERT_explainability.modules.BERT.BERT import BertModel | |
| from torch.nn import CrossEntropyLoss, MSELoss | |
| import torch.nn as nn | |
| from typing import List, Any | |
| import torch | |
| from BERT_rationale_benchmark.models.model_utils import PaddedSequence | |
| class BertForSequenceClassification(BertPreTrainedModel): | |
| def __init__(self, config): | |
| super().__init__(config) | |
| self.num_labels = config.num_labels | |
| self.bert = BertModel(config) | |
| self.dropout = Dropout(config.hidden_dropout_prob) | |
| self.classifier = Linear(config.hidden_size, config.num_labels) | |
| self.init_weights() | |
| def forward( | |
| self, | |
| input_ids=None, | |
| attention_mask=None, | |
| token_type_ids=None, | |
| position_ids=None, | |
| head_mask=None, | |
| inputs_embeds=None, | |
| labels=None, | |
| output_attentions=None, | |
| output_hidden_states=None, | |
| return_dict=None, | |
| ): | |
| r""" | |
| labels (:obj:`torch.LongTensor` of shape :obj:`(batch_size,)`, `optional`): | |
| Labels for computing the sequence classification/regression loss. | |
| Indices should be in :obj:`[0, ..., config.num_labels - 1]`. | |
| If :obj:`config.num_labels == 1` a regression loss is computed (Mean-Square loss), | |
| If :obj:`config.num_labels > 1` a classification loss is computed (Cross-Entropy). | |
| """ | |
| return_dict = return_dict if return_dict is not None else self.config.use_return_dict | |
| outputs = self.bert( | |
| input_ids, | |
| attention_mask=attention_mask, | |
| token_type_ids=token_type_ids, | |
| position_ids=position_ids, | |
| head_mask=head_mask, | |
| inputs_embeds=inputs_embeds, | |
| output_attentions=output_attentions, | |
| output_hidden_states=output_hidden_states, | |
| return_dict=return_dict, | |
| ) | |
| pooled_output = outputs[1] | |
| pooled_output = self.dropout(pooled_output) | |
| logits = self.classifier(pooled_output) | |
| loss = None | |
| if labels is not None: | |
| if self.num_labels == 1: | |
| # We are doing regression | |
| loss_fct = MSELoss() | |
| loss = loss_fct(logits.view(-1), labels.view(-1)) | |
| else: | |
| loss_fct = CrossEntropyLoss() | |
| loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1)) | |
| if not return_dict: | |
| output = (logits,) + outputs[2:] | |
| return ((loss,) + output) if loss is not None else output | |
| return SequenceClassifierOutput( | |
| loss=loss, | |
| logits=logits, | |
| hidden_states=outputs.hidden_states, | |
| attentions=outputs.attentions, | |
| ) | |
| def relprop(self, cam=None, **kwargs): | |
| cam = self.classifier.relprop(cam, **kwargs) | |
| cam = self.dropout.relprop(cam, **kwargs) | |
| cam = self.bert.relprop(cam, **kwargs) | |
| # print("conservation: ", cam.sum()) | |
| return cam | |
| # this is the actual classifier we will be using | |
| class BertClassifier(nn.Module): | |
| """Thin wrapper around BertForSequenceClassification""" | |
| def __init__(self, | |
| bert_dir: str, | |
| pad_token_id: int, | |
| cls_token_id: int, | |
| sep_token_id: int, | |
| num_labels: int, | |
| max_length: int = 512, | |
| use_half_precision=True): | |
| super(BertClassifier, self).__init__() | |
| bert = BertForSequenceClassification.from_pretrained(bert_dir, num_labels=num_labels) | |
| if use_half_precision: | |
| import apex | |
| bert = bert.half() | |
| self.bert = bert | |
| self.pad_token_id = pad_token_id | |
| self.cls_token_id = cls_token_id | |
| self.sep_token_id = sep_token_id | |
| self.max_length = max_length | |
| def forward(self, | |
| query: List[torch.tensor], | |
| docids: List[Any], | |
| document_batch: List[torch.tensor]): | |
| assert len(query) == len(document_batch) | |
| print(query) | |
| # note about device management: | |
| # since distributed training is enabled, the inputs to this module can be on *any* device (preferably cpu, since we wrap and unwrap the module) | |
| # we want to keep these params on the input device (assuming CPU) for as long as possible for cheap memory access | |
| target_device = next(self.parameters()).device | |
| cls_token = torch.tensor([self.cls_token_id]).to(device=document_batch[0].device) | |
| sep_token = torch.tensor([self.sep_token_id]).to(device=document_batch[0].device) | |
| input_tensors = [] | |
| position_ids = [] | |
| for q, d in zip(query, document_batch): | |
| if len(q) + len(d) + 2 > self.max_length: | |
| d = d[:(self.max_length - len(q) - 2)] | |
| input_tensors.append(torch.cat([cls_token, q, sep_token, d])) | |
| position_ids.append(torch.tensor(list(range(0, len(q) + 1)) + list(range(0, len(d) + 1)))) | |
| bert_input = PaddedSequence.autopad(input_tensors, batch_first=True, padding_value=self.pad_token_id, | |
| device=target_device) | |
| positions = PaddedSequence.autopad(position_ids, batch_first=True, padding_value=0, device=target_device) | |
| (classes,) = self.bert(bert_input.data, | |
| attention_mask=bert_input.mask(on=0.0, off=float('-inf'), device=target_device), | |
| position_ids=positions.data) | |
| assert torch.all(classes == classes) # for nans | |
| print(input_tensors[0]) | |
| print(self.relprop()[0]) | |
| return classes | |
| def relprop(self, cam=None, **kwargs): | |
| return self.bert.relprop(cam, **kwargs) | |
| if __name__ == '__main__': | |
| from transformers import BertTokenizer | |
| import os | |
| class Config: | |
| def __init__(self, hidden_size, num_attention_heads, attention_probs_dropout_prob, num_labels, | |
| hidden_dropout_prob): | |
| self.hidden_size = hidden_size | |
| self.num_attention_heads = num_attention_heads | |
| self.attention_probs_dropout_prob = attention_probs_dropout_prob | |
| self.num_labels = num_labels | |
| self.hidden_dropout_prob = hidden_dropout_prob | |
| tokenizer = BertTokenizer.from_pretrained("bert-base-uncased") | |
| x = tokenizer.encode_plus("In this movie the acting is great. The movie is perfect! [sep]", | |
| add_special_tokens=True, | |
| max_length=512, | |
| return_token_type_ids=False, | |
| return_attention_mask=True, | |
| pad_to_max_length=True, | |
| return_tensors='pt', | |
| truncation=True) | |
| print(x['input_ids']) | |
| model = BertForSequenceClassification.from_pretrained('bert-base-uncased', num_labels=2) | |
| model_save_file = os.path.join('./BERT_explainability/output_bert/movies/classifier/', 'classifier.pt') | |
| model.load_state_dict(torch.load(model_save_file)) | |
| # x = torch.randint(100, (2, 20)) | |
| # x = torch.tensor([[101, 2054, 2003, 1996, 15792, 1997, 2023, 3319, 1029, 102, | |
| # 101, 4079, 102, 101, 6732, 102, 101, 2643, 102, 101, | |
| # 2038, 102, 101, 1037, 102, 101, 2933, 102, 101, 2005, | |
| # 102, 101, 2032, 102, 101, 1010, 102, 101, 1037, 102, | |
| # 101, 3800, 102, 101, 2005, 102, 101, 2010, 102, 101, | |
| # 2166, 102, 101, 1010, 102, 101, 1998, 102, 101, 2010, | |
| # 102, 101, 4650, 102, 101, 1010, 102, 101, 2002, 102, | |
| # 101, 2074, 102, 101, 2515, 102, 101, 1050, 102, 101, | |
| # 1005, 102, 101, 1056, 102, 101, 2113, 102, 101, 2054, | |
| # 102, 101, 1012, 102]]) | |
| # x.requires_grad_() | |
| model.eval() | |
| y = model(x['input_ids'], x['attention_mask']) | |
| print(y) | |
| cam, _ = model.relprop() | |
| #print(cam.shape) | |
| cam = cam.sum(-1) | |
| #print(cam) | |