深度学习第J6周:ResNeXt-50实战解析

目录

一、模型结构介绍

二、前期准备

三、模型

 三、训练运行

3.1训练

3.2指定图片进行预测


🍨 本文为[🔗365天深度学习训练营]内部限免文章(版权归 *K同学啊* 所有)
🍖 作者:[K同学啊]

 📌 本周任务:
●阅读ResNeXt论文,了解作者的构建思路
●对比我们之前介绍的ResNet50V2、DenseNet算法
●使用ResNeXt-50算法完成猴痘病识别

一、模型结构介绍

ResNeXt是由何凯明团队在2017年CVPR会议上提出来的新型图像分类网络。ResNeXt是ResNet的升级版,在ResNet的基础上,引入了cardinality的概念,类似于ResNet,ResNeXt也有ResNeXt-50,ResNeXt-101的版本。ResNeXt论文名为:Aggregated Residual Transformations for Deep Neural Networks.pdf

这篇文章介绍了一种用于图像分类的简单而有效的网络架构,称为Aggregated Residual Transformations for Deep Neural Networks。该网络采用了VGG/ResNets的策略,通过重复层来增加深度和宽度,并利用分裂-变换-合并策略以易于扩展的方式进行转换。文章还提出了一个新的维度——“基数”,它是指转换集合的大小,可以在保持复杂性不变的情况下提高分类准确性。作者在ImageNet-1K数据集上进行了实证研究,证明了这种方法的有效性。

 下图是ResNet(左)与ResNeXt(右)block的差异。在ResNet中,输入的具有256个通道的特征经过1×1卷积压缩4倍到64个通道,之后3×3的卷积核用于处理特征,经1×1卷积扩大通道数与原特征残差连接后输出。

ResNeXt也是相同的处理策略,但在ResNeXt中,输入的具有256个通道的特征被分为32个组,每组被压缩64倍到4个通道后进行处理。32个组相加后与原特征残差连接后输出。这里cardinatity指的是一个block中所具有的相同分支的数目。

 分组卷积

ResNeXt中采用的分组卷机简单来说就是将特征图分为不同的组,再对每组特征图分别进行卷积,这个操作可以有效的降低计算量。
在分组卷积中,每个卷积核只处理部分通道,比如下图中,红色卷积核只处理红色的通道,绿色卷积核只处理绿色通道,黄色卷积核只处理黄色通道。此时每个卷积核有2个通道,每个卷积核生成一张特征图。

总结来说就是:ResNeXt-50网络简单讲就是在ResNet结构的基础上采用了聚合残差结构局部连接结构,同时引入了Random ErasingMixup等数据增强和正则化方法 

二、前期准备

大致模板和以前一样,以后不再详细列,样例可见:深度学习第J4周:ResNet与DenseNet结合探索_牛大了2023的博客-CSDN博客

配置gpu+导入数据集

import os,PIL,random,pathlib
import torch
import torch.nn as nn
import torchvision.transforms as transforms
import torchvision
from torchvision import transforms, datasets
 
 
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
 
print(device)
 
data_dir = './data/'
data_dir = pathlib.Path(data_dir)
 
data_paths = list(data_dir.glob('*'))
classeNames = [str(path).split("\\")[1] for path in data_paths]
print(classeNames)
 
image_count = len(list(data_dir.glob('*/*')))
print("图片总数为:", image_count)

数据预处理+划分数据集

train_transforms = transforms.Compose([
    transforms.Resize([224, 224]),  # 将输入图片resize成统一尺寸
    # transforms.RandomHorizontalFlip(), # 随机水平翻转
    transforms.ToTensor(),  # 将PIL Image或numpy.ndarray转换为tensor,并归一化到[0,1]之间
    transforms.Normalize(  # 标准化处理-->转换为标准正太分布(高斯分布),使模型更容易收敛
        mean=[0.485, 0.456, 0.406],
        std=[0.229, 0.224, 0.225])  # 其中 mean=[0.485,0.456,0.406]与std=[0.229,0.224,0.225] 从数据集中随机抽样计算得到的。
])
 
test_transform = transforms.Compose([
    transforms.Resize([224, 224]),  # 将输入图片resize成统一尺寸
    transforms.ToTensor(),  # 将PIL Image或numpy.ndarray转换为tensor,并归一化到[0,1]之间
    transforms.Normalize(  # 标准化处理-->转换为标准正太分布(高斯分布),使模型更容易收敛
        mean=[0.485, 0.456, 0.406],
        std=[0.229, 0.224, 0.225])  # 其中 mean=[0.485,0.456,0.406]与std=[0.229,0.224,0.225] 从数据集中随机抽样计算得到的。
])
 
total_data = datasets.ImageFolder("./data/", transform=train_transforms)
print(total_data.class_to_idx)
 
train_size = int(0.8 * len(total_data))
test_size = len(total_data) - train_size
train_dataset, test_dataset = torch.utils.data.random_split(total_data, [train_size, test_size])
 
batch_size = 32
train_dl = torch.utils.data.DataLoader(train_dataset,
                                       batch_size=batch_size,
                                       shuffle=True,
                                       num_workers=0)
test_dl = torch.utils.data.DataLoader(test_dataset,
                                      batch_size=batch_size,
                                      shuffle=True,
                                      num_workers=0)
for X, y in test_dl:
    print("Shape of X [N, C, H, W]: ", X.shape)
    print("Shape of y: ", y.shape, y.dtype)
    break

三、模型

class BN_Conv2d(nn.Module):
    """
    BN_CONV_RELU
    """

    def __init__(self, in_channels, out_channels, kernel_size, stride, padding, dilation=1, groups=1, bias=False):
        super(BN_Conv2d, self).__init__()
        self.seq = nn.Sequential(
            nn.Conv2d(in_channels, out_channels, kernel_size=kernel_size, stride=stride,
                      padding=padding, dilation=dilation, groups=groups, bias=bias),
            nn.BatchNorm2d(out_channels)
        )

    def forward(self, x):
        return F.relu(self.seq(x))

class ResNeXt_Block(nn.Module):
    """
    ResNeXt block with group convolutions
    """

    def __init__(self, in_chnls, cardinality, group_depth, stride):
        super(ResNeXt_Block, self).__init__()
        self.group_chnls = cardinality * group_depth
        self.conv1 = BN_Conv2d(in_chnls, self.group_chnls, 1, stride=1, padding=0)
        self.conv2 = BN_Conv2d(self.group_chnls, self.group_chnls, 3, stride=stride, padding=1, groups=cardinality)
        self.conv3 = nn.Conv2d(self.group_chnls, self.group_chnls*2, 1, stride=1, padding=0)
        self.bn = nn.BatchNorm2d(self.group_chnls*2)
        self.short_cut = nn.Sequential(
            nn.Conv2d(in_chnls, self.group_chnls*2, 1, stride, 0, bias=False),
            nn.BatchNorm2d(self.group_chnls*2)
        )

    def forward(self, x):
        out = self.conv1(x)
        out = self.conv2(out)
        out = self.bn(self.conv3(out))
        out += self.short_cut(x)
        return F.relu(out)

class ResNeXt(nn.Module):
    """
    ResNeXt builder
    """

    def __init__(self, layers: object, cardinality, group_depth, num_classes) -> object:
        super(ResNeXt, self).__init__()
        self.cardinality = cardinality
        self.channels = 64
        self.conv1 = BN_Conv2d(3, self.channels, 7, stride=2, padding=3)
        d1 = group_depth
        self.conv2 = self.___make_layers(d1, layers[0], stride=1)
        d2 = d1 * 2
        self.conv3 = self.___make_layers(d2, layers[1], stride=2)
        d3 = d2 * 2
        self.conv4 = self.___make_layers(d3, layers[2], stride=2)
        d4 = d3 * 2
        self.conv5 = self.___make_layers(d4, layers[3], stride=2)
        self.fc = nn.Linear(self.channels, num_classes)   # 224x224 input size

    def ___make_layers(self, d, blocks, stride):
        strides = [stride] + [1] * (blocks-1)
        layers = []
        for stride in strides:
            layers.append(ResNeXt_Block(self.channels, self.cardinality, d, stride))
            self.channels = self.cardinality*d*2
        return nn.Sequential(*layers)

    def forward(self, x):
        out = self.conv1(x)
        out = F.max_pool2d(out, 3, 2, 1)
        out = self.conv2(out)
        out = self.conv3(out)
        out = self.conv4(out)
        out = self.conv5(out)
        out = F.avg_pool2d(out, 7)
        out = out.view(out.size(0), -1)
        out = F.softmax(self.fc(out),dim=1)
        return out
# 定义完成,测试一下
model = ResNeXt([3, 4, 6, 3], 32, 4, 4)
model.to(device)

