机器学习 - 选择模型

接着这一篇博客做进一步说明:
机器学习 - 准备数据

PyTorch moduleExplain
torch.nnContains all of the building blocks for computational graphs (essentially a series of computations executed in a particular way). nn 模块为用户提供了丰富的神经网络组件,包括各种层,激活函数,损失函数以及其他辅助功能。
torch.nn.ParameterStores tensors that can be used with nn.Module. If requires_grad=True gradients (used for updating model parameters via gradient descent) are calculated automatically, this is often referred to as “autograd”. 通常在定义神经网络模型时用于表示权重 (weights) 和 偏置 (biases) 等参数
torch.nn.ModuleThe base class for all neural network modules, all the building blocks for neural networks are subclasses. If you’re building a neural network in PyTorch, your models should subclass nn.Module. Requires a forward() method be implemented
torch.optimContains various optimization algorithms (these tell the model parameters stored in nn.Parameter how to best change to improve gradient descent and in turn reduce the loss).
def forward()All nn.Module subclasses require a forward() method, this defines the computation that will take place on the data passed to the particular nn.Module (e.g. the linear regression formula above).

可以这么理解,almost everything in a PyTorch neural network comes from torch.nn .

  • nn.Module contains the larger building blocks (layers)
  • nn.Parameter contains the smaller parameters like weights and biases (put these together to make nn.Module )
  • forward() tells the larger blocks how to make calculations on inputs (tensors full of data) within nn.Module(s)
  • torch.optim contains optimization methods on how to improve the parameters within nn.Parameter to better represent input data.

大概可以这么理解:module 里包含各种参数 (parameter),在 module 里做计算 (forward) 甚至可以通过修改参数来优化 (torch.optim)。

这里稍微介绍 Neural Network Block。
Neural Network Block 通常指的是神经网络中的一个模块化组件,它可以包含一个或多个层 (layers) 以及一些额外的操作,被设计用来完成特定的功能或实现特定的神经网络结构。
Neural Network Block的设计旨在简化神经网络模型的构建和管理,提高代码的可读性和可维护性。通过将神经网络模型划分为多个块,可以将模型的不同部分进行分离,使得每个部分都可以独立地设计,调整和复用。这种模块化的设计使得构建复杂的神经网络变得更加灵活和高效。
比如:卷积神经网络中的卷积块。

代码如下所示

import torch 

class LinearRegressionModel(nn.Module):  # child class nn.Module
  def __init__(self):
    super().__init__()

    # Initialize model parameters
    self.weights = nn.Parameter(torch.randn(1,
                                            dtype=torch.float),
                                requires_grad = True)
    self.bias = nn.Parameter(torch.randn(1,
                                         dtype=torch.float),
                             requires_grad = True)  # requires_grad=True means PyTorch will track the gradients of this specific parameter for use with torch.autograd and gradient descent (for many torch.nn modules, requires_grad=True is set by default)

  # Any child class of nn.Module needs to override forward()
  # This defines the forward computation of the model
  def forward(self, x: torch.Tensor) -> torch.tensor:
    return self.weights * x + self.bias

# Set manual seed since nn.Parameter are randomly initizalized
torch.manual_seed(42)

# Create an instance of the model (this is a subclass of nn.Module that contains nn.Parameter(s))
model_0 = LinearRegressionModel()

# Check the nn.Parameter(s) within the nn.Module subclass
print(f"Check the nn.Parameter(s): {list(model_0.parameters())}")

# List named parameters
print(f"List named parameters: {model_0.state_dict()}")

# 输出结果如下
Check the nn.Parameter(s): [Parameter containing:
tensor([0.3367], requires_grad=True), Parameter containing:
tensor([0.1288], requires_grad=True)]
List named parameters: OrderedDict([('weights', tensor([0.3367])), ('bias', tensor([0.1288]))])


使用 torch.inference_mode() 来做预测。
The data is passed to our model. It will go through the model’s forward() method and produce a result using the computation.

# Make predictions with model
with torch.inference_mode():
  y_test_preds = model_0(X_test)

As the name suggests, torch.inference_mode() is used when using a model for inference (making predictions). torch.inference_mode() turns off a bunch of things (like gradient tracking, which is necessary for training but not for inference) to make forward-passes (data going through the forward() method) faster.

# Check the predictions
print(f"Number of testing samples: {len(X_test)}")
print(f"Number of predictions made: {len(y_test_preds)}")
print(f"Predicted values (X_test):\n {y_test_preds}")

