ICode9

精准搜索请尝试: 精确搜索
首页 > 编程语言> 文章详细

python3实现Kmeans++算法

2019-08-12 18:52:44  阅读:293  来源: 互联网

标签:.__ ++ self Kmeans KMeans 算法 中心点 python3


零:环境

python 3.6.5

JetBrains PyCharm 2018.1.4 x64

一:KMeans算法大致思路

  KMeans算法是机器学习中的一种无监督聚类算法,是针对不具有类型的数据进行分类的一种算法

  形象的来说可以说成是给定一组点data,给定要分类的簇数k,来求中心点和对应的簇的集合

  中心点所在的簇中的其他点都是距离该中心点最近的点,因而才在一个簇里

  具体步骤

  1、首先在点集中随机寻找k个点来当作中心点

  2、然后初始化k个集合,用于存放对应的簇的对象

  3、开始KMeans算法的一轮。计算第i个点到k个中心点的距离[l1,l2,l3,……,ln],然后记录下距离最短的中心点,并将该点加入到对应的簇集合中

  4、全部点都计算完之后开始计算每个簇内的所有点的中心点,即取各个维度上的平均值的点作为新的中心点

  5、计算所有新旧中心点的距离的平方的和,看是否为0,不为0则继续循环或递归

  6、重复第3,4,5步骤,直到循环或递归跳出

  可以看出步骤还是非常简单明了的

  关于第5步为什么是0,因为当簇的分类趋于稳定的时候,各个簇之间应当没有数据的摆动。什么是数据的摆动呢?就是簇中的某个数上一次归属于簇A,这回归属于簇B,反复变化的情况即为摆动。

  对于KMeans算法来说是不存在的,因为新的中心点是簇内点集的中心点,所以当簇内稳定时新中心点也是稳定的,所以可以以0作为判断条件

  因为KMeans++算法与KMeans算法区别非常小,所以在讨论完KMeans++算法之后再一起发代码

二:KMeans++的思路

  KMeans++算法实际就是修改了KMeans算法的第一步操作

  之所以进行这样的优化,是为了让随机选取的中心点不再只是趋于局部最优解,而是让其尽可能的趋于全局最优解。要注意“尽可能”的三个字,即使是正常的KMeans++算法也无法保证百分百全局最优,在说取值原理之后我们就能知道为什么了

  思路就是我们要尽可能的保证各个簇的中心点的距离要尽可能的远

  当簇的中心尽可能的远的时候就能够尽可能的保证中心点之间不会在同一个簇内

  KMeans的迭代实际上就是簇的形状的修改,只要初始形状不太出格就会回归于正确形状

  具体步骤如下

  1、首先随机寻找一个点作为中心点

  2、然后计算其他点到目前的全部簇中心点的距离(最开始只有一个中心点)

  3、计算出映射到对应点的概率

