ICode9

精准搜索请尝试: 精确搜索
首页 > 其他分享> 文章详细

22.详解过拟合代码

2021-11-15 20:02:51  阅读:157  来源: 互联网

标签:plt dout 22 ofit torch 详解 拟合 test data


'''
Description: overfitting-review
Autor: 365JHWZGo
Date: 2021-11-15 18:41:20
LastEditors: 365JHWZGo
LastEditTime: 2021-11-15 19:59:11
'''
# 导包
import torch
import matplotlib.pyplot as plt

# hyper parameters
LR = 0.01           #Adam学习效率
N_HIDDENS = 300     #隐藏神经元的个数
N_POINTS = 20       #数据点的个数

# create some data for training
#创造一个从[-10,10]的十个均等的间隔,并给他们新加一个列维度
x = torch.unsqueeze(torch.linspace(-10, 10, N_POINTS), dim=1)
# torch.zeros(row,column)
# torch.normal(means【均值】,std【标准值】)
#使其于y=x^2较为拟合
#10*torch.normal(torch.zeros(N_POINTS, 1), torch.ones(N_POINTS, 1))每一个值都需要添加一个随机噪点
y = x**2+10*torch.normal(torch.zeros(N_POINTS, 1), torch.ones(N_POINTS, 1))

# create some data for testing
test_x = torch.unsqueeze(torch.linspace(-10, 10, N_POINTS), dim=1)
# torch.zeros(row,column)
# torch.normal(means【均值】,std【标准值】)
test_y = test_x**2+10 * \
    torch.normal(torch.zeros(N_POINTS, 1), torch.ones(N_POINTS, 1))

#画点
plt.scatter(x, y, c='r', lw=3, alpha=0.2, label='train data')
plt.scatter(test_x, test_y, c='b', lw=3, alpha=0.2, label='test data')
#展示画布
plt.show()

# create network
# overfitting
net_ofit = torch.nn.Sequential(
    #输入一个点的纵坐标,经过N_HIDDENS个神经元
    torch.nn.Linear(1, N_HIDDENS),
    #激活函数
    torch.nn.ReLU(),
    torch.nn.Linear(N_HIDDENS, N_HIDDENS),
    torch.nn.ReLU(),
    torch.nn.Linear(N_HIDDENS, 1)
)

# dropout
net_dout = torch.nn.Sequential(
    torch.nn.Linear(1, N_HIDDENS),
    #将神经元的个数随机丢掉50%
    torch.nn.Dropout(0.5),
    torch.nn.ReLU(),
    torch.nn.Linear(N_HIDDENS, N_HIDDENS),
    torch.nn.Dropout(0.5),
    torch.nn.ReLU(),
    torch.nn.Linear(N_HIDDENS, 1)
)

#创造优化器
opt_ofit = torch.optim.Adam(net_ofit.parameters(), lr=LR)
opt_dout = torch.optim.Adam(net_dout.parameters(), lr=LR)
#创建损失函数
loss_func = torch.nn.MSELoss()
#进入互动模式
plt.ion()

# training
if __name__ == '__main__':
    for i in range(1000):
        pred_ofit = net_ofit(x)
        pred_dout = net_dout(x)

        loss_ofit = loss_func(pred_ofit, y)
        loss_dout = loss_func(pred_dout, y)

        #优化
        opt_ofit.zero_grad()    #梯度清零
        opt_dout.zero_grad()
        loss_ofit.backward()    #回滚
        loss_dout.backward()
        opt_ofit.step()         #更新梯度参数
        opt_dout.step()

        if i % 50 == 0:
            #使其进入预测模式,此时不需要将神经元随机丢掉50%,因为神经网络已经训练过了,只需要使用就可以
            net_ofit.eval()
            net_dout.eval()

            test_pred_ofit = net_ofit(test_x)
            test_pred_dout = net_dout(test_x)

            plt.cla()
            plt.scatter(x.data.numpy(), y.data.numpy(), c='r', lw=3, alpha=0.2, label='train data')
            plt.scatter(test_x.data.numpy(), test_y.data.numpy(), c='b', lw=3, alpha=0.2, label='test data')
            plt.plot(test_x.data.numpy(),test_pred_ofit.data.numpy(),'r-',lw=3,label='overfitting line')
            plt.plot(test_x.data.numpy(),test_pred_dout.data.numpy(),'b--',lw=3,label='dropout line')
            plt.legend(loc='best')
            plt.ylim((-20,120))
            plt.text(-5,50,'overfitting_loss=%.2f'%loss_func(test_pred_ofit,test_y).data.numpy(),fontdict={'color':'orange','size':13})
            plt.text(-5,55,'dropout_loss=%.2f'%loss_func(test_pred_dout,test_y).data.numpy(),fontdict={'color':'pink','size':13})
            plt.pause(0.1)

            #进入训练模式
            net_ofit.train()
            net_dout.train()
    #停止互动模式
    plt.ioff()
    #展示图画
    plt.show()

在这里插入图片描述

过拟合

简介

当数据量一定时,机器为了将误差减到最小,从而使得模型不再符合实际中的真实样子,这种现象叫做过拟合

蓝线:我们希望计算机学习到的模型

红线:计算机为了减小误差学习到的模型

出现原因

  1. 数据量过于少
  2. 神经网络过于复杂

解决方法

  1. 增加数据量

  2. 运用正规化

    Y=Wx (W:机器学习的各种参数)

    1. l1正规化

    2. l2正规化

    3. Dropout[专门用于神经网络]

      在训练过程中随机忽略一些神经元,让神经网络变得不完整,使得每一次预测结果都不会太依赖其中的某些值

总结

在这个例子中,首先是先创造出一些数据,使其较拟合一个曲线函数,然后通过神经网络学习,没有使用dropout的神经网络虽然它能很好的拟合train data 但是对于test data的曲线拟合确显得误差很大,不能拟合更多的点,而随机dropout50%的神经元之后,使得神经网络的对于某些点的依赖减少,从而获得了很好的贴合度,总之,凡事都有个度,可不要贪杯呀!

标签:plt,dout,22,ofit,torch,详解,拟合,test,data
来源: https://blog.csdn.net/qq_44833392/article/details/121341395

本站声明: 1. iCode9 技术分享网(下文简称本站)提供的所有内容,仅供技术学习、探讨和分享;
2. 关于本站的所有留言、评论、转载及引用,纯属内容发起人的个人观点,与本站观点和立场无关;
3. 关于本站的所有言论和文字,纯属内容发起人的个人观点,与本站观点和立场无关;
4. 本站文章均是网友提供,不完全保证技术分享内容的完整性、准确性、时效性、风险性和版权归属;如您发现该文章侵犯了您的权益,可联系我们第一时间进行删除;
5. 本站为非盈利性的个人网站,所有内容不会用来进行牟利,也不会利用任何形式的广告来间接获益,纯粹是为了广大技术爱好者提供技术内容和技术思想的分享性交流网站。

专注分享技术,共同学习,共同进步。侵权联系[81616952@qq.com]

Copyright (C)ICode9.com, All Rights Reserved.

ICode9版权所有