python实现——分类类型数据挖掘任务(图形识别分类任务)

  1. 分类类型数据挖掘任务

基于卷积神经网络(CNN)的岩石图像分类。有一岩石图片数据集,共300张岩石图片,图片尺寸224x224。岩石种类有砾岩(Conglomerate)、安山岩(Andesite)、花岗岩(Granite)、石灰岩(Limestone)、石英岩(Quartzite)和5种,每种岩石图片各50张,共250张。请选择合适模型对该数据集进行建模,训练优化模型并给出模型评估指标,再利用GUI框架开发岩石图片分类界面。

1.1总体流程

1.2数据增强

定义:数据增强是利用现有数据生成新的数据来增加数据量的过程,能够有效地扩充训练数据集的大小,提高模型的泛化能力,同时也能够有效地防止过拟合现象的发生。

本项目采用的数据增强方法:

(1)水平翻转

(2)缩放

(3)旋转

(4)添加高斯噪音

(5)调整对比度和亮度

通过数据增强,数据集从之前的250张扩充至1500张,数据量为之前的6倍。

参考代码:

import cv2
import os
import glob
# 数据增强函数
def augment_data(img, save_path):
    rows, cols, _ = img.shape
    # 水平翻转图像
    img_flip = cv2.flip(img, 1)
    img_name = os.path.splitext(save_path)[0] + "_flip.jpg"
    cv2.imwrite(img_name, img_flip)
    print("Saved augmented image:", img_name)
    # 随机缩放图像
    scale = np.random.uniform(0.9, 1.1)
    M = cv2.getRotationMatrix2D((cols / 2, rows / 2), 0, scale)
    img_transformed = cv2.warpAffine(img, M, (cols, rows))
    img_name = os.path.splitext(save_path)[0] + "_transform.jpg"
    cv2.imwrite(img_name, img_transformed)
    print("Saved augmented image:", img_name)
    # 随机旋转图像
    angle = np.random.randint(-10, 10)
    M = cv2.getRotationMatrix2D((cols / 2, rows / 2), angle, 1)
    img_rotated = cv2.warpAffine(img, M, (cols, rows))
    img_name = os.path.splitext(save_path)[0] + "_rotated.jpg"
    cv2.imwrite(img_name, img_rotated)
    print("Saved augmented image:", img_name)
    # 添加高斯噪音
    mean = 0
    std = np.random.uniform(5, 15)
    noise = np.zeros(img.shape, np.float32)
    cv2.randn(noise, mean, std)
    noise = np.uint8(noise)
    img_noisy = cv2.add(img, noise)
    img_name = os.path.splitext(save_path)[0] + "_noisy.jpg"
    cv2.imwrite(img_name, img_noisy)
    print("Saved augmented image:", img_name)
    # 随机调整对比度和亮度
    alpha = np.random.uniform(0.8, 1.2)
    beta = np.random.randint(-10, 10)
    img_contrast = cv2.convertScaleAbs(img, alpha=alpha, beta=beta)
    img_name = os.path.splitext(save_path)[0] + "_contrast.jpg"
    cv2.imwrite(img_name, img_contrast)
    print("Saved augmented image:", img_name)
    return img
# 读取 data 文件夹中的所有图片,并进行数据增强
data_dir = r"images"
save_dir = r"images2"
if not os.path.exists(save_dir):
    os.makedirs(save_dir)
# 使用 glob 库来遍历 data 文件夹中所有图像
for img_path in glob.glob(os.path.join(data_dir, "*.jpg")):
    img = cv2.imread(img_path)
    if img is None:
        print("Error: Unable to read image at", img_path)
        continue
    # 获取保存增强后的图片文件名
    img_name = os.path.basename(img_path)
    save_path = os.path.join(save_dir, img_name)
    # 数据增强
    augmented_img = augment_data(img, save_path)
    if augmented_img is not None:
        # 保存原始图片
        cv2.imwrite(save_path, img)
        print("Saved original image:", save_path)

 结果:

1.3数据预处理

将1500张图片依次读入并转化为可训练的数据(特征变量(X)和标签(Y))

代码:

import os
import cv2
import numpy as np
from PIL import Image
# 设置图片文件夹路径
image_folder = r"images2"
# 获取所有类别的文件夹名(假设每个文件夹是一个类别)
categories = os.listdir(image_folder)

# 初始化特征变量 X 和标签 Y 的列表
X_list = np.zeros((len(categories), 224, 224, 3))
Y_list = np.zeros((len(categories)))

i=0
for name in categories:
    img = Image.open(image_folder + '\\' +name)
    img_rgb = img.split()
    X_list[i,:,:,0] = np.array(img_rgb[0])/255
    X_list[i,:,:,1] = np.array(img_rgb[1])/255
    X_list[i,:,:,2] = np.array(img_rgb[2])/255
    Y_list[i] = name.split('_')[0]
    i+=1
# 将特征变量 X 和标签 Y 的列表转化为 NumPy 数组
X = np.array(X_list)
Y = np.array(Y_list)

# 打印特征变量 X 和标签 Y 的形状
print('特征变量 X 的形状:', X)
print('标签 Y 的形状:', Y)

1.4模型构建

1.4.1模型结构定义

模型参数:

参考代码:

