ICode9

精准搜索请尝试: 精确搜索
首页 > 其他分享> 文章详细

任务6的大概总结和思考

2020-01-06 21:02:55  阅读:350  来源: 互联网

标签:总结 KNN plt 10 模型 train 大概 思考 图片


在该任务中,我们要做的是基于KNN算法的图像识别

  首先,我在下载cifar-10数据集的时候遇到一些麻烦,第一次出现错误是说找不到load_data,通过发现找到了jupyter notebook新建的文件都是放在了D盘下Python3的Lib文件下,然后发现是我没把老师给的文件放在根目录下,我选择了直接将文件目录写进去,即直接给出数据集在本地的绝对地址

把cifar10_dir = 'F:\Jupyter\CIFAR10'

改为cifar10_dir = 'D:\python3\Lib\cifar-10-batches-py'

可是放了之后又出现问题,百度之后得到要下载pillow(pip install pillow),下载完之后还是不行,于是根据百度到的教程安装scipy的1.2.1版本(pip install scipy==1.2.1),然后就可以了。

 

   CIFAR-10,是一个用于做图像分类研究的数据集

  • 由60000个图片组成
  • 6万个图片中,5万张用于训练,1万张用于测试
  • 每个图片是32x32像素
  • 所有图片可以分成10类
  • 每个图片都有一个标签,标记属于哪一个类
  • 测试集中一个类对应1000张图
  • 训练集中将5万张图分为5份
  • 类之间的图片是互斥的,不存在类别重叠的情况

 

  下载了cifar-10数据集后就开始了下面的操作,我是从整体到细节的层面去理解该任务的代码的,从整体来看,我们的操作分为4步

   一, 文件的读取、可视化、以及采样

   二,使用KNN算法识别图片

   三,抽取图片特征,再用KNN算法来识别图片

   四,使用PCA对图片做降维,并做可视化

先看第一步

   一, 文件的读取、可视化、以及采样

     在这部分我们需要读取图片文件,并展示部分图片便于观察,以及做少量的采样。

  1. 文件的读取: 读取部分的代码已经提供,你只需要调用一下即可以读取图片数据。
  2. 可视化: 选择其中的一些样本做可视化,也就是展示图片的内容以及它的标签。
  3. 采样:统计一下各类出现的个数以及采样部分样本作为后续的模型的训练。

              (1) 将数据集cifar-10读入并保存到训练集和测试集中,

1 # 读取文件,并把数据保存到训练集和测试集合。 
2 X_train, y_train, X_test, y_test = load_CIFAR10(cifar10_dir)
读取文件

 

              (2) 选择部分的样本做可视化(作用是什么?单纯为了展示吗?)

 

 1  plt.imshow(X_train[1]/255)#imshow的要求:像素值要么是整数,要么是[0,1]的实数
 2  plt.imshow(X_train[1][0:16]/255)
 3  plt.imshow(X_train[1][:,0:16]/255)
 4 
 5  plt.subplot(1,2,1) #一行两列的子图,最后的1是位置1
 6  plt.imshow(X_train[1]/255)#在刚才指定位置绘制第1副图
 7  plt.subplot(1,2,2)#一行两列的子图,最后的2是位置2
 8  plt.imshow(X_train[2]/255)#在刚才指定的位置绘制第2副图
 9  plt.figure(figsize = (8,8)) #创建一个指定大小的窗口(可根据需要调节)
10    for i in range(1,5): #左闭右开 [ 1,5 )
11     plt.subplot(2,2,i)
12     plt.imshow(X_train[i]/255)#绘制第i副图
13     plt.axis('off')
可视化

 

 

 然后选出10个类别,每个类别选5张图来展示。

 之后统计测试集中每个种类的个数,看是否平均,防止影响模型训练的准确率

 

 

              (3) 随机采样

                       随机采样训练样本5000个和测试样本500个。训练样本从训练集里采样,测试样本从测试集里采样。

                       从5万训练集从抽出5千训练样本,从1万训练集中抽取5百测试样本,并将样本打乱,但要用种子确保打乱后y_train[i]和X_train[i]仍对应

 

通过上述步骤得到采样后的样本X_train, y_train, X_test, y_test

  二, 使用KNN算法识别图片

——通过K折交叉验证来训练KNN,以及选择最合适的K值和p值

 

   因为KNN依赖于两个样本之间的距离运算,所以首先要用reshape把我们的样本图片转为向量模式,以便可以作为模型的输入

   也就是把原来大小为(32, 32, 3)的图片直接转换成一个长度为32*32*3=3072的向量。

   向量模式下记为X_train1 , X_test1

 通过K折交叉验证构造最好的KNN模型,并输出最好的模型参数,以及测试集上的准确率。

 需要搜索的参数为Kp。对于交叉验证,在这里使用GridSearchCV,这是一种参数搜索的方法也叫作网格搜索,其实就是考虑所有的组合,而选出最优的K,P组合。

 

 1 params_k = [1,3,5]  # 可以选择的K值
 2 params_p = [1,2] # 可以选择的P值
 3 
 4 params = {
 5     'n_neighbors':params_k,
 6     'p':params_p
 7 }
 8 # 构建模型
 9 knn = neighbors.KNeighborsClassifier()
10 #knn = KNeighborsClassifier()
11 model1 = GridSearchCV(knn, params, cv=3, n_jobs=-1, verbose=1)
GridSearchCV

 

以上代码构建好了KNN模型

 

#把训练集代入该模型

model1.fit(X_train1, y_train)

 

# 输出最好的K和p值

print(model1.best_params_)

 

 

# 输出在测试集上的准确率

 

print("准确率为 %.3f" % (model.score(X_test1,y_test)))

 

