雷锋网按:本文原作者 天雨粟 ,原文载于作者的知乎专栏—— 机器不学习 ,雷锋网经授权发布。 前言 GAN 从 2014 年诞生以来发展的是相当火热,比较著名的 GAN 的应用有 Pix2Pix、CycleGAN 等。本篇文章主要是让初学者通过代码了解 GAN 的结构和运作机制,对理论细节不做过多介绍。我们还是采用 MNIST 手写数据集(不得不说这个数据集对于新手来说非常好用)来作为我们的训练数据,我们将构建一个简单的 GAN 来进行手写数字图像的生成。
认识 GAN GAN 主要包括了两个部分,即生成器 generator 与判别器 discriminator。生成器主要用来学习真实图像分布从而让自身生成的图像更加真实,以骗过判别器。判别器则需要对接收的图片进行真假判别。在整个过程中,生成器努力地让生成的图像更加真实,而判别器则努力地去识别出图像的真假,这个过程相当于一个二人博弈,随着时间的推移,生成器和判别器在不断地进行对抗,最终两个网络达到了一个动态均衡:生成器生成的图像接近于真实图像分布,而判别器识别不出真假图像,对于给定图像的预测为真的概率基本接近 0.5(相当于随机猜测类别)。
<img src="https://static.leiphone.com/uploads/new/article/pic/201707/470eeb25c5f7ce8bb780fc8137031239.png" data-rawwidth="2520" data-rawheight="1142" class="origin_image zh-lightbox-thumb" width="2520" data-original="https://pic2.zhimg.com/v2-bea410c4a9ac2ac5f86ac4ac4fc827cd_r.png" _src="https://static.leiphone.com/uploads/new/article/pic/201707/470eeb25c5f7ce8bb780fc8137031239.png"/>
对于 GAN 更加直观的理解可以用一个例子来说明:造假币的团伙相当于生成器,他们想通过伪造金钱来骗过银行,使得假币能够正常交易,而银行相当于判别器,需要判断进来的钱是真钱还是假币。因此假币团伙的目的是要造出银行识别不出的假币而骗过银行,银行则是要想办法准确地识别出假币。
因此,我们可以将上面的内容进行一个总结。给定真 = 1,假 = 0,那么有:
对于给定的真实图片(real image),判别器要为其打上标签 1;
对于给定的生成图片(fake image),判别器要为其打上标签 0;
对于生成器传给辨别器的生成图片,生成器希望辨别器打上标签 1。
有了上面的直观理解,下面就让我们来实现一个 GAN 来生成手写数据吧!还有一些细节会在代码部分进行介绍。
说明 建议将代码 pull 下来,有部分代码实现没有写在文章中。 代码部分 数据加载与查看
数据我们使用 TensorFlow 中给定的 MNIST 数据接口。
<img src="https://static.leiphone.com/uploads/new/article/pic/201707/8f74c5dca5fb79cb239f6c674d32e7ff.png" data-rawwidth="2054" data-rawheight="960" class="origin_image zh-lightbox-thumb" width="2054" data-original="https://pic1.zhimg.com/v2-2559f5ad3ea16c4648d48665eba1fd50_r.png" _src="https://static.leiphone.com/uploads/new/article/pic/201707/8f74c5dca5fb79cb239f6c674d32e7ff.png"/>
在构建模型之前,我们首先来看一下我们需要完成的任务:
Inputs
generator
discriminator
定义参数
loss & optimizer
训练模型
显示结果
输入 inputs
输入函数主要来定义真实图片与生成图片两个 tensor。
<img src="https://static.leiphone.com/uploads/new/article/pic/201707/c856a3b5b4b0894e93edc95888233e5a.png" data-rawwidth="1522" data-rawheight="330" class="origin_image zh-lightbox-thumb" width="1522" data-original="https://pic2.zhimg.com/v2-89760fdc1c3201fb6a7536b7cd6847e1_r.png" _src="https://static.leiphone.com/uploads/new/article/pic/201707/c856a3b5b4b0894e93edc95888233e5a.png"/>
定义生成器
我们的生成器结构如下:
<img src="https://static.leiphone.com/uploads/new/article/pic/201707/4b73d5230f59a585207d05b5d598a67f.png" data-rawwidth="1366" data-rawheight="1286" class="origin_image zh-lightbox-thumb" width="1366" data-original="https://pic4.zhimg.com/v2-cc92aafbeafed55a98ccd9453d98e397_r.png" _src="https://static.leiphone.com/uploads/new/article/pic/201707/4b73d5230f59a585207d05b5d598a67f.png"/>
我们使用了一个采用 Leaky ReLU 作为激活函数的隐层,并在输出层加入 tanh 激活函数。
下面是生成器的代码。注意在定义生成器和判别器时,我们要指定变量的 scope,这是因为 GAN 中实际上包含生成器与辨别器两个网络,在后面进行训练时是分开训练的,因此我们要把 scope 定义好,方便训练时候指定变量。
<img src="https://static.leiphone.com/uploads/new/article/pic/201707/f056d06ea565c4529b9620c506554d05.png" data-rawwidth="1282" data-rawheight="798" class="origin_image zh-lightbox-thumb" width="1282" data-original="https://pic3.zhimg.com/v2-1f41bb3189a3316eaf6a675cc857c06e_r.png" _src="https://static.leiphone.com/uploads/new/article/pic/201707/f056d06ea565c4529b9620c506554d05.png"/>
在这个网络中,我们使用了一个隐层,并加入 dropout 防止过拟合。通过输入噪声图片,generator 输出一个与真实图片一样大小的图像。
在这里我们的隐层激活函数采用的是 Leaky ReLU(中文不知道咋翻译),这个函数在 ReLU 函数基础上改变了左半边的定义。
<img src="https://static.leiphone.com/uploads/new/article/pic/201707/8b9b7144ff764f597b03fe44e32379e1.png" data-rawwidth="2034" data-rawheight="576" class="origin_image zh-lightbox-thumb" width="2034" data-original="https://pic3.zhimg.com/v2-7b556b137a1a8e9115bd338d24d3bdb2_r.png" _src="https://static.leiphone.com/uploads/new/article/pic/201707/8b9b7144ff764f597b03fe44e32379e1.png"/>
图片来自维基百科。Andrej Karpathy 在 CS231n 中也提到有模型通过这个函数取得了不错的效果。 由于 TensorFlow 中没有这个函数的实现,在这里我们通过函数定义实现了 Leaky ReLU,其中 alpha 是一个很小的数。在输出层我们使用 tanh 函数,这是因为 tanh 在这里相比 sigmoid 的结果会更好一点(在这里要注意,由于生成器的生成图片像素限制在了 (-1, 1) 的取值之间,而 MNIST 数据集的像素区间为 [0, 1],所以在训练时要对 MNIST 的输入做处理,具体见训练部分的代码) 。到此,我们构建好了生成器,它通过接收一个噪声图片输出一个与真实图片一样 size 的图像。
定义判别器
判别器的结构如下:
<img src="https://static.leiphone.com/uploads/new/article/pic/201707/6cd371f69ce00ccb78100ab4c3509aea.png" data-rawwidth="1202" data-rawheight="1184" class="origin_image zh-lightbox-thumb" width="1202" data-original="https://pic3.zhimg.com/v2-cab82713a90ab308b18941bde607e352_r.png" _src="https://static.leiphone.com/uploads/new/article/pic/201707/6cd371f69ce00ccb78100ab4c3509aea.png"/>
判别器接收一张图片,并判断它的真假,同样隐层使用了 Leaky ReLU,输出层为 1 个结点,输出为 1 的概率。代码如下:
<img src="https://static.leiphone.com/uploads/new/article/pic/201707/890d6271a9bb0b0faa4b19336350bff6.png" data-rawwidth="1164" data-rawheight="664" class="origin_image zh-lightbox-thumb" width="1164" data-original="https://pic4.zhimg.com/v2-b7238f5ddd5711a8d3a544810d3ca053_r.png" _src="https://static.leiphone.com/uploads/new/article/pic/201707/890d6271a9bb0b0faa4b19336350bff6.png"/>
在这里,我们需要注意真实图片与生成图片是共享判别器的参数的,因此在这里我们留了 reuse 接口来方便我们后面调用。
定义参数
img_size 是我们真实图片的 size=32*32=784。
<img src="https://static.leiphone.com/uploads/new/article/pic/201707/4847194c337e1ff2ca4606afb9801618.png" data-rawwidth="904" data-rawheight="560" class="origin_image zh-lightbox-thumb" width="904" data-original="https://pic3.zhimg.com/v2-ab972524d8027f976261c63467a7a512_r.png" _src="https://static.leiphone.com/uploads/new/article/pic/201707/4847194c337e1ff2ca4606afb9801618.png"/>
smooth 是进行 Label Smoothing Regularization 的参数,在后面会介绍。 构建网络
接下来我们来构建我们的网络,并获得生成器与判别器返回的变量。
<img src="https://static.leiphone.com/uploads/new/article/pic/201707/754c4d0e14cc56fdd4c7b8b3c9aae6f0.png" data-rawwidth="1448" data-rawheight="398" class="origin_image zh-lightbox-thumb" width="1448" data-original="https://pic1.zhimg.com/v2-70e5cb23e9eb00bab74edd45f6b5c7d8_r.png" _src="https://static.leiphone.com/uploads/new/article/pic/201707/754c4d0e14cc56fdd4c7b8b3c9aae6f0.png"/>
我们分别获得了生成器与判别器的 logits 和 outputs。注意真实图片与生成图片是共享参数的,因此在判别器输入生成图片时,需要 reuse 参数。
定义 Loss 和 Optimizer
有了上面的 logits,我们就可以定义我们的 loss 和 Optimizer。在这之前,我们再来回顾一下生成器和判别器各自的目的是什么:
我们来把上面这三句话转换成代码:
<img src="https://static.leiphone.com/uploads/new/article/pic/201707/5d463e523ceeab07f3fb2f03fa1033f5.png" data-rawwidth="2056" data-rawheight="504" class="origin_image zh-lightbox-thumb" width="2056" data-original="https://pic2.zhimg.com/v2-910e968a6d5d3acaf3b1670d2d4b92a9_r.png" _src="https://static.leiphone.com/uploads/new/article/pic/201707/5d463e523ceeab07f3fb2f03fa1033f5.png"/>
d_loss_real 对应着真实图片的 loss,它尽可能让判别器的输出接近于 1。在这里,我们使用了单边的 Label Smoothing Regularization,它是一种防止过拟合的方式,在传统的分类中,我们的目标非 0 即 1,从直觉上来理解的话,这样的目标不够 soft,会导致训练出的模型对于自己的预测结果过于自信。因此我们加入一个平滑值来让判别器的泛化效果更好。
d_loss_fake 对应着生成图片的 loss,它尽可能地让判别器输出为 0。
d_loss_real 与 d_loss_fake 加起来就是整个判别器的损失。
而在生成器端,它希望让判别器对自己生成的图片尽可能输出为 1,相当于它在于判别器进行对抗。
下面我们定义了优化函数,由于 GAN 中包含了生成器和判别器两个网络,因此需要分开进行优化,这也是我们在之前定义 variable_scope 的原因。
<img src="https://static.leiphone.com/uploads/new/article/pic/201707/d84e7cbd0d1c174092e7cfaa70e6c456.png" data-rawwidth="1482" data-rawheight="392" class="origin_image zh-lightbox-thumb" width="1482" data-original="https://pic3.zhimg.com/v2-f5befa264fa631be5696ce85dbde3a6e_r.png" _src="https://static.leiphone.com/uploads/new/article/pic/201707/d84e7cbd0d1c174092e7cfaa70e6c456.png"/>
训练模型
由于训练部分代码太长,我在这里就不贴出来了,请前往我的 GitHub 下载代码。在训练部分,我们记录了部分图像的生成过程,并记录了训练数据的 loss 变化。
我们将整个训练过程的 loss 变化绘制出来:
<img src="https://static.leiphone.com/uploads/new/article/pic/201707/7a717f4f0e143c2887481b3e021b7631.png" data-rawwidth="1990" data-rawheight="738" class="origin_image zh-lightbox-thumb" width="1990" data-original="https://pic3.zhimg.com/v2-804e10e73615c14906fa567c6ef66166_r.png" _src="https://static.leiphone.com/uploads/new/article/pic/201707/7a717f4f0e143c2887481b3e021b7631.png"/>
从图中可以看出来,最终的判别器总体 loss 在 1 左右波动,而 real loss 和 fake loss 几乎在一条水平线上波动,这说明判别器最终对于真假图像已经没有判别能力,而是进行随机判断。
查看过程结果
我们在整个训练过程中记录了 25 个样本在不同阶段的 samples 图像,以序列化的方式进行了保存,我们的将 samples 加载进来。samples 的 size=epochs x 2 x n_samples x 784,我们的迭代次数为 300 轮,25 个样本,因此,samples 的 size=300 x 2 x 25 x 784。我们将最后一轮的生成结果打印出来:
<img src="https://static.leiphone.com/uploads/new/article/pic/201707/18d7a377373add1fbbba5a098133e18d.png" data-rawwidth="822" data-rawheight="794" class="origin_image zh-lightbox-thumb" width="822" data-original="https://pic4.zhimg.com/v2-7ea77283fed8c649972f629b083c86cf_r.png" _src="https://static.leiphone.com/uploads/new/article/pic/201707/18d7a377373add1fbbba5a098133e18d.png"/>
这就是我们的 GAN 通过学习真实图片的分布后生成的图像结果。
那么有同学可能会问了,我们如果想要看这 300 轮中生成图像的变化是什么样该怎么办呢?因为我们已经有了 samples,存储了每一轮迭代的结果,我们可以挑选几次迭代,把对应的图像打出来:
<img src="https://static.leiphone.com/uploads/new/article/pic/201707/6be01cba5bacfd2ae52e1fcdeb3452b6.png" data-rawwidth="1982" data-rawheight="792" class="origin_image zh-lightbox-thumb" width="1982" data-original="https://pic3.zhimg.com/v2-63d322a93048370ef9e7593fa4633bee_r.png" _src="https://static.leiphone.com/uploads/new/article/pic/201707/6be01cba5bacfd2ae52e1fcdeb3452b6.png"/>
这里我挑选了第 0, 5, 10, 20, 40, 60, 80, 100, 150, 250 轮的迭代效果图,在这个图中,我们可以看到最开始的时候只有中间是白色,背景黑色块中存在着很多噪声。随着迭代次数的不断增加,生成器制造 “假图” 的能力也越来越强,它逐渐学得了真实图片的分布,最明显的一点就是图片区分出了黑色背景和白色字符的界限。
生成新的图片
如果我们想重新生成新的图片呢?此时我们只需要将我们之前保存好的模型文件加载进来就可以啦。
<img src="https://static.leiphone.com/uploads/new/article/pic/201707/afa32b414c683d2229d667b26cbe1fed.png" data-rawwidth="1478" data-rawheight="1212" class="origin_image zh-lightbox-thumb" width="1478" data-original="https://pic3.zhimg.com/v2-66cf48d6bb5f1ab86544f833d0652746_r.png" _src="https://static.leiphone.com/uploads/new/article/pic/201707/afa32b414c683d2229d667b26cbe1fed.png"/>
总结 整篇文章基于 MNIST 数据集构造了一个简单的 GAN 模型,相信小伙伴看完代码会对 GAN 有一个初步的了解。从最终的模型结果来看,生成的图像能够将背景与数字区分开,黑色块噪声逐渐消失,但从显示结果来看还是有很多模糊区域的。
对于这里的图片处理,相信很多小伙伴会想到卷积神经网络,那么后面我们还会将生成器和判别器改为卷积神经网络来构造深度卷积 GAN,它对于图片的生成会取得更好的效果。
如果觉得不错,请给 GitHub 点个 Star 吧~