【地平线 征程 6 工具链进阶教程】QAT 训练常见问题和排查

【地平线 征程 6 工具链进阶教程】QAT 训练常见问题和排查

一、引言

对于使用地平线 征程 6 各平台的用户而言,QAT 量化感知训练是模型量化部署流程中重要的环节,同时也是量化调优的难点之一。在 QAT 训练过程中,受模型结构、量化配置、训练超参、硬件特性适配等多因素影响,可能出现训练 Loss 不收敛、前向反向过程存在 NAN/INF 异常值、量化精度指标相比校准阶段回退或无法通过 QAT 训练使指标达标等各类异常问题。

本文作为前序 征程 6 平台 QAT 精度调优教程的补充,总结地平线 征程 6 系列工具链 QAT 训练环节的各类典型异常问题以及定位解决的流程和思路。阅读前,请提前学习 征程 6 平台 QAT 精度调优教程:

  • 征程 6 E/M 工具链 QAT 精度调优:https://developer\。horizon\.auto/blog/13132
  • 征程 6 H/P 工具链 QAT 精度调优:https://developer\。horizon\.auto/blog/13157

二、QAT 训练问题定位

流程简述:

  1. 和 qat 训练相同的量化配置,检查 calib 阶段精度情况,如 calib 精度较差则回头通过 debug 工具定位 calib 阶段量化问题
  2. 在 1 的基础上,如果 calib 精度没问题,则取消 qat 链路的 prepare,其他配置不变的情况下对模型做 fine-tune,或者直接使用 qat 链路的参数配置对浮点模型做 fine-tune,如指标明显下降则浮点大概率未收敛,优化浮点训练
  3. 在 2 的基础上,设置_FLOAT 取消伪量化来进行 fine-tune,模型相当于未引入量化误差,应和直接 fine-tune 浮点情况一致,如指标明显下降则需要检查是否 qat 链路和流程问题。特殊情况:训练数据分布差异大的情况,训练早期数据可能出现极大值,而 plugin 工具对部分算子有 clip 操作进行数值约束,可能造成_FLOAT 训练初期精度差的现象,但随着迭代次数增加,_FLOAT 精度能对齐浮点
  4. 在 3 的基础上,设置全 int16 进行 qat 训练,如果精度指标下降或不达预期,需要结合固定校准激活 scale 以及 with_bn 的做法,设置不同 lr 进行实验
  5. 在 4 的基础上,回归到混合精度进行 qat 训练,如果精度指标下降或不达预期,需要增加更高精度量化的算子,结合固定校准激活 scale 以及 withbn 的做法,设置不同 lr 进行实验。特殊情况:如果模型中有训练辅助头或者 loss 计算被量化,不符合我们量化适配预期,可能造成全 int16 下训练精度达标但是混合精度不达标的现象

三、常见问题和原因

问题现象问题原因
征程 6 尝试开启 POT 量化后,QAT 训练指标下降,或训不上去pot scale 加大了 qat fine-tune 调优的难度
QAT 链路去掉 prepare qat 环节,fine-tune 指标就掉浮点权重未收敛,往往取浮点训练中间权重易出现
1. Calib 指标正常,QAT 训练指标下降,越训越低2. 全 int16 精度配置,QAT 训练指标正常;混合精度配置 QAT 训练初始 loss 高,越训越差1. 未按照标准建议流程设计调优实验,如 QAT 训练未尝试固定 Calib 阶段的激活 scale2. QAT 训练超参学习率未调整到位3. 训练过程中的训练辅助分支或 loss 相关计算被量化4. Calib 阶段的精度调优未做扎实,模型有量化敏感算子
1. QAT 训练前向或反向过程存在 NAN/INF 异常值2. QAT 训练的 Loss 值极大,出现异常值浮点模型或者 Calib 模型中就存在异常值,Calib 阶段的精度调优未做扎实,异常值未排查出,一般异常值可能为数据 mask 操作引入

四、解决思路和方案

4.1 浮点模型权重未训收敛,QAT 本质也是小学习率 fine-tune,指标越训越低

解决方案:用浮点初始学习率的 1/10~1/100 来微调浮点,观察指标是否出现大幅度波动,如出现则未收敛,需要增加浮点权重的迭代次数。此外还需要对齐浮点训练和 qat 训练各个实验的训练配置,例如是否提交集群、机器数和卡数配置、训练集以及 dataloader 配置等

