机器学习-自定义Loss函数

1、简介

机器学习框架中使用自定义的Loss函数,

2、应用

(1)sklearn

from sklearn.metrics import max_error
from sklearn.metrics import make_scorer
from sklearn.model_selection import cross_val_score
from sklearn.linear_model import Ridge

def custom_loss(y_true, y_pred, **kwargs):
    # Define your custom loss calculation here
    y_true = np.array(y_true)
    y_pred = np.array(y_pred)
    if y_true.ndim == 1 :
        y_true = y_true.reshape((-1, 1))

    if y_pred.ndim == 1:
        y_pred = y_pred.reshape((-1, 1))   
        
    loss = max(y_true-y_pred)
    return loss


data = pd.DataFrame(np.array([[i for i in range(0,300)],[i for i in range(100,400)],[i for i in range(200,500)]]).T,columns=['a','b','c'])

X_train ,y_train = data[['a','b']],data[['c']]
clf = Ridge()

custom_scorer = make_scorer(custom_loss, greater_is_better=False)

# Create and train a model using the custom loss function
# model = Ridge()
scores = cross_val_score(clf, X_train, y_train, cv=5, scoring=custom_scorer)

输出是cv=5,交叉验证的5个结果,评估模型

(2)pycaret

from pycaret.regression import *
from pycaret.datasets import get_data
import pandas as pd
import numpy as np
from sklearn.metrics import max_error
from sklearn.metrics import make_scorer


def custom_loss(y_true, y_pred, **kwargs):
    # Define your custom loss calculation here
    y_true = np.array(y_true)
    y_pred = np.array(y_pred)
    if y_true.ndim == 1 :
        y_true = y_true.reshape((-1, 1))

    if y_pred.ndim == 1:
        y_pred = y_pred.reshape((-1, 1))   
        
    loss = max(y_true-y_pred)
    return loss

# # load sample dataset
# # data = get_data('insurance')
data = pd.DataFrame(np.array([[i for i in range(0,300)],[i for i in range(100,400)],[i for i in range(200,500)]]).T,columns=['a','b','c'])
s = setup(data, target='c')
# custom_loss = make_scorer(custom_loss)
add_metric('custom_loss', 'Custom Loss', custom_loss)
best = compare_models()
predict_model(best)

 

3、深度学习框架

(1)torch

import torch
import torch.nn as nn
import numpy as np
import os

os.environ["KMP_DUPLICATE_LIB_OK"] = "TRUE"


def get_x_y():
    np.random.seed(0)
    x = np.random.randint(0, 50, 300)
    y_values = 2 * x + 21
    x = np.array(x, dtype=np.float32)
    y = np.array(y_values, dtype=np.float32)
    x = x.reshape(-1, 1)
    y = y.reshape(-1, 1)
    return x, y


class LinearRegressionModel(nn.Module):
    def __init__(self, input_dim, output_dim):
        super(LinearRegressionModel, self).__init__()
        self.linear = nn.Linear(input_dim, output_dim)  # 输入的个数,输出的个数

    def forward(self, x):
        out = self.linear(x)
        return out


class CustomLoss(nn.Module):
    def __init__(self):
        super(CustomLoss, self).__init__()
        self.mse_loss = nn.MSELoss()

    def forward(self, x, y):
        mse_loss = torch.mean(torch.pow((x - y), 2))
        return mse_loss


if __name__ == '__main__':
    input_dim = 1
    output_dim = 1
    x_train, y_train = get_x_y()

    model = LinearRegressionModel(input_dim, output_dim)
    epochs = 1000  # 迭代次数
    optimizer = torch.optim.SGD(model.parameters(), lr=0.001)
    # model_loss = nn.MSELoss() # 使用MSE作为loss
    model_loss = CustomLoss()  # 自定义loss
    # 开始训练模型
    for epoch in range(epochs):
        epoch += 1
        # 注意转行成tensor
        inputs = torch.from_numpy(x_train)
        labels = torch.from_numpy(y_train)
        # 梯度要清零每一次迭代
        optimizer.zero_grad()
        # 前向传播
        outputs: torch.Tensor = model(inputs)
        # 计算损失
        loss = model_loss(outputs, labels)
        # 返向传播
        loss.backward()
        # 更新权重参数
        optimizer.step()
        if epoch % 50 == 0:
            print('epoch {}, loss {}'.format(epoch, loss.item()))