from sklearn.model_selection import train_test_split
import seaborn as sns  
import matplotlib.pyplot as plt  
import tensorflow as tf
from sklearn.metrics import confusion_matrix  
from tensorflow.keras.layers import Conv2D, MaxPooling2D, Flatten, Dense
# 5个类别  
num_classes = 5  
# 输入图像的大小是224x224,有3个颜色通道(对于彩色图像)  
input_shape = (224, 224, 3)  
# 假设X和Y是您的原始数据  
# X: 图像数据,形状为(num_samples, 224, 224, 3)  
# Y: 标签数据,形状为(num_samples,) 并且是整数形式的标签(从0到4)  
# 将数据划分为训练集和测试集(只执行一次)  
x_train, x_test, y_train, y_test = train_test_split(X, Y, test_size=0.2, random_state=42)  
# 构建模型  
model = tf.keras.models.Sequential([  
    tf.keras.layers.Conv2D(6, (5, 5), strides=(1,1), activation='relu', input_shape=input_shape),  
    tf.keras.layers.MaxPooling2D((2,2), strides=2),  
    tf.keras.layers.Conv2D(16, (5,5), activation='relu'),  
    tf.keras.layers.MaxPooling2D((2,2), strides=2),  
    tf.keras.layers.Conv2D(120, (5,5), activation='relu'),  
    tf.keras.layers.Flatten(),  
    tf.keras.layers.Dense(84, activation='relu'),  
    tf.keras.layers.Dropout(0.3),  
    tf.keras.layers.Dense(num_classes, activation='softmax')  # 确保输出层的神经元数量与类别数量匹配  
])  
  
# 编译模型  
model.compile(loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=False),# 使用sparse categorical crossentropy损失函数   
              optimizer=tf.keras.optimizers.Adam(),  # 使用Adam优化器  
              metrics=['sparse_categorical_accuracy'])  # 监控准确率  
  
# 打印模型概述  
model.summary()  
  
# 使用model.fit()函数训练模型  
history = model.fit(x_train, y_train, epochs=10, validation_split=0.2)  

 

1.4.2模型译

编译参数参考:

# 优化器

optimizer='adam'

# 损失函数

loss='sparse_categorical_crossentropy'

# 评估指标

metrics=['sparse_categorical_accuracy']

1.5模型训练

1.5.1划分训练集和测试集

按照训练集:测试集=8:2的比例对数据集进行划分,建议使用sklearn库中的train_test_split函数。

1.5.2训练

使用fit函数对训练集进行拟合训练,并将训练过程中产生的历史数据history保存至变量中。

训练参数参考:

# 迭代次数

epochs=20

# 验证集比例

validation_split=0.2

1.5.3训练过程可视化

对history中保存下来的训练过程中的loss和sparse_categorical_accuracy的变化情况进行绘图。

参考代码:

# 获取训练和验证的准确率和损失  
acc = history.history['sparse_categorical_accuracy']  
val_acc = history.history['sparse_categorical_accuracy']  
loss = history.history['loss']  
val_loss = history.history['val_loss']  
  
# 使用model.evaluate()函数评估模型在测试集上的性能  
test_loss, test_accuracy = model.evaluate(x_test, y_test)  
print(f'Test accuracy: {test_accuracy}')  
  
# 使用model.predict()函数对新的图像进行预测。
plt.figure(figsize=(15,10))
plt.plot(history.epoch, history.history['loss'],label='loss')
plt.plot(history.epoch, history.history['val_loss'],label='var_loss')
plt.xlabel('Epoch')
plt.ylabel('loss')
plt.legend(loc='upper right')

plt.figure(figsize=(15,10))
plt.plot(history.epoch,history.history['sparse_categorical_accuracy'],label='sparse_categorical_accuracy')
plt.plot(history.epoch,history.history['val_sparse_categorical_accuracy'],label='val_sparse_categorical_accuracy')
plt.xlabel('Epoch')
plt.ylabel('sparse_categorical_accuracy')
plt.legend(loc='upper right')
plt.show()

plt.rcParams['font.sans-serif'] = ['SimHei'] 
y_pred = np.argmax(model.predict(x_test),axis=1)
cm = confusion_matrix(y_test, y_pred,labels=[0,1,2,3,4])
sns.heatmap(cm,annot=True,
            cmap="Blues",
            cbar=False,
            linewidths=2,
            linecolor='white',
            square=True,
            xticklabels=['砾岩','安山岩','花岗岩','石灰岩','石英岩'],
            yticklabels=['砾岩','安山岩','花岗岩','石灰岩','石英岩']
            )
plt.show

 

 

1.6.3保存模型

使用save函数对训练好的模型进行保存,方便后续使用。

参考代码:

model.save('roch_classification_cnn.h5')

1.7图形用户界面(GUI)开发

1.7.1配置开发工具

在PyCharm中配置QtDesigner和PyUIC工具。

注意:需提前在python环境中安装好PyQt5和PyQt5-tools库。

  1. 配置QtDesigner

Program:(对应designer.exe的路径)

Working directory: $FileDir$

  1. 配置PyUCI

Program:(对应pyuic5.exe的路径)

Arguments: $FileName$ -o $FileNameWithoutExtension$.py

Working directory: $FileDir$

配置完成后的界面:

1.7.2设计图形用户界面

在PyCharm中“Tools”—“External Tools”中打开QtDesigner

在QtDesigner主界面中选择创建Main Window,然后根据需求选择相应的控件进行设计。

设计界面参考:

设计好之后保存为.ui文件。

1.7.3 ui文件转换为代码

在PyCharm中右键点击.ui文件并使用PyUCI工具进行转换。

1.7.4代码与模型结合

将转化后的代码与之前训练的模型相结合。

参考代码:

# -*- coding: utf-8 -*-
import os

