文章目录
-
- AutoEncoder自编码
-
- 1. 获取训练数据
- 2. AutoEncoder模型
- 3. 训练
AutoEncoder自编码 神经网络也能进行非监督学习, 只需要训练数据, 不需要标签数据. 自编码就是这样一种形式. 自编码能自动分类数据, 而且也能嵌套在半监督学习的上面, 用少量的有标签样本和大量的无标签样本学习.
1. 获取训练数据
【机器学习|机器学习_pytorch_高级神经网络结构_AutoEncoder自编码】自编码只用训练集就好了, 而且只需要训练 training data 的 image, 不用训练 labels.
import torch
import torch.nn as nn
import torch.utils.data as Data
import torchvision# 超参数
EPOCH = 10
BATCH_SIZE = 64
LR = 0.005
DOWNLOAD_MNIST = True# 下过数据的话, 就可以设置成 False
N_TEST_IMG = 5# 到时候显示 5张图片看效果, 如上图一# Mnist digits dataset
train_data = https://www.it610.com/article/torchvision.datasets.MNIST(
root='./mnist/',
train=True,# this is training data
transform=torchvision.transforms.ToTensor(),# Converts a PIL.Image or numpy.ndarray to
# torch.FloatTensor of shape (C x H x W) and normalize in the range [0.0, 1.0]
download=DOWNLOAD_MNIST,# download it if you don't have it
)
2. AutoEncoder模型
AutoEncoder 形式很简单, 分别是 encoder 和 decoder, 压缩和解压, 压缩后得到压缩的特征值, 再从压缩的特征值解压成原图片.
class AutoEncoder(nn.Module):
def __init__(self):
super(AutoEncoder, self).__init__()# 压缩
self.encoder = nn.Sequential(
nn.Linear(28*28, 128),
nn.Tanh(),
nn.Linear(128, 64),
nn.Tanh(),
nn.Linear(64, 12),
nn.Tanh(),
nn.Linear(12, 3),# 压缩成3个特征, 进行 3D 图像可视化
)
# 解压
self.decoder = nn.Sequential(
nn.Linear(3, 12),
nn.Tanh(),
nn.Linear(12, 64),
nn.Tanh(),
nn.Linear(64, 128),
nn.Tanh(),
nn.Linear(128, 28*28),
nn.Sigmoid(),# 激励函数让输出值在 (0, 1)
)def forward(self, x):
encoded = self.encoder(x)
decoded = self.decoder(encoded)
return encoded, decodedautoencoder = AutoEncoder()
3. 训练
optimizer = torch.optim.Adam(autoencoder.parameters(), lr=LR)
loss_func = nn.MSELoss()for epoch in range(EPOCH):
for step, (x, b_label) in enumerate(train_loader):
b_x = x.view(-1, 28*28)# batch x, shape (batch, 28*28)
b_y = x.view(-1, 28*28)# batch y, shape (batch, 28*28)encoded, decoded = autoencoder(b_x)loss = loss_func(decoded, b_y)# mean square error
optimizer.zero_grad()# clear gradients for this training step
loss.backward()# backpropagation, compute gradients
optimizer.step()# apply gradients
推荐阅读
- 自动驾驶|手握全球最大ADAS激光雷达订单(这家公司股价却急速“跌落”)
- 数据科学从0到1|python使用numpy生成指定步长的浮点数序列
- EECS6127 分类器
- 数据分析|python机器学习之模型选择与优化
- 人工智能+大数据|逻辑回归(使用激活函数sigmoid)详细介绍
- 统计学|灰色关联分析,Python实现GRA(gray relation analysis)
- python|总结|图像分割5大经典方法
- 目标检测|yolov5——断点训练/继续训练【解决方法、使用教程】
- 目标检测|yolov5——detect.py代码【注释、详解、使用教程】