StableAdam Optimizer

With fastai native and fused ForEach implementations

StableAdamW was introduced by Wortsman et al in Stable and low-precision training for large-scale vision-language models. StableAdamW is a AdamW-Adafactor hybrid, porting Adafactor’s update clipping into AdamW as a per parameter learning rate modification. StableAdamW’s update clipping outperforms gradient clipping on downstream tasks while avoiding model training instability.

Note

This implementation of StableAdam also includes L2 weight decay from Adam, which is not included in the paper.

In addition to a fastai native implementation, StableAdam has a fused ForEach implementation. See the Fused Optimizer documentation for more details.


source

StableAdam

 StableAdam (params:Union[torch.Tensor,Iterable[torch.Tensor],MutableSeque
             nce[torch.Tensor],fastcore.foundation.L,fastcore.basics.fastu
             ple], lr:float, mom:float=0.9, sqr_mom:float=0.99,
             eps:float=1e-05, wd:float=0.01, decouple_wd:bool=True,
             foreach:bool=False)

A fastai StableAdam/StableAdamW optimizer with a fused ForEach implementation

Type Default Details
params Listified[Tensor] Model parameters or parameter groups
lr float Default learning rate
mom float 0.9 Gradient moving average (β1) coefficient
sqr_mom float 0.99 Gradient squared moving average (β2) coefficient
eps float 1e-05 Added for numerical stability
wd float 0.01 Optional weight decay (true or L2)
decouple_wd bool True Apply true weight decay (StableAdamW) or L2 regularization (StableAdam)
foreach bool False Use fused ForEach implementation
Returns Optimizer | StableAdamForEachOptimizer

source

stableadam

 stableadam (mom:float=0.9, sqr_mom:float=0.99, eps:float=1e-05,
             wd:float=0.01, decouple_wd:bool=True, foreach:bool=False)

Partial function for the StableAdam/StableAdamW optimizer with a fused ForEach implementation

Type Default Details
mom float 0.9 Gradient moving average (β1) coefficient
sqr_mom float 0.99 Gradient squared moving average (β2) coefficient
eps float 1e-05 Added for numerical stability
wd float 0.01 Optional weight decay (true or L2)
decouple_wd bool True Apply true weight decay (StableAdamW) or L2 regularization (StableAdam)
foreach bool False Use fused ForEach implementation
Returns Optimizer | StableAdamForEachOptimizer

Hyperparameters

Hyperparameter notes from Wortsman et al:

StableAdamW should be the same as AdamW, with β2, or sqr_mom for fastai optimizers, set to higher values such as 0.99 for best performance.