from PyQt5 import QtCore, QtGui, QtWidgets
from PyQt5.QtWidgets import *
import tensorflow as tf
from PIL import Image
import numpy as np
import sys
class Ui_MainWindow(object):
    def setupUi(self, MainWindow):
        MainWindow.setObjectName("MainWindow")
        MainWindow.resize(800, 600)
        self.centralwidget = QtWidgets.QWidget(MainWindow)
        self.centralwidget.setObjectName("centralwidget")
        self.label = QtWidgets.QLabel(self.centralwidget)
        self.label.setGeometry(QtCore.QRect(220, 20, 291, 61))
        self.label.setScaledContents(False)
        self.label.setObjectName("label")
        self.pushButton = QtWidgets.QPushButton(self.centralwidget)
        self.pushButton.setGeometry(QtCore.QRect(160, 430, 93, 28))
        self.pushButton.setObjectName("pushButton")
        self.pushButton_2 = QtWidgets.QPushButton(self.centralwidget)
        self.pushButton_2.setGeometry(QtCore.QRect(440, 430, 93, 28))
        self.pushButton_2.setObjectName("pushButton_2")
        self.label_2 = QtWidgets.QLabel(self.centralwidget)
        self.label_2.setGeometry(QtCore.QRect(150, 90, 381, 321))
        self.label_2.setText("")
        self.label_2.setObjectName("label_2")
        self.label_3 = QtWidgets.QLabel(self.centralwidget)
        self.label_3.setGeometry(QtCore.QRect(550, 130, 141, 51))
        self.label_3.setText("")
        self.label_3.setObjectName("label_3")
        self.label_4 = QtWidgets.QLabel(self.centralwidget)
        self.label_4.setGeometry(QtCore.QRect(550, 90, 141, 31))
        self.label_4.setObjectName("label_4")
        self.textBrowser = QtWidgets.QTextBrowser(self.centralwidget)
        self.textBrowser.setGeometry(QtCore.QRect(150, 90, 381, 321))
        self.textBrowser.setObjectName("textBrowser")
        self.textBrowser_2 = QtWidgets.QTextBrowser(self.centralwidget)
        self.textBrowser_2.setGeometry(QtCore.QRect(550, 130, 141, 51))
        self.textBrowser_2.setObjectName("textBrowser_2")
        self.textBrowser_3 = QtWidgets.QTextBrowser(self.centralwidget)
        self.textBrowser_3.setGeometry(QtCore.QRect(220, 20, 291, 61))
        self.textBrowser_3.setObjectName("textBrowser_3")
        self.textBrowser_4 = QtWidgets.QTextBrowser(self.centralwidget)
        self.textBrowser_4.setGeometry(QtCore.QRect(550, 90, 141, 31))
        self.textBrowser_4.setObjectName("textBrowser_4")
        self.textBrowser_2.raise_()
        self.label.raise_()
        self.textBrowser.raise_()
        self.textBrowser_3.raise_()
        self.pushButton.raise_()
        self.pushButton_2.raise_()
        self.label_2.raise_()
        self.label_4.raise_()
        self.textBrowser_4.raise_()
        self.label_3.raise_()
        MainWindow.setCentralWidget(self.centralwidget)
        self.menubar = QtWidgets.QMenuBar(MainWindow)
        self.menubar.setGeometry(QtCore.QRect(0, 0, 800, 26))
        self.menubar.setObjectName("menubar")
        MainWindow.setMenuBar(self.menubar)
        self.statusbar = QtWidgets.QStatusBar(MainWindow)
        self.statusbar.setObjectName("statusbar")
        MainWindow.setStatusBar(self.statusbar)
        self.toolBar = QtWidgets.QToolBar(MainWindow)
        self.toolBar.setObjectName("toolBar")
        MainWindow.addToolBar(QtCore.Qt.TopToolBarArea, self.toolBar)

        self.retranslateUi(MainWindow)
        QtCore.QMetaObject.connectSlotsByName(MainWindow)
        # 模型相关变量初始化
        self.model = tf.keras.models.load_model(r'C:\Users\zjl15\PycharmProjects\pythonProject1\roch_classification_cnn.h5')
        self.path = ''
        self.rock_types = ['砾岩','安山岩','花岗岩','石灰岩','石英岩']
        # 将“导入图片”按钮与openImage函数绑定
        self.pushButton.clicked.connect(self.openImage)
        # 将“岩石分类”按钮与classify函数绑定
        self.pushButton_2.clicked.connect(self.classify)
    def retranslateUi(self, MainWindow):
        _translate = QtCore.QCoreApplication.translate
        MainWindow.setWindowTitle(_translate("MainWindow", "MainWindow"))
        self.label.setText(_translate("MainWindow", "岩石图像分类"))
        self.pushButton.setText(_translate("MainWindow", "导入图像"))
        self.pushButton_2.setText(_translate("MainWindow", "岩石分类"))
        self.label_4.setText(_translate("MainWindow", "分类结果"))
        self.textBrowser_3.setHtml(_translate("MainWindow",
                                              "<!DOCTYPE HTML PUBLIC \"-//W3C//DTD HTML 4.0//EN\" \"http://www.w3.org/TR/REC-html40/strict.dtd\">\n"
                                              "<html><head><meta name=\"qrichtext\" content=\"1\" /><style type=\"text/css\">\n"
                                              "p, li { white-space: pre-wrap; }\n"
                                              "</style></head><body style=\" font-family:\'SimSun\'; font-size:9pt; font-weight:400; font-style:normal;\">\n"
                                              "<p align=\"center\" style=\" margin-top:0px; margin-bottom:0px; margin-left:0px; margin-right:0px; -qt-block-indent:0; text-indent:0px;\"><span style=\" font-size:24pt;\">岩石图像识别</span></p></body></html>"))
        self.textBrowser_4.setHtml(_translate("MainWindow",
                                              "<!DOCTYPE HTML PUBLIC \"-//W3C//DTD HTML 4.0//EN\" \"http://www.w3.org/TR/REC-html40/strict.dtd\">\n"
                                              "<html><head><meta name=\"qrichtext\" content=\"1\" /><style type=\"text/css\">\n"
                                              "p, li { white-space: pre-wrap; }\n"
                                              "</style></head><body style=\" font-family:\'SimSun\'; font-size:9pt; font-weight:400; font-style:normal;\">\n"
                                              "<p align=\"center\" style=\" margin-top:0px; margin-bottom:0px; margin-left:0px; margin-right:0px; -qt-block-indent:0; text-indent:0px;\"><span style=\" font-size:11pt;\">分类结果</span></p></body></html>"))
        self.toolBar.setWindowTitle(_translate("MainWindow", "toolBar"))
    # 导入图片函数

    def resource_path(relative):
        if hasattr(sys, "_MEIPASS"):
            absolute_path = os.path.join(sys._MEIPASS, relative)
        else:
            absolute_path = os.path.join(relative)
        return absolute_path

    # 在原来引用该文件的地方加上这个函数 (resource_path("文件名"))
    def openImage(self):
        imgPath, imgType = QFileDialog.getOpenFileName(None, "导入图片", "", "*.jpg;;*.png;;All Files(*)")
        jpg = QtGui.QPixmap(imgPath).scaled(self.label_2.width(), self.label_2.height())
        self.label_2.setPixmap(jpg)
        self.path=imgPath
        self.label_3.setText('')
    def classify(self):
        img = Image.open(self.path)  # 读取图像
        img_rgb = img.split()
        x = np.zeros((1, 224, 224, 3))
        x[0,:, :, 0] = np.array(img_rgb[0]) / 255
        x[0,:, :, 1] = np.array(img_rgb[1]) / 255
        x[0,:, :, 2] = np.array(img_rgb[2]) / 255
        y = self.model.predict(x)
        result = self.rock_types[np.argmax(y)]
        self.label_3.setText(result)
