Fused Optimizers

Fused fastai optimizers using ForEach methods and TorchScript

fastxtend’s fused optimizers are 21 to 293 percent faster, drop-in replacements for fastai native optimizers. Like fastai optimizers, fastxtend fused optimizers support both discriminative learning rates across multiple parameter groups and per-parameter weight decay without any extra setup.

While all fastai optimizers have vertically fused TorchScript implementations, only a subset have horizontally fused ForEach implementations. These optimizers, SGD, Adam, RAdam, Lamb, and Ranger, usually outperform their TorchScript counterparts in all but the tiniest models.

fastxtend ForEach optimizers are equivalent in performance to PyTorch ForEach optimizers with two parameter groups, one for applying weight decay and one for parameters without weight decay.

Note

Documentation for individual optimizers are lightly adapted from the fastai optimizer documentation.

fastxtend ForEach optimizers are adapted from the PyTorch ForEach _multi_tensor implementations, but seamlessly work with fastai features.

For implementation details, see the ForEach or TorchScript documentation.

Important

ForEach and TorchScript optimizers have only been tested on PyTorch 1.12 and are not guaranteed to work on older versions.

Fused Performance

As shown in Table 1, ForEach Optimizers are 21 to 293 percent faster1 in optimizer step performance relative to fastai implementations across benchmarked models. Complex optimizers without ForEach implementations, such as QHAdam, are up to 137 percent faster using the TorchScript implementation.

Table 1: Increase in AdamW opt_step Speed vs fastai Native Optimizer
Model fastai Step ForEach Step ForEach Speedup JIT Step JIT Speedup
XResNet18 26ms 12ms 109% 20ms 29%
XResNet50 56ms 32ms 74% 46ms 20%
XSE-ResNext50 72ms 43ms 68% 61ms 18%
XResNet101 88ms 47ms 84% 68ms 30%
DeBERTa Base 27ms 6.9ms 293% 19ms 46%

This speedup persists with single or multiple parameter groups. Although more groups can lead to a small decrease in optimizer step speed, as shown by DeBERTa in Table 2.

Table 2: Increase in AdamW opt_step Speed With Multiple Param Groups vs fastai Native Optimizer
Model Layers fastai Step ForEach Step ForEach Speedup JIT Step JIT Speedup
XResNet18 2 25ms 12ms 103% 19ms 30%
XResNet50 2 56ms 32ms 76% 46ms 24%
XSE-ResNext50 2 72ms 45ms 85% 61ms 29%
XResNet101 2 87ms 47ms 60% 67ms 17%
ConvNext Tiny 2 125ms 102ms 22% 115ms 9.4%
ConvNext Small 2 200ms 165ms 21% 181ms 10%
ViT Patch16 Small 2 62ms 38ms 62% 52ms 20%
DeBERTa Base 4 27ms 7.7ms 254% 19ms 47%

Examples

For backwards compatibility, all fastxtend optimizers return a fastai native optimizer by default. To use a fused version set foreach=True or jit=True.

from fastai.vision.all import *
from fastxtend.vision.all import *

# Use ForEach AdamW
opt_func = adam(foreach=True)

# Or use TorchScript AdamW
opt_func = adam(jit=True)

Learner(..., opt_func=opt_func)

Or import fused optimizers independent of other fastxtend features.

from fastai.vision.all import *
from fastxtend.optimizer.all import *

Learner(..., opt_func=partial(Adam, foreach=True))
Note

adam(...) is a fastxtend convenience method equivalent to partial(Adam, ...). fastextend adds lowercase convenience methods for all fastai optimizers.

SGD Optimizer

Stochastic gradient descent, optionally with momentum.

Optional weight decay of wd is applied, as true weight decay (decay the weights directly) if decouple_wd=True else as L2 regularization (add the decay to the gradients).


source

SGD

 SGD (params:Union[torch.Tensor,Iterable[torch.Tensor],fastcore.foundation
      .L,fastcore.basics.fastuple], lr:float, mom:float=0.0, wd:float=0.0,
      decouple_wd:bool=True, foreach:bool=False, jit:bool=False)

A fastai SGD/SGDW optimizer with fused ForEach and TorchScript implementations

