新闻  |   论坛  |   博客  |   在线研讨会
BEVFormer 开源算法逐行解析(一):Encoder 部分
地平线开发者 | 2024-09-08 10:31:50    阅读:10680   发布文章

写在前面:

对于 BEVFormer 算法框架的整体理解,大家可以找到大量的资料参考,但是对于算法代码的解读缺乏详实的资料。因此,本系列的目的是结合代码实现细节、在 tensor 维度的变换中帮助读者对算法能有更直观的认识。

本系列我们将对 BEVFormer 公版代码(开源算法)进行逐行解析,以结合代码理解 Bevformer 原理,掌握算法细节,帮助读者们利用该算法框架开发感知算法。在本系列的最后笔者还将面向地平线的用户,指出地平线参考算法在开源算法基础上做出的修改及修改背后的考虑,在算法部署过程中为用户提供参考。

公版代码目录封装较好,且以注册器的方式调用模型,各个模块的调用关系可以从 configs/bevformer 中的 config 文件中清晰体现,我们以 bevformer_tiny.py 为例3解析代码。

对代码的解析和理解主要体现在代码注释中。

model = dict(
   type='BEVFormer',
   use_grid_mask=True,
   video_test_mode=True,
   pretrained=dict(img='torchvision://resnet50'), #预训练权重
   img_backbone=dict(
       type='ResNet', #主干网络
       ....
       ),
   img_neck=dict(
       type='FPN', #颈部网络
       ...
       ),
   pts_bbox_head=dict(
       type='BEVFormerHead', #进入transformer
       ...,
       transformer=dict(
           type='PerceptionTransformer', #进入transformer
           ...,
           encoder=dict(
               type='BEVFormerEncoder', #进入Encoder
               ...,
               transformerlayers=dict(
                   type='BEVFormerLayer', #进入Encoder中的一层
                   attn_cfgs=[
                       dict(
                           type='TemporalSelfAttention', #时序注意力机制
                           ...,
                       dict(
                           type='SpatialCrossAttention', #空间注意力机制
                           ...,
                           deformable_attention=dict(
                               type='MSDeformableAttention3D',
                               ...),
                           ...,
                       )
                   ],
           ...),
           decoder=dict(
               type='DetectionTransformerDecoder', #进入Decoder
               ...,
               transformerlayers=dict(
                   type='DetrTransformerDecoderLayer', #进入Decoder中的一层
                   attn_cfgs=[
                       dict(
                           type='MultiheadAttention', #多头注意力机制
                       ...),
                        dict(
                           type='CustomMSDeformableAttention',
                       ...),
                   ],
))),

从上述 config 文件可以看出,6个相机输出的图像在前向传播过程中依次经过了'ResNet'、'FPN'获得了图像特征,然后经过'BEVFormerHead'模块中的'BEVFormerEncoder'和'DetectionTransformerDecoder'完成了特征融合的全过程。其中'BEVFormerEncoder'包括前后级联的'TemporalSelfAttention'和'SpatialCrossAttention',这种前后级联的结构在 bevformer_tiny 中一共有3层。

1 BEVFormer:

功能:

  • 通过 grid_mask 进行了图像增强;

  • 利用 ResNet(backbone)和FPN(neck)两个网络提取图像特征;

  • 进入 BEVFormerHead 中。

解析:

#img: (bs,queue=3,num_cams=6,C=3,H=480,W=800)
#按照queue长度将图像分为pre_img和img
len_queue = img.size(1)
prev_img = img[:, :-1, ...] #(1,2,6,3,480,800)
img = img[:, -1, ...] #(1,6,3,480,800)
     
#利用img_queue中除当前帧之外的前几帧生成BEV_pre
prev_img_metas = copy.deepcopy(img_metas)
prev_bev = self.obtain_history_bev(prev_img, prev_img_metas)
#------------------------obtain_history_bev start--------------------------------------
#利用img_queue中除当前帧之外的前2帧生成BEV_pre,供后续TSA使用
def obtain_history_bev(self, imgs_queue, img_metas_list):
   """Obtain history BEV features iteratively. To save GPU memory, gradients are not calculated.
   """
   self.eval()
   with torch.no_grad():
       prev_bev = None
       #imgs_queue:torch.size(1,2,6,3,480,800)
       bs, len_queue, num_cams, C, H, W = imgs_queue.shape
       #imgs_queue:torch.size(1,6,3,480,800)
       imgs_queue = imgs_queue.reshape(bs*len_queue, num_cams, C, H, W)
       img_feats_list = self.extract_feat(img=imgs_queue, len_queue=len_queue
#--------------------------------extract_feat start--------------------------------------
#提取图像特征并reshape
@auto_fp16(apply_to=('img'))
def extract_feat(self, img, img_metas=None, len_queue=None):
   """Extract features from images and points."""
   img_feats = self.extract_img_feat(img, img_metas, len_queue=len_queue)
#--------------------------------extract_feat end----------------------------------------
#-----------------------------extract_img_feat start---------------------------------------
def extract_img_feat(self, img, img_metas, len_queue=None):
   """Extract features of images."""
   # B=batch_size
   B = img.size(0)
   if img is not None:
       if img.dim() == 5 and img.size(0) == 1:
           img.squeeze_()
       elif img.dim() == 5 and img.size(0) > 1:
           B, N, C, H, W = img.size()
           img = img.reshape(B * N, C, H, W)
       #图像增强的一种手段,利用mask遮挡部分图像,让网络学习目标更多的特征,避免过拟合
       if self.use_grid_mask:
           #从obtain_history_bev()中进入:
           #img:torch.size(12,2,480,800) queue=3-1=2(t-2,t-1时刻两帧用以生成prev_bev)
           #在obtain_history_bev后,从extract_feat()中进入:
           #img:torch.size(12,2,480,800) queue=3-2=1(t时刻当前帧和prev_bev生成bev_query)        
           img = self.grid_mask(img)
       #从obtain_history_bev()中进入:
       #img_feats:tuple(torch.Size(12,2048,15,25)) 12=2*6:queue=3-1=2,cam_num=6
       #在obtain_history_bev后,从extract_feat()中进入:
       #img_feats:tuple(torch.Size(6,2048,15,25)) 12=1*6:queue=3-2=1,cam_num=6
       img_feats = self.img_backbone(img)
       if isinstance(img_feats, dict):
           img_feats = list(img_feats.values())
   else:
       return None
   if self.with_img_neck:
       #同理,img_feats:tuple(torch.Size(12,256,15,25))/tuple(torch.Size(6,256,15,25))
       img_feats = self.img_neck(img_feats)
   img_feats_reshaped = []
   for img_feat in img_feats:
       BN, C, H, W = img_feat.size()
       #非第一帧(3-1=2)
       if len_queue is not None:
           #img_feat的形状为:
           #torch.size(B/len_queue=1,len_queue=2),num_cams=6,C=256,H=15,W=25)
           img_feats_reshaped.append(img_feat.view(int(B/len_queue), len_queue, int(BN / B), C, H, W))
       #第一帧(1)
       else:
           img_feats_reshaped.append(img_feat.view(B, int(BN / B), C, H, W))
   return img_feats_reshaped  
#-----------------------------extract_img_feat end----------------------------------------
for i in range(len_queue):
   img_metas = [each[i] for each in img_metas_list]
   if not img_metas[0]['prev_bev_exists']:
           prev_bev = None
      # img_feats = self.extract_feat(img=img, img_metas=img_metas)
      # img_feats 按照queue维度进行切片,从6维度降到5维,制作成一个列表
      img_feats = [each_scale[:, i] for each_scale in img_feats_list]
      #利用img_feats和img_metas生成prev_bev
      ##从pts_bbox_head进入下一环节BEVFormerHead(包含encoder、decoder)
      prev_bev = self.pts_bbox_head(
           img_feats, img_metas, prev_bev, only_bev=True)
      self.train()
      return prev_bev
#------------------------obtain_history_bev end----------------------------------------
img_metas = [each[len_queue-1] for each in img_metas]
if not img_metas[0]['prev_bev_exists']:
   prev_bev = None
#提取当前帧图像特征
img_feats = self.extract_feat(img=img, img_metas=img_metas)
#obtain_history_bev用于利用t-2、t-1时刻的图像和img_metas生成pre_bev
#然后将当前帧图像特征img_feats、obtain_history_bev生成的prev_bev和当前图像帧对应的img_metas
#以及bboxes_labels、class_labels输入forward_pts_train计算loss
#在forward_pts_train中进入BEVFormerHaed
losses = dict()
losses_pts = self.forward_pts_train(img_feats, gt_bboxes_3d,
                                   gt_labels_3d, img_metas,
                                   gt_bboxes_ignore, prev_bev)
#------------------------------forward_pts_train start-----------------------------------------
def forward_pts_train(self,pts_feats,gt_bboxes_3d,gt_labels_3d,img_metas,gt_bboxes_ignore=None,prev_bev=None):
   #从pts_bbox_head进入下一环节BEVFormerHead(包含encoder、decoder)
   outs = self.pts_bbox_head(
       pts_feats, img_metas, prev_bev)
   loss_inputs = [gt_bboxes_3d, gt_labels_3d, outs]
   losses = self.pts_bbox_head.loss(*loss_inputs, img_metas=img_metas)
   return losses
#------------------------------forward_pts_train end-----------------------------------------
losses.update(losses_pts)  
2 BEVFormerHead:

功能:

  • 定义了后面要使用的 bev_queries(即论文中 BEV Queries)、bev_pos(表征 BEV Queries 各个 Query 的位置信息)、query_embeds;

  • 计算了后续要使用的 pc_range、real_w 和 real_h;

  • 从该模块进入 PerceptionTransformer。

解析:

#定义了bev_query所表征的空间范围
self.bbox_coder = build_bbox_coder(bbox_coder)
#pc_range = [-51.2, -51.2, -5.0, 51.2, 51.2, 3.0](单位:m)
self.pc_range = self.bbox_coder.pc_range
self.real_w = self.pc_range[3] - self.pc_range[0]
self.real_h = self.pc_range[4] - self.pc_range[1]
bs, num_cam, _, _, _ = mlvl_feats[0].shape
#object_query_embeds和object_query_embeds是在_init_layers中定义的
#_init_layers仅在BEVFormerHead初始化时进行
#即object_query_embeds、bev_queries、bev_pos是可学习的、复用的
#object_query_embeds:torch.Size([900,256])
object_query_embeds = self.query_embedding.weight.to(dtype)
#bev_queries:torch.Size([50*50,256])
bev_queries = self.bev_embedding.weight.to(dtype)
bev_mask = torch.zeros((bs, self.bev_h, self.bev_w),
                       device=bev_queries.device).to(dtype) #torch.Size([1,50,50])
#利用bev_mask生成bev_pos作为bev_queries的位置编码
#bev_pos:torch.Size([1, 256, 50, 50])                        
bev_pos = self.positional_encoding(bev_mask).to(dtype)
#------------------------BEVFormerHead._init_layers start---------------------------
self.bev_embedding = nn.Embedding(
   self.bev_h * self.bev_w, self.embed_dims)
self.query_embedding = nn.Embedding(self.num_query,
                                   self.embed_dims * 2)
positional_encoding=dict(
   type='LearnedPositionalEncoding',
   num_feats=_pos_dim_, #128
   row_num_embed=bev_h_, #50
   col_num_embed=bev_w_, #50
   )
#------------------------BEVFormerHead._init_layers end-----------------------------
#-----------------------position_encodeing start------------------------------------
pe_num_feats, pe_row_num_embed, pe_col_num_embed = pos_dim, bev_h, bev_w
pe_row_embed = nn.Embedding(pe_row_num_embed, pe_num_feats).cuda()
pe_col_embed = nn.Embedding(pe_col_num_embed, pe_num_feats).cuda()
pe_h, pe_w = bev_mask.shape[-2:]  # torch.Size([1, 50, 50])
pe_x = torch.arange(pe_w, device=bev_mask.device)
pe_y = torch.arange(pe_h, device=bev_mask.device)
pe_x_embed = pe_col_embed(pe_x)  # torch.Size([50, 128])
pe_y_embed = pe_row_embed(pe_y)  # torch.Size([50, 128])
pe_pos = torch.cat((pe_x_embed.unsqueeze(0).repeat(pe_h, 1, 1), pe_y_embed.unsqueeze(1).repeat(1, pe_w, 1)),
                   dim=-1).permute(2, 0, 1).unsqueeze(0).repeat(bev_mask.shape[0], 1, 1, 1)
# return pos
bev_pos = pe_pos
#-----------------------position_encodeing_end------------------------------------
#下面我们先只关注BEVFormerHead中Encoder部分
#给Encoder模块输入了mivl_feats、bev_queries、bev_pos和prev_bev
if only_bev:  # only use encoder to obtain BEV features, TODO: refine the workaround
   return self.transformer.get_bev_features(
       mlvl_feats,
       bev_queries,
       self.bev_h,
       self.bev_w,
       grid_length=(self.real_h / self.bev_h,
                       self.real_w / self.bev_w), #102.4/50=2.048
       bev_pos=bev_pos,
       img_metas=img_metas,
       prev_bev=prev_bev,
   )
....
3 PerceptionTransformer:

功能:

  • 利用 level_embeds、cams_embeds 对 img_feats 进行图像增强

  • 将 can_bus 信息融入 bev_queries

  • 如果存在 prev_bev,将 prev_bev 进行旋转以便与当前 bev 对齐

  • 从该模块进入 Encoder

解析:

def init_layers(self):
   """Initialize layers of the Detr3DTransformer."""
   #用于增强图像特征的level_embeds
   self.level_embeds = nn.Parameter(torch.Tensor(
       self.num_feature_levels, self.embed_dims))
   #用于增强图像特征的cams_embeds
   self.cams_embeds = nn.Parameter(
       torch.Tensor(self.num_cams, self.embed_dims))
   self.reference_points = nn.Linear(self.embed_dims, 3)
   #线性层18->256
   self.can_bus_mlp = nn.Sequential(
       nn.Linear(18, self.embed_dims // 2),
       nn.ReLU(inplace=True),
       nn.Linear(self.embed_dims // 2, self.embed_dims),
       nn.ReLU(inplace=True),
   )
   if self.can_bus_norm:
       self.can_bus_mlp.add_module('norm', nn.LayerNorm(self.embed_dims))
def get_bev_features(
       self,
       mlvl_feats,
       bev_queries,
       bev_h,
       bev_w,
       grid_length=[0.512, 0.512],
       bev_pos=None,
       prev_bev=None,
       **kwargs):

   #mlvl_feats:[torch.Size([1, 6, 256, 15, 25])]
   bs = mlvl_feats[0].size(0)
   #bev_queries:torch.Size([50*50,256])->torch.Size([50*50,1,256])
   bev_queries = bev_queries.unsqueeze(1).repeat(1, bs, 1)
   #bev_pos:torch.Size([1, 256, 50, 50]) -> torch.Size([50*50,1,256])
   bev_pos = bev_pos.flatten(2).permute(2, 0, 1)
#-------------------------------------grid_length start----------------------------------------------
   grid_length = (real_h / bev_h, real_w / bev_w)
#-------------------------------------grid_length  end------------------------------------------------
   # obtain rotation angle and shift with ego motion
   #获得自车在x方向上的位移
   delta_x = np.array([each['can_bus'][0]
                       for each in kwargs['img_metas']])
   #获得自车在y方向上的位移
   delta_y = np.array([each['can_bus'][1]
                       for each in kwargs['img_metas']])
   #获得自车航向角
   ego_angle = np.array(
       [each['can_bus'][-2] / np.pi * 180 for each in kwargs['img_metas']])
   #计算shift,并在bev_h、bev_w尺度上进行归一化
   grid_length_y = grid_length[0]
   grid_length_x = grid_length[1]
   translation_length = np.sqrt(delta_x ** 2 + delta_y ** 2)
   translation_angle = np.arctan2(delta_y, delta_x) / np.pi * 180
   bev_angle = ego_angle - translation_angle
   #将车辆位置变化投影到BEV网格上,并进行归一化
   #其实用不到shift,因为prev_bev存在时只进行了rotation操作,并未平移到bev
   #因此后续prev_bev的ref_2d也不需要shift到当前时刻
   shift_y = translation_length * \
       np.cos(bev_angle / 180 * np.pi) / grid_length_y / bev_h
   shift_x = translation_length * \
       np.sin(bev_angle / 180 * np.pi) / grid_length_x / bev_w
   shift_y = shift_y * self.use_shift
   shift_x = shift_x * self.use_shift
   shift = bev_queries.new_tensor(
       [shift_x, shift_y]).permute(1, 0)  # xy, bs -> bs, xy # torch.Size([1, 2]) and normalize 0~
   #第一组图片(queue(2)中的第1组,num_cam=6)进入时,prev_bev = None,此后prev_bev不为0
   if prev_bev is not None:
       if prev_bev.shape[1] == bev_h * bev_w: # torch.Size([1, 50*50, 256])
           prev_bev = prev_bev.permute(1, 0, 2) # torch.Size([50*50,1,256])
       #对prev_bev进行旋转以便与当前bev对齐
       #TODO:旋转中心为何取(100,100)?
       if self.rotate_prev_bev:
           for i in range(bs):
               # num_prev_bev = prev_bev.size(1)
               rotation_angle = kwargs['img_metas'][i]['can_bus'][-1]
               tmp_prev_bev = prev_bev[:, i].reshape(
                   bev_h, bev_w, -1).permute(2, 0, 1)
               tmp_prev_bev = rotate(tmp_prev_bev, rotation_angle,
                                       center=self.rotate_center)
               tmp_prev_bev = tmp_prev_bev.permute(1, 2, 0).reshape(
                   bev_h * bev_w, 1, -1)
               prev_bev[:, i] = tmp_prev_bev[:, 0]
   # add can bus signals
   can_bus = bev_queries.new_tensor(
       [each['can_bus'] for each in kwargs['img_metas']])  # [:, :] # torch.Size([1, 18])
   can_bus = self.can_bus_mlp(can_bus)[None, :, :] # torch.Size([1, 1, 256])
   #以广播的方式在bev_queries中融入了can_bus信息
   bev_queries = bev_queries + can_bus * self.use_can_bus
   feat_flatten = []
   spatial_shapes = []
   #mlvl_feats:[torch.Size([1, 6, 256, 15, 25])]
   for lvl, feat in enumerate(mlvl_feats):
   #feat:torch.Size([1, 6, 256, 15, 25])
       bs, num_cam, c, h, w = feat.shape
       spatial_shape = (h, w)
       #feat:torch.Size([1,6,256,15,25])->torch.Size([1,6,256,15*25])->torch.Size([6,1,15*25,256])
       feat = feat.flatten(3).permute(1, 0, 3, 2)
       #cams_embeds:torch.Size([6,256]),添加cam编码
       if self.use_cams_embeds:
           feat = feat + self.cams_embeds[:, None, None, :].to(feat.dtype)
       #level_embeds:torch.Size([4,256]),添加level编码,num_level=1
       feat = feat + self.level_embeds[None,
                                       None, lvl:lvl + 1, :].to(feat.dtype)
       spatial_shapes.append(spatial_shape)
       feat_flatten.append(feat)
   #将不同level上的在C通道上进行cat操作,但是对于tiny版本num_level=1,所以这里没有意义
   feat_flatten = torch.cat(feat_flatten, 2)
   spatial_shapes = torch.as_tensor(
       spatial_shapes, dtype=torch.long, device=bev_pos.device) # torch.Size([1, 2])
   # 将不同(H, W)的平面flatten为(H*W)的向量后cat在一起每个平面的起始下标
   level_start_index = torch.cat((spatial_shapes.new_zeros(
       (1,)), spatial_shapes.prod(1).cumsum(0)[:-1]))
   # torch.Size([6, 15*25, 1, 256])
   feat_flatten = feat_flatten.permute(
       0, 2, 1, 3)  # (num_cam, H*W, bs, embed_dims)
   #进入encoder环节!
   bev_embed = self.encoder(
       bev_queries,
       feat_flatten,
       feat_flatten,
       bev_h=bev_h,
       bev_w=bev_w,
       bev_pos=bev_pos,
       spatial_shapes=spatial_shapes,
       level_start_index=level_start_index,
       prev_bev=prev_bev,
       shift=shift,
       **kwargs
   )
   return bev_embed
4 Encoder:

功能:

  • 生成 ref_3d 和 ref_2d,供后续 TSA、CSA 模块中的 deformable attentiuon 采样使用;

  • 利用6个相机的内参、外参,将真实尺度下的 ref_3d 从自车坐标系(Lidar坐标系)投影到6个相机的像素坐标系内,判断哪些点会出现在哪些相机内,该信息体现在 bev_mask 内;

  • 循环进入3个相同的 BEVFormerLayer 模块,一个 BEVFormerLayer 包含 ('self_attn', 'norm', 'cross_attn', 'norm', 'ffn', 'norm')。

解析:

output = bev_query
intermediate = []
ref_3d = self.get_reference_points(
   bev_h, bev_w, self.pc_range[5]-self.pc_range[2], self.num_points_in_pillar, dim='3d', bs=bev_query.size(1),  device=bev_query.device, dtype=bev_query.dtype)
ref_2d = self.get_reference_points(
   bev_h, bev_w, dim='2d', bs=bev_query.size(1), device=bev_query.device, dtype=bev_query.dtype)
#----------------------------get_reference_points start---------------------------------------
if dim == '3d':
   zs = torch.linspace(0.5, Z - 0.5, num_points_in_pillar, dtype=dtype,
                       device=device).view(-1, 1, 1).expand(num_points_in_pillar, H, W) / Z
   xs = torch.linspace(0.5, W - 0.5, W, dtype=dtype,
                       device=device).view(1, 1, W).expand(num_points_in_pillar, H, W) / W
   ys = torch.linspace(0.5, H - 0.5, H, dtype=dtype,
                       device=device).view(1, H, 1).expand(num_points_in_pillar, H, W) / H
   ref_3d = torch.stack((xs, ys, zs), -1)
   ref_3d = ref_3d.permute(0, 3, 1, 2).flatten(2).permute(0, 2, 1)
   ref_3d = ref_3d[None].repeat(bs, 1, 1, 1) # torch.Size([1, 4, 50*50, 3])
# reference points on 2D bev plane, used in temporal self-attention (TSA).
elif dim == '2d':
   ref_y, ref_x = torch.meshgrid(
       torch.linspace(
           0.5, H - 0.5, H, dtype=dtype, device=device),
       torch.linspace(
           0.5, W - 0.5, W, dtype=dtype, device=device)
   )
   ref_y = ref_y.reshape(-1)[None] / H
   ref_x = ref_x.reshape(-1)[None] / W
   ref_2d = torch.stack((ref_x, ref_y), -1)
   #ref_2d:torch.Size([1, 50*50, 1, 2])
   ref_2d = ref_2d.repeat(bs, 1, 1).unsqueeze(2)
#----------------------------get_reference_points end---------------------------------------
reference_points_cam, bev_mask = self.point_sampling(
reference_points_cam, bev_mask = self.point_sampling(
   ref_3d, self.pc_range, kwargs['img_metas'])
#----------------------------point_sampling start---------------------------------------------
allow_tf32 = torch.backends.cuda.matmul.allow_tf32
torch.backends.cuda.matmul.allow_tf32 = False
torch.backends.cudnn.allow_tf32 = False
lidar2img = []
#lidar2img相当于相机外参
for img_meta in img_metas:
   lidar2img.append(img_meta['lidar2img'])
lidar2img = np.asarray(lidar2img)
lidar2img = reference_points.new_tensor(lidar2img)  # (B, N, 4, 4) # torch.Size([1, 6, 4, 4])
reference_points = reference_points.clone()
# reference_points = ref_3d # normalize 0~1 # torch.Size([1, 4, 50*50, 3])
# pc_range = [-51.2, -51.2, -5.0, 51.2, 51.2, 3.0]
# reference_points映射到自车坐标系下, 尺度被缩放为真实尺度
reference_points[..., 0:1] = reference_points[..., 0:1] * \
   (pc_range[3] - pc_range[0]) + pc_range[0]
reference_points[..., 1:2] = reference_points[..., 1:2] * \
   (pc_range[4] - pc_range[1]) + pc_range[1]
reference_points[..., 2:3] = reference_points[..., 2:3] * \
   (pc_range[5] - pc_range[2]) + pc_range[2]
# 变成齐次坐标  # bs, num_points_in_pillar, bev_h * bev_w, xyz1
reference_points = torch.cat(
   (reference_points, torch.ones_like(reference_points[..., :1])), -1)
# num_points_in_pillar, bs, bev_h * bev_w, xyz1
reference_points = reference_points.permute(1, 0, 2, 3)
D, B, num_query = reference_points.size()[:3]
num_cam = lidar2img.size(1)
# num_points_in_pillar, bs, num_cam, bev_h * bev_w, xyz1, None
reference_points = reference_points.view(
   D, B, 1, num_query, 4).repeat(1, 1, num_cam, 1, 1).unsqueeze(-1)
# num_points_in_pillar, bs, num_cam, bev_h * bev_w, 4, 4
lidar2img = lidar2img.view(
   1, B, num_cam, 1, 4, 4).repeat(D, 1, 1, num_query, 1, 1)
# num_points_in_pillar, bs, num_cam, bev_h * bev_w, xys1
reference_points_cam = torch.matmul(lidar2img.to(torch.float32),
                                   reference_points.to(torch.float32)).squeeze(-1)
eps = 1e-5
# 只保留位于相机前方的点
bev_mask = (reference_points_cam[..., 2:3] > eps)
# 齐次坐标下除以比例系数得到图像平面的坐标真值
# num_points_in_pillar, bs, num_cam, bev_h * bev_w, uv
reference_points_cam = reference_points_cam[..., 0:2] / torch.maximum(
   reference_points_cam[..., 2:3], torch.ones_like(reference_points_cam[..., 2:3]) * eps)
# 像素坐标归一化
reference_points_cam[..., 0] /= img_metas[0]['img_shape'][0][1]
reference_points_cam[..., 1] /= img_metas[0]['img_shape'][0][0]
#从相机感受野的角度再删选一次
bev_mask = (bev_mask & (reference_points_cam[..., 1:2] > 0.0)
           & (reference_points_cam[..., 1:2] < 1.0)
           & (reference_points_cam[..., 0:1] < 1.0)
           & (reference_points_cam[..., 0:1] > 0.0))
if digit_version(TORCH_VERSION) >= digit_version('1.8'):
   bev_mask = torch.nan_to_num(bev_mask)
else:
   bev_mask = bev_mask.new_tensor(
       np.nan_to_num(bev_mask.cpu().numpy()))
# num_cam, bs, bev_h * bev_w, num_points_in_pillar, uv
# torch.Size([6, 1, 50*50, 4, 2])
reference_points_cam = reference_points_cam.permute(2, 1, 3, 0, 4)
# ( num_points_in_pillar, bs, num_cam, bev_h * bev_w,bool)
# ->(num_cam, bs, bev_h * bev_w, num_points_in_pillar)(bool)
bev_mask = bev_mask.permute(2, 1, 3, 0, 4).squeeze(-1)
torch.backends.cuda.matmul.allow_tf32 = allow_tf32
torch.backends.cudnn.allow_tf32 = allow_tf32
#----------------------------point_sampling end---------------------------------------------
# bug: this code should be 'shift_ref_2d = ref_2d.clone()', we keep this bug for reproducing our results in paper.
shift_ref_2d = ref_2d.clone()
#这一步不必要,因为prev_bev和bev已经完成了对齐,prev_bev的ref_2d和bev的ref_2d保持一致就好
shift_ref_2d += shift[:, None, None, :]
# (num_query,bs,embed_dims) -> (bs,num_query,embed_dims)
# torch.Size([50*50, 1, 256]) -> torch.Size([1,50*50,256])
bev_query = bev_query.permute(1, 0, 2)
# torch.Size([50*50,1,256]) -> torch.Size([1,50*50,256])
bev_pos = bev_pos.permute(1, 0, 2)
# bs, bev_h * bev_w, None, xy # torch.Size([1, 50*50, 1, 2])
bs, len_bev, num_bev_level, _ = ref_2d.shape
#当存在prev_bev时,将上一时刻prev_bev和当前时刻bev进行cat作为prev_bev
#传入layer中作为value。否则在layer中将当前时刻的两个bev进行cat作为value。
if prev_bev is not None:
   prev_bev = prev_bev.permute(1, 0, 2)
   prev_bev = torch.stack(
       [prev_bev, bev_query], 1).reshape(bs*2, len_bev, -1) #2 torch.Size([2, 50*50, 256])
   hybird_ref_2d = torch.stack([shift_ref_2d, ref_2d], 1).reshape(
       bs*2, len_bev, num_bev_level, 2) #2 torch.Size([2, 50*50, 1, 2])
else:
   hybird_ref_2d = torch.stack([ref_2d, ref_2d], 1).reshape(
       bs*2, len_bev, num_bev_level, 2)  
#进入encoder的三层!!!每层包含 ('self_attn', 'norm', 'cross_attn', 'norm', 'ffn', 'norm')
for lid, layer in enumerate(self.layers):
    output = layer(
        bev_query,
        key,
        value,
        *args,
        bev_pos=bev_pos,
        ref_2d=hybird_ref_2d,
        ref_3d=ref_3d,
        bev_h=bev_h,
        bev_w=bev_w,
        spatial_shapes=spatial_shapes,  # torch.Size([1, 2])
        level_start_index=level_start_index,
        reference_points_cam=reference_points_cam,
        bev_mask=bev_mask,
        prev_bev=prev_bev,
        **kwargs)
   bev_query = output
self.layers: [BEVFormerLayer, BEVFormerLayer, BEVFormerLayer] # 3个BEVFormerLayer的参数完全一致的    

prev_bev 已经通过和 bev 对齐,这里的 ref_2d 没有必要在+shift,因为 prev_bev 已经对齐了,prev_bev 的 ref_2d 和 bev 的 ref_2d 保持一致就好

公版算法:

shift_ref_2d = ref_2d.clone()

shift_ref_2d += shift[:, None, None, :]

bs, len_bev, num_bev_level, _ = ref_2d.shape

hybird_ref_2d = torch.stack([shift_ref_2d, ref_2d], 1).reshape(

bs*2, len_bev, num_bev_level, 2)

参考算法:

shift_ref_2d = ref_2d.clone()

bs, len_bev, num_bev_level, _ = ref_2d.shape

hybird_ref_2d = torch.stack([shift_ref_2d, ref_2d], 1).reshape(

bs * 2, len_bev, num_bev_level, 2

)

4.1 TemporalSelfAttention:

功能:

Temporal Self-Attention 模块为自注意力,提取了前一个时刻的 BEV 特征作为先验知识提高 BEV Query 的建模能力,相比于直接初始 query 来说融合历史的 bev 特征更能提高 BEV Query 的建模能力。用一个生动形象的图来表现下吧:

解析:

#----------------------------------init start---------------------------------------------
#TSA在初始化阶段定义了下面4个线性映射层,这些参数均在训练中进行学习
#利用value生成采样点的offset:256*2——>2*8*1*4*2
self.sampling_offsets = nn.Linear(
   embed_dims*self.num_bev_queue, num_bev_queue*num_heads * num_levels * num_points * 2)
#利用value生成采样点的权重:256*2——>2*8*1*4
self.attention_weights = nn.Linear(embed_dims*self.num_bev_queue,
                                   num_bev_queue*num_heads * num_levels * num_points)
#value的线性映射层
self.value_proj = nn.Linear(embed_dims, embed_dims)
#输出的线性映射层
self.output_proj = nn.Linear(embed_dims, embed_dims)
self.init_weights()
#----------------------------------init end-----------------------------------------------
#-----------------------------init_weights start------------------------------------------
constant_init(self.sampling_offsets, 0.)
thetas = torch.arange(
   self.num_heads,
   dtype=torch.float32) * (2.0 * math.pi / self.num_heads)
grid_init = torch.stack([thetas.cos(), thetas.sin()], -1)
#torch.Size([8,1*2,4,2])
grid_init = (grid_init /
               grid_init.abs().max(-1, keepdim=True)[0]).view(
   self.num_heads, 1, 1,
   2).repeat(1, self.num_levels*self.num_bev_queue, self.num_points, 1)

for i in range(self.num_points):
   grid_init[:, :, i, :] *= i + 1
#注意sampling_offsets.bias的初始化很有意思,相当于在参考点周围撒了一圈采样点
self.sampling_offsets.bias.data = grid_init.view(-1)
constant_init(self.attention_weights, val=0., bias=0.)
xavier_init(self.value_proj, distribution='uniform', bias=0.)
xavier_init(self.output_proj, distribution='uniform', bias=0.)
self._is_init = True
#-----------------------------init_weights   end------------------------------------------
#在没有prev_bev的情况下,将当前时刻的两个bev进行cat作为value
#如果已经有prev_bev,value已在上一模块中生成
if value is None:
   assert self.batch_first
   # query:torch.Size([1, 50 * 50, 256])
   bs, len_bev, c = query.shape
   value = torch.stack([query, query], 1).reshape(bs*2, len_bev, c) # torch.Size([2, 50 * 50, 256])
   
if identity is None:
   identity = query
#为bev_query添加位置编码
if query_pos is not None:
   # bev_query + bev_pos
   query = query + query_pos
if not self.batch_first:
   # change to (bs, num_query ,embed_dims)
   query = query.permute(1, 0, 2)
   value = value.permute(1, 0, 2)
bs,  num_query, embed_dims = query.shape
_, num_value, _ = value.shape
assert (spatial_shapes[:, 0] * spatial_shapes[:, 1]).sum() == num_value
assert self.num_bev_queue == 2

#这里在特征维度cat来自于tsa_value[:bs]的prev_bev/bev_query(不含query_pos)和当前的bev_query
#并以此推理出offsets和weights
#query:torch.Size([1, 50 * 50, 512])
query = torch.cat([value[:bs], query], -1)
#query:torch.Size([2, 50 * 50, 256])
value = self.value_proj(value)
#value空白位置填充0,例如prev_bev旋转后部分位置没有特征值
if key_padding_mask is not None:
   value = value.masked_fill(key_padding_mask[..., None], 0.0)
#value:torch.Size([2, 50 * 50, 8, 32]),将256拆分成8*32,为后续多头注意力机制做准备
value = value.reshape(bs*self.num_bev_queue,
                       num_value, self.num_heads, -1)
#sampling_offsets:torch.Size([1, 50 * 50, 128])
sampling_offsets = self.sampling_offsets(query)
#sampling_offsets:torch.Size([1,50*50, 8,2,1,4,2])
sampling_offsets = sampling_offsets.view(
   bs, num_query, self.num_heads,  self.num_bev_queue, self.num_levels, self.num_points, 2)
attention_weights = self.attention_weights(query).view(
   bs, num_query,  self.num_heads, self.num_bev_queue, self.num_levels * self.num_points)
attention_weights = attention_weights.softmax(-1)

attention_weights = attention_weights.view(bs, num_query,
                                           self.num_heads,
                                           self.num_bev_queue,
                                           self.num_levels,
                                           self.num_points)
#attention_weights:torch.Size([1*2, 50*50, 8, 1, 4])
attention_weights = attention_weights.permute(0, 3, 1, 2, 4, 5)\
   .reshape(bs*self.num_bev_queue, num_query, self.num_heads, self.num_levels, self.num_points).contiguous()
#sampling_offsets:torch.Size([1*2, 50*50, 8, 1, 4, 2])
sampling_offsets = sampling_offsets.permute(0, 3, 1, 2, 4, 5, 6)\
   .reshape(bs*self.num_bev_queue, num_query, self.num_heads, self.num_levels, self.num_points, 2)

# reference_points在TSA中即为hybird_ref_2d:torch.Size([2, 50*50, 1, 2])
# 下步为对采样位置进行归一化
if reference_points.shape[-1] == 2:
   offset_normalizer = torch.stack(
       [spatial_shapes[..., 1], spatial_shapes[..., 0]], -1)
   #sampling_locations:torch.Size([1*2, 50*50, 8, 1, 4, 2])
   sampling_locations = reference_points[:, :, None, :, None, :] \
       + sampling_offsets \
       / offset_normalizer[None, None, None, :, None, :]

elif reference_points.shape[-1] == 4:
   sampling_locations = reference_points[:, :, None, :, None, :2] \
       + sampling_offsets / self.num_points \
       * reference_points[:, :, None, :, None, 2:] \
       * 0.5
else:
   raise ValueError(
       f'Last dim of reference_points must be'
       f' 2 or 4, but get {reference_points.shape[-1]} instead.')
if torch.cuda.is_available() and value.is_cuda:
   # using fp16 deformable attention is unstable because it performs many sum operations
   if value.dtype == torch.float16:
       MultiScaleDeformableAttnFunction = MultiScaleDeformableAttnFunction_fp32
   else:
       MultiScaleDeformableAttnFunction = MultiScaleDeformableAttnFunction_fp32
   output = MultiScaleDeformableAttnFunction.apply(
       value, spatial_shapes, level_start_index, sampling_locations,
       attention_weights, self.im2col_step)
else:
   output = multi_scale_deformable_attn_pytorch(
       value, spatial_shapes, sampling_locations, attention_weights)
# -------------multi_scale_deformable_attn_pytorch start-------------
#value:torch.Size([2, 50 * 50, 8, 32])
msda_bs, _, num_heads, msda_embed_dims = value.shape  
#sampling_locations:torch.Size([2, 50*50, 8, 1, 4, 2])
_, num_queries, num_heads, num_levels, num_points, _ = sampling_locations.shape
#value_list:(torch.Size([2, 50*50, 8, 32]))
value_list = value.split([H_ * W_ for H_, W_ in spatial_shapes], dim=1)
sampling_grids = 2 * sampling_locations - 1  # 为了对齐F.grid_sample函数中的grid坐标系范围[-1~1]
sampling_value_list = []
for level, (H_, W_) in enumerate(spatial_shapes):
   value_l_ = value_list[level].flatten(2).transpose(1, 2).reshape(msda_bs * num_heads, msda_embed_dims, H_, W_)
   # msda_bs, H_*W_, num_heads, msda_embed_dims ->
   # msda_bs, H_*W_, num_heads*msda_embed_dims ->
   # msda_bs, num_heads*msda_embed_dims, H_*W_ ->
   # msda_bs*num_heads, msda_embed_dims, H_, W_
   # torch.Size([2*8, 32, 50, 50])
   sampling_grid_l_ = sampling_grids[:, :, :, level].transpose(1, 2).flatten(0, 1)
   # msda_bs, num_queries, num_heads, num_points, 2 ->
   # msda_bs, num_heads, num_queries, num_points, 2 ->
   # msda_bs*num_heads, num_queries, num_points, 2
   # torch.Size([2*8, 50*50, 4, 2])
   sampling_value_l_ = F.grid_sample(
       value_l_,
       sampling_grid_l_,
       mode='bilinear',
       padding_mode='zeros',
       align_corners=False)
   # msda_bs*num_heads, msda_embed_dims, num_queries, num_points
   # torch.Size([2*8, 32, 50*50, 4])
   sampling_value_list.append(sampling_value_l_)
attention_weights = attention_weights.transpose(1, 2).reshape(msda_bs * num_heads, 1, num_queries, num_levels * num_points)
# (msda_bs, num_queries, num_heads, num_levels, num_points) ->
# (msda_bs, num_heads, num_queries, num_levels, num_points) ->
# (msda_bs*num_heads, 1, num_queries, num_levels*num_points)
# torch.Size([2*8, 1, 50*50, 1*4])
output = (torch.stack(sampling_value_list, dim=-2).flatten(-2) *
           attention_weights).sum(-1).view(msda_bs, num_heads * msda_embed_dims, num_queries)
# torch.Size([2, 256, 50*50])
# return output.transpose(1, 2).contiguous()
output = output.transpose(1, 2).contiguous()
# torch.Size([2, 50*50, 256]) # (bs*num_bev_queue, num_query, embed_dims)
# --------------multi_scale_deformable_attn_pytorch end--------------
# output shape (bs*num_bev_queue, num_query, embed_dims)
# (bs*num_bev_queue, num_query, embed_dims)-> (num_query, embed_dims, bs*num_bev_queue)
output = output.permute(1, 2, 0)
# fuse history value and current value
# (num_query, embed_dims, bs*num_bev_queue)-> (num_query, embed_dims, bs, num_bev_queue)
output = output.view(num_query, embed_dims, bs, self.num_bev_queue)
output = output.mean(-1)
# (num_query, embed_dims, bs)-> (bs, num_query, embed_dims)
output = output.permute(2, 0, 1)
#output:torch.Size([1, 50*50, 256])
output = self.output_proj(output)
if not self.batch_first:
   output = output.permute(1, 0, 2)
#类似于残差连接,dropout可以避免过拟合,相当于在原有bev_query中融入了之前时刻的信息
return self.dropout(output) + identity
4.2 SpatialCrossAttention:

功能:

Temporal Self-Attention 模块为自注意力,提取了前一个时刻的 BEV 特征作为先验知识来提高 BEV Query 的建模能力,相比于直接初始 query 来说融合历史的 bev 特征更能提高 BEV Query 的建模能力,用一个生动形象的图来表现下吧:

解析:

#存在value,value=feat_flatten( torch.Size([6, 15*25, 1, 256]) )
if value is None:
   value = key
if residual is None:
   inp_residual = query  # torch.Size([1, 50*50, 256])
   slots = torch.zeros_like(query)
#添加位置编码信息
if query_pos is not None:
   query = query + query_pos

bs, num_query, _ = query.size()

D = reference_points_cam.size(3) # torch.Size([6, 1, 50*50, 4, 2]),D=4
indexes = []
# bev_mask:torch.Size([6, 1, 50*50, 4])
for i, mask_per_img in enumerate(bev_mask):
   # 包含的4个点至少有一个可以投影到图像上的pillar在50*50的bev中的索引
   index_query_per_img = mask_per_img[0].sum(-1).nonzero().squeeze(-1)
   # torch.Size([num_nonzero_pillar])
   indexes.append(index_query_per_img)
max_len = max([len(each) for each in indexes])
# each camera only interacts with its corresponding BEV queries. This step can  greatly save GPU memory.
queries_rebatch = query.new_zeros(
   [bs, self.num_cams, max_len, self.embed_dims])
reference_points_rebatch = reference_points_cam.new_zeros(
   [bs, self.num_cams, max_len, D, 2])
for j in range(bs):
   for i, reference_points_per_img in enumerate(reference_points_cam):  
       index_query_per_img = indexes[i]
       queries_rebatch[j, i, :len(index_query_per_img)] = query[j, index_query_per_img]
       reference_points_rebatch[j, i, :len(index_query_per_img)] = reference_points_per_img[j, index_query_per_img]
#key:torch.Size([6, 15*25, 1, 256])
num_cams, l, bs, embed_dims = key.shape
#key:torch.Size([6, 15*25, 256])
key = key.permute(2, 0, 1, 3).reshape(
   bs * self.num_cams, l, self.embed_dims)
#value:torch.Size([6, 15*25, 256])
value = value.permute(2, 0, 1, 3).reshape(
   bs * self.num_cams, l, self.embed_dims)

#query:torch.Size([1,6,max_len,256])->torch.Size([6,max_len,256])
#key and value:torch.Size([6,15*25,256])
#reference_points:torch.Size([6,max_len,4,2])
#spatial_shapes:torch.Size([1,2]) {15,25}
#level_start_index没有用
queries = self.deformable_attention(query=queries_rebatch.view(bs*self.num_cams, max_len, self.embed_dims), key=key, value=value,
                                   reference_points=reference_points_rebatch.view(bs*self.num_cams, max_len, D, 2), spatial_shapes=spatial_shapes,
                                   level_start_index=level_start_index).view(bs, self.num_cams, max_len, self.embed_dims)
# ------------------------multi_scale_deformable_attn_pytorch start-----------------------
if msda_identity is None:
   msda_identity = query #torch.Size([6,max_len,256])
#query:torch.Size([bs * num_cams, max_len, embed_dims])
msda_bs, max_len, _ = query.shape
# value:torch.Size([6, 15*25, 256])
msda_bs, num_value, _ = value.shape
assert (spatial_shapes[:, 0] * spatial_shapes[:, 1]).sum() == num_value
value = msda_value_proj(ca_value)
#value:torch.Size([6,15*25,8,32])
value = value.view(msda_bs, num_value, msda_num_heads, -1)
#sampling_offsets:torch.Size([6,max_len,8,1,8,2])
sampling_offsets = msda_sampling_offsets(query).view(
   msda_bs, max_len, msda_num_heads, msda_num_levels, msda_num_points, 2)
#attention_weights:torch.Size([6,max_len,8,1*8])
attention_weights = msda_attention_weights(query).view(
   msda_bs, max_len, msda_num_heads, msda_num_levels * msda_num_points)
attention_weights = attention_weights.softmax(-1)
attention_weights = attention_weights.view(msda_bs, max_len,
                                           msda_num_heads,
                                           msda_num_levels,
                                           msda_num_points)
#reference_points:torch.Size([6, max_len(eg:606), 4, 2])                                            
if reference_points.shape[-1] == 2:
   """
   For each BEV query, it owns `num_Z_anchors` in 3D space that having different heights.
   After proejcting, each BEV query has `num_Z_anchors` reference points in each 2D image.
   For each referent point, we sample `num_points` sampling points.
   For `num_Z_anchors` reference points,  it has overall `num_points * num_Z_anchors` sampling points.
   """
   offset_normalizer = torch.stack([spatial_shapes[..., 1], spatial_shapes[..., 0]], -1)
   #reference_points:torch.Size([6, 606, 4, 2])
   msda_bs, max_len, num_Z_anchors, xy = reference_points.shape
   reference_points = reference_points[:, :, None, None, None, :, :]
   # [msda_bs, max_len, 1, 1, 1, num_Z_anchors, uv]
   sampling_offsets = sampling_offsets / offset_normalizer[None, None, None, :, None, :]
   # [msda_bs, max_len, msda_num_heads, msda_num_levels, msda_num_points, 2]
   msda_bs, max_len, msda_num_heads, msda_num_levels, msda_num_all_points, xy = sampling_offsets.shape
   # 6        606           8               1                  8            2
   sampling_offsets = sampling_offsets.view(
       msda_bs, max_len, msda_num_heads, msda_num_levels, msda_num_all_points // num_Z_anchors, num_Z_anchors, xy)
   #      6       606          8                 1                            2                        4        2
   sampling_locations = reference_points + sampling_offsets
   msda_bs, max_len, msda_num_heads, msda_num_levels, Z_anchor_points, num_Z_anchors, xy = sampling_locations.shape
   assert msda_num_all_points == Z_anchor_points * num_Z_anchors
   #sampling_locations:torch.Size([6,max_len(eg:606),8,1,8,2])
   sampling_locations = sampling_locations.view(msda_bs, max_len, msda_num_heads, msda_num_levels, msda_num_all_points, xy)

if torch.cuda.is_available() and value.is_cuda and use_msda_cuda:
   if value.dtype == torch.float16:
       MultiScaleDeformableAttnFunction = MultiScaleDeformableAttnFunction_fp32
   else:
       MultiScaleDeformableAttnFunction = MultiScaleDeformableAttnFunction_fp32
   output = MultiScaleDeformableAttnFunction.apply(
       value, spatial_shapes, level_start_index, sampling_locations,
       attention_weights, msda_im2col_step)
else:
   # value: torch.Size([6, 15*25, 8, 32])
   # spatial_shapes: torch.Size([1, 2])
   # sampling_locations: torch.Size([6, 606, 8, 1, 8, 2])
   # attention_weights: torch.Size([6, 606, 8, 1, 8])
   # output:torch.Size([6, 606, 256])
   output = multi_scale_deformable_attn_pytorch(value, spatial_shapes, sampling_locations, attention_weights)
# --------------multi_scale_deformable_attn_pytorch end--------------
queries = output.view(bs, num_cams, max_len, embed_dims)  # torch.Size([1, 6, 606, 256])
# slots: torch.Size([1, 50*50, 256])
for j in range(bs):
   for i, index_query_per_img in enumerate(indexes):
       # TODO: 不同的相机之间特征向量是直接相加的
       slots[j, index_query_per_img] += queries[j, i, :len(index_query_per_img)]

# bev_mask: torch.Size([6, 1, 50*50, 4])
# slots: torch.Size([1, 50*50, 256])
count = bev_mask.sum(-1) > 0
count = count.permute(1, 2, 0).sum(-1)
count = torch.clamp(count, min=1.0)
slots = slots / count[..., None]
slots = output_proj(slots)
# return self.dropout(slots) + inp_residual
query = ca_dropout(slots) + inp_residual

未完待续。。。。


*博客内容为网友个人发布,仅代表博主个人观点,如有侵权请联系工作人员删除。

参与讨论
登录后参与讨论
推荐文章
最近访客