GAN 生成图像入坑

50次阅读
没有评论

共计 13564 个字符,预计需要花费 34 分钟才能阅读完成。

下面是一个 完整、可运行的 GAN 示例,分别适用于 MNIST(手写数字)和 CelebA(人脸)数据集。我们将使用 PyTorch 实现一个简化版的 DCGAN(Deep Convolutional GAN),这是训练稳定、结构清晰的经典 GAN 架构。


🧪 一、准备工作

安装依赖(如未安装)

pip install torch torchvision matplotlib tqdm

🖼️ 二、GAN 生成 MNIST 图像(简单入门)

1. 导入库

import torch
import torch.nn as nn
import torchvision
import torchvision.transforms as transforms
from torch.utils.data import DataLoader
import matplotlib.pyplot as plt
import numpy as np

2. 超参数设置

# 设备
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# 超参数
batch_size = 128
nz = 100          # 噪声维度
lr = 0.0002
epochs = 20
image_size = 28   # MNIST 是 28x28
nc = 1            # 通道数(灰度图)

3. 数据加载(MNIST)

transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize([0.5], [0.5])  # [-1, 1]
])

dataset = torchvision.datasets.MNIST(root='./data', train=True, download=True, transform=transform)
dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True)

4. 构建生成器(Generator)

class Generator(nn.Module):
    def __init__(self):
        super().__init__()
        self.main = nn.Sequential(# 输入: (batch, nz, 1, 1)
            nn.ConvTranspose2d(nz, 256, 7, 1, 0, bias=False),
            nn.BatchNorm2d(256),
            nn.ReLU(True),
            # (256, 7, 7)
            nn.ConvTranspose2d(256, 128, 4, 2, 1, bias=False),
            nn.BatchNorm2d(128),
            nn.ReLU(True),
            # (128, 14, 14)
            nn.ConvTranspose2d(128, nc, 4, 2, 1, bias=False),
            nn.Tanh()
            # 输出: (nc, 28, 28)
        )

    def forward(self, x):
        return self.main(x)

5. 构建判别器(Discriminator)

class Discriminator(nn.Module):
    def __init__(self):
        super().__init__()
        self.main = nn.Sequential(# 输入: (nc, 28, 28)
            nn.Conv2d(nc, 128, 4, 2, 1, bias=False),
            nn.LeakyReLU(0.2, inplace=True),
            # (128, 14, 14)
            nn.Conv2d(128, 256, 4, 2, 1, bias=False),
            nn.BatchNorm2d(256),
            nn.LeakyReLU(0.2, inplace=True),
            # (256, 7, 7)
            nn.Conv2d(256, 1, 7, 1, 0, bias=False),
            nn.Sigmoid()
            # 输出: (1, 1, 1)
        )

    def forward(self, x):
        return self.main(x).view(-1, 1).squeeze(1)

6. 初始化模型与优化器

netG = Generator().to(device)
netD = Discriminator().to(device)

criterion = nn.BCELoss()
optimizerG = torch.optim.Adam(netG.parameters(), lr=lr, betas=(0.5, 0.999))
optimizerD = torch.optim.Adam(netD.parameters(), lr=lr, betas=(0.5, 0.999))

# 固定噪声用于可视化
fixed_noise = torch.randn(64, nz, 1, 1, device=device)

7. 训练循环

real_label = 1.0
fake_label = 0.0

