ICode9

精准搜索请尝试: 精确搜索
首页 > 其他分享> 文章详细

TVM量化代码解析

2021-10-31 06:31:51  阅读:307  来源: 互联网

标签:relay cfg TVM quantize pass 量化 解析 mod


TVM量化代码解析

TVM量化,非常方便,即插即用。使用加入了伪量化后的pass,替代原来的pass,一个官方提供的示例:

def test_mul_rewrite():

    """a test case where rhs of mul is not constant"""

    data=relay.var("data",shape=(1,16,64,64))

    multiplier=relay.sigmoid(relay.var("data",shape=(1,16,1,1)))

    conv=relay.nn.conv2d(data,relay.var("weight"),

                           kernel_size=(3,3),

                           padding=(1,1),

                           channels=16)

    act=relay.nn.relu(data=conv)

    quantize_and_build(act * multiplier)

    pool=relay.nn.global_avg_pool2d(data=act)

    quantize_and_build(act * pool)

入口就是函数:

def quantize_and_build(out):

    f=relay.Function(relay.analysis.free_vars(out),out)

    mod,params=testing.create_workload(f)

    with relay.quantize.qconfig(skip_conv_layers=[]):

        qmod=relay.quantize.quantize(mod,params)

    relay.build(qmod,"llvm",params=params)

    return qmod

调用relay.quantize.quantize函数,这个函数实在太长了,只放上主体部分。

 1. mod=prerequisite_optimize(mod,params)

 2. calibrate_pass=tvm.transform.module_pass(

        calibrate(dataset),opt_level=1,

        name="QuantizeCalibrate")

    quant_passes=[partition(),

                    annotate(),

                    calibrate_pass]

    if not current_qconfig().do_simulation:

        quant_passes.append(realize())

    quant_passes.append(_transform.FoldConstant())

    quantize_seq=tvm.transform.Sequential(quant_passes)

    with tvm.transform.PassContext(opt_level=3,

                                   required_pass=["QuantizeAnnotate",

                                                  "QuantizeCalibrate",

                                                  "QuantizeRealize"]):

 3. with quantize_context():

            mod=quantize_seq(mod)

 4. q_cfg=current_qconfig()

    assert q_cfg.partition_conversions in ['disabled','enabled','fully_integral']

    if q_cfg.partition_conversions != 'disabled':

        quantized_dtypes={q_cfg.dtype_input,q_cfg.dtype_weight,q_cfg.dtype_activation}

        ensure_fully_integral=q_cfg.partition_conversions == 'fully_integral'

        return partition_conversions(mod,quantized_dtypes,ensure_fully_integral)

从代码中,可看到,TVM量化需要做的就是

l  标号1,图优化部分,具体做哪些图优化,就可自己选了,如算子融合,常量折叠。

l  标号2,整个量化的步骤,包括定义quant_passes,如果发现config设置,不需要伪量化,就是inference阶段了,就把realize加进去,否则,只需要annotate及calibrate,优化量化参数。

l  标号3,开始做量化了,将一个fp32的inference graph,转成int类型的inference graph,可参照第一张图。

l  标号4,把realize的graph,或者说对于一个op的前向推理的步骤,分成前中后三部分:

比如,conv2d,input_quantization -> input_quantization*weight_quantization(core function) -> ouput_dequantization,

每一个算子计算完后,都要dequant回去,很有可能某些算子没量化,还得用fp32。

最优解肯定是全部都量化掉,直接int32跑到底,TVM搞了个参数ensure_fully_integral,保证所有的算子都量化了。

 

 

参考链接:

https://blog.csdn.net/Artyze/article/details/108776522

https://www.freesion.com/article/3155559638/

https://discuss.tvm.apache.org/t/rfc-search-based-automated-quantization/5483

标签:relay,cfg,TVM,quantize,pass,量化,解析,mod
来源: https://www.cnblogs.com/wujianming-110117/p/15488221.html

本站声明: 1. iCode9 技术分享网(下文简称本站)提供的所有内容,仅供技术学习、探讨和分享;
2. 关于本站的所有留言、评论、转载及引用,纯属内容发起人的个人观点,与本站观点和立场无关;
3. 关于本站的所有言论和文字,纯属内容发起人的个人观点,与本站观点和立场无关;
4. 本站文章均是网友提供,不完全保证技术分享内容的完整性、准确性、时效性、风险性和版权归属;如您发现该文章侵犯了您的权益,可联系我们第一时间进行删除;
5. 本站为非盈利性的个人网站,所有内容不会用来进行牟利,也不会利用任何形式的广告来间接获益,纯粹是为了广大技术爱好者提供技术内容和技术思想的分享性交流网站。

专注分享技术,共同学习,共同进步。侵权联系[81616952@qq.com]

Copyright (C)ICode9.com, All Rights Reserved.

ICode9版权所有