ICode9

精准搜索请尝试: 精确搜索
首页 > 其他分享> 文章详细

基于鸢尾花卉数据集的Fisher分类器设计

2021-10-06 14:34:13  阅读:265  来源: 互联网

标签:vector 分类器 np train ax test Fisher 鸢尾 mean


基于鸢尾花卉数据集的Fisher分类器设计

本文主要探讨Iris数据集(二维)的Fisher线性分类器的设计。
数据集下载

1. 预处理

# 导包
import numpy as np
import matplotlib.pyplot as plt
import pandas as pd
import math # 数学函数
import sympy as sp # 绘图
from sklearn.model_selection import train_test_split # 拆分数据集的工具

# 导入数据
data = pd.read_excel('3-iris 数据集(2类).xls',header=None)
X = np.array(data.iloc[:,2:4]) # 截取两维
y = np.array(data.iloc[:,4])
y_c = np.unique(y) # 离散化数据
X_train,X_test,y_train,y_test=train_test_split(X,y,test_size=0.3,random_state=0)

2. 求取向量均值

np.set_printoptions(precision=4)
mean_vector = []  # 类别的平均值
for i in y_c:
    mean_vector.append(np.mean(X_train[y_train == i], axis=0))

结果

1:
1.445714285714285507e+00
2.457142857142857462e-01

2:
4.265714285714285126e+00
1.342857142857142749e+00

3. 计算类内离散度矩阵

S_W = np.zeros((X_train.shape[1], X_train.shape[1]))
for i in y_c:
    Xi = X_train[y_train == i] - mean_vector[i-1]
    S_W += np.mat(Xi).T * np.mat(Xi)
print('S_W:',S_W)

结果

[[9.6257 3.0683]
 [3.0683 1.9126]]

4. 计算类间离散度矩阵

S_B = np.zeros((X_train.shape[1], X_train.shape[1]))
mu = np.mean(X_train, axis=0)  # 所有样本平均值
for i in y_c:
    Ni = len(X_train[y_train == i])
    S_B += Ni * np.mat(mean_vector[i-1] - mu).T * np.mat(mean_vector[i-1] - mu)
print('S_B:',S_B)

结果

[[139.167   54.144 ]
 [ 54.144   21.0651]]

5. 计算最优投影方向w

w = (np.linalg.inv(S_W)).dot((mean_vector[0]-mean_vector[1]).T)
print(w)

结果

[-0.2253 -0.2121]

6. 计算决策面常数项

三种不同的 P ( w 2 ) / P ( w 1 ) P(w_2)/P(w_1) P(w2​)/P(w1​)

P = [1,3/7,1/9]
w_0 = []
const1 = -0.5*(((mean_vector[0]+mean_vector[1]).dot(np.linalg.inv(S_W))).dot((mean_vector[0]-mean_vector[1]).T))
for i in P:
    w_0.append(const1-math.log(i))
print(w_0)

结果

[0.8120180310967475, 1.6593158914839512, 3.009242608432967]

7. 根据不同的先验概率比绘图

fig,ax = plt.subplots(1,1)
ax.scatter(X_train[y_train == 1][:,0],X_train[y_train == 1][:,1],c='b',label='1')
ax.scatter(X_train[y_train == 2][:,0],X_train[y_train == 2][:,1],c='r',label='2')
x = sp.Symbol('x')
y = sp.Symbol('y')
X = np.array([x,y])
xx,yy = np.linspace(0,10,7),np.linspace(0,10,7)
x,y = np.meshgrid(xx,yy)
ax.contour(x,y,(0.812018031096748 - 0.225346692821696*x - 0.212130544635166*y),[0]);
ax.contour(x,y,(1.6593158914839512 - 0.225346692821696*x - 0.212130544635166*y),[0]);
ax.contour(x,y,(3.009242608432967 - 0.225346692821696*x - 0.212130544635166*y),[0]);
ax.legend()
plt.show()

图示

从下到上先验概率 P ( w 2 ) / P ( w 1 ) P(w_2)/P(w_1) P(w2​)/P(w1​)分别为1,3/7,1/9:
在这里插入图片描述

8. 对测试数据集进行计算错误率

scores=0
for i in range(len(X_test[:,0])):
    if ((0.812018031096748 - 0.225346692821696*X_test[i,0] - 0.212130544635166*X_test[i,1] > 0)&(y_test[i]==1))|((0.812018031096748 - 0.225346692821696*X_test[i,0] - 0.212130544635166*X_test[i,1] < 0)&(y_test[i]==2)):
        scores+=1;
print('errorRate:',1-scores/len(X_test[:,0]))

结果

errorRate: 0.0

标签:vector,分类器,np,train,ax,test,Fisher,鸢尾,mean
来源: https://blog.csdn.net/linjing_zyq/article/details/120624191

本站声明: 1. iCode9 技术分享网(下文简称本站)提供的所有内容,仅供技术学习、探讨和分享;
2. 关于本站的所有留言、评论、转载及引用,纯属内容发起人的个人观点,与本站观点和立场无关;
3. 关于本站的所有言论和文字,纯属内容发起人的个人观点,与本站观点和立场无关;
4. 本站文章均是网友提供,不完全保证技术分享内容的完整性、准确性、时效性、风险性和版权归属;如您发现该文章侵犯了您的权益,可联系我们第一时间进行删除;
5. 本站为非盈利性的个人网站,所有内容不会用来进行牟利,也不会利用任何形式的广告来间接获益,纯粹是为了广大技术爱好者提供技术内容和技术思想的分享性交流网站。

专注分享技术,共同学习,共同进步。侵权联系[81616952@qq.com]

Copyright (C)ICode9.com, All Rights Reserved.

ICode9版权所有