for epoch in range(epochs):
    for i, (data, _) in enumerate(dataloader):
        ############################
        # (1) 更新判别器 D
        ############################
        netD.zero_grad()
        real = data.to(device)
        b_size = real.size(0)
        label = torch.full((b_size,), real_label, dtype=torch.float, device=device)

        output = netD(real)
        errD_real = criterion(output, label)
        errD_real.backward()

        noise = torch.randn(b_size, nz, 1, 1, device=device)
        fake = netG(noise)
        label.fill_(fake_label)
        output = netD(fake.detach())
        errD_fake = criterion(output, label)
        errD_fake.backward()
        optimizerD.step()

        ############################
        # (2) 更新生成器 G
        ############################
        netG.zero_grad()
        label.fill_(real_label)  # trick: 让 G 欺骗 D
        output = netD(fake)
        errG = criterion(output, label)
        errG.backward()
        optimizerG.step()

        if i % 100 == 0:
            print(f'[{epoch}/{epochs}][{i}/{len(dataloader)}]'
                  f'Loss_D: {errD_real + errD_fake:.4f}, Loss_G: {errG:.4f}')

    # 每个 epoch 保存生成图像
    with torch.no_grad():
        fake = netG(fixed_noise).detach().cpu()
    grid = torchvision.utils.make_grid(fake, normalize=True, nrow=8)
    plt.figure(figsize=(8,8))
    plt.imshow(np.transpose(grid, (1,2,0)))
    plt.axis('off')
    plt.savefig(f'mnist_epoch_{epoch}.png')
    plt.close()

结果:你会看到 mnist_epoch_*.png 文件,随着训练进行,生成的数字越来越清晰。


👤 三、GAN 生成 CelebA 人脸图像(进阶)

注意:CelebA 数据集较大(约 1.4GB),需先下载。

1. 数据准备(CelebA)

# 下载 CelebA(首次运行会自动下载)# 或手动下载到 ./data/celeba 并解压

transform = transforms.Compose([
    transforms.CenterCrop(138),
    transforms.Resize(64),
    transforms.ToTensor(),
    transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5])  # RGB, [-1,1]
])

dataset = torchvision.datasets.CelebA(root='./data', split='train', download=True, transform=transform)
dataloader = DataLoader(dataset, batch_size=128, shuffle=True, num_workers=2)

2. 修改超参数

nc = 3        # RGB
image_size = 64
nz = 100
lr = 0.0002
epochs = 30   # 建议至少 30 轮

3. 调整网络结构(适配 64×64)

生成器

class Generator(nn.Module):
    def __init__(self):
        super().__init__()
        self.main = nn.Sequential(nn.ConvTranspose2d(nz, 512, 4, 1, 0, bias=False),
            nn.BatchNorm2d(512),
            nn.ReLU(True),
            # 4x4
            nn.ConvTranspose2d(512, 256, 4, 2, 1, bias=False),
            nn.BatchNorm2d(256),
            nn.ReLU(True),
            # 8x8
            nn.ConvTranspose2d(256, 128, 4, 2, 1, bias=False),
            nn.BatchNorm2d(128),
            nn.ReLU(True),
            # 16x16
            nn.ConvTranspose2d(128, 64, 4, 2, 1, bias=False),
            nn.BatchNorm2d(64),
            nn.ReLU(True),
            # 32x32
            nn.ConvTranspose2d(64, nc, 4, 2, 1, bias=False),
            nn.Tanh()
            # 64x64
        )
    def forward(self, x):
        return self.main(x)

判别器

class Discriminator(nn.Module):
    def __init__(self):
        super().__init__()
        self.main = nn.Sequential(nn.Conv2d(nc, 64, 4, 2, 1, bias=False),
            nn.LeakyReLU(0.2, inplace=True),
            # 32x32
            nn.Conv2d(64, 128, 4, 2, 1, bias=False),
            nn.BatchNorm2d(128),
            nn.LeakyReLU(0.2, inplace=True),
            # 16x16
            nn.Conv2d(128, 256, 4, 2, 1, bias=False),
            nn.BatchNorm2d(256),
            nn.LeakyReLU(0.2, inplace=True),
            # 8x8
            nn.Conv2d(256, 512, 4, 2, 1, bias=False),
            nn.BatchNorm2d(512),
            nn.LeakyReLU(0.2, inplace=True),
            # 4x4
            nn.Conv2d(512, 1, 4, 1, 0, bias=False),
            nn.Sigmoid())
    def forward(self, x):
        return self.main(x).view(-1, 1).squeeze(1)

其余训练代码与 MNIST 完全相同(只需替换 dataloader 和网络)。

结果:训练 30 轮后,你会看到模糊但可辨认的人脸;50+ 轮后效果更佳。

