Self-Supervised Learning with I-JEPA on MedMNIST3D

This example demonstrates how to pretrain a 3D vision transformer with I-JEPA [1] on a MedMNIST3D dataset [2] using the nidl library. It will also show you how to evaluate the learned representations with a simple linear probe.

I-JEPA: the key idea behind I-JEPA is to learn representations by predicting masked-out image blocks from their surrounding context in the latent space . This is the main difference with Masked Autoencoders (MAE) which predicts the masked-out blocks in the pixel space.

3D adaptation: I-JEPA is designed to be flexible. Even if the original implementation only used 2d images, its extension to 3d volumes is straightforward. There are two key differences with the 2d case: the tokenization is performed with 3D patches, and the positional embeddings are 3D as well. As for the masking strategy, it follows the same random block subsampling strategy as in 2d.

In this tutorial, we will follow these steps:

  1. Load a MedMNIST3D dataset.

  2. Build a 3D vision transformer encoder.

  3. Train an I-JEPA model, or optionally load pretrained weights from the Hugging Face Hub.

  4. Evaluate the pretrained encoder on the downstream classification task with a logistic regression probe.

In this example we use OrganMNIST3D, one of the 3D datasets distributed by MedMNIST. MedMNIST3D datasets are lightweight 3D medical image classification benchmarks standardized to a common spatial size, which makes them convenient for prototyping self-supervised pipelines.

Setup

This example requires medmnist in addition to nidl. If you want to load pretrained weights from the Hugging Face Hub, you also need huggingface_hub installed in your environment.

from __future__ import annotations

import os
from typing import Callable, Optional

import matplotlib.pyplot as plt
import medmnist
import numpy as np
import torch
from lightning_fabric import seed_everything
from medmnist import INFO
from sklearn.linear_model import LogisticRegression
from sklearn.metrics import accuracy_score
from torch.utils.data import DataLoader
from torchvision.transforms import Compose

from nidl.estimators.ssl import IJEPA
from nidl.utils.weights import Weights
from nidl.volume.backbones import VisionTransformer3D
from nidl.volume.transforms.augmentation import RandomResizedCrop
from nidl.volume.transforms.preprocessing import ZNormalization

We define some global parameters that will be used throughout the example.

Training a 3D I-JEPA model can take substantial time depending on your hardware. By default, this example trains a lightweight configuration that is suitable for a tutorial. If you later publish pretrained weights to the Hugging Face Hub, set load_pretrained = True and fill the corresponding repository information below.

Data-related parameters

# Directory where to download MedMNIST data
data_dir = "/tmp/medmnist"
# Directory where to cache optional pretrained weights
model_dir = "/tmp/nidl_example_ijepa_medmnist"
# MedMNIST3D dataset to use
dataset_name = "organmnist3d"
# Spatial size used by MedMNIST+ for 3D datasets
img_size = 64
# Whether to load a pretrained checkpoint from HF or train locally
load_pretrained = True
# Fill these two values once a checkpoint is published on HF
hf_repo_id = "neurospin/nidl_example_ijepa_medmnist"
hf_checkpoint = "nidl_example_ijepa_medmnist.pt"

Reproducibility and training configuration

# What accelerator to use: GPU if available, else CPU
accelerator = "gpu" if torch.cuda.is_available() else "cpu"
# Parameters for the data loaders. Reduce them if you run out of memory.
batch_size = 16
num_workers = 4
# Training configuration
max_epochs = 20
learning_rate = 3e-4
weight_decay = 5e-4
random_seed = 42

seed_everything(random_seed)
rd_generator = np.random.default_rng(seed=random_seed)

Data preparation

We first define a small MedMNIST3D dataset wrapper and the transforms used for self-supervised pretraining and downstream evaluation.

class MedMNIST3DDataset:
    """Simple wrapper around a MedMNIST3D split."""

    def __init__(
        self,
        dataset_name: str,
        root: str,
        split: str,
        transform: Optional[Callable] = None,
        size: int = 64,
        download: bool = True,
    ):
        dataset_name = dataset_name.lower()
        if dataset_name not in INFO:
            raise ValueError(
                f"Unknown MedMNIST dataset '{dataset_name}'. "
                f"Available datasets include: {sorted(INFO.keys())}"
            )
        if "3d" not in dataset_name:
            raise ValueError(
                f"This example is written for MedMNIST3D datasets, got "
                f"'{dataset_name}'."
            )

        info = INFO[dataset_name]
        dataset_cls = getattr(medmnist, info["python_class"])
        os.makedirs(root, exist_ok=True)
        self.dataset = dataset_cls(
            split=split,
            root=root,
            transform=transform,
            download=download,
            size=size,
            as_rgb=False,
        )

    def __len__(self):
        return len(self.dataset)

    def __getitem__(self, index):
        return self.dataset[index]
