Resume Training Neural Network in MATLAB

Feb. 05, 2024

Introduction

It always takes a long time to train a “large” deep learning neural network, and in which case specifying checkpoints in the training progress is practical. This method avoids re-training network from scratch if the training progress is interrupted unexpectedly, although it will occupy extra storage space.

MATLAB official example “Resume Training from Checkpoint Network” 1 shows how to set checkpoints and continue training from a specified checkpoint network in the classic MNIST multi-class classification task. Some details of this example would be discussed in this post.


Setting checkpoints while training

In the official example 1, the complete code of training multi-class classification CNN shows as follows:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
clc,clear,close all

rng("default")

% Prepare training dataset
[XTrain,YTrain] = digitTrain4DArrayData;

% Define newtwork architecture
layers = [
    imageInputLayer([28 28 1])
    
    convolution2dLayer(3,8,"Padding","same")
    batchNormalizationLayer
    reluLayer    
    maxPooling2dLayer(2,"Stride",2) 
    
    convolution2dLayer(3,16,"Padding","same")
    batchNormalizationLayer
    reluLayer    
    maxPooling2dLayer(2,"Stride",2)
    
    convolution2dLayer(3,32,"Padding","same")
    batchNormalizationLayer
    reluLayer   
    averagePooling2dLayer(7)  
    
    fullyConnectedLayer(10)
    softmaxLayer
    classificationLayer];

% Specify Training Options
options = trainingOptions("sgdm", ...
    "InitialLearnRate",0.1, ...
    "MaxEpochs",20, ...
    "Verbose",false, ...
    "Plots","training-progress", ...
    "Shuffle","every-epoch", ...
    "CheckpointPath",pwd);

% Train the network
net1 = trainNetwork(XTrain,YTrain,layers,options);

Among which, setting checkpoints is just realized by specifying a file path to "CheckpointPath" property of trainingOptions function 2:

image-20240205212752721

1
2
3
4
5
6
7
8
% Specify Training Options
options = trainingOptions("sgdm", ...
    "InitialLearnRate",0.1, ...
    "MaxEpochs",20, ...
    "Verbose",false, ...
    "Plots","training-progress", ...
    "Shuffle","every-epoch", ...
    "CheckpointPath",pwd);

The training progress of running Script 1 shows as follows:

image-20240205200322596

After running, there are some .mat files in the current folder, and each .mat file contains and only contains a neural network variable net:

image-20240205200426267

Specifically, the file name of each .mat file is like:

1
net_checkpoint__<iterations>__<year>_<month>_<date>__<hour>_<minute>_<second>.mat

As showed in at the right side of Fig. 1, there totally are 20 epochs in the entire training progress, and each epoch contains 39 iterations. Correspondingly, there are 20 .mat files in the current folder, and the first figure in file name, i.e., 39, 78, 117 or et al., is the iteration number, showing how many iterations the stored network has been trained.

Or rather, under the default settings, “trainNetwork saves one checkpoint network each epoch and automatically assigns unique names to the checkpoint files.” 1 If we want to change this point, we could specify two other properties of trainingOptions function 2, i.e., "CheckpointFrequency":

image-20240205225607185

and "CheckpointFrequencyUnit":

image-20240205225640531


Resume training from checkpoint network

If we want to resume training from the last checkpoint network, we could use the following script to realize it:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
clc,clear,close all

rng("default")

% Load checkpoint network
load net_checkpoint__780__2024_02_05__20_02_53.mat

% Prepare training dataset
[XTrain,YTrain] = digitTrain4DArrayData;

% Sepcify training options
options = trainingOptions("sgdm", ...
    "InitialLearnRate",0.05, ...
    "MaxEpochs",15, ...
    "Verbose",false, ...
    "Plots","training-progress", ...
    "Shuffle","every-epoch", ...
    "CheckpointPath",pwd);

% Resume training
net2 = trainNetwork(XTrain,YTrain,net.Layers,options);

image-20240205215027916

And of course, the property specifications of trainingOptions in the both training progress (Script 1 and Script 2) don’t need to be the same (such as "InitialLearnRate" and "MaxEpochs").

In addition, there is a small detail that needs to be noted, that is the iteration number of the file name will be re-counted in the second training progress:

image-20240205220333904


References