Train an SVM Using Generated Data by MVN-RNG, and Test with Real Data

Sep. 10, 2022

Warning⚠️⚠️⚠️ :本文所有的示例都存在严重的数据泄露现象,当时没有进行考虑,之后会对针对这一问题进行修改。2023.08.31:数据泄露的问题已经进行了矫正,详见博客Correct Data Leakage Problem - What a starry night~ .

Introduction

分类模型一般是通过具有特征和标签的训练数据集进行训练,之后通过测试集验证模型的效果。但是,训练模型时通常会遇到的问题是训练样本数量不足,解决该问题的一种方法是使用数学工具生成训练样本。本文尝试了一种使用多元正态分布随机数产生器(Multivariate Normal Random Number Generator)生成样本数据的方法。

首先将真实数据集按照标签分类,计算各个类别的样本均值和样本协方差矩阵。假设针对每一种分类,各特征(即随机变量)服从多元正态分布,因此依据各类别的样本均值和样本协方差矩阵使用RNG生成对应类别的数据集。并且本文考虑最糟糕的一种情况:即使用所有的生成数据作为训练集训练机器学习模型(eg SVM),之后使用所有的真实数据作为测试集测试模型预测效果。

MATLAB提供了一些数据集:Sample Data Sets - MathWorks


Numerical Features Input

Example 1: fisheriris Dataset

fisheriris数据集有150条数据,特征维度为4,标签种类有3种,属于三分类问题。每个类别生成500条数据作为训练集。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
clc, clear, close all

% Load real dataset
load fisheriris

% Split the features of 3 kinds of species
idx = strcmp(species, "setosa");
Features_seto = meas(idx, :);
idx = strcmp(species, "versicolor");
Features_vers = meas(idx, :);
idx = strcmp(species, "virginica");
Features_virg = meas(idx, :);

% Calculate the mu and Sigma of features of each sepices
mu_seto = mean(Features_seto);
Sigma_seto = cov(Features_seto);
mu_vers = mean(Features_vers);
Sigma_vers = cov(Features_vers);
mu_virg = mean(Features_virg);
Sigma_virg = cov(Features_virg);

% Generate the features and the corresponding labels
numPerSpecies = 500;
X_generate = [mvnrnd(mu_seto, Sigma_seto, numPerSpecies);
    mvnrnd(mu_vers, Sigma_vers, numPerSpecies);
    mvnrnd(mu_virg, Sigma_virg, numPerSpecies)];
Y_generate = [repmat("setosa", numPerSpecies, 1);
    repmat("versicolor", numPerSpecies, 1); 
    repmat("virginica", numPerSpecies, 1)];

% Create an SVM template, and standardize the predictors
t = templateSVM('Standardize', true);

% Train the ECOC classifier using generated dataset
mdl = fitcecoc(X_generate, Y_generate, 'Learners', t,...
    'ClassNames', {'setosa','versicolor','virginica'});

% Predict the labels using the true fetures
pred = mdl.predict(meas);

% Calculate the accuracy
accu = sum(strcmp(pred, species))/numel(species)*100;
1
2
3
accu =

    98


Example 2: ionosphere Dataset

ionosphere数据集有351条数据,特征维度为34,标签种类有2种,是一个二分类。每个类别生成2000条数据作为训练集。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
clc, clear, close all

load ionosphere.mat

idx = strcmp(Y, 'g');
Features_g = X(idx, :);
idx = strcmp(Y, 'b');
Features_b = X(idx, :);

mu_g = mean(Features_g);
Sigma_g = cov(Features_g);
mu_b = mean(Features_b);
Sigma_b = cov(Features_b);

numPerClass = 2000;
X_generate  = [mvnrnd(mu_g, Sigma_g, numPerClass);
    mvnrnd(mu_b, Sigma_b, numPerClass)];
Y_generate  = [repmat("g", numPerClass, 1);
    repmat("b", numPerClass, 1)];

t = templateSVM('Standardize', true);
mdl = fitcecoc(X_generate, Y_generate, 'Learners', t, 'ClassNames', {'g','b'});

pred = mdl.predict(X);

