【center-loss 中心损失函数】 参数与应用

文章目录

  • 前言
  • 简单总结一下
  • 参数对比
    • 解释参数
    • 权重衰减(L2正则化)
    • 动量
    • 其他参数
    • 运行


前言

之前我们已经完全弄明白了中心损失函数里的代码是什么意思,并且怎么用的了,现在我们来运行它。

论文:https://ydwen.github.io/papers/WenECCV16.pdf
github代码:https://github.com/KaiyangZhou/pytorch-center-loss

前文:【center-loss 中心损失函数】 原理及程序解释(完)

简单总结一下

这段主代码,还是先以小见大。

首先,有很多点,以普通的拟合直线为例子,假设直线是用来做分类问题,一条直线分成两类,或者说是回归问题,则就是,每个点落在两类的例子是多少。(可以想象可以用来做很多事。)
我们是这样一步一步做的
1、确立损失函数(作为评判好的模型的标准)。(损失函数有很多可选,具体使用具体分析)
2、随机设置权重参数(作为最后好的模型的参数)。(随机函数也有很多可选)
3、确立模型(如:y=wx+b)
4、计算当前的参数得出的值与实际值(标签值)的误差(可跳,损失函数里一般有此值),后代入(1)中的损失函数求得损失值。
5、确定降低损失值的方法。(有梯度下降法或数学公式法)注意这里是对损失函数求导!!!
6、得出由方法计算出的参数值。

贴一下:(梯度下降法的代码)

# 导入必要的库
import numpy as np
import matplotlib.pyplot as plt

# 生成模拟数据,假设真实的w为2,b为3
np.random.seed(0) # 设置随机种子
x = np.linspace(0, 10, 100) # 生成100个在[0,10]的等距数 (包括0,10)
y = 2 * x + 3 + np.random.normal(0, 1, 100) # 生成y值,加入噪声

# 定义一元线性回归模型
def linear_regression(x, w, b):
    return w * x + b

# 定义均方误差函数
def mean_squared_error(y_true, y_pred):
    return np.mean((y_true - y_pred) ** 2)

# 定义梯度下降算法
def gradient_descent(x, y, w, b, lr, epochs):
    # x: 自变量
    # y: 目标变量
    # w: 斜率的初始值
    # b: 截距的初始值
    # lr: 学习率
    # epochs: 迭代次数
    n = len(x) # 样本数量
    history_w = [] # 用来记录w的历史值
    history_b = [] # 用来记录b的历史值
    history_loss = [] # 用来记录损失函数的历史值

    for i in range(epochs): # 迭代epochs次
        # 计算预测值
        y_pred = linear_regression(x, w, b)
        # 计算损失值
        loss = mean_squared_error(y, y_pred)
        # 计算梯度
        dw = -2/n * np.sum((y - y_pred) * x)
        db = -2/n * np.sum(y - y_pred)
        # 更新w和b
        w = w - lr * dw
        b = b - lr * db
        # 记录w,b和损失值
        history_w.append(w)
        history_b.append(b)
        history_loss.append(loss)
        # 打印结果
        print(f"Epoch {i+1}: w={w:.4f}, b={b:.4f}, loss={loss:.4f}")

    return history_w, history_b, history_loss

# 设置超参数
w = 0 # 斜率的初始值
b = 0 # 截距的初始值
lr = 0.02 # 学习率
epochs = 200 # 迭代次数

# 调用梯度下降算法
history_w, history_b, history_loss = gradient_descent(x, y, w, b, lr, epochs)

# 绘制损失函数的变化曲线
plt.plot(range(epochs), history_loss, color="r")
plt.xlabel("Epochs")
plt.ylabel("Loss")
plt.title("Loss Curve")
plt.show()

# 绘制最终的拟合直线
plt.scatter(x, y, color="b", label="Data")
plt.plot(x, linear_regression(x, history_w[-1], history_b[-1]), color="g", label="Fitted Line")
plt.xlabel("x")
plt.ylabel("y")
plt.title("Linear Regression")
plt.legend()
plt.show()

