新闻  |   论坛  |   博客  |   在线研讨会
量化训练及精度调优经验分享
地平线开发者 | 2024-11-15 18:33:53    阅读:11   发布文章

本文提纲:

  1. fx 和 eager 两种量化训练方式介绍

  2. 量化训练的流程介绍:以 mmdet 的 yolov3 为例

  3. 常用的精度调优 debug 工具介绍

  4. 案例分析:模型精度调优经验分享


第一部分:fx 和 eager 两种量化训练方式介绍

首先介绍一下量化训练的原理。

img

上图为单个神经元的计算,计算形式是加权求和,再经过非线性激活后得到输出,这个输出又可以作为下一个神经元的输入继续运输,所以神经网络的基础运算是矩阵的乘法。如果神经元的计算全部采用 float32 的形式,模型的内存占用和数据搬运都会很占资源。如果用 int8 替换 float32,内存的搬运效率能提高 75%,充分展示了量化的有效性。由于两个 int8 相乘会超出 int8 的表示范围,为了防止溢出,累加器使用 int32 类型的,累加后的结果会再次 requantized 到 int8;

量化的目标就是在尽可能不影响模型精度的情况下降低模型的功耗,实现模型压缩效果,常见的量化方式有后量化训练 PTQ 和量化感知训练 QAT。

aW1hZ2U=.png

量化感知训练其实是一种伪量化的过程,即在训练过程中模拟浮点转定点的量化过程,数据虽然都是表示为 float32,但实际的值会间隔地受到量化参数的限制。具体方法是在某些 op 前插入伪量化节点(fake quantization nodes),伪量化节点有两个作用:

1.在训练时,用以统计流经该 op 的数据的最大最小值,便于在部署量化模型时对节点进行量化

2.伪量化节点参与模型训练的前向推理过程,因此会模型训练中导入了量化损失,但伪量化节点是不参与梯度更新过程的。

aW1hZ2U=.png

上图是模型学习量化损失的示意图, 正常的量化流程是 quantize->mul(int)->dequantize,而伪量化是对原先的 float 先 quantize 到 int,再 dequantize 到 float,这个步骤用于模拟量化过程中 round 操作所带来的误差,用这个误差再去进行前向运算。上图可以比较直观的表示引起误差的原因,从左到右数第 4 个黑点表示一个浮点数,quantize 后映射到 253,dequantize 后取到了第 5 个黑点,这就引起了误差。

地平线基于 PyTorch 开发的 horizon_plugin_pytorch 量化训练工具,同时支持 Eager 和 fx 两种模式。

aW1hZ2U=.png

eager 模式的使用方式建议参考用户手册 -4.2 量化感知训练章节(4.2.2。 快速上手中有完整的快速上手示例,各使用阶段注意事项建议参考 4.2.3。 使用指南)。fx 模式的相关 API 介绍请参考用户手册 -4.2.3.4.2。 主要接口参数说明章节


第二部分:量化训练的流程介绍:以 mmdet 的 yolov3 为例QAT 流程介绍准备好浮点模型,加载训好的浮点权重
    model = build_detector(
       cfg.model,
       train_cfg=cfg.get('train_cfg'),
       test_cfg=cfg.get('test_cfg'))
   model.init_weights()# 加载config里的 init_cfg
设置 BPU 架构
set_march(March.BAYES)
算子融合(eager 模式需要,fx 可省略)
    # qat: run fuse_module to fuse conv+bn/relu/add op
   model.backbone.fuse_modules()
   model.neck.fuse_modules()
   model.bbox_head.fuse_modules()
设置量化配置
  • 整个 model 使用默认的 qconfig

  • 模型的输出,配置高精度输出

  • det 模型 head 输出的 loss 损失函数的 qconfig 设置为 None

    # qat: set qconfig for float model
   model.qconfig = get_default_qat_qconfig()
   # qat: set default_qat_out_qconfig for last conv
   for m in model.bbox_head.convs_pred:
       m.qconfig = get_default_qat_out_qconfig()
   # qat: set None for loss qconfig, loss should be quantized
   model.bbox_head.loss_cls.qconfig = None
   model.bbox_head.loss_conf.qconfig = None
   model.bbox_head.loss_xy.qconfig = None
   model.bbox_head.loss_wh.qconfig = None
