Train Conditional Generative Adversarial Network (CGAN) in MATLAB

Oct. 09, 2022

Introduction

博客1分析了MATLAB中一个关于GAN的示例,但实际上,GAN有很多很多变种,MATLAB还提供了一个CGAN(Conditional Generative Adversarial Network,条件生成对抗网络)的示例,”Train Conditional Generative Adversarial Network(CGAN)”2. 前者可以视为一个无监督学习模型,而后者是一个使用带标签数据集训练的有监督学习模型。两个示例使用的是同一个训练集,其代码也非常类似,本文主要分析两者实现之间的差异,相同的部分从略。


Differences Between GAN and CGAN

Difference 1: Load Data

对于CGAN,在使用imgaeDatastore装载图片数据时,需要新增LabelSource的参数设置:

1
2
3
mageFolder = fullfile(pwd,"flower_photos");
datasetFolder = fullfile(imageFolder);
imds = imageDatastore(datasetFolder,IncludeSubfolders=true,LabelSource="foldernames");

表示将flower_photos文件夹下的子文件夹名作为图片的标签。此时,imds变量新增了一个Labels属性:

image-20221008134917873

我们可以查看一下分类:

1
2
classes = categories(imds.Labels);
numClasses = numel(classes);
1
2
3
4
5
6
7
8
9
10
classes =
  5×1 cell array
    {'daisy'     }
    {'dandelion' }
    {'roses'     }
    {'sunflowers'}
    {'tulips'    }

numClasses =
     5

Difference 2: Define the structures of Generator and Discriminator

image-20221008133932054 image-20221008134014138 (a)The structure of GAN; (b)The structure of CGAN.

上面一组图展示了GAN和CGAN结构之间的差异,可以看到,在CGAN中,图片的Labels需要输入到Generator和Discriminator中。因此,CGAN的Generator和Discriminator的结构也需要更改为下面的样子:

image-20221008135334707 image-20221008135422776
(a)The structure of Generator of CGAN;(b)The stucture of Discriminator of CGAN.

下面分别进行介绍。

The structure of Generator

CGAN的Generator网络结构如下图所示:

image-20221008135334707

和GAN的差异就在前面一部分:

image-20221008140148586

它的结构定义为:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
numLatentInputs = 100;
embeddingDimension = 50;
numFilters = 64;

filterSize = 5;
projectionSize = [4 4 1024];

layersGenerator = [
    featureInputLayer(numLatentInputs)               % Output size, [100, 1]
    fullyConnectedLayer(prod(projectionSize))        % Output size, [16384, 1]
    functionLayer(@(X) feature2image(X,projectionSize),Formattable=true) % Output size, [4, 4, 1024, 1]
    concatenationLayer(3,2,Name="cat");              % Output size, [4, 4, 1025, 1]

    transposedConv2dLayer(filterSize,4*numFilters)   % Output size, [8, 8, 256, 1]
    batchNormalizationLayer                          % Output size, [8, 8, 256, 1]
    reluLayer                                        % Output size, [8, 8, 256, 1]
    transposedConv2dLayer(filterSize,2*numFilters,Stride=2,Cropping="same") % Output size, [16, 16, 128, 1]
    batchNormalizationLayer                          % Output size, [16, 16, 128, 1]
    reluLayer                                        % Output size, [16, 16, 128, 1]
    transposedConv2dLayer(filterSize,numFilters,Stride=2,Cropping="same") % Output size, [32, 32, 64, 1]
    batchNormalizationLayer                          % Output size, [32, 32, 64, 1]
    reluLayer                                        % Output size, [32, 32, 64, 1]
    transposedConv2dLayer(filterSize,3,Stride=2,Cropping="same")% Output size, [64, 64, 3, 1]
    tanhLayer                                        % Output size, [64, 64, 3, 1]
    ];

lgraphGenerator = layerGraph(layersGenerator);

