Spaces:
Runtime error
Runtime error
| import math | |
| import torch | |
| import torch.nn as nn | |
| from ..utils import log_sum_exp | |
| class EncoderBase(nn.Module): | |
| """docstring for EncoderBase""" | |
| def __init__(self): | |
| super(EncoderBase, self).__init__() | |
| def forward(self, x): | |
| """ | |
| Args: | |
| x: (batch_size, *) | |
| Returns: the tensors required to parameterize a distribution. | |
| E.g. for Gaussian encoder it returns the mean and variance tensors | |
| """ | |
| raise NotImplementedError | |
| def sample(self, input, nsamples): | |
| """sampling from the encoder | |
| Returns: Tensor1 | |
| Tensor1: the tensor latent z with shape [batch, nsamples, nz] | |
| """ | |
| raise NotImplementedError | |
| def encode(self, input, nsamples): | |
| """perform the encoding and compute the KL term | |
| Returns: Tensor1, Tensor2 | |
| Tensor1: the tensor latent z with shape [batch, nsamples, nz] | |
| Tensor2: the tenor of KL for each x with shape [batch] | |
| """ | |
| raise NotImplementedError | |
| def eval_inference_dist(self, x, z, param=None): | |
| """this function computes log q(z | x) | |
| Args: | |
| z: tensor | |
| different z points that will be evaluated, with | |
| shape [batch, nsamples, nz] | |
| Returns: Tensor1 | |
| Tensor1: log q(z|x) with shape [batch, nsamples] | |
| """ | |
| raise NotImplementedError | |
| def calc_mi(self, x): | |
| """Approximate the mutual information between x and z | |
| I(x, z) = E_xE_{q(z|x)}log(q(z|x)) - E_xE_{q(z|x)}log(q(z)) | |
| Returns: Float | |
| """ | |
| raise NotImplementedError |