ICode9

精准搜索请尝试: 精确搜索
首页 > 其他分享> 文章详细

sklearn代码23 6-线性回归岭回归 套索回归比较

2021-11-08 10:33:36  阅读:162  来源: 互联网

标签:plt ridge linear 23 回归 axes 1.00000000 sklearn lasso


# LinearRegression,Ridge,Lasso

import numpy as np

from sklearn.linear_model import LinearRegression,Ridge,Lasso,RidgeCV,LassoCV

import matplotlib.pyplot as plt
%matplotlib inline
# 50个样本 200个特征
# 无解,无数个解
X = np.random.randn(50,200)
w = np.random.randn(200)
w
array([-0.71763223, -0.01975597, -1.66512775,  1.15509566, -1.30815193,
       -0.07886716, -0.12621629, -0.48452705, -0.76894705, -1.27958424,
        0.20661147,  0.07626266, -1.05664013,  1.43455568,  2.22725443,
        0.13220785,  1.01291249, -1.87467501, -0.80073911, -0.86567154,
       -1.24317069,  1.0130023 ,  0.33956167,  0.75203886,  0.7022749 ,
       -1.14882555,  0.634176  , -0.13809194, -2.03394849, -0.5516863 ,
        1.1398463 , -0.51857542, -0.88925621, -0.27183436,  1.56244012,
        0.66154914,  0.08529891, -0.18766498, -0.7229419 ,  0.6913235 ,
        1.66743931,  1.40285862,  0.50516722,  0.69088917, -0.13801636,
        0.82850681,  0.62598677,  0.5211237 ,  0.57181996,  0.91503186,
       -1.14734987,  0.46803846,  0.49677025, -1.004296  ,  1.3109282 ,
       -1.91016754,  1.45630189,  0.08982377, -0.51922071,  0.46723805,
        0.01055369, -0.48847605, -0.68935962,  1.6901229 , -0.23703418,
       -0.64618434, -0.93594604,  1.48674155,  0.79347216, -2.45278997,
       -0.1055643 , -1.10797711,  0.18312005, -1.63356662,  1.97703239,
       -0.16488839, -0.64318795, -1.14363873,  0.66084745,  1.14099327,
        0.9259731 ,  0.04103045, -0.55955006,  0.52709757, -1.28036461,
        0.74475445, -1.07053689,  0.20305404,  1.39808953,  0.31716686,
        0.63150615, -1.00307068, -0.95333729, -0.69220477, -0.03925317,
       -1.19738869, -0.01158072, -0.40013061, -2.18699458, -0.18176726,
        1.16341707, -0.91878923,  1.03465085, -1.81036414, -0.8452893 ,
        0.88631047, -0.07775361, -1.09726693,  1.18568627,  2.97868689,
        0.15734896,  0.35259873, -1.18522538, -0.20386231,  1.06447013,
       -1.50989228, -1.18713503, -2.72484655,  1.82771012, -0.56030818,
       -1.26393399,  0.09519989,  0.75043212,  0.1845392 ,  0.57406391,
        0.32241044, -0.92922765, -1.81582008, -0.17089422, -0.82638478,
        0.85685134,  0.33737166, -1.27335904, -0.12061047,  0.43116238,
        0.69293522, -0.3116372 ,  2.10697826,  0.22059706, -2.04990896,
        1.20031869, -0.65923924,  0.21741321,  2.69016452,  1.79752197,
        0.07034715, -0.74076325,  1.17818112,  1.27788198,  1.62346993,
       -0.61267043, -0.50887636,  0.0502629 , -0.63902576,  1.78457654,
        0.36369644, -1.59256726,  0.25070796,  0.02888558,  0.27984078,
        0.79969306,  0.81636017,  0.09265504, -1.29414286, -0.41225244,
       -0.90373965, -0.78816351, -0.81712267,  0.35288045,  0.46918612,
        0.43737485,  1.29382308,  0.07552618, -2.20126151, -1.16065126,
       -1.34731012, -0.61742243,  1.55812234, -0.5166435 ,  0.79979653,
       -1.12110086,  0.33189134,  0.63867508,  0.77258482, -0.78317738,
       -0.00803317, -0.28364481, -1.92934529, -1.33925024,  0.01404032,
       -1.69134308,  0.81528267,  0.26279143,  0.0321547 , -0.03219929,
       -0.00751312, -0.08025871, -1.24736631,  0.52277507,  0.30633436])
