TL;DR This post reviews Multi-head Self-attention (MHA), Group Query Attention (GQA), and Multi-head Latent Attention (MLA).

This blog post reviews Multi-head Self-attention (MHA), Group Query Attention (GQA), and Multi-head Latent Attention (MLA) in a comparative manner. Similar discussions have been presented in [1, 2]; and our explanation draws on these work while providing a more explicit explanation.

For clarity, we initially omit the positional embedding in our explanation, and add it back later.

Multi-head Self-attention

Let us denote the number of tokens as $n \in \mathbb{N}$, hidden dimension as $d \in \mathbb{N}$, the number of heads as $h \in \mathbb{N}$, and dimension of each head as $d_h = \frac{d}{h} \in \mathbb{N}$.

Multi-head Self-attention (MHA) [3] computes the output via following process. For input $\mathbf{X} \in \mathbb{R}^{n \times d}$, it first computes the query, key, and value matrices by:

\[\mathbf{Q} = \mathbf{X} \mathbf{W}_Q, \quad \mathbf{K} = \mathbf{X} \mathbf{W}_K, \quad \mathbf{V} = \mathbf{X} \mathbf{W}_V,\]

where $\mathbf{W}_Q \in \mathbb{R}^{d \times d}$, $\mathbf{W}_K \in \mathbb{R}^{d \times d}$, and $\mathbf{W}_V \in \mathbb{R}^{d \times d}$ are the weight matrices.

Afterwards, it splits the computed $\mathbf{Q}$, $\mathbf{K}$, and $\mathbf{V}$ into $h$ heads:

\[\forall j \in [h], \quad \mathbf{Q}_j = Q_{:, d_h(j-1):d_h j}, \quad \mathbf{K}_j = K_{:, d_h(j-1):d_h j}, \quad \mathbf{V}_j = V_{:, d_h(j-1):d_h j},\]

where $\mathbf{Q}_j \in \mathbb{R}^{n \times d_h}$, $\mathbf{K}_j \in \mathbb{R}^{n \times d_h}$, and $\mathbf{V}_j \in \mathbb{R}^{n \times d_h}$.

Then, it computes the attention weights for each head:

\[\forall j \in [h], \quad \mathbf{Z}_j = \text{softmax}\left(\frac{\mathbf{Q}_j \mathbf{K}_j^\top}{\sqrt{d_h}}\right) \mathbf{V}_j ,\]

where $\mathbf{Z}_j \in \mathbb{R}^{n \times d_h}$. Finally, it concatenates the results from all heads in a column-wise manner, and then applies the output projection linear layer:

\[\mathbf{Y}_{\text{MHA}} = \text{MHA}(\mathbf{Q}, \mathbf{K}, \mathbf{V}) = \begin{bmatrix} \mathbf{Z}_1 & \mathbf{Z}_2 & \cdots & \mathbf{Z}_h \vphantom{\begin{pmatrix} a \\ b \end{pmatrix}} \end{bmatrix}_{n \times d} \mathbf{W}_O,\]

where $\mathbf{W}_O \in \mathbb{R}^{d \times d}$ is the weight matrix of the output projection linear layer.

Group Query Attention

During the inference process, the key ($\mathbf{K}$) and value ($\mathbf{V}$) matrices are stored in a cache (KV cache) to prevent repeating computations. However, when generating long sequences, this KV cache can become very large and end up being a bottleneck of the inference process.

Group Query Attention (GQA) [4, 5] addresses this issue by compressing the $\mathbf{K}$ and $\mathbf{V}$ matrices, retaining only $\frac{1}{g}$ of their original columns (where $g \in \mathbb{N}$ is a divisor of $h$). These compressed matrices are then reused across different query heads in the multi-head self-attention mechanism. When $g=h$, this approach becomes equivalent to the multi-query attention described in [4], and when $g=1$, it becomes equivalent to the original MHA.

Concretely, for $d_g = \frac{d}{g} \in \mathbb{N}$, GQA defines a smaller projection matrices $\widetilde{\mathbf{W}}_K \in \mathbb{R}^{d \times d_g}$ and $\widetilde{\mathbf{W}}_V \in \mathbb{R}^{d \times d_g}$, and then computes key and value matrices through:

\[\widetilde{\mathbf{K}} = \mathbf{X} \widetilde{\mathbf{W}}_K, \quad \widetilde{\mathbf{V}} = \mathbf{X} \widetilde{\mathbf{W}}_V,\]

