For use in fastai 2.6.x or older. Import globally:
from fastxtend.vision.all import *
or individually:
from fastxtend.callback import casttotensor
Workaround for bug in PyTorch where subclassed tensors, such as TensorBase
, train up to ~20% slower than Tensor
when passed to a model. Added to Learner
by default if using fastai 2.6.x or older.
CastToTensorBackport is identical to the CastToTensor callback releasing with fastai 2.7.0.
CastToTensorBackport's order is right before MixedPrecision
so callbacks which make use of fastai's tensor subclasses still can use them.
If inputs are not a subclassed tensor or tuple of tensors, you may need to cast inputs in Learner.xb
and Learner.yb
to Tensor
via your own callback or in the dataloader before Learner
performs the forward pass.
If the CastToTensorBackport workaround interferes with custom code, it can be removed:
learn = Learner(...)
learn.remove_cb(CastToTensorBackport)
You should verify your inputs are of type Tensor
or implement a cast to Tensor
via a custom callback or dataloader if CastToTensor is removed.