ICode9

精准搜索请尝试: 精确搜索
首页 > 编程语言> 文章详细

Kd树实现K近邻算法

2022-02-05 19:03:26  阅读:173  来源: 互联网

标签:node index Kd 近邻 节点 算法 neighbor sets self


参考 GitHub

"""
Kd树搜索的k近邻模型,和《统计学习方法》上介绍的最近邻算法差距有点大..
(1) 设定一个当前最优点集合,用来保存当前离搜索点最近的样本点集合
(2) 从根节点开始,并设其为当前节点;在此code中由query执行,下面的(3)(4)(5)(6)(7)由一个函数_search执行
(3) 如果当前节点为空,则更新集合并结束;
(4) 如果被搜索点的划分维度小于当前节点的划分维度,则设当前节点的左孩子为新的下一次应该访问的节点;反之设当前节点的右孩子为下一次应该访问的节点
(5) 如果当前节点到被搜索点的距离小于当前全局最短距离,则更新最优k点集;
(6) 如果被搜索点到当前节点划分维度的距离小于全局最短距离,则说明全局最佳点可能存在于当前节点的另外一个子树中,
所以设当前节点的另外一个孩子为下一次应该访问的节点并执行步骤(3)(4)(5)(6)(7);
(7) 搜索下一次应该访问的节点,它是由(4)设置的,直到叶子节点,即触发(3)

算法减少搜索量就是通过只检查有可能足够近的点,而对于不可能是最近的k个点的那些点则永远不理会
哪些是不可能的点?答:假设当前节点离目标点够近(有可能是那k个点),比一下当前点和目标点某维度的大小,左子树都是比当前点小的点,
如果目标点也比当前点小,显然左子树的那些点应该考虑一下是不是的那k个点,因为这时左子节点必然比右子节点离目标点近,再判断一下右子节点离目标点的距离
如果这个距离比最差的最优点都小,那右子树的其他节点就不用考虑了

"""
import json
from sklearn.model_selection import train_test_split
from sklearn.datasets import load_iris
import numpy as np
import pandas as pd
import time


class Node:
	"""节点类"""

	def __init__(self, value, index, left_child, right_child):
		"""
		:param value:节点存储的k维特征
		:param index:此节点存储的样本点在训练集中的index
		:param left_child:左子树
		:param right_child:右子树
		"""
		self.value = value.tolist()
		self.index = index
		self.left_child = left_child
		self.right_child = right_child

	def __repr__(self):
		"""
		print对象时相当于执行print(对象.__repr__)
		"""
		return json.dumps(self, indent=3, default=lambda obj: obj.__dict__, ensure_ascii=False)


