深度学习分类问题之Logistic Regression

逻辑回归模型,虽然名字是回归,但是是解决分类问题。

在线性回归里面,我们根据有效信息,预测下一个由已知信息得到的数值,叫做回归问题,但是在机器学习里面,常见的是分类问题。最常见的就是MNIST数据集里面的手写数字问题。

在这个问题里面:我们给出了六万多张训练集对我们的模型进行训练,然后给出一张手写数字,模型可以帮我们判断出这个手写数字是几,这叫做分类问题。通过训练(喂数据),训练模型,而后给出数据,模型基于前面的训练将数据进行归类。

这与线性模型不同,我们得到的是数据属于各个类的概率值,我们输出的是一个概率,所有概率之和是1。我们算出概率值,通过找到概率最大的类,得到预测结果。

MNIST:pytorch里面的torchvision包里面提供了相应的(比较流行的)数据集。

第一个参数表示数据集的位置,第二个参数是是否作为训练集,第三个参数表示是否要从网上进行下载。如果已经下载过了,可以设置为False。建议直接选择True,如果发现已经下载,则不需要再次下载了。  

仍然以之前的那个模型为例:

如果我们考虑由前面数据的规律得到的x=4时y的值,我们得到的是一个点数,这是线性回归问题,如果说x=1,2是得到的y=0表示不能通过考试,而当x=3时得到的y=1表示可以通过考试,你们当我们的x=4时得到的y应该表示的是能否通过考试,在这里我们使用分类问题,将我们得到的结果映射为对应的分类。(可以这样理解,每周学习一两个小时的都没有通过,而学习三个小时的通过了,预测学习四个小时的是否能够通过)。

只有两个类别的分类问题,我们称之为二分类。
有多个类别的分类问题,我们称之为多分类问题。

我们要计算的是y=0(没有通过)的概率,和y=1(通过)的概率。二分类其实只需要计算一个值就可以了,因为二分类问题隶属于两个类别的概率值和是1,所以当我们求出一个的概率,可以用1减去这个概率值,得到隶属于另一个类别的概率。

当我们的学习器计算出来隶属于各种类的比例差不多的时候,我们有信心判断,我们的模型对每种类别都没有相应的把握,这个时候就需要质疑学习器的实用性了。我们要做相应的处理,比如我们想输出A种类,但是发现隶属于A种类的概率不足50%,我们就输出“不确定”。或者在二分类问题中,我们隶属于某一种类的概率在0.4到0.6之间,我们也输出“不确定”。

线性回归方程输出的是一个实数,而分类问题输出的是概率,概率值要在0到1之间,所以我们要将线性模型的输出值由实数空间映射到0到1。换句话说,我们需要找到一个函数,将实数的值x转化为概率值0到1,我们通常使用Logistic函数,明显函数的图形超过某个阈值之后,增长非常缓慢。这种函数称为饱和函数,(导数在分界线一边是越来越小,另一边是越来越大)。明显Logistic函数的导数类似正态分布。

看一下其他的sigmoid函数:

关于这些sigmoid函数,请看链接:点击这里。

这些函数里面最出名的就是Logistic函数,所以在大多数情况下,我们说的sigmoid函数指的是Logistic函数。

在最初的模型中,我们不进行非线性处理,但是在Logistic模型中,我们在进行线性处理后,结果做Logistic函数处理,Logistic函数的函数名,我们直接写成\sigma \left ( x \right ),以后当我们看到这个符号,就默认是Logistic函数。

但是我们用到的非线性函数并不一定是把结果映射到0到1之间,有时候我们需要均值是0,那么要映射到-1到1之间。

同理,这时我们需要计算,从而才能反向传播算出损失对权重的梯度。很明显,分类问题的回归和线性回归的最大区别就是加了一个Logistic函数(或说是激活函数)。 

在我们线性回归问题中,残差项表示预测值和真实值在数轴上的距离,是刻画便宜程度的一个量。那么这显然不试用与分类问题。分类问题要怎么计算损失呢?

分类问题的损失不能用几何之间的度量空间来表示,我们要计算分布的差异。

我们可以使用KL散度和交叉熵来计算。

