【机器学习】单变量线性回归

文章目录

  • 线性回归模型(linear regression model)
  • 损失/代价函数(cost function)——均方误差(mean squared error)
  • 梯度下降算法(gradient descent algorithm)
  • 参数(parameter)和超参数(hyperparameter)
  • 代码实现样例
  • 运行结果

源代码文件请点击此处!

线性回归模型(linear regression model)

  • 线性回归模型:

f w , b ( x ) = w x + b f_{w,b}(x) = wx + b fw,b(x)=wx+b

其中, w w w 为权重(weight), b b b 为偏置(bias)

  • 预测值(通常加一个帽子符号):

y ^ ( i ) = f w , b ( x ( i ) ) = w x ( i ) + b \hat{y}^{(i)} = f_{w,b}(x^{(i)}) = wx^{(i)} + b y^(i)=fw,b(x(i))=wx(i)+b

损失/代价函数(cost function)——均方误差(mean squared error)

  • 一个训练样本: ( x ( i ) , y ( i ) ) (x^{(i)}, y^{(i)}) (x(i),y(i))
  • 训练样本总数 = m m m
  • 损失/代价函数是一个二次函数,在图像上是一个开口向上的抛物线的形状。

J ( w , b ) = 1 2 m ∑ i = 1 m [ f w , b ( x ( i ) ) − y ( i ) ] 2 = 1 2 m ∑ i = 1 m [ w x ( i ) + b − y ( i ) ] 2 \begin{aligned} J(w, b) &= \frac{1}{2m} \sum^{m}_{i=1} [f_{w,b}(x^{(i)}) - y^{(i)}]^2 \\ &= \frac{1}{2m} \sum^{m}_{i=1} [wx^{(i)} + b - y^{(i)}]^2 \end{aligned} J(w,b)=2m1i=1m[fw,b(x(i))y(i)]2=2m1i=1m[wx(i)+by(i)]2

  • 为什么需要乘以 1/2?因为对平方项求偏导后会出现系数 2,是为了约去这个系数。

梯度下降算法(gradient descent algorithm)

  • α \alpha α:学习率(learning rate),用于控制梯度下降时的步长,以抵达损失函数的最小值处。若 α \alpha α 太小,梯度下降太慢;若 α \alpha α 太大,下降过程可能无法收敛。
  • 梯度下降算法:

r e p e a t { t m p _ w = w − α ∂ J ( w , b ) w t m p _ b = b − α ∂ J ( w , b ) b w = t m p _ w b = t m p _ b } u n t i l   c o n v e r g e \begin{aligned} repeat \{ \\ & tmp\_w = w - \alpha \frac{\partial J(w, b)}{w} \\ & tmp\_b = b - \alpha \frac{\partial J(w, b)}{b} \\ & w = tmp\_w \\ & b = tmp\_b \\ \} until \ & converge \end{aligned} repeat{}until tmp_w=wαwJ(w,b)tmp_b=bαbJ(w,b)w=tmp_wb=tmp_bconverge

其中,偏导数为

∂ J ( w , b ) w = 1 m ∑ i = 1 m [ f w , b ( x ( i ) ) − y ( i ) ] x ( i ) ∂ J ( w , b ) b = 1 m ∑ i = 1 m [ f w , b ( x ( i ) ) − y ( i ) ] \begin{aligned} & \frac{\partial J(w, b)}{w} = \frac{1}{m} \sum^{m}_{i=1} [f_{w,b}(x^{(i)}) - y^{(i)}] x^{(i)} \\ & \frac{\partial J(w, b)}{b} = \frac{1}{m} \sum^{m}_{i=1} [f_{w,b}(x^{(i)}) - y^{(i)}] \end{aligned} wJ(w,b)=m1i=1m[fw,b(x(i))y(i)]x(i)bJ(w,b)=m1i=1m[fw,b(x(i))y(i)]

参数(parameter)和超参数(hyperparameter)

  • 超参数(hyperparameter):训练之前人为设置的任何数量都是超参数,例如学习率 α \alpha α
  • 参数(parameter):模型在训练过程中创建或修改的任何数量都是参数,例如 w , b w, b w,b

代码实现样例

import numpy as np
import matplotlib.pyplot as plt

# 计算误差均方函数 J(w,b)
def cost_function(x, y, w, b):
    m = x.shape[0] # 训练集的数据样本数
    cost_sum = 0.0
    for i in range(m):
        f_wb = w * x[i] + b
        cost = (f_wb - y[i]) ** 2
        cost_sum += cost
    return cost_sum / (2 * m)

# 计算梯度值 dJ/dw, dJ/db
def compute_gradient(x, y, w, b):
    m = x.shape[0] # 训练集的数据样本数
    d_w = 0.0
    d_b = 0.0
    for i in range(m):
        f_wb = w * x[i] + b
        d_wi = (f_wb - y[i]) * x[i]
        d_bi = (f_wb - y[i])
        d_w += d_wi
        d_b += d_bi
    dj_dw = d_w / m
    dj_db = d_b / m
    return dj_dw, dj_db

