博客專(zhuān)欄

EEPW首頁(yè) > 博客 > 地平線 3D 目標(biāo)檢測(cè) Bevformer 參考算法 V2.0

地平線 3D 目標(biāo)檢測(cè) Bevformer 參考算法 V2.0

發(fā)布人:地平線開(kāi)發(fā)者 時(shí)間:2025-02-08 來(lái)源:工程師 發(fā)布文章

該示例為參考算法,僅作為在 征程 6 上模型部署的設(shè)計(jì)參考,非量產(chǎn)算法

簡(jiǎn)介

BEVFormer 是當(dāng)前熱門(mén)的自動(dòng)駕駛系統(tǒng)中的 3D 視覺(jué)感知任務(wù)模型。BEVFormer 是一個(gè)端到端的框架,BEVFormer 可以直接從原始圖像數(shù)據(jù)生成 BEV 特征,無(wú)需依賴(lài)于傳統(tǒng)的圖像處理流程。它通過(guò)利用 Transformer 架構(gòu)和注意力機(jī)制,有效地從多攝像頭圖像中學(xué)習(xí)生成高質(zhì)量的鳥(niǎo)瞰圖(Bird's-Eye-View, BEV)特征表示。相較于其他的 BEV 轉(zhuǎn)換方式:

  1. 時(shí)空注意力機(jī)制:模型結(jié)合了空間交叉注意力(Spatial Cross-Attention, SCA)和時(shí)間自注意力(Temporal Self-Attention, TSA),使網(wǎng)絡(luò)能夠同時(shí)考慮空間和時(shí)間維度上的信息。融合歷史 bev 特征來(lái)提升預(yù)設(shè)的 BEV 空間中的 query 的自學(xué)能力,得到 bev 特征。

  2. Deformable attn:通過(guò)對(duì)每個(gè)目標(biāo)生成幾個(gè)采樣點(diǎn)和采樣點(diǎn)的 offset 來(lái)提取采樣點(diǎn)周?chē)闹匾卣鳎粗魂P(guān)注和目標(biāo)相關(guān)的特征,減少計(jì)算量。

  3. transformer 架構(gòu):能夠有效捕捉序列中的長(zhǎng)期依賴(lài)關(guān)系,適用于處理圖像序列。

性能精度指標(biāo)

模型參數(shù):

圖片


性能精度表現(xiàn):

image.png

模型介紹

圖片

·公版 BEVFormer 模型主要可以分為以下幾個(gè)關(guān)鍵部分:

  1. Backbone 網(wǎng)絡(luò):用于從多視角攝像頭圖像中提取特征,本文為 tiny 版本,因此為 ResNet50。

  2. 時(shí)空特征提取:BEVFormer 通過(guò)引入時(shí)間和空間特征來(lái)學(xué)習(xí) BEV 特征。具體來(lái)說(shuō),模型包括:

  3. Temporal Self-Attention(時(shí)間自注意力):利用前一時(shí)刻的 BEV 特征作為歷史特征,通過(guò)自注意力機(jī)制來(lái)計(jì)算當(dāng)前時(shí)刻的 BEV 特征。

  4. Spatial Cross-Attention(空間交叉注意力):進(jìn)行空間特征注意力,融合多視角圖像特征。

  5. Deformable Attention(可變形注意力):BEVFormer 使用可變形注意力機(jī)制來(lái)加速運(yùn)算,提高模型對(duì)不同視角圖像特征的適應(yīng)性。

  6. BEV 特征生成:通過(guò)時(shí)空特征的融合,完成環(huán)視圖像特征向 BEV 特征的建模。

  7. Decoder:設(shè)計(jì)用于 3D 物體檢測(cè)的端到端網(wǎng)絡(luò)結(jié)構(gòu),基于 2D 檢測(cè)器 Deformable DETR 進(jìn)行改進(jìn),以適應(yīng) 3D 空間的檢測(cè)任務(wù)。

地平線部署說(shuō)明

