C/C++开发,opencv-ml库学习,支持向量机(SVM)应用

目录

一、OpenCV支持向量机(SVM)模块

1.1 openCV的机器学习库

1.2 SVM(支持向量机)模块

1.3 支持向量机(SVM)应用步骤

二、支持向量机(SVM)应用示例

 2.1  训练及验证数据获取

2.2 训练及验证数据加载

2.3 SVM(支持向量机)训练及验证,输出svm模型

2.4 SVM(支持向量机)实时识别应用

三、完整代码编译

3.1 OpenCV+MinGW的MakeFile编译

3.2 OpenCV+vc2015+cmake编译

3.3 执行效果

3.4 附件,main.cpp全文


一、OpenCV支持向量机(SVM)模块

1.1 openCV的机器学习库

        OpenCV-ml库是OpenCV(开放源代码计算机视觉库)中的机器学习模块,常用于分类和回归问题,它是 OpenCV 众多modules下的一个模块。

        该模块提供了一系列常见的统计模型和分类算法,用于进行各种机器学习任务。以下是关于OpenCV-ml库的一些主要功能和特点:

  1. 丰富的算法支持:OpenCV-ml库包含了多种机器学习算法,如支持向量机(SVM)、决策树、Boosting方法、K近邻(KNN)、随机森林等。这些算法可以用于分类、回归、聚类等多种任务。
  2. 易于使用:OpenCV-ml库提供了简洁的API接口,使得开发者能够方便地调用各种机器学习算法。同时,它也支持多种数据格式,方便用户导入和处理数据。
  3. 高效性:OpenCV-ml库经过优化,能够高效地处理大规模数据集,并且具有较快的运算速度。这使得它能够满足实时处理和分析的需求。
  4. 与OpenCV其他模块的集成:OpenCV-ml库与OpenCV的其他模块(如imgproc、features2d等)紧密集成,可以方便地进行图像处理和特征提取,然后将提取的特征用于机器学习任务。
1.2 SVM(支持向量机)模块

        OpenCV 的 SVM(支持向量机)模块是 OpenCV 机器学习库中的一个重要组成部分,它实现了支持向量机算法,用于解决分类和回归问题。支持向量机是一种监督学习模型,广泛应用于各种领域,特别是在图像分类和识别任务中。

        OpenCV 的 SVM 模块提供了灵活的参数设置和多种核函数选择,以适应不同的数据集和问题。以下是一些关于 OpenCV SVM 模块的主要特点:

  1. 多种核函数:支持线性核、多项式核、径向基函数(RBF)核和 Sigmoid 核等,可以根据问题的特性选择合适的核函数。

  2. 参数调整:可以通过调整 SVM 的参数,如 C 值(错误项的惩罚系数)和 gamma 值(对于 RBF、Poly 和 Sigmoid 核函数),来优化模型的性能。

  3. 多类分类支持:通过“一对一”或“一对多”的方式,可以处理多类分类问题。

  4. 概率估计:SVM 可以输出类别的概率估计,这对于某些应用(如置信度评估)非常有用。

  5. 易于使用:OpenCV 提供了简洁的 API,使得 SVM 的训练和测试过程相对简单。

1.3 支持向量机(SVM)应用步骤

        在OpenCV中,使用支持向量机(SVM)进行预测涉及几个步骤。首先,获得训练数据,用于训练一个SVM模型,然后使用该模型对新的、未见过的数据进行预测。

    使用svm模型,包含必要的头文件:

#include <opencv2/opencv.hpp>  
#include <opencv2/ml/ml.hpp>  

   1) 准备训练和测试数据:

    你需要为SVM准备训练和测试数据。这些数据通常是特征向量,存储在cv::Mat对象中。每个特征向量对应一个标签(分类的类别)。
    2)创建和训练SVM模型:
    使用OpenCV的cv::ml::SVM类来创建SVM模型。然后,使用train方法来训练模型。
   3) 进行预测:
    使用训练好的模型对新数据进行预测。这通常涉及将新数据作为输入传递给模型的predict方法。

