Exponential Moving Average

Exponential Moving Average (EMA) of model weights with a fused update step

EMACallback and EMAWarmupCallback have a fast fused implementation using PyTorch ForEach methods from PyTorch’s _multi_tensor optimizers. The fused EMA step is 1 to 8.4 times faster1 than a standard Python EMA step via a for loop.

Table 1: For Loop EMA Step vs Fused ForEach EMA Step
Model For Loop Step ForEach Step Speedup
XResNet50 8.04µs 2.37µs 2.3x
XSE-ResNeXt50 10.7µs 2.34µs 3.5x
XResNet101 15.3µs 4.27µs 2.5x
ConvNeXt Tiny 8.14µs 963ns 7.4x
ViT Patch16 Small 7.37µs 778ns 8.4x
DeBERTa Base 10.0µs 4.79µs 1.0x

source

EMACallback

 EMACallback (decay:float=0.9998, start:Numeric=0,
              ema_device:torch.device|str|None=None,
              validate_ema:bool=True, replace_weights:bool=False,
              foreach:bool|None=None, resume:bool=False,
              all_parameters:bool=False, all_buffers:bool=False,
              skip_ema:bool=True)

Exponential Moving Average (EMA) of model weights with a fused update step

Type Default Details
decay float 0.9998 EMA decay value
start Numeric 0 Start EMA in percent of training steps (float) or epochs (int, index 0)
ema_device torch.device | str | None None Device to store EMA weights. Defaults to model device
validate_ema bool True Run validation metrics using EMA weights instead of model weights. If true, ema_device must match model device
replace_weights bool False Replace model weights with EMA weights when finished training. If false, sets Learner.model_ema to EMA weights
foreach bool | None None Fuse EMA update step with PyTorch ForEach methods or use a standard for loop. Defaults to true if PyTorch 1.12+ and Cuda device detected
resume bool False Resume from EMA weights from previous training saved to Learner.model_ema
all_parameters bool False Apply EMA step to all parameters or only those with requires_grad
all_buffers bool False Apply EMA step to persistent model buffers or all buffers
skip_ema bool True Skip EMA step if callbacks, such as GradientAccumulation or MixedPrecision, skip the Optimizer update step

EMACallback is inspired by ModelEmaV2 from PyTorch Image Model (timm), and should match the TensorFlow EMA implementation.

The ema_decay default of 0.9998 means each EMA model update EMACallback will keep 99.98% of the prior EMA weights and update 0.02% towards the training model weights.

To prevent EMACallback from using GPU memory, set device='cpu'. EMA validation will need to be performed manually post-training or via a custom callback.

To use the fused EMA step, set foreach=True. Likewise set to false to disable. EMACallback will automatically select the fused method if using PyToch 1.12+, a Cuda device is detected, and foreach is the default of foreach=None.

If all_parameters=False, only parameters with requires_grad are included in the EMA calculation.

If all_buffers=False, only persistent buffers are included in the EMA calculation.

If skip_ema=True (the default), then the EMA calculation will not apply if any other callback raises a CancelBatchException or CancelStepException. This is intended to handle the fastai.callback.fp16.MixedPrecision AMP scaler and fastai.callback.training.GradientAccumulation skipping the optimizer update step, which means the model weights won’t have changed so the EMA step should not be calculated. If needed this behavior can be turned off. In general, this argument should be left unchanged.


source

EMAWarmupCallback

 EMAWarmupCallback (start_decay:float=0.9, final_decay:float=0.9998,
                    start:Numeric=0, finish:Numeric=0.3,
                    schedule:Callable[...,_Annealer]=<function SchedCos>,
                    ema_device:torch.device|str|None=None,
                    validate_ema:bool=True, replace_weights:bool=False,
                    foreach:bool|None=None, resume:bool=False,
                    all_parameters:bool=False, all_buffers:bool=False,
                    skip_ema:bool=True, logger_callback:str='wandb')

Exponential Moving Average (EMA) of model weights with a warmup schedule and fused update step

Type Default Details
start_decay float 0.9 Initial EMA decay value
final_decay float 0.9998 Final EMA decay value
start Numeric 0 Start EMA warmup in percent of training steps (float) or epochs (int, index 0)
finish Numeric 0.3 Finish EMA warmup in percent of training steps (float) or epochs (int, index 0)
schedule Callable[…, _Annealer] SchedCos EMA decay warmup schedule
ema_device torch.device | str | None None Device to store EMA weights. Defaults to model device
validate_ema bool True Run validation metrics using EMA weights instead of model weights. If true, ema_device must match model device
replace_weights bool False Replace model weights with EMA weights when finished training. If false, set Learner.model_ema to EMA weights
foreach bool | None None Fuse EMA update step with PyTorch ForEach methods or use a standard for loop. Defaults to true if PyTorch 1.12+ and Cuda device detected
resume bool False Resume from EMA weights from previous training saved to Learner.model_ema
all_parameters bool False Apply EMA step to all parameters or only those with requires_grad
all_buffers bool False Apply EMA step to persistent model buffers or all buffers
skip_ema bool True Skip EMA step if callbacks, such as GradientAccumulation or MixedPrecision, skip the Optimizer update step
logger_callback str wandb Log EMA decay to logger_callback using Callback.name if available

EMAWarmupCallback extends EMACallback by adding a schedulable EMA decay value from an initial value of start_decay to final_decay for the rest of training. The change in the EMA decay occurs between start_epoch and final_epoch.

The EMA warmup schedule can be one of SchedCos (the default), SchedLin, SchedExp, SchedPoly, or a custom fastai annealer based schedule. SchedPoly must be passed as partial function: partial(SchedPoly, power=0.5).

Warning

EMAWarmupCallback does not support resumed training while EMA warmup is in progress. This is due to fastai not fully supporting resumable training.

EMAWarmupCallback can resume training after the warmup period as finished.

Weights & Biases Logging

If Weights & Biases is installed and the WandbCallback is added to Learner, EMAWarmupCallback will automatically log the current EMA decay rate to Weights & Biases as ema_decay.

Extend to other Loggers

To extend to new loggers, follow the Weights & Biases code below and create patches for EMAWarmupCallback to add a _{Callback.name}_log_ema_decay, where Callback.name is the name of the logger callback.

try:
    import wandb

    @patch
    def _wandb_log_ema_decay(self:EMAWarmupCallback, decay:float):
        wandb.log({'ema_decay': decay}, self.learn.wandb._wandb_step+1)
except:
    pass

Then to use, pass logger_callback='{Callback.name}' to EMAWarmupCallback.

EMAWarmupCallback sets its _log_ema_decay method to f'_{self.logger_callback}_log_ema_decay', which should match the patched method.

self._log_size = getattr(self, f'_{self.logger_callback}_log_ema_decay', noop)

Footnotes

  1. EMACallback performance was benchmarked on a GeForce 3080 Ti using PyTorch 1.13.1, Cuda 11.7, Mixed Precision, and Channels Last (except DeBERTa and ViT). Results may differ on other models, hardware, and across benchmarking runs. Speedup is calculated from the total time spent on the EMA step.↩︎