神经网络入门实战:(二十二)只训练 (多层网络的) 指定层 / (单层网络的) 指定参数

只训练 (多层网络的) 指定层 / (单层网络的) 指定参数

在训练的时候,有时候并不需要将网络层每次都从头训练,对于只训练指定层的情况,可以通过设置该层参数的 requires_grad =True ,其余层通过设置 requires_grad = False 来冻结(不更新权重)。

【注意】:此时就不要使用 self.model = nn.Sequential(...) 函数将所有层都放在一起,变成一个整体了。

1)多层网络

在定义模型时,在 __init__ 函数后, forward 函数前,加上下方这段代码:

# 初始化时先冻结所有层
for name, param in self.named_parameters():
    param.requires_grad = False

# 然后单独设置指定层的 requires_grad 为 True
for param in self.layer_name.parameters(): # layer_name 就是要单独训练的层的自定义名字
    param.requires_grad = True

随后在创建优化器时,只输入指定层的参数。【如果想要训练所有层的话,优化器的输入一般是 model_Instance.parameters()

optimizer = torch.optim.SGD(model_instance.layer_name.parameters(), lr=0.01)

具体示例:

CIFAR10 配套网络模型为例(只训练第三个卷积层):

import torch
import torch.nn as nn
import torch.optim as optim

class CIFAR10_NET(nn.Module):
    def __init__(self):
        super(CIFAR10_NET, self).__init__()
        self.conv1 = nn.Conv2d(3, 32, 5, padding=2)
        self.pool1 = nn.MaxPool2d(2, 2)
        self.conv2 = nn.Conv2d(32, 32, 5, padding=2)
        self.pool2 = nn.MaxPool2d(2, 2)
        self.conv3 = nn.Conv2d(32, 64, 5, padding=2)  # 这是我们要训练的层
        self.pool3 = nn.MaxPool2d(2, 2)
        self.flatten = nn.Flatten()
        self.fc1 = nn.Linear(1024, 64)  # 注意:这里的1024是基于输入图像大小和前面的卷积+池化层计算得出的
        self.fc2 = nn.Linear(64, 10)

        # 初始化时先冻结所有层
        for name, param in self.named_parameters():
            param.requires_grad = False

        # 然后单独设置第三个卷积层的requires_grad为True
        for param in self.conv3.parameters():
            param.requires_grad = True

    def forward(self, x):
        x = self.pool1(torch.relu(self.conv1(x)))
        x = self.pool2(torch.relu(self.conv2(x)))
        x = self.pool3(torch.relu(self.conv3(x)))  # 第三个卷积层
        x = self.flatten(x)
        x = torch.relu(self.fc1(x))
        x = self.fc2(x)
        return x

# 实例化模型
model_instance = CIFAR10_NET()

# 创建优化器,只包含第三个卷积层的参数
optimizer = optim.SGD(model_instance.conv3.parameters(), lr=0.01)

# 损失函数
criterion = nn.CrossEntropyLoss()

# 示例输入和目标(这里需要真实的CIFAR-10数据,但这里只是示例)
inputs = torch.randn(1, 3, 32, 32)  # 假设输入图像大小为32x32
targets = torch.tensor([1])  # 假设目标类别为1

# 训练循环(简化版)
model.train()
for epoch in range(10):  # 假设训练10个epoch
    optimizer.zero_grad()
    outputs = model(inputs)
    loss = criterion(outputs, targets)
    loss.backward()
    optimizer.step()

    print(f'Epoch {epoch+1}, Loss: {loss.item()}')

2)从官网或本地加载预训练模型,并修改或添加某个层,并且只训练该层

由于这种情况,代码里不会再完整的定义一遍网络模型,所以无法直接在模型定义时选择该指定哪一层训练。

此时需要先导入预训练模型,然后加入下方代码进行选择:

for name, param in model_name.named_parameters(): # model_name为模型名
    # layer_name为指定层名
    if 'layer_name' in name:
        param.requires_grad = True
    else:
        param.requires_grad = False  # 冻结其他所有参数

