橘智橘智
FakeOrange
预计阅读时间:6分钟28秒

通俗解释生成对抗网络(GAN)背后的数学原理

本篇文章将介绍一个独特的深度学习框架——生成对抗网络(Generative Adversarial Networks,简称 GANs)

0
0


前言


本文属于搬运内容,原作者:Ameh Emmanuel Sunday,原文链接。



作者引言


本篇文章将介绍一个独特的深度学习框架——生成对抗网络(Generative Adversarial Networks,简称 GANs)。我之所以对它们感到十分着迷,不仅是因为它们的工作原理令人惊艳,更因为它们正在彻底改变我们在流体力学领域(尤其是在降阶建模和动力系统)中处理科研问题的方式。



GANs 解析


在 GAN 出现之前,大多数机器学习模型是判别式的,也就是说,它们主要被用于分类或回归任务。而我认为,GAN 的出现实际上标志着机器学习和深度学习进入了一个“创造性时代”的开始。Meta AI 的首席科学家 Yann LeCun 曾表示,GAN 是“过去十年中机器学习领域最有趣的想法”——对此我完全认同。


尽管这个框架自 2014 年才问世,但 GAN 已经在流体力学研究社区中占据了重要地位。它们被用来生成逼真的流场快照,通过学习真实流体流动的数据分布——这非常关键。因为这意味着我们现在可以在不运行完整 CFD(计算流体动力学)仿真的情况下生成流动数据,而 CFD 仿真通常代价高昂。这一点在你希望构建一个机器学习模型但又只有有限数据时尤为有用。GAN 也被用来基于已有数据生成新的、可信的仿真边界条件分布——这只是它们众多酷炫应用之一。


即使你不像我这样热爱流体力学 😎,GAN 仍然无处不在。它们被用来生成高度逼真的“假人脸”图像(这些人其实根本不存在,哈哈)。你可以访问这个网站,每次刷新页面都会生成一张完全虚构但看起来极其真实的人脸。除此之外,GAN 也被用于生成式设计,比如帮你设计出很酷的 3D 家具。Adobe 利用 GAN 构建下一代 Photoshop 工具,Google 用它们进行文本生成,IBM 用于数据增强,而 Snapchat 和 TikTok 等平台早就用它们来开发图像滤镜了。


在本文中,我将讲解 GAN 背后的数学原理。我认为学习这些数学原理非常重要,因为这不仅能帮助我们理解“GAN 为什么能起作用”,而不是仅仅知道“GAN 确实能起作用”,还能让我们跳出深度学习“黑箱思维”的框架,获得改进现有生成模型或发明新模型的工具。我还强烈建议大家去读一下 Ian Goodfellow 在 2014 年发表的那篇开创性论文,这正是 GAN 首次被介绍给世界。如果你逐字逐句读完那篇论文(原文链接),那么日后你看到任何一篇关于 GAN 的论文,基本都能一眼看出他们的研究意图和创新点。


data/78df2c1f-e442-415d-a382-fa7925af0c4b/5eaecb24-4e52-4f7c-8ae5-9d46897e71a0/4161image.png


首先从一个高层次的概览说起,生成对抗网络(GAN)是一种深度学习框架,由两部分组成:生成器(Generator)和判别器(Discriminator)。

生成器通常从一个正态或高斯分布中采样噪声,并生成一些“假样本”,比如一个假的二维温度流场。在这个例子中,我们可以将生成器看作一个“艺术伪造者”,试图伪造一个二维温度流场的样本。


而判别器则试图判断这个流场样本是“真的”还“是假的”。在这个情景中,判别器就像一个“裁判”或“法官”。


如果生成器能够成功地欺骗判别器,让它认为伪造的样本是真的,那么生成器就完成得很好,而判别器表现得不好。反之,如果判别器能够很好地识别出生成器生成的假样本,那么它就是表现出色,而生成器就还有许多学习要做。


