Menu

Deep learning for NeuroImaging in Python.

Note

This page is a reference documentation. It only explains the class signature, and not how to use it. Please refer to the gallery for the big picture.

class nidl.estimators.base.BaseEstimator(callbacks: list[Callback] | Callback | None = None, check_val_every_n_epoch: int | None = 1, val_check_interval: int | float | None = None, max_epochs: int | None = None, min_epochs: int | None = None, max_steps: int = -1, min_steps: int | None = None, enable_checkpointing: bool | None = None, enable_progress_bar: bool | None = None, enable_model_summary: bool | None = None, accelerator: str | Accelerator = 'auto', strategy: str | Strategy = 'auto', devices: list[int] | str | int = 'auto', num_nodes: int = 1, precision: Literal[64, 32, 16] | Literal['transformer-engine', 'transformer-engine-float16', '16-true', '16-mixed', 'bf16-true', 'bf16-mixed', '32-true', '64-true'] | Literal['64', '32', '16', 'bf16'] | None = None, ignore: Sequence[str] | None = None, random_state: int | None = None, **kwargs)[source]

Bases: LightningModule

Base class for all estimators in the NIDL framework designed for scalability.

Inherits from PyTorch Lightning’s LightningModule. This class provides a common interface for training, validation, and prediction/transformation in a distributed setting (multi-node multi-GPU) inheriting from the Lightning’s Trainer capabilities.

Basicaly, this class defines:

  • a fit method.

  • a transform or predict method if the child class inherit from a valid Mixin class.

Parameters:

callbacks : list of Callback or Callback, default=None

add a callback or list of callbacks.

check_val_every_n_epoch : int, default=1

perform a validation loop after every N training epochs. If None, validation will be done solely based on the number of training batches, requiring val_check_interval to be an integer value.

val_check_interval : int or float, default=None

how often to check the validation set. Pass a float in the range [0.0, 1.0] to check after a fraction of the training epoch. Pass an int to check after a fixed number of training batches. An int value can only be higher than the number of training batches when check_val_every_n_epoch=None, which validates after every N training batches across epochs or during iteration-based training. Default: 1.0.

max_epochs : int, default=None

stop training once this number of epochs is reached. If both max_epochs and max_steps are not specified, defaults to max_epochs = 1000. To enable infinite training, set max_epochs = -1.

min_epochs : int, default=None

force training for at least these many epochs. Disabled by default.

max_steps : int, default -1

stop training after this number of steps. If max_steps = -1 and max_epochs = None, will default to max_epochs = 1000. To enable infinite training, set max_epochs to -1.

min_steps : int, default=None

force training for at least these number of steps. Disabled by default.

enable_checkpointing : bool, default=None

if True, enable checkpointing. It will configure a default ModelCheckpoint callback if there is no user-defined ModelCheckpoint in trainer callbacks. Default: True.

enable_progress_bar : bool, default=None

whether to enable to progress bar by default. Default: True.

enable_model_summary : bool, default=None

whether to enable model summarization by default. Default: True.

accelerator : str or Accelerator, default=”auto”

supports passing different accelerator types (“cpu”, “gpu”, “tpu”, “hpu”, “mps”, “auto”) as well as custom accelerator instances.

strategy : str or Strategy, default=”auto”

supports different training strategies with aliases as well custom strategies.

devices : listof int, str, int, default=”auto”

the devices to use. Can be set to a positive number (int or str), a sequence of device indices (list or str), the value -1 to indicate all available devices should be used, or "auto" for automatic selection based on the chosen accelerator.

num_nodes : int, default=1

number of GPU nodes for distributed training.

precision : int or str, default=”32-true”

double precision (64, ‘64’ or ‘64-true’), full precision (32, ‘32’ or ‘32-true’), 16bit mixed precision (16, ‘16’, ‘16-mixed’) or bfloat16 mixed precision (‘bf16’, ‘bf16-mixed’). Can be used on CPU, GPU, TPUs, or HPUs.

ignore : list of str, default=None

ignore attribute of instance nn.Module.

random_state : int, default=None

when shuffling is used, random_state affects the ordering of the indices, which controls the randomness of each batch. Pass an int for reproducible output across multiple function calls.

kwargs : dict

lightning’s trainer extra parameters.

Notes

Callbacks can help you to tune, monitor or debug an estimator. For instance you can check the type of the input batches using BatchTypingCallback callback.

Attributes

fitted

a boolean that is True if the estimator has been fitted, and is False otherwise.

hparams:

contains the estimator hyperparameters.

trainer

the current lightning trainer.

trainer_params

a dictionaray with the trainer parameters.

