对比学习(Contrastive Learning)中的损失函数

zz/2024/4/20 15:56:43

文章目录

    • 写在前面
    • 一、Info Noise-contrastive estimation(Info NCE)
      • 1.1 描述
      • 1.2 实现
    • 二、HCL
      • 2.1 描述
      • 2.2 实现
    • 三、文字解释
    • 四、代码解释
      • 4.1 Info NCE
      • 4.2 HCL

写在前面

  最近在基于对比学习做实验,github有许多实现,虽然直接套用即可,但是细看之下,损失函数部分甚是疑惑,故学习并记录于此。关于对比学习的内容网络上已经有很多内容了,因此不再赘述。本文重在对InfoNCE的两种实现方式的记录。

一、Info Noise-contrastive estimation(Info NCE)

1.1 描述

  InfoNCE在MoCo中被描述为:
Lq=−log⁡exp⁡(q⋅k+/τ)∑i=0Kexp⁡(q⋅ki/τ)(1)\mathcal{L}_{q}=-\log \frac{\exp \left(q \cdot k_{+} / \tau\right)}{\sum_{i=0}^{K} \exp \left(q \cdot k_{i} / \tau\right)} \tag{1}Lq=logi=0Kexp(qki/τ)exp(qk+/τ)(1)
其中τ\tauτ是超参。

  • 分子表示:qqqk+k_+k+点积。所谓点积就是描述qqqk+k_+k+两个向量之间的距离。
  • 分母表示:qqq所有kkk的点积。所谓所有就是指正例(positive sample)和负例(negative sample),所以求和号是从i=0i=0i=0KKK,一共K+1K+1K+1项。

1.2 实现

  MoCo源码的\moco\builder.py中,实现如下:

	# compute logits# Einstein sum is more intuitive# positive logits: Nx1l_pos = torch.einsum('nc,nc->n', [q, k]).unsqueeze(-1)# negative logits: NxKl_neg = torch.einsum('nc,ck->nk', [q, self.queue.clone().detach()])# logits: Nx(1+K)logits = torch.cat([l_pos, l_neg], dim=1)# apply temperaturelogits /= self.T# labels: positive key indicatorslabels = torch.zeros(logits.shape[0], dtype=torch.long).cuda()...return logits, labels

这里的变量logits的意义我也查了一下:是未进入softmax的概率

这段代码根据注释即可理解:l_pos表示正样本的得分,l_neg表示所有负样本的得分,logits表示将正样本和负样本在列上cat起来之后的值。值得关注的是,labels的数值,是根据logits.shape[0]的大小生成的一组zero。也就是大小为batch_size的一组0。

  接下来看损失函数部分,\main_moco.py

	# define loss function (criterion) and optimizercriterion = nn.CrossEntropyLoss().cuda(args.gpu)...# compute outputoutput, target = model(im_q=images[0], im_k=images[1])loss = criterion(output, target)

这里直接对输出的logits和生成的labels计算交叉熵,然后就是模型的loss。这里就是让我不是很理解的地方。先将疑惑埋在心里~

二、HCL

2.1 描述

  在文章《Contrastive Learning with Hard Negative Samples》中描述到,使用负样本的损失函数为:
Ex∼p,x+∼px+[−log⁡ef(x)Tf(x+)ef(x)Tf(x+)+QN∑i=1Nef(x)Tf(xi−)](2)\mathbb{E}_{x \sim p, x^{+} \sim p_{x}^{+}}\left[-\log \frac{e^{f(x)^{T} f\left(x^{+}\right)}}{e^{f(x)^{T} f\left(x^{+}\right)}+\frac{Q}{N} \sum_{i=1}^{N} e^{f(x)^{T} f\left(x_{i}^{-}\right)}}\right] \tag{2}Exp,x+px+[logef(x)Tf(x+)+NQi=1Nef(x)Tf(xi)ef(x)Tf(x+)](2)

  • 分子:ef(x)Tf(x+)e^{f(x)^{T} f(x^{+})}ef(x)Tf(x+)表示学到的表示f(x)f(x)f(x)和正样本f(x+)f(x^+)f(x+)的点积。(其实也就是正样本的得分)
  • 分母:第一项表示正样本的得分,第二项表示负样本的得分。

其实本质上适合InfoNCE一个道理,都是mean(-log(正样本的得分/所有样本的得分))

2.2 实现

  但是在这篇文章的实现中,\image\main.py

def criterion(out_1,out_2,tau_plus,batch_size,beta, estimator):# neg scoreout = torch.cat([out_1, out_2], dim=0)neg = torch.exp(torch.mm(out, out.t().contiguous()) / temperature)old_neg = neg.clone()mask = get_negative_mask(batch_size).to(device)neg = neg.masked_select(mask).view(2 * batch_size, -1)# pos scorepos = torch.exp(torch.sum(out_1 * out_2, dim=-1) / temperature)pos = torch.cat([pos, pos], dim=0)# negative samples similarity scoringif estimator=='hard':N = batch_size * 2 - 2imp = (beta* neg.log()).exp()reweight_neg = (imp*neg).sum(dim = -1) / imp.mean(dim = -1)Ng = (-tau_plus * N * pos + reweight_neg) / (1 - tau_plus)# constrain (optional)Ng = torch.clamp(Ng, min = N * np.e**(-1 / temperature))elif estimator=='easy':Ng = neg.sum(dim=-1)else:raise Exception('Invalid estimator selected. Please use any of [hard, easy]')# contrastive lossloss = (- torch.log(pos / (pos + Ng) )).mean()return loss

