Note
This page is a reference documentation. It only explains the class signature, and not how to use it. Please refer to the user guide for the big picture.
nidl.estimators.ssl.utils.ProjectionHead¶
- class nidl.estimators.ssl.utils.ProjectionHead(blocks)[source]¶
Bases:
ModuleBase class for all projection and prediction heads in self-supervised estimators.
- Parameters:
- blockslist of tuple (int, int, Optional[nn.Module], Optional[nn.Module])
List of tuples, each denoting one block of the projection head MLP. Each tuple reads (in_features, out_features, batch_norm_layer, non_linearity_layer). Each block applies:
a linear layer with in_features and out_features (with bias if batch_norm_layer is None)
- a batch normalization layer as defined by batch_norm_layer
(optional)
a non-linearity as defined by non_linearity_layer (optional)
- Attributes:
- layersnn.Sequential
List of
Moduleto apply.
Examples
>>> # the following projection head has two blocks >>> # the first block uses batch norm an a ReLU non-linearity >>> # the second block is a simple linear layer >>> projection_head = ProjectionHead([ >>> (256, 256, nn.BatchNorm1d(256), nn.ReLU()), >>> (256, 128, None, None) >>> ])