Source code for vae_mixin

"""Mixin classes for VAE varients."""

import math
from typing import Optional, Tuple

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch import Tensor


[docs] class VAEMixin: """Mixin class for basic VAE. https://arxiv.org/abs/1312.6114 """
[docs] @staticmethod def reparameterize(mu: Tensor, logvar: Tensor) -> Tensor: """Reparameterization trick.""" std = torch.exp(0.5 * logvar) eps = torch.randn_like(std) return mu + std * eps
[docs] def encode(self, x: Tensor) -> Tensor: """ Parameters ---------- x : torch.Tensor The input data. Returns ------- z : torch.Tensor The latent variables if training else mean. """ mu, logvar = self.encoder(x) self.mu_, self.logvar_ = mu, logvar if self.training: z = self.reparameterize(mu, logvar) return z return mu
[docs] def decode(self, z: Tensor) -> Tensor: """ Parameters ---------- z : torch.Tensor The latent variables. Returns ------- x_rec : torch.Tensor The reconstructed data. """ return self.decoder(z)
[docs] def rec_loss(self, x_rec: Tensor, x: Tensor, **kwargs) -> Tensor: """Compute the reconstruction loss, when the output :math:`p(x|z)` is Bernoulli. Parameters ---------- x_rec : torch.Tensor The reconstructed data. x : torch.Tensor The input data. Returns ------- bce : torch.Tensor The binary cross entropy. """ bce = F.binary_cross_entropy(x_rec, x, reduction="none").sum(dim=1).mean() self.loss_rec_ = bce return bce
[docs] def kl_loss( self, mu: Optional[Tensor] = None, logvar: Optional[Tensor] = None, **kwargs, ) -> Tensor: """Compute the KL divergence loss, when the prior :math:`p(z)=\mathcal{N}(0,\mathbf{I})` and the posterior approximation :math:`q(z|x)` are Gaussian. Parameters ---------- mu : torch.Tensor, optional Encoder output :math:`\mu`. logvar : torch.Tensor, optional Encoder output :math:`\log\sigma^2`. Returns ------- kld : torch.Tensor The KL divergence. """ mu = self.mu_ if mu is None else mu logvar = self.logvar_ if logvar is None else logvar var = torch.exp(logvar) kld = -0.5 * (1 + logvar - mu**2 - var).sum(dim=1).mean() self.loss_kl_ = kld return kld
[docs] def loss(self, **kwargs) -> Tensor: """Compute the VAE objective function. Parameters ---------- **kwargs Extra arguments to loss terms. Returns ------- torch.Tensor """ return self.rec_loss(**kwargs) + self.kl_loss(**kwargs)
[docs] class BetaVAEMixin: r"""Mixin class for :math:`\beta`-VAE. https://openreview.net/forum?id=Sy2fzU9gl Parameters ---------- beta : float, default 1 Regularisation coefficient :math:`\beta`. """ def __init__(self, beta: float = 1, **kwargs) -> None: self.beta = beta super().__init__(**kwargs)
[docs] def loss(self, **kwargs) -> Tensor: r"""Compute the :math:`\beta`-VAE objective function. Parameters ---------- **kwargs Extra arguments to loss terms. Returns ------- torch.Tensor """ return self.rec_loss(**kwargs) + self.beta * self.kl_loss(**kwargs)
[docs] class InfoVAEMixin: r"""Mixin class for InfoVAE. https://arxiv.org/abs/1706.02262 Parameters ---------- lamb : float, default 1 Scaling coefficient :math:`\lambda`. alpha : float, default 0 Information preference :math:`\alpha`. """ def __init__(self, lamb: float = 1, alpha: float = 0, **kwargs) -> None: self.lamb = lamb self.alpha = alpha super().__init__(**kwargs)
[docs] @staticmethod def kernel(x1: Tensor, x2: Tensor) -> Tensor: """Compute the RBF kernel.""" n1 = x1.size(0) n2 = x2.size(0) dim = x1.size(1) x1 = x1.view(n1, 1, dim) x2 = x2.view(1, n2, dim) return torch.exp(-1.0 / dim * (x1 - x2).pow(2).sum(dim=2))
[docs] def mmd_loss(self, z: Tensor, **kwargs) -> Tensor: """Compute the MMD loss of :math:`q(z)` and :math:`p(z)`. Parameters ---------- z : torch.Tensor The latent variables. Returns ------- mmd : torch.Tensor The squared maximum mean discrepancy. """ qz = z pz = torch.randn_like(z) mmd = ( self.kernel(pz, pz).mean() - 2 * self.kernel(qz, pz).mean() + self.kernel(qz, qz).mean() ) self.loss_mmd_ = mmd return mmd
[docs] def loss(self, **kwargs) -> Tensor: """Compute the InfoVAE objective function. Parameters ---------- **kwargs Extra arguments to loss terms. Returns ------- torch.Tensor """ return ( self.rec_loss(**kwargs) + (1 - self.alpha) * self.kl_loss(**kwargs) + (self.alpha + self.lamb - 1) * self.mmd_loss(**kwargs) )
[docs] class DIPVAEMixin: """Mixin class for DIP-VAE. https://arxiv.org/abs/1711.00848 Parameters ---------- lambda_od : float, default 10 Hyperparameter :math:`\lambda_{od}` for the loss on the off-diagonal entries. lambda_d : float, default 10 Hyperparameter :math:`\lambda_d` for the loss on the diagonal entries. dip : int, 1 or 2 DIP type. """ def __init__( self, lambda_od: float = 10, lambda_d: float = 10, dip: int = 1, **kwargs, ) -> None: self.lambda_od = lambda_od self.lambda_d = lambda_d self.dip = dip super().__init__(**kwargs)
[docs] def dip_regularizer( self, mu: Optional[Tensor] = None, logvar: Optional[Tensor] = None, **kwargs, ) -> Tuple[Tensor, Tensor]: """Compute the DIP regularization terms, when :math:`q(z|x)` is Gaussian. Parameters ---------- mu : torch.Tensor, optional Encoder output :math:`\mu`. logvar : torch.Tensor, optional Encoder output :math:`\log\sigma^2`. Returns ------- off_diag_loss : torch.Tensor The off-diagonal loss. diag_loss : torch.Tensor The diagonal loss. """ mu = self.mu_ if mu is None else mu logvar = self.logvar_ if logvar is None else logvar mu_mean = mu - mu.mean(dim=0) cov_mu = mu_mean.t() @ mu_mean / mu.size(0) if self.dip == 1: cov = cov_mu elif self.dip == 2: var = torch.diag_embed(torch.exp(logvar)) cov = var.mean(dim=0) + cov_mu else: raise ValueError cov_diag = torch.diag(cov) cov_off_diag = cov - torch.diag(cov_diag) off_diag_loss = cov_off_diag.pow(2).sum() diag_loss = (cov_diag - 1).pow(2).sum() self.loss_off_diag_, self.loss_diag_ = off_diag_loss, diag_loss return off_diag_loss, diag_loss
[docs] def loss(self, **kwargs) -> Tensor: """Compute the DIP-VAE objective function. Parameters ---------- **kwargs Extra arguments to loss terms. Returns ------- torch.Tensor """ off_diag_loss, diag_loss = self.dip_regularizer(**kwargs) return ( self.rec_loss(**kwargs) + self.kl_loss(**kwargs) + self.lambda_od * off_diag_loss + self.lambda_d * diag_loss )
[docs] class BetaTCVAEMixin: r"""Mixin class for :math:`\beta`-TCVAE. https://arxiv.org/abs/1802.04942 Parameters ---------- alpha : float, default 1 Weight :math:`\alpha` of the index-code mutual information. beta : float, default 1 Weight :math:`\beta` of the total correlation. gamma : float, default 1 Weight :math:`\gamma` of the dimension-wise KL. """ def __init__( self, alpha: float = 1, beta: float = 1, gamma: float = 1, **kwargs, ) -> None: self.alpha = alpha self.beta = beta self.gamma = gamma super().__init__(**kwargs)
[docs] @staticmethod def gaussian_log_density(x: Tensor, mu: Tensor, logvar: Tensor) -> Tensor: """Compute the log density of a Gaussian.""" log2pi = math.log(2 * math.pi) inv_var = torch.exp(-logvar) return -0.5 * (log2pi + logvar + (x - mu).pow(2) * inv_var)
[docs] def decompose_kl( self, n: int, z: Tensor, mu: Optional[Tensor] = None, logvar: Optional[Tensor] = None, **kwargs, ) -> Tuple[Tensor, Tensor, Tensor]: """Compute the decomposed KL terms. Parameters ---------- n : int Dataset size :math:`N`. z : torch.Tensor The latent variables. mu : torch.Tensor, optional Encoder output :math:`\mu`. logvar : torch.Tensor, optional Encoder output :math:`\log\sigma^2`. Returns ------- mi_loss : torch.Tensor The index-code mutual information. tc_loss : torch.Tensor The total correlation. kl_loss : torch.Tensor The dimension-wise KL. """ mu = self.mu_ if mu is None else mu logvar = self.logvar_ if logvar is None else logvar m, dim = z.shape logqz_x = self.gaussian_log_density(z, mu, logvar).sum(dim=1) zeros = torch.zeros_like(z) logpz = self.gaussian_log_density(z, zeros, zeros).sum(dim=1) # minibatch weighted sampling logqzxi_xj = self.gaussian_log_density( z.view(m, 1, dim), mu.view(1, m, dim), logvar.view(1, m, dim) ) logqz = torch.logsumexp(logqzxi_xj.sum(dim=2), dim=1) - math.log(n * m) logprodqz = (torch.logsumexp(logqzxi_xj, dim=1) - math.log(n * m)).sum(dim=1) mi_loss = (logqz_x - logqz).mean() tc_loss = (logqz - logprodqz).mean() kl_loss = (logprodqz - logpz).mean() self.loss_mi_, self.loss_tc_, self.loss_kl_ = mi_loss, tc_loss, kl_loss return mi_loss, tc_loss, kl_loss
[docs] def loss(self, **kwargs) -> Tensor: r"""Compute the :math:`\beta`-TCVAE objective function. Parameters ---------- **kwargs Extra arguments to loss terms. Returns ------- torch.Tensor """ mi_loss, tc_loss, kl_loss = self.decompose_kl(**kwargs) return ( self.rec_loss(**kwargs) + self.alpha * mi_loss + self.beta * tc_loss + self.gamma * kl_loss )
[docs] class VQVAEMixin: r"""Mixin class for VQ-VAE. https://arxiv.org/abs/1711.00937 Parameters ---------- embedding_dim : int The dimensionality of latent embedding vector. num_embeddings : int The size of the discrete latent space. commitment_cost : float Scalar :math:`\beta` which controls the weighting of the commitment loss. """ def __init__( self, embedding_dim: int, num_embeddings: int, commitment_cost: float, **kwargs, ) -> None: self.embedding_dim = embedding_dim self.num_embeddings = num_embeddings self.commitment_cost = commitment_cost super().__init__(**kwargs) self.embeddings = nn.Embedding(num_embeddings, embedding_dim)
[docs] def encode(self, x: Tensor) -> Tensor: """ Parameters ---------- x : torch.Tensor The input data. Returns ------- z_e : torch.Tensor The encoder outputs. """ z_e = self.encoder(x) return z_e
[docs] def quantize(self, inputs: Tensor) -> Tuple[Tensor, Tensor]: """Quantize the inputs vectors. Parameters ---------- inputs : torch.Tensor The encoder outputs. Returns ------- z_q : torch.Tensor The quantized vectors of the inputs. z : torch.Tensor The discrete latent variables. """ assert inputs.size(-1) == self.embedding_dim flat_inputs = inputs.reshape(-1, self.embedding_dim) distances = ( (flat_inputs**2).sum(dim=1, keepdim=True) - 2 * flat_inputs @ self.embeddings.weight.t() + (self.embeddings.weight**2).sum(dim=1) ) encoding_indices = distances.argmin(dim=1) encoding_indices = encoding_indices.reshape(inputs.shape[:-1]) quantized = self.embeddings.forward(encoding_indices) self.loss_codebook_, self.loss_commitment_ = self.vq_loss(inputs, quantized) # straight-through estimator quantized = inputs + (quantized - inputs).detach() return quantized, encoding_indices
[docs] def vq_loss(self, z_e: Tensor, z_q: Tensor, **kwargs) -> Tuple[Tensor, Tensor]: """Compute the vector quantization loss terms. Parameters ---------- z_e : torch.Tensor The encoder outputs. z_q : torch.Tensor The quantized vectors. Returns ------- codebook_loss : torch.Tensor commitment_loss : torch.Tensor """ codebook_loss = (z_e.detach() - z_q).pow(2).sum(dim=-1).mean() commitment_loss = (z_e - z_q.detach()).pow(2).sum(dim=-1).mean() return codebook_loss, commitment_loss
[docs] def loss(self, **kwargs) -> Tensor: """Compute the VQ-VAE objective function. Parameters ---------- **kwargs Extra arguments to loss terms. Returns ------- torch.Tensor """ beta = self.commitment_cost return ( self.rec_loss(**kwargs) + self.loss_codebook_ + beta * self.loss_commitment_ )