Wasserstein-GAN(WGAN)对抗生成网络,数据由Excel导入,直接

作品简介

适用平台:Matlab 2021a及以上

Wasserstein GAN(Wasserstein Generative Adversarial Network,WGAN)是一种生成对抗网络(GAN)的改进模型,旨在解决传统GAN模型中存在的梯度消失、训练不稳定以及模式崩溃等问题。WGAN引入了Wasserstein距离(也称为Earth Mover's Distance,EMD)作为GAN的优化目标,以更稳定和可解释的方式进行训练生成器和判别器。

以下是Wasserstein GAN的一些关键特点和工作原理:

Wasserstein距离:传统GAN使用的JS(Jensen-Shannon)散度或KL(Kullback-Leibler)散度作为损失函数,但这些损失函数存在问题,导致训练不稳定。WGAN引入Wasserstein距离,它被认为更好地度量了两个分布之间的差异,尤其对于高维数据更加有效。Wasserstein距离计算了将一个分布转化成另一个分布的最小代价,通常被称为“运输”成本。

梯度稳定性:WGAN通过使用Wasserstein距离作为损失函数显著提高了梯度稳定性。这意味着在训练过程中,生成器和判别器之间的梯度不会突然消失或爆炸,从而更容易收敛到稳定的解。

Lipschitz连续性:为了确保判别器是Lipschitz连续的(一种数学性质,对Wasserstein距离的计算至关重要),WGAN对判别器的权重进行了剪裁或权重约束,以确保梯度不会变得不稳定。这一步骤被称为权重剪裁。

生成器和判别器平衡:WGAN的训练过程更容易实现生成器和判别器的平衡。这意味着生成器和判别器之间的性能差异不会太大,生成器更容易生成高质量的样本。

生成高质量样本:由于WGAN的稳定性和平衡性,生成器更容易生成高质量的数据样本,这对于图像生成、数据生成等任务非常有用。

综上Wasserstein GAN通过引入Wasserstein距离和一系列改进方法,解决了传统GAN模型中的一些问题,使生成对抗网络更容易训练,生成更高质量的数据。这使得WGAN成为生成模型领域的一项重要进展。

对抗生成样本对比​:

部分代码:

​
%xxxxx
%% 训练
iterationG = 0;
iterationD = 0;
start = tic;
​
%% 循环处理像批量数据
while iterationG < numIterationsG
    % 生成器迭代次数 + 1
    iterationG = iterationG + 1;
​
    % 训练判别器
    for n = 1 : numIterationsDPerG
        iterationD = iterationD + 1;
​
        %重置并打乱数据
        temp = randperm(size(augimds, 4));
        data = augimds(:, : , :, temp);
​
        % 读取批次数据
        X = single(data);
        % 数据类型转换
        [X, ps_output] = mapminmax(X, -1, 1);
        dlX = dlarray(X, 'SSCB');
​
        % 生成生成器输入样本,并转换格式
        Z = randn([numLatentInputs, size(dlX, 4)], 'like', dlX);
        dlZ = dlarray(Z, 'CB');
​
        % 得到判别器损失和梯度
        [gradientsD, lossD, lossDUnregularized] = dlfeval(@modelGradientsD, dlnetD, dlnetG, dlX, dlZ, lambda);
​
        % 更新判别器参数
        [dlnetD, trailingAvgD, trailingAvgSqD] = adamupdate(dlnetD, gradientsD, ...
            trailingAvgD, trailingAvgSqD, iterationD, ...
            learnRateD, gradientDecayFactor, squaredGradientDecayFactor);
    end
​
    % 得到生成器输入样本,并转换格式.
    Z = randn([numLatentInputs, size(dlX, 4)], 'like', dlX);
    dlZ = dlarray(Z, 'CB');
​
    % 得到生成器梯度
    gradientsG = dlfeval(@modelGradientsG, dlnetG, dlnetD, dlZ);
​
    % 更新判别器参数
    [dlnetG, trailingAvgG, trailingAvgSqG] = adamupdate(dlnetG, gradientsG, ...
        trailingAvgG, trailingAvgSqG, iterationG, ...
        learnRateG, gradientDecayFactor, squaredGradientDecayFactor);
​
   %% 更新显示曲线
    subplot(1, 1, 1)
    % 得到判别器损失函数和未经梯度惩罚的损失函数
    lossD = double(gather(extractdata(lossD)));
    lossDUnregularized = double(gather(extractdata(lossDUnregularized)));
    
    % 更新曲线
    addpoints(lineLossD, iterationG, lossD);
    addpoints(lineLossDUnregularized, iterationG, lossDUnregularized);
    
    % 更新标题
    D = duration(0, 0, toc(start), 'Format', 'hh:mm:ss');
    title( ...
        "Iteration: " + iterationG + ", " + ...
        "Elapsed: " + string(D))
    drawnow
end%% 生成生成器输入数据
ZNew = randn(numLatentInputs, M, 'single');
dlZNew = dlarray(ZNew, 'CB');
​
%% 判断是否存在GPU
if (executionEnvironment == "auto" && canUseGPU) || executionEnvironment == "gpu"
    dlZNew = gpuArray(dlZNew);
end%% 生成数据
dlXGeneratedNew = predict(dlnetG, dlZNew);



创作时间: