反向传播算法


在介绍反向传播算法前,先看看矩阵微分的概念。

矩阵微积分

为了书写简便,我们通常把单个函数对多个变量或者多元函数对单个变量的偏导数写成向量和矩阵的形式,使其可以被当成一个整体处理.

标量关于向量的偏导数

对于 /(M/) 维向量 /(/boldsymbol{x} /in /mathbb{R}^{M}/) 和函数 /(y=f(/boldsymbol{x}) /in /mathbb{R}/) , 则 /(y/) 关 于 /(/boldsymbol{x}/) 的偏导数为:

/[/frac{/partial y}{/partial /boldsymbol{x}}=/left[/frac{/partial y}{/partial x_{1}}, /cdots, /frac{/partial y}{/partial x_{M}}/right]^{/top}
/]

向量关于标量的偏导数

对于标量 /(x /in /mathbb{R}/) 和函数 /(/boldsymbol{y}=f(x) /in /mathbb{R}^{N}/) , 则 /(/boldsymbol{y}/) 关于 /(x/) 的 偏导数为:

/[/frac{/partial y}{/partial x}=/left[/frac{/partial y_{1}}{/partial x}, /cdots, /frac{/partial y_{N}}{/partial x}/right]
/]

向量关于向量的偏导数

对于 /(M/) 维向量 /(/boldsymbol{x} /in /mathbb{R}^{M}/) 和函数 /(/boldsymbol{y}=f(/boldsymbol{x}) /in /mathbb{R}^{N}/) , 则 /(f(/boldsymbol{x})/) 关于 /(/boldsymbol{x}/) 的偏导数为:

/[/frac{/partial f(/boldsymbol{x})}{/partial /boldsymbol{x}}=/left[/begin{array}{ccc}
/frac{/partial y_{1}}{/partial x_{1}} & /cdots & /frac{/partial y_{N}}{/partial x_{1}} //
/vdots & /ddots & /vdots //
/frac{/partial y_{1}}{/partial x_{M}} & /cdots & /frac{/partial y_{N}}{/partial x_{M}}
/end{array}/right]
/]

前向传播算法

根据之前的介绍,第/(l/)层的输出为:

/[/boldsymbol{a}^{(l)}=f_{l}/left(/boldsymbol{W}^{(l)} /boldsymbol{a}^{(l-1)}+/boldsymbol{b}^{(l)}/right)
/]

其中:

/[/boldsymbol{z}^{(l)}=/boldsymbol{W}^{(l)} /boldsymbol{a}^{(l-1)}+/boldsymbol{b}^{(l)}
/]

反向传播算法

假设采用随机梯度下降进行神经网络参数学习, 给定一个样本 /((/boldsymbol{x}, /boldsymbol{y})/) , 将其输入到神经网络模型中, 得到网络输出为 /(/hat{/boldsymbol{y}}/) . 假设损失函数为 /(/mathcal{L}(/boldsymbol{y}, /hat{/boldsymbol{y}})/) , 要进行参数学习就需要计算损失函数关于每个参数的导数.

不失一般性, 对第 /(l/) 层中的参数 /(/boldsymbol{W}^{(l)}/) 和 /(/boldsymbol{b}^{(l)}/) 计算偏导数. 因为 /(/frac{/partial /mathcal{L}(/boldsymbol{y}, /hat{y})}{/partial /boldsymbol{W}^{(l)}}/) 的计算 涉及向量对矩阵的微分, 十分繁琐, 因此我们先计算 /(/mathcal{L}(/boldsymbol{y}, /hat{/boldsymbol{y}})/) 关于参数矩阵中每个元素的偏导数 /(/frac{/partial /mathcal{L}(/boldsymbol{y}, /hat{y})}{/partial w_{i j}^{(l)}}/) . 根据链式法则:

/[/begin{array}{l}
/frac{/partial /mathcal{L}(/boldsymbol{y}, /hat{/boldsymbol{y}})}{/partial w_{i j}^{(l)}}=/frac{/partial /boldsymbol{z}^{(l)}}{/partial w_{i j}^{(l)}} /frac{/partial /mathcal{L}(/boldsymbol{y}, /hat{/boldsymbol{y}})}{/partial /boldsymbol{z}^{(l)}}, //
/frac{/partial /mathcal{L}(/boldsymbol{y}, /hat{/boldsymbol{y}})}{/partial /boldsymbol{b}^{(l)}}=/frac{/partial /boldsymbol{z}^{(l)}}{/partial /boldsymbol{b}^{(l)}} /frac{/partial /mathcal{L}(/boldsymbol{y}, /hat{/boldsymbol{y}})}{/partial /boldsymbol{z}^{(l)}}
/end{array}
/]

