Train Conditional Generative Adversarial Network (CGAN) in MATLAB
Introduction
博客1分析了MATLAB中一个关于GAN的示例,但实际上,GAN有很多很多变种,MATLAB还提供了一个CGAN(Conditional Generative Adversarial Network,条件生成对抗网络)的示例,”Train Conditional Generative Adversarial Network(CGAN)”2. 前者可以视为一个无监督学习模型,而后者是一个使用带标签数据集训练的有监督学习模型。两个示例使用的是同一个训练集,其代码也非常类似,本文主要分析两者实现之间的差异,相同的部分从略。
Differences Between GAN and CGAN
Difference 1: Load Data
对于CGAN,在使用imgaeDatastore
装载图片数据时,需要新增LabelSource
的参数设置:
1
2
3
mageFolder = fullfile(pwd,"flower_photos");
datasetFolder = fullfile(imageFolder);
imds = imageDatastore(datasetFolder,IncludeSubfolders=true,LabelSource="foldernames");
表示将flower_photos
文件夹下的子文件夹名作为图片的标签。此时,imds
变量新增了一个Labels
属性:
我们可以查看一下分类:
1
2
classes = categories(imds.Labels);
numClasses = numel(classes);
1
2
3
4
5
6
7
8
9
10
classes =
5×1 cell array
{'daisy' }
{'dandelion' }
{'roses' }
{'sunflowers'}
{'tulips' }
numClasses =
5
Difference 2: Define the structures of Generator and Discriminator


上面一组图展示了GAN和CGAN结构之间的差异,可以看到,在CGAN中,图片的Labels需要输入到Generator和Discriminator中。因此,CGAN的Generator和Discriminator的结构也需要更改为下面的样子:


