ICode9

精准搜索请尝试: 精确搜索
首页 > 其他分享> 文章详细

MindSpore网络自定义反向报错:TypeError: The params of function 'bprop' of

2022-07-16 21:37:17  阅读:142  来源: 互联网

标签:function 自定义 bprop self py Cell 报错 mindspore out


1. 报错描述

1.1 系统环境

Hardware Environment(Ascend/GPU/CPU): GPU
Software Environment:

  • MindSpore version (source or binary): 1.7.0
  • Python version (e.g., Python 3.7.5): 3.7.5
  • OS platform and distribution (e.g., Linux Ubuntu 16.04): Ubuntu 18.04.4 LTS
  • GCC/Compiler version (if compiled from source): 7.5.0

1.2 基本信息

1.2.1 源码

import mindspore as ms
import mindspore.nn as nn
from mindspore.common.tensor import Tensor
from mindspore.ops import composite as C

grad_all = C.GradOperation(get_all=True)

class MulAdd(nn.Cell):
    def construct(self, x, y):
        return 2 * x + y

    def bprop(self, x, y, out):
        return 2 * x, 2 * y
mul_add = MulAdd()
x = Tensor(1, dtype=ms.int32)
y = Tensor(2, dtype=ms.int32)
output = grad_all(mul_add)(x, y)

1.2.2 报错

TypeError: The params of function 'bprop' of Primitive or Cell requires the forward inputs as well as the 'out' and 'dout'

Traceback (most recent call last):
  File "test_grad.py", line 20, in <module>
    output = grad_all(mul_add)(x, y)
  File "/home/liangzhibo/mindspore/build/package/mindspore/common/api.py", line 522, in staging_specialize
    out = _MindsporeFunctionExecutor(func, hash_obj, input_signature, process_obj)(*args)
  File "/home/liangzhibo/mindspore/build/package/mindspore/common/api.py", line 93, in wrapper
    results = fn(*arg, **kwargs)
  File "/home/liangzhibo/mindspore/build/package/mindspore/common/api.py", line 353, in __call__
    phase = self.compile(args_list, self.fn.__name__)
  File "/home/liangzhibo/mindspore/build/package/mindspore/common/api.py", line 321, in compile
    is_compile = self._graph_executor.compile(self.fn, compile_args, phase, True)
TypeError: The params of function 'bprop' of Primitive or Cell requires the forward inputs as well as the 'out' and 'dout'.
In file test_grad.py(13)
    def bprop(self, x, y, out):
    ^

----------------------------------------------------
- The Traceback of Net Construct Code:
----------------------------------------------------

# In file test_grad.py(13)
    def bprop(self, x, y, out):
    ^

----------------------------------------------------
- C++ Call Stack: (For framework developers)
----------------------------------------------------
mindspore/ccsrc/frontend/optimizer/ad/kprim.cc:651 BuildOutput

2. 原因分析与解决方法

在这个用例中, 我们使用了Cell的自定义反向规则。 而报错信息也提示了我们是自定义规则的输入, 即

def bprop(self, x, y, out):

这句话存在错误。 

在自定义Cell的反向规则bprop时, 需要接受三类输入, 分别是Cell的正向输入(在本用例中为x, y), Cell的正向输出(在本用例中为out),以及输入网络反向的累加梯度(dout)。本用例中正式因为缺少了dout输入, 因此运行失败。 因此我们只需要将代码更改为:

def bprop(self, x, y, out, dout):
    return 2 * x, 2 * y

 程序即可正常运行。

下图表示了三类输入分别的意义, dout为反向图前一个节点输出的梯度, bprop函数需要此输入来对计算的梯度进行继承与使用。

Untitled Diagram.png

另外, bprop的三类输入是构图的时候需要使用的, 因此即使某些输入在bprop函数中没有被使用, 也是需要传入bprop中的。

3. 参考文档

https://www.mindspore.cn/tutorials/zh-CN/master/advanced/network/derivation.html#%E8%87%AA%E5%AE%9A%E4%B9%89%E5%8F%8D%E5%90%91%E4%BC%A0%E6%92%AD%E5%87%BD%E6%95%B0

标签:function,自定义,bprop,self,py,Cell,报错,mindspore,out
来源: https://www.cnblogs.com/skytier/p/16485291.html

本站声明: 1. iCode9 技术分享网(下文简称本站)提供的所有内容,仅供技术学习、探讨和分享;
2. 关于本站的所有留言、评论、转载及引用,纯属内容发起人的个人观点,与本站观点和立场无关;
3. 关于本站的所有言论和文字,纯属内容发起人的个人观点,与本站观点和立场无关;
4. 本站文章均是网友提供,不完全保证技术分享内容的完整性、准确性、时效性、风险性和版权归属;如您发现该文章侵犯了您的权益,可联系我们第一时间进行删除;
5. 本站为非盈利性的个人网站,所有内容不会用来进行牟利,也不会利用任何形式的广告来间接获益,纯粹是为了广大技术爱好者提供技术内容和技术思想的分享性交流网站。

专注分享技术,共同学习,共同进步。侵权联系[81616952@qq.com]

Copyright (C)ICode9.com, All Rights Reserved.

ICode9版权所有