下面来个例子:

 交叉熵越大说明匹配程度越高,所以我们加一个负号表示损失,这时就是交叉熵越大(匹配程度越高),损失越小。

二分类里面所用到的损失函数(即上图函数),我们将其称之为BCE。

很明显,预测结果和真实结果越接近,其损失就越小,Mini-Batch Loss计算的是小批量损失的均值。

在代码的实现中:调用sigmiod函数(默认是Logistic函数。当然不止有Logistic函数,也有tanh函数,relu函数。)

损失也有不同,在线性模型中,我们使用的是MSE,现在我们使用的是BCE。CE是什么意思呢,就是我们刚刚写的cross-entropy(交叉熵)。

代码中数据输入由数据变成了表分类的映射:

那么我们编写网络模型的时候要做哪些任务呢:

1,准备数据
2,模型构造
3,构造损失和优化器
4,进行训练

数据可视化处理:

首先先在0到10之间选择200个数据点
然后将其转化为200行1列的数组
送进模型中
将结果用数组的方式表示出来
最后将其画出来
 

为什么在2.5的时候通过率达到了0.5,因为我们在x=2的时候通过率为0,在x=3的时候通过率是1,那么由线性规则可知,在x=2.5的时候,应该是通过与不通过的分界线。符合我们的生活认知。


 

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:/a/350870.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

【深度学习】sdxl中的 tokenizer tokenizer_2 区别

代码仓库: https://huggingface.co/stabilityai/stable-diffusion-xl-base-1.0/tree/main 截图: 为什么有两个分词器 tokenizer 和 tokenizer_2? 在仔细阅读这些代码后,我们了解到 tokenizer_2 主要是用于 refiner 模型的。 #…

【Flink】记录Flink 任务单独设置配置文件而不使用集群默认配置的一次实践

前言 我们的大数据环境是 CDP 环境。该环境已经默认添加了Flink on Yarn 的客户端配置。 我们的 Flink 任务类型是 Flink on Yarn 的任务。 默认的配置文件是在 /etc/flink/conf 目录下。如今我们的需求是个别任务提供的配置仅用于配置执行参数,例如影响作业的配置…

HCIA学习第四天:静态路由与动态路由

静态路由: 选路原则:尽量选择路径最短的路由条目 扩展配置: 1、负载均衡:当路由器访问同一个目标且目标且目标具有多条开销相似的路径时,可以让设备将流量拆分后延多条路径同时进行传输,以达到叠加带宽的…

JavaScript 学习笔记(JS进阶 Day2)

「写在前面」 本文为 b 站黑马程序员 pink 老师 JavaScript 教程的学习笔记。本着自己学习、分享他人的态度,分享学习笔记,希望能对大家有所帮助。推荐先按顺序阅读往期内容: 1. JavaScript 学习笔记(Day1) 2. JavaSc…

PCL-IO输入输入模块

IO输入输入模块 一、概述二、点云数据格式1. PCD 格式2. PLY 格式3. OBJ 格式4. STL 格式5. OFF 格式 三、读取3D文件1. API 总览2. 示例 四、保存3D文件1. API 总览2. 示例 一、概述 PCL 库提供了一个模块用来对3D数据进行读写操作,这个库提供了一个模块&#xff…

CPQ配置报价 | 面向非标设备制造项目报价系统解决方案

在非标设备细分领域,企业面向定制型项目经常会遇到以下难题:第一,方案制作效率低,易出错;第二,成本核算过程不严谨,准备性差;第三,报价试算过程不科学;第四&a…

最长公共子串的问题(正常方法和矩阵法,动态规划)

题目: 给定两个字符串 text1 和 text2,返回这两个字符串的最长 公共子序列 的长度。如果不存在 公共子序列 ,返回 0 。 一个字符串的 子序列 是指这样一个新的字符串:它是由原字符串在不改变字符的相对顺序的情况下删除某些字符…

C++知识点笔记

二维数组 定义方式: 1、数据类型 数组名[行数][列数]; 2、数据类型 数组名[行数][列数]{{数据1,数据2},{数据3,数据4}}; 3、数据类型 数组名[行数][列数]{数据1,数据2,数据3,数据4}; 4、数据类型 数组名[][列数]{数据1,数据2,数据3,数据4}; 建议:以…

