创建和探索VGG16模型

        PyTorch在torchvision库中提供了一组训练好的模型。这些模型大多数接受一个称为 pretrained 的参数,当这个参数为True 时,它会下载为ImageNet 分类问题调整好的权重。让我们看一下创建 VGG16模型的代码片段:

from torchvision import models
vgg = models.vggl6(pretrained=True)

        现在有了所有权重已经预训练好且可马上使用的VGG16模型。当代码第一次运行时,可能需要几分钟,这取决于网络速度。权重的大小可能在500MB左右。我们可以通过打印快速查看下 VGG16模型。当使用现代架构时,理解这些网络的实现方式非常有用。我们来看看这个模型:

VGG(
    (features): Sequential(
        (0):Conv2d(3,64,kernel_size=(3,3),stride=(1,1),padding=(1,1))
        (1):ReLU (inplace)
        (2):Conv2d(64,64,kernel_size=(3,3),stride=(1,1),padding=(1,1))
        (3):ReLU(inplace)
        (4):MaxPool2d(size=(2,2),stride=(2,2),dilation=(1,1))
        (5):Conv2d(64,128,kernel_size=(3,3),stride=(1,1),padding=(1,1))
        (6):ReLU(inplace)
        (7):Conv2d(128,128,kernel_size=(3,3),stride=(1,1),padding=(1,1))
        (8):ReLU(inplace)
        (9):MaxPool2d(size=(2,2),stride=(2,2),dilation=(1,1))
        (10):Conv2d(128,256,kernel_size=(3,3),stride=(1,1),padding=(1,1))
        (11):ReLU(inplace)
        (12):Conv2d(256,256,kernel_size=(3,3),stride=(1,1),padding=(1,1))
        (13):ReLU(inplace)
        (14):Conv2d(256,256,kernel_size=(3,3),stride=(1,1),padding=(1,1))
        (15):ReLU(inplace)
        (16):MaxPool2d(size=(2,2),stride=(2,2)dilation=(1,1))
        (17):Conv2d(256,512,kernel_size=(3,3),stride=(1,1),padding=(1,1))
        (18):ReLU(inplace)
        (19):Conv2d(512,512,kernel_size=(3,3),stride=(1,1),padding=(1,1))
        (20):ReLU(inplace)
        (21):Conv2d(512,512,kernel_size=(3,3),stride=(1,1),padding=(1,1))
        (22):ReLU(inplace)
        (23):MaxPool2d(size=(2,2),stride=(2,2),dilation=(1,1))
        (24):Conv2d(512,512,kernel_size=(3,3),stride=(1,1),padding=(1,1))
        (25):ReLU(inplace)
        (26):Conv2d(512,512,kernel_size=(3,3),stride=(1,1),padding=(1,1))
        (27):ReLU(inplace)
        (28):Conv2d(512,512,kernel_size=(3,3),stride=(1,1),padding=(1,1))
        (29):ReLU(inplace)
        (30):MaxPool2d(size=(2,2),stride=(2,2),dilation=(1,1))
    )
    (classifier):Sequential(
        (0):Linear(25088>4096)
        (1):ReLU(inplace)
        (2):Dropout(p=0.5)
        (3):Linear(4096->4096)
        (4):ReLU (inplace)
        (5):Dropout(p=0.5)
        (6):Linear(4096>1000)
    )
)

        模型摘要包含了两个序列模型:features和classifiers。features和sequentia1模型包含了将要冻结的层。

冻结层

        下面冻结包含卷积块的features模型的所有层。冻结层中的权重将阻止更新这些卷积块的权重。由于模型的权重被训练用来识别许多重要的特征,因而我们的算法从第一个迭代开时就具有了这样的能力。使用最初为不同用例训练的模型权重的能力,被称为迁移学习。现在看一下如何冻结层的权重或参数:

for param in vgg.features.parameters():param.requires_grad = False

        该代码阻止优化器更新权重。

