纯纯python实现梯度下降、随机梯度下降

最近面试有要求手撕SGD,这里顺便就把梯度下降、随机梯度下降、批次梯度下降给写出来了
有几个注意点:
1.求梯度时注意label[i]和pred[i]不要搞反,否则会导致模型发散
2.如果跑了几千个epoch,还是没有收敛,可能是学习率太小了

# X:n*k
# Y: n*1

import random
import numpy

class GD:
    def __init__(self,w_dim,r):
        # 随机初始化
        self.w = [random.random() for _ in range(w_dim)]
        self.bias = random.random()
        self.learningRate = r
        print(f"original w is {self.w}, original bias is {self.bias}")
        

    def forward(self,x):
        # 前馈网络
        ans = []
        for i in range(len(x)):
            y=0
            for j in range(len(x[0])):
                y+=self.w[j]*x[i][j]
            ans.append(y+self.bias)
        return ans

    def bp(self,X,pred,label,op="GD"):
        # 计算均方差
        loss = 0
        for i in range(len(pred)):
            loss+=(label[i]-pred[i])**2
        loss = loss/len(X)
            
        
        # 计算梯度
        # 梯度下降
        if op=="GD":
            grad_w = [0 for _ in range(len(self.w))]
            grad_bias=0
            for i in range(len(X)):
                grad_bias+=-2*(label[i]-pred[i])
                for j in range(len(self.w)):
                    grad_w[j]+=-2*(label[i]-pred[i])*X[i][j]  
            # 反向传播,更新梯度
            self.bias=self.bias-self.learningRate*grad_bias/len(X)
            for i in range(len(self.w)):
                self.w[i]-=self.learningRate*grad_w[i]/len(X)
        
        # 随机梯度下降
        if op=="SGD":
            grad_w = [0 for _ in range(len(self.w))]
            grad_bias=0
            randInd = random.randint(0,len(X)-1)
            grad_bias+=-2*(label[randInd]-pred[randInd])
            for j in range(len(self.w)):
                grad_w[j]+=-2*(label[randInd]-pred[randInd])*X[randInd][j]  
            # 反向传播,更新梯度
            self.bias=self.bias-self.learningRate*grad_bias
            for i in range(len(self.w)):
                self.w[i]-=self.learningRate*grad_w[i]
        
        # 批次梯度下降
        if op=="BGD":        
            grad_w = [0 for _ in range(len(self.w))]
            grad_bias=0
            BS=8
            randInd = random.randint(0,len(X)/BS-1)
            X = X[BS*randInd:BS*(randInd+1)]
            label = label[BS*randInd:BS*(randInd+1)]
            pred = pred[BS*randInd:BS*(randInd+1)]
            
            for i in range(len(X)):
                grad_bias+=-2*(label[i]-pred[i])
                for j in range(len(self.w)):
                    grad_w[j]+=-2*(label[i]-pred[i])*X[i][j]  
            # 反向传播,更新梯度
            self.bias=self.bias-self.learningRate*grad_bias/len(X)
            for i in range(len(self.w)):
                self.w[i]-=self.learningRate*grad_w[i]/len(X)

        return loss




def testY(X,w):

    Y = []
    for x in X:
        y=0
        for i in range(len(x)):
            y+=w[i]*x[i]
        Y.append(y)
    return Y

# 构建数据
n = 1000
X=[[random.random() for _ in range(2)] for _ in range(n)]
w=[0.2,0.3]
B=0.4
Y = testY(X,w)

# 设置样本维度为2
k = 2
lr = GD(k,0.01)
Loss=0
epochs=2000

for e in range(epochs):
    Loss = 0
    pred = lr.forward(X)
    loss=lr.bp(X,pred,Y,"BGD")
    Loss+=loss 
       
    if (e%100)==0:       
        print(f"step:{e},Loss:{Loss}") 
    
    
X_test=[[random.random() for _ in range(2)] for _ in range(2)]
Y_test=testY(X_test,w)

