ICode9

精准搜索请尝试: 精确搜索
首页 > 其他分享> 文章详细

【实验】鸢尾花分类——简单的神经网络

2022-02-21 13:02:33  阅读:279  来源: 互联网

标签:loss torch nn 分类 神经网络 train test 鸢尾花 net


import torch
from torch import nn
from sklearn.datasets import load_iris
from sklearn.model_selection import train_test_split
import numpy as np
import matplotlib.pyplot as plt

X = torch.tensor(load_iris().data, dtype=torch.float32)  
y = torch.tensor(load_iris().target, dtype=torch.long)

X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42)

导入鸢尾花数据集,这里注意数据和标签类型的设置:dtype=torch.float32,dtype=torch.long,否则会报错

net = nn.Sequential(nn.Linear(4, 10), 
                    nn.ReLU(),
                    nn.Linear(10, 10),
                    nn.ReLU(),
                    nn.Linear(10, 3))

def init_weights(m):
    if type(m) == nn.Linear:
        nn.init.normal_(m.weights, std=0.01)

loss = nn.CrossEntropyLoss(reduction="none")      

trainer = torch.optim.Adam(net.parameters(), lr=0.05)

train_loss = []
test_loss = []
train_l = sum(loss(net(X_train), y_train)).detach().numpy()
test_l = sum(loss(net(X_test), y_test)).detach().numpy()
train_loss.append(train_l)
test_loss.append(test_l)

epochs = 1000

for i in range(epochs):
    trainer.zero_grad()
    l = sum(loss(net(X_train), y_train))
    l.backward()
    trainer.step()
    l = sum(loss(net(X), y))
    
    train_l = sum(loss(net(X_train), y_train)).detach().numpy()
    test_l = sum(loss(net(X_test), y_test)).detach().numpy()
    train_loss.append(train_l)
    test_loss.append(test_l)
epoch_index = range(epochs + 1)
plt.plot(epoch_index, train_loss, 'green', epoch_index, test_loss, 'blue')    
plt.show()

使用交叉熵损失函数时, 定义神经网络架构的时候不需要用Softmax !     (我一开始在神经网络最后一层加了nn.Softmax有报错)

关于交叉熵损失函数,nn.CrossEntropyLoss(),有一些需要注意的点

贴篇网上介绍的博客,后面看自己有没有时间总结下。https://blog.csdn.net/geter_CS/article/details/84857220

 

有些场合(例如用matplotlib绘图)需要用numpy的数组,使用能求梯度的tensor是会报错的!

这里用.detach().numpy()来完成,例子可以见上面的代码

 

实验结果:

 

标签:loss,torch,nn,分类,神经网络,train,test,鸢尾花,net
来源: https://www.cnblogs.com/kyfishing/p/15918401.html

本站声明: 1. iCode9 技术分享网(下文简称本站)提供的所有内容,仅供技术学习、探讨和分享;
2. 关于本站的所有留言、评论、转载及引用,纯属内容发起人的个人观点,与本站观点和立场无关;
3. 关于本站的所有言论和文字,纯属内容发起人的个人观点,与本站观点和立场无关;
4. 本站文章均是网友提供,不完全保证技术分享内容的完整性、准确性、时效性、风险性和版权归属;如您发现该文章侵犯了您的权益,可联系我们第一时间进行删除;
5. 本站为非盈利性的个人网站,所有内容不会用来进行牟利,也不会利用任何形式的广告来间接获益,纯粹是为了广大技术爱好者提供技术内容和技术思想的分享性交流网站。

专注分享技术,共同学习,共同进步。侵权联系[81616952@qq.com]

Copyright (C)ICode9.com, All Rights Reserved.

ICode9版权所有