if __name__=='__main__':
    QtCore.QCoreApplication.setAttribute(QtCore.Qt.AA_EnableHighDpiScaling)
    app=QtWidgets.QApplication(sys.argv)
    MainWindow=QtWidgets.QMainWindow()
    ui_test=Ui_MainWindow()
    ui_test.setupUi(MainWindow)
    MainWindow.show()
    sys.exit(app.exec_())

1.7.5测试

执行程序测试“导入图片”和“鉴定分类”功能。

1.8打包可执行文件(exe)

在命令窗口中使用如下指令对上一步的程序进行打包。

Pyinstaller -F -w xxxxx.py

运行生成的.exe文件并测试功能。

打完包之后可能出现错误

报错信息:

=============================================================

A RecursionError (maximum recursion depth exceeded) occurred.

For working around please follow these instructions

=============================================================

1. In your program's .spec file add this line near the top::

     import sys ; sys.setrecursionlimit(sys.getrecursionlimit() * 5)

2. Build your program by running PyInstaller with the .spec file as

   argument::

     pyinstaller myprog.spec

3. If this fails, you most probably hit an endless recursion in

   PyInstaller. Please try to track this down has far as possible,

   create a minimal example so we can reproduce and open an issue at

   https://github.com/pyinstaller/pyinstaller/issues following the

   instructions in the issue template. Many thanks.

Explanation: Python's stack-limit is a safety-belt against endless recursion,

eating up memory. PyInstaller imports modules recursively. If the structure

how modules are imported within your program is awkward, this leads to the

nesting being too deep and hitting Python's stack-limit.

With the default recursion limit (1000), the recursion error occurs at about

115 nested imported, with limit 2000 at about 240, with limit 5000 at about

660.

————————————————

你打包目录下会生成如下文件

打开你的main.spec文件

在顶端添加代码:

import sys

sys.setrecursionlimit(sys.getrecursionlimit() * 5)

然后在运行命令(对应的文件名)

pyinstaller 你的文件名.spec

然后就完成了

打完包之的运行闪退问题:

先安装一个新的第三方库ordereddict

安装命令:

pip install ordereddict

注意自己python代码的文件引入路径(确保对应的路径下有对应的文件,我这里设置的是根目录下)

重新打包

完成之后

打开对应的文件夹双击就可以了

完整代码:

import cv2
import os
import glob
# 数据增强函数
def augment_data(img, save_path):
    rows, cols, _ = img.shape
    # 水平翻转图像
    img_flip = cv2.flip(img, 1)
    img_name = os.path.splitext(save_path)[0] + "_flip.jpg"
    cv2.imwrite(img_name, img_flip)
    print("Saved augmented image:", img_name)
    # 随机缩放图像
    scale = np.random.uniform(0.9, 1.1)
    M = cv2.getRotationMatrix2D((cols / 2, rows / 2), 0, scale)
    img_transformed = cv2.warpAffine(img, M, (cols, rows))
    img_name = os.path.splitext(save_path)[0] + "_transform.jpg"
    cv2.imwrite(img_name, img_transformed)
    print("Saved augmented image:", img_name)
    # 随机旋转图像
    angle = np.random.randint(-10, 10)
    M = cv2.getRotationMatrix2D((cols / 2, rows / 2), angle, 1)
    img_rotated = cv2.warpAffine(img, M, (cols, rows))
    img_name = os.path.splitext(save_path)[0] + "_rotated.jpg"
    cv2.imwrite(img_name, img_rotated)
    print("Saved augmented image:", img_name)
    # 添加高斯噪音
    mean = 0
    std = np.random.uniform(5, 15)
    noise = np.zeros(img.shape, np.float32)
    cv2.randn(noise, mean, std)
    noise = np.uint8(noise)
    img_noisy = cv2.add(img, noise)
    img_name = os.path.splitext(save_path)[0] + "_noisy.jpg"
    cv2.imwrite(img_name, img_noisy)
    print("Saved augmented image:", img_name)
    # 随机调整对比度和亮度
    alpha = np.random.uniform(0.8, 1.2)
    beta = np.random.randint(-10, 10)
    img_contrast = cv2.convertScaleAbs(img, alpha=alpha, beta=beta)
    img_name = os.path.splitext(save_path)[0] + "_contrast.jpg"
    cv2.imwrite(img_name, img_contrast)
    print("Saved augmented image:", img_name)
    return img
