ForEach Optimizers

Fused fastai optimizers using PyTorch ForEach methods for improved performance

fastxtend ForEach optimizers are adapted from the PyTorch ForEach _multi_tensor implementations. They are 21 to 293 percent faster relative to fastai native optimizers depending on the model.

The primary difference between PyTorch’s ForEach implementations and fastxtend is fastxtend’s ForEach optimizers apply per-parameter weight decay in one optimizer step instead of requiring a seperate weight decay parameter group and a non-weight decay parameter group. This also allows seamless support for fastai’s discriminative learning rates.

Unlike fastai optimizers, which are made of multiple stepper callbacks and share one Optimizer, ForEach optimizers require an optimizer specific ForEachOptimizer implementation.

Currently SGD, Adam, RAdam, Lamb, and Ranger have ForEach implementations.

Important

ForEach optimizers have only been tested on PyTorch 1.12 and are not guaranteed to work on older versions.

ForEach optimizers are faster due to horizontal fusion across multiple parameters. Using xresnet50 and the simplest form of SGD as an example, a ForEach optimizer would construct a list of all 167 params and their grads before performing one horizontally fused step.

def simple_sgd_foreach(params:list[Tensor], grads:list[Tensor], lr:float)
    torch._foreach_add_(params, grads, alpha=-lr)

In contrast, a standard PyTorch optimizer would call the simple SGD step 167 times:

def simple_sgd_standard(param:Tensor, lr:float)
    param.add_(param.grad, alpha=-lr)

ForEach optimizers are tested to be equal to fastai optimizers for 25 steps using nbdev’s GitHub CI.


source

ForEachOptimizer

 ForEachOptimizer (params:listified[Tensor], opt_step:Callable,
                   decouple_wd:bool=True, **defaults)

Base foreach optimizer class, updating params with opt_step instead of Optimizer.cbs

Type Default Details
params listified[Tensor] Model parameters
opt_step Callable ForEachOptimizer optimizer step
decouple_wd bool True Use true weight decay or L2 regularization, if applicable
defaults

source

SGDForEachOptimizer

 SGDForEachOptimizer (params:listified[Tensor], opt_step:Callable,
                      decouple_wd:bool=True, **defaults)

A ForEachOptimizer with a modified step for sgd_foreach_step

Type Default Details
params listified[Tensor] Model parameters
opt_step Callable ForEachOptimizer optimizer step
decouple_wd bool True Use true weight decay or L2 regularization, if applicable
defaults

source

AdamForEachOptimizer

 AdamForEachOptimizer (params:listified[Tensor], opt_step:Callable,
                       decouple_wd:bool=True, **defaults)

An ForEachOptimizer with a modified step for adam_foreach_step

Type Default Details
params listified[Tensor] Model parameters
opt_step Callable ForEachOptimizer optimizer step
decouple_wd bool True Use true weight decay or L2 regularization, if applicable
defaults

source

RAdamForEachOptimizer

 RAdamForEachOptimizer (params:listified[Tensor], opt_step:Callable,
                        decouple_wd:bool=True, **defaults)

An ForEachOptimizer with a modified step for radam_foreach_step

Type Default Details
params listified[Tensor] Model parameters
opt_step Callable ForEachOptimizer optimizer step
decouple_wd bool True Use true weight decay or L2 regularization, if applicable
defaults

source

LambForEachOptimizer

 LambForEachOptimizer (params:listified[Tensor], opt_step:Callable,
                       decouple_wd:bool=True, **defaults)

An ForEachOptimizer with a modified step for lamb_foreach_step

Type Default Details
params listified[Tensor] Model parameters
opt_step Callable ForEachOptimizer optimizer step
decouple_wd bool True Use true weight decay or L2 regularization, if applicable
defaults

source

RangerForEachOptimizer

 RangerForEachOptimizer (params:listified[Tensor], opt_step:Callable,
                         decouple_wd:bool=True, **defaults)

An ForEachOptimizer with a modified LookAhead step for ranger_foreach_step

Type Default Details
params listified[Tensor] Model parameters
opt_step Callable ForEachOptimizer optimizer step
decouple_wd bool True Use true weight decay or L2 regularization, if applicable
defaults