Pytorch网络模型训练

现有网络模型的使用与修改

vgg16_false = torchvision.models.vgg16(pretrained=False)        # 加载一个未预训练的模型
vgg16_true = torchvision.models.vgg16(pretrained=True)
# 把数据分为了1000个类别

print(vgg16_true)

以下是vgg16预训练模型的输出 

VGG(
  (features): Sequential(
    (0): Conv2d(3, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (1): ReLU(inplace=True)
    (2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (3): ReLU(inplace=True)
    (4): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
    (5): Conv2d(64, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (6): ReLU(inplace=True)
    (7): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (8): ReLU(inplace=True)
    (9): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
    (10): Conv2d(128, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (11): ReLU(inplace=True)
    (12): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (13): ReLU(inplace=True)
    (14): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (15): ReLU(inplace=True)
    (16): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
    (17): Conv2d(256, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (18): ReLU(inplace=True)
    (19): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (20): ReLU(inplace=True)
    (21): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (22): ReLU(inplace=True)
    (23): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
    (24): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (25): ReLU(inplace=True)
    (26): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (27): ReLU(inplace=True)
    (28): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (29): ReLU(inplace=True)
    (30): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
  )
  (avgpool): AdaptiveAvgPool2d(output_size=(7, 7))
  (classifier): Sequential(
    (0): Linear(in_features=25088, out_features=4096, bias=True)
    (1): ReLU(inplace=True)
    (2): Dropout(p=0.5, inplace=False)
    (3): Linear(in_features=4096, out_features=4096, bias=True)
    (4): ReLU(inplace=True)
    (5): Dropout(p=0.5, inplace=False)
    (6): Linear(in_features=4096, out_features=1000, bias=True)
  )
)

预训练模型的输出从1000类别转为10类别

import torchvision
from torch import nn
# 因为数据集过大,所以注释掉此行代码
# train_data = torchvision.datasets.ImageNet("./data_image_net", split='train', download=True,
#                                            transform=torchvision.transforms.ToTensor())

vgg16_false = torchvision.models.vgg16(pretrained=False)        # 加载一个未预训练的模型
vgg16_true = torchvision.models.vgg16(pretrained=True)
# 把数据分为了1000个类别

print(vgg16_true)

# vgg16_true.add_module("add_linear", nn.Linear(1000, 10))
vgg16_true.classifier.add_module("add_linear", nn.Linear(1000, 10))
# 在预训练模型的最后添加了一个新的全连接层,用于将最后的输出转化为10个类别
print(vgg16_true)

print(vgg16_false)
vgg16_false.classifier[6] = nn.Linear(4096, 10)
# 未预训练模型的最后一层的输出特征数更改为了10
print(vgg16_false)

网络模型的保存与读取

加载未预训练的模型

vgg16 = torchvision.models.vgg16(pretrained=False)

方式一

# 保存方式1  保存的模型结构+模型参数
torch.save(vgg16, "vgg16_method1.pyth")

#读取方式1
model = torch.load("vgg16_method1.pth")

方式二

# 保存方式2  不再保存模型结构,而是保存模型的参数为字典结构    推荐
torch.save(vgg16.state_dict(), "vgg16_method2.pyth")

# 方式2,加载模型
# model = torch.load("vgg16_method2.pth")     #这样输出的是字典类型
# print(model)
vgg16 = torchvision.models.vgg16(pretrained=False)
vgg16.load_state_dict(torch.load("vgg16_method2.pth"))      # 将其恢复为网络模型
print(vgg16)

完整的模型训练套路

准备数据集

# 准备数据集
train_data = torchvision.datasets.CIFAR10("../data", train=True, transform=torchvision.transforms.ToTensor(),
                                          download=True)
test_data = torchvision.datasets.CIFAR10("../data", train=False, transform=torchvision.transforms.ToTensor(),
                                          download=True)

train_data_size = len(train_data)
test_data_size = len(test_data)
print("训练数据集的长度为{}".format(train_data_size))    # 50000
print("测试数据集的长度为{}".format(test_data_size))     # 10000

# 利用Dataloader来加载数据集
train_dataloader = DataLoader(train_data, batch_size=64)
test_dataloader = DataLoader(test_data, batch_size=64)

创建网络模型

# 创建网络模型  神经网络的代码在train_module文件
tudui = Tudui()

train_module文件

# 搭建神经网络
class Tudui(nn.Module):
    def __init__(self):
        super(Tudui, self).__init__()
        # 简化操作,并且按顺序进行操作
        self.model1 = Sequential(
            Conv2d(3, 32, 5, padding=2),
            MaxPool2d(2),
            Conv2d(32, 32, 5, padding=2),
            MaxPool2d(2),
            Conv2d(32, 64, 5, padding=2),
            MaxPool2d(2),
            Flatten(),
            Linear(1024, 64),
            Linear(64, 10)
        )

    def forward(self, x):
        x = self.model1(x)
        return x

构建损失函数

# 损失函数
loss_fn = nn.CrossEntropyLoss()

构建优化器

# 优化器
# 如果学习率过大,模型可能会在最小值附近震荡而无法收敛;如果学习率过小,模型训练可能会过于缓慢
learning_rate = 0.01
# 使用随机梯度下降算法来更新模型的权重
optimizer = torch.optim.SGD(tudui.parameters(), lr=learning_rate)

设置训练集参数

# 记录训练的次数
total_train_step = 0
# 记录测试的次数
total_test_step = 0
# 训练的轮数
epoch = 10

添加tensorboard

# 将数据写入 TensorBoard 可视化的日志文件中
writer = SummaryWriter("../logs_train")

训练步骤

# tudui.train()
for data in train_dataloader:
    imgs, targets = data
    outputs = tudui(imgs)
    loss = loss_fn(outputs, targets)

    # 优化器优化模型
    optimizer.zero_grad()
    # 将优化器中的梯度缓存(如果有的话)清零
    loss.backward()
    # 计算损失函数(loss)相对于模型参数的梯度
    optimizer.step()

    total_train_step = total_train_step + 1
    if total_train_step % 100 == 0:
        # .item()是将tensor张量变为正常的数字
        print("训练次数:{},Loss:{}".format(total_train_step, loss.item()))
        # loss.item()是当前步骤的损失值
        writer.add_scalar("train_loss", loss.item(), total_train_step)
        # 使用add_scalar可以将一个标量添加到之前的所有标量值中,
        # 这样就可以在TensorBoard中绘制一个标量随时间变化的图表

测试步骤

# 测试步骤开始
# tudui.eval()
total_test_loss = 0
total_accuracy = 0
# 不会对以下的代码进行调优
with torch.no_grad():
    for data in test_dataloader:
        imgs, targets = data
        outputs = tudui(imgs)
        loss = loss_fn(outputs, targets)
        total_test_loss = total_test_loss + loss.item()
        # argmax(1)是横向看,argmax(0)是纵向看
        accuracy = (outputs.argmax(1) == targets).sum()
        # argmax在找到模型预测的最大概率对应的类别
        # 预测正确的个数
        total_accuracy = total_accuracy + accuracy

print("整体测试集上的Loss:{}".format(total_test_loss))
print("整体测试集上的正确率:{}".format(total_accuracy/test_data_size))
# 测试集上的总损失
writer.add_scalar("test_loss", total_test_loss, total_test_step)
writer.add_scalar("test_accuracy", total_accuracy/test_data_size, total_test_step)
total_test_step = total_test_step + 1

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:/a/116171.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

论文浅尝 | ChatKBQA:基于微调大语言模型的知识图谱问答框架

第一作者:罗浩然,北京邮电大学博士研究生,研究方向为知识图谱与大语言模型协同推理 OpenKG地址:http://openkg.cn/tool/bupt-chatkbqa GitHub地址:https://github.com/LHRLAB/ChatKBQA 论文链接:https://ar…

小程序如何设置用户同意服务协议并上传头像和昵称

为了保护用户权益和提供更好的用户体验,设置一些必填项和必读协议是非常必要的。首先,用户必须阅读服务协议。服务协议是明确规定用户和商家之间权益和义务的文件。通过要求用户在下单前必须同意协议,可以确保用户在使用服务之前了解并同意相…

Android studio新版本多渠道打包配置

最近公司套壳app比较多 功能也都一样只有地址,和app名字还有icon不一样 签名文件也是一样的,所以就研究了多渠道打包 配置如下: 在app下build.gradle配置 因为最新版as中禁用了BuildConfig 所以我们需要手动配置一下 android { //TODO 其他省略buildFe…

3 函数的升级-上

常量与宏回顾 C中的const常量可以替代常数定义,如: "Const int a 8; --> 等价于 #define a 8 " 宏在预编译阶段处理,而c const常量则在编译阶段处理,比宏 更为安全。 C中,我们可以用宏代码片段去实现某个函数&…

0006Java安卓程序设计-ssm基于Android的校园二手商品交易平台

文章目录 **摘** **要****目** **录**系统设计开发环境 编程技术交流、源码分享、模板分享、网课教程 🐧裙:776871563 摘 要 随着毕业季的来临以及当代大学生的消费力购买力的不断增强,我们的寝室中囤积了很多二手商品,有很多是…

正点原子嵌入式linux驱动开发——Linux CAN驱动

CAN是目前应用非常广泛的现场总线之一,主要应用于汽车电子和工业领域,尤其是汽车领域,汽车上大量的传感器与模块都是通过CAN总线连接起来的。CAN总线目前是自动化领域发展的热点技术之一,由于其高可靠性,CAN总线目前广…

阿里云安全恶意程序检测

阿里云安全恶意程序检测 赛题理解赛题介绍赛题说明数据说明评测指标 赛题分析数据特征解题思路 数据探索数据特征类型数据分布箱型图 变量取值分布缺失值异常值分析训练集的tid特征标签分布测试集数据探索同上 数据集联合分析file_id分析API分析 特征工程与基线模型构造特征与特…

【前端周报11.03】

前端周汇报11.03 那我们接着上一周的继续往下进行推进上周总结本周工作下周内容 那我们接着上一周的继续往下进行推进 上周总结 上一周的话我其实最主要的工作还是进行了一系列的调研主要的话是针对于我们未来要做的小程序的项目的,为未来开发这个小程序做好一系列…

leetcode:26. 删除有序数组中的重复项(python3解法)

难度:简单 给你一个 非严格递增排列 的数组 nums ,请你 原地 删除重复出现的元素,使每个元素 只出现一次 ,返回删除后数组的新长度。元素的 相对顺序 应该保持 一致 。然后返回 nums 中唯一元素的个数。 考虑 nums 的唯一元素的数…

多目标跟踪算法 实时检测 - opencv 深度学习 机器视觉 计算机竞赛

文章目录 0 前言2 先上成果3 多目标跟踪的两种方法3.1 方法13.2 方法2 4 Tracking By Detecting的跟踪过程4.1 存在的问题4.2 基于轨迹预测的跟踪方式 5 训练代码6 最后 0 前言 🔥 优质竞赛项目系列,今天要分享的是 🚩 深度学习多目标跟踪 …

【入门Flink】- 04Flink部署模式和运行模式【偏概念】

部署模式 在一些应用场景中,对于集群资源分配和占用的方式,可能会有特定的需求。Flink为各种场景提供了不同的部署模式,主要有以下三种:会话模式(Session Mode)、单作业模式(Per-Job Mode&…

Ubuntu20.04下安装Redis环境

apt安装Redis环境 更新apt-get安装镜像源 安装Redis sudo apt-get install -y redis-server设置密码 # 编辑Redis的配置文件redis.conf,如果不知道配置文件的位置可以执行whereis redis.conf查看 sudo vim /etc/redis/redis.conf取消文件中的requirepass注释&am…

设计模式(22)享元模式

一、介绍: 1、定义:享元模式(Flyweight Pattern)主要用于减少创建对象的数量,以减少内存占用和提高性能。这种类型的设计模式属于结构型模式,它提供了减少对象数量从而改善应用所需的对象结构的方式。 2、…

memcpy()之小端模式

函数原型 void memcpy(voiddestin, const void *src, size_t n); 功能 由src指向地址为起始地址的连续n个字节的数据复制到以destin指向地址为起始地址的空间内。 头文件 #include<string.h> 返回值 函数返回一个指向dest的指针。 例1&#xff1a;如果用来复制字…

FPGA高端项目:图像采集+GTP+UDP架构,高速接口以太网视频传输,提供2套工程源码加QT上位机源码和技术支持

目录 1、前言免责声明本项目特点 2、相关方案推荐我这里已有的 GT 高速接口解决方案我这里已有的以太网方案 3、设计思路框架设计框图视频源选择OV5640摄像头配置及采集动态彩条视频数据组包GTP 全网最细解读GTP 基本结构GTP 发送和接收处理流程GTP 的参考时钟GTP 发送接口GTP …

【计算机网络】运输层

概述运输层服务 运输层协议为运行在不同主机上的应用程序提供了逻辑通信功能。 运输层协议是在端系统中而不是在路由器中实现的。 运输层和网络层的关系&#xff1a; 网络层提供主机之间的逻辑通信&#xff0c;而运输层为**运行在不同主机上的应用程序&#xff08;进程&#…

做读书笔记时的一个高效小技巧

你好&#xff0c;我是 EarlGrey&#xff0c;一名双语学习者&#xff0c;会一点编程&#xff0c;目前已翻译出版《Python 无师自通》、《Python 并行编程手册》等书籍。 在这里&#xff0c;我会持续和大家分享好书、好工具和高效生活、工作技巧&#xff0c;欢迎大家一起提升认知…

【CesiumJS】(1)Hello world

介绍 Cesium 起源于2011年&#xff0c;初衷是航空软件公司(Analytical Graphics, Inc.)的一个团队要制作世界上最准确、性能最高且具有时间动态性的虚拟地球。取名"Cesium"是因为元素铯Cesium让原子钟非常准确&#xff08;1967年&#xff0c;人们依据铯原子的振动而对…

Android Studio打包AAR

注意 依赖的Android Studio版本为4.2.2 更高的Android Studio版本使用方法可能有所不同&#xff0c;gradle的版本和gradle plugins的版本都会影响使用方式。 基于此&#xff0c;本文只能作为参考&#xff0c;而不能作为唯一答案&#xff0c;如果要完全依赖本文&#xff0c;则…

GPT与人类共生:解析AI助手的兴起

随着GPT模型的崭新应用&#xff0c;如百度的​1​和CSDN的​2​&#xff0c;以及AI助手的普及&#xff0c;人们开始讨论AI对就业市场和互联网公司的潜在影响。本文将探讨GPT和AI助手的共生关系&#xff0c;以及我们如何使用它们&#xff0c;以及使用的平台和动机。 GPT和AI助手…