ICode9

精准搜索请尝试: 精确搜索
首页 > 其他分享> 文章详细

深度学习-训练MNIST数据集Demo

2021-01-15 12:31:41  阅读:314  来源: 互联网

标签:img nn Demo torch train 深度 test data MNIST


一、安装Anacoda

  1. 下载相应版本并进行安装,本文下载 64-bit(x86) Installer (Python 3.8)
    在这里插入图片描述
  2. 参考Anacoda安装指导官方文档进行安装, 切换至下载目录,使用命令安装
$ sh Anaconda3-2020.11-Linux-x86_64.sh
  1. 添加Anaconda环境变量
  • 打开环境变量文件
$ vi ~/.bashrc 
  • 添加一行
export PATH=/home/[User]/anaconda3/bin:$PATH
  • 使环境变量生效
$ source ~/.bashrc 
  • 验证conda命令
    在这里插入图片描述

二、安装PyTorch

访问PyTorch官网, 按照实际环境选择对应版本安装命令
在这里插入图片描述

本文选择的命令组合如上图所示(无cuda版本)

conda install pytorch torchvision torchaudio cpuonly -c pytorch

三、MNIST数据集

MNIST数据集包含四个文件,如下
在这里插入图片描述

四、训练过程

4.1 代码解析

1. 导入相关包

import torch
from torchvision import datasets, transforms
import cv2
import os
import torchvision
import numpy as np
from torch.autograd import Variable

2. 获取数据集和测试集

transform指定对数据集进行的变换操作
root指定数据集下载存放目录
train指定数据集下载完成后需要载入类型
True:代表载入训练集
False:代表载入测试集

transform = transforms.Compose([transforms.ToTensor(),
                               transforms.Normalize(mean=[0.5],std=[0.5])])
data_train = datasets.MNIST(root = "./data/",
                            transform=transform,
                            train = True,
                            download = True)

data_test = datasets.MNIST(root="./data/",
                           transform = transform,
                           train = False)

3. 数据预览及装载

dataset指定载入的数据集
batch_size指定每个包的图像数据个数
shuffle指定是否在装载时随机打乱顺序

data_loader_train = torch.utils.data.DataLoader(dataset=data_train,
                                                batch_size = 64,
                                                shuffle = True,
                                                 num_workers=2)

data_loader_test = torch.utils.data.DataLoader(dataset=data_test,
                                               batch_size = 64,
                                               shuffle = True,
                                                num_workers=2)

next和iter获取一个批次(64)的图片及其对应标签
torchvision.utils.make_grid将该批次图片构建成网格模式
网格模式的图片维度是(channel,height,weight)(即色彩通道,图片高度,宽度),经过transpose(1,2,0)后变为(height,weight,channel)

images, labels = next(iter(data_loader_train))
img = torchvision.utils.make_grid(images)

img = img.numpy().transpose(1,2,0)
std = [0.5]
mean = [0.5]
img = img*std+mean
print(labels)
cv2.imshow('win',img)
key_pressed=cv2.waitKey(0)

4. 构建模型和设置参数

卷积神经网络CNN的结构一般包含这几层:

  • 输入层:用于数据的输入

  • 卷积层:使用卷积核进行特征提取和特征映射

  • 激励层:由于卷积也是一种线性运算,因此需要增加非线性映射

  • 池化层:进行下采样,对特征图稀疏处理,减少特征信息的损失

  • 输出层:用于输出结果

  • torch.nn.Conv2d类构建卷积层
  • torch.nn.ReLU类构建激活层
  • torch.nn.MaxPool2d类构建池化层
  • torch.nn.Linear类构建全连接层

前向传播forward()函数

  • self.conv1(): 卷积处理

  • x.view(-1,1414128): 参数扁平化(保证全连接层的实际输出的参数维度和其定义输入的维度相匹配)

  • self.dense: 全连接层进行最后的分类

