标签:__ log 自定义 self act Pytorch alpha 激活 out
(每个人错误情况可能不一样,仅供参考。以复现一篇含自定义激活函数论文为例。)
1:拟设计激活函数:
2:代码编辑:
class log_act(nn.Module):
def __init__(self, alpha, beta, positive_flag=True):
super(log_act, self).__init__()
self.positive_flag = positive_flag
self.a = alpha
self.C1 = beta
def forward(self, x):
x = x.detach().numpy()
C0 = np.exp(-1)
if self.positive_flag:
alpha = self.a
else:
alpha = -self.a
out = alpha * np.log(np.greater(x, 0)+C0) + self.C1
out = torch.tensor(out)
return out
```
#设置自定义输出值条件定义:
class log_act_helper(nn.Module):
def __init__(self, alpha, beta):
super(log_act_helper, self).__init__()
self.a = alpha
self.C1 = beta
def forward(self, x):
x0 = x.clone().detach()
f1 = log_act(alpha=self.a, beta=self.C1, positive_flag=True)
out1 = f1(x0)
f2 = log_act(alpha=self.a, beta=self.C1, positive_flag=False)
out2 = f2(x0)
out = torch.where(np.greater(x0, 0), out1, out2)
# out = torch.tensor(out)
return out
激活函数设计完毕。
问题1:
查找资料:带转换PyTorch张量带有梯度,直接转换为numpy数据会破坏梯度图。转换数据不需要保留梯度信息,x=x.detach().numpy()
(不适合我的错误,尝试之后未解决)
**解决方案:**对卷积之后得到张量x进行克隆,传入激活函数计算:x=x.clone().detach()
或者使用x=x.clone().detach().requires_grad_(True)
其中clone()复制张量,梯度流仍流向原来的张量。detach()张量脱离计算图,不牵扯梯度计算。requires_grad_是否需要梯度。
问题2:
**解决方案:**问题是:日志中遇到了无效值。log函数图像如下:
基于log(对数)函数log(A) 中当A<0时,对数函数不成立,必须取A>0,更改为:out = alpha * (np.log(np.greater(x, 0)+C0)) + 1
比较x与0大小,确保A取正值。
3:测试:
if __name__=="__main__":
x = torch.rand(3, 3, 448, 448)
conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3,
bias=False)
x = conv1(x)
bn1 = nn.BatchNorm2d(64)
x = bn1(x)
activation = log_act_helper(alpha=0.2, beta=1)
out = activation(x)
print(out.shape)
输出:
以上是激活函数不需要梯度情况。需要梯度情况,需要重定义forward和backward参考: https://pytorch.org/tutorials/beginner/examples_autograd/two_layer_net_custom_function.html添加链接描述
标签:__,log,自定义,self,act,Pytorch,alpha,激活,out 来源: https://blog.csdn.net/qq_44631242/article/details/122223274
本站声明: 1. iCode9 技术分享网(下文简称本站)提供的所有内容,仅供技术学习、探讨和分享; 2. 关于本站的所有留言、评论、转载及引用,纯属内容发起人的个人观点,与本站观点和立场无关; 3. 关于本站的所有言论和文字,纯属内容发起人的个人观点,与本站观点和立场无关; 4. 本站文章均是网友提供,不完全保证技术分享内容的完整性、准确性、时效性、风险性和版权归属;如您发现该文章侵犯了您的权益,可联系我们第一时间进行删除; 5. 本站为非盈利性的个人网站,所有内容不会用来进行牟利,也不会利用任何形式的广告来间接获益,纯粹是为了广大技术爱好者提供技术内容和技术思想的分享性交流网站。