Menu

Deep learning for NeuroImaging in Python.

Note

This page is a reference documentation. It only explains the class signature, and not how to use it. Please refer to the gallery for the big picture.

class nidl.callbacks.model_probing.RidgeCVCallback(train_dataloader: DataLoader, test_dataloader: DataLoader, probe_name: str | None = None, alphas: tuple[float] = (0.1, 1.0, 10.0), cv: int = 5, scoring: str = 'r2', **kwargs)[source]

Bases: ModelProbing

Perform Ridge regression on top of an embedding model.

Concretely this callback:

  1. Embeds the input data through the estimator.

  2. Performs n-fold CV to find the best L2 regularization strength.

  3. Logs the main regression metrics by regressor and averaged, including:

    • mean absolute error

    • median absolute error

    • root mean squared error

    • mean squared error

    • R² score

    • pearsonr

    • explained variance score

    If multiple regressors are given (multivariate regression), metrics are computed per regressor and averaged.

Parameters:

train_dataloader : torch.utils.data.DataLoader

Training dataloader yielding batches in the form (X, y) for further embedding and training of the ridge probe.

test_dataloader : torch.utils.data.DataLoader

Test dataloader yielding batches in the form (X, y) for further embedding and test of the ridge probe.

probe_name : str or None, default=None

Name of the probe displayed when logging the results. It will appear as <probe_name>/<metric_name> for each metric. If None, only <metric_name> is displayed.

alphas : tuple of floats, default=(0.1, 1.0, 10.0)

Arrays of alpha values to try in CV. It corresponds to the regularization strength.

cv : int or cross-validation generator, default=5

How many folds to use for cross-validating the alpha regularization strength in the Ridge regression.

scoring : str in {“r2”, “neg_mean_absolute_error”,

“neg_mean_squared_error”, …}, default=”r2” Which scoring function to use to cross-validate the alpha hyper-parameter. For a complete list of scoring options, check https://scikit-learn.org/1.4/modules/model_evaluation.html#scoring

kwargs : dict

Additional keyword arguments to pass to the ModelProbing constructor (e.g. every_n_train_epochs, every_n_val_epochs, prog_bar, …).

fit(X, y)[source]

Fit the probe on the embeddings and labels of the training data.

log_metrics(pl_module, y_pred, y_true)[source]

Log the metrics given the predictions and the true labels.

predict(X)[source]

Predict the probe on new data X.

Follow us

© 2025, nidl developers