新闻  |   论坛  |   博客  |   在线研讨会
BEVFormer 开源算法逐行解析(二):Decoder 和 Det 部分
地平线开发者 | 2024-09-11 11:48:24    阅读:76   发布文章

写在前面:

对于 BEVFormer 算法框架的整体理解,大家可以找到大量的资料参考,但是对于算法代码的解读缺乏详实的资料。因此,本系列的目的是结合代码实现细节、在 tensor 维度的变换中帮助读者对算法能有更直观的认识。

本系列我们将对 BEVFormer 公版代码(开源算法)进行逐行解析,以结合代码理解 Bevformer 原理,掌握算法细节,帮助读者们利用该算法框架开发感知算法。在本系列的最后笔者还将面向地平线的用户,指出地平线参考算法在开源算法基础上做出的修改及修改背后的考虑,在算法部署过程中为用户提供参考。

公版代码目录封装较好,且以注册器的方式调用模型,各个模块的调用关系可以从 configs/bevformer 中的 config 文件中清晰体现,我们以 bevformer_tiny.py 为例3解析代码,Encoder 部分已经发出,见《BEVFormer 开源算法逐行解析(一):Encoder 部分》,本文主要关注 BEVFormer 的 Decoder 和 Det 部分。

对代码的解析和理解主要体现在代码注释中。

1 PerceptionTransformer:

功能:

  • 将 encoder 层输出的 bev_embed 传入 decoder 中

  • 将在 BEVFormer 中定义的 query_embedding 按通道拆分成通道数相同的 query_pos 和 query,并传入 decoder 中;

  • 利用 query_pos 通过线性层 reference_points 生成 reference_points,并传入 decoder;该 reference_points 在 decoder 中的CustimMSDeformableAttention 作为融合 bev_embed 的基准采样点,作用类似于 two-stage 目标检测中的 Region Proposal ;

  • 返回 inter_states, inter_references 给 cls_branches 和 reg_branches 分支得到目标的种类和 bboxes。

解析:

#详见《BEVFormer开源算法逐行解析(一):Encoder部分》,用于获得bev_embed
#在decoder中利用CustimMSDeformableAttention将bev_embed与query融合
bev_embed = self.get_bev_features(
   mlvl_feats,
   bev_queries,
   bev_h,
   bev_w,
   grid_length=grid_length,
   bev_pos=bev_pos,
   prev_bev=prev_bev,
   **kwargs)  # bev_embed shape: bs, bev_h*bev_w, embed_dims

bs = mlvl_feats[0].size(0)
#object_query_embed:torch.Size([900, 512])
#query_pos:torch.Size([900, 256])
#query:torch.Size([900, 256])
query_pos, query = torch.split(
   object_query_embed, self.embed_dims, dim=1)
query_pos = query_pos.unsqueeze(0).expand(bs, -1, -1)
query = query.unsqueeze(0).expand(bs, -1, -1)
#reference_points:torch.Size([1, 900, 3])
reference_points = self.reference_points(query_pos)
reference_points = reference_points.sigmoid()
init_reference_out = reference_points

#query:torch.Size([900, 1, 256])
query = query.permute(1, 0, 2)
#query_pos:torch.Size([900, 1, 256])
query_pos = query_pos.permute(1, 0, 2)
#bev_embed:torch.Size([50*50, 1, 256])
bev_embed = bev_embed.permute(1, 0, 2)

#进入decoder模块!
inter_states, inter_references = self.decoder(
   query=query,
   key=None,
   value=bev_embed,
   query_pos=query_pos,
   reference_points=reference_points,
   reg_branches=reg_branches,
   cls_branches=cls_branches,
   spatial_shapes=torch.tensor([[bev_h, bev_w]], device=query.device),
   level_start_index=torch.tensor([0], device=query.device),
   **kwargs)
#返回inter_states, inter_references
#后续用于提供给cls_branches和reg_branches分支得到目标的种类和bboxes
inter_references_out = inter_references

return bev_embed, inter_states, init_reference_out, inter_references_out
2 DetectionTransformerDecoder