\[\frac{{D{{(k)}^2}}}{{\sum\limits_{i = 0}^{\rm{m}} {D{{(i)}^2}} }}\]

  其中D(k)就是第k个点到其他中心点的最短距离,注意还有平方

  4、根据这个概率来利用轮盘法随机出一个中心点作为下一个中心点,然后重复2,3,4步骤直至找到全部中心点

  我们可以看出即使是KMeans++算法也只是概率性的选择,所以还是不稳定的,但是实际效果上已经比原有的随机选取K值好多了,当然最好的还是人工根据数据手动选取中心点

  以下是参考代码

  1 import csv
  2 import math
  3 import random
  4 from functools import reduce
  5 import matplotlib.pyplot as plt
  6 import numpy
  7 
  8 #   KMeans++算法,优化后的KMeans的算法
  9 class KMeansPP():
 10     def __init__(self,pBasePoints,pN = 5,pPointsCSVName = "kmeans_points.csv",pSetsCSVName = "kmeas_sets.csv"):
 11         """
 12         初始化KMeans++算法的构造函数
 13         :param pBasePoints: 所要计算的数据,为点的二维数组
 14         :param pN: 要分成的簇的个数
 15         :param pPointsCSVName: 要写入的点集的CSV文件
 16         :param pSetsCSVName: 要写入的簇的CSV文件
 17         """
 18         self.__N = pN
 19         self.__PCSVName = pPointsCSVName
 20         self.__SCSVName = pSetsCSVName
 21         self.__M = len(pBasePoints)#数据的个数
 22         self.__basePoints = pBasePoints
 23 
 24         self.__initBaseCenterPoint()   #kmeans++算法初始化中心点
 25         #self.__centerPoints = random.sample(self.__basePoints,self.__N) #kmeans算法初始化中心点
 26         self.__initSetsAndNewCenter()#初始化簇集合
 27         pass
 28 
 29     #   初始化N个点
 30     #   这里改进为Kmeans++算法
 31     def __initBaseCenterPoint(self):
 32         self.__centerPoints = []
 33         self.__centerPoints.append(self.__basePoints[random.randint(0, self.__M - 1)])#   首先初始化一个中心点
 34         while len(self.__centerPoints) < self.__N:#添加中心点直到N个
 35             tempDX = [min([KMeansPP.f_dAB(a,b) for b in self.__centerPoints])**2 for a in self.__basePoints]#D(x)的平方的列表。这一步中的a是遍历了所有的点,然后将a再分别与中心点集合进行遍历求出两点距离求出最短距离
 36             DXSum = sum(tempDX)#kmeans++公式中的分母
 37             DXP = []#轮盘法的值域范围计算,从开始的0到最后的1
 38             for i in range(len(tempDX)):
 39                 if i == 0:
 40                     DXP.append(tempDX[0]/DXSum)
 41                 else:
 42                     DXP.append(DXP[i-1]+tempDX[i]/DXSum)
 43             #   因为中心点到其他中心点的最短距离必定是0,所以必定不会选中中心点
 44             self.__centerPoints.append(self.__basePoints[KMeansPP.f_Roulette(DXP)])
 45         pass
 46 
 47     #   初始化新中心点和中心点集合
 48     def __initSetsAndNewCenter(self):
 49         self.__sets = {k:[] for k in self.__centerPoints}
 50         self.__newCenterPoints = []
 51 
 52     #   计算新的中心点
 53     def __countNewCenterPoints(self):
 54         self.__newCenterPoints = []
 55         pDim = len(self.__basePoints[0])
 56         for i in range(self.__N):#重新计算每个簇的中心点
 57             tp = self.__sets[self.__centerPoints[i]]#获取簇集合
 58             point = tuple([sum([p[i] for p in tp])/len(tp) for i in range(pDim)])#计算新的点。先i遍历维度,然后遍历每个点,对每个点的维度i取出来作为集合再求平均值。实际上就是矩阵的转置
 59             self.__newCenterPoints.append(point)
 60         pass
 61 
 62     #   求AB距离
 63     @staticmethod
 64     def f_dAB(A,B):
 65         dim = min(len(A),len(B))
 66         return sum([(A[i] - B[i]) ** 2 for i in range(dim)]) ** 0.5
 67 
 68     #   轮盘法,返回下标
 69     @staticmethod
 70     def f_Roulette(_list):
 71         tr = random.random()
 72         for i in range(len(_list)):
 73             if i == 0 and _list[i] > tr:
 74                 return 0
 75             else:
 76                 if _list[i] > tr and _list[i - 1] <= tr:
 77                     return i
 78 
 79     #   划分集合,kmeans算法
 80     def __kmeans(self):
 81 
 82         #   {其他点:[这个点到N个中心点的距离],……}
 83         t_dList = {b:[KMeansPP.f_dAB(a, b) for a in self.__centerPoints] for b in self.__basePoints}#先遍历b为其他点,a为中心点。计算点b到其他所有的中心点的距离
 84         for k,v in t_dList.items():
 85             self.__sets[self.__centerPoints[v.index(min(v))]].append(k)#将距离最小的添加到对应的簇里
 86 
 87         self.__countNewCenterPoints()#计算新中心点
 88         #   当各个簇之间有点变动时,就继续
 89         if sum([KMeansPP.f_dAB(self.__centerPoints[i],self.__newCenterPoints[i]) for i in range(self.__N)]) > 0:
 90             self.__centerPoints = self.__newCenterPoints[:]#把新中心点作为中心点
 91             self.__initSetsAndNewCenter()#重置集合和新中心点
 92             self.k_means()#递归调用
 93         pass
 94 
 95     #   k_means算法的对外接口
 96     def k_means(self):
 97         self.__kmeans()
 98         return self.__sets,self.__centerPoints
 99 
100     def writeToCSV(self):
101         with open(self.__SCSVName,"w",newline="") as fpc:
102             fpcWriter = csv.writer(fpc)
103             fpcWriter.writerow(self.__centerPoints)
104             maxIndex = max([len(v) for k, v in self.__sets.items()])
105             fpcWriter.writerows([[v[i] if len(v) > i else "" for (k, v) in self.__sets.items()] for i in range(maxIndex)])
106             pass
107 
108         with open(self.__PCSVName,"w",newline="") as fpp:
109             fppWriter = csv.writer(fpp)
110             fppWriter.writerows([[self.__basePoints[i*10 + j] if i*10+j < self.__M else "" for j in range(10)] for i in range(self.__M//10)])
111             pass
112         pass
kmeans与kmeans++代码

 


 

  本文原创,转载请注明出处https://www.cnblogs.com/dofstar/p/11341494.html

标签:.__,++,self,Kmeans,KMeans,算法,中心点,python3
来源: https://www.cnblogs.com/dofstar/p/11341494.html

本站声明: 1. iCode9 技术分享网(下文简称本站)提供的所有内容,仅供技术学习、探讨和分享;
2. 关于本站的所有留言、评论、转载及引用,纯属内容发起人的个人观点,与本站观点和立场无关;
3. 关于本站的所有言论和文字,纯属内容发起人的个人观点,与本站观点和立场无关;
4. 本站文章均是网友提供,不完全保证技术分享内容的完整性、准确性、时效性、风险性和版权归属;如您发现该文章侵犯了您的权益,可联系我们第一时间进行删除;
5. 本站为非盈利性的个人网站,所有内容不会用来进行牟利,也不会利用任何形式的广告来间接获益,纯粹是为了广大技术爱好者提供技术内容和技术思想的分享性交流网站。

专注分享技术,共同学习,共同进步。侵权联系[81616952@qq.com]

Copyright (C)ICode9.com, All Rights Reserved.

ICode9版权所有