深度学习 - PyTorch基本流程 (代码)

直接上代码

import torch 
import matplotlib.pyplot as plt 
from torch import nn

# 创建data
print("**** Create Data ****")
weight = 0.3
bias = 0.9
X = torch.arange(0,1,0.01).unsqueeze(dim = 1)
y = weight * X + bias
print(f"Number of X samples: {len(X)}")
print(f"Number of y samples: {len(y)}")
print(f"First 10 X & y sample: \n X: {X[:10]}\n y: {y[:10]}")
print("\n")

# 将data拆分成training 和 testing
print("**** Splitting data ****")
train_split = int(len(X) * 0.8)
X_train = X[:train_split]
y_train = y[:train_split]
X_test = X[train_split:]
y_test = y[train_split:]
print(f"The length of X train: {len(X_train)}")
print(f"The length of y train: {len(y_train)}")
print(f"The length of X test: {len(X_test)}")
print(f"The length of y test: {len(y_test)}\n")

# 显示 training 和 testing 数据
def plot_predictions(train_data = X_train,
                     train_labels = y_train,
                     test_data = X_test,
                     test_labels = y_test,
                     predictions = None):
  plt.figure(figsize = (10,7))
  plt.scatter(train_data, train_labels, c = 'b', s = 4, label = "Training data")
  plt.scatter(test_data, test_labels, c = 'g', label="Test data")

  if predictions is not None:
    plt.scatter(test_data, predictions, c = 'r', s = 4, label = "Predictions")
  plt.legend(prop = {"size": 14})
plot_predictions()

# 创建线性回归
print("**** Create PyTorch linear regression model by subclassing nn.Module ****")
class LinearRegressionModel(nn.Module):
  def __init__(self):
    super().__init__()
    self.weight = nn.Parameter(data = torch.randn(1,
                                                  requires_grad = True,
                                                  dtype = torch.float))
    self.bias = nn.Parameter(data = torch.randn(1,
                                                requires_grad = True,
                                                dtype = torch.float))
    
  def forward(self, x):
    return self.weight * x + self.bias

torch.manual_seed(42)
model_1 = LinearRegressionModel()
print(model_1)
print(model_1.state_dict())
print("\n")

# 初始化模型并放到目标机里
print("*** Instantiate the model ***")
print(list(model_1.parameters()))
print("\\n")

# 创建一个loss函数并优化
print("*** Create and Loss function and optimizer ***")
loss_fn = nn.L1Loss()
optimizer = torch.optim.SGD(params = model_1.parameters(),
                            lr = 0.01)
print(f"loss_fn: {loss_fn}")
print(f"optimizer: {optimizer}\n")

# 训练
print("*** Training Loop ***")
torch.manual_seed(42)
epochs = 300
for epoch in range(epochs):
  # 将模型加载到训练模型里
  model_1.train()

  # 做 Forward
  y_pred = model_1(X_train)

  # 计算 Loss
  loss = loss_fn(y_pred, y_train)

  # 零梯度
  optimizer.zero_grad()

  # 反向传播
  loss.backward()

  # 步骤优化
  optimizer.step()

  ### 做测试
  if epoch % 20 == 0:
    # 将模型放到评估模型并设置上下文
    model_1.eval()
    with torch.inference_mode():
      # 做 Forward
      y_preds = model_1(X_test)
      # 计算测试 loss
      test_loss = loss_fn(y_preds, y_test)
      # 输出测试结果
      print(f"Epoch: {epoch} | Train loss: {loss:.3f} | Test loss: {test_loss:.3f}")

# 在测试集上对训练模型做预测
print("\n")
print("*** Make predictions with the trained model on the test data. ***")
model_1.eval()
with torch.inference_mode():
  y_preds = model_1(X_test)
print(f"y_preds:\n {y_preds}")
## 画图
plot_predictions(predictions = y_preds) 

