ICode9

精准搜索请尝试: 精确搜索
首页 > 其他分享> 文章详细

pytorch复习与总结

2021-09-20 22:01:25  阅读:163  来源: 互联网

标签:总结 __ 复习 self label pytorch images import name


今天来复习pytorch的数据读取机制
torch.utils.data.DataLoader();构建可迭代的数据装载器,每一个for 循环,每一个iteration,都是从DataLoader中获取一个Batch_size大小的数据。
有没有好奇过,就加载这几个类,然后就可以把数据读取,而且还能以批量的形式加载,这是怎样的一个过程呢?今天我们就来慢慢的深入学习,学到哪是哪。
在这里插入图片描述

其中DataLoader大概有几个重要的参数,分别为:
1、dataset:属于DataSet类,决定数据从哪读取,怎么读取。
2、num_works:是否多进程读取
3、shuffle:每个epoch是否是乱序
4、batchsize:批量大小
5、drop_last:组成批量是,多余的是不是要剔除掉,
先来理解epoch
所有训练样本都已经输入到模型中,称为一个epoch,iteration:一批样本输入到模型中
再来理解Batchsize:
批量的大小,决定了一个epoch有多少个iteration
Dataset复习
torch.utils.data.Dataset(): Dataset抽象类,所有自定义的Dataset都必须继承他,并且还要复习__getitem__()和__len__()这两个函数,那么第一个函数是干啥用的呢?第二个是干啥用的呢,我这里通过学习查资料,理解了这么一个过程:
getitem:这个函数主要是来收集并返回图片和标签信息的,这个函数有两个参数,

def __getitem__(self, item):

其中item是干啥的呢?这个就是一个索引,很重的一个参数,我们在这个函数里读取信息的时候就是根据这个item 参数来寻找每张图片的信息的,其中过程可以在他的父类中看到
在这里插入图片描述
那么这个重写的函数是要收集哪些信息呢?
他的作用是收集我们训练集或者测试集的图片和图片所对应的标签号码,而且把图片信息转换为张量信息也是在这个函数里面发生的,转换完之后会返回出去
在这里插入图片描述
他把信息返回到哪里了呢?
那就是返回到你创建的这个自定义的数据类里面了
在这里插入图片描述
只有这样,你才能实例化对象把这写数据打包成批量或者单个。
**len_()**返回的就是数据集的长度,这个是很简单的return len(self.image)
我们在自定义数据集的时候还有一点很重要
我们怎么收集图片呢?
在这里插入图片描述
来看这张图片,这长图片表达的意思是:我要获取图片的具体位置(自己在__init___已经设置好了)和图片对应的标签,把获取后的信息随机打乱一下,放入到两个列表里面,并返回出去,返回给谁呢?
在这里插入图片描述
返回给__init__里面自定义的两个变量,这两个变量负责将数据包里面的内容根据你设定的训练测试返回来剪辑数据的大小。是不是再想剪辑后干啥呢?放到哪里呢?
还记得上边我们提到的__getitem__(self, item)?
在这里插入图片描述
返回给了他(红线圈起来的),然后继续往下执行,张量化。这就是一整个过程
下边张贴一下所有的代码过程。

import csv
import glob
import os
import random

import torch
import torchvision
import visdom
from PIL import Image
from torch.utils.data.dataset import Dataset
from torchvision import transforms
from torch.utils.data import DataLoader

class myData(Dataset):
    def __init__(self,file,size,mode):
        self.file=file
        self.size=size
        self.label_name={}#存放文件名称和标号
        for name in (os.listdir(os.path.join(file))):
            if os.path.isdir(os.path.join(self.file,name)):
                self.label_name[name]=len(self.label_name.keys())
        #print(self.label_name)
        self.imgaes,self.labels=self.get_img_info('image_csv')
        #--------------划分训练集和测试集范围--------------
        if mode=='train':
            self.images=self.imgaes[:int(0.8*len(self.imgaes))]
            self.labels=self.labels[:int(0.8*len(self.labels))]
        else:
            self.images=self.imgaes[int(0.8*len(self.imgaes)):]
            self.labels=self.labels[int(0.8*len(self.labels)):]
        #---------------划分训练集和测试集范围-------------
        pass
    def __len__(self):
        return len(self.images)
    pass
    def get_img_info(self,filename):#这个时获得图片信息的函数
        images=[]
        labels=[]
        for name in self.label_name.keys():#取出字典的键
            images+=glob.glob(os.path.join(self.file,name,'*.jpg'))
            images+=glob.glob(os.path.join(self.file,name,'*png'))
            #print(images)
            pass
        random.shuffle(images)#把这里面所有的地址和都打乱
        with open(os.path.join(self.file,filename),mode='w',newline='') as f:
            writer=csv.writer(f)
            for file in images:
                img=file.split(os.sep)[-2]
                label=self.label_name[img]
                labels.append(label)
                writer.writerow([file,label])
                pass
            pass
        #print(images)
        return images,labels
    def __getitem__(self, item):
        image=self.images[item]
        label=self.labels[item]
        tf=torchvision.transforms.Compose([
            lambda x:Image.open(x).convert('RGB'),
            transforms.Resize((int(self.size),int(self.size))),
            transforms.RandomRotation(15),
            transforms.CenterCrop(64),
            transforms.ToTensor()#这个放在最后操作,前边那几个是在图片的基础上修改的,这个把修改好的再转化为张量
        ])
        img=tf(image)
        label=torch.tensor(label)
        return img,label


def main():
    #viz=visdom.Visdom()
    mydata_train = myData('traindata', 64, 'train')
    mydata_test=myData('traindata', 64, 'test')
    #x,y=next(iter(mydata))
    #viz.image(x,win='sample_x',opts=dict(title='sample_x'))
    train=DataLoader(mydata_train,batch_size=32,shuffle=True)#把数据打包成批量Batchsize
    test=DataLoader(mydata_test,batch_size=32)
    '''
        print(train)
    for x,y in train:
        viz.images(x,nrow=8,win='batch',opts=dict(title='bacht'))
    '''

if __name__=='__main__':
    main()

这里还有几个知识点要记录:
数据增强:
对数据集进行变换,让模型更具有泛化能力,比如

       transforms.RandomRotation(15),
       transforms.CenterCrop(64),

上边这两个操作,具体的可以去网上查找

transforms.ToTensor()

把图像转为张量,同时进行归一化操作,将张量从0-255转到0-1之间

transforms.Normalize()

加快模型的收敛速度
总结:
今天复习了自定义数据的收集过程,到底是一个怎么收集的过程,然后就是一步一步的介绍了整体收集的过程。

标签:总结,__,复习,self,label,pytorch,images,import,name
来源: https://blog.csdn.net/weixin_52646021/article/details/120395430

本站声明: 1. iCode9 技术分享网(下文简称本站)提供的所有内容,仅供技术学习、探讨和分享;
2. 关于本站的所有留言、评论、转载及引用,纯属内容发起人的个人观点,与本站观点和立场无关;
3. 关于本站的所有言论和文字,纯属内容发起人的个人观点,与本站观点和立场无关;
4. 本站文章均是网友提供,不完全保证技术分享内容的完整性、准确性、时效性、风险性和版权归属;如您发现该文章侵犯了您的权益,可联系我们第一时间进行删除;
5. 本站为非盈利性的个人网站,所有内容不会用来进行牟利,也不会利用任何形式的广告来间接获益,纯粹是为了广大技术爱好者提供技术内容和技术思想的分享性交流网站。

专注分享技术,共同学习,共同进步。侵权联系[81616952@qq.com]

Copyright (C)ICode9.com, All Rights Reserved.

ICode9版权所有