广告位联系
返回顶部
分享到

C++ OpenCV之手写数字识别的教程

C语言 来源:互联网 作者:酷站 发布时间:2022-08-07 21:45:00 人浏览
摘要

前言 本案例通过使用machine learning机器学习模块进行手写数字识别。源码注释也写得比较清楚啦,大家请看源码注释!!! 一、准备数据集 原图如图所示:总共有0~9数字类别,每个数字

前言

本案例通过使用machine learning机器学习模块进行手写数字识别。源码注释也写得比较清楚啦,大家请看源码注释!!!

一、准备数据集

原图如图所示:总共有0~9数字类别,每个数字共20个。现在需要将下面图片切分成训练数据图片、测试数据图片。该图片尺寸为560x280,故将其切割成28x28大小数据图片。具体请看源码注释。

1

2

3

4

5

6

7

8

9

10

const int classNum = 10;  //总共有0~9个数字类别

const int picNum = 20;//每个类别共20张图片

const int pic_w = 28;//图片宽

const int pic_h = 28;//图片高

 

//将数据集分为训练集、测试集

double totalNum = classNum * picNum;//图片总数

double per = 0.8;   //百分比--修改百分比可改变训练集、测试集比重

double trainNum = totalNum * per;//训练图片数量

double testNum = totalNum * (1.0 - per);//测试图片数量

下面需要将整张图像一一切割成28x28小尺寸图片作为数据集,填充至训练集与测试集。

1

2

3

4

5

6

7

8

9

10

11

12

13

14

15

16

17

18

19

Mat Train_Data, Train_Label;//用于训练

vector<MyNum>TestData;//用于测试

for (int i = 0; i < picNum; i++)

{

    for (int j = 0; j < classNum; j++)

    {

        //将所有图片数据都拷贝到Mat矩阵里

        Mat temp;

        gray(Range(j*pic_w, j*pic_w + pic_w), Range(i*pic_h, i*pic_h + pic_h)).copyTo(temp);

        Train_Data.push_back(temp.reshape(0, 1)); //将temp数字图像reshape成一行数据,然后一一追加到Train_Data矩阵中

        Train_Label.push_back(j);

 

        //而外用于测试

        if (i * classNum + j >= trainNum)

        {

            TestData.push_back({ temp,Rect(i*pic_w,j*pic_h,pic_w,pic_h),j });

        }

    }

}

接下来就是要将数据集进行格式转换。

1

2

3

4

5

//准备训练数据集

Train_Data.convertTo(Train_Data, CV_32FC1); //转化为CV_32FC1类型

Train_Label.convertTo(Train_Label, CV_32FC1);

Mat TrainDataMat = Train_Data(Range(0, trainNum), Range::all()); //只取trainNum行训练

Mat TrainLabelMat = Train_Label(Range(0, trainNum), Range::all());

二、KNN训练

这里使用OpenCV中的KNN算法进行训练。

1

2

3

4

5

6

7

//KNN训练

const int k = 3;  //k值,取奇数,影响最终识别率

Ptr<KNearest>knn = KNearest::create();  //构造KNN模型

knn->setDefaultK(k);//设定k值

knn->setIsClassifier(true);//KNN算法可用于分类、回归。

knn->setAlgorithmType(KNearest::BRUTE_FORCE);//字符匹配算法

knn->train(TrainDataMat, ROW_SAMPLE, TrainLabelMat);//模型训练

三、模型预测及结果显示

1

2

3

4

5

6

7

8

9

10

11

12

13

14

15

16

17

18

19

20

21

22

23

24

25

26

27

28

29

30

31

32

33

//预测及结果显示

double count = 0.0;

Scalar color;

for (int i = 0; i < TestData.size(); i++)

{

    //将测试图片转成CV_32FC1,单行形式

    Mat data = TestData[i].mat.reshape(0, 1);

    data.convertTo(data, CV_32FC1);

    Mat sample = data(Range(0, data.rows), Range::all());

 

    float f = knn->predict(sample); //预测

    if (f == TestData[i].label)

    {

        color = Scalar(0, 255, 0); //如果预测正确,绘制绿色,并且结果+1

        count++;

    }

    else

    {

        color = Scalar(0, 0, 255);//如果预测错误,绘制红色

    }

 

    rectangle(src, TestData[i].rect, color, 2);

}

 

//将绘制结果拷贝到一张新图上

Mat result(Size(src.cols, src.rows + 50), CV_8UC3, Scalar::all(255));

src.copyTo(result(Rect(0, 0, src.cols, src.rows)));

//将得分在结果图上显示

