论文笔记
资料
1.代码地址
https://github.com/locuslab/tta_conjugate
2.论文地址
https://arxiv.org/abs/2207.09640
3.数据集地址
论文摘要的翻译
测试时间适应(TTA)指的是使神经网络适应分布变化,在测试时间仅访问来自新领域的未标记测试样本。以前的TTA方法对无监督目标进行优化,如Tent中模型预测的熵,但尚不清楚究竟是什么造成了良好的TTA损失。本文首先提出一个令人惊讶的现象:如果我们试图在一个广泛类别函数上元学习可能的最佳TTA损失,那么我们恢复的函数与Tent使用的Softmax-熵非常相似(温度缩放的版本)。然而,这只有在我们正在适应的分类器是通过交叉熵损失来训练的情况下才成立;如果分类器是通过平方误差来训练的,那么就会出现不同的“最佳”TTA损失。为了解释这一现象,我们通过训练损失的凸共轭来分析测试时间适应。我们证明了在自然条件下,这个(无监督的)共轭函数可以被看作是对原始监督损失的良好局部逼近,并且确实,它恢复了元学习发现的“最佳”损失。这导致了一个通用方法,该方法可用于为一般类的任何给定的有监督训练损失函数找出良好的TTA损失。从经验上看,我们的方法在广泛的领域适应基准上始终主导着其他TTA替代方案。当我们的方法被应用于用新的损失函数训练的分类器时,尤其令人感兴趣,例如最近提出的PolyLoss函数,其中它与基于熵的损失有很大的不同(并且性能更优)。此外,我们还证明了我们的基于共轭的方法也可以解释为一种使用非常特定的软标签的自我训练,我们称之为共轭伪标签。总体而言,我们的方法为更好地理解和改进测试时间适应提供了一个广泛的框架。
1 引言
现代深度网络在接近训练分布的测试分布输入上表现得非常好。然而,在来自不同分布的测试输入上,这种性能会显著降低。虽然在提高模型的稳健性方面有大量的工作,但大多数健壮的训练方法都高度专业化于它们所迎合的环境。例如,它们假设预先指定的扰动、子总体和虚假相关性,或者从目标分布访问未标记的数据,并且大多数方法除了它们被训练的之外,几乎没有提供对一般分布偏移的改进。
在实践中,准确地描述模型可能遇到的所有可能的分布变化并随后进行相应的训练通常是繁琐的(甚至不可能的)。相反,已经在某些源数据上训练的模型必须能够在测试时适应来自不同领域的新输入。近年来,这种测试时间适应(TTA)的设置引起了人们的兴趣。TTA通常是通过在涉及来自目标分布的新测试样本的非监督目标上的几个优化步骤更新源模型参数来实现的。选择这种不受监督的目标,我们称之为TTA损失,决定了适应程序的成功。对测试样本使用自监督目标,使用模型预测的熵,几个后续行动提出了变体或备选方案。然而,如何选择或指导选择这一TTA损失仍不清楚,到目前为止,这些损失的选择在很大程度上仍然是启发式的。
在这项工作中,我们首先提出一组有趣的实验,在这些实验中,我们试图学习给定来源分类器和分布转移的“最佳”TTA损失。我们通过另一个神经网络对TTA损失进行参数化,该神经网络的参数是通过元学习学习的,其中我们通过自适应过程进行区分,以找到在分布平移上实现最佳适应的TTA损失。令人惊讶的是,我们最终得到了一个TTA损失,它看起来非常类似于已经提出的Softmax-entropy损失(按温度缩放的版本)。为什么我们恢复了常用的Softmax-entropy损失,尽管该过程能够学习非常一般的损失类别,并且元学习过程可能专门用于源分类器和感兴趣的分布转移?此外,我们发现,这种模式只有当用于训练源分类器的损失是交叉熵损失时才适用;当使用不同的损失,如平方损失时,元学习过程恢复TTA损失,该TTA损失本身看起来更像负平方误差,并且与Softmax-熵损失(第3节)非常不同。
为了解释这一现象,我们建议通过凸共轭函数的透镜来考虑TTA。具体地说,在给定假设函数
h
(
x
)
h(x)
h(x)和标签
y
y
y的情况下,对于某些函数
f
f
f,可以以
L
(
h
(
x
)
,
y
)
=
f
(
h
(
x
)
)
−
y
T
h
(
x
)
\mathcal L(h(x),y)=f(h(x))−y^Th(x)
L(h(x),y)=f(h(x))−yTh(x)的形式来写出几个共同的损失(交叉熵和它们之间的平方损失,但不限于这些)。在这些情况下,我们证明了对于这种分类器来说,“自然的”TTa损失恰好是在h的梯度处评估的凸共轭的(否定),
L
T
T
A
(
x
)
=
−
f
∗
(
∇
f
(
h
(
X
)
)
\mathcal L_{TTA}(x)=−f∗(∇f(h(X))
LTTA(x)=−f∗(∇f(h(X)),其中f∗是f的凸共轭。这个框架不仅恢复了我们的元学习实验的结果,而且也证明了为什么以前文献中的一些特定的TTA损失选择很好地工作(例如,这个框架恢复了Content为交叉熵训练的分类器选择的Softmax-熵)。此外,它还提供了一个广泛的框架,当使用各种不同的损失函数(例如最近提出的PolyLoss)训练源模型时,TTA损失应该是什么,这在机器学习中变得越来越常见。此外,我们还证明了我们提出的共轭适应损失实际上是一种带有伪标签的自我训练[42],这是机器学习中的一种经典方法。在文献中已经提出了各种伪标记的公式,并且我们的共轭分析提供了由ˆy(X)=∇f(h(X))给出的软伪标记的“正确”选择的一般公式。因此,我们将这些称为共轭伪标签(Conjugate PL),并相信我们的工作为理解与未标记数据的适应提供了一个广泛的框架。
最后,我们在几个数据集和训练损失(如交叉熵和平方损失)以及最近提出的PolyLoss25上经验地验证了我们提出的共轭适应损失的有效性。在所有模型、数据集和训练损失中,我们发现我们提出的共轭伪标记始终优于先前的TTA损失,并改善了当前技术状态下的TTA性能。
2背景和准备工作
- TTA。
我们感兴趣的是将输入 x ∈ R d x\in \mathbb {R^ d} x∈Rd映射到标签 y ∈ Y y\in \mathcal Y y∈Y。我们学习了一个由 θ θ θ参数化的模型 h θ : R d → R ∣ y ∣ h_θ:\mathbb {R^ d} →\mathbb {R^ {|y|}} hθ:Rd→R∣y∣,该模型将输入 x x x映射到预测 h θ ( x ) h_θ(x) hθ(x)。我们假设访问经过训练的源模型,并在进行最终预测之前,在测试时对测试输入进行调整。这是标准测试时间自适应(TTA)设置。在TTA过程中,我们更新无监督目标 L ( x , h θ ) \mathcal L(x,h_θ) L(x,hθ)的模型参数。例如,在TENT中,这种损失是模型的softmax归一化预测的熵。在适应的每个时间步骤,我们观察一批测试输入,并采取梯度步骤优化该测试批次的TTA损失。按照标准,我们测量自适应过程中所有步骤(测试批次输入的数量)中模型的平均在线性能。 - Meta learning the loss function
为了探索不同TTA损失的存在,我们采用了元学习程序,试图学习TTA损失。我们使用与先前关于元学习损失函数的工作类似的程序,并通过神经网络 m φ : R ∣ y ∣ → R m_φ:\mathbb R^{|y|}→\mathbb R mφ:R∣y∣→R来参数化损失函数,该神经网络接受模型预测/逻辑并输出损失值。我们想学习参数 φ φ φ,这样当我们通过损失函数m φ φ φ更新 θ θ θ时,我们的最终性能是最优的。为了做到这一点,设x是要适应的未标记测试样本, y y y是相应的标签。我们交替更新 θ θ θ和 φ φ φ如下。 θ t + 1 ← θ t − α ∂ m ϕ t ( h θ t ( x ) ) ∂ θ t , ϕ t + 1 ← ϕ t − β ∂ L ( h θ t + 1 ( x ′ ) , y ′ ) ∂ ϕ t , ( 1 ) \theta^{t+1}\leftarrow\theta^t-\alpha\frac{\partial m_{\phi^t}(h_{\theta^t}(x))}{\partial\theta^t}, \phi^{t+1}\leftarrow\phi^t-\beta\frac{\partial\mathcal{L}(h_{\theta^{t+1}}(x^{\prime}),y^{\prime})}{\partial\phi^t},\quad(1) θt+1←θt−α∂θt∂mϕt(hθt(x)),ϕt+1←ϕt−β∂ϕt∂L(hθt+1(x′),y′),(1)其中 L \mathcal L L是一些有监督的替代损失函数,例如交叉熵。有关元学习设置的更多详细信息,请参阅附录A3。注意,上面的元学习过程假设访问测试输入的标签 y y y。在本文中,我们不建议将TTA损失作为元学习的一种方法。相反,我们使用元学习来探索“最佳”TTA损失是什么样子的。我们将在下一节中讨论我们的探索结果。
3 论文方法的概述
TENT中使用的目标是模型预测的softmax熵,这本质上使分类器对其当前预测更有信心。这也可以通过各种其他损失公式来实现,如之前方法中提到的公式。损失函数有这么多可能的选择,我们应该使用什么来进行TTA?在本节中,我们试图从经验上回答这个问题,并提出一些有趣的观察结果。
- Experiment 1
我们通过元学习学习由神经网络参数化的TTA损失,如第2节所述。我们的源分类器是在CIFAR-10上训练的ResNet-26,我们适应了CIFAR-10-C中的分布变化。我们使用CIFAR-10-C中的4个标记的验证噪声来学习元损失网络参数,并通过元TTA损失来表示最终的学习损失函数。然后,我们通过优化元TTA损失,使源分类器适应15个损坏的测试集。 - Observations
首先,我们发现使用元TTA损失的TTA比TENT表现更好(12.35%对13.14%),这表明TTA损失比以前基于softmax熵的损失更好。
然而,在研究这种元TTA损失时,我们发现了一个令人惊讶的观察结果。图1a(蓝色曲线)显示了学习到的元损失与模型预测的关系,因为我们改变了单个类别的预测,其余的固定不变。定性地说,所学习的元损失在一维上看起来与softmax熵非常相似。事实上,我们可以将其与缩放的softmax熵函数(红色虚线)密切拟合: α ⋅ H ( s o f t m a x ( H θ ( x ) / T )) \alpha·\mathcal H(softmax(H_θ(x)/T)) α⋅H(softmax(Hθ(x)/T)),其中 α α α是幅度参数, T T T是温度缩放器。我们想测试元损失是否基本上是在学习softmax熵函数。因此,我们使用拟合的softmax熵函数(红色虚线曲线)进行测试时间自适应,并实现12.32%的误差,基本上恢复了元TTA的性能。
尽管能够表示许多不同的损失函数,并可能专门用于CIFAR10-C设置,但元损失过程还是返回了标准熵目标。我们总是恢复看起来像softmax熵的损失吗? - Experiment 2
当我们回到熵目标时,为了隔离,我们改变了一些事情。我们尝试了源分类器的不同架构,在元学习过程公式(1)中的不同损失 L \mathcal L L,以及源分类器不同的训练损失。 - Results
我们观察到,除了我们改变源分类器的训练损失时,我们在所有情况下都一致地恢复了温度标度的softmax熵函数(附录A.10)。在使用平方损失函数时,出现了截然不同的元TTA损失。图1b(蓝色曲线)显示了该网络的学习元损失(13.48%的误差)。这里,元TTA损失优于熵(14.57%),但这不仅仅是由于比例因子。现在的损失看起来像是负平方误差(红色曲线)。与之前一样,我们试图将二次损失直接拟合到图1b中的元损失,这一次我们甚至略优于元TTA损失。
总之,我们使用元学习过程来搜索“最佳”TTA损失,其中损失本身由神经网络参数化,该神经网络可能表示任意复杂的损失函数。然而,我们最终发现损失函数显示出显著的结构:在元学习的不同架构和不同变体中,对于用交叉熵训练的分类器,元TTA损失是温度标度的softmax熵,而对于用平方损失训练的分类器来说,元TTA损失是负平方损失。从实用和概念的角度来看,这都很有趣,因为“最佳”TTA损失取决于用于以干净的方式训练源分类器的损失。我们试图在下一节中理解和解释这一现象。
4 Conjugate Pseudo Labels
上一节的结果提出了一个明显的问题:为什么TENT中使用的softmax熵似乎是通过交叉熵训练的分类器的“最佳”测试时自适应损失(至少,在元学习一致地恢复本质上模仿softmax熵的东西的意义上是最佳的,尽管元损失是由神经网络参数化的,因此可以学习特定于模型和特定偏移的更复杂的函数)?或者,为什么当通过平方损失训练分类器时,二次TTA损失似乎表现最好?
在本节中,我们通过构造凸共轭函数来解释这一现象。正如我们将看到的,我们的方法将softmax熵和二次损失恢复为分别通过交叉熵和平方损失训练的分类器的“自然”目标。此外,对于通过其他损失函数训练的分类器,正如在深度学习中越来越常见的那样,我们的方法自然会提出相应的测试时自适应损失,我们在下一节中展示了这一点,以相对优于替代方案。因此,我们认为,我们的框架总体上提供了一个令人信服的配方,用于为一大类可能的损失指定TTA的“正确”方法。
4.1 Losses and the convex conjugate
我们首先正式考虑假设输出
h
θ
(
x
)
h_θ(x)
hθ(x)(例如,分类器的logit输出或回归器的直接预测)和目标y之间的损失函数,其形式如下
L
(
h
θ
(
x
)
,
y
)
=
f
(
h
θ
(
x
)
)
−
y
T
h
θ
(
x
)
(2)
\mathcal{L}(h_\theta(x),y)=f(h_\theta(x))-y^Th_\theta(x)\text{(2)}
L(hθ(x),y)=f(hθ(x))−yThθ(x)(2)或者一些函数f;当没有混淆的风险时,为了便于记法,我们将使用
h
h
h代替
h
θ
(
x
)
h_θ(x)
hθ(x)。虽然不是每一种损失都可以用这样的形式表示,但这涵盖了各种常见的损失(可能按常数值缩放)。例如,交叉熵损失对应于选择
f
(
h
)
=
log
∑
i
exp
(
h
i
)
f(h)=\log\sum_{i}\exp(h_{i})
f(h)=log∑iexp(hi),其中y表示类标签的一次热编码;类似地,平方损失对应于选择
:
f
(
h
)
=
1
2
∥
h
∥
2
2
.
:f(h)=\frac12\|h\|_{2}^{2}.
:f(h)=21∥h∥22.当训练过参数化分类器时,我们可以粗略地将训练过程视为(近似地)为每个训练示例获得最小过假设h
min
θ
1
t
∑
i
=
1
t
L
(
h
θ
(
x
i
)
,
y
i
)
≈
1
t
∑
i
=
1
t
min
h
L
(
h
,
y
i
)
(3)
\min_\theta\frac1t\sum_{i=1}^t\mathcal{L}(h_\theta(x_i),y_i)\approx\frac1t\sum_{i=1}^t\min_h\mathcal{L}(h,y_i)\text{(3)}
θmint1i=1∑tL(hθ(xi),yi)≈t1i=1∑thminL(h,yi)(3)其中
t
t
t是训练样本的数量。然而,在公式(2)中的损失的情况下,这种形式的
h
h
h上的最小化代表了一个非常具体和众所周知的优化问题:它被称为函数
f
f
f的凸共轭
min
h
L
(
h
,
y
)
=
min
h
{
f
(
h
)
−
y
T
h
}
=
−
f
⋆
(
y
)
(4)
\min_h\mathcal{L}(h,y)=\min_h\{f(h)-y^Th\}=-f^\star(y)\text{(4)}
hminL(h,y)=hmin{f(h)−yTh}=−f⋆(y)(4)f在哪里?表示
f
f
f的凸共轭。
f
f
f是
y
y
y中的凸函数(事实上,无论f是否凸,它都是凸的)。此外,对于f是凸可微的情况,这个最小化问题的最优性条件由
∇
f
(
h
o
p
t
)
=
y
,
\nabla f(h^{\mathrm{opt}})=y,
∇f(hopt)=y,给出,所以我们也有
f
⋆
(
y
)
=
f
⋆
(
∇
f
(
h
o
p
t
)
)
(
5
)
f^\star(y)=f^\star(\nabla f(h^{\mathrm{opt}}))\quad(5)
f⋆(y)=f⋆(∇f(hopt))(5)其中
h
o
p
t
h_{opt}
hopt指的是最优分类器(可与
h
θ
o
p
t
h_{\theta ^{opt}}
hθopt互换使用)。将所有这些放在一起,我们可以声明(诚然,以一种相当非正式的方式),在假设选择
θ
o
p
t
\theta ^{opt}
θopt是为了在过度参数化的设置中近似最小化源数据的经验损失的情况下,我们得出对于t个输入
1
t
∑
i
=
1
t
L
(
h
θ
∗
(
x
i
)
,
y
i
)
≈
1
t
∑
i
=
1
t
−
f
⋆
(
∇
f
(
h
θ
∗
(
x
i
)
)
)
(6)
\frac{1}{t}\sum_{i=1}^{t}\mathcal{L}(h_{\theta^{*}}(x_{i}),y_{i})\approx\frac{1}{t}\sum_{i=1}^{t}-f^{\star}(\nabla f(h_{\theta^{*}}(x_{i})))\text{(6)}
t1i=1∑tL(hθ∗(xi),yi)≈t1i=1∑t−f⋆(∇f(hθ∗(xi)))(6)经验损失可以通过应用于
f
f
f的梯度的(负)共轭来近似,至少在接近使经验损失最小化的最优
θ
o
p
t
θ^{opt}
θopt的区域中是这样。但后面的表达式具有显著的优点,即它不需要任何标签
y
i
y_i
yi来计算损失,因此可以用作假设函数
h
θ
o
p
t
h_{\theta ^{opt}}
hθopt的目标域上的TTA的基础。
- 定理1(共轭自适应损失)
考虑采用公式2中给出的形式的损失函数,用于在过参数化状态下训练模型假设 h θ h_θ hθ。我们定义共轭自适应损失 L c o n j ( h θ ( x ) ) : R ∣ Y ∣ ↦ R \mathcal{L}^{conj}(h_{\theta}(x)):\mathbb{R}^{|\mathcal{Y}|}\mapsto\mathbb{R} Lconj(hθ(x)):R∣Y∣↦R如下。 L c o n j ( h θ ( x ) ) = − f ⋆ ( ∇ f ( h θ ( x ) ) ) = f ( h θ ( x ) ) − ∇ f ( h θ ( x ) ) ⊤ h θ ( x ) . ( 7 ) \mathcal{L}^{conj}(h_\theta(x))=-f^\star(\nabla f(h_\theta(x)))=f(h_\theta(x))-\nabla f(h_\theta(x))^\top h_\theta(x).\quad(7) Lconj(hθ(x))=−f⋆(∇f(hθ(x)))=f(hθ(x))−∇f(hθ(x))⊤hθ(x).(7)
4.2 Recovery of existing test-time adaptation strategies
- 交叉熵
这种形式主义的有趣之处在于,当应用于用交叉熵训练的分类器时,它准确地恢复了TTA的TENT方法:最小化 h θ ( x )的 s o f t m a x h_θ(x)的softmax hθ(x)的softmax熵。事实上,当使用元学习来学习“最佳”测试时间适应损失时,这种损失也得到了弥补。要看到这一点,请注意,对于交叉熵,我们有 f ( h ) = log ∑ i exp ( h i ) , {f}(h)=\log\sum_{i}{\exp}(h_{i}), f(h)=log∑iexp(hi),,给出了最优化条件 y = ∇ f ( h o p t ) = exp ( h o p t ) ∑ i exp ( h i o p t ) y=\nabla f(h^{\mathrm{opt}})=\frac{\exp(h^{\mathrm{opt}})}{\sum_{i}\exp(h_{i}^{\mathrm{opt}})} y=∇f(hopt)=∑iexp(hiopt)exp(hopt)并且共轭函数 f ⋆ ( y ) = { ∑ i y i log y i if ∑ i y i = 1 ∞ otherwise . ( 8 ) f^\star(y)=\left\{\begin{array}{ll}\sum_iy_i\log y_i&\quad\text{if}\sum_iy_i=1\\\infty&\quad\text{otherwise}\end{array}\right..\quad(8) f⋆(y)={∑iyilogyi∞if∑iyi=1otherwise.(8)
此外 L c o n j ( h θ ( x ) ) = − f ⋆ ( ∇ f ( h θ ( x ) ) ) = − ∑ i exp ( h i ) ∑ j exp ( h j ) log exp ( h i ) ∑ j exp ( h j ) ( 9 ) \mathcal{L}^{\mathbf{conj}}(h_\theta(x))=-f^\star(\nabla f(h_\theta(x)))=-\sum_i\frac{\exp(h_i)}{\sum_j\exp(h_j)}\log\frac{\exp(h_i)}{\sum_j\exp(h_j)}\quad(9) Lconj(hθ(x))=−f⋆(∇f(hθ(x)))=−i∑∑jexp(hj)exp(hi)log∑jexp(hj)exp(hi)(9)模型预测的softmax-entropy,这正是TENT使用的TTA损失。 - Squared loss
对于平方损失,我们得到 f ( h ) = 1 2 ∥ h ∥ 2 2 f(h)=\frac12\|h\|_{2}^{2} f(h)=21∥h∥22,从而得到最优性条件 y = h y=h y=h和共轭函数 f ⋆ ( y ) = 1 2 ∥ y ∥ 2 2 . f^{\star}(y)=\frac{1}{2}\|y\|_{2}^{2}. f⋆(y)=21∥y∥22.。因此,这种情况下的自适应损失可以简单地由 L c o n j ( h θ ( x ) ) = − f ⋆ ( ∇ f ~ ( h θ ( x ) ) ) = − 1 2 ∥ h ∥ 2 2 \mathcal{L}^{\mathrm{conj}}(h_\theta(x))=-f^\star(\nabla\tilde{f}(h_\theta(x)))=-\frac12\|h\|_2^2 Lconj(hθ(x))=−f⋆(∇f~(hθ(x)))=−21∥h∥22,这也是我们在第3节讨论的元学习实验中观察到的。
4.3 Conjugate pseudo-labels
我们现在强调,根据近似的性质,对共轭损失有一个额外的简单解释:它也等于应用于“伪标签”的原始损失(公式2)
y
~
θ
C
P
L
(
x
)
=
∇
f
(
h
θ
(
x
)
~
)
\tilde{y}_{\theta}^{\mathrm{CPL}}(x)=\nabla f(\tilde{h_{\theta}(x)})
y~θCPL(x)=∇f(hθ(x)~)其中CPL是指共轭伪标签。
L
c
o
n
j
(
h
θ
(
x
)
)
=
−
f
⋆
(
∇
f
(
h
θ
(
x
)
)
)
=
f
(
h
θ
(
x
)
)
−
∇
f
(
h
θ
(
x
)
)
T
h
θ
(
x
)
=
L
(
h
θ
(
x
)
,
∇
f
(
h
θ
(
x
)
)
)
.
(
10
)
\mathcal{L}^{\mathrm{conj}}(h_\theta(x))=-f^\star(\nabla f(h_\theta(x)))=f(h_\theta(x))-\nabla f(h_\theta(x))^Th_\theta(x)=\mathcal{L}(h_\theta(x),\nabla f(h_\theta(x))).\\(10)
Lconj(hθ(x))=−f⋆(∇f(hθ(x)))=f(hθ(x))−∇f(hθ(x))Thθ(x)=L(hθ(x),∇f(hθ(x))).(10)这个性质被称为Fenchel-Young不等式,即
f
(
x
)
+
f
⋆
(
u
)
≥
x
T
u
f(x)+f^\star(u)\geq x^Tu
f(x)+f⋆(u)≥xTu在
u
=
∇
f
(
x
)
u=\nabla f(x)
u=∇f(x)时保持相等。换言之,我们的共轭自适应损失精确地等价于在特定的软伪标签下的自训练,该软伪标签由?
y
~
C
P
L
=
∇
f
(
h
θ
(
x
)
)
.
\tilde{y}^{\mathrm{CPL}}=\nabla f(h_{\theta}(x)).
y~CPL=∇f(hθ(x)).给出。事实上,在许多情况下,这可能是一种比显式计算共轭函数更方便的计算形式。出于这个原因,我们将我们的方法称为共轭伪标签的方法。
在交叉熵损失的情况下,这种方法正好对应于使用由应用于当前假设的softmax给出的标签的自训练。然而,我们必须强调,虽然我们的共轭公式在交叉熵损失的情况下确实具有这种“简单”的形式,但真正的优势在于它提供了与其他损失一起使用的“正确”伪标签,这可能导致与“常见”softmax运算不同的伪标签。
- Test-time adaptation.
最后,我们注意到,上面的讨论实际上没有涉及任何与OOD数据的测试时间自适应相关的主题,而只是提供了形式公式(2)的一般损失函数的自训练过程的一般特征。然而,在OOD数据上应用TTA是相当简单的:只要学习到的源参数 θ θ θ是移位域上真正最优 θ o p t θopt θopt的合理近似值,使用共轭伪标签的自训练就为在真正的OOD损失上微调网络提供了合理的代理。我们强调,与大多数TTA方法一样,仍有一些设计决策必须到位;这些在第5.1节中有详细说明。在实践中,我们观察到OOD泛化通常(在所有基线上)受益于额外的“温度”缩放,即,对于某些固定温度T,将TTA损失应用于hθ(x)/T,尽管它需要一个保留的验证数据集来调整T。然而,我们应该强调,真正无监督的TTA需要对这些超参数的值进行知情的猜测。通过共轭伪标签进行测试时间自适应的完整过程如算法1所示。
6 结论、局限性和未来方向
在这项工作中,我们提出了一种基于凸共轭公式的通用测试时间自适应损失,这反过来又受到了有趣的元学习实验的启发。元学习恢复了所提出的损失这一事实暗示了损失的某种最优性。在第4节中,我们证明了对于一组广义损失函数,所提出的(无监督的)共轭损失接近于预言机监督损失。然而,这仍然不能完全回答最佳测试时间自适应损失是什么以及为什么。
这项工作中的元学习框架被限制为在每个单独输入的logits上学习函数。它可以扩展到更多涉及的设置,在这些设置中,我们也考虑中间表示上的函数,也考虑一批输入上的学习函数,同时考虑它们的交互
除了自适应损失本身的选择之外,实现良好的测试时间自适应通常涉及几种启发式方法,如仅更新批范数参数。虽然我们的工作是由损失函数驱动的,但通过元学习实验,我们发现温度标度是另一个重要的超参数,它也提高了以前所有基线的性能。在高水平上,必须对测试时间自适应进行适当的正则化,以防止批量更新使模型走得太远:只更新几个批量范数参数是实现这一点的一种方法,也许温度缩放通过降低对未标记输入的网络预测的置信度,提供了类似的有益正则化效果。更具体地理解这些启发式的作用是未来工作的一个有趣方向。了解基于自我训练的方法在什么样的现实世界分布变化下会有所帮助,这仍然是一个悬而未决的问题。
最后,还值得将共轭伪标记扩展并应用于其他环境,如半监督学习