专栏算法工具链【参考算法】地平线轨迹预测参考算法 DenseTNT-V1.2.1

【参考算法】地平线轨迹预测参考算法 DenseTNT-V1.2.1

芯链情报局2023-07-31
256
4
多年来,轨迹预测在社交机器人和自动驾驶汽车等自主系统中受到了极大的关注。 它旨在根据过去的轨迹和周围环境(包括地形和障碍物等静态因素以及周围移动智能体等动态因素)预测车辆、行人和骑自行车者等道路使用者的未来轨迹。工具链提供的轨迹预测参考算法为DenseTNT,通过vectornet对global graph构建,根据预测goals来预测未来轨迹,本文为该算法的介绍文档,以下为正文。

该示例为参考算法,仅作为在J5上模型部署的设计参考,非量产算法

0 性能精度指标

model

minFDE6(float/int)

minADE6(float/int)

MR(float/int)

batch1 帧率(J5/双核)

batch30 帧率(J5/双核)

batch1 latency(ms)

batch30 latency(ms)

Densetnt

1.29/1.30

0.737/0.741

0.101/0.102

1826

83.56

1.73

27.28

模型参数:

dataset

input_shape

topk

nms_threshold

argoverse1

见下方注释1

150

2m

注1:
"traj_feat": (N, 9, 19, 32) #[batch,轨迹点特征,轨迹点,车辆数]
"lane_feat":(N, 11, 9, 64) #[batch,道路矢量特征,矢量点,道路信息]
"instance_mask":(N, 1, 1, 96)#轨迹和道路的mask
"goals_2d":(N, 2, 1, 2048)#goal点坐标
"goals_2d_mask": (N, 1, 1, 2048)#goal点mask

1 模型介绍

DenseTNT为goals-based的轨迹预测模型,通过密集的目标状态集预测目标未来的状态,根据目标生成轨迹,整体pipeline如下图:

  • Context encoder:使用vectorNet 对高精地图和车辆信息进行编码,得到要预测的车辆的全局特征;
  • Goal encoder&Select:对密集的目标做编码,生成goals目标得分,筛选topk个goals;
  • Complete trajectory:经过MLP层和自注意力操作,得到整条预测轨迹的state,最后经过NMS得到6个目标以及对应轨迹。

1.1 模型改动点

在网络结构上,相比于官方实现,我们做了如下更改:
  1. subgraph中去除了subgraphLayer每一层中的maxpool计算,便于性能上提升

  2. encoder部分取消LaneGCN

  3. 去除Lane scoring和dense goal的概率估计

  4. 后处理部分goal的筛选使用NMS

1.2 源码说明

Config文件

configs/traj_pred/densetnt_argoverse1.py 为 densetnt的配置文件,定义了模型结构、数据集加载,和整套训练流程,所需参数的说明在算子定义中会给出。配置文件主要内容包括:
#train数据处理
train_set=dict(
...
)
# 数据加载相关定义
dataloader=dict(...)
val_data_loader=dict(...)
#callbacks 定义
stat_callback = dict(...)
ckpt_callback = dict(...)
val_callback = dict(...)
#训练策略配置
float_trainer=dict(...)
calibration_trainer=dict(...)
qat_trainer=dict(...)
int_infer_trainer=dict(...)
#编译设置
compile_cfg = dict(...)
# predictor
float_predictor = dict(...)
calibration_predictor= dict(...)
qat_predictor = dict(...)
int_infer_predictor= dict(...)

注: 如果需要复现精度,config中的训练策略最好不要修改。否则可能会有意外的训练情况出现。

Vectornet

上图为VectorNet的架构,主要由以下构成:
  • Input vectors

  • polyline subgraphs

  • global interaction graph

Input vectors

对于graph中的轨迹或道路信息都可以通过不同的采样方式得到向量表示,如下图:

不同的实例具有不同的特征,以下为agent轨迹特征和道路特征说明。

traj_feat(轨迹特征):[batch,轨迹点特征,轨迹点,车辆数]-->[1, 9, 19, 32]
Argoverse中的轨迹信息包括几种不同的类型:
1.Agent (需要预测的车辆)
2.AV (自车)
3.Others 其他车辆
Argoverse 1中每个轨迹是一共时长2秒,50个轨迹点。前20个点作为输入值,预测Agent的后30个轨迹点。轨迹信息的纬度目前是9维.
lane_feat(道路特征):[batch,道路矢量特征,矢量点,道路信息]-->[1, 11, 9, 64]