功能:

  • 循环进入6个相同的 DetrTransformerDecoderLayer,一个 DetrTransformerDecoderLayer 包含 ('self_attn', 'norm', 'cross_attn', 'norm', 'ffn', 'norm'),每层输出 output 和 reference_points;

  • 在6层 DetrTransformerDecoderLayer 遍历完成后,将6层输出的 output 和 reference_points 输出。

解析:

#output:torch.Size([900, 1, 256])
output = query
intermediate = []
intermediate_reference_points = []
#循环进入6个相同的DetrTransformerDecoderLayer模块
for lid, layer in enumerate(self.layers):
   #reference_points_input:torch.Size([1, 900, 1, 2])
   #该reference_points在decoder中的CustimMSDeformableAttention作为融合bev_embed的基准采样点
   reference_points_input = reference_points[..., :2].unsqueeze(
       2)  # BS NUM_QUERY NUM_LEVEL 2
   #进入某一层DetrTransformerDecoderLayer
   output = layer(
       output,
       *args,
       reference_points=reference_points_input,
       key_padding_mask=key_padding_mask,
       **kwargs)
   #output:torch.Size([1, 900, 256])
   output = output.permute(1, 0, 2)

   if reg_branches is not None:
       #tmp:torch.Size([1, 900, 10])
       tmp = reg_branches[lid](output)

       assert reference_points.shape[-1] == 3
       #new_reference_pointtorch.Size([1, 900, 3])
       new_reference_points = torch.zeros_like(reference_points)
       new_reference_points[..., :2] = tmp[
           ..., :2] + inverse_sigmoid(reference_points[..., :2])
       new_reference_points[..., 2:3] = tmp[
           ..., 4:5] + inverse_sigmoid(reference_points[..., 2:3])

       new_reference_points = new_reference_points.sigmoid()

       reference_points = new_reference_points.detach()

   output = output.permute(1, 0, 2)
   if self.return_intermediate:
       intermediate.append(output)
       intermediate_reference_points.append(reference_points)
       
#在6层DetrTransformerDecoderLayer遍历完成后,将6层输出的output和reference_points输出。
if self.return_intermediate:
   return torch.stack(intermediate), torch.stack(
       intermediate_reference_points)

return output, reference_points

深色代码部分生成的 reference_points 结构见下图,其中 inverse_sigmoid(pt_reference_points) 即为reference_points/Linear(query_pos)

2.1 MultiheadAttention

功能:

  • object_query 的多头自注意力机制,如下图所示。

解析:

embed_dim = 256
kdim = embed_dim
vdim = embed_dim
qkv_same_embed_dim = kdim == embed_dim and vdim == embed_dim  # True
num_heads = 8
dropout = 0.1
batch_first = False
head_dim = embed_dim // num_heads
assert head_dim * num_heads == embed_dim, "embed_dim must be divisible by num_heads"
factory_kwargs = {'device': 'cuda', 'dtype': None}
in_proj_weight = nn.Parameter(torch.empty((3 * embed_dim, embed_dim), **factory_kwargs))
in_proj_bias = nn.Parameter(torch.empty(3 * embed_dim, **factory_kwargs))
bias_k = bias_v = None
add_zero_attn = False
out_proj = NonDynamicallyQuantizableLinear(embed_dim, embed_dim, bias=True, **factory_kwargs)
attn_mask = attn_mask  # None

if batch_first:
    query, key, value = [x.transpose(1, 0) for x in (query, key, value)]

if not qkv_same_embed_dim:
   # attn_output, attn_output_weights = F.multi_head_attention_forward(
   #     query, key, value, self.embed_dim, self.num_heads,
   #     self.in_proj_weight, self.in_proj_bias,
   #     self.bias_k, self.bias_v, self.add_zero_attn,
   #     self.dropout, self.out_proj.weight, self.out_proj.bias,
   #     training=self.training,
   #     key_padding_mask=key_padding_mask, need_weights=need_weights,
   #     attn_mask=attn_mask, use_separate_proj_weight=True,
   #     q_proj_weight=self.q_proj_weight, k_proj_weight=self.k_proj_weight,
   #     v_proj_weight=self.v_proj_weight)
   pass
