ICode9

精准搜索请尝试: 精确搜索
首页 > 编程语言> 文章详细

kNN(K近邻) 算法

2021-02-02 07:35:25  阅读:157  来源: 互联网

标签:kNN digits ... 近邻 算法 train test import


目录


KNN:k-Nearest Neighbors

K近邻算法示例:

  • 数据:两类点方块和三角.
  • 绿色的点属于方块还是三角呢?
  • K=3还是K=5?结果一样吗?


K近邻计算流程:

  1. 计算 已知类别数据集中的点 与 当前点 的距离;
  2. 按照距离依次排序;
  3. 选取与当前点距离最小的K个点;
  4. 确定前K个点所在类别的出现概率;
  5. 返回前K个点出现频率最高的类别作为当前点预测分类;

特点

  • 思想极度简单;它是一种 lazy-learning 算法。
  • 应用数学知识少(几乎为0);
  • 效果好;
  • 很适合入门机器学习;
    更完整的刻画机器学习应用的流程;
    分类器不需要使用训练集进行训练,训练时间复杂度为0。
    可以解释机器学习算法 使用过程中的很多细节问题;
  • KNN分类的计算复杂度和训练集中的文档数目成正比,也就是说,如果训练集中文档总数为n,那么KNN的分类时间复杂度为O(n)。
  • K值的选择,距离度量和分类决策规则是该算法的三个基本要素
  • 本质:如果两个样本足够相似,就大概率属于同一个类别。和 n 个样本相似,和n个样本属于同一个类别
  • 可以解决分类问题,也可以解决回归问题;
  • KNN 不需要训练过程,没有产生模型;为了和其他算法统一,可以认为训练数据集 就是 模型本身。

使用

python 原生的实现

from math import sqrt
distances = []
for x_ train in X_train:
    d = sqrt(np. sum((x_ train - x)**2) ) 
    distances.append(d)

 # 排序后 原数据的索引
nearest = np.argsort(distance) 

k = 6
topK_y = [ y_train[i] for i in nearest[:k] ]

from collections import Counter

votes = Counter(topK_y)

votes.most_common(2) # 最多的数据


以上方法封装

import numpy as np
from math import sqrt
from collections import Counter

def kNN_classify(k, X_train, y_train, x):
    assert 1 <= k <= X_train.shape[0], "k must be valid"
    assert X_train.shape[0] == y_train.shape [0], "the size of X_train must equal to the size of y_train"
    assert X_train.shape[1] == x.shape[0], "the feature number of x must be equal to X_train"
    distances = [sqrt(np.sum((x_train - x)**2)) for x_train in X_train ]
    nearest = np.argsort(distances)
    
    topK_y = [y_train[i] for i in nearest[:k]]
    votes = Counter(topK_y)
    return votes.most_common(1)[0][0]

sklearn 的实现

from sklearn.neighbors import KNeighborsClassifier
knn_cls = KNeighborsClassifier(n_neighbors=6)
knn_cls.fit(X_train, y_train) # 返回 机器学习对象自身
X_predict = x.reshape(1, -1) # 将 x 转化为矩阵,-1表示numpy 来决定多少列
# 预测,需要传入矩阵
y_predict = knn_cls.predict(X_predict) # 得到向量
y_predict[0]

sklearn 调用方法的封装

import numpy as np
from math import sqrt
from collections import Counter
from .metrics import accuracy_score

class KNNClassifier:

    def __init__(self, k):
        """初始化kNN分类器"""
        assert k >= 1, "k must be valid"
        self.k = k
        self._X_train = None
        self._y_train = None
    
    def fit(self, X_train, y_train):
        """根据训练数据集X_train和y_train训练kNN分类器"""
        assert X_train.shape[0] == y_train.shape[0], \
            "the size of X_train must be equal to the size of y_train"
        assert self.k <= X_train.shape[0], \
            "the size of X_train must be at least k."
    
        self._X_train = X_train
        self._y_train = y_train
        return self
    
    def predict(self, X_predict):
        """给定待预测数据集X_predict,返回表示X_predict的结果向量"""
        assert self._X_train is not None and self._y_train is not None, "must fit before predict!"
        assert X_predict.shape[1] == self._X_train.shape[1],   "the feature number of X_predict must be equal to X_train"
    
        y_predict = [self._predict(x) for x in X_predict]
        return np.array(y_predict)
    
    # 给定单个待预测数据x,返回x的预测结果值
    def _predict(self, x):
         
        assert x.shape[0] == self._X_train.shape[1],  "the feature number of x must be equal to X_train"
    
        distances = [sqrt(np.sum((x_train - x) ** 2))
                     for x_train in self._X_train]
        nearest = np.argsort(distances)
    
        topK_y = [self._y_train[i] for i in nearest[:self.k]]
        votes = Counter(topK_y)
    
        return votes.most_common(1)[0][0]
    
    # 根据测试数据集 X_test 和 y_test 确定当前模型的准确度
    def score(self, X_test, y_test): 
        y_predict = self.predict(X_test)
        return accuracy_score(y_test, y_predict)
    
    def __repr__(self):
        return "KNN(k=%d)" % self.k
 

