Spaces:
Runtime error
Runtime error
| import torch | |
| import torch.nn as nn | |
| class BaseLosses(nn.Module): | |
| def __init__(self, cfg, losses, params, losses_func, num_joints, **kwargs): | |
| super().__init__() | |
| # Save parameters | |
| self.num_joints = num_joints | |
| self._params = params | |
| # Add total indicator | |
| losses.append("total") if "total" not in losses else None | |
| # Register losses | |
| for loss in losses: | |
| self.register_buffer(loss, torch.tensor(0.0)) | |
| self.register_buffer("count", torch.tensor(0.0)) | |
| self.losses = losses | |
| # Instantiate loss functions | |
| self._losses_func = {} | |
| for loss in losses[:-1]: | |
| self._losses_func[loss] = losses_func[loss](reduction='mean') | |
| def _update_loss(self, loss: str, outputs, inputs): | |
| '''Update the loss and return the weighted loss.''' | |
| # Update the loss | |
| val = self._losses_func[loss](outputs, inputs) | |
| # self.losses_values[loss] += val.detach() | |
| getattr(self, loss).add_(val.detach()) | |
| # Return a weighted sum | |
| weighted_loss = self._params[loss] * val | |
| return weighted_loss | |
| def reset(self): | |
| '''Reset the losses to 0.''' | |
| for loss in self.losses: | |
| setattr(self, loss, torch.tensor(0.0, device=getattr(self, loss).device)) | |
| setattr(self, "count", torch.tensor(0.0, device=getattr(self, "count").device)) | |
| def compute(self, split): | |
| '''Compute the losses and return a dictionary with the losses.''' | |
| count = self.count | |
| # Loss dictionary | |
| loss_dict = {loss: getattr(self, loss)/count for loss in self.losses} | |
| # Format the losses for logging | |
| log_dict = { self.loss2logname(loss, split): value.item() | |
| for loss, value in loss_dict.items() if not torch.isnan(value)} | |
| # Reset the losses | |
| self.reset() | |
| return log_dict | |
| def loss2logname(self, loss: str, split: str): | |
| '''Convert the loss name to a log name.''' | |
| if loss == "total": | |
| log_name = f"{loss}/{split}" | |
| else: | |
| loss_type, name = loss.split("_") | |
| log_name = f"{loss_type}/{name}/{split}" | |
| return log_name | |