将浮点模型转换为 qat 模型(示例使用 eager 模式)
    qat_model = prepare_qat(model)
   qat_model.to(torch.device("cuda:1"))
开始 qat 训练
  1. 可以复用浮点的 train_detector,替换 model 即可

train_detector(
       qat_model,
       datasets,
       cfg,
       distributed=distributed,
       validate=(not args.no_validate),
       timestamp=timestamp,
       meta=meta)
qat 模型转定点(需要 load 训练好的 qat 模型权重)
quantized_model = convert(qat_model.eval())
deploy_model 和 example_input 准备
    deploy_model = DeployModel(
       quantized_model.backbone, quantized_model.neck,
       quantized_model.bbox_head).to(torch.device("cuda:1"))
   example_input = torch.randn(size=(24, 3, 320, 320), device=torch.device("cuda:1"))
Trace 模型构建静态 graph,进行编译
  • eval()使 bn、dropout 等处于正确的状态

  • 编译只能在 cpu 上做

  • check_model 用于检查算子是否能全部跑在 bpu 上,建议提前检查

    traced_model = torch.jit.trace(deploy_model.eval(), example_input)
   traced_model.to(torch.device("cpu"))
   example_input.to(torch.device("cpu"))
   check_model(traced_model, example_input, advice=1)
   compile_model(traced_model, [example_input], opt=0, hbm="model.hbm")
如果 qat 精度不达标,如何插入 calibration?
1. 准备好浮点模型,加载训好的浮点权重
2. 设置BPU架构
3. 算子融合(eager模式需要,fx可省略)
4. 设置model的量化配置
-----------------calib_model-------------------
calib_model = prepare_qat(float_model)
calib_model.eval() # 使bn、dropout等处于正确的状态
set_fake_quantize(calib_model, FakeQuantState.CALIBRATION) # 不进行伪量化操作,仅观测算子输入输出统计量,更新scale
#校准训练(可复用浮点的train_detector,替换model即可)
train_detector(
       calib_model,
       datasets,
       cfg,
       distributed=distributed,
       validate=(not args.no_validate),
       timestamp=timestamp,
       meta=meta)
#校准精度验证
calib_model.eval()
set_fake_quantize(calib_model, FakeQuantState.VALIDATION)
val(calib_model,val_dataloader,device)
-----------此时calib_model里的scale已经更新了-------------------------
qat_model = prepare_qat(float_model)
-----------qat_model加载calib训练好的模型权重,开始qat训练-----------------------------------------------
train_detector(
       qat_model,
       datasets,
       cfg,
       distributed=distributed,
       validate=(not args.no_validate),
       timestamp=timestamp,
       meta=meta)

伪量化节点(fake quantize)的三种状态:

  • CALIBRATION 模式:即不进行伪量化操作,仅观测算子输入输出统计量,更新 scale

  • QAT 模式:观测统计量并进行伪量化操作。

  • VALIDATION 模式:不会观测统计量,仅进行伪量化操作。

以下常见误操作会导致一些异常现象:

  1. calibration 之前模型设置为 train()的状态,且未使用set_fake_quantize,等于是在跑 QAT 训练;

  2. calibration 之前模型设置为 eval()的状态,且未使用set_fake_quantize,会导致 scale 一直处于初始状态,全为 1,calib 不起作用。

  3. calibration 之前模型设置为 eval()的状态,且正确使用了set_fake_quantize,但是在这之后又设置了一遍 model.eval(),这将导致 fake_quant 未处于训练状态,scale 一直处于初始状态,全为 1;


对 mobilenet_v2 模型做 qat 训练的设置量化节点设置

关键代码:

from horizon_plugin_pytorch.quantization import QuantStub

self.quant = QuantStub(scale=1/128) # 一般 pyramid 输入的 Quant 层,需要手动设置 scale=1/128def fuse_modules(self):

x = self.quant(x)
# Copyright (c) OpenMMLab. All rights reserved.
import warnings

import torch.nn as nn
from mmcv.cnn import ConvModule
from mmcv.runner import BaseModule
from torch.nn.modules.batchnorm import _BatchNorm
from horizon_plugin_pytorch.quantization import QuantStub

