【单层神经网络】基于MXNet的线性回归实现(底层实现)

写在前面

  1. 基于亚马逊的MXNet库
  2. 本专栏是对李沐博士的《动手学深度学习》的笔记,仅用于分享个人学习思考
  3. 以下是本专栏所需的环境(放进一个environment.yml,然后用conda虚拟环境统一配置即可)
  4. 刚开始先从普通的寻优算法开始,熟悉一下学习训练过程
  5. 下面将使用梯度下降法寻优,但这大概只能是局部最优,它并不是一个十分优秀的寻优算法
name: gluon
dependencies:
- python=3.6
- pip:
  - mxnet==1.5.0
  - d2lzh==1.0.0
  - jupyter==1.0.0
  - matplotlib==2.2.2
  - pandas==0.23.4

整体流程

  1. 生成训练数据集(实际工程中,需要从实际对象身上采集数据)
  2. 确定模型及其参数(输入输出个数、阶次,偏置等)
  3. 确定学习方式(损失函数、优化算法,学习率,训练次数,终止条件等)
  4. 读取数据集(不同的读取方式会影响最终的训练效果)
  5. 训练模型

完整程序及注释

from IPython import display
from matplotlib import pyplot as plt
from mxnet import autograd, nd
import random


'''
获取(生成)训练集
'''
input_num = 2				# 输入个数
examples_num = 1000			# 生成样本个数
# 确定真实模型参数
real_W = [10.9, -8.7]		
real_bias = 6.5	

features = nd.random.normal(scale=1, shape=(examples_num, input_num))       # 标准差=1,均值缺省=0
labels = real_W[0]*features[:,0] + real_W[1]*features[:,1] + real_bias		# 根据特征和参数生成对应标签
labels_noise = labels + nd.random.normal(scale=0.1, shape=labels.shape)		# 为标签附加噪声,模拟真实情况

# 绘制标签和特征的散点图(矢量图)
# def use_svg_display():
#     display.set_matplotlib_formats('svg')

# def set_figure_size(figsize=(3.5,2.5)):
#     use_svg_display()
#     plt.rcParams['figure.figsize'] = figsize

# set_figure_size()
# plt.scatter(features[:,0].asnumpy(), labels_noise.asnumpy(), 1)
# plt.scatter(features[:,1].asnumpy(), labels_noise.asnumpy(), 1)
# plt.show()


# 创建一个迭代器(确定从数据集获取数据的方式)
def data_iter(batch_size, features, labels):
    num = len(features)
    indices = list(range(num))                                  # 生成索引数组
    random.shuffle(indices)                                     # 打乱indices
    # 该遍历方式同时确保了随机采样和无遗漏
    for i in range(0, num, batch_size):
        j = nd.array(indices[i: min(i+batch_size, num)])        # 对indices从i开始取,取batch_size个样本,并转换为列表
        yield features.take(j), labels.take(j)                  # take方法使用索引数组,从features和labels提取所需数据


"""
训练的基础准备
"""
# 声明训练变量,并赋高斯随机初始值
w = nd.random.normal(scale=0.01, shape=(input_num))
b = nd.zeros(shape=(1,))
# b = nd.zeros(1)       # 不同写法,等价于上面的
w.attach_grad()         # 为需要迭代的参数申请求梯度空间
b.attach_grad()

# 定义模型
def linreg(X, w, b):
    return nd.dot(X,w)+b

# 定义损失函数
def squared_loss(y_hat, y):
    return (y_hat - y.reshape(y_hat.shape)) **2 /2
    
# 定义寻优算法
def sgd(params, learning_rate, batch_size):
    for param in params:
        # 新参数 = 原参数 - 学习率*当前批量的参数梯度/当前批量的大小
        param[:] = param - learning_rate * param.grad / batch_size

# 确定超参数和学习方式
lr = 0.03
num_iterations = 5
net = linreg				# 目标模型
loss = squared_loss			# 代价函数(损失函数)
batch_size = 10				# 每次随机小批量的大小

'''
开始训练
'''
for iteration in range(num_iterations):		# 确定迭代次数
    for x, y in data_iter(batch_size, features, labels):
        with autograd.record():
            l = loss(net(x,w,b), y)			# 求当前小批量的总损失
        l.backward()						# 求梯度
        sgd([w,b], lr, batch_size)			# 梯度更新参数
    train_l = loss(net(features,w,b), labels)

    print("iteration %d, loss %f" % (iteration+1, train_l.mean().asnumpy()))