公版 bevformer 在 征程 6 上部署相比于 征程 5 來(lái)說(shuō)更簡(jiǎn)單了,需要考慮的因素更少。征程 6 對(duì)非 4 維的支持可以和 4 維的同等效率,因此 征程 6 支持公版的注意力實(shí)現(xiàn),不再限制維度,因此無(wú)需對(duì)維度做 Reshape,可直接支持公版寫(xiě)法。但需注意的是公版的 bev_mask 會(huì)導(dǎo)致動(dòng)態(tài) shape。征程 6 不支持動(dòng)態(tài)輸入,因此 bev_mask 無(wú)法使用。在精度上,我們修復(fù)了公版的 bug 已獲得了精度上的提升,同時(shí)通過(guò)對(duì)關(guān)鍵層做 int16 的量化精度配置以保障 1%以?xún)?nèi)的量化精度損失。

下面將部署優(yōu)化對(duì)應(yīng)的改動(dòng)點(diǎn)以及量化配置依次說(shuō)明。

性能優(yōu)化

改動(dòng)點(diǎn) 1:

將 attention 層的 mean 替換為 conv 計(jì)算,使性能上獲得提升。

/usr/local/lib/python3.10/dist-packages/hat/models/task_modules/bevformer/attention.py

self.query_reduce_mean = nn.Conv2d(
   self.num_bev_queue * self.reduce_align_num,
   self.reduce_align_num,
   1,
   bias=False,)

# init query_reduce_mean weight
query_reduce_mean_weight = torch.zeros(
   self.query_reduce_mean.weight.size(),      
   dtype=self.query_reduce_mean.weight.dtype,
)
for i in range(self.reduce_align_num):
   for j in range(self.num_bev_queue):
       query_reduce_mean_weight[i, j * self.reduce_align_num + i] = (
           1 / self.num_bev_queue
       )
self.query_reduce_mean.weight = torch.nn.Parameter(
   query_reduce_mean_weight, requires_grad=False
)

改動(dòng)點(diǎn) 2:

公版中,在 Encoder 的空間融合模塊,會(huì)根據(jù) bev_mask 計(jì)算有效的 query 和 reference_points,輸出 queries_rebatch 和 reference_points_rebatch,作用為減少交互的數(shù)據(jù)量,提升模型運(yùn)行性能。對(duì)于稀疏的 query 做 crossattn 后再將 query 放回到 bev_feature 中。

以上提取稀疏 query 步驟的主要算子為 gather,放回 bev_feature 步驟的主要算子為 scatter。由于工具鏈對(duì)這兩個(gè)算子暫未支持(gather 算子 930 已支持)而且 bev_mask 為動(dòng)態(tài)的,為了提升模型的運(yùn)行性能,工具鏈提供了 gridsample 算子的替換方式,index 計(jì)算只與內(nèi)外參有關(guān),因此作為前處理,將計(jì)算好的 index 作為模型輸入即可。

gather

gather 為根據(jù) bevmask 來(lái)提取稀疏 query,降低 cross attn 的數(shù)據(jù)量,提升運(yùn)行效率。