# 将其中的190个置为0
index = np.arange(0,199)

np.random.shuffle(index)
index
array([102, 166, 151,  76,  68, 140,  60,  99,  92,  55, 125,  94, 132,
        33, 177,  23, 111,  42, 110,  81, 179, 129, 192, 106,  34, 152,
        29, 139, 171, 109,  52, 142, 173, 180, 131,  73,  98,  36,  26,
        87,  28, 190, 183,  69,  88, 141, 178, 119,  48,  78, 169, 137,
        25, 104, 120, 130,   5, 147, 146, 148, 107, 187,  79,  32,   4,
       181,  77, 156, 112,  50,  59, 149, 172,  46, 114,  12, 134,  93,
        31,  71, 138,  19,  66,  70,  10,   3,  96, 195,  61,  14,  49,
        16, 182,  27, 193,   9, 196, 136,  15, 150,  75, 157, 116, 155,
       133, 165, 145, 160, 162,  97, 121,  56,  84, 122,  82, 108, 194,
       115, 124,  63,   7,  62,  86,   0,  30, 175,  58,  64,  18,  91,
         6,  89,  53,   8, 189,  17, 143,  90,  51, 154, 159, 100, 164,
       128, 174, 170,   1, 191,  39,  37, 117, 184, 105,  80, 144, 176,
       101, 168, 186,  44,  67, 153,  41,  54,  11, 163,  47,  65,  22,
       197,  24,  21, 167,  95, 161, 126,  74,  35,   2,  83, 127, 188,
       198,  13, 135, 103, 118, 123,  38, 185,  85,  72,  45,  20,  40,
        57,  43, 113, 158])
w[index[:190]] = 0
w
array([ 0.        ,  0.        ,  0.        ,  0.        ,  0.        ,
        0.        ,  0.        ,  0.        ,  0.        ,  0.        ,
        0.        ,  0.        ,  0.        ,  0.        ,  0.        ,
        0.        ,  0.        ,  0.        ,  0.        ,  0.        ,
       -1.24317069,  0.        ,  0.        ,  0.        ,  0.        ,
        0.        ,  0.        ,  0.        ,  0.        ,  0.        ,
        0.        ,  0.        ,  0.        ,  0.        ,  0.        ,
        0.        ,  0.        ,  0.        ,  0.        ,  0.        ,
        1.66743931,  0.        ,  0.        ,  0.69088917,  0.        ,
        0.82850681,  0.        ,  0.        ,  0.        ,  0.        ,
        0.        ,  0.        ,  0.        ,  0.        ,  0.        ,
        0.        ,  0.        ,  0.08982377,  0.        ,  0.        ,
        0.        ,  0.        ,  0.        ,  0.        ,  0.        ,
        0.        ,  0.        ,  0.        ,  0.        ,  0.        ,
        0.        ,  0.        ,  0.18312005,  0.        ,  0.        ,
        0.        ,  0.        ,  0.        ,  0.        ,  0.        ,
        0.        ,  0.        ,  0.        ,  0.        ,  0.        ,
        0.74475445,  0.        ,  0.        ,  0.        ,  0.        ,
        0.        ,  0.        ,  0.        ,  0.        ,  0.        ,
        0.        ,  0.        ,  0.        ,  0.        ,  0.        ,
        0.        ,  0.        ,  0.        ,  0.        ,  0.        ,
        0.        ,  0.        ,  0.        ,  0.        ,  0.        ,
        0.        ,  0.        ,  0.        , -0.20386231,  0.        ,
        0.        ,  0.        ,  0.        ,  0.        ,  0.        ,
        0.        ,  0.        ,  0.        ,  0.        ,  0.        ,
        0.        ,  0.        ,  0.        ,  0.        ,  0.        ,
        0.        ,  0.        ,  0.        ,  0.        ,  0.        ,
        0.        ,  0.        ,  0.        ,  0.        ,  0.        ,
        0.        ,  0.        ,  0.        ,  0.        ,  0.        ,
        0.        ,  0.        ,  0.        ,  0.        ,  0.        ,
        0.        ,  0.        ,  0.        ,  0.        ,  0.        ,
        0.        ,  0.        ,  0.        ,  0.02888558,  0.        ,
        0.        ,  0.        ,  0.        ,  0.        ,  0.        ,
        0.        ,  0.        ,  0.        ,  0.        ,  0.        ,
        0.        ,  0.        ,  0.        ,  0.        ,  0.        ,
        0.        ,  0.        ,  0.        ,  0.        ,  0.        ,
        0.        ,  0.        ,  0.        ,  0.        ,  0.        ,
        0.        ,  0.        ,  0.        ,  0.        ,  0.        ,
        0.        ,  0.        ,  0.        ,  0.        ,  0.        ,
        0.        ,  0.        ,  0.        ,  0.        ,  0.30633436])