class KDTree:
	"""kd tree类"""

	def __init__(self, data):
		# 原始训练数据集
		self.data = np.asarray(data)
		# 构造的kd树的根节点
		self.kd_tree = None
		# 每个节点所存储的样本特征数量,即特征空间的维度
		self.K = self.data.shape[1]
		# 创建平衡kd树
		self._create_kd_tree(data)

	def __repr__(self):
		"""
		调用时相当于print的是kd_tree的根节点,它是一个Node对象,相当于print的是Node类的__repr__
		"""
		return str(self.kd_tree)

	def SplitData_GetNode(self, data, depth):
		"""
		根据数据集构造节点
		:param data:剩下待分割的数据集
		:param depth:现在构造的节点所处的深度
		:return:子树的节点
		"""
		if len(data) == 0:
			return None

		split_feature_index = depth % self.K
		# 对数据在split_feature_index维度进行排序
		data = data[data[:, split_feature_index].argsort()]
		# 切分点
		median_index = data.shape[0] // 2
		# 获取结点在原始训练数据集中的位置
		node_index = [i for i, v in enumerate(self.data) if list(v) == list(data[median_index])]
		return Node(
			# 本结点
			value=data[median_index],
			# 本结点在数据集中的位置
			index=node_index[0],
			# 左子结点
			left_child=self.SplitData_GetNode(data[:median_index], depth + 1),
			# 右子结点
			right_child=self.SplitData_GetNode(data[median_index + 1:], depth + 1)
		)

	def _create_kd_tree(self, X):
		"""
		:param X:数据集,只在初始化时调用,递归的开启:从最原始数据集开始,构造深度为0的节点
		:return:根节点
		"""
		self.kd_tree = self.SplitData_GetNode(X, 0)

	def query(self, TargetPoint, k=1):
		"""
		调用时搜索目标点,开始递归
		:param TargetPoint: 目标点
		:param k:
		:return:
		"""
		TargetPoint = np.asarray(TargetPoint)
		results = self._search(TargetPoint, 0, self.kd_tree, k=k, k_neighbor_sets=list())
		DistanceList_KBestPoints_And_TargetPoint = np.array([results[0] for results in results])
		IndexList_KBestPoints_in_OriginalData = np.array([results[1] for results in results])
		return DistanceList_KBestPoints_And_TargetPoint, IndexList_KBestPoints_in_OriginalData

	@staticmethod
	def _cal_node_distance(node1, node2):
		"""计算两个结点之间的距离"""
		return np.sqrt(np.sum(np.square(node1 - node2)))

	def _search(self, point, depth, node=None, k=1, k_neighbor_sets=None):
		"""
		搜索附近可能的那k个点
		:param point: 目标点,即新输入模型的样本点,现在要预测这个点的label
		:param depth: 当前考察节点的深度
		:param node: 当前考察点,即判断这个点是否确实是那k个点
		:param k: k近邻中的k
		:param k_neighbor_sets: 最优点集,存储的是(节点离目标点的距离, 节点存储值在原始训练集中的位置, 节点值)
		:return:无
		"""
		if k_neighbor_sets is None:
			k_neighbor_sets = []
		if node is None:
			return k_neighbor_sets

		#到了叶节点,停止搜索,附近可能的点已经搜索结束
		if node.left_child is None and node.right_child is None:
			# 更新当前k近邻点集
			return self._update_k_neighbor_sets(k_neighbor_sets, k, node, point)

		# 递归地向下访问kd树
		if point[depth % self.K] < node.value[depth % self.K]:
			direct = 'left'
			next_node = node.left_child
		else:
			direct = 'right'
			next_node = node.right_child
		if next_node is not None:
			#判断当前结点是否符合进入最优点集的条件
			k_neighbor_sets = self._update_k_neighbor_sets(k_neighbor_sets, k, next_node, point)
			#检查兄弟节点对应的区域是否相交,
			if direct == 'left' and node.right_child is not None:
				node_distance = self._cal_node_distance(point, node.right_child.value)
				if k_neighbor_sets[0][0] > node_distance:
					# 如果相交,递归地进行近邻搜索
					return self._search(point, depth=depth + 1, node=node.right_child, k=k,
										k_neighbor_sets=k_neighbor_sets)
			elif node.left_child is not None:
				node_distance = self._cal_node_distance(point, node.left_child.value)
				if k_neighbor_sets[0][0] > node_distance:
					return self._search(point, depth=depth + 1, node=node.left_child, k=k,
										k_neighbor_sets=k_neighbor_sets)

		return self._search(point, depth=depth + 1, node=next_node, k=k, k_neighbor_sets=k_neighbor_sets)

	def _update_k_neighbor_sets(self, k_neighbor_sets, k, node, point):
		"""
		更新最优点集
		:param k_neighbor_sets:最优点集
		:param k: k近邻中的k
		:param node: 符合进入最优点集的条件
		:param point: 目标点
		:return: 更新后的最优点集
		"""
		# 计算目标点与当前结点的距离
		node_distance = self._cal_node_distance(point, node.value)
		if len(k_neighbor_sets) == 0:
			k_neighbor_sets.append((node_distance, node.index, node.value))
		elif len(k_neighbor_sets) < k:
			# 如果“当前k近邻点集”元素数量小于k
			self._insert_k_neighbor_sets(k_neighbor_sets, node, node_distance)
		else:
			# 叶节点距离小于“当前k近邻点集”中最远点距离
			if k_neighbor_sets[0][0] > node_distance:
				k_neighbor_sets = k_neighbor_sets[1:] #将差的点移除
				self._insert_k_neighbor_sets(k_neighbor_sets, node, node_distance)
		return k_neighbor_sets

	@staticmethod
	def _insert_k_neighbor_sets(k_neighbor_sets, node, node_distance):
		"""
		发现新的更符合的点
		:param k_neighbor_sets:最优点集
		:param node: 节点
		:param node_distance:节点与目标点间的距离
		:return: 无
		"""
		n = len(k_neighbor_sets)
		for i, item in enumerate(k_neighbor_sets):
			if item[0] < node_distance:
				# 将距离最远的结点插入到前面
				k_neighbor_sets.insert(i, (node_distance, node.index, node.value))
				break
		if len(k_neighbor_sets) == n:
			k_neighbor_sets.append((node_distance, node.index, node.value))


