ICode9

精准搜索请尝试: 精确搜索
首页 > 其他分享> 文章详细

PyTorch笔记——FX

2021-10-22 19:59:45  阅读:353  来源: 互联网

标签:None torch FX self args 笔记 PyTorch kwargs 节点


官方文档链接:https://pytorch.org/docs/master/fx.html#

概述

FX是供开发人员用于转换nn.Module实例的工具包。FX由三个主要组件组成:符号追踪:symbolic tracer, 中间层表示:intermediate representation, Python代码生成:Python code generation。这些组件的运行演示:

import torch
# Simple module for demonstration
class MyModule(torch.nn.Module):
    def __init__(self):
        super().__init__()
        self.param = torch.nn.Parameter(torch.rand(3, 4))
        self.linear = torch.nn.Linear(4, 5)

    def forward(self, x):
        return self.linear(x + self.param).clamp(min=0.0, max=1.0)

module = MyModule()

from torch.fx import symbolic_trace
# Symbolic tracing frontend - captures the semantics of the module
symbolic_traced : torch.fx.GraphModule = symbolic_trace(module)

# High-level intermediate representation (IR) - Graph representation
print(symbolic_traced.graph)
"""
graph():
    %x : [#users=1] = placeholder[target=x]
    %param : [#users=1] = get_attr[target=param]
    %add : [#users=1] = call_function[target=operator.add](args = (%x, %param), kwargs = {})
    %linear : [#users=1] = call_module[target=linear](args = (%add,), kwargs = {})
    %clamp : [#users=1] = call_method[target=clamp](args = (%linear,), kwargs = {min: 0.0, max: 1.0})
    return clamp
"""

# Code generation - valid Python code
print(symbolic_traced.code)
"""
def forward(self, x):
    param = self.param
    add = x + param;  x = param = None
    linear = self.linear(add);  add = None
    clamp = linear.clamp(min = 0.0, max = 1.0);  linear = None
    return clamp
"""

符号追踪(symbolic tracer): 对Python代码进行"符号执行"。它以构造的值(也叫作:代理Proxies)为输入,贯穿运行所有代码。记录下对这些Proxie的操作。更多的符号追踪的信息可见 symbolic_trace()Tracer的相关文档。

**中间层表示(intermediate representation): ** 它里面保存了在符号追中中记录下的运算操作。它由表示函数输入、调用哪些对象(函数、方法或torch.nn.Module实例)和返回值的节点列表组成。关于IR的更多信息可以在Graph的文档中找到。IR是应用转换的格式。

**Python代码生成(Python code generation): ** Python代码生成使FX成为Python代码到Python代码(或模块到模块)转换工具包。对于每个 Graph IR,我们可以创建与图的语义匹配的有效Python代码。此功能包含在GraphModule中,GraphModule是一个torch.nn.Module实例,它包含一个图以及从该图生成的正向方法。

综合起来,这个组件的流水线(符号跟踪 -> 中间表示 -> 转换 -> Python 代码生成)构成了 FX 的 Python-to-Python 转换通道。 此外,这些组件可以单独使用。 例如,可以单独使用符号跟踪来捕获代码形式以用于分析(而不是转换)。 代码生成可用于以编程方式生成模型,例如从配置文件生成模型。 FX 有很多用途!

在示例库中有几个转换的样例

API

symbolic_trace

torch.fx.symbolic_trace(root, concrete_args=None, enable_cpatching=False)

符号追踪的函数,以nn.Module或者函数实例为输入,然后将追踪过程中记录的操作记录下来构造一个GraphModule对象并返回。

concrete_args的作用是根据函数中的分支和参数进行定制化,无论是删除控制流还是数据结构。

例如:

def f(a, b):
    if b == True:
        return a
    else:
        return a*2

由于控制流的存在,FX通常无法正常的追踪。但是,我们可以使用concrete_args指定b的值来解决该问题。

f = fx.symbolic_trace(f, concrete_args={‘b’: False}) assert f(3, False) == 6

注意,虽然你仍然可以给b传不同的值,但是这些值都会被忽略掉。

我们还可以使用concrete_args来消除函数中的数据结构处理。这将使用pytrees将输入展开。为避免过度定制,请为不应指定固定值的传入fx.PH。例如:

def f(x):
    out = 0
    for v in x.values():
        out += v
    return out
f = fx.symbolic_trace(f, concrete_args={'x': {'a': fx.PH, 'b': fx.PH, 'c': fx.PH}})
assert f({'a': 1, 'b': 2, 'c': 4}) == 7

