gan代码解读(乘风破浪的GAN)

本文来源于Hulu知乎官方账号文章《<百面深度学习>试读 | 系列二:回顾GAN的发展之路》,葫芦娃著。

gan代码解读(乘风破浪的GAN)(1)

2014 年,在一间酒吧里,Ian Goodfellow 与朋友觥筹交错畅谈生成模型之际,生成式对抗网络(Generative Adversarial Network, GAN)诞生了。2020年,GAN 一路乘风破浪,披荆斩棘,在各种领域成功“出道”。

有兴趣的读者可以去记录了五百多种 GAN 的 GAN Zoo 里 pick 你喜欢的 GAN。那么 GAN 是如何从一个博弈的想法,进化到如今这样高产的呢?这之中又有多少艰难险阻呢?今天我们就来回顾一下 GAN 的发展之路。

回到 GAN 最原始的模样

GAN 是一种非常能展现“创造性”的模型。从名字就可以看出,GAN 的核心在于“生成”和“对抗”:“生成”指的是生成式模型,其目的是模拟多个变量的联合概率分布;“对抗”指的是对抗训练方法,又名 “互怼”、“相爱相杀”。这两者结合,就产生出了神奇的 GAN。

GAN 是专门为了优化生成任务而提出的模型。生成模型的一大难点在于如何度量生成分布与真实分布的相似度。一般情况下,我们只知道这两个分布的采样结果,很难知道具体的分布表达式,因此难以找到合适的度量方法。GAN 的思路是,把这个度量任务交给一个神经网络来做,这个网络被称为判别器(Discriminator)。GAN 在训练阶段用对抗训练方式来交替优化生成器 G(⋅) 与判别器 D(⋅)。整个模型的优化目标是:

gan代码解读(乘风破浪的GAN)(2)

上述公式直观地解释了 GAN 的原理:判别器D(⋅)的目标是区分真实样本和生成样本,对应在目标函数上就是使上式的值尽可能大,也就是对真实样本的输出D(x)尽量接近1,对生成样本的输出D(G(z))尽量接近0;生成器G(⋅)的目标是欺骗判别器,尽量生成“以假乱真”的样本来逃过判别器的“法眼”,对应在目标函数上就是让上式的值尽可能小,也就是让D(G(z))也尽量接近1。

这是一个 "MiniMax" 游戏,在游戏过程中G(⋅)和D(⋅)的目标是相反的,这就是 GAN 名字中“对抗”的含义。通过对抗训练,生成器与判别器交替优化,共同成长,最终修炼为两个势均力敌的强者。Fig. 1 是 GAN 的基本框架图。

gan代码解读(乘风破浪的GAN)(3)

Figure 1: GAN 的基本框架

不完美的 GAN

GAN 最初的样子看上去简洁又优美。 可是诞生不久,一些对于 GAN 的负面评价开始出现,“梯度消失”、“训练不稳定”、“模式坍塌”...... 对 GAN 有过了解的人应该对这些词汇并不陌生。而想要弄清这些问题究竟是什么原因造成的,需要深入到模型背后对其原理进行详细地分析。

  • 梯度消失

在 Goodfellow 提出的原始 GAN [1] 中,模型的优化目标为 Eq. 1。如果将判别器 D(⋅) 看作一个二分类器,真实样本的训练标签为 1,生成样本的训练标签为 0,D(⋅) 的优化目标可以解释为最大化该分类问题的对数似然函数,也即最小化交叉熵损失。这个看起来简洁又直观的定义,在理论上存在着一些问题。

这是因为,在一开始训练时,生成器还很差,生成的数据与真实数据相差甚远,判别器可以很轻松地作出正确的判断,也就是输出与训练标签基本一致的结果。有监督的训练需要通过损失函数的梯度来更新模型,若模型输出结果与标签基本一致,也就意味着梯度接近于0,这会让判别器觉得自己已经足够强大,失去优化自己的“动力”。简单用一句话概括来说,就是在训练的早期阶段,目标函数 Eq. 1 无法为生成器提供足够大的梯度。

我们还可以从“散度”的角度来理解这个问题:当判别器达到最优时,根据损失函数 L(G),此时生成器的目标其实是最小化真实分布与生成分布之间的 JS 散度。JS 散度(Jensen–Shannon divergence),用于度量两个概率分布的相似度,其定义为:

gan代码解读(乘风破浪的GAN)(4)

然而,JS 散度有一个特性:当两个分布没有重叠的部分,或几乎没有重叠时,JS 散度为常数(这可以根据 JS 散度的定义 Eq. 2 得到)。那么在 GAN 中,真实分布和生成分布的重叠部分有多大呢?

生成器一般是从一个低维空间(如 128 维)中采样一个向量并将其映射到一个高维空间中(比如一个 32×32 的图像就是 1024 维),所以生成数据只是高维空间中的一个低维流形(比如生成样本在上述 1024 维图像空间的所有可能性实际上是被那 128 维的输入向量限定了)。

