基于深度学习的手写汉字识别系统(含PyQt+代码+训练数据集)
- 前言
- 一、数据集
- 1.1 数据集介绍
- 1.2 数据预处理
- 二、模型搭建
- 三、训练与测试
- 3.1 模型训练
- 3.2 模型测试
- 四、PyQt界面实现
- 参考资料
前言
本项目是基于深度学习网络模型的人脸表情识别系统,核心采用CNN卷积神经网络搭建,详述了数据集处理、模型构建、训练代码、以及基于PyQt5的应用界面设计。在应用中可以支持手写汉字图像的识别。本文附带了完整的应用界面设计、深度学习模型代码和训练数据集的下载链接。
完整资源下载链接:博主在面包多网站上的完整资源下载页
项目演示视频:
【项目分享】基于深度学习的手写汉字识别系统(含PyQt+代码+训练数据集)
一、数据集
1.1 数据集介绍
本项目的数据集在下载后的data
文件夹下,主要分为训练数据集train
和测试数据集test
,如下图所示。
以训练集为例,如下图所示。其中训练集包含零、一、计、算、机等20个中文手写汉字,图像共计4757张。可以自己添加相应的手写汉字,可获取的 手写中文数据集链接。
1.2 数据预处理
首先,加载数据集中的图像文件,并将它们调整为相同的大小(64x64)。然后,根据文件所在的目录结构,为每个图像文件分配一个标签(label),标签是根据文件所在的子目录来确定的。最后,使用 train_test_split
函数将数据集划分为训练集和验证集,以便后续模型训练和评估。
def load_data(filepath):
# 遍历filepath下所有文件,包括子目录
files = os.listdir(filepath)
for fi in files:
fi_d = os.path.join(filepath, fi+'/')
if os.path.isdir(fi_d):
global label
load_data(fi_d)
label += 1
else:
labels.append(label)
img = mi.imread(fi_d[:-1])
img2 = cv2.resize(img, (64, 64))
dataset.append(img2)
# 在训练集中取一部分作为验证集
train_image, val_image, train_label, val_label = train_test_split(
np.array(dataset), np.array(labels), random_state=7)
return train_image, val_image, train_label, val_label
二、模型搭建
CNN(卷积神经网络)主要包括卷积层、池化层和全连接层。输入数据经过多个卷积层和池化层提取图片信息后,最后经过若干个全连接层获得最终的输出。CNN的实现主要包括以下步骤:数据加载与预处理、模型搭建、定义损失函数和优化器、模型训练、模型测试。想了解更多关于CNN卷积神经网络的请自行百度。本项目基于tensorflow实现的,并搭建如下图所示的CNN网络模型。
具体代码:
def get_model():
k.clear_session()
# 创建一个新模型
model = Sequential()
model.add(Conv2D(32, 3, padding='same', activation='relu', input_shape=(64, 64, 3)))
model.add(MaxPooling2D(pool_size=(2, 2)))
model.add(Conv2D(64, 3, padding='same', activation='relu'))
model.add(MaxPooling2D(pool_size=(2, 2)))
model.add(Conv2D(128, 3, padding='same', activation='relu'))
model.add(MaxPooling2D(pool_size=(2, 2)))
model.add(Flatten())
model.add(Dropout(0.2))
model.add(Dense(512, activation='relu'))
model.add(Dropout(0.2))
model.add(Dense(20, activation='softmax'))
model.summary()
# 选择优化器和损失函数
model.compile(optimizer='adam',
loss='sparse_categorical_crossentropy',
metrics=['accuracy'])
return model
三、训练与测试
3.1 模型训练
使用Keras库和自定义的CNN模型进行手写汉字识别的训练。首先,设置了一些关键的训练参数,如训练周期、选择模型、迭代器等参数。
epochs = 8 # 选择批次
model = get_model.get_model() # 选择模型
# 加载训练数据和测试数据
(train_image, val_image, train_label, val_label) = get_array_1.load_data('data/train/')
(test_image, test_label) = get_array_2.load_data('data/test/')
# 训练, fit方法自带shuffle随机读取
history = model.fit(
train_image, train_label, epochs=epochs, validation_data=(val_image, val_label))
# 测试, 单用evaluate方法不会自动输出数值,需要手动输出他返回的两个数值
test_scores = model.evaluate(test_image, test_label)
epochs_range = range(1, epochs+1)
train_loss = history.history['loss']
val_loss = history.history['val_loss']
test_loss = test_scores[0]
train_accuracy = history.history['accuracy']
val_accuracy = history.history['val_accuracy']
test_accuracy = test_scores[1]
# 将模型保存为 HDF5 文件
model.save('Chinese_recognition_model.h5')
print("save model: Chinese_recognition_model.h5")
# 绘制图表
get_pyplot.show(epochs_range, train_loss, val_loss, train_accuracy, val_accuracy)
# 打印得分
print('')
print('train loss:', train_loss[-1], ' ', 'train accuracy:', train_accuracy[-1])
print('val loss:', val_loss[-1], ' ', 'val accuracy:', val_accuracy[-1])
print('test loss:', test_loss, ' ', 'test accuracy:', test_accuracy)
print('')
然后使用fit.py
文件进行训练,整个训练过程的损失、准确率如下图所示。
3.2 模型测试
使用predict.py
对手写汉字图像进行识别测试,具体代码实现如下:
import cv2
import numpy as np
import tensorflow as tf
import matplotlib.image as mpimg
class_names = ['零', '一', '二', '三', '四', '五', '六', '七', '八', '九', '肇', '庆', '学', '院',
'计', '算', '机', '杨', '先', '生']
model = tf.keras.models.load_model('Chinese_recognition_model.h5')
img = mpimg.imread('data/test/r/29.png')
img2 = cv2.resize(img, (64, 64))
img3 = np.zeros((1, img2.shape[0], img2.shape[1], img2.shape[2]))
img3[0, :] = img2
pre = model.predict(img3) # 预测
pre = np.argmax(pre, axis=1)
result = class_names[pre[0]]
print('预测结果:', result)
测试结果:
四、PyQt界面实现
当整个项目构建完成后,使用PyQt5编写可视化界面,可以支持输入手写汉字图像进行识别。
整个界面显示代码如下:
class Ui_MainWindow(object):
def setupUi(self, MainWindow):
MainWindow.setObjectName("MainWindow")
MainWindow.resize(745, 590)
MainWindow.setStyleSheet("background-color: rgb(255, 255, 255);")
self.centralwidget = QtWidgets.QWidget(MainWindow)
self.centralwidget.setObjectName("centralwidget")
self.label = QtWidgets.QLabel(self.centralwidget)
self.label.setGeometry(QtCore.QRect(200, 30, 411, 81))
self.label.setStyleSheet("font: 28pt \"黑体\";")
self.label.setObjectName("label")
self.label_2 = QtWidgets.QLabel(self.centralwidget)
self.label_2.setGeometry(QtCore.QRect(80, 180, 271, 261))
self.label_2.setStyleSheet("background-color: rgb(234, 234, 234);")
self.label_2.setText("")
self.label_2.setObjectName("label_2")
self.pushButton = QtWidgets.QPushButton(self.centralwidget)
self.pushButton.setGeometry(QtCore.QRect(160, 470, 101, 51))
self.pushButton.setStyleSheet("background-color: rgb(226, 226, 226);\n"
"font: 12pt \"黑体\";")
self.pushButton.setObjectName("pushButton")
self.label_3 = QtWidgets.QLabel(self.centralwidget)
self.label_3.setGeometry(QtCore.QRect(410, 170, 191, 81))
self.label_3.setStyleSheet("font: 22pt \"黑体\";\n"
"background-color: transparent;")
self.label_3.setObjectName("label_3")
self.label_4 = QtWidgets.QLabel(self.centralwidget)
self.label_4.setGeometry(QtCore.QRect(540, 290, 81, 71))
self.label_4.setStyleSheet("font: 40pt \"黑体\";\n"
"background-color: rgb(234, 234, 234);\n"
"background-color: transparent;")
self.label_4.setObjectName("label_4")
self.pushButton_2 = QtWidgets.QPushButton(self.centralwidget)
self.pushButton_2.setGeometry(QtCore.QRect(520, 470, 101, 51))
self.pushButton_2.setStyleSheet("background-color: rgb(226, 226, 226);\n"
"font: 12pt \"黑体\";")
self.pushButton_2.setObjectName("pushButton_2")
self.label_5 = QtWidgets.QLabel(self.centralwidget)
self.label_5.setGeometry(QtCore.QRect(490, 260, 151, 141))
self.label_5.setStyleSheet("background-color: rgb(234, 234, 234);")
self.label_5.setText("")
self.label_5.setObjectName("label_5")
self.label_5.raise_()
self.label.raise_()
self.label_2.raise_()
self.pushButton.raise_()
self.label_3.raise_()
self.label_4.raise_()
self.pushButton_2.raise_()
MainWindow.setCentralWidget(self.centralwidget)
self.statusbar = QtWidgets.QStatusBar(MainWindow)
self.statusbar.setObjectName("statusbar")
MainWindow.setStatusBar(self.statusbar)
self.retranslateUi(MainWindow)
QtCore.QMetaObject.connectSlotsByName(MainWindow)
self.pushButton.clicked.connect(self.open_img) # 图片选择按钮 连接open_img函数
self.pushButton_2.clicked.connect(self.detect)
def open_img(self):
self.fname, _ = QFileDialog.getOpenFileName(None, 'open file', '', "*.jpg;*.png;;All Files(*)")
print(self.fname)
img = cv2.imdecode(np.fromfile(self.fname, dtype=np.uint8), -1)
# img = cv2.imread(self.fname)
show = cv2.resize(img, (271, 261))
cv2.imwrite('./linshi.png', show)
self.label_2.setStyleSheet("image: url(./linshi.png)")
def detect(self):
class_names = ['零', '一', '二', '三', '四', '五', '六', '七', '八', '九', '肇', '庆', '学', '院',
'计', '算', '机', '杨', '先', '生']
model = tf.keras.models.load_model('Chinese_recognition_model.h5')
img = mpimg.imread(self.fname)
img2 = cv2.resize(img, (64, 64))
img3 = np.zeros((1, img2.shape[0], img2.shape[1], img2.shape[2]))
img3[0, :] = img2
pre = model.predict(img3) # 预测
pre = np.argmax(pre,axis=1)
result = class_names[pre[0]]
print('预测结果:', result)
self.label_4.setText(result)
def retranslateUi(self, MainWindow):
_translate = QtCore.QCoreApplication.translate
MainWindow.setWindowTitle(_translate("MainWindow", "MainWindow"))
self.label.setText(_translate("MainWindow", "手写汉字识别系统"))
self.pushButton.setText(_translate("MainWindow", "选择图像"))
self.label_3.setText(_translate("MainWindow", "识别结果:"))
self.label_4.setText(_translate("MainWindow", ""))
self.pushButton_2.setText(_translate("MainWindow", "开始检测"))
界面显示效果:
参考资料
- 手写中文数据集
- CNN实现手写数字识别(Pytorch)