where $\widetilde{\mathbf{K}} \in \mathbb{R}^{n \times d_g}$ and $\widetilde{\mathbf{V}} \in \mathbb{R}^{n \times d_g}$. This reduces the KV cache size by a factor of $g$.

Afterwards, GQA reuses this $\widetilde{\mathbf{K}}$ and $\widetilde{\mathbf{V}}$ for all query heads, which can be written as:

\[\mathbf{K}_{\text{GQA}} = \widetilde{\mathbf{K}} \begin{bmatrix} \mathbf{I}_{d_g} & \mathbf{I}_{d_g} & \cdots & \mathbf{I}_{d_g} \vphantom{\begin{pmatrix} a \\ b \end{pmatrix}} \end{bmatrix}_{d_g \times d} = \widetilde{\mathbf{K}} \mathbf{J}_{d_g \times d}, \\ \mathbf{V}_{\text{GQA}} = \widetilde{\mathbf{V}} \begin{bmatrix} \mathbf{I}_{d_g} & \mathbf{I}_{d_g} & \cdots & \mathbf{I}_{d_g} \vphantom{\begin{pmatrix} a \\ b \end{pmatrix}} \end{bmatrix}_{d_g \times d} = \widetilde{\mathbf{V}} \mathbf{J}_{d_g \times d},\]

where \(\mathbf{I}_{d_g} \in \mathbb{R}^{d_g \times d_g}\) is the identity matrix, \(\mathbf{J}_{d_g \times d} \in \mathbb{R}^{d_g \times d}\) matrix concatenating $\mathbf{I}_{d_g}$ for $g$ times across column dimension.

Then, GQA computes the attention output via MHA with respect to this \(\mathbf{K}_{\text{GQA}}\) and $\mathbf{V}_{\text{GQA}}$.

\[\mathbf{Y}_{\text{GQA}} = \text{MHA}(\mathbf{Q}, \mathbf{K}_{\text{GQA}}, \mathbf{V}_{\text{GQA}}),\]

where $\mathbf{W}_O \in \mathbb{R}^{d \times d}$ is the output weight matrix. In practice, $\widetilde{\mathbf{K}}$ and $\widetilde{\mathbf{V}}$ are broadcast along the column dimension, enabling to compute the attention output in an efficient manner.

Multi-head Latent Attention

Multi-head Latent Attention (MLA) [6] generalizes GQA by recognizing that GQA approximates the projection to $\mathbf{K}$ and $\mathbf{V}$ using low-rank linear transformations. MLA extends this idea by directly learning the low-rank transformation, effectively treating GQA as a special case.

Let us revisit the MHA and GQA process. We first denote a concatenation of \(\mathbf{W}_K \in \mathbb{R}^{d \times d}\) and \(\mathbf{W}_V \in \mathbb{R}^{d \times d}\) as \(\mathbf{W}_{KV} = \begin{bmatrix} \mathbf{W}_K \vphantom{\begin{pmatrix} a \\ b \end{pmatrix}} & \mathbf{W}_V \vphantom{\begin{pmatrix} a \\ b \end{pmatrix}} \end{bmatrix} \in \mathbb{R}^{d \times 2d}\). Recall that MHA computes the key and value matrices by:

\[\begin{bmatrix} \mathbf{K} \vphantom{\begin{pmatrix} a \\ b \end{pmatrix}} & \mathbf{V} \vphantom{\begin{pmatrix} a \\ b \end{pmatrix}} \end{bmatrix}_{n \times 2d} = \begin{bmatrix} \mathbf{X} \mathbf{W}_K \vphantom{\begin{pmatrix} a \\ b \end{pmatrix}} & \mathbf{X} \mathbf{W}_V \vphantom{\begin{pmatrix} a \\ b \end{pmatrix}} \end{bmatrix}_{n \times 2d} = \mathbf{X} \mathbf{W}_{KV}\]

GQA instead, employs smaller projection matrices \(\widetilde{\mathbf{W}}_K \in \mathbb{R}^{d \times d_g}\) and \(\widetilde{\mathbf{W}}_V \in \mathbb{R}^{d \times d_g}\), and then computes key and value matrices for attention computation by:

\[\begin{bmatrix} \mathbf{K}_{\text{GQA}} \vphantom{\begin{pmatrix} a \\ b \end{pmatrix}} & \mathbf{V}_{\text{GQA}} \vphantom{\begin{pmatrix} a \\ b \end{pmatrix}} \end{bmatrix}_{n \times 2d} = \begin{bmatrix} \mathbf{X} \widetilde{\mathbf{W}}_K \mathbf{J}_{d_g \times d} \vphantom{\begin{pmatrix} a \\ b \\ c \end{pmatrix}} & \mathbf{X} \widetilde{\mathbf{W}}_V \mathbf{J}_{d_g \times d} \end{bmatrix}_{n \times 2d} = \mathbf{X} \widetilde{\mathbf{W}}_{KV} \mathbf{J}_{KV},\]

where \(\mathbf{W}_{KV} \in \mathbb{R}^{d \times 2d_g}\) and \(\mathbf{J}_{KV} \in \mathbb{R}^{2d_g \times 2d}\) is defined as:

\[\mathbf{W}_{KV} = \begin{bmatrix} \widetilde{\mathbf{W}}_K \vphantom{\begin{pmatrix} a \\ b \end{pmatrix}} & \widetilde{\mathbf{W}}_V \vphantom{\begin{pmatrix} a \\ b \end{pmatrix}} \end{bmatrix}_{d \times 2d_g}, \quad \mathbf{J}_{KV} = \text{diag}\big(\mathbf{J}_{d_g \times d},\quad \mathbf{J}_{d_g \times d}\big).\]

It becomes clear that GQA replaces the original \(d \times 2d\) projection matrix $\mathbf{W}_{KV}$ with the product of two matrices: a \(d \times 2d_g\) matrix \(\widetilde{\mathbf{W}}_{KV}\) and a \(2d_g \times 2d\) matrix \(\mathbf{J}_{KV}\) (\(2d_g \ll d\)), thereby forming a low-rank approximation.

Now, the MLA generalizes this idea by directly learning the low-rank transformation, effectively treating GQA as a special case. Specifically, MLA learns a projection matrix \(\mathbf{W}_{DKV} \in \mathbb{R}^{d \times d_c}\) and a up-projection matrix \(\mathbf{W}_{UKV} \in \mathbb{R}^{d_c \times 2d}\) to approximate the original MHA KV cache by:

\[\begin{bmatrix} \mathbf{K} \vphantom{\begin{pmatrix} a \\ b \end{pmatrix}} & \mathbf{V} \vphantom{\begin{pmatrix} a \\ b \end{pmatrix}} \end{bmatrix}_{n \times 2d} = \mathbf{X} \mathbf{W}_{KV} \approx \mathbf{X} \mathbf{W}_{DKV} \mathbf{W}_{UKV} = \begin{bmatrix} \mathbf{K}_{\text{MLA}} \vphantom{\begin{pmatrix} a \\ b \end{pmatrix}} & \mathbf{V}_{\text{MLA}} \vphantom{\begin{pmatrix} a \\ b \end{pmatrix}} \end{bmatrix}_{n \times 2d},\]

where \(d_c \in \mathbb{N}\) is the dimension of the compressed latent KV cache. Finally, MLA computes the attention output via MHA with respect to this \(\mathbf{K}_{\text{MLA}}\) and \(\mathbf{V}_{\text{MLA}}\).

\[\mathbf{Y}_{\text{MLA}} = \text{MHA}(\mathbf{Q}, \mathbf{K}_{\text{MLA}}, \mathbf{V}_{\text{MLA}}).\]

In the inference process, the method saves the compressed latent KV cache \(\mathbf{C} = \mathbf{X} \mathbf{W}_{DKV} \in \mathbb{R}^{n \times d_c}\).

However, introducing the learned \(\mathbf{W}_{UKV}\) and reconstructing \(\mathbf{K}_{\text{MLA}}\) and \(\mathbf{V}_{\text{MLA}}\) from \(\mathbf{C}\) introduces additional computational overhead at the inference time. Concretely, let us define \(\mathbf{W}_{UK} \in \mathbb{R}^{d_c \times d}\) and \(\mathbf{W}_{UV} \in \mathbb{R}^{d_c \times d}\) as

\[\mathbf{W}_{UK} = \mathbf{W}_{UKV}[:,:d], \quad \mathbf{W}_{UV} = \mathbf{W}_{UKV}[:,d:2d].\]

Then, in the inference time, we need to compute

\[\mathbf{K}_{\text{MLA}} = \mathbf{C} \mathbf{W}_{UK}, \quad \mathbf{V}_{\text{MLA}} = \mathbf{C} \mathbf{W}_{UV},\]