二、支持向量机(SVM)应用示例

 2.1  训练及验证数据获取

        以下展示如何使用OpenCV的机器学习模块来实现一个基于SVM的手写数字识别器。首先前往网站:MNIST handwritten digit database, Yann LeCun, Corinna Cortes and Chris Burges,下载MNIST database,用于实现一个SVM的手写数字识别模型训练及验证。

        下载完成后,进行解压操作:

        解压后是idx1-ubyteidx3-ubyte 是两种常见的标签编码格式,主要用于图像分割任务中。它们都是用来表示图像中每个像素所属类别的标签图像(也称为掩码或mask)。

  1. idx1-ubyte:

    • idx: 表示这是一个索引图像。
    • 1: 表示每个像素用一个字节(8位)来表示,且这些值从0开始,通常是连续的整数。
    • ubyte: 表示无符号字节类型,其值的范围是0到255。在idx1-ubyte格式中,通常会将0用作背景或未标记的类别,而其他值则用于表示不同的分割区域或类别。
  2. idx3-ubyte:

    • idx: 同样表示这是一个索引图像。
    • 3: 这里并不是指每个像素用3个字节来表示,而是指每个像素用一个字节来表示,但值的范围是从0到255,通常用来表示256个不同的类别(包括0作为背景或未标记的类别)。注意,虽然名为idx3,但实际上它并不是用3个字节来存储每个像素的值。
    • ubyte: 同样表示无符号字节类型。

        在图像分割任务中,这些标签图像通常与原始RGB图像一起使用。RGB图像用于显示给人类观察者或作为模型的输入,而标签图像则用于训练模型或评估模型的性能。

2.2 训练及验证数据加载

        idx3-ubyte 文件通常与 MNIST 数据集相关联,这是一个大型的手写数字数据库,经常用于机器学习和深度学习中的图像识别任务。MNIST 数据集包含两个文件:train-images-idx3-ubytetrain-labels-idx1-ubyte(用于训练),以及 t10k-images-idx3-ubytet10k-labels-idx1-ubyte(用于测试)。这些文件使用特定的二进制格式存储图像和标签。

        通过两个函数来读取手写图像数据集和手写图像数据对应的标签(每个标签都是一个 0 到 9 之间的整数,表示对应图像中的手写数字)。

//大小端转换
int intReverse(int num)
{
	return (num>>24|((num&0xFF0000)>>8)|((num&0xFF00)<<8)|((num&0xFF)<<24));
}

//读取手写图像数据集
cv::Mat read_mnist_image(const std::string fileName) {
	int magic_number = 0;
	int number_of_images = 0;
	int img_rows = 0;
	int img_cols = 0;

	cv::Mat DataMat;

	std::ifstream file(fileName, std::ios::binary);
	if (file.is_open())
	{
		std::cout << "open images file: "<< fileName << std::endl;

		file.read((char*)&magic_number, sizeof(magic_number));//format
		file.read((char*)&number_of_images, sizeof(number_of_images));//images number
		file.read((char*)&img_rows, sizeof(img_rows));//img rows
		file.read((char*)&img_cols, sizeof(img_cols));//img cols

		magic_number = intReverse(magic_number);
		number_of_images = intReverse(number_of_images);
		img_rows = intReverse(img_rows);
		img_cols = intReverse(img_cols);
		std::cout << "format:" << magic_number
			<< " img num:" << number_of_images
			<< " img row:" << img_rows
			<< " img col:" << img_cols << std::endl;

		std::cout << "read img data" << std::endl;

		DataMat = cv::Mat::zeros(number_of_images, img_rows * img_cols, CV_32FC1);
		unsigned char temp = 0;
		for (int i = 0; i < number_of_images; i++) {
			for (int j = 0; j < img_rows * img_cols; j++) {
				file.read((char*)&temp, sizeof(temp));
				//svm data is CV_32FC1
				float pixel_value = float(temp);
				DataMat.at<float>(i, j) = pixel_value;
			}
		}
		std::cout << "read img data finish!" << std::endl;
	}
	file.close();
	return DataMat;
}
//读取手写标签
cv::Mat read_mnist_label(const std::string fileName) {
	int magic_number;
	int number_of_items;

	cv::Mat LabelMat;

	std::ifstream file(fileName, std::ios::binary);
	if (file.is_open())
	{
		std::cout << "open label file: "<< fileName << std::endl;

		file.read((char*)&magic_number, sizeof(magic_number));
		file.read((char*)&number_of_items, sizeof(number_of_items));
		magic_number = intReverse(magic_number);
		number_of_items = intReverse(number_of_items);

		std::cout << "format:" << magic_number << "  ;label_num:" << number_of_items << std::endl;

		std::cout << "read Label data" << std::endl;
		//data type:CV_32SC1,channel:1
		LabelMat = cv::Mat::zeros(number_of_items, 1, CV_32SC1);
		for (int i = 0; i < number_of_items; i++) {
			unsigned char temp = 0;
			file.read((char*)&temp, sizeof(temp));
			LabelMat.at<unsigned int>(i, 0) = (unsigned int)temp;
		}
		std::cout << "read label data finish!" << std::endl;

	}
	file.close();
	return LabelMat;
}
2.3 SVM(支持向量机)训练及验证,输出svm模型

        1)加载训练图像数据和标签数据,采用cv::Mat存储,图像数据虚归一化;

        2)创建svm模型,设置svm模型的各关联参数,不同参数设置,对应模型精度有较大影响;

        3)加载测试图像数据和标签数据,采用cv::Mat存储,图像数据虚归一化;

        4)采用测试图像数据验证已经训练好的svm模型,获得测试推演结果;

        5)通过测试结果和已有的标签数据进行校对,验证该模型精度。

        6)将训练好的模型保持输出。便于后续用于实时识别应用。