layers = [
    featureInputLayer(1)                           % Output size, [1, 1]
    embeddingLayer(embeddingDimension,numClasses)  % Output size, [50, 1]
    fullyConnectedLayer(prod(projectionSize(1:2))) % Output size, [16, 1]
    functionLayer(@(X) feature2image(X,[projectionSize(1:2) 1]), ...
    Formattable=true,Name="emb_reshape")           % Output size, [4, 4, 1, 1]
    ];

lgraphGenerator = addLayers(lgraphGenerator,layers);
lgraphGenerator = connectLayers(lgraphGenerator,"emb_reshape","cat/in2");

% convert the layer graph to a dlnetwork object
netG = dlnetwork(lgraphGenerator)
1
plot(lgraphGenerator)

image-20221008142317431

其中,比较陌生的层是feature2image层、concatenationLayer层和embeddingLayer层。

feature2image Custom Layer

feature2image层是一个用户自定义层:

1
2
3
4
5
6
function Y = feature2image(X,outputSize)
% Y = feature2image(X,outputSize) reshapes input with format "CB" to have
% format "SSCB" to have size given by outputSize.
Y = reshape(X, outputSize(1), outputSize(2), outputSize(3), []);
Y = dlarray(Y,'SSCB');
end

这个层的作用是将一个CB类型的数据reshape成一个SSCB的数据,对应于RGB图像。

concatenationLayerBuilt-in Layer

concatenationLayer层是MATLAB提供的一个层3,该层的作用就是连接两个网络结构:

1
2
3
4
layersGenerator = [
	...
	concatenationLayer(3,2,Name="cat");
	...];

其中:

  • 3dim属性,表示连接的维度是3;
  • 2numInputs属性,表示该层提供的连接端口个数;

它需要配合addLayersconnectLayers函数进行使用:

1
2
lgraphGenerator = addLayers(lgraphGenerator,layers);
lgraphGenerator = connectLayers(lgraphGenerator,"emb_reshape","cat/in2");

embeddingLayerCustom Layer

embeddingLayer层同样是一个用户自定义层:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
classdef embeddingLayer < nnet.layer.Layer & ...
        nnet.layer.Formattable

    properties (Learnable)
        % Layer learnable parameters.
        Weights
    end
    
    methods
        function layer = embeddingLayer(embeddingDimension, inputDimension, NameValueArgs)
            % layer = embeddingLayer(embeddingDimension,inputDimension)
            % creates a embedAndReshapeLayer object that embeds and
            % reshapes the input to the specified output size using an
            % embedding of the specified size and input dimension.
            %
            % layer = embeddingLayer(embeddingDimension,inputDimension,Name=name)
            % also specifies the layer name.
            
            % Parse input arguments.
            arguments
                embeddingDimension
                inputDimension
                NameValueArgs.Name = "";
            end
            
            name = NameValueArgs.Name;
            
            % Set layer name.
            layer.Name = name;

            % Set layer description.
            layer.Description = "Embedding layer with dimension " + embeddingDimension;
            
            % Initialize embedding weights.
            layer.Weights = randn(embeddingDimension,inputDimension);
            sz = [embeddingDimension inputDimension];
            mu = 0;
            sigma = 0.01;
            layer.Weights = initializeGaussian(sz,mu,sigma);
        end
        
        function Z = predict(layer, X)
            % Forward input data through the layer at prediction time and
            % output the result.
            %
            % Inputs:
            %         layer - Layer to forward propagate through
            %         X     - Numeric indices, specified as a formatted
            %                 dlarray with a "C" and optionally a "B"
            %                 dimension.
            % Outputs:
            %         Z     - Output of layer forward function returned as 
            %                 an dlarray with format "CB".

            % Embedding.
            weights = layer.Weights;
            Z = embed(X,weights);
        end
    end
end

它的作用主要还是Reshape,相当于定义了两个层和它们之间的连接线。

The structure of Discriminator

同样地,CGAN和GAN的Discriminator之间的差异仍体现在前面一部分:

image-20221008144604573

