声明:本文主要参考开源资料进行学习整理,如有错漏,欢迎评论交流~
本文主要介绍时序模型在使用 horizon_plugin_profiler.QuantAnalysis 进行量化调试时的原理以及具体示例。
1. 时序模型分类
时序模型根据状态传递方式分为两类:
类型 | 定义方式 | 特点 |
|---|---|---|
显式时序 | forward(x, pre_h) 参数传递 | 状态由外部管理,调用者负责维护 |
隐式时序 | self.register_buffer(cached_h, ...) | 状态注册为模型属性,模型内部管理 |
2. QuantAnalysis 基本流程
3. cached_attrs 原理解析
作用
cached_attrs 用于确保 baseline 模型和 analysis 模型在每次推理前拥有相同的状态。
工作机制
源码位置:horizon_plugin_profiler/find_bad_case.py
实现细节
适用范围
类型 | 示例 | 能否用 cachedattrs |
|---|---|---|
Buffer | self.register_buffer(pre_h, ...) | ✅ 可以 |
Parameter | self.pre_h = nn.Parameter(...) | ✅ 可以 |
普通属性 | self.pre_h = torch.zeros(...) | ✅ 可以 |
forward 参数 | def forward(self, x, pre_h) | ❌ 不能 |
重要:cachedattrs 只能管理模型的属性,无法管理 forward 的参数!
4. 时序模型调试方案
显式时序模型
方案:时序状态由外部管理,cached_attrs 管理其他属性变量。
使用方式:
隐式时序模型
方案:时序状态注册为 buffer,由 cached_attrs 自动管理。
使用方式:
