Compare Deep Learning Networks Based on ROC Curves in MATLAB

Jan. 06, 2023

Introduction

博客1介绍了ROC曲线以及如何在MATLAB中绘制和计算ROC曲线。本博客介绍MATLAB提供的另一个使用ROC曲线对比模型的示例,”Compare Deep Learning Model Using ROC Curves”2. 该示例对比的两个模型均为深度学习模型,虽然更加复杂,但是基本思想并没有区别;另一方面,该示例提供了一些基于模型ROC曲线对比的角度,值得学习借鉴。


Load data set

本示例所采用的数据集是”Flowers dataset”3,该数据集有3670张图片,一共有五种分类,分别是:daisy,dandelion,roses,sunflowers和tulips。创建image datastore,并使用splitEachLabel函数按照6:2:2的比例将数据集划分为训练集、验证集和测试集:

1
2
3
imds = imageDatastore("flower_photos", IncludeSubfolders=true, LabelSource="foldernames");

[imdsTrain, imdsValidation, imdsTest] = splitEachLabel(imds, 0.6, 0.2, 0.2, "randomize");

设置数据的类别数为5:

1
numClasses = 5;


Specify network structures

创建两个图像分类模型。对于第一个模型,我们从头开始构建和训练一个深度神经网络;第二个模型则采用transfer learning的技术基于pretrained GoogLeNet Network针对新数据进行再训练。两个神经网络输入层的大小均为:

1
inputSize = [224 224 3];

Network 1

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
numFilters = 16;
filterSize = 3;
poolSize = 2;

layers = [
    imageInputLayer(inputSize)
    
    convolution2dLayer(filterSize,numFilters,Padding="same")
    batchNormalizationLayer
    reluLayer
    maxPooling2dLayer(filterSize,Stride=2)

    convolution2dLayer(filterSize,2*numFilters,Padding="same")
    batchNormalizationLayer
    reluLayer
    maxPooling2dLayer(poolSize,Stride=2)

    convolution2dLayer(filterSize,4*numFilters,Padding="same")
    batchNormalizationLayer
    reluLayer
    maxPooling2dLayer(poolSize,Stride=2)
   
    dropoutLayer(0.8)

    fullyConnectedLayer(numClasses)
    softmaxLayer
    classificationLayer];

lgraphSmallNet = layerGraph(layers);

Network 2

首先加载GoogLeNet:

1
lgraphGoogLeNet = layerGraph(googlenet);

查看一下这个神经网络的结构:

1
plot(lgraphGoogLeNet)

image-20230104112343706

为了能够将pretrained network用于transfer learning,我们首先需要修改预训练的神经网络结构使其匹配新的数据集的特点。

  • Replace the last learnable layer with a new layer that is adapted to the new data. For GoogLeNet, this layer is the final fully connected layer, loss3-classifier. Set the output size in the new layer to match the number of classes in the new data.
  • Increase the learning in the new layer by increasing the weight and bias learn rate factors. This increase ensures that learning is faster in the new layer than in the transferred layers.
  • Replace the output layer, output, with a new output layer that is adapted to the new data.

image-20230104112833124

1
2
3
4
5
6
7
newLearnableLayer = fullyConnectedLayer(numClasses, ...
    WeightLearnRateFactor=10, ...
    BiasLearnRateFactor=10);
lgraphGoogLeNet = replaceLayer(lgraphGoogLeNet,"loss3-classifier",newLearnableLayer);

newOutputLayer = classificationLayer("Name","ClassificationLayer_predictions");
lgraphGoogLeNet = replaceLayer(lgraphGoogLeNet,"output",newOutputLayer);

我们可以对比一下修改前后神经网络的变化:

1
2
3
4
5
6
7
8
9
lgraphGoogLeNet = layerGraph(googlenet);
% plot(lgraphGoogLeNet)
NetBefore = lgraphGoogLeNet;
...
replaceLayer(lgraphGoogLeNet,"output",newOutputLayer);
NetAfter = lgraphGoogLeNet;
...
analyzeNetwork(NetBefore)
analyzeNetwork(NetAfter)

image-20230104114014040

image-20230104114033174

Compare networks

1
2
analyzeNetwork(lgraphGoogLeNet)
analyzeNetwork(lgraphSmallNet)

Network 1:

image-20230104114517225

Network 2:

image-20230104114543380

第一个小的神经网络一共有17个layers和近300,000个参数,大的GoogLeNet Network有144个layers和大约6,000,000个参数。但是由于我们采用了transfer learning的技术,后者的GoogLeNet Network训练时长并不会很长,因为该pretrained network已经学习到了一定的特征提取能力,我们可以将其作为一个starting point for new data。


Prepare data set

对原始图像数据集使用一些图像增强技术,获得一个augmented image datastore:

注:Data augmentation helps prevent the network from overfitting and memorizing the exact details of the training images.

1
2
3
% For train
augmenter = imageDataAugmenter(RandXReflection=true, RandScale=[0.5 1.5]);
augimdsTrain = augmentedImageDatastore(inputSize, imdsTrain, DataAugmentation=augmenter);

而对于validation images,则仅仅resize而不进行任何的data augmentation,不指定任何的preprocessing operations:

1
2
% For validation
augimdsValidation = augmentedImageDatastore(inputSize,imdsValidation);

因为验证集的作用是监测神经网络模型的训练过程。

Set training options

对于神经网络1:

1
2
3
4
5
6
7
8
% Set training options for Network 1
optsSmallNet = trainingOptions("sgdm", ...
    MaxEpochs=150, ...
    InitialLearnRate=0.002, ...
    ValidationData=augimdsValidation, ...
    ValidationFrequency=150, ...
    Verbose=false, ...
    Plots="training-progress");

