ICode9

精准搜索请尝试: 精确搜索
首页 > 其他分享> 文章详细

超详细的softmax的反向传播梯度计算推导

2020-03-10 13:44:29  阅读:1432  来源: 互联网

标签:frac zi 推导 ai 梯度 ak softmax partial sum


Softmax及其求导

正向传播

为了方便理解,若输入,输出只有3个变量(下面的普通性情况不太理解的可以带入此特殊情况帮助理解)。

  • 输入:输出层神经元 Z = [z1z_1z1​, z2z_2z2​, z3z_3z3​],分类标签 Y = [y1,y2,y3y_1, y_2, y_3y1​,y2​,y3​](Y是one-hot标签,只有一个yiy_iyi​值为1,其他全为0)
  • 输出:A = softmax(Z) = [a1,a2,a3a_1, a_2, a_3a1​,a2​,a3​]

更一般的,假设有n个神经元(或者说n类),softmax公式为:
ai=ezij=1nezj a_i = \frac{e^{z_i}}{\sum_{j=1}^n e^{z_j}} ai​=∑j=1n​ezj​ezi​​
损失函数采用交叉熵,公式为:
L=i=1nyilnai L = -\sum_{i=1}^n y_ilna_i L=−i=1∑n​yi​lnai​

求导

先说结论,按向量形式表示LZ=AY\frac{\partial L}{\partial Z} = A - Y∂Z∂L​=A−Y

证明:在这里我们假设第k个神经元为正确标签,即在Y = [y1,y2,y3y_1, y_2, y_3y1​,y2​,y3​]中yky_kyk​ = 1,其他yiy_iyi​都为0。

首先求L对A的导数
Lai=i=1nyilnaiai=yiai \frac{\partial L}{\partial a_i} = \frac{\partial -\sum_{i=1}^n y_ilna_i}{\partial a_i}=-\frac{y_i}{a_i} ∂ai​∂L​=∂ai​∂−∑i=1n​yi​lnai​​=−ai​yi​​

再求L对Z的导数,这里需要注意,在正向传播时,每一个aia_iai​的计算都有所有的zjz_jzj​参加(请看softmax的公式的分母,是求和)
Lzi=j=1nLajajzi \frac{\partial L}{\partial z_i} = \sum_{j=1}^n \frac{\partial L}{a_j} \cdot \frac{\partial a_j}{\partial z_i} ∂zi​∂L​=j=1∑n​aj​∂L​⋅∂zi​∂aj​​

再次强调,因为每一个aia_iai​的计算都由所有的zjz_jzj​参加,换句话说,每一个zjz_jzj​都包含在每一个aia_iai​中,因此对zjz_jzj​求偏导要先对所有的aia_iai​求偏导。

而我们事先假设好了yk=1y_k = 1yk​=1,其他yiy_iyi​都为0,因此只有Lak\frac{\partial L}{\partial a_k}∂ak​∂L​不为0,其他都为0,可以进一步将上式化简为:
Lzi=j=1nLajajzi=Lakakzi \frac{\partial L}{\partial z_i} = \sum_{j=1}^n \frac{\partial L}{a_j} \cdot \frac{\partial a_j}{\partial z_i} = \frac{\partial L}{\partial a_k} \cdot \frac{\partial a_k}{\partial z_i} ∂zi​∂L​=j=1∑n​aj​∂L​⋅∂zi​∂aj​​=∂ak​∂L​⋅∂zi​∂ak​​
由此可见,我们的重点就是要求akzi\frac{\partial a_k}{\partial z_i}∂zi​∂ak​​了。对于akzi\frac{\partial a_k}{\partial z_i}∂zi​∂ak​​的求解,要分两种情况。

  • 若i = k
    akzi=akzk=(ezkj=1nezj)zk=ezk(j=1nezj)(ezk)2(j=1nezj)2=ezkj=1nezj(ezkj=1nezj)2=akak2=ak(1ak) \frac{\partial a_k}{\partial z_i} = \frac{\partial a_k}{\partial z_k} = \frac{\partial (\frac{e^{z_k}}{\sum_{j=1}^n e^{z_j}})}{\partial z_k} = \frac{e^{z_k}(\sum_{j=1}^n e^{z_j}) - (e^{z_k})^2}{(\sum_{j=1}^n e^{z_j})^2} = \frac{e^{z_k}}{\sum_{j=1}^n e^{z_j}} - (\frac{e^{z_k}}{\sum_{j=1}^n e^{z_j}})^2 \\= a_k - a_k^2 = a_k(1 - a_k) ∂zi​∂ak​​=∂zk​∂ak​​=∂zk​∂(∑j=1n​ezj​ezk​​)​=(∑j=1n​ezj​)2ezk​(∑j=1n​ezj​)−(ezk​)2​=∑j=1n​ezj​ezk​​−(∑j=1n​ezj​ezk​​)2=ak​−ak2​=ak​(1−ak​)

  • 若i \neq​= k
    akzi=(ezkj=1nezj)zi=ezkezi(j=1nezj)2=ezkj=1nezjezij=1nezj=akai \frac{\partial a_k}{\partial z_i} = \frac{\partial (\frac{e^{z_k}}{\sum_{j=1}^n e^{z_j}})}{\partial z_i} = \frac{-e^{z_k} \cdot e^{z_i}}{(\sum_{j=1}^n e^{z_j})^2} = -\frac{e^{z_k}}{\sum_{j=1}^n e^{z_j}} \cdot \frac{e^{z_i}}{\sum_{j=1}^n e^{z_j}} \\= -a_k \cdot a_i ∂zi​∂ak​​=∂zi​∂(∑j=1n​ezj​ezk​​)​=(∑j=1n​ezj​)2−ezk​⋅ezi​​=−∑j=1n​ezj​ezk​​⋅∑j=1n​ezj​ezi​​=−ak​⋅ai​

