ICode9

精准搜索请尝试: 精确搜索
首页 > 其他分享> 文章详细

Pytorch卷积神经网络对MNIST数据集的手写数字识别

2022-07-16 13:33:38  阅读:150  来源: 互联网

标签:__ labels Tensor 卷积 torch Pytorch images import MNIST


这个程序由两个文件组成,一个训练脚本,一个测试脚本。安装好相应依赖环境之后即可进行训练,MNIST数据集使用torchvision.datasets.mnist包自动下载。

mnistTrain.py

# -*- coding: utf-8 -*-
import torch
from torchvision.datasets.mnist import MNIST
from torchvision.transforms import ToTensor
from torch.utils.data import DataLoader
from multiprocessing import cpu_count
from tqdm import tqdm


EPOCHS = 25                     # 训练轮数
BATCH_SIZE = 64                 # 每组数据多少张图片
DATA_FOLDER = 'dataset'         # 数据集保存目录
MODEL_FILE = 'MNIST_CNN.pkl'    # 模型文件路径
DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')


class CNN(torch.nn.Module):
    def __init__(self) -> None:
        super().__init__()
        self.conv = torch.nn.Sequential(
            torch.nn.Conv2d(1, 32, kernel_size=5, padding=2),
            torch.nn.BatchNorm2d(32),
            torch.nn.ReLU(),
            torch.nn.MaxPool2d(2)
        )
        self.fc = torch.nn.Linear(14 * 14 * 32, 10)

    def forward(self, feature: torch.Tensor) -> torch.Tensor:
        out: torch.Tensor = self.conv(feature)
        out = out.flatten(1)
        out = self.fc(out)
        return out


if __name__ == '__main__':
    torch.set_num_threads(cpu_count())

    trainData = MNIST(DATA_FOLDER, train=True, transform=ToTensor(), download=True)
    testData = MNIST(DATA_FOLDER, train=False, transform=ToTensor(), download=True)
    trainLoader = DataLoader(trainData, batch_size=BATCH_SIZE, shuffle=True)
    testLoader = DataLoader(testData, batch_size=128, shuffle=True)

    cnn = CNN().to(DEVICE)
    lossFunc = torch.nn.CrossEntropyLoss()
    optimizer = torch.optim.Adam(cnn.parameters(), lr=0.005)

    bestAccuracy = 0
    for epoch in range(EPOCHS):
        # Train
        for images, labels in tqdm(trainLoader, desc=f'Epoch {epoch + 1}/{EPOCHS}'):
            images: torch.Tensor = images.to(DEVICE)
            labels: torch.Tensor = labels.to(DEVICE)
            predictions: torch.Tensor = cnn(images)

            loss: torch.Tensor = lossFunc(predictions, labels)
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

        accuracy = 0
        for images, labels in testLoader:
            images: torch.Tensor = images.to(DEVICE)
            labels: torch.Tensor = labels.to(DEVICE)
            predictions: torch.Tensor = cnn(images)
            pred: torch.Tensor = predictions.max(dim=1)[1]
            accuracy += (pred == labels).sum().item()

        accuracy /= len(testData.targets)

        if bestAccuracy < accuracy:
            bestAccuracy = accuracy
            torch.save(cnn, MODEL_FILE)

        print(f'Accuracy: {accuracy * 100}%    Best Accuracy: {bestAccuracy * 100}%')

mnistTest.py

# -*- coding: utf-8 -*-

from mnistTrain import CNN, BATCH_SIZE, DATA_FOLDER, DEVICE, MODEL_FILE
import torch
from torchvision.datasets.mnist import MNIST
from torchvision.transforms import ToTensor
from torch.utils.data import DataLoader
from tqdm import tqdm

if __name__ == '__main__':
    testData = MNIST(DATA_FOLDER, train=False, transform=ToTensor(), download=True)
    testLoader = DataLoader(testData, batch_size=BATCH_SIZE, shuffle=True)
    cnn: CNN = torch.load(MODEL_FILE).to(DEVICE)

    accuracy = 0
    for images, labels in tqdm(testLoader):
        images: torch.Tensor = images.to(DEVICE)
        labels: torch.Tensor = labels.to(DEVICE)
        predictions: torch.Tensor = cnn.forward(images)
        pred: torch.Tensor = predictions.max(dim=1)[1]
        accuracy += (pred == labels).sum().item()

    accuracy /= len(testData.targets)
    print(f'Accuracy: {accuracy * 100}%')

标签:__,labels,Tensor,卷积,torch,Pytorch,images,import,MNIST
来源: https://www.cnblogs.com/fang-d/p/Pytorch_MNIST_CNN.html

本站声明: 1. iCode9 技术分享网(下文简称本站)提供的所有内容,仅供技术学习、探讨和分享;
2. 关于本站的所有留言、评论、转载及引用,纯属内容发起人的个人观点,与本站观点和立场无关;
3. 关于本站的所有言论和文字,纯属内容发起人的个人观点,与本站观点和立场无关;
4. 本站文章均是网友提供,不完全保证技术分享内容的完整性、准确性、时效性、风险性和版权归属;如您发现该文章侵犯了您的权益,可联系我们第一时间进行删除;
5. 本站为非盈利性的个人网站,所有内容不会用来进行牟利,也不会利用任何形式的广告来间接获益,纯粹是为了广大技术爱好者提供技术内容和技术思想的分享性交流网站。

专注分享技术,共同学习,共同进步。侵权联系[81616952@qq.com]

Copyright (C)ICode9.com, All Rights Reserved.

ICode9版权所有