PyTorch Compile
torch.compile
into fastai
The CompilerCallback
and DynamoExplainCallback
provide an easy to use torch.compile
integration for fastai.
For more information on torch.compile
please read PyTorch’s getting started guide. For troubleshooting torch.compile
refer to this PyTorch Nightly guide.
This module is not imported via any fastxtend all imports. You must import it separately after importing fastai and fastxtend as it modifies model saving, loading, and training:
from fastxtend.callback import compiler
# or
from fastxtend.callback.compiler import *
To use, create a fastai.learner.Learner
with a torch.compile
compatible model and call compile
on the Learner
or pass CompilerCallback
to the Learner
of fit method callbacks.
compile()
Learner(...).# or
=CompilerCallback()) Learner(..., cbs
CompileMode
CompileMode (value, names=None, module=None, qualname=None, type=None, start=1)
All valid torch.compile
modes for tab-completion and typo-proofing
Currently, the ‘reduce-overhead’ mode doesn’t appear to work with all models, and ‘max-autotune’ shouldn’t be used per Compile troubleshooting and gotchas.
MatMulPrecision
MatMulPrecision (value, names=None, module=None, qualname=None, type=None, start=1)
All valid matmul_precision
modes for tab-completion and typo-proofing
CompilerCallback
CompilerCallback (fullgraph:bool=False, dynamic:bool=False, backend:str|Callable='inductor', mode:str|CompileMode|None=None, options:Dict[str,Union[str,int,bool]]|None=None, matmul_precision:str|MatMulPrecision='high', recompile:bool=False, verbose:bool=False)
A callback for using torch.compile
(beta) with fastai
Type | Default | Details | |
---|---|---|---|
fullgraph | bool | False | Prevent breaking model into subgraphs |
dynamic | bool | False | Use dynamic shape tracing |
backend | str | Callable | inductor | torch.compile backend to use |
mode | str | CompileMode | None | None | torch.compile mode to use |
options | Dict[str, Union[str, int, bool]] | None | None | Extra options to pass to compile backend |
matmul_precision | str | MatMulPrecision | high | Set Ampere and newer matmul precision |
recompile | bool | False | Force a compiled model to recompile. Use when freezing/unfreezing a compiled model. |
verbose | bool | False | Verbose output |
Using torch.compile
with mode='max-autotune'
is under active development and might fail. See Compile troubleshooting and gotchas for more details.
By default, CompilerCallback
will set matmul ops to use TensorFloat32 for supported GPUs, which is the recommended setting for torch.compile
. Set matmul_precision='highest'
to turn off or matmul_precision='medium'
to enable bfloat16
mode.
fastxtend provides the compile
convenience method for easily enabling torch.compile
. Or you can pass CompilerCallback
to the cbs
argument of the fastai.learner.Learner
or a fit method.
= Learner(..., cbs=CompilerCallback())
learn 1) learn.fine_tune(
DynamoExplainCallback
DynamoExplainCallback (print_results:bool=True, out_guards:bool=False, ops_per_graph:bool=False, break_reasons:bool=False)
A callback to automate finding graph breaks with PyTorch Compile’s Dynamo Explain
Type | Default | Details | |
---|---|---|---|
print_results | bool | True | Print enabled torch._dynamo.explain output(s) |
out_guards | bool | False | Print the out_guards output |
ops_per_graph | bool | False | Print the ops_per_graph output |
break_reasons | bool | False | Print the break_reasons output |
DynamoExplainCallback
automates finding graph breaks using torch._dynamo.explain
per the Identifying the cause of a graph break section in the PyTorch Compile FAQ. DynamoExplainCallback
uses one batch from the validation dataloader1 to generate the _dynamo.explain
report(s) and then cancels training.
To use, pass DynamoExplainCallback
to the cbs
argument of the fastai.learner.Learner
or fit method.
= Learner(..., cbs=DynamoExplainCallback())
learn 1) learn.fit(
By default, DynamoExplainCallback
prints the basic explanation output from _dynamo.explain
, with arguments to enable printing out_guards
and/or ops_per_graph
.
All _dynamo.explain
outputs are stored as attributes in the callback for later reference. For example, to view the out_guards after running Learner
with DynamoExplainCallback
:
# PyTorch 2.0
print(learn.dynamo_explain.out_guards)
# PyTorch 2.1
print(learn.dynamo_explain.explain_output.out_guards)
Convenience Method
fastxtend adds a convenience method to fastai.learner.Learner
to easily enable torch.compile
.
Learner.compile
Learner.compile (fullgraph:bool=False, dynamic:bool=False, backend:Union[str,Callable]='inductor', mode:Union[str,__main__.CompileMode,NoneType]=None, options:Optional[Dict[str,Union[str,int,bool]]]=None, ma tmul_precision:Union[str,__main__.MatMulPrecision]='high ', recompile:bool=False, verbose:bool=False)
Set Learner
to compile model using torch.compile
via CompilerCallback
Type | Default | Details | |
---|---|---|---|
fullgraph | bool | False | Prevent breaking model into subgraphs |
dynamic | bool | False | Use dynamic shape tracing. Sets to False if PyTorch < 2.1 |
backend | str | Callable | inductor | torch.compile backend to use |
mode | str | CompileMode | None | None | torch.compile mode to use |
options | Dict[str, Union[str, int, bool]] | None | None | Extra options to pass to compile backend |
matmul_precision | str | MatMulPrecision | high | Set Ampere and newer matmul precision |
recompile | bool | False | Force a compiled model to recompile. Use when freezing/unfreezing a compiled model. |
verbose | bool | False | Verbose output |
compile
only sets dynamic
if using PyTorch 2.1 or later, for PyTorch 2.0 it’s hardcoded to False
. You can override this by directly setting via CompilerCallback
.
To use, call the compile
method after initalizing a fastai.learner.Learner
.
= Learner(...).compile()
learn 1) learn.fine_tune(
Compatability Patches
These patches integrate torch.compile
with fastai exporting, loading, freezing, unfreezing, and fine tuning.
Exporting and Loading
Learner.export
Learner.export (fname:Union[str,os.PathLike,BinaryIO,IO[bytes]]='export. pkl', pickle_module:Any=<module 'pickle' from '/opt/hoste dtoolcache/Python/3.9.18/x64/lib/python3.9/pickle.py'>, pickle_protocol:int=2)
Export the content of self
without the items and the optimizer state for inference
Type | Default | Details | |
---|---|---|---|
fname | FILE_LIKE | export.pkl | Learner export file name, path, bytes, or IO |
pickle_module | Any | pickle | Module used for pickling metadata and objects |
pickle_protocol | int | 2 | Pickle protocol used |
load_learner
load_learner (fname:Union[str,os.PathLike,BinaryIO,IO[bytes]], cpu:bool=True, pickle_module=<module 'pickle' from '/opt/ho stedtoolcache/Python/3.9.18/x64/lib/python3.9/pickle.py'>)
Load a Learner
object in fname
, by default putting it on the cpu
Type | Default | Details | |
---|---|---|---|
fname | FILE_LIKE | File name, path, bytes, or IO | |
cpu | bool | True | Load model to CPU |
pickle_module | module | pickle | Module used for unpickling metadata and objects |
By default, load_learner
will remove the CompilerCallback
.
Freezing and Unfreezing
Learner.freeze_to
Learner.freeze_to (n:int)
Freeze parameter groups up to n
Freezing and unfreezing models works, but they need to be recompiled after. freeze_to
will set CompilerCallback
to recompile the model or warn users they need to manually recompile.
Training
Learner.fine_tune
Learner.fine_tune (epochs:int, base_lr:float=0.002, freeze_epochs:int=1, lr_mult:Union[int,float]=100, pct_start:float=0.3, div:Union[int,float]=5.0, compile_frozen:bool=False, lr_max=None, div_final=100000.0, wd=None, moms=None, cbs=None, reset_opt=False, start_epoch=0)
Fine tune with Learner.freeze
for freeze_epochs
, then with Learner.unfreeze
for epochs
, using discriminative LR.
Type | Default | Details | |
---|---|---|---|
epochs | int | Number of unfrozen epochs to train | |
base_lr | float | 0.002 | Base learning rate, model head unfrozen learning rate |
freeze_epochs | int | 1 | Number of frozen epochs to train |
lr_mult | Numeric | 100 | Model stem unfrozen learning rate: base_lr/lr_mult |
pct_start | float | 0.3 | Start unfrozen learning rate cosine annealing |
div | Numeric | 5.0 | Initial unfrozen learning rate: base_lr/div |
compile_frozen | bool | False | Compile model during frozen finetuning if CompilerCallback is used |
lr_max | NoneType | None | |
div_final | float | 100000.0 | |
wd | NoneType | None | |
moms | NoneType | None | |
cbs | NoneType | None | |
reset_opt | bool | False | |
start_epoch | int | 0 |
By default, fine_tune
will not compile the freeze_epochs
, but this can be overridden by passing freeze_compile=True
. If the model is already compiled, this will have no effect.