Note
This page is a reference documentation. It only explains the class signature, and not how to use it. Please refer to the user guide for the big picture.
nidl.callbacks.ModelProbing¶
- class nidl.callbacks.ModelProbing(train_dataloader, test_dataloader, probe, scoring=None, every_n_train_epochs=1, every_n_val_epochs=None, on_test_epoch_start=False, on_test_epoch_end=False, prog_bar=True, prefix_score='')[source]¶
Bases:
CallbackCallback to probe the representation of an embedding estimator on a dataset.
It has the following logic:
Embeds the input data (training+test) through the estimator using transform_step method (handles distributed multi-gpu forward pass).
Train the probe on the training embedding (handles multi-cpu training).
Evaluate the probe on the test embedding and log the scores.
The probing can be performed at the end of training epochs, validation epochs, and/or at the start/end of the test epoch.
The metrics logged depend on the
scoringparameter:If a single score is provided, it logs
test_score.If multiple scores are provided, it logs each score with its name (such as
test_accuracy,test_auc).
Eventually, a prefix_score can be added to the score names when logging, such as
ridge_orlogreg_(givingridge_test_r2orlogreg_test_accuracy).- Parameters:
- train_dataloader: torch.utils.data.DataLoader
Training dataloader yielding batches in the form (X, y) for further embedding and training of the probe.
- test_dataloader: torch.utils.data.DataLoader
Test dataloader yielding batches in the form (X, y) for further embedding and test of the probe.
- probe: sklearn.base.BaseEstimator
The probe model to be trained on the embedding. It must implement fit and predict methods on numpy array.
- scoring: str, callable, list, tuple, or dict, default=None
Strategy to evaluate the performance of the probe on the test set. The scores are logged into the
LightningModuleduring training/validation/test according to the configuration of the callback.If scoring represents a single score, one can use:
a single string (see String name scorers);
a callable (see Callable scorers) that returns a single value.
None, the probe’s default evaluation criterion is used.
If scoring represents multiple scores, one can use:
a list or tuple of unique strings;
a callable returning a dictionary where the keys are the metric names and the values are the metric scores;
a dictionary with metric names as keys and callables a values.
- every_n_train_epochs: int or None, default=1
Number of training epochs after which to run the probing. Disabled if None.
- every_n_val_epochs: int or None, default=None
Number of validation epochs after which to run the probing. Disabled if None.
- on_test_epoch_start: bool, default=False
Whether to run the linear probing at the start of the test epoch.
- on_test_epoch_end: bool, default=False
Whether to run the linear probing at the end of the test epoch.
- prog_bar: bool, default=True
Whether to display the metrics in the progress bar.
- prefix_score: str, default=””
Prefix to add to the score name when logging. This can be useful when using multiple ModelProbing callbacks to distinguish the logged metrics, such as
"ridge_"or"logreg_".
Examples
>>> from sklearn.linear_model import LogisticRegression >>> from nidl.callbacks import ModelProbing >>> callback = ModelProbing( ... train_dataloader=train_loader, ... test_dataloader=test_loader, ... probe=LogisticRegression(), ... scoring=["accuracy", "balanced_accuracy"], ... every_n_train_epochs=5, ... )
- __init__(train_dataloader, test_dataloader, probe, scoring=None, every_n_train_epochs=1, every_n_val_epochs=None, on_test_epoch_start=False, on_test_epoch_end=False, prog_bar=True, prefix_score='')[source]¶
- static adapt_dataloader_for_ddp(dataloader, trainer)[source]¶
Wrap user dataloader with DistributedSampler if in DDP mode.
- extract_features(trainer, pl_module, dataloader)[source]¶
Extract features from a dataloader with the BaseEstimator.
By default, it uses the transform_step logic applied on each batch to get the embeddings with the labels. The input dataloader should yield batches of the form (X, y) where X is the input data and y is the label.
- Parameters:
- trainer: pl.Trainer
The pytorch-lightning trainer instance.
- pl_module: BaseEstimator
The BaseEstimator module that implements the ‘transform_step’.
- dataloader: torch.utils.data.DataLoader
The dataloader to extract features from. It should yield batches of the form (X, y) where X is the input data and y is the label.
- Returns:
- tuple of (z, y)
Tuple of numpy arrays (z, y) where z are the extracted features and y are the corresponding labels.
- on_train_epoch_end(trainer, pl_module)[source]¶
Called when the train epoch ends.
To access all batch outputs at the end of the epoch, you can cache step outputs as an attribute of the
pytorch_lightning.core.LightningModuleand access them in this hook:class MyLightningModule(L.LightningModule): def __init__(self): super().__init__() self.training_step_outputs = [] def training_step(self): loss = ... self.training_step_outputs.append(loss) return loss class MyCallback(L.Callback): def on_train_epoch_end(self, trainer, pl_module): # do something with all training_step outputs, for example: epoch_mean = torch.stack(pl_module.training_step_outputs).mean() pl_module.log("training_epoch_mean", epoch_mean) # free up the memory pl_module.training_step_outputs.clear()
- probing(trainer, pl_module)[source]¶
Perform the probing on the given estimator.
This method performs the following steps: 1) Extracts the features from the training and test dataloaders 2) Fits the probe on the training features and labels 3) Makes predictions on the test features 4) Computes and logs the metrics.
- Parameters:
- pl_module: BaseEstimator
The BaseEstimator module that implements the transform_step.
- Raises:
- ValueError: If the pl_module does not inherit from BaseEstimator or
- from TransformerMixin.