上面两个公式中的第二项都是目标函数关于第 /(l/) 层的神经元 /(/boldsymbol{z}^{(l)}/) 的偏导数, 称为误差项,可以一次计算得到.且记/(/delta^{(l)} /triangleq /frac{/partial /mathcal{L}(/boldsymbol{y}, /hat{/boldsymbol{y}})}{/partial /boldsymbol{z}^{(l)}}/).它的大小间接反应了其神经元对整个网络能力的贡献. 这样我们只需要计算三个偏导数, 分别为 /(/frac{/partial /boldsymbol{z}^{(l)}}{/partial w_{i j}^{(l)}}, /frac{/partial /boldsymbol{z}^{(l)}}{/partial /boldsymbol{b}^{(l)}}/) 和 /(/frac{/partial /mathcal{L}(/boldsymbol{y}, /hat{/boldsymbol{y}})}{/partial /boldsymbol{z}^{(l)}}/) .

下面分别来计算这三个偏导数

因 /(/boldsymbol{z}^{(l)}=/boldsymbol{W}^{(l)} /boldsymbol{a}^{(l-1)}+/boldsymbol{b}^{(l)}/) ,且/(w_{i j}^{(l)}/)为/(l-1/)层的第/(j/)个元素到/(l/)层的第/(i/)个元素连接的权重,所以:

/[/begin{aligned}
/frac{/partial /boldsymbol{z}^{(l)}}{/partial w_{i j}^{(l)}} &=/left[/frac{/partial z_{1}^{(l)}}{/partial w_{i j}^{(l)}}, /cdots, /frac{/partial z_{i}^{(l)}}{/partial w_{i j}^{(l)}}, /cdots, /frac{/partial z_{M_{l}}^{(l)}}{/partial w_{i j}^{(l)}}/right] //
&=/left[0, /cdots, /frac{/partial/left(/boldsymbol{w}_{i:}^{(l)} /boldsymbol{a}^{(l-1)}+b_{i}^{(l)}/right)}{/partial w_{i j}^{(l)}}, /cdots, 0/right] //
&=/left[0, /cdots, a_{j}^{(l-1)}, /cdots, 0/right]/in /mathbb{R}^{1 /times M_{l}}
/end{aligned}
/]

/[/frac{/partial /boldsymbol{z}^{(l)}}{/partial /boldsymbol{b}^{(l)}}=/boldsymbol{I}_{M_{l}}
/]

其中,/(/boldsymbol{I}_{M_{l}}/)是/({M_{l}}/times{M_{l}}/)的单位阵。

因/(/boldsymbol{z}^{(l+1)}=/boldsymbol{W}^{(l+1)} /boldsymbol{a}^{(l)}+/boldsymbol{b}^{(l+1)}/),/(/boldsymbol{a}^{(l)}=f_{l}/left(/boldsymbol{z}^{(l)}/right)/)所以:

/[/frac{/partial /boldsymbol{z}^{(l+1)}}{/partial /boldsymbol{a}^{(l)}}=/left(/boldsymbol{W}^{(l+1)}/right)^{/top}
/]

/[/begin{aligned}
/frac{/partial /boldsymbol{a}^{(l)}}{/partial /boldsymbol{z}^{(l)}} &=/frac{/partial f_{l}/left(/boldsymbol{z}^{(l)}/right)}{/partial /boldsymbol{z}^{(l)}} //
&=/operatorname{diag}/left(f_{l}^{/prime}/left(/boldsymbol{z}^{(l)}/right)/right)
/end{aligned}
/]

根据链式法则得:

/[/begin{aligned}
/delta^{(l)} & /triangleq /frac{/partial /mathcal{L}(/boldsymbol{y}, /hat{/boldsymbol{y}})}{/partial /boldsymbol{z}^{(l)}} //
&=/frac{/partial /boldsymbol{a}^{(l)}}{/partial /boldsymbol{z}^{(l)}} /cdot /frac{/partial /boldsymbol{z}^{(l+1)}}{/partial /boldsymbol{a}^{(l)}} /cdot /frac{/partial /mathcal{L}(/boldsymbol{y}, /hat{/boldsymbol{y}})}{/partial /boldsymbol{z}^{(l+1)}} //
&=/operatorname{diag}/left(f_{l}^{/prime}/left(/boldsymbol{z}^{(l)}/right)/right) /cdot/left(/boldsymbol{W}^{(l+1)}/right)^{/top} /cdot /delta^{(l+1)}
/end{aligned}
/]

从上式可以看出,第

原创文章,作者:ItWorker,如若转载,请注明出处:https://blog.ytso.com/288425.html

(0)
上一篇 2022年9月9日
下一篇 2022年9月9日

相关推荐

发表回复

登录后才能评论