Note
This page is a reference documentation. It only explains the class signature, and not how to use it. Please refer to the user guide for the big picture.
nidl.callbacks.MetricsCallback¶
- class nidl.callbacks.MetricsCallback(metrics, needs=None, compute_per_training_step=True, compute_per_val_step=False, compute_per_test_step=False, compute_on_cpu=False, every_n_train_steps=1, every_n_train_epochs=None, every_n_val_epochs=1, on_test_end=False, prog_bar=True)[source]¶
Bases:
CallbackCallback to compute and log metrics during training, validation and test of a PL model.
This callback will:
Collect the model’s outputs after each training/validation/test step.
Compute the required metrics on the collected outputs, either batch-wise or epoch-wise (depending on the use-case).
Log the metrics after each iteration or epoch.
It handles multi-GPU / distributed training setups. It is compatible with torchmetrics metrics, scikit-learn metrics, and custom metric functions for metrics computation.
- Parameters:
- metricsdict[str, Callable] or list[Callable] or Callable
A list of metrics (callables including
torchmetrics.Metricand sklearn metrics) to be computed on the model’s outputs. If a dict is provided, the keys will be used as the metric names when logging. If aCallableorlist[Callable]is provided, the metric names will be inferred from the function or class names.- needsNeedsType, optional
A mapping defining which outputs from the model are needed to compute the metrics.
It can be either:
None: the needed arguments are inferred automatically from the metric signatures and parsed from the model outputs. It will raise an error if the required arguments are not found in the outputs or if there are ambiguities.list[str | Callable]ordict[str, str | Callable]: applies to all metrics.For example, positional arguments:
["logits", "labels"]
if the metric needs
predsandtargetsas positional arguments, and these are found in the model outputs under the keys"logits"and"labels".Or keyword arguments:
{ "preds": lambda outputs: outputs["logits"].softmax(dim=-1), "targets": "labels", }
if the metric needs
predsandtargetsas keyword arguments, and these are found in the model outputs under the keys"logits"(with pre-processing required) and"labels"(used as-is), respectively.dict[str, list[str | Callable] | dict[str, str | Callable]]: per-metric overrides, keyed by metric name.For example:
{ "Accuracy": [ lambda outputs: outputs["logits"].softmax(dim=-1), "labels", ], "Alignment": {"z1": "Z1", "z2": "Z2"}, }
if different metrics need different arguments from the outputs. The same logic applies per metric as above.
- compute_per_training_stepbool, default=True
Ignored for
torchmetrics.Metricinstances, which always handle per-step updates internally. For other metrics (e.g. sklearn metrics or custom functions), whether to compute the metrics at each training step (batch) or only at the end of the epoch. IfTrue, metrics are computed at each training step and averaged at the end of the epoch. This is useful for metrics that can be computed batch-wise (e.g. accuracy) or for efficiency. IfFalse, all needed outputs are collected and the metric is computed only once, ensuring exact results but requiring more memory.- compute_per_val_stepbool, default=False
Same as
compute_per_training_stepbut for validation.- compute_per_test_stepbool, default=False
Same as
compute_per_training_stepbut for test.- compute_on_cpubool, default=False
Whether to move the collected outputs to CPU for metrics computation. This is useful to avoid GPU memory issues when using metrics that require all outputs to be in memory at once. If
False, outputs are kept on the original device (GPU or CPU).- every_n_train_stepsint or None, default=1
Frequency (in training steps) to compute and log metrics during training. If
0orNone, metrics are not computed during training. This is mutually exclusive withevery_n_train_epochs.- every_n_train_epochsint or None, default=None
Frequency (in epochs) to compute and log metrics during training. If
0orNone, metrics are not computed during training. This is mutually exclusive withevery_n_train_steps.- every_n_val_epochsint or None, default=1
Frequency (in epochs) to compute and log metrics during validation. If
0orNone, metrics are not computed during validation.- on_test_endbool, default=False
Whether to compute and log metrics at the end of the test.
- prog_barbool, default=True
Whether to display the metrics in the progress bar.
Notes
We assume that the model’s step methods (training_step, validation_step, test_step) return a dict of outputs that contain all the necessary information to compute the metrics. Some keys in this dict should match those specified in the needs argument. The values should be tensors or numpy arrays.
If scikit-learn metrics or custom functions are used and compute_per_(training|val|test)_step is False, the entire outputs needed for metric computations must fit in memory (either CPU or GPU). We advise using compute_on_cpu=True in this case to avoid GPU memory issues.
We recommend using torchmetrics.Metric instances whenever possible, as they handle memory efficiently in a distributed fashion.
Examples
A simple use-case for classification metrics during training and validation. We assume that the model’s training_step and validation_step return the logits as “preds” and targets as “targets” in their outputs dictionary (matching signature of the metrics):
>>> from nidl.callbacks import MetricsCallback >>> from torchmetrics.metrics import Accuracy, F1Score >>> metrics = {"acc": Accuracy(), "f1": F1Score()} >>> metrics_callback = MetricsCallback( ... metrics=metrics, ... every_n_train_epochs=1, ... every_n_val_epochs=1, ... on_test_end=True, ... )
Another use-case for self-supervised metrics during training only. We assume that the model’s training_step returns the embeddings “Z1” and “Z2” in its outputs dictionary and the metrics require keyword arguments “z1” and “z2”:
>>> from nidl.callbacks import MetricsCallback >>> from nidl.metrics.ssl import alignment_score >>> metrics = {"alignment": alignment_score} >>> metrics_callback = MetricsCallback( ... metrics=metrics, ... needs={ ... "alignment: {"z1": "Z1", "z2": "Z2"}, ... } ... every_n_train_epochs=1, ... every_n_val_epochs=None, ... on_test_end=False, ... )
- __init__(metrics, needs=None, compute_per_training_step=True, compute_per_val_step=False, compute_per_test_step=False, compute_on_cpu=False, every_n_train_steps=1, every_n_train_epochs=None, every_n_val_epochs=1, on_test_end=False, prog_bar=True)[source]¶
- compute_metrics_and_log(trainer, pl_module, collector, on_step, on_epoch, reset=True)[source]¶
Compute and log metrics using the collected outputs.
- on_test_batch_end(trainer, pl_module, outputs, batch, batch_idx, dataloader_idx=0)[source]¶
Called when the test batch ends.
- on_train_batch_end(trainer, pl_module, outputs, batch, batch_idx)[source]¶
Called when the train batch ends.
- Note:
The value
outputs["loss"]here will be the normalized value w.r.taccumulate_grad_batchesof the loss returned fromtraining_step.
- on_train_epoch_end(trainer, pl_module)[source]¶
Called when the train epoch ends.
To access all batch outputs at the end of the epoch, you can cache step outputs as an attribute of the
pytorch_lightning.core.LightningModuleand access them in this hook:class MyLightningModule(L.LightningModule): def __init__(self): super().__init__() self.training_step_outputs = [] def training_step(self): loss = ... self.training_step_outputs.append(loss) return loss class MyCallback(L.Callback): def on_train_epoch_end(self, trainer, pl_module): # do something with all training_step outputs, for example: epoch_mean = torch.stack(pl_module.training_step_outputs).mean() pl_module.log("training_epoch_mean", epoch_mean) # free up the memory pl_module.training_step_outputs.clear()
Examples using nidl.callbacks.MetricsCallback¶
Visualization of metrics during training of PyTorch-Lightning models