💡 提示:使用 GPU(如 Colab)可显著加速训练。CelebA 训练一轮约需 5~10 分钟(Tesla T4)。


🛠️ 四、常见问题与改进建议

问题 解决方案
模式崩溃(只生成一种数字 / 人脸) 使用 WGAN-GP、增加噪声、调整学习率
生成图像模糊 使用 StyleGAN增加训练轮数
训练不稳定 使用 BatchNorm(G)+ LeakyReLU(D),学习率设为 0.0002,betas=(0.5, 0.999)
CelebA 下载慢 手动从 Kaggle 下载

📚 五、进一步探索

  • 尝试 WGANLSGAN 改进损失函数
  • 使用 StyleGAN2 生成高清人脸(推荐 NVIDIA 官方 repo
  • 在生成器中加入 条件标签(cGAN),生成指定数字或属性人脸

附录一

改进 GAN 的损失函数是解决传统 GAN 训练过程中出现的一些问题(如模式崩溃、训练不稳定等)的有效方法。WGAN(Wasserstein GAN)和 LSGAN(Least Squares GAN)就是两种常用的改进方案。接下来,我将分别介绍如何在上述代码基础上实现 WGAN 和 LSGAN。

一、Wasserstein GAN (WGAN)

WGAN 主要通过使用 Wasserstein 距离代替传统的 JS 散度来衡量生成分布与真实分布之间的差异。为了保证判别器输出满足 Lipschitz 连续条件,通常会对判别器的参数进行裁剪或采用梯度惩罚(WGAN-GP)的方法。

实现步骤

  1. 移除 Sigmoid 层:WGAN 的判别器输出不需要经过 Sigmoid 函数。
  2. 修改损失函数:对于判别器,最大化真实样本评分与生成样本评分之差;对于生成器,最小化生成样本的评分。
  3. 权重裁剪(可选) 或者 梯度惩罚(推荐):为确保判别器满足 Lipschitz 条件。

这里以 WGAN-GP 为例:

# 在 Discriminator 类中移除 Sigmoid
class Discriminator(nn.Module):
    def __init__(self):
        super().__init__()
        self.main = nn.Sequential(
            # ...(其他层保持不变)nn.Conv2d(512, 1, 4, 1, 0, bias=False),
            # 移除了 Sigmoid
        )
    def forward(self, x):
        return self.main(x).view(-1)

然后,在训练循环中加入梯度惩罚部分:

def compute_gradient_penalty(D, real_samples, fake_samples):
    """计算梯度惩罚"""
    alpha = torch.rand(real_samples.size(0), 1, 1, 1).to(device)
    interpolates = (alpha * real_samples + ((1 - alpha) * fake_samples)).requires_grad_(True)
    d_interpolates = D(interpolates)
    fake = torch.ones(real_samples.size(0), 1).requires_grad_(False).to(device)

    gradients = torch.autograd.grad(
        outputs=d_interpolates,
        inputs=interpolates,
        grad_outputs=fake,
        create_graph=True,
        retain_graph=True,
        only_inputs=True,
    )[0]
    gradients = gradients.view(gradients.size(0), -1)
    gradient_penalty = ((gradients.norm(2, dim=1) - 1) ** 2).mean()
    return gradient_penalty

# 在训练循环中:gradient_penalty = compute_gradient_penalty(netD, real.data, fake.data)
lambda_gp = 10  # 梯度惩罚系数
errD = -torch.mean(netD(real)) + torch.mean(netD(fake.detach())) + lambda_gp * gradient_penalty

二、Least Squares GAN (LSGAN)

LSGAN 使用最小二乘法作为损失函数,旨在使生成分布更接近真实分布,并且能够提供比原始 GAN 更平滑的损失表面。

实现步骤

  • 修改判别器和生成器的损失函数。LSGAN 的目标是最小化平方误差。
# 在训练循环中替换原来的 BCELoss 部分
a, b, c = -1, 1, 0  # 参数 a, b, c 根据论文设置

for epoch in range(epochs):
    for i, (data, _) in enumerate(dataloader):
        real = data.to(device)
        b_size = real.size(0)

        # 更新 D
        netD.zero_grad()
        output_real = netD(real).view(-1)
        errD_real = 0.5 * torch.mean((output_real - b) ** 2)

        noise = torch.randn(b_size, nz, 1, 1, device=device)
        fake = netG(noise)
        output_fake = netD(fake.detach()).view(-1)
        errD_fake = 0.5 * torch.mean((output_fake - a) ** 2)

        errD = errD_real + errD_fake
        errD.backward()
        optimizerD.step()

        # 更新 G
        netG.zero_grad()
        output_fake = netD(fake).view(-1)
        errG = 0.5 * torch.mean((output_fake - c) ** 2)
        errG.backward()
        optimizerG.step()

以上便是对 WGAN 和 LSGAN 的基本实现方法。根据实际需要选择适合的模型结构和优化策略,可以显著提升生成图像的质量和训练稳定性。

附录二

在生成器中加入 条件标签(Conditional GAN, cGAN),可以让模型根据指定类别(如 MNIST 的数字 0–9,或 CelebA 的“戴眼镜”“微笑”等属性)生成对应图像。这是实现 可控图像生成 的关键技术。


🧠 一、cGAN 原理简述

  • 传统 GAN:输入噪声 $z$ → 生成图像 $G(z)$
  • cGAN:输入噪声 $z$ + 条件标签 $y$ → 生成图像 $G(z, y)$
  • 判别器也接收条件 $y$:判断 $(x, y)$ 是否为真实配对

损失函数(BCE 版):

$$L_D = -E_x,y[log D(x, y)] – E_z,y[log(1 – D(G(z,y), y))]$$

$$L_G = -E_z,y[log D(G(z,y), y)]$$


🖼️ 二、cGAN 生成指定 MNIST 数字

1. 修改数据加载(保留标签)

# MNIST 数据集自带标签(0-9)dataset = torchvision.datasets.MNIST(root='./data', train=True, download=True, transform=transform)
dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True)

