Menu

Deep learning for NeuroImaging in Python.

Note

This page is a reference documentation. It only explains the class signature, and not how to use it. Please refer to the gallery for the big picture.

class nidl.callbacks.multitask_probing.MultitaskModelProbing(train_dataloader: DataLoader, test_dataloader: DataLoader, probes: list[BaseEstimator] | MultiTaskEstimator, probe_names: list[str] | None = None, every_n_train_epochs: int | None = 1, every_n_val_epochs: int | None = None, on_test_epoch_start: bool = False, on_test_epoch_end: bool = False, prog_bar: bool = True)[source]

Bases: ModelProbing

Callback to probe the representation of an embedding estimator on a multi-task dataset.

This callback implements multitask probing on top of an embedding estimator for both classification and regression tasks. It avoids computing the embeddings multiple times for each task by storing the embeddings once in memory. Each probe is then trained and evaluated separately on the stored embeddings for each task.

Parameters:

train_dataloader : torch.utils.data.DataLoader

Training dataloader yielding batches in the form (X, y) for further embedding and training of the probe. y should have shape (n_samples, n_tasks) with one output per task (categorical or continuous).

test_dataloader : torch.utils.data.DataLoader

Test dataloader yielding batches in the form (X, y) for further embedding and test of the probe. y should have shape (n_samples, n_tasks) with one output per task.

probes : list of sklearn.base.BaseEstimator or MultiTaskEstimator

The probes used to evaluate the data embedding on multiple tasks (classification or regression). Each probe is fitted on one task (=one target) and should implement fit and predict.

probe_names : str or list of str or None, default=None

Name of the probes to be displayed when logging the results. It will appear as <task_name_i>/<metric_name> for each task i. It should have the same length as probes (if list). If None, [“task1”, “task2”, …] are used.

every_n_train_epochs : int or None, default=1

Number of training epochs after which to run the probing. Disabled if None.

every_n_val_epochs : int or None, default=None

Number of validation epochs after which to run the probing. Disabled if None.

on_test_epoch_start : bool, default=False

Whether to run the probing at the start of the test epoch.

on_test_epoch_end : bool, default=False

Whether to run the probing at the end of the test epoch.

prog_bar : bool, default=True

Whether to display the metrics in the progress bar.

log_classification_metrics(pl_module, y_pred, y_true, task_name)[source]

Log the metrics for a classification task.

The main classification metrics reported are:

  • precision (macro)

  • recall (macro)

  • f1-score (weighted and macro)

  • accuracy (global)

  • balanced accuracy

Parameters:

pl_module : nidl.estimators.base.BaseEstimator

The embedding estimator currently evaluated.

y_pred : array-like, shape (n_samples,)

Predicted values for classification.

y_true : array-like, shape (n_samples,)

Ground-truth values.

task_name : str

Name to display when logging the metrics.

log_metrics(pl_module, y_pred, y_true)[source]

Log the metrics for each task (classification or regression).

Parameters:

pl_module : nidl.estimators.base.BaseEstimator

The embedding estimator currently evaluated.

y_pred : array-like, shape (n_samples, n_tasks)

Predicted values.

y_true : array-like, shape (n_samples, n_tasks)

Ground-truth for the tasks.

log_regression_metrics(pl_module, y_pred, y_true, task_name)[source]

Log the metrics for a regression task.

The main regression metrics reported are:

  • mean absolute error

  • median absolute error

  • root mean squared error

  • mean squared error

  • R² score

  • Pearson’s r

  • explained variance score

Parameters:

pl_module : nidl.estimators.base.BaseEstimator

The embedding estimator currently evaluated.

y_pred : array-like, shape (n_samples,)

Predicted values for regression.

y_true : array-like, shape (n_samples,)

Ground-truth values.

task_name : str

Name to display when logging the metrics.

Examples

Model probing callback of embedding estimators

Model probing callback of embedding estimators

Follow us

© 2025, nidl developers