ICode9

精准搜索请尝试: 精确搜索
首页 > 编程语言> 文章详细

Python手写字母识别

2021-05-09 18:59:58  阅读:424  来源: 互联网

标签:... Python 字母 models train import 手写 csv sklearn


目录

Python手写字母识别

准备

#设置镜像
pip install pip -U
pip config set global.index-url https://pypi.tuna.tsinghua.edu.cn/simple
#安装
pip install numpy
pip install pandas
pip install scipy
pip install scikit-learn
pip install imageio

数据下载

Chars74K数据集:链接
使用EnglishHnd数据集,每个字母/数字有55张
由于原始数据的格式不方便处理,将其处理为以下格式,然后把每张图片缩小为40x30:

img
--- 0
------ 0.jpg
------ 1.jpg
------ ...
------ 54.jpg
--- 1
------ ...
--- lower_a
------ ...
--- ...
--- upper_a
------ ...
--- ...

处理过程我就不细说了,可以直接下载处理后的数据集

数据预处理

将数据集处理为CSV格式:

import numpy as np
import imageio
import os
#将各个数字、字母对应0-61中的一个数字
classes=os.listdir('img')
f=open("map.csv","w")
for i,j in enumerate(classes):
    print(str(i)+','+j,file=f)
#用imageio库读入图片
x=[]
y=[]
for i in classes:
    for j in range(55):
        img=imageio.imread("img/{}/{}.jpg".format(i,j))
        bw=[]
        for k in img:
            bw.append(l[0])
        x_train.append(bw)
        y_train.append(classes.index(i))
#导出csv
f=open("chars.csv","w")
for i,j in zip(x,y):
    print(','.join(i)+','+str(j),file=f)

得到map.csv:

0,0
......
10,lower_a
......
36,upper_a
......

chars.csv:

0,0,...,0,0,0
......
0,0,...,0,0,61

下载:map.csv chars.csv

训练

首先导入包:

from pandas import read_csv
from pandas.plotting import scatter_matrix
from matplotlib import pyplot
from sklearn.model_selection import train_test_split
from sklearn.model_selection import KFold
from sklearn.model_selection import cross_val_score
from sklearn.metrics import classification_report
from sklearn.metrics import confusion_matrix
from sklearn.metrics import accuracy_score
from sklearn.linear_model import LogisticRegression
from sklearn.tree import DecisionTreeClassifier
from sklearn.discriminant_analysis import LinearDiscriminantAnalysis
from sklearn.neighbors import KNeighborsClassifier
from sklearn.naive_bayes import GaussianNB
from sklearn.svm import SVC
import joblib

评估模型:

#读取、处理
filename = 'chars.csv'
dataset = read_csv(filename)
array = dataset.values
X = array[:, 0:784]
Y = array[:, 784]
validation_size = 0.2
seed = 7
X_train, X_validation, Y_train, Y_validation = \
    train_test_split(X, Y, test_size=validation_size, random_state=seed)
#模型选择
models = {}
models['LR'] = LogisticRegression()
models['LDA'] = LinearDiscriminantAnalysis()
models['KNN'] = KNeighborsClassifier()
models['CART'] = DecisionTreeClassifier()
models['NB'] = GaussianNB()
models['SVM'] = SVC()
results = []
for key in models:
    kfold = KFold(n_splits=10, random_state=seed, shuffle=True)
    cv_results = cross_val_score(models[key], X_train, Y_train, cv=kfold, scoring='accuracy')
    results.append(cv_results)
    print('%s: %f (%f)' %(key, cv_results.mean(), cv_results.std()))
#可视化
fig = pyplot.figure()
fig.suptitle('Algorithm Comparison')
ax = fig.add_subplot(111)
pyplot.boxplot(results)
ax.set_xticklabels(models.keys())
pyplot.show()

结果如下:

LR: 0.916341 (0.017975)
LDA: 0.906424 (0.013931)
KNN: 0.838157 (0.013883)
CART: 0.893944 (0.023094)
NB: 0.235954 (0.023780)
SVM: 0.916711 (0.017377)


可以看出SVM效果最好,使用SVM分类(SVM训练时间可能较长):

svm=SVC()
svm.fit(X=X_train, y=Y_train)
predictions = svm.predict(X_validation)
print(accuracy_score(Y_validation, predictions))
print(confusion_matrix(Y_validation, predictions))
print(classification_report(Y_validation, predictions))
joblib.dump(svm,"letters.model")
0.9046920821114369
[[502  29]
 [ 36 115]]
              precision    recall  f1-score   support

           0       0.93      0.95      0.94       531
           1       0.80      0.76      0.78       151

    accuracy                           0.90       682
   macro avg       0.87      0.85      0.86       682
weighted avg       0.90      0.90      0.90       682

准确率在90%左右。由于大小写是分离的,做到这个准确率已经很不错了。如果应用不需要区分大小写,可以将同一个字母的大小写放在一起,大家可以尝试一下。
这段程序导出了一个模型。下次使用时,可以使用与处理训练图片相同的方法进行图片处理,然后使用模型预测。
有任何不足之处,欢迎评论区留言

标签:...,Python,字母,models,train,import,手写,csv,sklearn
来源: https://blog.csdn.net/weixin_54756357/article/details/116567160

本站声明: 1. iCode9 技术分享网(下文简称本站)提供的所有内容,仅供技术学习、探讨和分享;
2. 关于本站的所有留言、评论、转载及引用,纯属内容发起人的个人观点,与本站观点和立场无关;
3. 关于本站的所有言论和文字,纯属内容发起人的个人观点,与本站观点和立场无关;
4. 本站文章均是网友提供,不完全保证技术分享内容的完整性、准确性、时效性、风险性和版权归属;如您发现该文章侵犯了您的权益,可联系我们第一时间进行删除;
5. 本站为非盈利性的个人网站,所有内容不会用来进行牟利,也不会利用任何形式的广告来间接获益,纯粹是为了广大技术爱好者提供技术内容和技术思想的分享性交流网站。

专注分享技术,共同学习,共同进步。侵权联系[81616952@qq.com]

Copyright (C)ICode9.com, All Rights Reserved.

ICode9版权所有