pytorch如何使用Focal Loss

Focal loss 是 文章 Focal Loss for Dense Object Detection 中提出对简单样本的进行decay的一种损失函数。是对标准的Cross Entropy Loss 的一种改进。 FL对于简单样本(p比较大)回应较小的loss。 如论文中的图1, 在p=0.6时, 标准的CE然后又较大的loss, 但是对于FL就有相对较小的loss回应。这样就是对简单样本的一种decay。其中alpha 是对每个类别在训练数据中的频率有关, 但是下面的实现我们是基于alpha=1进行实验的。

在这里插入图片描述

PyTorch中使用Focal Loss,你可以按照以下步骤进行操作

方法一:

1、创建FocalLoss.py文件,添加一下代码

在这里插入图片描述

代码修改处:

  • classnum 处改为你分类的数量
  • P = F.softmax(inputs) 改为 P = F.softmax(inputs,dim=1)
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.autograd import Variable

class FocalLoss(nn.Module):
    r"""
        This criterion is a implemenation of Focal Loss, which is proposed in 
        Focal Loss for Dense Object Detection.

            Loss(x, class) = - \alpha (1-softmax(x)[class])^gamma \log(softmax(x)[class])

        The losses are averaged across observations for each minibatch.

        Args:
            alpha(1D Tensor, Variable) : the scalar factor for this criterion
            gamma(float, double) : gamma > 0; reduces the relative loss for well-classified examples (p > .5), 
                                   putting more focus on hard, misclassified examples
            size_average(bool): By default, the losses are averaged over observations for each minibatch.
                                However, if the field size_average is set to False, the losses are
                                instead summed for each minibatch.


    """
    def __init__(self, class_num=5, alpha=None, gamma=2, size_average=True):
        super(FocalLoss, self).__init__()
        if alpha is None:
            self.alpha = Variable(torch.ones(class_num, 1))
        else:
            if isinstance(alpha, Variable):
                self.alpha = alpha
            else:
                self.alpha = Variable(alpha)
        self.gamma = gamma
        self.class_num = class_num
        self.size_average = size_average

    def forward(self, inputs, targets):
        N = inputs.size(0)
        C = inputs.size(1)
        P = F.softmax(inputs)

        class_mask = inputs.data.new(N, C).fill_(0)
        class_mask = Variable(class_mask)
        ids = targets.view(-1, 1)
        class_mask.scatter_(1, ids.data, 1.)
        #print(class_mask)


        if inputs.is_cuda and not self.alpha.is_cuda:
            self.alpha = self.alpha.cuda()
        alpha = self.alpha[ids.data.view(-1)]

        probs = (P*class_mask).sum(1).view(-1,1)

        log_p = probs.log()
        #print('probs size= {}'.format(probs.size()))
        #print(probs)

        batch_loss = -alpha*(torch.pow((1-probs), self.gamma))*log_p 
        #print('-----bacth_loss------')
        #print(batch_loss)


        if self.size_average:
            loss = batch_loss.mean()
        else:
            loss = batch_loss.sum()
        return loss

2、在你的训练函数里加入模块

from FocalLoss import FocalLoss

loss = FocalLoss()

方法二:

首先,确保你已经导入了torchtorch.nn模块,其中torch.nn提供了各种常见的损失函数。

import torch
import torch.nn as nn

然后,定义一个自定义的Focal Loss类,继承自torch.nn.Module。在类的构造函数中,可以指定Focal Loss所需的参数,例如γ(调节因子)和权重。

class FocalLoss(nn.Module):
    def __init__(self, gamma=2, weight=None):
        super(FocalLoss, self).__init__()
        self.gamma = gamma
        self.weight = weight

    def forward(self, inputs, targets):
        ce_loss = nn.CrossEntropyLoss(weight=self.weight)(inputs, targets)  # 使用交叉熵损失函数计算基础损失
        pt = torch.exp(-ce_loss)  # 计算预测的概率
        focal_loss = (1 - pt) ** self.gamma * ce_loss  # 根据Focal Loss公式计算Focal Loss
        return focal_loss