accu = sum(strcmp(pred, Y))/numel(Y)*100;
1
2
accu =
   90.0285


Example 3: ovariancancer Dataset

ovariancancer数据集有216条数据,特征数量为4000,标签种类有2种,是一个二分类问题。每一种类别生成1000条数据作为训练集。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
clc, clear, close all

load ovariancancer.mat

idx = strcmp(grp, "Cancer");
Features_cancer = obs(idx, :);
idx = strcmp(grp, "Normal");
Features_normal = obs(idx, :);

mu_cancer = mean(Features_cancer);
Sigma_cancer = cov(Features_cancer);
mu_normal = mean(Features_normal);
SIgma_normal = cov(Features_normal);

numPerClass = 1000;

X_generate = [mvnrnd(mu_cancer, Sigma_cancer, numPerClass);
    mvnrnd(mu_normal, SIgma_normal, numPerClass)];
Y_generate = [repmat("Cancer", numPerClass, 1);
    repmat("Normal", numPerClass, 1)];

t = templateSVM('Standardize', true);
mdl = fitcecoc(X_generate, Y_generate, 'Learners', t, 'ClassNames', {'Cancer', 'Normal'});

pred = mdl.predict(obs);

accu = sum(strcmp(pred, grp))/numel(grp)*100;
1
2
accu =
   100


Sequential Signal Input

Example : simulatedDataset

之前的例子都是假设多个数值特征属于MVN分布,而本例子是假设信号数据的各个数据点所构成的随机向量属于MVN分布。simulatedDataset.mat中保存着1571条信号数据,每个信号具有1201个数据点,标签种类有2种,属于二分类问题。每个类别生成1000条数据作为训练集。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
clc, clear, close all

load simulatedDataset.mat

flow = flow';
labels = labels';

idx = labels==1;
signals_1 = flow(idx, :);
idx = labels==2;
signals_2 = flow(idx, :);

mu_1 = mean(signals_1);
Sigma_1 = cov(signals_1);
mu_2 = mean(signals_2);
Sigma_2 = cov(signals_2);

numPerClass = 1000;
X_generate = [mvnrnd(mu_1, Sigma_1, numPerClass);
    mvnrnd(mu_2, Sigma_2, numPerClass)];
Y_generate = [ones(numPerClass, 1);
    2*ones(numPerClass, 1)];

t = templateSVM('Standardize', true);
mdl = fitcecoc(X_generate, Y_generate, 'Learners', t);

pred = mdl.predict(flow);
accu = sum(pred==labels)/numel(labels)*100;
1
2
accu =
   98.1587

同样有着很高的准确率。


2-D Image Input

Example: digitTrainCellArrayData

digitTrainCellArrayData函数会依次调用digitTableToCell函数、digitTrainTable函数、digitTableToCell函数将本地的DigitDataset文件夹(手写数字数据集)整理成特征xTrainImages和分类标签tTrain

1
2
% Load the training data into memory
[xTrainImages, tTrain] = digitTrainCellArrayData;
1
2
3
Name               Size                 Bytes  Class     
  tTrain            10x5000              400000  double              
  xTrainImages       1x5000            31880000  cell

其中,标签tTrain是one-hot matrix,数据xTrainImages是一个cell类型数据,每一个cell中保存着一个28x28 double的矩阵,表征一个灰度图像。

图像展示:

1
2
3
4
5
% Display some of the training images
for i = 1:20
    subplot(4, 5, i);
    imshow(xTrainImages{i});
end

image-20220910141053455

之后,将特征(28-by-28)展开成行向量,并且将one-hot标签转换为定类变量(1~10,其中10代表数字0):

1
2
3
4
5
X = zeros(5000, 28*28);
for i = 1:numel(xTrainImages)
    X(i, :) = xTrainImages{i}(:);
end
[Y, ~] = find(tTrain == 1);

