Note

This page is a reference documentation. It only explains the class signature, and not how to use it. Please refer to the user guide for the big picture.

nidl.callbacks.MetricsCallback

class nidl.callbacks.MetricsCallback(metrics, needs=None, compute_per_training_step=True, compute_per_val_step=False, compute_per_test_step=False, compute_on_cpu=False, every_n_train_steps=1, every_n_train_epochs=None, every_n_val_epochs=1, on_test_end=False, prog_bar=True)[source]

Bases: Callback

Callback to compute and log metrics during training, validation and test of a PL model.

This callback will:

  1. Collect the model’s outputs after each training/validation/test step.

  2. Compute the required metrics on the collected outputs, either batch-wise or epoch-wise (depending on the use-case).

  3. Log the metrics after each iteration or epoch.

It handles multi-GPU / distributed training setups. It is compatible with torchmetrics metrics, scikit-learn metrics, and custom metric functions for metrics computation.

Parameters:
metricsdict[str, Callable] or list[Callable] or Callable

A list of metrics (callables including torchmetrics.Metric and sklearn metrics) to be computed on the model’s outputs. If a dict is provided, the keys will be used as the metric names when logging. If a Callable or list[Callable] is provided, the metric names will be inferred from the function or class names.

needsNeedsType, optional

A mapping defining which outputs from the model are needed to compute the metrics.

It can be either:

  • None: the needed arguments are inferred automatically from the metric signatures and parsed from the model outputs. It will raise an error if the required arguments are not found in the outputs or if there are ambiguities.

  • list[str | Callable] or dict[str, str | Callable]: applies to all metrics.

    For example, positional arguments:

    ["logits", "labels"]
    

    if the metric needs preds and targets as positional arguments, and these are found in the model outputs under the keys "logits" and "labels".

    Or keyword arguments:

    {
        "preds": lambda outputs: outputs["logits"].softmax(dim=-1),
        "targets": "labels",
    }
    

    if the metric needs preds and targets as keyword arguments, and these are found in the model outputs under the keys "logits" (with pre-processing required) and "labels" (used as-is), respectively.

  • dict[str, list[str | Callable] | dict[str, str | Callable]]: per-metric overrides, keyed by metric name.

    For example:

    {
        "Accuracy": [
            lambda outputs: outputs["logits"].softmax(dim=-1),
            "labels",
        ],
        "Alignment": {"z1": "Z1", "z2": "Z2"},
    }
    

    if different metrics need different arguments from the outputs. The same logic applies per metric as above.

compute_per_training_stepbool, default=True

Ignored for torchmetrics.Metric instances, which always handle per-step updates internally. For other metrics (e.g. sklearn metrics or custom functions), whether to compute the metrics at each training step (batch) or only at the end of the epoch. If True, metrics are computed at each training step and averaged at the end of the epoch. This is useful for metrics that can be computed batch-wise (e.g. accuracy) or for efficiency. If False, all needed outputs are collected and the metric is computed only once, ensuring exact results but requiring more memory.

compute_per_val_stepbool, default=False

Same as compute_per_training_step but for validation.

compute_per_test_stepbool, default=False

Same as compute_per_training_step but for test.

compute_on_cpubool, default=False

Whether to move the collected outputs to CPU for metrics computation. This is useful to avoid GPU memory issues when using metrics that require all outputs to be in memory at once. If False, outputs are kept on the original device (GPU or CPU).

every_n_train_stepsint or None, default=1

Frequency (in training steps) to compute and log metrics during training. If 0 or None, metrics are not computed during training. This is mutually exclusive with every_n_train_epochs.

every_n_train_epochsint or None, default=None

Frequency (in epochs) to compute and log metrics during training. If 0 or None, metrics are not computed during training. This is mutually exclusive with every_n_train_steps.

every_n_val_epochsint or None, default=1

Frequency (in epochs) to compute and log metrics during validation. If 0 or None, metrics are not computed during validation.

on_test_endbool, default=False