可以看到最后计算loss的公式是:

	loss = (- torch.log(pos / (pos + Ng) )).mean()

的确与我上文中的理解相同,可是为什么这样的实现,没有用到全0的label呢?

三、文字解释

  既然是同一种方法的两种实现,已经理解了第二种实现(HCL)。那么,问题就出在了:不理解第一种实现的label为何要这样生成? 于是乎,查看交叉熵的计算方式:
loss(x,class)=−log⁡(exp⁡(x[class])∑jexp⁡(x[j]))=−x[class]+log⁡(∑jexp⁡(x[j]))(3)\text{loss}(x, class) = -\log\left(\frac{\exp(x[class])}{\sum_j \exp(x[j])}\right)= -x[class] + \log\left(\sum_j \exp(x[j])\right) \tag{3}loss(x,class)=log(jexp(x[j])exp(x[class]))=x[class]+log(jexp(x[j]))(3)

交叉熵的label的作用是:将label作为索引,来取得xxx中的项(x[class]x[class]x[class]),因此,这些项就是label。而倘若label是全0的项,那么其含义为:xxx中的第一列为label(正样本),其他列就是负样本。然后带入公式(3)中计算,即可得到交叉熵下的loss值。

  而对于HCL的实现方式,是直接将InfoNCE拆解开来,使用正样本的得分和负样本的得分来计算。

四、代码解释

  首先,生成pos得分和neg的得分:
得分
注意,这里省略了生成的特征,直接生成了得分,

4.1 Info NCE

在这里插入图片描述

4.2 HCL

HCL loss
嗒哒~两者的结果“一模一样”(取值范围导致最后一位不太一样)


http://www.ngui.cc/zz/2390172.html

相关文章

常见学习率衰减方式

学习率 学习率的作用 ​ 在机器学习中,监督式学习通过定义一个模型,并根据训练集上的数据估计最优参数。梯度下降法是一个广泛被用来最小化模型误差的参数优化算法。梯度下降法通过多次迭代,并在每一步中最小化成本函数(cost 来…

i=i++深入解释

以下内容是在JAVA虚拟机中探究,学习C语言的小伙伴请自行绕开 一道基础的题目: int i0; ii; i?? 执行结果:0; why??不应该是1吗?大脑中快速飞过计算步骤: i初始化位0,题目中是…

超简单!一部手机就能提取视频中的语音转换成文字

当我们工作中去整理一些视频资料时,有时候需要对视频中所讲的内容进行整理,这时候很多办公小白会采用传统的方法,就是需要一遍又一遍地看,并记录其中的内容。实际上我们可以提取视频中的语言,将相应的语音内容转换成文…

教你如何将语音转换成文字

语音识别是一种将人的语音转换为文本的的技术。语音识别可以直接把你说的话直接转换成文字, 使用起来也比较方便,不用动手,就可以输入你想要的文字。下面小编就来教大家如何将语音转换成文字。 工具:迅捷PDF阅读器 操作方法&#…

想把语音转成文字,就这样做

将语音转成文字的方法很多,如果你不怕麻烦你可以边听语音边敲文字,就是比较费时间。当我们想转化的语音时间比较长的时候往往是行不通的,那比较快速、省力的方法就是使用软件进行转写。给你推荐2个比较好用的转写软件。 一:滴答转…

语音识别技术,将语音转换成文字

现在越来越多的同学都不想打字,而是用语音来代替文字的输入,现在随着语音识别 技术的越来越成熟,完全可以应用到我们的日常生活里了。其实这项技术也可以应用 到工作上,比如利用语音来写文档,方便快捷。那么我们怎么实…

实用系列1 —— 视频中的语音转换成文字

实用系列1 —— 视频中的语音转换成文字python版本 背景说明 疫情原因,家里的老师亲戚需要对着电脑上网课,晋升为十八线小主播~ 备课的内容来源都是当地教育局的公开课,为了学习公开课的上课方法,只能自己慢速播放视…

某优job

目标url : aHR0cHM6Ly93d3cuNTFqb2IuY29tLw 抓取相关数据 通过对源代码的查看,可以很明确的知道,这些数据是同步加载的。 抓包分析: payload不变,由此可以确定url和发送的params 确定headers需要的字段 凭经验和实际测试可得…

手写代码(笔试面试真题)

★★★ 手写代码:实现forEach map filter reduce ★★★ 手写实现一个简易的 Vue Reactive ★★★ 手写代码,监测数组变化,并返回数组长度 ★★★ 手写原生继承,并说出局限性? ★★★★ 手写一个柯里化函数 ★★★…

2021-最新Web前端经典面试试题及答案-史上最全前端面试题(含答案)---手写代码篇

★★★ 手写代码:实现forEach map filter reduce ★★★ 手写实现一个简易的 Vue Reactive ★★★ 手写代码,监测数组变化,并返回数组长度 ★★★ 手写原生继承,并说出局限性? ★★★★ 手写一个柯里化函数 ★★★…