Spaces:
Running
on
Zero
Running
on
Zero
| from abc import abstractmethod | |
| from typing import Any, Tuple | |
| import torch | |
| import torch.nn as nn | |
| import torch.nn.functional as F | |
| from ....modules.distributions.distributions import \ | |
| DiagonalGaussianDistribution | |
| from .base import AbstractRegularizer | |
| class DiagonalGaussianRegularizer(AbstractRegularizer): | |
| def __init__(self, sample: bool = True): | |
| super().__init__() | |
| self.sample = sample | |
| def get_trainable_parameters(self) -> Any: | |
| yield from () | |
| def forward(self, z: torch.Tensor) -> Tuple[torch.Tensor, dict]: | |
| log = dict() | |
| posterior = DiagonalGaussianDistribution(z) | |
| if self.sample: | |
| z = posterior.sample() | |
| else: | |
| z = posterior.mode() | |
| kl_loss = posterior.kl() | |
| kl_loss = torch.sum(kl_loss) / kl_loss.shape[0] | |
| log["kl_loss"] = kl_loss | |
| return z, log | |