ICode9

精准搜索请尝试: 精确搜索
首页 > 其他分享> 文章详细

PyTorch中卷积的简单操作

2022-01-26 00:02:14  阅读:237  来源: 互联网

标签:kernel output 卷积 import torch PyTorch 简单 input 默认值


卷积

Conv1d

Conv1d
Conv1d

  • input:形状的输入张量

  • weight: 形状过滤器

  • bias:形状的可选偏置张量( out_channels ). 默认:None

  • stride:卷积核的步长。可以是单个数字或元组(sH, sW)。默认值:1

  • padding:输入两侧的隐式填充。可以是字符串 {‘valid’, ‘same’}、单个数字或元组(padH, padW)。默认值:0 padding='valid'与无填充相同。padding='same'填充输入,使输出具有作为输入的形状。但是,此模式不支持除 1 以外的任何步幅值。

    对于padding='same',如果weight是偶数长度并且 dilation在任何维度上都是奇数,则pad()内部可能需要完整操作。降低性能。

  • dilation:内核元素之间的间距。可以是单个数字或元组(dH, dW)。默认值:1

  • groups:将输入分成组,\text{in_channels}输入频道应该可以被组数整除。默认值:1

import torch
import torch.nn.functional as F

input = torch.tensor([[1, 2, 0, 3, 1],
                      [0, 1, 2, 3, 1],
                      [1, 2, 1, 0, 0],
                      [5, 2, 3, 1, 1],
                      [2, 1, 0, 1, 1]])

kernel = torch.tensor([[1, 2, 1],
                       [0, 1, 0],
                       [2, 1, 0]])

input = torch.reshape(input, (1, 1, 5, 5))
kernel = torch.reshape(kernel, (1, 1, 3, 3))

print(input.shape)
print(kernel.shape)

output = F.conv2d(input, kernel, stride=1)
print(output)

[外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-89FYZhgK-1643125675914)(H:\codes\pytorch\note\卷积操作.assets\微信图片_20220107023303.jpg)]

输出结果:
[外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-YrHwl6D2-1643125675914)(H:\codes\pytorch\note\卷积操作.assets\image-20220107023504127.png)]

output2 = F.conv2d(input, kernel, stride=2)
print(output2)

output3 = F.conv2d(input, kernel, stride=1, padding=1)
print(output3)

输出结果:
[外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-ypEwUHX4-1643125675915)(H:\codes\pytorch\note\卷积操作.assets\image-20220107023616457.png)]

Conv2d

Conv2d
[外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-nqgFrCnR-1643125675916)(H:\codes\pytorch\note\卷积操作.assets\image-20220107024928073.png)]

  • in_channels (int):输入图像中的通道数
  • out_channels(int):卷积产生的通道数
  • kernel_size(int or tuple):卷积核的大小
  • stride(int or tuple, optional):卷积的步幅。默认值:1
  • padding(int, tuple or str, optional):填充添加到输入的所有四个边。默认值:0
  • padding_mode(string, optional):‘zeros’,‘reflect’,‘replicate’‘circular’。默认:‘zeros’
  • dilation(int or tuple,optional):内核元素之间的间距。默认值:1
  • groups(int, optional):从输入通道到输出通道的阻塞连接数。默认值:1
  • bia(*bool, optional):如果True,则向输出添加可学习的偏差。默认:True
import torch
import torchvision
from torch import nn
from torch.nn import Conv2d
from torch.utils.data import DataLoader
from torch.utils.tensorboard import SummaryWriter

dataset = torchvision.datasets.CIFAR10(root = "./dataset", train = False,
                                       transform = torchvision.transforms.ToTensor(), download = True)
dataloader = DataLoader(dataset, batch_size = 64)


class Model(nn.Module):
    def __init__(self):
        super().__init__()
        self.conv1 = Conv2d(in_channels = 3, out_channels = 6, kernel_size = 3, stride = 1, padding = 0)

    def forward(self, x):
        x = self.conv1(x)
        return x


model = Model()

writer = SummaryWriter("logs")
step = 0
for data in dataloader:
    imgs, targets = data
    output = model(imgs)
    print(imgs.shape)
    print(output.shape)
    # torch.Size([64, 3, 32, 32])
    writer.add_images("input", imgs, step)
    # torch.Size([64, 6, 30, 30])  -> [xxx, 3, 30, 30]
    output = torch.reshape(output, (-1, 3, 30, 30))
    writer.add_images("output", output, step)

    step = step + 1

writer.close()

运行结果:
[外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-NQFErpAF-1643125675917)(H:\codes\pytorch\note\卷积操作.assets\image-20220107041336233.png)][外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-SUz9zVOb-1643125675918)(H:\codes\pytorch\note\卷积操作.assets\image-20220107041402544.png)]

标签:kernel,output,卷积,import,torch,PyTorch,简单,input,默认值
来源: https://blog.csdn.net/weixin_51296032/article/details/122693992

本站声明: 1. iCode9 技术分享网(下文简称本站)提供的所有内容,仅供技术学习、探讨和分享;
2. 关于本站的所有留言、评论、转载及引用,纯属内容发起人的个人观点,与本站观点和立场无关;
3. 关于本站的所有言论和文字,纯属内容发起人的个人观点,与本站观点和立场无关;
4. 本站文章均是网友提供,不完全保证技术分享内容的完整性、准确性、时效性、风险性和版权归属;如您发现该文章侵犯了您的权益,可联系我们第一时间进行删除;
5. 本站为非盈利性的个人网站,所有内容不会用来进行牟利,也不会利用任何形式的广告来间接获益,纯粹是为了广大技术爱好者提供技术内容和技术思想的分享性交流网站。

专注分享技术,共同学习,共同进步。侵权联系[81616952@qq.com]

Copyright (C)ICode9.com, All Rights Reserved.

ICode9版权所有