Methods

BaseEstimator.fit()

the fit method.

BaseEstimator.transform()

the transform method for transformer.

BaseEstimator.predict()

the predict method for regression, classification and clustering.

BaseEstimator.training_step()

compute and return the training loss and some additional metrics. TO BE IMPLEMENTED.

BaseEstimator.validation_step()

compute anything of interest like accuracy on a single batch of data from the validation set. TO BE IMPLEMENTED.

BaseEstimator.transform_step()

transform new data. TO BE IMPLEMENTED.

BaseEstimator.predict_step()

make some predictions on new data. TO BE IMPLEMENTED.

BaseEstimator.log()

log a key, value pair.

BaseEstimator.log_dict()

log a dictionary of values at once.

fit(train_dataloader: DataLoader, val_dataloader: DataLoader | None = None)[source]

The fit method.

In the child class you will need to define:

  • a training_step method for defining the training instructions at each step.

  • a validation_step method for defining the validation instructions at each step.

Parameters:

train_dataloader : torch DataLoader

training samples.

val_dataloader : torch DataLoader, default None

validation samples.

Returns:

self : object

fitted estimator.

log(name: str, value: Metric | Tensor | int | float, prog_bar: bool = False, logger: bool | None = None, on_step: bool | None = None, on_epoch: bool | None = None, reduce_fx: str | Callable = 'mean', enable_graph: bool = False, sync_dist: bool = False, sync_dist_group: Any | None = None, add_dataloader_idx: bool = True, batch_size: int | None = None, metric_attribute: str | None = None, rank_zero_only: bool = False) None[source]

Log a key, value pair.

Parameters:

name : str

key to log. Must be identical across all processes if using DDP or any other distributed strategy.

value : object

value to log. Can be a float, Tensor, or a Metric.

prog_bar : bool, default=False

if True logs to the progress bar.

logger : bool, default=None

if True logs to the logger.

on_step : bool, default=None

if True logs at this step. The default value is determined by the hook.

on_epoch : bool, default=None

if True logs epoch accumulated metrics. The default value is determined by the hook.

reduce_fx : str of callable, default=’mean’

reduction function over step values for end of epoch. torch.mean() by default.

enable_graph : bool, default=False

if True, will not auto detach the graph.

sync_dist : bool, default=False

if True, reduces the metric across devices. Use with care as this may lead to a significant communication overhead.

sync_dist_group : object, default=None

the DDP group to sync across.

add_dataloader_idx : bool, default=True

if True, appends the index of the current dataloader to the name (when using multiple dataloaders). If False, user needs to give unique names for each dataloader to not mix the values.

batch_size : int, default=None

current batch_size. This will be directly inferred from the loaded batch, but for some data structures you might need to explicitly provide it.

metric_attribute : str, default=None

to restore the metric state, Lightning requires the reference of the torchmetrics.Metric in your model. This is found automatically if it is a model attribute.

rank_zero_only : bool, default=False

tells Lightning if you are calling self.log from every process (default) or only from rank 0. If True, you won’t be able to use this metric as a monitor in callbacks (e.g., early stopping). Warning: Improper use can lead to deadlocks!

Examples

>>> self.log('train_loss', loss)
log_dict(dictionary: Mapping[str, Metric | Tensor | int | float] | MetricCollection, prog_bar: bool = False, logger: bool | None = None, on_step: bool | None = None, on_epoch: bool | None = None, reduce_fx: str | Callable = 'mean', enable_graph: bool = False, sync_dist: bool = False, sync_dist_group: Any | None = None, add_dataloader_idx: bool = True, batch_size: int | None = None, rank_zero_only: bool = False) None[source]

Log a dictionary of values at once.

Parameters:

dictionary : dict

key value pairs. Keys must be identical across all processes if using DDP or any other distributed strategy. The values can be a float, Tensor, Metric, or MetricCollection.

prog_bar : bool, default=False

if True logs to the progress bar.

logger : bool, default=None

if True logs to the logger.

on_step : bool, default=None

None auto-logs for training_step but not validation/test_step. The default value is determined by the hook.

on_epoch : bool, default=None

None auto-logs for val/test step but not training_step. The default value is determined by the hook.

reduce_fx : str of callable, default=’mean’

reduction function over step values for end of epoch. torch.mean() by default.

enable_graph : bool, default=False

if True, will not auto detach the graph.

sync_dist : bool, default=False

if True, reduces the metric across devices. Use with care as this may lead to a significant communication overhead.

sync_dist_group : object, default=None

the DDP group to sync across.

add_dataloader_idx : bool, default=True

