Spaces:
Sleeping
Sleeping
| # -*- coding: utf-8 -*- | |
| # @Time : 2022/2/17 6:05 下午 | |
| # @Author : JianingWang | |
| # @File : loss | |
| import torch | |
| from torch import nn | |
| import torch.nn.functional as F | |
| class FocalLoss(nn.Module): | |
| """Multi-class Focal loss implementation""" | |
| def __init__(self, gamma=2, weight=None, ignore_index=-100): | |
| super(FocalLoss, self).__init__() | |
| self.gamma = gamma | |
| self.weight = weight | |
| self.ignore_index = ignore_index | |
| def forward(self, input, target): | |
| """ | |
| input: [N, C] | |
| target: [N, ] | |
| """ | |
| logpt = F.log_softmax(input, dim=1) | |
| pt = torch.exp(logpt) | |
| logpt = (1 - pt) ** self.gamma * logpt | |
| loss = F.nll_loss(logpt, target, self.weight, ignore_index=self.ignore_index) | |
| return loss | |