Deep learning for NeuroImaging in Python.
Note
This page is a reference documentation. It only explains the class signature, and not how to use it. Please refer to the gallery for the big picture.
- class nidl.callbacks.model_probing.KNeighborsRegressorCVCallback(train_dataloader: DataLoader, test_dataloader: DataLoader, probe_name: str | None = None, n_neighbors: tuple[int] = (2, 5, 10), cv: int = 5, n_jobs: int | None = None, scoring: str = 'r2', **kwargs)[source]¶
Bases:
ModelProbing
Perform KNN regression on top of an embedding model.
Concretely, this callback:
Embeds the input data through the torch model.
Performs n-fold cross-validation to find the best n_neighbors.
Logs the main regression metrics by regressor and averaged, including:
mean absolute error
median absolute error
root mean squared error
mean squared error
R² score
Pearson correlation coefficient
explained variance score
- Parameters:
train_dataloader : torch.utils.data.DataLoader
Training dataloader yielding batches in the form (X, y) for further embedding and training of the KNN probe.
test_dataloader : torch.utils.data.DataLoader
Test dataloader yielding batches in the form (X, y) for further embedding and test of the KNN probe.
probe_name : str or None, default=None
Name of the probe displayed when logging the results. It will appear as <probe_name>/<metric_name> for each metric. If None, only <metric_name> is displayed.
n_neighbors : tuple of int, default=(2, 5, 10)
Arrays of n_neighbors values to try in CV. It corresponds to the number of neighbors to use by the KNN on the training set.
cv : int or cross-validation generator, default=5
How many folds to use for cross-validating the alpha regularization strength in the Ridge regression.
n_jobs : int or None, default=None
Number of jobs to run in parallel.
None
means 1 unless in ajoblib.parallel_backend
context.-1
means using all processors.scoring : str in {“r2”, “neg_mean_absolute_error”,
“neg_mean_squared_error”, …}, default=”r2” Which scoring function to use to cross-validate the n_neighbors hyper-parameter. For a complete list of scoring options, check https://scikit-learn.org/1.4/modules/model_evaluation.html#scoring
kwargs : dict
Additional keyword arguments to pass to the ModelProbing constructor (e.g. every_n_train_epochs, every_n_val_epochs, prog_bar, …).
- fit(X, y)[source]¶
Fit the k-nearest neighbors regressor from the training dataset and log the regression metrics.
- Parameters:
X : {array-like, sparse matrix} of shape (n_samples, n_features)
Training data.
y : {array-like, sparse matrix} of shape (n_samples,) or (n_samples, n_outputs)
Target values.
Follow us