目录
- 1. 说明
- 2. 手写数字识别的ANN模型测试
- 2.1 导入相关库
- 2.2 加载数据和模型
- 2.3 设置保存图片的路径
- 2.4 加载图片
- 2.5 图片预处理
- 2.6 对图片进行预测
- 2.7 显示图片
- 3. 完整代码和显示结果
- 4. 多张图片进行测试的完整代码以及结果
1. 说明
本篇文章是对上篇文章训练的模型进行测试。首先是将训练好的模型进行重新加载,然后采用opencv对图片进行加载,最后将加载好的图片输送给模型并且显示结果。
2. 手写数字识别的ANN模型测试
2.1 导入相关库
在这里导入需要的第三方库如cv2,如果没有,则需要自行下载。
from tensorflow import keras
# 引入内置手写体数据集mnist
from keras.datasets import mnist
import skimage, os, sys, cv2
from PIL import ImageFont, Image, ImageDraw # PIL就是pillow包(保存图像)
import numpy as np
2.2 加载数据和模型
把MNIST数据集进行加载,并且把训练好的模型也加载进来。
# 加载mnist数据
(x_train, y_train), (x_test, y_test) = mnist.load_data()
# 加载ann_mnist.h5文件,重新生成模型对象, 等价于之前训练好的ann_model
recons_model = keras.models.load_model('ann_mnist.h5')
2.3 设置保存图片的路径
将数据集的某个数据以图片的形式进行保存,便于测试的可视化。
在这里设置图片存储的位置。
# 创建图片保存路径
test_file_path = os.path.join(sys.path[0], 'imgs', 'test1.png')
# 存储测试数据的任意一个
Image.fromarray(x_test[1]).save(test_file_path)
在书写完上述代码后,需要在代码的当前路径下新建一个imgs的文件夹用于存储图片,如下。
执行完上述代码后就会在imgs的文件中可以发现多了一张图片,如下(下面测试了很多次)。
2.4 加载图片
采用cv2对图片进行加载,下面最后一行代码取一个通道的原因是用opencv库也就是cv2读取图片的时候,图片是三通道的,而训练的模型是单通道的,因此取单通道。
# 加载本地test.png图像
image = cv2.imread(test_file_path)
# 复制图片
test_img = image.copy()
# 将图片大小转换成(28,28)
test_img = cv2.resize(test_img, (28, 28))
# 取单通道值
test_img = test_img[:, :, 0]
2.5 图片预处理
对图片进行预处理,即进行归一化处理和改变形状处理,这是为了便于将图片输入给训练好的模型进行预测。
# 预处理: 归一化 + reshape
new_test_img = (test_img/255.0).reshape(1, 784)
2.6 对图片进行预测
将图片输入给训练好我的模型并且进行预测。
预测的结果是10个概率值,所以需要进行处理, np.argmax()是得到概率值最大值的序号,也就是预测的数字。
# 预测
y_pre_pro = recons_model.predict(new_test_img, verbose=1)
# 哪一类数字
class_id = np.argmax(y_pre_pro, axis=1)[0]
print('test.png的预测概率:', y_pre_pro)
print('test.png的预测概率:', y_pre_pro[0, class_id])
print('test.png的所属类别/手写体数字:', class_id)
class_id = str(class_id)
2.7 显示图片
对预测的图片进行显示,把预测的数字显示在图片上。
下面6行代码分别是创建窗口,设定窗口大小,显示数字,显示图片,停留图片,清除内存。
# # 显示
cv2.namedWindow('img', 0)
cv2.resizeWindow('img', 500, 500) # 自己设定窗口图片的大小
cv2.putText(image, class_id, (2, 5), cv2.FONT_HERSHEY_SCRIPT_SIMPLEX, 0.2, (255, 0, 0), 1)
cv2.imshow('img', image)
cv2.waitKey()
cv2.destroyAllWindows()
3. 完整代码和显示结果
以下是完整的代码和图片显示结果。
from tensorflow import keras
# 引入内置手写体数据集mnist
from keras.datasets import mnist
import skimage, os, sys, cv2
from PIL import ImageFont, Image, ImageDraw # PIL就是pillow包(保存图像)
import numpy as np
# 加载mnist数据
(x_train, y_train), (x_test, y_test) = mnist.load_data()
# 加载ann_mnist.h5文件,重新生成模型对象, 等价于之前训练好的ann_model
recons_model = keras.models.load_model('ann_mnist.h5')
# 创建图片保存路径
test_file_path = os.path.join(sys.path[0], 'imgs', 'test1.png')
# 存储测试数据的任意一个
Image.fromarray(x_test[1]).save(test_file_path)
# 加载本地test.png图像
image = cv2.imread(test_file_path)
# 复制图片
test_img = image.copy()
# 将图片大小转换成(28,28)
test_img = cv2.resize(test_img, (28, 28))
# 取单通道值
test_img = test_img[:, :, 0]
# 预处理: 归一化 + reshape
new_test_img = (test_img/255.0).reshape(1, 784)
# 预测
y_pre_pro = recons_model.predict(new_test_img, verbose=1)
# 哪一类数字
class_id = np.argmax(y_pre_pro, axis=1)[0]
print('test.png的预测概率:', y_pre_pro)
print('test.png的预测概率:', y_pre_pro[0, class_id])
print('test.png的所属类别/手写体数字:', class_id)
class_id = str(class_id)
# # 显示
cv2.namedWindow('img', 0)
cv2.resizeWindow('img', 500, 500) # 自己设定窗口图片的大小
cv2.putText(image, class_id, (2, 5), cv2.FONT_HERSHEY_SCRIPT_SIMPLEX, 0.2, (255, 0, 0), 1)
cv2.imshow('img', image)
cv2.waitKey()
cv2.destroyAllWindows()
4. 多张图片进行测试的完整代码以及结果
为了测试更多的图片,引入循环进行多次测试,效果更好。
# python练习
# 重新学习时间:2023/4/30 23:45
from tensorflow import keras
# 引入内置手写体数据集mnist
from keras.datasets import mnist
import skimage, os, sys, cv2
from PIL import ImageFont, Image, ImageDraw # PIL就是pillow包(保存图像)
import numpy as np
# 加载mnist数据
(x_train, y_train), (x_test, y_test) = mnist.load_data()
# 加载ann_mnist.h5文件,重新生成模型对象, 等价于之前训练好的ann_model
recons_model = keras.models.load_model('ann_mnist.h5')
prepicture = int(input("input the number of test picture :"))
for i in range(prepicture):
path1 = input("input the test picture path:")
# 创建图片保存路径
test_file_path = os.path.join(sys.path[0], 'imgs', path1)
# 存储测试数据的任意一个
num = int(input("input the test picture num:"))
Image.fromarray(x_test[num]).save(test_file_path)
# 加载本地test.png图像
image = cv2.imread(test_file_path)
# 复制图片
test_img = image.copy()
# 将图片大小转换成(28,28)
test_img = cv2.resize(test_img, (28, 28))
# 取单通道值
test_img = test_img[:, :, 0]
# 预处理: 归一化 + reshape
new_test_img = (test_img/255.0).reshape(1, 784)
# 预测
y_pre_pro = recons_model.predict(new_test_img, verbose=1)
# 哪一类数字
class_id = np.argmax(y_pre_pro, axis=1)[0]
print('test.png的预测概率:', y_pre_pro)
print('test.png的预测概率:', y_pre_pro[0, class_id])
print('test.png的所属类别/手写体数字:', class_id)
class_id = str(class_id)
# # 显示
cv2.namedWindow('img', 0)
cv2.resizeWindow('img', 500, 500) # 自己设定窗口图片的大小
cv2.putText(image, class_id, (2, 5), cv2.FONT_HERSHEY_SCRIPT_SIMPLEX, 0.2, (255, 0, 0), 1)
cv2.imshow('img', image)
cv2.waitKey()
cv2.destroyAllWindows()
下面的test picture num指的是数据集中该数据的序号(0-59999),并不是值实际的数字。
2023-07-18 21:24:54.034234: I tensorflow/core/platform/cpu_feature_guard.cc:193] This TensorFlow binary is optimized with oneAPI Deep Neural Network Library (oneDNN) to use the following CPU instructions in performance-critical operations: AVX AVX2
To enable them in other operations, rebuild TensorFlow with the appropriate compiler flags.
input the number of test picture :5
input the test picture path:61.jpg
input the test picture num:1
1/1 [==============================] - 0s 212ms/step
test.png的预测概率: [[6.7599565e-11 5.6974045e-08 9.9999976e-01 1.4167172e-08 4.2876313e-14
8.5433702e-17 9.8270281e-12 2.0837895e-07 2.0044362e-13 3.8371804e-15]]
test.png的预测概率: 0.99999976
test.png的所属类别/手写体数字: 2
input the test picture path:62.jpg
input the test picture num:2
1/1 [==============================] - 0s 25ms/step
test.png的预测概率: [[2.95021305e-08 9.99796808e-01 5.78483643e-08 1.15721946e-07
1.02379022e-06 1.07751411e-07 1.75613415e-04 1.84143373e-05
7.72468411e-06 8.39250518e-08]]
test.png的预测概率: 0.9997968
test.png的所属类别/手写体数字: 1
input the test picture path:63.jpg
input the test picture num:3
1/1 [==============================] - 0s 26ms/step
test.png的预测概率: [[9.9962425e-01 7.8167646e-11 6.5924123e-06 9.7057705e-07 2.3867991e-11
3.1169588e-04 5.6094854e-05 9.8954046e-11 1.0871034e-08 3.3060348e-07]]
test.png的预测概率: 0.99962425
test.png的所属类别/手写体数字: 0
input the test picture path:64.jpg
input the test picture num:4
1/1 [==============================] - 0s 30ms/step
test.png的预测概率: [[1.3954380e-09 5.2584750e-07 7.7287673e-08 2.3394799e-08 9.9983513e-01
4.9446136e-10 1.9493827e-06 4.0978726e-08 3.1354301e-07 1.6186526e-04]]
test.png的预测概率: 0.99983513
test.png的所属类别/手写体数字: 4
input the test picture path:65.jpg
input the test picture num:5
1/1 [==============================] - 0s 47ms/step
test.png的预测概率: [[4.70661676e-10 9.99986053e-01 5.76763526e-10 1.16811161e-09
5.13054097e-08 5.98078254e-10 1.21732055e-05 1.10577037e-06
5.98011809e-07 1.74244752e-09]]
test.png的预测概率: 0.99998605
test.png的所属类别/手写体数字: 1