目录
一、opencv支持向量机(svm)模块
1.1 opencv的机器学习库
opencv-ml库是opencv(开放源代码计算机视觉库)中的机器学习模块,常用于分类和回归问题,它是 opencv 众多modules下的一个模块。
该模块提供了一系列常见的统计模型和分类算法,用于进行各种机器学习任务。以下是关于opencv-ml库的一些主要功能和特点:
- 丰富的算法支持:opencv-ml库包含了多种机器学习算法,如支持向量机(svm)、决策树、boosting方法、k近邻(knn)、随机森林等。这些算法可以用于分类、回归、聚类等多种任务。
- 易于使用:opencv-ml库提供了简洁的api接口,使得开发者能够方便地调用各种机器学习算法。同时,它也支持多种数据格式,方便用户导入和处理数据。
- 高效性:opencv-ml库经过优化,能够高效地处理大规模数据集,并且具有较快的运算速度。这使得它能够满足实时处理和分析的需求。
- 与opencv其他模块的集成:opencv-ml库与opencv的其他模块(如imgproc、features2d等)紧密集成,可以方便地进行图像处理和特征提取,然后将提取的特征用于机器学习任务。
1.2 svm(支持向量机)模块
opencv 的 svm(支持向量机)模块是 opencv 机器学习库中的一个重要组成部分,它实现了支持向量机算法,用于解决分类和回归问题。支持向量机是一种监督学习模型,广泛应用于各种领域,特别是在图像分类和识别任务中。
opencv 的 svm 模块提供了灵活的参数设置和多种核函数选择,以适应不同的数据集和问题。以下是一些关于 opencv svm 模块的主要特点:
-
多种核函数:支持线性核、多项式核、径向基函数(rbf)核和 sigmoid 核等,可以根据问题的特性选择合适的核函数。
-
参数调整:可以通过调整 svm 的参数,如 c 值(错误项的惩罚系数)和 gamma 值(对于 rbf、poly 和 sigmoid 核函数),来优化模型的性能。
-
多类分类支持:通过“一对一”或“一对多”的方式,可以处理多类分类问题。
-
概率估计:svm 可以输出类别的概率估计,这对于某些应用(如置信度评估)非常有用。
-
易于使用:opencv 提供了简洁的 api,使得 svm 的训练和测试过程相对简单。
1.3 支持向量机(svm)应用步骤
在opencv中,使用支持向量机(svm)进行预测涉及几个步骤。首先,获得训练数据,用于训练一个svm模型,然后使用该模型对新的、未见过的数据进行预测。
使用svm模型,包含必要的头文件:
#include <opencv2/opencv.hpp>
#include <opencv2/ml/ml.hpp>
1) 准备训练和测试数据:
你需要为svm准备训练和测试数据。这些数据通常是特征向量,存储在cv::mat对象中。每个特征向量对应一个标签(分类的类别)。
2)创建和训练svm模型:
使用opencv的cv::ml::svm类来创建svm模型。然后,使用train方法来训练模型。
3) 进行预测:
使用训练好的模型对新数据进行预测。这通常涉及将新数据作为输入传递给模型的predict方法。
二、支持向量机(svm)应用示例
2.1 训练及验证数据获取
以下展示如何使用opencv的机器学习模块来实现一个基于svm的手写数字识别器。首先前往网站:mnist handwritten digit database, yann lecun, corinna cortes and chris burges,下载mnist database,用于实现一个svm的手写数字识别模型训练及验证。
下载完成后,进行解压操作:
解压后是idx1-ubyte
和 idx3-ubyte
是两种常见的标签编码格式,主要用于图像分割任务中。它们都是用来表示图像中每个像素所属类别的标签图像(也称为掩码或mask)。
-
idx1-ubyte:
- idx: 表示这是一个索引图像。
- 1: 表示每个像素用一个字节(8位)来表示,且这些值从0开始,通常是连续的整数。
- ubyte: 表示无符号字节类型,其值的范围是0到255。在
idx1-ubyte
格式中,通常会将0用作背景或未标记的类别,而其他值则用于表示不同的分割区域或类别。
-
idx3-ubyte:
- idx: 同样表示这是一个索引图像。
- 3: 这里并不是指每个像素用3个字节来表示,而是指每个像素用一个字节来表示,但值的范围是从0到255,通常用来表示256个不同的类别(包括0作为背景或未标记的类别)。注意,虽然名为
idx3
,但实际上它并不是用3个字节来存储每个像素的值。 - ubyte: 同样表示无符号字节类型。
在图像分割任务中,这些标签图像通常与原始rgb图像一起使用。rgb图像用于显示给人类观察者或作为模型的输入,而标签图像则用于训练模型或评估模型的性能。
2.2 训练及验证数据加载
idx3-ubyte
文件通常与 mnist 数据集相关联,这是一个大型的手写数字数据库,经常用于机器学习和深度学习中的图像识别任务。mnist 数据集包含两个文件:train-images-idx3-ubyte
和 train-labels-idx1-ubyte
(用于训练),以及 t10k-images-idx3-ubyte
和 t10k-labels-idx1-ubyte
(用于测试)。这些文件使用特定的二进制格式存储图像和标签。
通过两个函数来读取手写图像数据集和手写图像数据对应的标签(每个标签都是一个 0 到 9 之间的整数,表示对应图像中的手写数字)。
//大小端转换
int intreverse(int num)
{
return (num>>24|((num&0xff0000)>>8)|((num&0xff00)<<8)|((num&0xff)<<24));
}
//读取手写图像数据集
cv::mat read_mnist_image(const std::string filename) {
int magic_number = 0;
int number_of_images = 0;
int img_rows = 0;
int img_cols = 0;
cv::mat datamat;
std::ifstream file(filename, std::ios::binary);
if (file.is_open())
{
std::cout << "open images file: "<< filename << std::endl;
file.read((char*)&magic_number, sizeof(magic_number));//format
file.read((char*)&number_of_images, sizeof(number_of_images));//images number
file.read((char*)&img_rows, sizeof(img_rows));//img rows
file.read((char*)&img_cols, sizeof(img_cols));//img cols
magic_number = intreverse(magic_number);
number_of_images = intreverse(number_of_images);
img_rows = intreverse(img_rows);
img_cols = intreverse(img_cols);
std::cout << "format:" << magic_number
<< " img num:" << number_of_images
<< " img row:" << img_rows
<< " img col:" << img_cols << std::endl;
std::cout << "read img data" << std::endl;
datamat = cv::mat::zeros(number_of_images, img_rows * img_cols, cv_32fc1);
unsigned char temp = 0;
for (int i = 0; i < number_of_images; i++) {
for (int j = 0; j < img_rows * img_cols; j++) {
file.read((char*)&temp, sizeof(temp));
//svm data is cv_32fc1
float pixel_value = float(temp);
datamat.at<float>(i, j) = pixel_value;
}
}
std::cout << "read img data finish!" << std::endl;
}
file.close();
return datamat;
}
//读取手写标签
cv::mat read_mnist_label(const std::string filename) {
int magic_number;
int number_of_items;
cv::mat labelmat;
std::ifstream file(filename, std::ios::binary);
if (file.is_open())
{
std::cout << "open label file: "<< filename << std::endl;
file.read((char*)&magic_number, sizeof(magic_number));
file.read((char*)&number_of_items, sizeof(number_of_items));
magic_number = intreverse(magic_number);
number_of_items = intreverse(number_of_items);
std::cout << "format:" << magic_number << " ;label_num:" << number_of_items << std::endl;
std::cout << "read label data" << std::endl;
//data type:cv_32sc1,channel:1
labelmat = cv::mat::zeros(number_of_items, 1, cv_32sc1);
for (int i = 0; i < number_of_items; i++) {
unsigned char temp = 0;
file.read((char*)&temp, sizeof(temp));
labelmat.at<unsigned int>(i, 0) = (unsigned int)temp;
}
std::cout << "read label data finish!" << std::endl;
}
file.close();
return labelmat;
}
2.3 svm(支持向量机)训练及验证,输出svm模型
1)加载训练图像数据和标签数据,采用cv::mat存储,图像数据虚归一化;
2)创建svm模型,设置svm模型的各关联参数,不同参数设置,对应模型精度有较大影响;
3)加载测试图像数据和标签数据,采用cv::mat存储,图像数据虚归一化;
4)采用测试图像数据验证已经训练好的svm模型,获得测试推演结果;
5)通过测试结果和已有的标签数据进行校对,验证该模型精度。
6)将训练好的模型保持输出。便于后续用于实时识别应用。
//change path for real paths
std::string trainimgfile = "d:\\workformy\\opencvlib\\opencv_demo\\opencv_ml01\\train-images.idx3-ubyte";
std::string trainlabefile = "d:\\workformy\\opencvlib\\opencv_demo\\opencv_ml01\\train-labels.idx1-ubyte";
std::string testimgfile = "d:\\workformy\\opencvlib\\opencv_demo\\opencv_ml01\\t10k-images.idx3-ubyte";
std::string testlabefile = "d:\\workformy\\opencvlib\\opencv_demo\\opencv_ml01\\t10k-labels.idx1-ubyte";
void train_svm()
{
//read train images, data type cv_32fc1
cv::mat trainingdata = read_mnist_image(trainimgfile);
//images data normalization
trainingdata = trainingdata/255.0;
std::cout << "trainingdata.size() = " << trainingdata.size() << std::endl;
//read train label, data type cv_32sc1
cv::mat labelsmat = read_mnist_label(trainlabefile);
std::cout << "labelsmat.size() = " << labelsmat.size() << std::endl;
std::cout << "trainingdata & labelsmat finish!" << std::endl;
//create svm model
cv::ptr<cv::ml::svm> svm = cv::ml::svm::create();
//set svm args,type and kerneltypes
svm->settype(cv::ml::svm::c_svc);
svm->setkernel(cv::ml::svm::poly);
//kerneltypes poly is need set gamma and degree
svm->setgamma(3.0);
svm->setdegree(2.0);
//set iteration termination conditions, maxcount is importance
svm->settermcriteria(cv::termcriteria(cv::termcriteria::eps | cv::termcriteria::count, 1000, 1e-8));
std::cout << "create svm object finish!" << std::endl;
std::cout << "trainingdata.rows = " << trainingdata.rows << std::endl;
std::cout << "trainingdata.cols = " << trainingdata.cols << std::endl;
std::cout << "trainingdata.type() = " << trainingdata.type() << std::endl;
// svm model train
svm->train(trainingdata, cv::ml::row_sample, labelsmat);
std::cout << "svm training finish!" << std::endl;
// svm model test
cv::mat testdata = read_mnist_image(testimgfile);
//images data normalization
testdata = testdata/255.0;
std::cout << "testdata.rows = " << testdata.rows << std::endl;
std::cout << "testdata.cols = " << testdata.cols << std::endl;
std::cout << "testdata.type() = " << testdata.type() << std::endl;
//read test label, data type cv_32sc1
cv::mat testlabel = read_mnist_label(testlabefile);
cv::mat testresp;
float response = svm->predict(testdata,testresp);
// std::cout << "response = " << response << std::endl;
testresp.convertto(testresp,cv_32sc1);
int map_num = 0;
for (int i = 0; i <testresp.rows&&testresp.rows==testlabel.rows; i++)
{
if (testresp.at<int>(i, 0) == testlabel.at<int>(i, 0))
{
map_num++;
}
// else{
// std::cout << "testresp.at<int>(i, 0) " << testresp.at<int>(i, 0) << std::endl;
// std::cout << "testlabel.at<int>(i, 0) " << testlabel.at<int>(i, 0) << std::endl;
// }
}
float proportion = float(map_num) / float(testresp.rows);
std::cout << "map rate: " << proportion * 100 << "%" << std::endl;
std::cout << "svm testing finish!" << std::endl;
//save svm model
svm->save("mnist_svm.xml");
}
2.4 svm(支持向量机)实时识别应用
将t10k-images.idx3-ubyte处理成图片数据,用于svm模型调用示例,本文主要是通过一段python代码,将t10k-images.idx3-ubyte另存为一张张手写图片。
import numpy as np
import os
from pil import image
from struct import unpack
def read_idx3_ubyte(filename):
with open(filename, 'rb') as f:
magic, num_images, rows, cols = unpack('>iiii', f.read(16))
buf = f.read()
data = np.frombuffer(buf, dtype=np.uint8).reshape((num_images, rows, cols))
return data
def save_images_as_png(idx3_file, output_dir, prefix='image'):
images = read_idx3_ubyte(idx3_file)
for i, image in enumerate(images):
image_pil = image.fromarray(image, 'l') # 'l' 表示灰度模式
filename = f"{output_dir}/{prefix}_{i}.png"
image_pil.save(filename)
# 使用示例
# idx3_file = 'train-images.idx3-ubyte'
# output_dir = 'train-images'
# if not os.path.exists(output_dir):#检查目录是否存在
# os.makedirs(output_dir)#如果不存在则创建目录
# save_images_as_png(idx3_file, output_dir)
idx3_file = 't10k-images.idx3-ubyte'
output_dir = 't10k-images'
if not os.path.exists(output_dir):#检查目录是否存在
os.makedirs(output_dir)#如果不存在则创建目录
save_images_as_png(idx3_file, output_dir)
在获得图片数据后,将加载这些图片,和上述已保存的svm模型(mnist_svm.xml),实现模型调用验证。
void prediction(const std::string filename,cv::ptr<cv::ml::svm> svm)
{
//read img 28*28 size
cv::mat image = cv::imread(filename, cv::imread_grayscale);
//uchar->float32
image.convertto(image, cv_32f);
//image data normalization
image = image / 255.0;
//28*28 -> 1*784
image = image.reshape(1, 1);
//预测图片
float ret = svm->predict(image);
std::cout << "predict val = "<< ret << std::endl;
}
std::string imgdir = "d:\\workformy\\opencvlib\\opencv_demo\\opencv_ml01\\t10k-images\\";
std::string imgfiles[5] = {"image_0.png","image_10.png","image_20.png","image_30.png","image_40.png",};
void predictimgs()
{
//load svm model
cv::ptr<cv::ml::svm> svm = cv::ml::statmodel::load<cv::ml::svm>("mnist_svm.xml");
for (size_t i = 0; i < 5; i++)
{
prediction(imgdir+imgfiles[i],svm);
}
}
三、完整代码编译
3.1 opencv+mingw的makefile编译
本文是采用win系统下,opencv采用mingw编译的静态库(c/c++开发,win下opencv+mingw编译环境搭建_opencv mingw-csdn博客),建立makefile:
#/bin/sh
#win32
cx= g++ -dwin32
#linux
#cx= g++ -dlinux
bin := ./
target := opencv_ml01.exe
flags := -std=c++11 -static
srcdir := ./
#includes
includedir := -i"../../opencv_mingw/include" -i"./"
#-i"$(srcdir)"
staticdir := ../../opencv_mingw/x64/mingw/staticlib/
#libdir := $(staticdir)/libopencv_world460.a\
# $(staticdir)/libade.a \
# $(staticdir)/libilmimf.a \
# $(staticdir)/libquirc.a \
# $(staticdir)/libzlib.a \
# $(wildcard $(staticdir)/liblib*.a) \
# -lgdi32 -lcomdlg32 -loleaut32 -lole32 -luuid
#opencv_world放弃前,然后是opencv依赖的第三方库,后面的库是mingw编译工具的库
libdir := -l $(staticdir) -lopencv_world460 -lade -lilmimf -lquirc -lzlib \
-llibjpeg-turbo -llibopenjp2 -llibpng -llibprotobuf -llibtiff -llibwebp \
-lgdi32 -lcomdlg32 -loleaut32 -lole32 -luuid
source := $(wildcard $(srcdir)/*.cpp)
$(target) :
$(cx) $(flags) $(includedir) $(source) -o $(bin)/$(target) $(libdir)
clean:
rm $(bin)/$(target)
编译如下:
3.2 opencv+vc2015+cmake编译
第二种编译,本文采用了vs2015 x64编译了opencv库(c/c++开发,opencv在win下安装及应用_windows安装opencv c++-csdn博客)。
建立cmake文件:
# cmake 最低版本号要求
cmake_minimum_required (version 2.8)
# 项目信息
project (opencv_test)
#
message(status "windows compiling...")
add_definitions(-d_platform_is_windows_)
set(cmake_cxx_flags_release "${cmake_cxx_flags_release} /mt")
set(cmake_cxx_flags_debug "${cmake_cxx_flags_debug} /mtd")
set(win_os true)
#
set(executable_output_path ${project_source_dir}/bin)
# 指定源文件的目录,并将名称保存到变量
set(source_h
#
)
set(source_cpp
#
${project_source_dir}/main.cpp
)
#头文件目录
include_directories(${project_source_dir}/../../opencv_vc/include)
set(cmake_cxx_flags "${cmake_cxx_flags} /wd4819")
add_definitions(
"-d_crt_secure_no_warnings"
"-d_winsock_deprecated_no_warnings"
"-dno_warn_mbcs_mfc_deprecation"
"-dwin32_lean_and_mean"
)
link_directories(
${project_source_dir}/../../opencv_vc/x64/vc14/bin
${project_source_dir}/../../opencv_vc/x64/vc14/lib
)
if (cmake_build_type strequal "debug")
set(cmake_runtime_output_directory_debug ${project_source_dir})
# 指定生成目标
add_executable(opencv_testd ${source_h} ${source_cpp})
else(cmake_build_type)
set(cmake_runtime_output_directory_release ${project_source_dir})
# 指定生成目标
add_executable(opencv_test ${source_h} ${source_cpp})
target_link_libraries(opencv_test opencv_world460.lib opencv_img_hash460.lib)
endif (cmake_build_type)
# mkdir build_win
# cd build_win
# cmake -g "visual studio 14 2015 win64" -dcmake_build_type=release ..
# msbuild opencv_test.sln /p:configuration="release" /p:platform="x64"
启动vs2015 x64的命令工具(使前面配置的环境变量生效),进入main.cpp文件目录,编译如下:
mkdir build_win
cd build_win
cmake -g "visual studio 14 2015 win64" -dcmake_build_type=release ..
msbuild opencv_test.sln /p:configuration="release" /p:platform="x64"
编译输出大致如下:
3.3 执行效果
【1】opencv+mingw+makefile编译程序执行输出,准确率达到98%以上(ps,大家可尝试去调设svm模型的参数设置,看怎样设置可以获得更高的准确率)
通过模型调用识别图片全ok(呵呵,毕竟是测试集内的图片数据)
【2】opencv+vc2015+cmake编译程序执行输出,同样能到达效果。
3.4 附件,main.cpp全文
#include <opencv2/opencv.hpp>
#include <opencv2/ml/ml.hpp>
#include <opencv2/imgcodecs.hpp>
#include <iostream>
#include <vector>
#include <iostream>
#include <fstream>
int intreverse(int num)
{
return (num>>24|((num&0xff0000)>>8)|((num&0xff00)<<8)|((num&0xff)<<24));
}
std::string inttostring(int num)
{
char buf[32]={0};
itoa(num,buf,10);
return std::string(buf);
}
cv::mat read_mnist_image(const std::string filename) {
int magic_number = 0;
int number_of_images = 0;
int img_rows = 0;
int img_cols = 0;
cv::mat datamat;
std::ifstream file(filename, std::ios::binary);
if (file.is_open())
{
std::cout << "open images file: "<< filename << std::endl;
file.read((char*)&magic_number, sizeof(magic_number));//format
file.read((char*)&number_of_images, sizeof(number_of_images));//images number
file.read((char*)&img_rows, sizeof(img_rows));//img rows
file.read((char*)&img_cols, sizeof(img_cols));//img cols
magic_number = intreverse(magic_number);
number_of_images = intreverse(number_of_images);
img_rows = intreverse(img_rows);
img_cols = intreverse(img_cols);
std::cout << "format:" << magic_number
<< " img num:" << number_of_images
<< " img row:" << img_rows
<< " img col:" << img_cols << std::endl;
std::cout << "read img data" << std::endl;
datamat = cv::mat::zeros(number_of_images, img_rows * img_cols, cv_32fc1);
unsigned char temp = 0;
for (int i = 0; i < number_of_images; i++) {
for (int j = 0; j < img_rows * img_cols; j++) {
file.read((char*)&temp, sizeof(temp));
//svm data is cv_32fc1
float pixel_value = float(temp);
datamat.at<float>(i, j) = pixel_value;
}
}
std::cout << "read img data finish!" << std::endl;
}
file.close();
return datamat;
}
cv::mat read_mnist_label(const std::string filename) {
int magic_number;
int number_of_items;
cv::mat labelmat;
std::ifstream file(filename, std::ios::binary);
if (file.is_open())
{
std::cout << "open label file: "<< filename << std::endl;
file.read((char*)&magic_number, sizeof(magic_number));
file.read((char*)&number_of_items, sizeof(number_of_items));
magic_number = intreverse(magic_number);
number_of_items = intreverse(number_of_items);
std::cout << "format:" << magic_number << " ;label_num:" << number_of_items << std::endl;
std::cout << "read label data" << std::endl;
//data type:cv_32sc1,channel:1
labelmat = cv::mat::zeros(number_of_items, 1, cv_32sc1);
for (int i = 0; i < number_of_items; i++) {
unsigned char temp = 0;
file.read((char*)&temp, sizeof(temp));
labelmat.at<unsigned int>(i, 0) = (unsigned int)temp;
}
std::cout << "read label data finish!" << std::endl;
}
file.close();
return labelmat;
}
//change path for real paths
std::string trainimgfile = "d:\\workformy\\opencvlib\\opencv_demo\\opencv_ml01\\train-images.idx3-ubyte";
std::string trainlabefile = "d:\\workformy\\opencvlib\\opencv_demo\\opencv_ml01\\train-labels.idx1-ubyte";
std::string testimgfile = "d:\\workformy\\opencvlib\\opencv_demo\\opencv_ml01\\t10k-images.idx3-ubyte";
std::string testlabefile = "d:\\workformy\\opencvlib\\opencv_demo\\opencv_ml01\\t10k-labels.idx1-ubyte";
void train_svm()
{
//read train images, data type cv_32fc1
cv::mat trainingdata = read_mnist_image(trainimgfile);
//images data normalization
trainingdata = trainingdata/255.0;
std::cout << "trainingdata.size() = " << trainingdata.size() << std::endl;
//read train label, data type cv_32sc1
cv::mat labelsmat = read_mnist_label(trainlabefile);
std::cout << "labelsmat.size() = " << labelsmat.size() << std::endl;
std::cout << "trainingdata & labelsmat finish!" << std::endl;
//create svm model
cv::ptr<cv::ml::svm> svm = cv::ml::svm::create();
//set svm args,type and kerneltypes
svm->settype(cv::ml::svm::c_svc);
svm->setkernel(cv::ml::svm::poly);
//kerneltypes poly is need set gamma and degree
svm->setgamma(3.0);
svm->setdegree(2.0);
//set iteration termination conditions, maxcount is importance
svm->settermcriteria(cv::termcriteria(cv::termcriteria::eps | cv::termcriteria::count, 1000, 1e-8));
std::cout << "create svm object finish!" << std::endl;
std::cout << "trainingdata.rows = " << trainingdata.rows << std::endl;
std::cout << "trainingdata.cols = " << trainingdata.cols << std::endl;
std::cout << "trainingdata.type() = " << trainingdata.type() << std::endl;
// svm model train
svm->train(trainingdata, cv::ml::row_sample, labelsmat);
std::cout << "svm training finish!" << std::endl;
// svm model test
cv::mat testdata = read_mnist_image(testimgfile);
//images data normalization
testdata = testdata/255.0;
std::cout << "testdata.rows = " << testdata.rows << std::endl;
std::cout << "testdata.cols = " << testdata.cols << std::endl;
std::cout << "testdata.type() = " << testdata.type() << std::endl;
//read test label, data type cv_32sc1
cv::mat testlabel = read_mnist_label(testlabefile);
cv::mat testresp;
float response = svm->predict(testdata,testresp);
// std::cout << "response = " << response << std::endl;
testresp.convertto(testresp,cv_32sc1);
int map_num = 0;
for (int i = 0; i <testresp.rows&&testresp.rows==testlabel.rows; i++)
{
if (testresp.at<int>(i, 0) == testlabel.at<int>(i, 0))
{
map_num++;
}
// else{
// std::cout << "testresp.at<int>(i, 0) " << testresp.at<int>(i, 0) << std::endl;
// std::cout << "testlabel.at<int>(i, 0) " << testlabel.at<int>(i, 0) << std::endl;
// }
}
float proportion = float(map_num) / float(testresp.rows);
std::cout << "map rate: " << proportion * 100 << "%" << std::endl;
std::cout << "svm testing finish!" << std::endl;
//save svm model
svm->save("mnist_svm.xml");
}
void prediction(const std::string filename,cv::ptr<cv::ml::svm> svm)
{
//read img 28*28 size
cv::mat image = cv::imread(filename, cv::imread_grayscale);
//uchar->float32
image.convertto(image, cv_32f);
//image data normalization
image = image / 255.0;
//28*28 -> 1*784
image = image.reshape(1, 1);
//预测图片
float ret = svm->predict(image);
std::cout << "predict val = "<< ret << std::endl;
}
std::string imgdir = "d:\\workformy\\opencvlib\\opencv_demo\\opencv_ml01\\t10k-images\\";
std::string imgfiles[5] = {"image_0.png","image_10.png","image_20.png","image_30.png","image_40.png",};
void predictimgs()
{
//load svm model
cv::ptr<cv::ml::svm> svm = cv::ml::statmodel::load<cv::ml::svm>("mnist_svm.xml");
for (size_t i = 0; i < 5; i++)
{
prediction(imgdir+imgfiles[i],svm);
}
}
int main()
{
train_svm();
predictimgs();
return 0;
}
发表评论