这里我们神经网络逻辑上还是一样。
1、确立损失函数 多了中心损失函数
2、随机设置权重参数nn.parameters()
3、确立模型这里用了CNN模型
4、求得损失值
5、确定降低损失值的方法这里也是梯度下降法
6、得出参数值

原来神经网路并没有想象中的那么高深,只是在原来基础上,做了很多的优化。
贴一下:(简单线性层的模型)

import torch
import torch.nn as nn
import torch.optim as optim

#设置随机种子
torch.manual_seed(0)
# 生成一些随机的输入和标签
x = torch.randn(100, 1) # 100个样本,每个样本有1个特征  randn
y = 3 * x + 5 + torch.randn(100, 1) # 100个样本,每个样本有1个标签,服从 y = 3x + 5 + 噪声 的分布

# 定义一个简单的线性模型
model = nn.Linear(1, 1) # 输入维度是1,输出维度是1
# 定义一个均方误差损失函数
criterion = nn.MSELoss()
# 定义一个随机梯度下降优化器
optimizer = optim.SGD(model.parameters(), lr=0.01) # 学习率是0.01

# 训练100个迭代
for epoch in range(100):
    # 清零梯度
    optimizer.zero_grad()
    # 得到预测结果
    output = model(x)
    # 计算损失
    loss = criterion(output, y)
    # 反向传播,计算梯度
    loss.backward()
    # 更新参数
    optimizer.step()
    # 打印损失
    print(f"Epoch {epoch}, loss {loss.item():.4f}")

# 打印模型参数
print(model.weight)
print(model.bias)

参数对比

之前我们第一个代码与第二个用到的参数:
w:权重
b:偏置
lr:学习率,乘在梯度前
epoch:迭代次数

此github代码:
model.parameters():里面为W权重矩阵、b偏置
criterion_cent.parameters():里面为W权重矩阵、b偏置
lr:学习率,乘在梯度前 这里lr_model为0.001,lr_cent为0.5
epoch:迭代次数 这里为100
变化
weight_decay:权重衰减(=L2惩罚项),一般很小 ,防止过拟合 一般5e-4 这里为5e-4
momentum:动量,乘在速度项前,用来加速学习过程 一般0.5-0.9 这里为0.9
gamma:学习率下降,乘在学习率前,一般0-1之间 这里为0.5
stepsize:学习率下降周期,每隔多少stepsize下降一次 这里为20

torch.optim.SGD(params, lr=0.001, momentum=0, dampening=0, weight_decay=0, nesterov=False, *, maximize=False, foreach=None, differentiable=False)

未用到的默认参数:
nesterov:使用Nesterov动量方法,默认为False。
maximize:梯度找山顶,默认为False。
foreach:None时,在使用Cuda的情况下性能会更好,默认为None。
differentiable:选择为True时,可能会损害性能,默认为False。

解释参数

还是用官方SGD的图来解释一下更新的代码。
已知y(学习率),θ0(参数),f(θ)(目标损失函数),λ(权重衰减),μ(动量),τ(阻尼)
默认θ0为随机值,b1为0。(bt为中间量,或者叫动量缓冲区。累积了之前的梯度信息)
在这里插入图片描述
第一次迭代t=1时,求损失函数的梯度赋给g1,
如果权重衰减λ不为0,则g1 = g1 + λθ0。

以下讲解这行代码:

权重衰减(L2正则化)

这里在梯度中加入了权重参数θ0的信息,从而提高模型泛化的能力。

这个参数λ之所以很小,是因为我们不希望正则化项主导整个损失函数,而只是作为一个轻微的调整。

这个惩罚项会鼓励模型学习到更小的权重,因为大的权重会导致惩罚项增大,从而增加整体的损失函数值。

所以,当我们更新权重时,实际上是在原始梯度的基础上加上了这个惩罚项的梯度。这样做可以防止权重变得过大,有助于防止模型过拟合。

原理参考:机器学习中,L2正则化的原理,及其可以防止过拟合的原因

公式:
在这里插入图片描述
其中:L2范数为
在这里插入图片描述
这里是在损失函数环节上加的L2正则化。

