线性回归预测波士顿房价 loss为NAN原因 画散点图找特征与标签的关系

波士顿房价csv文件

链接: https://pan.baidu.com/s/1uz6oKs7IeEzHdJkfrpiayg?pwd=vufb 提取码: vufb

代码

%matplotlib inline
import random
import torch
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import torch

从CSV中取出数据集

# 加载数据,第一行是无用行,直接跳过
boston = pd.read_csv('../data/boston_house_prices.csv',skiprows=[0])
# 共有14列,前面十三列是特征,最后一列是价格
boston

在这里插入图片描述

取最后一列设置为labels,前面所有列为features

# 最后一列作为labels,把前面十三列的内容作为features
# 直接让最后一列出栈,boston剩下前面13列
labels = boston.pop('MEDV')
features = boston

画散点图,看特征与房价的关系,如果是线性关系,则说明该特征与标签存在一定的相关性。选出与labels相关的特征,作为最终的features

# 看各个特征与房价的散点图
data_xTitle = ['CRIM','ZN','INDUS','CHAS','NOX','RM','AGE','DIS','RAD','TAX','PTRATIO','B', 'LSTAT']
# 设置5行,3列 =15个子图
fig, a = plt.subplots(5, 3)
m = 0
for i in range(0, 5):
    if i == 4:
        a[i][0].scatter(features[str(data_xTitle[m])], labels, s=30, edgecolor='white')
        a[i][0].set_title(str(data_xTitle[m]))
    else:
        for j in range(0, 3):
            a[i][j].scatter(features[str(data_xTitle[m])], labels, s=30, edgecolor='white')
            a[i][j].set_title(str(data_xTitle[m]))
            m = m + 1
plt.show()
# 由下面的图可以看出CRIM,RM,LSTAT 与y是线性的关系,所以选择这三个特征作为特征值。

在这里插入图片描述

# CRIM,RM,LSTAT 与y是线性的关系,所以选择这三个特征作为特征值。
features = features[['LSTAT','CRIM','RM']]

把数据格式转为tensor

features = torch.tensor(np.array(features)).to(torch.float32)
labels = torch.tensor(np.array(labels)).to(torch.float32)
features.shape,labels.shape

(torch.Size([506, 13]), torch.Size([506]))

定义线性回归,损失函数,优化函数

# 制定线性回归模型
def linreg(X,w,b):
    return torch.matmul(X,w) + b
    
# 定义损失函数
def squared_loss(y_hat,y):
    return (y_hat - y.reshape(y_hat.shape)) **2 /2
    
# 定义优化函数
def sgd(params,lr,batch_size):
    '''小批量随机梯度下降'''
    with torch.no_grad():
        for param in params:
            param -= lr * param.grad / batch_size
            param.grad.zero_()

data_iter函数,按批次取数据

def data_iter(batch_size, features, labels):
    num_examples = len(features)
    indices = list(range(num_examples))
    # 这些样本是随机读取的,没有特定的顺序
    random.shuffle(indices)
    for i in range(0, num_examples, batch_size):
        batch_indices = torch.tensor(indices[i: min(i + batch_size, num_examples)])
        yield features[batch_indices], labels[batch_indices]

设置参数

w = torch.normal(0, 0.01, size=(features.shape[1],1), requires_grad=True)
b = torch.zeros(1, requires_grad=True)
lr = 0.03
# lr = 0.0001
num_epochs = 100
net = linreg
loss = squared_loss
batch_size = 10

w和b的shape为:
torch.Size([3, 1])
torch.Size([1])

开始训练

for epoch in range(num_epochs):
    for X, y in data_iter(batch_size, features, labels):
        l = loss(net(X, w, b), y)
        # X和y的小批量损失
        # 因为l形状是(batch_size,1),而不是一个标量。l中的所有元素被加到一起,
        # 并以此计算关于[w,b]的梯度
        l.sum().backward()
        sgd([w, b], lr, batch_size)
    # 使用参数的梯度更新参数
    with torch.no_grad():
        train_l = loss(net(features, w, b), labels)
        print(f'epoch {epoch + 1}, loss {float(train_l.mean()):f}')

当模型的学习率设置为0.03,loss直接变为NAN