然后再指定优化器的输入参数即可:

model_instance = model_name() # 实例化模型
optimizer = optim.SGD(model_instance.layer_name.parameters(), lr=0.01)

具体示例:

仍以 CIFAR10 配套网络模型为例(将最后的 nn.Linear(64, 10) 层改为 nn.Linear(64, 20) )。并且原先的网络模型已经训练好了并保存在本地了。

import torch
import torch.nn as nn
import torch.optim as optim

CIFAR10_NET_new = torch.load("E:\\5_NN_model\\CIFAR10_NET.pth") # 原先将整个预训练模型保存在了本地
CIFAR10_NET_new.fc2 = nn.Linear(64, 20) # 修改最后一个全连接层,其名字为: fc2 

for name, param in CIFAR10_NET_new.named_parameters():
    # 如果参数属于修改后的全连接层,则设置为可训练
    if 'fc2' in name:
        param.requires_grad = True
    else:
        param.requires_grad = False  # 冻结其他所有参数

model_instance = CIFAR10_NET_new() # 实例化模型
        
# 创建优化器,只包含修改后的全连接层的参数
optimizer = optim.SGD(model_instance.fc2.parameters(), lr=0.01)

# 损失函数
criterion = nn.CrossEntropyLoss()

# 示例输入和目标(这里需要真实的CIFAR-10数据,但这里只是示例)
inputs = torch.randn(1, 3, 32, 32)  # 假设输入图像大小为32x32
targets = torch.tensor([1])  # 假设目标类别为1

# 训练循环(简化版)
model.train()
for epoch in range(10):  # 假设训练10个epoch
    optimizer.zero_grad()
    outputs = model(inputs)
    loss = criterion(outputs, targets)
    loss.backward()
    optimizer.step()

    print(f'Epoch {epoch+1}, Loss: {loss.item()}')

2)单层网络

此单层网络需要是使用张量类型的。

如果只训练某个参数,那么令其 requires_grad 属性为 True ,其他的参数该属性设置为 False 即可。

另外在训练过程中,只使用 w2.data -= learning_rate * w2.grad.data 来更新 w2 的值即可。

具体示例:

以线性模型 y = w1*x^2+w2*x+b 为例(只训练 w2 )。

import torch

x_data = [1.0, 2.0, 3.0]
y_data = [2.0, 4.0, 6.0]

# 初始化权重
w1 = torch.tensor([0.5], requires_grad=False)  # 不需要训练
w2 = torch.tensor([0.0], requires_grad=True)  # 需要训练,初始化为0.0以更清楚地看到训练效果
b = torch.tensor([0.5], requires_grad=False)  # 不需要训练

# 定义前向传播函数
def forward(x):
    return x * x * w1 + x * w2 + b

# 定义损失函数(均方误差)
def loss(x, y):
    y_pred = forward(x)
    return (y_pred - y) ** 2

# 训练过程
learning_rate = 0.01
for epoch in range(100):
    total_loss = 0.0
    for x, y in zip(x_data, y_data):
        l = loss(x, y)
        l.backward()  # 反向传播,计算梯度

        # 只更新w2
        w2.data -= learning_rate * w2.grad.data
        w2.grad.zero_()# 清零梯度,为下一次迭代做准备

        total_loss += l.item()

    average_loss = total_loss / len(x_data)
    print(f"Epoch {epoch+1}, Average Loss: {average_loss}")

上一篇下一篇
神经网络入门实战(二十一)待发布

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:/a/944276.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

WebRTC服务质量(10)- Pacer机制(02) RoundRobinPacketQueue

