【Week-P4】CNN猴痘病识别

文章目录

  • 一、环境配置
  • 二、准备数据
  • 三、搭建网络结构
  • 四、开始训练
  • 五、查看训练结果
  • 六、总结
    • 2.3 ⭐`torch.utils.data.DataLoader()`参数详解
    • 6.1 `print()`常用的三种输出格式
    • 6.2 修改网络结构
      • 6.2.1 增加池化、卷积和bn层
      • 6.2.2 增加卷积、bn、卷积、bn

  • 🍨 本文为🔗365天深度学习训练营 中的学习记录博客
  • 🍖 原作者:K同学啊 | 接辅导、项目定制
    在这里插入图片描述
  • 本周的代码相对于上周增加指定图片预测与保存并加载模型这个两个模块,在学习这个两知识点后,时间有余的同学请自由探索更佳的模型结构以提升模型是识别准确率,模型的搭建是深度学习程度的重点。

一、环境配置

import torch
import torch.nn as nn
import torchvision.transforms as transforms
import torchvision
from torchvision import transforms, datasets

import os,PIL,pathlib

import sys
from datetime import datetime
print("---------------------1.配置环境------------------")
print("Start time: ", datetime.today())
print("Pytorch version: " + torch.__version__)
print("Python version: " + sys.version)

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
device

在这里插入图片描述

二、准备数据

2.1 打印classeNames列表,显示每个文件所属的类别名称
2.2 打印归一化后的类别名称,01
2.3 划分数据集,划分为训练集&测试集,torch.utils.data.DataLoader()参数详解
2.4 检查数据集的shape

  • 第一步:使用pathlib.Path()函数将字符串类型的文件夹路径转换为pathlib.Path对象
  • 第二步:使用glob()方法获取data_dir路径下的所有文件路径,并以列表形式存储在data_paths中。
  • 第三步:通过split()函数对data_paths中的每个文件路径执行分割操作,获得各个文件所属的类别名称,并存储在classNames
  • 第四步:打印classNames列表,显示每个文件所属的类别名称。
import os,PIL,random,pathlib
print("------------2.1 打印classeNames列表,显示每个文件所属的类别名称------------")
total_datadir= './4-data/'
data_dir = pathlib.Path(total_datadir)

data_paths = list(total_datadir.glob('*'))
classNames = [str(path).split("\\")[1] for path in data_paths]
print("classNames: ", classNames)

print("------------2.2 打印归一化后的类别名称,0或1------------")
# 关于transforms.Compose的更多介绍可以参考:https://blog.csdn.net/qq_38251616/article/details/124878863
train_transforms = transforms.Compose([
    transforms.Resize([224, 224]),  # 将输入图片resize成统一尺寸
    transforms.ToTensor(),          # 将PIL Image或numpy.ndarray转换为tensor,并归一化到[0,1]之间
    transforms.Normalize(           # 标准化处理-->转换为标准正太分布(高斯分布),使模型更容易收敛
        mean=[0.485, 0.456, 0.406], 
        std=[0.229, 0.224, 0.225])  # 其中 mean=[0.485,0.456,0.406]与std=[0.229,0.224,0.225] 从数据集中随机抽样计算得到的。
])

total_data = datasets.ImageFolder(total_datadir,transform=train_transforms)
print("total_data: ", total_data)
print("total_data.class_to_idx: ", total_data.class_to_idx)

print("------------2.3 划分数据集,划分为训练集&测试集------------")
train_size = int(0.8 * len(total_data))
test_size  = len(total_data) - train_size
train_dataset, test_dataset = torch.utils.data.random_split(total_data, [train_size, test_size])
print( f"train_dataset: {train_dataset}, test_dataset: {test_dataset}")
print( f"train_size: {train_size}, test_size: {test_size}")

print("------------2.4 检查数据集的shape------------")
batch_size = 32
train_dl = torch.utils.data.DataLoader(train_dataset,
                                           batch_size=batch_size,
                                           shuffle=True,
                                           num_workers=1)
test_dl = torch.utils.data.DataLoader(test_dataset,
                                          batch_size=batch_size,
                                          shuffle=True,
                                          num_workers=1)
for X, y in test_dl:
    print("Shape of X [N, C, H, W]: ", X.shape)
    print("Shape of y: ", y.shape, y.dtype)
    break

在这里插入图片描述

三、搭建网络结构