epoch 1, loss nan
epoch 2, loss nan
epoch 3, loss nan
epoch 4, loss nan
epoch 5, loss nan
epoch 6, loss nan
epoch 7, loss nan
epoch 8, loss nan
epoch 9, loss nan
epoch 10, loss nan
epoch 11, loss nan
epoch 12, loss nan
epoch 13, loss nan
epoch 14, loss nan
epoch 15, loss nan
epoch 16, loss nan
epoch 17, loss nan
epoch 18, loss nan
epoch 19, loss nan
epoch 20, loss nan
epoch 21, loss nan
epoch 22, loss nan
epoch 23, loss nan
epoch 24, loss nan
epoch 25, loss nan
epoch 26, loss nan
epoch 27, loss nan
epoch 28, loss nan
epoch 29, loss nan
epoch 30, loss nan
epoch 31, loss nan
epoch 32, loss nan
epoch 33, loss nan
epoch 34, loss nan
epoch 35, loss nan
epoch 36, loss nan
epoch 37, loss nan
epoch 38, loss nan
epoch 39, loss nan
epoch 40, loss nan
epoch 41, loss nan
epoch 42, loss nan
epoch 43, loss nan
epoch 44, loss nan
epoch 45, loss nan
epoch 46, loss nan
epoch 47, loss nan
epoch 48, loss nan
epoch 49, loss nan
epoch 50, loss nan
epoch 51, loss nan
epoch 52, loss nan
epoch 53, loss nan
epoch 54, loss nan
epoch 55, loss nan
epoch 56, loss nan
epoch 57, loss nan
epoch 58, loss nan
epoch 59, loss nan
epoch 60, loss nan
epoch 61, loss nan
epoch 62, loss nan
epoch 63, loss nan
epoch 64, loss nan
epoch 65, loss nan
epoch 66, loss nan
epoch 67, loss nan
epoch 68, loss nan
epoch 69, loss nan
epoch 70, loss nan
epoch 71, loss nan
epoch 72, loss nan
epoch 73, loss nan
epoch 74, loss nan
epoch 75, loss nan
epoch 76, loss nan
epoch 77, loss nan
epoch 78, loss nan
epoch 79, loss nan
epoch 80, loss nan
epoch 81, loss nan
epoch 82, loss nan
epoch 83, loss nan
epoch 84, loss nan
epoch 85, loss nan
epoch 86, loss nan
epoch 87, loss nan
epoch 88, loss nan
epoch 89, loss nan
epoch 90, loss nan
epoch 91, loss nan
epoch 92, loss nan
epoch 93, loss nan
epoch 94, loss nan
epoch 95, loss nan
epoch 96, loss nan
epoch 97, loss nan
epoch 98, loss nan
epoch 99, loss nan
epoch 100, loss nan

当模型的学习率设置为0.0001,loss正常,模型开始收敛

epoch 1, loss 141.555878
epoch 2, loss 115.449852
epoch 3, loss 101.026237
epoch 4, loss 90.287994
epoch 5, loss 81.646828
epoch 6, loss 74.384491
epoch 7, loss 68.148872
epoch 8, loss 62.699074
epoch 9, loss 57.872326
epoch 10, loss 53.601421
epoch 11, loss 49.778000
epoch 12, loss 46.333401
epoch 13, loss 43.253365
epoch 14, loss 40.471313
epoch 15, loss 37.963455
epoch 16, loss 35.711601
epoch 17, loss 33.679176
epoch 18, loss 31.841145
epoch 19, loss 30.203505
epoch 20, loss 28.699686
epoch 21, loss 27.352037
epoch 22, loss 26.142868
epoch 23, loss 25.045834
epoch 24, loss 24.059885
epoch 25, loss 23.171280
epoch 26, loss 22.369287
epoch 27, loss 21.646309
epoch 28, loss 20.998608
epoch 29, loss 20.407761
epoch 30, loss 19.874365
epoch 31, loss 19.396839
epoch 32, loss 18.967056
epoch 33, loss 18.576946
epoch 34, loss 18.234808
epoch 35, loss 17.904724
epoch 36, loss 17.623093
epoch 37, loss 17.360590
epoch 38, loss 17.126835
epoch 39, loss 16.916040
epoch 40, loss 16.727121
epoch 41, loss 16.555841
epoch 42, loss 16.401901
epoch 43, loss 16.264545
epoch 44, loss 16.145824
epoch 45, loss 16.026453
epoch 46, loss 15.927325
epoch 47, loss 15.830773
epoch 48, loss 15.748351
epoch 49, loss 15.672281
epoch 50, loss 15.606522
epoch 51, loss 15.546185
epoch 52, loss 15.490641
epoch 53, loss 15.458157
epoch 54, loss 15.395338
epoch 55, loss 15.359412
epoch 56, loss 15.331330
epoch 57, loss 15.284848
epoch 58, loss 15.264071
epoch 59, loss 15.238921
epoch 60, loss 15.206428
epoch 61, loss 15.184341
epoch 62, loss 15.190187
epoch 63, loss 15.144171
epoch 64, loss 15.127305
epoch 65, loss 15.115336
epoch 66, loss 15.111353
epoch 67, loss 15.098548
epoch 68, loss 15.077714
epoch 69, loss 15.075640
epoch 70, loss 15.072990
epoch 71, loss 15.051690
epoch 72, loss 15.046121
epoch 73, loss 15.038815
epoch 74, loss 15.038069
epoch 75, loss 15.027984
epoch 76, loss 15.028069
epoch 77, loss 15.030132
epoch 78, loss 15.015227
epoch 79, loss 15.014658
epoch 80, loss 15.010786
epoch 81, loss 15.005883
epoch 82, loss 15.007875
epoch 83, loss 15.003115
epoch 84, loss 15.015619
epoch 85, loss 14.996306
epoch 86, loss 15.008889
epoch 87, loss 14.993307
epoch 88, loss 14.997282
epoch 89, loss 14.990996
epoch 90, loss 14.991257
epoch 91, loss 14.997286
epoch 92, loss 14.989521
epoch 93, loss 14.987417
epoch 94, loss 14.989147
epoch 95, loss 14.989621
epoch 96, loss 14.984948
epoch 97, loss 14.984961
epoch 98, loss 14.984855
epoch 99, loss 14.983346
epoch 100, loss 14.999675