其结构定义:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
dropoutProb = 0.75;
numFilters = 64;
scale = 0.2;

inputSize = [64 64 3];
filterSize = 5;

layersDiscriminator = [
    imageInputLayer(inputSize,Normalization="none")   % Output size, [64, 64, 3, 1]
    dropoutLayer(dropoutProb)                         % Output size, [64, 64, 3, 1]
    concatenationLayer(3,2,Name="cat")                % Output size, [64, 64, 4, 1]
    convolution2dLayer(filterSize,numFilters, ...     
    Stride=2,Padding="same")                          % Output size, [32, 32, 64, 1]
    leakyReluLayer(scale)                             % Output size, [32, 32, 64, 1]
    convolution2dLayer(filterSize,2*numFilters, ...   
    Stride=2,Padding="same")                          % Output size, [16, 16, 128, 1]
    batchNormalizationLayer                           % Output size, [16, 16, 128, 1]
    leakyReluLayer(scale)                             % Output size, [16, 16, 128, 1]
    convolution2dLayer(filterSize,4*numFilters, ...
    Stride=2,Padding="same")                          % Output size, [8, 8, 256, 1]
    batchNormalizationLayer                           % Output size, [8, 8, 256, 1]
    leakyReluLayer(scale)                             % Output size, [8, 8, 256, 1]
    convolution2dLayer(filterSize,8*numFilters, ...
    Stride=2,Padding="same")                          % Output size, [4, 4, 512, 1]
    batchNormalizationLayer                           % Output size, [4, 4, 512, 1]
    leakyReluLayer(scale)                             % Output size, [4, 4, 512, 1]
    convolution2dLayer(4,1)                           % Output size, [1, 1, 1, 1]
    ];

lgraphDiscriminator = layerGraph(layersDiscriminator);

layers = [
    featureInputLayer(1)                                   % Output size, [1, 1]
    embeddingLayer(embeddingDimension,numClasses)          % Output size, [50, 1]
    fullyConnectedLayer(prod(inputSize(1:2)))              % Output size, [4096, 1]
    functionLayer(@(X) feature2image(X,[inputSize(1:2) 1]), ...
    Formattable=true,Name="emb_reshape")                   % Output size, [64, 64, 1, 1]
    ]; 

lgraphDiscriminator = addLayers(lgraphDiscriminator,layers);
lgraphDiscriminator = connectLayers(lgraphDiscriminator,"emb_reshape","cat/in2");

% Convert the layer graph to a dlnetwork object
netD = dlnetwork(lgraphDiscriminator)

与Generator的定义是类似的。

1
plot(lgraphDiscriminator)

image-20221008144924079

Difference 3: Calculate losses and scores

CGAN在losses和scores的计算方面也有细微差异:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
function [lossG,lossD,gradientsG,gradientsD,stateG,scoreG,scoreD] = ...
    modelLoss(netG,netD,X,T,Z,flipFactor)

% Calculate the predictions for real data with the discriminator network.
YReal = forward(netD,X,T);

% Calculate the predictions for generated data with the discriminator network.
[XGenerated,stateG] = forward(netG,Z,T);
YGenerated = forward(netD,XGenerated,T);

% Calculate probabilities.
probGenerated = sigmoid(YGenerated);
probReal = sigmoid(YReal);

% Calculate the generator and discriminator scores.
scoreG = mean(probGenerated);
scoreD = (mean(probReal) + mean(1-probGenerated)) / 2;

% Flip labels.
numObservations = size(YReal,4);
idx = randperm(numObservations,floor(flipFactor * numObservations));
probReal(:,:,:,idx) = 1 - probReal(:,:,:,idx);

% Calculate the GAN loss.
[lossG, lossD] = ganLoss(probReal,probGenerated);

