Learning Rate Settings while Using MATLAB trainNetwork Function

Mar. 23, 2023

在使用MATLAB Deep Learning Toolbox所提供的trainNetwork函数(trainNetwork - MathWorks)训练神经网络时,可以使用trainingOptions函数(trainingOptions - MathWorks)设置神经网络的学习率,涉及到的trainingOptions属性一共有四个,分别是InitialLearnRateLearnRateScheduleLearnRateDropPeriodLearnRateDropFactor.

(1)InitialLearnRate:设置初始学习率,sgdm求解器的默认值为0.01,rmspropadam求解器的默认值为0.001.

image-20230323102117839

(2)LearnRateSchedule:设置学习率的dropping策略,一共由两个选项,none代表整个训练过程中的学习率是个常数(即InitialLearnRate设置的值),piecewise代表在经过一定的训练Epochs(而不是Iterations)后,通过乘系数的方式更新学习率。

image-20230323102623048

(3)LearnRateDropPeriod:当'LearnRateSchedule'='piecewise'时,设置降低学习率的Epochs.

image-20230323102852606

(4)LearnRateDropFactor:当'LearnRateSchedule'='piecewise'时,设置学习率的衰减因子,从0到1取值。

image-20230323103044281

由此可见,在整个训练的过程中,学习率要么保持不变,要么衰减,不能设置学习率增加的训练过程(增加学习率不太符合常理,但并非是不能这么做)。


之后找了一个简单的官方示例,修改了trainingOptions的这四个属性,做了简单的验证:

matlab
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
% MATLAB official example: Train Network for Sequence Classification
clc,clear,close all

[XTrain,YTrain] = japaneseVowelsTrainData;
[XTest,YTest] = japaneseVowelsTestData;

inputSize = 12;
numHiddenUnits = 100;
numClasses = 9;

layers = [ ...
    sequenceInputLayer(inputSize)
    lstmLayer(numHiddenUnits,'OutputMode','last')
    fullyConnectedLayer(numClasses)
    softmaxLayer
    classificationLayer];

maxEpochs = 70;
miniBatchSize = 27;

options = trainingOptions('sgdm', ...
    'ExecutionEnvironment','cpu', ...
    'MaxEpochs',maxEpochs, ...
    'MiniBatchSize',miniBatchSize, ...
    'ValidationData', {XTest,YTest},...
    'ValidationFrequency', 10,...
    'GradientThreshold',1, ...
    'Verbose',false, ...
    'Plots','training-progress',...
    'LearnRateSchedule','piecewise',...
    'LearnRateDropPeriod',30,...
    'LearnRateDropFactor',0.3);

[net,info] = trainNetwork(XTrain,YTrain,layers,options);

在这样的miniBatchSizemaxEpochs设置下,一共会训练70个Epochs,每个Epoch会训练10个Iterations:

image-20230323104327523

因此,'LearnRateDropPeriod',30'LearnRateDropFactor',0.3就意味着每300个Iterations(30 Epochs)学习率衰减0.3.

最后,将训练过程的学习率绘制出来:

matlab
1
plot(info.BaseLearnRate)

image-20230323110050545

符合我们的预期。


最后,找了篇论文 [1] 看了一下作者所采用的学习率设置:

image-20230323110228657

是和MATLAB的学习率设置逻辑是一致的。


References

[1] Zhao, Minghang, et al. “Deep residual shrinkage networks for fault diagnosis.” IEEE Transactions on Industrial Informatics 16.7 (2019): 4681-4690.