A loss wrapper and callback to calculate and log individual losses as fastxtend metrics.

class MultiLoss[source]

MultiLoss(loss_funcs:listy[Callable[..., nn.Module] | FunctionType], weights:listified[Number] | None=None, loss_kwargs:listy[dict[str, Any]] | None=None, loss_names:listy[str] | None=None, reduction:str | None='mean') :: Module

Combine multiple loss_funcs on one prediction & target via reduction, with optional weighting.

Log loss_funcs as metrics via MultiLossCallback, optionally using loss_names.

Type Default Details
loss_funcs listy[Callable[..., nn.Module] or FunctionType] Uninitialized loss functions or classes. Must support PyTorch reduction string.
weights listified[Number] or None None Weight per loss. Defaults to uniform weighting.
loss_kwargs listy[dict[str, Any]] or None None kwargs to pass to each loss function. Defaults to None.
loss_names listy[str] or None None Loss names to log using MultiLossCallback. Defaults to loss __name__.
reduction str or None mean PyTorch loss reduction

MultiLoss is a simple multiple loss wrapper which allows logging each individual loss automatically using the MultiLossCallback.

Pass uninitialized loss functions to loss_funcs, optional per loss weighting via weights, any loss arguments via a list of dictionaries in loss_kwargs, and optional names for each individual loss via loss_names.

If passed, weights, loss_kwargs, & loss_names must be an iterable of the same length as loss_funcs.

Output from each loss function must be the same shape.

class MultiTargetLoss[source]

MultiTargetLoss(loss_funcs:listy[Callable[..., nn.Module] | FunctionType], weights:listified[Number] | None=None, loss_kwargs:listy[dict[str, Any]] | None=None, loss_names:listy[str] | None=None, reduction:str | None='mean') :: MultiLoss

Combine loss_funcs from multiple predictions & targets via reduction, with optional weighting.

Log loss_funcs as metrics via MultiLossCallback, optionally using loss_names.

Type Default Details
loss_funcs listy[Callable[..., nn.Module] or FunctionType] Uninitialized loss functions or classes. One per prediction and target. Must support PyTorch reduction string.
weights listified[Number] or None None Weight per loss. Defaults to uniform weighting.
loss_kwargs listy[dict[str, Any]] or None None kwargs to pass to each loss function. Defaults to None.
loss_names listy[str] or None None Loss names to log using MultiLossCallback. Defaults to loss __name__.
reduction str or None mean PyTorch loss reduction

MultiTargetLoss a single loss per multiple target version of Multiloss. It is a simple multiple loss wrapper which allows logging each individual loss automatically using the MultiLossCallback.

Pass uninitialized loss functions to loss_funcs, optional per loss weighting via weights, any loss arguments via a list of dictionaries in loss_kwargs, and optional names for each individual loss via loss_names.

If passed, weights, loss_kwargs, & loss_names must be an iterable of the same length as loss_funcs.

Output from each loss function must be the same shape.

class MultiLossCallback[source]

MultiLossCallback(beta:float=0.98, reduction:str | None='mean') :: Callback

Callback to automatically log and name MultiLoss losses as fastxtend metrics

Type Default Details
beta float 0.98 Smoothing beta
reduction str or None mean Override loss reduction for logging

Example

with no_random():
    mloss = MultiLoss(loss_funcs=[nn.MSELoss, nn.L1Loss], 
                      weights=[1, 3.5],
                      loss_names=['mse_loss', 'l1_loss'])


    learn = synth_learner(n_trn=5, loss_func=mloss, metrics=RMSE(), cbs=MultiLossCallback)
    learn.fit(5)
epoch train_loss train_mse_loss train_l1_loss valid_loss valid_mse_loss valid_l1_loss valid_rmse time
0 23.598301 12.719514 10.878788 17.910727 9.067028 8.843699 3.011151 00:00
1 22.448792 11.937573 10.511218 15.481797 7.464430 8.017367 2.732111 00:00
2 20.827835 10.837888 9.989948 12.756706 5.756156 7.000550 2.399199 00:00
3 19.028177 9.657351 9.370827 10.031281 4.145008 5.886274 2.035929 00:00
4 17.167393 8.481768 8.685625 7.581020 2.787561 4.793459 1.669599 00:00