神经网络进行波士顿房价预测

前言

前一阵学校有五一数模节校赛,和朋友一起参加做B题,波士顿房价预测,算是第一次自己动手实现一个简单的小网络吧,虽然很简单,但还是想记录一下。

题目介绍

波士顿住房数据由哈里森和鲁宾菲尔德于1978年Harrison and Rubinfeld1收集。它包括了波士顿大区每个调查行政区的506个观察值。1980年Belsley et al.2曾对此数据做过分析。

数据一共14列,每一列的含义分别如下:

英文简称    详细含义
CRIM    城镇的人均犯罪率
ZN    大于25,000平方英尺的地块的住宅用地比例。
INDUS    每个镇的非零售业务英亩的比例。
CHAS    查尔斯河虚拟变量(如果环河,则等于1;否则等于0)
NOX    一氧化氮的浓度(百万分之几)
RM    每个住宅的平均房间数
AGE    1940年之前建造的自有住房的比例
DIS    到五个波士顿就业中心的加权距离
RAD    径向公路通达性的指标
TAX    每一万美元的全值财产税率
PTRATIO    各镇的师生比率
B    计算方法为 $1000(B_k-0.63)^2$,其中Bk是按城镇划分的非裔美国人的比例
LSTAT    底层人口的百分比(%)
price    自有住房数的中位数,单位(千美元)
基于上述数据,请完成以下问题:

建立波士顿房价预测模型并对预测结果进行评价。

问题分析

首先这道题目的很明确,数据一共是 $506×14$ 的一个矩阵,有十三维的自变量,通过建立一个模型来拟合回归出最终的因变量 price,即户主拥有住房价值的中位数。这是一个回归问题,综合考虑有以下两个思路

通过各种回归算法(GradientBoostingRegressor,RandomForestRegressor,ExtraTreesRegressor,LinearRegressor等)结合全部或部分自变量来回归最终的price

建立前馈神经网络模型,根据通用逼近定理,我们可以拟合此回归模型。

我们对上述模型来进行实现并确定评估标准来对他们进行比较,选择最优的模型作为预测模型。

算法流程

传统的回归算法

自变量的选择

首先,考虑到数据集中13列自变量其中某一些可能和最终的房价并无强相关性,如果全部使用进行预测可能会对模型引入噪声,因此我们首先计算了房价price与各个自变量之间的相关系数 $r$ ,其中 $r$ 计算公式如下: $$ r = \frac{\sum(x_i-\bar{x})(y_i-\bar{y})}{\sqrt{\sum(x_i-\bar{x})^2\sum(y_i-\bar{y})^2}} $$ 其中 $x_i,y_i$ 为数据的每个分量,$\bar{x},\bar{y}$ 为数据的均值

该系数反映了两变量之间的相关性,$r$ 的绝对值介于 $[0,1]$ 区间内,$|r|$ 越接近1,表示两数据相关性越高,反之越低。计算后结果如下:


观察结果可以发现,在给定的十三个变量中,LSTAT 与 price 的相关程度最高$(|r|>0.7)$,其次是 RM 与PTRATIO $(|r|>0.5)$,再者是 TAX,INDUS,NOX $(|r|>0.4)$,除上述之外的七个变量都与 price 无较强的相关性,因此我们考虑使用六个相关性较强变量和十三个变量分别来对房价进行预测,并对他们进行对比,来寻找最优的回归模型。

模型的构建

首先我们使用了sklearn中自带的 boston 数据集,并将整体数据集随机划分为了训练集和测试集两部分,所占比例分别为80%和20%。

然后,我们利用Linear,Ridge,Lasso,ElasticNet,DecisionTree,GradientBoosting,RandomForest,ExtraTrees八种模型通过训练集对其进行训练。

