ICode9

精准搜索请尝试: 精确搜索
首页 > 其他分享> 文章详细

深度学习MNIST代码(2022)

2022-01-26 09:05:57  阅读:264  来源: 互联网

标签:loss 代码 transforms 2022 test model image self MNIST


这里只是一个走全程的代码,重在体验,如果想学习深度学习建议看

官方文档

大体步骤是:

1.先处理数据,分训练集和测试集

2.构建模型

3.优化模型参数

4.保存模型

5.加载模型,测试

训练代码

# -*- coding: utf-8 -*-
# day day study day day up
# create by a Man
import torch
from torchvision.datasets import MNIST
from torchvision import transforms
from torch.utils.data import DataLoader
from torch import nn
from torch import optim
from tqdm import tqdm
from torch import save

device = "cuda" if torch.cuda.is_available() else "cpu"
print(f"Using {device} device")
#1.数据加载
my_transforms=transforms.Compose(
    [transforms.ToTensor(),#将图片变成张量
     transforms.Normalize(mean=(0.1307,),std=(0.3081,)) #标准化处理
     ])

batch_size=64##batch_size每批要加载多少样本(默认值:1)
mnist_train = MNIST(root="../MNIST_data",
                    train=True, #训练集
                    download=True, #如果下了就设置为False
                    transform=my_transforms)
mnist_test=MNIST(root="../MNIST_test",
                 train=False,#测试集
                 download=True,#如果下了就设置为False
                 transform=my_transforms)
#2.构建模型
# 全连接层
class MnistModel(nn.Module):
    def __init__(self):
        super(MnistModel, self).__init__()
        self.fc1 = nn.Linear(1*28*28, 100)  # 最终为什么是 10,因为手写数字识别最终是 10分类的,分类任务中有多少,就分几类。 0-9
        self.relu = nn.ReLU()
        self.fc2=nn.Linear(100,10)

    def forward(self, image):
        image_viwed = image.view(-1, 1*28*28)  # 此处需要拍平
        out = self.fc1(image_viwed)
        fc1_out = self.relu(out)
        out2=self.fc2(fc1_out)
        return out2
#3.优化模型参数
def train(train_dataloader, model, loss_function, optimizer):
    '''
    训练
    :param train_dataloader:
    :param model:
    :param loss_function:
    :param optimizer:
    :return:
    '''
    model.train()#必写,需要训练一次,不然报错
    for (images, labels) in tqdm(train_dataloader,total=len(train_dataloader)):
        images, labels = images.to(device), labels.to(device)
        #梯度置零
        optimizer.zero_grad()
        #前向传播
        output=model(images)
        #通过结果计算损失
        loss=loss_function(output,labels)
        #反向传播
        loss.backward()
        #单次优化,优化器更新
        optimizer.step()


def test(test_dataloader, model, loss_function):
    '''
    测试
    :param test_dataloader:
    :param model:
    :param loss_function:
    :return:
    '''
    model.eval()
    size = len(test_dataloader.dataset)
    num_batches = len(test_dataloader)
    model.eval()
    test_loss, correct = 0, 0
    with torch.no_grad():
        for X, y in test_dataloader:
            X, y = X.to(device), y.to(device)
            pred = model(X)  #
            test_loss += loss_function(pred, y).item()
            correct += (pred.argmax(1) == y).type(torch.float).sum().item()
    test_loss /= num_batches
    correct /= size
    print(f"Test Error: \n Accuracy: {(100 * correct):>0.1f}%, Avg loss: {test_loss:>8f} \n")


model=MnistModel().to(device)#实例化模型并到gpu上运算
optimizer = optim.Adam(model.parameters())#优化器选择
loss_function= nn.CrossEntropyLoss()#选择交叉熵损失
train_dataloader = DataLoader(mnist_train, batch_size=batch_size,shuffle=True)#shuffle ( bool , optional ) – 设置为True在每个 epoch 重新洗牌数据(默认值:)False。
test_dataloader=DataLoader(mnist_test,batch_size=batch_size,shuffle=True)
epochs=10
for t in range(epochs):
    print(f"Epoch {t+1}\n-------------------------------")
    train(train_dataloader, model, loss_function, optimizer)
    test(test_dataloader, model, loss_function)
print("Done!")


#4.保存
save(model.state_dict(),"minist.pkl")#存模型
save(optimizer.state_dict(),"optimizer.pkl")#存优化器

拿自己图片测试的代码

from torchvision import transforms
from torch import nn
import torch
from PIL import Image

class MnistModel(nn.Module):
    def __init__(self):
        super(MnistModel, self).__init__()#父类初始化命令行
        self.fc1 = nn.Linear(1*28*28, 100)  # 28是像素,最终为什么是 10,因为手写数字识别最终是10(0,1,2,3,4,5,6,7,8,9)分类的,分类任务中有多少,就分几类
        self.relu = nn.ReLU()#激活函数
        self.fc2 = nn.Linear(100, 10)#线性层

    def forward(self, image):
        image_viwed = image.view(-1, 1*28*28)  #重点: 此处需要拍平
        out_1 = self.fc1(image_viwed)
        fc1 = self.relu(out_1)#激活函数
        out_2 = self.fc2(fc1)
        return out_2


model = MnistModel()
model.load_state_dict(torch.load("D:\pych\pytorch_test\minist.pkl"))#路径尽量写绝对路径吧
image = Image.open(r'D:\Desktop\5.jpg')#需要测试的图片路径
# print(image)<PIL.JpegImagePlugin.JpegImageFile image mode=RGB(三通道) size=224x205 at 0x19583197430>
my_transforms = transforms.Compose(
    [
        transforms.Grayscale(1),#通道变为1,因为图片之前是RGB三通道的
        transforms.ToTensor(),#变张量
        transforms.Normalize(mean=(0.1307, ), std=(0.3081, ))#z-score 标准化,参数type为元组
    ]
)
image = my_transforms(image)
with torch.no_grad():#禁止梯度计算,因为测试效果不需要
    pred = model(image)
    result=pred.max(dim=1).indices
    print(result)#

标签:loss,代码,transforms,2022,test,model,image,self,MNIST
来源: https://blog.csdn.net/qq_53593099/article/details/122693833

本站声明: 1. iCode9 技术分享网(下文简称本站)提供的所有内容,仅供技术学习、探讨和分享;
2. 关于本站的所有留言、评论、转载及引用,纯属内容发起人的个人观点,与本站观点和立场无关;
3. 关于本站的所有言论和文字,纯属内容发起人的个人观点,与本站观点和立场无关;
4. 本站文章均是网友提供,不完全保证技术分享内容的完整性、准确性、时效性、风险性和版权归属;如您发现该文章侵犯了您的权益,可联系我们第一时间进行删除;
5. 本站为非盈利性的个人网站,所有内容不会用来进行牟利,也不会利用任何形式的广告来间接获益,纯粹是为了广大技术爱好者提供技术内容和技术思想的分享性交流网站。

专注分享技术,共同学习,共同进步。侵权联系[81616952@qq.com]

Copyright (C)ICode9.com, All Rights Reserved.

ICode9版权所有