Note

This page is a reference documentation. It only explains the class signature, and not how to use it. Please refer to the user guide for the big picture.

nidl.estimators.ssl.BarlowTwins

class nidl.estimators.ssl.BarlowTwins(encoder, encoder_kwargs=None, proj_input_dim=2048, proj_hidden_dim=2048, proj_output_dim=2048, lambd=0.005, optimizer='adamW', learning_rate=0.0003, weight_decay=0.0005, exclude_bias_and_norm_wd=True, optimizer_kwargs=None, lr_scheduler='warmup_cosine', lr_scheduler_kwargs=None, **kwargs)[source]

Bases: TransformerMixin, BaseEstimator

Barlow Twins [1].

Barlow Twins is a self-supervised learning model for learning visual representations by i) imposing invariance to data augmentation and ii) reducing the redundancy between output features. Contrary to contrastive methods, it does not rely on negative samples. The framework consists of:

  1. Data Augmentation - Generates two augmented views of an image.

  2. Encoder (Backbone Network) - Maps images to feature embeddings (e.g., 3D-ResNet).

  3. Projection Head - Maps features to a latent space for Barlow Twins loss optimization. The projector dimension in Barlow Twins is typically very high (e.g., 8192 or 16384) compared to the features dimension (e.g., 2048 in ResNet-50). This is a key difference with other SSL methods.

  4. Redundancy reduction loss in addition to a data augmentation invariance loss.

Parameters:
encodernn.Module or class

Which deep architecture to use for encoding the input. A PyTorch Module is expected. In general, the uninstantiated class should be passed, although instantiated modules will also work.

encoder_kwargsdict or None, default=None

Options for building the encoder (depends on each architecture). Examples:

  • encoder=torchvision.ops.MLP, encoder_kwargs={“in_channels”: 10, “hidden_channels”: [4, 3, 2]} builds an MLP with 3 hidden layers, input dim 10, output dim 2.

  • encoder=nidl.volume.backbones.resnet3d.resnet18, encoder_kwargs={“n_embedding”: 10} builds a ResNet-18 model with 10 output dimension.

Ignored if encoder is instantiated.

proj_input_dimint, default=2048

Projector input dimension. It must be consistent with encoder’s output dimension.

proj_hidden_dimint, default=2048

Projector hidden dimension. Original value in [1] is 8192, but it can be reduced for computational reasons.

proj_output_dimint, default=2048

Projector output dimension. Original value in [1] is 8192, but it can be reduced for computational reasons.

lambdfloat, default=5e-3

lambda value in the BarlowTwins loss. Trading off the importance of the redundancy reduction term over the invariance term.

optimizer{‘sgd’, ‘adam’, ‘adamW’} or torch.optim.Optimizer or type, default=”adamW”

Optimizer for training the model. If a string is given, it can be:

  • ‘sgd’: Stochastic Gradient Descent (with optional momentum).

  • ‘adam’: First-order gradient-based optimizer.

  • ‘adamW’ (default): Adam with decoupled weight decay regularization (see “Decoupled Weight Decay Regularization”, Loshchilov and Hutter, ICLR 2019).

learning_ratefloat, default=3e-4

Initial learning rate.

weight_decayfloat, default=5e-4

Weight decay in the optimizer.

exclude_bias_and_norm_wdbool, default=True

Whether the bias terms and normalization layers get weight decay during optimization or not.

optimizer_kwargsdict or None, default=None

Extra named arguments for the optimizer.

lr_scheduler{“none”, “warmup_cosine”}, LRSchedulerPLType or None, default=”warmup_cosine”

Learning rate scheduler to use.

lr_scheduler_kwargsdict or None, default=None

Extra named arguments for the scheduler. By default, it is set to {“warmup_epochs”: 10, “warmup_start_lr”: 1e-6, “min_lr”: 0.0, “interval”: “step”}

**kwargsdict, optional

Additional keyword arguments for the BaseEstimator class, such as max_epochs, max_steps, num_sanity_val_steps, check_val_every_n_epoch, callbacks, etc.

Attributes:
encodertorch.nn.Module

Deep neural network mapping input data to low-dimensional vectors.

projection_headtorch.nn.Module

Projector that maps encoder output to latent space for loss optimization.

lossBarlowTwinsLoss

The BarlowTwins loss function used for training.

optimizertorch.optim.Optimizer

Optimizer used for training.

lr_schedulerLRSchedulerPLType or None

Learning rate scheduler used for training.

References

[1] (1,2,3)

Zbontar, J., et al., “Barlow Twins: Self-Supervised Learning via Redundancy Reduction.” PMLR, 2021. hhttps://proceedings.mlr.press/v139/zbontar21a

__init__(encoder, encoder_kwargs=None, proj_input_dim=2048, proj_hidden_dim=2048, proj_output_dim=2048, lambd=0.005, optimizer='adamW', learning_rate=0.0003, weight_decay=0.0005, exclude_bias_and_norm_wd=True, optimizer_kwargs=None, lr_scheduler='warmup_cosine', lr_scheduler_kwargs=None, **kwargs)[source]
configure_optimizers()[source]

Initialize the optimizer and learning rate scheduler in Barlow Twins.

test_step(batch, batch_idx)[source]

Skip the test step.

training_step(batch, batch_idx, dataloader_idx=0)[source]

Perform one training step and computes training loss.

Parameters:
batch: Sequence[Any]

A batch of data from the train dataloader. Supported formats are [X1, X2] or ([X1, X2], y), where X1 and X2 are tensors representing two augmented views of the same samples.

batch_idx: int

The index of the current batch (ignored).

dataloader_idx: int, default=0

The index of the dataloader (ignored).

Returns:
outputsdict
Dictionary containing:
  • “loss”: the Barlow Twins loss computed on this batch;

  • “z1”: tensor of shape (batch_size, n_features);

  • “z2”: tensor of shape (batch_size, n_features);

  • “y”: eventual targets (returned as is).

transform_step(batch, batch_idx, dataloader_idx=0)[source]

Encode the input data into the latent space.

Importantly, we do not apply the projection head here since it is not part of the final model at inference time (only used for training).

Parameters:
batch: torch.Tensor

A batch of data that has been generated from test_dataloader. This is given as is to the encoder.

batch_idx: int

The index of the current batch (ignored).

dataloader_idx: int, default=0

The index of the dataloader (ignored).

Returns:
features: torch.Tensor

The encoded features returned by the encoder.

validation_step(batch, batch_idx, dataloader_idx=0)[source]

Perform one validation step and computes validation loss.

Parameters:
batch: Sequence[Any]

A batch of data from the validation dataloader. Supported formats are [X1, X2] or ([X1, X2], y).

batch_idx: int

The index of the current batch (ignored).

dataloader_idx: int, default=0

The index of the dataloader (ignored).

Returns:
outputsdict
Dictionary containing:
  • “loss”: the Barlow Twins loss computed on this batch;

  • “z1”: tensor of shape (batch_size, n_features)

  • “z2”: tensor of shape (batch_size, n_features)

  • “y”: eventual targets (returned as is)

Examples using nidl.estimators.ssl.BarlowTwins

Self-Supervised Learning with Barlow Twins

Self-Supervised Learning with Barlow Twins