# 读取 data 文件夹中的所有图片,并进行数据增强
data_dir = r"images"
save_dir = r"images2"
if not os.path.exists(save_dir):
    os.makedirs(save_dir)
# 使用 glob 库来遍历 data 文件夹中所有图像
for img_path in glob.glob(os.path.join(data_dir, "*.jpg")):
    img = cv2.imread(img_path)
    if img is None:
        print("Error: Unable to read image at", img_path)
        continue
    # 获取保存增强后的图片文件名
    img_name = os.path.basename(img_path)
    save_path = os.path.join(save_dir, img_name)
    # 数据增强
    augmented_img = augment_data(img, save_path)
    if augmented_img is not None:
        # 保存原始图片
        cv2.imwrite(save_path, img)
        print("Saved original image:", save_path)
#%%
import os
import cv2
import numpy as np
from PIL import Image
# 设置图片文件夹路径
image_folder = r"images2"
# 获取所有类别的文件夹名(假设每个文件夹是一个类别)
categories = os.listdir(image_folder)

# 初始化特征变量 X 和标签 Y 的列表
X_list = np.zeros((len(categories), 224, 224, 3))
Y_list = np.zeros((len(categories)))

i=0
for name in categories:
    img = Image.open(image_folder + '\\' +name)
    img_rgb = img.split()
    X_list[i,:,:,0] = np.array(img_rgb[0])/255
    X_list[i,:,:,1] = np.array(img_rgb[1])/255
    X_list[i,:,:,2] = np.array(img_rgb[2])/255
    Y_list[i] = name.split('_')[0]
    i+=1
# 将特征变量 X 和标签 Y 的列表转化为 NumPy 数组
X = np.array(X_list)
Y = np.array(Y_list)

# 打印特征变量 X 和标签 Y 的形状
print('特征变量 X 的形状:', X)
print('标签 Y 的形状:', Y)
#%%
from sklearn.model_selection import train_test_split
import seaborn as sns  
import matplotlib.pyplot as plt  
import tensorflow as tf
from sklearn.metrics import confusion_matrix  
from tensorflow.keras.layers import Conv2D, MaxPooling2D, Flatten, Dense
# 5个类别  
num_classes = 5  
# 输入图像的大小是224x224,有3个颜色通道(对于彩色图像)  
input_shape = (224, 224, 3)  
# 假设X和Y是您的原始数据  
# X: 图像数据,形状为(num_samples, 224, 224, 3)  
# Y: 标签数据,形状为(num_samples,) 并且是整数形式的标签(从0到4)  
# 将数据划分为训练集和测试集(只执行一次)  
x_train, x_test, y_train, y_test = train_test_split(X, Y, test_size=0.2, random_state=42)  
# 构建模型  
model = tf.keras.models.Sequential([  
    tf.keras.layers.Conv2D(6, (5, 5), strides=(1,1), activation='relu', input_shape=input_shape),  
    tf.keras.layers.MaxPooling2D((2,2), strides=2),  
    tf.keras.layers.Conv2D(16, (5,5), activation='relu'),  
    tf.keras.layers.MaxPooling2D((2,2), strides=2),  
    tf.keras.layers.Conv2D(120, (5,5), activation='relu'),  
    tf.keras.layers.Flatten(),  
    tf.keras.layers.Dense(84, activation='relu'),  
    tf.keras.layers.Dropout(0.3),  
    tf.keras.layers.Dense(num_classes, activation='softmax')  # 确保输出层的神经元数量与类别数量匹配  
])  
  
# 编译模型  
model.compile(loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=False),# 使用sparse categorical crossentropy损失函数   
              optimizer=tf.keras.optimizers.Adam(),  # 使用Adam优化器  
              metrics=['sparse_categorical_accuracy'])  # 监控准确率  
  
# 打印模型概述  
model.summary()  
  
# 使用model.fit()函数训练模型  
history = model.fit(x_train, y_train, epochs=10, validation_split=0.2)  

#%%
y_pred = model.predict(x_test) 
print(y_pred)
#%%

#%%
# 获取训练和验证的准确率和损失  
acc = history.history['sparse_categorical_accuracy']  
val_acc = history.history['sparse_categorical_accuracy']  
loss = history.history['loss']  
val_loss = history.history['val_loss']  
  
# 使用model.evaluate()函数评估模型在测试集上的性能  
test_loss, test_accuracy = model.evaluate(x_test, y_test)  
print(f'Test accuracy: {test_accuracy}')  
  
# 使用model.predict()函数对新的图像进行预测。
plt.figure(figsize=(15,10))
plt.plot(history.epoch, history.history['loss'],label='loss')
plt.plot(history.epoch, history.history['val_loss'],label='var_loss')
plt.xlabel('Epoch')
plt.ylabel('loss')
plt.legend(loc='upper right')

plt.figure(figsize=(15,10))
plt.plot(history.epoch,history.history['sparse_categorical_accuracy'],label='sparse_categorical_accuracy')
plt.plot(history.epoch,history.history['val_sparse_categorical_accuracy'],label='val_sparse_categorical_accuracy')
plt.xlabel('Epoch')
plt.ylabel('sparse_categorical_accuracy')
plt.legend(loc='upper right')
plt.show()