这种你来我往的过程,通常称为对抗训练(adversarial training),会持续进行一段时间,直到双方都变得更强。训练只有在判别器无法再区分样本是真还是假时才会停止。这个时候,模型就停止学习了,生成器已经学会了训练数据的概率分布,我们可以说 GAN 的训练已经收敛。



data/78df2c1f-e442-415d-a382-fa7925af0c4b/5eaecb24-4e52-4f7c-8ae5-9d46897e71a0/eb48image.png


(GAN 的图示描述)


既然这篇博客主要讨论 GAN 背后的数学,我们现在来深入了解它背后的数学逻辑🎯。


在判别模型(如逻辑回归、支持向量机,以及一些用于分类的简单前馈神经网络)中,我们学习的是“给定某些输入特征,属于某个标签的概率”。而在生成模型中,我们学习的是两件事:“某个标签的概率”以及“在已知该标签的情况下,输入特征出现的概率”。这称为“输入特征与其对应标签的联合概率分布(joint probability distribution)”。


如果你对上面这段仍然感觉不够直观,下面用动物举个例子来说明:

  • 判别模型的思路是:“给定这张带有某些特征的图像,它是猫还是狗?”
  • 生成模型的思路则是:“我已经学会了猫的样子,现在我可以凭空生成一张真实感很强的猫的图像——即便这只猫在现实中并不存在。”


通过学习这种联合概率分布,生成器学会了训练数据的分布,因此可以生成那些看起来真实但实际上并不存在的图像,比如超真实的“假猫”。下面这张图展示了从概率角度看,判别模型与生成模型的差异:


data/78df2c1f-e442-415d-a382-fa7925af0c4b/5eaecb24-4e52-4f7c-8ae5-9d46897e71a0/9622image.png


(判别模型与生成模型的概率视角(Discriminative models: 判别模型,Generative models: 生成模型))


接下来的这张图展示了 GAN 的整体结构概览。生成器可以是一个简单的前馈神经网络、反卷积网络(transposed CNN),或者一个解码器(decoder);而判别器可以是一个简单的前馈神经网络、卷积神经网络(CNN),甚至是一个自编码器(autoencoder)。网络结构的选择取决于你想构建的 GAN 的应用场景,但无论结构如何,GAN 的两个部分都会认真地发挥各自的作用。


data/78df2c1f-e442-415d-a382-fa7925af0c4b/5eaecb24-4e52-4f7c-8ae5-9d46897e71a0/c131image.png



GAN 结构中的符号总结:


  • z:表示从高斯或正态分布中采样的噪声
  • G:表示生成器,用于从噪声中生成假图像
  • G(z):表示生成器生成的假图像
  • D:表示判别器
  • X:是真实数据样本的分布域
  • θg:表示生成器的权重和偏置,在梯度上升中被更新
  • θd:表示判别器的权重和偏置,在梯度下降中被更新


如果你仔细想一想,其实 GAN 的工作机制完全可以被建模为一个极小极大博弈(minimax game)


生成器的目标是最小化它被判别器识破所生成“假数据样本”的概率,而判别器的目标则是最大化它识别出这些“假样本”的能力。

从数学角度来看,我们可以说,生成器在最小化,而判别器在最大化一个称为 V 的值函数(也就是损失函数)的值。


data/78df2c1f-e442-415d-a382-fa7925af0c4b/5eaecb24-4e52-4f7c-8ae5-9d46897e71a0/28d1image.png


GAN 的极小极大博弈


在上图中你可以看到两个公式:一个是值函数的极小极大表示,另一个是值函数本身。我会以易于理解的方式来解释第二个公式。

观察第二个公式,你会发现它非常类似于二元交叉熵损失函数(Binary Cross-Entropy, BCE),如下所示:


data/78df2c1f-e442-415d-a382-fa7925af0c4b/5eaecb24-4e52-4f7c-8ae5-9d46897e71a0/4929image.png


(二元交叉熵损失函数公式)


  • Y 表示第 i 个样本的真实标签
  • Ŷ(Y-hat) 表示判别器对该样本的预测结果
  • n 是数据集或小批量中的样本数量