else:
   attn_output, attn_output_weights = F.multi_head_attention_forward(
        query, key, value, _embed_dim, num_heads, in_proj_weight, in_proj_bias,
        bias_k, bias_v, add_zero_attn, dropout, out_proj.weight, out_proj.bias,
        training=True, key_padding_mask=None, need_weights=True, attn_mask=mhaf_attn_mask)
   -------------------------------F.multi_head_attention_forward start----------------------------
   out_proj_weight = out_proj.weight
   out_proj_bias = out_proj.bias
   key = key
   value = value
   embed_dim_to_check = embed_dim
   use_separate_proj_weight = False
   training = True
   key_padding_mask = None
   need_weights = True
   q_proj_weight, k_proj_weight, v_proj_weight = None, None, None
   static_k, static_v = None, None

   # set up shape vars
   tgt_len, bsz, embed_dim = query.shape  # torch.Size([900, 1, 256])
   src_len, _, _ = key.shape
   assert embed_dim == embed_dim_to_check, \
       f"was expecting embedding dimension of {embed_dim_to_check}, but got {embed_dim}"
   if isinstance(embed_dim, torch.Tensor):
   #     # embed_dim can be a tensor when JIT tracing
   #     head_dim = embed_dim.div(mhaf_num_heads, rounding_mode='trunc')
       pass
   else:
       head_dim = embed_dim // num_heads
   assert head_dim * num_heads == embed_dim, f"embed_dim {embed_dim} not divisible by num_heads {mhaf_num_heads}"

   if not use_separate_proj_weight:
       # q, k, v = _in_projection_packed(query, key, value, in_proj_weight, in_proj_bias)
       # -----------_in_projection_packed start-----------
       # q, k, v, w, b = query, mhaf_key, mhaf_value, mhaf_in_proj_weight, mhaf_in_proj_bias
       # E = query.size(-1)
       if key is value:
           # if query is mhaf_key:
           #     # self-attention
           #     return linear(query, mhaf_in_proj_weight, mhaf_in_proj_bias).chunk(3, dim=-1)
           # else:
           #     # encoder-decoder attention
           #     w_q, w_kv = mhaf_in_proj_weight.split([E, E * 2])
           #     if mhaf_in_proj_bias is None:
           #         b_q = b_kv = None
           #     else:
           #         b_q, b_kv = mhaf_in_proj_bias.split([E, E * 2])
           #     return (linear(query, w_q, b_q),) + linear(mhaf_key, w_kv, b_kv).chunk(2, dim=-1)
           pass
       else:
           w_q, w_k, w_v = in_proj_weight.chunk(3)
           if in_proj_bias is None:
               # b_q = b_k = b_v = None
               pass
           else:
               b_q, b_k, b_v = in_proj_bias.chunk(3)
           # return linear(query, w_q, b_q), linear(mhaf_key, w_k, b_k), linear(mhaf_value, w_v, b_v)
           # F.linear(x, A, b): return x @ A.T + b
           query, key, value = F.linear(query, w_q, b_q), F.linear(key, w_k, b_k), F.linear(value, w_v, b_v)
           #                                   query + pt_query_pos      query + pt_query_pos                 query
   # ------------_in_projection_packed end------------
   # else:
   #     assert q_proj_weight is not None, "use_separate_proj_weight is True but q_proj_weight is None"
   #     assert k_proj_weight is not None, "use_separate_proj_weight is True but k_proj_weight is None"
   #     assert v_proj_weight is not None, "use_separate_proj_weight is True but v_proj_weight is None"
   #     if in_proj_bias is None:
   #         b_q = b_k = b_v = None
   #     else:
   #         b_q, b_k, b_v = in_proj_bias.chunk(3)
   #     q, k, v = _in_projection(query, key, value, q_proj_weight, k_proj_weight, v_proj_weight, b_q, b_k, b_v)

   #
   # reshape q, k, v for multihead attention and make em batch first
   query = query.contiguous().view(tgt_len, bsz * num_heads, head_dim).transpose(0, 1)  # [900, 1, 256] -> [900, 8, 32] -> [8, 900, 32]
   if static_k is None:
       key = key.contiguous().view(-1, bsz * num_heads, head_dim).transpose(0, 1)  # [900, 8, 32] -> [8, 900, 32]
   # else:
   #     # TODO finish disentangling control flow so we don't do in-projections when statics are passed
   #     assert mhaf_static_k.size(0) == bsz * mhaf_num_heads, \
   #         f"expecting static_k.size(0) of {bsz * mhaf_num_heads}, but got {mhaf_static_k.size(0)}"
   #     assert mhaf_static_k.size(2) == head_dim, \
   #         f"expecting static_k.size(2) of {head_dim}, but got {mhaf_static_k.size(2)}"
   #     mhaf_key = mhaf_static_k
   if static_v is None:
       value = value.contiguous().view(-1, bsz * num_heads, head_dim).transpose(0, 1)  # [900, 8, 32] -> [8, 900, 32]
   # else:
   #     # TODO finish disentangling control flow so we don't do in-projections when statics are passed
   #     assert mhaf_static_v.size(0) == bsz * mhaf_num_heads, \
   #         f"expecting static_v.size(0) of {bsz * mhaf_num_heads}, but got {mhaf_static_v.size(0)}"
   #     assert mhaf_static_v.size(2) == head_dim, \
   #         f"expecting static_v.size(2) of {head_dim}, but got {mhaf_static_v.size(2)}"
   #     mhaf_value = mhaf_static_v

   # update source sequence length after adjustments
   src_len = key.size(1)

   attn_output, attn_output_weights = _scaled_dot_product_attention(query, key, value, attn_mask, dropout)
   # ------------_scaled_dot_product_attention start------------
   # q: Tensor,
   # k: Tensor,
   # v: Tensor,
   # attn_mask: Optional[Tensor] = None,
   # dropout_p: float = 0.0,
   B, Nt, E = query.shape  # torch.Size([8, 900, 32]), mhaf_key and mhaf_value is same shape.
   query = query / math.sqrt(E)
   # (B, Nt, E) x (B, E, Ns) -> (B, Nt, Ns)
   attn = torch.bmm(query, key.transpose(-2, -1))  # [8, 900, 32] @ [8, 32, 900] -> [8, 900, 900]
   # if mhaf_attn_mask is not None:
   #     attn += mhaf_attn_mask
   attn = F.softmax(attn, dim=-1)
   if dropout > 0.0:
       attn = F.dropout(attn, p=dropout)
   # (B, Nt, Ns) x (B, Ns, E) -> (B, Nt, E)
   output = torch.bmm(attn, value)  # [8, 900, 900] @ [8, 900, 32] -> # torch.Size([8, 900, 32])
   # return output, attn
   attn_output, attn_output_weights = output, attn
   # -------------_scaled_dot_product_attention end-------------
   # tgt_len: 900  # [8, 900, 32]->[900, 8, 32]->[900, 1, 256]
   attn_output = attn_output.transpose(0, 1).contiguous().view(tgt_len, bsz, embed_dim)  # torch.Size([900, 1, 256])
   attn_output = F.linear(attn_output, out_proj_weight, out_proj_bias)  # nn.Linear

