Note

This page is a reference documentation. It only explains the class signature, and not how to use it. Please refer to the user guide for the big picture.

nidl.estimators.ssl.IJEPA

class nidl.estimators.ssl.IJEPA(encoder, dim=3, context_block_scale=(0.85, 1.0), target_block_scale=(0.15, 0.2), aspect_ratio=(0.75, 1.5), num_target_blocks=4, min_keep=4, allow_overlap=False, predictor_embed_dim=384, predictor_depth_pred=6, optimizer='adamW', learning_rate=0.0003, weight_decay=0.0005, exclude_bias_and_norm_wd=True, ema_start=0.996, ema_end=1.0, optimizer_kwargs=None, lr_scheduler='warmup_cosine', lr_scheduler_kwargs=None, **kwargs)[source]

Bases: TransformerMixin, BaseEstimator

Implementation of I-JEPA [1].

Solver that predicts the representations of missing parts of an image based on its surrounding context. It uses two encoders (context encoder and target encoder) to obtain the contextual and target features and a third predictor network to obtain the predictions.

The target encoder is an Exponential Moving Average (EMA) of the context encoder. It is used to avoid representation collapse of the context encoder during training but it is thrown at inference. The context encoder is used for downstream tasks.

Parameters:
encodernn.Module

Vision Transformer-like encoder module. This encoder should follow the timm interface of VisionTransformer.

dim: {2, 3}, default=3

Input data dimensionality. 3 == 3d volumes and 2 == 2d images.

context_block_scale(float, float), default=(0.85, 1.0)

Range of scale of the context block.

target_block_scale(float, float), default=(0.15, 0.2)

Range of scale of the target blocks.

aspect_ratio(float, float), default=(0.75, 1.5)

Aspect ratio of the target blocks.

num_target_blocksint, default=4

Number of target blocks to predict.

min_keepint, default=4

Minimum number of patches to keep in the context/target block.

allow_overlapbool, default=False

Whether to allow overlap between target and context blocks.

predictor_embed_dimint, default=384

Dimension of the predictor hidden layers. It can be different from the encoder output dimension.

predictor_depth_predint, default=6

Number of Transformer blocks in the predictor.

optimizer{‘sgd’, ‘adam’, ‘adamW’} or Optimizer, default=”adamW”

Optimizer for training the model. If a string is given, it can be:

  • ‘sgd’: Stochastic Gradient Descent (with optional momentum).

  • ‘adam’: First-order gradient-based optimizer.

  • ‘adamW’ (default): Adam with decoupled weight decay regularization (see “Decoupled Weight Decay Regularization”, Loshchilov and Hutter, ICLR 2019).

learning_ratefloat, default=3e-4

Initial learning rate.

weight_decayfloat, default=5e-4

Weight decay in the optimizer.

exclude_bias_and_norm_wdbool, default=True

Whether the bias terms and normalization layers get weight decay during optimization or not.

ema_startfloat, default=0.996

Base value for the weighting coefficient in the teacher momentum update with exponential moving average. A cosine annealing scheme is used.

ema_endfloat, default=1.0

Final value for the weighting coefficient in the teacher momentum update.

optimizer_kwargsdict or None, default=None

Extra named arguments for the optimizer.

lr_scheduler{“none”, “warmup_cosine”}, LRSchedulerPLType or None, default=”warmup_cosine”

Learning rate scheduler to use.

lr_scheduler_kwargsdict or None, default=None

Extra named arguments for the scheduler. By default, it is set to {“warmup_epochs”: 10, “warmup_start_lr”: 1e-6, “min_lr”: 0.0, “interval”: “step”}

**kwargsdict, optional

Extra named arguments for the BaseEstimator class (given to PL Trainer), such as max_epochs, max_steps, num_sanity_val_steps, check_val_every_n_epoch, callbacks, etc. See the PL Trainer API for more details.

Attributes:
context_encodertorch.nn.Module

It corresponds to the context encoder in the I-JEPA model. It is used at inference time for downstream tasks.

