Deep learning for NeuroImaging in Python.
Note
This page is a reference documentation. It only explains the class signature, and not how to use it. Please refer to the gallery for the big picture.
- class nidl.callbacks.model_probing.RegressionProbingCallback(train_dataloader, test_dataloader, probe, probe_name=None, every_n_train_epochs=1, every_n_val_epochs=None, on_test_epoch_start=False, on_test_epoch_end=False, prog_bar=True)[source]¶
Bases:
ModelProbing
Perform regression on top of an embedding model.
Concretely this callback:
Embeds the input data through the estimator.
Fits the regression probe on the embedded data.
Logs the main regression metrics including:
mean absolute error
median absolute error
root mean squared error
mean squared error
R² score
Pearson’s r
explained variance score
If multiple regressors are given (multivariate regression), metrics are computed per regressor and averaged.
- Parameters:
train_dataloader : torch.utils.data.DataLoader
Training dataloader yielding batches in the form (X, y) for further embedding and training of the probe.
test_dataloader : torch.utils.data.DataLoader
Test dataloader yielding batches in the form (X, y) for further embedding and test of the probe.
probe : sklearn.base.BaseEstimator
The scikit-learn regressor to be trained on the embedding.
probe_name : str or None, default=None
Name of the probe displayed when logging the results. It will appear as <probe_name>/<metric_name> for each metric. If None, <probe_class_name>/<metric_name> is displayed.
every_n_train_epochs : int or None, default=1
Number of training epochs after which to run the probing. Disabled if None.
every_n_val_epochs : int or None, default=None
Number of validation epochs after which to run the probing. Disabled if None.
on_test_epoch_start : bool, default=False
Whether to run the probing at the start of the test epoch.
on_test_epoch_end : bool, default=False
Whether to run the probing at the end of the test epoch.
prog_bar : bool, default=True
Whether to display the metrics in the progress bar.
Follow us