class Model(torch.nn.Module):
    
    def __init__(self):
        super(Model, self).__init__()
        self.conv1 = torch.nn.Sequential(torch.nn.Conv2d(1,64,kernel_size=3,stride=1,padding=1),
                                         torch.nn.ReLU(),
                                         torch.nn.Conv2d(64,128,kernel_size=3,stride=1,padding=1),
                                         torch.nn.ReLU(),
                                         torch.nn.MaxPool2d(stride=2,kernel_size=2))
        self.dense = torch.nn.Sequential(torch.nn.Linear(14*14*128,1024),
                                         torch.nn.ReLU(),
                                         torch.nn.Dropout(p=0.5),
                                         torch.nn.Linear(1024, 10))
    def forward(self, x):
        x = self.conv1(x)
        #x = self.conv2(x)
        x = x.view(-1, 14*14*128)
        x = self.dense(x)
        return x

训练模型和参数

  • 实例化模型

  • 定义损失函数

  • 定义优化函数

model = Model()
cost = torch.nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters())

5. 训练

训练代码

n_epochs = 5 #训练次数
for epoch in range(n_epochs):
    running_loss = 0.0
    running_correct = 0
    print("Epoch {}/{}".format(epoch, n_epochs))
    print("-"*10)
    for data in data_loader_train:
        X_train, y_train = data
        X_train, y_train = Variable(X_train), Variable(y_train)
        outputs = model(X_train)
        _,pred = torch.max(outputs.data, 1)
        optimizer.zero_grad()
        loss = cost(outputs, y_train)
        
        loss.backward()
        optimizer.step()
        running_loss += loss.item()
        running_correct += torch.sum(pred == y_train.data)
    testing_correct = 0
    for data in data_loader_test:
        X_test, y_test = data
        X_test, y_test = Variable(X_test), Variable(y_test)
        outputs = model(X_test)
        _, pred = torch.max(outputs.data, 1)
        testing_correct += torch.sum(pred == y_test.data)
    print("Loss is:{:.4f}, Train Accuracy is:{:.4f}%, Test Accuracy is:{:.4f}".format(running_loss/len(data_train),
                                                                                      100*running_correct/len(data_train),
                                                                                      100*testing_correct/len(data_test)))

训练结果

在这里插入图片描述

6. 测试

data_loader_test = torch.utils.data.DataLoader(dataset=data_test,
                                          batch_size = 4,
                                          shuffle = True)
X_test, y_test = next(iter(data_loader_test))
inputs = Variable(X_test)
pred = model(inputs)
_,pred = torch.max(pred, 1)

print("Predict Label is:", pred.data)
print("Real Label is:",y_test)

img = torchvision.utils.make_grid(X_test)
img = img.numpy().transpose(1,2,0)

std = [0.5]
mean = [0.5]
img = img*std+mean
cv2.imshow('test',img)

4.2 完整代码

请关注公众号【考拉技术研究所】

–>“资源分享”
在这里插入图片描述

标签:img,nn,Demo,torch,train,深度,test,data,MNIST
来源: https://blog.csdn.net/KoalaZB/article/details/111460765

本站声明: 1. iCode9 技术分享网(下文简称本站)提供的所有内容,仅供技术学习、探讨和分享;
2. 关于本站的所有留言、评论、转载及引用,纯属内容发起人的个人观点,与本站观点和立场无关;
3. 关于本站的所有言论和文字,纯属内容发起人的个人观点,与本站观点和立场无关;
4. 本站文章均是网友提供,不完全保证技术分享内容的完整性、准确性、时效性、风险性和版权归属;如您发现该文章侵犯了您的权益,可联系我们第一时间进行删除;
5. 本站为非盈利性的个人网站,所有内容不会用来进行牟利,也不会利用任何形式的广告来间接获益,纯粹是为了广大技术爱好者提供技术内容和技术思想的分享性交流网站。

专注分享技术,共同学习,共同进步。侵权联系[81616952@qq.com]

Copyright (C)ICode9.com, All Rights Reserved.

ICode9版权所有