Menu

Deep learning for NeuroImaging in Python.

Source code for surfify.models.unet

# -*- coding: utf-8 -*-
##########################################################################
# NSAp - Copyright (C) CEA, 2021
# Distributed under the terms of the CeCILL-B license, as published by
# the CEA-CNRS-INRIA. Refer to the LICENSE file or to
# http://www.cecill.info/licences/Licence_CeCILL-B_V1-en.html
# for details.
##########################################################################

"""
The spherical UNet architecture.
"""

# Imports
import torch
import torch.nn as nn
import torch.nn.functional as func
from joblib import Memory
from ..utils import number_of_ico_vertices, get_logger, debug_msg
from ..nn import (
    IcoUpConv, IcoMaxIndexUpSample, IcoFixIndexUpSample, IcoUpSample, IcoPool,
    IcoSpMaConv, IcoSpMaConvTranspose)
from .base import SphericalBase


# Global parameters
logger = get_logger()


[docs] class GraphicalUNet(nn.Module): """ The Graph U-Net model: implements a U-Net like architecture with graph pooling and unpooling operations. Notes ----- Debuging messages can be displayed by changing the log level using ``setup_logging(level='debug')``. See Also -------- SphericalUNet, SphericalGUNet References ---------- Hongyang Gao, and Shuiwang Ji, Graph U-Nets, arXiv, 2019. """ def __init__(self, in_channels, out_channels, depth=5, hidden_channels=32, pool_ratios=0.5, sum_res=False, act=func.relu): """ Init GraphicalUNet. Parameters ---------- in_channels: int input features/channels. out_channels: int output features/channels. depth: int, default 5 number of layers in the UNet. hidden_channels: int, default 32 number of convolutional filters for the convs. pool_ratios: float or list of float, default 0.5 graph pooling ratio for each depth. sum_res: bool,default True if set to False, will use concatenation for integration of skip connections instead summation. act: torch.nn.functional, default relu the nonlinearity to use. """ import torch_geometric.nn as gnn from torch_geometric.utils.repeat import repeat super().__init__() assert depth >= 1 self.in_channels = in_channels self.hidden_channels = hidden_channels self.out_channels = out_channels self.depth = depth self.pool_ratios = repeat(pool_ratios, depth) self.act = act self.sum_res = sum_res channels = hidden_channels self.down_convs = torch.nn.ModuleList() self.pools = torch.nn.ModuleList() self.down_convs.append(gnn.GCNConv(in_channels, channels, improved=True)) for i in range(depth): new_channels = channels * 2 self.pools.append(gnn.TopKPooling(channels, self.pool_ratios[i])) self.down_convs.append(gnn.GCNConv(channels, new_channels, improved=True)) channels = new_channels self.up_convs = torch.nn.ModuleList() for _i in range(depth - 1): new_channels = channels // 2 in_channels = channels if sum_res else channels + new_channels self.up_convs.append(gnn.GCNConv(in_channels, new_channels, improved=True)) channels = new_channels new_channels = channels // 2 in_channels = channels if sum_res else channels + new_channels self.up_convs.append(gnn.GCNConv(in_channels, out_channels, improved=True)) self.reset_parameters() def reset_parameters(self): for conv in self.down_convs: conv.reset_parameters() for pool in self.pools: pool.reset_parameters() for conv in self.up_convs: conv.reset_parameters()
[docs] def forward(self, x, edge_index, batch=None): if batch is None: batch = edge_index.new_zeros(x.size(0)) edge_weight = x.new_ones(edge_index.size(1)) # print("input", x.shape) x = self.down_convs[0](x, edge_index, edge_weight) x = self.act(x) # print("down", x.shape) xs = [x] edge_indices = [edge_index] edge_weights = [edge_weight] perms = [] for i in range(1, self.depth + 1): # edge_index, edge_weight = self.augment_adj( # edge_index, edge_weight, x.size(0)) x, edge_index, edge_weight, batch, perm, _ = self.pools[i - 1]( x, edge_index, edge_weight, batch) x = self.down_convs[i](x, edge_index, edge_weight) x = self.act(x) if i < self.depth: xs += [x] edge_indices += [edge_index] edge_weights += [edge_weight] perms += [perm] for i in range(self.depth): j = self.depth - 1 - i res = xs[j] edge_index = edge_indices[j] edge_weight = edge_weights[j] perm = perms[j] up = torch.zeros(res.size(dim=0), x.size(dim=1), dtype=x.dtype, device=x.device) up[perm] = x # print("zero-pad", x.shape, up.shape) x = res + up if self.sum_res else torch.cat((res, up), dim=-1) # print("cat", x.shape) x = self.up_convs[i](x, edge_index, edge_weight) x = self.act(x) if i < self.depth - 1 else x # print("up", x.shape) return x
def augment_adj(self, edge_index, edge_weight, num_nodes): from torch_sparse import spspmm from torch_geometric.utils import ( add_self_loops, sort_edge_index, remove_self_loops) edge_index, edge_weight = remove_self_loops(edge_index, edge_weight) edge_index, edge_weight = add_self_loops(edge_index, edge_weight, num_nodes=num_nodes) edge_index, edge_weight = sort_edge_index(edge_index, edge_weight, num_nodes) edge_index, edge_weight = spspmm(edge_index, edge_weight, edge_index, edge_weight, num_nodes, num_nodes, num_nodes) edge_index, edge_weight = remove_self_loops(edge_index, edge_weight) return edge_index, edge_weight
[docs] class SphericalUNet(SphericalBase): """ The Spherical U-Net architecture. The architecture is built upon specific spherical surface convolution, pooling, and transposed convolution modules. It has an encoder path and a decoder path, with a user-defined resolution steps. Different from the standard U-Net, all 3x3 convolution are replaced with the RePa or DiNe convolution, 2x2 up-convolution with surface transposed convolution or surface upsampling, and 2x2 max pooling with surface max/mean pooling. In addition to the standard U-Net, before each convolution layer's rectified linear units (ReLU) activation function, a batch normalization layer is added. At the final layer, 1x1 convolution is replaced by vertex-wise filter. The number of feature channels are double after each surface pooling layer and halve at each transposed convolution or up sampling layer. Notes ----- Debuging messages can be displayed by changing the log level using ``setup_logging(level='debug')``. See Also -------- SphericalGUNet Examples -------- >>> import torch >>> from surfify.models import SphericalUNet >>> from surfify.utils import icosahedron >>> vertices, triangles = icosahedron(order=2) >>> model = SphericalUNet( in_order=2, in_channels=2, out_channels=4, depth=2, start_filts=8, conv_mode="DiNe", dine_size=1, up_mode="interp", standard_ico=False) >>> x = torch.zeros((10, 2, len(vertices))) >>> out = model(x) >>> out.shape References ---------- Zhao F, et al., Spherical U-Net on Cortical Surfaces: Methods and Applications, IPMI, 2019. """ def __init__(self, in_order, in_channels, out_channels, depth=5, start_filts=32, conv_mode="DiNe", dine_size=1, repa_size=5, repa_zoom=5, dynamic_repa_zoom=False, up_mode="interp", standard_ico=False, cachedir=None): """ Init SphericalUNet. Parameters ---------- in_order: int the input icosahedron order. in_channels: int input features/channels. out_channels: int output features/channels. depth: int, default 5 number of layers in the UNet. start_filts: int, default 32 number of convolutional filters for the first conv. conv_mode: str, default 'DiNe' use either 'RePa' - Rectangular Patch convolution method or 'DiNe' - 1 ring Direct Neighbor convolution method.. dine_size: int, default 1 the size of the spherical convolution filter, ie. the number of neighbor rings to be considered. repa_size: int, default 5 the size of the rectangular grid in the tangent space. repa_zoom: int, default 5 control the rectangular grid spacing in the tangent space by applying a multiplicative factor of `1 / repa_zoom`. dynamic_repa_zoom: bool, default False dynamically adapt the RePa zoom by applying a multiplicative factor of `log(order + 1) + 1`. up_mode: str, default 'interp' type of upsampling: 'transpose' for transpose convolution (1 ring), 'interp' for nearest neighbor linear interpolation, 'maxpad' for max pooling shifted zero padding, and 'zeropad' for classical zero padding. standard_ico: bool, default False optionaly use surfify tesselation. cachedir: str, default None set this folder to use smart caching speedup. """ logger.debug("SphericalUNet init...") super().__init__( input_order=in_order, n_layers=depth, conv_mode=conv_mode, dine_size=dine_size, repa_size=repa_size, repa_zoom=repa_zoom, dynamic_repa_zoom=dynamic_repa_zoom, standard_ico=standard_ico, cachedir=cachedir) self.memory = Memory(cachedir, verbose=0) self.in_order = in_order self.depth = depth self.in_vertices = number_of_ico_vertices(order=in_order) self.in_channels = in_channels self.out_channels = out_channels self.up_mode = up_mode self.filts = [in_channels] + [ start_filts * 2 ** idx for idx in range(depth)] logger.debug("- filters: {0}".format(self.filts)) for idx in range(depth): order = self.in_order - idx logger.debug( "- DownBlock {0}: {1} -> {2} [{3} - {4} - {5}]".format( idx, self.filts[idx], self.filts[idx + 1], self.ico[order].neighbor_indices.shape, (None if idx == 0 else self.ico[order + 1].neighbor_indices.shape), (None if idx == 0 else self.ico[order + 1].down_indices.shape))) block = DownBlock( conv_layer=self.sconv, in_ch=self.filts[idx], out_ch=self.filts[idx + 1], conv_neigh_indices=self.ico[order].conv_neighbor_indices, down_neigh_indices=( None if idx == 0 else self.ico[order + 1].neighbor_indices), down_indices=( None if idx == 0 else self.ico[order + 1].down_indices), pool_mode=("max" if self.up_mode == "maxpad" else "mean"), first=(idx == 0)) setattr(self, "down{0}".format(idx + 1), block) cnt = 1 for idx in range(depth - 1, 0, -1): logger.debug("- UpBlock {0}: {1} -> {2} [{3} - {4}]".format( cnt, self.filts[idx + 1], self.filts[idx], self.ico[order + 1].neighbor_indices.shape, self.ico[order].up_indices.shape)) block = UpBlock( conv_layer=self.sconv, in_ch=self.filts[idx + 1], out_ch=self.filts[idx], conv_neigh_indices=self.ico[order + 1].conv_neighbor_indices, neigh_indices=self.ico[order + 1].neighbor_indices, up_neigh_indices=self.ico[order].up_indices, down_indices=self.ico[order + 1].down_indices, up_mode=self.up_mode) setattr(self, "up{0}".format(cnt), block) order += 1 cnt += 1 logger.debug("- FC: {0} -> {1}".format(self.filts[1], out_channels)) self.fc = nn.Sequential( nn.Linear(self.filts[1], out_channels))
[docs] def forward(self, x): """ Forward method. """ logger.debug("SphericalUNet...") logger.debug(debug_msg("input", x)) if x.size(2) != self.in_vertices: raise RuntimeError("Input data must be projected on an {0} order " "icosahedron.".format(self.in_order)) encoder_outs = [] pooling_outs = [] for idx in range(1, self.depth + 1): down_block = getattr(self, "down{0}".format(idx)) logger.debug("- filter {0}: {1}".format(idx, down_block)) x, max_pool_indices = down_block(x) encoder_outs.append(x) pooling_outs.append(max_pool_indices) encoder_outs = encoder_outs[::-1] pooling_outs = pooling_outs[::-1] for idx in range(1, self.depth): up_block = getattr(self, "up{0}".format(idx)) logger.debug("- filter {0}: {1}".format(idx, up_block)) x_up = encoder_outs[idx] max_pool_indices = pooling_outs[idx - 1] x = up_block(x, x_up, max_pool_indices) logger.debug("FC...") logger.debug(debug_msg("input", x)) n_samples = len(x) x = x.permute(0, 2, 1) x = x.reshape(n_samples * self.in_vertices, self.filts[1]) x = self.fc(x) x = x.view(n_samples, self.in_vertices, self.out_channels) x = x.permute(0, 2, 1) logger.debug(debug_msg("output", x)) return x
[docs] class DownBlock(nn.Module): """ Downsampling block in spherical UNet: mean pooling => (conv => BN => ReLU) * 2 """ def __init__(self, conv_layer, in_ch, out_ch, conv_neigh_indices, down_neigh_indices, down_indices, pool_mode="mean", first=False): """ Init DownBlock. Parameters ---------- conv_layer: nn.Module the convolutional layer on icosahedron discretized sphere. in_ch: int input features/channels. out_ch: int output features/channels. conv_neigh_indices: array conv layer's filters' neighborhood indices at sampling i. down_neigh_indices: array conv layer's filters' neighborhood indices at sampling i + 1. down_indices: array downsampling indices at sampling i. pool_mode: str, default 'mean' the pooling mode: 'mean' or 'max'. first: bool, default False if set skip the pooling block. """ super().__init__() self.first = first if not first: self.pooling = IcoPool( down_neigh_indices, down_indices, pool_mode) self.double_conv = nn.Sequential( conv_layer(in_ch, out_ch, conv_neigh_indices), nn.BatchNorm1d(out_ch, momentum=0.15, affine=True, track_running_stats=False), nn.LeakyReLU(0.2, inplace=True), conv_layer(out_ch, out_ch, conv_neigh_indices), nn.BatchNorm1d(out_ch, momentum=0.15, affine=True, track_running_stats=False), nn.LeakyReLU(0.2, inplace=True))
[docs] def forward(self, x): """ Forward method. """ logger.debug("- DownBlock") logger.debug(debug_msg("input", x)) max_pool_indices = None if not self.first: x, max_pool_indices = self.pooling(x) logger.debug(debug_msg("pooling", x)) if max_pool_indices is not None: logger.debug(debug_msg("max pooling indices", max_pool_indices)) x = self.double_conv(x) logger.debug(debug_msg("output", x)) return x, max_pool_indices
[docs] class UpBlock(nn.Module): """ Define the upsamping block in spherical UNet: upconv => (conv => BN => ReLU) * 2 """ def __init__(self, conv_layer, in_ch, out_ch, conv_neigh_indices, neigh_indices, up_neigh_indices, down_indices, up_mode): """ Init UpBlock. Parameters ---------- conv_layer: nn.Module the convolutional layer on icosahedron discretized sphere. in_ch: int input features/channels. out_ch: int output features/channels. conv_neigh_indices: tensor, int conv layer's filters' neighborhood indices at sampling i. neigh_indices: tensor, int neighborhood indices at sampling i. up_neigh_indices: array upsampling neighborhood indices at sampling i + 1. down_indices: array downsampling indices at sampling i. up_mode: str, default 'interp' type of upsampling: 'transpose' for transpose convolution, 'interp' for nearest neighbor linear interpolation, 'maxpad' for max pooling shifted zero padding, and 'zeropad' for classical zero padding. """ super().__init__() self.up_mode = up_mode if up_mode == "interp": self.up = IcoUpSample(in_ch, out_ch, up_neigh_indices) elif up_mode == "zeropad": self.up = IcoFixIndexUpSample(in_ch, out_ch, up_neigh_indices) elif up_mode == "maxpad": self.up = IcoMaxIndexUpSample( in_ch, out_ch, neigh_indices, down_indices) elif up_mode == "transpose": self.up = IcoUpConv( in_ch, out_ch, neigh_indices, down_indices) else: raise ValueError("Invalid upsampling method.") self.double_conv = nn.Sequential( conv_layer(in_ch, out_ch, conv_neigh_indices), nn.BatchNorm1d(out_ch, momentum=0.15, affine=True, track_running_stats=False), nn.LeakyReLU(0.2, inplace=True), conv_layer(out_ch, out_ch, conv_neigh_indices), nn.BatchNorm1d(out_ch, momentum=0.15, affine=True, track_running_stats=False), nn.LeakyReLU(0.2, inplace=True))
[docs] def forward(self, x1, x2, max_pool_indices): """ Forward method. """ logger.debug("- UpBlock") logger.debug(debug_msg("input", x1)) logger.debug(debug_msg("skip", x2)) x1 = (self.up(x1, max_pool_indices) if self.up_mode == "maxpad" else self.up(x1)) logger.debug(debug_msg("upsampling", x1)) x = torch.cat((x1, x2), 1) logger.debug(debug_msg("cat", x)) x = self.double_conv(x) logger.debug(debug_msg("output", x)) return x
[docs] class SphericalGUNet(nn.Module): """ The Spherical Grided U-Net architecture. The architecture is built upon specific spherical surface convolution, pooling, and transposed convolution modules. It has an encoder path and a decoder path, with a user-defined resolution steps. Different from the standard U-Net, all 3x3 convolution are replaced with the SpMa convolution. In addition to the standard U-Net, before each convolution layer's rectified linear units (ReLU) activation function, a batch normalization layer is added. The number of feature channels are double after each surface pooling layer and halve at each transposed convolution or up sampling layer. Notes ----- Debuging messages can be displayed by changing the log level using ``setup_logging(level='debug')``. See Also -------- SphericalUNet References ---------- Zhao F, et al., Spherical U-Net on Cortical Surfaces: Methods and Applications, IPMI, 2019. """ def __init__(self, in_channels, out_channels, input_dim=192, depth=5, start_filts=32): """ Init SphericalUNet. Parameters ---------- in_channels: int input features/channels. out_channels: int output features/channels. input_dim: int, default 192 the size of the converted 3-D surface to the 2-D grid. depth: int, default 5 number of layers in the UNet. start_filts: int, default 32 number of convolutional filters for the first conv. """ logger.debug("SphericalGUNet init...") super().__init__() self.in_channels = in_channels self.out_channels = out_channels self.input_dim = input_dim self.depth = depth self.start_filts = start_filts self.filts = [in_channels] + [ start_filts * 2 ** idx for idx in range(depth)] logger.debug("- filters: {0}".format(self.filts)) for idx in range(depth): logger.debug( "- DownGBlock {0}: {1} -> {2}".format( idx, self.filts[idx], self.filts[idx + 1])) block = DownGBlock( in_ch=self.filts[idx], out_ch=self.filts[idx + 1], first=(idx == 0)) setattr(self, "down{0}".format(idx + 1), block) cnt = 1 for idx in range(depth - 1, 0, -1): logger.debug("- UpGBlock {0}: {1} -> {2}".format( cnt, self.filts[idx + 1], self.filts[idx])) block = UpGBlock( in_ch=self.filts[idx + 1], out_ch=self.filts[idx]) setattr(self, "up{0}".format(cnt), block) cnt += 1 logger.debug("- Conv 1x1 final: {0} -> {1}".format( self.filts[1], out_channels)) self.conv_final = nn.Conv2d( self.filts[1], out_channels, kernel_size=1, groups=1, stride=1)
[docs] def forward(self, x): """ Forward method. """ logger.debug("SphericalGUNet...") logger.debug(debug_msg("input", x)) encoder_outs = [] for idx in range(1, self.depth + 1): down_block = getattr(self, "down{0}".format(idx)) logger.debug("- filter {0}: {1}".format(idx, down_block)) x = down_block(x) encoder_outs.append(x) encoder_outs = encoder_outs[::-1] for idx in range(1, self.depth): up_block = getattr(self, "up{0}".format(idx)) logger.debug("- filter {0}: {1}".format(idx, up_block)) x_up = encoder_outs[idx] x = up_block(x, x_up) x = self.conv_final(x) logger.debug(debug_msg("output", x)) return x
[docs] class DownGBlock(nn.Module): """ Downsampling block in grided spherical UNet: max pooling => (conv => BN => ReLU) * 2 """ def __init__(self, in_ch, out_ch, first=False): """ Init DownGBlock. Parameters ---------- in_ch: int input features/channels. out_ch: int output features/channels. first: bool, default False if set skip the pooling block. """ super().__init__() self.first = first if not first: self.pooling = nn.MaxPool2d(kernel_size=2, stride=2) self.double_conv = nn.Sequential( IcoSpMaConv(in_feats=in_ch, out_feats=out_ch, kernel_size=3, pad=1), nn.BatchNorm2d(out_ch, momentum=0.15, affine=True, track_running_stats=False), nn.LeakyReLU(0.2, inplace=True), IcoSpMaConv(in_feats=out_ch, out_feats=out_ch, kernel_size=3, pad=1), nn.BatchNorm2d(out_ch, momentum=0.15, affine=True, track_running_stats=False), nn.LeakyReLU(0.2, inplace=True))
[docs] def forward(self, x): """ Forward method. """ logger.debug("- DownGBlock") logger.debug(debug_msg("input", x)) if not self.first: x = self.pooling(x) logger.debug(debug_msg("pooling", x)) x = self.double_conv(x) logger.debug(debug_msg("output", x)) return x
[docs] class UpGBlock(nn.Module): """ Define the upsamping block in grided spherical UNet: upconv => (conv => BN => ReLU) * 2 """ def __init__(self, in_ch, out_ch): """ Init UpGBlock. Parameters ---------- in_ch: int input features/channels. out_ch: int output features/channels. """ super().__init__() self.up = IcoSpMaConvTranspose( in_feats=in_ch, out_feats=out_ch, kernel_size=4, stride=2, pad=1, zero_pad=3) self.double_conv = nn.Sequential( IcoSpMaConv(in_feats=in_ch, out_feats=out_ch, kernel_size=3, pad=1), nn.BatchNorm2d(out_ch, momentum=0.15, affine=True, track_running_stats=False), nn.LeakyReLU(0.2, inplace=True), IcoSpMaConv(in_feats=out_ch, out_feats=out_ch, kernel_size=3, pad=1), nn.BatchNorm2d(out_ch, momentum=0.15, affine=True, track_running_stats=False), nn.LeakyReLU(0.2, inplace=True))
[docs] def forward(self, x1, x2): """ Forward method. """ logger.debug("- UpGBlock") logger.debug(debug_msg("input", x1)) logger.debug(debug_msg("skip", x2)) x1 = self.up(x1) logger.debug(debug_msg("upsampling", x1)) x = torch.cat((x1, x2), 1) logger.debug(debug_msg("cat", x)) x = self.double_conv(x) logger.debug(debug_msg("output", x)) return x

Follow us

© 2025, nidl developers