Pytorch迁移学习使用Resnet50进行模型训练预测猫狗二分类

目录

1.ResNet残差网络

1.1 ResNet定义

 1.2 ResNet 几种网络配置

 1.3 ResNet50网络结构

1.3.1 前几层卷积和池化

1.3.2 残差块:构建深度残差网络

1.3.3 ResNet主体:堆叠多个残差块

1.4 迁移学习猫狗二分类实战

1.4.1 迁移学习

1.4.2 模型训练

1.4.3 模型预测


1.ResNet残差网络

1.1 ResNet定义

深度学习在图像分类、目标检测、语音识别等领域取得了重大突破,但是随着网络层数的增加,梯度消失和梯度爆炸问题逐渐凸显。随着层数的增加,梯度信息在反向传播过程中逐渐变小,导致网络难以收敛。同时,梯度爆炸问题也会导致网络的参数更新过大,无法正常收敛。

为了解决这些问题,ResNet提出了一个创新的思路:引入残差块(Residual Block)。残差块的设计允许网络学习残差映射,从而减轻了梯度消失问题,使得网络更容易训练。

下图是一个基本残差块。它的操作是把某层输入跳跃连接到下一层乃至更深层的激活层之前,同本层输出一起经过激活函数输出。
 

24353e89d9c84a17babbbf4ebe90630b.png

 1.2 ResNet 几种网络配置

如下图:

 1.3 ResNet50网络结构

ResNet-50是一个具有50个卷积层的深度残差网络。它的网络结构非常复杂,但我们可以将其分为以下几个模块:

1.3.1 前几层卷积和池化

import torch
import torch.nn as nn

class ResNet50(nn.Module):
    def __init__(self, num_classes=1000):
        super(ResNet50, self).__init__()
        self.conv1 = nn.Conv2d(in_channels=3, out_channels=64, kernel_size=7, stride=2, padding=3, bias=False)
        self.bn1 = nn.BatchNorm2d(64)
        self.relu = nn.ReLU(inplace=True)
        self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)

1.3.2 残差块:构建深度残差网络

class ResidualBlock(nn.Module):
    def __init__(self, in_channels, out_channels, stride=1, downsample=None):
        super(ResidualBlock, self).__init__()
        self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=stride, bias=False)
        self.bn1 = nn.BatchNorm2d(out_channels)
        self.conv2 = nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=1, padding=1, bias=False)
        self.bn2 = nn.BatchNorm2d(out_channels)
        self.conv3 = nn.Conv2d(out_channels, out_channels * 4, kernel_size=1, stride=1, bias=False)
        self.bn3 = nn.BatchNorm2d(out_channels * 4)
        self.relu = nn.ReLU(inplace=True)
        self.downsample = downsample

    def forward(self, x):
        identity = x

        out = self.conv1(x)
        out = self.bn1(out)
        out = self.relu(out)

        out = self.conv2(out)
        out = self.bn2(out)
        out = self.relu(out)

        out = self.conv3(out)
        out = self.bn3(out)

        if self.downsample is not None:
            identity = self.downsample(x)

        out += identity
        out = self.relu(out)

        return out

1.3.3 ResNet主体:堆叠多个残差块

在ResNet-50中,我们堆叠了多个残差块来构建整个网络。每个残差块会将输入的特征图进行处理,并输出更加丰富的特征图。堆叠多个残差块允许网络在深度方向上进行信息的层层提取,从而获得更高级的语义信息。代码如下:

class ResNet50(nn.Module):
    def __init__(self, num_classes=1000):
        # ... 前几层代码 ...

        # 4个残差块的block1
        self.layer1 = self._make_layer(ResidualBlock, 64, 3, stride=1)
        # 4个残差块的block2
        self.layer2 = self._make_layer(ResidualBlock, 128, 4, stride=2)
        # 4个残差块的block3
        self.layer3 = self._make_layer(ResidualBlock, 256, 6, stride=2)
        # 4个残差块的block4
        self.layer4 = self._make_layer(ResidualBlock, 512, 3, stride=2)

 利用make_layer函数实现对基本残差块Bottleneck的堆叠。代码如下:

def _make_layer(self, block, channel, block_num, stride=1):
    """
        block: 堆叠的基本模块
        channel: 每个stage中堆叠模块的第一个卷积的卷积核个数,对resnet50分别是:64,128,256,512
        block_num: 当期stage堆叠block个数
        stride: 默认卷积步长
    """
        downsample = None   # 用于控制shorcut路的
        if stride != 1 or self.in_channel != channel*block.expansion:   # 对resnet50:conv2中特征图尺寸H,W不需要下采样/2,但是通道数x4,因此shortcut通道数也需要x4。对其余conv3,4,5,既要特征图尺寸H,W/2,又要shortcut维度x4
            downsample = nn.Sequential(
                nn.Conv2d(in_channels=self.in_channel, out_channels=channel*block.expansion, kernel_size=1, stride=stride, bias=False), # out_channels决定输出通道数x4,stride决定特征图尺寸H,W/2
                nn.BatchNorm2d(num_features=channel*block.expansion))

        layers = []  # 每一个convi_x的结构保存在一个layers列表中,i={2,3,4,5}
        layers.append(block(in_channel=self.in_channel, out_channel=channel, downsample=downsample, stride=stride)) # 定义convi_x中的第一个残差块,只有第一个需要设置downsample和stride
        self.in_channel = channel*block.expansion   # 在下一次调用_make_layer函数的时候,self.in_channel已经x4

        for _ in range(1, block_num):  # 通过循环堆叠其余残差块(堆叠了剩余的block_num-1个)
            layers.append(block(in_channel=self.in_channel, out_channel=channel))

        return nn.Sequential(*layers)   # '*'的作用是将list转换为非关键字参数传入

1.4 迁移学习猫狗二分类实战

1.4.1 迁移学习

迁移学习(Transfer Learning)是一种机器学习和深度学习技术,它允许我们将一个任务学到的知识或特征迁移到另一个相关的任务中,从而加速模型的训练和提高性能。在迁移学习中,我们通常利用已经在大规模数据集上预训练好的模型(称为源任务模型),将其权重用于新的任务(称为目标任务),而不是从头开始训练一个全新的模型。

迁移学习的核心思想是:在解决一个新任务之前,我们可以先从已经学习过的任务中获取一些通用的特征或知识,并将这些特征或知识迁移到新任务中。这样做的好处在于,源任务模型已经在大规模数据集上进行了充分训练,学到了很多通用的特征,例如边缘检测、纹理等,这些特征对于许多任务都是有用的。

1.4.2 模型训练

首先,我们需要准备用于猫狗二分类的数据集。数据集可以从Kaggle上下载,其中包含了大量的猫和狗的图片。

在下载数据集后,我们需要将数据集划分为训练集和测试集。训练集文件夹命名为train,其中建立两个文件夹分别为cat和dog,每个文件夹里存放相应类别的图片。测试集命名为test,同理。然后我们使用ResNet50网络模型,在我们的计算机上使用GPU进行训练并保存我们的模型,训练完成后在测试集上验证模型预测的正确率。

import torch
import torch.nn as nn
import torch.optim as optim
import torchvision.transforms as transforms
from torch.utils.data import DataLoader, Dataset
from torchvision.datasets import ImageFolder
from torchvision.models import resnet50

# 设置随机种子
torch.manual_seed(42)

# 定义超参数
batch_size = 32
learning_rate = 0.001
num_epochs = 10

# 定义数据转换
transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
])

# 加载数据集
train_dataset = ImageFolder("train", transform=transform)
test_dataset = ImageFolder("test", transform=transform)

train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=batch_size)

# 加载预训练的ResNet-50模型
model = resnet50(pretrained=True)
num_ftrs = model.fc.in_features
model.fc = nn.Linear(num_ftrs, 2)  # 替换最后一层全连接层,以适应二分类问题

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)

# 定义损失函数和优化器
criterion = nn.CrossEntropyLoss()
optimizer = optim.SGD(model.parameters(), lr=learning_rate, momentum=0.9)

# 训练模型
total_step = len(train_loader)
for epoch in range(num_epochs):
    for i, (images, labels) in enumerate(train_loader):
        images = images.to(device)
        labels = labels.to(device)

        # 前向传播
        outputs = model(images)
        loss = criterion(outputs, labels)

        # 反向传播和优化
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        if (i + 1) % 100 == 0:
            print(f"Epoch [{epoch+1}/{num_epochs}], Step [{i+1}/{total_step}], Loss: {loss.item()}")
torch.save(model,'model/c.pth')
# 测试模型
model.eval()
with torch.no_grad():
    correct = 0
    total = 0
    for images, labels in test_loader:
        images = images.to(device)
        labels = labels.to(device)
        outputs = model(images)
        print(outputs)
        _, predicted = torch.max(outputs.data, 1)
        total += labels.size(0)
        correct += (predicted == labels).sum().item()
        break

    print(f"Accuracy on test images: {(correct / total) * 100}%")

1.4.3 模型预测

首先加载我们保存的模型,这里我们进行单张图片的预测,并把预测结果打印日志。

