ICode9

精准搜索请尝试: 精确搜索
首页 > 其他分享> 文章详细

鸢尾花识别

2020-04-16 09:37:31  阅读:492  来源: 互联网

标签:csv predict DecisionTreeClassifier step7 train pd 鸢尾花 识别


任务描述

使用sklearn完成鸢尾花分类任务。
在这里插入图片描述
鸢尾花数据集是一类多重变量分析的数据集。通过花萼长度,花萼宽度,花瓣长度,花瓣宽度4个属性预测鸢尾花卉属于(Setosa,Versicolour,Virginica)三个种类中的哪一类(其中分别用0,1,2代替)。

数据集中部分数据与标签如下图所示:
在这里插入图片描述
在这里插入图片描述

DecisionTreeClassifier

DecisionTreeClassifier的构造函数中有两个常用的参数可以设置:

criterion:划分节点时用到的指标。有gini(基尼系数),entropy(信息增益)。若不设置,默认为gini
max_depth:决策树的最大深度,如果发现模型已经出现过拟合,可以尝试将该参数调小。若不设置,默认为None
和sklearn中其他分类器一样,DecisionTreeClassifier类中的fit函数用于训练模型,fit函数有两个向量输入:

X:大小为[样本数量,特征数量]的ndarray,存放训练样本;

Y:值为整型,大小为[样本数量]的ndarray,存放训练样本的分类标签。

DecisionTreeClassifier类中的predict函数用于预测,返回预测标签,predict函数有一个向量输入:

X:大小为[样本数量,特征数量]的ndarray,存放预测样本。
DecisionTreeClassifier的使用代码如下:

from sklearn.tree import DecisionTreeClassifier
clf = tree.DecisionTreeClassifier()
clf.fit(X_train, Y_train)
result = clf.predict(X_test)

编程要求

实现鸢尾花数据的分类任务,其中训练集数据保存在./step7/train_data.csv中,训练集标签保存在。./step7/train_label.csv中,测试集数据保存在。./step7/test_data.csv中。请将对测试集的预测结果保存至。./step7/predict.csv中。这些csv文件可以使用pandas读取与写入。

注意:当使用pandas读取完csv文件后,请将读取到的DataFrame转换成ndarray类型。这样才能正常的使用fit和predict。

示例代码:

import pandas as pd
# as_matrix()可以将DataFrame转换成ndarray
# 此时train_df的类型为ndarray而不是DataFrame
train_df = pd.read_csv('train_data.csv').as_matrix()

数据文件格式如下图所示:
在这里插入图片描述
在这里插入图片描述

通关代码:

import pandas as pd
from sklearn.tree import DecisionTreeClassifier
train_df = pd.read_csv('./step7/train_data.csv').as_matrix()
train_label = pd.read_csv('./step7/train_label.csv').as_matrix()
test_df = pd.read_csv('./step7/test_data.csv').as_matrix()
dt = DecisionTreeClassifier()
dt.fit(train_df, train_label)
result = dt.predict(test_df)
result = pd.DataFrame({'target':result})
result.to_csv('./step7/predict.csv', index=False)

在这里插入图片描述

标签:csv,predict,DecisionTreeClassifier,step7,train,pd,鸢尾花,识别
来源: https://blog.csdn.net/zag666/article/details/105497621

本站声明: 1. iCode9 技术分享网(下文简称本站)提供的所有内容,仅供技术学习、探讨和分享;
2. 关于本站的所有留言、评论、转载及引用,纯属内容发起人的个人观点,与本站观点和立场无关;
3. 关于本站的所有言论和文字,纯属内容发起人的个人观点,与本站观点和立场无关;
4. 本站文章均是网友提供,不完全保证技术分享内容的完整性、准确性、时效性、风险性和版权归属;如您发现该文章侵犯了您的权益,可联系我们第一时间进行删除;
5. 本站为非盈利性的个人网站,所有内容不会用来进行牟利,也不会利用任何形式的广告来间接获益,纯粹是为了广大技术爱好者提供技术内容和技术思想的分享性交流网站。

专注分享技术,共同学习,共同进步。侵权联系[81616952@qq.com]

Copyright (C)ICode9.com, All Rights Reserved.

ICode9版权所有