if parse(fastai.__version__) >= parse('2.7.0'):
= "before_draw before_batch after_pred after_loss before_backward after_cancel_backward after_backward before_step after_step after_cancel_batch after_batch".split()
_inner_loop else:
= "before_draw before_batch after_pred after_loss before_backward before_step after_step after_cancel_batch after_batch".split() _inner_loop
Simple Profiler
Since simple profiler changes the fastai data loading loop, it is not imported by any of the fastxtend all imports. It needs to be imported seperately:
from fastxtend.callback import simpleprofiler
Simple Profiler is currently untested on distributed training.
Jump to usage examples.
Events
Fastai callbacks do not have an event which is called directly before drawing a batch. Simple Profiler adds a new callback event called before_draw
.
With Simple Profiler imported, a callback can implement actions on the following events:
after_create
: called after theLearner
is createdbefore_fit
: called before starting training or inference, ideal for initial setup.before_epoch
: called at the beginning of each epoch, useful for any behavior you need to reset at each epoch.before_train
: called at the beginning of the training part of an epoch.before_draw
: called at the beginning of each batch, just before drawing said batch.before_batch
: called at the beginning of each batch, just after drawing said batch. It can be used to do any setup necessary for the batch (like hyper-parameter scheduling) or to change the input/target before it goes in the model (change of the input with techniques like mixup for instance).after_pred
: called after computing the output of the model on the batch. It can be used to change that output before it’s fed to the loss.after_loss
: called after the loss has been computed, but before the backward pass. It can be used to add any penalty to the loss (AR or TAR in RNN training for instance).before_backward
: called after the loss has been computed, but only in training mode (i.e. when the backward pass will be used)before_step
: called after the backward pass, but before the update of the parameters. It can be used to do any change to the gradients before said update (gradient clipping for instance).after_step
: called after the step and before the gradients are zeroed.after_batch
: called at the end of a batch, for any clean-up before the next one.after_train
: called at the end of the training phase of an epoch.before_validate
: called at the beginning of the validation phase of an epoch, useful for any setup needed specifically for validation.after_validate
: called at the end of the validation part of an epoch.after_epoch
: called at the end of an epoch, for any clean-up before the next one.after_fit
: called at the end of training, for final clean-up.
Implement before_draw
To add before_draw
as a callable event, first it needs to be added to both the _inner_loop
and _events
lists of fastai events (fastai 2.7.0 adds new backward events).
if parse(fastai.__version__) >= parse('2.7.0'):
= L.split('after_create before_fit before_epoch before_train before_draw before_batch after_pred after_loss \
_events before_backward after_cancel_backward after_backward before_step after_cancel_step after_step \
after_cancel_batch after_batch after_cancel_train after_train before_validate after_cancel_validate \
after_validate after_cancel_epoch after_epoch after_cancel_fit after_fit')
else:
= L.split('after_create before_fit before_epoch before_train before_draw before_batch after_pred after_loss \
_events before_backward before_step after_cancel_step after_step after_cancel_batch after_batch after_cancel_train \
after_train before_validate after_cancel_validate after_validate after_cancel_epoch \
after_epoch after_cancel_fit after_fit')
'event', **_events.map_dict(),
mk_class(="All possible events as attributes to get tab-completion and typo-proofing") doc
Next, Callback
needs to be modified to be aware of the new event.
@patch
def __call__(self:Callback, event_name):
"Call `self.{event_name}` if it's defined"
= (event_name not in _inner_loop or (self.run_train and getattr(self, 'training', True)) or
_run self.run_valid and not getattr(self, 'training', False)))
(= None
res if self.run and _run:
try: res = getattr(self, event_name, noop)()
except (CancelBatchException, CancelEpochException, CancelFitException, CancelStepException, CancelTrainException, CancelValidException): raise
except Exception as e:
= [f'Exception occured in `{self.__class__.__name__}` when calling event `{event_name}`:\n\t{e.args[0]}']
e.args raise
if event_name=='after_fit': self.run=True #Reset self.run to True at each end of fit
return res
Then Learner._call_one
needs to patch to be aware of the before_draw
.
@patch
def _call_one(self:Learner, event_name):
if not hasattr(event, event_name): raise Exception(f'missing {event_name}')
for cb in self.cbs.sorted('order'): cb(event_name)
Finally, Learner.all_batches
can be modified to call before_draw
when iterating through a dataloader.
@patch
def all_batches(self:Learner):
self.n_iter = len(self.dl)
if hasattr(self, 'simple_profiler'):
self.it = iter(self.dl)
for i in range(self.n_iter):
self("before_draw")
self.one_batch(i, next(self.it))
del(self.it)
else:
for o in enumerate(self.dl): self.one_batch(*o)
Learner.all_batches
Learner.all_batches ()
While testing hasn’t shown any negative side effects of this approach, all_batches
only uses the new batch drawing implementation if SimpleProfilerCallback
is in the list of callbacks, and reverts back to the original method if not.
SimpleProfilerPostCallback
SimpleProfilerPostCallback (samples_per_second=True)
Pair with SimpleProfilerCallback
to profile training performance. Removes itself after training is over.
SimpleProfilerCallback
SimpleProfilerCallback (show_report=True, plain=False, markdown=False, save_csv=False, csv_name='simple_profile.csv', logger_callback='wandb')
Adds a simple profiler to the fastai Learner
. Optionally showing formatted report or saving unformatted results as csv.
Pair with SimpleProfilerPostCallback to profile training performance.
Post fit, access report & results via Learner.simple_profile_report
& Learner.simple_profile_results
.
Type | Default | Details | |
---|---|---|---|
show_report | bool | True | Display formatted report post profile |
plain | bool | False | For Jupyter Notebooks, display plain report |
markdown | bool | False | Display markdown formatted report |
save_csv | bool | False | Save raw results to csv |
csv_name | str | simple_profile.csv | CSV save location |
logger_callback | str | wandb | Log report and samples/second to logger_callback using Callback.name |
Learner.profile
Learner.profile (show_report=True, plain=False, markdown=False, save_csv=False, csv_name='simple_profile.csv', samples_per_second=True, logger_callback='wandb')
Run Simple Profiler when training. Simple Profiler removes itself when finished.
Type | Default | Details | |
---|---|---|---|
show_report | bool | True | Display formatted report post profile |
plain | bool | False | For Jupyter Notebooks, display plain report |
markdown | bool | False | Display markdown formatted report |
save_csv | bool | False | Save raw results to csv |
csv_name | str | simple_profile.csv | CSV save location |
samples_per_second | bool | True | Log samples/second for all actions & steps |
logger_callback | str | wandb | Log report and samples/second to logger_callback using Callback.name |
Output
The Simple Profiler report contains the following items divided in three Phases (Fit, Train, & Valid)
Fit:
fit
: total time fitting the model takes.epoch
: duration of both training and validation epochs. Often epoch total time is the same amount of elapsed time as fit.train
: duration of each training epoch.valid
: duration of each validation epoch.
Train:
draw
: time spent waiting for a batch to be drawn. Measured frombefore_draw
tobefore_batch
. Ideally this value should be as close to zero as possible.batch
: total duration of all batch steps except drawing the batch. Measured frombefore_batch
toafter_batch
.forward
: duration of the forward pass and any additional batch modifications. Measured frombefore_batch
toafter_pred
.loss
: duration of calculating loss. Measured fromafter_pred
toafter_loss
.backward
: duration of the backward pass. Measured frombefore_backward
tobefore_step
.opt_step
: duration of the optimizer step. Measured frombefore_step
toafter_step
.zero_grad
: duration of the zero_grad step. Measured fromafter_step
toafter_batch
.
Valid:
draw
: time spent waiting for a batch to be drawn. Measured frombefore_draw
tobefore_batch
. Ideally this value should be as close to zero as possible.batch
: total duration of all batch steps except drawing the batch. Measured frombefore_batch
toafter_batch
.predict
: duration of the prediction pass and any additional batch modifications. Measured frombefore_batch
toafter_pred
.loss
: duration of calculating loss. Measured fromafter_pred
toafter_loss
.
Examples
The example is trained on Imagenette with an image size of 256 and batch size of 64 on a SageMaker Studio Lab T4 four CPU instance.
= Learner(dls, xse_resnet50(n_out=dls.c), metrics=Accuracy()).to_fp16().profile()
learn 2, 3e-3) learn.fit_one_cycle(
epoch | train_loss | valid_loss | accuracy | time |
---|---|---|---|---|
0 | 1.526267 | 1.588631 | 0.509554 | 02:43 |
1 | 1.044853 | 0.949273 | 0.705732 | 02:46 |
Phase | Action | Step | Mean Duration | Duration Std Dev | Number of Calls | Samples/Second | Total Time | Percent of Total |
---|---|---|---|---|---|---|---|---|
fit | - | - | 1 | - | 330.2 s | 100% | ||
epoch | 165.1 s | 1.160 s | 2 | - | 330.2 s | 100% | ||
train | 148.6 s | 1.113 s | 2 | 66 | 297.1 s | 90% | ||
valid | 16.56 s | 47.87ms | 2 | 3,019 | 33.11 s | 10% | ||
train | draw | 7.803ms | 54.23ms | 294 | 0 | 2.294 s | 1% | |
batch | 967.9ms | 10.47ms | 294 | 66 | 284.6 s | 86% | ||
forward | 21.23ms | 19.30ms | 294 | 3,290 | 6.242 s | 2% | ||
loss | 974.6µs | 204.6µs | 294 | 68,140 | 286.5ms | 0% | ||
backward | 387.2ms | 18.61ms | 294 | 167 | 113.8 s | 34% | ||
opt_step | 556.4ms | 6.388ms | 294 | 115 | 163.6 s | 50% | ||
zero_grad | 1.968ms | 122.2µs | 294 | - | 578.7ms | 0% | ||
valid | draw | 13.44ms | 75.85ms | 124 | -680 | 1.667 s | 1% | |
batch | 18.33ms | 5.751ms | 124 | 3,699 | 2.273 s | 1% | ||
predict | 17.37ms | 5.533ms | 124 | - | 2.154 s | 1% | ||
loss | 836.1µs | 228.3µs | 124 | 80,358 | 103.7ms | 0% |
New Training Loop
The show_training_loop
output below shows where the new before_draw
event fits into the training loop.
= synth_learner()
learn learn.show_training_loop()
Start Fit
- before_fit : [TrainEvalCallback, Recorder, ProgressCallback]
Start Epoch Loop
- before_epoch : [Recorder, ProgressCallback]
Start Train
- before_train : [TrainEvalCallback, Recorder, ProgressCallback]
Start Batch Loop
- before_draw : []
- before_batch : [CastToTensor]
- after_pred : []
- after_loss : []
- before_backward: []
- before_step : []
- after_step : []
- after_cancel_batch: []
- after_batch : [TrainEvalCallback, Recorder, ProgressCallback]
End Batch Loop
End Train
- after_cancel_train: [Recorder]
- after_train : [Recorder, ProgressCallback]
Start Valid
- before_validate: [TrainEvalCallback, Recorder, ProgressCallback]
Start Batch Loop
- **CBs same as train batch**: []
End Batch Loop
End Valid
- after_cancel_validate: [Recorder]
- after_validate : [Recorder, ProgressCallback]
End Epoch Loop
- after_cancel_epoch: []
- after_epoch : [Recorder]
End Fit
- after_cancel_fit: []
- after_fit : [ProgressCallback]
Simple Profiler Wandb Logging
Logs samples/second for draw, batch, forward, loss, backward, and opt_step steps as wandb charts.
Also logs two tables to active wandb run: - simple_profile_report
: formatted report from Simple Profiler Callback - simple_profile_results
: raw results from Simple Profiler Callback
try:
import wandb
@patch
def _wandb_log_after_batch(self:SimpleProfilerCallback):
= {f'samples_per_second/train_{action}': self._train_samples_per_second(action) for action in _train[:-1]}
train_vals self.learn.wandb._wandb_step+1)
wandb.log(train_vals,
@patch
def _wandb_log_after_fit(self:SimpleProfilerCallback):
= wandb.Table(dataframe=self.learn.simple_profile_report)
report = wandb.Table(dataframe=self.learn.simple_profile_results)
results
"simple_profile_report": report})
wandb.log({"simple_profile_results": results})
wandb.log({# ensure sync
wandb.log({}) except:
pass
Extend to other Loggers
To extend to new loggers, follow the Weights & Biases code above and create patches for SimpleProfilerCallback
to add a _{Callback.name}_log_after_batch
and _{Callback.name}_log_after_fit
, where Callback.name
is the name of the logger callback.
Then to use, pass logger_callback='{Callback.name}'
to Learner.profile()
.
SimpleProfilerCallback
sets its _log_after_batch
method to f'_{self.logger_callback}_log_after_batch'
, which should match the patched method.
self._log_after_batch = getattr(self, f'_{self.logger_callback}_log_after_batch', noop)
SimpleProfilerCallback.log_after_fit
behaves the same way.