Menu

Deep learning for NeuroImaging in Python.

Source code for surfify.datasets._generic

# -*- coding: utf-8 -*-
##########################################################################
# NSAp - Copyright (C) CEA, 2025
# Distributed under the terms of the CeCILL-B license, as published by
# the CEA-CNRS-INRIA. Refer to the LICENSE file or to
# http://www.cecill.info/licences/Licence_CeCILL-B_V1-en.html
# for details.
##########################################################################

"""
Create a generic dataset to import cortical data.
"""

# Imports
import os
import glob
import errno
import nibabel
import numpy as np
import pandas as pd
from torch.utils.data import Dataset
from surfify.utils import icosahedron, downsample, patch_tri


[docs] class GenericSurfDataset(Dataset): """ A scalable neuroimaging dataset. Parameters ---------- root: str the location where are stored the data. patterns: str or list of str the regex that can be used to retrieve the images of interest or any data that can be retrieved by nibabel.load. subject_in_patterns: int or list of int the folder level where the subject identifiers can be retrieved. ico_order: int the input data ico order. targets: str or list of str the dataset will also return these tabular data. A 'participants.tsv' file containing subject information (including the requested targets) is expected at the root. target_mapping: dict, default None optionaly, define a dictionary specifying different replacement values for different existing values. See pandas DataFrame.replace documentation for more information. split: str, default 'train' define the split to be considered. A '<split>.tsv' file containg the subject to include us expected at the root. transforms: callable or list of callable, default None a function that can be called to augment the input images. mask: array, default None optionnaly, mask the input image. contrastive: bool, default False optionaly, create a contrastive dataset that will return a pair of augmented images. patch: bool, default False optionaly, return triangular patches. n_max: int, default None optionaly, keep only a subset of subjects (for debuging purposes). withdraw_subjects: list of str, default None optionaly, provide a list of subjects to remove from the dataset. target_ico_order: int, default None the desired ico order (data will be downsample to this resolution). size: int, default 3 the patch size. """ def __init__(self, root, patterns, subject_in_patterns, ico_order, targets, target_mapping=None, split="train", transforms=None, mask=None, contrastive=False, patch=False, n_max=None, withdraw_subjects=None, target_ico_order=None, size=3): # Sanity if not isinstance(patterns, (list, tuple)): patterns = [patterns] if not isinstance(subject_in_patterns, (list, tuple)): subject_in_patterns = [subject_in_patterns] if not isinstance(targets, (list, tuple)): targets = [targets] if not isinstance(transforms, (list, tuple)): transforms = [transforms] * len(patterns) assert len(patterns) == len(transforms) assert len(patterns) == len(subject_in_patterns) participant_file = os.path.join(root, "participants.tsv") split_file = os.path.join(root, f"{split}.tsv") for path in (participant_file, split_file): if not os.path.isfile(path): raise FileNotFoundError( errno.ENOENT, os.strerror(errno.ENOENT), path) # Parameters self.root = root self.patterns = patterns self.n_modalities = len(patterns) self.ico_order = ico_order self.targets = targets self.target_mapping = target_mapping self.split = split self.transforms = transforms self.mask = mask self.contrastive = contrastive self.patch = patch self.target_ico_order = target_ico_order self.size = size # Load subjects / data location self.info_df = pd.read_csv(participant_file, sep="\t") self.info_df = self.info_df.astype({"participant_id": "str"}) self.split_df = pd.read_csv(split_file, sep="\t") self.split_df = self.split_df[["participant_id"]] self.split_df = self.split_df.astype({"participant_id": "str"}) if withdraw_subjects is not None: self.split_df = self.split_df[ ~self.split_df["participant_id"].isin(withdraw_subjects)] self._df = pd.merge(self.split_df, self.info_df, on="participant_id") self.mod_names = [] for idx, pattern in enumerate(patterns): _regex = os.path.join(root, pattern) _sidx = subject_in_patterns[idx] _files = dict( (self.sanitize_subject(path.split(os.sep)[_sidx]), path) for path in glob.glob(_regex)) self._df[f"data{idx}"] = [ _files.get(subject) for subject in self._df["participant_id"]] self.mod_names.append(f"data{idx}") # Keep only useful information / sanitize self._df = self._df[["participant_id"] + self.mod_names + targets] _missing_data = self._df[self._df.isnull().any(axis=1)] if len(_missing_data) > 0: print(_missing_data) print(_missing_data.participant_id.values.tolist()) raise ValueError(f"Missing data in {split}!") self._df = self._df[self.mod_names + targets] self._df.replace(target_mapping, inplace=True) if n_max is not None and len(self._df) > n_max: self._df = self._df.head(n_max) self.data = self._df[self.mod_names].values self.target = self._df[targets].values # Cache some parameters if target_ico_order is not None: ico_verts, ico_tris = icosahedron(order=ico_order) target_ico_verts, target_ico_tris = icosahedron( order=target_ico_order) self.down_indices = downsample(ico_verts, target_ico_verts) else: self.down_indices = None if self.patch: self.patch_indices = patch_tri( order=target_ico_order, size=size, direct_neighbor=True, n_jobs=-1) def sanitize_subject(self, subject): return subject.replace("sub-", "").split("_")[0] def __repr__(self): return (f"{self.__class__.__name__}<split='{self.split}'," f"modalities={self.n_modalities},targets={self.targets}," f"contrastive={self.contrastive},patch={self.patch}>") def __getitem__(self, idx): """ Get an item of the dataset. Paraemters ---------- idx: int the item location in the dataset. Returns ------- Xi: array (<n_patches>, 2|1, n_vertices) the returned texture data (3d for patched data, 2d otherwise). yi: array (n_aux, ) the returned auxiliary variables. """ data = [] for path, trf in zip(self.data[idx], self.transforms): arr = nibabel.load(path).get_fdata().astype(np.float32).squeeze() if self.mask is not None: arr[np.where(self.mask == 0)] = 0 if self.down_indices is not None: arr = arr[self.down_indices] arr = np.expand_dims(arr, axis=0) if self.contrastive: assert trf is not None arr = np.stack((trf(arr), trf(arr)), axis=0) elif trf is not None: arr = trf(arr) if self.patch: arr = np.asarray([ arr[:, indices] for indices in self.patch_indices]) data.append(arr) return *data, *self.target[idx] def __len__(self): return len(self._df)

Follow us

© 2025, nidl developers