结合我们求出的Lai\frac{\partial L}{\partial a_i}∂ai​∂L​和aizj\frac{\partial a_i}{\partial z_j}∂zj​∂ai​​,这里直接写出L对整个Z的导数,我们可得
LZ=j=1ni=1nLajajzi=1aki=1nakzi=1aki=1,ikn(akai)1aki=k(ai(1ai))=i=1,iknai+i=kai1=i=1nai1 \frac{\partial L}{\partial Z} = \sum_{j=1}^n \sum_{i=1}^n \frac{\partial L}{\partial a_j} \cdot \frac{\partial a_j}{\partial z_i} = -\frac{1}{a_k} \sum_{i=1}^n \frac{\partial a_k}{\partial z_i} = -\frac{1}{a_k} \cdot \sum_{i=1, i\neq k}^n (-a_ka_i) - \frac{1}{a_k} \cdot \sum_{i=k} (a_i(1 - a_i)) \\= \sum_{i=1, i\neq k}^n a_i + \sum_{i=k} a_i - 1 = \sum_{i=1}^n a_i - 1 ∂Z∂L​=j=1∑n​i=1∑n​∂aj​∂L​⋅∂zi​∂aj​​=−ak​1​i=1∑n​∂zi​∂ak​​=−ak​1​⋅i=1,i​=k∑n​(−ak​ai​)−ak​1​⋅i=k∑​(ai​(1−ai​))=i=1,i​=k∑n​ai​+i=k∑​ai​−1=i=1∑n​ai​−1
我们又知道yky_kyk​ = 1,除了yky_kyk​其他yiy_iyi​都等于0,因此上式的1可以写成yky_kyk​,在进一步写成i=1nyi\sum_{i=1}^n y_i∑i=1n​yi​,带入上式
LZ=i=1nai1=i=1naii=1nyi=AY \frac{\partial L}{\partial Z} = \sum_{i=1}^n a_i - 1 = \sum_{i=1}^n a_i - \sum_{i=1}^n y_i = A - Y ∂Z∂L​=i=1∑n​ai​−1=i=1∑n​ai​−i=1∑n​yi​=A−Y

证明完毕~

标签:frac,zi,推导,ai,梯度,ak,softmax,partial,sum
来源: https://blog.csdn.net/weixin_43217928/article/details/104772424

本站声明: 1. iCode9 技术分享网(下文简称本站)提供的所有内容,仅供技术学习、探讨和分享;
2. 关于本站的所有留言、评论、转载及引用,纯属内容发起人的个人观点,与本站观点和立场无关;
3. 关于本站的所有言论和文字,纯属内容发起人的个人观点,与本站观点和立场无关;
4. 本站文章均是网友提供,不完全保证技术分享内容的完整性、准确性、时效性、风险性和版权归属;如您发现该文章侵犯了您的权益,可联系我们第一时间进行删除;
5. 本站为非盈利性的个人网站,所有内容不会用来进行牟利,也不会利用任何形式的广告来间接获益,纯粹是为了广大技术爱好者提供技术内容和技术思想的分享性交流网站。

专注分享技术,共同学习,共同进步。侵权联系[81616952@qq.com]

Copyright (C)ICode9.com, All Rights Reserved.

ICode9版权所有