Menu

Deep learning for NeuroImaging in Python.

Source code for nidl.datasets.base

##########################################################################
# NSAp - Copyright (C) CEA, 2025
# Distributed under the terms of the CeCILL-B license, as published by
# the CEA-CNRS-INRIA. Refer to the LICENSE file or to
# http://www.cecill.info/licences/Licence_CeCILL-B_V1-en.html
# for details.
##########################################################################


""" Base class to generate datasets.
"""

import abc
import errno
import glob
import hashlib
import os
import warnings

import numpy as np
import pandas as pd
from torch.utils.data import Dataset


[docs] class BaseDataset(Dataset): """ Base neuroimaging dataset. Notes ----- A 'participants.tsv' file containing subject information (including the requested targets) is expected at the root. A '<split>.tsv' file containg the subject to include is expected at the root. Parameters ---------- root: str the location where are stored the data. patterns: str or list of str the relative locations of your data. channels: str or list of str, default=None the name of the channels. split: str, default 'train' define the split to be considered. targets: str or list of str, default=None the dataset will also return these tabular data. target_mapping: dict, default None optionaly, define a dictionary specifying different replacement values for different existing values. See pandas DataFrame.replace documentation for more information. transforms: callable, default None a function that can be called to augment the input images. mask: str, default None optionnaly, mask the input data using this numpy array. withdraw_subjects: list of str, default None optionaly, provide a list of subjects to remove from the dataset. Raises ------ FileNotFoundError If the mandatorry input files are not found. KeyError If the mandatory key are not found. UserWarning If missing data are found. """ __metaclass__ = abc.ABCMeta def __init__(self, root, patterns, channels, split="train", targets=None, target_mapping=None, transforms=None, mask=None, withdraw_subjects=None): # Sanity if not isinstance(patterns, (list, tuple)): patterns = [patterns] if not isinstance(channels, (list, tuple)): channels = [channels] if targets is not None and not isinstance(targets, (list, tuple)): targets = [targets] participant_file = os.path.join(root, "participants.tsv") split_file = os.path.join(root, f"{split}.tsv") for path in (participant_file, split_file): if not os.path.isfile(path): raise FileNotFoundError( errno.ENOENT, os.strerror(errno.ENOENT), path) # Parameters self.root = root self.patterns = patterns self.channels = channels self.n_modalities = len(self.channels) self.targets = targets self.target_mapping = target_mapping or {} self.split = split self.transforms = transforms self.mask = (np.load(mask) if mask is not None else None) # Load subjects self.info_df = pd.read_csv(participant_file, sep="\t") if "participant_id" not in self.info_df: raise KeyError( "A 'participant_id' is mandatory in the participants file.") self.info_df = self.info_df.astype({"participant_id": "str"}) self.split_df = pd.read_csv(split_file, sep="\t") if "participant_id" not in self.split_df: raise KeyError( "A 'participant_id' is mandatory in the split file.") self.split_df = self.split_df[["participant_id"]] self.split_df = self.split_df.astype({"participant_id": "str"}) if withdraw_subjects is not None: self.split_df = self.split_df[ ~self.split_df["participant_id"].isin(withdraw_subjects)] self._df = pd.merge(self.split_df, self.info_df, on="participant_id") # Keep only useful information / sanitize if targets is not None: for key in targets: if key not in self._df: raise KeyError( f"A '{key}' column is mandatory in the participant " "file.") self._df = self._df[["participant_id"] + (targets or [])] _missing_data = self._df[self._df.isnull().any(axis=1)] if len(_missing_data) > 0: warnings.warn(f"Missing data in {split}!", UserWarning, stacklevel=2) self._df.replace(self.target_mapping, inplace=True) self._targets = ( self._df[targets].values if targets is not None else None) def __repr__(self): return (f"{self.__class__.__name__}<split='{self.split}'," f"modalities={self.n_modalities},targets={self.targets}>") def __len__(self): return len(self._df)
[docs] class BaseNumpyDataset(BaseDataset): """ Neuroimaging dataset that uses numpy arrays and memory mapping. Notes ----- A 'participants.tsv' file containing subject information (including the requested targets) is expected at the root. A '<split>.tsv' file containg the subject to include is expected at the root. Parameters ---------- root: str the location where are stored the data. patterns: str or list of str the relative locations (no path names matching allowed in specified pattern) of the numpy array to be loaded. channels: str or list of str, default=None the name of the channels. split: str, default 'train' define the split to be considered. targets: str or list of str, default=None the dataset will also return these tabular data. target_mapping: dict, default None optionaly, define a dictionary specifying different replacement values for different existing values. See pandas DataFrame.replace documentation for more information. transforms: callable, default None a function that can be called to augment the input images. mask: str, default None optionnaly, mask the input data using this numpy array. withdraw_subjects: list of str, default None optionaly, provide a list of subjects to remove from the dataset. Raises ------ FileNotFoundError If the mandatorry input files are not found. KeyError If the mandatory key are not found. UserWarning If missing data are found. """ def __init__(self, root, patterns, channels, split="train", targets=None, target_mapping=None, transforms=None, mask=None, withdraw_subjects=None): super().__init__( root, patterns, channels, split=split, targets=targets, target_mapping=target_mapping, transforms=transforms, mask=mask, withdraw_subjects=withdraw_subjects) self._data = [np.load(os.path.join(root, name), mmap_mode="r") for name in patterns]
[docs] def get_data(self, idx): """ Proper data indexing. """ subject = self._df.iloc[idx].participant_id data_idx = self.info_df.loc[ self.info_df.participant_id == subject].index.item() return ([arr[data_idx] for arr in self._data], (self._targets[idx] if self._targets is not None else None))
@abc.abstractmethod def __getitem__(self, idx): """ Get an item of the dataset: this method must be implemented in derived class. """
[docs] class BaseImageDataset(BaseDataset): """ Scalable neuroimaging dataset that uses files. Notes ----- A 'participants.tsv' file containing subject information (including the requested targets) is expected at the root. A '<split>.tsv' file containg the subject to include is expected at the root. The general idea is not to copy all your data in the root folder but rather use a single symlink per project (if you are working with aggregated data). To enforce reproducibility you can check if the content of each file is persistent using the `get_checksum` method. Parameters ---------- root: str the location where are stored the data. patterns: str or list of str the relative locations of the images to be loaded. channels: str or list of str, default=None the name of the channels. subject_in_patterns: int or list of int the folder level where the subject identifiers can be retrieved. split: str, default 'train' define the split to be considered. targets: str or list of str, default=None the dataset will also return these tabular data. target_mapping: dict, default None optionaly, define a dictionary specifying different replacement values for different existing values. See pandas DataFrame.replace documentation for more information. transforms: callable, default None a function that can be called to augment the input images. mask: str, default None optionnaly, mask the input data using this numpy array. withdraw_subjects: list of str, default None optionaly, provide a list of subjects to remove from the dataset. Raises ------ FileNotFoundError If the mandatorry input files are not found. KeyError If the mandatory key are not found. UserWarning If missing data are found. """ def __init__(self, root, patterns, channels, subject_in_patterns, split="train", targets=None, target_mapping=None, transforms=None, mask=None, withdraw_subjects=None): super().__init__( root, patterns, channels, split=split, targets=targets, target_mapping=target_mapping, transforms=transforms, mask=mask, withdraw_subjects=withdraw_subjects) if not isinstance(subject_in_patterns, (list, tuple)): subject_in_patterns = [subject_in_patterns] * len(patterns) assert len(patterns) == len(subject_in_patterns) self.subject_in_patterns = subject_in_patterns self._data = {} for idx, pattern in enumerate(patterns): _regex = os.path.join(root, pattern) _sidx = subject_in_patterns[idx] _files = { self.sanitize_subject(path.split(os.sep)[_sidx]): path for path in glob.glob(_regex)} self._data[f"{self.channels[idx]}"] = [ _files.get(subject) for subject in self._df["participant_id"]] self._data = pd.DataFrame.from_dict(self._data) _missing_data = self._data[self._data.isnull().any(axis=1)] if len(_missing_data) > 0: warnings.warn(f"Missing file data in {split}!", UserWarning, stacklevel=2) self._data = self._data.values def sanitize_subject(self, subject): return subject.replace("sub-", "").split("_")[0]
[docs] def get_checksum(self, path): """ Hashing file. """ with open(path) as of: checksum = hashlib.sha1(of.read()).hexdigest() return checksum
[docs] def get_data(self, idx): """ Proper data indexing. """ return self._data[idx], (self._targets[idx] if self._targets is not None else None)
@abc.abstractmethod def __getitem__(self, idx): """ Get an item of the dataset: this method must be implemented in derived class. """

Follow us

© 2025, nidl developers