WebRTC服务质量(01)- Qos概述 WebRTC服务质量(02)- RTP协议 WebRTC服务质量(03)- RTCP协议 WebRTC服务质量(04)- 重传机制(01) RTX NACK概述 WebRTC服务质量(…

硬件设计-时钟振荡器

目录 摘要 壳式晶振 正常工作条件 摘要 本章主要介绍了晶振的分类、各项参数的意义、特点,同时也介绍了时钟抖动的成因、测量 方法、消除措施和典型滤波电路,使得我们可以正确地选择和使用晶振。 壳式晶振 如图 所示,壳式晶振的名字来源于…

Redis基础知识分享(含5种数据类型介绍+增删改查操作)

一、redis基本介绍 1.redis的启动 服务端启动 pythonubuntu:~$ redis-server客户端启动 pythonubuntu:~$ redis-cli <127.0.0.1:6379> exit pythonubuntu:~$ redis-cli --raw //(支持中文的启动方式) <127.0.0.1:6379> exit2.redis基本操作 ping发送给服务器…

sql字段值转字段

表alertlabel中记录变字段 如何用alertlabel表得到下面数据 实现的sql语句 select a.AlertID, (select Value from alertlabel where AlertIDa.AlertID and Labelhost) as host, (select Value from alertlabel where AlertIDa.AlertID and Labeljob) as job from (select …

llamafactory报错:双卡4090GPU,训练qwen2.5:7B、14B时报错GPU显存不足(out of memory),轻松搞定~~~

实际问题场景&#xff1a; 使用llamafactory进行微调qwen2.5 7B和14B的大模型时&#xff0c;会出现out of memory的报错。尝试使用降低batch_size&#xff08;原本是2&#xff0c;现在降到1&#xff09;的方式&#xff0c;可以让qwen2.5:7B跑起来&#xff0c;但时不时会不稳定…

【hackmyvm】hacked靶机wp

tags: HMVrootkitDiamorphine Type: wp 1. 基本信息^toc 文章目录 1. 基本信息^toc2. 信息收集2.1. 端口扫描2.2. 目录扫描2.3. 获取参数 3. 提权 靶机链接 https://hackmyvm.eu/machines/machine.php?vmHacked 作者 sml 难度 ⭐️⭐️⭐️⭐️️ 2. 信息收集 2.1. 端口扫描…

.NET平台用C#通过字节流动态操作Excel文件

在.NET开发中&#xff0c;通过字节流动态操作Excel文件提供了一种高效且灵活的方式处理数据。这种方法允许开发者直接在内存中创建、修改和保存Excel文档&#xff0c;无需依赖直接的文件储存、读取操作&#xff0c;从而提高了程序的性能和安全性。使用流技术处理Excel不仅简化了…

应用层1——C/S、P2P、DNS域名系统

目录 一、网络应用模型 1、C/S 2、p2p模型 二、域名解析系统DNS 1、为什么有DNS系统&#xff1f; 2、域名的特点 3、DNS域名系统原理 4、递归查询、迭代查询 5、常用的根域名与顶级域名 一、网络应用模型 1、C/S 客户/服务器模型 客户请求服务&#xff0c;服务器提供…

【疑难杂症】 HarmonyOS NEXT中Axios库的响应拦截器无法拦截424状态码怎么办?

今天在开发一个HarmonyOS NEXT的应用的时候&#xff0c;发现http接口如果返回的状态码是424时&#xff0c;我在axios中定义的拦截器失效了。直接走到了业务调用的catch中。 问题表现&#xff1a; 我的拦截器代码如下&#xff1a; 解决办法&#xff1a; 先说解决办法&#xff…

在Windows上读写Linux磁盘镜像的一种方法

背景 嵌入式开发中&#xff0c;经常会把系统的Linux磁盘镜像保存到Windows上&#xff0c;以便上传到网盘备份或发送给工厂&#xff0c;但是如果想读取/修改镜像中的某个文件&#xff0c;一般有2种方案&#xff1a; 直接访问 就是用虚拟磁盘软件将镜像文件挂载成磁盘&#xf…

ffmpeg之显示一个yuv照片

显示YUV图片的步骤 1.初始化SDL库 目的&#xff1a;确保SDL库正确初始化&#xff0c;以便可以使用其窗口、渲染和事件处理功能。操作&#xff1a;调用 SDL_Init(SDL_INIT_VIDEO) 来初始化SDL的视频子系统。 2.创建窗口用于显示YUV图像&#xff1a; 目的&#xff1a;创建一个…

Windows下播放文件作为麦克风声源的一种方式

近期测试一种外语的ASR识别成功率&#xff0c;样本素材是懂这门语言的同事录制的mp3文件。测试client端原本是从麦克风拾音生成媒体流的。 这样&#xff0c;就需要想办法把mp3文件转换为测试client的输入声音。物理方式上&#xff0c;可以用一根音频线&#xff0c;把电…

如何在网页端使用 IDE 高效地阅读 GitHub 源码?

如何在网页端使用 IDE 高效地阅读 GitHub 源码&#xff1f; 前言什么是 GitHub1s&#xff1f;使用 GitHub1s 阅读 browser-use 项目源码步骤 1: 打开 GitHub 项目页面步骤 2: 修改 URL 使用 GitHub1s步骤 3: 浏览文件结构步骤 4: 使用代码高亮和智能补全功能步骤 5: 快速跳转和…

Microsoft word@【标题样式】应用不生效(主要表现为在导航窗格不显示)

背景 随笔。Microsoft word 2013基础使用&#xff0c;仅做参考和积累。 问题 Microsoft word 2013&#xff0c;对段落标题文字应用【标题样式】不生效&#xff08;主要表现为在导航窗格不显示&#xff09;。 图1 图2 观察图1和图2&#xff0c;发现图1的文字在应用【标题一】样…

2021.12.28基于UDP同信的相关流程

作业 1、将TCP的CS模型再敲一遍 服务器 #include <myhead.h> #define PORT 8888 #define IP "192.168.124.123" int main(int argc, const char *argv[]) {//创建套接字//绑定本机IP和端口号//监听客户端请求//接收客户端连接请求//收发消息//创建套接字int…

OpenCV和PyQt的应用

1.创建一个 PyQt 应用程序&#xff0c;该应用程序能够&#xff1a; 使用 OpenCV 加载一张图像。在 PyQt 的窗口中显示这张图像。提供四个按钮&#xff08;QPushButton&#xff09;&#xff1a; 一个用于将图像转换为灰度图一个用于将图像恢复为原始彩色图一个用于将图像进行翻…

kibana启动报错:Invalid character in header content [“kbn-name“]

启动时候kibana报错&#xff1a; 打开 kibana配置文件&#xff0c;config/kibana.yml&#xff0c;配置上server.name即可&#xff0c;如下&#xff1a;

Pandas08

Pandas01 Pandas02 Pandas03 Pandas04 Pandas05 Pandas06 Pandas07 文章目录 内容回顾同期群分析1.1 同期群分析概念1.2 案例代码 数据分析报告数据分析工作内容数据分析简历说明用户生命周期标签1 什么是生命周期标签2 如何计算生命周期标签 内容回顾 TGI 偏好分析 TGI 目标…

网页数据的解析提取之Beautiful Soup

前面博客介绍了正则表达式的相关用法&#xff0c;只是一旦正则表达式写得有问题&#xff0c;得到的结果就可能不是我们想要的了。而且每一个网页都有一定的特殊结构和层级关系&#xff0c;很多节点都用id或 class 作区分所以借助它们的结构和属性来提取不也可以吗? 本篇博客我…

电脑缺失sxs.dll文件要怎么解决?

一、文件丢失问题&#xff1a;以sxs.dll文件缺失为例 当你在运行某个程序时&#xff0c;如果系统提示“找不到sxs.dll文件”&#xff0c;这意味着你的系统中缺少了一个名为sxs.dll的动态链接库文件。sxs.dll文件通常与Microsoft的.NET Framework相关&#xff0c;是许多应用程序…