Menu

Deep learning for NeuroImaging in Python.

Note

This page is a reference documentation. It only explains the class signature, and not how to use it. Please refer to the gallery for the big picture.

class nidl.estimators.ssl.yaware.YAwareContrastiveLearning(encoder: ~torch.nn.modules.module.Module | type[~torch.nn.modules.module.Module], encoder_kwargs: dict[str, ~typing.Any] | None = None, projection_head: ~torch.nn.modules.module.Module | type[~torch.nn.modules.module.Module] | None = <class 'nidl.estimators.ssl.utils.projection_heads.YAwareProjectionHead'>, projection_head_kwargs: dict[str, ~typing.Any] | None = None, temperature: float = 0.1, kernel: str = 'gaussian', bandwidth: float | list[float] | ~numpy.ndarray | ~nidl.losses.yaware_infonce.KernelMetric = 1.0, optimizer: str | ~torch.optim.optimizer.Optimizer | type[~torch.optim.optimizer.Optimizer] = 'adam', optimizer_kwargs: dict[str, ~typing.Any] | None = None, learning_rate: float = 0.0001, lr_scheduler: ~torch.optim.lr_scheduler.LRScheduler | ~torch.optim.lr_scheduler.ReduceLROnPlateau | type[~torch.optim.lr_scheduler.LRScheduler | ~torch.optim.lr_scheduler.ReduceLROnPlateau] | None = None, lr_scheduler_kwargs: dict[str, ~typing.Any] | None = None, **kwargs: ~typing.Any)[source]

Bases: TransformerMixin, BaseEstimator

y-Aware Contrastive Learning implementation [1]

y-Aware Contrastive Learning is a self-supervised learning framework for learning visual representations with auxiliary variables. It leverages contrastive learning by maximizing the agreement between differently augmented views of images with similar auxiliary variables while minimizing agreement between different images. The framework consists of:

  1. Data Augmentation - Generates two augmented views of an image.

  2. Kernel - Similarity function between auxiliary variables.

  3. Encoder (Backbone Network) - Maps images to feature embeddings (e.g., 3D-ResNet).

  4. Projection Head - Maps features to a latent space for contrastive loss optimization.

  5. Contrastive Loss (y-Aware) - Encourages augmented views of i) the same image and ii) images with close auxiliary variables to be closer while pushing dissimilar ones apart.

Parameters:

encoder : nn.Module or class

Which deep architecture to use for encoding the input. A PyTorch Module is expected. In general, the uninstantiated class should be passed, although instantiated modules will also work.

encoder_kwargs : dict or None, default=None

Options for building the encoder (depends on each architecture). Examples:

  • encoder=torchvision.ops.MLP, encoder_kwargs={“in_channels”: 10, “hidden_channels”: [4, 3, 2]} builds an MLP with 3 hidden layers, input dim 10, output dim 2.

  • encoder=nidl.volume.backbones.resnet3d.resnet18, encoder_kwargs={“n_embedding”: 10} builds a ResNet-18 model with 10 output dimension.

Ignored if encoder is instantiated.

projection_head : nn.Module or class or None, default=YAwareProjectionHead

Which projection head to use for the model. If None, no projection head is used and the encoder output is directly used for loss computation. Otherwise, a Module is expected. In general, the uninstantiated class should be passed, although instantiated modules will also work. By default, a 2-layer MLP with ReLU activation, 2048-d hidden units, and 128-d output dimensions is used.

projection_head_kwargs : dict or None, default=None

Arguments for building the projection head. By default, input dimension is 2048-d and output dimension is 128-d. These can be changed by passing a dictionary with keys ‘input_dim’ and ‘output_dim’. ‘input_dim’ must be equal to the encoder’s output dimension. Ignored if projection_head is instantiated.

temperature : float, default=0.1

Temperature value in y-Aware InfoNCE loss. Small values imply more uniformity between samples’ embeddings, whereas high values impose clustered embedding more sensitive to augmentations.

kernel : {‘gaussian’, ‘epanechnikov’, ‘exponential’, ‘linear’, ‘cosine’}, default=”gaussian”