这个值函数就是从二元交叉熵损失中推导而来的。二元交叉熵目标函数衡量了我们与“正确区分真实样本和伪造样本”的目标之间的差距。

为了更好地理解 BCE 损失的工作机制,请看下图:


data/78df2c1f-e442-415d-a382-fa7925af0c4b/5eaecb24-4e52-4f7c-8ae5-9d46897e71a0/dd0eimage.png


(BCE 损失函数图解)


当样本的真实标签是 1(真实) 时,只有损失函数左半部分(绿色标出)才有效;

而当真实标签是 0(伪造) 时,只有损失函数右半部分才起作用。

从图中可以看到:

  • 当 y = 1 时,如果判别器预测值接近 1,损失接近于 0,预测越接近 0,损失趋近于 无穷大;
  • 当 y = 0 时,如果预测值接近 0,损失趋近于 0,预测越接近 1,损失同样趋向于 无穷大。


data/78df2c1f-e442-415d-a382-fa7925af0c4b/5eaecb24-4e52-4f7c-8ae5-9d46897e71a0/32dfimage.png


(BCE 损失函数中的预测与损失关系图)


将这两部分(二元交叉熵中 y=1 和 y=0 两种情况)的损失合并,就构成了生成对抗网络(GAN)中的值函数 V。由于这个值函数需要在多个数据点上进行计算,我们引入期望(E),来表示整个数据分布上的平均损失。


在 GAN 的框架中:


  • 判别器的目标是最大化这个损失函数 —— 也就是说,它希望尽可能准确地区分真实样本和伪造样本;
  • 生成器的目标则是最小化这个损失 —— 它想要欺骗判别器,让它认为伪造数据也是真实的。


data/78df2c1f-e442-415d-a382-fa7925af0c4b/5eaecb24-4e52-4f7c-8ae5-9d46897e71a0/871aimage.png


(值函数 V(D, G) 的极小极大过程)


还有一点非常重要:生成器对 D(x)(即判别器对真实样本的判断)没有任何影响,因为生成器本身并不产生真实样本。所有真实样本都是在训练过程中单独输入判别器进行学习的。



训练生成对抗网络(GANs)


下图来自 2014 年首次提出 GAN 的原始论文,展示了用于训练 GAN 的算法。我会用直观易懂的方式来解释这个过程🙃,但在此之前,请先花点时间阅读这段算法说明 —— 如果你还没读过这篇论文,强烈建议你看一看完整的原文!


data/78df2c1f-e442-415d-a382-fa7925af0c4b/5eaecb24-4e52-4f7c-8ae5-9d46897e71a0/43a6image.png


(原始 GAN 论文中的训练算法图示)



首先,从**均匀分布(uniform distribution)**中采样噪声,并将其输入生成器,生成“假数据样本”。然后断开生成器,将这些“假图像”与从真实数据分布中采样的“真实图像”一起输入判别器。此时,通过梯度上升(gradient ascent),更新判别器的权重和偏置,以寻找其值函数的全局最大值。

接着,断开判别器的连接,开始训练生成器:生成一些“假图像”,并将其传入判别器。然后通过梯度下降(gradient descent),更新生成器的参数(权重和偏置)。在每一轮训练循环中,判别器和生成器都会分别更新一次。


data/78df2c1f-e442-415d-a382-fa7925af0c4b/5eaecb24-4e52-4f7c-8ae5-9d46897e71a0/b2ebimage.png


(梯度上升与下降在损失函数曲面中的图示描述)


在实际中,判别器和生成器所拟合的函数比上图中简单的凸函数和凹函数复杂得多。图中仅展示了梯度下降和上升的最简形式,用以帮助理解。



生成器与判别器的结构设计


  • 生成器(Generator) 是一个简单的前馈神经网络,仅包含一个隐藏层,隐藏层使用 ReLU 激活函数,输出层使用 tanh 激活函数将输出压缩到 -1 到 1 之间的值范围,表示 MNIST 数据集中图像的像素强度。
  • 判别器(Discriminator) 也是一个简单的前馈神经网络,包含一个隐藏层,使用 Leaky ReLU 作为激活函数,输出层只有一个神经元,采用 sigmoid 激活函数,将输出压缩为“是否为真实图像”的概率(0 表示伪造,1 表示真实)。



