目录
- 1 介绍
- 2 数据爬虫
- 3 模型训练和验证
- 3.1 模型训练
- 3.2 导入一张图片进行验证
- 4 后台flask部署
- 5 微信小程序
1 介绍
本项目使用深度学习模型,训练5种中药材数据集,然后将其集成到微信小程序,通过微信小程序拍照,将图片传输给后端,后端将返回的结果展示到前端页面,项目主要包含以下内容:
- 数据爬取:使用爬虫爬取百度图片,可以自己定义要爬取的中草药种类、数量等信息。
- 模型训练使用基于keras训练分类模型,模型可以修改,例如:ResNet50系列,MobileNet系列等,支持在gpu、cpu训练。
- 后台flask部署:使用flask将模型部署到后台,提供ip地址和端口号
- 前端微信小程序:制作前端的微信小程序页面,将图片传输给后端,并且将分类结果返回到前端展示
2 数据爬虫
使用requests进行爬虫
示例:
for i in range(30):
image_url = result['data'][i]['middleURL']
image_name = "%d.jpg" % count
response = requests.get(image_url, headers=headers, stream=True, timeout=10)
with open(os.path.join(download_path, image_name), 'wb') as f:
f.write(response.content)
count += 1
爬取输入参数,可以自己输入爬取哪些中草药,输入到list里面即可,下面展示只爬取两种中草药。
# 设置搜索关键字和爬取图片的数量
name_list = ['枸杞','金银花']
save_path = "data_爬虫"
page_num = 1 #爬取多少页,每页30个
for keyword in name_list:
get_images(save_path, keyword, page_num)
3 模型训练和验证
此处,我们分别使用keras版本进行训练和验证,具体代码和结果展示如下:
3.1 模型训练
导入必要的包
from keras.preprocessing.image import ImageDataGenerator
from keras.optimizers import Adam
from keras.applications import MobileNetV2
from keras.layers import GlobalAveragePooling2D, Dense
from keras.models import Sequential
import json
# 定义ImageDataGenerator
datagen = ImageDataGenerator(
rescale=1./255,
shear_range=0.2,
zoom_range=0.2,
horizontal_flip=True,
validation_split=0.2 # 设置验证集的比例
)
base_model = MobileNetV2(weights='imagenet', include_top=False, input_shape=(224, 224, 3))
model = Sequential([
base_model,
GlobalAveragePooling2D(),
Dense(128, activation='relu'),
Dense(num_classes, activation='softmax')
])
# 训练模型
model.fit(
train_generator,
steps_per_epoch=train_generator.samples // batch_size,
epochs=10,
validation_data=validation_generator,
validation_steps=validation_generator.samples // batch_size)
# 步骤6: 评估模型性能
eval_result = model.evaluate(validation_generator)
print(f"Test accuracy: {eval_result[1]*100:.2f}%")
部分结果截图
36/36 [==============================] - 22s 449ms/step - loss: 0.7144 - accuracy: 0.7664 - val_loss: 0.7706 - val_accuracy: 0.7278
Epoch 2/10
36/36 [==============================] - 13s 352ms/step - loss: 0.1504 - accuracy: 0.9601 - val_loss: 0.5325 - val_accuracy: 0.8278
Epoch 3/10
36/36 [==============================] - 13s 352ms/step - loss: 0.0959 - accuracy: 0.9829 - val_loss: 0.2743 - val_accuracy: 0.9222
Epoch 4/10
36/36 [==============================] - 13s 351ms/step - loss: 0.0896 - accuracy: 0.9758 - val_loss: 0.3960 - val_accuracy: 0.8500
Epoch 5/10
36/36 [==============================] - 13s 354ms/step - loss: 0.0743 - accuracy: 0.9758 - val_loss: 0.2853 - val_accuracy: 0.9111
Epoch 6/10
36/36 [==============================] - 13s 351ms/step - loss: 0.0525 - accuracy: 0.9829 - val_loss: 0.2473 - val_accuracy: 0.9222
3.2 导入一张图片进行验证
导入图片
import cv2
import numpy as np
import json
from keras.models import load_model
def get_img(img_path,img_width, img_height ):
img = cv2.imread(img_path)
img = cv2.resize(img, (img_width, img_height)) # 调整图像大小
img = img.astype("float") / 255.0 # 数据预处理,确保与训练时一致
img = np.expand_dims(img, axis=0)
return img
img_width = 224
img_height = 224
model = load_model(r'E:\project\1-zhongcaoyao\model-keras.h5')
print(class_indict)
img_file_path = 'data_all/baihe/b (20).jpg'
classify_img = get_img(img_file_path,img_width, img_height)
results = np.squeeze(model.predict(classify_img)).astype(np.float64) # 获得预测结果(注意:1.降维2.json中的小数类型为float)
predict_class = np.argmax(results) # 获得预测结果中置信度最大值所对应的下标
例如:我们导入一张百合的图片,下面是输出结果。
注意,可能会出现如下错误,原因是模型路径包含中文名称,只需要把模型放到全英文路径下就行。
DecodeError: 'utf-8' codec can't decode byte 0xc6 in position 10: invalid continuation byte
4 后台flask部署
app = flask.Flask(__name__)
idx2class = {0:"百合",1:"党参",2:"枸杞",3:"槐花",4:"金银花"}
idx2info ={}
# 导入药效信息
with open("info.txt", "r", encoding="UTF-8") as fin:
lines = fin.readlines()
for line in lines:
idx = int(line.strip().split(":")[0])
info = line.strip().split(":")[1]
idx2info[idx] = info
img_bytes = flask.request.form.get('picture') # 获取值
image = base64.b64decode(img_bytes)# 编码转换
image = Image.open(io.BytesIO(image))
classify_img = prepare_image(image,224,224) # 预处理图像
results = np.squeeze(model.predict(classify_img)).astype(np.float64) # 获得预测结果(注意:1.降维2.json中的小数类型为float)
predicted_idx = np.argmax(results) # 获得预测结果中置信度最大值所对应的下标
score = results[predicted_idx]
label_name = idx2class[predicted_idx]
label_info = idx2info[predicted_idx]
5 微信小程序
我们使用一个界面,完成图片的上传,结果展示等
核心代码,将图片传输到后台,并且将data结果拿回来,再解析里面的各个字段,最后将字段展示出来。
wx.request({
url: 'http://127.0.0.1:8080/predict', //本地服务器地址
method: 'POST',
header: {
'content-type': 'application/x-www-form-urlencoded'
},
data: {
"picture": that.data.picture,
},
success: (res)=>{
that.setData({
class_name: res.data['class_name'],
prob: res.data['prob'],
info:res.data['info']
})
以上就是所有的内容,包含了前端后端、模型训练、数据爬取等功能,详细咨询完整代码:https://docs.qq.com/doc/DWEtRempVZ1NSZHdQ