Fused Optimizers
fastxtend’s fused optimizers are 21 to 293 percent faster, drop-in replacements for fastai native optimizers.
Like fastai optimizers, fastxtend fused optimizers support both discriminative learning rates across multiple parameter groups and per-parameter weight decay without any extra setup.
While all fastai optimizers have vertically fused TorchScript implementations, only a subset have horizontally fused ForEach1 implementations. These optimizers, SGD, Adam, RAdam, LAMB, and Ranger, usually outperform their TorchScript counterparts in all but the tiniest models. fastxtend also has ForEach implementatons of Adan, Lion, Sophia, and StableAdam.
fastxtend also adds full fastai support for bitsandbytes 8-bit optimizers2. 8-bit optimizers can reduce optimizer memory usage up to 75% compared to 32-bit optimizers. A subset of optimizers are supported: SGD, Adam, LARS, LAMB, and Lion.
ForEach and TorchScript optimizers have only been tested on PyTorch 1.12+ and are not guaranteed to work on older versions.
Documentation for individual optimizers are lightly adapted from the fastai optimizer documentation. Docments and type hints have been upstreamed to fastai.
For implementation details, see the ForEach, TorchScript, or 8-bit documentation.
Fused Performance
As shown in Table 1, ForEach Optimizers are 21 to 293 percent faster3 in AdamW optimizer step performance relative to fastai implementations across benchmarked models. Complex optimizers without ForEach implementations, such as QHAdam, are up to 137 percent faster using TorchScript implementations.
Model | fastai Step | ForEach Step | ForEach Speedup | JIT Step | JIT Speedup |
---|---|---|---|---|---|
XResNet18 | 26ms | 12ms | 109% | 20ms | 29% |
XResNet50 | 56ms | 32ms | 74% | 46ms | 20% |
XSE-ResNeXt50 | 72ms | 43ms | 68% | 61ms | 18% |
XResNet101 | 88ms | 47ms | 84% | 68ms | 30% |
DeBERTa Base | 27ms | 6.9ms | 293% | 19ms | 46% |
This speedup persists with single or multiple parameter groups. Although more groups can lead to a small decrease in optimizer step speed, as shown by DeBERTa in Table 2.
Model | Layers | fastai Step | ForEach Step | ForEach Speedup | JIT Step | JIT Speedup |
---|---|---|---|---|---|---|
XResNet18 | 2 | 25ms | 12ms | 103% | 19ms | 30% |
XResNet50 | 2 | 56ms | 32ms | 76% | 46ms | 24% |
XSE-ResNeXt50 | 2 | 72ms | 45ms | 85% | 61ms | 29% |
XResNet101 | 2 | 87ms | 47ms | 60% | 67ms | 17% |
ConvNeXt Tiny | 2 | 125ms | 102ms | 22% | 115ms | 9.4% |
ConvNeXt Small | 2 | 200ms | 165ms | 21% | 181ms | 10% |
ViT Patch16 Small | 2 | 62ms | 38ms | 62% | 52ms | 20% |
DeBERTa Base | 4 | 27ms | 7.7ms | 254% | 19ms | 47% |
Examples
For backwards compatibility, all fastxtend optimizers return a fastai native optimizer by default. To use a fused version set foreach=True
or jit=True
.
from fastai.vision.all import *
from fastxtend.vision.all import *
# Use ForEach AdamW
= adam(foreach=True)
opt_func
# Or use TorchScript AdamW
= adam(jit=True)
opt_func
# Or use bitsandbytes' 8-bit AdamW
= adam(eightbit=True)
opt_func
=opt_func) Learner(..., opt_func
Or import fused optimizers independent of other fastxtend features.
from fastai.vision.all import *
from fastxtend.optimizer.all import *
=partial(Adam, foreach=True)) Learner(..., opt_func
adam(...)
is a fastxtend convenience method equivalent to partial(Adam, ...)
. fastextend adds lowercase convenience methods for all fastai optimizers.
SGD Optimizer
Stochastic gradient descent, optionally with momentum.
Optional weight decay of wd
is applied, as true weight decay (decay the weights directly) if decouple_wd=True
else as L2 regularization (add the decay to the gradients).
8-bit SGD only supports L2 weight decay: decouple_wd=False
, and requires momentum: mom>0
.
SGD
SGD (params:Union[torch.Tensor,Iterable[torch.Tensor],MutableSequence[tor ch.Tensor],fastcore.foundation.L,fastcore.basics.fastuple], lr:float, mom:float=0.0, wd:float=0.0, decouple_wd:bool=True, foreach:bool=False, jit:bool=False, eightbit:bool=False, **eightbitargs)
A fastai SGD/SGDW optimizer with fused ForEach, TorchScript, & 8-bit implementations
Type | Default | Details | |
---|---|---|---|
params | Listified[Tensor] | Model parameters or parameter groups | |
lr | float | Default learning rate | |
mom | float | 0.0 | Gradient moving average (β1) coefficient |
wd | float | 0.0 | Optional weight decay (true or L2) |
decouple_wd | bool | True | Apply true weight decay (SGDW) or L2 regularization (SGD) |
foreach | bool | False | Use fused ForEach implementation |
jit | bool | False | Use fused TorchScript implementation |
eightbit | bool | False | Use fused 8-bit implementation |
eightbitargs | |||
Returns | Optimizer | SGDForEachOptimizer | JitOptimizer | SGD8bitOptimizer |
sgd
sgd (mom:float=0.0, wd:float=0.0, decouple_wd:bool=True, foreach:bool=False, jit:bool=False, eightbit:bool=False, **eightbitargs)
Partial function for the SGD/SGDW optimizer with fused ForEach, TorchScript, & 8-bit implementations
Type | Default | Details | |
---|---|---|---|
mom | float | 0.0 | Gradient moving average (β1) coefficient |
wd | float | 0.0 | Optional weight decay (true or L2) |
decouple_wd | bool | True | Apply true weight decay (SGDW) or L2 regularization (SGD) |
foreach | bool | False | Use fused ForEach implementation |
jit | bool | False | Use fused TorchScript implementation |
eightbit | bool | False | Use fused 8-bit implementation |
eightbitargs | |||
Returns | Optimizer | SGDForEachOptimizer | JitOptimizer | SGD8bitOptimizer |
RMSProp Optimizer
RMSProp was introduced by Geoffrey Hinton in his course. What is named sqr_mom
here is the alpha
in the course.
Optional weight decay of wd
is applied as true weight decay (decay the weights directly) if decouple_wd=True
else as L2 regularization (add the decay to the gradients).
8-bit RMSProp only supports L2 weight decay: decouple_wd=False
, and does not support momentum: mom=0
.
The order of the mom
and sqr_mom
hyperparameters has been swapped from fastai to follow the order of all the other fastai and fastxtend optimizers.
RMSProp
RMSProp (params:Union[torch.Tensor,Iterable[torch.Tensor],MutableSequence [torch.Tensor],fastcore.foundation.L,fastcore.basics.fastuple], lr:float, mom:float=0.0, sqr_mom:float=0.99, eps:float=1e-08, wd:float=0.0, decouple_wd:bool=True, jit:bool=False, eightbit:bool=False, **eightbitargs)
A fastai RMSProp/RMSPropW optimizer with fused TorchScript and 8-bit implementations
Type | Default | Details | |
---|---|---|---|
params | Listified[Tensor] | Model parameters or parameter groups | |
lr | float | Default learning rate | |
mom | float | 0.0 | Gradient moving average (β1) coefficient |
sqr_mom | float | 0.99 | Gradient squared moving average (β2) coefficient |
eps | float | 1e-08 | Added for numerical stability |
wd | float | 0.0 | Optional weight decay (true or L2) |
decouple_wd | bool | True | Apply true weight decay or L2 regularization. Ignored if eightbit=True |
jit | bool | False | Use fused TorchScript implementation |
eightbit | bool | False | Use fused 8-bit implementation |
eightbitargs | |||
Returns | Optimizer | JitOptimizer | RMSProp8bitOptimizer |
rmsprop
rmsprop (mom:float=0.0, sqr_mom:float=0.99, eps:float=1e-08, wd:float=0.0, decouple_wd:bool=True, jit:bool=False, eightbit:bool=False, **eightbitargs)
Partial function for the RMSProp/RMSPropW optimizer with fused TorchScript and 8-bit implementations
Type | Default | Details | |
---|---|---|---|
mom | float | 0.0 | Gradient moving average (β1) coefficient |
sqr_mom | float | 0.99 | Gradient squared moving average (β2) coefficient |
eps | float | 1e-08 | Added for numerical stability |
wd | float | 0.0 | Optional weight decay (true or L2) |
decouple_wd | bool | True | Apply true weight decay (RMSPropW) or L2 regularization (RMSProp) |
jit | bool | False | Use fused TorchScript implementation |
eightbit | bool | False | Use fused 8-bit implementation |
eightbitargs | |||
Returns | Optimizer | JitOptimizer | RMSProp8bitOptimizer |
Adam Optimizer
Adam was introduced by Diederik P. Kingma and Jimmy Ba in Adam: A Method for Stochastic Optimization. For consistency across optimizers, fastai renamed beta1
and beta2
in the paper to mom
and sqr_mom
. Note that the defaults also differ from the paper (0.99 for sqr_mom
or beta2
, 1e-5 for eps
). Those values seem to be better from experimentation in a wide range of situations.
Optional weight decay of wd
is applied, as true weight decay (decay the weights directly) if decouple_wd=True
else as L2 regularization (add the decay to the gradients).
8-bit Adam only supports true weight decay: decouple_wd=True
.
Don’t forget that eps
is an hyper-parameter you can change. Some models won’t train without a very high eps
like 0.1 (intuitively, the higher eps
is, the closer Adam is to normal SGD). The usual default of 1e-8 is often too extreme in the sense Adam does’t manage to get as good results as with SGD.
Adam
Adam (params:Union[torch.Tensor,Iterable[torch.Tensor],MutableSequence[to rch.Tensor],fastcore.foundation.L,fastcore.basics.fastuple], lr:float, mom:float=0.9, sqr_mom:float=0.99, eps:float=1e-05, wd:float=0.01, decouple_wd:bool=True, foreach:bool=False, jit:bool=False, eightbit:bool=False, **eightbitargs)
A fastai Adam/AdamW optimizer with fused ForEach, TorchScript, & 8-bit implementations
Type | Default | Details | |
---|---|---|---|
params | Listified[Tensor] | Model parameters or parameter groups | |
lr | float | Default learning rate | |
mom | float | 0.9 | Gradient moving average (β1) coefficient |
sqr_mom | float | 0.99 | Gradient squared moving average (β2) coefficient |
eps | float | 1e-05 | Added for numerical stability |
wd | float | 0.01 | Optional weight decay (true or L2) |
decouple_wd | bool | True | Apply true weight decay (AdamW) or L2 regularization (Adam) |
foreach | bool | False | Use fused ForEach implementation |
jit | bool | False | Use fused TorchScript implementation |
eightbit | bool | False | Use fused 8-bit implementation |
eightbitargs | |||
Returns | Optimizer | AdamForEachOptimizer | JitOptimizer | AdamW8bitOptimizer |
adam
adam (mom:float=0.9, sqr_mom:float=0.99, eps:float=1e-05, wd:float=0.01, decouple_wd:bool=True, foreach:bool=False, jit:bool=False, eightbit:bool=False, **eightbitargs)
Partial function for the Adam/AdamW optimizer with fused ForEach, TorchScript, & 8-bit implementations
Type | Default | Details | |
---|---|---|---|
mom | float | 0.9 | Gradient moving average (β1) coefficient |
sqr_mom | float | 0.99 | Gradient squared moving average (β2) coefficient |
eps | float | 1e-05 | Added for numerical stability |
wd | float | 0.01 | Optional weight decay (true or L2) |
decouple_wd | bool | True | Apply true weight decay (AdamW) or L2 regularization (Adam) |
foreach | bool | False | Use fused ForEach implementation |
jit | bool | False | Use fused TorchScript implementation |
eightbit | bool | False | Use fused 8-bit implementation |
eightbitargs | |||
Returns | Optimizer | AdamForEachOptimizer | JitOptimizer | AdamW8bitOptimizer |
RAdam Optimizer
RAdam (for rectified Adam) was introduced by Zhang et al. in On the Variance of the Adaptive Default learning rate and Beyond to slightly modify the Adam optimizer to be more stable at the beginning of training (and thus not require a long warmup). They use an estimate of the variance of the moving average of the squared gradients (the term in the denominator of traditional Adam) and rescale this moving average by this term before performing the update.
The native fastai implementation also incorporates SAdam; set beta
to enable this (definition same as in the paper).
fastxtend ForEach and TorchScript implementations do not support beta
and SAdam.
Optional weight decay of wd
is applied, as true weight decay (decay the weights directly) if decouple_wd=True
else as L2 regularization (add the decay to the gradients).
RAdam
RAdam (params:Union[torch.Tensor,Iterable[torch.Tensor],MutableSequence[t orch.Tensor],fastcore.foundation.L,fastcore.basics.fastuple], lr:float, mom:float=0.9, sqr_mom:float=0.99, eps:float=1e-05, wd:float=0.0, beta:float=0.0, decouple_wd:bool=True, foreach:bool=False, jit:bool=False)
A fastai RAdam/RAdamW optimizer with fused ForEach and TorchScript implementations
Type | Default | Details | |
---|---|---|---|
params | Listified[Tensor] | Model parameters or parameter groups | |
lr | float | Default learning rate | |
mom | float | 0.9 | Gradient moving average (β1) coefficient |
sqr_mom | float | 0.99 | Gradient squared moving average (β2) coefficient |
eps | float | 1e-05 | Added for numerical stability |
wd | float | 0.0 | Optional weight decay (true or L2) |
beta | float | 0.0 | Set to enable SAdam with native fastai RAdam |
decouple_wd | bool | True | Apply true weight decay (RAdamW) or L2 regularization (RAdam) |
foreach | bool | False | Use fused ForEach implementation |
jit | bool | False | Use fused TorchScript implementation |
Returns | Optimizer | RAdamForEachOptimizer | JitOptimizer |
radam
radam (mom:float=0.9, sqr_mom:float=0.99, eps:float=1e-05, wd:float=0.0, beta:float=0.0, decouple_wd:bool=True, foreach:bool=False, jit:bool=False)
Partial function for the RAdam/RAdamW optimizer with fused ForEach and TorchScript implementations
Type | Default | Details | |
---|---|---|---|
mom | float | 0.9 | Gradient moving average (β1) coefficient |
sqr_mom | float | 0.99 | Gradient squared moving average (β2) coefficient |
eps | float | 1e-05 | Added for numerical stability |
wd | float | 0.0 | Optional weight decay (true or L2) |
beta | float | 0.0 | Set to enable SAdam with native fastai RAdam |
decouple_wd | bool | True | Apply true weight decay (RAdamW) or L2 regularization (RAdam) |
foreach | bool | False | Use fused ForEach implementation |
jit | bool | False | Use fused TorchScript implementation |
Returns | Optimizer | RAdamForEachOptimizer | JitOptimizer |
QHAdam Optimizer
QHAdam (for Quasi-Hyperbolic Adam) was introduced by Ma & Yarats in Quasi-Hyperbolic Momentum and Adam for Deep Learning as a “computationally cheap, intuitive to interpret, and simple to implement” optimizer. Additional code can be found in their qhoptim repo. QHAdam is based on QH-Momentum, which introduces the immediate discount factor nu
, encapsulating plain SGD (nu = 0
) and momentum (nu = 1
). QH-Momentum is defined below, where g_t+1 is the update of the moment. An interpretation of QHM is as a nu-weighted average of the momentum update step and the plain SGD update step.
θ_t+1 ← θ_t − lr * [(1 − nu) · ∇L_t(θ_t) + nu · g_t+1]
QHAdam takes the concept behind QHM above and applies it to Adam, replacing both of Adam’s moment estimators with quasi-hyperbolic terms.
The paper’s suggested default parameters are mom = 0.999
, sqr_mom = 0.999
, nu_1 = 0.7
and and nu_2 = 1.0
. When training is not stable, it is possible that setting nu_2 < 1
can improve stability by imposing a tighter step size bound. Note that QHAdam recovers Adam when nu_1 = nu_2 = 1.0
. QHAdam recovers RMSProp (Hinton et al., 2012) when nu_1 = 0
and nu_2 = 1
, and NAdam (Dozat, 2016) when nu_1 = mom
and nu_2 = 1
.
Optional weight decay of wd
is applied, as true weight decay (decay the weights directly) if decouple_wd=True
else as L2 regularization (add the decay to the gradients).
QHAdam
QHAdam (params:Union[torch.Tensor,Iterable[torch.Tensor],MutableSequence[ torch.Tensor],fastcore.foundation.L,fastcore.basics.fastuple], lr:float, mom:float=0.999, sqr_mom:float=0.999, nu_1:float=0.7, nu_2:float=1.0, eps:float=1e-08, wd:float=0.0, decouple_wd:bool=True, jit:bool=False)
A fastai QHAdam/QHAdamW optimizer with a fused TorchScript implementation
Type | Default | Details | |
---|---|---|---|
params | Listified[Tensor] | Model parameters or parameter groups | |
lr | float | Default learning rate | |
mom | float | 0.999 | Gradient moving average (β1) coefficient |
sqr_mom | float | 0.999 | Gradient squared moving average (β2) coefficient |
nu_1 | float | 0.7 | QH immediate discount factor |
nu_2 | float | 1.0 | QH momentum discount factor |
eps | float | 1e-08 | Added for numerical stability |
wd | float | 0.0 | Optional weight decay (true or L2) |
decouple_wd | bool | True | Apply true weight decay (QHAdamW) or L2 regularization (QHAdam) |
jit | bool | False | Use fused TorchScript implementation |
Returns | Optimizer | JitOptimizer |
qhadam
qhadam (mom:float=0.999, sqr_mom:float=0.999, nu_1:float=0.7, nu_2:float=1.0, eps:float=1e-08, wd:float=0.0, decouple_wd:bool=True, jit:bool=False)
Partial function for the QHAdam/QHAdamW optimizer with a fused TorchScript implementation
Type | Default | Details | |
---|---|---|---|
mom | float | 0.999 | Gradient moving average (β1) coefficient |
sqr_mom | float | 0.999 | Gradient squared moving average (β2) coefficient |
nu_1 | float | 0.7 | QH immediate discount factor |
nu_2 | float | 1.0 | QH momentum discount factor |
eps | float | 1e-08 | Added for numerical stability |
wd | float | 0.0 | Optional weight decay (true or L2) |
decouple_wd | bool | True | Apply true weight decay (QHAdamW) or L2 regularization (QHAdam) |
jit | bool | False | Use fused TorchScript implementation |
Returns | Optimizer | JitOptimizer |
LARS/LARC Optimizer
The LARS optimizer was first introduced in Large Batch Training of Convolutional Networks then refined in its LARC variant (original LARS is with clip=False
). A Default learning rate is computed for each individual layer with a certain trust_coefficient
, then clipped to be always less than lr
.
Optional weight decay of wd
is applied, as true weight decay (decay the weights directly) if decouple_wd=True
else as L2 regularization (add the decay to the gradients).
8-bit implementation is for LARS: clip=False
, and only supports L2 weight decay: decouple_wd=False
.
Larc
Larc (params:Union[torch.Tensor,Iterable[torch.Tensor],MutableSequence[to rch.Tensor],fastcore.foundation.L,fastcore.basics.fastuple], lr:float, mom:float=0.9, clip:bool=True, trust_coeff:float=0.02, eps:float=1e-08, wd:float=0.0, decouple_wd:bool=True, jit:bool=False, eightbit:bool=False, **eightbitargs)
A fastai LARC/LARS optimizer with fused TorchScript & 8-bit implementations
Type | Default | Details | |
---|---|---|---|
params | Listified[Tensor] | Model parameters or parameter groups | |
lr | float | Default learning rate | |
mom | float | 0.9 | Gradient moving average (β1) coefficient |
clip | bool | True | LARC if clip=True, LARS if clip=False |
trust_coeff | float | 0.02 | Trust coeffiecnet for calculating layerwise LR |
eps | float | 1e-08 | Added for numerical stability |
wd | float | 0.0 | Optional weight decay (true or L2) |
decouple_wd | bool | True | Apply true weight decay or L2 regularization. Ignored if eightbit=True |
jit | bool | False | Use fused TorchScript implementation |
eightbit | bool | False | Use fused 8-bit implementation. Only supports LARS: clip=False |
eightbitargs | |||
Returns | Optimizer | JitOptimizer | LARS8bitOptimizer |
larc
larc (mom:float=0.9, clip:bool=True, trust_coeff:float=0.02, eps:float=1e-08, wd:float=0.0, decouple_wd:bool=True, jit:bool=False, eightbit:bool=False, **eightbitargs)
Partial function for the LARC/LARS optimizer with fused TorchScript & 8-bit implementations
Type | Default | Details | |
---|---|---|---|
mom | float | 0.9 | Gradient moving average (β1) coefficient |
clip | bool | True | LARC if clip=True, LARS if clip=False |
trust_coeff | float | 0.02 | Trust coeffiecnet for calculating layerwise LR |
eps | float | 1e-08 | Added for numerical stability |
wd | float | 0.0 | Optional weight decay (true or L2) |
decouple_wd | bool | True | Apply true weight decay or L2 regularization |
jit | bool | False | Use fused TorchScript implementation |
eightbit | bool | False | Use fused 8-bit implementation. Only supports LARS |
eightbitargs | |||
Returns | Optimizer | JitOptimizer | LARS8bitOptimizer |
LAMB Optimizer
LAMB was introduced in Large Batch Optimization for Deep Learning: Training BERT in 76 minutes. Intuitively, it’s LARC applied to Adam. As in Adam
, beta1
and beta2
in the paper is renamed to mom
and sqr_mom
. Note that the defaults also differ from the paper (0.99 for sqr_mom
or beta2
, 1e-5 for eps
). Those values seem to be better from experimentation in a wide range of situations.
Optional weight decay of wd
is applied, as true weight decay (decay the weights directly) if decouple_wd=True
else as L2 regularization (add the decay to the gradients).
8-bit LAMB only supports true weight decay: decouple_wd=True
.
Lamb
Lamb (params:Union[torch.Tensor,Iterable[torch.Tensor],MutableSequence[to rch.Tensor],fastcore.foundation.L,fastcore.basics.fastuple], lr:float, mom:float=0.9, sqr_mom:float=0.99, eps:float=1e-05, wd:float=0.0, decouple_wd:bool=True, foreach:bool=False, jit:bool=False, eightbit:bool=False, **eightbitargs)
A fastai LAMB optimizer with fused ForEach, TorchScript, & 8-bit implementations
Type | Default | Details | |
---|---|---|---|
params | Listified[Tensor] | Model parameters or parameter groups | |
lr | float | Default learning rate | |
mom | float | 0.9 | Gradient moving average (β1) coefficient |
sqr_mom | float | 0.99 | Gradient squared moving average (β2) coefficient |
eps | float | 1e-05 | Added for numerical stability |
wd | float | 0.0 | Optional weight decay (true or L2) |
decouple_wd | bool | True | Apply true weight decay or L2 regularization. Ignored if eightbit=True |
foreach | bool | False | Use fused ForEach implementation |
jit | bool | False | Use fused TorchScript implementation |
eightbit | bool | False | Use fused 8-bit implementation. Only supports true weight decay |
eightbitargs | |||
Returns | Optimizer | LambForEachOptimizer | JitOptimizer | LAMB8bitOptimizer |
lamb
lamb (mom:float=0.9, sqr_mom:float=0.99, eps:float=1e-05, wd:float=0.0, decouple_wd:bool=True, foreach:bool=False, jit:bool=False, eightbit:bool=False, **eightbitargs)
Partial function for the LAMB optimizer with fused ForEach, TorchScript, & 8-bit implementations
Type | Default | Details | |
---|---|---|---|
mom | float | 0.9 | Gradient moving average (β1) coefficient |
sqr_mom | float | 0.99 | Gradient squared moving average (β2) coefficient |
eps | float | 1e-05 | Added for numerical stability |
wd | float | 0.0 | Optional weight decay (true or L2) |
decouple_wd | bool | True | Apply true weight decay or L2 regularization |
foreach | bool | False | Use fused ForEach implementation |
jit | bool | False | Use fused TorchScript implementation |
eightbit | bool | False | Use fused 8-bit implementation. Only supports true weight decay |
eightbitargs | |||
Returns | Optimizer | LambForEachOptimizer | JitOptimizer | LAMB8bitOptimizer |
Ranger Optimizer
Lookahead was introduced by Zhang et al. in Lookahead Optimizer: k steps forward, 1 step back. With Lookahead, the final weights (slow weights) are a moving average of the normal weights (fast weights). Every k steps, Lookahead modifieds the current weights by a moving average of the fast weights (normal weights) with the slow weights (the copy of old weights k steps ago). Those slow weights act like a stability mechanism.
Ranger was introduced by Less Wright in New Deep Learning Optimizer, Ranger: Synergistic combination of RAdam + Lookahead for the best of both. It combines RAdam and Lookahead together in one optimizer and reduces the need for hyperparameter tuning due to a combination of RAdam’s warmup heuristic and Lookahead’s interpolation of parameter weights.
Ranger performs best on vision tasks when paired with the fit_flat_cos
or fit_flat_varied
schedulers.
Optional weight decay of wd
is applied, as true weight decay (decay the weights directly) if decouple_wd=True
else as L2 regularization (add the decay to the gradients).
While fastai’s Lookahead
can be applied to any optimizer, fastxtend’s JitLookahead
must have a custom written TorchScript callback and ForEachOptimizer
a custom Lookahead optimizer step. Currently ranger with RAdam is the only TorchScript and ForEach optimizer with Lookahead support.
Ranger
Ranger (params:Union[torch.Tensor,Iterable[torch.Tensor],MutableSequence[ torch.Tensor],fastcore.foundation.L,fastcore.basics.fastuple], lr:float, mom:float=0.95, sqr_mom:float=0.99, eps:float=1e-06, wd:float=0.01, k:int=6, alpha:float=0.5, decouple_wd:bool=True, foreach:bool=False, jit:bool=False)
Convenience method for Lookahead
with RAdam
fused ForEach and TorchScript implementations
Type | Default | Details | |
---|---|---|---|
params | Listified[Tensor] | Model parameters or parameter groups | |
lr | float | Default learning rate | |
mom | float | 0.95 | Gradient moving average (β1) coefficient |
sqr_mom | float | 0.99 | Gradient squared moving average (β2) coefficient |
eps | float | 1e-06 | Added for numerical stability |
wd | float | 0.01 | Optional weight decay (true or L2) |
k | int | 6 | How often to conduct Lookahead step |
alpha | float | 0.5 | Slow weight moving average coefficient |
decouple_wd | bool | True | Apply true weight decay (RAdamW) or L2 regularization (RAdam) |
foreach | bool | False | Use fused ForEach implementation |
jit | bool | False | Use fused TorchScript implementation |
Returns | Lookahead | RangerForEachOptimizer | JitLookahead |
ranger
ranger (mom:float=0.95, sqr_mom:float=0.99, eps:float=1e-06, wd:float=0.01, k:int=6, alpha:float=0.5, decouple_wd:bool=True, foreach:bool=False, jit:bool=False)
Partial function of the onvenience method for Lookahead
with RAdam
fused ForEach and TorchScript implementations
Type | Default | Details | |
---|---|---|---|
mom | float | 0.95 | Gradient moving average (β1) coefficient |
sqr_mom | float | 0.99 | Gradient squared moving average (β2) coefficient |
eps | float | 1e-06 | Added for numerical stability |
wd | float | 0.01 | Optional weight decay (true or L2) |
k | int | 6 | How often to conduct Lookahead step |
alpha | float | 0.5 | Slow weight moving average coefficient |
decouple_wd | bool | True | Apply true weight decay (RAdamW) or L2 regularization (RAdam) |
foreach | bool | False | Use fused ForEach implementation |
jit | bool | False | Use fused TorchScript implementation |
Returns | Lookahead | RangerForEachOptimizer | JitLookahead |
Footnotes
fastxtend ForEach optimizers are adapted from the PyTorch ForEach
_multi_tensor
implementations, but seamlessly work with fastai features.↩︎While it is possible to use bitsandbytes optimizers with fastai via
fastai.optimizer.OptimWrapper
, this doesn’t provide compatibility with all fastai optimizer features. fastxtend adds full fastai compatibility to bitsandbytes 8-bit optimizers, including per-parameter weight decay, automatic weight decay exclusion for normalization and bias terms, and discriminative learning rate support.↩︎All optimizers benchmarked on a GeForce 3080 Ti using PyTorch 1.12.1, Cuda 11.6, Mixed Precision, Channels Last (except ViT and DeBERTa), and fastxtend’s Simple Profiler Callback. Results may differ with other optimizers, models, hardware, and across benchmarking runs. Speedup is calculated from the total time spent on the optimization step.↩︎