Note

This page is a reference documentation. It only explains the class signature, and not how to use it. Please refer to the user guide for the big picture.

nidl.losses.YAwareInfoNCE

class nidl.losses.YAwareInfoNCE(kernel='gaussian', bandwidth=1.0, temperature=0.1)[source]

Bases: Module

Implementation of the y-Aware InfoNCE loss [1].

Compute the y-Aware InfoNCE loss, which integrates auxiliary information into contrastive learning by weighting sample pairs.

Given a mini-batch of size n, two embeddings z_1=(z_1^i)_{i\in [1..n]} and z_2=(z_2^i)_{i\in [1..n]} representing two views of the same samples and a weighting matrix W=(w_{i,j})_{i,j\in [1..n]} computed using auxiliary variables y, the loss is:

\mathcal{L}_{NCE}^y = -\frac{1}{n} \sum_{i,j} \frac{w_{i,j}}         {\sum_{k=1}^{n} w_{i, k}} \log \frac{\exp(\text{sim}(z_1^{i},         z_2^{j}) / \tau)}{\sum_{k=1}^{n} \exp(\text{sim}(z_1^{i}, z_2^{k})        / \tau)}

where sim is the cosine similarity, \tau is the temperature and w_{i,j} is computed with a kernel K (e.g. Gaussian) and bandwidth H as:

w_{i,j} = K\left( H^{-\frac{1}{2}} (y_i-y_j) \right)

Parameters:
kernel: str in {‘gaussian’, ‘epanechnikov’, ‘exponential’, ‘linear’, ‘cosine’}, default=’gaussian’

Kernel to compute the weighting matrix between auxiliary variables. See PhD thesis, Dufumier 2022 page 94-95.

bandwidth: Union[float, int, List[float], array, KernelMetric], default=1.0

The method used to calculate the bandwidth (\sigma^2 in [1]) between auxiliary variables:

  • If bandwidth is a scalar (int or float), it sets the bandwidth to a diagnonal matrix with equal values.

  • If bandwidth is a 1d array, it sets the bandwidth to a diagonal matrix and it must be of size equal to the number of features in y.

  • If bandwidth is a 2d array, it must be of shape (n_features, n_features) where n_features is the number of features in y.

  • If bandwidth is KernelMetric, it uses the pairwise method to compute the similarity matrix between auxiliary variables.

temperature: float, default=0.1

Temperature used to scale the dot-product between embedded vectors

References

[1] (1,2)

Dufumier, B., et al., “Contrastive learning with continuous proxy meta-data for 3D MRI classification.” MICCAI, 2021. https://arxiv.org/abs/2106.08808

__init__(kernel='gaussian', bandwidth=1.0, temperature=0.1)[source]

Initialize internal Module state, shared by both nn.Module and ScriptModule.

forward(z1, z2, labels=None)[source]
Parameters:
z1: torch.Tensor of shape (batch_size, n_features)

First embedded view.

z2: torch.Tensor of shape (batch_size, n_features)

Second embedded view.

labels: Optional[torch.Tensor] of shape (batch_size, n_labels)

Auxiliary variables associated to the input data. If None, the standard InfoNCE loss is returned.

Returns:
loss: torch.Tensor

The y-Aware InfoNCE loss computed between z1 and z2.