Pytorch|Pytorch 载入和保存模型(无格式整理,先记下)

  1. 定义网络结构
class DenseNet(nn.Module): r"""Densenet-BC model class, based on `"Densely Connected Convolutional Networks" ` Args: growth_rate (int) - how many filters to add each layer (`k` in paper) block_config (list of 4 ints) - how many layers in each pooling block num_init_features (int) - the number of filters to learn in the first convolution layer bn_size (int) - multiplicative factor for number of bottle neck layers (i.e. bn_size * k features in the bottleneck layer) drop_rate (float) - dropout rate after each dense layer num_classes (int) - number of classification classes """ def __init__(self, growth_rate=32, block_config=(6, 12, 24, 16), num_init_features=64, bn_size=4, drop_rate=0, num_classes=1000):super(DenseNet, self).__init__()# First convolution self.features = nn.Sequential(OrderedDict([ ('conv0', nn.Conv2d(3, num_init_features, kernel_size=7, stride=2, padding=3, bias=False)), ('norm0', nn.BatchNorm2d(num_init_features)), ('relu0', nn.ReLU(inplace=True)), ('pool0', nn.MaxPool2d(kernel_size=3, stride=2, padding=1)), ]))# Each denseblock num_features = num_init_features for i, num_layers in enumerate(block_config): block = _DenseBlock(num_layers=num_layers, num_input_features=num_features, bn_size=bn_size, growth_rate=growth_rate, drop_rate=drop_rate) self.features.add_module('denseblock%d' % (i + 1), block) num_features = num_features + num_layers * growth_rate if i != len(block_config) - 1: trans = _Transition(num_input_features=num_features, num_output_features=num_features // 2) self.features.add_module('transition%d' % (i + 1), trans) num_features = num_features // 2# Final batch norm self.features.add_module('norm5', nn.BatchNorm2d(num_features))# Linear layer self.classifier = nn.Linear(num_features, num_classes)def forward(self, x): features = self.features(x) out = F.relu(features, inplace=True) out = F.avg_pool2d(out, kernel_size=7).view(features.size(0), -1) out = self.classifier(out) return out

  1. 使用网络结构定义模型:
net = DenseNet(num_init_features=96, growth_rate=48, block_config=(6, 12, 36, 24))

  1. 载入模型参数
net.load_state_dict(torch.load('/home/wei.fan/.torch/models/densenet161-17b70270.pth'))

【Pytorch|Pytorch 载入和保存模型(无格式整理,先记下)】4.训练模型
num_ftrs = model_conv.classifier.in_features net.classifier = nn.Linear(num_ftrs, 100) #调整最后一层的尺寸 net =net.cuda() criterion = nn.CrossEntropyLoss() net = optim.SGD(net.parameters(), lr=0.001, momentum=0.9) net =train_net() #训练模型的函数,自定义 torch.save(net.state_dict(), 'net_params.pkl') #只保存模型参数

    推荐阅读