out = attn_output
# ------------------------------self.attn end------------------------------

# return mha_identity + self.dropout_layer(self.proj_drop(out))
query = identity + dropout_layer(mha_proj_drop(out))
# torch.Size([900, 1, 256]) + # torch.Size([900, 1, 256])
2.2 CustomMSDeformableAttention

功能:

  • 利用可变形注意力机制将 encoder 模块输出的 bev_embed 融入 object_query,如下图所示;

  • 输出该层的 output,将其作为下一层 DetrTransformerDecoderLayer 的输入,并利用该层 output 生成该层对应的 reference_points。

解析:

#-------------------------CustomMSDeformableAttention init(in part)---------------------------------
sampling_offsets = nn.Linear(ca_embed_dims, ca_num_heads * ca_num_levels * ca_num_points * 2).cuda()
attention_weights = nn.Linear(ca_embed_dims, ca_num_heads * ca_num_levels * ca_num_points).cuda()
value_proj = nn.Linear(ca_embed_dims, ca_embed_dims).cuda()
output_proj = nn.Linear(ca_embed_dims, ca_embed_dims).cuda()
#-------------------------CustomMSDeformableAttention init(in part)---------------------------------
if value is None:
   value = query

if identity is None:
   identity = query
if query_pos is not None:
   query = query + query_pos
if not self.batch_first:
   # change to (bs, num_query ,embed_dims)
   #query:torch.Size([1, 900, 256])
   query = query.permute(1, 0, 2)
   #value(即bev_embed):torch.Size([1, 50*50, 256])
   value = value.permute(1, 0, 2)

bs, num_query, _ = query.shape
bs, num_value, _ = value.shape
assert (spatial_shapes[:, 0] * spatial_shapes[:, 1]).sum() == num_value

#value(即bev_embed):torch.Size([1, 50*50, 256])
value = self.value_proj(value)
if key_padding_mask is not None:
   value = value.masked_fill(key_padding_mask[..., None], 0.0)
#value:torch.Size([1, 50*50, 8, 32]),为多头做准备
value = value.view(bs, num_value, self.num_heads, -1)

sampling_offsets = self.sampling_offsets(query).view(
   bs, num_query, self.num_heads, self.num_levels, self.num_points, 2)
#    1,    900,          8,            1,             4,             2
attention_weights = self.attention_weights(query).view(
   bs, num_query, self.num_heads, self.num_levels * self.num_points)
#    1,    900,          8,                      4,            
attention_weights = attention_weights.softmax(-1)

#attention_weights:torch.Size([1, 900, 8, 1, 32])
attention_weights = attention_weights.view(bs, num_query,
                                           self.num_heads,
                                           self.num_levels,
                                           self.num_points)
#reference_points:torch.Size([1, 900, 1, 2])                                            
if reference_points.shape[-1] == 2:
   offset_normalizer = torch.stack(
       [spatial_shapes[..., 1], spatial_shapes[..., 0]], -1)
   #sampling_locations:torch.Size([1, 900, 8, 1, 4, 2])
   sampling_locations = reference_points[:, :, None, :, None, :] \
       + sampling_offsets \
       / offset_normalizer[None, None, None, :, None, :]
elif reference_points.shape[-1] == 4:
   sampling_locations = reference_points[:, :, None, :, None, :2] \
       + sampling_offsets / self.num_points \
       * reference_points[:, :, None, :, None, 2:] \
       * 0.5
else:
   raise ValueError(
       f'Last dim of reference_points must be'
       f' 2 or 4, but get {reference_points.shape[-1]} instead.')
if torch.cuda.is_available() and value.is_cuda:

   # using fp16 deformable attention is unstable because it performs many sum operations
   if value.dtype == torch.float16:
       MultiScaleDeformableAttnFunction = MultiScaleDeformableAttnFunction_fp32
   else:
       MultiScaleDeformableAttnFunction = MultiScaleDeformableAttnFunction_fp32
   output = MultiScaleDeformableAttnFunction.apply(
       value, spatial_shapes, level_start_index, sampling_locations,
       attention_weights, self.im2col_step)
else:
   #output:torch.Size([1, 900, 256])
   #可变形注意力机制,利用query从value(bev_embed)中提取有用信息
   output = multi_scale_deformable_attn_pytorch(
       value, spatial_shapes, sampling_locations, attention_weights)
       
#output:torch.Size([1, 900, 256])
output = self.output_proj(output)

if not self.batch_first:
   # (num_query, bs ,embed_dims)
   output = output.permute(1, 0, 2)

return self.dropout(output) + identity
3 cls_branches&®_branches

功能:

  • 利用 decoder 输出的 bev_embed, inter_states(6层输出的 outs), init_reference_out(由 query_pos 生成的初始reference_points), inter_references_out(6层输出的 reference_points)生成目标类别和 bboxes;

  • 生成包含 bev_embed、 all_cls_scores、all_bbox_preds 在内的 outs,其中 all_cls_scores、all_bbox_preds 用于计算 Loss、梯度回传;bev_embed 可用于 segmentation 等任务,进行 BEV 视角下的语义分割。

解析:

#以下变量的含义见《BEVFormer开源算法逐行解析(一):Encoder部分》
bs, num_cam, _, _, _ = mlvl_feats[0].shape
dtype = mlvl_feats[0].dtype
object_query_embeds = self.query_embedding.weight.to(dtype)
bev_queries = self.bev_embedding.weight.to(dtype)
bev_mask = torch.zeros((bs, self.bev_h, self.bev_w),
                       device=bev_queries.device).to(dtype)
