Exponential Moving Average

Exponential Moving Average (EMA) of model weights with a fused update step

EMACallback and EMASchedule have a fast fused implementation using PyTorch ForEach methods from PyTorch’s _multi_tensor optimizers. The fused EMA step is 1 to 8.4 times faster1 than a standard Python EMA step via a for loop.

Table 1: For Loop EMA Step vs Fused ForEach EMA Step
Model For Loop Step ForEach Step Speedup
XResNet50 8.04µs 2.37µs 2.3x
XSE-ResNeXt50 10.7µs 2.34µs 3.5x
XResNet101 15.3µs 4.27µs 2.5x
ConvNeXt Tiny 8.14µs 963ns 7.4x
ViT Patch16 Small 7.37µs 778ns 8.4x
DeBERTa Base 10.0µs 4.79µs 1.0x

source

EMACallback

 EMACallback (decay:float=0.9998, start:Numeric=0,
              ema_device:torch.device|str|None=None,
              validate_ema:bool=True, replace_weights:bool=False,
              foreach:bool|None=None, resume:bool=False,
              all_parameters:bool=False, all_buffers:bool=False,
              skip_ema:bool=True)

Exponential Moving Average (EMA) of model weights with a fused update step

Type Default Details
decay float 0.9998 EMA decay value
start Numeric 0 Start EMA in percent of training steps (float) or epochs (int, index 0)
ema_device torch.device | str | None None Device to store EMA weights. Defaults to model device
validate_ema bool True Run validation metrics using EMA weights instead of model weights. If true, ema_device must match model device
replace_weights bool False Replace model weights with EMA weights when finished training. If false, sets Learner.model_ema to EMA weights
foreach bool | None None Fuse EMA update step with PyTorch ForEach methods or use a standard for loop. Defaults to true if PyTorch 1.12+ and Cuda device detected
resume bool False Resume from EMA weights from previous training saved to Learner.model_ema
all_parameters bool False Apply EMA step to all parameters or only those with requires_grad
all_buffers bool False Apply EMA step to persistent model buffers or all buffers
skip_ema bool True Skip EMA step if callbacks, such as GradientAccumulation or MixedPrecision, skip the Optimizer update step

EMACallback is inspired by ModelEmaV2 from PyTorch Image Model (timm), and should match the TensorFlow EMA implementation.

The ema_decay default of 0.9998 means each EMA model update EMACallback will keep 99.98% of the prior EMA weights and update 0.02% towards the training model weights.

To prevent EMACallback from using GPU memory, set device='cpu'. EMA validation will need to be performed manually post-training or via a custom callback.

To use the fused EMA step, set foreach=True. Likewise set to false to disable. EMACallback will automatically select the fused method if using PyToch 1.12+, a Cuda device is detected, and foreach is the default of foreach=None.

If all_parameters=False, only parameters with requires_grad are included in the EMA calculation.

If all_buffers=False, only persistent buffers are included in the EMA calculation.

If skip_ema=True (the default), then the EMA calculation will not apply if any other callback raises a CancelBatchException or CancelStepException. This is intended to handle the fastai.callback.fp16.MixedPrecision AMP scaler and fastai.callback.training.GradientAccumulation skipping the optimizer update step, which means the model weights won’t have changed so the EMA step should not be calculated. If needed this behavior can be turned off. In general, this argument should be left unchanged.


source

EMASchedule

 EMASchedule (start_decay:float=0.9, final_decay:float=0.9998,
              start:Numeric=0, finish:Numeric=0.3,
              schedule:Callable[...,_Annealer]=<function SchedCos>,
              ema_device:torch.device|str|None=None,
              validate_ema:bool=True, replace_weights:bool=False,
              foreach:bool|None=None, resume:bool=False,
              all_parameters:bool=False, all_buffers:bool=False,
              skip_ema:bool=True)

Exponential Moving Average (EMA) of model weights with a warmup schedule and fused update step

Type Default Details
start_decay float 0.9 Initial EMA decay value
final_decay float 0.9998 Final EMA decay value
start Numeric 0 Start EMA warmup in percent of training steps (float) or epochs (int, index 0)
finish Numeric 0.3 Finish EMA warmup in percent of training steps (float) or epochs (int, index 0)
schedule Callable[…, _Annealer] SchedCos EMA decay warmup schedule
ema_device torch.device | str | None None Device to store EMA weights. Defaults to model device
validate_ema bool True Run validation metrics using EMA weights instead of model weights. If true, ema_device must match model device
replace_weights bool False Replace model weights with EMA weights when finished training. If false, set Learner.model_ema to EMA weights
foreach bool | None None Fuse EMA update step with PyTorch ForEach methods or use a standard for loop. Defaults to true if PyTorch 1.12+ and Cuda device detected
resume bool False Resume from EMA weights from previous training saved to Learner.model_ema
all_parameters bool False Apply EMA step to all parameters or only those with requires_grad
all_buffers bool False Apply EMA step to persistent model buffers or all buffers
skip_ema bool True Skip EMA step if callbacks, such as GradientAccumulation or MixedPrecision, skip the Optimizer update step

EMASchedule extends EMACallback by adding a schedulable EMA decay value from an initial value of start_decay to final_decay for the rest of training. The change in the EMA decay occurs between start_epoch and final_epoch.

The EMA schedule can be one of SchedCos (the default), SchedLin, SchedExp, SchedPoly, or a custom fastai annealer based schedule. SchedPoly must be passed as partial function: partial(SchedPoly, power=0.5).

Warning: Resumed Training

EMASchedule does not support resumed training while EMA scheduling is in progress. This is due to fastai not fully supporting resumable training.

EMASchedule can resume training after the schedule period is finished.

EMASchedule supports logging to Weights & Biases and TensorBoard via the LogDispatch callback. If either the fastai.callback.wandb.WandbCallback or fastai.callback.tensorboard.TensorBoardCallback are added to Learner, EMASchedule will automatically log the current EMA decay rate as ema_decay.

Footnotes

  1. EMACallback performance was benchmarked on a GeForce 3080 Ti using PyTorch 1.13.1, Cuda 11.7, Mixed Precision, and Channels Last (except DeBERTa and ViT). Results may differ on other models, hardware, and across benchmarking runs. Speedup is calculated from the total time spent on the EMA step.↩︎