大模型注意力机制:MHA GQA MQA MLA理论与实践 原创

发布于 2025-7-17 14:03
浏览
0收藏

注意力机制是 Transformer 架构的灵魂,也是大模型性能与效率平衡的关键。从最初的多头注意力(MHA)到最新的多头潜在注意力(MLA),研究者们通过不断优化键(Key)、值(Value)与查询(Query)的交互方式,在模型表达能力与计算效率之间持续探索。本文将系统梳理 MHA、MQA、GQA、MLA 四种主流注意力机制的理论根基,剖析其设计动机、核心原理与代码实践。

一、多头注意力(MHA):并行特征捕捉的奠基之作

1.1 设计动机:突破单头注意力的表达瓶颈

在 Transformer 提出之前,传统注意力机制(如 Bahdanau 注意力)通过单组 Query、Key、Value 计算序列依赖,难以同时捕捉不同维度的特征模式(如语法结构、语义关联)。MHA 的核心创新在于将输入映射到多个子空间并行计算注意力,使模型能同时关注序列中不同位置的多维度特征,从而增强对复杂模式的建模能力。

大模型注意力机制:MHA GQA MQA MLA理论与实践-AI.x社区图片

1.2 核心原理:三分三拆,聚合增强

MHA 的计算过程可概括为 “线性变换 - 多头拆分 - 注意力计算 - 聚合投影” 四步:

  • 线性变换:输入序列通过三个可学习矩阵生成 Query(Q)、Key(K)、Value(V),维度均为(batch_size, seq_len, hidden_size)。
  • 多头拆分:将 Q、K、V 按头数(num_heads)拆分,每个头的维度为(head_dim = hidden_size /num_heads),形状调整为(batch_size, num_heads, seq_len, head_dim)。
  • 缩放点积注意力:每个头独立计算注意力权重,公式如下,其中根号 d_{k} 为缩放因子,缓解点积过大导致的梯度消失问题。

大模型注意力机制:MHA GQA MQA MLA理论与实践-AI.x社区

  • 聚合投影:将所有头的输出拼接,通过线性层映射回原始维度(hidden_size)。

