TorchScript Optimizers
fastxtend TorchScript optimizers are adapted from fastai optimizers and are modified to be compiled with TorchScript. They are 10 to 137 percent faster relative to fastai native optimizers depending on the model and optimizer, with complex optimizers like QHAdam
recieving the largest performance increase.
Unlike fastai optimizers, which are made of multiple stepper callbacks, TorchScript optimizers require a per-optimizer step so TorchScript can fuse the operation into as few CUDA calls as possible. All fastai optimizers have TorchScript implementations.
TorchScript optimizers have only been tested on PyTorch 1.12+ with NVFuser and are not guaranteed to work on older versions.
TorchScript optimizers are faster due to vertical fusion across multiple Cuda calls. Using xresnet50
and SGD
with momentum as an example, a TorchScript fused SGD
step would (hopefully) fuse all three Cuda calls (mul
, add
, and add
) into one or two Cuda kernels resulting in 167 or 334 Cuda calls.
@torch.jit.script
def sgd_momentum_jit(param:Tensor, grad:Tensor, grad_avg:Tensor, lr:float):
= grad_avg.mul(mom).add(grad)
grad_avg = param.add(grad_avg, alpha=-lr) param
In contrast, a standard PyTorch optimizer would call the SGD
with momentum step 167 times for a total of 501 inplace Cuda kernel calls:
def simple_momentum_standard(param:Tensor, grad_avg:Tensor, lr:float):
grad_avg.mul_(mom).add_(grad)=-lr) param.add_(param.grad, alpha
TorchScript optimizers are tested to be equal to fastai optimizers for 25 steps using nbdev’s GitHub CI.
JitOptimizer
JitOptimizer (params:Listified[Tensor], opt_step:Callable, decouple_wd:bool=False, **defaults)
An Optimizer
with a modified step for TorchScript optimizers
Type | Default | Details | |
---|---|---|---|
params | Listified[Tensor] | Model parameters | |
opt_step | Callable | JitOptimizer optimizer step |
|
decouple_wd | bool | False | Use decoupled weight decay or L2 regularization, if applicable |
defaults |
sgd_jit_step
sgd_jit_step
rmsprop_jit_step
rmsprop_jit_step
adam_jit_step
adam_jit_step
radam_jit_step
radam_jit_step
qhadam_jit_step
qhadam_jit_step
larc_jit_step
larc_jit_step
lamb_jit_step
lamb_jit_step
JitLookahead
JitLookahead (params:Listified[Tensor], opt_step:Callable, decouple_wd:bool=False, **defaults)
An JitOptimizer
with a modified step for Lookahead TorchScript optimizers
Type | Default | Details | |
---|---|---|---|
params | Listified[Tensor] | Model parameters | |
opt_step | Callable | JitLookahead optimizer step |
|
decouple_wd | bool | False | Use decoupled weight decay or L2 regularization, if applicable |
defaults |
ranger_jit_step
ranger_jit_step