print("------------3 搭建简单CNN网络------------")
import torch.nn.functional as F

class Network_bn(nn.Module):
    def __init__(self):
        super(Network_bn, self).__init__()
        """
        nn.Conv2d()函数:
        第一个参数(in_channels)是输入的channel数量
        第二个参数(out_channels)是输出的channel数量
        第三个参数(kernel_size)是卷积核大小
        第四个参数(stride)是步长,默认为1
        第五个参数(padding)是填充大小,默认为0
        """
        self.conv1 = nn.Conv2d(in_channels=3, out_channels=12, kernel_size=5, stride=1, padding=0)
        self.bn1 = nn.BatchNorm2d(12)
        self.conv2 = nn.Conv2d(in_channels=12, out_channels=12, kernel_size=5, stride=1, padding=0)
        self.bn2 = nn.BatchNorm2d(12)
        self.pool = nn.MaxPool2d(2,2)
        self.conv4 = nn.Conv2d(in_channels=12, out_channels=24, kernel_size=5, stride=1, padding=0)
        self.bn4 = nn.BatchNorm2d(24)
        self.conv5 = nn.Conv2d(in_channels=24, out_channels=24, kernel_size=5, stride=1, padding=0)
        self.bn5 = nn.BatchNorm2d(24)
        self.fc1 = nn.Linear(24*50*50, len(classNames))

    def forward(self, x):
        x = F.relu(self.bn1(self.conv1(x)))      
        x = F.relu(self.bn2(self.conv2(x)))     
        x = self.pool(x)                        
        x = F.relu(self.bn4(self.conv4(x)))     
        x = F.relu(self.bn5(self.conv5(x)))  
        x = self.pool(x)                        
        x = x.view(-1, 24*50*50)
        x = self.fc1(x)

        return x

device = "cuda" if torch.cuda.is_available() else "cpu"
print("Using {} device".format(device))

model = Network_bn().to(device)
model

在这里插入图片描述

四、开始训练

4.1 设置超参数
4.2 编写训练函数
4.3 编写测试函数
4.4 开始正式训练,epochs==20

print("------------4.1 设置超参数------------")
loss_fn    = nn.CrossEntropyLoss() # 创建损失函数
learn_rate = 1e-4 # 学习率
opt        = torch.optim.SGD(model.parameters(),lr=learn_rate)

print("------------4.2 编写训练函数------------")
# 训练循环
def train(dataloader, model, loss_fn, optimizer):
    size = len(dataloader.dataset)  # 训练集的大小,一共60000张图片
    num_batches = len(dataloader)   # 批次数目,1875(60000/32)

    train_loss, train_acc = 0, 0  # 初始化训练损失和正确率
    
    for X, y in dataloader:  # 获取图片及其标签
        X, y = X.to(device), y.to(device)
        
        # 计算预测误差
        pred = model(X)          # 网络输出
        loss = loss_fn(pred, y)  # 计算网络输出和真实值之间的差距,targets为真实值,计算二者差值即为损失
        
        # 反向传播
        optimizer.zero_grad()  # grad属性归零
        loss.backward()        # 反向传播
        optimizer.step()       # 每一步自动更新
        
        # 记录acc与loss
        train_acc  += (pred.argmax(1) == y).type(torch.float).sum().item()
        train_loss += loss.item()
            
    train_acc  /= size
    train_loss /= num_batches

    return train_acc, train_loss
    
print("------------4.3 编写测试函数------------")
def test (dataloader, model, loss_fn):
    size        = len(dataloader.dataset)  # 测试集的大小,一共10000张图片
    num_batches = len(dataloader)          # 批次数目,313(10000/32=312.5,向上取整)
    test_loss, test_acc = 0, 0
    
    # 当不进行训练时,停止梯度更新,节省计算内存消耗
    with torch.no_grad():
        for imgs, target in dataloader:
            imgs, target = imgs.to(device), target.to(device)
            
            # 计算loss
            target_pred = model(imgs)
            loss        = loss_fn(target_pred, target)
            
            test_loss += loss.item()
            test_acc  += (target_pred.argmax(1) == target).type(torch.float).sum().item()

    test_acc  /= size
    test_loss /= num_batches

    return test_acc, test_loss
    
print("------------4.4 开始正式训练,epochs==20------------")
epochs     = 20
train_loss = []
train_acc  = []
test_loss  = []
test_acc   = []

