A loss wrapper and callback to calculate and log individual losses as fastxtend metrics.
MultiLoss and MultiTargetLoss are two simple multiple loss wrappers which allow logging individual losses as metrics using MultiLossCallback. Jump to the example to see how to use with fastai.
Combine multiple loss_funcs on one prediction & target via reduction, with optional weighting.
Log loss_funcs as metrics via MultiLossCallback, optionally using loss_names.
Type
Default
Details
loss_funcs
Listy[Callable[…, nn.Module] | FunctionType]
Uninitialized loss functions or classes. Must support PyTorch reduction string.
weights
Listy[Numeric | Tensor] | None
None
Weight per loss. Defaults to uniform weighting.
loss_kwargs
Listy[dict[str, Any]] | None
None
kwargs to pass to each loss function. Defaults to None.
loss_names
Listy[str] | None
None
Loss names to log using MultiLossCallback. Defaults to loss __name__.
reduction
str | None
mean
PyTorch loss reduction
MultiLoss is a simple multiple loss wrapper which allows logging each individual loss as a metric using the MultiLossCallback.
Pass uninitialized loss functions to loss_funcs, optional per loss weighting via weights, any loss arguments via a list of dictionaries in loss_kwargs, and optional names for each individual loss via loss_names.
If passed, weights, loss_kwargs, & loss_names must be an iterable of the same length as loss_funcs.
Output from each loss function must be the same shape.
Combine loss_funcs from multiple predictions & targets via reduction, with optional weighting.
Log loss_funcs as metrics via MultiLossCallback, optionally using loss_names.
Type
Default
Details
loss_funcs
Listy[Callable[…, nn.Module] | FunctionType]
Uninitialized loss functions or classes. One per prediction and target. Must support PyTorch reduction string.
weights
Listy[Numeric | Tensor] | None
None
Weight per loss. Defaults to uniform weighting.
loss_kwargs
Listy[dict[str, Any]] | None
None
kwargs to pass to each loss function. Defaults to None.
loss_names
Listy[str] | None
None
Loss names to log using MultiLossCallback. Defaults to loss __name__.
reduction
str | None
mean
PyTorch loss reduction
MultiTargetLoss a single loss per multiple target version of MultiLoss. It is a simple multiple loss wrapper which allows logging each individual loss as a metric using the MultiLossCallback.
Pass uninitialized loss functions to loss_funcs, optional per loss weighting via weights, any loss arguments via a list of dictionaries in loss_kwargs, and optional names for each individual loss via loss_names.
If passed, weights, loss_kwargs, & loss_names must be an iterable of the same length as loss_funcs.
Output from each loss function must be the same shape.
A handler class for implementing MixUp style scheduling. Like fastai’s MixHandler but supports MultiLoss.
Type
Default
Details
alpha
float
0.5
Alpha & beta parametrization for Beta distribution
interp_label
bool | None
None
Blend or stack labels. Defaults to loss_func.y_int if None
If interp_label is false, then labels will be blended together. Use with losses that prefer floats as labels such as BCE.
If interp_label is true, then MixHandlerX will call the loss function twice, once with each label, and blend the losses together. Use with losses that prefer class integers as labels such as CE.
If interp_label is None, then it is set via loss_func.y_int.
Note
MixHandlerX is defined here to prevent a circular import between multiloss and cutmixup modules.