API Reference#
Mixin classes for VAE varients.
- class vae_mixin.VAEMixin[source]#
Mixin class for basic VAE.
https://arxiv.org/abs/1312.6114
- encode(x)[source]#
- Parameters:
x (
torch.Tensor
) – The input data.- Returns:
z (
torch.Tensor
) – The latent variables if training else mean.
- decode(z)[source]#
- Parameters:
z (
torch.Tensor
) – The latent variables.- Returns:
x_rec (
torch.Tensor
) – The reconstructed data.
- rec_loss(x_rec, x, **kwargs)[source]#
Compute the reconstruction loss, when the output \(p(x|z)\) is Bernoulli.
- Parameters:
x_rec (
torch.Tensor
) – The reconstructed data.x (
torch.Tensor
) – The input data.
- Returns:
bce (
torch.Tensor
) – The binary cross entropy.
- kl_loss(mu=None, logvar=None, **kwargs)[source]#
Compute the KL divergence loss, when the prior \(p(z)=\mathcal{N}(0,\mathbf{I})\) and the posterior approximation \(q(z|x)\) are Gaussian.
- Parameters:
mu (
torch.Tensor
, optional) – Encoder output \(\mu\).logvar (
torch.Tensor
, optional) – Encoder output \(\log\sigma^2\).
- Returns:
kld (
torch.Tensor
) – The KL divergence.
- class vae_mixin.BetaVAEMixin(beta=1, **kwargs)[source]#
Mixin class for \(\beta\)-VAE.
https://openreview.net/forum?id=Sy2fzU9gl
- Parameters:
beta (
float
, default1
) – Regularisation coefficient \(\beta\).
- class vae_mixin.InfoVAEMixin(lamb=1, alpha=0, **kwargs)[source]#
Mixin class for InfoVAE.
https://arxiv.org/abs/1706.02262
- Parameters:
- mmd_loss(z, **kwargs)[source]#
Compute the MMD loss of \(q(z)\) and \(p(z)\).
- Parameters:
z (
torch.Tensor
) – The latent variables.- Returns:
mmd (
torch.Tensor
) – The squared maximum mean discrepancy.
- class vae_mixin.DIPVAEMixin(lambda_od=10, lambda_d=10, dip=1, **kwargs)[source]#
Mixin class for DIP-VAE.
https://arxiv.org/abs/1711.00848
- Parameters:
- dip_regularizer(mu=None, logvar=None, **kwargs)[source]#
Compute the DIP regularization terms, when \(q(z|x)\) is Gaussian.
- Parameters:
mu (
torch.Tensor
, optional) – Encoder output \(\mu\).logvar (
torch.Tensor
, optional) – Encoder output \(\log\sigma^2\).
- Returns:
off_diag_loss (
torch.Tensor
) – The off-diagonal loss.diag_loss (
torch.Tensor
) – The diagonal loss.
- class vae_mixin.BetaTCVAEMixin(alpha=1, beta=1, gamma=1, **kwargs)[source]#
Mixin class for \(\beta\)-TCVAE.
https://arxiv.org/abs/1802.04942
- Parameters:
- decompose_kl(n, z, mu=None, logvar=None, **kwargs)[source]#
Compute the decomposed KL terms.
- Parameters:
n (
int
) – Dataset size \(N\).z (
torch.Tensor
) – The latent variables.mu (
torch.Tensor
, optional) – Encoder output \(\mu\).logvar (
torch.Tensor
, optional) – Encoder output \(\log\sigma^2\).
- Returns:
mi_loss (
torch.Tensor
) – The index-code mutual information.tc_loss (
torch.Tensor
) – The total correlation.kl_loss (
torch.Tensor
) – The dimension-wise KL.
- class vae_mixin.VQVAEMixin(embedding_dim, num_embeddings, commitment_cost, **kwargs)[source]#
Mixin class for VQ-VAE.
https://arxiv.org/abs/1711.00937
- Parameters:
- encode(x)[source]#
- Parameters:
x (
torch.Tensor
) – The input data.- Returns:
z_e (
torch.Tensor
) – The encoder outputs.
- quantize(inputs)[source]#
Quantize the inputs vectors.
- Parameters:
inputs (
torch.Tensor
) – The encoder outputs.- Returns:
z_q (
torch.Tensor
) – The quantized vectors of the inputs.z (
torch.Tensor
) – The discrete latent variables.
- vq_loss(z_e, z_q, **kwargs)[source]#
Compute the vector quantization loss terms.
- Parameters:
z_e (
torch.Tensor
) – The encoder outputs.z_q (
torch.Tensor
) – The quantized vectors.
- Returns:
codebook_loss (
torch.Tensor
)commitment_loss (
torch.Tensor
)