[机器学习]简单线性回归——梯度下降法

一.梯度下降法概念

2.代码实现  

# 0. 引入依赖
import numpy as np
import matplotlib.pyplot as plt

# 1. 导入数据(data.csv)
points = np.genfromtxt('data.csv', delimiter=',')
points[0,0]

# 提取points中的两列数据,分别作为x,y
x = points[:, 0]
y = points[:, 1]

# 用plt画出散点图
# plt.scatter(x, y)
# plt.show()

# 2. 定义损失函数:最小平方损失函数
# 损失函数是系数的函数,另外还要传入数据的x,y
def compute_cost(w, b, points):
    total_cost = 0
    M = len(points)

    # 逐点计算平方损失误差,然后求平均数
    for i in range(M):
        x = points[i, 0]
        y = points[i, 1]
        total_cost += (y - w * x - b) ** 2

    return total_cost / M

# 3. 定义模型的超参数
alpha = 0.0001 # 根据实际去调节
initial_w = 0 # 初始w,可以随机生成
initial_b = 0 # 初始b,可以随机生成
num_iter = 10 # 迭代次数

# 4. 定义核心梯度下降算法函数
def grad_desc(points, initial_w, initial_b, alpha, num_iter):
    w = initial_w
    b = initial_b
    # 定义一个list保存所有的损失函数值,用来显示下降的过程
    cost_list = []

    # 迭代,获取每次的损失函数值以及更行w、b
    for i in range(num_iter):
        cost_list.append(compute_cost(w, b, points))
        w, b = step_grad_desc(w, b, alpha, points)

    return [w, b, cost_list]


def step_grad_desc(current_w, current_b, alpha, points):
    sum_grad_w = 0
    sum_grad_b = 0
    M = len(points)

    # 对每个点,代入公式求和
    for i in range(M):
        x = points[i, 0]
        y = points[i, 1]
        sum_grad_w += (current_w * x + current_b - y) * x
        sum_grad_b += current_w * x + current_b - y

    # 用公式求当前梯度
    grad_w = 2 / M * sum_grad_w
    grad_b = 2 / M * sum_grad_b

    # 梯度下降,更新当前的w和b
    updated_w = current_w - alpha * grad_w
    updated_b = current_b - alpha * grad_b

    return updated_w, updated_b

# 5. 测试:运行梯度下降算法计算最优的w和b
w, b, cost_list = grad_desc( points, initial_w, initial_b, alpha, num_iter )

print("w is: ", w)
print("b is: ", b)

cost = compute_cost(w, b, points)

print("cost is: ", cost)

# plt.plot(cost_list)
# plt.show() # 画图

# 6. 画出拟合曲线
plt.scatter(x, y)
# 针对每一个x,计算出预测的y值
pred_y = w * x + b

plt.plot(x, pred_y, c='r')
plt.show()

 w is:  1.4774173755483797
b is:  0.02963934787473238
cost is:  112.65585181499748

 3.代码及数据下载

简单线性回归-最小二乘法及梯度下降法资源-CSDN文库

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:/a/356485.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

【hcie-cloud】【23】容器编排【k8s】【Kubernetes常用工作负载、Kubernetes调度器简介、Helm简介、缩略词】【下】

文章目录 单机容器面临的问题、Kubernetes介绍与安装、Kubernetes对象的基本操作、Kubernetes YAML文件编写基础Kubernetes常用工作负载Kubernetes常用工作负载简介创建一个无状态nginx集群无状态工作负载Deployment说明无状态工作负载Deployment常见操作创建一个有状态的MySQL…

单链表实现通讯录(增删查改)

前言 之前写了很多次通讯录,一次比一次复杂,从静态到动态,再到文件操作,再到顺序表,今天要好好复习一下单链表,于是乎干脆用单链表再写一遍。 首先我们之前已经用单链表写过他的增删查改了,于…

1.28回溯(中等)

目录 1.格雷编码 2. 复原 IP 地址 3. 火柴拼正方形 1.格雷编码 n 位格雷码序列 是一个由 2n 个整数组成的序列,其中: 每个整数都在范围 [0, 2n - 1] 内(含 0 和 2n - 1)第一个整数是 0一个整数在序列中出现 不超过一次每对 相…

备战蓝桥杯----贪心算法(二进制)

已经差不多掌握了贪心的基本思想,让我们看几道比较趣的题吧! 先来个比较有意思的题热热身: 法1.我们可以先把l,r化成二进制的形式。 然后分俩种情况: (1)若他们位数不一样并且位数高的全为1,…

在Shopee菲律宾站点进行选品时的策略

随着电子商务的快速发展,越来越多的卖家开始将目光投向了海外市场。作为东南亚地区最大的电商平台之一,Shopee菲律宾站点吸引了众多卖家的关注。然而,在这个竞争激烈的市场上,卖家需要制定一系列的策略,才能在选品中脱…

arcgis 批量删除字段

