Resume Training Neural Network in MATLAB
Introduction
It always takes a long time to train a “large” deep learning neural network, and in which case specifying checkpoints in the training progress is practical. This method avoids re-training network from scratch if the training progress is interrupted unexpectedly, although it will occupy extra storage space.
MATLAB official example “Resume Training from Checkpoint Network” 1 shows how to set checkpoints and continue training from a specified checkpoint network in the classic MNIST multi-class classification task. Some details of this example would be discussed in this post.
Setting checkpoints while training
In the official example 1, the complete code of training multi-class classification CNN shows as follows:
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
clc,clear,close all
rng("default")
% Prepare training dataset
[XTrain,YTrain] = digitTrain4DArrayData;
% Define newtwork architecture
layers = [
imageInputLayer([28 28 1])
convolution2dLayer(3,8,"Padding","same")
batchNormalizationLayer
reluLayer
maxPooling2dLayer(2,"Stride",2)
convolution2dLayer(3,16,"Padding","same")
batchNormalizationLayer
reluLayer
maxPooling2dLayer(2,"Stride",2)
convolution2dLayer(3,32,"Padding","same")
batchNormalizationLayer
reluLayer
averagePooling2dLayer(7)
fullyConnectedLayer(10)
softmaxLayer
classificationLayer];
% Specify Training Options
options = trainingOptions("sgdm", ...
"InitialLearnRate",0.1, ...
"MaxEpochs",20, ...
"Verbose",false, ...
"Plots","training-progress", ...
"Shuffle","every-epoch", ...
"CheckpointPath",pwd);
% Train the network
net1 = trainNetwork(XTrain,YTrain,layers,options);
Among which, setting checkpoints is just realized by specifying a file path to "CheckpointPath"
property of trainingOptions
function 2:
1
2
3
4
5
6
7
8
% Specify Training Options
options = trainingOptions("sgdm", ...
"InitialLearnRate",0.1, ...
"MaxEpochs",20, ...
"Verbose",false, ...
"Plots","training-progress", ...
"Shuffle","every-epoch", ...
"CheckpointPath",pwd);
The training progress of running Script 1 shows as follows:
After running, there are some .mat
files in the current folder, and each .mat
file contains and only contains a neural network variable net
:
Specifically, the file name of each .mat
file is like:
1
net_checkpoint__<iterations>__<year>_<month>_<date>__<hour>_<minute>_<second>.mat
As showed in at the right side of Fig. 1, there totally are 20 epochs in the entire training progress, and each epoch contains 39 iterations. Correspondingly, there are 20 .mat
files in the current folder, and the first figure in file name, i.e., 39, 78, 117 or et al., is the iteration number, showing how many iterations the stored network has been trained.
Or rather, under the default settings, “trainNetwork
saves one checkpoint network each epoch and automatically assigns unique names to the checkpoint files.” 1 If we want to change this point, we could specify two other properties of trainingOptions
function 2, i.e., "CheckpointFrequency"
:
and "CheckpointFrequencyUnit"
:
Resume training from checkpoint network
If we want to resume training from the last checkpoint network, we could use the following script to realize it:
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
clc,clear,close all
rng("default")
% Load checkpoint network
load net_checkpoint__780__2024_02_05__20_02_53.mat
% Prepare training dataset
[XTrain,YTrain] = digitTrain4DArrayData;
% Sepcify training options
options = trainingOptions("sgdm", ...
"InitialLearnRate",0.05, ...
"MaxEpochs",15, ...
"Verbose",false, ...
"Plots","training-progress", ...
"Shuffle","every-epoch", ...
"CheckpointPath",pwd);
% Resume training
net2 = trainNetwork(XTrain,YTrain,net.Layers,options);
And of course, the property specifications of trainingOptions
in the both training progress (Script 1 and Script 2) don’t need to be the same (such as "InitialLearnRate"
and "MaxEpochs"
).
In addition, there is a small detail that needs to be noted, that is the iteration number of the file name will be re-counted in the second training progress:
References