train_transform = Compose(
    [
        RandomResizedCrop(img_size, scale=(0.5, 1.0)),
        ZNormalization(),
        lambda x: torch.from_numpy(x).float(),
    ]
)

eval_transform = Compose(
    [
        ZNormalization(),
        lambda x: torch.from_numpy(x).float(),
    ]
)

We create one dataset for self-supervised training, and labeled datasets for linear probing.

ssl_dataset = MedMNIST3DDataset(
    dataset_name=dataset_name,
    root=data_dir,
    split="train",
    transform=train_transform,
    size=img_size,
)

train_xy_dataset = MedMNIST3DDataset(
    dataset_name=dataset_name,
    root=data_dir,
    split="train",
    transform=eval_transform,
    size=img_size,
)

test_xy_dataset = MedMNIST3DDataset(
    dataset_name=dataset_name,
    root=data_dir,
    split="test",
    transform=eval_transform,
    size=img_size,
)
  0%|          | 0.00/361M [00:00<?, ?B/s]
  0%|          | 65.5k/361M [00:00<17:30, 344kB/s]
  0%|          | 164k/361M [00:00<09:57, 605kB/s]
  0%|          | 295k/361M [00:00<08:56, 673kB/s]
  0%|          | 623k/361M [00:00<04:16, 1.40MB/s]
  0%|          | 983k/361M [00:00<03:36, 1.67MB/s]
  1%|          | 1.97M/361M [00:00<01:38, 3.65MB/s]
  1%|          | 2.69M/361M [00:01<01:34, 3.81MB/s]
  1%|▏         | 4.95M/361M [00:01<00:43, 8.27MB/s]
  2%|▏         | 7.83M/361M [00:01<00:26, 13.3MB/s]
  3%|▎         | 10.8M/361M [00:01<00:23, 14.7MB/s]
  4%|▍         | 15.1M/361M [00:01<00:16, 21.5MB/s]
  5%|▌         | 19.8M/361M [00:01<00:12, 27.9MB/s]
  7%|▋         | 24.3M/361M [00:01<00:10, 32.7MB/s]
  8%|▊         | 28.9M/361M [00:01<00:09, 35.5MB/s]
  9%|▉         | 33.5M/361M [00:01<00:08, 38.3MB/s]
 11%|█         | 38.3M/361M [00:02<00:07, 41.2MB/s]
 12%|█▏        | 43.1M/361M [00:02<00:07, 42.8MB/s]
 13%|█▎        | 47.5M/361M [00:02<00:08, 35.2MB/s]
 14%|█▍        | 51.9M/361M [00:02<00:08, 37.4MB/s]
 16%|█▌        | 57.3M/361M [00:02<00:07, 40.9MB/s]
 17%|█▋        | 62.1M/361M [00:02<00:07, 42.6MB/s]
 19%|█▊        | 66.9M/361M [00:02<00:06, 44.2MB/s]
 20%|█▉        | 71.5M/361M [00:02<00:06, 44.6MB/s]
 21%|██        | 76.0M/361M [00:03<00:07, 36.5MB/s]
 22%|██▏       | 81.2M/361M [00:03<00:06, 40.4MB/s]
 24%|██▎       | 85.7M/361M [00:03<00:06, 40.6MB/s]
 25%|██▌       | 90.5M/361M [00:03<00:06, 42.5MB/s]
 26%|██▋       | 95.4M/361M [00:03<00:06, 44.2MB/s]
 28%|██▊       | 99.9M/361M [00:03<00:05, 44.5MB/s]
 29%|██▉       | 104M/361M [00:03<00:07, 36.6MB/s]
 30%|███       | 110M/361M [00:03<00:06, 40.3MB/s]
 32%|███▏      | 114M/361M [00:03<00:06, 40.8MB/s]
 33%|███▎      | 119M/361M [00:04<00:05, 42.5MB/s]
 34%|███▍      | 124M/361M [00:04<00:05, 43.4MB/s]
 35%|███▌      | 128M/361M [00:04<00:05, 43.6MB/s]
 37%|███▋      | 133M/361M [00:04<00:06, 36.7MB/s]
 38%|███▊      | 138M/361M [00:04<00:05, 39.9MB/s]
 39%|███▉      | 142M/361M [00:04<00:05, 40.4MB/s]
 40%|████      | 146M/361M [00:04<00:05, 41.8MB/s]
 42%|████▏     | 151M/361M [00:04<00:05, 42.1MB/s]
 43%|████▎     | 155M/361M [00:04<00:04, 41.5MB/s]
 44%|████▍     | 160M/361M [00:04<00:04, 43.2MB/s]
 45%|████▌     | 164M/361M [00:05<00:04, 44.3MB/s]
 47%|████▋     | 169M/361M [00:05<00:05, 36.6MB/s]
 48%|████▊     | 174M/361M [00:05<00:04, 39.1MB/s]
 49%|████▉     | 179M/361M [00:05<00:04, 42.0MB/s]
 51%|█████     | 183M/361M [00:05<00:04, 41.3MB/s]
 52%|█████▏    | 187M/361M [00:05<00:04, 41.7MB/s]
 53%|█████▎    | 192M/361M [00:05<00:03, 43.1MB/s]
 54%|█████▍    | 197M/361M [00:05<00:03, 44.7MB/s]
 56%|█████▌    | 201M/361M [00:05<00:03, 44.8MB/s]
 57%|█████▋    | 206M/361M [00:06<00:04, 36.8MB/s]
 58%|█████▊    | 211M/361M [00:06<00:03, 39.7MB/s]
 60%|█████▉    | 215M/361M [00:06<00:03, 40.2MB/s]
 61%|██████    | 220M/361M [00:06<00:03, 41.3MB/s]
 62%|██████▏   | 224M/361M [00:06<00:03, 42.7MB/s]
 63%|██████▎   | 229M/361M [00:06<00:03, 43.2MB/s]
 65%|██████▍   | 233M/361M [00:06<00:03, 35.9MB/s]
 66%|██████▌   | 239M/361M [00:06<00:03, 40.1MB/s]
 67%|██████▋   | 243M/361M [00:07<00:02, 42.3MB/s]
 69%|██████▊   | 248M/361M [00:07<00:02, 41.6MB/s]
 70%|██████▉   | 253M/361M [00:07<00:02, 43.0MB/s]
 71%|███████   | 257M/361M [00:07<00:02, 42.9MB/s]
 72%|███████▏  | 261M/361M [00:07<00:02, 42.9MB/s]
 74%|███████▎  | 266M/361M [00:07<00:02, 36.3MB/s]
 75%|███████▍  | 271M/361M [00:07<00:02, 39.3MB/s]
 76%|███████▋  | 276M/361M [00:07<00:02, 39.9MB/s]
 77%|███████▋  | 280M/361M [00:07<00:01, 40.8MB/s]
 79%|███████▊  | 284M/361M [00:08<00:01, 41.8MB/s]
 80%|███████▉  | 289M/361M [00:08<00:01, 42.4MB/s]
 81%|████████  | 293M/361M [00:08<00:01, 43.5MB/s]
 82%|████████▏ | 298M/361M [00:08<00:01, 44.4MB/s]
 84%|████████▎ | 303M/361M [00:08<00:01, 37.9MB/s]
 85%|████████▌ | 308M/361M [00:08<00:01, 41.3MB/s]
 86%|████████▋ | 312M/361M [00:08<00:01, 39.4MB/s]
 88%|████████▊ | 317M/361M [00:08<00:01, 42.3MB/s]
 89%|████████▉ | 322M/361M [00:08<00:00, 42.9MB/s]
 90%|█████████ | 326M/361M [00:09<00:00, 42.9MB/s]
 91%|█████████▏| 330M/361M [00:09<00:00, 37.0MB/s]
 93%|█████████▎| 336M/361M [00:09<00:00, 40.6MB/s]
 94%|█████████▍| 340M/361M [00:09<00:00, 41.6MB/s]
 95%|█████████▌| 344M/361M [00:09<00:00, 39.3MB/s]
 97%|█████████▋| 349M/361M [00:09<00:00, 42.5MB/s]
 98%|█████████▊| 354M/361M [00:09<00:00, 42.7MB/s]
 99%|█████████▉| 358M/361M [00:09<00:00, 36.4MB/s]