plt.rcParams['font.sans-serif'] = ['SimHei'] 
y_pred = np.argmax(model.predict(x_test),axis=1)
cm = confusion_matrix(y_test, y_pred,labels=[0,1,2,3,4])
sns.heatmap(cm,annot=True,
            cmap="Blues",
            cbar=False,
            linewidths=2,
            linecolor='white',
            square=True,
            xticklabels=['砾岩','安山岩','花岗岩','石灰岩','石英岩'],
            yticklabels=['砾岩','安山岩','花岗岩','石灰岩','石英岩']
            )
plt.show
#%%
model.save('roch_classification_cnn.h5')

# -*- coding: utf-8 -*-
import os

from PyQt5 import QtCore, QtGui, QtWidgets
from PyQt5.QtWidgets import *
import tensorflow as tf
from PIL import Image
import numpy as np
import sys
class Ui_MainWindow(object):
    def setupUi(self, MainWindow):
        MainWindow.setObjectName("MainWindow")
        MainWindow.resize(800, 600)
        self.centralwidget = QtWidgets.QWidget(MainWindow)
        self.centralwidget.setObjectName("centralwidget")
        self.label = QtWidgets.QLabel(self.centralwidget)
        self.label.setGeometry(QtCore.QRect(220, 20, 291, 61))
        self.label.setScaledContents(False)
        self.label.setObjectName("label")
        self.pushButton = QtWidgets.QPushButton(self.centralwidget)
        self.pushButton.setGeometry(QtCore.QRect(160, 430, 93, 28))
        self.pushButton.setObjectName("pushButton")
        self.pushButton_2 = QtWidgets.QPushButton(self.centralwidget)
        self.pushButton_2.setGeometry(QtCore.QRect(440, 430, 93, 28))
        self.pushButton_2.setObjectName("pushButton_2")
        self.label_2 = QtWidgets.QLabel(self.centralwidget)
        self.label_2.setGeometry(QtCore.QRect(150, 90, 381, 321))
        self.label_2.setText("")
        self.label_2.setObjectName("label_2")
        self.label_3 = QtWidgets.QLabel(self.centralwidget)
        self.label_3.setGeometry(QtCore.QRect(550, 130, 141, 51))
        self.label_3.setText("")
        self.label_3.setObjectName("label_3")
        self.label_4 = QtWidgets.QLabel(self.centralwidget)
        self.label_4.setGeometry(QtCore.QRect(550, 90, 141, 31))
        self.label_4.setObjectName("label_4")
        self.textBrowser = QtWidgets.QTextBrowser(self.centralwidget)
        self.textBrowser.setGeometry(QtCore.QRect(150, 90, 381, 321))
        self.textBrowser.setObjectName("textBrowser")
        self.textBrowser_2 = QtWidgets.QTextBrowser(self.centralwidget)
        self.textBrowser_2.setGeometry(QtCore.QRect(550, 130, 141, 51))
        self.textBrowser_2.setObjectName("textBrowser_2")
        self.textBrowser_3 = QtWidgets.QTextBrowser(self.centralwidget)
        self.textBrowser_3.setGeometry(QtCore.QRect(220, 20, 291, 61))
        self.textBrowser_3.setObjectName("textBrowser_3")
        self.textBrowser_4 = QtWidgets.QTextBrowser(self.centralwidget)
        self.textBrowser_4.setGeometry(QtCore.QRect(550, 90, 141, 31))
        self.textBrowser_4.setObjectName("textBrowser_4")
        self.textBrowser_2.raise_()
        self.label.raise_()
        self.textBrowser.raise_()
        self.textBrowser_3.raise_()
        self.pushButton.raise_()
        self.pushButton_2.raise_()
        self.label_2.raise_()
        self.label_4.raise_()
        self.textBrowser_4.raise_()
        self.label_3.raise_()
        MainWindow.setCentralWidget(self.centralwidget)
        self.menubar = QtWidgets.QMenuBar(MainWindow)
        self.menubar.setGeometry(QtCore.QRect(0, 0, 800, 26))
        self.menubar.setObjectName("menubar")
        MainWindow.setMenuBar(self.menubar)
        self.statusbar = QtWidgets.QStatusBar(MainWindow)
        self.statusbar.setObjectName("statusbar")
        MainWindow.setStatusBar(self.statusbar)
        self.toolBar = QtWidgets.QToolBar(MainWindow)
        self.toolBar.setObjectName("toolBar")
        MainWindow.addToolBar(QtCore.Qt.TopToolBarArea, self.toolBar)

        self.retranslateUi(MainWindow)
        QtCore.QMetaObject.connectSlotsByName(MainWindow)
        # 模型相关变量初始化
        self.model = tf.keras.models.load_model(r'C:\Users\zjl15\PycharmProjects\pythonProject1\roch_classification_cnn.h5')
        self.path = ''
        self.rock_types = ['砾岩','安山岩','花岗岩','石灰岩','石英岩']
        # 将“导入图片”按钮与openImage函数绑定
        self.pushButton.clicked.connect(self.openImage)
        # 将“岩石分类”按钮与classify函数绑定
        self.pushButton_2.clicked.connect(self.classify)
    def retranslateUi(self, MainWindow):
        _translate = QtCore.QCoreApplication.translate
        MainWindow.setWindowTitle(_translate("MainWindow", "MainWindow"))
        self.label.setText(_translate("MainWindow", "岩石图像分类"))
        self.pushButton.setText(_translate("MainWindow", "导入图像"))
        self.pushButton_2.setText(_translate("MainWindow", "岩石分类"))
        self.label_4.setText(_translate("MainWindow", "分类结果"))
        self.textBrowser_3.setHtml(_translate("MainWindow",
                                              "<!DOCTYPE HTML PUBLIC \"-//W3C//DTD HTML 4.0//EN\" \"http://www.w3.org/TR/REC-html40/strict.dtd\">\n"
                                              "<html><head><meta name=\"qrichtext\" content=\"1\" /><style type=\"text/css\">\n"
                                              "p, li { white-space: pre-wrap; }\n"
                                              "</style></head><body style=\" font-family:\'SimSun\'; font-size:9pt; font-weight:400; font-style:normal;\">\n"
                                              "<p align=\"center\" style=\" margin-top:0px; margin-bottom:0px; margin-left:0px; margin-right:0px; -qt-block-indent:0; text-indent:0px;\"><span style=\" font-size:24pt;\">岩石图像识别</span></p></body></html>"))
        self.textBrowser_4.setHtml(_translate("MainWindow",
                                              "<!DOCTYPE HTML PUBLIC \"-//W3C//DTD HTML 4.0//EN\" \"http://www.w3.org/TR/REC-html40/strict.dtd\">\n"
                                              "<html><head><meta name=\"qrichtext\" content=\"1\" /><style type=\"text/css\">\n"
                                              "p, li { white-space: pre-wrap; }\n"
                                              "</style></head><body style=\" font-family:\'SimSun\'; font-size:9pt; font-weight:400; font-style:normal;\">\n"
                                              "<p align=\"center\" style=\" margin-top:0px; margin-bottom:0px; margin-left:0px; margin-right:0px; -qt-block-indent:0; text-indent:0px;\"><span style=\" font-size:11pt;\">分类结果</span></p></body></html>"))
        self.toolBar.setWindowTitle(_translate("MainWindow", "toolBar"))
    # 导入图片函数

    def resource_path(relative):
        if hasattr(sys, "_MEIPASS"):
            absolute_path = os.path.join(sys._MEIPASS, relative)
        else:
            absolute_path = os.path.join(relative)
        return absolute_path

    # 在原来引用该文件的地方加上这个函数 (resource_path("文件名"))
    def openImage(self):
        imgPath, imgType = QFileDialog.getOpenFileName(None, "导入图片", "", "*.jpg;;*.png;;All Files(*)")
        jpg = QtGui.QPixmap(imgPath).scaled(self.label_2.width(), self.label_2.height())
        self.label_2.setPixmap(jpg)
        self.path=imgPath
        self.label_3.setText('')
    def classify(self):
        img = Image.open(self.path)  # 读取图像
        img_rgb = img.split()
        x = np.zeros((1, 224, 224, 3))
        x[0,:, :, 0] = np.array(img_rgb[0]) / 255
        x[0,:, :, 1] = np.array(img_rgb[1]) / 255
        x[0,:, :, 2] = np.array(img_rgb[2]) / 255
        y = self.model.predict(x)
        result = self.rock_types[np.argmax(y)]
        self.label_3.setText(result)
