PyTorch如何修改模型(魔改)

文章目录

  • PyTorch如何修改模型(魔改)
    • 1.修改模型层(模型框架⭐)
      • 1.1通过继承修改模型
      • 1.2通过组合修改模型(重点学👀)
      • 1.3通过猴子补丁修改模型
    • 2.添加外部输入
    • 3.添加额外输出
    • 参考

PyTorch如何修改模型(魔改)

对模型缝缝补补、修修改改,是我们必须要掌握的技能,本文详细介绍了如何修改PyTorch模型?也就是我们经常说的如何魔改。👍

PyTorch 的模型是一个 torch.nn.Module 的某个子类的对象,修改模型实际就等价于修改某个类,对面向对象熟悉的同学应该知道,对类做修改有两个经典的方法:组合继承

1.修改模型层(模型框架⭐)

1.1通过继承修改模型

首先创建自己需要的模型类,然后其父类指向需要被修改的模型,这时自己的模型则具有完备的父类行为,最后在子类中实现魔改的逻辑。其大致的框架代码如下所示:

from torchvision.models import ResNet

class CustomizedResNet(ResNet):

    def __init__(self):
        super().__init__()
        ...
        
    def forward(self, x):
        ...

下面这个例子,将对 ResNet 进行魔改,把 ResNet 的 4 个 stage 输出的特征连接起来,然后通过一个全连接层后输出一个标量。

from torchvision.models.resnet import Bottleneck, BasicBlock, ResNet
import torch

# 定义一个自定义的ResNet类,继承自torchvision的ResNet类
class CustomizedResNet(ResNet):
    def __init__(self, block, layers, num_classes=2):
        """
        初始化函数
        block: ResNet中的基本块类型,可以是BasicBlock或Bottleneck
        layers: 每个层级的基本块数量,是一个列表
        num_classes: 输出的类别数量,默认为2
        """
        # 调用父类的初始化方法
        super().__init__(block, layers, num_classes)
        # 重新定义全连接层,改变输出的特征数量
        self.fc = torch.nn.Linear(int(512 * block.expansion * 1.875), num_classes)

    def forward(self, x):
        # 以下是ResNet的前向传播过程
        x = self.conv1(x)
        x = self.bn1(x)
        x = self.relu(x)
        x = self.maxpool(x)

        # 通过四个残差层
        x1 = self.layer1(x)
        x2 = self.layer2(x1)
        x3 = self.layer3(x2)
        x4 = self.layer4(x3)

        # 将四个残差层的输出进行拼接
        x = torch.cat(
            [self.avgpool(x1),
             self.avgpool(x2),
             self.avgpool(x3),
             self.avgpool(x4),], dim=1)

        # 将拼接后的张量展平
        x = torch.flatten(x, 1)
        # 通过全连接层,得到最终的输出
        x = self.fc(x)

        return x

# 创建不同版本的ResNet模型
new_resnet34 = CustomizedResNet(BasicBlock, [3, 4, 6, 3], num_classes=1)
new_resnet50 = CustomizedResNet(Bottleneck, [3, 4, 6, 3], num_classes=1)
new_resnet101 = CustomizedResNet(Bottleneck, [3, 4, 23, 3], num_classes=1)
new_resnet200 = CustomizedResNet(Bottleneck, [3, 24, 36, 3], num_classes=1)

1.2通过组合修改模型(重点学👀)

在面向对象编程中,可能听说过「组合优于继承」,在模型修改的场景中其实也是这样,大多数情况下我们可能都适用组合而非继承。

首先依然需要创建模型的类,但这个类不再继承自魔改的类,而是直接继承 PyTorch 的模型基类 torch.nn.Module,然后将需要魔改的类作为类变量融入到模型中,下面是大致的框架代码:

from torchvision.models import resnet18
import torch.nn as nn

class CustomizedResNet(nn.Module):

    def __init__(self, backbone):
        super().__init__()
        self.backbone = backbone
        ...

    def forward(self, x):
        ...

my_resnet18 = CustomizedResNet(resnet18)

同样,实现对 ResNet 进行魔改,把 ResNet 的 4 个 stage 输出的特征连接起来,然后通过一个全连接层后输出一个标量。

from torchvision.models import resnet50