这里解释下上文:g1 = g1 + λθ0
1、这里不是θ0 2,原因其实就是上面一步求梯度,已经将平方项移下来了(当然也有可能损失函数没有平方项)。
或者,我们可以看成,梯度项相加时,我们也要对 λ·IIwII2 求导,此时的平方项2就下来了,放进超参数里了。

2、至于L2范数里的求和为什么在g1 = g1 + λθ0 里没写,解释:是有这个操作的,只是求gt里也没写求和的操作,那为简便写伪代码起见,就也没写。

继续:

动量

如果动量u为0,则直接进行判断是否是求梯度最小还是梯度最大,如果最大则θ1 = θ0 + y x g1 ,如果最小则是θ1 = θ0 - y x g1 。

如果动量u不为0,此时判断t=1,则b1 = g1,
如果使用Nesterov动量方法,则g1 = g1 + ub1 (实际上是g1 =(1+u)x g1),如果不使用则g1 = b1。(可以看出Nesterov动量方法相比于原方法收敛更快)

然后如上更新。

t =2时,(λ、u不为0的情况 )损失函数求导得出g2-> g2 = g2+λθ1
此时判断t>1, 则b2 = ub1 +(1-τ)x g2 (此时b1 = g1) 可以看出当τ为0时,动量u起到一个加速作用

-> 之后同上。

momentum动量解释:
是用来加速梯度下降过程的,但如果动量值设置大于1,会导致更新步长过大,从而可能导致优化过程在最小值附近震荡,甚至发散,而不是收敛到最小值。动量的目的是为了帮助梯度下降算法更快地穿过平坦区域并减少震荡,但它也需要保证整个过程的稳定性。

dampening阻尼解释:
“dampening” 通常指的是减少振荡和过度调整的过程。在随机梯度下降(SGD)中,当使用动量(momentum)时,dampening 参数用于减少动量的影响,从而帮助稳定学习过程。如果设置为非零值,dampening 会减少累积过去梯度的速度,这样可以防止在最小值附近的过度震荡。

例如,在PyTorch的 torch.optim.SGD 优化器中,dampening 参数通常与动量一起使用。如果不希望使用动量的衰减效果,可以将 dampening 设置为0。如果设置了 dampening,每次更新时累积的动量会乘以 (1 - dampening )。这样,即使动量保持不变,通过调整 dampening,也可以控制优化过程的平滑程度

其他参数

# argparse.ArgumentParser() 创建一个ArgumentParser对象 用来处理命令行参数
parser = argparse.ArgumentParser("Center Loss Example")
# dataset # 数据集
# add_argument() 方法用于指定程序需要接受的命令参数
parser.add_argument('-d', '--dataset', type=str, default='mnist', choices=['mnist']) # 选择数据集 例:python main.py -d mnist
parser.add_argument('-j', '--workers', default=4, type=int,
                    help="number of data loading workers (default: 4)") # 数据加载工作线程数 例:python main.py -j 4 #-j表示短名称
# optimization # 优化
parser.add_argument('--batch-size', type=int, default=128) # 批大小
parser.add_argument('--lr-model', type=float, default=0.001, help="learning rate for model") # 学习率
parser.add_argument('--lr-cent', type=float, default=0.5, help="learning rate for center loss") # 中心损失学习率
parser.add_argument('--weight-cent', type=float, default=1, help="weight for center loss") # 中心损失权重
parser.add_argument('--max-epoch', type=int, default=100) # 最大迭代次数
parser.add_argument('--stepsize', type=int, default=20) # 学习率下降间隔 : 每隔多少个epoch下降一次
parser.add_argument('--gamma', type=float, default=0.5, help="learning rate decay") # 学习率衰减 : 学习率下降的倍数 比如0.5表示学习率减半
# model # 模型
parser.add_argument('--model', type=str, default='cnn') # 模型选择
# misc # 其他
parser.add_argument('--eval-freq', type=int, default=10) # 评估频率
parser.add_argument('--print-freq', type=int, default=50) # 打印频率 
parser.add_argument('--gpu', type=str, default='0') # GPU
parser.add_argument('--seed', type=int, default=1) # 随机种子
parser.add_argument('--use-cpu', action='store_true') # 是否使用CPU action='store_true' 表示如果有这个参数则为True
parser.add_argument('--save-dir', type=str, default='log') # 保存路径 保存训练日志 保存在log文件夹下
parser.add_argument('--plot', action='store_true', help="whether to plot features for every epoch") # 是否绘制特征图

