ICode9

精准搜索请尝试: 精确搜索
首页 > 编程语言> 文章详细

GridSearchCV用于python中的多类SVM

2019-09-10 14:58:46  阅读:377  来源: 互联网

标签:grid-search python scikit-learn svm


我正在尝试学习如何为分类器找到最佳参数.所以,我使用GridSearchCV来解决多类分类问题.在Does not GridSearchCV support multi-class?上生成了一个虚拟代码我正在使用n_classes = 3的代码.

import numpy as np
from sklearn.datasets import make_classification
from sklearn.preprocessing import StandardScaler,label_binarize
from sklearn.svm import SVC
from sklearn.pipeline import make_pipeline
from sklearn.grid_search import GridSearchCV
from sklearn.metrics import accuracy_score, recall_score, f1_score, roc_auc_score, make_scorer

X, y = make_classification(n_samples=3000, n_features=10, weights=[0.1, 0.9, 0.3],n_classes=3, n_clusters_per_class=1,n_informative=2)

pipe = make_pipeline(StandardScaler(), SVC(kernel='rbf', class_weight='auto'))

param_space = dict(svc__C=np.logspace(-5,0,5), svc__gamma=np.logspace(-2, 2, 10))

f1_score
my_scorer = make_scorer(f1_score, greater_is_better=True)

gscv = GridSearchCV(pipe, param_space, scoring=my_scorer)

我正在尝试按照这里建议的一热编码Scikit-learn GridSearch giving “ValueError: multiclass format is not supported” error.此外,有时会有像Kaggle上的Toxic Comment Classification数据集这样的数据集,它会为您提供二值化标签.

y = label_binarize(y, classes=[0, 1, 2])
for i in classes:    
gscv.fit(X, y[i])

print gscv.best_params_

我正进入(状态:

ValueError: bad input shape (2000L, 3L)

我不知道为什么我会收到这个错误.我的目标是找到多类分类问题的最佳参数.

解决方法:

代码的两个部分有两个问题.

1)当你没有对标签进行单热编码时,让我们从第一部分开始.你看,SVC支持多类案件就好了.但是当与(内部)GridSearchCV结合时,f1_score不会.

在二进制分类的情况下,f1_score默认返回正标签的分数,因此会在您的情况下抛出错误.

或者它也可以返回一个分数数组(每个类一个),但GridSearchCV只接受一个值作为分数,因为它需要找到最佳分数和超参数的最佳组合.因此,您需要在f1_score中传递平均方法以从数组中获取单个值.

根据f1_score documentation,允许采用平均方法:

average : string, [None, ‘binary’ (default), ‘micro’, ‘macro’,
‘samples’, ‘weighted’]

所以像这样改变你的make_scorer:

my_scorer = make_scorer(f1_score, greater_is_better=True, average='micro')

根据您的需要更改上面的“平均”参数.

2)现在进入第二部分:当您对标签进行单热编码时,y的形状变为2-d,但SVC仅支持1-d数组,如文档中指定的y:

06001

但即使您对标签进行编码并使用支持2-d标签的分类器,也必须解决第一个错误.所以我建议你不要对标签进行单热编码,只需更改f1_score即可.

标签:grid-search,python,scikit-learn,svm
来源: https://codeday.me/bug/20190910/1799375.html

本站声明: 1. iCode9 技术分享网(下文简称本站)提供的所有内容,仅供技术学习、探讨和分享;
2. 关于本站的所有留言、评论、转载及引用,纯属内容发起人的个人观点,与本站观点和立场无关;
3. 关于本站的所有言论和文字,纯属内容发起人的个人观点,与本站观点和立场无关;
4. 本站文章均是网友提供,不完全保证技术分享内容的完整性、准确性、时效性、风险性和版权归属;如您发现该文章侵犯了您的权益,可联系我们第一时间进行删除;
5. 本站为非盈利性的个人网站,所有内容不会用来进行牟利,也不会利用任何形式的广告来间接获益,纯粹是为了广大技术爱好者提供技术内容和技术思想的分享性交流网站。

专注分享技术,共同学习,共同进步。侵权联系[81616952@qq.com]

Copyright (C)ICode9.com, All Rights Reserved.

ICode9版权所有