ICode9

精准搜索请尝试: 精确搜索
首页 > 编程语言> 文章详细

python – 使用交叉验证评估Logistic回归

2019-05-19 10:56:51  阅读:1116  来源: 互联网

标签:python scikit-learn cross-validation logistic-regression


我想使用交叉验证来测试/训练我的数据集,并评估逻辑回归模型在整个数据集上的性能,而不仅仅是在测试集上(例如25%).

这些概念对我来说是全新的,我不确定它是否做得对.如果有人能告诉我正确的步骤,我会在错误的地方采取行动,我将不胜感激.我的部分代码如下所示.

另外,如何在当前图形的同一图形上绘制“y2”和“y3”的ROC?

谢谢

import pandas as pd 
Data=pd.read_csv ('C:\\Dataset.csv',index_col='SNo')
feature_cols=['A','B','C','D','E']
X=Data[feature_cols]

Y=Data['Status'] 
Y1=Data['Status1']  # predictions from elsewhere
Y2=Data['Status2'] # predictions from elsewhere

from sklearn.linear_model import LogisticRegression
logreg=LogisticRegression()
logreg.fit(X_train,y_train)

from sklearn.cross_validation import train_test_split
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=0)

from sklearn import metrics, cross_validation
predicted = cross_validation.cross_val_predict(logreg, X, y, cv=10)
metrics.accuracy_score(y, predicted) 

from sklearn.cross_validation import cross_val_score
accuracy = cross_val_score(logreg, X, y, cv=10,scoring='accuracy')
print (accuracy)
print (cross_val_score(logreg, X, y, cv=10,scoring='accuracy').mean())

from nltk import ConfusionMatrix 
print (ConfusionMatrix(list(y), list(predicted)))
#print (ConfusionMatrix(list(y), list(yexpert)))

# sensitivity:
print (metrics.recall_score(y, predicted) )

import matplotlib.pyplot as plt 
probs = logreg.predict_proba(X)[:, 1] 
plt.hist(probs) 
plt.show()

# use 0.5 cutoff for predicting 'default' 
import numpy as np 
preds = np.where(probs > 0.5, 1, 0) 
print (ConfusionMatrix(list(y), list(preds)))

# check accuracy, sensitivity, specificity 
print (metrics.accuracy_score(y, predicted)) 

#ROC CURVES and AUC 
# plot ROC curve 
fpr, tpr, thresholds = metrics.roc_curve(y, probs) 
plt.plot(fpr, tpr) 
plt.xlim([0.0, 1.0]) 
plt.ylim([0.0, 1.0]) 
plt.xlabel('False Positive Rate') 
plt.ylabel('True Positive Rate)') 
plt.show()

# calculate AUC 
print (metrics.roc_auc_score(y, probs))

# use AUC as evaluation metric for cross-validation 
from sklearn.cross_validation import cross_val_score 
logreg = LogisticRegression() 
cross_val_score(logreg, X, y, cv=10, scoring='roc_auc').mean() 

解决方法:

你得到它几乎是正确的. cross_validation.cross_val_predict为您提供整个数据集的预测.您只需要在代码中删除logreg.fit即可.具体来说,它的作用如下:
它将您的数据集划分为n个折叠,并且在每次迭代中,它将其中一个折叠作为测试集并在其余折叠上训练模型(n-1倍).因此,最终您将获得整个数据的预测.

让我们用sklearn,iris中的一个内置数据集来说明这一点.该数据集包含150个具有4个特征的训练样本. iris [‘data’]是X,iris [‘target’]是y

In [15]: iris['data'].shape
Out[15]: (150, 4)

要通过交叉验证获得整个集合的预测,您可以执行以下操作:

from sklearn.linear_model import LogisticRegression
from sklearn import metrics, cross_validation
from sklearn import datasets
iris = datasets.load_iris()
predicted = cross_validation.cross_val_predict(LogisticRegression(), iris['data'], iris['target'], cv=10)
print metrics.accuracy_score(iris['target'], predicted)

Out [1] : 0.9537

print metrics.classification_report(iris['target'], predicted) 

Out [2] :
                     precision    recall  f1-score   support

                0       1.00      1.00      1.00        50
                1       0.96      0.90      0.93        50
                2       0.91      0.96      0.93        50

      avg / total       0.95      0.95      0.95       150

所以,回到你的代码.你需要的只是这个:

from sklearn import metrics, cross_validation
logreg=LogisticRegression()
predicted = cross_validation.cross_val_predict(logreg, X, y, cv=10)
print metrics.accuracy_score(y, predicted)
print metrics.classification_report(y, predicted) 

要在多类别分类中绘制ROC,您可以按照this tutorial进行以下操作:

一般来说,sklearn有非常好的教程和文档.我强烈建议阅读他们的tutorial on cross_validation.

标签:python,scikit-learn,cross-validation,logistic-regression
来源: https://codeday.me/bug/20190519/1134596.html

本站声明: 1. iCode9 技术分享网(下文简称本站)提供的所有内容,仅供技术学习、探讨和分享;
2. 关于本站的所有留言、评论、转载及引用,纯属内容发起人的个人观点,与本站观点和立场无关;
3. 关于本站的所有言论和文字,纯属内容发起人的个人观点,与本站观点和立场无关;
4. 本站文章均是网友提供,不完全保证技术分享内容的完整性、准确性、时效性、风险性和版权归属;如您发现该文章侵犯了您的权益,可联系我们第一时间进行删除;
5. 本站为非盈利性的个人网站,所有内容不会用来进行牟利,也不会利用任何形式的广告来间接获益,纯粹是为了广大技术爱好者提供技术内容和技术思想的分享性交流网站。

专注分享技术,共同学习,共同进步。侵权联系[81616952@qq.com]

Copyright (C)ICode9.com, All Rights Reserved.

ICode9版权所有