ICode9

精准搜索请尝试: 精确搜索
首页 > 其他分享> 文章详细

手动实现前馈神经网络解决 多分类 任务

2022-03-06 01:32:50  阅读:232  来源: 互联网

标签:loss torch num sum 手动 前馈 神经网络 train test


1 导入实验需要的包

import torch
import numpy as np
import random
from IPython import  display
import matplotlib.pyplot as plt
from torch.utils.data import DataLoader,TensorDataset
from torchvision import transforms,datasets
from torch import nn

2 加载数据集

mnist_train = datasets.MNIST(root = './Datasets/MNIST/',train = True,download = True,transform =transforms.ToTensor())
mnist_test = datasets.MNIST(root ='./Datasets/MNIST/',train = False,download = True,transform = transforms.ToTensor())

batch_size = 256
train_iter = DataLoader( 
    dataset = mnist_train,
    shuffle = True,
    batch_size = batch_size,
    num_workers = 0
)
test_iter = DataLoader(
    dataset  = mnist_test,
    shuffle  =False,
    batch_size = batch_size,
    num_workers = 0
)

3 初始化参数

num_input ,num_hiddens ,num_output = 784,256,10
W1 =  torch.tensor(np.random.normal(0,0.01,size = (num_hiddens,num_input)),dtype = torch.float32)
b1 = torch.zeros(1,dtype = torch.float32)

W2 =  torch.tensor(np.random.normal(0,0.01,size = (num_output,num_hiddens)),dtype = torch.float32)
b2 = torch.zeros(1,dtype = torch.float32)

params = [W1 ,b1,W2,b2]
for param in params:
    param.requires_grad_(requires_grad = True)

4 定义激活函数

def ReLU(X):
    return torch.max(X,other = torch.tensor(0.0))

5 定义网络模型

def net(x):
    x = x.view(-1,num_input)
    H1 = ReLU(torch.matmul(x,W1.t())+b1)
    H2 = torch.matmul(H1,W2.t()+b2)
    return H2

6 定义损失函数和优化算法

#定义多分类交叉熵损失函数  
loss = torch.nn.CrossEntropyLoss()  
def SGD(params,lr):
    for param in params:
        param.data -= param.grad/batch_size

7 定义评价函数

def evaluate_loss(data_iter,net):
        acc_sum,loss_sum,n= 0,0,0
        for x,y in data_iter:
            y_pred = net(x)
            l = loss(y_pred,y)
            loss_sum += l.item()
            acc_sum += (y_pred.argmax(dim =1)==y).sum().item()
            n += y.shape[0]
        return acc_sum/n,loss_sum/n
# def evaluate_loss():
#         n = mnist_test.data.shape[0]
#         x = torch.tensor(mnist_test.data,dtype = torch.float32)
#         y  = torch.tensor(mnist_test.targets,dtype = torch.float32)
#         y_pred = net(x)
#         acc_sum = (y_pred.argmax(dim = 1)==y).sum().item()
#         loss_sum = loss(y_pred,mnist_test.targets).item()
#         return acc_sum/n,loss_sum/n

8 定义训练函数

def train(net,train_iter,test_iter,loss,num_epochs,batch_size,lr):
    train_ls ,test_ls = [],[]
    for epoch in range(num_epochs):
        train_l_sum, train_acc_num,n = 0.0,0.0,0
        for x ,y in train_iter:
            y_pred = net(x)
            l = loss(y_pred,y)
            if params is not None and params[0].grad is not None:
                for param in params:
                    param.grad.data.zero_()
            l.backward()
            SGD(params,lr)
            train_l_sum += l.item()
            train_acc_num += (y_pred.argmax(dim = 1)==y).sum().item()
            n +=y.shape[0]
        train_ls.append(train_l_sum/n)  
        test_acc,test_l = evaluate_loss(test_iter,net)  
        test_ls.append(test_l)
        print('epoch %d, train_loss %.6f,test_loss %f,train_acc %.6f,test_acc %.6f'%(epoch+1, train_ls[epoch],test_ls[epoch],train_acc_num/n,test_acc))  
    return train_ls,test_ls        

9 训练

lr = 0.01  
num_epochs = 50  
train_loss,test_loss = train(net,train_iter,test_iter,loss,num_epochs,batch_size,lr)   

10 可视化

x = np.linspace(0,len(train_loss),len(train_loss))  
plt.plot(x,train_loss,label="train_loss",linewidth=1.5)  
plt.plot(x,test_loss,label="test_loss",linewidth=1.5)  
plt.xlabel("epoch")  
plt.ylabel("loss")  
plt.legend()  
plt.show()  

标签:loss,torch,num,sum,手动,前馈,神经网络,train,test
来源: https://www.cnblogs.com/BlairGrowing/p/15970091.html

本站声明: 1. iCode9 技术分享网(下文简称本站)提供的所有内容,仅供技术学习、探讨和分享;
2. 关于本站的所有留言、评论、转载及引用,纯属内容发起人的个人观点,与本站观点和立场无关;
3. 关于本站的所有言论和文字,纯属内容发起人的个人观点,与本站观点和立场无关;
4. 本站文章均是网友提供,不完全保证技术分享内容的完整性、准确性、时效性、风险性和版权归属;如您发现该文章侵犯了您的权益,可联系我们第一时间进行删除;
5. 本站为非盈利性的个人网站,所有内容不会用来进行牟利,也不会利用任何形式的广告来间接获益,纯粹是为了广大技术爱好者提供技术内容和技术思想的分享性交流网站。

专注分享技术,共同学习,共同进步。侵权联系[81616952@qq.com]

Copyright (C)ICode9.com, All Rights Reserved.

ICode9版权所有