【机器学习】西瓜书习题3.5Python编程实现线性判别分析,并给出西瓜数据集 3.0α上的结果

参考代码
结合自己的理解,添加注释。

代码

  1. 导入相关的库
import numpy as np
import pandas as pd
import matplotlib
from matplotlib import pyplot as plt
  1. 导入数据,进行数据处理和特征工程
    得到数据集 D = { ( x i , y i ) } i = 1 m , y i ∈ { 0 , 1 } D=\{ (x_i,y_i) \}_{i=1}^m, y_i \in \{0,1\} D={(xi,yi)}i=1m,yi{0,1}
# 1.数据处理,特征工程
data_path = 'watermelon3_0_Ch.csv'
data = pd.read_csv(data_path).values
# 按照数据集3.0α,强制转换数据类型
X = data[:,7:9].astype(float)
y = data[:,9]
y[y=='是'] = 1
y[y=='否'] = 0
y = y.astype(int)
  1. 计算西瓜书60页中的 X i 、 μ i 、 Σ i X_{i}、\mu_i、\Sigma_i XiμiΣi
# 将X的数据根据label值分成X0和X1
pos = y == 1
neg = y == 0
X0 = X[neg]
X1 = X[pos]

# 计算u0,u1 keepdims保持原数据维数
u0 = X0.mean(0, keepdims=True)
u1 = X1.mean(0, keepdims=True)

# 计算sigma0,sigma1
sigma0 = np.dot((X0-u0).T,X0-u0)
sigma1 = np.dot((X1-u1).T,X1-u1)
  1. 根据式3.33计算类内散度矩阵
    S w = Σ 0 + Σ 1 = ∑ x ∈ X 0 ( x − μ 0 ) ( x − μ 0 ) T + ∑ x ∈ X 1 ( x − μ 1 ) ( x − μ 1 ) T S_w=\Sigma_0+\Sigma_1=\sum_{x\in X_{0}}(x-\mu_0)(x-\mu_0)^T+\sum_{x\in X_{1}}(x-\mu_1)(x-\mu_1)^T Sw=Σ0+Σ1=xX0(xμ0)(xμ0)T+xX1(xμ1)(xμ1)T
    根据式3.39计算 w w w
    w = S w − 1 ( μ 0 − μ 1 ) w=S_w^{-1}(\mu_0-\mu_1) w=Sw1(μ0μ1)
# 计算类内散度矩阵 with-class scatter matrix
sw = sigma0 + sigma1

# numpy.linalg.inv() 函数来计算矩阵的逆
w = np.dot(np.linalg.inv(sw),(u0-u1).T).reshape(1,-1)
  1. 画出样本点和得到的直线
fig, ax = plt.subplots()
ax.spines['right'].set_color('none')
ax.spines['top'].set_color('none')
ax.spines['left'].set_position(('data', 0))
ax.spines['bottom'].set_position(('data', 0))

plt.scatter(X1[:, 0], X1[:, 1], c='k', marker='o', label='good')
plt.scatter(X0[:, 0], X0[:, 1], c='r', marker='x', label='bad')

plt.xlabel('密度', labelpad=1)
plt.ylabel('含糖量')
plt.legend(loc='upper right')

x_tmp = np.linspace(-0.05, 0.15)
y_tmp = x_tmp * w[0, 1] / w[0, 0]
plt.plot(x_tmp, y_tmp, '#808080', linewidth=1)

得到下图
在这里插入图片描述

  1. 计算每个样本点在直线上的投影
    计算的理解参考这篇文章
# 求w这个向量的 单位向量 wu
# np.linalg.norm()默认求2 范数,表示向量中各个元素平方和 的 1/2 次方,L2 范数又称 Euclidean 范数或者 Frobenius 范数。
wu = w / np.linalg.norm(w)

# 正负样本点
# 求负样本的投影点,并连线
X0_project = np.dot(X0, np.dot(wu.T, wu))
plt.scatter(X0_project[:, 0], X0_project[:, 1], c='r', s=15)
for i in range(X0.shape[0]):
    plt.plot([X0[i, 0], X0_project[i, 0]], [X0[i, 1], X0_project[i, 1]], '--r', linewidth=1)

