Callbacks which add a simple profiler to fastai. Inspired by PyTorch Lightning's SimpleProfiler.

Since simple profiler changes the fastai data loading loop, it is not imported by any of the fastxtend all imports. It needs to be imported seperately:

from fastxtend.callback import simpleprofiler

Jump here for usage examples.

Events

Fastai callbacks do not have an event which is called directly before drawing a batch. Simple Profiler adds a new callback event called before_draw.

With Simple Profiler imported, a callback can implement actions on the following events:

  • after_create: called after the Learner is created
  • before_fit: called before starting training or inference, ideal for initial setup.
  • before_epoch: called at the beginning of each epoch, useful for any behavior you need to reset at each epoch.
  • before_train: called at the beginning of the training part of an epoch.
  • before_draw: called at the beginning of each batch, just before drawing said batch.
  • before_batch: called at the beginning of each batch, just after drawing said batch. It can be used to do any setup necessary for the batch (like hyper-parameter scheduling) or to change the input/target before it goes in the model (change of the input with techniques like mixup for instance).
  • after_pred: called after computing the output of the model on the batch. It can be used to change that output before it's fed to the loss.
  • after_loss: called after the loss has been computed, but before the backward pass. It can be used to add any penalty to the loss (AR or TAR in RNN training for instance).
  • before_backward: called after the loss has been computed, but only in training mode (i.e. when the backward pass will be used)
  • before_step: called after the backward pass, but before the update of the parameters. It can be used to do any change to the gradients before said update (gradient clipping for instance).
  • after_step: called after the step and before the gradients are zeroed.
  • after_batch: called at the end of a batch, for any clean-up before the next one.
  • after_train: called at the end of the training phase of an epoch.
  • before_validate: called at the beginning of the validation phase of an epoch, useful for any setup needed specifically for validation.
  • after_validate: called at the end of the validation part of an epoch.
  • after_epoch: called at the end of an epoch, for any clean-up before the next one.
  • after_fit: called at the end of training, for final clean-up.

Implement before_draw

To add before_draw as a callable event, first it needs to be added to both lists of fastai events.

if parse(fastai.__version__) >= parse('2.7.0'):
    _inner_loop = "before_draw before_batch after_pred after_loss before_backward after_cancel_backward after_backward before_step after_step after_cancel_batch after_batch".split()
else:
    _inner_loop = "before_draw before_batch after_pred after_loss before_backward before_step after_step after_cancel_batch after_batch".split()
if parse(fastai.__version__) >= parse('2.7.0'):
    _events = L.split('after_create before_fit before_epoch before_train before_draw before_batch after_pred after_loss \
        before_backward after_cancel_backward after_backward before_step after_cancel_step after_step \
        after_cancel_batch after_batch after_cancel_train after_train before_validate after_cancel_validate \
        after_validate after_cancel_epoch after_epoch after_cancel_fit after_fit')
else:
    _events = L.split('after_create before_fit before_epoch before_train before_draw before_batch after_pred after_loss \
        before_backward before_step after_cancel_step after_step after_cancel_batch after_batch after_cancel_train \
        after_train before_validate after_cancel_validate after_validate after_cancel_epoch \
        after_epoch after_cancel_fit after_fit')

mk_class('event', **_events.map_dict(),
         doc="All possible events as attributes to get tab-completion and typo-proofing")

Next, Callback needs to be modified to be aware of the new event.

Callback.__call__[source]

Callback.__call__(event_name)

Call self.{event_name} if it's defined

@patch
def __call__(self:Callback, event_name):
    "Call `self.{event_name}` if it's defined"
    _run = (event_name not in _inner_loop or (self.run_train and getattr(self, 'training', True)) or
            (self.run_valid and not getattr(self, 'training', False)))
    res = None
    if self.run and _run: res = getattr(self, event_name, noop)()
    if event_name=='after_fit': self.run=True #Reset self.run to True at each end of fit
    return res

Then Learner._call_one needs to patch to be aware of the before_draw.

@patch
def _call_one(self:Learner, event_name):
    if not hasattr(event, event_name): raise Exception(f'missing {event_name}')
    for cb in self.cbs.sorted('order'): cb(event_name)