Type Default Details
params listified[Tensor] Model parameters or parameter groups
lr float Default learning rate
mom float 0.0 Gradient moving average (β1) coefficient
wd float 0.0 Optional weight decay (true or L2)
decouple_wd bool True Apply true weight decay (SGDW) or L2 regularization (SGD)
foreach bool False Use fused ForEach implementation
jit bool False Use fused TorchScript implementation
Returns Optimizer | SGDForEachOptimizer | JitOptimizer

source

sgd

 sgd (mom:float=0.0, wd:float=0.0, decouple_wd:bool=True,
      foreach:bool=False, jit:bool=False)

Partial function for the SGD/SGDW optimizer with fused ForEach and TorchScript implementations

Type Default Details
mom float 0.0 Gradient moving average (β1) coefficient
wd float 0.0 Optional weight decay (true or L2)
decouple_wd bool True Apply true weight decay (SGDW) or L2 regularization (SGD)
foreach bool False Use fused ForEach implementation
jit bool False Use fused TorchScript implementation
Returns Optimizer | SGDForEachOptimizer | JitOptimizer

RMSProp Optimizer

RMSProp was introduced by Geoffrey Hinton in his course. What is named sqr_mom here is the alpha in the course.

Optional weight decay of wd is applied as true weight decay (decay the weights directly) if decouple_wd=True else as L2 regularization (add the decay to the gradients).

Note

The order of the mom and sqr_mom hyperparameters has been swapped from fastai to follow the order of all the other fastai and fastxtend optimizers.


source

RMSProp

 RMSProp (params:Union[torch.Tensor,Iterable[torch.Tensor],fastcore.founda
          tion.L,fastcore.basics.fastuple], lr:float, mom:float=0.0,
          sqr_mom:float=0.99, eps:float=1e-08, wd:float=0.0,
          decouple_wd:bool=True, jit:bool=False)

A fastai RMSProp/RMSPropW optimizer with a fused TorchScript implementation

Type Default Details
params listified[Tensor] Model parameters or parameter groups
lr float Default learning rate
mom float 0.0 Gradient moving average (β1) coefficient
sqr_mom float 0.99 Gradient squared moving average (β2) coefficient
eps float 1e-08 Added for numerical stability
wd float 0.0 Optional weight decay (true or L2)
decouple_wd bool True Apply true weight decay (RMSPropW) or L2 regularization (RMSProp)
jit bool False Use fused TorchScript implementation
Returns Optimizer | JitOptimizer

source

rmsprop

 rmsprop (mom:float=0.0, sqr_mom:float=0.99, eps:float=1e-08,
          wd:float=0.0, decouple_wd:bool=True, jit:bool=False)

Partial function for the RMSProp/RMSPropW optimizer with a fused TorchScript implementation

Type Default Details
mom float 0.0 Gradient moving average (β1) coefficient
sqr_mom float 0.99 Gradient squared moving average (β2) coefficient
eps float 1e-08 Added for numerical stability
wd float 0.0 Optional weight decay (true or L2)
decouple_wd bool True Apply true weight decay (RMSPropW) or L2 regularization (RMSProp)
jit bool False Use fused TorchScript implementation
Returns Optimizer | JitOptimizer

Adam Optimizer

Adam was introduced by Diederik P. Kingma and Jimmy Ba in Adam: A Method for Stochastic Optimization. For consistency across optimizers, fastai renamed beta1 and beta2 in the paper to mom and sqr_mom. Note that the defaults also differ from the paper (0.99 for sqr_mom or beta2, 1e-5 for eps). Those values seem to be better from experimentation in a wide range of situations.

Optional weight decay of wd is applied, as true weight decay (decay the weights directly) if decouple_wd=True else as L2 regularization (add the decay to the gradients).

Note

Don’t forget that eps is an hyper-parameter you can change. Some models won’t train without a very high eps like 0.1 (intuitively, the higher eps is, the closer Adam is to normal SGD). The usual default of 1e-8 is often too extreme in the sense Adam does’t manage to get as good results as with SGD.


source

