
什么是神经网络---LSTM模型实例讲解
LSTM的关键在于它的“记忆单元”,能够选择性地记住或者忘记信息。其核心组件包括三个门和一个记忆单元:
1. 遗忘门(Forget Gate):决定应该丢弃哪些信息。
2. 输入门(Input Gate):决定更新哪些新的信息。
3. 输出门(Output Gate):决定当前状态如何影响输出。
数学公式解释
- 遗忘门:
ft=σ(Wf⋅[ht−1,xt]+bf)
遗忘门决定了上一时刻的状态 Ct−1中,哪些信息需要保留,哪些需要丢弃。值域为 [0, 1],1表示完全保留,0表示完全丢弃。 - 输入门:
it=σ(Wi⋅[ht−1,xt]+bi)
输入门决定了当前时刻的输入 xt - 候选记忆单元:
C~t=tanh(WC⋅[ht−1,xt]+bC)
这是当前时刻的候选记忆内容。 - 更新记忆单元:
Ct=ft⋅Ct−1+it⋅C~t
记忆单元通过遗忘门和输入门结合,更新当前的记忆状态。 - 输出门:
ot=σ(Wo⋅[ht−1,xt]+bo)
输出门决定了当前记忆单元 Ct - 隐藏状态更新:
ht=ot⋅tanh(Ct)
隐藏状态通过输出门和当前记忆单元来更新。
有同学比较疑惑LSTM中的sigmoid函数和tanh的作用是什么?下来我来为大家解惑:
- Sigmoid函数(遗忘门、输入门、输出门):用于控制信息流。由于sigmoid的输出值在 0 和 1 之间,表示“选择”的强度。0表示完全不通过,1表示完全通过。它起到类似开关的作用。
- Tanh函数(候选记忆单元、隐藏状态):用于将输入值缩放到 -1 到 1 之间,确保信息在网络中不会增长过大或过小,帮助模型捕捉数据中的正负变化。同时,tanh可以引入非线性特征,增加网络表达能力。
也就是说Sigmoid函数作为**“开关”**,在LSTM的各个门(遗忘门、输入门、输出门)中使用,决定信息流的多少。Tanh函数用于将数值范围缩放到-1到1之间,帮助控制记忆单元的值,确保信息的平衡和稳定性,并用于生成隐藏状态。
下面让我们来用一个例子来辅助大家对模型的理解:
为了详细讲解LSTM如何工作,我们通过一个具体的例子一步步剖析每个步骤。
例子:预测序列中的下一个数
我们有一个简单的序列数据:1,2,3,4,5,6,7,8,9,10
我们的目标是训练一个LSTM模型,让它能够根据之前的数字预测下一个数字。例如,输入[1, 2, 3]时,模型应该输出4。
1. 输入表示
LSTM处理的是时间序列数据,我们可以将每个数字视为一个时间步(time step)。对于输入序列1,2,3,我们需要在每个时间步都输入一个数字:
- 时间步1:输入1
- 时间步2:输入2
- 时间步3:输入3
在每一个时间步,LSTM会使用前一步的隐藏状态以及当前输入来更新它的记忆单元C和隐藏状态h。
2. LSTM的核心机制
LSTM有三个关键的门:遗忘门、输入门和输出门。这三个门控制了信息如何在LSTM单元中流动。让我们看一下当我们输入1,2,3时,LSTM内部发生了什么。
时间步1:输入1
- 遗忘门
遗忘门的作用是决定上一个时间步的信息要保留多少。因为这是第一个时间步,之前没有信息,所以LSTM的记忆单元初始为0。假设此时遗忘门计算出的值为0.8,这意味着LSTM会保留80%的之前的记忆状态(虽然此时没有实际的历史状态)。 - 输入门
输入门决定新信息要多少被写入记忆单元。假设输入门给出的值为0.9,意味着我们会把当前输入的信息的90%加入到记忆单元。 - 更新记忆单元
LSTM单元会计算候选记忆内容。假设此时的候选内容为1(通过激活函数计算得到),结合遗忘门和输入门,更新记忆单元:
C1=0.8⋅0+0.9⋅1=0.9 - 输出门
输出门决定记忆单元如何影响隐藏状态。假设输出门给出的值为0.7,隐藏状态通过以下公式计算:
h1=0.7⋅tanh(0.9)≈0.63 - 这个隐藏状态会作为下一时间步的输入。
时间步2:输入2
- 遗忘门
遗忘门决定如何处理前一个时间步的记忆单元。假设遗忘门的值为0.7,意味着70%的上一步记忆将被保留。 - 输入门
假设此时输入门的值为0.8,意味着会将当前输入2的80%加入到记忆单元。 - 更新记忆单元
候选内容通过激活函数计算,假设此时候选内容为2。结合遗忘门和输入门:
C2=0.7⋅0.9+0.8⋅2=2.03 - 输出门
假设输出门的值为0.6,隐藏状态为:
h2=0.6⋅tanh(2.03)≈0.56
时间步3:输入3
- 遗忘门
假设遗忘门的值为0.75,保留75%的上一个记忆。 - 输入门
假设输入门的值为0.85,意味着会将当前输入3的85%加入到记忆单元。 - 更新记忆单元
假设候选内容为3,更新记忆单元:
C3=0.75⋅2.03+0.85⋅3=3.5 - 输出门
假设输出门的值为0.65,隐藏状态为:
h3=0.65⋅tanh(3.5)≈0.58
3. LSTM如何预测
经过这三个时间步,LSTM得到了一个隐藏状态h3,代表了模型对序列1,2,3的理解。接下来,隐藏状态会通过一个全连接层或线性层,输出预测值。
假设线性层的输出是4,这意味着LSTM模型根据序列1,2,3,预测接下来是4。
总结
- 遗忘门决定了LSTM是否保留之前的记忆。
- 输入门决定了新输入如何影响当前的记忆单元。
- 输出门决定了隐藏状态如何影响输出。
通过这些门的组合,LSTM可以有效地在长序列数据中学习到哪些信息是重要的,哪些是可以忽略的,从而解决传统RNN的长依赖问题。
接下来会继续深入讲解LSTM模型,比如如果输入一串中文,LSTM该如何处理怎么预测,比如各个门的权重如何更新,如何决定遗忘多少,输入多少。
本文转载自人工智能训练营,作者:人工智能训练营
