C语言
主页 > 软件编程 > C语言 >

C++ OpenCV之手写数字识别的教程

2022-08-07 | 酷站 | 点击:

前言

本案例通过使用machine learning机器学习模块进行手写数字识别。源码注释也写得比较清楚啦,大家请看源码注释!!!

一、准备数据集

原图如图所示:总共有0~9数字类别,每个数字共20个。现在需要将下面图片切分成训练数据图片、测试数据图片。该图片尺寸为560x280,故将其切割成28x28大小数据图片。具体请看源码注释。

1

2

3

4

5

6

7

8

9

10

const int classNum = 10;  //总共有0~9个数字类别

const int picNum = 20;//每个类别共20张图片

const int pic_w = 28;//图片宽

const int pic_h = 28;//图片高

 

//将数据集分为训练集、测试集

double totalNum = classNum * picNum;//图片总数

double per = 0.8;   //百分比--修改百分比可改变训练集、测试集比重

double trainNum = totalNum * per;//训练图片数量

double testNum = totalNum * (1.0 - per);//测试图片数量

下面需要将整张图像一一切割成28x28小尺寸图片作为数据集,填充至训练集与测试集。

1

2

3

4

5

6

7

8

9

10

11

12

13

14

15

16

17

18

19

Mat Train_Data, Train_Label;//用于训练

vector<MyNum>TestData;//用于测试

for (int i = 0; i < picNum; i++)

{

    for (int j = 0; j < classNum; j++)

    {

        //将所有图片数据都拷贝到Mat矩阵里

        Mat temp;

        gray(Range(j*pic_w, j*pic_w + pic_w), Range(i*pic_h, i*pic_h + pic_h)).copyTo(temp);

        Train_Data.push_back(temp.reshape(0, 1)); //将temp数字图像reshape成一行数据,然后一一追加到Train_Data矩阵中

        Train_Label.push_back(j);

 

        //而外用于测试

        if (i * classNum + j >= trainNum)

        {

            TestData.push_back({ temp,Rect(i*pic_w,j*pic_h,pic_w,pic_h),j });

        }

    }

}

接下来就是要将数据集进行格式转换。

1

2

3

4

5

//准备训练数据集

Train_Data.convertTo(Train_Data, CV_32FC1); //转化为CV_32FC1类型

Train_Label.convertTo(Train_Label, CV_32FC1);

Mat TrainDataMat = Train_Data(Range(0, trainNum), Range::all()); //只取trainNum行训练

Mat TrainLabelMat = Train_Label(Range(0, trainNum), Range::all());

二、KNN训练

这里使用OpenCV中的KNN算法进行训练。

1

2

3

4

5

6

7

//KNN训练

const int k = 3;  //k值,取奇数,影响最终识别率

Ptr<KNearest>knn = KNearest::create();  //构造KNN模型

knn->setDefaultK(k);//设定k值

knn->setIsClassifier(true);//KNN算法可用于分类、回归。

knn->setAlgorithmType(KNearest::BRUTE_FORCE);//字符匹配算法

knn->train(TrainDataMat, ROW_SAMPLE, TrainLabelMat);//模型训练

三、模型预测及结果显示

1

2

3

4

5

6

7

8

9

10

11

12

13

14

15

16

17

18

19

20

21

22

23

24

25

26

27

28

29

30

31

32

33

//预测及结果显示

double count = 0.0;

Scalar color;

for (int i = 0; i < TestData.size(); i++)

{

    //将测试图片转成CV_32FC1,单行形式

    Mat data = TestData[i].mat.reshape(0, 1);

    data.convertTo(data, CV_32FC1);

    Mat sample = data(Range(0, data.rows), Range::all());

 

    float f = knn->predict(sample); //预测

    if (f == TestData[i].label)

    {

        color = Scalar(0, 255, 0); //如果预测正确,绘制绿色,并且结果+1

        count++;

    }

    else

    {

        color = Scalar(0, 0, 255);//如果预测错误,绘制红色

    }

 

    rectangle(src, TestData[i].rect, color, 2);

}

 

//将绘制结果拷贝到一张新图上

Mat result(Size(src.cols, src.rows + 50), CV_8UC3, Scalar::all(255));

src.copyTo(result(Rect(0, 0, src.cols, src.rows)));

//将得分在结果图上显示

char text[10];

int score = (count / testNum) * 100;

sprintf_s(text, "%s%d%s", "Score:", score, "%");

putText(result, text, Point((result.cols / 2) - 80, result.rows - 15), FONT_HERSHEY_SIMPLEX, 1, Scalar(0, 255, 0), 2);

如图为不同比重训练集与测试集识别结果。

四、源码

1

2

3

4

5

6

7

8

9

10

11

12

13

14

15

16

17

18

19

20

21

22

23

24

25

26

27

28

29

30

31

32

33

34

35

36

37

38

39

40

41

42

43

44

45

46

47

48

49

50

51

52

53

