TorchScript Optimizers

Fused fastai optimizers compiled with TorchScript for improved performance

fastxtend TorchScript optimizers are adapted from fastai optimizers and are modified to be compiled with TorchScript. They are 10 to 137 percent faster relative to fastai native optimizers depending on the model and optimizer, with complex optimizers like QHAdam recieving the largest performance increase.

Unlike fastai optimizers, which are made of multiple stepper callbacks, TorchScript optimizers require a per-optimizer step so TorchScript can fuse the operation into as few CUDA calls as possible. All fastai optimizers have TorchScript implementations.

Important

TorchScript optimizers have only been tested on PyTorch 1.12+ with NVFuser and are not guaranteed to work on older versions.

TorchScript optimizers are faster due to vertical fusion across multiple Cuda calls. Using xresnet50 and SGD with momentum as an example, a TorchScript fused SGD step would (hopefully) fuse all three Cuda calls (mul, add, and add) into one or two Cuda kernels resulting in 167 or 334 Cuda calls.

@torch.jit.script
def sgd_momentum_jit(param:Tensor, grad:Tensor, grad_avg:Tensor, lr:float):
    grad_avg = grad_avg.mul(mom).add(grad)
    param = param.add(grad_avg, alpha=-lr)

In contrast, a standard PyTorch optimizer would call the SGD with momentum step 167 times for a total of 501 inplace Cuda kernel calls:

def simple_momentum_standard(param:Tensor, grad_avg:Tensor, lr:float):
    grad_avg.mul_(mom).add_(grad)
    param.add_(param.grad, alpha=-lr)

TorchScript optimizers are tested to be equal to fastai optimizers for 25 steps using nbdev’s GitHub CI.


source

JitOptimizer

 JitOptimizer (params:Listified[Tensor], opt_step:Callable,
               decouple_wd:bool=False, **defaults)

An Optimizer with a modified step for TorchScript optimizers

Type Default Details
params Listified[Tensor] Model parameters
opt_step Callable JitOptimizer optimizer step
decouple_wd bool False Use decoupled weight decay or L2 regularization, if applicable
defaults

source

sgd_jit_step

sgd_jit_step


source

rmsprop_jit_step

rmsprop_jit_step


source

adam_jit_step

adam_jit_step


source

radam_jit_step

radam_jit_step


source

qhadam_jit_step

qhadam_jit_step


source

larc_jit_step

larc_jit_step


source

lamb_jit_step

lamb_jit_step


source

JitLookahead

 JitLookahead (params:Listified[Tensor], opt_step:Callable,
               decouple_wd:bool=False, **defaults)

An JitOptimizer with a modified step for Lookahead TorchScript optimizers

Type Default Details
params Listified[Tensor] Model parameters
opt_step Callable JitLookahead optimizer step
decouple_wd bool False Use decoupled weight decay or L2 regularization, if applicable
defaults

source

ranger_jit_step

ranger_jit_step