# 求正样本的投影点,并连线
X1_project = np.dot(X1, np.dot(wu.T, wu))
plt.scatter(X1_project[:, 0], X1_project[:, 1], c='k', s=15)
for i in range(X1.shape[0]):
    plt.plot([X1[i, 0], X1_project[i, 0]], [X1[i, 1], X1_project[i, 1]], '--k', linewidth=1)

得到下图
在这里插入图片描述

将上述代码封装成类,如下:

class LDA(object):

    def fit(self, X_, y_, plot_=False):
        pos = y_ == 1
        neg = y_ == 0
        X0 = X_[neg]
        X1 = X_[pos]

        u0 = X0.mean(0, keepdims=True)  # (1, n)
        u1 = X1.mean(0, keepdims=True)

        sw = np.dot((X0 - u0).T, X0 - u0) + np.dot((X1 - u1).T, X1 - u1)
        w = np.dot(np.linalg.inv(sw), (u0 - u1).T).reshape(1, -1)  # (1, n)

        if plot_:
            # 设置字体为楷体
            plt.rcParams['axes.unicode_minus']=False #用来正常显示负号
            plt.rcParams['font.sans-serif'] = ['KaiTi']
            fig, ax = plt.subplots()
            ax.spines['right'].set_color('none')
            ax.spines['top'].set_color('none')
            ax.spines['left'].set_position(('data', 0))
            ax.spines['bottom'].set_position(('data', 0))

            plt.scatter(X1[:, 0], X1[:, 1], c='k', marker='o', label='good')
            plt.scatter(X0[:, 0], X0[:, 1], c='r', marker='x', label='bad')

            plt.xlabel('密度', labelpad=1)
            plt.ylabel('含糖量')
            plt.legend(loc='upper right')

            x_tmp = np.linspace(-0.05, 0.15)
            y_tmp = x_tmp * w[0, 1] / w[0, 0]
            plt.plot(x_tmp, y_tmp, '#808080', linewidth=1)

            wu = w / np.linalg.norm(w)

            # 正负样板店
            X0_project = np.dot(X0, np.dot(wu.T, wu))
            plt.scatter(X0_project[:, 0], X0_project[:, 1], c='r', s=15)
            for i in range(X0.shape[0]):
                plt.plot([X0[i, 0], X0_project[i, 0]], [X0[i, 1], X0_project[i, 1]], '--r', linewidth=1)

            X1_project = np.dot(X1, np.dot(wu.T, wu))
            plt.scatter(X1_project[:, 0], X1_project[:, 1], c='k', s=15)
            for i in range(X1.shape[0]):
                plt.plot([X1[i, 0], X1_project[i, 0]], [X1[i, 1], X1_project[i, 1]], '--k', linewidth=1)

            # 中心点的投影
            u0_project = np.dot(u0, np.dot(wu.T, wu))
            plt.scatter(u0_project[:, 0], u0_project[:, 1], c='#FF4500', s=60)
            u1_project = np.dot(u1, np.dot(wu.T, wu))
            plt.scatter(u1_project[:, 0], u1_project[:, 1], c='#696969', s=60)

            ax.annotate(r'u0 投影点',
                        xy=(u0_project[:, 0], u0_project[:, 1]),
                        xytext=(u0_project[:, 0] - 0.2, u0_project[:, 1] - 0.1),
                        size=13,
                        va="center", ha="left",
                        arrowprops=dict(arrowstyle="->",
                                        color="k",
                                        )
                        )

            ax.annotate(r'u1 投影点',
                        xy=(u1_project[:, 0], u1_project[:, 1]),
                        xytext=(u1_project[:, 0] - 0.1, u1_project[:, 1] + 0.1),
                        size=13,
                        va="center", ha="left",
                        arrowprops=dict(arrowstyle="->",
                                        color="k",
                                        )
                        )
            plt.axis("equal")  # 两坐标轴的单位刻度长度保存一致
            plt.show()

        self.w = w
        self.u0 = u0
        self.u1 = u1
        return self

    def predict(self, X):
        project = np.dot(X, self.w.T)

        wu0 = np.dot(self.w, self.u0.T)
        wu1 = np.dot(self.w, self.u1.T)

        return (np.abs(project - wu1) < np.abs(project - wu0)).astype(int)

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:/a/57645.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

