机器学习实验----逻辑回归实现二分类

目录

一、介绍

二、sigmoid函数

(1)公式:

(2)sigmoid函数的输入

预测函数:

以下是sigmoid函数代码:

三、梯度上升

(1)似然函数

公式:

概念:

对数平均似然函数公式:

对数似然函数代码:

代码解释:

梯度上升

学习率:

学习率的选取:

参数更新:

 梯度上升代码:

 代码解释:

四、打印散点图和线性回归图像及数据集处理

1.数据集处理

(1)代码思路

(2)代码展现

2.散点图

思路解析:

代码展现:

 运行结果截图:

3.逻辑回归曲线

代码解析:

代码展现:

运行截图:

五、利用逻辑回归进行分类

代码思路:

代码展现:

运行截图:

六、实验中遇到的问题

七、逻辑回归的优缺点

优点:

缺点:

八、实验的改进

九、总代码


一、介绍

逻辑回归是机器学习中的一种分类模型,逻辑回归是一种分类算法,虽然名字中带有回归。由于算法的简单和高效,在实际中应用非常广泛。

二、sigmoid函数

(1)公式:

g(z)=\frac{1}{1+e^{-z}} 

z取值为负无穷到正无穷,g(z)的取值为[0, 1]。他将一个任意的值映射到[0, 1]区间上,我们从线性回归中得到一个预测值,并将该值映射到sigmoid函数中,因为是二分类,所以我们就可以把这个函数看成预测为正类的概率,我们给定一个阈值,如果大于该阈值就判断为正类,如果小于就预测为负类,以此达到分类的效果。

举个例子:当我们抽奖的时候,假设有60%的概率会中奖,当我们线性回归预测得到的预测值代入sigmoid函数的时候值大等于60%,就预测中奖,否则预测为不中奖。

(2)sigmoid函数的输入

预测函数:

h(x)=g(\theta ^{T} x)=\frac{1}{e^{-\theta ^{T}x}}

其中\theta ^{T}x=\sum_{i=1}^{n}\theta _{i}x_{i}=\theta _{0}x_{0}+\theta _{1}x_{1}+....+\theta _{n}x_{n}

按照上面所说,我们用h(x)来表示预测为正类的概率那么预测为负类的概率就为1-h(x)

当二分类的时候就用1来表示正类用0来表示负类,那么我们就可以得到公式:

p(y|x,\theta )=h(x)^{y}(1-h(x))^{1-y}

这样使得y取值为0和1都能表示为预测为y的概率。

以下是sigmoid函数代码:

其中np.exp()也就是e

 def sigmoid(z):
        return 1 / (1 + np.exp(-z))

三、梯度上升

(1)似然函数

公式:

L(\theta )=\prod_{i=1}^{m}P(y_{i}|x_{i};\theta )=\prod_{i=1}^{m}h(x_{i})^{y}(1-h(x_{i}))^{1-y}

概念:

似然函数是一种衡量参数与观测数据之间关系的函数,通过最大化似然函数来估计参数的取值,从而使模型更好地拟合观测数据。

 也就是说我们要使得这个似然函数的值尽可能大,才能使得这个函数拟合效果好。

这时候我们就需要求导求出这个函数的极值。但是由于这个函数求导计算量大,所以我们可以对这个函数取对数,得到对数似然函数。

对数平均似然函数公式:

我们对L(θ)除以m得到平均似然函数,当然可除可不除。

l(\theta)=\frac{1}{m}lnL(\theta)=\frac{1}{m}\sum_{i=1}^{m}(y_{i}ln h(x_{i})+(1-y_{i})ln(1-h(x_{i})))

转化为对数似然将原来的累乘转化为累加使得通过导数求极值的计算量得到减少。

对数似然函数代码:
def srhs(X, y, theta):
    m = len(y)
    h = sigmoid(X.dot(theta))
    res = (1/m) * (y.dot(np.log(h)) + (1-y).dot(np.log(1-h)))
    return res
代码解释:

m是y的样本数量,theta是以数组存储的θ,也就是上面说的θ1到θn,也就是sigmoid的输入,线性回归方程。h得到的也是一个列向量,是数据集x中每个样本点代入到sigmoid函数中的值。因为是累加的,所以我们最后结果就可以直接将其分为两部分用矩阵相乘计算。最后返回一个浮点型的值。

