Getting Started with fastxtend and FFCV

Use fastxtend’s FFCV integration to accelerate fastai training

fastxtend integrates FFCV with fastai. You can now use the speed of the highly optimized FFCV DataLoader natively with fastai batch transforms, callbacks, and other DataLoader features.

fastxtend accelerates fastai

FFCV is “a drop-in data loading system that dramatically increases data throughput in model training.” It accelerates1 DataLoader throughput by combining Numba compiled item transforms with a custom data format and cached data loading.

FFCV has a getting started tutorial which pairs well with this guide, providing additional context and depth.

While fastxtend’s FFCV integration is currently in beta, it is fully functional. Expect new features and quality of life improvements in future fastxtend releases.

The Imagenette benchmark uses fused optimizers, Progressive Resizing callback, and the integrated FFCV DataLoader. To run the benchmark on your own machine, see the example scripts for details on how to replicate.

Installing fastai, FFCV, and fastxtend

The easiest way to install fastai, fastxtend, and FFCV is to use Conda or Miniconda2 on Linux (or WSL):

conda create -n fastxtend python=3.11 "pytorch>=2.1" \
torchvision pytorch-cuda=12.1 fastai pkg-config \
libjpeg-turbo opencv tqdm terminaltables psutil numpy \
"numba>=0.57" timm kornia -c pytorch -c nvidia \
-c fastai -c huggingface -c conda-forge

# Switch to the newly created conda environment
conda activate fastxtend

replacing pytorch-cuda=12.13 with your prefered supported version of Cuda. In rare4 cases, you may need to add the compilers package to the conda install.

And then install fastxtend via pip:

# Install fastxtend with Vision & FFCV support
pip install fastxtend[ffcv]

Or to install with all of fastxtend’s features:

conda create -n fastxtend python=3.11 "pytorch>=2.1" torchvision \
torchaudio pytorch-cuda=12.1 fastai nbdev pkg-config libjpeg-turbo \
opencv tqdm psutil terminaltables numpy "numba>=0.57" librosa timm \
kornia rich typer wandb "transformers>=4.34" "tokenizers>=0.14" \
"datasets>=2.14" ipykernel ipywidgets "matplotlib<3.8" -c pytorch \
-c nvidia -c fastai -c huggingface -c conda-forge

# Switch to the newly created conda environment
conda activate fastxtend

# install fastxtend and FFCV
pip install fastxtend[all]

If you are using Windows5, please follow the FFCV Windows installation guide, then install fastxtend via pip.

If you want to use notebooks with fastxtend, it’s recommended to add one ofjupyter or jupyterlab along with ipykernel to the conda install packages.

cupy is not listed in the conda packages as it’s only needed if you want to use FFCV’s NormalizeImage on the GPU. It’s recommended to use fastai’s Normalize instead. See adding batch transforms section for more details.

Importing FFCV via fastxtend

fastxtend’s FFCV integration has been designed to use __all__ to safely import everything needed to use FFCV with fastai and fastxtend.

Run:

from fastai.vision.all import *
from fastxtend.vision.all import *
from fastxtend.ffcv.all import *

and you are ready to go.

Note: Importing Transforms & Operations

fastxtend’s FFCV transforms and operations are imported under the fx prefix, since they sometimes overlap with fastai batch transforms.

For example, the FFCV augmenation RandomErasing is an Numba FFCV version of the batch transform fastai.vision.augment.RandomErasing.

You can also import the FFCV integration individually. This will require you to mix imports from fastxtend.ffcv and ffcv.

from ffcv.fields.decoders import CenterCropDecoder
from fastxtend.ffcv.loader import Loader, OrderOption
from fastxtend.ffcv.transforms import RandomHorizontalFlip
# etc

However, during the beta it’s recommended to use from fastxtend.ffcv.all import *, as imports may change between fastxtend releases.

Creating a FFCV Dataset

Before we can start training with FFCV, our dataset needs to be converted into FFCV’s custom beaton format. This can be done using the DatasetWriter.

