ICode9

精准搜索请尝试: 精确搜索
首页 > 其他分享> 文章详细

PyTorch 1.0 中文文档:torch.utils.checkpoint

2019-03-07 21:54:32  阅读:408  来源: 互联网

标签:function 1.0 utils torch RNG checkpoint 模型 输入


译者: belonHan

注意

checkpointing的实现方法是在向后传播期间重新运行已被checkpint的前向传播段。 所以会导致像RNG这类(模型)的持久化的状态比实际更超前。默认情况下,checkpoint包含了使用RNG状态的逻辑(例如通过dropout),与non-checkpointed传递相比,checkpointed具有更确定的输出。RNG状态的存储逻辑可能会导致一定的性能损失。如果不需要确定的输出,设置全局标志(global flag) torch.utils.checkpoint.preserve_rng_state=False 忽略RNG状态在checkpoint时的存取。

torch.utils.checkpoint.checkpoint(function, *args)

checkpoint模型或模型的一部分

checkpoint通过计算换内存空间来工作。与向后传播中存储整个计算图的所有中间激活不同的是,checkpoint不会保存中间激活部分,而是在反向传递中重新计算它们。它被应用于模型的任何部分。

具体来说,在正向传播中,function将以torch.no_grad()方式运行 ,即不存储中间激活,但保存输入元组和 function的参数。在向后传播中,保存的输入变量以及 function会被取回,并且function在正向传播中被重新计算.现在跟踪中间激活,然后使用这些激活值来计算梯度。

Warning
警告

Checkpointing 在 torch.autograd.grad()中不起作用, 仅作用于 torch.autograd.backward().

警告

如果function在向后执行和前向执行不同,例如,由于某个全局变量,checkpoint版本将会不同,并且无法被检测到。

参数:

  • function - 描述在模型的正向传递或模型的一部分中运行的内容。它也应该知道如何处理作为元组传递的输入。例如,在LSTM中,如果用户通过 ,应正确使用第一个输入作为第二个输入(activation, hidden)functionactivationhidden
  • args – 包含输入的元组function

阅读全文/改进本文

标签:function,1.0,utils,torch,RNG,checkpoint,模型,输入
来源: https://www.cnblogs.com/wizardforcel/p/10492605.html

本站声明: 1. iCode9 技术分享网(下文简称本站)提供的所有内容,仅供技术学习、探讨和分享;
2. 关于本站的所有留言、评论、转载及引用,纯属内容发起人的个人观点,与本站观点和立场无关;
3. 关于本站的所有言论和文字,纯属内容发起人的个人观点,与本站观点和立场无关;
4. 本站文章均是网友提供,不完全保证技术分享内容的完整性、准确性、时效性、风险性和版权归属;如您发现该文章侵犯了您的权益,可联系我们第一时间进行删除;
5. 本站为非盈利性的个人网站,所有内容不会用来进行牟利,也不会利用任何形式的广告来间接获益,纯粹是为了广大技术爱好者提供技术内容和技术思想的分享性交流网站。

专注分享技术,共同学习,共同进步。侵权联系[81616952@qq.com]

Copyright (C)ICode9.com, All Rights Reserved.

ICode9版权所有