分类网络搭建示例

搭建CNN网络

本章我们来学习一下如何搭建网络,初始化方法,模型的保存,预训练模型的加载方法。本专栏需要搭建的是对分类性能的测试,所以这里我们只以VGG为例。

请注意,这里定义的只是一个简陋的版本,后续一些经典网络的学习,我们会在另外单独去开一个专栏讲解。

1. 网络搭建

在PyTorch中,你可以使用 torchvision.models 中的 vgg16 来加载预定义的VGG16模型,也可以手动定义。以下是手动定义的一个简化版本:

import torch
import torch.nn as nn

class VGG16(nn.Module):
    def __init__(self, num_classes=1000):
        super(VGG16, self).__init__()
        self.features = nn.Sequential(
            nn.Conv2d(3, 64, kernel_size=3, padding=1),
            nn.ReLU(inplace=True),
            nn.Conv2d(64, 64, kernel_size=3, padding=1),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(kernel_size=2, stride=2),
            
            nn.Conv2d(64, 128, kernel_size=3, padding=1),
            nn.ReLU(inplace=True),
            nn.Conv2d(128, 128, kernel_size=3, padding=1),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(kernel_size=2, stride=2),
            
            nn.Conv2d(128, 256, kernel_size=3, padding=1),
            nn.ReLU(inplace=True),
            nn.Conv2d(256, 256, kernel_size=3, padding=1),
            nn.ReLU(inplace=True),
            nn.Conv2d(256, 256, kernel_size=3, padding=1),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(kernel_size=2, stride=2),
            
            nn.Conv2d(256, 512, kernel_size=3, padding=1),
            nn.ReLU(inplace=True),
            nn.Conv2d(512, 512, kernel_size=3, padding=1),
            nn.ReLU(inplace=True),
            nn.Conv2d(512, 512, kernel_size=3, padding=1),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(kernel_size=2, stride=2),
            
            nn.Conv2d(512, 512, kernel_size=3, padding=1),
            nn.ReLU(inplace=True),
            nn.Conv2d(512, 512, kernel_size=3, padding=1),
            nn.ReLU(inplace=True),
            nn.Conv2d(512, 512, kernel_size=3, padding=1),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(kernel_size=2, stride=2),
        )
        self.avgpool = nn.AdaptiveAvgPool2d((7, 7))
        self.classifier = nn.Sequential(
            nn.Linear(512 * 7 * 7, 4096),
            nn.ReLU(inplace=True),
            nn.Dropout(),
            nn.Linear(4096, 4096),
            nn.ReLU(inplace=True),
            nn.Dropout(),
            nn.Linear(4096, num_classes),
        )

    def forward(self, x):
        x = self.features(x)
        x = self.avgpool(x)
        x = torch.flatten(x, 1)
        x = self.classifier(x)
        return x

2. 初始化方法

在这里,我们不再手动初始化每一层,因为PyTorch的默认初始化通常足够好。你可以选择手动初始化,如果需要,可以使用 torch.nn.init 中的不同方法。

3. 模型的保存

使用 torch.save 保存VGG16模型:

vgg16 = VGG16()

torch.save(vgg16.state_dict(), 'vgg16_model.pth')

4. 预训练模型的加载

要加载预训练的VGG16模型,你可以使用 torchvision.models 中的 vgg16(pretrained=True),或者手动加载预训练权重:

vgg16 = VGG16()

vgg16.load_state_dict(torch.load('pretrained_vgg16.pth'))

请确保路径 'pretrained_vgg16.pth' 是你预训练模型文件的实际路径。你可以从PyTorch的官方模型库或其他来源下载预训练权重。

上面是最简单的一种模型全部加载的方式,但也有一些情况下,只是想加载其中一部分层的参数。剩下一部分由于已经改变参数了,无法加载预训练模型,所以要选择随机初始化。 、

这里我们来观察网络怎么去表示的:

if __name__ == "__main__":
    model = VGG16()
    for name, value in model.named_parameters():
        print(name)

下面就是控制台打印出的部分信息。 

这两行的输出就是打印网络层的名字,实际上加载预训练模型时,也是按照这个名字来加载的。

# 加载预训练 VGG16 模型的参数
pretrained_dict = torch.load('pretrained_vgg16.pth')

# 剔除预训练模型中全连接层的参数
pretrained_dict.pop('classifier.0.weight')
pretrained_dict.pop('classifier.0.bias')
pretrained_dict.pop('classifier.3.weight')
pretrained_dict.pop('classifier.3.bias')
pretrained_dict.pop('classifier.6.weight')
pretrained_dict.pop('classifier.6.bias')