Adam

 Adam (params:Union[torch.Tensor,Iterable[torch.Tensor],fastcore.foundatio
       n.L,fastcore.basics.fastuple], lr:float, mom:float=0.9,
       sqr_mom:float=0.99, eps:float=1e-05, wd:float=0.01,
       decouple_wd:bool=True, foreach:bool=False, jit:bool=False)

A fastai Adam/AdamW optimizer with fused ForEach and TorchScript implementations

Type Default Details
params listified[Tensor] Model parameters or parameter groups
lr float Default learning rate
mom float 0.9 Gradient moving average (β1) coefficient
sqr_mom float 0.99 Gradient squared moving average (β2) coefficient
eps float 1e-05 Added for numerical stability
wd float 0.01 Optional weight decay (true or L2)
decouple_wd bool True Apply true weight decay (AdamW) or L2 regularization (Adam)
foreach bool False Use fused ForEach implementation
jit bool False Use fused TorchScript implementation
Returns Optimizer | AdamForEachOptimizer | JitOptimizer

source

adam

 adam (mom:float=0.0, sqr_mom:float=0.99, eps:float=1e-08, wd:float=0.0,
       decouple_wd:bool=True, foreach:bool=False, jit:bool=False)

Partial function for the Adam/AdamW optimizer with fused ForEach and TorchScript implementations

Type Default Details
mom float 0.0 Gradient moving average (β1) coefficient
sqr_mom float 0.99 Gradient squared moving average (β2) coefficient
eps float 1e-08 Added for numerical stability
wd float 0.0 Optional weight decay (true or L2)
decouple_wd bool True Apply true weight decay (RMSPropW) or L2 regularization (RMSProp)
foreach bool False Use fused ForEach implementation
jit bool False Use fused TorchScript implementation
Returns Optimizer | AdamForEachOptimizer | JitOptimizer

RAdam Optimizer

RAdam (for rectified Adam) was introduced by Zhang et al. in On the Variance of the Adaptive Default learning rate and Beyond to slightly modify the Adam optimizer to be more stable at the beginning of training (and thus not require a long warmup). They use an estimate of the variance of the moving average of the squared gradients (the term in the denominator of traditional Adam) and rescale this moving average by this term before performing the update.

The native fastai implementation also incorporates SAdam; set beta to enable this (definition same as in the paper).

Note

fastxtend ForEach and TorchScript implementations do not support beta and SAdam.

Optional weight decay of wd is applied, as true weight decay (decay the weights directly) if decouple_wd=True else as L2 regularization (add the decay to the gradients).


source

RAdam

 RAdam (params:Union[torch.Tensor,Iterable[torch.Tensor],fastcore.foundati
        on.L,fastcore.basics.fastuple], lr:float, mom:float=0.9,
        sqr_mom:float=0.99, eps:float=1e-05, wd:float=0.0, beta:float=0.0,
        decouple_wd:bool=True, foreach:bool=False, jit:bool=False)

A fastai RAdam/RAdamW optimizer with fused ForEach and TorchScript implementations

Type Default Details
params listified[Tensor] Model parameters or parameter groups
lr float Default learning rate
mom float 0.9 Gradient moving average (β1) coefficient
sqr_mom float 0.99 Gradient squared moving average (β2) coefficient
eps float 1e-05 Added for numerical stability
wd float 0.0 Optional weight decay (true or L2)
beta float 0.0 Set to enable SAdam with native fastai RAdam
decouple_wd bool True Apply true weight decay (RAdamW) or L2 regularization (RAdam)
foreach bool False Use fused ForEach implementation
jit bool False Use fused TorchScript implementation
Returns Optimizer | RAdamForEachOptimizer | JitOptimizer

source

radam

 radam (mom:float=0.0, sqr_mom:float=0.99, eps:float=1e-08, wd:float=0.0,
        beta:float=0.0, decouple_wd:bool=True, foreach:bool=False,
        jit:bool=False)

Partial function for the RAdam/RAdamW optimizer with fused ForEach and TorchScript implementations