from ..builder import BACKBONES
from ..utils import InvertedResidual, make_divisible
import torch

@BACKBONES.register_module()
class MobileNetV2(BaseModule):
   arch_settings = [[1, 16, 1, 1], [6, 24, 2, 2], [6, 32, 3, 2],
                    [6, 64, 4, 2], [6, 96, 3, 1], [6, 160, 3, 2],
                    [6, 320, 1, 1]]

   def __init__(self,
                widen_factor=1.,
                out_indices=(1, 2, 4, 7),
                frozen_stages=-1,
                conv_cfg=None,
                norm_cfg=dict(type='BN'),
                act_cfg=dict(type='ReLU6'),
                norm_eval=False,
                with_cp=False,
                pretrained=None,
                init_cfg=None):
       super(MobileNetV2, self).__init__(init_cfg)
       # qat: model start with Quantization node
       # and set scale=1/128
       self.quant = QuantStub(scale=1/128) # 一般pyramid输入的Quant层,需要手动设置scale=1/128
       self.pretrained = pretrained
       assert not (init_cfg and pretrained), \
           'init_cfg and pretrained cannot be specified at the same time'
       if isinstance(pretrained, str):
           warnings.warn('DeprecationWarning: pretrained is deprecated, '
                         'please use "init_cfg" instead')
           self.init_cfg = dict(type='Pretrained', checkpoint=pretrained)
       elif pretrained is None:
           if init_cfg is None:
               self.init_cfg = [
                   dict(type='Kaiming', layer='Conv2d'),
                   dict(
                       type='Constant',
                       val=1,
                       layer=['_BatchNorm', 'GroupNorm'])
               ]
       else:
           raise TypeError('pretrained must be a str or None')

       self.widen_factor = widen_factor
       self.out_indices = out_indices
       if not set(out_indices).issubset(set(range(0, 8))):
           raise ValueError('out_indices must be a subset of range'
                            f'(0, 8). But received {out_indices}')

       if frozen_stages not in range(-1, 8):
           raise ValueError('frozen_stages must be in range(-1, 8). '
                            f'But received {frozen_stages}')
       self.out_indices = out_indices
       self.frozen_stages = frozen_stages
       self.conv_cfg = conv_cfg
       self.norm_cfg = norm_cfg
       self.act_cfg = act_cfg
       self.norm_eval = norm_eval
       self.with_cp = with_cp

       self.in_channels = make_divisible(32 * widen_factor, 8)

       self.conv1 = ConvModule(
           in_channels=3,
           out_channels=self.in_channels,
           kernel_size=3,
           stride=2,
           padding=1,
           conv_cfg=self.conv_cfg,
           norm_cfg=self.norm_cfg,
           act_cfg=self.act_cfg)

       self.layers = []

       for i, layer_cfg in enumerate(self.arch_settings):
           expand_ratio, channel, num_blocks, stride = layer_cfg
           out_channels = make_divisible(channel * widen_factor, 8)
           inverted_res_layer = self.make_layer(
               out_channels=out_channels,
               num_blocks=num_blocks,
               stride=stride,
               expand_ratio=expand_ratio)
           layer_name = f'layer{i + 1}'
           self.add_module(layer_name, inverted_res_layer)
           self.layers.append(layer_name)

       if widen_factor > 1.0:
           self.out_channel = int(1280 * widen_factor)
       else:
           self.out_channel = 1280

       layer = ConvModule(
           in_channels=self.in_channels,
           out_channels=self.out_channel,
           kernel_size=1,
           stride=1,
           padding=0,
           conv_cfg=self.conv_cfg,
           norm_cfg=self.norm_cfg,
           act_cfg=self.act_cfg)
       self.add_module('conv2', layer)
       self.layers.append('conv2')

   def make_layer(self, out_channels, num_blocks, stride, expand_ratio):
       """Stack InvertedResidual blocks to build a layer for MobileNetV2.

       Args:
           out_channels (int): out_channels of block.
           num_blocks (int): number of blocks.
           stride (int): stride of the first block. Default: 1
           expand_ratio (int): Expand the number of channels of the
               hidden layer in InvertedResidual by this ratio. Default: 6.
       """
       layers = []
       for i in range(num_blocks):
           if i >= 1:
               stride = 1
           layers.append(
               InvertedResidual(
                   self.in_channels,
                   out_channels,
                   mid_channels=int(round(self.in_channels * expand_ratio)),
                   stride=stride,
                   with_expand_conv=expand_ratio != 1,
                   conv_cfg=self.conv_cfg,
                   norm_cfg=self.norm_cfg,
                   act_cfg=self.act_cfg,
                   with_cp=self.with_cp))
           self.in_channels = out_channels

       return nn.Sequential(*layers)

   def _freeze_stages(self):
       if self.frozen_stages >= 0:
           for param in self.conv1.parameters():
               param.requires_grad = False
       for i in range(1, self.frozen_stages + 1):
           layer = getattr(self, f'layer{i}')
           layer.eval()
           for param in layer.parameters():
               param.requires_grad = False
   
   # qat: do fuse model
   def fuse_modules(self):
       self.conv1.fuse_modules()
       for layer_name in self.layers:
           layer = getattr(self, layer_name)
           if hasattr(layer, "fuse_modules"):
               layer.fuse_modules()
           elif isinstance(layer, nn.Sequential):
               for m in layer:
                   if hasattr(m, "fuse_modules"):
                       m.fuse_modules()

   def forward(self, x):
       """Forward function."""
       # qat: qat model start with QuantStub
       x = self.quant(x)
       x = self.conv1(x)
       outs = []
       for i, layer_name in enumerate(self.layers):
           layer = getattr(self, layer_name)
           x = layer(x)
           if i in self.out_indices:
               outs.append(x)
       return tuple(outs)

   def train(self, mode=True):
       """Convert the model into training mode while keep normalization layer
       frozen."""
       super(MobileNetV2, self).train(mode)
       self._freeze_stages()
       if mode and self.norm_eval:
           for m in self.modules():
               # trick: eval have effect on BatchNorm only
               if isinstance(m, _BatchNorm):
                   m.eval()


