Exponential Moving Average

Exponential Moving Average (EMA) of model weights with a fused update step

EMACallback and EMAWarmupCallback have a fast fused implementation using PyTorch ForEach methods from PyTorch’s _multi_tensor optimizers. The fused EMA step is 1 to 18 times faster1 than a standard Python EMA step via a for loop.

Table 1: For Loop EMA Step vs Fused ForEach EMA Step
Model For Loop Step ForEach Step Speedup
XResNet50 12.07ms 678.3µs 18x
XSE-ResNext50 11.25ms 613.4µs 15x
XResNet101 13.96ms 1.193ms 10x
ConvNext Tiny 8.244ms 764.0µs 9x
ViT Patch16 Small 10.22ms 650.8µs 14x
DeBERTa Base 9.646ms 4.630ms 1x

source

EMACallback

 EMACallback (decay:float=0.9998, start_epoch:Number=0,
              ema_device:torch.device|str|None=None,
              validate_ema:bool=True, replace_weights:bool=False,
              foreach:bool|None=None, resume:bool=False)

Exponential Moving Average (EMA) of model weights with a fused update step

Type Default Details
decay float 0.9998 EMA decay value
start_epoch Number 0 Epoch to start EMA in percent of training steps (float) or epochs (int, index 0)
ema_device torch.device | str | None None Device to store EMA weights. Defaults to model device
validate_ema bool True Run validation metrics using EMA weights instead of model weights. If true, ema_device must match model device
replace_weights bool False Replace model weights with EMA weights when finished training. If false, set Learner.model_ema to EMA weights
foreach bool | None None Fuse EMA update step with PyTorch ForEach methods or use a standard for loop. Defaults to true if PyTorch 1.12+ and Cuda device detected
resume bool False Resume from EMA weights from previous training saved to Learner.model_ema

EMACallback is inspired by ModelEmaV2 from PyTorch Image Model (timm), and should match the TensorFlow EMA implementation.

The ema_decay default of 0.9998 means each EMA model update EMACallback will keep 99.98% of the prior EMA weights and update 0.02% towards the training model weights.

To prevent EMACallback from using GPU memory, set device='cpu'. EMA validation will need to be performed manually post-training or via a custom callback.

To use the fused EMA step, set foreach=True. Likewise set to false to disable. EMACallback will automatically select the fused method if using PyToch 1.12+ and a Cuda device is detected.


source

EMAWarmupCallback

 EMAWarmupCallback (start_decay:float=0.9, final_decay:float=0.9998,
                    start_epoch:Number=0, final_epoch:Number=0.3,
                    schedule:Callable[...,_Annealer]=<function SchedCos>,
                    ema_device:torch.device|str|None=None,
                    validate_ema:bool=True, replace_weights:bool=False,
                    foreach:bool|None=None, resume:bool=False,
                    logger_callback:str='wandb')

Exponential Moving Average (EMA) of model weights with a warmup schedule and fused update step

Type Default Details
start_decay float 0.9 Initial EMA decay value
final_decay float 0.9998 Final EMA decay value
start_epoch Number 0 Epoch to start EMA warmup in percent of training steps (float) or epochs (int, index 0)
final_epoch Number 0.3 Epoch to finish EMA warmup in percent of training steps (float) or epochs (int, index 0)
schedule Callable[…, _Annealer] SchedCos EMA decay warmup schedule
ema_device torch.device | str | None None Device to store EMA weights. Defaults to model device
validate_ema bool True Run validation metrics using EMA weights instead of model weights. If true, ema_device must match model device
replace_weights bool False Replace model weights with EMA weights when finished training. If false, set Learner.model_ema to EMA weights
foreach bool | None None Fuse EMA update step with PyTorch ForEach methods or use a standard for loop. Defaults to true if PyTorch 1.12+ and Cuda device detected
resume bool False Resume from EMA weights from previous training saved to Learner.model_ema
logger_callback str wandb Log EMA decay to logger_callback using Callback.name if available

EMAWarmupCallback extends EMACallback by adding a schedulable EMA decay value from an initial value of start_decay to final_decay for the rest of training. The change in the EMA decay occurs between start_epoch and final_epoch.

The EMA warmup schedule can be one of SchedCos (the default), SchedLin,SchedExp, SchedPoly, or a custom fastai annealer based schedule. SchedPoly must be passed as partial function: partial(SchedPoly, power=0.5).

EMA Warmup Wandb Logging

try:
    import wandb

    @patch
    def _wandb_log_ema_decay(self:EMAWarmupCallback, decay:float):
        wandb.log({'ema_decay': decay}, self.learn.wandb._wandb_step+1)
except:
    pass

Extend to other Loggers

To extend to new loggers, follow the Weights & Biases code above and create patches for EMAWarmupCallback to add a _{Callback.name}_log_ema_decay, where Callback.name is the name of the logger callback.

Then to use, pass logger_callback='{Callback.name}' to EMAWarmupCallback.

EMAWarmupCallback sets its _log_ema_decay method to f'_{self.logger_callback}_log_ema_decay', which should match the patched method.

self._log_size = getattr(self, f'_{self.logger_callback}_log_ema_decay', noop)

Footnotes

  1. EMACallback performance was benchmarked on a GeForce 3080 Ti using PyTorch 1.12.1, Mixed Precision, and Channels Last (except DeBERTa). Results may differ on other models, hardware, and across benchmarking runs. Speedup is calculated from the total time spent on the EMA step and rounded down to the nearest whole number.↩︎