Simple Profiler

Callbacks which add a simple profiler to fastai. Inspired by PyTorch Lightning’s SimpleProfiler.

Since simple profiler changes the fastai data loading loop, it is not imported by any of the fastxtend all imports. It needs to be imported seperately:

from fastxtend.callback import simpleprofiler
Note

Simple Profiler is currently untested on distributed training.

Jump here for usage examples.

Events

Fastai callbacks do not have an event which is called directly before drawing a batch. Simple Profiler adds a new callback event called before_draw.

With Simple Profiler imported, a callback can implement actions on the following events:

  • after_create: called after the Learner is created
  • before_fit: called before starting training or inference, ideal for initial setup.
  • before_epoch: called at the beginning of each epoch, useful for any behavior you need to reset at each epoch.
  • before_train: called at the beginning of the training part of an epoch.
  • before_draw: called at the beginning of each batch, just before drawing said batch.
  • before_batch: called at the beginning of each batch, just after drawing said batch. It can be used to do any setup necessary for the batch (like hyper-parameter scheduling) or to change the input/target before it goes in the model (change of the input with techniques like mixup for instance).
  • after_pred: called after computing the output of the model on the batch. It can be used to change that output before it’s fed to the loss.
  • after_loss: called after the loss has been computed, but before the backward pass. It can be used to add any penalty to the loss (AR or TAR in RNN training for instance).
  • before_backward: called after the loss has been computed, but only in training mode (i.e. when the backward pass will be used)
  • before_step: called after the backward pass, but before the update of the parameters. It can be used to do any change to the gradients before said update (gradient clipping for instance).
  • after_step: called after the step and before the gradients are zeroed.
  • after_batch: called at the end of a batch, for any clean-up before the next one.
  • after_train: called at the end of the training phase of an epoch.
  • before_validate: called at the beginning of the validation phase of an epoch, useful for any setup needed specifically for validation.
  • after_validate: called at the end of the validation part of an epoch.
  • after_epoch: called at the end of an epoch, for any clean-up before the next one.
  • after_fit: called at the end of training, for final clean-up.

Implement before_draw

To add before_draw as a callable event, first it needs to be added to both the _inner_loop and _events lists of fastai events (fastai 2.7.0 adds new backward events).

if parse(fastai.__version__) >= parse('2.7.0'):
    _inner_loop = "before_draw before_batch after_pred after_loss before_backward after_cancel_backward after_backward before_step after_step after_cancel_batch after_batch".split()
else:
    _inner_loop = "before_draw before_batch after_pred after_loss before_backward before_step after_step after_cancel_batch after_batch".split()
if parse(fastai.__version__) >= parse('2.7.0'):
    _events = L.split('after_create before_fit before_epoch before_train before_draw before_batch after_pred after_loss \
        before_backward after_cancel_backward after_backward before_step after_cancel_step after_step \
        after_cancel_batch after_batch after_cancel_train after_train before_validate after_cancel_validate \
        after_validate after_cancel_epoch after_epoch after_cancel_fit after_fit')
else:
    _events = L.split('after_create before_fit before_epoch before_train before_draw before_batch after_pred after_loss \
        before_backward before_step after_cancel_step after_step after_cancel_batch after_batch after_cancel_train \
        after_train before_validate after_cancel_validate after_validate after_cancel_epoch \
        after_epoch after_cancel_fit after_fit')

mk_class('event', **_events.map_dict(),
         doc="All possible events as attributes to get tab-completion and typo-proofing")

Next, Callback needs to be modified to be aware of the new event.

@patch
def __call__(self:Callback, event_name):
    "Call `self.{event_name}` if it's defined"
    _run = (event_name not in _inner_loop or (self.run_train and getattr(self, 'training', True)) or
            (self.run_valid and not getattr(self, 'training', False)))
    res = None
    if self.run and _run:
        try: res = getattr(self, event_name, noop)()
        except (CancelBatchException, CancelEpochException, CancelFitException, CancelStepException, CancelTrainException, CancelValidException): raise
        except Exception as e:
            e.args = [f'Exception occured in `{self.__class__.__name__}` when calling event `{event_name}`:\n\t{e.args[0]}']
            raise
    if event_name=='after_fit': self.run=True #Reset self.run to True at each end of fit
    return res

