这篇笔记的出发点:在自定义损失函数时,发现同样是交叉熵损失函数,我自己写的和库本身的数值是一样的,但是训练时,二者效果有很大差别,于是怀疑是自己对梯度的理解没到位,由此写下这篇笔记。
1. tensor的is_leaf属性
【pytorch|Pytorch梯度理解+自定义损失函数】主要是为了节省内存,如果1个tensor的is_leaf属性为False,说明他不是叶子节点,是中间变量,那么就不会保存他的梯度
''' 1. 用户自己建立的tentor, is_leaf属性都是True, 无论required_grad是否为True
2. 只有float类型才有梯度,因此需要 a = torch.tensor(a, dtype = torch.float) '''x1 = torch.tensor(1.,requires_grad=True)
x2 = torch.tensor(2.,requires_grad=True)w11 = torch.tensor(0.1)
w12 = torch.tensor(0.2)
print(x1.is_leaf)# True
print(w11.is_leaf)# True
2. grad_fn的属性不影响反向传播
后来发现:我的交叉熵损失函数和库本身的,反向传播值是一样的,
虽然grad_fn不同,库本身是grad_fn=
我的是grad_fn=
后来验证过,我的和他的交叉熵损失函数,对训练结果是没有影响的,当时可能是我函数没写对
class my_CEloss(nn.Module):
def __init__(self) -> None:
super().__init__()
def forward(self,input,label):
result = 0
for i in range(label.size(0)):
# mini_batch里面的1个sample
loss = 0
pred = input[i,:,:]
tar = label[i,:].float()
num = len(tar)
for j in range(len(tar)):
if int(tar[j])==0:
loss += -torch.log(torch.exp(pred[0,j])/(torch.exp(pred[0,j])+torch.exp(pred[1,j])))
else:
loss += -torch.log(torch.exp(pred[1,j])/(torch.exp(pred[0,j])+torch.exp(pred[1,j])))
loss /= num # 对应nn.CrossEntroptLoss()的reduction = 'mean'
result += loss # 求和1个mini_batch的总loss
return result/label.size(0) # 返回每个sample的平均值fn1 = nn.CrossEntropyLoss()
fn2 = my_CEloss()x = torch.randn(2,2,3,requires_grad=True)
x_ = x
y = torch.randint(0,2,(2,3))
y_ = yloss1 = fn1(x,y)
loss2 = fn2(x_,y_)
print(loss1)
print(loss2)
# tensor(1.1663, grad_fn=)
# tensor(1.1663, grad_fn=)
loss1.backward()
loss2.backward()
print(x.grad)
print(x_.grad)# 梯度值是一样的
3. 维度变换不影响反向传播
4. 通常都是scalar反向传播,但是向量也可以反向传播
''' 向量tentor反向传播 '''
x = torch.tensor([1,2,3],dtype= torch.float,requires_grad=True)
y = 2*x
print(y)# tensor([2., 4., 6.], grad_fn=)
y.backward(gradient=torch.ones_like(x))
print(x.grad)# tensor([2., 2., 2.])
推荐阅读
- torch|pytorch入门(三)线性代数的实现
- 线性代数|05 线性代数【动手学深度学习v2】
- 深度学习|深度学习入门之线性回归(PyTorch)
- 深度学习|深度学习入门之自动求导(Pytorch)
- 深度学习|神经网络入门之矩阵计算(Pytorch)
- 深度学习|动手学深度学习----pytorch中数据操作的基本知识
- AI|PyTorch实现LeNet-5
- 神经网络|【深度学习】我用 PyTorch 复现了 LeNet-5 神经网络(MNIST 手写数据集篇)!
- 深度学习|LeNet-5 详解+pytorch简洁实现