从数学角度理解值函数的最大化与最小化


在固定或断开生成器之后,判别器的目标是最大化值函数,这个目标可以通过如下数学表达式来表示:


data/78df2c1f-e442-415d-a382-fa7925af0c4b/5eaecb24-4e52-4f7c-8ae5-9d46897e71a0/e5a0image.png


D(x) 被表示为


data/78df2c1f-e442-415d-a382-fa7925af0c4b/5eaecb24-4e52-4f7c-8ae5-9d46897e71a0/2511image.png


时,判别器的目标就是最大化这个值函数。其中,Pₙ(x) 是真实数据样本的概率分布,P𝗀(x) 是生成样本的概率分布。而生成器的目标则相反,是要最小化这个被最大化的函数


要最小化这个值函数,就意味着我们希望真实数据的概率分布与生成数据的概率分布相同。既然目标是让两者相等,我们就需要有一个度量它们相距多远的指标。为此,我们使用 JS 散度(Jensen–Shannon divergence),而它是通过 KL 散度(Kullback-Leibler divergence) 计算出来的👌。


data/78df2c1f-e442-415d-a382-fa7925af0c4b/5eaecb24-4e52-4f7c-8ae5-9d46897e71a0/e9c3image.png


(生成器对值函数的最小化过程)


JS 散度的表达形式与我们希望由生成器来最小化的那个最大值函数非常相似👌。要最小化这个函数,就必须让真实数据的概率分布与生成数据的概率分布完全相同。当两者达到一致时,值函数将变为 −2ln2


此时,判别器将无法再分辨输入到底是真还是假,它对所有输入的输出都变成了 0.5


data/78df2c1f-e442-415d-a382-fa7925af0c4b/5eaecb24-4e52-4f7c-8ae5-9d46897e71a0/301bimage.png


当判别器“卡住”无法判断时,就意味着训练结束


当判别器被“困住”,无法再判断哪些数据是真、哪些是假时,我们就可以停止训练了。下图直观地说明了 GAN 的训练过程,来自原始的 GAN 论文。


data/78df2c1f-e442-415d-a382-fa7925af0c4b/5eaecb24-4e52-4f7c-8ae5-9d46897e71a0/a0c2image.png


(生成器 - 绿色,判别器 - 蓝色,真实数据分布 - 黑色虚线。图源:GAN 论文 2014)


训练一开始时,生成器和判别器都“一无所知”。在图 (b) 中,判别器的权重和偏置被更新,而生成器是冻结的;在图 (c) 中,判别器被冻结,生成器被训练并更新其权重和偏置。


这个循环不断重复,形成所谓的对抗训练(adversarial training)。随着训练的推进,图 (d) 中可以看到,生成器输出的数据分布逐渐逼近真实数据分布,而判别器则变成一条“直线”,对所有样本都输出 0.5


这意味着生成器已经学会了训练集的概率分布,能够生成高度逼真的数据样本,其效果与真实样本十分相似。


谢谢你读到这里☺️!

希望这些关于 GAN 数学原理的讲解对你有所帮助,也希望你现在已经更有信心去构建属于你自己的 GAN 网络了。

如果你想进一步加深理解,建议深入学习以下几个关键概念:

  • KL 散度
  • JS 散度
  • GAN 训练中的稳定性问题

此外,你可能也会对更高级的主题感兴趣,比如:

  • Wasserstein GANs(WGAN)
  • 边界平衡 GAN(BEGAN)
  • Lipschitz 连续性约束
  • GAN 中的 Batch Normalization 技术
  • Spectral Normalization(谱归一化)
  • DCGAN(深度卷积生成对抗网络)


这些内容将帮助你更全面地理解和训练更稳定、更强大的 GAN 模型。


评论