ICode9

精准搜索请尝试: 精确搜索
首页 > 其他分享> 文章详细

利用pytorch构建alexnet网络对cifar-10进行分类

2021-04-05 14:33:43  阅读:335  来源: 互联网

标签:10 nn loss self epoch cifar label pytorch data


文章目录

(一)概述

(二)数据预处理

(三)构建网络

(四)选择优化器

(五)训练测试加保存模型

正文

(一)概述

1、CIFAR-10数据集包含10个类别的60000个32x32彩色图像,每个类别有6000张图像。有50000张训练图像和10000张测试图像。
2、数据集分为五个训练批次和一个测试批次,每个批次具有10000张图像。测试集包含从每个类别中1000张随机选择的图像。剩余的图像按照随机顺序构成5个批次的训练集,每个批次中各类图像的数量不相同,但总训练集中每一类都正好有5000张图片
3、数据集中的class(类),以及每个class的10个随机图像:

(二)数据预处理

1、引入包库

import torch
import numpy as np
from torch.utils.data import DataLoader
from torchvision import datasets,transforms
import torch.nn as nn
import torch.nn.functional as F
import os
import time

2、定义超参数

#定义超参数
batch_size=100
learning_rate=1e-2
epochs=200

3、标准化

data_tf=transforms.Compose([transforms.ToTensor(),transforms.Normalize((0.5,0.5,0.5),(0.5,0.5,0.5))])

4、读取数据

train_data = datasets.CIFAR10(root='./data',train=True,transform=data_tf,download=False)
test_data = datasets.CIFAR10(root='./data',train=False,transform=data_tf)

5、装载数据

train_loader=DataLoader(train_data,batch_size=batch_size,shuffle=True)
test_loader=DataLoader(test_data,batch_size=batch_size,shuffle=True)

(三)构建网络

代码

class Alexnet(nn.Module):
    def __init__(self):
        super(Alexnet, self).__init__()
        self.conv1 = nn.Conv2d(3,64,3,2,1)
        self.pool = nn.MaxPool2d(3, 2)
        self.conv2 = nn.Conv2d(64,192, 5, 1, 2)
        self.conv3 = nn.Conv2d(192, 384, 3, 1, 1)
        self.conv4 = nn.Conv2d(384,256, 3, 1, 1)
        self.conv5 = nn.Conv2d(256,256, 3, 1, 1)
        self.drop = nn.Dropout(0.5)
        self.fc1 = nn.Linear(256*6*6, 4096)
        self.fc2 = nn.Linear(4096, 4096)
        self.fc3 = nn.Linear(4096, 1000)

    def forward(self, x):
        x = self.pool(F.relu(self.conv1(x)))
        x = self.pool(F.relu(self.conv2(x)))
        x = F.relu(self.conv3(x))
        x = F.relu(self.conv4(x))
        x = self.pool(F.relu(self.conv5(x)))
        x = x.view(-1, self.num_flat_features(x))
        x = F.relu(self.fc1(x))
        x = self.drop(F.relu(self.fc1(x)))
        x = self.drop(F.relu(self.fc2(x)))
        x = self.fc3(x)
        return x

网络结构

(四)选择模型、优化器,定义loss

model=AlexNet()
#定义loss与参数更新
criterion=nn.CrossEntropyLoss()
optimizer=torch.optim.SGD(model.parameters(),lr=learning_rate)

(五)训练测试

#训练
for epoch in range(epochs):
    total = 0
    running_loss = 0.0
    running_correct = 0
    print("epoch {}/{}".format(epoch, epochs))
    print("-" * 10)
    for data in train_loader:
        img, label = data
     
        img = Variable(img)
        if torch.cuda.is_available():
            img = img.cuda()
            label = label.cuda()
        else:
            img = Variable(img)
            label = Variable(label)
        out = model(img)  # 得到前向传播的结果
        loss = criterion(out, label)  # 得到损失函数
        print_loss = loss.data.item()
        optimizer.zero_grad()  # 归0梯度
        loss.backward()  # 反向传播
        optimizer.step()  # 优化
        running_loss += loss.item()
        epoch += 1
        if epoch % 50 == 0:
            print('epoch:{},loss:{:.4f}'.format(epoch, loss.data.item()))
    _, predicted = torch.max(out.data, 1)
    total += label.size(0)
    running_correct += (predicted == label).sum()
    print('第%d个epoch的识别准确率为:%d%%' % (epoch + 1, (100 * running_correct / total)))

 

标签:10,nn,loss,self,epoch,cifar,label,pytorch,data
来源: https://blog.csdn.net/MosterSakura/article/details/115441316

本站声明: 1. iCode9 技术分享网(下文简称本站)提供的所有内容,仅供技术学习、探讨和分享;
2. 关于本站的所有留言、评论、转载及引用,纯属内容发起人的个人观点,与本站观点和立场无关;
3. 关于本站的所有言论和文字,纯属内容发起人的个人观点,与本站观点和立场无关;
4. 本站文章均是网友提供,不完全保证技术分享内容的完整性、准确性、时效性、风险性和版权归属;如您发现该文章侵犯了您的权益,可联系我们第一时间进行删除;
5. 本站为非盈利性的个人网站,所有内容不会用来进行牟利,也不会利用任何形式的广告来间接获益,纯粹是为了广大技术爱好者提供技术内容和技术思想的分享性交流网站。

专注分享技术,共同学习,共同进步。侵权联系[81616952@qq.com]

Copyright (C)ICode9.com, All Rights Reserved.

ICode9版权所有