多头潜在注意力:手把手用数学公式推导

发布于 2025-8-6 06:26
浏览
0收藏

多头潜在注意力,这可是个大家伙,得好好聊聊!

想象一下,你正在处理一堆复杂的数据,就像是面对着一座混乱不堪的宝藏山,每个数据点都是一块闪闪发光的宝石,但它们却杂乱无章地堆放在一起。这时,多头潜在注意力机制就像是一位身怀绝技的探险家,它带着一堆神奇的分身(也就是那些“头”),准备深入这座宝藏山,寻找那些隐藏的宝藏。

这些分身,哦不,这些“头”们,每个都有自己独特的视角和技能。它们会将这座混乱的宝藏山分割成多个小块,每个小块都由一个“头”负责探索。每个“头”都会在自己的潜在空间里自由翱翔,寻找那些与当前任务最相关的特征信息。这就像是在玩一场寻宝游戏,每个“头”都在努力寻找自己的线索。

当所有的“头”都找到了自己的宝藏后,它们就会将这些宝藏(也就是那些加权后的输出)收集起来,然后通过一个神奇的整合仪式(也就是那个线性变换层),将这些宝藏融合成一个完整的宝藏图。这张宝藏图不仅包含了所有“头”们的发现,还以一种更加有序和易于理解的方式呈现了出来。

所以,你看,多头潜在注意力机制就像是那位身怀绝技的探险家和它的分身们一起,通过分工合作和集体智慧,成功地探索了那座混乱的宝藏山,找到了一份珍贵的宝藏图。而这份宝藏图,就是我们所需要的、更加准确和丰富的数据表示。

怎么样?通过这样幽默风趣的语言和深入浅出的讲解,你是不是对多头潜在注意力机制有了更加清晰的认识呢?

减兵而不减势:多头潜在注意力

想象一下,你正在看一场盛大的魔术表演,魔术师手里拿着一叠扑克牌,准备给你展示一场令人惊叹的魔术。这时,多头潜在注意力机制就像那位魔术师,而扑克牌就是我们的输入序列。魔术师(多头潜在注意力机制)会将这叠扑克牌(输入序列)分成好几摞(小块),每摞都由他的一个助手(专家网络)来处理。

这些助手(专家网络)啊,可都是魔术师(多头潜在注意力机制)精心挑选的,每个都有自己的独门绝技。他们会仔细观察自己手里的那摞扑克牌(小块输入序列),找出其中最关键的几张牌(特征信息),然后施展魔法(进行处理),将这些牌变成更加炫酷的魔术道具(加权后的输出)。

当所有的助手(专家网络)都完成自己的任务后,魔术师(多头潜在注意力机制)就会将这些魔术道具(加权后的输出)收集起来,然后通过一个神奇的魔法阵(线性变换层),将这些道具融合成一个超级炫酷的魔术表演(最终的输出)。

你看,多头潜在注意力机制就像那位魔术师一样,通过分工合作和集体智慧,成功地将一堆普通的扑克牌变成了一场令人惊叹的魔术表演。而这场魔术表演的背后,其实是多个专家网络在共同工作,通过动态地选择和加权,将输入序列中的关键信息提取出来,并融合成一个更加丰富和准确的数据表示。

所以,下次当你再看到多头潜在注意力机制时,不妨想象一下那位魔术师和他的助手们正在为你上演一场精彩的魔术表演吧!这样,你就能更加轻松地理解这个复杂而有趣的机制了。

多头注意力

我们在聊多头潜在注意力的时候,当然不能忘了多头注意力。

  • 一、多头注意力:让模型“多管齐下”的绝技

我们先设想一下你和三位朋友走进一场相亲大会,目标是帮你找对象(当然我不是催婚啦)。

在这个看脸的时代,你当然首先专注的是颜值啦。不过聪明如你,当然不希望仅仅被颜值吸引,所以聪明的你请来了三个朋友:

朋友A看学历和家庭背景;

朋友B看性格和兴趣爱好;

朋友C看长远发展。

这样,每个人都有自己负责关注一个维度,你们就这样信心百倍的走进了相亲现场。这就是 多头注意力(Multi-Head Attention) 的核心思路:“不同的注意力头,关注不同的重点。最后,你们坐下来,把各自的观察结果一合计——完美!这比你一个人单打独斗靠谱多了!

多头潜在注意力:手把手用数学公式推导-AI.x社区

IMG_256

🧮换成模型的说法:

一个“注意力头”关注词和词的关系(比如主谓关系);

一个头专门看情感走向(这句话开心还是生气?);

一个头盯着长远关系(比如“前男友”这个词和结尾的“后悔”有关);

最后,把这些不同的视角综合起来,模型就能给出更聪明的判断。

多头潜在注意力

还记得刚才的相亲大会吗?

这次你带的朋友都太“聪明”了,他们不直接告诉你观察结果,而是自己先偷偷开了个小会,先互相交流一下谁的观察结果更有用。

朋友A说:“我觉得性格最重要,这一轮就盯性格好了。”

朋友B说:“算了,家庭背景这次先不管,我们看性格和颜值。”

他们根据现场的情况动态调整每个人的关注重点,而不是一开始就分工固定。

这就是所谓的 潜在多头注意力:“我不直接分配任务,而是让每个注意力头自己决定最值得关注的方向。”

多头潜在注意力:手把手用数学公式推导-AI.x社区

换成模型的说法:每个注意力头自己“思考”当前任务下,应该重点观察哪些方面,而不是固定关注点。

这样模型变得更加灵活,不会死板地每次都盯着同一类信息。

多头潜在注意力:手把手用数学公式推导-AI.x社区

小结

让我们对比一下,如:

概念

生活比喻

模型里的作用

多头注意力

带朋友去相亲,每人盯一个重点

每个“头”关注不同特征,提高模型看问题的全面性

潜在多头注意力

朋友们先开小会,再决定谁关注什么

模型动态调整注意力,更灵活、更聪明

一句话总结就是:

多头注意力让模型“分头行动”,潜在多头注意力让模型“会看场合调整战术”!

多头潜在注意力:手把手用数学公式推导-AI.x社区

纸上推演:多头潜在注意力的数学推演

在探索多头潜在注意力的数学推演之路上,我们即将启程,深入理解这一复杂而迷人的领域。多头潜在注意力模型,作为一种先进的深度学习架构,它在处理序列数据、图像识别、自然语言处理等众多领域中展现出了卓越的性能。我们通过结合 Excel 表格进行推演,揭开其背后的原理,理解其如何通过多头机制捕捉数据中的丰富信息,以及如何通过潜在空间的变换来增强模型的表达能力。

下图来自于多头潜在注意力论文中,描述的是多头潜在注意力的机制。

多头潜在注意力:手把手用数学公式推导-AI.x社区

下面让我们先看看我们的任务背景:

我们继续依据之前的例子,输入序列中包括6 个 Token,每个是 5 维向量。潜在向量有4 个 ,可学习潜在向量有4个,每个是 5 维,头的数量是2个。

每头维度 dk=dv=2 (因为我们把 5 维分成 2 个头)

步骤1: 输入隐藏状态

输入: h_t ∈ ℝ^{T × d}

多头潜在注意力:手把手用数学公式推导-AI.x社区

步骤 2: 初始化计算latent的权重

多头潜在注意力:手把手用数学公式推导-AI.x社区

计算 

多头潜在注意力:手把手用数学公式推导-AI.x社区

多头潜在注意力:手把手用数学公式推导-AI.x社区

wpsoffice

多头潜在注意力:手把手用数学公式推导-AI.x社区


同样,我们计算出

多头潜在注意力:手把手用数学公式推导-AI.x社区