if __name__=='__main__':
    QtCore.QCoreApplication.setAttribute(QtCore.Qt.AA_EnableHighDpiScaling)
    app=QtWidgets.QApplication(sys.argv)
    MainWindow=QtWidgets.QMainWindow()
    ui_test=Ui_MainWindow()
    ui_test.setupUi(MainWindow)
    MainWindow.show()
    sys.exit(app.exec_())

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:/a/666979.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

github有趣项目:自制“我的世界” project make

videocodehttps://www.youtube.com/watch?v4O0_-1NaWnY,https://www.bilibili.com/video/BV1oj411p7qM/?https://github.com/jdah/minecraft-weekend MAKE git clone --recurse-submodules https://github.com/jdah/minecraft-weekend.git 正克隆到 minecraft-weekend... …

【笔记】关于brew install ffmpeg出现问题解决

Macos系统需要安装ffmpeg使用&#xff0c;通过brew install ffmpeg安装相关依赖时&#xff0c;当安装至flac时出现下列问题 环境&#xff1a;有代理开启 使用国内数据源 brew install ffmpeg --verbose --debug 安装过程中显示日志 curl: (35) error:1400442E:SSL routines:C…

前端3剑客(第1篇)-初识HTML

100编程书屋_孔夫子旧书网 当今主流的技术中&#xff0c;可以分为前端和后端两个门类。 前端&#xff1a;简单的理解就是和用户打交道 后端&#xff1a;主要用于组织数据 而前端就Web开发方向来说&#xff0c; 分为三门语言&#xff0c; HTML、CSS、JavaScript 语言作用HT…

Apache Pulsar 中文社区有奖问卷调查(2024 上半年度)

Apache Pulsar 中文社区有奖问卷调查&#xff08;2024 上半年度&#xff09; &#x1f4e3; &#x1f4e3; &#x1f4e3; Hi&#xff0c;Apache Pulsar 社区的小伙伴们&#xff0c;社区 2024 上半年度的有奖问卷调查来啦&#xff01; &#x1f64c; 本次调查旨在了解用户使用 …

EIS 2019 webshell

