Aussie AI
Multi-Head Attention
-
Book Excerpt from "Generative AI in C++"
-
by David Spuler, Ph.D.
Multi-Head Attention
The standard Transformer architecture splits the attention algorithm in parallel using Multi-Head Attention (MHA). Typically, the calculation of the attention algorithm is parallelized across 16 heads. This is the reason that the “model dimension” of Transformers is usually an exact multiple of 16, so that it can be efficiently parallelized without any “extra” cases (i.e. avoids needing any padding).
When I read about the vanilla Transformer in 2017 having “Multi-Head Attention” with 16 attention heads acting in parallel, I had a major question about it: why? Or more specifically, why 16?
Is splitting the attention mechanism into 16 different “heads” just a code optimization? Is it only to make it faster? Or does it affect the model's output results to make it “smarter” because of changes to the distribution of weights and its processing algorithm? In other words, would the model act any differently in terms of outputs if we merged these 16 heads into one, or if we doubled it to 32 attention heads? Or would changing the number of attention heads only affect the speed?
I'm like a three-year-old: why? why? why?
The whole thing is very complicated, and I could be wrong, but here's my best answer: it's just a code optimization and there's nothing special about 16.
To be clear, I'm not saying that attention is just a code optimization, but only the “multi-head attention” modification. The attention mechanism makes the model smarter, but not splitting it into multiple heads. The innovation was attention and being able to run it fast enough was important for that reason.
In more detail, the computation of a multi-head attention simply computes exactly the same numbers as if it did the whole thing with one head and one massive matrix computation. Or if it had 32 heads instead of 16, and did smaller slices with more parallelism. I don't see anywhere in the multi-head attention algorithm where the final calculations are dependent on the number of heads or the size of each head. Yes, it's a nice parallelism, where it does 16 slices of the same computation and then merges them back together at the end. It seems to be purely a parallelization coding optimization, and does not change the accuracy or intelligence of the output.
Maybe I'm wrong. I've seen a lot of articles that say that MHA allows each attention head to focus on different features of the inputs, and gives a more intricate understanding of the text. The implication is that splitting the computations across multiple attention heads makes it smarter than one big tensor operation. But I can't see that in the code.
• Next: • Up: Table of Contents |
The new AI programming book by Aussie AI co-founders:
Get your copy from Amazon: Generative AI in C++ |