Deep learning for NeuroImaging in Python.
Note
This page is a reference documentation. It only explains the class signature, and not how to use it. Please refer to the gallery for the big picture.
- class nidl.estimators.base.BaseEstimator(callbacks: list[Callback] | Callback | None = None, check_val_every_n_epoch: int | None = 1, val_check_interval: int | float | None = None, max_epochs: int | None = None, min_epochs: int | None = None, max_steps: int = -1, min_steps: int | None = None, enable_checkpointing: bool | None = None, enable_progress_bar: bool | None = None, enable_model_summary: bool | None = None, accelerator: str | Accelerator = 'auto', strategy: str | Strategy = 'auto', devices: list[int] | str | int = 'auto', num_nodes: int = 1, precision: Literal[64, 32, 16] | Literal['transformer-engine', 'transformer-engine-float16', '16-true', '16-mixed', 'bf16-true', 'bf16-mixed', '32-true', '64-true'] | Literal['64', '32', '16', 'bf16'] | None = None, ignore: Sequence[str] | None = None, random_state: int | None = None, **kwargs)[source]¶
Bases:
LightningModule
Base class for all estimators in the NIDL framework designed for scalability.
Inherits from PyTorch Lightning’s LightningModule. This class provides a common interface for training, validation, and prediction/transformation in a distributed setting (multi-node multi-GPU) inheriting from the Lightning’s Trainer capabilities.
Basicaly, this class defines:
a fit method.
a transform or predict method if the child class inherit from a valid Mixin class.
- Parameters:
callbacks : list of Callback or Callback, default=None
add a callback or list of callbacks.
check_val_every_n_epoch : int, default=1
perform a validation loop after every N training epochs. If
None
, validation will be done solely based on the number of training batches, requiringval_check_interval
to be an integer value.val_check_interval : int or float, default=None
how often to check the validation set. Pass a
float
in the range [0.0, 1.0] to check after a fraction of the training epoch. Pass anint
to check after a fixed number of training batches. Anint
value can only be higher than the number of training batches whencheck_val_every_n_epoch=None
, which validates after everyN
training batches across epochs or during iteration-based training. Default:1.0
.max_epochs : int, default=None
stop training once this number of epochs is reached. If both max_epochs and max_steps are not specified, defaults to
max_epochs = 1000
. To enable infinite training, setmax_epochs = -1
.min_epochs : int, default=None
force training for at least these many epochs. Disabled by default.
max_steps : int, default -1
stop training after this number of steps. If
max_steps = -1
andmax_epochs = None
, will default tomax_epochs = 1000
. To enable infinite training, setmax_epochs
to-1
.min_steps : int, default=None
force training for at least these number of steps. Disabled by default.
enable_checkpointing : bool, default=None
if
True
, enable checkpointing. It will configure a default ModelCheckpoint callback if there is no user-defined ModelCheckpoint in trainer callbacks. Default:True
.enable_progress_bar : bool, default=None
whether to enable to progress bar by default. Default:
True
.enable_model_summary : bool, default=None
whether to enable model summarization by default. Default:
True
.accelerator : str or Accelerator, default=”auto”
supports passing different accelerator types (“cpu”, “gpu”, “tpu”, “hpu”, “mps”, “auto”) as well as custom accelerator instances.
strategy : str or Strategy, default=”auto”
supports different training strategies with aliases as well custom strategies.
devices : listof int, str, int, default=”auto”
the devices to use. Can be set to a positive number (int or str), a sequence of device indices (list or str), the value
-1
to indicate all available devices should be used, or"auto"
for automatic selection based on the chosen accelerator.num_nodes : int, default=1
number of GPU nodes for distributed training.
precision : int or str, default=”32-true”
double precision (64, ‘64’ or ‘64-true’), full precision (32, ‘32’ or ‘32-true’), 16bit mixed precision (16, ‘16’, ‘16-mixed’) or bfloat16 mixed precision (‘bf16’, ‘bf16-mixed’). Can be used on CPU, GPU, TPUs, or HPUs.
ignore : list of str, default=None
ignore attribute of instance nn.Module.
random_state : int, default=None
when shuffling is used, random_state affects the ordering of the indices, which controls the randomness of each batch. Pass an int for reproducible output across multiple function calls.
kwargs : dict
lightning’s trainer extra parameters.
Notes
Callbacks can help you to tune, monitor or debug an estimator. For instance you can check the type of the input batches using BatchTypingCallback callback.
Attributes
fitted
a boolean that is True if the estimator has been fitted, and is False otherwise.
hparams:
contains the estimator hyperparameters.
trainer
the current lightning trainer.
trainer_params
a dictionaray with the trainer parameters.
Methods
the fit method.
the transform method for transformer.
the predict method for regression, classification and clustering.
compute and return the training loss and some additional metrics. TO BE IMPLEMENTED.
compute anything of interest like accuracy on a single batch of data from the validation set. TO BE IMPLEMENTED.
transform new data. TO BE IMPLEMENTED.
make some predictions on new data. TO BE IMPLEMENTED.
log a key, value pair.
log a dictionary of values at once.
- fit(train_dataloader: DataLoader, val_dataloader: DataLoader | None = None)[source]¶
The fit method.
In the child class you will need to define:
a training_step method for defining the training instructions at each step.
a validation_step method for defining the validation instructions at each step.
- Parameters:
train_dataloader : torch DataLoader
training samples.
val_dataloader : torch DataLoader, default None
validation samples.
- Returns:
self : object
fitted estimator.
- log(name: str, value: Metric | Tensor | int | float, prog_bar: bool = False, logger: bool | None = None, on_step: bool | None = None, on_epoch: bool | None = None, reduce_fx: str | Callable = 'mean', enable_graph: bool = False, sync_dist: bool = False, sync_dist_group: Any | None = None, add_dataloader_idx: bool = True, batch_size: int | None = None, metric_attribute: str | None = None, rank_zero_only: bool = False) None [source]¶
Log a key, value pair.
- Parameters:
name : str
key to log. Must be identical across all processes if using DDP or any other distributed strategy.
value : object
value to log. Can be a
float
,Tensor
, or aMetric
.prog_bar : bool, default=False
if
True
logs to the progress bar.logger : bool, default=None
if
True
logs to the logger.on_step : bool, default=None
if
True
logs at this step. The default value is determined by the hook.on_epoch : bool, default=None
if
True
logs epoch accumulated metrics. The default value is determined by the hook.reduce_fx : str of callable, default=’mean’
reduction function over step values for end of epoch.
torch.mean()
by default.enable_graph : bool, default=False
if
True
, will not auto detach the graph.sync_dist : bool, default=False
if
True
, reduces the metric across devices. Use with care as this may lead to a significant communication overhead.sync_dist_group : object, default=None
the DDP group to sync across.
add_dataloader_idx : bool, default=True
if
True
, appends the index of the current dataloader to the name (when using multiple dataloaders). IfFalse
, user needs to give unique names for each dataloader to not mix the values.batch_size : int, default=None
current batch_size. This will be directly inferred from the loaded batch, but for some data structures you might need to explicitly provide it.
metric_attribute : str, default=None
to restore the metric state, Lightning requires the reference of the
torchmetrics.Metric
in your model. This is found automatically if it is a model attribute.rank_zero_only : bool, default=False
tells Lightning if you are calling
self.log
from every process (default) or only from rank 0. IfTrue
, you won’t be able to use this metric as a monitor in callbacks (e.g., early stopping). Warning: Improper use can lead to deadlocks!
Examples
>>> self.log('train_loss', loss)
- log_dict(dictionary: Mapping[str, Metric | Tensor | int | float] | MetricCollection, prog_bar: bool = False, logger: bool | None = None, on_step: bool | None = None, on_epoch: bool | None = None, reduce_fx: str | Callable = 'mean', enable_graph: bool = False, sync_dist: bool = False, sync_dist_group: Any | None = None, add_dataloader_idx: bool = True, batch_size: int | None = None, rank_zero_only: bool = False) None [source]¶
Log a dictionary of values at once.
- Parameters:
dictionary : dict
key value pairs. Keys must be identical across all processes if using DDP or any other distributed strategy. The values can be a
float
,Tensor
,Metric
, orMetricCollection
.prog_bar : bool, default=False
if
True
logs to the progress bar.logger : bool, default=None
if
True
logs to the logger.on_step : bool, default=None
None
auto-logs for training_step but not validation/test_step. The default value is determined by the hook.on_epoch : bool, default=None
None
auto-logs for val/test step but nottraining_step
. The default value is determined by the hook.reduce_fx : str of callable, default=’mean’
reduction function over step values for end of epoch.
torch.mean()
by default.enable_graph : bool, default=False
if
True
, will not auto detach the graph.sync_dist : bool, default=False
if
True
, reduces the metric across devices. Use with care as this may lead to a significant communication overhead.sync_dist_group : object, default=None
the DDP group to sync across.
add_dataloader_idx : bool, default=True
if
True
, appends the index of the current dataloader to the name (when using multiple dataloaders). IfFalse
, user needs to give unique names for each dataloader to not mix the values.batch_size : int, default=None
current batch_size. This will be directly inferred from the loaded batch, but for some data structures you might need to explicitly provide it.
rank_zero_only : bool, default=False
tells Lightning if you are calling
self.log
from every process (default) or only from rank 0. IfTrue
, you won’t be able to use this metric as a monitor in callbacks (e.g., early stopping). Warning: Improper use can lead to deadlocks!
Examples
>>> values = {'loss': loss, 'acc': acc, ..., 'metric_n': metric_n} >>> self.log_dict(values)
- predict(test_dataloader: DataLoader) Any [source]¶
The predict method for regression, classification and clustering.
In the child class you will need to define:
a predict_step method for defining the predict instructions at each step.
- Parameters:
test_dataloader : torch DataLoader
testing samples.
- Returns:
out : torch Tensor
returns predicted samples.
- predict_step(batch: Any, batch_idx: int, dataloader_idx: int | None = 0) Any [source]¶
Step function called during
predict()
. By default, it callsforward()
. Override to add any processing logic.The
predict_step()
is used to scale inference on multi-devices.To prevent an OOM error, it is possible to use
BasePredictionWriter
callback to write the predictions to disk or database after each batch or on epoch end.The
BasePredictionWriter
should be used while using a spawn based accelerator. This happens for training strategystrategy="ddp_spawn"
or training on 8 TPU cores withaccelerator="tpu", devices=8
as predictions won’t be returned.- Parameters:
batch : iterable, normally a
DataLoader
the current data.
batch_idx : int
the index of this batch.
dataloader_idx : int, default=0
the index of the dataloader that produced this batch (only if multiple dataloaders are used).
- Returns:
out : Any
the predicted output.
- training_step(batch: Any, batch_idx: int, dataloader_idx: int | None = 0) Tensor | Mapping[str, Any] | None [source]¶
Here you compute and return the training loss and some additional metrics for e.g. the progress bar or logger.
- Parameters:
batch : iterable, normally a
DataLoader
the current data.
batch_idx : int
the index of this batch.
dataloader_idx : int, default=0
the index of the dataloader that produced this batch (only if multiple dataloaders are used).
- Returns:
loss : STEP_OUTPUT
the computed loss:
Tensor
- the loss tensor.dict
- a dictionary which can include any keys, but must include the key'loss'
in the case of automatic optimization.None
- in automatic optimization, this will skip to the next batch (but is not supported for multi-GPU, TPU, or DeepSpeed). For manual optimization, this has no special meaning, as returning the loss is not required.
To use multiple optimizers, you can switch to ‘manual optimization’
and control their stepping:
Notes
When
accumulate_grad_batches
> 1, the loss returned here will be automatically normalized byaccumulate_grad_batches
internally.Examples
>>> def __init__(self): >>> super().__init__() >>> self.automatic_optimization = False >>> >>> >>> # Multiple optimizers (e.g.: GANs) >>> def training_step(self, batch, batch_idx): >>> opt1, opt2 = self.optimizers() >>> >>> # do training_step with encoder >>> ... >>> opt1.step() >>> # do training_step with decoder >>> ... >>> opt2.step()
- transform(test_dataloader: DataLoader) Any [source]¶
The transform method for transformer.
In the child class you will need to define:
a transform_step method for defining the transform instructions at each step.
- Parameters:
test_dataloader : torch DataLoader
testing samples.
- Returns:
out : torch Tensor
returns transformed samples.
- abstract transform_step(batch: Any, batch_idx: int, dataloader_idx: int | None = 0) Any [source]¶
Define a transform step.
Share the same API as
BaseEstimator.predict_step()
.
- validation_step(batch: Any, batch_idx: int, dataloader_idx: int | None = 0) Tensor | Mapping[str, Any] | None [source]¶
Operates on a single batch of data from the validation set. In this step you’d might generate examples or calculate anything of interest like accuracy.
- Parameters:
batch : iterable, normally a
DataLoader
the current data.
batch_idx : int
the index of this batch.
dataloader_idx : int, default=0
the index of the dataloader that produced this batch (only if multiple dataloaders are used).
- Returns:
loss : STEP_OUTPUT
the computed loss:
Tensor
- the loss tensor.dict
- a dictionary. can include any keys, but must include the key'loss'
.None
- skip to the next batch.
Notes
When the
validation_step()
is called, the model has been put in eval mode and PyTorch gradients have been disabled. At the end of validation, the model goes back to training mode and gradients are enabled.
Follow us