安装企业级高负载web服务器tomcat,并部署应用

web服务器Tocamt 1.Tocmat简介2.Tocmat安装1.安装jdk2.部署Tomcat1.配置环境变量2.启动tocmat3.Tomcat web管理功能 3.部署jpress应用 1.Tocmat简介 Tomcat是Apache软件基金会&#xff08;Apache Software Foundation&#xff09;的Jakarta 项目中的一个核心项目&#xff0c;由…

C. Binary String Copying - 思维

分析&#xff1a; 赛时我是直接模拟的&#xff0c;tle然后mle&#xff0c;补提&#xff0c;发现规律&#xff0c;每一个改变的字符串都只会对应一个需要改变的区间&#xff0c;例如第一个样例前两个101100 -> 011100和101100 -> 011100&#xff0c;对应区间在确定改变的范…

paddle实现获取pdf的内容

paddle实现获取pdf的内容 1. 环境安装2. 实现代码 源码链接 1. 环境安装 安装paddlepaddle gpu版本python -m pip install paddlepaddle-gpu -i https://pypi.tuna.tsinghua.edu.cn/simplecpu版本&#xff1a;python -m pip install paddlepaddle -i https://pypi.tuna.tsing…

实战:Prometheus+Grafana监控Linux服务器及Springboot项目

文章目录 前言知识积累什么是Prometheus什么是Grafana怎样完成数据采集和监控 环境搭建docker与docker-compose安装docker-compose编写 监控配置grafana配置prometheus数据源grafana配置dashboardLinux Host Metrics监控Spring Boot 监控 写在最后 前言 相信大家都知道一个项目…

云原生全栈体系(二)

Kubernetes实战入门 第一章 Kubernetes基础概念 一、是什么 我们急需一个大规模容器编排系统kubernetes具有以下特性&#xff1a; 服务发现和负载均衡 Kubernetes 可以使用 DNS 名称或自己的 IP 地址公开容器&#xff0c;如果进入容器的流量很大&#xff0c;Kubernetes 可以负…

SSL 证书过期巡检脚本 (Python 版)

哈喽大家好&#xff0c;我是咸鱼 之前写了个 shell 版本的 SSL 证书过期巡检脚本 &#xff08;文章&#xff1a;《SSL 证书过期巡检脚本》&#xff09;&#xff0c;后台反响还是很不错的 那么今天咸鱼给大家介绍一下 python 版本的 SSL 证书过期巡检脚本 &#xff08;完整代码…

王道《操作系统》学习(二)—— 进程管理(二)

2.1 处理机调度的概念、层次 2.1.1 调度的基本概念 2.1.2 调度的三个层次 &#xff08;1&#xff09;高级调度&#xff08;作业调度&#xff09; &#xff08;2&#xff09;中级调度&#xff08;内存调度&#xff09; 补充知识&#xff1a;进程的挂起状态和七状态模型 &#x…

SAP从放弃到入门系列之创建特殊库存转储预留

文章概览 一、思路二、过程2.1前台的主要过程&#xff1a;2.2 BAPI的实现过程&#xff1a; 之前写过几篇生产领料的思路包括代码&#xff0c;有兴趣的可以翻翻之前我发的文章。最近遇到既有项目专用物资、按单专用物资、通用物资合并领料的业务模式&#xff0c;所以领料的库存的…

VGG卷积神经网络-笔记

VGG卷积神经网络-笔记 VGG是当前最流行的CNN模型之一&#xff0c; 2014年由Simonyan和Zisserman提出&#xff0c; 其命名来源于论文作者所在的实验室Visual Geometry Group。 测试结果为&#xff1a; 通过运行结果可以发现&#xff0c;在眼疾筛查数据集iChallenge-PM上使用VGG…

什么是高级持续威胁(APT)攻击

目录 前言什么是高级持续威胁高级持续威胁攻击有哪些独特特征APT攻击的五个阶段APT检测及防护措施总结 前言 APT攻击是利用多个阶段和不同攻击技术的复合网络攻击。APT不是一时兴起2构思或实施的攻击。相反&#xff0c;攻击者故意针对特定目标定制攻击策略。并在较长时间内进行…

