ICode9

精准搜索请尝试: 精确搜索
首页 > 其他分享> 文章详细

重磅分享:使用PyTorch实现MNIST手写体识别代码

2020-04-26 17:38:23  阅读:370  来源: 互联网

标签:Loss 60000 loss Average Epoch PyTorch Train 手写体 MNIST


@本文来源于公众号:csdn2299,喜欢可以关注公众号 程序员学府
今天小编就为大家分享一篇使用PyTorch实现MNIST手写体识别代码,具有很好的参考价值,希望对大家有所帮助。一起跟随小编过来看看吧

文章目录

实验环境

win10 + anaconda + jupyter notebook

Pytorch1.1.0

Python3.7

gpu环境(可选)

MNIST数据集介绍

MNIST 包括6万张28x28的训练样本,1万张测试样本,可以说是CV里的“Hello Word”。本文使用的CNN网络将MNIST数据的识别率提高到了99%。下面我们就开始进行实战。

导入包

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torchvision import datasets, transforms
torch.__version__

定义超参数

BATCH_SIZE=512
EPOCHS=20
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")

数据集

我们直接使用PyTorch中自带的dataset,并使用DataLoader对训练数据和测试数据分别进行读取。如果下载过数据集这里download可选择False

train_loader = torch.utils.data.DataLoader(
    datasets.MNIST('data', train=True, download=True, 
            transform=transforms.Compose([
              transforms.ToTensor(),
              transforms.Normalize((0.1307,), (0.3081,))
            ])),
    batch_size=BATCH_SIZE, shuffle=True)
 
test_loader = torch.utils.data.DataLoader(
    datasets.MNIST('data', train=False, transform=transforms.Compose([
              transforms.ToTensor(),
              transforms.Normalize((0.1307,), (0.3081,))
            ])),
    batch_size=BATCH_SIZE, shuffle=True)

定义网络

该网络包括两个卷积层和两个线性层,最后输出10个维度,即代表0-9十个数字。

class ConvNet(nn.Module):
  def __init__(self):
    super().__init__()
    self.conv1=nn.Conv2d(1,10,5) # input:(1,28,28) output:(10,24,24) 
    self.conv2=nn.Conv2d(10,20,3) # input:(10,12,12) output:(20,10,10)
    self.fc1 = nn.Linear(20*10*10,500)
    self.fc2 = nn.Linear(500,10)
  def forward(self,x):
    in_size = x.size(0)
    out = self.conv1(x)
    out = F.relu(out)
    out = F.max_pool2d(out, 2, 2) 
    out = self.conv2(out)
    out = F.relu(out)
    out = out.view(in_size,-1)
    out = self.fc1(out)
    out = F.relu(out)
    out = self.fc2(out)
    out = F.log_softmax(out,dim=1)
    return out

实例化网络

model = ConvNet().to(DEVICE) # 将网络移动到gpu上
optimizer = optim.Adam(model.parameters()) # 使用Adam优化器

定义训练函数

def train(model, device, train_loader, optimizer, epoch):
  model.train()
  for batch_idx, (data, target) in enumerate(train_loader):
    data, target = data.to(device), target.to(device)
    optimizer.zero_grad()
    output = model(data)
    loss = F.nll_loss(output, target)
    loss.backward()
    optimizer.step()
    if(batch_idx+1)%30 == 0: 
      print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format(
        epoch, batch_idx * len(data), len(train_loader.dataset),
        100. * batch_idx / len(train_loader), loss.item()))

定义测试函数

def test(model, device, test_loader):
  model.eval()
  test_loss = 0
  correct = 0
  with torch.no_grad():
    for data, target in test_loader:
      data, target = data.to(device), target.to(device)
      output = model(data)
      test_loss += F.nll_loss(output, target, reduction='sum').item() # 将一批的损失相加
      pred = output.max(1, keepdim=True)[1] # 找到概率最大的下标
      correct += pred.eq(target.view_as(pred)).sum().item()
 
  test_loss /= len(test_loader.dataset)
  print('\nTest set: Average loss: {:.4f}, Accuracy: {}/{} ({:.0f}%)\n'.format(
    test_loss, correct, len(test_loader.dataset),
    100. * correct / len(test_loader.dataset)))

开始训练

for epoch in range(1, EPOCHS + 1):
  train(model, DEVICE, train_loader, optimizer, epoch)
  test(model, DEVICE, test_loader)