Then Learner._call_one needs to patch to be aware of the before_draw.

@patch
def _call_one(self:Learner, event_name):
    if not hasattr(event, event_name): raise Exception(f'missing {event_name}')
    for cb in self.cbs.sorted('order'): cb(event_name)

Finally, Learner.all_batches can be modified to call before_draw when iterating through a dataloader.

@patch
def all_batches(self:Learner):
    self.n_iter = len(self.dl)
    if hasattr(self, 'simple_profiler'):
        self.it = iter(self.dl)
        for i in range(self.n_iter):
            self("before_draw")
            self.one_batch(i, next(self.it))
        del(self.it)
    else:
        for o in enumerate(self.dl): self.one_batch(*o)

source

Learner.all_batches

 Learner.all_batches ()

While testing hasn’t shown any negative side effects of this approach, all_batches only uses the new batch drawing implementation if SimpleProfilerCallback is in the list of callbacks, and reverts back to the original method if not.


source

SimpleProfilerPostCallback

 SimpleProfilerPostCallback (samples_per_second=True)

Pair with SimpleProfilerCallback to profile training performance. Removes itself after training is over.


source

SimpleProfilerCallback

 SimpleProfilerCallback (show_report=True, plain=False, markdown=False,
                         save_csv=False, csv_name='simple_profile.csv',
                         logger_callback='wandb')

Adds a simple profiler to the fastai Learner. Optionally showing formatted report or saving unformatted results as csv.

Pair with SimpleProfilerPostCallback to profile training performance.

Post fit, access report & results via Learner.simple_profile_report & Learner.simple_profile_results.

Type Default Details
show_report bool True Display formatted report post profile
plain bool False For Jupyter Notebooks, display plain report
markdown bool False Display markdown formatted report
save_csv bool False Save raw results to csv
csv_name str simple_profile.csv CSV save location
logger_callback str wandb Log report and samples/second to logger_callback using Callback.name

source

Learner.profile

 Learner.profile (show_report=True, plain=False, markdown=False,
                  save_csv=False, csv_name='simple_profile.csv',
                  samples_per_second=True, logger_callback='wandb')

Run Simple Profiler when training. Simple Profiler removes itself when finished.

Type Default Details
show_report bool True Display formatted report post profile
plain bool False For Jupyter Notebooks, display plain report
markdown bool False Display markdown formatted report
save_csv bool False Save raw results to csv
csv_name str simple_profile.csv CSV save location
samples_per_second bool True Log samples/second for all actions & steps
logger_callback str wandb Log report and samples/second to logger_callback using Callback.name

Output

The Simple Profiler report contains the following items divided in three Phases (Fit, Train, & Valid)

Fit:

  • fit: total time fitting the model takes.
  • epoch: duration of both training and validation epochs. Often epoch total time is the same amount of elapsed time as fit.
  • train: duration of each training epoch.
  • valid: duration of each validation epoch.

Train:

  • draw: time spent waiting for a batch to be drawn. Measured from before_draw to before_batch. Ideally this value should be as close to zero as possible.
  • batch: total duration of all batch steps except drawing the batch. Measured from before_batch to after_batch.
  • forward: duration of the forward pass and any additional batch modifications. Measured from before_batch to after_pred.
  • loss: duration of calculating loss. Measured from after_pred to after_loss.
  • backward: duration of the backward pass. Measured from before_backward to before_step.
  • opt_step: duration of the optimizer step. Measured from before_step to after_step.
  • zero_grad: duration of the zero_grad step. Measured from after_step to after_batch.

Valid:

  • draw: time spent waiting for a batch to be drawn. Measured from before_draw to before_batch. Ideally this value should be as close to zero as possible.
  • batch: total duration of all batch steps except drawing the batch. Measured from before_batch to after_batch.
  • predict: duration of the prediction pass and any additional batch modifications. Measured from before_batch to after_pred.
  • loss: duration of calculating loss. Measured from after_pred to after_loss.

Examples

The example is trained on Imagenette with an image size of 256 and batch size of 64 on a SageMaker Studio Lab T4 four CPU instance.

learn = Learner(dls, xse_resnet50(n_out=dls.c), metrics=Accuracy()).to_fp16().profile()
learn.fit_one_cycle(2, 3e-3)
epoch train_loss valid_loss accuracy time
0 1.526267 1.588631 0.509554 02:43
1 1.044853 0.949273 0.705732 02:46
Simple Profiler Results
Phase Action Step Mean Duration Duration Std Dev Number of Calls Samples/Second Total Time Percent of Total
fit - - 1 - 330.2 s 100%
epoch 165.1 s 1.160 s 2 - 330.2 s 100%
train 148.6 s 1.113 s 2 66 297.1 s 90%
valid 16.56 s 47.87ms 2 3,019 33.11 s 10%
train draw 7.803ms 54.23ms 294 0 2.294 s 1%
batch 967.9ms 10.47ms 294 66 284.6 s 86%
forward 21.23ms 19.30ms 294 3,290 6.242 s 2%
loss 974.6µs 204.6µs 294 68,140 286.5ms 0%
backward 387.2ms 18.61ms 294 167 113.8 s 34%
opt_step 556.4ms 6.388ms 294 115 163.6 s 50%
zero_grad 1.968ms 122.2µs 294 - 578.7ms 0%
valid draw 13.44ms 75.85ms 124 -680 1.667 s 1%
batch 18.33ms 5.751ms 124 3,699 2.273 s 1%
predict 17.37ms 5.533ms 124 - 2.154 s 1%
loss 836.1µs 228.3µs 124 80,358 103.7ms 0%

New Training Loop

The show_training_loop output below shows where the new before_draw event fits into the training loop.

learn = synth_learner()
learn.show_training_loop()
Start Fit
   - before_fit     : [TrainEvalCallback, Recorder, ProgressCallback]
  Start Epoch Loop
     - before_epoch   : [Recorder, ProgressCallback]
    Start Train
       - before_train   : [TrainEvalCallback, Recorder, ProgressCallback]
      Start Batch Loop
         - before_draw    : []
         - before_batch   : [CastToTensor]
         - after_pred     : []
         - after_loss     : []
         - before_backward: []
         - before_step    : []
         - after_step     : []
         - after_cancel_batch: []
         - after_batch    : [TrainEvalCallback, Recorder, ProgressCallback]
      End Batch Loop
    End Train
     - after_cancel_train: [Recorder]
     - after_train    : [Recorder, ProgressCallback]
    Start Valid
       - before_validate: [TrainEvalCallback, Recorder, ProgressCallback]
      Start Batch Loop
         - **CBs same as train batch**: []
      End Batch Loop
    End Valid
     - after_cancel_validate: [Recorder]
     - after_validate : [Recorder, ProgressCallback]
  End Epoch Loop
   - after_cancel_epoch: []
   - after_epoch    : [Recorder]
End Fit
 - after_cancel_fit: []
 - after_fit      : [ProgressCallback]

Simple Profiler Wandb Logging

Logs samples/second for draw, batch, forward, loss, backward, and opt_step steps as wandb charts.

Also logs two tables to active wandb run: - simple_profile_report: formatted report from Simple Profiler Callback - simple_profile_results: raw results from Simple Profiler Callback

try:
    import wandb

    @patch
    def _wandb_log_after_batch(self:SimpleProfilerCallback):
        train_vals = {f'samples_per_second/train_{action}': self._train_samples_per_second(action) for action in _train[:-1]}
        wandb.log(train_vals, self.learn.wandb._wandb_step+1)

    @patch
    def _wandb_log_after_fit(self:SimpleProfilerCallback):
        report = wandb.Table(dataframe=self.learn.simple_profile_report)
        results = wandb.Table(dataframe=self.learn.simple_profile_results)

        wandb.log({"simple_profile_report": report})
        wandb.log({"simple_profile_results": results})
        wandb.log({}) # ensure sync
except:
    pass

Extend to other Loggers

To extend to new loggers, follow the Weights & Biases code above and create patches for SimpleProfilerCallback to add a _{Callback.name}_log_after_batch and _{Callback.name}_log_after_fit, where Callback.name is the name of the logger callback.

Then to use, pass logger_callback='{Callback.name}' to Learner.profile().

SimpleProfilerCallback sets its _log_after_batch method to f'_{self.logger_callback}_log_after_batch', which should match the patched method.

self._log_after_batch = getattr(self, f'_{self.logger_callback}_log_after_batch', noop)

SimpleProfilerCallback.log_after_fit behaves the same way.