y = X.dot(w)
y
array([-3.87237903,  2.20010955, -0.2256118 , -0.37557475, -4.14212192,
        1.04985011, -2.10190133, -3.26354501, -0.97369551,  6.60477669,
       -1.41760307,  3.99964396, -0.80688322, -1.80704093, -0.22938875,
       -0.17397568, -1.73476906,  2.66256959, -1.1235785 , -1.79337405,
       -0.68295749,  0.84934801,  3.8256913 , -1.53396707,  2.78321804,
        4.81322558,  5.66743923, -3.47945016, -0.5536171 ,  3.44103892,
        4.39879306, -2.25096992,  4.57986203, -1.27052193,  1.37985431,
       -4.79615756, -0.1770027 ,  0.12112281, -3.08041981, -0.69028005,
       -4.23879438, -0.16577193, -2.06806875,  4.00844061,  0.05162934,
       -4.99438005, -1.92438892, -5.47358088, -0.34889523, -1.70204475])
import warnings
warnings.filterwarnings('ignore')   #针对粉红色的提示,在导入此包后不会再有
linear = LinearRegression()

ridge = RidgeCV(alphas=[0.001,0.01,0.1,1,2,5,10],cv = 5)

lasso = LassoCV(alphas=[0.001,0.01,0.1,1,2,5,10],cv = 3)
linear.fit(X,y)

ridge.fit(X,y)

lasso.fit(X,y)
LassoCV(alphas=[0.001, 0.01, 0.1, 1, 2, 5, 10], copy_X=True, cv=3, eps=0.001,
        fit_intercept=True, max_iter=1000, n_alphas=100, n_jobs=None,
        normalize=False, positive=False, precompute='auto', random_state=None,
        selection='cyclic', tol=0.0001, verbose=False)
a = lasso.alphas_
a
array([  1.00000000e+01,   5.00000000e+00,   2.00000000e+00,
         1.00000000e+00,   1.00000000e-01,   1.00000000e-02,
         1.00000000e-03])
linear_w = linear.coef_

ridge_w = ridge.coef_

lasso_w = lasso.coef_

plt.figure(figsize=(12,9))
axes = plt.subplot(2,2,1)
axes.plot(w)

axes = plt.subplot(2,2,2)
axes.plot(linear_w)
axes.set_title('linear')

axes = plt.subplot(2,2,3)
axes.plot(ridge_w)
axes.set_title('ridge')

axes = plt.subplot(2,2,4)
axes.plot(lasso_w)
axes.set_title('lasso')
Text(0.5,1,'lasso')

请添加图片描述

# 套索回归和标准回归最像

标签:plt,ridge,linear,23,回归,axes,1.00000000,sklearn,lasso
来源: https://blog.csdn.net/weixin_44632711/article/details/121202845

本站声明: 1. iCode9 技术分享网(下文简称本站)提供的所有内容,仅供技术学习、探讨和分享;
2. 关于本站的所有留言、评论、转载及引用,纯属内容发起人的个人观点,与本站观点和立场无关;
3. 关于本站的所有言论和文字,纯属内容发起人的个人观点,与本站观点和立场无关;
4. 本站文章均是网友提供,不完全保证技术分享内容的完整性、准确性、时效性、风险性和版权归属;如您发现该文章侵犯了您的权益,可联系我们第一时间进行删除;
5. 本站为非盈利性的个人网站,所有内容不会用来进行牟利,也不会利用任何形式的广告来间接获益,纯粹是为了广大技术爱好者提供技术内容和技术思想的分享性交流网站。

专注分享技术,共同学习,共同进步。侵权联系[81616952@qq.com]

Copyright (C)ICode9.com, All Rights Reserved.

ICode9版权所有