//change path for real paths
std::string trainImgFile = "D:\\workForMy\\OpenCVLib\\opencv_demo\\opencv_ml01\\train-images.idx3-ubyte";
std::string trainLabeFile = "D:\\workForMy\\OpenCVLib\\opencv_demo\\opencv_ml01\\train-labels.idx1-ubyte";
std::string testImgFile = "D:\\workForMy\\OpenCVLib\\opencv_demo\\opencv_ml01\\t10k-images.idx3-ubyte";
std::string testLabeFile = "D:\\workForMy\\OpenCVLib\\opencv_demo\\opencv_ml01\\t10k-labels.idx1-ubyte";

void train_SVM()
{
	//read train images, data type CV_32FC1
	cv::Mat trainingData = read_mnist_image(trainImgFile);
	//images data normalization
	trainingData = trainingData/255.0;
	std::cout << "trainingData.size() = " << trainingData.size() << std::endl; 
	//read train label, data type CV_32SC1
	cv::Mat labelsMat = read_mnist_label(trainLabeFile);
	std::cout << "labelsMat.size() = " << labelsMat.size() << std::endl; 
	std::cout << "trainingData & labelsMat finish!" << std::endl;  

    //create SVM model
    cv::Ptr<cv::ml::SVM> svm = cv::ml::SVM::create();  
	//set svm args,type and KernelTypes
    svm->setType(cv::ml::SVM::C_SVC);  
	svm->setKernel(cv::ml::SVM::POLY);  
	//KernelTypes POLY is need set gamma and degree
	svm->setGamma(3.0);
	svm->setDegree(2.0);
	//Set iteration termination conditions, maxCount is importance
	svm->setTermCriteria(cv::TermCriteria(cv::TermCriteria::EPS | cv::TermCriteria::COUNT, 1000, 1e-8)); 
	std::cout << "create SVM object finish!" << std::endl;  

	std::cout << "trainingData.rows = " << trainingData.rows << std::endl; 
	std::cout << "trainingData.cols = " << trainingData.cols << std::endl; 
	std::cout << "trainingData.type() = " << trainingData.type() << std::endl; 
    // svm model train 
    svm->train(trainingData, cv::ml::ROW_SAMPLE, labelsMat);  
	std::cout << "SVM training finish!" << std::endl; 
    // svm model test  
	cv::Mat testData = read_mnist_image(testImgFile);
	//images data normalization
	testData = testData/255.0;
	std::cout << "testData.rows = " << testData.rows << std::endl; 
	std::cout << "testData.cols = " << testData.cols << std::endl; 
	std::cout << "testData.type() = " << testData.type() << std::endl; 
	//read test label, data type CV_32SC1
	cv::Mat testlabel = read_mnist_label(testLabeFile);
	cv::Mat testResp;
	float response = svm->predict(testData,testResp); 
	// std::cout << "response = " << response << std::endl; 
	testResp.convertTo(testResp,CV_32SC1);
	int map_num = 0;
	for (int i = 0; i <testResp.rows&&testResp.rows==testlabel.rows; i++)
	{
		if (testResp.at<int>(i, 0) == testlabel.at<int>(i, 0))
		{
			map_num++;
		}
		// else{
		// 	std::cout << "testResp.at<int>(i, 0) " << testResp.at<int>(i, 0) << std::endl;
		// 	std::cout << "testlabel.at<int>(i, 0) " << testlabel.at<int>(i, 0) << std::endl;
		// }
	}
	float proportion  = float(map_num) / float(testResp.rows);
	std::cout << "map rate: " << proportion * 100 << "%" << std::endl;
	std::cout << "SVM testing finish!" << std::endl; 
	//save svm model
	svm->save("mnist_svm.xml");
}
2.4 SVM(支持向量机)实时识别应用

        将t10k-images.idx3-ubyte处理成图片数据,用于svm模型调用示例,本文主要是通过一段python代码,将t10k-images.idx3-ubyte另存为一张张手写图片。

import numpy as np 
import os 
from PIL import Image  
from struct import unpack  
  
def read_idx3_ubyte(filename):  
    with open(filename, 'rb') as f:  
        magic, num_images, rows, cols = unpack('>IIII', f.read(16))  
        buf = f.read()  
        data = np.frombuffer(buf, dtype=np.uint8).reshape((num_images, rows, cols))  
    return data  
  
