人工智能|机器学习——从0开始构建自己的GAN网络

目录
一 前言
二 生成式对抗网络GAN
三 GAN的训练思路
四 数据集——Chinese MNIST
五 代码——python
1.文件展示
2.代码(一) ——数据预处理
3.代码(二) ——生成器的构建
4.代码(三) ——判别器的构建
5.代码(四) ——图像的储存
6.代码(五) ——网络的训练
7.代码(六) ——网络参数的定义
8.完整代码
六 运行效果
总结
【人工智能|机器学习——从0开始构建自己的GAN网络】
一 前言 本文仅作为经验分享以及学习记录,如有问题,可以在评论区和我讨论。
具体的理论知识暂且不讲,待有时间了我就会慢慢分享理论知识,目前就整点干货,直接上代码,怎么从零开始构建自己的GAN网络。
本项目Github地址
二 生成式对抗网络GAN 生成对抗网络(GAN)有两个部分:生成网络G(Generator)和判别网络D(Discriminator)。
(1)生成网络G:用来生成图片的网络,它接收一个随机的噪声noise,通过这个噪声生成图片。
(2)判别网络D:用来判别图片是否真实的网络。它的输入是一张图片img,输出是img为真实图片的概率,如果为1,就代表100%是真实的图片,而输出为0,就代表不可能是真实的图片。
生成网络G的目的是努力生成一个图片来骗过判别网络D,判别网络D的目的是努力鉴别出生成出来的图片是假的。两个网络在不断博弈中互相进步,达到理想状态:D(G(noise))=0.5(即判别网络D也不确定是到底是不是真实的)
三 GAN的训练思路 GAN的训练要同时训练两个网络,我们使用的方法是:单独交替迭代训练(即训练一个网络的时候,固定住一个网络,去训练另一个网络)
这样做的目的是防止其中一个网络比另一个网络强大太多,导致网络性能弱化。在整个训练过程中,两个网络不断变强,达到理想状态。
四 数据集——Chinese MNIST 我的数据集选的是Kaggle网站上的Chinese MNIST,下载地址
下载速度慢的可以参考我的另一篇博客——解决Kaggle网站下载数据集速度慢,不方便下载的可以联系我发给你压缩包。
数据集举例:
人工智能|机器学习——从0开始构建自己的GAN网络
文章图片

五 代码——python 如果需要替换成自己数据集,我会在每部分代码首部进行特别说明。
这里我们直接开始,直接上代码,通过代码,一方面有助于我梳理本次学习思路,二是我觉得这样更直接明了一些,毕竟动手才有趣。
1.文件展示 人工智能|机器学习——从0开始构建自己的GAN网络
文章图片

2.代码(一) ——数据预处理 这部分函数用来加载path路径下的文件,即我的数据集,也可以根据你们需求换成别的数据集。只需要更改自己的数据集文件夹即可。

def load_data(self, path): print("loading images...") data = https://www.it610.com/article/[] labels = [] imagePaths = sorted(list(paths.list_images(path))) random.seed(42) random.shuffle(imagePaths) for imagePath in imagePaths: image = cv2.imread(imagePath) image = cv2.resize(image, (self.img_rows, self.img_cols)) image = img_to_array(image) data.append(image)label = str(imagePath.split(os.path.sep)[-2]) labels.append(label)data = np.array(data, dtype="float") / 255.0 return data

3.代码(二) ——生成器的构建 这部分代码用来构建生成网络G,不需要更改,尽管网络性能不是很好,但不是必须修改的。
# 构建生成器 def build_generator(self): model = Sequential()# 模型选用的是传统的线性模型 model.add(Dense(256, input_dim=self.latent_dim))# 全连接层 model.add(LeakyReLU(alpha=0.2))# 带泄露修正线性单元 model.add(BatchNormalization(momentum=0.8))# 批归一化 model.add(Dense(512)) model.add(LeakyReLU(alpha=0.2)) model.add(BatchNormalization(momentum=0.8)) model.add(Dense(1024)) model.add(LeakyReLU(alpha=0.2)) model.add(BatchNormalization(momentum=0.8)) model.add(Dense(np.prod(self.img_shape), activation='tanh'))# np.prod()计算所有乘积,输入 model.add(Reshape(self.img_shape))# reshape成图片的尺寸# model.summary()# 日志noise = Input(shape=(self.latent_dim,)) img = model(noise)return Model(noise, img)