一、打开ArcToolbox-数据管理工具-字段-删除字段。 二、在输入表中选择要删除字段的要素,在删除字段栏中选择要删除的字段,点击确认即可。

git配置用户名和邮箱

1.git 1.配置用户名和邮箱 2.git初体验 git init 初始化git仓库 管理项目让git管理你的本次代码变更 git add .git commit -m “你完成的功能” 后续如果新增/修改/删除代码, 完成新功能时 重复2 3.查看日志 1.git log 4.版本回退 1.查看提交的版本记录 git l…

UE5.1_常用节点说明(经常忘记怎么用?)(常改)

UE5.1_常用节点说明(经常忘记怎么用?)(常改) 1. Gate——门节点。只有当门是Open状态才会执行Exit后面的代码。 Open开门;Close关门;Toggle开门和关门交替。 2. 关于控制ArmLength即控制相机前…

基于saltstack开发自动化开通主机防火墙策略工具

一、前言 企业安全防护策略中会要求操作系统开启防火墙,开启iptables防火墙后,对于业务网络访问意味着要经常去变更调整iptables防火墙策略。如果是管理几台服务器,手工登录操作下还能接受。但在实际大型IT架构中,可能涉及到的服…

【JavaScript基础入门】05 JavaScript基础语法(三)

JavaScript基础语法(三) 目录 JavaScript基础语法(三)数组概述数组语法多维数组 操作数组修改数组获取数组长度数组和字符串之间的转换添加和删除数组项 Null 和 Undefined字符串连接字符串字符串转换获取字符串的长度在字符串中查…

后台管理系统模板搭建/项目配置

1 项目初始化 一个项目要有统一的规范,需要使用eslintstylelintprettier来对我们的代码质量做检测和修复,需要使用husky来做commit拦截,需要使用commitlint来统一提交规范,需要使用preinstall来统一包管理工具。 1.1 环境准备 1…

合合信息技术能力

体验中心网址: https://www.textin.com/TextIn体验中心 - 在线免费体验中心https://www.textin.com/ 合合技术团队CSDN: 合合技术团队的博客_CSDN博客-基于深度学习的文本检测与识别技术白皮书,【通用文本信息抽取技术白皮书】,【技术应用】领域博主 …

xss靶场实战

靶场链接:https://pan.baidu.com/s/1ors60QJujcmIZPf3iU3SmA?pwd4mg4 提取码:4mg4 XSS漏洞原理 XSS又叫CSS(Cross Site Script),跨站脚本攻击。因为与html中的css样式同,所以称之为XSS。在OWASP top 1…

Python Tornado 实现SSE服务端主动推送方案

一、SSE 服务端消息推送 SSE 是 Server-Sent Events 的简称, 是一种服务器端到客户端(浏览器)的单项消息推送。对应的浏览器端实现 Event Source 接口被制定为HTML5 的一部分。相比于 WebSocket,服务器端和客户端工作量都要小很多、简单很多&#xff0c…

深度强化学习(王树森)笔记05

深度强化学习(DRL) 本文是学习笔记,如有侵权,请联系删除。本文在ChatGPT辅助下完成。 参考链接 Deep Reinforcement Learning官方链接:https://github.com/wangshusen/DRL 源代码链接:https://github.c…

政安晨的机器学习笔记——基于Ubuntu系统的Miniconda安装Jupyter Notebook

一、准备工作 Miniconda的安装请参考我的另一篇博客文章: 实例讲解深度学习工具PyTorch在Ubuntu系统上的安装入门(基于Miniconda)(非常详细)https://blog.csdn.net/snowdenkeke/article/details/135887509 这里…

༺༽༾ཊ—Unity之-05-抽象工厂模式—ཏ༿༼༻

首先创建一个项目, 在这个初始界面我们需要做一些准备工作, 建基础通用文件夹, 创建一个Plane 重置后 缩放100倍 加一个颜色, 任务:使用 抽象工厂模式 创建 人物与宠物 模型, 首先资源商店下载 人物与宠物…

C++STL之map、set的使用和模拟实现

绪论​: “我这个人走得很慢,但是我从不后退。——亚伯拉罕林肯”,本章是接上一章搜索二叉树中红黑树的后续文章,若没有看过强烈建议观看,否则后面模拟实现部分很看懂其代码原理。本章主要讲了map、set是如何使用的&am…

torch与cuda\cudnn和torchvision的对应

以上图片来源于这篇博客 于是,我需要手动下载0.9.0torchvision 直接在网站https://pypi.tuna.tsinghua.edu.cn/simple/后面加上torchvision,就不用ctrlF搜torchvision了,即进入下面这个网站,找到对应版本的包下载安装即可 https…

html页面练习——公司发展流程图

1.效果图 2.html <div class"center"><header><h1>发展历程</h1><h3>CONMPANY HISTORY</h3></header><main><div class"left"><div class"time1">2012.12</div><div cla…