深度学习中序列模型(RNN、GRU、LSTM、Transformer)的本质理解
本文详细描述了RNN、GRU、LSTM模型架构的设计理念,希望能帮助大家更深入的理解这几大模型。模型的intuition搞清楚了,复杂的数学公式也就不再显得复杂了,更像是水到渠成的结果。
不成熟想法
我们如何理解序列模型计算出来的隐藏状态向量,它究竟表示什么呢?比如有购买序列:男士外套、男士外套、男士T恤、男士T恤、男士T恤,人类看到这个序列会得出怎样的结论?至少有下面几条:
1、 他购买女士衣服的概率比较低
2、 他购买男士裤子的概率比较低
3、 他喜欢购买男士外套和男士T恤
4、 最近他购买男士T恤较多,所以下一个购买的商品,男士T恤会高于男士外套
5、 …
所以我的理解是,隐藏状态就是对上面所有这些结论的一个综合向量表示。
上面是对于从左到右的序列模型,那bidirectionalRNN和transformer呢?
这俩又可以进一步细分为sentence level 和 token level。对于sentence level,会丢失位置信息,隐藏状态表示的结论就是:
1、 他购买女士衣服的概率比较低
2、 他购买男士裤子的概率比较低
3、 他喜欢购买男士外套和男士T恤
4、 他购买男士T恤较多,所以更喜欢男士T恤
5、 。。。
对于token level,比如我们处于第三个购买商品上,左边儿是两件男士外套,右边儿是两件男士T恤,自己是男士T恤,该商品的隐藏状态表示的结论就是:
1、 他购买女士衣服的概率比较低
2、 他购买男士裤子的概率比较低
3、 他喜欢购买男士外套和男士T恤
4、 我自己是男士T恤,所以他更喜欢男士T恤
5、 …
写得很不严谨,一些不成熟的想法,供大家选择算法时能够有些intuition~~
正餐
开始严谨了。r2rt大神的blog探讨了RNN、GRU和LSTM模型架构设计背后隐藏的intuition,告诉你模型中的每一部分是怎么来的,所以特别推荐读一读:https://r2rt.com/written-memories-understanding-deriving-and-extending-the-lstm.html
本章节将要点进行总结。
RNN
RNNcell中的状态信息就是对输入序列的向量表示(截止到所在时刻)。它的计算方式决定了状态向量state一直在变换,也就是文中说的information morphs,所以模型不够稳定。举个例子,“这条裤子款式很好,价格合理”,根据评论描述按照商品类别对评论归类,那这条评论就属于裤子的。我们读到裤子其实就已经可以判断类别了,此时state就是我们需要的。但是裤子后面的输入文本对这个state进行了信息变换,最后rnn的输出状态已经不是我们想要的了。
RNN的第二个问题是,gradient vanishing,也就是梯度消失问题。RNN的梯度消失和CNN的不同,指的是较早时刻的梯度发生了消失,RNN的梯度近似为各个时刻的梯度计算结果之和,较早时刻的梯度计算包含的小于1的连乘项较多,因而梯度基本为0了。结果就是只有最近时刻的输入参与了误差的后向传播过程,长度为100的序列,模型就只能利用最后的10个元素。
Prototype LSTM
information morphs发生的原因是对于上个时刻的state,我们不加选择的全部拿来计算当前时刻的state,这是读的问题。另一个原因是,计算出来的新状态,不加选择的将上个时刻的state完全覆盖。所以我们要有选择的进行读和写操作。这样将RNN的结构修改如下:
i
t
i_t
it和
o
t
o_t
ot分别是控制写和读的门。
s
t
~
=
ϕ
(
W
(
o
t
⊙
s
t
−
1
)
+
U
x
t
+
b
)
\tilde{s_t}=\phi(W(o_t\odot s_{t-1})+Ux_t+b)
st~=ϕ(W(ot⊙st−1)+Uxt+b)
s
t
=
i
t
⊙
s
t
~
s_t=i_t\odot\tilde{s_t}
st=it⊙st~
对于RNN的第二个问题,我们需要保证较早时刻的梯度可以流动到当前时刻,实现方法也很简单,写的时候把
s
t
−
1
s_{t-1}
st−1也考虑进来就可以了,即
s
t
=
s
t
−
1
+
i
t
⊙
s
t
~
s_t=s_{t-1}+i_t\odot\tilde{s_t}
st=st−1+it⊙st~。看着是不是眼熟,大名鼎鼎的残差网络ResNet就是借鉴了LSTM。但是这样设计也有问题,远距离的所有输入都不加选择的保留下来。我们希望的是,远距离的有些输入留下来,有些就遗忘吧。所以可以设计成
f
t
f_t
ft是遗忘门,
这样就得到了文中的Prototype LSTM:
看看是不是和GRU很像了。
GRU
Prototype LSTM的主要问题是state缺少bound,所以可能会无限增长直到模型崩溃,无法再学习任何东西,有点儿像梯度爆炸。于是乎我们对
f
t
f_t
ft和
i
t
i_t
it做下修改,使得
s
t
s_t
st无法膨胀,这样我们就得到了GRU:
LSTM
除了GRU中的那种bound方法,还有第二个bound方法,就是在计算门和状态的时候,对state进行normalization,
ϕ
(
s
t
−
1
)
\phi(s_{t-1})
ϕ(st−1) ,
ϕ
\phi
ϕ可以是tanh()。这样模型结构如下,文中称为Pseudo LSTM:
这个模型的问题就是最后的
r
n
n
o
u
t
rnn_out
rnnout不够完美,我们希望输出的状态也应该是有选择的,那改成这样可行么:
r
n
n
o
u
t
=
o
t
⊙
ϕ
(
s
t
)
rnn_{out}=o_t\odot\phi(s_t)
rnnout=ot⊙ϕ(st)
不够完美,因为
o
t
o_t
ot是使用上个时刻的state计算的,所以可以改成:
o
t
′
=
σ
(
W
o
(
ϕ
(
s
t
)
)
+
U
o
x
t
+
b
o
)
o_t'=\sigma(W_o(\phi(s_t))+U_ox_t+b_o)
ot′=σ(Wo(ϕ(st))+Uoxt+bo)
r
n
n
o
u
t
=
o
t
′
⊙
ϕ
(
s
t
)
rnn_{out}=o_t'\odot\phi(s_t)
rnnout=ot′⊙ϕ(st)
这是一个读的过程,这样我们经历了读-写-读三步。如果我们把第一个读的步骤去掉,写步骤使用上一时刻读的结果,就得到了标准的LSTM了。
Peephole LSTM
Peephole LSTM是比较有名的LSTM的变体了。LSTM的问题是,使用
h
t
−
1
h_{t-1}
ht−1计算门和候选状态,但是
h
t
−
1
h_{t-1}
ht−1并不是主状态,即memory cell的状态信息,而是我们经过筛选的输出信息。所以我们在计算门和候选状态时,应该将
c
t
−
1
c_{t-1}
ct−1也考虑进来。这样模型结构如下:
注意计算
o
t
o_t
ot的时候使用的是
c
t
c_{t}
ct而不是
c
t
−
1
c_{t-1}
ct−1。
总结
RNN、GRU、LSTM模型虽然复杂,公式繁多,但是每个符号都有它存在的意义,绝不是随便就写在那里的,当然更不是凭空猜出来,而是基于一定的设计思路和直觉想法。对我的启示时,对AI算法的理解一定要上升到intuition层面,搞清楚算法每一步后面的动机。相比于直接去记忆抽象的公式符号,动机其实是更容易记忆的。当动机都掌握清楚了,公式推导自然也是水到渠成的事情了。这应该才是掌握一个算法正确的路子。
更多推荐
所有评论(0)