char text[10];

int score = (count / testNum) * 100;

sprintf_s(text, "%s%d%s", "Score:", score, "%");

putText(result, text, Point((result.cols / 2) - 80, result.rows - 15), FONT_HERSHEY_SIMPLEX, 1, Scalar(0, 255, 0), 2);

如图为不同比重训练集与测试集识别结果。

四、源码

1

2

3

4

5

6

7

8

9

10

11

12

13

14

15

16

17

18

19

20

21

22

23

24

25

26

27

28

29

30

31

32

33

34

35

36

37

38

39

40

41

42

43

44

45

46

47

48

49

50

51

52

53

54

55

56

57

58

59

60

61

62

63

64

65

66

67

68

69

70

71

72

73

74

75

76

77

78

79

80

81

82

83

84

85

86

87

88

89

90

91

92

93

94

95

96

97

98

99

100

101

102

103

104

105

106

107

108

109

110

111

112

#include<iostream>

#include<opencv2/opencv.hpp>

#include<opencv2/ml.hpp>

using namespace std;

using namespace cv;

using namespace cv::ml;

 

 

//**自定义结构体

struct MyNum

{

    cv::Mat mat; //数字图片

    cv::Rect rect;//相对整张图所在矩形

    int label;//数字标签

};

 

int main()

{

    Mat src = imread("digit.png");

    if (src.empty())

    {

        cout << "No Image..." << endl;

        system("pause");

        return -1;

    }

 

    Mat gray;

    cvtColor(src, gray, COLOR_BGR2GRAY);

 

    const int classNum = 10;  //总共有0~9个数字类别

    const int picNum = 20;//每个类别共20张图片

    const int pic_w = 28;//图片宽

    const int pic_h = 28;//图片高

 

    //将数据集分为训练集、测试集

    double totalNum = classNum * picNum;//图片总数

    double per = 0.8;   //百分比--修改百分比可改变训练集、测试集比重

    double trainNum = totalNum * per;//训练图片数量

    double testNum = totalNum * (1.0 - per);//测试图片数量

 

    Mat Train_Data, Train_Label;//用于训练

    vector<MyNum>TestData;//用于测试

    for (int i = 0; i < picNum; i++)

    {

        for (int j = 0; j < classNum; j++)

        {

            //将所有图片数据都拷贝到Mat矩阵里

            Mat temp;

            gray(Range(j*pic_w, j*pic_w + pic_w), Range(i*pic_h, i*pic_h + pic_h)).copyTo(temp);

            Train_Data.push_back(temp.reshape(0, 1)); //将temp数字图像reshape成一行数据,然后一一追加到Train_Data矩阵中

            Train_Label.push_back(j);

 

            //额外用于测试

            if (i * classNum + j >= trainNum)

            {

                TestData.push_back({ temp,Rect(i*pic_w,j*pic_h,pic_w,pic_h),j });

            }

        }

    }

 

    //准备训练数据集

    Train_Data.convertTo(Train_Data, CV_32FC1); //转化为CV_32FC1类型

    Train_Label.convertTo(Train_Label, CV_32FC1);

    Mat TrainDataMat = Train_Data(Range(0, trainNum), Range::all()); //只取trainNum行训练

    Mat TrainLabelMat = Train_Label(Range(0, trainNum), Range::all());

 

    //KNN训练

    const int k = 3;  //k值,取奇数,影响最终识别率

    Ptr<KNearest>knn = KNearest::create();  //构造KNN模型

    knn->setDefaultK(k);//设定k值

    knn->setIsClassifier(true);//KNN算法可用于分类、回归。

    knn->setAlgorithmType(KNearest::BRUTE_FORCE);//字符匹配算法

    knn->train(TrainDataMat, ROW_SAMPLE, TrainLabelMat);//模型训练

 

    //预测及结果显示

    double count = 0.0;

    Scalar color;

    for (int i = 0; i < TestData.size(); i++)

    {

        //将测试图片转成CV_32FC1,单行形式

        Mat data = TestData[i].mat.reshape(0, 1);

        data.convertTo(data, CV_32FC1);

        Mat sample = data(Range(0, data.rows), Range::all());

 

        float f = knn->predict(sample); //预测

        if (f == TestData[i].label)

        {

            color = Scalar(0, 255, 0); //如果预测正确,绘制绿色,并且结果+1

            count++;

        }

        else

        {

            color = Scalar(0, 0, 255);//如果预测错误,绘制红色

        }

 

        rectangle(src, TestData[i].rect, color, 2);

    }

 

    //将绘制结果拷贝到一张新图上

    Mat result(Size(src.cols, src.rows + 50), CV_8UC3, Scalar::all(255));

    src.copyTo(result(Rect(0, 0, src.cols, src.rows)));

    //将得分在结果图上显示

    char text[10];

    int score = (count / testNum) * 100;

    sprintf_s(text, "%s%d%s", "Score:", score, "%");

    putText(result, text, Point((result.cols / 2) - 80, result.rows - 15), FONT_HERSHEY_SIMPLEX, 1, Scalar(0, 255, 0), 2);

    imshow("test", result);

    imwrite("result.jpg", result);

    waitKey(0);

    system("pause");

    return 0;

}

