Stop Training Neural Network while Using MATLAB trainNetwork Function

Mar. 21, 2023

Introduction

MATLAB Deep Learning Toolbox提供了一个用于训练分类和回归神经网络的函数trainNetworktrainNetwork - MathWorks)。在用户不自定义training loop的情况下,使用trainNetwork训练神经网络是很方便的。并且,用户可以通过trainingOptions函数(trainingOptions - MathWorks)设置神经网络的训练选项,并将其作为输入参数options传入到trainNetwork函数中。例如,MATLAB官网所提供的一个示例(Monitor Deep Learning Training Progress):

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
...
layers = [
    imageInputLayer([28 28 1])
    convolution2dLayer(3,8,Padding="same")
    batchNormalizationLayer
    reluLayer   
    maxPooling2dLayer(2,Stride=2)
    convolution2dLayer(3,16,Padding="same")
    batchNormalizationLayer
    reluLayer
    maxPooling2dLayer(2,Stride=2)
    convolution2dLayer(3,32,Padding="same")
    batchNormalizationLayer
    reluLayer
    fullyConnectedLayer(10)
    softmaxLayer
    classificationLayer];
    
options = trainingOptions("sgdm", ...
    MaxEpochs=8, ...
    ValidationData={XValidation,YValidation}, ...
    ValidationFrequency=30, ...
    Verbose=false, ...
    Plots="training-progress");
    
net = trainNetwork(XTrain,YTrain,layers,options);

在训练神经网络时,一个很重要的问题是:在什么情况下停止训练?因此,本博客就梳理一下基于trainNetwork函数训练网络时如何使用trainingOptions函数设置神经网络的停止条件。


(1) Stop training if the maximum training epochs is reached (MaxEpochs property)

trainingOptions函数提供了设置神经网络的最大训练轮数的属性MaxEpochs

image-20230321110115211

其默认值是30。

在训练神经网络时,我们通常会使用MiniBatchSize属性(同样使用trainingOptions函数设置,默认大小为128)指定Mini-batch的大小,将一个Mini-batch的数据传入到神经网络后,会进行前馈和反馈的过程,这是一个Iteration;当所有的数据都以Mini-batch的方式参与训练后,就完成了一个Epoch. 而到达了MaxEpochs属性所指定的最大训练轮数后,神经网络停止训练,这是最简单的一种停止神经网络训练的方式。


(2) Automatic validation stopping (ValidatationPatience and OutputNetwork properties)

使用trainNetwork训练神经网络很方便的一点是我们可以在图形界面监视神经网络的训练过程,监视的指标有:Training accuracy,Smoothed training accuracy,Validation accuracy;Training loss,Smoothed training loss,Validation loss:

image-20230321111611264

训练集的准确率和误差是始终可以监视的,但是如果我们想要使用验证集同时验证网络的性能,则需要在trainingOptions函数中指定验证集(ValidationData属性)以及验证频率的设置(ValidationFrequency属性)。例如:

1
2
3
4
5
6
options = trainingOptions("sgdm", ...
    MaxEpochs=8, ...
    ValidationData={XValidation,YValidation}, ...
    ValidationFrequency=30, ...
    Verbose=false, ...
    Plots="training-progress");

但是,这样的设置仅仅使得我们可以观察到验证集的验证效果(准确率和误差),并没有利用验证集的信息进行早停的设置。

如果想要automatic validation stopping,则需要指定trianingOptionsValidatationPatience属性:

image-20230321112353362

该属性指定为一个正整数或者是默认的Inf.

随着神经网络的不断训练,测试集损失值会出现不降反升的现象,这意味着可能发生了过拟合现象;但另一方面,这也可能是神经网络训练过程的正常波动现象。因此,需要使用一种方式(在某种程度上)确保是出现了过拟合现象。trainNetwork函数会将截止到当前训练的最小验证集损失值记录下来,并将后面每一次验证时的验证集损失值与当前最小验证集损失相比较,如果验证集损失值大于最小验证集损失的次数大于ValidatationPatience属性的值,就会停止训练。

这里需要强调的是另一个属性ValidationFrequency,这个属性指定了验证集验证的频率:

image-20230321125922866

ValidationFrequency指定为一个正整数,默认值是50。需要注意的是,这一属性的单位是iteration(并非epochs),即每训练50 iterations就使用验证集验证一次。

当神经网络停止训练后,输出的神经网络也有两个选择。一是输出具有最小验证集误差的那个神经网络,二是输出停止时最后一次训练的神经网络。我们可以通过OutputNetwork属性进行设置:

image-20230321130544695


(3) Custom early stopping based on validation accuracy (OutputFcn property)

trainingOptions函数还有另外一个输入参数OutputFcn,该输入参数接受一个函数句柄(function handle)或者函数句柄的元组数组(cell array of function handles):

image-20230321131608733

MATLAB提供了基于OutputFcn属性早停的示例:Customize Output During Deep Learning Network Training - MathWorks,相关的代码为:

1
2
3
4
5
6
7
8
9
10
11
miniBatchSize = 128;
validationFrequency = floor(numel(YTrain)/miniBatchSize);
options = trainingOptions('sgdm', ...
    'InitialLearnRate',0.01, ...
    'MaxEpochs',100, ...
    'MiniBatchSize',miniBatchSize, ...
    'VerboseFrequency',validationFrequency, ...
    'ValidationData',{XValidation,YValidation}, ...
    'ValidationFrequency',validationFrequency, ...
    'Plots','training-progress', ...
    'OutputFcn',@(info)stopIfAccuracyNotImproving(info,3));
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
function stop = stopIfAccuracyNotImproving(info,N)