接下来,我们利用训练集拟合得到的模型,使用测试集对其进行测试,与 Ground Truth 进行对比,并通过 $R^2$ 来评价该预测结果,其中 $R^2$ 计算公式如下,其是衡量回归模型好坏的常见指标,其值一般处于[0,1]之间,$R^2$ 越接近1,说明模型的性能越好。 $$ R^2 = 1-\frac{\sum(\hat{y_i}-y_i)^2}{\sum(\bar{y}-y_i)^2} $$

最后,考虑到模型的训练及预测可能具有偶然性,因此我们对于每一个模型进行20次训练及预测,利用20次的结果对其进行综合评价。利用得到的结果绘制 箱线图 所得结果如下:

分析最终结果可以发现,无论是使用六个相关性较强变量还是十三个变量来进行预测,GradientBoost(梯度提升决策树)回归模型都是最好的,此外,我们可以发现,利用十三个变量要比利用六个主要变量来进行预测比有着更好的效果。

前馈神经网络

模型的构建
近年来,神经网络理论不断发展,前馈神经网络(多层感知机、全连接神经网络)越来越多的被利用到数据分析中,因此考虑使用前馈神经网络来解决此问题。

前馈神经网络(全连接神经网络)的网络结构一般由三部分构成,输入层,隐藏层,以及输出层,输入层与输出层一般只有一层,隐藏层可有多层。中间利用非线性函数作为激活函数可以使得网络具有拟合非线性函数的能力

根据通用近似定理:

通用近似定理

对于具有线性输出层和至少一个使用“挤压”性质的激活函数的隐藏层组成的前馈神经网络,只要其隐藏层神经元的数量足够,它可以以任意精度来近似任何从一个定义在实数空间中的有界闭集函数。

只要隐藏层网络维度够高,就可以拟合任意的函数。

考虑到我们的模型有六维or十三维的数据输入,因此我们建立两层前馈神经网络,中间具有一层隐藏层,维度为1000维,激活函数使用Relu,Relu函数有以下优点:

Relu相比于传统的Sigmoid、Tanh,导数更加好求,反向传播就是不断的更新参数的过程,因为其导数不复杂形式简单,可以使得网络训练更快速。

此外,当数值过大或者过小,Sigmoid,Tanh的导数接近于0,Relu为非饱和激活函数则不存在这种现象,可以很好的解决梯度消失的问题

Relu函数及网络结构图如图所示:

$$ Relu:f(x) = max(0,x) $$

具体实现

利用流行的深度学习框架 Pytorch 来对模型进行实现。

首先,将数据集随机划分为训练集和测试集两部分,分别占80%和20%,并将其转化为Pytorch中的张量形式。
然后,利用MinMaxScaler对输入数据进行归一化,利用下列公式将其统一归一化为 $[0,1]$ 之间,以求模型能够更快的收敛。
$$ MinMaxScaler:x^{*} = \frac{x-min(x)}{max(x)-min(x)} $$

接下来,构建网络模型,利用 mseloss 作为损失函数,在训练过程中利用反向传播使其最终收敛为0。
$$ MseLoss = \frac{1}{2n}\sum||y(x)-a^L(x)||^2 $$

最后,我们设置网络的学习率为0.01,训练10000个epoch,发现其loss最终降低到0.3%左右,我们利用上文提到的 $R^2$ 对结果进行评估并与回归模型进行对比,通过观察图片可以发现,前馈神经网络相比于传统的回归模型有着更好的拟合效果, 20次预测得到的$R^2$平均值达到了0.95,此外中位数,最大值,最小值也要比回归模型更加优秀,因此我们采用前馈神经网络模型来对最后的房价进行预测。

最终预测

最终我们利用构建的前馈神经网络模型进行预测,利用测试集对其进行对比,绘制预测如下:

可以看到其中很多点都覆盖的很好,即预测准确。

通过理论对模型进行量化分析,计算预测的 $R^2$ $$ R^2 = 1-\frac{\sum(\hat{y_i}-y_i)^2}{\sum(\bar{y}-y_i)^2} = 1-0.01357 = 0.98643=98.643% $$ 可以发现 $R^2$ 十分接近1,说明回归模型性能良好,符合要求。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:/a/905543.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