参看:pytorch自定义loss损失函数_python_脚本之家

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:/a/71482.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

腾讯云标准型CVM云服务器详细介绍

腾讯云CVM服务器标准型实例的各项性能参数平衡,标准型云服务器适用于大多数常规业务,例如:web网站及中间件等,常见的标准型云服务器有CVM标准型S5、S6、SA3、SR1、S5se等规格,腾讯云服务器网来详细说下云服务器CVM标准…

Ubuntu22.04安装docker

在ubuntu22.04上安装docker还是比较容易的,之前在公司的centos6上边装docker,那才真是一言难尽呀,废话不多说,开始安装 1、更新包管理器 apt update 2、安装必要的软件包,以便允许 apt 使用 HTTPS 仓库 sudo apt i…

手撕Java集合——链表

链表 一、链表概念特性二、不带头单向非循环链表实现🍑1、定义结点🍑2、打印链表🍑3、使用递归逆序打印链表🍑4、头插🍑5、尾插🍑6、指定位置插入🍑7、查找是否包含关键字key是否在单链表当中&a…

多传感器融合相关技术

重要说明:本文从网上资料整理而来,仅记录博主学习相关知识点的过程,侵删。 一、参考资料 多传感器融合定位学习 深蓝-多传感器定位融合 深蓝学院 多传感器融合定位 作业 多传感器融合详解 二、相关介绍 1. 毫米波雷达(Radar&a…

R语言4_安装BayesSpace

环境Ubuntu22/20, R4.1 你可能会报错说你的R语言版本没有这个库,但其实不然。这是一个在Bioconductor上的库。 同时我也碰到了这个问题,ERROR: configuration failed for package systemfonts’等诸多类似问题,下面的方法可以一并解决。 第…

790. 多米诺和托米诺平铺

题目描述&#xff1a; 主要思路&#xff1a; class Solution { public:int numTilings(int n) {long long f[n][4],mod1e97;f[0][0]1;f[0][1]f[0][2]0;f[0][3]1;for(int i1;i<n;i){f[i][0]f[i-1][3];f[i][1] (f[i-1][0]f[i-1][2])%mod;f[i][2] (f[i-1][0]f[i-1][1])%mod;f…

力扣221.最大正方形(动态规划)

思路&#xff1a; 思路&#xff1a;从[0,0]元素开始&#xff0c;计算每个元素对应其与[0,0]之间矩阵块中最大正方形边长情况&#xff1a;1&#xff09;matrix [ i , j ] ‘0’ --> 元素对应的最大正方形为0。情况&#xff1a;2&#xff09;matrix [ i , j ] ‘1’ -->…

中级课程-SSRF(CSRF进阶)

文章目录 成因危害挖掘 成因 危害 挖掘

详解JAVA远程debug

目录 1.什么是远程debug&#xff1f; 2.远程debug普通JAVA程序 环境 测试程序 程序启动指令 编译器配置 3.远程debug JAVA Web程序 4.远程debug spring boot程序 1.什么是远程debug&#xff1f; 远程debug&#xff0c;也就是可以在本地debug远端部署的程序&#xff0c…

界面控件DevExpress WPF Chart组件——拥有超快的数据可视化库!

DevExpress WPF Chart组件拥有超大的可视化数据集&#xff0c;并提供交互式仪表板与高性能WPF图表库。DevExpress Charts提供了全面的2D / 3D图形集合&#xff0c;包括数十个UI定制和数据分析/数据挖掘选项。 PS&#xff1a;DevExpress WPF拥有120个控件和库&#xff0c;将帮助…

快速上手PyCharm指南

PyCharm简介 PyCharm是一种Python IDE&#xff08;Integrated Development Environment&#xff0c;集成开发环境&#xff09;&#xff0c;带有一整套可以帮助用户在使用Python语言开发时提高其效率的工具&#xff0c;比如调试、语法高亮、项目管理、代码跳转、智能提示、自动…

Fabric系列 - 知识点整理

知识点 源码编译 主机编译 容器编译 手动部署(docker-compose) 单peer 多peer 中途加peer 多主机多peer 链码 语法, 接口 (go版) 命令行调用 ca server 在DApp中使用SDK调用 (js版) 部署的几个阶段 部署1排序和1节点, 1组织1通道 光部署能Dapp 带ca server (每个组织一个)…

CI/CD—K8S 基本理解与部署

1 K8S 是什么 Kubernetes 是一款容器的编排调度工具&#xff0c;来源于 Google 开源的 Brog 系统。Kubernetes简称K8S&#xff0c;是用8代替8个字符 “ubernete” 而成的缩写&#xff0c;用于管理云平台中多个主机上的容器化的应用&#xff0c;Kubernetes 的目标是让部署容器化…

【javaweb】学习日记Day1 - HTML CSS入门

目录 一、图片标签 ① 绝对路径 1.绝对磁盘路径 2.绝对网络路径 ② 相对路径 &#xff08;推荐&#xff09; 二、标题标签 三、水平线标签 四、标题样式 1、CSS引入样式 ① 行内样式 ② 内嵌样式 ③ 外嵌样式 2、CSS选择器 ① 元素选择器 ② id选择器 ③…

读书笔记 |【项目思维与管理】➾ 顺势而动

读书笔记 |【项目思维与管理】➾ 顺势而动 一、企业步入“终结者时代”二、过去成功的经验也许是最可怕的三、做好非重复性的事四、适应客户是出发点五、向知识型企业转变六、速度是决胜条件 &#x1f496;The Begin&#x1f496;点点关注&#xff0c;收藏不迷路&#x1f496; …

Python web实战之Django 的缓存机制详解

关键词&#xff1a;Python、Web 开发、Django、缓存 1. 缓存是什么&#xff1f;为什么需要缓存&#xff1f; 在 Web 开发中&#xff0c;缓存是一种用于存储数据的临时存储区域。它可以提高应用程序的性能和响应速度&#xff0c;减轻服务器的负载。 当用户访问网页时&#xff…

网安周报|半数人工智能开源项目引用存在漏洞的软件包

1、半数人工智能开源项目引用存在漏洞的软件包 根据EndorLabs的数据&#xff0c;开源在AI技术堆栈中发挥着越来越重要的作用&#xff0c;但大多数项目(52%)引用了存在已知漏洞的易受攻击的依赖项。EndorLabs在最新的《软件依赖管理状态报告》声称&#xff0c;在发布仅五个月后&…

【Minecraft】Fabric Mod开发完整流程4 - 自定义物品方块以及食物、燃料

目录 自定义物品与方块自动侦测矿藏工具工具功能实现执行结果 自定义音乐方块自定义食物自定义燃料 自定义物品与方块 自动侦测矿藏工具 探测器纹理下载地址&#xff1a; https://url.kaupenjoe.net/mbkj57/assets 众所周知&#xff0c;正经人永远不喜欢常规套路挖矿&#xff0…

web-xss-dvwa

目录 xss&#xff08;reflected&#xff09; low medium high xss(store) low medium high xss(dom) low medium high xss&#xff08;reflected&#xff09; low 没有什么过滤&#xff0c;直接用最普通的标签就可以了 http://127.0.0.1/DVWA-master/vulnerabili…

AT89C51单片机实现单片机串口互动(中断方式,单片机--单片机,应答)

说一下功能&#xff1a;客户机发送0x01到服务机 2服务单片机应答0xf2到客户机 3客户机接收到0xf2,发送信息153432这6个数字到服务机 4client发送完信息后发送0xaa结束通信 5server接收到0xaa后回复0xaa结束通信&#xff0c;从此老死不相往来 看代码&#xff1a; //发送端…