# 统计模型参数量以及其他指标
import torchsummary as summary
summary.summary(model, (3, 224, 224))

 三、训练运行

3.1训练

代码和以前的差不多,不再细说

 
# 训练循环
def train(dataloader, model, loss_fn, optimizer):
    size = len(dataloader.dataset)  # 训练集的大小
    num_batches = len(dataloader)  # 批次数目, (size/batch_size,向上取整)
 
    train_loss, train_acc = 0, 0  # 初始化训练损失和正确率
 
    for X, y in dataloader:  # 获取图片及其标签
        X, y = X.to(device), y.to(device)
 
        # 计算预测误差
        pred = model(X)  # 网络输出
        loss = loss_fn(pred, y)  # 计算网络输出和真实值之间的差距,targets为真实值,计算二者差值即为损失
 
        # 反向传播
        optimizer.zero_grad()  # grad属性归零
        loss.backward()  # 反向传播
        optimizer.step()  # 每一步自动更新
 
        # 记录acc与loss
        train_acc += (pred.argmax(1) == y).type(torch.float).sum().item()
        train_loss += loss.item()
 
    train_acc /= size
    train_loss /= num_batches
 
    return train_acc, train_loss
 
 
def test(dataloader, model, loss_fn):
    size = len(dataloader.dataset)  # 测试集的大小
    num_batches = len(dataloader)  # 批次数目
    test_loss, test_acc = 0, 0
 
    # 当不进行训练时,停止梯度更新,节省计算内存消耗
    with torch.no_grad():
        for imgs, target in dataloader:
            imgs, target = imgs.to(device), target.to(device)
 
            # 计算loss
            target_pred = model(imgs)
            loss = loss_fn(target_pred, target)
 
            test_loss += loss.item()
            test_acc += (target_pred.argmax(1) == target).type(torch.float).sum().item()
 
    test_acc /= size
    test_loss /= num_batches
 
    return test_acc, test_loss

 跑十轮并保存模型

 
import copy
 
optimizer = torch.optim.Adam(model.parameters(), lr=1e-4)
loss_fn = nn.CrossEntropyLoss()  # 创建损失函数
 
epochs = 10
 
train_loss = []
train_acc = []
test_loss = []
test_acc = []
 
best_acc = 0  # 设置一个最佳准确率,作为最佳模型的判别指标
 
for epoch in range(epochs):
    # 更新学习率(使用自定义学习率时使用)
    # adjust_learning_rate(optimizer, epoch, learn_rate)
 
    model.train()
    epoch_train_acc, epoch_train_loss = train(train_dl, model, loss_fn, optimizer)
    # scheduler.step() # 更新学习率(调用官方动态学习率接口时使用)
 
    model.eval()
    epoch_test_acc, epoch_test_loss = test(test_dl, model, loss_fn)
 
    # 保存最佳模型到 best_model
    if epoch_test_acc > best_acc:
        best_acc = epoch_test_acc
        best_model = copy.deepcopy(model)
 
    train_acc.append(epoch_train_acc)
    train_loss.append(epoch_train_loss)
    test_acc.append(epoch_test_acc)
    test_loss.append(epoch_test_loss)
 
    # 获取当前的学习率
    lr = optimizer.state_dict()['param_groups'][0]['lr']
 
    template = ('Epoch:{:2d}, Train_acc:{:.1f}%, Train_loss:{:.3f}, Test_acc:{:.1f}%, Test_loss:{:.3f}, Lr:{:.2E}')
    print(template.format(epoch + 1, epoch_train_acc * 100, epoch_train_loss,
                          epoch_test_acc * 100, epoch_test_loss, lr))
 
# 保存最佳模型到文件中
PATH = './best_model.pth'  # 保存的参数文件名
torch.save(model.state_dict(), PATH)
 
print('Done')

 打印训练记录图

import matplotlib.pyplot as plt
# 隐藏警告
import warnings
 
warnings.filterwarnings("ignore")  # 忽略警告信息
plt.rcParams['font.sans-serif'] = ['SimHei']  # 用来正常显示中文标签
plt.rcParams['axes.unicode_minus'] = False  # 用来正常显示负号
plt.rcParams['figure.dpi'] = 100  # 分辨率
 
epochs_range = range(epochs)
 
plt.figure(figsize=(12, 3))
plt.subplot(1, 2, 1)
 
plt.plot(epochs_range, train_acc, label='Training Accuracy')
plt.plot(epochs_range, test_acc, label='Test Accuracy')
plt.legend(loc='lower right')
plt.title('Training and Validation Accuracy')
 
