Deep learning for NeuroImaging in Python.
Note
This page is a reference documentation. It only explains the class signature, and not how to use it. Please refer to the gallery for the big picture.
- class nidl.callbacks.model_probing.LogisticRegressionCVCallback(train_dataloader: DataLoader, test_dataloader: DataLoader, probe_name: str | None = None, Cs: int | list[float] = 5, cv: int = 5, max_iter: int = 100, n_jobs: int | None = None, scoring: str = 'balanced_accuracy', linear_solver: str = 'lbfgs', **kwargs)[source]¶
Bases:
ModelProbing
Performs logistic regression on top of an embedding model.
Concretely this callback:
Embeds the input data through the torch model.
Performs n-fold CV to find the best L2 regularization strength.
Logs the main classification metrics for each class and averaged across classes (weighted by class support and unweighted):
precision
recall
f1-score
support
accuracy (global)
Please check this User Guide for more details on the classification metrics reported.
- Parameters:
train_dataloader : torch.utils.data.DataLoader
Training dataloader yielding batches in the form (X, y) for further embedding and training of the logistic regression probe.
test_dataloader : torch.utils.data.DataLoader
Test dataloader yielding batches in the form (X, y) for further embedding and test of the logistic regression probe.
probe_name : str or None, default=None
Name of the probe displayed when logging the results. It will appear as <probe_name>/<metric_name> for each metric. If None, only <metric_name> is displayed.
Cs : int or list of floats, default=10
Each of the values in Cs describes the inverse of regularization strength. If Cs is as an int, then a grid of Cs values are chosen in a logarithmic scale between 1e-4 and 1e4. Like in support vector machines, smaller values specify stronger regularization.
cv : int or cross-validation generator, default=5
How many folds to use for cross-validating the C regularization strenght in the LogisticRegression.
max_iter : int, default=100
Maximum number of iterations taken for the solver to converge.
n_jobs : int or None, default=None
Number of jobs to run in parallel.
None
means 1 unless in ajoblib.parallel_backend
context.-1
means using all processors.scoring : str in {“accuracy”, “balanced_accuracy”, “f1”, …},
default=”balanced_accuracy” Which scoring function to use to cross-validate the C hyper-parameter. For a complete list of scoring options, check https://scikit-learn.org/1.4/modules/model_evaluation.html#scoring
linear_solver : str in {‘lbfgs’, ‘liblinear’, ‘newton-cg’,
‘newton-cholesky’, ‘sag’, ‘saga’}, default=’lbfgs’ Algorithm to use in the optimization problem.
kwargs : dict
Additional keyword arguments to pass to the ModelProbing constructor (e.g. every_n_train_epochs, every_n_val_epochs, prog_bar, …).
- fit(X, y)[source]¶
Fit the model according to the given training data and log the classification metrics.
- Parameters:
X : {array-like, sparse matrix} of shape (n_samples, n_features)
Training vector, where n_samples is the number of samples and n_features is the number of features.
y : array-like of shape (n_samples,)
Target vector relative to X.
Follow us