stop = false;

% Keep track of the best validation accuracy and the number of validations for which
% there has not been an improvement of the accuracy.
persistent bestValAccuracy
persistent valLag

% Clear the variables when training starts.
if info.State == "start"
    bestValAccuracy = 0;
    valLag = 0;
    
elseif ~isempty(info.ValidationLoss)
    
    % Compare the current validation accuracy to the best accuracy so far,
    % and either set the best accuracy to the current accuracy, or increase
    % the number of validations for which there has not been an improvement.
    if info.ValidationAccuracy > bestValAccuracy
        valLag = 0;
        bestValAccuracy = info.ValidationAccuracy;
    else
        valLag = valLag + 1;
    end
    
    % If the validation lag is at least N, that is, the validation accuracy
    % has not improved for at least N validations, then return true and
    % stop training.
    if valLag >= N
        stop = true;
    end
    
end
end

这种早停策略是:

Define the output function stopIfAccuracyNotImproving(info,N), which stops network training if the best classification accuracy on the validation data does not improve for N network validations in a row.

这种早停策略与上一种使用ValidatationPatience属性早停的策略是类似的,唯一不同的是这里使用的是准确率而不是损失值:

This criterion is similar to the built-in stopping criterion using the validation loss, except that it applies to the classification accuracy instead of the loss.

但是,这种方式明显更加得灵活,我们可以设置各种各样的早停策略。但是在此之前,需要首先了解stopIfAccuracyNotImproving函数的形参info结构体包含了什么信息。

在这个示例文件的stopIfAccuracyNotImproving的函数中增加一个断点,观察一下前三次神经网络输出的结构体:

1
2
3
4
5
6
7
8
9
10
11
12
13
info = 
  struct with fields:
                 Epoch: 0
             Iteration: 0
        TimeSinceStart: []
          TrainingLoss: []
        ValidationLoss: []
         BaseLearnRate: []
      TrainingAccuracy: []
          TrainingRMSE: []
    ValidationAccuracy: []
        ValidationRMSE: []
                 State: "start"
1
2
3
4
5
6
7
8
9
10
11
12
13
info = 
  struct with fields:
                 Epoch: 1
             Iteration: 1
        TimeSinceStart: 23.7756
          TrainingLoss: 2.7155
        ValidationLoss: 2.5169
         BaseLearnRate: 0.0100
      TrainingAccuracy: 7.8125
          TrainingRMSE: 3.3138
    ValidationAccuracy: 12.7000
        ValidationRMSE: []
                 State: "iteration"
1
2
3
4
5
6
7
8
9
10
11
12
13
info = 
  struct with fields:
                 Epoch: 1
             Iteration: 2
        TimeSinceStart: 79.4602
          TrainingLoss: 2.5766
        ValidationLoss: []
         BaseLearnRate: 0.0100
      TrainingAccuracy: 12.5000
          TrainingRMSE: 2.8631
    ValidationAccuracy: []
        ValidationRMSE: []
                 State: "iteration"

可以得出以下三个结论:

(1)每次Mini-batch所对应的Iteration结束后都会输出info结构体,并不是以Epoch为单位;

(2)输出的结构体的内容是与OutputFcn属性中所列出的内容是一致的;

(3)只有间隔ValidationFrequency的次数,才会输出验证集的相关信息(如这里的fields ValidationLossValidationAccuracy);


(4) Custom early stopping based on training loss (OutputFcn property)

Added on Sep. 15 2023.

At times, available data are limited, and it is likely impractical to specify partial data as validation dataset. In these scenarios, stopping training network according to training loss or accuracy is an alternative choice. This could be realised by aforementioned 'OutputFcn' property of trainingOptions as well.

The following example shows how to stop network training when training loss is lower than a specified value (i.e. 1e-2) 3 times in a row.

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
...
validationFrequency = floor(numel(YTrain)/miniBatchSize);
options = trainingOptions('sgdm', ...
    'InitialLearnRate',0.01, ...
    'MaxEpochs',100, ...
    'MiniBatchSize',miniBatchSize, ...
    'VerboseFrequency',validationFrequency, ...
    'ValidationData',{XValidation,YValidation}, ...
    'ValidationFrequency',validationFrequency, ...
    'Plots','training-progress', ...
    'OutputFcn',@(info)stopIfTrainingLossNotImproving(info,3));
[net,info] = trainNetwork(XTrain,YTrain,layers,options);

function stop = stopIfTrainingLossNotImproving(info,N)
stop = false;

persistent lossLag
lossThreshold = 1e-2;

if info.State == "start"
    lossLag = 0;
elseif ~isempty(info.TrainingLoss)
    if info.TrainingLoss > lossThreshold
        lossLag = 0;
    else
        lossLag = lossLag + 1;
    end
    if lossLag >= N
        stop = true;
    end
end
end

image-20230915160625882

We could observe recorded training information info.TrainingLoss to verify it:

1
2
3
>> info.TrainingLoss(end-4:end)
ans =
    0.0173    0.0141    0.0087    0.0068    0.0099


References

[1] trainNetwork - MathWorks.

[2] trainingOptions - MathWorks.

[3] Customize Output During Deep Learning Network Training - MathWorks.