VQ-VAE

23 Nov 2022

VQ-VAE: Neural Discrete Representation Learning

Van Den Oord, Aaron and Vinyals, Oriol and others

NIPS 2017

Summary:

VQ-VAE is aiming to provide a discrete latent representation space rather than continuous VAE ones so that it could better model discrete modalities such as language, image, and reasoning etc. To be honest, I don’t see too much connection between the VQ-VAE and the VAE models besides they both share some high-level latent space regularization idea, VQ-VAE is more like a data compression tool from my view. The goal of the VQ-VAE is not to provide a ‘discrete’ VAE version of latent space which have explainable and gap-less latent space, rather, it is designed to provide a discrete and compressed version.

The encoder part of the VQ-VAE is a CNN which convert the image space input x into a 3D tensor z_e(x) with shape WHD. The VQ-VAE then maps every 1*D vector to one of the K 1*D latent embedding space vectors which has least Norm2 distance with the encoder vector. That K*D latent space is called the textbook, which can be treated as a collection of discrete tokens that used as the building block of the input of the decoder. By this mapping, the output of the WHD encoder z_e(x) is converted to a WHD tensor z_q(x) with all elements coming from that K*D codebook. Thus, the latent space z can be compressed to a W*H*log(K) tensor, where each element represent the index of the closest vector in the codebook. Then, the z_q(x) is passed to the decoder to reconstruct back the desired output.

The training process of the VQ-VAE is also worth learning. To make this discrete process trainable, the authors applied Vector Quantisation (VQ) algorithm, which is freezing the gradient of a tensor (z_e(x)) to make it to be identity during the forward process and have zero gradient for the backward process. So that the neural network won’t update the z_e(x). Since the z_e(x) and z_q(x) have same shape and they are assumed to be closed, the neural network just copy the gradient of the z_q(x) to z_e(x) to further update the encoder part to lower down the reconstruction loss. The loss function of VQ-VAE is also very novel. As the VAE, it has a reconstruction loss as first term represent as log(x|z_q(x)). Since there is a gradient copy that skip the codebook part, this term only affect the encoder and decoder. However, to make the codebook better represent the data, it has a second term that use VQ to stop the gradient of z_e(x) to make the codebook vector get closer to the z_e(x) to better represent. Also, since the range of z_e(x) is not restricted, to make sure it can generate some value that close to the codebook, they add a weighted third term to stop the gradient of the codebook and make z_e(x) approximate the codebook. We can imagine if there is no third term, the network would be really unstable at the beginning phase if the z_e(x) predict large value that the codebook can’t easily match. The scaling vector beta here is to balance the encoder commitment loss to the reconstruction loss. This somehow help the z_e(x) to find a balance between reconstructing the better result and committing to the latent space.

During the generation stage, we can’t directly random sample a W*H latent variable since the latent vectors are independent to each other and spatial relationship we forced by W*H shape and sampling distribution doesn’t carry any code-level semantics information. To bring that sequential semantics information back, we basically need to use another auto-regressive model (PixelCNN for image and WaveNet for raw audio) to learn that sequence information from large latent coded training dataset. The benefit we do so is we shrink down the image sequence size.

If we think deeper about the meaning of codebook, it comes form the convolution of the input images, which carry a representation of a local part of the image. Those element might represent some image pattern elements. We might find some code are about a ‘dog’ or ‘cat’ element. So rather than represent a dog on grass image in pixel level, we can represent the image to a ‘dog’ latent code and a ‘grassland’ latent code. Which will significantly lower down the long image size. The VQ-VAE is doing K-Means to make sure the z_e(x) and code e could generate those elements cluster. Or pulling the z_e(x) towards the clusters, which somehow can be thought as a way of regularization.

Strength:

Compress the image towards a latent embedding rather than the pixel space embedding. This makes the image sequence model much more tractable and much faster. It would be great if they could share the codebook so that we will have a different representation option later for image applications.

Critique:

Resolution is limited and blur might happen if we don’t have large enough K for the codebook. However, if we have enough large K, we will lose the meaning of applying VQ-VAE for compressing. It would be great if we could set that number K as a trainable parameter.