1.centos7 安装显卡驱动、cuda、cudnn-CSDN博客
2.安装conda python库-CSDN博客
3.Cenots Swin-Transformer-Object-Detection环境配置-CSDN博客
步骤1:准备待训练的coco数据集
下载地址:https://download.csdn.net/download/malingyu/88519420
https://download.csdn.net/download/malingyu/88519411
说明:由于数据集比较大,分开两个资源下载
在项目跟目录,新建目录data/coco,将下载的资源直接放到文件夹中
复制test2017,分布为train2017、val2017。
步骤2:修改tools/tran.py文件
其中config添加上默认的路径
def parse_args():
parser = argparse.ArgumentParser(description='Train a detector')
parser.add_argument('--config',default='../configs/swin/mask_rcnn_swin_tiny_patch4_window7_mstrain_480-800_adamw_3x_coco.py', help='train config file path')
步骤3:修改文件configs/_base_/default_runtime.py
添加上下载好的模型路径。
步骤4.修改文件configs/_base_/dataset/coco_instance.py
补充好data_root,和后面的文件夹路径
dataset_type = 'CocoDataset'
data_root = '../data/coco/'
img_norm_cfg = dict(
mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375], to_rgb=True)
train_pipeline = [
dict(type='LoadImageFromFile'),
dict(type='LoadAnnotations', with_bbox=True, with_mask=True),
dict(type='Resize', img_scale=(1333, 800), keep_ratio=True),
dict(type='RandomFlip', flip_ratio=0.5),
dict(type='Normalize', **img_norm_cfg),
dict(type='Pad', size_divisor=32),
dict(type='DefaultFormatBundle'),
dict(type='Collect', keys=['img', 'gt_bboxes', 'gt_labels', 'gt_masks']),
]
test_pipeline = [
dict(type='LoadImageFromFile'),
dict(
type='MultiScaleFlipAug',
img_scale=(1333, 800),
flip=False,
transforms=[
dict(type='Resize', keep_ratio=True),
dict(type='RandomFlip'),
dict(type='Normalize', **img_norm_cfg),
dict(type='Pad', size_divisor=32),
dict(type='ImageToTensor', keys=['img']),
dict(type='Collect', keys=['img']),
])
]
data = dict(
samples_per_gpu=2,
workers_per_gpu=2,
train=dict(
type=dataset_type,
ann_file=data_root + 'annotations/instances_train2017.json',
img_prefix=data_root + 'train2017/',
pipeline=train_pipeline),
val=dict(
type=dataset_type,
ann_file=data_root + 'annotations/instances_val2017.json',
img_prefix=data_root + 'val2017/',
pipeline=test_pipeline),
test=dict(
type=dataset_type,
ann_file=data_root + 'annotations/instances_val2017.json',
img_prefix=data_root + 'val2017/',
pipeline=test_pipeline))
evaluation = dict(metric=['bbox', 'segm'])
步骤5:进入tools目录
执行python tran.py文件
运行成功,可以进行数据的训练。
报错问题:TypeError: FormatCode() got an unexpected keyword argument ‘verify’
原因:yapf版本过高
由0.40.2 切换成 0.40.1问题解决
pip install yapf==0.40.1