Swin Transformer实战图像分类(Windows下,无需用到Conda,亲测有效)

目录

前言

一、从官网拿到源码,然后配置自己缺少的环境。

针对可能遇到的错误:

二、数据集获取与处理

2.1 数据集下载

2.2 数据集处理

三、下载预训练权重

四、修改部分参数配置

4.1 修改config.py

4.2 修改build.py

4.3 修改units.py

4.4 修改main.py

4.5 修改其他的地方

4.6 将最后结果以折线图的形式呈现出来

五、其他可修改的地方

六、运行代码


前言

关于Swin Transformer的讲解和实战,实际上网络上已经有很多了。不过有一些代码跑起来可能有一些问题(有一些确实有点问题,或者没头没尾的)。

最初的时候,我通过一些调研,参照网上的一些教程,跑的时候也遇到了一些问题,但是最后确实是成功了。下面我就详细地来讲述应该怎么做。

关于Swin Transformer的基础知识就不再赘述了。相信想到用Swin Transformer来实战的同学肯定已经多多少少对其有一定了解了。

在此,我说一下我的实战的思路:

从官网拿到代码,然后改改,换成自己的数据集,加载它的预训练权重,然后让代码跑起来。

如果你的coding能力确实比较强,那么你完全可以从官网上找到部分Swin Transformer的Model部分核心代码,然后数据处理部分、跑模型的部分都自己来写,这样做也完全OK。但是对能力要求较高,并且对模型的理解要求也比较高。比如,参考某B站Up主的视频,它的代码就是这样子的(在别人的文章里应该都能看的到,我就不重复了)。

而我们在这里就介绍傻瓜式的操作。

我的环境:

  • Win 10
  • Python 3.8
  • Pytorch/torchvision 1.13.1+cu116
  • NVIDIA GeForce RTX 3060(CUDA 11.7.102)
  • Pycharm Community;

OK,开始。

一、从官网拿到源码,然后配置自己缺少的环境。

论文: https://arxiv.org/abs/2103.14030

代码: https://github.com/microsoft/Swin-Transformer

注意下,这里的fused window process、还有apex等不安装也是可以跑通的。它们的安装不影响代码的运行。如果你最后对性能有很高的要求,那你再去下载,我们这里主要是学习,然后让它先跑起来。

你要最起码确保在data文件夹下、models文件夹下和最外层的所有py文件都没有依赖报错(就是导入包的报错)

就是如上图那些的一些包,你给它都下载好不报错就行了。或者你新搞个虚拟环境,然后重新装一下就行,怎么搞虚拟环境可以参考这篇文章【正在更新中...】。

针对可能遇到的错误:

注意,这个错误只是可能遇到,不是一定会遇到。它和你下载的Pytorch的版本有关系。并且你的分类数如果大于5,应该是不会报错的。

那我们需要做什么呢?就是你可能需要改一下你的accuracy函数。

说一下这个函数的入口在哪找,因为这个函数并不是Swin-transformer的函数,它是Pytorch内置的文件函数,所以它原本是只读的(只是有写保护,并非不可更改)。那么我们从哪里找呢?

找到main.py->函数validate,有一行

acc1, acc5 = accuracy(output, target, topk=(1, 5))

鼠标点击accuracy函数,然后按ctrl B就可以了(转到定义)

然后在metrics.py(这是个只读文件)里的这个函数accuracy,它应该是这个样子的(如下图):

但有些小伙伴该函数的第一行是这个样子的:

maxk = max(topk)

然后就导致如果你训练过程中,如果分类数比较少(比如二分类),那它就会报类似于“索引k超出范围”这样的错误。你把它这一行给改成上面截图的那个样子就行:

maxk = min(max(topk), output.size()[1])  //应该是这个样子的,否则二分类会报错

OK了家人们,现在我们环境配置就这样说完了。拿到项目和配置环境相信都是最基本的,没有什么好说的了哈。

二、数据集获取与处理

注意,我们是要用预训练权重去跑我们自己是数据集,所以不要傻乎乎的去下载ImgNet 1K,更不要傻乎乎地去下载ImgNet 22K哈哈哈,这些都是官方在一开始训练swin transformer的时候所用到的数据集,如果我们用预训练权重来去训练自己的训练集的话,是不需要下载这些东西的了。

我们这里,就以猫狗数据集为例来为大家说下数据集怎么整。

我们就以Kaggle猫狗大战数据集为例。