print("X_test=",X_test)
print("Y_test=",Y_test)
print("Y_pred=",lr.forward(X_test))




测试效果如下:
在这里插入图片描述
也还行

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:/a/540814.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

Linux 秋招必知必会(三、线程、线程同步)

六、线程 38. 什么是线程 线程是参与系统调度的最小单位,它被包含在进程之中,是进程中的实际运行单位 一个进程中可以创建多个线程,多个线程实现并发运行,每个线程执行不同的任务 主线程:当一个程序启动时&#xff0…

【Qt 学习笔记】Qt控件概述

博客主页:Duck Bro 博客主页系列专栏:Qt 专栏关注博主,后期持续更新系列文章如果有错误感谢请大家批评指出,及时修改感谢大家点赞👍收藏⭐评论✍ Qt控件概述 文章编号:Qt 学习笔记 / 14 文章目录 Qt控件概…

排序之快速排序

代码 class Solution {public int[] sortArray(int[] nums) {merge(nums, 0, nums.length - 1);return nums;}private void merge(int[] nums, int l, int r){if(l > r) return;// 随机选取主元int i new Random().nextInt(r - l 1) l;int temp nums[i];nums[i] nums[…

探索ElasticSearch高级特性:从映射到智能搜索

欢迎关注我的公众号“知其然亦知其所以然”,获取更多技术干货! 大家好,我是小米!今天我们来聊聊阿里巴巴面试题中的一个高级话题:ElasticSearch(以下简称ES)的高级特性。ES作为一款强大的搜索引擎,在处理大规模数据和复杂查询时发挥着重要作用。而了解其高级特性,则是…

微服务-6 Gateway网关

一、网关搭建 此时浏览器访问 localhost:10010/user/list 后正常返回数据,说明网关已生效,其原理流程图如下: 二、网关过滤器 作用:处理一切进入网关的请求和微服务响应。 1. 网关过滤器的分类: a. 某个路由的过滤器 …

购物车实现

目录 1.购物车常见的实现方式 2.购物车数据结构介绍 3.实例分析 1.controller层 2.service层 1.购物车常见的实现方式 方式一:存储到数据库 性能存在瓶颈方式二:前端本地存储 localstorage在浏览器中存储 key/value 对,没有过期时间。s…

Linux中使用Alias技术实现虚拟网卡

背景 在《Linux中虚拟网络技术有哪些》一文中,我们介绍了多种创建虚拟网卡的方法。本文介绍使用Alias技术创建虚拟网卡。 分析 Alias技术 在计算机领域中,Alias技术指的是给一个实体(如文件、命令、网络接口等)起一个别名或替代…

【leetcode】 跳跃游戏 IV

跳跃游戏IV 题目思路代码 题目 给你一个整数数组 arr &#xff0c;你一开始在数组的第一个元素处&#xff08;下标为 0&#xff09;。每一步&#xff0c;你可以从下标 i 跳到下标 i 1 、i - 1 或者 j &#xff1a;i 1 需满足&#xff1a;i 1 < arr.length i - 1 需满足&…

C++静态库与动态库

什么是库 库是写好的现有的&#xff0c;成熟的&#xff0c;可以复用的代码。现实中每个程序都要依赖很多基础的底层库&#xff0c;不可能每个人的代码都从零开始&#xff0c;因此库的存在意义非同寻常。 本质上来说库是一种可执行代码的二进制形式&#xff0c;可以被操作系统载…

Linux中磁盘的分区,格式化,挂载和文件系统的修复

一.分区工具 1.分区工具介绍 fdisk 2t及以下分区 推荐 (分完区不保存不生效&#xff0c;有反悔的可能) gdisk 全支持 推荐 parted 全支持 不推荐 ( 即时生效&#xff0c;分完立即生效) 2.fdisk 分区,查看磁盘 格式:fdisk -l [磁盘设备] fdisk -l 查看…

运动听歌哪款耳机靠谱?精选五款热门开放式耳机

随着人们对运动健康的重视&#xff0c;越来越多的运动爱好者开始关注如何在运动中享受音乐。开放式蓝牙耳机凭借其独特的设计&#xff0c;成为了户外运动的理想选择。它不仅让你在运动时能够清晰听到周围环境的声音&#xff0c;保持警觉&#xff0c;还能让你在需要时与他人轻松…

【数据结构】常见的排序算法

&#x1f9e7;&#x1f9e7;&#x1f9e7;&#x1f9e7;&#x1f9e7;个人主页&#x1f388;&#x1f388;&#x1f388;&#x1f388;&#x1f388; &#x1f9e7;&#x1f9e7;&#x1f9e7;&#x1f9e7;&#x1f9e7;数据结构专栏&#x1f388;&#x1f388;&#x1f388;&…

基于单链表实现通讯管理系统!(有完整源码!)

​ 个人主页&#xff1a;秋风起&#xff0c;再归来~ 文章专栏&#xff1a;C语言实战项目 个人格言&#xff1a;悟已往之不谏&#xff0c;知来者犹可追 克心守己&#xff0c;律己则安&#xff01; 1、前言 友友们&#xff0c;这篇文章是基于单链…

解决window10 utf-8编码软件中文全部乱码问题

问题描述 很多软件都是乱码状态&#xff0c;不管是Keil还是ISP或者是其他的一些非知名软件&#xff0c;都出现了中文乱码&#xff0c;英文正常显示问题&#xff0c;这个时候是系统出了问题。 解决方法 打开控制面板 点击更改日期、时间或数字格式 点击管理和更改系统区域…

华为云配置安全组策略开放端口

&#x1f436;博主主页&#xff1a;ᰔᩚ. 一怀明月ꦿ ❤️‍&#x1f525;专栏系列&#xff1a;线性代数&#xff0c;C初学者入门训练&#xff0c;题解C&#xff0c;C的使用文章&#xff0c;「初学」C &#x1f525;座右铭&#xff1a;“不要等到什么都没有了&#xff0c;才下…

mysql 查询实战-变量方式-解答

对mysql 查询实战-变量方式-题目&#xff0c;进行一个解答。&#xff08;先看题&#xff0c;先做&#xff0c;再看解答&#xff09; 1、查询表中⾄少连续三次的数字 1&#xff0c;处理思路 要计算连续出现的数字&#xff0c;加个前置变量&#xff0c;记录上一个的值&#xff0c…

类和对象(拷贝构造函数)

目录 拷贝构造函数 特征 结论&#xff1a; 拷贝构造函数 拷贝构造函数&#xff1a;只有单个形参&#xff0c;该形参是对本类类型对象的引用(一般常用const修饰)&#xff0c;在用已存 在的类类型对象创建新对象时由编译器自动调用。 特征 拷贝构造函数也是特殊的成员函数&…

SQL注入sqli_labs靶场第十一、十二、十三、十四题详解

第十一题 方法一 poss提交 输入1显示登录失败 输入1 显示报错信息 根据提示得知&#xff1a;SQL查询语句为 username参数 and password and是与运算&#xff1b;两个或多个条件同时满足&#xff0c;才为真&#xff08;显示一条数据&#xff09; or是或运算&#xff0c;两个…

itext7 pdf转图片

https://github.com/thombrink/itext7.pdfimage 新建asp.net core8项目&#xff0c;安装itext7和system.drawing.common 引入itext.pdfimage核心代码 imageListener下有一段不安全的代码 unsafe{for (int y 0; y < image.Height; y){byte* ptrMask (byte*)bitsMask.Scan…

B站大数据平台元数据业务分享

背景介绍 元数据是数据平台的衍生数据&#xff0c;比如调度任务信息&#xff0c;离线hive表&#xff0c;实时topic&#xff0c;字段信息&#xff0c;存储信息&#xff0c;质量信息&#xff0c;热度信息等。在数据平台建设初期&#xff0c;这类数据主要散落于各种平台子系统的数…