A callback to cast model inputs to `Tensor` as a workaroud for a PyTorch performance bug

For use in fastai 2.6.x or older. Import globally:

from fastxtend.vision.all import *

or individually:

from fastxtend.callback import casttotensor

class CastToTensorBackport[source]

CastToTensorBackport(after_create=None, before_fit=None, before_epoch=None, before_train=None, before_batch=None, after_pred=None, after_loss=None, before_backward=None, before_step=None, after_cancel_step=None, after_step=None, after_cancel_batch=None, after_batch=None, after_cancel_train=None, after_train=None, before_validate=None, after_cancel_validate=None, after_validate=None, after_cancel_epoch=None, after_epoch=None, after_cancel_fit=None, after_fit=None) :: Callback

Cast Subclassed Tensors to Tensor

Workaround for bug in PyTorch where subclassed tensors, such as TensorBase, train up to ~20% slower than Tensor when passed to a model. Added to Learner by default if using fastai 2.6.x or older.

CastToTensorBackport is identical to the CastToTensor callback releasing with fastai 2.7.0.

CastToTensorBackport's order is right before MixedPrecision so callbacks which make use of fastai's tensor subclasses still can use them.

If inputs are not a subclassed tensor or tuple of tensors, you may need to cast inputs in Learner.xb and Learner.yb to Tensor via your own callback or in the dataloader before Learner performs the forward pass.

If the CastToTensorBackport workaround interferes with custom code, it can be removed:

learn = Learner(...)
learn.remove_cb(CastToTensorBackport)

You should verify your inputs are of type Tensor or implement a cast to Tensor via a custom callback or dataloader if CastToTensor is removed.