Kernel used as a similarity function between auxiliary variables.

bandwidth : Union[float, int, List[float], array, KernelMetric], default=1.0

The method used to calculate the bandwidth (“sigma” in [1]) between auxiliary variables:

  • If bandwidth is a scalar (int or float), it sets the bandwidth to a diagnonal matrix with equal values.

  • If bandwidth is a 1d array, it sets the bandwidth to a diagonal matrix and it must be of size equal to the number of features in y.

  • If bandwidth is a 2d array, it must be of shape (n_features, n_features) where n_features is the number of features in y.

  • If bandwidth is KernelMetric, it uses the pairwise method to compute the similarity matrix between auxiliary variables.

optimizer : {‘sgd’, ‘adam’, ‘adamW’} or torch.optim.Optimizer or type, default=”adam”

Optimizer for training the model. Can be:

  • A string:

    • ‘sgd’: Stochastic Gradient Descent (with optional momentum).

    • ‘adam’: First-order gradient-based optimizer (default).

    • ‘adamW’: Adam with decoupled weight decay regularization (see “Decoupled Weight Decay Regularization”, Loshchilov and Hutter, ICLR 2019).

  • An instance or subclass of torch.optim.Optimizer.

optimizer_kwargs : dict or None, default=None

Arguments for the optimizer (‘adam’ by default). By default: {‘betas’: (0.9, 0.99), ‘weight_decay’: 5e-05} where ‘betas’ are the exponential decay rates for first and second moment estimates.

Ignored if optimizer is instantiated.

learning_rate : float, default=1e-4

Initial learning rate.

lr_scheduler : LRSchedulerPLType or class or None, default=None

Learning rate scheduler to use.

lr_scheduler_kwargs : dict or None, default=None

Additional keyword arguments for the scheduler.

Ignored if lr_scheduler is instantiated.

**kwargs : dict, optional

Additional keyword arguments for the BaseEstimator class, such as max_epochs, max_steps, num_sanity_val_steps, check_val_every_n_epoch, callbacks, etc.

References

[1] Contrastive Learning with Continuous Proxy Meta-Data for 3D MRI

Classification, Dufumier et al., MICCAI 2021

Attributes

encoder

(torch.nn.Module) Deep neural network mapping input data to low-dimensional vectors.

projection_head

(torch.nn.Module) Maps encoder output to latent space for contrastive loss optimization.

loss

(yAwareInfoNCE) The yAwareInfoNCE loss function used for training.

optimizer

(torch.optim.Optimizer) Optimizer used for training.

lr_scheduler

(LRSchedulerPLType or None) Learning rate scheduler used for training.

configure_optimizers()[source]

Choose what optimizers and learning-rate schedulers to use in your optimization. Normally you’d need one. But in the case of GANs or similar you might have multiple. Optimization with multiple optimizers only works in the manual optimization mode.

Return:

Any of these 6 options.

  • Single optimizer.

  • List or Tuple of optimizers.

  • Two lists - The first list has multiple optimizers, and the second has multiple LR schedulers (or multiple lr_scheduler_config).

  • Dictionary, with an "optimizer" key, and (optionally) a "lr_scheduler" key whose value is a single LR scheduler or lr_scheduler_config.

  • None - Fit will run without any optimizer.

The lr_scheduler_config is a dictionary which contains the scheduler and its associated configuration. The default configuration is shown below.

lr_scheduler_config = {
    # REQUIRED: The scheduler instance
    "scheduler": lr_scheduler,
    # The unit of the scheduler's step size, could also be 'step'.
    # 'epoch' updates the scheduler on epoch end whereas 'step'
    # updates it after a optimizer update.
    "interval": "epoch",
    # How many epochs/steps should pass between calls to
    # `scheduler.step()`. 1 corresponds to updating the learning
    # rate after every epoch/step.
    "frequency": 1,
    # Metric to monitor for schedulers like `ReduceLROnPlateau`
    "monitor": "val_loss",
    # If set to `True`, will enforce that the value specified 'monitor'
    # is available when the scheduler is updated, thus stopping
    # training if not found. If set to `False`, it will only produce a warning
    "strict": True,
    # If using the `LearningRateMonitor` callback to monitor the
    # learning rate progress, this keyword can be used to specify
    # a custom logged name
    "name": None,
}

