nidl.estimators: Available estimators

This modules details the public API you should use and implement for a nidl compatible estimator, as well as the estimators available in nidl.

Introduction

An estimator is an object that fits a model based on some training data and is capable of inferring some properties on new data. It can be, a classifier, a clustering algorithm, a regressor or a transformer (in the “scikit-learn” sense). All estimators implement a fit method. Behind the hood, it inherits from a pytorch_lightning.LightningModule, and thus benefits from all the features of the pytorch_lightning library:

  • Distributed multi-GPU training

  • Logging and visualization

  • Clear organization of the training and evaluation code

  • Automatic checkpointing and early stopping

  • Callback logic

Instanciation

The estimator __init__ method might accept constants as arguments that determine the estimator’s behavior (like the hyperparameters and training settings). It should not, however, take the actual training data as an argument, as this is left to the fit method.

Fitting

The next thing you will probably want to do is to estimate some parameters in the model. This is implemented in the fit method, and it’s where the training happens.

The fit method takes the following training data as arguments:

Parameters

train_dataloader

torch DataLoader [(n_samples, *)]

val_dataloader

torch DataLoader [(n_samples, *)]

Build as a LightningModule, the fit method gets organized under a training_step and validation_step methods.

Estimator types

The proposed types of estimators are transformers, classifiers, regressors, and clustering algorithms.

Transformers inherit from TransformerMixin, and implement a transform method. These are estimators which take the input, and transform it in some way. Note that they should never change the number of input samples, and the output of transform should correspond to its input samples in the same given order.

Regressors inherit from RegressorMixin, and implement a predict method returning the values assigned to newly given samples. In this case the training data must returns two tensors.

Classifiers inherit from ClassifierMixin, and implement a predict method returning the labels assigned to newly given samples. In this case the training data must returns two tensors.

Clustering inherit from ClusterMixin, and implement a predict method returning the labels assigned to newly given samples. In this case the training data must returns two tensors.

Build as a LightningModule, the transform and the predict method gets organized under a the transform_step and the predict_step methods.

Base Classes

Base classes for all nidl estimators.

BaseEstimator([callbacks, ...])

Base class for all estimators in the NIDL framework designed for scalability.

ClassifierMixin()

Mixin class for all classifiers in nidl.

ClusterMixin()

Mixin class for all cluster estimators in nidl.

RegressorMixin()

Mixin class for all regression estimators in nidl.

TransformerMixin()

Mixin class for all transformers in nidl.

        classDiagram
  LightningModule <|-- BaseEstimator
    

Self-Supervised Learning

Self-supervised learning embedding estimators, losses and associated tools.

Embedding estimators

SimCLR(encoder[, encoder_kwargs, ...])

SimCLR [Rfa9275c25f8b-1].

DCL(encoder[, encoder_kwargs, ...])

Decoupled Contrastive Learning [R23fac62c525a-1].

YAwareContrastiveLearning(encoder[, ...])

y-Aware Contrastive Learning [R8ad1d13554f2-1].

BarlowTwins(encoder[, encoder_kwargs, ...])

Barlow Twins [Rf18316f97edd-1].

DINO(encoder[, encoder_kwargs, ...])

DINO [Rf531b9701544-1].

IJEPA(encoder[, dim, context_block_scale, ...])

Implementation of I-JEPA [Re5620117445e-1].

        classDiagram
  BaseEstimator <|-- BarlowTwins
  BaseEstimator <|-- DCL
  BaseEstimator <|-- DINO
  BaseEstimator <|-- IJEPA
  BaseEstimator <|-- SimCLR
  BaseEstimator <|-- YAwareContrastiveLearning
  TransformerMixin <|-- BarlowTwins
  TransformerMixin <|-- DCL
  TransformerMixin <|-- DINO
  TransformerMixin <|-- IJEPA
  TransformerMixin <|-- SimCLR
  TransformerMixin <|-- YAwareContrastiveLearning
    

Losses

InfoNCE([temperature])

Implementation of the InfoNCE loss [Re38fc64e0ed4-1], [Re38fc64e0ed4-2].

DCLLoss([temperature, pos_weight_fn])

Implementation of the Decoupled Contrastive Learning loss [R0cf4714be807-1]

DCLWLoss([sigma, temperature])

Decoupled Contrastive Loss (DCL) with von Mises-Fisher (vMF) weighting.

YAwareInfoNCE([kernel, bandwidth, temperature])

Implementation of the y-Aware InfoNCE loss [Ra2feb9ab43ec-1].

BarlowTwinsLoss([lambd])

Implementation of the Barlow Twins loss [Re83c9b545e4a-1].

DINOLoss([output_dim, warmup_teacher_temp, ...])

Implementation of the DINO loss [Re2e7efabd714-1].

Projection heads

ProjectionHead(blocks)

Base class for all projection and prediction heads in self-supervised estimators.

SimCLRProjectionHead([input_dim, ...])

Projection head used for SimCLR.

YAwareProjectionHead([input_dim, ...])

Projection head used for yAware contrastive learning.

BarlowTwinsProjectionHead([input_dim, ...])

Projection head used for Barlow Twins [R86bff2d09119-1].

DINOProjectionHead([input_dim, hidden_dim, ...])

Projection head used in DINO [Re0843dd52443-1].

Probing estimators

Probing estimators, which are used to evaluate the quality of the representations learned by an embedding estimator.

ModelProbing(embedding_estimator, probe[, ...])

Estimator to probe the representation of an embedding estimator.

Autoencoders

Autoencoder estimators and losses.

Embedding estimators

VAE(encoder, decoder, encoder_out_dim, ...)

Variational Auto-Encoder (VAE) [Rfc808ee91d8e-1] [Rfc808ee91d8e-2].

        classDiagram
  BaseEstimator <|-- VAE
  TransformerMixin <|-- VAE
    

Losses

BetaVAELoss([beta, default_dist])

Compute the Beta-VAE loss [Rd208eb1f92d3-1].

Dummy estimator

A dummy estimator that does not perform any training. This can be useful for testing purposes.

DummyEmbeddingEstimator([strategy, ...])

A dummy embedding estimator returning an embedding independent of the input data.