之后生成数据集,并使用生成数据集训练SVM,最后使用真实数据集进行验证:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
X_generate = [];
Y_generate = [];
numPerClass = 2000;
for i = 1:10
    idx = Y==i;
    expression = ['Image_', num2str(i), '=', 'X(idx, :);'];
    eval(expression);
    expression = ['mu_', num2str(i), '=', 'mean(Image_', num2str(i), ');'];
    eval(expression);
    expression = ['Sigma_', num2str(i), '=', 'cov(Image_', num2str(i), ');'];
    eval(expression);
    expression = ['X_generate=', '[X_generate; ', 'mvnrnd(mu_', num2str(i), ', ', 'Sigma_', num2str(i), ', numPerClass)];'];
    eval(expression);
    expression = ['Y_generate=', '[Y_generate;',  num2str(i), '*ones(numPerClass, 1)];'];
    eval(expression);
end

t = templateSVM('Standardize', true);
mdl = fitcecoc(X_generate, Y_generate, 'Learners', t);

pred = mdl.predict(X);
accu = sum(pred==Y)/numel(Y)*100;
1
2
accu =
   95.5600

准确率很高。

之后,可视化一下生成的图片:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
% Show generated image in figure(2)
idx = randperm(size(X_generate, 1))';
idx = idx(1:20);
X_show = X_generate(idx, :);
Y_show = Y_generate(idx, :);

X_show_cell = {};
for i = 1:size(X_show, 1)
    X_show_cell(i) = {reshape(X_show(i, :), 28, 28)};
end

figure(2)
for i = 1:size(X_show_cell, 2)
    subplot(4, 5, i);
    imshow(X_show_cell{i});
    title(num2str(Y_show(i)))
end

image-20220910135500317

可以看到,虽然由生成数据所训练的模型对于真实数据集的预测效果很好(准确率为95.56%),但是所生成的图片效果并不是很好。

完整代码:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
clc, clear, close all

% Load the training data into memory
[xTrainImages, tTrain] = digitTrainCellArrayData;

% Display some of the training images
figure(1)
for i = 1:20
    subplot(4, 5, i);
    imshow(xTrainImages{i});
end

X = zeros(5000, 28*28);
for i = 1:numel(xTrainImages)
    X(i, :) = xTrainImages{i}(:);
end
[Y, ~] = find(tTrain == 1);

X_generate = [];
Y_generate = [];
numPerClass = 2000;
for i = 1:10
    idx = Y==i;
    expression = ['Image_', num2str(i), '=', 'X(idx, :);'];
    eval(expression);
    expression = ['mu_', num2str(i), '=', 'mean(Image_', num2str(i), ');'];
    eval(expression);
    expression = ['Sigma_', num2str(i), '=', 'cov(Image_', num2str(i), ');'];
    eval(expression);
    expression = ['X_generate=', '[X_generate; ', 'mvnrnd(mu_', num2str(i), ', ', 'Sigma_', num2str(i), ', numPerClass)];'];
    eval(expression);
    expression = ['Y_generate=', '[Y_generate;',  num2str(i), '*ones(numPerClass, 1)];'];
    eval(expression);
end

t = templateSVM('Standardize', true);
mdl = fitcecoc(X_generate, Y_generate, 'Learners', t);

pred = mdl.predict(X);
accu = sum(pred==Y)/numel(Y)*100;

% Show generated image in figure(2)
idx = randperm(size(X_generate, 1))';
idx = idx(1:20);
X_show = X_generate(idx, :);
Y_show = Y_generate(idx, :);

X_show_cell = {};
for i = 1:size(X_show, 1)
    X_show_cell(i) = {reshape(X_show(i, :), 28, 28)};
end

figure(2)
for i = 1:size(X_show_cell, 2)
    subplot(4, 5, i);
    imshow(X_show_cell{i});
    title(num2str(Y_show(i)))
end


Conclusion

综上,使用MVN-RNG生成数据的确是一种很好的数据增强方法,另一方面,这些结果也从侧面验证了多元正态分布是一个很普遍的分布。

但是,我们这里面对的始终是一个分类问题。从最后一个图像的例子也可以看出,虽然对于分类问题而言,这种生成数据的方式是一个很好的数据增强方法,但是如果将MVN-RNG作为一个生成器而言,它的效果就比较差。对于类似生成图像、序列等Structure Learning问题,GAN等复杂的生成器的效果可能会更好,同时也复杂得多。