for epoch in range(epochs):
    model.train()
    epoch_train_acc, epoch_train_loss = train(train_dl, model, loss_fn, opt)
    
    model.eval()
    epoch_test_acc, epoch_test_loss = test(test_dl, model, loss_fn)
    
    train_acc.append(epoch_train_acc)
    train_loss.append(epoch_train_loss)
    test_acc.append(epoch_test_acc)
    test_loss.append(epoch_test_loss)
    
    template = ('Epoch:{:2d}, Train_acc:{:.1f}%, Train_loss:{:.3f}, Test_acc:{:.1f}%,Test_loss:{:.3f}')
    print(template.format(epoch+1, epoch_train_acc*100, epoch_train_loss, epoch_test_acc*100, epoch_test_loss))
print('Done')

在这里插入图片描述

五、查看训练结果

5.1 Loss与Accuracy图
5.2 指定图片进行预测
5.3 保存并加载模型

print("------------5.1 Loss与Accuracy图------------")
import matplotlib.pyplot as plt
#隐藏警告
import warnings
warnings.filterwarnings("ignore")               #忽略警告信息
plt.rcParams['font.sans-serif']    = ['SimHei'] # 用来正常显示中文标签
plt.rcParams['axes.unicode_minus'] = False      # 用来正常显示负号
plt.rcParams['figure.dpi']         = 100        #分辨率

epochs_range = range(epochs)

plt.figure(figsize=(12, 3))
plt.subplot(1, 2, 1)

plt.plot(epochs_range, train_acc, label='Training Accuracy')
plt.plot(epochs_range, test_acc, label='Test Accuracy')
plt.legend(loc='lower right')
plt.title('Training and Validation Accuracy')

plt.subplot(1, 2, 2)
plt.plot(epochs_range, train_loss, label='Training Loss')
plt.plot(epochs_range, test_loss, label='Test Loss')
plt.legend(loc='upper right')
plt.title('Training and Validation Loss')
plt.show()

print("------------5.2 指定图片进行预测------------")
from PIL import Image 

classes = list(total_data.class_to_idx)

def predict_one_image(image_path, model, transform, classes):
    test_img = Image.open(image_path).convert('RGB')
    # plt.imshow(test_img)  # 展示预测的图片

    test_img = transform(test_img)
    img = test_img.to(device).unsqueeze(0)
    
    model.eval()
    output = model(img)

    _,pred = torch.max(output,1)
    pred_class = classes[pred]
    print(f'预测结果是:{pred_class}')
    
# 预测训练集中的某张照片
predict_one_image(image_path='./4-data/Monkeypox/M01_01_00.jpg', 
                  model=model, 
                  transform=train_transforms, 
                  classes=classes)
                  
print("------------5.3 保存并加载模型------------")
# 模型保存
PATH = './model.pth'  # 保存的参数文件名
torch.save(model.state_dict(), PATH)

# 将参数加载到model当中
model.load_state_dict(torch.load(PATH, map_location=device))

在这里插入图片描述

六、总结

2.3 ⭐torch.utils.data.DataLoader()参数详解

torch.utils.data.DataLoaderPyTorch 中用于加载和管理数据的一个实用工具类。它允许你以小批次的方式迭代你的数据集,这对于训练神经网络和其他机器学习任务非常有用。DataLoader 构造函数接受多个参数,下面是一些常用的参数及其解释:

  1. dataset(必需参数):这是你的数据集对象,通常是 torch.utils.data.Dataset 的子类,它包含了你的数据样本。
  2. batch_size(可选参数):指定每个小批次中包含的样本数。默认值为 1
  3. shuffle(可选参数):如果设置为 True,则在每个 epoch 开始时对数据进行洗牌,以随机打乱样本的顺序。这对于训练数据的随机性很重要,以避免模型学习到数据的顺序性。默认值为 False
  4. num_workers(可选参数):用于数据加载的子进程数量。通常,将其设置为大于 0 的值可以加>快数据加载速度,特别是当数据集很大时。默认值为 0,表示在主进程中加载数据
  5. pin_memory(可选参数):如果设置为 True,则数据加载到 GPU 时会将数据存储在 CUDA 的锁页内存中,这可以加速数据传输到 GPU。默认值为 False
  6. drop_last(可选参数):如果设置为 True,则在最后一个小批次可能包含样本数小于 batch_size 时,丢弃该小批次。这在某些情况下很有用,以确保所有小批次具有相同的大小。默认值为 False
  7. timeout(可选参数):如果设置为正整数,它定义了每个子进程在等待数据加载器传递数据时的超时时间(以秒为单位),这可以用于避免子进程卡住的情况。默认值为 0,表示没有超时限制
  8. worker_init_fn(可选参数):一个可选的函数,用于初始化每个子进程的状态。这对于设置每个子进程的随机种子或其他初始化操作很有用。