bev_pos = self.positional_encoding(bev_mask).to(dtype)

if only_bev:  # only use encoder to obtain BEV features, TODO: refine the workaround
   return self.transformer.get_bev_features(
       mlvl_feats,
       bev_queries,
       self.bev_h,
       self.bev_w,
       grid_length=(self.real_h / self.bev_h,
                       self.real_w / self.bev_w),
       bev_pos=bev_pos,
       img_metas=img_metas,
       prev_bev=prev_bev,
   )
else:
   #outputs就是object_query_embeds、bev_pos、bev_queries、img_metas和mlvl_feats
   #输入encoder和decoder模块后的最终输出
   #outputs:bev_embed, inter_states, init_reference_out, inter_references_out
   outputs = self.transformer(
       mlvl_feats,
       bev_queries,
       object_query_embeds,
       self.bev_h,
       self.bev_w,
       grid_length=(self.real_h / self.bev_h,
                       self.real_w / self.bev_w),
       bev_pos=bev_pos,
       reg_branches=self.reg_branches if self.with_box_refine else None,  # noqa:E501
       cls_branches=self.cls_branches if self.as_two_stage else None,
       img_metas=img_metas,
       prev_bev=prev_bev
)

bev_embed, hs, init_reference, inter_references = outputs
hs = hs.permute(0, 2, 1, 3)
outputs_classes = []
outputs_coords = []
for lvl in range(hs.shape[0]):
   if lvl == 0:
       reference = init_reference
   else:
       reference = inter_references[lvl - 1]
   reference = inverse_sigmoid(reference)
   #outputs_class:torch.Size([1, 900, 10])
   outputs_class = self.cls_branches[lvl](hs[lvl])
   #tmp:torch.Size([1, 900, 10])
   tmp = self.reg_branches[lvl](hs[lvl])

   # TODO: check the shape of reference
   assert reference.shape[-1] == 3
   tmp[..., 0:2] += reference[..., 0:2]
   tmp[..., 0:2] = tmp[..., 0:2].sigmoid()
   tmp[..., 4:5] += reference[..., 2:3]
   tmp[..., 4:5] = tmp[..., 4:5].sigmoid()
   #下面" *(self.pc_range[3] -self.pc_range[0]) + self.pc_range[0]",
   #是为了将目标bboxes中心点x、y、z坐标恢复到实际尺度
   tmp[..., 0:1] = (tmp[..., 0:1] * (self.pc_range[3] -
                       self.pc_range[0]) + self.pc_range[0])
   tmp[..., 1:2] = (tmp[..., 1:2] * (self.pc_range[4] -
                       self.pc_range[1]) + self.pc_range[1])
   tmp[..., 4:5] = (tmp[..., 4:5] * (self.pc_range[5] -
                       self.pc_range[2]) + self.pc_range[2])

   # TODO: check if using sigmoid
   outputs_coord = tmp
   outputs_classes.append(outputs_class)
   outputs_coords.append(outputs_coord)
#outputs_classes:torch.Size([6, 1, 900, 10])
outputs_classes = torch.stack(outputs_classes)
#outputs_coords:torch.Size([6, 1, 900, 10])
outputs_coords = torch.stack(outputs_coords)

outs = {
   'bev_embed': bev_embed,
   'all_cls_scores': outputs_classes,
   'all_bbox_preds': outputs_coords,
   'enc_cls_scores': None,
   'enc_bbox_preds': None,
}

#outs输出后就可以和class_labels和bboxe_labels一起计算Loss,
#然后反向传播梯度,更新模型中的可学习参数:
#各个线性层、object_query_embeds、bev_queries、bev_pos等
return outs

深色代码部分生成的 tmp[0:2]和tmp[4:5] 结构见下图,实质上就是"DetectionTransformerDecoder"中生成的 reference_points。

结语:

至此,BEVFormer 中的 Encoder 和 Decoder 部分的逐行代码解析就完成了,如果后续有需求也可以再出一期关于解析 Loss 计算的文档,这部分比较基础,有兴趣的同学也可以先结合源码自学。


*博客内容为网友个人发布,仅代表博主个人观点,如有侵权请联系工作人员删除。

参与讨论
登录后参与讨论
推荐文章
最近访客