Whether to compute and log metrics at the end of the test.

prog_barbool, default=True

Whether to display the metrics in the progress bar.

Notes

We assume that the model’s step methods (training_step, validation_step, test_step) return a dict of outputs that contain all the necessary information to compute the metrics. Some keys in this dict should match those specified in the needs argument. The values should be tensors or numpy arrays.

If scikit-learn metrics or custom functions are used and compute_per_(training|val|test)_step is False, the entire outputs needed for metric computations must fit in memory (either CPU or GPU). We advise using compute_on_cpu=True in this case to avoid GPU memory issues.

We recommend using torchmetrics.Metric instances whenever possible, as they handle memory efficiently in a distributed fashion.

Examples

A simple use-case for classification metrics during training and validation. We assume that the model’s training_step and validation_step return the logits as “preds” and targets as “targets” in their outputs dictionary (matching signature of the metrics):

>>> from nidl.callbacks import MetricsCallback
>>> from torchmetrics.metrics import Accuracy, F1Score
>>> metrics = {"acc": Accuracy(), "f1": F1Score()}
>>> metrics_callback = MetricsCallback(
...     metrics=metrics,
...     every_n_train_epochs=1,
...     every_n_val_epochs=1,
...     on_test_end=True,
... )

Another use-case for self-supervised metrics during training only. We assume that the model’s training_step returns the embeddings “Z1” and “Z2” in its outputs dictionary and the metrics require keyword arguments “z1” and “z2”:

>>> from nidl.callbacks import MetricsCallback
>>> from nidl.metrics.ssl import alignment_score
>>> metrics = {"alignment": alignment_score}
>>> metrics_callback = MetricsCallback(
...     metrics=metrics,
...     needs={
...         "alignment: {"z1": "Z1", "z2": "Z2"},
...     }
...     every_n_train_epochs=1,
...     every_n_val_epochs=None,
...     on_test_end=False,
... )
__init__(metrics, needs=None, compute_per_training_step=True, compute_per_val_step=False, compute_per_test_step=False, compute_on_cpu=False, every_n_train_steps=1, every_n_train_epochs=None, every_n_val_epochs=1, on_test_end=False, prog_bar=True)[source]
compute_metrics_and_log(trainer, pl_module, collector, on_step, on_epoch, reset=True)[source]

Compute and log metrics using the collected outputs.

on_test_batch_end(trainer, pl_module, outputs, batch, batch_idx, dataloader_idx=0)[source]

Called when the test batch ends.

on_test_epoch_end(trainer, pl_module)[source]

Called when the test epoch ends.

on_train_batch_end(trainer, pl_module, outputs, batch, batch_idx)[source]

Called when the train batch ends.

Note:

The value outputs["loss"] here will be the normalized value w.r.t accumulate_grad_batches of the loss returned from training_step.

on_train_epoch_end(trainer, pl_module)[source]

Called when the train epoch ends.

To access all batch outputs at the end of the epoch, you can cache step outputs as an attribute of the pytorch_lightning.core.LightningModule and access them in this hook:

class MyLightningModule(L.LightningModule):
    def __init__(self):
        super().__init__()
        self.training_step_outputs = []

    def training_step(self):
        loss = ...
        self.training_step_outputs.append(loss)
        return loss

class MyCallback(L.Callback):
    def on_train_epoch_end(self, trainer, pl_module):
        # do something with all training_step outputs, for example:
        epoch_mean = torch.stack(pl_module.training_step_outputs).mean()
        pl_module.log("training_epoch_mean", epoch_mean)
        # free up the memory
        pl_module.training_step_outputs.clear()
on_validation_batch_end(trainer, pl_module, outputs, batch, batch_idx, dataloader_idx=0)[source]

Called when the validation batch ends.

on_validation_epoch_end(trainer, pl_module)[source]

Called when the val epoch ends.

Examples using nidl.callbacks.MetricsCallback

Visualization of metrics during training of PyTorch-Lightning models

Visualization of metrics during training of PyTorch-Lightning models