def save_images_as_png(idx3_file, output_dir, prefix='image'):  
    images = read_idx3_ubyte(idx3_file)  
    for i, image in enumerate(images):  
        image_pil = Image.fromarray(image, 'L')  # 'L' 表示灰度模式  
        filename = f"{output_dir}/{prefix}_{i}.png"  
        image_pil.save(filename)  
  
# 使用示例  
# idx3_file = 'train-images.idx3-ubyte'  
# output_dir = 'train-images' 
# if not os.path.exists(output_dir):#检查目录是否存在
# 	os.makedirs(output_dir)#如果不存在则创建目录
# save_images_as_png(idx3_file, output_dir)

idx3_file = 't10k-images.idx3-ubyte'  
output_dir = 't10k-images' 
if not os.path.exists(output_dir):#检查目录是否存在
	os.makedirs(output_dir)#如果不存在则创建目录
save_images_as_png(idx3_file, output_dir)

        在获得图片数据后,将加载这些图片,和上述已保存的svm模型(mnist_svm.xml),实现模型调用验证。

void prediction(const std::string fileName,cv::Ptr<cv::ml::SVM> svm)
{
	//read img 28*28 size
	cv::Mat image = cv::imread(fileName, cv::IMREAD_GRAYSCALE);
	//uchar->float32
	image.convertTo(image, CV_32F);
	//image data normalization
	image = image / 255.0;
	//28*28 -> 1*784
	image = image.reshape(1, 1);

	//预测图片
	float ret = svm->predict(image);
	std::cout << "predict val = "<< ret << std::endl;
}

std::string imgDir = "D:\\workForMy\\OpenCVLib\\opencv_demo\\opencv_ml01\\t10k-images\\";
std::string ImgFiles[5] = {"image_0.png","image_10.png","image_20.png","image_30.png","image_40.png",};
void predictimgs()
{
	//load svm model
	cv::Ptr<cv::ml::SVM> svm = cv::ml::StatModel::load<cv::ml::SVM>("mnist_svm.xml");
	for (size_t i = 0; i < 5; i++)
	{
		prediction(imgDir+ImgFiles[i],svm);
	}
}

三、完整代码编译

3.1 OpenCV+MinGW的MakeFile编译

        本文是采用win系统下,opencv采用MinGW编译的静态库(C/C++开发,win下OpenCV+MinGW编译环境搭建_opencv mingw-CSDN博客),建立makefile:

#/bin/sh
#win32
CX= g++ -DWIN32 
#linux
#CX= g++ -Dlinux 

BIN 		:= ./
TARGET      := opencv_ml01.exe
FLAGS		:= -std=c++11 -static
SRCDIR 		:= ./
#INCLUDES
INCLUDEDIR 	:= -I"../../opencv_MinGW/include" -I"./"
#-I"$(SRCDIR)"
staticDir   := ../../opencv_MinGW/x64/mingw/staticlib/
#LIBDIR		:= $(staticDir)/libopencv_world460.a\
#			   $(staticDir)/libade.a \
#			   $(staticDir)/libIlmImf.a \
#			   $(staticDir)/libquirc.a \
#			   $(staticDir)/libzlib.a \
#			   $(wildcard $(staticDir)/liblib*.a) \
#			   -lgdi32 -lComDlg32 -lOleAut32 -lOle32 -luuid 
#opencv_world放弃前,然后是opencv依赖的第三方库,后面的库是MinGW编译工具的库

LIBDIR 	    := -L $(staticDir) -lopencv_world460 -lade -lIlmImf -lquirc -lzlib \
				-llibjpeg-turbo -llibopenjp2 -llibpng -llibprotobuf -llibtiff -llibwebp \
				-lgdi32 -lComDlg32 -lOleAut32 -lOle32 -luuid 