接下来,在模型训练时,使用自定义的Focal Loss替代交叉熵损失函数即可。

# 定义模型
model = YourModel()

# 定义损失函数(使用自定义的Focal Loss)
criterion = FocalLoss(gamma=2, weight=None)

# 初始化优化器等

# 开始训练循环
for epoch in range(num_epochs):
    # 前向传播、计算损失
    outputs = model(inputs)
    loss = criterion(outputs, targets)

    # 反向传播、更新模型参数
    optimizer.zero_grad()
    loss.backward()
    optimizer.step()

    # 其他操作(如打印训练日志等)

通过以上步骤,就可以在PyTorch中将损失函数由交叉熵损失函数换为Focal Loss。请注意,上述代码示例中的一些细节(例如模型、输入、优化器等)可能需要根据你的实际情况进行修改和补充。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:/a/100426.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

nginx反向代理 负载均衡

目录 1.反向代理介绍: 2.七层代理和四层代理: 2.1 七层代理: 2.2 四层代理: 3.反向代理web服务器: 3.1 代理服务器配置: 3.2 服务器配置 : 3.3 客户端访问: 3.4 代理不同端口&am…

使用nps实现内网穿透

1、介绍 ​ 当我们想把内网的一些资源暴露在公网上时,可以使用内网穿透功能。比如公司的内网服务器,部署了平时需要开发的项目,但是回到家中无法访问,就可以使用内网穿透,将公司内网的接口映射到一台公网的服务器上&a…

C++之std::search应用实例(一百八十九)

简介: CSDN博客专家,专注Android/Linux系统,分享多mic语音方案、音视频、编解码等技术,与大家一起成长! 优质专栏:Audio工程师进阶系列【原创干货持续更新中……】🚀 人生格言: 人生…

typescript的~~和双感叹号符号使用

(标题不给用“!”) "~~"符号使用 1.对于number类型的值,~~是取整作用 const num: number 3.14; const roundedNum: number ~~num; // 3 2.对于boolean类型的值,~ ~ true还是true,false还是false “!!”…

6.(Python数模)游泳接力指派问题

Python解决游泳接力指派问题 该问题也属于一个线性规划问题 源代码 import pulp # 导入 pulp 库 import numpy as np# 主程序 def main():# 问题建模:"""决策变量:x(i,j) 0, 第 i 个人不游第 j 种姿势x(i,j) 1, 第 i 个人游第 …

javaee spring 静态代理

静态代理 package com.test.staticProxy;public interface IUsersService {public void insert(); }package com.test.staticProxy;//目标类 public class UsersService implements IUsersService {Overridepublic void insert() {System.out.println("添加用户");…

A. Increasing and Decreasing

题目:样例: 输入 3 1 4 3 1 3 3 100 200 4输出 1 3 4 -1 100 150 180 200 思路: 思维题,这里我们看一下规律,我们已知a(1),a(n) ,又因为 数列b 应该是递减的,而观察规律可知 &#x…

无涯教程-Android - Spinner函数

Spinner允许您从下拉菜单中选择一个项目 例如。使用Gmail应用程序时,将显示如下所示的下拉菜单,您需要从下拉菜单中选择一个项目。 Spinner Example 示例 本示例演示计算机的类别,您需要从类别中选择一个类别。 以下是修改后的主要Activity文件src/com.example.spinner/Andr…

opencv-人脸识别

对https://blog.csdn.net/weixin_46291251/article/details/117996591这哥们代码的一些修改 import cv2 import numpy as np import os import shutil import threading import tkinter as tk from PIL import Image, ImageTkchoice 0# 首先读取config文件,第一行…

postman-使用Postman的模拟服务来模拟(mock)后端数据,完成前端模拟API调用

最近项目上比较忙,任务多时间紧,导致后端开发任务繁多,无法及时开发完毕,但是前端同学已经把对应功能开发完成,需要进行前后端联调来验证API及一些交互问题;这不能因为后端的进度来影响前端的工作完成情况&…

