什么是神经网络---LSTM模型实例讲解

发布于 2025-6-19 06:57
浏览
0收藏

LSTM的关键在于它的“记忆单元”,能够选择性地记住或者忘记信息。其核心组件包括三个门和一个记忆单元:

1. 遗忘门(Forget Gate):决定应该丢弃哪些信息。

2. 输入门(Input Gate):决定更新哪些新的信息。

3. 输出门(Output Gate):决定当前状态如何影响输出。

什么是神经网络---LSTM模型实例讲解-AI.x社区

数学公式解释

  • 遗忘门
    ft=σ(Wf⋅[ht−1,xt]+bf)
    遗忘门决定了上一时刻的状态 Ct−1中,哪些信息需要保留,哪些需要丢弃。值域为 [0, 1],1表示完全保留,0表示完全丢弃。
  • 输入门
    it=σ(Wi⋅[ht−1,xt]+bi)
    输入门决定了当前时刻的输入 xt
  • 候选记忆单元
    C~t=tanh⁡(WC⋅[ht−1,xt]+bC)
    这是当前时刻的候选记忆内容。
  • 更新记忆单元
    Ct=ft⋅Ct−1+it⋅C~t
    记忆单元通过遗忘门和输入门结合,更新当前的记忆状态。
  • 输出门
    ot=σ(Wo⋅[ht−1,xt]+bo)
    输出门决定了当前记忆单元 Ct
  • 隐藏状态更新
    ht=ot⋅tanh⁡(Ct)
    ‍隐藏状态通过输出门和当前记忆单元来更新。

有同学比较疑惑LSTM中的sigmoid函数和tanh的作用是什么?下来我来为大家解惑:‍

  • Sigmoid函数(遗忘门、输入门、输出门):用于控制信息流。由于sigmoid的输出值在 0 和 1 之间,表示“选择”的强度。0表示完全不通过,1表示完全通过。它起到类似开关的作用。
  • Tanh函数(候选记忆单元、隐藏状态):用于将输入值缩放到 -1 到 1 之间,确保信息在网络中不会增长过大或过小,帮助模型捕捉数据中的正负变化。同时,tanh可以引入非线性特征,增加网络表达能力。

    也就是说Sigmoid函数作为**“开关”**,在LSTM的各个门(遗忘门、输入门、输出门)中使用,决定信息流的多少。Tanh函数用于将数值范围缩放到-1到1之间,帮助控制记忆单元的值,确保信息的平衡和稳定性,并用于生成隐藏状态。

下面让我们来用一个例子来辅助大家对模型的理解:

为了详细讲解LSTM如何工作,我们通过一个具体的例子一步步剖析每个步骤。

例子:预测序列中的下一个数

我们有一个简单的序列数据:1,2,3,4,5,6,7,8,9,10
我们的目标是训练一个LSTM模型,让它能够根据之前的数字预测下一个数字。例如,输入[1, 2, 3]时,模型应该输出4。

1. 输入表示

LSTM处理的是时间序列数据,我们可以将每个数字视为一个时间步(time step)。对于输入序列1,2,3,我们需要在每个时间步都输入一个数字:

  • 时间步1:输入1
  • 时间步2:输入2
  • 时间步3:输入3

在每一个时间步,LSTM会使用前一步的隐藏状态以及当前输入来更新它的记忆单元C隐藏状态h

2. LSTM的核心机制

LSTM有三个关键的门:遗忘门输入门输出门。这三个门控制了信息如何在LSTM单元中流动。让我们看一下当我们输入1,2,3时,LSTM内部发生了什么。

时间步1:输入1

  • 遗忘门
    遗忘门的作用是决定上一个时间步的信息要保留多少。因为这是第一个时间步,之前没有信息,所以LSTM的记忆单元初始为0。假设此时遗忘门计算出的值为0.8,这意味着LSTM会保留80%的之前的记忆状态(虽然此时没有实际的历史状态)。
  • 输入门
    输入门决定新信息要多少被写入记忆单元。假设输入门给出的值为0.9,意味着我们会把当前输入的信息的90%加入到记忆单元。
  • 更新记忆单元
    LSTM单元会计算候选记忆内容。假设此时的候选内容为1(通过激活函数计算得到),结合遗忘门和输入门,更新记忆单元:
    C1=0.8⋅0+0.9⋅1=0.9
  • 输出门
    输出门决定记忆单元如何影响隐藏状态。假设输出门给出的值为0.7,隐藏状态通过以下公式计算:
    h1=0.7⋅tanh⁡(0.9)≈0.63
  • 这个隐藏状态会作为下一时间步的输入。

时间步2:输入2

  • 遗忘门
    遗忘门决定如何处理前一个时间步的记忆单元。假设遗忘门的值为0.7,意味着70%的上一步记忆将被保留。
  • 输入门
    假设此时输入门的值为0.8,意味着会将当前输入2的80%加入到记忆单元。
  • 更新记忆单元
    候选内容通过激活函数计算,假设此时候选内容为2。结合遗忘门和输入门:
    C2=0.7⋅0.9+0.8⋅2=2.03
  • 输出门
    假设输出门的值为0.6,隐藏状态为:
    h2=0.6⋅tanh⁡(2.03)≈0.56

时间步3:输入3

  • 遗忘门
    假设遗忘门的值为0.75,保留75%的上一个记忆。
  • 输入门
    假设输入门的值为0.85,意味着会将当前输入3的85%加入到记忆单元。
  • 更新记忆单元
    假设候选内容为3,更新记忆单元:
    C3=0.75⋅2.03+0.85⋅3=3.5
  • 输出门
    假设输出门的值为0.65,隐藏状态为:
    h3=0.65⋅tanh⁡(3.5)≈0.58

3. LSTM如何预测

经过这三个时间步,LSTM得到了一个隐藏状态h3,代表了模型对序列1,2,3的理解。接下来,隐藏状态会通过一个全连接层或线性层,输出预测值。

假设线性层的输出是4,这意味着LSTM模型根据序列1,2,3,预测接下来是4。

总结

  • 遗忘门决定了LSTM是否保留之前的记忆。
  • 输入门决定了新输入如何影响当前的记忆单元。
  • 输出门决定了隐藏状态如何影响输出。

通过这些门的组合,LSTM可以有效地在长序列数据中学习到哪些信息是重要的,哪些是可以忽略的,从而解决传统RNN的长依赖问题

接下来会继续深入讲解LSTM模型,比如如果输入一串中文,LSTM该如何处理怎么预测,比如各个门的权重如何更新,如何决定遗忘多少,输入多少。

本文转载自​​人工智能训练营​​,作者:人工智能训练营

收藏
回复
举报
回复
相关推荐