评估算法的准确性

测试数据和训练数据部分可参考:

my_knn_clf = KNNClassifier(k=3)
my_knn_clf.fit(X_train, y_train)
y_predict = my_knn_clf.predict(X_test)
sum(y_test == y_predict)/len(y_test) # 0.8666666666666667

以上评估方式也称为 分类的准确度 accuracy


digits 手写数字识别

import numpy as np
from sklearn import datasets
import matplotlib.pyplot as plt
 
digits = datasets.load_digits()  # digits 是一个字典类型数据
 
digits
''' 
    {'data': array([[ 0.,  0.,  5., ...,  0.,  0.,  0.],
            [ 0.,  0.,  0., ..., 10.,  0.,  0.],
            ...,
            [ 0.,  0.,  1., ...,  6.,  0.,  0.],
            [ 0.,  0., 10., ..., 12.,  1.,  0.]]),
     'target': array([0, 1, 2, ..., 8, 9, 8]),
     'target_names': array([0, 1, 2, 3, 4, 5, 6, 7, 8, 9]),
     'images': array([[[ 0.,  0.,  5., ...,  1.,  0.,  0.],
             [ 0.,  0., 13., ..., 15.,  5.,  0.],
             ...,
             [ 0.,  4., 16., ..., 16.,  6.,  0.],
             [ 0.,  8., 16., ..., 16.,  8.,  0.],
             [ 0.,  1.,  8., ..., 12.,  1.,  0.]]]),
     'DESCR': ".. _digits_dataset:\n\nOptical recognition of handwritten digits dataset\n--------------------------------------------------\n\n**Data Set Characteristics:**\n\n    :Number of Instances: 5620\n    :Number of Attributes: 64\n    :Attribute Information: 8x8 image of integer pixels in the range 0..16.\n    :Missing Attribute Values: None\n    :Creator: E. Alpaydin (alpaydin '@' boun.edu.tr)\n    :Date: July; 1998\n\nThis is a copy of the test set of the UCI ML hand-written digits datasets\nhttps://archive.ics.uci.edu/ml/datasets/Optical+Recognition+of+Handwritten+Digits\n\nThe data set contains images of hand-written digits: 10 classes where\neach class refers to a digit.\n\nPreprocessing programs made available by NIST were used to extract\nnormalized bitmaps of handwritten digits from a preprinted form. From a\ntotal of 43 people, 30 contributed to the training set and different 13\nto the test set. 32x32 bitmaps are divided into nonoverlapping blocks of\n4x4 and the number of on pixels are counted in each block. This generates\nan input matrix of 8x8 where each element is an integer in the range\n0..16. This reduces dimensionality and gives invariance to small\ndistortions.\n\nFor info on NIST preprocessing routines, see M. D. Garris, J. L. Blue, G.\nT. Candela, D. L. Dimmick, J. Geist, P. J. Grother, S. A. Janet, and C.\nL. Wilson, NIST Form-Based Handprint Recognition System, NISTIR 5469,\n1994.\n\n.. topic:: References\n\n  - C. Kaynak (1995) Methods of Combining Multiple Classifiers and Their\n    Applications to Handwritten Digit Recognition, MSc Thesis, Institute of\n    Graduate Studies in Science and Engineering, Bogazici University.\n  - E. Alpaydin, C. Kaynak (1998) Cascading Classifiers, Kybernetika.\n  - Ken Tang and Ponnuthurai N. Suganthan and Xi Yao and A. Kai Qin.\n    Linear dimensionalityreduction using relevance weighted LDA. School of\n    Electrical and Electronic Engineering Nanyang Technological University.\n    2005.\n  - Claudio Gentile. A New Approximate Maximal Margin Classification\n    Algorithm. NIPS. 2000."}
''' 
 
digits.keys()
''' 
    dict_keys(['data', 'target', 'target_names', 'images', 'DESCR'])
''' 
 
digits.DESCR 

X = digits.data
y = digits.target
X.shape, y.shape
# ((1797, 64), (1797,))

digits.target_names
# array([0, 1, 2, 3, 4, 5, 6, 7, 8, 9])
 
