Note

This page is a reference documentation. It only explains the class signature, and not how to use it. Please refer to the user guide for the big picture.

nidl.losses.DCLLoss

class nidl.losses.DCLLoss(temperature=0.1, pos_weight_fn=None)[source]

Bases: Module

Implementation of the Decoupled Contrastive Learning loss [1]

This loss function implements the decoupled contrastive learning loss as described in [1]. It builds upon the classic InfoNCE loss but removes the positive-negative coupling that biases training in small batch sizes.

Given a mini-batch of size N, we obtain two embeddings z_{i}^(1) and z_{i}^(2) representing two different augmented views of the same sample. The DCL loss is defined as:

\mathcal{L}_i^{(k)}
= - \big(\operatorname{sim}(z_i^{(1)}, z_i^{(2)})/\tau\big)
+ \log
\sum\limits_{l \in \{1,2\}, j \in \![1,N\!]}
\mathbf{1}_{[j \ne i]},
\exp\!\big(\operatorname{sim}(z_i^{(k)}, z_j^{(l)})/\tau\big)

where \operatorname{sim}(z_i^(k), z_j^(l)) denotes the cosine similarity between the normalized embeddings z_i^(k) and z_j^(l), and \tau > 0 is a temperature parameter controlling the concentration of the distribution. \mathbf{1}_{[j \ne i]} ensures decoupling.

Additionnaly, a weighting function w can be added to modulate the contribution of the positive pairs’ similarity to the loss. The intuition is that when the embedding of the positive sample z_i^{(2)} is close to the anchor z_i^{(1)}, there is less learning signal than when the two embeddings are less similar. The weighted loss is:

\mathcal{L}_i^{(k)}
= - w(z_i^{(1)}, z_i^{(2)})
\big(\operatorname{sim}(z_i^{(1)}, z_i^{(2)})/\tau\big)
+ \log
\sum\limits_{l \in \{1,2\}, j \in \![1,N\!]}
\mathbf{1}_{[j \ne i]},
\exp\!\big(\operatorname{sim}(z_i^{(k)}, z_j^{(l)})/\tau\big)

See the class DCLWLoss for an implementation with a negative von Mises-Fisher weighting function such as proposed in [1].

Parameters:
temperature: float, default=0.1

Scale logits by the inverse of the temperature.

pos_weight_fn: Optional[callable], default=None

Weighting function of the positive pairs (w in [1]). It is a callable that takes two tensors z^(1) and z^(2) as inputs and returns the weights w(z1,z2) as a tensor. If None, a DCL loss without weighting is returned.

References

[1] (1,2,3,4)

Yeh, Chun-Hsiao, et al. “Decoupled contrastive learning.” European conference on computer vision. Cham: Springer Nature Switzerland, https://www.ecva.net/papers/eccv_2022/papers_ECCV/papers/136860653.pdf

__init__(temperature=0.1, pos_weight_fn=None)[source]

Initialize internal Module state, shared by both nn.Module and ScriptModule.

forward(z1, z2)[source]

Forward implementation.

Parameters:
z1: torch.Tensor of shape (batch_size, n_features)

First embedded view.

z2: torch.Tensor of shape (batch_size, n_features)

Second embedded view.

Returns:
loss: torch.Tensor

The DCL loss computed between z1 and z2.