ICode9

精准搜索请尝试: 精确搜索
首页 > 其他分享> 文章详细

123

2022-03-21 19:03:12  阅读:140  来源: 互联网

标签:real res mid 123 np data out


import pandas as pd
import numpy as np
import matplotlib.pyplot as plt

def sigmoid(x):    # 定义网络激活函数
    return 1/(1+np.exp(-x))

data_tr = pd.read_csv('D:\\人工智能\\xunlian.txt')  # 训练集样本
data_te = pd.read_csv('D:\\人工智能\\ceshi.txt')  # 测试集样本
n = len(data_tr)
yita = 0.1  # 自己设置学习率

out_in = np.array([0.0, 0, 0, 0, -1])   # 输出层的输入,即隐层的输出
w_mid = np.zeros([3,4])  # 隐层神经元的权值&阈值
w_out = np.zeros([5])     # 输出层神经元的权值&阈值

delta_w_out = np.zeros([5])      # 输出层权值&阈值的修正量
delta_w_mid = np.zeros([3,4])   # 中间层权值&阈值的修正量
Err = []
'''
模型训练
'''
for j in range(800):
    error = []
    for it in range(n):
        net_in = np.array([data_tr.iloc[it, 0], data_tr.iloc[it, 1], -1])  # 网络输入
        real = data_tr.iloc[it, 2]
        for i in range(4):
            out_in[i] = sigmoid(sum(net_in * w_mid[:, i]))  # 从输入到隐层的传输过程
        res = sigmoid(sum(out_in * w_out))   # 模型预测值
        error.append(abs(real-res))#误差

        print('第',it, '个样本的模型输出:', res, 'real:', real)
        delta_w_out = yita*res*(1-res)*(real-res)*out_in  # 输出层权值的修正量
        delta_w_out[4] = -yita*res*(1-res)*(real-res)     # 输出层阈值的修正量
        w_out = w_out + delta_w_out   # 更新,加上修正量

        for i in range(4):
            delta_w_mid[:, i] = yita*out_in[i]*(1-out_in[i])*w_out[i]*res*(1-res)*(real-res)*net_in   # 中间层神经元的权值修正量
            delta_w_mid[2, i] = -yita*out_in[i]*(1-out_in[i])*w_out[i]*res*(1-res)*(real-res)         # 中间层神经元的阈值修正量,第2行是阈值
        w_mid = w_mid + delta_w_mid   # 更新,加上修正量
    Err.append(np.mean(error))
print(w_mid,w_out)
plt.plot(Err)#训练集上每一轮的平均误差
plt.show()
plt.close()

'''
将测试集样本放入训练好的网络中去
'''
for it in range(len(data_te)):
    net_in = np.array([data_te.iloc[it, 0], data_te.iloc[it, 1], -1])  # 网络输入
    for i in range(4):
        out_in[i] = sigmoid(sum(net_in * w_mid[:, i]))  # 从输入到隐层的传输过程
    res = sigmoid(sum(out_in * w_out))   # 模型预测值
    print('第',it+1,'个测试值:',res)

复制代码  

 

 

标签:real,res,mid,123,np,data,out
来源: https://www.cnblogs.com/jiujiuwawa/p/16036013.html

本站声明: 1. iCode9 技术分享网(下文简称本站)提供的所有内容,仅供技术学习、探讨和分享;
2. 关于本站的所有留言、评论、转载及引用,纯属内容发起人的个人观点,与本站观点和立场无关;
3. 关于本站的所有言论和文字,纯属内容发起人的个人观点,与本站观点和立场无关;
4. 本站文章均是网友提供,不完全保证技术分享内容的完整性、准确性、时效性、风险性和版权归属;如您发现该文章侵犯了您的权益,可联系我们第一时间进行删除;
5. 本站为非盈利性的个人网站,所有内容不会用来进行牟利,也不会利用任何形式的广告来间接获益,纯粹是为了广大技术爱好者提供技术内容和技术思想的分享性交流网站。

专注分享技术,共同学习,共同进步。侵权联系[81616952@qq.com]

Copyright (C)ICode9.com, All Rights Reserved.

ICode9版权所有