NCE损失对应的论文为《A fast and simple algorithm for training neural probabilistic language models》,发表于2012年的ICML会议。
背景
在2012年,语言模型一般采用n-gram的方法,统计单词/上下文间的共现关系,比神经概率语言模型(neural probabilistic language models, NPLMs)效果好。
现在主流的语言模型都是神经概率语言模 型,核心思想是已知上下文
h
h
h,预测下一个词为
w
w
w的概率,通过一定的解码方法(例如greedy search、beam search等),对概率做解码,得到下一个词。Greedy search可以理解为选择概率最大的那个词。
2012年神经概率语言模型效果不好的原因是难训练。一方面自然是硬件的制约,那一年英伟达刚发布GTX680,和现在的A100、H100完全没法比。当时老黄不给力,学术界也没办法;另一方面是算法效率不行,难以进行大规模的分类学习,将”已知上下文
h
h
h,预测下一个词为
w
i
w_i
wi的概率“建模成分类学习任务,目的在于把下一个词分类到词表中的某个词上。
举个例子,已知上下文是“我想去”,需要预测下一个词。词表中有4个词,即['北京','上海','天津','广州']
,需要把下一个词归类到词表的4个词里。如果词表有10万个词呢?训不动啊~
这就是当时面临的困境。NCE对分类算法做了优化,使得对大词表做分类任务成为可能。
原理
通俗的背景讲完了,接下来谈谈公式化的原理部分。
问题建模
已知上下文
h
h
h,预测下一次词为
w
w
w的概率为:
P
θ
h
(
w
)
=
e
x
p
(
s
θ
(
w
,
h
)
)
∑
w
i
e
x
p
(
s
θ
(
w
i
,
h
)
)
(1)
P_{\theta}^h(w)=\frac{exp(s_{\theta}(w,h))}{\sum_{w_i}{exp(s_{\theta}(w_i,h))}}\tag{1}
Pθh(w)=∑wiexp(sθ(wi,h))exp(sθ(w,h))(1)
其中,
s
θ
(
w
,
h
)
s_{\theta}(w,h)
sθ(w,h)表示已知上下文
h
h
h,下一个词为
w
w
w的预测得分;
∑
w
i
\sum_{w_i}
∑wi表示词表内的所有词。
一般情况下,
s
θ
(
w
,
h
)
s_{\theta}(w,h)
sθ(w,h)通过对上下文
h
h
h表征以及词类别
w
w
w表征添加多个全连接层计算得到。最简单的策略,仅对上下文
h
h
h表征
f
h
f_h
fh用一个全连接层
W
W
W做一次映射,再和词类别
w
w
w表征
f
w
i
f_{w_i}
fwi做点积即可。
s
θ
(
w
,
h
)
=
(
f
h
W
)
⋅
f
w
s_{\theta}(w,h)=(f_h W) \cdot f_{w}
sθ(w,h)=(fhW)⋅fw
难度分析
对公式(1)进行分析,
分子部分
e
x
p
(
s
θ
(
w
,
h
)
)
exp(s_{\theta}(w,h))
exp(sθ(w,h))是好算的,针对单个
w
w
w,只需要计算一次。
分母部分KaTeX parse error: \tag works only in display equations不好算,针对单个
w
w
w,需要计算
e
x
p
(
s
θ
(
w
1
,
h
)
)
,
e
x
p
(
s
θ
(
w
2
,
h
)
)
,
.
.
.
e
x
p
(
s
θ
(
w
n
,
h
)
)
exp(s_{\theta}(w_1,h)), exp(s_{\theta}(w_2,h)), ...exp(s_{\theta}(w_n,h))
exp(sθ(w1,h)),exp(sθ(w2,h)),...exp(sθ(wn,h)),如果词表中词很多,计算量不小。
目前学术界、工业界对超大规模分类的优化基本上都聚焦在如何优化分母上,例如InfoNCE仅关注batch内的负类样本、KNN softmax对类别聚类,减少类别数目、partial FC对类别做采样以及显存均分来较少计算量、Inf-CL借助FlashAttention的思想,以空间换时间。
优化策略
既然对词表内n个词的大规模分类任务难做,难办,那就掀桌子不办了!!!
将原多分类任务转换成一个更容易实现的任务——新二分类任务。
除了有正常的真实数据之外,从一个噪声分布里采样噪声数据,对真实数据和噪声数据做二分类,可以证明:随着噪声数据越多,转换后任务的优化目标和转换前任务越接近。
新二分类任务
给定上下文
h
h
h后,现在有两个数据分布,一个是真实数据分布
P
d
h
(
w
)
P_d^h(w)
Pdh(w)(实际应该写成
P
d
(
w
∣
h
)
P_d(w|h)
Pd(w∣h),简化形式写成
P
d
h
(
w
)
P_d^h(w)
Pdh(w)),另一个是噪声数据分布
P
n
(
w
)
P_n(w)
Pn(w),真实数据和噪声数据的比例是1:k
。所以,训练数据的完整分布是
P
h
(
w
)
=
1
k
+
1
P
d
h
(
w
)
+
k
k
+
1
P
n
(
w
)
P^h(w)=\frac{1}{k+1}P_d^h(w)+\frac{k}{k+1}P_n(w)
Ph(w)=k+11Pdh(w)+k+1kPn(w),训练任务是
D
=
1
D=1
D=1(分辨真实数据)和
D
=
0
D=0
D=0(分辨噪声数据)。
我们希望优化神经网络参数
θ
\theta
θ,来拟合真实数据分布
P
d
h
(
w
)
=
P
θ
h
(
w
)
P_d^h(w)=P^h_{\theta}(w)
Pdh(w)=Pθh(w),后者就是我们学到的数据分布
P
θ
h
(
w
)
P^h_{\theta}(w)
Pθh(w),于是,训练数据的完整分布写成
P
h
(
w
,
θ
)
=
1
k
+
1
P
θ
h
(
w
)
+
k
k
+
1
P
n
(
w
)
P^h(w,\theta)=\frac{1}{k+1}P^h_{\theta}(w)+\frac{k}{k+1}P_n(w)
Ph(w,θ)=k+11Pθh(w)+k+1kPn(w)
训练目标一般是最大化后验概率
P
h
(
D
∣
w
,
θ
)
P^h(D|w,\theta)
Ph(D∣w,θ)的对数似然期望
E
[
l
o
g
(
P
h
(
D
∣
w
,
θ
)
)
]
E \left[log(P^h(D|w,\theta))\right]
E[log(Ph(D∣w,θ))],需要计算后验概率
P
h
(
D
∣
w
,
θ
)
P^h(D|w,\theta)
Ph(D∣w,θ)。
P
h
(
D
∣
w
,
θ
)
=
P
h
(
D
=
1
∣
w
,
θ
)
+
P
h
(
D
=
0
∣
w
,
θ
)
(2)
P^h(D|w,\theta)=P^h(D=1|w,\theta)+P^h(D=0|w,\theta)\tag{2}
Ph(D∣w,θ)=Ph(D=1∣w,θ)+Ph(D=0∣w,θ)(2)
真实数据分布的后验概率为:
P
h
(
D
=
1
∣
w
,
θ
)
=
P
h
(
w
,
θ
∣
D
=
1
)
P
h
(
w
,
θ
)
P
h
(
D
=
1
)
=
P
θ
h
(
w
)
1
k
+
1
P
θ
h
(
w
)
+
k
k
+
1
P
n
(
w
)
1
k
+
1
=
P
θ
h
(
w
)
P
θ
h
(
w
)
+
k
P
n
(
w
)
(3)
\begin{equation}\begin{aligned} P^h(D=1|w,\theta) &= \frac{P^h(w,\theta|D=1)}{P^h(w,\theta)}P^h(D=1) \\ &=\frac{P_{\theta}^h(w)}{\frac{1}{k+1}P^h_{\theta}(w)+\frac{k}{k+1}P_n(w)}\frac{1}{k+1} \\ &=\frac{P_{\theta}^h(w)}{P_{\theta}^h(w)+kP_n(w)} \end{aligned} \end{equation} \tag{3}
Ph(D=1∣w,θ)=Ph(w,θ)Ph(w,θ∣D=1)Ph(D=1)=k+11Pθh(w)+k+1kPn(w)Pθh(w)k+11=Pθh(w)+kPn(w)Pθh(w)(3)
我们来看看等式为什么成立
- 边缘概率 P h ( w , θ ) = 1 k + 1 P θ h ( w ) + k k + 1 P n ( w ) P^h(w,\theta)=\frac{1}{k+1}P^h_{\theta}(w)+\frac{k}{k+1}P_n(w) Ph(w,θ)=k+11Pθh(w)+k+1kPn(w)
- 先验概率
P
h
(
D
=
1
)
=
1
k
+
1
P^h(D=1)=\frac{1}{k+1}
Ph(D=1)=k+11,原因是真实数据和噪声数据的比例是
1:k
。 - 似然函数 P h ( w , θ ∣ D = 1 ) = P θ h ( w ) P^h(w,\theta|D=1)=P^h_{\theta}(w) Ph(w,θ∣D=1)=Pθh(w),表明在真实数据分布下,从词表里预测下一个词为 w w w的概率是 P θ h ( w ) P^h_{\theta}(w) Pθh(w),这就是我们想拟合的函数。
类似的,噪声数据分布的后验概率为:
P
h
(
D
=
0
∣
w
,
θ
)
=
P
h
(
w
,
θ
∣
D
=
0
)
P
h
(
w
,
θ
)
P
h
(
D
=
0
)
=
P
n
(
w
)
1
k
+
1
P
θ
h
(
w
)
+
k
k
+
1
P
n
(
w
)
k
k
+
1
=
k
P
n
(
w
)
P
θ
h
(
w
)
+
k
P
n
(
w
)
(4)
\begin{equation}\begin{aligned} P^h(D=0|w,\theta) &= \frac{P^h(w,\theta|D=0)}{P^h(w,\theta)}P^h(D=0) \\ &=\frac{P_n(w)}{\frac{1}{k+1}P^h_{\theta}(w)+\frac{k}{k+1}P_n(w)}\frac{k}{k+1} \\ &=\frac{kP_n(w)}{P_{\theta}^h(w)+kP_n(w)} \end{aligned} \end{equation} \tag{4}
Ph(D=0∣w,θ)=Ph(w,θ)Ph(w,θ∣D=0)Ph(D=0)=k+11Pθh(w)+k+1kPn(w)Pn(w)k+1k=Pθh(w)+kPn(w)kPn(w)(4)
后验概率
P
h
(
D
∣
w
i
,
θ
)
P^h(D|w_i,\theta)
Ph(D∣wi,θ)的对数似然的期望
E
[
l
o
g
(
P
h
(
D
∣
w
i
,
θ
)
)
]
E \left[log(P^h(D|w_i,\theta))\right]
E[log(Ph(D∣wi,θ))]为
J
h
(
θ
)
=
E
[
l
o
g
(
P
h
(
D
∣
w
,
θ
)
)
]
=
E
P
d
h
[
l
o
g
P
h
(
D
=
1
∣
w
,
θ
)
]
+
E
P
n
[
l
o
g
P
h
(
D
=
0
∣
w
,
θ
)
]
=
E
P
d
h
[
l
o
g
P
θ
h
(
w
)
P
θ
h
(
w
)
+
k
P
n
(
w
)
]
+
E
P
n
[
l
o
g
k
P
n
(
w
)
P
θ
h
(
w
)
+
k
P
n
(
w
)
]
(5)
\begin{equation}\begin{aligned} J^h(\theta)&=E \left[log(P^h(D|w,\theta))\right] \\ &= E_{P_d^h}\left[logP^h(D=1|w,\theta)\right] +E_{P_n}\left[logP^h(D=0|w,\theta)\right] \\ &= E_{P_d^h}\left[log\frac{P_{\theta}^h(w)}{P_{\theta}^h(w)+kP_n(w)}\right] +E_{P_n}\left[log\frac{kP_n(w)}{P_{\theta}^h(w)+kP_n(w)}\right] \\ \end{aligned} \end{equation} \tag{5}
Jh(θ)=E[log(Ph(D∣w,θ))]=EPdh[logPh(D=1∣w,θ)]+EPn[logPh(D=0∣w,θ)]=EPdh[logPθh(w)+kPn(w)Pθh(w)]+EPn[logPθh(w)+kPn(w)kPn(w)](5)
我们来算一下梯度,等于
∂
∂
θ
J
h
(
θ
)
=
E
P
d
h
[
k
P
n
(
w
)
P
θ
h
(
w
)
+
k
P
n
(
w
)
∂
∂
θ
l
o
g
P
θ
h
(
w
)
]
−
k
E
P
n
[
P
θ
h
(
w
)
P
θ
h
(
w
)
+
k
P
n
(
w
)
∂
∂
θ
l
o
g
P
θ
h
(
w
)
]
(6)
\begin{equation} \begin{aligned} \frac{\partial}{\partial{\theta}}{J^h(\theta)}&= E_{P_d^h}\left[\frac{kP_n(w)}{P_{\theta}^h(w)+kP_n(w)}\frac{\partial}{\partial\theta}logP_{\theta}^h(w)\right] -\\&kE_{P_n}\left[\frac{P_{\theta}^h(w)}{P_{\theta}^h(w)+kP_n(w)}\frac{\partial}{\partial\theta}logP_{\theta}^h(w)\right] \end{aligned} \end{equation} \tag{6}
∂θ∂Jh(θ)=EPdh[Pθh(w)+kPn(w)kPn(w)∂θ∂logPθh(w)]−kEPn[Pθh(w)+kPn(w)Pθh(w)∂θ∂logPθh(w)](6)
对(6)式做化简,有
∂
∂
θ
J
h
(
θ
)
=
E
P
d
h
[
k
P
n
(
w
)
P
θ
h
(
w
)
+
k
P
n
(
w
i
)
∂
∂
θ
l
o
g
P
θ
h
(
w
)
]
−
k
E
P
n
[
P
θ
h
(
w
)
P
θ
h
(
w
)
+
k
P
n
(
w
)
∂
∂
θ
l
o
g
P
θ
h
(
w
)
]
=
∑
w
[
P
d
h
⋅
k
P
n
(
w
)
P
θ
h
(
w
)
+
k
P
n
(
w
)
∂
∂
θ
l
o
g
P
θ
h
(
w
)
−
k
P
n
⋅
P
θ
h
(
w
)
P
θ
h
(
w
)
+
k
P
n
(
w
)
∂
∂
θ
l
o
g
P
θ
h
(
w
)
]
=
∑
w
[
k
P
n
(
w
)
P
θ
h
(
w
)
+
k
P
n
(
w
)
×
(
P
d
h
(
w
)
−
P
θ
h
(
w
)
)
∂
∂
θ
l
o
g
P
θ
h
(
w
)
]
(7)
\begin{equation} \begin{aligned} \frac{\partial}{\partial{\theta}}{J^h(\theta)}&= E_{P_d^h}\left[\frac{kP_n(w)}{P_{\theta}^h(w)+kP_n(w_i)}\frac{\partial}{\partial\theta}logP_{\theta}^h(w)\right] -\\&kE_{P_n}\left[\frac{P_{\theta}^h(w)}{P_{\theta}^h(w)+kP_n(w)}\frac{\partial}{\partial\theta}logP_{\theta}^h(w)\right]\\ &=\sum_w\left[P_d^h\cdot\frac{kP_n(w)}{P_{\theta}^h(w)+kP_n(w)}\frac{\partial}{\partial\theta}logP_{\theta}^h(w)-\right.\\ &\left. kP_{n}\cdot\frac{P_{\theta}^h(w)}{P_{\theta}^h(w)+kP_n(w)}\frac{\partial}{\partial\theta}logP_{\theta}^h(w) \right]\\ &=\sum_w\left[\frac{kP_n(w)}{P_{\theta}^h(w)+kP_n(w)}\times\right.\\ &\left. (P_d^h(w)-P_{\theta}^h(w))\frac{\partial}{\partial\theta}logP_{\theta}^h(w) \right] \end{aligned} \end{equation} \tag{7}
∂θ∂Jh(θ)=EPdh[Pθh(w)+kPn(wi)kPn(w)∂θ∂logPθh(w)]−kEPn[Pθh(w)+kPn(w)Pθh(w)∂θ∂logPθh(w)]=w∑[Pdh⋅Pθh(w)+kPn(w)kPn(w)∂θ∂logPθh(w)−kPn⋅Pθh(w)+kPn(w)Pθh(w)∂θ∂logPθh(w)]=w∑[Pθh(w)+kPn(w)kPn(w)×(Pdh(w)−Pθh(w))∂θ∂logPθh(w)](7)
当噪声数据量级巨大,
k
→
∞
k\to \infty
k→∞ ,
k
P
n
(
w
)
P
θ
h
(
w
)
+
k
P
n
(
w
)
→
1
\frac{kP_n(w)}{P_{\theta}^h(w)+kP_n(w)}\to1
Pθh(w)+kPn(w)kPn(w)→1 ,有
∂
∂
θ
J
h
(
θ
)
=
∑
w
[
k
P
n
(
w
)
P
θ
h
(
w
)
+
k
P
n
(
w
)
×
(
P
d
h
(
w
)
−
P
θ
h
(
w
)
)
∂
∂
θ
l
o
g
P
θ
h
(
w
)
]
→
∑
w
[
(
P
d
h
(
w
)
−
P
θ
h
(
w
)
)
∂
∂
θ
l
o
g
P
θ
h
(
w
)
]
(8)
\begin{equation} \begin{aligned} \frac{\partial}{\partial{\theta}}{J^h(\theta)}&= \sum_w\left[\frac{kP_n(w)}{P_{\theta}^h(w)+kP_n(w)}\times\right.\\ &\left. (P_d^h(w)-P_{\theta}^h(w))\frac{\partial}{\partial\theta}logP_{\theta}^h(w) \right]\\ &\to \sum_w\left[(P_d^h(w)-P_{\theta}^h(w))\frac{\partial}{\partial\theta}logP_{\theta}^h(w) \right] \end{aligned} \end{equation} \tag{8}
∂θ∂Jh(θ)=w∑[Pθh(w)+kPn(w)kPn(w)×(Pdh(w)−Pθh(w))∂θ∂logPθh(w)]→w∑[(Pdh(w)−Pθh(w))∂θ∂logPθh(w)](8)
原多分类任务
我们计算下原多分类任务的对数似然期望和梯度,看看
k
→
∞
k\to \infty
k→∞ 时的新二分类任务和原多分类任务有什么关系。原多分类任务的优化目标为
J
h
(
θ
)
=
E
P
d
h
[
l
o
g
(
P
θ
h
(
w
)
]
=
E
P
d
h
[
l
o
g
(
e
x
p
(
s
θ
(
w
,
h
)
)
∑
w
e
x
p
(
s
θ
(
w
,
h
)
)
)
]
=
E
P
d
h
[
s
θ
(
w
,
h
)
]
−
E
P
d
h
[
l
o
g
(
∑
w
e
x
p
(
s
θ
(
w
,
h
)
)
)
]
=
E
P
d
h
[
s
θ
(
w
,
h
)
]
−
l
o
g
(
∑
w
e
x
p
(
s
θ
(
w
,
h
)
)
)
(9)
\begin{equation}\begin{aligned} J^h(\theta)&=E_{P_d^h} \left[log(P_{\theta}^h(w)\right] \\ &= E_{P_d^h} \left[log\left(\frac{exp(s_{\theta}(w,h))}{\sum_w{exp(s_{\theta}(w,h))}}\right)\right]\\ &=E_{P_d^h}\left[s_{\theta}(w,h)\right]-E_{P_d^h}\left[log\left(\sum_w{exp\left(s_{\theta}(w,h)\right)}\right)\right]\\ &=E_{P_d^h}\left[s_{\theta}(w,h)\right]-log\left(\sum_w{exp\left(s_{\theta}(w,h)\right)}\right) \end{aligned} \end{equation} \tag{9}
Jh(θ)=EPdh[log(Pθh(w)]=EPdh[log(∑wexp(sθ(w,h))exp(sθ(w,h)))]=EPdh[sθ(w,h)]−EPdh[log(w∑exp(sθ(w,h)))]=EPdh[sθ(w,h)]−log(w∑exp(sθ(w,h)))(9)
等式最后一步成立的原因是
[
l
o
g
(
∑
w
e
x
p
(
s
θ
(
w
,
h
)
)
)
]
\left[log\left(\sum_w{exp\left(s_{\theta}(w,h)\right)}\right)\right]
[log(∑wexp(sθ(w,h)))]仅和模型预测分布
P
θ
h
P_{\theta}^h
Pθh有关,和真实数据分布
P
d
h
P_d^h
Pdh无关。
对(9)式求梯度,有
∂
∂
θ
J
h
(
θ
)
=
E
P
d
h
[
∂
∂
θ
s
θ
(
w
,
h
)
]
−
∂
∂
θ
l
o
g
(
∑
w
e
x
p
(
s
θ
(
w
,
h
)
)
)
=
E
P
d
h
[
∂
∂
θ
s
θ
(
w
,
h
)
]
−
1
∑
w
e
x
p
(
s
θ
(
w
,
h
)
)
∂
∂
θ
∑
w
e
x
p
(
s
θ
(
w
,
h
)
)
=
E
P
d
h
[
∂
∂
θ
s
θ
(
w
,
h
)
]
−
1
∑
w
e
x
p
(
s
θ
(
w
,
h
)
)
∑
w
(
s
θ
(
w
,
h
)
∂
∂
θ
s
θ
(
w
,
h
)
)
=
E
P
d
h
[
∂
∂
θ
s
θ
(
w
,
h
)
]
−
∑
w
s
θ
(
w
,
h
)
∑
w
e
x
p
(
s
θ
(
w
,
h
)
)
∂
∂
θ
s
θ
(
w
,
h
)
=
E
P
d
h
[
∂
∂
θ
s
θ
(
w
,
h
)
]
−
∑
w
P
θ
h
(
w
)
∂
∂
θ
s
θ
(
w
,
h
)
=
E
P
d
h
[
∂
∂
θ
s
θ
(
w
,
h
)
]
−
∑
w
P
θ
h
(
w
)
∂
∂
θ
s
θ
(
w
,
h
)
=
∑
w
P
d
h
∂
∂
θ
s
θ
(
w
,
h
)
−
∑
w
P
θ
h
(
w
)
∂
∂
θ
s
θ
(
w
,
h
)
=
∑
w
(
P
d
h
(
w
)
−
P
θ
h
(
w
)
)
∂
∂
θ
s
θ
(
w
,
h
)
(10)
\begin{equation}\begin{aligned} \frac{\partial}{\partial\theta}J^h(\theta)&=E_{P_d^h}\left[\frac{\partial}{\partial\theta}s_{\theta}(w,h)\right]-\frac{\partial}{\partial\theta}log\left(\sum_w{exp\left(s_{\theta}(w,h)\right)}\right)\\ &=E_{P_d^h}\left[\frac{\partial}{\partial\theta}s_{\theta}(w,h)\right]-\frac{1}{\sum_w{exp\left(s_{\theta}(w,h)\right)}}\frac{\partial}{\partial\theta}\sum_w{exp\left(s_{\theta}(w,h)\right)}\\ &=E_{P_d^h}\left[\frac{\partial}{\partial\theta}s_{\theta}(w,h)\right]-\frac{1}{\sum_w{exp\left(s_{\theta}(w,h)\right)}}\sum_w\left(s_{\theta}(w,h)\frac{\partial}{\partial\theta}s_{\theta}(w,h)\right)\\ &=E_{P_d^h}\left[\frac{\partial}{\partial\theta}s_{\theta}(w,h)\right]-\sum_w\frac{s_{\theta}(w,h)}{\sum_w{exp\left(s_{\theta}(w,h)\right)}}\frac{\partial}{\partial\theta}s_{\theta}(w,h)\\ &=E_{P_d^h}\left[\frac{\partial}{\partial\theta}s_{\theta}(w,h)\right]-\sum_wP_{\theta}^h(w)\frac{\partial}{\partial\theta}s_{\theta}(w,h)\\ &=E_{P_d^h}\left[\frac{\partial}{\partial\theta}s_{\theta}(w,h)\right]-\sum_wP_{\theta}^h(w)\frac{\partial}{\partial\theta}s_{\theta}(w,h)\\ &=\sum_wP_d^h\frac{\partial}{\partial\theta}s_{\theta}(w,h)-\sum_wP_{\theta}^h(w)\frac{\partial}{\partial\theta}s_{\theta}(w,h)\\ &=\sum_w(P_d^h(w)-P_{\theta}^h(w))\frac{\partial}{\partial\theta}s_{\theta}(w,h)\\ \end{aligned} \end{equation} \tag{10}
∂θ∂Jh(θ)=EPdh[∂θ∂sθ(w,h)]−∂θ∂log(w∑exp(sθ(w,h)))=EPdh[∂θ∂sθ(w,h)]−∑wexp(sθ(w,h))1∂θ∂w∑exp(sθ(w,h))=EPdh[∂θ∂sθ(w,h)]−∑wexp(sθ(w,h))1w∑(sθ(w,h)∂θ∂sθ(w,h))=EPdh[∂θ∂sθ(w,h)]−w∑∑wexp(sθ(w,h))sθ(w,h)∂θ∂sθ(w,h)=EPdh[∂θ∂sθ(w,h)]−w∑Pθh(w)∂θ∂sθ(w,h)=EPdh[∂θ∂sθ(w,h)]−w∑Pθh(w)∂θ∂sθ(w,h)=w∑Pdh∂θ∂sθ(w,h)−w∑Pθh(w)∂θ∂sθ(w,h)=w∑(Pdh(w)−Pθh(w))∂θ∂sθ(w,h)(10)对比公式(8)和公式(10),很像,但不一样。公式(8)最后是
∂
∂
θ
l
o
g
P
θ
h
(
w
)
\frac{\partial}{\partial\theta}logP_{\theta}^h(w)
∂θ∂logPθh(w),公式(10)最后是
∂
∂
θ
s
θ
(
w
,
h
)
\frac{\partial}{\partial\theta}s_{\theta}(w,h)
∂θ∂sθ(w,h),咋回事?
不一样就对了,在NCE中,我们可以将
∑
w
e
x
p
(
s
θ
(
w
,
h
)
)
\sum_w{exp\left(s_{\theta}(w,h)\right)}
∑wexp(sθ(w,h))等价成1,那公式(8)和公式(10)就一样了。那为什么可以等价呢?论文的说辞是:
模型参数较多,把正则项当做常数,公式中其他项,比如
s
θ
,能学到正则项。
\textcolor{red}{{模型参数较多,把正则项当做常数,公式中其他项,比如s_{\theta},能学到正则项。}}
模型参数较多,把正则项当做常数,公式中其他项,比如sθ,能学到正则项。(正则项可以理解为
∑
w
e
x
p
(
s
θ
(
w
,
h
)
)
\sum_w{exp\left(s_{\theta}(w,h)\right)}
∑wexp(sθ(w,h))),那么
∑
w
e
x
p
(
s
θ
(
w
,
h
)
)
\sum_w{exp\left(s_{\theta}(w,h)\right)}
∑wexp(sθ(w,h))是1也好,100也好,都不会对模型收敛有影响。简单起见,当做1就行。
这段说辞还是太抽象了,有没有形象一点的解释?
两个任务为什么可以等价
原多分类任务
J
h
(
θ
)
=
E
P
d
h
[
l
o
g
(
P
θ
h
(
w
)
]
=
E
P
d
h
[
l
o
g
(
e
x
p
(
s
θ
(
w
,
h
)
)
∑
w
e
x
p
(
s
θ
(
w
,
h
)
)
)
]
(11)
\begin{equation}\begin{aligned} J^h(\theta)&=E_{P_d^h} \left[log(P_{\theta}^h(w)\right] \\ &= E_{P_d^h} \left[log\left(\frac{exp(s_{\theta}(w,h))}{\sum_w{exp(s_{\theta}(w,h))}}\right)\right] \end{aligned} \end{equation} \tag{11}
Jh(θ)=EPdh[log(Pθh(w)]=EPdh[log(∑wexp(sθ(w,h))exp(sθ(w,h)))](11)
该任务的对数似然期望见公式(11),
l
o
g
log
log函数曲线如下:
如果
l
o
g
(
P
θ
h
(
w
)
=
e
x
p
(
s
θ
(
w
,
h
)
)
∈
[
0
,
+
∞
]
log(P_{\theta}^h(w)=exp(s_{\theta}(w,h))\in[0,+\infty]
log(Pθh(w)=exp(sθ(w,h))∈[0,+∞],
J
h
(
θ
)
=
E
P
d
h
[
l
o
g
(
P
θ
h
(
w
)
]
J^h(\theta)=E_{P_d^h} \left[log(P_{\theta}^h(w)\right]
Jh(θ)=EPdh[log(Pθh(w)]不存在极值,无法收敛。
如果对
l
o
g
(
P
θ
h
(
w
)
=
e
x
p
(
s
θ
(
w
,
h
)
)
∈
[
0
,
+
∞
]
log(P_{\theta}^h(w)=exp(s_{\theta}(w,h))\in[0,+\infty]
log(Pθh(w)=exp(sθ(w,h))∈[0,+∞]进行归一化,
l
o
g
(
P
θ
h
(
w
)
=
[
l
o
g
(
e
x
p
(
s
θ
(
w
,
h
)
)
∑
w
e
x
p
(
s
θ
(
w
,
h
)
)
)
]
∈
(
0
,
1
)
log(P_{\theta}^h(w)=\left[log\left(\frac{exp(s_{\theta}(w,h))}{\sum_w{exp(s_{\theta}(w,h))}}\right)\right]\in(0,1)
log(Pθh(w)=[log(∑wexp(sθ(w,h))exp(sθ(w,h)))]∈(0,1),
J
h
(
θ
)
=
E
P
d
h
[
l
o
g
(
P
θ
h
(
w
)
]
J^h(\theta)=E_{P_d^h} \left[log(P_{\theta}^h(w)\right]
Jh(θ)=EPdh[log(Pθh(w)]存在极值,具备收敛条件。
现二分类任务
从公式(5)可知,
J
h
(
θ
)
=
E
[
l
o
g
(
P
h
(
D
∣
w
,
θ
)
)
]
=
E
P
d
h
[
l
o
g
P
h
(
D
=
1
∣
w
,
θ
)
]
+
E
P
n
[
l
o
g
P
h
(
D
=
0
∣
w
,
θ
)
]
=
E
P
d
h
[
l
o
g
P
θ
h
(
w
)
P
θ
h
(
w
)
+
k
P
n
(
w
)
]
+
E
P
n
[
l
o
g
k
P
n
(
w
)
P
θ
h
(
w
)
+
k
P
n
(
w
)
]
=
E
P
d
h
[
l
o
g
(
σ
(
Δ
)
)
]
+
k
E
P
n
[
l
o
g
(
1
−
σ
(
Δ
)
)
]
\begin{equation}\begin{aligned} J^h(\theta)&=E \left[log(P^h(D|w,\theta))\right] \\ &= E_{P_d^h}\left[logP^h(D=1|w,\theta)\right] +E_{P_n}\left[logP^h(D=0|w,\theta)\right] \\ &= E_{P_d^h}\left[log\frac{P_{\theta}^h(w)}{P_{\theta}^h(w)+kP_n(w)}\right] +E_{P_n}\left[log\frac{kP_n(w)}{P_{\theta}^h(w)+kP_n(w)}\right] \\ &= E_{P_d^h}\left[log(\sigma({\Delta}))\right] +kE_{P_n}\left[log(1-\sigma({\Delta}))\right] \\ \end{aligned} \tag{12}\end{equation}
Jh(θ)=E[log(Ph(D∣w,θ))]=EPdh[logPh(D=1∣w,θ)]+EPn[logPh(D=0∣w,θ)]=EPdh[logPθh(w)+kPn(w)Pθh(w)]+EPn[logPθh(w)+kPn(w)kPn(w)]=EPdh[log(σ(Δ))]+kEPn[log(1−σ(Δ))](12)
,其中
Δ
=
l
o
g
P
θ
h
(
w
)
−
l
o
g
k
P
n
(
w
)
\Delta=logP_{\theta}^h(w)-logkP_n(w)
Δ=logPθh(w)−logkPn(w),将公式(5)推导成具备
σ
\sigma
σ的公式(12),原因在于求导方便,
∂
∂
x
σ
(
x
)
=
σ
(
x
)
(
1
−
σ
(
x
)
)
\frac{\partial}{\partial x}\sigma(x)=\sigma(x)(1-\sigma(x))
∂x∂σ(x)=σ(x)(1−σ(x)),将公式(5)推导成公式(12)的过程是:
P
θ
h
(
w
)
P
θ
h
(
w
)
+
k
P
n
(
w
)
=
1
1
+
k
P
n
(
w
)
P
θ
h
(
w
)
=
1
1
+
e
x
p
(
l
o
g
(
k
P
n
(
w
)
P
θ
h
(
w
)
)
)
=
1
1
+
e
x
p
(
l
o
g
k
P
n
(
w
)
−
l
o
g
P
θ
h
(
w
)
)
=
1
1
+
e
x
p
(
−
(
l
o
g
P
θ
h
(
w
)
−
l
o
g
k
P
n
(
w
)
)
)
=
σ
(
l
o
g
P
θ
h
(
w
)
−
l
o
g
k
P
n
(
w
)
)
\begin{equation}\begin{aligned} \frac{P_{\theta}^h(w)}{P_{\theta}^h(w)+kP_n(w)}&=\frac{1}{1+\frac{kP_n(w)}{P_{\theta}^h(w)}}\\ &=\frac{1}{1+exp(log(\frac{kP_n(w)}{P_{\theta}^h(w)}))}\\ &=\frac{1}{1+exp(logkP_n(w)-logP_{\theta}^h(w))}\\ &=\frac{1}{1+exp(-(logP_{\theta}^h(w)-logkP_n(w)))}\\ &=\sigma(logP_{\theta}^h(w)-logkP_n(w))\\ \end{aligned} \tag{12}\end{equation}
Pθh(w)+kPn(w)Pθh(w)=1+Pθh(w)kPn(w)1=1+exp(log(Pθh(w)kPn(w)))1=1+exp(logkPn(w)−logPθh(w))1=1+exp(−(logPθh(w)−logkPn(w)))1=σ(logPθh(w)−logkPn(w))(12)
k
P
n
(
w
)
P
θ
h
(
w
)
+
k
P
n
(
w
)
=
1
−
P
θ
h
(
w
)
P
θ
h
(
w
)
+
k
P
n
(
w
)
=
1
−
σ
(
l
o
g
P
θ
h
(
w
)
−
l
o
g
k
P
n
(
w
)
)
\begin{equation}\begin{aligned} \frac{kP_n(w)}{P_{\theta}^h(w)+kP_n(w)}&=1-\frac{P_{\theta}^h(w)}{P_{\theta}^h(w)+kP_n(w)}\\ &=1-\sigma(logP_{\theta}^h(w)-logkP_n(w))\\ \end{aligned} \tag{13}\end{equation}
Pθh(w)+kPn(w)kPn(w)=1−Pθh(w)+kPn(w)Pθh(w)=1−σ(logPθh(w)−logkPn(w))(13)
于是,计算对数似然均值(公式(12))对
l
o
g
P
θ
h
(
w
)
logP_{\theta}^h(w)
logPθh(w)的一阶导,有
∂
J
h
(
θ
)
∂
l
o
g
P
θ
h
(
w
)
=
∂
J
h
(
θ
)
∂
Δ
∂
Δ
∂
l
o
g
P
θ
h
(
w
)
=
∂
J
h
(
θ
)
∂
Δ
=
∂
∂
Δ
{
E
P
d
h
[
l
o
g
(
σ
(
Δ
)
)
]
+
k
E
P
n
[
l
o
g
(
1
−
σ
(
Δ
)
)
]
}
=
E
P
d
h
[
∂
∂
Δ
l
o
g
(
σ
(
Δ
)
)
]
+
k
E
P
n
[
∂
∂
Δ
l
o
g
(
1
−
σ
(
Δ
)
)
]
=
E
P
d
h
[
1
−
σ
(
Δ
)
]
+
k
E
P
n
[
−
σ
(
Δ
)
]
=
∑
w
P
θ
h
(
w
)
(
1
−
σ
(
Δ
)
)
−
k
P
n
(
w
)
σ
(
Δ
)
\begin{equation}\begin{aligned} \frac{\partial J^h(\theta)}{\partial logP_{\theta}^h(w)} &=\frac{\partial J^h(\theta)}{\partial \Delta}\frac{\partial \Delta}{\partial logP_{\theta}^h(w)}\\ &=\frac{\partial J^h(\theta)}{\partial \Delta}\\ &=\frac{\partial }{\partial \Delta}\left\{E_{P_d^h}\left[log(\sigma({\Delta}))\right] +kE_{P_n}\left[log(1-\sigma({\Delta}))\right]\right\}\\ &=E_{P_d^h}\left[\frac{\partial }{\partial \Delta}log(\sigma({\Delta}))\right] +kE_{P_n}\left[\frac{\partial }{\partial \Delta}log(1-\sigma({\Delta}))\right]\\ &=E_{P_d^h}\left[1-\sigma({\Delta})\right] +kE_{P_n}\left[-\sigma({\Delta})\right]\\ &=\sum_wP_{\theta}^h(w)(1-\sigma({\Delta}))-kP_n(w)\sigma({\Delta})\\ \end{aligned} \tag{14}\end{equation}
∂logPθh(w)∂Jh(θ)=∂Δ∂Jh(θ)∂logPθh(w)∂Δ=∂Δ∂Jh(θ)=∂Δ∂{EPdh[log(σ(Δ))]+kEPn[log(1−σ(Δ))]}=EPdh[∂Δ∂log(σ(Δ))]+kEPn[∂Δ∂log(1−σ(Δ))]=EPdh[1−σ(Δ)]+kEPn[−σ(Δ)]=w∑Pθh(w)(1−σ(Δ))−kPn(w)σ(Δ)(14)
如果
P
θ
h
(
w
)
=
P
d
h
(
w
)
P_{\theta}^h(w)=P_d^h(w)
Pθh(w)=Pdh(w),对数似然均值达到极大值(这个是废话,因为训练目标就是希望
P
θ
h
(
w
)
→
P
d
h
(
w
)
P_{\theta}^h(w)\to P_d^h(w)
Pθh(w)→Pdh(w),并且在优化策略章节开始部分,我们就让
P
θ
h
(
w
)
=
P
d
h
(
w
)
P_{\theta}^h(w)= P_d^h(w)
Pθh(w)=Pdh(w))其中
P
d
h
(
w
)
P_d^h(w)
Pdh(w)表示真实分布。
我们再计算对数似然均值(公式(12))对
l
o
g
P
θ
h
(
w
)
logP_{\theta}^h(w)
logPθh(w)的二阶导,有:
∂
2
J
h
(
θ
)
∂
l
o
g
2
P
θ
h
(
w
)
=
∂
2
J
(
θ
)
∂
Δ
2
=
∂
∂
Δ
{
E
P
d
h
[
1
−
σ
(
Δ
)
]
+
k
E
P
n
[
−
σ
(
Δ
)
]
}
=
E
P
d
h
∂
∂
Δ
[
1
−
σ
(
Δ
)
]
+
k
E
P
n
∂
∂
Δ
[
−
σ
(
Δ
)
]
=
E
P
d
h
[
−
σ
(
Δ
)
(
1
−
σ
(
Δ
)
)
]
+
k
E
P
n
[
−
σ
(
Δ
)
(
1
−
σ
(
Δ
)
)
]
\begin{equation}\begin{aligned} \frac{\partial^2 J^h(\theta)}{\partial log^2P_{\theta}^h(w)} &=\frac{\partial^2J(\theta)}{\partial \Delta^2}\\ &=\frac{\partial}{\partial \Delta} \left\{E_{P_d^h}\left[1-\sigma({\Delta})\right] +kE_{P_n}\left[-\sigma({\Delta})\right] \right\} \\ &= E_{P_d^h}\frac{\partial}{\partial \Delta}\left[1- \sigma({\Delta})\right] +kE_{P_n}\frac{\partial}{\partial \Delta}\left[-\sigma({\Delta})\right] \\ &= E_{P_d^h}[-\sigma(\Delta)(1-\sigma(\Delta))] +kE_{P_n}[-\sigma(\Delta)(1-\sigma(\Delta))] \\ \end{aligned} \tag{14}\end{equation}
∂log2Pθh(w)∂2Jh(θ)=∂Δ2∂2J(θ)=∂Δ∂{EPdh[1−σ(Δ)]+kEPn[−σ(Δ)]}=EPdh∂Δ∂[1−σ(Δ)]+kEPn∂Δ∂[−σ(Δ)]=EPdh[−σ(Δ)(1−σ(Δ))]+kEPn[−σ(Δ)(1−σ(Δ))](14)
因为
[
−
σ
(
Δ
)
(
1
−
σ
(
Δ
)
)
]
[-\sigma(\Delta)(1-\sigma(\Delta))]
[−σ(Δ)(1−σ(Δ))]始终小于0,所以二阶导始终小于0,说明新二分类任务的对数似然均值是关于
l
o
g
P
θ
h
(
w
)
logP_{\theta}^h(w)
logPθh(w)的凸函数,有唯一极大值。所以极大值一定是
P
θ
h
(
w
)
=
P
h
(
w
)
P_{\theta}^h(w)=P^h(w)
Pθh(w)=Ph(w)。
最重要的是,整个推导过程对是否需要归一化没有要求,既然没有要求,直接让
∑
w
e
x
p
(
s
θ
(
w
,
h
)
)
=
1
\sum_w{exp\left(s_{\theta}(w,h)\right)}=1
∑wexp(sθ(w,h))=1
代码实现
从公式(12),我们可以知道:
Δ
=
l
o
g
P
θ
h
(
w
)
−
l
o
g
k
P
n
(
w
)
\Delta=logP_{\theta}^h(w)-logkP_n(w)
Δ=logPθh(w)−logkPn(w)
J
h
(
θ
)
=
E
[
l
o
g
(
P
h
(
D
∣
w
,
θ
)
)
]
=
E
P
d
h
[
l
o
g
σ
(
Δ
)
]
+
k
E
P
n
[
l
o
g
(
1
−
σ
(
Δ
)
)
]
=
E
P
d
h
[
l
o
g
σ
(
l
o
g
P
θ
h
(
w
)
−
l
o
g
k
P
n
(
w
)
)
]
+
k
E
P
n
[
l
o
g
(
1
−
σ
(
l
o
g
P
θ
h
(
w
)
−
l
o
g
k
P
n
(
w
)
)
)
]
=
∑
w
{
P
d
h
[
l
o
g
σ
(
l
o
g
P
θ
h
(
w
)
−
l
o
g
k
P
n
(
w
)
)
]
}
+
k
∑
w
{
P
n
[
l
o
g
(
1
−
σ
(
l
o
g
P
θ
h
(
w
)
−
l
o
g
k
P
n
(
w
)
)
)
]
}
→
l
o
g
(
σ
(
l
o
g
P
θ
h
(
w
0
)
−
l
o
g
k
P
n
(
w
0
)
)
+
∑
i
=
1
k
[
l
o
g
(
1
−
σ
(
l
o
g
P
θ
h
(
w
i
)
−
l
o
g
k
P
n
(
w
i
)
)
)
]
=
l
o
g
(
σ
(
s
θ
(
w
0
,
h
)
−
l
o
g
k
P
n
(
w
0
)
)
+
∑
i
=
1
k
[
l
o
g
(
1
−
σ
(
s
θ
(
w
i
,
h
)
−
l
o
g
k
P
n
(
w
i
)
)
)
]
\begin{equation}\begin{aligned} J^h(\theta)&=E \left[log(P^h(D|w,\theta))\right] \\ &= E_{P_d^h}\left[log\sigma({\Delta})\right] +kE_{P_n}\left[log(1-\sigma({\Delta}))\right] \\ &= E_{P_d^h}\left[log\sigma(logP_{\theta}^h(w)-logkP_n(w))\right] +\\ &\quad\quad\quad\quad\quad\quad kE_{P_n}\left[log(1-\sigma(logP_{\theta}^h(w)-logkP_n(w)))\right] \\ &= \sum_w\left\{P_d^h\left[log\sigma(logP_{\theta}^h(w)-logkP_n(w))\right] \right\}+\\ &\quad\quad\quad\quad\quad\quad k\sum_w\left\{P_n\left[log(1-\sigma(logP_{\theta}^h(w)-logkP_n(w)))\right]\right\} \\ &\to log(\sigma(logP_{\theta}^h(w_0)-logkP_n(w_0)) +\\ &\quad\quad\quad\quad\quad\quad\sum_{i=1}^k\left[log(1-\sigma(logP_{\theta}^h(w_i)-logkP_n(w_i)))\right] \\ &=log(\sigma(s_{\theta}(w_0,h)-logkP_n(w_0)) +\\ &\quad\quad\quad\quad\quad\quad\sum_{i=1}^k\left[log(1-\sigma(s_{\theta}(w_i,h)-logkP_n(w_i)))\right] \\ \end{aligned} \tag{15}\end{equation}
Jh(θ)=E[log(Ph(D∣w,θ))]=EPdh[logσ(Δ)]+kEPn[log(1−σ(Δ))]=EPdh[logσ(logPθh(w)−logkPn(w))]+kEPn[log(1−σ(logPθh(w)−logkPn(w)))]=w∑{Pdh[logσ(logPθh(w)−logkPn(w))]}+kw∑{Pn[log(1−σ(logPθh(w)−logkPn(w)))]}→log(σ(logPθh(w0)−logkPn(w0))+i=1∑k[log(1−σ(logPθh(wi)−logkPn(wi)))]=log(σ(sθ(w0,h)−logkPn(w0))+i=1∑k[log(1−σ(sθ(wi,h)−logkPn(wi)))](15)
具体实现时,正样本项仅考虑目标class,负样本项随机选择k个样本,通过蒙特卡洛来模拟抽样。
那最终损失函数代码应该怎么写呢?
l o s s = − J h ( θ ) = − l o g ( σ ( s θ ( w 0 , h ) − l o g k P n ( w 0 ) ) ) − ∑ i = 1 k [ l o g ( 1 − σ ( s θ ( w i , h ) − l o g k P n ( w i ) ) ) ] \begin{equation}\begin{aligned} loss &= -J^h(\theta) \\ &=-log(\sigma(s_{\theta}(w_0,h)-logkP_n(w_0))) - \\ &\quad\quad\quad\quad\quad\quad\sum_{i=1}^k\left[log(1-\sigma(s_{\theta}(w_i,h)-logkP_n(w_i)))\right] \\ \end{aligned} \tag{16}\end{equation} loss=−Jh(θ)=−log(σ(sθ(w0,h)−logkPn(w0)))−i=1∑k[log(1−σ(sθ(wi,h)−logkPn(wi)))](16)
公式(16)中有四个项输入,分别是
- s θ ( w 0 , h ) s_{\theta}(w_0,h) sθ(w0,h),目标class的logit
- P n ( w 0 ) P_n(w_0) Pn(w0),目标class的噪声分布
- s θ ( w i , h ) s_{\theta}(w_i,h) sθ(wi,h),噪声class的logit
- P n ( w i ) P_n(w_i) Pn(wi),噪声class的噪声分布
from torch import randn, tensor, log, multinomial
import torch.nn.functional as F
from einops import repeat
import torch
import math
bs,k=2,8
num_classes=16
#构造噪声:按照类别的频率采样
#(噪声分布约等于实际数据分布,两个分布越接近,nce效果越好)
classes=[0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15]
class_freq=tensor([20,10,30,5,45,56,76,43,23,11,34,5,6,54,23,7])
class_probs=class_freq/class_freq.sum()
noise_classes=multinomial(class_probs, num_classes)
#模型预测的logits
logits=randn(bs, num_classes)
#2个样本的标签
labels=tensor([2, 4])
#目标class的logit
true_class_logits=logits.take_along_dim(labels[:, None], dim=1)
#目标class的噪声分布
true_class_noise=class_probs[labels]
#噪声class的logit
logits_k = repeat(logits, '(b 1) h -> (b k) h', k=k)
noise_class_logits = logits_k.take_along_dim(noise_classes.reshape(bs * k, -1), dim=1)
#噪声class的噪声分布
noise_class_noise=class_probs[noise_classes]
#nce loss计算
true_class_loss = -torch.log( F.sigmoid(true_class_logits - torch.log(k*true_class_noise))).mean()
noise_class_loss = -torch.log( 1-F.sigmoid(noise_class_logits - torch.log(k*noise_class_noise))).mean()
loss = true_class_loss+noise_class_loss
print("nce loss is {:.4f}".format(loss))