深度学习_18_模型的下载与读取

在深度学习的过程中,需要将训练好的模型运用到我们要使用的另一个程序中,这就需要模型的下载与转移操作

代码:

import math
import torch
from torch import nn
from d2l import torch as d2l
import matplotlib.pyplot as plt

# 生成随机的数据集
max_degree = 20  # 多项式的最大阶数
n_train, n_test = 100, 100  # 训练和测试数据集大小
true_w = torch.zeros(max_degree)
true_w[0:4] = torch.Tensor([5, 1.2, -3.4, 5.6])

# 生成特征
features = torch.randn((n_train + n_test, 1))
permutation_indices = torch.randperm(features.size(0))
# 使用随机排列的索引来打乱features张量(原地修改)
features = features[permutation_indices]
poly_features = torch.pow(features, torch.arange(max_degree).reshape(1, -1))
for i in range(max_degree):
    poly_features[:, i] /= math.gamma(i + 1)

# 生成标签
labels = torch.matmul(poly_features, true_w)
labels += torch.normal(0, 0.1, size=labels.shape)


# 以下是你原来的训练函数,没有修改
def evaluate_loss(net, data_iter, loss):
    metric = d2l.Accumulator(2)
    for X, y in data_iter:
        out = net(X)
        y = y.reshape(out.shape)
        l = loss(out, y)
        metric.add(l.sum(), l.numel())
    return metric[0] / metric[1]


def l2_penalty(w):
    w = w[0].weight
    return torch.sum(w.pow(2)) / 2


def train(train_features, test_features, train_labels, test_labels, lambd,
          num_epochs=100):
    loss = d2l.squared_loss
    input_shape = train_features.shape[-1]
    net = nn.Sequential(nn.Linear(input_shape, 1, bias=False))  # 模型
    batch_size = min(10, train_labels.shape[0])

    train_iter = d2l.load_array((train_features, train_labels.reshape(-1, 1)),
                                batch_size)
    test_iter = d2l.load_array((test_features, test_labels.reshape(-1, 1)),
                               batch_size, is_train=False)

    # 用于存储训练和测试损失的列表
    train_losses = []
    test_losses = []
    total_loss = 0
    total_samples = 0
    for epoch in range(num_epochs):
        for X, y in train_iter:
            out = net(X)
            y = y.reshape(-1, 1)  # 确保y是二维的
            l = loss(out, y) + lambd * l2_penalty(net)

            # 反向传播和优化器更新
            l.sum().backward()
            d2l.sgd(net.parameters(), lr=0.01, batch_size= batch_size)
            total_loss += l.sum().item()  # 统计所有元素损失
            total_samples += y.numel()  # 统计个数
        a = total_loss / total_samples  # 本次训练的平均损失
        train_losses.append(a)
        test_loss = evaluate_loss(net, test_iter, loss)
        test_losses.append(test_loss)
        total_loss = 0
        total_samples = 0
        print(f"Epoch {epoch + 1}/{num_epochs}:")
        print(f"训练损失: {a:.4f}   测试损失: {test_loss:.4f} ")
    print(net[0].weight)

    torch.save(net.state_dict(), "NetSave")  # 存模型
    net_try = nn.Sequential(nn.Linear(input_shape, 1, bias=False))
    print("net_try")
    print(net_try[0].weight)
    net_try.load_state_dict(torch.load("NetSave"))
    net_try.eval()  # 评估模式
    print("net_try_load")
    print(net_try[0].weight)
    # 绘制损失曲线
    plt.figure(figsize=(10, 6))
    plt.plot(train_losses, label='train', color='blue', linestyle='-', marker='.')
    plt.plot(test_losses, label='test', color='purple', linestyle='--', marker='.')
    plt.xlabel('epoch')
    plt.ylabel('loss')
    plt.title('Loss over Epochs')
    plt.legend()
    plt.grid(True)
    plt.ylim(0, 1)  # 设置y轴的范围从0.01到100
    plt.show()


# 选择多项式特征中的前4个维度
train(poly_features[:n_train, :4], poly_features[n_train:, :4],
      labels[:n_train], labels[n_train:], 0)

##  net.parameters() 是一个 PyTorch 模型的方法,用于返回模型所有参数的迭代器。这个迭代器产生模型中所有可学习的参数(例如权重和偏置)。