fastxtend provides the rgb_dataset_to_ffcv convenience method for easy FFCV image dataset creation. rgb_dataset_to_ffcv expects a PyTorch compatible Dataset or any Python iterator.

First, create an Imagenette dataset using the fastai DataBlock API.

path = URLs.IMAGENETTE_320
source = untar_data(path)

dblock = DataBlock(blocks=(ImageBlock, CategoryBlock),
                   splitter=GrandparentSplitter(valid_name='val'),
                   get_items=get_image_files, get_y=parent_label)
dset = dblock.datasets(source)

Next, use rgb_dataset_to_ffcv to create two FFCV files: one for the training dataset and one for the validation dataset.

path = Path.home()/'.cache/fastxtend'
path.mkdir(exist_ok=True)

rgb_dataset_to_ffcv(dset.train, path/'imagenette_320_train.ffcv')

rgb_dataset_to_ffcv(dset.valid, path/'imagenette_320_valid.ffcv')

If you have more (or less) memory, you can increase (or decrease) DatasetWriter’s chunk_size from the default of 100.

If Imagenette was not already resized, we could pass max_resolution or min_resolution to resize the images. To recreate Imagenette 320 from the full size dataset, pass min_resolution=320:

rgb_dataset_to_ffcv(
    dset.train, path/'imagenette_320_valid.ffcv', min_resolution=320)

By default, rgb_dataset_to_ffcv will use Pillow and the LANCZOS resample method to resize the image, and DatasetWriter will use OpenCV with INTER_AREA.

To accelerate image resizing, you’ll probably want Pillow-SIMD installed.

Creating a fastxtend Loader

fastxtend adds fastai features to FFCV’s Loader, including one_batch, show_batch, show_results, and support for batch transforms, to name a few.

Currently fastai.data.block.DataBlock is unsupported for creating a fastxtend Loader, so we’ll have to create it from scratch.

For reference, here is the fastai DataBlock we’ll be recreating6 using the fastxtend Loader.

dblock = DataBlock(blocks=(ImageBlock, CategoryBlock),
                   splitter=GrandparentSplitter(valid_name='val'),
                   get_items=get_image_files, get_y=parent_label,
                   item_tfms=[RandomResizedCrop(224), FlipItem(0.5)],
                   batch_tfms=[*aug_transforms(do_flip=False), Normalize(*imagenet_stats)])
dls = dblock.dataloaders(source, bs=64, num_workers=num_cpus())

Unlike fastai, fastxtend’s FFCV integration, and FFCV itself, does not automatically select between training and validation versions of transforms.

Neither does it automatically create a validation pipeline or automatically reorder transforms and operations.

You are responsible for adding the correct decoders, transforms, and operations to the correct pipelines in the correct order.

Setting Up Pipelines

FFCV uses pipelines to declare what input fields to read, how to decode them, and which operations and transforms to apply on them.

The dataloader will need three pipelines: one for the training images, validation images, and a shared pipeline for labels.

We need to make sure that decoders, transforms, and operations are all in the correct order, as they will be executed sequentially.

Training Pipeline

Reading a FFCV dataset requires a FFCV decoder. FFCV has multiple decoders, but we’ll use RandomResizedCropRGBImageDecoder which integrates a random resizing crop into image loading.

And we’ll add RandomHorizontalFlip to flip the image.

train_pipe = [
    RandomResizedCropRGBImageDecoder(output_size=(224,224)),
    ft.RandomHorizontalFlip(0.5)
]

fastxtend provides multiple FFCV transforms, including existing FFCV transforms with harmonized arguments, fastai transforms implemented as FFCV transforms, and additional FFCV transforms.

After passing through FFCV’s Numba compiled transforms, the Imagenette images are still in CPU memory as NumPy arrays. Before we can pass them to our model, they need to be converted to fastai.torch_core.TensorImage and moved to the GPU.

We’ll extend our training pipeline by adding the ToTensorImage and ToDevice operations.

train_pipe.extend([ft.ToTensorImage(), ft.ToDevice()])

