API Reference#

Mixin classes for VAE varients.

class vae_mixin.VAEMixin[source]#

Mixin class for basic VAE.

https://arxiv.org/abs/1312.6114

static reparameterize(mu, logvar)[source]#

Reparameterization trick.

encode(x)[source]#
Parameters:

x (torch.Tensor) – The input data.

Returns:

z (torch.Tensor) – The latent variables if training else mean.

decode(z)[source]#
Parameters:

z (torch.Tensor) – The latent variables.

Returns:

x_rec (torch.Tensor) – The reconstructed data.

rec_loss(x_rec, x, **kwargs)[source]#

Compute the reconstruction loss, when the output \(p(x|z)\) is Bernoulli.

Parameters:
Returns:

bce (torch.Tensor) – The binary cross entropy.

kl_loss(mu=None, logvar=None, **kwargs)[source]#

Compute the KL divergence loss, when the prior \(p(z)=\mathcal{N}(0,\mathbf{I})\) and the posterior approximation \(q(z|x)\) are Gaussian.

Parameters:
  • mu (torch.Tensor, optional) – Encoder output \(\mu\).

  • logvar (torch.Tensor, optional) – Encoder output \(\log\sigma^2\).

Returns:

kld (torch.Tensor) – The KL divergence.

loss(**kwargs)[source]#

Compute the VAE objective function.

Parameters:

**kwargs – Extra arguments to loss terms.

Returns:

torch.Tensor

class vae_mixin.BetaVAEMixin(beta=1, **kwargs)[source]#

Mixin class for \(\beta\)-VAE.

https://openreview.net/forum?id=Sy2fzU9gl

Parameters:

beta (float, default 1) – Regularisation coefficient \(\beta\).

loss(**kwargs)[source]#

Compute the \(\beta\)-VAE objective function.

Parameters:

**kwargs – Extra arguments to loss terms.

Returns:

torch.Tensor

class vae_mixin.InfoVAEMixin(lamb=1, alpha=0, **kwargs)[source]#

Mixin class for InfoVAE.

https://arxiv.org/abs/1706.02262

Parameters:
  • lamb (float, default 1) – Scaling coefficient \(\lambda\).

  • alpha (float, default 0) – Information preference \(\alpha\).

static kernel(x1, x2)[source]#

Compute the RBF kernel.

mmd_loss(z, **kwargs)[source]#

Compute the MMD loss of \(q(z)\) and \(p(z)\).

Parameters:

z (torch.Tensor) – The latent variables.

Returns:

mmd (torch.Tensor) – The squared maximum mean discrepancy.

loss(**kwargs)[source]#

Compute the InfoVAE objective function.

Parameters:

**kwargs – Extra arguments to loss terms.

Returns:

torch.Tensor

class vae_mixin.DIPVAEMixin(lambda_od=10, lambda_d=10, dip=1, **kwargs)[source]#

Mixin class for DIP-VAE.

https://arxiv.org/abs/1711.00848

Parameters:
  • lambda_od (float, default 10) – Hyperparameter \(\lambda_{od}\) for the loss on the off-diagonal entries.

  • lambda_d (float, default 10) – Hyperparameter \(\lambda_d\) for the loss on the diagonal entries.

  • dip (int, 1 or 2) – DIP type.

dip_regularizer(mu=None, logvar=None, **kwargs)[source]#

Compute the DIP regularization terms, when \(q(z|x)\) is Gaussian.

Parameters:
  • mu (torch.Tensor, optional) – Encoder output \(\mu\).

  • logvar (torch.Tensor, optional) – Encoder output \(\log\sigma^2\).

Returns:

loss(**kwargs)[source]#

Compute the DIP-VAE objective function.

Parameters:

**kwargs – Extra arguments to loss terms.

Returns:

torch.Tensor

class vae_mixin.BetaTCVAEMixin(alpha=1, beta=1, gamma=1, **kwargs)[source]#

Mixin class for \(\beta\)-TCVAE.

https://arxiv.org/abs/1802.04942

Parameters:
  • alpha (float, default 1) – Weight \(\alpha\) of the index-code mutual information.

  • beta (float, default 1) – Weight \(\beta\) of the total correlation.

  • gamma (float, default 1) – Weight \(\gamma\) of the dimension-wise KL.

static gaussian_log_density(x, mu, logvar)[source]#

Compute the log density of a Gaussian.

decompose_kl(n, z, mu=None, logvar=None, **kwargs)[source]#

Compute the decomposed KL terms.

Parameters:
  • n (int) – Dataset size \(N\).

  • z (torch.Tensor) – The latent variables.

  • mu (torch.Tensor, optional) – Encoder output \(\mu\).

  • logvar (torch.Tensor, optional) – Encoder output \(\log\sigma^2\).

Returns:

loss(**kwargs)[source]#

Compute the \(\beta\)-TCVAE objective function.

Parameters:

**kwargs – Extra arguments to loss terms.

Returns:

torch.Tensor

class vae_mixin.VQVAEMixin(embedding_dim, num_embeddings, commitment_cost, **kwargs)[source]#

Mixin class for VQ-VAE.

https://arxiv.org/abs/1711.00937

Parameters:
  • embedding_dim (int) – The dimensionality of latent embedding vector.

  • num_embeddings (int) – The size of the discrete latent space.

  • commitment_cost (float) – Scalar \(\beta\) which controls the weighting of the commitment loss.

encode(x)[source]#
Parameters:

x (torch.Tensor) – The input data.

Returns:

z_e (torch.Tensor) – The encoder outputs.

quantize(inputs)[source]#

Quantize the inputs vectors.

Parameters:

inputs (torch.Tensor) – The encoder outputs.

Returns:

vq_loss(z_e, z_q, **kwargs)[source]#

Compute the vector quantization loss terms.

Parameters:
Returns:

loss(**kwargs)[source]#

Compute the VQ-VAE objective function.

Parameters:

**kwargs – Extra arguments to loss terms.

Returns:

torch.Tensor