pytorch|pytorch 定义一个网络

声明一个关于网络的类

import torch.nn as nn class NetName(nn.Module): def __init__(self): super(NetName, self).__init__()nn.module1 = ... nn.module2 = ... nn.module3 = ...def forward(self,x): x = self.module1(x) x = self.module1(x) x = self.module2(x) x = self.module3(x) return x

其中在构造函数__init__中构造这个NN中需要使用的各种模块(module),比如:参数完全相同的maxpooling声明为一个模块,或者例如在CV任务中,把feature_extraction的网络和classification的网络分别声明。
forward函数用于声明各个模块间的关系。即,连接整个网络。
net = NetName().to(device) # 创建网络,并放入指定的device

【pytorch|pytorch 定义一个网络】网络创建后,可以通过以下方式遍历模块信息:
for name, module in net._modules.items(): print(name) # name就是__init__中的各个模块名 print(module) # module就是各个模块内具体的层

示例:AlexNet
注释中的tensor大小变化是基于cifar10的图片----(channel=3, height=32, width=32)
import torch.nn as nnclass CNN(nn.Module): def __init__(self): super(CNN, self).__init__()self.features = nn.Sequential( nn.Conv2d(3, 64, kernel_size=11, stride=4, padding=5), # (3,32,32) -> (64,8,8) nn.ReLU(inplace=True), nn.MaxPool2d(kernel_size=2, stride=2),# (64,8,8)-> (64,4,4) nn.Conv2d(64, 192, kernel_size=5, padding=2),# (64,4,4)-> (192,4,4) nn.ReLU(inplace=True), nn.MaxPool2d(kernel_size=2, stride=2),# (192,4,4) -> (192,2,2) nn.Conv2d(192, 384, kernel_size=3, padding=1),# (192,2,2) -> (384,2,2) nn.ReLU(inplace=True), nn.Conv2d(384, 256, kernel_size=3, padding=1),# (384,2,2) -> (256,2,2) nn.ReLU(inplace=True), nn.Conv2d(256, 256, kernel_size=3, padding=1),# (256,2,2) -> (256,2,2) nn.ReLU(inplace=True), nn.MaxPool2d(kernel_size=2, stride=2),# (256,2,2) -> (256,1,1) )self.classifier = nn.Linear(256, 10)# (batch_size,256) -> (batch_size,10)def forward(self, x): x = self.features(x) x = x.view(x.size(0), -1) # flatten to (batch_size, 256*1*1) x = self.classifier(x) return x

    推荐阅读