Learning Rate Settings while Using MATLAB trainNetwork
Function
在使用MATLAB Deep Learning Toolbox所提供的trainNetwork
函数(trainNetwork - MathWorks)训练神经网络时,可以使用trainingOptions
函数(trainingOptions - MathWorks)设置神经网络的学习率,涉及到的trainingOptions
属性一共有四个,分别是InitialLearnRate
,LearnRateSchedule
,LearnRateDropPeriod
和LearnRateDropFactor
.
(1)InitialLearnRate
:设置初始学习率,sgdm
求解器的默认值为0.01,rmsprop
和adam
求解器的默认值为0.001.
(2)LearnRateSchedule
:设置学习率的dropping策略,一共由两个选项,none
代表整个训练过程中的学习率是个常数(即InitialLearnRate
设置的值),piecewise
代表在经过一定的训练Epochs(而不是Iterations)后,通过乘系数的方式更新学习率。
(3)LearnRateDropPeriod
:当'LearnRateSchedule'='piecewise'
时,设置降低学习率的Epochs.
(4)LearnRateDropFactor
:当'LearnRateSchedule'='piecewise'
时,设置学习率的衰减因子,从0到1取值。
由此可见,在整个训练的过程中,学习率要么保持不变,要么衰减,不能设置学习率增加的训练过程(增加学习率不太符合常理,但并非是不能这么做)。
之后找了一个简单的官方示例,修改了trainingOptions
的这四个属性,做了简单的验证:
matlab
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
% MATLAB official example: Train Network for Sequence Classification
clc,clear,close all
[XTrain,YTrain] = japaneseVowelsTrainData;
[XTest,YTest] = japaneseVowelsTestData;
inputSize = 12;
numHiddenUnits = 100;
numClasses = 9;
layers = [ ...
sequenceInputLayer(inputSize)
lstmLayer(numHiddenUnits,'OutputMode','last')
fullyConnectedLayer(numClasses)
softmaxLayer
classificationLayer];
maxEpochs = 70;
miniBatchSize = 27;
options = trainingOptions('sgdm', ...
'ExecutionEnvironment','cpu', ...
'MaxEpochs',maxEpochs, ...
'MiniBatchSize',miniBatchSize, ...
'ValidationData', {XTest,YTest},...
'ValidationFrequency', 10,...
'GradientThreshold',1, ...
'Verbose',false, ...
'Plots','training-progress',...
'LearnRateSchedule','piecewise',...
'LearnRateDropPeriod',30,...
'LearnRateDropFactor',0.3);
[net,info] = trainNetwork(XTrain,YTrain,layers,options);
在这样的miniBatchSize
和maxEpochs
设置下,一共会训练70个Epochs,每个Epoch会训练10个Iterations:
因此,'LearnRateDropPeriod',30
和'LearnRateDropFactor',0.3
就意味着每300个Iterations(30 Epochs)学习率衰减0.3.
最后,将训练过程的学习率绘制出来:
matlab
1
plot(info.BaseLearnRate)
符合我们的预期。
最后,找了篇论文 [1] 看了一下作者所采用的学习率设置:
是和MATLAB的学习率设置逻辑是一致的。
References
[1] Zhao, Minghang, et al. “Deep residual shrinkage networks for fault diagnosis.” IEEE Transactions on Industrial Informatics 16.7 (2019): 4681-4690.