补充

学习率为0.03为什么会出现loss为NAN的情况?

说明对于模型的损失函数来说,步子太大了,最优的地方直接跨过去了。调小学习率,随着epoch增多,loss降低,模型收敛。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:/a/138266.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

SAP ABAP列表格式及表格输出

REPORT YTEST001. DATA wa LIKE spfli. WRITE: /. WRITE: 10航班承运人,40航班连接,60国家代码,80起飞城市,100起飞机场. SELECT * INTO wa FROM spfli.WRITE: / wa-carrid UNDER 航班承运人,wa-connid UNDER 航班连接,wa-countryfr UNDER 国家代码,wa-cityfrom UNDER 起飞城市…

保洁行业上门预约小程序源码系统 轻松预约 避免排队 源码开源可二开 带完整部署教程

生活节奏的逐步加快,人们对家庭保洁服务的需求日益增长。为了满足这一需求,我们为您打造了一款保洁行业上门预约小程序源码系统。这款系统让您轻松预约保洁服务,避免排队等待,同时源码开源可进行二次开发,还带有完整的…

详解Python中单引号双引号三引号的用法(适合小白)

单引号和双引号的使用 python 中单引号和双引号都是用来表示字符串,在一般情况下两者没有任何差别,在编码时统一规则即可 str1hello python! str2"hello python!" print(str1) print(str2) 有的时候我们需要在输出的字符串中输出双引号或者…

上课笔记(11.11之前笔记)

一.数据结构的分类 1.数据结构中分为四大类:线性表,哈希表,树,图。 2.线性表(line table):呈现线性结构的一种数据结构。具有顺序性,也就是所有数据都是有序的; 数组&…

【无标题】111

这里写自定义目录标题 欢迎使用Markdown编辑器新的改变功能快捷键合理的创建标题,有助于目录的生成如何改变文本的样式插入链接与图片如何插入一段漂亮的代码片生成一个适合你的列表创建一个表格设定内容居中、居左、居右SmartyPants 创建一个自定义列表如何创建一个…

通讯协议学习之路(实践部分):UART开发实践

通讯协议之路主要分为两部分,第一部分从理论上面讲解各类协议的通讯原理以及通讯格式,第二部分从具体运用上讲解各类通讯协议的具体应用方法。 后续文章会同时发表在个人博客(jason1016.club)、CSDN;视频会发布在bilibili(UID:399951374) 本文…

Javascript享元模式

Javascript享元模式 1 什么是享元模式2 内部状态与外部状态3 享元模式的通用结构4 文件上传4.1 对象爆炸4.2 享元模式重构 5 没有内部状态的享元模式6 对象池7 通用对象池实现 1 什么是享元模式 享元(flyweight)模式是一种用于性能优化的模式&#xff0…

数据恢复工具推荐,高效恢复,这4款很实用!