算子融合

7.5.5. 算子融合 — Horizon Open Explorer

aW1hZ2U=.png

举个例子:mmcv/cnn/bricks/conv_module.py
class ConvModule(nn.Module):
...
# qat: fuse conv + bn/relu
   def fuse_modules(self):
       fuse_list = None
       if self.with_norm:
           if self.with_activation:
               fuse_list = ["conv", self.norm_name, "activate"] # conv+bn+relu
           else:
               fuse_list = ["conv", self.norm_name] # conv+bn
       else:
           if self.with_activation:
               fuse_list = ["conv", "activate"] # conv+relu
       if fuse_list is not None:
           torch.quantization.fuse_modules(
               self,
               fuse_list,
               inplace=True,
               fuser_func=quantization.fuse_known_modules,
           )

eager 方案麻烦的是,基本每个模块都要手动去设置算子融合

反量化节点设置

mmdetection-master/mmdet/models/dense_heads/yolo_head.py

关键代码:

self.dequant = nn.ModuleList() # 不止1个反量化节点,用list包起来

self.dequant.append(DeQuantStub())

def fuse_modules(self):

pred_map = self.dequant[i](self.convs_pred[i](x))
class YOLOV3Head(BaseDenseHead, BBoxTestMixin):

   def __init__(self,
                num_classes,
                in_channels,
                out_channels=(1024, 512, 256),
                anchor_generator=dict(
                    type='YOLOAnchorGenerator',
                    base_sizes=[[(116, 90), (156, 198), (373, 326)],
                                [(30, 61), (62, 45), (59, 119)],
                                [(10, 13), (16, 30), (33, 23)]],
                    strides=[32, 16, 8]),
                bbox_coder=dict(type='YOLOBBoxCoder'),
                featmap_strides=[32, 16, 8],
                one_hot_smoother=0.,
                conv_cfg=None,
                norm_cfg=dict(type='BN', requires_grad=True),
                # qat
                # act_cfg=dict(type='LeakyReLU', negative_slope=0.1),
                act_cfg=dict(type='ReLU'),
                loss_cls=dict(
                    type='CrossEntropyLoss',
                    use_sigmoid=True,
                    loss_weight=1.0),
                loss_conf=dict(
                    type='CrossEntropyLoss',
                    use_sigmoid=True,
                    loss_weight=1.0),
                loss_xy=dict(
                    type='CrossEntropyLoss',
                    use_sigmoid=True,
                    loss_weight=1.0),
                loss_wh=dict(type='MSELoss', loss_weight=1.0),
                train_cfg=None,
                test_cfg=None,
                init_cfg=dict(
                    type='Normal', std=0.01,
                    override=dict(name='convs_pred'))):
       super(YOLOV3Head, self).__init__(init_cfg)
       # Check params
       assert (len(in_channels) == len(out_channels) == len(featmap_strides))

       self.num_classes = num_classes
       self.in_channels = in_channels
       self.out_channels = out_channels
       self.featmap_strides = featmap_strides
       self.train_cfg = train_cfg
       self.test_cfg = test_cfg
       if self.train_cfg:
           self.assigner = build_assigner(self.train_cfg.assigner)
           if hasattr(self.train_cfg, 'sampler'):
               sampler_cfg = self.train_cfg.sampler
           else:
               sampler_cfg = dict(type='PseudoSampler')
           self.sampler = build_sampler(sampler_cfg, context=self)
       self.fp16_enabled = False

       self.one_hot_smoother = one_hot_smoother

       self.conv_cfg = conv_cfg
       self.norm_cfg = norm_cfg
       self.act_cfg = act_cfg

       self.bbox_coder = build_bbox_coder(bbox_coder)

       self.prior_generator = build_prior_generator(anchor_generator)

       self.loss_cls = build_loss(loss_cls)
       self.loss_conf = build_loss(loss_conf)
       self.loss_xy = build_loss(loss_xy)
       self.loss_wh = build_loss(loss_wh)

       self.num_base_priors = self.prior_generator.num_base_priors[0]
       assert len(
           self.prior_generator.num_base_priors) == len(featmap_strides)
       self._init_layers()

   def _init_layers(self):
       self.convs_bridge = nn.ModuleList()
       self.convs_pred = nn.ModuleList()
       self.dequant = nn.ModuleList() # 不止1个反量化节点,用list包起来
       for i in range(self.num_levels):
           conv_bridge = ConvModule(
               self.in_channels[i],
               self.out_channels[i],
               3,
               padding=1,
               conv_cfg=self.conv_cfg,
               norm_cfg=self.norm_cfg,
               act_cfg=self.act_cfg)
           conv_pred = nn.Conv2d(self.out_channels[i],
                                 self.num_base_priors * self.num_attrib, 1)

           self.convs_bridge.append(conv_bridge)
           self.convs_pred.append(conv_pred)
           self.dequant.append(DeQuantStub())

   def fuse_modules(self):
       for m in self.convs_bridge:
           m.fuse_modules()

   def forward(self, feats):
       """Forward features from the upstream network.

       Args:
           feats (tuple[Tensor]): Features from the upstream network, each is
               a 4D-tensor.

       Returns:
           tuple[Tensor]: A tuple of multi-level predication map, each is a
               4D-tensor of shape (batch_size, 5+num_classes, height, width).
       """

       assert len(feats) == self.num_levels
       pred_maps = []
       for i in range(self.num_levels):
           x = feats[i]
           x = self.convs_bridge[i](x)
           pred_map = self.dequant[i](self.convs_pred[i](x))
           pred_maps.append(pred_map)

       return tuple(pred_maps),