class CustomizedResNet(torch.nn.Module):
    def __init__(self, backbone, num_classes=2):
        super().__init__()
        self.backbone = backbone
        self.fc = torch.nn.Linear(3840, num_classes)

    def forward(self, x):
        x = self.backbone.conv1(x)
        x = self.backbone.bn1(x)
        x = self.backbone.relu(x)
        x = self.backbone.maxpool(x)

        x1 = self.backbone.layer1(x)
        x2 = self.backbone.layer2(x1)
        x3 = self.backbone.layer3(x2)
        x4 = self.backbone.layer4(x3)

        x = torch.cat(
            [
                self.backbone.avgpool(x1),
                self.backbone.avgpool(x2),
                self.backbone.avgpool(x3),
                self.backbone.avgpool(x4),
            ],
            dim=1,
        )

        x = torch.flatten(x, 1)
        x = self.fc(x)

        return x

new_resnet50 = CustomizedResNet(resnet50())

1.3通过猴子补丁修改模型

最简单粗暴的方法:猴子补丁(Monkey Patch)。之所以叫猴子补丁,是因为这种方法从程序设计的角度上来说,是具有破坏性的。而且这种方法仅能实现一些简单的修改需求,所以还是推荐使用继承或组合去修改我们的模型。😉

猴子补丁修改模型非常简单粗暴,直接使用需要修改的模型创建对象,然后直接对对象的属性做出修改。下面是把 ResNet34 的输出从 1000 改为 1 的简单例子:

from torchvision.models import resnet50
import torch.nn as nn

model = resnet50()
model.fc = nn.Linear(2048, 1)

还有一个例子,以 PyTorch 官方视觉库 torchvision 预定义好的模型 ResNet50 为例,修改模型的某一层或者某几层。先观察一下它的网络结构:

import torch
import torch.nn as nn
from collections import OrderedDict
import torchvision.models as models
net = models.resnet50()
print(net)

假设要用这个模型去做一个10分类的问题,就应该修改模型的 fc 层,将其输出节点数替换为10。另外,想再加一层全连接层。可以做如下修改:

classifier = nn.Sequential(OrderedDict([('fc1', nn.Linear(2048, 128)),
                          ('relu1', nn.ReLU()), 
                          ('dropout1',nn.Dropout(0.5)),
                          ('fc2', nn.Linear(128, 10)),
                          ('output', nn.Softmax(dim=1))
                          ]))

net.fc = classifier

这里的操作相当于将模型(net)最后名称为“fc”的层替换成了名称为“classifier”的结构。

2.添加外部输入

有时候在模型训练中,除了已有模型的输入之外,还需要输入额外的信息。比如在CNN网络中,我们除了输入图像,还需要同时输入图像对应的其他信息,这时候就需要在已有的CNN网络中添加额外的输入变量。基本思路是:将原模型添加输入位置前的部分作为一个整体,同时在forward中定义好原模型不变的部分、添加输入和后续层之间的连接关系,从而完成模型的修改。

以 torchvision 的 resnet50 模型为基础,任务还是10分类任务。不同点在于,我们希望利用已有的模型结构,在倒数第二层增加一个额外的输入变量 add_variable 来辅助预测。具体实现如下:

class Model(nn.Module):
    def __init__(self, net):
        super().__init__()
        self.net = net
        self.relu = nn.ReLU()
        self.dropout = nn.Dropout(0.5)
        self.fc_add = nn.Linear(1001, 10, bias=True)
        self.output = nn.Softmax(dim=1)
        
    def forward(self, x, add_variable):
        x = self.net(x)
        x = torch.cat((self.dropout(self.relu(x)),
                       add_variable.unsqueeze(1)),1)
        x = self.fc_add(x)
        x = self.output(x)
        return x

这里的实现要点是通过torch.cat实现了tensor的拼接。torchvision 中的 resnet50 输出是一个1000维的 tensor,通过修改 forward 函数,先将 1000 维的 tensor 通过激活函数层和dropout层,再和外部输入变量"add_variable"拼接,最后通过全连接层映射到指定的输出维度 10。

另外这里对外部输入变量"add_variable"进行 unsqueeze 操作是为了和 net 输出的 tensor 保持维度一致,常用于 add_variable 是单一数值 (scalar) 的情况,此时 add_variable 的维度是 (batch_size, ),需要在第二维补充维数1,从而可以和 tensor 进行torch.cat操作。
unsqueeze与sequeeze语法说明

最后,对我们修改好的模型结构进行实例化,就可以使用了:

net = models.resnet50()
model = Model(net).cuda()

另外别忘了,训练中在输入数据的时候要给两个inputs:

outputs = model(inputs, add_var)

3.添加额外输出

有时候在模型训练中,除了模型最后的输出外,我们需要输出模型某一中间层的结果,以施加额外的监督,获得更好的中间层结果。基本的思路是修改模型定义中 forward 函数的 return 变量。

