Spaces:
Runtime error
Runtime error
| import torch.nn as nn | |
| import torch.nn.functional as F | |
| class BCELoss(nn.Module): | |
| def forward(self, prediction, target): | |
| loss = F.binary_cross_entropy_with_logits(prediction,target) | |
| return loss, {} | |
| class BCELossWithQuant(nn.Module): | |
| def __init__(self, codebook_weight=1.): | |
| super().__init__() | |
| self.codebook_weight = codebook_weight | |
| def forward(self, qloss, target, prediction, split): | |
| bce_loss = F.binary_cross_entropy_with_logits(prediction,target) | |
| loss = bce_loss + self.codebook_weight*qloss | |
| return loss, {"{}/total_loss".format(split): loss.clone().detach().mean(), | |
| "{}/bce_loss".format(split): bce_loss.detach().mean(), | |
| "{}/quant_loss".format(split): qloss.detach().mean() | |
| } | |