VAE: Variational Auto Encoder

22 Nov 2022

Variational Inference: A Review for Statisticians

Blei, David M and Kucukelbir, Alp and McAuliffe, Jon D

Journal of the American statistical Association 2017

Summary:

The motivation of VAE came from improving the representation (latent space, z) quality of the traditional autoencoder, so that we could take random samples from the latent space and pass it to decoder to generate similar images of our training dataset. The traditional autoencoder is hard to achieve this goal due to there are many gaps and characteristics overlaps in the latent space, which results in the sampling inside the latent gap doesn’t have meaningful generation in the image space. VAE tries to alleviate this problem by regularizing the latent space as a standard Gaussian distribution.

For every input data, rather than encode it directly to the latent space as autoencoder does, the VAE uses neural network to predict a particular Gaussian distribution for that input, which are the mean and the variance. With such parameters, the VAE sample a latent variable from that Gaussian distribution. That latent variable is the output of the encoder part and it is passed to a decoder NN to get the final decoded results. To make sure the final generated results match the desired output and the latent space could be regularized by a standard normal distribution, the loss function contains two parts: a reconstruction loss and a KL divergence regularization loss. By feeding and updating such network with all training dataset using gradient descent, we expect to get an optimized latent space so that the latent space is a standard Gaussian distribution. The reason behind this is we can write p(z) as the sum of the product of the p(x) and p(z|x). Since the sum of the p(x) is 1, if p(z |x) follows the standard Gaussian distribution, then p(z) will also fit for standard Gaussian distribution. Since the latent space fit for a standard normal distribution, we expect this is a continuous space rather than a discrete one. When we want to generate new images, we could directly do a standard Gaussian sampling to get the generated images. Since there is some overlapping area in the center due to the normal distribution, we can do characteristic mixing in those overlapping area.

Adversarial generation is another insightful angle to rethink the training process of the VAE, in other words, the relationship between the variance generator encoder and the VAE decoder. If there is no regularization term, the encoder might encode the latent vector for similar input to different places since the loss function is only about getting the better reconstruction. This is actually equal to add a regularization term with zero variance. Especially at the beginning of the training phase, the decoder doesn’t learn well which cause the VAE to generate a smaller variance to make the reconstruction loss easier. However, too much improvement on the decoder side will be limited by the regularization of the encoder side. In the later training phase, the decoder is good enough to make the reconstruction loss small. In that case, the regularization loss will behave as the main part of the loss and increase the generated variance close to 1, thus, force the decoder to learn better. This idea is similar in the D and G networks relationship in GAN. But VAE did it in a more implicit way.

Strength:

Comparing with the autoencoder, VAE has more continuous and explainable latent space to sample. This is due to the overall standard normalization constraint on the p(z) and the similar images tend to generate the similar mean and variance (Thus tend to gather together). It’s training process is relatively simple compared with the GAN methods since the converging speed is automatically controlled by the model itself rather than hyperparameters in GAN. Compared with the diffusion process, the generation speed is as fast as the inference time of the decoder.

Critique:

VAE regularization term actually deteriorates the image quality since it is competing with the reconstruction loss. There seems to be an inherent tradeoff between the image quality and the latent space quality, which seems to be different from the GAN and diffusion methods.