knn分类算法实现手写体数字识别python
之前写过knn分类算法代码,想把knn用于设别手写体数字,看下正确率。
大概思路:获取图片(可以自己写,我之前有写过黑白图片转文本的代码,也可以网上找,反正数据量大会更好)->转成文本->建立大量的训练数据集->建立好训练数据与类别的关联->测试
注意:训练数据一定要明确给出类别。本次实验手写体数字一共就是10中类别,0-9
获取图片转成文本之前写过,跳过,直接从建立训练数据开始。
首先加载数据,图片保存在文本里,不方便处理,转成数组。这里的32都是保存图片的宽和高(px),按照具体图片大小决定。
之后将所有训练数据存储在一个数组里,一个文本存储在数组的一行里,一共有多少文本,就有多少行,列数是固定的,32*32=1024
上图就是建立了类别和训练数据的联系。
测试数据,用knns算法去给测试数据分类。
单个手写体数字文件识别:
trainarray,labels=traindata()
tfile="1_32.txt"#注意:1是该文本的真实类别,32是该类别第32个数据
tarray=datatoarray("D:/xx/testdata"+tfile)
result=knn(4,tarray,trainarray,labels)
print(result)
批量手写体数字文件识别:
结果:
一共是964个文件,设别错误11个,k为4,可以看出KNN正确率还是可以的。
源码:
from numpy import *
import operator
from os import listdir
def knn(k, testdata, traindata, labels):
traindatasize = traindata.shape[0]
dif = tile(testdata, (traindatasize, 1)) - traindata
sqdif = dif ** 2
sumsqdif = sqdif.sum(axis=1)
distance = sumsqdif ** 0.5
sortdistance = distance.argsort()
count = {}
for i in range(0, k):
vote = labels[sortdistance[i]]
count[vote] = count.get(vote, 0) + 1
sortcount = sorted(count.items(), key=operator.itemgetter(1), reverse=True)
return sortcount[0][0]
from PIL import Image
im=Image.open("C:/xx/xx/3.jpg")
fh=open("C:/xx/xx/3_20.txt","a")
width=im.size[0]
height=im.size[1]
for i in range(0,width):
for j in range(0,height):
cl=im.getpixel((i,j))
clall=cl[0]+cl[1]+cl[2]
if(clall==0):
fh.write("1")
else:
fh.write("0")
fh.write(" ")
fh.close()
def datatoarray(fname):
arr = []
fh = open(fname)
for i in range(0, 32):
thisline = fh.readline()
for j in range(0, 32):
arr.append(int(thisline[j]))
return arr
# 建立一个函数取文件名前缀
def seplabel(fname):
filestr = fname.split(".")[0]
label = int(filestr.split("_")[0])
return label
def traindata():
labels = []
trainfile = listdir("D:/xx/traindata")
num = len(trainfile)
trainarr = zeros((num, 1024))
for i in range(0, num):
thisfname = trainfile[i]
thislabel = seplabel(thisfname)
labels.append(thislabel)
trainarr[i, :] = datatoarray("D:/xx/traindata/" + thisfname)
return trainarr, labels
def datatest():
trainarr, labels = traindata()
testlist = listdir("D:/xx/testdata")
tnum = len(testlist)
count = 0
for i in range(0, tnum):
thistestfile = testlist[i]
reallabel = seplabel(thistestfile)
testarr = datatoarray("D:/xx/" + thistestfile)
rknn = knn(3, testarr, trainarr, labels)
if (rknn != reallabel):
count = count + 1
print("kNN识别的是" + str(rknn) + "错误,真实类别是" + str(reallabel))
print("KNN正确率:" + str((tnum - count) / tnum))
datatest()
#抽某一个测试文件出来进行试验
trainarr,labels=traindata()
testfile=listdir("D:/pythonlianxi/result/traindata")
for i in range(0,len(testfile)):
thisfname=testfile[i]
reallabel=seplabel(thisfname)
testarr[i,:]=datatoarray("D:/pythonlianxi/result/testdata/"+testfile[i])
rknn=knn(4,testarr,trainarr,labels)
print(rknn)