代碼路徑:<code>/usr/local/lib/python3.10/dist-packages/hat/models/task_modules/bevformer/<span style="caret-color: #000000; color: #000000; font-family: monospace; font-size: medium; font-style: normal; font-variant-caps: normal; font-weight: 400; letter-spacing: normal; orphans: auto; text-align: start; text-indent: 0px; text-transform: none; white-space: normal; widows: auto; word-spacing: 0px; -webkit-text-stroke-width: 0px; background-color: #e8e8e8; text-decoration: none; display: inline !important; float: none;">view_transformer.py</span>

        reference_points_cam = torch.clamp(
           reference_points_cam, min=-2.1, max=2.1
       )
       reference_points_cam = reference_points_cam.permute(2, 1, 3, 0, 4)
       bev_mask = bev_mask.permute(2, 1, 3, 0, 4).squeeze(-1)
       bev_mask_ori = bev_mask.clone()
       max_len = self.virtual_bev_h * self.virtual_bev_w
       queries_rebatch_grid = reference_points_cam.new_zeros(
           [B * self.numcam, self.virtual_bev_h, self.virtual_bev_w, 2]
       )
       for camera_idx, mask_per_img_bs in enumerate(bev_mask):
           for bs_id, mask_per_img in enumerate(mask_per_img_bs):
               temp_grid = (
                   torch.zeros(
                       (max_len, 2),
                       device=queries_rebatch_grid.device,
                       dtype=torch.float32,
                   )
                   - 1.5
               )
               index_query_per_img = (
                   mask_per_img.sum(-1).nonzero().squeeze(-1)
               )
               num_bev_points = index_query_per_img.shape[0]
               camera_idx_tensor_x = index_query_per_img % self.bev_w
               camera_idx_tensor_y = index_query_per_img // self.bev_w
               index_grid = torch.stack(
                   [
                       camera_idx_tensor_x / (self.bev_w - 1),
                       camera_idx_tensor_y / (self.bev_h - 1),
                   ],
                   dim=-1,
               )
               index_grid = index_grid * 2 - 1
               temp_grid[:num_bev_points] = index_grid
               temp_grid = temp_grid.reshape(
                   self.virtual_bev_h, self.virtual_bev_w, 2
               )
               queries_rebatch_grid[
                   bs_id * self.numcam + camera_idx
               ] = temp_grid
       reference_points_rebatch = (
           reference_points_cam.flatten(-2)
           .permute(1, 0, 3, 2)
           .flatten(0, 1)
           .reshape(B * self.numcam, D * 2, self.bev_h, self.bev_w)
       )
       reference_points_rebatch = (
           F.grid_sample(
               reference_points_rebatch,
               queries_rebatch_grid,
               mode="nearest",
               align_corners=True,
           )
           .flatten(-2)
           .permute(0, 2, 1)
           .reshape(B * self.numcam, max_len, D, 2)
       )

query_rebatch

        queries_rebatch = (
           query.unsqueeze(1)
           .repeat(1, self.num_cams, 1, 1)
           .reshape(
               bs * self.num_cams, self.bev_h, self.bev_w, self.embed_dims
           )
           .permute(0, 3, 1, 2)
       )
       queries_rebatch = F.grid_sample(
           queries_rebatch,
           queries_rebatch_grid,
           mode="nearest",
           align_corners=True,
       )
       queries_rebatch = queries_rebatch.flatten(-2).permute(0, 2, 1)
       reference_points_rebatch = reference_points_rebatch.flatten(
           -2
       ).unsqueeze(-2)

scatter

scatter 操作對(duì)經(jīng)過(guò) deformable_attention 后的 query 放入到 bevfeature 中,然后求平均。

代碼路徑為:/usr/local/lib/python3.10/dist-packages/hat/models/task_modules/bevformer/attention.py

        slots = self.restore_outputs(
           restore_bev_grid,
           queries_out,
           bev_pillar_counts,
           bs,
           queries_rebatch_grid,
       )
   def restore_outputs(
       self,
       restore_bev_grid: Tensor,
       queries_out: Tensor,
       counts: Tensor,
       bs: int,
       queries_rebatch_grid: Tensor,
   ):
       """Restore outputs to bev feature."""
       queries_out = queries_out.reshape(
           bs, self.num_cams, self.embed_dims, -1
       )
       queries_out = queries_out.permute(0, 2, 1, 3)
       queries_out = queries_out.reshape(
           bs,
           self.embed_dims,
           self.num_cams * queries_rebatch_grid.shape[1],
           queries_rebatch_grid.shape[2],
       )
       bev_queries = F.grid_sample(
           queries_out, restore_bev_grid, mode="nearest", align_corners=True
       )
       bev_queries = bev_queries.reshape(bs, -1, self.bev_h, self.bev_w)
       slots = self.query_reduce_sum(bev_queries).flatten(-2).permute(0, 2, 1)
       slots = self.mul_pillarweight.mul(slots, counts)
       return slots

