PyTorch Compile

Callbacks and patches to integrate torch.compile into fastai

The CompilerCallback and DynamoExplainCallback provide an easy to use torch.compile integration for fastai.

For more information on torch.compile please read PyTorch’s getting started guide. For troubleshooting torch.compile refer to this PyTorch Nightly guide.

This module is not imported via any fastxtend all imports. You must import it separately after importing fastai and fastxtend as it modifies model saving, loading, and training:

from fastxtend.callback import compiler
# or
from fastxtend.callback.compiler import *

To use, create a fastai.learner.Learner with a torch.compile compatible model and call compile on the Learner or pass CompilerCallback to the Learner of fit method callbacks.

Learner(...).compile()
# or
Learner(..., cbs=CompilerCallback())

source

CompileMode

 CompileMode (value, names=None, module=None, qualname=None, type=None,
              start=1)

All valid torch.compile modes for tab-completion and typo-proofing

Currently, the ‘reduce-overhead’ mode doesn’t appear to work with all models, and ‘max-autotune’ shouldn’t be used per Compile troubleshooting and gotchas.


source

MatMulPrecision

 MatMulPrecision (value, names=None, module=None, qualname=None,
                  type=None, start=1)

All valid matmul_precision modes for tab-completion and typo-proofing


source

CompilerCallback

 CompilerCallback (fullgraph:bool=False, dynamic:bool=False,
                   backend:str|Callable='inductor',
                   mode:str|CompileMode|None=None,
                   options:Dict[str,Union[str,int,bool]]|None=None,
                   matmul_precision:str|MatMulPrecision='high',
                   recompile:bool=False, verbose:bool=False)

A callback for using torch.compile (beta) with fastai

Type Default Details
fullgraph bool False Prevent breaking model into subgraphs
dynamic bool False Use dynamic shape tracing
backend str | Callable inductor torch.compile backend to use
mode str | CompileMode | None None torch.compile mode to use
options Dict[str, Union[str, int, bool]] | None None Extra options to pass to compile backend
matmul_precision str | MatMulPrecision high Set Ampere and newer matmul precision
recompile bool False Force a compiled model to recompile. Use when freezing/unfreezing a compiled model.
verbose bool False Verbose output

Using torch.compile with mode='max-autotune' is under active development and might fail. See Compile troubleshooting and gotchas for more details.

By default, CompilerCallback will set matmul ops to use TensorFloat32 for supported GPUs, which is the recommended setting for torch.compile. Set matmul_precision='highest' to turn off or matmul_precision='medium' to enable bfloat16 mode.

fastxtend provides the compile convenience method for easily enabling torch.compile. Or you can pass CompilerCallback to the cbs argument of the fastai.learner.Learner or a fit method.

learn = Learner(..., cbs=CompilerCallback())
learn.fine_tune(1)

source

DynamoExplainCallback

 DynamoExplainCallback (print_results:bool=True, out_guards:bool=False,
                        ops_per_graph:bool=False,
                        break_reasons:bool=False)

A callback to automate finding graph breaks with PyTorch Compile’s Dynamo Explain

Type Default Details
print_results bool True Print enabled torch._dynamo.explain output(s)
out_guards bool False Print the out_guards output
ops_per_graph bool False Print the ops_per_graph output
break_reasons bool False Print the break_reasons output

DynamoExplainCallback automates finding graph breaks using torch._dynamo.explain per the Identifying the cause of a graph break section in the PyTorch Compile FAQ. DynamoExplainCallback uses one batch from the validation dataloader1 to generate the _dynamo.explain report(s) and then cancels training.

To use, pass DynamoExplainCallback to the cbs argument of the fastai.learner.Learner or fit method.

learn = Learner(..., cbs=DynamoExplainCallback())
learn.fit(1)

By default, DynamoExplainCallback prints the basic explanation output from _dynamo.explain, with arguments to enable printing out_guards and/or ops_per_graph.

All _dynamo.explain outputs are stored as attributes in the callback for later reference. For example, to view the out_guards after running Learner with DynamoExplainCallback:

# PyTorch 2.0
print(learn.dynamo_explain.out_guards)

# PyTorch 2.1
print(learn.dynamo_explain.explain_output.out_guards)

Convenience Method

fastxtend adds a convenience method to fastai.learner.Learner to easily enable torch.compile.


source