4.2 征程 6B/P 的 POT 策略可以提高部署模型一致性表现以及对性能有收益,但 pot scale 加大了 qat fine-tune 调优的难度

解决方案:

  1. 对于 征程 6B/H/P,默认 int16 下开启 POT,int8 关闭 POT;对于 征程 6E/M,默认不开启 POT。因此客户侧选择关闭 POT 再做验证,如果模型为全 int 量化,建议直接全局关闭 POT:
# 全局整体关闭 from horizon_plugin_pytorch.quantization.fake_quantize_base import FakeQuantizeBase FakeQuantizeBase._enable_fp16_compute = False calib_model = prepare(...) # 指定算子关闭,prepare后调整 calib_model = prepare(...) calib_model.conv1.qconfig = Qconfig( activation=FakeQuantize.with_kwargs( ..., _enable_fp16_compute=False ) ) # 指定算子关闭,prepare前调整 disable_pot_ops = { "conv1": Qconfig( activation=FakeQuantize.with_kwargs( ..., _enable_fp16_compute=False ) ) } calib_model = prepare( ..., custom_qconfig_mapping=disable_pot_ops )
  1. calib 阶段通过敏感度定位到具体 POT 量化敏感算子,针对性优化 POT 精度。目前未出现过 calib pot 精度好但是 qat pot 精度差的现象

4.3 calib 阶段的量化精度未做扎实,主要量化问题仍存在:敏感算子误差未解除、INF 值或者极大的 scale 值影响

解决方案:

  1. 大部分模型进入到 qat 前 calib 精度应达到的标准:calib 精度至少需要达到浮点的 90%,建议达到浮点的 95%。结合调优建议 https://developer\。horizon\.auto/blog/13132 以及 https://developer\。horizon\.auto/blog/13157 ; 少数模型 calib 精度没有达到上述标准但通过 qat 也可以训回来,需要结合具体模型和 case 来看,通常来说如果 qat 训练指标上升但未达标,第一时间仍然是优化 calib 精度。
  2. calib 阶段做扎实敏感度分析:
    1. 敏感度怎么解读:https://developer\。horizon\.auto/blog/13132 以及 https://developer\。horizon\.auto/blog/13157
    2. 什么是标准的敏感度:
      1. 敏感度值呈现梯度分布(主要),例如排序前 20% 算子应贡献 80% 的量化误差,下面敏感度靠前和最后的值基本一致,为常见的错误敏感度
      2. 敏感算子分布合理,敏感算子应集中在模型关键结构,例如任务头、多尺度特征融合模块、backbone 高维特征层等,不应该均匀地分布在整个网络
      3. 敏感度值的大小取决于模型输出物理特征或者语义特征,以及 badcase 的典型性,只要敏感度符合上面特征则可以协助我们定位量化问题,不用过多关注敏感度本身的数值量级
      4. 模型在 badcase 上的整体敏感度应高于单算子的敏感度;不同敏感度 metric 下的排序结果应接近
    3. 敏感度不正常时排查和解决方式:
      1. 在 calib 链路正确的前提下,确定 analysis_model 需要比对的量化结构和 baseline_model 是可以对的上的,下面提供一种方式
      2. 在跑敏感度时检查后处理算子是否正确处理:对于 topk、nms、sort 和 argmax 这类算子在跑敏感度时应从 forward 中去除;对于 sigmoid 算子在跑敏感度时应加在 forward 中
      3. 检查 badcase 是否数据异常,是否为校准数据集中的脏数据
      4. 对于多类别的回归模型,评测链路的后处理中一般会有根据置信度阈值筛选的操作,例如 bbox 框通过 score 按照不同的阈值进行过滤,这部分过滤的操作也需要加入到 forward 当中进行敏感度分析