# 获取自定义模型的参数字典
model_dict = model.state_dict()

# 更新自定义模型的参数字典,加载预训练模型的参数值
model_dict.update(pretrained_dict)

# 加载更新后的参数字典到自定义模型中
model.load_state_dict(model_dict)

自己定义的一些层是不会出现在pretrained_dict中,因此会将其剔除,从而只加载了 pretrained_dict中有的层。

总结

本章只是对网络的定义进行一个简单的示例,具体的部分我们会在另外一个专栏讲解,这里只是为了让读者了解网络定义的流程。在实际项目中,通常需要更详细的网络结构,包括适当的初始化方法、损失函数的选择、优化器的设置等。如果读者了解掌握了基本的网络定义过程,你可以在本专栏中深入讲解这些方面,以及如何训练和评估模型等内容。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:/a/134891.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

什么是数据库事务、事务的ACID、怎么设置/禁止自动提交?

数据库事务及ACID 数据库事务是指作为单个逻辑工作单元执行的一组操作。这组操作要么全部成功地执行,要么全部不执行,不允许出现部分执行的情况。数据库事务通常需要满足ACID属性,即原子性(Atomicity)、一致性&#x…

第2关:还原键盘输入(list)

题目&#xff1a; 知识点&#xff1a; 列表list相较于数组&#xff1a; 优势&#xff1a;可在任意指定位置插入或者删除元素而不影响列表其他地方 。 劣势&#xff1a;无法直接进行下标索引&#xff0c;需要迭代器it逐个遍历。 代码&#xff1a; #include <iostream>…

企业级信息化系统 ERP、OA、CRM、EAM、WMS、MES、PM

微服务架构&#xff0c;前端采用微应用架构&#xff0c;可做到不同服务使用不同数据库独立运行。全平台采用基于模型驱动的设计模式&#xff0c;并在前后端留有大量的代码植入入口&#xff0c;方便开发者对平台进行改造扩充。企业信息中心开发ERP、OA、CRM、EAM、WMS、MES、PM等…

R系组播调优方案

修改/etc/sysctl.conf添加如下内容&#xff1a; Vim /etc/sysctl.con net.ipv4.ip_forward1 net.ipv4.ip_nonlocal_bind1 net.ipv4.conf.all.rp_filter0 net.ipv4.conf.default.rp_filter0 net.bridge.bridge-nf-call-arptables 0 net.bridge.bridge-nf-call-ip6tables 0 …

【踩坑】Putty报错: Can’t agree a key change algorithm

原因可能是putty版本太老了&#xff0c;更新putty就好了 下载地址&#xff1a;https://www.chiark.greenend.org.uk/~sgtatham/putty/latest.html 根据需要选择自己想要下载的版本&#xff0c;我是下载的如下图所示的版本。 另外&#xff0c;了解了一下Putty是用来远程连接…

用excel计算一个矩阵的逆矩阵

假设我们的原矩阵是一个3*3的矩阵&#xff1a; 125346789 我们现在要求该矩阵的逆矩阵&#xff1a; 鼠标点到其它空白的地方&#xff0c;用来存放计算结果&#xff1a; 插入-》函数&#xff1a; 选择MINVERSE函数&#xff0c;这个就是求逆矩阵的函数&#xff1a; 点击“继续…

怎么改变容易紧张的性格?

容易紧张的性格是比较通俗的说法&#xff0c;在艾森克人格测试中&#xff0c;容易紧张的性格就属于神经症人格&#xff0c;神经质不是神-经-病&#xff0c;而是一种人格特征&#xff0c;这种特征包括&#xff1a;敏感&#xff0c;情绪不稳定&#xff0c;易焦虑和紧张。有兴趣的…

中国电子学会2023年09月份青少年软件编程Python等级考试试卷六级真题(含答案)

2023-09 Python六级真题 分数&#xff1a;100 题数&#xff1a;38 测试时长&#xff1a;60min 一、单选题(共25题&#xff0c;共50分) 1. 以下选项中&#xff0c;不是tkinter变量类型的是&#xff1f;&#xff08;D &#xff09;(2分) A.IntVar() B.StringVar() C.Do…

Redis缓存穿透、击穿和雪崩

文章目录 前言一、缓存穿透&#xff08;查不到&#xff09;1.概念2.解决方案布隆过滤器缓存空对象 二、缓存击穿&#xff08;量太大&#xff0c;缓存过期&#xff01;&#xff09;1.概述2.解决方案1.设置热点数据永不过期2.加互斥锁 三、缓存雪崩1.概念2.解决方案1.redis高可用…

