Note

This page is a reference documentation. It only explains the class signature, and not how to use it. Please refer to the user guide for the big picture.

nidl.estimators.ssl.utils.ProjectionHead

class nidl.estimators.ssl.utils.ProjectionHead(blocks)[source]

Bases: Module

Base class for all projection and prediction heads in self-supervised estimators.

Parameters:
blockslist of tuple (int, int, Optional[nn.Module], Optional[nn.Module])

List of tuples, each denoting one block of the projection head MLP. Each tuple reads (in_features, out_features, batch_norm_layer, non_linearity_layer). Each block applies:

  1. a linear layer with in_features and out_features (with bias if batch_norm_layer is None)

  2. a batch normalization layer as defined by batch_norm_layer

    (optional)

  3. a non-linearity as defined by non_linearity_layer (optional)

Attributes:
layersnn.Sequential

List of Module to apply.

Examples

>>> # the following projection head has two blocks
>>> # the first block uses batch norm an a ReLU non-linearity
>>> # the second block is a simple linear layer
>>> projection_head = ProjectionHead([
>>>     (256, 256, nn.BatchNorm1d(256), nn.ReLU()),
>>>     (256, 128, None, None)
>>> ])
__init__(blocks)[source]

Initialize internal Module state, shared by both nn.Module and ScriptModule.

forward(x)[source]

Computes one forward pass through the projection head.