GAN|GAN(生成对抗网络)学习——pytorch实现MINIST手写数据集生成


目录

  • 一、生成对抗网络原理
  • 二、GAN网络结构详解
  • 三、生成器与判别器训练思路
  • 四、实现代码

一、生成对抗网络原理 生成对抗网络,是一种基于博弈思想的网络训练思路,其主要网络模块由两部分组成,分别为generator(生成器)和discriminator(判别器)。
我们以GAM生成Minist手写数据集为例,在这个例子中,我们的目的是为了生成可以以假乱真的手写数字图片。而我们的训练思路,是使用生成器来产生一张照片,并且由判别器来判断这张照片是否是真实的照片。
在这个过程中,生成器会根据判别器返回的结果,一步一步的学习,生成器学习的目的是产生可以骗过判别器的照片,而判别器在这个过程中也会不断的进行学习,其学习的目的是,可以正确的分别出假图片和真图片。
二、GAN网络结构详解
  • 生成器定义
    生成器的输入是一组噪声,网络中包含着全连接层和激活函数,最终会生成一张大小为28x28(784)的图片。
#生成网络 def generator(noise_dim=NOISE_DIM): net = nn.Sequential( nn.Linear(noise_dim, 1024), nn.ReLU(True), nn.Linear(1024, 1024), nn.ReLU(True), nn.Linear(1024, 784), #最终输出大小为784 nn.Tanh() ) return net

  • 判别器定义
    判别器的输入是一张28x28的图片,不过在输入到判别器网络之前,会先将其展开成28x28 = 784的一维向量。
    这个一维向量,可以是真实的图片(从MINIST数据集取出的图片),也可以是生成器生成的图片。
# 判别网络 def discriminator(): net = nn.Sequential( nn.Linear(784, 256), nn.LeakyReLU(0.2), nn.Linear(256, 256), nn.LeakyReLU(0.2), nn.Linear(256, 1) #输出结果为置信度 ) return net

三、生成器与判别器训练思路 【GAN|GAN(生成对抗网络)学习——pytorch实现MINIST手写数据集生成】按照生成对抗网络的基本原理,我们对生成器和判别器的损失函数进行定义。
  • 判别器损失函数
    判别器的损失函数分为两部分
    第一步是要能准确的将正确的图片识别为正确
    第二步是要能准确的将错误的图片识别为错误
    因此,需要为真正的图片和生成的图片分别生成标签
    正确的图片标签为1,生成的图片标签为0
    因此生成器的损失函数即为两部分的损失函数的和
    在这里损失计算函数采用交叉熵的计算方式
    即只需要关注是否分类正确来进行损失计算,而忽略分类错误的损失部分
real_data = https://www.it610.com/article/Variable(x).view(bs, -1)# 真实数据 logits_real = D_net(real_data)# 判别网络得分 sample_noise = (torch.rand(bs, noise_size) - 0.5) / 0.5# -1 ~ 1 的均匀分布 g_fake_seed = Variable(sample_noise) fake_images = G_net(g_fake_seed)# 生成的假的数据 logits_fake = D_net(fake_images)# 判别网络得分 d_total_error = discriminator_loss(logits_real, logits_fake)# 判别器的 loss

def discriminator_loss(logits_real, logits_fake):# 判别器的 loss size = logits_real.shape[0] true_labels = Variable(torch.ones(size, 1)).float() false_labels = Variable(torch.zeros(size, 1)).float() loss = bce_loss(logits_real, true_labels) + bce_loss(logits_fake, false_labels) return loss

  • 生成器损失函数
    与判别器不同,生成器的目的是要生成出可以骗过判别器的假图片
    因此生成器所生成的图片在经过判别器判别后的结果,需要最大限度的接近真实
    因此需要将判别后的结果,与正确的标签进行损失函数计算
    这里损失函数也是采用交叉熵的方式
# 生成网络 g_fake_seed = Variable(sample_noise) fake_images = G_net(g_fake_seed)# 生成的假的数据gen_logits_fake = D_net(fake_images) g_error = generator_loss(gen_logits_fake)# 生成网络的 loss

def generator_loss(logits_fake):# 生成器的 loss size = logits_fake.shape[0] true_labels = Variable(torch.ones(size, 1)).float() loss = bce_loss(logits_fake, true_labels) return loss