if True, appends the index of the current dataloader to the name (when using multiple dataloaders). If False, user needs to give unique names for each dataloader to not mix the values.

batch_size : int, default=None

current batch_size. This will be directly inferred from the loaded batch, but for some data structures you might need to explicitly provide it.

rank_zero_only : bool, default=False

tells Lightning if you are calling self.log from every process (default) or only from rank 0. If True, you won’t be able to use this metric as a monitor in callbacks (e.g., early stopping). Warning: Improper use can lead to deadlocks!

Examples

>>> values = {'loss': loss, 'acc': acc, ..., 'metric_n': metric_n}
>>> self.log_dict(values)
predict(test_dataloader: DataLoader) Any[source]

The predict method for regression, classification and clustering.

In the child class you will need to define:

  • a predict_step method for defining the predict instructions at each step.

Parameters:

test_dataloader : torch DataLoader

testing samples.

Returns:

out : torch Tensor

returns predicted samples.

predict_step(batch: Any, batch_idx: int, dataloader_idx: int | None = 0) Any[source]

Step function called during predict(). By default, it calls forward(). Override to add any processing logic.

The predict_step() is used to scale inference on multi-devices.

To prevent an OOM error, it is possible to use BasePredictionWriter callback to write the predictions to disk or database after each batch or on epoch end.

The BasePredictionWriter should be used while using a spawn based accelerator. This happens for training strategy strategy="ddp_spawn" or training on 8 TPU cores with accelerator="tpu", devices=8 as predictions won’t be returned.

Parameters:

batch : iterable, normally a DataLoader

the current data.

batch_idx : int

the index of this batch.

dataloader_idx : int, default=0

the index of the dataloader that produced this batch (only if multiple dataloaders are used).

Returns:

out : Any

the predicted output.

training_step(batch: Any, batch_idx: int, dataloader_idx: int | None = 0) Tensor | Mapping[str, Any] | None[source]

Here you compute and return the training loss and some additional metrics for e.g. the progress bar or logger.

Parameters:

batch : iterable, normally a DataLoader

the current data.

batch_idx : int

the index of this batch.

dataloader_idx : int, default=0

the index of the dataloader that produced this batch (only if multiple dataloaders are used).

Returns:

loss : STEP_OUTPUT

the computed loss:

  • Tensor - the loss tensor.

  • dict - a dictionary which can include any keys, but must include the key 'loss' in the case of automatic optimization.

  • None - in automatic optimization, this will skip to the next batch (but is not supported for multi-GPU, TPU, or DeepSpeed). For manual optimization, this has no special meaning, as returning the loss is not required.

To use multiple optimizers, you can switch to ‘manual optimization’

and control their stepping:

Notes

When accumulate_grad_batches > 1, the loss returned here will be automatically normalized by accumulate_grad_batches internally.

Examples

>>> def __init__(self):
>>>     super().__init__()
>>>     self.automatic_optimization = False
>>>
>>>
>>> # Multiple optimizers (e.g.: GANs)
>>> def training_step(self, batch, batch_idx):
>>>     opt1, opt2 = self.optimizers()
>>>
>>>     # do training_step with encoder
>>>     ...
>>>     opt1.step()
>>>     # do training_step with decoder
>>>     ...
>>>     opt2.step()
transform(test_dataloader: DataLoader) Any[source]

The transform method for transformer.

In the child class you will need to define:

  • a transform_step method for defining the transform instructions at each step.

Parameters:

test_dataloader : torch DataLoader

testing samples.

Returns:

out : torch Tensor

returns transformed samples.

abstract transform_step(batch: Any, batch_idx: int, dataloader_idx: int | None = 0) Any[source]

Define a transform step.

Share the same API as BaseEstimator.predict_step().

validation_step(batch: Any, batch_idx: int, dataloader_idx: int | None = 0) Tensor | Mapping[str, Any] | None[source]

Operates on a single batch of data from the validation set. In this step you’d might generate examples or calculate anything of interest like accuracy.

Parameters:

batch : iterable, normally a DataLoader

the current data.

batch_idx : int

the index of this batch.

dataloader_idx : int, default=0

the index of the dataloader that produced this batch (only if multiple dataloaders are used).

Returns:

loss : STEP_OUTPUT

the computed loss:

  • Tensor - the loss tensor.

  • dict - a dictionary. can include any keys, but must include the key 'loss'.

  • None - skip to the next batch.

Notes

When the validation_step() is called, the model has been put in eval mode and PyTorch gradients have been disabled. At the end of validation, the model goes back to training mode and gradients are enabled.

Follow us

© 2025, nidl developers