ICode9

精准搜索请尝试: 精确搜索
首页 > 编程语言> 文章详细

python – 有没有办法在GridSearchCV中查看交叉验证的折叠?

2019-07-01 21:44:40  阅读:235  来源: 互联网

标签:python grid-search


我目前正在使用Python中的GridSearchCV进行3倍的cv来优化超参数.我只是想知道是否有任何方法可以在GridSearchCV中使用的cv中查看训练和测试数据的索引?

解决方法:

如果你不想在CV阶段折叠之前将样品洗牌,你可以.您可以将KFold(或另一个CV类)的实例传递给GridSearchCV构造函数,并像这样访问它的折叠:

import pandas as pd
import numpy as np
from sklearn.model_selection import GridSearchCV
from sklearn.linear_model import LogisticRegression
from sklearn.model_selection import KFold

params = {'penalty' : ['l1', 'l2'], 'C' : [1,2,3]}
grid = GridSearchCV(LogisticRegression(), params, cv=KFold(n_splits=3))

X = np.array([[1, 2], [3, 4], [1, 2], [3, 4], [5, 6], [7, 8]])

for train, test in grid.cv.split(X):
    print('TRAIN: ', train, ' TEST: ', test)

打印:

TRAIN:  [2 3 4 5]  TEST:  [0 1]
TRAIN:  [0 1 4 5]  TEST:  [2 3]
TRAIN:  [0 1 2 3]  TEST:  [4 5]

对于非混洗的CV,折叠总是相同的,因此您可以确定这些是在网格搜索期间使用的折叠.

如果你想在折叠之前对样本进行混洗,那就更复杂了,因为每次调用cv.split()都会产生不同的分割.我可以想到两种方式:

>您可以为CV对象提供固定的random_state,例如KFold(n_splits = 3,shuffle = True,random_state = 42).
>在创建GridSearchCV对象之前,从KFold迭代器创建一个列表.

因此,对于第二种方法,请执行:

grid = GridSearchCV(LogisticRegression(), params, 
                    cv=list(KFold(n_splits=3, shuffle=True).split(X)))

除了迭代器之外,列表是固定对象,除非您手动操作它,否则它将在所有GridSearch迭代中保持相同的值.

标签:python,grid-search
来源: https://codeday.me/bug/20190701/1351216.html

本站声明: 1. iCode9 技术分享网(下文简称本站)提供的所有内容,仅供技术学习、探讨和分享;
2. 关于本站的所有留言、评论、转载及引用,纯属内容发起人的个人观点,与本站观点和立场无关;
3. 关于本站的所有言论和文字,纯属内容发起人的个人观点,与本站观点和立场无关;
4. 本站文章均是网友提供,不完全保证技术分享内容的完整性、准确性、时效性、风险性和版权归属;如您发现该文章侵犯了您的权益,可联系我们第一时间进行删除;
5. 本站为非盈利性的个人网站,所有内容不会用来进行牟利,也不会利用任何形式的广告来间接获益,纯粹是为了广大技术爱好者提供技术内容和技术思想的分享性交流网站。

专注分享技术,共同学习,共同进步。侵权联系[81616952@qq.com]

Copyright (C)ICode9.com, All Rights Reserved.

ICode9版权所有