欢迎光临 x-algo
关注算法在工业界应用
Hi, 这是一个关注大数据算法在工业界应用的网站

RNN(Recurrent Neural Networks)公式推导和实现

本文主要参考wildml的博客所写,所有的代码都是python实现。没有使用任何深度学习的工具,公式推导虽然枯燥,但是推导一遍之后对RNN的理解会更加的深入。看本文之前建议对传统的神经网络的基本知识已经了解,如果不了解的可以看此文:『神经网络(Neural Network)实现』。

所有可执行代码:Code to follow along is on Github.

语言模型

熟悉NLP的应该会比较熟悉,就是将自然语言的一句话『概率化』。具体的,如果一个句子有m个词,那么这个句子生成的概率就是:

\(P(w_1,...,w_m) = \prod_{i=1}^{m} P(w_i \mid w_1,..., w_{i-1})\)

其实就是假设下一次词生成的概率和只和句子前面的词有关,例如句子『He went to buy some chocolate』生成的概率可以表示为:  P(他喜欢吃巧克力) = P(他喜欢吃) * P(巧克力|他喜欢吃) 。

数据预处理

训练模型总需要语料,这里语料是来自google big query的reddit的评论数据,语料预处理会去掉一些低频词从而控制词典大小,低频词使用一个统一标识替换(这里是UNKNOWN_TOKEN),预处理之后每一个词都会使用一个唯一的编号替换;为了学出来哪些词常常作为句子开始和句子结束,引入SENTENCE_START和SENTENCE_END两个特殊字符。具体就看代码吧:

网络结构

和传统的nn不同,但是也很好理解,rnn的网络结构如下图:

rnn

A recurrent neural network and the unfolding in time of the computation involved in its forward computation.

不同之处就在于rnn是一个『循环网络』,并且有『状态』的概念。

如上图,t表示的是状态, \(x_t\) 表示的状态t的输入, \(s_t\) 表示状态t时隐层的输出, \(o_t\) 表示输出。特别的地方在于,隐层的输入有两个来源,一个是当前的 \(x_t\) 输入、一个是上一个状态隐层的输出 \(s_{t-1}\)\(W,U,V\) 为参数。使用公式可以将上面结构表示为:

\(\begin{aligned}s_t &= \tanh(Ux_t + Ws_{t-1}) \\ \hat{y}_t &= \mathrm{softmax}(Vs_t)\end{aligned}\)

 如果隐层节点个数为100,字典大小C=8000,参数的维度信息为:

\(\begin{aligned} x_t & \in \mathbb{R}^{8000} \\ o_t & \in \mathbb{R}^{8000} \\ s_t & \in \mathbb{R}^{100} \\ U & \in \mathbb{R}^{100 \times 8000} \\ V & \in \mathbb{R}^{8000 \times 100} \\ W & \in \mathbb{R}^{100 \times 100} \\ \end{aligned}\)

初始化

参数的初始化有很多种方法,都初始化为0将会导致『symmetric calculations 』(我也不懂),如何初始化其实是和具体的激活函数有关系,我们这里使用的是tanh,一种推荐的方式是初始化为 \(\left[-\frac{1}{\sqrt{n}}, \frac{1}{\sqrt{n}}\right]\) ,其中n是前一层接入的链接数。更多信息请点击查看更多

 

前向传播

类似传统的nn的方法,计算几个矩阵乘法即可:

预测函数可以写为:

 

损失函数

类似nn方法,使用交叉熵作为损失函数,如果有N个样本,损失函数可以写为:

\(\begin{aligned} L(y,o) = - \frac{1}{N} \sum_{n \in N} y_{n}\log o_{n}\end{aligned}\)

下面两个函数用来计算损失:

BPTT学习参数

BPTT( Backpropagation Through Time)是一种非常直观的方法,和传统的BP类似,只不过传播的路径是个『循环』,并且路径上的参数是共享的。

损失是交叉熵,损失可以表示为:

\(\begin{aligned}E_t(y_t, \hat{y}_t) &= - y_{t} \log \hat{y}_{t} \\ E(y, \hat{y}) &=\sum\limits_{t} E_t(y_t,\hat{y}_t) \\ & = -\sum\limits_{t} y_{t} \log \hat{y}_{t} \end{aligned}\)

