Adan: ADAptive Nesterov Momentum Optimizer

With fastai native, fused ForEach, and fused TorchScript implementations

Adan was introduced by Xie et al in Adan: Adaptive Nesterov momentum Algorithm for Faster Optimizing Deep Models. Adan uses a efficient Nesterov momentum estimation method to avoid the extra computation and memory overhead of calculating the extrapolation point gradient.

Nadam also estimates Nesterov momentum, but in contrast it only estimates the first-order gradient moment while Adan estimates both first- and second-order movements.

For consistency with other fastai optimizers, the coefficients beta1, beta2, and beta3 have been inversed from the paper values, e.g. β1=0.98 instead of β1=0.02.

Note

This implementation of Adan does not contain the restart condition, as it is mostly unused in the paper.

In addition to a fastai native implementation, Adan has fused ForEach and Torchscript implementations. See the Fused Optimizer documentation for more details.


source

Adan

 Adan (params:Union[torch.Tensor,Iterable[torch.Tensor],fastcore.foundatio
       n.L,fastcore.basics.fastuple], lr:float, beta1:float=0.98,
       beta2:float=0.92, beta3:float=0.99, eps:float=1e-08, wd:float=0.02,
       paper_init:bool=False, foreach:bool=False, jit:bool=False)

A fastai Adan optimizer with optional ForEach and TorchScript implementations

Type Default Details
params listified[Tensor] Model parameters or parameter groups
lr float Default learning rate
beta1 float 0.98 Gradient moving average (β1) coefficient
beta2 float 0.92 Gradient difference moving average (β2) coefficient
beta3 float 0.99 Gradient squared moving average (β3) coefficient
eps float 1e-08 Added for numerical stability
wd float 0.02 True weight decay
paper_init bool False Initialize prior gradient with current gradient per paper, or zeroes
foreach bool False Use fused ForEach implementation
jit bool False Use fused TorchScript implementation
Returns Optimizer | AdanForEachOptimizer | JitOptimizer

source

adan

 adan (beta1:float=0.98, beta2:float=0.92, beta3:float=0.99,
       eps:float=1e-08, wd:float=0.02, paper_init:bool=False,
       foreach:bool=False, jit:bool=False)

Partial function for the Adan optimizer with fused ForEach and TorchScript implementations

Type Default Details
beta1 float 0.98 Gradient moving average (β1) coefficient
beta2 float 0.92 Gradient difference moving average (β2) coefficient
beta3 float 0.99 Gradient squared moving average (β3) coefficient
eps float 1e-08 Added for numerical stability
wd float 0.02 True weight decay
paper_init bool False Initialize prior gradient with current gradient per paper, or zeroes
foreach bool False Use fused ForEach implementation
jit bool False Use fused TorchScript implementation
Returns Optimizer | AdanForEachOptimizer | JitOptimizer

source

AdanLargeBatchLR

 AdanLargeBatchLR (bs:int)

Square root rule for scaling Adan learning rate for large-batch training

Hyperparameters

Hyperparameter notes from Xie et al:

  1. beta2 is the least sensitive Adan hyperparameter, default of 0.92 works for majority of tasks
  2. Xie et al primarily tune beta3 (between 0.9-0.999) before beta1 (between 0.9-0.98) for different tasks
  3. Adan pairs well with large learning rates. Paper and GitHub report up to 3x larger than Lamb and up to 5-10x larger than Adam
  4. Xie et al use the default weight decay of 0.02 for all tasks except fine-tuning Bert wd=0.01 and reinforcement learning wd=0
Note

With paper_init=True, fastxtend’s Adan matches Xie et al’s Adan implementation.

Training Speed

Important

ForEach and TorchScript optimizers have only been tested on PyTorch 1.12 and are not guaranteed to work on older versions.

One critque of Adan is when using a standard PyTorch implementation, Adan is significantly slower than AdamW. Between 41 to 97 percent slower on the models benchmarked in Table 2 (a) below.

As shown in Table 1, fastxtend’s fused ForEach Adan is 95 to 284 percent faster1 then a standard PyTorch implementation.

Table 1: Increase in Adan opt_step Speed vs Native Optimizer
Model Native Step ForEach Step ForEach Speedup JIT Step JIT Speedup
XResNet18 36ms 17ms 112% 23ms 60%
XResNet50 78ms 41ms 95% 50ms 59%
XSE-ResNext50 109ms 52ms 108% 74ms 48%
XResNet101 131ms 59ms 120% 75ms 75%
DeBERTa Base 53ms 13ms 284% 22ms 137%

While the Adan ForEach methods are still 17 to 82 percent slower than the AdamW ForEach methods, the difference in performance and as a percentage of total training time is significantly smaller. Adan ForEach is 6ms to 11ms slower then AdamW ForEach instead of 11ms to 43ms with Adan native as shown in Table 2.

Table 2: AdamW vs Adan Training Speed

(a) Native Implementation
Model AdamW Step Adan Step Slowdown
XResNet18 25ms 36ms 44%
XResNet50 55ms 78ms 41%
XSE-ResNext50 72ms 109ms 52%
XResNet101 88ms 131ms 51%
DeBERTa Base 27ms 53ms 97%
(b) Fused ForEach Implementation
Model AdamW Step Adan Step Slowdown
XResNet18 13ms 17ms 38%
XResNet50 31ms 41ms 28%
XSE-ResNext50 43ms 52ms 17%
XResNet101 48ms 59ms 27%
DeBERTa Base 6.9ms 13ms 82%

Footnotes

  1. Benchmarked on a GeForce 3080 Ti using PyTorch 1.12.1, Mixed Precision, Channels Last (except DeBERTa), and fastxtend’s Simple Profiler Callback. Results may differ on other models, hardware, and across benchmarking runs. Speedup and slowdown are calculated from the total time spent on the optimization step.↩︎