Type Default Details
mom float 0.0 Gradient moving average (β1) coefficient
sqr_mom float 0.99 Gradient squared moving average (β2) coefficient
eps float 1e-08 Added for numerical stability
wd float 0.0 Optional weight decay (true or L2)
beta float 0.0 Set to enable SAdam with native fastai RAdam
decouple_wd bool True Apply true weight decay (RMSPropW) or L2 regularization (RMSProp)
foreach bool False Use fused ForEach implementation
jit bool False Use fused TorchScript implementation
Returns Optimizer | RAdamForEachOptimizer | JitOptimizer

QHAdam Optimizer

QHAdam (for Quasi-Hyperbolic Adam) was introduced by Ma & Yarats in Quasi-Hyperbolic Momentum and Adam for Deep Learning as a “computationally cheap, intuitive to interpret, and simple to implement” optimizer. Additional code can be found in their qhoptim repo. QHAdam is based on QH-Momentum, which introduces the immediate discount factor nu, encapsulating plain SGD (nu = 0) and momentum (nu = 1). QH-Momentum is defined below, where g_t+1 is the update of the moment. An interpretation of QHM is as a nu-weighted average of the momentum update step and the plain SGD update step.

θ_t+1 ← θ_t − lr * [(1 − nu) · ∇L_t(θ_t) + nu · g_t+1]

QHAdam takes the concept behind QHM above and applies it to Adam, replacing both of Adam’s moment estimators with quasi-hyperbolic terms.

The paper’s suggested default parameters are mom = 0.999, sqr_mom = 0.999, nu_1 = 0.7 and and nu_2 = 1.0. When training is not stable, it is possible that setting nu_2 < 1 can improve stability by imposing a tighter step size bound. Note that QHAdam recovers Adam when nu_1 = nu_2 = 1.0. QHAdam recovers RMSProp (Hinton et al., 2012) when nu_1 = 0 and nu_2 = 1, and NAdam (Dozat, 2016) when nu_1 = mom and nu_2 = 1.

Optional weight decay of wd is applied, as true weight decay (decay the weights directly) if decouple_wd=True else as L2 regularization (add the decay to the gradients).


source

QHAdam

 QHAdam (params:Union[torch.Tensor,Iterable[torch.Tensor],fastcore.foundat
         ion.L,fastcore.basics.fastuple], lr:float, mom:float=0.999,
         sqr_mom:float=0.999, nu_1:float=0.7, nu_2:float=1.0,
         eps:float=1e-08, wd:float=0.0, decouple_wd:bool=True,
         jit:bool=False)

A fastai QHAdam/QHAdamW optimizer with a fused TorchScript implementation

Type Default Details
params listified[Tensor] Model parameters or parameter groups
lr float Default learning rate
mom float 0.999 Gradient moving average (β1) coefficient
sqr_mom float 0.999 Gradient squared moving average (β2) coefficient
nu_1 float 0.7 QH immediate discount factor
nu_2 float 1.0 QH momentum discount factor
eps float 1e-08 Added for numerical stability
wd float 0.0 Optional weight decay (true or L2)
decouple_wd bool True Apply true weight decay (QHAdamW) or L2 regularization (QHAdam)
jit bool False Use fused TorchScript implementation
Returns Optimizer | JitOptimizer

source

qhadam

 qhadam (mom:float=0.0, sqr_mom:float=0.99, nu_1:float=0.7,
         nu_2:float=1.0, eps:float=1e-08, wd:float=0.0,
         decouple_wd:bool=True, jit:bool=False)

Partial function for the QHAdam/QHAdamW optimizer with a fused TorchScript implementation

Type Default Details
mom float 0.0 Gradient moving average (β1) coefficient
sqr_mom float 0.99 Gradient squared moving average (β2) coefficient
nu_1 float 0.7 QH immediate discount factor
nu_2 float 1.0 QH momentum discount factor
eps float 1e-08 Added for numerical stability
wd float 0.0 Optional weight decay (true or L2)
decouple_wd bool True Apply true weight decay (RMSPropW) or L2 regularization (RMSProp)
jit bool False Use fused TorchScript implementation
Returns Optimizer | JitOptimizer

LARS/LARC Optimizer

The LARS optimizer was first introduced in Large Batch Training of Convolutional Networks then refined in its LARC variant (original LARS is with clip=False). A Default learning rate is computed for each individual layer with a certain trust_coefficient, then clipped to be always less than lr.