target_encodertorch.nn.Module

It corresponds to the target encoder in the I-JEPA model. It is an EMA of the context encoder during training and is not used at inference.

predictortorch.nn.Module

Predictor model trained to predict the masked part of an image from a context in the latent space. This can be useful at inference time to predict missing part of the input.

lossSmoothL1Loss

Smooth l1 loss used for training the model.

References

[1]

Self-Supervised Learning from Images with a Joint-Embedding Predictive Architecture, Assran et al., ICCV 2023

__init__(encoder, dim=3, context_block_scale=(0.85, 1.0), target_block_scale=(0.15, 0.2), aspect_ratio=(0.75, 1.5), num_target_blocks=4, min_keep=4, allow_overlap=False, predictor_embed_dim=384, predictor_depth_pred=6, optimizer='adamW', learning_rate=0.0003, weight_decay=0.0005, exclude_bias_and_norm_wd=True, ema_start=0.996, ema_end=1.0, optimizer_kwargs=None, lr_scheduler='warmup_cosine', lr_scheduler_kwargs=None, **kwargs)[source]
configure_optimizers()[source]

Initialize the optimizer and learning rate scheduler.

forward_target(x, target_blocks)[source]
on_train_batch_end(outputs, batch, batch_idx)[source]

Performs the teacher momentum update.

Parameters:
outputsdict[str, Any]

The outputs of the training step (ignored).

batchtorch.Tensor or pair of torch.Tensor

A batch of input data (ignored).

batch_idxint

The index of the current batch (ignored).

test_step(batch, batch_idx)[source]

Skip the test step.

training_step(batch, batch_idx)[source]

Perform one training step during an epoch and computes the training loss.

Parameters:
batchtorch.Tensor or (torch.Tensor, torch.Tensor)

A batch of data in the format X or (X, Y) where X is a torch.Tensor with shape (B, C, H, W) (2d images) or (B, C, H, W, D) (3d volumes) representing the input data. Y are eventual labels (ignored).

batch_idxint

The index of the current batch (ignored).

Returns:
outputsdict
Dictionary containing three torch.Tensors:
  • “loss”: training loss computed on this batch of data.

  • “z_pred”: embeddings predictions for target tokens with

    shape (B * M, L, D) where M is the number of target blocks (e.g. 4), L is the number of target tokens predicted per block, and D is the embedding dimension.

  • “z_target”: embeddings to predict with the same shape as “z_pred”.

transform_step(batch, batch_idx, dataloader_idx=0)[source]

Encode the input data into the latent space.

Importantly, we do not apply the predictor here since it is not part of the final model at inference time (only used for training).

Parameters:
batch: torch.Tensor

A batch of data that has been generated from test_dataloader. This is given as is to the context encoder.

batch_idx: int

The index of the current batch (ignored).

dataloader_idx: int, default=0

The index of the dataloader (ignored).

Returns:
features: torch.Tensor

The encoded features returned by the context encoder averaged across the sequence dimension.

validation_step(batch, batch_idx)[source]

Performs one validation step and computes the validation loss.

Parameters:
batchtorch.Tensor or (torch.Tensor, torch.Tensor)

A batch of data in the format X or (X, Y) where X is a torch.Tensor with shape (B, C, H, W) (2d images) or (B, C, H, W, D) (3d volumes) representing the input data. Y are eventual labels (ignored).

batch_idxint

The index of the current batch (ignored).

Returns:
outputsdict
Dictionary containing three torch.Tensors:
  • “loss”: validation loss computed on this batch of data.

  • “z_pred”: embeddings predictions for target tokens with

    shape (B * M, L, D) where M is the number of target blocks (e.g. 4), L is the number of target tokens predicted per block, and D is the embedding dimension.

  • “z_target”: embeddings to predict with the same shape as “z_pred”.

Examples using nidl.estimators.ssl.IJEPA

Self-Supervised Learning with I-JEPA on MedMNIST3D

Self-Supervised Learning with I-JEPA on MedMNIST3D