【Loss|对比学习(Contrastive Learning)中的损失函数】
文章目录
-
- 写在前面
- 一、Info Noise-contrastive estimation(Info NCE)
-
- 1.1 描述
- 1.2 实现
- 二、HCL
-
- 2.1 描述
- 2.2 实现
- 三、文字解释
- 四、代码解释
-
- 4.1 Info NCE
- 4.2 HCL
写在前面 ??最近在基于对比学习做实验,github有许多实现,虽然直接套用即可,但是细看之下,损失函数部分甚是疑惑,故学习并记录于此。关于对比学习的内容网络上已经有很多内容了,因此不再赘述。本文重在对InfoNCE的两种实现方式的记录。
一、Info Noise-contrastive estimation(Info NCE) 1.1 描述
??InfoNCE在MoCo中被描述为:
L q = ? log ? exp ? ( q ? k + / τ ) ∑ i = 0 K exp ? ( q ? k i / τ ) (1) \mathcal{L}_{q}=-\log \frac{\exp \left(q \cdot k_{+} / \tau\right)}{\sum_{i=0}^{K} \exp \left(q \cdot k_{i} / \tau\right)} \tag{1} Lq?=?log∑i=0K?exp(q?ki?/τ)exp(q?k+?/τ)?(1)
其中 τ \tau τ是超参。
- 分子表示: q q q对 k + k_+ k+?的点积。所谓点积就是描述 q q q和 k + k_+ k+?两个向量之间的距离。
- 分母表示: q q q对所有 k k k的点积。所谓所有就是指正例(positive sample)和负例(negative sample),所以求和号是从 i = 0 i=0 i=0到 K K K,一共 K + 1 K+1 K+1项。
??MoCo源码的
\moco\builder.py
中,实现如下: # compute logits
# Einstein sum is more intuitive
# positive logits: Nx1
l_pos = torch.einsum('nc,nc->n', [q, k]).unsqueeze(-1)
# negative logits: NxK
l_neg = torch.einsum('nc,ck->nk', [q, self.queue.clone().detach()])
# logits: Nx(1+K)
logits = torch.cat([l_pos, l_neg], dim=1)
# apply temperature
logits /= self.T
# labels: positive key indicators
labels = torch.zeros(logits.shape[0], dtype=torch.long).cuda()
...
return logits, labels
这里的变量这段代码根据注释即可理解:logits
的意义我也查了一下:是未进入softmax的概率
l_pos
表示正样本的得分,l_neg
表示所有负样本的得分,logits
表示将正样本和负样本在列上cat起来之后的值。值得关注的是,labels
的数值,是根据logits.shape[0]
的大小生成的一组zero。也就是大小为batch_size
的一组0。??接下来看损失函数部分,
\main_moco.py
: # define loss function (criterion) and optimizer
criterion = nn.CrossEntropyLoss().cuda(args.gpu)
...
# compute output
output, target = model(im_q=images[0], im_k=images[1])
loss = criterion(output, target)
这里直接对输出的
logits
和生成的labels
计算交叉熵,然后就是模型的loss。这里就是让我不是很理解的地方。先将疑惑埋在心里~二、HCL 2.1 描述
??在文章《Contrastive Learning with Hard Negative Samples》中描述到,使用负样本的损失函数为:
E x ~ p , x + ~ p x + [ ? log ? e f ( x ) T f ( x + ) e f ( x ) T f ( x + ) + Q N ∑ i = 1 N e f ( x ) T f ( x i ? ) ] (2) \mathbb{E}_{x \sim p, x^{+} \sim p_{x}^{+}}\left[-\log \frac{e^{f(x)^{T} f\left(x^{+}\right)}}{e^{f(x)^{T} f\left(x^{+}\right)}+\frac{Q}{N} \sum_{i=1}^{N} e^{f(x)^{T} f\left(x_{i}^{-}\right)}}\right] \tag{2} Ex~p,x+~px+??[?logef(x)Tf(x+)+NQ?∑i=1N?ef(x)Tf(xi??)ef(x)Tf(x+)?](2)
- 分子: e f ( x ) T f ( x + ) e^{f(x)^{T} f(x^{+})} ef(x)Tf(x+)表示学到的表示 f ( x ) f(x) f(x)和正样本 f ( x + ) f(x^+) f(x+)的点积。(其实也就是正样本的得分)
- 分母:第一项表示正样本的得分,第二项表示负样本的得分。
mean(-log(正样本的得分/所有样本的得分))
。2.2 实现
??但是在这篇文章的实现中,
\image\main.py
:def criterion(out_1,out_2,tau_plus,batch_size,beta, estimator):
# neg score
out = torch.cat([out_1, out_2], dim=0)
neg = torch.exp(torch.mm(out, out.t().contiguous()) / temperature)
old_neg = neg.clone()
mask = get_negative_mask(batch_size).to(device)
neg = neg.masked_select(mask).view(2 * batch_size, -1)
# pos score
pos = torch.exp(torch.sum(out_1 * out_2, dim=-1) / temperature)
pos = torch.cat([pos, pos], dim=0)
# negative samples similarity scoring
if estimator=='hard':
N = batch_size * 2 - 2
imp = (beta* neg.log()).exp()
reweight_neg = (imp*neg).sum(dim = -1) / imp.mean(dim = -1)
Ng = (-tau_plus * N * pos + reweight_neg) / (1 - tau_plus)
# constrain (optional)
Ng = torch.clamp(Ng, min = N * np.e**(-1 / temperature))
elif estimator=='easy':
Ng = neg.sum(dim=-1)
else:
raise Exception('Invalid estimator selected. Please use any of [hard, easy]') # contrastive loss
loss = (- torch.log(pos / (pos + Ng) )).mean()
return loss
可以看到最后计算loss的公式是:
loss = (- torch.log(pos / (pos + Ng) )).mean()
的确与我上文中的理解相同,可是为什么这样的实现,没有用到
全0的label
呢?三、文字解释 ??既然是同一种方法的两种实现,已经理解了第二种实现(HCL)。那么,问题就出在了:不理解第一种实现的label为何要这样生成? 于是乎,查看交叉熵的计算方式:
loss ( x , c l a s s ) = ? log ? ( exp ? ( x [ c l a s s ] ) ∑ j exp ? ( x [ j ] ) ) = ? x [ c l a s s ] + log ? ( ∑ j exp ? ( x [ j ] ) ) (3) \text{loss}(x, class) = -\log\left(\frac{\exp(x[class])}{\sum_j \exp(x[j])}\right)= -x[class] + \log\left(\sum_j \exp(x[j])\right) \tag{3} loss(x,class)=?log(∑j?exp(x[j])exp(x[class])?)=?x[class]+log(j∑?exp(x[j]))(3)
交叉熵的label的作用是:
将label作为索引
,来取得 x x x中的项( x [ c l a s s ] x[class] x[class]),因此,这些项就是label。而倘若label是全0的项,那么其含义为: x x x中的第一列为label(正样本),其他列就是负样本。然后带入公式(3)中计算,即可得到交叉熵下的loss值。??而对于HCL的实现方式,是直接将InfoNCE拆解开来,使用正样本的得分和负样本的得分来计算。
四、代码解释 ??首先,生成pos得分和neg的得分:
文章图片
注意,这里省略了生成的特征,直接生成了得分,
4.1 Info NCE
文章图片
4.2 HCL
文章图片
嗒哒~两者的结果“一模一样”(取值范围导致最后一位不太一样)
推荐阅读
- delphi|Delphi的StringReplace
- opencv|opencv(15)---图像膨胀腐蚀
- Machine|【机器学习】当贝叶斯、奥卡姆和香农一起来定义机器学习时
- Machine|在pycharm中部署yolov5报错问题
- DEEP|什么是端到端神经网络()
- Machine|Paper Notes: Cross-Domain Image Translation Based on GAN
- Machine|scikit-learn-分类模型评价标准
- 学习心得|javascript动态添加删除文本框
- 学习心得|JavaScript中只高亮选中文本框中指的的文本