StableAdam Optimizer
StableAdamW was introduced by Wortsman et al in Stable and low-precision training for large-scale vision-language models. StableAdamW is a AdamW-Adafactor hybrid, porting Adafactor’s update clipping into AdamW
as a per parameter learning rate modification. StableAdamW’s update clipping outperforms gradient clipping on downstream tasks while avoiding model training instability.
This implementation of StableAdam
also includes L2 weight decay from Adam, which is not included in the paper.
In addition to a fastai native implementation, StableAdam
has a fused ForEach implementation. See the Fused Optimizer documentation for more details.
StableAdam
StableAdam (params:Union[torch.Tensor,Iterable[torch.Tensor],MutableSeque nce[torch.Tensor],fastcore.foundation.L,fastcore.basics.fastu ple], lr:float, mom:float=0.9, sqr_mom:float=0.99, eps:float=1e-05, wd:float=0.01, decouple_wd:bool=True, foreach:bool=False)
A fastai StableAdam/StableAdamW optimizer with a fused ForEach implementation
Type | Default | Details | |
---|---|---|---|
params | Listified[Tensor] | Model parameters or parameter groups | |
lr | float | Default learning rate | |
mom | float | 0.9 | Gradient moving average (β1) coefficient |
sqr_mom | float | 0.99 | Gradient squared moving average (β2) coefficient |
eps | float | 1e-05 | Added for numerical stability |
wd | float | 0.01 | Optional weight decay (true or L2) |
decouple_wd | bool | True | Apply true weight decay (StableAdamW) or L2 regularization (StableAdam) |
foreach | bool | False | Use fused ForEach implementation |
Returns | Optimizer | StableAdamForEachOptimizer |
stableadam
stableadam (mom:float=0.9, sqr_mom:float=0.99, eps:float=1e-05, wd:float=0.01, decouple_wd:bool=True, foreach:bool=False)
Partial function for the StableAdam/StableAdamW optimizer with a fused ForEach implementation
Type | Default | Details | |
---|---|---|---|
mom | float | 0.9 | Gradient moving average (β1) coefficient |
sqr_mom | float | 0.99 | Gradient squared moving average (β2) coefficient |
eps | float | 1e-05 | Added for numerical stability |
wd | float | 0.01 | Optional weight decay (true or L2) |
decouple_wd | bool | True | Apply true weight decay (StableAdamW) or L2 regularization (StableAdam) |
foreach | bool | False | Use fused ForEach implementation |
Returns | Optimizer | StableAdamForEachOptimizer |
Hyperparameters
Hyperparameter notes from Wortsman et al:
StableAdamW should be the same as AdamW, with β2, or sqr_mom
for fastai optimizers, set to higher values such as 0.99 for best performance.