ICode9

精准搜索请尝试: 精确搜索
首页 > 编程语言> 文章详细

鸢尾花-k近邻预测算法

2021-11-30 14:01:36  阅读:192  来源: 互联网

标签:iris target 近邻 dataset 算法 train test import 鸢尾花


目录

环境

编程语言: python3.10

运行平台: windows10

依赖库安装: matplotlib pandas numpy scikit-learn

介绍

根据花瓣的长度和宽度以及花萼的长度和宽度,得出花的品种属于setosa、versicolor 或virginica 三个品种之一。

散点图源码

from sklearn.datasets import load_iris
from sklearn.model_selection import train_test_split
import pandas as pd
import matplotlib.pyplot as plt
# 获取鸢尾花数据集
iris_dataset = load_iris()
# 打乱数据集,获取训练集与预测集,可以添加test_size train_size参数指定测试集大小,默认25%
X_train, X_test, y_train, y_test = train_test_split(
    iris_dataset['data'], iris_dataset['target'], random_state=0)
# 利用X_train中的数据创建DataFrame
# 利用iris_dataset.feature_names中的字符串对数据列进行标记
iris_dataframe = pd.DataFrame(X_train, columns=iris_dataset.feature_names)
# 利用DataFrame创建散点图矩阵,按y_train着色
grr = pd.plotting.scatter_matrix(iris_dataframe, c=y_train, figsize=(24, 24), alpha=.8)
# 创建窗口
plt.figure(figsize=(24, 24))
# 展示窗口
plt.show()

数据集数据结构

{
	'data': array([[5.1, 3.5, 1.4, 0.2],
       		......
       		[5.9, 3. , 5.1, 1.8]]), 
    'target': array([0, 0, ... 2, 2]), 
    'frame': None, 
    'target_names': array(['setosa', 'versicolor', 'virginica'], dtype='<U10'), 	'DESCR': '... more ...', 
    'feature_names': ['sepal length (cm)', 'sepal width (cm)', 'petal length (cm)', 'petal width (cm)'], 
    'filename': 'iris.csv', 
    'data_module': 'sklearn.datasets.data'
}

结构说明:

  • data: 花瓣特征数据集
  • target: 每个花瓣数据对应品种结果,保存的是target_names数组的下标
  • target_names: 结果集,鸢尾花的三个品种
  • DESCR: 数据集的简要说明
  • feature_names: 每一个特征的简要说明
  • filename: 数据集的文件名称
  • data_module: 数据对应的module

散点图

k近邻算法

k近邻算法在训练集中寻找与这个新数据点距离最近的数据点,然后将找到的数据点的标签赋值给这个新数据点。k 近邻算法中k 的含义是,我们可以考虑训练集中与新数据点最近的任意k 个邻居,然后用这些邻居中数量最多的类别做出预测。

k近邻源码

from sklearn.datasets import load_iris
from sklearn.model_selection import train_test_split
import pandas as pd
import numpy as np
from sklearn.neighbors import KNeighborsClassifier
iris_dataset = load_iris()
X_train, X_test, y_train, y_test = train_test_split(
    iris_dataset['data'], iris_dataset['target'], random_state=0)
iris_dataframe = pd.DataFrame(X_train, columns=iris_dataset.feature_names)
grr = pd.plotting.scatter_matrix(iris_dataframe, c=y_train, figsize=(24, 24), alpha=.8)
# 设置k近邻算法的k值
knn = KNeighborsClassifier(n_neighbors=5)
# 设置k近邻算法的训练数据集与训练结果集
knn.fit(X_train, y_train)
# 创建一个新的测试数据
X_new = np.array([[5, 2.9, 1, 0.2]])
# 根据测试数据预测结果
prediction = knn.predict(X_new)
# 输出预测结果
print("Prediction: {}".format(prediction))
print("Predicted target name: {}".format(iris_dataset['target_names'][prediction]))
# 根据测试数据集预测结果
y_pred = knn.predict(X_test)
# 输出预测结果与 预测准确性
print("Test set predictions:\n {}".format(y_pred))
print("Test set score: {:.2f}".format(np.mean(y_pred == y_test)))

输出结果

Prediction: [0]
Predicted target name: ['setosa']
Test set predictions:
 [2 1 0 2 0 2 0 1 1 1 2 1 1 1 1 0 1 1 0 0 2 1 0 0 2 0 0 1 1 0 2 1 0 2 2 1 0
 2]
Test set score: 0.97

结论

根据测试数据集的预测结果与测试数据集的正确结果比较,得到预测的准确性可以达到97%.

注意

​ 安装sklearn的时候,可能会需要安装VC.

标签:iris,target,近邻,dataset,算法,train,test,import,鸢尾花
来源: https://www.cnblogs.com/52why/p/15623928.html

本站声明: 1. iCode9 技术分享网(下文简称本站)提供的所有内容,仅供技术学习、探讨和分享;
2. 关于本站的所有留言、评论、转载及引用,纯属内容发起人的个人观点,与本站观点和立场无关;
3. 关于本站的所有言论和文字,纯属内容发起人的个人观点,与本站观点和立场无关;
4. 本站文章均是网友提供,不完全保证技术分享内容的完整性、准确性、时效性、风险性和版权归属;如您发现该文章侵犯了您的权益,可联系我们第一时间进行删除;
5. 本站为非盈利性的个人网站,所有内容不会用来进行牟利,也不会利用任何形式的广告来间接获益,纯粹是为了广大技术爱好者提供技术内容和技术思想的分享性交流网站。

专注分享技术,共同学习,共同进步。侵权联系[81616952@qq.com]

Copyright (C)ICode9.com, All Rights Reserved.

ICode9版权所有