Menu

Deep learning for NeuroImaging in Python.

Source code for surfify.augmentation.utils

# -*- coding: utf-8 -*-
##########################################################################
# NSAp - Copyright (C) CEA, 2023
# Distributed under the terms of the CeCILL-B license, as published by
# the CEA-CNRS-INRIA. Refer to the LICENSE file or to
# http://www.cecill.info/licences/Licence_CeCILL-B_V1-en.html
# for details.
##########################################################################

"""
A module with common augmentation utility functions.
"""

# Import
import abc
import numbers
import numpy as np
from collections import namedtuple


[docs] class RandomAugmentation: """ Apply an augmentation with random parameters defined in intervals. """ Interval = namedtuple("Interval", ["low", "high", "dtype"]) def __init__(self): """ Init class. """ self.intervals = {} self.writable = True def _randomize(self): """ Update the random parameters. """ if self.writable: for param, bound in self.intervals.items(): setattr(self, param, self._rand(bound)) def _rand(self, bound): """ Generate a new random value. """ if bound.dtype == int: return np.random.randint(bound.low, bound.high + 1) elif bound.dtype == float: return np.random.uniform(bound.low, bound.high) else: raise ValueError(f"'{bound.dtype}' dtype not supported.") def __setattr__(self, name, value): """ Store intervals. """ if isinstance(value, RandomAugmentation.Interval): self.intervals[name] = value value = self._rand(value) super().__setattr__(name, value) def __call__(self, data, *args, **kwargs): """ Applies the augmentation to the data. Parameters ---------- data: array (N, ) input data/texture. inplace: bool, default False wether to copy or not the input data (pass as a kwargs). Returns ------- data: arr (N, ) augmented input data. """ self._randomize() if kwargs.get("inplace", True): data = data.copy() return self.run(data, *args, **kwargs) @abc.abstractmethod def run(self, data, *args, **kwargs): return
[docs] def interval(bound, dtype=float): """ Create an interval. Parameters ---------- bound: 2-uplet or number the object used to build the interval. dtype: object, default float data type: float, int, ... Returns ------- interval: 2-uplet an interval. """ if isinstance(bound, numbers.Number): if bound < 0: raise ValueError("Specified interval value must be positive.") bound = (-bound, bound) if len(bound) != 2: raise ValueError("Interval must be specified with 2 values.") min_val, max_val = bound if min_val > max_val: raise ValueError("Wrong interval boundaries.") return RandomAugmentation.Interval(min_val, max_val, dtype)
[docs] class BaseTransformer: """ Class that can be used to register a sequence of transformations. """ Transform = namedtuple("Transform", [ "transform", "probability", "randomize_per_channel"]) def __init__(self): """ Init class. """ self.transforms = []
[docs] def register(self, transform, probability=1, randomize_per_channel=True): """ Register a new transformation. Parameters ---------- transform: RandomAugmentation instance a transformation. probability: float, default 1 the transform is applied with the specified probability. randomize_per_channel: bool, default True a parameter to control if the randomization of tranformation parameters must be applied channel-wise. """ trf = self.Transform(transform=transform, probability=probability, randomize_per_channel=randomize_per_channel) self.transforms.append(trf)
@abc.abstractmethod def __call__(self, data, *args, **kwargs): return
[docs] class Transformer(BaseTransformer): """ Class that can be used to register a sequence of transformations and apply them to some data. """ def __call__(self, data, *args, **kwargs): """ Apply the registered transformations. Parameters ---------- data: array (N, ) or (n_channels, N) the input data. Returns ------- _data: array (N, ) or (n_channels, N) the transformed input data. """ return apply_chained_transforms(data, self.transforms, *args, **kwargs)
[docs] def apply_chained_transforms(data, transforms, *args, **kwargs): """ Function to apply a series of transforms to some data. Parameters ---------- data: array (N, ) or (n_channels, N) the input data. transforms: list of BaseTransformer.Transform list of transforms to apply. Returns ------- _data: array (N, ) or (n_channels, N) the transformed input data. """ ndim = data.ndim assert ndim in (1, 2) _data = data.copy() if ndim == 1: _data = _data[np.newaxis] all_c_data = [] for _c_data in _data: for trf in transforms: if np.random.rand() < trf.probability: _c_data = trf.transform(_c_data, *args, **kwargs) if not trf.randomize_per_channel: trf.transform.writable = False all_c_data.append(_c_data) for trf in transforms: trf.transform.writable = True _data = np.array(all_c_data) return _data.squeeze()
[docs] def multichannel_augmentation(augmentation, randomize_per_channel=True): """ Decorator to transform an augmentation to a multichannel one. Parameters ---------- augmentation: RandomAugmentation class the augmentation class. randomize_per_channel: bool, default True optionnaly randomizes the augmentation parameter for each channel. Returns ------- MultiChannelAugmentation: child class of augmentation augmentation applicable to multi channel data. """ class MultiChannelAugmentation(augmentation): def __call__(self, data, *args, **kwargs): """ Function to apply a series of transforms to some data. Parameters ---------- data: array (N, ) or (n_channels, N) the input data. Returns ------- _data: array (N, ) or (n_channels, N) the transformed input data. """ ndim = data.ndim assert ndim in (1, 2) _data = data.copy() if ndim == 1: _data = _data[np.newaxis] all_c_data = [] for _c_data in _data: _c_data = super().__call__(_c_data, *args, **kwargs) if not randomize_per_channel: self.writable = False all_c_data.append(_c_data) self.writable = True _data = np.array(all_c_data) return _data.squeeze() return MultiChannelAugmentation
[docs] def listify(data): """ Ensure that the input is a list or tuple. Parameters ---------- data: list or array the input data. Returns ------- out: list the liftify input data. """ if isinstance(data, (list, tuple)): return data else: return [data]

Follow us

© 2025, nidl developers