0.简介
github项目链接 论文链接
Argoverse 1.1验证集的预期性能是:
Models minADE minFDE MR
HiVT-64 0.69 1.03 0.10
HiVT-128 0.66 0.97 0.09
1. 拉取代码仓库
git clone https://github.com/ZikangZhou/HiVT.git
cd HiVT
2. 创建conda环境
conda create -n HiVT python=3.8
conda activate HiVT
conda install pytorch==1.8.0 cudatoolkit=11.1 -c pytorch -c conda-forge
conda install pytorch-geometric==1.7.2 -c rusty1s -c conda-forge
conda install pytorch-lightning==1.5.2 -c conda-forge
3. 下载Argoverse预测数据集
Argoverse Motion Forecasting Dataset v1.1
下载后解压成下面的形式
Argoverse1.1
├── train/
| └── data/
| ├── 1.csv
| ├── 2.csv
| ├── ...
└── val/
└── data/
├── 1.csv
├── 2.csv
4. 安装Argoverse API
如果需要安装在HiVT的conda环境里,忽略下面连接中创建conda环境的操作
Ubuntu Argoverse API安装
5. 训练
如果数据集在home路径(~/Argoverse1.1/)下,执行
python train.py --root ~/Argoverse1.1/ --embed_dim 64 //To train HiVT-64
python train.py --root ~/Argoverse1.1/ --embed_dim 128 //To train HiVT-128
6. 常见错误
6.1 RuntimeError(‘received %d items of ancdata’ %
训练过程中报该错误,是因为文件读取太多,修改文件读取限制
ulimit -n //查看读取文件限制数量
ulimit -n 65536 //修改读取文件限制数量为65536