上述代码的模型是简单线性模型

net = nn.Sequential(nn.Linear(input_shape, 1, bias=False))  # 模型

此模型的下载与储存如下

    torch.save(net.state_dict(), "NetSave")  # 存模型
    net_try = nn.Sequential(nn.Linear(input_shape, 1, bias=False))  # 搭建模型框架
    print("net_try")
    print(net_try[0].weight)
    net_try.load_state_dict(torch.load("NetSave"))  # 下载模型
    net_try.eval()  # 评估模式
    print("net_try_load")
    print(net_try[0].weight)

效果
在这里插入图片描述

所以说要想在另一个程序中将训练好的模型加载到上面去,首先要保存训练好的模型,另一个程序必须有和本模型一样的框架,再将训练好的模型权重加载到另一个程序框架内即可

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:/a/434540.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

服务器为什么会卡顿,出现卡顿情况怎么办

从维护服务器的角度来看,服务器卡顿是一种常见的问题,但服务器卡顿可能会影响到网站、游戏或平台的正常访问和运行,所以出现卡顿问题首先需要对服务器进行全面的检查,确定卡顿原因,然后选取适合的解决方案,…

基于Spring Boot的秒杀系统(附项目源码+论文)

摘要 社会发展日新月异,用计算机应用实现数据管理功能已经算是很完善的了,但是随着移动互联网的到来,处理信息不再受制于地理位置的限制,处理信息及时高效,备受人们的喜爱。本次开发一套基于Spring Boot的秒杀系统&am…

网络编程作业day6

数据库操作的增、删、改完成 #include <myhead.h>//查询的回调函数 int callback(void* data,int count,char** argv, char** columnName) {//count是字段数//argv是字段内容//columnName是字段名称for(int i0;i<count;i) {printf("%s%s\n", columnName[…

智能驾驶规划控制理论学习06-基于优化的规划方法

目录 一、优化概念 1、一般优化问题 2、全局最优和局部最优 二、无约束优化 1、无约束优化概述 2、梯度方法 通用框架 线性搜索 回溯搜索 3、梯度下降 基本思想 实现流程 ​4、牛顿法 基本思想 实现流程 5、高斯牛顿法 6、LM法&#xff08;Le…

通过hyperbeam创建梁单元截面属性

1、为模型中标准的圆柱形创建梁单元和赋予属性&#xff1b; 2、为模型中不标准的对称性实体创建梁单元和赋予属性&#xff1b; 3、为模型中壳体部分创建梁单元和赋予属性&#xff1b;

上位机图像处理和嵌入式模块部署(qmacvisual三个特色)

【 声明&#xff1a;版权所有&#xff0c;欢迎转载&#xff0c;请勿用于商业用途。 联系信箱&#xff1a;feixiaoxing 163.com】 了解了qmacvisual的配置之后&#xff0c;正常来说&#xff0c;我们需要了解下不同插件的功能是什么。不过我们不用着急&#xff0c;可以继续学习下…

分布式事务(SeataClient)

问题场景 元数据 库存 100订单记录为空下单操作 @AutowiredRestTemplate restTemplate;/*** 下单** @return*/@Transactional // 开启事务 异常后触发数据库回滚操作@Overridepublic Order create(Order order) {// 插入订单orderMapper.insert(order);// 扣减库存 MultiValu…

Python 弱引用全解析:深入探讨对象引用机制!

目录 前言 弱引用的概述 弱引用的原理 使用 WeakRef 类创建弱引用 使用 WeakValueDictionary 类创建弱引用字典 实际应用场景 1. 解决循环引用问题 2. 对象缓存 总结 前言 在Python编程中&#xff0c;弱引用&#xff08;Weak Reference&#xff09;是一种特殊的引用方式…

折线图 温度变化曲线图

代码详情介绍 导入必要的库&#xff1a; matplotlib.pyplot&#xff1a;用于绘图。 matplotlib.font_manager&#xff1a;用于设置中文字体。 datetime&#xff1a;用于处理日期和时间。 random&#xff1a;用于生成随机数。 numpy&#xff1a;用于生成arange函数的刻度。 设置…

【kubernetes】关于k8s集群如何将pod调度到指定node节点?