Xilinx DDR3 MIG系列——ddr3控制器的时钟架构

本节目录 一、ddr3控制器的时钟架构 1、PLL输入时钟——系统时钟system_clk 2、PLL输出时钟——sync_pulse、mem_refclk、freq_refclk、MMCM1的输入时钟 3、MMCM1的输入时钟和输出时钟 4、MMCM2的输入时钟和输出时钟一、ddr3控制器的时钟架构 对于FPGA开发来说,调用IP或者移植…

PHP开源自动化平台CRUD代码生成器

生成CRUD&#xff08;创建、读取、更新、删除&#xff09;代码的实现方式有很多种&#xff0c; 一、实现方式 1. 定义数据模型&#xff1a;首先需要定义数据模型&#xff0c;包括表结构、字段以及数据类型等。 2. 自动生成数据库表&#xff1a;根据数据模型&#xff0c;使用数…

利用爬虫采集外卖数据进行竞争对手分析

目录 一、引言 二、准备工作 三、爬取数据 四、数据处理与存储 五、竞争对手分析 六、结论与展望 一、引言 在当今的数字化时代&#xff0c;数据已经成为企业成功的关键因素之一。对于餐饮外卖行业来说&#xff0c;数据的收集和分析尤为重要。通过对竞争对手的数据进行采…

【LeetCode刷题笔记】滑动窗口

992. K 个不同整数的子数组 解题思路: 滑动窗口 , 题目问题转化为: 求 「最多存在 K 个不同整数的子数组的个数」 与 「最多存在 K - 1 个不同整数的子数组的个数」 之差, 就是题目所求的 「恰好存在 K 个不同整数的子数组的个数」 , 最终问题就变成求解滑动窗口内,以 R …

webpack工作原理

目录 合并代码模块化webpack 的打包webpack 的结构webpack 的源码addEntry 和 _addModuleChainbuildModuleCompilation 的钩子产出构建结果 了解 webpack 实现原理&#xff0c;掌握 webpack 基础的工作流程&#xff0c;在平时使用 webpack 遇见问题时&#xff0c;能够帮助我们洞…

2015年计网408

第33题 通过 POP3 协议接收邮件时, 使用的传输层服务类型是( ) A. 无连接不可靠的数据传输服务 B. 无连接可靠的数据传输服务 C. 有连接不可靠的数据传输服务 D. 有连接可靠的数据传输服务 本题考察邮件接收协议POP3使用的运输层服务类型。 如图所示。接收方用户代理使用pop…

Typora-PicGo-七牛云图床

Typora-PicGo-七牛云图床 问题描述&#xff1a; 每次使用Typora写完笔记后&#xff0c;想要将笔记上传至CSDN会发现一个问题&#xff0c;由于没有配置图床&#xff0c;笔记中的图片需要一张一张的上传到CSDN&#xff0c;非常麻烦&#xff0c;若使用PicGo并搭配七牛云的10G免费…

Spring Security使用总结五,加密用户密码,不再使用明文保存密码

上一章我们成功的注册了一个新用户&#xff0c;按照正常逻辑来说&#xff0c;这一章应该是登录了&#xff0c;但是我们也看到了&#xff0c;这数据库保存的居然是明文密码&#xff0c;这谁受得了&#xff0c;这要是用户信息泄露了&#xff0c;这不让人一锅端了啊&#xff0c;还…

Java编程--单例模式(饿汉模式/懒汉模式)/阻塞队列

前言 逆水行舟&#xff0c;不进则退&#xff01;&#xff01;&#xff01; 目录 单例模式 饿汉模式&#xff1a; 懒汉模式&#xff1a; 什么是阻塞队列 什么是高内聚 低耦合 阻塞队列的实现 单例模式 单例模式&#xff08;Singleton Pattern&#xff09;是一种常见…

代码随想录算法训练营第四十八天丨 动态规划part11

123.买卖股票的最佳时机III 思路 这道题目相对 121.买卖股票的最佳时机 (opens new window)和 122.买卖股票的最佳时机II (opens new window)难了不少。 关键在于至多买卖两次&#xff0c;这意味着可以买卖一次&#xff0c;可以买卖两次&#xff0c;也可以不买卖。 接来下我…

(SpringBoot)第五章:SpringBoot创建和使用

文章目录 一&#xff1a;Spring和SpringBoot&#xff08;1&#xff09;Spring已解决和未解决的问题&#xff08;2&#xff09;SpringBoot 二&#xff1a;Spring项目的创建&#xff08;1&#xff09;IDEA创建&#xff08;2&#xff09;网页端创建 三&#xff1a;项目目录介绍及运…