source		:= $(wildcard $(SRCDIR)/*.cpp) 

$(TARGET) :
	$(CX) $(FLAGS) $(INCLUDEDIR) $(source)  -o $(BIN)/$(TARGET) $(LIBDIR)

clean:
	rm  $(BIN)/$(TARGET)

        编译如下:

3.2 OpenCV+vc2015+cmake编译

        第二种编译,本文采用了vs2015 x64编译了opencv库(C/C++开发,opencv在win下安装及应用_windows安装opencv c++-CSDN博客)。

        建立cmake文件:

# CMake 最低版本号要求
cmake_minimum_required (VERSION 2.8)
# 项目信息
project (opencv_test)
#
message(STATUS "windows compiling...")
add_definitions(-D_PLATFORM_IS_WINDOWS_)
set(CMAKE_CXX_FLAGS_RELEASE "${CMAKE_CXX_FLAGS_RELEASE} /MT")
set(CMAKE_CXX_FLAGS_DEBUG "${CMAKE_CXX_FLAGS_DEBUG} /MTd")
set(WIN_OS true)

#
set(EXECUTABLE_OUTPUT_PATH ${PROJECT_SOURCE_DIR}/bin)

# 指定源文件的目录,并将名称保存到变量
SET(source_h
    #
  )
  
SET(source_cpp
    #
	${PROJECT_SOURCE_DIR}/main.cpp
  )
  
#头文件目录
include_directories(${PROJECT_SOURCE_DIR}/../../opencv_VC/include)

set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} /wd4819")

add_definitions(
  "-D_CRT_SECURE_NO_WARNINGS"
  "-D_WINSOCK_DEPRECATED_NO_WARNINGS"
  "-DNO_WARN_MBCS_MFC_DEPRECATION"
  "-DWIN32_LEAN_AND_MEAN"
)

link_directories(
	${PROJECT_SOURCE_DIR}/../../opencv_VC/x64/vc14/bin
	${PROJECT_SOURCE_DIR}/../../opencv_VC/x64/vc14/lib
	)

if (CMAKE_BUILD_TYPE STREQUAL "Debug")

set(CMAKE_RUNTIME_OUTPUT_DIRECTORY_DEBUG ${PROJECT_SOURCE_DIR})
# 指定生成目标
add_executable(opencv_testd ${source_h} ${source_cpp})

else(CMAKE_BUILD_TYPE)

set(CMAKE_RUNTIME_OUTPUT_DIRECTORY_RELEASE ${PROJECT_SOURCE_DIR})
# 指定生成目标
add_executable(opencv_test ${source_h} ${source_cpp})

target_link_libraries(opencv_test opencv_world460.lib opencv_img_hash460.lib)

endif (CMAKE_BUILD_TYPE)

# mkdir build_win
# cd build_win
# cmake -G "Visual Studio 14 2015 Win64" -DCMAKE_BUILD_TYPE=Release ..
# msbuild opencv_test.sln /p:Configuration="Release" /p:Platform="x64"

启动vs2015 x64的命令工具(使前面配置的环境变量生效),进入main.cpp文件目录,编译如下:

mkdir build_win
cd build_win
cmake -G "Visual Studio 14 2015 Win64" -DCMAKE_BUILD_TYPE=Release ..
msbuild opencv_test.sln /p:Configuration="Release" /p:Platform="x64"

        编译输出大致如下:

3.3 执行效果

        【1】OpenCV+MinGW+makefile编译程序执行输出,准确率达到98%以上(PS,大家可尝试去调设SVM模型的参数设置,看怎样设置可以获得更高的准确率)

        通过模型调用识别图片全OK(呵呵,毕竟是测试集内的图片数据)

【2】opencv+vc2015+cmake编译程序执行输出,同样能到达效果。

3.4 附件,main.cpp全文
#include <opencv2/opencv.hpp>  
#include <opencv2/ml/ml.hpp>  
#include <opencv2/imgcodecs.hpp>
#include <iostream>  
#include <vector>  
#include <iostream>
#include <fstream>

int intReverse(int num)
{
	return (num>>24|((num&0xFF0000)>>8)|((num&0xFF00)<<8)|((num&0xFF)<<24));
}

std::string intToString(int num)
{
	char buf[32]={0};
	itoa(num,buf,10);
	return std::string(buf);
}


cv::Mat read_mnist_image(const std::string fileName) {
	int magic_number = 0;
	int number_of_images = 0;
	int img_rows = 0;
	int img_cols = 0;

	cv::Mat DataMat;

	std::ifstream file(fileName, std::ios::binary);
	if (file.is_open())
	{
		std::cout << "open images file: "<< fileName << std::endl;

		file.read((char*)&magic_number, sizeof(magic_number));//format
		file.read((char*)&number_of_images, sizeof(number_of_images));//images number
		file.read((char*)&img_rows, sizeof(img_rows));//img rows
		file.read((char*)&img_cols, sizeof(img_cols));//img cols

		magic_number = intReverse(magic_number);
		number_of_images = intReverse(number_of_images);
		img_rows = intReverse(img_rows);
		img_cols = intReverse(img_cols);
		std::cout << "format:" << magic_number
			<< " img num:" << number_of_images
			<< " img row:" << img_rows
			<< " img col:" << img_cols << std::endl;

		std::cout << "read img data" << std::endl;

		DataMat = cv::Mat::zeros(number_of_images, img_rows * img_cols, CV_32FC1);
		unsigned char temp = 0;
		for (int i = 0; i < number_of_images; i++) {
			for (int j = 0; j < img_rows * img_cols; j++) {
				file.read((char*)&temp, sizeof(temp));
				//svm data is CV_32FC1
				float pixel_value = float(temp);
				DataMat.at<float>(i, j) = pixel_value;
			}
		}
		std::cout << "read img data finish!" << std::endl;
	}
	file.close();
	return DataMat;
}

cv::Mat read_mnist_label(const std::string fileName) {
	int magic_number;
	int number_of_items;

	cv::Mat LabelMat;

	std::ifstream file(fileName, std::ios::binary);
	if (file.is_open())
	{
		std::cout << "open label file: "<< fileName << std::endl;

		file.read((char*)&magic_number, sizeof(magic_number));
		file.read((char*)&number_of_items, sizeof(number_of_items));
		magic_number = intReverse(magic_number);
		number_of_items = intReverse(number_of_items);

		std::cout << "format:" << magic_number << "  ;label_num:" << number_of_items << std::endl;

		std::cout << "read Label data" << std::endl;
		//data type:CV_32SC1,channel:1
		LabelMat = cv::Mat::zeros(number_of_items, 1, CV_32SC1);
		for (int i = 0; i < number_of_items; i++) {
			unsigned char temp = 0;
			file.read((char*)&temp, sizeof(temp));
			LabelMat.at<unsigned int>(i, 0) = (unsigned int)temp;
		}
		std::cout << "read label data finish!" << std::endl;

	}
	file.close();
	return LabelMat;
}

//change path for real paths
std::string trainImgFile = "D:\\workForMy\\OpenCVLib\\opencv_demo\\opencv_ml01\\train-images.idx3-ubyte";
std::string trainLabeFile = "D:\\workForMy\\OpenCVLib\\opencv_demo\\opencv_ml01\\train-labels.idx1-ubyte";
std::string testImgFile = "D:\\workForMy\\OpenCVLib\\opencv_demo\\opencv_ml01\\t10k-images.idx3-ubyte";
std::string testLabeFile = "D:\\workForMy\\OpenCVLib\\opencv_demo\\opencv_ml01\\t10k-labels.idx1-ubyte";

void train_SVM()
{
	//read train images, data type CV_32FC1
	cv::Mat trainingData = read_mnist_image(trainImgFile);
	//images data normalization
	trainingData = trainingData/255.0;
	std::cout << "trainingData.size() = " << trainingData.size() << std::endl; 
	//read train label, data type CV_32SC1
	cv::Mat labelsMat = read_mnist_label(trainLabeFile);
	std::cout << "labelsMat.size() = " << labelsMat.size() << std::endl; 
	std::cout << "trainingData & labelsMat finish!" << std::endl;  

    //create SVM model
    cv::Ptr<cv::ml::SVM> svm = cv::ml::SVM::create();  
	//set svm args,type and KernelTypes
    svm->setType(cv::ml::SVM::C_SVC);  
	svm->setKernel(cv::ml::SVM::POLY);  
	//KernelTypes POLY is need set gamma and degree
	svm->setGamma(3.0);
	svm->setDegree(2.0);
	//Set iteration termination conditions, maxCount is importance
	svm->setTermCriteria(cv::TermCriteria(cv::TermCriteria::EPS | cv::TermCriteria::COUNT, 1000, 1e-8)); 
	std::cout << "create SVM object finish!" << std::endl;  

	std::cout << "trainingData.rows = " << trainingData.rows << std::endl; 
	std::cout << "trainingData.cols = " << trainingData.cols << std::endl; 
	std::cout << "trainingData.type() = " << trainingData.type() << std::endl; 
    // svm model train 
    svm->train(trainingData, cv::ml::ROW_SAMPLE, labelsMat);  
	std::cout << "SVM training finish!" << std::endl; 
    // svm model test  
	cv::Mat testData = read_mnist_image(testImgFile);
	//images data normalization
	testData = testData/255.0;
	std::cout << "testData.rows = " << testData.rows << std::endl; 
	std::cout << "testData.cols = " << testData.cols << std::endl; 
	std::cout << "testData.type() = " << testData.type() << std::endl; 
	//read test label, data type CV_32SC1
	cv::Mat testlabel = read_mnist_label(testLabeFile);
	cv::Mat testResp;
	float response = svm->predict(testData,testResp); 
	// std::cout << "response = " << response << std::endl; 
	testResp.convertTo(testResp,CV_32SC1);
	int map_num = 0;
	for (int i = 0; i <testResp.rows&&testResp.rows==testlabel.rows; i++)
	{
		if (testResp.at<int>(i, 0) == testlabel.at<int>(i, 0))
		{
			map_num++;
		}
		// else{
		// 	std::cout << "testResp.at<int>(i, 0) " << testResp.at<int>(i, 0) << std::endl;
		// 	std::cout << "testlabel.at<int>(i, 0) " << testlabel.at<int>(i, 0) << std::endl;
		// }
	}
	float proportion  = float(map_num) / float(testResp.rows);
	std::cout << "map rate: " << proportion * 100 << "%" << std::endl;
	std::cout << "SVM testing finish!" << std::endl; 
	//save svm model
	svm->save("mnist_svm.xml");
}


void prediction(const std::string fileName,cv::Ptr<cv::ml::SVM> svm)
{
	//read img 28*28 size
	cv::Mat image = cv::imread(fileName, cv::IMREAD_GRAYSCALE);
	//uchar->float32
	image.convertTo(image, CV_32F);
	//image data normalization
	image = image / 255.0;
	//28*28 -> 1*784
	image = image.reshape(1, 1);

	//预测图片
	float ret = svm->predict(image);
	std::cout << "predict val = "<< ret << std::endl;
}

std::string imgDir = "D:\\workForMy\\OpenCVLib\\opencv_demo\\opencv_ml01\\t10k-images\\";
std::string ImgFiles[5] = {"image_0.png","image_10.png","image_20.png","image_30.png","image_40.png",};
void predictimgs()
{
	//load svm model
	cv::Ptr<cv::ml::SVM> svm = cv::ml::StatModel::load<cv::ml::SVM>("mnist_svm.xml");
	for (size_t i = 0; i < 5; i++)
	{
		prediction(imgDir+ImgFiles[i],svm);
	}
}

int main()  
{  
	train_SVM();
	predictimgs();	
    return 0;  
}

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:/a/574565.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

报错:OpenGL.error.NullFunctionError: Attempt to call an undefined function”

文件我已经上传 CSDN默认就是收费的 我修改不了 免费链接在文中 请寻找 OpenGL.error.NullFunctionError: Attempt to call an undefined function” 环境陈述: windows11 AMD-R9 python版本3.9.9 背景: 通过pip安装pip install PyOpenGL安装PyOpenGL模块后 运行出现的问题…

NLP Step by Step -- How to use pipeline

正如我们在摸鱼有一手&#xff1a;NLP step by step -- 了解Transformer中看到的那样&#xff0c;Transformers模型通常非常大。对于数以百万计到数千万计数十亿的参数&#xff0c;训练和部署这些模型是一项复杂的任务。此外&#xff0c;由于几乎每天都在发布新模型&#xff0c…

数据挖掘实验一

一、实验环境及背景 使用软件&#xff1a; Anaconda3 Jupyter Notebook 实验内容&#xff1a; 1.使用Tushare或者其他手段获取任意两支股票近三个月的交易数据。做出收盘价的变动图像。2.使用Pandas_datareader获取世界银行数据库中美国&#xff08;USA&#xff09;、瑞典&…

Windows电脑中护眼(夜间)模式的开启异常

我的电脑是联想小新16pro&#xff0c;Windows11版本。之前一直可以正常使用夜间模式&#xff0c;但是经过一次电脑的版本更新之后&#xff0c;我重启电脑发现我的夜间模式不能使用了。明明显示开启状态&#xff0c;但是却不能使用&#xff0c;电脑还是无法显示夜间模式。 询问…

基于Spring Boot的考研资讯平台设计与实现

基于Spring Boot的考研资讯平台设计与实现 开发语言&#xff1a;Java框架&#xff1a;springbootJDK版本&#xff1a;JDK1.8数据库工具&#xff1a;Navicat11开发软件&#xff1a;eclipse/myeclipse/idea 系统部分展示 系统功能界面图&#xff0c;在系统首页可以查看首页、考…

【Qt QML】TabBar的用法

Qt Quick中的TabBar提供了一个基于选项卡的导航模型。TabBar由TabButton控件填充&#xff0c;并且可以与任何提供currentIndex属性的布局或容器控件一起使用&#xff0c;例如StackLayout或SwipeView。 import QtQuick import QtQuick.Controls import QtQuick.LayoutsWindow …

FPGA实现AXI4总线的读写_如何写axi4逻辑

FPGA实现AXI4总线的读写_如何写axi4逻辑 一、AXI4 接口描述 通道信号源信号描述全局信号aclk主机全局时钟aresetn主机全局复位&#xff0c;低有效写通道地址与控制信号通道M_AXI_WR_awid[3:0]主机写地址ID&#xff0c;用来标志一组写信号M_AXI_WR_awaddr[31:0]主机写地址&…

贪吃蛇身子改进加贪吃蛇向右移动

1. 蛇移动的思想&#xff1a; 其实就是删除头节点 &#xff0c;增加尾节点&#xff1b;一句代码搞定 struct Snake *p; p head; head head -> next; free(p) 防止造成多的空间节点 2.增加尾节点代码思想&#xff1a; 2.1 .开辟new 节点的空间 struct Snake *new (stru…

每日OJ题_DFS回溯剪枝①_力扣46. 全排列(回溯算法简介)

目录 回溯算法简介 力扣46. 全排列 解析代码 回溯算法简介 回溯算法是一种经典的递归算法&#xff0c;通常⽤于解决组合问题、排列问题和搜索问题等。 回溯算法的基本思想&#xff1a;从一个初始状态开始&#xff0c;按照⼀定的规则向前搜索&#xff0c;当搜索到某个状态无…

Quarto Dashboards 教程 3:Dashboard Data Display

「写在前面」 学习一个软件最好的方法就是啃它的官方文档。本着自己学习、分享他人的态度&#xff0c;分享官方文档的中文教程。软件可能随时更新&#xff0c;建议配合官方文档一起阅读。推荐先按顺序阅读往期内容&#xff1a; 1.quarto 教程 1&#xff1a;Hello, Quarto 2.qu…

耐酸碱腐蚀PFA冷凝回流装置进口透明聚四氟材质PFA梨形漏斗特氟龙圆底烧瓶

PFA分液漏斗&#xff1a;也叫特氟龙分液漏斗、特氟龙梨型分液漏斗。 规格参考&#xff1a;125ml、250ml、500ml、1000ml 其主要特性有&#xff1a; 1.内壁对溶剂无粘贴性和吸附&#xff0c;可完全排空&#xff0c;分界面清晰可见&#xff1b; 2.密封性好&#xff0c;可防止…

excel文件导入dbeaver中文乱码

1.将excel文件进行另存为&#xff0c;保存类型选择【CSV】 2.选择【工具】–>【web选项】–> 【编码】–> 【简体中文&#xff08;GB18030&#xff09;】 3.在DBeaver进行数据导入 直接导入应该就可以&#xff0c;如果不行的话按下面处理。 选择【导入数据——选择cs…

云原生Kubernetes: K8S 1.29版本 部署Nexus

目录 一、实验 1.环境 2.搭建NFS 3. K8S 1.29版本 部署Nexus 二、问题 1.volumeMode有哪几种模式 一、实验 1.环境 &#xff08;1&#xff09;主机 表1 主机 主机架构版本IP备注masterK8S master节点1.29.0192.168.204.8 node1K8S node节点1.29.0192.168.204.9node2K…

Java毕业设计 基于SpringBoot vue养老院管理系统 微信小程序

Java毕业设计 基于SpringBoot vue养老院管理系统 微信小程序 SpringBoot 养老院管理系统 功能介绍 小程序 护工登录注册 忘记密码 护工信息维护 首页 图片轮播 床位调动申请 床位展示 床位详情 床位分配 房间展示 公告信息 公告详情 健康信息 请假申请 离职申请 后台管理 登…

09.JAVAEE之网络初识

1.网络 单机时代 >局域网时代 >广域网时代 >移动互联网时代 1.1 局域网LAN 局域网&#xff0c;即 Local Area Network&#xff0c;简称LAN。 Local 即标识了局域网是本地&#xff0c;局部组建的一种私有网络。 局域网内的主机之间能方便的进行网络通信&#xff0…

有哪些人工智能/数据分析领域可以考取的证书?

一、TensorFlow谷歌开发者认证 TensorFlow面向学生、开发者、数据科学家等人群&#xff0c;帮助他们展示自己在用 TensorFlow 构建、训练模型的过程中所学到的实用机器学习技能。 添加图片注释&#xff0c;不超过 140 字&#xff08;可选&#xff09; TensorFlow 的产品总监 …

抖音智能运营系统源码

这是一个一站式服务的抖音智能运营系统&#xff0c;旨在提升内容创作者和营销人员的工作效率。它是一个综合性的在线服务平台&#xff0c;专为抖音内容创作者和营销人员设计。系统基于高性能、可扩展性强的ThinkPHP框架&#xff0c;整合了视频处理、数据分析、文案生成与配音等…

Redis网络部分相关的结构体2 和 绑定回调函数细节

目录 1. struct connection ConnectionType属性 创建connection 2. struct client 3. 绑定客户端回调函数的流程 3.1. 读事件回调函数的设置 3.2. 写事件回调函数的设置 3.3. connSocketEventHandler函数 3.4. Redis5版本的设置回调函数 3.5. 个人的一些想法&#xf…

2024贵州康博会|特色健康食品展|医药展|医疗器械展会

2024 中国(贵州)大健康产业博览会2024 特色食品(农产品、水、饮料)暨第22届医药及医疗器械、设备展览会邀请函 时间:2024 年 9 月 26 日 -28 日(共三天) 地点:贵阳国际会议展览中心 &#xff08;观山湖区&#xff09; 主办单位: 贵州省天然饮用水行业协会 贵州省大健康产业…

diskMirror docker 使用容器部署 diskMirror 服务器!!!

Welcome to diskMirror-docker 获取项目 这个项目是 diskMirror-spring-boot 镜像版本的项目&#xff0c;您可以使用下面的命令将此项目编译为一个镜像&#xff01; # 进入到您下载的源码包目录 cd diskMirror-docker# 点击脚本来进行版本的设置以及对应版本的下载 设置 和 编…