Deep learning for NeuroImaging in Python.
Source code for surfify.nn.modules
# -*- coding: utf-8 -*-
##########################################################################
# NSAp - Copyright (C) CEA, 2021
# Distributed under the terms of the CeCILL-B license, as published by
# the CEA-CNRS-INRIA. Refer to the LICENSE file or to
# http://www.cecill.info/licences/Licence_CeCILL-B_V1-en.html
# for details.
##########################################################################
"""
Module that provides spherical layers.
"""
# Imports
import collections
import torch
import torch.nn as nn
import numpy as np
from ..utils import get_logger, debug_msg
from .functional import circular_pad
# Global parameters
logger = get_logger()
[docs]
class IcoSpMaConv(nn.Module):
""" Define the convolutional layer on icosahedron discretized sphere using
spherical 2-d mapping & circular padding.
Notes
-----
Debuging messages can be displayed by changing the log level using
``setup_logging(level='debug')``.
See Also
--------
IcoDiNeConv, IcoRePaConv
Examples
--------
>>> import torch
>>> from surfify.nn import IcoSpMaConv
>>> module = IcoSpMaConv(
in_feats=8, out_feats=16, kernel_size=3, stride=2, pad=1)
>>> proj_ico_x = torch.zeros((10, 8, 194, 194))
>>> proj_ico_x = module(proj_ico_x)
>>> proj_ico_x.shape
"""
def __init__(self, in_feats, out_feats, kernel_size, stride=1, pad=0):
""" Init IcoSpMaConv.
Parameters
----------
in_feats: int
input features/channels.
out_feats: int
output features/channels.
kernel_size: int or tuple
the convolutional kernel size.
stride: int or tuple, default 1
controls the stride for the cross-correlation.
pad: int or tuple (pad_azimuth, pad_elevation), default 0
the size of the padding.
"""
super().__init__()
self.in_feats = in_feats
self.out_feats = out_feats
self.kernel_size = kernel_size
self.stride = stride
self.pad = pad
self.conv = nn.Conv2d(
in_channels=in_feats,
out_channels=out_feats,
kernel_size=kernel_size, stride=stride, padding=0)
[docs]
def forward(self, x):
logger.debug("IcoSpMaConv...")
logger.debug(debug_msg("input", x))
x = circular_pad(x, pad=self.pad)
logger.debug(debug_msg("pad", x))
x = self.conv(x)
logger.debug(debug_msg("conv", x))
return x
[docs]
class IcoSpMaConvTranspose(nn.Module):
""" Define the transpose convolution on icosahedron discretized sphere
using spherical 2-d mapping & circular padding.
Notes
-----
Debuging messages can be displayed by changing the log level using
``setup_logging(level='debug')``.
See Also
--------
IcoConv, IcoGenericUpConv, IcoUpSample, IcoFixIndexUpSample,
IcoMaxIndexUpSample
Examples
--------
>>> import torch
>>> from surfify.nn import IcoSpMaConvTranspose
>>> module = IcoSpMaConvTranspose(
in_feats=16, out_feats=8, kernel_size=4, stride=2, zero_pad=3,
pad=1)
>>> proj_ico_x = torch.zeros((10, 16, 96, 96))
>>> proj_ico_x = module(proj_ico_x)
>>> proj_ico_x.shape
"""
def __init__(self, in_feats, out_feats, kernel_size, stride=1, pad=0,
zero_pad=0, output_shape=None):
""" Init IcoSpMaConvTranspose.
Parameters
----------
in_feats: int
input features/channels.
out_feats: int
output features/channels.
kernel_size: int or tuple
the convolutional kernel size.
stride: int or tuple, default 1
controls the stride for the cross-correlation.
pad: int or tuple (pad_azimuth, pad_elevation), default 0
the size of the padding.
zero_pad: int or tuple, default 0
add a zero padding in both axes before the transpose convolution.
"""
super().__init__()
self.in_feats = in_feats
self.out_feats = out_feats
self.kernel_size = kernel_size
self.stride = stride
self.pad = pad
self.zero_pad = zero_pad
self.output_shape = output_shape
self.tconv = nn.ConvTranspose2d(
in_channels=in_feats,
out_channels=out_feats,
kernel_size=kernel_size, stride=stride, padding=zero_pad)
[docs]
def forward(self, x):
logger.debug("IcoSpMaConvTranspose...")
logger.debug(debug_msg("input", x))
x = circular_pad(x, pad=self.pad)
logger.debug(debug_msg("pad", x))
output_size = ([len(x)] + self.output_shape if self.output_shape
else None)
x = self.tconv(x, output_size=output_size)
logger.debug(debug_msg("transpose conv", x))
return x
[docs]
class IcoRePaConv(nn.Module):
""" Define the convolutional layer on icosahedron discretized sphere using
rectagular filter in tangent plane.
Notes
-----
Debuging messages can be displayed by changing the log level using
``setup_logging(level='debug')``.
See Also
--------
IcoDiNeConv, IcoSpMaConv
Examples
--------
>>> import torch
>>> from surfify.nn import IcoRePaConv
>>> from surfify.utils import icosahedron, neighbors_rec
>>> ico2_vertices, ico2_triangles = icosahedron(order=2)
>>> neighbors = neighbors_rec(
ico2_vertices, ico2_triangles, size=5, zoom=5)[:2]
>>> module = IcoRePaConv(
in_feats=8, out_feats=8, neighs=neighbors)
>>> ico2_x = torch.zeros((10, 8, len(ico2_vertices)))
>>> ico2_x = module(ico2_x)
>>> ico2_x.shape
"""
def __init__(self, in_feats, out_feats, neighs):
""" Init IcoRePaConv.
Parameters
----------
in_feats: int
input features/channels.
out_feats: int
output features/channels.
neighs: 2-uplet
neigh_indices: array (N, k, 3) - the neighbors indices.
neigh_weights: array (N, k, 3) - the neighbors distances.
"""
super().__init__()
self.in_feats = in_feats
self.out_feats = out_feats
self.neigh_indices, self.neigh_weights = neighs
self.n_vertices, self.neigh_size, _ = self.neigh_indices.shape
self.neigh_indices = self.neigh_indices.reshape(self.n_vertices, -1)
self.neigh_weights = torch.from_numpy(
self.neigh_weights.reshape(self.n_vertices, -1).astype(np.float32))
self.weight = nn.Linear(self.neigh_size * in_feats, out_feats)
[docs]
def forward(self, x):
logger.debug("IcoRePaConv...")
device = x.get_device()
if self.neigh_weights.get_device() != device:
self.neigh_weights = self.neigh_weights.to(device)
logger.debug(debug_msg("input", x))
logger.debug(" weight: {0}".format(self.weight))
logger.debug(" neighbors indices: {0}".format(
self.neigh_indices.shape))
logger.debug(" neighbors weights: {0}".format(
self.neigh_weights.shape))
n_samples = len(x)
mat = x[:, :, self.neigh_indices.reshape(-1)].view(
n_samples, self.in_feats, self.n_vertices, self.neigh_size * 3)
logger.debug(debug_msg("neighors", mat))
x = torch.mul(mat, self.neigh_weights).view(
n_samples, self.in_feats, self.n_vertices, self.neigh_size, 3)
logger.debug(debug_msg("weighted neighors", x))
x = torch.sum(x, dim=4)
logger.debug(debug_msg("sum", x))
x = x.permute(0, 2, 1, 3)
x = x.reshape(n_samples * self.n_vertices,
self.in_feats * self.neigh_size)
out = self.weight(x)
out = out.view(n_samples, self.n_vertices, self.out_feats)
out = out.permute(0, 2, 1)
logger.debug(debug_msg("output", out))
return out
[docs]
class IcoDiNeConv(nn.Module):
""" The convolutional layer on icosahedron discretized sphere using
n-ring filter (based on the Direct Neighbor (DiNe) formulation).
Notes
-----
Debuging messages can be displayed by changing the log level using
``setup_logging(level='debug')``.
See Also
--------
IcoRePaConv, IcoSpMaConv
Examples
--------
>>> import torch
>>> from surfify.nn import IcoDiNeConv
>>> from surfify.utils import icosahedron, neighbors
>>> ico2_vertices, ico2_triangles = icosahedron(order=2)
>>> neighbor_indices = neighbors(
ico2_vertices, ico2_triangles, depth=1, direct_neighbor=True)
>>> neighbor_indices = np.asarray(list(neighbor_indices.values()))
>>> module = IcoDiNeConv(
in_feats=8, out_feats=8, neigh_indices=neighbor_indices)
>>> ico2_x = torch.zeros((10, 8, len(ico2_vertices)))
>>> ico2_x = module(ico2_x)
>>> ico2_x.shape
"""
def __init__(self, in_feats, out_feats, neigh_indices, bias=True):
""" Init IcoDiNeConv.
Parameters
----------
in_feats: int
input features/channels.
out_feats: int
output features/channels.
neigh_indices: array (N, k)
conv layer's filters' neighborhood indices, where N is the ico
number of vertices and k the considered nodes neighbors.
bias: bool, default True
the layer will learn / not learn an additive bias.
"""
super().__init__()
self.in_feats = in_feats
self.out_feats = out_feats
self.neigh_indices = neigh_indices
self.n_vertices, self.neigh_size = neigh_indices.shape
self.weight = nn.Linear(self.neigh_size * in_feats, out_feats,
bias=bias)
[docs]
def forward(self, x):
""" Forward method.
"""
logger.debug("IcoDiNeConv...")
logger.debug(debug_msg("input", x))
logger.debug(" weight: {0}".format(self.weight))
logger.debug(" neighbors indices: {0}".format(
self.neigh_indices.shape))
mat = x[:, :, self.neigh_indices.reshape(-1)].view(
len(x), self.in_feats, self.n_vertices, self.neigh_size)
mat = mat.permute(0, 2, 1, 3)
mat = mat.reshape(len(x) * self.n_vertices,
self.in_feats * self.neigh_size)
logger.debug(debug_msg("neighors", mat))
out_features = self.weight(mat)
out_features = out_features.view(len(x), self.n_vertices,
self.out_feats)
out_features = out_features.permute(0, 2, 1)
logger.debug(debug_msg("output", out_features))
return out_features
[docs]
class IcoPool(nn.Module):
""" The pooling layer on icosahedron discretized sphere using
1-ring filter: can perform a mean or max pooling.
Notes
-----
Debuging messages can be displayed by changing the log level using
``setup_logging(level='debug')``.
Examples
--------
>>> import torch
>>> from surfify.nn import IcoPool
>>> from surfify.utils import downsample, icosahedron, neighbors
>>> ico2_vertices, ico2_triangles = icosahedron(order=2)
>>> ico3_vertices, ico3_triangles = icosahedron(order=3)
>>> down_neigh_indices = neighbors(
ico2_vertices, ico2_triangles, depth=1, direct_neighbor=True)
>>> down_neigh_indices = np.asarray(list(down_neigh_indices.values()))
>>> down_indices = downsample(ico3_vertices, ico2_vertices)
>>> module = IcoPool(
down_neigh_indices=down_neigh_indices, down_indices=down_indices)
>>> ico3_x = torch.zeros((10, 4, len(ico3_vertices)))
>>> ico2_x, _ = module(ico3_x)
>>> ico2_x.shape, ico3_x.shape
"""
def __init__(self, down_neigh_indices, down_indices, pooling_type="mean"):
""" Init IcoPool.
Parameters
----------
down_neigh_indices: array
downsampling neighborhood indices at sampling i + 1.
down_indices: array
downsampling indices at sampling i.
pooling_type: str, default 'mean'
the pooling type: 'mean' or 'max'.
"""
super().__init__()
self.down_indices = down_indices
self.down_neigh_indices = down_neigh_indices[down_indices]
self.n_vertices, self.neigh_size = self.down_neigh_indices.shape
self.pooling_type = pooling_type
[docs]
def forward(self, x):
""" Forward method.
"""
logger.debug("IcoPool...")
logger.debug(debug_msg("input", x))
n_vertices = int((x.size(2) + 6) / 4)
assert self.n_vertices == n_vertices
n_features = x.size(1)
logger.debug(" down neighbors indices: {0}".format(
self.down_neigh_indices.shape))
x = x[:, :, self.down_neigh_indices.reshape(-1)].view(
len(x), n_features, n_vertices, self.neigh_size)
logger.debug(debug_msg("neighors", x))
if self.pooling_type == "mean":
x = torch.mean(x, dim=-1)
max_pool_indices = None
elif self.pooling_type == "max":
x, max_pool_indices = torch.max(x, dim=-1)
logger.debug(debug_msg("max pool indices", max_pool_indices))
else:
raise RuntimeError("Invalid pooling.")
logger.debug(debug_msg("pool", x))
return x, max_pool_indices
[docs]
class IcoUpConv(nn.Module):
""" The transposed convolution layer on icosahedron discretized sphere
using 1-ring filter.
Notes
-----
Debuging messages can be displayed by changing the log level using
``setup_logging(level='debug')``.
See Also
--------
IcoGenericUpConv, IcoUpSample, IcoFixIndexUpSample, IcoMaxIndexUpSample,
IcoSpMaConvTranspose
Examples
--------
>>> import torch
>>> from surfify.nn import IcoUpConv
>>> from surfify.utils import downsample, icosahedron, neighbors
>>> ico2_vertices, ico2_triangles = icosahedron(order=2)
>>> ico3_vertices, ico3_triangles = icosahedron(order=3)
>>> neighbor_indices = neighbors(
ico3_vertices, ico3_triangles, depth=1, direct_neighbor=True)
>>> neighbor_indices = np.asarray(list(neighbor_indices.values()))
>>> down_indices = downsample(ico3_vertices, ico2_vertices)
>>> module = IcoUpConv(
in_feats=8, out_feats=4, up_neigh_indices=neighbor_indices,
down_indices=down_indices)
>>> ico2_x = torch.zeros((10, 8, len(ico2_vertices)))
>>> ico3_x = module(ico2_x)
>>> ico2_x.shape, ico3_x.shape
"""
def __init__(self, in_feats, out_feats, up_neigh_indices, down_indices):
""" Init IcoUpConv.
Parameters
----------
in_feats: int
input features/channels.
out_feats: int
output features/channels.
up_neigh_indices: array
upsampling neighborhood indices at sampling i + 1.
down_indices: array
downsampling indices at sampling i
"""
super().__init__()
self.in_feats = in_feats
self.out_feats = out_feats
self.up_neigh_indices = up_neigh_indices
self.neigh_indices = up_neigh_indices[down_indices]
self.down_indices = down_indices
self.n_vertices, self.neigh_size = self.up_neigh_indices.shape
self.flat_neigh_indices = self.neigh_indices.reshape(-1)
self.argsort_neigh_indices = np.argsort(self.flat_neigh_indices)
self.sorted_neigh_indices = self.flat_neigh_indices[
self.argsort_neigh_indices]
assert (np.unique(self.sorted_neigh_indices).tolist() ==
list(range(self.n_vertices)))
self.sorted_2occ_12neigh_indices = self.sorted_neigh_indices[:24]
self._check_occurence(self.sorted_2occ_12neigh_indices, occ=2)
self.sorted_1occ_neigh_indices = self.sorted_neigh_indices[
24: len(down_indices) + 12]
self._check_occurence(self.sorted_1occ_neigh_indices, occ=1)
self.sorted_2occ_neigh_indices = self.sorted_neigh_indices[
len(down_indices) + 12:]
self._check_occurence(self.sorted_2occ_neigh_indices, occ=2)
self.argsort_2occ_12neigh_indices = self.argsort_neigh_indices[:24]
self.argsort_1occ_neigh_indices = self.argsort_neigh_indices[
24: len(down_indices) + 12]
self.argsort_2occ_neigh_indices = self.argsort_neigh_indices[
len(down_indices) + 12:]
self.weight = nn.Linear(in_feats, self.neigh_size * out_feats)
def _check_occurence(self, data, occ):
count = collections.Counter(data)
unique_count = np.unique(list(count.values()))
assert len(unique_count) == 1
assert unique_count[0] == occ
[docs]
def forward(self, x):
""" Forward method.
"""
logger.debug("IcoUpConv: transpose conv...")
logger.debug(debug_msg("input", x))
n_samples, n_feats, n_vertices = x.size()
logger.debug(" weight: {0}".format(self.weight))
logger.debug(" neighbors indices: {0}".format(
self.neigh_indices.shape))
x = x.permute(0, 2, 1)
x = x.reshape(n_samples * n_vertices, n_feats)
logger.debug(debug_msg("input", x))
x = self.weight(x)
logger.debug(debug_msg("weighted input", x))
x = x.view(n_samples, n_vertices, self.neigh_size, self.out_feats)
logger.debug(debug_msg("weighted input", x))
x = x.view(n_samples, n_vertices * self.neigh_size, self.out_feats)
x1 = x[:, self.argsort_2occ_12neigh_indices]
x1 = x1.view(n_samples, 12, 2, self.out_feats)
logger.debug(debug_msg("12 first 2 occ output", x1))
x2 = x[:, self.argsort_1occ_neigh_indices]
logger.debug(debug_msg("1 occ output", x2))
x3 = x[:, self.argsort_2occ_neigh_indices]
x3 = x3.view(n_samples, -1, 2, self.out_feats)
logger.debug(debug_msg("2 occ output", x3))
x = torch.cat(
(torch.mean(x1, dim=2), x2, torch.mean(x3, dim=2)), dim=1)
x = x.permute(0, 2, 1)
logger.debug(debug_msg("output", x))
return x
[docs]
class IcoGenericUpConv(nn.Module):
""" The transposed convolution layer on icosahedron discretized sphere
using n-ring filter (slow).
Notes
-----
Debuging messages can be displayed by changing the log level using
``setup_logging(level='debug')``.
See Also
--------
IcoUpConv, IcoUpSample, IcoFixIndexUpSample, IcoMaxIndexUpSample,
IcoSpMaConvTranspose
Examples
--------
>>> import torch
>>> from surfify.nn import IcoGenericUpConv
>>> from surfify.utils import downsample, icosahedron, neighbors
>>> ico2_vertices, ico2_triangles = icosahedron(order=2)
>>> ico3_vertices, ico3_triangles = icosahedron(order=3)
>>> neighbor_indices = neighbors(
ico3_vertices, ico3_triangles, depth=1, direct_neighbor=True)
>>> neighbor_indices = np.asarray(list(neighbor_indices.values()))
>>> down_indices = downsample(ico3_vertices, ico2_vertices)
>>> module = IcoGenericUpConv(
in_feats=8, out_feats=4, up_neigh_indices=neighbor_indices,
down_indices=down_indices)
>>> ico2_x = torch.zeros((10, 8, len(ico2_vertices)))
>>> ico3_x = module(ico2_x)
>>> ico2_x.shape, ico3_x.shape
"""
def __init__(self, in_feats, out_feats, up_neigh_indices, down_indices):
""" Init IcoGenericUpConv.
Parameters
----------
in_feats: int
input features/channels.
out_feats: int
output features/channels.
up_neigh_indices: array
upsampling neighborhood indices at sampling i + 1.
down_indices: array
downsampling indices at sampling i
"""
super().__init__()
self.in_feats = in_feats
self.out_feats = out_feats
self.up_neigh_indices = up_neigh_indices
self.neigh_indices = up_neigh_indices[down_indices]
self.down_indices = down_indices
self.n_vertices, self.neigh_size = self.up_neigh_indices.shape
self.flat_neigh_indices = self.neigh_indices.reshape(-1)
self.argsort_neigh_indices = np.argsort(self.flat_neigh_indices)
self.sorted_neigh_indices = self.flat_neigh_indices[
self.argsort_neigh_indices]
assert (np.unique(self.sorted_neigh_indices).tolist() ==
list(range(self.n_vertices)))
count = collections.Counter(self.sorted_neigh_indices)
self.count = sorted(count.items(), key=lambda item: item[0])
self.weight = nn.Linear(in_feats, self.neigh_size * out_feats)
def _check_occurence(self, data, occ):
count = collections.Counter(data)
unique_count = np.unique(list(count.values()))
assert len(unique_count) == 1
assert unique_count[0] == occ
[docs]
def forward(self, x):
""" Forward method.
"""
logger.debug("IcoGenericUpConv: transpose conv...")
logger.debug(debug_msg("input", x))
n_samples, n_feats, n_vertices = x.size()
logger.debug(" weight: {0}".format(self.weight))
logger.debug(" neighbors indices: {0}".format(
self.neigh_indices.shape))
x = x.permute(0, 2, 1)
x = x.reshape(n_samples * n_vertices, n_feats)
logger.debug(debug_msg("input", x))
x = self.weight(x)
logger.debug(debug_msg("weighted input", x))
x = x.view(n_samples, n_vertices, self.neigh_size, self.out_feats)
logger.debug(debug_msg("weighted input", x))
x = x.view(n_samples, n_vertices * self.neigh_size, self.out_feats)
out = torch.zeros(n_samples, self.out_feats, self.n_vertices)
start = 0
for idx in range(self.n_vertices):
_idx, _count = self.count[idx]
assert _idx == idx
stop = start + _count
_x = x[:, self.argsort_neigh_indices[start: stop]]
out[..., idx] = torch.mean(_x, dim=1)
start = stop
logger.debug(debug_msg("output", out))
return out
[docs]
class IcoUpSample(nn.Module):
""" The upsampling layer on icosahedron discretized sphere using
interpolation.
Notes
-----
Debuging messages can be displayed by changing the log level using
``setup_logging(level='debug')``.
See Also
--------
IcoFixIndexUpSample, IcoMaxIndexUpSample, IcoUpConv, IcoGenericUpConv,
IcoSpMaConvTranspose
Examples
--------
>>> import torch
>>> from surfify.nn import IcoUpSample
>>> from surfify.utils import interpolate, icosahedron
>>> ico2_vertices, ico2_triangles = icosahedron(order=2)
>>> ico3_vertices, ico3_triangles = icosahedron(order=3)
>>> up_indices = interpolate(
ico2_vertices, ico3_vertices, ico3_triangles)
>>> up_indices = np.asarray(list(up_indices.values()))
>>> module = IcoUpSample(
in_feats=8, out_feats=4, up_neigh_indices=up_indices)
>>> ico2_x = torch.zeros((10, 8, len(ico2_vertices)))
>>> ico3_x = module(ico2_x)
>>> ico2_x.shape, ico3_x.shape
"""
def __init__(self, in_feats, out_feats, up_neigh_indices):
""" Init IcoUpSample.
Parameters
----------
in_feats: int
input features/channels.
out_feats: int
output features/channels.
up_neigh_indices: array
upsampling neighborhood indices.
"""
super().__init__()
self.up_neigh_indices = up_neigh_indices
self.n_vertices, self.neigh_size = up_neigh_indices.shape
self.in_feats = in_feats
self.out_feats = out_feats
self.fc = nn.Linear(in_feats, out_feats)
[docs]
def forward(self, x):
""" Forward method.
"""
logger.debug("IcoUpSample: interp...")
logger.debug(debug_msg("input", x))
n_vertices = x.size(2) * 4 - 6
assert self.n_vertices == n_vertices
n_features = x.size(1)
logger.debug(" up neighbors indices: {0}".format(
self.up_neigh_indices.shape))
x = x[:, :, self.up_neigh_indices.reshape(-1)].view(
len(x), n_features, n_vertices, self.neigh_size)
logger.debug(debug_msg("neighbors", x))
x = torch.mean(x, dim=-1)
logger.debug(debug_msg("interp", x))
n_samples = len(x)
x = x.permute(0, 2, 1)
x = x.reshape(n_samples * self.n_vertices, self.in_feats)
x = self.fc(x)
x = x.view(n_samples, self.n_vertices, self.out_feats)
x = x.permute(0, 2, 1)
logger.debug(debug_msg("output", x))
return x
[docs]
class IcoFixIndexUpSample(nn.Module):
""" The upsampling layer on icosahedron discretized sphere using fixed
zero indices (padding new vertices with 0).
Notes
-----
Debuging messages can be displayed by changing the log level using
``setup_logging(level='debug')``.
See Also
--------
IcoUpSample, IcoMaxIndexUpSample, IcoUpConv, IcoGenericUpConv,
IcoSpMaConvTranspose
Examples
--------
>>> import torch
>>> from surfify.nn import IcoFixIndexUpSample
>>> from surfify.utils import interpolate, icosahedron
>>> ico2_vertices, ico2_triangles = icosahedron(order=2)
>>> ico3_vertices, ico3_triangles = icosahedron(order=3)
>>> up_indices = interpolate(
ico2_vertices, ico3_vertices, ico3_triangles)
>>> up_indices = np.asarray(list(up_indices.values()))
>>> module = IcoFixIndexUpSample(
in_feats=8, out_feats=4, up_neigh_indices=up_indices)
>>> ico2_x = torch.zeros((10, 8, len(ico2_vertices)))
>>> ico3_x = module(ico2_x)
>>> ico2_x.shape, ico3_x.shape
"""
def __init__(self, in_feats, out_feats, up_neigh_indices):
""" Init IcoFixIndexUpSample.
Parameters
----------
in_feats: int
input features/channels.
out_feats: int
output features/channels.
up_neigh_indices: array
upsampling neighborhood indices.
"""
super().__init__()
self.up_neigh_indices = up_neigh_indices
self.n_vertices, self.neigh_size = up_neigh_indices.shape
self.in_feats = in_feats
self.out_feats = out_feats
self.fc = nn.Linear(in_feats, out_feats)
self.new_indices = []
for idx, row in enumerate(self.up_neigh_indices):
if len(np.unique(row)) > 1:
self.new_indices.append(idx)
[docs]
def forward(self, x):
""" Forward method.
"""
logger.debug("IcoFixIndexUpSample: zero padding...")
logger.debug(debug_msg("input", x))
n_vertices = x.size(2) * 4 - 6
assert self.n_vertices == n_vertices
# n_features = x.size(1)
logger.debug(" up neighbors indices: {0}".format(
self.up_neigh_indices.shape))
x = x[:, :, self.up_neigh_indices[:, 0]]
logger.debug(debug_msg("neighbors", x))
x[:, :, self.new_indices] = 0
logger.debug(debug_msg("interp", x))
n_samples = len(x)
x = x.permute(0, 2, 1)
x = x.reshape(n_samples * self.n_vertices, self.in_feats)
x = self.fc(x)
x = x.view(n_samples, self.n_vertices, self.out_feats)
x = x.permute(0, 2, 1)
logger.debug(debug_msg("output", x))
return x
[docs]
class IcoMaxIndexUpSample(nn.Module):
""" The upsampling layer on icosahedron discretized sphere using
max indices.
Notes
-----
Debuging messages can be displayed by changing the log level using
``setup_logging(level='debug')``.
See Also
--------
IcoUpConv, IcoGenericUpConv, IcoUpSample, IcoFixIndexUpSample,
IcoSpMaConvTranspose
Examples
--------
>>> import torch
>>> from surfify.nn import IcoMaxIndexUpSample
>>> from surfify.utils import downsample, icosahedron, neighbors
>>> ico2_vertices, ico2_triangles = icosahedron(order=2)
>>> ico3_vertices, ico3_triangles = icosahedron(order=3)
>>> neighbor_indices = neighbors(
ico3_vertices, ico3_triangles, depth=1, direct_neighbor=True)
>>> neighbor_indices = np.asarray(list(neighbor_indices.values()))
>>> down_neigh_indices = neighbors(
ico2_vertices, ico2_triangles, depth=1, direct_neighbor=True)
>>> down_neigh_indices = np.asarray(list(down_neigh_indices.values()))
>>> down_indices = downsample(ico3_vertices, ico2_vertices)
>>> module = IcoPool(
down_neigh_indices=down_neigh_indices,
down_indices=down_indices, pooling_type="max")
>>> ico3_x = torch.zeros((10, 4, len(ico3_vertices)))
>>> _, max_pool_indices = module(ico3_x)
>>> module = IcoMaxIndexUpSample(
in_feats=8, out_feats=4, up_neigh_indices=neighbor_indices,
down_indices=down_indices)
>>> ico2_x = torch.zeros((10, 8, len(ico2_vertices)))
>>> ico3_x = module(ico2_x, max_pool_indices)
>>> ico2_x.shape, ico3_x.shape
"""
def __init__(self, in_feats, out_feats, up_neigh_indices, down_indices):
""" Init IcoMaxIndexUpSample.
Parameters
----------
in_feats: int
input features/channels.
out_feats: int
output features/channels.
up_neigh_indices: array
upsampling neighborhood indices at sampling i + 1.
down_indices: array
downsampling indices at sampling i.
"""
super().__init__()
self.up_neigh_indices = up_neigh_indices
self.neigh_indices = up_neigh_indices[down_indices]
self.down_indices = down_indices
self.n_vertices, self.neigh_size = up_neigh_indices.shape
self.in_feats = in_feats
self.out_feats = out_feats
self.fc = nn.Linear(in_feats, out_feats)
[docs]
def forward(self, x, max_pool_indices):
""" Forward method.
"""
logger.debug("IcoMaxIndexUpSample: max pooling driven zero padding...")
logger.debug(debug_msg("input", x))
logger.debug(" neighbors indices: {0}".format(
self.neigh_indices.shape))
logger.debug(" max pool indices: {0}".format(max_pool_indices.shape))
logger.debug(debug_msg("input", x))
n_samples, n_feats, n_raw_vertices = x.size()
x = x.permute(0, 2, 1)
x = x.reshape(n_samples * n_raw_vertices, self.in_feats)
x = self.fc(x)
x = x.view(n_samples, n_raw_vertices, self.out_feats)
x = x.permute(0, 2, 1)
logger.debug(debug_msg("fc", x))
n_samples, n_feats, n_raw_vertices = x.size()
x = x.reshape(n_samples, -1)
y = torch.zeros(n_samples, n_feats, self.n_vertices)
vertices_indices = np.zeros((n_samples, n_feats, n_raw_vertices))
# TODO: how to deal with different channels count
for idx in range(n_raw_vertices):
vertices_indices[..., idx] = self.neigh_indices[idx][
max_pool_indices[..., idx]]
vertices_indices = torch.from_numpy(vertices_indices).long()
logger.debug(" vertices indices: {0}".format(vertices_indices.shape))
vertices_indices = vertices_indices.view(n_samples, -1)
logger.debug(" vertices indices: {0}".format(vertices_indices.shape))
feats_indices = np.floor(
np.linspace(0.0, float(n_feats), num=(n_raw_vertices * n_feats)))
feats_indices[-1] -= 1
feats_indices = torch.from_numpy(feats_indices).long()
logger.debug(" features indices: {0}".format(feats_indices.shape))
y[:, feats_indices, vertices_indices] = x
logger.debug(debug_msg("interp", y))
return y
Follow us