Attention Is All You Need
Ashish Vaswani, et al 31st Conference on Neural Information Processing Systems (NIPS 2017)
Summary:
This paper proposes the famous neural translation architecture Transformer, which replaces the conventional RNN structure with scaled dot product attention and positional encoding to achieve less computational cost and better parallelization for training. Unlike the multi layer RNN in previous work, the encoder of Transform consist of several self-attention + fully connected layer stacks.
For encoder stack, the intuition behind self-attention comes from the coreference phenomenon in language. When considering the semantics of a word in the sentence, by weightly combining every rest of the words’ semantics into that word, one can get a better representation to express the semantics of that word in the given sentence. After get a self-attention semantics representation of every word, all the word representation are fed into the forward FC layer to get better non-linear representation. Notice that, the word wise dot product don’t contain any sequence location information, to cover the loss of the sequence information, the encoder stack will first concatenate the input word vector with the position encoding. The position encoding is the output of a sinusoid function given the word abs location in the sentence. (Without sinusoid the last word of the sentence will always tend to get the largest weight in self-attention operation). Besides, to find out the best stack number in the encoder, there are residual connections in sublayers of the stacks.
For the decoder stack, there are 3 sublayers rather than 2 in the encoder stack. The first sublayer is the output self-attention layer. It only focus on the self-attention on all output words that have already decoded. The positional encoding for the output sentence also happens in this sublayers. Thus, it will bring the semantics information of the previous decoded words. The second sublayer is called encoder-decoder attention. Taken the previous output of the decoded self-attention sublayer as input, it performs the scaled dot-product with all the word outputs of the encoder. The goal here is to find on which source sentences’ word the current decoded context should pay more attention to. The result should be the weighted mix of the weighted source sentence words semantics and the current decoded context. The output then be passed to the third sublayer, which is a feedforward FC layer to non-linearly map the current semantics mix to the target language word’s representation vector. As what encoder does, all sublayers of the decoder have its own residual connection to make sure the model could find the best layer number.
Strength:
The art of the paper lies in its standard teach for later researchers how to convert the RNN structure to self-attention + FC layer, which could save lots of computational resources and better leverage parallelization.
The save of the computation comes from the shrink of the linear operation. For each RNN unit, a weight matrix should first multiply the input d-dimension vector to generate the hidden state and then use another weight matrix to multiply the hidden state to generate the unit output. The encoder RNN time steps number is equal to the length of the input sentence, thus, the final linear operation number for a single layer RNN encoder should be 2*d*d*n. However, for the self-attention operation, every d-dimension word vector will generate its weight by multiplying d*n size key matrix and d*n size queue matrix. And the output is the weighted sum of the dot product of the vector and the d*n size value matrix. In this process, all the linear operation is 3*d*n*n. In the most of the cases, d »n, thus, the time got saved. Besides, the length of the input sentences is pre-known thus can be parallelized same time rather than go through n time steps in RNN.
Critique:
The FC layer in encoder is not that important compared with those in decoder. No need to set equal layer number for both encoder and decoder in my opinion. Encoder is just a representation.