第二章:Pytorch框架构建残差神经网络(ResNet) 第一章: Pytorch框架制作自己的数据集实现图像分类
第二章: Pytorch框架构建残差神经网络(ResNet)
第三章: Pytorch框架构建DenseNet神经网络
提示:本文第二部分为代码实现
文章目录
- 第二章:Pytorch框架构建残差神经网络(ResNet)
- 前言
- 一、残差网络(ResNet)简介
-
- 1.背景介绍
- 2.提出ResNet的原因
- 3.关键技术
- 3.残差网络结构特点
- 二、代码实现一个简单Residual Block
-
- 1.导入相关数据包
- 2.定义ResnetbasicBlock类,实现一个简单block
- 3.展示ResNet34网络架构
- 本文代码
前言 ??神经网络模型想取得更高的正确率,一种显然的思路就是给模型添加更多的层。随着层数的增加,模型的准确率得到提升。但是如果出现过拟合现象,这时再增加更多的层,准确率会下降。在到达一定深度后加入更多层,模型可能产生梯度消失(梯度衰减到0)或梯度爆炸(梯度变成一个非常大的值)等问题。一般来说可以通过更好的初始化权重、添加BN层、设计更好的架构等解决此类问题。
??残差网络(ResNet)则是通过残差连接解决此问题。残差网络结构非常容易修改和扩展,通过调整block内的channel数量以及堆叠的block数量,就可以很容易地调整网络的宽度和深度,来得到不同表达能力的网络,而不用过多地担心网络的“退化”问题,只要训练数据足够,逐步加深网络,就可以获得更好的性能表现。
提示:以下是本篇文章正文内容,下面案例可供参考
一、残差网络(ResNet)简介 1.背景介绍 ??ResNet是在2015年由微软实验室提出,作者是何恺明(本科清华、博士香港中文大学出来的大神)、孙剑(现任西安交大人工智能学院首任院长)等人提出,其论文《Deep Residual Learning for Image Recognition》,斩获当年ImageNet竞赛中分类任务第一名,目标检测第一名。
2.提出ResNet的原因 ??卷积网络发展史中,VGG最高达到了19层,再就是GoogleNet,达到了22层。 增加网络的宽度和深度可以很好的提高网络的性能,深的网络一般都比浅的网络效果好。比如说VGG,该网络就是在AlexNex的基础上通过增加网络的深度大幅度提高了网络性能。但是,简单地不断增加深度,又会导致以下问题:
- 梯度消失或者梯度爆炸??可以通过Batch Normalization可以避免。
- 模型过拟合???????可以通过增大数据量,配合Dropout来避免。
- 计算资源的消耗?????可以通过CPU集群来解决
- 退化问题 ???????可以通过现代网络架构解决
文章图片
3.关键技术 ??ResNet是一种残差网络,可以把它理解为一个模块,这个模块经过堆叠可以构成一个很深的网络。
??ResNet通过增加残差连接(shortcut connection),显示地让网络中的层拟合残差映射(residual mapping)。
??ResNet不再尝试学习 x x x到 H ( x ) H\left( x \right) H(x) 的潜在映射,而是学习两者之间的不同,或者说是残差(residual),然后为了计算 H ( x ) H(x) H(x),可将残差加到输入上。假设残差是 F ( x ) = H ( x ) ? x F\left( x \right) =H\left( x \right) -x F(x)=H(x)?x,我们将尝试学习 F ( x ) + x F\left( x \right) +x F(x)+x,而不是直接学习 H ( x ) H(x) H(x)。
文章图片
每个ResNet块都包含一系列层,残差连接把块的输入加到块的输出上。
由于加操作是在元素级别执行的,所以输入和输出的大小要一致。如果它们的大小不同,我们可以采用填充的方式。
ResNet网络结构为多个Residual Block串联
实验表明学习残差比直接学习输入、输出间映射要容易收敛,可达到更高的分类精度,ResNet在上百层都有很好的表现。
3.残差网络结构特点
- List item 与纯层的堆叠相比, ResNet多了很多“残差连接”,即shortcut路径,也就是Residual Block;
- ResNet中,所有的Residual Block都没有pooling层,降采样是通过conv的stride实现的;
- 通过Average Pooling得到最终的特征,而不是通过全连接层;
- 每个卷积层之后都紧接着BatchNorm层、
二、代码实现一个简单Residual Block 1.导入相关数据包
代码如下:
import torchvision.models
import torch.nn.functional as F
import torch.nn as nn
2.定义ResnetbasicBlock类,实现一个简单block
文章图片
代码如下:
class ResnetbasicBlock(nn.Module):
def __init__(self,in_channels, out_channels):
super().__init()# 继承父类属性
#第一层初始化
self.conv1 = nn.Conv2d(in_channels,
out_channels,
kernel_size=3,#使用3*3的卷积
padding=1,#填充减小的部分,保证原有图片不变,padding代表对图片的上下左右填充的像素值
bias=False#kernel_size = n , 损失 n - 1个像素
)
self.bn1 = nn.BatchNorm2d(out_channels) #对图像按照批次做标准化,使图像失去原有量纲
#第二层初始化
self.conv2 = nn.Conv2d(in_channels,
out_channels,
kernel_size=3,#使用3*3的卷积
padding=1,
bias=False
)
self.bn2 = nn.BatchNorm2d(out_channels)
#定义前向传播过程
def forward(self, x):
residual = x#残差
# 调用第一层,经过第一层权重
out.self.conv1(x)
out.F.relu(self.bn1(out), inplace=True)
# 调用第二层,经过第二层权重
out.self.conv2(out)
out.self.bn2(out)
out = out + residual
return F.rule(out)#激活输出
【python|使用Pytorch框架自己制作做数据集进行图像分类(二)】该代码实现了一个简单的残差块,根据实际情况需要将一个个残差块相连即可获得一个定制的残差神经网络。
3.展示ResNet34网络架构
代码如下:
model = torchvision.models.resnet34()
print(model)
调用Pytroch框架所提供的网络架构,输出显示后大家可以查看其各个层设计模式,以后也可以根据自己的需要去自己设计。
本文代码
可以直接跑通,运行后会显示ResNet34的网络架构。
import torchvision.models
import torch.nn.functional as F
import torch.nn as nn
class ResnetbasicBlock(nn.Module):
def __init__(self,in_channels, out_channels):
super().__init()# 继承父类属性
#第一层初始化
self.conv1 = nn.Conv2d(in_channels,
out_channels,
kernel_size=3,#使用3*3的卷积
padding=1, #填充减小的部分,保证原有图片不变,padding代表对图片的上下左右填充的像素值
bias=False
)#kernel_size=n,损失 n - 1个像素
self.bn1 = nn.BatchNorm2d(out_channels)#对图像按照批次做标准化,使图像失去原有量纲
#第二层初始化
self.conv2 = nn.Conv2d(in_channels,
out_channels,
kernel_size=3, #使用3*3的卷积
padding=1,
bias=False
)
self.bn2 = nn.BatchNorm2d(out_channels)
#定义前向传播过程
def forward(self, x):
residual = x#残差
# 调用第一层,经过第一层权重
out.self.conv1(x)
out.F.relu(self.bn1(out), inplace=True)
# 调用第二层,经过第二层权重
out.self.conv2(out)
out.self.bn2(out)
out = out + residual
return F.rule(out)#激活输出model = torchvision.models.resnet34()
print(model)
推荐阅读
- 机器学习|Dynamic Graph Learning-Neural Network for Multivariate Time Series Modeling
- 时空数据预测(基于图神经网络)|图神经网络应用变体(时空数据挖掘一)
- #|数理统计与机器学习
- 可视化|文献阅读|Nomograms列线图在肿瘤中的应用
- Python Pandas时间戳isoformat介绍
- Python Pandas时间戳now用法介绍
- Python Pandas时间戳替换
- Python Pandas时间戳介绍和用法实例
- Python Pandas.to_datetime()用法介绍