Generate Synthetic Signals using Conditional GAN (CGAN) in MATLAB
Introduction
在我目前看到的应用GAN生成数据的场景中,基本上都是使用GAN来生成图像或者震动信号。个人认为之所以这两种场景中应用比较多,是因为GAN生成的数据有很多噪声,比如博客1中daisy的原始图像和Generator生成图像之间的对比:


可以看到,噪声是很明显的。而在上述两种场景之下,噪声的影响是不那么重要的,或者可以通过后期处理掉的。比如图像有Denoise等图像处理,而对于震动信号,我们主要关心它的频域特征和时频域特征,时域中某个采样点的噪声影响不是很大。
注:在官方示例的Signal Feature Visualization部分提到了这样一段话,Unlike images and audio signals, general signals have characteristics that make them difficult for human perception to tell apart. To compare real and generated signals or healthy and faulty signals, you can apply principal component analysis (PCA) to the statistical features of the real signals and then project the features of the generated signals to the same PCA subspace.
博客123分别介绍了MATLAB官方提供的三个关于GAN的示例,均采用鲜花数据集http://download.tensorflow.org/example_images/flower_photos.tgz训练GANs,第一篇博客介绍了GAN,第二篇博客介绍了CGAN,第三篇博客介绍了WGAN-GP。
本博客所介绍了MATLAB官方提供的另一个示例4,该示例介绍了使用CGAN产生pump signal(有些类似于震动信号)的应用,与之前的三个示例不同点之处在于,本示例将CGAN的训练和使用作为整个故障诊断工作流的其中一个步骤。
Work Flow
Load Data
本示例所使用的训练集数据为仿真pump signals,由示例5中的仿真模型生成,数据压缩包可从https://ssd.mathworks.com/supportfiles/SPT/data/PumpSignalGAN.zip下载。压缩包解压后,一共有两个.mat
文件:
-
GANModel.mat
,该文件中包含着训练好的CGAN模型,包含着G和D:1 2 3
Name Size Bytes Class Attributes dlnetDiscriminator 1x1 11834154 dlnetwork dlnetGenerator 1x1 20873629 dlnetwork
这两个模型并不是pre-trained模型,而是well-trained模型,在后面用户可以通过Livescript控件来选择是使用现有的模型,还是再训练一个模型(因为在这个示例中,CGAN的使用只是作为故障诊断的其中一个步骤):
-
simulatedDataset.mat
,该文件包含着训练集数据:1 2 3
Name Size Bytes Class Attributes flow 1201x1575 15132600 double labels 1x1575 12600 double
其中,
flow
为pump signal数据,一共有1575条信号数据,每条数据有1201个样本点。labels
为数据标签:1 2 3 4 5
categories(categorical(labels)) ans = 2×1 cell array {'1'} {'2'}
其中,类别
1
代表Healthy signals,类别2
代表Faulty signals。注:还需要注意的一个点是,由于这里是Simulink仿真模型所生成的仿真数据,因此每条数据样本点的个数是一样的,都是1201个样本点,但这并不是唯一的情况。在实践中,由Pico或其他数据采集系统采集到的数据,会存在异常点(例如+Inf,-Inf,NaN值)和离群值点,我们首先会对这些点的值进行处理,处理后的信号可能是不等长的(Varying length)。对于这种情况,可以在输入到神经网络前先padding,但是这样做可能会导致每条信号的数据点不对齐,所以我认为最好的做法还是在信号预处理阶段进行改善,比如不使用
rmoutliers
函数,而是采用filloutliers
函数等等,保证进行过数据预处理后的信号还是等长的。
Define Networks Structures
本示例CGAN的示例和博客1中所述的CGAN是很类似的:


另外,和博客1中的CGAN一样,本示例虽然使用的是一维的信号数据,但是在定义G和D的时候,仍然使用的是imageInputLayer
、transposedConv2dLayer
和convolution2dLayer
,差别就在于本示例采用的是它们都是一通道的layers。
Train process and loss function
本示例的训练过程封装在了trainGAN
函数中,没有什么特殊的地方;losses和scores的计算也和博客1完全一致。
最终,使用NVIDA GeForce RTX 3060 Ti GPU 花费48min完成了对CGAN的训练。CGAN的训练效果和生成信号的示意图:
Synthesize Flow Signals
从这一步骤开始,就进入到故障诊断的阶段。首先是准备用于训练故障诊断模型的数据集。
使用训练好的CGAN生成Healthy signals和Fault signals各1000条:
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
rng default
numTests = 2000;
ZNew = randn(1,1,numLatentInputs,numTests,'single');
dlZNew = dlarray(ZNew,'SSCB');
% Specify that the first 1000 random arrays are healthy and the rest are faulty.
TNew = ones(1,1,1,numTests,'single');
TNew(1,1,1,numTests/2+1:end) = single(2);
dlTNew = dlarray(TNew,'SSCB');
% To generate signals using the GPU, convert the data to gpuArray objects
if executionEnvironment == "gpu"
dlZNew = gpuArray(dlZNew);
dlTNew = gpuArray(dlTNew);
end
% Use the predict function on the generator with the batch of 1-by-1-by-100 arrays of random values and labels to generate synthetic signals
% and revert the standardization step that you performed on the original flow signals.
dlXGeneratedNew = predict(dlnetGenerator,dlZNew,dlTNew)*stdFlow+meanFlow;
注意,在最后的时候,将生成的数据进行了反标准化,这是因为在最开始训练CGAN前首先进行了标准化处理:
1
2
3
4
5
load(fullfile(pwd,'simulatedDataset.mat')) % load data set
meanFlow = mean(flow,2);
flowNormalized = flow-meanFlow;
stdFlow = std(flowNormalized(:));
flowNormalized = flowNormalized/stdFlow;
Signal Features Visualization
Unlike images and audio signals, general signals have characteristics that make them difficult for human perception to tell apart. To compare real and generated signals or healthy and faulty signals, you can apply principal component analysis (PCA) to the statistical features of the real signals and then project the features of the generated signals to the same PCA subspace.
Feature extraction
在特征提取阶段前,首先将真实数据和生成数据混合在一起:
1
2
3
4
5
idxGenerated = 1:numTests;
idxReal = numTests+1:numTests+size(flow,2);
XGeneratedNew = squeeze(extractdata(gather(dlXGeneratedNew)));
x = [XGeneratedNew single(flow)];
之后,使用自定义函数extractFeatures
对每一条信号进行特征提取:
1
2
3
4
features = zeros(size(x,2),14,'like',x);
for ii = 1:size(x,2)
features(ii,:) = extractFeatures(x(:,ii));
end
特征提取函数extractFeatures
的定义:
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
function ci = extractFeatures(flow)
%EXTRACTFEATURES Extract features from flow signal.
% Copyright 2020 The MathWorks, Inc.
fA = flow - mean(flow);
[flowSpectrum, flowFrequencies] = pspectrum(fA, 1000, 'FrequencyLimits', [2 250]);
ci = extractCI(flow, flowSpectrum, flowFrequencies);
end
function ci = extractCI(flow, flowP, flowF)
% Compute signal statistical characteristics.
% Find frequency of peak magnitude in power spectrum
pMax = max(flowP);
fPeak = min(flowF(flowP==pMax));
% Compute power in low-frequency range 10 Hz-20 Hz
fRange = flowF >= 10 & flowF <= 20;
pLow = sum(flowP(fRange));
% Compute power in mid-frequency range 40 Hz-60 Hz
fRange = flowF >= 40 & flowF <= 60;
pMid = sum(flowP(fRange));
% Compute power in high-frequency range >100 Hz
fRange = flowF >= 100;
pHigh = sum(flowP(fRange));
% Find frequency of spectral kurtosis peak
[pKur,fKur] = pkurtosis(flow, 1000);
pKur = fKur(pKur == max(pKur));
% Compute flow cumulative sum range
csFlow = cumsum(flow);
csFlowRange = max(csFlow)-min(csFlow);
% Collect features and feature values in cell array.
qMean = mean(flow);
qVar = var(flow);
qSkewness = skewness(flow);
qKurtosis = kurtosis(flow);
qPeak2Peak = peak2peak(flow);
qCrest = peak2rms(flow);
qRMS = rms(flow);
qMAD = mad(flow);
qCSRange = csFlowRange;
pKurtosis = pKur(1);
ci = [fPeak, pLow, pMid, pHigh, pKurtosis, ...
qMean, qVar, qSkewness, qKurtosis, ...
qPeak2Peak, qCrest, qRMS, qMAD, qCSRange];
end
可以看到,这里提取的特征都是信号的频域特征。
最后,将标签进行重新标号:
1
L = [squeeze(TNew)+2; labels.'];
新的标签L
的含义:
- 1 — Generated healthy signals
- 2 — Generated faulty signals
- 3 — Real healthy signals
- 4 — Real faulty signals
注意:这里的新标签只是为了后面可视化而构造,在后续训练SVM的时候,使用的是仍然是标签1
和2
,用于表示Healthy和Faulty。
Principle Component Analysis
之后,对提取到的特征进行PCA降维:
1
2
3
4
5
6
% PCA via svd
featuresReal = features(idxReal,:);
mu = mean(featuresReal,1);
[~,S,W] = svd(featuresReal-mu);
S = diag(S);
Y = (features-mu)*W;
1
2
3
4
>> sum(S(1:3))/sum(S)
ans =
single
0.9923
可以看到,前三个奇异值占比99%。
之后,可视化四种类型信号的三个主成分在三维空间中的分布情况:
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
idxHealthyR = L==1;
idxFaultR = L==2;
idxHealthyG = L==3;
idxFaultG = L==4;
pp = Y(:,1:3);
figure
scatter3(pp(idxHealthyR,1),pp(idxHealthyR,2),pp(idxHealthyR,3),'o')
xlabel('1st Principal Component')
ylabel('2nd Principal Component')
zlabel('3rd Principal Component')
hold on
scatter3(pp(idxFaultR,1),pp(idxFaultR,2),pp(idxFaultR,3),'d')
scatter3(pp(idxHealthyG,1),pp(idxHealthyG,2),pp(idxHealthyG,3),'s')
scatter3(pp(idxFaultG,1),pp(idxFaultG,2),pp(idxFaultG,3),'+')
view(-10,20)
legend('Real healthy','Real faulty','Generated healthy','Generated faulty', ...
'Location','Best')
hold off
为了更好地捕捉到真实信号和生成信号的差异,进一步在前两个主成分所构成的平面中可视化特征的分布:
1
view(2)
可以看到,生成的信号与真实信号的分布是很接近的。
Predict Labels of Real Signals
在这一步骤中,首先使用生成信号将训练一个SVM,之后用该SVM预测真实信号的标签,并将预测标签与真实标签进行比较。
使用生成信号训练SVM:
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
LABELS = {'Healthy','Faulty'};
strL = LABELS([squeeze(TNew);labels.']).';
dataTrain = features(idxGenerated,:);
dataTest = features(idxReal,:);
labelTrain = strL(idxGenerated);
labelTest = strL(idxReal);
predictors = dataTrain;
response = labelTrain;
cvp = cvpartition(size(predictors,1),'KFold',5);
% Train an SVM classifier using the generated signals
SVMClassifier = fitcsvm( ...
predictors(cvp.training(1),:), ...
response(cvp.training(1)),'KernelFunction','polynomial', ...
'PolynomialOrder',2, ...
'KernelScale','auto', ...
'BoxConstraint',1, ...
'ClassNames',LABELS, ...
'Standardize',true);
之后,预测真实数据的标签,并计算准确率:
1
2
3
actualValue = labelTest;
predictedValue = predict(SVMClassifier,dataTest);
predictAccuracy = mean(cellfun(@strcmp,actualValue,predictedValue));
1
2
predictAccuracy =
0.8997
预测准确率达到了0.8997。最后,看一下混淆矩阵:
1
2
figure
confusionchart(actualValue,predictedValue)
Discussion
Application scenario
这种使用生成数据作为故障诊断训练集的做法通常使用于现场试验成本较大或者仿真时间较长的场景。按照官方的说法:
The Simulink simulation takes about 14 hours to generate 2000 pump flow signals. This duration can be reduced to about 1.7 hours with eight parallel workers if you have Parallel Computing Toolbox™.
The CGAN takes 1.5 hours to train and 70 seconds to generate the same amount of synthetic data with an NVIDIA Titan V GPU.
About Data Set
综上,整个示例的工作流是很漂亮的,但是从机器学习数据集分割的角度讲,个人认为仍存在一些瑕疵。
在这个示例中,作者使用了真实数据集来训练CGAN模型,并基于后者生成了生成数据,之后使用生成数据训练故障诊断模型SVM,最后用同样的真实数据作为测试集来测试SVM模型的预测准确率。在这个过程中,其实已经将SVM模型的训练集数据泄漏到了测试集中,存在数据泄露问题,所以最后SVM模型0.8997的预测准确率其实是偏高的。比较合理的做法是一开始就将真实数据集分割为训练集和测试集X1
和X2
,使用X1
来训练CGAN,使用CGAN生成数据X3
,之后基于真实数据X1
和生成数据X3
训练SVM,最后使用X2
来测试模型准确率。这样的做法可以避免数据泄露现象,关于这一点可以参考博客6。
Reference
- Arjovsky, Martin, Soumith Chintala, and Léon Bottou. “Wasserstein GAN.” arXiv preprint arXiv:1701.07875 (2017).
-
Train Conditional Generative Adversarial Network (CGAN) in MATLAB - What a starry night~. ˄ ˄2 ˄3 ˄4 ˄5
-
Train Generative Adversarial Network (GAN) in MATLAB - What a starry night~. ˄
-
Train Wassertein GAN with Gradient with Gradient Penalty (WGAN-GP) in MATLAB - What a starry night~. ˄
-
Generate Synthetic Signals Using Conditional GAN - MathWorks. ˄
-
Multi-Class Fault Detection Using Simulated Data - MathWorks. ˄