Adan: ADAptive Nesterov Momentum Optimizer
Adan was introduced by Xie et al in Adan: Adaptive Nesterov Momentum Algorithm for Faster Optimizing Deep Models. Adan uses a efficient Nesterov momentum estimation method to avoid the extra computation and memory overhead of calculating the extrapolation point gradient.
Nadam also estimates Nesterov momentum, but in contrast it only estimates the first-order gradient moment while Adan estimates both first- and second-order movements.
For consistency with other fastai optimizers, the coefficients beta1
, beta2
, and beta3
have been inversed from the paper values, e.g. β1=0.98
instead of β1=0.02
.
This implementation of Adan
does not contain the restart condition, as it is mostly unused in the paper.
In addition to a fastai native implementation, Adan
has fused ForEach and Torchscript implementations. See the Fused Optimizer documentation for more details.
Adan
Adan (params:Union[torch.Tensor,Iterable[torch.Tensor],MutableSequence[to rch.Tensor],fastcore.foundation.L,fastcore.basics.fastuple], lr:float, beta1:float=0.98, beta2:float=0.92, beta3:float=0.99, eps:float=1e-08, wd:float=0.02, paper_init:bool=False, foreach:bool=False, jit:bool=False)
A fastai Adan optimizer with optional ForEach and TorchScript implementations
Type | Default | Details | |
---|---|---|---|
params | Listified[Tensor] | Model parameters or parameter groups | |
lr | float | Default learning rate | |
beta1 | float | 0.98 | Gradient moving average (β1) coefficient |
beta2 | float | 0.92 | Gradient difference moving average (β2) coefficient |
beta3 | float | 0.99 | Gradient squared moving average (β3) coefficient |
eps | float | 1e-08 | Added for numerical stability |
wd | float | 0.02 | True weight decay |
paper_init | bool | False | Initialize prior gradient with current gradient per paper, or zeroes |
foreach | bool | False | Use fused ForEach implementation |
jit | bool | False | Use fused TorchScript implementation |
Returns | Optimizer | AdanForEachOptimizer | JitOptimizer |
adan
adan (beta1:float=0.98, beta2:float=0.92, beta3:float=0.99, eps:float=1e-08, wd:float=0.02, paper_init:bool=False, foreach:bool=False, jit:bool=False)
Partial function for the Adan optimizer with fused ForEach and TorchScript implementations
Type | Default | Details | |
---|---|---|---|
beta1 | float | 0.98 | Gradient moving average (β1) coefficient |
beta2 | float | 0.92 | Gradient difference moving average (β2) coefficient |
beta3 | float | 0.99 | Gradient squared moving average (β3) coefficient |
eps | float | 1e-08 | Added for numerical stability |
wd | float | 0.02 | True weight decay |
paper_init | bool | False | Initialize prior gradient with current gradient per paper, or zeroes |
foreach | bool | False | Use fused ForEach implementation |
jit | bool | False | Use fused TorchScript implementation |
Returns | Optimizer | AdanForEachOptimizer | JitOptimizer |
AdanLargeBatchLR
AdanLargeBatchLR (bs:int)
Square root rule for scaling Adan
learning rate for large-batch training
Hyperparameters
Hyperparameter notes from Xie et al:
beta2
is the least sensitive Adan hyperparameter, default of 0.92 works for majority of tasks- Xie et al primarily tune
beta3
(between 0.9-0.999) beforebeta1
(between 0.9-0.98) for different tasks - Adan pairs well with large learning rates. Paper and GitHub report up to 3x larger than
Lamb
and up to 5-10x larger thanAdamW
- Xie et al use the default weight decay of 0.02 for all tasks except fine-tuning BERT (
wd=0.01
) and reinforcement learning (wd=0
)
With paper_init=True
, fastxtend’s Adan
matches Xie et al’s Adan implementation.
Training Speed
ForEach and TorchScript optimizers have only been tested on PyTorch 1.12+ and are not guaranteed to work on older versions.
One critique of Adan is the original PyTorch implementation was significantly slower than AdamW. Between 41 to 97 percent slower on tested models. However, Xie et al’s implementation has been refactored to decrease memory usage and Cuda operations. These improvements have been ported to the fastxtend versions. This improved Adan implementation is benchmarked in Table 2 below.
As shown in Table 1, fastxtend’s fused ForEach Adan is 36 to 401 percent faster^{1} then a standard PyTorch implementation.
Model | Layers | Native Step | ForEach Step | ForEach Speedup | JIT Step | JIT Speedup |
---|---|---|---|---|---|---|
XResNet18 | 1 | 31ms | 13ms | 150% | 27ms | 18% |
XResNet50 | 1 | 69ms | 33ms | 108% | 56ms | 24% |
XSE-ResNeXt50 | 1 | 94ms | 45ms | 110% | 75ms | 25% |
XResNet101 | 1 | 115ms | 46ms | 148% | 89ms | 29% |
ConvNeXt Tiny | 2 | 139ms | 102ms | 36% | 124ms | 11% |
ConvNeXt Small | 2 | 225ms | 162ms | 39% | 198ms | 13% |
ViT Patch16 Small | 2 | 76ms | 46ms | 65% | 62ms | 21% |
DeBERTa Base | 1 | 42ms | 8.4ms | 401% | 28ms | 46% |
Now the Adan ForEach steps are only 1.6 to 17 percent slower than the AdamW ForEach steps, and the difference in performance and as a percentage of total training time is significantly smaller. An Adan ForEach step is 0.6ms to 4ms slower than an AdamW ForEach step across measured models, instead of 6ms to 29ms with Adan native as shown in Table 2.
Model | AdamW Step | Adan Step | Slowdown |
---|---|---|---|
XResNet18 | 25ms | 31ms | 27% |
XResNet50 | 53ms | 69ms | 28% |
XSE-ResNeXt50 | 71ms | 94ms | 35% |
XResNet101 | 85ms | 115ms | 33% |
ConvNeXt Tiny | 124ms | 139ms | 12% |
ConvNeXt Small | 196ms | 225ms | 15% |
ViT Patch16 Small | 63ms | 76ms | 20% |
DeBERTa Base | 26ms | 42ms | 62% |
Model | AdamW Step | Adan Step | Slowdown |
---|---|---|---|
XResNet18 | 13ms | 13ms | 5.0% |
XResNet50 | 31ms | 33ms | 2.3% |
XSE-ResNeXt50 | 43ms | 45ms | 5.9% |
XResNet101 | 43ms | 46ms | 8.4% |
ConvNeXt Tiny | 100ms | 102ms | 1.6% |
ConvNeXt Small | 159ms | 162ms | 1.6% |
ViT Patch16 Small | 44ms | 46ms | 2.9% |
DeBERTa Base | 7.2ms | 8.4ms | 17% |
Footnotes
Benchmarked on a GeForce 3080 Ti using PyTorch 1.13.1, Cuda 11.7, Mixed Precision, Channels Last (except DeBERTa and ViT), and fastxtend’s Simple Profiler Callback. Results may differ on other models, hardware, and across benchmarking runs. Speedup and slowdown are calculated from the total time spent on the optimization step.↩︎