Attention Mechanism: Benefits and Applications

Recurrent Neural Networks (RNNs) are powerful neural network architectures used for modeling sequences. LSTM (Long Short Term Memory) based RNNs are surprisingly good at capturing long-term dependencies in the sequences. A barebones sequence-to-sequence/encoder-decoder architecture performs incredibly well in tasks like Machine Translation.

A typical sequence-sequence architecture consists of an encoder and a decoder RNN. The encoder processes a source sequence and reduces it into a fixed length vector – the context, and the decoder generates a target sequence, token by token, conditioned on the context. The context is usually the final state of the encoder RNN. Consider the following example, in which the source sequence in English, is mapped to the target sequence in French.

We can observe that the final state of the encoder RNN ultimately decides the decoding process, or at least heavily influences it, while the previous states of the RNN, {h0, h1, . . h5} do not have any influence over the decoding process. Next observation is that the final encoder state , a fixed length vector, represents the essence or the meaning of the source sequence in the context of translation.

When the source sequence is too long and contains multiple information-rich phrases apart from each other, the burden of capturing and condensing the information into a fixed length vector representation becomes too much for the encoder to bear. This inability leads to loss of information which gets reflected in the generated target sequence.

What if we had a mechanism to allow the decoder to selectively (dynamically) focus on the information-rich phrases in the source sequence? That would be great. The decoder would take a peek at the encoder states that are relevant to the current decoding step. This relevance sorting is precisely what the Attention Mechanism does – it allows the decoder to pay attention to different parts of the source sequence at different decoding steps.

But what does it mean to pay attention? How does the decoder decide to select which parts of the source sequence needs focus?

Attention Mechanism as a Black Box

Let’s play a game. Consider this scenario where the attention mechanism is presented to you, as a black box module. The inputs to the module are current decoder state sj and the set of encoder states {h?}. The output of the module is a list of floating-point numbers.

Without going into logical reasoning, your System 1 would have figured out that the number of floating-point values matches the number of tokens in the source sequence and by extension, the number of states in the encoder. We can hypothesize that the floating point numbers are weights associated with corresponding encoder states. But what do the weights signify? It is safe to assume that the weight associated with a state presents the degree of attention/focus the state deserves. Let’s call these values, attention weights.

The second inference is a bit difficult to make unless you’ve worked within the framework of probability. If you sum up the attention weights, they add up to 1. The weights are normalized, which means that the list of floating-point numbers represent the probability distribution of attention across the encoder states. What can we do with that? If the weights are normalized, we can combine the states together by taking a weighted sum of all the encoder states. This means, for each decoder step, we get a new set of attention weights (from the black box), leading to a new weighted combination of encoder states every step – a truly dynamic context, dependent on current input to decoder and the previous step of the decoder.

Pretty cool right? But aren’t we curious about what exactly happens inside the black box? Yes, we are. How do we map  ({hi },sj) to {? } to ? [4]. This happens in two steps. First, we calculate attention energies. These are basically unnormalized scores of alignment between decoder state  and the hidden states ?. The mapping from ({hi },sj) to the attention energies is known as the alignment model. The alignment model proposed by Bahdanau et al. [3], (the paper that introduced attention mechanism), is a simple linear combination ( ax + by ) of encoder states (?) and decoder state ( S?−1), followed by a non-linearity. To calculate attention weights %CE%B1, we simply normalize attention energies, by applying softmax.

Looking back

Now, let’s take another look at the questions we asked.

What does it mean to pay attention? 

To pay attention to a part of the source sequence, is to assign a high attention weight relative to the rest of the sequence. It means that the dynamic context built by the decoder contains more information from the encoder states corresponding to the parts of the source sequence that are interesting.

How does the decoder decide to select which parts of the source sequence to focus on? 

The alignment model calculates energies (unnormalized probabilities) corresponding to each encoder state. The energy of encoder state hi and decoder state sjgiven by eji, is a scalar measure of alignment or match between hiand sj. The decoder uses the alignment model for choosing which parts of the source sequence to focus on. The alignment model is parameterized by {?a, ?a, ?}, which are learnable parameters trained jointly with the rest of the network. Hence the name, Neural Machine Translation by Jointly Learning to Align and Translate.