很多电脑用户都会选择将文件直接保存在电脑上,但是在实际的操作过程中,数据丢失的情况难免会出现。而实用的数据恢复工具或许能有效帮助我们找回丢失的数据。电脑上有哪些使用效果比较好的数据恢复工具呢? 今天小编总结了几款好用的工具&…

leetcode:21. 合并两个有序链表

一、题目 函数原型: struct ListNode* mergeTwoLists(struct ListNode* list1, struct ListNode* list2) 二、思路 合并两个有序链表为一个新的升序链表,只需要遍历两个有序链表并比较结点值大小,依次将较小的结点尾插到新链表即可。 三、代码…

C#中.NET Framework 4.8控制台应用通过EF访问已建数据库

目录 一、创建.NET Framework 4.8控制台应用 二、建立数据库 1. 在SSMS中建立数据库Blogging 2.在VS上新建数据库连接 三、安装EF程序包 四、自动生成EF模型和上下文 1.Blog.cs类的模型 2.Post.cs类的模型 3.BloggingContext.cs数据库上下文 五、编写应用程序吧 我们…

Vatee万腾数字化引领未来,vatee创新思维

随着数字化时代的全面来临,Vatee万腾正以其独特的创新思维,为未来描绘出令人瞩目的数字化画卷。在这个充满变革和机遇的时代,Vatee万腾所展现的数字化引领力和创新思维,成为业界的翘楚。 Vatee万腾的创新思维贯穿于其数字化战略的…

数据结构 | 队列的实现

数据结构 | 队列的实现 文章目录 数据结构 | 队列的实现队列的概念及结构队列的实现队列的实现头文件,需要实现的接口 Queue.h初始化队列队尾入队列【重点】队头出队列【重点】获取队列头部元素获取队列队尾元素获取队列中有效元素个数检测队列是否为空销毁队列 Que…

更新:扶风解析计费系统V1.8.2源码/免授权优化版+附教程/修正完整版

源码简介: 最新的扶风解析计费系统V1.8.2源码,它是修正完整版,免授权优化版附带了教程。是更新优化版最新 V1.8 版本免授权版本。 之前分享过1.7.1版本的扶风计费系统,该版本已经存在相当长的时间,并且一直没有进行更…

一文读懂:什么是RISC-V?为啥它是国产芯崛起的关键?

各位ICT的小伙伴们大家好呀。 提到CPU, 大家首先就会想到"卡脖子"事件。 X86和ARM的IP授权虽然方便,但是不自主和不可控, 一被限制就可能导致国内一夜间"无芯"可用。 今天我们就来聊聊一个解决芯片卡脖子的有效方式-…

多路复用IO:select、poll、epoll

文章目录 一、常见的IO模型二、什么是多路IO复用?三、select、poll、epollselectpollepoll 四、总结 一、常见的IO模型 概念优点       缺点适用场景阻塞IOBlocking IO当应用程序执行IO操作时,会被阻塞,直到数据准备好或者IO操作完成才…

项目管理工具:提高团队协作效率,确保项目按时完成

项目管理对于企业的成功至关重要,一个好的项目管理工具可以提高团队协作效率,确保项目按时完成,并保持项目进度的高效跟踪。 近年来,一款名为“进度猫”的项目管理工具逐渐崭露头角,它以其独特的功能和优势&#xff…

删除快一年的数据,能够恢复吗?

在数字化时代,数据已经成为了企业和个人生活中不可或缺的一部分。然而,由于各种原因,我们有时会需要删除某些数据,比如过期的文件、无用的照片或者账号下的旧信息等。但是,当我们删除这些数据后,是否真的能…

提高生产效率和质量,这个方式很有效

在当今竞争激烈的市场环境下,企业需要不断提高生产效率和质量水平以保持竞争优势。而精益生产正是一种能够帮助企业实现这一目标的方法。其中,持续改善是精益生产的核心理念之一。它是指通过不断地寻找和消除浪费,改善流程和提高效率来实现质…

PHP中$_SERVER全局变量

在PHP中,$_SERVER 是一个全局数组变量,它包含了有关服务器和当前脚本的信息。$_SERVER 数组中的每个元素都是服务器环境的一个参数,如请求的方法、请求的 URI、客户端 IP 地址等。 PATH 系统环境变量的值,包含了多个目录的路径…

Xmind 24 for Mac思维导图软件

XMind是一款流行的思维导图软件,可以帮助用户创建各种类型的思维导图和概念图。 以下是XMind的主要特点: - 多样化的导图类型:XMind提供了多种类型的导图,如鱼骨图、树形图、机构图等,可以满足不同用户的需求。 - 强大…