Model probing callback of embedding estimators

This notebook will show you how to investigate the data representation given by an embedding estimator during training (such as SimCLR, y-Aware Contrastive Learning or Barlow Twins) using the notion of “probing”. A standard machine learning model (e.g. linear or SVM) is trained and evaluated on the data embedding for a given task as the model is being fitted. It allows the user to understand what concepts are learned by the model.

This has been first introduced by Guillaume Alain and Yoshua Bengio in 2017 [1] to understand the internal behavior of a deep neural network along the different layers. This technique aimed at answering questions like: what is the intermediate representation of a neural network? What information is contained for a given layer ?

Then, it has been adapted to benchmark self-supervised vision models (like SimCLR, Barlow Twins, DINO, DINOv2) on classical datasets (ImageNet, CIFAR, …) by implementing linear probing and K-Nearest Neighbors probing on model’s output representation.

Setup

This notebook requires some packages besides nidl. Let’s first start with importing our standard libraries below:

import os
import re

import matplotlib.pyplot as plt
import numpy as np
import torch.nn.functional as func
from sklearn.base import BaseEstimator as sk_BaseEstimator
from sklearn.base import clone
from sklearn.linear_model import LogisticRegression, Ridge
from sklearn.metrics import (
    accuracy_score,
    f1_score,
    make_scorer,
    r2_score,
)
from tensorboard.backend.event_processing import event_accumulator
from torch import nn
from torch.utils.data import DataLoader
from torchvision import transforms
from torchvision.datasets import MNIST
from torchvision.ops import MLP
from torchvision.utils import make_grid

from nidl.callbacks.model_probing import ModelProbing
from nidl.datasets import OpenBHB
from nidl.estimators.ssl import SimCLR, YAwareContrastiveLearning
from nidl.metrics import pearson_r
from nidl.transforms import MultiViewsTransform

We define some global parameters that will be used throughout the notebook:

data_dir = "/tmp/mnist"
batch_size = 128
num_workers = 10
latent_size = 32

Unsupervised Contrastive Learning on MNIST

For illustration purposes on how to use the probing callback, we will focus on the handwritten digits dataset MNIST. It contains 60k training images and 10k test images of size 28x28 pixels. Each image contains a digit from 0 to 9. It is rather small-scale compared to modern datasets like ImageNet but sufficient to illustrate the probing technique. We will train a SimCLR model on these data and probe the learned representation using a logistic regression classifier on the digit classification task. It will show how the data embedding evolves during training to become more linearly separable for each digit class.

We start by loading the MNIST dataset dataset with standard scaling transforms. These datasets are used for training and testing the probing.

scale_transforms = transforms.Compose(
    [transforms.ToTensor(), transforms.Normalize((0.5,), (0.5,))]
)
train_xy_dataset = MNIST(data_dir, download=True, transform=scale_transforms)
test_xy_dataset = MNIST(
    data_dir, download=True, train=False, transform=scale_transforms
)
  0%|          | 0.00/9.91M [00:00<?, ?B/s]
  1%|          | 98.3k/9.91M [00:00<00:12, 762kB/s]
  4%|▎         | 360k/9.91M [00:00<00:06, 1.54MB/s]
 15%|█▌        | 1.51M/9.91M [00:00<00:01, 4.96MB/s]
 61%|██████▏   | 6.09M/9.91M [00:00<00:00, 17.2MB/s]
100%|██████████| 9.91M/9.91M [00:00<00:00, 17.8MB/s]

  0%|          | 0.00/28.9k [00:00<?, ?B/s]
100%|██████████| 28.9k/28.9k [00:00<00:00, 396kB/s]

  0%|          | 0.00/1.65M [00:00<?, ?B/s]
  6%|▌         | 98.3k/1.65M [00:00<00:02, 679kB/s]
 24%|██▍       | 393k/1.65M [00:00<00:00, 1.48MB/s]
 97%|█████████▋| 1.61M/1.65M [00:00<00:00, 4.64MB/s]
100%|██████████| 1.65M/1.65M [00:00<00:00, 3.80MB/s]

  0%|          | 0.00/4.54k [00:00<?, ?B/s]
100%|██████████| 4.54k/4.54k [00:00<00:00, 21.2MB/s]

Dataset and data augmentations for contrastive learning

To perform contrastive learning, we need to define a set of data augmentations to create multiple views of the same image. Since we work with grayscale images, we will use random resized crop and Gaussian blur. We reduce the size of the Gaussian kernel to 3x3 since MNIST images are only 28x28 pixels.

We create the datasets returning the augmented views for training the SSL models.

ssl_dataset = MNIST(
    data_dir,
    download=True,
    transform=MultiViewsTransform(contrast_transforms, n_views=2),
)
test_ssl_dataset = MNIST(
    data_dir,
    download=True,
    train=False,
    transform=MultiViewsTransform(contrast_transforms, n_views=2),
)

And finally we create the data loaders for training and testing the models.

