Note

This page is a reference documentation. It only explains the class signature, and not how to use it. Please refer to the user guide for the big picture.

nidl.losses.DINOLoss

class nidl.losses.DINOLoss(output_dim=4096, warmup_teacher_temp=0.04, teacher_temp=0.07, warmup_teacher_temp_epochs=30, student_temp=0.1, center_momentum=0.9)[source]

Bases: Module

Implementation of the DINO loss [1].

This implementation follows the code published by the authors: https://github.com/facebookresearch/dino

It supports global and local image crops. A linear warmup schedule for the teacher temperature is implemented to stabilize training at the beginning. Centering is applied to the teacher output to avoid model collapse.

Parameters:
output_dim: int, default=4096

Dimension of the model output.

warmup_teacher_temp: float, default=0.04

Initial temperature for the teacher network.

teacher_temp: float, default=0.07

Final temperature for the teacher network.

warmup_teacher_temp_epochs: int, default=30

Number of epochs for the warmup phase of the teacher temperature.

student_temp: float, default=0.1

Temperature for the student network.

center_momentum: float, default=0.9

Momentum term for the center calculation.

References

[1]

Caron, M., et al., “Emerging Properties in Self-Supervised Vision Transformers.” ICCV, 2021. https://arxiv.org/abs/2104.14294

Examples

>>> # initialize loss function
>>> loss_fn = DINOLoss(128)
>>>
>>> # generate a view of the images with a random transform
>>> view = transform(images)
>>>
>>> # embed the view with a student and teacher model
>>> teacher_out = teacher(view[:2])
>>> student_out = student(view)
>>>
>>> # calculate loss
>>> loss = loss_fn(teacher_out, student_out)
__init__(output_dim=4096, warmup_teacher_temp=0.04, teacher_temp=0.07, warmup_teacher_temp_epochs=30, student_temp=0.1, center_momentum=0.9)[source]

Initialize internal Module state, shared by both nn.Module and ScriptModule.

forward(teacher_out, student_out, epoch=None)[source]

Cross-entropy between softmax outputs of the centered teacher and student.

Parameters:
teacher_out: torch.Tensor of shape (n_globals, batch_size, n_features)

Features from the teacher model. Each tensor represents one (global) view of the batch.

student_out: torch.Tensor of shape (n_tot, batch_size, n_features)

Features from the student model. The first tensors represent global views of the batch (same order as teacher) and the last ones represent local views.

epoch: int or None

The current epoch used to set the teacher temperature. If None, the default teacher_temp is used.

Returns:
loss: torch.Tensor

The average cross-entropy loss.

update_center(teacher_out)[source]

Moving average update of the center used for the teacher output.

Parameters:
teacher_out: torch.Tensor of shape (n_views, batch_size, n_features)

Features from the teacher model.