PyTorch学习笔记:新冠肺炎X光分类

前言

目的是要了解pytorch如何完成模型训练
https://github.com/TingsongYu/PyTorch-Tutorial-2nd参考的学习笔记


数据准备

由于本案例目的是pytorch流程学习,为了简化学习过程,数据仅选择了4张图片,分为2类,正常与新冠,训练集2张,
验证集2张。标签信息存储于TXT文件中。具体目录结构如下:

注意:covid-19的图可以找到但是no-finding两张图没有找到
covid-19-1
covid-19-2
no-finding的图随便照两张看着正常的,别问我哪个是正常的,我也不知道(❍ᴥ❍ʋ),需要改名字为00001215_000.png00001215_001.png

├─imgs
│  ├─covid-19
│  │      auntminnie-a-2020_01_28_23_51_6665_2020_01_28_Vietnam_coronavirus.jpeg
│  │      ryct.2020200028.fig1a.jpeg
│  │
│  └─no-finding
│         00001215_000.png
│         00001215_001.png
│
└─labels
       train.txt
       valid.txt

创建标签文件:

创建 train.txt 和 valid.txt 文件,并填入图片路径和标签信息

  • train.txt:
covid-19/auntminnie-a-2020_01_28_23_51_6665_2020_01_28_Vietnam_coronavirus.jpeg 1
no-finding/00001215_000.png 0

  • valid.txt:
covid-19/ryct.2020200028.fig1a.jpeg 1
no-finding/00001215_001.png 0

完整代码示例:

以下是准备数据集、定义模型和训练模型的完整代码示例:

import os
from PIL import Image
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F

# 自定义数据集类
class COVID19Dataset(Dataset):
    def __init__(self, img_dir, label_file, transform=None):
        self.img_dir = img_dir
        self.transform = transform
        self.img_labels = []

        with open(label_file, 'r') as f:
            lines = f.readlines()
            for line in lines:
                self.img_labels.append(line.strip().split())

    def __len__(self):
        return len(self.img_labels)

    def __getitem__(self, idx):
        img_path, label = self.img_labels[idx]
        img_path = os.path.join(self.img_dir, img_path)
        image = Image.open(img_path).convert('RGB')
        label = int(label)

        if self.transform:
            image = self.transform(image)

        return image, label

# 图像预处理
transform = transforms.Compose([
    transforms.Resize((8, 8)),
    transforms.ToTensor()
])

# 创建数据集和数据加载器
train_dataset = COVID19Dataset(img_dir='imgs', label_file='labels/train.txt', transform=transform)
train_loader = DataLoader(train_dataset, batch_size=2, shuffle=True)

valid_dataset = COVID19Dataset(img_dir='imgs', label_file='labels/valid.txt', transform=transform)
valid_loader = DataLoader(valid_dataset, batch_size=2, shuffle=False)

# 定义简单卷积神经网络
class SimpleCNN(nn.Module):
    def __init__(self):
        super(SimpleCNN, self).__init__()
        self.conv1 = nn.Conv2d(3, 1, kernel_size=3)  # 输入通道为3(RGB),输出通道为1,卷积核大小为3x3
        self.fc1 = nn.Linear(1 * 6 * 6, 2)  # 全连接层,输入大小为6*6*1,输出大小为2(2类)

    def forward(self, x):
        x = self.conv1(x)
        x = F.relu(x)
        x = x.view(-1, 1 * 6 * 6)  # 展平操作
        x = self.fc1(x)
        return x

model = SimpleCNN()

# 定义损失函数和优化器
criterion = nn.CrossEntropyLoss()
optimizer = optim.SGD(model.parameters(), lr=0.01)

# 训练函数
def train(model, train_loader, criterion, optimizer, epoch):
    model.train()
    running_loss = 0.0
    for batch_idx, (data, target) in enumerate(train_loader):
        optimizer.zero_grad()
        output = model(data)
        loss = criterion(output, target)
        loss.backward()
        optimizer.step()
        running_loss += loss.item()
        if batch_idx % 10 == 9:
            print(f'Train Epoch: {epoch} [{batch_idx * len(data)}/{len(train_loader.dataset)}] Loss: {running_loss / 10:.6f}')
            running_loss = 0.0

# 验证函数
def validate(model, valid_loader, criterion):
    model.eval()
    validation_loss = 0.0
    correct = 0
    with torch.no_grad():
        for data, target in valid_loader:
            output = model(data)
            validation_loss += criterion(output, target).item()
            pred = output.argmax(dim=1, keepdim=True)
            correct += pred.eq(target.view_as(pred)).sum().item()

    validation_loss /= len(valid_loader.dataset)
    print(f'\nValidation set: Average loss: {validation_loss:.4f}, Accuracy: {correct}/{len(valid_loader.dataset)} ({100. * correct / len(valid_loader.dataset):.0f}%)\n')

