目录
- 0、基本信息
- 1、研究动机
- 2、创新点
- 2.1、核心思想:
- 2.2、思想推导:
- 3、准备
- 3.1、符号
- 3.2、互信息
- 3.3、JS散度
- 3.4、Deep InfoMax方法
- 3.5、判别器:f-GAN估计散度
- 4、具体实现
- 4.1、局部-全局互信息最大化
- 4.2、理论动机
- 5、实验设置
- 5.1、直推式学习(Cora,Citeseer and Pubmed)
- 6、代码实现
- 6.1、DGI
- 6.2、GCNLayer
- 6.3、readout function
- 6.4、discriminator
- 7、参考链接
0、基本信息
- 会议:2019-ICLR
- 作者:Petar Veliˇckovi´,William Fedus
- 文章链接:Deep Graph Infomax
- 代码链接:Deep Graph Infomax
1、研究动机
(1)无监督图学习的重要性
尽管图神经网络取得了显著的进步,但是大多数方法采用监督学习的方法,然而,在真实世界中,图的标签是较少的,这些方法很难推广到大量的无标签的图数据中。因此,对于很多重要任务而言就显得不可或缺。
(2)现有方法的缺点
目前,主流的用于图结构数据表征学习的无监督算法,random walk-based objectives,有时简化为重构邻域信息,基本思想为训练编码器使输入图中的接近的结点在表征空间中也接近。但也存在着如下的缺点:
- random walk-based objectives以牺牲结构信息为代价过度强调邻近信息
- 性能很大程度上取决于超参数的选择
- 基于图卷积编码器模型的引入,不清楚random-walk objectives是否提供了有用信号。
基于上述的缺点与不足,本文提出了用于无监督图学习的替代目标,其基于互信息而不是随机游走。将Deep InfoMax(DIM)引入图结构数据,提出Deep Graph Infomax(DGI)模型。
在概率论和信息论中,两个随机变量的互信息(Mutual Information,简称MI)是指变量间相互依赖性的量度,度量两个事件集合之间的相关性(mutual dependence)。
2、创新点
- 将DIM引入图领域。
2.1、核心思想:
训练一个编码器,它的目标函数,不是最小化输入与输出的MSE,而是最大化输入与输出的互信息。
重构误差小,不能说明学习出来的特征好,好特征应该是能够提取出样本的最独特,具体的信息。那如何衡量学习出来的信息是该样本独特的呢?这里就是用“互信息”来衡量。
2.2、思想推导:
->1、首先,我们已知要用互信息来衡量学习特征的好坏,也就是说,最大化互信息是我们的目标。
->2、最大化互信息可以转化为最大化JS散度
->3、JS散度计算困难,通过“局部变分技巧”快速估算,得到JS散度的估算值,即“负采样估计”。
->4、“负采样估计”——引入一个判别网络
σ
(
T
(
x
,
z
)
)
\sigma(T(x,z))
σ(T(x,z)),
x
x
x 及其对应的
z
z
z 视为一个正样本对,
x
x
x 及随机抽取的
z
^
\hat{z}
z^ 则视为负样本,然后最大化似然函数,等价于最小化交叉熵,得到最终的优化目标函数。
->5、接下来,定义细节内容。。。
3、准备
3.1、符号
- 结点特征: X = { x ⃗ 1 , x ⃗ 2 , … , x ⃗ N } \mathbf{X}=\{\vec{x}_{1},\vec{x}_{2},\ldots,\vec{x}_{N}\} X={x1,x2,…,xN}, N N N是图中结点的个数, X ⃗ i ∈ R F \vec{X}_i\in \mathbb{R}^F Xi∈RF表示结点i的特征。
- 邻接矩阵(无权图): A ∈ R N × N A\in \mathbb{R}^{N \times N} A∈RN×N,如果结点i与结点j相连 A i j = 1 A_{ij}=1 Aij=1,否则 A i j = 0 A_{ij}=0 Aij=0
- 编码器: E : R N × F × R N × N → R N × F ′ \mathcal{E}:\mathbb{R}^{N \times F} \times \mathbb{R}^{N \times N} \to \mathbb{R}^{N \times F'} E:RN×F×RN×N→RN×F′, E ( X , A ) = H = { h ⃗ 1 , h ⃗ 2 , … , h ⃗ N } \mathcal{E}(X,A)=H=\{\vec{h}_{1},\vec{h}_{2},\ldots,\vec{h}_{N}\} E(X,A)=H={h1,h2,…,hN},表示结点 i i i的高阶表征 h ⃗ i ∈ R F ′ \vec{h}_{i}\in \mathbb{R}^{F'} hi∈RF′,称为patch representions
- Readout Function: R : R N × F → R F \mathcal{R}:\mathbb{R}^{N \times F} \to \mathbb{R}^F R:RN×F→RF, s ⃗ = R ( E ( X , A ) ) \vec{s}=\mathcal{R}(\mathcal{E}(X,A)) s=R(E(X,A))为图级别的表示
- 判别器: D : R F × R F → R \mathcal{D}:\mathbb{R}^{F} \times \mathbb{R}^F \to \mathbb{R} D:RF×RF→R, D ( h ⃗ i , s ⃗ ) \mathcal{D}(\vec{h}_i,\vec{s}) D(hi,s)表示分配给该patch-summary对的概率分数
- Corruption Function : C : R N × F × R N × N → R M × F × R M × M \mathcal{C}:\mathbb{R}^{N \times F} \times \mathbb{R}^{N \times N} \to\mathbb{R}^{M \times F} \times \mathbb{R}^{M \times M} C:RN×F×RN×N→RM×F×RM×M,从原始图获得一个负例样本,即 ( X ~ , A ~ ) = C ( X , A ) (\tilde{X},\tilde{A})=\mathcal{C}(X,A) (X~,A~)=C(X,A).
3.2、互信息
互信息(Mutual Information)是度量两个事件集合之间的相关性(mutual dependence),它是信息论里一种有用的信息度量,它可以看成是一个随机变量中包含的关于另一个随机变量的信息量,或者说是一个随机变量由于已知另一个随机变量而减少的不肯定性。互信息最常用的单位是bit。
互信息指的是两个随机变量之间的关联程度,即给定一个随机变量后,另一个随机变量不确定性的削弱程度,因而互信息取值最小为0,意味着给定一个随机变量对确定一另一个随机变量没有关系,最大取值为随机变量的熵,意味着给定一个随机变量,能完全消除另一个随机变量的不确定性。
3.3、JS散度
JS散度度量了两个概率分布的相似度,基于KL散度(相对熵)的变体,解决了KL散度非对称的问题。一般地,JS散度是对称的,其取值是0到1之间。
设概率空间上有两个概率分布
P
P
P和
Q
Q
Q,
M
=
1
2
(
P
+
Q
)
M=\frac{1}{2}(P+Q)
M=21(P+Q)为
P
P
P和
Q
Q
Q的平均,则
P
P
P和
Q
Q
Q的JS散度定义为:
J
S
(
P
∣
∣
Q
)
=
1
2
D
K
L
(
P
∣
∣
M
)
+
1
2
D
K
L
(
Q
∣
∣
M
)
JS(P||Q)=\frac{1}{2}D_{KL}(P||M)+\frac{1}{2}D_{KL}(Q||M)
JS(P∣∣Q)=21DKL(P∣∣M)+21DKL(Q∣∣M)
其中,
D
K
L
D_{KL}
DKL表示
K
L
KL
KL 散度,定义如下:
D
K
L
(
P
∣
∣
Q
)
=
∑
x
∈
X
P
(
x
)
l
o
g
(
P
(
x
)
Q
(
x
)
)
D_{KL}(P||Q)=\sum_{x\in X}P(x)log(\frac{P(x)}{Q(x)})
DKL(P∣∣Q)=x∈X∑P(x)log(Q(x)P(x))
3.4、Deep InfoMax方法
许多表示学习算法使用像素级的训练目标,当只有一小部分信号在语义层面上起作用时是不利的。Bengio 等研究者假设应该更直接地根据信息内容和统计或架构约束来学习表示,据此提出了 Deep InfoMax(DIM)。该方法可用于学习期望特征的表示,并且在分类任务上优于许多流行的无监督学习方法。
互信息是概率论和信息论中重要的内容,它表示的是一个随机变量中包含另一个随机变量的信息量,可以理解成两个随机变量之间的相关程度。最大化互信息,也就是说对于每个输入样本x,编码器能够尽可能地找出专属于样本x的特征y。因此,这样一来,只通过特征y,也能很好地分辨出原始样本来(因为学习到的特征含有样本的独特信息)。
于是,最大化互信息的计算过程,有了以下的转换:
I ( X ; Y ) = ∫ Y ∫ X p ( x , y ) log p ( x , y ) p ( x ) p ( y ) d x d y = ∫ Y ∫ X p ( y ∣ x ) p ( x ) log p ( y ∣ x ) p ( y ) d x d y \begin{aligned} I(X ; Y) &=\int_{Y} \int_{X} p(x, y) \log \frac{p(x, y)}{p(x) p(y)} d x d y \\ &=\int_{Y} \int_{X} p(y | x) p(x) \log \frac{p(y | x)}{p(y)} d x d y \end{aligned} I(X;Y)=∫Y∫Xp(x,y)logp(x)p(y)p(x,y)dxdy=∫Y∫Xp(y∣x)p(x)logp(y)p(y∣x)dxdy
等价于
I
(
X
;
Y
)
=
K
L
(
p
(
x
,
y
)
∥
p
(
x
)
p
(
y
)
)
I(X ; Y)=K L(p(x, y) \| p(x) p(y))
I(X;Y)=KL(p(x,y)∥p(x)p(y))
最大化互信息,就是要拉大联合分布与边缘分布乘积的距离。
KL散度无上界,利用JS散度与KL散度之间的转换关系:
$$\begin{array}{l}
I(X ; Y) \propto J S(p(x, y), p(x) p(y))
\end{array}$$
此时就可以将最大化互信息这个问题转换成最大化JS散度。
这个过程详细可以看MINE-Mutual Information Neural Estimation
Deep InfoMax是在论文Learning deep representations by mutual information estimation and maximization中提出。
3.5、判别器:f-GAN估计散度
在机器学习中,计算两个概率分布P,Q的散度是有一定难度的,因为很多时候是无法知道两个概率分布的解析形式,或者分布只有采样出来的样本(这时就是比较两批样本之间的相似性)。
f-GAN是通过“局部变分技巧“来进行快速地估算。
D
f
(
P
∥
Q
)
=
max
T
(
E
x
∼
p
(
x
)
[
T
(
x
)
]
−
E
x
∼
q
(
x
)
[
g
(
T
(
x
)
)
]
)
\mathbb{D}_{\mathbf{f}}(\mathbf{P} \| \mathbf{Q})=\max _{\mathbf{T}}\left(\mathbb{E}_{\mathbf{x} \sim \mathbf{p}(\mathbf{x})}[\mathbf{T}(\mathbf{x})]-\mathbb{E}_{\mathbf{x} \sim \mathbf{q}(\mathbf{x})}[\mathbf{g}(\mathbf{T}(\mathbf{x}))]\right)
Df(P∥Q)=Tmax(Ex∼p(x)[T(x)]−Ex∼q(x)[g(T(x))])
分别从两个分布
P
\mathbf{P}
P和
Q
\mathbf{Q}
Q进行采样,然后计算
T
(
x
)
T(x)
T(x)与
g
(
T
(
x
)
)
g(T(x))
g(T(x))的平均值,优化
T
T
T,使得它们的差最大,最终的结果即为散度的估算值。T(x)可以用足够复杂的神经网络去拟合。
因此,最大化互信息的目标函数为:
J
S
(
p
(
x
,
y
)
,
p
(
x
)
p
(
y
)
)
=
max
D
(
E
(
x
,
y
)
∼
p
(
x
,
y
)
[
log
σ
(
D
(
x
,
y
)
)
]
+
E
x
~
∼
p
(
x
)
,
y
~
∼
p
(
y
)
[
log
(
1
−
σ
(
D
(
x
~
,
y
~
)
)
)
]
)
\mathbf{J S}(\mathbf{p}(\mathbf{x}, \mathbf{y}), \mathbf{p}(\mathbf{x}) \mathbf{p}(\mathbf{y}))=\max _{\mathbf{D}}\left(\mathbb{E}_{(\mathbf{x}, \mathbf{y}) \sim \mathbf{p}(\mathbf{x}, \mathbf{y})}[\log \sigma(\mathbf{D}(\mathbf{x}, \mathbf{y}))]+\mathbb{E}_{\tilde{\mathbf{x}} \sim \mathbf{p}(\mathbf{x}), \tilde{y} \sim \mathbf{p}(\mathbf{y})}[\log (1-\sigma(\mathbf{D}(\tilde{\mathbf{x}}, \tilde{\mathbf{y}})))]\right)
JS(p(x,y),p(x)p(y))=Dmax(E(x,y)∼p(x,y)[logσ(D(x,y))]+Ex~∼p(x),y~∼p(y)[log(1−σ(D(x~,y~)))])
这个公式实际上就是“负采样估计”:引入一个判别网络 σ(D(x,y)),x 及其对应的 y 视为一个正样本对,x 及随机抽取的 y 则视为负样本,然后最大化似然函数(等价于最小化交叉熵)。
4、具体实现
4.1、局部-全局互信息最大化
(1)、利用
r
e
a
d
o
u
t
f
u
n
c
t
i
o
n
readout\;\; function
readoutfunction将patch representions即
E
(
X
,
A
)
\mathcal{E}(X,A)
E(X,A),转化为图的全局信息
s
y
m
m
a
r
y
v
e
c
t
o
r
s
⃗
symmary \;\;vector\;\; \vec{s}
symmaryvectors,即
s
⃗
=
R
(
E
(
X
,
A
)
)
\vec{s}=\mathcal{R}(\mathcal{E}(X,A))
s=R(E(X,A));
(2)、采用判别器,
D
:
R
F
×
R
F
→
R
\mathcal{D}:\mathbb{R}^{F} \times \mathbb{R}^F \to \mathbb{R}
D:RF×RF→R,作为最大化局部互信息的近似,
D
(
h
⃗
i
,
s
⃗
)
\mathcal{D}(\vec{h}_i,\vec{s})
D(hi,s)表示分配给该patch-summary对的概率分数,若patch在summary内,则得分越高。
(3)、负样本。在多图情形下,可以直接从训练集的另一个图中选择一个图即可;在单图情况下,构造Corruption Function :
C
:
R
N
×
F
×
R
N
×
N
→
R
M
×
F
×
R
M
×
M
\mathcal{C}:\mathbb{R}^{N \times F} \times \mathbb{R}^{N \times N} \to\mathbb{R}^{M \times F} \times \mathbb{R}^{M \times M}
C:RN×F×RN×N→RM×F×RM×M,从原始图获得一个负例样本,即
(
X
~
,
A
~
)
=
C
(
X
,
A
)
(\tilde{X},\tilde{A})=\mathcal{C}(X,A)
(X~,A~)=C(X,A)。负样本的选择过程决定了我们要捕获特定类型的结构信息。
(4)、判别器
D
\mathcal{D}
D,可以通过将
s
⃗
\vec{s}
s与另一个图
(
X
~
,
A
~
)
(\tilde{X},\tilde{A})
(X~,A~)中的patch representations
h
~
⃗
j
\vec{\tilde{h}}_j
h~j得到;
(5)、本文使用与DIM一致的、使用噪音对比类型的目标函数,以联合分布(正样本)与边缘分布之积的标准二值交叉熵作为损失函数:
L
=
1
N
+
M
(
∑
i
=
1
N
E
(
X
,
A
)
[
log
D
(
h
⃗
i
,
s
⃗
)
]
+
∑
j
=
1
M
E
(
X
~
,
A
~
)
[
log
(
1
−
D
(
h
~
⃗
j
,
s
⃗
)
)
]
)
\mathcal{L}=\frac{1}{N+M}\left(\sum_{i=1}^{N}\mathbb{E}_{(\mathbf{X},\mathbf{A})}\left[\log\mathcal{D}\left(\vec{h}_i,\vec{s}\right)\right]+\sum_{j=1}^{M}\mathbb{E}_{(\tilde{\mathbf{X}},\tilde{\mathbf{A}})}\left[\log\left(1-\mathcal{D}\left(\vec{\widetilde{h}}_j,\vec{s}\right)\right)\right]\right)
L=N+M1(i=1∑NE(X,A)[logD(hi,s)]+j=1∑ME(X~,A~)[log(1−D(h
j,s))])
基于联合分布和边缘分布之积的JS散度,可以有效地最大化 h ⃗ i \vec{h}_i hi和 s ⃗ \vec{s} s的互信息。
4.2、理论动机
(1)引理1
给定K个图,每个图的结点表示的集合为
X
(
k
)
X^{(k)}
X(k),每个图从分布
p
(
X
)
p(X)
p(X)中被选中的概率是均匀的,那么,联合概率与边缘概率之间的最优分类器在类平衡的条件下,错误的上限是:
E
r
r
∗
=
1
2
∑
k
=
1
∣
X
∣
p
(
s
⃗
k
)
2
\mathrm{Err}^* = \frac{1}{2}\sum_{k=1}^{|X|}p(\vec{s}^{k})^2
Err∗=21k=1∑∣X∣p(sk)2
(2)推论1
假设readout函数是单射的,
∣
s
⃗
∣
>
=
∣
X
∣
|\vec{s}|>=|X|
∣s∣>=∣X∣,则对于
s
⃗
∗
\vec{s}^*
s∗在联合分布和边际分布之积间最优分类器分类误差下的最优summary,存在
∣
s
⃗
∗
∣
=
∣
X
∣
|\vec{s}^*|=|X|
∣s∗∣=∣X∣
(3)定理1
∣
s
⃗
∗
∣
=
a
r
g
m
i
n
s
⃗
M
I
(
X
;
s
⃗
)
|\vec{s}^*|=argmin_{\vec{s}} MI(X;\vec{s})
∣s∗∣=argminsMI(X;s),MI表示互信息.
定理1表明,最小化判别器的分类误差可以最大化输入与 输出之间的互信息。
(4)定理2
假设
∣
X
i
∣
=
∣
X
∣
=
∣
s
⃗
∣
>
=
∣
h
⃗
i
∣
|X_i|=|X|=|\vec{s}|>=|\vec{h}_i|
∣Xi∣=∣X∣=∣s∣>=∣hi∣,则能最小化
p
(
h
⃗
)
i
,
s
⃗
)
p(\vec{h})_i,\vec{s})
p(h)i,s)与
p
(
h
⃗
i
)
p
(
s
⃗
)
p(\vec{h}_i)p(\vec{s})
p(hi)p(s)的分类误差可以最大化
M
I
(
X
(
k
)
,
h
⃗
i
)
MI(X^{(k)},\vec{h}_i)
MI(X(k),hi)。
4.2、DGI的整体架构
- 通过corruption function采样一个负样本: ( X ~ , A ~ ) ∼ C ( X , A ) (\tilde{X},\tilde{A})\sim\mathcal{C}(X,A) (X~,A~)∼C(X,A)
- 通过编码器获得正例样本的patch representations: H = E ( X , A ) = { h ⃗ 1 , h ⃗ 2 , … , h ⃗ N } \mathrm{H}=\mathcal{E}(\mathrm{X},\mathrm{A})=\{\vec{h}_{1},\vec{h}_{2},\ldots,\vec{h}_{N}\} H=E(X,A)={h1,h2,…,hN}
- 通过编码器获得负例样本的patch representations: H ~ = E ( X ~ , A ~ ) = { h ~ ⃗ 1 , h ~ ⃗ 2 , … , h ~ ⃗ N } \mathrm{\tilde{H}}=\mathcal{E}(\mathrm{\tilde{X}},\mathrm{\tilde{A}})=\{\vec{\tilde{h}}_{1},\vec{\tilde{h}}_{2},\ldots,\vec{\tilde{h}}_{N}\} H~=E(X~,A~)={h~1,h~2,…,h~N}
- 将patch represents输入到readout function 计算图的全局信息 s ⃗ = R ( H ) \vec{s}=\mathcal{R}(\mathrm{H}) s=R(H)
- 通过梯度下降算法最大化损失函数,更新 E , R , D \mathcal{E},\mathcal{R},\mathcal{D} E,R,D的参数
5、实验设置
DGI以完全无监督的方式学习patch represents,然后直接使用这些表征来训练和测试简单的线性分类器,来评估这些表征得节点级分类的效果。
这里仅仅列出直推式学习的实验设置。
5.1、直推式学习(Cora,Citeseer and Pubmed)
(1)、编码器(encoder)
由一层的图卷积网络(GCN)组成,如下形式:
E
(
X
,
A
)
=
σ
(
D
^
−
1
2
A
^
D
^
−
1
2
X
Θ
)
\mathcal{E}(X,A)=\sigma(\hat{D}^{-\frac{1}{2}}\hat{A}\hat{D}^{-\frac{1}{2}}X\Theta)
E(X,A)=σ(D^−21A^D^−21XΘ)
A
^
=
A
+
I
N
\hat{A}=A+I_N
A^=A+IN是带有自环的邻接矩阵,
D
^
\hat{D}
D^是对应的度矩阵,
σ
\sigma
σ是参数化的ReLU(即PReLU),
Θ
\Theta
Θ为可学习的线性变换参数。
(2)、corruption function
在直推式学习任务中,corruption function旨在使图中不同结点的结构相似性进行正确编码,故保留原始的邻接矩阵,并对原始特征矩阵按行随机排列,即
(
X
~
,
A
~
)
=
(
s
h
u
f
f
l
e
(
X
)
,
A
)
(\tilde{X},\tilde{A})=(shuffle(X),A)
(X~,A~)=(shuffle(X),A).
(3)、readout function
R
(
H
)
=
σ
(
1
N
∑
i
=
1
N
h
⃗
i
)
\mathcal{R}(H)=\sigma(\frac{1}{N}\sum_{i=1}^N\vec{h}_i)
R(H)=σ(N1i=1∑Nhi)
σ
\sigma
σ为sigmoid函数。
(4)、discriminator function
D
(
h
⃗
i
,
s
⃗
)
=
σ
(
h
⃗
i
T
W
s
⃗
)
\mathcal{D}(\vec{h}_i,\vec{s})=\sigma(\vec{h}_i^TW\vec{s})
D(hi,s)=σ(hiTWs)
W
\mathbf{W}
W是一个可学习的得分矩阵,
σ
\sigma
σ为sigmoid激活函数,用来将得分转化为
(
h
⃗
i
,
s
⃗
)
(\vec{h}_i,\vec{s})
(hi,s)的概率。
6、代码实现
完整代码链接
链接:https://pan.baidu.com/s/1JyWhR1LP0Sdzhpl25SSjXw
提取码:6666
6.1、DGI
import torch
import torch.nn as nn
import torch.nn.functional as F
from layers.readout import readout
from layers.GCNLayer import GCNLayer
from layers.discriminator import discriminator
class DGI(nn.Module):
def __init__(self,infeat,hidfeat,activation='prelu', ) -> None:
super(DGI,self).__init__()
self.gcn = GCNLayer(infeat,hidfeat,activation)
self.readout = readout()
self.disc = discriminator(hidfeat)
def forward(self,g,x1,x2):
h1 = self.gcn(x1,g)
h2 = self.gcn(x2,g)
s = self.readout(h1)
res = self.disc(s,h1,h2)
return res
def embed(self,x,g):
h = self.gcn(x,g)
s = self.readout(h)
return h.detach(),s.detach()
6.2、GCNLayer
import torch.nn as nn
import torch
class GCNLayer(nn.Module):
def __init__(self, infeat,outfeat,activation,bias=True) -> None:
super(GCNLayer,self).__init__()
self.layer = nn.Linear(infeat,outfeat,bias=False)
self.activation = nn.PReLU()
if bias:
self.bias = nn.Parameter(torch.FloatTensor(outfeat))
self.bias.data.fill_(0.0)
else:
self.register_parameter('bias', None)
for m in self.modules():
self.weights_init(m)
def weights_init(self, m):
if isinstance(m, nn.Linear):
torch.nn.init.xavier_uniform_(m.weight.data)
if m.bias is not None:
m.bias.data.fill_(0.0)
def forward(self,x,g):
out = self.layer(x)
out = torch.spmm(g,out)
if self.bias is not None:
out = out + self.bias
return self.activation(out)
6.3、readout function
import torch
import torch.nn as nn
class readout(nn.Module):
def __init__(self, ) -> None:
super(readout,self).__init__()
self.act = nn.Sigmoid()
def forward(self,seq):
return self.act(torch.mean(seq,dim=1))
6.4、discriminator
import torch
import torch.nn as nn
class discriminator(nn.Module):
def __init__(self,hidfeat) -> None:
super(discriminator,self).__init__()
self.bidlinear = nn.Bilinear(hidfeat,hidfeat,1)
#self.act = nn.Sigmoid()
for m in self.modules():
self.weights_init(m)
def weights_init(self, m):
if isinstance(m, nn.Bilinear):
torch.nn.init.xavier_uniform_(m.weight.data)
if m.bias is not None:
m.bias.data.fill_(0.0)
def forward(self,s,h1,h2):
s = torch.unsqueeze(s, 1)
s = s.expand_as(h1)
dis_1 = self.bidlinear(h1,s)
#dis_1 = self.act(dis_1)
dis_2 = self.bidlinear(h2,s)
#dis_2 = self.act(dis_2)
logits = torch.cat([dis_1,dis_2],dim=0)#.transpose(1,0)
return logits
7、参考链接
参考链接1
参考链接2
参考链接3
参考链接4