2. 构建条件生成器(Generator)

我们将 标签嵌入为向量,并与噪声拼接后输入网络。

class ConditionalGenerator(nn.Module):
    def __init__(self, nz=100, num_classes=10, embed_dim=10):
        super().__init__()
        self.embed = nn.Embedding(num_classes, embed_dim)  # 标签嵌入
        self.nz = nz
        self.embed_dim = embed_dim

        self.main = nn.Sequential(nn.ConvTranspose2d(nz + embed_dim, 256, 7, 1, 0, bias=False),
            nn.BatchNorm2d(256),
            nn.ReLU(True),
            nn.ConvTranspose2d(256, 128, 4, 2, 1, bias=False),
            nn.BatchNorm2d(128),
            nn.ReLU(True),
            nn.ConvTranspose2d(128, 1, 4, 2, 1, bias=False),
            nn.Tanh())

    def forward(self, z, labels):
        # z: (B, nz, 1, 1)
        # labels: (B,)
        label_emb = self.embed(labels)  # (B, embed_dim)
        label_emb = label_emb.view(label_emb.size(0), label_emb.size(1), 1, 1)
        x = torch.cat([z, label_emb], dim=1)  # (B, nz+embed_dim, 1, 1)
        return self.main(x)

3. 构建条件判别器(Discriminator)

判别器也需要接收标签信息。我们将标签嵌入后与图像特征拼接。

class ConditionalDiscriminator(nn.Module):
    def __init__(self, num_classes=10, embed_dim=10):
        super().__init__()
        self.embed = nn.Embedding(num_classes, embed_dim * 28 * 28)
        self.embed_dim = embed_dim

        self.main = nn.Sequential(nn.Conv2d(1 + embed_dim, 128, 4, 2, 1, bias=False),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Conv2d(128, 256, 4, 2, 1, bias=False),
            nn.BatchNorm2d(256),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Conv2d(256, 1, 7, 1, 0, bias=False),
            nn.Sigmoid())

    def forward(self, x, labels):
        # x: (B, 1, 28, 28)
        # labels: (B,)
        label_emb = self.embed(labels)  # (B, embed_dim*28*28)
        label_emb = label_emb.view(label_emb.size(0), self.embed_dim, 28, 28)
        x = torch.cat([x, label_emb], dim=1)  # (B, 1+embed_dim, 28, 28)
        return self.main(x).view(-1, 1).squeeze(1)