54

55

56

57

58

59

60

61

62

63

64

65

66

67

68

69

70

71

72

73

74

75

76

77

78

79

80

81

82

83

84

85

86

87

88

89

90

91

92

93

94

95

96

97

98

99

100

101

102

103

104

105

106

107

108

109

110

111

112

#include<iostream>

#include<opencv2/opencv.hpp>

#include<opencv2/ml.hpp>

using namespace std;

using namespace cv;

using namespace cv::ml;

 

 

//**自定义结构体

struct MyNum

{

    cv::Mat mat; //数字图片

    cv::Rect rect;//相对整张图所在矩形

    int label;//数字标签

};

 

int main()

{

    Mat src = imread("digit.png");

    if (src.empty())

    {

        cout << "No Image..." << endl;

        system("pause");

        return -1;

    }

 

    Mat gray;

    cvtColor(src, gray, COLOR_BGR2GRAY);

 

    const int classNum = 10;  //总共有0~9个数字类别

    const int picNum = 20;//每个类别共20张图片

    const int pic_w = 28;//图片宽

    const int pic_h = 28;//图片高

 

    //将数据集分为训练集、测试集

    double totalNum = classNum * picNum;//图片总数

    double per = 0.8;   //百分比--修改百分比可改变训练集、测试集比重

    double trainNum = totalNum * per;//训练图片数量

    double testNum = totalNum * (1.0 - per);//测试图片数量

 

    Mat Train_Data, Train_Label;//用于训练

    vector<MyNum>TestData;//用于测试

    for (int i = 0; i < picNum; i++)

    {

        for (int j = 0; j < classNum; j++)

        {

            //将所有图片数据都拷贝到Mat矩阵里

            Mat temp;

            gray(Range(j*pic_w, j*pic_w + pic_w), Range(i*pic_h, i*pic_h + pic_h)).copyTo(temp);

            Train_Data.push_back(temp.reshape(0, 1)); //将temp数字图像reshape成一行数据,然后一一追加到Train_Data矩阵中

            Train_Label.push_back(j);

 

            //额外用于测试

            if (i * classNum + j >= trainNum)

            {

                TestData.push_back({ temp,Rect(i*pic_w,j*pic_h,pic_w,pic_h),j });

            }

        }

    }

 

    //准备训练数据集

    Train_Data.convertTo(Train_Data, CV_32FC1); //转化为CV_32FC1类型

    Train_Label.convertTo(Train_Label, CV_32FC1);

    Mat TrainDataMat = Train_Data(Range(0, trainNum), Range::all()); //只取trainNum行训练

    Mat TrainLabelMat = Train_Label(Range(0, trainNum), Range::all());

 

    //KNN训练

    const int k = 3;  //k值,取奇数,影响最终识别率

    Ptr<KNearest>knn = KNearest::create();  //构造KNN模型

    knn->setDefaultK(k);//设定k值

    knn->setIsClassifier(true);//KNN算法可用于分类、回归。

    knn->setAlgorithmType(KNearest::BRUTE_FORCE);//字符匹配算法

    knn->train(TrainDataMat, ROW_SAMPLE, TrainLabelMat);//模型训练

 

    //预测及结果显示

    double count = 0.0;

    Scalar color;

    for (int i = 0; i < TestData.size(); i++)

    {

        //将测试图片转成CV_32FC1,单行形式

        Mat data = TestData[i].mat.reshape(0, 1);

        data.convertTo(data, CV_32FC1);

        Mat sample = data(Range(0, data.rows), Range::all());

 

        float f = knn->predict(sample); //预测

        if (f == TestData[i].label)

        {

            color = Scalar(0, 255, 0); //如果预测正确,绘制绿色,并且结果+1

            count++;

        }

        else

        {

            color = Scalar(0, 0, 255);//如果预测错误,绘制红色

        }

 

        rectangle(src, TestData[i].rect, color, 2);

    }

 

    //将绘制结果拷贝到一张新图上

    Mat result(Size(src.cols, src.rows + 50), CV_8UC3, Scalar::all(255));

    src.copyTo(result(Rect(0, 0, src.cols, src.rows)));

    //将得分在结果图上显示

    char text[10];

    int score = (count / testNum) * 100;

    sprintf_s(text, "%s%d%s", "Score:", score, "%");

    putText(result, text, Point((result.cols / 2) - 80, result.rows - 15), FONT_HERSHEY_SIMPLEX, 1, Scalar(0, 255, 0), 2);

    imshow("test", result);

    imwrite("result.jpg", result);

    waitKey(0);

    system("pause");

    return 0;

}

总结

本文使用OpenCV C++ 利用ml模块进行手写数字识别,源码注释也比较详细,主要操作有以下几点。

1、数据集划分为训练集与测试集

2、进行KNN训练

3、进行模型预测以及结果显示

原文链接:https://blog.csdn.net/Zero___Chen/article/details/126206827
相关文章
最新更新