多头潜在注意力:手把手用数学公式推导-AI.x社区

步骤 3: RoPE计算

对latent进行线性变换,得到注意力组件:

RoPE的计算如下:

多头潜在注意力:手把手用数学公式推导-AI.x社区

wpsoffice

为了计算

多头潜在注意力:手把手用数学公式推导-AI.x社区

,我们采用下面的公式:

多头潜在注意力:手把手用数学公式推导-AI.x社区

wpsoffice

我们先计算R1-R6:

多头潜在注意力:手把手用数学公式推导-AI.x社区

/Users/i/Library/Containers/com.kingsoft.wpsoffice.mac/Data/tmp/wpsoffice.KixErtwpsoffice

R1 的上下左右4个脚正是上图中对应的值:

多头潜在注意力:手把手用数学公式推导-AI.x社区

θ我们这分别取10,20,30,40.50,60。

计算Query 的潜在向量,公式如下:

多头潜在注意力:手把手用数学公式推导-AI.x社区

wpsoffice

下面是我们在 Excel中计算

多头潜在注意力:手把手用数学公式推导-AI.x社区

多头潜在注意力:手把手用数学公式推导-AI.x社区

位置向量(rotary):

多头潜在注意力:手把手用数学公式推导-AI.x社区

wpsoffice

多头潜在注意力:手把手用数学公式推导-AI.x社区

多头潜在注意力:手把手用数学公式推导-AI.x社区

RoPE(Key)的计算方式类似,我们就不一一计算了,在此一并给出计算结果:

多头潜在注意力:手把手用数学公式推导-AI.x社区

步骤4:: 单头注意力计算

多头潜在注意力:手把手用数学公式推导-AI.x社区

我们先计算头1的值。首先,我们计算头1的

多头潜在注意力:手把手用数学公式推导-AI.x社区

,Excel中的公式为:

=MMULT(Q47:S50,V16#)

多头潜在注意力:手把手用数学公式推导-AI.x社区

同理,我们计算出

多头潜在注意力:手把手用数学公式推导-AI.x社区

多头潜在注意力:手把手用数学公式推导-AI.x社区

,最终结果如图:

多头潜在注意力:手把手用数学公式推导-AI.x社区

步骤5: 

多头潜在注意力:手把手用数学公式推导-AI.x社区

多头潜在注意力:手把手用数学公式推导-AI.x社区

联合起来,得到以下矩阵:

多头潜在注意力:手把手用数学公式推导-AI.x社区

步骤6: 计算注意力

多头潜在注意力:手把手用数学公式推导-AI.x社区

/Users/i/Library/Containers/com.kingsoft.wpsoffice.mac/Data/tmp/wpsoffice.JiORRtwpsoffice

这里的dk指的是:单个注意力头(head)中,Query 和 Key 向量的特征维度大小。

也就是说:我们整体模型维度d是6,注意力有3个头h(heads),

多头潜在注意力:手把手用数学公式推导-AI.x社区

/Users/i/Library/Containers/com.kingsoft.wpsoffice.mac/Data/tmp/wpsoffice.rWNXPIwpsoffice

在我们这个例子中,dk = 2

多头潜在注意力:手把手用数学公式推导-AI.x社区

用同样的方法,我们计算出其他头:

多头潜在注意力:手把手用数学公式推导-AI.x社区

多头潜在注意力:手把手用数学公式推导-AI.x社区

步骤7: 连接所有头

多头潜在注意力:手把手用数学公式推导-AI.x社区

这是连接所有头后,计算输出的公式:

多头潜在注意力:手把手用数学公式推导-AI.x社区

wpsoffice

这是我们 Excel 中的计算:

多头潜在注意力:手把手用数学公式推导-AI.x社区

如此就完成了我们多头潜在注意里的计算。怎么样,是不是感觉收获满满?

以少胜多:多头潜在注意力的代码实现

接下来用 PyTorch 实现多头潜在注意力。

多头潜在注意力的代码实现逻辑其实并不复杂,主要步骤包括初始化参数、计算潜在向量、应用RoPE变换、计算单头注意力、将多个头的输出进行拼接等。下面,我们就来一步步揭开它的神秘面纱。

首先,我们需要定义一个类来实现多头潜在注意力机制,比如叫​​MultiHeadLatentAttention​​。在这个类中,我们需要初始化一些必要的参数,比如头的数量、输入和输出的维度、可学习的潜在向量等。

然后,我们来实现前向传播函数。在这个函数中,我们首先需要对输入进行线性变换,得到Query、Key和Value。接着,我们计算每个头的潜在向量,并应用RoPE变换。然后,我们按照标准的多头注意力机制来计算每个头的注意力得分,并将这些得分进行softmax归一化。

接下来用归一化后的得分对Value进行加权求和,得到每个头的输出。最后,我们将所有头的输出进行拼接,并通过一个线性变换层得到最终的输出。

在代码中,我们还需要注意一些细节,比如如何保持维度的一致性、如何高效地计算等。不过,只要理解了多头潜在注意力机制的基本原理,这些细节问题就迎刃而解了。

现在,你是不是已经迫不及待地想要看看具体的代码实现了呢?别急,下面我们就来给出完整的代码实现,并逐行进行解释。相信通过这段代码,你一定能更加深入地理解多头潜在注意力机制的实现原理。

代码实现

我们先来修改 MultiHeadLatentAttention ,用它来输出注意力权重:

import torch
import torch.nn as nn
import torch.nn.functional as F

class MultiHeadLatentAttention(nn.Module):
    def __init__(self, input_dim, latent_dim, num_latents, num_heads=1, dropout=0.1):
        super().__init__()
        self.num_heads = num_heads
        self.head_dim = latent_dim // num_heads
        assert latent_dim % num_heads == 0, "latent_dim must be divisible by num_heads"

        self.latent = nn.Parameter(torch.randn(num_latents, latent_dim))  # 改为英文命名
        self.q_proj = nn.Linear(latent_dim, latent_dim)
        self.k_proj = nn.Linear(input_dim, latent_dim)
        self.v_proj = nn.Linear(input_dim, latent_dim)
        self.out_proj = nn.Linear(latent_dim, latent_dim)
        self.dropout = nn.Dropout(dropout)

    def forward(self, x):
        batch_size = x.size(0)
        latent = self.latent.unsqueeze(0).expand(batch_size, -1, -1)  # 修改为英文变量名

        Q = self.q_proj(latent)
        K = self.k_proj(x)
        V = self.v_proj(x)

        def reshape(t):
            return t.view(batch_size, -1, self.num_heads, self.head_dim).transpose(1, 2)

        Q, K, V = map(reshape, (Q, K, V))

        # 计算注意力分数
        attn_scores = torch.matmul(Q, K.transpose(-2, -1)) / (self.head_dim ** 0.5)
        attn_weights = F.softmax(attn_scores, dim=-1)
        attended = torch.matmul(self.dropout(attn_weights), V)

        # 合并头
        attended = attended.transpose(1, 2).contiguous().view(batch_size, -1, self.num_heads * self.head_dim)
        output = self.out_proj(attended)

        # 返回输出和注意力权重(去除 batch 维度)
        return output, attn_weights.mean(dim=0)  # 或者返回 attn_weights.detach().cpu() 用于可视化

MultiHeadLatentAttention 逐行解释

✅ 类定义和初始化

​class MultiHeadLatentAttention(nn.Module):​

定义一个继承自 torch.nn.Module 的神经网络模块:

def __init__(self, input_dim, latent_dim, num_latents, num_heads=1, dropout=0.1):  
    super().__init__()

参数含义:

  • input_dim: 输入特征维度(序列的特征大小)。
  • latent_dim: 潜在变量特征维度(通常等于隐藏层维度)。
  • num_latents: 潜在向量的数量,用于注意力摘要。
  • num_heads: 注意力头的数量。
  • dropout: Dropout 比例,防止过拟合。

self.num_heads = num_heads  
self.head_dim = latent_dim // num_heads  
assert latent_dim % num_heads == 0

计算每个注意力头的维度 head_dim。

保证 latent_dim 能被 num_heads 整除,以便均匀拆分维度:

​self.latent = nn.Parameter(torch.randn(num_latents, latent_dim))​

初始化潜在向量,形状 [num_latents, latent_dim]。

关键点:这些是可学习参数,用于从输入中提取摘要信息。

self.q_proj = nn.Linear(latent_dim, latent_dim)

self.k_proj = nn.Linear(input_dim, latent_dim)  
self.v_proj = nn.Linear(input_dim, latent_dim)  
self.out_proj = nn.Linear(latent_dim, latent_dim)  
self.dropout = nn.Dropout(dropout)

定义线性投影层:

  • q_proj: 将潜在向量映射为 Query。
  • k_proj: 将输入映射为 Key。
  • v_proj: 将输入映射为 Value。
  • out_proj: 将多头注意力结果映射回潜在空间。

Dropout 防止过拟合。

✅ 前向传播 Forward

def forward(self, x):  
     batch_size = x.size(0)

获取输入的 batch 大小。​​latent = self.latent.unsqueeze(0).expand(batch_size, -1, -1)​

将潜在向量扩展到当前 batch 大小,形状变为 [batch_size, num_latents, latent_dim]。

Q = self.q_proj(latent)  
K = self.k_proj(x)  
V = self.v_proj(x)

计算 Query、Key、Value:

Q: 从潜在向量中得到。

K 和 V: 从输入序列中得到。

def reshape(t):  
     return t.view(batch_size, -1, self.num_heads, self.head_dim).transpose(1, 2)

将张量 reshape 为多头注意力所需的形状:

从 [batch_size, seq_len, latent_dim] →  [batch_size, num_heads, seq_len, head_dim]

Q, K, V = map(reshape, (Q, K, V))

应用 reshape 到 Q、K、V。

attn_scores = torch.matmul(Q, K.transpose(-2, -1)) / (self.head_dim ** 0.5)

计算注意力分数:

公式:Q @ K^T / sqrt(head_dim)

缩放因子防止梯度爆炸。

​attn_weights = F.softmax(attn_scores, dim=-1)​

通过 softmax 获取注意力权重。

​attended = torch.matmul(self.dropout(attn_weights), V)​

根据注意力权重计算上下文向量(信息聚合)。

​attended = attended.transpose(1, 2).reshape(batch_size, -1, self.num_heads * self.head_dim)​

将多头结果重新拼接回 [batch_size, num_latents, latent_dim]。

​output = self.out_proj(attended)​

最终线性映射到 latent_dim 空间。

​return output, attn_weights.squeeze(0)​

返回:

  • output: 潜在向量的更新结果。
  • attn_weights: 注意力权重 [num_heads, num_latents, seq_len],可用于可视化注意力分布。

核心原理总结

用固定数量的潜在向量代替直接对长序列计算注意力,减少计算量。

每个潜在向量通过注意力机制从输入序列中摘要关键信息。

类似于 Perceiver、Set Transformer 中的 Inducing Points 思路,高效且适用于大规模输入。

代码应用示例

为 Latent指定语义标签

latent_labels = ["Subject", "Verb", "Object", "Time", "Emotion", "Action"]

可视化 注意力权重(以热图显示)

导入绘图库 matplotlib 和 seaborn,用于绘制热力图

import matplotlib.pyplot as plt

import seaborn as sns

def visualize_attention(attn_weights, Token_ids, latent_labels, id_to_Token):

可视化多头注意力权重的热力图。

参数说明:

  • attn_weights: 注意力权重张量,形状 [num_heads, num_latents, seq_len]
  • Token_ids: 输入序列的 Token ID 列表,长度为 seq_len
  • latent_labels: 潜在向量的标签列表,长度为 num_latents
  • id_to_Token: 字典,用于将 Token_id 转换成实际的 Token 文本

"""

# 获取注意力头的数量

num_heads = attn_weights.shape[0]

# 将 Token_ids 转换为对应的可读 Token 文本

Tokens = [id_to_Token[idx] for idx in Token_ids]

# 遍历每一个注意力头,分别绘制热力图

for h in range(num_heads):

# 新建一个图像窗口,设置大小

plt.figure(figsize=(10, 6))

# 绘制当前注意力头的热力图

sns.heatmap(

attn_weights[h].detach().numpy(), # 将当前注意力头的张量转成 NumPy 数组

xticklabels=Tokens, # X 轴标签为输入的 Token 文本

yticklabels=latent_labels, # Y 轴标签为潜在向量标签

cmap="viridis", # 配色风格为 "viridis"

annot=True, # 在热力图上显示具体数值

fmt=".2f"# 数值保留两位小数

)

# 设置图表标题,标明当前是第几个注意力头

plt.title(f"Attention Heatmap (Head {h})")

# 设置 X 轴和 Y 轴的标签

plt.xlabel("Input Tokens")

plt.ylabel("潜在Vectors")

# 显示图表

plt.show()

将这一切连起来:

# 定义输入句子

sentence = ["i", "drink", "and", "know", "things"]

# 将单词转为对应的 Token ID(假设 Token_to_id 是一个字典)

Token_ids = [Token_to_id[t] for t in sentence]

# 将 Token ID 列表转换成张量,并增加 batch 维度(形状变为 [1, seq_len])

Token_tensor = torch.tensor(Token_ids).unsqueeze(0)

# 将 Token_tensor 中的 Token IDs 映射到对应的嵌入向量(假设 embedding_tensor 已经预定义)

# 结果 embedded 的形状为 [1, seq_len, embedding_dim]

embedded = embedding_tensor[Token_tensor]

# 创建 MultiHeadLatentAttention 模型实例

# input_dim = embedding_dim(单词嵌入的维度)

# latent_dim = 2(潜在空间维度,通常不这么低,这里是演示用)

# num_latents = 6(使用 6 个潜在向量)

# num_heads = 1(单头注意力)

mhla = MultiHeadLatentAttention(input_dim=embedding_dim, latent_dim=2, num_latents=6, num_heads=1)

# 执行前向传播

# embedded 输入形状:[batch_size=1, seq_len, embedding_dim]

# output: 更新后的潜在向量表示 [1, num_latents, latent_dim]

# attn_weights: 注意力权重 [num_heads, num_latents, seq_len]

output, attn_weights = mhla(embedded)

# 使用可视化函数展示注意力分布

# latent_labels 是潜在向量的名称或编号(如 ["L1", "L2", ..., "L6"])

# id_to_Token 是 Token ID 到单词的映射字典

visualize_attention(attn_weights, Token_ids, latent_labels, id_to_Token)

小结

多头潜在注意力的创新在于它结合了多头注意力和潜在向量的思想,实现了对长序列的高效处理。通过固定数量的潜在向量来代表输入序列的关键信息,多头注意力机制能够捕捉不同方面的依赖关系。

这种方法不仅减少了计算量,还提高了模型的泛化能力,使其能够处理更复杂的任务。

此外,通过为潜在向量指定语义标签,可以进一步增强模型的可解释性,使得我们能够更好地理解模型是如何从输入序列中提取关键信息的。

总之,多头潜在注意力是一种高效且强大的注意力机制,为自然语言处理等领域的研究和应用提供了新的思路和方法。

文本转载自 ​AI大模型世界​,作者:roclv

已于2025-8-6 10:11:20修改
收藏
回复
举报
回复
相关推荐