def plot_predictions(train_data = X_train,
                     train_labels = y_train,
                     test_data = X_test,
                     test_labels = y_test,
                     predictions = None):
  """
  Plots training data, test data and compares predictions
  """
  plt.figure(figsize=(10, 7))

  # Plot training data in blue
  plt.scatter(train_data, train_labels, c="b", s=4, label="Training data")

  # Plot test data in green
  plt.scatter(test_data, test_labels, c="g", s=4, label="Test data")

  if predictions is not None:
    plt.scatter(test_data, predictions, c="r", s=4, label="Predictions")

  plt.legend(prop={"size": 14})

plot_predictions(predictions=y_test_preds)

print(f"check the difference:\n {y_test - y_test_preds}")  # 可以发现两者之间的差距是很大的

# 结果如下
Number of testing samples: 10
Number of predictions made: 10
Predicted values (X_test):
 tensor([[0.3982],
        [0.4049],
        [0.4116],
        [0.4184],
        [0.4251],
        [0.4318],
        [0.4386],
        [0.4453],
        [0.4520],
        [0.4588]])
check the difference:
 tensor([[0.4618],
        [0.4691],
        [0.4764],
        [0.4836],
        [0.4909],
        [0.4982],
        [0.5054],
        [0.5127],
        [0.5200],
        [0.5272]])

将数据显示到图里
效果图

看到这了,给个赞呗~

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:/a/470906.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

【理解机器学习算法】之分类问题的模型评估(ROC-AUC)

ROC曲线(接收者操作特性曲线)和AUC(曲线下面积)是在不同阈值设置下,用于分类问题的性能度量工具。下面是它们所代表的含义以及使用方法: ROC曲线 代表含义:ROC曲线是一个图形化的表示&#xf…

反射 Reflection

反射 反射的概念 反射机制允许程序在执行期借助于ReflectionAPI取得任何类的内部信息(比如成员变量,构造器,成员方法等等),并能操作对象的属性及方法。反射在设计模式和框架底层都会用到加载完类之后,在堆中就产生了一个Class类型…

SurfaceFlinger实战dump获取单个Layer图像方案学员改进成果

背景: hi,粉丝朋友们: 在马哥课程的实战实现dump单个图层的发布后,很多学员朋友就纷纷享马哥要了相关源码,相关的链接请参考这里: https://blog.csdn.net/learnframework/article/details/136323076 学员…

前端项目,个人笔记(三)【Vue-cli - api封装-axios使用举例】

目录 前言 1、axios配置与测试 1.1、配置 1.2、测试 2、使用axios案例-渲染header 3、Pinia优化重复请求 3.1、为什么? 3.2、使用Pinia优化代码步骤 步骤一:在main.js中创建 Pinia 实例,并将其作为插件添加到 Vue 应用中 步骤二&am…

【QT+QGIS跨平台编译】之八十四:【QGIS_Gui跨平台编译】—【错误处理:未实例化QgsMapLayer - QgsHighlight】

文章目录 一、未实例化QgsMapLayer二、错误处理 一、未实例化QgsMapLayer 报错信息: 二、错误处理 第31行修改为: #include "qgsmaplayer.h"

【Leetcode】2684. 矩阵中移动的最大次数