三,抽取图片特征,再用KNN算法来识别图片

  ——提取两种类型的特征,分别是color histogramHOG特征,并把它们拼接在一起作为最终的特征向量。

   1,调用特征提取工具给每一个图片提取特征。

   2, 使用K折交叉验证去学出最好的模型(同上)

 

1,两个特征抽取器,分别是hog_feature, color_histogram_hsv

   # 分别设置接下来需要调用的两个特征抽取器,分别是hog_feature, color_histogram_hsv

feature_fns = [hog_feature, lambda img: color_histogram_hsv(img, nbin=num_color_bins)]

 

 

# 抽取特征,分别对特征数据和测试数据,把结果存放在X_train2和X_test2

X_train2 = extract_features(X_train, feature_fns, verbose=True)
#X_val_feats = extract_features(X_val, feature_fns)
X_test2 = extract_features(X_test, feature_fns)

  !!!注意,要对特征数据做归一化,由于特征提取之后的,每一个维度的特征范围差异有可能比较大,所以使用KNN之前需要做归一化

 

2,同样用GridSearchCV网格搜索法,K折交叉验证,来构造KNN模型

      然后训练数据: (X_train2, y_train)代入模型, 用测试数据:(X_test2, y_test)来算该模型在该测试集上的准确率(步骤同上)

 

 

四,使用PCA对图片做降维,并做可视化

  ——PCA是一种常用的降维工具,可以把高维度的特征映射到任意低维的空间,所以这个方法也经常用来做数据的可视化

 

  其中有主要的参数需要设计,就是n_components, 指的是降维之后的维度。比如设置为2,就代表降维到2维的空间。

     接下来的任务是:

              通过PCA把数据降维,然后再通过KNN来分类

 

         首先使用PCA对数据做降维,之后再用KNN做交叉验证。 每一个PCA的维度都需要做一次KNN的交叉验证过程

          KNN的交叉验证过程同上,是先用GridSearchCV网格搜索法建立KNN模型,然后代入不同维度下的训练集,得到最好K,P的值,以及得到该模型在测试集上的准确率。

---------------------------------------------------------------------------------------------------------------------------------

未编译通过的部分:

 

 

一些知识点:

(1)a = np.unique(A): 

                              对于一维数组或者列表,unique函数去除其中重复的元素,并按元素由大到小返回一个新的无元素重复的元组或者列表

 (2)x[m,n]是通过numpy库引用数组或矩阵中的某一段数据集的一种写法,

          m代表第m维,n代表m维中取第几段特征数据。

             通常用法:

               x[:,n]或者x[n,:]

               x[:,n]表示在全部数组(维)中取第n个数据,直观来说,x[:,n]就是取所有集合的第n个数据

  (3) 在matplotlib中,整个图像为一个Figure对象

            plt.figure()#用于绘图

  (4)for i in range ()作用:

             range()是一个函数, for i in range () 就是给i赋值:

              比如 for i in range (1,3):就是把1,2依次赋值给i

               range () 函数的使用是这样的:

               range(start, stop[, step]),分别是起始、终止和步长

               range(3)即:从0到3,不包含3,即0,1,2

   (5)matplotlib下, 一个 Figure 对象可以包含多个子图(Axes), 可以使用 subplot() 快速绘制

       subplot()#绘制子图

   (6)where()的用法

            首先强调一下,where()函数对于不同的输入,返回的只是不同的。

               1,当数组是一维数组时,返回的值是一维的索引,所以只有一组索引数组

               2,当数组是二维数组时,满足条件的数组值返回的是值的位置索引,因此会有两组索引数组来表示值的位置

 

总结

  该任务确实很有难度,我在第一步都卡了半天,抛开python的基本语法不说,整体理解下来发现我对任务4,5的模型还不算很了解,不知道训练集中的验证数据和测试数据具体是怎么来评估的,以及本次任务中的网格搜索法,具体是怎么实现建立模型的,还有每个参数的意义都还不是很了解。但是由于我对python没有基础,可能要真正解决我的问题还要花更多时间,所以希望我能找时间把本次任务剩下的疑点解决。

 

相关参考资料:

GridSearchCV官网

 

https://scikit-learn.org/stable/modules/generated/sklearn.model_selection.GridSearchCV.html#sklearn.model_selection.GridSearchCV

 

KNeighborsClassifier官网

https://scikit-learn.org/stable/modules/generated/sklearn.neighbors.KNeighborsClassifier.html

以及一些博客

https://www.imooc.com/article/260226

gridSearchCV(网格搜索)的参数、方法及示例

 

https://blog.csdn.net/weixin_41988628/article/details/83098130

标签:总结,KNN,plt,10,模型,train,大概,思考,图片
来源: https://www.cnblogs.com/tann/p/12158459.html

本站声明: 1. iCode9 技术分享网(下文简称本站)提供的所有内容,仅供技术学习、探讨和分享;
2. 关于本站的所有留言、评论、转载及引用,纯属内容发起人的个人观点,与本站观点和立场无关;
3. 关于本站的所有言论和文字,纯属内容发起人的个人观点,与本站观点和立场无关;
4. 本站文章均是网友提供,不完全保证技术分享内容的完整性、准确性、时效性、风险性和版权归属;如您发现该文章侵犯了您的权益,可联系我们第一时间进行删除;
5. 本站为非盈利性的个人网站,所有内容不会用来进行牟利,也不会利用任何形式的广告来间接获益,纯粹是为了广大技术爱好者提供技术内容和技术思想的分享性交流网站。

专注分享技术,共同学习,共同进步。侵权联系[81616952@qq.com]

Copyright (C)ICode9.com, All Rights Reserved.

ICode9版权所有