Integrating Context

We have seen how to dynamically build a context, by soft-searching the source sequence for parts that are essential to predict a target word. We have not yet discussed how we condition the decoder on this context.

In a simple sequence-sequence model, we condition the decoder on the final state of the encoder – the context hn, by setting the initial state s0 of the decoder with hn. In [3], Bahdanau et al. force the decoder to depend on the dynamic context ct, by making the hidden state of the decoder st, a function of ct.

Luong et. al. in [5] suggest a simple function ? that calculates the attentional hidden state s, by composing together, the context and the current hidden state of the RNN cell and applying a linear transformation. The attentional hidden state is used for predicting the target word.

Alignment Models

The alignment model, also known as the scoring or score function, proposed by Bahdanau et. al., performs a linear combination of encoder states and decoder state, to calculate attention energies. There are few other methods for calculating alignment, proposed in [5].

Multiplicative Attention reduces encoder states {hi} and decoder state sj into attention scores, by applying simple matrix multiplications. This could be a parameteric function, with learnable parameters or a simple dot product of the hi and sj.


Additive Attention performs a linear combination of encoder states and the decoder state. It could be expressed in the following ways.

In different configurations of attention mechanism (which we will get into, in the next section), there are slight differences in performance of multiplicative and additive functions. In general, the performance of multiplicative and additive functions are similar but the multiplicative function is faster and more space-efficient [6].

Soft and Hard Attention

Before exploring the different configurations of attention mechanism, we must know the difference between soft and hard attention. After calculating the attention weights, we could select the encoder state which has the highest attention weight, as the context, instead of taking an attention-weighted sum of states, as we did above. This selection of/focus on the most relevant part of the source sequence, is called hard attention. Soft attention allows the decoder to consider all the states in the source sequence, weighted based on relevance. The distinction between soft and hard attention has nothing to do with the attention calculation mechanism. It is purely based on the selection/search method.

The Curse of Hard Attention

Hard attention is difficult to incorporate as a component in a neural network. A Neural Network is fundamentally a mathematical function. It is a heterogeneous composition of multiple smaller functions (layers). If we could represent our network as a function ?, then this function is expected be continuous, and hence, differentiable. Its differentiability is essential for calculating the error gradients, backpropagate gradients, and update parameters.

In other words, in order to train a network, it needs to be a differentiable function. Coming back to hard attention, the ?????? operation used in selecting the most relevant encoder state, is not a continuous function and hence, not differentiable.

Global and Local Configurations

In the context of Machine Translation, there are two configurations of attention mechanism – global and local [5]. The global configuration is what we have discussed above, where all the encoder states are considered while calculating attention weights. Attention is a probability distribution across all the encoder states, using which we compose the states into a context.

Since we consider all the encoder states, we call this configuration global.

Conversely, local configuration should allow the decoder to peak at a small segment of the source sequence, so that the decoder could selectively focus on the tokens in that small segment. But which segment to select?

As you may have already guessed, there are two levels of selection/search here. First we need to select which segment to pay attention to, then select which tokens in the segment to pay attention to. But how long should the segment be? We select (empirically) a hyperparameter ? such that our window of focus (segment length) is 2? + 1. The number of attention weights is now fixed and dependent on D, instead of being dependent of the source sequence length, which is variable in nature.

Local attention proposed in [5)] is sometimes also known as window-based attention. Let’s drill down a bit. In order to select a window of interest, we need to correctly place the center of the window. In other words, we need to select a token (location) in the source sequence which will be the center of attention – the center of the window. What is the problem here? We have run into hard attention again. If we use an ?????? to select the most interesting token (location) in the source sequence, the network becomes non-differentiable. Luong et. al. [5] proposed a clever hack to circumvent this issue.

Predictive Alignment

We first calculate Predictive Alignment, by reducing the current decoder state st to a “rough location” on the source sequence.

Note : To be compatible with conventions used in [5], I’ve used ? to index the decoder steps, instead of ?.

