Menu

Deep learning for NeuroImaging in Python.

Note

This page is a reference documentation. It only explains the class signature, and not how to use it. Please refer to the gallery for the big picture.

class nidl.callbacks.model_probing.RegressionProbingCallback(train_dataloader, test_dataloader, probe, probe_name=None, every_n_train_epochs=1, every_n_val_epochs=None, on_test_epoch_start=False, on_test_epoch_end=False, prog_bar=True)[source]

Bases: ModelProbing

Perform regression on top of an embedding model.

Concretely this callback:

  1. Embeds the input data through the estimator.

  2. Fits the regression probe on the embedded data.

  3. Logs the main regression metrics including:

    • mean absolute error

    • median absolute error

    • root mean squared error

    • mean squared error

    • R² score

    • Pearson’s r

    • explained variance score

    If multiple regressors are given (multivariate regression), metrics are computed per regressor and averaged.

Parameters:

train_dataloader : torch.utils.data.DataLoader

Training dataloader yielding batches in the form (X, y) for further embedding and training of the probe.

test_dataloader : torch.utils.data.DataLoader

Test dataloader yielding batches in the form (X, y) for further embedding and test of the probe.

probe : sklearn.base.BaseEstimator

The scikit-learn regressor to be trained on the embedding.

probe_name : str or None, default=None

Name of the probe displayed when logging the results. It will appear as <probe_name>/<metric_name> for each metric. If None, <probe_class_name>/<metric_name> is displayed.

every_n_train_epochs : int or None, default=1

Number of training epochs after which to run the probing. Disabled if None.

every_n_val_epochs : int or None, default=None

Number of validation epochs after which to run the probing. Disabled if None.

on_test_epoch_start : bool, default=False

Whether to run the probing at the start of the test epoch.

on_test_epoch_end : bool, default=False

Whether to run the probing at the end of the test epoch.

prog_bar : bool, default=True

Whether to display the metrics in the progress bar.

log_metrics(pl_module, y_pred, y_true)[source]

Log the metrics given the predictions and the true labels.

Follow us

© 2025, nidl developers