# 打印比较真实参数和训练得到的参数
print("real_w " + str(real_W) + "\n train_w " + str(w))
print("real_w " + str(real_bias) + "\n train_b " + str(b))

具体程序解释

param[:] = param - learning_rate * param.grad / batch_size
将batch_size与参数调整相关联的原因,是为了使得每次更新的步长不受批次大小的影响
具体来说,当计算一批数据的损失函数的梯度时,实际上是将这批数据中每个样本对损失函数的贡献累加起来。这意味着如果批次较大,梯度的模也会相应增大
故更新权值时,使用的是数据集的平均梯度,而不是总和

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:/a/963466.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

5.5.1 面向对象的基本概念

文章目录 基本概念面向对象的5个原则 基本概念 面向对象的方法,特点时其分析与设计无明显界限。虽然在软件开发过程中,用户的需求会经常变化,但客观世界对象间的关系是相对稳定的。对象是基本的运行实体,由数据、操作、对象名组成…

在线免费快速无痕去除照片海报中的文字logo

上期和大家分享了用photoshop快速无痕去除照片海报中的文字logo的方法,有的同学觉得安装PS太麻烦,有那下载安装时间早都日落西山了,问有没有合适的在线方法可以快速去除;达芬奇上网也尝试了几个网站,今天分享一个对国人…

Linux网络 | 网络层IP报文解析、认识网段划分与IP地址

前言:本节内容为网络层。 主要讲解IP协议报文字段以及分离有效载荷。 另外, 本节也会带领友友认识一下IP地址的划分。 那么现在废话不多说, 开始我们的学习吧!! ps:本节正式进入网络层喽, 友友们…

【深度学习】DeepSeek模型介绍与部署

原文链接:DeepSeek-V3 1. 介绍 DeepSeek-V3,一个强大的混合专家 (MoE) 语言模型,拥有 671B 总参数,其中每个 token 激活 37B 参数。 为了实现高效推理和成本效益的训练,DeepSeek-V3 采用了多头潜在注意力 (MLA) 和 De…

STM32 PWM驱动舵机

接线图: 这里将信号线连接到了开发板的PA1上 代码配置: 这里的PWM配置与呼吸灯一样,呼吸灯连接的是PA0引脚,输出比较单元用的是OC1通道,这里只需改为OC2通道即可。 完整代码: #include "servo.h&quo…

51单片机 02 独立按键

一、独立按键控制LED亮灭 轻触按键&#xff1a;相当于是一种电子开关&#xff0c;按下时开关接通&#xff0c;松开时开关断开&#xff0c;实现原理是通过轻触按键内部的金属弹片受力弹动来实现接通和断开。 #include <STC89C5xRC.H> void main() { // P20xFE;while(1){…

本地部署 DeepSeek-R1:简单易上手,AI 随时可用!

&#x1f3af; 先看看本地部署的运行效果 为了测试本地部署的 DeepSeek-R1 是否真的够强&#xff0c;我们随便问了一道经典的“鸡兔同笼”问题&#xff0c;考察它的推理能力。 &#x1f4cc; 问题示例&#xff1a; 笼子里有鸡和兔&#xff0c;总共有 35 只头&#xff0c;94 只…

[EAI-027] RDT-1B,目前最大的用于机器人双臂操作的机器人基础模型

Paper Card 论文标题&#xff1a;RDT-1B: a Diffusion Foundation Model for Bimanual Manipulation 论文作者&#xff1a;Songming Liu, Lingxuan Wu, Bangguo Li, Hengkai Tan, Huayu Chen, Zhengyi Wang, Ke Xu, Hang Su, Jun Zhu 论文链接&#xff1a;https://arxiv.org/ab…

DeepSeek为什么超越了OpenAI?从“存在主义之问”看AI的觉醒

悉尼大学学者Teodor Mitew向DeepSeek提出的问题&#xff0c;在推特上掀起了一场关于AI与人类意识的大讨论。当被问及"你最想问人类什么问题"时&#xff0c;DeepSeek的回答直指人类存在的本质&#xff1a;"如果意识是进化的偶然&#xff0c;宇宙没有内在的意义&a…

在 crag 中用 LangGraph 进行评分知识精炼-下

在上一次给大家展示了基本的 Rag 检索过程&#xff0c;着重描述了增强检索中的知识精炼和补充检索&#xff0c;这些都是 crag 的一部分&#xff0c;这篇内容结合 langgraph 给大家展示通过检索增强生成&#xff08;Retrieval-Augmented Generation, RAG&#xff09;的工作流&am…