同理,真实分布也是高维空间中的低维流形。高维空间中的两个低维流形,在这样“地广人稀”的空间中碰面的几率趋于 0,所以生成分布与真实分布是几乎没有重叠部分的。高维空间中的 JS 散度遇到了“降维打击”,只能无奈地输出一个常数,导致梯度消失问题。

  • 模式坍塌(mode collapse)

为解决梯度消失问题,Goodfellow 提出了改进方案,采用以下公式来替代生成器的损失函数:

gan代码解读(乘风破浪的GAN)(5)

可以看到,两个公式就是差了一个“1”,但就能在训练早期阶段可以为生成器提供更大的梯度(参见文献[1])。

然而,改进后的损失函数也存在不合理之处。此时最小化损失函数相当于最小化

gan代码解读(乘风破浪的GAN)(6)

(推导略,详情见《百面深度学习》对应章节),这就既要最小化生成分布与真实分布的 KL 散度,即减小两个分布的距离,又要最大化两者的 JS 散度,即增大两个分布的距离,这会在训练时造成梯度的不稳定。

另外,KL 散度是一个非对称度量,因此还存在对不同错误惩罚不一致的问题。当生成器缺乏多样性但能生成某一种简单的模式时,生成的样本对损失函数贡献趋近于0;当生成器尝试去生成其他更复杂的模式下的样本时,惩罚会趋于无穷大。

举个简单的例子,美术课上“判别器”老师,给了一些英短猫、银渐层、缅因猫、Hello Kitty 等等各式各样形态各异的猫的样例,要求“生成器”学生们创作属于自己的猫,选择高难度种类的创作并不会获得奖励,但是与样例的的水准差距明显则会受到请家长并记过的惩罚。由于惩罚过于严厉,学生们纷纷选择了最简单的Hello Kitty,其他形态各异,胖瘦不一的猫猫由于太难而无人问津。

真实数据的分布也往往是高度复杂并且是多模态的,数据分布有很多模式(modes),相似的样本属于一个模式。由于惩罚的不一致,生成器宁愿多生成一些真实却属于同一个模式的样本,也不愿意冒着巨大惩罚的风险去生成其他不同模式的、具有多样性的样本来欺骗判别器,这就是所谓的模式坍塌(mode collapse)。

  • 模型收敛性

在实际应用中,一般常用深度神经网络来表示 G(⋅) 和 D(⋅),然后采用梯度下降法反向传播算法来更新网络参数,而不是直接学习

gan代码解读(乘风破浪的GAN)(7)

本身。然而,Goodfellow 给出的收敛性证明是基于概率密度函数空间上

gan代码解读(乘风破浪的GAN)(8)

的凸性,当问题变成了参数空间的优化时,凸性便不再确定了,所以理论上的收敛性在实际中不再有效。

另外,训练的收敛性的判断也是一个难题。由于存在对抗,生成器与判别器的损失是反相关的,一个增大时另一个减小,因而无法根据损失函数的值来判断什么时候应该停止训练。当然,我们也很难直接通过损失函数或者生成器的输出来判断生成数据的质量,例如难以比较哪个图更“真实”,哪些生成数据多样性更高。

GAN 的成长之路

近几年来,GAN 发展十分迅速,各式各样的 GAN 不断涌现,有很多工作都致力于解决 GAN 训练的不稳定、提高生成数据真实性和多样性等问题。主要可以归类为对于目标函数的优化、对于模型结构的优化以及对于训练过程的优化。

  • 目标函数的优化

前文已经提到,基于 JS 散度距离度量的 GAN 会存在一些问题,在分布距离度量方式上进行优化自然的成为一个改进的方向。采用 Wasserstein 距离的 WGAN [2] 就是一个典型的代表。 另一种比较典型的对目标函数的优化是折页损失(Hinge Loss)形式的损失函数,它起源于 Geometric GAN [3]。

在 Geometric GAN 中,研究者将 GAN 解释为在特征空间进行的三步操作:(1) 分类超平面搜索;(2) 判别器向远离超平面的方向更新;(3) 生成器向超平面的方向更新。各种 GAN 之间的主要区别就在于分类超平面的构建方法以及特征向量的几何尺度缩放因子的选择,具体理论推导参见文献 [3]。

在训练阶段,批(mini-batch)的大小往往远小于特征空间的维度,这种情况下的分类问题被称为高维低采样尺寸(High-Dimension-Low-Sample-Size, HDLSS)问题。支持向量机(SVM)中最大化两类的分类边界以及软边界的思想被广泛应用在 HDLSS 问题中,并被证明具有鲁棒性。Geometric GAN 借鉴 SVM 的思想,判别器的损失函数形式与 SVM 中的折页损失的形式很相似。Geometric GAN 出现后,这种具有折页损失形式的 GAN 损失函数在很多方法中被采用,包括 2018 年热门的 SAGAN [4] 和 BigGAN [5]。

  • 模型结构的优化

