ICode9

精准搜索请尝试: 精确搜索
首页 > 其他分享> 文章详细

torch中的mask:masked_fill, masked_select, masked_scatter

2022-06-14 21:37:25  阅读:291  来源: 互联网

标签:tensor torch mask source masked input


1. 简介

  pytorch提供mask机制用来提取数据中“感兴趣”的部分。过程如下:左边的矩阵是原数据,中间的mask是遮罩矩阵,标记为1的表明对这个位置的数据“感兴趣”-保留,反之舍弃。整个过程可以视作是在原数据上盖了一层mask,只有感兴趣的部分(值为1)显露出来,而其他部分则背遮住。(matlab中也有mask操作)

  mask为一个和元数据size相匹配的tensor-bool,相匹配: broadcastable-广播机制。如一个2*3*3的原数据可以由一个3*3的mask来提取。

  mask一般是先建立0/1矩阵,然后通过tensor.bool()来转为bool类型的tensor,其他true表示原数据被遮住或者被选中,false表示原数据没有被遮住或者未被选中:这句话在下面的演示中更容易理解。

2. 程序演示

  这里涉及的是torch中的三个常见mask函数:masked_fill, masked_select, masked_scatter。

  先构造好input和mask矩阵:

imgs = torch.randint(0, 255, [2, 3, 3], dtype=torch.float32)
"""
tensor([[[182., 242.,  11.],
         [163.,  92., 183.],
         [222.,  54.,  86.]],
        [[157., 139., 254.],
         [158., 148.,  46.],
         [  1.,  13.,  56.]]])
"""
mask = torch.tensor([[1, 0, 0], [0, 1, 0], [0, 0, 1]]).bool()
"""
tensor([[ True, False, False],
        [False,  True, False],
        [False, False,  True]])
"""

1)torch.masked_fill(input, mask, value)

  参数:  

  • input:输入的原数据
  • mask:遮罩矩阵
  • value:被“遮住的”部分填充的数据,可以取0、1等值,数据类型不限,int、float均可

  返回值:一个和input相同size的masked-tensor

  使用:

  • output = torch.masked_fill(input, mask, value)
  • output = input.masked_fill(mask, value)
imgs_masked = torch.masked_fill(input=imgs, mask=~mask, value=0) # 这里mask取反:true表示被“遮住的”
"""
tensor([[[182.,   0.,   0.],
         [  0.,  92.,   0.],
         [  0.,   0.,  86.]],
        [[157.,   0.,   0.],
         [  0., 148.,   0.],
         [  0.,   0.,  56.]]])
"""

2)torch.masked_select(input, mask, out)

  参数:  

  • input:输入的原数据
  • mask:遮罩矩阵
  • out:输出的结果,和原tensor不共用内存,一般在左侧接收,而不在形参中赋值

  返回值:一维tensor,数据为“选中”的数据

  使用:

  • torch.masked_select(input, mask, out)
  • output = input.masked_select(mask)
selected_ele = torch.masked_select(input=imgs, mask=mask)  # true表示selected,false则未选中,所以这里没有取反
# tensor([182., 92., 86., 157., 148., 56.])

3)torch.masked_scatter(input, mask, source)

  说明:将从input中mask得到的数据赋值到source-tensor中

  参数:  

  • input:输入的原数据
  • mask:遮罩矩阵
  • source:遮罩矩阵的”样子“(全零还是全一或是其他),true表示遮住了

  返回值:一个和source相同size的masked-tensor

  使用:

  • output = torch.masked_scatter(input, mask, source)
  • output = input.masked_scatter(mask, source)
source = torch.zeros_like(imgs)
imgs_masked_copied = torch.masked_scatter(input=imgs, mask=~mask, source=source)
"""
tensor([[[173.,   0.,   0.],
         [  0.,  77.,   0.],
         [  0.,   0., 159.]],
        [[ 85.,   0.,   0.],
         [  0., 184.,   0.],
         [  0.,   0., 223.]]])
"""

3. 参考链接

PyTorch documentation — PyTorch 1.11.0 documentation

深度学习中的mask操作

python中的三个mask

 

标签:tensor,torch,mask,source,masked,input
来源: https://www.cnblogs.com/YuanShiRenY/p/torch_mask.html

本站声明: 1. iCode9 技术分享网(下文简称本站)提供的所有内容,仅供技术学习、探讨和分享;
2. 关于本站的所有留言、评论、转载及引用,纯属内容发起人的个人观点,与本站观点和立场无关;
3. 关于本站的所有言论和文字,纯属内容发起人的个人观点,与本站观点和立场无关;
4. 本站文章均是网友提供,不完全保证技术分享内容的完整性、准确性、时效性、风险性和版权归属;如您发现该文章侵犯了您的权益,可联系我们第一时间进行删除;
5. 本站为非盈利性的个人网站,所有内容不会用来进行牟利,也不会利用任何形式的广告来间接获益,纯粹是为了广大技术爱好者提供技术内容和技术思想的分享性交流网站。

专注分享技术,共同学习,共同进步。侵权联系[81616952@qq.com]

Copyright (C)ICode9.com, All Rights Reserved.

ICode9版权所有