Finally, Learner.all_batches can be modified to call before_draw when iterating through a dataloader.

Learner.all_batches[source]

Learner.all_batches()

@patch
def all_batches(self:Learner):
    self.n_iter = len(self.dl)
    if hasattr(self, 'simple_profiler'):
        self.it = iter(self.dl)
        for i in range(self.n_iter):
            self("before_draw")
            self.one_batch(i, next(self.it))
        del(self.it)
    else:
        for o in enumerate(self.dl): self.one_batch(*o)

While testing hasn't shown any negative side effects of this approach, all_batches only uses the new batch drawing implementation if SimpleProfilerCallback is in the list of callbacks, and reverts back to the original method if not.

class SimpleProfilerPostCallback[source]

SimpleProfilerPostCallback(after_create=None, before_fit=None, before_epoch=None, before_train=None, before_batch=None, after_pred=None, after_loss=None, before_backward=None, before_step=None, after_cancel_step=None, after_step=None, after_cancel_batch=None, after_batch=None, after_cancel_train=None, after_train=None, before_validate=None, after_cancel_validate=None, after_validate=None, after_cancel_epoch=None, after_epoch=None, after_cancel_fit=None, after_fit=None) :: Callback

Pair with SimpleProfilerCallback to profile training performance. Removes itself after training is over.

class SimpleProfilerCallback[source]

SimpleProfilerCallback(show_report=True, plain=False, markdown=False, save_csv=False, csv_name='simple_profile.csv') :: Callback

Adds a simple profiler to the fastai Learner. Optionally showing formatted report or saving unformatted results as csv.

Pair with SimpleProfilerPostCallback to profile training performance.

Post fit, access report & results via Learner.simple_profile_report & Learner.simple_profile_results.

Convenience Method

Learner.profile[source]

Learner.profile(show_report=True, plain=False, markdown=False, save_csv=False, csv_name='simple_profile.csv')

Run Simple Profiler when training. Simple Profiler removes itself when finished.

Output

The Simple Profiler report contains the following items divided in three Phases (Fit, Train, & Valid)

Fit:

  • fit: total time fitting the model takes.
  • epoch: duration of both training and validation epochs. Often epoch total time is the same amount of elapsed time as fit.
  • train: duration of each training epoch.
  • validate: duration of each validation epoch.

Train:

  • draw: time spent waiting for a batch to be drawn. Measured from before_draw to before_batch. With default prefetching settings, ideally this is as close to instantly as possible.
  • batch: total duration of all batch steps sans drawing the batch. Measured from before_batch to after_batch.
  • pred: duration of the forward pass and any additional batch modifications. Measured from before_batch to after_pred.
  • loss: duration of caculating loss. Measured from after_pred to after_loss.
  • backward: duration of the backward pass. Measured from before_backward to before_step.
  • step: duration of the optimizer step. Measured from before_step to after_step.
  • zero_grad: duration of the zero_grad step. Measured from after_step to after_batch.

Valid:

  • draw: time spent waiting for a batch to be drawn. Measured from before_draw to before_batch. With default prefetching settings, ideally this is as close to instantly as possible.
  • batch: total duration of all batch steps sans drawing the batch. Measured from before_batch to after_batch.
  • pred: duration of the forward pass and any additional batch modifications. Measured from before_batch to after_pred.
  • loss: duration of caculating loss. Measured from after_pred to after_loss.

Examples

Both examples are trained on Imagenette with an image size of 256 and batch size of 64 on a Colab P100 4CPU instance.

