ICode9

精准搜索请尝试: 精确搜索
首页 > 其他分享> 文章详细

广义线性模型--1.1.普通最小二乘法

2020-04-21 16:02:49  阅读:249  来源: 互联网

标签:plt 20 1.1 fit test 线性 import 乘法 diabetes


1.最小二乘法数学表达式:

 

使经验函数风险最小化 = 损失函数(平方损失)

2.示例

1 from sklearn import linear_model
2 reg = linear_model.LinearRegression()
3 reg.fit ([[0, 0], [1, 1], [2, 2]], [0, 1, 2])
4 LinearRegression(copy_X=True, fit_intercept=True, n_jobs=1, normalize=False)
5 reg.coef_
6 array([ 0.5,  0.5])

模型参数 coef = w 相关系数 intercept = w0 截距 在fit_intercept = Flase的时候将会返回0

  copy_x  是否对x值进行复制保存

  fit_intercept 是否显示截距 即b的值
  n_jobs  运算时电脑几核运行
  normalize  在进行最小二乘法运算时x是否先标准化,当_fit_intercep_t为False时该参数默认取消

3.官方案例

 1 import matplotlib.pyplot as plt
 2 import numpy as np
 3 from sklearn import datasets, linear_model
 4 from sklearn.metrics import mean_squared_error, r2_score
 5 
 6 # Load the diabetes dataset
 7 diabetes_X, diabetes_y = datasets.load_diabetes(return_X_y=True)
 8 print(diabetes_y)
 9 # Use only one feature
10 diabetes_X = diabetes_X[:, np.newaxis, 2]
11 
12 
13 # Split the data into training/testing sets
14 diabetes_X_train = diabetes_X[:-20]
15 diabetes_X_test = diabetes_X[-20:]
16 
17 # Split the targets into training/testing sets
18 diabetes_y_train = diabetes_y[:-20]
19 diabetes_y_test = diabetes_y[-20:]
20 
21 # Create linear regression object
22 regr = linear_model.LinearRegression()
23 
24 # Train the model using the training sets
25 regr.fit(diabetes_X_train, diabetes_y_train)
26 
27 # Make predictions using the testing set
28 diabetes_y_pred = regr.predict(diabetes_X_test)
29 
30 # The coefficients
31 print('Coefficients: \n', regr.coef_)
32 # The mean squared error  #残差平方和
33 print('Mean squared error: %.2f'
34       % mean_squared_error(diabetes_y_test, diabetes_y_pred))
35 # The coefficient of determination: 1 is perfect prediction
36 print('Coefficient of determination: %.2f'
37       % r2_score(diabetes_y_test, diabetes_y_pred))
38 
39 # Plot outputs
40 plt.scatter(diabetes_X_test, diabetes_y_test,  color='black')
41 plt.plot(diabetes_X_test, diabetes_y_pred, color='blue', linewidth=3)
42 
43 plt.xticks(())
44 plt.yticks(())

 

标签:plt,20,1.1,fit,test,线性,import,乘法,diabetes
来源: https://www.cnblogs.com/zhengyinboke/p/12745223.html

本站声明: 1. iCode9 技术分享网(下文简称本站)提供的所有内容,仅供技术学习、探讨和分享;
2. 关于本站的所有留言、评论、转载及引用,纯属内容发起人的个人观点,与本站观点和立场无关;
3. 关于本站的所有言论和文字,纯属内容发起人的个人观点,与本站观点和立场无关;
4. 本站文章均是网友提供,不完全保证技术分享内容的完整性、准确性、时效性、风险性和版权归属;如您发现该文章侵犯了您的权益,可联系我们第一时间进行删除;
5. 本站为非盈利性的个人网站,所有内容不会用来进行牟利,也不会利用任何形式的广告来间接获益,纯粹是为了广大技术爱好者提供技术内容和技术思想的分享性交流网站。

专注分享技术,共同学习,共同进步。侵权联系[81616952@qq.com]

Copyright (C)ICode9.com, All Rights Reserved.

ICode9版权所有