Hugging Face Transformers Compatibility

Train Hugging Face Transformers models using fastai

fastxtend provides basic compatibility for training Hugging Face Transformers models using the fastai.learner.Learner.

blurr provides a complete Hugging Face Transformers integration with fastai, including working fastai datablocks, dataloaders, and other fastai methods.

In contrast, fastxtend only provides basic Learner compatibility.

fastxtend’s Transformers compatibility requires a minimum of PyTorch 2.0.

To use fastxend’s compatibility, setup the Hugging Face dataset, dataloader, and model per the Transformers documentation, exchanging the PyTorch Dataloader for the HuggingFaceLoader. Then wrap the dataloaders in fastai.data.core.DataLoaders and create a Learner with the Hugging Face model, HuggingFaceLoss, and HuggingFaceCallback. This will automatically setup the compatibility and use the Hugging Face model’s built in loss.

Jump to the example section for a full tutorial.

from fastai.text.all import *
from fastxtend.text.all import *

# load a task specific AutoModel
hf_model = AutoModel.from_pretrained("model-name")

# setup dataset and then dataloaders
train_dataset = dataset['train'].with_format('torch')
train_dataloader = HuggingFaceLoader(
    train_dataset, batch_size=batch_size,
    collate_fn=data_collator, shuffle=True,
    drop_last=True, num_workers=num_cpus()
)

# valid_dataloader definition cut for brevity
dls = DataLoaders(train_dataloader, valid_dataloader)

learn = Learner(dls, hf_model, loss_func=HuggingFaceLoss(), ...,
                cbs=HuggingFaceCallback()).to_bf16()

# save the model after training using Transformers
learn.hf_model.save_pretrained("trained-model-name")

To train with a different loss, pass in a PyTorch compatible loss to Learner as normal, and HuggingFaceCallback will use it instead of the model’s built in loss.


source

HuggingFaceLoss

 HuggingFaceLoss (**kwargs)

To use the Hugging Face model’s built in loss function, pass this loss to Learner


source

HuggingFaceWrapper

 HuggingFaceWrapper (model:PreTrainedModel)

A minimal compatibility wrapper between a Hugging Face model and Learner

Type Details
model PreTrainedModel Hugging Face compatible model

In practice, you won’t need to use the HuggingFaceWrapper as HuggingFaceCallback will automatically add it for you.


source

HuggingFaceCallback

 HuggingFaceCallback (labels:str|None='labels', loss:str='loss',
                      logits:str='logits', unwrap:bool=False)

Provides compatibility between fastai’s Learner, the Transformers model, & HuggingFaceLoader

Type Default Details
labels str | None labels Input batch labels key. Set to None if input doesn’t contain labels
loss str loss Model output loss key
logits str logits Model output logits key
unwrap bool False After training completes, unwrap the Transformers model

HuggingFaceCallback automatically wraps a Transformer model with the HuggingFaceWrapper for compatibility with fastai.learner.Learner. The original Transformers model is accessable via Learner.hf_model.

If HuggingFaceLoss is passed to Learner, then HuggingFaceCallback will use the Hugging Face model’s built in loss.

If any other loss function is passed to Learner, HuggingFaceCallback will prevent the built-in loss from being calculated and will use the Learner loss function instead.

If labels=None, then HuggingFaceCallback will not attempt to assign a fastai target from the Hugging Face input batch. The default fastai and fastxtend metrics will not work without labels.

After training, the HuggingFaceCallback can automatically unwrap model if unwrap=True.


source

HuggingFaceLoader

 HuggingFaceLoader (dataset:Dataset, batch_size:int,
                    shuffle:bool|None=None,
                    sampler:Sampler|Iterable|None=None, batch_sampler:Samp
                    ler[Sequence]|Iterable[Sequence]|None=None,
                    num_workers:int=0, collate_fn:_collate_fn_t|None=None,
                    pin_memory:bool=False, drop_last:bool=False,
                    timeout:float=0,
                    worker_init_fn:_worker_init_fn_t|None=None,
                    multiprocessing_context=None, generator=None,
                    prefetch_factor:int|None=None,
                    persistent_workers:bool=False,
                    pin_memory_device:str='')

A minimal compatibility DataLoader between a Hugging Face and Learner

Type Default Details
dataset Dataset dataset from which to load the data
batch_size int Batch size
shuffle bool | None None Randomize the order of data at each epoch (default: False)
sampler Sampler | Iterable | None None Determines how to draw samples from the dataset. Cannot be used with shuffle.
batch_sampler Sampler[Sequence] | Iterable[Sequence] | None None Rreturns a batch of indices at a time. Cannot be used with batch_size, shuffle, sampler, or drop_last.
num_workers int 0 Number of processes to use for data loading. 0 means using the main process (default: 0).
collate_fn _collate_fn_t | None None Function that merges a list of samples into a mini-batch of Tensors. Used for map-style datasets.
pin_memory bool False Copy Tensors into device/CUDA pinned memory before returning them
drop_last bool False Drop the last incomplete batch if the dataset size is not divisible by the batch size
timeout float 0 Timeout value for collecting a batch from workers
worker_init_fn _worker_init_fn_t | None None called on each worker subprocess with the worker id as input
multiprocessing_context NoneType None
generator NoneType None
prefetch_factor int | None None number of batches loaded in advance by each worker
persistent_workers bool False if True, the data loader will not shutdown the worker processes after a dataset has been consumed once
pin_memory_device str the data loader will copy Tensors into device pinned memory before returning them if pin_memory is set to true