第三部分:常用的精度调优 debug 工具介绍

工具:集成接口量化配置检查模型可视化相似度对比统计量分步量化异构模型部署 device 检查

aW1hZ2U=.png


第四部分:模型精度调优分享模型精度调优时常遇到的问题:
  1. calib 模型的精度和 float 对齐,quantized 模型的精度损失较大

正常情况下,calib/qat 模型的精度和 quantized 模型的精度损失很小(1%), 如果偏差过大,可能是 calib/qat 的流程不对。

原因:calib 模型伪量化节点的状态不正确,导致 calib 阶段,测试的是 float 模型的精度,而 quantized 阶段,测试的是 calib 模型的精度,所以精度损失本质上还是量化精度的损失。

如何避免:

  • 正确设置 calib 训练和评测时的伪量化节点状态。

  • 让客户在 calib 的基础上,做 qat, 评测 qat 模型的精度。(客户的数据量大,qat 时间太长,一直没有选择 qat,导致这个问题被暴露出来了)

如何设置正确的 calib 伪量化节点的状态?(fx 和 eager 都是一样的)

http://model.aidi.hobot.cc/api/docs/horizon_plugin_pytorch/latest/html/user_guide/calibration.html

aW1hZ2U=.png

#加载浮点模型权重
   model.load_state_dict(torch.load("output/toy_experiments/model_moderate_best_soft_float_131892.pth"))
   set_march(March.BAYES)
