A Callback which converts a fastai `Learner` and input to channels_last format.

Using Mixed Precision, image models trained in channels last format on Nvidia Tensor Cores can achieve 8%-35% increased performance over contiguous format.

Channels last memory format is only implemented for NCHW Tensors. Not all PyTorch operators have been converted to support channels last. See (Beta) Channels Last Memory Format in PyTorch for more details.

Channels Last format can error out if torch.backends.cudnn.benchmark = False, e.g. via fastai's no_random context manager. If this occurs use the less_random context manager instead. This will allow reproducable training on the same GPU, PyTorch, and CUDA setup at the expense of less reproducablity should any of those change.

class ChannelsLast[source]

ChannelsLast(after_create=None, before_fit=None, before_epoch=None, before_train=None, before_batch=None, after_pred=None, after_loss=None, before_backward=None, after_cancel_backward=None, after_backward=None, before_step=None, after_cancel_step=None, after_step=None, after_cancel_batch=None, after_batch=None, after_cancel_train=None, after_train=None, before_validate=None, after_cancel_validate=None, after_validate=None, after_cancel_epoch=None, after_epoch=None, after_cancel_fit=None, after_fit=None) :: Callback

Channels last training using PyTorch's Channels Last Memory Format (beta)

When a PyTorch model is set to channels last format, PyTorch will automatically convert any compatible NCHW input tensors to NHWC format. ChannelsLast casts the model to channels-last format, so no changes to dataloaders or inputs are required.

Convenience Methods


Learner.to_channelslast(to_fp16:bool=True, init_scale=65536.0, growth_factor=2.0, backoff_factor=0.5, growth_interval=2000, enabled=True)

Set Learner and inputs to channels_last format and Mixed Precision by default

Type Default Details
to_fp16 bool True Add MixedPrecision callback. Required for full channels last performance
Valid Keyword Arguments
init_scale float 65536.0 Argument passed to GradScaler.__init__
growth_factor float 2.0 Argument passed to GradScaler.__init__
backoff_factor float 0.5 Argument passed to GradScaler.__init__
growth_interval int 2000 Argument passed to GradScaler.__init__
enabled bool True Argument passed to GradScaler.__init__



Set Learner and inputs to contiguous_format (default format), optionally to single precision