args = parser.parse_args() # 解析参数 保存到args中

运行

进入colab官网白嫖gpu。
参考:利用谷歌colab跑github代码详细步骤

$ git clone https://github.com/KaiyangZhou/pytorch-center-loss
$ cd pytorch-center-loss
$ python main.py --eval-freq 1 --gpu 0 --save-dir log/ --plot

评估频率1epoch1次, gpu选择0号。保存在log文件夹下并绘图。
运行如下:
在这里插入图片描述
运行截图如下:
在这里插入图片描述

发现在第12epoch时就达到了97%的正确率。

在第32epoch时就基本上达到了高峰,然后逐渐下降。
在这里插入图片描述
如预期所见生成了log的文件夹。

在这里插入图片描述
选择train/epoch_33.png 如下:
在这里插入图片描述
分离表现的还不错。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:/a/440070.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

【数学+前缀和】第十四届蓝桥杯省赛C++ A组《平方差》(c++)

【问题描述】 给定 L,R,问 L≤x≤R 中有多少个数 x 满足存在整数 y,z 使得 xy的平方−z的平方。 【输入格式】 输入一行包含两个整数 L,R,用一个空格分隔。 【输出格式】 输出一行包含一个整数满足题目给定条件的 x 的数量。 【数据范围】 对于 40% 的…

创建RAID0,RAID5并管理,热备盘,模拟故障

目录 1. RAID介绍以及mdadm安装 1.1 安装mdadm工具 2. 创建raid0 2.1 环境准备 2.2 使用两个磁盘创建RAID0 2.3 查看RAID0信息 2.4 对创建的RAID0进行格式化并挂载 2.5 设置成开机挂载 2.6 删除RAID0 3. 创建raid5 3.1 环境准备 3.2 用3个磁盘来模拟R…

【C++杂货铺】详解string

目录 🌈前言🌈 📁 为什么学习string 📁 认识string(了解) 📁 string的常用接口 📂 构造函数 📂 string类对象的容量操作 📂 string类对象的访问以及遍历操…

(MATLAB)应用实例13-时域信号的频谱分析

采用傅里叶变换来计算存在噪声的适于信号频谱。 假设数据采样频率为1000Hz,一个信号包含两个正弦波,频率50Hz、120Hz,振幅0.7、1,噪声为零平均值的随机噪声,采用FFT方法分析其频谱。 clearFs 1000; …

分享axios+signalr简单封装示例

Ajax Axios Axios 是一个基于 promise 网络请求库,作用于node.js 和浏览器中。 它是 isomorphic 的(即同一套代码可以运行在浏览器和node.js中)。在服务端它使用原生 node.js http 模块, 而在客户端 (浏览端) 则使用 XMLHttpRequests。 从浏览器创建 XMLHttpReque…

kafka 可视化工具

kafka可视化工具 随着科技发展,中间件也百花齐放。平时我们用的redis,我就会通过redisInsight-v2 来查询数据,mysql就会使用goland-ide插件来查询,都挺方便。但是kafka可视化工具就找了半天,最后还是觉得redpandadata…

一招教你优化TCP提高大文件传输效率

在当今企业的数据传输实践中,传统的传输控制协议(TCP)在处理大型文件传输时,其固有的可靠性和复杂性有时会导致效率不足。为了提升大文件传输的效率,对TCP进行优化成为了一个关键任务。 TCP传输的可靠性是其核心优势&a…

Kubernetes-2

Kubernetes学习第二天 k8s-21、Kubernetes的核心组件2、pod2.1、什么是pod 3、3种启动pod的方式3.1、命令行启动pod3.1.1、执行下面命令,背后发生了什么? 3.2、启动一个pod背后发生了什么3.3、使用yml文件3.3.1、标准的pod3.3.2、使用部署控制器启动pod3…