#校准配置
   calib_model = prepare_qat_fx(
       model,
       {"":default_calib_8bit_fake_quant_qconfig,
       "module_name":
           ...
       }).to(device)
   calib_model.to(device)
   #校准需要全程开启eval()状态
   calib_model.eval()
   #校准的训练阶段,设置伪量化节点模式为 CALIBRATION
   set_fake_quantize(calib_model, FakeQuantState.CALIBRATION)
   train(cfg, calib_model, device, distributed)
   #校准的评测阶段,设置伪量化节点的模式为 VALIDATION
   set_fake_quantize(calib_model, FakeQuantState.VALIDATION)
   #加载校准的模型权重
   calib_model.load_state_dict(torch.load("output/toy_experiments/model_moderate_best_soft_calib_118633.pth"))
   #测试校准的精度
   run_test(cfg, calib_model, vis=args.vis, eval_score_iou=args.eval_score_iou, eval_all_depths=args.eval_all_depths) # 11.8650

注意:16 行的 train 在评测时,也要设置 FakeQuantState.VALIDATION,不然 scale 不生效,评测的指标也不对

常见问题:

  1. 数据校准之前模型设置为 train()的状态,且未使用set_fake_quantize,等于 caib 阶段是在跑 QAT 训练;

  2. 校准的评测阶段,未设置伪量化节点的模式为 VALIDATION, 实际评测的是 float 模型;

总结 2: 如果做 calib,一定要仔细检查伪量化节点状态和模型状态是否正确,避免不符合预期的结果


2.当量化精度损失超过大,如何调优?
  1. 使用 model_profiler() 这个集成接口,生成压缩包。

  2. 检查是否配置高精度输出、是否存在未融合的算子、是否共享 op、是否算子分布过大 int8 兜不住?

  • 注意:使用 debug 集成接口时,要保证浮点模型训练到位,并传入真实数据


3.多任务模型的精度调优建议
  1. qat 调优策略和常规模型一样,ptq+qat

  2. 如果只有一个 head 精度有损失,可以固定其他部分,单独使用这个 head 的数据做 calib


4.calib 和 qat 流程的正确衔接

calib:

    #加载浮点模型权重
   model.load_state_dict(torch.load("output/toy_experiments/model_moderate_best_soft_float_131892.pth"))
   set_march(March.BAYES)
   #校准配置
   calib_model = prepare_qat_fx(
       model,
       {"":default_calib_8bit_fake_quant_qconfig,
       "module_name":
          ...
       }).to(device)
   calib_model.to(device)
   #校准需要全程开启eval()状态
   calib_model.eval()
   #校准的训练阶段,设置伪量化节点模式为 CALIBRATION
   set_fake_quantize(calib_model, FakeQuantState.CALIBRATION)
   train(cfg, calib_model, device, distributed)
   #校准的评测阶段,设置伪量化节点的模式为 VALIDATION
   set_fake_quantize(calib_model, FakeQuantState.VALIDATION)
   #加载校准的模型权重
   calib_model.load_state_dict(torch.load("output/toy_experiments/model_moderate_best_soft_calib_118633.pth"))
   #测试校准的精度
   run_test(cfg, calib_model, vis=args.vis, eval_score_iou=args.eval_score_iou, eval_all_depths=args.eval_all_depths) # 11.8650

qat:

set_march(March.BAYES)
   qat_model = prepare_qat_fx(
       model,
       {"":default_qat_8bit_fake_quant_qconfig,
       "module_name":
           '''
       }).to(device)
   qat_model.to(device)
   #加载校准模型权重
   qat_model.load_state_dict(torch.load("output/toy_experiments/model_moderate_best_soft_calib_118633.pth"))
   #训练阶段,保证模型处于model.train()状态,这样伪量化节点也处于qat模式
   train(cfg, qat_model, device, distributed)
5.检查 conv 高精度输出

方式 1:查看 qconfig_info.txt,重点关注 DeQuantStub 附近的 conv 是不是 float32 输出

qconfig_info.txt

aW1hZ2U=.png

方式 2:打印 qat_model 的最后一层,查看该层是否有 (activation_post_process): FakeQuantize

高精度的 conv:

  (1): ConvModule2d(
   (0): Conv2d(
     64, 3, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)
     (weight_fake_quant): FakeQuantize(
       fake_quant_enabled=tensor([1], dtype=torch.uint8), observer_enabled=tensor([1], dtype=torch.uint8),            quant_min=-128, quant_max=127, dtype=qint8, qscheme=torch.per_channel_symmetric, ch_axis=0,         scale=tensor([1., 1., 1.]), zero_point=tensor([0, 0, 0])
       (activation_post_process): MovingAveragePerChannelMinMaxObserver(min_val=tensor([]), max_val=tensor([]))
     )
   )
 )
)

int8 的 conv

  (0): ConvModule2d(
   (0): ConvReLU2d(
     64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)
     (weight_fake_quant): FakeQuantize(
       fake_quant_enabled=tensor([1], dtype=torch.uint8), observer_enabled=tensor([1], dtype=torch.uint8),            quant_min=-128, quant_max=127, dtype=qint8, qscheme=torch.per_channel_symmetric, ch_axis=0,         scale=tensor([1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,
               1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,
               1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,
               1., 1., 1., 1., 1., 1., 1., 1., 1., 1.]), zero_point=tensor([0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
               0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
               0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0])
       (activation_post_process): MovingAveragePerChannelMinMaxObserver(min_val=tensor([]), max_val=tensor([]))
     )
     (activation_post_process): FakeQuantize(
       fake_quant_enabled=tensor([1], dtype=torch.uint8), observer_enabled=tensor([1], dtype=torch.uint8),            quant_min=-128, quant_max=127, dtype=qint8, qscheme=torch.per_tensor_symmetric, ch_axis=-1,         scale=tensor([1.]), zero_point=tensor([0])
       (activation_post_process): MovingAverageMinMaxObserver(min_val=tensor([]), max_val=tensor([]))
     )
   )
6.检查共享 op

打开 qconfig_info.txt,后面标有(n)的就是共享的

aW1hZ2U=.png

特殊情况:layernorm 在 QAT 阶段是多个小量化算子拼接而成,module 的重复调用,也会产生大量 op 共享的问题

解决办法: 将 layernorm 替换为 batchnorm,测试了 float 精度,没有下降。

aW1hZ2U=.png

aW1hZ2U=.png

7.检查未融合的算子

打开 qconfig_info.txt,全局搜 BatchNorm2d 和 ReLU,如果前面有 conv,那就是没做算子融合

可以融合的算子:

  • conv+bn

  • conv+relu

  • conv+add

  • conv+bn+relu

  • conv+bn+add

  • conv+bn+relu+add

aW1hZ2U=.png

8.检查数据分布特别大的算子

打开 float 模型的统计量分布,一般是 model0_statistic.txt

有两个表,第一个表是按模型结构排列的;第二个表是按数据分布范围排列的

拖到第二个表,看前几行是那些 op

可以看到很多 conv 的分布很异常,使用的是 int8 量化

aW1hZ2U=.png

解决办法:

  1. 检查这些 conv 后面是否有 bn,添加 bn 后,数据能收敛一些

  2. 如果结构上已经加了 bn,数据分布还大,可以配置 int16 量化

  • int16 调这两个接口,default_qat_16bit_fake_quant_qconfig 和 default_calib_16bit_fake_quant_qconfig

  • 中间算子的写法和高精度输出类似 model.xx.qconfig = default_qat_16bit_fake_quant_qconfig ()



*博客内容为网友个人发布,仅代表博主个人观点,如有侵权请联系工作人员删除。

参与讨论
登录后参与讨论
推荐文章
最近访客