依然以 resnet50 做 10 分类任务为例,在已经定义好的模型结构上,同时输出 1000 维的倒数第二层和 10 维的最后一层结果。具体实现如下:

class Model(nn.Module):
    def __init__(self, net):
        super().__init__()
        self.net = net
        self.relu = nn.ReLU()
        self.dropout = nn.Dropout(0.5)
        self.fc1 = nn.Linear(1000, 10, bias=True)
        self.output = nn.Softmax(dim=1)
        
    def forward(self, x, add_variable):
        x1000 = self.net(x)
        x10 = self.dropout(self.relu(x1000))
        x10 = self.fc1(x10)
        x10 = self.output(x10)
        return x10, x1000

之后,对我们修改好的模型结构进行实例化,就可以使用了:

net = models.resnet50()
model = Model(net).cuda()

out10, out1000 = model(inputs, add_var)

参考

  • Chenglu’s Log

  • Pytorch修改预训练模型的方法汇总

😃😃😃

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:/a/590934.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

一加12/11/10/Ace2/Ace3手机上锁回锁BL无限重启黑屏9008模式救砖

一加12/11/10/Ace2/Ace3手机官方都支持解锁BL,搞机的用户也比较多,相对于其他品牌来说,并没有做出限制,这也可能是搞机党最后的救命稻草。而厌倦了root搞机的用户,就习惯性回锁BL,希望彻底变回官方原来的样…

QT:输入类控件的使用

LineEdit 录入个人信息 #include "widget.h" #include "ui_widget.h" #include <QDebug> #include <QString>Widget::Widget(QWidget *parent): QWidget(parent), ui(new Ui::Widget) {ui->setupUi(this);// 初始化输入框ui->lineEdit…

数据结构与算法之经典排序算法

一、简单排序 在我们的程序中&#xff0c;排序是非常常见的一种需求&#xff0c;提供一些数据元素&#xff0c;把这些数据元素按照一定的规则进行排序。比如查询一些订单按照订单的日期进行排序&#xff0c;再比如查询一些商品&#xff0c;按照商品的价格进行排序等等。所以&a…

正点原子[第二期]Linux之ARM(MX6U)裸机篇学习笔记-10.1-NXP SDK 移植

前言&#xff1a; 本文是根据哔哩哔哩网站上“正点原子[第二期]Linux之ARM&#xff08;MX6U&#xff09;裸机篇”视频的学习笔记&#xff0c;在这里会记录下正点原子 I.MX6ULL 开发板的配套视频教程所作的实验和学习笔记内容。本文大量引用了正点原子教学视频和链接中的内容。…

随便聊一下 显控科技 控制屏 通过 RS485 接口 上位机 通讯 说明

系统搭建&#xff1a; 1、自己研发的一个小系统&#xff08;采集信号&#xff0c;将采集的信号数字化&#xff09;通过COM口&#xff0c;连接显控屏 COM3 口采用 485 协议送到显控屏&#xff08;显控科技&#xff09;的显示屏展示出来&#xff09;。 2、显控屏 将 展示的数据…

【C++ | 关键字】C++ 关键字介绍

&#x1f601;博客主页&#x1f601;&#xff1a;&#x1f680;https://blog.csdn.net/wkd_007&#x1f680; &#x1f911;博客内容&#x1f911;&#xff1a;&#x1f36d;嵌入式开发、Linux、C语言、C、数据结构、音视频&#x1f36d; ⏰发布时间⏰&#xff1a;2024-05-04 0…

鸿蒙内核源码分析(汇编传参篇) | 如何传递复杂的参数

汇编如何传复杂的参数? 汇编基础篇 中很详细的介绍了一段具有代表性很经典的汇编代码&#xff0c;有循环&#xff0c;有判断&#xff0c;有运算&#xff0c;有多级函数调用。但有一个问题没有涉及&#xff0c;就是很复杂的参数如何处理? 在实际开发过程中函数参数往往是很复…

小程序账号设置以及request请求的封装

一般开发在小程序时&#xff0c;都会有测试版和正式版&#xff0c;这样在开发时会比较方便。 在开发时。产品经理都会给到测试账号和正式账号&#xff0c;后端给的接口也都会有测试环境用到的接口和正式环境用到的接口。 这里讲一讲我这边如何去做的。 1.在更目录随便命名一…

吴恩达2022机器学习专项课程(一)正则化(正则化成本函数正则化线性回归正则化逻辑回归)

