专栏感知QAT入门 代码一键执行

QAT入门 代码一键执行

kotei左文亮2024-11-11
57
0
        本文主要根据OE包文档里的QAT流程来编写,版本是v1.1.68,以 torchvision 中的 MobileNetV2 模型为例,介绍QAT流程中每个阶段的具体操作和详细代码,读者可安装好docker环境之后,复制里面的代码一键执行。出于流程展示的执行速度考虑,我们使用了 cifar-10 数据集。为了快速展示整体方法和流程,本文未进行精细地调参提高模型精度,用户实际使用过程中可以根据需求进行调参。地平线工具链OE文档里,有一部分代码是需要自己写的,此文已经做了添加,也可自己进行修改。

 

一、浮点训练

import os

import copy

import numpy as np

import torch

import torch.nn as nn

import torchvision.transforms as transforms from torch

import Tensor from torch.quantization

import DeQuantStub from torchvision.datasets

import CIFAR10 from torchvision.models.mobilenetv2

import MobileNetV2 from torch.utils import data

from typing import Optional, Callable, List, Tuple

from horizon_plugin_pytorch.functional import rgb2centered_yuv

import torch.quantization

from horizon_plugin_pytorch.march import March, set_march

from horizon_plugin_pytorch.quantization import ( QuantStub, convert_fx, prepare_qat_fx, set_fake_quantize, FakeQuantState, check_model, compile_model, perf_model, visualize_model, )

from horizon_plugin_pytorch.quantization.qconfig import ( default_calib_8bit_fake_quant_qconfig, default_qat_8bit_fake_quant_qconfig, default_qat_8bit_weight_32bit_out_fake_quant_qconfig, default_calib_8bit_weight_32bit_out_fake_quant_qconfig, )

import logging

logging.basicConfig(level=logging.INFO, format="%(levelname)s: %(message)s")

 

 

用户可根据需要修改以下参数

 

# 1. 模型 ckpt 和编译产出物的保存路径

model_path = "model/mobilenetv2"

# 2. 数据集下载和保存的路径

data_path = "/open_explorer"

# 3. 训练时使用的 batch_size

train_batch_size = 256

# 4. 预测时使用的 batch_size

#eval_batch_size = 256

eval_batch_size = 32

# 5. 训练的 epoch 数

epoch_num = 30

# 6. 模型保存和执行计算使用的