4.代码(三) ——判别器的构建 这部分代码用来构建判别网络D,不需要更改,尽管网络性能不是很好,但不是必须修改的。
# 构建判别器 def build_discriminator(self): # 模型选用的是传统的线性模型,CNN中用的也是这个 model = Sequential()model.add(Flatten(input_shape=self.img_shape))# 展平层 model.add(Dense(512))# 全连接层 model.add(LeakyReLU(alpha=0.2))# 带泄露修正线性单元 model.add(Dense(256)) model.add(LeakyReLU(alpha=0.2)) model.add(Dense(1, activation='sigmoid')) # model.summary()img = Input(shape=self.img_shape)# 输入尺寸 validity = model(img)return Model(img, validity)

5.代码(四) ——图像的储存 这部分代码用来储存生成网络不同epoch的输出,不必更改。
def sample_images(self, epoch): r, c = 5, 5 noise = np.random.normal(0, 1, (r * c, self.latent_dim)) gen_imgs = self.generator.predict(noise)gen_imgs = 0.5 * gen_imgs + 0.5fig, axs = plt.subplots(r, c) cnt = 0 for i in range(r): for j in range(c): axs[i, j].imshow(gen_imgs[cnt, :, :, 0], cmap='gray') axs[i, j].axis('off') cnt += 1 # 保存地址为:"images/" fig.savefig("images/%d.png" % epoch) plt.close()

6.代码(五) ——网络的训练 这部分代码用来训练网络,不必更改。
def train(self, epochs, batch_size=128, sample_interval=50, file_path=None):# 加载数据 X_train = self.load_data(file_path) # 标准化 # X_train = np.expand_dims(X_train, axis=3)# 创建标签 valid = np.ones((batch_size, 1)) fake = np.zeros((batch_size, 1))for epoch in range(epochs): idx = np.random.randint(0, X_train.shape[0], batch_size) imgs = X_train[idx]noise = np.random.normal(0, 1, (batch_size, self.latent_dim)) gen_imgs = self.generator.predict(noise)d_loss_real = self.discriminator.train_on_batch(imgs, valid) d_loss_fake = self.discriminator.train_on_batch(gen_imgs, fake) d_loss = 0.5 * np.add(d_loss_real, d_loss_fake)noise = np.random.normal(0, 1, (batch_size, self.latent_dim))g_loss = self.combined.train_on_batch(noise, valid)if epoch % 200 == 0: print("%d [D loss: %f, acc.: %.2f%%] [G loss: %f]" % (epoch, d_loss[0], 100 * d_loss[1], g_loss))# 图像的保存,每sample_interval次保存图片一次 if epoch % sample_interval == 0: self.sample_images(epoch) # 模型权重的保存,每2000个epoch,保存一次模型,保存地址为"weights/" if epoch % 2000 == 0: os.makedirs('weights', exist_ok=True) self.generator.save_weights("weights/gen_epoch%d.h5" % epoch) self.discriminator.save_weights("weights/dis_epoch%d.h5" % epoch)

7.代码(六) ——网络参数的定义 这部分代码定义了网络的一些参数,比如输入尺寸(我的数据集图片大小是[64,64,3]),优化器等等。
需要根据自己的数据集图片的大小,更改self.img_rows、self.img_cols、self.channels
def __init__(self): # 图片尺寸 在这里更改!!!! self.img_rows = 64 self.img_cols = 64 self.channels = 3 # 输入的图片尺寸 self.img_shape = (self.img_rows, self.img_cols, self.channels) self.latent_dim = 100# Adam优化器 optimizer = Adam(0.0002, 0.5)self.discriminator = self.build_discriminator() self.discriminator.compile(loss='binary_crossentropy', optimizer=optimizer, metrics=['accuracy'])self.generator = self.build_generator()z = Input(shape=(self.latent_dim,)) img = self.generator(z)self.discriminator.trainable = Falsevalidity = self.discriminator(img)self.combined = Model(z, validity) self.combined.compile(loss='binary_crossentropy', optimizer=optimizer)

8.完整代码 需要根据自己需求,更改
epochs训练次数;batch_size每组的数量;sample_interval多少次输出一张图片;file_path数据集路径
完整代码我已上传到github中,代码地址
六 运行效果 迭代0次
人工智能|机器学习——从0开始构建自己的GAN网络
文章图片


迭代10000次
人工智能|机器学习——从0开始构建自己的GAN网络
文章图片

迭代30000次
人工智能|机器学习——从0开始构建自己的GAN网络
文章图片

迭代50000次
人工智能|机器学习——从0开始构建自己的GAN网络
文章图片

由于时间以及本人显卡配置的限制,只进行了50000次迭代,为了更好的效果可以增加迭代次数。
总结 至此本博客,从0开始搭建GAN网络就结束了,有什么问题欢迎和我讨论。

    推荐阅读