Adan: ADAptive Nesterov Momentum Optimizer
Adan was introduced by Xie et al in Adan: Adaptive Nesterov momentum Algorithm for Faster Optimizing Deep Models. Adan uses a efficient Nesterov momentum estimation method to avoid the extra computation and memory overhead of calculating the extrapolation point gradient.
Nadam also estimates Nesterov momentum, but in contrast it only estimates the first-order gradient moment while Adan estimates both first- and second-order movements.
For consistency with other fastai optimizers, the coefficients beta1
, beta2
, and beta3
have been inversed from the paper values, e.g. β1=0.98
instead of β1=0.02
.
In addition to a fastai native implementation, Adan
has fused ForEach and Torchscript implementations. See the Fused Optimizer documentation for more details.
Adan
Adan (params:Union[torch.Tensor,Iterable[torch.Tensor],fastcore.foundatio n.L,fastcore.basics.fastuple], lr:float, beta1:float=0.98, beta2:float=0.92, beta3:float=0.99, eps:float=1e-08, wd:float=0.02, paper_init:bool=False, foreach:bool=False, jit:bool=False)
A fastai Adan optimizer with optional ForEach and TorchScript implementations
Type | Default | Details | |
---|---|---|---|
params | listified[Tensor] | Model parameters or parameter groups | |
lr | float | Default learning rate | |
beta1 | float | 0.98 | Gradient moving average (β1) coefficient |
beta2 | float | 0.92 | Gradient difference moving average (β2) coefficient |
beta3 | float | 0.99 | Gradient squared moving average (β3) coefficient |
eps | float | 1e-08 | Added for numerical stability |
wd | float | 0.02 | True weight decay |
paper_init | bool | False | Initialize prior gradient with current gradient per paper, or zeroes |
foreach | bool | False | Use fused ForEach implementation |
jit | bool | False | Use fused TorchScript implementation |
Returns | Optimizer | AdanForEachOptimizer | JitOptimizer |
adan
adan (beta1:float=0.98, beta2:float=0.92, beta3:float=0.99, eps:float=1e-08, wd:float=0.02, paper_init:bool=False, foreach:bool=False, jit:bool=False)
Partial function for the Adan optimizer with fused ForEach and TorchScript implementations
Type | Default | Details | |
---|---|---|---|
beta1 | float | 0.98 | Gradient moving average (β1) coefficient |
beta2 | float | 0.92 | Gradient difference moving average (β2) coefficient |
beta3 | float | 0.99 | Gradient squared moving average (β3) coefficient |
eps | float | 1e-08 | Added for numerical stability |
wd | float | 0.02 | True weight decay |
paper_init | bool | False | Initialize prior gradient with current gradient per paper, or zeroes |
foreach | bool | False | Use fused ForEach implementation |
jit | bool | False | Use fused TorchScript implementation |
Returns | Optimizer | AdanForEachOptimizer | JitOptimizer |
AdanLargeBatchLR
AdanLargeBatchLR (bs:int)
Square root rule for scaling Adan
learning rate for large-batch training
Hyperparameters
Hyperparameter notes from Xie et al:
beta2
is the least sensitive Adan hyperparameter, default of 0.92 works for majority of tasks- Xie et al primarily tune
beta3
(between 0.9-0.999) beforebeta1
(between 0.9-0.98) for different tasks - Adan pairs well with large learning rates. Paper and GitHub report up to 3x larger than
Lamb
and up to 5-10x larger thanAdam
- Xie et al use the default weight decay of 0.02 for all tasks except fine-tuning Bert
wd=0.01
and reinforcement learningwd=0
Training Speed
One critque of Adan is when using a standard PyTorch implementation, Adan is significantly slower than AdamW. Between 41 to 97 percent slower on the models benchmarked in Table 2 (a) below.
As shown in Table 1, fastxtend’s fused ForEach Adan is 95 to 284 percent faster^{1} then a standard PyTorch implementation.
Model | Native Step | ForEach Step | ForEach Speedup | JIT Step | JIT Speedup |
---|---|---|---|---|---|
XResNet18 | 36ms | 17ms | 112% | 23ms | 60% |
XResNet50 | 78ms | 41ms | 95% | 50ms | 59% |
XSE-ResNext50 | 109ms | 52ms | 108% | 74ms | 48% |
XResNet101 | 131ms | 59ms | 120% | 75ms | 75% |
DeBERTa Base | 53ms | 13ms | 284% | 22ms | 137% |
While the Adan ForEach methods are still 17 to 82 percent slower than the AdamW ForEach methods, the difference in performance and as a percentage of total training time is significantly smaller. Adan ForEach is 6ms to 11ms slower then AdamW ForEach instead of 11ms to 43ms with Adan native as shown in Table 2.
Model | AdamW Step | Adan Step | Slowdown |
---|---|---|---|
XResNet18 | 25ms | 36ms | 44% |
XResNet50 | 55ms | 78ms | 41% |
XSE-ResNext50 | 72ms | 109ms | 52% |
XResNet101 | 88ms | 131ms | 51% |
DeBERTa Base | 27ms | 53ms | 97% |
Model | AdamW Step | Adan Step | Slowdown |
---|---|---|---|
XResNet18 | 13ms | 17ms | 38% |
XResNet50 | 31ms | 41ms | 28% |
XSE-ResNext50 | 43ms | 52ms | 17% |
XResNet101 | 48ms | 59ms | 27% |
DeBERTa Base | 6.9ms | 13ms | 82% |
Footnotes
Benchmarked on a GeForce 3080 Ti using PyTorch 1.12.1, Mixed Precision, Channels Last (except DeBERTa), and fastxtend’s Simple Profiler Callback. Results may differ on other models, hardware, and across benchmarking runs. Speedup and slowdown are calculated from the total time spent on the optimization step.↩︎