# 保存训练好的模型
print("\n")
print("*** Save the trained model ***")
from pathlib import Path 
## 创建模型的文件夹
MODEL_PATH = Path("models")
MODEL_PATH.mkdir(parents = True, exist_ok = True)
## 创建模型的位置
MODEL_NAME = "trained model"
MODEL_SAVE_PATH = MODEL_PATH / MODEL_NAME 
## 保存模型到刚创建好的文件夹
print(f"Saving model to {MODEL_SAVE_PATH}")
torch.save(obj = model_1.state_dict(), f = MODEL_SAVE_PATH)
## 创建模型的新类型
loaded_model = LinearRegressionModel()
loaded_model.load_state_dict(torch.load(f = MODEL_SAVE_PATH))
## 做预测,并跟之前的做预测
y_preds_new = loaded_model(X_test)
print(y_preds == y_preds_new)

结果如下

**** Create Data ****
Number of X samples: 100
Number of y samples: 100
First 10 X & y sample: 
 X: tensor([[0.0000],
        [0.0100],
        [0.0200],
        [0.0300],
        [0.0400],
        [0.0500],
        [0.0600],
        [0.0700],
        [0.0800],
        [0.0900]])
 y: tensor([[0.9000],
        [0.9030],
        [0.9060],
        [0.9090],
        [0.9120],
        [0.9150],
        [0.9180],
        [0.9210],
        [0.9240],
        [0.9270]])

**** Splitting data ****
The length of X train: 80
The length of y train: 80
The length of X test: 20
The length of y test: 20

**** Create PyTorch linear regression model by subclassing nn.Module ****
LinearRegressionModel()
OrderedDict([('weight', tensor([0.3367])), ('bias', tensor([0.1288]))])


*** Instantiate the model ***
[Parameter containing:
tensor([0.3367], requires_grad=True), Parameter containing:
tensor([0.1288], requires_grad=True)]

*** Create and Loss function and optimizer ***
loss_fn: L1Loss()
optimizer: SGD (
Parameter Group 0
    dampening: 0
    differentiable: False
    foreach: None
    lr: 0.01
    maximize: False
    momentum: 0
    nesterov: False
    weight_decay: 0
)

*** Training Loop ***
Epoch: 0 | Train loss: 0.757 | Test loss: 0.725
Epoch: 20 | Train loss: 0.525 | Test loss: 0.454
Epoch: 40 | Train loss: 0.294 | Test loss: 0.183
Epoch: 60 | Train loss: 0.077 | Test loss: 0.073
Epoch: 80 | Train loss: 0.053 | Test loss: 0.116
Epoch: 100 | Train loss: 0.046 | Test loss: 0.105
Epoch: 120 | Train loss: 0.039 | Test loss: 0.089
Epoch: 140 | Train loss: 0.032 | Test loss: 0.074
Epoch: 160 | Train loss: 0.025 | Test loss: 0.058
Epoch: 180 | Train loss: 0.018 | Test loss: 0.042
Epoch: 200 | Train loss: 0.011 | Test loss: 0.026
Epoch: 220 | Train loss: 0.004 | Test loss: 0.009
Epoch: 240 | Train loss: 0.004 | Test loss: 0.006
Epoch: 260 | Train loss: 0.004 | Test loss: 0.006
Epoch: 280 | Train loss: 0.004 | Test loss: 0.006


*** Make predictions wit the trained model on the test data. ***
y_preds:
 tensor([[1.1464],
        [1.1495],
        [1.1525],
        [1.1556],
        [1.1587],
        [1.1617],
        [1.1648],
        [1.1679],
        [1.1709],
        [1.1740],
        [1.1771],
        [1.1801],
        [1.1832],
        [1.1863],
        [1.1893],
        [1.1924],
        [1.1955],
        [1.1985],
        [1.2016],
        [1.2047]])


*** Save the trained model ***
Saving model to models/trained model
tensor([[True],
        [True],
        [True],
        [True],
        [True],
        [True],
        [True],
        [True],
        [True],
        [True],
        [True],
        [True],
        [True],
        [True],
        [True],
        [True],
        [True],
        [True],
        [True],
        [True]])

第一个结果图
第二个结果图

点个赞支持一下咯~

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:/a/486331.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

ZYNQ学习之Ubuntu环境下的Shell与APT下载工具