道路信息是从HD map里获得,需下载相应的map。道路信息目前有11维:
polyline subgraphs
从几何意义看,车道线包含多个控制点,交叉路口是个多边形(带多个顶点),交通标志是一个点,所有这些都可被近似–多个顶点多边形。polyline为graph的多边形或一个node,为node的集合 p={v0,v1...vp},一条折线最后凝练出一个特征向量,每个特征向量就是一个node。

polyline是全连接的,同一条折线(polyline)上的节点构成一张子图(subgraph),为了减少计算量,对traj和lane分别做subgraph提取,计算分为以下3个步骤:

step1: 对lane和traj做 encoder,encoder为MLP操作:

step2:将多层subgraph(layer num=3)做maxpool:

注:此处去除subgraphLayer每一层中的maxpool堆叠,仅最后做maxpooling,便于性能上提升

step3:将traj_feat和lane_feat做Cat

global interaction graph

为提取不同实例之间的交互特征,构建全局特征,将各个polyline node 经过一个GNN计算得到全局特征。

gobal_graph为self-attention操作,具体实现见:hat/models/base_modules/attention.py的HorizonMultiheadAttention,相较于公版的注意力计算,为适配J5,地平线版本的实现为4维Attention,逻辑与公版相同。

Densetnt

Densetnt为轨迹预测部分,通过Vectornet提取的map特征(graph_feat),预测目标点得分,选取最优的K个终点(K=150),然后输出预测的轨迹。主要由以下步骤构成:

step1:构建Dense goals的SubGraph,融合graph_feat
step2:获取goal得分

获得得分的做法是self-attention,goals_feats为query,traj_feats为key和value。
经过softmax后获得输出goals的得分。

step3: 经过一个linear后根据Score,选择topk个goals,k=150
step4:生成预测轨迹

在经过2个MLP层后做自注意力,然后经过2层MLP得到整条预测轨迹的state

模型输出为goals_preds, traj_preds, pred_goals:

Loss

loss为goals loss和target loss。其中goals loss为nll_loss;target loss为smooth_l1_loss

Postprocess

使用NMS筛选goals,threshold=2,代码如下:

后处理输出为6个goals和6条轨迹traj(一条轨迹30个轨迹点)。

2 浮点模型训练

2.1 Before Start

2.1.1 环境部署

DenseTNT示例集成在OE包中,获取方式见:J5芯片算法工具链OpenExplorer 版本发布
lidar_multitask 示例位于ddk/samples/ai_toolchain/horizon_model_train_sample下,其结构为:

release_models获取路径见:scripts/configs/traj_pred/README.md

拉取docker环境

如需本地离线安装HAT,我们提供了训练环境的whl包,路径在ddk/package/host/ai_toolchain

2.1.2 数据下载

在开始训练模型之前,第一步是需要准备好数据集,可以在Argoverse 1 数据集 下载。 需要下载Argoverse Motion Forecasting v1.1的TrainingValidation 和Argoverse HD Maps的Miami and Pittsburgh

下载后,解压并按照如下方式组织文件夹结构:

2.1.3 数据打包

为了提升训练的速度,需要对数据信息文件做了一个打包,将其转换成lmdb格式的数据集。只需要运行下面的脚本,就可以成功实现格式转换:

--src-data-dir为解压后的argoverse数据集目录;
--target-data为打包后数据集的存储目录;
--num-worker为执行线程数
数据集打包命令执行完毕后会在target-data-dir下生成train_lmdb和val_lmdb,train_lmdb和val_lmdb就是打包之后的训练数据集和验证数据集为config中的data_rootdir。

2.1.4 config配置

在进行模型训练和验证之前,需要对configs文件中的部分参数进行配置,一般情况下,我们需要配置以下参数:

  • device_ids、batch_size_per_gpu:根据实际硬件配置进行device_ids和每个gpu的batchsize的配置;

  • ckpt_dir:浮点、calib、量化训练的权重路径配置,权重下载链接在config文件夹下的README中;

  • data_rootdir:打包的lmdb数据集路径配置;

  • map_path :map_files文件夹的路径配置;

  • infer_cfg:指定模型输入,在infer脚本使用时需配置

2.2 浮点模型训练