2.1 数据集下载

传送门:Dogs vs. Cats Redux: Kernels Edition | Kaggle​​​​​​

具体怎么下载,可以参考这篇文章:【正在更新中】

或者也可以用百度网盘分享的链接下载:【正在更新中】

简单来说,就是你需要先登录,同意它比赛的Rules,然后点击那个Download就可以下载了。数据集不是很大,所以直接从浏览器上下载感觉问题也不大。

2.2 数据集处理

OK,那么数据集下载完毕之后,我们要对数据集进行处理分类。

那个csv表格没有用的,它是给你提交的样例。我们需要拿的是test数据集合train数据集,解压缩。

我们的目标,就是把它整成这样的形式:

图①

简单来说,也就是把下载拿到的训练集分成训练集(train)和验证集(val),然后猫的目录下都存猫,狗的目录下都存狗。

怎么整呢?我的做法是,首先新建一个文件夹,命名为dataset,然后把从网上下载的train和test的压缩包都解压到该目录下。再新建一个main.py文件用于执行划分训练集和验证集的代码。那么,解压后的文件结构就是:

 |——dataset
      |——test
          |—— 0.png
          |—— 1.png
          |—— 2.png
          |——  ...
      |——train
          |——cat //这个是自己提前建好的空目录
          |——dog //这个也是提前建好的空目录
          |——cat.1.png
          |——cat.2.png
          |——cat.3.png
          |——...
          |——dog.1.png
          |——dog.2.png
          |——dog.3.png
          |——...
      |——val
          |——cat //这个是自己提前建好的空目录
          |——dog //这个也是提前建好的空目录
      |——main.py

然后在main.py中执行以下代码:

import os
import shutil
dir_train = '.\\train'
dir_val = '.\\val'
for file in os.listdir(dir_train):
    if file.startswith('cat') and file.endswith(".jpg"):
        num = int(file.split('.')[1])
        if num <= 9999:
            shutil.move(os.path.join(dir_train, file), os.path.join(dir_train, 'cat', file))
        else:
            shutil.move(os.path.join(dir_train, file), os.path.join(dir_val, 'cat', file))
    elif file.startswith('dog') and file.endswith(".jpg"):
        num = int(file.split('.')[1])
        if num <= 9999:
            shutil.move(os.path.join(dir_train, file), os.path.join(dir_train, 'dog', file))
        else:
            shutil.move(os.path.join(dir_train, file), os.path.join(dir_val, 'dog', file))

解释下这个代码是做什么的。这个代码就是把train目录文件下的编号是0-9999的猫和狗作为训练集(就是放到train目录下),然后把10000-12500的猫和狗放到val文件下。这样,我们在训练集,就总共有20000张图片,然后我们放了5000张图片留作验证使用。(当然你要是觉得这个比例不够好,可以自己再去调,该上面的两个9999就可以)

运行这段代码后,得到的文件夹结构就是像图①所示的那样了。

三、下载预训练权重

接下来,下载预训练配置权重。

从modelhub.md文件中,下载预训练权重。下载哪个都行,当然了,有些模型需要配置额外的环境。

我们这里就以最简单的Swin-T为例,我们下载的也是Image 1K的model的预训练权重。那么,下载下来的预训练权重的名称就是swin_tiny_patch4_window7_224.pth

OK,不管哪个,下载下来,然后放到swin-Transformer的根路径中,就像这样:

四、修改部分参数配置

由于我们是Windows操作系统,并且按照我们上面数据集来看,我们是二分类,所以我们需要修改一些参数配置。

4.1 修改config.py

_C.DATA.DATA_PATH = 'dataset'
# Dataset name
_C.DATA.DATASET = 'imagenet'
# Model name
_C.MODEL.NAME = 'swin_tiny_patch4_window7_224'
# Checkpoint to resume, could be overwritten by command line argument
_C.MODEL.RESUME ='swin_tiny_patch4_window7_224.pth'
# Number of classes, overwritten in data preparation
_C.MODEL.NUM_CLASSES = 2

上面这些参数是必然要改的。相信大家也能看懂这些参数是什么意思哈。

稍微解释下这个 (了解下即可)

_C.DATA.DATASET = 'imagenet' 

这个的意思,我的理解就是采用什么数据集来去训练的,它有俩选择,要么是“imagenet”,要么是"imagenet22K",对应下面build.py的截图中的if和elsif。

4.2 修改build.py