The ?????? function defined above, reduces the current decoder state st into a scalar. We apply sigmoid to squash the range of this scalar within (0,1). We then, multiply with source sequence length ? to arrive at the aligned position ?? ?[0, ?]. Notice that while calculating predictive alignment, we have not peaked at the source sequence (although the decoder state st is loosely conditioned on the source sequence). pt isn’t directly conditioned on the encoder states.

So, we have a rough location on the source sequence to focus our attention. This is where it gets really interesting. Without explicitly selecting a window on the source sequence, we want to place emphasis on the tokens around pt, while ignoring the ones outside the window. How do we do that without using ???????

The Gaussian Trick

Let us calculate alignment between encoder states and current decoder state st, using one of the alignment model we’ve defined above ?????(?,??). Now we need to construct a function around the location pt; a function that is differentiable – continuous throughout (−∞, +∞). Luckily we have the gaussian function (defined below) that is continuous and can be centered around pt. The guassian function, given by ?(?,??) in the figure below, is a symmetric bell-curve centered around pt and is a function of ? – the index along source sequence. Instead of non-differentiable sharp peak at pt, we now have a differentiable curve with it’s area distributed throughout (−∞, +∞).

The Gaussian function applied over alignment, truncates the attention values in such a way that the tokens near pt are favored over the rest of the tokens in the source sequence. We now have two levels of attention – ????? calculates global alignment while at fine-tunes the global alignment by shifting the focus near pt.

By careful application of the concepts we have discussed above, any flavor of attention mechanism used in deep NLP literature could be understood and reproduced. Let us look at two flavors of attention, that’s been extensively used in the literature.

Hierarchical Attention

In [7], Yang et. al. propose a Hierarchical Attention Network (HAN) to model the hierarchical structure of documents. Sentences form a document. Words form a sentence. HAN learns sentence level representation and then composes them together to learn document representation. The model uses two levels of attention – word-level and sentence-level attention, to pay attention to sentences and to individual words while constructing document representation.

HAN consists of a word encoder – a bidirectional GRU-based RNN that encodes a document word by word to produce word-level representation {{h?0, . . . , h??}} of document. We make use of the global attention mechanism discussed above, to reduce each sentence (list of words) into a sentence vector, by ({???}) attention-weighted sum of individual word representations. Applying Word Attention to sentences in a document yields an array of sentence vectors. The sentence vectors are further processed by another bi-directional GRU-based RNN (sentence encoder) to incorporate document-level contextual information. Once again, we apply global attention mechanism on the sentence representations {h0, hj, . . hl}- Sentence Attention, to produce the final representation ? of the document – ({? }) sentence-level attention-weighted sum of sentence representations.

Multi-hop Attention

bAbI [8] is a synthetic RC (Reading Comprehension) dataset, created by researchers at FAIR in 2015. By synthetic, I mean the data is not extracted from a book or from the internet, it is generated by using a few rules that simulate natural language. This characteristic of bAbI places the weight of the task on reasoning rather than language understanding. The dataset is organized into 20 tasks, based on the type and complexity of reasoning necessary to answer questions.

Each sample in a task consists of a list of facts/statements {??} and a question ? based on the facts. The objective is to understand the question, acquire clues from the facts, perform appropriate reasoning based on the information from the facts and predict the answer a to the question.

Sukhbaatar et. al. in [9], propose the famous End-to-End Memory Networks, which consists of memory slots {??}– one for representing each fact. It uses an iterative attention mechanism to score the importance of each fact and reduce them into a final joint representation of the question and facts.

The sentences {??} are encoded into memory vectors {??} by a simple embedding lookup based on the embedding matrix A. The question ? is encoded into the initial internal memory state ?, using the embedding matrix B. Attention mechanism is applied on the memory vectors {??}, conditioned on ?, to produce attention weights {??}. Another set of memory vectors {??} are created using the embedding matrix C. An output memory representation 0 is produced by a weighted sum of {??}, weighted on {??}. A new internal memory representation is formed by summing up ? and 0. This is one iteration.

Observe the repetitive structure in figure (?) above. The attention mechanism is applied in three iterations and the memory representations ? and ? are iteratively updated. The number of hops is 3, in this case. After 3 hops, we condition the answer selection step on ? and ?.

