Menu

Deep learning for NeuroImaging in Python.

Note

This page is a reference documentation. It only explains the class signature, and not how to use it. Please refer to the gallery for the big picture.

class nidl.callbacks.model_probing.ModelProbing(train_dataloader: DataLoader, test_dataloader: DataLoader, probe_name: str | None = None, every_n_train_epochs: int | None = 1, every_n_val_epochs: int | None = None, on_test_epoch_start: bool = False, on_test_epoch_end: bool = False, prog_bar: bool = True)[source]

Bases: ABC, Callback

Callback implementing the basic logic of embedding model’s probing.

  1. Embeds the input data (training+test) through the estimator using transform_step method.

  2. Train the probe on the training embedding

  3. Test the probe on the test embedding and log the metrics

This callback is abstract and should be inherited to implement the fit, predict and log_metrics methods (e.g. for Ridge regression, KNN regression, logistic regression, KNN classification, …).

Parameters:

train_dataloader : torch.utils.data.DataLoader

Training dataloader yielding batches in the form (X, y) for further embedding and training of the probe.

test_dataloader : torch.utils.data.DataLoader

Test dataloader yielding batches in the form (X, y) for further embedding and test of the probe.

probe_name : str or None, default=None

Name of the probe displayed when logging the results. It will appear as <probe_name>/<metric_name> for each metric. If None, only <metric_name> is displayed.

every_n_train_epochs : int or None, default=1

Number of training epochs after which to run the linear probing. Disabled if None.

every_n_val_epochs : int or None, default=None

Number of validation epochs after which to run the linear probing. Disabled if None.

on_test_epoch_start : bool, default=False

Whether to run the linear probing at the start of the test epoch.

on_test_epoch_end : bool, default=False

Whether to run the linear probing at the end of the test epoch.

prog_bar : bool, default=True

Whether to display the metrics in the progress bar.

extract_features(pl_module, dataloader)[source]

Extract features from a dataloader with the BaseEstimator.

By default, it uses the transform_step logic applied on each batch to get the embeddings with the labels. The input dataloader should yield batches of the form (X, y) where X is the input data and y is the label.

Parameters:

pl_module : BaseEstimator

The BaseEstimator module that implements the ‘transform_step’.

dataloader : torch.utils.data.DataLoader

The dataloader to extract features from. It should yield batches of the form (X, y) where X is the input data and y is the label.

Returns:

tuple of (X, y)

Tuple of numpy arrays (X, y) where X is the extracted features and y is the corresponding labels.

abstract fit(X, y)[source]

Fit the probe on the embeddings and labels of the training data.

linear_probing(pl_module: BaseEstimator)[source]

Perform the linear probing on the given estimator.

This method performs the following steps: 1) Extracts the features from the training and test dataloaders 2) Fits the probe on the training features and labels 3) Makes predictions on the test features 4) Computes and logs the metrics.

Parameters:

pl_module : BaseEstimator

The BaseEstimator module that implements the ‘transform_step’.

Raises:

ValueError : If the pl_module does not inherit from BaseEstimator or

from TransformerMixin.

abstract log_metrics(pl_module, y_pred, y_true)[source]

Log the metrics given the predictions and the true labels.

on_test_epoch_end(trainer, pl_module)[source]

Called when the test epoch ends.

on_test_epoch_start(trainer, pl_module)[source]

Called when the test epoch begins.

on_train_epoch_end(trainer, pl_module)[source]

Called when the train epoch ends.

To access all batch outputs at the end of the epoch, you can cache step outputs as an attribute of the pytorch_lightning.core.LightningModule and access them in this hook:

class MyLightningModule(L.LightningModule):
    def __init__(self):
        super().__init__()
        self.training_step_outputs = []

    def training_step(self):
        loss = ...
        self.training_step_outputs.append(loss)
        return loss

class MyCallback(L.Callback):
    def on_train_epoch_end(self, trainer, pl_module):
        # do something with all training_step outputs, for example:
        epoch_mean = torch.stack(pl_module.training_step_outputs).mean()
        pl_module.log("training_epoch_mean", epoch_mean)
        # free up the memory
        pl_module.training_step_outputs.clear()
on_validation_epoch_end(trainer, pl_module)[source]

Called when the val epoch ends.

abstract predict(X)[source]

Predict the probe on new data X.

Follow us

© 2025, nidl developers