UE5.3 C++ CDO的初步理解

一.UObject UObject是所有对象的基类&#xff0c;往上还有UObjectBaseUtility。 注释&#xff1a;所有虚幻引擎对象的基类。对象的类型由基于 UClass 类来定义。 这为创建和使用UObject的对象提供了 函数&#xff0c;并且提供了应在子类中重写的虚函数。 /** * The base cla…

知识库管理在提升企业决策效率与知识共享中的应用探讨

内容概要 知识库管理是指企业对内部知识、信息进行系统化整理和管理的过程&#xff0c;其重要性在于为企业决策提供了坚实的数据支持与参考依据。知识库管理不仅能够提高信息的获取速度&#xff0c;还能有效减少重复劳动&#xff0c;提升工作效率。在如今快速变化的商业环境中…

Linux:线程池和单例模式

一、普通线程池 1.1 线程池概念 线程池&#xff1a;一种线程使用模式。线程过多会带来调度开销&#xff0c;进而影响缓存局部性和整体性能。而线程池维护着多个线程&#xff0c;等待着监督管理者分配可并发执行的任务。这避免了在处理短时间任务时创建与销毁线程的代价&…

AJAX笔记原理篇

黑马程序员视频地址&#xff1a; AJAX-Day03-01.XMLHttpRequest_基本使用https://www.bilibili.com/video/BV1MN411y7pw?vd_source0a2d366696f87e241adc64419bf12cab&spm_id_from333.788.videopod.episodes&p33https://www.bilibili.com/video/BV1MN411y7pw?vd_sour…

ComfyUI安装调用DeepSeek——DeepSeek多模态之图形模型安装问题解决(ComfyUI-Janus-Pro)

ComfyUI 的 Janus-Pro 节点&#xff0c;一个统一的多模态理解和生成框架。 试用&#xff1a; https://huggingface.co/spaces/deepseek-ai/Janus-1.3B https://huggingface.co/spaces/deepseek-ai/Janus-Pro-7B https://huggingface.co/spaces/deepseek-ai/JanusFlow-1.3B 安装…

3D图形学与可视化大屏:什么是材质属性,有什么作用?

一、颜色属性 漫反射颜色 漫反射颜色决定了物体表面对入射光进行漫反射后的颜色。当光线照射到物体表面时&#xff0c;一部分光被均匀地向各个方向散射&#xff0c;形成漫反射。漫反射颜色的选择会直接影响物体在光照下的外观。例如&#xff0c;一个红色的漫反射颜色会使物体在…

JVM方法区

一、栈、堆、方法区的交互关系 二、方法区的理解: 尽管所有的方法区在逻辑上属于堆的一部分&#xff0c;但是一些简单的实现可能不会去进行垃圾收集或者进行压缩&#xff0c;方法区可以看作是一块独立于Java堆的内存空间。 方法区(Method Area)与Java堆一样&#xff0c;是各个…

租赁管理系统在促进智能物业运营中的关键作用和优化策略分析

租赁管理系统在智能物业运营中的关键作用与优化策略 随着科技的飞速发展&#xff0c;租赁管理系统在智能物业运营中扮演着越来越重要的角色。这种系统不仅提高了物业管理的效率&#xff0c;更是促进了资源的优化配置和客户关系的加强。对于工业园、产业园、物流园、写字楼和公…

LLMs之DeepSeek:Math-To-Manim的简介(包括DeepSeek R1-Zero的详解)、安装和使用方法、案例应用之详细攻略

LLMs之DeepSeek&#xff1a;Math-To-Manim的简介(包括DeepSeek R1-Zero的详解)、安装和使用方法、案例应用之详细攻略 目录 Math-To-Manim的简介 1、特点 2、一个空间推理测试—考察不同大型语言模型如何解释和可视化空间关系 3、DeepSeek R1-Zero的简介&#xff1a;处理更…

【课题推荐】基于t分布的非高斯滤波框架在水下自主导航中的应用研究

水下自主导航系统在海洋探测、环境监测及水下作业等领域具有广泛的应用。然而&#xff0c;复杂的水下环境常常导致传感器输出出现野值噪声&#xff0c;这些噪声会严重影响导航信息融合算法的精度&#xff0c;甚至导致系统发散。传统的卡尔曼滤波算法基于高斯噪声假设&#xff0…