from fastai.vision.all import *
from fastxtend.vision.all import *
from fastxtend.ffcv.all import *
Getting Started with fastxtend and FFCV
fastxtend integrates FFCV with fastai. You can now use the speed of the highly optimized FFCV DataLoader natively with fastai batch transforms, callbacks, and other DataLoader features.
FFCV is “a drop-in data loading system that dramatically increases data throughput in model training.” It accelerates1 DataLoader throughput by combining Numba compiled item transforms with a custom data format and cached data loading.
FFCV has a getting started tutorial which pairs well with this guide, providing additional context and depth.
While fastxtend’s FFCV integration is currently in beta, it is fully functional. Expect new features and quality of life improvements in future fastxtend releases.
The Imagenette benchmark uses fused optimizers, Progressive Resizing callback, and the integrated FFCV DataLoader. To run the benchmark on your own machine, see the example scripts for details on how to replicate.
Installing fastai, FFCV, and fastxtend
The easiest way to install fastai, fastxtend, and FFCV is to use Conda or Miniconda2 on Linux (or WSL):
conda create -n fastxtend python=3.11 "pytorch>=2.1" \
\
torchvision pytorch-cuda=12.1 fastai pkg-config \
libjpeg-turbo opencv tqdm terminaltables psutil numpy "numba>=0.57" timm kornia -c pytorch -c nvidia \
-c huggingface -c conda-forge
-c fastai
# Switch to the newly created conda environment
conda activate fastxtend
replacing pytorch-cuda=12.1
3 with your prefered supported version of Cuda. In rare4 cases, you may need to add the compilers
package to the conda install.
And then install fastxtend via pip:
# Install fastxtend with Vision & FFCV support
pip install fastxtend[ffcv]
Or to install with all of fastxtend’s features:
conda create -n fastxtend python=3.11 "pytorch>=2.1" torchvision \
\
torchaudio pytorch-cuda=12.1 fastai nbdev pkg-config libjpeg-turbo "numba>=0.57" librosa timm \
opencv tqdm psutil terminaltables numpy "transformers>=4.34" "tokenizers>=0.14" \
kornia rich typer wandb "datasets>=2.14" ipykernel ipywidgets "matplotlib<3.8" -c pytorch \
-c fastai -c huggingface -c conda-forge
-c nvidia
# Switch to the newly created conda environment
conda activate fastxtend
# install fastxtend and FFCV
pip install fastxtend[all]
If you are using Windows5, please follow the FFCV Windows installation guide, then install fastxtend via pip.
If you want to use notebooks with fastxtend, it’s recommended to add one ofjupyter
or jupyterlab
along with ipykernel
to the conda install packages.
cupy
is not listed in the conda packages as it’s only needed if you want to use FFCV’s NormalizeImage
on the GPU. It’s recommended to use fastai’s Normalize instead. See adding batch transforms section for more details.
Importing FFCV via fastxtend
fastxtend’s FFCV integration has been designed to use __all__
to safely import everything needed to use FFCV with fastai and fastxtend.
Run:
and you are ready to go.
fastxtend’s FFCV transforms and operations are imported under the fx
prefix, since they sometimes overlap with fastai batch transforms.
For example, the FFCV augmenation RandomErasing
is an Numba FFCV version of the batch transform fastai.vision.augment.RandomErasing
.
You can also import the FFCV integration individually. This will require you to mix imports from fastxtend.ffcv
and ffcv
.
from ffcv.fields.decoders import CenterCropDecoder
from fastxtend.ffcv.loader import Loader, OrderOption
from fastxtend.ffcv.transforms import RandomHorizontalFlip
# etc
However, during the beta it’s recommended to use from fastxtend.ffcv.all import *
, as imports may change between fastxtend releases.
Creating a FFCV Dataset
Before we can start training with FFCV, our dataset needs to be converted into FFCV’s custom beaton format. This can be done using the DatasetWriter
.
fastxtend provides the rgb_dataset_to_ffcv
convenience method for easy FFCV image dataset creation. rgb_dataset_to_ffcv
expects a PyTorch compatible Dataset
or any Python iterator.
First, create an Imagenette dataset using the fastai DataBlock API.
= URLs.IMAGENETTE_320
path = untar_data(path)
source
= DataBlock(blocks=(ImageBlock, CategoryBlock),
dblock =GrandparentSplitter(valid_name='val'),
splitter=get_image_files, get_y=parent_label)
get_items= dblock.datasets(source) dset
Next, use rgb_dataset_to_ffcv
to create two FFCV files: one for the training dataset and one for the validation dataset.
= Path.home()/'.cache/fastxtend'
path =True)
path.mkdir(exist_ok
/'imagenette_320_train.ffcv')
rgb_dataset_to_ffcv(dset.train, path
/'imagenette_320_valid.ffcv') rgb_dataset_to_ffcv(dset.valid, path
If you have more (or less) memory, you can increase (or decrease) DatasetWriter
’s chunk_size
from the default of 100.
If Imagenette was not already resized, we could pass max_resolution
or min_resolution
to resize the images. To recreate Imagenette 320 from the full size dataset, pass min_resolution=320
:
rgb_dataset_to_ffcv(/'imagenette_320_valid.ffcv', min_resolution=320) dset.train, path
By default, rgb_dataset_to_ffcv
will use Pillow and the LANCZOS
resample method to resize the image, and DatasetWriter
will use OpenCV with INTER_AREA
.
To accelerate image resizing, you’ll probably want Pillow-SIMD installed.
Creating a fastxtend Loader
fastxtend adds fastai features to FFCV’s Loader, including one_batch
, show_batch
, show_results
, and support for batch transforms, to name a few.
Currently fastai.data.block.DataBlock
is unsupported for creating a fastxtend Loader
, so we’ll have to create it from scratch.
For reference, here is the fastai DataBlock we’ll be recreating6 using the fastxtend Loader
.
= DataBlock(blocks=(ImageBlock, CategoryBlock),
dblock =GrandparentSplitter(valid_name='val'),
splitter=get_image_files, get_y=parent_label,
get_items=[RandomResizedCrop(224), FlipItem(0.5)],
item_tfms=[*aug_transforms(do_flip=False), Normalize(*imagenet_stats)])
batch_tfms= dblock.dataloaders(source, bs=64, num_workers=num_cpus()) dls
Unlike fastai, fastxtend’s FFCV integration, and FFCV itself, does not automatically select between training and validation versions of transforms.
Neither does it automatically create a validation pipeline or automatically reorder transforms and operations.
You are responsible for adding the correct decoders, transforms, and operations to the correct pipelines in the correct order.
Setting Up Pipelines
FFCV uses pipelines to declare what input fields to read, how to decode them, and which operations and transforms to apply on them.
The dataloader will need three pipelines: one for the training images, validation images, and a shared pipeline for labels.
We need to make sure that decoders, transforms, and operations are all in the correct order, as they will be executed sequentially.
Training Pipeline
Reading a FFCV dataset requires a FFCV decoder. FFCV has multiple decoders, but we’ll use RandomResizedCropRGBImageDecoder
which integrates a random resizing crop into image loading.
And we’ll add RandomHorizontalFlip
to flip the image.
= [
train_pipe =(224,224)),
RandomResizedCropRGBImageDecoder(output_size0.5)
ft.RandomHorizontalFlip( ]
fastxtend provides multiple FFCV transforms, including existing FFCV transforms with harmonized arguments, fastai transforms implemented as FFCV transforms, and additional FFCV transforms.
After passing through FFCV’s Numba compiled transforms, the Imagenette images are still in CPU memory as NumPy arrays. Before we can pass them to our model, they need to be converted to fastai.torch_core.TensorImage
and moved to the GPU.
We’ll extend our training pipeline by adding the ToTensorImage
and ToDevice
operations.
train_pipe.extend([ft.ToTensorImage(), ft.ToDevice()])
Loader
will now asynchronously transfer each training image batch to the GPU.
FFCV has ToTensor
, ToTorchImage
, and ToDevice
operations for converting NumPy arrays to PyTorch Tensors and moving to the GPU. These are compatible with PyTorch dataloaders but they are not compatible with fastai as they will strip the required types for fastai features, such as batch transforms, callbacks, plotting, etc.
Use fastxtend’s ToTensorImage
and ToDevice
for compatibility with fastai features.
Validation Pipeline
With the training pipeline finalized, it’s time to create the validation pipeline.
FFCV and fastxtend currently have one validation image decoder: CenterCropRGBImageDecoder
, which resizes and center crops the validation image. This is identical to fastai.vision.augment.Resize
valdiation behavior.
= [
valid_pipe =(224,224), ratio=1),
CenterCropRGBImageDecoder(output_size
ft.ToTensorImage(),
ft.ToDevice() ]
Like train_pipe
, we use ToTensorImage
to convert to the correct tensor type and ToDevice
asynchronously transfer each batch to the GPU.
Label Pipeline
Now we have our image pipeline for our training and validaton datasets but we need to create our label pipeline.
Since this is a single label dataset, we have integers as labels, so we’ll use an IntDecoder
and convert to TensorCategory
, followed by squeezing the extra dimension7 with Squeeze
and using ToDevice
to transfer to the GPU.
= [
label_pipe
IntDecoder(), ft.ToTensorCategory(),
ft.Squeeze(), ft.ToDevice() ]
Adding Required & Optional Batch Transforms
After Loader
finishes processing the pipelines, the images are batched on the GPU, but will be in uint8
format and unnormalized.
FFCV has operations to handle both8, but using them will convert the images from TensorImage
to Tensor
, limiting compatibility with other fastai features, such as callbacks, plotting, etc.
Since fastxtend’s Loader
supports fastai GPU batch transforms, we’ll use them instead.
Required Batch Transform
To convert the uint8
tensors to float
and normalize the images, we’ll use fastai.data.transforms.IntToFloatTensor
and fastai.data.transforms.Normalize
to preserve tensor types and metadata.
fastai.data.transforms.IntToFloatTensor
is a required batch transform (batch_tfms
) when training on image data for fastai feature compatibility.
Unlike the FFCV transforms we’ve used so far, the fastai transforms in Loader
will automatically reorder themselves into and use type dispatch to apply to the correct tensor types.
= [
batch_tfms
IntToFloatTensor,*imagenet_stats)
Normalize.from_stats( ]
It is recommended, but not required, to normalize an image batch.
Optional Batch Transforms
It’s also possible to add any fastai batch transform to Loader
’s batch_tfms
, such as fastai.vision.augment.aug_transforms
or affine_transforms
:
= [
batch_tfms
IntToFloatTensor,*aug_transforms(),
*imagenet_stats)
Normalize.from_stats( ]
Creating the Loader
With the image pipelines, label pipeline, and batch transforms set up, we can now create our dataloaders, one Loader
for train and one for valid.
Training Loader
Starting with the training Loader
, we can manually set all the training specific arguments:
/'imagenette_320_train.ffcv',
Loader(path=64,
batch_size=num_cpus(),
num_workers=True,
os_cache=OrderOption.RANDOM,
order=True,
drop_last={'image': train_pipeline, 'label': label_pipeline},
pipelines=batch_tfms,
batch_tfms=2,
batches_ahead='cuda',
device=0,
split_idx=1
n_inp )
Loader Arguments
There are a handful of important Loader
arguments which need further explanation:
order
: Controls how much memory is used for dataset caching and whether the dataset is randomly shuffled. Can be one ofRANDOM
,QUASI_RANDOM
, orSEQUENTIAL
. See the note below for more details. Defaults toSEQUENTIAL
, which is unrandomized.os_cache
: By default, FFCV will attempt to cache the entire dataset into RAM using the operating system’s caching. This can be changed by settingos_cache=False
or setting the enviroment variable ‘FFCV_DEFAULT_CACHE_PROCESS’ to “True” or “1”. Ifos_cache=False
thenorder
must be set toQUASI_RANDOM
for the trainingLoader
.num_workers
: If not set, will use all CPU cores up to 16 by default.batches_ahead
: Controls the number of batches ahead theLoader
works. Increasing uses more RAM, both CPU and GPU. Defaults to 2.n_inp
: Controls which inputs to pass to the model. By default, set to number of pipelines minus 1.drop_last
: Whether to drop the last partial batch. By default, will set to True iforder
isRANDOM
orQUASI_RANDOM
, False ifSEQUENTIAL
.async_tfms
: Asynchronously applybatch_tfms
before the batch is drawn. Can accelerate training if GPU compute isn’t fully saturated (95% or less) or if only usingIntToFloatTensor
andNormalize
.device
: The device to place the processed batches of data on. Defaults tofastai.torch_core.default_device
if not set.split_idx
: This tells the fastai batch transforms what dataset they are operating on. By default will use 0 (train) iforder
isRANDOM
orQUASI_RANDOM
, 1 (valid) ifSEQUENTIAL
.
Each order
option requires differing amounts of system memory.
RANDOM
caches the entire dataset in memory for fast random sampling.RANDOM
uses the most memory.QUASI_RANDOM
caches a subset of the dataset at a time in memory and randomly samples from the subset. Use when the entire dataset cannot fit into memory.SEQUENTIAL
requires least memory. It loads a few samples ahead of time. As the name suggests, it is not random, and primarly is for validation.
With Loader
, fastai batch transforms can either be computed synchronously or asynchronously.
Synchronous transforms behave like the fastai dataloader: a batch is called for by the training loop, batch transforms are computed and applied to the batch, and then the batch is passed to the model for training. Asynchronous batch transforms are computed and applied to batches ahead of time, in parallel with model training. With asynchronous batch transforms there isn’t any delay between calling for a batch and training on the batch.
Because asynchronous transforms are computed in parallel with the training loop they can slow down the model training loop. However, due to eliminating the delay for computing batch transforms, asynchronous transforms increase overall training speed.
Try asynchronous transforms by setting async_tfms = True
if GPU compute isn’t fully saturated (95% or less) or if only using the required transforms IntToFloatTensor
and Normalize
.
Simplified Training Loader
Since Imagenette 320 is small enough to load both the training and validation images into RAM9, the only arguments for the training Loader
that must be set are:
/'imagenette_320_train.ffcv',
Loader(path=64,
batch_size=OrderOption.RANDOM,
order={'image': train_pipeline, 'label': label_pipeline},
pipelines=batch_tfms,
batch_tfms=2
batches_ahead )
Validation Loader
Next is the validation Loader
.
/'imagenette_320_valid.ffcv',
Loader(path=64,
batch_size={'image': valid_pipeline, 'label': label_pipeline},
pipelines=batch_tfms,
batch_tfms=2
batches_ahead )
This example only sets the required arguments, relying on the Loader
defaults for the rest.
Putting it all Together
The last step is to wrap both the training and validation Loader
in a fastai.data.core.DataLoaders
.
The example below shows all the steps covered so far in a single codeblock.
= {}
loaders for name in ['train', 'valid']:
= [
label_pipe
IntDecoder(), fx.ToTensorCategory(),
fx.Squeeze(), fx.ToDevice()
]
if name=='train':
= [
image_pipe =(224,224), scale=(0.35, 1)),
RandomResizedCropRGBImageDecoder(output_size
fx.RandomHorizontalFlip(), fx.ToTensorImage(), fx.ToDevice()
]= OrderOption.RANDOM
order else:
= [
image_pipe =(224,224), ratio=1),
CenterCropRGBImageDecoder(output_size
fx.ToTensorImage(), fx.ToDevice()
]= OrderOption.SEQUENTIAL
order
= [IntToFloatTensor, *aug_transforms(), Normalize.from_stats(*imagenet_stats)]
batch_tfms
= Loader(path/f'imagenette_320_{name}.ffcv',
loaders[name] =64 if name=='train' else 128,
batch_size=order,
order={'image': image_pipe, 'label': label_pipe},
pipelines=batch_tfms,
batch_tfms=1,
batches_ahead=42
seed
)
= DataLoaders(loaders['train'], loaders['valid']) dls
Training with a fastxtend Loader
With the DataLoaders
created, the only thing left to do is create a fastai.learner.Learner
with our model, optimizer, loss function, metrics, and callbacks of choice.
Here we create the setup which should train Imagenette to ~92.5% in ~226 seconds on a 3080 Ti, depending on randomness and hardware.
with less_random():
= Learner(dls, xresnext50(n_out=10), opt_func=ranger(foreach=True),
learn =nn.CrossEntropyLoss(label_smoothing=0.1), metrics=Accuracy(),
loss_func=ProgressiveResize(increase_by=16)).to_channelslast() cbs
The first batch will be slower, as Numba needs to compile and FFCV needs to allocate memory for each transform and operation.
Once the compilation is over, we will benefit from FFCV’s accelerated data loading.
with less_random():
20, 8e-3) learn.fit_flat_cos(
Progressively increase the initial image size of [112, 112] by 16 pixels every 0.8333 epochs for 7 resizes.
Starting at epoch 10 and finishing at epoch 15 for a final training size of [224, 224].
epoch | train_loss | valid_loss | accuracy | time |
---|---|---|---|---|
0 | 1.739842 | 1.804805 | 0.515924 | 00:12 |
1 | 1.483479 | 1.330434 | 0.662930 | 00:08 |
2 | 1.317108 | 1.193714 | 0.724076 | 00:08 |
3 | 1.215840 | 1.404873 | 0.640255 | 00:08 |
4 | 1.114017 | 1.206947 | 0.715924 | 00:08 |
5 | 1.068557 | 0.985433 | 0.817070 | 00:08 |
6 | 1.017356 | 1.111050 | 0.760255 | 00:08 |
7 | 0.985094 | 1.069334 | 0.767898 | 00:08 |
8 | 0.960843 | 0.889831 | 0.847389 | 00:08 |
9 | 0.911082 | 1.085570 | 0.755159 | 00:08 |
10 | 0.895687 | 0.927597 | 0.836688 | 00:09 |
11 | 0.867939 | 0.798519 | 0.885860 | 00:09 |
12 | 0.870780 | 0.957848 | 0.807388 | 00:11 |
13 | 0.843433 | 0.894144 | 0.846624 | 00:13 |
14 | 0.820972 | 0.793127 | 0.884076 | 00:15 |
15 | 0.797982 | 0.948408 | 0.824204 | 00:17 |
16 | 0.766880 | 0.769455 | 0.896815 | 00:17 |
17 | 0.716646 | 0.749211 | 0.906242 | 00:17 |
18 | 0.671945 | 0.715331 | 0.923057 | 00:17 |
19 | 0.649383 | 0.699832 | 0.926115 | 00:17 |
with less_random():
5, 8e-3) learn.fit_flat_cos(
Inference with FFCV and fastai
While the fastxtend Loader
and custom FFCV file format are great for accelerated training, they are not as useful for inference. This tutorial we will use the fastai DataBlock for inference.
Since FFCV uses OpenCV for resizing and fastai uses Pillow, we cannot use the default fastai pipeline.
FFCV is hardcoded to use OpenCV’s INTER_AREA
when resizing images, while fastai uses Pillow. This means we cannot use fastai.vision.data.ImageBlock
for inference.
fastxtend’s FFCV Inference module provides a FFAIImageBlock
and FFAICenterCrop
item transform which use OpenCV for resizing images.
Any class or method with the prefix FFAI is intended for inference after training with the fastxtend Loader
.
We can create an inference dataloader using the fastai DataBlock API which will create images identically to the Loader
pipeline we created earlier in this tutorial.
= DataBlock(blocks=(FFAIImageBlock, CategoryBlock),
inference_dblock =GrandparentSplitter(valid_name='val'),
splitter=get_image_files, get_y=parent_label,
get_items=[FFAICenterCrop(224, ratio=1)],
item_tfms=[*aug_transforms(), Normalize(*imagenet_stats)])
batch_tfms= dblock.dataloaders(source, bs=64, num_workers=num_cpus()) inference_dls
Then assuming we had a folder with test images in '/test'
, we’d set the learn
dataloader to the new inference_dls
and perform inference like normal.
= inference_dls
learn.dls = dls.test_dl('/test')
test_dl = learn.get_preds(dl=test_dl) preds, _
For more details on inference, check out my Inference with fastai tutorial.
Reproducibility and Other Limitations
One downside of FFCV is it provides less reproducibility than most dataloader solutions.
While Loader
has a seed
argument, it currently only affects the order data is loaded. With a couple exceptions, the Numba transforms are neither seeded nor reproducible. These transforms are also independent across pipelines.
This means many image-to-image training tasks, such as image segmentation, cannot easily use Loader
as the inputs and outputs will not be identically resized, flipped, etc.
However, this FFCV limitation might be resolved in the near future. As MetaAI recently announced FFCV-SSL, a fork of FFCV which, among other things, has reproducible transforms.
Identical fastai and FFCV Dataloaders
The fastai dataloader and FFCV dataloader we created in this tutorial do not produce identical validation images10.
This is due to fastai.vision.augment.RandomResizedCrop
adding padding and squishing validation images, while FFCV’s CenterCropRGBImageDecoder
creates validation images via center crop.
There’s an easy way to create a fastai dataloader with identical behavior via the DataBlock API, we just have to create it twice.
First, create a dataloader like we did before using RandomResizedCrop
. Then create a second dataloader except with fastai.vision.augment.Resize
. Resize
creates a center crop during validation, just like CenterCropRGBImageDecoder
with ratio=1
. Finally, set the first DataLoaders
’ valid dataloader as our second valid dataloader. Then we can ptionally delete the second DataLoaders
.
= DataBlock(blocks=(ImageBlock, CategoryBlock),
dblock =GrandparentSplitter(valid_name='val'),
splitter=get_image_files, get_y=parent_label,
get_items=[RandomResizedCrop(224), FlipItem(0.5)],
item_tfms=[*aug_transforms(do_flip=False),
batch_tfms*imagenet_stats)])
Normalize.from_stats(= dblock.dataloaders(source, bs=64, num_workers=num_cpus())
dls
= DataBlock(blocks=(ImageBlock, CategoryBlock),
vblock =GrandparentSplitter(valid_name='val'),
splitter=get_image_files, get_y=parent_label,
get_items=Resize(224),
item_tfms=[*aug_transforms(), Normalize.from_stats(*imagenet_stats)])
batch_tfms= vblock.dataloaders(source, bs=64, num_workers=num_cpus())
vls
= vls.valid
dls.valid = None vls
Footnotes
MosiacML found that using FFCV led to a ~1.85x increase in throughput, from ~17,800 images/sec to ~30,000 images/sec on a 2x 32-core CPU and 8x A100 system.↩︎
Miniconda with the faster libmamba solver is recommended.↩︎
If you want to include a full Cuda install, I find that specifying the Cuda version label
-c nvidia/label/cuda-12.1.0
usually results in better enviroment solving. Sometimes the label doesn’t have all the packages, so you’ll need to add them manually to the conda install packagesnvidia::missing_pacakge
.↩︎The FFCV Linux installation guide states the
compilers
package is rarely needed.↩︎fastxtend Windows support is currently untested, but it should work. It is recommended to use WSL on Windows.↩︎
This will not be a one-to-one recreation, as
fastai.vision.augment.RandomResizedCrop
adds padding and squishes the validation images, while the fastxendLoader
will use a standard center crop for validation. See the Identical fastai and FFCV Dataloaders section for how to create an identical fastai dataloader.↩︎Since this is a single label problem, we need to remove the added dimension after the batch is created. If this were a multi-labeled dataset, we’d skip the squeezing step.↩︎
CuPy is required to use for FFCV’s
NormalizeImage
on the GPU. Addcupy
to the conda installation script if using.↩︎Assuming your machine has more than 8GB of RAM.↩︎
Ignoring the OpenCV vs Pillow resizing differences.↩︎