API Reference#
Mixin classes for VAE variants.
- class vae_mixin.VAEMixin[source]#
Mixin class for basic VAE.
https://arxiv.org/abs/1312.6114
- encode(x)[source]#
- Parameters:
- x
torch.Tensor The input data.
- x
- Returns:
- z
torch.Tensor The latent variables if training else mean.
- z
- decode(z)[source]#
- Parameters:
- z
torch.Tensor The latent variables.
- z
- Returns:
- x_rec
torch.Tensor The reconstructed data.
- x_rec
- rec_loss(x_rec, x, **kwargs)[source]#
Compute the reconstruction loss, when the output \(p(x|z)\) is Bernoulli.
- Parameters:
- x_rec
torch.Tensor The reconstructed data.
- x
torch.Tensor The input data.
- x_rec
- Returns:
- bce
torch.Tensor The binary cross entropy.
- bce
- kl_loss(mu=None, logvar=None, **kwargs)[source]#
Compute the KL divergence loss, when the prior \(p(z)=\mathcal{N}(0,\mathbf{I})\) and the posterior approximation \(q(z|x)\) are Gaussian.
- Parameters:
- mu
torch.Tensor, optional Encoder output \(\mu\).
- logvar
torch.Tensor, optional Encoder output \(\log\sigma^2\).
- mu
- Returns:
- kld
torch.Tensor The KL divergence.
- kld
- class vae_mixin.BetaVAEMixin(beta=1, **kwargs)[source]#
Mixin class for \(\beta\)-VAE.
https://openreview.net/forum?id=Sy2fzU9gl
- Parameters:
- beta
float, default 1 Regularisation coefficient \(\beta\).
- beta
- class vae_mixin.InfoVAEMixin(lamb=1, alpha=0, **kwargs)[source]#
Mixin class for InfoVAE.
https://arxiv.org/abs/1706.02262
- Parameters:
- mmd_loss(z, **kwargs)[source]#
Compute the MMD loss of \(q(z)\) and \(p(z)\).
- Parameters:
- z
torch.Tensor The latent variables.
- z
- Returns:
- mmd
torch.Tensor The squared maximum mean discrepancy.
- mmd
- class vae_mixin.DIPVAEMixin(lambda_od=10, lambda_d=10, dip=1, **kwargs)[source]#
Mixin class for DIP-VAE.
https://arxiv.org/abs/1711.00848
- Parameters:
- dip_regularizer(mu=None, logvar=None, **kwargs)[source]#
Compute the DIP regularization terms, when \(q(z|x)\) is Gaussian.
- Parameters:
- mu
torch.Tensor, optional Encoder output \(\mu\).
- logvar
torch.Tensor, optional Encoder output \(\log\sigma^2\).
- mu
- Returns:
- off_diag_loss
torch.Tensor The off-diagonal loss.
- diag_loss
torch.Tensor The diagonal loss.
- off_diag_loss
- class vae_mixin.BetaTCVAEMixin(alpha=1, beta=1, gamma=1, **kwargs)[source]#
Mixin class for \(\beta\)-TCVAE.
https://arxiv.org/abs/1802.04942
- Parameters:
- decompose_kl(n, z, mu=None, logvar=None, **kwargs)[source]#
Compute the decomposed KL terms.
- Parameters:
- n
int Dataset size \(N\).
- z
torch.Tensor The latent variables.
- mu
torch.Tensor, optional Encoder output \(\mu\).
- logvar
torch.Tensor, optional Encoder output \(\log\sigma^2\).
- n
- Returns:
- mi_loss
torch.Tensor The index-code mutual information.
- tc_loss
torch.Tensor The total correlation.
- kl_loss
torch.Tensor The dimension-wise KL.
- mi_loss
- class vae_mixin.VQVAEMixin(embedding_dim, num_embeddings, commitment_cost, **kwargs)[source]#
Mixin class for VQ-VAE.
https://arxiv.org/abs/1711.00937
- Parameters:
- encode(x)[source]#
- Parameters:
- x
torch.Tensor The input data.
- x
- Returns:
- z_e
torch.Tensor The encoder outputs.
- z_e
- quantize(inputs)[source]#
Quantize the input vectors.
- Parameters:
- inputs
torch.Tensor The encoder outputs.
- inputs
- Returns:
- z_q
torch.Tensor The quantized vectors of the inputs.
- z
torch.Tensor The discrete latent variables.
- z_q
- vq_loss(z_e, z_q, **kwargs)[source]#
Compute the vector quantization loss terms.
- Parameters:
- z_e
torch.Tensor The encoder outputs.
- z_q
torch.Tensor The quantized vectors.
- z_e
- Returns:
- codebook_loss
torch.Tensor - commitment_loss
torch.Tensor
- codebook_loss