[Pytorch]|[Pytorch] dropout and eval()
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
import torchvision
import torchvision.datasets as dsets
from torchvision.transforms import transforms
import cv2class Net(nn.Module):def __init__(self):super(Net, self).__init__()
self.fc = nn.Linear(5, 5)
self.dp = nn.Dropout(0.5)def forward(self, x):x = self.fc(x)
x = self.dp(x)return xif __name__ == '__main__':x = torch.FloatTensor([1]*5)
z = torch.FloatTensor([1]*5)
print(x)
net = Net()
criterion = nn.MSELoss()
optim = torch.optim.SGD(net.parameters(), 0.1)optim.zero_grad()
net.train()
y = net(x)
loss = criterion(y, z)
loss.backward()
optim.step()
print(y)
print(net(x))
# net(x) is not the same as y -> dropout() changes result every time net.eval()
optim.zero_grad()
y = net(x)
loss = criterion(y, z)
loss.backward()
optim.step()
print(y)
print(net(x))
# eval() or mode(train=False) only changes the state of some modules, e.g., dropout, but do not disable loss back-propogation.
# By setting train=False, dropout() does not work and is temporarily removed from the chain of update.with torch.no_grad():
optim.zero_grad()
y = net(x)
loss = criterion(y, z)
# loss.backward() -> torch.no_grad sets torch.parameters() to be an empty set, and conducting loss.backward() will raise error.
optim.step()
print(y)
print(net(x))
# net(x) == y -> with no loss.backward(), params of network are fixed.
【[Pytorch]|[Pytorch] dropout and eval()】Output:
tensor([ 1.,1.,1.,1.,1.])
tensor([ 0.0000, -0.0000,1.6945,0.7744,0.0000])
tensor([ 0.0000, -0.0000,0.0000,0.9910,0.0000])
tensor([ 0.3758, -1.1484,0.5139,0.4955,0.7577])
tensor([ 0.5256, -0.6328,0.6306,0.6166,0.8158])
tensor([ 0.5256, -0.6328,0.6306,0.6166,0.8158])
tensor([ 0.5256, -0.6328,0.6306,0.6166,0.8158])
推荐阅读
- android第三方框架(五)ButterKnife
- Android中的AES加密-下
- Eddy小文
- 带有Hilt的Android上的依赖注入
- android|android studio中ndk的使用
- Android事件传递源码分析
- pytorch|使用pytorch从头实现多层LSTM
- RxJava|RxJava 在Android项目中的使用(一)
- Android7.0|Android7.0 第三方应用无法访问私有库
- 深入理解|深入理解 Android 9.0 Crash 机制(二)