Spark的集群环境部署

一、Standalone集群 1.1、架构 架构:普通分布式主从架构 主:Master:管理节点:管理从节点、接客、资源管理和任务 调度,等同于YARN中的ResourceManager 从:Worker:计算节点:负责利…

[java][基础]JSP

目标: 理解 JSP 及 JSP 原理 能在 JSP中使用 EL表达式 和 JSTL标签 理解 MVC模式 和 三层架构 能完成品牌数据的增删改查功能 1,JSP 概述 JSP(全称:Java Server Pages):Java 服务端页面。是一种动态的…

常见问题 | 数字签名如何保障电子商务交易安全?

如何解决电商交易中数据泄露、交易欺诈等问题? 数字签名是一种类似于电子“指纹”的安全技术,它在电子商务中扮演着至关重要的角色。随着电子商务的迅猛发展,网上交易的数量不断增加,确保交易的安全性和完整性成为了亟待解决的问题…

【Python基础】

一、编程语言介绍 1、分类 机器语言 (直接用 0 1代码编写)汇编语言 (英文单词替代二进制指令)高级语言 2、总结 1、执行效率:机器语言>汇编语言>高级语言(编译型>解释型) 2、开发效率&…

Java项目实战II基于Java+Spring Boot+MySQL的编程训练系统(源码+数据库+文档)

目录 一、前言 二、技术介绍 三、系统实现 四、文档参考 五、核心代码 六、源码获取 全栈码农以及毕业设计实战开发,CSDN平台Java领域新星创作者,专注于大学生项目实战开发、讲解和毕业答疑辅导。获取源码联系方式请查看文末 一、前言 在当今数字…

双指针习题篇(上)

双指针习题篇(上) 文章目录 双指针习题篇(上)1.移动零题目描述:算法原理:算法流程:代码实现: 2.复写零题目描述:算法原理:算法流程:代码实现: 3.快乐数题目描述:算法原理…

更安全高效的文件传输工具,Ftrans国产FTP替代方案可以了解

文件传输协议(FTP),诞生于1971年,自20世纪70年代发明以来,FTP已成为传输大文件的不二之选。内置有操作系统的 FTP 可提供一个相对简便、看似免费的文件交换方法,因此得到广泛使用。 随着企业发展过程中新增…

Leetcode21:合并两个有效链表

原题地址:. - 力扣(LeetCode) 题目描述 将两个升序链表合并为一个新的 升序 链表并返回。新链表是通过拼接给定的两个链表的所有节点组成的。 示例 1: 输入:l1 [1,2,4], l2 [1,3,4] 输出:[1,1,2,3,4,4]示…

《Mini-internVL》论文阅读:OpenGVLab+清华/南大等开源Mini-InternVL | 1~4B参数,仅用5%参数实现90%性能

论文地址Mini-InternVL: A Flexible-Transfer Pocket Multimodal Model with 5% Parameters and 90% PerformanceGitHub仓库地址模型使用教程和权重下载地址 该论文发表于2024年10月份&#xff0c;截止2024年11月&#xff0c;引用数<10 文章目录 论文摘要1. 引用介绍2. 本文…

双目视觉标定——1原理与实践

0 前言 双目视觉定位是目前机器&#xff08;机器人&#xff09;等领域中使用得非常广泛的视觉定位技术&#xff0c;双目视觉是模拟人的视觉系统利用两个不同位置的摄像头的视差来确定物体的位置。由于有需要采集两个摄像头的图像共同参与计算&#xff0c;所以双目相机装配要求…

免杀对抗—DLL劫持白加黑隐写分离EDRSyscall-hook

前言 今天讲点比较高端的东西—DLL反射注入&#xff0c;首先什么是DLL文件&#xff0c;简答来说就是程序为了实现某个功能而调用的文件。举个例子&#xff0c;某个代码想要实现某个功能是不是会调用一些封装好的函数&#xff0c;exe同样如此&#xff0c;想要实现某个功能就会调…

