Menu

Deep learning for NeuroImaging in Python.

Note

This page is a reference documentation. It only explains the class signature, and not how to use it. Please refer to the gallery for the big picture.

class nidl.callbacks.model_probing.ClassificationProbingCallback(train_dataloader, test_dataloader, probe, probe_name=None, every_n_train_epochs=1, every_n_val_epochs=None, on_test_epoch_start=False, on_test_epoch_end=False, prog_bar=True)[source]

Bases: ModelProbing

Perform classification on top of an embedding model.

Concretely this callback:

  1. Embeds the input data through the torch model.

  2. Fits the classification probe on the embedded data.

  3. Logs the main classification metrics:

    • precision (macro)

    • recall (macro)

    • f1-score (weighted and macro)

    • accuracy (global)

    • balanced accuracy

Please check this User Guide for more details on the classification metrics reported.

Parameters:

train_dataloader : torch.utils.data.DataLoader

Training dataloader yielding batches in the form (X, y) for further embedding and training of the probe.

test_dataloader : torch.utils.data.DataLoader

Test dataloader yielding batches in the form (X, y) for further embedding and test of the probe.

probe : sklearn.base.BaseEstimator

The scikit-learn classifier to be trained on the embedding.

probe_name : str or None, default=None

Name of the probe displayed when logging the results. It will appear as <probe_name>/<metric_name> for each metric. If None, <probe_class_name>/<metric_name> is displayed.

every_n_train_epochs : int or None, default=1

Number of training epochs after which to run the probing. Disabled if None.

every_n_val_epochs : int or None, default=None

Number of validation epochs after which to run the probing. Disabled if None.

on_test_epoch_start : bool, default=False

Whether to run the probing at the start of the test epoch.

on_test_epoch_end : bool, default=False

Whether to run the probing at the end of the test epoch.

prog_bar : bool, default=True

Whether to display the metrics in the progress bar.

log_metrics(pl_module, y_pred, y_true)[source]

Log the metrics given the predictions and the true labels.

Examples

Model probing callback of embedding estimators

Model probing callback of embedding estimators

Follow us

© 2025, nidl developers