# 梯度下降算法
def linear_regression(x, y, w, b, learning_rate=0.01, epochs=1000):
    J_history = [] # 记录每次迭代产生的误差值
    for epoch in range(epochs):
        dj_dw, dj_db = compute_gradient(x, y, w, b)
        # w 和 b 需同步更新
        w = w - learning_rate * dj_dw
        b = b - learning_rate * dj_db
        J_history.append(cost_function(x, y, w, b)) # 记录每次迭代产生的误差值
    return w, b, J_history

# 绘制线性方程的图像
def draw_line(w, b, xmin, xmax, title):
    x = np.linspace(xmin, xmax)
    y = w * x + b
    # plt.axis([0, 10, 0, 50]) # xmin, xmax, ymin, ymax
    plt.xlabel("X-axis", size=15)
    plt.ylabel("Y-axis", size=15)
    plt.title(title, size=20)
    plt.plot(x, y)

# 绘制散点图
def draw_scatter(x, y, title):
    plt.xlabel("X-axis", size=15)
    plt.ylabel("Y-axis", size=15)
    plt.title(title, size=20)
    plt.scatter(x, y)

# 从这里开始执行
if __name__ == '__main__':
    # 训练集样本
    x_train = np.array([1, 2, 3, 5, 6, 7])
    y_train = np.array([15.5, 19.7, 24.4, 35.6, 40.7, 44.8])
    w = 0.0 # 权重
    b = 0.0 # 偏置
    epochs = 10000 # 迭代次数
    learning_rate = 0.01 # 学习率
    J_history = [] # # 记录每次迭代产生的误差值

    w, b, J_history = linear_regression(x_train, y_train, w, b, learning_rate, epochs)
    print(f"result: w = {w:0.4f}, b = {b:0.4f}") # 打印结果

    # 绘制迭代计算得到的线性回归方程
    plt.figure(1)
    draw_line(w, b, 0, 10, "Linear Regression")
    plt.scatter(x_train, y_train) # 将训练数据集也表示在图中
    plt.show()

    # 绘制误差值的散点图
    plt.figure(2)
    x_axis = list(range(0, 10000))
    draw_scatter(x_axis, J_history, "Cost Function in Every Epoch")
    plt.show()

运行结果

在这里插入图片描述
在这里插入图片描述

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:/a/382802.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

基于Linux的HTTP代理服务器搭建与配置实战

在数字化世界中,HTTP代理服务器扮演着至关重要的角色,它们能够帮助我们管理网络请求、提高访问速度,甚至在某些情况下还能保护我们的隐私。而Linux系统,凭借其强大的功能和灵活性,成为了搭建HTTP代理服务器的理想选择。…

I2C基础协议详解

串口是传感器、外设常用的接口,在低速器件中可以通过串口传输数据。高速复杂的器件,往往内部存在很多寄存器,这些寄存器的配置一般也是采用串口通信,可以节省IO口。 常用串口大致分为UART、IIC、SPI三种,其中IIC时序稍…

unity学习案例总结

动态标签 GitHub - SarahMit/DynamicLabel3D: Simple dynamic labels for a 3D Unity scene

《乱弹篇(十三)明朝事儿》

2024年农历除夕夜,因追剧收看电视连续剧《后宫》而放弃了收看一年一度的《春晚》,至到春节(农历正月初一)晚才看完了《后宫》。 社交网站“必应”图片《后宫》 电视连续剧《后宫》, 讲的是明朝英宗末年的历史故事&…

【大厂AI课学习笔记】【1.5 AI技术领域】(10)对话系统

对话系统,Dialogue System,也称为会话代理。是一种模拟人类与人交谈的计算机系统,旨在可以与人类形成连贯通顺的对话,通信方式主要有语音/文本/图片,当然也可以手势/触觉等其他方式 一般我们将对话系统,分…

[算法学习]

矩阵乘法 只有当左矩阵列数等于右矩阵行数,才能相乘N*M的矩阵和M*K的矩阵做乘法后矩阵大小为N*k矩阵乘法规则:第一个矩阵A的第 i 行与第二个矩阵的第 j 列的各M个元素对应相乘再相加得到新矩阵C[i][j]的值 整除 同余 同余的性质 线性运算,…

【制作100个unity游戏之25】3D背包、库存、制作、快捷栏、存储系统、砍伐树木获取资源、随机战利品宝箱1(附带项目源码)

效果演示 文章目录 效果演示系列目录前言人物和视角基本控制简单的背包系统和物品交互绘制背包UI脚本控制 源码完结 系列目录 前言 欢迎来到【制作100个Unity游戏】系列!本系列将引导您一步步学习如何使用Unity开发各种类型的游戏。在这第25篇中,我们将…

【c++基础】扑克牌组合

说明 小明从一副扑克牌中(没有大小王,J认为是数字11,Q是12,K是13,A是1)抽出2张牌求和,请问能够组合出多少个不相等的数,按照由小到大输出这些数。 输入数据 第一行是一个整数n代表…

