ICode9

精准搜索请尝试: 精确搜索
首页 > 其他分享> 文章详细

在PyTorch中使用自己的数据集

2021-11-01 16:58:49  阅读:265  来源: 互联网

标签:val 数据 PyTorch train transforms 使用 images path data


太累了 看了一上午CSDN还是没搞明白

看的下面的up主的讲解  做一下笔记 免得忘记 

在pytorch中自定义dataset读取数据_哔哩哔哩_bilibili

主要内容:如何划分训练集 验证集 数据读取 预处理 

代码在github上 pytorch_classification文件夹下custom_dataset文件夹中,内有main.py my_dataset.py utils.py三个py文件

 先看main.py文件

import os

import torch
from torchvision import transforms

from my_dataset import MyDataSet
from utils import read_split_data, plot_data_loader_image

# http://download.tensorflow.org/example_images/flower_photos.tgz
root = "/home/wz/my_github/data_set/flower_data/flower_photos"  # 数据集所在根目录


def main():
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    print("using {} device.".format(device))

    train_images_path, train_images_label, val_images_path, val_images_label = read_split_data(root)

    data_transform = {
        "train": transforms.Compose([transforms.RandomResizedCrop(224),
                                     transforms.RandomHorizontalFlip(),
                                     transforms.ToTensor(),
                                     transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])]),
        "val": transforms.Compose([transforms.Resize(256),
                                   transforms.CenterCrop(224),
                                   transforms.ToTensor(),
                                   transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])])}

    train_data_set = MyDataSet(images_path=train_images_path,
                               images_class=train_images_label,
                               transform=data_transform["train"])

    batch_size = 8
    nw = min([os.cpu_count(), batch_size if batch_size > 1 else 0, 8])  # number of workers
    print('Using {} dataloader workers'.format(nw))
    train_loader = torch.utils.data.DataLoader(train_data_set,
                                               batch_size=batch_size,
                                               shuffle=True,
                                               num_workers=nw,
                                               collate_fn=train_data_set.collate_fn)

    # plot_data_loader_image(train_loader)

    for step, data in enumerate(train_loader):
        images, labels = data


if __name__ == '__main__':
    main()

1需要更改第11行root为自己数据集的位置 且文件夹下包含的文件夹名字即为他们的标签

 

 2用read_split_data划分训练集和 验证集

main文件中查看read_split_data,跳转到utils第九行

#val_rate划分验证集占所有样本的比例 默认值是0.2
def read_split_data(root: str, val_rate: float = 0.2):
    random.seed(0)  # 保证随机结果可复现 不管在谁的电脑上划分的数据集都一样
    assert os.path.exists(root), "dataset root: {} does not exist.".format(root)

    # 遍历文件夹,一个文件夹对应一个类别
    flower_class = [cla for cla in os.listdir(root) if os.path.isdir(os.path.join(root, cla))]
    # 排序,保证顺序一致
    flower_class.sort()
    # 生成类别名称以及对应的数字索引
    class_indices = dict((k, v) for v, k in enumerate(flower_class))
    json_str = json.dumps(dict((val, key) for key, val in class_indices.items()), indent=4)
    with open('class_indices.json', 'w') as json_file:
        json_file.write(json_str)

    train_images_path = []  # 存储训练集的所有图片路径
    train_images_label = []  # 存储训练集图片对应索引信息
    val_images_path = []  # 存储验证集的所有图片路径
    val_images_label = []  # 存储验证集图片对应索引信息
    every_class_num = []  # 存储每个类别的样本总数
    supported = [".jpg", ".JPG", ".png", ".PNG"]  # 支持的文件后缀类型
    # 遍历每个文件夹下的文件

 utils第54行 设置为True即可以将样本的数量可视化,在第71行return语句返回四个值到main函数中,在return语句设置断点 debug一下main函数

运行结果

Connected to pydev debugger (build 212.5284.44)
using cuda device.
3670 images were found in the dataset.
2939 images for training.

 3main.py中预处理图片

my_dataset.py 中第20行如果使用的不是RGB图像可以自行更改 

 使用PIL来预处理图片(也可以用opencv 但pytorch中用pil的预处理较多)

 

标签:val,数据,PyTorch,train,transforms,使用,images,path,data
来源: https://blog.csdn.net/m0_60673782/article/details/121075259

本站声明: 1. iCode9 技术分享网(下文简称本站)提供的所有内容,仅供技术学习、探讨和分享;
2. 关于本站的所有留言、评论、转载及引用,纯属内容发起人的个人观点,与本站观点和立场无关;
3. 关于本站的所有言论和文字,纯属内容发起人的个人观点,与本站观点和立场无关;
4. 本站文章均是网友提供,不完全保证技术分享内容的完整性、准确性、时效性、风险性和版权归属;如您发现该文章侵犯了您的权益,可联系我们第一时间进行删除;
5. 本站为非盈利性的个人网站,所有内容不会用来进行牟利,也不会利用任何形式的广告来间接获益,纯粹是为了广大技术爱好者提供技术内容和技术思想的分享性交流网站。

专注分享技术,共同学习,共同进步。侵权联系[81616952@qq.com]

Copyright (C)ICode9.com, All Rights Reserved.

ICode9版权所有