Note

This page is a reference documentation. It only explains the class signature, and not how to use it. Please refer to the user guide for the big picture.

nidl.callbacks.ModelProbingCV

class nidl.callbacks.ModelProbingCV(dataloader, probe, scoring=None, cv=None, n_jobs=None, every_n_train_epochs=1, every_n_val_epochs=None, on_test_epoch_start=False, on_test_epoch_end=False, prog_bar=True, prefix_score='')[source]

Bases: Callback

Callback to probe the representation of an embedding estimator on a dataset using cross-validation.

It has the following logic:

  1. Embeds the input data through the estimator using transform_step.

  2. For each CV split, train the probe on the training embedding split and make predictions on the test embedding split.

  3. Compute and log the scores computed between the true and predicted targets for each CV split.

The probing can be performed at the end of training epochs, validation epochs, and/or at the start/end of the test epoch.

The metrics logged depend on the scoring parameter:

  • If a single score is provided, it logs fold{i}/test_score for each fold i.

  • If multiple scores are provided, it logs each score with its name, such as fold{i}/test_accuracy or fold{i}/test_auc for each fold i.

Eventually, a prefix_score can be added to the score names when logging, such as ridge_ or logreg_ (giving fold{i}/ridge_test_accuracy or fold{i}/logreg_test_accuracy).

Parameters:
dataloader: torch.utils.data.DataLoader

Dataloader yielding batches in the form (X, y) for further embedding and cross-validation of the probe.

probe: sklearn.base.BaseEstimator

The probe model to be trained on the embedding. It must implement fit and predict methods.

scoring: str, callable, list, tuple, or dict, default=None

Strategy to evaluate the performance of the probe across cross-validation splits. The scores are logged into the LightningModule during training/validation/test according to the configuration of the callback.

If scoring represents a single score, one can use:

If scoring represents multiple scores, one can use:

  • a list or tuple of unique strings;

  • a callable returning a dictionary where the keys are the metric names and the values are the metric scores;

  • a dictionary with metric names as keys and callables a values.

cv: int, cross-validation generator or an iterable, default=None

Determines the cross-validation splitting strategy. Possible inputs for cv are:

  • None, to use the default 5-fold cross validation,

  • int, to specify the number of folds in a (Stratified)KFold,

  • CV splitter,

  • An iterable yielding (train, test) splits as arrays of indices.

For int/None inputs, if the probe is a classifier and y is either binary or multiclass, StratifiedKFold is used. In all other cases, KFold is used. These splitters are instantiated with shuffle=False so the splits will be the same across calls.

n_jobsint, default=None

Number of jobs to run in parallel. Training the probe and computing the score are parallelized over the cross-validation splits. None means 1 unless in a joblib.parallel_backend context. -1 means using all processors.

every_n_train_epochs: int or None, default=1

Number of training epochs after which to run the probing. Disabled if None.

every_n_val_epochs: int or None, default=None

Number of validation epochs after which to run the probing. Disabled if None.

on_test_epoch_start: bool, default=False

Whether to run the linear probing at the start of the test epoch.

on_test_epoch_end: bool, default=False

Whether to run the linear probing at the end of the test epoch.

prog_bar: bool, default=True

Whether to display the metrics in the progress bar.

prefix_score: str, default=””

Prefix to add to the score name when logging, such as ridge_ or logreg_.

__init__(dataloader, probe, scoring=None, cv=None, n_jobs=None, every_n_train_epochs=1, every_n_val_epochs=None, on_test_epoch_start=False, on_test_epoch_end=False, prog_bar=True, prefix_score='')[source]
static adapt_dataloader_for_ddp(dataloader, trainer)[source]

Wrap user dataloader with DistributedSampler if in DDP mode.

check_array(X, y)[source]

Check the input arrays for cross-validation.

Parameters:
X: np.ndarray

The input features.

y: np.ndarray

The input targets.

Returns:
tuple of (X, y)

The checked input features and targets.

cross_validate(X, y)[source]

Cross-validate the probe on the data embeddings.

extract_features(trainer, pl_module, dataloader)[source]

Extract features from a dataloader with the BaseEstimator.

By default, it uses the transform_step logic applied on each batch to get the embeddings with the labels. The input dataloader should yield batches of the form (X, y) where X is the input data and y is the label.

Parameters:
trainer: pl.Trainer

The pytorch-lightning trainer.

pl_module: BaseEstimator

The BaseEstimator module that implements the ‘transform_step’.

dataloader: torch.utils.data.DataLoader

The dataloader to extract features from. It should yield batches of the form (X, y) where X is the input data and y is the label.

Returns:
tuple of (z, y)

Tuple of numpy arrays (z, y) where z are the extracted features and y are the corresponding labels.

log_metrics(pl_module, scores)[source]

Log all scores + times from sklearn.cross_validate into PL.

on_test_epoch_end(trainer, pl_module)[source]

Called when the test epoch ends.

on_test_epoch_start(trainer, pl_module)[source]

Called when the test epoch begins.

on_train_epoch_end(trainer, pl_module)[source]

Called when the train epoch ends.

To access all batch outputs at the end of the epoch, you can cache step outputs as an attribute of the pytorch_lightning.core.LightningModule and access them in this hook:

class MyLightningModule(L.LightningModule):
    def __init__(self):
        super().__init__()
        self.training_step_outputs = []

    def training_step(self):
        loss = ...
        self.training_step_outputs.append(loss)
        return loss

class MyCallback(L.Callback):
    def on_train_epoch_end(self, trainer, pl_module):
        # do something with all training_step outputs, for example:
        epoch_mean = torch.stack(pl_module.training_step_outputs).mean()
        pl_module.log("training_epoch_mean", epoch_mean)
        # free up the memory
        pl_module.training_step_outputs.clear()
on_validation_epoch_end(trainer, pl_module)[source]

Called when the val epoch ends.

probing(trainer, pl_module)[source]

Perform the probing on the given estimator.

This method performs the following steps: 1) Extracts the features from the training and test dataloaders 2) Fits the probe on the training features and labels 3) Makes predictions on the test features 4) Computes and logs the metrics.

Parameters:
pl_module: BaseEstimator

The BaseEstimator module that implements the transform_step.

Raises:
ValueError: If the pl_module does not inherit from BaseEstimator or
from TransformerMixin.