windows部署腾讯tmagic-editor01-Hello world

之前写过一篇使用yarn实现的https://blog.csdn.net/qq_36437991/article/details/133644558,后面的两个没有写,这次准备重新实现 环境 pnpm 8.15.1 node 18.19.0 创建vue项目 pnpm create vitecd hello-world pnpm install执行pnpm dev启动项目 安…

[PTA] 分解质因子

输入一个正整数n(1≤n≤1e15),编程将其分解成若干个质因子(素数因子)积的形式。 输入格式: 任意给定一个正整数n(1≤n≤1e15)。 输出格式: 将输入的正整数分解成若干个质因子积的形式&#…

TypeScript 基础(一)

目录 一、概述 二、开发环境 三、数据类型 1.boolean 2.number 3.string 4.Array 5.type 6.tuple 7.enum 8.any 9.null / undefined 10.never 11.object 结束 一、概述 TypeScript 是一种由微软开发的开源编程语言。它是 JavaScript 的一个超集,这意…

正则表达式-分组

1、oracle-正则表达式:将09/29/2008 用正则表达式转换成2008-09-29 select regexp_replace(09/29/2008, ^([0-9]{2})/([0-9]{2})/([0-9]{4})$, \3-\1-\2) replace from dual; 解析:regexp_replace-替换, 第一个参数:需要进行处…

5个实用的PyCharm插件

大家好,本文向大家推荐五个顶级插件,帮助开发人员提升PyCharm工作流程,将生产力飞升到新高度。 1.CodiumAI 安装链接:https://plugins.jetbrains.com/plugin/21206-codiumate--code-test-and-review-with-confidence--by-codium…

项目的搭建与配置

vue create calendar_pro 选择如下配置选项 安装 vue3 支持 vue add vue-next package.json 关闭 eslint 检测。 vue.config.js 配置跨域同源策略。 const { defineConfig } require(vue/cli-service) module.exports defineConfig({transpileDependencies: true,devServe…

【uni-app】condition 启动模式配置,生产环境无效,仅开发期间生效

在小程序开发过程中,每次代码修改后,都会启动到首页,有时非常不方便,为了更高效的开发,有时需要模拟直接跳转到指定的页面, 操作方法如下: 在pages.joson里面配置下列代码: "…

【自然语言处理六-最重要的模型-transformer-上】

自然语言处理六-最重要的模型-transformer-上 什么是transformer模型transformer 模型在自然语言处理领域的应用transformer 架构encoderinput处理部分(词嵌入和postional encoding)attention部分addNorm Feedforward & add && NormFeedforw…

24/03/07总结

esayx: 贪吃蛇: #include "iostream" #include "cmath" #include "conio.h" #include "easyx.h" #include "time.h" #define NODE_WIDTH 40 using namespace std; typedef struct {int x;int y; }node; enum direction /…

Python笔记|基础算数运算+数字类型(1)

重新整理记录一下python的基础知识 基础运算符 、-、*、/ ;括号 ()用来分组。 >>>2 2 4 >>>50 - 5*6 20 >>>(50 - 5*6) / 4 5.0 >>>8 / 5 1.6向下取整除法:向下舍入到最接近的整数的数学除法。运算符是 //。比如1…

力扣面试经典150 —— 6-10题

力扣面试经典150题在 VScode 中安装 LeetCode 插件即可使用 VScode 刷题,安装 Debug LeetCode 插件可以免费 debug本文使用 python 语言解题,文中 “数组” 通常指 python 列表;文中 “指针” 通常指 python 列表索引 文章目录 6. [中等] 轮转…

谷歌浏览器打包扩展插件,提示清单文件缺失或不可读取

今天想把谷歌浏览器的扩展打包一下,放到我虚拟机的谷歌浏览器,但是一直打包不成功。 问题 打包扩展程序错误 提示‘清单文件缺失或不可读取’ 原因 路径没有选择正确! 解决办法 1.首先找到google浏览器的安装路径。在谷歌浏览器地址栏输…