目录 一、k8s的watch机制 二、scheduler的调度策略 Predicate&#xff08;预选策略&#xff09; 常见算法&#xff1a; priorities&#xff08;优选策略&#xff09;常见的算法有&#xff1a; 三、k8s的标签管理之增删改查 四、k8s的将pod调度到指定node的方法 方案一&am…

RK356X RK3588 单独编译kernel 与烧录

RK356X RK3588 单独编译kernel 与烧录 可以快速提高我们开发与调试速度 网上可查到的方法如下&#xff1a; RK3568 Android12&#xff1a; 1.添加kernel-4.19/makekernel.sh #!/bin/sh make -j24 ARCHarm64 CC../prebuilts/clang/host/linux-x86/clang-r416183b/bin/clang …

EasyRecovery易恢复2024免激活安装包下载

EasyRecovery易恢复是一款功能强大的数据恢复软件。这款软件由全球著名数据厂商Kroll Ontrack出品&#xff0c;可以恢复被删除的文件、文件夹&#xff0c;以及被格式化的磁盘等数据。无论是硬盘、U盘、SD卡还是其他移动设备&#xff0c;EasyRecovery易恢复都能通过其专业的数据…

全连接神经网络算法原理(激活函数、前向传播、梯度下降法、损失函数、反向传播)

文章目录 前言1、全连接神经网络的整体结构&#xff1a;全连接神经网络模型是由输入层、隐藏层、输出层所组成&#xff0c;全连接神经网络结构如下图所示&#xff1a;全连接神经网络的每一层都是由一个一个的神经元所组成的&#xff0c;因此只要搞清楚神经元的本质就可以搞清楚…

MetaQTL:元分析基础教程

MetaQTL 基础知识 在遥远的海洋中&#xff0c;每个岛屿都藏着无尽的宝藏&#xff0c;而探险家们争相寻找地图&#xff0c;以期揭开宝藏的秘密。 现实世界中&#xff0c;我们的基因组就像那片广阔的海洋&#xff0c;而隐藏在其中的宝藏就是控制我们身高、健康、甚至是我们性格的…

MM配置2-给公司代码分配工厂

配置步骤&#xff0c;如下图&#xff1a;在弹出的对话框中将工厂分配给相应的公司代码 保存完成

UDP通信发送和接收 || UDP实现全双工通信

recvfrom ssize_t recvfrom(int sockfd, void *buf, size_t len, int flags, struct sockaddr *src_addr, socklen_t *addrlen); 功能: 从套接字中接收数据 参数: sockfd:套接字文件描述符 buf:存放数据空间首地址 …

java的运算符

整形和浮点型相比&#xff0c;浮点型的范围更大&#xff0c;所以在Java中正常条件下都是整形隐式转换为浮点型(任意整形都可以隐式转换为double或者float)&#xff0c;浮点型不能隐式转换为整形。 1.算术运算符 1. 基本四则运算符&#xff1a;加减乘除模( - * / %) 加减乘都…

mfc110u.dll丢失的解决方法,5招搞定mfc110u.dll丢失问题

mfc110u.dll是一个动态链接库文件&#xff0c;它是Microsoft Foundation Class&#xff08;MFC&#xff09;库的一部分。MFC是微软公司为Visual C开发人员提供的一个类库&#xff0c;用于简化Windows应用程序的开发过程。mfc110u.dll文件包含了MFC库中的一些功能和类&#xff0…

口碑营销:品牌如何维护良好口碑?

企业的品牌传播最有效的方式莫过用户的口碑&#xff0c;互联网的发展为企业的品牌传播引入了驱动力&#xff0c;愈来愈多的企业花费更多的资源开展网络口碑的建设和维护&#xff0c;那么企业如何维护好网络口碑&#xff1f; 1、持续传递优质的品牌内容 内容是营销推广的支撑点&…

MySQL进阶之(四)InnoDB数据存储结构之行格式

四、InnoDB数据存储结构之行格式 4.1 行格式的语法4.2 COMPACT 行格式4.2.1 记录的额外信息01、变长字段长度列表02、NULL 值列表03、记录头信息 4.2.2 记录的真实数据 4.3 Dynamic 和 Compressed 行格式4.3.1 字段的长度限制4.3.2 行溢出4.3.3 Dynamic 和 Compressed 行格式 4…