Hugging Face Datasets, and thus DataLoaders, return dictionary objects while the fastai.learner.Learner expects tuples. HuggingFaceLoader is a PyTorch DataLoader which wraps the Hugging Face batch dictionary in a tuple for Learner compatibility. It is otherwise identical to a PyTorch DataLoader.

Example

In this example, we’ll use Hugging Face Transformers along with fastai & fastxtend to train a DistilRoBERTa on a IMDb subset.

This example is based on the Transformers documentation sequence classification example.

Setup Transformer Objects

First, we’ll grab the DistilRoBERTa tokenizer and model from the Transformers Auto methods.

tokenizer = AutoTokenizer.from_pretrained('distilroberta-base')
model = AutoModelForSequenceClassification.from_pretrained('distilroberta-base', num_labels=2)

Next, download IMDb using Dataset’s load_dataset. In this example, we’ll use a subset of IMDb.

imdb = load_dataset('imdb')
with less_random():
    imdb['train'] = imdb['train'].shuffle().select(range(5000))
    imdb['test'] = imdb['test'].shuffle().select(range(1000))

Next, we’ll tokenize the data using Dataset’s map method.

def tokenize_data(batch, tokenizer):
    return tokenizer(batch['text'], truncation=True)

imdb['train'] = imdb['train'].map(
    partial(tokenize_data, tokenizer=tokenizer),
    remove_columns='text', batched=True, batch_size=512, num_proc=num_cpus(),
)

imdb['test'] = imdb['test'].map(
    partial(tokenize_data, tokenizer=tokenizer),
    remove_columns='text', batched=True, batch_size=512, num_proc=num_cpus(),
)

Define the DataLoader

We need to use fastxtend’s HuggingFaceLoader instead of the PyTorch DataLoader. HuggingFaceLoader is a simple wrapper around a PyTorch DataLoader which returns Transformer’s dictionary batches in tuples as the fastai.learner.Learner expects. It is otherwise identical to the PyTorch DataLoader.

After creating the train and valid HuggingFaceLoader, we need to wrap them in fastai.data.core.DataLoaders.

with less_random():
    train_dataloader = HuggingFaceLoader(
        imdb['train'].with_format('torch'), batch_size=16,
        collate_fn=DataCollatorWithPadding(tokenizer), shuffle=True,
        drop_last=True, num_workers=num_cpus()
    )

    valid_dataloader = HuggingFaceLoader(
        imdb['test'].with_format('torch'), batch_size=16,
        collate_fn=DataCollatorWithPadding(tokenizer), shuffle=False,
        drop_last=False, num_workers=num_cpus()
    )

    dls = DataLoaders(train_dataloader, valid_dataloader)

Create a Learner and Train

Finally, we’ll create the Learner to train DistilRoBERTa on IMDb. We’ll pass in the HuggingFaceCallback to cbs to handle loss function compatibility between Transformers and fastai.

Transformer models contain an internal loss method, which we’ll use by passing HuggingFaceLoss to loss_func.

HuggingFaceCallback expects the Transformer model to have and output logits and loss keys. If these exist but are named differently, you’ll need to pass the non-standard key names to HuggingFaceCallback.

If your input doesn’t have a label key, perhaps because you are pretraining a causal language model, you should set it to None.

We now can use any fastai and/or fastxtend callbacks, optimizers, or metrics to train our Transformers model like usual.

with less_random():
    learn = Learner(dls, model, loss_func=HuggingFaceLoss(),
                    opt_func=stableadam(foreach=True),
                    metrics=Accuracy(), cbs=HuggingFaceCallback).to_bf16()

    learn.fit_flat_warmup(3, lr=8e-4, wd=1e-2)
epoch train_loss valid_loss accuracy time
0 0.691708 0.690203 0.492000 00:38
1 0.510412 0.409681 0.854000 00:37
2 0.282954 0.300484 0.873000 00:38

If we want to use our own loss, such as nn.CrossEntropyLoss with label smoothing, we could pass in any PyTorch compatible loss function to Learner and HuggingFaceCallback will automatically use it instead of DistilRoBERTa’s internal loss function.

In this example, we use fastxtend’s CompilerCallback via the Learner.compile convenience method to accelerate training throughput using torch.compile. After compiling the model in the first epoch, training speed is increased, and memory usage is reduced. In this small example it’s an overall loss, but we’d want to compile DistilRoBERTa if training on the entirety of IMDb.

Compiling the model with compile(dynamic=True) requires a minimum of Pytorch 2.1. Dynamic shapes does not work in PyTorch 2.0.

model = AutoModelForSequenceClassification.from_pretrained('distilroberta-base', num_labels=2)

with less_random():
    learn = Learner(dls, model, loss_func=nn.CrossEntropyLoss(label_smoothing=0.1),
                    opt_func=stableadam(foreach=True), metrics=Accuracy(),
                    cbs=HuggingFaceCallback).to_bf16().compile(dynamic=True)

    learn.fit_flat_warmup(3, lr=8e-4, wd=1e-2)
epoch train_loss valid_loss accuracy time
0 0.686346 0.677865 0.658000 01:25
1 0.423131 0.383354 0.886000 00:27
2 0.355547 0.374400 0.887000 00:27

Accessing and Saving the Model

The the original Transformers model is accessable via Learner.model.hf_model or Learner.hf_model (both point to the same object).

We can use any Transformers method to save the model, such as save_pretrained.

learn.hf_model.save_pretrained(model_path)