其中 restore_bev_grid,根據(jù) bevmask 反算回 bev_feature 的位置:

        restore_bev_grid = (
           reference_points_cam.new_zeros(
               B, self.max_camoverlap_num * self.bev_h, self.bev_w, 2
           )
           - 1.5
       )
       for bs_id, bev_mask_ in enumerate(bev_mask):
           bev_pillar_num_map = torch.zeros(
               (self.bev_h, self.bev_w), device=bev_mask_.device
           )
           count = bev_mask_.sum(-1) > 0
           camera_idxs, bev_pillar_idxs = torch.where(count)
           camera_idx_offset = 0
           for cam_id in range(self.numcam):
               camera_idx = torch.where(camera_idxs == cam_id)
               bev_pillar_idx_cam = bev_pillar_idxs[camera_idx[0]]
               num_camera_idx = len(camera_idx[0])
               camera_idx_tmp = camera_idx[0] - camera_idx_offset
               camare_tmp_idx_x = camera_idx_tmp % self.virtual_bev_w
               camare_tmp_idx_y = camera_idx_tmp // self.virtual_bev_w
               grid_x = camare_tmp_idx_x
               grid_y = cam_id * self.virtual_bev_h + camare_tmp_idx_y
               bev_pillar_idx_cam_x = bev_pillar_idx_cam % self.bev_w
               bev_pillar_idx_cam_y = bev_pillar_idx_cam // self.bev_w
               bev_pillar_num_map_tmp = bev_pillar_num_map[
                   bev_pillar_idx_cam_y, bev_pillar_idx_cam_x
               ]
               grid_h = (
                   bev_pillar_num_map_tmp * self.bev_h + bev_pillar_idx_cam_y
               ).to(torch.int64)
               grid_w = (bev_pillar_idx_cam_x).to(torch.int64)
               restore_bev_grid[bs_id, grid_h, grid_w, 0] = grid_x / (
                   self.virtual_bev_w - 1
               )
               restore_bev_grid[bs_id, grid_h, grid_w, 1] = grid_y / (
                   self.numcam * self.virtual_bev_h - 1
               )
               bev_pillar_num_map[
                   bev_pillar_idx_cam_y, bev_pillar_idx_cam_x
               ] = (
                   bev_pillar_num_map[
                       bev_pillar_idx_cam_y, bev_pillar_idx_cam_x
                   ]
                   + 1
               )
               camera_idx_offset = camera_idx_offset + num_camera_idx
       restore_bev_grid = restore_bev_grid * 2 - 1
精度優(yōu)化浮點(diǎn)精度

改動(dòng)點(diǎn) 3:

公版通過(guò) can_bus 初始化 ref 來(lái)做時(shí)序融合,然而這個(gè)時(shí)候 bev feat 并沒(méi)有對(duì)齊,在 attention 計(jì)算時(shí)不能簡(jiǎn)單的 concat 起來(lái)。因此我們換了一種時(shí)序?qū)R的方式,通過(guò)前后兩幀的 ego2global 坐標(biāo)系轉(zhuǎn)換矩陣將當(dāng)前幀的 bev 特征和上一幀對(duì)齊,此時(shí) ref 都是一樣的。(非 征程 6 不支持,為公版 bug),精度上獲得提升。

/usr/local/lib/python3.10/dist-packages/hat/models/task_modules/bevformer/view_transformer.py` `get_prev_bev` `get_fusion_ref
pre_scene = prev_meta["scene_token"]
for i in range(bs):
   if pre_scene[i] != cur_meta["meta"][i]["scene"]:
       prev_bev[i] = torch.zeros(
           (self.bev_h * self.bev_w, self.embed_dims),
           dtype=torch.float32,
           device=device,
       )
##公版:
shift_ref_2d = ref_2d.clone()
shift_ref_2d += shift[:, None, None, :]
bs, len_bev, num_bev_level, _ = ref_2d.shape
hybird_ref_2d = torch.stack([shift_ref_2d, ref_2d], 1).reshape(
   bs*2, len_bev, num_bev_level, 2)
##地平線版本
shift_ref_2d = ref_2d.clone()
bs, len_bev, num_bev_level, _ = ref_2d.shape
hybird_ref_2d = torch.stack([shift_ref_2d, ref_2d], 1).reshape(
   bs * 2, len_bev, num_bev_level, 2
)


改動(dòng)點(diǎn) 4:

修復(fù)了個(gè) tsa 公版的 batchsize 不等于 1 的 bug。BEVFormer/projects/mmdet3d_plugin/bevformer/modules/temporal_self_attention.py at master · fundament

量化精度

為量化精度保證,我們將以下的算子配置為 int16 或 int32 輸出:

view_transformer:輸入節(jié)點(diǎn)做 int16 量化:

        int16_models = [
           self.quant_hybird_ref_2d,
           self.quant_norm_coords,
           self.quant_restore_bev_grid,
           self.quant_reference_points_rebatch,
           self.quant_queries_rebatch_grid,
       ]
       for m in int16_models:
           m.qconfig = qconfig_manager.get_qconfig(
               activation_qat_qkwargs={"dtype": qint16},
               activation_calibration_qkwargs={
                   "dtype": qint16,
               },
               activation_calibration_observer="mix",
           )

attention 層:最后兩個(gè) conv 和 add 開(kāi)啟 int16

    def set_qconfig(self) -> None:
       """Set the quantization configuration."""
       from hat.utils import qconfig_manager
       int16_module = [
           self.output_proj,
           self.add_res,
       ]

decoder 層:cls_branches、reg_branches 的 conv 配置為 int32 輸出;sigmoid 和 reference_points 配置為 int16

    def set_qconfig(self) -> None:
        """Set the quantization configuration."""
        from hat.utils import qconfig_manager
        for _, m in enumerate(self.cls_branches):
            m[0].qconfig = qconfig_manager.get_qconfig(
                activation_qat_qkwargs={"dtype": qint16},
                activation_calibration_qkwargs={
                    "dtype": qint16,
                },
                activation_calibration_observer="mix",
            )
            m[3].qconfig = qconfig_manager.get_qconfig(
                activation_qat_qkwargs={"dtype": qint16},
                activation_calibration_qkwargs={
                    "dtype": qint16,
                },
                activation_calibration_observer="mix",
            )
            m[-1].qconfig = qconfig_manager.get_default_qat_out_qconfig()
        self.reg_branches[-1][
            -1
        ].qconfig = qconfig_manager.get_default_qat_out_qconfig()
        self.query_embedding.qconfig = None
        int16_module = [
            self.reference_points,
            self.sigmoid,
        ]
        for m in int16_module:
            m.qconfig = qconfig_manager.get_qconfig(
                activation_qat_qkwargs={"dtype": qint16},
                activation_calibration_qkwargs={
                    "dtype": qint16,
                },
                activation_calibration_observer="mix",
            )
總結(jié)與建議訓(xùn)練建議
  • 浮點(diǎn)和公版一致即可

  • qat 訓(xùn)練需要將 lr 降低,下降策略建議使用 StepDecayLrUpdater。

部署建議
  • 建議 bev size 的選擇考慮性能影響。征程 6 相比于 征程 5 帶寬增大,但仍需注意 bevsize 過(guò)大導(dǎo)致訪存時(shí)間過(guò)長(zhǎng)對(duì)性能的影響,建議考慮實(shí)際部署情況選擇合適的 bevsize 做性能驗(yàn)證。

  • 使用 bevmask 來(lái)提升運(yùn)行性能,可參考 4.1 章節(jié)使用 gridsample 替換不支持的 scatter。

  • 在注意力機(jī)制中存在一些 ElementWise 操作,對(duì)于導(dǎo)致性能瓶頸的可以考慮 conv 替換,對(duì)于造成量化風(fēng)險(xiǎn)的可以根據(jù)敏感度分析結(jié)果合理選擇更高的量化精度,以確保注意力機(jī)制的部署。


本文通過(guò)對(duì) Bevformer 在地平線征程 6 上量化部署的優(yōu)化,使得模型在該計(jì)算方案上用低于 1%的量化精度損失,得到 latency 為 45.74ms 的部署性能,同時(shí),通過(guò) Bevformer 的部署經(jīng)驗(yàn),可以推廣到其他模型部署優(yōu)化,例如包含 MSDA 模型結(jié)構(gòu)、transformer-based BEV 的部署。

附錄
  1. 論文:https://arxiv.org/pdf/2203.17270

  2. 公版代碼:https://github.com/fundamentalvision/BEVFormer


*博客內(nèi)容為網(wǎng)友個(gè)人發(fā)布,僅代表博主個(gè)人觀點(diǎn),如有侵權(quán)請(qǐng)聯(lián)系工作人員刪除。




相關(guān)推薦

技術(shù)專(zhuān)區(qū)

關(guān)閉