专栏算法工具链hat中MultiBatchProcessor类

hat中MultiBatchProcessor类

已解决
beyondaa2024-12-17
51
5

您好,在这个类中,batches包含了多个batch的数据,那每个bach的数据会通过module计算一个loss。我看到每个batch的数据会进行一个loss的backward(),但是最终循环处理完之后再进行update grade。我想问一下这最后一次进行梯度更新是,是用前面最后一个batch的loss计算的梯度更新,还是累计更新,这部分没有看明白,能否解答一下?


class MultiBatchProcessor(BatchProcessorMixin):
    """

    Processor can forward backward multiple batches within a training step (before `optimizer.step()`).



    It is useful for:



    (1) Training a multitask model on single task annotation samples, of which

    each task forward backward its batch sequentially within a multitask training step



    (2) Training on a memory shortage GPU and want to increase batch size,

    you are able to forward backward multiple batches within a training step



    .. note::



        Example multitask: vehicle, person and traffic light detection.

        Single task annotation means only annotate vehicle bounding boxes on an image with vehicle,

        person, and traffic light objects.



    .. note::



        Multiple batches should be organized in tuple format, e.g.



        * `batch = (batch1, batch2, ...)`



        If not, it will be treated as a single batch, e.g.



        * `batch = dict(inputs=xx, target=xx)`



        * `batch = [inputs, target]`



        See code below for extra explanation.



    It is much general in usage than `BasicBatchProcessor` , batch and model

    outputs can be in any format, but note that if batch is a tuple means it contains multiple batches.



    It is Hardware independent, run on cpu (device None) or gpu

    (device is gpu id).



    It is suitable for training (need_grad_update) and validation

    (not need_grad_update).



    Args:

        need_grad_update: Whether need gradient update, True for training,

            False for Validation.

        batch_transforms: Config of batch transforms.

        inverse_transforms: Config of transforms,

            used for infer results transform inverse.

        loss_collector: A callable object used to collect loss Tensors in model

            outputs.

        enable_amp: Whether training with `Automatic Mixed Precision`.

        enable_amp_dtype: The dtype of amp, float16 or bfloat16.

        enabel_apex: Whether training with `Apex`.

        enable_channels_last: Whether training with `channels_last`.

        channels_last_keys: Keys in batch need to convert to channels_last.

            if None, all 4d-tensor in batch data will convert to channels_last.

        delay_sync: Whether delay sync grad when train on DDP.

            Refer to: DDP.no_sync() API

        empty_cache: Whether to execute torch.cuda.empty_cache() after each

            forward and backward run

        grad_scaler: An instance ``scaler`` of :class:`GradScaler`

            helps perform the steps of gradient scaling conveniently.

        grad_accumulation_step: The step of grad accumulation.

            Gradient accumulation refers to multiple backwards passes are

            performed before updating the parameters. The goal is to update

            the model's parameters based on different steps,

            instead of performing an update after every single batch.

    """  # noqa


    def __init__(

        self,

        need_grad_update: bool,

        batch_transforms: Optional[List] = None,
        inverse_transforms: Optional[List] = None,
        loss_collector: Callable = None,
        enable_amp: bool = False,

        enable_amp_dtype: torch.dtype = torch.float16,

        enable_apex: bool = False,
        enable_channels_last: bool = False,
        channels_last_keys: Optional[Sequence[str]] = None,
        delay_sync: bool = False,
        empty_cache: bool = False,
        grad_scaler: torch.cuda.amp.GradScaler = None,
        grad_accumulation_step: Union[int, str] = 1,

    ):

        if need_grad_update:
            assert (
                loss_collector is not None
            ), "Provide `loss_collector` when need_grad_update"
            assert callable(loss_collector)
        if enable_amp and enable_apex:
            raise RuntimeError(
                "enable_amp and enable_apex cannot be true together."

            )

        if enable_apex and apex is None:
            raise ModuleNotFoundError("Apex is required.")


        self.ga_step = grad_accumulation_step
        if delay_sync or self.ga_step > 1:

            torch_version = torch.__version__

            assert (enable_apex and apex is not None) or LooseVersion(

                torch_version

            ) >= LooseVersion(

                "1.10.2"
            ), "Delay sync or grad accumulation \

                need apex enabled or higher version of torch."



        if enable_amp_dtype == torch.bfloat16:
            if not torch.cuda.is_bf16_supported():
                raise RuntimeError(
                    "current gpu devices do not support bfloat16."

                )



        self.need_grad_update = need_grad_update
        self.loss_collector = loss_collector
        self.enable_amp_dtype = enable_amp_dtype
        self.enable_apex = enable_apex
        self.enable_channels_last = enable_channels_last
        self.channels_last_keys = channels_last_keys
        self.delay_sync = delay_sync
        self.empty_cache = empty_cache
        if grad_scaler is not None:
            self.grad_scaler = grad_scaler
        else:
            self.grad_scaler = GradScaler(enabled=enable_amp)
        self.enable_amp = self.grad_scaler.is_enabled()
        if enable_amp:
            assert self.enable_amp, (
                "When grad_scaler is not None, enable_amp does not work."
                "You set enable_amp is {}, but the enable_amp of "
                "grad_scaler is {}. Please check your config!!"
            ).format(enable_amp, self.enable_amp)
        if batch_transforms:
            if isinstance(batch_transforms, (list, tuple)):

                batch_transforms = torchvision.transforms.Compose(

                    batch_transforms

                )  # noqa
            self.transforms = batch_transforms
        else:
            self.transforms = None
        self.inverse_transforms = inverse_transforms


    def __call__(

        self,

        step_id: int,

        batch: Union[Tuple[Any], List[Any], object],

        model: torch.nn.Module,

        device: Union[int, None],
        optimizer=None,
        storage: EventStorage = None,
        batch_begin_callback: Callable = None,
        batch_end_callback: Callable = None,
        backward_begin_callback: Callable = None,
        backward_end_callback: Callable = None,
        optimizer_step_begin_callback: Callable = None,
        forward_begin_callback: Callable = None,
        forward_end_callback: Callable = None,
        profiler: Optional[Union[BaseProfiler, str]] = None,

    ):

        assert self.need_grad_update == model.training, (
            "%s vs. %s, set model to training/eval mode by "
            "model.train()/model.eval() when need_grad_update or not"
            % (self.need_grad_update, model.training)

        )



        if profiler is None:

            profiler = PassThroughProfiler()



        # 0. reset grad
        if self.need_grad_update:
            with profiler.profile("optimizer_zero_grad"):
                optimizer.zero_grad(set_to_none=True)


        if isinstance(batch, tuple):
            # Means that `batch_data` contains multiple batches, e.g.
            # (1) contains task specific batches of a `multitask model`
            # batch_data = (
            #    [task1_data1, task1_data2, ...],   # task1 batch
            #    [task2_data1, task2_data2, ...],   # task2 batch
            #    [task3_data1, task3_data2, ...],   # can be list/tuple of objs
            #    task4_data                         # or just a single obj
            #    ...
            # )
            #
            # (2) contains multiple batches for a single task model
            # batch_data = (
            #    [batch1_data1, batch1_data2, ...],
            #    [batch2_data1, batch2_data2, ...], # can be list/tuple of objs
            #    data1                              # or just a single obj
            #    ...
            # )

            batches = batch

        else:
            # Means that `data` just contains a single batch, e.g.
            # (1) is a single obj
            # batch_data = task_data  # e.g. a dict(inputs=xx, target=xx)
            #
            # (2) is a list (NOT A TUPLE) of objs
            # batch_data = [task_data1, task_data2, ...]
            #
            # convert to tuple

            batches = (batch,)



        # for each batch in multiple batches
        last_batch_idx = len(batches) - 1
        for idx, batch_i in enumerate(batches):
            if batch_begin_callback is not None:

                batch_begin_callback(batch=batch_i, batch_idx=idx)



            if device is not None:
                batch_i = to_cuda(batch_i, device, non_blocking=True)
            else:
                # run on cpu
                pass


            if isinstance(batch_i, Tuple) and len(batch_i) == 2:
                profile_suffix = batch_i[1]
            else:

                profile_suffix = idx



            if self.transforms is not None:
                with profiler.profile(f"batch_transforms_{profile_suffix}"):
                    batch_i = (self.transforms(batch_i[0]), batch_i[1])


            if self.enable_channels_last:

                batch_i = convert_memory_format(

                    batch_i, self.channels_last_keys, torch.channels_last

                )



            # 1. forward

            grad_decorator = (

                torch.enable_grad if self.need_grad_update else torch.no_grad

            )

            if not self.enable_apex:

                auto_cast = autocast(

                    enabled=self.enable_amp, dtype=self.enable_amp_dtype

                )

            else:

                auto_cast = localcontext()



            if forward_begin_callback is not None:

                forward_begin_callback(batch=batch_i, model=model)



            with profiler.profile(f"model_forward_{profile_suffix}"):
                with auto_cast:
                    with grad_decorator():
                        if self.delay_sync and idx != last_batch_idx:
                            # delay sync grad util last batch in mt tuple.
                            if hasattr(model, "disable_allreduce"):

                                model.disable_allreduce()

                        elif (
                            self.ga_step > 1
                            and step_id % self.ga_step != 0
                            and idx != last_batch_idx

                        ):

                            # delay sync grad by grad accumulation step.
                            if hasattr(model, "disable_allreduce"):

                                model.disable_allreduce()

                        else:
                            # only support enable_allreduce in apex
                            if hasattr(model, "enable_allreduce"):

                                model.enable_allreduce()

                        # model outputs can be in any format

                        model_outs = model(*_as_list(batch_i))



            if self.inverse_transforms is not None:
                model_outs = self.inverse_transforms(model_outs, batch_i)


            if forward_end_callback is not None:

                forward_end_callback(model_outs=model_outs, batch_idx=idx)



            # 2. filter out loss Tensors in model outputs
            if self.loss_collector is not None:
                losses = self.loss_collector(model_outs)
            else:
                losses = None


            if self.empty_cache:

                torch.cuda.empty_cache()



            # 3. backward
            if self.need_grad_update:
                # Not allow to backward each loss independently, so sum them

                loss = sum(

                    [loss for loss in _as_list(losses) if loss is not None]

                )

                print("***************", batch_i["task_name"], loss)
                assert isinstance(loss, torch.Tensor), type(loss)

                loss_scalar = loss.sum()

                if backward_begin_callback:

                    backward_begin_callback(batch=batch_i)

                # when grad_scaler is not enable, equivalent to loss.backward()
                with profiler.profile(f"model_backward_{profile_suffix}"):
                    if self.enable_apex:
                        with apex.amp.scale_loss(

                            loss_scalar, optimizer

                        ) as loss_s:

                            loss_s.backward()

                    else:
                        if self.delay_sync and idx != last_batch_idx:
                            # delay sync grad util last batch in mt tuple.
                            with model.no_sync():
                                self.grad_scaler.scale(loss_scalar).backward()
                        elif (
                            self.ga_step > 1
                            and step_id % self.ga_step != 0
                            and idx != last_batch_idx

                        ):

                            # delay sync grad by grad accumulation step.
                            with model.no_sync():
                                self.grad_scaler.scale(loss_scalar).backward()
                        else:
                            self.grad_scaler.scale(loss_scalar).backward()
                if backward_end_callback:

                    backward_end_callback(batch=batch_i, batch_idx=idx)



            if batch_end_callback is not None:

                batch_end_callback(

                    batch=batch_i, losses=losses, model_outs=model_outs

                )



        # 4. update grad
        if self.need_grad_update:
            if optimizer_step_begin_callback is not None:
                optimizer_step_begin_callback(grad_scaler=self.grad_scaler)
            # when grad_scaler is not enable, equivalent to optimizer.step()
            with profiler.profile("optimizer_step"):
                if self.enable_apex:

                    optimizer.step()

                else:
                    self.grad_scaler.step(optimizer)
                    self.grad_scaler.update()


        if self.empty_cache:

            torch.cuda.empty_cache()



        if self.enable_amp:

            storage.put(

                "grad_scaler", self.grad_scaler.state_dict(), always_dict=True

          )

