3. RNN及其变体_LSTMGUR
1. LSTM 模型
图中每个黄色块都是一个 linear全连接层、3个σ代表3个门值,值域(0,1),它是黄色的,所以每个门的门值对应一个全连接层;
(LSTM的输入包含三部分:当前时间步输入Xt、上一时间步的隐藏层张量输出结果 hidden、上一时间步的C(细胞状态);细胞状态:图中下面部分进行复制:一个output输出、一个 ht传入下一层,所以上面没有交叉的部分称为细胞状态。)
1️⃣ 遗忘门结构分析:将当前时间步输入 xt与上一时间步隐藏层状态 ht-1进行 concat拼接得到 [xt, ht-1],然后通过一个 linear全连接层做变换,最后通过 sigmoid函数进行激活得到一个遗忘门门值 ft,(sigmoid值域(0,1)则 ft值(0,1)),好比一扇门开合的大小程度,门值都将作用在通过该扇门的张量,遗忘门门值将作用的上一层的细胞状态上,代表遗忘过去的多少信息,又因为遗忘门门值是由 xt,ht-1计算得来的,因此整个公式意味着根据当前时间步输入和上一个时间步隐含状态 ht-1来决定遗忘多少上一层的细胞状态所携带的过往信息;
(① concat之后一定要经过一个全连接层,全连接层的核心目标是为了进行形状的转化。② 对于xt, ht-1作用得到一个门值,作用的地方是上一层的细胞状态,上一层的细胞状态包含前文信息,遗忘门会选择性部分遗忘即 选择性部分记忆,并非全部记忆;对于 RNN它的缺点是:链式法则需要考虑到全部,每一个词都要记忆,每次都要对最前面的词进行求导,链式法则时会乘很多元素;现在遗忘门会将前文信息选择性的遗忘删除,使得在连乘时元素个数减少,因此遗忘门可以缓解梯度消失现象。如何判断遗忘?:假设开始时遗忘门的门值都是 1即都进行了保留,但模型本身有损失,损失大效果差,黄色部分是一个 linear层,权重也会更新,导致门值一定也会更新。遗忘门门值获取:由当前时间步的 xt与上一时间步隐藏层张量的结果 ht-1拼接后,经过一个全连接层,再经过一个 sigmoid激活函数,得到一个遗忘门门值 ft;f即 forget)
2️⃣ 输入门结构分析:我们看到输入门的计算公式有两个:第一个就是产生输入门门值的公式,它和遗忘门公式几乎相同,区别只是在于它们之后要作用的目标上,这个公式意味着输入信息有多少需要进行过滤;输入门的第二个公式是与传统 RNN的内部结构计算相同,对于 LSTM来讲,它得到的是当前的细胞状态,而不是像经典 RNN一样得到的是隐含状态。
(① xt与 ht-1拼接后经过 linear全连接层线性变换后,经过 sigmoid激活函数后得到一个输入门的门值 it;这个门值乘以 一个xt与 ht-1拼接后经过 linear全连接层线性变换后,经过tanh得到的结果 Ct~(此结果类似传统 RNN输出的结果,对此结果做了输入门的选择:Ct~ 可看作是临时的细胞状态或者说是加了新的 xt之后,当前时间步得到的一个新的临时的细胞状态,但需要通过输入门对其进行选择性的记忆(即经过输入门进行一次过滤));② 输入门包含两部分:获取输入门门值、选择输入的对象(即输入门即将作用的对象):第一部分:输入门门值获取:由 xt、ht-1拼接后送给 linear全连接层,sigmoid后得到输入门门值 it;第二部分:作用的对象:xt、ht-1拼接后送给 linear全连接层,再经过一个 tanh激活函数,得到一个结果 Ct(Ct是加了新的 xt之后,当前时间步得到的一个新的临时的细胞状态,但需要通过输入门对其进行选择性的记忆);
3️⃣ 细胞状态更新分析:细胞更新的结构与计算公式非常容易理解,这里没有全连接层,只是将刚刚得到的遗忘门门值与上一个时间步得到的 C(t-1)相乘,再加上输入门门值与当前时间步得到的未更新 C(t)相乘的结果,最终得到更新后的C(t)作为下一个时间步输入的一部分,整个细胞状态更新过程就是对遗忘门和输入门的应用。
(细胞状态更新用到了遗忘门门值 ft和输入门门值 it,公式::
选择性遗忘:遗忘门 ft作用于上一时间步的细胞状态 Ct-1,ft * Ct-1这个过程是选择性的遗忘’历史的’一些消息 ➕️ 选择性输入:输入门 it作用于’当前’真实输入的一部分 Ct~,哪些重要则记忆,不重要的赋值给小的权重值 = 两者融合:既包含了以前的历史信息、又包含了当前输入的新的信息,最终得到当前时间步新的细胞状态 Ct。得到新的细胞状态 Ct后可直接进行输出,给下一个时间步进行使用;此时当前时间步的细胞状态已经更新完毕。)
对于 ht-1(隐藏状态)和 Ct-1(细胞状态):两者都包含了历史消息:ht-1是一个短期记忆,临时的上下文;Ct-1是一个长期记忆。
