Deep learning for NeuroImaging in Python.
Source code for surfify.models.base
# -*- coding: utf-8 -*-
##########################################################################
# NSAp - Copyright (C) CEA, 2021
# Distributed under the terms of the CeCILL-B license, as published by
# the CEA-CNRS-INRIA. Refer to the LICENSE file or to
# http://www.cecill.info/licences/Licence_CeCILL-B_V1-en.html
# for details.
##########################################################################
"""
A base class for spherical networks that generate the requested icosahedrons
and related informations.
"""
# Imports
from collections import namedtuple
import numpy as np
from joblib import Memory
import torch.nn as nn
from ..utils import (
icosahedron, neighbors, downsample, interpolate,
neighbors_rec, get_logger)
from ..nn import IcoDiNeConv, IcoRePaConv, IcoPool
# Global parameters
logger = get_logger()
[docs]
class SphericalBase(nn.Module):
""" Spherical network base information.
Use either RePa - Rectangular Patch convolution method or DiNe - Direct
Neighbor convolution method.
Examples
--------
>>> from surfify.models import SphericalBase
>>> ico_info = SphericalBase.build_ico_info(input_order=3, n_layers=2)
>>> print(ico_info.keys())
"""
Ico = namedtuple("Ico", ["order", "vertices", "triangles",
"neighbor_indices", "down_indices", "up_indices",
"conv_neighbor_indices"])
def __init__(self, input_order, n_layers, conv_mode="DiNe",
dine_size=1, repa_size=5, repa_zoom=5,
dynamic_repa_zoom=False, standard_ico=False, cachedir=None):
""" Init class.
Parameters
----------
input_order: int
the input icosahedron order.
n_layers: int
the number of layers in the network.
conv_mode: str, default 'DiNe'
use either 'RePa' - Rectangular Patch convolution method or 'DiNe'
- 1 ring Direct Neighbor convolution method.
dine_size: int, default 1
the size of the spherical convolution filter, ie. the number of
neighbor rings to be considered.
repa_size: int, default 5
the size of the rectangular grid in the tangent space.
repa_zoom: int, default 5
control the rectangular grid spacing in the tangent space by
applying a multiplicative factor of `1 / repa_zoom`.
dynamic_repa_zoom: bool, default False
dynamically adapt the RePa zoom by applying a multiplicative factor
of `log(order + 1) + 1`.
standard_ico: bool, default False
optionally uses a standard icosahedron tessalation. FreeSurfer
tesselation is used by default.
cachedir: str, default None
set this folder to use smart caching speedup.
"""
super().__init__()
self.input_order = input_order
self.n_layers = n_layers
self.conv_mode = conv_mode
self.dine_size = dine_size
self.repa_size = repa_size
self.repa_zoom = repa_zoom
self.dynamic_repa_zoom = dynamic_repa_zoom
self.standard_ico = standard_ico
self.cachedir = cachedir
if conv_mode == "RePa":
self.sconv = IcoRePaConv
else:
self.sconv = IcoDiNeConv
self.ico = self.build_ico_info(
input_order, n_layers, conv_mode, dine_size, repa_size, repa_zoom,
dynamic_repa_zoom, standard_ico, cachedir)
def _safe_forward(self, block, x, act=None, skip_last_act=False):
""" Perform a safe forward pass on a specific input block.
"""
n_mods = len(list(block.children()))
for cnt, mod in enumerate(block.children()):
if isinstance(mod, IcoPool):
x = mod(x)[0]
else:
x = mod(x)
if skip_last_act and cnt == (n_mods - 1):
continue
if act is not None:
x = act(x)
return x
[docs]
@classmethod
def build_ico_info(cls, input_order, n_layers, conv_mode="DiNe",
dine_size=1, repa_size=5, repa_zoom=5,
dynamic_repa_zoom=False, standard_ico=False,
cachedir=None):
""" Build an dictionnary containing icosehedron informations at
each order of interest with the related upsampling and downsampling
informations. This methods is useful to speed up processings
by caching icosahedron onformations.
Parameters
----------
input_order: int
the input icosahedron order.
n_layers: int
the number of layers in the network.
conv_mode: str, default 'DiNe'
use either 'RePa' - Rectangular Patch convolution method or 'DiNe'
- 1 ring Direct Neighbor convolution method.
dine_size: int, default 1
the size of the spherical convolution filter, ie. the number of
neighbor rings to be considered.
repa_size: int, default 5
the size of the rectangular grid in the tangent space.
repa_zoom: int, default 5
control the rectangular grid spacing in the tangent space by
applying a multiplicative factor of `1 / repa_zoom`.
dynamic_repa_zoom: bool, default False
dynamically adapt the RePa zoom by applying a multiplicative factor
of `log(order + 1) + 1`.
standard_ico: bool, default False
optionally uses a standard icosahedron tessalation. FreeSurfer
tesselation is used by default.
cachedir: str, default None
set this folder to use smart caching speedup.
Returns
-------
ico: dict of Ico
the icosahedron informations at different orders.
"""
ico = {}
memory = Memory(cachedir, verbose=0)
icosahedron_cached = memory.cache(icosahedron)
neighbors_cached = memory.cache(neighbors)
neighbors_rec_cached = memory.cache(neighbors_rec)
for order in range(input_order - n_layers,
input_order + 1):
vertices, triangles = icosahedron_cached(
order=order, standard_ico=standard_ico)
logger.debug("- ico {0}: verts {1} - tris {2}".format(
order, vertices.shape, triangles.shape))
neighs = neighbors_cached(
vertices, triangles, depth=1, direct_neighbor=True)
neighs = np.asarray(list(neighs.values()))
logger.debug("- neighbors {0}: {1}".format(order, neighs.shape))
if conv_mode == "DiNe":
if dine_size == 1:
conv_neighs = neighs
else:
conv_neighs = neighbors_cached(
vertices, triangles, depth=dine_size,
direct_neighbor=True)
conv_neighs = np.asarray(list(conv_neighs.values()))
logger.debug("- conv neighbors {0}: {1}".format(
order, conv_neighs.shape))
elif conv_mode == "RePa":
if dynamic_repa_zoom:
current_zoom = repa_zoom * (np.log(order + 1) + 1)
else:
current_zoom = repa_zoom
conv_neighs, conv_weights, _ = neighbors_rec_cached(
vertices, triangles, size=repa_size,
zoom=current_zoom)
logger.debug("- conv neighbors {0} - {1}: {2} - {3}".format(
order, current_zoom, conv_neighs.shape,
conv_weights.shape))
conv_neighs = (conv_neighs, conv_weights)
else:
raise ValueError("Unexptected convolution mode.")
ico[order] = cls.Ico(
order=order, vertices=vertices, triangles=triangles,
neighbor_indices=neighs, down_indices=None, up_indices=None,
conv_neighbor_indices=conv_neighs)
downsample_cached = memory.cache(downsample)
for order in range(
input_order, input_order - n_layers, -1):
down_indices = downsample_cached(
ico[order].vertices, ico[order - 1].vertices)
logger.debug("- down {0}: {1}".format(order, down_indices.shape))
ico[order] = ico[order]._replace(
down_indices=down_indices)
interpolate_cached = memory.cache(interpolate)
for order in range(input_order - n_layers, input_order):
up_indices = interpolate_cached(
ico[order].vertices, ico[order + 1].vertices,
ico[order + 1].triangles)
up_indices = np.asarray(list(up_indices.values()))
logger.debug("- up {0}: {1}".format(order, up_indices.shape))
ico[order] = ico[order]._replace(
up_indices=up_indices)
return ico
Follow us