Deep learning for NeuroImaging in Python.
Source code for surfify.models.vgg
# -*- coding: utf-8 -*-
##########################################################################
# NSAp - Copyright (C) CEA, 2021
# Distributed under the terms of the CeCILL-B license, as published by
# the CEA-CNRS-INRIA. Refer to the LICENSE file or to
# http://www.cecill.info/licences/Licence_CeCILL-B_V1-en.html
# for details.
##########################################################################
"""
Spherical implementation of the torch vision VGG.
"""
# Imports
import torch
import torch.nn as nn
from ..utils import get_logger, debug_msg
from ..nn import IcoPool, IcoSpMaConv
from .base import SphericalBase
# Global parameters
logger = get_logger()
[docs]
class SphericalVGG(SphericalBase):
""" Spherical VGG architecture.
Notes
-----
Debuging messages can be displayed by changing the log level using
``setup_logging(level='debug')``.
See Also
--------
SphericalGVGG
Examples
--------
>>> import torch
>>> from surfify.utils import icosahedron
>>> from surfify.models import SphericalVGG11
>>> verts, tris = icosahedron(order=6)
>>> x = torch.zeros((1, 2, len(verts)))
>>> model = SphericalVGG11(
>>> input_channels=2, n_classes=10, input_order=6,
>>> conv_mode="DiNe", dine_size=1, hidden_dim=512,
>>> fusion_level=2, init_weights=True, standard_ico=False)
>>> print(model)
>>> out = model(x, x)
>>> print(out.shape)
"""
def __init__(self, input_channels, cfg, n_classes, input_order=5,
conv_mode="DiNe", dine_size=1, repa_size=5, repa_zoom=5,
dynamic_repa_zoom=False, hidden_dim=4096, batch_norm=False,
fusion_level=1, init_weights=True, standard_ico=False,
cachedir=None):
""" Init class.
Parameters
----------
input_channels: int
the number of input channels.
cfg: list
the definition of layers where 'M' stands for max pooling.
num_classes: int
the number of class in the classification problem.
input_order: int, default 5
the input icosahedron order.
conv_mode: str, default 'DiNe'
use either 'RePa' - Rectangular Patch convolution method or 'DiNe'
- 1 ring Direct Neighbor convolution method.
dine_size: int, default 1
the size of the spherical convolution filter, ie. the number of
neighbor rings to be considered.
repa_size: int, default 5
the size of the rectangular grid in the tangent space.
repa_zoom: int, default 5
control the rectangular grid spacing in the tangent space by
applying a multiplicative factor of `1 / repa_zoom`.
dynamic_repa_zoom: bool, default False
dynamically adapt the RePa zoom by applying a multiplicative factor
of `log(order + 1) + 1`.
hidden_dim: int, default 4096
the 2-layer classification MLP number of hidden dims.
batch_norm: bool, default False
wether or not to use batch normalization after a convolution
layer.
fusion_level: int, default 1
at which max pooling level left and right hemisphere data
are concatenated.
init_weights: bool, default True
initialize network weights.
standard_ico: bool, default False
optionaly use surfify tesselation.
cachedir: str, default None
set this folder to use smart caching speedup.
"""
logger.debug("SphericalVGG init...")
cfg = cfg
super().__init__(
input_order=input_order, n_layers=cfg.count("M"),
conv_mode=conv_mode, dine_size=dine_size, repa_size=repa_size,
repa_zoom=repa_zoom, dynamic_repa_zoom=dynamic_repa_zoom,
standard_ico=standard_ico, cachedir=cachedir)
self.input_channels = input_channels
self.cfg = cfg
self.n_classes = n_classes
self.batch_norm = batch_norm
self.n_modules = len(cfg)
if fusion_level > self.n_layers or fusion_level <= 0:
raise ValueError("Impossible to use input fusion level with "
"'{0}' layers.".format(self.n_layers))
self.fusion_level = fusion_level
self.final_flt = int(cfg[-2])
self.top_flatten_dim = len(
self.ico[self.input_order - self.n_layers + 1].vertices)
self.top_final = self.final_flt * 7
self._make_encoder()
self.avgpool = nn.AdaptiveAvgPool1d(7)
self.classifier = nn.Sequential(
nn.Linear(self.top_final, hidden_dim),
nn.ReLU(True),
nn.Dropout(),
nn.Linear(hidden_dim, hidden_dim),
nn.ReLU(True),
nn.Dropout(),
nn.Linear(hidden_dim, n_classes)
)
if init_weights:
self._initialize_weights()
[docs]
def forward(self, left_x, right_x):
""" Forward method.
Parameters
----------
left_x: Tensor (samples, <input_channels>, azimuth, elevation)
input left cortical texture.
right_x: Tensor (samples, <input_channels>, azimuth, elevation)
input right cortical texture.
Returns
-------
out: torch.Tensor
the prediction.
"""
logger.debug("SphericalVGG forward pass")
logger.debug(debug_msg("left cortical", left_x))
logger.debug(debug_msg("right cortical", right_x))
left_x = self._safe_forward(self.enc_left_conv, left_x)
right_x = self._safe_forward(self.enc_right_conv, right_x)
x = torch.cat((left_x, right_x), dim=1)
logger.debug(debug_msg("lh/rh path", x))
x = self._safe_forward(self.enc_w_conv, x)
logger.debug(debug_msg("features", x))
x = self.avgpool(x)
logger.debug(debug_msg("avg pooling", x))
x = torch.flatten(x, 1)
logger.debug(debug_msg("flat", x))
x = self.classifier(x)
logger.debug(debug_msg("classifier", x))
return x
def _initialize_weights(self):
""" Init model weights.
"""
for m in self.modules():
if isinstance(m, nn.Conv1d):
nn.init.kaiming_normal_(
m.weight, mode="fan_out", nonlinearity="relu")
if m.bias is not None:
nn.init.constant_(m.bias, 0)
elif isinstance(m, nn.BatchNorm1d):
nn.init.constant_(m.weight, 1)
nn.init.constant_(m.bias, 0)
elif isinstance(m, nn.Linear):
nn.init.normal_(m.weight, 0, 0.01)
nn.init.constant_(m.bias, 0)
def _make_encoder(self):
""" Method to create the encoding layers.
"""
input_channels = self.input_channels
order = self.input_order
self.enc_left_conv = nn.Sequential()
self.enc_right_conv = nn.Sequential()
self.enc_w_conv = nn.Sequential()
multi_path = True
layer_idx = 0
for idx in range(self.n_modules):
if self.cfg[idx] == "M":
order -= 1
layer_idx += 1
if layer_idx == self.fusion_level:
multi_path = False
input_channels *= 2
pooling = IcoPool(
down_neigh_indices=self.ico[order + 1].neighbor_indices,
down_indices=self.ico[order + 1].down_indices,
pooling_type="max")
if multi_path:
self.enc_left_conv.add_module(
"pooling_{0}".format(idx), pooling)
self.enc_right_conv.add_module(
"pooling_{0}".format(idx), pooling)
else:
self.enc_w_conv.add_module(
"pooling_{0}".format(idx), pooling)
elif multi_path:
lconv = self.sconv(
input_channels, (self.cfg[idx] // 2),
self.ico[order].conv_neighbor_indices)
self.enc_left_conv.add_module("l_conv_{0}".format(idx), lconv)
if self.batch_norm:
lbn = nn.BatchNorm1d(self.cfg[idx] // 2)
self.enc_left_conv.add_module("l_bn_{0}".format(idx), lbn)
lrelu = nn.ReLU(inplace=True)
self.enc_left_conv.add_module("l_relu_{0}".format(idx), lrelu)
rconv = self.sconv(
input_channels, (self.cfg[idx] // 2),
self.ico[order].conv_neighbor_indices)
self.enc_right_conv.add_module("r_conv_{0}".format(idx), rconv)
if self.batch_norm:
rbn = nn.BatchNorm1d(self.cfg[idx] // 2)
self.enc_right_conv.add_module("r_bn_{0}".format(idx), rbn)
rrelu = nn.ReLU(inplace=True)
self.enc_right_conv.add_module("r_relu_{0}".format(idx), rrelu)
input_channels = self.cfg[idx] // 2
else:
conv = self.sconv(
input_channels, self.cfg[idx],
self.ico[order].conv_neighbor_indices)
self.enc_w_conv.add_module("conv_{0}".format(idx), conv)
if self.batch_norm:
bn = nn.BatchNorm1d(self.cfg[idx])
self.enc_w_conv.add_module("bn_{0}".format(idx), bn)
relu = nn.ReLU(inplace=True)
self.enc_w_conv.add_module("relu_{0}".format(idx), relu)
input_channels = self.cfg[idx]
[docs]
class SphericalGVGG(nn.Module):
""" Spherical Grided VGG architecture.
Notes
-----
Debuging messages can be displayed by changing the log level using
``setup_logging(level='debug')``.
See Also
--------
SphericalVGG
Examples
--------
>>> import torch
>>> from surfify.models import SphericalGVGG11
>>> x = torch.zeros((1, 2, 192, 192))
>>> model = SphericalGVGG11(
>>> input_channels=2, n_classes=10, input_dim=194, hidden_dim=512,
>>> fusion_level=2, init_weights=True)
>>> print(model)
>>> out = model(x, x)
>>> print(out.shape)
"""
def __init__(self, input_channels, cfg, n_classes, input_dim=194,
hidden_dim=4096, batch_norm=False, fusion_level=1,
init_weights=True):
""" Init class.
Parameters
----------
input_channels: int
the number of input channels.
cfg: list
the definition of layers where 'M' stands for max pooling.
n_classes: int
the number of class in the classification problem.
input_dim: int, default 192
the size of the converted 3-D surface to the 2-D grid.
hidden_dim: int, default 4096
the 2-layer classification MLP number of hidden dims.
batch_norm: bool, default False
wether or not to use batch normalization after a convolution
layer.
fusion_level: int, default 1
at which max pooling level left and right hemisphere data
are concatenated.
init_weights: bool, default True
initialize network weights.
"""
logger.debug("SphericalGVGG init...")
super().__init__()
self.input_channels = input_channels
self.cfg = cfg
self.n_classes = n_classes
self.input_dim = input_dim
self.batch_norm = batch_norm
self.n_modules = len(cfg)
self.n_layers = cfg.count("M")
if fusion_level > self.n_layers or fusion_level <= 0:
raise ValueError("Impossible to use input fusion level with "
"'{0}' layers.".format(self.n_layers))
self.fusion_level = fusion_level
self.final_flt = int(cfg[-2])
self.top_flatten_dim = int(self.input_dim / (2 ** self.n_layers))
self.top_final = self.final_flt * 7 ** 2
self._make_encoder()
self.avgpool = nn.AdaptiveAvgPool2d((7, 7))
self.classifier = nn.Sequential(
nn.Linear(self.top_final, hidden_dim),
nn.ReLU(True),
nn.Dropout(),
nn.Linear(hidden_dim, hidden_dim),
nn.ReLU(True),
nn.Dropout(),
nn.Linear(hidden_dim, n_classes)
)
if init_weights:
self._initialize_weights()
[docs]
def forward(self, left_x, right_x):
""" Forward method.
Parameters
----------
left_x: Tensor (samples, <input_channels>, azimuth, elevation)
input left cortical texture.
right_x: Tensor (samples, <input_channels>, azimuth, elevation)
input right cortical texture.
Returns
-------
out: torch.Tensor
the prediction.
"""
logger.debug("SphericalGVGG forward pass")
logger.debug(debug_msg("left cortical", left_x))
logger.debug(debug_msg("right cortical", right_x))
x = torch.cat(
(self.enc_left_conv(left_x), self.enc_right_conv(right_x)), dim=1)
logger.debug(debug_msg("lh/rh path", x))
x = self.enc_w_conv(x)
logger.debug(debug_msg("features", x))
x = self.avgpool(x)
logger.debug(debug_msg("avg pooling", x))
x = torch.flatten(x, 1)
logger.debug(debug_msg("flat", x))
x = self.classifier(x)
logger.debug(debug_msg("classifier", x))
return x
def _initialize_weights(self):
""" Init model weights.
"""
for m in self.modules():
if isinstance(m, nn.Conv2d):
nn.init.kaiming_normal_(
m.weight, mode="fan_out", nonlinearity="relu")
if m.bias is not None:
nn.init.constant_(m.bias, 0)
elif isinstance(m, nn.BatchNorm2d):
nn.init.constant_(m.weight, 1)
nn.init.constant_(m.bias, 0)
elif isinstance(m, nn.Linear):
nn.init.normal_(m.weight, 0, 0.01)
nn.init.constant_(m.bias, 0)
def _make_encoder(self):
""" Method to create the encoding layers.
"""
input_channels = self.input_channels
self.enc_left_conv = nn.Sequential()
self.enc_right_conv = nn.Sequential()
self.enc_w_conv = nn.Sequential()
multi_path = True
layer_idx = 0
for idx in range(self.n_modules):
if self.cfg[idx] == "M":
layer_idx += 1
if layer_idx == self.fusion_level:
multi_path = False
input_channels *= 2
pooling = nn.MaxPool2d(kernel_size=2, stride=2)
if multi_path:
self.enc_left_conv.add_module(
"pooling_{0}".format(idx), pooling)
self.enc_right_conv.add_module(
"pooling_{0}".format(idx), pooling)
else:
self.enc_w_conv.add_module(
"pooling_{0}".format(idx), pooling)
elif multi_path:
lconv = IcoSpMaConv(
in_feats=input_channels,
out_feats=(self.cfg[idx] // 2),
kernel_size=3, pad=1)
self.enc_left_conv.add_module("l_conv_{0}".format(idx), lconv)
if self.batch_norm:
lbn = nn.BatchNorm2d(self.cfg[idx] // 2)
self.enc_left_conv.add_module("l_bn_{0}".format(idx), lbn)
lrelu = nn.ReLU(inplace=True)
self.enc_left_conv.add_module("l_relu_{0}".format(idx), lrelu)
rconv = IcoSpMaConv(
in_feats=input_channels,
out_feats=(self.cfg[idx] // 2),
kernel_size=3, pad=1)
self.enc_right_conv.add_module("r_conv_{0}".format(idx), rconv)
if self.batch_norm:
rbn = nn.BatchNorm2d(self.cfg[idx] // 2)
self.enc_right_conv.add_module("r_bn_{0}".format(idx), rbn)
rrelu = nn.ReLU(inplace=True)
self.enc_right_conv.add_module("r_relu_{0}".format(idx), rrelu)
input_channels = self.cfg[idx] // 2
else:
conv = IcoSpMaConv(
in_feats=input_channels,
out_feats=self.cfg[idx],
kernel_size=3, pad=1)
self.enc_w_conv.add_module("conv_{0}".format(idx), conv)
if self.batch_norm:
bn = nn.BatchNorm2d(self.cfg[idx])
self.enc_w_conv.add_module("bn_{0}".format(idx), bn)
relu = nn.ReLU(inplace=True)
self.enc_w_conv.add_module("relu_{0}".format(idx), relu)
input_channels = self.cfg[idx]
[docs]
def class_factory(klass_name, klass_params, destination_module_globals):
""" Dynamically define a class.
In order to make the class publicly accessible, we assign the result of
the function to a variable dynamically using globals().
Parameters
----------
klass_name: str
the class name that will be created.
klass_params: dict
the class specific parameters.
"""
class SphericalVGGBase(SphericalVGG):
cfg = None
batch_norm = False
def __init__(self, input_channels, n_classes, input_order=5,
conv_mode="DiNe", dine_size=1, repa_size=5, repa_zoom=5,
dynamic_repa_zoom=False, hidden_dim=4096,
fusion_level=1, init_weights=True, standard_ico=False,
cachedir=None):
if self.cfg is None:
raise ValueError("Please specify a configuration first.")
SphericalVGG.__init__(
self,
cfg=self.cfg,
batch_norm=self.batch_norm,
input_channels=input_channels,
n_classes=n_classes,
input_order=input_order,
conv_mode=conv_mode,
dine_size=dine_size,
repa_size=repa_size,
repa_zoom=repa_zoom,
dynamic_repa_zoom=dynamic_repa_zoom,
hidden_dim=hidden_dim,
fusion_level=fusion_level,
init_weights=init_weights,
standard_ico=standard_ico,
cachedir=cachedir)
class SphericalGVGGBase(SphericalGVGG):
cfg = None
batch_norm = False
def __init__(self, input_channels, n_classes, input_dim=194,
hidden_dim=4096, fusion_level=1, init_weights=True):
if self.cfg is None:
raise ValueError("Please specify a configuration first.")
SphericalGVGG.__init__(
self,
cfg=self.cfg,
batch_norm=self.batch_norm,
input_channels=input_channels,
n_classes=n_classes,
input_dim=input_dim,
hidden_dim=hidden_dim,
fusion_level=fusion_level,
init_weights=init_weights)
klass_params.update({
"__module__": destination_module_globals["__name__"],
"_id": destination_module_globals["__name__"] + "." + klass_name
})
_klass_name = "Spherical" + klass_name
destination_module_globals[_klass_name] = type(
_klass_name, (SphericalVGGBase, ), klass_params)
_klass_name = "SphericalG" + klass_name
destination_module_globals[_klass_name] = type(
_klass_name, (SphericalGVGGBase, ), klass_params)
klass_params["batch_norm"] = True
_klass_name = "Spherical" + klass_name + "BN"
destination_module_globals[_klass_name] = type(
_klass_name, (SphericalVGGBase, ), klass_params)
_klass_name = "SphericalG" + klass_name + "BN"
destination_module_globals[_klass_name] = type(
_klass_name, (SphericalGVGGBase, ), klass_params)
CFGS = {
"VGG11": {
"cfg": [64, "M", 128, "M", 256, 256, "M", 512, 512, "M", 512, 512,
"M"]
},
"VGG13": {
"cfg": [64, 64, "M", 128, 128, "M", 256, 256, "M", 512, 512, "M", 512,
512, "M"]
},
"VGG16": {
"cfg": [64, 64, "M", 128, 128, "M", 256, 256, 256, "M", 512, 512, 512,
"M", 512, 512, 512, "M"]
},
"VGG19": {
"cfg": [64, 64, "M", 128, 128, "M", 256, 256, 256, 256, "M", 512, 512,
512, 512, "M", 512, 512, 512, 512, "M"]
}
}
destination_module_globals = globals()
for klass_name, klass_params in CFGS.items():
class_factory(klass_name, klass_params, destination_module_globals)
Follow us