Optional weight decay of wd is applied, as true weight decay (decay the weights directly) if decouple_wd=True else as L2 regularization (add the decay to the gradients).


source

Larc

 Larc (params:Union[torch.Tensor,Iterable[torch.Tensor],fastcore.foundatio
       n.L,fastcore.basics.fastuple], lr:float, mom:float=0.9,
       clip:bool=True, trust_coeff:float=0.02, eps:float=1e-08,
       wd:float=0.0, decouple_wd:bool=True, jit:bool=False)

A fastai LARC/LARS optimizer with a fused TorchScript implementation

Type Default Details
params listified[Tensor] Model parameters or parameter groups
lr float Default learning rate
mom float 0.9 Gradient moving average (β1) coefficient
clip bool True LARC if clip=True, LARS if clip=False
trust_coeff float 0.02 Trust coeffiecnet for calculating layerwise LR
eps float 1e-08 Added for numerical stability
wd float 0.0 Optional weight decay (true or L2)
decouple_wd bool True Apply true weight decay or L2 regularization
jit bool False Use fused TorchScript implementation
Returns Optimizer | JitOptimizer

source

larc

 larc (mom:float=0.0, clip:bool=True, trust_coeff:float=0.02,
       eps:float=1e-08, wd:float=0.0, decouple_wd:bool=True,
       jit:bool=False)

Partial function for the LARC/LARS optimizer with a fused TorchScript implementation

Type Default Details
mom float 0.0 Gradient moving average (β1) coefficient
clip bool True LARC if clip=True, LARS if clip=False
trust_coeff float 0.02 Trust coeffiecnet for calculating layerwise LR
eps float 1e-08 Added for numerical stability
wd float 0.0 Optional weight decay (true or L2)
decouple_wd bool True Apply true weight decay (RMSPropW) or L2 regularization (RMSProp)
jit bool False Use fused TorchScript implementation
Returns Optimizer | JitOptimizer

LAMB Optimizer

LAMB was introduced in Large Batch Optimization for Deep Learning: Training BERT in 76 minutes. Intuitively, it’s LARC applied to Adam. As in Adam, beta1 and beta2 in the paper is renamed to mom and sqr_mom. Note that the defaults also differ from the paper (0.99 for sqr_mom or beta2, 1e-5 for eps). Those values seem to be better from experimentation in a wide range of situations.

Optional weight decay of wd is applied, as true weight decay (decay the weights directly) if decouple_wd=True else as L2 regularization (add the decay to the gradients).


source

Lamb

 Lamb (params:Union[torch.Tensor,Iterable[torch.Tensor],fastcore.foundatio
       n.L,fastcore.basics.fastuple], lr:float, mom:float=0.9,
       sqr_mom:float=0.99, eps:float=1e-05, wd:float=0.01,
       decouple_wd:bool=True, foreach:bool=False, jit:bool=False)

A fastai LAMB optimizer with fused ForEach and TorchScript implementations

Type Default Details
params listified[Tensor] Model parameters or parameter groups
lr float Default learning rate
mom float 0.9 Gradient moving average (β1) coefficient
sqr_mom float 0.99 Gradient squared moving average (β2) coefficient
eps float 1e-05 Added for numerical stability
wd float 0.01 Optional weight decay (true or L2). Paper default, fastai’s is 0.
decouple_wd bool True Apply true weight decay or L2 regularization
foreach bool False Use fused ForEach implementation
jit bool False Use fused TorchScript implementation
Returns Optimizer | LambForEachOptimizer | JitOptimizer

source

lamb

 lamb (mom:float=0.0, sqr_mom:float=0.99, eps:float=1e-08, wd:float=0.0,
       decouple_wd:bool=True, foreach:bool=False, jit:bool=False)

Partial function for the LAMB optimizer with fused ForEach and TorchScript implementations

Type Default Details
mom float 0.0 Gradient moving average (β1) coefficient
sqr_mom float 0.99 Gradient squared moving average (β2) coefficient
eps float 1e-08 Added for numerical stability
wd float 0.0 Optional weight decay (true or L2)
decouple_wd bool True Apply true weight decay (RMSPropW) or L2 regularization (RMSProp)
foreach bool False Use fused ForEach implementation
jit bool False Use fused TorchScript implementation
Returns Optimizer | LambForEachOptimizer | JitOptimizer