请求中可以确定是http POST流量 同时可以判断是 蚁剑的流量 进一步过滤 http.request.method "POST" 直接追踪其tcp流 得到 列举部分 eVAl(cHr(0x40).ChR(0x69).ChR(0x6e).ChR(0x69).ChR(0x5f).ChR(0x73).ChR(0x65).ChR(0x74).ChR(0x28)直接输出一下 内容 <…

数据治理基础知识

文章目录 基本概念相关名词术语数据治理对象 基本概念 1&#xff09;从管理者视角看数据治理 数据治理是企业发展战略的组成部分&#xff0c;是指导整个集团进行数字化变革的基石&#xff0c;要将数据治理纳入企业的顶 层规划&#xff0c;各分/子公司、各业务部门都需要按照企…

智慧园区整理技术方案(ppt,软件全套建设方案)

智慧园区管控平台整体技术方案 1.平台概述 2.公共安全 3.物业管理 4.综合管理 5.企业服务 平台规划&#xff0c;整理技术架构搭建&#xff0c;统一门户&#xff0c;lot物联平台&#xff0c;视频云管理平台&#xff0c;GIS服务平台&#xff0c;服务器架构&#xff0c;统一身份认…

发现一个ai工具网站

网址 https://17yongai.com/ 大概看了下&#xff0c;这个网站收集的数据还挺有用的&#xff0c;有很多实用的ai教程。 懂ai工具的可以在这上面找找灵感。

HTML如何让文字底部线条不紧贴在文字下面(既在内容下方又超出内容区域)

hello&#xff0c;大家好&#xff0c;星途小鹏今天给大家带来的内容是如何让文字底部线条不紧贴在文字下面。 话不多说&#xff0c;先上效果图 简单来说就是padding和margin的区别。 在网页设计中&#xff0c;有时我们想要给某个元素添加一个装饰性的线条&#xff0c;比如底部…

【设计模式】创建型-建造者模式

前言 在面向对象的软件开发中&#xff0c;构建复杂对象时经常会遇到许多挑战。一种常见的解决方案是使用设计模式&#xff0c;其中建造者模式是一个强大而灵活的选择。本文将深入探讨建造者模式的原理、结构、优点以及如何在实际项目中应用它。 一、复杂的对象 public class…

安卓如何书写注册和登录界面

一、如何跳转一个活动 左边的是本活动名称&#xff0c; 右边的是跳转界面活动名称 Intent intent new Intent(LoginActivity.this, RegisterActivity.class); startActivity(intent); finish(); 二、如果在不同的界面传递参数 //发送消息 SharedPreferences sharedPreferen…

【再探】设计模式—中介者模式、观察者模式及模板方法模式

中介者模式让多对多的复杂引用关系变成一对多&#xff0c;同时能通过中间类来封装多个类中的行为&#xff0c;观察者模式在目标状态更新时能自动通知给订阅者&#xff0c;模版方法模式则是控制方法的执行顺序&#xff0c;子类在不改变算法的结构基础上可以扩展功能实现。 1 中…

Python 之SQLAlchemy使用详细说明

目录 1、SQLAlchemy 1.1、ORM概述 1.2、SQLAlchemy概述 1.3、SQLAlchemy的组成部分 1.4、SQLAlchemy的使用 1.4.1、安装 1.4.2、创建数据库连接 1.4.3、执行原生SQL语句 1.4.4、映射已存在的表 1.4.5、创建表 1.4.5.1、创建表的两种方式 1、使用 Table 类直接创建表…

【稳定检索/投稿优惠】2024年商务、信息管理与大数据经济国际会议(BIMBDE 2024)

2024 International Conference on Business, Information Management, and Big Data Economy 2024年商务、信息管理与大数据经济国际会议 【会议信息】 会议简称&#xff1a;BIMBDE 2024 大会地点&#xff1a;中国北京 会议官网&#xff1a;www.bimbde.com 会议邮箱&#xff…

MySql part1 安装和介绍

MySql part1 安装和介绍 数据 介绍 什么是数据库&#xff0c;数据很好理解&#xff0c;一般来说数据通常是我们所认识的 描述事物的符号记录&#xff0c; 可以是数字、 文字、图形、图像、声音、语言等&#xff0c;数据有多种形式&#xff0c;它们都以经过数字化后存入计算机…

CS4344国产替代音频DAC数模转换芯片DP7344采样率192kHz

目录 DAC应用简介DP7344简介结构框图DP7344主要特性微信号&#xff1a;dnsj5343参考原理图 应用领域 DAC应用简介 DAC&#xff08;中文&#xff1a;数字模拟转换器&#xff09;是一种将数字信号转换为模拟信号&#xff08;以电流、电压或电荷的形式&#xff09;的设备。电脑对…

Golang | Leetcode Golang题解之第123题买卖股票的最佳时机III

题目&#xff1a; 题解&#xff1a; func maxProfit(prices []int) int {buy1, sell1 : -prices[0], 0buy2, sell2 : -prices[0], 0for i : 1; i < len(prices); i {buy1 max(buy1, -prices[i])sell1 max(sell1, buy1prices[i])buy2 max(buy2, sell1-prices[i])sell2 m…

Docker 环境下 3D Guassian Splatting 的编译和配置

Title: Docker 环境下 3D Guassian Splatting 的编译和配置 文章目录 前言I. 宿主系统上的安装配置1. 安装 nvidia driver2. 安装 docker3. 安装 nvidia-container-toolkit II. Docker 容器安装配置1. 拉取 ubuntu 22.042. 创建容器3. 进入容器4. 容器中安装 cuda SDK5. 容器中…

python class __new__、__init__、__call__ 区别

在Python中&#xff0c;__new__、__init__ 和 __call__ 是三个不同的特殊方法&#xff0c;它们在类的创建和调用过程中扮演着不同的角色。以下是它们的区别和用法&#xff1a; 1. __new__ 方法 作用&#xff1a;__new__ 是一个静态方法&#xff0c;负责创建并返回一个新的实例…

携手亚马逊云科技,神州泰岳如何打通生成式AI落地最后三公里

导读&#xff1a;神州泰岳成为首批获得亚马逊云科技生成式AI能力认证的合作伙伴。 “过去6年来&#xff0c;在与亚马逊云科技的合作过程中&#xff0c;我们大概签约了300家以上的中国出海企业。”近日在一次沟通会上&#xff0c;神州泰岳副总裁兼云事业部总经理刘家歆这样向媒…