device device = ( torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu") )


准备数据集, cifar-10 数据集是本人提前下载好的保存在目录“/open_explore”之下。请注意 collate_fn 中对 rgb2centered_yuv 的使用。

def prepare_data_loaders( data_path: str, train_batch_size: int, eval_batch_size: int ) -> Tuple[data.DataLoader, data.DataLoader]:

normalize = transforms.Normalize(mean=0.0, std=128.0)

 

      def collate_fn(batch):

             batched_img = torch.stack( [ torch.from_numpy(np.array(example[0], np.uint8, copy=True)) for example in batch ] ).permute(0, 3, 1, 2)                             batched_target = torch.tensor([example[1] for example in batch])

             batched_img = rgb2centered_yuv(batched_img)

             batched_img = normalize(batched_img.float())

             return batched_img, batched_target

 

     train_dataset = CIFAR10( data_path, True, transforms.Compose( [ transforms.RandomHorizontalFlip(), transforms.RandAugment(), ] ), download=False, )

     eval_dataset = CIFAR10( data_path, False, download=False, )

     train_data_loader = data.DataLoader( train_dataset, batch_size=train_batch_size, sampler=data.RandomSampler(train_dataset), num_workers=8, collate_fn=collate_fn, pin_memory=True, )

     eval_data_loader = data.DataLoader( eval_dataset, batch_size=eval_batch_size, sampler=data.SequentialSampler(eval_dataset), num_workers=8, collate_fn=collate_fn, pin_memory=True, )

      return train_data_loader, eval_data_loader



对浮点模型做必要的改造,以支持量化相关操作。模型改造必要的操作有:
在模型输入前插入 QuantStub在模型输出后插入 DequantStub
改造模型时需要注意:
插入的 QuantStub 和 DequantStub 必须注册为模型的子模块,否则将无法正确处理它们的量化状态
多个输入仅在 scale 相同时可以共享 QuantStub,否则请为每个输入定义单独的 QuantStub
若需要将上板时输入的数据来源指定为 "pyramid",请手动设置对应 QuantStub 的 scale 参数为 1/128
也可以使用 torch.quantization.QuantStub,但是仅有 horizon_plugin_pytorch.quantization.QuantStub 支持通过参数手动固定 scale
改造后的模型可以无缝加载改造前模型的参数,因此若已有训练好的浮点模型,直接加载即可,否则需要正常进行浮点训练。

class FxQATReadyMobileNetV2(MobileNetV2):

        def __init__( self, num_classes: int = 10, width_mult: float = 1.0, inverted_residual_setting: Optional[List[List[int]]] = None, round_nearest: int = 8, ):                super().__init__( num_classes, width_mult, inverted_residual_setting, round_nearest ) self.quant = QuantStub(scale=1 / 128) self.dequant = DeQuantStub()

        def forward(self, x: Tensor) -> Tensor:

              x = self.quant(x)

              x = super().forward(x)

              x = self.dequant(x)

              return x

 

训练一个eopch的代码,此处OE文档里没有详细给出,需要自己编写,下面的代码可以使用。

def train_one_epoch(model,optimizer,train_loader,device):

       for i, (data, target) in enumerate(train_loader):

             optimizer.zero_grad()

             data = data.to(device)

             target =target.to(device)

             output = model(data)

             loss = nn.CrossEntropyLoss()(output, target)

             loss.backward()

             optimizer.step()

 

 

#检测精度代码


class AverageMeter(object):

"""Computes and stores the average and current value"""

       def __init__(self, name, fmt=':f'):

             self.name = name

             self.fmt = fmt

             self.reset()

       def reset(self):

             self.val = 0

             self.avg = 0 self.sum = 0 self.count = 0

       def update(self, val, n=1):

             self.val = val

             self.sum += val * n

             self.count += n

             self.avg = self.sum / self.count

       def __str__(self): fmtstr = '{name} {val' + self.fmt + '} ({avg' + self.fmt + '})' return fmtstr.format(**self.__dict__)

 

def accuracy(output, target, topk=(1,)):

"""Computes the accuracy over the k top predictions for the specified values of k"""

      with torch.no_grad():

              maxk = max(topk)

              batch_size = target.size(0)

              _, pred = output.topk(maxk, 1, True, True)

              pred = pred.t()

              correct = pred.eq(target.view(1, -1).expand_as(pred))

              res = []

               for k in topk:

                     correct_k = correct[:k].contiguous().view(-1).float().sum(0, keepdim=True)

                     res.append(correct_k.mul_(100.0 / batch_size))

               return res

 

def evaluate(model, criterion, data_loader, neval_batches):

       model.eval()

       top1 = AverageMeter('Acc@1', ':6.2f')

       top5 = AverageMeter('Acc@5', ':6.2f')

       cnt = 0

       with torch.no_grad():

               for image, target in data_loader:

                    image = image.to(device)

                    target =target.to(device)

                    output = model(image)

                    loss = criterion(output, target)

                    cnt += 1

                    acc1, acc5 = accuracy(output, target, topk=(1, 5))

                    print('.', end = '')

                    top1.update(acc1[0], image.size(0))

                    top5.update(acc5[0], image.size(0))

                    if cnt >= neval_batches:

                           return top1, top5

 

#训练,保存模型。

if not os.path.exists(model_path):

      os.makedirs(model_path, exist_ok=True)

 

# 浮点模型初始化

float_model = FxQATReadyMobileNetV2()

 

# 准备数据集

train_data_loader, eval_data_loader = prepare_data_loaders( data_path, train_batch_size, eval_batch_size )

# 由于模型的最后一层和预训练模型不一致,需要进行浮点 finetune

optimizer = torch.optim.Adam( float_model.parameters(), lr=0.001, weight_decay=1e-3 )

best_acc = 0

float_model=float_model.to(device)

for nepoch in range(epoch_num):

      float_model.train()

      train_one_epoch( float_model, optimizer, train_data_loader, device )

      # 浮点精度测试

      float_model.eval()

      top1, top5 = evaluate(float_model, nn.CrossEntropyLoss(), eval_data_loader, eval_batch_size)

      print( "Float Epoch {}: evaluation Acc@1 {:.3f} Acc@5 {:.3f}".format( nepoch, top1.avg, top5.avg ) )

      if top1.avg > best_acc:

         best_acc = top1.avg

         # 保存最佳浮点模型参数

         torch.save( float_model.state_dict(), os.path.join(model_path, "float-checkpoint.ckpt"), )

 

二、校准模型

参数可根据需要修改以下  

# 1. Calibration 时使用的 batch_size

calib_batch_size = 256

# 2. Validation 时使用的 batch_size

eval_batch_size = 8

# 3. Calibration 使用的数据量,配置为 inf 以使用全部数据

num_examples = float("inf")

# 4. 目标硬件平台的代号

march = March.BAYES

 

 

#校准模型代码如下

# 在进行模型转化前,必须设置好模型将要执行的硬件平台

set_march(march)

float_params_dict = torch.load('./model/mobilenetv2/float-checkpoint.ckpt')

float_model.load_state_dict(float_params_dict)

# 将模型转化为 Calibration 状态,以统计各处数据的数值分布特征

calib_model = prepare_qat_fx(

# 输出模型会共享输入模型的 attributes,为不影响 float_model 的后续使用, # 此处进行了 deepcopy

       copy.deepcopy(float_model), { "": default_calib_8bit_fake_quant_qconfig, "module_name": {

       # 在模型的输出层为 Conv 或 Linear 时,可以使用 out_qconfig # 配置为高精度输出

               "classifier": default_calib_8bit_weight_32bit_out_fake_quant_qconfig, }, }, ).to( device )

# prepare_qat_fx 接口不保证输出模型的 device 和输入模型完全一致

# 准备数据集

calib_data_loader, eval_data_loader = prepare_data_loaders( data_path, calib_batch_size, eval_batch_size )

# 执行 Calibration 过程(不需要 backward)

# 注意此处对模型状态的控制,模型需要处于 eval 状态以使 Bn 的行为符合要求

calib_model.eval()

set_fake_quantize(calib_model, FakeQuantState.CALIBRATION)

with torch.no_grad():

        cnt = 0

        for image, target in calib_data_loader:

              image, target = image.to(device), target.to(device)

              calib_model(image)

             print(".", end="", flush=True)

             cnt += image.size(0)

             if cnt >= num_examples:

                 break

             print()

# 测试伪量化精度

# 注意此处对模型状态的控制

calib_model.eval()

set_fake_quantize(calib_model, FakeQuantState.VALIDATION)

#top1, top5 = evaluate(calib_model, eval_data_loader,device,)

top1, top5 = evaluate(calib_model, nn.CrossEntropyLoss(), eval_data_loader, eval_batch_size)

print( "Calibration: evaluation Acc@1 {:.3f} Acc@5 {:.3f}".format( top1.avg, top5.avg ) )

 

# 保存 Calibration 模型参数

torch.save( calib_model.state_dict(), os.path.join(model_path, "calib-checkpoint.ckpt"), )

三、 QAT训

# 5 用户可根据需要修改以下参数

# 训练时使用的 batch_size

train_batch_size = 256

# Validation 时使用的 batch_size

eval_batch_size = 8

# 训练的 epoch 数

epoch_num = 3

# 数据集下载和保存的路径

data_path = "/open_explorer"

model_path = "model/mobilenetv2"

# 模型保存和执行计算使用的 device

device = ( torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu") )

# 浮点模型初始化

float_model = FxQATReadyMobileNetV2()

# 将模型转化为 Calibration 状态,以统计各处数据的数值分布特征

calib_model = prepare_qat_fx(

  # 输出模型会共享输入模型的 attributes,为不影响 float_model 的后续使用,

   # 此处进行了 deepcopy

      copy.deepcopy(float_model),

      { "": default_calib_8bit_fake_quant_qconfig,

            "module_name":{

           # 在模型的输出层为 Conv 或 Linear 时,可以使用 out_qconfig

           # 配置为高精度输出

                     "classifier": default_calib_8bit_weight_32bit_out_fake_quant_qconfig, }, }, ).to( device )

# prepare_qat_fx 接口不保证输出模型的 device 和输入模型完全一致

####################### ###注意:校准模型一定要先加载校准参数,然后QAT模型再加载校准模型参数;QAT模型直接加载校准参数会报错 ###########################

calib_params_dict = torch.load('./model/mobilenetv2/calib-checkpoint.ckpt')

calib_model.load_state_dict(calib_params_dict)

# 将模型转为 QAT 状态

qat_model = prepare_qat_fx(

           copy.deepcopy(float_model), {

           "": default_qat_8bit_fake_quant_qconfig, "module_name": { "classifier":       

               default_qat_8bit_weight_32bit_out_fake_quant_qconfig, }, }, ).to(device)

 

#calib_params_dict = torch.load('./model/mobilenetv2/calib-checkpoint.ckpt')

# 加载 Calibration 模型中的量化参数

qat_model.load_state_dict(calib_model.state_dict())

#qat_model.load_state_dict('./model/mobilenetv2/calib-checkpoint.ckpt')

# 进行量化感知训练

# 作为一个 filetune 过程,量化感知训练一般需要设定较小的学习率

optimizer = torch.optim.Adam( qat_model.parameters(), lr=1e-3, weight_decay=1e-4 )

best_acc = 0

# 准备数据集

train_data_loader, eval_data_loader = prepare_data_loaders( data_path, train_batch_size, eval_batch_size )

 

for nepoch in range(epoch_num):

 

# 注意此处对 QAT 模型 training 状态的控制方法

       qat_model.train()

       set_fake_quantize(qat_model, FakeQuantState.QAT)

       train_one_epoch( qat_model, optimizer, train_data_loader, device, )

# 注意此处对 QAT 模型 eval 状态的控制方法

        qat_model.eval()

        set_fake_quantize(qat_model, FakeQuantState.VALIDATION)

         top1, top5 = evaluate( qat_model, nn.CrossEntropyLoss(), eval_data_loader, eval_batch_size, )

         print( "QAT Epoch {}: evaluation Acc@1 {:.3f} Acc@5 {:.3f}".format( nepoch, top1.avg, top5.avg ) )

         if top1.avg > best_acc:

             best_acc = top1.avg

             torch.save( qat_model.state_dict(), os.path.join(model_path, "qat-checkpoint.ckpt"), )

 

######################################################################

# 6 用户可根据需要修改以下参数

# 6.1. 使用哪个模型作为流程的输入,可以选择 calib_model 或 qat_model base_model = qat_model ######################################################################

# 将模型转为定点状态

quantized_model = convert_fx(base_model).to(device)

# 测试定点模型精度

# top1, top5 = evaluate( # quantized_model, # eval_data_loader, # device, # )

top1, top5 = evaluate( quantized_model, nn.CrossEntropyLoss(), eval_data_loader, eval_batch_size, )

print( "Quantized model: evaluation Acc@1 {:.3f} Acc@5 {:.3f}".format( top1.avg, top5.avg ) )



四、模型编译量化

# 编译时启用的优化等级,可选 0~3,等级越高编译出的模型上板执行速度越快,

# 但编译过程会慢

compile_opt = "O1"

######################################################################

# 这里的 example_input 也可以是随机生成的数据,但是推荐使用真实数据,以提高

# 性能测试的准确性

example_input = next(iter(eval_data_loader))[0]

# 通过 trace 将模型序列化并生成计算图,注意模型和数据要放在 CPU 上

script_model = torch.jit.trace(quantized_model.cpu(), example_input)

print("script_model:")

print(script_model)

torch.jit.save(script_model, os.path.join(model_path, "int_model.pt"))

# 模型检查

check_model(script_model, [example_input])

# 8 模型编译,生成的 hbm 文件即为可部署的模型

ret = compile_model( script_model, [example_input], hbm=os.path.join(model_path, "model.hbm"), input_source="pyramid", opt=compile_opt, )


五、性能测试

ret = perf_model( script_model, [example_input], out_dir=os.path.join(model_path, "perf_out"), input_source="pyramid", opt=compile_opt, layer_details=True, )

# 10 模型可视化

visualize_model( script_model, [example_input], save_path=os.path.join(model_path, "model.svg"), show=False, )

 

六、产物解读

会产生如下文件,
其中“float-checkpoint.ckpt”是浮点训练后保存的浮点模型;
“calib-checkpoint.ckpt” 是浮点模型经过校准后的校准模型;
“qat-checkpoint.ckpt”是QAT训练后的模型;
“int_model.pt”是通过“trace”将模型序列化生成计算图。
“model.hbm”是经过编译后的模型文件,即部署在板端的模型。
“model.svg”是模型可视化文件,里面有关于模型算子的详细信息。
文件夹“perf_out”里是性能测试的结果
感知
技术深度解析征程5
评论0
0/1000