ICode9

精准搜索请尝试: 精确搜索
首页 > 其他分享> 文章详细

支持向量回归机(SVR)代码

2021-12-13 14:32:14  阅读:221  来源: 互联网

标签:plt predict data 代码 print train test SVR 向量


SVR的代码(python

项目中一个早期版本的代码,PCA-SVR,参数寻优采用传统的GridsearchCV。

  1 from sklearn.decomposition import PCA
  2 from sklearn.svm import SVR
  3 from sklearn.model_selection import train_test_split
  4 from sklearn.model_selection import GridSearchCV
  5 from sklearn.metrics import r2_score, mean_squared_error, mean_absolute_error
  6 from sklearn.preprocessing import StandardScaler, MinMaxScaler
  7 from numpy import *
  8 import numpy as np
  9 import matplotlib.pyplot as plt
 10 import xlrd
 11 from svmutil import *
 12 import pandas as pd
 13 
 14 '''前言'''
 15 # pca - svr
 16 # CG测试
 17 
 18 '''预设参数'''
 19 fname = "all01.xlsx"  # 训练数据文件读取 26hao
 20 random_1 = 34  # 样本集选取随机种子
 21 random_2 = 4  # 训练集选取随机种子
 22 newpca = 6  # 降维
 23 yuzhi = 50  # 异常点阈值
 24 rate_1 = 0.8  # 样本集验证集
 25 rate_2 = 0.8  # 训练集测试集
 26 bestc = 384  # c
 27 bestg = 9  # gamma
 28 
 29 '''数据读取'''
 30 # xlrd生成对excel表进行操作的对象
 31 ...
 32 
 33 # 输入输出分割
 34 data_x = data[:, 1:11]
 35 data_y = data[:, 0:1]
 36 
 37 '''PCA'''
 38 pca = PCA(n_components=newpca)  # 加载PCA算法,设置降维后主成分数目为
 39 data_x = pca.fit_transform(data_x)  # 对样本进行降维
 40 print(pca.components_)  # 输出主成分,即行数为降维后的维数,列数为原始特征向量转换为新特征的系数
 41 print(pca.explained_variance_ratio_)  # 新特征 每维所能解释的方差大小在全方差中所占比例
 42 
 43 '''数据划分'''
 44 # 样本数据分割
 45 train_data_x, predict_data_x, train_data_y, predict_data_y = train_test_split(data_x, data_y, test_size=rate_1,
 46                                                                               random_state=random_1)
 47 
 48 # 训练数据分割
 49 train_x, test_x, train_y, test_y = train_test_split(train_data_x, train_data_y, test_size=rate_2, random_state=random_2)
 50 predict_x = predict_data_x
 51 predict_y = predict_data_y
 52 
 53 # reshape y
 54 test_y = np.reshape(test_y, -1)
 55 train_y = np.reshape(train_y, -1)
 56 predict_y = np.reshape(predict_y, (-1, 1))
 57 
 58 # StandardScaler x
 59 ss_X = StandardScaler()
 60 ss_X.fit(train_data_x)  # 20%
 61 train_x = ss_X.transform(train_x)
 62 test_x = ss_X.transform(test_x)
 63 predict_x = ss_X.transform(predict_x)
 64 
 65 '''参数优化与SVR'''
 66 # 网格搜索交叉验证(GridSearchCV):以穷举的方式遍历所有可能的参数组合
 67 # 测试用
 68 # param_grid = {'gamma': [bestg], 'C': [bestc]}
 69 # rbf_svr_cg = GridSearchCV(SVR(kernel='rbf'), param_grid, cv=5)
 70 # rbf_svr_cg.fit(train_x,train_y)
 71 # bestc = rbf_svr_cg.best_params_.get('C')
 72 # bestg = rbf_svr_cg.best_params_.get('gamma')
 73 
 74 # 最优参数
 75 print(bestc, bestg)
 76 param_grid = {'gamma': [bestg], 'C': [bestc]}
 77 rbf_svr = SVR(kernel='rbf',param_grid) # 需要修改
 78 
 79 # 训练
 80 rbf_svr.fit(train_x, train_y)
 81 
 82 # 预测
 83 test_y_predict = rbf_svr.predict(test_x)
 84 test_y_predict = np.reshape(test_y_predict, (-1, 1))
 85 predict_y_predict = rbf_svr.predict(predict_x)
 86 predict_y_predict = np.reshape(predict_y_predict, (-1, 1))
 87 
 88 '''去异常点'''
 89 print('样本集:', len(train_data_y))
 90 print('验证集:', len(predict_data_y))
 91 size = len(test_y_predict)
 92 count = 0
 93 for i in range(size):
 94     if abs(test_y_predict[size - i - 1] - test_y[size - i - 1]) > yuzhi:
 95         test_y_predict = np.delete(test_y_predict, size - i - 1)
 96         test_y = np.delete(test_y, size - i - 1)
 97         count = count + 1
 98 print('测试集异常点', count)
 99 size = len(predict_y_predict)
100 count = 0
101 for i in range(size):
102     if abs(predict_y_predict[size - i - 1] - predict_y[size - i - 1]) > yuzhi:
103         predict_y_predict = np.delete(predict_y_predict, size - i - 1)
104         predict_y = np.delete(predict_y, size - i - 1)
105         count = count + 1
106 print('验证集异常点', count)
107 
108 '''评估'''
109 # # 使用r2__score模块,并输出评估结果,拟合程度,R2决定系数,衡量模型预测能力好坏(真实与预测的相关程度百分比)
110 # print('The value of R-squared of kernal=rbf is',r2_score(test_y,test_y_predict))
111 # # 使用mean_squared_error模块,输出评估结果,均方误差
112 # print('The mean squared error of kernal=rbf is',mean_squared_error(test_y,test_y_predict))
113 # # 使用mean_absolute_error模块,输出评估结果,平均绝对误差
114 # print('The mean absolute error of kernal=rbf is',mean_absolute_error(test_y,test_y_predict))
115 
116 # 使用r2__score模块,并输出评估结果,拟合程度,R2决定系数,衡量模型预测能力好坏(真实与预测的相关程度百分比)
117 print('The value of R-squared of kernal=rbf is', r2_score(predict_y, predict_y_predict))
118 # 使用mean_squared_error模块,输出评估结果,均方误差
119 print('The mean squared error of kernal=rbf is', mean_squared_error(predict_y, predict_y_predict))
120 # 使用mean_absolute_error模块,输出评估结果,平均绝对误差
121 print('The mean absolute error of kernal=rbf is', mean_absolute_error(predict_y, predict_y_predict))
122 # r
123 X1 = pd.Series(np.reshape(predict_y,-1))
124 Y1 = pd.Series(np.reshape(predict_y_predict,-1))
125 print('The r is', X1.corr(Y1, method="pearson"))
126 print('The r is', sqrt(r2_score(predict_y, predict_y_predict)))
127 
128 '''作图'''
129 # PRN
130 print('PRN:', fname)
131 
132 # PCA
133 print()
134 
135 # 残差
136 diff_predict = predict_y_predict - predict_y
137 plt.plot(diff_predict, color='black', label='error')
138 plt.xlabel("no.")
139 plt.ylabel("error(m)")
140 plt.title('xxx')
141 plt.grid()
142 plt.legend()
143 plt.show()
144 
145 # 真实/模型_1
146 plt.plot(predict_y, color='g', label='dtu15mss')
147 plt.plot(predict_y_predict, color='b', label='pre')
148 plt.xlabel("xxx")
149 plt.ylabel("error(m)")
150 plt.title('xxx')
151 plt.grid()
152 plt.legend()
153 plt.show()
154 
155 # 真实/模型_2
156 fig = plt.figure(3)
157 ax1 = fig.add_subplot(2, 1, 1)
158 ax1.plot(predict_y, color='g', label='dtu15mss')
159 ax2 = fig.add_subplot(2, 1, 2)
160 ax2.plot(predict_y_predict, color='b', label='pre')
161 plt.show()
162 
163 # 真实/模型_3
164 p_x = [x for x in range(int(min(predict_y)) - 5, int(max(predict_y)) + 5)]
165 p_y = p_x
166 plt.plot(p_x, p_y, color='black', label='1')
167 plt.scatter(predict_y_predict, predict_y, s=10, color='b', marker='x',
168             label='0')  # https://www.cnblogs.com/shanlizi/p/6850318.html
169 plt.xlabel('PRE')
170 plt.ylabel('DTU')
171 plt.show()

 

标签:plt,predict,data,代码,print,train,test,SVR,向量
来源: https://www.cnblogs.com/ltkekeli1229/p/15683225.html

本站声明: 1. iCode9 技术分享网(下文简称本站)提供的所有内容,仅供技术学习、探讨和分享;
2. 关于本站的所有留言、评论、转载及引用,纯属内容发起人的个人观点,与本站观点和立场无关;
3. 关于本站的所有言论和文字,纯属内容发起人的个人观点,与本站观点和立场无关;
4. 本站文章均是网友提供,不完全保证技术分享内容的完整性、准确性、时效性、风险性和版权归属;如您发现该文章侵犯了您的权益,可联系我们第一时间进行删除;
5. 本站为非盈利性的个人网站,所有内容不会用来进行牟利,也不会利用任何形式的广告来间接获益,纯粹是为了广大技术爱好者提供技术内容和技术思想的分享性交流网站。

专注分享技术,共同学习,共同进步。侵权联系[81616952@qq.com]

Copyright (C)ICode9.com, All Rights Reserved.

ICode9版权所有