Python人工智能学习PyTorch实现WGAN示例详解
目录
- 1.GAN简述
- 2.生成器模块
- 3.判别器模块
- 4.数据生成模块
- 5.判别器训练
- 6.生成器训练
- 7.结果可视化
1.GAN简述 在GAN中,有两个模型,一个是生成模型,用于生成样本,一个是判别模型,用于判断样本是真还是假。但由于在GAN中,使用的JS散度去计算损失值,很容易导致梯度弥散的情况,从而无法进行梯度下降更新参数,于是在WGAN中,引入了Wasserstein Distance,使得训练变得稳定。本文中我们以服从高斯分布的数据作为样本。
2.生成器模块 这里从2维数据,最终生成2维,主要目的是为了可视化比较方便。也就是说,在生成模型中,我们输入杂乱无章的2维的数据,通过训练之后,可以生成一个赝品,这个赝品在模仿高斯分布。
文章图片
3.判别器模块 判别器同样输入的是2维的数据。比如我们上面的生成器,生成了一个2维的赝品,输入判别器之后,它能够最终输出一个sigmoid转换后的结果,相当于是一个概率,从而判别,这个赝品到底能不能达到以假乱真的程度。
文章图片
4.数据生成模块 由于我们使用的是高斯模型,因此,直接生成我们需要的数据即可。我们在这个模块中,生成8个服从高斯分布的数据。
文章图片
5.判别器训练 由于使用JS散度去计算损失的时候,会很容易出现梯度极小,接近于0的情况,会使得梯度下降无法进行,因此计算损失的时候,使用了Wasserstein Distance,去度量两个分布之间的差异。因此我们假如了梯度惩罚的因子。
文章图片
其中,梯度惩罚的模块如下:
文章图片
6.生成器训练 这里的训练是紧接着判别器训练的。也就是说,在一个周期里面,先训练判别器,再训练生成器。
文章图片
7.结果可视化 通过visdom可视化损失值,通过matplotlib可视化分布的预测结果。
【Python人工智能学习PyTorch实现WGAN示例详解】
文章图片
以上就是人工智能学习PyTorch实现WGAN示例详解的详细内容,更多关于PyTorch实现WGAN的资料请关注脚本之家其它相关文章!
推荐阅读
- 由浅入深理解AOP
- 继续努力,自主学习家庭Day135(20181015)
- python学习之|python学习之 实现QQ自动发送消息
- 逻辑回归的理解与python示例
- 一起来学习C语言的字符串转换函数
- python自定义封装带颜色的logging模块
- 【Leetcode/Python】001-Two|【Leetcode/Python】001-Two Sum
- 定制一套英文学习方案
- 漫画初学者如何学习漫画背景的透视画法(这篇教程请收藏好了!)
- 《深度倾听》第5天──「RIA学习力」便签输出第16期