python爬虫-Selenium

一、Selenium简介 Selenium是一个用于Web应用程序测试的工具,Selenium 测试直接运行在浏览器中,就像真正的用户在操作一样。模拟浏览器功能,自动执行网页中的js代码,实现动态加载。 二、环境配置 1、查看本机电脑谷歌浏览器的版…

【设备树笔记整理7】实践操作

1 使用设备树给DM9000网卡_触摸屏指定中断 1.1 修改方法 根据设备节点的compatible属性,在驱动程序中构造/注册 platform_driver,在 platform_driver 的 probe 函数中获得中断资源。 1.2 实验方法 以下是修改好的代码:第6课第1节_网卡_触摸…

PyTorch深度学习遥感影像地物分类与目标检测、分割及遥感影像问题深度学习优化实践技术应用

我国高分辨率对地观测系统重大专项已全面启动,高空间、高光谱、高时间分辨率和宽地面覆盖于一体的全球天空地一体化立体对地观测网逐步形成,将成为保障国家安全的基础性和战略性资源。未来10年全球每天获取的观测数据将超过10PB,遥感大数据时…

WEBGL(3):鼠标动态绘制点

1 实现思路 绘制单个点鼠标事件监听点击事件将点推送到数组中绘制数组中所有点 2 实现代码 <!DOCTYPE html> <html lang"en"><head><meta charset"UTF-8"><meta http-equiv"X-UA-Compatible" content"IEedge&…

应届生面试指南:如何在缺乏经验的情况下脱颖而出

&#x1f337;&#x1f341; 博主猫头虎&#xff08;&#x1f405;&#x1f43e;&#xff09;带您 Go to New World✨&#x1f341; &#x1f984; 博客首页——&#x1f405;&#x1f43e;猫头虎的博客&#x1f390; &#x1f433; 《面试题大全专栏》 &#x1f995; 文章图文…

运用Python解析HTML页面获取资料

目录 安装Beautiful Soup库&#xff1a; 解析HTML页面&#xff1a; 怎么获取图片、视频、音频资源 1. 图片资源&#xff1a; 2. 视频资源&#xff1a; 3. 音频资源&#xff1a; 可能遇到的问题 1. 编码问题&#xff1a; 2. 动态内容&#xff1a; 3. 反爬虫机制&#…

关于在香橙派安装mysql时遇到的坑

前言 基础环境&#xff1a;硬件&#xff1a;香橙派5操作系统&#xff1a;openkylincpu架构&#xff1a;arm 过程&#xff1a;最近有个任务&#xff0c;要在新的环境中验证一些服务是否可用。目的时向全国产化靠拢。 需要在香橙派上安装openkylin&#xff0c;一开始尝试的是香橙…

用于设计和分析具有恒定近心点半径的低推力螺旋轨迹研究(Matlab代码实现)

&#x1f4a5;&#x1f4a5;&#x1f49e;&#x1f49e;欢迎来到本博客❤️❤️&#x1f4a5;&#x1f4a5; &#x1f3c6;博主优势&#xff1a;&#x1f31e;&#x1f31e;&#x1f31e;博客内容尽量做到思维缜密&#xff0c;逻辑清晰&#xff0c;为了方便读者。 ⛳️座右铭&a…

Oracle报错 PLS-00103: 出现符号 ““在需要下列之一时

在IDEA中执行以下SQL时&#xff0c;报了这个异常 检查了语法&#xff0c;你会发现语法没有任何问题&#xff0c;标点也没有任何问题。 罪魁祸首在这&#xff1a; 换行符为CRLF&#xff0c;我们需要改成LF 即可执行成功

B080-RabbitMQ

目录 RabbitMQ认识概念使用场景优点AMQP协议JMS RabbitMQ安装安装elang安装RabbitMQ安装管理插件登录RabbitMQ消息队列的工作流程 RabbitMQ常用模型HelloWorld-基本消息模型生产者发送消息导包获取链接工具类消息的生产者 消费者消费消息模拟消费者手动签收消息 Work QueuesSen…