Note
This page is a reference documentation. It only explains the class signature, and not how to use it. Please refer to the user guide for the big picture.
nidl.losses.DINOLoss¶
- class nidl.losses.DINOLoss(output_dim=4096, warmup_teacher_temp=0.04, teacher_temp=0.07, warmup_teacher_temp_epochs=30, student_temp=0.1, center_momentum=0.9)[source]¶
Bases:
ModuleImplementation of the DINO loss [1].
This implementation follows the code published by the authors: https://github.com/facebookresearch/dino
It supports global and local image crops. A linear warmup schedule for the teacher temperature is implemented to stabilize training at the beginning. Centering is applied to the teacher output to avoid model collapse.
- Parameters:
- output_dim: int, default=4096
Dimension of the model output.
- warmup_teacher_temp: float, default=0.04
Initial temperature for the teacher network.
- teacher_temp: float, default=0.07
Final temperature for the teacher network.
- warmup_teacher_temp_epochs: int, default=30
Number of epochs for the warmup phase of the teacher temperature.
- student_temp: float, default=0.1
Temperature for the student network.
- center_momentum: float, default=0.9
Momentum term for the center calculation.
References
[1]Caron, M., et al., “Emerging Properties in Self-Supervised Vision Transformers.” ICCV, 2021. https://arxiv.org/abs/2104.14294
Examples
>>> # initialize loss function >>> loss_fn = DINOLoss(128) >>> >>> # generate a view of the images with a random transform >>> view = transform(images) >>> >>> # embed the view with a student and teacher model >>> teacher_out = teacher(view[:2]) >>> student_out = student(view) >>> >>> # calculate loss >>> loss = loss_fn(teacher_out, student_out)
- __init__(output_dim=4096, warmup_teacher_temp=0.04, teacher_temp=0.07, warmup_teacher_temp_epochs=30, student_temp=0.1, center_momentum=0.9)[source]¶
Initialize internal Module state, shared by both nn.Module and ScriptModule.
- forward(teacher_out, student_out, epoch=None)[source]¶
Cross-entropy between softmax outputs of the centered teacher and student.
- Parameters:
- teacher_out: torch.Tensor of shape (n_globals, batch_size, n_features)
Features from the teacher model. Each tensor represents one (global) view of the batch.
- student_out: torch.Tensor of shape (n_tot, batch_size, n_features)
Features from the student model. The first tensors represent global views of the batch (same order as teacher) and the last ones represent local views.
- epoch: int or None
The current epoch used to set the teacher temperature. If None, the default teacher_temp is used.
- Returns:
- loss: torch.Tensor
The average cross-entropy loss.