下面分别进行介绍。
The structure of Generator
CGAN的Generator网络结构如下图所示:
和GAN的差异就在前面一部分:
它的结构定义为:
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
numLatentInputs = 100;
embeddingDimension = 50;
numFilters = 64;
filterSize = 5;
projectionSize = [4 4 1024];
layersGenerator = [
featureInputLayer(numLatentInputs) % Output size, [100, 1]
fullyConnectedLayer(prod(projectionSize)) % Output size, [16384, 1]
functionLayer(@(X) feature2image(X,projectionSize),Formattable=true) % Output size, [4, 4, 1024, 1]
concatenationLayer(3,2,Name="cat"); % Output size, [4, 4, 1025, 1]
transposedConv2dLayer(filterSize,4*numFilters) % Output size, [8, 8, 256, 1]
batchNormalizationLayer % Output size, [8, 8, 256, 1]
reluLayer % Output size, [8, 8, 256, 1]
transposedConv2dLayer(filterSize,2*numFilters,Stride=2,Cropping="same") % Output size, [16, 16, 128, 1]
batchNormalizationLayer % Output size, [16, 16, 128, 1]
reluLayer % Output size, [16, 16, 128, 1]
transposedConv2dLayer(filterSize,numFilters,Stride=2,Cropping="same") % Output size, [32, 32, 64, 1]
batchNormalizationLayer % Output size, [32, 32, 64, 1]
reluLayer % Output size, [32, 32, 64, 1]
transposedConv2dLayer(filterSize,3,Stride=2,Cropping="same")% Output size, [64, 64, 3, 1]
tanhLayer % Output size, [64, 64, 3, 1]
];
lgraphGenerator = layerGraph(layersGenerator);
layers = [
featureInputLayer(1) % Output size, [1, 1]
embeddingLayer(embeddingDimension,numClasses) % Output size, [50, 1]
fullyConnectedLayer(prod(projectionSize(1:2))) % Output size, [16, 1]
functionLayer(@(X) feature2image(X,[projectionSize(1:2) 1]), ...
Formattable=true,Name="emb_reshape") % Output size, [4, 4, 1, 1]
];
lgraphGenerator = addLayers(lgraphGenerator,layers);
lgraphGenerator = connectLayers(lgraphGenerator,"emb_reshape","cat/in2");
% convert the layer graph to a dlnetwork object
netG = dlnetwork(lgraphGenerator)
1
plot(lgraphGenerator)
其中,比较陌生的层是feature2image
层、concatenationLayer
层和embeddingLayer
层。
feature2image
Custom Layer
feature2image
层是一个用户自定义层:
1
2
3
4
5
6
function Y = feature2image(X,outputSize)
% Y = feature2image(X,outputSize) reshapes input with format "CB" to have
% format "SSCB" to have size given by outputSize.
Y = reshape(X, outputSize(1), outputSize(2), outputSize(3), []);
Y = dlarray(Y,'SSCB');
end
这个层的作用是将一个CB类型的数据reshape成一个SSCB的数据,对应于RGB图像。
concatenationLayer
Built-in Layer
concatenationLayer
层是MATLAB提供的一个层3,该层的作用就是连接两个网络结构:
1
2
3
4
layersGenerator = [
...
concatenationLayer(3,2,Name="cat");
...];
其中:
3
是dim
属性,表示连接的维度是3;2
是numInputs
属性,表示该层提供的连接端口个数;
它需要配合addLayers
和connectLayers
函数进行使用:
1
2
lgraphGenerator = addLayers(lgraphGenerator,layers);
lgraphGenerator = connectLayers(lgraphGenerator,"emb_reshape","cat/in2");
embeddingLayer
Custom Layer
embeddingLayer
层同样是一个用户自定义层:
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
classdef embeddingLayer < nnet.layer.Layer & ...
nnet.layer.Formattable
properties (Learnable)
% Layer learnable parameters.
Weights
end
methods
function layer = embeddingLayer(embeddingDimension, inputDimension, NameValueArgs)
% layer = embeddingLayer(embeddingDimension,inputDimension)
% creates a embedAndReshapeLayer object that embeds and
% reshapes the input to the specified output size using an
% embedding of the specified size and input dimension.
%
% layer = embeddingLayer(embeddingDimension,inputDimension,Name=name)
% also specifies the layer name.
% Parse input arguments.
arguments
embeddingDimension
inputDimension
NameValueArgs.Name = "";
end
name = NameValueArgs.Name;
% Set layer name.
layer.Name = name;
% Set layer description.
layer.Description = "Embedding layer with dimension " + embeddingDimension;
% Initialize embedding weights.
layer.Weights = randn(embeddingDimension,inputDimension);
sz = [embeddingDimension inputDimension];
mu = 0;
sigma = 0.01;
layer.Weights = initializeGaussian(sz,mu,sigma);
end
function Z = predict(layer, X)
% Forward input data through the layer at prediction time and
% output the result.
%
% Inputs:
% layer - Layer to forward propagate through
% X - Numeric indices, specified as a formatted
% dlarray with a "C" and optionally a "B"
% dimension.
% Outputs:
% Z - Output of layer forward function returned as
% an dlarray with format "CB".
% Embedding.
weights = layer.Weights;
Z = embed(X,weights);
end
end
end
它的作用主要还是Reshape,相当于定义了两个层和它们之间的连接线。
The structure of Discriminator
同样地,CGAN和GAN的Discriminator之间的差异仍体现在前面一部分:
其结构定义:
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
dropoutProb = 0.75;
numFilters = 64;
scale = 0.2;
inputSize = [64 64 3];
filterSize = 5;
layersDiscriminator = [
imageInputLayer(inputSize,Normalization="none") % Output size, [64, 64, 3, 1]
dropoutLayer(dropoutProb) % Output size, [64, 64, 3, 1]
concatenationLayer(3,2,Name="cat") % Output size, [64, 64, 4, 1]
convolution2dLayer(filterSize,numFilters, ...
Stride=2,Padding="same") % Output size, [32, 32, 64, 1]
leakyReluLayer(scale) % Output size, [32, 32, 64, 1]
convolution2dLayer(filterSize,2*numFilters, ...
Stride=2,Padding="same") % Output size, [16, 16, 128, 1]
batchNormalizationLayer % Output size, [16, 16, 128, 1]
leakyReluLayer(scale) % Output size, [16, 16, 128, 1]
convolution2dLayer(filterSize,4*numFilters, ...
Stride=2,Padding="same") % Output size, [8, 8, 256, 1]
batchNormalizationLayer % Output size, [8, 8, 256, 1]
leakyReluLayer(scale) % Output size, [8, 8, 256, 1]
convolution2dLayer(filterSize,8*numFilters, ...
Stride=2,Padding="same") % Output size, [4, 4, 512, 1]
batchNormalizationLayer % Output size, [4, 4, 512, 1]
leakyReluLayer(scale) % Output size, [4, 4, 512, 1]
convolution2dLayer(4,1) % Output size, [1, 1, 1, 1]
];
lgraphDiscriminator = layerGraph(layersDiscriminator);
layers = [
featureInputLayer(1) % Output size, [1, 1]
embeddingLayer(embeddingDimension,numClasses) % Output size, [50, 1]
fullyConnectedLayer(prod(inputSize(1:2))) % Output size, [4096, 1]
functionLayer(@(X) feature2image(X,[inputSize(1:2) 1]), ...
Formattable=true,Name="emb_reshape") % Output size, [64, 64, 1, 1]
];
lgraphDiscriminator = addLayers(lgraphDiscriminator,layers);
lgraphDiscriminator = connectLayers(lgraphDiscriminator,"emb_reshape","cat/in2");
% Convert the layer graph to a dlnetwork object
netD = dlnetwork(lgraphDiscriminator)
与Generator的定义是类似的。
1
plot(lgraphDiscriminator)
Difference 3: Calculate losses and scores
CGAN在losses和scores的计算方面也有细微差异:
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
function [lossG,lossD,gradientsG,gradientsD,stateG,scoreG,scoreD] = ...
modelLoss(netG,netD,X,T,Z,flipFactor)
% Calculate the predictions for real data with the discriminator network.
YReal = forward(netD,X,T);
% Calculate the predictions for generated data with the discriminator network.
[XGenerated,stateG] = forward(netG,Z,T);
YGenerated = forward(netD,XGenerated,T);
% Calculate probabilities.
probGenerated = sigmoid(YGenerated);
probReal = sigmoid(YReal);
% Calculate the generator and discriminator scores.
scoreG = mean(probGenerated);
scoreD = (mean(probReal) + mean(1-probGenerated)) / 2;
% Flip labels.
numObservations = size(YReal,4);
idx = randperm(numObservations,floor(flipFactor * numObservations));
probReal(:,:,:,idx) = 1 - probReal(:,:,:,idx);
% Calculate the GAN loss.
[lossG, lossD] = ganLoss(probReal,probGenerated);
% For each network, calculate the gradients with respect to the loss.
gradientsG = dlgradient(lossG,netG.Learnables,RetainData=true);
gradientsD = dlgradient(lossD,netD.Learnables);
end
1
2
3
4
5
6
7
function [lossG, lossD] = ganLoss(probReal,probGenerated)
% Combine the losses for the discriminator network.
lossD = -mean(log(probReal))-mean(log(1 - probGenerated));
% Calculate the loss for the generator network.
lossG = -mean(log(probGenerated));
end
CGAN的Generator和Discriminator网络的losses和scores计算路径示意图如下图所示:
与GAN相比:
可以看到主要有两点差异:
- 标签
T
需要输入到G和D中; - 在本示例的CGAN中,D输出值需要先经过
sigmoid
函数计算,再进行计算,这一点影响不大;
Difference 4: Fix validation data
在设置用于观察训练进度的验证集时,除了固定随机向量,还需要固定标签:
1
2
3
4
5
6
7
8
9
10
11
12
13
14
numValidationImagesPerClass = 5;
ZValidation = randn(numLatentInputs,numValidationImagesPerClass*numClasses,"single");
TValidation = single(repmat(1:numClasses,[1 numValidationImagesPerClass]));
% Convert the data to dlarray objects and specify the dimension labels "CB" (channel, batch).
ZValidation = dlarray(ZValidation,"CB");
TValidation = dlarray(TValidation,"CB");
% For GPU training, convert the data to gpuArray objects
if (executionEnvironment == "auto" && canUseGPU) || executionEnvironment == "gpu"
ZValidation = gpuArray(ZValidation);
TValidation = gpuArray(TValidation);
end
Difference 5: Generate New Images
同样地,在测试Generator生成图片的效果时,除了要定义随机向量,还需要定义标签:
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
% Create an array of 36 vectors of random values corresponding to the first class
numObservationsNew = 36;
idxClass = 1;
ZNew = randn(numLatentInputs,numObservationsNew,"single");
TNew = repmat(single(idxClass),[1 numObservationsNew]);
% Convert the data to dlarray objects with the dimension labels "SSCB" (spatial, spatial, channels, batch)
ZNew = dlarray(ZNew,"CB");
TNew = dlarray(TNew,"CB");
% To generate images using the GPU, also convert the data to gpuArray objects
if (executionEnvironment == "auto" && canUseGPU) || executionEnvironment == "gpu"
ZNew = gpuArray(ZNew);
TNew = gpuArray(TNew);
end
% Generate images using the predict function with the generator network
XGeneratedNew = predict(netG,ZNew,TNew);
Result of CGAN
使用NVIDA GeForce RTX 3060 Ti GPU 花费1h45min训练了500轮,得到最终的结果:
查看训练好的生成器生成daisy花图像的效果:
最后,保存一下模型:
1
2
3
% Save models
save('Generator.mat', 'netG')
save('Discriminator.mat', 'netD')
并进行测试:
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
clc, clear, close all
load Generator.mat
load Discriminator.mat
% Create an array of 36 vectors of random values corresponding to the first class
numObservationsNew = 36;
numLatentInputs = 100;
idxClass = 1;
executionEnvironment = "auto";
ZNew = randn(numLatentInputs,numObservationsNew,"single");
TNew = repmat(single(idxClass),[1 numObservationsNew]);
% Convert the data to dlarray objects with the dimension labels "SSCB" (spatial, spatial, channels, batch)
ZNew = dlarray(ZNew,"CB");
TNew = dlarray(TNew,"CB");
% To generate images using the GPU, also convert the data to gpuArray objects
if (executionEnvironment == "auto" && canUseGPU) || executionEnvironment == "gpu"
ZNew = gpuArray(ZNew);
TNew = gpuArray(TNew);
end
% Generate images using the predict function with the generator network
XGeneratedNew = predict(netG,ZNew,TNew);
% Display the generated images in a plot
figure
I = imtile(extractdata(XGeneratedNew));
I = rescale(I);
imshow(I)
title("Class: daisy")