梯度上升

因为我们是求似然函数的最大值,所以我们用梯度上升来求解。当然还有一个算法是梯度下降,求损失函数的最小值,损失函数公式就是我们的平均似然函数乘以-1的结果。本质上来看其实算法原理都一样,只是表示的内容不一样。

学习率:

决定了模型在训练过程中更新权重参数的速度与方向

再梯度上升中就是表示每次对θ值梯度上升的步长,就有点像微积分中某点的导数来估计这个点在一个范围为t内的的导数, t取尽可能小来减少误差。这里的t也就是θ步长。

学习率的选取:

1.选取的太大大,那么在每一步迭代中,模型参数可能会跨过最优解,导致震荡或者发散,这被称为“振荡现象”或“不稳定性”。

2.如果学习率设置得太小,模型收敛到最优解的速度将会非常慢,而且可能会陷入局部极小点,而不是全局最优解。

普通梯度上升和随机梯度上升的区别也就在学习率的选取上,随机梯度上升的学习率不是固定的,更新时通常使用较小的学习率,并且这个学习率可以自适应地调整。

接下来是对似然函数求导以求极大值

在网上查阅资料得到以下求导公式:

\frac{\delta}{\delta_{\theta _{j}}}h(x_{i})=-\frac{e^{\theta _{0}x_{0}+....+\theta _{n}x_{n}}}{(1+e^{\theta _{0}x_{0}+....+\theta _{n}x_{n}})^{2}}x_{i}^{j}=h(x_{i})(1-h(x_{i}))x_{i}^{j}

\frac{\delta }{\delta _{\theta _{j}}}J(\theta )=\frac{1}{m}\sum_{i=1}^{m}(y_{i}\frac{1}{h(x_{i})} \frac{\delta}{\delta_{\theta _{j}}}h(x_{i})-(1-y_{i})\frac{1}{1-h(x_{i})}\frac{\delta}{\delta_{\theta _{j}}}h(x_{i}))

=\frac{1}{m}\sum_{i=1}^{m}(y_{i}\frac{1}{h(x_{i})} -(1-y_{i})\frac{1}{1-h(x_{i})})\frac{\delta}{\delta_{\theta _{j}}}h(x_{i})

=\frac{1}{m}\sum_{i=1}^{m}(y_{i}(1-h(x_{i}))-(1-y_{i})h(x_{i}))x_{i}^j

=\frac{1}{m}\sum_{i=1}^{m}(y_{i}-h(x_{i}))x^j_{i}

 其中的xij就表示对应的θj的那一列。

 那我们就求得了

参数更新:

\theta _{j}:=\theta _{j}+\alpha \frac{1}{m}\sum_{i=1}^{m}(y_{i}-h(x_{i}))x_{i}^{j}

 梯度上升代码:
def gradient_ascent(X, y, num, alpha):
    m, n = X.shape
    theta = np.zeros(n)
    for i in range(num_):
        h = sigmoid(np.dot(X, theta))

        for j in range(n):
            theta[j] += alpha*(np.sum((y - h) * X[:, j]) / m)
    return theta
 代码解释:

按照上面的思路,我们先假设一个\theta _{0}x_{0}+\theta _{1}x_{1}+....+\theta _{n}x_{n},我们用数组theta来存储θ。我们初始值都赋值为0,从0开始使用梯度上升函数对其进行修改。

第一个for循环表示的是总共的迭代次数,也就是进行梯度上升的次数。内层循环遍历就是对θ中的每个值进行修改。theta[j] += alpha*(np.sum((y - h) * X[:, j]) / m)就是套用上面给定的参数更新公式进行迭代。

四、打印散点图和线性回归图像及数据集处理

1.数据集处理

(1)代码思路

数据集使用的是鸢尾花数据集,因为还不会多维的在图像上表示,所以我取的是鸢尾花数据集的"Sepal.Length" "Sepal.Width" 两列来代表二维坐标轴的横纵坐标。

打印一下我的数据:

因为实验是二分类,所以我挑选的是鸢尾花数据集中的setosa和versicolor类别。为了方便我将versicolor赋值为0,setosa赋值为1.最后得到整合的数据。

对迭代次数赋值为1000,步长alafa赋值为0.01。

(2)代码展现

data = pd.read_csv(r"C:\\Users\\李烨\\Desktop\\ljhg.txt", sep=' ')

X = data.iloc[:, :2].values 
y1 = data.iloc[:, -1].values 

#print(X);
len1=len(y1);
y=np.zeros(len1);
for i in range(len1):
    if y1[i] == 'setosa':
        y[i]=1
    if y1[i] == 'versicolor':
        y[i]=0

#print(y)
X = np.c_[np.ones(X.shape[0]), X]
num = 1000  
alafa = 0.1  

2.散点图

思路解析:

用plt.rcParams['font.sans-serif'] = ['SimHei']使得图像上面可以正常显示中文。

用plt.scatter将数据中的点都画到图上

代码展现:

plt.rcParams['font.sans-serif'] = ['SimHei']
plt.scatter(X[:, 1], X[:, 2], c=y, cmap='viridis')
plt.xlabel('Sepal.Length')
plt.ylabel('Sepal.Width')
plt.title('散点图')
plt.show()

 运行结果截图:

3.逻辑回归曲线

代码解析:

因为我们要连接这一条直线,直线其实是很多的点连接起来的,我们就在x的范围内平均取100个点,因为我们的特征只有两个,我们的线性回归方程为\theta _{0}x_{0}+\theta _{1}x_{1}+....+\theta _{n}x_{n}=0。所以第二个特征的解析式x2的取值就为代码所示。然后就是把点代入得到直线。

当然我们看到总共代码有一些和散点图中的重复。但是还得再敲一遍代码,因为plt.show()会把当前的图像打印出来,并不会保留。

代码展现:
xi = np.linspace(np.min(X[:, 1]), np.max(X[:, 1]), 100)
yi = -(theta[0] + theta[1] * xi) / theta[2]

plt.scatter(X[:, 1], X[:, 2], c=y, cmap='viridis')
plt.plot(xi, yi, "r-", label='线性回归曲线')
plt.xlabel('Sepal.Length')
plt.ylabel('Sepal.Width')
plt.title('鸢尾花数据集二分类')
plt.legend()
plt.show()
运行截图:

五、利用逻辑回归进行分类

代码思路:

导入测试集,处理一下数据。

将每一个数据代入到线性回归方程上,得到一个值,这个值的取值就是负无穷到正无穷,代入到sigmoid函数上映射到0,1上,如果大等于0.5就将其分类为正类,否则就是分类为负类,如果分类正确就记录下来。以此得到准确率

代码展现:

test_data = pd.read_csv("C:\\Users\\李烨\\Desktop\\ljhgtest.txt", sep='\s+')
X_test = test_data[["Sepal.Length", "Sepal.Width"]].values
y_test = test_data["Species"].values
cnt=len(y_test)
#print(X_test[1][0])
num=0
for i in range(cnt):
    t = theta[0] + theta[1] * X_test[i][0] + theta[2] * X_test[i][1]

    res1 = res1 = sigmoid(t)
    print(X_test[i],end=' ')
    flag=0
    if res1 >= 0.5 :
        flag=1
        print(f"预测类型为setosa",end=' ')
    else :
        print(f"预测类型为versicolor",end=' ')
    print("实际结果为",y_test[i])
    if flag==1 and y_test[i]=='setosa':
        num=num+1
    if flag==0 and y_test[i]=='versicolor':
        num=num+1

print("准确率为",num/cnt)

运行截图:

六、实验中遇到的问题

(1)本来以为逻辑回归这是简单的线性回归加上sigmoid函数,但是逻辑回归中得到的线和线性回归中得到的不太一样,线性回归中得到的好像是只有一个类别的拟合函数,而逻辑回归得到的是将多个类别分开的曲线,二者的性质不太一样。

(2)对似然函数的求导化简过程需要先取对数减少计算量,靠查阅资料得到化简后的函数,进而得到参数更新的公式。

(3)学习率的选取上需要选择合适的数据,太大太小都会影响程序的准确率。

