ICode9

精准搜索请尝试: 精确搜索
首页 > 其他分享> 文章详细

Pytorch实战:CIFAR-10分类

2020-11-08 12:00:44  阅读:281  来源: 互联网

标签:10 nn loss self labels CIFAR Pytorch print data


最近在学习Pytorch,先照着别人的代码过一遍,加油!!!

 

加载数据集

# 加载数据集及预处理
import torchvision as tv
import torchvision.transforms as transforms
from torchvision.transforms import ToPILImage
import torch as t
show=ToPILImage() #可以将Tensor转成Image,方便可视化

划分数据集为训练集和测试集

#定义对数据的预处理
transform=transforms.Compose([
    transforms.ToTensor(),  #转为Tensor
    transforms.Normalize((0.5,0.5,0.5),(0.5,0.5,0.5)), #归一化
])

#训练集
trainset=tv.datasets.CIFAR10(
    root='/home/cy/data',
    train=True,
    download=True,
    transform=transform
)

trainloader=t.utils.data.DataLoader(
    trainset,
    batch_size=4,
    shuffle=True,
    num_workers=2
)

testset=tv.datasets.CIFAR10(
    '/home/cy/data/',
    train=False,
    download=True,
    transform=transform
)

testloader=t.utils.data.DataLoader(
    testset,
    batch_size=4,
    shuffle=False,
    num_workers=2
)

classes=('plane','car','bird','cat','deer','dog','frog','horse','ship','truck')
Files already downloaded and verified
Files already downloaded and verified

可视化看下图片效果
(data, label)=trainset[100]
print(classes[label])

#(data+1)是为了还原被归一化的数据
show((data+1)/2).resize((100,100))

展示一个mini-batch中的图片

dataiter=iter(trainloader)
images,labels=dataiter.next() #返回4张图片及标签
print(' '.join('%11s'%classes[labels[j]] for j in range(4)))
show(tv.utils.make_grid((images+1)/2)).resize((400,100))

 

定义网络结构,挺方便的

## 定义网络
import torch.nn as nn
import torch.nn.functional as F

class Net(nn.Module):
    def __init__(self):
        super(Net,self).__init__()
        self.conv1=nn.Conv2d(3,6,5)
        self.conv2=nn.Conv2d(6,16,5)
        self.fc1=nn.Linear(16*5*5,120)
        self.fc2=nn.Linear(120,84)
        self.fc3=nn.Linear(84,10)
        
        
    def forward(self,x):
        x=F.max_pool2d(F.relu(self.conv1(x)),(2,2))
        x=F.max_pool2d(F.relu(self.conv2(x)),2)
        x=x.view(x.size()[0],-1)
        x=F.relu(self.fc1(x))
        x=F.relu(self.fc2(x))
        x=self.fc3(x)
        return x

net=Net()
print(net)
Net(
  (conv1): Conv2d(3, 6, kernel_size=(5, 5), stride=(1, 1))
  (conv2): Conv2d(6, 16, kernel_size=(5, 5), stride=(1, 1))
  (fc1): Linear(in_features=400, out_features=120, bias=True)
  (fc2): Linear(in_features=120, out_features=84, bias=True)
  (fc3): Linear(in_features=84, out_features=10, bias=True)
)

定义损失函数和优化器
## 定义损失函数和优化器
from torch import optim
criterion=nn.CrossEntropyLoss()  # 交叉熵损失函数
optimizer=optim.SGD(net.parameters(),lr=0.001,momentum=0.9) #随机梯度下降,stochastic gradient descent

开始训练网络

一共有三个步骤。输入数据,前向传播+反向传播,更新参数

from torch.autograd import Variable

for epoch in range(2):
    running_loss=0.0
    for i,data in enumerate(trainloader,0):
        #输入数据
        inputs,labels=data
        inputs,labels=Variable(inputs),Variable(labels)
        
        #梯度清零
        optimizer.zero_grad()
        
        #forward+backward
        outputs=net(inputs)
        loss=criterion(outputs,labels)
        loss.backward()
        
        #更新参数
        optimizer.step()
        
        #打印log信息
        #running_loss +=loss.data[0]
        running_loss +=loss.item()
        if i%2000 ==1999:   #每2000个batch打印一次训练状态
            print('[%d, %5d] loss: %.3f' \
                 %(epoch+1,i+1,running_loss / 2000))
            running_loss=0.0
print('Finished Training')

 

检查一下网络在一个batch内的效果如何

## 检验网络效果
dataiter=iter(testloader)
images,labels=dataiter.next() #一个batch返回4张图片
print('实际的label: ',' '.join(\
            '%08s'%classes[labels[j]] for j in range(4)))
show(tv.utils.make_grid(images/2 -0.5)).resize((400,100))

# 计算网络预测的label
outputs=net(Variable(images))
_,predicted=t.max(outputs.data,1)
print('预测结果: ',' '.join('%5s'\
        % classes[predicted[j]] for j in range(4)))

 

测试集上计算正确率

correct=0
total=0
for data in testloader:
    images,labels=data
    outputs=net(Variable(images))
    _,predicted=t.max(outputs.data,1)
    total +=labels.size(0)
    correct +=(predicted==labels).sum()
    
print('1000张测试集中的准确率为: %d  %%' %(100* correct/total))
1000张测试集中的准确率为: 52  %

 

可以看到,在CIFAR-10上的正确率为52%,网络训练还是有些效果的。

 

标签:10,nn,loss,self,labels,CIFAR,Pytorch,print,data
来源: https://www.cnblogs.com/keeptry/p/13943820.html

本站声明: 1. iCode9 技术分享网(下文简称本站)提供的所有内容,仅供技术学习、探讨和分享;
2. 关于本站的所有留言、评论、转载及引用,纯属内容发起人的个人观点,与本站观点和立场无关;
3. 关于本站的所有言论和文字,纯属内容发起人的个人观点,与本站观点和立场无关;
4. 本站文章均是网友提供,不完全保证技术分享内容的完整性、准确性、时效性、风险性和版权归属;如您发现该文章侵犯了您的权益,可联系我们第一时间进行删除;
5. 本站为非盈利性的个人网站,所有内容不会用来进行牟利,也不会利用任何形式的广告来间接获益,纯粹是为了广大技术爱好者提供技术内容和技术思想的分享性交流网站。

专注分享技术,共同学习,共同进步。侵权联系[81616952@qq.com]

Copyright (C)ICode9.com, All Rights Reserved.

ICode9版权所有