微调VGG16模型

        VGG16模型被训练为针对1000个类别进行分类,但没有训练为针对狗和猫进行分类。因此,需要将最后一层的输出特征从1000改为2。以下代码片段执行此操作:

vgg.classifier[6].out_features = 2

        vgg.classifier可以访问序列模型中的所有层,第6个元素将包含最后一个层。当训练VGG16模型时,只需要训练分类器参数。因此,我们只将classifier.parameters传入优化器,如下所示:

optimizer=
optim.SGD(vgg.classifier.parameters(),lr=0.0001,momentum=0.5)

训练VGG16模型

        我们已经创建了模型和优化器。由于使用的是Dogs vs. Cats数据集,因此可以使用相同的数据加载器和train函数来训练模型。请记住,当训练模型时,只有分类器内的参数会发生变化。下面的代码片段对模型进行了20轮的训练,在验证集上达到了98.45%的准确率:

train_losses, train_accuracy =[],[]
val_losses, val_accuracy =[],[]
for epoch in range(l,20):
    epoch_loss,epoch_accuracy=fit(epoch,vgg,train_data_loader,phase='training')
    val_epoch_loss,val_epoch_accuracy=
        fit(epoch,vgg,valid_data_loader,phase='validation')
    train_losses.append(epoch_loss)
    train_accuracy.append(epoch_accuracy)
    val_losses.append(val_epoch_loss)
    val_accuracy.append(val_epoch_accuracy)

        将训练和验证的损失可视化,如图5.19所示。

        将训练和验证的准确率可视化,如图5.20所示:

        我们可以应用一些技巧,例如数据增强和使用不同的dropout值来改进模型的泛化能力。以下代码片段将 VGG分类器模块中的dropout值从0.5更改为0.2并训练模型:

for layer in vgg.classifier.children():
    if(type(layer)== nn.Dropout):
        layer.p=0.2
#训练
train_losses,train_accuracy = [][]
val_losses, val accuracy =[],[ ]
for epoch in range(1,3):
    epoch_loss,epoch_accuracy=fit(epoch,vgg,train_data_loader,phase='training')
    val_epoch_loss,val_epoch_accuracy=
        fit(epoch,vgg,valid_data_loader,phase='validation')
    train_losses.append(epoch_loss)
    train_accuracy.append(epoch_accuracy)
    val_losses.append(val_epoch_loss)
    val_accuracy.append(val_epoch_accuracy)

        通过几轮的训练,模型得到了些许改进。还可以尝试使用不同的dropout值。改进模型泛化能力的另一个重要技巧是添加更多数据或进行数据增强。我们将通过随机地水平翻转图像或以小角度旋转图像来进行数据增强。torchvision转换为数据增强提供了不同的功能,它们可以动态地进行,每轮都发生变化。我们使用以下代码实现数据增强:

train transform =transforms.Compose([transforms,Resize((224,224)),
            transforms.RandomHorizontalFlip(),
            transforms.RandomRotation(0.2),
            transforms.ToTensor(),
            transforms.Normalize([0.485,0.456,0.406],[0.229,0.224,0.225])
        ])
train = ImageFolder('dogsandcats/train/',train_transform)
valid = ImageFolder('dogsandcats/valid/',simple_transform)
#训练
train_losses,train_accuracy=[][]
val_losses,val_accuracy = [],[]
for epoch in range(1,3):
    epoch_loss,epoch_accuracy=fit(epoch,vgg,train_data_loader,phase='training')
    val_epoch_loss,val_epoch_accuracy=
        fit(epoch,vgg,valid_data_loader,phase='validation')
    train_losses.append(epoch_loss)
    train_accuracy.append(epoch_accuracy)
    val_losses.append(val_epoch_loss)
    val_accuracy.append(val_epoch_accuracy)

        前面的代码输出如下:

#结果
training loss is 0.041 and training accuracy is 22657/23000 98.51
validation loss is 0.043 and validation accuracy is 1969/2000 98.45
training loss is 0.04 and training accuracy is 22697/23000 98.68 
validation loss is 0.043 and validation accuracy is 1970/2000 98.5

        使用增强数据训练模型仅运行两轮就将模型准确率提高了0.1%;可以再运行几轮以进一步改进模型。如果大家在阅读本书时一直在训练这些模型,将意识到每轮的训练可能需要几分钟,具体取决于运行的GPU。让我们看一下可以在几秒钟内训练一轮的技术。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:/a/740433.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

猫头虎分享已解决Bug:Array Index Out of Bounds Exception

🐯 猫头虎分享已解决Bug:Array Index Out of Bounds Exception 🐯 摘要 大家好,我是猫头虎,今天我们要聊聊后端开发中经常遇到的一个问题:Array Index Out of Bounds Exception,即 java.lang.…

计算机系统基础知识(上)

目录 计算机系统的概述 计算机的硬件 处理器 存储器 总线 接口 外部设备 计算机的软件 操作系统 数据库 文件系统 计算机系统的概述 如图所示计算机系统分为软件和硬件:硬件包括:输入输出设备、存储器,处理器 软件则包括系统软件和…

北邮《计算机网络》MAC子层笔记