算法工具链
征程5
评论1
0/1000
  • Huanghui
    Lv.5

    收到,先看看,稍后回复你!

    2024-12-17
    0
    4
    • beyondaa回复Huanghui:

      每个子batch的数据过来计算了loss,并进行了backward。但是update gard在for循环之后。那 update gard 梯度更新是用的最后一次loss还是前面所有batch依次进行的。这里面的机制不明白能解答一下吗
      2024-12-17
      0
    • beyondaa回复Huanghui:

      我理解如果是累计更新的话,如何第一个batch计算出来是+1,第二个batch计算出来是-1的话,那这个两个累计之后不就没有梯度了吗

      2024-12-17
      0
    • Huanghui回复beyondaa:

      是累计更新梯度的呀,你对数据每个子 batch 的计算 loss 并进行backward操作时,实际上是在计算每个子 batch 数据对于模型参数的梯度贡献,并且这些梯度会在计算完后被累积起来。虽然update grad在for循环之后,但它会利用之前每个子 batch 计算并累积的梯度信息来更新模型的参数。在每次backward时,计算出的梯度会被添加到之前已经存在的梯度上。当循环结束后执行update grad操作时,使用的是所有子 batch 数据累计得到的梯度来更新模型参数,而不是仅仅使用最后一个 batch 的梯度。这样做的好处是可以利用更多的数据样本来调整模型参数,使得参数更新更加稳定和准确,通常能够获得更好的模型性能和收敛效果。

      2024-12-17
      0
    • beyondaa回复Huanghui:

      那我这个呢,我理解如果是累计更新的话,如何第一个batch计算出来是+1,第二个batch计算出来是-1的话,那这个两个累计之后不就没有梯度了吗

      2024-12-17
      0