Spaces:
Runtime error
Runtime error
| import torch.nn as nn | |
| from modeling import DocFormerEncoder,ResNetFeatureExtractor,DocFormerEmbeddings,LanguageFeatureExtractor | |
| class DocFormerForClassification(nn.Module): | |
| def __init__(self, config): | |
| super(DocFormerForClassification, self).__init__() | |
| self.resnet = ResNetFeatureExtractor(hidden_dim = config['max_position_embeddings']) | |
| self.embeddings = DocFormerEmbeddings(config) | |
| self.lang_emb = LanguageFeatureExtractor() | |
| self.config = config | |
| self.dropout = nn.Dropout(config['hidden_dropout_prob']) | |
| self.linear_layer = nn.Linear(in_features = config['hidden_size'], out_features = 16) ## Number of Classes | |
| self.encoder = DocFormerEncoder(config) | |
| def forward(self, batch_dict): | |
| x_feat = batch_dict['x_features'] | |
| y_feat = batch_dict['y_features'] | |
| token = batch_dict['input_ids'] | |
| img = batch_dict['resized_scaled_img'] | |
| v_bar_s, t_bar_s = self.embeddings(x_feat,y_feat) | |
| v_bar = self.resnet(img) | |
| t_bar = self.lang_emb(token) | |
| out = self.encoder(t_bar,v_bar,t_bar_s,v_bar_s) | |
| out = self.linear_layer(out) | |
| out = out[:, 0, :] | |
| return out | |
| ## Defining pytorch lightning model | |
| import pytorch_lightning as pl | |
| from sklearn.metrics import accuracy_score, confusion_matrix | |
| import pandas as pd | |
| import matplotlib.pyplot as plt | |
| import seaborn as sns | |
| import numpy as np | |
| import torchmetrics | |
| import wandb | |
| import torch | |
| class DocFormer(pl.LightningModule): | |
| def __init__(self, config , lr = 5e-5): | |
| super(DocFormer, self).__init__() | |
| self.save_hyperparameters() | |
| self.config = config | |
| self.docformer = DocFormerForClassification(config) | |
| self.num_classes = 16 | |
| self.train_accuracy_metric = torchmetrics.Accuracy() | |
| self.val_accuracy_metric = torchmetrics.Accuracy() | |
| self.f1_metric = torchmetrics.F1Score(num_classes=self.num_classes) | |
| self.precision_macro_metric = torchmetrics.Precision( | |
| average="macro", num_classes=self.num_classes | |
| ) | |
| self.recall_macro_metric = torchmetrics.Recall( | |
| average="macro", num_classes=self.num_classes | |
| ) | |
| self.precision_micro_metric = torchmetrics.Precision(average="micro") | |
| self.recall_micro_metric = torchmetrics.Recall(average="micro") | |
| def forward(self, batch_dict): | |
| logits = self.docformer(batch_dict) | |
| return logits | |
| def training_step(self, batch, batch_idx): | |
| logits = self.forward(batch) | |
| loss = nn.CrossEntropyLoss()(logits, batch['label']) | |
| preds = torch.argmax(logits, 1) | |
| ## Calculating the accuracy score | |
| train_acc = self.train_accuracy_metric(preds, batch["label"]) | |
| ## Logging | |
| self.log('train/loss', loss,prog_bar = True, on_epoch=True, logger=True, on_step=True) | |
| self.log('train/acc', train_acc, prog_bar = True, on_epoch=True, logger=True, on_step=True) | |
| return loss | |
| def validation_step(self, batch, batch_idx): | |
| logits = self.forward(batch) | |
| loss = nn.CrossEntropyLoss()(logits, batch['label']) | |
| preds = torch.argmax(logits, 1) | |
| labels = batch['label'] | |
| # Metrics | |
| valid_acc = self.val_accuracy_metric(preds, labels) | |
| precision_macro = self.precision_macro_metric(preds, labels) | |
| recall_macro = self.recall_macro_metric(preds, labels) | |
| precision_micro = self.precision_micro_metric(preds, labels) | |
| recall_micro = self.recall_micro_metric(preds, labels) | |
| f1 = self.f1_metric(preds, labels) | |
| # Logging metrics | |
| self.log("valid/loss", loss, prog_bar=True, on_step=True, logger=True) | |
| self.log("valid/acc", valid_acc, prog_bar=True, on_epoch=True, logger=True, on_step=True) | |
| self.log("valid/precision_macro", precision_macro, prog_bar=True, on_epoch=True, logger=True, on_step=True) | |
| self.log("valid/recall_macro", recall_macro, prog_bar=True, on_epoch=True, logger=True, on_step=True) | |
| self.log("valid/precision_micro", precision_micro, prog_bar=True, on_epoch=True, logger=True, on_step=True) | |
| self.log("valid/recall_micro", recall_micro, prog_bar=True, on_epoch=True, logger=True, on_step=True) | |
| self.log("valid/f1", f1, prog_bar=True, on_epoch=True) | |
| return {"label": batch['label'], "logits": logits} | |
| def validation_epoch_end(self, outputs): | |
| labels = torch.cat([x["label"] for x in outputs]) | |
| logits = torch.cat([x["logits"] for x in outputs]) | |
| preds = torch.argmax(logits, 1) | |
| wandb.log({"cm": wandb.sklearn.plot_confusion_matrix(labels.cpu().numpy(), preds.cpu().numpy())}) | |
| self.logger.experiment.log( | |
| {"roc": wandb.plot.roc_curve(labels.cpu().numpy(), logits.cpu().numpy())} | |
| ) | |
| def configure_optimizers(self): | |
| return torch.optim.AdamW(self.parameters(), lr = self.hparams['lr']) |