Spaces:
Runtime error
Runtime error
| import torch | |
| import torch.nn as nn | |
| from .base import BaseLosses | |
| class CommitLoss(nn.Module): | |
| """ | |
| Useless Wrapper | |
| """ | |
| def __init__(self, **kwargs): | |
| super().__init__() | |
| def forward(self, commit, commit2, **kwargs): | |
| return commit | |
| class GPTLosses(BaseLosses): | |
| def __init__(self, cfg, stage, num_joints, **kwargs): | |
| # Save parameters | |
| self.stage = stage | |
| recons_loss = cfg.LOSS.ABLATION.RECONS_LOSS | |
| # Define losses | |
| losses = [] | |
| params = {} | |
| if stage == "vae": | |
| losses.append("recons_feature") | |
| params['recons_feature'] = cfg.LOSS.LAMBDA_FEATURE | |
| losses.append("recons_velocity") | |
| params['recons_velocity'] = cfg.LOSS.LAMBDA_VELOCITY | |
| losses.append("vq_commit") | |
| params['vq_commit'] = cfg.LOSS.LAMBDA_COMMIT | |
| elif stage in ["lm_pretrain", "lm_instruct"]: | |
| losses.append("gpt_loss") | |
| params['gpt_loss'] = cfg.LOSS.LAMBDA_CLS | |
| # Define loss functions & weights | |
| losses_func = {} | |
| for loss in losses: | |
| if loss.split('_')[0] == 'recons': | |
| if recons_loss == "l1": | |
| losses_func[loss] = nn.L1Loss | |
| elif recons_loss == "l2": | |
| losses_func[loss] = nn.MSELoss | |
| elif recons_loss == "l1_smooth": | |
| losses_func[loss] = nn.SmoothL1Loss | |
| elif loss.split('_')[1] in [ | |
| 'commit', 'loss', 'gpt', 'm2t2m', 't2m2t' | |
| ]: | |
| losses_func[loss] = CommitLoss | |
| elif loss.split('_')[1] in ['cls', 'lm']: | |
| losses_func[loss] = nn.CrossEntropyLoss | |
| else: | |
| raise NotImplementedError(f"Loss {loss} not implemented.") | |
| super().__init__(cfg, losses, params, losses_func, num_joints, | |
| **kwargs) | |
| def update(self, rs_set): | |
| '''Update the losses''' | |
| total: float = 0.0 | |
| if self.stage in ["vae"]: | |
| total += self._update_loss("recons_feature", rs_set['m_rst'], | |
| rs_set['m_ref']) | |
| # total += self._update_loss("recons_joints", rs_set['joints_rst'], rs_set['joints_ref']) | |
| nfeats = rs_set['m_rst'].shape[-1] | |
| if nfeats in [263, 135 + 263]: | |
| if nfeats == 135 + 263: | |
| vel_start = 135 + 4 | |
| elif nfeats == 263: | |
| vel_start = 4 | |
| total += self._update_loss( | |
| "recons_velocity", | |
| rs_set['m_rst'][..., vel_start:(self.num_joints - 1) * 3 + | |
| vel_start], | |
| rs_set['m_ref'][..., vel_start:(self.num_joints - 1) * 3 + | |
| vel_start]) | |
| else: | |
| if self._params['recons_velocity'] != 0.0: | |
| raise NotImplementedError( | |
| "Velocity not implemented for nfeats = {})".format(nfeats)) | |
| total += self._update_loss("vq_commit", rs_set['loss_commit'], | |
| rs_set['loss_commit']) | |
| if self.stage in ["lm_pretrain", "lm_instruct"]: | |
| total += self._update_loss("gpt_loss", rs_set['outputs'].loss, | |
| rs_set['outputs'].loss) | |
| # Update the total loss | |
| self.total += total.detach() | |
| self.count += 1 | |
| return total | |