build.py文件在data文件夹里。

修改这个地方:

把原来的1000改成上图所示,或者直接改成2都行。

4.3 修改units.py

因为我们的预训练权重中最后是对1000类别进行分类的,而我们是二分类,所以我们需要在load_checkpoint这个函数中,把输出头的类别数给修改了。具体如下:

在上图的位置添加以下代码:

    if checkpoint['model']['head.weight'].shape[0] == 1000:
        checkpoint['model']['head.weight'] = torch.nn.Parameter(
            torch.nn.init.xavier_uniform_(torch.empty(config.MODEL.NUM_CLASSES, 768))
        )
        checkpoint['model']['head.bias'] = torch.nn.Parameter(torch.randn(config.MODEL.NUM_CLASSES))

4.4 修改main.py

将init_process_group函数修改如下:

torch.distributed.init_process_group(backend='gloo', init_method='env://', world_size=1, rank=0)

想要在windows环境下跑通代码,那么前面的backend得换成“gloo”。然后后面的init_method我的用'env://'是可以跑通的,有的小伙伴可能跑不通,我看网上也有用‘file://tmp/somefile’的(对我来说行不通),大家如果最终是因为这行代码而导致代码跑不起来,可以专门去网上看看解决方案。反正上述代码我的是跑起来了。

4.5 修改其他的地方

还有些其他奇奇怪怪的情况,在这里就不一一列举,上网一查都可以查到。比如

from torch._six import inf

已经不再用,而是直接从torch里导入,即

from torch import inf

即可。

但是这些问题不是每个人都会遇到,就比如我自己配置的时候并没有遇到这些问题,我给别人配置的时候就遇到了这些问题。大家自己如果遇到了再去查就行。

4.6 将最后结果以折线图的形式呈现出来

先说一下我们最后画出来的图是什么样子的。就是横坐标是迭代次数,纵坐标左边是每次Train的avg_loss(平均的loss)大小,右边是Test时每次的avg_acc@1(平均准确率)。

修改五个地方。

在util.py文件中修改一处:

也就是添加以下代码,就是新添加一个函数。

from matplotlib import pyplot as plt


def plot_curve(x, y1, y2, label1, label2):
    # 创建折线图
    fig, ax1 = plt.subplots()

    # 绘制第一条折线(范围在[0,1])
    ax1.plot(x, y1, 'b-', label=label1)
    ax1.set_xlabel('train_epoch')
    ax1.set_ylabel(label1, color='b')
    ax1.set_ylim([0, 8])
    ax2 = ax1.twinx()
    ax2.set_ylabel(label2, color='r')
    ax2.set_ylim([0, 1])
    ax2.plot(x, y2, 'r-', label=label2)
    lines1, labels1 = ax1.get_legend_handles_labels()
    lines2, labels2 = ax2.get_legend_handles_labels()
    lines = lines1 + lines2
    labels = labels1 + labels2
    ax1.legend(lines, labels)
    plt.show()

然后在main.py中修改四处:

第一处:

可以在一开始的地方加上两个list用于存储每个epoch的avg_loss和avg_acc@1.

第二处:

train_one_epoch函数的最后增加一行:

list_train_loss.append(loss_meter.avg)

就是说每次训练完都把avg_loss存储进去。

第三处:

同理,在validate函数中也添加一行代码:

list_val_acc.append(acc1_meter.avg)

作用和上面同理。

第四处:

在main函数的最后,加上这两行:

    num_x = range(config.TRAIN.START_EPOCH, config.TRAIN.EPOCHS)
    plot_curve(num_x, list_train_loss, list_val_acc, "train_loss", "val_acc@1")

它的作用是最后显示出图示来。

但需要注意下哈,我在这里画出来的图,默认只能够画出来你从断点开始到最后结束的那一段。就是你checkpoint之前的是画不出来的,所以尽量一次性训练完就好。

五、其他可修改的地方

再说说其他可以修改的地方:

如果你觉得它打印的频次太多了,你可以修改参数:

# Frequency to logging info
_C.PRINT_FREQ = 10

同样如果你觉得你训练次数过多或者过少,你可以修改参数:

_C.TRAIN.EPOCHS = 100

如果你不想把你自身训练的每一轮模型都保存,你可以修改参数:

# Frequency to save checkpoint
_C.SAVE_FREQ = 1

为什么修改这些参数大家可以从源码中去找哈。

六、运行代码