Learner.compile

 Learner.compile (fullgraph:bool=False, dynamic:bool=False,
                  backend:Union[str,Callable]='inductor',
                  mode:Union[str,__main__.CompileMode,NoneType]=None,
                  options:Optional[Dict[str,Union[str,int,bool]]]=None, ma
                  tmul_precision:Union[str,__main__.MatMulPrecision]='high
                  ', recompile:bool=False, verbose:bool=False)

Set Learner to compile model using torch.compile via CompilerCallback

Type Default Details
fullgraph bool False Prevent breaking model into subgraphs
dynamic bool False Use dynamic shape tracing. Sets to False if PyTorch < 2.1
backend str | Callable inductor torch.compile backend to use
mode str | CompileMode | None None torch.compile mode to use
options Dict[str, Union[str, int, bool]] | None None Extra options to pass to compile backend
matmul_precision str | MatMulPrecision high Set Ampere and newer matmul precision
recompile bool False Force a compiled model to recompile. Use when freezing/unfreezing a compiled model.
verbose bool False Verbose output

compile only sets dynamic if using PyTorch 2.1 or later, for PyTorch 2.0 it’s hardcoded to False. You can override this by directly setting via CompilerCallback.

To use, call the compile method after initalizing a fastai.learner.Learner.

learn = Learner(...).compile()
learn.fine_tune(1)

Compatability Patches

These patches integrate torch.compile with fastai exporting, loading, freezing, unfreezing, and fine tuning.

Exporting and Loading


source

Learner.export

 Learner.export
                 (fname:Union[str,os.PathLike,BinaryIO,IO[bytes]]='export.
                 pkl', pickle_module:Any=<module 'pickle' from '/opt/hoste
                 dtoolcache/Python/3.9.18/x64/lib/python3.9/pickle.py'>,
                 pickle_protocol:int=2)

Export the content of self without the items and the optimizer state for inference

Type Default Details
fname FILE_LIKE export.pkl Learner export file name, path, bytes, or IO
pickle_module Any pickle Module used for pickling metadata and objects
pickle_protocol int 2 Pickle protocol used

source

load_learner

 load_learner (fname:Union[str,os.PathLike,BinaryIO,IO[bytes]],
               cpu:bool=True, pickle_module=<module 'pickle' from '/opt/ho
               stedtoolcache/Python/3.9.18/x64/lib/python3.9/pickle.py'>)

Load a Learner object in fname, by default putting it on the cpu

Type Default Details
fname FILE_LIKE File name, path, bytes, or IO
cpu bool True Load model to CPU
pickle_module module pickle Module used for unpickling metadata and objects

By default, load_learner will remove the CompilerCallback.

Freezing and Unfreezing


source

Learner.freeze_to

 Learner.freeze_to (n:int)

Freeze parameter groups up to n

Freezing and unfreezing models works, but they need to be recompiled after. freeze_to will set CompilerCallback to recompile the model or warn users they need to manually recompile.

Training


source

Learner.fine_tune

 Learner.fine_tune (epochs:int, base_lr:float=0.002, freeze_epochs:int=1,
                    lr_mult:Union[int,float]=100, pct_start:float=0.3,
                    div:Union[int,float]=5.0, compile_frozen:bool=False,
                    lr_max=None, div_final=100000.0, wd=None, moms=None,
                    cbs=None, reset_opt=False, start_epoch=0)

Fine tune with Learner.freeze for freeze_epochs, then with Learner.unfreeze for epochs, using discriminative LR.

Type Default Details
epochs int Number of unfrozen epochs to train
base_lr float 0.002 Base learning rate, model head unfrozen learning rate
freeze_epochs int 1 Number of frozen epochs to train
lr_mult Numeric 100 Model stem unfrozen learning rate: base_lr/lr_mult
pct_start float 0.3 Start unfrozen learning rate cosine annealing
div Numeric 5.0 Initial unfrozen learning rate: base_lr/div
compile_frozen bool False Compile model during frozen finetuning if CompilerCallback is used
lr_max NoneType None
div_final float 100000.0
wd NoneType None
moms NoneType None
cbs NoneType None
reset_opt bool False
start_epoch int 0

By default, fine_tune will not compile the freeze_epochs, but this can be overridden by passing freeze_compile=True. If the model is already compiled, this will have no effect.

Footnotes

  1. Unless using the FFCV Loader, then it uses the training dataloader. This doesn’t effect seeded training as FFCV dataloaders do not seed transforms, only dataset order.↩︎