y[:100]
''' 
array([0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 0, 1,
           2, 3, 4, 5, 6, 7, 8, 9, 0, 9, 5, 5, 6, 5, 0, 9, 8, 9, 8, 4, 1, 7,
           7, 3, 5, 1, 0, 0, 2, 2, 7, 8, 2, 0, 1, 2, 6, 3, 3, 7, 3, 3, 4, 6,
           6, 6, 4, 9, 1, 5, 0, 9, 5, 2, 8, 2, 0, 0, 1, 7, 6, 3, 2, 1, 7, 4,
           6, 3, 1, 3, 9, 1, 7, 6, 8, 4, 3, 1])
''' 
from sklearn.model_selection import train_test_split
 
X_train, X_test, y_train, y_test = train_test_split( X, y, test_size=0.2, random_state=2) 
 
y_test.shape, y_train.shape
# ((360,), (1437,))
 
from sklearn.neighbors import KNeighborsClassifier
 
knn_cls = KNeighborsClassifier(n_neighbors=3)
knn_cls.fit(X_train, y_train) 
''' 
    KNeighborsClassifier(algorithm='auto', leaf_size=30, metric='minkowski',
                         metric_params=None, n_jobs=None, n_neighbors=3, p=2,
                         weights='uniform')
''' 
y_predict = knn_cls.predict(X_test)
 
sum(y_predict == y_test)/len(y_test)
# 0.9888888888888889

from sklearn.metrics import accuracy_score
accuracy_score(y_test, y_predict)  # 0.9888888888888889

超参数 & 模型参数

超参数:运行算法之前需要决定的参数
模型参数:算法过程中学习的参数

kNN 算法中没有模型参数,k 是经典的超参数;


如何寻找好的超参数:

  • 领域知识
  • 经验数值
    如 kNN 中 k 的经典数值为5;但实际问题以实际为准。
  • 试验搜索

kNN 算法中的超参数

如果只考虑最邻近节点,那么附近蓝色节点多于红色节点,会被判断为蓝色节点;
但红色距离它更近,所以需要将距离(距离的倒数)也考虑进去。 1 > 1/3 + 1/4
这样还可以解决 平票 的问题;即四个节点时,两蓝两红。


使用网格搜索查找最优超参数

数据归一化后进行KNN

机器学习中的距离


更多

  • kNN 是解决分类问题的算法,天然的可以解决多分类问题。
  • kNN 解决回归问题
    kNN 还可以解决回归问题,如预测房价、考试分数。
    方法:找到距离最近的 k 个节点,计算他们的平均值,或者根据距离添加权重。
    sklearn 中对应的类:KNeighborsRegressor

缺点:

  1. 效率低下(最大的缺点)
    如果训练集有 m 个样本,n 个特征,则预测 每一个新的数据,需要 O(m*n) 的时间复杂度。
    优化:使用树结构:KD-Tree, Ball-Tree;即使如此,kNN 的效率依然非常低。

  2. 高度数据相关
    相较其他机器学习算法,kNN 对 outlier 更加敏感。

  3. 预测的结果不具有可解释性。

  4. 维数灾难
    随着维度的增加,看似相近的两个点之间的距离 越来越大。处理高维度的数据,很容易产生维数灾难。如下:

维数 距离
1维 0到1的距离 1
2维 (0,0)到(1,1) 的距离 1.414
3维 (0,0,0)到(1,1,1)的距离 1.73
64维 (0,0...0)到(1,1...1)的距离 8
10000维 (0,0...0)到(1,1...1)的距离 100

维数灾难解决方法:降维



https://zhuanlan.zhihu.com/p/25994179

标签:kNN,digits,...,近邻,算法,train,test,import
来源: https://www.cnblogs.com/devwalks/p/14360131.html

本站声明: 1. iCode9 技术分享网(下文简称本站)提供的所有内容,仅供技术学习、探讨和分享;
2. 关于本站的所有留言、评论、转载及引用,纯属内容发起人的个人观点,与本站观点和立场无关;
3. 关于本站的所有言论和文字,纯属内容发起人的个人观点,与本站观点和立场无关;
4. 本站文章均是网友提供,不完全保证技术分享内容的完整性、准确性、时效性、风险性和版权归属;如您发现该文章侵犯了您的权益,可联系我们第一时间进行删除;
5. 本站为非盈利性的个人网站,所有内容不会用来进行牟利,也不会利用任何形式的广告来间接获益,纯粹是为了广大技术爱好者提供技术内容和技术思想的分享性交流网站。

专注分享技术,共同学习,共同进步。侵权联系[81616952@qq.com]

Copyright (C)ICode9.com, All Rights Reserved.

ICode9版权所有