在configs/traj_pred/densetnt_argoverse1.py下配置参数,需要将相关硬件配置device_ids和权重路径ckpt_dir数据集路径data_rootdir配置修改后使用以下命令训练浮点模型:

2.3 浮点模型精度验证

通过指定训好的float_checkpoint_path,使用以下命令验证已经训练好的模型精度:

验证完成后,会在终端输出float模型在验证集上的检测精度。

3 模型量化和编译

模型上板前需要将模型编译为.hbm文件, 可以使用compile的工具用来将量化模型编译成可以上板运行的hbm文件,因此首先需要将浮点模型量化,地平线对DenseTNT模型的量化采用horizon_plugin框架,通过Calibration+QAT量化训练和转换最终获得定点模型。

3.1 Calibration

为加速QAT训练收敛和获得最优量化精度,建议在QAT之前做calibration,其过程为通过batchsize个样本初始化量化参数,为QAT的量化训练提供一个更好的初始化参数,和浮点训练的方式一样,将checkpoint_path指定为训好的浮点权重路径。通过运行下面的脚本就可以开启模型的Calibration:

3.2 Calibration 模型精度验证

calibration完成以后,可以使用以下命令验证经过calib后模型的精度:

验证完成后,会在终端输出calib模型在验证集上的检测精度。

3.3 量化模型训练

Calibration完成后,就可以加载calib权重开启模型的量化训练。 量化训练其实是在浮点训练基础上的finetue,具体配置信息在config的qat_trainer中定义。量化训练的时候,初始学习率设置为浮点训练的十分之一,训练的epoch次数也大大减少。和浮点训练的方式一样,将checkpoint_path指定为训好的calibration权重路径。
通过运行下面的脚本就可以开启模型的qat训练:

3.4 量化模型精度验证

量化模型的精度验证,只需要运行以下命令:

qat模型的精度验证对象为插入伪量化节点后的模型(float32);quantized模型的精度验证对象为定点模型(int8),验证的精度是最终的int8模型的真正精度,这两个精度应该是十分接近的。

3.5 仿真上板精度验证

除了上述模型验证之外,我们还提供和上板完全一致的精度验证方法,可以通过下面的方式完成:

3.6 量化模型编译

在训练完成之后,可以使用compile的工具用来将量化模型编译成可以上板运行的hbm文件,同时该工具也能预估在BPU上的运行性能,可以采用以下脚本:
opt为优化等级,取值范围为0~3,数字越大优化等级越高,运行时间越长;
compile_perf脚本将生成.html文件和.hbm文件(compile文件目录下),.html文件为BPU上的运行性能,.hbm文件为上板实测文件。

4 其他工具

4.1 结果可视化

如果你希望可以看到训练出来的模型对于DenseTNT单帧的检测效果,我们的tools文件夹下提供了预测及可视化的脚本,你只需要运行以下脚本即可:

注:
需在infer_cfg中配置模型输入
由于开发机配置不同,plt.show可能不会正常显像,可以在/usr/local/lib/python3.8/dist-packages/hat/visualize/argoverse.py通过plt.savefig在保存的路径中查看

可视化示例:

brown:lane;blue:agent_history_traj;green:av_traj;yellow:other_trajs;purple:labels;red:preds

5 板端部署

5.1 上板性能实测

使用hrt_model_exec perf工具将生成的.hbm文件上板做BPU性能实测,hrt_model_exec perf参数如下:
算法工具链
评论3
0/1000
  • wylmine
    Lv.1

    感谢分享~ 可否请问一下中间结果嘛, calib pred和qat pred的minADE6. 我还未下载数据,所以暂未复现,想先问问结果怎么样~ 感谢感谢

    2023-08-14
    0
    1
    • 颜值即正义回复wylmine:
      您好,请参考:
      calib:'minADE': 0.8233378726330707, 'minFDE': 1.4106424627679135, 'MR': 0.12058570198105081

      qat:'minADE': 0.7427025280962253, 'minFDE': 1.2990297500595855, 'MR': 0.09988853422505953

      2023-08-15
      0
  • CC.
    Lv.1

    感谢分享,想请问一下生成hbm后续的板端部署环节是否有相关教程或示例?

    2025-04-02
    0
    0
  • Huanghui
    Lv.5

    这个就是QAT的参考算法,板端部署请参考ai_benchmark!

    2025-05-04
    0
    0