一、背景
最近需要用Vision Transformer(ViT)完成图像分类任务,因此查到了WZMIAOMIAO的GitHub,里面有各种图像处理的方法。而图像处理的前期工作就是获取大量的数据集,用于训练模型参数,以准确识别或分类我们的目标图像。
因此,这里以下载花分类数据集为例,并使用python程序,自动将数据集分为训练集和测试集,原理是通用的,我们可以用此方法,制作我们自己的数据集,并自动将其分类。
二、环境配置
系统:Windows 11(为了方便,我并没有切换到ubuntu系统)
为成功运行程序,我是新建了一个conda环境,conda名称为Vit。
Anaconda3
python3.8
pycharm(IDE)
具体指令如下:
# 打开Anaconda Prompt
conda create -n Vit python=3.8
conda activate Vit
三、下载数据集并自动分为训练集和测试集
先下载deep-learning-for-image-processing整个项目,保存在E:\manipulator_programming\ViT
文件夹。
项目链接:https://github.com/WZMIAOMIAO/deep-learning-for-image-processing
然后根据链接https://storage.googleapis.com/download.tensorflow.org/example_images/flower_photos.tgz
,下载花分类数据集。
花分类数据集使用教程:https://github.com/WZMIAOMIAO/deep-learning-for-image-processing/tree/master/data_set
具体步骤为:
(1)在data_set文件夹下创建新文件夹"flower_data"
(2)点击链接下载花分类数据集 https://storage.googleapis.com/download.tensorflow.org/example_images/flower_photos.tgz
(3)解压数据集到flower_data文件夹下
这一步一定要注意文件夹的层级结构,删除多余的文件,包括压缩文件,不然执行第4步脚步时容易报错。
├── data_set
├── flower_data
├── flower_photos
├──daisy
├──dandelion
├──roses
├──sunflowers
├──tulips
├──LICENSE.txt
├── README.md
└── split_data.py
小tip,如何在CSDN的HTML文档下输入空格字符:$ +空格$
,
(4)执行"split_data.py"脚本自动将数据集划分成训练集train和验证集val,完整代码如下:
import os
from shutil import copy, rmtree
import random
def mk_file(file_path: str):
if os.path.exists(file_path):
# 如果文件夹存在,则先删除原文件夹在重新创建
rmtree(file_path)
os.makedirs(file_path)
def main():
# 保证随机可复现
random.seed(0)
# 将数据集中10%的数据划分到验证集中
split_rate = 0.1
# 指向你解压后的flower_photos文件夹
cwd = os.getcwd()
data_root = os.path.join(cwd, "flower_data")
origin_flower_path = os.path.join(data_root, "flower_photos")
assert os.path.exists(origin_flower_path), "path '{}' does not exist.".format(origin_flower_path)
flower_class = [cla for cla in os.listdir(origin_flower_path)
if os.path.isdir(os.path.join(origin_flower_path, cla))]
# 建立保存训练集的文件夹
train_root = os.path.join(data_root, "train")
mk_file(train_root)
for cla in flower_class:
# 建立每个类别对应的文件夹
mk_file(os.path.join(train_root, cla))
# 建立保存验证集的文件夹
val_root = os.path.join(data_root, "val")
mk_file(val_root)
for cla in flower_class:
# 建立每个类别对应的文件夹
mk_file(os.path.join(val_root, cla))
for cla in flower_class:
cla_path = os.path.join(origin_flower_path, cla)
images = os.listdir(cla_path)
num = len(images)
# 随机采样验证集的索引
eval_index = random.sample(images, k=int(num*split_rate))
for index, image in enumerate(images):
if image in eval_index:
# 将分配至验证集中的文件复制到相应目录
image_path = os.path.join(cla_path, image)
new_path = os.path.join(val_root, cla)
copy(image_path, new_path)
else:
# 将分配至训练集中的文件复制到相应目录
image_path = os.path.join(cla_path, image)
new_path = os.path.join(train_root, cla)
copy(image_path, new_path)
print("\r[{}] processing [{}/{}]".format(cla, index+1, num), end="") # processing bar
print()
print("processing done!")
if __name__ == '__main__':
main()
我们可以根据这个框架,进行适当修改,将自己的数据集自动分为训练集和测试集。
至此,数据集中的10%被复制到val文件夹下,90%被复制到train文件夹下,完美!!!
结果如下:
分享一张花分类数据集中好看的tulips郁金香图片。