差不多都搞完了,我们将我们的代码跑起来:

如果不想用命令行,可以点击右上角的编辑配置:

然后在参数这一栏:

添加以下参数:

--cfg configs/swin/swin_tiny_patch4_window7_224.yaml  --local_rank 0  --batch-size=64 --data-path=data/dataset

解释下什么意思。

  • --cfg是参数配置文件,它在configs文件夹里,你添加的这个配置的文件名需要和你的预训练权重的名字是一样的。
  • --local_rank设置成0就行。
  • --data-path后面需要添加你数据文件的路径。我就把它放到了data/dataset文件下。
  • --batch-size看你GPU有多大了,如果不大,可以设置成16,如果比较大,也可以设置成64、128等。

然后就,跑起来了。

好啦,如果有什么问题,欢迎留言。同时,如果觉得我写的文章对你有帮助,那就点个收藏,或者赞和关注吧,我将会持续带来优质的分享。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:/a/211297.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

LeetCode的几道题

一、捡石头 292 思路就是&#xff1a; 谁面对4块石头的时候&#xff0c;谁就输&#xff08;因为每次就是1-3块石头&#xff0c;如果剩下4块石头&#xff0c;你怎么拿&#xff0c;我都能把剩下的拿走&#xff0c;所以你就要想尽办法让对面面对4块石头的倍数&#xff0c; 比如有…

python常用函数

1.len函数求字符串长度 例如 2.input函数为输入 input里边可以是任意类型的数据 但是它返回的值是一个字符串(即现在只能做出打印那些操作) 想做出其他操作的话,要强制类型转换 例,用str转换为字符串(类似的还有float),字符串可以互相拼接 所以要记得用了input函数后要强制…

十六进制数列求和

高精度数组的集大成 做的时候在和高中同学叙叙旧&#xff0c;差点寄掉 代码如下&#xff1a; #include<stdio.h> void expand(int len); const char hexadecimal[17] "0123456789ABCDEF"; int result[20], mid[20], l_result[100];int main(void) {char tm…

深度学习常见回归分支算法逐步分析,各种回归之间的优缺点,适用场景,举例演示

文章目录 1、线性回归&#xff08;Linear Regression&#xff09;1.1 优点1.2 缺点1.3 适用场景1.4 图例说明 2、多项式回归&#xff08;Polynomial Regression&#xff09;2.1 优点2.2 缺点2.3 适用场景2.4 图例说明 3、决策树回归&#xff08;Decision Tree Regression&#…

疫苗接种(链表练习)

很明显&#xff0c;数组也可以做&#xff0c;但是我想练习链表 这道题我上交的时候&#xff0c;同一份代码&#xff0c;三个编译器&#xff0c;三个成绩&#xff0c;有点搞心态 代码如下&#xff1a; #include<stdio.h> #include<math.h> #include<stdlib.h&…

线上CPU飙高问题排查!

https://v.douyin.com/iRTqH5ug/ linux top命令 top 命令是 Linux 下一个强大的实用程序&#xff0c;提供了系统资源使用情况的动态、实时概览。它显示了当前正在运行的进程信息&#xff0c;以及有关系统性能和资源利用情况的信息。 以下是 top 命令提供的关键信息的简要概述…

面试数据库八股文十问十答第一期

面试数据库八股文十问十答第一期 作者&#xff1a;程序员小白条&#xff0c;个人博客 1.MySQL常见索引、 MySQL常见索引有: 主键索引、唯一索引、普通索引、全文索引、组合索引(最左前缀)主键索引特点&#xff1a;唯一性&#xff0c;非空&#xff0c;自增&#xff08;如果使用…

Linux中的UDEV机制与守护进程

Linux中的UDEV守护进程 udev简介守护进程守护进程概念守护进程程序设计守护进程的应用守护进程和后台进程的区别 UDEV的配置文件自动挂载U盘 udev简介 udev是一个设备管理工具&#xff0c;udev以守护进程的形式运行&#xff0c;通过侦听内核发出来的uevent来管理/dev目录下的设…

cnpm 安装后无法使用怎么办?

问题的原因 cnpm 安装成功&#xff0c;但是却无法使用&#xff0c;一般分为两种情况&#xff0c;一种是提示无法执行命令&#xff0c;另一种是可以执行但是执行时报错&#xff0c;下面分别说明遇到这两种情况的解决方案。 解决方案 问题一&#xff1a;无法执行相关命令 首先…

