系列文章目录
文章目录
- 系列文章目录
- 前言
- 1. 什么是 Focal Loss
- 2. 逐过程解析 Focal Loss
- 3. Focal Loss 的 PyTorch 实现
- 总结
前言
类别不平衡是一个在目标检测领域被广泛讨论的问题,因为目标数量的多少在数据集中能很直观的体现。同时,在分割中这也是一个值得关注的问题,毕竟分割的本质是对像素进行分类。而处理类别不平衡一个非差常用的方法就是通过 Focal Loss
来引导模型更关注困难的类。
1. 什么是 Focal Loss
Focal Loss
是在标准交叉熵损失基础上修改得到的。相比 CrossEntropy Loss
它增加了容易和难分样本的权重,对于难分的样本增加权重,增加 loss 的贡献度;减少易分类样本的权重,使得模型在训练时更专注于难分类的样本。
Focal Loss
从另外的视角来解决样本不平衡问题,那就是根据置信度动态调整 CE Loss
,当预测正确的置信度增加时,loss 的权重系数会逐渐衰减至0,这样模型训练的 loss 更关注难例,而大量容易的例子其 loss 贡献很低。
比如假如一张图片上有 10 个正样本,每个正样本的损失值是 3,那么这些正样本的总损失是 10x3=30。而假如该图片上有 10000 个简单易分负样本,尽管每个负样本的损失值很小,假设是 0.1,那么这些简单易分负样本的总损失是 10000x0.1=1000,那么损失值要远远高于正样本的损失值。所以如果在训练的过程中使用全部的正负样本,那么它的训练效果会很差。
2. 逐过程解析 Focal Loss
- 公式一览:
- α \alpha α 侧重的是正负样本之间的不平衡,一般设置为 0.25
- γ \gamma γ 难易样本上的权重调节,一般设置为 2
- 简单的加权 CE Loss 可能只能实现正负样本之间不平衡的调节,所以对于大多数不平衡任务来说 Focal Loss 应该还是能起到更好的效果
- 首先看一下二分类交叉熵损失函数
- 二分类交叉熵损失函数: y y y 是样本的标签值,而 p p p 是模型预测某一个样本为正样本的概率,对于真实标签为正样本的样本,它的概率 p p p 越大说明模型预测的越准确,对于真实标签为负样本的样本,它的概率 p p p 越小说明模型预测的越准确
- 如果我们定义
p
t
p_t
pt 为如下的形式
- 公式 (1) 可以修改为如下形式 (2)
- 现在我们定义一个参数
α
\alpha
α 和
1
−
α
1 - \alpha
1−α 来平衡正负样本的权重,定义
α
t
\alpha_t
αt 如下,需要注意的是,
α
\alpha
α 是个超参数用来平衡正负样本的权重,并不是实际的正负样本的比例,
- 公式 (2) 可以修改为如下形式 (3)
- 又因为样本有难易之分,所以我们必须要能区分出困难样本和简单样本,所以我们设置一个系数 ( 1 − p t ) γ ( 1-p_t )^{\gamma} (1−pt)γ
- 它可以降低简单样本的损失贡献,而使得训练时更重视一些困难样本,
Focal Loss
可以定义为:
- 看一些权重计算的例子:
- 如果预测正样本概率是 0.95(即对于一个真实标签为正样本的样本,使用模型预测它也是正样本的概率是 0.95),这显然是一个简单的样本
- 如果预测正样本概率是 0.5 ,这显然是一个稍微困难一定的样本
- 如果预测负样本的概率为 0.9(即对于一个真实标签为负样本的样本,使用模型预测它是正样本的概率是 0.9),这显然是一个困难的样本,则该样本的难易权重是
- 如果预测负样本的概率为 0.1(即对于一个真实标签为负样本的样本,使用模型预测它是正样本的概率是 0.1),这显然是一个简单的样本,
- 为此,我们得到最终的
Focal Loss
3. Focal Loss 的 PyTorch 实现
首先感谢上海 AI Lab 的杰出工作,SAM-Med2D
我这里的实现来自仓库:SAM-Med2D
如果能对大家有帮助,希望后期大家不要忘记引用这个工作:
class FocalLoss(nn.Module):
def __init__(self, gamma=2.0, alpha=0.25):
super(FocalLoss, self).__init__()
self.gamma = gamma
self.alpha = alpha
def forward(self, pred, mask):
"""
pred: [B, 1, H, W]
mask: [B, 1, H, W]
"""
assert pred.shape == mask.shape, "pred and mask should have the same shape."
p = torch.sigmoid(pred)
num_pos = torch.sum(mask)
num_neg = mask.numel() - num_pos
w_pos = (1 - p) ** self.gamma
w_neg = p ** self.gamma
loss_pos = -self.alpha * mask * w_pos * torch.log(p + 1e-12)
loss_neg = -(1 - self.alpha) * (1 - mask) * w_neg * torch.log(1 - p + 1e-12)
loss = (torch.sum(loss_pos) + torch.sum(loss_neg)) / (num_pos + num_neg + 1e-12)
return loss
总结
参考链接:
深入剖析Focal loss损失函数