参数

  • **root(torch.nn.Module或者可调用对象): ** 要跟踪并转换为Graph的Module或函数。
  • **concrete_args (可选[Dict[str, any]]): **定制化部分输入
  • **enable_cpatching: ** 启用C级功能补全(捕获类似torch.randn的内容)

返回值

通过遍历root获得的相关计算操作创建出的Module。

返回值类型

GraphModule

注意:保证此 API 的向后兼容性。

CLASS torch.fx.Graph

CLASS torch.fx.Graph(owning_module=None, tracer_cls=None)

Graph是FX中间层表示的主要数据结构。它包含一组Node,每个Node表示了一个调用关系(或其他语法结构)。这些Node组合在一起构成了完整的Python功能。

样例:

import torch
import torch.fx

class MyModule(torch.nn.Module):
    def __init__(self):
        super().__init__()
        self.param = torch.nn.Parameter(torch.rand(3, 4))
        self.linear = torch.nn.Linear(4, 5)

    def forward(self, x):
        return torch.topk(torch.sum(self.linear(x + self.linear.weight).relu(), dim=-1), 3)

m = MyModule()
gm = torch.fx.symbolic_trace(m)

这样我们就构造了下面的Graph。

> print(gm.graph)
> graph(x):
    %linear_weight : [#users=1] = self.linear.weight
    %add_1 : [#users=1] = call_function[target=operator.add](args = (%x, %linear_weight), kwargs = {})
    %linear_1 : [#users=1] = call_module[target=linear](args = (%add_1,), kwargs = {})
    %relu_1 : [#users=1] = call_method[target=relu](args = (%linear_1,), kwargs = {})
    %sum_1 : [#users=1] = call_function[target=torch.sum](args = (%relu_1,), kwargs = {dim: -1})
    %topk_1 : [#users=1] = call_function[target=torch.topk](args = (%sum_1, 3), kwargs = {})
    return topk_1

call_function(the_function, args=None, kwargs=None, type_expr=None)

在Graph中插入一个call_function的节点。call_function 节点表示对 Python 可调用对象的调用,被调用对象由 the_function 指定。
插入的位置选择同create_node。

call_method(method_name, args=None, kwargs=None, type_expr=None)

在Graph中插入一个call_method的节点。call_method 节点表示对 args 的第 0 个元素上的给定方法的调用。
插入的位置选择同create_node。

call_module(module_name, args=None, kwargs=None, type_expr=None)

在Graph中插入call_module 节点。 call_module 节点表示对模块层次结构中模块的 forward()函数的调用。
插入的位置选择同create_node。

create_node(op, target, args=None, kwargs=None, name=None, type_expr=None)

创建一个节点并将其添加到当前插入点的图形中。 请注意,当前插入点可以通过 Graph.inserting_before() 和 Graph.inserting_after() 设置。

eliminate_dead_code()

根据每个节点的用户数以及删除节点是否有任何其他影响,从图中删除所有无用代码。调用之前,必须对图形进行拓扑排序。

erase_node(to_erase)

从图形中删除节点。如果图中仍在使用该节点,则抛出异常。

flatten_inps(*args)

get_attr(qualified_name, type_expr=None)

在Graph中插入一个get_attr节点。get_attr节点表示从模块层次结构中获取属性。
插入的位置选择同create_node。

graph_copy(g, val_map, return_output_node=False)

将所给的Graph中所有节点拷贝一份。

inserting_after(n=None)

设置 create_node 和配套方法将插入图中的点。
使用with语句时,临时设置插入点,然后在 with 语句退出时恢复原来的值:

with g.inserting_after(n):
    ... # inserting after node n
... # insert point restored to what it was previously
g.inserting_after(n) #  set the insert point permanently

inserting_before(n=None)

类似after

lint()

对此图运行各种检查以确保其格式正确。 特别是: - 检查节点是否具有正确的所有权(由该图拥有) - 检查节点是否按拓扑顺序出现 - 如果该图有一个拥有的 GraphModule,则检查该 GraphModule 中是否存在目标

node_copy(node, arg_transform=<function Graph.>)

将一个节点从一个图中复制到另一个图中。 arg_transform 需要将参数从源Graph转换到目的Graph。 例子:

# Copying all the nodes in `g` into `new_graph`
g : torch.fx.Graph = ...
new_graph = torch.fx.graph()
value_remap = {}
for node in g.nodes:
    value_remap[node] = new_graph.node_copy(node, lambda n : value_remap[n])

PROPERTY nodes

获取Graph的Node列表。
请注意,此节点列表表示是一个双向链表。 迭代期间的突变(例如删除节点、添加节点)是安全的。

output(result, type_expr=None)

在Graph中插入一个output节点。output节点表示 Python 代码中的返回语句。 result是应该返回的值。

PROPERTY owning_module

如果有拥有此 GraphModule 的模块,则返回该模块,如果没有或有多个则返回 None。

placeholder(name, type_expr=None)

在Graph中插入一个placeholder。placeholder表示函数的输入。

print_tabular()

以表格格式打印图形的IR。 请注意,此 API 需要安装 tabulate 模块。

python_code(root_module)

将该Graph转换为有效的Python代码。

unflatten_outs(out)

CLASS torch.fx.Node

CLASStorch.fx.Node(graph, name, op, target, args, kwargs, return_type=None)

Node是Graph中表示一个独立计算单元的数据结构。在大多数情况下,Node代表对各种entity的调用关系,例如运算符、方法和模块(一些例外包括指定函数输入和输出的节点)。 每个节点都有一个由其 op 属性指定的函数。 op的每个值的Node语义如下:

  • **placeholder : ** 表示函数输入。 name 属性指定此值将采用的名称。 target 同样是参数的名称。 args 包含:1) 什么都没有,或 2) 表示函数输入的默认参数的单个参数。 kwargs 不用关心的。 占位符对应于图形打印输出中的函数参数(例如 x)。
  • **get_attr: ** 从模块层次结构中检索参数。 name 与提取结果的名称类似。 target 是模块层次结构中参数位置的完全限定名称。 args 和 kwargs不用关心.
  • **call_function: ** 对某些值应用的自由函数也就是非成员函数。 name 同样是要分配给的值的名称。 target 是要应用的函数。 args 和 kwargs 表示函数的参数,遵循 Python 调用约定
  • **call_module: ** 将模块层次结构的forward()方法中的模块应用于给定参数。name和前面一样。target是要调用的模块层次结构中模块的完全限定名。args和kwargs表示要在其上调用模块的参数,包括self参数。
  • **call_method: ** 对值调用方法。name的含义一样。target是要应用于自参数的方法的字符串名称。args和kwargs表示要在其上调用模块的参数,包括self参数
  • **output: ** 在其args[0]属性中包含跟踪函数的输出。这对应于图形打印输出中的"return"语句。

PROPERTY all_input_nodes

获取该节点的所有输入节点。这相当于找出 args 和 kwargs中值为Node的参数。

append(x)

在图中的节点列表中,在此节点后插入x。与self.next.prepend(x)功能相同。

PROPERTY args

此节点的参数元组。参数的解释取决于节点的操作码。有关详细信息,请参阅节点docstring。
允许对此属性进行赋值。使用和用户的所有记帐在分配时自动更新。

format_node(placeholder_names=None, maybe_return_typename=None)

返回描述本Node的字符串。

is_impure()

返回此op是否是纯操作,即其op是否为占位符或输出,或者是否是纯的call_module或call_function。

PROPERTY kwargs

此节点的关键字参数的dict。参数的解释取决于节点的op代码实现。有关详细信息,请参阅node docstring。

PROPERTY next

返回该节点的下一个Node

normalized_arguments(root, arg_types=None, kwarg_types=None, normalize_to_only_use_kwargs=False)

将规范化参数返回到Python目标。这意味着,如果normalize_to_only_use_kwargs为true,则args/kwargs将与模块/函数的签名匹配,并以位置顺序专门返回kwargs。还填充默认值。不支持仅位置参数或varargs参数。
支持模块调用。
可能需要arg_类型和kwarg_类型以消除重载的歧义。

prepend(x)

在该节点前插入x节点

replace_all_uses_with(replace_with)

将图中所有用本节点的地方替换为节点replace_with。

replace_input_with(old_input, new_input)

遍历Node的所有输入,并将old_input都替换为new_input。

PROPERTY stack_trace

返回跟踪期间记录的Python堆栈跟踪(如果有)。此属性通常由Tracer.create_proxy填充。要在跟踪过程中记录堆栈跟踪以进行调试,请在跟踪程序实例上设置record_stack_traces=True。

update_arg(idx, arg)

更新现有参数使第inx个参数值为arg。调用后,self.args[idx]==arg。

update_kwarg(key, arg)

更新现有kwarg参数新增键值为key对应值为arg的参数。调用后,self.kwargs[key]==arg。

torch.fx.replace_pattern(gm, pattern, replacement)

找到GraphModule中符合pattern匹配规则的所有运算符集,然后用replacement替换掉。

参数

  • gm: 要操作的GraphModule
  • **pattern: ** 匹配的模式
  • **replacement: ** 要替换成的目的子图

返回值

匹配对象列表,表示模式匹配到的原始图形中的位置。如果没有匹配项,则列表为空。匹配定义为:

class Match(NamedTuple):
    # Node from which the match was found
    anchor: Node
    # Maps nodes in the pattern subgraph to nodes in the larger graph
    nodes_map: Dict[Node, Node]

返回值

List[Match]

例子:

import torch
from torch.fx import symbolic_trace, subgraph_rewriter

class M(torch.nn.Module):
    def __init__(self):
        super().__init__()

    def forward(self, x, w1, w2):
        m1 = torch.cat([w1, w2]).sum()
        m2 = torch.cat([w1, w2]).sum()
        return x + torch.max(m1) + torch.max(m2)

def pattern(w1, w2):
    return torch.cat([w1, w2]).sum()

def replacement(w1, w2):
    return torch.stack([w1, w2])

traced_module = symbolic_trace(M())

subgraph_rewriter.replace_pattern(traced_module, pattern, replacement)

上面的代码将首先在traced_module的forward方法中对pattern进行匹配。模式匹配是基于use-def关系而不是节点名称完成的。例如,如果模式中有p=torch.cat([a,b]),则可以在原始forward函数中匹配m=torch.cat([a,b]),尽管变量名称不同(p与m)。

模式中的return语句仅基于其值进行匹配;它可能与较大图形中的return语句匹配,也可能不匹配。换句话说,模式不必延伸到较大图形的末尾。

当pattern匹配成功时,它将从较大的函数中删除,并用replacement来替换。如果在较大的函数中有匹配成功多个,则将替换每个不重叠的匹配。在匹配重叠的情况下,将替换重叠匹配集中找到的第一个匹配。(“第一个”在这里被定义为节点use-def关系拓扑顺序中的第一个。在大多数情况下,第一个节点是直接出现在self之后的参数,而最后一个节点是函数返回的任何值。)

如果pattern是可调用的,则它的参数必须在pattern里面使用,replacement的参数必须与pattern的参数匹配。第一条规则,为什么在上面的代码块中,forward函数有参数x、w1、w2,而pattern函数只有参数w1、w2。因为pattern不使用x,所以不应该将x指定为参数。作为第二条规则的一个例子:

def pattern(x, y):
    return torch.neg(x) + torch.relu(y)

替换为

def replacement(x, y):
    return torch.relu(x)

在这种情况下,替换需要与pattern相同数量的参数(x和y),即使在替换中没有使用参数y。
调用subgraph_rewriter.replace_pattern后,生成的Python代码如下所示:

def forward(self, x, w1, w2):
    stack_1 = torch.stack([w1, w2])
    sum_1 = stack_1.sum()
    stack_2 = torch.stack([w1, w2])
    sum_2 = stack_2.sum()
    max_1 = torch.max(sum_1)
    add_1 = x + max_1
    max_2 = torch.max(sum_2)
    add_2 = add_1 + max_2
    return add_2

标签:None,torch,FX,self,args,笔记,PyTorch,kwargs,节点
来源: https://blog.csdn.net/itlilyer/article/details/120893638

本站声明: 1. iCode9 技术分享网(下文简称本站)提供的所有内容,仅供技术学习、探讨和分享;
2. 关于本站的所有留言、评论、转载及引用,纯属内容发起人的个人观点,与本站观点和立场无关;
3. 关于本站的所有言论和文字,纯属内容发起人的个人观点,与本站观点和立场无关;
4. 本站文章均是网友提供,不完全保证技术分享内容的完整性、准确性、时效性、风险性和版权归属;如您发现该文章侵犯了您的权益,可联系我们第一时间进行删除;
5. 本站为非盈利性的个人网站,所有内容不会用来进行牟利,也不会利用任何形式的广告来间接获益,纯粹是为了广大技术爱好者提供技术内容和技术思想的分享性交流网站。

专注分享技术,共同学习,共同进步。侵权联系[81616952@qq.com]

Copyright (C)ICode9.com, All Rights Reserved.

ICode9版权所有