Loader will now asynchronously transfer each training image batch to the GPU.

FFCV has ToTensor, ToTorchImage, and ToDevice operations for converting NumPy arrays to PyTorch Tensors and moving to the GPU. These are compatible with PyTorch dataloaders but they are not compatible with fastai as they will strip the required types for fastai features, such as batch transforms, callbacks, plotting, etc.

Use fastxtend’s ToTensorImage and ToDevice for compatibility with fastai features.

Validation Pipeline

With the training pipeline finalized, it’s time to create the validation pipeline.

FFCV and fastxtend currently have one validation image decoder: CenterCropRGBImageDecoder, which resizes and center crops the validation image. This is identical to fastai.vision.augment.Resize valdiation behavior.

valid_pipe = [
    CenterCropRGBImageDecoder(output_size=(224,224), ratio=1),
    ft.ToTensorImage(),
    ft.ToDevice()
]

Like train_pipe, we use ToTensorImage to convert to the correct tensor type and ToDevice asynchronously transfer each batch to the GPU.

Label Pipeline

Now we have our image pipeline for our training and validaton datasets but we need to create our label pipeline.

Since this is a single label dataset, we have integers as labels, so we’ll use an IntDecoder and convert to TensorCategory, followed by squeezing the extra dimension7 with Squeeze and using ToDevice to transfer to the GPU.

label_pipe = [
    IntDecoder(), ft.ToTensorCategory(),
    ft.Squeeze(), ft.ToDevice()
]

Adding Required & Optional Batch Transforms

After Loader finishes processing the pipelines, the images are batched on the GPU, but will be in uint8 format and unnormalized.

FFCV has operations to handle both8, but using them will convert the images from TensorImage to Tensor, limiting compatibility with other fastai features, such as callbacks, plotting, etc.

Since fastxtend’s Loader supports fastai GPU batch transforms, we’ll use them instead.

Required Batch Transform

To convert the uint8 tensors to float and normalize the images, we’ll use fastai.data.transforms.IntToFloatTensor and fastai.data.transforms.Normalize to preserve tensor types and metadata.

fastai.data.transforms.IntToFloatTensor is a required batch transform (batch_tfms) when training on image data for fastai feature compatibility.

Unlike the FFCV transforms we’ve used so far, the fastai transforms in Loader will automatically reorder themselves into and use type dispatch to apply to the correct tensor types.

batch_tfms = [
    IntToFloatTensor,
    Normalize.from_stats(*imagenet_stats)
]

It is recommended, but not required, to normalize an image batch.

Optional Batch Transforms

It’s also possible to add any fastai batch transform to Loader’s batch_tfms, such as fastai.vision.augment.aug_transforms or affine_transforms:

batch_tfms = [
    IntToFloatTensor,
    *aug_transforms(),
    Normalize.from_stats(*imagenet_stats)
]

Creating the Loader

With the image pipelines, label pipeline, and batch transforms set up, we can now create our dataloaders, one Loader for train and one for valid.

Training Loader

Starting with the training Loader, we can manually set all the training specific arguments:

Loader(path/'imagenette_320_train.ffcv',
    batch_size=64,
    num_workers=num_cpus(),
    os_cache=True,
    order=OrderOption.RANDOM,
    drop_last=True,
    pipelines={'image': train_pipeline, 'label': label_pipeline},
    batch_tfms=batch_tfms,
    batches_ahead=2,
    device='cuda',
    split_idx=0,
    n_inp=1
)

Loader Arguments

