机器学习(八)-基于KNN分类算法的手写识别系统
内容导读
互联网集市收集整理的这篇技术教程文章主要介绍了机器学习(八)-基于KNN分类算法的手写识别系统,小编现在分享给大家,供广大互联网技能从业者学习和参考。文章包含3677字,纯文字阅读大概需要6分钟。
内容图文
1 项目介绍
基于k-近邻分类器(KNN)的手写识别系统, 这里构造的系统只能识别数字0到9。
- 难点: 图形信息如何处理?
图像转换为文本格式
2 准备数据
将图像转换为测试向量
-
训练集:
- 目录trainingDigits
- 大约2000个例子
- 每个数字大约有200个样本;
-
测试集
- 目录testDigits
- 大约900个测试数据。
将图像格式化处理为一个向量。我们将把一个32×32的二进制图像矩阵转换为1×1024的向量, 如下图所示,
import numpy as np
def img2vector(filename):
"""
# 将图像数据转换为(1,1024)向量
:param filename:
:return: (1,1024)向量
"""
# 生成一个1*1024且值全为0的向量;
returnVect = np.zeros((1, 1024))
# 读取要转换的信息;
file = open(filename)
# 依次填充
# 读取每一行数据;
for i in range(32):
lineStr = file.readline()
# 读取每一列数据;
for j in range(32):
returnVect[0, 32 * i + j] = int(lineStr[j])
return returnVect
3 实施 KNN 算法
对未知类别属性的数据集中的每个点依次执行以下操作, 与上一个案例代码相同:
(1) 计算已知类别数据集中的点与当前点之间的距离;
(2) 按照距离递增次序排序;
(3) 选取与当前点距离最小的k个点;
(4) 确定前k个点所在类别的出现频率;
(5) 返回前k个点出现频率最高的类别作为当前点的预测分类。
def classify(inX, dataSet, labels, k):
"""
:param inX: 要预测的数据
:param dataSet: 我们要传入的已知数据集
:param labels: 我们要传入的标签
:param k: KNN里的k, 也就是说我们要选几个近邻
:return: 排序的结果
"""
dataSetSize = dataSet.shape[0] # (6,2) 6
# tile会重复inX, 把他重复成(datasetsize, 1)型的矩阵
# print(inX)
# (x1 - y1), (x2- y2)
diffMat = np.tile(inX, (dataSetSize, 1)) - dataSet
# 平方
sqDiffMat = diffMat ** 2
# 相加, axis=1 行相加
sqDistance = sqDiffMat.sum(axis=1)
# 开根号
distances = sqDistance ** 0.5
# print(distances)
# 排序 输出的是序列号index,并不是值
sortedDistIndicies = distances.argsort()
# print(sortedDistIndicies)
classCount = {}
for i in range(k):
voteLabel = labels[sortedDistIndicies[i]]
classCount[voteLabel] = classCount.get(voteLabel, 0) + 1
# print(classCount)
sortedClassCount = sorted(classCount.items(), key=lambda d: float(d[1]), reverse=True)
return sortedClassCount[0]
4 测试算法
使用 k-近邻算法识别手写数字
- 测试集里面的信息;
def handWritingClassTest(k):
"""
# 测试手写数字识别错误率的代码
:param k:
:return:
"""
hwLabels = []
import os
# 读取所有的训练集文件;
trainingFileList = os.listdir('data/knn-digits/trainingDigits')
# 获取训练集个数;
m = len(trainingFileList)
# 生成m行1024列全为0的矩阵;
trainingMat = np.zeros((m, 1024))
# 填充训练集矩阵;
for i in range(m):
fileNameStr = trainingFileList[i] # fileNameStr: 0_0.txt
fileStr = fileNameStr.split('.')[0] # fileStr: 0_0
classNumStr = int(fileStr.split('_')[0]) # (数字分类的结果)classNumStr: 0
# 填写真实的数字结果;
hwLabels.append(classNumStr)
# 图形的数据: (1,1024)向量
trainingMat[i, :] = img2vector("data/knn-digits/trainingDigits/%s" % fileNameStr)
# 填充测试集矩阵;
testFileList = os.listdir('data/knn-digits/testDigits')
# 默认错误率为0;
errorCount = 0.0
# 测试集的总数;
mTest = len(testFileList)
# 填充测试集矩阵;
for i in range(mTest):
fileNameStr = testFileList[i]
fileStr = fileNameStr.split('.')[0]
classNumStr = int(fileStr.split('_')[0])
vectorTest = img2vector("data/knn-digits/testDigits/%s" % fileNameStr)
# 判断预测结果与真实结果是否一致?
result = classify(vectorTest, trainingMat, hwLabels, k)
if result != classNumStr:
# 如果不一致,则统计出来, 计算错误率;
errorCount += 1.0
print("[预测失误]:分类结果是:%d, 真实结果是:%d" % (result, classNumStr))
print("错误总数:%d" % errorCount)
print("错误率:%f" % (errorCount / mTest))
print("模型准确率:%f" %(1-errorCount / mTest))
return errorCount
print(handWritingClassTest(2))
- 效果展示
5 KNN算法手写识别的缺点
算法的执行效率并不高。
- 每个测试向量做2000次距离计算,每个距离计算包括了1024个维度浮点运算,总计要执行900次;
- 需要为测试向量准备2MB的存储空间
有没有更好的方法?
- k决策树就是k-近邻算法的优化版,可以节省大量的计算开销。
内容总结
以上是互联网集市为您收集整理的机器学习(八)-基于KNN分类算法的手写识别系统全部内容,希望文章能够帮你解决机器学习(八)-基于KNN分类算法的手写识别系统所遇到的程序开发问题。 如果觉得互联网集市技术教程内容还不错,欢迎将互联网集市网站推荐给程序员好友。
内容备注
版权声明:本文内容由互联网用户自发贡献,该文观点与技术仅代表作者本人。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如发现本站有涉嫌侵权/违法违规的内容, 请发送邮件至 gblab@vip.qq.com 举报,一经查实,本站将立刻删除。
内容手机端
扫描二维码推送至手机访问。