ERROR Failed to get response from https://registry.npm.taobao.org/ 错误的解决

这个问题最近才出现的。可能跟淘宝镜像的证书到期有关。 解决方式一:更新淘宝镜像(本人测试无效,但建议尝试) 虽然无效,但感觉是有很大关系的。还是设置一下比较好。 淘宝镜像的地址(registry.npm.taobao…

leetcode hot 100 电话号码的字母组合

在本题目中,要求我们根据给的数字字符串对应电话号码上的字母组合。所以我们需要建立起数字和电话上字母的对应关系。 然后,组合问题依旧采用回溯来做。首先我们需要确定一下参数,我们需要给的digits,然后还需要字母和数字对应关…

使用IP爬虫代理提取数据的步骤是什么?爬虫代理IP怎么提高采集效率?

​​​​​ 一、使用IP爬虫代理提取数据的步骤 在使用爬虫代理IP提取数据之前,需要先了解数据来源和目标网站的结构。以下是一个基本的步骤:1.确定数据来源 首先需要确定要提取数据的网站或数据源,了解网站的结构、数据存储方式以及数据更新…

HTML - 介绍

一.简介 HTML,超文本标记语言(HyperText Markup Language),是一种用于创建网页的标准标记语言。我们可以使用HTML建立自己的WEB网站或特定页面。HTML运行在浏览器上,由浏览器解析。 ⚠️注意:HTML文件的后缀…

HTML以及CSS相关知识总结(二)

css文件写样式时建议遵循以下顺序: 1.布局定位属性:display/position/float/ear/visibility/overflow(建议display第一个写,毕竟关系到模式) 2.自身属性: width/height/margin/ padding /border/ background 3.文本属性: color/font / text-decoration/t…

区块链中分叉机制

在区块链中我们经常会听到分叉【fork】的概念,今天通过这篇文章来详细的介绍下分叉 什么是分叉 在介绍区块链的分叉机制中,我们以公有链来说明,公有链是去中心化的。任何协议的改变都是代价巨大的,因为全网那么多节点&#xff0…

国产GC6610应用于打印机,医疗器械等产品中,可替代TMC2208/2209/trinamic的参数分析

电机驱动芯片应用范围十分广泛,目前已经广泛应用于消费电子、电动工具、打印机、3D打印、安防监控、通信设备、汽车,以及工业控制等领域。据市场调研机构ResearchAndMarkets统计,2021年全球电机驱动芯片是市场规模为38.8亿美元,预…

uniapp小程序:内存超过2mb解决方法(简单)message:Error: 上传失败:网络请求错误 代码包大小超过限制。

分析:这种情况是代码文件内存超过2mb无法进行预览上传 解决方法: 1、Hbuilder中点击运行-->运行到小程序模拟器--->运行时是否压缩代码 2、在微信小程序中点击详情--->本地设置: 3、点击预览即可运行了

Java通过模板替换实现excel的传参填写

以模板为例子 将上面$转义的内容替换即可 package com.gxuwz.zjh.util;import org.apache.poi.ss.usermodel.*; import org.apache.poi.xssf.usermodel.XSSFWorkbook; import java.io.*; import java.util.HashMap; import java.util.Map; import java.io.IOException; impor…

vue3 常见的路由传参无刷新修改当前路由url带参

无刷新修改当前路由url带参 //tabs切换部分 <el-tabs v-model"activeName" class"demo-tabs" tab-click"handleClick"><el-tab-pane v-for"(item,index) in tagList" :label"item.title" :name"item.name…

4-4 D. 银行排队问题之单队列多窗口加VIP服务

题目描述 假设银行有K个窗口提供服务&#xff0c;窗口前设一条黄线&#xff0c;所有顾客按到达时间在黄线后排成一条长龙。当有窗口空闲时&#xff0c;下一位顾客即去该窗口处理事务。当有多个窗口可选择时&#xff0c;假设顾客总是选择编号最小的窗口。 有些银行会给VIP客户以…

gitee仓库使用中的警告

当 Git 执行 git pull 命令时&#xff0c;有时候会出现类似下面的警告信息&#xff1a; warning: ----------------- SECURITY WARNING ---------------- warning: | TLS certificate verification has been disabled! | warning: ------------------------------------------…