There are a handful of important Loader arguments which need further explanation:

  • order: Controls how much memory is used for dataset caching and whether the dataset is randomly shuffled. Can be one of RANDOM, QUASI_RANDOM, or SEQUENTIAL. See the note below for more details. Defaults to SEQUENTIAL, which is unrandomized.

  • os_cache: By default, FFCV will attempt to cache the entire dataset into RAM using the operating system’s caching. This can be changed by setting os_cache=False or setting the enviroment variable ‘FFCV_DEFAULT_CACHE_PROCESS’ to “True” or “1”. If os_cache=False then order must be set to QUASI_RANDOM for the training Loader.

  • num_workers: If not set, will use all CPU cores up to 16 by default.

  • batches_ahead: Controls the number of batches ahead the Loader works. Increasing uses more RAM, both CPU and GPU. Defaults to 2.

  • n_inp: Controls which inputs to pass to the model. By default, set to number of pipelines minus 1.

  • drop_last: Whether to drop the last partial batch. By default, will set to True if order is RANDOM or QUASI_RANDOM, False if SEQUENTIAL.

  • async_tfms: Asynchronously apply batch_tfms before the batch is drawn. Can accelerate training if GPU compute isn’t fully saturated (95% or less) or if only using IntToFloatTensor and Normalize.

  • device: The device to place the processed batches of data on. Defaults to fastai.torch_core.default_device if not set.

  • split_idx: This tells the fastai batch transforms what dataset they are operating on. By default will use 0 (train) if order is RANDOM or QUASI_RANDOM, 1 (valid) if SEQUENTIAL.

Note: Order Memory Usage

Each order option requires differing amounts of system memory.

  • RANDOM caches the entire dataset in memory for fast random sampling. RANDOM uses the most memory.

  • QUASI_RANDOM caches a subset of the dataset at a time in memory and randomly samples from the subset. Use when the entire dataset cannot fit into memory.

  • SEQUENTIAL requires least memory. It loads a few samples ahead of time. As the name suggests, it is not random, and primarly is for validation.

With Loader, fastai batch transforms can either be computed synchronously or asynchronously.

Synchronous transforms behave like the fastai dataloader: a batch is called for by the training loop, batch transforms are computed and applied to the batch, and then the batch is passed to the model for training. Asynchronous batch transforms are computed and applied to batches ahead of time, in parallel with model training. With asynchronous batch transforms there isn’t any delay between calling for a batch and training on the batch.

Because asynchronous transforms are computed in parallel with the training loop they can slow down the model training loop. However, due to eliminating the delay for computing batch transforms, asynchronous transforms increase overall training speed.

Try asynchronous transforms by setting async_tfms = True if GPU compute isn’t fully saturated (95% or less) or if only using the required transforms IntToFloatTensor and Normalize.

Simplified Training Loader

Since Imagenette 320 is small enough to load both the training and validation images into RAM9, the only arguments for the training Loader that must be set are:

Loader(path/'imagenette_320_train.ffcv',
    batch_size=64,
    order=OrderOption.RANDOM,
    pipelines={'image': train_pipeline, 'label': label_pipeline},
    batch_tfms=batch_tfms,
    batches_ahead=2
)

Validation Loader

Next is the validation Loader.

Loader(path/'imagenette_320_valid.ffcv',
    batch_size=64,
    pipelines={'image': valid_pipeline, 'label': label_pipeline},
    batch_tfms=batch_tfms,
    batches_ahead=2
)

This example only sets the required arguments, relying on the Loader defaults for the rest.

Putting it all Together

The last step is to wrap both the training and validation Loader in a fastai.data.core.DataLoaders.

The example below shows all the steps covered so far in a single codeblock.

loaders = {}
for name in ['train', 'valid']:
    label_pipe = [
        IntDecoder(), fx.ToTensorCategory(),
        fx.Squeeze(), fx.ToDevice()
    ]

    if name=='train':
        image_pipe = [
            RandomResizedCropRGBImageDecoder(output_size=(224,224), scale=(0.35, 1)),
            fx.RandomHorizontalFlip(), fx.ToTensorImage(), fx.ToDevice()
        ]
        order = OrderOption.RANDOM
    else:
        image_pipe = [
            CenterCropRGBImageDecoder(output_size=(224,224), ratio=1),
            fx.ToTensorImage(), fx.ToDevice()
        ]
        order = OrderOption.SEQUENTIAL

    batch_tfms = [IntToFloatTensor, *aug_transforms(), Normalize.from_stats(*imagenet_stats)]

    loaders[name] = Loader(path/f'imagenette_320_{name}.ffcv',
                        batch_size=64 if name=='train' else 128,
                        order=order,
                        pipelines={'image': image_pipe, 'label': label_pipe},
                        batch_tfms=batch_tfms,
                        batches_ahead=1,
                        seed=42
                    )

