PyTorch Compile
torch.compile into fastai
The CompilerCallback and DynamoExplainCallback provide an easy to use torch.compile integration for fastai.
For more information on torch.compile please read PyTorch’s getting started guide. For troubleshooting torch.compile refer to this PyTorch Nightly guide.
This module is not imported via any fastxtend all imports. You must import it separately after importing fastai and fastxtend as it modifies model saving, loading, and training:
from fastxtend.callback import compiler
# or
from fastxtend.callback.compiler import *To use, create a fastai.learner.Learner with a torch.compile compatible model and call compile on the Learner or pass CompilerCallback to the Learner of fit method callbacks.
Learner(...).compile()
# or
Learner(..., cbs=CompilerCallback())CompileMode
CompileMode (value, names=None, module=None, qualname=None, type=None, start=1)
All valid torch.compile modes for tab-completion and typo-proofing
Currently, the ‘reduce-overhead’ mode doesn’t appear to work with all models, and ‘max-autotune’ shouldn’t be used per Compile troubleshooting and gotchas.
MatMulPrecision
MatMulPrecision (value, names=None, module=None, qualname=None, type=None, start=1)
All valid matmul_precision modes for tab-completion and typo-proofing
CompilerCallback
CompilerCallback (fullgraph:bool=False, dynamic:bool=False, backend:str|Callable='inductor', mode:str|CompileMode|None=None, options:Dict[str,Union[str,int,bool]]|None=None, matmul_precision:str|MatMulPrecision='high', recompile:bool=False, verbose:bool=False)
A callback for using torch.compile (beta) with fastai
| Type | Default | Details | |
|---|---|---|---|
| fullgraph | bool | False | Prevent breaking model into subgraphs |
| dynamic | bool | False | Use dynamic shape tracing |
| backend | str | Callable | inductor | torch.compile backend to use |
| mode | str | CompileMode | None | None | torch.compile mode to use |
| options | Dict[str, Union[str, int, bool]] | None | None | Extra options to pass to compile backend |
| matmul_precision | str | MatMulPrecision | high | Set Ampere and newer matmul precision |
| recompile | bool | False | Force a compiled model to recompile. Use when freezing/unfreezing a compiled model. |
| verbose | bool | False | Verbose output |
Using torch.compile with mode='max-autotune' is under active development and might fail. See Compile troubleshooting and gotchas for more details.
By default, CompilerCallback will set matmul ops to use TensorFloat32 for supported GPUs, which is the recommended setting for torch.compile. Set matmul_precision='highest' to turn off or matmul_precision='medium' to enable bfloat16 mode.
fastxtend provides the compile convenience method for easily enabling torch.compile. Or you can pass CompilerCallback to the cbs argument of the fastai.learner.Learner or a fit method.
learn = Learner(..., cbs=CompilerCallback())
learn.fine_tune(1)DynamoExplainCallback
DynamoExplainCallback (print_results:bool=True, out_guards:bool=False, ops_per_graph:bool=False, break_reasons:bool=False)
A callback to automate finding graph breaks with PyTorch Compile’s Dynamo Explain
| Type | Default | Details | |
|---|---|---|---|
| print_results | bool | True | Print enabled torch._dynamo.explain output(s) |
| out_guards | bool | False | Print the out_guards output |
| ops_per_graph | bool | False | Print the ops_per_graph output |
| break_reasons | bool | False | Print the break_reasons output |
DynamoExplainCallback automates finding graph breaks using torch._dynamo.explain per the Identifying the cause of a graph break section in the PyTorch Compile FAQ. DynamoExplainCallback uses one batch from the validation dataloader1 to generate the _dynamo.explain report(s) and then cancels training.
To use, pass DynamoExplainCallback to the cbs argument of the fastai.learner.Learner or fit method.
learn = Learner(..., cbs=DynamoExplainCallback())
learn.fit(1)By default, DynamoExplainCallback prints the basic explanation output from _dynamo.explain, with arguments to enable printing out_guards and/or ops_per_graph.
All _dynamo.explain outputs are stored as attributes in the callback for later reference. For example, to view the out_guards after running Learner with DynamoExplainCallback:
# PyTorch 2.0
print(learn.dynamo_explain.out_guards)
# PyTorch 2.1
print(learn.dynamo_explain.explain_output.out_guards)Convenience Method
fastxtend adds a convenience method to fastai.learner.Learner to easily enable torch.compile.
Learner.compile
Learner.compile (fullgraph:bool=False, dynamic:bool=False, backend:Union[str,Callable]='inductor', mode:Union[str,__main__.CompileMode,NoneType]=None, options:Optional[Dict[str,Union[str,int,bool]]]=None, ma tmul_precision:Union[str,__main__.MatMulPrecision]='high ', recompile:bool=False, verbose:bool=False)
Set Learner to compile model using torch.compile via CompilerCallback
| Type | Default | Details | |
|---|---|---|---|
| fullgraph | bool | False | Prevent breaking model into subgraphs |
| dynamic | bool | False | Use dynamic shape tracing. Sets to False if PyTorch < 2.1 |
| backend | str | Callable | inductor | torch.compile backend to use |
| mode | str | CompileMode | None | None | torch.compile mode to use |
| options | Dict[str, Union[str, int, bool]] | None | None | Extra options to pass to compile backend |
| matmul_precision | str | MatMulPrecision | high | Set Ampere and newer matmul precision |
| recompile | bool | False | Force a compiled model to recompile. Use when freezing/unfreezing a compiled model. |
| verbose | bool | False | Verbose output |
compile only sets dynamic if using PyTorch 2.1 or later, for PyTorch 2.0 it’s hardcoded to False. You can override this by directly setting via CompilerCallback.
To use, call the compile method after initalizing a fastai.learner.Learner.
learn = Learner(...).compile()
learn.fine_tune(1)Compatability Patches
These patches integrate torch.compile with fastai exporting, loading, freezing, unfreezing, and fine tuning.
Exporting and Loading
Learner.export
Learner.export (fname:Union[str,os.PathLike,BinaryIO,IO[bytes]]='export. pkl', pickle_module:Any=<module 'pickle' from '/opt/hoste dtoolcache/Python/3.9.18/x64/lib/python3.9/pickle.py'>, pickle_protocol:int=2)
Export the content of self without the items and the optimizer state for inference
| Type | Default | Details | |
|---|---|---|---|
| fname | FILE_LIKE | export.pkl | Learner export file name, path, bytes, or IO |
| pickle_module | Any | pickle | Module used for pickling metadata and objects |
| pickle_protocol | int | 2 | Pickle protocol used |
load_learner
load_learner (fname:Union[str,os.PathLike,BinaryIO,IO[bytes]], cpu:bool=True, pickle_module=<module 'pickle' from '/opt/ho stedtoolcache/Python/3.9.18/x64/lib/python3.9/pickle.py'>)
Load a Learner object in fname, by default putting it on the cpu
| Type | Default | Details | |
|---|---|---|---|
| fname | FILE_LIKE | File name, path, bytes, or IO | |
| cpu | bool | True | Load model to CPU |
| pickle_module | module | pickle | Module used for unpickling metadata and objects |
By default, load_learner will remove the CompilerCallback.
Freezing and Unfreezing
Learner.freeze_to
Learner.freeze_to (n:int)
Freeze parameter groups up to n
Freezing and unfreezing models works, but they need to be recompiled after. freeze_to will set CompilerCallback to recompile the model or warn users they need to manually recompile.
Training
Learner.fine_tune
Learner.fine_tune (epochs:int, base_lr:float=0.002, freeze_epochs:int=1, lr_mult:Union[int,float]=100, pct_start:float=0.3, div:Union[int,float]=5.0, compile_frozen:bool=False, lr_max=None, div_final=100000.0, wd=None, moms=None, cbs=None, reset_opt=False, start_epoch=0)
Fine tune with Learner.freeze for freeze_epochs, then with Learner.unfreeze for epochs, using discriminative LR.
| Type | Default | Details | |
|---|---|---|---|
| epochs | int | Number of unfrozen epochs to train | |
| base_lr | float | 0.002 | Base learning rate, model head unfrozen learning rate |
| freeze_epochs | int | 1 | Number of frozen epochs to train |
| lr_mult | Numeric | 100 | Model stem unfrozen learning rate: base_lr/lr_mult |
| pct_start | float | 0.3 | Start unfrozen learning rate cosine annealing |
| div | Numeric | 5.0 | Initial unfrozen learning rate: base_lr/div |
| compile_frozen | bool | False | Compile model during frozen finetuning if CompilerCallback is used |
| lr_max | NoneType | None | |
| div_final | float | 100000.0 | |
| wd | NoneType | None | |
| moms | NoneType | None | |
| cbs | NoneType | None | |
| reset_opt | bool | False | |
| start_epoch | int | 0 |
By default, fine_tune will not compile the freeze_epochs, but this can be overridden by passing freeze_compile=True. If the model is already compiled, this will have no effect.