基本都是摘抄正点原子的文章&#xff1a;<领航者 ZYNQ 之嵌入式Linux 开发指南 V3.2.pdf&#xff0c;因初次学习&#xff0c;仅作学习摘录之用&#xff0c;有不懂之处后续会继续更新~ 一、Ubuntu Shell操作 简单的说Shell 就是敲命令。国内把 Linux 下通过命令行输入命令叫…

代码随想录算法训练营第三十二天 | 122.买卖股票的最佳时机II ,55. 跳跃游戏 , 45.跳跃游戏II

贪心&#xff1a;只要把每一个上升区间都吃到手&#xff0c;就能一直赚 class Solution { public:int maxProfit(vector<int>& prices) {int res 0;for(int i 1;i< prices.size();i){int diff prices[i] - prices[i-1];if(prices[i] > prices[i-1]){res d…

WSL使用

WSL使用 WSL安装和使用 Termianl和Ubuntu的安装 打开Hype-V虚拟化配置Microsoft Store中搜索Window Terminal并安装Microsoft Store中搜索Ubuntu, 选择安装Ubuntu 22.04.3 LTS版本打开Window Terminal选择Ubuntu标签栏, 进入命令行 中文输入法安装 查看是否安装了fcitx框架…

【官方】操作指南,附代码!银河麒麟服务器迁移运维管理平台V2.1中间件及高可用服务部署(4)

1.RocketMQ集群模式 主机配置示例&#xff1a; IP 角色 架构模式 对应配置文件 1.1.1.1 nameserver1 master broker-n0.conf 2.2.2.2 nameserver2 salve1 broker-n1.conf 3.3.3.3 nameserver3 salve2 broker-n2.conf 1.1.安装rocketmq 在服务器上安装rocket…

第14篇:2线-4线译码器

Q&#xff1a;有编码器那对应的就会有译码器&#xff0c;本期我们来设计实现2线-4线二进制译码器 。 A&#xff1a;基本原理&#xff1a;译码器是编码器的逆过程&#xff0c;其功能是将具有特定含义的二进制码转换为对应的输出信号。2线-4线二进制译码器有2个输入共4种不同的组…

java目标和(力扣Leetcode106)

目标和 力扣原题 问题描述 给定一个正整数数组 nums 和一个整数 target&#xff0c;向数组中的每个整数前添加 ‘’ 或 ‘-’&#xff0c;然后串联起所有整数&#xff0c;可以构造一个表达式。返回可以通过上述方法构造的、运算结果等于 target 的不同表达式的数目。 示例 …

【MySQL】11. 复合查询(重点)

4. 子查询 子查询是指嵌入在其他sql语句中的select语句&#xff0c;也叫嵌套查询 4.1 单行子查询 返回一行记录的子查询 显示SMITH同一部门的员工 mysql> select * from emp where deptno (select deptno from emp where ename SMITH); -----------------------------…

小目标检测篇 | YOLOv8改进之添加BiFormer注意力机制

前言:Hello大家好,我是小哥谈。BiFormer是一种具有双层路由的动态稀疏注意力机制,它通过查询自适应的方式关注一小部分相关标记,从而提供了更灵活的计算分配和内容感知。它在多个计算机视觉任务中表现出了良好的性能和高计算效率。BiFormer注意力机制比较适合处理小尺度目标…

聚类算法之高斯混合模型聚类 (Gaussian Mixture Model, GMM)

注意&#xff1a;本文引用自专业人工智能社区Venus AI 更多AI知识请参考原站 &#xff08;[www.aideeplearning.cn]&#xff09; 高斯混合模型&#xff08;GMM&#xff09;是统计模型中的一颗璀璨之星&#xff0c;它为数据提供了一种复杂而又强大的表示方法。在机器学习的许多…

大数据基础:Linux基础详解

课程介绍 本课程主要通过对linux基础课程的详细讲解&#xff0c;让大家熟练虚拟机的安装使用&#xff0c;Linux系统的安装配置&#xff0c;学习掌握linux系统常用命令的使用&#xff0c;常用的软件安装方法&#xff0c;制作快照&#xff0c;克隆&#xff0c;完成免密登录&…

linux系统--------------mysql数据库管理

目录 一、SQL语句 1.1SQL语言分类 1.2查看数据库信息 1.3登录到你想登录的库 1.4查看数据库中的表信息 1.5显示数据表的结构&#xff08;字段&#xff09; 1.5.1数据表的结构 1.5.2常用的数据类型: 二、关系型数据库的四种语言 2.1DDL&#xff1a;数据定义语言&am…

跨域与Spring Boot中CORS的应用

摘要&#xff1a;前后端独立开发期间&#xff0c;交互主要通过接口文档&#xff0c;前端Mock数据&#xff0c;后端使用Postman都不会发现跨域问题。当联调时前端尝试调用后端接口&#xff0c;这往往就需要需要处理的跨域问题…… 下面总结下跨域问题产生的前因后果以及如何通过…

day03_mysql_课后练习 - 参考答案

文章目录 day03_mysql_课后练习mysql练习题第1题第2题第3题第4题第5题 day03_mysql_课后练习 mysql练习题 第1题 案例&#xff1a; 1、创建一个数据库&#xff1a;day03_test01_school 2、创建如下表格 表1 Department表的定义 字段名字段描述数据类型主键外键非空唯一D…

Java学习笔记 | JavaSE基础语法 | 04 | 数组

文章目录 0.前言1.数组2.数组声明2.1 数组定义2.2 数组初始化1.静态初始化2.动态初始化3.区别4.数组的默认初始化值&#xff1a; 2.3 数组名 3.访问数组3.1 索引3.2 访问数组3.3 length属性 4.数组常见问题5.数组内存分析5.1 内存分配5.2 数组内存分配 6.数组的练习练习1&#…

用Springboot(java程序)访问Salesforce RestAPI

本文讲一下&#xff0c;如何从0构建一个Springboot的应用程序&#xff0c;并且和Salesforce系统集成&#xff0c;取得Salesforce里面的数据。 一、先在Salesforce上构建一个ConnectApp。 有了这个&#xff0c;SF才允许你和它集成。手顺如下&#xff1a; 保存后&#xff0c;…

华为ensp中vrrp虚拟路由器冗余协议 原理及配置命令

CSDN 成就一亿技术人&#xff01; 作者主页&#xff1a;点击&#xff01; ENSP专栏&#xff1a;点击&#xff01; CSDN 成就一亿技术人&#xff01; ————前言————— VRRP&#xff08;Virtual Router Redundancy Protocol&#xff0c;虚拟路由器冗余协议&#xff0…

使用 CSS 预处理器的优缺点

使用CSS预处理器在前端开发中已经成为一种流行的趋势&#xff0c;它们提供了一种更灵活、更高效的方式来编写和管理样式表。然而&#xff0c;就像任何工具一样&#xff0c;CSS预处理器也有其优点和缺点。本文将深入探讨使用CSS预处理器的优缺点&#xff0c;并讨论如何在项目中明…

Luminar Neo:让每一张照片都散发独特魅力 mac/win版

Luminar Neo是一款引领摄影艺术新纪元的智能影像处理软件。它融合了先进的算法和人工智能技术&#xff0c;为摄影师提供了前所未有的创作自由度和影像处理能力。 Luminar Neo软件获取 作为一款强大的后期处理工具&#xff0c;Luminar Neo不仅具备丰富的调整选项和滤镜效果&…

MES管理系统生产调度模块的工作原理是什么

在现代制造业中&#xff0c;MES管理系统发挥着举足轻重的作用&#xff0c;其中的生产调度模块更是整个生产流程的核心。它集成了自动排产和手动排产的功能&#xff0c;能够精确安排每个工单在各个工序的具体生产线体、计划开始时间和计划结束时间&#xff0c;从而确保生产的高效…

一分钟学习Markdown语法

title: 一分钟学习Markdown语法 date: 2024/3/24 19:33:29 updated: 2024/3/24 19:33:29 tags: MD语法文本样式列表结构链接插入图片展示练习实践链接问题 欢迎来到Markdown语法的世界&#xff01;Markdown是一种简单而直观的标记语言&#xff0c;让文本排版变得轻松有趣。接下…