dls = DataLoaders(loaders['train'], loaders['valid'])

Training with a fastxtend Loader

With the DataLoaders created, the only thing left to do is create a fastai.learner.Learner with our model, optimizer, loss function, metrics, and callbacks of choice.

Here we create the setup which should train Imagenette to ~92.5% in ~226 seconds on a 3080 Ti, depending on randomness and hardware.

with less_random():
    learn = Learner(dls, xresnext50(n_out=10), opt_func=ranger(foreach=True),
                    loss_func=nn.CrossEntropyLoss(label_smoothing=0.1), metrics=Accuracy(),
                    cbs=ProgressiveResize(increase_by=16)).to_channelslast()

The first batch will be slower, as Numba needs to compile and FFCV needs to allocate memory for each transform and operation.

Once the compilation is over, we will benefit from FFCV’s accelerated data loading.

with less_random():
    learn.fit_flat_cos(20, 8e-3)
Progressively increase the initial image size of [112, 112] by 16 pixels every 0.8333 epochs for 7 resizes. 
Starting at epoch 10 and finishing at epoch 15 for a final training size of [224, 224].
epoch train_loss valid_loss accuracy time
0 1.739842 1.804805 0.515924 00:12
1 1.483479 1.330434 0.662930 00:08
2 1.317108 1.193714 0.724076 00:08
3 1.215840 1.404873 0.640255 00:08
4 1.114017 1.206947 0.715924 00:08
5 1.068557 0.985433 0.817070 00:08
6 1.017356 1.111050 0.760255 00:08
7 0.985094 1.069334 0.767898 00:08
8 0.960843 0.889831 0.847389 00:08
9 0.911082 1.085570 0.755159 00:08
10 0.895687 0.927597 0.836688 00:09
11 0.867939 0.798519 0.885860 00:09
12 0.870780 0.957848 0.807388 00:11
13 0.843433 0.894144 0.846624 00:13
14 0.820972 0.793127 0.884076 00:15
15 0.797982 0.948408 0.824204 00:17
16 0.766880 0.769455 0.896815 00:17
17 0.716646 0.749211 0.906242 00:17
18 0.671945 0.715331 0.923057 00:17
19 0.649383 0.699832 0.926115 00:17
with less_random():
    learn.fit_flat_cos(5, 8e-3)

Inference with FFCV and fastai

While the fastxtend Loader and custom FFCV file format are great for accelerated training, they are not as useful for inference. This tutorial we will use the fastai DataBlock for inference.

Since FFCV uses OpenCV for resizing and fastai uses Pillow, we cannot use the default fastai pipeline.

FFCV is hardcoded to use OpenCV’s INTER_AREA when resizing images, while fastai uses Pillow. This means we cannot use fastai.vision.data.ImageBlock for inference.

fastxtend’s FFCV Inference module provides a FFAIImageBlock and FFAICenterCrop item transform which use OpenCV for resizing images.

Any class or method with the prefix FFAI is intended for inference after training with the fastxtend Loader.

We can create an inference dataloader using the fastai DataBlock API which will create images identically to the Loader pipeline we created earlier in this tutorial.

inference_dblock = DataBlock(blocks=(FFAIImageBlock, CategoryBlock),
                             splitter=GrandparentSplitter(valid_name='val'),
                             get_items=get_image_files, get_y=parent_label,
                             item_tfms=[FFAICenterCrop(224, ratio=1)],
                             batch_tfms=[*aug_transforms(), Normalize(*imagenet_stats)])
inference_dls = dblock.dataloaders(source, bs=64, num_workers=num_cpus())

