目前运营的社交平台账号:
- CSDN 【雪天鱼】: 雪天鱼-CSDN博客
- 哔哩哔哩 【雪天鱼】: 雪天鱼个人主页-bilibili.com
可能后续有更新,也可能没有更新,谨慎参考
- V1.0 24-02-13 ViT 代码的基本训练, 预测推理脚本运行
1 学习目标
- 能用官方的 ViT 预训练模型在 imagenet1k 上进行预测推理 完成
- 在 ImageNet-1K 的完整验证集上验证下载的官方 ViT 预训练模型的准确率
未处理的问题:
- 官方的 ViT 预处理模型训练时的图片数据预处理方法是什么?
Github pytorch实现的 ViT 代码下载:deep-learning-for-image-processing/pytorch_classification/vision_transformer at master · WZMIAOMIAO/deep-learning-for-image-processing · GitHub
Note: 非官方仓库代码,但 vit_model.py 即ViT 模型定义代码是用的被 TIMM 采用的代码。
已经处理好的 ImageNet1K数据集网盘链接:
链接:https://pan.baidu.com/s/1sYMIwqkNldmqpaJqDK8lSQ?pwd=2024
提取码:2024
2 运行 flops.py
(不重要,可跳过)
先安装fvcore包: pip install fvcore
然后点击运行会出错,报错为:
ValueError: Invalid type <class 'numpy.int32'> for the flop count! Please use a wider type to avoid overflow.
点击红框中的位置进入到 jit_handles.py
文件中,修改 14~19行代码如下:
try:
from math import prod
except ImportError:
from numpy import prod as prodnp
def prod(x):
return int(prodnp(x))
然后再重新运行 flops.py 无报错。结果为:
Self-Attention FLOPs: 60129542144
# 中间有一些红色字体的 warnings
Multi-Head Attention FLOPs: 68719476736
3 训练—train.py
从 vit_model 中导入想要训练的 ViT版本, 把默认导入的 vit_base_patch16_224_in21k
给注释掉,确保加载的预训练权重和实例化的模型class一致。
from vit_model import vit_base_patch16_224 as create_model
运行脚本,默认训练10 epochs, 每轮都会将训练好的权重文件保存至 weights 目录下
模型有 327 MB
用tensorboard 打开 runs 目录下的训练log,如下图所示:
4 预测推理—predict.py
现在我们用训练好的模型进行预测推理,自己从数据集或者网上选择一张图作为输入,预测结果如下图所示:
5 在 ImageNet1K 数据集上进行预测推理
我们可以直接加载官方预训练模型在 ImageNet1K 数据集上进行预测推理,需要准备 imagenet 1k的类别索引 json文件,这里我们从github下载即可:
https://github.com/raghakot/keras-vis/blob/master/resources/imagenet_class_index.json
然后准备好部分的 imagenet1K 数据集作为输入的预测图片,最终效果如下图所示:
在进行 data_transform
预处理之后,输入图片数据的最大值为 1,最小值为 -0.97
6 其他未整理的学习资料
- pytorch实现的ViT的详细思路讲解: GitHub - FrancescoSaverioZuppichini/ViT: Implementing Vi(sion)T(transformer)
- 对应的中文翻译:Vision Transformer(ViT)PyTorch代码全解析(附图解)_vit代码-CSDN博客