import cv2 as cv
import matplotlib.pyplot as plt
import matplotlib.image as mpimg
import torchvision.transforms as transforms
import  torch
from PIL import Image
import os
os.environ['KMP_DUPLICATE_LIB_OK'] = 'TRUE'
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model=torch.load('model/c.pth')
print(model)
model.to(device)

test_image_path = 'test/dogs/dog.4001.jpg'  # Replace with your test image path
image = Image.open(test_image_path)
transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
])
input_tensor = transform(image).unsqueeze(0).to(device)  # Add a batch dimension and move to GPU

# Set the model to evaluation mode
model.eval()


with torch.no_grad():
    outputs = model(input_tensor)
    _, predicted = torch.max(outputs, 1)
    predicted_label = predicted.item()


label=['猫','狗']
print(label[predicted_label])
plt.axis('off')
plt.imshow(image)
plt.show()

运行截图

至此这篇文章到此结束。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:/a/42806.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

(css)滚动条样式

(css)滚动条样式 效果: /*滚动条整体样式*/ ::-webkit-scrollbar {width: 2px;/*高宽分别对应横竖滚动条的尺寸*/height: 10px; } ::-webkit-scrollbar-thumb {/*滚动条里面小方块*/border-radius: 10px;width: 2px;height: 60px;background: linear-gradient(0deg,…

CentOS7系统MBR、GRUB2、内核启动流程报错问题

目录 🥩Linux启动流程 🥩MBR修复 🍭1、模拟损坏 🍭2、重启测试 🍭3、修复MBR 🍭4、测试系统 🥩GRUB2修复 🍭1、模拟损坏 🍭2、修复GRUB2 🍭3、测试系统 &…

03. 自定义镜像 Dockerfile

目录 1、前言 2、构建镜像的方式 2.1、docker commit 2.1.1、先查看下当前的容器 2.1.2、生成该容器镜像 2.1.3、查看镜像列表 2.2、Dockerfile 2.2.1、创建Dockerfile文件 2.2.2、编写Dockerfile文件 2.2.3、构建镜像 2.2.4、使用该镜像生成容器 3、Dockerfile 3…

GO内存模型(同步机制)

文章目录 概念1. 先行发生 编译器重排同步机制init函数协程的创建channelsync 包1. sync.mutex2. sync.rwmutex3. sync.once atomic 参考文献 概念 1. 先行发生 The happens before relation is defined as the transitive closure of the union of the sequenced before and …

【微信小程序】使用iView组件库中的icons资源

要在微信小程序中使用iView组件库中的icons资源,需要先下载并引入iView组件库,并按照iView的文档进行配置和使用。 以下是一般的使用步骤: 下载iView组件库的源码或使用npm安装iView。 在小程序项目的app.json文件中添加iView组件库的引入配…

PHP中常用数组排序算法

一:冒泡排序 1:算法步骤 比较相邻项的值,如果前者比后者大,交换顺序。 进行一轮比较后,最后一个值为最大的值。 进行下一轮比较,比上次少比较一项。 以此类推,比较剩下最后一项的时候&#…

【Linux进程】进程控制(上) {进程创建:fork的用法,fork的工作流程,写时拷贝;进程终止:3种退出情况,退出码,常见的退出方法}

一、进程创建 1.1 fork的初步认识和基本使用 在linux中fork函数是非常重要的函数&#xff0c;它从已存在进程中创建一个新进程。新进程为子进程&#xff0c;而原进程为父进程。 #include <unistd.h> pid_t fork(void);返回值&#xff1a;子进程中返回0&#xff0c;父进…

ORB-SLAM2学习笔记5之EuRoc、TUM和KITTI开源数据运行ROS版ORB-SLAM2生成轨迹

文章目录 0 引言1 数据预处理1.1 EuRoc数据1.2 TUM数据1.3 KITTI数据 2 代码修改2.1 单目2.2 双目2.3 RGB-D 3 运行ROS版ORB-SLAM23.1 单目3.2 双目3.3 RGB-D ORB-SLAM2学习笔记系列&#xff1a; 0 引言 ORB-SLAM2学习笔记1已成功编译安装ROS版本ORB-SLAM2到本地&#xff0c;本…

SQL高级教程第三章

SQL CREATE DATABASE 语句 CREATE DATABASE 语句 CREATE DATABASE 用于创建数据库。 SQL CREATE DATABASE 语法 CREATE DATABASE database_name SQL CREATE DATABASE 实例 现在我们希望创建一个名为 "my_db" 的数据库。 我们使用下面的 CREATE DATABASE 语句&…

2023云曦期中复现

目录 SIGNIN 新猫和老鼠 baby_sql SIGNIN 签到抓包 新猫和老鼠 看到反序列化 来分析一下 <?php //flag is in flag.php highlight_file(__FILE__); error_reporting(0);class mouse { public $v;public function __toString(){echo "Good. You caught the mouse:&…

5.1.tensorRT基础(2)-正确导出onnx的介绍,使得onnx问题尽量少

目录 前言1. 正确导出ONNX总结 前言 杜老师推出的 tensorRT从零起步高性能部署 课程&#xff0c;之前有看过一遍&#xff0c;但是没有做笔记&#xff0c;很多东西也忘了。这次重新撸一遍&#xff0c;顺便记记笔记。 本次课程学习 tensorRT 基础-正确导出 onnx 的介绍&#xff0…

飞书ChatGPT机器人 – 打造智能问答助手实现无障碍交流

文章目录 前言环境列表1.飞书设置2.克隆feishu-chatgpt项目3.配置config.yaml文件4.运行feishu-chatgpt项目5.安装cpolar内网穿透6.固定公网地址7.机器人权限配置8.创建版本9.创建测试企业10. 机器人测试 前言 在飞书中创建chatGPT机器人并且对话&#xff0c;在下面操作步骤中…

基于DeepFace模型设计的人脸识别软件

完整资料进入【数字空间】查看——baidu搜索"writebug" 人脸识别软件(无外部API) V2.0 基于DeepFace模型设计的人脸识别软件 V1.0 基于PCA模型设计的人脸识别软件 V2.0 更新时间&#xff1a;2018-08-15 在观看了吴恩达老师的“深度学习课程”&#xff0c;了解了深…

2023/7/23周报

目录 摘要 论文阅读 1、题目和现存问题 2、问题阐述及相关定义 3、LGDL模型框架 4、实验准备 5、实验过程 深度学习 1、GCN简单分类任务 2、文献引用数据分类案例 3、将时序型数据构建为图数据格式 总结 摘要 本周在论文阅读上&#xff0c;对基于图神经网络与深度…

【蓝牙AVDTP A2DP协议】

蓝牙AVDTP A2DP 一.AVDTP1.1 AVDTP概念1.2 Source Sink整体框架1.3 AVDTP术语1.3.2 Stream1.3.2 SRC and Sink1.3.3 INT and ACP1.3.4 SEP&#xff1a; 1.4 AVDTP体系1.4.1 体系概括1.4.2 Transport Services 1.5 Signaling Procedures1.5.1 General Requirements1.5.2 Transac…

关于Arduino IDE库文件存放路径问题总结(双版本)

在开发过程中,如果不注意,库文件存放路径很乱,如果在转移系统环境时,容易忘记备份。编译过程中出现多个可用引用包的位置,为了解决这些问题,要明白各文件夹的默认路径在哪,区别在哪,如有了解不对的地方请指正。 IDE安装目录(默认C盘,自定义可以其他盘符下)IDE升级可…

2023华为OD统一考试(B卷)题库清单(持续收录中)以及考点说明

目录 专栏导读2023 B卷 “新加题”&#xff08;100分值&#xff09;2023Q2 100分2023Q2 200分2023Q1 100分2023Q1 200分2022Q4 100分2022Q4 200分牛客练习题 专栏导读 本专栏收录于《华为OD机试&#xff08;JAVA&#xff09;真题&#xff08;A卷B卷&#xff09;》。 刷的越多&…

PALO ALTO NETWORKS 的新一代防火墙如何保护企业安全

轻松采用创新技术、阻止网络攻击得逞并专注更重要的工作 IT 的快速发展已改变网络边界的面貌。数据无处不在&#xff0c;用户可随时随地从各类设备访问这些数据。同时&#xff0c;IT 团队正在采用云、分析和自动化来加速新应用的交付以及推动业务发展。这些根本性的转变带来了…

11 简单的Thymeleaf语法

11.1 spring-boot环境准备 重要依赖&#xff1a; <!--thymeleaf--> <dependency><groupId>org.springframework.boot</groupId><artifactId>spring-boot-starter-thymeleaf</artifactId> </dependency> 11.2 转发消息不转义 就是如…

Vue3状态管理库Pinia——核心概念(Store、State、Getter、Action)

个人简介 &#x1f440;个人主页&#xff1a; 前端杂货铺 &#x1f64b;‍♂️学习方向&#xff1a; 主攻前端方向&#xff0c;正逐渐往全干发展 &#x1f4c3;个人状态&#xff1a; 研发工程师&#xff0c;现效力于中国工业软件事业 &#x1f680;人生格言&#xff1a; 积跬步…