MultiLoss

A loss wrapper and callback to calculate and log individual losses as fastxtend metrics.

MultiLoss and MultiTargetLoss are two simple multiple loss wrappers which allow logging individual losses as metrics using MultiLossCallback. Jump to the example to see how to use with fastai.

Via MixHandlerX, both are compatible with MixUp, CutMix, CutMixUp, and CutMixUpAugment.


source

MultiLoss

 MultiLoss (loss_funcs:Listy[Callable[...,nn.Module]|FunctionType],
            weights:Listy[Numeric|Tensor]|None=None,
            loss_kwargs:Listy[dict[str,Any]]|None=None,
            loss_names:Listy[str]|None=None, reduction:str|None='mean')

Combine multiple loss_funcs on one prediction & target via reduction, with optional weighting.

Log loss_funcs as metrics via MultiLossCallback, optionally using loss_names.

Type Default Details
loss_funcs Listy[Callable[…, nn.Module] | FunctionType] Uninitialized loss functions or classes. Must support PyTorch reduction string.
weights Listy[Numeric | Tensor] | None None Weight per loss. Defaults to uniform weighting.
loss_kwargs Listy[dict[str, Any]] | None None kwargs to pass to each loss function. Defaults to None.
loss_names Listy[str] | None None Loss names to log using MultiLossCallback. Defaults to loss __name__.
reduction str | None mean PyTorch loss reduction

MultiLoss is a simple multiple loss wrapper which allows logging each individual loss as a metric using the MultiLossCallback.

Pass uninitialized loss functions to loss_funcs, optional per loss weighting via weights, any loss arguments via a list of dictionaries in loss_kwargs, and optional names for each individual loss via loss_names.

If passed, weights, loss_kwargs, & loss_names must be an iterable of the same length as loss_funcs.

Output from each loss function must be the same shape.


source

MultiTargetLoss

 MultiTargetLoss (loss_funcs:Listy[Callable[...,nn.Module]|FunctionType],
                  weights:Listy[Numeric|Tensor]|None=None,
                  loss_kwargs:Listy[dict[str,Any]]|None=None,
                  loss_names:Listy[str]|None=None,
                  reduction:str|None='mean')

Combine loss_funcs from multiple predictions & targets via reduction, with optional weighting.

Log loss_funcs as metrics via MultiLossCallback, optionally using loss_names.

Type Default Details
loss_funcs Listy[Callable[…, nn.Module] | FunctionType] Uninitialized loss functions or classes. One per prediction and target. Must support PyTorch reduction string.
weights Listy[Numeric | Tensor] | None None Weight per loss. Defaults to uniform weighting.
loss_kwargs Listy[dict[str, Any]] | None None kwargs to pass to each loss function. Defaults to None.
loss_names Listy[str] | None None Loss names to log using MultiLossCallback. Defaults to loss __name__.
reduction str | None mean PyTorch loss reduction

MultiTargetLoss a single loss per multiple target version of MultiLoss. It is a simple multiple loss wrapper which allows logging each individual loss as a metric using the MultiLossCallback.

Pass uninitialized loss functions to loss_funcs, optional per loss weighting via weights, any loss arguments via a list of dictionaries in loss_kwargs, and optional names for each individual loss via loss_names.

If passed, weights, loss_kwargs, & loss_names must be an iterable of the same length as loss_funcs.

Output from each loss function must be the same shape.


source

MixHandlerX

 MixHandlerX (alpha:float=0.5, interp_label:bool|None=None)

A handler class for implementing MixUp style scheduling. Like fastai’s MixHandler but supports MultiLoss.

Type Default Details
alpha float 0.5 Alpha & beta parametrization for Beta distribution
interp_label bool | None None Blend or stack labels. Defaults to loss_func.y_int if None

If interp_label is false, then labels will be blended together. Use with losses that prefer floats as labels such as BCE.

If interp_label is true, then MixHandlerX will call the loss function twice, once with each label, and blend the losses together. Use with losses that prefer class integers as labels such as CE.

If interp_label is None, then it is set via loss_func.y_int.

Note

MixHandlerX is defined here to prevent a circular import between multiloss and cutmixup modules.


source

MultiLossCallback

 MultiLossCallback (beta:float=0.98, reduction:str|None='mean')

Callback to automatically log and name MultiLoss losses as fastxtend metrics

Type Default Details
beta float 0.98 Smoothing beta
reduction str | None mean Override loss reduction for logging

Example

with no_random():
    mloss = MultiLoss(loss_funcs=[nn.MSELoss, nn.L1Loss],
                      weights=[1, 3.5],
                      loss_names=['mse_loss', 'l1_loss'])


    learn = synth_learner(n_trn=5, loss_func=mloss, metrics=RMSE(), cbs=MultiLossCallback)
    learn.fit(5)
epoch train_loss train_mse_loss train_l1_loss valid_loss valid_mse_loss valid_l1_loss valid_rmse time
0 23.598301 12.719514 10.878788 17.910727 9.067028 8.843699 3.011151 00:00
1 22.448792 11.937573 10.511218 15.481797 7.464430 8.017367 2.732111 00:00
2 20.827835 10.837888 9.989948 12.756706 5.756156 7.000550 2.399199 00:00
3 19.028177 9.657351 9.370827 10.031281 4.145008 5.886274 2.035929 00:00
4 17.167393 8.481768 8.685625 7.581020 2.787561 4.793459 1.669599 00:00