6.1 print()常用的三种输出格式

    1. 带格式输出,{0}是指输出的第0个元素,同理{1}为第1个元素,{2}为第2个… 可以不按顺序排列
      print( "Hello {0}, I'm {2}, I,m {1} year old".format("world", age, name) )
    1. 使用类型输出,指定输出类型
      print( "I am %s, today is %d year"%(name, year) )
    1. f字符串,{}中为元素,是.format的简化形式
      print( f"Today is {year}")  

6.2 修改网络结构

6.2.1 增加池化、卷积和bn层

在这里插入图片描述
在这里插入图片描述
训练结果如下:
在这里插入图片描述
在这里插入图片描述
训练结果表明:修改网络结构之后,test_accuracy反而从85.3%降低到82.5%,说明此次修改结构不能提升test_accuracy的值。

6.2.2 增加卷积、bn、卷积、bn

在这里插入图片描述
在这里插入图片描述
在这里插入图片描述
训练结果表明:与6.2.1的修改方法相比,test_accuracy从82.5%提升到84.4%。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:/a/289543.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

界面控件DevExpress Blazor Grid v23.2 - 支持全新的单元格编辑模式

DevExpress Blazor UI组件使用了C#为Blazor Server和Blazor WebAssembly创建高影响力的用户体验,这个UI自建库提供了一套全面的原生Blazor UI组件(包括Pivot Grid、调度程序、图表、数据编辑器和报表等)。 在这篇文章中,我们将介…

Navicat 技术干货 | 如何查看关系型数据库(MySQL、PostgreSQL、SQL Server、 Oracle)查询的运行时间

在数据库优化中,理解和监控查询运行时间是至关重要的。无论你是数据库管理员、开发人员或是参与性能调优的人员,知道如何查看查询运行时间能为你的数据库操作提供有价值的参考。本文中,我们将探索几款热门的关系数据库(如 MySQL、…

❀记忆冒泡、选择和插入排序算法思想在bash里运用❀

目录 冒泡排序算法:) 选择排序算法:) 插入排序算法:) 冒泡排序算法:) 思想:依次比较相邻两个元素,重复的进行直到没有相邻元素需要交换,排序完成。 #!/bin/bash arr(12 324 543 213 65 64 1 3 45) #定义一个数组 n${#arr[*]} #获取数组…

DHTMLX Spreadsheet v5.1.1 Crack

DHTMLX Spreadsheet 5.1 具有新主题、简化的数字格式本地化、与框架的实时集成演示等 推出 DHTMLX Spreadsheet v5.1。新版本提供了一组有用的功能,这对开发人员和最终用户都有吸引力。 首先,新的电子表格版本提供了 4 个内置主题,可以根据您…

【完整代码】网上书店信息管理系统--基于Mysql数据库与java

网上书店信息管理系统 一、需求分析(一)设计系统的意义以及用途(二)实现的功能1.用户模块:1、全部图书浏览2、图书搜索3、购物车管理和订单查看4、修改密码 2.书店管理员模块1、图书类别管理2、图书管理3、全部订单查看…

【React】01-React 入门到放弃系列

01-React 入门 背景入门学生成绩录入的表单 小结 背景 由于捣鼓一些项目需要用到React,找了一些文档入门实践了一番。本篇文章以一个学生成绩录入的表单为例子,记录React 入门的一些基础操作。 入门 该操作的前提是本地安装了NodeJS环境。根据官网给的…

Docker入门教程(详解)

Docker容器化 一 入门 1. 引言 (1)单机部署 场景: 将多个应用部署一台服务器上。 问题 每个应用软件,都会消耗物理资源,共用计算机资源,彼此之间会形成竞争关系。 (2)多机部署 …

PHP反序列化漏洞利用及修复,示例代码讲解