100%|██████████| 361M/361M [00:09<00:00, 36.5MB/s]

Finally, we create the data loaders.

train_ssl_loader = DataLoader(
    ssl_dataset,
    batch_size=batch_size,
    shuffle=True,
    drop_last=True,
    pin_memory=True,
    num_workers=num_workers,
)
train_xy_loader = DataLoader(
    train_xy_dataset,
    batch_size=batch_size,
    shuffle=False,
    drop_last=False,
    pin_memory=True,
    num_workers=num_workers,
)
test_xy_loader = DataLoader(
    test_xy_dataset,
    batch_size=batch_size,
    shuffle=False,
    drop_last=False,
    pin_memory=True,
    num_workers=num_workers,
)

Model architecture

We use the 3D vision transformer from nidl as backbone.

encoder = VisionTransformer3D(
    img_size=img_size,
    patch_size=8,
    in_chans=1,
    embed_dim=256,
    depth=6,
    num_heads=8,
    mlp_ratio=4.0,
)

Training the I-JEPA model

We either train a compact I-JEPA model directly, or load a pretrained checkpoint from the Hugging Face Hub (which was trained with the same configuration).

model = IJEPA(
    encoder=encoder,
    dim=3,
    context_block_scale=(0.85, 1.0),
    target_block_scale=(0.15, 0.2),
    aspect_ratio=(0.75, 1.5),
    num_target_blocks=4,
    min_keep=4,
    allow_overlap=False,
    predictor_embed_dim=256,
    predictor_depth_pred=6,
    learning_rate=learning_rate,
    optimizer="adamW",
    weight_decay=weight_decay,
    max_epochs=max_epochs,
    check_val_every_n_epoch=1,
    use_distributed_sampler=False,
    enable_checkpointing=False,
    accelerator=accelerator,
    devices=1,
    random_state=random_seed,
)
if not load_pretrained:
    model.fit(train_ssl_loader)
