Menu

PyTorch toolbox to work with spherical surfaces.

Source code for surfify.models.simclr

import torch
from torch import nn
import torch.nn.functional as func


[docs] class SimCLR(nn.Module): """ Class implementing the simCLR model. Refer for instance to: Ting Chen, Simon Kornblith, Mohammad Norouzi, Geoffrey Hinton A Simple Framework for Contrastive Learning of Visual Representations """ def __init__(self, latent_dim, hidden_layers, temperature, backbone, return_logits=False): super().__init__() self.backbone = backbone # projector sizes = [latent_dim] + hidden_layers layers = [] for i in range(len(sizes) - 2): layers.extend([ nn.Linear(sizes[i], sizes[i + 1], bias=False), nn.BatchNorm1d(sizes[i + 1]), nn.ReLU(inplace=True) ]) layers.append(nn.Linear(sizes[-2], sizes[-1], bias=False)) self.projector = nn.Sequential(*layers) self.temperature = temperature self.INF = 1e8 self.return_logits = return_logits
[docs] def forward(self, y1, y2): z_i = self.projector(self.backbone(y1)) z_j = self.projector(self.backbone(y2)) N = len(z_i) z_i = func.normalize(z_i, p=2, dim=-1) # dim [N, D] z_j = func.normalize(z_j, p=2, dim=-1) # dim [N, D] # dim [N, N] => Upper triangle contains incorrect pairs sim_zii = (z_i @ z_i.T) / self.temperature # dim [N, N] => Upper triangle contains incorrect pairs sim_zjj = (z_j @ z_j.T) / self.temperature # dim [N, N] => the diag contains the correct pairs (i,j) # (x transforms via T_i and T_j) sim_zij = (z_i @ z_j.T) / self.temperature # 'Remove' the diag terms by penalizing it (exp(-inf) = 0) sim_zii = sim_zii - self.INF * torch.eye(N, device=z_i.device) sim_zjj = sim_zjj - self.INF * torch.eye(N, device=z_i.device) correct_pairs = torch.arange(N, device=z_i.device).long() loss_i = func.cross_entropy(torch.cat([sim_zij, sim_zii], dim=1), correct_pairs) loss_j = func.cross_entropy(torch.cat([sim_zij.T, sim_zjj], dim=1), correct_pairs) if self.return_logits: return (loss_i + loss_j), sim_zij, correct_pairs return (loss_i + loss_j)

Follow us

© 2025, surfify developers