运动想象迁移学习系列:SSMT
- 0. 引言
- 1. 主要贡献
- 2. 网络结构
- 3. 算法
- 4. 补充
- 4.1 为什么设置一种新的适配器?
- 4.2 动态加权融合机制究竟是干啥的?
- 5. 实验结果
- 6. 总结
- 欢迎来稿
论文地址:https://link.springer.com/article/10.1007/s11517-024-03032-z
论文题目:Semi-supervised multi-source transfer learning for cross-subject EEG motor imagery classification
论文代码:无
0. 引言
脑电图(EEG)运动意象(MI)分类是指利用脑电信号对受试者的运动意象活动进行识别和分类;随着脑机接口(BCI)的发展,这项任务越来越受到关注。然而,脑电图数据的收集通常是耗时且劳动密集型的,这使得很难从新受试者那里获得足够的标记数据来训练新模型。此外,不同个体的脑电信号表现出显着差异,导致在直接对从新受试者获得的脑电信号进行分类时,在现有受试者上训练的模型的性能显着下降。因此,充分利用现有受试者的脑电数据和新目标受试者的未标记脑电数据,提高目标受试者达到的心肌梗死分类性能至关重要。本研究提出了一种半监督多源迁移(SSMT)学习模型来解决上述问题;该模型学习信息和域不变表示,以解决跨主题的 MI-EEG 分类任务。具体而言,该文提出了一种动态转移加权模式,通过整合从多源域派生的加权特征来获得最终预测。
文中主要解决方法是针对无监督的脑电数据迁移学习方案,是一个不错的角度,也提出了很有新意的算法设计!!!
1. 主要贡献
- 一种基于
MMD
和CMMD
的域适应方法
,用于解决单个 MI-EEG信号差异
的问题,对齐每个源域和靶域之间的条件和边际分布差异。此外,伪标签被应用于目标域的未标记数据,并在整个训练过程中迭代更新。通过这种方式,条件分布信息将更新为近似真实的条件分布。 - 基于域间差异度量设计了一种
动态权重转移模型
,使每个源域能够根据其与目标域的相似性
为训练过程做出贡献。因此,通过减轻与目标域显著差异的源域的不利影响,可以进一步提高分类器对目标域的预测性能。 - 通过一系列实验,在两个公开可用的 BCI数据集上
评估
了所提出的方法。结果表明,所提方法的每一项创新都有助于提高解码性能,与基线相比,解码性能更好。
2. 网络结构
SSMT
由两个主要阶段
组成。预训练阶段
预训练所有可用于在特征提取任务和原始监督分类任务中训练的标记数据,以获得仅包含特征提取器和分类器的全局模型
。然后,利用预训练模型对目标域的未标记数据进行伪标记;再训练阶段
包括三个主要步骤。首先,域适配器
旨在减少每个源域和目标域之间的差异。然后,使用伪标签信息
并不断更新以优化模型。最后,最终决策由MLP分类器的转移权重融合
产生。
3. 算法
符号说明:
{
X
s
k
,
y
s
k
}
k
=
1
n
\{X_s^k, y_s^k\}_{k=1}^n
{Xsk,ysk}k=1n 表示存在n个源域;
X
t
X_t
Xt 表示目标域,包含两个部分,分别是
X
l
X_l
Xl 和
X
u
X_u
Xu;
X
l
X_l
Xl 和
y
l
y_l
yl 表示目标域中已知(标记)的样本;
X
u
X_u
Xu 表示目标域中未标记的样本,即也不知道其对应的类别。
SSMT算法步骤:
输入: { X s k , y s k } k = 1 n , X l , y l , X u \{X_s^k, y_s^k\}_{k=1}^n, X_l, y_l, X_u {Xsk,ysk}k=1n,Xl,yl,Xu
-
初始化权重参数 θ f , θ c \theta_f, \theta_c θf,θc
-
通过输入 { X s k , y s k } k = 1 n , X l , y l \{X_s^k, y_s^k\}_{k=1}^n, X_l, y_l {Xsk,ysk}k=1n,Xl,yl 直接训练预训练模型中的特征提取器 G f G_f Gf 和MLP分类器 G c G_c Gc , 并根据下面等式更新参数 θ f , θ c \theta_f, \theta_c θf,θc L c = − ∑ k = 1 n y s k ⋅ log ( G c ( G f ( X s k ; θ f ) ; θ c ) ) − y l ⋅ log ( G c ( G f ( X l ; θ f ) ; θ c ) ) , \begin{aligned} L_c= & {} -\sum _{k=1}^n \textbf{y}^k_s\cdot \log (G_c(G_f(\textbf{X}^k_s;\theta _f);\theta _c))\nonumber \\{} & {} -\textbf{y}_l\cdot \log (G_c(G_f(\textbf{X}_l;\theta _f);\theta _c)), \end{aligned} Lc=−k=1∑nysk⋅log(Gc(Gf(Xsk;θf);θc))−yl⋅log(Gc(Gf(Xl;θf);θc)),
-
生成测试集的伪标签: y ^ u = G c ( G f ( X u ; θ f ) ; θ c ) , \begin{aligned} \hat{\textbf{y}}_u=G_c(G_f(\textbf{X}_u;\theta _f);\theta _c), \end{aligned} y^u=Gc(Gf(Xu;θf);θc), 预训练阶段结束
-
将 X l X_l Xl 和 X u X_u Xu 的数据合并为目标域 X t X_t Xt,并连接所有域的数据(将 X s k X_s^k Xsk 和 X t X_t Xt 的数据进行连接)
-
重复
-
将连接的数据输入 G f G_f Gf 来得到所有域的特征:
F = [ G f ( X s 1 ; θ f ) , . . . , G f ( X s n ; θ f ) , G f ( X t ; θ f ) ] T F=[G_f(X_s^1;\theta_f),...,G_f(X_s^n;\theta_f),G_f(X_t;\theta_f)]^T F=[Gf(Xs1;θf),...,Gf(Xsn;θf),Gf(Xt;θf)]T -
根据以下公式获取每个源域的差异损失和转移权重: L d k = M M D ( D s k , D t ) + C M M D ( D s k , D t ) . \begin{aligned} L_d^k=MMD(\mathcal {D}^k_s, \mathcal {D}_t)+CMMD(\mathcal {D}^k_s, \mathcal {D}_t). \end{aligned} Ldk=MMD(Dsk,Dt)+CMMD(Dsk,Dt). C M M D ( D s k , D t ) = ∑ c = 1 C ∥ 1 m c ∑ x s k , i ∣ y s k , i = c ϕ ( G f ( x s k , i ; θ f ) ) − 1 n ^ c + n c ( ∑ x l i ∣ y l i = c ϕ ( G f ( x l i ; θ f ) ) + ∑ x u i ∣ y ^ u i = c ϕ ( G f ( x u i ; θ f ) ) ∥ , \begin{aligned} CMMD(\mathcal {D}^k_s, \mathcal {D}_t)= & {} \sum _{c=1}^C\Vert \frac{1}{m_c} \sum _{\textbf{x}_s^{k,i} |y^{k,i}_s=c} \phi (G_f(\textbf{x}_s^{k,i};\theta _f))\nonumber \\{} & {} -\frac{1}{\hat{n}_c+n_c}(\sum _{\textbf{x}_l^i |{y}_l^i=c} \phi (G_f(\textbf{x}_l^i;\theta _f))\nonumber \\{} & {} +\sum _{\textbf{x}_u^i |\hat{y}_u^i=c} \phi (G_f(\textbf{x}_u^i;\theta _f))\Vert , \end{aligned} CMMD(Dsk,Dt)=c=1∑C∥mc1xsk,i∣ysk,i=c∑ϕ(Gf(xsk,i;θf))−n^c+nc1(xli∣yli=c∑ϕ(Gf(xli;θf))+xui∣y^ui=c∑ϕ(Gf(xui;θf))∥, M M D ( D s k , D t ) = ∥ 1 n s k ∑ i = 1 n s k ϕ ( G f ( x s k , i ; θ f ) ) − 1 n t ∑ i = 1 n t ϕ ( G f ( x t i ; θ f ) ) ∥ , \begin{aligned} MMD\left( \mathcal {D}^k_s, \mathcal {D}_t\right)= & {} \Bigg \Vert \frac{1}{n^k_s} \sum _{i=1}^{n^k_s} \phi (G_f(\textbf{x}_s^{k,i};\theta _f))\nonumber \\{} & {} - \frac{1}{n_t} \sum _{i=1}^{n_t} \phi (G_f(\textbf{x}_t^i;\theta _f))\Bigg \Vert , \end{aligned} MMD(Dsk,Dt)= nsk1i=1∑nskϕ(Gf(xsk,i;θf))−nt1i=1∑ntϕ(Gf(xti;θf)) ,
-
基于下面式子对每个域的特征进行动态加权,然后将 F ∗ F^* F∗ 作为 G c G_c Gc 的输入:
w = [ W d 1 , … , W d n ] ⊤ = [ K − L d 1 2 ∑ k = 1 n K − L d k 2 , … , K − L d n 2 ∑ k = 1 n K − L d k 2 ] ⊤ , \begin{aligned} \textbf{w}= & {} [W^1_d, \ldots , W^n_d]^{\top }\nonumber \\= & {} \left[ \frac{K^{- {L_d^1}^2}}{\sum _{k=1}^n K^{- {L_d^k}^2}}, \ldots , \frac{K^{- {L_d^n}^2}}{\sum _{k=1}^n K^{- {L_d^k}^2}}\right] ^{\top }, \end{aligned} w==[Wd1,…,Wdn]⊤[∑k=1nK−Ldk2K−Ld12,…,∑k=1nK−Ldk2K−Ldn2]⊤, F ∗ = [ F s 1 ∗ , … , F s n ∗ , F t ] ⊤ = [ W d 1 F s 1 , … , W d n F s n , F t ] ⊤ , \begin{aligned} \textbf{F}^*=[{\textbf{F}^1_s}^*,\ldots ,{\textbf{F}^n_s}^*,\textbf{F}_t]^\top =[W^1_d\textbf{F}^1_s,\ldots ,W^n_d\textbf{F}^n_s,\textbf{F}_t]^\top , \end{aligned} F∗=[Fs1∗,…,Fsn∗,Ft]⊤=[Wd1Fs1,…,WdnFsn,Ft]⊤,
-
根据下面等式,通过最小化 L L L 更新参数 θ f , θ c \theta_f, \theta_c θf,θc
L = L c + λ L d , \begin{aligned} L=L_c+\lambda L_d, \end{aligned} L=Lc+λLd,
-
通过预测 X u X_u Xu 更新 y ^ u \hat{y}_u y^u
-
直到收敛
-
返回 y ^ u \hat{y}_u y^u
4. 补充
4.1 为什么设置一种新的适配器?
最近的研究表明,随着域间差异的增加,分类器对特征的可转移性显着降低,这表明直接转移提取的特征是一种不安全的策略。因此,在不考虑个体信号差异的情况下,使用所有可用数据进行预训练的模型可能会导致目标受试者分类的性能下降。为了防止传统两级流水线引起的分布过拟合问题,设计了一种域适配器来减轻单个信号差异的负面影响。
尽管经典MMD已被广泛用作分布差异度量,但现有研究表明,在处理类权重偏差(即类不平衡数据)时,MMD并不总是可靠的。调查发现类条件分布之间的差异 P s ( x s k , i ∣ y s k , i = c ) P_s\left( \textbf{x}_s^{k,i} \mid y^{k,i}_s=c\right) Ps(xsk,i∣ysk,i=c) 和 P t ( x l i ∣ y l i = c ) P_t\left( \textbf{x}_l^i \mid y_l^i=c\right) Pt(xli∣yli=c)可以提供更合适的域差异量表,并导致卓越的域适应性能。什么时候 P s ( x s k , i ∣ y s k , i = c ) = P t ( x l i ∣ y l i = c ) P_s\left( \textbf{x}_s^{k,i} \mid y^{k,i}_s=c\right) =P_t\left( \textbf{x}_l^i \mid y_l^i=c\right) Ps(xsk,i∣ysk,i=c)=Pt(xli∣yli=c),在源域中学习的分类器可以更安全地应用于目标域。基于这一概念,引入了条件最大均值差异(CMMD)度量,以对齐所有源域和目标域特征的类条件分布.
4.2 动态加权融合机制究竟是干啥的?
从所有数据中获得的特征 G f G_f Gf 可直接用于输入 G c G_c Gc 用于训练,但分类器的这种无歧视训练输入可能会导致不良结果。这一结果可归因于负转移,当通过蛮力利用与目标关系不相关的来源时,就会发生负转移,从而导致对目标域的分类器预测有偏差。
为了减轻负迁移的影响,分类器被赋予了动态加权特征,用于最终决策融合。
5. 实验结果
对比实验结果:
消融实验结果:
PT
:PT是仅包含特征提取器和MLP分类器的基本模型,可以完成简单的特征提取和分类任务。DA
:域适配器 (DA) 基于 MMD 和 CMMD。特别是,DA 仅使用通过预训练生成的伪标签来计算域间差异。SS
:SS 是一个迭代标签更新器。它的作用是在重新训练过程中周期性地生成和更新伪标签。WF
:WF是指动态加权模型,它对来自多源域的加权特征进行动态加权和整合。
6. 总结
到此,使用 SSMT 已经介绍完毕了!!! 如果有什么疑问欢迎在评论区提出,对于共性问题可能会后续添加到文章介绍中。
如果觉得这篇文章对你有用,记得点赞、收藏并分享给你的小伙伴们哦😄。
欢迎来稿
欢迎投稿合作,投稿请遵循科学严谨、内容清晰明了的原则!!!! 有意者可以后台私信!!