Two Simple-but-Notable Details when Using MATLAB fitcnet
to Train Neural Network
Introduction on fitcnet
function
In MATLAB, utilising fitcnet
function [1] is easy method to build (and hence train and test) a fully connected neural network. For instance, we could optionally construct a 4-3-4-3 layers neural network for realising the classification at fisheriris
dataset:
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
clc,clear,close all
load fisheriris.mat
cv = cvpartition(species,"HoldOut",0.4,"Stratify",true);
species = table2array(cell2table(species));
meas = array2table(meas);
dataTable = addvars(meas,species,'After',"meas4");
trainingDataTable = dataTable(cv.training,:);
testDataTable = dataTable(cv.test,:);
mdl = fitcnet(...
trainingDataTable,"species", ...
"LayerSizes",[3,4], ...
"Activations","relu", ...
"Standardize",true, ...
"LossTolerance",1e-6 ... % default value, 1e-6
);
% "Verbose",1);
pred = mdl.predict(testDataTable(:,1:4));
accu = sum(strcmp(pred,table2array(testDataTable(:,5))))/numel(testDataTable(:,5));
disp(accu)
figure
hold(gca,"on"),grid(gca,"on"),box(gca,"on")
plot(mdl.TrainingHistory.Iteration,mdl.TrainingHistory.TrainingLoss, ...
"LineWidth",1.5,"LineStyle","-","Marker","o")
title(sprintf("Training loss:%.8f",mdl.TrainingHistory.TrainingLoss(end)))
xlabel("Interation")
ylabel("Training loss")
xlim([0,mdl.TrainingHistory.Iteration(end)])
The property of "LossTolerance",1e-6
shows that the training process of this network stops when training loss is equal or less than the specified value 1e-6
(although is 1e-6
is default value, specifying it explicitly is a better practice). This is so literally obvious that users possibly neglect two details about stopping condition of network, which are I want to point out in this blog.
GradientTolerance
property
To begin with, we could find this network will stop before training loss is equal or less than 1e-6
in some cases, like:
The reason for this phenomenon is that the gradient is extremely small so that MATLAB believe the network is well-trained and unnecessary to continue training. Actually, we could define what is “extremely small” for gradient by property GradientTolerance
, whose default value is 1e-6
as well:
We could verify it in the above case:
1
2
3
4
5
>> mdl.TrainingHistory.Gradient(end-1:end)
ans =
1.0e-05 *
0.918975375158474
0.016944200306518
Validation concerned properties ValidationData
, ValidationPatience
and ValidationChecks
Another thing I want to mention is about validation and test. In some scenes, we may want to look into the corresponding TEST loss of model at different training stages, but we do not and also could not expect to make the test samples get involved in the training process, otherwise the problem of data leakage will happen. A seemingly correct way is by setting ValidationData
property, like this:
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
...
mdl = fitcnet(...
trainingDataTable,"species", ...
"LayerSizes",[3,4], ...
"Activations","relu", ...
"Standardize",true, ...
"LossTolerance",1e-6, ... % default value, 1e-6
"GradientTolerance",1e-6, ... % default value, 1e-6
"ValidationData",testDataTable...
);
pred = mdl.predict(testDataTable(:,1:4));
accu = sum(strcmp(pred,table2array(testDataTable(:,5))))/numel(testDataTable(:,5));
disp(accu)
figure
hold(gca,"on"),grid(gca,"on"),box(gca,"on")
plot(mdl.TrainingHistory.Iteration,mdl.TrainingHistory.TrainingLoss, ...
"LineWidth",1.5,"LineStyle","-","Marker","o","DisplayName","Training loss")
plot(mdl.TrainingHistory.Iteration,mdl.TrainingHistory.ValidationLoss, ...
"LineWidth",1.5,"LineStyle","-","Marker","square","DisplayName","Test loss")
title(sprintf("Training loss:%.8f",mdl.TrainingHistory.TrainingLoss(end)))
xlabel("Interation")
ylabel("Training loss")
xlim([0,mdl.TrainingHistory.Iteration(end)])
legend
But we could find from the results that the aforementioned phenomenon, i.e., the network stops training before training loss reduce to the LossTolerance
, happens more frequently, although the gradient dose not reduce to the GradientTolerance
value.
This is because that once we set ValidationData
property, the validation dataset will influence the training progress DEFINITELY, despite the fact that we just want to observe whose loss. Or rather, at this time, the stopping condition is determined by validation loss in a way, and depending on ValidationPatience
property:
According official introduction in the above figure, default value 6
of ValidationPatience
means if the situation which validation loss is greater than or equal to the minimum validation loss computed so far happens at least 6 times IN A ROW, the network will stop training and return a trained network, regardless of settings of LossTolerance
and GradientTolerance
. This point could also be verified by read-only ValidationChecks
property:
1
2
3
4
5
6
>> mdl.TrainingHistory.ValidationChecks'
ans =
Columns 1 through 18
0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0
Columns 19 through 25
0 1 2 3 4 5 6
Which means that the loss we observed now is VALIDATION loss, not test loss. But on the other hand, based on this property and whose explanation, we could figure out that setting ValidationChecks
at Inf
could realise what we expected, that is keeping these samples packed in ValidationData
away from training process and just observing whose loss.
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
clc,clear,close all
load fisheriris.mat
cv = cvpartition(species,"HoldOut",0.4,"Stratify",true);
species = table2array(cell2table(species));
meas = array2table(meas);
dataTable = addvars(meas,species,'After',"meas4");
trainingDataTable = dataTable(cv.training,:);
testDataTable = dataTable(cv.test,:);
mdl = fitcnet(...
trainingDataTable,"species", ...
"LayerSizes",[3,4], ...
"Activations","relu", ...
"Standardize",true, ...
"LossTolerance",1e-6, ...% default value, 1e-6
"GradientTolerance",1e-6, ...% default value, 1e-6
"ValidationData",testDataTable, ...
"ValidationPatience",Inf... % default value, 6
);
pred = mdl.predict(testDataTable(:,1:4));
accu = sum(strcmp(pred,table2array(testDataTable(:,5))))/numel(testDataTable(:,5));
disp(accu)
figure
hold(gca,"on"),grid(gca,"on"),box(gca,"on")
plot(mdl.TrainingHistory.Iteration,mdl.TrainingHistory.TrainingLoss, ...
"LineWidth",1.5,"LineStyle","-","Marker","o","DisplayName","Training loss")
plot(mdl.TrainingHistory.Iteration,mdl.TrainingHistory.ValidationLoss, ...
"LineWidth",1.5,"LineStyle","-","Marker","square","DisplayName","Test loss")
title(sprintf("Training loss:%.8f",mdl.TrainingHistory.TrainingLoss(end)))
xlabel("Interation")
ylabel("Training loss")
xlim([0,mdl.TrainingHistory.Iteration(end)])
legend
At this time, we could find more noticeable overfitting phenomenon from which.
References
[1] fitcnet - MathWorks.