梯度下降法、模拟训练、拟合二次曲线、最小二乘法、MSELoss、拟合:f(x)=ax^2+bx+c

本文目标:

f(x)=a*x^2+b*x+c

以这个公式为例,设计一个算法,用梯度下降法来模拟训练过程,最终得出参数a,b,c

原理介绍

目标函数:h_{\theta}(x) = a^{2}+bx+c

损失函数:Loss=\frac{1}{2m}\sum_{1}^{m}(h_{\theta}(x^{i})-y^{i})^2,就是mse

损失函数展开:Loss=\frac{1}{2m}\sum_{1}^{m}((ax^{i})^{2}+bx^i+c-y^{i})^2

损失函数对a,b,c求导数:

{L_{a}}^{'}=\frac{1}{m}\sum_{1}^{m}(ax^2+bx+c-y)*x^2

{L_{b}}^{'}=\frac{1}{m}\sum_{1}^{m}(ax^2+bx+c-y)*x

{L_{c}}^{'}=\frac{1}{m}\sum_{1}^{m}(ax^2+bx+c-y)

导数就是梯度,也就是目标参数与当前参数的差异,这个差异需要用梯度下降法更新

\Delta a{L_{a}}^{'}   \Delta b={L_{b}}^{'}     \Delta c={L_{c}}^{'}

a = a - lr*\Delta a

b = b - lr*\Delta b

c = c - lr*\Delta c

重复上面的过程,参数就可以更新了,然后可以看看新参数的效果,也就是损失有没有降低

具体流程

  1. 预设模型的表达式为:f(x)=a*x^2+b*x+c,也就是二次函数。同时随机初始化模型参数a,b,c。如果是其他函数如f(x)=ax^3+bx^2+cx,就无法在本版本适用(修改求导方式后才可用)。即本模型需要提前知道模型的表达式。
  2. 通过不断喂入(x_input,y_true),得出y_{out} = ax_{input}^{2}+bx_{input}+c.而y_out与y_true之间具有差异。
  3. 将差异封装成一个loss函数,并分别对a,b,c进行求导。得到a,b,c的梯度\Delta a\Delta b\Delta c
  4. \Delta a\Delta b\Delta c和原始的参数a,b,c和学习率作为输入,用梯度下降法来对a,b,c参数进行更新.
  5. 重复2,3,4过程。直到训练结束或者loss降低到较小值

python实现

  # 初始化a,b,c为:-11/6 , -395/3,-2400  目标a,b,c为:(2,-4,3)

class QuadraticFunc():
    def drew(self,w,name="show"):
        a,b,c = w
        x1 = np.array(range(-80,80))
        y1 = a*x1*x1 + b*x1 + c
        y2 = 2*x1*x1 - 4*x1 + 3
        plt.clf()
        plt.plot(x1, y1)
        plt.plot(x1, y2)
        plt.scatter(x1, y1, c='r')# set color
        plt.xlim((-50,50))
        plt.ylim((-500,500))
        plt.xlabel('X Axis')
        plt.ylabel('Y Axis')
        if name == "first":
            plt.pause(3)
        else:
            plt.pause(0.01)
        plt.ioff()

    #计算loss
    def cal_loss(self,y_out,y_true):
        # return np.dot((y_out - y_true),(y_out - y_true)) * 0.5
        return np.mean((y_out - y_true)**2)

    #计算梯度  
    def cal_grad(self,x,y_out,y_true):
        # x(batch),y_out(batch),y_true(batch)
        a_grad = (y_out-y_true)*x**2 #
        b_grad = (y_out-y_true)*x
        c_grad = (y_out-y_true)
        return np.array([np.mean(a_grad),np.mean(b_grad),np.mean(c_grad)])        

    #梯度下降法更新参数
    def update_theta(self,step,w,grad):
        new_w = w - step*grad
        return new_w

    def run(self):
        feed_x = np.array(range(-400,400))/400
        feed_y = 2*feed_x*feed_x - 4*feed_x + 3
        step = 0.5
        base_lr = 0.5
        lr = base_lr
        # 初始化参数
        a,b,c = -11/6 , -395/3,-2400#-1,10,26
        w = np.array([a,b,c])
        self.drew(w,"first")
        epochs = 100
        for epoch in range(epochs):
            # 每隔10轮 降低一半的学习率
            lr = base_lr/(2**(int((epoch+1)/10)))
            for i in range(len(feed_x)):
                x_input = feed_x[i]
                y_true = feed_y[i]
                y_out = w[0]*x_input*x_input +w[1]*x_input + w[2]
                #计算loss
                loss = self.cal_loss(y_out,y_true)
                #计算梯度
                grad = self.cal_grad(x_input,y_out,y_true)
                #更新参数,梯度下降
                w = self.update_theta(lr,w,grad)
                # self.drew(w)
                grad = np.round(grad,2)
                loss = np.round(loss,2)
                w = np.round(w,2)
            print("train times is:",epoch,"  grad is:",grad,"   loss is:","%.4f"%loss, "  w is:",w,"\n")
            self.drew(w)
            if loss<1e-5:
                print("train finish:",w)
                break
    def run_batch(self):
        feed_x = np.array(range(-400,400))/400
        feed_y = 2*feed_x*feed_x - 4*feed_x + 3
        x_y = [[feed_x[i],feed_y[i]] for i in range(len(feed_x))]
        base_lr = 0.5
        lr = base_lr
        # 初始化参数
        a,b,c = -11/6 , -395/3,-2400#-1,10,26
        w = np.array([a,b,c])
        self.drew(w,"first")
        batch_size = 16
        data_len = len(x_y)//batch_size
        epochs = 100
        for epoch in range(epochs):
            random.shuffle(x_y)
            # 每隔10轮 降低一半的学习率
            lr = base_lr/(2**(int((epoch+1)/10)))
            print("epoch,lr:",epoch,lr)
            for i in range(data_len):
                x_y_list = x_y[i*batch_size:(i+1)*batch_size]
                x_y_np = np.array(x_y_list)
                x_input = x_y_np[:,0]
                y_true = x_y_np[:,1]
                y_out = w[0]*x_input*x_input +w[1]*x_input + w[2]
                #计算loss
                loss = self.cal_loss(y_out,y_true)
                #计算梯度
                grad = self.cal_grad(x_input,y_out,y_true)
                #更新参数,梯度下降
                w = self.update_theta(lr,w,grad)
                grad = np.round(grad,2)
                loss = np.round(loss,2)
                w = np.round(w,2)
            print("train times is:",epoch,"  grad is:",grad,"   loss is:","%.4f"%loss, "  w is:",w,"\n")
            self.drew(w)
            if loss<1e-5:
                print("train finish:",w)
                # break
            time.sleep(0.1)
            
            
if __name__ == "__main__":
    qf = QuadraticFunc()
    qf.run()

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:/a/351249.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

JavaScript高级:闭包

1 概念 一个函数对周围状态的引用&#xff0c;捆绑在一起&#xff0c;内层函数中可以访问到外层函数的作用域。 简单理解&#xff1a;闭包 内层函数 外层函数的变量 先看个简单的代码&#xff1a; function outer() {let a 1function inner() {console.log(a)} } outer(…

tee漏洞学习-翻译-1:从任何上下文中获取 TrustZone 内核中的任意代码执行

原文&#xff1a;http://bits-please.blogspot.com/2015/03/getting-arbitrary-code-execution-in.html 目标是什么&#xff1f; 这将是一系列博客文章&#xff0c;详细介绍我发现的一系列漏洞&#xff0c;这些漏洞将使我们能够将任何用户的权限提升到所有用户的最高权限 - 在…

重磅!讯飞星火V3.5马上发布!AI写作、AI编程、AI绘画等功能全面提升!

讯飞星火大模型相信很多友友已经不陌生了&#xff0c;可以说是国内GPT相关领域的龙头标杆&#xff0c;而对于1月30日即将在讯飞星火发布会发出的V3.5新版本来说&#xff0c;讯飞星火V3.5与之前版本相比&#xff0c;性能提升方面相当明显&#xff0c;在提示语义理解、内容生成、…

Java项目:15 springboot vue的智慧养老手表管理系统

作者主页&#xff1a;源码空间codegym 简介&#xff1a;Java领域优质创作者、Java项目、学习资料、技术互助 文中获取源码 项目介绍 本系统共分为两个角色&#xff1a;家长&#xff0c;养老院管理员 框架&#xff1a;springboot、mybatis、vue 数据库&#xff1a;mysql 5.7&…

【幻兽帕鲁】开服务器,高性能高带宽(100mbps),免费!!!【学生党强推】

【幻兽帕鲁】开服务器&#xff0c;高性能高带宽&#xff08;100mbps&#xff09;&#xff0c;免费&#xff01;&#xff01;&#xff01;【学生党强推】 教程相关视频地址&#xff1a;https://www.bilibili.com/video/BV16e411Y7Fd/ 目前幻兽帕鲁开服务器有以下几套比较性价比的…

【计算机图形学】实验五 一个简单的交互式绘图系统(实验报告分析+截图+源码)

可以先看一看这篇呀~【计算机图形学】专栏前言-CSDN博客https://blog.csdn.net/m0_55931547/article/details/135863062 目录 一、实验目的 二、实验内容

docker-compose部署单机ES+Kibana

记录部署的操作步骤 准备工作编写docker-compose.yml启动服务验证部署结果 本次elasticsearch和kibana版本为8.2.2 使用环境&#xff1a;centos7.9 本次记录还包括&#xff1a;安装elasticsearch中文分词插件和拼音分词插件 准备工作 1、创建目录和填写配置 mkdir /home/es/s…

杰理方案——WIFI连接物联网配置阿里云操作步骤

demo——DevKitBoard 注意:最好用这个Demo,其它Demo可能会有莫名其妙的错误问题。 wifi配置 需要在app_config.h文件中定义USE_DEMO_WIFI_TEST,工程会在wifi_demo_task.c文件中自动启动wifi相关的任务, 我们将工程配置为连接外部网络STA模式 默认工程会使用如下账号密码 这…

go slice 基本用法

slice&#xff08;切片&#xff09;是 go 里面非常常用的一种数据结构&#xff0c;它代表了一个变长的序列&#xff0c;序列中的每个元素都有相同的数据类型。 一个 slice 类型一般写作 []T&#xff0c;其中 T 代表 slice 中元素的类型&#xff1b;slice 的语法和数组很像&…

一款强大的矢量图形设计软件:Adobe Illustrator 2023 (AI2023)软件介绍

Adobe Illustrator 2023 (AI2023) 是一款强大的矢量图形设计软件&#xff0c;为设计师提供了无限创意和畅行无阻的设计体验。AI2023具备丰富的功能和工具&#xff0c;让用户可以轻松创建精美的矢量图形、插图、徽标和其他设计作品。 AI2023在界面和用户体验方面进行了全面升级…

python-自动化篇-运维-监控-简单实例-道出如何使⽤Python进⾏网络监控?

如何使⽤Python进⾏⽹络监控&#xff1f; 使⽤Python进⾏⽹络监控可以帮助实时监视⽹络设备、流量和服务的状态&#xff0c;以便及时识别和解决问题。 以下是⼀般步骤&#xff0c;说明如何使⽤Python进⾏⽹络监控&#xff1a; 选择监控⼯具和库&#xff1a;选择适合⽹络监控需…

网络防御——NET实验

一、实验拓扑 二、实验要求 1、生产区在工作时间&#xff08;9&#xff1a;00---18&#xff1a;00&#xff09;内可以访问服务区&#xff0c;仅可以访问http服务器&#xff1b; 2、办公区全天可以访问服务器区&#xff0c;其中&#xff0c;10.0.2.20可以访问FTP服务器和HTTP服…

OSI七层模型 | TCP/IP模型 | 网络和操作系统的联系 | 网络通信的宏观流程

文章目录 1.OSI七层模型2.TCP/IP五层(或四层)模型3.网络通信的宏观流程3.1.同网段通信3.2.跨网段通信 1.OSI七层模型 在计算机通信诞生之初&#xff0c;不同的厂商都生产自己的设备&#xff0c;都有自己的网络通讯标准&#xff0c;导致了不同厂家之间各种协议不兼容&#xff0…

Oracle篇—分区索引的重建和管理(第三篇,总共五篇)

☘️博主介绍☘️&#xff1a; ✨又是一天没白过&#xff0c;我是奈斯&#xff0c;DBA一名✨ ✌✌️擅长Oracle、MySQL、SQLserver、Linux&#xff0c;也在积极的扩展IT方向的其他知识面✌✌️ ❣️❣️❣️大佬们都喜欢静静的看文章&#xff0c;并且也会默默的点赞收藏加关注❣…

Web3创业:去中心化初创公司的崛起

随着Web3时代的到来&#xff0c;去中心化技术的崛起不仅令人瞩目&#xff0c;也为创业者带来了前所未有的机遇。在这个新的时代&#xff0c;一批去中心化初创公司正崭露头角&#xff0c;重新定义着商业和创新的边界。本文将深入探讨Web3创业的趋势&#xff0c;以及去中心化初创…

VScode 好用的插件合集

VS Code是一个轻量级但功能强大的源代码编辑器&#xff0c;轻量级指的是下载下来的VS Code其实就是一个简单的编辑器&#xff0c;强大指的是支持多种语言的环境插件拓展&#xff0c;也正是因为这种支持插件式安装环境开发让VS Code成为了开发语言工具中的霸主&#xff0c;让其同…

策略者模式-C#实现

该实例基于WPF实现&#xff0c;直接上代码&#xff0c;下面为三层架构的代码。 目录 一 Model 二 View 三 ViewModel 一 Model using System; using System.Collections.Generic; using System.Linq; using System.Text; using System.Threading.Tasks;namespace 设计模式练…

Facebook 广告帐户:多账号运营如何防止封号?

Facebook目前是全球最受欢迎的社交媒体平台之一&#xff0c;拥有超过27亿活跃用户。因此&#xff0c;它已成为个人和企业向全球受众宣传其产品和服务的重要平台。 然而&#xff0c;Facebook 制定了广告商必须遵守的严格政策和准则&#xff0c;以确保其广告的质量和相关性&…

vulnhub靶场之Five86-1

由于这些文章都是从我的hexo博客上面复制下来的&#xff0c;所以有的图片可能不是很完整&#xff0c;但是不受影响&#xff0c;如果有疑问&#xff0c;可以在评论区留言&#xff0c;我看到之后会回复。 一.环境搭建 1.靶场描述 Five86-1 is another purposely built vulnerab…

Vue2:通过代理服务器解决跨域问题

一、场景描述 现在的项目大多数是前后端分离的。Vue前端项目通过ajax去请求后端接口的时候&#xff0c;会有同源策略的限制。从而产生跨域问题。 二、基本概念 1、什么是同源策略&#xff1f; 就是前端服务和后端服务的协议名&#xff0c;IP或主机名&#xff0c;端口号不完…