您提到的PHP反序列化漏洞是一个重要的网络安全问题。在我的网络安全工程师的角色下,我可以提供关于此问题的深入分析。 PHP反序列化漏洞通常发生在当不可信的数据被反序列化时。序列化是将数据结构或对象状态转换为可存储或可传输的格式的过程,而反序列…

Springcloud 微服务实战笔记 Ribbon

使用 Configurationpublic class CustomConfiguration {BeanLoadBalanced // 开启负载均衡能力public RestTemplate restTemplate() {return new RestTemplate();}}可看到使用Ribbon,非常简单,只需将LoadBalanced注解加在RestTemplate的Bean上&#xff0…

AIGC时代-GPT-4和DALL·E 3的结合

在当今这个快速发展的数字时代,人工智能(AI)已经成为了我们生活中不可或缺的一部分。从简单的自动化任务到复杂的决策制定,AI的应用范围日益扩大。而在这个广阔的领域中,有两个特别引人注目的名字:GPT-4和D…

C/C++汇编学习(二)——学习使用IDA pro

学习使用IDA Pro是一项很有价值的技能,特别是对于那些对逆向工程和软件安全分析感兴趣的人。以下是一些基本步骤和概念,帮助你熟悉IDA Pro的界面和操作。 1. 熟悉IDA Pro界面和基本操作 主界面布局 IDA Pro的主界面包含多个组件,每个组件都…

pytest-yaml 测试平台-4.生成allure报告,报告反馈企业微信、钉钉、飞书通知

前言 定时任务执行完成后生成可视化allure报告,并把结果发到企业微信,钉钉,飞书通知群里。 生成allure报告 添加定时任务 执行完成后生成allure报告 查看报告详情 报告会显示详细的request 和 response 详细信息 也可以查看log日志 …

Vue3-35-路由-路由守卫的简单认识

什么是路由守卫 路由守卫,就是在 路由跳转 的过程中, 可以进行一些拦截,做一些逻辑判断, 控制该路由是否可以正常跳转的函数。常用的路由守卫有三个 : beforeEach() : 前置守卫,在路由 跳转前 就会被拦截&…

kube-promethues配置钉钉告警

kube-promethues配置钉钉告警 前置:k8s部署kube-promethues 一.配置钉钉机器人 打开钉钉的智能群助手,点击添加机器人 选择自定义机器人 勾选加签,复制后保存 复制webhook地址后点击保存 二.编写dingtalk的yaml部署文件 vi dingta…

74HC595驱动数码管程序

数码管的驱动分静态扫描和动态扫描两种,使用最多的是动态扫描,优点是使用较少的MCU的IO口就能驱动较多位数的数码管。数码管动态扫描驱动电路很多,其中最常见的是74HC164驱动数码管,这种电路一般用三极管作位选信号,用…

pytest conftest通过fixture实现变量共享

conftest.py scope"module" 只对当前执行的python文件 作用 pytest.fixture(scope"module") def global_variable():my_dict {}yield my_dict test_case7.py import pytestlist1 []def test_case001(global_variable):data1 123global_variable.u…

人工智能如何重塑金融服务业

在体验优先的世界中识别金融服务业中的AI使用场景 人工智能(AI)作为主要行业的大型组织的重要业务驱动力,持续受到关注。众所周知,传统金融服务业在采用新技术方面相对滞后,一些组织使用的还是上世纪50年代和60年代发…

华为云Sys-default、Sys-WebServer和Sys-FullAccess安全组配置规则

华为云服务器默认安全组可选Sys-default、Sys-WebServer或Sys-FullAccess。default是默认安全组规则,只开放了22和3389端口;Sys-WebServer适用于Web网站开发场景,开放了80和443端口;Sys-FullAccess开放了全部端口。阿腾云atengyun…

快速搭建知识付费小程序,3分钟即可开启知识变现之旅

产品服务 线上线下课程传播 线上线下活动管理 项目撮合交易 找商机找合作 一对一线下交流 企业文化宣传 企业产品销售 更多服务 实时行业资讯 动态学习交流 分销代理推广 独立知识店铺 覆盖全行业 个人IP打造 独立小程序 私域运营解决方案 公域引流 营销转化 …

使用jieba库进行中文分词和去除停用词

jieba.lcut jieba.lcut()和jieba.lcut_for_search()是jieba库中的两个分词函数,它们的功能和参数略有不同。 jieba.lcut()方法接受三个参数:需要分词的字符串,是否使用全模式(默认为False)以及是否使用HMM模型&…