对于神经网络2:

1
2
3
4
% Set training options for Network 2
optsGoogLeNet = optsSmallNet;
optsGoogLeNet.MaxEpochs = 15;
optsGoogLeNet.InitialLearnRate = 0.0001;

对于pretrained network,我们不需要训练太多轮,因此我们将最大轮数设置为15。在之前,我们增加了new learnable layer的learning rate。为了放慢pretrained network的earlier layers的训练过程,我们选择一个小的initial learning rate,设置为0.0001。


Train networks

1
2
% For Network 1 
netSmallNet = trainNetwork(augimdsTrain, lgraphSmallNet, optsSmallNet);

image-20230106200837717

1
2
% For Network 2
netGoogLeNet = trainNetwork(augimdsTrain, lgraphGoogLeNet, optsGoogLeNet);

image-20230106200825645

从模型训练的训练过程中我们可以看到,采用了pretrained的模型虽然具有更多参数,但是反而能在较短的时间内训练出更好的效果。


Test networks

本部分基于测试集数据从不同角度测试两个训练好的神经网络模型。

Compare network accuracy

首先,对比预测准确率。

同样地,我们需要首先对test data进行resize:

1
2
% Prepare the test data
augimdsTest = augmentedImageDatastore(inputSize, imdsTest);

测试两个networks的预测准确率:

1
2
3
4
5
6
7
8
% Classify the test data
[YTestSmallNet, scoresSmallNet] = classify(netSmallNet, augimdsTest);
[YTestGoogLeNet, scoresGoogLeNet] = classify(netGoogLeNet, augimdsTest);

% Compare the accuracy
TTest = imdsTest.Labels;
accSmallNet = sum(TTest == YTestSmallNet)/numel(TTest);
accGoogLeNet = sum(TTest == YTestGoogLeNet)/numel(TTest);
1
2
3
4
5
>> accSmallNet, accGoogLeNet
accSmallNet =
    0.7401
accGoogLeNet =
    0.8898

这一部分代码另一个主要作用就是为后面rocmetrics函数提供Scores输入:scoresSmallNetscoresGoogLeNet

Confusion matrix

绘制混淆矩阵:

1
2
3
4
5
6
7
8
9
% Plot confunsion matrix
figure 
tiledlayout(1,2)
nexttile
confusionchart(TTest, YTestSmallNet)
title("SmallNet")
nexttile
confusionchart(TTest, YTestGoogLeNet)
title("GoogLeNet")

image-20230106202354561

ROC curves

Compare ROC curves

使用rocmetrics函数计算并绘制ROC曲线:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
% Compare ROC Curves
% Create rocmetrics objects
classNames = netSmallNet.Layers(end).Classes;
rocSmallNet = rocmetrics(TTest, scoresSmallNet, classNames);
rocGoogLeNet = rocmetrics(TTest, scoresGoogLeNet, classNames);

% Plot ROC Curves
figure
tiledlayout(1,2)
nexttile
plot(rocSmallNet)
title("ROC Curve: SmallNet")
nexttile
plot(rocGoogLeNet)
title("ROC Curve: GoogLeNet")

image-20230106202641040

Compare AUC values

1
2
3
% Access the AUC value for each class
aucSmallNet = rocSmallNet.AUC;
aucGoogLeNet = rocGoogLeNet.AUC;
1
2
3
4
5
6
7
8
>> aucSmallNet, aucGoogLeNet
aucSmallNet =
  1×5 single row vector
    0.9518    0.9278    0.8823    0.9632    0.9064

aucGoogLeNet =
  1×5 single row vector
    0.9830    0.9919    0.9771    0.9900    0.9776

并以条形图的方式可视化:

1
2
3
4
5
figure
bar([aucSmallNet; aucGoogLeNet]')
xticklabels(classNames)
legend(["SmallNet","GoogLeNet"],Location="southeast")
title("AUC")

image-20230106202814337

Investigate the specific class

本部分针对类别”sunflowers”对比两个模型的ROC曲线:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
% ROC curves of 'sunflowers' class
classToInvestigate = "sunflowers";

figure
c = cell(2,1);
g = cell(2,1);
[c{1},g{1}] = plot(rocSmallNet,ClassNames=classToInvestigate);
hold on
[c{2},g{2}] = plot(rocGoogLeNet,ClassNames=classToInvestigate);
modelNames = ["SmallNet","GoogLeNet"];
for i = 1:2
    c{i}.DisplayName = replace(c{i}.DisplayName, ...
        classToInvestigate,modelNames(i));
    g{i}(1).DisplayName = join([modelNames(i),"Model Operating Point"]);
end
title("ROC Curve","Class: " + classToInvestigate)
hold off

image-20230106202849289

Compare average ROC curves

对比两个神经网络模型的平均ROC曲线:

1
2
3
4
5
6
7
8
9
10
11
12
13
% Compare average ROC curves
figure
[FPR1, TPR1, Thresholds1, AUC1] = average(rocSmallNet, "macro");
[FPR2, TPR2, Thresholds2, AUC2] = average(rocGoogLeNet, "macro");
hold(gca, "on")
box(gca, "on")
plot([0; FPR1], [0; TPR1],...
    DisplayName=sprintf("SmallNet (AUC=%.4f)", AUC1))
plot([0; FPR2], [0; TPR2],...
    DisplayName=sprintf("GoogLelNet (AUC=%.4f)", AUC2))
legend("Location", "southeast")
xlabel("FPR")
ylabel("TPR")

image-20230106204356697

Small sumarry

上述的这些测试效果都表明:采用了GoogLeNet神经网络结构和Transfer Learning的神经网络具有更优秀的性能。


References