# 训练和验证
for epoch in range(1, 11):
    train(model, train_loader, criterion, optimizer, epoch)
    validate(model, valid_loader, criterion)

效果展示:

由于数据量少,随机性非常大,大家多运行几次,观察结果。不过本案例结果完全不重要!)可以观看Average loss变化,Accuracy由于训练数据过少几乎不会变化
在这里插入图片描述

知识点总结

1. 数据

  • Q:要知道pytorch需要模型的格式
    A:需要编写代码完成数据的读取,转换成模型能够读取的格式。在 PyTorch 中,读取数据通常通过自定义 Dataset 类和内置的 DataLoader 来实现。这种方法既灵活又高效,适用于各种类型的数据集。
  • Q:自己如何编写Dataset?
    A:编写一个自定义的 Dataset 类,需要继承 torch.utils.data.Dataset 并实现三个方法:__init____len__ __getitem__

2. 模型

可参考:
从“卷积”、到“图像卷积操作”、再到“卷积神经网络”,“卷积”意义的3次改变_哔哩哔哩_bilibili

  • Q: 卷积层,全连接层的作用是什么?
    A: 卷积层提取特征,全连接层进行分类。
    1. 卷积层
    • 卷积层的作用是提取输入图像的特征。
    • 使用 3x3 的卷积核进行卷积操作,可以捕捉到局部的空间特征。
    • 卷积操作后的输出会产生一个新的特征图,这个特征图是卷积层提取到的特征表示。
    1. 全连接层
    • 全连接层的作用是将卷积层提取到的特征进行进一步的处理,最终输出分类结果。
    • 在这个例子中,全连接层有两个神经元,分别输出两个分类的概率。
    • 全连接层的输入被限制在 8x8,这意味着输入的特征图经过扁平化(flatten)后被映射到一个 8x8 的向量。

3. 优化

  • Q:根据什么规则对模型的参数进行更新学习呢?
    A:常用的方法:交叉熵损失函数(CrossEntropyLoss)、随机梯度下降法(SGD)和按固定步长下降学习率策略(StepLR)

4. 迭代

  • Q:怎么进行模型迭代?
    A: 有了模型参数更新的必备组件,接下来需要一遍又一遍地给模型喂数据,监控模型训练状态,这时候就需要for循环,不断地从dataloader里取出数据进行前向传播,反向传播,参数更新,观察loss、acc,周而复始。

总结

详细内容https://github.com/TingsongYu/PyTorch-Tutorial-2nd可查看,这是一篇读书笔记,与代码实现的分享。后续的笔记会以Q-A解决一些问题

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:/a/648039.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

解决鼠标滚动时element-ui下拉框错位的问题

问题描述:elementUi的el-select下拉选择框,打开之后,直到失去焦点才会自动关闭。 在有滚动条的弹窗中使用时就会出现打开下拉框,滚动弹窗,el-select下拉框会超出弹窗范围的问题. 解决方案: 1、先在util文件夹下创建个hideSelect.js文件,代码…

《德米安:彷徨少年时》

文前 我之所愿无非是尝试依本性而生活, 却缘何如此之难? 强盗 疏于独立思考和自我评判的人只能顺应现成的世俗法则,让生活变轻松。其他人则有自己的戒条:正派人惯常做的事于他可能是禁忌,而他自认合理的或许遭他人唾…

GM Bali,OKLink受邀参与Polygon AggIsland大会

5月16日-17日,OKLink 受到生态合作伙伴 Polygon 的特别邀请,来到巴厘岛参与以 AggIsland 为主题的大会活动并发表演讲,详细介绍 OKLink 为 Polygon 所带来的包括多个浏览器和数据解析等方面的成果,并与 Polygon 一起,对…

深入解析BGP:互联网路由协议的全貌与应用

BGP(Border Gateway Protocol)是互联网上用于在自治系统(AS)之间交换路由信息的协议。它负责决定数据包的最佳路径以及路由的选择。以下是BGP的一些关键特点和工作原理的详细内容: BGP的特点: 1.路径矢量型…

stm32-PWM输出比较配置

配置流程 1.RCC开启时钟 2.时钟源选择和配置时基单元 这一部分上一篇有写,可以参考一下上一篇的内容,此处不多赘述了。 原文链接:https://blog.csdn.net/m0_74246768/article/details/139048136 3.配置输出比较单…

Ubuntu server 24 源码安装Quagga 支持动态路由协议ospf bgp

1 下载:GitHub - Quagga/quagga: Quagga Tracking repository - Master is at http://git.savannah.gnu.org/cgit/quagga.git 2 安装 #安装依赖包 sudo apt install gcc make libreadline-dev pkg-config #解压 tar zxvf quagga-1.2.4.tar.gz cd quagga-1.2.4/sudo ./co…

