Note
This page is a reference documentation. It only explains the class signature, and not how to use it. Please refer to the user guide for the big picture.
nidl.estimators.ssl.IJEPA¶
- class nidl.estimators.ssl.IJEPA(encoder, dim=3, context_block_scale=(0.85, 1.0), target_block_scale=(0.15, 0.2), aspect_ratio=(0.75, 1.5), num_target_blocks=4, min_keep=4, allow_overlap=False, predictor_embed_dim=384, predictor_depth_pred=6, optimizer='adamW', learning_rate=0.0003, weight_decay=0.0005, exclude_bias_and_norm_wd=True, ema_start=0.996, ema_end=1.0, optimizer_kwargs=None, lr_scheduler='warmup_cosine', lr_scheduler_kwargs=None, **kwargs)[source]¶
Bases:
TransformerMixin,BaseEstimatorImplementation of I-JEPA [1].
Solver that predicts the representations of missing parts of an image based on its surrounding context. It uses two encoders (context encoder and target encoder) to obtain the contextual and target features and a third predictor network to obtain the predictions.
The target encoder is an Exponential Moving Average (EMA) of the context encoder. It is used to avoid representation collapse of the context encoder during training but it is thrown at inference. The context encoder is used for downstream tasks.
- Parameters:
- encodernn.Module
Vision Transformer-like encoder module. This encoder should follow the timm interface of VisionTransformer.
- dim: {2, 3}, default=3
Input data dimensionality. 3 == 3d volumes and 2 == 2d images.
- context_block_scale(float, float), default=(0.85, 1.0)
Range of scale of the context block.
- target_block_scale(float, float), default=(0.15, 0.2)
Range of scale of the target blocks.
- aspect_ratio(float, float), default=(0.75, 1.5)
Aspect ratio of the target blocks.
- num_target_blocksint, default=4
Number of target blocks to predict.
- min_keepint, default=4
Minimum number of patches to keep in the context/target block.
- allow_overlapbool, default=False
Whether to allow overlap between target and context blocks.
- predictor_embed_dimint, default=384
Dimension of the predictor hidden layers. It can be different from the encoder output dimension.
- predictor_depth_predint, default=6
Number of Transformer blocks in the predictor.
- optimizer{‘sgd’, ‘adam’, ‘adamW’} or Optimizer, default=”adamW”
Optimizer for training the model. If a string is given, it can be:
‘sgd’: Stochastic Gradient Descent (with optional momentum).
‘adam’: First-order gradient-based optimizer.
‘adamW’ (default): Adam with decoupled weight decay regularization (see “Decoupled Weight Decay Regularization”, Loshchilov and Hutter, ICLR 2019).
- learning_ratefloat, default=3e-4
Initial learning rate.
- weight_decayfloat, default=5e-4
Weight decay in the optimizer.
- exclude_bias_and_norm_wdbool, default=True
Whether the bias terms and normalization layers get weight decay during optimization or not.
- ema_startfloat, default=0.996
Base value for the weighting coefficient in the teacher momentum update with exponential moving average. A cosine annealing scheme is used.
- ema_endfloat, default=1.0
Final value for the weighting coefficient in the teacher momentum update.
- optimizer_kwargsdict or None, default=None
Extra named arguments for the optimizer.
- lr_scheduler{“none”, “warmup_cosine”}, LRSchedulerPLType or None, default=”warmup_cosine”
Learning rate scheduler to use.
- lr_scheduler_kwargsdict or None, default=None
Extra named arguments for the scheduler. By default, it is set to {“warmup_epochs”: 10, “warmup_start_lr”: 1e-6, “min_lr”: 0.0, “interval”: “step”}
- **kwargsdict, optional
Extra named arguments for the BaseEstimator class (given to PL Trainer), such as max_epochs, max_steps, num_sanity_val_steps, check_val_every_n_epoch, callbacks, etc. See the PL Trainer API for more details.
- Attributes:
- context_encodertorch.nn.Module
It corresponds to the context encoder in the I-JEPA model. It is used at inference time for downstream tasks.
- target_encodertorch.nn.Module
It corresponds to the target encoder in the I-JEPA model. It is an EMA of the context encoder during training and is not used at inference.
- predictortorch.nn.Module
Predictor model trained to predict the masked part of an image from a context in the latent space. This can be useful at inference time to predict missing part of the input.
- lossSmoothL1Loss
Smooth l1 loss used for training the model.
References
[1]Self-Supervised Learning from Images with a Joint-Embedding Predictive Architecture, Assran et al., ICCV 2023
- __init__(encoder, dim=3, context_block_scale=(0.85, 1.0), target_block_scale=(0.15, 0.2), aspect_ratio=(0.75, 1.5), num_target_blocks=4, min_keep=4, allow_overlap=False, predictor_embed_dim=384, predictor_depth_pred=6, optimizer='adamW', learning_rate=0.0003, weight_decay=0.0005, exclude_bias_and_norm_wd=True, ema_start=0.996, ema_end=1.0, optimizer_kwargs=None, lr_scheduler='warmup_cosine', lr_scheduler_kwargs=None, **kwargs)[source]¶
- on_train_batch_end(outputs, batch, batch_idx)[source]¶
Performs the teacher momentum update.
- Parameters:
- outputsdict[str, Any]
The outputs of the training step (ignored).
- batchtorch.Tensor or pair of torch.Tensor
A batch of input data (ignored).
- batch_idxint
The index of the current batch (ignored).
- training_step(batch, batch_idx)[source]¶
Perform one training step during an epoch and computes the training loss.
- Parameters:
- batchtorch.Tensor or (torch.Tensor, torch.Tensor)
A batch of data in the format
Xor(X, Y)whereXis a torch.Tensor with shape(B, C, H, W)(2d images) or(B, C, H, W, D)(3d volumes) representing the input data. Y are eventual labels (ignored).- batch_idxint
The index of the current batch (ignored).
- Returns:
- outputsdict
- Dictionary containing three torch.Tensors:
“loss”: training loss computed on this batch of data.
- “z_pred”: embeddings predictions for target tokens with
shape
(B * M, L, D)whereMis the number of target blocks (e.g. 4),Lis the number of target tokens predicted per block, andDis the embedding dimension.
“z_target”: embeddings to predict with the same shape as “z_pred”.
- transform_step(batch, batch_idx, dataloader_idx=0)[source]¶
Encode the input data into the latent space.
Importantly, we do not apply the predictor here since it is not part of the final model at inference time (only used for training).
- Parameters:
- batch: torch.Tensor
A batch of data that has been generated from test_dataloader. This is given as is to the context encoder.
- batch_idx: int
The index of the current batch (ignored).
- dataloader_idx: int, default=0
The index of the dataloader (ignored).
- Returns:
- features: torch.Tensor
The encoded features returned by the context encoder averaged across the sequence dimension.
- validation_step(batch, batch_idx)[source]¶
Performs one validation step and computes the validation loss.
- Parameters:
- batchtorch.Tensor or (torch.Tensor, torch.Tensor)
A batch of data in the format
Xor(X, Y)whereXis a torch.Tensor with shape(B, C, H, W)(2d images) or(B, C, H, W, D)(3d volumes) representing the input data. Y are eventual labels (ignored).- batch_idxint
The index of the current batch (ignored).
- Returns:
- outputsdict
- Dictionary containing three torch.Tensors:
“loss”: validation loss computed on this batch of data.
- “z_pred”: embeddings predictions for target tokens with
shape
(B * M, L, D)whereMis the number of target blocks (e.g. 4),Lis the number of target tokens predicted per block, andDis the embedding dimension.
“z_target”: embeddings to predict with the same shape as “z_pred”.
Examples using nidl.estimators.ssl.IJEPA¶
Self-Supervised Learning with I-JEPA on MedMNIST3D