零基础打靶—CTF4靶场

一、打靶的主要五大步骤 1.确定目标&#xff1a;在所有的靶场中&#xff0c;确定目标就是使用nmap进行ip扫描&#xff0c;确定ip即为目标&#xff0c;其他实战中确定目标的方式包括nmap进行扫描&#xff0c;但不局限于这个nmap。 2.常见的信息收集&#xff1a;比如平常挖洞使用…

jionlp :一款超级强大的Python 神器!轻松提取地址中的省、市、县

在日常数据处理中&#xff0c;如果你需要从一个完整的地址中提取出省、市、县三级地名&#xff0c;或者乡镇、村、社区两级详细地名&#xff0c;你可以使用一个第三方库来实现快速解析。在使用之前&#xff0c;你需要先安装这个库。 pip install jionlp -i https://pypi.douba…

LeetCode - 965. 单值二叉树(C语言,二叉树,配图)

二叉树每个节点都具有相同的值&#xff0c;我们就可以比较每个树的根节点与左右两个孩子节点的值是否相同&#xff0c;如果不同返回false&#xff0c;否则&#xff0c;返回true。 如果是叶子节点&#xff0c;不存在还孩子节点&#xff0c;则这个叶子节点为根的树是单值二叉树。…

【算法通关村】链表基础经典问题解析

【算法通关村】链表基础&经典问题解析 一.什么是链表 链表是一种通过指针将多个节点串联在一起的线性结构&#xff0c;每一个节点&#xff08;结点&#xff09;都由两部分组成&#xff0c;一个是数据域&#xff08;用来存储数据&#xff09;&#xff0c;一个是指针域&…

每日一练:冒泡排序

1. 概述 冒泡排序&#xff08;Bubble Sort&#xff09;也是一种简单直观的排序算法。它重复地走访过要排序的数列&#xff0c;一次比较两个元素&#xff0c;如果他们的顺序错误就把他们交换过来。走访数列的工作是重复地进行直到没有再需要交换&#xff0c;也就是说该数列已经排…

【集合篇】Java集合概述

Java 集合概述 集合与容器 容器&#xff08;Container&#xff09;是一个更广泛的术语&#xff0c;用于表示可以容纳、组织和管理其他对象的对象。它是一个更高层次的概念&#xff0c;包括集合&#xff08;Collection&#xff09;在内。集合&#xff08;Collection&#xff0…

CSS 选择器优先级,!important 也会被覆盖?

目录 1&#xff0c;重要性2&#xff0c;专用性3&#xff0c;源代码顺序 CSS 属性值的计算过程中。其中第2步层叠冲突只是简单说明了下&#xff0c;这篇文章来详细介绍。 层叠冲突更广泛的被称为 CSS选择器优先级计算。 为什么叫层叠冲突&#xff0c;可以理解为 CSS 是 Cascadi…

HarmonyOS开发工具安装

目录 下载与安装DevEco Studio DevEco Studio下载官网&#xff0c;点击下载 下载完成后&#xff0c;双击下载的“deveco-studio-xxxx.exe” 进入DevEco Studio安装向导 选择安装路径 如下安装选项界面勾选DevEco Studio后&#xff0c;单击“Next” 点击Install 安装完…

什么是Daily Scrum?

Daily Scrum&#xff08;每日站会&#xff09;&#xff0c;Scrum Master要确保这个会在每天都会开。这个会的目的就是检查正在做的东西和方式是否有利于完成Sprint目的&#xff0c;并及时做出必要的调整。 每日站会一般只开15分钟&#xff0c;为了让事情更简单些&#xff0c;这…

Python遥感开发之批量拼接

Python遥感开发之批量拼接 1 遥感图像无交错的批量拼接2 遥感图像有交错的批量拼接 前言&#xff1a;主要借助python实现遥感影像的批量拼接&#xff0c;遥感影像的批量拼接主要分为两种情况&#xff0c;一种是遥感图像无交错&#xff0c;另一种情况是遥感图像相互有交错。具体…

【送书活动三期】解决docker服务假死问题

工作中使用docker-compose部署容器&#xff0c;有时候会出现使用docker-compose stop或docker-compose down命令想停掉容器&#xff0c;但是依然无法停止或者一直卡顿在停止中的阶段&#xff0c;这种问题很让人头疼啊&#xff01; 目录 问题描述问题排查问题解决终极杀招-最粗暴…