4. 训练循环(关键修改)

netG = ConditionalGenerator(nz=100, num_classes=10).to(device)
netD = ConditionalDiscriminator(num_classes=10).to(device)

criterion = nn.BCELoss()
optimizerG = torch.optim.Adam(netG.parameters(), lr=0.0002, betas=(0.5, 0.999))
optimizerD = torch.optim.Adam(netD.parameters(), lr=0.0002, betas=(0.5, 0.999))

fixed_noise = torch.randn(10, 100, 1, 1, device=device)  # 10 个噪声
fixed_labels = torch.arange(10, dtype=torch.long, device=device)  # 0~9

for epoch in range(epochs):
    for i, (data, labels) in enumerate(dataloader):
        real = data.to(device)
        labels = labels.to(device)
        b_size = real.size(0)

        # Train Discriminator
        netD.zero_grad()
        label_real = torch.ones(b_size, device=device)
        output = netD(real, labels)
        errD_real = criterion(output, label_real)
        errD_real.backward()

        noise = torch.randn(b_size, 100, 1, 1, device=device)
        fake = netG(noise, labels)
        label_fake = torch.zeros(b_size, device=device)
        output = netD(fake.detach(), labels)
        errD_fake = criterion(output, label_fake)
        errD_fake.backward()
        optimizerD.step()

        # Train Generator
        netG.zero_grad()
        output = netD(fake, labels)
        errG = criterion(output, label_real)
        errG.backward()
        optimizerG.step()

    # 生成 0~9 各一张
    with torch.no_grad():
        fake = netG(fixed_noise, fixed_labels).detach().cpu()
    grid = torchvision.utils.make_grid(fake, nrow=10, normalize=True)
    plt.figure(figsize=(10, 1))
    plt.imshow(np.transpose(grid, (1, 2, 0)), cmap='gray')
    plt.axis('off')
    plt.savefig(f'cgan_mnist_epoch_{epoch}.png')
    plt.close()

结果:每张图从左到右依次是 0,1,2,…,9,清晰可控!


👤 三、cGAN 生成带属性的 CelebA 人脸

CelebA 提供 40 个二值属性(如 Eyeglasses, Smiling, Male 等)。我们将使用其中几个属性作为条件。

1. 数据加载(带属性标签)

# CelebA 属性:dataset.attr 是 [N, 40] 的 0/1 张量
# 我们选取前 2 个属性:0=5_o_Clock_Shadow, 1=Arched_Eyebrows, ..., 15=Male, 31=Smiling

selected_attrs = [15, 31]  # Male, Smiling
transform = transforms.Compose([
    transforms.CenterCrop(138),
    transforms.Resize(64),
    transforms.ToTensor(),
    transforms.Normalize([0.5]*3, [0.5]*3)
])

dataset = torchvision.datasets.CelebA(root='./data', split='train', download=True, transform=transform)

# 自定义 Dataset 包装器(只取选定属性)class CelebAWithAttrs(torch.utils.data.Dataset):
    def __init__(self, dataset, attr_indices):
        self.dataset = dataset
        self.attr_indices = attr_indices
    def __len__(self):
        return len(self.dataset)
    def __getitem__(self, idx):
        img, _ = self.dataset[idx]
        attrs = self.dataset.attr[idx][self.attr_indices].float()  # (2,)
        return img, attrs

attr_dataset = CelebAWithAttrs(dataset, selected_attrs)
dataloader = DataLoader(attr_dataset, batch_size=128, shuffle=True, num_workers=2)

2. 条件生成器(使用属性向量)