% For each network, calculate the gradients with respect to the loss.
gradientsG = dlgradient(lossG,netG.Learnables,RetainData=true);
gradientsD = dlgradient(lossD,netD.Learnables);
end
1
2
3
4
5
6
7
function [lossG, lossD] = ganLoss(probReal,probGenerated)
% Combine the losses for the discriminator network.
lossD = -mean(log(probReal))-mean(log(1 - probGenerated));

% Calculate the loss for the generator network.
lossG = -mean(log(probGenerated));
end

CGAN的Generator和Discriminator网络的losses和scores计算路径示意图如下图所示:

image-20221008154036959

image-20221008154014613

与GAN相比:

image-20221008153243800

image-20221008153217280

可以看到主要有两点差异:

  1. 标签T需要输入到G和D中;
  2. 在本示例的CGAN中,D输出值需要先经过sigmoid函数计算,再进行计算,这一点影响不大;

Difference 4: Fix validation data

在设置用于观察训练进度的验证集时,除了固定随机向量,还需要固定标签:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
numValidationImagesPerClass = 5;
ZValidation = randn(numLatentInputs,numValidationImagesPerClass*numClasses,"single");

TValidation = single(repmat(1:numClasses,[1 numValidationImagesPerClass]));

% Convert the data to dlarray objects and specify the dimension labels "CB" (channel, batch).
ZValidation = dlarray(ZValidation,"CB");
TValidation = dlarray(TValidation,"CB");

% For GPU training, convert the data to gpuArray objects
if (executionEnvironment == "auto" && canUseGPU) || executionEnvironment == "gpu"
    ZValidation = gpuArray(ZValidation);
    TValidation = gpuArray(TValidation);
end

Difference 5: Generate New Images

同样地,在测试Generator生成图片的效果时,除了要定义随机向量,还需要定义标签:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
% Create an array of 36 vectors of random values corresponding to the first class
numObservationsNew = 36;
idxClass = 1;
ZNew = randn(numLatentInputs,numObservationsNew,"single");
TNew = repmat(single(idxClass),[1 numObservationsNew]);

% Convert the data to dlarray objects with the dimension labels "SSCB" (spatial, spatial, channels, batch)
ZNew = dlarray(ZNew,"CB");
TNew = dlarray(TNew,"CB");

% To generate images using the GPU, also convert the data to gpuArray objects
if (executionEnvironment == "auto" && canUseGPU) || executionEnvironment == "gpu"
    ZNew = gpuArray(ZNew);
    TNew = gpuArray(TNew);
end
% Generate images using the predict function with the generator network
XGeneratedNew = predict(netG,ZNew,TNew);


Result of CGAN

使用NVIDA GeForce RTX 3060 Ti GPU 花费1h45min训练了500轮,得到最终的结果:

image-20221008231021799

查看训练好的生成器生成daisy花图像的效果:

image-20221008231146582

最后,保存一下模型:

1
2
3
% Save models 
save('Generator.mat', 'netG')
save('Discriminator.mat', 'netD')

并进行测试:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
clc, clear, close all

load Generator.mat
load Discriminator.mat

% Create an array of 36 vectors of random values corresponding to the first class
numObservationsNew = 36;
numLatentInputs = 100;
idxClass = 1;
executionEnvironment = "auto";

ZNew = randn(numLatentInputs,numObservationsNew,"single");
TNew = repmat(single(idxClass),[1 numObservationsNew]);

% Convert the data to dlarray objects with the dimension labels "SSCB" (spatial, spatial, channels, batch)
ZNew = dlarray(ZNew,"CB");
TNew = dlarray(TNew,"CB");

% To generate images using the GPU, also convert the data to gpuArray objects
if (executionEnvironment == "auto" && canUseGPU) || executionEnvironment == "gpu"
    ZNew = gpuArray(ZNew);
    TNew = gpuArray(TNew);
end
% Generate images using the predict function with the generator network
XGeneratedNew = predict(netG,ZNew,TNew);

% Display the generated images in a plot
figure
I = imtile(extractdata(XGeneratedNew));
I = rescale(I);
imshow(I)
title("Class: daisy")

image-20221008231718465


Reference