else:
    weights = Weights(
        f"hf-hub:{hf_repo_id}",
        data_dir=model_dir,
        filepath=hf_checkpoint,
    )
    weights.load_pretrained(model)
    model.fitted_ = True

Evaluation with a linear probe

Once pretraining is complete, we extract frozen representations and fit a logistic regression classifier on top of them.

/opt/hostedtoolcache/Python/3.12.13/x64/lib/python3.12/site-packages/pytorch_lightning/utilities/_pytree.py:21: `isinstance(treespec, LeafSpec)` is deprecated, use `isinstance(treespec, TreeSpec) and treespec.is_leaf()` instead.
/opt/hostedtoolcache/Python/3.12.13/x64/lib/python3.12/site-packages/torch/utils/data/dataloader.py:1118: UserWarning: 'pin_memory' argument is set as true but no accelerator is found, then device pinned memory won't be used.
  super().__init__(loader)
Predicting ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 61/61 0:00:52 • 0:00:00 1.17it/s
Predicting ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 39/39 0:00:33 • 0:00:00 1.17it/s

We train linear probes on increasing fractions of the labeled training set to assess sample efficiency.

estimator = LogisticRegression(
    max_iter=1000, random_state=random_seed, n_jobs=1
)
train_sizes = np.unique(
    np.logspace(
        np.log10(max(10, len(X_train) // 20)),
        np.log10(len(X_train)),
        8,
        dtype=int,
    )
)
accs = []
for size in train_sizes:
    estimator.fit(X_train[:size], y_train[:size])
    y_pred = estimator.predict(X_test)
    accs.append(accuracy_score(y_test, y_pred))

We plot the scaling curve.

plt.plot(train_sizes / len(X_train), accs)
plt.ylim(0, 1)
plt.ylabel("Accuracy")
plt.xlabel("Proportion of labeled training samples")
plt.xscale("log")
plt.text(
    train_sizes[-1] / len(X_train),
    accs[-1],
    f"{accs[-1]:.2f}",
    ha="right",
    va="bottom",
)
plt.text(
    train_sizes[0] / len(X_train),
    accs[0],
    f"{accs[0]:.2f}",
    ha="left",
    va="bottom",
)
plt.show()
plot ijepa medmnist

This example shows how to train and evaluate the I-JEPA model on MedMNIST3D using nidl. The same pipeline can be applied to other 3D medical imaging datasets, and you can also expand it to include more complex training setups with logging and callbacks.

Total running time of the script: (2 minutes 4.144 seconds)

Estimated memory usage: 812 MB

Gallery generated by Sphinx-Gallery