Lion: EvoLved Sign Momentum Optimizer
Lion was introduced by Chen et al in Symbolic Discovery of Optimization Algorithms. Lion only keeps track of the gradient moving average (momentum) which reduces memory usage compared to AdamW
. Lion uses two momentum EMA factors, one for tracking momentum and another for using momentum in the update step. Using default hyperparameters, this allows up to ten times longer history for momentum tracking while leveraging more of the current gradient for the model update. Unlike most optimizers, Lion uses the same magnitude for each parameter update calculated using the sign operation.
In addition to a fastai native implementation, Lion
has a fused ForEach and bitsandbytes 8-bit implementations. See the Fused Optimizer and 8-bit Optimizer documentation for more details.
Lion
Lion (params:Union[torch.Tensor,Iterable[torch.Tensor],MutableSequence[to rch.Tensor],fastcore.foundation.L,fastcore.basics.fastuple], lr:float, beta1:float=0.9, beta2:float=0.99, wd:float=0.1, foreach:bool=False, eightbit:bool=False, **eightbitargs)
A fastai Lion optimizer with fused ForEach and 8-bit implementations
Type | Default | Details | |
---|---|---|---|
params | Listified[Tensor] | Model parameters or parameter groups | |
lr | float | Default learning rate | |
beta1 | float | 0.9 | Update gradient moving average (β1) coefficient |
beta2 | float | 0.99 | Gradient moving average (β2) coefficient |
wd | float | 0.1 | True weight decay |
foreach | bool | False | Use fused ForEach implementation |
eightbit | bool | False | Use fused 8-bit implementation |
eightbitargs | |||
Returns | Optimizer | LionForEachOptimizer | Lion8bitOptimizer |
lion
lion (beta1:float=0.9, beta2:float=0.99, wd:float=0.1, foreach:bool=False, eightbit:bool=False, **eightbitargs)
Partial function for the Lion optimizer with fused ForEach and 8-bit implementations
Type | Default | Details | |
---|---|---|---|
beta1 | float | 0.9 | Update gradient moving average (β1) coefficient |
beta2 | float | 0.99 | Gradient moving average (β2) coefficient |
wd | float | 0.1 | True weight decay |
foreach | bool | False | Use fused ForEach implementation |
eightbit | bool | False | Use fused 8-bit implementation |
eightbitargs | |||
Returns | Optimizer | LionForEachOptimizer | Lion8bitOptimizer |
Hyperparameters
Hyperparameter notes from Chen et al:
- Except for language modeling,
beta1
andbeta2
are held at 0.9 and 0.99, respectively. When traing T5 they setbeta1=0.95
andbeta2=0.98
. - Due to the larger update norm from the sign operation, the Lion learning rate is typically 10X smaller than
AdamW
, with 3X smaller sometimes performing better. - Since the effective weight decay is multiplied by the learning rate, weight decay should be increased by the learning rate decrease (10X or 3X).
- The optimal batch size for Lion is 4096 (vs AdamW’s 256), but Lion still performs well at a batch size of 64 and matches or exceeds
AdamW
on all tested batch sizes.
Training Speed
The ForEach optimizer has only been tested on PyTorch 1.12+ and are not guaranteed to work on older versions.
As shown in Table 1, fastxtend’s fused ForEach Lion is 13 to 195 percent faster1 then a standard PyTorch implementation. This training speed advantage could increase in a future PyTorch release, as PyTorch doesn’t have a ForEach implementation of sign
2, so the implementation falls back to a for loop in the middle of the Lion update step.
Model | Layers | Native Step | ForEach Step | ForEach Speedup |
---|---|---|---|---|
XResNet18 | 1 | 23ms | 13ms | 73% |
XResNet50 | 1 | 50ms | 34ms | 47% |
XSE-ResNeXt50 | 1 | 66ms | 47ms | 41% |
XResNet101 | 1 | 76ms | 48ms | 59% |
ConvNeXt Tiny | 2 | 118ms | 104ms | 13% |
ConvNeXt Small | 2 | 189ms | 164ms | 16% |
ViT Patch16 Small | 2 | 57ms | 45ms | 26% |
DeBERTa Base | 1 | 22ms | 7.5ms | 195% |
Due a simpler update and only tracking momentum, the native implementation of Lion is both faster than the native implementation of AdamW and uses less memory. However, since the ForEach implementation requires use of a for loop, Lion ForEach is equal or slower than AdamW ForEach. However, Lion ForEach should still use less memory than AdamW ForEach.
Model | AdamW Step | Lion Step | Speedup |
---|---|---|---|
XResNet18 | 26ms | 23ms | 15% |
XResNet50 | 54ms | 50ms | 8.2% |
XSE-ResNeXt50 | 72ms | 66ms | 8.3% |
XResNet101 | 91ms | 76ms | 19% |
ConvNeXt Tiny | 125ms | 118ms | 6.1% |
ConvNeXt Small | 202ms | 189ms | 6.8% |
ViT Patch16 Small | 63ms | 57ms | 9.4% |
DeBERTa Base | 26ms | 22ms | 25% |
Model | AdamW Step | Lion Step | Slowdown |
---|---|---|---|
XResNet18 | 13ms | 13ms | 0.9% |
XResNet50 | 33ms | 34ms | 4.6% |
XSE-ResNeXt50 | 42ms | 47ms | 12% |
XResNet101 | 46ms | 48ms | 4.2% |
ConvNeXt Tiny | 102ms | 104ms | 2.5% |
ConvNeXt Small | 161ms | 164ms | 1.6% |
ViT Patch16 Small | 42ms | 45ms | 4.9% |
DeBERTa Base | 7.4ms | 7.5ms | 1.5% |
Footnotes
Benchmarked on a GeForce 3080 Ti using PyTorch 1.13.1, Cuda 11.7, Mixed Precision, Channels Last (except DeBERTa and ViT), and fastxtend’s Simple Profiler Callback. Results may differ on other models, hardware, and across benchmarking runs. Speedup and slowdown are calculated from the total time spent on the optimization step.↩︎
Numerically equivalent approximations of
sign
using ForEach operators ended up using more memory and were a wash on training speed.↩︎