Ranger Optimizer

Lookahead was introduced by Zhang et al. in Lookahead Optimizer: k steps forward, 1 step back. With Lookahead, the final weights (slow weights) are a moving average of the normal weights (fast weights). Every k steps, Lookahead modifieds the current weights by a moving average of the fast weights (normal weights) with the slow weights (the copy of old weights k steps ago). Those slow weights act like a stability mechanism.

Ranger was introduced by Less Wright in New Deep Learning Optimizer, Ranger: Synergistic combination of RAdam + Lookahead for the best of both. It combines RAdam and Lookahead together in one optimizer and reduces the need for hyperparameter tuning due to a combination of RAdam’s warmup heuristic and Lookahead’s interpolation of parameter weights.

Important

While fastai’s Lookahead can be applied to any optimizer, fastxtend’s JitLookahead must have a custom written TorchScript callback and ForEachOptimizer a custom Lookahead optimizer step. Currently ranger with RAdam is the only TorchScript and ForEach optimizer with Lookahead support.

Ranger performs best on vision tasks when paired with the fit_flat_cos or fit_flat_varied schedulers.

Warning

Ranger is the only non-backward compatible fastxtend Optimizer. Ranger is equivalent fastai’s ranger while fastxtend’s ranger is a partial function which returns Ranger. Most fastai code should be uneffected by this change.

Optional weight decay of wd is applied, as true weight decay (decay the weights directly) if decouple_wd=True else as L2 regularization (add the decay to the gradients).


source

Ranger

 Ranger (params:Union[torch.Tensor,Iterable[torch.Tensor],fastcore.foundat
         ion.L,fastcore.basics.fastuple], lr:float, mom:float=0.95,
         sqr_mom:float=0.99, eps:float=1e-06, wd:float=0.01, k:int=6,
         alpha:float=0.5, decouple_wd:bool=True, foreach:bool=False,
         jit:bool=False)

Convenience method for Lookahead with RAdam fused ForEach and TorchScript implementations

Type Default Details
params listified[Tensor] Model parameters or parameter groups
lr float Default learning rate
mom float 0.95 Gradient moving average (β1) coefficient
sqr_mom float 0.99 Gradient squared moving average (β2) coefficient
eps float 1e-06 Added for numerical stability
wd float 0.01 Optional weight decay (true or L2)
k int 6 How often to conduct Lookahead step
alpha float 0.5 Slow weight moving average coefficient
decouple_wd bool True Apply true weight decay (RAdamW) or L2 regularization (RAdam)
foreach bool False Use fused ForEach implementation
jit bool False Use fused TorchScript implementation
Returns Lookahead | RangerForEachOptimizer | JitLookahead

source

ranger

 ranger (mom:float=0.95, sqr_mom:float=0.99, eps:float=1e-06,
         wd:float=0.01, k:int=6, alpha:float=0.5, decouple_wd:bool=True,
         foreach:bool=False, jit:bool=False)

Partial function of the onvenience method for Lookahead with RAdam fused ForEach and TorchScript implementations

Type Default Details
mom float 0.95 Gradient moving average (β1) coefficient
sqr_mom float 0.99 Gradient squared moving average (β2) coefficient
eps float 1e-06 Added for numerical stability
wd float 0.01 Optional weight decay (true or L2)
k int 6 How often to conduct Lookahead step
alpha float 0.5 Slow weight moving average coefficient
decouple_wd bool True Apply true weight decay (RAdamW) or L2 regularization (RAdam)
foreach bool False Use fused ForEach implementation
jit bool False Use fused TorchScript implementation
Returns Lookahead | RangerForEachOptimizer | JitLookahead

Footnotes

  1. All optimizers benchmarked on a GeForce 3080 Ti using PyTorch 1.12.1, Mixed Precision, Channels Last (except ViT and DeBERTa), and fastxtend’s Simple Profiler Callback. Results may differ with other optimizers, models, hardware, and across benchmarking runs. Speedup is calculated from the total time spent on the optimization step.↩︎