梯度是如何计算的
引言
深度学习模型的训练本质上是一个优化问题,而常采用的优化算法是梯度下降法(SGD)。对于SGD算法,最重要的就是如何计算梯度。此时,估计跟多人会告诉你:采用BP(backpropagation)算法,这没有错,因为神经网络曾经的一大进展就是使用BP算法计算梯度提升训练速度。但是从BP的角度,很多人陷入了推导公式的深渊。如果你学过微积分,我相信你一定知道如何计算梯度,或者说计算导数。对于深度网络来说,其可以看成多层非线性函数的堆积,即:
而我们知道深度学习模型的优化目标L一般是output的函数,如果要你求L关于各个参数的导数,你会不假思索地想到:链式法则。因为output是一个复合函数。在微积分里面,求解复合函数的导数采用链式法则再合适不过了。其实本质上BP算法就是链式法则的一个调用。让我们先忘记BP算法,从链式法则开始说起。
链式法则
链式法则无非是将一个复杂的复合函数从上到下逐层求导,比如你要求导式子:f(x,y,z)=(x+y)(x+z)。当然这个例子是足够简单的,但是我们要使用链式法则的方式来求导。首先可以将f(x,y,z)看成两个函数p(x,y)=(x+y)与q(x,z)=(x+z)的复合:f=pq。假如你要求df/dz,首先我们先要求出df/dp与df/dq。很显然,df/dp=q,df/dq=p。然后要求dp/dx与dq/dx,显然,dp/dx=1.0,dq/dx=1.0。这个时候已经求到最底层了,可以利用链式法则求出最终的结果了:df/dx=(df/dp)(dp/dx)+(df/dq)(dq/dx)=q+p。同样的方法,可以求出:df/dy=q,df/dz=p。如果大家细致观察的话,可以看到要求出最终的导数,你需要计算出中间结果:p与q。计算中间结果的过程一般是前向(forward)过程,然后再反向(backward)计算出最终的导数。过程如下:
#程序1#
#输入
x, y, z = -3, 2, 5
#执行前向过程
p = x + y #p = -1
q = x + z #q = 2
#执行反向过程计算梯度
#第一个层反向:f = pq
dfdp = q #df / dp = 2
dfdq = p #df / dq = -1
#第二个层反向,并累积第一层梯度:
#p = x + y, q = x + z
dfdx = 1.0 * dfdp + 1.0 * dfdq #df / dx = 1
dfdy = 1.0 * dfdp #df / dy = 2
dfdz = 1.0 * dfdq #df / dz = -1
上面的一个过程就是BP算法,包含两个过程:前向forward)过程与反向(backword)过程。前向过程是从输入计算得到输出,而反向过程就是一个梯度累积的过程,或者说是BP,即误差反向传播。这就是BP的思想。上面的例子应该是比较简单的,而对于深度学习模型来说,其只不过是函数复杂一点罢了,但是如果你严格按照链式法则来去推导,只要你会基本求导方法,应该都不是什么难事了。
矩阵运算
其实对于深度学习模型来说,其运算都是基于矩阵运算的。对于新手来说,矩阵运算的求导可能会是一件比较头疼的事。其实矩阵运算求导是一个纸老虎。对于元素级的矩阵运算来说,比如激活函数这种,你完全可以把看成普通的求导。但是对于矩阵乘法,你需要特别注意,这里先抛出例子:
#程序2#
import numpy as np
#前向过程
W = np.random.randn(5,10)
X = np.random.randn(10,3)
D = W.dot(X)
#反向过程
#假定dD是后面传播过来的梯度项
dD = np.random.randn(*D.shape)
dW = dD.dot(X.T)
dX = W.T.dot(dD)
如果你认真推导的话,是可以得到上面的结果的。但是这里有其它捷径。对于两个矩阵相乘的话,在反向传播时反正是另外一个项与传播过来的梯度项相乘。差别就在于位置以及翻转。这里有个小窍门,就是最后计算出梯度肯定要与原来的矩阵是同样的shape。那么这就容易了,反正组合不多。比如你要计算dW,你知道要用dD与X两个矩阵相乘就可以得到。W的shape是[5,10],而dD的shape是[5,3],X的shape是[10,3]。要保证dW与W的shape一致,好吧,此时只能用dD.dot(X.T),真的没有其它选择了,那这就是对了。
活学活用:
实现一个简单的神经网络
上面我们讲了链式法则,也讲了BP的思想,并且也讲了如何对矩阵运算求梯度。下面我们基于Python中的Numpy库实现一个简单的神经网络模型,代码如下:
#程序3#
"""
一个简单两层神经网络回归模型
"""
import numpy as np
# batch size
N = 32
# 输入维度
D = 100
# 隐含层单元数
H = 200
# 输出维度
O = 10
# 训练样本(这里随机生成)
X = np.random.randn(N, D)
y = np.random.randn(N, O)
# 初始化参数
W1 = np.random.randn(D, H)
b1 = np.zeros((H,))
W2 = np.random.randn(H, O)
b2 = np.zeros((O,))
# 训练参数
learning_rate = 1e-02
iterations = 200
# 训练过程
for t in range(iterations): # 前向过程
h = X.dot(W1) + b1
h_relu = np.maximum(h, 0)
pred = h_relu.dot(W2) + b2
#定义loss,采用均方差
loss = np.sum(np.square(y - pred))
print("Iteration %d loss: %f" % (t, loss))
# 反向过程计算梯度
dpred = 2.0 * (pred - y)
db2 = np.sum(dpred, axis=0)
dW2 = h_relu.T.dot(db2)
dh_relu = db2.dot(W2.T)
dh = (h > 0) * dh_relu
db1 = np.sum(dh, axis=0)
dW1 = X.T.dot(dh)
# SGD更新梯度
params = [W1, b1, W2, b2]
grads = [dW1, db1, dW2, db2]
for p, g in zip(params, grads):
p += -learning_rate * g
总结
这里我们简单介绍了梯度下降法中最重要的一部分,就是如何计算梯度。相信通过本文,大家对BP算法以及链式法则有更深刻的理解。
参考资料
cs231n教程:http://cs231n.github.io/optimization-2/
- JavaScript 教程
- JavaScript 编辑工具
- JavaScript 与HTML
- JavaScript 与Java
- JavaScript 数据结构
- JavaScript 基本数据类型
- JavaScript 特殊数据类型
- JavaScript 运算符
- JavaScript typeof 运算符
- JavaScript 表达式
- JavaScript 类型转换
- JavaScript 基本语法
- JavaScript 注释
- Javascript 基本处理流程
- Javascript 选择结构
- Javascript if 语句
- Javascript if 语句的嵌套
- Javascript switch 语句
- Javascript 循环结构
- Javascript 循环结构实例
- Javascript 跳转语句
- Javascript 控制语句总结
- Javascript 函数介绍
- Javascript 函数的定义
- Javascript 函数调用
- Javascript 几种特殊的函数
- JavaScript 内置函数简介
- Javascript eval() 函数
- Javascript isFinite() 函数
- Javascript isNaN() 函数
- parseInt() 与 parseFloat()
- escape() 与 unescape()
- Javascript 字符串介绍
- Javascript length属性
- javascript 字符串函数
- Javascript 日期对象简介
- Javascript 日期对象用途
- Date 对象属性和方法
- Javascript 数组是什么
- Javascript 创建数组
- Javascript 数组赋值与取值
- Javascript 数组属性和方法
- 一次完整的JVM堆外内存泄漏故障排查记录
- Python 爬虫进阶必备 | 某视频数据分析平台加密参数分析(终于我还是手把手扣了代码)
- Python 爬虫进阶必备 | 某视频平台 sign 加密参数分析
- 进击吧!Pythonista(3/100)
- begin backup导致的故障恢复全过程
- 通过历史控制文件恢复Oracle数据库,只需这10步
- python应用(2):写个python程序给自己用
- 基于Prometheus+Grafana监控SQL Server数据库
- 手把手教你用R语言读取CSV文件
- 6个案例手把手教你用Python和OpenCV进行图像处理
- Android 10(Q)/11(R) 分区存储适配
- Usual*** CMS 8.0代码审计
- 由一条like语句引发的SQL注入新玩法
- 《黑神话:悟空》B站弹幕、知乎回答分析
- 12岁小读者使用Python暴力破解Wi-Fi密码