Menu

Deep learning for NeuroImaging in Python.

Source code for surfify.losses.vae

# -*- coding: utf-8 -*-
##########################################################################
# NSAp - Copyright (C) CEA, 2021
# Distributed under the terms of the CeCILL-B license, as published by
# the CEA-CNRS-INRIA. Refer to the LICENSE file or to
# http://www.cecill.info/licences/Licence_CeCILL-B_V1-en.html
# for details.
##########################################################################

"""
Definition of the Cortical Spherical Variational Auto-Encoder (SVAE) loss.
"""

# Imports
import torch
from torch.nn import functional as func
from torch.distributions import Normal, kl_divergence


[docs] def log_likelihood(recon, xs): """ Computes the log likelihood of the input sample given the reconstructed sample Parameters ---------- recon: Tensor (N, C, H, W) reconstructed images xs: Tensor (N, C, H, W) original images Returns ------- log_likelihoods: Tensor (N) log likelihood for each sample """ return -Normal(recon, torch.ones_like(recon)).log_prob(xs).sum( dim=tuple(range(1, recon.ndim)))
[docs] class SphericalVAELoss: """ Spherical VAE Loss. """ def __init__(self, beta=9, left_mask=None, right_mask=None, use_mse=True): """ Init class. Parameters ---------- beta: float, default 9 weight of the kl divergence. left_mask: Tensor (azimuth, elevation), default None left cortical binary mask. right_mask: Tensor (azimuth, elevation), default None right cortical binary mask. use_mse: bool, default True optionally uses the log likelihood. """ super().__init__() self.beta = beta self.left_mask = left_mask self.right_mask = right_mask self.layer_outputs = None self.use_mse = use_mse def __call__(self, left_recon_x, right_recon_x, left_x, right_x): """ Compute loss. """ if self.layer_outputs is None: raise ValueError( "This loss needs intermediate layers outputs. Please register " "an appropriate callback.") q = self.layer_outputs["q"] # z = self.layer_outputs["z"] if self.left_mask is None: device = left_x.device self.left_mask = torch.ones( (left_x.shape[-2], left_x.shape[-1]), dtype=int).to(device) self.right_mask = torch.ones( (right_x.shape[-2], right_x.shape[-1]), dtype=int).to(device) # Reconstruction loss terms if self.use_mse: left_recon_loss = func.mse_loss( left_recon_x * self.left_mask.detach(), left_x * self.left_mask.detach(), reduction="mean") right_recon_loss = func.mse_loss( right_recon_x * self.right_mask.detach(), right_x * self.right_mask.detach(), reduction="mean") else: left_recon_loss = log_likelihood( left_recon_x * self.left_mask.detach(), left_x * self.left_mask.detach()).mean() right_recon_loss = log_likelihood( right_recon_x * self.right_mask.detach(), right_x * self.right_mask.detach()).mean() # Latent loss between approximate posterior and prior for z kl_div = kl_divergence(q, Normal(0, 1)).mean(dim=0).sum() # Need to maximise the ELBO with respect to these weights loss = left_recon_loss + right_recon_loss + self.beta * kl_div return loss, {"left_recon_loss": left_recon_loss, "right_recon_loss": right_recon_loss, "kl_div": kl_div, "beta": self.beta}

Follow us

© 2025, nidl developers