总结

本文使用OpenCV C++ 利用ml模块进行手写数字识别,源码注释也比较详细,主要操作有以下几点。

1、数据集划分为训练集与测试集

2、进行KNN训练

3、进行模型预测以及结果显示


版权声明 : 本文内容来源于互联网或用户自行发布贡献,该文观点仅代表原作者本人。本站仅提供信息存储空间服务和不拥有所有权,不承担相关法律责任。如发现本站有涉嫌抄袭侵权, 违法违规的内容, 请发送邮件至2530232025#qq.cn(#换@)举报,一经查实,本站将立刻删除。
原文链接 : https://blog.csdn.net/Zero___Chen/article/details/126206827
相关文章
  • C++中类的六大默认成员函数的介绍

    C++中类的六大默认成员函数的介绍
    一、类的默认成员函数 二、构造函数Date(形参列表) 构造函数主要完成初始化对象,相当于C语言阶段写的Init函数。 默认构造函数:无参的构
  • C/C++实现遍历文件夹最全方法总结介绍

    C/C++实现遍历文件夹最全方法总结介绍
    一、filesystem(推荐) 在c++17中,引入了文件系统,使用起来非常方便 在VS中,可以直接在项目属性中调整: 只要是C++17即以上都可 然后头文件
  • C语言实现手写Map(数组+链表+红黑树)的代码

    C语言实现手写Map(数组+链表+红黑树)的代码
    要求 需要准备数组集合(List) 数据结构 需要准备单向链表(Linked) 数据结构 需要准备红黑树(Rbtree)数据结构 需要准备红黑树和链表适配策略
  • MySQL系列教程之使用C语言来连接数据库

    MySQL系列教程之使用C语言来连接数据库
    写在前面 知道了 Java中使用 JDBC编程 来连接数据库了,但是使用 C语言 来连接数据库却总是连接不上去~ 立即安排一波使用 C语言连接 MySQL数
  • 基于C语言实现简单学生成绩管理系统

    基于C语言实现简单学生成绩管理系统
    一、系统主要功能 1、密码登录 2、输入数据 3、查询成绩 4、修改成绩 5、输出所有学生成绩 6、退出系统 二、代码实现 1 2 3 4 5 6 7 8 9 10 11
  • C语言实现共享单车管理系统

    C语言实现共享单车管理系统
    1.功能模块图; 2.各个模块详细的功能描述。 1.登陆:登陆分为用户登陆,管理员登陆以及维修员登录,登陆后不同的用户所执行的操作
  • C++继承与菱形继承的介绍

    C++继承与菱形继承的介绍
    继承的概念和定义 继承机制是面向对象程序设计的一种实现代码复用的重要手段,它允许程序员在保持原有类特性的基础上进行拓展,增加
  • C/C++指针介绍与使用介绍

    C/C++指针介绍与使用介绍
    什么是指针 C/C++语言拥有在程序运行时获得变量的地址和操作地址的能力,这种用来操作地址的特殊类型变量被称作指针。 翻译翻译什么
  • C++进程的创建和进程ID标识介绍
    进程的ID 进程的ID,可称为PID。它是进程的唯一标识,类似于我们的身份证号是唯一标识,因为名字可能会和其他人相同,生日可能会与其他
  • C++分析如何用虚析构与纯虚析构处理内存泄漏

    C++分析如何用虚析构与纯虚析构处理内存泄漏
    一、问题引入 使用多态时,如果有一些子类的成员开辟在堆区,那么在父类执行完毕释放后,没有办法去释放子类的内存,这样会导致内存
  • 本站所有内容来源于互联网或用户自行发布,本站仅提供信息存储空间服务,不拥有版权,不承担法律责任。如有侵犯您的权益,请您联系站长处理!
  • Copyright © 2017-2022 F11.CN All Rights Reserved. F11站长开发者网 版权所有 | 苏ICP备2022031554号-1 | 51LA统计