Excel·VBA定量装箱、凑数值金额、组合求和问题

如图&#xff1a;对图中A-C列数据&#xff0c;根据C列数量按照一定的取值范围&#xff0c;组成一个分组装箱&#xff0c;要求如下&#xff1a; 1&#xff0c;每箱数量最好凑足50&#xff0c;否则为47-56之间&#xff1b; 2&#xff0c;图中每行数据不得拆分&#xff1b; 3&…

webpack基础知识一:说说你对webpack的理解?解决了什么问题?

一、背景 Webpack 最初的目标是实现前端项目的模块化&#xff0c;旨在更高效地管理和维护项目中的每一个资源 模块化 最早的时候&#xff0c;我们会通过文件划分的形式实现模块化&#xff0c;也就是将每个功能及其相关状态数据各自单独放到不同的JS 文件中 约定每个文件是一…

Matlab对TMS320F28335编程-新建工程闪烁led灯

前言 工具&#xff1a;Matlab2022b Matlab对接C2000插件&#xff0c;下载连接如下 Embedded Coder Support Package for Texas Instruments C2000 Processors - File Exchange - MATLAB Central 在Matlab中加载此插件后&#xff0c;按照要求一步一步的进行就可以&#xff0c…

基于 JavaScript 的富文本编辑器框架简单使用

1.打开wangEditor wangEditor开源 Web 富文本编辑器&#xff0c;开箱即用&#xff0c;配置简单https://www.wangeditor.com/ 2.html文件 <link href"https://unpkg.com/wangeditor/editorlatest/dist/css/style.css" rel"stylesheet"> <style&…

qt源码--事件系统之QAbstractEventDispatcher

1、QAbstractEventDispatcher内容较少&#xff0c;其主要是定义了一些注册接口&#xff0c;如定时器事件、socket事件、注册本地事件、自定义事件等等。其源码如下&#xff1a; 其主要定义了大量的纯虚函数&#xff0c;具体的实现会根据不同的系统平台&#xff0c;实现对应的方…

MQTT服务器详细介绍:连接物联网的通信枢纽

随着物联网技术的不断发展&#xff0c;MQTT&#xff08;Message Queuing Telemetry Transport&#xff09;协议作为一种轻量级、可靠、灵活的通信协议&#xff0c;被广泛应用于物联网领域。在MQTT系统中&#xff0c;MQTT服务器扮演着重要的角色&#xff0c;作为连接物联网设备和…

C高级【day2】

思维导图&#xff1a; 递归实现&#xff0c;输入一个数&#xff0c;输出这个数的每一位&#xff1a; #include<myhead.h>//递归函数 void fun(int num){//num没值不再递归if(0 num){return;}//输出数的最后一位printf("%d\t", num%10);//递归fun(num/10);}…

linux du命令解析(递归计算文件子目录大小)(计算大小)(计算容量)

文章目录 du命令简介用法常用选项示例 文档原 中文选项详细解释示例递归统计某个目录下所有文件大小&#xff08;不足单位会向上取整&#xff09;&#xff08;注意&#xff1a;可能会将目录大小也统计进去&#xff0c;目录大小为4096字节4kb&#xff1f;&#xff09; du命令使用…

AI算法图形化编程加持|OPT(奥普特)智能相机轻松适应各类检测任务

OPT&#xff08;奥普特&#xff09;基于SciVision视觉开发包&#xff0c;全新推出多功能一体化智能相机&#xff0c;采用图形化编程设计&#xff0c;操作简单、易用&#xff1b;不仅有上百种视觉检测算法加持&#xff0c;还支持深度学习功能&#xff0c;能轻松应对计数、定位、…

Vulnhub: blogger:1靶机

kali&#xff1a;192.168.111.111 靶机&#xff1a;192.168.111.176 信息收集 端口扫描 nmap -A -sC -v -sV -T5 -p- --scripthttp-enum 192.168.111.176 在80端口的/assets/fonts/目录下发现blog目录&#xff0c;访问后发现为wordpress 利用wpscan发现wordpress插件wpdisc…