故障诊断 | MTF-TLSSA-DarkNet-GRU-MSA迁移学习故障识别程序(t分布+莱维飞行改进麻雀优化)

故障诊断 | 故障诊断实例代码 目录 故障诊断 | 故障诊断实例代码效果一览基本介绍程序设计参考资料 效果一览 基本介绍 利用了迁移学习和多项技术改进&#xff0c;包括麻雀搜索法、DarkNet19、GRU、多头注意力机制等&#xff0c;以提高故障识别的准确性和效率 模型框架&#x…

MyBatis中的多级缓存机制(一级缓存和二级缓存)

MyBatis中的多级缓存机制&#xff08;一级缓存和二级缓存&#xff09; 缓存&#xff08;Cache&#xff09;技术在互联网系统的开发过程中应用非常广泛。当系统中出现性能瓶颈时&#xff0c;很多场景都可以使用缓存技术来重构业务处理流程&#xff0c;从而获取性能的提升。缓存…

day14:RSYNC同步

一&#xff0c;概述 概述 rsync &#xff08;开源&#xff09;是一个高效的文件同步和传输工具&#xff0c;广泛用于 Linux 和 Unix 系统中。它可以在本地和远程系统之间同步文件和目录&#xff0c;同时支持增量备份&#xff0c;能够只传输更改过的文件部分&#xff0c;以减少…

Matlab实现白鲸优化算法(BWO)求解路径规划问题

目录 1.内容介绍 2.部分代码 3.实验结果 4.内容获取 1内容介绍 白鲸优化算法&#xff08;BWO&#xff09;是一种受自然界白鲸捕食行为启发的新型优化算法&#xff0c;它通过模拟白鲸的群体捕猎策略和社会互动来探索问题的最优解。BWO因其强大的全局搜索能力和高效的局部搜索能…

python 模块和包、类和对象

模块 模块是包含 Python 代码的文件&#xff0c;通常用于组织相关的函数、类和其他语句。模块可以被导入并在其他 Python 文件中使用。 创建模块 假设你创建了一个名为 mymodule.py 的文件&#xff0c;内容如下&#xff1a; # mymodule.pydef greet(name): return f"…

SpringBoot节奏:Web音乐网站构建手册

2相关技术 2.1 MYSQL数据库 MySQL是一个真正的多用户、多线程SQL数据库服务器。 是基于SQL的客户/服务器模式的关系数据库管理系统&#xff0c;它的有点有有功能强大、使用简单、管理方便、安全可靠性高、运行速度快、多线程、跨平台性、完全网络化、稳定性等&#xff0c;非常…

使用Django REST framework构建RESTful API

使用Django REST framework构建RESTful API Django REST framework简介 安装Django REST framework 创建Django项目 创建Django应用 配置Django项目 创建模型 迁移数据库 创建序列化器 创建视图 配置URL 配置全局URL 配置认证和权限 测试API 使用Postman测试API 分页 过滤和排序…

MySQL 9从入门到性能优化-系统信息函数

【图书推荐】《MySQL 9从入门到性能优化&#xff08;视频教学版&#xff09;》-CSDN博客 《MySQL 9从入门到性能优化&#xff08;视频教学版&#xff09;&#xff08;数据库技术丛书&#xff09;》(王英英)【摘要 书评 试读】- 京东图书 (jd.com) MySQL9数据库技术_夏天又到了…

芯片上音频相关的验证

通常芯片设计公司&#xff08;比如QUALCOMM&#xff09;把芯片设计好后交由芯片制造商&#xff08;比如台积电&#xff09;去生产&#xff0c;俗称流片。芯片设计公司由ASIC部门负责设计芯片。ASIC设计的芯片只有经过充分的验证&#xff08;这里说的验证是FPGA&#xff08;现场…