plt.subplot(1, 2, 2)
plt.plot(epochs_range, train_loss, label='Training Loss')
plt.plot(epochs_range, test_loss, label='Test Loss')
plt.legend(loc='upper right')
plt.title('Training and Validation Loss')
plt.show()

3.2指定图片进行预测

把训练部分注释掉

 
from PIL import Image
 
classes = list(total_data.class_to_idx)
 
 
def predict_one_image(image_path, model, transform, classes):
    test_img = Image.open(image_path).convert('RGB')
    plt.imshow(test_img)  # 展示预测的图片
 
    test_img = transform(test_img)
    img = test_img.to(device).unsqueeze(0)
 
    model.eval()
    output = model(img)
 
    _, pred = torch.max(output, 1)
    pred_class = classes[pred]
    print(f'预测结果是:{pred_class}')
 
 
# 预测训练集中的某张照片
predict_one_image(image_path='./data/Others/NM01_01_01.jpg',
                  model=model,
                  transform=train_transforms,
                  classes=classes)

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:/a/18800.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

移动端异构运算技术 - GPU OpenCL 编程(基础篇)

一、前言 随着移动端芯片性能的不断提升,在移动端上实时进行计算机图形学、深度学习模型推理等计算密集型任务不再是一个奢望。在移动端设备上,GPU 凭借其优秀的浮点运算性能,以及良好的 API 兼容性,成为移动端异构计算中非常重要…

Echarts使用本地JSON文件加载不出图表的解决方法以及Jquery访问本地JSON文件跨域的解决方法

前言 最近需要做一个大屏展示,需要用原生html5cssjs来写,所以去学了一下echarts的使用。在使用的过程中难免碰到许多BUG,百度那是必不可少的,可是这些人写的牛头不对马嘴,简直是标题党一大堆,令我作呕&…

fastjson 代码执行 (CNVD-2017-02833)

漏洞存在原因 在fastjson<1.2.24版本中&#xff0c;在解析json的过程中&#xff0c;支持使用autoType来实例化某一个具体的类&#xff0c;并调用该类的set/get方法来访问属性。而在1.24<fastjson<1.2.48版本中后增加了反序列化白名单。 漏洞复现过程如下 在vulfocu…

龙的画法图片

由龙老师画素描中国龙的方法,大概可以遵循以下步骤: 确定龙的姿态和比例:在纸上简单地画出龙的基本形状和姿态,包括身体的长度,颈部、腿和尾巴的位置和比例关系。 添加细节:在基本形状的基础上,开始添加一些细节,如龙的头部、眼睛、鼻子、嘴巴、爪子等。注意要保持姿态和比例…

UU跑腿“跑男失联”:同城即配服务赛道商业逆袭难?

五一假期&#xff0c;人们纷纷走出家门&#xff0c;要么扎堆奔向“远方”&#xff0c;要么、享受本地烟火气息。 据文化和旅游部数据中心测算&#xff0c;劳动节假期&#xff0c;全国国内旅游出游合计2.74亿人次&#xff0c;同比增长70.83%。 五一假日的郑州东站 面对人山人海…

Docker 应用部署-MySQL

一、安装MySQL 1搜索mysql镜像 docker search mysql 2拉取mysql镜像 docker pull mysql:8.0.20 3创建容器 通过下面的命令&#xff0c;创建容器并设置端口映射、目录映射 #在用户名目录下创建mysql目录用于存储mysql数据信息 mkdir /home/mysql cd /home/mysql #创建docker容…

Python图形化编程开源项目拼码狮PinMaShi

开源仓库 #项目地址 https://github.com/supercoderlee/pinmashi https://gitee.com/supercoderlee/pinmashiPinMaShi采用electron开发&#xff0c;图形化拖拽式编程有效降低编程难度&#xff0c;对Python编程的初学者非常友好&#xff1b;积木式编程加快Python程序的开发&…

微服务介绍 SpringCloud,服务拆分和远程调用,注册中心Eureka,负载均衡Ribbon,注册中心Nacos

一、微服务介绍 1.系统架构的演变 什么是微服务&#xff0c;微服务有哪些特征SpringCloud是什么SpringCloud与Dubbo的区别 1.单体架构 将业务的所有功能集中在一个项目中开发&#xff0c;打成一个包部署。当网站流量很小时&#xff0c;单体架构非常合适 1.单体架构 优点&a…

携带数据的Ajax POST请求