实验结果

Train Epoch: 1 [14848/60000 (25%)]  Loss: 0.375058
Train Epoch: 1 [30208/60000 (50%)]  Loss: 0.255248
Train Epoch: 1 [45568/60000 (75%)]  Loss: 0.128060
 
Test set: Average loss: 0.0992, Accuracy: 9690/10000 (97%)
 
Train Epoch: 2 [14848/60000 (25%)]  Loss: 0.093066
Train Epoch: 2 [30208/60000 (50%)]  Loss: 0.087888
Train Epoch: 2 [45568/60000 (75%)]  Loss: 0.068078
 
Test set: Average loss: 0.0599, Accuracy: 9816/10000 (98%)
 
Train Epoch: 3 [14848/60000 (25%)]  Loss: 0.043926
Train Epoch: 3 [30208/60000 (50%)]  Loss: 0.037321
Train Epoch: 3 [45568/60000 (75%)]  Loss: 0.068404
 
Test set: Average loss: 0.0416, Accuracy: 9859/10000 (99%)
 
Train Epoch: 4 [14848/60000 (25%)]  Loss: 0.031654
Train Epoch: 4 [30208/60000 (50%)]  Loss: 0.041341
Train Epoch: 4 [45568/60000 (75%)]  Loss: 0.036493
 
Test set: Average loss: 0.0361, Accuracy: 9873/10000 (99%)
 
Train Epoch: 5 [14848/60000 (25%)]  Loss: 0.027688
Train Epoch: 5 [30208/60000 (50%)]  Loss: 0.019488
Train Epoch: 5 [45568/60000 (75%)]  Loss: 0.018023
 
Test set: Average loss: 0.0344, Accuracy: 9875/10000 (99%)
 
Train Epoch: 6 [14848/60000 (25%)]  Loss: 0.024212
Train Epoch: 6 [30208/60000 (50%)]  Loss: 0.018689
Train Epoch: 6 [45568/60000 (75%)]  Loss: 0.040412
 
Test set: Average loss: 0.0350, Accuracy: 9879/10000 (99%)
 
Train Epoch: 7 [14848/60000 (25%)]  Loss: 0.030426
Train Epoch: 7 [30208/60000 (50%)]  Loss: 0.026939
Train Epoch: 7 [45568/60000 (75%)]  Loss: 0.010722
 
Test set: Average loss: 0.0287, Accuracy: 9892/10000 (99%)
 
Train Epoch: 8 [14848/60000 (25%)]  Loss: 0.021109
Train Epoch: 8 [30208/60000 (50%)]  Loss: 0.034845
Train Epoch: 8 [45568/60000 (75%)]  Loss: 0.011223
 
Test set: Average loss: 0.0299, Accuracy: 9904/10000 (99%)
 
Train Epoch: 9 [14848/60000 (25%)]  Loss: 0.011391
Train Epoch: 9 [30208/60000 (50%)]  Loss: 0.008091
Train Epoch: 9 [45568/60000 (75%)]  Loss: 0.039870
 
Test set: Average loss: 0.0341, Accuracy: 9890/10000 (99%)
 
Train Epoch: 10 [14848/60000 (25%)] Loss: 0.026813
Train Epoch: 10 [30208/60000 (50%)] Loss: 0.011159
Train Epoch: 10 [45568/60000 (75%)] Loss: 0.024884
 
Test set: Average loss: 0.0286, Accuracy: 9901/10000 (99%)
 
Train Epoch: 11 [14848/60000 (25%)] Loss: 0.006420
Train Epoch: 11 [30208/60000 (50%)] Loss: 0.003641
Train Epoch: 11 [45568/60000 (75%)] Loss: 0.003402
 
Test set: Average loss: 0.0377, Accuracy: 9894/10000 (99%)
 
Train Epoch: 12 [14848/60000 (25%)] Loss: 0.006866
Train Epoch: 12 [30208/60000 (50%)] Loss: 0.012617
Train Epoch: 12 [45568/60000 (75%)] Loss: 0.008548
 
Test set: Average loss: 0.0311, Accuracy: 9908/10000 (99%)
 
Train Epoch: 13 [14848/60000 (25%)] Loss: 0.010539
Train Epoch: 13 [30208/60000 (50%)] Loss: 0.002952
Train Epoch: 13 [45568/60000 (75%)] Loss: 0.002313
 
Test set: Average loss: 0.0293, Accuracy: 9905/10000 (99%)
 
Train Epoch: 14 [14848/60000 (25%)] Loss: 0.002100
Train Epoch: 14 [30208/60000 (50%)] Loss: 0.000779
Train Epoch: 14 [45568/60000 (75%)] Loss: 0.005952
 
Test set: Average loss: 0.0335, Accuracy: 9897/10000 (99%)
 
Train Epoch: 15 [14848/60000 (25%)] Loss: 0.006053
Train Epoch: 15 [30208/60000 (50%)] Loss: 0.002559
Train Epoch: 15 [45568/60000 (75%)] Loss: 0.002555
 
Test set: Average loss: 0.0357, Accuracy: 9894/10000 (99%)
 
Train Epoch: 16 [14848/60000 (25%)] Loss: 0.000895
Train Epoch: 16 [30208/60000 (50%)] Loss: 0.004923
Train Epoch: 16 [45568/60000 (75%)] Loss: 0.002339
 
Test set: Average loss: 0.0400, Accuracy: 9893/10000 (99%)
 
Train Epoch: 17 [14848/60000 (25%)] Loss: 0.004136
Train Epoch: 17 [30208/60000 (50%)] Loss: 0.000927
Train Epoch: 17 [45568/60000 (75%)] Loss: 0.002084
 
Test set: Average loss: 0.0353, Accuracy: 9895/10000 (99%)
 
Train Epoch: 18 [14848/60000 (25%)] Loss: 0.004508
Train Epoch: 18 [30208/60000 (50%)] Loss: 0.001272
Train Epoch: 18 [45568/60000 (75%)] Loss: 0.000543
 
Test set: Average loss: 0.0380, Accuracy: 9894/10000 (99%)
 
Train Epoch: 19 [14848/60000 (25%)] Loss: 0.001699
Train Epoch: 19 [30208/60000 (50%)] Loss: 0.000661
Train Epoch: 19 [45568/60000 (75%)] Loss: 0.000275
 
Test set: Average loss: 0.0339, Accuracy: 9905/10000 (99%)
 
Train Epoch: 20 [14848/60000 (25%)] Loss: 0.000441
Train Epoch: 20 [30208/60000 (50%)] Loss: 0.000695
Train Epoch: 20 [45568/60000 (75%)] Loss: 0.000467
 
Test set: Average loss: 0.0396, Accuracy: 9894/10000 (99%)

总结

一个实际项目的工作流程:找到数据集,对数据做预处理,定义我们的模型,调整超参数,测试训练,再通过训练结果对超参数进行调整或者对模型进行调整。

非常感谢你的阅读
大学的时候选择了自学python,工作了发现吃了计算机基础不好的亏,学历不行这是没办法的事,只能后天弥补,于是在编码之外开启了自己的逆袭之路,不断的学习python核心知识,深入的研习计算机基础知识,整理好了,我放在我们的Python学习扣qun:250933691,如果你也不甘平庸,那就与我一起在编码之外,不断成长吧!

其实这里不仅有技术,更有那些技术之外的东西,比如,如何做一个精致的程序员,而不是“屌丝”,程序员本身就是高贵的一种存在啊,难道不是吗?[点击加入]
想做你自己想成为高尚人,加油!

标签:Loss,60000,loss,Average,Epoch,PyTorch,Train,手写体,MNIST
来源: https://blog.csdn.net/daidaiweng/article/details/105757653

本站声明: 1. iCode9 技术分享网(下文简称本站)提供的所有内容,仅供技术学习、探讨和分享;
2. 关于本站的所有留言、评论、转载及引用,纯属内容发起人的个人观点,与本站观点和立场无关;
3. 关于本站的所有言论和文字,纯属内容发起人的个人观点,与本站观点和立场无关;
4. 本站文章均是网友提供,不完全保证技术分享内容的完整性、准确性、时效性、风险性和版权归属;如您发现该文章侵犯了您的权益,可联系我们第一时间进行删除;
5. 本站为非盈利性的个人网站,所有内容不会用来进行牟利,也不会利用任何形式的广告来间接获益,纯粹是为了广大技术爱好者提供技术内容和技术思想的分享性交流网站。

专注分享技术,共同学习,共同进步。侵权联系[81616952@qq.com]

Copyright (C)ICode9.com, All Rights Reserved.

ICode9版权所有