When there are schedulers in which the .step() method is conditioned on a value, such as the torch.optim.lr_scheduler.ReduceLROnPlateau scheduler, Lightning requires that the lr_scheduler_config contains the keyword "monitor" set to the metric name that the scheduler should be conditioned on.

# The ReduceLROnPlateau scheduler requires a monitor
def configure_optimizers(self):
    optimizer = Adam(...)
    return {
        "optimizer": optimizer,
        "lr_scheduler": {
            "scheduler": ReduceLROnPlateau(optimizer, ...),
            "monitor": "metric_to_track",
            "frequency": "indicates how often the metric is updated",
            # If "monitor" references validation metrics, then "frequency" should be set to a
            # multiple of "trainer.check_val_every_n_epoch".
        },
    }

# In the case of two optimizers, only one using the ReduceLROnPlateau scheduler
def configure_optimizers(self):
    optimizer1 = Adam(...)
    optimizer2 = SGD(...)
    scheduler1 = ReduceLROnPlateau(optimizer1, ...)
    scheduler2 = LambdaLR(optimizer2, ...)
    return (
        {
            "optimizer": optimizer1,
            "lr_scheduler": {
                "scheduler": scheduler1,
                "monitor": "metric_to_track",
            },
        },
        {"optimizer": optimizer2, "lr_scheduler": scheduler2},
    )

Metrics can be made available to monitor by simply logging it using self.log('metric_to_track', metric_val) in your LightningModule.

Note:

Some things to know:

  • Lightning calls .backward() and .step() automatically in case of automatic optimization.

  • If a learning rate scheduler is specified in configure_optimizers() with key "interval" (default “epoch”) in the scheduler configuration, Lightning will call the scheduler’s .step() method automatically in case of automatic optimization.

  • If you use 16-bit precision (precision=16), Lightning will automatically handle the optimizer.

  • If you use torch.optim.LBFGS, Lightning handles the closure function automatically for you.

  • If you use multiple optimizers, you will have to switch to ‘manual optimization’ mode and step them yourself.

  • If you need to control how often the optimizer steps, override the optimizer_step() hook.

parse_batch(batch: Any) tuple[Tensor, Tensor, Tensor | None][source]

Parses the batch to extract the two views and the auxiliary variable.

Parameters:

batch : Any

Parse a batch input and return V1, V2, and y. The batch can be either:

  • (V1, V2): two views of the same sample.

  • ((V1, V2), y): two views and an auxiliary label or variable.

Returns:

V1 : torch.Tensor

First view of the input.

V2 : torch.Tensor

Second view of the input.

y : Optional[torch.Tensor]

Auxiliary label or variable, if present. Otherwise, None.

training_step(batch: Any, batch_idx: int)[source]

Perform one training step and computes training loss.

Parameters:

batch : Any

A batch of data that has been generated from train_dataloader. It can be a pair of torch.Tensor (V1, V2) or a pair ((V1, V2), y) where V1 and V2 are the two views of the same sample and y is the auxiliary variable.

batch_idx : int

The index of the current batch.

Returns:

loss : Tensor

Training loss computed on this batch of data.

transform_step(batch: Tensor, batch_idx: int, dataloader_idx: int | None = 0)[source]

Define a transform step.

Share the same API as BaseEstimator.predict_step().

validation_step(batch: Any, batch_idx: int)[source]

Perform one validation step and computes validation loss.

Parameters:

batch : Any

A batch of data that has been generated from val_dataloader. It can be a pair of torch.Tensor (V1, V2) or a pair ((V1, V2), y) where V1 and V2 are the two views of the same sample and y is the auxiliary variable.

batch_idx : int

The index of the current batch.

Follow us

© 2025, nidl developers