生成对抗网络(GAN)原理详解

1. 引言

生成对抗网络(Generative Adversarial Network, GAN)是一种强大的生成模型,由Ian Goodfellow等人于2014年提出。其核心思想是通过 生成器(Generator)判别器(Discriminator) 的对抗训练,学习数据分布并生成高质量样本。


2. GAN的基本结构

2.1 生成器(Generator)

  • 目标:将随机噪声zzz映射为真实数据分布pdata(x)p_{\text{data}}(x)pdata(x)
  • 输入:随机噪声zzz(通常从高斯分布或均匀分布中采样)。
  • 输出:生成样本G(z)G(z)G(z)

2.2 判别器(Discriminator)

  • 目标:区分真实数据xxx和生成数据G(z)G(z)G(z)
  • 输入:真实数据xxx或生成数据G(z)G(z)G(z)
  • 输出:概率值D(x)D(x)D(x)D(G(z))D(G(z))D(G(z)),表示输入为真实数据的概率。

3. GAN的数学原理

3.1 目标函数

GAN的训练过程可以看作一个极小极大博弈(Minimax Game),其目标函数为:
min⁡Gmax⁡DV(D,G)=Ex∼pdata(x)[log⁡D(x)]+Ez∼pz(z)[log⁡(1−D(G(z)))] \min_G \max_D V(D, G) = \mathbb{E}_{x \sim p_{\text{data}}(x)}[\log D(x)] + \mathbb{E}_{z \sim p_z(z)}[\log(1 - D(G(z)))] GminDmaxV(D,G)=Expdata(x)[logD(x)]+Ezpz(z)[log(1D(G(z)))]

  • 判别器的目标:最大化V(D,G)V(D, G)V(D,G),即正确区分真实数据和生成数据。
  • 生成器的目标:最小化V(D,G)V(D, G)V(D,G),即生成数据G(z)G(z)G(z)尽可能接近真实数据。

3.2 优化过程

  1. 固定生成器GGG,更新判别器DDD
    max⁡DV(D,G)=Ex∼pdata(x)[log⁡D(x)]+Ez∼pz(z)[log⁡(1−D(G(z)))] \max_D V(D, G) = \mathbb{E}_{x \sim p_{\text{data}}(x)}[\log D(x)] + \mathbb{E}_{z \sim p_z(z)}[\log(1 - D(G(z)))] DmaxV(D,G)=Expdata(x)[logD(x)]+Ezpz(z)[log(1D(G(z)))]
  2. 固定判别器DDD,更新生成器GGG
    min⁡GV(D,G)=Ez∼pz(z)[log⁡(1−D(G(z)))] \min_G V(D, G) = \mathbb{E}_{z \sim p_z(z)}[\log(1 - D(G(z)))] GminV(D,G)=Ezpz(z)[log(1D(G(z)))]

3.3 梯度更新

  • 判别器梯度
    ∇θd1m∑i=1m[log⁡D(x(i))+log⁡(1−D(G(z(i))))] \nabla_{\theta_d} \frac{1}{m} \sum_{i=1}^m [\log D(x^{(i)}) + \log(1 - D(G(z^{(i)})))] θdm1i=1m[logD(x(i))+log(1D(G(z(i))))]
  • 生成器梯度
    ∇θg1m∑i=1mlog⁡(1−D(G(z(i)))) \nabla_{\theta_g} \frac{1}{m} \sum_{i=1}^m \log(1 - D(G(z^{(i)}))) θgm1i=1mlog(1D(G(z(i))))

4. GAN的训练过程

4.1 初始化

  • 初始化生成器GGG和判别器DDD的参数。

4.2 迭代训练

  1. 采样
    • 从真实数据分布pdata(x)p_{\text{data}}(x)pdata(x)中采样mmm个样本{x(1),...,x(m)}\{x^{(1)}, ..., x^{(m)}\}{x(1),...,x(m)}
    • 从噪声分布pz(z)p_z(z)pz(z)中采样mmm个噪声{z(1),...,z(m)}\{z^{(1)}, ..., z^{(m)}\}{z(1),...,z(m)}
  2. 更新判别器
    • 计算判别器梯度并更新参数θd\theta_dθd
  3. 更新生成器
    • 计算生成器梯度并更新参数θg\theta_gθg
  4. 重复:直到生成器和判别器达到平衡。

5. GAN的变体

5.1 DCGAN(Deep Convolutional GAN)

  • 使用卷积神经网络作为生成器和判别器。
  • 引入批量归一化(Batch Normalization)和LeakyReLU激活函数。

5.2 WGAN(Wasserstein GAN)

  • 使用Wasserstein距离作为损失函数,解决训练不稳定的问题。
  • 目标函数:
    min⁡Gmax⁡D∈DEx∼pdata(x)[D(x)]−Ez∼pz(z)[D(G(z))] \min_G \max_{D \in \mathcal{D}} \mathbb{E}_{x \sim p_{\text{data}}(x)}[D(x)] - \mathbb{E}_{z \sim p_z(z)}[D(G(z))] GminDDmaxExpdata(x)[D(x)]Ezpz(z)[D(G(z))]
    其中D\mathcal{D}D为1-Lipschitz函数空间。

5.3 Conditional GAN(CGAN)

  • 在生成器和判别器中引入条件信息yyy(如类别标签)。
  • 目标函数:
    min⁡Gmax⁡DV(D,G)=Ex∼pdata(x)[log⁡D(x∣y)]+Ez∼pz(z)[log⁡(1−D(G(z∣y)))] \min_G \max_D V(D, G) = \mathbb{E}_{x \sim p_{\text{data}}(x)}[\log D(x|y)] + \mathbb{E}_{z \sim p_z(z)}[\log(1 - D(G(z|y)))] GminDmaxV(D,G)=Expdata(x)[logD(xy)]+Ezpz(z)[log(1D(G(zy)))]

6. 实例分析:图像生成

6.1 任务描述

生成手写数字图像(如MNIST数据集)。

6.2 处理流程

  1. 生成器
    • 输入:随机噪声zzz(如100维向量)。
    • 输出:28×2828 \times 2828×28的手写数字图像。
  2. 判别器
    • 输入:28×2828 \times 2828×28的图像。
    • 输出:图像为真实数据的概率。

6.3 训练结果

  • 初始阶段:生成器生成噪声图像,判别器容易区分。
  • 训练中期:生成器生成模糊但可辨认的数字。
  • 训练后期:生成器生成高质量的手写数字图像。

7. 数学附录

7.1 最优判别器

当生成器固定时,最优判别器为:
D∗(x)=pdata(x)pdata(x)+pg(x) D^*(x) = \frac{p_{\text{data}}(x)}{p_{\text{data}}(x) + p_g(x)} D(x)=pdata(x)+pg(x)pdata(x)

7.2 全局最优解

当且仅当pg(x)=pdata(x)p_g(x) = p_{\text{data}}(x)pg(x)=pdata(x)时,GAN达到全局最优解,此时判别器无法区分真实数据和生成数据:
D∗(x)=12 D^*(x) = \frac{1}{2} D(x)=21


8. 总结

  • GAN通过生成器和判别器的对抗训练,能够生成高质量的数据样本。
  • GAN的训练过程是一个极小极大博弈,需要平衡生成器和判别器的能力。
  • GAN的变体(如DCGAN、WGAN、CGAN)进一步提升了模型的稳定性和生成质量。
  • GAN在图像生成、图像修复、风格迁移等任务中表现出色,是生成模型领域的重要突破。
Logo

GitCode 天启AI是一款由 GitCode 团队打造的智能助手,基于先进的LLM(大语言模型)与多智能体 Agent 技术构建,致力于为用户提供高效、智能、多模态的创作与开发支持。它不仅支持自然语言对话,还具备处理文件、生成 PPT、撰写分析报告、开发 Web 应用等多项能力,真正做到“一句话,让 Al帮你完成复杂任务”。

更多推荐