1. 概述
ObserverBase 是 horizon_plugin_pytorch 量化框架中所有 Observer 的抽象基类。它定义了量化校准器的统一接口和核心功能,为各种量化策略(MinMax、MSE、KL等)提供了基础架构。
2. ABCMeta 深度解析
2.1 Python 元类机制
在 Python 中,类也是对象,类是由元类(metaclass) 创建的:
默认情况下,所有类都由元类创建。当指定 metaclass=ABCMeta 时,类的创建过程由 ABCMeta 控制。
示例如下:
2.2 @abstractmethod 装饰器
2.3 ObserverBase 中的应用
3. ObserverBase 完整源码
4. 核心属性详解
4.1 量化配置属性
4.2 统计量缓冲区
使用 register_buffer 注册的原因:
- 不参与梯度计算:统计量不是模型参数
- 随模型迁移设备:model.cuda() 时自动迁移
- 可保存到 state_dict:校准结果可持久化
5. 核心方法详解
5.1 __init__ - 初始化
参数说明:
参数 | 默认值 | 说明 |
|---|---|---|
averaging_constant | 0.01 | 移动平均系数,值越大当前batch权重越高 |
ch_axis | -1 | 通道轴,负数表示 per_tensor,非负表示 per_channel |
dtype | qint8 | 量化数据类型 |
qscheme | per_tensor_symmetric | 量化方案 |
quant_min/max | None | 自定义量化范围,None时根据dtype自动设置 |
is_sync_quantize | True | 多卡训练时是否同步统计量 |
关键校验逻辑:
5.2 forward - 更新统计信息(抽象方法)
设计意图:
- 子类必须实现此方法(由 ABCMeta 强制)
在校准阶段,每个 forward pass 收集激活值的统计信息
返回原始输入(不修改数据流)
典型实现模式:
5.3 calculate_qparams - 计算量化参数
核心计算逻辑(_compute_scale_symmetric):
5.4 sync_minmax - 多卡同步
原理:
- 使用 all_reduce 聚合多卡的统计量
- MIN 操作取所有卡的最小值
- MAX 操作取所有卡的最大值
确保多卡训练时校准结果一致
5.5 _load_from_state_dict - 状态加载
关键功能:
- 版本兼容(处理旧版名称 min_vals → min_val)
动态调整 buffer 大小
支持从校准模型加载参数到 QAT 模型
6. 类继承体系
7. 设计亮点
- 统一接口:所有 Observer 遵循相同的 API,便于替换和扩展
- 抽象基类约束:通过 ABCMeta 强制子类实现 forward 方法
- 状态持久化:统计量作为 buffer 保存,支持校准结果复用
- 分布式支持:内置多卡同步机制
- 版本兼容:_load_from_state_dict 处理历史版本兼容
- 灵活配置:支持多种量化方案、数据类型、scale 策略
8. 与 PyTorch 原生 Observer 的对比
特性 | PyTorch ObserverBase | Horizon ObserverBase |
|---|---|---|
量化方案 | 支持非对称量化 | 仅支持对称量化 |
scale 约束 | 无 | POT/FP16/KPOT 策略 |
分布式同步 | 需自行实现 | 内置 sync_minmax |
数据类型 | 标准 torch.dtype | 扩展 QuantDType (qint4等) |
版本管理 | 无 | _version 字段支持迁移 |

