ICode9

精准搜索请尝试: 精确搜索
首页 > 其他分享> 文章详细

scikit基础与机器学习入门(11) 欠拟合,过拟合和交叉验证

2021-09-19 20:34:27  阅读:188  来源: 互联网

标签:11 index 验证 scikit test 拟合 交叉 import iris


欠拟合和过拟合的定义

image

在机器学习问题中,经常会出现模型在训练数据上的得分很高,但是在新的数据上表现很差的情况,这称之为过拟合overfitting,又叫高方差high variance

而如果在训练数据上得分就很低,这称之为欠拟合underfitting,又叫高偏差high bias

留出法与验证集

为了解决过拟合问题,常见的方法将数据分为训练集和测试集,用训练集去训练模型的参数,用测试集去测试训练后模型的表现。有时对于一些具有超参数的模型(例如svm.SVC的参数C和kernel就属于超参数),还需要从训练集中划出一部分数据去验证超参数的有效性。

交叉验证法

在数据数量有限时,按留出法将数据分成3部分将会严重影响到模型训练的效果。为了有效利用有限的数据,可以采用交叉验证cross_validation方法。

交叉验证的基本思想是:以不同的方式多次将数据集划分成训练集和测试集,分别训练和测试,再综合最后的测试得分。每个数据在一些划分情况下属于训练集,在另外一些划分情况下属于测试集。
常用的交叉验证方法:K折(KFold),留一交叉验证(LeaveOneOut,LOO),留P交叉验证
(LeavePOut,LPO),重复K折交叉验证(RepeatedKFold),随机排列交叉验证(ShuffleSplit)。
此外,为了保证训练集中每种标签类别数据的分布和完整数据集中的分布一致,可以采用分层交叉验证方法(StratifiedKFold,StratifiedShuffleSplit)。

当数据集的来源有不同的分组时,独立同分布假设(independent identical distributed:i.i.d)将被打破,可以使用分组交叉验证方法来确保测试集合中的的所有样本来自训练样本中没有表示过的新的分组。(GroupKFold,LeaveOneGroupOut,LeavePGroupsOut,GroupShuffleSplit)

对于时间序列数据,一个非常重要的特点是时间相邻的观测之间的相关性(自相关性),因此用过去的数据训练而用未来的数据测试非常重要。TimeSeriesSplit可以实现这样的分割。

代码实现

注意·,这里生成的是数据集生成的下标

随机排列交叉验证

import numpy as np
from sklearn.model_selection import ShuffleSplit
X = np.arange(5)
ss = ShuffleSplit(n_splits=10, test_size=0.6,
random_state=0)
for train_index, test_index in ss.split(X):
	print("%s %s" % (train_index, test_index))

分层K折交叉验证

分三折,n_split在这里指生成数据的组数

# 分层K折交叉验证
from sklearn.model_selection import StratifiedKFold
X = np.ones(10)
y = [0, 0, 0, 0, 1, 1, 1, 1, 1,1]
skf = StratifiedKFold(n_splits=3,shuffle = False)
for train_index, test_index in skf.split(X, y):
	print("%s %s" % (train_index, test_index))

留P分组交叉验证

from sklearn.model_selection import LeavePGroupsOut
X = np.arange(6)
y = [1, 1, 1, 2, 2, 2]
groups = [1, 1, 2, 2, 3, 3]
lpgo = LeavePGroupsOut(n_groups=2)
for train_index, test_index in lpgo.split(X, y, groups=groups):
	print("%s %s" % (train_index, test_index))

对于分类问题,可以用这种方法,此时留两组作为验证集

image

实现时间序列的分割

# 时间序列分割
from sklearn.model_selection import TimeSeriesSplit
X = np.array([[1, 2], [3, 4], [1, 2], [3, 4], [1, 2], [3, 4],[2, 2],[4, 6]])
y = np.array([1, 2, 3, 4, 5, 6, 7, 8])
tscv = TimeSeriesSplit(n_splits=3,max_train_size = 3)
for train_index, test_index in tscv.split(X,y):
	print("%s %s" % (train_index, test_index))