By iteratively applying attention mechanism on the facts, we refine the internal representation of facts based on the information (relevance) from the question. By aggregating relevant information into the final representation (?(3) + ?(3)), we provide necessary information to the answer selection layer, to predict the answer to the question. This form of multi-hop “search” (on facts) allows the model to learn to perform some sophisticated reasoning for solving certain challenging tasks.

Attention Mechanism has become an integral part of neural network-based models designed for NLP. Both sequence-to-sequence tasks, like Machine Translation, Document Summarization, etc. and classification tasks, like Sentiment Analysis, PoS Tagging, Document Classification, etc. benefit by including this ability to selectively focus on segments of the sequence.

In the article, Embed, encode, attend, predict: The new deep learning formula for state-of-the-art NLP models, Matthew Honnibal identifies and summarizes the EEAP framework – a recurring pattern in deep NLP literature. The state-of-the-art (SoTA) neural network models that solve NLP tasks I’ve mentioned above, are most likely a variant of the EEAP framework. Attention plays a critical role of selective reduction (aggregation of information) in the EEAP framework and hence, is an essential component of SoTA models.

In this article, I have tried to throw light on the mysterious fragments of attention mechanism. It is no longer mysterious. We have grounded it in some basic mathematics. I’ve thrown in some of my own metaphors that helped me shape my intuitive understanding. Understanding of the concepts we’ve discussed above should enable you to appreciate the role and significance of attention mechanism used extensively in deep NLP literature.

References

  1. Sepp Hochreiter, Jürgen Schmidhuber, Long Short-Term Memory
  2. Ilya Sutskever, Oriol Vinyals, Quoc V. Le, Sequence to Sequence Learning with Neural Networks Our method uses a multilayered Long Short-Term Memory (LSTM) to map the input sequence to a vector of a fixed dimensionality, and then another deep LSTM to decode the target sequence from the vector.
  3. Dzmitry Bahdanau, Kyunghyun Cho, Yoshua Bengio, Neural Machine Translation by Jointly Learning to Align and Translate In this paper, we conjecture that the use of a fixed-length vector is a bottleneck in improving the performance of this basic encoder-decoder architecture, and propose to extend this by allowing a model to automatically (soft-)search for parts of a source sentence that are relevant to predicting a target word, without having to form these parts as a hard segment explicitly.
  4. Instead of restricting the decoder’s access to just the last state of the encoder, we allow the decoder to peek at the whole array of encoder states. Imagine a person manually entering a large number into a portal. Let’s say he’s entering an OTP code which he received via SMS, to log in to his Gmail account. I, personally need to peek at it, at least twice, to break down the sequence of numbers into two smaller sequences and then enter it in two iterations.
  5. Effective Approaches to Attention-based Neural Machine Translation This paper examines two simple and effective classes of attentional mechanism: a global approach which always attends to all source words and a local one that only looks at a subset of source words at a time.
  6. Ruder’s blog on Deep Learning for NLP Best Practices Additive and multiplicative attention are similar in complexity, although multiplicative attention is faster and more space-efficient in practice as it can be implemented more efficiently using matrix multiplication.
  7. Zichao Yang, Diyi Yang, Chris Dyer, Xiaodong He, Alex Smola, Eduard Hovy, Hierarchical Attention Networks for Document Classification
  8. The bAbI project This page gather resources related to the bAbI project of Facebook AI Research which is organized towards the goal of automatic text understanding and reasoning.
  9. Sainbayar Sukhbaatar, Arthur Szlam, Jason Weston, Rob Fergus, End-To-End Memory Networks
  10. Embed, encode, attend, predict : The new deep learning formula for state-of-the-art NLP models
Facebook
LinkedIn
Twitter
YouTube

About Suriyadeepan Ramamoorthy

mmSuriyadeepan Ramamoorthy is an AI researcher and engineer from Puducherry. His primary areas of research are Natural Language Understanding and Reasoning. He actively blogs about Deep Learning at suriyadeepan.github.io

At SAAMA, he's applying advanced Deep Learning techniques for Biomedical Text Analysis. He is a Free Software Evangelist who is actively involved in community development activities at FSFTN. His other interests include Community Networks, Data Visualisation and Creative Coding.


Related Posts