to reconstruct \(\mathbf{K}_{\text{MLA}}\) and \(\mathbf{V}_{\text{MLA}}\) from the compressed latent KV cache \(\mathbf{C}\).

To avoid this extra computation, MLA uses a simple trick: it merges \(\mathbf{W}_{UK}\) into \(\mathbf{W}_Q\) and \(\mathbf{W}_{UV}\) into \(\mathbf{W}_O\) as follows:

\[\mathbf{W}_Q \leftarrow \mathbf{W}_Q {\mathbf{W}_{UK}}^\top, \quad \mathbf{W}_O \leftarrow {\mathbf{W}_{UV}} \mathbf{W}_O.\]

This integration removes the additional computational overhead at the inference time, due to the following equivalence:

\[\mathbf{Q}_{\text{MLA}} {\mathbf{K}_{\text{MLA}}}^\top = (\mathbf{X}\mathbf{W}_Q) (\mathbf{X} \mathbf{W}_{DKV} \mathbf{W}_{UK})^\top = \mathbf{X} \big( \mathbf{W}_Q {\mathbf{W}_{UK}}^\top \big)\mathbf{C}^\top, \\ \mathbf{V}_{\text{MLA}} \mathbf{W}_O = \mathbf{X} \mathbf{W}_{DKV} \mathbf{W}_{UV} \mathbf{W}_O = \mathbf{C} \big(\mathbf{W}_{UV} \mathbf{W}_O \big).\]

MLA with RoPE.

RoPE [7] is a positional encoding technique that is used to inject positional information into the attention mechanism. RoPE post-processes the computed query and key matrices before the attention computation by:

\[\forall i \in [n]: \quad \mathbf{Q}{[i, :]} \leftarrow \mathbf{Q}{[i, :]} \mathcal{R}_i, \quad \mathbf{K}{[i, :]} \leftarrow \mathbf{K}{[i, :]} \mathcal{R}_i,\]

where \(\mathcal{R}_i \in \mathbb{R}^{d \times d}\) is the block-diagonal embedding matrix defined for the $i$-th position. Since different \(\mathcal{R}_i\) matrices are separately multiplied to each row of \(\mathbf{Q}\) and \(\mathbf{K}\), projection matrices \(\mathbf{W}_Q\) and \(\mathbf{W}_K\) are differently post-processed for each position, thus merging \(\mathbf{W}_{UK}\) into \(\mathbf{W}_Q\) is not possible anymore.

DeepSeek-V2 [6] sidesteps this issue by splitting the column dimension of key matrices into two parts (\(d_c\) and \(d_r\)), where one part with dimension size of \(d_c\) is used for the MLA, and the other part with dimension size of \(d_r\) is used for the positional encoding which computes GQA with $g=h$ (i.e., multi-query attention).

References

[1] Jianlin, Su. (2024). The ultimate tug-of-war between cache and capacity from MHA, MQA, GQA to MLA. Link: https://yuxi-liu-wired.github.io/docs/posts/2024-05-13-multi-latent-attention/.

[2] Fanxu Meng, Zengwei Yao, Muhan Zhang (2025). TransMLA: Multi-Head Latent Attention Is All You Need. arxiv preprint arxiv:2502.07864.

[3] Vaswani, A., Shazeer, N., Parmar, N., Uszkoreit, J., Jones, L., Gomez, A. N., Kaiser, L., Polosukhin, I. (2017). Attention is all you need. In Advances in neural information processing systems (pp. 5998-6008).

[4] Shazeer, N. (2019). Fast Transformer Decoding: One Write-Head is All You Need. arXiv preprint arXiv:1911.02150.

[5] Ainslie, J., Lee-Thorp, J., De Jong, M., Zemlyanskiy, Y., Lebron, F., & Sanghai, S. (2023). GQA: Training Generalized Multi-Query Transformer Models from Multi-Head Checkpoints. arXiv preprint arXiv:2305.13245.

[6] DeepSeek-AI. (2025). DeepSeek-V2: A Strong, Economical, and Efficient Mixture-of-Experts Language Model. arxiv preprint arxiv:2405.04434.

[7] Su, J., Lu, Y., Pan, S., Murtadha, A., Wen, B., Liu, Y. (2021). RoFormer: Enhanced Transformer with Rotary Position Embedding. arXiv preprint arXiv:2104.09864.