Exponential Moving Average
EMACallback
and EMASchedule
have a fast fused implementation using PyTorch ForEach methods from PyTorch’s _multi_tensor
optimizers. The fused EMA step is 1 to 8.4 times faster1 than a standard Python EMA step via a for loop.
Model | For Loop Step | ForEach Step | Speedup |
---|---|---|---|
XResNet50 | 8.04µs | 2.37µs | 2.3x |
XSE-ResNeXt50 | 10.7µs | 2.34µs | 3.5x |
XResNet101 | 15.3µs | 4.27µs | 2.5x |
ConvNeXt Tiny | 8.14µs | 963ns | 7.4x |
ViT Patch16 Small | 7.37µs | 778ns | 8.4x |
DeBERTa Base | 10.0µs | 4.79µs | 1.0x |
EMACallback
EMACallback (decay:float=0.9998, start:Numeric=0, ema_device:torch.device|str|None=None, validate_ema:bool=True, replace_weights:bool=False, foreach:bool|None=None, resume:bool=False, all_parameters:bool=False, all_buffers:bool=False, skip_ema:bool=True)
Exponential Moving Average (EMA) of model weights with a fused update step
Type | Default | Details | |
---|---|---|---|
decay | float | 0.9998 | EMA decay value |
start | Numeric | 0 | Start EMA in percent of training steps (float) or epochs (int, index 0) |
ema_device | torch.device | str | None | None | Device to store EMA weights. Defaults to model device |
validate_ema | bool | True | Run validation metrics using EMA weights instead of model weights. If true, ema_device must match model device |
replace_weights | bool | False | Replace model weights with EMA weights when finished training. If false, sets Learner.model_ema to EMA weights |
foreach | bool | None | None | Fuse EMA update step with PyTorch ForEach methods or use a standard for loop. Defaults to true if PyTorch 1.12+ and Cuda device detected |
resume | bool | False | Resume from EMA weights from previous training saved to Learner.model_ema |
all_parameters | bool | False | Apply EMA step to all parameters or only those with requires_grad |
all_buffers | bool | False | Apply EMA step to persistent model buffers or all buffers |
skip_ema | bool | True | Skip EMA step if callbacks, such as GradientAccumulation or MixedPrecision, skip the Optimizer update step |
EMACallback
is inspired by ModelEmaV2
from PyTorch Image Model (timm), and should match the TensorFlow EMA implementation.
The ema_decay
default of 0.9998 means each EMA model update EMACallback
will keep 99.98% of the prior EMA weights and update 0.02% towards the training model weights.
To prevent EMACallback
from using GPU memory, set device='cpu'
. EMA validation will need to be performed manually post-training or via a custom callback.
To use the fused EMA step, set foreach=True
. Likewise set to false to disable. EMACallback
will automatically select the fused method if using PyToch 1.12+, a Cuda device is detected, and foreach is the default of foreach=None
.
If all_parameters=False
, only parameters with requires_grad
are included in the EMA calculation.
If all_buffers=False
, only persistent buffers are included in the EMA calculation.
If skip_ema=True
(the default), then the EMA calculation will not apply if any other callback raises a CancelBatchException
or CancelStepException
. This is intended to handle the fastai.callback.fp16.MixedPrecision
AMP scaler and fastai.callback.training.GradientAccumulation
skipping the optimizer update step, which means the model weights won’t have changed so the EMA step should not be calculated. If needed this behavior can be turned off. In general, this argument should be left unchanged.
EMASchedule
EMASchedule (start_decay:float=0.9, final_decay:float=0.9998, start:Numeric=0, finish:Numeric=0.3, schedule:Callable[...,_Annealer]=<function SchedCos>, ema_device:torch.device|str|None=None, validate_ema:bool=True, replace_weights:bool=False, foreach:bool|None=None, resume:bool=False, all_parameters:bool=False, all_buffers:bool=False, skip_ema:bool=True)
Exponential Moving Average (EMA) of model weights with a warmup schedule and fused update step
Type | Default | Details | |
---|---|---|---|
start_decay | float | 0.9 | Initial EMA decay value |
final_decay | float | 0.9998 | Final EMA decay value |
start | Numeric | 0 | Start EMA warmup in percent of training steps (float) or epochs (int, index 0) |
finish | Numeric | 0.3 | Finish EMA warmup in percent of training steps (float) or epochs (int, index 0) |
schedule | Callable[…, _Annealer] | SchedCos | EMA decay warmup schedule |
ema_device | torch.device | str | None | None | Device to store EMA weights. Defaults to model device |
validate_ema | bool | True | Run validation metrics using EMA weights instead of model weights. If true, ema_device must match model device |
replace_weights | bool | False | Replace model weights with EMA weights when finished training. If false, set Learner.model_ema to EMA weights |
foreach | bool | None | None | Fuse EMA update step with PyTorch ForEach methods or use a standard for loop. Defaults to true if PyTorch 1.12+ and Cuda device detected |
resume | bool | False | Resume from EMA weights from previous training saved to Learner.model_ema |
all_parameters | bool | False | Apply EMA step to all parameters or only those with requires_grad |
all_buffers | bool | False | Apply EMA step to persistent model buffers or all buffers |
skip_ema | bool | True | Skip EMA step if callbacks, such as GradientAccumulation or MixedPrecision, skip the Optimizer update step |
EMASchedule
extends EMACallback
by adding a schedulable EMA decay value from an initial value of start_decay
to final_decay
for the rest of training. The change in the EMA decay occurs between start_epoch
and final_epoch
.
The EMA schedule
can be one of SchedCos
(the default), SchedLin
, SchedExp
, SchedPoly
, or a custom fastai annealer based schedule. SchedPoly
must be passed as partial function: partial(SchedPoly, power=0.5)
.
EMASchedule
does not support resumed training while EMA scheduling is in progress. This is due to fastai not fully supporting resumable training.
EMASchedule
can resume training after the schedule period is finished.
EMASchedule
supports logging to Weights & Biases and TensorBoard via the LogDispatch
callback. If either the fastai.callback.wandb.WandbCallback
or fastai.callback.tensorboard.TensorBoardCallback
are added to Learner
, EMASchedule
will automatically log the current EMA decay rate as ema_decay
.
Footnotes
EMACallback
performance was benchmarked on a GeForce 3080 Ti using PyTorch 1.13.1, Cuda 11.7, Mixed Precision, and Channels Last (except DeBERTa and ViT). Results may differ on other models, hardware, and across benchmarking runs. Speedup is calculated from the total time spent on the EMA step.↩︎