class ConditionalGeneratorCelebA(nn.Module):
    def __init__(self, nz=100, num_attrs=2):
        super().__init__()
        self.nz = nz
        self.num_attrs = num_attrs

        self.main = nn.Sequential(nn.ConvTranspose2d(nz + num_attrs, 512, 4, 1, 0, bias=False),
            nn.BatchNorm2d(512),
            nn.ReLU(True),
            nn.ConvTranspose2d(512, 256, 4, 2, 1, bias=False),
            nn.BatchNorm2d(256),
            nn.ReLU(True),
            nn.ConvTranspose2d(256, 128, 4, 2, 1, bias=False),
            nn.BatchNorm2d(128),
            nn.ReLU(True),
            nn.ConvTranspose2d(128, 64, 4, 2, 1, bias=False),
            nn.BatchNorm2d(64),
            nn.ReLU(True),
            nn.ConvTranspose2d(64, 3, 4, 2, 1, bias=False),
            nn.Tanh())

    def forward(self, z, attrs):
        # z: (B, nz, 1, 1)
        # attrs: (B, num_attrs)
        attrs = attrs.view(attrs.size(0), self.num_attrs, 1, 1)
        x = torch.cat([z, attrs], dim=1)  # (B, nz+num_attrs, 1, 1)
        return self.main(x)

3. 条件判别器

class ConditionalDiscriminatorCelebA(nn.Module):
    def __init__(self, num_attrs=2):
        super().__init__()
        self.num_attrs = num_attrs
        self.attr_proj = nn.Linear(num_attrs, 64*64)  # 投影到图像尺寸

        self.main = nn.Sequential(nn.Conv2d(3 + 1, 64, 4, 2, 1, bias=False),  # +1 通道用于属性
            nn.LeakyReLU(0.2, inplace=True),
            nn.Conv2d(64, 128, 4, 2, 1, bias=False),
            nn.BatchNorm2d(128),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Conv2d(128, 256, 4, 2, 1, bias=False),
            nn.BatchNorm2d(256),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Conv2d(256, 512, 4, 2, 1, bias=False),
            nn.BatchNorm2d(512),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Conv2d(512, 1, 4, 1, 0, bias=False),
            nn.Sigmoid())

    def forward(self, x, attrs):
        # x: (B, 3, 64, 64)
        # attrs: (B, num_attrs)
        attr_map = self.attr_proj(attrs)  # (B, 64*64)
        attr_map = attr_map.view(-1, 1, 64, 64)  # (B, 1, 64, 64)
        x = torch.cat([x, attr_map], dim=1)  # (B, 4, 64, 64)
        return self.main(x).view(-1, 1).squeeze(1)

4. 固定条件生成示例

# 生成 4 种组合:[Male=0, Smiling=0], [0,1], [1,0], [1,1]
fixed_noise = torch.randn(4, 100, 1, 1, device=device)
fixed_attrs = torch.tensor([
    [0.0, 0.0],
    [0.0, 1.0],
    [1.0, 0.0],
    [1.0, 1.0]
], device=device)

# 训练后生成
with torch.no_grad():
    fake = netG(fixed_noise, fixed_attrs).detach().cpu()
grid = torchvision.utils.make_grid(fake, nrow=2, normalize=True)
plt.figure(figsize=(6,6))
plt.imshow(np.transpose(grid, (1,2,0)))
plt.axis('off')
plt.savefig('celeba_cgan_samples.png')

结果:你会看到 2×2 网格,分别对应:

  • 女性不笑
  • 女性微笑
  • 男性不笑
  • 男性微笑

🔧 四、注意事项

  1. 标签对齐:确保生成器和判别器使用相同的条件表示。
  2. 属性平衡:CelebA 中某些属性(如 Wearing_Hat)样本极少,建议选择常见属性。
  3. 嵌入维度:对于多类别(如 1000 类 ImageNet),建议用 Embedding;对于二值属性,直接拼接即可。
  4. 训练技巧
  • 使用 WGAN-GP + cGAN 可进一步提升稳定性
  • 对属性做 数据增强(如随机翻转时同步翻转 Male 属性)

正文完
 0
一诺
版权声明:本站原创文章,由 一诺 于2025-09-17发表,共计13564字。
转载说明:除特殊说明外本站文章皆由CC-4.0协议发布,转载请注明出处。
评论(没有评论)
验证码