生成器和判别器的网络结构对于训练过程的稳定性和模型表现至关重要。在 GAN 模型结构的各种改进方法中,最先要提到的是 DCGAN(Deep Convalutional GAN)[6],它首次将卷积神经网络用到 GAN 中,为 GAN 家族贡献了一个重要的基准结构。之后比较具有代表性的结构改进包括加入自动编码器的结构(auto-encoder architecture),比如 VAE-GAN [7], ALI [8];以及层次化的结构,例如 Stacked GAN [9] 通过堆叠多个“生成器-判别器-编码器”来构造层次化结构。也有方法通过在训练过程中对单个 GAN 进行动态地堆叠以构成层次化结构,例如 Progressive GAN [10] 就仅用一个“生成器-判别器”对,但在训练过程中逐渐增加网络的层数,其模型结构如 Fig. 2 所示。

gan代码解读(乘风破浪的GAN)(9)

Figure 2: Progressive GAN 结构示意图

  • 训练过程的优化

GAN 的训练过程实际上是在极高维参数空间中寻找一个非凸优化问题的纳什均衡点的过程,该过程常常是很不稳定的。很多文献中提出了一些针对神经网络训练过程的改进方法来提升 GAN 的效果,还有一些是在实际应用中发现的能够稳定训练过程的经验和技巧,如特征匹配技术 [11]、单边标签平滑 [11]、谱归一化 [12] 等等。

硕果累累的 GAN

随着 GAN 在理论上突飞猛进的发展,各种 GAN 在不同领域中的应用也遍地开花。在计算机视觉中的应用包括图像和视频的生成、图像与图像或文字之间的翻译、物体检测、语义分割等。除了图像领域,GAN 还被应用在半监督学习、迁移学习、强化学习、多模态学习、特征学习等领域。在 GAN 的应用中,除了上面提到的一些偏基础的改进以外,研究者们往往会根据具体的场景和问题因地制宜地对模型进行进一步的改进和调整。下图以人脸合成任务为例,列举了GAN 在不同时期的合成效果。

gan代码解读(乘风破浪的GAN)(10)

Figure 3: 基于 GAN 的人脸图像合成示例 [13]

GAN 的故事还有很多,GAN 的未来也非常值得期待。如果读者们心中还是有对 GAN 的各种问号,或是想了解更多更深入的关于 GAN 的知识,欢迎到《百面深度学习》书中寻找答案。

gan代码解读(乘风破浪的GAN)(11)

参考文献:

[1] GOODFELLOW I, POUGET-ABADIE J, MIRZA M, 等. Generative adversarial nets[C]//Advances in Neural Information Processing Systems. 2014: 2672–2680.[2] ARJOVSKY M, CHINTALA S, BOTTOU L. Wasserstein GAN[J]. arXiv preprint arXiv:1701.07875, 2017.[3] LIM J H, YE J C. Geometric GAN[J]. arXiv preprint arXiv:1705.02894, 2017.[4] ZHANG H, GOODFELLOW I, METAXAS D, 等. Self-attention generative adversarial networks[J]. arXiv preprint arXiv:1805.08318, 2018.[5] BROCK A, DONAHUE J, SIMONYAN K. Large scale GAN training for high fidelity natural image synthesis[J]. arXiv preprint arXiv:1809.11096, 2018.[6] RADFORD A, METZ L, CHINTALA S. Unsupervised representation learning with deep convolutional generative adversarial networks[J]. arXiv preprint arXiv:1511.06434, 2015.[7] LARSEN A B L, SØNDERBY S K, LAROCHELLE H, 等. Autoencoding beyond pixels using a learned similarity metric[J]. arXiv preprint arXiv:1512.09300, 2015.[8] DUMOULIN V, BELGHAZI I, POOLE B, 等. Adversarially learned inference[J]. arXiv preprint arXiv:1606.00704, 2016.[9] ZHANG H, XU T, LI H, 等. StackGAN: Text to photo-realistic image synthesis with stacked generative adversarial networks[C]//Proceedings of the IEEE International Conference on Computer Vision. 2017: 5907–5915.[10] KARRAS T, AILA T, LAINE S, 等. Progressive growing of GANs for improved quality, stability, and variation[J]. arXiv preprint arXiv:1710.10196, 2017.[11] SALIMANS T, GOODFELLOW I, ZAREMBA W, 等. Improved techniques for training GANs[C]//Advances in Neural Information Processing Systems. 2016: 2234–2242.[12] MIYATO T, KATAOKA T, KOYAMA M, 等. Spectral normalization for generative adversarial networks[J]. arXiv preprint arXiv:1802.05957, 2018.[13] GUI J, SUN Z, WEN Y, 等. A review on generative adversarial networks: Algorithms, theory, and applications[J]. arXiv preprint arXiv:2001.06937, 2020.

,

免责声明:本文仅代表文章作者的个人观点,与本站无关。其原创性、真实性以及文中陈述文字和内容未经本站证实,对本文以及其中全部或者部分内容文字的真实性、完整性和原创性本站不作任何保证或承诺,请读者仅作参考,并自行核实相关内容。文章投诉邮箱:anhduc.ph@yahoo.com

    分享
    投诉
    首页