ICode9

精准搜索请尝试: 精确搜索
首页 > 其他分享> 文章详细

Pytorch实现mnist手写体数字识别(非常非常详细!!!!最新!!!!!!!!!)

2021-03-04 21:30:03  阅读:1342  来源: 互联网

标签:loss nn torch Pytorch train 手写体 model valid mnist


Pytorch实现mnist

读取Mnist数据集

from pathlib import Path   # python3中 取代os.path
import requests

DATA_PATH = Path("data")
PATH = DATA_PATH / "mnist"

PATH.mkdir(parents=True, exist_ok=True)

URL = "http://deeplearning.net/data/mnist/"
FILENAME = "mnist.pkl.gz"

if not (PATH / FILENAME).exists():    #mnist文件未下载时,requests去get URL下载文件
        content = requests.get(URL + FILENAME).content
        (PATH / FILENAME).open("wb").write(content)
import pickle
import gzip

with gzip.open((PATH / FILENAME).as_posix(), "rb") as f:
        ((x_train, y_train), (x_valid, y_valid), _) = pickle.load(f, encoding="latin-1")
```![在这里插入图片描述](https://www.icode9.com/i/ll/?i=20210304210945960.png?,type_ZmFuZ3poZW5naGVpdGk,shadow_10,text_aHR0cHM6Ly9ibG9nLmNzZG4ubmV0L3dlaXhpbl80NTA3NDU2OA==,size_16,color_FFFFFF,t_70#pic_center)
#### 注意数据需转换成tensor才能参与后续建模训练

```python
import torch

x_train, y_train, x_valid, y_valid = map(
    torch.tensor, (x_train, y_train, x_valid, y_valid)  # map(function,iter)   iter中的元素调用function
)
n, c = x_train.shape
x_train, x_train.shape, y_train.min(), y_train.max()
print(x_train, y_train)
print(x_train.shape)
print(y_train.min(), y_train.max())

torch.nn.functional 很多层和函数在这里都会见到

torch.nn.functional中有很多功能,后续会常用的。那什么时候使用nn.Module,什么时候使用nn.functional呢?一般情况下,如果模型有可学习的参数,最好用nn.Module,其他情况nn.functional相对更简单一些.

import torch.nn.functional as F

loss_func = F.cross_entropy

def model(xb):
    return xb.mm(weights) + bias
bs = 64
xb = x_train[0:bs]  # a mini-batch from x
yb = y_train[0:bs]
weights = torch.randn([784, 10], dtype = torch.float,  requires_grad = True) 
bs = 64
bias = torch.zeros(10, requires_grad=True)

print(loss_func(model(xb), yb))

创建一个model来更简化代码

(1)必须继承nn.Module且在其构造函数中需调用nn.Module的构造函数
(2)无需写反向传播函数,nn.Module能够利用autograd自动实现反向传播
(3)Module中的可学习参数可以通过named_parameters()或者parameters()返回迭代器

from torch import nn

class Mnist_NN(nn.Module):
    def __init__(self):
        super().__init__()
        self.hidden1 = nn.Linear(784, 128)
        self.hidden2 = nn.Linear(128, 256)
        self.out  = nn.Linear(256, 10)

    def forward(self, x):
        x = F.relu(self.hidden1(x))
        x = F.relu(self.hidden2(x))
        x = self.out(x)
        return x
        

使用TensorDataset和DataLoader来简化

这里TensorDataset将x_train和y_train进行绑定,DataLoader类似python的生成器,在做数据增强的时候使用的。

from torch.utils.data import TensorDataset
from torch.utils.data import DataLoader

train_ds = TensorDataset(x_train, y_train)  #类似python中zip   
train_dl = DataLoader(train_ds, batch_size=bs, shuffle=True)   #一般与数据增强搭配使用,类似python中生成器

valid_ds = TensorDataset(x_valid, y_valid)
valid_dl = DataLoader(valid_ds, batch_size=bs * 2)
def get_data(train_ds, valid_ds, bs):
    return (
        DataLoader(train_ds, batch_size=bs, shuffle=True),
        DataLoader(valid_ds, batch_size=bs * 2),
    )
一般在训练模型时加上model.train(),这样会正常使用Batch Normalization和 Dropout
测试的时候一般选择model.eval(),这样就不会使用Batch Normalization和 Dropout
import numpy as np

def fit(steps, model, loss_func, opt, train_dl, valid_dl):
    for step in range(steps):
        model.train()  # 在训练的时候调用,因为train的时候用的数据都是不同的
        for xb, yb in train_dl:
            loss_batch(model, loss_func, xb, yb, opt)

        model.eval()#  在预测的时候调用,因为val中 用的数据集是整体的
        with torch.no_grad():
            losses, nums = zip(
                *[loss_batch(model, loss_func, xb, yb) for xb, yb in valid_dl]
            )
        val_loss = np.sum(np.multiply(losses, nums)) / np.sum(nums)
        print('当前step:'+str(step), '验证集损失:'+str(val_loss))
from torch import optim
def get_model():
    model = Mnist_NN()
    return model, optim.SGD(model.parameters(), lr=0.001)
def loss_batch(model, loss_func, xb, yb, opt=None):
    loss = loss_func(model(xb), yb)

    if opt is not None:
        loss.backward()
        opt.step()
        opt.zero_grad()

    return loss.item(), len(xb)

三行搞定测试!!!

train_dl, valid_dl = get_data(train_ds, valid_ds, bs)  # 获取batch数据,跟python生成器类似
model, opt = get_model()# 获得model,以及优化器
fit(100, model, loss_func, opt, train_dl, valid_dl)# 将参数传入fit

标签:loss,nn,torch,Pytorch,train,手写体,model,valid,mnist
来源: https://blog.csdn.net/weixin_45074568/article/details/114377911

本站声明: 1. iCode9 技术分享网(下文简称本站)提供的所有内容,仅供技术学习、探讨和分享;
2. 关于本站的所有留言、评论、转载及引用,纯属内容发起人的个人观点,与本站观点和立场无关;
3. 关于本站的所有言论和文字,纯属内容发起人的个人观点,与本站观点和立场无关;
4. 本站文章均是网友提供,不完全保证技术分享内容的完整性、准确性、时效性、风险性和版权归属;如您发现该文章侵犯了您的权益,可联系我们第一时间进行删除;
5. 本站为非盈利性的个人网站,所有内容不会用来进行牟利,也不会利用任何形式的广告来间接获益,纯粹是为了广大技术爱好者提供技术内容和技术思想的分享性交流网站。

专注分享技术,共同学习,共同进步。侵权联系[81616952@qq.com]

Copyright (C)ICode9.com, All Rights Reserved.

ICode9版权所有