四、实现代码
import torch from torch import nn from torch.autograd import Variableimport torchvision.transforms as tfs from torch.utils.data import DataLoader, sampler from torchvision.datasets import MNISTimport numpy as npimport matplotlib.pyplot as plt import matplotlib.gridspec as gridspec import os plt.rcParams['figure.figsize'] = (10.0, 8.0)# 设置画图的尺寸 plt.rcParams['image.interpolation'] = 'nearest' plt.rcParams['image.cmap'] = 'gray'def show_images(images):# 定义画图工具 images = np.reshape(images, [images.shape[0], -1]) sqrtn = int(np.ceil(np.sqrt(images.shape[0]))) sqrtimg = int(np.ceil(np.sqrt(images.shape[1])))fig = plt.figure(figsize=(sqrtn, sqrtn)) gs = gridspec.GridSpec(sqrtn, sqrtn) gs.update(wspace=0.05, hspace=0.05)for i, img in enumerate(images): ax = plt.subplot(gs[i]) plt.axis('off') ax.set_xticklabels([]) ax.set_yticklabels([]) ax.set_aspect('equal') plt.imshow(img.reshape([sqrtimg, sqrtimg])) returndef preprocess_img(x): x = tfs.ToTensor()(x) return (x - 0.5) / 0.5def deprocess_img(x): return (x + 1.0) / 2.0class ChunkSampler(sampler.Sampler):# 定义一个取样的函数 """Samples elements sequentially from some offset. Arguments: num_samples: # of desired datapoints start: offset where we should start selecting from """def __init__(self, num_samples, start=0): self.num_samples = num_samples self.start = startdef __iter__(self): return iter(range(self.start, self.start + self.num_samples))def __len__(self): return self.num_samplesNUM_TRAIN = 50000 NUM_VAL = 5000NOISE_DIM = 96 batch_size = 128train_set = MNIST('./data', train=True, transform=preprocess_img,download=True)train_data = https://www.it610.com/article/DataLoader(train_set, batch_size=batch_size, sampler=ChunkSampler(NUM_TRAIN, 0))val_set = MNIST('./data', train=True, transform=preprocess_img,download=True)val_data = DataLoader(val_set, batch_size=batch_size, sampler=ChunkSampler(NUM_VAL, NUM_TRAIN))imgs = deprocess_img(train_data.__iter__().next()[0].view(batch_size, 784)).numpy().squeeze()# 可视化图片效果 show_images(imgs)# 判别网络 def discriminator(): net = nn.Sequential( nn.Linear(784, 256), nn.LeakyReLU(0.2), nn.Linear(256, 256), nn.LeakyReLU(0.2), nn.Linear(256, 1) ) return net# 生成网络 def generator(noise_dim=NOISE_DIM): net = nn.Sequential( nn.Linear(noise_dim, 1024), nn.ReLU(True), nn.Linear(1024, 1024), nn.ReLU(True), nn.Linear(1024, 784), nn.Tanh() ) return net# 判别器的 loss 就是将真实数据的得分判断为 1,假的数据的得分判断为 0,而生成器的 loss 就是将假的数据判断为 1bce_loss = nn.BCEWithLogitsLoss()# 交叉熵损失函数def discriminator_loss(logits_real, logits_fake):# 判别器的 loss size = logits_real.shape[0] true_labels = Variable(torch.ones(size, 1)).float() false_labels = Variable(torch.zeros(size, 1)).float() loss = bce_loss(logits_real, true_labels) + bce_loss(logits_fake, false_labels) return lossdef generator_loss(logits_fake):# 生成器的 loss size = logits_fake.shape[0] true_labels = Variable(torch.ones(size, 1)).float() loss = bce_loss(logits_fake, true_labels) return loss# 使用 adam 来进行训练,学习率是 3e-4, beta1 是 0.5, beta2 是 0.999 def get_optimizer(net): optimizer = torch.optim.Adam(net.parameters(), lr=3e-4, betas=(0.5, 0.999)) return optimizerdef train_a_gan(D_net, G_net, D_optimizer, G_optimizer, discriminator_loss, generator_loss, show_every=250, noise_size=96, num_epochs=10): iter_count = 0 for epoch in range(num_epochs): for x, _ in train_data: bs = x.shape[0] # 判别网络 real_data = Variable(x).view(bs, -1)# 真实数据 logits_real = D_net(real_data)# 判别网络得分sample_noise = (torch.rand(bs, noise_size) - 0.5) / 0.5# -1 ~ 1 的均匀分布 g_fake_seed = Variable(sample_noise) fake_images = G_net(g_fake_seed)# 生成的假的数据 logits_fake = D_net(fake_images)# 判别网络得分d_total_error = discriminator_loss(logits_real, logits_fake)# 判别器的 loss D_optimizer.zero_grad() d_total_error.backward() D_optimizer.step()# 优化判别网络# 生成网络 g_fake_seed = Variable(sample_noise) fake_images = G_net(g_fake_seed)# 生成的假的数据gen_logits_fake = D_net(fake_images) g_error = generator_loss(gen_logits_fake)# 生成网络的 loss G_optimizer.zero_grad() g_error.backward() G_optimizer.step()# 优化生成网络if (iter_count % show_every == 0): print('Iter: {}, D: {:.4}, G:{:.4}'.format(iter_count, d_total_error.item(), g_error.item())) imgs_numpy = deprocess_img(fake_images.data.cpu().numpy()) show_images(imgs_numpy[0:16]) plt.show() print() iter_count += 1D = discriminator() G = generator()D_optim = get_optimizer(D) G_optim = get_optimizer(G)train_a_gan(D, G, D_optim, G_optim, discriminator_loss, generator_loss)

    推荐阅读