Deep learning for NeuroImaging in Python.
Note
This page is a reference documentation. It only explains the class signature, and not how to use it. Please refer to the gallery for the big picture.
- class nidl.callbacks.model_probing.ClassificationProbingCallback(train_dataloader, test_dataloader, probe, probe_name=None, every_n_train_epochs=1, every_n_val_epochs=None, on_test_epoch_start=False, on_test_epoch_end=False, prog_bar=True)[source]¶
Bases:
ModelProbing
Perform classification on top of an embedding model.
Concretely this callback:
Embeds the input data through the torch model.
Fits the classification probe on the embedded data.
Logs the main classification metrics:
precision (macro)
recall (macro)
f1-score (weighted and macro)
accuracy (global)
balanced accuracy
Please check this User Guide for more details on the classification metrics reported.
- Parameters:
train_dataloader : torch.utils.data.DataLoader
Training dataloader yielding batches in the form (X, y) for further embedding and training of the probe.
test_dataloader : torch.utils.data.DataLoader
Test dataloader yielding batches in the form (X, y) for further embedding and test of the probe.
probe : sklearn.base.BaseEstimator
The scikit-learn classifier to be trained on the embedding.
probe_name : str or None, default=None
Name of the probe displayed when logging the results. It will appear as <probe_name>/<metric_name> for each metric. If None, <probe_class_name>/<metric_name> is displayed.
every_n_train_epochs : int or None, default=1
Number of training epochs after which to run the probing. Disabled if None.
every_n_val_epochs : int or None, default=None
Number of validation epochs after which to run the probing. Disabled if None.
on_test_epoch_start : bool, default=False
Whether to run the probing at the start of the test epoch.
on_test_epoch_end : bool, default=False
Whether to run the probing at the end of the test epoch.
prog_bar : bool, default=True
Whether to display the metrics in the progress bar.
Examples¶
Follow us