文章目录 缩写复习MAC层所在层次动态分配信道算法们的简要介绍信道的五条基本假设多路访问的协议(理论上的协议)aloha协议CSMA协议其他冲突避免协议无线局域网协议 ,MACA 以太网协议802.3(实际协议,刚刚是理论&#xf…

C++ 内存分配可视化

GitHub - archibate/mallocvis: allocation visualization in svg graph 正常连续内存分配 #include <vector>int main() {// 堆mallocstd::vector<int> memory;for (int i 0; i < 1000; i) {memory.emplace_back(i*10);}return 0; } 主动内存分配释放 #in…

计算机组成原理 | CPU子系统(1)基本概述

基本结构模型 运算与缓存部件 数据寄存部件 PSW不是很清楚 存储器是什么&#xff1f;属于那个结构里&#xff1f; 时序处理部件 cpu是大脑&#xff0c;控制器是神经元 ①通过硬件产生控制信号 ②通过软件产生控制信号 外频&#xff08;系统时钟信号&#xff09;&#xff0c;…

Modbus协议在工业自动化中的应用

Modbus协议介绍 Modbus是一种常用的工业现场总线通信协议,被广泛应用于工业自动化领域。它是一种简单、易实现的主从式通信协议,具有高度的可靠性和通用性。本文将从Modbus协议的基本概念、通信模式、数据格式、常见应用场景等方面进行全面介绍,并通过图文并茂的方式帮助读者更…

Linux运行jar包:Invalid or corrupt jarfile

你们好&#xff0c;我是金金金。 场景 maven打包springboot项目得到一个jar包&#xff0c;我通过xshell上传到虚拟机环境里面&#xff0c;试图运行它&#xff0c;结果Invalid or corrupt jarfile&#xff1a;jar 文件无效或损坏 排查 jdk版本是否一致&#xff1f;结果&#xf…

C++精解【6】

文章目录 eigenMatrix基础例编译时固定尺寸运行指定大小 OpenCV概述 eigen Matrix 基础 所有矩阵和向量都是Matrix模板类的对象。向量也是矩阵&#xff0c;单行或单列。Matrix模板类6个参数&#xff0c;常用就3个参数&#xff0c;其它3个参数有默认值。 Matrix<typename…

【uniapp】uniapp开发微信小程序入门教程

HBuilderx中uniapp开发微信小程序入门教程 一、 环境搭建 1. HBuilderx下载安装 HBuilderx下载安装地址 2. 微信开发者工具下载安装 微信开发者工地址具下载安装 二、创建uniapp项目 选择&#xff1a;文件>新建>项目>uni-app 输入项目名称>选择默认模板>…

2024 CISCN 华东北分区赛-Ahisec

Ahisec战队 WEB python-1 break 源码如下&#xff1a; # -*- coding: UTF-8 -*-from flask import Flask, request,render_template,render_template_stringapp Flask(__name__)def blacklist(name):blacklists ["print","cat","flag",&q…

通过高德api查询所有店铺地址信息

通过高德api查询所有店铺地址电话信息 需求&#xff1a;通过高德api查询所有店铺地址信息需求分析具体实现1、申请高德appkey2、下载types city 字典值3、具体代码调用 需求&#xff1a;通过高德api查询所有店铺地址信息 需求分析 查询现有高德api发现现有接口关键字搜索API服…

ai智能写作一键生成的软件盘点,4款宝藏!

在信息爆炸的时代&#xff0c;内容创作已成为各行各业的刚需。然而&#xff0c;对于许多创作者来说&#xff0c;如何高效、高质量地输出内容却是一个不小的挑战。幸运的是&#xff0c;随着人工智能技术的飞速发展&#xff0c;AI智能写作软件应运而生&#xff0c;它们凭借一键生…

在Vue表单中设置缺省值

有个需求&#xff0c;在新增记录的时候&#xff0c;打开新增页面&#xff0c;员工姓名处获取到当前登录用户的用户名&#xff0c;并将其设置为缺省值。 /** 新增按钮操作 */handleAdd() {this.reset();this.open true;// this.form.employeeName this.$store.state.user.name…

【Spine学习15】变换约束

变换约束&#xff1a;能让一个骨骼受另一个骨骼的变化影响。 1、选择m创建一个变换约束&#xff1a; 2、点击这个约束&#xff0c; 将移动数值拉的越满&#xff0c;m越接近s骨骼 当约束为0也就是默认的时候&#xff0c;m骨骼将不会受影响&#xff0c;变换约束可有可无。 tips…

基于Pytorch框架构建AlexNet模型

Pytorch 一、判断环境1.导入必要的库2.判断环境 二、定义字典1.定义字典 三、处理图像数据集1.导入必要的模块2.定义变量3.删除隐藏文件/文件夹 四、加载数据集1.加载训练数据集2.加载测试数据集3.定义训练数据集和测试集路径4.加载训练集和测试集5.创建训练集和测试集数据加载…

Java基础:IO流

目录 一、定义 1.引言 2.分类 &#xff08;1&#xff09;按照流的方向分 &#xff08;2&#xff09;按操作文件的类型分 3.体系结构 二、字节流&#xff08;以操作本地文件为例&#xff09; 1. FileOutputStream 类 &#xff08;1&#xff09;定义 &#xff08;2&am…

每日一题——Python代码实现PAT甲级1059 Prime Factors(举一反三+思想解读+逐步优化)五千字好文

一个认为一切根源都是“自己不够强”的INTJ 个人主页&#xff1a;用哲学编程-CSDN博客专栏&#xff1a;每日一题——举一反三Python编程学习Python内置函数 Python-3.12.0文档解读 目录 我的写法 代码点评 时间复杂度分析 空间复杂度分析 改进建议 我要更强 时间复杂度…

大学生综合能力测评系统(安装+讲解+源码)

【毕设者】大学生综合能力测评系统(安装讲解源码) 分为管理员老师学生端 技术栈 后端: SpringBoot Mysql MybatisPlus 前端: Vue Element 功能截图: 给你安装运行

从WebM到MP3:利用Python和wxPython提取音乐的魔法

前言 有没有遇到过这样的问题&#xff1a;你有一个包含多首歌曲的WebM视频文件&#xff0c;但你只想提取其中的每一首歌曲&#xff0c;并将它们保存为单独的MP3文件&#xff1f;这听起来可能有些复杂&#xff0c;但借助Python和几个强大的库&#xff0c;这个任务变得异常简单。…

开源的网络瑞士军刀「GitHub 热点速览」

上周的开源热搜项目可谓是精彩纷呈&#xff0c;主打的就一个方便快捷、开箱即用&#xff01;这款无需安装、点开就用的网络瑞士军刀 CyberChef&#xff0c;试用后你就会感叹它的功能齐全和干净的界面。不喜欢 GitHub 的英文界面&#xff1f;GitHub 网站汉化插件 github-chinese…