Spring Boot 项目统一异常处理

在 Spring Boot 项目开发中,异常处理是一个非常重要的环节。良好的异常处理不仅能提高应用的健壮性,还能提升用户体验。本文将介绍如何在 Spring Boot 项目中实现统一异常处理。 统一异常处理有以下几个优点: 提高代码可维护性:…

Linux系统之GoAccess实时Web日志分析工具的基本使用

Linux系统之GoAccess实时Web日志分析工具的基本使用 一、GoAccess介绍1.1 GoAccess简介1.2 GoAccess功能1.3 Web日志格式 二、本地环境介绍2.1 本地环境规划2.2 本次实践介绍 三、检查本地环境3.1 检查本地操作系统版本3.2 检查系统内核版本3.3 检查系统镜像源3.4 更新软件列表…

夏老师小课堂(7) 免费撸Harmony0S应用开发者高级认证

点击上方 “机械电气电机杂谈 ” → 点击右上角“...” → 点选“设为星标 ★”,为加上机械电气电机杂谈星标,以后找夏老师就方便啦!你的星标就是我更新动力,星标越多,更新越快,干货越多! 关注…

24年湖南教资认定即将开始,别被照片卡审!

24年湖南教资认定即将开始,别被照片卡审!

springboot vue 开源 会员收银系统 (4) 门店模块开发

前言 完整版演示 前面我们对会员系统 springboot vue 开源 会员收银系统 (3) 会员管理的开发 实现了简单的会员添加 下面我们将从会员模块进行延伸 门店模块的开发 首先我们先分析一下常见门店的管理模式 常见的管理形式为总公司 - 区域管理(若干个门店&#xff…

简单操作一单利润500+,最新快手缺货赔付玩法,【找店教程+详细教程】

在如今快速变化的时代,寻找充满创新的收入来源已经成为了一种趋势。这不仅是为了实现财务的自由,更是为了在生活中拥有更多的选择权。一项革新的实践——利用手机进行快手缺货赔付单号的操作,已经成为许多人稳定“下车”的一个新途径。 据了…

英语学习笔记28——Where are they?

Where are they? 他们在哪里? 课文部分

【模拟面试问答】深入解析力扣163题:缺失的区间(线性扫描与双指针法详解)

❤️❤️❤️ 欢迎来到我的博客。希望您能在这里找到既有价值又有趣的内容,和我一起探索、学习和成长。欢迎评论区畅所欲言、享受知识的乐趣! 推荐:数据分析螺丝钉的首页 格物致知 终身学习 期待您的关注 导航: LeetCode解锁100…

2024中青杯数学建模竞赛A题人工智能视域下养老辅助系统的构建思路代码论文分析

2024中青杯数学建模A题论文和代码已完成,代码为A题全部问题的代码,论文包括摘要、问题重述、问题分析、模型假设、符号说明、模型的建立和求解(问题1模型的建立和求解、问题2模型的建立和求解、问题3模型的建立和求解)、模型的评价…

浅谈网络通信(1)

文章目录 一、认识一些网络基础概念1.1、ip地址1.2、端口号1.3、协议1.4、协议分层1.5、协议分层的2种方式1.5.1、OSI七层模型1.5.2、TCP/IP五层模型[!]1.5.2.1、TCP/IP五层协议各层的含义及功能 二、网络中数据传输的基本流程——封装、分用2.1、封装2.2、分用2.2.1、5元组 三…

edge浏览器的网页复制

一些网页往往禁止复制粘贴,本文方法如下: 网址最前面加上 read: (此方法适用于Microsoft Edge 浏览器)在此网站网址前加上read:进入阅读器模式即可

AI办公自动化:用kimi批量将word文档部分文件名保存到Excel中

文件夹中有很多个word文档,现在只要英文部分的文件名,保存到一个Excel文件中。 可以在kimi中输入提示词: 你是一个Python编程专家,要完成一个编写Python脚本的任务,具体步骤如下: 打开文件夹:…

群晖nas连接(路由器设置)--群晖配置下文

目录 前言 本文目的与核心 一、打开IPV6和关闭防火墙 路由器后台 二、打开群晖查看是否有ipv6和记住ipv4地址 群晖后台界面 三、路由器设置端口转发 路由器后台 四、打开DDNS-GO的配置页面查看是否配置生效成功 群晖另一个配置后台 五、访问测试 前言 群晖配置上…

Thinkphp5内核宠物领养平台H5源码

源码介绍 Thinkphp5内核流浪猫流浪狗宠物领养平台H5源码 可封装APP,适合做猫狗宠物类的发信息发布,当然懂的修改一下,做其他信息发布也是可以的。 源码预览 源码下载 https://download.csdn.net/download/huayula/89361685