Then assuming we had a folder with test images in '/test', we’d set the learn dataloader to the new inference_dls and perform inference like normal.

learn.dls = inference_dls
test_dl = dls.test_dl('/test')
preds, _ = learn.get_preds(dl=test_dl)

For more details on inference, check out my Inference with fastai tutorial.

Reproducibility and Other Limitations

One downside of FFCV is it provides less reproducibility than most dataloader solutions.

While Loader has a seed argument, it currently only affects the order data is loaded. With a couple exceptions, the Numba transforms are neither seeded nor reproducible. These transforms are also independent across pipelines.

This means many image-to-image training tasks, such as image segmentation, cannot easily use Loader as the inputs and outputs will not be identically resized, flipped, etc.

However, this FFCV limitation might be resolved in the near future. As MetaAI recently announced FFCV-SSL, a fork of FFCV which, among other things, has reproducible transforms.

Identical fastai and FFCV Dataloaders

The fastai dataloader and FFCV dataloader we created in this tutorial do not produce identical validation images10.

This is due to fastai.vision.augment.RandomResizedCrop adding padding and squishing validation images, while FFCV’s CenterCropRGBImageDecoder creates validation images via center crop.

There’s an easy way to create a fastai dataloader with identical behavior via the DataBlock API, we just have to create it twice.

First, create a dataloader like we did before using RandomResizedCrop. Then create a second dataloader except with fastai.vision.augment.Resize. Resize creates a center crop during validation, just like CenterCropRGBImageDecoder with ratio=1. Finally, set the first DataLoaders’ valid dataloader as our second valid dataloader. Then we can ptionally delete the second DataLoaders.

dblock = DataBlock(blocks=(ImageBlock, CategoryBlock),
                   splitter=GrandparentSplitter(valid_name='val'),
                   get_items=get_image_files, get_y=parent_label,
                   item_tfms=[RandomResizedCrop(224), FlipItem(0.5)],
                   batch_tfms=[*aug_transforms(do_flip=False),
                               Normalize.from_stats(*imagenet_stats)])
dls = dblock.dataloaders(source, bs=64, num_workers=num_cpus())

vblock = DataBlock(blocks=(ImageBlock, CategoryBlock),
                   splitter=GrandparentSplitter(valid_name='val'),
                   get_items=get_image_files, get_y=parent_label,
                   item_tfms=Resize(224),
                   batch_tfms=[*aug_transforms(), Normalize.from_stats(*imagenet_stats)])
vls = vblock.dataloaders(source, bs=64, num_workers=num_cpus())

dls.valid = vls.valid
vls = None

Footnotes

  1. MosiacML found that using FFCV led to a ~1.85x increase in throughput, from ~17,800 images/sec to ~30,000 images/sec on a 2x 32-core CPU and 8x A100 system.↩︎

  2. Miniconda with the faster libmamba solver is recommended.↩︎

  3. If you want to include a full Cuda install, I find that specifying the Cuda version label -c nvidia/label/cuda-12.1.0 usually results in better enviroment solving. Sometimes the label doesn’t have all the packages, so you’ll need to add them manually to the conda install packages nvidia::missing_pacakge.↩︎

  4. The FFCV Linux installation guide states the compilers package is rarely needed.↩︎

  5. fastxtend Windows support is currently untested, but it should work. It is recommended to use WSL on Windows.↩︎

  6. This will not be a one-to-one recreation, as fastai.vision.augment.RandomResizedCrop adds padding and squishes the validation images, while the fastxend Loader will use a standard center crop for validation. See the Identical fastai and FFCV Dataloaders section for how to create an identical fastai dataloader.↩︎

  7. Since this is a single label problem, we need to remove the added dimension after the batch is created. If this were a multi-labeled dataset, we’d skip the squeezing step.↩︎

  8. CuPy is required to use for FFCV’s NormalizeImage on the GPU. Add cupy to the conda installation script if using.↩︎

  9. Assuming your machine has more than 8GB of RAM.↩︎

  10. Ignoring the OpenCV vs Pillow resizing differences.↩︎