其中 \(y_t\) 是真实值, \(\hat(y_t)\) 是预估值,将误差展开可以用图表示为:

rnn-bptt1

所以对所有误差求W的偏导数为:

\(\frac{\partial E}{\partial W} = \sum\limits_{t} \frac{\partial E_t}{\partial W}\)

进一步可以将 \(E_t\) 表示为:

\(\begin{aligned}\frac{\partial E_3}{\partial V} &=\frac{\partial E_3}{\partial \hat{y}_3}\frac{\partial\hat{y}_3}{\partial V}\\&=\frac{\partial E_3}{\partial \hat{y}_3}\frac{\partial\hat{y}_3}{\partial z_3}\frac{\partial z_3}{\partial V}\\ &=(\hat{y}_3 - y_3) \otimes s_3 \\ \end{aligned}\)

根据链式法则和RNN中W权值共享,可以得到:

\(\begin{aligned}\frac{\partial E_3}{\partial W} &= \sum\limits_{k=0}^{3} \frac{\partial E_3}{\partial \hat{y}_3}\frac{\partial\hat{y}_3}{\partial s_3}\frac{\partial s_3}{\partial s_k}\frac{\partial s_k}{\partial W}\\ \end{aligned}\)

下图将这个过程表示的比较形象

rnn-bptt-with-gradients

BPTT更新梯度的代码:

梯度弥散现象

tanh和sigmoid函数和导数的取值返回如下图,可以看到导数取值是[0-1],用几次链式法则就会将梯度指数级别缩小,所以传播不了几层就会出现梯度非常弱。克服这个问题的LSTM是一种最近比较流行的解决方案。

tanh

Gradient Checking

梯度检验是非常有用的,检查的原理是一个点的『梯度』等于这个点的『斜率』,估算一个点的斜率可以通过求极限的方式:

\(\begin{aligned} \frac{\partial L}{\partial \theta} \approx \lim_{h \to 0} \frac{J(\theta + h) - J(\theta -h)}{2h} \end{aligned}\)

通过比较『斜率』和『梯度』的值,我们就可以判断梯度计算的是否有问题。需要注意的是这个检验成本还是很高的,因为我们的参数个数是百万量级的。

梯度检验的代码:

 

SGD实现

这个公式应该非常熟悉:

\(W = W - \lambda \Delta {W}\)

其中 \(\Delta{W}\) 就是梯度,具体代码:

生成文本

生成过程其实就是模型的应用过程,只需要反复执行预测函数即可:

参考文献

Recurrent Neural Networks Tutorial, Part 2 – Implementing a RNN with Python, Numpy and Theano

Recurrent Neural Networks Tutorial, Part 3 – Backpropagation Through Time and Vanishing Gradients

未经允许不得转载:大数据算法 » RNN(Recurrent Neural Networks)公式推导和实现

分享到:更多 ()

评论 7

*

  • 昵称 (必填)
  • 邮箱 (必填)
  • 网址
  1. #4

    交叉熵的损失函数为什么只写了一半?

    Duckbill1年前 (2017-02-26)回复
    • 交叉熵不就是sum(y_i log(o_i))吗? 能不能描述的具体点?

      leihao1年前 (2017-02-27)回复
  2. #3
    • (1)y_n和o_n是一维向量,维度是字典大小(2)E_3是使用s_3前向传播计算出来的,反向的时候也应该包括它;这里不好理解的是求和的形式,可以看做是对不同的W求偏导数,然后合并到一起的。

      leihao1年前 (2017-02-27)回复
  3. #2

    第一个问题已经明白了。第二个问题是在算E_3的偏导的时候,为什么k是从0--3,而不是0--2, 如果上限是3的话,那么后面的式子展开就会有s_3对s_3的偏导。这是为什么呢?

    Duckbill1年前 (2017-02-28)回复
    • 能否赐教,第一个问题是为什么吗

      aaaa12个月前 (06-07)回复
  4. #1

    我明白了,s_3对s_3的偏导就直接简化成1了吧,哈哈哈。

    Duckbill1年前 (2017-02-28)回复

关注大数据算法在工业界应用

本站的GitHub关于本站