标签:__ 自定义 bprop torch pytorch import grad mindspore
要迁移的项目为图像压缩算法https://github.com/ywz978020607/HESIC
1.自定义算子迁移--LowerBoundFunction类
为了能够准确迁移底层封装的类,需要详细测试原版类以及迁移测试
pytorch中自定义的算子有torch.autograd.Function
import torch
import torch.nn as nn
class LowerBoundFunction(torch.autograd.Function):
"""Autograd function for the `LowerBound` operator.
"""
@staticmethod
def forward(ctx, input_, bound):
ctx.save_for_backward(input_, bound)
return torch.max(input_, bound)
@staticmethod
def backward(ctx, grad_output):
input_, bound = ctx.saved_tensors
pass_through_if = (input_ >= bound) | (grad_output < 0)
print(pass_through_if)
print(pass_through_if.type(grad_output.dtype) * grad_output)
return pass_through_if.type(grad_output.dtype) * grad_output, None
if __name__=="__main__":
a = torch.Tensor([1,2,3])
b = torch.Tensor([0,1,5])
a.requires_grad_(True)
b.requires_grad_(True)
c = a*b
m = LowerBoundFunction.apply(a,b)
m.backward(c)
通过两行print测试后发现,这个类用于阻断梯度,有点类似Relu的感觉
而mindspore的自定义算子在昇腾、GPU、CPU下定义不同且过于复杂,咨询hw工程师后,准备继承nn.Cell并重载bprop函数实现,测试bprop反向梯度传播如下
# https://gitee.com/mindspore/mindspore/blob/master/tests/ut/python/pynative_mode/test_hook.py#
import numpy as np
import pytest
import mindspore.nn as nn
import mindspore.ops.operations as P
from mindspore import context, Tensor, ParameterTuple
from mindspore.common.initializer import TruncatedNormal
from mindspore.nn import WithLossCell, Momentum
from mindspore.ops import composite as C
context.set_context(mode=context.PYNATIVE_MODE, device_target="CPU")
grad_all = C.GradOperation(get_all=True)
bprop_debug = False
class MulAdd(nn.Cell):
def __init__(self):
super(MulAdd, self).__init__()
def construct(self, x, y):
return 2 * x * x + y * y
def bprop(self, x, y, out, dout):
global bprop_debug
bprop_debug = True
return dout, 2 * y
def test_custom_bprop():
mul_add = MulAdd()
mul_add.bprop_debug = True
x = Tensor(np.array([1, 2, 3]).astype(np.int32))
y = Tensor(np.array([2, 3, 4]).astype(np.int32))
ret = grad_all(mul_add)(x, y)
print(ret) #(Tensor(shape=[3], dtype=Int32, value= [1, 1, 1]), Tensor(shape=[3], dtype=Int32, value= [4, 6, 8]))
assert bprop_debug
##############
#ywz
test_custom_bprop()
print(bprop_debug)
标签:__,自定义,bprop,torch,pytorch,import,grad,mindspore 来源: https://www.cnblogs.com/sharklet/p/14983398.html
本站声明: 1. iCode9 技术分享网(下文简称本站)提供的所有内容,仅供技术学习、探讨和分享; 2. 关于本站的所有留言、评论、转载及引用,纯属内容发起人的个人观点,与本站观点和立场无关; 3. 关于本站的所有言论和文字,纯属内容发起人的个人观点,与本站观点和立场无关; 4. 本站文章均是网友提供,不完全保证技术分享内容的完整性、准确性、时效性、风险性和版权归属;如您发现该文章侵犯了您的权益,可联系我们第一时间进行删除; 5. 本站为非盈利性的个人网站,所有内容不会用来进行牟利,也不会利用任何形式的广告来间接获益,纯粹是为了广大技术爱好者提供技术内容和技术思想的分享性交流网站。