1.3 理论优势与局限

  • 优势:多头并行机制使模型能捕捉多尺度特征(如局部句法与全局语义),是大模型强表达能力的核心来源。
  • 局限:参数量与计算量随头数线性增长(仅 Q、K、V 的线性层参数量就达

大模型注意力机制:MHA GQA MQA MLA理论与实践-AI.x社区

,且推理时需缓存所有头的 K、V,导致 KV 缓存占用过高(每 token 缓存

大模型注意力机制:MHA GQA MQA MLA理论与实践-AI.x社区

,限制长序列与大规模部署。

二、多查询注意力(MQA):极致效率的参数共享方案

2.1 设计动机:破解 KV 缓存的内存瓶颈

MHA 的 KV 缓存随头数线性增长,在长序列场景(如文档理解)中极易引发显存溢出。MQA 的核心思路是通过共享 K、V 参数减少冗余计算与存储,在牺牲部分表达能力的前提下换取效率提升。

大模型注意力机制:MHA GQA MQA MLA理论与实践-AI.x社区

2.2 核心原理:单组 KV,多头 Query

MQA 对 MHA 的改进体现在参数共享策略:

  • Query 保持多头独立:Q 仍通过多头线性层生成,确保每个头的查询能力差异化。
  • Key 与 Value 全局共享:所有头共享一组 K、V 参数,即 K、V 的线性层输出维度为(batch_size, seq_len, head_dim),而非 MHA 的(batch_size, seq_len, hidden_size)。
  • 广播扩展:通过​​unsqueeze​​​与​​expand​​操作将共享的 K、V 扩展到所有头,实现多头注意力计算。

2.3 理论优势与局限

  • 优势:参数量大幅降低(K、V 的线性层参数量从
  • 大模型注意力机制:MHA GQA MQA MLA理论与实践-AI.x社区

  • 降至

大模型注意力机制:MHA GQA MQA MLA理论与实践-AI.x社区

,KV 缓存量仅为 MHA 的 1/num_heads,推理速度提升显著。

  • 局限:K、V 的全局共享导致头间特征区分度下降,可能损失模型表达能力(尤其长序列任务中对细微差异的捕捉能力)。

三、分组查询注意力(GQA):性能与效率的折中之道

3.1 设计动机:平衡表达与效率的中间方案

MQA 虽高效但损失过多表达能力,MHA 虽强但成本过高。GQA 通过分组共享 KV 参数,在两者间找到平衡:将 Query 头划分为若干组,每组共享一组 K、V,既减少冗余又保留一定的头间差异。

大模型注意力机制:MHA GQA MQA MLA理论与实践-AI.x社区

3.2 核心原理:分组共享,局部独立

  • 分组策略:设总头数为num_heads,每组包含group_size个头,则组数为num_groups = num_heads / group_size
  • KV 按组生成:K、V 的线性层输出维度为(batch_size,seq_len,num_groups,head_dim),每组对应一组独立的 K、V。
  • 扩展计算:通过​​unsqueeze​​​与​​expand​​将每组 K、V 扩展到组内所有头,实现分组注意力计算。

3.3 理论优势与局限

  • 优势:参数量与 KV 缓存量为 MHA 的num_groups / num_heads(如 8 头分为 4 组,成本降至 50%),同时保留组间差异,表达能力优于 MQA。
  • 局限:性能依赖分组大小group_size 的选择,过小则接近 MHA(效率低),过大则接近 MQA(表达弱),需根据任务调优。

四、多头潜在注意力(MLA):低秩压缩与位置解耦的创新融合

4.1 设计动机:低秩分解与位置编码的协同优化

MHA、MQA、GQA 均未突破 “显式生成 K、V 并缓存” 的范式,而 MLA 通过低秩参数化压缩 KV 维度,并解耦内容与位置信息,实现效率与性能的双重突破,尤其适合长序列与大规模模型部署。

大模型注意力机制:MHA GQA MQA MLA理论与实践-AI.x社区

4.2 核心原理:低秩压缩 + 位置解耦,双线并行

MLA 的创新体现在两个关键设计:

大模型注意力机制:MHA GQA MQA MLA理论与实践-AI.x社区

4.3 理论优势与局限

  • 优势:参数量(约为 MHA 的 42%)与 KV 缓存量(压缩至 GQA 的 1/5~1/10)大幅降低,同时通过低秩分解保留关键特征,表达能力接近 MHA。
  • 局限:低秩投影与位置解耦增加了模型复杂度,实现难度高于前三种机制,且需针对性优化矩阵合并(如 “吸收” 操作)以避免计算冗余。

五、四种机制的理论对比:从参数到能力的全面权衡

机制

核心创新

参数量(相对值)

KV 缓存量(相对值)

表达能力

适用场景

MHA

多头并行

1.0

1.0

最强

预训练、高性能需求

MQA

全局共享 KV

0.56

1/num_heads

较弱

边缘部署、高并发推理

GQA

分组共享 KV

0.75

num_groups/num_heads

较强

通用大模型、平衡需求

MLA

低秩压缩 + 位置解耦

0.42

~0.1

长序列、大规模部署

六、实践代码实现

下面代码来自:​​https://mp.weixin.qq.com/s/j5J2qRCNDa7NTOHirx4kvA​​ 侵删

6.1 多头注意力(MHA)实现

import torch
import torch.nn as nn


class MultiHeadAttention(nn.Module):
    def __init__(self, hidden_size, num_heads, dropout=0.0):
        super(MultiHeadAttention, self).__init__()
        assert hidden_size % num_heads == 0, "hidden_size 必须能被 num_heads 整除"


        self.hidden_size = hidden_size
        self.num_heads = num_heads
        self.head_dim = hidden_size // num_heads # 每个头的维度


        # 定义Q、K、V的线性变换层
        self.query = nn.Linear(hidden_size, hidden_size)
        self.key = nn.Linear(hidden_size, hidden_size)
        self.value = nn.Linear(hidden_size, hidden_size)


        self.dropout = nn.Dropout(dropout)
        self.out_projection = nn.Linear(hidden_size, hidden_size)


    def forward(self, hidden_state, attention_mask=None):
        batch_size, seq_len, _ = hidden_state.size()


        # 生成Q、K、V
        query = self.query(hidden_state) # [batch_size, seq_len, hidden_size]
        key = self.key(hidden_state) # [batch_size, seq_len, hidden_size]
        value = self.value(hidden_state) # [batch_size, seq_len, hidden_size]


        # 拆分多头
        # [batch_size, num_heads, seq_len, head_dim]
        query = query.view(batch_size, seq_len, self.num_heads, self.head_dim).transpose(1, 2)
        # [batch_size, num_heads, seq_len, head_dim]
        key = key.view(batch_size, seq_len, self.num_heads, self.head_dim).transpose(1, 2)
        # [batch_size, num_heads, seq_len, head_dim]
        value = value.view(batch_size, seq_len, self.num_heads, self.head_dim).transpose(1, 2)


        # 计算注意力权重
        # [batch_size, num_heads, seq_len, seq_len]
        attention_weights = torch.matmul(query, key.transpose(-2, -1)) / (self.head_dim **0.5)


        # 应用掩码
        if attention_mask is not None:
            attention_weights = attention_weights.masked_fill(attention_mask[:, None, None, :] == 0, float('-inf'))


        # [batch_size, num_heads, seq_len, seq_len]
        attention_weights = torch.softmax(attention_weights, dim=-1)
        attention_weights = self.dropout(attention_weights)


        # 计算上下文向量
        context = torch.matmul(attention_weights, value)


        # 合并多头 # [batch_size, seq_len, hidden_size]
        context = context.transpose(1, 2).contiguous().view(batch_size, seq_len, self.hidden_size)


        # 输出投影 # [batch_size, seq_len, hidden_size]
        output = self.out_projection(context)
        return output


# 示例用法
if __name__ == '__main__':
    batch_size = 2
    seq_len = 10
    hidden_size = 256
    num_heads = 8


    mha = MultiHeadAttention(hidden_size, num_heads)
    hidden_state = torch.randn(batch_size, seq_len, hidden_size)
    attention_mask = torch.ones(batch_size, seq_len)
    attention_mask[:, 5:] = 0  # 屏蔽后5个位置


    output = mha(hidden_state, attention_mask)
    print("MHA输出形状:", output.shape)  # torch.Size([2, 10, 256])

6.2 多查询注意力(MQA)实现

import torch
import torch.nn as nn


class MultiQueryAttention(nn.Module):
    def __init__(self, hidden_size, num_heads, dropout=0.0):
        super(MultiQueryAttention, self).__init__()
        assert hidden_size % num_heads == 0, "hidden_size 必须能被 num_heads 整除"


        self.hidden_size = hidden_size
        self.num_heads = num_heads
        self.head_dim = hidden_size // num_heads


        # Q保持多头独立,K和V共享
        self.query = nn.Linear(hidden_size, hidden_size)
        self.key = nn.Linear(hidden_size, self.head_dim)
        self.value = nn.Linear(hidden_size, self.head_dim)


        self.dropout = nn.Dropout(dropout)
        self.out_projection = nn.Linear(hidden_size, hidden_size)


    def forward(self, hidden_state, attention_mask=None):
        batch_size, seq_len, _ = hidden_state.size()


        # 生成Q、K、V
        # [batch_size, seq_len, hidden_size]
        query = self.query(hidden_state)
        # [batch_size, seq_len, head_dim]
        key = self.key(hidden_state)
        # [batch_size, seq_len, head_dim]
        value = self.value(hidden_state)


        # 拆分Q为多头
        # [batch_size, num_heads, seq_len, head_dim]
        query = query.view(batch_size, seq_len, self.num_heads, self.head_dim).transpose(1, 2)


        # 扩展K和V到多头
        # [batch_size, num_heads, seq_len, head_dim]
        key = key.unsqueeze(1).expand(-1, self.num_heads, -1, -1)
        # [batch_size, num_heads, seq_len, head_dim]
        value = value.unsqueeze(1).expand(-1, self.num_heads, -1, -1)


        # 计算注意力权重 # [batch_size, num_heads, seq_len, seq_len]
        attention_weights = torch.matmul(query, key.transpose(-2, -1)) / (self.head_dim **0.5)


        # 应用掩码
        if attention_mask is not None:
            attention_weights = attention_weights.masked_fill(attention_mask[:, None, None, :] == 0, float('-inf'))


        # [batch_size, num_heads, seq_len, seq_len]
        attention_weights = torch.softmax(attention_weights, dim=-1)
        attention_weights = self.dropout(attention_weights)


        # 计算上下文向量
        # [batch_size, num_heads, seq_len, head_dim]
        context = torch.matmul(attention_weights, value)


        # 合并多头 # [batch_size, seq_len, hidden_size]
        context = context.transpose(1, 2).contiguous().view(batch_size, seq_len, self.hidden_size)


        # 输出投影 # [batch_size, seq_len, hidden_size]
        output = self.out_projection(context)
        return output


# 示例用法
if __name__ == '__main__':
    batch_size = 2
    seq_len = 10
    hidden_size = 256
    num_heads = 8


    mqa = MultiQueryAttention(hidden_size, num_heads)
    hidden_state = torch.randn(batch_size, seq_len, hidden_size)
    attention_mask = torch.ones(batch_size, seq_len)
    attention_mask[:, 5:] = 0


    output = mqa(hidden_state, attention_mask)
    print("MQA输出形状:", output.shape)  # torch.Size([2, 10, 256])

6.3 分组查询注意力(GQA)实现

import torch
import torch.nn as nn


class GroupedQueryAttention(nn.Module):
    def __init__(self, hidden_size, num_heads, group_size=2, dropout=0.0):
        super(GroupedQueryAttention, self).__init__()
        assert hidden_size % num_heads == 0, "hidden_size 必须能被 num_heads 整除"
        assert num_heads % group_size == 0, "num_heads 必须能被 group_size 整除"


        self.hidden_size = hidden_size
        self.num_heads = num_heads
        self.group_size = group_size
        self.group_num = num_heads // group_size
        self.head_dim = hidden_size // num_heads


        # Q保持多头独立,K和V按组共享
        self.query = nn.Linear(hidden_size, hidden_size)
        self.key = nn.Linear(hidden_size, self.group_num * self.head_dim)
        self.value = nn.Linear(hidden_size, self.group_num * self.head_dim)


        self.dropout = nn.Dropout(dropout)
        self.out_projection = nn.Linear(hidden_size, hidden_size)


    def forward(self, hidden_state, attention_mask=None):
        batch_size, seq_len, _ = hidden_state.size()


        # 生成Q、K、V
        # [batch_size, seq_len, hidden_size]
        query = self.query(hidden_state)
        # [batch_size, seq_len, group_num * head_dim]
        key = self.key(hidden_state)
        # [batch_size, seq_len, group_num * head_dim]
        value = self.value(hidden_state)


        # 拆分Q为多头
        query = query.view(batch_size, seq_len, self.num_heads, self.head_dim).transpose(1, 2)


        # 拆分K和V为组并扩展到多头 
        # [batch_size, group_num, seq_len, head_dim]
        key = key.view(batch_size, seq_len, self.group_num, self.head_dim).transpose(1, 2)
        # [batch_size, num_heads, seq_len, head_dim]
        key = key.unsqueeze(2).expand(-1, -1, self.group_size, -1, -1).contiguous().view(batch_size, -1, seq_len, self.head_dim)


        # [batch_size, group_num, seq_len, head_dim]
        value = value.view(batch_size, seq_len, self.group_num, self.head_dim).transpose(1, 2)
        # [batch_size, num_heads, seq_len, head_dim]
        value = value.unsqueeze(2).expand(-1, -1, self.group_size, -1, -1).contiguous().view(batch_size, -1, seq_len, self.head_dim)


        # 计算注意力权重
        attention_weights = torch.matmul(query, key.transpose(-2, -1)) / (self.head_dim **0.5)


        # 应用掩码
        if attention_mask is not None:
            attention_weights = attention_weights.masked_fill(attention_mask[:, None, None, :] == 0, float('-inf'))


        attention_weights = torch.softmax(attention_weights, dim=-1)
        attention_weights = self.dropout(attention_weights)


        # 计算上下文向量
        context = torch.matmul(attention_weights, value)


        # 合并多头
        context = context.transpose(1, 2).contiguous().view(batch_size, seq_len, self.hidden_size)


        # 输出投影
        output = self.out_projection(context)
        return output


# 示例用法
if __name__ == '__main__':
    batch_size = 2
    seq_len = 10
    hidden_size = 256
    num_heads = 8
    group_size = 2  # 每组2个头,共4组


    gqa = GroupedQueryAttention(hidden_size, num_heads, group_size)
    hidden_state = torch.randn(batch_size, seq_len, hidden_size)
    attention_mask = torch.ones(batch_size, seq_len)
    attention_mask[:, 5:] = 0


    output = gqa(hidden_state, attention_mask)
    print("GQA输出形状:", output.shape)  # torch.Size([2, 10, 256])

6.4 多头潜在注意力(MLA)实现

import torch
import torch.nn as nn
import math


class RotaryEmbedding(nn.Module):
    def __init__(self, hidden_size, num_heads, base=10000, max_len=512):
        super().__init__()
        self.head_dim = hidden_size // num_heads
        self.hidden_size = hidden_size
        self.num_heads = num_heads
        self.base = base
        self.max_len = max_len
        self.cos_pos_cache, self.sin_pos_cache = self._compute_pos_emb()


    def _compute_pos_emb(self):
        theta_i = 1. / (self.base **(torch.arange(0, self.head_dim, 2).float() / self.head_dim))
        positions = torch.arange(self.max_len)
        pos_emb = positions.unsqueeze(1) * theta_i.unsqueeze(0)


        cos_pos = pos_emb.sin().repeat_interleave(2, dim=-1)
        sin_pos = pos_emb.cos().repeat_interleave(2, dim=-1)


        return cos_pos, sin_pos


    def forward(self, q):
        bs, seq_len = q.shape[0], q.shape[2]
        # [seq_len, head_dim]
        cos_pos = self.cos_pos_cache[:seq_len].to(q.device)
        # [seq_len, head_dim]
        sin_pos = self.sin_pos_cache[:seq_len].to(q.device)


        # [1, 1, seq_len, head_dim]
        cos_pos = cos_pos.unsqueeze(0).unsqueeze(0)
        # [1, 1, seq_len, head_dim]
        sin_pos = sin_pos.unsqueeze(0).unsqueeze(0)


        # RoPE变换
        q2 = torch.stack([-q[..., 1::2], q[..., ::2]], dim=-1)
        q2 = q2.reshape(q.shape).contiguous()


        return q * cos_pos + q2 * sin_pos


class MultiHeadLatentAttention(nn.Module):
    def __init__(self, hidden_size=256, down_dim=64, up_dim=128, num_heads=8, rope_head_dim=26, dropout_prob=0.0):
        super(MultiHeadLatentAttention, self).__init__()
        self.d_model = hidden_size
        self.down_dim = down_dim
        self.up_dim = up_dim
        self.num_heads = num_heads
        self.head_dim = hidden_size // num_heads
        self.rope_head_dim = rope_head_dim
        self.v_head_dim = up_dim // num_heads


        # 降维投影
        self.down_proj_kv = nn.Linear(hidden_size, down_dim)
        self.down_proj_q = nn.Linear(hidden_size, down_dim)


        # 升维投影
        self.up_proj_k = nn.Linear(down_dim, up_dim)
        self.up_proj_v = nn.Linear(down_dim, up_dim)
        self.up_proj_q = nn.Linear(down_dim, up_dim)


        # 解耦Q/K投影
        self.proj_qr = nn.Linear(down_dim, rope_head_dim * num_heads)
        self.proj_kr = nn.Linear(hidden_size, rope_head_dim)


        # RoPE位置编码
        self.rope_q = RotaryEmbedding(rope_head_dim * num_heads, num_heads)
        self.rope_k = RotaryEmbedding(rope_head_dim, 1)


        # 输出层
        self.dropout = nn.Dropout(dropout_prob)
        self.fc = nn.Linear(num_heads * self.v_head_dim, hidden_size)
        self.res_dropout = nn.Dropout(dropout_prob)


    def forward(self, h, mask=None):
        bs, seq_len, _ = h.size()


        # 低秩转换
        # [bs, seq_len, down_dim]
        c_t_kv = self.down_proj_kv(h)
        # [bs, seq_len, up_dim]
        k_t_c = self.up_proj_k(c_t_kv)
        # [bs, seq_len, up_dim]
        v_t_c = self.up_proj_v(c_t_kv)
        # [bs, seq_len, down_dim]
        c_t_q = self.down_proj_q(h)
        # [bs, seq_len, up_dim]
        q_t_c = self.up_proj_q(c_t_q)


        # 解耦Q/K处理
        # [bs, seq_len, rope_head_dim*num_heads]
        q_t_r = self.proj_qr(c_t_q)
        # [bs, num_heads, seq_len, rope_head_dim]
        q_t_r = q_t_r.view(bs, seq_len, self.num_heads, self.rope_head_dim).transpose(1, 2)
        # RoPE投影处理
        q_t_r = self.rope_q(q_t_r)


        # [bs, seq_len, rope_head_dim]
        k_t_r = self.proj_kr(h)
        # [bs, 1, seq_len, rope_head_dim]
        k_t_r = k_t_r.unsqueeze(1)
        # 应用RoPE编码
        k_t_r = self.rope_k(k_t_r)


        # 注意力计算
        # [bs, num_heads, seq_len, up_dim/num_heads]
        q_t_c = q_t_c.view(bs, seq_len, self.num_heads, -1).transpose(1, 2)
        # [bs, num_heads, seq_len, (up_dim+rope_head_dim)/num_heads]
        q = torch.cat([q_t_c, q_t_r], dim=-1)


        # [bs, num_heads, seq_len, up_dim/num_heads]
        k_t_c = k_t_c.view(bs, seq_len, self.num_heads, -1).transpose(1, 2)
        # [bs, num_heads, seq_len, rope_head_dim]
        k_t_r = k_t_r.expand(bs, self.num_heads, seq_len, -1)
        # [bs, num_heads, seq_len, (up_dim+rope_head_dim)/num_heads]
        k = torch.cat([k_t_c, k_t_r], dim=-1)


        # [bs, num_heads, seq_len, seq_len]
        scores = torch.matmul(q, k.transpose(-1, -2))
        scores = scores / (math.sqrt(self.head_dim) + math.sqrt(self.rope_head_dim))


        if mask is not None:
            # [bs, num_heads, seq_len, seq_len]
            scores = scores.masked_fill(mask[:, None, None, :] == 0, float('-inf'))


        # [bs, num_heads, seq_len, seq_len]
        attn_weights = torch.softmax(scores, dim=-1)
        attn_weights = self.dropout(attn_weights)


        # V维度调整 # [bs, num_heads, seq_len, v_head_dim]
        v_t_c = v_t_c.view(bs, seq_len, self.num_heads, self.v_head_dim).transpose(1, 2)


        # 计算上下文向量
        # [bs, num_heads, seq_len, v_head_dim]
        context = torch.matmul(attn_weights, v_t_c)


        # 合并多头 # [bs, seq_len, num_heads*v_head_dim]
        context = context.transpose(1, 2).contiguous().view(bs, seq_len, -1)


        # 输出投影
        # [bs, seq_len, d_model]
        output = self.fc(context)
        output = self.res_dropout(output)


        return output


# 示例用法
if __name__ == '__main__':
    batch_size = 2
    seq_len = 10
    hidden_size = 256


    h = torch.randn(batch_size, seq_len, hidden_size)
    mla = MultiHeadLatentAttention(hidden_size=hidden_size)


    mask = torch.ones(batch_size, seq_len)
    mask[:, 5:] = 0


    output = mla(h, mask)
    print("MLA输出形状:", output.shape)  # torch.Size([2, 10, 256])

七、总结与建议

从 MHA 到 MLA 的演进,本质是 “表达能力 - 计算效率” 的权衡艺术:MHA 奠定了多头并行的基础,MQA 与 GQA 通过参数共享优化效率,MLA 则通过低秩分解与位置解耦实现了质的突破。

四种注意力机制各有优劣,在实际应用中需根据具体场景选择:

  • MHA:适用于对性能要求高、资源充足的场景,如预训练阶段。
  • MQA:适用于资源受限、对推理速度要求高的场景,如边缘设备部署。
  • GQA:大多数情况下的优选,在性能与效率间取得平衡,适合通用大模型。
  • MLA:适用于长序列任务和大规模模型部署,在显存有限的情况下表现出色。

随着大模型向更大参数量、更长序列发展,注意力机制的优化将持续推进。开发者应根据实际需求选择合适的机制,并关注最新研究进展,不断提升模型的性能与效率。

参考文献

  1. 宋志学,《手撕大模型 Attention:MLA、MHA、MQA 与 GQA (含实现代码)》,​​https://mp.weixin.qq.com/s/j5J2qRCNDa7NTOHirx4kvA​​,2025-05-20,微信公众号
  2. 苏剑林,《Transformer 升级之路:多头潜在注意力机制 (MLA) 究竟好在哪里?》,​​https://mp.weixin.qq.com/s/KdOjWF4n5gNtQxKKvkG5Mw​​,2025-05-22,微信公众号
  3. 姜富春,《DeepSeek 技术解读 1: 彻底理解 MLA》,​​https://mp.weixin.qq.com/s/yL_Z8zcAfWDcviZwApdL_w​​,2025-01-15,微信公众号
  4. 算法狗,《DeepSeek MLA: 高效推理的省钱之道,全流程剖析》,​​https://mp.weixin.qq.com/s/yNxjgQMl2LKzpGOoCWRRcw​​,2025-02-19,微信公众号
  5. 羽说 AI 研究圈,《从 MHA→MQA→GQA→MLA》,​​https://mp.weixin.qq.com/s/S9dfOCrWeru6zGjOjchV7Q​​,2025-02-12,微信公众号


本文转载自​鸿煊的学习笔记​,作者:乘风破浪jxj

©著作权归作者所有,如需转载,请注明出处,否则将追究法律责任
已于2025-7-17 14:14:47修改
收藏
回复
举报
回复
相关推荐