Note
This page is a reference documentation. It only explains the class signature, and not how to use it. Please refer to the user guide for the big picture.
nidl.estimators.ssl.YAwareContrastiveLearning¶
- class nidl.estimators.ssl.YAwareContrastiveLearning(encoder, encoder_kwargs=None, proj_input_dim=2048, proj_hidden_dim=512, proj_output_dim=128, temperature=0.1, kernel='gaussian', bandwidth=1.0, optimizer='adamW', learning_rate=0.0003, weight_decay=0.0005, exclude_bias_and_norm_wd=True, optimizer_kwargs=None, lr_scheduler='warmup_cosine', lr_scheduler_kwargs=None, **kwargs)[source]¶
Bases:
TransformerMixin,BaseEstimatory-Aware Contrastive Learning [1].
y-Aware Contrastive Learning is a self-supervised learning framework for learning visual representations with auxiliary variables. It leverages contrastive learning by maximizing the agreement between differently augmented views of images with similar auxiliary variables while minimizing agreement between different images. The framework consists of:
Data Augmentation - Generates two augmented views of an image.
Kernel - Similarity function between auxiliary variables.
Encoder (Backbone Network) - Maps images to feature embeddings (e.g., 3D-ResNet).
Projection Head - Maps features to a latent space for contrastive loss optimization.
Contrastive Loss (y-Aware) - Encourages augmented views of i) the same image and ii) images with close auxiliary variables to be closer while pushing dissimilar ones apart.
- Parameters:
- encodernn.Module or class
Which deep architecture to use for encoding the input. A PyTorch torch.nn.Module is expected. In general, the uninstantiated class should be passed, although instantiated modules will also work.
- encoder_kwargsdict or None, default=None
Options for building the encoder (depends on each architecture). Ignored if encoder is instantiated.
- proj_input_dimint, default=2048
Projector input dimension. It must be consistent with encoder’s output dimension.
- proj_hidden_dimint, default=512
Projector hidden dimension.
- proj_output_dimint, default=128
Projector output dimension.
- temperaturefloat, default=0.1
Temperature value in y-Aware InfoNCE loss. Small values imply more uniformity between samples’ embeddings, whereas high values impose clustered embedding more sensitive to augmentations.
- kernel{‘gaussian’, ‘epanechnikov’, ‘exponential’, ‘linear’, ‘cosine’}, default=”gaussian”
Kernel used as a similarity function between auxiliary variables.
- bandwidthUnion[float, List[float], array, KernelMetric], default=1.0
The method used to calculate the bandwidth (“sigma^2” in [1]) between auxiliary variables:
If bandwidth is a scalar, it sets the bandwidth to a diagnonal matrix with equal values.
If bandwidth is a 1d array, it sets the bandwidth to a diagonal matrix and it must be of size equal to the number of features in y.
If bandwidth is a 2d array, it must be of shape (n_features, n_features) where n_features is the number of features in y.
If bandwidth is KernelMetric, it uses the pairwise method to compute the similarity matrix between auxiliary variables.
- optimizer{‘sgd’, ‘adam’, ‘adamW’} or torch.optim.Optimizer or type, default=”adamW”
Optimizer for training the model. If a string is given, it can be:
‘sgd’: Stochastic Gradient Descent (with optional momentum).
‘adam’: First-order gradient-based optimizer.
‘adamW’ (default): Adam with decoupled weight decay regularization (see “Decoupled Weight Decay Regularization”, Loshchilov and Hutter, ICLR 2019).
- learning_ratefloat, default=3e-4
Initial learning rate.
- weight_decayfloat, default=5e-4
Weight decay in the optimizer.
- exclude_bias_and_norm_wdbool, default=True
Whether the bias terms and normalization layers get weight decay during optimization or not.
- optimizer_kwargsdict or None, default=None
Extra named arguments for the optimizer.
- lr_scheduler{“none”, “warmup_cosine”}, LRSchedulerPLType or None, default=”warmup_cosine”
Learning rate scheduler to use.
- lr_scheduler_kwargsdict or None, default=None
Extra named arguments for the scheduler. By default, it is set to {“warmup_epochs”: 10, “warmup_start_lr”: 1e-6, “min_lr”: 0.0, “interval”: “step”}
- **kwargsdict, optional
Additional keyword arguments for the BaseEstimator class, such as max_epochs, max_steps, num_sanity_val_steps, check_val_every_n_epoch, callbacks, etc.
- Attributes:
- encoder: torch.nn.Module
Deep neural network mapping input data to low-dimensional vectors.
- projection_head: torch.nn.Module
Maps encoder output to latent space for contrastive loss optimization.
- loss: yAwareInfoNCE
The yAwareInfoNCE loss function used for training.
- optimizer: torch.optim.Optimizer
Optimizer used for training.
- lr_scheduler: LRSchedulerPLType or None
Learning rate scheduler used for training.
References
[1]Dufumier, B., et al., “Contrastive learning with continuous proxy meta-data for 3D MRI classification.” MICCAI, 2021. https://arxiv.org/abs/2106.08808
- __init__(encoder, encoder_kwargs=None, proj_input_dim=2048, proj_hidden_dim=512, proj_output_dim=128, temperature=0.1, kernel='gaussian', bandwidth=1.0, optimizer='adamW', learning_rate=0.0003, weight_decay=0.0005, exclude_bias_and_norm_wd=True, optimizer_kwargs=None, lr_scheduler='warmup_cosine', lr_scheduler_kwargs=None, **kwargs)[source]¶
- training_step(batch, batch_idx, dataloader_idx=0)[source]¶
Perform one training step and computes training loss.
- Parameters:
- batch: Sequence[Any]
A batch of data from the train dataloader. Supported formats are
[X1, X2]or([X1, X2], y), whereX1andX2are tensors representing two augmented views of the same samples andyis the auxiliary variable (e.g., age).- batch_idx: int
The index of the current batch (ignored).
- dataloader_idx: int, default=0
The index of the dataloader (ignored).
- Returns:
- outputsdict
- Dictionary containing:
“loss”: the y-Aware loss computed on this batch;
“z1”: tensor of shape (batch_size, n_features);
“z2”: tensor of shape (batch_size, n_features);
“y”: auxiliary variables.
- transform_step(batch, batch_idx, dataloader_idx=0)[source]¶
Encode the input data into the latent space.
Importantly, we do not apply the projection head here since it is not part of the final model at inference time (only used for training).
- Parameters:
- batch: torch.Tensor
A batch of data that has been generated from test_dataloader. This is given as is to the encoder.
- batch_idx: int
The index of the current batch (ignored).
- dataloader_idx: int, default=0
The index of the dataloader (ignored).
- Returns:
- features: torch.Tensor
The encoded features returned by the encoder.
- validation_step(batch, batch_idx, dataloader_idx=0)[source]¶
Perform one validation step and computes validation loss.
- Parameters:
- batch: Sequence[Any]
A batch of data from the validation dataloader. Supported formats are
[X1, X2]or([X1, X2], y).- batch_idx: int
The index of the current batch (ignored).
- dataloader_idx: int, default=0
The index of the dataloader (ignored).
- Returns:
- outputsdict
- Dictionary containing:
“loss”: the y-Aware loss computed on this batch;
“z1”: tensor of shape (batch_size, n_features);
“z2”: tensor of shape (batch_size, n_features);
“y”: auxiliary variables.
Examples using nidl.estimators.ssl.YAwareContrastiveLearning¶
Weakly Supervised Contrastive Learning with y-Aware