/opt/hostedtoolcache/Python/3.12.12/x64/lib/python3.12/site-packages/torch/utils/data/dataloader.py:626: UserWarning: This DataLoader will create 10 worker processes in total. Our suggested max number of worker in current system is 4, which is smaller than what this DataLoader is going to create. Please be aware that excessive worker creation might get DataLoader running slow or even freeze, lower the worker number to avoid potential slowness/freeze if necessary.
  warnings.warn(

Before starting training the SimCLR model, let’s visualize some examples of the dataset.

def show_images(images, title=None, nrow=8):
    grid = make_grid(images, nrow=nrow, normalize=True, pad_value=1)
    plt.figure(figsize=(10, 5))
    plt.imshow(grid.permute(1, 2, 0).cpu())
    if title:
        plt.title(title)
    plt.axis("off")
    plt.show()


# Original and augmented images
images, labels = next(iter(test_xy_loader))
augmented_views, _ = next(iter(test_ssl_loader))
view1, view2 = augmented_views[0], augmented_views[1]
fig, axes = plt.subplots(2, 3, figsize=(6, 4))
for i in range(2):
    axes[i, 0].imshow(images[i][0].cpu(), cmap="gray")
    axes[i, 0].set_title(f"Original (label={labels[i].item()})")
    axes[i, 0].axis("off")

    axes[i, 1].imshow(view1[i][0].cpu(), cmap="gray")
    axes[i, 1].set_title("Augmented View 1")
    axes[i, 1].axis("off")

    axes[i, 2].imshow(view2[i][0].cpu(), cmap="gray")
    axes[i, 2].set_title("Augmented View 2")
    axes[i, 2].axis("off")

plt.tight_layout()
plt.show()
Original (label=7), Augmented View 1, Augmented View 2, Original (label=2), Augmented View 1, Augmented View 2
/opt/hostedtoolcache/Python/3.12.12/x64/lib/python3.12/site-packages/torch/utils/data/dataloader.py:665: UserWarning: 'pin_memory' argument is set as true but no accelerator is found, then device pinned memory won't be used.
  warnings.warn(warn_msg)

SimCLR training with classification probing callback

We can now create the probing callback that will train a logistic regression classifier on the learned representation during SimCLR training. The probing is performed every 2 epochs on the training and test sets. The classification metrics (accuracy and f1-weighted) are logged to TensorBoard by default.

callback = ModelProbing(
    train_xy_loader,
    test_xy_loader,
    probe=LogisticRegression(max_iter=200),
    scoring=["accuracy", "f1_weighted"],
    every_n_train_epochs=3,
)

Since MNIST images are small, we can use a simple LeNet-like architecture as encoder for SimCLR, with few parameters. The output dimension of the encoder is set to 32, which is approximately 30 times smaller that the input, but larger than the number of input classes (10).

class LeNetEncoder(nn.Module):
    def __init__(self, latent_size=32):
        super().__init__()
        self.latent_size = latent_size
        self.conv1 = nn.Conv2d(1, 6, kernel_size=5, stride=1, padding=2)
        self.pool1 = nn.AvgPool2d(2, 2)
        self.conv2 = nn.Conv2d(6, 16, kernel_size=5)
        self.pool2 = nn.AvgPool2d(2, 2)
        self.fc1 = nn.Linear(16 * 5 * 5, 120)
        self.fc2 = nn.Linear(120, 84)
        self.fc3 = nn.Linear(84, latent_size)

    def forward(self, x):
        x = func.relu(self.conv1(x))
        x = self.pool1(x)
        x = func.relu(self.conv2(x))
        x = self.pool2(x)
        x = x.view(x.size(0), -1)
        x = func.relu(self.fc1(x))
        x = func.relu(self.fc2(x))
        return self.fc3(x)


encoder = LeNetEncoder(latent_size)

We can now create the SimCLR model with the encoder and the probing callback. We limit the training to 10 epochs for the sake of time and because it is enough for checking the evolution of the embedding geometry across training.

model = SimCLR(
    encoder=encoder,
    limit_train_batches=100,
    proj_input_dim=latent_size,
    proj_hidden_dim=64,
    proj_output_dim=32,
    max_epochs=10,
    temperature=0.1,
    learning_rate=3e-4,
    weight_decay=5e-5,
    enable_checkpointing=False,
    callbacks=callback,  # <-- key part for probing
)
model.fit(train_ssl_loader, test_ssl_loader)
/opt/hostedtoolcache/Python/3.12.12/x64/lib/python3.12/site-packages/lightning_fabric/utilities/seed.py:44: No seed found, seed set to 0
┏━━━┳━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━┳━━━━━━━┳━━━━━━━┓
┃   ┃ Name            ┃ Type                 ┃ Params ┃ Mode  ┃ FLOPs ┃
┡━━━╇━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━╇━━━━━━━╇━━━━━━━┩
│ 0 │ encoder         │ LeNetEncoder         │ 63.6 K │ train │     0 │
│ 1 │ projection_head │ SimCLRProjectionHead │  4.2 K │ train │     0 │
│ 2 │ loss            │ InfoNCE              │      0 │ train │     0 │
└───┴─────────────────┴──────────────────────┴────────┴───────┴───────┘
Trainable params: 67.8 K
Non-trainable params: 0
Total params: 67.8 K
Total estimated model params size (MB): 0
Modules in train mode: 14
Modules in eval mode: 0
Total FLOPs: 0

Extracting features: 0it [00:00, ?it/s]
Extracting features: 2it [00:00, 16.37it/s]
Extracting features: 5it [00:00, 22.40it/s]
Extracting features: 14it [00:00, 50.33it/s]
Extracting features: 26it [00:00, 73.71it/s]
Extracting features: 34it [00:00, 72.64it/s]
Extracting features: 45it [00:00, 81.22it/s]
Extracting features: 54it [00:00, 81.08it/s]
Extracting features: 63it [00:00, 82.05it/s]
Extracting features: 73it [00:01, 83.44it/s]
Extracting features: 85it [00:01, 92.69it/s]
Extracting features: 95it [00:01, 85.58it/s]
Extracting features: 105it [00:01, 88.89it/s]
Extracting features: 115it [00:01, 85.61it/s]
Extracting features: 124it [00:01, 85.16it/s]
Extracting features: 133it [00:01, 86.30it/s]
Extracting features: 142it [00:01, 86.94it/s]
Extracting features: 151it [00:01, 86.37it/s]
Extracting features: 160it [00:01, 87.16it/s]
Extracting features: 170it [00:02, 88.30it/s]
Extracting features: 179it [00:02, 85.07it/s]
Extracting features: 189it [00:02, 88.91it/s]
Extracting features: 198it [00:02, 88.44it/s]
Extracting features: 207it [00:02, 85.98it/s]
Extracting features: 216it [00:02, 85.37it/s]
Extracting features: 226it [00:02, 85.83it/s]
Extracting features: 236it [00:02, 87.87it/s]
Extracting features: 245it [00:02, 86.62it/s]
Extracting features: 254it [00:03, 85.85it/s]
Extracting features: 264it [00:03, 88.47it/s]
Extracting features: 274it [00:03, 88.05it/s]
Extracting features: 283it [00:03, 85.94it/s]
Extracting features: 293it [00:03, 86.91it/s]
Extracting features: 302it [00:03, 87.59it/s]
Extracting features: 311it [00:03, 87.01it/s]
Extracting features: 320it [00:03, 84.56it/s]
Extracting features: 329it [00:03, 85.40it/s]
Extracting features: 339it [00:04, 87.95it/s]
Extracting features: 349it [00:04, 88.95it/s]
Extracting features: 358it [00:04, 79.97it/s]
Extracting features: 367it [00:04, 78.91it/s]
Extracting features: 377it [00:04, 84.03it/s]
Extracting features: 386it [00:04, 84.63it/s]
Extracting features: 396it [00:04, 86.26it/s]
Extracting features: 405it [00:04, 86.14it/s]
Extracting features: 414it [00:04, 84.93it/s]
Extracting features: 424it [00:05, 87.62it/s]
Extracting features: 433it [00:05, 85.39it/s]
Extracting features: 442it [00:05, 86.61it/s]
Extracting features: 453it [00:05, 91.55it/s]


Extracting features: 0it [00:00, ?it/s]
Extracting features: 1it [00:00,  9.62it/s]
Extracting features: 4it [00:00, 21.03it/s]
Extracting features: 12it [00:00, 45.24it/s]
Extracting features: 23it [00:00, 68.17it/s]
Extracting features: 34it [00:00, 81.17it/s]
Extracting features: 43it [00:00, 78.73it/s]
Extracting features: 51it [00:00, 77.85it/s]
Extracting features: 62it [00:00, 86.10it/s]

/opt/hostedtoolcache/Python/3.12.12/x64/lib/python3.12/site-packages/torch/utils/data/dataloader.py:626: UserWarning: This DataLoader will create 10 worker processes in total. Our suggested max number of worker in current system is 4, which is smaller than what this DataLoader is going to create. Please be aware that excessive worker creation might get DataLoader running slow or even freeze, lower the worker number to avoid potential slowness/freeze if necessary.
  warnings.warn(
/opt/hostedtoolcache/Python/3.12.12/x64/lib/python3.12/site-packages/torch/utils/data/dataloader.py:665: UserWarning: 'pin_memory' argument is set as true but no accelerator is found, then device pinned memory won't be used.
  warnings.warn(warn_msg)

Extracting features: 0it [00:00, ?it/s]
Extracting features: 1it [00:00,  5.84it/s]
Extracting features: 10it [00:00, 43.69it/s]
Extracting features: 19it [00:00, 61.29it/s]
Extracting features: 28it [00:00, 70.55it/s]
Extracting features: 38it [00:00, 77.95it/s]
Extracting features: 47it [00:00, 79.01it/s]
Extracting features: 58it [00:00, 85.76it/s]
Extracting features: 67it [00:00, 84.23it/s]
Extracting features: 76it [00:01, 84.69it/s]
Extracting features: 85it [00:01, 83.56it/s]
Extracting features: 94it [00:01, 84.01it/s]
Extracting features: 105it [00:01, 87.66it/s]
Extracting features: 114it [00:01, 81.13it/s]
Extracting features: 123it [00:01, 81.67it/s]
Extracting features: 133it [00:01, 82.86it/s]
Extracting features: 142it [00:01, 84.08it/s]
Extracting features: 152it [00:01, 84.13it/s]
Extracting features: 162it [00:02, 87.46it/s]
Extracting features: 171it [00:02, 86.27it/s]
Extracting features: 180it [00:02, 86.83it/s]
Extracting features: 189it [00:02, 86.56it/s]
Extracting features: 199it [00:02, 84.48it/s]
Extracting features: 209it [00:02, 88.29it/s]
Extracting features: 219it [00:02, 89.03it/s]
Extracting features: 228it [00:02, 88.81it/s]
Extracting features: 237it [00:02, 88.99it/s]
Extracting features: 246it [00:03, 86.52it/s]
Extracting features: 255it [00:03, 84.61it/s]
Extracting features: 265it [00:03, 85.93it/s]
Extracting features: 274it [00:03, 86.04it/s]
Extracting features: 284it [00:03, 87.74it/s]
Extracting features: 294it [00:03, 89.21it/s]
Extracting features: 303it [00:03, 86.55it/s]
Extracting features: 312it [00:03, 86.87it/s]
Extracting features: 321it [00:03, 86.36it/s]
Extracting features: 331it [00:03, 87.22it/s]
Extracting features: 341it [00:04, 87.46it/s]
Extracting features: 351it [00:04, 90.15it/s]
Extracting features: 361it [00:04, 87.71it/s]
Extracting features: 370it [00:04, 86.25it/s]
Extracting features: 379it [00:04, 87.12it/s]
Extracting features: 389it [00:04, 87.63it/s]
Extracting features: 398it [00:04, 86.16it/s]
Extracting features: 408it [00:04, 87.62it/s]
Extracting features: 417it [00:04, 88.22it/s]
Extracting features: 426it [00:05, 87.22it/s]
Extracting features: 436it [00:05, 90.49it/s]
Extracting features: 446it [00:05, 86.34it/s]
Extracting features: 464it [00:05, 111.54it/s]


Extracting features: 0it [00:00, ?it/s]
Extracting features: 1it [00:00,  8.30it/s]
Extracting features: 4it [00:00, 19.72it/s]
Extracting features: 14it [00:00, 51.99it/s]
Extracting features: 25it [00:00, 72.70it/s]
Extracting features: 34it [00:00, 77.67it/s]
Extracting features: 43it [00:00, 78.23it/s]
Extracting features: 53it [00:00, 83.89it/s]
Extracting features: 64it [00:00, 91.68it/s]

/opt/hostedtoolcache/Python/3.12.12/x64/lib/python3.12/site-packages/torch/utils/data/dataloader.py:626: UserWarning: This DataLoader will create 10 worker processes in total. Our suggested max number of worker in current system is 4, which is smaller than what this DataLoader is going to create. Please be aware that excessive worker creation might get DataLoader running slow or even freeze, lower the worker number to avoid potential slowness/freeze if necessary.
  warnings.warn(
/opt/hostedtoolcache/Python/3.12.12/x64/lib/python3.12/site-packages/torch/utils/data/dataloader.py:665: UserWarning: 'pin_memory' argument is set as true but no accelerator is found, then device pinned memory won't be used.
  warnings.warn(warn_msg)

Extracting features: 0it [00:00, ?it/s]
Extracting features: 1it [00:00,  8.87it/s]
Extracting features: 5it [00:00, 26.25it/s]
Extracting features: 14it [00:00, 52.68it/s]
Extracting features: 25it [00:00, 69.15it/s]
Extracting features: 34it [00:00, 75.17it/s]
Extracting features: 44it [00:00, 79.18it/s]
Extracting features: 54it [00:00, 83.75it/s]
Extracting features: 64it [00:00, 88.24it/s]
Extracting features: 73it [00:00, 84.66it/s]
Extracting features: 84it [00:01, 88.85it/s]
Extracting features: 93it [00:01, 85.27it/s]
Extracting features: 103it [00:01, 85.05it/s]
Extracting features: 113it [00:01, 87.47it/s]
Extracting features: 123it [00:01, 88.28it/s]
Extracting features: 132it [00:01, 88.04it/s]
Extracting features: 141it [00:01, 88.39it/s]
Extracting features: 150it [00:01, 87.44it/s]
Extracting features: 159it [00:01, 87.64it/s]
Extracting features: 168it [00:02, 88.19it/s]
Extracting features: 177it [00:02, 88.22it/s]
Extracting features: 186it [00:02, 87.18it/s]
Extracting features: 196it [00:02, 85.12it/s]
Extracting features: 206it [00:02, 86.38it/s]
Extracting features: 216it [00:02, 88.47it/s]
Extracting features: 227it [00:02, 89.02it/s]
Extracting features: 237it [00:02, 89.24it/s]
Extracting features: 247it [00:02, 90.84it/s]
Extracting features: 257it [00:03, 89.28it/s]
Extracting features: 267it [00:03, 89.10it/s]
Extracting features: 277it [00:03, 90.25it/s]
Extracting features: 287it [00:03, 90.74it/s]
Extracting features: 297it [00:03, 89.16it/s]
Extracting features: 306it [00:03, 89.30it/s]
Extracting features: 315it [00:03, 88.81it/s]
Extracting features: 325it [00:03, 88.38it/s]
Extracting features: 334it [00:03, 88.23it/s]
Extracting features: 343it [00:04, 88.07it/s]
Extracting features: 352it [00:04, 86.50it/s]
Extracting features: 362it [00:04, 88.87it/s]
Extracting features: 372it [00:04, 89.72it/s]
Extracting features: 381it [00:04, 85.76it/s]
Extracting features: 391it [00:04, 88.95it/s]
Extracting features: 400it [00:04, 82.41it/s]
Extracting features: 410it [00:04, 84.63it/s]
Extracting features: 420it [00:04, 86.43it/s]
Extracting features: 429it [00:05, 85.82it/s]
Extracting features: 439it [00:05, 88.20it/s]
Extracting features: 448it [00:05, 87.03it/s]


Extracting features: 0it [00:00, ?it/s]
Extracting features: 1it [00:00,  5.34it/s]
Extracting features: 10it [00:00, 41.63it/s]
Extracting features: 20it [00:00, 62.60it/s]
Extracting features: 30it [00:00, 73.03it/s]
Extracting features: 39it [00:00, 77.21it/s]
Extracting features: 49it [00:00, 84.09it/s]
Extracting features: 58it [00:00, 85.07it/s]

/opt/hostedtoolcache/Python/3.12.12/x64/lib/python3.12/site-packages/torch/utils/data/dataloader.py:626: UserWarning: This DataLoader will create 10 worker processes in total. Our suggested max number of worker in current system is 4, which is smaller than what this DataLoader is going to create. Please be aware that excessive worker creation might get DataLoader running slow or even freeze, lower the worker number to avoid potential slowness/freeze if necessary.
  warnings.warn(
/opt/hostedtoolcache/Python/3.12.12/x64/lib/python3.12/site-packages/torch/utils/data/dataloader.py:665: UserWarning: 'pin_memory' argument is set as true but no accelerator is found, then device pinned memory won't be used.
  warnings.warn(warn_msg)

Extracting features: 0it [00:00, ?it/s]
Extracting features: 1it [00:00,  9.70it/s]
Extracting features: 3it [00:00, 15.13it/s]
Extracting features: 12it [00:00, 47.57it/s]
Extracting features: 21it [00:00, 62.85it/s]
Extracting features: 31it [00:00, 73.93it/s]
Extracting features: 40it [00:00, 78.76it/s]
Extracting features: 49it [00:00, 81.22it/s]
Extracting features: 58it [00:00, 82.40it/s]
Extracting features: 68it [00:00, 85.29it/s]
Extracting features: 77it [00:01, 86.29it/s]
Extracting features: 86it [00:01, 86.67it/s]
Extracting features: 95it [00:01, 86.36it/s]
Extracting features: 104it [00:01, 84.02it/s]
Extracting features: 114it [00:01, 86.06it/s]
Extracting features: 124it [00:01, 87.96it/s]
Extracting features: 133it [00:01, 87.87it/s]
Extracting features: 142it [00:01, 84.92it/s]
Extracting features: 152it [00:01, 87.51it/s]
Extracting features: 162it [00:02, 90.28it/s]
Extracting features: 172it [00:02, 89.66it/s]
Extracting features: 181it [00:02, 85.83it/s]
Extracting features: 190it [00:02, 86.57it/s]
Extracting features: 199it [00:02, 86.28it/s]
Extracting features: 208it [00:02, 87.26it/s]
Extracting features: 218it [00:02, 88.95it/s]
Extracting features: 227it [00:02, 88.92it/s]
Extracting features: 236it [00:02, 87.87it/s]
Extracting features: 245it [00:02, 85.58it/s]
Extracting features: 254it [00:03, 85.99it/s]
Extracting features: 264it [00:03, 85.71it/s]
Extracting features: 273it [00:03, 86.91it/s]
Extracting features: 282it [00:03, 87.37it/s]
Extracting features: 292it [00:03, 86.94it/s]
Extracting features: 301it [00:03, 85.49it/s]
Extracting features: 311it [00:03, 87.57it/s]
Extracting features: 320it [00:03, 86.96it/s]
Extracting features: 330it [00:03, 88.59it/s]
Extracting features: 339it [00:04, 86.57it/s]
Extracting features: 349it [00:04, 87.12it/s]
Extracting features: 359it [00:04, 88.56it/s]
Extracting features: 368it [00:04, 88.70it/s]
Extracting features: 377it [00:04, 88.10it/s]
Extracting features: 386it [00:04, 86.62it/s]
Extracting features: 395it [00:04, 85.48it/s]
Extracting features: 404it [00:04, 85.84it/s]
Extracting features: 413it [00:04, 85.25it/s]
Extracting features: 423it [00:05, 85.62it/s]
Extracting features: 433it [00:05, 87.87it/s]
Extracting features: 442it [00:05, 87.05it/s]
Extracting features: 452it [00:05, 90.69it/s]


Extracting features: 0it [00:00, ?it/s]
Extracting features: 2it [00:00, 11.25it/s]
Extracting features: 10it [00:00, 41.00it/s]
Extracting features: 19it [00:00, 58.96it/s]
Extracting features: 28it [00:00, 69.83it/s]
Extracting features: 40it [00:00, 82.65it/s]
Extracting features: 50it [00:00, 85.60it/s]
Extracting features: 59it [00:00, 84.45it/s]

Epoch 9/9  ━━━━━━━━━━━━━━━━ 100/100 0:00:10 •        11.10it/s v_num: 0.000
                                    0:00:00                    loss/train: 0.671
                                                               loss/val: 1.008
                                                               test_accuracy:
                                                               0.899
                                                               test_f1_weighted:
                                                               0.899

SimCLR(
  (encoder): LeNetEncoder(
    (conv1): Conv2d(1, 6, kernel_size=(5, 5), stride=(1, 1), padding=(2, 2))
    (pool1): AvgPool2d(kernel_size=2, stride=2, padding=0)
    (conv2): Conv2d(6, 16, kernel_size=(5, 5), stride=(1, 1))
    (pool2): AvgPool2d(kernel_size=2, stride=2, padding=0)
    (fc1): Linear(in_features=400, out_features=120, bias=True)
    (fc2): Linear(in_features=120, out_features=84, bias=True)
    (fc3): Linear(in_features=84, out_features=32, bias=True)
  )
  (projection_head): SimCLRProjectionHead(
    (layers): Sequential(
      (0): Linear(in_features=32, out_features=64, bias=True)
      (1): ReLU()
      (2): Linear(in_features=64, out_features=32, bias=True)
    )
  )
  (loss): InfoNCE(temperature=0.1)
)

Visualization of the classification metrics during training

After training, we can visualize the classification metrics logged by the linear probe using TensorBoard. The logged metrics are stored in the lightning_logs folder by default. They contain the accuracy, and f1-weighted scores.

def get_last_log_version(logs_dir="lightning_logs"):
    versions = []
    for d in os.listdir(logs_dir):
        match = re.match(r"version_(\d+)", d)
        if match:
            versions.append(int(match.group(1)))
    return max(versions) if versions else None


log_dir = f"lightning_logs/version_{get_last_log_version()}/"
ea = event_accumulator.EventAccumulator(log_dir)
ea.Reload()
metrics = [
    "test_accuracy",
    "test_f1_weighted",
]
scalars = {m: ea.Scalars(m) for m in metrics}

Once all the metrics are loaded, we plot them as the number of training steps increases:

plt.figure(figsize=(5, 3))
for m, events in scalars.items():
    steps = [e.step for e in events]
    values = [e.value for e in events]
    plt.plot(steps, values, label=m)
plt.xlabel(f"Nb steps (batch size={batch_size})")
plt.ylabel("Metric score")
plt.title("Classification metrics during SimCLR training")
plt.legend()
plt.show()
Classification metrics during SimCLR training

Observations: we can see that the classification metrics increase steadily during training, showing that the learned representation becomes more and more linearly separable for the digit classes. The accuracy reaches more than 80% after 10 epochs, which is quite good for such a simple model trained without supervision and a small number of epochs.

Probing of y-Aware representation on age and sex prediction

We have previously seen a simple case where only one classification task is being monitored during training. We can also monitor a mixed of classification and regression tasks at the same time during training of an embedding model. This could be useful if several target variables should be monitored from the representation. We will show how to perform this with nidl using the ModelProbing callback on the OpenBHB dataset to monitor age and sex decoding from brain imaging data. We refer to the example on OpenBHB for more details on this neuroimaging dataset.

We define the relevant global parameters for this example:

data_dir = "/tmp/openBHB"
batch_size = 128
num_workers = 10
latent_size = 32

OpenBHB dataset and data augmentations

We consider the gray matter and CSF volumes on some regions of interests in the Neuromorphometrics atlas across subjects in OpenBHB (“vbm_roi” modality). These data are tabular (not images) but they are still well suited for contrastive learning and they are very light compared to the raw images (284-d vector for each subject). We start by loading these data for training and testing the probing callback. The target variables are age (regression) and sex (classification).

def target_transforms(labels):
    return np.array([labels["age"], labels["sex"] == "female"])


train_xy_dataset = OpenBHB(
    data_dir,
    modality="vbm_roi",
    target=["age", "sex"],
    transforms=lambda x: x.flatten(),
    target_transforms=target_transforms,
    streaming=False,
)
test_xy_dataset = OpenBHB(
    data_dir,
    modality="vbm_roi",
    split="val",
    target=["age", "sex"],
    transforms=lambda x: x.flatten(),
    target_transforms=target_transforms,
    streaming=False,
)
Fetching ... files: 0it [00:00, ?it/s]
Fetching ... files: 1it [00:00, 11214.72it/s]

Fetching ... files: 0it [00:00, ?it/s]
Fetching ... files: 1it [00:00, 15087.42it/s]

To perform contrastive learning, we will use random masking and Gaussian noise as data augmentations. These are well suited for tabular data. We will train a y-Aware Contrastive Learning model on these data, using age as auxiliary variable.

mask_prob = 0.8
noise_std = 0.5
contrast_transforms = transforms.Compose(
    [
        lambda x: x.flatten(),
        lambda x: (np.random.rand(*x.shape) > mask_prob).astype(np.float32)
        * x,  # random masking
        lambda x: x
        + (
            (np.random.rand() > 0.5) * np.random.randn(*x.shape) * noise_std
        ).astype(np.float32),  # random Gaussian noise
    ]
)

ssl_dataset = OpenBHB(
    data_dir,
    modality="vbm_roi",
    target="age",
    transforms=MultiViewsTransform(contrast_transforms, n_views=2),
)
test_ssl_dataset = OpenBHB(
    data_dir,
    modality="vbm_roi",
    target="age",
    split="val",
    transforms=MultiViewsTransform(contrast_transforms, n_views=2),
)

As before, we create the data loaders for training and testing the models.

/opt/hostedtoolcache/Python/3.12.12/x64/lib/python3.12/site-packages/torch/utils/data/dataloader.py:626: UserWarning: This DataLoader will create 10 worker processes in total. Our suggested max number of worker in current system is 4, which is smaller than what this DataLoader is going to create. Please be aware that excessive worker creation might get DataLoader running slow or even freeze, lower the worker number to avoid potential slowness/freeze if necessary.
  warnings.warn(

y-Aware CL training with multitask probing callback

Next, we create the multitask probing callback that will train a ridge regression on age and a logistic regression classifier on sex. The probing is performed every epoch on the training and test sets. The metrics are logged to TensorBoard by default.

To do so, we need to create a meta-estimator (compatible with scikit-learn) that wraps the two estimators (ridge and logistic regression) and handles the mixed regression/classification tasks. We provide such a meta-estimator called MultiTaskEstimator below.

class MultiTaskEstimator(sk_BaseEstimator):
    """
    A meta-estimator that wraps a list of sklearn estimators
    for multi-task problems (mixed regression/classification).
    """

    def __init__(self, estimators):
        self.estimators = estimators

    def fit(self, X, y):
        """Fit each estimator on its corresponding column in y."""
        y = np.asarray(y)
        if y.ndim == 1:
            y = y.reshape(-1, 1)
        self.estimators_ = []
        for i, est in enumerate(self.estimators):
            fitted = clone(est).fit(X, y[:, i])
            self.estimators_.append(fitted)
        return self

    def predict(self, X):
        """Predict for each task."""
        preds = [est.predict(X).reshape(-1, 1) for est in self.estimators_]
        return np.hstack(preds)

    def score(self, X, y):
        """Average score across all tasks."""
        y = np.asarray(y)
        scores = []
        for i, est in enumerate(self.estimators_):
            scores.append(est.score(X, y[:, i]))
        return np.mean(scores)

    def __len__(self):
        return len(self.estimators)

Then, we define a scorer specific for each task:

def make_task_scorer(metric_fn, task_index, **kwargs):
    """Returns a scorer evaluating on y or y[:, task_index]."""

    def scorer(y_true, y_pred):
        if task_index is None:
            return metric_fn(y_true, y_pred)
        else:
            return metric_fn(y_true[:, task_index], y_pred[:, task_index])

    return make_scorer(scorer, **kwargs)

Finally, we create the multitask probing callback with the relevant estimators and scorers for age and sex.

callback = ModelProbing(
    train_xy_loader,
    test_xy_loader,
    probe=MultiTaskEstimator([Ridge(), LogisticRegression(max_iter=200)]),
    scoring={
        "age/r2": make_task_scorer(r2_score, task_index=0),
        "age/pearsonr": make_task_scorer(pearson_r, task_index=0),
        "sex/accuracy": make_task_scorer(accuracy_score, task_index=1),
        "sex/f1": make_task_scorer(f1_score, task_index=1),
    },
    every_n_train_epochs=3,
)

Since we work with tabular data, we can use a simple MLP as encoder for y-Aware Contrastive Learning. The input dimension is 284 and we compress the data to a 32-d latent space.

encoder = MLP(in_channels=284, hidden_channels=[64, latent_size])

We can now create the y-Aware Contrastive Learning model with the MLP encoder and the multitask probing callback. We limit the training to 10 epochs for the sake of time and we use a small bandwidth for the Gaussian kernel in the y-Aware model compared to the variance of the age in OpenBHB (sigma=4).

sigma = 4
model = YAwareContrastiveLearning(
    encoder=encoder,
    proj_input_dim=latent_size,
    proj_hidden_dim=2 * latent_size,
    proj_output_dim=latent_size,
    bandwidth=sigma**2,
    max_epochs=10,
    temperature=0.1,
    learning_rate=1e-3,
    enable_checkpointing=False,
    callbacks=callback,  # <-- add callback to monitor the training
)

model.fit(train_ssl_loader, test_ssl_loader)
/opt/hostedtoolcache/Python/3.12.12/x64/lib/python3.12/site-packages/torch/utils/data/dataloader.py:665: UserWarning: 'pin_memory' argument is set as true but no accelerator is found, then device pinned memory won't be used.
  warnings.warn(warn_msg)
/opt/hostedtoolcache/Python/3.12.12/x64/lib/python3.12/site-packages/pytorch_lightning/loops/fit_loop.py:317: The number of training batches (26) is smaller than the logging interval Trainer(log_every_n_steps=50). Set a lower value for log_every_n_steps if you want to see logs for the training epoch.
┏━━━┳━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━┳━━━━━━━┳━━━━━━━┓
┃   ┃ Name            ┃ Type                 ┃ Params ┃ Mode  ┃ FLOPs ┃
┡━━━╇━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━╇━━━━━━━╇━━━━━━━┩
│ 0 │ encoder         │ MLP                  │ 20.3 K │ train │     0 │
│ 1 │ projection_head │ YAwareProjectionHead │  4.2 K │ train │     0 │
│ 2 │ loss            │ YAwareInfoNCE        │      0 │ train │     0 │
└───┴─────────────────┴──────────────────────┴────────┴───────┴───────┘
Trainable params: 24.5 K
Non-trainable params: 0
Total params: 24.5 K
Total estimated model params size (MB): 0
Modules in train mode: 13
Modules in eval mode: 0
Total FLOPs: 0
/opt/hostedtoolcache/Python/3.12.12/x64/lib/python3.12/site-packages/torch/utils/data/dataloader.py:626: UserWarning: This DataLoader will create 10 worker processes in total. Our suggested max number of worker in current system is 4, which is smaller than what this DataLoader is going to create. Please be aware that excessive worker creation might get DataLoader running slow or even freeze, lower the worker number to avoid potential slowness/freeze if necessary.
  warnings.warn(
/opt/hostedtoolcache/Python/3.12.12/x64/lib/python3.12/site-packages/torch/utils/data/dataloader.py:665: UserWarning: 'pin_memory' argument is set as true but no accelerator is found, then device pinned memory won't be used.
  warnings.warn(warn_msg)
/opt/hostedtoolcache/Python/3.12.12/x64/lib/python3.12/site-packages/torch/utils/data/dataloader.py:626: UserWarning: This DataLoader will create 10 worker processes in total. Our suggested max number of worker in current system is 4, which is smaller than what this DataLoader is going to create. Please be aware that excessive worker creation might get DataLoader running slow or even freeze, lower the worker number to avoid potential slowness/freeze if necessary.
  warnings.warn(
/opt/hostedtoolcache/Python/3.12.12/x64/lib/python3.12/site-packages/torch/utils/data/dataloader.py:665: UserWarning: 'pin_memory' argument is set as true but no accelerator is found, then device pinned memory won't be used.
  warnings.warn(warn_msg)

Extracting features: 0it [00:00, ?it/s]
Extracting features: 1it [00:05,  5.32s/it]
Extracting features: 2it [00:05,  2.30s/it]
Extracting features: 8it [00:05,  2.29it/s]
Extracting features: 11it [00:08,  1.67it/s]
Extracting features: 18it [00:08,  3.38it/s]
Extracting features: 21it [00:10,  2.85it/s]


Extracting features: 0it [00:00, ?it/s]
Extracting features: 1it [00:02,  2.40s/it]
Extracting features: 2it [00:02,  1.13s/it]

/opt/hostedtoolcache/Python/3.12.12/x64/lib/python3.12/site-packages/torch/utils/data/dataloader.py:626: UserWarning: This DataLoader will create 10 worker processes in total. Our suggested max number of worker in current system is 4, which is smaller than what this DataLoader is going to create. Please be aware that excessive worker creation might get DataLoader running slow or even freeze, lower the worker number to avoid potential slowness/freeze if necessary.
  warnings.warn(
/opt/hostedtoolcache/Python/3.12.12/x64/lib/python3.12/site-packages/torch/utils/data/dataloader.py:665: UserWarning: 'pin_memory' argument is set as true but no accelerator is found, then device pinned memory won't be used.
  warnings.warn(warn_msg)
/opt/hostedtoolcache/Python/3.12.12/x64/lib/python3.12/site-packages/torch/utils/data/dataloader.py:626: UserWarning: This DataLoader will create 10 worker processes in total. Our suggested max number of worker in current system is 4, which is smaller than what this DataLoader is going to create. Please be aware that excessive worker creation might get DataLoader running slow or even freeze, lower the worker number to avoid potential slowness/freeze if necessary.
  warnings.warn(
/opt/hostedtoolcache/Python/3.12.12/x64/lib/python3.12/site-packages/torch/utils/data/dataloader.py:665: UserWarning: 'pin_memory' argument is set as true but no accelerator is found, then device pinned memory won't be used.
  warnings.warn(warn_msg)
/opt/hostedtoolcache/Python/3.12.12/x64/lib/python3.12/site-packages/torch/utils/data/dataloader.py:626: UserWarning: This DataLoader will create 10 worker processes in total. Our suggested max number of worker in current system is 4, which is smaller than what this DataLoader is going to create. Please be aware that excessive worker creation might get DataLoader running slow or even freeze, lower the worker number to avoid potential slowness/freeze if necessary.
  warnings.warn(
/opt/hostedtoolcache/Python/3.12.12/x64/lib/python3.12/site-packages/torch/utils/data/dataloader.py:665: UserWarning: 'pin_memory' argument is set as true but no accelerator is found, then device pinned memory won't be used.
  warnings.warn(warn_msg)
/opt/hostedtoolcache/Python/3.12.12/x64/lib/python3.12/site-packages/torch/utils/data/dataloader.py:626: UserWarning: This DataLoader will create 10 worker processes in total. Our suggested max number of worker in current system is 4, which is smaller than what this DataLoader is going to create. Please be aware that excessive worker creation might get DataLoader running slow or even freeze, lower the worker number to avoid potential slowness/freeze if necessary.
  warnings.warn(
/opt/hostedtoolcache/Python/3.12.12/x64/lib/python3.12/site-packages/torch/utils/data/dataloader.py:665: UserWarning: 'pin_memory' argument is set as true but no accelerator is found, then device pinned memory won't be used.
  warnings.warn(warn_msg)
/opt/hostedtoolcache/Python/3.12.12/x64/lib/python3.12/site-packages/torch/utils/data/dataloader.py:626: UserWarning: This DataLoader will create 10 worker processes in total. Our suggested max number of worker in current system is 4, which is smaller than what this DataLoader is going to create. Please be aware that excessive worker creation might get DataLoader running slow or even freeze, lower the worker number to avoid potential slowness/freeze if necessary.
  warnings.warn(
/opt/hostedtoolcache/Python/3.12.12/x64/lib/python3.12/site-packages/torch/utils/data/dataloader.py:665: UserWarning: 'pin_memory' argument is set as true but no accelerator is found, then device pinned memory won't be used.
  warnings.warn(warn_msg)
/opt/hostedtoolcache/Python/3.12.12/x64/lib/python3.12/site-packages/torch/utils/data/dataloader.py:626: UserWarning: This DataLoader will create 10 worker processes in total. Our suggested max number of worker in current system is 4, which is smaller than what this DataLoader is going to create. Please be aware that excessive worker creation might get DataLoader running slow or even freeze, lower the worker number to avoid potential slowness/freeze if necessary.
  warnings.warn(
/opt/hostedtoolcache/Python/3.12.12/x64/lib/python3.12/site-packages/torch/utils/data/dataloader.py:665: UserWarning: 'pin_memory' argument is set as true but no accelerator is found, then device pinned memory won't be used.
  warnings.warn(warn_msg)
/opt/hostedtoolcache/Python/3.12.12/x64/lib/python3.12/site-packages/torch/utils/data/dataloader.py:626: UserWarning: This DataLoader will create 10 worker processes in total. Our suggested max number of worker in current system is 4, which is smaller than what this DataLoader is going to create. Please be aware that excessive worker creation might get DataLoader running slow or even freeze, lower the worker number to avoid potential slowness/freeze if necessary.
  warnings.warn(
/opt/hostedtoolcache/Python/3.12.12/x64/lib/python3.12/site-packages/torch/utils/data/dataloader.py:665: UserWarning: 'pin_memory' argument is set as true but no accelerator is found, then device pinned memory won't be used.
  warnings.warn(warn_msg)

Extracting features: 0it [00:00, ?it/s]
Extracting features: 1it [00:05,  5.49s/it]
Extracting features: 5it [00:05,  1.17it/s]
Extracting features: 10it [00:05,  2.81it/s]
Extracting features: 13it [00:08,  1.84it/s]
Extracting features: 16it [00:08,  2.65it/s]
Extracting features: 20it [00:08,  3.87it/s]
Extracting features: 22it [00:10,  2.73it/s]


Extracting features: 0it [00:00, ?it/s]
Extracting features: 1it [00:01,  1.98s/it]

/opt/hostedtoolcache/Python/3.12.12/x64/lib/python3.12/site-packages/torch/utils/data/dataloader.py:626: UserWarning: This DataLoader will create 10 worker processes in total. Our suggested max number of worker in current system is 4, which is smaller than what this DataLoader is going to create. Please be aware that excessive worker creation might get DataLoader running slow or even freeze, lower the worker number to avoid potential slowness/freeze if necessary.
  warnings.warn(
/opt/hostedtoolcache/Python/3.12.12/x64/lib/python3.12/site-packages/torch/utils/data/dataloader.py:665: UserWarning: 'pin_memory' argument is set as true but no accelerator is found, then device pinned memory won't be used.
  warnings.warn(warn_msg)
/opt/hostedtoolcache/Python/3.12.12/x64/lib/python3.12/site-packages/torch/utils/data/dataloader.py:626: UserWarning: This DataLoader will create 10 worker processes in total. Our suggested max number of worker in current system is 4, which is smaller than what this DataLoader is going to create. Please be aware that excessive worker creation might get DataLoader running slow or even freeze, lower the worker number to avoid potential slowness/freeze if necessary.
  warnings.warn(
/opt/hostedtoolcache/Python/3.12.12/x64/lib/python3.12/site-packages/torch/utils/data/dataloader.py:665: UserWarning: 'pin_memory' argument is set as true but no accelerator is found, then device pinned memory won't be used.
  warnings.warn(warn_msg)
/opt/hostedtoolcache/Python/3.12.12/x64/lib/python3.12/site-packages/torch/utils/data/dataloader.py:626: UserWarning: This DataLoader will create 10 worker processes in total. Our suggested max number of worker in current system is 4, which is smaller than what this DataLoader is going to create. Please be aware that excessive worker creation might get DataLoader running slow or even freeze, lower the worker number to avoid potential slowness/freeze if necessary.
  warnings.warn(
/opt/hostedtoolcache/Python/3.12.12/x64/lib/python3.12/site-packages/torch/utils/data/dataloader.py:665: UserWarning: 'pin_memory' argument is set as true but no accelerator is found, then device pinned memory won't be used.
  warnings.warn(warn_msg)
/opt/hostedtoolcache/Python/3.12.12/x64/lib/python3.12/site-packages/torch/utils/data/dataloader.py:626: UserWarning: This DataLoader will create 10 worker processes in total. Our suggested max number of worker in current system is 4, which is smaller than what this DataLoader is going to create. Please be aware that excessive worker creation might get DataLoader running slow or even freeze, lower the worker number to avoid potential slowness/freeze if necessary.
  warnings.warn(
/opt/hostedtoolcache/Python/3.12.12/x64/lib/python3.12/site-packages/torch/utils/data/dataloader.py:665: UserWarning: 'pin_memory' argument is set as true but no accelerator is found, then device pinned memory won't be used.
  warnings.warn(warn_msg)
/opt/hostedtoolcache/Python/3.12.12/x64/lib/python3.12/site-packages/torch/utils/data/dataloader.py:626: UserWarning: This DataLoader will create 10 worker processes in total. Our suggested max number of worker in current system is 4, which is smaller than what this DataLoader is going to create. Please be aware that excessive worker creation might get DataLoader running slow or even freeze, lower the worker number to avoid potential slowness/freeze if necessary.
  warnings.warn(
/opt/hostedtoolcache/Python/3.12.12/x64/lib/python3.12/site-packages/torch/utils/data/dataloader.py:665: UserWarning: 'pin_memory' argument is set as true but no accelerator is found, then device pinned memory won't be used.
  warnings.warn(warn_msg)
/opt/hostedtoolcache/Python/3.12.12/x64/lib/python3.12/site-packages/torch/utils/data/dataloader.py:626: UserWarning: This DataLoader will create 10 worker processes in total. Our suggested max number of worker in current system is 4, which is smaller than what this DataLoader is going to create. Please be aware that excessive worker creation might get DataLoader running slow or even freeze, lower the worker number to avoid potential slowness/freeze if necessary.
  warnings.warn(
/opt/hostedtoolcache/Python/3.12.12/x64/lib/python3.12/site-packages/torch/utils/data/dataloader.py:665: UserWarning: 'pin_memory' argument is set as true but no accelerator is found, then device pinned memory won't be used.
  warnings.warn(warn_msg)
/opt/hostedtoolcache/Python/3.12.12/x64/lib/python3.12/site-packages/torch/utils/data/dataloader.py:626: UserWarning: This DataLoader will create 10 worker processes in total. Our suggested max number of worker in current system is 4, which is smaller than what this DataLoader is going to create. Please be aware that excessive worker creation might get DataLoader running slow or even freeze, lower the worker number to avoid potential slowness/freeze if necessary.
  warnings.warn(
/opt/hostedtoolcache/Python/3.12.12/x64/lib/python3.12/site-packages/torch/utils/data/dataloader.py:665: UserWarning: 'pin_memory' argument is set as true but no accelerator is found, then device pinned memory won't be used.
  warnings.warn(warn_msg)

Extracting features: 0it [00:00, ?it/s]
Extracting features: 1it [00:05,  5.55s/it]
Extracting features: 7it [00:05,  1.68it/s]
Extracting features: 11it [00:08,  1.48it/s]
Extracting features: 21it [00:10,  2.72it/s]


Extracting features: 0it [00:00, ?it/s]
Extracting features: 1it [00:02,  2.59s/it]

/opt/hostedtoolcache/Python/3.12.12/x64/lib/python3.12/site-packages/torch/utils/data/dataloader.py:626: UserWarning: This DataLoader will create 10 worker processes in total. Our suggested max number of worker in current system is 4, which is smaller than what this DataLoader is going to create. Please be aware that excessive worker creation might get DataLoader running slow or even freeze, lower the worker number to avoid potential slowness/freeze if necessary.
  warnings.warn(
/opt/hostedtoolcache/Python/3.12.12/x64/lib/python3.12/site-packages/torch/utils/data/dataloader.py:665: UserWarning: 'pin_memory' argument is set as true but no accelerator is found, then device pinned memory won't be used.
  warnings.warn(warn_msg)
/opt/hostedtoolcache/Python/3.12.12/x64/lib/python3.12/site-packages/torch/utils/data/dataloader.py:626: UserWarning: This DataLoader will create 10 worker processes in total. Our suggested max number of worker in current system is 4, which is smaller than what this DataLoader is going to create. Please be aware that excessive worker creation might get DataLoader running slow or even freeze, lower the worker number to avoid potential slowness/freeze if necessary.
  warnings.warn(
/opt/hostedtoolcache/Python/3.12.12/x64/lib/python3.12/site-packages/torch/utils/data/dataloader.py:665: UserWarning: 'pin_memory' argument is set as true but no accelerator is found, then device pinned memory won't be used.
  warnings.warn(warn_msg)
/opt/hostedtoolcache/Python/3.12.12/x64/lib/python3.12/site-packages/torch/utils/data/dataloader.py:626: UserWarning: This DataLoader will create 10 worker processes in total. Our suggested max number of worker in current system is 4, which is smaller than what this DataLoader is going to create. Please be aware that excessive worker creation might get DataLoader running slow or even freeze, lower the worker number to avoid potential slowness/freeze if necessary.
  warnings.warn(
/opt/hostedtoolcache/Python/3.12.12/x64/lib/python3.12/site-packages/torch/utils/data/dataloader.py:665: UserWarning: 'pin_memory' argument is set as true but no accelerator is found, then device pinned memory won't be used.
  warnings.warn(warn_msg)
/opt/hostedtoolcache/Python/3.12.12/x64/lib/python3.12/site-packages/torch/utils/data/dataloader.py:626: UserWarning: This DataLoader will create 10 worker processes in total. Our suggested max number of worker in current system is 4, which is smaller than what this DataLoader is going to create. Please be aware that excessive worker creation might get DataLoader running slow or even freeze, lower the worker number to avoid potential slowness/freeze if necessary.
  warnings.warn(
/opt/hostedtoolcache/Python/3.12.12/x64/lib/python3.12/site-packages/torch/utils/data/dataloader.py:665: UserWarning: 'pin_memory' argument is set as true but no accelerator is found, then device pinned memory won't be used.
  warnings.warn(warn_msg)
/opt/hostedtoolcache/Python/3.12.12/x64/lib/python3.12/site-packages/torch/utils/data/dataloader.py:626: UserWarning: This DataLoader will create 10 worker processes in total. Our suggested max number of worker in current system is 4, which is smaller than what this DataLoader is going to create. Please be aware that excessive worker creation might get DataLoader running slow or even freeze, lower the worker number to avoid potential slowness/freeze if necessary.
  warnings.warn(
/opt/hostedtoolcache/Python/3.12.12/x64/lib/python3.12/site-packages/torch/utils/data/dataloader.py:665: UserWarning: 'pin_memory' argument is set as true but no accelerator is found, then device pinned memory won't be used.
  warnings.warn(warn_msg)
/opt/hostedtoolcache/Python/3.12.12/x64/lib/python3.12/site-packages/torch/utils/data/dataloader.py:626: UserWarning: This DataLoader will create 10 worker processes in total. Our suggested max number of worker in current system is 4, which is smaller than what this DataLoader is going to create. Please be aware that excessive worker creation might get DataLoader running slow or even freeze, lower the worker number to avoid potential slowness/freeze if necessary.
  warnings.warn(
/opt/hostedtoolcache/Python/3.12.12/x64/lib/python3.12/site-packages/torch/utils/data/dataloader.py:665: UserWarning: 'pin_memory' argument is set as true but no accelerator is found, then device pinned memory won't be used.
  warnings.warn(warn_msg)
/opt/hostedtoolcache/Python/3.12.12/x64/lib/python3.12/site-packages/torch/utils/data/dataloader.py:626: UserWarning: This DataLoader will create 10 worker processes in total. Our suggested max number of worker in current system is 4, which is smaller than what this DataLoader is going to create. Please be aware that excessive worker creation might get DataLoader running slow or even freeze, lower the worker number to avoid potential slowness/freeze if necessary.
  warnings.warn(
/opt/hostedtoolcache/Python/3.12.12/x64/lib/python3.12/site-packages/torch/utils/data/dataloader.py:665: UserWarning: 'pin_memory' argument is set as true but no accelerator is found, then device pinned memory won't be used.
  warnings.warn(warn_msg)

Extracting features: 0it [00:00, ?it/s]
Extracting features: 1it [00:03,  3.66s/it]
Extracting features: 2it [00:03,  1.67s/it]
Extracting features: 3it [00:05,  1.58s/it]
Extracting features: 11it [00:06,  2.39it/s]
Extracting features: 12it [00:07,  2.45it/s]
Extracting features: 13it [00:08,  1.86it/s]
Extracting features: 21it [00:09,  3.70it/s]
Extracting features: 23it [00:10,  3.69it/s]


Extracting features: 0it [00:00, ?it/s]
Extracting features: 1it [00:02,  2.43s/it]
Extracting features: 2it [00:02,  1.07s/it]

/opt/hostedtoolcache/Python/3.12.12/x64/lib/python3.12/site-packages/sklearn/linear_model/_logistic.py:470: ConvergenceWarning: lbfgs failed to converge after 200 iteration(s) (status=1):
STOP: TOTAL NO. OF ITERATIONS REACHED LIMIT

Increase the number of iterations to improve the convergence (max_iter=200).
You might also want to scale the data as shown in:
    https://scikit-learn.org/stable/modules/preprocessing.html
Please also refer to the documentation for alternative solver options:
    https://scikit-learn.org/stable/modules/linear_model.html#logistic-regression
  n_iter_i = _check_optimize_result(
Epoch 9/9  ━━━━━━━━━━━━━━━━━ 26/26 0:00:10 • 0:00:00 3.71it/s v_num: 1.000
                                                              loss/train: 7.943
                                                              loss/val: 11.054
                                                              test_age/r2: 0.639
                                                              test_age/pearsonr:
                                                              0.807
                                                              test_sex/accuracy:
                                                              0.760 test_sex/f1:
                                                              0.747

YAwareContrastiveLearning(
  (encoder): MLP(
    (0): Linear(in_features=284, out_features=64, bias=True)
    (1): ReLU()
    (2): Dropout(p=0.0, inplace=False)
    (3): Linear(in_features=64, out_features=32, bias=True)
    (4): Dropout(p=0.0, inplace=False)
  )
  (projection_head): YAwareProjectionHead(
    (layers): Sequential(
      (0): Linear(in_features=32, out_features=64, bias=True)
      (1): ReLU()
      (2): Linear(in_features=64, out_features=32, bias=True)
    )
  )
  (loss): YAwareInfoNCE(
    (sim_metric): PairwiseCosineSimilarity()
  )
)

Visualization of the classification and regression metrics during training

After training, we can visualize the classification and regression metrics logged by the model probing using TensorBoard. The logged metrics are stored in the lightning_logs folder by default.

log_dir = f"lightning_logs/version_{get_last_log_version()}/"

# Reload the log file
ea = event_accumulator.EventAccumulator(log_dir)
ea.Reload()
metrics = [
    "test_age/r2",
    "test_age/pearsonr",
    "test_sex/accuracy",
    "test_sex/f1",
]
# fetch all events
scalars = {m: ea.Scalars(m) for m in metrics}

Once all the metrics are loaded, we plot them as the number of training steps increases. We create two subplots, one for each task (age regression and sex classification).

def plot_task(ax, task_metrics, title):
    for m in task_metrics:
        steps = [s.step for s in scalars[m]]
        values = [s.value for s in scalars[m]]
        ax.plot(steps, values, label=m.split("/")[1])
    ax.set_title(title)
    ax.set_xlabel("Step")
    ax.set_ylabel("Metric Value")
    ax.legend()
    ax.grid(True)


fig, axes = plt.subplots(1, 2, figsize=(10, 5))
plot_task(axes[0], ["test_age/r2", "test_age/pearsonr"], "Age Regression")
plot_task(axes[1], ["test_sex/accuracy", "test_sex/f1"], "Sex Classification")
plt.tight_layout()
plt.show()
Age Regression, Sex Classification

Conclusions

In this notebook, we have shown how to use the model probing callbacks available in nidl to monitor the evolution of the data representation during training of embedding models such as SimCLR and y-Aware Contrastive Learning. We have seen how to use the ModelProbing callback for single-task probing and multi-task probing. These callbacks allow to train standard machine learning models (e.g. logistic regression, ridge regression, SVM) on the learned representation at regular intervals during training and log the relevant metrics to TensorBoard. This provides insights on what concepts are being learned by the model and how the representation evolves to become more suitable for downstream tasks.

Total running time of the script: (7 minutes 27.784 seconds)

Estimated memory usage: 397 MB

Gallery generated by Sphinx-Gallery