try:
import wandb
@patch
def _wandb_log_ema_decay(self:EMAWarmupCallback, decay:float):
'ema_decay': decay}, self.learn.wandb._wandb_step+1)
wandb.log({except:
pass
Exponential Moving Average
EMACallback
and EMAWarmupCallback
have a fast fused implementation using PyTorch ForEach methods from PyTorch’s _multi_tensor
optimizers. The fused EMA step is 1 to 8.4 times faster1 than a standard Python EMA step via a for loop.
Model | For Loop Step | ForEach Step | Speedup |
---|---|---|---|
XResNet50 | 8.04µs | 2.37µs | 2.3x |
XSE-ResNeXt50 | 10.7µs | 2.34µs | 3.5x |
XResNet101 | 15.3µs | 4.27µs | 2.5x |
ConvNeXt Tiny | 8.14µs | 963ns | 7.4x |
ViT Patch16 Small | 7.37µs | 778ns | 8.4x |
DeBERTa Base | 10.0µs | 4.79µs | 1.0x |
EMACallback
EMACallback (decay:float=0.9998, start:Numeric=0, ema_device:torch.device|str|None=None, validate_ema:bool=True, replace_weights:bool=False, foreach:bool|None=None, resume:bool=False, all_parameters:bool=False, all_buffers:bool=False, skip_ema:bool=True)
Exponential Moving Average (EMA) of model weights with a fused update step
Type | Default | Details | |
---|---|---|---|
decay | float | 0.9998 | EMA decay value |
start | Numeric | 0 | Start EMA in percent of training steps (float) or epochs (int, index 0) |
ema_device | torch.device | str | None | None | Device to store EMA weights. Defaults to model device |
validate_ema | bool | True | Run validation metrics using EMA weights instead of model weights. If true, ema_device must match model device |
replace_weights | bool | False | Replace model weights with EMA weights when finished training. If false, sets Learner.model_ema to EMA weights |
foreach | bool | None | None | Fuse EMA update step with PyTorch ForEach methods or use a standard for loop. Defaults to true if PyTorch 1.12+ and Cuda device detected |
resume | bool | False | Resume from EMA weights from previous training saved to Learner.model_ema |
all_parameters | bool | False | Apply EMA step to all parameters or only those with requires_grad |
all_buffers | bool | False | Apply EMA step to persistent model buffers or all buffers |
skip_ema | bool | True | Skip EMA step if callbacks, such as GradientAccumulation or MixedPrecision, skip the Optimizer update step |
EMACallback
is inspired by ModelEmaV2
from PyTorch Image Model (timm), and should match the TensorFlow EMA implementation.
The ema_decay
default of 0.9998 means each EMA model update EMACallback
will keep 99.98% of the prior EMA weights and update 0.02% towards the training model weights.
To prevent EMACallback
from using GPU memory, set device='cpu'
. EMA validation will need to be performed manually post-training or via a custom callback.
To use the fused EMA step, set foreach=True
. Likewise set to false to disable. EMACallback
will automatically select the fused method if using PyToch 1.12+, a Cuda device is detected, and foreach is the default of foreach=None
.
If all_parameters=False
, only parameters with requires_grad
are included in the EMA calculation.
If all_buffers=False
, only persistent buffers are included in the EMA calculation.
If skip_ema=True
(the default), then the EMA calculation will not apply if any other callback raises a CancelBatchException
or CancelStepException
. This is intended to handle the fastai.callback.fp16.MixedPrecision
AMP scaler and fastai.callback.training.GradientAccumulation
skipping the optimizer update step, which means the model weights won’t have changed so the EMA step should not be calculated. If needed this behavior can be turned off. In general, this argument should be left unchanged.
EMAWarmupCallback
EMAWarmupCallback (start_decay:float=0.9, final_decay:float=0.9998, start:Numeric=0, finish:Numeric=0.3, schedule:Callable[...,_Annealer]=<function SchedCos>, ema_device:torch.device|str|None=None, validate_ema:bool=True, replace_weights:bool=False, foreach:bool|None=None, resume:bool=False, all_parameters:bool=False, all_buffers:bool=False, skip_ema:bool=True, logger_callback:str='wandb')
Exponential Moving Average (EMA) of model weights with a warmup schedule and fused update step
Type | Default | Details | |
---|---|---|---|
start_decay | float | 0.9 | Initial EMA decay value |
final_decay | float | 0.9998 | Final EMA decay value |
start | Numeric | 0 | Start EMA warmup in percent of training steps (float) or epochs (int, index 0) |
finish | Numeric | 0.3 | Finish EMA warmup in percent of training steps (float) or epochs (int, index 0) |
schedule | Callable[…, _Annealer] | SchedCos | EMA decay warmup schedule |
ema_device | torch.device | str | None | None | Device to store EMA weights. Defaults to model device |
validate_ema | bool | True | Run validation metrics using EMA weights instead of model weights. If true, ema_device must match model device |
replace_weights | bool | False | Replace model weights with EMA weights when finished training. If false, set Learner.model_ema to EMA weights |
foreach | bool | None | None | Fuse EMA update step with PyTorch ForEach methods or use a standard for loop. Defaults to true if PyTorch 1.12+ and Cuda device detected |
resume | bool | False | Resume from EMA weights from previous training saved to Learner.model_ema |
all_parameters | bool | False | Apply EMA step to all parameters or only those with requires_grad |
all_buffers | bool | False | Apply EMA step to persistent model buffers or all buffers |
skip_ema | bool | True | Skip EMA step if callbacks, such as GradientAccumulation or MixedPrecision, skip the Optimizer update step |
logger_callback | str | wandb | Log EMA decay to logger_callback using Callback.name if available |
EMAWarmupCallback
extends EMACallback
by adding a schedulable EMA decay value from an initial value of start_decay
to final_decay
for the rest of training. The change in the EMA decay occurs between start_epoch
and final_epoch
.
The EMA warmup schedule
can be one of SchedCos
(the default), SchedLin
, SchedExp
, SchedPoly
, or a custom fastai annealer based schedule. SchedPoly
must be passed as partial function: partial(SchedPoly, power=0.5)
.
EMAWarmupCallback
does not support resumed training while EMA warmup is in progress. This is due to fastai not fully supporting resumable training.
EMAWarmupCallback
can resume training after the warmup period as finished.
Weights & Biases Logging
If Weights & Biases is installed and the WandbCallback
is added to Learner
, EMAWarmupCallback
will automatically log the current EMA decay rate to Weights & Biases as ema_decay
.
Extend to other Loggers
To extend to new loggers, follow the Weights & Biases code below and create patches for EMAWarmupCallback
to add a _{Callback.name}_log_ema_decay
, where Callback.name
is the name of the logger callback.
Then to use, pass logger_callback='{Callback.name}'
to EMAWarmupCallback
.
EMAWarmupCallback
sets its _log_ema_decay
method to f'_{self.logger_callback}_log_ema_decay'
, which should match the patched method.
self._log_size = getattr(self, f'_{self.logger_callback}_log_ema_decay', noop)
Footnotes
EMACallback
performance was benchmarked on a GeForce 3080 Ti using PyTorch 1.13.1, Cuda 11.7, Mixed Precision, and Channels Last (except DeBERTa and ViT). Results may differ on other models, hardware, and across benchmarking runs. Speedup is calculated from the total time spent on the EMA step.↩︎