Deep learning for NeuroImaging in Python.
Note
This page is a reference documentation. It only explains the class signature, and not how to use it. Please refer to the gallery for the big picture.
- class nidl.callbacks.multitask_probing.MultitaskModelProbing(train_dataloader: DataLoader, test_dataloader: DataLoader, probes: list[BaseEstimator] | MultiTaskEstimator, probe_names: list[str] | None = None, every_n_train_epochs: int | None = 1, every_n_val_epochs: int | None = None, on_test_epoch_start: bool = False, on_test_epoch_end: bool = False, prog_bar: bool = True)[source]¶
Bases:
ModelProbing
Callback to probe the representation of an embedding estimator on a multi-task dataset.
This callback implements multitask probing on top of an embedding estimator for both classification and regression tasks. It avoids computing the embeddings multiple times for each task by storing the embeddings once in memory. Each probe is then trained and evaluated separately on the stored embeddings for each task.
- Parameters:
train_dataloader : torch.utils.data.DataLoader
Training dataloader yielding batches in the form (X, y) for further embedding and training of the probe. y should have shape (n_samples, n_tasks) with one output per task (categorical or continuous).
test_dataloader : torch.utils.data.DataLoader
Test dataloader yielding batches in the form (X, y) for further embedding and test of the probe. y should have shape (n_samples, n_tasks) with one output per task.
probes : list of sklearn.base.BaseEstimator or MultiTaskEstimator
The probes used to evaluate the data embedding on multiple tasks (classification or regression). Each probe is fitted on one task (=one target) and should implement fit and predict.
probe_names : str or list of str or None, default=None
Name of the probes to be displayed when logging the results. It will appear as <task_name_i>/<metric_name> for each task i. It should have the same length as probes (if list). If None, [“task1”, “task2”, …] are used.
every_n_train_epochs : int or None, default=1
Number of training epochs after which to run the probing. Disabled if None.
every_n_val_epochs : int or None, default=None
Number of validation epochs after which to run the probing. Disabled if None.
on_test_epoch_start : bool, default=False
Whether to run the probing at the start of the test epoch.
on_test_epoch_end : bool, default=False
Whether to run the probing at the end of the test epoch.
prog_bar : bool, default=True
Whether to display the metrics in the progress bar.
- log_classification_metrics(pl_module, y_pred, y_true, task_name)[source]¶
Log the metrics for a classification task.
The main classification metrics reported are:
precision (macro)
recall (macro)
f1-score (weighted and macro)
accuracy (global)
balanced accuracy
- Parameters:
pl_module : nidl.estimators.base.BaseEstimator
The embedding estimator currently evaluated.
y_pred : array-like, shape (n_samples,)
Predicted values for classification.
y_true : array-like, shape (n_samples,)
Ground-truth values.
task_name : str
Name to display when logging the metrics.
- log_metrics(pl_module, y_pred, y_true)[source]¶
Log the metrics for each task (classification or regression).
- Parameters:
pl_module : nidl.estimators.base.BaseEstimator
The embedding estimator currently evaluated.
y_pred : array-like, shape (n_samples, n_tasks)
Predicted values.
y_true : array-like, shape (n_samples, n_tasks)
Ground-truth for the tasks.
- log_regression_metrics(pl_module, y_pred, y_true, task_name)[source]¶
Log the metrics for a regression task.
The main regression metrics reported are:
mean absolute error
median absolute error
root mean squared error
mean squared error
R² score
Pearson’s r
explained variance score
- Parameters:
pl_module : nidl.estimators.base.BaseEstimator
The embedding estimator currently evaluated.
y_pred : array-like, shape (n_samples,)
Predicted values for regression.
y_true : array-like, shape (n_samples,)
Ground-truth values.
task_name : str
Name to display when logging the metrics.
Examples¶
Follow us