stn在mnist上的实现 个人博客 - https://cxy-sky.github.io/
代码参考来源
:PyTorch框架实战系列(3)——空间变换器网络STN_Daniel Yuz的博客-CSDN博客
理论
:Pytorch中的仿射变换(affine_grid)_liangbaqiang的博客-CSDN博客
【python|stn在mnist上的实现】详细解读Spatial Transformer Networks(STN)-一篇文章让你完全理解STN了_黄小猿的博客-CSDN博客_stn
? 图片显示用的是matplotlib,自己没下opencv.CNN
import torch
from torch import nn, optimclass CNN(nn.Module):
def __init__(self):
super(CNN, self).__init__()
self.cnn = nn.Sequential(
nn.Conv2d(1, 64, kernel_size=3, padding=1),
nn.ReLU(),
nn.MaxPool2d(kernel_size=4),
nn.Conv2d(64, 128, kernel_size=3, padding=1),
nn.ReLU(),
nn.MaxPool2d(kernel_size=3),
)
self.linear = nn.Sequential(
nn.Dropout2d(0.5),
nn.Linear(512, 10)
)def forward(self, x):
x = self.cnn(x)
x = x.view(x.size()[0], -1)
# print(x.size())
x = self.linear(x)
return xif __name__ == '__main__':
model = CNN()
x = torch.rand(1, 1, 28, 28)
print(model)
y = model(x)
print(y)
STN
import torch
from torch import nnclass STN(nn.Module):
def __init__(self):
super(STN, self).__init__()
self.location_cov = nn.Sequential(
nn.Conv2d(1, 8, kernel_size=7),
nn.ReLU(),
nn.MaxPool2d(2, stride=2),
nn.Conv2d(8, 10, kernel_size=5),
nn.ReLU(),
nn.MaxPool2d(2, stride=2),
)self.localization_linear = nn.Sequential(
nn.Linear(in_features=10 * 3 * 3, out_features=32),
nn.ReLU(),
nn.Linear(in_features=32, out_features=2 * 3)
)self.localization_linear[2].weight.data.zero_()
self.localization_linear[2].bias.data.copy_(torch.tensor([1, 0, 0,
0, 1, 0], dtype=torch.float))self.cnn = nn.Sequential(
nn.Conv2d(1, 64, kernel_size=3, padding=1),
nn.ReLU(),
nn.MaxPool2d(kernel_size=4),
nn.Conv2d(64, 128, kernel_size=3, padding=1),
nn.ReLU(),
nn.MaxPool2d(kernel_size=3),
)
self.linear = nn.Sequential(
nn.Dropout2d(0.5),
nn.Linear(512, 10)
)def stn(self, x):
x2 = self.location_cov(x)
x2 = x2.view(x2.size()[0], -1)
x2 = self.localization_linear(x2)
theta = x2.view(x2.size()[0], 2, 3)
grid = nn.functional.affine_grid(theta, x.size(), align_corners=True)
x = nn.functional.grid_sample(x, grid, align_corners=True)
return xdef forward(self, x):
x = self.stn(x)
x = self.cnn(x)
x = x.view(x.size()[0], -1)
x = self.linear(x)
return xif __name__ == '__main__':
x = torch.rand(1, 1, 28, 28)
model = STN()
print(model)
print(model(x))
train
import numpy as np
import torch
from torchvision import transforms
import torch.utils.data
import matplotlib.pyplot as plt
import torchvision
from torch.utils.tensorboard import SummaryWriter
from torchvision.datasets import ImageFolder
from PIL import Image
from torch import nn, optimfrom stn.CNN import CNN
from stn.STN import STNdevice = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')# 数据处理
transform = transforms.Compose([
transforms.RandomRotation(45),
transforms.ToTensor(),
transforms.Normalize((0.5), (0.5))
]
)train_data = https://www.it610.com/article/torchvision.datasets.MNIST('../data/mnist',
download=True,
train=True,
transform=transform
)test_data = https://www.it610.com/article/torchvision.datasets.MNIST('../data/mnist',
download=True,
train=False,
transform=transform, )train_loader = torch.utils.data.DataLoader(train_data,
batch_size=64,
shuffle=True)
test_loader = torch.utils.data.DataLoader(test_data,
batch_size=64,
shuffle=True)data_iter = iter(train_loader)
imgs = torchvision.utils.make_grid(next(data_iter)[0], 8)
imgs = imgs.numpy().transpose(1, 2, 0)
imgs = imgs * 0.5 + 0.5
plt.imshow(imgs)
plt.show()# model = CNN()
model = STN()
model = model.to(device)
loss_fun = nn.CrossEntropyLoss().to(device)
opt_fun = optim.Adam(params=model.parameters(), lr=0.001)loss = 0
train_acc_count = []
test_acc_count = []
train_loss = []
test_loss = []def train(epoch):for i in range(epoch):
for index, data in enumerate(train_loader):
imgs = data[0].to(device)
labels = data[1].to(device)
outputs = model(imgs).to(device)
loss = loss_fun(outputs, labels)
loss.backward()
opt_fun.step()
opt_fun.zero_grad()
if index % 100 == 0:
print("第{}轮,第{}次,loss为:{}".format(i + 1, index, loss.item()))
train_loss.append(loss.item())def test():
test_count = 0.
for imgs, labels in test_loader:
with torch.no_grad():
outputs = model(imgs.to(device)).to(device)
test_acc_count = (torch.max(outputs, dim=1)[1] == labels.to(device)).sum().item()
test_count = labels.size()[0]
print("测试集准确率{}".format(test_acc_count / test_count))if __name__ == '__main__':
# 设置随机数种子
np.random.seed(1)
torch.manual_seed(1)
torch.cuda.manual_seed_all(1)
# 保证每次结果一样
torch.backends.cudnn.deterministic = True
train(10)
test()
sava_path = '../model/mnistStn.pth'
torch.save(model.state_dict(), sava_path)
plt.plot(train_loss)
plt.show()
showImage
from torchvision import datasets, transforms
from torch.utils.data import DataLoader
import torchvision
import torch
import matplotlib.pyplot as pltfrom stn.STN import STNtransform = transforms.Compose([
transforms.RandomRotation(45),
transforms.ToTensor(),
transforms.Normalize((0.5), (0.5))
]
)train_data = https://www.it610.com/article/torchvision.datasets.MNIST('../data/mnist',
download=True,
train=True,
transform=transform
)train_loader = torch.utils.data.DataLoader(train_data,
batch_size=64,
shuffle=True)data_iter = iter(train_loader)
imgs, labels = next(data_iter)
pre = torchvision.utils.make_grid(imgs, 8)
pre = pre.numpy().transpose(1, 2, 0)
pre = pre * 0.5 + 0.5
plt.subplot(2, 1, 1)
plt.imshow(pre)
plt.title('pre')model = STN()
model.load_state_dict(torch.load('../model/mnistStn.pth'))
now = model.stn(imgs).detach()
now = torchvision.utils.make_grid(now, 8)
now = now.numpy().transpose(1, 2, 0)
now = now * 0.5 + 0.5
plt.subplot(2, 1, 2)
plt.imshow(now)
plt.title('now')plt.show()
train,epoch=10
文章图片
? 展示transom后的图片,还是感觉很神奇
文章图片
推荐阅读
- python|基于PyTorch搭建CNN实现视频动作分类任务
- B站|B站马士兵python入门基础版详细笔记(5)
- flask|阿里云宝塔部署python-flask项目
- python|python爬虫实战演示
- python|day04 循环练习题
- 面试|面经自己汇总(三维视觉算法&机器学习&深度学习)——持续更新
- Debug生涯|seaborn urllib.error.URLError: < urlopen error [WinError 10054] 远程主机强迫关闭了一个现有的连接。>
- python|使用python-Django创建Web站点
- 分类|基于vgg16的猫狗识别(二分类)