目录 一.正则化1.1 正则化的好处1.2 正则化的实现方式 二.正则化改进线性回归的成本函数2.1 正则化后的成本函数的意义2.2 λ参数的作用2.3 不同λ对算法的影响2.4 为什么参数b没有正则化项 三.正则化线性回归的梯度下降3.1 为什么正则化可以在梯度下降迭代中减小w3.2 导数的计…

如何在Mac上恢复格式化硬盘的数据?

“嗨&#xff0c;我格式化了我的一个Mac硬盘&#xff0c;而没有使用Time Machine备份数据。这个硬盘被未知病毒感染了&#xff0c;所以我把它格式化为出厂设置。但是&#xff0c;我忘了备份我的文件。现在&#xff0c;我想恢复格式化的硬盘驱动器并恢复我的文档&#xff0c;您能…

面试算法题精讲:最长公共子串

面试算法题精讲&#xff1a;最长公共子串 最长公共子串问题是指给定两个字符串S1和S2&#xff0c;求它们的公共子串中最长的那一个。其实就是求两个字符串的最长重复子串。 最朴素的算法就是枚举S1和S2的每一对子串&#xff0c;然后判断它们是否相等&#xff0c;时间复杂度是…

手搓堆(C语言)

Heap.h #pragma once#include <stdio.h> #include <stdlib.h> #include <assert.h> #include <stdbool.h> #include <string.h> typedef int HPDataType; typedef struct Heap {HPDataType* a;int size;int capacity; }Heap;//初始化 void Heap…

Java Jackson-jr 库是干什么用的

Jackson-jr 是一个轻量级的Java JSON 处理库。这个库被设计用来替代 Jackson 的复杂性。对比 Jackson 的复杂 API&#xff0c;Jackson-jr 的启动速度更快&#xff0c;包大小更小。 虽然Jackson databind&#xff08;如ObjectMapper&#xff09;是通用数据绑定的良好选择&#…

如何从Mac电脑恢复任何删除的视频

Microsoft Office是包括Mac用户在内的人们在世界各地创建文档时使用的最佳软件之一。该软件允许您创建任何类型的文件&#xff0c;如演示文稿、帐户文件和书面文件。您可以使用 MS Office 来完成。所有Microsoft文档都可以在Mac上使用。大多数情况下&#xff0c;您处理文档&…

苹果CEO对未来一代人工智能投资持乐观态度

尽管在动荡的第二季度&#xff0c;苹果的收入和iPhone销量有所下降&#xff0c;但其新兴的人工智能技术可能会带来急需的提振。 在5月2日的电话财报会议上&#xff0c;苹果公布季度收入为908亿美元&#xff0c;比去年下降4%。iPhone的收入也下降了10%&#xff0c;至460亿美元。…

《Python编程从入门到实践》day19

#昨日知识点回顾 使用unittest模块测试单元和类 #今日知识点学习 第12章 武装飞船 12.1 规划项目 游戏《外星人入侵》 12.2 安装pygame 终端管理器执行 pip install pygame 12.3 开始游戏项目 12.3.1 创建Pygame窗口及响应用户输入 import sysimport pygameclass…

SpringCloud微服务项目创建流程

为了模拟微服务场景&#xff0c;学习中为了方便&#xff0c;先创建一个父工程&#xff0c;后续的工程都以这个工程为准&#xff0c;实用maven聚合和继承&#xff0c;统一管理子工程的版本和配置。 后续使用中只需要只有配置和版本需要自己规定之外没有其它区别。 微服务中分为…

电脑数据怎么拷贝到u盘?操作指南与数据丢失防范

在数字时代&#xff0c;数据的传输与备份已成为我们日常生活和工作中不可或缺的一部分。U盘作为一种便捷、高效的移动存储设备&#xff0c;广泛应用于各种数据拷贝场景。无论是个人文件的备份&#xff0c;还是工作资料的传输&#xff0c;U盘都发挥着举足轻重的作用。那么&#…

【刷题(1)】链表

一、链表问题基础 移动:head=head.next 移动到最后:head.next=null 停止:if not xx: 相当于if not null: 取值:a.val 赋值:b.next=a 遍历:while head: head=head.next 或者: if not head: head=head.next 递归:单链表: def func(pre,cur): return func(pre.next,cur.…

Windows安装Ubuntu24详细教程

从官网下载ISO镜像&#xff1a; 使用VMWare创建新的虚拟机&#xff1a; 选择刚才下载的ISO镜像&#xff1a; 填写账号和密码&#xff1a; 选择安装位置和虚拟机名称&#xff0c;因为我装这个主要是为了QT开发的&#xff0c;所以名字叫ubuntu24_for_qt&#xff1a; 处理…