ForEach Optimizers
fastxtend ForEach optimizers are adapted from the PyTorch ForEach _multi_tensor
implementations. They are 21 to 293 percent faster relative to fastai native optimizers depending on the model.
The primary difference between PyTorch’s ForEach implementations and fastxtend is fastxtend’s ForEach optimizers apply per-parameter weight decay in one optimizer step instead of requiring a seperate weight decay parameter group and a non-weight decay parameter group. This also allows seamless support for fastai’s discriminative learning rates.
Unlike fastai optimizers, which are made of multiple stepper callbacks and share one Optimizer
, ForEach optimizers require an optimizer specific ForEachOptimizer
implementation.
Currently SGD
, Adam
, RAdam
, Lamb
, and Ranger
have ForEach implementations.
::{.callout-important} ForEach optimizers have only been tested on PyTorch 1.12+ and are not guaranteed to work on older versions. ::
ForEach optimizers are faster due to horizontal fusion across multiple parameters. Using xresnet50
and the simplest form of SGD
as an example, a ForEach optimizer would construct a list of all 167 params
and their grads
before performing one horizontally fused step.
def simple_sgd_foreach(params:list[Tensor], grads:list[Tensor], lr:float)
=-lr) torch._foreach_add_(params, grads, alpha
In contrast, a standard PyTorch optimizer would call the simple SGD
step 167 times:
def simple_sgd_standard(param:Tensor, lr:float)
=-lr) param.add_(param.grad, alpha
ForEach optimizers are tested to be equal to fastai optimizers for 25 steps using nbdev’s GitHub CI.
ForEachOptimizer
ForEachOptimizer (params:Listified[Tensor], opt_step:Callable, decouple_wd:bool=True, **defaults)
Base foreach optimizer class, updating params
with opt_step
instead of Optimizer.cbs
Type | Default | Details | |
---|---|---|---|
params | Listified[Tensor] | Model parameters | |
opt_step | Callable | ForEachOptimizer optimizer step |
|
decouple_wd | bool | True | Use true weight decay or L2 regularization, if applicable |
defaults |
SGDForEachOptimizer
SGDForEachOptimizer (params:Listified[Tensor], opt_step:Callable, decouple_wd:bool=True, **defaults)
A ForEachOptimizer
with a modified step for sgd_foreach_step
Type | Default | Details | |
---|---|---|---|
params | Listified[Tensor] | Model parameters | |
opt_step | Callable | ForEachOptimizer optimizer step |
|
decouple_wd | bool | True | Use true weight decay or L2 regularization, if applicable |
defaults |
AdamForEachOptimizer
AdamForEachOptimizer (params:Listified[Tensor], opt_step:Callable, decouple_wd:bool=True, **defaults)
An ForEachOptimizer
with a modified step for adam_foreach_step
Type | Default | Details | |
---|---|---|---|
params | Listified[Tensor] | Model parameters | |
opt_step | Callable | ForEachOptimizer optimizer step |
|
decouple_wd | bool | True | Use true weight decay or L2 regularization, if applicable |
defaults |
RAdamForEachOptimizer
RAdamForEachOptimizer (params:Listified[Tensor], opt_step:Callable, decouple_wd:bool=True, **defaults)
An ForEachOptimizer
with a modified step for radam_foreach_step
Type | Default | Details | |
---|---|---|---|
params | Listified[Tensor] | Model parameters | |
opt_step | Callable | ForEachOptimizer optimizer step |
|
decouple_wd | bool | True | Use true weight decay or L2 regularization, if applicable |
defaults |
LambForEachOptimizer
LambForEachOptimizer (params:Listified[Tensor], opt_step:Callable, decouple_wd:bool=True, **defaults)
An ForEachOptimizer
with a modified step for lamb_foreach_step
Type | Default | Details | |
---|---|---|---|
params | Listified[Tensor] | Model parameters | |
opt_step | Callable | ForEachOptimizer optimizer step |
|
decouple_wd | bool | True | Use true weight decay or L2 regularization, if applicable |
defaults |
RangerForEachOptimizer
RangerForEachOptimizer (params:Listified[Tensor], opt_step:Callable, decouple_wd:bool=True, **defaults)
An ForEachOptimizer
with a modified LookAhead
step for ranger_foreach_step
Type | Default | Details | |
---|---|---|---|
params | Listified[Tensor] | Model parameters | |
opt_step | Callable | ForEachOptimizer optimizer step |
|
decouple_wd | bool | True | Use true weight decay or L2 regularization, if applicable |
defaults |