Fused Optimizers
fastxtend’s fused optimizers are 21 to 293 percent faster, drop-in replacements for fastai native optimizers. Like fastai optimizers, fastxtend fused optimizers support both discriminative learning rates across multiple parameter groups and per-parameter weight decay without any extra setup.
While all fastai optimizers have vertically fused TorchScript implementations, only a subset have horizontally fused ForEach implementations. These optimizers, SGD
, Adam
, RAdam
, Lamb
, and Ranger
, usually outperform their TorchScript counterparts in all but the tiniest models.
fastxtend ForEach optimizers are equivalent in performance to PyTorch ForEach optimizers with two parameter groups, one for applying weight decay and one for parameters without weight decay.
For implementation details, see the ForEach or TorchScript documentation.
Fused Performance
As shown in Table 1, ForEach Optimizers are 21 to 293 percent faster^{1} in optimizer step performance relative to fastai implementations across benchmarked models. Complex optimizers without ForEach implementations, such as QHAdam
, are up to 137 percent faster using the TorchScript implementation.
Model | fastai Step | ForEach Step | ForEach Speedup | JIT Step | JIT Speedup |
---|---|---|---|---|---|
XResNet18 | 26ms | 12ms | 109% | 20ms | 29% |
XResNet50 | 56ms | 32ms | 74% | 46ms | 20% |
XSE-ResNext50 | 72ms | 43ms | 68% | 61ms | 18% |
XResNet101 | 88ms | 47ms | 84% | 68ms | 30% |
DeBERTa Base | 27ms | 6.9ms | 293% | 19ms | 46% |
This speedup persists with single or multiple parameter groups. Although more groups can lead to a small decrease in optimizer step speed, as shown by DeBERTa in Table 2.
Model | Layers | fastai Step | ForEach Step | ForEach Speedup | JIT Step | JIT Speedup |
---|---|---|---|---|---|---|
XResNet18 | 2 | 25ms | 12ms | 103% | 19ms | 30% |
XResNet50 | 2 | 56ms | 32ms | 76% | 46ms | 24% |
XSE-ResNext50 | 2 | 72ms | 45ms | 85% | 61ms | 29% |
XResNet101 | 2 | 87ms | 47ms | 60% | 67ms | 17% |
ConvNext Tiny | 2 | 125ms | 102ms | 22% | 115ms | 9.4% |
ConvNext Small | 2 | 200ms | 165ms | 21% | 181ms | 10% |
ViT Patch16 Small | 2 | 62ms | 38ms | 62% | 52ms | 20% |
DeBERTa Base | 4 | 27ms | 7.7ms | 254% | 19ms | 47% |
Examples
For backwards compatibility, all fastxtend optimizers return a fastai native optimizer by default. To use a fused version set foreach=True
or jit=True
.
from fastai.vision.all import *
from fastxtend.vision.all import *
# Use ForEach AdamW
= adam(foreach=True)
opt_func
# Or use TorchScript AdamW
= adam(jit=True)
opt_func
=opt_func) Learner(..., opt_func
Or import fused optimizers independent of other fastxtend features.
from fastai.vision.all import *
from fastxtend.optimizer.all import *
=partial(Adam, foreach=True)) Learner(..., opt_func
SGD Optimizer
Stochastic gradient descent, optionally with momentum.
Optional weight decay of wd
is applied, as true weight decay (decay the weights directly) if decouple_wd=True
else as L2 regularization (add the decay to the gradients).
SGD
SGD (params:Union[torch.Tensor,Iterable[torch.Tensor],fastcore.foundation .L,fastcore.basics.fastuple], lr:float, mom:float=0.0, wd:float=0.0, decouple_wd:bool=True, foreach:bool=False, jit:bool=False)
A fastai SGD/SGDW optimizer with fused ForEach and TorchScript implementations
Type | Default | Details | |
---|---|---|---|
params | listified[Tensor] | Model parameters or parameter groups | |
lr | float | Default learning rate | |
mom | float | 0.0 | Gradient moving average (β1) coefficient |
wd | float | 0.0 | Optional weight decay (true or L2) |
decouple_wd | bool | True | Apply true weight decay (SGDW) or L2 regularization (SGD) |
foreach | bool | False | Use fused ForEach implementation |
jit | bool | False | Use fused TorchScript implementation |
Returns | Optimizer | SGDForEachOptimizer | JitOptimizer |
sgd
sgd (mom:float=0.0, wd:float=0.0, decouple_wd:bool=True, foreach:bool=False, jit:bool=False)
Partial function for the SGD/SGDW optimizer with fused ForEach and TorchScript implementations
Type | Default | Details | |
---|---|---|---|
mom | float | 0.0 | Gradient moving average (β1) coefficient |
wd | float | 0.0 | Optional weight decay (true or L2) |
decouple_wd | bool | True | Apply true weight decay (SGDW) or L2 regularization (SGD) |
foreach | bool | False | Use fused ForEach implementation |
jit | bool | False | Use fused TorchScript implementation |
Returns | Optimizer | SGDForEachOptimizer | JitOptimizer |
RMSProp Optimizer
RMSProp was introduced by Geoffrey Hinton in his course. What is named sqr_mom
here is the alpha
in the course.
Optional weight decay of wd
is applied as true weight decay (decay the weights directly) if decouple_wd=True
else as L2 regularization (add the decay to the gradients).
RMSProp
RMSProp (params:Union[torch.Tensor,Iterable[torch.Tensor],fastcore.founda tion.L,fastcore.basics.fastuple], lr:float, mom:float=0.0, sqr_mom:float=0.99, eps:float=1e-08, wd:float=0.0, decouple_wd:bool=True, jit:bool=False)
A fastai RMSProp/RMSPropW optimizer with a fused TorchScript implementation
Type | Default | Details | |
---|---|---|---|
params | listified[Tensor] | Model parameters or parameter groups | |
lr | float | Default learning rate | |
mom | float | 0.0 | Gradient moving average (β1) coefficient |
sqr_mom | float | 0.99 | Gradient squared moving average (β2) coefficient |
eps | float | 1e-08 | Added for numerical stability |
wd | float | 0.0 | Optional weight decay (true or L2) |
decouple_wd | bool | True | Apply true weight decay (RMSPropW) or L2 regularization (RMSProp) |
jit | bool | False | Use fused TorchScript implementation |
Returns | Optimizer | JitOptimizer |
rmsprop
rmsprop (mom:float=0.0, sqr_mom:float=0.99, eps:float=1e-08, wd:float=0.0, decouple_wd:bool=True, jit:bool=False)
Partial function for the RMSProp/RMSPropW optimizer with a fused TorchScript implementation
Type | Default | Details | |
---|---|---|---|
mom | float | 0.0 | Gradient moving average (β1) coefficient |
sqr_mom | float | 0.99 | Gradient squared moving average (β2) coefficient |
eps | float | 1e-08 | Added for numerical stability |
wd | float | 0.0 | Optional weight decay (true or L2) |
decouple_wd | bool | True | Apply true weight decay (RMSPropW) or L2 regularization (RMSProp) |
jit | bool | False | Use fused TorchScript implementation |
Returns | Optimizer | JitOptimizer |
Adam Optimizer
Adam was introduced by Diederik P. Kingma and Jimmy Ba in Adam: A Method for Stochastic Optimization. For consistency across optimizers, fastai renamed beta1
and beta2
in the paper to mom
and sqr_mom
. Note that the defaults also differ from the paper (0.99 for sqr_mom
or beta2
, 1e-5 for eps
). Those values seem to be better from experimentation in a wide range of situations.
Optional weight decay of wd
is applied, as true weight decay (decay the weights directly) if decouple_wd=True
else as L2 regularization (add the decay to the gradients).
Adam
Adam (params:Union[torch.Tensor,Iterable[torch.Tensor],fastcore.foundatio n.L,fastcore.basics.fastuple], lr:float, mom:float=0.9, sqr_mom:float=0.99, eps:float=1e-05, wd:float=0.01, decouple_wd:bool=True, foreach:bool=False, jit:bool=False)
A fastai Adam/AdamW optimizer with fused ForEach and TorchScript implementations
Type | Default | Details | |
---|---|---|---|
params | listified[Tensor] | Model parameters or parameter groups | |
lr | float | Default learning rate | |
mom | float | 0.9 | Gradient moving average (β1) coefficient |
sqr_mom | float | 0.99 | Gradient squared moving average (β2) coefficient |
eps | float | 1e-05 | Added for numerical stability |
wd | float | 0.01 | Optional weight decay (true or L2) |
decouple_wd | bool | True | Apply true weight decay (AdamW) or L2 regularization (Adam) |
foreach | bool | False | Use fused ForEach implementation |
jit | bool | False | Use fused TorchScript implementation |
Returns | Optimizer | AdamForEachOptimizer | JitOptimizer |
adam
adam (mom:float=0.0, sqr_mom:float=0.99, eps:float=1e-08, wd:float=0.0, decouple_wd:bool=True, foreach:bool=False, jit:bool=False)
Partial function for the Adam/AdamW optimizer with fused ForEach and TorchScript implementations
Type | Default | Details | |
---|---|---|---|
mom | float | 0.0 | Gradient moving average (β1) coefficient |
sqr_mom | float | 0.99 | Gradient squared moving average (β2) coefficient |
eps | float | 1e-08 | Added for numerical stability |
wd | float | 0.0 | Optional weight decay (true or L2) |
decouple_wd | bool | True | Apply true weight decay (RMSPropW) or L2 regularization (RMSProp) |
foreach | bool | False | Use fused ForEach implementation |
jit | bool | False | Use fused TorchScript implementation |
Returns | Optimizer | AdamForEachOptimizer | JitOptimizer |
RAdam Optimizer
RAdam (for rectified Adam) was introduced by Zhang et al. in On the Variance of the Adaptive Default learning rate and Beyond to slightly modify the Adam optimizer to be more stable at the beginning of training (and thus not require a long warmup). They use an estimate of the variance of the moving average of the squared gradients (the term in the denominator of traditional Adam) and rescale this moving average by this term before performing the update.
The native fastai implementation also incorporates SAdam; set beta
to enable this (definition same as in the paper).
Optional weight decay of wd
is applied, as true weight decay (decay the weights directly) if decouple_wd=True
else as L2 regularization (add the decay to the gradients).
RAdam
RAdam (params:Union[torch.Tensor,Iterable[torch.Tensor],fastcore.foundati on.L,fastcore.basics.fastuple], lr:float, mom:float=0.9, sqr_mom:float=0.99, eps:float=1e-05, wd:float=0.0, beta:float=0.0, decouple_wd:bool=True, foreach:bool=False, jit:bool=False)
A fastai RAdam/RAdamW optimizer with fused ForEach and TorchScript implementations
Type | Default | Details | |
---|---|---|---|
params | listified[Tensor] | Model parameters or parameter groups | |
lr | float | Default learning rate | |
mom | float | 0.9 | Gradient moving average (β1) coefficient |
sqr_mom | float | 0.99 | Gradient squared moving average (β2) coefficient |
eps | float | 1e-05 | Added for numerical stability |
wd | float | 0.0 | Optional weight decay (true or L2) |
beta | float | 0.0 | Set to enable SAdam with native fastai RAdam |
decouple_wd | bool | True | Apply true weight decay (RAdamW) or L2 regularization (RAdam) |
foreach | bool | False | Use fused ForEach implementation |
jit | bool | False | Use fused TorchScript implementation |
Returns | Optimizer | RAdamForEachOptimizer | JitOptimizer |
radam
radam (mom:float=0.0, sqr_mom:float=0.99, eps:float=1e-08, wd:float=0.0, beta:float=0.0, decouple_wd:bool=True, foreach:bool=False, jit:bool=False)
Partial function for the RAdam/RAdamW optimizer with fused ForEach and TorchScript implementations
Type | Default | Details | |
---|---|---|---|
mom | float | 0.0 | Gradient moving average (β1) coefficient |
sqr_mom | float | 0.99 | Gradient squared moving average (β2) coefficient |
eps | float | 1e-08 | Added for numerical stability |
wd | float | 0.0 | Optional weight decay (true or L2) |
beta | float | 0.0 | Set to enable SAdam with native fastai RAdam |
decouple_wd | bool | True | Apply true weight decay (RMSPropW) or L2 regularization (RMSProp) |
foreach | bool | False | Use fused ForEach implementation |
jit | bool | False | Use fused TorchScript implementation |
Returns | Optimizer | RAdamForEachOptimizer | JitOptimizer |
QHAdam Optimizer
QHAdam (for Quasi-Hyperbolic Adam) was introduced by Ma & Yarats in Quasi-Hyperbolic Momentum and Adam for Deep Learning as a “computationally cheap, intuitive to interpret, and simple to implement” optimizer. Additional code can be found in their qhoptim repo. QHAdam is based on QH-Momentum, which introduces the immediate discount factor nu
, encapsulating plain SGD (nu = 0
) and momentum (nu = 1
). QH-Momentum is defined below, where g_t+1 is the update of the moment. An interpretation of QHM is as a nu-weighted average of the momentum update step and the plain SGD update step.
θ_t+1 ← θ_t − lr * [(1 − nu) · ∇L_t(θ_t) + nu · g_t+1]
QHAdam takes the concept behind QHM above and applies it to Adam, replacing both of Adam’s moment estimators with quasi-hyperbolic terms.
The paper’s suggested default parameters are mom = 0.999
, sqr_mom = 0.999
, nu_1 = 0.7
and and nu_2 = 1.0
. When training is not stable, it is possible that setting nu_2 < 1
can improve stability by imposing a tighter step size bound. Note that QHAdam recovers Adam when nu_1 = nu_2 = 1.0
. QHAdam recovers RMSProp (Hinton et al., 2012) when nu_1 = 0
and nu_2 = 1
, and NAdam (Dozat, 2016) when nu_1 = mom
and nu_2 = 1
.
Optional weight decay of wd
is applied, as true weight decay (decay the weights directly) if decouple_wd=True
else as L2 regularization (add the decay to the gradients).
QHAdam
QHAdam (params:Union[torch.Tensor,Iterable[torch.Tensor],fastcore.foundat ion.L,fastcore.basics.fastuple], lr:float, mom:float=0.999, sqr_mom:float=0.999, nu_1:float=0.7, nu_2:float=1.0, eps:float=1e-08, wd:float=0.0, decouple_wd:bool=True, jit:bool=False)
A fastai QHAdam/QHAdamW optimizer with a fused TorchScript implementation
Type | Default | Details | |
---|---|---|---|
params | listified[Tensor] | Model parameters or parameter groups | |
lr | float | Default learning rate | |
mom | float | 0.999 | Gradient moving average (β1) coefficient |
sqr_mom | float | 0.999 | Gradient squared moving average (β2) coefficient |
nu_1 | float | 0.7 | QH immediate discount factor |
nu_2 | float | 1.0 | QH momentum discount factor |
eps | float | 1e-08 | Added for numerical stability |
wd | float | 0.0 | Optional weight decay (true or L2) |
decouple_wd | bool | True | Apply true weight decay (QHAdamW) or L2 regularization (QHAdam) |
jit | bool | False | Use fused TorchScript implementation |
Returns | Optimizer | JitOptimizer |
qhadam
qhadam (mom:float=0.0, sqr_mom:float=0.99, nu_1:float=0.7, nu_2:float=1.0, eps:float=1e-08, wd:float=0.0, decouple_wd:bool=True, jit:bool=False)
Partial function for the QHAdam/QHAdamW optimizer with a fused TorchScript implementation
Type | Default | Details | |
---|---|---|---|
mom | float | 0.0 | Gradient moving average (β1) coefficient |
sqr_mom | float | 0.99 | Gradient squared moving average (β2) coefficient |
nu_1 | float | 0.7 | QH immediate discount factor |
nu_2 | float | 1.0 | QH momentum discount factor |
eps | float | 1e-08 | Added for numerical stability |
wd | float | 0.0 | Optional weight decay (true or L2) |
decouple_wd | bool | True | Apply true weight decay (RMSPropW) or L2 regularization (RMSProp) |
jit | bool | False | Use fused TorchScript implementation |
Returns | Optimizer | JitOptimizer |
LARS/LARC Optimizer
The LARS optimizer was first introduced in Large Batch Training of Convolutional Networks then refined in its LARC variant (original LARS is with clip=False
). A Default learning rate is computed for each individual layer with a certain trust_coefficient
, then clipped to be always less than lr
.
Optional weight decay of wd
is applied, as true weight decay (decay the weights directly) if decouple_wd=True
else as L2 regularization (add the decay to the gradients).
Larc
Larc (params:Union[torch.Tensor,Iterable[torch.Tensor],fastcore.foundatio n.L,fastcore.basics.fastuple], lr:float, mom:float=0.9, clip:bool=True, trust_coeff:float=0.02, eps:float=1e-08, wd:float=0.0, decouple_wd:bool=True, jit:bool=False)
A fastai LARC/LARS optimizer with a fused TorchScript implementation
Type | Default | Details | |
---|---|---|---|
params | listified[Tensor] | Model parameters or parameter groups | |
lr | float | Default learning rate | |
mom | float | 0.9 | Gradient moving average (β1) coefficient |
clip | bool | True | LARC if clip=True, LARS if clip=False |
trust_coeff | float | 0.02 | Trust coeffiecnet for calculating layerwise LR |
eps | float | 1e-08 | Added for numerical stability |
wd | float | 0.0 | Optional weight decay (true or L2) |
decouple_wd | bool | True | Apply true weight decay or L2 regularization |
jit | bool | False | Use fused TorchScript implementation |
Returns | Optimizer | JitOptimizer |
larc
larc (mom:float=0.0, clip:bool=True, trust_coeff:float=0.02, eps:float=1e-08, wd:float=0.0, decouple_wd:bool=True, jit:bool=False)
Partial function for the LARC/LARS optimizer with a fused TorchScript implementation
Type | Default | Details | |
---|---|---|---|
mom | float | 0.0 | Gradient moving average (β1) coefficient |
clip | bool | True | LARC if clip=True, LARS if clip=False |
trust_coeff | float | 0.02 | Trust coeffiecnet for calculating layerwise LR |
eps | float | 1e-08 | Added for numerical stability |
wd | float | 0.0 | Optional weight decay (true or L2) |
decouple_wd | bool | True | Apply true weight decay (RMSPropW) or L2 regularization (RMSProp) |
jit | bool | False | Use fused TorchScript implementation |
Returns | Optimizer | JitOptimizer |
LAMB Optimizer
LAMB was introduced in Large Batch Optimization for Deep Learning: Training BERT in 76 minutes. Intuitively, it’s LARC applied to Adam. As in Adam
, beta1
and beta2
in the paper is renamed to mom
and sqr_mom
. Note that the defaults also differ from the paper (0.99 for sqr_mom
or beta2
, 1e-5 for eps
). Those values seem to be better from experimentation in a wide range of situations.
Optional weight decay of wd
is applied, as true weight decay (decay the weights directly) if decouple_wd=True
else as L2 regularization (add the decay to the gradients).
Lamb
Lamb (params:Union[torch.Tensor,Iterable[torch.Tensor],fastcore.foundatio n.L,fastcore.basics.fastuple], lr:float, mom:float=0.9, sqr_mom:float=0.99, eps:float=1e-05, wd:float=0.01, decouple_wd:bool=True, foreach:bool=False, jit:bool=False)
A fastai LAMB optimizer with fused ForEach and TorchScript implementations
Type | Default | Details | |
---|---|---|---|
params | listified[Tensor] | Model parameters or parameter groups | |
lr | float | Default learning rate | |
mom | float | 0.9 | Gradient moving average (β1) coefficient |
sqr_mom | float | 0.99 | Gradient squared moving average (β2) coefficient |
eps | float | 1e-05 | Added for numerical stability |
wd | float | 0.01 | Optional weight decay (true or L2). Paper default, fastai’s is 0. |
decouple_wd | bool | True | Apply true weight decay or L2 regularization |
foreach | bool | False | Use fused ForEach implementation |
jit | bool | False | Use fused TorchScript implementation |
Returns | Optimizer | LambForEachOptimizer | JitOptimizer |
lamb
lamb (mom:float=0.0, sqr_mom:float=0.99, eps:float=1e-08, wd:float=0.0, decouple_wd:bool=True, foreach:bool=False, jit:bool=False)
Partial function for the LAMB optimizer with fused ForEach and TorchScript implementations
Type | Default | Details | |
---|---|---|---|
mom | float | 0.0 | Gradient moving average (β1) coefficient |
sqr_mom | float | 0.99 | Gradient squared moving average (β2) coefficient |
eps | float | 1e-08 | Added for numerical stability |
wd | float | 0.0 | Optional weight decay (true or L2) |
decouple_wd | bool | True | Apply true weight decay (RMSPropW) or L2 regularization (RMSProp) |
foreach | bool | False | Use fused ForEach implementation |
jit | bool | False | Use fused TorchScript implementation |
Returns | Optimizer | LambForEachOptimizer | JitOptimizer |
Ranger Optimizer
Lookahead was introduced by Zhang et al. in Lookahead Optimizer: k steps forward, 1 step back. With Lookahead, the final weights (slow weights) are a moving average of the normal weights (fast weights). Every k steps, Lookahead modifieds the current weights by a moving average of the fast weights (normal weights) with the slow weights (the copy of old weights k steps ago). Those slow weights act like a stability mechanism.
Ranger was introduced by Less Wright in New Deep Learning Optimizer, Ranger: Synergistic combination of RAdam + Lookahead for the best of both. It combines RAdam and Lookahead together in one optimizer and reduces the need for hyperparameter tuning due to a combination of RAdam’s warmup heuristic and Lookahead’s interpolation of parameter weights.
Ranger performs best on vision tasks when paired with the fit_flat_cos
or fit_flat_varied
schedulers.
Optional weight decay of wd
is applied, as true weight decay (decay the weights directly) if decouple_wd=True
else as L2 regularization (add the decay to the gradients).
Ranger
Ranger (params:Union[torch.Tensor,Iterable[torch.Tensor],fastcore.foundat ion.L,fastcore.basics.fastuple], lr:float, mom:float=0.95, sqr_mom:float=0.99, eps:float=1e-06, wd:float=0.01, k:int=6, alpha:float=0.5, decouple_wd:bool=True, foreach:bool=False, jit:bool=False)
Convenience method for Lookahead
with RAdam
fused ForEach and TorchScript implementations
Type | Default | Details | |
---|---|---|---|
params | listified[Tensor] | Model parameters or parameter groups | |
lr | float | Default learning rate | |
mom | float | 0.95 | Gradient moving average (β1) coefficient |
sqr_mom | float | 0.99 | Gradient squared moving average (β2) coefficient |
eps | float | 1e-06 | Added for numerical stability |
wd | float | 0.01 | Optional weight decay (true or L2) |
k | int | 6 | How often to conduct Lookahead step |
alpha | float | 0.5 | Slow weight moving average coefficient |
decouple_wd | bool | True | Apply true weight decay (RAdamW) or L2 regularization (RAdam) |
foreach | bool | False | Use fused ForEach implementation |
jit | bool | False | Use fused TorchScript implementation |
Returns | Lookahead | RangerForEachOptimizer | JitLookahead |
ranger
ranger (mom:float=0.95, sqr_mom:float=0.99, eps:float=1e-06, wd:float=0.01, k:int=6, alpha:float=0.5, decouple_wd:bool=True, foreach:bool=False, jit:bool=False)
Partial function of the onvenience method for Lookahead
with RAdam
fused ForEach and TorchScript implementations
Type | Default | Details | |
---|---|---|---|
mom | float | 0.95 | Gradient moving average (β1) coefficient |
sqr_mom | float | 0.99 | Gradient squared moving average (β2) coefficient |
eps | float | 1e-06 | Added for numerical stability |
wd | float | 0.01 | Optional weight decay (true or L2) |
k | int | 6 | How often to conduct Lookahead step |
alpha | float | 0.5 | Slow weight moving average coefficient |
decouple_wd | bool | True | Apply true weight decay (RAdamW) or L2 regularization (RAdam) |
foreach | bool | False | Use fused ForEach implementation |
jit | bool | False | Use fused TorchScript implementation |
Returns | Lookahead | RangerForEachOptimizer | JitLookahead |
Footnotes
All optimizers benchmarked on a GeForce 3080 Ti using PyTorch 1.12.1, Mixed Precision, Channels Last (except ViT and DeBERTa), and fastxtend’s Simple Profiler Callback. Results may differ with other optimizers, models, hardware, and across benchmarking runs. Speedup is calculated from the total time spent on the optimization step.↩︎