修改网络的结构用于预训练

目录

一、模型准备

二、修改结构

1、在网络中添加一层

2、在classifier结点添加一个线性层

3、修改网络中的某一层(features 结点举例)

4、替换网络中的某一层结构(与第3点类似)

5、提取全连接层的输入特征数和输出特征数

6、删除网络层

7、指定某一层冻结

8、批量冻结(只训练最后一层)

9、查看哪些层冻结哪些没有

10、加载网络权重

三、自定义数据集加载


一、模型准备

import torch
import torch.nn as nn
from torchvision import models

model = models.vgg11(pretrained=False)

查看模型结构:

print(model)
VGG(
  (features): Sequential(
    (0): Conv2d(3, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (1): ReLU(inplace=True)
    (2): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
    (3): Conv2d(64, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (4): ReLU(inplace=True)
    (5): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
    (6): Conv2d(128, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (7): ReLU(inplace=True)
    (8): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (9): ReLU(inplace=True)
    (10): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
    (11): Conv2d(256, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (12): ReLU(inplace=True)
    (13): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (14): ReLU(inplace=True)
    (15): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
    (16): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (17): ReLU(inplace=True)
    (18): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (19): ReLU(inplace=True)
    (20): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
  )
  (avgpool): AdaptiveAvgPool2d(output_size=(7, 7))
  (classifier): Sequential(
    (0): Linear(in_features=25088, out_features=4096, bias=True)
    (1): ReLU(inplace=True)
    (2): Dropout(p=0.5, inplace=False)
    (3): Linear(in_features=4096, out_features=4096, bias=True)
    (4): ReLU(inplace=True)
    (5): Dropout(p=0.5, inplace=False)
    (6): Linear(in_features=4096, out_features=1000, bias=True)
  )
)

二、修改结构

1、在网络中添加一层

model.features.add_module('last_layer', nn.Conv2d(512,512, kernel_size=3, stride=1, padding=1))
print(model)

2、在classifier结点添加一个线性层

model.classifier.add_module('Linear', nn.Linear(1000, 10))
print(model)

3、修改网络中的某一层(features 结点举例)

model.features[8] = nn.Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
print(model)

4、替换网络中的某一层结构(与第3点类似)

直接提取这个结构,重新赋值一个结构即可

import torch
import torch.nn as nn
from torchvision import models

# 创建resnet50网络
model = models.resnet50(num_classes=1000)
print(model)

fc_in_features = model.fc.in_features
model.fc = torch.nn.Linear(fc_in_features, 2, bias=True)

5、提取全连接层的输入特征数和输出特征数

查看最后一层全连接的输入数量:

import torch
import torch.nn as nn
from torchvision import models

model = models.vgg11(pretrained=False)
print(model)

print(model.classifier[6].in_features)

其他的结构也类似,主要是写出结构里的参数即可。例如:全连接的in_features和out_features

6、删除网络层

使用一个空结构替换即可,即nn.Sequential()

model.classifier[6] = nn.Sequential()
print(model)

或者用切片来删除(删除后4层):

model.features = nn.Sequential(*list(model.features.children())[:-4])

net.classifier 对应 net.classifier.children()

net.features 对应 net.features.children()

7、指定某一层冻结

# 冻结指定层的预训练参数
model.classifier[3].weight.requires_grad = False

8、批量冻结(只训练最后一层)

import torch
import torch.nn as nn
from torchvision import models

model = models.vgg11(pretrained=False)
print(model)

for param in model.parameters():
    param.requires_grad_(False)
fc_in_features = model.classifier[6].in_features
model.classifier[6] = nn.Linear(fc_in_features, 2, bias=True)

9、查看哪些层冻结哪些没有

for i in model.parameters():
    if i.requires_grad:
        print(i)

10、加载网络权重


model.load_state_dict(torch.load('model.pth')) # 加载网络参数

三、自定义数据集加载

from torch.utils.data import DataLoader, Dataset

class MyDataset(Dataset):
    def __init__(self, xxxx):
        super(MyDataset, self).__init__()
        pass
    def __len__(self):
        return len(self.images)
    def __getitem__(self, idx):
        pass
        return image, classes

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:/a/754976.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

springboot + Vue前后端项目(第二十一记)

项目实战第二十一记 写在前面1. springboot文件默认传输限制2. 安装视频插件包命令3. 前台Video.vue4. 创建视频播放组件videoDetail.vue5. 路由6. 效果图总结写在最后 写在前面 本篇主要讲解系统集成视频播放插件 1. springboot文件默认传输限制 在application.yml文件中添…

《昇思25天学习打卡营第2天|快速入门》

文章目录 前言:今日所学:1. 数据集处理2. 网络的构建3. 模型训练4. 保存模型5. 加载模型 总体代码与运行结果:1. 总体代码2. 运行结果 前言: 今天是学习打卡的第2天,今天的内容是对MindSpore的一个快速入门&#xff0…

Selenium IDE 的使用指南

Selenium IDE 的使用指南 在自动化测试的领域中,Selenium 是一个广为人知且强大的工具集。而 Selenium IDE 作为其中的一个组件,为测试人员提供了一种便捷且直观的方式来创建和执行自动化测试脚本。 一、Selenium IDE 简介 Selenium IDE 是一个用于录…

第十三章 常用类

一、包装类 1. 包装类的分类 (1)针对八种基本数据类型相应的引用类型—包装类 (2)有了类的特点,就可以调用类中的方法。 2. 包装类和基本数据的转换 jdk5 前的手动装箱和拆箱方式,装箱:基本…

【Qt】信号和槽机制

创作不易&#xff0c;本篇文章如果帮助到了你&#xff0c;还请点赞 关注支持一下♡>&#x16966;<)!! 主页专栏有更多知识&#xff0c;如有疑问欢迎大家指正讨论&#xff0c;共同进步&#xff01; &#x1f525;c系列专栏&#xff1a;C/C零基础到精通 &#x1f525; 给大…

操作系统之《PV操作》【知识点+详细解题过程】

1、并发进程 &#xff1a; 并发的实质是一个处理器在几个进程之间的多路复用&#xff0c;并发是对有限的物理资源强制行使多用户共享&#xff0c;消除计算机部件之间的互等现象&#xff0c;以提高系统资源利用率。 &#xff08;1&#xff09;并发进程——互斥性&#xff1a; 进…

使用Jetpack Compose实现具有多选功能的图片网格

使用Jetpack Compose实现具有多选功能的图片网格 在现代应用中,多选功能是一项常见且重要的需求。例如,Google Photos允许用户轻松选择多个照片进行分享、添加到相册或删除。在本文中,我们将展示如何使用Jetpack Compose实现类似的多选行为,最终效果如下: 主要步骤 实现…

【redis】Redis AOF

1、AOF的基本概念 AOF持久化方式是通过保存Redis所执行的写命令来记录数据库状态的。AOF以日志的形式来记录每个写操作&#xff08;增量保存&#xff09;&#xff0c;将Redis执行过的所有写指令记录下来&#xff08;读操作不记录&#xff09;。AOF文件是一个只追加的文件&…

Redis 高级数据结构业务实践

0、前言 本文所有代码可见 > 【gitee code demo】 本文会涉及 hyperloglog 、GEO、bitmap、布隆过滤器的介绍和业务实践 1、HyperLogLog 1.1、功能 基数统计&#xff08;去重&#xff09; 1.2、redis api 命令作用案例PFADD key element [element ...]添加元素到keyPF…

PortSip测试

安装PBX 下载 免费下载 PortSIP PBX 安装PBX&#xff0c;安装后&#xff0c;运行 &#xff0c;默认用户是admin 密码是admin&#xff0c;然后配置IP 为192.168.0.189 设置域名为192.168.0.189 配置分机 添加分机&#xff0c;添加了10001、10002、9999 三个分机&#xff0c…

深度学习实验第T2周:彩色图片分类

>- **&#x1f368; 本文为[&#x1f517;365天深度学习训练营](https://mp.weixin.qq.com/s/0dvHCaOoFnW8SCp3JpzKxg) 中的学习记录博客** >- **&#x1f356; 原作者&#xff1a;[K同学啊](https://mtyjkh.blog.csdn.net/)** 目录 一、前言 目标 二、我的环境&#…

【Linux进程通信】进程间通信介绍、匿名管道原理分析

目录 进程通信是什么&#xff1f; 进程通信的目的 进程通信的本质 匿名管道&#xff1a;基于文件级别的通信方式 站在文件描述符角度-深度理解管道原理 进程通信是什么&#xff1f; 进程通信就是两个或多个进程之间进行数据层面的交互。 进程通信的目的 1.数据传输&#x…

已解决java.security.acl.LastOwnerException:无法移除最后一个所有者的正确解决方法,亲测有效!!!

已解决java.security.acl.LastOwnerException&#xff1a;无法移除最后一个所有者的正确解决方法&#xff0c;亲测有效&#xff01;&#xff01;&#xff01; 目录 问题分析 出现问题的场景 报错原因 解决思路 解决方法 1. 检查当前所有者数量 2. 添加新的所有者 3. 维…

mac Canon打印机连接教程

官网下载安装驱动&#xff1a; 选择打印机类型和mac系统型号下载即可 Mac PS 打印机驱动程序 双击安装 系统偏好设置 点击“”添加&#xff1a; OK可打印玩耍&#xff01;&#xff01; 备注&#xff1a; 若需扫描&#xff0c;下载扫描程序&#xff1a; 备注&#xff1a;…

设置小蓝熊的CPU亲和性、CPU优先级再设置法环的CPU亲和性

# 适用于Windows系统 # 时间 : 2024-06-28 # 作者 : 三巧(https://blog.csdn.net/qq_39124701) # 文件名 : 设置小蓝熊的CPU亲和性、CPU优先级再设置法环的CPU亲和性.ps1 # 使用方法: 打开记事本&#xff0c;将所有代码复制到记事本中&#xff0c;保存文件时候修改文件后…

Hugging Face Accelerate 两个后端的故事:FSDP 与 DeepSpeed

社区中有两个流行的零冗余优化器 (Zero Redundancy Optimizer&#xff0c;ZeRO)算法实现&#xff0c;一个来自DeepSpeed&#xff0c;另一个来自PyTorch。Hugging FaceAccelerate对这两者都进行了集成并通过接口暴露出来&#xff0c;以供最终用户在训练/微调模型时自主选择其中之…

zabbix-server的搭建

zabbix-server的搭建 部署 zabbix 服务端(192.168.99.180) rpm -ivh https://mirrors.aliyun.com/zabbix/zabbix/5.0/rhel/7/x86_64/zabbix-release-5.0-1.el7.noarch.rpm cd /etc/yum.repos.d sed -i s#http://repo.zabbix.com#https://mirrors.aliyun.com/zabbix# zabbix.r…

关于FPGA对 DDR4 (MT40A256M16)的读写控制 4

关于FPGA对 DDR4 &#xff08;MT40A256M16&#xff09;的读写控制 4 语言 &#xff1a;Verilg HDL 、VHDL EDA工具&#xff1a;ISE、Vivado、Quartus II 关于FPGA对 DDR4 &#xff08;MT40A256M16&#xff09;的读写控制 4一、引言二、DDR4 SDRAM设备中模式寄存器重要的模式寄存…

Arduino - LED 矩阵

Arduino - LED 矩阵 Arduino - LED Matrix LED matrix display, also known as LED display, or dot matrix display, are wide-used. In this tutorial, we are going to learn: LED矩阵显示器&#xff0c;也称为LED显示器&#xff0c;或点阵显示器&#xff0c;应用广泛。在…

“Hello, World!“ 历史由来

布莱恩W.克尼汉&#xff08;Brian W. Kernighan&#xff09;—— Unix 和 C 语言背后的巨人 布莱恩W.克尼汉在 1942 年出生在加拿大多伦多&#xff0c;他在普林斯顿大学取得了电气工程的博士学位&#xff0c;2000 年之后取得普林斯顿大学计算机科学的教授教职。 1973 年&#…