我们可以看出,这里的训练集和数据集是从前到后一致的

image

交叉验证综合评分

调用 cross_val_score 函数可以计算模型在各交叉验证数据集上的得分。可以指定metrics中的打分函数,也可以指定交叉验证迭代器。

from sklearn.model_selection import cross_val_score
from sklearn import svm
from sklearn import datasets
iris = datasets.load_iris()
clf = svm.SVC(kernel='linear', C=1)
scores = cross_val_score(clf, iris.data, iris.target, cv=5) #采用5折交叉验证
print(scores)
#平均得分和 95% 置信区间
print("Accuracy: %0.2f (+/- %0.2f)" % (scores.mean(), scores.std() * 2))

默认情况下,每个 CV 迭代计算的分数是估计器的 score 方法。可以通过使用 scoring 参数来改变计算方式如下

from sklearn import metrics
scores = cross_val_score(
	clf, iris.data, iris.target, cv=5, scoring='f1_macro')
scores

可以通过传入一个自己定义的分类器

# 通过传入一个交叉验证迭代器来指定其他交叉验证策略
from sklearn.model_selection import ShuffleSplit
n_samples = iris.data.shape[0]
ss = ShuffleSplit(n_splits=10, test_size=0.3, random_state=0)
cross_val_score(clf, iris.data, iris.target, cv=ss)

cross_validate函数和cross_val_score函数类似,但功能更为强大,它允许指定多个指标进行评估,并且除返回指定的指标外,还会返回一个fit_time和score_time即训练时间和评分时间。

from sklearn.model_selection import cross_validate
from sklearn.metrics import recall_score
clf = svm.SVC(kernel='linear', C=1, random_state=0)
scores = cross_validate(clf, iris.data, iris.target, scoring=['f1_macro','f1_micro'],
cv=10, return_train_score=False)
print(sorted(scores.keys()))
#scores['test_recall_macro']
print(scores['fit_time']) #训练集拟合时间
print(scores['score_time'])#测试集评分时间
print('f1_micro:',scores['test_f1_micro'])
print('f1_macro:',scores['test_f1_macro'])

使用cross_val_predict可以返回每条样本作为CV中的测试集时,对应的模型对该样本的预测结果。

这就要求使用的CV策略能保证每一条样本都有机会作为测试数据,否则会报异常。

from sklearn.model_selection import cross_val_predict
predicted = cross_val_predict(clf, iris.data, iris.target, cv=10)
print(predicted)
metrics.accuracy_score(iris.target, predicted)

这里的方法仅用于评估超参数的优劣,至于用交叉验证法寻找超参数,下一章会提到

标签:11,index,验证,scikit,test,拟合,交叉,import,iris
来源: https://www.cnblogs.com/xiaoyunbowen/p/15312605.html

本站声明: 1. iCode9 技术分享网(下文简称本站)提供的所有内容,仅供技术学习、探讨和分享;
2. 关于本站的所有留言、评论、转载及引用,纯属内容发起人的个人观点,与本站观点和立场无关;
3. 关于本站的所有言论和文字,纯属内容发起人的个人观点,与本站观点和立场无关;
4. 本站文章均是网友提供,不完全保证技术分享内容的完整性、准确性、时效性、风险性和版权归属;如您发现该文章侵犯了您的权益,可联系我们第一时间进行删除;
5. 本站为非盈利性的个人网站,所有内容不会用来进行牟利,也不会利用任何形式的广告来间接获益,纯粹是为了广大技术爱好者提供技术内容和技术思想的分享性交流网站。

专注分享技术,共同学习,共同进步。侵权联系[81616952@qq.com]

Copyright (C)ICode9.com, All Rights Reserved.

ICode9版权所有