learn = Learner(dls, xse_resnet50(n_out=dls.c), metrics=accuracy).to_fp16().profile()
learn.fit_one_cycle(2, 3e-3)
epoch train_loss valid_loss accuracy time
0 1.624050 2.352882 0.376051 03:25
1 1.184010 1.147916 0.643567 03:19
Simple Profiler Results
Phase Action Mean Duration Duration Std Dev Number of Calls Total Time Percent of Total
fit fit - - 1 404.7 s 100%
epoch 202.4 s 2.721 s 2 404.7 s 100%
train 178.4 s 2.020 s 2 356.7 s 88%
validate 23.99 s 699.9ms 2 47.98 s 12%
train batch 1.203 s 293.3ms 294 353.7 s 87%
step 726.8ms 35.05ms 294 213.7 s 53%
backward 411.3ms 159.6ms 294 120.9 s 30%
pred 32.90ms 107.5ms 294 9.673 s 2%
draw 28.49ms 78.12ms 294 8.375 s 2%
zero_grad 2.437ms 324.4µs 294 716.4ms 0%
loss 958.6µs 107.4µs 294 281.8ms 0%
valid batch 72.83ms 176.0ms 124 9.031 s 2%
pred 40.13ms 126.7ms 124 4.976 s 1%
draw 31.58ms 121.3ms 124 3.916 s 1%
loss 967.8µs 1.034ms 124 120.0ms 0%

When training a XResNet18, the total time spent drawing while training increases to 20 seconds, twenty percent of total fit time. This is due to Colab's four core CPU and slow disk not being able to prefetch quickly enough, causing the training process to wait on drawing a batch. In contrast in the XSEResNet50 training above, total training draw time was only 8 seconds, two percent of the total fit time.

learn = Learner(dls, xresnet18(n_out=dls.c), metrics=accuracy).to_fp16().profile()
learn.fit_one_cycle(2, 3e-3)
epoch train_loss valid_loss accuracy time
0 1.552181 1.569250 0.480764 00:49
1 1.148098 1.137456 0.638217 00:49
Simple Profiler Results
Phase Action Mean Duration Duration Std Dev Number of Calls Total Time Percent of Total
fit fit - - 1 99.46 s 100%
epoch 49.73 s 91.12ms 2 99.46 s 100%
train 34.84 s 7.996ms 2 69.68 s 70%
validate 14.89 s 82.42ms 2 29.78 s 30%
train batch 228.9ms 102.6ms 294 67.30 s 68%
step 130.7ms 9.819ms 294 38.43 s 39%
draw 68.27ms 102.4ms 294 20.07 s 20%
pred 16.81ms 6.889ms 294 4.942 s 5%
backward 9.573ms 2.601ms 294 2.815 s 3%
zero_grad 1.990ms 2.550ms 294 585.0ms 1%
loss 1.378ms 2.065ms 294 405.0ms 0%
valid batch 209.4ms 282.0ms 124 25.96 s 26%
draw 192.7ms 282.0ms 124 23.89 s 24%
pred 15.27ms 12.77ms 124 1.893 s 2%
loss 1.269ms 2.028ms 124 157.4ms 0%

New Training Loop

The show_training_loop output below shows where the new before_draw event fits into the training loop.

from fastai.test_utils import synth_learner
learn = synth_learner()
learn.show_training_loop()
Start Fit
   - before_fit     : [TrainEvalCallback, Recorder, ProgressCallback]
  Start Epoch Loop
     - before_epoch   : [Recorder, ProgressCallback]
    Start Train
       - before_train   : [TrainEvalCallback, Recorder, ProgressCallback]
      Start Batch Loop
         - before_draw    : []
         - before_batch   : []
         - after_pred     : []
         - after_loss     : []
         - before_backward: []
         - before_step    : []
         - after_step     : []
         - after_cancel_batch: []
         - after_batch    : [TrainEvalCallback, Recorder, ProgressCallback]
      End Batch Loop
    End Train
     - after_cancel_train: [Recorder]
     - after_train    : [Recorder, ProgressCallback]
    Start Valid
       - before_validate: [TrainEvalCallback, Recorder, ProgressCallback]
      Start Batch Loop
         - **CBs same as train batch**: []
      End Batch Loop
    End Valid
     - after_cancel_validate: [Recorder]
     - after_validate : [Recorder, ProgressCallback]
  End Epoch Loop
   - after_cancel_epoch: []
   - after_epoch    : [Recorder]
End Fit
 - after_cancel_fit: []
 - after_fit      : [ProgressCallback]

Simple Profiler Wandb Logging

Automatically logs Simple Profiler Callback to wandb.

Logs two tables to active wandb run:

  • simple_profile_report: formatted report from Simple Profiler Callback
  • simple_profile_results: raw results from Simple Profiler Callback