iris = load_iris()
df = pd.DataFrame(iris.data, columns=iris.feature_names)
df['label'] = iris.target
data = np.array(df)
X = data[:, 0:4]
y = data[:, 4]
train_X, test_X, train_y, test_y = train_test_split(X, y, test_size=0.3)  # (105, 4) (105,)
# print(train_X.shape,train_y.shape)
my_clock_start = time.perf_counter()
kd_tree = KDTree(train_X)
k = 2
cnt = 0
for i, point in enumerate(test_X):
	# print("目标点是{},对应标签是{}".format(point,test_y[i]))
	DistanceList_KBestPoints_And_x, IndexList_KBestPoints_in_OriginalData = kd_tree.query(np.array(point), k=k)
	predict_y = []
	for index in IndexList_KBestPoints_in_OriginalData:
		# print(train_y[index],end=",")
		predict_y.append(train_y[index])
	# print("")
	if (max(predict_y, key=predict_y.count) == test_y[i]):
		cnt += 1
print("我的模型准确率为{}".format(cnt / len(test_X)))
my_clock_end = time.perf_counter()
print("我的模型花费时间{}".format(my_clock_end - my_clock_start))

cnt = 0
from sklearn.neighbors import KNeighborsClassifier

clocl_start = time.perf_counter()
model = KNeighborsClassifier(n_neighbors=2)
model.fit(train_X, train_y)
predict_y = model.predict(test_X)
for index, label in enumerate(predict_y):
	if label == test_y[index]:
		cnt += 1
print("sklearn模型准确率为{}".format(cnt / len(test_X)))
clock_end = time.perf_counter()
print("sklearn模型花费时间{}".format(clock_end - clocl_start))

"""
我的模型准确率为0.9333333333333333    准确率有时比sklearn中的高,要看k的选择
我的模型花费时间0.10518830000000001   时间差接近十倍,因为做了一些其他事
sklearn模型准确率为0.9111111111111111
sklearn模型花费时间0.007459700000000069
"""

标签:node,index,Kd,近邻,节点,算法,neighbor,sets,self
来源: https://blog.csdn.net/weixin_52812620/article/details/122792354

本站声明: 1. iCode9 技术分享网(下文简称本站)提供的所有内容,仅供技术学习、探讨和分享;
2. 关于本站的所有留言、评论、转载及引用,纯属内容发起人的个人观点,与本站观点和立场无关;
3. 关于本站的所有言论和文字,纯属内容发起人的个人观点,与本站观点和立场无关;
4. 本站文章均是网友提供,不完全保证技术分享内容的完整性、准确性、时效性、风险性和版权归属;如您发现该文章侵犯了您的权益,可联系我们第一时间进行删除;
5. 本站为非盈利性的个人网站,所有内容不会用来进行牟利,也不会利用任何形式的广告来间接获益,纯粹是为了广大技术爱好者提供技术内容和技术思想的分享性交流网站。

专注分享技术,共同学习,共同进步。侵权联系[81616952@qq.com]

Copyright (C)ICode9.com, All Rights Reserved.

ICode9版权所有