适用平台:Matlab 2021a及以上
Wasserstein GAN(Wasserstein Generative Adversarial Network,WGAN)是一种生成对抗网络(GAN)的改进模型,旨在解决传统GAN模型中存在的梯度消失、训练不稳定以及模式崩溃等问题。WGAN引入了Wasserstein距离(也称为Earth Mover's Distance,EMD)作为GAN的优化目标,以更稳定和可解释的方式进行训练生成器和判别器。
以下是Wasserstein GAN的一些关键特点和工作原理:
Wasserstein距离:传统GAN使用的JS(Jensen-Shannon)散度或KL(Kullback-Leibler)散度作为损失函数,但这些损失函数存在问题,导致训练不稳定。WGAN引入Wasserstein距离,它被认为更好地度量了两个分布之间的差异,尤其对于高维数据更加有效。Wasserstein距离计算了将一个分布转化成另一个分布的最小代价,通常被称为“运输”成本。
梯度稳定性:WGAN通过使用Wasserstein距离作为损失函数显著提高了梯度稳定性。这意味着在训练过程中,生成器和判别器之间的梯度不会突然消失或爆炸,从而更容易收敛到稳定的解。
Lipschitz连续性:为了确保判别器是Lipschitz连续的(一种数学性质,对Wasserstein距离的计算至关重要),WGAN对判别器的权重进行了剪裁或权重约束,以确保梯度不会变得不稳定。这一步骤被称为权重剪裁。
生成器和判别器平衡:WGAN的训练过程更容易实现生成器和判别器的平衡。这意味着生成器和判别器之间的性能差异不会太大,生成器更容易生成高质量的样本。
生成高质量样本:由于WGAN的稳定性和平衡性,生成器更容易生成高质量的数据样本,这对于图像生成、数据生成等任务非常有用。
综上Wasserstein GAN通过引入Wasserstein距离和一系列改进方法,解决了传统GAN模型中的一些问题,使生成对抗网络更容易训练,生成更高质量的数据。这使得WGAN成为生成模型领域的一项重要进展。
对抗生成样本对比:

部分代码:
%xxxxx
%% 训练
iterationG = 0;
iterationD = 0;
start = tic;
%% 循环处理像批量数据
while iterationG < numIterationsG
% 生成器迭代次数 + 1
iterationG = iterationG + 1;
% 训练判别器
for n = 1 : numIterationsDPerG
iterationD = iterationD + 1;
%重置并打乱数据
temp = randperm(size(augimds, 4));
data = augimds(:, : , :, temp);
% 读取批次数据
X = single(data);
% 数据类型转换
[X, ps_output] = mapminmax(X, -1, 1);
dlX = dlarray(X, 'SSCB');
% 生成生成器输入样本,并转换格式
Z = randn([numLatentInputs, size(dlX, 4)], 'like', dlX);
dlZ = dlarray(Z, 'CB');
% 得到判别器损失和梯度
[gradientsD, lossD, lossDUnregularized] = dlfeval(@modelGradientsD, dlnetD, dlnetG, dlX, dlZ, lambda);
% 更新判别器参数
[dlnetD, trailingAvgD, trailingAvgSqD] = adamupdate(dlnetD, gradientsD, ...
trailingAvgD, trailingAvgSqD, iterationD, ...
learnRateD, gradientDecayFactor, squaredGradientDecayFactor);
end
% 得到生成器输入样本,并转换格式.
Z = randn([numLatentInputs, size(dlX, 4)], 'like', dlX);
dlZ = dlarray(Z, 'CB');
% 得到生成器梯度
gradientsG = dlfeval(@modelGradientsG, dlnetG, dlnetD, dlZ);
% 更新判别器参数
[dlnetG, trailingAvgG, trailingAvgSqG] = adamupdate(dlnetG, gradientsG, ...
trailingAvgG, trailingAvgSqG, iterationG, ...
learnRateG, gradientDecayFactor, squaredGradientDecayFactor);
%% 更新显示曲线
subplot(1, 1, 1)
% 得到判别器损失函数和未经梯度惩罚的损失函数
lossD = double(gather(extractdata(lossD)));
lossDUnregularized = double(gather(extractdata(lossDUnregularized)));
% 更新曲线
addpoints(lineLossD, iterationG, lossD);
addpoints(lineLossDUnregularized, iterationG, lossDUnregularized);
% 更新标题
D = duration(0, 0, toc(start), 'Format', 'hh:mm:ss');
title( ...
"Iteration: " + iterationG + ", " + ...
"Elapsed: " + string(D))
drawnow
end
%% 生成生成器输入数据
ZNew = randn(numLatentInputs, M, 'single');
dlZNew = dlarray(ZNew, 'CB');
%% 判断是否存在GPU
if (executionEnvironment == "auto" && canUseGPU) || executionEnvironment == "gpu"
dlZNew = gpuArray(dlZNew);
end
%% 生成数据
dlXGeneratedNew = predict(dlnetG, dlZNew);