前端页面代码&#xff1a; <!DOCTYPE html> <html lang"en"> <head> <meta charset"UTF-8"> <title>发送ajax POST请求 看如何携带数据</title> <script type"text/javascript"> …

第8章 未执行缓存的强制清理操作导致显示异常解决方案

1 异常产生原因&#xff1a; 由于未为Role实体定义相就的缓存强制销毁器类&#xff1a;Services.Customers.Caching.RoleCacheEventConsumer,从而导致Services.Events.EventPublisher.PublishAsync<TEvent>(TEvent event)中的 consumers实例为0,如下图所示&#xff1a; 2…

Redis(连接池)

SpringBoor环境下使用redis连接池 依赖&#xff1a; <dependencies><dependency><groupId>com.yugabyte</groupId><artifactId>jedis</artifactId><version>2.9.0-yb-11</version></dependency><dependency><…

SpringBoot基础篇3(SpringBoot+Mybatis-plus案例)

环境搭建&#xff1a;配置起步依赖pom.xml和配置文件application.yml 1.创建模块时&#xff0c;勾选的依赖有springMVC和MySQL驱动 2.手动添加的依赖有&#xff1a;MyBatis-plus、Druid、lombok <dependencies><dependency><groupId>org.springframework.…

行为型模式-解释器模式

解释器模式 概述 如上图&#xff0c;设计一个软件用来进行加减计算。我们第一想法就是使用工具类&#xff0c;提供对应的加法和减法的工具方法。 //用于两个整数相加 public static int add(int a,int b){return a b; }//用于两个整数相加 public static int add(int a,int …

使用护眼灯台灯哪个牌子好用来保护眼睛?真正做到护眼台灯品牌

现在的家长很多人觉得家里已经有灯光了&#xff0c;没必要在买台灯。但事实上台灯有很多优点&#xff0c;尤其对于小孩子来说&#xff1a;1.提供更好的光线:台灯能够提供更加明亮的光线&#xff0c;有助于保护眼睛健康。2.提高工作效率:台灯光线舒适可提高工作效率或学习效率。…

CPU 架构(x86/ARM)简介

CPU 架构通过指令集的方式一般可分为 复杂指令集&#xff08;CISC&#xff09; 和 精简指令集&#xff08;RISC&#xff09; 两类&#xff0c;CISC 主要是 x86 架构&#xff0c;RISC 主要是 ARM 架构&#xff0c;还有 MIPS、RISC-V、PowerPC 等架构。 本文重点介绍 x86 和 ARM…

SpringBoot整合Nacos配置中心和注册中心

一、背景 公司项目中使用的Nacos作为服务的注册中心和配置中心&#xff0c;但是呢公司的这一套Nacos是经过封装了的&#xff0c;而且封装的不是很友好&#xff0c;想着自己搭建一套标注的Nacos配置中心和服务中心 二、Nacos配置中心和注册中心搭建 2.1 依赖引入 <!--注册…

【Linux】shell编程之循环语句

提示&#xff1a;文章写完后&#xff0c;目录可以自动生成&#xff0c;如何生成可参考右边的帮助文档 文章目录 一、循环语句二、for循环语句1.for 语句的结构2.for语句应用示例 三、while 循环语句1.while 循环语句结构2.while语句应用示例 四、until 循环五、跳出循环六、死循…

新品发布全线添员,九号全力奔向“红海”深处?

5月10日&#xff0c;九号公司2023新品发布会声势达到顶峰。此次发布会的看点为九号电动2023产品线的更新&#xff0c;电动家族再添多员大将。 随着人们出行选择的多样化&#xff0c;国内短途出行工具发展迎来井喷期。在传统的电动两轮车市场上&#xff0c;雅迪、爱玛等品牌仍然…

今年这面试难度,我给跪了……

大家好&#xff0c;最近有不少小伙伴在后台留言&#xff0c;又得准备面试了&#xff0c;不知道从何下手&#xff01; 不论是跳槽涨薪&#xff0c;还是学习提升&#xff01;先给自己定一个小目标&#xff0c;然后再朝着目标去努力就完事儿了&#xff01; 为了帮大家节约时间&a…

关于cartographer建立正确关系树的理解

正确的TF关系map----odom----base_link----laser base_link是固定在机器人本体上的坐标系&#xff0c;通常选择飞控 其中map–odom 的链接是由cartographer中lua文件配置完成的 map_frame "map", tracking_frame "base_link", published_frame "b…