(4)对于实验数据的选择,实验实现二分类,如果样本特征太多的话,就会导致在坐标轴上无法正确画出图像,我在实验运行选取的是在二维上,所以只选取两个特征。当然我的程序是可以分类多特征值的算法,公式\theta _{0}x_{0}+\theta _{1}x_{1}+....+\theta _{n}x_{n}就可以保证我的算法可以解决多特征值的算法,但是特征太多就无法正确映射到坐标轴上。

七、逻辑回归的优缺点

优点:

  1. 输出值自然地落在0到1之间,类比到概率
  2. 参数代表每个特征对输出的影响
  3. 时间空间复杂度低,计算量小

缺点:

  1. 因为它本质上是一个线性的分类器,所以处理不好特征之间相关的情况。
  2. 容易欠拟合,精度不高

八、实验的改进

(1)实验中可以实现从二分类到多分类的算法,

(2)实验中的梯度上升可以升级为随机梯度上升

随机梯度和梯度的区别主要是在梯度上升每次迭代使用整个训练集计算梯度,然后更新参数;而随机梯度上升每次迭代只使用一个样本数据计算梯度,并更新参数。

修改我的代码改为随机梯度上升只需要修改td函数,两个函数大差不差。

def sgd(X, y, num, alpha):
    m, n = X.shape
    theta = np.zeros(n)
    for j in range(num):
        for i in range(m):
            pos = np.random.randint(0, m)
            h = sigmoid(np.dot(X[pos], theta))
            theta += alpha * (y[pos] - h) * X[pos]
            
    return theta

 与梯度上升相比随机梯度上升的优缺点:

随机梯度上升在每次迭代中只需要计算一个样本数据的梯度,因此通常比梯度上升更快。但是由于每次只用一个样本就会导致曲线的效果没那么好。

九、总代码

import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
def sigmoid(z):
    return 1.00 / (1.00 + np.exp(-z))

def sr(X, y, theta):
    m = len(y)
    h = sigmoid(X.dot(theta))
    res = (1 / m) * (y.dot(np.log(h)) + (1 - y).dot(np.log(1 - h)))
    return res

def td(X, y, num_iterations, alpha):
    m, n = X.shape
    theta = np.zeros(n)
    for iteration in range(num_iterations):
        h = sigmoid(np.dot(X, theta))

        for j in range(n):
            theta[j] += alpha*(np.sum((y - h) * X[:, j]) / m)
    return theta


data = pd.read_csv(r"C:\\Users\\李烨\\Desktop\\ljhg.txt", sep=' ')

X = data.iloc[:, :2].values
y1 = data.iloc[:, -1].values

#print(X)
len1 = len(y1)
y=np.zeros(len1)
for i in range(len1):
    if y1[i] == 'setosa':
        y[i]=1
    if y1[i] == 'versicolor':
        y[i]=0
#print(y)
X = np.c_[np.ones(X.shape[0]), X]
num = 1000
alafa = 0.1

theta = td(X, y, num, alafa)
print("线性回归曲线 theta:", theta)

plt.rcParams['font.sans-serif'] = ['SimHei']
plt.scatter(X[:, 1], X[:, 2], c=y, cmap='viridis')
plt.xlabel('Sepal.Length')
plt.ylabel('Sepal.Width')
plt.title('散点图')
plt.show()

xi = np.linspace(np.min(X[:, 1]), np.max(X[:, 1]), 100)
yi = -(theta[0] + theta[1] * xi) / theta[2]

plt.scatter(X[:, 1], X[:, 2], c=y, cmap='viridis')
plt.plot(xi, yi, "r-", label='线性回归曲线')
plt.xlabel('Sepal.Length')
plt.ylabel('Sepal.Width')
plt.title('鸢尾花数据集二分类')
plt.legend()
plt.show()