2-8 单链表+双链表+模拟栈+模拟队列

今天给大家用数组来实现链表栈和队列 单链表: 首先要明白是如何用数组实现, 在这里需要用到几个数组,head表示头节点的下标,e[i]表示表示下标为i的值,ne[i]表示当前节点下一个节点的下标。idx表示当前已经用到那个点…

抛弃Spring Cloud Gateway,得物 使用Netty架构100Wqps网关

说在前面 在40岁老架构师 尼恩的读者交流群(50)中,很多小伙伴拿到一线互联网企业如阿里、网易、有赞、希音、百度、滴滴的面试资格。 最近,尼恩指导一个小伙伴简历,写了一个《高并发网关项目》,此项目帮这个小伙拿到 字节/阿里/…

线程池7个参数描述

所谓的线程池的 7 大参数是指&#xff0c;在使用 ThreadPoolExecutor 创建线程池时所设置的 7 个参数&#xff0c;如以下源码所示&#xff1a; public ThreadPoolExecutor(int corePoolSize,int maximumPoolSize,long keepAliveTime,TimeUnit unit,BlockingQueue<Runnable&…

检查链表是否回文

根据回文对称的特点&#xff0c;不难想到将链表分成前后两部分&#xff0c;然后将其中一部分反转的方法。 可以使用快慢指针的方式找到链表的中点&#xff0c;其中快指针每次移动两步&#xff0c;慢指针每次移动一步。当快指针到达链表末尾时&#xff0c;慢指针指向的位置即为链…

[LeetCode周赛复盘] 第 384 场周赛20240211

[LeetCode周赛复盘] 第 384 场周赛20240211 一、本周周赛总结100230. 修改矩阵1. 题目描述2. 思路分析3. 代码实现 100219. 回文字符串的最大数量1. 题目描述2. 思路分析3. 代码实现 100198. 匹配模式数组的子数组数目 II1. 题目描述2. 思路分析3. 代码实现 参考链接 一、本周…

前端JavaScript篇之ajax、axios、fetch的区别

目录 ajax、axios、fetch的区别AjaxAxiosFetch总结注意 ajax、axios、fetch的区别 在Web开发中&#xff0c;ajax、axios和fetch都是用于与服务器进行异步通信的技术&#xff0c;但它们在实现方式和功能上有所不同。 Ajax 定义与特点&#xff1a;Ajax是一种在无需重新加载整个…

【c++基础】国王的魔镜

说明 国王有一个魔镜&#xff0c;可以把任何接触镜面的东西变成原来的两倍——只是&#xff0c;因为是镜子嘛&#xff0c;增加的那部分是反的。 比如一条项链&#xff0c;我们用AB来表示&#xff0c;不同的字母表示不同颜色的珍珠。如果把B端接触镜面的话&#xff0c;魔镜会把…

小游戏和GUI编程(5) | SVG图像格式简介

小游戏和GUI编程(5) | SVG图像格式简介 0. 问题 Q1: SVG 是什么的缩写&#xff1f;Q2: SVG 是一种图像格式吗&#xff1f;Q3: SVG 相对于其他图像格式的优点和缺点是什么&#xff1f;Q4: 哪些工具可以查看 SVG 图像&#xff1f;Q5: SVG 图像格式的规范是怎样的&#xff1f;Q6…

中文GPTS详尽教程,字节扣子Coze插件使用全输出

今天&#xff0c;斜杠君和大家分享如何在字节扣子Coze中创建插件&#xff0c;并在创建后如何使用这个插件。 01 新建插件 首先&#xff0c;进入到插件页面&#xff0c;创建一个插件。 https://www.coze.cn/home 点击左侧的个人空间。 在上面选择”插件“标签&#xff0c;来到…

精灵图,字体图标,CSS3三角

精灵图 1.1为什么需要精灵图 一个网页中往往会应用很多小的背景图像作为修饰&#xff0c;当网页中的图像过多时&#xff0c;服务器就会频繁的接受和发送请求图片&#xff0c;造成服务器请求压力过大&#xff0c;这将大大降低页面的加载速度。 因此&#xff0c;为了有效地减少…

《计算思维导论》笔记:10.4 关系模型-关系运算

《大学计算机—计算思维导论》&#xff08;战德臣 哈尔滨工业大学&#xff09; 《10.4 关系模型-关系运算》 一、引言 本章介绍数据库的基本数据模型&#xff1a;关系模型-关系运算。 二、什么是关系运算 在数据库理论中&#xff0c;关系运算&#xff08;Relational Operatio…

读书笔记之《运动改造大脑》:运动是最佳的健脑丸

《运动改造大脑》的作者是约翰•瑞迪&#xff08;John Ratey&#xff09; / 埃里克•哈格曼&#xff08;Eric Hagerman&#xff09;&#xff0c;原著名称为&#xff1a;Spark&#xff1a;the revolutionary new science of exercise and the brain&#xff0c;于 2013年出版约翰…