文章目录 题目思路代码结果 题目 题目链接🔗 给你一个下标从 0 开始、大小为 m x n 的矩阵 grid ,矩阵由若干 正 整数组成。 你可以从矩阵第一列中的 任一 单元格出发,按以下方式遍历 grid : 从单元格 (row, col) 可以移动到 (…

Flutter-仿携程首页类型切换

效果 唠叨 闲来无事,不小心下载了携程app,还幻想可以去旅游一番,奈何自己运气不好,自从高考时第一次吹空调导致自己拉肚子考试,物理,数学考了一半就交卷,英语2B铅笔除了问题,导致原…

MySQL数据库的基本概念与安装

目录 引言 一、数据库的基本概念 (一)数据、表与数据库 1.数据(Data) 2.表 3.数据库 (二)数据库管理系统 (三)数据库系统 二、数据库的发展 三、主流数据库的介绍 (一)关…

一命通关深度优先遍历

前言 在解释深度优先遍历之前,我想先让大家去思考一个可能从未想过的问题: 为什么我们在学习基础数据结构的时候,都没有出现dfs和bfs这两个词,而在学习二叉树和图的时候,突然蹦出了深度优先遍历和广度优先遍历这两个…

力扣 第 126 场双周赛 解题报告 | 珂学家 | 贡献法思维场 + 贪心构造 + 0-1背包

前言 整体评价 T3是道好题,一开始思路偏了往按字母前缀和和DP去想了,但是感觉很难下手,后来发现从贡献的角度,其实和位子无关系,只需要贪心即可。 T4也是一道贡献思路题,理清核心的点,就能简单…

Vue项目使用process.env关键字及Vue.config.js配置解决前端跨域问题

1.process.env 是Node.js 中的一个环境 1.打开命令行查看环境: 2.process.env与Vue CLI 项目 Vue Cli 有以下三种运行模式 development 模式用于 vue-cli-service serve test 模式用于 vue-cli-service test:unit production 模式用于 vue-cli-service build 和 vue-cli-se…

每日一练:LeeCode-167. 两数之和 II - 输入有序数组【双指针】

给你一个下标从 1 开始的整数数组 numbers &#xff0c;该数组已按 非递减顺序排列 &#xff0c;请你从数组中找出满足相加之和等于目标数 target 的两个数。如果设这两个数分别是 numbers[index1] 和 numbers[index2] &#xff0c;则 1 < index1 < index2 < numbers.…

如何本地部署1Panel面板

文章目录 前言1. Linux 安装1Panel2. 安装cpolar内网穿透3. 配置1Panel公网访问地址4. 公网远程访问1Panel管理界面5. 固定1Panel公网地址 前言 1Panel 是一个现代化、开源的 Linux 服务器运维管理面板。高效管理,通过 Web 端轻松管理 Linux 服务器&#xff0c;包括主机监控、…

YOLOv8改进 | 主干篇 | 利用图像分割网络UNetV2改善图像分割检测性能(全网独家首发)

一、本文介绍 本文给大家带来的改进机制是利用图像分割网络UNetV2的主干来改进我们的YOLOv8分割模型&#xff08;本文的内容虽然YOLOv8所有的功能的用户都能使用&#xff0c;但是还是建议分割的用户使用&#xff09;&#xff0c;U-Net v2 旨在改进医学图像分割的性能&#xff…

C# 方法(函数)

文章目录 C# 方法&#xff08;函数&#xff09;简单示例程序代码运行效果 值传递和引用传递示例程序 运行效果按输出传递参数运行结果 C# 方法&#xff08;函数&#xff09; 简单示例 程序代码 访问的限制符 using System; using System.Collections.Generic; using Syste…

GO-初识包管理

初识包管理&#xff0c;知道项目中文件和文件夹之间的关系 输出&#xff0c;代码&#xff0c;在go编译器运行时会显示在屏幕中 初识数据类型 整型&#xff0c;数字。例如&#xff1a;1、2、3、4 字符串类型&#xff0c;表示文本信息的。例如:“张三”“李四” 布尔类型&#x…

Day51-Nginx多实例知识与大厂企业级实战

Day51-Nginx多实例知识与大厂企业级实战 1. 什么是nginx多实例&#xff1f;2. 为什么要用多实例&#xff1f;3. 大厂数百个业务项目&#xff0c;如何来管理&#xff1f;4. 大厂上百项目web分用户解决方案4.1 编译nginx环境实践&#xff1a;4.2 zuma实例(利用普通用户权限将不同…

Arduino 开发 野火ESP8266模块的串口通信

一、硬件连接 Arduino ESP8266 串口通信是一个常见的任务&#xff0c;允许通过串行接口与其他设备或计算机进行通信。 连接硬件&#xff1a;将野火ESP8266模块通过USB连接到电脑。注意电源为3.3V。 选择开发板和端口&#xff0c;在Arduino IDE中&#xff0c;选择ESP8266开发板…

图神经网络实战(5)——常用图数据集

图神经网络实战&#xff08;5&#xff09;——常用图数据集 0. 前言0. 图数据集介绍2. Cora 数据集3. Facebook Page-Page 数据集小结系列链接 0. 前言 图数据集往往比单纯的连接集合更丰富&#xff0c;节点和边也可以具有表示分数、颜色、单词等的特征。在输入数据中包含这些…

NLP---Bert分词

目录&#xff1a; Q&#xff1a;bert分词步骤1&#xff1a;构建N * N 的相关性矩阵&#xff0c;计算相邻两个字的相关性&#xff0c;低的话&#xff08;<阈值&#xff09;就切割。2&#xff1a;将A词进行mask计算出A的embedding&#xff0c;然后将AB两个词一起mask&#xff…