# Step1 将calib model作为analysis_model并设置评测状态 self.analysis_model = analysis_model_convert_pipeline( copy.deepcopy(model) ) self.analysis_model.eval() set_fake_quantize(self.analysis_model, FakeQuantState.VALIDATION) # Step2 复制analysis_model并设置成浮点状态作为baseline_model self.baseline_model = copy.deepcopy(self.analysis_model) set_fake_quantize(self.baseline_model, FakeQuantState._FLOAT) # 或者使用FakeQuantState._CALIBRATION_V2
  1. 异常值或极大 scale 出现场景和解决方式:
    1. Attn mask、其他数据 mask 等设定的 mask value。优化方式:手动设置较小值
    2. clamp 算子未指定最大最小值,或者最大最小值数值过大或过小:
    3. norm 算子的拆分算子如 mul/input_mean/var_mean,实际对精度的影响需要结合敏感度分析。优化方式:手动截断大数值,以稀疏分布的截断换取更高的量化分辨率;征程 6P 平台尝试设置 fp16/fp32 高精度,征程 6M 尝试使用 QAT 训练,来进一步优化舍入误差
    4. gemm(尤其指 matmul)层输出。优化方式:数据归一化,增加 norm 层;有物理含义结合物理含义进行放缩;结合算子支持情况和目标部署平台,切换量化精度,例如 int16->fp16、fp16->int16(在某些值域内 int16 表示范围更大)、fp16->fp32
    5. scale 出现 nan 值:
      1. 原因:模型中有 tensor 包含 inf/nan 值
      2. 定位方式:日志中会打印建议把 check_nan_scale 打开,打开后可具体找到 nan scale 出现的代码位置,再逐步定位到 inf/nan 出现的最开始的位置
      3. 案例:参数初始化时未考虑到边界情况,通过添加 clamp 即可避免出现 inf/nan
# 修改前 depth = torch.clamp(proj_pt[..., 2:3], min=1e-9) # 修改后 depth = torch.clamp(proj_pt[..., 2:3], min=1e-2)
from horizon_plugin_pytorch.quantization.fake_quantize_base import FakeQuantizeBase FakeQuantizeBase.check_nan_scale = "forward" calib_model = prepare(...)
def sigmoid_inverse(y): epsilon = 1e-6 y = torch.clamp(y, min=epsilon, max=1.0 - epsilon) x = torch.log(y / (1.0 - y)) return x
  1. 结合 calib 阶段调优流程和建议:https://developer\。horizon\.auto/blog/13132 以及 https://developer\。horizon\.auto/blog/13157

4.4 qat 适配过程存在问题,一些辅助头分支或者 loss 相关计算被量化

解决方案:

  1. 关注 qat train 过程中 miss_key/unexpected_key 的信息,被量化的辅助头或者 loss 计算往往会以 miss_key 形式提醒,大部分情况被显式量化的操作在对比 calib 和 qat 各自生成的 model_check_result 生成物后可被发现
  2. 梳理 qat train 阶段计算图 fx_graph 调用关系,重点关注模型输出,可使用可视化工具:
from horizon_plugin_pytorch.fx.visualize import visualize visualize(model=qat_calib_model, output_path='./horizonn_calib_model_vis.onnx')
  1. 对比 calib train 和 qat train 过程中生成的 fx_graph,重点关注模型任务头输出部分 trace 到的算子和操作

4.5 未按照标准建议流程设计调优实验,未首先采取例如固定校准激活 scale、取消 warmup、使用较小的固定 lr 微调等优化手段

解决方案:

  1. qat 训练初始 loss 大或出现异常值,初始精度掉点,应按如下流程设计实验:
    1. 检查是否在训练时使用了特殊的数据增强策略例如旋转和马赛克等,应当去除;warmup 应当去除
    2. 去掉 prepare 的步骤,用 qat pipeline fine-tune 浮点,实验是否正常。如异常需检查训练配置如优化器和 lr_updater
    3. 保留 qat 的配置,设置_FLOAT 或_CALIBRATION_V2(不会引入工具 clip 误差)状态关闭伪量化节点,实验是否正常,正常精度应基本对齐浮点,则继续固定校准激活 scale 的实验
    4. 固定校准激活 scale,lr 设置为 0,实验是否正常,正常精度应对齐 calib
    5. 固定校准激活 scale,使用较小的固定 lr 微调,进行实验
  2. qat 训练 loss 收敛慢,精度相比 calib 无变化(训不上去),实验建议:
    1. 优先调整 BN 状态控制,默认 fuse_bn(无需开启 sync_bn),安排 with_bn 的实验(sync_bn 对齐浮点训练配置)
    2. 取消固定校准激活 scale,和固定 scale 做对比
    3. 取消 warmup
    4. 采取较大的 lr
    5. 延长 qat 迭代次数

4.6 其他

解决方案:结合问题定位流程定位到问题阶段和现象,综合参考上述解决方案,不排除可能存在 bug,需同步到地平线工具链团队解决