
大模型注意力机制:MHA GQA MQA MLA理论与实践 原创
注意力机制是 Transformer 架构的灵魂,也是大模型性能与效率平衡的关键。从最初的多头注意力(MHA)到最新的多头潜在注意力(MLA),研究者们通过不断优化键(Key)、值(Value)与查询(Query)的交互方式,在模型表达能力与计算效率之间持续探索。本文将系统梳理 MHA、MQA、GQA、MLA 四种主流注意力机制的理论根基,剖析其设计动机、核心原理与代码实践。
一、多头注意力(MHA):并行特征捕捉的奠基之作
1.1 设计动机:突破单头注意力的表达瓶颈
在 Transformer 提出之前,传统注意力机制(如 Bahdanau 注意力)通过单组 Query、Key、Value 计算序列依赖,难以同时捕捉不同维度的特征模式(如语法结构、语义关联)。MHA 的核心创新在于将输入映射到多个子空间并行计算注意力,使模型能同时关注序列中不同位置的多维度特征,从而增强对复杂模式的建模能力。
图片
1.2 核心原理:三分三拆,聚合增强
MHA 的计算过程可概括为 “线性变换 - 多头拆分 - 注意力计算 - 聚合投影” 四步:
- 线性变换:输入序列通过三个可学习矩阵生成 Query(Q)、Key(K)、Value(V),维度均为(batch_size, seq_len, hidden_size)。
- 多头拆分:将 Q、K、V 按头数(num_heads)拆分,每个头的维度为(head_dim = hidden_size /num_heads),形状调整为(batch_size, num_heads, seq_len, head_dim)。
- 缩放点积注意力:每个头独立计算注意力权重,公式如下,其中根号 d_{k} 为缩放因子,缓解点积过大导致的梯度消失问题。
- 聚合投影:将所有头的输出拼接,通过线性层映射回原始维度(hidden_size)。
1.3 理论优势与局限
- 优势:多头并行机制使模型能捕捉多尺度特征(如局部句法与全局语义),是大模型强表达能力的核心来源。
- 局限:参数量与计算量随头数线性增长(仅 Q、K、V 的线性层参数量就达
,且推理时需缓存所有头的 K、V,导致 KV 缓存占用过高(每 token 缓存
,限制长序列与大规模部署。
二、多查询注意力(MQA):极致效率的参数共享方案
2.1 设计动机:破解 KV 缓存的内存瓶颈
MHA 的 KV 缓存随头数线性增长,在长序列场景(如文档理解)中极易引发显存溢出。MQA 的核心思路是通过共享 K、V 参数减少冗余计算与存储,在牺牲部分表达能力的前提下换取效率提升。
2.2 核心原理:单组 KV,多头 Query
MQA 对 MHA 的改进体现在参数共享策略:
- Query 保持多头独立:Q 仍通过多头线性层生成,确保每个头的查询能力差异化。
- Key 与 Value 全局共享:所有头共享一组 K、V 参数,即 K、V 的线性层输出维度为(batch_size, seq_len, head_dim),而非 MHA 的(batch_size, seq_len, hidden_size)。
- 广播扩展:通过
unsqueeze
与expand
操作将共享的 K、V 扩展到所有头,实现多头注意力计算。
2.3 理论优势与局限
- 优势:参数量大幅降低(K、V 的线性层参数量从
- 降至
,KV 缓存量仅为 MHA 的 1/num_heads,推理速度提升显著。
- 局限:K、V 的全局共享导致头间特征区分度下降,可能损失模型表达能力(尤其长序列任务中对细微差异的捕捉能力)。
三、分组查询注意力(GQA):性能与效率的折中之道
3.1 设计动机:平衡表达与效率的中间方案
MQA 虽高效但损失过多表达能力,MHA 虽强但成本过高。GQA 通过分组共享 KV 参数,在两者间找到平衡:将 Query 头划分为若干组,每组共享一组 K、V,既减少冗余又保留一定的头间差异。
3.2 核心原理:分组共享,局部独立
- 分组策略:设总头数为num_heads,每组包含group_size个头,则组数为num_groups = num_heads / group_size。
- KV 按组生成:K、V 的线性层输出维度为(batch_size,seq_len,num_groups,head_dim),每组对应一组独立的 K、V。
- 扩展计算:通过
unsqueeze
与expand
将每组 K、V 扩展到组内所有头,实现分组注意力计算。
3.3 理论优势与局限
- 优势:参数量与 KV 缓存量为 MHA 的num_groups / num_heads(如 8 头分为 4 组,成本降至 50%),同时保留组间差异,表达能力优于 MQA。
- 局限:性能依赖分组大小group_size 的选择,过小则接近 MHA(效率低),过大则接近 MQA(表达弱),需根据任务调优。
四、多头潜在注意力(MLA):低秩压缩与位置解耦的创新融合
4.1 设计动机:低秩分解与位置编码的协同优化
MHA、MQA、GQA 均未突破 “显式生成 K、V 并缓存” 的范式,而 MLA 通过低秩参数化压缩 KV 维度,并解耦内容与位置信息,实现效率与性能的双重突破,尤其适合长序列与大规模模型部署。
4.2 核心原理:低秩压缩 + 位置解耦,双线并行
MLA 的创新体现在两个关键设计:
4.3 理论优势与局限
- 优势:参数量(约为 MHA 的 42%)与 KV 缓存量(压缩至 GQA 的 1/5~1/10)大幅降低,同时通过低秩分解保留关键特征,表达能力接近 MHA。
- 局限:低秩投影与位置解耦增加了模型复杂度,实现难度高于前三种机制,且需针对性优化矩阵合并(如 “吸收” 操作)以避免计算冗余。
五、四种机制的理论对比:从参数到能力的全面权衡
机制 | 核心创新 | 参数量(相对值) | KV 缓存量(相对值) | 表达能力 | 适用场景 |
MHA | 多头并行 | 1.0 | 1.0 | 最强 | 预训练、高性能需求 |
MQA | 全局共享 KV | 0.56 | 1/num_heads | 较弱 | 边缘部署、高并发推理 |
GQA | 分组共享 KV | 0.75 | num_groups/num_heads | 较强 | 通用大模型、平衡需求 |
MLA | 低秩压缩 + 位置解耦 | 0.42 | ~0.1 | 强 | 长序列、大规模部署 |
六、实践代码实现
下面代码来自:https://mp.weixin.qq.com/s/j5J2qRCNDa7NTOHirx4kvA 侵删
6.1 多头注意力(MHA)实现
import torch
import torch.nn as nn
class MultiHeadAttention(nn.Module):
def __init__(self, hidden_size, num_heads, dropout=0.0):
super(MultiHeadAttention, self).__init__()
assert hidden_size % num_heads == 0, "hidden_size 必须能被 num_heads 整除"
self.hidden_size = hidden_size
self.num_heads = num_heads
self.head_dim = hidden_size // num_heads # 每个头的维度
# 定义Q、K、V的线性变换层
self.query = nn.Linear(hidden_size, hidden_size)
self.key = nn.Linear(hidden_size, hidden_size)
self.value = nn.Linear(hidden_size, hidden_size)
self.dropout = nn.Dropout(dropout)
self.out_projection = nn.Linear(hidden_size, hidden_size)
def forward(self, hidden_state, attention_mask=None):
batch_size, seq_len, _ = hidden_state.size()
# 生成Q、K、V
query = self.query(hidden_state) # [batch_size, seq_len, hidden_size]
key = self.key(hidden_state) # [batch_size, seq_len, hidden_size]
value = self.value(hidden_state) # [batch_size, seq_len, hidden_size]
# 拆分多头
# [batch_size, num_heads, seq_len, head_dim]
query = query.view(batch_size, seq_len, self.num_heads, self.head_dim).transpose(1, 2)
# [batch_size, num_heads, seq_len, head_dim]
key = key.view(batch_size, seq_len, self.num_heads, self.head_dim).transpose(1, 2)
# [batch_size, num_heads, seq_len, head_dim]
value = value.view(batch_size, seq_len, self.num_heads, self.head_dim).transpose(1, 2)
# 计算注意力权重
# [batch_size, num_heads, seq_len, seq_len]
attention_weights = torch.matmul(query, key.transpose(-2, -1)) / (self.head_dim **0.5)
# 应用掩码
if attention_mask is not None:
attention_weights = attention_weights.masked_fill(attention_mask[:, None, None, :] == 0, float('-inf'))
# [batch_size, num_heads, seq_len, seq_len]
attention_weights = torch.softmax(attention_weights, dim=-1)
attention_weights = self.dropout(attention_weights)
# 计算上下文向量
context = torch.matmul(attention_weights, value)
# 合并多头 # [batch_size, seq_len, hidden_size]
context = context.transpose(1, 2).contiguous().view(batch_size, seq_len, self.hidden_size)
# 输出投影 # [batch_size, seq_len, hidden_size]
output = self.out_projection(context)
return output
# 示例用法
if __name__ == '__main__':
batch_size = 2
seq_len = 10
hidden_size = 256
num_heads = 8
mha = MultiHeadAttention(hidden_size, num_heads)
hidden_state = torch.randn(batch_size, seq_len, hidden_size)
attention_mask = torch.ones(batch_size, seq_len)
attention_mask[:, 5:] = 0 # 屏蔽后5个位置
output = mha(hidden_state, attention_mask)
print("MHA输出形状:", output.shape) # torch.Size([2, 10, 256])
6.2 多查询注意力(MQA)实现
import torch
import torch.nn as nn
class MultiQueryAttention(nn.Module):
def __init__(self, hidden_size, num_heads, dropout=0.0):
super(MultiQueryAttention, self).__init__()
assert hidden_size % num_heads == 0, "hidden_size 必须能被 num_heads 整除"
self.hidden_size = hidden_size
self.num_heads = num_heads
self.head_dim = hidden_size // num_heads
# Q保持多头独立,K和V共享
self.query = nn.Linear(hidden_size, hidden_size)
self.key = nn.Linear(hidden_size, self.head_dim)
self.value = nn.Linear(hidden_size, self.head_dim)
self.dropout = nn.Dropout(dropout)
self.out_projection = nn.Linear(hidden_size, hidden_size)
def forward(self, hidden_state, attention_mask=None):
batch_size, seq_len, _ = hidden_state.size()
# 生成Q、K、V
# [batch_size, seq_len, hidden_size]
query = self.query(hidden_state)
# [batch_size, seq_len, head_dim]
key = self.key(hidden_state)
# [batch_size, seq_len, head_dim]
value = self.value(hidden_state)
# 拆分Q为多头
# [batch_size, num_heads, seq_len, head_dim]
query = query.view(batch_size, seq_len, self.num_heads, self.head_dim).transpose(1, 2)
# 扩展K和V到多头
# [batch_size, num_heads, seq_len, head_dim]
key = key.unsqueeze(1).expand(-1, self.num_heads, -1, -1)
# [batch_size, num_heads, seq_len, head_dim]
value = value.unsqueeze(1).expand(-1, self.num_heads, -1, -1)
# 计算注意力权重 # [batch_size, num_heads, seq_len, seq_len]
attention_weights = torch.matmul(query, key.transpose(-2, -1)) / (self.head_dim **0.5)
# 应用掩码
if attention_mask is not None:
attention_weights = attention_weights.masked_fill(attention_mask[:, None, None, :] == 0, float('-inf'))
# [batch_size, num_heads, seq_len, seq_len]
attention_weights = torch.softmax(attention_weights, dim=-1)
attention_weights = self.dropout(attention_weights)
# 计算上下文向量
# [batch_size, num_heads, seq_len, head_dim]
context = torch.matmul(attention_weights, value)
# 合并多头 # [batch_size, seq_len, hidden_size]
context = context.transpose(1, 2).contiguous().view(batch_size, seq_len, self.hidden_size)
# 输出投影 # [batch_size, seq_len, hidden_size]
output = self.out_projection(context)
return output
# 示例用法
if __name__ == '__main__':
batch_size = 2
seq_len = 10
hidden_size = 256
num_heads = 8
mqa = MultiQueryAttention(hidden_size, num_heads)
hidden_state = torch.randn(batch_size, seq_len, hidden_size)
attention_mask = torch.ones(batch_size, seq_len)
attention_mask[:, 5:] = 0
output = mqa(hidden_state, attention_mask)
print("MQA输出形状:", output.shape) # torch.Size([2, 10, 256])
6.3 分组查询注意力(GQA)实现
import torch
import torch.nn as nn
class GroupedQueryAttention(nn.Module):
def __init__(self, hidden_size, num_heads, group_size=2, dropout=0.0):
super(GroupedQueryAttention, self).__init__()
assert hidden_size % num_heads == 0, "hidden_size 必须能被 num_heads 整除"
assert num_heads % group_size == 0, "num_heads 必须能被 group_size 整除"
self.hidden_size = hidden_size
self.num_heads = num_heads
self.group_size = group_size
self.group_num = num_heads // group_size
self.head_dim = hidden_size // num_heads
# Q保持多头独立,K和V按组共享
self.query = nn.Linear(hidden_size, hidden_size)
self.key = nn.Linear(hidden_size, self.group_num * self.head_dim)
self.value = nn.Linear(hidden_size, self.group_num * self.head_dim)
self.dropout = nn.Dropout(dropout)
self.out_projection = nn.Linear(hidden_size, hidden_size)
def forward(self, hidden_state, attention_mask=None):
batch_size, seq_len, _ = hidden_state.size()
# 生成Q、K、V
# [batch_size, seq_len, hidden_size]
query = self.query(hidden_state)
# [batch_size, seq_len, group_num * head_dim]
key = self.key(hidden_state)
# [batch_size, seq_len, group_num * head_dim]
value = self.value(hidden_state)
# 拆分Q为多头
query = query.view(batch_size, seq_len, self.num_heads, self.head_dim).transpose(1, 2)
# 拆分K和V为组并扩展到多头
# [batch_size, group_num, seq_len, head_dim]
key = key.view(batch_size, seq_len, self.group_num, self.head_dim).transpose(1, 2)
# [batch_size, num_heads, seq_len, head_dim]
key = key.unsqueeze(2).expand(-1, -1, self.group_size, -1, -1).contiguous().view(batch_size, -1, seq_len, self.head_dim)
# [batch_size, group_num, seq_len, head_dim]
value = value.view(batch_size, seq_len, self.group_num, self.head_dim).transpose(1, 2)
# [batch_size, num_heads, seq_len, head_dim]
value = value.unsqueeze(2).expand(-1, -1, self.group_size, -1, -1).contiguous().view(batch_size, -1, seq_len, self.head_dim)
# 计算注意力权重
attention_weights = torch.matmul(query, key.transpose(-2, -1)) / (self.head_dim **0.5)
# 应用掩码
if attention_mask is not None:
attention_weights = attention_weights.masked_fill(attention_mask[:, None, None, :] == 0, float('-inf'))
attention_weights = torch.softmax(attention_weights, dim=-1)
attention_weights = self.dropout(attention_weights)
# 计算上下文向量
context = torch.matmul(attention_weights, value)
# 合并多头
context = context.transpose(1, 2).contiguous().view(batch_size, seq_len, self.hidden_size)
# 输出投影
output = self.out_projection(context)
return output
# 示例用法
if __name__ == '__main__':
batch_size = 2
seq_len = 10
hidden_size = 256
num_heads = 8
group_size = 2 # 每组2个头,共4组
gqa = GroupedQueryAttention(hidden_size, num_heads, group_size)
hidden_state = torch.randn(batch_size, seq_len, hidden_size)
attention_mask = torch.ones(batch_size, seq_len)
attention_mask[:, 5:] = 0
output = gqa(hidden_state, attention_mask)
print("GQA输出形状:", output.shape) # torch.Size([2, 10, 256])
6.4 多头潜在注意力(MLA)实现
import torch
import torch.nn as nn
import math
class RotaryEmbedding(nn.Module):
def __init__(self, hidden_size, num_heads, base=10000, max_len=512):
super().__init__()
self.head_dim = hidden_size // num_heads
self.hidden_size = hidden_size
self.num_heads = num_heads
self.base = base
self.max_len = max_len
self.cos_pos_cache, self.sin_pos_cache = self._compute_pos_emb()
def _compute_pos_emb(self):
theta_i = 1. / (self.base **(torch.arange(0, self.head_dim, 2).float() / self.head_dim))
positions = torch.arange(self.max_len)
pos_emb = positions.unsqueeze(1) * theta_i.unsqueeze(0)
cos_pos = pos_emb.sin().repeat_interleave(2, dim=-1)
sin_pos = pos_emb.cos().repeat_interleave(2, dim=-1)
return cos_pos, sin_pos
def forward(self, q):
bs, seq_len = q.shape[0], q.shape[2]
# [seq_len, head_dim]
cos_pos = self.cos_pos_cache[:seq_len].to(q.device)
# [seq_len, head_dim]
sin_pos = self.sin_pos_cache[:seq_len].to(q.device)
# [1, 1, seq_len, head_dim]
cos_pos = cos_pos.unsqueeze(0).unsqueeze(0)
# [1, 1, seq_len, head_dim]
sin_pos = sin_pos.unsqueeze(0).unsqueeze(0)
# RoPE变换
q2 = torch.stack([-q[..., 1::2], q[..., ::2]], dim=-1)
q2 = q2.reshape(q.shape).contiguous()
return q * cos_pos + q2 * sin_pos
class MultiHeadLatentAttention(nn.Module):
def __init__(self, hidden_size=256, down_dim=64, up_dim=128, num_heads=8, rope_head_dim=26, dropout_prob=0.0):
super(MultiHeadLatentAttention, self).__init__()
self.d_model = hidden_size
self.down_dim = down_dim
self.up_dim = up_dim
self.num_heads = num_heads
self.head_dim = hidden_size // num_heads
self.rope_head_dim = rope_head_dim
self.v_head_dim = up_dim // num_heads
# 降维投影
self.down_proj_kv = nn.Linear(hidden_size, down_dim)
self.down_proj_q = nn.Linear(hidden_size, down_dim)
# 升维投影
self.up_proj_k = nn.Linear(down_dim, up_dim)
self.up_proj_v = nn.Linear(down_dim, up_dim)
self.up_proj_q = nn.Linear(down_dim, up_dim)
# 解耦Q/K投影
self.proj_qr = nn.Linear(down_dim, rope_head_dim * num_heads)
self.proj_kr = nn.Linear(hidden_size, rope_head_dim)
# RoPE位置编码
self.rope_q = RotaryEmbedding(rope_head_dim * num_heads, num_heads)
self.rope_k = RotaryEmbedding(rope_head_dim, 1)
# 输出层
self.dropout = nn.Dropout(dropout_prob)
self.fc = nn.Linear(num_heads * self.v_head_dim, hidden_size)
self.res_dropout = nn.Dropout(dropout_prob)
def forward(self, h, mask=None):
bs, seq_len, _ = h.size()
# 低秩转换
# [bs, seq_len, down_dim]
c_t_kv = self.down_proj_kv(h)
# [bs, seq_len, up_dim]
k_t_c = self.up_proj_k(c_t_kv)
# [bs, seq_len, up_dim]
v_t_c = self.up_proj_v(c_t_kv)
# [bs, seq_len, down_dim]
c_t_q = self.down_proj_q(h)
# [bs, seq_len, up_dim]
q_t_c = self.up_proj_q(c_t_q)
# 解耦Q/K处理
# [bs, seq_len, rope_head_dim*num_heads]
q_t_r = self.proj_qr(c_t_q)
# [bs, num_heads, seq_len, rope_head_dim]
q_t_r = q_t_r.view(bs, seq_len, self.num_heads, self.rope_head_dim).transpose(1, 2)
# RoPE投影处理
q_t_r = self.rope_q(q_t_r)
# [bs, seq_len, rope_head_dim]
k_t_r = self.proj_kr(h)
# [bs, 1, seq_len, rope_head_dim]
k_t_r = k_t_r.unsqueeze(1)
# 应用RoPE编码
k_t_r = self.rope_k(k_t_r)
# 注意力计算
# [bs, num_heads, seq_len, up_dim/num_heads]
q_t_c = q_t_c.view(bs, seq_len, self.num_heads, -1).transpose(1, 2)
# [bs, num_heads, seq_len, (up_dim+rope_head_dim)/num_heads]
q = torch.cat([q_t_c, q_t_r], dim=-1)
# [bs, num_heads, seq_len, up_dim/num_heads]
k_t_c = k_t_c.view(bs, seq_len, self.num_heads, -1).transpose(1, 2)
# [bs, num_heads, seq_len, rope_head_dim]
k_t_r = k_t_r.expand(bs, self.num_heads, seq_len, -1)
# [bs, num_heads, seq_len, (up_dim+rope_head_dim)/num_heads]
k = torch.cat([k_t_c, k_t_r], dim=-1)
# [bs, num_heads, seq_len, seq_len]
scores = torch.matmul(q, k.transpose(-1, -2))
scores = scores / (math.sqrt(self.head_dim) + math.sqrt(self.rope_head_dim))
if mask is not None:
# [bs, num_heads, seq_len, seq_len]
scores = scores.masked_fill(mask[:, None, None, :] == 0, float('-inf'))
# [bs, num_heads, seq_len, seq_len]
attn_weights = torch.softmax(scores, dim=-1)
attn_weights = self.dropout(attn_weights)
# V维度调整 # [bs, num_heads, seq_len, v_head_dim]
v_t_c = v_t_c.view(bs, seq_len, self.num_heads, self.v_head_dim).transpose(1, 2)
# 计算上下文向量
# [bs, num_heads, seq_len, v_head_dim]
context = torch.matmul(attn_weights, v_t_c)
# 合并多头 # [bs, seq_len, num_heads*v_head_dim]
context = context.transpose(1, 2).contiguous().view(bs, seq_len, -1)
# 输出投影
# [bs, seq_len, d_model]
output = self.fc(context)
output = self.res_dropout(output)
return output
# 示例用法
if __name__ == '__main__':
batch_size = 2
seq_len = 10
hidden_size = 256
h = torch.randn(batch_size, seq_len, hidden_size)
mla = MultiHeadLatentAttention(hidden_size=hidden_size)
mask = torch.ones(batch_size, seq_len)
mask[:, 5:] = 0
output = mla(h, mask)
print("MLA输出形状:", output.shape) # torch.Size([2, 10, 256])
七、总结与建议
从 MHA 到 MLA 的演进,本质是 “表达能力 - 计算效率” 的权衡艺术:MHA 奠定了多头并行的基础,MQA 与 GQA 通过参数共享优化效率,MLA 则通过低秩分解与位置解耦实现了质的突破。
四种注意力机制各有优劣,在实际应用中需根据具体场景选择:
- MHA:适用于对性能要求高、资源充足的场景,如预训练阶段。
- MQA:适用于资源受限、对推理速度要求高的场景,如边缘设备部署。
- GQA:大多数情况下的优选,在性能与效率间取得平衡,适合通用大模型。
- MLA:适用于长序列任务和大规模模型部署,在显存有限的情况下表现出色。
随着大模型向更大参数量、更长序列发展,注意力机制的优化将持续推进。开发者应根据实际需求选择合适的机制,并关注最新研究进展,不断提升模型的性能与效率。
参考文献
- 宋志学,《手撕大模型 Attention:MLA、MHA、MQA 与 GQA (含实现代码)》,https://mp.weixin.qq.com/s/j5J2qRCNDa7NTOHirx4kvA,2025-05-20,微信公众号
- 苏剑林,《Transformer 升级之路:多头潜在注意力机制 (MLA) 究竟好在哪里?》,https://mp.weixin.qq.com/s/KdOjWF4n5gNtQxKKvkG5Mw,2025-05-22,微信公众号
- 姜富春,《DeepSeek 技术解读 1: 彻底理解 MLA》,https://mp.weixin.qq.com/s/yL_Z8zcAfWDcviZwApdL_w,2025-01-15,微信公众号
- 算法狗,《DeepSeek MLA: 高效推理的省钱之道,全流程剖析》,https://mp.weixin.qq.com/s/yNxjgQMl2LKzpGOoCWRRcw,2025-02-19,微信公众号
- 羽说 AI 研究圈,《从 MHA→MQA→GQA→MLA》,https://mp.weixin.qq.com/s/S9dfOCrWeru6zGjOjchV7Q,2025-02-12,微信公众号
本文转载自鸿煊的学习笔记,作者:乘风破浪jxj