test_data = pd.read_csv("C:\\Users\\李烨\\Desktop\\ljhgtest.txt", sep='\s+')
X_test = test_data[["Sepal.Length", "Sepal.Width"]].values
y_test = test_data["Species"].values
cnt=len(y_test)
#print(X_test[1][0])
num=0
for i in range(cnt):
    t = theta[0] + theta[1] * X_test[i][0] + theta[2] * X_test[i][1]

    res1 = sigmoid(t)
    print(X_test[i],end=' ')
    flag=0
    if res1 >= 0.5 :
        flag=1
        print(f"预测类型为setosa",end=' ')
    else :
        print(f"预测类型为versicolor",end=' ')
    print("实际结果为",y_test[i])
    if flag==1 and y_test[i]=='setosa':
        num=num+1
    if flag==0 and y_test[i]=='versicolor':
        num=num+1

print("准确率为",num/cnt)

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:/a/635806.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

Golang | Leetcode Golang题解之第100题相同的树

题目: 题解: func isSameTree(p *TreeNode, q *TreeNode) bool {if p nil && q nil {return true}if p nil || q nil {return false}queue1, queue2 : []*TreeNode{p}, []*TreeNode{q}for len(queue1) > 0 && len(queue2) > …

10个顶级的论文降重指令,让你的论文降重至1.9%

10个顶级的论文降重指令,本硕博写论文必备! 在ChatGPT4o对话框中输入:写一个Spring BootVue实现的车位管理系统的论文大纲,并对其具体章节进行详细描述。 几小时即可完成一份1万字论文的编写 在GPTS中搜索论文降重,使…

FFmpeg开发笔记(三十)解析H.264码流中的SPS帧和PPS帧

《FFmpeg开发实战:从零基础到短视频上线》一书的“2.1.1 音视频编码的发展历程”介绍了H.26x系列的视频编码标准,其中H.264至今仍在广泛使用,无论视频文件还是网络直播,H.264标准都占据着可观的市场份额。 之所以H.264取得了巨大…

并发编程常见面试题

文章目录 为什么要使用线程池为什么不建议使用 Executors静态工厂构建线程池synchronized的实现原理Synchronized和Lock的区别什么是AQS什么是阻塞队列 为什么要使用线程池 关于线程池的作用和线程池的执行流程参考:java线程池 为什么不建议使用 Executors静态工厂…

BUSCO安装及使用(生物信息学工具-019)

01 背景 Benchmarking Universal Single-Copy Orthologs (BUSCO)是用于评估基因组组装和注释的完整性的工具。通过与已有单拷贝直系同源数据库的比较,得到有多少比例的数据库能够有比对,比例越高代表基因组完整度越好。基于进化信息的近乎全基因单拷贝直…

521源码-免费教程-Linux系统硬盘扩容教程

本教程来自521源码:更多网站源码下载学习教程,请点击👉-521源码-👈获取最新资源 首先:扩容分区表 SSH登陆服务器输入命令:df -TH,获得数据盘相关信息 可以看到演示服务器的数据盘分区是&…

如何让外网访问内网服务?

随着互联网的快速发展,越来越多的企业和个人需要将内网服务暴露给外网用户访问。由于安全和隐私等因素的考虑,直接将内网服务暴露在外网是非常不安全的做法。如何让外网用户安全访问内网服务成为了一个重要的问题。 在这个问题上,天联公司提供…

【吊打面试官系列】Java高并发篇 - AQS 支持几种同步方式 ?

大家好,我是锋哥。今天分享关于 【AQS 支持几种同步方式 ?】面试题,希望对大家有帮助; AQS 支持几种同步方式 ? 1、独占式 2、共享式 这样方便使用者实现不同类型的同步组件,独占式如 ReentrantLock&…

第一份工资

当我拿到我人生的第一份工资时,那是一种难以言表的激动。我记得那个下午,阳光透过窗户洒在了我的办公桌上,我看着那张支票,心中满是欣喜和自豪。那是我独立生活的开始,也是我对自己能力的一种肯定。 我记得我是如何支配…

《Rust奇幻之旅:从Java和C++开启》第1章Hello world 2/5

讲动人的故事,写懂人的代码 很多程序员都在自学Rust。 🤕但Rust的学习曲线是真的陡,让人有点儿怵头。 程序员工作压力大,能用来自学新东西的时间简直就是凤毛麟角。 📕目前,在豆瓣上有7本Rust入门同类书。它们虽有高分评价,但仍存在不足。 首先,就是它们介绍的Rust新…

web自动化文件上传弹框处理

目录 文件上传介绍文件上传处理Alert 弹窗介绍Alert 弹窗处理 课程目标 掌握文件上传的场景以及文件上传的处理方式。掌握 Alert 弹窗的场景以及 Alert 弹窗的处理方式。 思考 碰到需要上传文件的场景,自动化测试应该如何解决? 文件上传处理 找到文…

大数据量MySQL的分页查询优化

目录 造数据查看耗时优化方案总结 造数据 我用MySQL存储过程生成了100多万条数据&#xff0c;存储过程如下。 DELIMITER $$ USE test$$ DROP PROCEDURE IF EXISTS proc_user$$CREATE PROCEDURE proc_user() BEGINDECLARE i INT DEFAULT 1;WHILE i < 1000000 DOINSERT INT…

SkyEye对接CANoe:助力汽车软件功能验证

01.简介 CANoe&#xff08;CAN open environment&#xff09;是德国Vector公司专为汽车总线设计而开发的一款通用开发环境&#xff0c;作为车载网络和ECU开发、测试和分析的专业工具&#xff0c;支持从需求分析到系统实现的整个系统的开发过程。CANoe丰富的功能和配置选项被OE…

【php开发系统性学习】——thinkphp框架的安装和启动保姆式教程

&#x1f468;‍&#x1f4bb;个人主页&#xff1a;开发者-曼亿点 &#x1f468;‍&#x1f4bb; hallo 欢迎 点赞&#x1f44d; 收藏⭐ 留言&#x1f4dd; 加关注✅! &#x1f468;‍&#x1f4bb; 本文由 曼亿点 原创 &#x1f468;‍&#x1f4bb; 收录于专栏&#xff1a…

Linux应用入门(二)

1. 输入系统应用编程 1.1 输入系统介绍 常见的输入设备有键盘、鼠标、遥控杆、书写板、触摸屏等。用户经过这些输入设备与Linux系统进行数据交换。这些设备种类繁多&#xff0c;如何去统一它们的接口&#xff0c;Linux为了统一管理这些输入设备实现了一套能兼容所有输入设备的…

CSS基础(第五天)

目录 定位 为什么需要定位 定位组成 边偏移 静态定位 static&#xff08;了解&#xff09; 相对定位 relative 绝对定位 absolute&#xff08;重要&#xff09; 子绝父相的由来 固定定位 fixed &#xff08;重要&#xff09; 粘性定位 sticky&#xff08;了解&#xff…

【笔记】树(Tree)

一、树的基本概念 1、树的简介 之前我们都是在谈论一对一的线性数据结构&#xff0c;可现实中也有很多一对多的情况需要处理&#xff0c;所以我们就需要一种能实现一对多的数据结构--“树”。 2、树的定义 树&#xff08;Tree&#xff09;是一种非线性的数据结构&#xff0…

Hadoop3:HDFS中NameNode和SecondaryNameNode的工作机制(较复杂)

一、HDFS存储数据的机制简介 HDFS存储元数据(meta data)的时候 结果&#xff0c;记录在fsImage文件里 过程&#xff0c;记录在Edits文件里 同时fsImageEdits最终结果&#xff0c;这个最终结果&#xff08;fsImageEdits&#xff09;会保存一份在内存中&#xff0c;为了提升性能…

5月30日在线研讨会 | 面向智能网联汽车的产教融合解决方案

随着智能网联汽车技术的快速发展&#xff0c;产业对高素质技术技能人才的需求日益增长。为了促进智能网联汽车行业的健康发展&#xff0c;推动教育链、人才链与产业链、创新链的深度融合&#xff0c;经纬恒润推出产教融合相关方案&#xff0c;旨在通过促进教育链与产业链的深度…

Cookie 和 Session概念及相关API

目录 1.Cookie概念 2.理解会话机制 (Session) 3.相关API 3.1HttpServletRequest 3.2HttpServletResponse 3.3HttpSession 3.4Cookie 4.代码示例: 实现用户登陆 1.Cookie概念 Cookie 是存储在用户本地终端&#xff08;如计算机、手机等&#xff09;上的数据片段。 它…